├── requirements.txt
├── scripts
├── images
│ ├── lilly.jpg
│ ├── sunflower.JPG
│ ├── cow_in_beach.jpg
│ └── test_image.jpg
├── run.py
├── run_multimodal.py
├── run_multimodal.py.orig
└── run_xla.py
├── tokenizer
├── tokenizer.model
└── gemma3_cleaned_262144_v2.spiece.model
├── gemma
├── __init__.py
├── siglip_vision
│ ├── __init__.py
│ ├── config.py
│ ├── preprocessor.py
│ ├── pan_and_scan.py
│ └── siglip_vision_model.py
├── tokenizer.py
├── gemma3_preprocessor.py
├── config.py
├── gemma3_model.py
├── model_xla.py
├── xla_model_parallel.py
└── model.py
├── CONTRIBUTING.md
├── docker
├── Dockerfile
├── xla.Dockerfile
└── xla_gpu.Dockerfile
├── setup.py
├── .gitignore
├── README.md
└── LICENSE
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==2.2.3
2 | pillow==11.1.0
3 | sentencepiece==0.2.0
4 | torch==2.6.0
5 | absl-py==2.1.0
--------------------------------------------------------------------------------
/scripts/images/lilly.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/gemma_pytorch/HEAD/scripts/images/lilly.jpg
--------------------------------------------------------------------------------
/tokenizer/tokenizer.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/gemma_pytorch/HEAD/tokenizer/tokenizer.model
--------------------------------------------------------------------------------
/scripts/images/sunflower.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/gemma_pytorch/HEAD/scripts/images/sunflower.JPG
--------------------------------------------------------------------------------
/scripts/images/cow_in_beach.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/gemma_pytorch/HEAD/scripts/images/cow_in_beach.jpg
--------------------------------------------------------------------------------
/scripts/images/test_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/gemma_pytorch/HEAD/scripts/images/test_image.jpg
--------------------------------------------------------------------------------
/tokenizer/gemma3_cleaned_262144_v2.spiece.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/gemma_pytorch/HEAD/tokenizer/gemma3_cleaned_262144_v2.spiece.model
--------------------------------------------------------------------------------
/gemma/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/gemma/siglip_vision/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute
2 |
3 | We'd love to accept your patches and contributions to this project.
4 |
5 | ## Before you begin
6 |
7 | ### Sign our Contributor License Agreement
8 |
9 | Contributions to this project must be accompanied by a
10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
11 | You (or your employer) retain the copyright to your contribution; this simply
12 | gives us permission to use and redistribute your contributions as part of the
13 | project.
14 |
15 | If you or your current employer have already signed the Google CLA (even if it
16 | was for a different project), you probably don't need to do it again.
17 |
18 | Visit to see your current agreements or to
19 | sign a new one.
20 |
21 | ### Review our community guidelines
22 |
23 | This project follows
24 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
25 |
26 | ## Contribution process
27 |
28 | ### Code reviews
29 |
30 | All submissions, including submissions by project members, require review. We
31 | use GitHub pull requests for this purpose. Consult
32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
33 | information on using pull requests.
34 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Use PyTorch CUDA base image
16 | FROM pytorch/pytorch:2.1.2-cuda11.8-cudnn8-runtime
17 |
18 | USER root
19 |
20 | # Install tools.
21 | ENV DEBIAN_FRONTEND=noninteractive
22 | RUN apt-get update
23 | RUN apt-get install -y --no-install-recommends apt-utils
24 | RUN apt-get install -y --no-install-recommends curl
25 | RUN apt-get install -y --no-install-recommends wget
26 | RUN apt-get install -y --no-install-recommends git
27 |
28 | # Install libraries.
29 | ENV PIP_ROOT_USER_ACTION=ignore
30 | RUN python -m pip install --upgrade pip
31 | RUN pip install numpy==1.24.4
32 | RUN pip install sentencepiece==0.1.99
33 |
34 | # Install from source.
35 | COPY . /workspace/gemma/
36 | WORKDIR /workspace/gemma/
37 | RUN pip install -e .
38 |
--------------------------------------------------------------------------------
/docker/xla.Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Use pytorch/xla base image
16 | FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_20231128
17 |
18 | USER root
19 |
20 | # Install tools.
21 | ENV DEBIAN_FRONTEND=noninteractive
22 | RUN apt-get update
23 | RUN apt-get install -y --no-install-recommends apt-utils
24 | RUN apt-get install -y --no-install-recommends curl
25 | RUN apt-get install -y --no-install-recommends wget
26 | RUN apt-get install -y --no-install-recommends git
27 |
28 | # Install libraries.
29 | ENV PIP_ROOT_USER_ACTION=ignore
30 | RUN python3 -m pip install --upgrade pip
31 | RUN pip install numpy==1.24.4
32 | RUN pip install sentencepiece==0.1.99
33 |
34 | # Install from source.
35 | COPY . /workspace/gemma/
36 | WORKDIR /workspace/gemma/
37 | RUN pip install -e .
38 |
--------------------------------------------------------------------------------
/docker/xla_gpu.Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Use pytorch/xla base image
16 | FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8_20231128
17 |
18 | USER root
19 |
20 | # Install tools.
21 | ENV DEBIAN_FRONTEND=noninteractive
22 | RUN apt-get update
23 | RUN apt-get install -y --no-install-recommends apt-utils
24 | RUN apt-get install -y --no-install-recommends curl
25 | RUN apt-get install -y --no-install-recommends wget
26 | RUN apt-get install -y --no-install-recommends git
27 |
28 | # Install libraries.
29 | ENV PIP_ROOT_USER_ACTION=ignore
30 | RUN python3 -m pip install --upgrade pip
31 | RUN pip uninstall -y torch
32 | RUN pip install torch==2.1.1
33 | RUN pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-2.2.0rc1-cp38-cp38-linux_x86_64.whl
34 | RUN pip install numpy==1.24.4
35 | RUN pip install sentencepiece==0.1.99
36 |
37 | # Install from source.
38 | COPY . /workspace/gemma/
39 | WORKDIR /workspace/gemma/
40 | RUN pip install -e .
41 |
--------------------------------------------------------------------------------
/gemma/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | from typing import List, Optional
16 |
17 | import sentencepiece
18 |
19 | def _assert_file_exists(model_path: str):
20 | assert os.path.isfile(model_path), model_path
21 |
22 | _BEGIN_IMAGE_TOKEN = 255999
23 | _END_IMAGE_TOKEN = 256000
24 |
25 | class Tokenizer:
26 |
27 | def __init__(self, model_path: Optional[str]):
28 | _assert_file_exists(model_path)
29 | self.sp_model = sentencepiece.SentencePieceProcessor()
30 | self.sp_model.Load(model_path)
31 |
32 | # BOS / EOS token IDs.
33 | self.n_words: int = self.sp_model.GetPieceSize()
34 | self.bos_id: int = self.sp_model.bos_id()
35 | self.eos_id: int = self.sp_model.eos_id()
36 | self.pad_id: int = self.sp_model.pad_id()
37 | self.boi_id: int = _BEGIN_IMAGE_TOKEN
38 | self.eoi_id: int = _END_IMAGE_TOKEN
39 | self.image_token_placeholder_id: int = self.sp_model.pad_id()
40 |
41 | def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]:
42 | """Converts a string into a list of tokens."""
43 | assert isinstance(s, str)
44 | t = self.sp_model.EncodeAsIds(s)
45 | if bos:
46 | t = [self.bos_id] + t
47 | if eos:
48 | t = t + [self.eos_id]
49 | return t
50 |
51 | def decode(self, t: List[int]) -> str:
52 | """Converts a list of tokens into a string."""
53 | return self.sp_model.DecodeIds(t)
54 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import io
16 | import os
17 | from typing import List
18 |
19 | import setuptools
20 |
21 | ROOT_DIR = os.path.dirname(__file__)
22 |
23 |
24 | def get_path(*filepath) -> str:
25 | return os.path.join(ROOT_DIR, *filepath)
26 |
27 |
28 | def read_readme() -> str:
29 | """Read the README file."""
30 | return io.open(get_path("README.md"), "r", encoding="utf-8").read()
31 |
32 |
33 | def get_requirements() -> List[str]:
34 | """Get Python package dependencies from requirements.txt."""
35 | with open(get_path("requirements.txt")) as f:
36 | requirements = f.read().strip().split("\n")
37 | return requirements
38 |
39 |
40 | setuptools.setup(
41 | name="gemma",
42 | version="0.1",
43 | author="Gemma contributors",
44 | license="Apache 2.0",
45 | description=("Gemma model implementation"),
46 | long_description=read_readme(),
47 | long_description_content_type="text/markdown",
48 | classifiers=[
49 | "Programming Language :: Python :: 3.8",
50 | "Programming Language :: Python :: 3.9",
51 | "Programming Language :: Python :: 3.10",
52 | "Programming Language :: Python :: 3.11",
53 | "License :: OSI Approved :: Apache Software License",
54 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
55 | ],
56 | packages=setuptools.find_packages(exclude=("benchmarks", "docs",
57 | "examples", "tests")),
58 | python_requires=">=3.11",
59 | install_requires=get_requirements(),
60 | )
61 |
--------------------------------------------------------------------------------
/gemma/siglip_vision/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Gemma model config."""
16 |
17 | import dataclasses
18 | from . import preprocessor
19 |
20 |
21 | # https://developers.googleblog.com/en/gemma-explained-paligemma-architecture/
22 | @dataclasses.dataclass
23 | class SiglipVisionModelConfig:
24 | """Returns the model config for the vision model of Gemma 3 andPaliGemma."""
25 | # The number of transformer encoder blocks in the siglip encoder model.
26 | num_hidden_layers: int = 27
27 | # The dimension of the embedding.
28 | embedding_dim: int = 1152
29 | # Whether to use bias in the 2D conv embedding layer.
30 | embedding_use_bias: bool = True
31 | # The number of channels in the input images.
32 | input_channels: int = 3
33 | # The input image size.
34 | image_size: int = preprocessor.DEFAULT_IMAGE_SIZE
35 | # Kernel size of 2D convolution layer.
36 | conv2d_patch_size = 14
37 | # The number of attention heads used in the attention layers of the model.
38 | num_attention_heads: int = 16
39 | # The number of head dimensions.
40 | head_dim: int = 72
41 | # Clarify: is num_key_value same as num_query_groups?
42 | num_key_value_heads: int = 16
43 | # The number of query groups for implementing attention.
44 | num_query_groups: int = 16
45 | # Clarify: usecase of this field is not clear.
46 | qkv_use_bias: bool = True
47 | # Clarify: usecase of this field is not clear.
48 | output_proj_use_bias: bool = True
49 | # The dimension of the MLP representations.
50 | intermediate_size: int = 4304
51 | # The epsilon used by the layer normalization layers.
52 | layer_norm_eps: float = 1e-6
53 | # Clarify: identify if the dtype varies for the siglip vision model.
54 | dtype: str = 'bfloat16'
55 | # Whether a quantized version of the model is used.
56 | quant: bool = False
57 | # The sequence length of the encoding.
58 | encoding_sequence_length: int = 256
59 |
60 |
61 | def get_siglip_vision_model_config() -> SiglipVisionModelConfig:
62 | """Returns the default model config for the vision model of Gemma 3 and PaliGemma."""
63 | return SiglipVisionModelConfig()
64 |
65 |
--------------------------------------------------------------------------------
/gemma/siglip_vision/preprocessor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Preprocessor for Siglip vision model.
15 |
16 | No neural network is used in the following functions. These are heuristic based.
17 | """
18 |
19 | from collections.abc import Sequence
20 |
21 | from PIL import Image
22 | import torch
23 | import numpy as np # Import NumPy
24 |
25 | _IMAGE_MEAN = [0.5, 0.5, 0.5] # equivalent to 127.5/255
26 | _IMAGE_STD = [0.5, 0.5, 0.5] # equivalent to 127.5/255
27 | DEFAULT_IMAGE_SIZE = 896
28 |
29 |
30 | def preprocess_images_for_siglip_vision(
31 | images: Sequence[Image.Image], image_size=DEFAULT_IMAGE_SIZE
32 | ) -> torch.Tensor:
33 | """Preprocesses a list of PIL images for Siglip vision model using only PyTorch and PIL.
34 |
35 | Args:
36 | images: A sequence of PIL Image objects.
37 | image_size: The target size for resizing the images.
38 |
39 | Returns:
40 | A sequence of torch.Tensor objects, each of shape (C, H, W).
41 | """
42 | processed_images = []
43 |
44 | mean_tensor = torch.tensor(_IMAGE_MEAN, dtype=torch.float32).reshape(3, 1, 1)
45 | std_tensor = torch.tensor(_IMAGE_STD, dtype=torch.float32).reshape(3, 1, 1)
46 |
47 | for image in images:
48 | # Resize image
49 | image = image.resize((image_size, image_size), Image.Resampling.BILINEAR)
50 |
51 | # Convert to NumPy and ensure float32 type
52 | image_np = np.array(image, dtype=np.float32) / 255.0 # Normalize to [0,1]
53 |
54 | # Convert to PyTorch tensor and rearrange channels
55 | image_tensor = torch.from_numpy(image_np).permute(2, 0, 1) # (H, W, C) → (C, H, W)
56 |
57 | # Normalize
58 | image_tensor = (image_tensor - mean_tensor) / std_tensor
59 |
60 | # Clip the values to [-1, 1]
61 | image_tensor = torch.clamp(image_tensor, -1, 1)
62 |
63 | processed_images.append(image_tensor)
64 |
65 | return processed_images
66 |
67 |
68 | # Example usage:
69 | # Assuming you have a list of PIL images called 'pil_images'
70 | # pil_images = [Image.open("image1.jpg"), Image.open("image2.png")]
71 | # processed_tensors = preprocess_images_pytorch(pil_images)
72 | # for tensor in processed_tensors: print(tensor.shape)
73 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
162 | # VSCode
163 | .vscode/*
164 | !.vscode/extensions.json
165 |
166 | # DS Store
167 | .DS_Store
168 |
169 | # Results
170 | *.csv
171 |
172 | # Python pickle files
173 | *.pkl
174 |
175 | # Sphinx documentation
176 | _build/
177 |
--------------------------------------------------------------------------------
/scripts/run.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import contextlib
17 | import random
18 |
19 | import numpy as np
20 | import torch
21 | from absl import app, flags
22 |
23 | from gemma import config
24 | from gemma import model as gemma_model
25 |
26 | # Define flags
27 | FLAGS = flags.FLAGS
28 |
29 | flags.DEFINE_string('ckpt', None, 'Path to the checkpoint file.', required=True)
30 | flags.DEFINE_string('variant', '4b', 'Model variant.')
31 | flags.DEFINE_string('device', 'cpu', 'Device to run the model on.')
32 | flags.DEFINE_integer('output_len', 10, 'Length of the output sequence.')
33 | flags.DEFINE_integer('seed', 12345, 'Random seed.')
34 | flags.DEFINE_boolean('quant', False, 'Whether to use quantization.')
35 | flags.DEFINE_string('prompt', 'What are large language models?', 'Input prompt for the model.')
36 |
37 | # Define valid text only model variants
38 | _VALID_MODEL_VARIANTS = ['2b', '2b-v2', '7b', '9b', '27b', '1b']
39 |
40 | # Define valid devices
41 | _VALID_DEVICES = ['cpu', 'cuda']
42 |
43 | # Validator function for the 'variant' flag
44 | def validate_variant(variant):
45 | if variant not in _VALID_MODEL_VARIANTS:
46 | raise ValueError(f'Invalid variant: {variant}. Valid variants are: {_VALID_MODEL_VARIANTS}')
47 | return True
48 |
49 | # Validator function for the 'device' flag
50 | def validate_device(device):
51 | if device not in _VALID_DEVICES:
52 | raise ValueError(f'Invalid device: {device}. Valid devices are: {_VALID_DEVICES}')
53 | return True
54 |
55 | # Register the validator for the 'variant' flag
56 | flags.register_validator('variant', validate_variant, message='Invalid model variant.')
57 |
58 | # Register the validator for the 'device' flag
59 | flags.register_validator('device', validate_device, message='Invalid device.')
60 |
61 | @contextlib.contextmanager
62 | def _set_default_tensor_type(dtype: torch.dtype):
63 | """Sets the default torch dtype to the given dtype."""
64 | torch.set_default_dtype(dtype)
65 | yield
66 | torch.set_default_dtype(torch.float)
67 |
68 | def main(_):
69 | # Construct the model config.
70 | model_config = config.get_model_config(FLAGS.variant)
71 | model_config.dtype = "float32"
72 | model_config.quant = FLAGS.quant
73 |
74 | # Seed random.
75 | random.seed(FLAGS.seed)
76 | np.random.seed(FLAGS.seed)
77 | torch.manual_seed(FLAGS.seed)
78 |
79 | # Create the model and load the weights.
80 | device = torch.device(FLAGS.device)
81 | with _set_default_tensor_type(model_config.get_dtype()):
82 | model = gemma_model.GemmaForCausalLM(model_config)
83 | model.load_weights(FLAGS.ckpt)
84 | model = model.to(device).eval()
85 | print("Model loading done")
86 |
87 | # Generate the response.
88 | result = model.generate(FLAGS.prompt, device, output_len=FLAGS.output_len)
89 |
90 | # Print the prompts and results.
91 | print('======================================')
92 | print(f'PROMPT: {FLAGS.prompt}')
93 | print(f'RESULT: {result}')
94 | print('======================================')
95 |
96 | if __name__ == "__main__":
97 | app.run(main)
98 |
99 |
100 | # How to run this script:
101 |
102 | # Example command (replace with your actual paths and values):
103 | # python scripts/run.py --device=cpu --ckpt=/path/to/your/pytorch_checkpoint/model.ckpt --output_len=2 --prompt="The name of the capital of Italy is"
104 | # Important:
105 | # - Replace '/path/to/your/pytorch_checkpoint/model.ckpt' with the actual path to your checkpoint file.
106 | # - Choose the correct --variant (model size).
107 | # - Use --device=cuda if you have a GPU; otherwise, use --device=cpu.
--------------------------------------------------------------------------------
/gemma/siglip_vision/pan_and_scan.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Pan and scan image cropping implementation."""
15 |
16 | from collections.abc import Sequence
17 |
18 | import numpy as np
19 | from PIL import Image
20 |
21 |
22 | def pan_and_scan(
23 | img: Image.Image,
24 | *,
25 | min_crop_size: int = 256,
26 | max_num_crops: int = 4,
27 | ) -> Sequence[Image.Image]:
28 | return _pan_and_scan_os(
29 | img,
30 | min_crop_size=min_crop_size,
31 | max_num_crops=max_num_crops,
32 | )[0]
33 |
34 |
35 | def pan_and_scan_os_with_crop_positions(
36 | img: Image.Image,
37 | *,
38 | min_crop_size: int = 256,
39 | max_num_crops: int = 4,
40 | ) -> tuple[Sequence[Image.Image], Sequence[tuple[int, int, int, int]]]:
41 | return _pan_and_scan_os(
42 | img,
43 | min_crop_size=min_crop_size,
44 | max_num_crops=max_num_crops,
45 | )
46 |
47 |
48 | def _pan_and_scan_os(
49 | img: Image.Image,
50 | *,
51 | min_crop_size: int,
52 | max_num_crops: int,
53 | ) -> tuple[Sequence[Image.Image], Sequence[tuple[int, int, int, int]]]:
54 | """Pan and scan an image for open source.
55 |
56 | If the image is landscape, the crops are made horizontally and if the image is
57 | portrait, the crops are made vertically. The longer side of the image is split
58 | into [2 - max_num_crops] crops.
59 |
60 | Args:
61 | img: PIL Image object.
62 | min_crop_size: The minimum size of each crop.
63 | max_num_crops: The maximum desired number of crops to be generated.
64 |
65 | Returns:
66 | List of cropped PIL Image objects and a list of crop positions.
67 | """
68 | w, h = img.size
69 |
70 | # Square or landscape image.
71 | if w >= h:
72 | if w / h < 1.5:
73 | return [img], [(0, 0, h, w)]
74 |
75 | # Select ideal number of crops close to the image aspect ratio and such that
76 | # crop_size > min_crop_size.
77 | num_crops_w = int(np.floor(w / h + 0.5)) # Half round up rounding.
78 | num_crops_w = min(
79 | int(np.floor(w / min_crop_size)),
80 | num_crops_w,
81 | )
82 |
83 | # Make sure the number of crops is in range [2, max_num_crops].
84 | num_crops_w = max(2, num_crops_w)
85 | num_crops_w = min(max_num_crops, num_crops_w)
86 | num_crops_h = 1
87 |
88 | # Portrait image.
89 | else:
90 | if h / w < 1.5:
91 | return [img], [(0, 0, h, w)]
92 |
93 | num_crops_h = int(np.floor(h / w + 0.5))
94 | num_crops_h = min(int(np.floor(h / min_crop_size)), num_crops_h)
95 | num_crops_h = max(2, num_crops_h)
96 | num_crops_h = min(max_num_crops, num_crops_h)
97 | num_crops_w = 1
98 |
99 | crop_size_w = int(np.ceil(w / num_crops_w))
100 | crop_size_h = int(np.ceil(h / num_crops_h))
101 |
102 | # Don't apply pan and scan if crop size is too small.
103 | if min(crop_size_w, crop_size_h) < min_crop_size:
104 | return [img], [(0, 0, h, w)]
105 |
106 | crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]
107 | crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
108 |
109 | # Generate crops.
110 | crops = []
111 | crop_positions = []
112 | for pos_h in crop_positions_h:
113 | for pos_w in crop_positions_w:
114 | crops.append(
115 | img.crop((
116 | pos_w,
117 | pos_h,
118 | pos_w + crop_size_w,
119 | pos_h + crop_size_h,
120 | ))
121 | )
122 | crop_positions.append(
123 | (pos_h, pos_w, pos_h + crop_size_h, pos_w + crop_size_w)
124 | )
125 |
126 | return crops, crop_positions
127 |
--------------------------------------------------------------------------------
/scripts/run_multimodal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import contextlib
17 | import random
18 |
19 | from absl import app
20 | from absl import flags
21 | import numpy as np
22 | from PIL import Image
23 | import torch
24 |
25 | from gemma import config
26 | from gemma import gemma3_model
27 |
28 | # Define flags
29 | FLAGS = flags.FLAGS
30 |
31 | _CKPT = flags.DEFINE_string(
32 | 'ckpt', None, 'Path to the checkpoint file.', required=True
33 | )
34 | _VARIANT = flags.DEFINE_string('variant', '4b', 'Model variant.')
35 | _DEVICE = flags.DEFINE_string('device', 'cpu', 'Device to run the model on.')
36 | _OUTPUT_LEN = flags.DEFINE_integer(
37 | 'output_len', 10, 'Length of the output sequence.'
38 | )
39 | _SEED = flags.DEFINE_integer('seed', 12345, 'Random seed.')
40 | _QUANT = flags.DEFINE_boolean('quant', False, 'Whether to use quantization.')
41 |
42 | # Define valid multimodal model variants
43 | _VALID_MODEL_VARIANTS = ['4b', '12b', '27b_v3']
44 |
45 | # Define valid devices
46 | _VALID_DEVICES = ['cpu', 'cuda']
47 |
48 |
49 | # Validator function for the 'variant' flag
50 | def validate_variant(variant):
51 | if variant not in _VALID_MODEL_VARIANTS:
52 | raise ValueError(
53 | f'Invalid variant: {variant}. Valid variants are:'
54 | f' {_VALID_MODEL_VARIANTS}'
55 | )
56 | return True
57 |
58 |
59 | # Validator function for the 'device' flag
60 | def validate_device(device):
61 | if device not in _VALID_DEVICES:
62 | raise ValueError(
63 | f'Invalid device: {device}. Valid devices are: {_VALID_DEVICES}'
64 | )
65 | return True
66 |
67 |
68 | # Register the validator for the 'variant' flag
69 | flags.register_validator(
70 | 'variant', validate_variant, message='Invalid model variant.'
71 | )
72 |
73 | # Register the validator for the 'device' flag
74 | flags.register_validator('device', validate_device, message='Invalid device.')
75 |
76 |
77 | @contextlib.contextmanager
78 | def _set_default_tensor_type(dtype: torch.dtype):
79 | """Sets the default torch dtype to the given dtype."""
80 | torch.set_default_dtype(dtype)
81 | yield
82 | torch.set_default_dtype(torch.float)
83 |
84 |
85 | def main(_):
86 | # Construct the model config.
87 | model_config = config.get_model_config(_VARIANT.value)
88 | model_config.dtype = 'float32'
89 | model_config.quant = _QUANT.value
90 | image_paths = {"cow_in_beach": "scripts/images/cow_in_beach.jpg",
91 | "lilly": "scripts/images/lilly.jpg",
92 | "sunflower": "scripts/images/sunflower.JPG",
93 | 'golden_test_image': (
94 | 'scripts/images/test_image.jpg'
95 | ),
96 | }
97 |
98 | image = {}
99 | for key in image_paths:
100 | try:
101 | image[key] = Image.open(image_paths[key]) # Open local file
102 | image[key].show()
103 | except IOError as e:
104 | print(f"Error loading image: {e}")
105 | exit()
106 |
107 | # Seed random.
108 | random.seed(_SEED.value)
109 | np.random.seed(_SEED.value)
110 | torch.manual_seed(_SEED.value)
111 |
112 | # Create the model and load the weights.
113 | device = torch.device(_DEVICE.value)
114 | with _set_default_tensor_type(model_config.get_dtype()):
115 | model = gemma3_model.Gemma3ForMultimodalLM(model_config)
116 | model.load_state_dict(torch.load(_CKPT.value)['model_state_dict'])
117 | # model.load_weights(_CKPT.value)
118 | model = model.to(device).eval()
119 | print('Model loading done')
120 |
121 | # Generate text only.
122 | result = model.generate(
123 | [
124 | [
125 | 'user The capital of Italy'
126 | ' is?\nmodel'
127 | ],
128 | [
129 | 'user What is your'
130 | ' purpose?\nmodel'
131 | ],
132 | ],
133 | device,
134 | output_len=_OUTPUT_LEN.value,
135 | )
136 |
137 | # Print the results.
138 | print('======================================')
139 | print(f'Text only RESULT: {result}')
140 | print('======================================')
141 |
142 | # Generate golden Gemax test image.
143 | result = model.generate(
144 | [[
145 | 'user\n',
146 | image['golden_test_image'],
147 | 'Caption this image. \nmodel',
148 | ]],
149 | device,
150 | output_len=_OUTPUT_LEN.value,
151 | )
152 |
153 | # Print the result.
154 | print('======================================')
155 | print(f'Golden test image RESULT: {result}')
156 | print('======================================')
157 |
158 | # Generate text and image.
159 | result = model.generate(
160 | [[
161 | 'user\n',
162 | image['cow_in_beach'],
163 | (
164 | 'The name of the animal in the image is'
165 | ' \nmodel'
166 | ),
167 | ]],
168 | device,
169 | output_len=_OUTPUT_LEN.value,
170 | )
171 |
172 | # Print the result.
173 | print('======================================')
174 | print(f'Single image RESULT: {result}')
175 | print('======================================')
176 |
177 | # Generate interleave text and multiple images.
178 | result = model.generate(
179 | [[
180 | 'user\nThis image',
181 | image['lilly'],
182 | 'and this image',
183 | image['sunflower'],
184 | 'are similar because? \nmodel',
185 | ]],
186 | device,
187 | output_len=_OUTPUT_LEN.value,
188 | )
189 |
190 | # Print the result.
191 | print('======================================')
192 | print(f'Interleave images RESULT: {result}')
193 | print('======================================')
194 |
195 |
196 | if __name__ == '__main__':
197 | app.run(main)
198 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Gemma in PyTorch
2 |
3 | **Gemma** is a family of lightweight, state-of-the art open models built from research and technology used to create Google Gemini models. They include both text-only and multimodal decoder-only large language models, with open weights, pre-trained variants, and instruction-tuned variants. For more details, please check out the following links:
4 |
5 | * [Gemma on Google AI](https://ai.google.dev/gemma)
6 | * [Gemma on Kaggle](https://www.kaggle.com/models/google/gemma-3)
7 | * [Gemma on Vertex AI Model Garden](https://pantheon.corp.google.com/vertex-ai/publishers/google/model-garden/gemma3)
8 |
9 | This is the official PyTorch implementation of Gemma models. We provide model and inference implementations using both PyTorch and PyTorch/XLA, and support running inference on CPU, GPU and TPU.
10 |
11 | ## Updates
12 |
13 | * [March 12th, 2025 🔥] Support Gemma v3. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma-3/pytorch) and [Hugging Face](https://huggingface.co/models?other=gemma_torch)
14 |
15 | * [June 26th, 2024] Support Gemma v2. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma-2/pytorch) and Hugging Face
16 |
17 | * [April 9th, 2024] Support CodeGemma. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/codegemma/pytorch) and [Hugging Face](https://huggingface.co/collections/google/codegemma-release-66152ac7b683e2667abdee11)
18 |
19 | * [April 5, 2024] Support Gemma v1.1. You can find the v1.1 checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/pyTorch) and [Hugging Face](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b).
20 |
21 | ## Download Gemma model checkpoint
22 |
23 | You can find the model checkpoints on Kaggle:
24 |
25 | - [Gemma 3](https://www.kaggle.com/models/google/gemma-3/pyTorch)
26 | - [Gemma 2](https://www.kaggle.com/models/google/gemma-2/pyTorch)
27 | - [Gemma](https://www.kaggle.com/models/google/gemma/pyTorch)
28 |
29 | Alternatively, you can find the model checkpoints on the Hugging Face Hub [here](https://huggingface.co/models?other=gemma_torch). To download the models, go the the model repository of the model of interest and click the `Files and versions` tab, and download the model and tokenizer files. For programmatic downloading, if you have `huggingface_hub` installed, you can also run:
30 |
31 | ```
32 | huggingface-cli download google/gemma-3-4b-it-pytorch
33 | ```
34 |
35 | The following model sizes are available:
36 |
37 | - **Gemma 3**:
38 | - **Text only**: 1b
39 | - **Multimodal**: 4b, 12b, 27b_v3
40 | - **Gemma 2**:
41 | - **Text only**: 2b-v2, 9b, 27b
42 | - **Gemma**:
43 | - **Text only**: 2b, 7b
44 |
45 |
46 | Note that you can choose between the 1B, 4B, 12B, and 27B variants.
47 |
48 | ```
49 | VARIANT=<1b, 2b, 2b-v2, 4b, 7b, 9b, 12b, 27b, 27b_v3>
50 | CKPT_PATH=
51 | ```
52 |
53 | ## Try it free on Colab
54 |
55 | Follow the steps at
56 | [https://ai.google.dev/gemma/docs/pytorch_gemma](https://ai.google.dev/gemma/docs/pytorch_gemma).
57 |
58 | ## Try it out with PyTorch
59 |
60 | Prerequisite: make sure you have setup docker permission properly as a non-root user.
61 |
62 | ```bash
63 | sudo usermod -aG docker $USER
64 | newgrp docker
65 | ```
66 |
67 | ### Build the docker image.
68 |
69 | ```bash
70 | DOCKER_URI=gemma:${USER}
71 |
72 | docker build -f docker/Dockerfile ./ -t ${DOCKER_URI}
73 | ```
74 |
75 | ### Run Gemma inference on CPU.
76 |
77 | > NOTE: This is a multimodal example. Use a multimodal variant.
78 |
79 | ```bash
80 | docker run -t --rm \
81 | -v ${CKPT_PATH}:/tmp/ckpt \
82 | ${DOCKER_URI} \
83 | python scripts/run_multimodal.py \
84 | --ckpt=/tmp/ckpt \
85 | --variant="${VARIANT}" \
86 | # add `--quant` for the int8 quantized model.
87 | ```
88 |
89 | ### Run Gemma inference on GPU.
90 |
91 | > NOTE: This is a multimodal example. Use a multimodal variant.
92 |
93 | ```bash
94 | docker run -t --rm \
95 | --gpus all \
96 | -v ${CKPT_PATH}:/tmp/ckpt \
97 | ${DOCKER_URI} \
98 | python scripts/run_multimodal.py \
99 | --device=cuda \
100 | --ckpt=/tmp/ckpt \
101 | --variant="${VARIANT}"
102 | # add `--quant` for the int8 quantized model.
103 | ```
104 |
105 | ## Try It out with PyTorch/XLA
106 |
107 | ### Build the docker image (CPU, TPU).
108 |
109 | ```bash
110 | DOCKER_URI=gemma_xla:${USER}
111 |
112 | docker build -f docker/xla.Dockerfile ./ -t ${DOCKER_URI}
113 | ```
114 |
115 | ### Build the docker image (GPU).
116 |
117 | ```bash
118 | DOCKER_URI=gemma_xla_gpu:${USER}
119 |
120 | docker build -f docker/xla_gpu.Dockerfile ./ -t ${DOCKER_URI}
121 | ```
122 |
123 | ### Run Gemma inference on CPU.
124 |
125 | > NOTE: This is a multimodal example. Use a multimodal variant.
126 |
127 | ```bash
128 | docker run -t --rm \
129 | --shm-size 4gb \
130 | -e PJRT_DEVICE=CPU \
131 | -v ${CKPT_PATH}:/tmp/ckpt \
132 | ${DOCKER_URI} \
133 | python scripts/run_xla.py \
134 | --ckpt=/tmp/ckpt \
135 | --variant="${VARIANT}" \
136 | # add `--quant` for the int8 quantized model.
137 | ```
138 |
139 | ### Run Gemma inference on TPU.
140 |
141 | Note: be sure to use the docker container built from `xla.Dockerfile`.
142 |
143 | ```bash
144 | docker run -t --rm \
145 | --shm-size 4gb \
146 | -e PJRT_DEVICE=TPU \
147 | -v ${CKPT_PATH}:/tmp/ckpt \
148 | ${DOCKER_URI} \
149 | python scripts/run_xla.py \
150 | --ckpt=/tmp/ckpt \
151 | --variant="${VARIANT}" \
152 | # add `--quant` for the int8 quantized model.
153 | ```
154 |
155 | ### Run Gemma inference on GPU.
156 |
157 | Note: be sure to use the docker container built from `xla_gpu.Dockerfile`.
158 |
159 | ```bash
160 | docker run -t --rm --privileged \
161 | --shm-size=16g --net=host --gpus all \
162 | -e USE_CUDA=1 \
163 | -e PJRT_DEVICE=CUDA \
164 | -v ${CKPT_PATH}:/tmp/ckpt \
165 | ${DOCKER_URI} \
166 | python scripts/run_xla.py \
167 | --ckpt=/tmp/ckpt \
168 | --variant="${VARIANT}" \
169 | # add `--quant` for the int8 quantized model.
170 | ```
171 |
172 | ### Tokenizer Notes
173 |
174 | 99 unused tokens are reserved in the pretrained tokenizer model to assist with more efficient training/fine-tuning. Unused tokens are in the string format of `` with token id range of `[7-104]`.
175 |
176 | ```
177 | "": 7,
178 | "": 8,
179 | "": 9,
180 | ...
181 | "": 104,
182 | ```
183 |
184 | ## Disclaimer
185 |
186 | This is not an officially supported Google product.
187 |
--------------------------------------------------------------------------------
/scripts/run_multimodal.py.orig:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import contextlib
17 | import random
18 |
19 | import numpy as np
20 | import torch
21 | from absl import app
22 | from absl import flags
23 | from google3.pyglib import gfile
24 | from PIL import Image
25 |
26 | from google3.third_party.open_models_release.gemma_pytorch.gemma import config
27 | from google3.third_party.open_models_release.gemma_pytorch.gemma import model as gemma_model
28 | from google3.third_party.open_models_release.gemma_pytorch.gemma import gemma3_model
29 |
30 | # Define flags
31 | FLAGS = flags.FLAGS
32 |
33 | _CKPT = flags.DEFINE_string('ckpt', None, 'Path to the checkpoint file.', required=True)
34 | _VARIANT = flags.DEFINE_string('variant', '4b', 'Model variant.')
35 | _DEVICE = flags.DEFINE_string('device', 'cpu', 'Device to run the model on.')
36 | _OUTPUT_LEN = flags.DEFINE_integer('output_len', 10, 'Length of the output sequence.')
37 | _SEED = flags.DEFINE_integer('seed', 12345, 'Random seed.')
38 | _QUANT = flags.DEFINE_boolean('quant', False, 'Whether to use quantization.')
39 |
40 | # Define valid model variants
41 | _VALID_MODEL_VARIANTS = ['4b', '12b', '27b_v3']
42 |
43 | # Define valid devices
44 | _VALID_DEVICES = ['cpu', 'cuda']
45 |
46 | # Validator function for the 'variant' flag
47 | def validate_variant(variant):
48 | if variant not in _VALID_MODEL_VARIANTS:
49 | raise ValueError(f'Invalid variant: {variant}. Valid variants are: {_VALID_MODEL_VARIANTS}')
50 | return True
51 |
52 | # Validator function for the 'device' flag
53 | def validate_device(device):
54 | if device not in _VALID_DEVICES:
55 | raise ValueError(f'Invalid device: {device}. Valid devices are: {_VALID_DEVICES}')
56 | return True
57 |
58 | # Register the validator for the 'variant' flag
59 | flags.register_validator('variant', validate_variant, message='Invalid model variant.')
60 |
61 | # Register the validator for the 'device' flag
62 | flags.register_validator('device', validate_device, message='Invalid device.')
63 |
64 | @contextlib.contextmanager
65 | def _set_default_tensor_type(dtype: torch.dtype):
66 | """Sets the default torch dtype to the given dtype."""
67 | torch.set_default_dtype(dtype)
68 | yield
69 | torch.set_default_dtype(torch.float)
70 |
71 | def main(_):
72 | # Construct the model config.
73 | model_config = config.get_model_config(_VARIANT.value)
74 | model_config.dtype = "float32" if _DEVICE.value == "cpu" else "float16"
75 | model_config.quant = _QUANT.value
76 | image_paths = {"cow_in_beach": "/cns/pd-d/home/omst/test_images/cow_in_beach.jpg",
77 | "lilly": "/cns/pd-d/home/omst/test_images/lilly.jpg",
78 | "sunflower": "/cns/pd-d/home/omst/test_images/sunflower.JPG",
79 | 'golden_test_image': (
80 | '/cns/ge-d/home/mriviere/gemma_logits/test_image.jpg'
81 | ),
82 | }
83 | image = {}
84 | for key in image_paths:
85 | try:
86 | with gfile.Open(image_paths[key], 'rb') as f:
87 | image[key] = Image.open(f)
88 | image[key].show()
89 | except gfile.FileError as e:
90 | print(f"Error loading image: {e}")
91 | exit()
92 |
93 | # Seed random.
94 | random.seed(_SEED.value)
95 | np.random.seed(_SEED.value)
96 | torch.manual_seed(_SEED.value)
97 |
98 | # Create the model and load the weights.
99 | device = torch.device(_DEVICE.value)
100 | with _set_default_tensor_type(model_config.get_dtype()):
101 | model = gemma3_model.Gemma3ForMultimodalLM(model_config)
102 | model.load_state_dict(torch.load(_CKPT.value)['model_state_dict'])
103 | # model.load_weights(_CKPT.value)
104 | model = model.to(device).eval()
105 | print("Model loading done")
106 |
107 | # Generate text only.
108 | result = model.generate([["user The capital of Italy is?\nmodel"], ["user What is your purpose?\nmodel"]], device, output_len=_OUTPUT_LEN.value)
109 |
110 | # Print the results.
111 | print('======================================')
112 | print(f'Text only RESULT: {result}')
113 | print('======================================')
114 |
115 | # Generate golden Gemax test image.
116 | result = model.generate(
117 | [['user\n', image['golden_test_image'], 'Caption this image. \nmodel']],
118 | device,
119 | output_len=_OUTPUT_LEN.value,
120 | )
121 |
122 | # Print the result.
123 | print('======================================')
124 | print(f'Golden test image RESULT: {result}')
125 | print('======================================')
126 |
127 | # Generate text and image.
128 | result = model.generate([['user\n', image["cow_in_beach"], "The name of the animal in the image is \nmodel"]], device, output_len=_OUTPUT_LEN.value)
129 |
130 |
131 | # Print the result.
132 | print('======================================')
133 | print(f'Single image RESULT: {result}')
134 | print('======================================')
135 |
136 | # Generate interleave text and multiple images.
137 | result = model.generate([["user\nThis image", image["lilly"], "and this image", image["sunflower"], "are similar because? \nmodel"]], device, output_len=_OUTPUT_LEN.value)
138 |
139 | # Print the result.
140 | print('======================================')
141 | print(f'Interleave images RESULT: {result}')
142 | print('======================================')
143 |
144 | if __name__ == "__main__":
145 | app.run(main)
146 |
147 |
148 | # How to run this script:
149 |
150 | # Example command (replace with your actual paths and values):
151 | # blaze run third_party/open_models_release/gemma_pytorch/scripts:run_multimodal -- --device=cpu --ckpt=/usr/local/google/home/imayank/Desktop/gemma_files/ckpts/gemma-3.0-4b-pt/mm/model5.ckpt --output_len=5
152 | # Important:
153 | # - Replace '/path/to/your/pytorch_checkpoint/model.ckpt' with the actual path to your checkpoint file.
154 | # - Choose the correct --variant (model size).
155 | # - Use --device=cuda if you have a GPU; otherwise, use --device=cpu.
--------------------------------------------------------------------------------
/gemma/siglip_vision/siglip_vision_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Siglip vision model for gemma 3 and paligemma."""
15 |
16 | import torch
17 | from torch import nn
18 | import torch.nn.functional as F
19 |
20 | from . import config as siglip_vision_config
21 | SiglipVisionModelConfig = siglip_vision_config.SiglipVisionModelConfig
22 |
23 | class AveragePool2D(nn.Module):
24 | """Applies 4x4 average pooling and reshaping."""
25 | def __init__(self, config):
26 | super().__init__()
27 | self.config = config
28 |
29 | def forward(self, x):
30 | """Applies 4x4 average pooling and reshaping."""
31 | batch_size, seq_len, channels = x.shape
32 | width = int(seq_len**0.5)
33 | if width * width != seq_len:
34 | raise ValueError(
35 | f"Sequence length {seq_len} is not a perfect square. Cannot reshape to a square image."
36 | )
37 | # Bx(64^2)x1152 -> Bx1152x(64^2) -> Bx1152x64x64
38 | x = x.transpose(1, 2).reshape(batch_size, channels, width, width)
39 | # Bx1152x64x64-> Bx1152x16x16
40 | x = F.avg_pool2d(x, kernel_size=4, stride=4)
41 | # Bx1152x64x64-> Bx1152x256 -> Bx256x1152
42 | x = x.flatten(2).transpose(1, 2)
43 | return x
44 |
45 | # Siglip Attention
46 | class SiglipAttention(nn.Module):
47 | """Siglip attention module."""
48 |
49 | def __init__(self, dim, num_heads, head_dim):
50 | super().__init__()
51 | self.dim = dim
52 | self.num_heads = num_heads
53 | self.head_dim = head_dim
54 |
55 | # Key, Query, Value projections
56 | self.k_proj = nn.Linear(dim, num_heads * head_dim, bias=True)
57 | self.q_proj = nn.Linear(dim, num_heads * head_dim, bias=True)
58 | self.v_proj = nn.Linear(dim, num_heads * head_dim, bias=True)
59 |
60 | # Output projection
61 | self.o_proj = nn.Linear(num_heads * head_dim, dim, bias=True)
62 |
63 | def forward(self, x):
64 | batch_size, seq_len, _ = x.size()
65 |
66 | # Project inputs to key, query, value
67 | k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
68 | q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
69 | v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
70 |
71 | # Transpose for multi-head attention
72 | k = k.transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
73 | q = q.transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
74 | v = v.transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
75 |
76 | # Scaled dot-product attention
77 | scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)
78 | attn_weights = F.softmax(scores, dim=-1)
79 | attn_output = torch.matmul(attn_weights, v)
80 |
81 | # Transpose back to (batch_size, seq_len, num_heads, head_dim)
82 | attn_output = attn_output.transpose(1, 2).contiguous()
83 | attn_output = attn_output.view(
84 | batch_size, seq_len, self.num_heads * self.head_dim
85 | )
86 |
87 | # Apply output projection
88 | output = self.o_proj(attn_output)
89 |
90 | return output
91 |
92 |
93 | class SiglipMLP(nn.Module):
94 | """Siglip MLP module."""
95 | def __init__(self, hidden_size, intermediate_size):
96 | super().__init__()
97 | self.fc1 = nn.Linear(hidden_size, intermediate_size)
98 | self.fc2 = nn.Linear(intermediate_size, hidden_size)
99 |
100 | def gelu_tanh(self, x):
101 | return (
102 | 0.5
103 | * x
104 | * (
105 | 1
106 | + torch.tanh(
107 | torch.sqrt(torch.tensor(2.0 / torch.pi, device=x.device))
108 | * (x + 0.044715 * torch.pow(x, 3))
109 | )
110 | )
111 | )
112 |
113 | def forward(self, x):
114 | x = self.fc1(x)
115 | x = self.gelu_tanh(x)
116 | x = self.fc2(x)
117 | return x
118 |
119 |
120 | class SiglipEncoderBlock(nn.Module):
121 | """Encoder block (Transformer layer) for siglip vision model."""
122 |
123 | def __init__(self, config: SiglipVisionModelConfig):
124 | super().__init__()
125 | self.self_attn = SiglipAttention(
126 | config.embedding_dim, config.num_attention_heads, config.head_dim
127 | )
128 | # SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_0/LayerNorm_0
129 | self.layer_norm1 = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps)
130 | self.mlp = SiglipMLP(config.embedding_dim, config.intermediate_size)
131 | # SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_0/LayerNorm_1
132 | self.layer_norm2 = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps)
133 |
134 | def forward(self, x):
135 | # Pre-LN
136 | residual = x
137 | # SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_0/LayerNorm_0
138 | x = self.layer_norm1(x)
139 | x = self.self_attn(x)
140 | x = x + residual # Residual connection *after* LayerNorm
141 |
142 | residual = x
143 | # SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_0/LayerNorm_1
144 | x = self.layer_norm2(x)
145 | x = self.mlp(x)
146 | x = x + residual # Residual connection *after* LayerNorm
147 | return x
148 |
149 |
150 | # https://developers.googleblog.com/en/gemma-explained-paligemma-architecture/
151 | class SiglipVisionModel(nn.Module):
152 | """Signlip vision model for gemma 3 and paligemma."""
153 |
154 | def __init__(self, config: SiglipVisionModelConfig):
155 | super().__init__()
156 |
157 | # SigLiPFromPatches_0/siglip_encoder/embedding
158 | self.patch_embedding = nn.Conv2d(
159 | in_channels=config.input_channels,
160 | out_channels=config.embedding_dim,
161 | kernel_size=config.conv2d_patch_size,
162 | stride=config.conv2d_patch_size,
163 | padding=0,
164 | bias=config.embedding_use_bias,
165 | )
166 | self.num_patches = (config.image_size // config.conv2d_patch_size) ** 2
167 | self.num_positions = self.num_patches
168 | # SigLiPFromPatches_0/siglip_encoder
169 | self.position_embedding = nn.Embedding(
170 | self.num_positions, config.embedding_dim
171 | )
172 | self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
173 |
174 | # SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_i
175 | self.encoder_blocks = nn.ModuleList(
176 | SiglipEncoderBlock(config=config)
177 | for _ in range(config.num_hidden_layers)
178 | )
179 | # SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm
180 | self.final_norm = nn.LayerNorm(config.embedding_dim, config.layer_norm_eps)
181 | self.avg_pool = AveragePool2D(config)
182 | self.config = config
183 |
184 | # pixel_values=Bxconfig.input_channels x config.image_size x config.image_size
185 | @torch.inference_mode
186 | def forward(
187 | self,
188 | pixel_values: torch.Tensor,
189 | ) -> torch.Tensor:
190 | # Embed the image according to SiplipVisionEmbeddings.
191 | x = self.patch_embedding(pixel_values)
192 | # (batch_size,channels,height,width)->(batch_size, height*width, channels)
193 | x = x.flatten(2).transpose(1, 2)
194 |
195 | position_ids = self.position_ids.to(pixel_values.device)
196 | x = x + self.position_embedding(position_ids)
197 |
198 | for block in self.encoder_blocks:
199 | x = block(x) # batch_size, height*width, embedding_dim (1152)
200 | x = self.final_norm(x)
201 |
202 | # siglip exit https://source.corp.google.com/piper///depot/google3/third_party/py/gemma/multimodal/vision.py;l=220
203 | return self.avg_pool(x)
204 |
--------------------------------------------------------------------------------
/gemma/gemma3_preprocessor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Preprocessor for Gemma3 input."""
15 | import token
16 |
17 | from typing import Union, Any, Sequence
18 |
19 | import torch
20 | from absl import app
21 | from PIL import Image
22 | from .siglip_vision import preprocessor as siglip_vision_preprocessor
23 | from .siglip_vision import pan_and_scan
24 | from . import tokenizer
25 | from . import config as gemma_config
26 |
27 | CROPPED_IMAGE_PREFIX = "here is the original image"
28 | CROPPED_IMAGE_FILLER = "and here are some crops to help you see better"
29 |
30 |
31 | def gemma3_input_preprocessor(
32 | raw_user_prompt: Sequence[Union[Image.Image, str]],
33 | ) -> Sequence[Union[torch.Tensor, str]]:
34 | """Preprocessor for Gemma3 input.
35 |
36 | Args:
37 | raw_user_prompt: A list of images or strings, as provided by the user.
38 |
39 | Returns:
40 | A list of preprocessed images or strings.
41 | """
42 | preprocessed_input: list[Union[torch.Tensor, str]] = []
43 | for element in raw_user_prompt:
44 | if isinstance(element, Image.Image):
45 | cropped_images = pan_and_scan.pan_and_scan(element)
46 | preprocessed_images_cropped = siglip_vision_preprocessor.preprocess_images_for_siglip_vision(cropped_images)
47 | preprocessed_images_uncropped = siglip_vision_preprocessor.preprocess_images_for_siglip_vision([element])
48 | if len(preprocessed_images_cropped) == 1:
49 | preprocessed_input.append(preprocessed_images_uncropped[0])
50 | elif len(preprocessed_images_cropped) > 1:
51 | preprocessed_input.append(CROPPED_IMAGE_PREFIX)
52 | preprocessed_input.append(preprocessed_images_uncropped[0])
53 | preprocessed_input.append(CROPPED_IMAGE_FILLER)
54 | preprocessed_input.extend(preprocessed_images_cropped)
55 | else:
56 | raise ValueError("No images found in the input.")
57 | else:
58 | preprocessed_input.append(element)
59 |
60 | return preprocessed_input
61 |
62 |
63 | def gemma3_batch_input_preprocessor(raw_input: Sequence[Sequence[Union[Image.Image, str]]]):
64 | """Preprocessor for Gemma3 batch input.
65 | """
66 | preprocessed_input: list[Sequence[Union[torch.Tensor, str]]] = []
67 | for element in raw_input:
68 | preprocessed_input.append(gemma3_input_preprocessor(element))
69 | return preprocessed_input
70 |
71 |
72 | def tokenize_raw_input(
73 | tokenizer_obj: tokenizer.Tokenizer,
74 | raw_input: Sequence[Sequence[Union[str, Image.Image]]],
75 | config: gemma_config.GemmaConfig,
76 | output_len: int,
77 | device: Any,
78 | ) -> dict[str, Any]:
79 | """
80 | Converts a preprocessed batch of interleaved text and image inputs into
81 | token IDs and an image batch suitable for gemma3 model.
82 |
83 | Args:
84 | preprocessed_batch: List of lists containing strings and torch.Tensor images.
85 | image_token_id: Token ID to represent image placeholders.
86 | max_image_tokens: Number of tokens reserved for each image.
87 | image_size: Expected size of images (C, H, W).
88 |
89 | Returns:
90 | user_input_token_ids: Batch of token IDs with shape (B, L), where L is the max sequence length.
91 | image_batch: Batch of images with shape (B, N, C, H, W), where N is the max number of images.
92 | """
93 | if config.vision_config is None:
94 | raise ValueError('vision_config must be provided for Gemma3.')
95 |
96 | preprocessed_batch = gemma3_batch_input_preprocessor(raw_input)
97 |
98 | # Initialize lists to store token IDs and image tensors
99 | all_token_ids = []
100 | all_images = []
101 | prompt_lengths = []
102 |
103 | max_prompt_len = 0
104 | min_prompt_len = float("inf")
105 | max_num_images = 0
106 | # Iterate over each user prompt in the batch
107 | for prompt in preprocessed_batch:
108 | token_ids = []
109 | images = []
110 | token_ids.append(tokenizer_obj.bos_id)
111 | # Process each element in the prompt
112 | for element in prompt:
113 | if isinstance(element, str):
114 | # Tokenize text and add to token_ids
115 | tokens = tokenizer_obj.encode(element, bos=False, eos=False)
116 | token_ids.extend(tokens)
117 | elif isinstance(element, torch.Tensor):
118 | # Prepend (dual endline + tokenizer_obj.boi)
119 | token_ids.extend(tokenizer_obj.encode("\n\n", bos=False, eos=False))
120 | token_ids.append(tokenizer_obj.boi_id)
121 | # Add image placeholder tokens
122 | token_ids.extend(
123 | [tokenizer_obj.image_token_placeholder_id]
124 | * config.vision_config.encoding_sequence_length
125 | )
126 | # Append (tokenizer_obj.eoi + dual endline)
127 | token_ids.append(tokenizer_obj.eoi_id)
128 | token_ids.extend(tokenizer_obj.encode("\n\n", bos=False, eos=False))
129 | # Store the image tensor
130 | images.append(element)
131 | else:
132 | raise ValueError(
133 | "Unsupported type in prompt. Expected str or torch.Tensor."
134 | )
135 | curr_prompt_len = len(token_ids)
136 | prompt_lengths.append(curr_prompt_len)
137 |
138 | max_prompt_len = max(max_prompt_len, curr_prompt_len)
139 | min_prompt_len = min(min_prompt_len, curr_prompt_len)
140 | max_num_images = max(max_num_images, len(images))
141 |
142 | all_token_ids.append(token_ids)
143 | all_images.append(images)
144 |
145 | max_seq_len = max_prompt_len + output_len
146 |
147 | # Pad token IDs to the maximum sequence length
148 | user_input_token_ids = []
149 | for token_ids in all_token_ids:
150 | pad_length = max_seq_len - len(token_ids)
151 | padded_token_ids = token_ids + [tokenizer_obj.pad_id] * pad_length
152 | user_input_token_ids.append(padded_token_ids)
153 |
154 | # Pad images to the maximum number of images in the batch
155 | image_batch = []
156 | image_presence_mask = []
157 | for images in all_images:
158 | # Check if all images within the current sublist have the same shape
159 | if images: # Check if the sublist is not empty
160 | first_shape = images[0].shape
161 | for img in images:
162 | assert img.shape == first_shape, "Images within a sublist must have the same shape."
163 | pad_length = max_num_images - len(images)
164 | padded_images = images.copy() #create a copy so the original data is not altered.
165 | presence_mask = [True] * len(images)
166 |
167 | if pad_length > 0:
168 | # Create a list of zero tensors for padding
169 | padding = [
170 | torch.zeros(
171 | (
172 | config.vision_config.input_channels,
173 | config.vision_config.image_size,
174 | config.vision_config.image_size,
175 | ), device=device
176 | )
177 | for _ in range(pad_length)
178 | ]
179 | padded_images.extend(padding)
180 | presence_mask.extend([False] * pad_length)
181 | image_batch.append(padded_images)
182 | image_presence_mask.append(presence_mask)
183 |
184 | # Convert lists to tensors
185 | user_input_token_ids = torch.tensor(user_input_token_ids, dtype=torch.long, device=device)
186 | if max_num_images > 0:
187 | image_batch = torch.stack([torch.stack(images) for images in image_batch]).to(
188 | device
189 | )
190 | image_presence_mask = torch.tensor(image_presence_mask, dtype=torch.bool, device=device)
191 | else:
192 | image_batch = None
193 | image_presence_mask = None
194 |
195 | # Prepare the output dictionary
196 | output_dict = {
197 | "user_input_token_ids": user_input_token_ids,
198 | "image_batch": image_batch,
199 | "batch_size": len(preprocessed_batch),
200 | "min_prompt_len": min_prompt_len,
201 | "max_prompt_len": max_prompt_len,
202 | "max_seq_len": max_seq_len,
203 | "image_presence_mask": image_presence_mask,
204 | }
205 |
206 | return output_dict
207 |
--------------------------------------------------------------------------------
/scripts/run_xla.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import argparse
15 | import contextlib
16 | import os
17 | import random
18 | import socket
19 | import sys
20 | from typing import List, Union
21 |
22 | import numpy as np
23 | import torch
24 | import torch.multiprocessing
25 |
26 | from gemma.config import GemmaConfig, get_model_config
27 | from gemma.model_xla import GemmaForCausalLM
28 | from gemma.tokenizer import Tokenizer
29 | import gemma.xla_model_parallel as xla_model_parallel
30 |
31 | USE_CUDA = os.environ.get('USE_CUDA', False)
32 | if not USE_CUDA:
33 | import torch_xla.core.xla_model as xm
34 | import torch_xla.distributed.xla_multiprocessing as xmp
35 | else:
36 | # Choose an available port.
37 | with contextlib.closing(socket.socket(socket.AF_INET,
38 | socket.SOCK_STREAM)) as s:
39 | s.bind(('', 0))
40 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
41 | MASTER_PORT = str(s.getsockname()[1])
42 |
43 |
44 | @contextlib.contextmanager
45 | def _set_default_tensor_type(dtype: torch.dtype):
46 | """Sets the default torch dtype to the given dtype."""
47 | torch.set_default_dtype(dtype)
48 | yield
49 | torch.set_default_dtype(torch.float)
50 |
51 |
52 | def generate(
53 | i: int,
54 | model_config: GemmaConfig,
55 | ckpt_path: str,
56 | prompts: List[str],
57 | output_lens: List[int],
58 | temperatures: Union[List[float], None],
59 | top_ps: List[float],
60 | top_ks: List[int],
61 | seed: int
62 | ):
63 | random.seed(seed)
64 | np.random.seed(seed)
65 | torch.manual_seed(seed)
66 | if USE_CUDA:
67 | os.environ['MASTER_ADDR'] = '127.0.0.1'
68 | os.environ['MASTER_PORT'] = MASTER_PORT
69 | if not torch.distributed.is_initialized():
70 | torch.distributed.init_process_group(
71 | "nccl",
72 | rank=int(os.environ.get("RANK", 0)),
73 | world_size=int(os.environ.get("WORLD_SIZE", 1)))
74 | xla_model_parallel.set_g_group()
75 | local_rank = int(os.environ.get("LOCAL_RANK", 0))
76 | device = torch.device("cuda", local_rank)
77 | torch.cuda.set_device(local_rank)
78 | else:
79 | device = xm.xla_device()
80 | xm.set_rng_state(seed, device)
81 |
82 | rank = xla_model_parallel.get_model_parallel_rank()
83 | world_size = xla_model_parallel.get_model_parallel_world_size()
84 | if rank > 0:
85 | sys.stdout = open(os.devnull, 'w')
86 |
87 | # build, load and compile model.
88 | with _set_default_tensor_type(model_config.get_dtype()):
89 | model = GemmaForCausalLM(model_config, world_size, rank, device)
90 | model.load_weights(ckpt_path)
91 | model = model.to(device).eval()
92 |
93 | # create tokenizer.
94 | tokenizer = Tokenizer(model_config.tokenizer)
95 |
96 | prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts]
97 | min_prompt_len = min(len(p) for p in prompt_tokens)
98 |
99 | batch_size = len(prompts)
100 | if temperatures is not None:
101 | assert batch_size == len(temperatures)
102 | assert batch_size == len(top_ps)
103 | assert batch_size == len(top_ks)
104 | max_seq_len = max([len(p) + o for p, o in zip(prompt_tokens, output_lens)])
105 | assert max_seq_len <= model_config.max_position_embeddings
106 | if model_config.num_key_value_heads < world_size:
107 | assert world_size % model_config.num_key_value_heads == 0
108 | n_local_heads = 1
109 | else:
110 | assert model_config.num_key_value_heads % world_size == 0
111 | n_local_heads = model_config.num_key_value_heads // world_size
112 |
113 | # build KV caches
114 | kv_caches = []
115 | for _ in range(model_config.num_hidden_layers):
116 | k_cache = torch.zeros(
117 | size=(batch_size, max_seq_len, n_local_heads,
118 | model_config.head_dim),
119 | dtype=model_config.get_dtype(),
120 | device=device,
121 | )
122 | v_cache = torch.zeros(
123 | size=(batch_size, max_seq_len, n_local_heads,
124 | model_config.head_dim),
125 | dtype=model_config.get_dtype(),
126 | device=device,
127 | )
128 | kv_caches.append((k_cache, v_cache))
129 |
130 | # prepare inputs
131 | token_ids_tensor = torch.full((batch_size, max_seq_len),
132 | tokenizer.pad_id,
133 | dtype=torch.int64)
134 | input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
135 | tokenizer.pad_id,
136 | dtype=torch.int64)
137 | for i, p in enumerate(prompt_tokens):
138 | token_ids_tensor[i, :len(p)] = torch.tensor(p)
139 | input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
140 | p[:min_prompt_len])
141 | token_ids_tensor = token_ids_tensor.to(device)
142 | prompt_mask_tensor = token_ids_tensor != tokenizer.pad_id
143 | input_token_ids_tensor = input_token_ids_tensor.to(device)
144 | input_positions_tensor = torch.arange(0, min_prompt_len,
145 | dtype=torch.int64).to(device)
146 | mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
147 | -2.3819763e38).to(torch.float)
148 | mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
149 | curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
150 | output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device)
151 | temperatures_tensor = None if not temperatures else torch.FloatTensor(temperatures).to(device)
152 | top_ps_tensor = torch.FloatTensor(top_ps).to(device)
153 | top_ks_tensor = torch.LongTensor(top_ks).to(device)
154 | output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device)
155 | if not USE_CUDA:
156 | xm.mark_step()
157 |
158 | # Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output.
159 | for i in range(max_seq_len - min_prompt_len):
160 | next_token_ids, _ = model(
161 | input_token_ids=input_token_ids_tensor,
162 | input_positions=input_positions_tensor,
163 | kv_write_indices=None,
164 | kv_caches=kv_caches,
165 | mask=curr_mask_tensor,
166 | output_positions=output_positions_tensor,
167 | temperatures=temperatures_tensor,
168 | top_ps=top_ps_tensor,
169 | top_ks=top_ks_tensor,
170 | )
171 | curr_prompt_mask = prompt_mask_tensor.index_select(
172 | 1, output_index).squeeze(dim=1)
173 | curr_token_ids = token_ids_tensor.index_select(
174 | 1, output_index).squeeze(dim=1)
175 | output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
176 | next_token_ids).unsqueeze(dim=1)
177 | token_ids_tensor.index_copy_(1, output_index, output_token_ids)
178 |
179 | input_token_ids_tensor = output_token_ids
180 | input_positions_tensor = output_index.unsqueeze(dim=-1)
181 | curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
182 | output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device)
183 | output_index = output_index + 1
184 | if not USE_CUDA:
185 | xm.mark_step()
186 |
187 | # Detokenization.
188 | token_ids = token_ids_tensor.tolist()
189 | results = []
190 | for i, tokens in enumerate(token_ids):
191 | trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i]) +
192 | output_lens[i]]
193 | if tokenizer.eos_id in trimmed_output:
194 | eos_index = trimmed_output.index(tokenizer.eos_id)
195 | trimmed_output = trimmed_output[:eos_index]
196 | results.append(tokenizer.decode(trimmed_output))
197 |
198 | for prompt, result in zip(prompts, results):
199 | print('======================================')
200 | print(f'PROMPT: {prompt}')
201 | print(f'RESULT: {result}')
202 | print('======================================')
203 |
204 |
205 | def main(args):
206 | model_config = get_model_config(args.variant)
207 | model_config.quant = args.quant
208 |
209 | prompts = [args.prompt]
210 | n = len(prompts)
211 | output_lengths = [args.output_len] * n
212 | temperatures = [0.95] * n
213 | top_ps = [1.0] * n
214 | top_ks = [100] * n
215 |
216 | if USE_CUDA:
217 | os.environ['MASTER_ADDR'] = '127.0.0.1'
218 | os.environ['MASTER_PORT'] = MASTER_PORT
219 | if not torch.distributed.is_initialized():
220 | torch.distributed.init_process_group(
221 | "nccl",
222 | rank=int(os.environ.get("RANK", 0)),
223 | world_size=int(os.environ.get("WORLD_SIZE", 1)))
224 | xla_model_parallel.set_g_group()
225 | torch.multiprocessing.spawn(
226 | generate,
227 | args=(
228 | model_config,
229 | args.ckpt,
230 | prompts,
231 | output_lengths,
232 | temperatures,
233 | top_ps,
234 | top_ks,
235 | args.seed,
236 | ),
237 | )
238 | else:
239 | xmp.spawn(
240 | generate,
241 | args=(
242 | model_config,
243 | args.ckpt,
244 | prompts,
245 | output_lengths,
246 | temperatures,
247 | top_ps,
248 | top_ks,
249 | args.seed,
250 | ),
251 | )
252 |
253 |
254 | if __name__ == '__main__':
255 | parser = argparse.ArgumentParser()
256 | parser.add_argument("--ckpt", type=str, required=True)
257 | parser.add_argument("--variant",
258 | type=str,
259 | default="2b",
260 | choices=["2b", "2b-v2", "7b", "9b", "27b"])
261 | parser.add_argument("--output_len", type=int, default=4)
262 | parser.add_argument("--seed", type=int, default=12345)
263 | parser.add_argument("--quant", action='store_true')
264 | parser.add_argument("--prompt", type=str, default="The meaning of life is")
265 | args = parser.parse_args()
266 |
267 | main(args)
268 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/gemma/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Gemma model config."""
16 |
17 | import dataclasses
18 | import enum
19 | import os
20 | import torch
21 | from typing import Optional, Sequence
22 | from .siglip_vision import config as siglip_vision_config
23 |
24 |
25 | # Keep a mapping from dtype strings to the supported torch dtypes.
26 | _STR_DTYPE_TO_TORCH_DTYPE = dict({
27 | 'float16': torch.float16,
28 | 'float': torch.float32,
29 | 'float32': torch.float32,
30 | 'bfloat16': torch.bfloat16,
31 | })
32 |
33 |
34 | class AttentionType(enum.Enum):
35 | GLOBAL = 1
36 | LOCAL_SLIDING = 2
37 |
38 |
39 | class Architecture(enum.Enum):
40 | GEMMA_1 = 1
41 | GEMMA_2 = 2
42 | GEMMA_3 = 3
43 |
44 |
45 | @dataclasses.dataclass
46 | class GemmaConfig:
47 | # The architecture of the model.
48 | architecture: Architecture = Architecture.GEMMA_1
49 | # The number of tokens in the vocabulary.
50 | vocab_size: int = 256000
51 | # The maximum sequence length that this model might ever be used with.
52 | max_position_embeddings: int = 8192
53 | # The number of blocks in the model.
54 | num_hidden_layers: int = 28
55 | # The number of attention heads used in the attention layers of the model.
56 | num_attention_heads: int = 16
57 | # The number of key-value heads for implementing attention.
58 | num_key_value_heads: int = 16
59 | # The hidden size of the model.
60 | hidden_size: int = 3072
61 | # The dimension of the MLP representations.
62 | intermediate_size: int = 24576
63 | # The number of head dimensions.
64 | head_dim: int = 256
65 | # The epsilon used by the rms normalization layers.
66 | rms_norm_eps: float = 1e-6
67 | # The dtype of the weights.
68 | dtype: str = 'bfloat16'
69 | # Whether a quantized version of the model is used.
70 | quant: bool = False
71 | # The path to the model tokenizer.
72 | tokenizer: Optional[str] = (
73 | 'tokenizer/tokenizer.model'
74 | )
75 | # The types of attention used in the layers of the model.
76 | attn_types: Optional[Sequence[AttentionType]] = None
77 | # The size of the sliding window used for local attention.
78 | sliding_window_size: Optional[int] = None
79 | # If provided, the final logits are softcapped to this value.
80 | final_logit_softcapping: Optional[float] = None
81 | # If provided, the attention logits are softcapped to this value.
82 | attn_logit_softcapping: Optional[float] = None
83 | # If provided, the query vector is normalized using the
84 | # inverse square root of this value instead of head_dim.
85 | query_pre_attn_scalar: Optional[int] = None
86 | # Whether to use pre mlp normalization.
87 | use_pre_ffw_norm: bool = False
88 | # Whether to use post mlp normalization.
89 | use_post_ffw_norm: bool = False
90 | # The wave length of the rotary embedding.
91 | rope_wave_length: dict[AttentionType, int] | None = None
92 | # Whether to use QK normalization in the attention blocks.
93 | use_qk_norm: bool = False
94 | # Vision model config.
95 | vision_config: siglip_vision_config.SiglipVisionModelConfig | None = None
96 | # The factor by which the rope wave length is divided for global layers.
97 | rope_scaling_factor: int| None = None
98 |
99 | def get_dtype(self) -> Optional[torch.dtype]:
100 | """Gets the torch dtype from the config dtype string."""
101 | return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None)
102 |
103 |
104 | def get_config_for_7b(dtype: str = 'bfloat16') -> GemmaConfig:
105 | return GemmaConfig(dtype=dtype)
106 |
107 |
108 | def get_config_for_2b(dtype: str = 'bfloat16') -> GemmaConfig:
109 | return GemmaConfig(
110 | dtype=dtype,
111 | num_hidden_layers=18,
112 | num_attention_heads=8,
113 | num_key_value_heads=1,
114 | hidden_size=2048,
115 | intermediate_size=16384,
116 | )
117 |
118 |
119 | def get_config_for_2b_v2(dtype: str = 'bfloat16') -> GemmaConfig:
120 | return GemmaConfig(
121 | dtype=dtype,
122 | architecture=Architecture.GEMMA_2,
123 | num_hidden_layers=26,
124 | num_attention_heads=8,
125 | num_key_value_heads=4,
126 | hidden_size=2304,
127 | intermediate_size=9216,
128 | use_pre_ffw_norm=True,
129 | use_post_ffw_norm=True,
130 | final_logit_softcapping=30.0,
131 | attn_logit_softcapping=50.0,
132 | head_dim=256,
133 | attn_types=[AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL] * 13,
134 | sliding_window_size=4096,
135 | )
136 |
137 |
138 | def get_config_for_9b(dtype: str = 'bfloat16') -> GemmaConfig:
139 | return GemmaConfig(
140 | dtype=dtype,
141 | architecture=Architecture.GEMMA_2,
142 | num_hidden_layers=42,
143 | num_attention_heads=16,
144 | num_key_value_heads=8,
145 | hidden_size=3584,
146 | intermediate_size=14336,
147 | use_pre_ffw_norm=True,
148 | use_post_ffw_norm=True,
149 | final_logit_softcapping=30.0,
150 | attn_logit_softcapping=50.0,
151 | head_dim=256,
152 | attn_types=[AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL] * 21,
153 | sliding_window_size=4096,
154 | )
155 |
156 |
157 | def get_config_for_27b(dtype: str = 'bfloat16') -> GemmaConfig:
158 | return GemmaConfig(
159 | dtype=dtype,
160 | architecture=Architecture.GEMMA_2,
161 | num_hidden_layers=46,
162 | num_attention_heads=32,
163 | num_key_value_heads=16,
164 | hidden_size=4608,
165 | intermediate_size=36864,
166 | use_pre_ffw_norm=True,
167 | use_post_ffw_norm=True,
168 | final_logit_softcapping=30.0,
169 | attn_logit_softcapping=50.0,
170 | head_dim=128,
171 | attn_types=[AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL] * 23,
172 | sliding_window_size=4096,
173 | query_pre_attn_scalar=144, # hidden_size / num_attention_heads
174 | )
175 |
176 |
177 | def get_config_for_1b(dtype: str) -> GemmaConfig:
178 | return GemmaConfig(
179 | dtype=dtype,
180 | architecture=Architecture.GEMMA_3,
181 | num_hidden_layers=26,
182 | num_attention_heads=4,
183 | num_key_value_heads=1,
184 | hidden_size=1152,
185 | intermediate_size=6912,
186 | use_pre_ffw_norm=True,
187 | use_post_ffw_norm=True,
188 | head_dim=256,
189 | attn_types=(
190 | AttentionType.LOCAL_SLIDING,
191 | AttentionType.LOCAL_SLIDING,
192 | AttentionType.LOCAL_SLIDING,
193 | AttentionType.LOCAL_SLIDING,
194 | AttentionType.LOCAL_SLIDING,
195 | AttentionType.GLOBAL,
196 | ),
197 | sliding_window_size=512,
198 | rope_wave_length={
199 | AttentionType.LOCAL_SLIDING: 10_000,
200 | AttentionType.GLOBAL: 1_000_000,
201 | },
202 | vocab_size=262_144,
203 | max_position_embeddings=32_768,
204 | tokenizer='tokenizer/gemma3_cleaned_262144_v2.spiece.model',
205 | use_qk_norm=True,
206 | vision_config=None,
207 | )
208 |
209 |
210 | def get_config_for_4b(dtype: str) -> GemmaConfig:
211 | return GemmaConfig(
212 | dtype=dtype,
213 | architecture=Architecture.GEMMA_3,
214 | num_hidden_layers=34,
215 | num_attention_heads=8,
216 | num_key_value_heads=4,
217 | hidden_size=2560,
218 | intermediate_size=10240,
219 | use_pre_ffw_norm=True,
220 | use_post_ffw_norm=True,
221 | head_dim=256,
222 | attn_types=(
223 | AttentionType.LOCAL_SLIDING,
224 | AttentionType.LOCAL_SLIDING,
225 | AttentionType.LOCAL_SLIDING,
226 | AttentionType.LOCAL_SLIDING,
227 | AttentionType.LOCAL_SLIDING,
228 | AttentionType.GLOBAL,
229 | ),
230 | sliding_window_size=1024,
231 | rope_wave_length={
232 | AttentionType.LOCAL_SLIDING: 10_000,
233 | AttentionType.GLOBAL: 1_000_000,
234 | },
235 | vocab_size=262_144,
236 | tokenizer='tokenizer/gemma3_cleaned_262144_v2.spiece.model',
237 | use_qk_norm=True,
238 | vision_config=siglip_vision_config.get_siglip_vision_model_config(),
239 | rope_scaling_factor=8,
240 | )
241 |
242 |
243 | def get_config_for_12b(dtype: str) -> GemmaConfig:
244 | return GemmaConfig(
245 | dtype=dtype,
246 | architecture=Architecture.GEMMA_3,
247 | num_hidden_layers=48,
248 | num_attention_heads=16,
249 | num_key_value_heads=8,
250 | hidden_size=3840,
251 | intermediate_size=3840 * 8 // 2,
252 | use_pre_ffw_norm=True,
253 | use_post_ffw_norm=True,
254 | head_dim=256,
255 | attn_types=(
256 | AttentionType.LOCAL_SLIDING,
257 | AttentionType.LOCAL_SLIDING,
258 | AttentionType.LOCAL_SLIDING,
259 | AttentionType.LOCAL_SLIDING,
260 | AttentionType.LOCAL_SLIDING,
261 | AttentionType.GLOBAL,
262 | ),
263 | sliding_window_size=1024,
264 | rope_wave_length={
265 | AttentionType.LOCAL_SLIDING: 10_000,
266 | AttentionType.GLOBAL: 1_000_000,
267 | },
268 | vocab_size=262_144,
269 | max_position_embeddings=131_072,
270 | tokenizer='tokenizer/gemma3_cleaned_262144_v2.spiece.model',
271 | use_qk_norm=True,
272 | vision_config=siglip_vision_config.get_siglip_vision_model_config(),
273 | rope_scaling_factor=8,
274 | )
275 |
276 |
277 | def get_config_for_27b_v3(dtype: str) -> GemmaConfig:
278 | return GemmaConfig(
279 | dtype=dtype,
280 | architecture=Architecture.GEMMA_3,
281 | num_hidden_layers=62,
282 | num_attention_heads=32,
283 | num_key_value_heads=16,
284 | hidden_size=5376,
285 | intermediate_size=5376 * 8 // 2,
286 | use_pre_ffw_norm=True,
287 | use_post_ffw_norm=True,
288 | head_dim=128,
289 | query_pre_attn_scalar=5376 // 32,
290 | attn_types=(
291 | AttentionType.LOCAL_SLIDING,
292 | AttentionType.LOCAL_SLIDING,
293 | AttentionType.LOCAL_SLIDING,
294 | AttentionType.LOCAL_SLIDING,
295 | AttentionType.LOCAL_SLIDING,
296 | AttentionType.GLOBAL,
297 | ),
298 | sliding_window_size=1024,
299 | rope_wave_length={
300 | AttentionType.LOCAL_SLIDING: 10_000,
301 | AttentionType.GLOBAL: 1_000_000,
302 | },
303 | vocab_size=262_144,
304 | max_position_embeddings=131_072,
305 | tokenizer='tokenizer/gemma3_cleaned_262144_v2.spiece.model',
306 | use_qk_norm=True,
307 | vision_config=siglip_vision_config.get_siglip_vision_model_config(),
308 | rope_scaling_factor=8,
309 | )
310 |
311 |
312 | def get_model_config(variant: str, dtype: str = 'bfloat16') -> GemmaConfig:
313 | """Gets the GemmaConfig for the diresired variant and dtype."""
314 | # Gemma1 variants
315 | if variant == '7b':
316 | return get_config_for_7b(dtype)
317 | elif variant == '2b':
318 | return get_config_for_2b(dtype)
319 | # Gemma2 variants
320 | elif variant == '2b-v2':
321 | return get_config_for_2b_v2(dtype)
322 | elif variant == '9b':
323 | return get_config_for_9b(dtype)
324 | elif variant == '27b':
325 | return get_config_for_27b(dtype)
326 | # Gemma3 variants
327 | elif variant == '1b':
328 | return get_config_for_1b(dtype)
329 | elif variant == '4b':
330 | return get_config_for_4b(dtype)
331 | elif variant == '12b':
332 | return get_config_for_12b(dtype)
333 | elif variant == '27b_v3':
334 | return get_config_for_27b_v3(dtype)
335 | # Invalid variants
336 | else:
337 | raise ValueError(
338 | f'Invalid variant {variant}. Supported variants are "1b", "2b", '
339 | '"2b-v2", "4b",, "7b", "9b" "12b", "27b", and "27b_v3".'
340 | )
341 |
--------------------------------------------------------------------------------
/gemma/gemma3_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Inference-only Gemma 3 multimodal model implementation."""
15 |
16 | import torch
17 | import os
18 | import json
19 | import gc
20 | from torch import nn
21 | from PIL import Image
22 | from typing import Any, List, Sequence, Tuple, Union
23 |
24 | from . import model as gemma_model
25 | from . import config as gemma_config
26 | from . import gemma3_preprocessor
27 | from . import tokenizer
28 | from .siglip_vision import siglip_vision_model
29 |
30 | class Gemma3ForMultimodalLM(nn.Module):
31 | """Gemma3 model for multimodal causal LM."""
32 | def __init__(
33 | self,
34 | config: gemma_config.GemmaConfig,
35 | ):
36 | super().__init__()
37 | self.dtype = config.get_dtype()
38 | assert config.architecture == gemma_config.Architecture.GEMMA_3
39 | self.config = config
40 | max_seq_len = config.max_position_embeddings
41 | head_dim = config.head_dim
42 | vocab_size = config.vocab_size
43 | self.tokenizer = tokenizer.Tokenizer(config.tokenizer)
44 | self.text_token_embedder = gemma_model.Embedding(vocab_size, config.hidden_size, config.quant)
45 | self.model = gemma_model.GemmaModel(config)
46 | self.sampler = gemma_model.Sampler(vocab_size, config)
47 |
48 | if config.vision_config is None:
49 | raise ValueError('vision_config must be provided for Gemma3.')
50 | self.siglip_vision_model = siglip_vision_model.SiglipVisionModel(config.vision_config)
51 | # transformer/embedder/mm_soft_embedding_norm
52 | self.mm_soft_embedding_norm = gemma_model.RMSNorm(config.vision_config.embedding_dim,
53 | eps = config.rms_norm_eps)
54 | # transformer/embedder/mm_input_projection
55 | self.mm_input_projection = gemma_model.Linear(config.vision_config.embedding_dim, config.hidden_size, config.quant)
56 |
57 | if config.rope_wave_length is None:
58 | raise ValueError('rope_wave_length must be provided for Gemma3.')
59 | rope_lengths = config.rope_wave_length
60 | defaults = {
61 | gemma_config.AttentionType.LOCAL_SLIDING: 10_000,
62 | gemma_config.AttentionType.GLOBAL: 10_000,
63 | }
64 | self._register_freqs_cis('local_freqs_cis', head_dim, max_seq_len, theta=rope_lengths.get(
65 | gemma_config.AttentionType.LOCAL_SLIDING, defaults[gemma_config.AttentionType.LOCAL_SLIDING]
66 | ))
67 | self._register_freqs_cis('global_freqs_cis', head_dim, max_seq_len, theta=rope_lengths.get(
68 | gemma_config.AttentionType.GLOBAL, defaults[gemma_config.AttentionType.GLOBAL]
69 | ), rope_scaling_factor=config.rope_scaling_factor)
70 |
71 | def _register_freqs_cis(
72 | self, name: str, head_dim: int, max_seq_len: int, theta: int = 10_000, rope_scaling_factor: int = 1
73 | ):
74 | self.register_buffer(
75 | name, gemma_model.precompute_freqs_cis(head_dim, max_seq_len * 2, theta=theta, rope_scaling_factor=rope_scaling_factor)
76 | )
77 |
78 | @torch.no_grad()
79 | def forward(self,
80 | input_token_ids: torch.Tensor, # B x L
81 | image_patches: torch.Tensor, # B x N x C x H x W (3x896x896)
82 | image_presence_mask: torch.Tensor, # B x N
83 | input_positions: torch.Tensor,
84 | kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
85 | mask: torch.Tensor,
86 | output_positions: torch.Tensor,
87 | temperatures: Union[torch.Tensor, None],
88 | top_ps: torch.Tensor,
89 | top_ks: torch.Tensor,
90 | local_mask: torch.Tensor | None = None,
91 | **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
92 | freqs_cis = {}
93 | freqs_cis[gemma_config.AttentionType.LOCAL_SLIDING] = (
94 | self.local_freqs_cis.index_select(0, input_positions)
95 | )
96 | freqs_cis[gemma_config.AttentionType.GLOBAL] = (
97 | self.global_freqs_cis.index_select(0, input_positions)
98 | )
99 | hidden_states = self.text_token_embedder(input_token_ids)
100 | normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype, device=hidden_states.device)
101 | hidden_states = hidden_states * normalizer
102 | if image_patches is not None and self.config.vision_config is not None:
103 | # the input has images
104 | B, N, C, H, W = image_patches.shape
105 | # Flatten and Pass to SiglipVisionModel, and apply SiglipVisionModel Exit
106 | flattened_input = image_patches.reshape(B * N, C, H, W) # (B*N)xCxHxW
107 | image_embeddings = self.siglip_vision_model(flattened_input) # (B*N)xUxD
108 | image_embeddings = self.mm_soft_embedding_norm(image_embeddings) # (B*N) x U x D
109 | image_embeddings = self.mm_input_projection(image_embeddings) # (B*N) x U x model_dim
110 | hidden_states = self.populate_image_embeddings(
111 | hidden_states.clone(),
112 | image_embeddings.clone(),
113 | input_token_ids.clone(),
114 | image_presence_mask.clone(),
115 | )
116 |
117 | kv_write_indices = input_positions
118 |
119 | hidden_states = self.model(
120 | hidden_states=hidden_states,
121 | freqs_cis=freqs_cis,
122 | kv_write_indices=kv_write_indices,
123 | kv_caches=kv_caches,
124 | mask=mask,
125 | local_mask=local_mask,
126 | )
127 | embedder_weight = self.text_token_embedder.weight
128 | if self.config.quant:
129 | embedder_weight = (
130 | embedder_weight * self.text_token_embedder.weight_scaler.unsqueeze(-1))
131 |
132 | next_tokens, logits = self.sampler(
133 | embedding=embedder_weight,
134 | hidden_states=hidden_states,
135 | output_positions=output_positions,
136 | temperatures=temperatures,
137 | top_ps=top_ps,
138 | top_ks=top_ks,
139 | )
140 | return next_tokens, logits
141 |
142 | def populate_image_embeddings(self,
143 | hidden_states: torch.Tensor, # B x L x model_dim
144 | image_embeddings: torch.Tensor, # (B*N) x U x model_dim
145 | input_token_ids: torch.Tensor, # B x L
146 | image_presence_mask: torch.Tensor, # B x N
147 | ):
148 | batch_size, seq_len, model_dim = hidden_states.shape
149 | # Step 1 of 2: Fetch valid image embeddings
150 | # flatten indices of valid image embeddings
151 | valid_image_embeddings_indices = torch.nonzero(image_presence_mask.flatten(), as_tuple=False).squeeze()
152 | # num_valid_images x model_dim
153 | valid_image_embeddings = image_embeddings.index_select(0, valid_image_embeddings_indices)
154 |
155 | # Step 2 of 2: Replace image embeddings at right places.
156 | image_placeholder_mask = input_token_ids == self.tokenizer.image_token_placeholder_id
157 | image_placeholder_indices = image_placeholder_mask.flatten().nonzero(as_tuple=False).squeeze()
158 |
159 | hidden_states = hidden_states.reshape(-1, self.config.hidden_size)
160 | hidden_states[image_placeholder_indices] = valid_image_embeddings.reshape(-1, self.config.hidden_size)
161 | return hidden_states.reshape(batch_size, seq_len, model_dim).contiguous()
162 |
163 | def create_attention_mask(self, input_ids: torch.Tensor, sequence_length: int):
164 | batch_size = input_ids.shape[0]
165 | causal_mask = torch.tril(torch.ones((batch_size, 1, sequence_length, sequence_length), dtype=torch.bool, device=input_ids.device))
166 | image_token_mask = input_ids == self.tokenizer.image_token_placeholder_id
167 | # Pad the mask to the left with 0. This is to make sure the boundary
168 | # detection works correctly. Boundary (starting index of image patch) is
169 | # detected when the value changes from 0 to 1.
170 | padded_mask = nn.functional.pad(image_token_mask, (1, 0), value=0)
171 | # Find the boundary (starting index) of the image tokens patch.
172 | boundary = padded_mask[:, 1:] > padded_mask[:, :-1]
173 | # Number the boundary.
174 | # boundary:
175 | # [[False, False, True, False, False],
176 | # [False, True, False, True, False]]
177 | # numbered_boundary:
178 | # [[0, 0, 1, 1, 1],
179 | # [0, 1, 1, 2, 2]]
180 | numbered_boundary = torch.cumsum(boundary, dim=-1)
181 |
182 | # image_token_mask:
183 | # [[False, False, True, True, False],
184 | # [True, True, False, True, True]]
185 | # numbered_boundary:
186 | # [[0, 0, 1, 1, 1],
187 | # [1, 1, 1, 2, 2]]
188 | # q_block_indices:
189 | # [[0, 0, 1, 1, 0],
190 | # [1, 1, 0, 2, 2]]
191 | q_block_indices = image_token_mask * numbered_boundary
192 | kv_block_indices = q_block_indices
193 | # Test the equality of vertical and horizontal numbered patches
194 | # to create the bidirectional mask.
195 | bidirectional_mask = torch.logical_and(
196 | kv_block_indices[:, None, :] == q_block_indices.unsqueeze(-1),
197 | q_block_indices.unsqueeze(-1) > 0,
198 | )
199 | attention_mask = torch.logical_or(causal_mask, bidirectional_mask.unsqueeze(1))
200 | # The upper triangular matrix's diagonal is shifted by sliding window size
201 | # before doing logical 'and' with attention mask. This is to make sure the
202 | # local attention is within the sliding window.
203 | local_mask = torch.logical_and(
204 | attention_mask,
205 | torch.triu(torch.ones((1, 1, sequence_length, sequence_length), dtype=torch.bool, device=input_ids.device), diagonal=-(self.config.sliding_window_size-1))
206 | )
207 | return attention_mask, local_mask
208 |
209 | def generate(
210 | self,
211 | prompts: Sequence[Sequence[Union[str, Image.Image]]],
212 | device: Any,
213 | output_len: int = 100,
214 | temperature: Union[float, None] = 1.0,
215 | top_p: float = 0.95,
216 | top_k: int = 64,
217 | ) -> Sequence[str]:
218 | """Generates responses for given prompts using Gemma model."""
219 | # Inference only.
220 | processing_result = gemma3_preprocessor.tokenize_raw_input(
221 | self.tokenizer, prompts, self.config, output_len, device
222 | )
223 | batch_size = processing_result["batch_size"]
224 | user_input_token_ids = processing_result["user_input_token_ids"]
225 | image_batch = processing_result["image_batch"]
226 | min_prompt_len = processing_result["min_prompt_len"]
227 | max_prompt_len = processing_result["max_prompt_len"]
228 | total_seq_len = processing_result["max_seq_len"]
229 | image_presence_mask = processing_result["image_presence_mask"]
230 |
231 | # Create attention mask.
232 | min_dtype = torch.finfo(self.dtype).min
233 | if self.config.sliding_window_size is None:
234 | raise ValueError('gemma 3 model requires sliding_window size')
235 | boolean_mask, local_boolean_mask = self.create_attention_mask(user_input_token_ids, total_seq_len)
236 | mask_tensor = torch.where(boolean_mask, 0, torch.tensor(min_dtype, dtype=torch.float32, device=device)).contiguous()
237 | local_mask_tensor = torch.where(local_boolean_mask, 0, torch.tensor(min_dtype, dtype=torch.float32, device=device)).contiguous()
238 |
239 | kv_caches = []
240 | for _ in range(self.config.num_hidden_layers):
241 | size = (batch_size, total_seq_len, self.config.num_key_value_heads,
242 | self.config.head_dim)
243 | dtype = self.config.get_dtype()
244 | k_cache = torch.zeros(size=size, dtype=dtype, device=device)
245 | v_cache = torch.zeros(size=size, dtype=dtype, device=device)
246 | kv_caches.append((k_cache, v_cache))
247 |
248 | input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
249 | self.tokenizer.pad_id,
250 | dtype=torch.int64, device=device)
251 | token_ids_tensor = user_input_token_ids.to(device)
252 | for i in range(batch_size):
253 | p = user_input_token_ids[i]
254 | input_token_ids_tensor[i, :min_prompt_len] = p[:min_prompt_len]
255 |
256 | input_positions_tensor = torch.arange(0, min_prompt_len, dtype=torch.int64, device=device)
257 | prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id
258 | curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
259 | curr_local_mask_tensor = local_mask_tensor.index_select(2, input_positions_tensor)
260 | output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device)
261 | temperatures_tensor = None if not temperature else torch.FloatTensor(
262 | [temperature] * batch_size).to(device)
263 | top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
264 | top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
265 | output_index = torch.tensor(min_prompt_len, dtype=torch.int64, device=device)
266 |
267 | # Prefill up to min_prompt_len tokens, then treat other prefill as
268 | # decode and ignore output.
269 | for i in range(total_seq_len - min_prompt_len):
270 | next_token_ids, _ = self(
271 | input_token_ids=input_token_ids_tensor,
272 | image_patches=image_batch,
273 | image_presence_mask=image_presence_mask,
274 | input_positions=input_positions_tensor,
275 | kv_caches=kv_caches,
276 | mask=curr_mask_tensor,
277 | output_positions=output_positions_tensor,
278 | temperatures=temperatures_tensor,
279 | top_ps=top_ps_tensor,
280 | top_ks=top_ks_tensor,
281 | local_mask=curr_local_mask_tensor,
282 | )
283 | curr_prompt_mask = prompt_mask_tensor.index_select(
284 | 1, output_index).squeeze(dim=1)
285 | curr_token_ids = token_ids_tensor.index_select(
286 | 1, output_index).squeeze(dim=1)
287 | output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
288 | next_token_ids).unsqueeze(dim=1)
289 | token_ids_tensor.index_copy_(1, output_index, output_token_ids)
290 |
291 | input_token_ids_tensor = output_token_ids
292 | input_positions_tensor = output_index.unsqueeze(dim=-1)
293 | curr_mask_tensor = mask_tensor.index_select(2,
294 | input_positions_tensor)
295 | curr_local_mask_tensor = local_mask_tensor.index_select(
296 | 2, input_positions_tensor
297 | ) if local_mask_tensor is not None else None
298 | output_positions_tensor = torch.tensor(0, dtype=torch.int64, device=device)
299 | output_index = output_index + 1
300 | image_batch = None
301 | image_presence_mask = None
302 |
303 | # Detokenization.
304 | token_ids = token_ids_tensor.tolist()
305 | results = []
306 | for i, tokens in enumerate(token_ids):
307 | output = tokens
308 | if self.tokenizer.eos_id in output:
309 | eos_index = output.index(self.tokenizer.eos_id)
310 | output = output[:eos_index]
311 | results.append(self.tokenizer.decode(output))
312 |
313 | return results
314 |
315 | def load_weights(self, model_path: str):
316 | if os.path.isfile(model_path):
317 | self.load_state_dict(
318 | torch.load(
319 | model_path, mmap=True, weights_only=True,
320 | )['model_state_dict'],
321 | strict=False,
322 | )
323 | else:
324 | index_path = os.path.join(model_path, 'pytorch_model.bin.index.json')
325 | with open(index_path, "r", encoding="utf-8") as f:
326 | index = json.load(f)
327 | shard_files = list(set(index["weight_map"].values()))
328 | for shard_file in shard_files:
329 | shard_path = os.path.join(model_path, shard_file)
330 | state_dict = torch.load(shard_path, map_location="cpu", weights_only=True)
331 | self.load_state_dict(state_dict, strict=False)
332 | del state_dict # Save memory.
333 | gc.collect()
334 |
--------------------------------------------------------------------------------
/gemma/model_xla.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Inference-only Gemma model implementation."""
16 |
17 | import json
18 | import gc
19 | import os
20 | import re
21 | import torch
22 | from torch import nn
23 | import torch.nn.functional as F
24 | from typing import List, Mapping, Optional, Tuple, Union
25 |
26 | from google3.third_party.open_models_release.gemma_pytorch.gemma import config as gemma_config
27 | from google3.third_party.open_models_release.gemma_pytorch.gemma.xla_model_parallel import (
28 | ColumnParallelLinear,
29 | ParallelEmbedding,
30 | RowParallelLinear,
31 | reduce_from_model_parallel_region,
32 | scatter_to_model_parallel_region,
33 | )
34 |
35 |
36 | class Sampler(nn.Module):
37 |
38 | def __init__(self, vocab_size: int, world_size: int, rank: int,
39 | config: gemma_config.GemmaConfig) -> None:
40 | super().__init__()
41 | self.vocab_size = vocab_size
42 | self.world_size = world_size
43 | self.rank = rank
44 | self.config = config
45 |
46 | @torch.no_grad()
47 | def forward(
48 | self,
49 | embedding: torch.Tensor,
50 | hidden_states: torch.Tensor,
51 | output_positions: torch.Tensor,
52 | temperatures: Union[torch.Tensor, None],
53 | top_ps: torch.Tensor,
54 | top_ks: torch.Tensor,
55 | embedding_bias: Optional[torch.Tensor] = None,
56 | ) -> Tuple[torch.Tensor, torch.Tensor]:
57 | # Select the last element for each sequence.
58 | # (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)
59 | hidden_states = hidden_states.index_select(
60 | 1, output_positions).squeeze(dim=1)
61 |
62 | hidden_states_parallel = scatter_to_model_parallel_region(
63 | hidden_states,
64 | groups=None,
65 | world_size=self.world_size,
66 | rank=self.rank)
67 | hidden_states_parallel = torch.matmul(hidden_states_parallel,
68 | embedding.t())
69 | logits = reduce_from_model_parallel_region(
70 | hidden_states_parallel,
71 | groups=None,
72 | world_size=self.world_size,
73 | rank=self.rank,
74 | )
75 | if embedding_bias is not None:
76 | logits += embedding_bias
77 | if self.config.final_logit_softcapping is not None:
78 | logits = logits / self.config.final_logit_softcapping
79 | logits = torch.tanh(logits)
80 | logits = logits * self.config.final_logit_softcapping
81 |
82 | if temperatures is None:
83 | return torch.argmax(logits, dim=-1).squeeze(dim=-1), logits
84 |
85 | # Apply temperature scaling.
86 | logits.div_(temperatures.unsqueeze(dim=1))
87 |
88 | # Calculate probabilities with softmax.
89 | probs = torch.softmax(logits, dim=-1, dtype=torch.float)
90 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
91 |
92 | # Apply top-p, top-k.
93 | probs_sum = torch.cumsum(probs_sort, dim=-1)
94 | top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
95 | probs_sort = torch.where(top_ps_mask, 0, probs_sort)
96 |
97 | top_ks_mask = torch.arange(probs_idx.shape[-1],
98 | device=probs_idx.device)
99 | top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
100 | top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
101 | probs_sort = torch.where(top_ks_mask, 0, probs_sort)
102 |
103 | # Re-normalization.
104 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
105 | probs = torch.gather(probs_sort,
106 | dim=-1,
107 | index=torch.argsort(probs_idx, dim=-1))
108 |
109 | next_token_ids = torch.multinomial(probs,
110 | num_samples=1,
111 | replacement=True).squeeze(dim=-1)
112 | return next_token_ids, logits
113 |
114 |
115 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
116 | """Precomputes the frequency cis."""
117 | freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
118 | t = torch.arange(end, device=freqs.device) # type: ignore
119 | freqs = torch.outer(t, freqs).float() # type: ignore
120 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
121 | return freqs_cis
122 |
123 |
124 | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
125 | """Applies the rotary embedding to the query and key tensors."""
126 | x_ = torch.view_as_complex(
127 | torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),
128 | dim=-1))
129 | x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
130 | x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
131 | x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
132 | -1).transpose(1, 2)
133 | return x_out
134 |
135 |
136 | class RMSNorm(torch.nn.Module):
137 |
138 | def __init__(
139 | self,
140 | dim: int,
141 | eps: float = 1e-6,
142 | add_unit_offset: bool = True,
143 | ):
144 | super().__init__()
145 | self.eps = eps
146 | self.add_unit_offset = add_unit_offset
147 | self.weight = nn.Parameter(torch.ones(dim))
148 |
149 | def _norm(self, x):
150 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
151 |
152 | def forward(self, x):
153 | # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
154 | # See https://github.com/huggingface/transformers/pull/29402
155 | output = self._norm(x.float())
156 | if self.add_unit_offset:
157 | output = output * (1 + self.weight.float())
158 | else:
159 | output = output * self.weight.float()
160 | return output.type_as(x)
161 |
162 |
163 | class GemmaMLP(nn.Module):
164 |
165 | def __init__(
166 | self,
167 | hidden_size: int,
168 | intermediate_size: int,
169 | world_size: int,
170 | rank: int,
171 | quant: bool,
172 | ):
173 | super().__init__()
174 | self.hidden_size = hidden_size
175 | self.intermediate_size = intermediate_size
176 |
177 | def init_method(x):
178 | return x
179 |
180 | self.gate_proj = ColumnParallelLinear(
181 | hidden_size,
182 | intermediate_size,
183 | bias=False,
184 | gather_output=False,
185 | init_method=init_method,
186 | world_size=world_size,
187 | rank=rank,
188 | quant=quant,
189 | )
190 |
191 | self.up_proj = ColumnParallelLinear(
192 | hidden_size,
193 | intermediate_size,
194 | bias=False,
195 | gather_output=False,
196 | init_method=init_method,
197 | world_size=world_size,
198 | rank=rank,
199 | quant=quant,
200 | )
201 |
202 | self.down_proj = RowParallelLinear(
203 | intermediate_size,
204 | hidden_size,
205 | bias=False,
206 | input_is_parallel=True,
207 | init_method=init_method,
208 | world_size=world_size,
209 | rank=rank,
210 | quant=quant,
211 | )
212 |
213 | def forward(self, x):
214 | gate = self.gate_proj(x)
215 | gate = F.gelu(gate, approximate="tanh")
216 | up = self.up_proj(x)
217 | fuse = gate * up
218 | outputs = self.down_proj(fuse)
219 | return outputs
220 |
221 |
222 | class GemmaAttention(nn.Module):
223 |
224 | def __init__(
225 | self,
226 | hidden_size: int,
227 | num_heads: int,
228 | num_kv_heads: int,
229 | attn_logit_softcapping: Optional[float],
230 | query_pre_attn_scalar: Optional[int],
231 | head_dim: int,
232 | world_size: int,
233 | rank: int,
234 | quant: bool,
235 | attn_type: gemma_config.AttentionType,
236 | sliding_window_size: Optional[int] = None,
237 | ):
238 | super().__init__()
239 | self.rank = rank
240 |
241 | def init_method(x):
242 | return x
243 |
244 | self.total_num_heads = num_heads
245 | assert self.total_num_heads % world_size == 0
246 | self.num_heads = self.total_num_heads // world_size # head per shard
247 |
248 | if num_kv_heads < world_size:
249 | assert world_size % num_kv_heads == 0
250 | self.total_num_kv_heads = world_size
251 | else:
252 | assert num_kv_heads % world_size == 0
253 | self.total_num_kv_heads = num_kv_heads
254 | self.num_kv_heads = self.total_num_kv_heads // world_size # kv head per shard
255 |
256 | assert self.num_heads % self.num_kv_heads == 0
257 | self.num_queries_per_kv = self.num_heads // self.num_kv_heads
258 |
259 | self.hidden_size = hidden_size
260 | self.head_dim = head_dim
261 |
262 | self.q_size = self.num_heads * self.head_dim
263 | self.kv_size = self.num_kv_heads * self.head_dim
264 |
265 | if query_pre_attn_scalar is not None:
266 | self.scaling = query_pre_attn_scalar**-0.5
267 | else:
268 | self.scaling = self.head_dim**-0.5
269 |
270 | self.qkv_proj = ColumnParallelLinear(
271 | self.hidden_size,
272 | (self.total_num_heads + 2 * self.total_num_kv_heads) *
273 | self.head_dim,
274 | bias=False,
275 | gather_output=False,
276 | init_method=init_method,
277 | world_size=world_size,
278 | rank=rank,
279 | quant=quant,
280 | )
281 |
282 | self.o_proj = RowParallelLinear(
283 | self.total_num_heads * self.head_dim,
284 | self.hidden_size,
285 | bias=False,
286 | input_is_parallel=True,
287 | init_method=init_method,
288 | world_size=world_size,
289 | rank=rank,
290 | quant=quant,
291 | )
292 |
293 | self.attn_type = attn_type
294 | self.sliding_window_size = sliding_window_size
295 | self.attn_logit_softcapping = attn_logit_softcapping
296 |
297 | def forward(
298 | self,
299 | hidden_states: torch.Tensor,
300 | freqs_cis: torch.Tensor,
301 | kv_write_indices: torch.Tensor,
302 | kv_cache: Tuple[torch.Tensor, torch.Tensor],
303 | mask: torch.Tensor,
304 | ) -> torch.Tensor:
305 | hidden_states_shape = hidden_states.shape
306 | assert len(hidden_states_shape) == 3
307 |
308 | batch_size, input_len, _ = hidden_states_shape
309 |
310 | qkv = self.qkv_proj(hidden_states)
311 | xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
312 | dim=-1)
313 |
314 | xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
315 | xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
316 | xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
317 |
318 | # Positional embedding.
319 | xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
320 | xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
321 |
322 | # Write new kv cache.
323 | # [batch_size, input_len, n_local_kv_heads, head_dim]
324 | k_cache, v_cache = kv_cache
325 | k_cache.index_copy_(1, kv_write_indices, xk)
326 | v_cache.index_copy_(1, kv_write_indices, xv)
327 |
328 | key = k_cache
329 | value = v_cache
330 | if self.num_kv_heads != self.num_heads:
331 | # [batch_size, max_seq_len, n_local_heads, head_dim]
332 | key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
333 | value = torch.repeat_interleave(value,
334 | self.num_queries_per_kv,
335 | dim=2)
336 |
337 | # [batch_size, n_local_heads, input_len, head_dim]
338 | q = xq.transpose(1, 2)
339 | # [batch_size, n_local_heads, max_seq_len, head_dim]
340 | k = key.transpose(1, 2)
341 | v = value.transpose(1, 2)
342 |
343 | # [batch_size, n_local_heads, input_len, max_seq_len]
344 | q.mul_(self.scaling)
345 | scores = torch.matmul(q, k.transpose(2, 3))
346 | if (
347 | self.attn_type == gemma_config.AttentionType.LOCAL_SLIDING
348 | and self.sliding_window_size is not None
349 | ):
350 | all_ones = torch.ones_like(mask)
351 | sliding_mask = torch.triu(
352 | all_ones, -1 * self.sliding_window_size + 1
353 | ) * torch.tril(all_ones, self.sliding_window_size - 1)
354 | mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)
355 | if self.attn_logit_softcapping is not None:
356 | scores = scores / self.attn_logit_softcapping
357 | scores = torch.tanh(scores)
358 | scores = scores * self.attn_logit_softcapping
359 | scores = scores + mask
360 | scores = F.softmax(scores.float(), dim=-1).type_as(q)
361 |
362 | # [batch_size, n_local_heads, input_len, head_dim]
363 | output = torch.matmul(scores, v)
364 |
365 | # [batch_size, input_len, hidden_dim]
366 | output = (output.transpose(1, 2).contiguous().view(
367 | batch_size, input_len, -1))
368 | output = self.o_proj(output)
369 | return output
370 |
371 |
372 | class GemmaDecoderLayer(nn.Module):
373 |
374 | def __init__(
375 | self,
376 | config: gemma_config.GemmaConfig,
377 | world_size: int,
378 | rank: int,
379 | ):
380 | super().__init__()
381 | self.rank = rank
382 | self.self_attn = GemmaAttention(
383 | hidden_size=config.hidden_size,
384 | num_heads=config.num_attention_heads,
385 | num_kv_heads=config.num_key_value_heads,
386 | attn_logit_softcapping=config.attn_logit_softcapping,
387 | query_pre_attn_scalar=config.query_pre_attn_scalar,
388 | head_dim=config.head_dim,
389 | world_size=world_size,
390 | rank=rank,
391 | quant=config.quant,
392 | attn_type=gemma_config.AttentionType.GLOBAL,
393 | )
394 | self.mlp = GemmaMLP(
395 | hidden_size=config.hidden_size,
396 | intermediate_size=config.intermediate_size,
397 | world_size=world_size,
398 | rank=rank,
399 | quant=config.quant,
400 | )
401 | self.input_layernorm = RMSNorm(config.hidden_size,
402 | eps=config.rms_norm_eps)
403 | self.post_attention_layernorm = RMSNorm(config.hidden_size,
404 | eps=config.rms_norm_eps)
405 |
406 | def forward(
407 | self,
408 | hidden_states: torch.Tensor,
409 | freqs_cis: torch.Tensor,
410 | kv_write_indices: torch.Tensor,
411 | kv_cache: Tuple[torch.Tensor, torch.Tensor],
412 | mask: torch.Tensor,
413 | ) -> torch.Tensor:
414 | # Self Attention
415 | residual = hidden_states
416 | hidden_states = self.input_layernorm(hidden_states)
417 | hidden_states = self.self_attn(
418 | hidden_states=hidden_states,
419 | freqs_cis=freqs_cis,
420 | kv_write_indices=kv_write_indices,
421 | kv_cache=kv_cache,
422 | mask=mask,
423 | )
424 | hidden_states = residual + hidden_states
425 |
426 | # MLP
427 | residual = hidden_states
428 | hidden_states = self.post_attention_layernorm(hidden_states)
429 | hidden_states = self.mlp(hidden_states)
430 | hidden_states = residual + hidden_states
431 |
432 | return hidden_states
433 |
434 |
435 | class Gemma2DecoderLayer(nn.Module):
436 |
437 | def __init__(
438 | self,
439 | config: gemma_config.GemmaConfig,
440 | attn_type: gemma_config.AttentionType,
441 | world_size: int,
442 | rank: int,
443 | ):
444 | super().__init__()
445 | self.rank = rank
446 | self.self_attn = GemmaAttention(
447 | hidden_size=config.hidden_size,
448 | num_heads=config.num_attention_heads,
449 | num_kv_heads=config.num_key_value_heads,
450 | attn_logit_softcapping=config.attn_logit_softcapping,
451 | query_pre_attn_scalar=config.query_pre_attn_scalar,
452 | head_dim=config.head_dim,
453 | world_size=world_size,
454 | rank=rank,
455 | quant=config.quant,
456 | attn_type=attn_type,
457 | sliding_window_size=config.sliding_window_size,
458 | )
459 | self.mlp = GemmaMLP(
460 | hidden_size=config.hidden_size,
461 | intermediate_size=config.intermediate_size,
462 | world_size=world_size,
463 | rank=rank,
464 | quant=config.quant,
465 | )
466 | self.input_layernorm = RMSNorm(config.hidden_size,
467 | eps=config.rms_norm_eps)
468 | self.post_attention_layernorm = RMSNorm(config.hidden_size,
469 | eps=config.rms_norm_eps)
470 | self.pre_feedforward_layernorm = (
471 | RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
472 | if config.use_pre_ffw_norm
473 | else None
474 | )
475 | self.post_feedforward_layernorm = (
476 | RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
477 | if config.use_post_ffw_norm
478 | else None
479 | )
480 |
481 | def forward(
482 | self,
483 | hidden_states: torch.Tensor,
484 | freqs_cis: torch.Tensor,
485 | kv_write_indices: torch.Tensor,
486 | kv_cache: Tuple[torch.Tensor, torch.Tensor],
487 | mask: torch.Tensor,
488 | ) -> torch.Tensor:
489 | # Self Attention
490 | residual = hidden_states
491 | hidden_states = self.input_layernorm(hidden_states)
492 | hidden_states = self.self_attn(
493 | hidden_states=hidden_states,
494 | freqs_cis=freqs_cis,
495 | kv_write_indices=kv_write_indices,
496 | kv_cache=kv_cache,
497 | mask=mask,
498 | )
499 | hidden_states = self.post_attention_layernorm(hidden_states)
500 | hidden_states = residual + hidden_states
501 |
502 | # MLP
503 | residual = hidden_states
504 | if self.pre_feedforward_layernorm is not None:
505 | hidden_states = self.pre_feedforward_layernorm(hidden_states)
506 | hidden_states = self.mlp(hidden_states)
507 | if self.post_feedforward_layernorm is not None:
508 | hidden_states = self.post_feedforward_layernorm(hidden_states)
509 | hidden_states = residual + hidden_states
510 |
511 | return hidden_states
512 |
513 |
514 | class GemmaModel(nn.Module):
515 |
516 | def __init__(
517 | self,
518 | config: gemma_config.GemmaConfig,
519 | world_size: int,
520 | rank: int
521 | ):
522 | super().__init__()
523 | self.config = config
524 | self.rank = rank
525 | self.vocab_size = config.vocab_size
526 |
527 | self.layers = nn.ModuleList()
528 | for i in range(config.num_hidden_layers):
529 | if config.architecture == gemma_config.Architecture.GEMMA_1:
530 | self.layers.append(GemmaDecoderLayer(config, world_size, rank))
531 | elif config.architecture == gemma_config.Architecture.GEMMA_2:
532 | attn_type = (
533 | config.attn_types[i]
534 | if config.attn_types is not None
535 | else gemma_config.AttentionType.GLOBAL
536 | )
537 | self.layers.append(
538 | Gemma2DecoderLayer(config, attn_type, world_size, rank))
539 | else:
540 | raise ValueError(f'Unknown architecture: {config.architecture}')
541 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
542 |
543 | def forward(
544 | self,
545 | hidden_states: torch.Tensor,
546 | freqs_cis: torch.Tensor,
547 | kv_write_indices: torch.Tensor,
548 | kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
549 | mask: torch.Tensor,
550 | ) -> torch.Tensor:
551 | for i in range(len(self.layers)):
552 | layer = self.layers[i]
553 | hidden_states = layer(
554 | hidden_states=hidden_states,
555 | freqs_cis=freqs_cis,
556 | kv_write_indices=kv_write_indices,
557 | kv_cache=kv_caches[i],
558 | mask=mask,
559 | )
560 | hidden_states = self.norm(hidden_states)
561 | return hidden_states
562 |
563 |
564 | class GemmaForCausalLM(nn.Module):
565 |
566 | def __init__(
567 | self,
568 | config: gemma_config.GemmaConfig,
569 | world_size: int,
570 | rank: int,
571 | device: torch.device,
572 | ):
573 | super().__init__()
574 | self.config = config
575 | self.world_size = world_size
576 | self.rank = rank
577 | self.device = device
578 |
579 | assert config.num_attention_heads % world_size == 0
580 | assert config.hidden_size % config.num_attention_heads == 0
581 |
582 | max_seq_len = config.max_position_embeddings
583 | head_dim = config.head_dim
584 | vocab_size = config.vocab_size
585 |
586 | def init_method(x):
587 | return x
588 |
589 | self.embedder = ParallelEmbedding(
590 | vocab_size,
591 | config.hidden_size,
592 | init_method=init_method,
593 | world_size=world_size,
594 | rank=rank,
595 | quant=config.quant,
596 | )
597 | self.model = GemmaModel(config, world_size, rank)
598 | self.sampler = Sampler(vocab_size, world_size, rank, config)
599 |
600 | rope_theta = getattr(config, 'rope_theta', 10000)
601 | # [head_dim * 2, ] -> complex -> two dim (real, imaginary) implicitly
602 | freqs_cis = precompute_freqs_cis(head_dim,
603 | max_seq_len * 2,
604 | theta=rope_theta)
605 | self.register_buffer('freqs_cis', freqs_cis)
606 |
607 | @torch.no_grad()
608 | def forward(
609 | self,
610 | input_token_ids: torch.Tensor,
611 | input_positions: torch.Tensor,
612 | kv_write_indices: torch.Tensor,
613 | kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
614 | mask: torch.Tensor,
615 | output_positions: torch.Tensor,
616 | temperatures: Union[torch.Tensor, None],
617 | top_ps: torch.Tensor,
618 | top_ks: torch.Tensor,
619 | **kwargs,
620 | ) -> Tuple[torch.Tensor, torch.Tensor]:
621 | freqs_cis = self.freqs_cis.index_select(0, input_positions)
622 | kv_write_indices = input_positions
623 |
624 | hidden_states = self.embedder(input_token_ids)
625 | # Gemma normalizes the embedding by sqrt(hidden_size).
626 | # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
627 | # See https://github.com/huggingface/transformers/pull/29402
628 | normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
629 | hidden_states = hidden_states * normalizer
630 | # hidden_states should be [batch_size, input_len, hidden_size]
631 |
632 | hidden_states = self.model(
633 | hidden_states=hidden_states,
634 | freqs_cis=freqs_cis,
635 | kv_write_indices=kv_write_indices,
636 | kv_caches=kv_caches,
637 | mask=mask,
638 | )
639 | embedder_weight = self.embedder.weight
640 | if self.config.quant:
641 | embedder_weight = (
642 | embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))
643 | next_tokens, logits = self.sampler(
644 | embedding=embedder_weight,
645 | hidden_states=hidden_states,
646 | output_positions=output_positions,
647 | temperatures=temperatures,
648 | top_ps=top_ps,
649 | top_ks=top_ks,
650 | )
651 | return next_tokens, logits
652 |
653 | def _load_weights(self, model_state_dict: Mapping[str, torch.Tensor]):
654 | num_attn_heads = self.config.num_attention_heads
655 | num_kv_heads = self.config.num_key_value_heads
656 | head_dim = self.config.head_dim
657 | hidden_size = self.config.hidden_size
658 |
659 | def split(tensor: torch.Tensor, axis: int) -> torch.Tensor:
660 | axis_len = tensor.shape[axis]
661 | split_len = axis_len // self.world_size
662 | split_start = split_len * self.rank
663 | split_end = split_start + split_len
664 | tensor = torch.moveaxis(tensor, axis, 0)
665 | tensor = tensor[split_start:split_end, ...]
666 | tensor = torch.moveaxis(tensor, 0, axis)
667 | return tensor
668 |
669 | for k, v in model_state_dict.items():
670 | if k == 'freqs_cis':
671 | continue
672 | if (k == 'model.norm.weight'
673 | or k.endswith('_layernorm.weight')
674 | or k.endswith('weight_scaler')):
675 | pass
676 | elif (k == 'embedder.weight' or re.fullmatch(
677 | r'model.layers.\d+.mlp.down_proj.weight', k)):
678 | v = split(v, 1)
679 | elif (re.fullmatch(r'model.layers.\d+.mlp.gate_proj.weight', k)
680 | or re.fullmatch(r'model.layers.\d+.mlp.up_proj.weight', k)):
681 | v = split(v, 0)
682 | elif re.fullmatch(r'model.layers.\d+.self_attn.qkv_proj.weight',
683 | k):
684 | if num_kv_heads <= num_attn_heads:
685 | # If num_kv_heads > self.world_size, we still want 1
686 | # replica.
687 | num_replicas = max(self.world_size // num_kv_heads, 1)
688 | v = v.reshape(num_attn_heads + num_kv_heads * 2, head_dim,
689 | hidden_size)
690 | query = v[:num_attn_heads, ...]
691 | key = v[num_attn_heads:num_attn_heads + num_kv_heads,
692 | ...].repeat(num_replicas, 1, 1)
693 | value = v[-num_kv_heads:, ...].repeat(num_replicas, 1, 1)
694 | v = torch.cat(
695 | (split(query, 0), split(key, 0), split(value, 0)),
696 | dim=0)
697 | else:
698 | v = v.reshape(3, num_attn_heads, head_dim, hidden_size)
699 | v = split(v, 1)
700 | v = v.reshape(-1, hidden_size)
701 | elif re.fullmatch(r'model.layers.\d+.self_attn.o_proj.weight', k):
702 | v = v.reshape(hidden_size, num_attn_heads, head_dim)
703 | v = split(v, 1)
704 | v = v.reshape(hidden_size, -1)
705 | else:
706 | raise ValueError(f'Unrecognized key: {k}')
707 | self.state_dict()[k].copy_(v)
708 |
709 | def load_weights(self, model_path: str):
710 | if os.path.isfile(model_path):
711 | checkpoint = torch.load(model_path, weights_only=True)
712 | model_state_dict = checkpoint['model_state_dict']
713 | self._load_weights(model_state_dict)
714 | else:
715 | index_path = os.path.join(model_path, 'pytorch_model.bin.index.json')
716 | with open(index_path, "r", encoding="utf-8") as f:
717 | index = json.load(f)
718 | shard_files = list(set(index["weight_map"].values()))
719 | for shard_file in shard_files:
720 | shard_path = os.path.join(model_path, shard_file)
721 | state_dict = torch.load(shard_path, map_location="cpu", weights_only=True)
722 | self._load_weights(state_dict)
723 | del state_dict # Save memory.
724 | gc.collect()
725 |
--------------------------------------------------------------------------------
/gemma/xla_model_parallel.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from copy import deepcopy
16 | from dataclasses import dataclass
17 | import os
18 | from typing import Callable, List, Optional, Tuple
19 |
20 | import torch
21 | import torch.ao.quantization.fx._decomposed
22 | import torch.distributed as dist
23 | import torch.distributed._functional_collectives as fc
24 | import torch.distributed.distributed_c10d as c10d
25 | import torch.nn.functional as F
26 | import torch.nn.init as init
27 | from torch.nn.parameter import Parameter
28 |
29 | EPS = torch.finfo(torch.float32).eps
30 |
31 | USE_CUDA = os.environ.get('USE_CUDA', False)
32 | if not USE_CUDA:
33 | import torch_xla.core.xla_model as xm
34 |
35 | TAG = None
36 | RANKSET = None
37 | GROUP_SIZE = None
38 |
39 |
40 | def set_g_group():
41 | global TAG
42 | global RANKSET
43 | global GROUP_SIZE
44 |
45 | assert USE_CUDA, "This hack is only for PyTorch non-XLA CUDA paths, i.e., eager and inductor."
46 | TAG, RANKSET, GROUP_SIZE = fc._expand_group(c10d._get_default_group())
47 |
48 |
49 | @dataclass
50 | class TensorQConfig:
51 | dtype: torch.dtype = torch.int8
52 | axis: int = -1
53 | quant_min: int = -128
54 | quant_max: int = 127
55 | symmetric_quant: bool = True
56 |
57 |
58 | def _find_per_channel_min_max(x: torch.Tensor, axis: int):
59 | x_dim = x.size()
60 | new_axis_list = list(range(len(x_dim)))
61 | new_axis_list[axis] = 0
62 | new_axis_list[0] = axis
63 | y = x.permute(new_axis_list)
64 | y = torch.flatten(y, start_dim=1)
65 | return torch.aminmax(y, dim=1)
66 |
67 |
68 | def _find_qparams(x: torch.Tensor, qconfig: TensorQConfig):
69 | # Only support per-channel symmetric quant to int8 now
70 | axis = qconfig.axis
71 | dtype = qconfig.dtype
72 | symmetric_quant = qconfig.symmetric_quant
73 | quant_min = qconfig.quant_min
74 | quant_max = qconfig.quant_max
75 | assert axis >= 0 and axis < len(x.shape)
76 | assert dtype == torch.int8
77 | min_val, max_val = _find_per_channel_min_max(x, axis)
78 | min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
79 | max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
80 | scale = torch.ones(min_val_neg.size(), dtype=torch.float32)
81 | if symmetric_quant:
82 | max_val_pos = torch.max(-min_val_neg, max_val_pos)
83 | scale = max_val_pos / (float(quant_max - quant_min) / 2)
84 | eps = torch.zeros_like(scale).fill_(EPS)
85 | scale = torch.max(scale, eps)
86 | return scale, None
87 | else:
88 | assert symmetric_quant
89 |
90 |
91 | def _quantize_to_dtype(
92 | x: torch.Tensor,
93 | qconfig: TensorQConfig,
94 | scale: torch.Tensor,
95 | zero_point: Optional[torch.Tensor] = None,
96 | ):
97 | if zero_point is None:
98 | zero_point = torch.zeros_like(scale)
99 | return torch.ops.quantized_decomposed.quantize_per_channel(
100 | x,
101 | scale,
102 | zero_point,
103 | qconfig.axis,
104 | qconfig.quant_min,
105 | qconfig.quant_max,
106 | qconfig.dtype,
107 | )
108 |
109 |
110 | def quantize_tensor(x: torch.Tensor, qconfig: TensorQConfig):
111 | scale, zp = _find_qparams(x, qconfig)
112 | x_int = _quantize_to_dtype(x, qconfig, scale, zp)
113 | return x_int, scale, zp
114 |
115 |
116 | def get_model_parallel_rank():
117 | if USE_CUDA:
118 | return dist.get_rank()
119 | return xm.get_ordinal()
120 |
121 |
122 | def get_model_parallel_world_size():
123 | if USE_CUDA:
124 | return dist.get_world_size()
125 | return xm.xrt_world_size()
126 |
127 |
128 | def get_model_parallel_group():
129 | return None
130 |
131 |
132 | class _CopyToModelParallelRegion(torch.autograd.Function):
133 | """Pass the input to the model parallel region."""
134 |
135 | @staticmethod
136 | def forward(ctx, input_, groups, world_size, rank): # type: ignore
137 | ctx.groups, ctx.world_size, ctx.rank = groups, world_size, rank
138 | return input_
139 |
140 | @staticmethod
141 | def backward(ctx, grad_output): # type: ignore
142 | groups, world_size, rank = ctx.groups, ctx.world_size, ctx.rank
143 | return my_reduce(grad_output, groups, world_size, rank)
144 |
145 |
146 | class _ReduceFromModelParallelRegion(torch.autograd.Function):
147 | """All-redcue the input from the model parallel region."""
148 |
149 | @staticmethod
150 | def forward(ctx, input_, groups, world_size, rank): # type: ignore
151 | return my_reduce(input_, groups, world_size, rank)
152 |
153 | @staticmethod
154 | def backward(ctx, grad_output): # type: ignore
155 | return grad_output
156 |
157 |
158 | class _ScatterToModelParallelRegion(torch.autograd.Function):
159 | """Split the input and keep only the corresponding chuck to the rank."""
160 |
161 | @staticmethod
162 | def forward(ctx, input_, groups, world_size, rank): # type: ignore
163 | ctx.groups, ctx.world_size, ctx.rank = groups, world_size, rank
164 | return my_split(input_, groups, world_size, rank)
165 |
166 | @staticmethod
167 | def backward(ctx, grad_output): # type: ignore
168 | groups, world_size, rank = ctx.groups, ctx.world_size, ctx.rank
169 | return my_gather(grad_output, groups, world_size, rank)
170 |
171 |
172 | class _GatherFromModelParallelRegion(torch.autograd.Function):
173 | """Gather the input from model parallel region and concatinate."""
174 |
175 | @staticmethod
176 | def forward(ctx, input_, groups, world_size, rank): # type: ignore
177 | ctx.groups, ctx.world_size, ctx.rank = groups, world_size, rank
178 | return my_gather(input_, groups, world_size, rank)
179 |
180 | @staticmethod
181 | def backward(ctx, grad_output): # type: ignore
182 | groups, world_size, rank = ctx.groups, ctx.world_size, ctx.rank
183 | return my_split(grad_output, groups, world_size, rank)
184 |
185 |
186 | # -----------------
187 | # Helper functions.
188 | # -----------------
189 |
190 |
191 | def copy_to_model_parallel_region(input_: torch.Tensor, groups, world_size,
192 | rank) -> torch.Tensor:
193 | return _CopyToModelParallelRegion.apply(input_, groups, world_size, rank)
194 |
195 |
196 | def reduce_from_model_parallel_region(input_: torch.Tensor, groups, world_size,
197 | rank) -> torch.Tensor:
198 | return _ReduceFromModelParallelRegion.apply(input_, groups, world_size,
199 | rank)
200 |
201 |
202 | def scatter_to_model_parallel_region(input_: torch.Tensor, groups, world_size,
203 | rank) -> torch.Tensor:
204 | return _ScatterToModelParallelRegion.apply(input_, groups, world_size,
205 | rank)
206 |
207 |
208 | def gather_from_model_parallel_region(input_: torch.Tensor, groups, world_size,
209 | rank) -> torch.Tensor:
210 | return _GatherFromModelParallelRegion.apply(input_, groups, world_size,
211 | rank)
212 |
213 |
214 | def ensure_divisibility(numerator: int, denominator: int) -> None:
215 | """Ensure that numerator is divisible by the denominator."""
216 | assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
217 |
218 |
219 | def divide_and_check_no_remainder(numerator: int, denominator: int) -> int:
220 | """Ensure that numerator is divisible by the denominator and return
221 | the division value."""
222 | ensure_divisibility(numerator, denominator)
223 | return numerator // denominator
224 |
225 |
226 | def split_tensor_along_last_dim(
227 | tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False
228 | ) -> Tuple[torch.Tensor, ...]:
229 | """Split a tensor along its last dimension.
230 | Arguments:
231 | tensor: input tensor.
232 | num_partitions: number of partitions to split the tensor
233 | contiguous_split_chunks: If True, make each chunk contiguous
234 | in memory.
235 | """
236 | # Get the size and dimension.
237 | last_dim = tensor.dim() - 1
238 | last_dim_size = divide_and_check_no_remainder(tensor.size()[last_dim], num_partitions)
239 | # Split.
240 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
241 | # Note: torch.split does not create contiguous tensors by default.
242 | if contiguous_split_chunks:
243 | return tuple(chunk.contiguous() for chunk in tensor_list)
244 |
245 | return tensor_list
246 |
247 | # Below copied from fairscale/nn/model_parallel/layers.py
248 |
249 |
250 | def my_reduce(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
251 | """All-reduce the the input tensor across model parallel group."""
252 | # Bypass the function if we are using only 1 GPU.
253 | if world_size == 1:
254 | return input_
255 |
256 | # All-reduce.
257 | if USE_CUDA:
258 | input_ = torch.ops.c10d_functional.all_reduce(input_, "sum", TAG,
259 | RANKSET, GROUP_SIZE)
260 | else:
261 | input_ = xm.all_reduce(xm.REDUCE_SUM, input_, groups=groups)
262 |
263 | return input_
264 |
265 |
266 | def my_split(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
267 | """Split the tensor along its last dimension and keep the
268 |
269 | corresponding slice.
270 | """
271 | # Bypass the function if we are using only 1 GPU.
272 | if world_size == 1:
273 | return input_
274 |
275 | # Split along last dimension.
276 | input_list = split_tensor_along_last_dim(input_, world_size)
277 |
278 | # Note: torch.split does not create contiguous tensors by default.
279 | output = input_list[rank].contiguous()
280 |
281 | return output
282 |
283 |
284 | def my_gather(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
285 | """Gather tensors and concatinate along the last dimension."""
286 | # Bypass the function if we are using only 1 GPU.
287 | if world_size == 1:
288 | return input_
289 |
290 | if USE_CUDA:
291 | last_dim = input_.dim() - 1
292 |
293 | # Using all_reduce to achieve all_gather as torch.ops.c10d_functional.all_gather_into_tensor
294 | # is buggy in 16 bits.
295 | size = input_.size(last_dim)
296 | padding = [0] * (2 * input_.dim())
297 | ordinal = rank
298 | left, right = ordinal, world_size - 1 - ordinal
299 | idx = input_.dim() - 1 - last_dim
300 | padding[2 * idx] = left * size
301 | padding[2 * idx + 1] = right * size
302 | output = torch.ops.c10d_functional.all_reduce(F.pad(input_,
303 | padding), "sum",
304 | TAG, RANKSET, GROUP_SIZE)
305 | else:
306 | output = xm.all_gather(input_, dim=-1, groups=groups)
307 |
308 | return output
309 |
310 |
311 | def _initialize_affine_weight(
312 | weight: torch.Tensor,
313 | out_features: int,
314 | in_features: int,
315 | per_partition_size: int,
316 | partition_dim: int,
317 | init_method: Callable[[torch.Tensor], torch.Tensor],
318 | world_size: int,
319 | rank: int,
320 | stride: int = 1,
321 | return_master_weight: bool = False,
322 | ) -> Optional[torch.Tensor]:
323 | """Initialize affine weight for model parallel.
324 |
325 | Build the master weight on all processes and scatter
326 | the relevant chunk.
327 | """
328 |
329 | # If we only use 1 process for model parallelism, bypass scatter.
330 | if world_size == 1:
331 | init_method(weight)
332 | if return_master_weight:
333 | return weight
334 | return None
335 |
336 | # Initialize master weight
337 | master_weight = torch.empty(out_features,
338 | in_features,
339 | dtype=weight.dtype,
340 | requires_grad=False)
341 | init_method(master_weight)
342 |
343 | # Split and copy
344 | per_partition_per_stride_size = divide_and_check_no_remainder(
345 | per_partition_size, stride)
346 | weight_list = torch.split(master_weight,
347 | per_partition_per_stride_size,
348 | dim=partition_dim)
349 | my_weight_list = weight_list[rank::world_size]
350 |
351 | with torch.no_grad():
352 | torch.cat(my_weight_list, dim=partition_dim, out=weight)
353 | if return_master_weight:
354 | return master_weight
355 | return None
356 |
357 |
358 | class ParallelEmbedding(torch.nn.Module):
359 | """Embedding parallelized in the embedding dimension.
360 |
361 | This is mainly adapted from torch.nn.Embedding and all the default
362 | values are kept.
363 | Arguments:
364 | num_embeddings: vocabulary size.
365 | embedding_dim: size of hidden state.
366 | init_method: method to initialize weights.
367 | """
368 |
369 | def __init__(
370 | self,
371 | num_embeddings: int,
372 | embedding_dim: int,
373 | padding_idx: Optional[int] = None,
374 | max_norm: Optional[float] = None,
375 | norm_type: float = 2.0,
376 | scale_grad_by_freq: bool = False,
377 | sparse: bool = False,
378 | init_method: Callable[[torch.Tensor],
379 | torch.Tensor] = init.xavier_normal_,
380 | keep_master_weight_for_test: bool = False,
381 | world_size: Optional[int] = None,
382 | rank: Optional[int] = None,
383 | groups: Optional[List] = None,
384 | quant: bool = False,
385 | ) -> None:
386 | super(ParallelEmbedding, self).__init__()
387 |
388 | if world_size is None:
389 | self.groups = get_model_parallel_group()
390 | self.world_size = get_model_parallel_world_size()
391 | self.rank = get_model_parallel_rank()
392 | else:
393 | self.groups = groups
394 | self.world_size = world_size
395 | self.rank = rank
396 |
397 | # Keep the input dimensions.
398 | self.num_embeddings = num_embeddings
399 | self.embedding_dim = embedding_dim
400 | self.padding_idx = padding_idx
401 | self.max_norm = max_norm
402 | self.norm_type = norm_type
403 | self.scale_grad_by_freq = scale_grad_by_freq
404 | self.sparse = sparse
405 | self._weight = None
406 | self.quant = quant
407 | # Divide the weight matrix along the embedding dimension.
408 | self.embedding_dim_per_partition = divide_and_check_no_remainder(
409 | self.embedding_dim, self.world_size)
410 |
411 | # Allocate weights.
412 | if quant:
413 | self.weight = Parameter(
414 | torch.empty(
415 | (self.num_embeddings, self.embedding_dim_per_partition),
416 | dtype=torch.int8,
417 | ),
418 | requires_grad=False,
419 | )
420 | self.weight_scaler = Parameter(torch.Tensor(self.num_embeddings))
421 | else:
422 | self.weight = Parameter(
423 | torch.Tensor(self.num_embeddings,
424 | self.embedding_dim_per_partition))
425 |
426 | # And initialize.
427 | _initialize_affine_weight(
428 | self.weight,
429 | self.num_embeddings,
430 | self.embedding_dim,
431 | self.embedding_dim_per_partition,
432 | 1,
433 | init_method,
434 | self.world_size,
435 | self.rank,
436 | stride=1,
437 | return_master_weight=False,
438 | )
439 |
440 | def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
441 | input_parallel = copy_to_model_parallel_region(input_, self.groups,
442 | self.world_size,
443 | self.rank)
444 | # PyTorch eager and inductor do not accept negative values in the input to embedding
445 | # layers. Take the modulus to avoid this error.
446 | if USE_CUDA:
447 | input_parallel = torch.remainder(input_parallel,
448 | self.weight.shape[0])
449 | weight = self.weight
450 | if self.quant:
451 | weight = weight * self.weight_scaler.unsqueeze(-1)
452 | output_parallel = F.embedding(
453 | input_parallel,
454 | weight,
455 | self.padding_idx,
456 | self.max_norm,
457 | self.norm_type,
458 | self.scale_grad_by_freq,
459 | self.sparse,
460 | )
461 | output = gather_from_model_parallel_region(output_parallel,
462 | self.groups,
463 | self.world_size, self.rank)
464 | return output
465 |
466 |
467 | class ColumnParallelLinear(torch.nn.Module):
468 | """Linear layer with column parallelism.
469 |
470 | The linear layer is defined as Y = XA + b. A is parallelized along
471 | its second dimension as A = [A_1, ..., A_p].
472 |
473 | Arguments:
474 | in_features: first dimension of matrix A.
475 | out_features: second dimension of matrix A.
476 | bias: If true, add bias
477 | gather_output: If true, call all-gether on output and make Y available to
478 | all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i
479 | init_method: method to initialize weights. Note that bias is always set to
480 | zero.
481 | stride: For the strided linear layers.
482 | keep_master_weight_for_test: This was added for testing and should be set
483 | to False. It returns the master weights used for initialization.
484 | """
485 |
486 | def __init__(
487 | self,
488 | in_features: int,
489 | out_features: int,
490 | bias: bool = True,
491 | gather_output: bool = True,
492 | init_method: Callable[[torch.Tensor],
493 | torch.Tensor] = init.xavier_normal_,
494 | stride: int = 1,
495 | keep_master_weight_for_test: bool = False,
496 | world_size: Optional[int] = None,
497 | rank: Optional[int] = None,
498 | groups: Optional[List] = None,
499 | quant: bool = False,
500 | ) -> None:
501 | super(ColumnParallelLinear, self).__init__()
502 |
503 | if world_size is None:
504 | self.groups = get_model_parallel_group()
505 | self.world_size = get_model_parallel_world_size()
506 | self.rank = get_model_parallel_rank()
507 | else:
508 | self.groups = groups
509 | self.world_size = world_size
510 | self.rank = rank
511 |
512 | # Keep input parameters
513 | self.in_features = in_features
514 | self.out_features = out_features
515 | self.gather_output = gather_output
516 | self.quant = quant
517 | # Divide the weight matrix along the last dimension.
518 | self.output_size_per_partition = divide_and_check_no_remainder(
519 | out_features, self.world_size)
520 |
521 | # Parameters.
522 | # Note: torch.nn.functional.linear performs XA^T + b and as a result
523 | # we allocate the transpose.
524 | if quant:
525 | self.weight = Parameter(
526 | torch.empty(
527 | (self.output_size_per_partition, self.in_features),
528 | dtype=torch.int8,
529 | ),
530 | requires_grad=False,
531 | )
532 | self.weight_scaler = Parameter(
533 | torch.Tensor(self.output_size_per_partition))
534 | else:
535 | self.weight = Parameter(
536 | torch.Tensor(self.output_size_per_partition, self.in_features))
537 | if bias:
538 | self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
539 | # Always initialize bias to zero.
540 | with torch.no_grad():
541 | self.bias.zero_()
542 | else:
543 | self.register_parameter('bias', None)
544 |
545 | # Initialize weight.
546 | self.master_weight = _initialize_affine_weight(
547 | self.weight,
548 | self.out_features,
549 | self.in_features,
550 | self.output_size_per_partition,
551 | 0,
552 | init_method,
553 | self.world_size,
554 | self.rank,
555 | stride=stride,
556 | return_master_weight=keep_master_weight_for_test,
557 | )
558 |
559 | def get_master_weight(self) -> torch.Tensor:
560 | return gather_from_model_parallel_region(
561 | self.weight.data.transpose(0, 1),
562 | self.groups,
563 | self.world_size,
564 | self.rank,
565 | ).transpose_(0, 1)
566 |
567 | def set_quantize(self):
568 | assert not self.quant
569 | self.weight = Parameter(
570 | torch.empty((self.output_size_per_partition, self.in_features),
571 | dtype=torch.int8),
572 | requires_grad=False,
573 | )
574 | self.weight_scaler = Parameter(
575 | torch.Tensor(self.output_size_per_partition))
576 | self.quant = True
577 |
578 | def quantize(self):
579 | assert not self.quant
580 | fp_w = deepcopy(self.weight.data)
581 | orig_dtype = fp_w.dtype
582 | fp_w = fp_w.to(torch.float32)
583 | self.weight = Parameter(
584 | torch.empty((self.output_size_per_partition, self.in_features),
585 | dtype=torch.int8),
586 | requires_grad=False,
587 | )
588 | self.weight_scaler = Parameter(
589 | torch.Tensor(self.output_size_per_partition))
590 | qconfig = TensorQConfig(axis=0)
591 | self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig)
592 | self.weight_scaler.data = scale.to(orig_dtype)
593 | self.quant = True
594 |
595 | def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
596 | # Set up backprop all-reduce.
597 | input_parallel = copy_to_model_parallel_region(input_, self.groups,
598 | self.world_size,
599 | self.rank)
600 | # Matrix multiply.
601 | if self.quant and USE_CUDA:
602 | # GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear.
603 | scaled_weight = self.weight * self.weight_scaler
604 | output_parallel = F.linear(input_parallel, scaled_weight, self.bias)
605 | elif self.quant:
606 | output_parallel = F.linear(input_parallel, self.weight, self.bias)
607 | output_parallel = output_parallel * self.weight_scaler
608 | else:
609 | output_parallel = F.linear(input_parallel, self.weight, self.bias)
610 | if self.gather_output:
611 | # All-gather across the partitions.
612 | output = gather_from_model_parallel_region(output_parallel,
613 | self.groups,
614 | self.world_size,
615 | self.rank)
616 | else:
617 | output = output_parallel
618 | return output
619 |
620 |
621 | class RowParallelLinear(torch.nn.Module):
622 | """Linear layer with row parallelism.
623 |
624 | The linear layer is defined as Y = XA + b. A is parallelized along
625 | its first dimension and X along its second dimension as:
626 | - -
627 | | A_1 |
628 | | . |
629 | A = | . | X = [X_1, ..., X_p]
630 | | . |
631 | | A_p |
632 | - -
633 | Arguments:
634 | in_features: first dimension of matrix A.
635 | out_features: second dimension of matrix A.
636 | bias: If true, add bias. Note that bias is not parallelized.
637 | input_is_parallel: If true, we assume that the input is already split
638 | across the GPUs and we do not split again.
639 | init_method: method to initialize weights. Note that bias is always set to
640 | zero.
641 | stride: For the strided linear layers.
642 | keep_master_weight_for_test: This was added for testing and should be set
643 | to False. It returns the master weights used for initialization.
644 | """
645 |
646 | def __init__(
647 | self,
648 | in_features: int,
649 | out_features: int,
650 | bias: bool = True,
651 | input_is_parallel: bool = False,
652 | init_method: Callable[[torch.Tensor],
653 | torch.Tensor] = init.xavier_normal_,
654 | stride: int = 1,
655 | keep_master_weight_for_test: bool = False,
656 | world_size: Optional[int] = None,
657 | rank: Optional[int] = None,
658 | groups: Optional[List] = None,
659 | quant: bool = False,
660 | ):
661 | super(RowParallelLinear, self).__init__()
662 |
663 | if world_size is None:
664 | self.groups = get_model_parallel_group()
665 | self.world_size = get_model_parallel_world_size()
666 | self.rank = get_model_parallel_rank()
667 | else:
668 | self.groups = groups
669 | self.world_size = world_size
670 | self.rank = rank
671 |
672 | # Keep input parameters
673 | self.in_features = in_features
674 | self.out_features = out_features
675 | self.input_is_parallel = input_is_parallel
676 | self.quant = quant
677 | # Divide the weight matrix along the last dimension.
678 | self.input_size_per_partition = divide_and_check_no_remainder(
679 | in_features, self.world_size)
680 |
681 | # Parameters.
682 | # Note: torch.nn.functional.linear performs XA^T + b and as a result
683 | # we allocate the transpose.
684 | if quant:
685 | self.weight = Parameter(
686 | torch.empty(
687 | (self.out_features, self.input_size_per_partition),
688 | dtype=torch.int8,
689 | ),
690 | requires_grad=False,
691 | )
692 | self.weight_scaler = Parameter(torch.Tensor(self.out_features))
693 | else:
694 | self.weight = Parameter(
695 | torch.Tensor(self.out_features, self.input_size_per_partition))
696 | if bias:
697 | self.bias = Parameter(torch.Tensor(self.out_features))
698 | # Always initialize bias to zero.
699 | with torch.no_grad():
700 | self.bias.zero_()
701 | else:
702 | self.register_parameter('bias', None)
703 |
704 | # Initialize weight.
705 | self.master_weight = _initialize_affine_weight(
706 | self.weight,
707 | self.out_features,
708 | self.in_features,
709 | self.input_size_per_partition,
710 | 1,
711 | init_method,
712 | self.world_size,
713 | self.rank,
714 | stride=stride,
715 | return_master_weight=keep_master_weight_for_test,
716 | )
717 |
718 | def get_master_weight(self) -> torch.Tensor:
719 | return gather_from_model_parallel_region(self.weight.data, self.groups,
720 | self.world_size, self.rank)
721 |
722 | def set_quantize(self):
723 | assert not self.quant
724 | self.weight = Parameter(
725 | torch.empty((self.out_features, self.input_size_per_partition),
726 | dtype=torch.int8),
727 | requires_grad=False,
728 | )
729 | self.weight_scaler = Parameter(torch.Tensor(self.out_features))
730 | self.quant = True
731 |
732 | def quantize(self):
733 | assert not self.quant
734 | fp_w = deepcopy(self.weight.data)
735 | orig_dtype = fp_w.dtype
736 | fp_w = fp_w.to(torch.float32)
737 | self.weight = Parameter(
738 | torch.empty((self.out_features, self.input_size_per_partition),
739 | dtype=torch.int8),
740 | requires_grad=False,
741 | )
742 | self.weight_scaler = Parameter(torch.Tensor(self.out_features))
743 | qconfig = TensorQConfig(axis=0)
744 | self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig)
745 | self.weight_scaler.data = scale.to(orig_dtype)
746 | self.quant = True
747 |
748 | def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
749 | # Set up backprop all-reduce.
750 | if self.input_is_parallel:
751 | input_parallel = input_
752 | else:
753 | input_parallel = scatter_to_model_parallel_region(
754 | input_, self.groups, self.world_size, self.rank)
755 | # Matrix multiply.
756 | if self.quant and USE_CUDA:
757 | # GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear.
758 | scaled_weight = self.weight * self.weight_scaler
759 | output_parallel = F.linear(input_parallel, scaled_weight, self.bias)
760 | elif self.quant:
761 | output_parallel = F.linear(input_parallel, self.weight, self.bias)
762 | output_parallel = output_parallel * self.weight_scaler
763 | else:
764 | output_parallel = F.linear(input_parallel, self.weight)
765 | # All-reduce across all the partitions.
766 | output_ = reduce_from_model_parallel_region(output_parallel,
767 | self.groups,
768 | self.world_size, self.rank)
769 | if self.bias is not None:
770 | output = output_ + self.bias
771 | else:
772 | output = output_
773 | return output
774 |
--------------------------------------------------------------------------------
/gemma/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Inference-only Gemma model implementation."""
15 |
16 | import json
17 | import gc
18 | import os
19 | import torch
20 | from torch import nn
21 | import torch.nn.functional as F
22 | from typing import Any, List, Optional, Sequence, Tuple, Union, Mapping
23 |
24 | from gemma import config as gemma_config
25 | from gemma import tokenizer
26 |
27 |
28 | class Sampler(nn.Module):
29 |
30 | def __init__(self, vocab_size: int, config: gemma_config.GemmaConfig):
31 | super().__init__()
32 | self.vocab_size = vocab_size
33 | self.config = config
34 |
35 | @torch.no_grad()
36 | def forward(
37 | self,
38 | embedding: torch.Tensor,
39 | hidden_states: torch.Tensor,
40 | output_positions: torch.Tensor,
41 | temperatures: Union[torch.Tensor, None],
42 | top_ps: torch.Tensor,
43 | top_ks: torch.Tensor,
44 | embedding_bias: Optional[torch.Tensor] = None,
45 | ) -> Tuple[torch.Tensor, torch.Tensor]:
46 | # Select the last element for each sequence.
47 | # (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)
48 | hidden_states = hidden_states.index_select(
49 | 1, output_positions).squeeze(dim=1)
50 | logits = torch.matmul(hidden_states, embedding.t())
51 | if embedding_bias is not None:
52 | logits += embedding_bias
53 | if self.config.final_logit_softcapping is not None:
54 | logits = logits / self.config.final_logit_softcapping
55 | logits = torch.tanh(logits)
56 | logits = logits * self.config.final_logit_softcapping
57 |
58 | if temperatures is None:
59 | return torch.argmax(logits, dim=-1).squeeze(dim=-1), logits
60 |
61 | # Apply temperature scaling.
62 | logits.div_(temperatures.unsqueeze(dim=1))
63 |
64 | # Calculate probabilities with softmax.
65 | probs = torch.softmax(logits, dim=-1, dtype=torch.float)
66 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
67 |
68 | # Apply top-p, top-k.
69 | probs_sum = torch.cumsum(probs_sort, dim=-1)
70 | top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
71 | probs_sort = torch.where(top_ps_mask, 0, probs_sort)
72 |
73 | top_ks_mask = torch.arange(probs_idx.shape[-1],
74 | device=probs_idx.device)
75 | top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
76 | top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
77 | probs_sort = torch.where(top_ks_mask, 0, probs_sort)
78 |
79 | # Re-normalization.
80 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
81 | probs = torch.gather(probs_sort,
82 | dim=-1,
83 | index=torch.argsort(probs_idx, dim=-1))
84 |
85 | next_token_ids = torch.multinomial(probs,
86 | num_samples=1,
87 | replacement=True).squeeze(dim=-1)
88 | return next_token_ids, logits
89 |
90 |
91 | def precompute_freqs_cis(dim: int,
92 | end: int,
93 | theta: float = 10000.0,
94 | rope_scaling_factor:int = 1) -> torch.Tensor:
95 | """Precomputes the frequency cis."""
96 | freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
97 | freqs = freqs/rope_scaling_factor
98 | t = torch.arange(end, device=freqs.device)
99 | freqs = torch.outer(t, freqs).float()
100 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
101 | return freqs_cis
102 |
103 |
104 | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
105 | """Applies the rotary embedding to the query and key tensors."""
106 | x_ = torch.view_as_complex(
107 | torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),
108 | dim=-1))
109 | x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
110 | x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
111 | x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
112 | -1).transpose(1, 2)
113 | return x_out
114 |
115 |
116 | class Linear(nn.Module):
117 |
118 | def __init__(self, in_features: int, out_features: int, quant: bool):
119 | super().__init__()
120 | if quant:
121 | self.weight = nn.Parameter(
122 | torch.empty((out_features, in_features), dtype=torch.int8),
123 | requires_grad=False,
124 | )
125 | self.weight_scaler = nn.Parameter(torch.Tensor(out_features))
126 | else:
127 | self.weight = nn.Parameter(
128 | torch.empty((out_features, in_features)),
129 | requires_grad=False,
130 | )
131 | self.quant = quant
132 |
133 | def forward(self, x):
134 | weight = self.weight
135 | if self.quant:
136 | weight = weight * self.weight_scaler.unsqueeze(-1)
137 | output = F.linear(x, weight)
138 | return output
139 |
140 |
141 | class Embedding(nn.Module):
142 |
143 | def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool):
144 | super().__init__()
145 | if quant:
146 | self.weight = nn.Parameter(
147 | torch.empty((num_embeddings, embedding_dim), dtype=torch.int8),
148 | requires_grad=False,
149 | )
150 | self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings))
151 | else:
152 | self.weight = nn.Parameter(
153 | torch.empty((num_embeddings, embedding_dim)),
154 | requires_grad=False,
155 | )
156 | self.quant = quant
157 |
158 | def forward(self, x):
159 | weight = self.weight
160 | if self.quant:
161 | weight = weight * self.weight_scaler.unsqueeze(-1)
162 | output = F.embedding(x, weight)
163 | return output
164 |
165 |
166 | class RMSNorm(torch.nn.Module):
167 |
168 | def __init__(
169 | self,
170 | dim: int,
171 | eps: float = 1e-6,
172 | add_unit_offset: bool = True,
173 | ):
174 | super().__init__()
175 | self.eps = eps
176 | self.add_unit_offset = add_unit_offset
177 | self.weight = nn.Parameter(torch.zeros(dim))
178 |
179 | def _norm(self, x):
180 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
181 |
182 | def forward(self, x):
183 | # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
184 | # See https://github.com/huggingface/transformers/pull/29402
185 | output = self._norm(x.float())
186 | if self.add_unit_offset:
187 | output = output * (1 + self.weight.float())
188 | else:
189 | output = output * self.weight.float()
190 | return output.type_as(x)
191 |
192 |
193 | class GemmaMLP(nn.Module):
194 |
195 | def __init__(
196 | self,
197 | hidden_size: int,
198 | intermediate_size: int,
199 | quant: bool,
200 | ):
201 | super().__init__()
202 | self.gate_proj = Linear(hidden_size, intermediate_size, quant)
203 | self.up_proj = Linear(hidden_size, intermediate_size, quant)
204 | self.down_proj = Linear(intermediate_size, hidden_size, quant)
205 |
206 | def forward(self, x):
207 | gate = self.gate_proj(x)
208 | gate = F.gelu(gate, approximate="tanh")
209 | up = self.up_proj(x)
210 | fuse = gate * up
211 | outputs = self.down_proj(fuse)
212 | return outputs
213 |
214 |
215 | class GemmaAttention(nn.Module):
216 |
217 | def __init__(
218 | self,
219 | config: gemma_config.GemmaConfig,
220 | attn_type: gemma_config.AttentionType,
221 | ):
222 | super().__init__()
223 |
224 | self.num_heads = config.num_attention_heads
225 | self.num_kv_heads = config.num_key_value_heads
226 |
227 | assert self.num_heads % self.num_kv_heads == 0
228 | self.num_queries_per_kv = self.num_heads // self.num_kv_heads
229 |
230 | self.hidden_size = config.hidden_size
231 | self.head_dim = config.head_dim
232 |
233 | self.q_size = self.num_heads * self.head_dim
234 | self.kv_size = self.num_kv_heads * self.head_dim
235 |
236 | if config.query_pre_attn_scalar is not None:
237 | self.scaling = config.query_pre_attn_scalar**-0.5
238 | else:
239 | self.scaling = self.head_dim**-0.5
240 |
241 | self.qkv_proj = Linear(
242 | self.hidden_size,
243 | (self.num_heads + 2 * self.num_kv_heads) * self.head_dim,
244 | quant=config.quant)
245 | self.o_proj = Linear(
246 | self.num_heads * self.head_dim, self.hidden_size, quant=config.quant
247 | )
248 | self.query_norm = (
249 | RMSNorm(self.head_dim, eps=config.rms_norm_eps)
250 | if config.use_qk_norm
251 | else None
252 | )
253 | self.key_norm = (
254 | RMSNorm(self.head_dim, eps=config.rms_norm_eps)
255 | if config.use_qk_norm
256 | else None
257 | )
258 |
259 | self.attn_type = attn_type
260 | self.sliding_window_size = config.sliding_window_size
261 | self.attn_logit_softcapping = config.attn_logit_softcapping
262 |
263 | def forward(
264 | self,
265 | hidden_states: torch.Tensor,
266 | freqs_cis: torch.Tensor,
267 | kv_write_indices: torch.Tensor,
268 | kv_cache: Tuple[torch.Tensor, torch.Tensor],
269 | mask: torch.Tensor,
270 | local_mask: torch.Tensor = None,
271 | ) -> torch.Tensor:
272 | hidden_states_shape = hidden_states.shape
273 | assert len(hidden_states_shape) == 3
274 |
275 | batch_size, input_len, _ = hidden_states_shape
276 |
277 | qkv = self.qkv_proj(hidden_states)
278 | xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
279 | dim=-1)
280 |
281 | xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
282 | xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
283 | xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
284 |
285 | if self.query_norm is not None and self.key_norm is not None:
286 | xq = self.query_norm(xq)
287 | xk = self.key_norm(xk)
288 |
289 | # Positional embedding.
290 | xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
291 | xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
292 |
293 | # Write new kv cache.
294 | # [batch_size, input_len, n_local_kv_heads, head_dim]
295 | k_cache, v_cache = kv_cache
296 | k_cache.index_copy_(1, kv_write_indices, xk)
297 | v_cache.index_copy_(1, kv_write_indices, xv)
298 |
299 | key = k_cache
300 | value = v_cache
301 | if self.num_kv_heads != self.num_heads:
302 | # [batch_size, max_seq_len, n_local_heads, head_dim]
303 | key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
304 | value = torch.repeat_interleave(value,
305 | self.num_queries_per_kv,
306 | dim=2)
307 |
308 | # [batch_size, n_local_heads, input_len, head_dim]
309 | q = xq.transpose(1, 2)
310 | # [batch_size, n_local_heads, max_seq_len, head_dim]
311 | k = key.transpose(1, 2)
312 | v = value.transpose(1, 2)
313 |
314 | # [batch_size, n_local_heads, input_len, max_seq_len]
315 | q.mul_(self.scaling)
316 | scores = torch.matmul(q, k.transpose(2, 3))
317 | if (
318 | self.attn_type == gemma_config.AttentionType.LOCAL_SLIDING
319 | and self.sliding_window_size is not None
320 | and local_mask is not None
321 | ):
322 | mask = local_mask
323 |
324 | if self.attn_logit_softcapping is not None:
325 | scores = scores / self.attn_logit_softcapping
326 | scores = torch.tanh(scores)
327 | scores = scores * self.attn_logit_softcapping
328 |
329 | scores = scores + mask
330 | scores = F.softmax(scores.float(), dim=-1).type_as(q)
331 |
332 | # [batch_size, n_local_heads, input_len, head_dim]
333 | output = torch.matmul(scores, v)
334 |
335 | # [batch_size, input_len, hidden_dim]
336 | output = (output.transpose(1, 2).contiguous().view(
337 | batch_size, input_len, -1))
338 | output = self.o_proj(output)
339 | return output
340 |
341 |
342 | class GemmaDecoderLayer(nn.Module):
343 |
344 | def __init__(
345 | self,
346 | config: gemma_config.GemmaConfig,
347 | ):
348 | super().__init__()
349 | self.attn_type = gemma_config.AttentionType.GLOBAL
350 | self.self_attn = GemmaAttention(
351 | config=config,
352 | attn_type=self.attn_type)
353 | self.mlp = GemmaMLP(
354 | hidden_size=config.hidden_size,
355 | intermediate_size=config.intermediate_size,
356 | quant=config.quant,
357 | )
358 | self.input_layernorm = RMSNorm(config.hidden_size,
359 | eps=config.rms_norm_eps)
360 | self.post_attention_layernorm = RMSNorm(config.hidden_size,
361 | eps=config.rms_norm_eps)
362 |
363 | # TODO(imayank): Decouple Gemma versions into separate files.
364 | def forward(
365 | self,
366 | hidden_states: torch.Tensor,
367 | freqs_cis: torch.Tensor,
368 | kv_write_indices: torch.Tensor,
369 | kv_cache: Tuple[torch.Tensor, torch.Tensor],
370 | mask: torch.Tensor,
371 | local_mask: torch.Tensor,
372 | ) -> torch.Tensor:
373 | # Self Attention
374 | residual = hidden_states
375 | hidden_states = self.input_layernorm(hidden_states)
376 | hidden_states = self.self_attn(
377 | hidden_states=hidden_states,
378 | freqs_cis=freqs_cis,
379 | kv_write_indices=kv_write_indices,
380 | kv_cache=kv_cache,
381 | mask=mask,
382 | )
383 | hidden_states = residual + hidden_states
384 |
385 | # MLP
386 | residual = hidden_states
387 | hidden_states = self.post_attention_layernorm(hidden_states)
388 | hidden_states = self.mlp(hidden_states)
389 | hidden_states = residual + hidden_states
390 |
391 | return hidden_states
392 |
393 |
394 | class Gemma2DecoderLayer(nn.Module):
395 | def __init__(
396 | self,
397 | config: gemma_config.GemmaConfig,
398 | attn_type: gemma_config.AttentionType,
399 | ):
400 | super().__init__()
401 | self.attn_type = attn_type
402 | self.self_attn = GemmaAttention(
403 | config=config,
404 | attn_type=self.attn_type,
405 | )
406 | self.mlp = GemmaMLP(
407 | hidden_size=config.hidden_size,
408 | intermediate_size=config.intermediate_size,
409 | quant=config.quant,
410 | )
411 | self.input_layernorm = RMSNorm(config.hidden_size,
412 | eps=config.rms_norm_eps)
413 | self.post_attention_layernorm = RMSNorm(config.hidden_size,
414 | eps=config.rms_norm_eps)
415 | self.pre_feedforward_layernorm = (
416 | RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
417 | if config.use_pre_ffw_norm
418 | else None
419 | )
420 | self.post_feedforward_layernorm = (
421 | RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
422 | if config.use_post_ffw_norm
423 | else None
424 | )
425 |
426 | def forward(
427 | self,
428 | hidden_states: torch.Tensor,
429 | freqs_cis: torch.Tensor,
430 | kv_write_indices: torch.Tensor,
431 | kv_cache: Tuple[torch.Tensor, torch.Tensor],
432 | mask: torch.Tensor,
433 | local_mask: torch.Tensor,
434 | ) -> torch.Tensor:
435 | # Self Attention
436 | residual = hidden_states
437 | hidden_states = self.input_layernorm(hidden_states)
438 | hidden_states = self.self_attn(
439 | hidden_states=hidden_states,
440 | freqs_cis=freqs_cis,
441 | kv_write_indices=kv_write_indices,
442 | kv_cache=kv_cache,
443 | mask=mask,
444 | local_mask=local_mask,
445 | )
446 | hidden_states = self.post_attention_layernorm(hidden_states)
447 | hidden_states = residual + hidden_states
448 |
449 | # MLP
450 | residual = hidden_states
451 | if self.pre_feedforward_layernorm is not None:
452 | hidden_states = self.pre_feedforward_layernorm(hidden_states)
453 | hidden_states = self.mlp(hidden_states)
454 | if self.post_feedforward_layernorm is not None:
455 | hidden_states = self.post_feedforward_layernorm(hidden_states)
456 | hidden_states = residual + hidden_states
457 |
458 | return hidden_states
459 |
460 |
461 | class GemmaModel(nn.Module):
462 |
463 | def __init__(self, config: gemma_config.GemmaConfig):
464 | super().__init__()
465 | self.config = config
466 | self.vocab_size = config.vocab_size
467 |
468 | self.layers = nn.ModuleList()
469 | for i in range(config.num_hidden_layers):
470 | if config.architecture == gemma_config.Architecture.GEMMA_1:
471 | self.layers.append(GemmaDecoderLayer(config))
472 | elif config.architecture in (
473 | gemma_config.Architecture.GEMMA_2,
474 | gemma_config.Architecture.GEMMA_3,
475 | ):
476 | attn_type = (
477 | config.attn_types[i % len(config.attn_types)]
478 | if config.attn_types is not None
479 | else gemma_config.AttentionType.GLOBAL
480 | )
481 | self.layers.append(Gemma2DecoderLayer(config, attn_type))
482 | else:
483 | raise ValueError(f'Unknown architecture: {config.architecture}')
484 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
485 |
486 | def forward(
487 | self,
488 | hidden_states: torch.Tensor,
489 | freqs_cis: Mapping[gemma_config.AttentionType, torch.Tensor],
490 | kv_write_indices: torch.Tensor,
491 | kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
492 | mask: torch.Tensor,
493 | local_mask: torch.Tensor,
494 | ) -> torch.Tensor:
495 | for i in range(len(self.layers)):
496 | layer = self.layers[i]
497 | hidden_states = layer(
498 | hidden_states=hidden_states,
499 | freqs_cis=freqs_cis.get(layer.attn_type),
500 | kv_write_indices=kv_write_indices,
501 | kv_cache=kv_caches[i],
502 | mask=mask,
503 | local_mask=local_mask,
504 | )
505 | hidden_states = self.norm(hidden_states)
506 | return hidden_states
507 |
508 |
509 | class GemmaForCausalLM(nn.Module):
510 |
511 | def __init__(
512 | self,
513 | config: gemma_config.GemmaConfig,
514 | ):
515 | super().__init__()
516 | self.config = config
517 | assert config.hidden_size % config.num_attention_heads == 0
518 |
519 | max_seq_len = config.max_position_embeddings
520 | head_dim = config.head_dim
521 | vocab_size = config.vocab_size
522 |
523 | self.tokenizer = tokenizer.Tokenizer(config.tokenizer)
524 | self.embedder = Embedding(vocab_size, config.hidden_size, config.quant)
525 | self.model = GemmaModel(config)
526 | self.sampler = Sampler(vocab_size, config)
527 |
528 | # Pre-compute rotary embedding table.
529 | if config.architecture == gemma_config.Architecture.GEMMA_3:
530 | if config.rope_wave_length is None:
531 | raise ValueError('rope_wave_length must be provided for Gemma3.')
532 |
533 | rope_lengths = config.rope_wave_length
534 | defaults = {
535 | gemma_config.AttentionType.LOCAL_SLIDING: 10_000,
536 | gemma_config.AttentionType.GLOBAL: 10_000,
537 | }
538 |
539 | for attn_type, name in [
540 | (gemma_config.AttentionType.LOCAL_SLIDING, 'local_freqs_cis'),
541 | (gemma_config.AttentionType.GLOBAL, 'global_freqs_cis'),
542 | ]:
543 | theta = rope_lengths.get(
544 | attn_type, defaults[attn_type]
545 | )
546 | self._register_freqs_cis(name, head_dim, max_seq_len, theta=theta)
547 |
548 | else:
549 | self._register_freqs_cis('freqs_cis', head_dim, max_seq_len)
550 |
551 | def _register_freqs_cis(
552 | self, name: str, head_dim: int, max_seq_len: int, theta: int = 10_000
553 | ):
554 | self.register_buffer(
555 | name, precompute_freqs_cis(head_dim, max_seq_len * 2, theta=theta)
556 | )
557 |
558 | @torch.no_grad()
559 | def forward(
560 | self,
561 | input_token_ids: torch.Tensor,
562 | input_positions: torch.Tensor,
563 | kv_write_indices: torch.Tensor,
564 | kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
565 | mask: torch.Tensor,
566 | output_positions: torch.Tensor,
567 | temperatures: Union[torch.Tensor, None],
568 | top_ps: torch.Tensor,
569 | top_ks: torch.Tensor,
570 | local_mask: torch.Tensor | None = None,
571 | **kwargs,
572 | ) -> Tuple[torch.Tensor, torch.Tensor]:
573 | freqs_cis = {}
574 |
575 | if self.config.architecture == gemma_config.Architecture.GEMMA_3:
576 | freqs_cis[gemma_config.AttentionType.LOCAL_SLIDING] = (
577 | self.local_freqs_cis.index_select(0, input_positions)
578 | )
579 | freqs_cis[gemma_config.AttentionType.GLOBAL] = (
580 | self.global_freqs_cis.index_select(0, input_positions)
581 | )
582 | else:
583 | freqs_cis[gemma_config.AttentionType.LOCAL_SLIDING] = (
584 | self.freqs_cis.index_select(0, input_positions)
585 | )
586 | freqs_cis[gemma_config.AttentionType.GLOBAL] = (
587 | self.freqs_cis.index_select(0, input_positions)
588 | )
589 |
590 | kv_write_indices = input_positions
591 |
592 | # [batch_size, input_len, hidden_size]
593 | hidden_states = self.embedder(input_token_ids)
594 | # Gemma normalizes the embedding by sqrt(hidden_size).
595 | # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
596 | # See https://github.com/huggingface/transformers/pull/29402
597 | normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype, device=hidden_states.device)
598 | hidden_states = hidden_states * normalizer
599 |
600 | hidden_states = self.model(
601 | hidden_states=hidden_states,
602 | freqs_cis=freqs_cis,
603 | kv_write_indices=kv_write_indices,
604 | kv_caches=kv_caches,
605 | mask=mask,
606 | local_mask=local_mask,
607 | )
608 | embedder_weight = self.embedder.weight
609 | if self.config.quant:
610 | embedder_weight = (
611 | embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))
612 | next_tokens, logits = self.sampler(
613 | embedding=embedder_weight,
614 | hidden_states=hidden_states,
615 | output_positions=output_positions,
616 | temperatures=temperatures,
617 | top_ps=top_ps,
618 | top_ks=top_ks,
619 | )
620 | return next_tokens, logits
621 |
622 | def generate(
623 | self,
624 | prompts: Union[str, Sequence[str]],
625 | device: Any,
626 | output_len: int = 100,
627 | temperature: Union[float, None] = 1.0,
628 | top_p: float = 0.95,
629 | top_k: int = 64,
630 | ) -> Union[str, Sequence[str]]:
631 | """Generates responses for given prompts using Gemma model."""
632 | # If a single prompt is provided, treat it as a batch of 1.
633 | is_str_prompt = isinstance(prompts, str)
634 | if is_str_prompt:
635 | prompts = [prompts]
636 |
637 | batch_size = len(prompts)
638 | prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts]
639 | min_prompt_len = min(len(p) for p in prompt_tokens)
640 | max_prompt_len = max(len(p) for p in prompt_tokens)
641 | max_seq_len = max_prompt_len + output_len
642 | assert max_seq_len <= self.config.max_position_embeddings
643 |
644 | # build KV caches
645 | kv_caches = []
646 | for _ in range(self.config.num_hidden_layers):
647 | size = (batch_size, max_seq_len, self.config.num_key_value_heads,
648 | self.config.head_dim)
649 | dtype = self.config.get_dtype()
650 | k_cache = torch.zeros(size=size, dtype=dtype, device=device)
651 | v_cache = torch.zeros(size=size, dtype=dtype, device=device)
652 | kv_caches.append((k_cache, v_cache))
653 |
654 | # prepare inputs
655 | token_ids_tensor = torch.full((batch_size, max_seq_len),
656 | self.tokenizer.pad_id, dtype=torch.int64)
657 | input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
658 | self.tokenizer.pad_id,
659 | dtype=torch.int64)
660 | for i, p in enumerate(prompt_tokens):
661 | token_ids_tensor[i, :len(p)] = torch.tensor(p)
662 | input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
663 | p[:min_prompt_len])
664 | token_ids_tensor = token_ids_tensor.to(device)
665 | input_token_ids_tensor = input_token_ids_tensor.to(device)
666 | prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id
667 | input_positions_tensor = torch.arange(0, min_prompt_len,
668 | dtype=torch.int64).to(device)
669 | mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
670 | -2.3819763e38).to(torch.float)
671 | mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
672 | local_mask_tensor = mask_tensor + torch.tril(
673 | torch.full((1, 1, max_seq_len, max_seq_len), -2.3819763e38, device=device),
674 | diagonal=-self.config.sliding_window_size,
675 | ) if self.config.sliding_window_size else None
676 | curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
677 | curr_local_mask_tensor = local_mask_tensor.index_select(
678 | 2, input_positions_tensor
679 | ) if local_mask_tensor is not None else None
680 | output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device)
681 | temperatures_tensor = None if not temperature else torch.FloatTensor(
682 | [temperature] * batch_size).to(device)
683 | top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
684 | top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
685 | output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(
686 | device)
687 |
688 | # Prefill up to min_prompt_len tokens, then treat other prefill as
689 | # decode and ignore output.
690 | for i in range(max_seq_len - min_prompt_len):
691 | next_token_ids, _ = self(
692 | input_token_ids=input_token_ids_tensor,
693 | input_positions=input_positions_tensor,
694 | kv_write_indices=None,
695 | kv_caches=kv_caches,
696 | mask=curr_mask_tensor,
697 | output_positions=output_positions_tensor,
698 | temperatures=temperatures_tensor,
699 | top_ps=top_ps_tensor,
700 | top_ks=top_ks_tensor,
701 | local_mask=curr_local_mask_tensor,
702 | )
703 |
704 | curr_prompt_mask = prompt_mask_tensor.index_select(
705 | 1, output_index).squeeze(dim=1)
706 | curr_token_ids = token_ids_tensor.index_select(
707 | 1, output_index).squeeze(dim=1)
708 | output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
709 | next_token_ids).unsqueeze(dim=1)
710 | token_ids_tensor.index_copy_(1, output_index, output_token_ids)
711 |
712 | input_token_ids_tensor = output_token_ids
713 | input_positions_tensor = output_index.unsqueeze(dim=-1)
714 | curr_mask_tensor = mask_tensor.index_select(2,
715 | input_positions_tensor)
716 | curr_local_mask_tensor = local_mask_tensor.index_select(
717 | 2, input_positions_tensor
718 | ) if local_mask_tensor is not None else None
719 | output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(
720 | device)
721 | output_index = output_index + 1
722 |
723 | # Detokenization.
724 | token_ids = token_ids_tensor.tolist()
725 | results = []
726 | for i, tokens in enumerate(token_ids):
727 | trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i])
728 | + output_len]
729 | if self.tokenizer.eos_id in trimmed_output:
730 | eos_index = trimmed_output.index(self.tokenizer.eos_id)
731 | trimmed_output = trimmed_output[:eos_index]
732 | results.append(self.tokenizer.decode(trimmed_output))
733 |
734 | # If a string was provided as input, return a string as output.
735 | return results[0] if is_str_prompt else results
736 |
737 | def load_weights(self, model_path: str):
738 | if os.path.isfile(model_path):
739 | self.load_state_dict(
740 | torch.load(
741 | model_path, mmap=True, weights_only=True,
742 | )['model_state_dict'],
743 | strict=False,
744 | )
745 | else:
746 | index_path = os.path.join(model_path, 'pytorch_model.bin.index.json')
747 | with open(index_path, "r", encoding="utf-8") as f:
748 | index = json.load(f)
749 | shard_files = list(set(index["weight_map"].values()))
750 | for shard_file in shard_files:
751 | shard_path = os.path.join(model_path, shard_file)
752 | state_dict = torch.load(shard_path, map_location="cpu", weights_only=True)
753 | self.load_state_dict(state_dict, strict=False)
754 | del state_dict # Save memory.
755 | gc.collect()
756 |
--------------------------------------------------------------------------------