├── .github └── workflows │ └── unit_tests.yaml ├── .gitignore ├── .gitmodules ├── .pylintrc ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks ├── __init__.py ├── analyze_sharegpt.py ├── basic_ops.py ├── mixtral_offline.sh └── summary.md ├── default_shardings ├── gemma.yaml ├── llama-blockwise-quant.yaml ├── llama.yaml └── mixtral.yaml ├── docker └── jetstream-pytorch-server │ ├── Dockerfile │ ├── README.md │ └── jetstream_pytorch_server_entrypoint.sh ├── docs ├── add_a_new_model.md └── add_hf_checkpoint_conversion.md ├── install_everything.sh ├── install_everything_gpu.sh ├── jetstream_pt ├── __init__.py ├── attention_kernel.py ├── cache_manager.py ├── cli.py ├── config.py ├── engine.py ├── environment.py ├── fetch_models.py ├── gcs_to_cns.sh ├── hf_tokenizer.py ├── layers.py ├── model_base.py ├── page_attention_manager.py ├── quantize.py ├── quantize_model.py ├── ray_engine.py ├── ray_worker.py ├── third_party │ ├── gemma │ │ ├── __init__.py │ │ ├── config.py │ │ ├── model.py │ │ ├── model_original.py │ │ └── tokenizer.py │ ├── llama │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── generation_original.py │ │ ├── model_args.py │ │ ├── model_exportable.py │ │ ├── model_original.py │ │ ├── tokenizer.model │ │ └── tokenizer.py │ └── mixtral │ │ ├── __init__.py │ │ ├── config.py │ │ ├── model.py │ │ ├── model_original.py │ │ └── tokenizer.model └── torchjax.py ├── kuberay ├── image │ └── Dockerfile └── manifests │ ├── ray-cluster.tpu-v4-multihost.yaml │ ├── ray-cluster.tpu-v4-singlehost.yaml │ ├── ray-cluster.tpu-v5-multihost.yaml │ └── ray-cluster.tpu-v5-singlehost.yaml ├── mlperf ├── README.md ├── backend.py ├── benchmark_run.sh ├── dataset.py ├── install.sh ├── main.py ├── mlperf.conf ├── start_server.sh ├── user.conf └── warmup.py ├── poetry.lock ├── pyproject.toml ├── run_interactive_disaggregated.py ├── run_interactive_multiple_host.py ├── run_ray_serve_interleave.py ├── run_server_with_ray.py ├── scripts ├── create_empty_sharding_map.py ├── jax_experiments.py └── validate_hf_ckpt_conversion.py └── tests ├── .pylintrc ├── __init__.py ├── helpers.py ├── test_attention_kernal.py ├── test_engine.py ├── test_hf_names.py ├── test_jax_torch.py ├── test_kv_cache_manager.py ├── test_llama_e2e.py ├── test_model_impl.py ├── test_page_attention.py └── test_quantization.py /.github/workflows/unit_tests.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 16 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 17 | 18 | name: Unit Tests 19 | 20 | on: 21 | pull_request: 22 | push: 23 | branches: [ "main" ] 24 | workflow_dispatch: 25 | schedule: 26 | # Run the job every 12 hours 27 | - cron: '0 */12 * * *' 28 | 29 | jobs: 30 | py: 31 | name: "Python type/lint/format checks" 32 | strategy: 33 | matrix: 34 | os: [ubuntu-20.04] 35 | python-version: ['3.10'] 36 | runs-on: ${{ matrix.os }} 37 | steps: 38 | - name: Checkout 39 | uses: actions/checkout@v4 40 | - name: Setup Python 41 | uses: actions/setup-python@v4 42 | with: 43 | python-version: ${{ matrix.python-version }} 44 | - name: Install Dependencies 45 | run: | 46 | pip install pytype 47 | pip install pylint 48 | pip install pyink 49 | source install_everything.sh 50 | # - name: Typecheck the code with pytype 51 | # run: | 52 | # pytype --jobs auto --disable import-error --disable module-attr jetstream_pt/ 53 | - name: Analysing the code with pylint 54 | run: | 55 | pylint --indent-string=' ' jetstream_pt/ benchmarks/ 56 | - name: Format check with pyink 57 | run: | 58 | pyink --pyink-indentation 2 --line-length 80 --check --verbose --extend-exclude=deps . 59 | 60 | cpu: 61 | name: "jetstream_pt unit tests" 62 | strategy: 63 | matrix: 64 | os: [ubuntu-20.04] 65 | python-version: ['3.10'] 66 | runs-on: ${{ matrix.os }} 67 | steps: 68 | - name: Checkout 69 | uses: actions/checkout@v4 70 | - name: Setup Python 71 | uses: actions/setup-python@v4 72 | with: 73 | python-version: ${{ matrix.python-version }} 74 | - name: Install Dependencies 75 | run: | 76 | source install_everything.sh 77 | - name: Run all unit tests for jetstream_pt (/tests) 78 | run: | 79 | JAX_PLATFORMS=cpu coverage run -m unittest -v 80 | - name: Create test coverage report 81 | run: | 82 | coverage report -m -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # vscode 156 | .vscode/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "deps/JetStream"] 2 | path = deps/JetStream 3 | url = https://github.com/google/JetStream.git 4 | [submodule "deps/xla"] 5 | path = deps/xla 6 | url = https://github.com/pytorch/xla.git 7 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MESSAGES CONTROL] 2 | disable=C0114,R0801,R0903,R0913,R0917,E1102,W0613,R1711,too-many-locals 3 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We'd love to accept your patches and contributions to this project. 4 | 5 | ## Before you begin 6 | 7 | ### Sign our Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a 10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). 11 | You (or your employer) retain the copyright to your contribution; this simply 12 | gives us permission to use and redistribute your contributions as part of the 13 | project. 14 | 15 | If you or your current employer have already signed the Google CLA (even if it 16 | was for a different project), you probably don't need to do it again. 17 | 18 | Visit to see your current agreements or to 19 | sign a new one. 20 | 21 | ### Review our community guidelines 22 | 23 | This project follows 24 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 25 | 26 | ## Contribution process 27 | 28 | ### Code reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use GitHub pull requests for this purpose. Consult 32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 33 | information on using pull requests. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Jetstream-PyTorch 2 | JetStream Engine implementation in PyTorch 3 | 4 | # Latest Release: 5 | 6 | The latest release version is tagged with `jetstream-v0.2.3`. If you are running the release version 7 | Please follow the README of the that version here: 8 | https://github.com/google/jetstream-pytorch/blob/jetstream-v0.2.3/README.md 9 | 10 | Commandline Flags might have changed between the release version to HEAD. 11 | 12 | # Outline 13 | 14 | 1. Ssh to Cloud TPU VM (using v5e-8 TPU VM) 15 | a. Create a Cloud TPU VM if you haven’t 16 | 2. Download jetstream-pytorch github repo 17 | 3. Run the server 18 | 4. Run benchmarks 19 | 5. Typical Errors 20 | 21 | # Ssh to Cloud TPU VM (using v5e-8 TPU VM) 22 | 23 | ```bash 24 | gcloud compute config-ssh 25 | gcloud compute tpus tpu-vm ssh "your-tpu-vm" --project "your-project" --zone "your-project-zone" 26 | ``` 27 | ## Create a Cloud TPU VM in a GCP project if you haven’t 28 | Follow the steps in 29 | * https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm 30 | 31 | # Clone repo and install dependencies 32 | 33 | ## Get the jetstream-pytorch code 34 | ```bash 35 | git clone https://github.com/google/jetstream-pytorch.git 36 | git checkout jetstream-v0.2.4 37 | ``` 38 | 39 | (optional) Create a virtual env using `venv` or `conda` and activate it. 40 | 41 | ## 2. Run installation script: 42 | 43 | ```bash 44 | cd jetstream-pytorch 45 | source install_everything.sh 46 | ``` 47 | 48 | 49 | # Run jetstream pytorch 50 | 51 | ## List out supported models 52 | 53 | ``` 54 | jpt list 55 | ``` 56 | 57 | This will print out list of support models and variants: 58 | 59 | ``` 60 | meta-llama/Llama-2-7b-chat-hf 61 | meta-llama/Llama-2-7b-hf 62 | meta-llama/Llama-2-13b-chat-hf 63 | meta-llama/Llama-2-13b-hf 64 | meta-llama/Llama-2-70b-hf 65 | meta-llama/Llama-2-70b-chat-hf 66 | meta-llama/Meta-Llama-3-8B 67 | meta-llama/Meta-Llama-3-8B-Instruct 68 | meta-llama/Meta-Llama-3-70B 69 | meta-llama/Meta-Llama-3-70B-Instruct 70 | meta-llama/Llama-3.1-8B 71 | meta-llama/Llama-3.1-8B-Instruct 72 | meta-llama/Llama-3.2-1B 73 | meta-llama/Llama-3.2-1B-Instruct 74 | meta-llama/Llama-3.3-70B 75 | meta-llama/Llama-3.3-70B-Instruct 76 | google/gemma-2b 77 | google/gemma-2b-it 78 | google/gemma-7b 79 | google/gemma-7b-it 80 | mistralai/Mixtral-8x7B-v0.1 81 | mistralai/Mixtral-8x7B-Instruct-v0.1 82 | ``` 83 | 84 | To run jetstream-pytorch server with one model: 85 | 86 | ``` 87 | jpt serve --model_id meta-llama/Meta-Llama-3-8B-Instruct 88 | ``` 89 | 90 | If it's the first time you run this model, it will download weights from 91 | HuggingFace. 92 | 93 | HuggingFace's Llama3 weights are gated, so you need to either run 94 | `huggingface-cli login` to set your token, OR, pass your hf_token explicitly. 95 | 96 | To pass hf token explicitly, add `--hf_token` flag 97 | ``` 98 | jpt serve --model_id meta-llama/Meta-Llama-3-8B-Instruct --hf_token=... 99 | ``` 100 | 101 | To login using huggingface hub, run: 102 | 103 | ``` 104 | pip install -U "huggingface_hub[cli]" 105 | huggingface-cli login 106 | ``` 107 | Then follow its prompt. 108 | 109 | After the weights are downloaded, 110 | Next time when you run this `--hf_token` will no longer be required. 111 | 112 | To run this model in `int8` quantization, add `--quantize_weights=1`. 113 | Quantization will be done on the flight as the weight loads. 114 | 115 | Weights downloaded from HuggingFace will be stored by default in `checkpoints` folder. 116 | in the place where `jpt` is executed. 117 | 118 | You can change where the weights are stored with `--working_dir` flag. 119 | 120 | If you wish to use your own checkpoint, then, place them inside 121 | of the `checkpoints///hf_original` dir (or the corresponding subdir in `--working_dir`). For example, 122 | Llama3 checkpoints will be at `checkpoints/meta-llama/Llama-2-7b-hf/hf_original/*.safetensors`. You can replace these files with modified 123 | weights in HuggingFace format. 124 | 125 | ## Send one request 126 | 127 | Jetstream-pytorch uses gRPC for handling requests, the script below demonstrates how to 128 | send gRPC in Python. You can also use other gPRC clients. 129 | 130 | ```python 131 | import requests 132 | import os 133 | import grpc 134 | 135 | from jetstream.core.proto import jetstream_pb2 136 | from jetstream.core.proto import jetstream_pb2_grpc 137 | 138 | prompt = "What are the top 5 languages?" 139 | 140 | channel = grpc.insecure_channel("localhost:8888") 141 | stub = jetstream_pb2_grpc.OrchestratorStub(channel) 142 | 143 | request = jetstream_pb2.DecodeRequest( 144 | text_content=jetstream_pb2.DecodeRequest.TextContent( 145 | text=prompt 146 | ), 147 | priority=0, 148 | max_tokens=2000, 149 | ) 150 | 151 | response = stub.Decode(request) 152 | output = [] 153 | for resp in response: 154 | output.extend(resp.stream_content.samples[0].text) 155 | 156 | text_output = "".join(output) 157 | print(f"Prompt: {prompt}") 158 | print(f"Response: {text_output}") 159 | ``` 160 | 161 | 162 | # Run the server with ray 163 | Below are steps run server with ray: 164 | 1. Ssh to Cloud Multiple Host TPU VM (v5e-16 TPU VM) 165 | 2. Step 2 to step 5 in Outline 166 | 3. Setup ray cluster 167 | 4. Run server with ray 168 | 169 | ## Setup Ray Cluster 170 | Login host 0 VM, start ray head with below command: 171 | 172 | ```bash 173 | 174 | ray start --head 175 | 176 | ``` 177 | 178 | Login other host VMs, start ray head with below command: 179 | 180 | ```bash 181 | 182 | ray start --address='$ip:$port' 183 | 184 | ``` 185 | 186 | Note: Get address ip and port information from ray head. 187 | 188 | ## Run server with ray 189 | 190 | Here is an example to run the server with ray for llama2 7B model: 191 | 192 | ```bash 193 | export DISABLE_XLA2_PJRT_TEST="true" 194 | python run_server_with_ray.py --tpu_chips=16 --num_hosts=4 --worker_chips=4 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" 195 | ``` 196 | 197 | # Run benchmark 198 | Start the server and then go to the deps/JetStream folder (downloaded during `install_everything.sh`) 199 | 200 | ```bash 201 | cd deps/JetStream 202 | wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json 203 | export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json 204 | python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs --warmup-mode=sampled --model=$model_name 205 | ``` 206 | Please look at `deps/JetStream/benchmarks/README.md` for more information. 207 | 208 | 209 | 210 | ## Run server with Ray Serve 211 | 212 | ### Prerequisites 213 | 214 | If running on GKE: 215 | 216 | 1. Follow instructions on [this link](https://github.com/GoogleCloudPlatform/ai-on-gke/tree/main/ray-on-gke/guides/tpu) to setup a GKE cluster and the TPU webhook. 217 | 2. Follow instructions 218 | [here](https://cloud.google.com/kubernetes-engine/docs/how-to/persistent-volumes/cloud-storage-fuse-csi-driver) 219 | to enable GCSFuse for your cluster. This will be needed to store the 220 | converted weights. 221 | 3. Deploy one of the sample Kuberay cluster configurations: 222 | ```bash 223 | kubectl apply -f kuberay/manifests/ray-cluster.tpu-v4-singlehost.yaml 224 | ``` 225 | or 226 | ```bash 227 | kubectl apply -f kuberay/manifests/ray-cluster.tpu-v4-multihost.yaml 228 | ``` 229 | 230 | 231 | ### Start a Ray Serve deployment 232 | 233 | Single-host (Llama2 7B): 234 | 235 | ```bash 236 | export RAY_ADDRESS=http://localhost:8265 237 | 238 | kubectl port-forward svc/example-cluster-kuberay-head-svc 8265:8265 & 239 | 240 | ray job submit --runtime-env-json='{"working_dir": "."}' -- python run_ray_serve_interleave.py --tpu_chips=4 --num_hosts=1 --size=7b --model_name=llama-2 --batch_size=32 --max_cache_length=2048 --tokenizer_path=/llama/tokenizer.model --checkpoint_path=/llama/ckpt --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml" 241 | ``` 242 | 243 | Multi-host (Llama2 70B): 244 | 245 | ```bash 246 | export RAY_ADDRESS=http://localhost:8265 247 | 248 | kubectl port-forward svc/example-cluster-kuberay-head-svc 8265:8265 & 249 | 250 | ray job submit --runtime-env-json='{"working_dir": "."}' -- python run_ray_serve_interleave.py --tpu_chips=8 --num_hosts=2 --size=70b --model_name=llama-2 --batch_size=8 --max_cache_length=2048 --tokenizer_path=/llama/tokenizer.model --checkpoint_path=/llama/ckpt --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml" 251 | ``` 252 | 253 | ### Sending an inference request 254 | 255 | Port-forward to port 8888 for gRPC: 256 | ``` 257 | kubectl port-forward svc/example-cluster-kuberay-head-svc 8888:8888 & 258 | ``` 259 | 260 | Sample python script: 261 | 262 | ```python 263 | import requests 264 | import os 265 | import grpc 266 | 267 | from jetstream.core.proto import jetstream_pb2 268 | from jetstream.core.proto import jetstream_pb2_grpc 269 | 270 | prompt = "What are the top 5 languages?" 271 | 272 | channel = grpc.insecure_channel("localhost:8888") 273 | stub = jetstream_pb2_grpc.OrchestratorStub(channel) 274 | 275 | request = jetstream_pb2.DecodeRequest( 276 | text_content=jetstream_pb2.DecodeRequest.TextContent( 277 | text=prompt 278 | ), 279 | priority=0, 280 | max_tokens=2000, 281 | ) 282 | 283 | response = stub.Decode(request) 284 | output = [] 285 | for resp in response: 286 | output.extend(resp.stream_content.samples[0].text) 287 | 288 | text_output = "".join(output) 289 | print(f"Prompt: {prompt}") 290 | print(f"Response: {text_output}") 291 | ``` 292 | 293 | 294 | 295 | # Typical Errors 296 | 297 | ## Unexpected keyword argument 'device' 298 | 299 | Fix: 300 | * Uninstall jax and jaxlib dependencies 301 | * Reinstall using `source install_everything.sh 302 | 303 | ## Out of memory 304 | 305 | Fix: 306 | * Use smaller batch size 307 | * Use quantization 308 | 309 | # Links 310 | 311 | ## JetStream 312 | * https://github.com/google/JetStream 313 | 314 | ## MaxText 315 | * https://github.com/google/maxtext 316 | 317 | 318 | -------------------------------------------------------------------------------- /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /benchmarks/analyze_sharegpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | 17 | CUTOFF_INPUT = 1024 18 | CUTOFF_OUTPUT = 1024 19 | 20 | 21 | # pylint: disable-next=all 22 | def do_simulation( 23 | sharegpt_path, prefill_bucket_size_to_ms, system_time_per_decode_token_ms 24 | ): 25 | def next_power_of_2(x): 26 | return 1 if x == 0 else 2 ** (x - 1).bit_length() 27 | 28 | def tokens_in_input_str(s): 29 | return_val = int(1.3 * len(s.split())) 30 | # print(f"{s=} -> {return_val=}") 31 | return return_val 32 | 33 | convo_numbers = [] 34 | # Please update with your own data file path 35 | 36 | with open(sharegpt_path, "r", encoding="utf-8") as f: 37 | loaded_share_gpt = json.load(f) 38 | for example in loaded_share_gpt: 39 | if len(example["conversations"]) < 2: 40 | continue 41 | input_tokens = tokens_in_input_str(example["conversations"][0]["value"]) 42 | output_tokens = tokens_in_input_str(example["conversations"][1]["value"]) 43 | convo_numbers.append((input_tokens, output_tokens)) 44 | 45 | num_convos = len(convo_numbers) 46 | kept_convos = [ 47 | c for c in convo_numbers if c[0] <= CUTOFF_INPUT and c[1] <= CUTOFF_OUTPUT 48 | ] 49 | 50 | mean_input = sum(c[0] for c in kept_convos) / len(kept_convos) 51 | mean_output = sum(c[1] for c in kept_convos) / len(kept_convos) 52 | 53 | print( 54 | f"""Total {num_convos=} but only kept {kept_convos=}. 55 | Out of kept, {mean_input=}, {mean_output=}""" 56 | ) 57 | 58 | total_prefill_system_ms = 0 59 | total_generate_system_ms = 0 60 | 61 | for convo in kept_convos: 62 | input_tok, output_tok = convo 63 | bucket = max(128, next_power_of_2(input_tok)) 64 | generate_system_ms = output_tok * system_time_per_decode_token_ms 65 | prefill_system_ms = prefill_bucket_size_to_ms[bucket] 66 | 67 | print( 68 | f"{convo=} {bucket=}, {prefill_system_ms=:.2f}, {generate_system_ms=:.2f}" 69 | ) 70 | 71 | total_prefill_system_ms += prefill_system_ms 72 | total_generate_system_ms += generate_system_ms 73 | 74 | total_time_ms = total_prefill_system_ms + total_generate_system_ms 75 | input_tokens = sum(c[0] for c in kept_convos) 76 | 77 | output_tokens = sum(c[1] for c in kept_convos) 78 | print( 79 | f"""Output tokens {output_tokens} in {total_time_ms/1000:.2f} seconds, 80 | for {output_tokens/(total_time_ms/1000):.2f} out tok/s""" 81 | ) 82 | 83 | total_prefill_sec = total_prefill_system_ms / 1000 84 | total_generate_sec = total_generate_system_ms / 1000 85 | 86 | print( 87 | f"""Total time {total_time_ms/1000:.2f} seconds, 88 | split {total_prefill_sec=:.2f} seconds and {total_generate_sec=:.2f} seconds""" 89 | ) 90 | 91 | idealized_prefill_sec = ( 92 | 1.1 * input_tokens / 1024 * prefill_bucket_size_to_ms[1024] / 1000 93 | ) 94 | 95 | prefill_savings_sec = total_prefill_sec - idealized_prefill_sec 96 | 97 | idealized_generate_sec = ( 98 | total_generate_sec / 2 99 | ) # (Roughly save 75% on KV cache high cost on the rest) 100 | generate_savings_sec = total_generate_sec - idealized_generate_sec 101 | 102 | print( 103 | f"""we think prefill will take {total_prefill_sec=:.2f}, 104 | we could get it to {idealized_prefill_sec=:.2f} so we'd 105 | save {prefill_savings_sec=:.2f} seconds """ 106 | ) 107 | print( 108 | f"""with sparsity we could go from {total_generate_sec=:.2f}, 109 | we could get it to {idealized_generate_sec=:.2f} so we'd save 110 | {generate_savings_sec=:.2f} seconds """ 111 | ) 112 | 113 | idealized_overall_time = idealized_generate_sec + idealized_prefill_sec 114 | 115 | print( 116 | f"""Idealized out tokens {output_tokens} in {idealized_overall_time:.2f} seconds, 117 | for {output_tokens/idealized_overall_time:.2f} out tok/s""" 118 | ) 119 | print("prfill", prefill_bucket_size_to_ms) 120 | print("decode step", system_time_per_decode_token_ms) 121 | -------------------------------------------------------------------------------- /benchmarks/basic_ops.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Tuple, Callable, List, Optional 3 | import time 4 | import dataclasses 5 | 6 | import numpy as np 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | from jax.experimental import mesh_utils, shard_map 11 | from jax.sharding import PositionalSharding 12 | 13 | 14 | from jax.sharding import Mesh 15 | from jax.sharding import PartitionSpec 16 | from jax.sharding import NamedSharding 17 | 18 | devices = jax.devices() 19 | P = PartitionSpec 20 | 21 | devices = mesh_utils.create_device_mesh((len(devices),)) 22 | mesh = Mesh(devices, axis_names=("x",)) 23 | # y = jax.device_put(x, NamedSharding(mesh, P('a', 'b'))) 24 | 25 | L = 1 << 15 26 | 27 | 28 | @dataclasses.dataclass 29 | class BenchmarkCase: 30 | """BenchmarkCase.""" 31 | 32 | name: str 33 | function: Callable 34 | args_shape: List[Tuple] 35 | args_sharding: List[PartitionSpec] 36 | profiler_output: Optional[str] = None 37 | 38 | 39 | start_key = jax.random.key(0) 40 | 41 | 42 | def _new_arg(shape, dtype): 43 | global start_key # pylint: disable=all 44 | start_key, _ = jax.random.split(start_key) 45 | with jax.default_device(jax.devices("cpu")[0]): 46 | if dtype == jnp.int8.dtype: 47 | return jax.random.randint(start_key, shape, 0, 100, dtype=dtype) 48 | else: 49 | return jax.random.normal(start_key, shape, dtype=dtype) + 1 50 | 51 | 52 | def _new_args(case, dtype): 53 | args = [] 54 | for shape, sharding in zip(case.args_shape, case.args_sharding): 55 | arg = _new_arg(shape, dtype) 56 | if sharding is not None: 57 | arg = jax.device_put(arg, NamedSharding(mesh, sharding)) 58 | args.append(arg) 59 | return args 60 | 61 | 62 | def _run_case(case, warmup=2, runtimes=5, dtype=jnp.bfloat16.dtype): 63 | for _ in range(warmup): 64 | args = _new_args(case, dtype) 65 | case.function(*args) 66 | 67 | stamps = [] 68 | for i in range(runtimes): 69 | args = _new_args(case, dtype) 70 | jax.block_until_ready(args) 71 | if case.profiler_output is not None and i == (runtimes - 1): 72 | jax.profiler.start_trace(case.profiler_output) 73 | start = time.perf_counter() 74 | jax.block_until_ready(case.function(*args)) 75 | end = time.perf_counter() 76 | if case.profiler_output is not None and i == (runtimes - 1): 77 | jax.profiler.stop_trace() 78 | stamps.append(end - start) 79 | return sum(stamps) / runtimes 80 | 81 | 82 | def _llama_ffn(x, w1, w2, w3): 83 | w1_res = jax.nn.silu((x @ w1).astype(jnp.bfloat16.dtype)) 84 | w3_res = x @ w3 85 | res = (w1_res * w3_res) @ w2 86 | return res 87 | 88 | 89 | @jax.jit 90 | @functools.partial( 91 | shard_map.shard_map, 92 | mesh=mesh, 93 | in_specs=(P(), P(None, "x"), P("x"), P(None, "x")), 94 | out_specs=(P()), 95 | ) 96 | def _llama_ffn_shmap(x, w1, w2, w3): 97 | for _ in range(3): 98 | x = _llama_ffn(x, w1, w2, w3) 99 | x = jax.lax.psum(x, "x") 100 | return x 101 | 102 | 103 | @jax.jit 104 | def _llama_ffn_spmd(x, w1, w2, w3): 105 | for _ in range(3): 106 | x = _llama_ffn(x, w1, w2, w3) 107 | x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P())) 108 | return x 109 | 110 | 111 | dim = 4096 112 | multiple_of = 256 113 | # hidden_dim = 4 * dim 114 | # hidden_dim = int(2 * hidden_dim / 3) 115 | # hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 116 | hidden_dim = 11008 117 | BATCH = 1024 118 | 119 | 120 | @jax.jit 121 | @functools.partial( 122 | shard_map.shard_map, 123 | mesh=mesh, 124 | in_specs=(P("x"),), 125 | out_specs=(P()), 126 | check_rep=False, 127 | ) 128 | def _all_gather(x): 129 | return jax.lax.all_gather(x, "x") 130 | 131 | 132 | @jax.jit 133 | @functools.partial( 134 | shard_map.shard_map, mesh=mesh, in_specs=(P("x"),), out_specs=(P()) 135 | ) 136 | def _all_reduce(x): 137 | return jax.lax.psum(x, "x") 138 | 139 | 140 | allcases = [ 141 | BenchmarkCase( 142 | name="Matmul replicated", 143 | function=jax.jit(jnp.matmul), 144 | args_shape=((L, L), (L, L)), 145 | args_sharding=(P(), P()), # replicated 146 | ), 147 | BenchmarkCase( 148 | name="Matmul sharded colrow", 149 | function=jax.jit(jnp.matmul), 150 | args_shape=((L, L), (L, L)), 151 | args_sharding=(P(None, "x"), P("x")), # replicated 152 | ), 153 | BenchmarkCase( 154 | name="matmul sharded rowcol", 155 | function=jax.jit(jnp.matmul), 156 | args_shape=((L, L), (L, L)), 157 | args_sharding=(P("x"), P("x", None)), # replicated 158 | ), 159 | BenchmarkCase( 160 | name="all_gather", 161 | function=_all_gather, 162 | args_shape=((L, L),), 163 | args_sharding=(P("x"),), # replicated 164 | ), 165 | BenchmarkCase( 166 | name="all_reduce", 167 | function=_all_reduce, 168 | args_shape=((L, L),), 169 | args_sharding=(P("x"),), # replicated 170 | ), 171 | BenchmarkCase( 172 | name="Llama 3xffn shardmap", 173 | function=_llama_ffn_shmap, 174 | args_shape=( 175 | (BATCH, dim), 176 | (dim, hidden_dim), 177 | (hidden_dim, dim), 178 | (dim, hidden_dim), 179 | ), 180 | args_sharding=(P(), P(None, "x"), P("x"), P(None, "x")), 181 | ), 182 | BenchmarkCase( 183 | name="Llama 3xffn gspmd", 184 | function=_llama_ffn_spmd, 185 | args_shape=( 186 | (BATCH, dim), 187 | (dim, hidden_dim), 188 | (hidden_dim, dim), 189 | (dim, hidden_dim), 190 | ), 191 | args_sharding=(P(), P(None, "x"), P("x"), P(None, "x")), 192 | ), 193 | ] 194 | 195 | 196 | def _run_call_cases(cases): 197 | for dtype in (jnp.bfloat16.dtype, jnp.int8.dtype): 198 | for case in cases: 199 | avg = _run_case(case, dtype=dtype) 200 | dtype_size = 2 if dtype == jnp.bfloat16.dtype else 1 201 | input_sizes = tuple( 202 | [ 203 | f"{np.prod(size) * dtype_size / (1<<20) :.6} MiB" 204 | for size in case.args_shape 205 | ] 206 | ) 207 | print( 208 | f"{dtype} \t {case.name}: \t{avg * 1000 :.6} ms \t sizes: {input_sizes}" 209 | ) 210 | 211 | 212 | def main(): 213 | print("Number of devices: ", len(devices)) 214 | _run_call_cases(allcases) 215 | 216 | 217 | if __name__ == "__main__": 218 | main() 219 | -------------------------------------------------------------------------------- /benchmarks/mixtral_offline.sh: -------------------------------------------------------------------------------- 1 | CACHE_LENGTH=1024 2 | INPUT_SIZE=512 3 | OUTPUT_SIZE=1024 4 | BATCH_SIZE=512 5 | CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ 6 | 7 | pushd .. 8 | python -m benchmarks.run_offline \ 9 | --model_name=mixtral \ 10 | --batch_size=$BATCH_SIZE \ 11 | --max_cache_length=$CACHE_LENGTH \ 12 | --max_decode_length=$OUTPUT_SIZE \ 13 | --context_length=$INPUT_SIZE \ 14 | --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \ 15 | --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \ 16 | --quantize_weights=1 \ 17 | --quantize_type=int8_per_channel \ 18 | --quantize_kv_cache=1 \ 19 | --profiling_output=/mnt/disks/hanq/mixtral-profiles 20 | popd -------------------------------------------------------------------------------- /benchmarks/summary.md: -------------------------------------------------------------------------------- 1 | # Benchmark results of various models 2 | 3 | 4 | ## Llama 3 - 8B 5 | 6 | Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s) 7 | ----| ------- | ------ |---------- | -------------|-----------------|------------------|---------------------- 8 | 2024-04-24 | TPU v5e-8 | bfloat16 | 128 | 2048 | 1024 | 1024 | 8249 9 | 2024-04-24 | TPU v5e-8 | int8 | 256 | 2048 | 1024 | 1024 | 10873 10 | 2024-07-29 | TPU v5e-8 | int8 | 256 | 2048 | 1024 | 1024 | 8471.54 11 | 12 | **NOTE:(2024-07-29)** Looks like we have a regression in the past 3 month. We are working in fixing it. 13 | 14 | 15 | ## Gemma - 7B 16 | 17 | Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s) 18 | ----| ------- | ------ |---------- | -------------|-----------------|------------------|---------------------- 19 | 2024-05-10 | TPU v5e-8 | bfloat16 | 96 | 2048 | 1024 | 1024 | 3236 20 | 2024-05-10 | TPU v5e-8 | int8 | 128 | 2048 | 1024 | 1024 | 4695 21 | 22 | ## Gemma - 2B 23 | 24 | Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s) 25 | ----| ------- | ------ |---------- | -------------|-----------------|------------------|---------------------- 26 | 2024-05-14 | TPU v5e-8 | bfloat16 | 512 | 2048 | 1024 | 1024 | 8700 27 | 2024-05-14 | TPU v5e-8 | int8 | 1024 | 2048 | 1024 | 1024 | 8746 28 | 2024-06-13 | TPU v5e-1 | bfloat16 | 1024 | 2048 | 1024 | 1024 | 4249 29 | 30 | 31 | ** NOTE: ** Gemma 2B uses `--shard_on_batch` flag so it's data parallel instead 32 | of model parallel. 33 | 34 | 35 | ## Llama 2 - 7B 36 | 37 | Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s) 38 | ----| ------- | ------ |---------- | -------------|-----------------|------------------|---------------------- 39 | 2024-03-28 | TPU v5e-8 | bfloat16 | 96 | 2048 | 1024 | 1024 | 3663 40 | 2024-03-28 | TPU v5e-8 | int8 | 96 | 2048 | 1024 | 1024 | 4783 41 | 42 | ## Llama 2 - 13B 43 | 44 | Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s) 45 | ----| ------- | ------ |---------- | -------------|-----------------|------------------|---------------------- 46 | 2024-03-28 | TPU v5e-8 | bfloat16 | 48 | 2048 | 1024 | 1024 | 2056 47 | 2024-03-28 | TPU v5e-8 | int8 | 96 | 2048 | 1024 | 1024 | 3458 48 | 2024-03-28 | TPU v5e-8 | bfloat16 | 80 | 1280 | 1024 | 1024 | 2911 49 | 2024-03-28 | TPU v5e-8 | int8 | 96 | 1280 | 1024 | 1024 | 3938 50 | 51 | **NOTE:** When cache length is less than the sum of max input length + max output length 52 | we employ *Rolling window attention*. 53 | 54 | 55 | # Instructions to reproduce: 56 | 57 | Please refer [README.md](README.md) for instructions in how to get the model weights. 58 | 59 | **NOTE** Different weights can produce different benchmark results (due to generating) 60 | different sentence length. For llama, we used the `-chat` versions of the weight. 61 | For Gemma we used the `-it` (instruction finetuned) version of the weights. 62 | 63 | ## Run the server 64 | NOTE: the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`) 65 | 66 | ```bash 67 | python run_server.py --param_size=7b --batch_size= 128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=$model_name 68 | ``` 69 | Now you can fire gRPC to it 70 | 71 | # Run benchmark 72 | go to the deps/JetStream folder (downloaded during `install_everything.sh`) 73 | 74 | ```bash 75 | cd deps/JetStream 76 | wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json 77 | export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json 78 | python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs --warmup-first=True 79 | ``` 80 | Please look at `deps/JetStream/benchmarks/README.md` for more information. 81 | -------------------------------------------------------------------------------- /default_shardings/gemma.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Sharding config for gemma 3 | # "replicated" to signify "replicated". 4 | # Integer signify axis to shard: 0 <= shard axis < rank 5 | 6 | freqs_cis : -1 # torch.complex64 (16384, 128) 7 | layers.*.self_attn.o_proj.weight: 1 8 | layers.*.self_attn.wq.weight : 0 # -1, 1] # torch.float32 (2048, 2048) 9 | layers.*.self_attn.wk.weight : 0 # -1, 1] # torch.float32 (256, 2048) 10 | layers.*.self_attn.wv.weight : 0 # -1, 1] # torch.float32 (256, 2048) 11 | layers.*.mlp.gate_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048) 12 | layers.*.mlp.gate_proj.bias : 0 # -1] # torch.float32 (16384,) 13 | layers.*.mlp.up_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048) 14 | layers.*.mlp.up_proj.bias : 0 # -1] # torch.float32 (16384,) 15 | layers.*.mlp.down_proj.weight : 1 # 1, -1] # torch.float32 (2048, 16384) 16 | layers.*.mlp.down_proj.bias : -1 # torch.float32 (2048,) 17 | layers.*.input_layernorm.weight : -1 # torch.float32 (2048,) 18 | layers.*.post_attention_layernorm.weight : -1 # torch.float32 (2048,) 19 | norm.weight : -1 # torch.float32 (2048,) 20 | embedder.weight : 1 # # 1, -1] # torch.float32 (256000, 2048) 21 | embedder.weight_scaler : 0 22 | layers.*.self_attn.o_proj.weight_scaler: 0 23 | layers.*.self_attn.wq.weight_scaler : 0 24 | layers.*.self_attn.wk.weight_scaler : 0 25 | layers.*.self_attn.wv.weight_scaler : 0 26 | layers.*.mlp.gate_proj.weight_scaler : 0 27 | layers.*.mlp.up_proj.weight_scaler : 0 28 | layers.*.mlp.down_proj.weight_scaler : 0 29 | -------------------------------------------------------------------------------- /default_shardings/llama-blockwise-quant.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Sharding config for llama-2 (With blockwise quantized linear layers) 3 | # Sharding should either be an int between 0 and rank - 1 4 | # signifying the axis to shard or -1 / null signifying replicated 5 | 6 | 7 | freqs_cis : -1 # torch.complex64 (2048, 64) 8 | tok_embeddings.weight : 1 # torch.int8 (32000, 4096) 9 | tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,) 10 | layers.*.attention.wo.weight : 2 # torch.int8 (32, 128, 4096) 11 | layers.*.attention.wo.weight_scaler : 1 # torch.bfloat16 (32, 4096) 12 | layers.*.attention.wq.weight : 0 # torch.int8 (32, 128, 4096) 13 | layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (32, 4096) 14 | layers.*.attention.wk.weight : 0 # torch.int8 (32, 128, 4096) 15 | layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (32, 4096) 16 | layers.*.attention.wv.weight : 0 # torch.int8 (32, 128, 4096) 17 | layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (32, 4096) 18 | layers.*.feed_forward.w1.weight : 0 # torch.int8 (32, 128, 11008) 19 | layers.*.feed_forward.w1.weight_scaler : 0 # torch.bfloat16 (32, 11008) 20 | layers.*.feed_forward.w2.weight : 2 # torch.int8 (86, 128, 4096) 21 | layers.*.feed_forward.w2.weight_scaler : 1 # torch.bfloat16 (86, 4096) 22 | layers.*.feed_forward.w3.weight : 0 # torch.int8 (32, 128, 11008) 23 | layers.*.feed_forward.w3.weight_scaler : 0 # torch.bfloat16 (32, 11008) 24 | layers.*.attention_norm.weight : -1 # torch.float32 (4096,) 25 | layers.*.ffn_norm.weight : -1 # torch.float32 (4096,) 26 | norm.weight : -1 # torch.float32 (4096,) 27 | output.weight : 0 # torch.int8 (32, 128, 32000) 28 | output.weight_scaler : 0 # torch.float32 (32, 32000) 29 | -------------------------------------------------------------------------------- /default_shardings/llama.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Sharding config for llama-2 3 | # Sharding should either be an int between 0 and rank - 1 4 | # signifying the axis to shard or -1 / null signifying replicated 5 | 6 | 7 | freqs_cis : -1 # torch.complex64 (2048, 64) 8 | tok_embeddings.weight : 1 # torch.float32 (vocab_size, 4096) 9 | tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,) 10 | layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096) 11 | layers.*.attention.wo.weight_scaler : 0 # torch.bfloat16 (4096,) 12 | layers.*.attention.wq.weight : 0 # torch.int8 (4096, 4096) 13 | layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (4096,) 14 | layers.*.attention.wk.weight : 0 # torch.int8 (4096, 4096) 15 | layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (4096,) 16 | layers.*.attention.wv.weight : 0 # torch.int8 (4096, 4096) 17 | layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (4096,) 18 | layers.*.feed_forward.w1.weight : 0 # torch.float32 (11008, 4096) 19 | layers.*.feed_forward.w1.weight_scaler : 0 # torch.bfloat16 (4096,) 20 | layers.*.feed_forward.w2.weight : 1 # torch.float32 (4096, 11008) 21 | layers.*.feed_forward.w2.weight_scaler : 0 # torch.bfloat16 (11008,) 22 | layers.*.feed_forward.w3.weight : 0 # torch.float32 (11008, 4096) 23 | layers.*.feed_forward.w3.weight_scaler : 0 # torch.bfloat16 (4096,) 24 | layers.*.attention_norm.weight : -1 # torch.float32 (4096,) 25 | layers.*.ffn_norm.weight : -1 # torch.float32 (4096,) 26 | norm.weight : -1 # torch.float32 (4096,) 27 | output.weight : 0 # torch.float32 (vocab_size, 4096) 28 | output.weight_scaler : 0 # torch.float32 (4096,) 29 | -------------------------------------------------------------------------------- /default_shardings/mixtral.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Sharding config for mixtral 3 | # Sharding should either be an int between 0 and rank - 1 4 | # signifying the axis to shard or -1 / null signifying replicated 5 | 6 | 7 | freqs_cis : -1 # torch.complex64 (2048, 64) 8 | tok_embeddings.weight : 1 # torch.float32 (vocab_size, 4096) 9 | tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,) 10 | layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096) 11 | layers.*.attention.wo.weight_scaler : -1 # torch.bfloat16 (4096,) 12 | layers.*.attention.wq.weight : 0 # torch.int8 (4096, 4096) 13 | layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (4096,) 14 | layers.*.attention.wk.weight : 0 # torch.int8 (4096, 4096) 15 | layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (4096,) 16 | layers.*.attention.wv.weight : 0 # torch.int8 (4096, 4096) 17 | layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (4096,) 18 | layers.*.attention.wqkv.weight : 0 # torch.int8 (4096, 4096) 19 | layers.*.attention.wqkv.weight_scaler : 0 # torch.bfloat16 (4096,) 20 | layers.*.block_sparse_moe.gate.weight: -1 21 | layers.*.block_sparse_moe.gate.weight_scaler: -1 22 | layers.*.block_sparse_moe.cond_ffn.w1: 1 23 | layers.*.block_sparse_moe.cond_ffn.w1_scaler: 1 24 | layers.*.block_sparse_moe.cond_ffn.w2: 2 25 | layers.*.block_sparse_moe.cond_ffn.w2_scaler: -1 26 | layers.*.block_sparse_moe.cond_ffn.w3: 1 27 | layers.*.block_sparse_moe.cond_ffn.w3_scaler: 1 28 | layers.*.ffn_norm.weight : -1 # torch.float32 (4096,) 29 | layers.*.attention_norm.weight : -1 # torch.float32 (4096,) 30 | norm.weight : -1 # torch.float32 (4096,) 31 | output.weight : 0 # torch.float32 (vocab_size, 4096) 32 | output.weight_scaler : 0 # torch.float32 (4096,) 33 | -------------------------------------------------------------------------------- /docker/jetstream-pytorch-server/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Ubuntu:22.04 16 | # Use Ubuntu 22.04 from Docker Hub. 17 | # https://hub.docker.com/_/ubuntu/tags?page=1&name=22.04 18 | FROM ubuntu:22.04 19 | 20 | ENV DEBIAN_FRONTEND=noninteractive 21 | ENV PYTORCH_JETSTREAM_VERSION=main 22 | 23 | RUN apt -y update && apt install -y --no-install-recommends \ 24 | ca-certificates \ 25 | git \ 26 | python3.10 \ 27 | python3-pip 28 | 29 | RUN python3 -m pip install --upgrade pip 30 | 31 | RUN update-alternatives --install \ 32 | /usr/bin/python3 python3 /usr/bin/python3.10 1 33 | 34 | RUN git clone https://github.com/AI-Hypercomputer/jetstream-pytorch.git && \ 35 | cd /jetstream-pytorch && \ 36 | git checkout ${PYTORCH_JETSTREAM_VERSION} && \ 37 | bash install_everything.sh 38 | 39 | RUN pip install -U jax[tpu]==0.4.34 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 40 | 41 | COPY jetstream_pytorch_server_entrypoint.sh /usr/bin/ 42 | 43 | RUN chmod +x /usr/bin/jetstream_pytorch_server_entrypoint.sh 44 | 45 | ENTRYPOINT ["/usr/bin/jetstream_pytorch_server_entrypoint.sh"] -------------------------------------------------------------------------------- /docker/jetstream-pytorch-server/README.md: -------------------------------------------------------------------------------- 1 | ## Build and upload JetStream PyTorch Server image 2 | 3 | These instructions are to build the JetStream PyTorch Server image, which calls an entrypoint script that invokes the [JetStream](https://github.com/AI-Hypercomputer/JetStream) inference server with the JetStream-PyTorch framework. 4 | 5 | ``` 6 | docker build -t jetstream-pytorch-server . 7 | docker tag jetstream-pytorch-server us-docker.pkg.dev/${PROJECT_ID}/jetstream/jetstream-pytorch-server:latest 8 | docker push us-docker.pkg.dev/${PROJECT_ID}/jetstream/jetstream-pytorch-server:latest 9 | ``` 10 | 11 | If you would like to change the version of MaxText the image is built off of, change the `PYTORCH_JETSTREAM_VERSION` environment variable: 12 | ``` 13 | ENV PYTORCH_JETSTREAM_VERSION= 14 | ``` -------------------------------------------------------------------------------- /docker/jetstream-pytorch-server/jetstream_pytorch_server_entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2024 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | export HUGGINGFACE_TOKEN_DIR="/huggingface" 18 | 19 | cd /jetstream-pytorch 20 | huggingface-cli login --token $(cat ${HUGGINGFACE_TOKEN_DIR}/HUGGINGFACE_TOKEN) 21 | jpt serve $@ -------------------------------------------------------------------------------- /docs/add_hf_checkpoint_conversion.md: -------------------------------------------------------------------------------- 1 | # Guide on adding HuggingFace checkpoint conversion support 2 | 3 | ## Prerequisites: 4 | The model implementation has been added in JetStream-pt 5 | The checkpoint conversion from a certain format is already supported. (Or no conversion is needed for the checkpoint) 6 | 7 | Please check this [guide](https://github.com/google/jetstream-pytorch/blob/main/docs/add_a_new_model.md) for adding a new model. 8 | 9 | ## Use case: 10 | The user has the checkpoint for the same model architecture in another format (e.g. HF format for LLaMA model). And want to have JetStream-pt support this checkpoint format. 11 | 12 | ## Guide 13 | 14 | Converting a public checkpoint to JetStream-pt format is mostly about finding the weight key mapping between the public checkpoint and JetStream model implementation. Besides the name mapping, the layout of the weights might be different among different checkpoint formats (e.g. Weight interleaved differently due to difference in Rotary Embedding implementation). These differences are model and checkpoint format specific. 15 | 16 | **Note** The model code and checkpoint format can be different from model to model, the following guide demonstrate a general guide, specific models may require additional effort for the checkpoint conversion support. 17 | 18 | The checkpoint conversion logic in the checkpoint conversion script. 19 | 20 | ### Step 1 Find the HuggingFace checkpoint you want to convert 21 | In this example, let’s use meta-llama/llama-2 7B as an example 22 | 23 | You can download the checkpoints to a local folder using 24 | huggingface-cli download meta-llama/Llama-2-7b-hf --local-dir Llama-2-7b-hf 25 | 26 | 27 | **Note** You may need to go to Huggingface website to sign an agreement to get the permission to download the model 28 | 29 | ### Step 2 Inspect the weight names in the checkpoint: 30 | 31 | Usually there is a model.safetensors.index.json file in the checkpoint. [example](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/model.safetensors.index.json) 32 | 33 | Alternatively, you can load the weights locally and inspect the model key names(Usually it’s in safetensor format, and it’s sharded) 34 | 35 | Example script: 36 | ```Python 37 | import glob 38 | import os 39 | import torch 40 | from safetensors import safe_open 41 | 42 | checkpoint_folder = "/mnt/disks/lsiyuan/llama_weight/Meta-Llama-3-8B-Instruct" 43 | 44 | safetensor_files = glob.glob(os.path.join(checkpoint_folder, "*.safetensors")) 45 | 46 | for st_f in safetensor_files: 47 | with safe_open(st_f, framework="pt", device="cpu") as f: 48 | for key in f.keys(): 49 | weight_tensor = f.get_tensor(key) 50 | print(f"Weight name {key}, Shape: {weight_tensor.shape}, dtype: {weight_tensor.dtype}") 51 | ``` 52 | 53 | Got the following output: 54 | 55 | ``` 56 | lm_head.weight torch.Size([32000, 4096]) x torch.float16 57 | model.norm.weight torch.Size([4096]) x torch.float16 58 | model.embed_tokens.weight torch.Size([32000, 4096]) x torch.float16 59 | model.layers.0.input_layernorm.weight torch.Size([4096]) x torch.float16 60 | model.layers.0.mlp.down_proj.weight torch.Size([4096, 11008]) x torch.float16 61 | model.layers.0.mlp.gate_proj.weight torch.Size([11008, 4096]) x torch.float16 62 | model.layers.0.mlp.up_proj.weight torch.Size([11008, 4096]) x torch.float16 63 | model.layers.0.post_attention_layernorm.weight torch.Size([4096]) x torch.float16 64 | model.layers.0.self_attn.k_proj.weight torch.Size([4096, 4096]) x torch.float16 65 | model.layers.0.self_attn.o_proj.weight torch.Size([4096, 4096]) x torch.float16 66 | model.layers.0.self_attn.q_proj.weight torch.Size([4096, 4096]) x torch.float16 67 | model.layers.0.self_attn.rotary_emb.inv_freq torch.Size([64]) x torch.float32 68 | model.layers.0.self_attn.v_proj.weight torch.Size([4096, 4096]) x torch.float16 69 | … # Duplicated name for model.layers.x 70 | ``` 71 | 72 | If it’s hard to tell which layer the weight is for, the HF model class can be checked in the checkpoint config file [example](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L4). Then we can find the model code in the transformer repo by searching the model class name [model code](https://github.com/huggingface/transformers/blob/bdf36dcd48106a4a0278ed7f3cc26cd65ab7b066/src/transformers/models/llama/modeling_llama.py#L1084) 73 | 74 | 75 | ### Step 3 Inspect the weight names in JetStream-pt model implementation: 76 | 77 | Run the model in JetStream using benchmarks/run_offline.py. The weight names, shape and dtype will be printed in the log (Omitting Layer N which are duplicated names) 78 | 79 | Example: 80 | 81 | ``` 82 | Name: freqs_cis, shape: (2048, 64) x complex64 83 | Name: tok_embeddings.weight, shape: (32000, 4096) x bfloat16 84 | Name: layers.0.attention.wo.weight, shape: (4096, 4096) x bfloat16 85 | Name: layers.0.attention.wq.weight, shape: (4096, 4096) x bfloat16 86 | Name: layers.0.attention.wk.weight, shape: (4096, 4096) x bfloat16 87 | Name: layers.0.attention.wv.weight, shape: (4096, 4096) x bfloat16 88 | Name: layers.0.feed_forward.w1.weight, shape: (11008, 4096) x bfloat16 89 | Name: layers.0.feed_forward.w2.weight, shape: (4096, 11008) x bfloat16 90 | Name: layers.0.feed_forward.w3.weight, shape: (11008, 4096) x bfloat16 91 | Name: layers.0.attention_norm.weight, shape: (4096,) x bfloat16 92 | Name: layers.0.ffn_norm.weight, shape: (4096,) x bfloat16 93 | Name: norm.weight, shape: (4096,) x bfloat16 94 | Name: output.weight, shape: (32000, 4096) x bfloat16 95 | ``` 96 | 97 | If it’s hard to tell which layer the weight is for, you can find out the meaning of the weight, please check the model implementation under jetstream_pt/third_party. 98 | 99 | ### Step 4 By comparing the weight names, or diving into the model code, we can find out the mapping: 100 | 101 | In this example: 102 | 103 | HF lm_head.weight -> JetStream-pt output.weight 104 | HF model.norm.weight -> JetStream-pt norm.weight 105 | HF model.embed_tokens.weight -> JetStream-pt tok_embeddings.weight 106 | HF model.layers.X.input_layernorm.weight -> layers.X.attention_norm.weight 107 | HF model.layers.0.post_attention_layernorm.weight -> layers.0.ffn_norm.weight 108 | HF model.layers.X.self_attn.{q/k/v/o}_proj.weight -> layers.X.attention.w{q/k/v/o}.weight 109 | HF model.layers.X.mlp.gate_proj.weight -> layers.X.feed_forward.w1.weight 110 | HF model.layers.X.mlp.down_proj.weight -> layers.X.feed_forward.w2.weight 111 | HF model.layers.X.mlp.up_proj.weight -> layers.X.feed_forward.w3.weight 112 | freqs_cis is a special case, in JetStream PyTorch, the weight is pre-computed during weight loading, so no need to map the Huggingface freq weight over. 113 | 114 | ### Step 5 Validate the converted checkpoint: 115 | 116 | If there is a checkpoint in already supported format, convert the checkpoint in supported format first, as the golden data to compare with the converted checkpoint from the new format. 117 | 118 | Write a small script, or reuse the [script](https://github.com/google/jetstream-pytorch/blob/main/scripts/validate_hf_ckpt_conversion.py) to compare the 2 converted checkpoints. 119 | 120 | Fix the difference between 2 converted checkpoints if there is any. (This will be model and checkpoint format specific) 121 | 122 | ### Step 6 End-to-end validation: From checkpoint conversion to serving 123 | 124 | Example 125 | 126 | ``` 127 | export input_ckpt_dir=/mnt/disks/lsiyuan/llama_weight/7B-FT-chat 128 | export output_ckpt_dir=/mnt/disks/lsiyuan/llama_weight/hf_llama_2_7b_converted_bf16_2 129 | export model_name="llama" 130 | export from_hf=True 131 | python -m convert_checkpoints --model_name=$model_name \ 132 | --input_checkpoint_dir=$input_ckpt_dir \ 133 | --output_checkpoint_dir=$output_ckpt_dir \ 134 | --quantize_weights=$quantize_weights \ 135 | --quantize_type=$quantize_type \ 136 | --from_hf=True 137 | ``` -------------------------------------------------------------------------------- /install_everything.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Uninstall existing jax 16 | pip show jax && pip uninstall -y jax 17 | pip show jaxlib && pip uninstall -y jaxlib 18 | pip show libtpu-nightly && pip uninstall -y libtpu-nightly 19 | pip show tensorflow && pip uninstall -y tensorflow 20 | pip show ray && pip uninstall -y ray 21 | pip show flax && pip uninstall -y flax 22 | pip show keras && pip uninstall -y keras 23 | pip show tensorboard && pip uninstall -y tensorboard 24 | pip show tensorflow-text && pip uninstall -y tensorflow-text 25 | pip show torch_xla2 && pip uninstall -y torch_xla2 26 | 27 | pip install flax 28 | pip install tensorflow-text 29 | pip install tensorflow 30 | pip install huggingface_hub 31 | pip install transformers 32 | 33 | pip install ray[default]==2.33.0 34 | # torch cpu 35 | pip install torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu 36 | pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage 37 | pip install safetensors colorama coverage humanize 38 | 39 | git submodule update --init --recursive 40 | pip show google-jetstream && pip uninstall -y google-jetstream 41 | pip show torch_xla2 && pip uninstall -y torch_xla2 42 | pip install -e . 43 | pip install -U jax[tpu]==0.4.37 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 44 | pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu 45 | -------------------------------------------------------------------------------- /install_everything_gpu.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Uninstall existing jax 16 | pip show jax && pip uninstall -y jax 17 | pip show jaxlib && pip uninstall -y jaxlib 18 | pip show libtpu-nightly && pip uninstall -y libtpu-nightly 19 | pip show tensorflow && pip uninstall -y tensorflow 20 | pip show ray && pip uninstall -y ray 21 | pip show flax && pip uninstall -y flax 22 | pip show keras && pip uninstall -y keras 23 | pip show tensorboard && pip uninstall -y tensorboard 24 | pip show tensorflow-text && pip uninstall -y tensorflow-text 25 | pip show torch_xla2 && pip uninstall -y torch_xla2 26 | 27 | pip install flax==0.8.4 28 | pip install tensorflow-text 29 | pip install tensorflow 30 | pip install transformers 31 | 32 | pip install ray[default]==2.22.0 33 | # torch cpu 34 | pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage 35 | pip install safetensors colorama coverage humanize 36 | 37 | git submodule update --init --recursive 38 | pip show google-jetstream && pip uninstall -y google-jetstream 39 | pip show torch_xla2 && pip uninstall -y torch_xla2 40 | pip install -e . 41 | pip install -U jax[cuda12]==0.4.30 42 | pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu 43 | -------------------------------------------------------------------------------- /jetstream_pt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jetstream_pt.engine import create_pytorch_engine 16 | 17 | __all__ = ["create_pytorch_engine"] 18 | -------------------------------------------------------------------------------- /jetstream_pt/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | from absl import flags 18 | import jax 19 | from jetstream_pt.environment import QuantizationConfig 20 | 21 | FLAGS = flags.FLAGS 22 | 23 | flags.DEFINE_string("tokenizer_path", None, "The tokenizer model path") 24 | flags.DEFINE_string("model_name", None, "model type") 25 | flags.DEFINE_string("checkpoint_path", None, "Directory for .pth checkpoints") 26 | flags.DEFINE_bool("bf16_enable", True, "Whether to enable bf16") 27 | flags.DEFINE_integer("context_length", 1024, "The context length") 28 | flags.DEFINE_integer("batch_size", 32, "The batch size") 29 | flags.DEFINE_string("size", "tiny", "size of model") 30 | flags.DEFINE_integer("max_cache_length", 1024, "kv_cache_quantize") 31 | flags.DEFINE_integer("max_decode_length", 1024, "max length of generated text") 32 | flags.DEFINE_string("sharding_config", "", "config file for sharding") 33 | flags.DEFINE_bool( 34 | "shard_on_batch", 35 | False, 36 | "whether to shard on batch dimension" 37 | "If set true, sharding_config will be ignored.", 38 | ) 39 | flags.DEFINE_string("profiling_output", "", "The profiling output") 40 | 41 | # Quantization related flags 42 | flags.DEFINE_bool("quantize_weights", False, "weight quantization") 43 | flags.DEFINE_bool( 44 | "quantize_activation", 45 | False, 46 | "Quantize Q,K,V projection and FeedForward activation. Defaults to False", 47 | ) 48 | flags.DEFINE_string( 49 | "quantize_type", "int8_per_channel", "Type of quantization." 50 | ) 51 | flags.DEFINE_bool( 52 | "quantize_kv_cache", None, "defaults to the same value as quantize_weights" 53 | ) 54 | flags.DEFINE_multi_string( 55 | "quantize_exclude_layers", 56 | None, 57 | "List of layer names to exclude from quantization", 58 | ) 59 | 60 | _VALID_QUANTIZATION_TYPE = { 61 | "int8_per_channel", 62 | "int4_per_channel", 63 | "int8_blockwise", 64 | "int4_blockwise", 65 | } 66 | 67 | flags.register_validator( 68 | "quantize_type", 69 | lambda value: value in _VALID_QUANTIZATION_TYPE, 70 | f"quantize_type is invalid, supported quantization types are {_VALID_QUANTIZATION_TYPE}", 71 | ) 72 | flags.DEFINE_bool( 73 | "profiling_prefill", 74 | False, 75 | "Whether to profile the prefill, " 76 | "if set to false, profile generate function only", 77 | required=False, 78 | ) 79 | flags.DEFINE_bool( 80 | "ragged_mha", 81 | False, 82 | "Whether to enable Ragged multi head attention", 83 | required=False, 84 | ) 85 | flags.DEFINE_integer( 86 | "starting_position", 87 | 512, 88 | "The starting position of decoding, " 89 | "for performance tuning and debugging only", 90 | required=False, 91 | ) 92 | flags.DEFINE_bool( 93 | "ring_buffer", 94 | True, 95 | "Whether to enable ring buffer", 96 | required=False, 97 | ) 98 | flags.DEFINE_bool( 99 | "flash_attention", 100 | False, 101 | "Whether to enable flas attention. Only takes effect at test mode", 102 | required=False, 103 | ) 104 | flags.DEFINE_bool( 105 | "generate_cache_stacked", 106 | False, 107 | "Whether to stack the generate cache to the layer dimension. Only takes effect at test mode", 108 | required=False, 109 | ) 110 | flags.DEFINE_bool( 111 | "new_cache_stacked", 112 | False, 113 | "Whether to stack the generate cache to the layer dimension. Only takes effect at test mode", 114 | required=False, 115 | ) 116 | flags.DEFINE_bool( 117 | "lazy_cache_update", 118 | False, 119 | "Whether to update the cache during attention or delayed until all the layers are done. " 120 | "Only takes effect at test mode", 121 | required=False, 122 | ) 123 | flags.DEFINE_float( 124 | "temperature", 125 | 1.0, 126 | "temperature parameter for scaling probability." 127 | "Only invoked when sampling algorithm is set to" 128 | "weighted or topk", 129 | ) 130 | flags.DEFINE_string( 131 | "sampling_algorithm", 132 | "greedy", 133 | "sampling algorithm to use. Options:" 134 | "('greedy', 'weighted', 'neucleus', 'topk')", 135 | ) 136 | flags.DEFINE_float( 137 | "nucleus_topp", 138 | 0.0, 139 | "restricting to p probability mass before sampling", 140 | ) 141 | flags.DEFINE_integer( 142 | "topk", 143 | 0, 144 | "size of top k used when sampling next token", 145 | ) 146 | 147 | flags.DEFINE_integer( 148 | "paged_attention_total_num_pages", 149 | 0, 150 | "total number of pages per layer for page attention", 151 | ) 152 | 153 | flags.DEFINE_integer( 154 | "paged_attention_page_size", 155 | 64, 156 | "page size per page", 157 | ) 158 | flags.DEFINE_string( 159 | "internal_jax_compilation_cache_dir", 160 | "~/jax_cache", 161 | "Jax compilation cache directory", 162 | ) 163 | flags.DEFINE_integer( 164 | "internal_jax_persistent_cache_min_entry_size_bytes", 165 | 0, 166 | "Minimum size (in bytes) of an entry that will be cached in the persistent compilation cache", 167 | ) 168 | flags.DEFINE_integer( 169 | "internal_jax_persistent_cache_min_compile_time_secs", 170 | 1, 171 | "Minimum compilation time for a computation to be written to persistent cache", 172 | ) 173 | 174 | 175 | def create_quantization_config_from_flags(): 176 | """Create Quantization Config from cmd flags""" 177 | config = QuantizationConfig() 178 | quantize_weights = FLAGS.quantize_weights 179 | quantize_type = FLAGS.quantize_type 180 | if not quantize_weights: 181 | return config 182 | config.enable_weight_quantization = True 183 | 184 | config.num_bits_weight = 8 if "int8" in quantize_type else 4 185 | config.is_blockwise_weight = "blockwise" in quantize_type 186 | 187 | config.enable_activation_quantization = FLAGS.quantize_activation 188 | config.exclude_layers = FLAGS.quantize_exclude_layers 189 | config.enable_kv_quantization = ( 190 | FLAGS.quantize_kv_cache 191 | if FLAGS.quantize_kv_cache is not None 192 | else FLAGS.quantize_weights 193 | ) 194 | return config 195 | 196 | 197 | def set_jax_compilation_cache_config(): 198 | """Sets the jax compilation cache configuration""" 199 | jax.config.update( 200 | "jax_compilation_cache_dir", 201 | os.path.expanduser(FLAGS.internal_jax_compilation_cache_dir), 202 | ) 203 | jax.config.update( 204 | "jax_persistent_cache_min_entry_size_bytes", 205 | FLAGS.internal_jax_persistent_cache_min_entry_size_bytes, 206 | ) 207 | jax.config.update( 208 | "jax_persistent_cache_min_compile_time_secs", 209 | FLAGS.internal_jax_persistent_cache_min_compile_time_secs, 210 | ) 211 | -------------------------------------------------------------------------------- /jetstream_pt/fetch_models.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import glob 3 | import os 4 | from typing import Optional 5 | from requests.exceptions import HTTPError 6 | from huggingface_hub import snapshot_download 7 | from absl import flags 8 | import torch 9 | from safetensors import safe_open 10 | from jetstream_pt.environment import ( 11 | JetEngineEnvironmentData, 12 | ) 13 | from jetstream_pt.third_party.llama import model_exportable as llama_model 14 | from jetstream_pt.third_party.mixtral import model as mixtral_model 15 | from jetstream_pt.third_party.gemma import model as gemma_model 16 | 17 | FLAGS = flags.FLAGS 18 | 19 | flags.DEFINE_string( 20 | "working_dir", 21 | "checkpoints", 22 | "Directory to store downloaded/converted weights", 23 | ) 24 | flags.DEFINE_string("hf_token", "", "huggingface token") 25 | flags.DEFINE_bool( 26 | "internal_use_random_weights", 27 | False, 28 | "Use random weights instead of HF weights. Testing only.", 29 | ) 30 | 31 | flags.DEFINE_bool( 32 | "internal_use_tiny_model", 33 | False, 34 | "Use tiny config instead of real config of HF weights. Testing only.", 35 | ) 36 | 37 | flags.DEFINE_integer( 38 | "override_max_cache_length", 39 | -1, 40 | "Size of cache, defaults to input + output length", 41 | ) 42 | 43 | 44 | @dataclasses.dataclass 45 | class ModelInfo: 46 | """Model information.""" 47 | 48 | model_class: torch.nn.Module 49 | # information needed to allocate cache 50 | num_layers: int 51 | # number of kv heads 52 | num_kv_heads: int 53 | 54 | head_dim: int 55 | n_reps: int # repeatition for GQA 56 | 57 | 58 | _llama2_7 = ModelInfo(llama_model.Transformer, 32, 32, 128, 1) 59 | _llama2_13 = ModelInfo(llama_model.Transformer, 40, 40, 128, 1) 60 | _llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 8) 61 | _llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128, 4) 62 | _llama3_70 = _llama2_70 63 | _llama3_1_8b = _llama3_8 64 | _llama3_2_1b = ModelInfo(llama_model.Transformer, 16, 8, 64, 4) 65 | _llama3_3_70b = _llama2_70 66 | 67 | _mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4) 68 | 69 | _gemma_2b = ModelInfo(gemma_model.GemmaModel, 18, 1, 256, 8) 70 | _gemma_7b = ModelInfo(gemma_model.GemmaModel, 28, 16, 256, 1) 71 | 72 | 73 | model_id_to_class = { 74 | "meta-llama/Llama-2-7b-chat-hf": _llama2_7, 75 | "meta-llama/Llama-2-7b-hf": _llama2_7, 76 | "meta-llama/Llama-2-13b-chat-hf": _llama2_13, 77 | "meta-llama/Llama-2-13b-hf": _llama2_13, 78 | "meta-llama/Llama-2-70b-hf": _llama2_70, 79 | "meta-llama/Llama-2-70b-chat-hf": _llama2_70, 80 | "meta-llama/Meta-Llama-3-8B": _llama3_8, 81 | "meta-llama/Meta-Llama-3-8B-Instruct": _llama3_8, 82 | "meta-llama/Meta-Llama-3-70B": _llama3_70, 83 | "meta-llama/Meta-Llama-3-70B-Instruct": _llama3_70, 84 | "meta-llama/Llama-3.1-8B": _llama3_1_8b, 85 | "meta-llama/Llama-3.1-8B-Instruct": _llama3_1_8b, 86 | "meta-llama/Llama-3.2-1B": _llama3_2_1b, 87 | "meta-llama/Llama-3.2-1B-Instruct": _llama3_2_1b, 88 | "meta-llama/Llama-3.3-70B": _llama3_3_70b, 89 | "meta-llama/Llama-3.3-70B-Instruct": _llama3_3_70b, 90 | "google/gemma-2b": _gemma_2b, 91 | "google/gemma-2b-it": _gemma_2b, 92 | "google/gemma-7b": _gemma_7b, 93 | "google/gemma-7b-it": _gemma_7b, 94 | "mistralai/Mixtral-8x7B-v0.1": _mixtral_87, 95 | "mistralai/Mixtral-8x7B-Instruct-v0.1": _mixtral_87, 96 | } 97 | 98 | 99 | def _model_dir(repo_id): 100 | """Model dir structure: 101 | 102 | working_dir/ 103 | repo_id/ 104 | hf_original/ 105 | converted_bfloat/ 106 | converted_int8/ 107 | """ 108 | return os.path.join(FLAGS.working_dir, repo_id) 109 | 110 | 111 | def _hf_dir(repo_id): 112 | """Dir to hf repo""" 113 | return os.path.join(_model_dir(repo_id), "hf_original") 114 | 115 | 116 | def _int_dir(repo_id): 117 | return os.path.join(_model_dir(repo_id), "converted_int8") 118 | 119 | 120 | def construct_env_data_from_model_id( 121 | repo_id, 122 | batch_size, 123 | input_length, 124 | output_length, 125 | ): 126 | """Create Environment from model id and options""" 127 | tokenizer_path = os.path.join(_hf_dir(repo_id), "tokenizer.model") 128 | checkpoint_path = _hf_dir(repo_id) 129 | checkpoint_format = "safetensors" 130 | 131 | shard_on_batch = False 132 | 133 | max_cache_length = ( 134 | FLAGS.override_max_cache_length 135 | if FLAGS.override_max_cache_length > 0 136 | else input_length + output_length 137 | ) 138 | 139 | model_info = model_id_to_class.get(repo_id) 140 | env_data = JetEngineEnvironmentData( 141 | tokenizer_path=tokenizer_path, 142 | checkpoint_path=checkpoint_path, 143 | checkpoint_format=checkpoint_format, 144 | batch_size=batch_size, 145 | max_decode_length=output_length, 146 | max_input_sequence_length=input_length, 147 | cache_sequence_length=max_cache_length, 148 | bf16_enable=True, 149 | sharding_config_path="", 150 | shard_on_batch=shard_on_batch, 151 | n_reps=model_info.n_reps, 152 | ) 153 | env_data.cache_shape = ( 154 | batch_size, 155 | model_info.num_kv_heads, 156 | max_cache_length, 157 | model_info.head_dim, 158 | ) 159 | env_data.num_layers = model_info.num_layers 160 | return env_data 161 | 162 | 163 | def _load_weights(directory): 164 | safetensors_files = glob.glob(os.path.join(directory, "*.safetensors")) 165 | state_dict = {} 166 | for file_path in safetensors_files: 167 | with safe_open(file_path, framework="pt") as f: 168 | for key in f.keys(): 169 | state_dict[key] = f.get_tensor(key).to(torch.bfloat16) 170 | # Load the state_dict into the model 171 | if not state_dict: 172 | raise AssertionError( 173 | f"Tried to load weights from {directory}, but couldn't find any." 174 | ) 175 | return state_dict 176 | 177 | 178 | def _make_random_model_weights(model): 179 | result = {} 180 | for key, val in model.state_dict().items(): 181 | new_weights = torch.rand(val.shape, dtype=val.dtype, device="cpu") 182 | result[key] = new_weights 183 | return result 184 | 185 | 186 | def instantiate_model_from_repo_id( 187 | repo_id, 188 | env, 189 | ): 190 | """Create model instance by hf model id.+""" 191 | model_dir = _hf_dir(repo_id) 192 | if not FLAGS.internal_use_random_weights and ( 193 | not os.path.exists(model_dir) 194 | or not glob.glob(os.path.join(model_dir, "*.safetensors")) 195 | ): 196 | # no weights has been downloaded 197 | _hf_download(repo_id, model_dir, FLAGS.hf_token) 198 | model_info = model_id_to_class.get(repo_id) 199 | assert model_info is not None 200 | 201 | env.device = "meta" 202 | model = model_info.model_class.from_hf_model_id( 203 | repo_id, env, FLAGS.internal_use_tiny_model 204 | ) 205 | if FLAGS.internal_use_random_weights or FLAGS.internal_use_tiny_model: 206 | weights = _make_random_model_weights(model) 207 | else: 208 | weights = _load_weights(model_dir) 209 | weights = model.convert_hf_weights(weights) 210 | model.load_state_dict(weights, assign=True, strict=False) 211 | 212 | return model 213 | ## QQ do i need to set the weights onto the model? 214 | 215 | 216 | def _hf_download( 217 | repo_id: str, dest_directory: str, hf_token: Optional[str] = None 218 | ) -> None: 219 | os.makedirs(dest_directory, exist_ok=True) 220 | try: 221 | if not hf_token: 222 | hf_token = os.environ.get("HF_TOKEN") 223 | if not hf_token: 224 | # NOTE: setting true allows hf to read from the config folder. 225 | hf_token = True 226 | snapshot_download( 227 | repo_id, 228 | local_dir=dest_directory, 229 | local_dir_use_symlinks=False, 230 | token=hf_token, 231 | allow_patterns=[ 232 | "model*.safetensors", 233 | "*.json", 234 | "*.model", 235 | ], 236 | ) 237 | except HTTPError as e: 238 | if e.response.status_code == 401: 239 | print( 240 | "Please use huggingface-cli login to authenticate " 241 | "to download private checkpoints." 242 | ) 243 | print("OR, pass `hf_token=...` explicitly.") 244 | raise e 245 | -------------------------------------------------------------------------------- /jetstream_pt/gcs_to_cns.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | set -e 17 | 18 | LOG_FILE_IN_GCS=$1 19 | filename=$(basename $LOG_FILE_IN_GCS) 20 | output_file=$(date "+%Y-%m-%d-%H:%M:%S")_${filename} 21 | 22 | CNS_PATH=/cns/pi-d/home/${USER}/tensorboard/multislice/ 23 | fileutil mkdir -p ${CNS_PATH} 24 | /google/data/ro/projects/cloud/bigstore/mpm/fileutil_bs/stable/bin/fileutil_bs cp /bigstore/${LOG_FILE_IN_GCS} ${CNS_PATH}/$output_file 25 | echo file to put into xprof: ${CNS_PATH}/$output_file 26 | -------------------------------------------------------------------------------- /jetstream_pt/hf_tokenizer.py: -------------------------------------------------------------------------------- 1 | from jetstream.engine import tokenizer_api, token_utils 2 | 3 | 4 | class HFTokenizerAdapter(tokenizer_api.Tokenizer): 5 | """Implementation of Tokenizer interface backed by HF tokenizer.""" 6 | 7 | def __init__(self, tokenizer): 8 | self.tokenizer = tokenizer 9 | 10 | def encode(self, s: str, **kwargs): 11 | """Tokenize a string. 12 | Args: 13 | s: String to tokenize. 14 | **kwargs: Additional keyword arguments. 15 | Returns: 16 | tokens: Tokenized into integers. 17 | true_length: Actual length of the non-padded sequence 18 | if padding is used. 19 | """ 20 | res = self.tokenizer.encode(s, add_special_tokens=False) 21 | return token_utils.pad_tokens( 22 | res, self.bos_id, self.pad_id, jax_padding=True 23 | ) 24 | 25 | def decode(self, token_ids: list[int], **kwargs) -> str: 26 | """Processess input token ids to generate a string. 27 | Args: 28 | token_ids: List of token ids. 29 | **kwargs: Additional keyword arguments. 30 | Returns: 31 | str: String generated from the token ids. 32 | """ 33 | return self.tokenizer.decode(token_ids) 34 | 35 | @property 36 | def pad_id(self) -> int: 37 | """ID of the pad token.""" 38 | return self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else 0 39 | 40 | @property 41 | def eos_id(self) -> int: 42 | """ID of EOS token.""" 43 | return self.tokenizer.eos_token_id 44 | 45 | @property 46 | def bos_id(self) -> int: 47 | """ID of BOS token.""" 48 | return self.tokenizer.bos_token_id 49 | 50 | @property 51 | def stop_tokens(self) -> set[int]: 52 | """ID of the stop token.""" 53 | return {self.eos_id} 54 | -------------------------------------------------------------------------------- /jetstream_pt/model_base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import itertools 3 | from typing import Dict, Any, Optional 4 | import dataclasses 5 | from collections import defaultdict 6 | import torch 7 | 8 | 9 | def _get_hf_name(module, key): 10 | if hasattr(module, "attr_to_property") and key in module.attr_to_property: 11 | return module.attr_to_property[key].huggingface_name 12 | return None 13 | 14 | 15 | def _gather_names(module, myprefix, hf_prefix, result): 16 | for key, _ in itertools.chain( 17 | module.named_parameters(recurse=False), 18 | module.named_buffers(recurse=False), 19 | ): 20 | hf_name = _get_hf_name(module, key) or key 21 | result[hf_prefix + hf_name] = myprefix + key 22 | 23 | for name, child in module.named_children(): 24 | hf_name = _get_hf_name(module, name) or name 25 | _gather_names( 26 | child, myprefix + name + ".", hf_prefix + hf_name + ".", result 27 | ) 28 | 29 | 30 | def _gather_sharding_axis(module, myprefix, result): 31 | if hasattr(module, "attr_to_property"): 32 | for key, val in module.attr_to_property.items(): 33 | if val.sharding_axis is not None: 34 | result[myprefix + key] = val.sharding_axis 35 | 36 | for name, child in module.named_children(): 37 | _gather_sharding_axis(child, myprefix + name + ".", result) 38 | 39 | 40 | @dataclasses.dataclass 41 | class AttrProperty: 42 | """Attributes attached to model weights.""" 43 | 44 | huggingface_name: Optional[str] = None 45 | sharding_axis: Optional[int] = None 46 | 47 | 48 | class ModuleBase(torch.nn.Module, metaclass=abc.ABCMeta): 49 | """nn Module that allows attaching properties. 50 | 51 | This class currently serves 2 goals: 52 | 1. Allow model to specify alternative names for submodules / weights 53 | this is needed so that it can *also* load HuggingFace checkpoints 54 | without need to do massive rewrites. 55 | 56 | 2. Allow model to attach information to weights, such as sharding config. 57 | 58 | Quantization config could be another thing to attach, but right now it's not used 59 | this way. 60 | """ 61 | 62 | attr_to_property: Dict[str, Any] 63 | 64 | def __init__(self): 65 | super().__init__() 66 | self.attr_to_property = defaultdict(AttrProperty) 67 | 68 | def get_hf_names_to_real_name(self): 69 | """Return a dict of attr names to it's hf name.""" 70 | result = {} 71 | _gather_names(self, "", "", result) 72 | return result 73 | 74 | def get_sharding_annotations(self): 75 | """Return a dict of attr names to it's sharding dim.""" 76 | result = {} 77 | _gather_sharding_axis(self, "", result) 78 | return result 79 | 80 | def hf_name(self, orig_name, hf_name): 81 | """Set it's alternative name for a attribute or submodule.""" 82 | self.attr_to_property[orig_name].huggingface_name = hf_name 83 | 84 | def annotate_sharding(self, name, axis): 85 | """Set sharding name for a attribute or submodule.""" 86 | self.attr_to_property[name].sharding_axis = axis 87 | 88 | def convert_hf_weights( 89 | self, hf_weights: Dict[str, torch.Tensor] 90 | ) -> Dict[str, torch.Tensor]: 91 | """Load state_dict with hg weights.""" 92 | weights = {} 93 | updated_keys = self.get_hf_names_to_real_name() 94 | for name, updated in updated_keys.items(): 95 | if name in hf_weights: 96 | weights[updated] = hf_weights[name] 97 | 98 | for name in list(weights.keys()): 99 | if "inv_freq" in name: 100 | weights.pop(name) 101 | if hasattr(self, "freqs_cis"): 102 | weights["freqs_cis"] = self.freqs_cis 103 | return weights 104 | -------------------------------------------------------------------------------- /jetstream_pt/page_attention_manager.py: -------------------------------------------------------------------------------- 1 | import queue 2 | import functools 3 | from typing import List, Tuple 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import jax.sharding as jsharding 8 | import numpy as np 9 | 10 | 11 | class PageAttentionManager: 12 | """Manages page blocks. 13 | 14 | This manager maintains a main list of free page blocks, it support below features: 15 | 1. Reseve pages for prefill insert and decode. 16 | 2. Free pages resource for the slots after decode. Pages indices go to free list. 17 | 3. Get pages indices meta data for all the slots. 18 | 4. Transform and insert prefill caches to decode caches. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | batch_size: int, 24 | paged_attention_total_num_pages: int, 25 | paged_attention_page_size: int, 26 | max_pages_per_sequence: int, 27 | ): 28 | self.unused_pages = queue.Queue() 29 | self.batch_size = batch_size 30 | self.page_indices = np.full( 31 | (batch_size, max_pages_per_sequence), 32 | paged_attention_total_num_pages - 1, 33 | dtype=np.int32, 34 | ) 35 | self.lengths = np.zeros(batch_size, dtype=np.int32) 36 | self.paged_attention_page_size = paged_attention_page_size 37 | self.max_pages_per_sequence = max_pages_per_sequence 38 | for i in range(paged_attention_total_num_pages): 39 | self.unused_pages.put(i, block=False) 40 | 41 | # pylint: disable-next=all 42 | def reserve_pages_insert(self, slot: int, seq_len: int): 43 | self.lengths[slot] = seq_len 44 | num_pages = ( 45 | seq_len // self.paged_attention_page_size 46 | if seq_len % self.paged_attention_page_size == 0 47 | else seq_len // self.paged_attention_page_size + 1 48 | ) 49 | 50 | indices = [self.unused_pages.get(block=False) for _ in range(num_pages)] 51 | self.page_indices[slot, :num_pages] = indices 52 | return num_pages, self.page_indices[slot, :num_pages] 53 | 54 | # pylint: disable-next=all 55 | def reserve_pages_decode(self, slot: int, seq_len: int): 56 | if seq_len > 0 and seq_len % self.paged_attention_page_size == 0: 57 | index = self.unused_pages.get(block=False) 58 | num_pages = seq_len // self.paged_attention_page_size 59 | self.page_indices[slot, num_pages] = index 60 | 61 | # pylint: disable-next=all 62 | def fill_new_pages(self, lens): 63 | for slot in range(self.batch_size): 64 | self.reserve_pages_decode(slot, lens[slot]) 65 | 66 | # pylint: disable-next=all 67 | def prefill_cache_padding( 68 | self, 69 | caches: List[Tuple[jax.Array, jax.Array]], 70 | seq_len: int, 71 | num_pages: int, 72 | ) -> List[Tuple[jax.Array, jax.Array]]: 73 | 74 | pad_width = num_pages * self.paged_attention_page_size - seq_len 75 | if pad_width == 0: 76 | return caches 77 | 78 | return [ 79 | (self.pad_sequences(k, pad_width), self.pad_sequences(v, pad_width)) 80 | for k, v in caches 81 | ] 82 | 83 | def insert_prefill_cache( 84 | self, 85 | prefill_caches: List[Tuple[jax.Array, jax.Array]], 86 | decode_caches: List[Tuple[jax.Array, jax.Array]], 87 | update_indexes: jax.Array, 88 | tep_kv: jax.Array, 89 | sharding: jsharding.Sharding, 90 | ) -> List[Tuple[jax.Array, jax.Array]]: 91 | """Insert prefill caches to decode caches. 92 | 93 | Args: 94 | prefill_caches: List of Tuple K, V. For each K, V: 95 | [batch_size, num_heads, seq_len, head_dim] jax.Array. 96 | decode_caches: List of Tuple K, V. For each K, V: 97 | [num_heads, paged_attention_total_num_pages, paged_attention_page_size, head_dim] jax.Array. 98 | update_indexes: Page indexes for insertion. 99 | tep_kv: List of Tuple K, V. For each K, V: 100 | kv_heads, num_pages * .paged_attention_page_size, dim. 101 | sharding: Decode cache sharding. 102 | 103 | 104 | Returns: 105 | Decode cache. List of Tuple K, V. For each K, V: 106 | [num_heads, paged_attention_total_num_pages, paged_attention_page_size, head_dim] jax.Array. 107 | """ 108 | # Reduce cache batch deminsion 109 | # [kv_heads, seq_len, dim] 110 | squeezed_caches = [ 111 | (jnp.squeeze(k, axis=0), jnp.squeeze(v, axis=0)) 112 | for k, v in prefill_caches 113 | ] 114 | tmp_caches = [ 115 | ( 116 | tep_kv.at[:, : k.shape[1], :].set(k), 117 | tep_kv.at[:, : v.shape[1], :].set(v), 118 | ) 119 | for k, v in squeezed_caches 120 | ] 121 | kv_heads, _, dim = tmp_caches[0][0].shape 122 | # [kv_heads, num_pages, paged_attention_page_size, dim] 123 | paged_caches = [ 124 | ( 125 | jnp.reshape(k, (kv_heads, -1, self.paged_attention_page_size, dim)), 126 | jnp.reshape(v, (kv_heads, -1, self.paged_attention_page_size, dim)), 127 | ) 128 | for k, v in tmp_caches 129 | ] 130 | 131 | @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) 132 | def insert(cache, new_entry): 133 | res = cache.at[:, update_indexes, :, :].set(new_entry) 134 | res = jax.lax.with_sharding_constraint(res, sharding) 135 | return res 136 | 137 | caches = [ 138 | (insert(k, newk), insert(v, newv)) 139 | for (k, v), (newk, newv) in zip(decode_caches, paged_caches) 140 | ] 141 | 142 | return caches 143 | 144 | # pylint: disable-next=all 145 | def get_page_token_indices(self, lens): 146 | # assert lens.shape == ( 147 | # self.batch_size, 148 | # 1, 149 | # ), f"len shape: {lens.shape} not equals batch size: {self.batch_size, 1}" 150 | update_page_indices = [] 151 | token_scale_indices = [] 152 | batch_slots = [] 153 | offset = 0 154 | 155 | for slot in range(self.batch_size): 156 | seq_len = lens[slot] 157 | if seq_len == 0: 158 | continue 159 | num_pages = seq_len // self.paged_attention_page_size + 1 160 | token_pos = seq_len % self.paged_attention_page_size 161 | page_index = self.page_indices[slot, num_pages - 1] 162 | 163 | update_page_indices.append(page_index) 164 | token_scale_indices.append(offset + token_pos) 165 | batch_slots.append(slot) 166 | offset += self.paged_attention_page_size 167 | self.lengths = np.where(lens == 0, 0, lens + 1) 168 | update_page_indices = np.asarray(update_page_indices) 169 | token_scale_indices = np.asarray(token_scale_indices) 170 | batch_slots = np.asarray(batch_slots) 171 | return np.stack( 172 | ( 173 | update_page_indices, 174 | token_scale_indices, 175 | batch_slots, 176 | ) 177 | ) 178 | 179 | # pylint: disable-next=all 180 | def get_compress_kv_cache( 181 | self, 182 | decode_caches: List[Tuple[jax.Array, jax.Array]], 183 | slot: int, 184 | ) -> List[Tuple[jax.Array, jax.Array]]: 185 | lens = self.lengths[slot] 186 | indices = self.page_indices[slot] 187 | return [ 188 | ( 189 | self._compress_cache(k, lens, indices), 190 | self._compress_cache(v, lens, indices), 191 | ) 192 | for k, v in decode_caches 193 | ] 194 | 195 | def _compress_cache(self, cache: jax.Array, lens: int, indices: jax.Array): 196 | head, _, _, dim = cache.shape 197 | selected_cache = cache[:, indices, :, :] 198 | selected_cache = selected_cache.reshape((head, -1, dim)) 199 | selected_cache = selected_cache[:, 0:lens, :] 200 | return selected_cache 201 | 202 | # pylint: disable-next=all 203 | def pad_sequences(self, array, pad_width=10): 204 | padding_config = [ 205 | (0, 0), 206 | (0, 0), 207 | (0, pad_width), 208 | (0, 0), 209 | ] # Pad only seq_len and dim 210 | padded_array = jnp.pad(array, padding_config, mode="constant") 211 | return padded_array 212 | 213 | # pylint: disable-next=all 214 | def free_pages_resource(self, slot): 215 | for i in range(self.max_pages_per_sequence): 216 | index = self.page_indices[slot, i] 217 | if index < 0: 218 | break 219 | self.unused_pages.put(index, block=False) 220 | 221 | self.page_indices = self.page_indices.at[slot, :].set(jnp.asarray([0])) 222 | return None 223 | -------------------------------------------------------------------------------- /jetstream_pt/quantize.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Tuple, Union 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | import torch 20 | 21 | EPS = 1e-5 22 | 23 | 24 | def quantize_tensor( 25 | w: torch.Tensor, 26 | reduce_axis: Union[Tuple[int], int], 27 | n_bit: int = 8, 28 | symmetric: bool = True, 29 | block_size: int = -1, 30 | ): 31 | """ 32 | Quantize weight tensor w along 'reduce_axis'. 33 | 34 | Args: 35 | w: weight tensor to be quantized. 36 | reduce_axis: axises along which to quantize. 37 | n_bit: Quantize to n_bit bits. (Use int8 container for n_bits < 8). 38 | symmetric: Whether quantization is symmetric. 39 | block_size: Blocksize for blockwise quantization. -1 for per-channel quant. 40 | 41 | Return: 42 | w_q: Quantized weight in int8 container 43 | scale: scalar for quantized tensor 44 | zero_point: zero_point for quantized tensor, None if symmetric quantization 45 | """ 46 | 47 | assert 0 < n_bit <= 8, "Quantization bits must be between [1, 8]." 48 | if isinstance(reduce_axis, int): 49 | reduce_axis = (reduce_axis,) 50 | 51 | if block_size > 0: 52 | axis = reduce_axis[0] 53 | w_shape = w.shape 54 | assert w_shape[axis] % block_size == 0 55 | w = w.reshape(w_shape[:axis] + (-1, block_size) + w_shape[axis + 1 :]) 56 | reduce_axis = axis + 1 57 | 58 | max_int = 2 ** (n_bit - 1) - 1 59 | min_int = -(2 ** (n_bit - 1)) 60 | if not symmetric: 61 | max_val = w.amax(dim=reduce_axis, keepdim=True) 62 | min_val = w.amin(dim=reduce_axis, keepdim=True) 63 | scales = (max_val - min_val).clamp(min=EPS) / float(max_int - min_int) 64 | zero_point = min_int - min_val / scales 65 | else: 66 | max_val = w.abs().amax(dim=reduce_axis, keepdim=True) 67 | max_val = max_val.clamp(min=EPS) 68 | scales = max_val / max_int 69 | zero_point = 0 70 | 71 | w = torch.clamp( 72 | torch.round(w * (1.0 / scales) + zero_point), min_int, max_int 73 | ).to(torch.int8) 74 | 75 | return w, scales, zero_point if not symmetric else None 76 | 77 | 78 | def dequantize_tensor(w, scale, zero_point=None): 79 | """Dequantize tensor quantized by quantize_tensor.""" 80 | if zero_point is not None: 81 | return (w - zero_point) * scale 82 | 83 | return w * scale 84 | 85 | 86 | def load_q_weight_helper(w_q, scale, zp=None, block_size=-1): 87 | """Helper function to update the shape of quantized weight to match 88 | what quantized linear layer expects.""" 89 | if block_size < 0: 90 | w_q = w_q.to(torch.int8) 91 | if zp is not None: 92 | zp = (zp * scale).squeeze(-1).to(torch.bfloat16) 93 | scale = scale.squeeze(-1).to(torch.bfloat16) 94 | else: 95 | w_q = w_q.permute(1, 2, 0).to(torch.int8) 96 | if zp is not None: 97 | zp = (zp * scale).transpose(1, 0).squeeze(-1).to(torch.bfloat16) 98 | scale = scale.transpose(1, 0).squeeze(-1).to(torch.bfloat16) 99 | return w_q, scale, zp 100 | 101 | 102 | def blockwise_jax_kernel(inputs, weight, weight_scaler, zero_point): 103 | """Blockwise Matmul kernel impl in JAX using einsum""" 104 | weight = weight.astype(jnp.int8) 105 | block_size = weight.shape[1] 106 | inputs_shape = inputs.shape 107 | inputs_new_shape = inputs_shape[:-1] + ( 108 | inputs_shape[-1] // block_size, 109 | block_size, 110 | ) 111 | inputs = inputs.reshape(inputs_new_shape) 112 | out = jnp.einsum("scz,bdsc->bdsz", weight, inputs) 113 | out = jnp.einsum("bdsz,sz->bdz", out, weight_scaler) 114 | if zero_point is not None: 115 | zp_out = jnp.einsum("bdsc,sz->bdz", inputs, zero_point) 116 | out = out - zp_out 117 | return out 118 | 119 | 120 | def blockwise_jax_kernel_dot_general(inputs, weight, weight_scaler, zero_point): 121 | """Blockwise Matmul kernel impl in JAX using dot general""" 122 | inputs_shape = inputs.shape 123 | block_size = weight.shape[2] 124 | bs = inputs_shape[0] 125 | inputs_new_shape = inputs_shape[:-1] + ( 126 | inputs_shape[-1] // block_size, 127 | block_size, 128 | ) 129 | inputs = inputs.reshape(inputs_new_shape) 130 | inputs = jax.lax.collapse(inputs, 0, 2) 131 | out = jax.lax.dot_general( 132 | inputs, weight, dimension_numbers=([(2), (2)], [(1), (0)]) 133 | ) 134 | out = jax.lax.dot_general( 135 | out, weight_scaler, dimension_numbers=([(0), (0)], [(2), (1)]) 136 | ) 137 | out = jax.lax.transpose(out, [1, 0]) 138 | out = out.reshape((bs, -1) + out.shape[1:]) 139 | return out 140 | 141 | 142 | def blockwise_jax_kernel_einsum_flatten( 143 | inputs, weight, weight_scaler, zero_point 144 | ): 145 | """Blockwise Matmul kernel impl in JAX using einsum, with operands flattened""" 146 | weight = weight.astype(jnp.int8) 147 | block_size = weight.shape[1] 148 | inputs_shape = inputs.shape 149 | bs = inputs_shape[0] 150 | inputs_new_shape = inputs_shape[:-1] + ( 151 | inputs_shape[-1] // block_size, 152 | block_size, 153 | ) 154 | inputs = inputs.reshape(inputs_new_shape) 155 | inputs = jax.lax.collapse(inputs, 0, 2) 156 | out = jnp.einsum("scz,bsc->bsz", weight, inputs) 157 | out = jnp.einsum("bsz,sz->bz", out, weight_scaler) 158 | out = out.reshape((bs, -1) + out.shape[1:]) 159 | return out 160 | -------------------------------------------------------------------------------- /jetstream_pt/quantize_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .environment import QuantizationConfig 3 | from .layers import ( 4 | create_quantized_from_nn_linear, 5 | create_quantized_from_nn_embedding, 6 | AttentionKernel, 7 | Int8KVAttentionKernel, 8 | ) 9 | 10 | 11 | def quantize_model(float_model, config: QuantizationConfig): 12 | """Apply quantization to linear layers.""" 13 | exclude_mods = None 14 | if config.exclude_layers: 15 | exclude_mods = [ 16 | module 17 | for name, module in float_model.named_modules() 18 | if name in config.exclude_layers 19 | ] 20 | 21 | def quantize_nn_mod(float_model): 22 | for name, mod in float_model.named_modules(): 23 | new_mod = None 24 | if config.exclude_layers and mod in exclude_mods: 25 | continue 26 | if hasattr(mod, "get_quantized_version"): 27 | new_mod = mod.get_quantized_version() 28 | elif isinstance(mod, torch.nn.Linear): 29 | new_mod = create_quantized_from_nn_linear(mod, config) 30 | elif isinstance(mod, torch.nn.Embedding): 31 | new_mod = create_quantized_from_nn_embedding(mod, config) 32 | 33 | if new_mod: 34 | setattr(float_model, name, new_mod) 35 | 36 | if config.enable_kv_quantization: 37 | for name, mod in float_model.__dict__.items(): 38 | if isinstance(mod, AttentionKernel): 39 | new_mod = Int8KVAttentionKernel(mod.env, mod.layer_id) 40 | setattr(float_model, name, new_mod) 41 | 42 | float_model.apply(quantize_nn_mod) 43 | return float_model 44 | -------------------------------------------------------------------------------- /jetstream_pt/third_party/gemma/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jetstream_pt/third_party/gemma/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gemma model config.""" 16 | 17 | import dataclasses 18 | import torch 19 | from typing import Optional 20 | 21 | 22 | # Keep a mapping from dtype strings to the supported torch dtypes. 23 | _STR_DTYPE_TO_TORCH_DTYPE = dict( 24 | { 25 | "float16": torch.float16, 26 | "float": torch.float32, 27 | "float32": torch.float32, 28 | "bfloat16": torch.bfloat16, 29 | } 30 | ) 31 | 32 | 33 | @dataclasses.dataclass 34 | class GemmaConfig: 35 | # The number of tokens in the vocabulary. 36 | vocab_size: int = 256000 37 | # The maximum sequence length that this model might ever be used with. 38 | max_position_embeddings: int = 8192 39 | # The number of blocks in the model. 40 | num_hidden_layers: int = 28 41 | # The number of attention heads used in the attention layers of the model. 42 | num_attention_heads: int = 16 43 | # The number of key-value heads for implementing attention. 44 | num_key_value_heads: int = 16 45 | # The hidden size of the model. 46 | hidden_size: int = 3072 47 | # The dimension of the MLP representations. 48 | intermediate_size: int = 24576 49 | # The number of head dimensions. 50 | head_dim: int = 256 51 | # The epsilon used by the rms normalization layers. 52 | rms_norm_eps: float = 1e-6 53 | # The dtype of the weights. 54 | dtype: str = "bfloat16" 55 | # Whether a quantized version of the model is used. 56 | quant: bool = False 57 | # The path to the model tokenizer. 58 | tokenizer: Optional[str] = "tokenizer/tokenizer.model" 59 | 60 | device: str = "meta" 61 | 62 | def get_dtype(self) -> Optional[torch.dtype]: 63 | """Gets the torch dtype from the config dtype string.""" 64 | return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None) 65 | 66 | 67 | def get_config_for_7b() -> GemmaConfig: 68 | return GemmaConfig() 69 | 70 | 71 | def get_config_for_2b() -> GemmaConfig: 72 | return GemmaConfig( 73 | num_hidden_layers=18, 74 | num_attention_heads=8, 75 | num_key_value_heads=1, 76 | hidden_size=2048, 77 | intermediate_size=16384, 78 | ) 79 | 80 | 81 | def get_model_config(variant: str) -> GemmaConfig: 82 | if variant == "7b": 83 | return get_config_for_7b() 84 | elif variant == "2b": 85 | return get_config_for_2b() 86 | return ValueError( 87 | f'Invalid variant {variant}. Supported variants are "2b"' 'and "7b"' 88 | ) 89 | -------------------------------------------------------------------------------- /jetstream_pt/third_party/gemma/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | from typing import List, Optional 16 | 17 | from sentencepiece import SentencePieceProcessor 18 | 19 | 20 | class Tokenizer: 21 | 22 | def __init__(self, model_path: Optional[str]): 23 | # Reload tokenizer. 24 | assert os.path.isfile(model_path), model_path 25 | self.sp_model = SentencePieceProcessor(model_file=model_path) 26 | 27 | # BOS / EOS token IDs. 28 | self.n_words: int = self.sp_model.vocab_size() 29 | self.bos_id: int = self.sp_model.bos_id() 30 | self.eos_id: int = self.sp_model.eos_id() 31 | self.pad_id: int = self.sp_model.pad_id() 32 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 33 | 34 | def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]: 35 | """Converts a string into a list of tokens.""" 36 | assert isinstance(s, str) 37 | t = self.sp_model.encode(s) 38 | if bos: 39 | t = [self.bos_id] + t 40 | if eos: 41 | t = t + [self.eos_id] 42 | return t 43 | 44 | def decode(self, t: List[int]) -> str: 45 | """Converts a list of tokens into a string.""" 46 | return self.sp_model.decode(t) 47 | -------------------------------------------------------------------------------- /jetstream_pt/third_party/llama/LICENSE: -------------------------------------------------------------------------------- 1 | LLAMA 2 COMMUNITY LICENSE AGREEMENT 2 | Llama 2 Version Release Date: July 18, 2023 3 | 4 | "Agreement" means the terms and conditions for use, reproduction, distribution and 5 | modification of the Llama Materials set forth herein. 6 | 7 | "Documentation" means the specifications, manuals and documentation 8 | accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and- 9 | libraries/llama-downloads/. 10 | 11 | "Licensee" or "you" means you, or your employer or any other person or entity (if 12 | you are entering into this Agreement on such person or entity's behalf), of the age 13 | required under applicable laws, rules or regulations to provide legal consent and that 14 | has legal authority to bind your employer or such other person or entity if you are 15 | entering in this Agreement on their behalf. 16 | 17 | "Llama 2" means the foundational large language models and software and 18 | algorithms, including machine-learning model code, trained model weights, 19 | inference-enabling code, training-enabling code, fine-tuning enabling code and other 20 | elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and- 21 | libraries/llama-downloads/. 22 | 23 | "Llama Materials" means, collectively, Meta's proprietary Llama 2 and 24 | Documentation (and any portion thereof) made available under this Agreement. 25 | 26 | "Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you 27 | are an entity, your principal place of business is in the EEA or Switzerland) and Meta 28 | Platforms, Inc. (if you are located outside of the EEA or Switzerland). 29 | 30 | By clicking "I Accept" below or by using or distributing any portion or element of the 31 | Llama Materials, you agree to be bound by this Agreement. 32 | 33 | 1. License Rights and Redistribution. 34 | 35 | a. Grant of Rights. You are granted a non-exclusive, worldwide, non- 36 | transferable and royalty-free limited license under Meta's intellectual property or 37 | other rights owned by Meta embodied in the Llama Materials to use, reproduce, 38 | distribute, copy, create derivative works of, and make modifications to the Llama 39 | Materials. 40 | 41 | b. Redistribution and Use. 42 | 43 | i. If you distribute or make the Llama Materials, or any derivative works 44 | thereof, available to a third party, you shall provide a copy of this Agreement to such 45 | third party. 46 | ii. If you receive Llama Materials, or any derivative works thereof, from 47 | a Licensee as part of an integrated end user product, then Section 2 of this 48 | Agreement will not apply to you. 49 | 50 | iii. You must retain in all copies of the Llama Materials that you 51 | distribute the following attribution notice within a "Notice" text file distributed as a 52 | part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License, 53 | Copyright (c) Meta Platforms, Inc. All Rights Reserved." 54 | 55 | iv. Your use of the Llama Materials must comply with applicable laws 56 | and regulations (including trade compliance laws and regulations) and adhere to the 57 | Acceptable Use Policy for the Llama Materials (available at 58 | https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into 59 | this Agreement. 60 | 61 | v. You will not use the Llama Materials or any output or results of the 62 | Llama Materials to improve any other large language model (excluding Llama 2 or 63 | derivative works thereof). 64 | 65 | 2. Additional Commercial Terms. If, on the Llama 2 version release date, the 66 | monthly active users of the products or services made available by or for Licensee, 67 | or Licensee's affiliates, is greater than 700 million monthly active users in the 68 | preceding calendar month, you must request a license from Meta, which Meta may 69 | grant to you in its sole discretion, and you are not authorized to exercise any of the 70 | rights under this Agreement unless or until Meta otherwise expressly grants you 71 | such rights. 72 | 73 | 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE 74 | LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE 75 | PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 76 | EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY 77 | WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR 78 | FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE 79 | FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING 80 | THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR 81 | USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS. 82 | 83 | 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE 84 | LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, 85 | NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS 86 | AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, 87 | CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN 88 | IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF 89 | ANY OF THE FOREGOING. 90 | 91 | 5. Intellectual Property. 92 | 93 | a. No trademark licenses are granted under this Agreement, and in 94 | connection with the Llama Materials, neither Meta nor Licensee may use any name 95 | or mark owned by or associated with the other or any of its affiliates, except as 96 | required for reasonable and customary use in describing and redistributing the 97 | Llama Materials. 98 | 99 | b. Subject to Meta's ownership of Llama Materials and derivatives made by or 100 | for Meta, with respect to any derivative works and modifications of the Llama 101 | Materials that are made by you, as between you and Meta, you are and will be the 102 | owner of such derivative works and modifications. 103 | 104 | c. If you institute litigation or other proceedings against Meta or any entity 105 | (including a cross-claim or counterclaim in a lawsuit) alleging that the Llama 106 | Materials or Llama 2 outputs or results, or any portion of any of the foregoing, 107 | constitutes an infringement of intellectual property or other rights owned or licensable 108 | by you, then any licenses granted to you under this Agreement shall terminate as of 109 | the date such litigation or claim is filed or instituted. You will indemnify and hold 110 | harmless Meta from and against any claim by any third party arising out of or related 111 | to your use or distribution of the Llama Materials. 112 | 113 | 6. Term and Termination. The term of this Agreement will commence upon your 114 | acceptance of this Agreement or access to the Llama Materials and will continue in 115 | full force and effect until terminated in accordance with the terms and conditions 116 | herein. Meta may terminate this Agreement if you are in breach of any term or 117 | condition of this Agreement. Upon termination of this Agreement, you shall delete 118 | and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the 119 | termination of this Agreement. 120 | 121 | 7. Governing Law and Jurisdiction. This Agreement will be governed and 122 | construed under the laws of the State of California without regard to choice of law 123 | principles, and the UN Convention on Contracts for the International Sale of Goods 124 | does not apply to this Agreement. The courts of California shall have exclusive 125 | jurisdiction of any dispute arising out of this Agreement. 126 | 127 | -------------------------------------------------------------------------------- /jetstream_pt/third_party/llama/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/jetstream-pytorch/f4d775358c3eadcf34d4ab0ba244cb1167f7a628/jetstream_pt/third_party/llama/__init__.py -------------------------------------------------------------------------------- /jetstream_pt/third_party/llama/model_args.py: -------------------------------------------------------------------------------- 1 | # pylint: disable-all 2 | """The original Llama2 model.""" 3 | 4 | import dataclasses 5 | from typing import Optional 6 | 7 | 8 | @dataclasses.dataclass 9 | class RopeScalingArgs: 10 | """Rope scaling configuration parameters.""" 11 | 12 | factor: float = 8.0 13 | low_freq_factor: float = 1.0 14 | high_freq_factor: float = 4.0 15 | original_max_position_embeddings: int = 8192 16 | 17 | 18 | @dataclasses.dataclass 19 | class ModelArgs: 20 | """Model configuration parameters.""" 21 | 22 | dim: int = -1 23 | n_layers: int = -1 24 | n_heads: int = -1 25 | n_kv_heads: Optional[int] = None 26 | vocab_size: int = -1 # defined later by tokenizer 27 | multiple_of: int = ( 28 | 256 # make SwiGLU hidden layer size multiple of large power of 2 29 | ) 30 | ffn_dim_multiplier: Optional[float] = None 31 | norm_eps: float = 1e-5 32 | 33 | max_batch_size: int = -1 34 | max_seq_len: int = -1 35 | 36 | bf16_enable: bool = False 37 | head_dim = -1 38 | infer_length = 0 39 | device = "cpu" 40 | 41 | rope_theta: float = 10000.0 42 | rope_scaling_args: RopeScalingArgs = None 43 | 44 | 45 | def get_arg( 46 | model_name: str, 47 | seqlen, 48 | batch_size, 49 | bf16_enable: bool = False, 50 | ) -> ModelArgs: 51 | """Gets model args.""" 52 | 53 | data = {} 54 | if model_name == "llama-2-tiny": 55 | data = { 56 | "dim": 128, 57 | "vocab_size": 32000, 58 | "multiple_of": 32, 59 | "n_heads": 64, 60 | "n_kv_heads": 8, 61 | "n_layers": 3, 62 | "norm_eps": 1e-05, 63 | } 64 | elif model_name == "llama-2-7b": 65 | data = { 66 | "dim": 4096, 67 | "vocab_size": 32000, 68 | "multiple_of": 256, 69 | "n_heads": 32, 70 | "n_layers": 32, 71 | "norm_eps": 1e-05, 72 | } 73 | elif model_name == "llama-2-13b": 74 | data = { 75 | "dim": 5120, 76 | "vocab_size": 32000, 77 | "multiple_of": 256, 78 | "n_heads": 40, 79 | "n_layers": 40, 80 | "norm_eps": 1e-05, 81 | } 82 | elif model_name == "llama-2-70b": 83 | data = { 84 | "dim": 8192, 85 | "vocab_size": 32000, 86 | "multiple_of": 4096, 87 | "ffn_dim_multiplier": 1.3, 88 | "n_heads": 64, 89 | "n_kv_heads": 8, 90 | "n_layers": 80, 91 | "norm_eps": 1e-05, 92 | } 93 | elif model_name == "llama-3-8b": 94 | data = { 95 | "dim": 4096, 96 | "vocab_size": 128256, 97 | "multiple_of": 1024, 98 | "ffn_dim_multiplier": 1.3, 99 | "n_layers": 32, 100 | "n_heads": 32, 101 | "n_kv_heads": 8, 102 | "norm_eps": 1e-05, 103 | "rope_theta": 500000.0, 104 | } 105 | elif model_name == "llama-3-70b": 106 | data = { 107 | "dim": 8192, 108 | "ffn_dim_multiplier": 1.3, 109 | "multiple_of": 4096, 110 | "n_heads": 64, 111 | "n_kv_heads": 8, 112 | "n_layers": 80, 113 | "norm_eps": 1e-05, 114 | "vocab_size": 128256, 115 | "rope_theta": 500000.0, 116 | } 117 | elif model_name == "llama-3.1-8b": 118 | data = { 119 | "dim": 4096, 120 | "vocab_size": 128256, 121 | "multiple_of": 1024, 122 | "ffn_dim_multiplier": 1.3, 123 | "n_layers": 32, 124 | "n_heads": 32, 125 | "n_kv_heads": 8, 126 | "norm_eps": 1e-05, 127 | "rope_theta": 500000.0, 128 | "rope_scaling_args": RopeScalingArgs( 129 | factor=8.0, 130 | low_freq_factor=1.0, 131 | high_freq_factor=4.0, 132 | original_max_position_embeddings=8192, 133 | ), 134 | } 135 | elif model_name == "llama-3.2-1b": 136 | data = { 137 | "dim": 2048, 138 | "vocab_size": 128256, 139 | "multiple_of": 1024, 140 | "ffn_dim_multiplier": 1.5, 141 | "n_layers": 16, 142 | "n_heads": 32, 143 | "n_kv_heads": 8, 144 | "norm_eps": 1e-05, 145 | "rope_theta": 500000.0, 146 | "rope_scaling_args": RopeScalingArgs( 147 | factor=32.0, 148 | low_freq_factor=1.0, 149 | high_freq_factor=4.0, 150 | original_max_position_embeddings=8192, 151 | ), 152 | } 153 | elif model_name == "llama-3.3-70b": 154 | data = { 155 | "dim": 8192, 156 | "vocab_size": 128256, 157 | "multiple_of": 1024, 158 | "ffn_dim_multiplier": 1.3, 159 | "n_layers": 80, 160 | "n_heads": 64, 161 | "n_kv_heads": 8, 162 | "norm_eps": 1e-05, 163 | "rope_theta": 500000.0, 164 | "rope_scaling_args": RopeScalingArgs( 165 | factor=8.0, 166 | low_freq_factor=1.0, 167 | high_freq_factor=4.0, 168 | original_max_position_embeddings=8192, 169 | ), 170 | } 171 | 172 | return ModelArgs( 173 | max_seq_len=seqlen, 174 | max_batch_size=batch_size, 175 | bf16_enable=bf16_enable, 176 | **data, 177 | ) 178 | 179 | 180 | def get_model_args(model_name, context_length, batch_size, bf16_enable): 181 | model_args = get_arg( 182 | model_name=model_name, 183 | seqlen=context_length, 184 | batch_size=batch_size, 185 | bf16_enable=bf16_enable, 186 | ) 187 | model_args.n_kv_heads = ( 188 | model_args.n_heads 189 | if model_args.n_kv_heads is None 190 | else model_args.n_kv_heads 191 | ) 192 | model_args.head_dim = model_args.dim // model_args.n_heads 193 | return model_args 194 | -------------------------------------------------------------------------------- /jetstream_pt/third_party/llama/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/jetstream-pytorch/f4d775358c3eadcf34d4ab0ba244cb1167f7a628/jetstream_pt/third_party/llama/tokenizer.model -------------------------------------------------------------------------------- /jetstream_pt/third_party/llama/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import os 5 | from typing import List 6 | 7 | from sentencepiece import SentencePieceProcessor 8 | 9 | 10 | """Only use decode to do accuacy varification""" 11 | 12 | 13 | class Tokenizer: 14 | """tokenizing and encoding/decoding text using SentencePiece.""" 15 | 16 | def __init__(self, model_path: str): 17 | """ 18 | Initializes the Tokenizer with a SentencePiece model. 19 | 20 | Args: 21 | model_path (str): The path to the SentencePiece model file. 22 | """ 23 | # reload tokenizer 24 | print(f"model_path: {model_path}") 25 | assert os.path.isfile(model_path), model_path 26 | self.sp_model = SentencePieceProcessor(model_file=model_path) 27 | 28 | # BOS / EOS token IDs 29 | self.n_words: int = self.sp_model.vocab_size() 30 | self.bos_id: int = self.sp_model.bos_id() 31 | self.eos_id: int = self.sp_model.eos_id() 32 | self.pad_id: int = self.sp_model.pad_id() 33 | 34 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 35 | 36 | def decode(self, t: List[int]) -> str: 37 | """ 38 | Decodes a list of token IDs into a string. 39 | 40 | Args: 41 | t (List[int]): The list of token IDs to be decoded. 42 | 43 | Returns: 44 | str: The decoded string. 45 | """ 46 | return self.sp_model.decode(t) 47 | -------------------------------------------------------------------------------- /jetstream_pt/third_party/mixtral/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/jetstream-pytorch/f4d775358c3eadcf34d4ab0ba244cb1167f7a628/jetstream_pt/third_party/mixtral/__init__.py -------------------------------------------------------------------------------- /jetstream_pt/third_party/mixtral/config.py: -------------------------------------------------------------------------------- 1 | # pylint: disable-all 2 | # # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Mixtral model config 17 | import dataclasses 18 | from dataclasses import dataclass 19 | 20 | 21 | def find_multiple(n: int, k: int) -> int: 22 | if n % k == 0: 23 | return n 24 | return n + k - (n % k) 25 | 26 | 27 | @dataclass 28 | class ModelArgs: 29 | block_size: int = 2048 30 | vocab_size: int = 32000 31 | n_layer: int = 32 32 | n_head: int = 32 33 | dim: int = 4096 34 | intermediate_size: int = None 35 | n_local_heads: int = -1 36 | head_dim: int = 64 37 | rope_base: float = 10000 38 | norm_eps: float = 1e-5 39 | num_experts: int = 8 40 | num_activated_experts: int = 2 41 | device: str = "meta" 42 | 43 | def __post_init__(self): 44 | if self.n_local_heads == -1: 45 | self.n_local_heads = self.n_head 46 | if self.intermediate_size is None: 47 | hidden_dim = 4 * self.dim 48 | n_hidden = int(2 * hidden_dim / 3) 49 | self.intermediate_size = find_multiple(n_hidden, 256) 50 | self.head_dim = self.dim // self.n_head 51 | 52 | @classmethod 53 | def from_name(cls, name: str): 54 | if name in transformer_configs: 55 | return cls(**transformer_configs[name]) 56 | # fuzzy search 57 | config = [ 58 | config 59 | for config in transformer_configs 60 | if config in str(name).upper() or config in str(name) 61 | ] 62 | assert len(config) == 1, name 63 | return cls(**transformer_configs[config[0]]) 64 | 65 | 66 | transformer_configs = { 67 | "Mixtral-8x7B-v0.1": dict( 68 | block_size=32768, 69 | n_layer=32, 70 | n_head=32, 71 | n_local_heads=8, 72 | dim=4096, 73 | intermediate_size=14336, 74 | rope_base=1000000.0, 75 | num_experts=8, 76 | num_activated_experts=2, 77 | ), 78 | "Mixtral-tiny": dict( 79 | block_size=128, 80 | n_layer=3, 81 | n_head=32, 82 | n_local_heads=8, 83 | dim=128, 84 | intermediate_size=None, 85 | rope_base=1000000.0, 86 | num_experts=8, 87 | num_activated_experts=2, 88 | ), 89 | } 90 | -------------------------------------------------------------------------------- /jetstream_pt/third_party/mixtral/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/jetstream-pytorch/f4d775358c3eadcf34d4ab0ba244cb1167f7a628/jetstream_pt/third_party/mixtral/tokenizer.model -------------------------------------------------------------------------------- /jetstream_pt/torchjax.py: -------------------------------------------------------------------------------- 1 | """This file will serve as proxy APIs for torch_xla2 API. 2 | 3 | It serves 2 purposes: 4 | 5 | 1. torch_xla2 APIs are not 6 | stable yet, and changes of it means lots of code edits throughout 7 | this repo. So future changes of torch_xla2 API we only need to edit 8 | this one file. 9 | 10 | 2. We can iterate API look and feel in this file and the influence 11 | how it looks like in torch_xla2. 12 | """ 13 | 14 | import torch 15 | from torch.utils import _pytree as pytree 16 | 17 | import torch_xla2 18 | import torch_xla2.interop 19 | 20 | call_jax = torch_xla2.interop.call_jax 21 | call_torch = torch_xla2.interop.call_torch 22 | 23 | 24 | def to_torch(tensors): 25 | """Wrap a jax Array into XLATensor.""" 26 | return torch_xla2.default_env().j2t_iso(tensors) 27 | 28 | 29 | def from_torch_with_copy(tensors): 30 | """Convert torch tensor to Jax Array.""" 31 | 32 | def convert_tensor(t): 33 | if isinstance(t, torch_xla2.tensor.XLATensor2): 34 | return t.jax() 35 | return torch_xla2.tensor.t2j(t) 36 | 37 | return pytree.tree_map_only(torch.Tensor, convert_tensor, tensors) 38 | 39 | 40 | def from_torch(tensors): 41 | """Unwrap a XLATensor into jax Array. 42 | 43 | Will raise if passed in a torch.Tensor that is not XLATensor 44 | """ 45 | return torch_xla2.default_env().t2j_iso(tensors) 46 | -------------------------------------------------------------------------------- /kuberay/image/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rayproject/ray:2.32.0-py310 2 | 3 | RUN pip install flax==0.8.3 4 | RUN pip install jax[tpu]==0.4.30 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 5 | RUN pip install tensorflow-text 6 | RUN pip install tensorflow 7 | 8 | RUN pip install torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu 9 | RUN pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage 10 | RUN pip install safetensors colorama coverage humanize 11 | 12 | RUN git clone https://github.com/google/jetstream-pytorch 13 | WORKDIR jetstream-pytorch 14 | 15 | RUN git submodule update --init --recursive 16 | RUN pip install -e . 17 | -------------------------------------------------------------------------------- /kuberay/manifests/ray-cluster.tpu-v4-multihost.yaml: -------------------------------------------------------------------------------- 1 | # This template contains a Kuberay cluster using a 2x2x2 TPU v4 PodSlice. 2 | # To get access to TPU resources, please follow instructions in this link: 3 | # https://cloud.google.com/kubernetes-engine/docs/how-to/tpus 4 | apiVersion: ray.io/v1 5 | kind: RayCluster 6 | metadata: 7 | name: example-cluster-kuberay 8 | spec: 9 | headGroupSpec: 10 | rayStartParams: 11 | {} 12 | template: 13 | spec: 14 | imagePullSecrets: 15 | [] 16 | serviceAccountName: ray-ksa 17 | containers: 18 | - volumeMounts: 19 | - name: gcs-fuse-checkpoint 20 | mountPath: /llama 21 | readOnly: true 22 | - mountPath: /tmp/ray 23 | name: ray-logs 24 | name: ray-head 25 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729 26 | imagePullPolicy: IfNotPresent 27 | resources: 28 | limits: 29 | cpu: "4" 30 | ephemeral-storage: 30Gi 31 | memory: 40G 32 | requests: 33 | cpu: "4" 34 | ephemeral-storage: 30Gi 35 | memory: 40G 36 | securityContext: 37 | {} 38 | env: 39 | - name: JAX_PLATFORMS 40 | value: "cpu" 41 | - name: RAY_memory_monitor_refresh_ms 42 | value: "0" 43 | - name: RAY_GRAFANA_IFRAME_HOST 44 | value: http://${grafana_host} 45 | - name: RAY_GRAFANA_HOST 46 | value: http://grafana:80 47 | - name: RAY_PROMETHEUS_HOST 48 | value: http://frontend:9090 49 | ports: 50 | - containerPort: 6379 51 | name: gcs 52 | - containerPort: 8265 53 | name: dashboard 54 | - containerPort: 10001 55 | name: client 56 | - containerPort: 8000 57 | name: serve 58 | - containerPort: 8471 59 | name: slicebuilder 60 | - containerPort: 8081 61 | name: mxla 62 | - containerPort: 8888 63 | name: grpc 64 | volumes: 65 | - emptyDir: {} 66 | name: ray-logs 67 | - name: gcs-fuse-checkpoint 68 | csi: 69 | driver: gcsfuse.csi.storage.gke.io 70 | readOnly: true 71 | volumeAttributes: 72 | bucketName: ricliu-llama2-70b-chat 73 | mountOptions: "implicit-dirs" 74 | metadata: 75 | annotations: 76 | gke-gcsfuse/volumes: "true" 77 | labels: 78 | cloud.google.com/gke-ray-node-type: head 79 | app.kubernetes.io/name: kuberay 80 | app.kubernetes.io/instance: example-cluster 81 | 82 | workerGroupSpecs: 83 | - rayStartParams: 84 | {} 85 | replicas: 1 86 | minReplicas: 1 87 | maxReplicas: 1 88 | numOfHosts: 2 89 | groupName: workergroup 90 | template: 91 | spec: 92 | imagePullSecrets: 93 | [] 94 | serviceAccountName: ray-ksa 95 | containers: 96 | - volumeMounts: 97 | - mountPath: /tmp/ray 98 | name: ray-logs 99 | - name: gcs-fuse-checkpoint 100 | mountPath: /llama 101 | readOnly: true 102 | name: ray-worker 103 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729 104 | imagePullPolicy: IfNotPresent 105 | resources: 106 | limits: 107 | cpu: "8" 108 | ephemeral-storage: 30Gi 109 | google.com/tpu: "4" 110 | memory: 200G 111 | requests: 112 | cpu: "8" 113 | ephemeral-storage: 30Gi 114 | google.com/tpu: "4" 115 | memory: 200G 116 | securityContext: 117 | {} 118 | env: 119 | - name: JAX_PLATFORMS 120 | value: "cpu" 121 | ports: 122 | null 123 | volumes: 124 | - emptyDir: {} 125 | name: ray-logs 126 | - name: gcs-fuse-checkpoint 127 | csi: 128 | driver: gcsfuse.csi.storage.gke.io 129 | readOnly: true 130 | volumeAttributes: 131 | bucketName: ricliu-llama2-70b-chat 132 | mountOptions: "implicit-dirs" 133 | nodeSelector: 134 | cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice 135 | cloud.google.com/gke-tpu-topology: 2x2x2 136 | iam.gke.io/gke-metadata-server-enabled: "true" 137 | metadata: 138 | annotations: 139 | gke-gcsfuse/volumes: "true" 140 | labels: 141 | cloud.google.com/gke-ray-node-type: worker 142 | app.kubernetes.io/name: kuberay 143 | app.kubernetes.io/instance: example-cluster 144 | 145 | -------------------------------------------------------------------------------- /kuberay/manifests/ray-cluster.tpu-v4-singlehost.yaml: -------------------------------------------------------------------------------- 1 | # This template contains a Kuberay cluster using a 2x2x1 TPU v4 PodSlice. 2 | # To get access to TPU resources, please follow instructions in this link: 3 | # https://cloud.google.com/kubernetes-engine/docs/how-to/tpus 4 | apiVersion: ray.io/v1 5 | kind: RayCluster 6 | metadata: 7 | name: example-cluster-kuberay 8 | spec: 9 | headGroupSpec: 10 | rayStartParams: 11 | {} 12 | template: 13 | spec: 14 | imagePullSecrets: 15 | [] 16 | serviceAccountName: ray-ksa 17 | containers: 18 | - volumeMounts: 19 | - name: gcs-fuse-checkpoint 20 | mountPath: /llama 21 | readOnly: true 22 | - mountPath: /tmp/ray 23 | name: ray-logs 24 | name: ray-head 25 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729 26 | imagePullPolicy: IfNotPresent 27 | resources: 28 | limits: 29 | cpu: "4" 30 | ephemeral-storage: 30Gi 31 | memory: 40G 32 | requests: 33 | cpu: "4" 34 | ephemeral-storage: 30Gi 35 | memory: 40G 36 | securityContext: 37 | {} 38 | env: 39 | - name: JAX_PLATFORMS 40 | value: "cpu" 41 | - name: RAY_memory_monitor_refresh_ms 42 | value: "0" 43 | - name: RAY_GRAFANA_IFRAME_HOST 44 | value: http://${grafana_host} 45 | - name: RAY_GRAFANA_HOST 46 | value: http://grafana:80 47 | - name: RAY_PROMETHEUS_HOST 48 | value: http://frontend:9090 49 | ports: 50 | - containerPort: 6379 51 | name: gcs 52 | - containerPort: 8265 53 | name: dashboard 54 | - containerPort: 10001 55 | name: client 56 | - containerPort: 8000 57 | name: serve 58 | - containerPort: 8888 59 | name: grpc 60 | volumes: 61 | - emptyDir: {} 62 | name: ray-logs 63 | - name: gcs-fuse-checkpoint 64 | csi: 65 | driver: gcsfuse.csi.storage.gke.io 66 | readOnly: true 67 | volumeAttributes: 68 | bucketName: ricliu-llama2 69 | mountOptions: "implicit-dirs" 70 | metadata: 71 | annotations: 72 | gke-gcsfuse/volumes: "true" 73 | labels: 74 | cloud.google.com/gke-ray-node-type: head 75 | app.kubernetes.io/name: kuberay 76 | app.kubernetes.io/instance: example-cluster 77 | 78 | workerGroupSpecs: 79 | - rayStartParams: 80 | {} 81 | replicas: 1 82 | minReplicas: 1 83 | maxReplicas: 1 84 | numOfHosts: 1 85 | groupName: workergroup 86 | template: 87 | spec: 88 | imagePullSecrets: 89 | [] 90 | serviceAccountName: ray-ksa 91 | containers: 92 | - volumeMounts: 93 | - mountPath: /tmp/ray 94 | name: ray-logs 95 | - name: gcs-fuse-checkpoint 96 | mountPath: /llama 97 | readOnly: true 98 | name: ray-worker 99 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729 100 | imagePullPolicy: IfNotPresent 101 | resources: 102 | limits: 103 | cpu: "8" 104 | ephemeral-storage: 30Gi 105 | google.com/tpu: "4" 106 | memory: 200G 107 | requests: 108 | cpu: "8" 109 | ephemeral-storage: 30Gi 110 | google.com/tpu: "4" 111 | memory: 200G 112 | securityContext: 113 | {} 114 | env: 115 | - name: JAX_PLATFORMS 116 | value: "cpu" 117 | ports: 118 | null 119 | volumes: 120 | - emptyDir: {} 121 | name: ray-logs 122 | - name: gcs-fuse-checkpoint 123 | csi: 124 | driver: gcsfuse.csi.storage.gke.io 125 | readOnly: true 126 | volumeAttributes: 127 | bucketName: ricliu-llama2 128 | mountOptions: "implicit-dirs" 129 | nodeSelector: 130 | cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice 131 | cloud.google.com/gke-tpu-topology: 2x2x1 132 | iam.gke.io/gke-metadata-server-enabled: "true" 133 | metadata: 134 | annotations: 135 | gke-gcsfuse/volumes: "true" 136 | labels: 137 | cloud.google.com/gke-ray-node-type: worker 138 | app.kubernetes.io/name: kuberay 139 | app.kubernetes.io/instance: example-cluster 140 | 141 | -------------------------------------------------------------------------------- /kuberay/manifests/ray-cluster.tpu-v5-multihost.yaml: -------------------------------------------------------------------------------- 1 | # This template contains a Kuberay cluster using a 2x2x2 TPU v4 PodSlice. 2 | # To get access to TPU resources, please follow instructions in this link: 3 | # https://cloud.google.com/kubernetes-engine/docs/how-to/tpus 4 | apiVersion: ray.io/v1 5 | kind: RayCluster 6 | metadata: 7 | name: example-cluster-kuberay 8 | spec: 9 | headGroupSpec: 10 | rayStartParams: 11 | {} 12 | template: 13 | spec: 14 | imagePullSecrets: 15 | [] 16 | serviceAccountName: ray-ksa 17 | containers: 18 | - volumeMounts: 19 | - name: gcs-fuse-checkpoint 20 | mountPath: /llama 21 | readOnly: true 22 | - mountPath: /tmp/ray 23 | name: ray-logs 24 | name: ray-head 25 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729 26 | imagePullPolicy: IfNotPresent 27 | resources: 28 | limits: 29 | cpu: "4" 30 | ephemeral-storage: 30Gi 31 | memory: 40G 32 | requests: 33 | cpu: "4" 34 | ephemeral-storage: 30Gi 35 | memory: 40G 36 | securityContext: 37 | {} 38 | env: 39 | - name: JAX_PLATFORMS 40 | value: "cpu" 41 | - name: RAY_memory_monitor_refresh_ms 42 | value: "0" 43 | - name: RAY_GRAFANA_IFRAME_HOST 44 | value: http://${grafana_host} 45 | - name: RAY_GRAFANA_HOST 46 | value: http://grafana:80 47 | - name: RAY_PROMETHEUS_HOST 48 | value: http://frontend:9090 49 | ports: 50 | - containerPort: 6379 51 | name: gcs 52 | - containerPort: 8265 53 | name: dashboard 54 | - containerPort: 10001 55 | name: client 56 | - containerPort: 8000 57 | name: serve 58 | - containerPort: 8471 59 | name: slicebuilder 60 | - containerPort: 8081 61 | name: mxla 62 | - containerPort: 8888 63 | name: grpc 64 | volumes: 65 | - emptyDir: {} 66 | name: ray-logs 67 | - name: gcs-fuse-checkpoint 68 | csi: 69 | driver: gcsfuse.csi.storage.gke.io 70 | readOnly: true 71 | volumeAttributes: 72 | bucketName: ricliu-llama2-70b-chat 73 | mountOptions: "implicit-dirs" 74 | metadata: 75 | annotations: 76 | gke-gcsfuse/volumes: "true" 77 | labels: 78 | cloud.google.com/gke-ray-node-type: head 79 | app.kubernetes.io/name: kuberay 80 | app.kubernetes.io/instance: example-cluster 81 | 82 | workerGroupSpecs: 83 | - rayStartParams: 84 | {} 85 | replicas: 1 86 | minReplicas: 1 87 | maxReplicas: 1 88 | numOfHosts: 2 89 | groupName: workergroup 90 | template: 91 | spec: 92 | imagePullSecrets: 93 | [] 94 | serviceAccountName: ray-ksa 95 | containers: 96 | - volumeMounts: 97 | - mountPath: /tmp/ray 98 | name: ray-logs 99 | - name: gcs-fuse-checkpoint 100 | mountPath: /llama 101 | readOnly: true 102 | name: ray-worker 103 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729 104 | imagePullPolicy: IfNotPresent 105 | resources: 106 | limits: 107 | cpu: "8" 108 | ephemeral-storage: 30Gi 109 | google.com/tpu: "4" 110 | memory: 180G 111 | requests: 112 | cpu: "8" 113 | ephemeral-storage: 30Gi 114 | google.com/tpu: "4" 115 | memory: 180G 116 | securityContext: 117 | {} 118 | env: 119 | - name: JAX_PLATFORMS 120 | value: "cpu" 121 | ports: 122 | null 123 | volumes: 124 | - emptyDir: {} 125 | name: ray-logs 126 | - name: gcs-fuse-checkpoint 127 | csi: 128 | driver: gcsfuse.csi.storage.gke.io 129 | readOnly: true 130 | volumeAttributes: 131 | bucketName: ricliu-llama2-70b-chat 132 | mountOptions: "implicit-dirs" 133 | nodeSelector: 134 | iam.gke.io/gke-metadata-server-enabled: "true" 135 | cloud.google.com/gke-tpu-topology: 2x4 136 | cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice 137 | metadata: 138 | annotations: 139 | gke-gcsfuse/volumes: "true" 140 | labels: 141 | cloud.google.com/gke-ray-node-type: worker 142 | app.kubernetes.io/name: kuberay 143 | app.kubernetes.io/instance: example-cluster 144 | 145 | -------------------------------------------------------------------------------- /kuberay/manifests/ray-cluster.tpu-v5-singlehost.yaml: -------------------------------------------------------------------------------- 1 | # This template contains a Kuberay cluster using a 2x2x1 TPU v4 PodSlice. 2 | # To get access to TPU resources, please follow instructions in this link: 3 | # https://cloud.google.com/kubernetes-engine/docs/how-to/tpus 4 | apiVersion: ray.io/v1 5 | kind: RayCluster 6 | metadata: 7 | name: example-cluster-kuberay 8 | spec: 9 | headGroupSpec: 10 | rayStartParams: 11 | {} 12 | template: 13 | spec: 14 | imagePullSecrets: 15 | [] 16 | serviceAccountName: ray-ksa 17 | containers: 18 | - volumeMounts: 19 | - name: gcs-fuse-checkpoint 20 | mountPath: /llama 21 | readOnly: true 22 | - mountPath: /tmp/ray 23 | name: ray-logs 24 | name: ray-head 25 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729 26 | imagePullPolicy: IfNotPresent 27 | resources: 28 | limits: 29 | cpu: "4" 30 | ephemeral-storage: 30Gi 31 | memory: 40G 32 | requests: 33 | cpu: "4" 34 | ephemeral-storage: 30Gi 35 | memory: 40G 36 | securityContext: 37 | {} 38 | env: 39 | - name: JAX_PLATFORMS 40 | value: "cpu" 41 | - name: RAY_memory_monitor_refresh_ms 42 | value: "0" 43 | - name: RAY_GRAFANA_IFRAME_HOST 44 | value: http://${grafana_host} 45 | - name: RAY_GRAFANA_HOST 46 | value: http://grafana:80 47 | - name: RAY_PROMETHEUS_HOST 48 | value: http://frontend:9090 49 | ports: 50 | - containerPort: 6379 51 | name: gcs 52 | - containerPort: 8265 53 | name: dashboard 54 | - containerPort: 10001 55 | name: client 56 | - containerPort: 8000 57 | name: serve 58 | - containerPort: 8888 59 | name: grpc 60 | volumes: 61 | - emptyDir: {} 62 | name: ray-logs 63 | - name: gcs-fuse-checkpoint 64 | csi: 65 | driver: gcsfuse.csi.storage.gke.io 66 | readOnly: true 67 | volumeAttributes: 68 | bucketName: ricliu-llama2 69 | mountOptions: "implicit-dirs" 70 | metadata: 71 | annotations: 72 | gke-gcsfuse/volumes: "true" 73 | labels: 74 | cloud.google.com/gke-ray-node-type: head 75 | app.kubernetes.io/name: kuberay 76 | app.kubernetes.io/instance: example-cluster 77 | 78 | workerGroupSpecs: 79 | - rayStartParams: 80 | {} 81 | replicas: 1 82 | minReplicas: 1 83 | maxReplicas: 1 84 | numOfHosts: 1 85 | groupName: workergroup 86 | template: 87 | spec: 88 | imagePullSecrets: 89 | [] 90 | serviceAccountName: ray-ksa 91 | containers: 92 | - volumeMounts: 93 | - mountPath: /tmp/ray 94 | name: ray-logs 95 | - name: gcs-fuse-checkpoint 96 | mountPath: /llama 97 | readOnly: true 98 | name: ray-worker 99 | image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240729 100 | imagePullPolicy: IfNotPresent 101 | resources: 102 | limits: 103 | cpu: "8" 104 | ephemeral-storage: 30Gi 105 | google.com/tpu: "8" 106 | memory: 200G 107 | requests: 108 | cpu: "8" 109 | ephemeral-storage: 30Gi 110 | google.com/tpu: "8" 111 | memory: 200G 112 | securityContext: 113 | {} 114 | env: 115 | - name: JAX_PLATFORMS 116 | value: "cpu" 117 | ports: 118 | null 119 | volumes: 120 | - emptyDir: {} 121 | name: ray-logs 122 | - name: gcs-fuse-checkpoint 123 | csi: 124 | driver: gcsfuse.csi.storage.gke.io 125 | readOnly: true 126 | volumeAttributes: 127 | bucketName: ricliu-llama2 128 | mountOptions: "implicit-dirs" 129 | nodeSelector: 130 | cloud.google.com/gke-tpu-topology: 2x4 131 | cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice 132 | iam.gke.io/gke-metadata-server-enabled: "true" 133 | metadata: 134 | annotations: 135 | gke-gcsfuse/volumes: "true" 136 | labels: 137 | cloud.google.com/gke-ray-node-type: worker 138 | app.kubernetes.io/name: kuberay 139 | app.kubernetes.io/instance: example-cluster 140 | 141 | -------------------------------------------------------------------------------- /mlperf/README.md: -------------------------------------------------------------------------------- 1 | # Run MLPerf tests 2 | 3 | NOTE: currently only tried with mixtral; 4 | and only tried with offline benchmark 5 | 6 | # How to run 7 | 8 | ### 1. Install 9 | 10 | ``` 11 | ./install.sh 12 | ``` 13 | 14 | ### 2. Start server 15 | 16 | ``` 17 | ./start_server.sh 18 | ``` 19 | 20 | ### 3. Warm up the server 21 | 22 | ``` 23 | python warmup.py 24 | ``` 25 | 26 | ### 4. Run the benchmark, now it runs offline mode 27 | 28 | ``` 29 | ./benchmark_run.sh 30 | ``` 31 | 32 | -------------------------------------------------------------------------------- /mlperf/benchmark_run.sh: -------------------------------------------------------------------------------- 1 | BASEDIR=mlperf 2 | API_URL=0.0.0.0:9000 3 | USER_CONFIG=$BASEDIR/user.conf 4 | DATA_DISK_DIR=$BASEDIR/data 5 | TOTAL_SAMPLE_COUNT=1000 6 | DATASET_PATH=$BASEDIR/data/mixtral_15k_data.pkl 7 | 8 | # HF model id 9 | TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1" 10 | 11 | LOADGEN_RUN_TYPE=offline-performance 12 | OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID} 13 | OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP} 14 | 15 | mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR} 16 | 17 | pushd .. 18 | python -m mlperf.main \ 19 | --api-url ${API_URL} \ 20 | --scenario Offline \ 21 | --input-mode tokenized \ 22 | --output-mode tokenized \ 23 | --log-pred-outputs \ 24 | --mlperf-conf $BASEDIR/mlperf.conf \ 25 | --user-conf ${USER_CONFIG} \ 26 | --audit-conf no-audit \ 27 | --total-sample-count ${TOTAL_SAMPLE_COUNT} \ 28 | --dataset-path ${DATASET_PATH} \ 29 | --tokenizer-path ${TOKENIZER_PATH} \ 30 | --log-interval 1000 \ 31 | --output-log-dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log 32 | popd -------------------------------------------------------------------------------- /mlperf/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | 18 | import pandas as pd 19 | 20 | logging.basicConfig(level=logging.INFO) 21 | log = logging.getLogger("dataset.py") 22 | 23 | 24 | class Dataset: 25 | 26 | def __init__( 27 | self, 28 | dataset_path: str, 29 | input_mode: str, 30 | total_sample_count: int = 15000, 31 | perf_count_override: int = None, 32 | ): 33 | if not os.path.isfile(dataset_path): 34 | log.warn( 35 | "Processed pickle file {} not found. Please check that the path is correct".format( 36 | dataset_path 37 | ) 38 | ) 39 | self.dataset_path = dataset_path 40 | 41 | self._input_mode = validate_sample_mode(input_mode) 42 | self.load_processed_dataset() 43 | 44 | self.total_sample_count = min(len(self.input_ids_strs), total_sample_count) 45 | self.perf_count = perf_count_override or self.total_sample_count 46 | 47 | @property 48 | def input_ids_strs(self): 49 | return self._input_ids_strs 50 | 51 | @property 52 | def input_texts(self): 53 | return self._input_texts 54 | 55 | @property 56 | def input_token_lengths(self): 57 | return self._input_token_lengths 58 | 59 | @property 60 | def inputs(self): 61 | return self._inputs 62 | 63 | @property 64 | def inputs_with_token_lengths(self): 65 | return self._inputs_with_token_lengths 66 | 67 | @property 68 | def input_datasets(self): 69 | return self._input_datasets 70 | 71 | def load_processed_dataset(self): 72 | processed_data = pd.read_pickle(self.dataset_path) 73 | # processed_data = processed_data[processed_data["dataset"] == "MBXP"] 74 | # processed_data = processed_data.reset_index(drop=True) 75 | 76 | self._input_ids_strs = [] 77 | for input_ids in processed_data["tok_input"]: 78 | input_ids_str = ",".join([str(input_id) for input_id in input_ids]) 79 | self._input_ids_strs.append(input_ids_str) 80 | 81 | self._input_texts = [] 82 | for input_text in processed_data["input"]: 83 | self._input_texts.append(input_text) 84 | 85 | self._input_token_lengths = [] 86 | for token_length in processed_data["tok_input_len"]: 87 | self._input_token_lengths.append(token_length) 88 | 89 | log.info(f"input_mode is {self._input_mode}") 90 | self._inputs = ( 91 | self._input_ids_strs 92 | if self._input_mode == "tokenized" 93 | else self._input_texts 94 | ) 95 | log.info(f"example sample input is {self._inputs[0]}") 96 | self._inputs_with_token_lengths = [ 97 | (input_ids_str_or_input_text, token_length) 98 | for input_ids_str_or_input_text, token_length in zip( 99 | self._inputs, self._input_token_lengths 100 | ) 101 | ] 102 | 103 | self._input_datasets = [] 104 | for dataset in processed_data["dataset"]: 105 | self._input_datasets.append(dataset) 106 | log.info( 107 | f"example sample input dataset is {self._input_datasets[0]} and total {len(self._input_datasets)}" 108 | ) 109 | 110 | def LoadSamplesToRam(self, sample_list): 111 | pass 112 | 113 | def UnloadSamplesFromRam(self, sample_list): 114 | pass 115 | 116 | def __del__(self): 117 | pass 118 | 119 | 120 | SAMPLE_MODE_CHOICES = ["tokenized", "text"] 121 | 122 | 123 | def validate_sample_mode(sample_mode: str) -> str: 124 | if sample_mode not in SAMPLE_MODE_CHOICES: 125 | raise ValueError( 126 | "The sample_mode should be set to either `tokenized` or `text`." 127 | ) 128 | return sample_mode 129 | -------------------------------------------------------------------------------- /mlperf/install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATA_DISK_DIR=data 4 | 5 | mkdir -p $DATA_DISK_DIR 6 | 7 | pip install -U "huggingface_hub[cli]" 8 | pip install \ 9 | transformers \ 10 | nltk==3.8.1 \ 11 | evaluate==0.4.0 \ 12 | absl-py==1.4.0 \ 13 | rouge-score==0.1.2 \ 14 | sentencepiece==0.1.99 \ 15 | accelerate==0.21.0 16 | 17 | # install loadgen 18 | pip install mlperf-loadgen 19 | 20 | 21 | pushd $DATA_DISK_DIR 22 | 23 | # model weights 24 | gcloud storage cp gs://sixiang_gcp/mixtral-instruct-quantized ./ --recursive 25 | # NOTE: uncomment one so you dont download too much weights to your box 26 | # gcloud storage cp gs://sixiang_gcp/llama2-70b/llama2-70b/ ./ --recursive 27 | 28 | # Get mixtral data 29 | wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl 30 | mv mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl mixtral_15k_data.pkl 31 | wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl 32 | mv mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl mixtral_15k_calibration_data.pkl 33 | 34 | # Get llama70b data 35 | gcloud storage cp \ 36 | gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl \ 37 | processed-calibration-data.pkl 38 | gcloud storage cp \ 39 | gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl \ 40 | processed-data.pkl 41 | popd 42 | -------------------------------------------------------------------------------- /mlperf/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | 17 | import gc 18 | import logging 19 | import os 20 | import sys 21 | 22 | from . import backend 23 | 24 | import mlperf_loadgen as lg 25 | 26 | _MLPERF_ID = "mixtral-8x7b" 27 | 28 | sys.path.insert(0, os.getcwd()) 29 | 30 | logging.basicConfig(level=logging.INFO) 31 | log = logging.getLogger("main.py") 32 | 33 | 34 | def get_args(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument( 37 | "--scenario", 38 | type=str, 39 | choices=["Offline", "Server"], 40 | default="Offline", 41 | help="Scenario", 42 | ) 43 | parser.add_argument( 44 | "--api-url", type=str, default=None, help="SAX published model path." 45 | ) 46 | parser.add_argument("--dataset-path", type=str, default=None, help="") 47 | parser.add_argument("--tokenizer-path", type=str, default=None, help="") 48 | parser.add_argument( 49 | "--accuracy", action="store_true", help="Run accuracy mode" 50 | ) 51 | parser.add_argument("--is-stream", action="store_true", help="") 52 | parser.add_argument( 53 | "--input-mode", 54 | type=str, 55 | choices=["text", "tokenized"], 56 | default="tokenized", 57 | ) 58 | parser.add_argument( 59 | "--output-mode", 60 | type=str, 61 | choices=["text", "tokenized"], 62 | default="tokenized", 63 | ) 64 | parser.add_argument( 65 | "--max-output-len", type=int, default=1024, help="Maximum output len" 66 | ) 67 | parser.add_argument( 68 | "--audit-conf", 69 | type=str, 70 | default="audit.conf", 71 | help="audit config for LoadGen settings during compliance runs", 72 | ) 73 | parser.add_argument( 74 | "--mlperf-conf", 75 | type=str, 76 | default="mlperf.conf", 77 | help="mlperf rules config", 78 | ) 79 | parser.add_argument( 80 | "--user-conf", 81 | type=str, 82 | default="user.conf", 83 | help="user config for user LoadGen settings such as target QPS", 84 | ) 85 | parser.add_argument( 86 | "--total-sample-count", 87 | type=int, 88 | default=15000, 89 | help="Number of samples to use in benchmark.", 90 | ) 91 | parser.add_argument( 92 | "--perf-count-override", 93 | type=int, 94 | default=None, 95 | help="Overwrite number of samples to use in benchmark.", 96 | ) 97 | parser.add_argument( 98 | "--output-log-dir", 99 | type=str, 100 | default="output-logs", 101 | help="Where logs are saved.", 102 | ) 103 | parser.add_argument( 104 | "--enable-log-trace", 105 | action="store_true", 106 | help="Enable log tracing. This file can become quite large", 107 | ) 108 | parser.add_argument( 109 | "--num-client-threads", 110 | type=int, 111 | default=200, 112 | help="Number of client threads to use", 113 | ) 114 | parser.add_argument("--batch-size-exp", type=int, default=6, help="") 115 | parser.add_argument("--log-pred-outputs", action="store_true", help="") 116 | parser.add_argument( 117 | "--log-interval", 118 | type=int, 119 | default=1000, 120 | help="Logging interval in seconds", 121 | ) 122 | parser.add_argument( 123 | "--user-conf-override-path", 124 | type=str, 125 | default="", 126 | help="When given overrides the default user.conf path", 127 | ) 128 | 129 | args = parser.parse_args() 130 | return args 131 | 132 | 133 | scenario_map = { 134 | "offline": lg.TestScenario.Offline, 135 | "server": lg.TestScenario.Server, 136 | } 137 | 138 | 139 | def main(): 140 | args = get_args() 141 | 142 | settings = lg.TestSettings() 143 | settings.scenario = scenario_map[args.scenario.lower()] 144 | if args.user_conf_override_path: 145 | user_conf = args.user_conf_override_path 146 | else: 147 | user_conf = args.user_conf 148 | 149 | settings.FromConfig(args.mlperf_conf, _MLPERF_ID, args.scenario) 150 | settings.FromConfig(user_conf, _MLPERF_ID, args.scenario) 151 | log.info("Mlperf config: %s", args.mlperf_conf) 152 | log.info("User config: %s", user_conf) 153 | 154 | if args.accuracy: 155 | settings.mode = lg.TestMode.AccuracyOnly 156 | log.warning( 157 | "Accuracy run will generate the accuracy logs, but the evaluation of the log is not completed yet" 158 | ) 159 | else: 160 | settings.mode = lg.TestMode.PerformanceOnly 161 | settings.print_timestamps = True 162 | 163 | settings.use_token_latencies = True 164 | 165 | os.makedirs(args.output_log_dir, exist_ok=True) 166 | log_output_settings = lg.LogOutputSettings() 167 | log_output_settings.outdir = args.output_log_dir 168 | log_output_settings.copy_summary_to_stdout = True 169 | log_settings = lg.LogSettings() 170 | log_settings.log_output = log_output_settings 171 | log_settings.enable_trace = args.enable_log_trace 172 | 173 | sut = backend.SUT( 174 | scenario=args.scenario.lower(), 175 | api_url=args.api_url, 176 | is_stream=args.is_stream, 177 | input_mode=args.input_mode, 178 | output_mode=args.output_mode, 179 | max_output_len=args.max_output_len, 180 | dataset_path=args.dataset_path, 181 | total_sample_count=args.total_sample_count, 182 | tokenizer_path=args.tokenizer_path, 183 | perf_count_override=args.perf_count_override, 184 | num_client_threads=args.num_client_threads, 185 | log_interval=args.log_interval, 186 | batch_size_exp=args.batch_size_exp, 187 | pred_outputs_log_path=os.path.join( 188 | args.output_log_dir, "pred_outputs_logger.json" 189 | ) 190 | if args.log_pred_outputs 191 | else None, 192 | ) 193 | 194 | lgSUT = sut.sut # lg.ConstructSUT(sut.issue_queries, sut.flush_queries) 195 | log.info("Starting Benchmark run") 196 | lg.StartTestWithLogSettings( 197 | lgSUT, sut.qsl, settings, log_settings, args.audit_conf 198 | ) 199 | 200 | log.info("Run Completed!") 201 | 202 | log.info("Destroying SUT...") 203 | lg.DestroySUT(lgSUT) 204 | 205 | log.info("Destroying QSL...") 206 | lg.DestroyQSL(sut.qsl) 207 | 208 | 209 | if __name__ == "__main__": 210 | # Disable garbage collection to avoid stalls when running tests. 211 | gc.disable() 212 | main() 213 | -------------------------------------------------------------------------------- /mlperf/mlperf.conf: -------------------------------------------------------------------------------- 1 | # The format of this config file is 'key = value'. 2 | # The key has the format 'model.scenario.key'. Value is mostly int64_t. 3 | # Model maybe '*' as wildcard. In that case the value applies to all models. 4 | # All times are in milli seconds 5 | 6 | # Set performance_sample_count for each model. 7 | # User can optionally set this to higher values in user.conf. 8 | resnet50.*.performance_sample_count_override = 1024 9 | ssd-mobilenet.*.performance_sample_count_override = 256 10 | retinanet.*.performance_sample_count_override = 64 11 | bert.*.performance_sample_count_override = 10833 12 | dlrm.*.performance_sample_count_override = 204800 13 | dlrm-v2.*.performance_sample_count_override = 204800 14 | rnnt.*.performance_sample_count_override = 2513 15 | gptj.*.performance_sample_count_override = 13368 16 | llama2-70b.*.performance_sample_count_override = 24576 17 | stable-diffusion-xl.*.performance_sample_count_override = 5000 18 | # set to 0 to let entire sample set to be performance sample 19 | 3d-unet.*.performance_sample_count_override = 0 20 | 21 | # Set seeds. The seeds will be distributed two weeks before the submission. 22 | *.*.qsl_rng_seed = 3066443479025735752 23 | *.*.sample_index_rng_seed = 10688027786191513374 24 | *.*.schedule_rng_seed = 14962580496156340209 25 | # Set seeds for TEST_05. The seeds will be distributed two weeks before the submission. 26 | *.*.test05_qsl_rng_seed = 16799458546791641818 27 | *.*.test05_sample_index_rng_seed = 5453809927556429288 28 | *.*.test05_schedule_rng_seed = 5435552105434836064 29 | 30 | 31 | *.SingleStream.target_latency_percentile = 90 32 | *.SingleStream.min_duration = 600000 33 | 34 | *.MultiStream.target_latency_percentile = 99 35 | *.MultiStream.samples_per_query = 8 36 | *.MultiStream.min_duration = 600000 37 | *.MultiStream.min_query_count = 662 38 | retinanet.MultiStream.target_latency = 528 39 | 40 | # 3D-UNet uses equal issue mode because it has non-uniform inputs 41 | 3d-unet.*.sample_concatenate_permutation = 1 42 | 43 | # LLM benchmarks have non-uniform inputs and outputs, and use equal issue mode for all latency scenario 44 | gptj.*.sample_concatenate_permutation = 1 45 | llama2-70b.*.sample_concatenate_permutation = 1 46 | mixtral-8x7B.*.sample_concatenate_permutation = 1 47 | 48 | *.Server.target_latency = 10 49 | *.Server.target_latency_percentile = 99 50 | *.Server.target_duration = 0 51 | *.Server.min_duration = 600000 52 | resnet50.Server.target_latency = 15 53 | retinanet.Server.target_latency = 100 54 | bert.Server.target_latency = 130 55 | dlrm.Server.target_latency = 60 56 | dlrm-v2.Server.target_latency = 60 57 | rnnt.Server.target_latency = 1000 58 | gptj.Server.target_latency = 20000 59 | stable-diffusion-xl.Server.target_latency = 20000 60 | # Llama2-70b benchmarks measures token latencies 61 | llama2-70b.*.use_token_latencies = 1 62 | mixtral-8x7b.*.use_token_latencies = 1 63 | # gptj benchmark infers token latencies 64 | gptj.*.infer_token_latencies = 1 65 | gptj.*.token_latency_scaling_factor = 69 66 | # Only ttft and tpot are tracked for the llama2-70b & mixtral-8x7B benchmark therefore target_latency = 0 67 | llama2-70b.Server.target_latency = 0 68 | llama2-70b.Server.ttft_latency = 2000 69 | llama2-70b.Server.tpot_latency = 200 70 | 71 | mixtral-8x7b.Server.target_latency = 0 72 | mixtral-8x7b.Server.ttft_latency = 2000 73 | mixtral-8x7b.Server.tpot_latency = 200 74 | 75 | *.Offline.target_latency_percentile = 90 76 | *.Offline.min_duration = 600000 77 | 78 | # In Offline scenario, we always have one query. But LoadGen maps this to 79 | # min_sample_count internally in Offline scenario. If the dataset size is larger 80 | # than 24576 we limit the min_query_count to 24576 and otherwise we use 81 | # the dataset size as the limit 82 | 83 | resnet50.Offline.min_query_count = 24576 84 | retinanet.Offline.min_query_count = 24576 85 | dlrm-v2.Offline.min_query_count = 24576 86 | bert.Offline.min_query_count = 10833 87 | gptj.Offline.min_query_count = 13368 88 | rnnt.Offline.min_query_count = 2513 89 | 3d-unet.Offline.min_query_count = 43 90 | stable-diffusion-xl.Offline.min_query_count = 5000 91 | llama2-70b.Offline.min_query_count = 1000 92 | mixtral-8x7b.Offline.min_query_count = 1000 93 | 94 | # These fields should be defined and overridden by user.conf. 95 | *.SingleStream.target_latency = 10 96 | *.MultiStream.target_latency = 80 97 | *.Server.target_qps = 1.0 98 | *.Offline.target_qps = 4.0 99 | -------------------------------------------------------------------------------- /mlperf/start_server.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CACHE_LENGTH=3072 4 | INPUT_SIZE=512 5 | OUTPUT_SIZE=512 6 | CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ 7 | 8 | pushd .. 9 | python run_server.py \ 10 | --model_name=mixtral \ 11 | --batch_size=128 \ 12 | --max_cache_length=$CACHE_LENGTH \ 13 | --max_decode_length=$OUTPUT_SIZE \ 14 | --context_length=$INPUT_SIZE \ 15 | --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \ 16 | --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \ 17 | --quantize_weights=1 \ 18 | --quantize_type=int8_per_channel \ 19 | --quantize_kv_cache=1 20 | popd -------------------------------------------------------------------------------- /mlperf/user.conf: -------------------------------------------------------------------------------- 1 | mixtral-8x7b.Server.target_qps = 1.8 2 | mixtral-8x7b.Offline.target_qps = 4.0 3 | 4 | -------------------------------------------------------------------------------- /mlperf/warmup.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | from dataclasses import dataclass, field 4 | from datetime import datetime 5 | import json 6 | import random 7 | import time 8 | from typing import Any, AsyncGenerator, Optional 9 | import os 10 | 11 | 12 | import grpc 13 | from jetstream.core.proto import jetstream_pb2 14 | from jetstream.core.proto import jetstream_pb2_grpc 15 | from jetstream.engine.token_utils import load_vocab 16 | from jetstream.third_party.llama3 import llama3_tokenizer 17 | import numpy as np 18 | from tqdm.asyncio import tqdm # pytype: disable=pyi-error 19 | import pandas 20 | 21 | 22 | @dataclass 23 | class InputRequest: 24 | prompt: str = "" 25 | prompt_len: int = 0 26 | output: str = "" 27 | output_len: int = 0 28 | sample_idx: int = -1 29 | 30 | 31 | @dataclass 32 | class RequestFuncOutput: 33 | input_request: Optional[InputRequest] = None 34 | generated_token_list: list[str] = field(default_factory=list) 35 | generated_text: str = "" 36 | success: bool = False 37 | latency: float = 0 38 | ttft: float = 0 39 | prompt_len: int = 0 40 | 41 | # Flatten the structure and return only the necessary results 42 | def to_dict(self): 43 | return { 44 | "prompt": self.input_request.prompt, 45 | "original_output": self.input_request.output, 46 | "generated_text": self.generated_text, 47 | "success": self.success, 48 | "latency": self.latency, 49 | "prompt_len": self.prompt_len, 50 | "sample_idx": self.input_request.sample_idx, 51 | } 52 | 53 | 54 | async def grpc_async_request( 55 | api_url: str, request: Any 56 | ) -> tuple[list[str], float, float]: 57 | """Send grpc synchronous request since the current grpc server is sync.""" 58 | options = [("grpc.keepalive_timeout_ms", 10000)] 59 | async with grpc.aio.insecure_channel(api_url, options=options) as channel: 60 | stub = jetstream_pb2_grpc.OrchestratorStub(channel) 61 | print("Making request") 62 | ttft = 0 63 | token_list = [] 64 | request_start_time = time.perf_counter() 65 | response = stub.Decode(request) 66 | async for resp in response: 67 | if ttft == 0: 68 | ttft = time.perf_counter() - request_start_time 69 | token_list.extend(resp.stream_content.samples[0].token_ids) 70 | latency = time.perf_counter() - request_start_time 71 | print("Done request: ", latency) 72 | return token_list, ttft, latency 73 | 74 | 75 | async def send_request( 76 | api_url: str, 77 | tokenizer: Any, 78 | input_request: InputRequest, 79 | pbar: tqdm, 80 | session_cache: str, 81 | priority: int, 82 | ) -> RequestFuncOutput: 83 | """Send the request to JetStream server.""" 84 | # Tokenization on client side following MLPerf standard. 85 | token_ids = np.random.randint(0, 1000, input_request.request_len) 86 | request = jetstream_pb2.DecodeRequest( 87 | session_cache=session_cache, 88 | token_content=jetstream_pb2.DecodeRequest.TokenContent( 89 | token_ids=token_ids 90 | ), 91 | priority=priority, 92 | max_tokens=input_request.output_len, 93 | ) 94 | output = RequestFuncOutput() 95 | output.input_request = input_request 96 | output.prompt_len = input_request.prompt_len 97 | generated_token_list, ttft, latency = await grpc_async_request( 98 | api_url, request 99 | ) 100 | output.ttft = ttft 101 | output.latency = latency 102 | output.generated_token_list = generated_token_list 103 | # generated_token_list is a list of token ids, decode it to generated_text. 104 | output.generated_text = "" 105 | output.success = True 106 | if pbar: 107 | pbar.update(1) 108 | return output 109 | 110 | 111 | async def benchmark( 112 | api_url: str, 113 | max_length: int, 114 | tokenizer: Any = None, 115 | request_rate: float = 0, 116 | disable_tqdm: bool = False, 117 | session_cache: str = "", 118 | priority: int = 100, 119 | ): 120 | """Benchmark the online serving performance.""" 121 | 122 | print(f"Traffic request rate: {request_rate}") 123 | 124 | benchmark_start_time = time.perf_counter() 125 | tasks = [] 126 | interesting_buckets = [ 127 | 4, 128 | 8, 129 | 16, 130 | 32, 131 | 64, 132 | 128, 133 | 256, 134 | 512, 135 | 1024, 136 | 2048, 137 | ] 138 | 139 | for length in interesting_buckets: 140 | if length > max_length: 141 | break 142 | request = InputRequest() 143 | request.request_len = length 144 | print("send request of length", request.request_len) 145 | tasks.append( 146 | asyncio.create_task( 147 | send_request( 148 | api_url=api_url, 149 | tokenizer=None, 150 | input_request=request, 151 | pbar=None, 152 | session_cache=session_cache, 153 | priority=priority, 154 | ) 155 | ) 156 | ) 157 | outputs = await asyncio.gather(*tasks) 158 | 159 | benchmark_duration = time.perf_counter() - benchmark_start_time 160 | return benchmark_duration, outputs 161 | 162 | 163 | def main(args: argparse.Namespace): 164 | print(args) 165 | random.seed(args.seed) 166 | np.random.seed(args.seed) 167 | api_url = f"{args.server}:{args.port}" 168 | 169 | benchmark_result, request_outputs = asyncio.run( 170 | benchmark(api_url=api_url, max_length=args.max_length) 171 | ) 172 | print("DURATION:", benchmark_result) 173 | 174 | 175 | if __name__ == "__main__": 176 | 177 | parser = argparse.ArgumentParser( 178 | description="Benchmark the online serving throughput." 179 | ) 180 | parser.add_argument( 181 | "--server", 182 | type=str, 183 | default="0.0.0.0", 184 | help="Server address.", 185 | ) 186 | parser.add_argument("--seed", type=int, default=0) 187 | 188 | parser.add_argument("--port", type=str, default=9000) 189 | parser.add_argument("--max-length", type=int, default=512) 190 | 191 | parsed_args = parser.parse_args() 192 | main(parsed_args) 193 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | version = "0.2.2" 7 | name = "jetstream_pt" 8 | dependencies = [ 9 | "absl-py", 10 | "flatbuffers", 11 | "flax", 12 | "sentencepiece", 13 | "pytest", 14 | "google-jetstream", 15 | "google-cloud-storage", 16 | "safetensors", 17 | "torch_xla2 @ {root:uri}/deps/xla/experimental/torch_xla2", 18 | "google-jetstream @ {root:uri}/deps/JetStream", 19 | ] 20 | 21 | 22 | requires-python = ">=3.10" 23 | license = {file = "LICENSE"} 24 | 25 | [project.scripts] 26 | jpt = "jetstream_pt.cli:main" 27 | 28 | [tool.hatch.metadata] 29 | allow-direct-references = true 30 | -------------------------------------------------------------------------------- /run_interactive_disaggregated.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import random 17 | import time 18 | 19 | from typing import List 20 | from absl import app 21 | from absl import flags 22 | 23 | import jax 24 | 25 | from jetstream.engine import token_utils 26 | from jetstream_pt import ray_engine 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | _TOKENIZER_PATH = flags.DEFINE_string( 31 | "tokenizer_path", 32 | "tokenizer.model", 33 | "The tokenizer model path", 34 | required=False, 35 | ) 36 | _CKPT_PATH = flags.DEFINE_string( 37 | "checkpoint_path", None, "Directory for .pth checkpoints", required=False 38 | ) 39 | _BF16_ENABLE = flags.DEFINE_bool( 40 | "bf16_enable", False, "Whether to enable bf16", required=False 41 | ) 42 | _CONTEXT_LENGTH = flags.DEFINE_integer( 43 | "context_length", 1024, "The context length", required=False 44 | ) 45 | _BATCH_SIZE = flags.DEFINE_integer( 46 | "batch_size", 32, "The batch size", required=False 47 | ) 48 | _PROFILING_OUTPUT = flags.DEFINE_string( 49 | "profiling_output", 50 | "", 51 | "The profiling output", 52 | required=False, 53 | ) 54 | 55 | _SIZE = flags.DEFINE_string("size", "tiny", "size of model") 56 | 57 | _QUANTIZE_WEIGHTS = flags.DEFINE_bool( 58 | "quantize_weights", False, "weight quantization" 59 | ) 60 | _QUANTIZE_KV_CACHE = flags.DEFINE_bool( 61 | "quantize_kv_cache", False, "kv_cache_quantize" 62 | ) 63 | _MAX_CACHE_LENGTH = flags.DEFINE_integer( 64 | "max_cache_length", 1024, "kv_cache_quantize" 65 | ) 66 | 67 | _MODEL_NAME = flags.DEFINE_string( 68 | "model_name", None, "model type", required=False 69 | ) 70 | 71 | _SHARDING_CONFIG = flags.DEFINE_string( 72 | "sharding_config", "", "config file for sharding" 73 | ) 74 | 75 | 76 | _IS_DISAGGREGATED = flags.DEFINE_bool( 77 | "is_disaggregated", False, "Disaggregated serving if it's True" 78 | ) 79 | 80 | _NUM_HOSTS = flags.DEFINE_integer( 81 | "num_hosts", 4, "Number of TPU host", required=False 82 | ) 83 | 84 | _DECODE_POD_SLICE_NAME = flags.DEFINE_string( 85 | "decode_pod_slice_name", "", "Decode pod slice name" 86 | ) 87 | 88 | 89 | def create_disaggregated_engines(): 90 | """create a pytorch engine""" 91 | # jax.config.update("jax_default_prng_impl", "unsafe_rbg") 92 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" 93 | 94 | start = time.perf_counter() 95 | prefill_engine_list, decode_engine_list = ( 96 | ray_engine.create_pytorch_ray_engine( 97 | model_name=_MODEL_NAME.value, 98 | tokenizer_path=_TOKENIZER_PATH.value, 99 | ckpt_path=_CKPT_PATH.value, 100 | bf16_enable=True, 101 | param_size=_SIZE.value, 102 | context_length=_CONTEXT_LENGTH.value, 103 | batch_size=_BATCH_SIZE.value, 104 | quantize_weights=_QUANTIZE_WEIGHTS.value, 105 | quantize_kv=_QUANTIZE_KV_CACHE.value, 106 | max_cache_length=_MAX_CACHE_LENGTH.value, 107 | sharding_config=_SHARDING_CONFIG.value, 108 | is_disaggregated=_IS_DISAGGREGATED.value, 109 | num_hosts=_NUM_HOSTS.value, 110 | decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value, 111 | ) 112 | ) 113 | 114 | print("Initialize engine", time.perf_counter() - start) 115 | return (prefill_engine_list[0], decode_engine_list[0]) 116 | 117 | 118 | # pylint: disable-next=all 119 | def main(argv): 120 | 121 | print("start the test") 122 | prefill_engine, decode_engine = create_disaggregated_engines() 123 | 124 | start = time.perf_counter() 125 | prefill_engine.load_params() 126 | decode_engine.load_params() 127 | print("Load params ", time.perf_counter() - start) 128 | 129 | metadata = prefill_engine.get_tokenizer() 130 | vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) 131 | stop_tokens = [vocab.eos_id, vocab.pad_id] 132 | max_output_length = 1024 133 | 134 | if _PROFILING_OUTPUT.value: 135 | jax.profiler.start_trace(_PROFILING_OUTPUT.value) 136 | 137 | decode_engine.init_decode_state() 138 | prompts: List[str] = [ 139 | "I believe the meaning of life is", 140 | # pylint: disable-next=all 141 | "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", 142 | # pylint: disable-next=all 143 | "[INST] <>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]", 144 | # pylint: disable-next=all 145 | "[INST] <>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST", 146 | # pylint: disable-next=all 147 | "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", 148 | ] 149 | for prompt in prompts: 150 | slot = random.randint(0, _BATCH_SIZE.value - 1) 151 | tokens, true_length = token_utils.tokenize_and_pad( 152 | prompt, vocab, is_bos=True, jax_padding=False 153 | ) 154 | print(f"---- Input prompts are: {prompt}") 155 | print(f"---- Encoded tokens are: {tokens}") 156 | 157 | print( 158 | # pylint: disable-next=all 159 | f"---- Do prefill in prefill engine pod_slice_name: {prefill_engine.pod_slice_name}" 160 | ) 161 | prefill_result, _ = prefill_engine.prefill( 162 | params=None, padded_tokens=tokens, true_length=true_length 163 | ) 164 | print( 165 | # pylint: disable-next=all 166 | f"---- Transfer prefill result to decode engine pod_slice_name: {decode_engine.pod_slice_name}" 167 | ) 168 | decode_engine.transfer(prefill_result) 169 | 170 | print( 171 | # pylint: disable-next=all 172 | f"---- Do insert in decode engine pod_slice_name: {decode_engine.pod_slice_name}" 173 | ) 174 | decode_state = decode_engine.insert(prefill_result, None, slot=slot) 175 | sampled_tokens_list = [] 176 | while True: 177 | # pylint: disable-next=all 178 | decode_state, result_tokens = decode_engine.generate(None, decode_state) 179 | result_tokens = result_tokens.convert_to_numpy() 180 | 181 | slot_data = result_tokens.get_result_at_slot(slot) 182 | slot_tokens = slot_data.tokens 183 | slot_lengths = slot_data.lengths 184 | 185 | token_id = slot_tokens[slot, 0].item() 186 | if slot_lengths > max_output_length or token_id in stop_tokens: 187 | break 188 | 189 | sampled_tokens_list.append(token_id) 190 | 191 | print("---- All output tokens.") 192 | print(sampled_tokens_list) 193 | print("---- All output text.") 194 | print(vocab.tokenizer.decode(sampled_tokens_list)) 195 | 196 | if _PROFILING_OUTPUT.value: 197 | jax.profiler.stop_trace() 198 | 199 | 200 | if __name__ == "__main__": 201 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" 202 | app.run(main) 203 | -------------------------------------------------------------------------------- /run_interactive_multiple_host.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import random 17 | import time 18 | from typing import List 19 | 20 | import jax 21 | from absl import app, flags 22 | from jetstream.engine import token_utils 23 | from jetstream_pt import ray_engine 24 | from jetstream_pt.config import FLAGS 25 | 26 | _NUM_HOSTS = flags.DEFINE_integer( 27 | "num_hosts", 0, "Number of TPU host", required=False 28 | ) 29 | 30 | _WORKER_CHIPS = flags.DEFINE_integer( 31 | "worker_chips", 4, "Number of TPU chips per worker", required=False 32 | ) 33 | 34 | _TPU_CHIPS = flags.DEFINE_integer( 35 | "tpu_chips", 4, "All devices TPU chips", required=False 36 | ) 37 | 38 | 39 | def create_engine(): 40 | """create a pytorch engine""" 41 | jax.config.update("jax_default_prng_impl", "unsafe_rbg") 42 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" 43 | 44 | start = time.perf_counter() 45 | engine = ray_engine.create_pytorch_ray_engine( 46 | model_name=FLAGS.model_name, 47 | tokenizer_path=FLAGS.tokenizer_path, 48 | ckpt_path=FLAGS.checkpoint_path, 49 | bf16_enable=FLAGS.bf16_enable, 50 | param_size=FLAGS.size, 51 | context_length=FLAGS.context_length, 52 | batch_size=FLAGS.batch_size, 53 | quantize_weights=FLAGS.quantize_weights, 54 | quantize_kv=FLAGS.quantize_kv_cache, 55 | max_cache_length=FLAGS.max_cache_length, 56 | sharding_config=FLAGS.sharding_config, 57 | num_hosts=_NUM_HOSTS.value, 58 | worker_chips=_WORKER_CHIPS.value, 59 | tpu_chips=_TPU_CHIPS.value, 60 | ) 61 | 62 | print("Initialize engine", time.perf_counter() - start) 63 | return engine 64 | 65 | 66 | # pylint: disable-next=all 67 | def main(argv): 68 | 69 | engine = create_engine() 70 | 71 | start = time.perf_counter() 72 | engine.load_params() 73 | print("Load params ", time.perf_counter() - start) 74 | 75 | metadata = engine.get_tokenizer() 76 | vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) 77 | stop_tokens = [vocab.eos_id, vocab.pad_id] 78 | max_output_length = 1024 79 | 80 | profiling_output = FLAGS.profiling_output 81 | if profiling_output: 82 | jax.profiler.start_trace(profiling_output) 83 | 84 | engine.init_decode_state() 85 | prompts: List[str] = [ 86 | "I believe the meaning of life is", 87 | # pylint: disable-next=all 88 | "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", 89 | # pylint: disable-next=all 90 | "[INST] <>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]", 91 | # pylint: disable-next=all 92 | "[INST] <>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST", 93 | # pylint: disable-next=all 94 | "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", 95 | ] 96 | for prompt in prompts: 97 | slot = random.randint(0, FLAGS.batch_size - 1) 98 | tokens, true_length = token_utils.tokenize_and_pad( 99 | prompt, vocab, is_bos=True, jax_padding=False 100 | ) 101 | print(f"---- Input prompts are: {prompt}") 102 | print(f"---- Encoded tokens are: {tokens}") 103 | 104 | # pylint: disable-next=all 105 | prefill_result, _ = engine.prefill( 106 | params=None, padded_tokens=tokens, true_length=true_length 107 | ) 108 | # pylint: disable-next=all 109 | decode_state = engine.insert(prefill_result, None, slot=slot) 110 | sampled_tokens_list = [] 111 | while True: 112 | # pylint: disable-next=all 113 | decode_state, result_tokens = engine.generate(None, decode_state) 114 | result_tokens = result_tokens.convert_to_numpy() 115 | 116 | slot_data = result_tokens.get_result_at_slot(slot) 117 | slot_tokens = slot_data.tokens 118 | slot_lengths = slot_data.lengths 119 | 120 | token_id = slot_tokens[slot, 0].item() 121 | if slot_lengths > max_output_length or token_id in stop_tokens: 122 | break 123 | 124 | sampled_tokens_list.append(token_id) 125 | 126 | print("---- All output tokens.") 127 | print(sampled_tokens_list) 128 | print("---- All output text.") 129 | print(vocab.tokenizer.decode(sampled_tokens_list)) 130 | 131 | if profiling_output: 132 | jax.profiler.stop_trace() 133 | 134 | 135 | if __name__ == "__main__": 136 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" 137 | app.run(main) 138 | -------------------------------------------------------------------------------- /run_ray_serve_interleave.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ Runs a RayServe deployment with Jetstream interleave mode.""" 16 | import os 17 | import time 18 | from typing import AsyncIterator 19 | from absl import app, flags 20 | 21 | from ray import serve 22 | from ray.serve.config import gRPCOptions 23 | 24 | from jetstream.core import config_lib 25 | from jetstream.core import orchestrator 26 | from jetstream.core.config_lib import ServerConfig 27 | from jetstream.core.proto import jetstream_pb2 28 | from jetstream_pt import ray_engine 29 | from jetstream_pt.config import FLAGS 30 | 31 | 32 | flags.DEFINE_string("tpu_generation", "v4", "TPU generation") 33 | flags.DEFINE_integer("tpu_chips", 16, "device tpu_chips") 34 | flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler") 35 | flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server") 36 | flags.DEFINE_integer("num_hosts", 4, "Number of TPU host", required=False) 37 | flags.DEFINE_integer( 38 | "worker_chips", 4, "Number of TPU chips per worker", required=False 39 | ) 40 | 41 | 42 | def create_head_resource_name(generation, tpu_chips): 43 | if generation == "v5litepod": 44 | return f"TPU-{generation}-{tpu_chips}-head" 45 | else: 46 | tpu_cores = tpu_chips * 2 47 | return f"TPU-{generation}-{tpu_cores}-head" 48 | 49 | 50 | def create_engine(**kwargs): 51 | """create a pytorch engine""" 52 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" 53 | 54 | start = time.perf_counter() 55 | engine = ray_engine.create_pytorch_ray_engine(**kwargs) 56 | 57 | print("Initialize engine", time.perf_counter() - start) 58 | return engine 59 | 60 | 61 | @serve.deployment 62 | class JetStreamDeployment: 63 | """JetStream deployment.""" 64 | 65 | def __init__(self, **kwargs): 66 | os.environ["XLA_FLAGS"] = ( 67 | "--xla_dump_to=/tmp/xla_logs --xla_dump_hlo_as_text" 68 | ) 69 | devices = [] 70 | for i in range(kwargs["tpu_chips"]): 71 | devices.append(i) 72 | 73 | print(f"devices: {devices}") 74 | 75 | self.engine = create_engine(**kwargs) 76 | server_config = ServerConfig( 77 | interleaved_slices=(f"tpu={len(devices)}",), 78 | interleaved_engine_create_fns=(lambda a: self.engine,), 79 | ) 80 | 81 | engines = config_lib.get_engines(server_config, devices=devices) 82 | prefill_params = [pe.load_params() for pe in engines.prefill_engines] 83 | generate_params = [ge.load_params() for ge in engines.generate_engines] 84 | shared_params = [ie.load_params() for ie in engines.interleaved_engines] 85 | print("Loaded all weights.") 86 | 87 | self.driver = orchestrator.Driver( 88 | prefill_engines=engines.prefill_engines + engines.interleaved_engines, 89 | generate_engines=engines.generate_engines + engines.interleaved_engines, 90 | prefill_params=prefill_params + shared_params, 91 | generate_params=generate_params + shared_params, 92 | interleaved_mode=True, 93 | jax_padding=False, 94 | metrics_collector=None, 95 | is_ray_backend=True, 96 | ) 97 | 98 | self.orchestrator = orchestrator.LLMOrchestrator(driver=self.driver) 99 | 100 | print("Started jetstream driver....") 101 | 102 | # pylint: disable-next=all 103 | async def Decode( 104 | self, 105 | # pylint: disable-next=all 106 | request: jetstream_pb2.DecodeRequest, 107 | # pylint: disable-next=all 108 | ) -> AsyncIterator[jetstream_pb2.DecodeResponse]: 109 | """Async decode function.""" 110 | return self.orchestrator.Decode(request) 111 | 112 | 113 | def main(_argv): 114 | """Main function""" 115 | resource_name = create_head_resource_name( 116 | FLAGS.tpu_generation, FLAGS.tpu_chips 117 | ) 118 | print(f"Using head resource {resource_name}") 119 | # pylint: disable-next=all 120 | deployment = JetStreamDeployment.options( 121 | ray_actor_options={"resources": {resource_name: 1}} 122 | ).bind( 123 | tpu_chips=FLAGS.tpu_chips, 124 | worker_chips=FLAGS.worker_chips, 125 | num_hosts=FLAGS.num_hosts, 126 | model_name=FLAGS.model_name, 127 | tokenizer_path=FLAGS.tokenizer_path, 128 | ckpt_path=FLAGS.checkpoint_path, 129 | bf16_enable=FLAGS.bf16_enable, 130 | param_size=FLAGS.size, 131 | context_length=FLAGS.context_length, 132 | batch_size=FLAGS.batch_size, 133 | quantize_weights=FLAGS.quantize_weights, 134 | quantize_kv=FLAGS.quantize_kv_cache, 135 | max_cache_length=FLAGS.max_cache_length, 136 | sharding_config=FLAGS.sharding_config, 137 | enable_jax_profiler=FLAGS.enable_jax_profiler, 138 | jax_profiler_port=FLAGS.jax_profiler_port, 139 | ) 140 | 141 | grpc_port = 8888 142 | grpc_servicer_functions = [ 143 | "jetstream.core.proto.jetstream_pb2_grpc.add_OrchestratorServicer_to_server", 144 | ] 145 | serve.start( 146 | grpc_options=gRPCOptions( 147 | port=grpc_port, 148 | grpc_servicer_functions=grpc_servicer_functions, 149 | ), 150 | ) 151 | 152 | serve.run(deployment) 153 | 154 | 155 | if __name__ == "__main__": 156 | app.run(main) 157 | -------------------------------------------------------------------------------- /run_server_with_ray.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Runs a pytorch server with ray.""" 16 | import os 17 | import time 18 | from typing import Sequence 19 | from absl import app, flags 20 | 21 | # import torch_xla2 first! 22 | import jax 23 | from jetstream.core import server_lib 24 | from jetstream.core.config_lib import ServerConfig 25 | from jetstream_pt import ray_engine 26 | from jetstream_pt.config import FLAGS 27 | 28 | flags.DEFINE_integer("port", 9000, "port to listen on") 29 | flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool") 30 | flags.DEFINE_string( 31 | "config", 32 | "InterleavedCPUTestServer", 33 | "available servers", 34 | ) 35 | flags.DEFINE_integer("prometheus_port", 0, "") 36 | flags.DEFINE_integer("tpu_chips", 16, "all devices tpu_chips") 37 | 38 | flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler") 39 | flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server") 40 | 41 | flags.DEFINE_bool( 42 | "is_disaggregated", False, "Disaggregated serving if it's True" 43 | ) 44 | 45 | flags.DEFINE_integer("num_hosts", 0, "Number of TPU host", required=False) 46 | 47 | flags.DEFINE_integer( 48 | "worker_chips", 4, "Number of TPU chips per worker", required=False 49 | ) 50 | 51 | flags.DEFINE_string("decode_pod_slice_name", "", "Decode pod slice name") 52 | 53 | 54 | def create_engine(): 55 | """create a pytorch engine""" 56 | jax.config.update("jax_default_prng_impl", "unsafe_rbg") 57 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" 58 | 59 | start = time.perf_counter() 60 | engine = ray_engine.create_pytorch_ray_engine( 61 | model_name=FLAGS.model_name, 62 | tokenizer_path=FLAGS.tokenizer_path, 63 | ckpt_path=FLAGS.checkpoint_path, 64 | bf16_enable=FLAGS.bf16_enable, 65 | param_size=FLAGS.size, 66 | context_length=FLAGS.context_length, 67 | batch_size=FLAGS.batch_size, 68 | quantize_weights=FLAGS.quantize_weights, 69 | quantize_kv=FLAGS.quantize_kv_cache, 70 | max_cache_length=FLAGS.max_cache_length, 71 | sharding_config=FLAGS.sharding_config, 72 | enable_jax_profiler=FLAGS.enable_jax_profiler, 73 | jax_profiler_port=FLAGS.jax_profiler_port, 74 | num_hosts=FLAGS.num_hosts, 75 | worker_chips=FLAGS.worker_chips, 76 | tpu_chips=FLAGS.tpu_chips, 77 | ) 78 | 79 | print("Initialize engine", time.perf_counter() - start) 80 | return engine 81 | 82 | 83 | def create_disaggregated_engine(): 84 | """create a pytorch engine""" 85 | jax.config.update("jax_default_prng_impl", "unsafe_rbg") 86 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" 87 | 88 | start = time.perf_counter() 89 | prefill_engine_list, decode_engine_list = ( 90 | ray_engine.create_pytorch_ray_engine( 91 | model_name=FLAGS.model_name, 92 | tokenizer_path=FLAGS.tokenizer_path, 93 | ckpt_path=FLAGS.checkpoint_path, 94 | bf16_enable=FLAGS.bf16_enable, 95 | param_size=FLAGS.size, 96 | context_length=FLAGS.context_length, 97 | batch_size=FLAGS.batch_size, 98 | quantize_weights=FLAGS.quantize_weights, 99 | quantize_kv=FLAGS.quantize_kv_cache, 100 | max_cache_length=FLAGS.max_cache_length, 101 | sharding_config=FLAGS.sharding_config, 102 | enable_jax_profiler=FLAGS.enable_jax_profiler, 103 | jax_profiler_port=FLAGS.jax_profiler_port, 104 | is_disaggregated=FLAGS.is_disaggregated, 105 | num_hosts=FLAGS.num_hosts, 106 | decode_pod_slice_name=FLAGS.decode_pod_slice_name, 107 | ) 108 | ) 109 | 110 | print("Initialize engine", time.perf_counter() - start) 111 | return (prefill_engine_list, decode_engine_list) 112 | 113 | 114 | # pylint: disable-next=all 115 | def main(argv: Sequence[str]): 116 | del argv 117 | os.environ["XLA_FLAGS"] = "--xla_dump_to=/tmp/xla_logs --xla_dump_hlo_as_text" 118 | devices = [] 119 | for i in range(FLAGS.tpu_chips): 120 | devices.append(i) 121 | 122 | print(f"devices: {devices}") 123 | 124 | if FLAGS.is_disaggregated: 125 | prefill_engine_list, decode_engine_list = create_disaggregated_engine() 126 | chips = int(len(devices) / 2) 127 | server_config = ServerConfig( 128 | prefill_slices=(f"tpu={chips}",), 129 | prefill_engine_create_fns=(lambda a: prefill_engine_list[0],), 130 | generate_slices=(f"tpu={chips}",), 131 | generate_engine_create_fns=(lambda a: decode_engine_list[0],), 132 | is_ray_backend=True, 133 | ) 134 | 135 | else: 136 | engine = create_engine() 137 | server_config = ServerConfig( 138 | interleaved_slices=(f"tpu={len(devices)}",), 139 | interleaved_engine_create_fns=(lambda a: engine,), 140 | ) 141 | 142 | print(f"server_config: {server_config}") 143 | 144 | jetstream_server = server_lib.run( 145 | threads=FLAGS.threads, 146 | port=FLAGS.port, 147 | config=server_config, 148 | devices=devices, 149 | jax_padding=False, # Jax_padding must be set as False 150 | ) 151 | print("Started jetstream_server....") 152 | jetstream_server.wait_for_termination() 153 | 154 | 155 | if __name__ == "__main__": 156 | app.run(main) 157 | -------------------------------------------------------------------------------- /scripts/create_empty_sharding_map.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from absl import app 18 | from absl import flags 19 | 20 | from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData, process_sharding_name 21 | from jetstream_pt.third_party.llama2 import model_exportable, model_args 22 | from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model 23 | 24 | FLAGS = flags.FLAGS 25 | 26 | _MODEL_NAME = flags.DEFINE_string( 27 | "model_name", None, "model type", required=False 28 | ) 29 | 30 | _SIZE = flags.DEFINE_string("size", "tiny", "size of model") 31 | 32 | _COLLAPSE_SAME_LAYERS = flags.DEFINE_bool("collapse_same_layers", True, "") 33 | 34 | 35 | def create_model(): 36 | batch_size = 3 37 | quant_config = QuantizationConfig( 38 | enable_weight_quantization=True, enable_kv_quantization=True 39 | ) 40 | env_data = JetEngineEnvironmentData( 41 | batch_size=3, 42 | max_decode_length=1024, 43 | max_input_sequence_length=1024, 44 | quant_config=quant_config, 45 | cache_sequence_length=1024, 46 | bf16_enable=True, 47 | ) 48 | model_name = _MODEL_NAME.value 49 | param_size = _SIZE.value 50 | if model_name.startswith("llama"): 51 | 52 | args = model_args.get_model_args( 53 | param_size, 54 | 1024, 55 | batch_size, 56 | vocab_size=32000, 57 | bf16_enable=True, 58 | ) 59 | args.device = "meta" 60 | env = JetEngineEnvironment(env_data) 61 | return model_exportable.Transformer(args, env) 62 | elif model_name == "gemma": 63 | args = gemma_config.get_model_config(param_size) 64 | args.device = "meta" 65 | env_data.model_type = "gemma-" + param_size 66 | env_data.num_layers = args.num_hidden_layers 67 | env = JetEngineEnvironment(env_data) 68 | pt_model = gemma_model.GemmaModel(args, env) 69 | return pt_model 70 | 71 | 72 | # pylint: disable-next=all 73 | def main(argv): 74 | model = create_model() 75 | res = {} 76 | for k, v in model.state_dict().items(): 77 | res[process_sharding_name(k)] = v 78 | 79 | print( 80 | f""" 81 | # Sharding config for {_MODEL_NAME.value} 82 | # Sharding should either be an int between 0 and rank - 1 83 | # signifying the axis to shard or -1 / null signifying replicated 84 | 85 | """ 86 | ) 87 | 88 | for k, v in res.items(): 89 | print(k, ":", -1, "# ", str(v.dtype), tuple(v.shape)) 90 | 91 | 92 | if __name__ == "__main__": 93 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" 94 | app.run(main) 95 | -------------------------------------------------------------------------------- /scripts/validate_hf_ckpt_conversion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors import safe_open 3 | 4 | """ 5 | Script to compare converted checkpoint for debugging purpose. 6 | """ 7 | 8 | converted_from_orig = ( 9 | "/mnt/disks/lsiyuan/llama_weight/7B-FT-chat-converted/model.safetensors" 10 | ) 11 | 12 | converted_from_hf = "/mnt/disks/lsiyuan/llama_weight/hf_llama_2_7b_converted_bf16/model.safetensors" 13 | 14 | orig_state_dict = {} 15 | with safe_open(converted_from_orig, framework="pt", device="cpu") as f: 16 | for key in f.keys(): 17 | orig_state_dict[key] = f.get_tensor(key) 18 | 19 | hf_state_dict = {} 20 | with safe_open(converted_from_hf, framework="pt", device="cpu") as f: 21 | for key in f.keys(): 22 | hf_state_dict[key] = f.get_tensor(key) 23 | 24 | for key in orig_state_dict.keys(): 25 | if key != "rope.freqs": 26 | assert key in hf_state_dict, f"{key} in orig but not in hf" 27 | else: 28 | print("rope.freqs skipped.") 29 | 30 | for key in hf_state_dict.keys(): 31 | assert key in orig_state_dict, f"{key} in hf but not in orig" 32 | 33 | 34 | def _calc_cosine_dist(x, y): 35 | x = x.flatten().to(torch.float32) 36 | y = y.flatten().to(torch.float32) 37 | return (torch.dot(x, y) / (x.norm() * y.norm())).item() 38 | 39 | 40 | for key in hf_state_dict.keys(): 41 | orig_w = orig_state_dict[key] 42 | hf_w = hf_state_dict[key] 43 | print(f"weight diff {key} : {_calc_cosine_dist(orig_w, hf_w)}") 44 | -------------------------------------------------------------------------------- /tests/.pylintrc: -------------------------------------------------------------------------------- 1 | [MESSAGES CONTROL] 2 | disable=C0114,W0212 -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import torch 3 | import torch_xla2 4 | from jetstream_pt.third_party.llama import model_args 5 | from jetstream_pt.third_party.mixtral import config as mixtral_config 6 | from jetstream_pt import environment 7 | 8 | 9 | # pylint: disable-next=all 10 | def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): 11 | torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 12 | torch.set_default_dtype(torch_dtype) 13 | jax.config.update("jax_dynamic_shapes", False) 14 | jax.config.update("jax_traceback_filtering", "off") 15 | config = model_args.get_model_args("llama-2-tiny", 128, 1, True) 16 | environment_data = environment.JetEngineEnvironmentData() 17 | environment_data.max_input_sequence_length = 128 18 | environment_data.max_input_sequence_length = 128 19 | environment_data.cache_sequence_length = 128 20 | environment_data.bf16_enable = bf16_enable 21 | environment_data.model_type = "llama-2-tiny" 22 | environment_data.batch_size = 1 23 | environment_data.num_layers = config.n_layers 24 | environment_data.cache_shape = ( 25 | 1, 26 | config.n_kv_heads, 27 | environment_data.cache_sequence_length, 28 | config.dim // config.n_heads, 29 | ) 30 | environment_data.n_reps = config.n_heads // config.n_kv_heads 31 | environment_data.testing = True 32 | env_data_update_fn(environment_data) 33 | env = environment.JetEngineEnvironment(environment_data) 34 | env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu 35 | return env, config 36 | 37 | 38 | # pylint: disable-next=all 39 | def make_mixtral_env(bf16_enable=True): 40 | torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 41 | torch.set_default_dtype(torch_dtype) 42 | jax.config.update("jax_dynamic_shapes", False) 43 | jax.config.update("jax_traceback_filtering", "off") 44 | config = mixtral_config.ModelArgs.from_name("Mixtral-tiny") 45 | environment_data = environment.JetEngineEnvironmentData() 46 | environment_data.max_input_sequence_length = 128 47 | environment_data.cache_sequence_length = 128 48 | environment_data.bf16_enable = bf16_enable 49 | environment_data.model_type = "mixtral" 50 | environment_data.batch_size = 1 51 | environment_data.num_layers = config.n_layer 52 | environment_data.cache_shape = ( 53 | 1, 54 | config.n_local_heads, 55 | environment_data.cache_sequence_length, 56 | config.dim // config.n_head, 57 | ) 58 | environment_data.n_reps = config.n_head // config.n_local_heads 59 | env = environment.JetEngineEnvironment(environment_data) 60 | env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu 61 | return env, config 62 | 63 | 64 | # pylint: disable-next=all 65 | def to_xla_tensor(tree): 66 | return torch_xla2.default_env().to_xla(tree) 67 | 68 | 69 | # pylint: disable-next=all 70 | def call_xla_model(model, weights, args): 71 | with jax.default_device(jax.devices("cpu")[0]): 72 | xla_weights, xla_inputs = to_xla_tensor((weights, args)) 73 | with torch_xla2.default_env(): 74 | result = torch.func.functional_call(model, xla_weights, xla_inputs) 75 | result_torch = torch_xla2.tensor.j2t(result.jax()) 76 | return result_torch 77 | 78 | 79 | # pylint: disable-next=all 80 | def make_page_attention_env_tiny( 81 | bf16_enable=True, env_data_update_fn=lambda _: None 82 | ): 83 | torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 84 | torch.set_default_dtype(torch_dtype) 85 | jax.config.update("jax_dynamic_shapes", False) 86 | jax.config.update("jax_traceback_filtering", "off") 87 | config = model_args.get_model_args("llama-2-tiny", 128, 1, True) 88 | environment_data = environment.JetEngineEnvironmentData() 89 | environment_data.paged_attention_page_size = 32 90 | environment_data.paged_attention_total_num_pages = 16 91 | environment_data.block_size = 64 92 | environment_data.max_input_sequence_length = 128 93 | environment_data.max_input_sequence_length = 128 94 | environment_data.cache_sequence_length = 128 95 | environment_data.bf16_enable = bf16_enable 96 | environment_data.model_type = "llama-2-tiny" 97 | environment_data.batch_size = 1 98 | environment_data.num_layers = config.n_layers 99 | environment_data.cache_shape = ( 100 | config.n_kv_heads, 101 | environment_data.paged_attention_total_num_pages, 102 | environment_data.paged_attention_page_size, 103 | config.dim // config.n_heads, 104 | ) 105 | environment_data.testing = True 106 | env_data_update_fn(environment_data) 107 | env = environment.JetEngineEnvironment(environment_data) 108 | env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu 109 | return env, config 110 | -------------------------------------------------------------------------------- /tests/test_hf_names.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from jetstream_pt.model_base import ModuleBase 4 | 5 | 6 | class TestModuleBase(unittest.TestCase): 7 | """Test module base.""" 8 | 9 | def test_get_hf_names_to_real_name(self): 10 | """Test get hugginface names to real name.""" 11 | 12 | class MyModule(ModuleBase): 13 | """My module.""" 14 | 15 | def __init__(self): 16 | super().__init__() 17 | self.linear1 = torch.nn.Linear(10, 20) 18 | self.linear2 = torch.nn.Linear(20, 30) 19 | self.hf_name("linear1", "model.my_linear1") 20 | self.hf_name("linear2", "model.my_linear2") 21 | self.param = torch.nn.Parameter(torch.randn(10)) 22 | self.hf_name("param", "model.param") 23 | 24 | def forward(self): 25 | """Forward function.""" 26 | 27 | module = MyModule() 28 | expected_mapping = { 29 | "model.my_linear1.weight": "linear1.weight", 30 | "model.my_linear1.bias": "linear1.bias", 31 | "model.my_linear2.weight": "linear2.weight", 32 | "model.my_linear2.bias": "linear2.bias", 33 | "model.param": "param", 34 | } 35 | 36 | self.assertEqual(module.get_hf_names_to_real_name(), expected_mapping) 37 | 38 | def test_get_sharding_annotations(self): 39 | """Test get sharding annotations.""" 40 | 41 | class MyModule(ModuleBase): 42 | """MyModule.""" 43 | 44 | def __init__(self): 45 | super().__init__() 46 | self.linear = torch.nn.Linear(10, 20) 47 | self.embedding = torch.nn.Embedding(100, 50) 48 | self.inner = InnerModule() 49 | 50 | def forward(self): 51 | """Forward function.""" 52 | 53 | class InnerModule(ModuleBase): 54 | """Inner modeule.""" 55 | 56 | def __init__(self): 57 | super().__init__() 58 | self.fc = torch.nn.Linear(50, 100) 59 | 60 | def forward(self): 61 | """Forward function.""" 62 | 63 | module = MyModule() 64 | module.annotate_sharding("linear.weight", 0) 65 | module.annotate_sharding("embedding.weight", 1) 66 | module.inner.annotate_sharding("fc.weight", 2) 67 | 68 | expected_mapping = { 69 | "linear.weight": 0, 70 | "embedding.weight": 1, 71 | "inner.fc.weight": 2, 72 | } 73 | self.assertEqual(module.get_sharding_annotations(), expected_mapping) 74 | 75 | 76 | if __name__ == "__main__": 77 | unittest.main() 78 | -------------------------------------------------------------------------------- /tests/test_jax_torch.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torch_xla2 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | 8 | class JaxTorchTest(unittest.TestCase): 9 | """Unit test compare Jax and Torch gap with float precision""" 10 | 11 | def test_matmul_bfloat16_xla2(self): 12 | """test jax vs torch matmul diff with bfloat16 on cpu""" 13 | jax.config.update("jax_platform_name", "cpu") 14 | torch.set_default_dtype(torch.bfloat16) 15 | r = c = 1000 16 | q = torch.randn((r, c)) 17 | k = torch.randn((r, c)) 18 | print(f"torch matlmul: {q.shape} * {k.shape}") 19 | result = torch.matmul(q, k) 20 | 21 | jax_q = torch_xla2.tensor.t2j(q) 22 | jax_k = torch_xla2.tensor.t2j(k) 23 | print(f"torch matlmul: {jax_q.shape} * {jax_k.shape}") 24 | jax_result = jnp.matmul(jax_q, jax_k) 25 | target_result = torch_xla2.tensor.j2t(jax_result) 26 | print( 27 | f"----------------------- matmul: Diff norm {(target_result - result).norm()}" 28 | ) 29 | self.assertTrue(torch.allclose(target_result, result, atol=1)) 30 | 31 | def test_matmul_bfloat32(self): 32 | """test jax vs torch matmul diff with bfloat32 on cpu""" 33 | jax.config.update("jax_platform_name", "cpu") 34 | torch.set_default_dtype(torch.float32) 35 | r = c = 1000 36 | q = torch.randn((r, c)) 37 | k = torch.randn((r, c)) 38 | print(f"torch matlmul: {q.shape} * {k.shape}") 39 | result = torch.matmul(q, k) 40 | 41 | jax_q = torch_xla2.tensor.t2j(q) 42 | jax_k = torch_xla2.tensor.t2j(k) 43 | print(f"torch matlmul: {jax_q.shape} * {jax_k.shape}") 44 | jax_result = jnp.matmul(jax_q, jax_k) 45 | target_result = torch_xla2.tensor.j2t(jax_result) 46 | print( 47 | f"----------------------- matmul: Diff norm {(target_result - result).norm()}" 48 | ) 49 | self.assertTrue(torch.allclose(target_result, result, atol=1e-4)) 50 | 51 | 52 | if __name__ == "__main__": 53 | unittest.main() 54 | -------------------------------------------------------------------------------- /tests/test_kv_cache_manager.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax 4 | import numpy as np 5 | import jax.numpy as jnp 6 | import torch 7 | 8 | from jetstream_pt.third_party.llama import model_args 9 | from jetstream_pt import environment 10 | from jetstream_pt.page_attention_manager import PageAttentionManager 11 | from jetstream_pt.cache_manager import PageKVCacheGenerate, KVCachePrefill 12 | from jetstream_pt import torchjax 13 | from absl.testing import parameterized 14 | 15 | P = jax.sharding.PartitionSpec 16 | 17 | 18 | class PageAttentnioTest(parameterized.TestCase): 19 | 20 | def _make_env(self, bf16_enable=True): 21 | torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 22 | torch.set_default_dtype(torch_dtype) 23 | jax.config.update("jax_dynamic_shapes", False) 24 | jax.config.update("jax_traceback_filtering", "off") 25 | jax.config.update("jax_platform_name", "cpu") 26 | config = model_args.get_model_args("tiny", 128, 1, True) 27 | environment_data = environment.JetEngineEnvironmentData() 28 | environment_data.max_input_sequence_length = 128 29 | environment_data.max_input_sequence_length = 128 30 | environment_data.cache_sequence_length = 128 31 | environment_data.bf16_enable = bf16_enable 32 | environment_data.model_type = "llama-2-tiny" 33 | environment_data.batch_size = 3 34 | environment_data.num_layers = config.n_layers 35 | environment_data.cache_shape = ( 36 | 1, 37 | config.n_kv_heads, 38 | environment_data.cache_sequence_length, 39 | config.dim // config.n_heads, 40 | ) 41 | env = environment.JetEngineEnvironment(environment_data) 42 | env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu 43 | mesh = jax.sharding.Mesh(np.array(jax.devices()), axis_names=("x",)) 44 | replicated = jax.sharding.NamedSharding(mesh, P()) 45 | env.sharding = replicated 46 | return env, config 47 | 48 | def test_page_attention_update(self): 49 | jax.config.update("jax_platform_name", "cpu") 50 | print(f"---------> {jax.devices()}") 51 | 52 | env, _ = self._make_env() 53 | 54 | pam = PageAttentionManager( 55 | batch_size=5, 56 | paged_attention_total_num_pages=20, 57 | paged_attention_page_size=4, 58 | max_pages_per_sequence=4, 59 | ) 60 | shape = (1, 20, 4, 2) 61 | decode_caches = [] 62 | decode_caches.append( 63 | PageKVCacheGenerate.empty(shape=shape, device=None, env=env) 64 | ) 65 | decode_caches = [c.state() for c in decode_caches] 66 | 67 | self.cache_sharding = env.cache_sharding 68 | 69 | def _insert_prefill(seq_len, dim, slot): 70 | prefill_chache = KVCachePrefill() 71 | k, v = jnp.arange(seq_len * dim), jnp.arange(seq_len * dim) 72 | k, v = jnp.reshape(k, (1, 1, seq_len, dim)), jnp.reshape( 73 | k, (1, 1, seq_len, dim) 74 | ) 75 | prefill_chache.update(k, v, 0) 76 | prefill_caches = [prefill_chache] 77 | prefill_caches = [c.state() for c in prefill_caches] 78 | num_pages, update_indexes = pam.reserve_pages_insert(slot, seq_len) 79 | _, kv_heads, _, dim = prefill_caches[0][0].shape 80 | tep_kv = jnp.zeros((kv_heads, num_pages * 4, dim), dtype=jnp.bfloat16) 81 | 82 | caches = pam.insert_prefill_cache( 83 | prefill_caches=prefill_caches, 84 | decode_caches=decode_caches, 85 | update_indexes=update_indexes, 86 | tep_kv=tep_kv, 87 | sharding=env.sharding, 88 | ) 89 | 90 | return caches 91 | 92 | decode_caches = _insert_prefill(3, 2, 0) 93 | decode_caches = _insert_prefill(8, 2, 1) 94 | decode_caches = _insert_prefill(13, 2, 3) 95 | 96 | lens = np.asarray([3, 8, 0, 13, 0]) 97 | pam.fill_new_pages(lens) 98 | np_page_token_indices = pam.get_page_token_indices(lens) 99 | page_token_indices = jnp.asarray(np_page_token_indices) 100 | page_token_indices = torchjax.to_torch(page_token_indices) 101 | 102 | caches_obj = [ 103 | PageKVCacheGenerate( 104 | k, v, pam, page_token_indices, self.cache_sharding, env=env 105 | ) 106 | for k, v in torchjax.to_torch(decode_caches) 107 | ] 108 | xk, xv = jnp.arange(-1, -11, -1).reshape(5, 1, 1, 2), jnp.arange( 109 | -1, -11, -1 110 | ).reshape(5, 1, 1, 2) 111 | xk = torchjax.to_torch(xk) 112 | xv = torchjax.to_torch(xv) 113 | decode_caches = caches_obj[0].update(xk, xv) 114 | expected = jnp.asarray([[0, 1], [2, 3], [4, 5], [-1, -2]]) 115 | self.assertTrue(jnp.array_equal(decode_caches[0][0][0], expected)) 116 | expected = jnp.asarray([[-3, -4], [0, 0], [0, 0], [0, 0]]) 117 | self.assertTrue(jnp.array_equal(decode_caches[0][0][7], expected)) 118 | expected = jnp.asarray([[24, 25], [-7, -8], [0, 0], [0, 0]]) 119 | self.assertTrue(jnp.array_equal(decode_caches[0][0][6], expected)) 120 | 121 | 122 | if __name__ == "__main__": 123 | unittest.main() 124 | -------------------------------------------------------------------------------- /tests/test_page_attention.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax 4 | import numpy as np 5 | import jax.numpy as jnp 6 | import torch 7 | 8 | from jetstream_pt.third_party.llama import model_args 9 | from jetstream_pt import environment 10 | from jetstream_pt.page_attention_manager import PageAttentionManager 11 | from jetstream_pt.cache_manager import PageKVCacheGenerate, KVCachePrefill 12 | from absl.testing import parameterized 13 | 14 | P = jax.sharding.PartitionSpec 15 | 16 | 17 | class PageAttentionTest(parameterized.TestCase): 18 | 19 | def _make_env(self, bf16_enable=True): 20 | torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 21 | torch.set_default_dtype(torch_dtype) 22 | jax.config.update("jax_dynamic_shapes", False) 23 | jax.config.update("jax_traceback_filtering", "off") 24 | jax.config.update("jax_platform_name", "cpu") 25 | jax.config.update("jax_enable_x64", False) 26 | mesh = jax.sharding.Mesh(np.array(jax.devices()), axis_names=("x",)) 27 | replicated = jax.sharding.NamedSharding(mesh, P()) 28 | config = model_args.get_model_args("tiny", 128, 1, True) 29 | environment_data = environment.JetEngineEnvironmentData() 30 | environment_data.max_input_sequence_length = 128 31 | environment_data.max_input_sequence_length = 128 32 | environment_data.cache_sequence_length = 128 33 | environment_data.bf16_enable = bf16_enable 34 | environment_data.model_type = "llama-2-tiny" 35 | environment_data.batch_size = 3 36 | environment_data.num_layers = config.n_layers 37 | environment_data.cache_shape = ( 38 | 1, 39 | config.n_kv_heads, 40 | environment_data.cache_sequence_length, 41 | config.dim // config.n_heads, 42 | ) 43 | env = environment.JetEngineEnvironment(environment_data) 44 | env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu 45 | env.sharding = replicated 46 | return env, config 47 | 48 | def test_prefill_insert(self): 49 | 50 | env, _ = self._make_env() 51 | 52 | pam = PageAttentionManager( 53 | batch_size=3, 54 | paged_attention_total_num_pages=20, 55 | paged_attention_page_size=4, 56 | max_pages_per_sequence=4, 57 | ) 58 | shape = (1, 6, 4, 2) 59 | decode_caches = [] 60 | decode_caches.append( 61 | PageKVCacheGenerate.empty(shape=shape, device=None, env=env) 62 | ) 63 | decode_caches = [c.state() for c in decode_caches] 64 | 65 | prefill_chache = KVCachePrefill() 66 | k, v = jnp.arange(6), jnp.arange(6) 67 | k, v = jnp.reshape(k, (1, 1, 3, 2)), jnp.reshape(k, (1, 1, 3, 2)) 68 | prefill_chache.update(k, v, 0) 69 | prefill_caches = [prefill_chache] 70 | prefill_caches = [c.state() for c in prefill_caches] 71 | 72 | num_pages, update_indexes = pam.reserve_pages_insert(0, 3) 73 | _, kv_heads, _, dim = prefill_caches[0][0].shape 74 | tep_kv = jnp.zeros((kv_heads, num_pages * 4, dim), dtype=jnp.bfloat16) 75 | 76 | caches = pam.insert_prefill_cache( 77 | prefill_caches=prefill_caches, 78 | decode_caches=decode_caches, 79 | update_indexes=update_indexes, 80 | tep_kv=tep_kv, 81 | sharding=env.sharding, 82 | ) 83 | expected_kv = jnp.arange(6).reshape(3, 2) 84 | padding = jnp.asarray([[0, 0]]) 85 | expected_kv = jnp.concatenate((expected_kv, padding)) 86 | 87 | self.assertTrue( 88 | jnp.array_equal( 89 | caches[0][0][0, 0, 0:4, 0:2], expected_kv.astype(jnp.bfloat16) 90 | ) 91 | ) 92 | self.assertTrue( 93 | jnp.array_equal( 94 | caches[0][1][0, 0, 0:4, 0:2], expected_kv.astype(jnp.bfloat16) 95 | ) 96 | ) 97 | 98 | def test_prefill_insert_multiple_pages(self): 99 | 100 | jax.config.update("jax_platform_name", "cpu") 101 | print(f"---------> {jax.devices()}") 102 | 103 | env, _ = self._make_env() 104 | 105 | pam = PageAttentionManager( 106 | batch_size=3, 107 | paged_attention_total_num_pages=20, 108 | paged_attention_page_size=4, 109 | max_pages_per_sequence=4, 110 | ) 111 | shape = (1, 20, 4, 2) 112 | decode_caches = [] 113 | decode_caches.append( 114 | PageKVCacheGenerate.empty(shape=shape, device=None, env=env) 115 | ) 116 | decode_caches = [c.state() for c in decode_caches] 117 | 118 | self.cache_sharding = env.cache_sharding 119 | 120 | prefill_chache = KVCachePrefill() 121 | k, v = jnp.arange(12), jnp.arange(12) 122 | k, v = jnp.reshape(k, (1, 1, 6, 2)), jnp.reshape(k, (1, 1, 6, 2)) 123 | prefill_chache.update(k, v, 0) 124 | prefill_caches = [prefill_chache] 125 | prefill_caches = [c.state() for c in prefill_caches] 126 | 127 | num_pages, update_indexes = pam.reserve_pages_insert(0, 6) 128 | _, kv_heads, _, dim = prefill_caches[0][0].shape 129 | tep_kv = jnp.zeros((kv_heads, num_pages * 4, dim), dtype=jnp.bfloat16) 130 | 131 | decode_caches = pam.insert_prefill_cache( 132 | prefill_caches=prefill_caches, 133 | decode_caches=decode_caches, 134 | update_indexes=update_indexes, 135 | tep_kv=tep_kv, 136 | sharding=env.sharding, 137 | ) 138 | 139 | self.assertEqual(len(decode_caches), 1) 140 | expected = jnp.arange(16).at[12:16].set([0, 0, 0, 0]).reshape(1, 2, 4, 2) 141 | 142 | updated_k = jax.lax.slice_in_dim(decode_caches[0][0], 0, 2, axis=1) 143 | self.assertTrue(jnp.array_equal(updated_k, expected)) 144 | noupdated_k = jax.lax.slice_in_dim(decode_caches[0][0], 2, 20, axis=1) 145 | self.assertTrue(jnp.array_equal(noupdated_k, jnp.zeros_like(noupdated_k))) 146 | 147 | def test_reserve_pages_decode(self): 148 | 149 | env, _ = self._make_env() 150 | 151 | pam = PageAttentionManager( 152 | batch_size=3, 153 | paged_attention_total_num_pages=20, 154 | paged_attention_page_size=4, 155 | max_pages_per_sequence=4, 156 | ) 157 | slot = 1 158 | seq_len = 8 159 | pam.reserve_pages_insert(slot, seq_len) 160 | expected_slot_page_indices = np.asarray([0, 1]) 161 | slot_page_indices = pam.page_indices[slot][0:2] 162 | self.assertTrue( 163 | np.array_equal(slot_page_indices, expected_slot_page_indices) 164 | ) 165 | 166 | lens = np.asarray([0, seq_len, 0]) 167 | pam.fill_new_pages(lens) 168 | expected_slot_page_indices = np.asarray([0, 1, 2, 19]) 169 | slot_page_indices = pam.page_indices[slot] 170 | self.assertTrue( 171 | np.array_equal(slot_page_indices, expected_slot_page_indices) 172 | ) 173 | 174 | expected_0_page_indices = np.asarray([19, 19, 19, 19]) 175 | zer0_page_indices = pam.page_indices[0][0:4] 176 | self.assertTrue(np.array_equal(zer0_page_indices, expected_0_page_indices)) 177 | 178 | def test_get_page_token_indices(self): 179 | env, _ = self._make_env() 180 | 181 | pam = PageAttentionManager( 182 | batch_size=5, 183 | paged_attention_total_num_pages=20, 184 | paged_attention_page_size=4, 185 | max_pages_per_sequence=4, 186 | ) 187 | pam.reserve_pages_insert(1, 8) 188 | pam.reserve_pages_insert(3, 13) 189 | pam.reserve_pages_insert(0, 3) 190 | 191 | lens = np.asarray([3, 8, 0, 13, 0]) 192 | pam.fill_new_pages(lens) 193 | 194 | page_token_indices = pam.get_page_token_indices(lens) 195 | 196 | expected_page_indices = np.asarray([6, 7, 5]) 197 | expected_token_indices = np.asarray([3, 4, 9]) 198 | self.assertTrue( 199 | np.array_equal(page_token_indices[0], expected_page_indices) 200 | ) 201 | self.assertTrue( 202 | np.array_equal(page_token_indices[1], expected_token_indices) 203 | ) 204 | 205 | 206 | if __name__ == "__main__": 207 | unittest.main() 208 | --------------------------------------------------------------------------------