├── .github
└── workflows
│ └── unit_tests.yaml
├── .gitignore
├── .gitmodules
├── .pylintrc
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── benchmarks
├── __init__.py
├── analyze_sharegpt.py
├── basic_ops.py
├── mixtral_offline.sh
└── summary.md
├── default_shardings
├── gemma.yaml
├── llama-blockwise-quant.yaml
├── llama.yaml
└── mixtral.yaml
├── docker
└── jetstream-pytorch-server
│ ├── Dockerfile
│ ├── README.md
│ └── jetstream_pytorch_server_entrypoint.sh
├── docs
├── add_a_new_model.md
└── add_hf_checkpoint_conversion.md
├── install_everything.sh
├── install_everything_gpu.sh
├── jetstream_pt
├── __init__.py
├── attention_kernel.py
├── cache_manager.py
├── cli.py
├── config.py
├── engine.py
├── environment.py
├── fetch_models.py
├── gcs_to_cns.sh
├── hf_tokenizer.py
├── layers.py
├── model_base.py
├── page_attention_manager.py
├── quantize.py
├── quantize_model.py
├── ray_engine.py
├── ray_worker.py
├── third_party
│ ├── gemma
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── model.py
│ │ ├── model_original.py
│ │ └── tokenizer.py
│ ├── llama
│ │ ├── LICENSE
│ │ ├── __init__.py
│ │ ├── generation_original.py
│ │ ├── model_args.py
│ │ ├── model_exportable.py
│ │ ├── model_original.py
│ │ ├── tokenizer.model
│ │ └── tokenizer.py
│ └── mixtral
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── model.py
│ │ ├── model_original.py
│ │ └── tokenizer.model
└── torchjax.py
├── kuberay
├── image
│ └── Dockerfile
└── manifests
│ ├── ray-cluster.tpu-v4-multihost.yaml
│ ├── ray-cluster.tpu-v4-singlehost.yaml
│ ├── ray-cluster.tpu-v5-multihost.yaml
│ └── ray-cluster.tpu-v5-singlehost.yaml
├── mlperf
├── README.md
├── backend.py
├── benchmark_run.sh
├── dataset.py
├── install.sh
├── main.py
├── mlperf.conf
├── start_server.sh
├── user.conf
└── warmup.py
├── poetry.lock
├── pyproject.toml
├── run_interactive_disaggregated.py
├── run_interactive_multiple_host.py
├── run_ray_serve_interleave.py
├── run_server_with_ray.py
├── scripts
├── create_empty_sharding_map.py
├── jax_experiments.py
└── validate_hf_ckpt_conversion.py
└── tests
├── .pylintrc
├── __init__.py
├── helpers.py
├── test_attention_kernal.py
├── test_engine.py
├── test_hf_names.py
├── test_jax_torch.py
├── test_kv_cache_manager.py
├── test_llama_e2e.py
├── test_model_impl.py
├── test_page_attention.py
└── test_quantization.py
/.github/workflows/unit_tests.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
16 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
17 |
18 | name: Unit Tests
19 |
20 | on:
21 | pull_request:
22 | push:
23 | branches: [ "main" ]
24 | workflow_dispatch:
25 | schedule:
26 | # Run the job every 12 hours
27 | - cron: '0 */12 * * *'
28 |
29 | jobs:
30 | py:
31 | name: "Python type/lint/format checks"
32 | strategy:
33 | matrix:
34 | os: [ubuntu-20.04]
35 | python-version: ['3.10']
36 | runs-on: ${{ matrix.os }}
37 | steps:
38 | - name: Checkout
39 | uses: actions/checkout@v4
40 | - name: Setup Python
41 | uses: actions/setup-python@v4
42 | with:
43 | python-version: ${{ matrix.python-version }}
44 | - name: Install Dependencies
45 | run: |
46 | pip install pytype
47 | pip install pylint
48 | pip install pyink
49 | source install_everything.sh
50 | # - name: Typecheck the code with pytype
51 | # run: |
52 | # pytype --jobs auto --disable import-error --disable module-attr jetstream_pt/
53 | - name: Analysing the code with pylint
54 | run: |
55 | pylint --indent-string=' ' jetstream_pt/ benchmarks/
56 | - name: Format check with pyink
57 | run: |
58 | pyink --pyink-indentation 2 --line-length 80 --check --verbose --extend-exclude=deps .
59 |
60 | cpu:
61 | name: "jetstream_pt unit tests"
62 | strategy:
63 | matrix:
64 | os: [ubuntu-20.04]
65 | python-version: ['3.10']
66 | runs-on: ${{ matrix.os }}
67 | steps:
68 | - name: Checkout
69 | uses: actions/checkout@v4
70 | - name: Setup Python
71 | uses: actions/setup-python@v4
72 | with:
73 | python-version: ${{ matrix.python-version }}
74 | - name: Install Dependencies
75 | run: |
76 | source install_everything.sh
77 | - name: Run all unit tests for jetstream_pt (/tests)
78 | run: |
79 | JAX_PLATFORMS=cpu coverage run -m unittest -v
80 | - name: Create test coverage report
81 | run: |
82 | coverage report -m
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # vscode
156 | .vscode/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "deps/JetStream"]
2 | path = deps/JetStream
3 | url = https://github.com/google/JetStream.git
4 | [submodule "deps/xla"]
5 | path = deps/xla
6 | url = https://github.com/pytorch/xla.git
7 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MESSAGES CONTROL]
2 | disable=C0114,R0801,R0903,R0913,R0917,E1102,W0613,R1711,too-many-locals
3 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute
2 |
3 | We'd love to accept your patches and contributions to this project.
4 |
5 | ## Before you begin
6 |
7 | ### Sign our Contributor License Agreement
8 |
9 | Contributions to this project must be accompanied by a
10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
11 | You (or your employer) retain the copyright to your contribution; this simply
12 | gives us permission to use and redistribute your contributions as part of the
13 | project.
14 |
15 | If you or your current employer have already signed the Google CLA (even if it
16 | was for a different project), you probably don't need to do it again.
17 |
18 | Visit to see your current agreements or to
19 | sign a new one.
20 |
21 | ### Review our community guidelines
22 |
23 | This project follows
24 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
25 |
26 | ## Contribution process
27 |
28 | ### Code reviews
29 |
30 | All submissions, including submissions by project members, require review. We
31 | use GitHub pull requests for this purpose. Consult
32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
33 | information on using pull requests.
34 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Jetstream-PyTorch
2 | JetStream Engine implementation in PyTorch
3 |
4 | # Latest Release:
5 |
6 | The latest release version is tagged with `jetstream-v0.2.3`. If you are running the release version
7 | Please follow the README of the that version here:
8 | https://github.com/google/jetstream-pytorch/blob/jetstream-v0.2.3/README.md
9 |
10 | Commandline Flags might have changed between the release version to HEAD.
11 |
12 | # Outline
13 |
14 | 1. Ssh to Cloud TPU VM (using v5e-8 TPU VM)
15 | a. Create a Cloud TPU VM if you haven’t
16 | 2. Download jetstream-pytorch github repo
17 | 3. Run the server
18 | 4. Run benchmarks
19 | 5. Typical Errors
20 |
21 | # Ssh to Cloud TPU VM (using v5e-8 TPU VM)
22 |
23 | ```bash
24 | gcloud compute config-ssh
25 | gcloud compute tpus tpu-vm ssh "your-tpu-vm" --project "your-project" --zone "your-project-zone"
26 | ```
27 | ## Create a Cloud TPU VM in a GCP project if you haven’t
28 | Follow the steps in
29 | * https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm
30 |
31 | # Clone repo and install dependencies
32 |
33 | ## Get the jetstream-pytorch code
34 | ```bash
35 | git clone https://github.com/google/jetstream-pytorch.git
36 | git checkout jetstream-v0.2.4
37 | ```
38 |
39 | (optional) Create a virtual env using `venv` or `conda` and activate it.
40 |
41 | ## 2. Run installation script:
42 |
43 | ```bash
44 | cd jetstream-pytorch
45 | source install_everything.sh
46 | ```
47 |
48 |
49 | # Run jetstream pytorch
50 |
51 | ## List out supported models
52 |
53 | ```
54 | jpt list
55 | ```
56 |
57 | This will print out list of support models and variants:
58 |
59 | ```
60 | meta-llama/Llama-2-7b-chat-hf
61 | meta-llama/Llama-2-7b-hf
62 | meta-llama/Llama-2-13b-chat-hf
63 | meta-llama/Llama-2-13b-hf
64 | meta-llama/Llama-2-70b-hf
65 | meta-llama/Llama-2-70b-chat-hf
66 | meta-llama/Meta-Llama-3-8B
67 | meta-llama/Meta-Llama-3-8B-Instruct
68 | meta-llama/Meta-Llama-3-70B
69 | meta-llama/Meta-Llama-3-70B-Instruct
70 | meta-llama/Llama-3.1-8B
71 | meta-llama/Llama-3.1-8B-Instruct
72 | meta-llama/Llama-3.2-1B
73 | meta-llama/Llama-3.2-1B-Instruct
74 | meta-llama/Llama-3.3-70B
75 | meta-llama/Llama-3.3-70B-Instruct
76 | google/gemma-2b
77 | google/gemma-2b-it
78 | google/gemma-7b
79 | google/gemma-7b-it
80 | mistralai/Mixtral-8x7B-v0.1
81 | mistralai/Mixtral-8x7B-Instruct-v0.1
82 | ```
83 |
84 | To run jetstream-pytorch server with one model:
85 |
86 | ```
87 | jpt serve --model_id meta-llama/Meta-Llama-3-8B-Instruct
88 | ```
89 |
90 | If it's the first time you run this model, it will download weights from
91 | HuggingFace.
92 |
93 | HuggingFace's Llama3 weights are gated, so you need to either run
94 | `huggingface-cli login` to set your token, OR, pass your hf_token explicitly.
95 |
96 | To pass hf token explicitly, add `--hf_token` flag
97 | ```
98 | jpt serve --model_id meta-llama/Meta-Llama-3-8B-Instruct --hf_token=...
99 | ```
100 |
101 | To login using huggingface hub, run:
102 |
103 | ```
104 | pip install -U "huggingface_hub[cli]"
105 | huggingface-cli login
106 | ```
107 | Then follow its prompt.
108 |
109 | After the weights are downloaded,
110 | Next time when you run this `--hf_token` will no longer be required.
111 |
112 | To run this model in `int8` quantization, add `--quantize_weights=1`.
113 | Quantization will be done on the flight as the weight loads.
114 |
115 | Weights downloaded from HuggingFace will be stored by default in `checkpoints` folder.
116 | in the place where `jpt` is executed.
117 |
118 | You can change where the weights are stored with `--working_dir` flag.
119 |
120 | If you wish to use your own checkpoint, then, place them inside
121 | of the `checkpoints///hf_original` dir (or the corresponding subdir in `--working_dir`). For example,
122 | Llama3 checkpoints will be at `checkpoints/meta-llama/Llama-2-7b-hf/hf_original/*.safetensors`. You can replace these files with modified
123 | weights in HuggingFace format.
124 |
125 | ## Send one request
126 |
127 | Jetstream-pytorch uses gRPC for handling requests, the script below demonstrates how to
128 | send gRPC in Python. You can also use other gPRC clients.
129 |
130 | ```python
131 | import requests
132 | import os
133 | import grpc
134 |
135 | from jetstream.core.proto import jetstream_pb2
136 | from jetstream.core.proto import jetstream_pb2_grpc
137 |
138 | prompt = "What are the top 5 languages?"
139 |
140 | channel = grpc.insecure_channel("localhost:8888")
141 | stub = jetstream_pb2_grpc.OrchestratorStub(channel)
142 |
143 | request = jetstream_pb2.DecodeRequest(
144 | text_content=jetstream_pb2.DecodeRequest.TextContent(
145 | text=prompt
146 | ),
147 | priority=0,
148 | max_tokens=2000,
149 | )
150 |
151 | response = stub.Decode(request)
152 | output = []
153 | for resp in response:
154 | output.extend(resp.stream_content.samples[0].text)
155 |
156 | text_output = "".join(output)
157 | print(f"Prompt: {prompt}")
158 | print(f"Response: {text_output}")
159 | ```
160 |
161 |
162 | # Run the server with ray
163 | Below are steps run server with ray:
164 | 1. Ssh to Cloud Multiple Host TPU VM (v5e-16 TPU VM)
165 | 2. Step 2 to step 5 in Outline
166 | 3. Setup ray cluster
167 | 4. Run server with ray
168 |
169 | ## Setup Ray Cluster
170 | Login host 0 VM, start ray head with below command:
171 |
172 | ```bash
173 |
174 | ray start --head
175 |
176 | ```
177 |
178 | Login other host VMs, start ray head with below command:
179 |
180 | ```bash
181 |
182 | ray start --address='$ip:$port'
183 |
184 | ```
185 |
186 | Note: Get address ip and port information from ray head.
187 |
188 | ## Run server with ray
189 |
190 | Here is an example to run the server with ray for llama2 7B model:
191 |
192 | ```bash
193 | export DISABLE_XLA2_PJRT_TEST="true"
194 | python run_server_with_ray.py --tpu_chips=16 --num_hosts=4 --worker_chips=4 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml"
195 | ```
196 |
197 | # Run benchmark
198 | Start the server and then go to the deps/JetStream folder (downloaded during `install_everything.sh`)
199 |
200 | ```bash
201 | cd deps/JetStream
202 | wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
203 | export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
204 | python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs --warmup-mode=sampled --model=$model_name
205 | ```
206 | Please look at `deps/JetStream/benchmarks/README.md` for more information.
207 |
208 |
209 |
210 | ## Run server with Ray Serve
211 |
212 | ### Prerequisites
213 |
214 | If running on GKE:
215 |
216 | 1. Follow instructions on [this link](https://github.com/GoogleCloudPlatform/ai-on-gke/tree/main/ray-on-gke/guides/tpu) to setup a GKE cluster and the TPU webhook.
217 | 2. Follow instructions
218 | [here](https://cloud.google.com/kubernetes-engine/docs/how-to/persistent-volumes/cloud-storage-fuse-csi-driver)
219 | to enable GCSFuse for your cluster. This will be needed to store the
220 | converted weights.
221 | 3. Deploy one of the sample Kuberay cluster configurations:
222 | ```bash
223 | kubectl apply -f kuberay/manifests/ray-cluster.tpu-v4-singlehost.yaml
224 | ```
225 | or
226 | ```bash
227 | kubectl apply -f kuberay/manifests/ray-cluster.tpu-v4-multihost.yaml
228 | ```
229 |
230 |
231 | ### Start a Ray Serve deployment
232 |
233 | Single-host (Llama2 7B):
234 |
235 | ```bash
236 | export RAY_ADDRESS=http://localhost:8265
237 |
238 | kubectl port-forward svc/example-cluster-kuberay-head-svc 8265:8265 &
239 |
240 | ray job submit --runtime-env-json='{"working_dir": "."}' -- python run_ray_serve_interleave.py --tpu_chips=4 --num_hosts=1 --size=7b --model_name=llama-2 --batch_size=32 --max_cache_length=2048 --tokenizer_path=/llama/tokenizer.model --checkpoint_path=/llama/ckpt --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"
241 | ```
242 |
243 | Multi-host (Llama2 70B):
244 |
245 | ```bash
246 | export RAY_ADDRESS=http://localhost:8265
247 |
248 | kubectl port-forward svc/example-cluster-kuberay-head-svc 8265:8265 &
249 |
250 | ray job submit --runtime-env-json='{"working_dir": "."}' -- python run_ray_serve_interleave.py --tpu_chips=8 --num_hosts=2 --size=70b --model_name=llama-2 --batch_size=8 --max_cache_length=2048 --tokenizer_path=/llama/tokenizer.model --checkpoint_path=/llama/ckpt --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"
251 | ```
252 |
253 | ### Sending an inference request
254 |
255 | Port-forward to port 8888 for gRPC:
256 | ```
257 | kubectl port-forward svc/example-cluster-kuberay-head-svc 8888:8888 &
258 | ```
259 |
260 | Sample python script:
261 |
262 | ```python
263 | import requests
264 | import os
265 | import grpc
266 |
267 | from jetstream.core.proto import jetstream_pb2
268 | from jetstream.core.proto import jetstream_pb2_grpc
269 |
270 | prompt = "What are the top 5 languages?"
271 |
272 | channel = grpc.insecure_channel("localhost:8888")
273 | stub = jetstream_pb2_grpc.OrchestratorStub(channel)
274 |
275 | request = jetstream_pb2.DecodeRequest(
276 | text_content=jetstream_pb2.DecodeRequest.TextContent(
277 | text=prompt
278 | ),
279 | priority=0,
280 | max_tokens=2000,
281 | )
282 |
283 | response = stub.Decode(request)
284 | output = []
285 | for resp in response:
286 | output.extend(resp.stream_content.samples[0].text)
287 |
288 | text_output = "".join(output)
289 | print(f"Prompt: {prompt}")
290 | print(f"Response: {text_output}")
291 | ```
292 |
293 |
294 |
295 | # Typical Errors
296 |
297 | ## Unexpected keyword argument 'device'
298 |
299 | Fix:
300 | * Uninstall jax and jaxlib dependencies
301 | * Reinstall using `source install_everything.sh
302 |
303 | ## Out of memory
304 |
305 | Fix:
306 | * Use smaller batch size
307 | * Use quantization
308 |
309 | # Links
310 |
311 | ## JetStream
312 | * https://github.com/google/JetStream
313 |
314 | ## MaxText
315 | * https://github.com/google/maxtext
316 |
317 |
318 |
--------------------------------------------------------------------------------
/benchmarks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/benchmarks/analyze_sharegpt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import json
16 |
17 | CUTOFF_INPUT = 1024
18 | CUTOFF_OUTPUT = 1024
19 |
20 |
21 | # pylint: disable-next=all
22 | def do_simulation(
23 | sharegpt_path, prefill_bucket_size_to_ms, system_time_per_decode_token_ms
24 | ):
25 | def next_power_of_2(x):
26 | return 1 if x == 0 else 2 ** (x - 1).bit_length()
27 |
28 | def tokens_in_input_str(s):
29 | return_val = int(1.3 * len(s.split()))
30 | # print(f"{s=} -> {return_val=}")
31 | return return_val
32 |
33 | convo_numbers = []
34 | # Please update with your own data file path
35 |
36 | with open(sharegpt_path, "r", encoding="utf-8") as f:
37 | loaded_share_gpt = json.load(f)
38 | for example in loaded_share_gpt:
39 | if len(example["conversations"]) < 2:
40 | continue
41 | input_tokens = tokens_in_input_str(example["conversations"][0]["value"])
42 | output_tokens = tokens_in_input_str(example["conversations"][1]["value"])
43 | convo_numbers.append((input_tokens, output_tokens))
44 |
45 | num_convos = len(convo_numbers)
46 | kept_convos = [
47 | c for c in convo_numbers if c[0] <= CUTOFF_INPUT and c[1] <= CUTOFF_OUTPUT
48 | ]
49 |
50 | mean_input = sum(c[0] for c in kept_convos) / len(kept_convos)
51 | mean_output = sum(c[1] for c in kept_convos) / len(kept_convos)
52 |
53 | print(
54 | f"""Total {num_convos=} but only kept {kept_convos=}.
55 | Out of kept, {mean_input=}, {mean_output=}"""
56 | )
57 |
58 | total_prefill_system_ms = 0
59 | total_generate_system_ms = 0
60 |
61 | for convo in kept_convos:
62 | input_tok, output_tok = convo
63 | bucket = max(128, next_power_of_2(input_tok))
64 | generate_system_ms = output_tok * system_time_per_decode_token_ms
65 | prefill_system_ms = prefill_bucket_size_to_ms[bucket]
66 |
67 | print(
68 | f"{convo=} {bucket=}, {prefill_system_ms=:.2f}, {generate_system_ms=:.2f}"
69 | )
70 |
71 | total_prefill_system_ms += prefill_system_ms
72 | total_generate_system_ms += generate_system_ms
73 |
74 | total_time_ms = total_prefill_system_ms + total_generate_system_ms
75 | input_tokens = sum(c[0] for c in kept_convos)
76 |
77 | output_tokens = sum(c[1] for c in kept_convos)
78 | print(
79 | f"""Output tokens {output_tokens} in {total_time_ms/1000:.2f} seconds,
80 | for {output_tokens/(total_time_ms/1000):.2f} out tok/s"""
81 | )
82 |
83 | total_prefill_sec = total_prefill_system_ms / 1000
84 | total_generate_sec = total_generate_system_ms / 1000
85 |
86 | print(
87 | f"""Total time {total_time_ms/1000:.2f} seconds,
88 | split {total_prefill_sec=:.2f} seconds and {total_generate_sec=:.2f} seconds"""
89 | )
90 |
91 | idealized_prefill_sec = (
92 | 1.1 * input_tokens / 1024 * prefill_bucket_size_to_ms[1024] / 1000
93 | )
94 |
95 | prefill_savings_sec = total_prefill_sec - idealized_prefill_sec
96 |
97 | idealized_generate_sec = (
98 | total_generate_sec / 2
99 | ) # (Roughly save 75% on KV cache high cost on the rest)
100 | generate_savings_sec = total_generate_sec - idealized_generate_sec
101 |
102 | print(
103 | f"""we think prefill will take {total_prefill_sec=:.2f},
104 | we could get it to {idealized_prefill_sec=:.2f} so we'd
105 | save {prefill_savings_sec=:.2f} seconds """
106 | )
107 | print(
108 | f"""with sparsity we could go from {total_generate_sec=:.2f},
109 | we could get it to {idealized_generate_sec=:.2f} so we'd save
110 | {generate_savings_sec=:.2f} seconds """
111 | )
112 |
113 | idealized_overall_time = idealized_generate_sec + idealized_prefill_sec
114 |
115 | print(
116 | f"""Idealized out tokens {output_tokens} in {idealized_overall_time:.2f} seconds,
117 | for {output_tokens/idealized_overall_time:.2f} out tok/s"""
118 | )
119 | print("prfill", prefill_bucket_size_to_ms)
120 | print("decode step", system_time_per_decode_token_ms)
121 |
--------------------------------------------------------------------------------
/benchmarks/basic_ops.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from typing import Tuple, Callable, List, Optional
3 | import time
4 | import dataclasses
5 |
6 | import numpy as np
7 |
8 | import jax
9 | import jax.numpy as jnp
10 | from jax.experimental import mesh_utils, shard_map
11 | from jax.sharding import PositionalSharding
12 |
13 |
14 | from jax.sharding import Mesh
15 | from jax.sharding import PartitionSpec
16 | from jax.sharding import NamedSharding
17 |
18 | devices = jax.devices()
19 | P = PartitionSpec
20 |
21 | devices = mesh_utils.create_device_mesh((len(devices),))
22 | mesh = Mesh(devices, axis_names=("x",))
23 | # y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
24 |
25 | L = 1 << 15
26 |
27 |
28 | @dataclasses.dataclass
29 | class BenchmarkCase:
30 | """BenchmarkCase."""
31 |
32 | name: str
33 | function: Callable
34 | args_shape: List[Tuple]
35 | args_sharding: List[PartitionSpec]
36 | profiler_output: Optional[str] = None
37 |
38 |
39 | start_key = jax.random.key(0)
40 |
41 |
42 | def _new_arg(shape, dtype):
43 | global start_key # pylint: disable=all
44 | start_key, _ = jax.random.split(start_key)
45 | with jax.default_device(jax.devices("cpu")[0]):
46 | if dtype == jnp.int8.dtype:
47 | return jax.random.randint(start_key, shape, 0, 100, dtype=dtype)
48 | else:
49 | return jax.random.normal(start_key, shape, dtype=dtype) + 1
50 |
51 |
52 | def _new_args(case, dtype):
53 | args = []
54 | for shape, sharding in zip(case.args_shape, case.args_sharding):
55 | arg = _new_arg(shape, dtype)
56 | if sharding is not None:
57 | arg = jax.device_put(arg, NamedSharding(mesh, sharding))
58 | args.append(arg)
59 | return args
60 |
61 |
62 | def _run_case(case, warmup=2, runtimes=5, dtype=jnp.bfloat16.dtype):
63 | for _ in range(warmup):
64 | args = _new_args(case, dtype)
65 | case.function(*args)
66 |
67 | stamps = []
68 | for i in range(runtimes):
69 | args = _new_args(case, dtype)
70 | jax.block_until_ready(args)
71 | if case.profiler_output is not None and i == (runtimes - 1):
72 | jax.profiler.start_trace(case.profiler_output)
73 | start = time.perf_counter()
74 | jax.block_until_ready(case.function(*args))
75 | end = time.perf_counter()
76 | if case.profiler_output is not None and i == (runtimes - 1):
77 | jax.profiler.stop_trace()
78 | stamps.append(end - start)
79 | return sum(stamps) / runtimes
80 |
81 |
82 | def _llama_ffn(x, w1, w2, w3):
83 | w1_res = jax.nn.silu((x @ w1).astype(jnp.bfloat16.dtype))
84 | w3_res = x @ w3
85 | res = (w1_res * w3_res) @ w2
86 | return res
87 |
88 |
89 | @jax.jit
90 | @functools.partial(
91 | shard_map.shard_map,
92 | mesh=mesh,
93 | in_specs=(P(), P(None, "x"), P("x"), P(None, "x")),
94 | out_specs=(P()),
95 | )
96 | def _llama_ffn_shmap(x, w1, w2, w3):
97 | for _ in range(3):
98 | x = _llama_ffn(x, w1, w2, w3)
99 | x = jax.lax.psum(x, "x")
100 | return x
101 |
102 |
103 | @jax.jit
104 | def _llama_ffn_spmd(x, w1, w2, w3):
105 | for _ in range(3):
106 | x = _llama_ffn(x, w1, w2, w3)
107 | x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
108 | return x
109 |
110 |
111 | dim = 4096
112 | multiple_of = 256
113 | # hidden_dim = 4 * dim
114 | # hidden_dim = int(2 * hidden_dim / 3)
115 | # hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
116 | hidden_dim = 11008
117 | BATCH = 1024
118 |
119 |
120 | @jax.jit
121 | @functools.partial(
122 | shard_map.shard_map,
123 | mesh=mesh,
124 | in_specs=(P("x"),),
125 | out_specs=(P()),
126 | check_rep=False,
127 | )
128 | def _all_gather(x):
129 | return jax.lax.all_gather(x, "x")
130 |
131 |
132 | @jax.jit
133 | @functools.partial(
134 | shard_map.shard_map, mesh=mesh, in_specs=(P("x"),), out_specs=(P())
135 | )
136 | def _all_reduce(x):
137 | return jax.lax.psum(x, "x")
138 |
139 |
140 | allcases = [
141 | BenchmarkCase(
142 | name="Matmul replicated",
143 | function=jax.jit(jnp.matmul),
144 | args_shape=((L, L), (L, L)),
145 | args_sharding=(P(), P()), # replicated
146 | ),
147 | BenchmarkCase(
148 | name="Matmul sharded colrow",
149 | function=jax.jit(jnp.matmul),
150 | args_shape=((L, L), (L, L)),
151 | args_sharding=(P(None, "x"), P("x")), # replicated
152 | ),
153 | BenchmarkCase(
154 | name="matmul sharded rowcol",
155 | function=jax.jit(jnp.matmul),
156 | args_shape=((L, L), (L, L)),
157 | args_sharding=(P("x"), P("x", None)), # replicated
158 | ),
159 | BenchmarkCase(
160 | name="all_gather",
161 | function=_all_gather,
162 | args_shape=((L, L),),
163 | args_sharding=(P("x"),), # replicated
164 | ),
165 | BenchmarkCase(
166 | name="all_reduce",
167 | function=_all_reduce,
168 | args_shape=((L, L),),
169 | args_sharding=(P("x"),), # replicated
170 | ),
171 | BenchmarkCase(
172 | name="Llama 3xffn shardmap",
173 | function=_llama_ffn_shmap,
174 | args_shape=(
175 | (BATCH, dim),
176 | (dim, hidden_dim),
177 | (hidden_dim, dim),
178 | (dim, hidden_dim),
179 | ),
180 | args_sharding=(P(), P(None, "x"), P("x"), P(None, "x")),
181 | ),
182 | BenchmarkCase(
183 | name="Llama 3xffn gspmd",
184 | function=_llama_ffn_spmd,
185 | args_shape=(
186 | (BATCH, dim),
187 | (dim, hidden_dim),
188 | (hidden_dim, dim),
189 | (dim, hidden_dim),
190 | ),
191 | args_sharding=(P(), P(None, "x"), P("x"), P(None, "x")),
192 | ),
193 | ]
194 |
195 |
196 | def _run_call_cases(cases):
197 | for dtype in (jnp.bfloat16.dtype, jnp.int8.dtype):
198 | for case in cases:
199 | avg = _run_case(case, dtype=dtype)
200 | dtype_size = 2 if dtype == jnp.bfloat16.dtype else 1
201 | input_sizes = tuple(
202 | [
203 | f"{np.prod(size) * dtype_size / (1<<20) :.6} MiB"
204 | for size in case.args_shape
205 | ]
206 | )
207 | print(
208 | f"{dtype} \t {case.name}: \t{avg * 1000 :.6} ms \t sizes: {input_sizes}"
209 | )
210 |
211 |
212 | def main():
213 | print("Number of devices: ", len(devices))
214 | _run_call_cases(allcases)
215 |
216 |
217 | if __name__ == "__main__":
218 | main()
219 |
--------------------------------------------------------------------------------
/benchmarks/mixtral_offline.sh:
--------------------------------------------------------------------------------
1 | CACHE_LENGTH=1024
2 | INPUT_SIZE=512
3 | OUTPUT_SIZE=1024
4 | BATCH_SIZE=512
5 | CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/
6 |
7 | pushd ..
8 | python -m benchmarks.run_offline \
9 | --model_name=mixtral \
10 | --batch_size=$BATCH_SIZE \
11 | --max_cache_length=$CACHE_LENGTH \
12 | --max_decode_length=$OUTPUT_SIZE \
13 | --context_length=$INPUT_SIZE \
14 | --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \
15 | --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \
16 | --quantize_weights=1 \
17 | --quantize_type=int8_per_channel \
18 | --quantize_kv_cache=1 \
19 | --profiling_output=/mnt/disks/hanq/mixtral-profiles
20 | popd
--------------------------------------------------------------------------------
/benchmarks/summary.md:
--------------------------------------------------------------------------------
1 | # Benchmark results of various models
2 |
3 |
4 | ## Llama 3 - 8B
5 |
6 | Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s)
7 | ----| ------- | ------ |---------- | -------------|-----------------|------------------|----------------------
8 | 2024-04-24 | TPU v5e-8 | bfloat16 | 128 | 2048 | 1024 | 1024 | 8249
9 | 2024-04-24 | TPU v5e-8 | int8 | 256 | 2048 | 1024 | 1024 | 10873
10 | 2024-07-29 | TPU v5e-8 | int8 | 256 | 2048 | 1024 | 1024 | 8471.54
11 |
12 | **NOTE:(2024-07-29)** Looks like we have a regression in the past 3 month. We are working in fixing it.
13 |
14 |
15 | ## Gemma - 7B
16 |
17 | Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s)
18 | ----| ------- | ------ |---------- | -------------|-----------------|------------------|----------------------
19 | 2024-05-10 | TPU v5e-8 | bfloat16 | 96 | 2048 | 1024 | 1024 | 3236
20 | 2024-05-10 | TPU v5e-8 | int8 | 128 | 2048 | 1024 | 1024 | 4695
21 |
22 | ## Gemma - 2B
23 |
24 | Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s)
25 | ----| ------- | ------ |---------- | -------------|-----------------|------------------|----------------------
26 | 2024-05-14 | TPU v5e-8 | bfloat16 | 512 | 2048 | 1024 | 1024 | 8700
27 | 2024-05-14 | TPU v5e-8 | int8 | 1024 | 2048 | 1024 | 1024 | 8746
28 | 2024-06-13 | TPU v5e-1 | bfloat16 | 1024 | 2048 | 1024 | 1024 | 4249
29 |
30 |
31 | ** NOTE: ** Gemma 2B uses `--shard_on_batch` flag so it's data parallel instead
32 | of model parallel.
33 |
34 |
35 | ## Llama 2 - 7B
36 |
37 | Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s)
38 | ----| ------- | ------ |---------- | -------------|-----------------|------------------|----------------------
39 | 2024-03-28 | TPU v5e-8 | bfloat16 | 96 | 2048 | 1024 | 1024 | 3663
40 | 2024-03-28 | TPU v5e-8 | int8 | 96 | 2048 | 1024 | 1024 | 4783
41 |
42 | ## Llama 2 - 13B
43 |
44 | Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s)
45 | ----| ------- | ------ |---------- | -------------|-----------------|------------------|----------------------
46 | 2024-03-28 | TPU v5e-8 | bfloat16 | 48 | 2048 | 1024 | 1024 | 2056
47 | 2024-03-28 | TPU v5e-8 | int8 | 96 | 2048 | 1024 | 1024 | 3458
48 | 2024-03-28 | TPU v5e-8 | bfloat16 | 80 | 1280 | 1024 | 1024 | 2911
49 | 2024-03-28 | TPU v5e-8 | int8 | 96 | 1280 | 1024 | 1024 | 3938
50 |
51 | **NOTE:** When cache length is less than the sum of max input length + max output length
52 | we employ *Rolling window attention*.
53 |
54 |
55 | # Instructions to reproduce:
56 |
57 | Please refer [README.md](README.md) for instructions in how to get the model weights.
58 |
59 | **NOTE** Different weights can produce different benchmark results (due to generating)
60 | different sentence length. For llama, we used the `-chat` versions of the weight.
61 | For Gemma we used the `-it` (instruction finetuned) version of the weights.
62 |
63 | ## Run the server
64 | NOTE: the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`)
65 |
66 | ```bash
67 | python run_server.py --param_size=7b --batch_size= 128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=$model_name
68 | ```
69 | Now you can fire gRPC to it
70 |
71 | # Run benchmark
72 | go to the deps/JetStream folder (downloaded during `install_everything.sh`)
73 |
74 | ```bash
75 | cd deps/JetStream
76 | wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
77 | export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
78 | python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs --warmup-first=True
79 | ```
80 | Please look at `deps/JetStream/benchmarks/README.md` for more information.
81 |
--------------------------------------------------------------------------------
/default_shardings/gemma.yaml:
--------------------------------------------------------------------------------
1 |
2 | # Sharding config for gemma
3 | # "replicated" to signify "replicated".
4 | # Integer signify axis to shard: 0 <= shard axis < rank
5 |
6 | freqs_cis : -1 # torch.complex64 (16384, 128)
7 | layers.*.self_attn.o_proj.weight: 1
8 | layers.*.self_attn.wq.weight : 0 # -1, 1] # torch.float32 (2048, 2048)
9 | layers.*.self_attn.wk.weight : 0 # -1, 1] # torch.float32 (256, 2048)
10 | layers.*.self_attn.wv.weight : 0 # -1, 1] # torch.float32 (256, 2048)
11 | layers.*.mlp.gate_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048)
12 | layers.*.mlp.gate_proj.bias : 0 # -1] # torch.float32 (16384,)
13 | layers.*.mlp.up_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048)
14 | layers.*.mlp.up_proj.bias : 0 # -1] # torch.float32 (16384,)
15 | layers.*.mlp.down_proj.weight : 1 # 1, -1] # torch.float32 (2048, 16384)
16 | layers.*.mlp.down_proj.bias : -1 # torch.float32 (2048,)
17 | layers.*.input_layernorm.weight : -1 # torch.float32 (2048,)
18 | layers.*.post_attention_layernorm.weight : -1 # torch.float32 (2048,)
19 | norm.weight : -1 # torch.float32 (2048,)
20 | embedder.weight : 1 # # 1, -1] # torch.float32 (256000, 2048)
21 | embedder.weight_scaler : 0
22 | layers.*.self_attn.o_proj.weight_scaler: 0
23 | layers.*.self_attn.wq.weight_scaler : 0
24 | layers.*.self_attn.wk.weight_scaler : 0
25 | layers.*.self_attn.wv.weight_scaler : 0
26 | layers.*.mlp.gate_proj.weight_scaler : 0
27 | layers.*.mlp.up_proj.weight_scaler : 0
28 | layers.*.mlp.down_proj.weight_scaler : 0
29 |
--------------------------------------------------------------------------------
/default_shardings/llama-blockwise-quant.yaml:
--------------------------------------------------------------------------------
1 |
2 | # Sharding config for llama-2 (With blockwise quantized linear layers)
3 | # Sharding should either be an int between 0 and rank - 1
4 | # signifying the axis to shard or -1 / null signifying replicated
5 |
6 |
7 | freqs_cis : -1 # torch.complex64 (2048, 64)
8 | tok_embeddings.weight : 1 # torch.int8 (32000, 4096)
9 | tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,)
10 | layers.*.attention.wo.weight : 2 # torch.int8 (32, 128, 4096)
11 | layers.*.attention.wo.weight_scaler : 1 # torch.bfloat16 (32, 4096)
12 | layers.*.attention.wq.weight : 0 # torch.int8 (32, 128, 4096)
13 | layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (32, 4096)
14 | layers.*.attention.wk.weight : 0 # torch.int8 (32, 128, 4096)
15 | layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (32, 4096)
16 | layers.*.attention.wv.weight : 0 # torch.int8 (32, 128, 4096)
17 | layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (32, 4096)
18 | layers.*.feed_forward.w1.weight : 0 # torch.int8 (32, 128, 11008)
19 | layers.*.feed_forward.w1.weight_scaler : 0 # torch.bfloat16 (32, 11008)
20 | layers.*.feed_forward.w2.weight : 2 # torch.int8 (86, 128, 4096)
21 | layers.*.feed_forward.w2.weight_scaler : 1 # torch.bfloat16 (86, 4096)
22 | layers.*.feed_forward.w3.weight : 0 # torch.int8 (32, 128, 11008)
23 | layers.*.feed_forward.w3.weight_scaler : 0 # torch.bfloat16 (32, 11008)
24 | layers.*.attention_norm.weight : -1 # torch.float32 (4096,)
25 | layers.*.ffn_norm.weight : -1 # torch.float32 (4096,)
26 | norm.weight : -1 # torch.float32 (4096,)
27 | output.weight : 0 # torch.int8 (32, 128, 32000)
28 | output.weight_scaler : 0 # torch.float32 (32, 32000)
29 |
--------------------------------------------------------------------------------
/default_shardings/llama.yaml:
--------------------------------------------------------------------------------
1 |
2 | # Sharding config for llama-2
3 | # Sharding should either be an int between 0 and rank - 1
4 | # signifying the axis to shard or -1 / null signifying replicated
5 |
6 |
7 | freqs_cis : -1 # torch.complex64 (2048, 64)
8 | tok_embeddings.weight : 1 # torch.float32 (vocab_size, 4096)
9 | tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,)
10 | layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096)
11 | layers.*.attention.wo.weight_scaler : 0 # torch.bfloat16 (4096,)
12 | layers.*.attention.wq.weight : 0 # torch.int8 (4096, 4096)
13 | layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (4096,)
14 | layers.*.attention.wk.weight : 0 # torch.int8 (4096, 4096)
15 | layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (4096,)
16 | layers.*.attention.wv.weight : 0 # torch.int8 (4096, 4096)
17 | layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (4096,)
18 | layers.*.feed_forward.w1.weight : 0 # torch.float32 (11008, 4096)
19 | layers.*.feed_forward.w1.weight_scaler : 0 # torch.bfloat16 (4096,)
20 | layers.*.feed_forward.w2.weight : 1 # torch.float32 (4096, 11008)
21 | layers.*.feed_forward.w2.weight_scaler : 0 # torch.bfloat16 (11008,)
22 | layers.*.feed_forward.w3.weight : 0 # torch.float32 (11008, 4096)
23 | layers.*.feed_forward.w3.weight_scaler : 0 # torch.bfloat16 (4096,)
24 | layers.*.attention_norm.weight : -1 # torch.float32 (4096,)
25 | layers.*.ffn_norm.weight : -1 # torch.float32 (4096,)
26 | norm.weight : -1 # torch.float32 (4096,)
27 | output.weight : 0 # torch.float32 (vocab_size, 4096)
28 | output.weight_scaler : 0 # torch.float32 (4096,)
29 |
--------------------------------------------------------------------------------
/default_shardings/mixtral.yaml:
--------------------------------------------------------------------------------
1 |
2 | # Sharding config for mixtral
3 | # Sharding should either be an int between 0 and rank - 1
4 | # signifying the axis to shard or -1 / null signifying replicated
5 |
6 |
7 | freqs_cis : -1 # torch.complex64 (2048, 64)
8 | tok_embeddings.weight : 1 # torch.float32 (vocab_size, 4096)
9 | tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,)
10 | layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096)
11 | layers.*.attention.wo.weight_scaler : -1 # torch.bfloat16 (4096,)
12 | layers.*.attention.wq.weight : 0 # torch.int8 (4096, 4096)
13 | layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (4096,)
14 | layers.*.attention.wk.weight : 0 # torch.int8 (4096, 4096)
15 | layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (4096,)
16 | layers.*.attention.wv.weight : 0 # torch.int8 (4096, 4096)
17 | layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (4096,)
18 | layers.*.attention.wqkv.weight : 0 # torch.int8 (4096, 4096)
19 | layers.*.attention.wqkv.weight_scaler : 0 # torch.bfloat16 (4096,)
20 | layers.*.block_sparse_moe.gate.weight: -1
21 | layers.*.block_sparse_moe.gate.weight_scaler: -1
22 | layers.*.block_sparse_moe.cond_ffn.w1: 1
23 | layers.*.block_sparse_moe.cond_ffn.w1_scaler: 1
24 | layers.*.block_sparse_moe.cond_ffn.w2: 2
25 | layers.*.block_sparse_moe.cond_ffn.w2_scaler: -1
26 | layers.*.block_sparse_moe.cond_ffn.w3: 1
27 | layers.*.block_sparse_moe.cond_ffn.w3_scaler: 1
28 | layers.*.ffn_norm.weight : -1 # torch.float32 (4096,)
29 | layers.*.attention_norm.weight : -1 # torch.float32 (4096,)
30 | norm.weight : -1 # torch.float32 (4096,)
31 | output.weight : 0 # torch.float32 (vocab_size, 4096)
32 | output.weight_scaler : 0 # torch.float32 (4096,)
33 |
--------------------------------------------------------------------------------
/docker/jetstream-pytorch-server/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Ubuntu:22.04
16 | # Use Ubuntu 22.04 from Docker Hub.
17 | # https://hub.docker.com/_/ubuntu/tags?page=1&name=22.04
18 | FROM ubuntu:22.04
19 |
20 | ENV DEBIAN_FRONTEND=noninteractive
21 | ENV PYTORCH_JETSTREAM_VERSION=main
22 |
23 | RUN apt -y update && apt install -y --no-install-recommends \
24 | ca-certificates \
25 | git \
26 | python3.10 \
27 | python3-pip
28 |
29 | RUN python3 -m pip install --upgrade pip
30 |
31 | RUN update-alternatives --install \
32 | /usr/bin/python3 python3 /usr/bin/python3.10 1
33 |
34 | RUN git clone https://github.com/AI-Hypercomputer/jetstream-pytorch.git && \
35 | cd /jetstream-pytorch && \
36 | git checkout ${PYTORCH_JETSTREAM_VERSION} && \
37 | bash install_everything.sh
38 |
39 | RUN pip install -U jax[tpu]==0.4.34 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
40 |
41 | COPY jetstream_pytorch_server_entrypoint.sh /usr/bin/
42 |
43 | RUN chmod +x /usr/bin/jetstream_pytorch_server_entrypoint.sh
44 |
45 | ENTRYPOINT ["/usr/bin/jetstream_pytorch_server_entrypoint.sh"]
--------------------------------------------------------------------------------
/docker/jetstream-pytorch-server/README.md:
--------------------------------------------------------------------------------
1 | ## Build and upload JetStream PyTorch Server image
2 |
3 | These instructions are to build the JetStream PyTorch Server image, which calls an entrypoint script that invokes the [JetStream](https://github.com/AI-Hypercomputer/JetStream) inference server with the JetStream-PyTorch framework.
4 |
5 | ```
6 | docker build -t jetstream-pytorch-server .
7 | docker tag jetstream-pytorch-server us-docker.pkg.dev/${PROJECT_ID}/jetstream/jetstream-pytorch-server:latest
8 | docker push us-docker.pkg.dev/${PROJECT_ID}/jetstream/jetstream-pytorch-server:latest
9 | ```
10 |
11 | If you would like to change the version of MaxText the image is built off of, change the `PYTORCH_JETSTREAM_VERSION` environment variable:
12 | ```
13 | ENV PYTORCH_JETSTREAM_VERSION=
14 | ```
--------------------------------------------------------------------------------
/docker/jetstream-pytorch-server/jetstream_pytorch_server_entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright 2024 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | export HUGGINGFACE_TOKEN_DIR="/huggingface"
18 |
19 | cd /jetstream-pytorch
20 | huggingface-cli login --token $(cat ${HUGGINGFACE_TOKEN_DIR}/HUGGINGFACE_TOKEN)
21 | jpt serve $@
--------------------------------------------------------------------------------
/docs/add_hf_checkpoint_conversion.md:
--------------------------------------------------------------------------------
1 | # Guide on adding HuggingFace checkpoint conversion support
2 |
3 | ## Prerequisites:
4 | The model implementation has been added in JetStream-pt
5 | The checkpoint conversion from a certain format is already supported. (Or no conversion is needed for the checkpoint)
6 |
7 | Please check this [guide](https://github.com/google/jetstream-pytorch/blob/main/docs/add_a_new_model.md) for adding a new model.
8 |
9 | ## Use case:
10 | The user has the checkpoint for the same model architecture in another format (e.g. HF format for LLaMA model). And want to have JetStream-pt support this checkpoint format.
11 |
12 | ## Guide
13 |
14 | Converting a public checkpoint to JetStream-pt format is mostly about finding the weight key mapping between the public checkpoint and JetStream model implementation. Besides the name mapping, the layout of the weights might be different among different checkpoint formats (e.g. Weight interleaved differently due to difference in Rotary Embedding implementation). These differences are model and checkpoint format specific.
15 |
16 | **Note** The model code and checkpoint format can be different from model to model, the following guide demonstrate a general guide, specific models may require additional effort for the checkpoint conversion support.
17 |
18 | The checkpoint conversion logic in the checkpoint conversion script.
19 |
20 | ### Step 1 Find the HuggingFace checkpoint you want to convert
21 | In this example, let’s use meta-llama/llama-2 7B as an example
22 |
23 | You can download the checkpoints to a local folder using
24 | huggingface-cli download meta-llama/Llama-2-7b-hf --local-dir Llama-2-7b-hf
25 |
26 |
27 | **Note** You may need to go to Huggingface website to sign an agreement to get the permission to download the model
28 |
29 | ### Step 2 Inspect the weight names in the checkpoint:
30 |
31 | Usually there is a model.safetensors.index.json file in the checkpoint. [example](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/model.safetensors.index.json)
32 |
33 | Alternatively, you can load the weights locally and inspect the model key names(Usually it’s in safetensor format, and it’s sharded)
34 |
35 | Example script:
36 | ```Python
37 | import glob
38 | import os
39 | import torch
40 | from safetensors import safe_open
41 |
42 | checkpoint_folder = "/mnt/disks/lsiyuan/llama_weight/Meta-Llama-3-8B-Instruct"
43 |
44 | safetensor_files = glob.glob(os.path.join(checkpoint_folder, "*.safetensors"))
45 |
46 | for st_f in safetensor_files:
47 | with safe_open(st_f, framework="pt", device="cpu") as f:
48 | for key in f.keys():
49 | weight_tensor = f.get_tensor(key)
50 | print(f"Weight name {key}, Shape: {weight_tensor.shape}, dtype: {weight_tensor.dtype}")
51 | ```
52 |
53 | Got the following output:
54 |
55 | ```
56 | lm_head.weight torch.Size([32000, 4096]) x torch.float16
57 | model.norm.weight torch.Size([4096]) x torch.float16
58 | model.embed_tokens.weight torch.Size([32000, 4096]) x torch.float16
59 | model.layers.0.input_layernorm.weight torch.Size([4096]) x torch.float16
60 | model.layers.0.mlp.down_proj.weight torch.Size([4096, 11008]) x torch.float16
61 | model.layers.0.mlp.gate_proj.weight torch.Size([11008, 4096]) x torch.float16
62 | model.layers.0.mlp.up_proj.weight torch.Size([11008, 4096]) x torch.float16
63 | model.layers.0.post_attention_layernorm.weight torch.Size([4096]) x torch.float16
64 | model.layers.0.self_attn.k_proj.weight torch.Size([4096, 4096]) x torch.float16
65 | model.layers.0.self_attn.o_proj.weight torch.Size([4096, 4096]) x torch.float16
66 | model.layers.0.self_attn.q_proj.weight torch.Size([4096, 4096]) x torch.float16
67 | model.layers.0.self_attn.rotary_emb.inv_freq torch.Size([64]) x torch.float32
68 | model.layers.0.self_attn.v_proj.weight torch.Size([4096, 4096]) x torch.float16
69 | … # Duplicated name for model.layers.x
70 | ```
71 |
72 | If it’s hard to tell which layer the weight is for, the HF model class can be checked in the checkpoint config file [example](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L4). Then we can find the model code in the transformer repo by searching the model class name [model code](https://github.com/huggingface/transformers/blob/bdf36dcd48106a4a0278ed7f3cc26cd65ab7b066/src/transformers/models/llama/modeling_llama.py#L1084)
73 |
74 |
75 | ### Step 3 Inspect the weight names in JetStream-pt model implementation:
76 |
77 | Run the model in JetStream using benchmarks/run_offline.py. The weight names, shape and dtype will be printed in the log (Omitting Layer N which are duplicated names)
78 |
79 | Example:
80 |
81 | ```
82 | Name: freqs_cis, shape: (2048, 64) x complex64
83 | Name: tok_embeddings.weight, shape: (32000, 4096) x bfloat16
84 | Name: layers.0.attention.wo.weight, shape: (4096, 4096) x bfloat16
85 | Name: layers.0.attention.wq.weight, shape: (4096, 4096) x bfloat16
86 | Name: layers.0.attention.wk.weight, shape: (4096, 4096) x bfloat16
87 | Name: layers.0.attention.wv.weight, shape: (4096, 4096) x bfloat16
88 | Name: layers.0.feed_forward.w1.weight, shape: (11008, 4096) x bfloat16
89 | Name: layers.0.feed_forward.w2.weight, shape: (4096, 11008) x bfloat16
90 | Name: layers.0.feed_forward.w3.weight, shape: (11008, 4096) x bfloat16
91 | Name: layers.0.attention_norm.weight, shape: (4096,) x bfloat16
92 | Name: layers.0.ffn_norm.weight, shape: (4096,) x bfloat16
93 | Name: norm.weight, shape: (4096,) x bfloat16
94 | Name: output.weight, shape: (32000, 4096) x bfloat16
95 | ```
96 |
97 | If it’s hard to tell which layer the weight is for, you can find out the meaning of the weight, please check the model implementation under jetstream_pt/third_party.
98 |
99 | ### Step 4 By comparing the weight names, or diving into the model code, we can find out the mapping:
100 |
101 | In this example:
102 |
103 | HF lm_head.weight -> JetStream-pt output.weight
104 | HF model.norm.weight -> JetStream-pt norm.weight
105 | HF model.embed_tokens.weight -> JetStream-pt tok_embeddings.weight
106 | HF model.layers.X.input_layernorm.weight -> layers.X.attention_norm.weight
107 | HF model.layers.0.post_attention_layernorm.weight -> layers.0.ffn_norm.weight
108 | HF model.layers.X.self_attn.{q/k/v/o}_proj.weight -> layers.X.attention.w{q/k/v/o}.weight
109 | HF model.layers.X.mlp.gate_proj.weight -> layers.X.feed_forward.w1.weight
110 | HF model.layers.X.mlp.down_proj.weight -> layers.X.feed_forward.w2.weight
111 | HF model.layers.X.mlp.up_proj.weight -> layers.X.feed_forward.w3.weight
112 | freqs_cis is a special case, in JetStream PyTorch, the weight is pre-computed during weight loading, so no need to map the Huggingface freq weight over.
113 |
114 | ### Step 5 Validate the converted checkpoint:
115 |
116 | If there is a checkpoint in already supported format, convert the checkpoint in supported format first, as the golden data to compare with the converted checkpoint from the new format.
117 |
118 | Write a small script, or reuse the [script](https://github.com/google/jetstream-pytorch/blob/main/scripts/validate_hf_ckpt_conversion.py) to compare the 2 converted checkpoints.
119 |
120 | Fix the difference between 2 converted checkpoints if there is any. (This will be model and checkpoint format specific)
121 |
122 | ### Step 6 End-to-end validation: From checkpoint conversion to serving
123 |
124 | Example
125 |
126 | ```
127 | export input_ckpt_dir=/mnt/disks/lsiyuan/llama_weight/7B-FT-chat
128 | export output_ckpt_dir=/mnt/disks/lsiyuan/llama_weight/hf_llama_2_7b_converted_bf16_2
129 | export model_name="llama"
130 | export from_hf=True
131 | python -m convert_checkpoints --model_name=$model_name \
132 | --input_checkpoint_dir=$input_ckpt_dir \
133 | --output_checkpoint_dir=$output_ckpt_dir \
134 | --quantize_weights=$quantize_weights \
135 | --quantize_type=$quantize_type \
136 | --from_hf=True
137 | ```
--------------------------------------------------------------------------------
/install_everything.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Uninstall existing jax
16 | pip show jax && pip uninstall -y jax
17 | pip show jaxlib && pip uninstall -y jaxlib
18 | pip show libtpu-nightly && pip uninstall -y libtpu-nightly
19 | pip show tensorflow && pip uninstall -y tensorflow
20 | pip show ray && pip uninstall -y ray
21 | pip show flax && pip uninstall -y flax
22 | pip show keras && pip uninstall -y keras
23 | pip show tensorboard && pip uninstall -y tensorboard
24 | pip show tensorflow-text && pip uninstall -y tensorflow-text
25 | pip show torch_xla2 && pip uninstall -y torch_xla2
26 |
27 | pip install flax
28 | pip install tensorflow-text
29 | pip install tensorflow
30 | pip install huggingface_hub
31 | pip install transformers
32 |
33 | pip install ray[default]==2.33.0
34 | # torch cpu
35 | pip install torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu
36 | pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage
37 | pip install safetensors colorama coverage humanize
38 |
39 | git submodule update --init --recursive
40 | pip show google-jetstream && pip uninstall -y google-jetstream
41 | pip show torch_xla2 && pip uninstall -y torch_xla2
42 | pip install -e .
43 | pip install -U jax[tpu]==0.4.37 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
44 | pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu
45 |
--------------------------------------------------------------------------------
/install_everything_gpu.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Uninstall existing jax
16 | pip show jax && pip uninstall -y jax
17 | pip show jaxlib && pip uninstall -y jaxlib
18 | pip show libtpu-nightly && pip uninstall -y libtpu-nightly
19 | pip show tensorflow && pip uninstall -y tensorflow
20 | pip show ray && pip uninstall -y ray
21 | pip show flax && pip uninstall -y flax
22 | pip show keras && pip uninstall -y keras
23 | pip show tensorboard && pip uninstall -y tensorboard
24 | pip show tensorflow-text && pip uninstall -y tensorflow-text
25 | pip show torch_xla2 && pip uninstall -y torch_xla2
26 |
27 | pip install flax==0.8.4
28 | pip install tensorflow-text
29 | pip install tensorflow
30 | pip install transformers
31 |
32 | pip install ray[default]==2.22.0
33 | # torch cpu
34 | pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage
35 | pip install safetensors colorama coverage humanize
36 |
37 | git submodule update --init --recursive
38 | pip show google-jetstream && pip uninstall -y google-jetstream
39 | pip show torch_xla2 && pip uninstall -y torch_xla2
40 | pip install -e .
41 | pip install -U jax[cuda12]==0.4.30
42 | pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu
43 |
--------------------------------------------------------------------------------
/jetstream_pt/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from jetstream_pt.engine import create_pytorch_engine
16 |
17 | __all__ = ["create_pytorch_engine"]
18 |
--------------------------------------------------------------------------------
/jetstream_pt/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | from absl import flags
18 | import jax
19 | from jetstream_pt.environment import QuantizationConfig
20 |
21 | FLAGS = flags.FLAGS
22 |
23 | flags.DEFINE_string("tokenizer_path", None, "The tokenizer model path")
24 | flags.DEFINE_string("model_name", None, "model type")
25 | flags.DEFINE_string("checkpoint_path", None, "Directory for .pth checkpoints")
26 | flags.DEFINE_bool("bf16_enable", True, "Whether to enable bf16")
27 | flags.DEFINE_integer("context_length", 1024, "The context length")
28 | flags.DEFINE_integer("batch_size", 32, "The batch size")
29 | flags.DEFINE_string("size", "tiny", "size of model")
30 | flags.DEFINE_integer("max_cache_length", 1024, "kv_cache_quantize")
31 | flags.DEFINE_integer("max_decode_length", 1024, "max length of generated text")
32 | flags.DEFINE_string("sharding_config", "", "config file for sharding")
33 | flags.DEFINE_bool(
34 | "shard_on_batch",
35 | False,
36 | "whether to shard on batch dimension"
37 | "If set true, sharding_config will be ignored.",
38 | )
39 | flags.DEFINE_string("profiling_output", "", "The profiling output")
40 |
41 | # Quantization related flags
42 | flags.DEFINE_bool("quantize_weights", False, "weight quantization")
43 | flags.DEFINE_bool(
44 | "quantize_activation",
45 | False,
46 | "Quantize Q,K,V projection and FeedForward activation. Defaults to False",
47 | )
48 | flags.DEFINE_string(
49 | "quantize_type", "int8_per_channel", "Type of quantization."
50 | )
51 | flags.DEFINE_bool(
52 | "quantize_kv_cache", None, "defaults to the same value as quantize_weights"
53 | )
54 | flags.DEFINE_multi_string(
55 | "quantize_exclude_layers",
56 | None,
57 | "List of layer names to exclude from quantization",
58 | )
59 |
60 | _VALID_QUANTIZATION_TYPE = {
61 | "int8_per_channel",
62 | "int4_per_channel",
63 | "int8_blockwise",
64 | "int4_blockwise",
65 | }
66 |
67 | flags.register_validator(
68 | "quantize_type",
69 | lambda value: value in _VALID_QUANTIZATION_TYPE,
70 | f"quantize_type is invalid, supported quantization types are {_VALID_QUANTIZATION_TYPE}",
71 | )
72 | flags.DEFINE_bool(
73 | "profiling_prefill",
74 | False,
75 | "Whether to profile the prefill, "
76 | "if set to false, profile generate function only",
77 | required=False,
78 | )
79 | flags.DEFINE_bool(
80 | "ragged_mha",
81 | False,
82 | "Whether to enable Ragged multi head attention",
83 | required=False,
84 | )
85 | flags.DEFINE_integer(
86 | "starting_position",
87 | 512,
88 | "The starting position of decoding, "
89 | "for performance tuning and debugging only",
90 | required=False,
91 | )
92 | flags.DEFINE_bool(
93 | "ring_buffer",
94 | True,
95 | "Whether to enable ring buffer",
96 | required=False,
97 | )
98 | flags.DEFINE_bool(
99 | "flash_attention",
100 | False,
101 | "Whether to enable flas attention. Only takes effect at test mode",
102 | required=False,
103 | )
104 | flags.DEFINE_bool(
105 | "generate_cache_stacked",
106 | False,
107 | "Whether to stack the generate cache to the layer dimension. Only takes effect at test mode",
108 | required=False,
109 | )
110 | flags.DEFINE_bool(
111 | "new_cache_stacked",
112 | False,
113 | "Whether to stack the generate cache to the layer dimension. Only takes effect at test mode",
114 | required=False,
115 | )
116 | flags.DEFINE_bool(
117 | "lazy_cache_update",
118 | False,
119 | "Whether to update the cache during attention or delayed until all the layers are done. "
120 | "Only takes effect at test mode",
121 | required=False,
122 | )
123 | flags.DEFINE_float(
124 | "temperature",
125 | 1.0,
126 | "temperature parameter for scaling probability."
127 | "Only invoked when sampling algorithm is set to"
128 | "weighted or topk",
129 | )
130 | flags.DEFINE_string(
131 | "sampling_algorithm",
132 | "greedy",
133 | "sampling algorithm to use. Options:"
134 | "('greedy', 'weighted', 'neucleus', 'topk')",
135 | )
136 | flags.DEFINE_float(
137 | "nucleus_topp",
138 | 0.0,
139 | "restricting to p probability mass before sampling",
140 | )
141 | flags.DEFINE_integer(
142 | "topk",
143 | 0,
144 | "size of top k used when sampling next token",
145 | )
146 |
147 | flags.DEFINE_integer(
148 | "paged_attention_total_num_pages",
149 | 0,
150 | "total number of pages per layer for page attention",
151 | )
152 |
153 | flags.DEFINE_integer(
154 | "paged_attention_page_size",
155 | 64,
156 | "page size per page",
157 | )
158 | flags.DEFINE_string(
159 | "internal_jax_compilation_cache_dir",
160 | "~/jax_cache",
161 | "Jax compilation cache directory",
162 | )
163 | flags.DEFINE_integer(
164 | "internal_jax_persistent_cache_min_entry_size_bytes",
165 | 0,
166 | "Minimum size (in bytes) of an entry that will be cached in the persistent compilation cache",
167 | )
168 | flags.DEFINE_integer(
169 | "internal_jax_persistent_cache_min_compile_time_secs",
170 | 1,
171 | "Minimum compilation time for a computation to be written to persistent cache",
172 | )
173 |
174 |
175 | def create_quantization_config_from_flags():
176 | """Create Quantization Config from cmd flags"""
177 | config = QuantizationConfig()
178 | quantize_weights = FLAGS.quantize_weights
179 | quantize_type = FLAGS.quantize_type
180 | if not quantize_weights:
181 | return config
182 | config.enable_weight_quantization = True
183 |
184 | config.num_bits_weight = 8 if "int8" in quantize_type else 4
185 | config.is_blockwise_weight = "blockwise" in quantize_type
186 |
187 | config.enable_activation_quantization = FLAGS.quantize_activation
188 | config.exclude_layers = FLAGS.quantize_exclude_layers
189 | config.enable_kv_quantization = (
190 | FLAGS.quantize_kv_cache
191 | if FLAGS.quantize_kv_cache is not None
192 | else FLAGS.quantize_weights
193 | )
194 | return config
195 |
196 |
197 | def set_jax_compilation_cache_config():
198 | """Sets the jax compilation cache configuration"""
199 | jax.config.update(
200 | "jax_compilation_cache_dir",
201 | os.path.expanduser(FLAGS.internal_jax_compilation_cache_dir),
202 | )
203 | jax.config.update(
204 | "jax_persistent_cache_min_entry_size_bytes",
205 | FLAGS.internal_jax_persistent_cache_min_entry_size_bytes,
206 | )
207 | jax.config.update(
208 | "jax_persistent_cache_min_compile_time_secs",
209 | FLAGS.internal_jax_persistent_cache_min_compile_time_secs,
210 | )
211 |
--------------------------------------------------------------------------------
/jetstream_pt/fetch_models.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import glob
3 | import os
4 | from typing import Optional
5 | from requests.exceptions import HTTPError
6 | from huggingface_hub import snapshot_download
7 | from absl import flags
8 | import torch
9 | from safetensors import safe_open
10 | from jetstream_pt.environment import (
11 | JetEngineEnvironmentData,
12 | )
13 | from jetstream_pt.third_party.llama import model_exportable as llama_model
14 | from jetstream_pt.third_party.mixtral import model as mixtral_model
15 | from jetstream_pt.third_party.gemma import model as gemma_model
16 |
17 | FLAGS = flags.FLAGS
18 |
19 | flags.DEFINE_string(
20 | "working_dir",
21 | "checkpoints",
22 | "Directory to store downloaded/converted weights",
23 | )
24 | flags.DEFINE_string("hf_token", "", "huggingface token")
25 | flags.DEFINE_bool(
26 | "internal_use_random_weights",
27 | False,
28 | "Use random weights instead of HF weights. Testing only.",
29 | )
30 |
31 | flags.DEFINE_bool(
32 | "internal_use_tiny_model",
33 | False,
34 | "Use tiny config instead of real config of HF weights. Testing only.",
35 | )
36 |
37 | flags.DEFINE_integer(
38 | "override_max_cache_length",
39 | -1,
40 | "Size of cache, defaults to input + output length",
41 | )
42 |
43 |
44 | @dataclasses.dataclass
45 | class ModelInfo:
46 | """Model information."""
47 |
48 | model_class: torch.nn.Module
49 | # information needed to allocate cache
50 | num_layers: int
51 | # number of kv heads
52 | num_kv_heads: int
53 |
54 | head_dim: int
55 | n_reps: int # repeatition for GQA
56 |
57 |
58 | _llama2_7 = ModelInfo(llama_model.Transformer, 32, 32, 128, 1)
59 | _llama2_13 = ModelInfo(llama_model.Transformer, 40, 40, 128, 1)
60 | _llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 8)
61 | _llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128, 4)
62 | _llama3_70 = _llama2_70
63 | _llama3_1_8b = _llama3_8
64 | _llama3_2_1b = ModelInfo(llama_model.Transformer, 16, 8, 64, 4)
65 | _llama3_3_70b = _llama2_70
66 |
67 | _mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4)
68 |
69 | _gemma_2b = ModelInfo(gemma_model.GemmaModel, 18, 1, 256, 8)
70 | _gemma_7b = ModelInfo(gemma_model.GemmaModel, 28, 16, 256, 1)
71 |
72 |
73 | model_id_to_class = {
74 | "meta-llama/Llama-2-7b-chat-hf": _llama2_7,
75 | "meta-llama/Llama-2-7b-hf": _llama2_7,
76 | "meta-llama/Llama-2-13b-chat-hf": _llama2_13,
77 | "meta-llama/Llama-2-13b-hf": _llama2_13,
78 | "meta-llama/Llama-2-70b-hf": _llama2_70,
79 | "meta-llama/Llama-2-70b-chat-hf": _llama2_70,
80 | "meta-llama/Meta-Llama-3-8B": _llama3_8,
81 | "meta-llama/Meta-Llama-3-8B-Instruct": _llama3_8,
82 | "meta-llama/Meta-Llama-3-70B": _llama3_70,
83 | "meta-llama/Meta-Llama-3-70B-Instruct": _llama3_70,
84 | "meta-llama/Llama-3.1-8B": _llama3_1_8b,
85 | "meta-llama/Llama-3.1-8B-Instruct": _llama3_1_8b,
86 | "meta-llama/Llama-3.2-1B": _llama3_2_1b,
87 | "meta-llama/Llama-3.2-1B-Instruct": _llama3_2_1b,
88 | "meta-llama/Llama-3.3-70B": _llama3_3_70b,
89 | "meta-llama/Llama-3.3-70B-Instruct": _llama3_3_70b,
90 | "google/gemma-2b": _gemma_2b,
91 | "google/gemma-2b-it": _gemma_2b,
92 | "google/gemma-7b": _gemma_7b,
93 | "google/gemma-7b-it": _gemma_7b,
94 | "mistralai/Mixtral-8x7B-v0.1": _mixtral_87,
95 | "mistralai/Mixtral-8x7B-Instruct-v0.1": _mixtral_87,
96 | }
97 |
98 |
99 | def _model_dir(repo_id):
100 | """Model dir structure:
101 |
102 | working_dir/
103 | repo_id/
104 | hf_original/
105 | converted_bfloat/
106 | converted_int8/
107 | """
108 | return os.path.join(FLAGS.working_dir, repo_id)
109 |
110 |
111 | def _hf_dir(repo_id):
112 | """Dir to hf repo"""
113 | return os.path.join(_model_dir(repo_id), "hf_original")
114 |
115 |
116 | def _int_dir(repo_id):
117 | return os.path.join(_model_dir(repo_id), "converted_int8")
118 |
119 |
120 | def construct_env_data_from_model_id(
121 | repo_id,
122 | batch_size,
123 | input_length,
124 | output_length,
125 | ):
126 | """Create Environment from model id and options"""
127 | tokenizer_path = os.path.join(_hf_dir(repo_id), "tokenizer.model")
128 | checkpoint_path = _hf_dir(repo_id)
129 | checkpoint_format = "safetensors"
130 |
131 | shard_on_batch = False
132 |
133 | max_cache_length = (
134 | FLAGS.override_max_cache_length
135 | if FLAGS.override_max_cache_length > 0
136 | else input_length + output_length
137 | )
138 |
139 | model_info = model_id_to_class.get(repo_id)
140 | env_data = JetEngineEnvironmentData(
141 | tokenizer_path=tokenizer_path,
142 | checkpoint_path=checkpoint_path,
143 | checkpoint_format=checkpoint_format,
144 | batch_size=batch_size,
145 | max_decode_length=output_length,
146 | max_input_sequence_length=input_length,
147 | cache_sequence_length=max_cache_length,
148 | bf16_enable=True,
149 | sharding_config_path="",
150 | shard_on_batch=shard_on_batch,
151 | n_reps=model_info.n_reps,
152 | )
153 | env_data.cache_shape = (
154 | batch_size,
155 | model_info.num_kv_heads,
156 | max_cache_length,
157 | model_info.head_dim,
158 | )
159 | env_data.num_layers = model_info.num_layers
160 | return env_data
161 |
162 |
163 | def _load_weights(directory):
164 | safetensors_files = glob.glob(os.path.join(directory, "*.safetensors"))
165 | state_dict = {}
166 | for file_path in safetensors_files:
167 | with safe_open(file_path, framework="pt") as f:
168 | for key in f.keys():
169 | state_dict[key] = f.get_tensor(key).to(torch.bfloat16)
170 | # Load the state_dict into the model
171 | if not state_dict:
172 | raise AssertionError(
173 | f"Tried to load weights from {directory}, but couldn't find any."
174 | )
175 | return state_dict
176 |
177 |
178 | def _make_random_model_weights(model):
179 | result = {}
180 | for key, val in model.state_dict().items():
181 | new_weights = torch.rand(val.shape, dtype=val.dtype, device="cpu")
182 | result[key] = new_weights
183 | return result
184 |
185 |
186 | def instantiate_model_from_repo_id(
187 | repo_id,
188 | env,
189 | ):
190 | """Create model instance by hf model id.+"""
191 | model_dir = _hf_dir(repo_id)
192 | if not FLAGS.internal_use_random_weights and (
193 | not os.path.exists(model_dir)
194 | or not glob.glob(os.path.join(model_dir, "*.safetensors"))
195 | ):
196 | # no weights has been downloaded
197 | _hf_download(repo_id, model_dir, FLAGS.hf_token)
198 | model_info = model_id_to_class.get(repo_id)
199 | assert model_info is not None
200 |
201 | env.device = "meta"
202 | model = model_info.model_class.from_hf_model_id(
203 | repo_id, env, FLAGS.internal_use_tiny_model
204 | )
205 | if FLAGS.internal_use_random_weights or FLAGS.internal_use_tiny_model:
206 | weights = _make_random_model_weights(model)
207 | else:
208 | weights = _load_weights(model_dir)
209 | weights = model.convert_hf_weights(weights)
210 | model.load_state_dict(weights, assign=True, strict=False)
211 |
212 | return model
213 | ## QQ do i need to set the weights onto the model?
214 |
215 |
216 | def _hf_download(
217 | repo_id: str, dest_directory: str, hf_token: Optional[str] = None
218 | ) -> None:
219 | os.makedirs(dest_directory, exist_ok=True)
220 | try:
221 | if not hf_token:
222 | hf_token = os.environ.get("HF_TOKEN")
223 | if not hf_token:
224 | # NOTE: setting true allows hf to read from the config folder.
225 | hf_token = True
226 | snapshot_download(
227 | repo_id,
228 | local_dir=dest_directory,
229 | local_dir_use_symlinks=False,
230 | token=hf_token,
231 | allow_patterns=[
232 | "model*.safetensors",
233 | "*.json",
234 | "*.model",
235 | ],
236 | )
237 | except HTTPError as e:
238 | if e.response.status_code == 401:
239 | print(
240 | "Please use huggingface-cli login to authenticate "
241 | "to download private checkpoints."
242 | )
243 | print("OR, pass `hf_token=...` explicitly.")
244 | raise e
245 |
--------------------------------------------------------------------------------
/jetstream_pt/gcs_to_cns.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2024 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | set -e
17 |
18 | LOG_FILE_IN_GCS=$1
19 | filename=$(basename $LOG_FILE_IN_GCS)
20 | output_file=$(date "+%Y-%m-%d-%H:%M:%S")_${filename}
21 |
22 | CNS_PATH=/cns/pi-d/home/${USER}/tensorboard/multislice/
23 | fileutil mkdir -p ${CNS_PATH}
24 | /google/data/ro/projects/cloud/bigstore/mpm/fileutil_bs/stable/bin/fileutil_bs cp /bigstore/${LOG_FILE_IN_GCS} ${CNS_PATH}/$output_file
25 | echo file to put into xprof: ${CNS_PATH}/$output_file
26 |
--------------------------------------------------------------------------------
/jetstream_pt/hf_tokenizer.py:
--------------------------------------------------------------------------------
1 | from jetstream.engine import tokenizer_api, token_utils
2 |
3 |
4 | class HFTokenizerAdapter(tokenizer_api.Tokenizer):
5 | """Implementation of Tokenizer interface backed by HF tokenizer."""
6 |
7 | def __init__(self, tokenizer):
8 | self.tokenizer = tokenizer
9 |
10 | def encode(self, s: str, **kwargs):
11 | """Tokenize a string.
12 | Args:
13 | s: String to tokenize.
14 | **kwargs: Additional keyword arguments.
15 | Returns:
16 | tokens: Tokenized into integers.
17 | true_length: Actual length of the non-padded sequence
18 | if padding is used.
19 | """
20 | res = self.tokenizer.encode(s, add_special_tokens=False)
21 | return token_utils.pad_tokens(
22 | res, self.bos_id, self.pad_id, jax_padding=True
23 | )
24 |
25 | def decode(self, token_ids: list[int], **kwargs) -> str:
26 | """Processess input token ids to generate a string.
27 | Args:
28 | token_ids: List of token ids.
29 | **kwargs: Additional keyword arguments.
30 | Returns:
31 | str: String generated from the token ids.
32 | """
33 | return self.tokenizer.decode(token_ids)
34 |
35 | @property
36 | def pad_id(self) -> int:
37 | """ID of the pad token."""
38 | return self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else 0
39 |
40 | @property
41 | def eos_id(self) -> int:
42 | """ID of EOS token."""
43 | return self.tokenizer.eos_token_id
44 |
45 | @property
46 | def bos_id(self) -> int:
47 | """ID of BOS token."""
48 | return self.tokenizer.bos_token_id
49 |
50 | @property
51 | def stop_tokens(self) -> set[int]:
52 | """ID of the stop token."""
53 | return {self.eos_id}
54 |
--------------------------------------------------------------------------------
/jetstream_pt/model_base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import itertools
3 | from typing import Dict, Any, Optional
4 | import dataclasses
5 | from collections import defaultdict
6 | import torch
7 |
8 |
9 | def _get_hf_name(module, key):
10 | if hasattr(module, "attr_to_property") and key in module.attr_to_property:
11 | return module.attr_to_property[key].huggingface_name
12 | return None
13 |
14 |
15 | def _gather_names(module, myprefix, hf_prefix, result):
16 | for key, _ in itertools.chain(
17 | module.named_parameters(recurse=False),
18 | module.named_buffers(recurse=False),
19 | ):
20 | hf_name = _get_hf_name(module, key) or key
21 | result[hf_prefix + hf_name] = myprefix + key
22 |
23 | for name, child in module.named_children():
24 | hf_name = _get_hf_name(module, name) or name
25 | _gather_names(
26 | child, myprefix + name + ".", hf_prefix + hf_name + ".", result
27 | )
28 |
29 |
30 | def _gather_sharding_axis(module, myprefix, result):
31 | if hasattr(module, "attr_to_property"):
32 | for key, val in module.attr_to_property.items():
33 | if val.sharding_axis is not None:
34 | result[myprefix + key] = val.sharding_axis
35 |
36 | for name, child in module.named_children():
37 | _gather_sharding_axis(child, myprefix + name + ".", result)
38 |
39 |
40 | @dataclasses.dataclass
41 | class AttrProperty:
42 | """Attributes attached to model weights."""
43 |
44 | huggingface_name: Optional[str] = None
45 | sharding_axis: Optional[int] = None
46 |
47 |
48 | class ModuleBase(torch.nn.Module, metaclass=abc.ABCMeta):
49 | """nn Module that allows attaching properties.
50 |
51 | This class currently serves 2 goals:
52 | 1. Allow model to specify alternative names for submodules / weights
53 | this is needed so that it can *also* load HuggingFace checkpoints
54 | without need to do massive rewrites.
55 |
56 | 2. Allow model to attach information to weights, such as sharding config.
57 |
58 | Quantization config could be another thing to attach, but right now it's not used
59 | this way.
60 | """
61 |
62 | attr_to_property: Dict[str, Any]
63 |
64 | def __init__(self):
65 | super().__init__()
66 | self.attr_to_property = defaultdict(AttrProperty)
67 |
68 | def get_hf_names_to_real_name(self):
69 | """Return a dict of attr names to it's hf name."""
70 | result = {}
71 | _gather_names(self, "", "", result)
72 | return result
73 |
74 | def get_sharding_annotations(self):
75 | """Return a dict of attr names to it's sharding dim."""
76 | result = {}
77 | _gather_sharding_axis(self, "", result)
78 | return result
79 |
80 | def hf_name(self, orig_name, hf_name):
81 | """Set it's alternative name for a attribute or submodule."""
82 | self.attr_to_property[orig_name].huggingface_name = hf_name
83 |
84 | def annotate_sharding(self, name, axis):
85 | """Set sharding name for a attribute or submodule."""
86 | self.attr_to_property[name].sharding_axis = axis
87 |
88 | def convert_hf_weights(
89 | self, hf_weights: Dict[str, torch.Tensor]
90 | ) -> Dict[str, torch.Tensor]:
91 | """Load state_dict with hg weights."""
92 | weights = {}
93 | updated_keys = self.get_hf_names_to_real_name()
94 | for name, updated in updated_keys.items():
95 | if name in hf_weights:
96 | weights[updated] = hf_weights[name]
97 |
98 | for name in list(weights.keys()):
99 | if "inv_freq" in name:
100 | weights.pop(name)
101 | if hasattr(self, "freqs_cis"):
102 | weights["freqs_cis"] = self.freqs_cis
103 | return weights
104 |
--------------------------------------------------------------------------------
/jetstream_pt/page_attention_manager.py:
--------------------------------------------------------------------------------
1 | import queue
2 | import functools
3 | from typing import List, Tuple
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import jax.sharding as jsharding
8 | import numpy as np
9 |
10 |
11 | class PageAttentionManager:
12 | """Manages page blocks.
13 |
14 | This manager maintains a main list of free page blocks, it support below features:
15 | 1. Reseve pages for prefill insert and decode.
16 | 2. Free pages resource for the slots after decode. Pages indices go to free list.
17 | 3. Get pages indices meta data for all the slots.
18 | 4. Transform and insert prefill caches to decode caches.
19 | """
20 |
21 | def __init__(
22 | self,
23 | batch_size: int,
24 | paged_attention_total_num_pages: int,
25 | paged_attention_page_size: int,
26 | max_pages_per_sequence: int,
27 | ):
28 | self.unused_pages = queue.Queue()
29 | self.batch_size = batch_size
30 | self.page_indices = np.full(
31 | (batch_size, max_pages_per_sequence),
32 | paged_attention_total_num_pages - 1,
33 | dtype=np.int32,
34 | )
35 | self.lengths = np.zeros(batch_size, dtype=np.int32)
36 | self.paged_attention_page_size = paged_attention_page_size
37 | self.max_pages_per_sequence = max_pages_per_sequence
38 | for i in range(paged_attention_total_num_pages):
39 | self.unused_pages.put(i, block=False)
40 |
41 | # pylint: disable-next=all
42 | def reserve_pages_insert(self, slot: int, seq_len: int):
43 | self.lengths[slot] = seq_len
44 | num_pages = (
45 | seq_len // self.paged_attention_page_size
46 | if seq_len % self.paged_attention_page_size == 0
47 | else seq_len // self.paged_attention_page_size + 1
48 | )
49 |
50 | indices = [self.unused_pages.get(block=False) for _ in range(num_pages)]
51 | self.page_indices[slot, :num_pages] = indices
52 | return num_pages, self.page_indices[slot, :num_pages]
53 |
54 | # pylint: disable-next=all
55 | def reserve_pages_decode(self, slot: int, seq_len: int):
56 | if seq_len > 0 and seq_len % self.paged_attention_page_size == 0:
57 | index = self.unused_pages.get(block=False)
58 | num_pages = seq_len // self.paged_attention_page_size
59 | self.page_indices[slot, num_pages] = index
60 |
61 | # pylint: disable-next=all
62 | def fill_new_pages(self, lens):
63 | for slot in range(self.batch_size):
64 | self.reserve_pages_decode(slot, lens[slot])
65 |
66 | # pylint: disable-next=all
67 | def prefill_cache_padding(
68 | self,
69 | caches: List[Tuple[jax.Array, jax.Array]],
70 | seq_len: int,
71 | num_pages: int,
72 | ) -> List[Tuple[jax.Array, jax.Array]]:
73 |
74 | pad_width = num_pages * self.paged_attention_page_size - seq_len
75 | if pad_width == 0:
76 | return caches
77 |
78 | return [
79 | (self.pad_sequences(k, pad_width), self.pad_sequences(v, pad_width))
80 | for k, v in caches
81 | ]
82 |
83 | def insert_prefill_cache(
84 | self,
85 | prefill_caches: List[Tuple[jax.Array, jax.Array]],
86 | decode_caches: List[Tuple[jax.Array, jax.Array]],
87 | update_indexes: jax.Array,
88 | tep_kv: jax.Array,
89 | sharding: jsharding.Sharding,
90 | ) -> List[Tuple[jax.Array, jax.Array]]:
91 | """Insert prefill caches to decode caches.
92 |
93 | Args:
94 | prefill_caches: List of Tuple K, V. For each K, V:
95 | [batch_size, num_heads, seq_len, head_dim] jax.Array.
96 | decode_caches: List of Tuple K, V. For each K, V:
97 | [num_heads, paged_attention_total_num_pages, paged_attention_page_size, head_dim] jax.Array.
98 | update_indexes: Page indexes for insertion.
99 | tep_kv: List of Tuple K, V. For each K, V:
100 | kv_heads, num_pages * .paged_attention_page_size, dim.
101 | sharding: Decode cache sharding.
102 |
103 |
104 | Returns:
105 | Decode cache. List of Tuple K, V. For each K, V:
106 | [num_heads, paged_attention_total_num_pages, paged_attention_page_size, head_dim] jax.Array.
107 | """
108 | # Reduce cache batch deminsion
109 | # [kv_heads, seq_len, dim]
110 | squeezed_caches = [
111 | (jnp.squeeze(k, axis=0), jnp.squeeze(v, axis=0))
112 | for k, v in prefill_caches
113 | ]
114 | tmp_caches = [
115 | (
116 | tep_kv.at[:, : k.shape[1], :].set(k),
117 | tep_kv.at[:, : v.shape[1], :].set(v),
118 | )
119 | for k, v in squeezed_caches
120 | ]
121 | kv_heads, _, dim = tmp_caches[0][0].shape
122 | # [kv_heads, num_pages, paged_attention_page_size, dim]
123 | paged_caches = [
124 | (
125 | jnp.reshape(k, (kv_heads, -1, self.paged_attention_page_size, dim)),
126 | jnp.reshape(v, (kv_heads, -1, self.paged_attention_page_size, dim)),
127 | )
128 | for k, v in tmp_caches
129 | ]
130 |
131 | @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True)
132 | def insert(cache, new_entry):
133 | res = cache.at[:, update_indexes, :, :].set(new_entry)
134 | res = jax.lax.with_sharding_constraint(res, sharding)
135 | return res
136 |
137 | caches = [
138 | (insert(k, newk), insert(v, newv))
139 | for (k, v), (newk, newv) in zip(decode_caches, paged_caches)
140 | ]
141 |
142 | return caches
143 |
144 | # pylint: disable-next=all
145 | def get_page_token_indices(self, lens):
146 | # assert lens.shape == (
147 | # self.batch_size,
148 | # 1,
149 | # ), f"len shape: {lens.shape} not equals batch size: {self.batch_size, 1}"
150 | update_page_indices = []
151 | token_scale_indices = []
152 | batch_slots = []
153 | offset = 0
154 |
155 | for slot in range(self.batch_size):
156 | seq_len = lens[slot]
157 | if seq_len == 0:
158 | continue
159 | num_pages = seq_len // self.paged_attention_page_size + 1
160 | token_pos = seq_len % self.paged_attention_page_size
161 | page_index = self.page_indices[slot, num_pages - 1]
162 |
163 | update_page_indices.append(page_index)
164 | token_scale_indices.append(offset + token_pos)
165 | batch_slots.append(slot)
166 | offset += self.paged_attention_page_size
167 | self.lengths = np.where(lens == 0, 0, lens + 1)
168 | update_page_indices = np.asarray(update_page_indices)
169 | token_scale_indices = np.asarray(token_scale_indices)
170 | batch_slots = np.asarray(batch_slots)
171 | return np.stack(
172 | (
173 | update_page_indices,
174 | token_scale_indices,
175 | batch_slots,
176 | )
177 | )
178 |
179 | # pylint: disable-next=all
180 | def get_compress_kv_cache(
181 | self,
182 | decode_caches: List[Tuple[jax.Array, jax.Array]],
183 | slot: int,
184 | ) -> List[Tuple[jax.Array, jax.Array]]:
185 | lens = self.lengths[slot]
186 | indices = self.page_indices[slot]
187 | return [
188 | (
189 | self._compress_cache(k, lens, indices),
190 | self._compress_cache(v, lens, indices),
191 | )
192 | for k, v in decode_caches
193 | ]
194 |
195 | def _compress_cache(self, cache: jax.Array, lens: int, indices: jax.Array):
196 | head, _, _, dim = cache.shape
197 | selected_cache = cache[:, indices, :, :]
198 | selected_cache = selected_cache.reshape((head, -1, dim))
199 | selected_cache = selected_cache[:, 0:lens, :]
200 | return selected_cache
201 |
202 | # pylint: disable-next=all
203 | def pad_sequences(self, array, pad_width=10):
204 | padding_config = [
205 | (0, 0),
206 | (0, 0),
207 | (0, pad_width),
208 | (0, 0),
209 | ] # Pad only seq_len and dim
210 | padded_array = jnp.pad(array, padding_config, mode="constant")
211 | return padded_array
212 |
213 | # pylint: disable-next=all
214 | def free_pages_resource(self, slot):
215 | for i in range(self.max_pages_per_sequence):
216 | index = self.page_indices[slot, i]
217 | if index < 0:
218 | break
219 | self.unused_pages.put(index, block=False)
220 |
221 | self.page_indices = self.page_indices.at[slot, :].set(jnp.asarray([0]))
222 | return None
223 |
--------------------------------------------------------------------------------
/jetstream_pt/quantize.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Tuple, Union
16 |
17 | import jax
18 | import jax.numpy as jnp
19 | import torch
20 |
21 | EPS = 1e-5
22 |
23 |
24 | def quantize_tensor(
25 | w: torch.Tensor,
26 | reduce_axis: Union[Tuple[int], int],
27 | n_bit: int = 8,
28 | symmetric: bool = True,
29 | block_size: int = -1,
30 | ):
31 | """
32 | Quantize weight tensor w along 'reduce_axis'.
33 |
34 | Args:
35 | w: weight tensor to be quantized.
36 | reduce_axis: axises along which to quantize.
37 | n_bit: Quantize to n_bit bits. (Use int8 container for n_bits < 8).
38 | symmetric: Whether quantization is symmetric.
39 | block_size: Blocksize for blockwise quantization. -1 for per-channel quant.
40 |
41 | Return:
42 | w_q: Quantized weight in int8 container
43 | scale: scalar for quantized tensor
44 | zero_point: zero_point for quantized tensor, None if symmetric quantization
45 | """
46 |
47 | assert 0 < n_bit <= 8, "Quantization bits must be between [1, 8]."
48 | if isinstance(reduce_axis, int):
49 | reduce_axis = (reduce_axis,)
50 |
51 | if block_size > 0:
52 | axis = reduce_axis[0]
53 | w_shape = w.shape
54 | assert w_shape[axis] % block_size == 0
55 | w = w.reshape(w_shape[:axis] + (-1, block_size) + w_shape[axis + 1 :])
56 | reduce_axis = axis + 1
57 |
58 | max_int = 2 ** (n_bit - 1) - 1
59 | min_int = -(2 ** (n_bit - 1))
60 | if not symmetric:
61 | max_val = w.amax(dim=reduce_axis, keepdim=True)
62 | min_val = w.amin(dim=reduce_axis, keepdim=True)
63 | scales = (max_val - min_val).clamp(min=EPS) / float(max_int - min_int)
64 | zero_point = min_int - min_val / scales
65 | else:
66 | max_val = w.abs().amax(dim=reduce_axis, keepdim=True)
67 | max_val = max_val.clamp(min=EPS)
68 | scales = max_val / max_int
69 | zero_point = 0
70 |
71 | w = torch.clamp(
72 | torch.round(w * (1.0 / scales) + zero_point), min_int, max_int
73 | ).to(torch.int8)
74 |
75 | return w, scales, zero_point if not symmetric else None
76 |
77 |
78 | def dequantize_tensor(w, scale, zero_point=None):
79 | """Dequantize tensor quantized by quantize_tensor."""
80 | if zero_point is not None:
81 | return (w - zero_point) * scale
82 |
83 | return w * scale
84 |
85 |
86 | def load_q_weight_helper(w_q, scale, zp=None, block_size=-1):
87 | """Helper function to update the shape of quantized weight to match
88 | what quantized linear layer expects."""
89 | if block_size < 0:
90 | w_q = w_q.to(torch.int8)
91 | if zp is not None:
92 | zp = (zp * scale).squeeze(-1).to(torch.bfloat16)
93 | scale = scale.squeeze(-1).to(torch.bfloat16)
94 | else:
95 | w_q = w_q.permute(1, 2, 0).to(torch.int8)
96 | if zp is not None:
97 | zp = (zp * scale).transpose(1, 0).squeeze(-1).to(torch.bfloat16)
98 | scale = scale.transpose(1, 0).squeeze(-1).to(torch.bfloat16)
99 | return w_q, scale, zp
100 |
101 |
102 | def blockwise_jax_kernel(inputs, weight, weight_scaler, zero_point):
103 | """Blockwise Matmul kernel impl in JAX using einsum"""
104 | weight = weight.astype(jnp.int8)
105 | block_size = weight.shape[1]
106 | inputs_shape = inputs.shape
107 | inputs_new_shape = inputs_shape[:-1] + (
108 | inputs_shape[-1] // block_size,
109 | block_size,
110 | )
111 | inputs = inputs.reshape(inputs_new_shape)
112 | out = jnp.einsum("scz,bdsc->bdsz", weight, inputs)
113 | out = jnp.einsum("bdsz,sz->bdz", out, weight_scaler)
114 | if zero_point is not None:
115 | zp_out = jnp.einsum("bdsc,sz->bdz", inputs, zero_point)
116 | out = out - zp_out
117 | return out
118 |
119 |
120 | def blockwise_jax_kernel_dot_general(inputs, weight, weight_scaler, zero_point):
121 | """Blockwise Matmul kernel impl in JAX using dot general"""
122 | inputs_shape = inputs.shape
123 | block_size = weight.shape[2]
124 | bs = inputs_shape[0]
125 | inputs_new_shape = inputs_shape[:-1] + (
126 | inputs_shape[-1] // block_size,
127 | block_size,
128 | )
129 | inputs = inputs.reshape(inputs_new_shape)
130 | inputs = jax.lax.collapse(inputs, 0, 2)
131 | out = jax.lax.dot_general(
132 | inputs, weight, dimension_numbers=([(2), (2)], [(1), (0)])
133 | )
134 | out = jax.lax.dot_general(
135 | out, weight_scaler, dimension_numbers=([(0), (0)], [(2), (1)])
136 | )
137 | out = jax.lax.transpose(out, [1, 0])
138 | out = out.reshape((bs, -1) + out.shape[1:])
139 | return out
140 |
141 |
142 | def blockwise_jax_kernel_einsum_flatten(
143 | inputs, weight, weight_scaler, zero_point
144 | ):
145 | """Blockwise Matmul kernel impl in JAX using einsum, with operands flattened"""
146 | weight = weight.astype(jnp.int8)
147 | block_size = weight.shape[1]
148 | inputs_shape = inputs.shape
149 | bs = inputs_shape[0]
150 | inputs_new_shape = inputs_shape[:-1] + (
151 | inputs_shape[-1] // block_size,
152 | block_size,
153 | )
154 | inputs = inputs.reshape(inputs_new_shape)
155 | inputs = jax.lax.collapse(inputs, 0, 2)
156 | out = jnp.einsum("scz,bsc->bsz", weight, inputs)
157 | out = jnp.einsum("bsz,sz->bz", out, weight_scaler)
158 | out = out.reshape((bs, -1) + out.shape[1:])
159 | return out
160 |
--------------------------------------------------------------------------------
/jetstream_pt/quantize_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .environment import QuantizationConfig
3 | from .layers import (
4 | create_quantized_from_nn_linear,
5 | create_quantized_from_nn_embedding,
6 | AttentionKernel,
7 | Int8KVAttentionKernel,
8 | )
9 |
10 |
11 | def quantize_model(float_model, config: QuantizationConfig):
12 | """Apply quantization to linear layers."""
13 | exclude_mods = None
14 | if config.exclude_layers:
15 | exclude_mods = [
16 | module
17 | for name, module in float_model.named_modules()
18 | if name in config.exclude_layers
19 | ]
20 |
21 | def quantize_nn_mod(float_model):
22 | for name, mod in float_model.named_modules():
23 | new_mod = None
24 | if config.exclude_layers and mod in exclude_mods:
25 | continue
26 | if hasattr(mod, "get_quantized_version"):
27 | new_mod = mod.get_quantized_version()
28 | elif isinstance(mod, torch.nn.Linear):
29 | new_mod = create_quantized_from_nn_linear(mod, config)
30 | elif isinstance(mod, torch.nn.Embedding):
31 | new_mod = create_quantized_from_nn_embedding(mod, config)
32 |
33 | if new_mod:
34 | setattr(float_model, name, new_mod)
35 |
36 | if config.enable_kv_quantization:
37 | for name, mod in float_model.__dict__.items():
38 | if isinstance(mod, AttentionKernel):
39 | new_mod = Int8KVAttentionKernel(mod.env, mod.layer_id)
40 | setattr(float_model, name, new_mod)
41 |
42 | float_model.apply(quantize_nn_mod)
43 | return float_model
44 |
--------------------------------------------------------------------------------
/jetstream_pt/third_party/gemma/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/jetstream_pt/third_party/gemma/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Gemma model config."""
16 |
17 | import dataclasses
18 | import torch
19 | from typing import Optional
20 |
21 |
22 | # Keep a mapping from dtype strings to the supported torch dtypes.
23 | _STR_DTYPE_TO_TORCH_DTYPE = dict(
24 | {
25 | "float16": torch.float16,
26 | "float": torch.float32,
27 | "float32": torch.float32,
28 | "bfloat16": torch.bfloat16,
29 | }
30 | )
31 |
32 |
33 | @dataclasses.dataclass
34 | class GemmaConfig:
35 | # The number of tokens in the vocabulary.
36 | vocab_size: int = 256000
37 | # The maximum sequence length that this model might ever be used with.
38 | max_position_embeddings: int = 8192
39 | # The number of blocks in the model.
40 | num_hidden_layers: int = 28
41 | # The number of attention heads used in the attention layers of the model.
42 | num_attention_heads: int = 16
43 | # The number of key-value heads for implementing attention.
44 | num_key_value_heads: int = 16
45 | # The hidden size of the model.
46 | hidden_size: int = 3072
47 | # The dimension of the MLP representations.
48 | intermediate_size: int = 24576
49 | # The number of head dimensions.
50 | head_dim: int = 256
51 | # The epsilon used by the rms normalization layers.
52 | rms_norm_eps: float = 1e-6
53 | # The dtype of the weights.
54 | dtype: str = "bfloat16"
55 | # Whether a quantized version of the model is used.
56 | quant: bool = False
57 | # The path to the model tokenizer.
58 | tokenizer: Optional[str] = "tokenizer/tokenizer.model"
59 |
60 | device: str = "meta"
61 |
62 | def get_dtype(self) -> Optional[torch.dtype]:
63 | """Gets the torch dtype from the config dtype string."""
64 | return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None)
65 |
66 |
67 | def get_config_for_7b() -> GemmaConfig:
68 | return GemmaConfig()
69 |
70 |
71 | def get_config_for_2b() -> GemmaConfig:
72 | return GemmaConfig(
73 | num_hidden_layers=18,
74 | num_attention_heads=8,
75 | num_key_value_heads=1,
76 | hidden_size=2048,
77 | intermediate_size=16384,
78 | )
79 |
80 |
81 | def get_model_config(variant: str) -> GemmaConfig:
82 | if variant == "7b":
83 | return get_config_for_7b()
84 | elif variant == "2b":
85 | return get_config_for_2b()
86 | return ValueError(
87 | f'Invalid variant {variant}. Supported variants are "2b"' 'and "7b"'
88 | )
89 |
--------------------------------------------------------------------------------
/jetstream_pt/third_party/gemma/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | from typing import List, Optional
16 |
17 | from sentencepiece import SentencePieceProcessor
18 |
19 |
20 | class Tokenizer:
21 |
22 | def __init__(self, model_path: Optional[str]):
23 | # Reload tokenizer.
24 | assert os.path.isfile(model_path), model_path
25 | self.sp_model = SentencePieceProcessor(model_file=model_path)
26 |
27 | # BOS / EOS token IDs.
28 | self.n_words: int = self.sp_model.vocab_size()
29 | self.bos_id: int = self.sp_model.bos_id()
30 | self.eos_id: int = self.sp_model.eos_id()
31 | self.pad_id: int = self.sp_model.pad_id()
32 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
33 |
34 | def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]:
35 | """Converts a string into a list of tokens."""
36 | assert isinstance(s, str)
37 | t = self.sp_model.encode(s)
38 | if bos:
39 | t = [self.bos_id] + t
40 | if eos:
41 | t = t + [self.eos_id]
42 | return t
43 |
44 | def decode(self, t: List[int]) -> str:
45 | """Converts a list of tokens into a string."""
46 | return self.sp_model.decode(t)
47 |
--------------------------------------------------------------------------------
/jetstream_pt/third_party/llama/LICENSE:
--------------------------------------------------------------------------------
1 | LLAMA 2 COMMUNITY LICENSE AGREEMENT
2 | Llama 2 Version Release Date: July 18, 2023
3 |
4 | "Agreement" means the terms and conditions for use, reproduction, distribution and
5 | modification of the Llama Materials set forth herein.
6 |
7 | "Documentation" means the specifications, manuals and documentation
8 | accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and-
9 | libraries/llama-downloads/.
10 |
11 | "Licensee" or "you" means you, or your employer or any other person or entity (if
12 | you are entering into this Agreement on such person or entity's behalf), of the age
13 | required under applicable laws, rules or regulations to provide legal consent and that
14 | has legal authority to bind your employer or such other person or entity if you are
15 | entering in this Agreement on their behalf.
16 |
17 | "Llama 2" means the foundational large language models and software and
18 | algorithms, including machine-learning model code, trained model weights,
19 | inference-enabling code, training-enabling code, fine-tuning enabling code and other
20 | elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and-
21 | libraries/llama-downloads/.
22 |
23 | "Llama Materials" means, collectively, Meta's proprietary Llama 2 and
24 | Documentation (and any portion thereof) made available under this Agreement.
25 |
26 | "Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you
27 | are an entity, your principal place of business is in the EEA or Switzerland) and Meta
28 | Platforms, Inc. (if you are located outside of the EEA or Switzerland).
29 |
30 | By clicking "I Accept" below or by using or distributing any portion or element of the
31 | Llama Materials, you agree to be bound by this Agreement.
32 |
33 | 1. License Rights and Redistribution.
34 |
35 | a. Grant of Rights. You are granted a non-exclusive, worldwide, non-
36 | transferable and royalty-free limited license under Meta's intellectual property or
37 | other rights owned by Meta embodied in the Llama Materials to use, reproduce,
38 | distribute, copy, create derivative works of, and make modifications to the Llama
39 | Materials.
40 |
41 | b. Redistribution and Use.
42 |
43 | i. If you distribute or make the Llama Materials, or any derivative works
44 | thereof, available to a third party, you shall provide a copy of this Agreement to such
45 | third party.
46 | ii. If you receive Llama Materials, or any derivative works thereof, from
47 | a Licensee as part of an integrated end user product, then Section 2 of this
48 | Agreement will not apply to you.
49 |
50 | iii. You must retain in all copies of the Llama Materials that you
51 | distribute the following attribution notice within a "Notice" text file distributed as a
52 | part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License,
53 | Copyright (c) Meta Platforms, Inc. All Rights Reserved."
54 |
55 | iv. Your use of the Llama Materials must comply with applicable laws
56 | and regulations (including trade compliance laws and regulations) and adhere to the
57 | Acceptable Use Policy for the Llama Materials (available at
58 | https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into
59 | this Agreement.
60 |
61 | v. You will not use the Llama Materials or any output or results of the
62 | Llama Materials to improve any other large language model (excluding Llama 2 or
63 | derivative works thereof).
64 |
65 | 2. Additional Commercial Terms. If, on the Llama 2 version release date, the
66 | monthly active users of the products or services made available by or for Licensee,
67 | or Licensee's affiliates, is greater than 700 million monthly active users in the
68 | preceding calendar month, you must request a license from Meta, which Meta may
69 | grant to you in its sole discretion, and you are not authorized to exercise any of the
70 | rights under this Agreement unless or until Meta otherwise expressly grants you
71 | such rights.
72 |
73 | 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE
74 | LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE
75 | PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
76 | EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY
77 | WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR
78 | FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE
79 | FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING
80 | THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR
81 | USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS.
82 |
83 | 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE
84 | LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT,
85 | NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS
86 | AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL,
87 | CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN
88 | IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF
89 | ANY OF THE FOREGOING.
90 |
91 | 5. Intellectual Property.
92 |
93 | a. No trademark licenses are granted under this Agreement, and in
94 | connection with the Llama Materials, neither Meta nor Licensee may use any name
95 | or mark owned by or associated with the other or any of its affiliates, except as
96 | required for reasonable and customary use in describing and redistributing the
97 | Llama Materials.
98 |
99 | b. Subject to Meta's ownership of Llama Materials and derivatives made by or
100 | for Meta, with respect to any derivative works and modifications of the Llama
101 | Materials that are made by you, as between you and Meta, you are and will be the
102 | owner of such derivative works and modifications.
103 |
104 | c. If you institute litigation or other proceedings against Meta or any entity
105 | (including a cross-claim or counterclaim in a lawsuit) alleging that the Llama
106 | Materials or Llama 2 outputs or results, or any portion of any of the foregoing,
107 | constitutes an infringement of intellectual property or other rights owned or licensable
108 | by you, then any licenses granted to you under this Agreement shall terminate as of
109 | the date such litigation or claim is filed or instituted. You will indemnify and hold
110 | harmless Meta from and against any claim by any third party arising out of or related
111 | to your use or distribution of the Llama Materials.
112 |
113 | 6. Term and Termination. The term of this Agreement will commence upon your
114 | acceptance of this Agreement or access to the Llama Materials and will continue in
115 | full force and effect until terminated in accordance with the terms and conditions
116 | herein. Meta may terminate this Agreement if you are in breach of any term or
117 | condition of this Agreement. Upon termination of this Agreement, you shall delete
118 | and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the
119 | termination of this Agreement.
120 |
121 | 7. Governing Law and Jurisdiction. This Agreement will be governed and
122 | construed under the laws of the State of California without regard to choice of law
123 | principles, and the UN Convention on Contracts for the International Sale of Goods
124 | does not apply to this Agreement. The courts of California shall have exclusive
125 | jurisdiction of any dispute arising out of this Agreement.
126 |
127 |
--------------------------------------------------------------------------------
/jetstream_pt/third_party/llama/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/jetstream-pytorch/f4d775358c3eadcf34d4ab0ba244cb1167f7a628/jetstream_pt/third_party/llama/__init__.py
--------------------------------------------------------------------------------
/jetstream_pt/third_party/llama/model_args.py:
--------------------------------------------------------------------------------
1 | # pylint: disable-all
2 | """The original Llama2 model."""
3 |
4 | import dataclasses
5 | from typing import Optional
6 |
7 |
8 | @dataclasses.dataclass
9 | class RopeScalingArgs:
10 | """Rope scaling configuration parameters."""
11 |
12 | factor: float = 8.0
13 | low_freq_factor: float = 1.0
14 | high_freq_factor: float = 4.0
15 | original_max_position_embeddings: int = 8192
16 |
17 |
18 | @dataclasses.dataclass
19 | class ModelArgs:
20 | """Model configuration parameters."""
21 |
22 | dim: int = -1
23 | n_layers: int = -1
24 | n_heads: int = -1
25 | n_kv_heads: Optional[int] = None
26 | vocab_size: int = -1 # defined later by tokenizer
27 | multiple_of: int = (
28 | 256 # make SwiGLU hidden layer size multiple of large power of 2
29 | )
30 | ffn_dim_multiplier: Optional[float] = None
31 | norm_eps: float = 1e-5
32 |
33 | max_batch_size: int = -1
34 | max_seq_len: int = -1
35 |
36 | bf16_enable: bool = False
37 | head_dim = -1
38 | infer_length = 0
39 | device = "cpu"
40 |
41 | rope_theta: float = 10000.0
42 | rope_scaling_args: RopeScalingArgs = None
43 |
44 |
45 | def get_arg(
46 | model_name: str,
47 | seqlen,
48 | batch_size,
49 | bf16_enable: bool = False,
50 | ) -> ModelArgs:
51 | """Gets model args."""
52 |
53 | data = {}
54 | if model_name == "llama-2-tiny":
55 | data = {
56 | "dim": 128,
57 | "vocab_size": 32000,
58 | "multiple_of": 32,
59 | "n_heads": 64,
60 | "n_kv_heads": 8,
61 | "n_layers": 3,
62 | "norm_eps": 1e-05,
63 | }
64 | elif model_name == "llama-2-7b":
65 | data = {
66 | "dim": 4096,
67 | "vocab_size": 32000,
68 | "multiple_of": 256,
69 | "n_heads": 32,
70 | "n_layers": 32,
71 | "norm_eps": 1e-05,
72 | }
73 | elif model_name == "llama-2-13b":
74 | data = {
75 | "dim": 5120,
76 | "vocab_size": 32000,
77 | "multiple_of": 256,
78 | "n_heads": 40,
79 | "n_layers": 40,
80 | "norm_eps": 1e-05,
81 | }
82 | elif model_name == "llama-2-70b":
83 | data = {
84 | "dim": 8192,
85 | "vocab_size": 32000,
86 | "multiple_of": 4096,
87 | "ffn_dim_multiplier": 1.3,
88 | "n_heads": 64,
89 | "n_kv_heads": 8,
90 | "n_layers": 80,
91 | "norm_eps": 1e-05,
92 | }
93 | elif model_name == "llama-3-8b":
94 | data = {
95 | "dim": 4096,
96 | "vocab_size": 128256,
97 | "multiple_of": 1024,
98 | "ffn_dim_multiplier": 1.3,
99 | "n_layers": 32,
100 | "n_heads": 32,
101 | "n_kv_heads": 8,
102 | "norm_eps": 1e-05,
103 | "rope_theta": 500000.0,
104 | }
105 | elif model_name == "llama-3-70b":
106 | data = {
107 | "dim": 8192,
108 | "ffn_dim_multiplier": 1.3,
109 | "multiple_of": 4096,
110 | "n_heads": 64,
111 | "n_kv_heads": 8,
112 | "n_layers": 80,
113 | "norm_eps": 1e-05,
114 | "vocab_size": 128256,
115 | "rope_theta": 500000.0,
116 | }
117 | elif model_name == "llama-3.1-8b":
118 | data = {
119 | "dim": 4096,
120 | "vocab_size": 128256,
121 | "multiple_of": 1024,
122 | "ffn_dim_multiplier": 1.3,
123 | "n_layers": 32,
124 | "n_heads": 32,
125 | "n_kv_heads": 8,
126 | "norm_eps": 1e-05,
127 | "rope_theta": 500000.0,
128 | "rope_scaling_args": RopeScalingArgs(
129 | factor=8.0,
130 | low_freq_factor=1.0,
131 | high_freq_factor=4.0,
132 | original_max_position_embeddings=8192,
133 | ),
134 | }
135 | elif model_name == "llama-3.2-1b":
136 | data = {
137 | "dim": 2048,
138 | "vocab_size": 128256,
139 | "multiple_of": 1024,
140 | "ffn_dim_multiplier": 1.5,
141 | "n_layers": 16,
142 | "n_heads": 32,
143 | "n_kv_heads": 8,
144 | "norm_eps": 1e-05,
145 | "rope_theta": 500000.0,
146 | "rope_scaling_args": RopeScalingArgs(
147 | factor=32.0,
148 | low_freq_factor=1.0,
149 | high_freq_factor=4.0,
150 | original_max_position_embeddings=8192,
151 | ),
152 | }
153 | elif model_name == "llama-3.3-70b":
154 | data = {
155 | "dim": 8192,
156 | "vocab_size": 128256,
157 | "multiple_of": 1024,
158 | "ffn_dim_multiplier": 1.3,
159 | "n_layers": 80,
160 | "n_heads": 64,
161 | "n_kv_heads": 8,
162 | "norm_eps": 1e-05,
163 | "rope_theta": 500000.0,
164 | "rope_scaling_args": RopeScalingArgs(
165 | factor=8.0,
166 | low_freq_factor=1.0,
167 | high_freq_factor=4.0,
168 | original_max_position_embeddings=8192,
169 | ),
170 | }
171 |
172 | return ModelArgs(
173 | max_seq_len=seqlen,
174 | max_batch_size=batch_size,
175 | bf16_enable=bf16_enable,
176 | **data,
177 | )
178 |
179 |
180 | def get_model_args(model_name, context_length, batch_size, bf16_enable):
181 | model_args = get_arg(
182 | model_name=model_name,
183 | seqlen=context_length,
184 | batch_size=batch_size,
185 | bf16_enable=bf16_enable,
186 | )
187 | model_args.n_kv_heads = (
188 | model_args.n_heads
189 | if model_args.n_kv_heads is None
190 | else model_args.n_kv_heads
191 | )
192 | model_args.head_dim = model_args.dim // model_args.n_heads
193 | return model_args
194 |
--------------------------------------------------------------------------------
/jetstream_pt/third_party/llama/tokenizer.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/jetstream-pytorch/f4d775358c3eadcf34d4ab0ba244cb1167f7a628/jetstream_pt/third_party/llama/tokenizer.model
--------------------------------------------------------------------------------
/jetstream_pt/third_party/llama/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3 |
4 | import os
5 | from typing import List
6 |
7 | from sentencepiece import SentencePieceProcessor
8 |
9 |
10 | """Only use decode to do accuacy varification"""
11 |
12 |
13 | class Tokenizer:
14 | """tokenizing and encoding/decoding text using SentencePiece."""
15 |
16 | def __init__(self, model_path: str):
17 | """
18 | Initializes the Tokenizer with a SentencePiece model.
19 |
20 | Args:
21 | model_path (str): The path to the SentencePiece model file.
22 | """
23 | # reload tokenizer
24 | print(f"model_path: {model_path}")
25 | assert os.path.isfile(model_path), model_path
26 | self.sp_model = SentencePieceProcessor(model_file=model_path)
27 |
28 | # BOS / EOS token IDs
29 | self.n_words: int = self.sp_model.vocab_size()
30 | self.bos_id: int = self.sp_model.bos_id()
31 | self.eos_id: int = self.sp_model.eos_id()
32 | self.pad_id: int = self.sp_model.pad_id()
33 |
34 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
35 |
36 | def decode(self, t: List[int]) -> str:
37 | """
38 | Decodes a list of token IDs into a string.
39 |
40 | Args:
41 | t (List[int]): The list of token IDs to be decoded.
42 |
43 | Returns:
44 | str: The decoded string.
45 | """
46 | return self.sp_model.decode(t)
47 |
--------------------------------------------------------------------------------
/jetstream_pt/third_party/mixtral/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/jetstream-pytorch/f4d775358c3eadcf34d4ab0ba244cb1167f7a628/jetstream_pt/third_party/mixtral/__init__.py
--------------------------------------------------------------------------------
/jetstream_pt/third_party/mixtral/config.py:
--------------------------------------------------------------------------------
1 | # pylint: disable-all
2 | # # Copyright 2024 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Mixtral model config
17 | import dataclasses
18 | from dataclasses import dataclass
19 |
20 |
21 | def find_multiple(n: int, k: int) -> int:
22 | if n % k == 0:
23 | return n
24 | return n + k - (n % k)
25 |
26 |
27 | @dataclass
28 | class ModelArgs:
29 | block_size: int = 2048
30 | vocab_size: int = 32000
31 | n_layer: int = 32
32 | n_head: int = 32
33 | dim: int = 4096
34 | intermediate_size: int = None
35 | n_local_heads: int = -1
36 | head_dim: int = 64
37 | rope_base: float = 10000
38 | norm_eps: float = 1e-5
39 | num_experts: int = 8
40 | num_activated_experts: int = 2
41 | device: str = "meta"
42 |
43 | def __post_init__(self):
44 | if self.n_local_heads == -1:
45 | self.n_local_heads = self.n_head
46 | if self.intermediate_size is None:
47 | hidden_dim = 4 * self.dim
48 | n_hidden = int(2 * hidden_dim / 3)
49 | self.intermediate_size = find_multiple(n_hidden, 256)
50 | self.head_dim = self.dim // self.n_head
51 |
52 | @classmethod
53 | def from_name(cls, name: str):
54 | if name in transformer_configs:
55 | return cls(**transformer_configs[name])
56 | # fuzzy search
57 | config = [
58 | config
59 | for config in transformer_configs
60 | if config in str(name).upper() or config in str(name)
61 | ]
62 | assert len(config) == 1, name
63 | return cls(**transformer_configs[config[0]])
64 |
65 |
66 | transformer_configs = {
67 | "Mixtral-8x7B-v0.1": dict(
68 | block_size=32768,
69 | n_layer=32,
70 | n_head=32,
71 | n_local_heads=8,
72 | dim=4096,
73 | intermediate_size=14336,
74 | rope_base=1000000.0,
75 | num_experts=8,
76 | num_activated_experts=2,
77 | ),
78 | "Mixtral-tiny": dict(
79 | block_size=128,
80 | n_layer=3,
81 | n_head=32,
82 | n_local_heads=8,
83 | dim=128,
84 | intermediate_size=None,
85 | rope_base=1000000.0,
86 | num_experts=8,
87 | num_activated_experts=2,
88 | ),
89 | }
90 |
--------------------------------------------------------------------------------
/jetstream_pt/third_party/mixtral/tokenizer.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AI-Hypercomputer/jetstream-pytorch/f4d775358c3eadcf34d4ab0ba244cb1167f7a628/jetstream_pt/third_party/mixtral/tokenizer.model
--------------------------------------------------------------------------------
/jetstream_pt/torchjax.py:
--------------------------------------------------------------------------------
1 | """This file will serve as proxy APIs for torch_xla2 API.
2 |
3 | It serves 2 purposes:
4 |
5 | 1. torch_xla2 APIs are not
6 | stable yet, and changes of it means lots of code edits throughout
7 | this repo. So future changes of torch_xla2 API we only need to edit
8 | this one file.
9 |
10 | 2. We can iterate API look and feel in this file and the influence
11 | how it looks like in torch_xla2.
12 | """
13 |
14 | import torch
15 | from torch.utils import _pytree as pytree
16 |
17 | import torch_xla2
18 | import torch_xla2.interop
19 |
20 | call_jax = torch_xla2.interop.call_jax
21 | call_torch = torch_xla2.interop.call_torch
22 |
23 |
24 | def to_torch(tensors):
25 | """Wrap a jax Array into XLATensor."""
26 | return torch_xla2.default_env().j2t_iso(tensors)
27 |
28 |
29 | def from_torch_with_copy(tensors):
30 | """Convert torch tensor to Jax Array."""
31 |
32 | def convert_tensor(t):
33 | if isinstance(t, torch_xla2.tensor.XLATensor2):
34 | return t.jax()
35 | return torch_xla2.tensor.t2j(t)
36 |
37 | return pytree.tree_map_only(torch.Tensor, convert_tensor, tensors)
38 |
39 |
40 | def from_torch(tensors):
41 | """Unwrap a XLATensor into jax Array.
42 |
43 | Will raise if passed in a torch.Tensor that is not XLATensor
44 | """
45 | return torch_xla2.default_env().t2j_iso(tensors)
46 |
--------------------------------------------------------------------------------
/kuberay/image/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM rayproject/ray:2.32.0-py310
2 |
3 | RUN pip install flax==0.8.3
4 | RUN pip install jax[tpu]==0.4.30 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
5 | RUN pip install tensorflow-text
6 | RUN pip install tensorflow
7 |
8 | RUN pip install torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu
9 | RUN pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage
10 | RUN pip install safetensors colorama coverage humanize
11 |
12 | RUN git clone https://github.com/google/jetstream-pytorch
13 | WORKDIR jetstream-pytorch
14 |
15 | RUN git submodule update --init --recursive
16 | RUN pip install -e .
17 |
--------------------------------------------------------------------------------
/kuberay/manifests/ray-cluster.tpu-v4-multihost.yaml:
--------------------------------------------------------------------------------
1 | # This template contains a Kuberay cluster using a 2x2x2 TPU v4 PodSlice.
2 | # To get access to TPU resources, please follow instructions in this link:
3 | # https://cloud.google.com/kubernetes-engine/docs/how-to/tpus
4 | apiVersion: ray.io/v1
5 | kind: RayCluster
6 | metadata:
7 | name: example-cluster-kuberay
8 | spec:
9 | headGroupSpec:
10 | rayStartParams:
11 | {}
12 | template:
13 | spec:
14 | imagePullSecrets:
15 | []
16 | serviceAccountName: ray-ksa
17 | containers:
18 | - volumeMounts:
19 | - name: gcs-fuse-checkpoint
20 | mountPath: /llama
21 | readOnly: true
22 | - mountPath: /tmp/ray
23 | name: ray-logs
24 | name: ray-head
25 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729
26 | imagePullPolicy: IfNotPresent
27 | resources:
28 | limits:
29 | cpu: "4"
30 | ephemeral-storage: 30Gi
31 | memory: 40G
32 | requests:
33 | cpu: "4"
34 | ephemeral-storage: 30Gi
35 | memory: 40G
36 | securityContext:
37 | {}
38 | env:
39 | - name: JAX_PLATFORMS
40 | value: "cpu"
41 | - name: RAY_memory_monitor_refresh_ms
42 | value: "0"
43 | - name: RAY_GRAFANA_IFRAME_HOST
44 | value: http://${grafana_host}
45 | - name: RAY_GRAFANA_HOST
46 | value: http://grafana:80
47 | - name: RAY_PROMETHEUS_HOST
48 | value: http://frontend:9090
49 | ports:
50 | - containerPort: 6379
51 | name: gcs
52 | - containerPort: 8265
53 | name: dashboard
54 | - containerPort: 10001
55 | name: client
56 | - containerPort: 8000
57 | name: serve
58 | - containerPort: 8471
59 | name: slicebuilder
60 | - containerPort: 8081
61 | name: mxla
62 | - containerPort: 8888
63 | name: grpc
64 | volumes:
65 | - emptyDir: {}
66 | name: ray-logs
67 | - name: gcs-fuse-checkpoint
68 | csi:
69 | driver: gcsfuse.csi.storage.gke.io
70 | readOnly: true
71 | volumeAttributes:
72 | bucketName: ricliu-llama2-70b-chat
73 | mountOptions: "implicit-dirs"
74 | metadata:
75 | annotations:
76 | gke-gcsfuse/volumes: "true"
77 | labels:
78 | cloud.google.com/gke-ray-node-type: head
79 | app.kubernetes.io/name: kuberay
80 | app.kubernetes.io/instance: example-cluster
81 |
82 | workerGroupSpecs:
83 | - rayStartParams:
84 | {}
85 | replicas: 1
86 | minReplicas: 1
87 | maxReplicas: 1
88 | numOfHosts: 2
89 | groupName: workergroup
90 | template:
91 | spec:
92 | imagePullSecrets:
93 | []
94 | serviceAccountName: ray-ksa
95 | containers:
96 | - volumeMounts:
97 | - mountPath: /tmp/ray
98 | name: ray-logs
99 | - name: gcs-fuse-checkpoint
100 | mountPath: /llama
101 | readOnly: true
102 | name: ray-worker
103 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729
104 | imagePullPolicy: IfNotPresent
105 | resources:
106 | limits:
107 | cpu: "8"
108 | ephemeral-storage: 30Gi
109 | google.com/tpu: "4"
110 | memory: 200G
111 | requests:
112 | cpu: "8"
113 | ephemeral-storage: 30Gi
114 | google.com/tpu: "4"
115 | memory: 200G
116 | securityContext:
117 | {}
118 | env:
119 | - name: JAX_PLATFORMS
120 | value: "cpu"
121 | ports:
122 | null
123 | volumes:
124 | - emptyDir: {}
125 | name: ray-logs
126 | - name: gcs-fuse-checkpoint
127 | csi:
128 | driver: gcsfuse.csi.storage.gke.io
129 | readOnly: true
130 | volumeAttributes:
131 | bucketName: ricliu-llama2-70b-chat
132 | mountOptions: "implicit-dirs"
133 | nodeSelector:
134 | cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice
135 | cloud.google.com/gke-tpu-topology: 2x2x2
136 | iam.gke.io/gke-metadata-server-enabled: "true"
137 | metadata:
138 | annotations:
139 | gke-gcsfuse/volumes: "true"
140 | labels:
141 | cloud.google.com/gke-ray-node-type: worker
142 | app.kubernetes.io/name: kuberay
143 | app.kubernetes.io/instance: example-cluster
144 |
145 |
--------------------------------------------------------------------------------
/kuberay/manifests/ray-cluster.tpu-v4-singlehost.yaml:
--------------------------------------------------------------------------------
1 | # This template contains a Kuberay cluster using a 2x2x1 TPU v4 PodSlice.
2 | # To get access to TPU resources, please follow instructions in this link:
3 | # https://cloud.google.com/kubernetes-engine/docs/how-to/tpus
4 | apiVersion: ray.io/v1
5 | kind: RayCluster
6 | metadata:
7 | name: example-cluster-kuberay
8 | spec:
9 | headGroupSpec:
10 | rayStartParams:
11 | {}
12 | template:
13 | spec:
14 | imagePullSecrets:
15 | []
16 | serviceAccountName: ray-ksa
17 | containers:
18 | - volumeMounts:
19 | - name: gcs-fuse-checkpoint
20 | mountPath: /llama
21 | readOnly: true
22 | - mountPath: /tmp/ray
23 | name: ray-logs
24 | name: ray-head
25 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729
26 | imagePullPolicy: IfNotPresent
27 | resources:
28 | limits:
29 | cpu: "4"
30 | ephemeral-storage: 30Gi
31 | memory: 40G
32 | requests:
33 | cpu: "4"
34 | ephemeral-storage: 30Gi
35 | memory: 40G
36 | securityContext:
37 | {}
38 | env:
39 | - name: JAX_PLATFORMS
40 | value: "cpu"
41 | - name: RAY_memory_monitor_refresh_ms
42 | value: "0"
43 | - name: RAY_GRAFANA_IFRAME_HOST
44 | value: http://${grafana_host}
45 | - name: RAY_GRAFANA_HOST
46 | value: http://grafana:80
47 | - name: RAY_PROMETHEUS_HOST
48 | value: http://frontend:9090
49 | ports:
50 | - containerPort: 6379
51 | name: gcs
52 | - containerPort: 8265
53 | name: dashboard
54 | - containerPort: 10001
55 | name: client
56 | - containerPort: 8000
57 | name: serve
58 | - containerPort: 8888
59 | name: grpc
60 | volumes:
61 | - emptyDir: {}
62 | name: ray-logs
63 | - name: gcs-fuse-checkpoint
64 | csi:
65 | driver: gcsfuse.csi.storage.gke.io
66 | readOnly: true
67 | volumeAttributes:
68 | bucketName: ricliu-llama2
69 | mountOptions: "implicit-dirs"
70 | metadata:
71 | annotations:
72 | gke-gcsfuse/volumes: "true"
73 | labels:
74 | cloud.google.com/gke-ray-node-type: head
75 | app.kubernetes.io/name: kuberay
76 | app.kubernetes.io/instance: example-cluster
77 |
78 | workerGroupSpecs:
79 | - rayStartParams:
80 | {}
81 | replicas: 1
82 | minReplicas: 1
83 | maxReplicas: 1
84 | numOfHosts: 1
85 | groupName: workergroup
86 | template:
87 | spec:
88 | imagePullSecrets:
89 | []
90 | serviceAccountName: ray-ksa
91 | containers:
92 | - volumeMounts:
93 | - mountPath: /tmp/ray
94 | name: ray-logs
95 | - name: gcs-fuse-checkpoint
96 | mountPath: /llama
97 | readOnly: true
98 | name: ray-worker
99 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729
100 | imagePullPolicy: IfNotPresent
101 | resources:
102 | limits:
103 | cpu: "8"
104 | ephemeral-storage: 30Gi
105 | google.com/tpu: "4"
106 | memory: 200G
107 | requests:
108 | cpu: "8"
109 | ephemeral-storage: 30Gi
110 | google.com/tpu: "4"
111 | memory: 200G
112 | securityContext:
113 | {}
114 | env:
115 | - name: JAX_PLATFORMS
116 | value: "cpu"
117 | ports:
118 | null
119 | volumes:
120 | - emptyDir: {}
121 | name: ray-logs
122 | - name: gcs-fuse-checkpoint
123 | csi:
124 | driver: gcsfuse.csi.storage.gke.io
125 | readOnly: true
126 | volumeAttributes:
127 | bucketName: ricliu-llama2
128 | mountOptions: "implicit-dirs"
129 | nodeSelector:
130 | cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice
131 | cloud.google.com/gke-tpu-topology: 2x2x1
132 | iam.gke.io/gke-metadata-server-enabled: "true"
133 | metadata:
134 | annotations:
135 | gke-gcsfuse/volumes: "true"
136 | labels:
137 | cloud.google.com/gke-ray-node-type: worker
138 | app.kubernetes.io/name: kuberay
139 | app.kubernetes.io/instance: example-cluster
140 |
141 |
--------------------------------------------------------------------------------
/kuberay/manifests/ray-cluster.tpu-v5-multihost.yaml:
--------------------------------------------------------------------------------
1 | # This template contains a Kuberay cluster using a 2x2x2 TPU v4 PodSlice.
2 | # To get access to TPU resources, please follow instructions in this link:
3 | # https://cloud.google.com/kubernetes-engine/docs/how-to/tpus
4 | apiVersion: ray.io/v1
5 | kind: RayCluster
6 | metadata:
7 | name: example-cluster-kuberay
8 | spec:
9 | headGroupSpec:
10 | rayStartParams:
11 | {}
12 | template:
13 | spec:
14 | imagePullSecrets:
15 | []
16 | serviceAccountName: ray-ksa
17 | containers:
18 | - volumeMounts:
19 | - name: gcs-fuse-checkpoint
20 | mountPath: /llama
21 | readOnly: true
22 | - mountPath: /tmp/ray
23 | name: ray-logs
24 | name: ray-head
25 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729
26 | imagePullPolicy: IfNotPresent
27 | resources:
28 | limits:
29 | cpu: "4"
30 | ephemeral-storage: 30Gi
31 | memory: 40G
32 | requests:
33 | cpu: "4"
34 | ephemeral-storage: 30Gi
35 | memory: 40G
36 | securityContext:
37 | {}
38 | env:
39 | - name: JAX_PLATFORMS
40 | value: "cpu"
41 | - name: RAY_memory_monitor_refresh_ms
42 | value: "0"
43 | - name: RAY_GRAFANA_IFRAME_HOST
44 | value: http://${grafana_host}
45 | - name: RAY_GRAFANA_HOST
46 | value: http://grafana:80
47 | - name: RAY_PROMETHEUS_HOST
48 | value: http://frontend:9090
49 | ports:
50 | - containerPort: 6379
51 | name: gcs
52 | - containerPort: 8265
53 | name: dashboard
54 | - containerPort: 10001
55 | name: client
56 | - containerPort: 8000
57 | name: serve
58 | - containerPort: 8471
59 | name: slicebuilder
60 | - containerPort: 8081
61 | name: mxla
62 | - containerPort: 8888
63 | name: grpc
64 | volumes:
65 | - emptyDir: {}
66 | name: ray-logs
67 | - name: gcs-fuse-checkpoint
68 | csi:
69 | driver: gcsfuse.csi.storage.gke.io
70 | readOnly: true
71 | volumeAttributes:
72 | bucketName: ricliu-llama2-70b-chat
73 | mountOptions: "implicit-dirs"
74 | metadata:
75 | annotations:
76 | gke-gcsfuse/volumes: "true"
77 | labels:
78 | cloud.google.com/gke-ray-node-type: head
79 | app.kubernetes.io/name: kuberay
80 | app.kubernetes.io/instance: example-cluster
81 |
82 | workerGroupSpecs:
83 | - rayStartParams:
84 | {}
85 | replicas: 1
86 | minReplicas: 1
87 | maxReplicas: 1
88 | numOfHosts: 2
89 | groupName: workergroup
90 | template:
91 | spec:
92 | imagePullSecrets:
93 | []
94 | serviceAccountName: ray-ksa
95 | containers:
96 | - volumeMounts:
97 | - mountPath: /tmp/ray
98 | name: ray-logs
99 | - name: gcs-fuse-checkpoint
100 | mountPath: /llama
101 | readOnly: true
102 | name: ray-worker
103 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729
104 | imagePullPolicy: IfNotPresent
105 | resources:
106 | limits:
107 | cpu: "8"
108 | ephemeral-storage: 30Gi
109 | google.com/tpu: "4"
110 | memory: 180G
111 | requests:
112 | cpu: "8"
113 | ephemeral-storage: 30Gi
114 | google.com/tpu: "4"
115 | memory: 180G
116 | securityContext:
117 | {}
118 | env:
119 | - name: JAX_PLATFORMS
120 | value: "cpu"
121 | ports:
122 | null
123 | volumes:
124 | - emptyDir: {}
125 | name: ray-logs
126 | - name: gcs-fuse-checkpoint
127 | csi:
128 | driver: gcsfuse.csi.storage.gke.io
129 | readOnly: true
130 | volumeAttributes:
131 | bucketName: ricliu-llama2-70b-chat
132 | mountOptions: "implicit-dirs"
133 | nodeSelector:
134 | iam.gke.io/gke-metadata-server-enabled: "true"
135 | cloud.google.com/gke-tpu-topology: 2x4
136 | cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
137 | metadata:
138 | annotations:
139 | gke-gcsfuse/volumes: "true"
140 | labels:
141 | cloud.google.com/gke-ray-node-type: worker
142 | app.kubernetes.io/name: kuberay
143 | app.kubernetes.io/instance: example-cluster
144 |
145 |
--------------------------------------------------------------------------------
/kuberay/manifests/ray-cluster.tpu-v5-singlehost.yaml:
--------------------------------------------------------------------------------
1 | # This template contains a Kuberay cluster using a 2x2x1 TPU v4 PodSlice.
2 | # To get access to TPU resources, please follow instructions in this link:
3 | # https://cloud.google.com/kubernetes-engine/docs/how-to/tpus
4 | apiVersion: ray.io/v1
5 | kind: RayCluster
6 | metadata:
7 | name: example-cluster-kuberay
8 | spec:
9 | headGroupSpec:
10 | rayStartParams:
11 | {}
12 | template:
13 | spec:
14 | imagePullSecrets:
15 | []
16 | serviceAccountName: ray-ksa
17 | containers:
18 | - volumeMounts:
19 | - name: gcs-fuse-checkpoint
20 | mountPath: /llama
21 | readOnly: true
22 | - mountPath: /tmp/ray
23 | name: ray-logs
24 | name: ray-head
25 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729
26 | imagePullPolicy: IfNotPresent
27 | resources:
28 | limits:
29 | cpu: "4"
30 | ephemeral-storage: 30Gi
31 | memory: 40G
32 | requests:
33 | cpu: "4"
34 | ephemeral-storage: 30Gi
35 | memory: 40G
36 | securityContext:
37 | {}
38 | env:
39 | - name: JAX_PLATFORMS
40 | value: "cpu"
41 | - name: RAY_memory_monitor_refresh_ms
42 | value: "0"
43 | - name: RAY_GRAFANA_IFRAME_HOST
44 | value: http://${grafana_host}
45 | - name: RAY_GRAFANA_HOST
46 | value: http://grafana:80
47 | - name: RAY_PROMETHEUS_HOST
48 | value: http://frontend:9090
49 | ports:
50 | - containerPort: 6379
51 | name: gcs
52 | - containerPort: 8265
53 | name: dashboard
54 | - containerPort: 10001
55 | name: client
56 | - containerPort: 8000
57 | name: serve
58 | - containerPort: 8888
59 | name: grpc
60 | volumes:
61 | - emptyDir: {}
62 | name: ray-logs
63 | - name: gcs-fuse-checkpoint
64 | csi:
65 | driver: gcsfuse.csi.storage.gke.io
66 | readOnly: true
67 | volumeAttributes:
68 | bucketName: ricliu-llama2
69 | mountOptions: "implicit-dirs"
70 | metadata:
71 | annotations:
72 | gke-gcsfuse/volumes: "true"
73 | labels:
74 | cloud.google.com/gke-ray-node-type: head
75 | app.kubernetes.io/name: kuberay
76 | app.kubernetes.io/instance: example-cluster
77 |
78 | workerGroupSpecs:
79 | - rayStartParams:
80 | {}
81 | replicas: 1
82 | minReplicas: 1
83 | maxReplicas: 1
84 | numOfHosts: 1
85 | groupName: workergroup
86 | template:
87 | spec:
88 | imagePullSecrets:
89 | []
90 | serviceAccountName: ray-ksa
91 | containers:
92 | - volumeMounts:
93 | - mountPath: /tmp/ray
94 | name: ray-logs
95 | - name: gcs-fuse-checkpoint
96 | mountPath: /llama
97 | readOnly: true
98 | name: ray-worker
99 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729
100 | imagePullPolicy: IfNotPresent
101 | resources:
102 | limits:
103 | cpu: "8"
104 | ephemeral-storage: 30Gi
105 | google.com/tpu: "8"
106 | memory: 200G
107 | requests:
108 | cpu: "8"
109 | ephemeral-storage: 30Gi
110 | google.com/tpu: "8"
111 | memory: 200G
112 | securityContext:
113 | {}
114 | env:
115 | - name: JAX_PLATFORMS
116 | value: "cpu"
117 | ports:
118 | null
119 | volumes:
120 | - emptyDir: {}
121 | name: ray-logs
122 | - name: gcs-fuse-checkpoint
123 | csi:
124 | driver: gcsfuse.csi.storage.gke.io
125 | readOnly: true
126 | volumeAttributes:
127 | bucketName: ricliu-llama2
128 | mountOptions: "implicit-dirs"
129 | nodeSelector:
130 | cloud.google.com/gke-tpu-topology: 2x4
131 | cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
132 | iam.gke.io/gke-metadata-server-enabled: "true"
133 | metadata:
134 | annotations:
135 | gke-gcsfuse/volumes: "true"
136 | labels:
137 | cloud.google.com/gke-ray-node-type: worker
138 | app.kubernetes.io/name: kuberay
139 | app.kubernetes.io/instance: example-cluster
140 |
141 |
--------------------------------------------------------------------------------
/mlperf/README.md:
--------------------------------------------------------------------------------
1 | # Run MLPerf tests
2 |
3 | NOTE: currently only tried with mixtral;
4 | and only tried with offline benchmark
5 |
6 | # How to run
7 |
8 | ### 1. Install
9 |
10 | ```
11 | ./install.sh
12 | ```
13 |
14 | ### 2. Start server
15 |
16 | ```
17 | ./start_server.sh
18 | ```
19 |
20 | ### 3. Warm up the server
21 |
22 | ```
23 | python warmup.py
24 | ```
25 |
26 | ### 4. Run the benchmark, now it runs offline mode
27 |
28 | ```
29 | ./benchmark_run.sh
30 | ```
31 |
32 |
--------------------------------------------------------------------------------
/mlperf/benchmark_run.sh:
--------------------------------------------------------------------------------
1 | BASEDIR=mlperf
2 | API_URL=0.0.0.0:9000
3 | USER_CONFIG=$BASEDIR/user.conf
4 | DATA_DISK_DIR=$BASEDIR/data
5 | TOTAL_SAMPLE_COUNT=1000
6 | DATASET_PATH=$BASEDIR/data/mixtral_15k_data.pkl
7 |
8 | # HF model id
9 | TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1"
10 |
11 | LOADGEN_RUN_TYPE=offline-performance
12 | OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID}
13 | OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP}
14 |
15 | mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR}
16 |
17 | pushd ..
18 | python -m mlperf.main \
19 | --api-url ${API_URL} \
20 | --scenario Offline \
21 | --input-mode tokenized \
22 | --output-mode tokenized \
23 | --log-pred-outputs \
24 | --mlperf-conf $BASEDIR/mlperf.conf \
25 | --user-conf ${USER_CONFIG} \
26 | --audit-conf no-audit \
27 | --total-sample-count ${TOTAL_SAMPLE_COUNT} \
28 | --dataset-path ${DATASET_PATH} \
29 | --tokenizer-path ${TOKENIZER_PATH} \
30 | --log-interval 1000 \
31 | --output-log-dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log
32 | popd
--------------------------------------------------------------------------------
/mlperf/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 | import os
17 |
18 | import pandas as pd
19 |
20 | logging.basicConfig(level=logging.INFO)
21 | log = logging.getLogger("dataset.py")
22 |
23 |
24 | class Dataset:
25 |
26 | def __init__(
27 | self,
28 | dataset_path: str,
29 | input_mode: str,
30 | total_sample_count: int = 15000,
31 | perf_count_override: int = None,
32 | ):
33 | if not os.path.isfile(dataset_path):
34 | log.warn(
35 | "Processed pickle file {} not found. Please check that the path is correct".format(
36 | dataset_path
37 | )
38 | )
39 | self.dataset_path = dataset_path
40 |
41 | self._input_mode = validate_sample_mode(input_mode)
42 | self.load_processed_dataset()
43 |
44 | self.total_sample_count = min(len(self.input_ids_strs), total_sample_count)
45 | self.perf_count = perf_count_override or self.total_sample_count
46 |
47 | @property
48 | def input_ids_strs(self):
49 | return self._input_ids_strs
50 |
51 | @property
52 | def input_texts(self):
53 | return self._input_texts
54 |
55 | @property
56 | def input_token_lengths(self):
57 | return self._input_token_lengths
58 |
59 | @property
60 | def inputs(self):
61 | return self._inputs
62 |
63 | @property
64 | def inputs_with_token_lengths(self):
65 | return self._inputs_with_token_lengths
66 |
67 | @property
68 | def input_datasets(self):
69 | return self._input_datasets
70 |
71 | def load_processed_dataset(self):
72 | processed_data = pd.read_pickle(self.dataset_path)
73 | # processed_data = processed_data[processed_data["dataset"] == "MBXP"]
74 | # processed_data = processed_data.reset_index(drop=True)
75 |
76 | self._input_ids_strs = []
77 | for input_ids in processed_data["tok_input"]:
78 | input_ids_str = ",".join([str(input_id) for input_id in input_ids])
79 | self._input_ids_strs.append(input_ids_str)
80 |
81 | self._input_texts = []
82 | for input_text in processed_data["input"]:
83 | self._input_texts.append(input_text)
84 |
85 | self._input_token_lengths = []
86 | for token_length in processed_data["tok_input_len"]:
87 | self._input_token_lengths.append(token_length)
88 |
89 | log.info(f"input_mode is {self._input_mode}")
90 | self._inputs = (
91 | self._input_ids_strs
92 | if self._input_mode == "tokenized"
93 | else self._input_texts
94 | )
95 | log.info(f"example sample input is {self._inputs[0]}")
96 | self._inputs_with_token_lengths = [
97 | (input_ids_str_or_input_text, token_length)
98 | for input_ids_str_or_input_text, token_length in zip(
99 | self._inputs, self._input_token_lengths
100 | )
101 | ]
102 |
103 | self._input_datasets = []
104 | for dataset in processed_data["dataset"]:
105 | self._input_datasets.append(dataset)
106 | log.info(
107 | f"example sample input dataset is {self._input_datasets[0]} and total {len(self._input_datasets)}"
108 | )
109 |
110 | def LoadSamplesToRam(self, sample_list):
111 | pass
112 |
113 | def UnloadSamplesFromRam(self, sample_list):
114 | pass
115 |
116 | def __del__(self):
117 | pass
118 |
119 |
120 | SAMPLE_MODE_CHOICES = ["tokenized", "text"]
121 |
122 |
123 | def validate_sample_mode(sample_mode: str) -> str:
124 | if sample_mode not in SAMPLE_MODE_CHOICES:
125 | raise ValueError(
126 | "The sample_mode should be set to either `tokenized` or `text`."
127 | )
128 | return sample_mode
129 |
--------------------------------------------------------------------------------
/mlperf/install.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | DATA_DISK_DIR=data
4 |
5 | mkdir -p $DATA_DISK_DIR
6 |
7 | pip install -U "huggingface_hub[cli]"
8 | pip install \
9 | transformers \
10 | nltk==3.8.1 \
11 | evaluate==0.4.0 \
12 | absl-py==1.4.0 \
13 | rouge-score==0.1.2 \
14 | sentencepiece==0.1.99 \
15 | accelerate==0.21.0
16 |
17 | # install loadgen
18 | pip install mlperf-loadgen
19 |
20 |
21 | pushd $DATA_DISK_DIR
22 |
23 | # model weights
24 | gcloud storage cp gs://sixiang_gcp/mixtral-instruct-quantized ./ --recursive
25 | # NOTE: uncomment one so you dont download too much weights to your box
26 | # gcloud storage cp gs://sixiang_gcp/llama2-70b/llama2-70b/ ./ --recursive
27 |
28 | # Get mixtral data
29 | wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl
30 | mv mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl mixtral_15k_data.pkl
31 | wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl
32 | mv mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl mixtral_15k_calibration_data.pkl
33 |
34 | # Get llama70b data
35 | gcloud storage cp \
36 | gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl \
37 | processed-calibration-data.pkl
38 | gcloud storage cp \
39 | gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl \
40 | processed-data.pkl
41 | popd
42 |
--------------------------------------------------------------------------------
/mlperf/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 |
17 | import gc
18 | import logging
19 | import os
20 | import sys
21 |
22 | from . import backend
23 |
24 | import mlperf_loadgen as lg
25 |
26 | _MLPERF_ID = "mixtral-8x7b"
27 |
28 | sys.path.insert(0, os.getcwd())
29 |
30 | logging.basicConfig(level=logging.INFO)
31 | log = logging.getLogger("main.py")
32 |
33 |
34 | def get_args():
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument(
37 | "--scenario",
38 | type=str,
39 | choices=["Offline", "Server"],
40 | default="Offline",
41 | help="Scenario",
42 | )
43 | parser.add_argument(
44 | "--api-url", type=str, default=None, help="SAX published model path."
45 | )
46 | parser.add_argument("--dataset-path", type=str, default=None, help="")
47 | parser.add_argument("--tokenizer-path", type=str, default=None, help="")
48 | parser.add_argument(
49 | "--accuracy", action="store_true", help="Run accuracy mode"
50 | )
51 | parser.add_argument("--is-stream", action="store_true", help="")
52 | parser.add_argument(
53 | "--input-mode",
54 | type=str,
55 | choices=["text", "tokenized"],
56 | default="tokenized",
57 | )
58 | parser.add_argument(
59 | "--output-mode",
60 | type=str,
61 | choices=["text", "tokenized"],
62 | default="tokenized",
63 | )
64 | parser.add_argument(
65 | "--max-output-len", type=int, default=1024, help="Maximum output len"
66 | )
67 | parser.add_argument(
68 | "--audit-conf",
69 | type=str,
70 | default="audit.conf",
71 | help="audit config for LoadGen settings during compliance runs",
72 | )
73 | parser.add_argument(
74 | "--mlperf-conf",
75 | type=str,
76 | default="mlperf.conf",
77 | help="mlperf rules config",
78 | )
79 | parser.add_argument(
80 | "--user-conf",
81 | type=str,
82 | default="user.conf",
83 | help="user config for user LoadGen settings such as target QPS",
84 | )
85 | parser.add_argument(
86 | "--total-sample-count",
87 | type=int,
88 | default=15000,
89 | help="Number of samples to use in benchmark.",
90 | )
91 | parser.add_argument(
92 | "--perf-count-override",
93 | type=int,
94 | default=None,
95 | help="Overwrite number of samples to use in benchmark.",
96 | )
97 | parser.add_argument(
98 | "--output-log-dir",
99 | type=str,
100 | default="output-logs",
101 | help="Where logs are saved.",
102 | )
103 | parser.add_argument(
104 | "--enable-log-trace",
105 | action="store_true",
106 | help="Enable log tracing. This file can become quite large",
107 | )
108 | parser.add_argument(
109 | "--num-client-threads",
110 | type=int,
111 | default=200,
112 | help="Number of client threads to use",
113 | )
114 | parser.add_argument("--batch-size-exp", type=int, default=6, help="")
115 | parser.add_argument("--log-pred-outputs", action="store_true", help="")
116 | parser.add_argument(
117 | "--log-interval",
118 | type=int,
119 | default=1000,
120 | help="Logging interval in seconds",
121 | )
122 | parser.add_argument(
123 | "--user-conf-override-path",
124 | type=str,
125 | default="",
126 | help="When given overrides the default user.conf path",
127 | )
128 |
129 | args = parser.parse_args()
130 | return args
131 |
132 |
133 | scenario_map = {
134 | "offline": lg.TestScenario.Offline,
135 | "server": lg.TestScenario.Server,
136 | }
137 |
138 |
139 | def main():
140 | args = get_args()
141 |
142 | settings = lg.TestSettings()
143 | settings.scenario = scenario_map[args.scenario.lower()]
144 | if args.user_conf_override_path:
145 | user_conf = args.user_conf_override_path
146 | else:
147 | user_conf = args.user_conf
148 |
149 | settings.FromConfig(args.mlperf_conf, _MLPERF_ID, args.scenario)
150 | settings.FromConfig(user_conf, _MLPERF_ID, args.scenario)
151 | log.info("Mlperf config: %s", args.mlperf_conf)
152 | log.info("User config: %s", user_conf)
153 |
154 | if args.accuracy:
155 | settings.mode = lg.TestMode.AccuracyOnly
156 | log.warning(
157 | "Accuracy run will generate the accuracy logs, but the evaluation of the log is not completed yet"
158 | )
159 | else:
160 | settings.mode = lg.TestMode.PerformanceOnly
161 | settings.print_timestamps = True
162 |
163 | settings.use_token_latencies = True
164 |
165 | os.makedirs(args.output_log_dir, exist_ok=True)
166 | log_output_settings = lg.LogOutputSettings()
167 | log_output_settings.outdir = args.output_log_dir
168 | log_output_settings.copy_summary_to_stdout = True
169 | log_settings = lg.LogSettings()
170 | log_settings.log_output = log_output_settings
171 | log_settings.enable_trace = args.enable_log_trace
172 |
173 | sut = backend.SUT(
174 | scenario=args.scenario.lower(),
175 | api_url=args.api_url,
176 | is_stream=args.is_stream,
177 | input_mode=args.input_mode,
178 | output_mode=args.output_mode,
179 | max_output_len=args.max_output_len,
180 | dataset_path=args.dataset_path,
181 | total_sample_count=args.total_sample_count,
182 | tokenizer_path=args.tokenizer_path,
183 | perf_count_override=args.perf_count_override,
184 | num_client_threads=args.num_client_threads,
185 | log_interval=args.log_interval,
186 | batch_size_exp=args.batch_size_exp,
187 | pred_outputs_log_path=os.path.join(
188 | args.output_log_dir, "pred_outputs_logger.json"
189 | )
190 | if args.log_pred_outputs
191 | else None,
192 | )
193 |
194 | lgSUT = sut.sut # lg.ConstructSUT(sut.issue_queries, sut.flush_queries)
195 | log.info("Starting Benchmark run")
196 | lg.StartTestWithLogSettings(
197 | lgSUT, sut.qsl, settings, log_settings, args.audit_conf
198 | )
199 |
200 | log.info("Run Completed!")
201 |
202 | log.info("Destroying SUT...")
203 | lg.DestroySUT(lgSUT)
204 |
205 | log.info("Destroying QSL...")
206 | lg.DestroyQSL(sut.qsl)
207 |
208 |
209 | if __name__ == "__main__":
210 | # Disable garbage collection to avoid stalls when running tests.
211 | gc.disable()
212 | main()
213 |
--------------------------------------------------------------------------------
/mlperf/mlperf.conf:
--------------------------------------------------------------------------------
1 | # The format of this config file is 'key = value'.
2 | # The key has the format 'model.scenario.key'. Value is mostly int64_t.
3 | # Model maybe '*' as wildcard. In that case the value applies to all models.
4 | # All times are in milli seconds
5 |
6 | # Set performance_sample_count for each model.
7 | # User can optionally set this to higher values in user.conf.
8 | resnet50.*.performance_sample_count_override = 1024
9 | ssd-mobilenet.*.performance_sample_count_override = 256
10 | retinanet.*.performance_sample_count_override = 64
11 | bert.*.performance_sample_count_override = 10833
12 | dlrm.*.performance_sample_count_override = 204800
13 | dlrm-v2.*.performance_sample_count_override = 204800
14 | rnnt.*.performance_sample_count_override = 2513
15 | gptj.*.performance_sample_count_override = 13368
16 | llama2-70b.*.performance_sample_count_override = 24576
17 | stable-diffusion-xl.*.performance_sample_count_override = 5000
18 | # set to 0 to let entire sample set to be performance sample
19 | 3d-unet.*.performance_sample_count_override = 0
20 |
21 | # Set seeds. The seeds will be distributed two weeks before the submission.
22 | *.*.qsl_rng_seed = 3066443479025735752
23 | *.*.sample_index_rng_seed = 10688027786191513374
24 | *.*.schedule_rng_seed = 14962580496156340209
25 | # Set seeds for TEST_05. The seeds will be distributed two weeks before the submission.
26 | *.*.test05_qsl_rng_seed = 16799458546791641818
27 | *.*.test05_sample_index_rng_seed = 5453809927556429288
28 | *.*.test05_schedule_rng_seed = 5435552105434836064
29 |
30 |
31 | *.SingleStream.target_latency_percentile = 90
32 | *.SingleStream.min_duration = 600000
33 |
34 | *.MultiStream.target_latency_percentile = 99
35 | *.MultiStream.samples_per_query = 8
36 | *.MultiStream.min_duration = 600000
37 | *.MultiStream.min_query_count = 662
38 | retinanet.MultiStream.target_latency = 528
39 |
40 | # 3D-UNet uses equal issue mode because it has non-uniform inputs
41 | 3d-unet.*.sample_concatenate_permutation = 1
42 |
43 | # LLM benchmarks have non-uniform inputs and outputs, and use equal issue mode for all latency scenario
44 | gptj.*.sample_concatenate_permutation = 1
45 | llama2-70b.*.sample_concatenate_permutation = 1
46 | mixtral-8x7B.*.sample_concatenate_permutation = 1
47 |
48 | *.Server.target_latency = 10
49 | *.Server.target_latency_percentile = 99
50 | *.Server.target_duration = 0
51 | *.Server.min_duration = 600000
52 | resnet50.Server.target_latency = 15
53 | retinanet.Server.target_latency = 100
54 | bert.Server.target_latency = 130
55 | dlrm.Server.target_latency = 60
56 | dlrm-v2.Server.target_latency = 60
57 | rnnt.Server.target_latency = 1000
58 | gptj.Server.target_latency = 20000
59 | stable-diffusion-xl.Server.target_latency = 20000
60 | # Llama2-70b benchmarks measures token latencies
61 | llama2-70b.*.use_token_latencies = 1
62 | mixtral-8x7b.*.use_token_latencies = 1
63 | # gptj benchmark infers token latencies
64 | gptj.*.infer_token_latencies = 1
65 | gptj.*.token_latency_scaling_factor = 69
66 | # Only ttft and tpot are tracked for the llama2-70b & mixtral-8x7B benchmark therefore target_latency = 0
67 | llama2-70b.Server.target_latency = 0
68 | llama2-70b.Server.ttft_latency = 2000
69 | llama2-70b.Server.tpot_latency = 200
70 |
71 | mixtral-8x7b.Server.target_latency = 0
72 | mixtral-8x7b.Server.ttft_latency = 2000
73 | mixtral-8x7b.Server.tpot_latency = 200
74 |
75 | *.Offline.target_latency_percentile = 90
76 | *.Offline.min_duration = 600000
77 |
78 | # In Offline scenario, we always have one query. But LoadGen maps this to
79 | # min_sample_count internally in Offline scenario. If the dataset size is larger
80 | # than 24576 we limit the min_query_count to 24576 and otherwise we use
81 | # the dataset size as the limit
82 |
83 | resnet50.Offline.min_query_count = 24576
84 | retinanet.Offline.min_query_count = 24576
85 | dlrm-v2.Offline.min_query_count = 24576
86 | bert.Offline.min_query_count = 10833
87 | gptj.Offline.min_query_count = 13368
88 | rnnt.Offline.min_query_count = 2513
89 | 3d-unet.Offline.min_query_count = 43
90 | stable-diffusion-xl.Offline.min_query_count = 5000
91 | llama2-70b.Offline.min_query_count = 1000
92 | mixtral-8x7b.Offline.min_query_count = 1000
93 |
94 | # These fields should be defined and overridden by user.conf.
95 | *.SingleStream.target_latency = 10
96 | *.MultiStream.target_latency = 80
97 | *.Server.target_qps = 1.0
98 | *.Offline.target_qps = 4.0
99 |
--------------------------------------------------------------------------------
/mlperf/start_server.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CACHE_LENGTH=3072
4 | INPUT_SIZE=512
5 | OUTPUT_SIZE=512
6 | CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/
7 |
8 | pushd ..
9 | python run_server.py \
10 | --model_name=mixtral \
11 | --batch_size=128 \
12 | --max_cache_length=$CACHE_LENGTH \
13 | --max_decode_length=$OUTPUT_SIZE \
14 | --context_length=$INPUT_SIZE \
15 | --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \
16 | --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \
17 | --quantize_weights=1 \
18 | --quantize_type=int8_per_channel \
19 | --quantize_kv_cache=1
20 | popd
--------------------------------------------------------------------------------
/mlperf/user.conf:
--------------------------------------------------------------------------------
1 | mixtral-8x7b.Server.target_qps = 1.8
2 | mixtral-8x7b.Offline.target_qps = 4.0
3 |
4 |
--------------------------------------------------------------------------------
/mlperf/warmup.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | from dataclasses import dataclass, field
4 | from datetime import datetime
5 | import json
6 | import random
7 | import time
8 | from typing import Any, AsyncGenerator, Optional
9 | import os
10 |
11 |
12 | import grpc
13 | from jetstream.core.proto import jetstream_pb2
14 | from jetstream.core.proto import jetstream_pb2_grpc
15 | from jetstream.engine.token_utils import load_vocab
16 | from jetstream.third_party.llama3 import llama3_tokenizer
17 | import numpy as np
18 | from tqdm.asyncio import tqdm # pytype: disable=pyi-error
19 | import pandas
20 |
21 |
22 | @dataclass
23 | class InputRequest:
24 | prompt: str = ""
25 | prompt_len: int = 0
26 | output: str = ""
27 | output_len: int = 0
28 | sample_idx: int = -1
29 |
30 |
31 | @dataclass
32 | class RequestFuncOutput:
33 | input_request: Optional[InputRequest] = None
34 | generated_token_list: list[str] = field(default_factory=list)
35 | generated_text: str = ""
36 | success: bool = False
37 | latency: float = 0
38 | ttft: float = 0
39 | prompt_len: int = 0
40 |
41 | # Flatten the structure and return only the necessary results
42 | def to_dict(self):
43 | return {
44 | "prompt": self.input_request.prompt,
45 | "original_output": self.input_request.output,
46 | "generated_text": self.generated_text,
47 | "success": self.success,
48 | "latency": self.latency,
49 | "prompt_len": self.prompt_len,
50 | "sample_idx": self.input_request.sample_idx,
51 | }
52 |
53 |
54 | async def grpc_async_request(
55 | api_url: str, request: Any
56 | ) -> tuple[list[str], float, float]:
57 | """Send grpc synchronous request since the current grpc server is sync."""
58 | options = [("grpc.keepalive_timeout_ms", 10000)]
59 | async with grpc.aio.insecure_channel(api_url, options=options) as channel:
60 | stub = jetstream_pb2_grpc.OrchestratorStub(channel)
61 | print("Making request")
62 | ttft = 0
63 | token_list = []
64 | request_start_time = time.perf_counter()
65 | response = stub.Decode(request)
66 | async for resp in response:
67 | if ttft == 0:
68 | ttft = time.perf_counter() - request_start_time
69 | token_list.extend(resp.stream_content.samples[0].token_ids)
70 | latency = time.perf_counter() - request_start_time
71 | print("Done request: ", latency)
72 | return token_list, ttft, latency
73 |
74 |
75 | async def send_request(
76 | api_url: str,
77 | tokenizer: Any,
78 | input_request: InputRequest,
79 | pbar: tqdm,
80 | session_cache: str,
81 | priority: int,
82 | ) -> RequestFuncOutput:
83 | """Send the request to JetStream server."""
84 | # Tokenization on client side following MLPerf standard.
85 | token_ids = np.random.randint(0, 1000, input_request.request_len)
86 | request = jetstream_pb2.DecodeRequest(
87 | session_cache=session_cache,
88 | token_content=jetstream_pb2.DecodeRequest.TokenContent(
89 | token_ids=token_ids
90 | ),
91 | priority=priority,
92 | max_tokens=input_request.output_len,
93 | )
94 | output = RequestFuncOutput()
95 | output.input_request = input_request
96 | output.prompt_len = input_request.prompt_len
97 | generated_token_list, ttft, latency = await grpc_async_request(
98 | api_url, request
99 | )
100 | output.ttft = ttft
101 | output.latency = latency
102 | output.generated_token_list = generated_token_list
103 | # generated_token_list is a list of token ids, decode it to generated_text.
104 | output.generated_text = ""
105 | output.success = True
106 | if pbar:
107 | pbar.update(1)
108 | return output
109 |
110 |
111 | async def benchmark(
112 | api_url: str,
113 | max_length: int,
114 | tokenizer: Any = None,
115 | request_rate: float = 0,
116 | disable_tqdm: bool = False,
117 | session_cache: str = "",
118 | priority: int = 100,
119 | ):
120 | """Benchmark the online serving performance."""
121 |
122 | print(f"Traffic request rate: {request_rate}")
123 |
124 | benchmark_start_time = time.perf_counter()
125 | tasks = []
126 | interesting_buckets = [
127 | 4,
128 | 8,
129 | 16,
130 | 32,
131 | 64,
132 | 128,
133 | 256,
134 | 512,
135 | 1024,
136 | 2048,
137 | ]
138 |
139 | for length in interesting_buckets:
140 | if length > max_length:
141 | break
142 | request = InputRequest()
143 | request.request_len = length
144 | print("send request of length", request.request_len)
145 | tasks.append(
146 | asyncio.create_task(
147 | send_request(
148 | api_url=api_url,
149 | tokenizer=None,
150 | input_request=request,
151 | pbar=None,
152 | session_cache=session_cache,
153 | priority=priority,
154 | )
155 | )
156 | )
157 | outputs = await asyncio.gather(*tasks)
158 |
159 | benchmark_duration = time.perf_counter() - benchmark_start_time
160 | return benchmark_duration, outputs
161 |
162 |
163 | def main(args: argparse.Namespace):
164 | print(args)
165 | random.seed(args.seed)
166 | np.random.seed(args.seed)
167 | api_url = f"{args.server}:{args.port}"
168 |
169 | benchmark_result, request_outputs = asyncio.run(
170 | benchmark(api_url=api_url, max_length=args.max_length)
171 | )
172 | print("DURATION:", benchmark_result)
173 |
174 |
175 | if __name__ == "__main__":
176 |
177 | parser = argparse.ArgumentParser(
178 | description="Benchmark the online serving throughput."
179 | )
180 | parser.add_argument(
181 | "--server",
182 | type=str,
183 | default="0.0.0.0",
184 | help="Server address.",
185 | )
186 | parser.add_argument("--seed", type=int, default=0)
187 |
188 | parser.add_argument("--port", type=str, default=9000)
189 | parser.add_argument("--max-length", type=int, default=512)
190 |
191 | parsed_args = parser.parse_args()
192 | main(parsed_args)
193 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | version = "0.2.2"
7 | name = "jetstream_pt"
8 | dependencies = [
9 | "absl-py",
10 | "flatbuffers",
11 | "flax",
12 | "sentencepiece",
13 | "pytest",
14 | "google-jetstream",
15 | "google-cloud-storage",
16 | "safetensors",
17 | "torch_xla2 @ {root:uri}/deps/xla/experimental/torch_xla2",
18 | "google-jetstream @ {root:uri}/deps/JetStream",
19 | ]
20 |
21 |
22 | requires-python = ">=3.10"
23 | license = {file = "LICENSE"}
24 |
25 | [project.scripts]
26 | jpt = "jetstream_pt.cli:main"
27 |
28 | [tool.hatch.metadata]
29 | allow-direct-references = true
30 |
--------------------------------------------------------------------------------
/run_interactive_disaggregated.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import random
17 | import time
18 |
19 | from typing import List
20 | from absl import app
21 | from absl import flags
22 |
23 | import jax
24 |
25 | from jetstream.engine import token_utils
26 | from jetstream_pt import ray_engine
27 |
28 | FLAGS = flags.FLAGS
29 |
30 | _TOKENIZER_PATH = flags.DEFINE_string(
31 | "tokenizer_path",
32 | "tokenizer.model",
33 | "The tokenizer model path",
34 | required=False,
35 | )
36 | _CKPT_PATH = flags.DEFINE_string(
37 | "checkpoint_path", None, "Directory for .pth checkpoints", required=False
38 | )
39 | _BF16_ENABLE = flags.DEFINE_bool(
40 | "bf16_enable", False, "Whether to enable bf16", required=False
41 | )
42 | _CONTEXT_LENGTH = flags.DEFINE_integer(
43 | "context_length", 1024, "The context length", required=False
44 | )
45 | _BATCH_SIZE = flags.DEFINE_integer(
46 | "batch_size", 32, "The batch size", required=False
47 | )
48 | _PROFILING_OUTPUT = flags.DEFINE_string(
49 | "profiling_output",
50 | "",
51 | "The profiling output",
52 | required=False,
53 | )
54 |
55 | _SIZE = flags.DEFINE_string("size", "tiny", "size of model")
56 |
57 | _QUANTIZE_WEIGHTS = flags.DEFINE_bool(
58 | "quantize_weights", False, "weight quantization"
59 | )
60 | _QUANTIZE_KV_CACHE = flags.DEFINE_bool(
61 | "quantize_kv_cache", False, "kv_cache_quantize"
62 | )
63 | _MAX_CACHE_LENGTH = flags.DEFINE_integer(
64 | "max_cache_length", 1024, "kv_cache_quantize"
65 | )
66 |
67 | _MODEL_NAME = flags.DEFINE_string(
68 | "model_name", None, "model type", required=False
69 | )
70 |
71 | _SHARDING_CONFIG = flags.DEFINE_string(
72 | "sharding_config", "", "config file for sharding"
73 | )
74 |
75 |
76 | _IS_DISAGGREGATED = flags.DEFINE_bool(
77 | "is_disaggregated", False, "Disaggregated serving if it's True"
78 | )
79 |
80 | _NUM_HOSTS = flags.DEFINE_integer(
81 | "num_hosts", 4, "Number of TPU host", required=False
82 | )
83 |
84 | _DECODE_POD_SLICE_NAME = flags.DEFINE_string(
85 | "decode_pod_slice_name", "", "Decode pod slice name"
86 | )
87 |
88 |
89 | def create_disaggregated_engines():
90 | """create a pytorch engine"""
91 | # jax.config.update("jax_default_prng_impl", "unsafe_rbg")
92 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
93 |
94 | start = time.perf_counter()
95 | prefill_engine_list, decode_engine_list = (
96 | ray_engine.create_pytorch_ray_engine(
97 | model_name=_MODEL_NAME.value,
98 | tokenizer_path=_TOKENIZER_PATH.value,
99 | ckpt_path=_CKPT_PATH.value,
100 | bf16_enable=True,
101 | param_size=_SIZE.value,
102 | context_length=_CONTEXT_LENGTH.value,
103 | batch_size=_BATCH_SIZE.value,
104 | quantize_weights=_QUANTIZE_WEIGHTS.value,
105 | quantize_kv=_QUANTIZE_KV_CACHE.value,
106 | max_cache_length=_MAX_CACHE_LENGTH.value,
107 | sharding_config=_SHARDING_CONFIG.value,
108 | is_disaggregated=_IS_DISAGGREGATED.value,
109 | num_hosts=_NUM_HOSTS.value,
110 | decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value,
111 | )
112 | )
113 |
114 | print("Initialize engine", time.perf_counter() - start)
115 | return (prefill_engine_list[0], decode_engine_list[0])
116 |
117 |
118 | # pylint: disable-next=all
119 | def main(argv):
120 |
121 | print("start the test")
122 | prefill_engine, decode_engine = create_disaggregated_engines()
123 |
124 | start = time.perf_counter()
125 | prefill_engine.load_params()
126 | decode_engine.load_params()
127 | print("Load params ", time.perf_counter() - start)
128 |
129 | metadata = prefill_engine.get_tokenizer()
130 | vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
131 | stop_tokens = [vocab.eos_id, vocab.pad_id]
132 | max_output_length = 1024
133 |
134 | if _PROFILING_OUTPUT.value:
135 | jax.profiler.start_trace(_PROFILING_OUTPUT.value)
136 |
137 | decode_engine.init_decode_state()
138 | prompts: List[str] = [
139 | "I believe the meaning of life is",
140 | # pylint: disable-next=all
141 | "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.",
142 | # pylint: disable-next=all
143 | "[INST] <>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]",
144 | # pylint: disable-next=all
145 | "[INST] <>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST",
146 | # pylint: disable-next=all
147 | "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]",
148 | ]
149 | for prompt in prompts:
150 | slot = random.randint(0, _BATCH_SIZE.value - 1)
151 | tokens, true_length = token_utils.tokenize_and_pad(
152 | prompt, vocab, is_bos=True, jax_padding=False
153 | )
154 | print(f"---- Input prompts are: {prompt}")
155 | print(f"---- Encoded tokens are: {tokens}")
156 |
157 | print(
158 | # pylint: disable-next=all
159 | f"---- Do prefill in prefill engine pod_slice_name: {prefill_engine.pod_slice_name}"
160 | )
161 | prefill_result, _ = prefill_engine.prefill(
162 | params=None, padded_tokens=tokens, true_length=true_length
163 | )
164 | print(
165 | # pylint: disable-next=all
166 | f"---- Transfer prefill result to decode engine pod_slice_name: {decode_engine.pod_slice_name}"
167 | )
168 | decode_engine.transfer(prefill_result)
169 |
170 | print(
171 | # pylint: disable-next=all
172 | f"---- Do insert in decode engine pod_slice_name: {decode_engine.pod_slice_name}"
173 | )
174 | decode_state = decode_engine.insert(prefill_result, None, slot=slot)
175 | sampled_tokens_list = []
176 | while True:
177 | # pylint: disable-next=all
178 | decode_state, result_tokens = decode_engine.generate(None, decode_state)
179 | result_tokens = result_tokens.convert_to_numpy()
180 |
181 | slot_data = result_tokens.get_result_at_slot(slot)
182 | slot_tokens = slot_data.tokens
183 | slot_lengths = slot_data.lengths
184 |
185 | token_id = slot_tokens[slot, 0].item()
186 | if slot_lengths > max_output_length or token_id in stop_tokens:
187 | break
188 |
189 | sampled_tokens_list.append(token_id)
190 |
191 | print("---- All output tokens.")
192 | print(sampled_tokens_list)
193 | print("---- All output text.")
194 | print(vocab.tokenizer.decode(sampled_tokens_list))
195 |
196 | if _PROFILING_OUTPUT.value:
197 | jax.profiler.stop_trace()
198 |
199 |
200 | if __name__ == "__main__":
201 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
202 | app.run(main)
203 |
--------------------------------------------------------------------------------
/run_interactive_multiple_host.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import random
17 | import time
18 | from typing import List
19 |
20 | import jax
21 | from absl import app, flags
22 | from jetstream.engine import token_utils
23 | from jetstream_pt import ray_engine
24 | from jetstream_pt.config import FLAGS
25 |
26 | _NUM_HOSTS = flags.DEFINE_integer(
27 | "num_hosts", 0, "Number of TPU host", required=False
28 | )
29 |
30 | _WORKER_CHIPS = flags.DEFINE_integer(
31 | "worker_chips", 4, "Number of TPU chips per worker", required=False
32 | )
33 |
34 | _TPU_CHIPS = flags.DEFINE_integer(
35 | "tpu_chips", 4, "All devices TPU chips", required=False
36 | )
37 |
38 |
39 | def create_engine():
40 | """create a pytorch engine"""
41 | jax.config.update("jax_default_prng_impl", "unsafe_rbg")
42 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
43 |
44 | start = time.perf_counter()
45 | engine = ray_engine.create_pytorch_ray_engine(
46 | model_name=FLAGS.model_name,
47 | tokenizer_path=FLAGS.tokenizer_path,
48 | ckpt_path=FLAGS.checkpoint_path,
49 | bf16_enable=FLAGS.bf16_enable,
50 | param_size=FLAGS.size,
51 | context_length=FLAGS.context_length,
52 | batch_size=FLAGS.batch_size,
53 | quantize_weights=FLAGS.quantize_weights,
54 | quantize_kv=FLAGS.quantize_kv_cache,
55 | max_cache_length=FLAGS.max_cache_length,
56 | sharding_config=FLAGS.sharding_config,
57 | num_hosts=_NUM_HOSTS.value,
58 | worker_chips=_WORKER_CHIPS.value,
59 | tpu_chips=_TPU_CHIPS.value,
60 | )
61 |
62 | print("Initialize engine", time.perf_counter() - start)
63 | return engine
64 |
65 |
66 | # pylint: disable-next=all
67 | def main(argv):
68 |
69 | engine = create_engine()
70 |
71 | start = time.perf_counter()
72 | engine.load_params()
73 | print("Load params ", time.perf_counter() - start)
74 |
75 | metadata = engine.get_tokenizer()
76 | vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
77 | stop_tokens = [vocab.eos_id, vocab.pad_id]
78 | max_output_length = 1024
79 |
80 | profiling_output = FLAGS.profiling_output
81 | if profiling_output:
82 | jax.profiler.start_trace(profiling_output)
83 |
84 | engine.init_decode_state()
85 | prompts: List[str] = [
86 | "I believe the meaning of life is",
87 | # pylint: disable-next=all
88 | "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.",
89 | # pylint: disable-next=all
90 | "[INST] <>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]",
91 | # pylint: disable-next=all
92 | "[INST] <>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST",
93 | # pylint: disable-next=all
94 | "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]",
95 | ]
96 | for prompt in prompts:
97 | slot = random.randint(0, FLAGS.batch_size - 1)
98 | tokens, true_length = token_utils.tokenize_and_pad(
99 | prompt, vocab, is_bos=True, jax_padding=False
100 | )
101 | print(f"---- Input prompts are: {prompt}")
102 | print(f"---- Encoded tokens are: {tokens}")
103 |
104 | # pylint: disable-next=all
105 | prefill_result, _ = engine.prefill(
106 | params=None, padded_tokens=tokens, true_length=true_length
107 | )
108 | # pylint: disable-next=all
109 | decode_state = engine.insert(prefill_result, None, slot=slot)
110 | sampled_tokens_list = []
111 | while True:
112 | # pylint: disable-next=all
113 | decode_state, result_tokens = engine.generate(None, decode_state)
114 | result_tokens = result_tokens.convert_to_numpy()
115 |
116 | slot_data = result_tokens.get_result_at_slot(slot)
117 | slot_tokens = slot_data.tokens
118 | slot_lengths = slot_data.lengths
119 |
120 | token_id = slot_tokens[slot, 0].item()
121 | if slot_lengths > max_output_length or token_id in stop_tokens:
122 | break
123 |
124 | sampled_tokens_list.append(token_id)
125 |
126 | print("---- All output tokens.")
127 | print(sampled_tokens_list)
128 | print("---- All output text.")
129 | print(vocab.tokenizer.decode(sampled_tokens_list))
130 |
131 | if profiling_output:
132 | jax.profiler.stop_trace()
133 |
134 |
135 | if __name__ == "__main__":
136 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
137 | app.run(main)
138 |
--------------------------------------------------------------------------------
/run_ray_serve_interleave.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """ Runs a RayServe deployment with Jetstream interleave mode."""
16 | import os
17 | import time
18 | from typing import AsyncIterator
19 | from absl import app, flags
20 |
21 | from ray import serve
22 | from ray.serve.config import gRPCOptions
23 |
24 | from jetstream.core import config_lib
25 | from jetstream.core import orchestrator
26 | from jetstream.core.config_lib import ServerConfig
27 | from jetstream.core.proto import jetstream_pb2
28 | from jetstream_pt import ray_engine
29 | from jetstream_pt.config import FLAGS
30 |
31 |
32 | flags.DEFINE_string("tpu_generation", "v4", "TPU generation")
33 | flags.DEFINE_integer("tpu_chips", 16, "device tpu_chips")
34 | flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler")
35 | flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server")
36 | flags.DEFINE_integer("num_hosts", 4, "Number of TPU host", required=False)
37 | flags.DEFINE_integer(
38 | "worker_chips", 4, "Number of TPU chips per worker", required=False
39 | )
40 |
41 |
42 | def create_head_resource_name(generation, tpu_chips):
43 | if generation == "v5litepod":
44 | return f"TPU-{generation}-{tpu_chips}-head"
45 | else:
46 | tpu_cores = tpu_chips * 2
47 | return f"TPU-{generation}-{tpu_cores}-head"
48 |
49 |
50 | def create_engine(**kwargs):
51 | """create a pytorch engine"""
52 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
53 |
54 | start = time.perf_counter()
55 | engine = ray_engine.create_pytorch_ray_engine(**kwargs)
56 |
57 | print("Initialize engine", time.perf_counter() - start)
58 | return engine
59 |
60 |
61 | @serve.deployment
62 | class JetStreamDeployment:
63 | """JetStream deployment."""
64 |
65 | def __init__(self, **kwargs):
66 | os.environ["XLA_FLAGS"] = (
67 | "--xla_dump_to=/tmp/xla_logs --xla_dump_hlo_as_text"
68 | )
69 | devices = []
70 | for i in range(kwargs["tpu_chips"]):
71 | devices.append(i)
72 |
73 | print(f"devices: {devices}")
74 |
75 | self.engine = create_engine(**kwargs)
76 | server_config = ServerConfig(
77 | interleaved_slices=(f"tpu={len(devices)}",),
78 | interleaved_engine_create_fns=(lambda a: self.engine,),
79 | )
80 |
81 | engines = config_lib.get_engines(server_config, devices=devices)
82 | prefill_params = [pe.load_params() for pe in engines.prefill_engines]
83 | generate_params = [ge.load_params() for ge in engines.generate_engines]
84 | shared_params = [ie.load_params() for ie in engines.interleaved_engines]
85 | print("Loaded all weights.")
86 |
87 | self.driver = orchestrator.Driver(
88 | prefill_engines=engines.prefill_engines + engines.interleaved_engines,
89 | generate_engines=engines.generate_engines + engines.interleaved_engines,
90 | prefill_params=prefill_params + shared_params,
91 | generate_params=generate_params + shared_params,
92 | interleaved_mode=True,
93 | jax_padding=False,
94 | metrics_collector=None,
95 | is_ray_backend=True,
96 | )
97 |
98 | self.orchestrator = orchestrator.LLMOrchestrator(driver=self.driver)
99 |
100 | print("Started jetstream driver....")
101 |
102 | # pylint: disable-next=all
103 | async def Decode(
104 | self,
105 | # pylint: disable-next=all
106 | request: jetstream_pb2.DecodeRequest,
107 | # pylint: disable-next=all
108 | ) -> AsyncIterator[jetstream_pb2.DecodeResponse]:
109 | """Async decode function."""
110 | return self.orchestrator.Decode(request)
111 |
112 |
113 | def main(_argv):
114 | """Main function"""
115 | resource_name = create_head_resource_name(
116 | FLAGS.tpu_generation, FLAGS.tpu_chips
117 | )
118 | print(f"Using head resource {resource_name}")
119 | # pylint: disable-next=all
120 | deployment = JetStreamDeployment.options(
121 | ray_actor_options={"resources": {resource_name: 1}}
122 | ).bind(
123 | tpu_chips=FLAGS.tpu_chips,
124 | worker_chips=FLAGS.worker_chips,
125 | num_hosts=FLAGS.num_hosts,
126 | model_name=FLAGS.model_name,
127 | tokenizer_path=FLAGS.tokenizer_path,
128 | ckpt_path=FLAGS.checkpoint_path,
129 | bf16_enable=FLAGS.bf16_enable,
130 | param_size=FLAGS.size,
131 | context_length=FLAGS.context_length,
132 | batch_size=FLAGS.batch_size,
133 | quantize_weights=FLAGS.quantize_weights,
134 | quantize_kv=FLAGS.quantize_kv_cache,
135 | max_cache_length=FLAGS.max_cache_length,
136 | sharding_config=FLAGS.sharding_config,
137 | enable_jax_profiler=FLAGS.enable_jax_profiler,
138 | jax_profiler_port=FLAGS.jax_profiler_port,
139 | )
140 |
141 | grpc_port = 8888
142 | grpc_servicer_functions = [
143 | "jetstream.core.proto.jetstream_pb2_grpc.add_OrchestratorServicer_to_server",
144 | ]
145 | serve.start(
146 | grpc_options=gRPCOptions(
147 | port=grpc_port,
148 | grpc_servicer_functions=grpc_servicer_functions,
149 | ),
150 | )
151 |
152 | serve.run(deployment)
153 |
154 |
155 | if __name__ == "__main__":
156 | app.run(main)
157 |
--------------------------------------------------------------------------------
/run_server_with_ray.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Runs a pytorch server with ray."""
16 | import os
17 | import time
18 | from typing import Sequence
19 | from absl import app, flags
20 |
21 | # import torch_xla2 first!
22 | import jax
23 | from jetstream.core import server_lib
24 | from jetstream.core.config_lib import ServerConfig
25 | from jetstream_pt import ray_engine
26 | from jetstream_pt.config import FLAGS
27 |
28 | flags.DEFINE_integer("port", 9000, "port to listen on")
29 | flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool")
30 | flags.DEFINE_string(
31 | "config",
32 | "InterleavedCPUTestServer",
33 | "available servers",
34 | )
35 | flags.DEFINE_integer("prometheus_port", 0, "")
36 | flags.DEFINE_integer("tpu_chips", 16, "all devices tpu_chips")
37 |
38 | flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler")
39 | flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server")
40 |
41 | flags.DEFINE_bool(
42 | "is_disaggregated", False, "Disaggregated serving if it's True"
43 | )
44 |
45 | flags.DEFINE_integer("num_hosts", 0, "Number of TPU host", required=False)
46 |
47 | flags.DEFINE_integer(
48 | "worker_chips", 4, "Number of TPU chips per worker", required=False
49 | )
50 |
51 | flags.DEFINE_string("decode_pod_slice_name", "", "Decode pod slice name")
52 |
53 |
54 | def create_engine():
55 | """create a pytorch engine"""
56 | jax.config.update("jax_default_prng_impl", "unsafe_rbg")
57 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
58 |
59 | start = time.perf_counter()
60 | engine = ray_engine.create_pytorch_ray_engine(
61 | model_name=FLAGS.model_name,
62 | tokenizer_path=FLAGS.tokenizer_path,
63 | ckpt_path=FLAGS.checkpoint_path,
64 | bf16_enable=FLAGS.bf16_enable,
65 | param_size=FLAGS.size,
66 | context_length=FLAGS.context_length,
67 | batch_size=FLAGS.batch_size,
68 | quantize_weights=FLAGS.quantize_weights,
69 | quantize_kv=FLAGS.quantize_kv_cache,
70 | max_cache_length=FLAGS.max_cache_length,
71 | sharding_config=FLAGS.sharding_config,
72 | enable_jax_profiler=FLAGS.enable_jax_profiler,
73 | jax_profiler_port=FLAGS.jax_profiler_port,
74 | num_hosts=FLAGS.num_hosts,
75 | worker_chips=FLAGS.worker_chips,
76 | tpu_chips=FLAGS.tpu_chips,
77 | )
78 |
79 | print("Initialize engine", time.perf_counter() - start)
80 | return engine
81 |
82 |
83 | def create_disaggregated_engine():
84 | """create a pytorch engine"""
85 | jax.config.update("jax_default_prng_impl", "unsafe_rbg")
86 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
87 |
88 | start = time.perf_counter()
89 | prefill_engine_list, decode_engine_list = (
90 | ray_engine.create_pytorch_ray_engine(
91 | model_name=FLAGS.model_name,
92 | tokenizer_path=FLAGS.tokenizer_path,
93 | ckpt_path=FLAGS.checkpoint_path,
94 | bf16_enable=FLAGS.bf16_enable,
95 | param_size=FLAGS.size,
96 | context_length=FLAGS.context_length,
97 | batch_size=FLAGS.batch_size,
98 | quantize_weights=FLAGS.quantize_weights,
99 | quantize_kv=FLAGS.quantize_kv_cache,
100 | max_cache_length=FLAGS.max_cache_length,
101 | sharding_config=FLAGS.sharding_config,
102 | enable_jax_profiler=FLAGS.enable_jax_profiler,
103 | jax_profiler_port=FLAGS.jax_profiler_port,
104 | is_disaggregated=FLAGS.is_disaggregated,
105 | num_hosts=FLAGS.num_hosts,
106 | decode_pod_slice_name=FLAGS.decode_pod_slice_name,
107 | )
108 | )
109 |
110 | print("Initialize engine", time.perf_counter() - start)
111 | return (prefill_engine_list, decode_engine_list)
112 |
113 |
114 | # pylint: disable-next=all
115 | def main(argv: Sequence[str]):
116 | del argv
117 | os.environ["XLA_FLAGS"] = "--xla_dump_to=/tmp/xla_logs --xla_dump_hlo_as_text"
118 | devices = []
119 | for i in range(FLAGS.tpu_chips):
120 | devices.append(i)
121 |
122 | print(f"devices: {devices}")
123 |
124 | if FLAGS.is_disaggregated:
125 | prefill_engine_list, decode_engine_list = create_disaggregated_engine()
126 | chips = int(len(devices) / 2)
127 | server_config = ServerConfig(
128 | prefill_slices=(f"tpu={chips}",),
129 | prefill_engine_create_fns=(lambda a: prefill_engine_list[0],),
130 | generate_slices=(f"tpu={chips}",),
131 | generate_engine_create_fns=(lambda a: decode_engine_list[0],),
132 | is_ray_backend=True,
133 | )
134 |
135 | else:
136 | engine = create_engine()
137 | server_config = ServerConfig(
138 | interleaved_slices=(f"tpu={len(devices)}",),
139 | interleaved_engine_create_fns=(lambda a: engine,),
140 | )
141 |
142 | print(f"server_config: {server_config}")
143 |
144 | jetstream_server = server_lib.run(
145 | threads=FLAGS.threads,
146 | port=FLAGS.port,
147 | config=server_config,
148 | devices=devices,
149 | jax_padding=False, # Jax_padding must be set as False
150 | )
151 | print("Started jetstream_server....")
152 | jetstream_server.wait_for_termination()
153 |
154 |
155 | if __name__ == "__main__":
156 | app.run(main)
157 |
--------------------------------------------------------------------------------
/scripts/create_empty_sharding_map.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | from absl import app
18 | from absl import flags
19 |
20 | from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData, process_sharding_name
21 | from jetstream_pt.third_party.llama2 import model_exportable, model_args
22 | from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model
23 |
24 | FLAGS = flags.FLAGS
25 |
26 | _MODEL_NAME = flags.DEFINE_string(
27 | "model_name", None, "model type", required=False
28 | )
29 |
30 | _SIZE = flags.DEFINE_string("size", "tiny", "size of model")
31 |
32 | _COLLAPSE_SAME_LAYERS = flags.DEFINE_bool("collapse_same_layers", True, "")
33 |
34 |
35 | def create_model():
36 | batch_size = 3
37 | quant_config = QuantizationConfig(
38 | enable_weight_quantization=True, enable_kv_quantization=True
39 | )
40 | env_data = JetEngineEnvironmentData(
41 | batch_size=3,
42 | max_decode_length=1024,
43 | max_input_sequence_length=1024,
44 | quant_config=quant_config,
45 | cache_sequence_length=1024,
46 | bf16_enable=True,
47 | )
48 | model_name = _MODEL_NAME.value
49 | param_size = _SIZE.value
50 | if model_name.startswith("llama"):
51 |
52 | args = model_args.get_model_args(
53 | param_size,
54 | 1024,
55 | batch_size,
56 | vocab_size=32000,
57 | bf16_enable=True,
58 | )
59 | args.device = "meta"
60 | env = JetEngineEnvironment(env_data)
61 | return model_exportable.Transformer(args, env)
62 | elif model_name == "gemma":
63 | args = gemma_config.get_model_config(param_size)
64 | args.device = "meta"
65 | env_data.model_type = "gemma-" + param_size
66 | env_data.num_layers = args.num_hidden_layers
67 | env = JetEngineEnvironment(env_data)
68 | pt_model = gemma_model.GemmaModel(args, env)
69 | return pt_model
70 |
71 |
72 | # pylint: disable-next=all
73 | def main(argv):
74 | model = create_model()
75 | res = {}
76 | for k, v in model.state_dict().items():
77 | res[process_sharding_name(k)] = v
78 |
79 | print(
80 | f"""
81 | # Sharding config for {_MODEL_NAME.value}
82 | # Sharding should either be an int between 0 and rank - 1
83 | # signifying the axis to shard or -1 / null signifying replicated
84 |
85 | """
86 | )
87 |
88 | for k, v in res.items():
89 | print(k, ":", -1, "# ", str(v.dtype), tuple(v.shape))
90 |
91 |
92 | if __name__ == "__main__":
93 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
94 | app.run(main)
95 |
--------------------------------------------------------------------------------
/scripts/validate_hf_ckpt_conversion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from safetensors import safe_open
3 |
4 | """
5 | Script to compare converted checkpoint for debugging purpose.
6 | """
7 |
8 | converted_from_orig = (
9 | "/mnt/disks/lsiyuan/llama_weight/7B-FT-chat-converted/model.safetensors"
10 | )
11 |
12 | converted_from_hf = "/mnt/disks/lsiyuan/llama_weight/hf_llama_2_7b_converted_bf16/model.safetensors"
13 |
14 | orig_state_dict = {}
15 | with safe_open(converted_from_orig, framework="pt", device="cpu") as f:
16 | for key in f.keys():
17 | orig_state_dict[key] = f.get_tensor(key)
18 |
19 | hf_state_dict = {}
20 | with safe_open(converted_from_hf, framework="pt", device="cpu") as f:
21 | for key in f.keys():
22 | hf_state_dict[key] = f.get_tensor(key)
23 |
24 | for key in orig_state_dict.keys():
25 | if key != "rope.freqs":
26 | assert key in hf_state_dict, f"{key} in orig but not in hf"
27 | else:
28 | print("rope.freqs skipped.")
29 |
30 | for key in hf_state_dict.keys():
31 | assert key in orig_state_dict, f"{key} in hf but not in orig"
32 |
33 |
34 | def _calc_cosine_dist(x, y):
35 | x = x.flatten().to(torch.float32)
36 | y = y.flatten().to(torch.float32)
37 | return (torch.dot(x, y) / (x.norm() * y.norm())).item()
38 |
39 |
40 | for key in hf_state_dict.keys():
41 | orig_w = orig_state_dict[key]
42 | hf_w = hf_state_dict[key]
43 | print(f"weight diff {key} : {_calc_cosine_dist(orig_w, hf_w)}")
44 |
--------------------------------------------------------------------------------
/tests/.pylintrc:
--------------------------------------------------------------------------------
1 | [MESSAGES CONTROL]
2 | disable=C0114,W0212
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/helpers.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import torch
3 | import torch_xla2
4 | from jetstream_pt.third_party.llama import model_args
5 | from jetstream_pt.third_party.mixtral import config as mixtral_config
6 | from jetstream_pt import environment
7 |
8 |
9 | # pylint: disable-next=all
10 | def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None):
11 | torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
12 | torch.set_default_dtype(torch_dtype)
13 | jax.config.update("jax_dynamic_shapes", False)
14 | jax.config.update("jax_traceback_filtering", "off")
15 | config = model_args.get_model_args("llama-2-tiny", 128, 1, True)
16 | environment_data = environment.JetEngineEnvironmentData()
17 | environment_data.max_input_sequence_length = 128
18 | environment_data.max_input_sequence_length = 128
19 | environment_data.cache_sequence_length = 128
20 | environment_data.bf16_enable = bf16_enable
21 | environment_data.model_type = "llama-2-tiny"
22 | environment_data.batch_size = 1
23 | environment_data.num_layers = config.n_layers
24 | environment_data.cache_shape = (
25 | 1,
26 | config.n_kv_heads,
27 | environment_data.cache_sequence_length,
28 | config.dim // config.n_heads,
29 | )
30 | environment_data.n_reps = config.n_heads // config.n_kv_heads
31 | environment_data.testing = True
32 | env_data_update_fn(environment_data)
33 | env = environment.JetEngineEnvironment(environment_data)
34 | env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu
35 | return env, config
36 |
37 |
38 | # pylint: disable-next=all
39 | def make_mixtral_env(bf16_enable=True):
40 | torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
41 | torch.set_default_dtype(torch_dtype)
42 | jax.config.update("jax_dynamic_shapes", False)
43 | jax.config.update("jax_traceback_filtering", "off")
44 | config = mixtral_config.ModelArgs.from_name("Mixtral-tiny")
45 | environment_data = environment.JetEngineEnvironmentData()
46 | environment_data.max_input_sequence_length = 128
47 | environment_data.cache_sequence_length = 128
48 | environment_data.bf16_enable = bf16_enable
49 | environment_data.model_type = "mixtral"
50 | environment_data.batch_size = 1
51 | environment_data.num_layers = config.n_layer
52 | environment_data.cache_shape = (
53 | 1,
54 | config.n_local_heads,
55 | environment_data.cache_sequence_length,
56 | config.dim // config.n_head,
57 | )
58 | environment_data.n_reps = config.n_head // config.n_local_heads
59 | env = environment.JetEngineEnvironment(environment_data)
60 | env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu
61 | return env, config
62 |
63 |
64 | # pylint: disable-next=all
65 | def to_xla_tensor(tree):
66 | return torch_xla2.default_env().to_xla(tree)
67 |
68 |
69 | # pylint: disable-next=all
70 | def call_xla_model(model, weights, args):
71 | with jax.default_device(jax.devices("cpu")[0]):
72 | xla_weights, xla_inputs = to_xla_tensor((weights, args))
73 | with torch_xla2.default_env():
74 | result = torch.func.functional_call(model, xla_weights, xla_inputs)
75 | result_torch = torch_xla2.tensor.j2t(result.jax())
76 | return result_torch
77 |
78 |
79 | # pylint: disable-next=all
80 | def make_page_attention_env_tiny(
81 | bf16_enable=True, env_data_update_fn=lambda _: None
82 | ):
83 | torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
84 | torch.set_default_dtype(torch_dtype)
85 | jax.config.update("jax_dynamic_shapes", False)
86 | jax.config.update("jax_traceback_filtering", "off")
87 | config = model_args.get_model_args("llama-2-tiny", 128, 1, True)
88 | environment_data = environment.JetEngineEnvironmentData()
89 | environment_data.paged_attention_page_size = 32
90 | environment_data.paged_attention_total_num_pages = 16
91 | environment_data.block_size = 64
92 | environment_data.max_input_sequence_length = 128
93 | environment_data.max_input_sequence_length = 128
94 | environment_data.cache_sequence_length = 128
95 | environment_data.bf16_enable = bf16_enable
96 | environment_data.model_type = "llama-2-tiny"
97 | environment_data.batch_size = 1
98 | environment_data.num_layers = config.n_layers
99 | environment_data.cache_shape = (
100 | config.n_kv_heads,
101 | environment_data.paged_attention_total_num_pages,
102 | environment_data.paged_attention_page_size,
103 | config.dim // config.n_heads,
104 | )
105 | environment_data.testing = True
106 | env_data_update_fn(environment_data)
107 | env = environment.JetEngineEnvironment(environment_data)
108 | env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu
109 | return env, config
110 |
--------------------------------------------------------------------------------
/tests/test_hf_names.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | from jetstream_pt.model_base import ModuleBase
4 |
5 |
6 | class TestModuleBase(unittest.TestCase):
7 | """Test module base."""
8 |
9 | def test_get_hf_names_to_real_name(self):
10 | """Test get hugginface names to real name."""
11 |
12 | class MyModule(ModuleBase):
13 | """My module."""
14 |
15 | def __init__(self):
16 | super().__init__()
17 | self.linear1 = torch.nn.Linear(10, 20)
18 | self.linear2 = torch.nn.Linear(20, 30)
19 | self.hf_name("linear1", "model.my_linear1")
20 | self.hf_name("linear2", "model.my_linear2")
21 | self.param = torch.nn.Parameter(torch.randn(10))
22 | self.hf_name("param", "model.param")
23 |
24 | def forward(self):
25 | """Forward function."""
26 |
27 | module = MyModule()
28 | expected_mapping = {
29 | "model.my_linear1.weight": "linear1.weight",
30 | "model.my_linear1.bias": "linear1.bias",
31 | "model.my_linear2.weight": "linear2.weight",
32 | "model.my_linear2.bias": "linear2.bias",
33 | "model.param": "param",
34 | }
35 |
36 | self.assertEqual(module.get_hf_names_to_real_name(), expected_mapping)
37 |
38 | def test_get_sharding_annotations(self):
39 | """Test get sharding annotations."""
40 |
41 | class MyModule(ModuleBase):
42 | """MyModule."""
43 |
44 | def __init__(self):
45 | super().__init__()
46 | self.linear = torch.nn.Linear(10, 20)
47 | self.embedding = torch.nn.Embedding(100, 50)
48 | self.inner = InnerModule()
49 |
50 | def forward(self):
51 | """Forward function."""
52 |
53 | class InnerModule(ModuleBase):
54 | """Inner modeule."""
55 |
56 | def __init__(self):
57 | super().__init__()
58 | self.fc = torch.nn.Linear(50, 100)
59 |
60 | def forward(self):
61 | """Forward function."""
62 |
63 | module = MyModule()
64 | module.annotate_sharding("linear.weight", 0)
65 | module.annotate_sharding("embedding.weight", 1)
66 | module.inner.annotate_sharding("fc.weight", 2)
67 |
68 | expected_mapping = {
69 | "linear.weight": 0,
70 | "embedding.weight": 1,
71 | "inner.fc.weight": 2,
72 | }
73 | self.assertEqual(module.get_sharding_annotations(), expected_mapping)
74 |
75 |
76 | if __name__ == "__main__":
77 | unittest.main()
78 |
--------------------------------------------------------------------------------
/tests/test_jax_torch.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import torch_xla2
4 | import jax
5 | import jax.numpy as jnp
6 |
7 |
8 | class JaxTorchTest(unittest.TestCase):
9 | """Unit test compare Jax and Torch gap with float precision"""
10 |
11 | def test_matmul_bfloat16_xla2(self):
12 | """test jax vs torch matmul diff with bfloat16 on cpu"""
13 | jax.config.update("jax_platform_name", "cpu")
14 | torch.set_default_dtype(torch.bfloat16)
15 | r = c = 1000
16 | q = torch.randn((r, c))
17 | k = torch.randn((r, c))
18 | print(f"torch matlmul: {q.shape} * {k.shape}")
19 | result = torch.matmul(q, k)
20 |
21 | jax_q = torch_xla2.tensor.t2j(q)
22 | jax_k = torch_xla2.tensor.t2j(k)
23 | print(f"torch matlmul: {jax_q.shape} * {jax_k.shape}")
24 | jax_result = jnp.matmul(jax_q, jax_k)
25 | target_result = torch_xla2.tensor.j2t(jax_result)
26 | print(
27 | f"----------------------- matmul: Diff norm {(target_result - result).norm()}"
28 | )
29 | self.assertTrue(torch.allclose(target_result, result, atol=1))
30 |
31 | def test_matmul_bfloat32(self):
32 | """test jax vs torch matmul diff with bfloat32 on cpu"""
33 | jax.config.update("jax_platform_name", "cpu")
34 | torch.set_default_dtype(torch.float32)
35 | r = c = 1000
36 | q = torch.randn((r, c))
37 | k = torch.randn((r, c))
38 | print(f"torch matlmul: {q.shape} * {k.shape}")
39 | result = torch.matmul(q, k)
40 |
41 | jax_q = torch_xla2.tensor.t2j(q)
42 | jax_k = torch_xla2.tensor.t2j(k)
43 | print(f"torch matlmul: {jax_q.shape} * {jax_k.shape}")
44 | jax_result = jnp.matmul(jax_q, jax_k)
45 | target_result = torch_xla2.tensor.j2t(jax_result)
46 | print(
47 | f"----------------------- matmul: Diff norm {(target_result - result).norm()}"
48 | )
49 | self.assertTrue(torch.allclose(target_result, result, atol=1e-4))
50 |
51 |
52 | if __name__ == "__main__":
53 | unittest.main()
54 |
--------------------------------------------------------------------------------
/tests/test_kv_cache_manager.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax
4 | import numpy as np
5 | import jax.numpy as jnp
6 | import torch
7 |
8 | from jetstream_pt.third_party.llama import model_args
9 | from jetstream_pt import environment
10 | from jetstream_pt.page_attention_manager import PageAttentionManager
11 | from jetstream_pt.cache_manager import PageKVCacheGenerate, KVCachePrefill
12 | from jetstream_pt import torchjax
13 | from absl.testing import parameterized
14 |
15 | P = jax.sharding.PartitionSpec
16 |
17 |
18 | class PageAttentnioTest(parameterized.TestCase):
19 |
20 | def _make_env(self, bf16_enable=True):
21 | torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
22 | torch.set_default_dtype(torch_dtype)
23 | jax.config.update("jax_dynamic_shapes", False)
24 | jax.config.update("jax_traceback_filtering", "off")
25 | jax.config.update("jax_platform_name", "cpu")
26 | config = model_args.get_model_args("tiny", 128, 1, True)
27 | environment_data = environment.JetEngineEnvironmentData()
28 | environment_data.max_input_sequence_length = 128
29 | environment_data.max_input_sequence_length = 128
30 | environment_data.cache_sequence_length = 128
31 | environment_data.bf16_enable = bf16_enable
32 | environment_data.model_type = "llama-2-tiny"
33 | environment_data.batch_size = 3
34 | environment_data.num_layers = config.n_layers
35 | environment_data.cache_shape = (
36 | 1,
37 | config.n_kv_heads,
38 | environment_data.cache_sequence_length,
39 | config.dim // config.n_heads,
40 | )
41 | env = environment.JetEngineEnvironment(environment_data)
42 | env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu
43 | mesh = jax.sharding.Mesh(np.array(jax.devices()), axis_names=("x",))
44 | replicated = jax.sharding.NamedSharding(mesh, P())
45 | env.sharding = replicated
46 | return env, config
47 |
48 | def test_page_attention_update(self):
49 | jax.config.update("jax_platform_name", "cpu")
50 | print(f"---------> {jax.devices()}")
51 |
52 | env, _ = self._make_env()
53 |
54 | pam = PageAttentionManager(
55 | batch_size=5,
56 | paged_attention_total_num_pages=20,
57 | paged_attention_page_size=4,
58 | max_pages_per_sequence=4,
59 | )
60 | shape = (1, 20, 4, 2)
61 | decode_caches = []
62 | decode_caches.append(
63 | PageKVCacheGenerate.empty(shape=shape, device=None, env=env)
64 | )
65 | decode_caches = [c.state() for c in decode_caches]
66 |
67 | self.cache_sharding = env.cache_sharding
68 |
69 | def _insert_prefill(seq_len, dim, slot):
70 | prefill_chache = KVCachePrefill()
71 | k, v = jnp.arange(seq_len * dim), jnp.arange(seq_len * dim)
72 | k, v = jnp.reshape(k, (1, 1, seq_len, dim)), jnp.reshape(
73 | k, (1, 1, seq_len, dim)
74 | )
75 | prefill_chache.update(k, v, 0)
76 | prefill_caches = [prefill_chache]
77 | prefill_caches = [c.state() for c in prefill_caches]
78 | num_pages, update_indexes = pam.reserve_pages_insert(slot, seq_len)
79 | _, kv_heads, _, dim = prefill_caches[0][0].shape
80 | tep_kv = jnp.zeros((kv_heads, num_pages * 4, dim), dtype=jnp.bfloat16)
81 |
82 | caches = pam.insert_prefill_cache(
83 | prefill_caches=prefill_caches,
84 | decode_caches=decode_caches,
85 | update_indexes=update_indexes,
86 | tep_kv=tep_kv,
87 | sharding=env.sharding,
88 | )
89 |
90 | return caches
91 |
92 | decode_caches = _insert_prefill(3, 2, 0)
93 | decode_caches = _insert_prefill(8, 2, 1)
94 | decode_caches = _insert_prefill(13, 2, 3)
95 |
96 | lens = np.asarray([3, 8, 0, 13, 0])
97 | pam.fill_new_pages(lens)
98 | np_page_token_indices = pam.get_page_token_indices(lens)
99 | page_token_indices = jnp.asarray(np_page_token_indices)
100 | page_token_indices = torchjax.to_torch(page_token_indices)
101 |
102 | caches_obj = [
103 | PageKVCacheGenerate(
104 | k, v, pam, page_token_indices, self.cache_sharding, env=env
105 | )
106 | for k, v in torchjax.to_torch(decode_caches)
107 | ]
108 | xk, xv = jnp.arange(-1, -11, -1).reshape(5, 1, 1, 2), jnp.arange(
109 | -1, -11, -1
110 | ).reshape(5, 1, 1, 2)
111 | xk = torchjax.to_torch(xk)
112 | xv = torchjax.to_torch(xv)
113 | decode_caches = caches_obj[0].update(xk, xv)
114 | expected = jnp.asarray([[0, 1], [2, 3], [4, 5], [-1, -2]])
115 | self.assertTrue(jnp.array_equal(decode_caches[0][0][0], expected))
116 | expected = jnp.asarray([[-3, -4], [0, 0], [0, 0], [0, 0]])
117 | self.assertTrue(jnp.array_equal(decode_caches[0][0][7], expected))
118 | expected = jnp.asarray([[24, 25], [-7, -8], [0, 0], [0, 0]])
119 | self.assertTrue(jnp.array_equal(decode_caches[0][0][6], expected))
120 |
121 |
122 | if __name__ == "__main__":
123 | unittest.main()
124 |
--------------------------------------------------------------------------------
/tests/test_page_attention.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax
4 | import numpy as np
5 | import jax.numpy as jnp
6 | import torch
7 |
8 | from jetstream_pt.third_party.llama import model_args
9 | from jetstream_pt import environment
10 | from jetstream_pt.page_attention_manager import PageAttentionManager
11 | from jetstream_pt.cache_manager import PageKVCacheGenerate, KVCachePrefill
12 | from absl.testing import parameterized
13 |
14 | P = jax.sharding.PartitionSpec
15 |
16 |
17 | class PageAttentionTest(parameterized.TestCase):
18 |
19 | def _make_env(self, bf16_enable=True):
20 | torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
21 | torch.set_default_dtype(torch_dtype)
22 | jax.config.update("jax_dynamic_shapes", False)
23 | jax.config.update("jax_traceback_filtering", "off")
24 | jax.config.update("jax_platform_name", "cpu")
25 | jax.config.update("jax_enable_x64", False)
26 | mesh = jax.sharding.Mesh(np.array(jax.devices()), axis_names=("x",))
27 | replicated = jax.sharding.NamedSharding(mesh, P())
28 | config = model_args.get_model_args("tiny", 128, 1, True)
29 | environment_data = environment.JetEngineEnvironmentData()
30 | environment_data.max_input_sequence_length = 128
31 | environment_data.max_input_sequence_length = 128
32 | environment_data.cache_sequence_length = 128
33 | environment_data.bf16_enable = bf16_enable
34 | environment_data.model_type = "llama-2-tiny"
35 | environment_data.batch_size = 3
36 | environment_data.num_layers = config.n_layers
37 | environment_data.cache_shape = (
38 | 1,
39 | config.n_kv_heads,
40 | environment_data.cache_sequence_length,
41 | config.dim // config.n_heads,
42 | )
43 | env = environment.JetEngineEnvironment(environment_data)
44 | env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu
45 | env.sharding = replicated
46 | return env, config
47 |
48 | def test_prefill_insert(self):
49 |
50 | env, _ = self._make_env()
51 |
52 | pam = PageAttentionManager(
53 | batch_size=3,
54 | paged_attention_total_num_pages=20,
55 | paged_attention_page_size=4,
56 | max_pages_per_sequence=4,
57 | )
58 | shape = (1, 6, 4, 2)
59 | decode_caches = []
60 | decode_caches.append(
61 | PageKVCacheGenerate.empty(shape=shape, device=None, env=env)
62 | )
63 | decode_caches = [c.state() for c in decode_caches]
64 |
65 | prefill_chache = KVCachePrefill()
66 | k, v = jnp.arange(6), jnp.arange(6)
67 | k, v = jnp.reshape(k, (1, 1, 3, 2)), jnp.reshape(k, (1, 1, 3, 2))
68 | prefill_chache.update(k, v, 0)
69 | prefill_caches = [prefill_chache]
70 | prefill_caches = [c.state() for c in prefill_caches]
71 |
72 | num_pages, update_indexes = pam.reserve_pages_insert(0, 3)
73 | _, kv_heads, _, dim = prefill_caches[0][0].shape
74 | tep_kv = jnp.zeros((kv_heads, num_pages * 4, dim), dtype=jnp.bfloat16)
75 |
76 | caches = pam.insert_prefill_cache(
77 | prefill_caches=prefill_caches,
78 | decode_caches=decode_caches,
79 | update_indexes=update_indexes,
80 | tep_kv=tep_kv,
81 | sharding=env.sharding,
82 | )
83 | expected_kv = jnp.arange(6).reshape(3, 2)
84 | padding = jnp.asarray([[0, 0]])
85 | expected_kv = jnp.concatenate((expected_kv, padding))
86 |
87 | self.assertTrue(
88 | jnp.array_equal(
89 | caches[0][0][0, 0, 0:4, 0:2], expected_kv.astype(jnp.bfloat16)
90 | )
91 | )
92 | self.assertTrue(
93 | jnp.array_equal(
94 | caches[0][1][0, 0, 0:4, 0:2], expected_kv.astype(jnp.bfloat16)
95 | )
96 | )
97 |
98 | def test_prefill_insert_multiple_pages(self):
99 |
100 | jax.config.update("jax_platform_name", "cpu")
101 | print(f"---------> {jax.devices()}")
102 |
103 | env, _ = self._make_env()
104 |
105 | pam = PageAttentionManager(
106 | batch_size=3,
107 | paged_attention_total_num_pages=20,
108 | paged_attention_page_size=4,
109 | max_pages_per_sequence=4,
110 | )
111 | shape = (1, 20, 4, 2)
112 | decode_caches = []
113 | decode_caches.append(
114 | PageKVCacheGenerate.empty(shape=shape, device=None, env=env)
115 | )
116 | decode_caches = [c.state() for c in decode_caches]
117 |
118 | self.cache_sharding = env.cache_sharding
119 |
120 | prefill_chache = KVCachePrefill()
121 | k, v = jnp.arange(12), jnp.arange(12)
122 | k, v = jnp.reshape(k, (1, 1, 6, 2)), jnp.reshape(k, (1, 1, 6, 2))
123 | prefill_chache.update(k, v, 0)
124 | prefill_caches = [prefill_chache]
125 | prefill_caches = [c.state() for c in prefill_caches]
126 |
127 | num_pages, update_indexes = pam.reserve_pages_insert(0, 6)
128 | _, kv_heads, _, dim = prefill_caches[0][0].shape
129 | tep_kv = jnp.zeros((kv_heads, num_pages * 4, dim), dtype=jnp.bfloat16)
130 |
131 | decode_caches = pam.insert_prefill_cache(
132 | prefill_caches=prefill_caches,
133 | decode_caches=decode_caches,
134 | update_indexes=update_indexes,
135 | tep_kv=tep_kv,
136 | sharding=env.sharding,
137 | )
138 |
139 | self.assertEqual(len(decode_caches), 1)
140 | expected = jnp.arange(16).at[12:16].set([0, 0, 0, 0]).reshape(1, 2, 4, 2)
141 |
142 | updated_k = jax.lax.slice_in_dim(decode_caches[0][0], 0, 2, axis=1)
143 | self.assertTrue(jnp.array_equal(updated_k, expected))
144 | noupdated_k = jax.lax.slice_in_dim(decode_caches[0][0], 2, 20, axis=1)
145 | self.assertTrue(jnp.array_equal(noupdated_k, jnp.zeros_like(noupdated_k)))
146 |
147 | def test_reserve_pages_decode(self):
148 |
149 | env, _ = self._make_env()
150 |
151 | pam = PageAttentionManager(
152 | batch_size=3,
153 | paged_attention_total_num_pages=20,
154 | paged_attention_page_size=4,
155 | max_pages_per_sequence=4,
156 | )
157 | slot = 1
158 | seq_len = 8
159 | pam.reserve_pages_insert(slot, seq_len)
160 | expected_slot_page_indices = np.asarray([0, 1])
161 | slot_page_indices = pam.page_indices[slot][0:2]
162 | self.assertTrue(
163 | np.array_equal(slot_page_indices, expected_slot_page_indices)
164 | )
165 |
166 | lens = np.asarray([0, seq_len, 0])
167 | pam.fill_new_pages(lens)
168 | expected_slot_page_indices = np.asarray([0, 1, 2, 19])
169 | slot_page_indices = pam.page_indices[slot]
170 | self.assertTrue(
171 | np.array_equal(slot_page_indices, expected_slot_page_indices)
172 | )
173 |
174 | expected_0_page_indices = np.asarray([19, 19, 19, 19])
175 | zer0_page_indices = pam.page_indices[0][0:4]
176 | self.assertTrue(np.array_equal(zer0_page_indices, expected_0_page_indices))
177 |
178 | def test_get_page_token_indices(self):
179 | env, _ = self._make_env()
180 |
181 | pam = PageAttentionManager(
182 | batch_size=5,
183 | paged_attention_total_num_pages=20,
184 | paged_attention_page_size=4,
185 | max_pages_per_sequence=4,
186 | )
187 | pam.reserve_pages_insert(1, 8)
188 | pam.reserve_pages_insert(3, 13)
189 | pam.reserve_pages_insert(0, 3)
190 |
191 | lens = np.asarray([3, 8, 0, 13, 0])
192 | pam.fill_new_pages(lens)
193 |
194 | page_token_indices = pam.get_page_token_indices(lens)
195 |
196 | expected_page_indices = np.asarray([6, 7, 5])
197 | expected_token_indices = np.asarray([3, 4, 9])
198 | self.assertTrue(
199 | np.array_equal(page_token_indices[0], expected_page_indices)
200 | )
201 | self.assertTrue(
202 | np.array_equal(page_token_indices[1], expected_token_indices)
203 | )
204 |
205 |
206 | if __name__ == "__main__":
207 | unittest.main()
208 |
--------------------------------------------------------------------------------