├── CONTRIBUTING ├── LICENSE ├── README.md ├── inference ├── trillium │ ├── JetStream-Maxtext │ │ ├── DeepSeek-R1-671B │ │ │ ├── README.md │ │ │ ├── docker │ │ │ │ ├── Dockerfile │ │ │ │ └── cloudbuild.yml │ │ │ ├── pathways.png │ │ │ ├── prepare-model │ │ │ │ └── batch_job.yaml │ │ │ ├── serve-model │ │ │ │ ├── Chart.yaml │ │ │ │ └── templates │ │ │ │ │ ├── model-serve-configmap.yaml │ │ │ │ │ ├── model-serve-launcher.yaml │ │ │ │ │ └── model-serve-svc.yaml │ │ │ └── values.yaml │ │ ├── DeepSeek-R1-Distill-Llama-70B │ │ │ ├── gke │ │ │ │ ├── README.md │ │ │ │ ├── docker │ │ │ │ │ ├── Dockerfile │ │ │ │ │ └── cloudbuild.yml │ │ │ │ ├── serve-model │ │ │ │ │ ├── Chart.yaml │ │ │ │ │ └── templates │ │ │ │ │ │ ├── model-serve-configmap.yaml │ │ │ │ │ │ ├── model-serve-launcher.yaml │ │ │ │ │ │ └── model-serve-svc.yaml │ │ │ │ └── values.yaml │ │ │ └── tpu-vm │ │ │ │ └── README.md │ │ ├── Llama-4-Maverick-17B-128E │ │ │ ├── README.md │ │ │ ├── docker │ │ │ │ ├── Dockerfile │ │ │ │ └── cloudbuild.yml │ │ │ ├── prepare-model │ │ │ │ ├── Chart.yaml │ │ │ │ └── templates │ │ │ │ │ └── model-serve-downloader.yaml │ │ │ ├── serve-model │ │ │ │ ├── Chart.yaml │ │ │ │ └── templates │ │ │ │ │ ├── model-serve-configmap.yaml │ │ │ │ │ └── pathways-server.yaml │ │ │ └── values.yaml │ │ ├── Llama-4-Scout-17B-16E │ │ │ ├── README.md │ │ │ ├── docker │ │ │ │ ├── Dockerfile │ │ │ │ └── cloudbuild.yml │ │ │ ├── prepare-model │ │ │ │ ├── Chart.yaml │ │ │ │ └── templates │ │ │ │ │ └── model-serve-downloader.yaml │ │ │ ├── serve-model │ │ │ │ ├── Chart.yaml │ │ │ │ └── templates │ │ │ │ │ ├── model-serve-configmap.yaml │ │ │ │ │ ├── model-serve-launcher.yaml │ │ │ │ │ └── model-serve-svc.yaml │ │ │ └── values.yaml │ │ └── Llama2-7B │ │ │ └── README.md │ ├── JetStream-Pytorch │ │ └── Llama2-7B │ │ │ └── README.md │ ├── MaxDiffusion │ │ └── SDXL │ │ │ └── README.md │ └── vLLM │ │ ├── Llama3-8b │ │ └── README.md │ │ ├── Llama3.3-70b │ │ └── README.md │ │ ├── Qwen2.5-32B │ │ └── README.md │ │ └── README.md └── v5e │ ├── JetStream-Maxtext │ └── Llama2-7B │ │ └── README.md │ ├── JetStream-Pytorch │ └── Llama2-7B │ │ └── README.md │ └── MaxDiffusion │ └── SDXL │ └── README.md ├── microbenchmarks ├── README.md ├── benchmark_hbm.py ├── benchmark_matmul.py ├── benchmark_utils.py ├── requirements.txt └── trillium │ └── collectives │ ├── README.md │ ├── collectives-1xv6e-256.sh │ ├── collectives-2xv6e-256.sh │ └── collectives-4xv6e-256.sh ├── training ├── trillium │ ├── Diffusion-2-PyTorch │ │ ├── README.md │ │ ├── benchmark.sh │ │ ├── env.sh │ │ ├── host.sh │ │ └── train.sh │ ├── GPT3-175B-MaxText │ │ ├── bf16 │ │ │ ├── README.md │ │ │ └── gpt3-175b-v6e-256.sh │ │ └── fp8 │ │ │ ├── README.md │ │ │ └── gpt3-175b-v6e-256.sh │ ├── Llama2-70B-MaxText │ │ ├── README.md │ │ └── llama2-70b-v6e-256.sh │ ├── Llama3-8B-MaxText │ │ ├── v6e-256 │ │ │ ├── README.md │ │ │ └── llama3-8B-1xv6e-256.sh │ │ └── v6e-8 │ │ │ ├── README.md │ │ │ └── llama3-8B-1xv6e-8.sh │ ├── Llama3.0-70B-PyTorch │ │ ├── GCE │ │ │ ├── README.md │ │ │ ├── benchmark.sh │ │ │ ├── config.json │ │ │ ├── env.sh │ │ │ ├── fsdp_config.json │ │ │ ├── host.sh │ │ │ ├── tpu.Dockerfile │ │ │ └── train.sh │ │ └── XPK │ │ │ ├── README.md │ │ │ ├── benchmark.sh │ │ │ ├── config_70b.json │ │ │ ├── env.sh │ │ │ ├── fsdp_config.json │ │ │ └── train.sh │ ├── Llama3.0-8B-PyTorch │ │ ├── GCE │ │ │ ├── README.md │ │ │ ├── benchmark.sh │ │ │ ├── config.json │ │ │ ├── env.sh │ │ │ ├── fsdp_config.json │ │ │ ├── host.sh │ │ │ └── train.sh │ │ └── XPK │ │ │ ├── README.md │ │ │ ├── benchmark.sh │ │ │ ├── config_8b.json │ │ │ ├── env.sh │ │ │ ├── fsdp_config.json │ │ │ └── train.sh │ ├── Llama3.1-405B-MaxText │ │ ├── README.md │ │ └── llama3-1-405b-2xv6e-256.sh │ ├── Llama3.1-405B-PyTorch │ │ ├── GCE │ │ │ ├── README.md │ │ │ ├── benchmark.sh │ │ │ ├── config.json │ │ │ ├── env.sh │ │ │ ├── host.sh │ │ │ └── train.sh │ │ └── XPK │ │ │ ├── README.md │ │ │ ├── benchmark.sh │ │ │ ├── config_405b.json │ │ │ ├── env.sh │ │ │ └── train.sh │ ├── Llama3.1-70B-MaxText │ │ ├── README.md │ │ └── llama3-1-70B-1xv6e-256.sh │ ├── MAXTEXT_README.md │ ├── Mistral-7B-MaxText │ │ ├── README.md │ │ └── mistral-7B-1xv6e-8.sh │ ├── Mixtral-8x22B-MaxText │ │ ├── README.md │ │ ├── mixtral-8x22b-10xv6e-256.sh │ │ ├── mixtral-8x22b-1xv6e-256.sh │ │ ├── mixtral-8x22b-20xv6e-256.sh │ │ ├── mixtral-8x22b-30xv6e-256.sh │ │ └── mixtral-8x22b-40xv6e-256.sh │ ├── Mixtral-8x7B-MaxText │ │ ├── README.md │ │ ├── mixtral-8x7b-1xv6e-256.sh │ │ ├── mixtral-8x7b-2xv6e-256.sh │ │ └── mixtral-8x7b-4xv6e-256.sh │ ├── Mixtral-8x7B-Pytorch │ │ ├── GCE │ │ │ ├── README.md │ │ │ ├── benchmark.sh │ │ │ ├── config.json │ │ │ ├── env.sh │ │ │ ├── fsdp_config.json │ │ │ ├── host.sh │ │ │ └── train.sh │ │ └── XPK │ │ │ ├── README.md │ │ │ ├── benchmark.sh │ │ │ ├── config.json │ │ │ ├── env.sh │ │ │ ├── fsdp_config.json │ │ │ └── train.sh │ └── XPK_README.md └── v5p │ ├── DLRM-V2-Tensorflow │ └── README.md │ ├── Diffusion-2-MaxDiffusion │ ├── README.md │ ├── docker │ │ └── maxdiffusion.Dockerfile │ └── scripts │ │ └── run_v5p-ddp-pbs-16.sh │ ├── Diffusion-2-PyTorch │ ├── README.md │ ├── benchmark.sh │ ├── env.sh │ ├── host.sh │ └── train.sh │ ├── GPT3-175B-MaxText │ └── README.md │ ├── Llama2-7B-Maxtext │ └── README.md │ ├── Llama2-7B-PyTorch │ ├── README.md │ ├── benchmark.sh │ ├── config.sh │ ├── env.sh │ ├── fsdp_config.sh │ ├── host.sh │ └── train.sh │ ├── Llama4-Maverick-17B-128E-Maxtext │ └── README.md │ ├── Llama4-Scout-17B-16E-Maxtext │ └── README.md │ ├── Mixtral-8X7B-Maxtext │ ├── README.md │ └── scripts │ │ └── run_mixtral-8x7b.sh │ ├── Mixtral-8x7B-PyTorch │ ├── README.md │ ├── benchmark.sh │ ├── config.json │ ├── env.sh │ ├── fsdp_config.json │ ├── host.sh │ └── train.sh │ ├── SDXL-MaxDiffusion │ ├── README.md │ ├── docker │ │ └── maxdiffusion.Dockerfile │ └── scripts │ │ └── run_v5p-ddp-pbs-1.sh │ └── XPK_README.md └── utils ├── profile_convert.py └── xplane_pb2.py /CONTRIBUTING: -------------------------------------------------------------------------------- 1 | # Contributions 2 | 3 | We appreciate your interest in contributing! This project is currently **not** accepting external contributions. We may revisit this policy in the future. 4 | 5 | While we aren't accepting code contributions at this time, you can still get involved by reporting any bugs, feature requests or documentation improvements via GitHub issues. 6 | 7 | Thank you for your understanding! -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cloud TPU performance recipes 2 | 3 | This repository provides the necessary instructions to reproduce a 4 | specific workload on Google Cloud TPUs. The focus is on reliably achieving 5 | a performance metric (e.g. throughput) that demonstrates the combined hardware 6 | and software stack on TPUs. 7 | 8 | ## Organization 9 | 10 | - `./training`: instructions to reproduce the training performance of 11 | popular LLMs, diffusion, and other models with PyTorch and JAX. 12 | 13 | - `./inference`: instructions to reproduce inference performance. 14 | 15 | - `./microbenchmarks`: instructions for low-level TPU benchmarks such as 16 | matrix multiplication performance and memory bandwidth. 17 | 18 | ## Contributor notes 19 | 20 | Note: This is not an officially supported Google product. This project is not 21 | eligible for the [Google Open Source Software Vulnerability Rewards 22 | Program](https://bughunters.google.com/open-source-security). 23 | -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/DeepSeek-R1-671B/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | FROM ubuntu:22.04 16 | 17 | ENV DEBIAN_FRONTEND=noninteractive 18 | 19 | # Install dependencies 20 | RUN apt -y update && apt install -y --no-install-recommends \ 21 | apt-transport-https ca-certificates gnupg git wget \ 22 | python3.10 python3-pip curl nano vim 23 | 24 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 25 | 26 | # Install google cloud sdk 27 | RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" \ 28 | | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list \ 29 | && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg \ 30 | | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg \ 31 | && apt-get update -y \ 32 | && apt-get install google-cloud-sdk -y 33 | 34 | # Install pip 35 | RUN python3 -m pip install --upgrade pip 36 | 37 | RUN pip install "huggingface_hub[cli]" hf_transfer 38 | 39 | # Set environment variables 40 | ENV JAX_PLATFORMS=proxy 41 | ENV JAX_BACKEND_TARGET=grpc://localhost:38681 42 | ENV XCLOUD_ENVIRONMENT=GCP 43 | 44 | # Install JetStream and MaxText 45 | 46 | RUN git clone https://github.com/AI-Hypercomputer/JetStream.git && \ 47 | git clone https://github.com/AI-Hypercomputer/maxtext.git && \ 48 | git clone https://github.com/google/aqt.git 49 | 50 | RUN cd /maxtext && bash setup.sh && pip install torch --index-url https://download.pytorch.org/whl/cpu 51 | 52 | RUN pip install safetensors setuptools fastapi uvicorn rouge_score scikit-learn 53 | 54 | RUN cd /JetStream && pip install -e . 55 | 56 | RUN apt -y update && apt-get -y install python3-dev && apt-get -y install build-essential 57 | RUN cp -r /aqt/aqt/* /usr/local/lib/python3.10/dist-packages/aqt/ 58 | 59 | ENTRYPOINT [ "/bin/bash" ] -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/DeepSeek-R1-671B/docker/cloudbuild.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | steps: 16 | - name: 'gcr.io/cloud-builders/docker' 17 | args: 18 | - 'build' 19 | - '--tag=${_ARTIFACT_REGISTRY}/${_JETSTREAM_MAXTEXT_IMAGE}:${_JETSTREAM_MAXTEXT_VERSION}' 20 | - '--file=Dockerfile' 21 | - '.' 22 | automapSubstitutions: true 23 | 24 | images: 25 | - ${_ARTIFACT_REGISTRY}/${_JETSTREAM_MAXTEXT_IMAGE}:${_JETSTREAM_MAXTEXT_VERSION} -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/DeepSeek-R1-671B/pathways.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Hypercomputer/tpu-recipes/4204255d2ce7f6be0a6b4cd6fa512dbe94824ef5/inference/trillium/JetStream-Maxtext/DeepSeek-R1-671B/pathways.png -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/DeepSeek-R1-671B/prepare-model/batch_job.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | taskGroups: 16 | - taskSpec: 17 | runnables: 18 | - container: 19 | imageUri: ${ARTIFACT_REGISTRY}/${JETSTREAM_MAXTEXT_IMAGE}:${JETSTREAM_MAXTEXT_VERSION} 20 | entrypoint: "/bin/sh" 21 | commands: 22 | - "-c" 23 | - mkdir -p /mnt/disks/persist/models/ && echo "Downloading model ${HF_MODEL_NAME}" && huggingface-cli download ${HF_MODEL_NAME} --local-dir /mnt/disks/persist/models/fp8 && cd /maxtext && echo "Converting checkpoint from fp8 to bf16" && python3 -m MaxText.deepseek_fp8_to_bf16 --input-fp8-hf-path /mnt/disks/persist/models/fp8 --output-bf16-hf-path /mnt/disks/persist/models/bf16 --cache-file-num 16 && echo "Converting checkpoint from bf16 to maxtext/unscanned format" && JAX_PLATFORMS='' python3 -m MaxText.convert_deepseek_unscanned_ckpt --base_model_path /mnt/disks/persist/models/bf16 --maxtext_model_path ${GCS_CKPT_PATH_UNSCANNED} --model_size $MODEL_NAME --use-zarr3 false --use-ocdbt false && echo "Completed checkpoint conversion. Unscanned checkpoint saved at ${GCS_CKPT_PATH_UNSCANNED}" 24 | volumes: 25 | - deviceName: persist 26 | mountPath: /mnt/disks/persist 27 | mountOptions: rw,async 28 | computeResource: 29 | cpuMilli: 160000 30 | memoryMib: 3936256 31 | # Define the allocation policy for provisioning VMs 32 | allocationPolicy: 33 | location: 34 | allowedLocations: ["regions/${CLUSTER_CKPT_NODE_REGION}"] 35 | instances: 36 | - policy: 37 | machineType: ${CLUSTER_CKPT_NODE_MACHINE_TYPE} 38 | bootDisk: 39 | type: pd-ssd 40 | sizeGb: 500 41 | disks: 42 | newDisk: 43 | sizeGb: 3000 44 | type: pd-ssd 45 | deviceName: persist 46 | logsPolicy: 47 | destination: CLOUD_LOGGING -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/DeepSeek-R1-671B/serve-model/Chart.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v2 16 | name: trillium-pathways-jetstream-maxtext-serve-model 17 | description: trillium-pathways-jetstream-maxtext-serve-model 18 | type: application 19 | version: 0.1.0 20 | appVersion: "1.16.0" -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/DeepSeek-R1-671B/serve-model/templates/model-serve-configmap.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v1 16 | kind: ConfigMap 17 | metadata: 18 | name: "{{ .Release.Name }}" 19 | data: 20 | maxtext-configuration.yaml: |- 21 | {{- range $key, $value := .Values.maxtext_config }} 22 | {{ $key }}: {{ $value }} 23 | {{- end }} 24 | -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/DeepSeek-R1-671B/serve-model/templates/model-serve-svc.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v1 16 | kind: Service 17 | metadata: 18 | name: jetstream-svc 19 | spec: 20 | selector: 21 | app: jetstream-pathways 22 | ports: 23 | - protocol: TCP 24 | name: jetstream-http 25 | port: {{ .Values.jetstream.service.ports.http }} 26 | targetPort: {{ .Values.jetstream.service.ports.http }} -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/DeepSeek-R1-671B/values.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | clusterName: 17 | 18 | huggingface: 19 | secretName: hf-secret 20 | secretData: 21 | token: "hf_api_token" 22 | 23 | model: 24 | name: &model-name deepseek3-671b 25 | hf_model_name: &hf-model-name deepseek-ai/DeepSeek-R1 26 | 27 | job: 28 | jax_tpu_image: 29 | repository: 30 | tag: 31 | jetstream_http_image: 32 | repository: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http 33 | tag: v0.2.3 34 | pathways_proxy_image: 35 | repository: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server 36 | tag: latest 37 | pathways_rm_image: 38 | repository: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server 39 | tag: latest 40 | 41 | 42 | volumes: 43 | ssdMountPath: "/ssd" 44 | gcsMounts: 45 | - bucketName: 46 | mountPath: "/gcs" 47 | 48 | jetstream: 49 | service: 50 | ports: 51 | http: 8000 52 | grpc: 9000 53 | 54 | convert_hf_ckpt: true 55 | 56 | maxtext_config: 57 | allow_split_physical_axes: true 58 | tokenizer_type: huggingface 59 | hf_access_token: $HF_TOKEN 60 | tokenizer_path: *hf-model-name 61 | model_name: *model-name 62 | use_chat_template: false 63 | load_parameters_path: 64 | max_prefill_predict_length: 1024 65 | max_target_length: 1536 66 | async_checkpointing: false 67 | steps: 1 68 | ici_fsdp_parallelism: 1 69 | ici_autoregressive_parallelism: 1 70 | ici_expert_parallelism: 1 71 | ici_tensor_parallelism: 64 72 | scan_layers: false 73 | weight_dtype: bfloat16 74 | per_device_batch_size: 1 75 | enable_single_controller: true 76 | megablox: false 77 | sparse_matmul: false 78 | capacity_factor: -1.0 79 | attention: "dot_product" 80 | quantize_kvcache: true 81 | kv_quant_dtype: int8 82 | enable_model_warmup: true -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/DeepSeek-R1-Distill-Llama-70B/gke/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | FROM ubuntu:22.04 16 | 17 | ENV DEBIAN_FRONTEND=noninteractive 18 | 19 | RUN apt update && apt install --yes --no-install-recommends \ 20 | ca-certificates \ 21 | curl \ 22 | git \ 23 | gnupg \ 24 | cmake \ 25 | python3.10 \ 26 | python3-pip \ 27 | && echo "deb https://packages.cloud.google.com/apt gcsfuse-buster main" \ 28 | | tee /etc/apt/sources.list.d/gcsfuse.list \ 29 | && echo "deb https://packages.cloud.google.com/apt cloud-sdk main" \ 30 | | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list \ 31 | && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - \ 32 | && apt-get update \ 33 | && apt-get install --yes gcsfuse \ 34 | && apt-get install --yes google-cloud-cli \ 35 | && apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \ 36 | && mkdir /gcs 37 | 38 | 39 | RUN update-alternatives --install \ 40 | /usr/bin/python3 python3 /usr/bin/python3.10 1 41 | 42 | RUN git clone https://github.com/AI-Hypercomputer/maxtext.git && \ 43 | git clone https://github.com/AI-Hypercomputer/JetStream.git 44 | 45 | RUN cd /JetStream && \ 46 | pip install -e . 47 | 48 | RUN cd /JetStream/benchmarks && \ 49 | pip install -r requirements.in 50 | 51 | RUN cd maxtext/ && \ 52 | bash setup.sh JAX_VERSION=0.5.0 53 | 54 | # Reset working directory to /workspace 55 | WORKDIR /workspace 56 | 57 | ENTRYPOINT [ "/bin/bash" ] 58 | -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/DeepSeek-R1-Distill-Llama-70B/gke/docker/cloudbuild.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | steps: 16 | - name: 'gcr.io/cloud-builders/docker' 17 | args: 18 | - 'build' 19 | - '--tag=${_ARTIFACT_REGISTRY}/maxtext-jetstream-deepseek:latest' 20 | - '--file=Dockerfile' 21 | - '.' 22 | automapSubstitutions: true 23 | 24 | images: 25 | - '${_ARTIFACT_REGISTRY}/maxtext-jetstream-deepseek:latest' -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/DeepSeek-R1-Distill-Llama-70B/gke/serve-model/Chart.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v2 16 | name: trillium-maxtext-jetstream-deepseek-prepare-model 17 | description: trillium-maxtext-jetstream-deepseek-prepare-model 18 | type: application 19 | version: 0.1.0 20 | appVersion: "1.16.0" -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/DeepSeek-R1-Distill-Llama-70B/gke/serve-model/templates/model-serve-configmap.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v1 16 | kind: ConfigMap 17 | metadata: 18 | name: "{{ .Release.Name }}" 19 | data: 20 | maxtext-configuration.yaml: |- 21 | {{- range $key, $value := .Values.maxtext_config }} 22 | {{ $key }}: {{ $value }} 23 | {{- end }} 24 | 25 | libtpu-init-args: |- 26 | --=false 27 | {{- range $key, $value := .Values.xla_flags }} 28 | --{{ $key }}={{ $value }} 29 | {{- end }} -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/DeepSeek-R1-Distill-Llama-70B/gke/serve-model/templates/model-serve-svc.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v1 16 | kind: Service 17 | 18 | metadata: 19 | name: {{ .Release.Name }}-svc 20 | 21 | spec: 22 | selector: 23 | app: {{ .Release.Name }}-serving 24 | ports: 25 | - protocol: TCP 26 | name: jetstream-http 27 | port: {{ .Values.jetstream.service.ports.http }} 28 | targetPort: {{ .Values.jetstream.service.ports.http }} 29 | - protocol: TCP 30 | name: jetstream-grpc 31 | port: {{ .Values.jetstream.service.ports.grpc }} 32 | targetPort: {{ .Values.jetstream.service.ports.grpc }} -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/DeepSeek-R1-Distill-Llama-70B/gke/values.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | clusterName: 18 | 19 | huggingface: 20 | secretName: hf-secret 21 | secretData: 22 | token: "hf_api_token" 23 | 24 | model: 25 | name: deepseek-ai/DeepSeek-R1-Distill-Llama-70B 26 | 27 | job: 28 | image: 29 | repository: 30 | tag: 31 | 32 | volumes: 33 | ssdMountPath: "/ssd" 34 | gcsMounts: 35 | - bucketName: 36 | mountPath: "/gcs" 37 | 38 | jetstream: 39 | service: 40 | ports: 41 | http: 8000 42 | grpc: 9000 43 | 44 | convert_hf_ckpt: true 45 | 46 | maxtext_config: 47 | tokenizer_path: $TOKENIZER_DIR 48 | load_parameters_path: $CHECKPOINT_TPU_UNSCANNED 49 | model_name: $MODEL_SIZE 50 | weight_dtype: "bfloat16" 51 | scan_layers: false 52 | ici_fsdp_parallelism: 1 53 | ici_autoregressive_parallelism: 1 54 | ici_tensor_parallelism: -1 55 | per_device_batch_size: 2 56 | max_prefill_predict_length: 1024 57 | max_target_length: 1536 58 | attention: "dot_product" 59 | optimize_mesh_for_tpu_v6e: true 60 | quantize_kvcache: true 61 | 62 | xla_flags: 63 | xla_tpu_enable_windowed_einsum_for_reduce_scatter: false 64 | xla_tpu_enable_windowed_einsum_for_all_gather: false 65 | xla_tpu_prefer_latch_optimized_rhs_layouts: true 66 | xla_tpu_enable_experimental_fusion_cost_model: false 67 | xla_tpu_dot_dot_fusion_duplicated: false 68 | xla_tpu_dot_dot_fusion: true 69 | xla_jf_conv_input_fusion: true 70 | xla_jf_conv_output_fusion: true 71 | xla_tpu_rwb_fusion: false 72 | xla_tpu_copy_fusion_pad_unpad_ratio: 0 73 | xla_tpu_licm_size_inflation_ratio: 1 74 | xla_tpu_copy_elision_analysis_allowance: 150000 75 | xla_tpu_copy_insertion_use_region_analysis_limit: 10000 76 | xla_tpu_order_dot_after_layout: true 77 | xla_jf_rematerialization_percent_shared_memory_limit: 100 78 | xla_tpu_use_repeated_instance_for_preferred_prefetch_time: true 79 | xla_tpu_enforce_prefetch_fifo_order: false 80 | xla_tpu_prefetch_interval_picker_size_override: 6000000 81 | xla_tpu_async_copy_bandwidth_scaling_factor: 1 82 | xla_tpu_nd_short_transfer_max_chunks: -1 83 | xla_tpu_enable_aggressive_broadcast_priority_update: true 84 | xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers: SQRT 85 | xla_tpu_memory_bound_loop_optimizer_options: enabled:true 86 | xla_tpu_enable_copy_fusion: true 87 | xla_tpu_enable_cross_program_prefetch_freeing: false 88 | xla_tpu_enable_dot_strength_reduction: true 89 | xla_tpu_layout_use_dot_grouping: false 90 | xla_tpu_msa_inefficient_use_to_copy_ratio: 0.5 91 | xla_tpu_reduce_loop_fusion_dup_with_unfusable_user: false 92 | xla_tpu_vector_load_fusion_window: 1024 93 | xla_tpu_vector_store_fusion_window: 256 94 | xla_jf_conv_reshape_fusion: false 95 | xla_tpu_input_conv_multi_users: false 96 | xla_tpu_enable_multi_level_input_dot_dot_fusion: false 97 | xla_tpu_enable_multi_level_output_dot_dot_fusion: false 98 | xla_tpu_dot_dot_fusion_separable_convs_only: false 99 | xla_tpu_enable_multi_level_nested_loop_fusion: true 100 | xla_tpu_nested_dot_fusion: true 101 | xla_tpu_enable_multi_level_nested_dot_fusion: false 102 | xla_jf_enable_multi_output_fusion: true 103 | xla_tpu_use_lp_llo_scheduler_for_dot_dot_fusions: false 104 | xla_tpu_enable_flash_attention: true -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/Llama-4-Maverick-17B-128E/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | FROM ubuntu:22.04 16 | 17 | ENV DEBIAN_FRONTEND=noninteractive 18 | 19 | RUN apt -y update && apt install -y --no-install-recommends apt-transport-https ca-certificates gnupg git python3.10 python3-pip curl nano vim 20 | 21 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 22 | RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && apt-get update -y && apt-get install google-cloud-sdk -y 23 | 24 | RUN python3 -m pip install --upgrade pip 25 | 26 | ENV JAX_PLATFORMS=proxy 27 | ENV JAX_BACKEND_TARGET=grpc://localhost:38681 28 | ENV XCLOUD_ENVIRONMENT=GCP 29 | 30 | ENV MAXTEXT_VERSION=main 31 | ENV JETSTREAM_VERSION=main 32 | 33 | RUN git clone https://github.com/AI-Hypercomputer/JetStream.git && \ 34 | git clone https://github.com/AI-Hypercomputer/maxtext.git 35 | 36 | RUN cd maxtext/ && \ 37 | git checkout ${MAXTEXT_VERSION} && \ 38 | bash setup.sh 39 | 40 | RUN cd /JetStream && \ 41 | git checkout ${JETSTREAM_VERSION} && \ 42 | pip install -e . 43 | 44 | RUN pip install setuptools fastapi uvicorn 45 | 46 | RUN apt -y update && apt-get -y install python3-dev && apt-get -y install build-essential 47 | 48 | ENTRYPOINT [ "/bin/bash" ] 49 | -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/Llama-4-Maverick-17B-128E/docker/cloudbuild.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | steps: 16 | - name: 'gcr.io/cloud-builders/docker' 17 | args: 18 | - 'build' 19 | - '--tag=${_ARTIFACT_REGISTRY}/${_JETSTREAM_MAXTEXT_IMAGE}:${_JETSTREAM_MAXTEXT_VERSION}' 20 | - '--file=Dockerfile' 21 | - '.' 22 | automapSubstitutions: true 23 | 24 | images: 25 | - '${_ARTIFACT_REGISTRY}/${_JETSTREAM_MAXTEXT_IMAGE}:${_JETSTREAM_MAXTEXT_VERSION}' -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/Llama-4-Maverick-17B-128E/prepare-model/Chart.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v2 16 | name: trillium-maxtext-jetstream-llama-serve-model 17 | description: trillium-maxtext-jetstream-llama-serve-model 18 | type: application 19 | version: 0.1.0 20 | appVersion: "1.16.0" -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/Llama-4-Maverick-17B-128E/serve-model/Chart.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v2 16 | name: trillium-maxtext-jetstream-llama-serve-model 17 | description: trillium-maxtext-jetstream-llama-serve-model 18 | type: application 19 | version: 0.1.0 20 | appVersion: "1.16.0" -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/Llama-4-Maverick-17B-128E/serve-model/templates/model-serve-configmap.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v1 16 | kind: ConfigMap 17 | metadata: 18 | name: "{{ .Release.Name }}" 19 | data: 20 | maxtext-configuration.yaml: |- 21 | {{- range $key, $value := .Values.maxtext_config }} 22 | {{ $key }}: {{ $value }} 23 | {{- end }} 24 | 25 | libtpu-init-args: |- 26 | {{- range $key, $value := .Values.xla_flags }} 27 | --{{ $key }}={{ $value }} 28 | {{- end }} -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/Llama-4-Maverick-17B-128E/values.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | clusterName: 18 | 19 | huggingface: 20 | secretName: hf-secret 21 | secretData: 22 | token: "hf_api_token" 23 | 24 | model: 25 | name: meta-llama/Llama-4-Maverick-17B-128E 26 | 27 | job: 28 | image: 29 | repository: 30 | tag: 31 | 32 | volumes: 33 | ssdMountPath: "/ssd" 34 | gcsMounts: 35 | - bucketName: 36 | mountPath: "/gcs" 37 | 38 | pathwaysDir: 39 | 40 | jetstream: 41 | service: 42 | ports: 43 | http: 8000 44 | grpc: 9000 45 | 46 | convert_hf_ckpt: true 47 | 48 | maxtext_config: 49 | load_parameters_path: $CHECKPOINT_TPU_UNSCANNED 50 | max_prefill_predict_length: 128 51 | max_target_length: 256 52 | async_checkpointing: false 53 | steps: 1 54 | ici_fsdp_parallelism: 1 55 | ici_autoregressive_parallelism: 8 56 | ici_tensor_parallelism: 8 57 | ici_context_autoregressive_parallelism: 1 58 | scan_layers: false 59 | weight_dtype: "bfloat16" 60 | per_device_batch_size: 10 61 | enable_single_controller: true 62 | enable_model_warmup: true 63 | checkpoint_storage_use_ocdbt: false 64 | checkpoint_storage_use_zarr3: false 65 | attention: dot_product 66 | hf_access_token: $HF_TOKEN -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/Llama-4-Scout-17B-16E/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | FROM ubuntu:22.04 16 | 17 | ENV DEBIAN_FRONTEND=noninteractive 18 | 19 | RUN apt update && apt install --yes --no-install-recommends \ 20 | ca-certificates \ 21 | curl \ 22 | git \ 23 | gnupg \ 24 | cmake \ 25 | python3.10 \ 26 | python3-pip \ 27 | && echo "deb https://packages.cloud.google.com/apt gcsfuse-buster main" \ 28 | | tee /etc/apt/sources.list.d/gcsfuse.list \ 29 | && echo "deb https://packages.cloud.google.com/apt cloud-sdk main" \ 30 | | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list \ 31 | && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - \ 32 | && apt-get update \ 33 | && apt-get install --yes gcsfuse \ 34 | && apt-get install --yes google-cloud-cli \ 35 | && apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \ 36 | && mkdir /gcs 37 | 38 | 39 | RUN update-alternatives --install \ 40 | /usr/bin/python3 python3 /usr/bin/python3.10 1 41 | 42 | RUN git clone https://github.com/AI-Hypercomputer/maxtext.git && \ 43 | git clone https://github.com/AI-Hypercomputer/JetStream.git 44 | 45 | RUN cd /JetStream && \ 46 | pip install -e . 47 | 48 | RUN cd /JetStream/benchmarks && \ 49 | pip install -r requirements.in 50 | 51 | RUN cd maxtext/ && \ 52 | bash setup.sh JAX_VERSION=0.5.1 53 | 54 | # Reset working directory to /workspace 55 | WORKDIR /workspace 56 | 57 | ENTRYPOINT [ "/bin/bash" ] 58 | -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/Llama-4-Scout-17B-16E/docker/cloudbuild.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | steps: 16 | - name: 'gcr.io/cloud-builders/docker' 17 | args: 18 | - 'build' 19 | - '--tag=${_ARTIFACT_REGISTRY}/${_JETSTREAM_MAXTEXT_IMAGE}:${_JETSTREAM_MAXTEXT_VERSION}' 20 | - '--file=Dockerfile' 21 | - '.' 22 | automapSubstitutions: true 23 | 24 | images: 25 | - '${_ARTIFACT_REGISTRY}/${_JETSTREAM_MAXTEXT_IMAGE}:${_JETSTREAM_MAXTEXT_VERSION}' -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/Llama-4-Scout-17B-16E/prepare-model/Chart.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v2 16 | name: trillium-maxtext-jetstream-llama-serve-model 17 | description: trillium-maxtext-jetstream-llama-serve-model 18 | type: application 19 | version: 0.1.0 20 | appVersion: "1.16.0" -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/Llama-4-Scout-17B-16E/serve-model/Chart.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v2 16 | name: trillium-maxtext-jetstream-llama-serve-model 17 | description: trillium-maxtext-jetstream-llama-serve-model 18 | type: application 19 | version: 0.1.0 20 | appVersion: "1.16.0" -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/Llama-4-Scout-17B-16E/serve-model/templates/model-serve-configmap.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v1 16 | kind: ConfigMap 17 | metadata: 18 | name: "{{ .Release.Name }}" 19 | data: 20 | maxtext-configuration.yaml: |- 21 | {{- range $key, $value := .Values.maxtext_config }} 22 | {{ $key }}: {{ $value }} 23 | {{- end }} 24 | 25 | libtpu-init-args: |- 26 | {{- range $key, $value := .Values.xla_flags }} 27 | --{{ $key }}={{ $value }} 28 | {{- end }} -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/Llama-4-Scout-17B-16E/serve-model/templates/model-serve-svc.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v1 16 | kind: Service 17 | 18 | metadata: 19 | name: {{ .Release.Name }}-svc 20 | 21 | spec: 22 | selector: 23 | app: {{ .Release.Name }}-serving 24 | ports: 25 | - protocol: TCP 26 | name: jetstream-http 27 | port: {{ .Values.jetstream.service.ports.http }} 28 | targetPort: {{ .Values.jetstream.service.ports.http }} 29 | - protocol: TCP 30 | name: jetstream-grpc 31 | port: {{ .Values.jetstream.service.ports.grpc }} 32 | targetPort: {{ .Values.jetstream.service.ports.grpc }} -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/Llama-4-Scout-17B-16E/values.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | clusterName: 18 | 19 | huggingface: 20 | secretName: hf-secret 21 | secretData: 22 | token: "hf_api_token" 23 | 24 | model: 25 | name: meta-llama/Llama-4-Scout-17B-16E 26 | 27 | job: 28 | image: 29 | repository: 30 | tag: 31 | 32 | volumes: 33 | ssdMountPath: "/ssd" 34 | gcsMounts: 35 | - bucketName: 36 | mountPath: "/gcs" 37 | 38 | jetstream: 39 | service: 40 | ports: 41 | http: 8000 42 | grpc: 9000 43 | 44 | convert_hf_ckpt: true 45 | 46 | maxtext_config: 47 | scan_layers: false 48 | model_name: llama4-17b-16e 49 | weight_dtype: "bfloat16" 50 | base_output_directory: $BASE_OUTPUT_PATH 51 | run_name: serving-run 52 | load_parameters_path: $CHECKPOINT_TPU_UNSCANNED 53 | sparse_matmul: false 54 | ici_tensor_parallelism: 8 55 | max_prefill_predict_length: 1024 56 | force_unroll: false 57 | max_target_length: 2048 58 | attention: dot_product 59 | hf_access_token: $HF_TOKEN 60 | -------------------------------------------------------------------------------- /inference/trillium/JetStream-Maxtext/Llama2-7B/README.md: -------------------------------------------------------------------------------- 1 | # Setup 2 | 3 | ## Step 1: Download JetStream and MaxText github repository 4 | ```bash 5 | cd ~ 6 | git clone https://github.com/google/maxtext.git 7 | cd maxtext 8 | git checkout main 9 | 10 | cd ~ 11 | git clone https://github.com/google/JetStream.git 12 | cd JetStream 13 | git checkout main 14 | ``` 15 | 16 | ## Step 2: Setup JetStream and MaxText 17 | ```bash 18 | cd ~ 19 | sudo apt install python3.10-venv 20 | python -m venv venv-maxtext 21 | source venv-maxtext/bin/activate 22 | 23 | cd ~ 24 | cd JetStream 25 | pip install -e . 26 | cd benchmarks 27 | pip install -r requirements.in 28 | 29 | cd ~ 30 | cd maxtext/ 31 | bash setup.sh 32 | ``` 33 | 34 | ## Step 3: Checkpoint conversion 35 | 36 | ```bash 37 | # Go to https://llama.meta.com/llama-downloads/ and fill out the form 38 | git clone https://github.com/meta-llama/llama 39 | bash download.sh # When prompted, choose 7B. This should create a directory llama-2-7b inside the llama directory 40 | 41 | 42 | export CHKPT_BUCKET=gs://... 43 | export MAXTEXT_BUCKET_SCANNED=gs://... 44 | export MAXTEXT_BUCKET_UNSCANNED=gs://... 45 | gsutil cp -r llama/llama-2-7b/* ${CHKPT_BUCKET} 46 | 47 | 48 | # Checkpoint conversion 49 | cd maxtext 50 | bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} 51 | 52 | # The path to the unscanned checkpoint should be set by the script, but set it explicitly if it hasn't 53 | # For example export UNSCANNED_CKPT_PATH=gs://${MAXTEXT_BUCKET_UNSCANNED}/llama2-7b_unscanned_chkpt_2024-08-23-23-17/checkpoints/0/items 54 | export UNSCANNED_CKPT_PATH=gs://.. 55 | ``` 56 | 57 | # Benchmark 58 | 59 | In terminal tab 1, start the server: 60 | ```bash 61 | export TOKENIZER_PATH=assets/tokenizer.llama2 62 | export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH} 63 | export MAX_PREFILL_PREDICT_LENGTH=1024 64 | export MAX_TARGET_LENGTH=2048 65 | export MODEL_NAME=llama2-7b 66 | export ICI_FSDP_PARALLELISM=1 67 | export ICI_AUTOREGRESSIVE_PARALLELISM=1 68 | export ICI_TENSOR_PARALLELISM=-1 69 | export SCAN_LAYERS=false 70 | export WEIGHT_DTYPE=bfloat16 71 | export PER_DEVICE_BATCH_SIZE=11 72 | 73 | cd ~/maxtext 74 | python MaxText/maxengine_server.py \ 75 | MaxText/configs/base.yml \ 76 | tokenizer_path=${TOKENIZER_PATH} \ 77 | load_parameters_path=${LOAD_PARAMETERS_PATH} \ 78 | max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ 79 | max_target_length=${MAX_TARGET_LENGTH} \ 80 | model_name=${MODEL_NAME} \ 81 | ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ 82 | ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ 83 | ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ 84 | scan_layers=${SCAN_LAYERS} \ 85 | weight_dtype=${WEIGHT_DTYPE} \ 86 | per_device_batch_size=${PER_DEVICE_BATCH_SIZE} 87 | ``` 88 | 89 | In terminal tab 2, run the benchmark: 90 | ```bash 91 | source venv-maxtext/bin/activate 92 | 93 | python JetStream/benchmarks/benchmark_serving.py \ 94 | --tokenizer ~/maxtext/assets/tokenizer.llama2 \ 95 | --warmup-mode sampled \ 96 | --save-result \ 97 | --save-request-outputs \ 98 | --request-outputs-file-path outputs.json \ 99 | --num-prompts 1000 \ 100 | --max-output-length 1024 \ 101 | --dataset openorca 102 | ``` 103 | 104 | After the benchmark finishes, you should see something like 105 | ```bash 106 | Successful requests: 995 107 | Benchmark duration: 305.366344 s 108 | Total input tokens: 217011 109 | Total generated tokens: 934964 110 | Request throughput: 3.26 requests/s 111 | Input token throughput: 710.66 tokens/s 112 | Output token throughput: 3061.78 tokens/s 113 | Mean TTFT: 130288.20 ms 114 | Median TTFT: 140039.96 ms 115 | P99 TTFT: 278498.91 ms 116 | Mean TPOT: 5052.76 ms 117 | Median TPOT: 164.01 ms 118 | P99 TPOT: 112171.56 ms 119 | 120 | ``` 121 | -------------------------------------------------------------------------------- /inference/trillium/MaxDiffusion/SDXL/README.md: -------------------------------------------------------------------------------- 1 | # Setup 2 | 3 | 4 | ## Step 1: Installing the dependencies: 5 | ``` 6 | mkdir -p ~/miniconda3 7 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh 8 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 9 | rm -rf ~/miniconda3/miniconda.sh 10 | 11 | export PATH="$HOME/miniconda3/bin:$PATH" 12 | source ~/.bashrc 13 | 14 | conda create -n tpu python=3.10 15 | source activate tpu 16 | 17 | https://github.com/google/maxdiffusion.git && cd maxdiffusion 18 | git checkout mlperf4.1 19 | 20 | pip install -e . 21 | pip install -r requirements.txt 22 | pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 23 | ``` 24 | 25 | ## Step 2: Running the inference benchmark: 26 | 27 | ``` 28 | LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536" python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" 29 | ``` 30 | 31 | 32 | -------------------------------------------------------------------------------- /inference/trillium/vLLM/README.md: -------------------------------------------------------------------------------- 1 | # Serve vLLM on Trillium TPUs (v6e): 2 | 3 | This repository provides examples demonstrating how to deploy and serve vLLM on Trillium TPUs using GCE (Google Compute Engine) for a select set of models. 4 | 5 | - [Llama3-8b](./Llama3-8b/README.md) 6 | - [Qwen2.5-32B](./Qwen2.5-32B/README.md) 7 | - [Llama-3.3-70B](./Llama3.3-70b/README.md) 8 | 9 | These models were chosen for demonstration purposes only. You can serve any model from this list: [vLLM Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html) 10 | 11 | If you are looking for GKE-based deployment, please refer to this documentation: [Serve an LLM using TPU Trillium on GKE with vLLM](https://cloud.google.com/kubernetes-engine/docs/tutorials/serve-vllm-tpu) 12 | 13 | -------------------------------------------------------------------------------- /inference/v5e/JetStream-Maxtext/Llama2-7B/README.md: -------------------------------------------------------------------------------- 1 | # Setup 2 | 3 | ## Step 1: Download JetStream and MaxText github repository 4 | ```bash 5 | cd ~ 6 | git clone https://github.com/google/maxtext.git 7 | cd maxtext 8 | git checkout main 9 | 10 | cd ~ 11 | git clone https://github.com/google/JetStream.git 12 | cd JetStream 13 | git checkout main 14 | ``` 15 | 16 | ## Step 2: Setup JetStream and MaxText 17 | ```bash 18 | cd ~ 19 | sudo apt install python3.10-venv 20 | python -m venv venv-maxtext 21 | source venv-maxtext/bin/activate 22 | 23 | cd ~ 24 | cd JetStream 25 | pip install -e . 26 | cd benchmarks 27 | pip install -r requirements.in 28 | 29 | cd ~ 30 | cd maxtext/ 31 | bash setup.sh 32 | ``` 33 | 34 | ## Step 3: Checkpoint conversion 35 | 36 | ```bash 37 | # Go to https://llama.meta.com/llama-downloads/ and fill out the form 38 | git clone https://github.com/meta-llama/llama 39 | bash download.sh # When prompted, choose 7B. This should create a directory llama-2-7b inside the llama directory 40 | 41 | 42 | export CHKPT_BUCKET=gs://... 43 | export MAXTEXT_BUCKET_SCANNED=gs://... 44 | export MAXTEXT_BUCKET_UNSCANNED=gs://... 45 | gsutil cp -r llama/llama-2-7b ${CHKPT_BUCKET} 46 | 47 | 48 | # Checkpoint conversion 49 | cd maxtext 50 | bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} 51 | 52 | # The path to the unscanned checkpoint should be set by the script, but set it explicitly if it hasn't 53 | # For example export UNSCANNED_CKPT_PATH=gs://${MAXTEXT_BUCKET_UNSCANNED}/llama2-7b_unscanned_chkpt_2024-08-23-23-17/checkpoints/0/items 54 | export UNSCANNED_CKPT_PATH=gs://.. 55 | ``` 56 | 57 | # Benchmark 58 | 59 | In terminal tab 1, start the server: 60 | ```bash 61 | export TOKENIZER_PATH=assets/tokenizer.llama2 62 | export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH} 63 | export MAX_PREFILL_PREDICT_LENGTH=1024 64 | export MAX_TARGET_LENGTH=2048 65 | export MODEL_NAME=llama2-7b 66 | export ICI_FSDP_PARALLELISM=1 67 | export ICI_AUTOREGRESSIVE_PARALLELISM=1 68 | export ICI_TENSOR_PARALLELISM=-1 69 | export SCAN_LAYERS=false 70 | export WEIGHT_DTYPE=bfloat16 71 | export PER_DEVICE_BATCH_SIZE=11 72 | 73 | cd ~/maxtext 74 | python MaxText/maxengine_server.py \ 75 | MaxText/configs/base.yml \ 76 | tokenizer_path=${TOKENIZER_PATH} \ 77 | load_parameters_path=${LOAD_PARAMETERS_PATH} \ 78 | max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ 79 | max_target_length=${MAX_TARGET_LENGTH} \ 80 | model_name=${MODEL_NAME} \ 81 | ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ 82 | ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ 83 | ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ 84 | scan_layers=${SCAN_LAYERS} \ 85 | weight_dtype=${WEIGHT_DTYPE} \ 86 | per_device_batch_size=${PER_DEVICE_BATCH_SIZE} 87 | ``` 88 | 89 | In terminal tab 2, run the benchmark: 90 | ```bash 91 | source venv-maxtext/bin/activate 92 | 93 | python JetStream/benchmarks/benchmark_serving.py \ 94 | --tokenizer ~/maxtext/assets/tokenizer.llama2 \ 95 | --warmup-mode sampled \ 96 | --save-result \ 97 | --save-request-outputs \ 98 | --request-outputs-file-path outputs.json \ 99 | --num-prompts 1000 \ 100 | --max-output-length 1024 \ 101 | --dataset openorca 102 | ``` 103 | 104 | After the benchmark finishes, you should see something like 105 | ```bash 106 | Successful requests: 995 107 | Benchmark duration: 305.366344 s 108 | Total input tokens: 217011 109 | Total generated tokens: 934964 110 | Request throughput: 3.26 requests/s 111 | Input token throughput: 710.66 tokens/s 112 | Output token throughput: 3061.78 tokens/s 113 | Mean TTFT: 130288.20 ms 114 | Median TTFT: 140039.96 ms 115 | P99 TTFT: 278498.91 ms 116 | Mean TPOT: 5052.76 ms 117 | Median TPOT: 164.01 ms 118 | P99 TPOT: 112171.56 ms 119 | 120 | ``` 121 | -------------------------------------------------------------------------------- /inference/v5e/JetStream-Pytorch/Llama2-7B/README.md: -------------------------------------------------------------------------------- 1 | # Setup 2 | 3 | ## Step 0: (optional) Create a virtual environment for Python packages to install 4 | 5 | ```bash 6 | sudo apt install python3.10-venv 7 | python -m venv venv 8 | source venv/bin/activate 9 | export WORKDIR=$(pwd) # set current dir as workdir (can set to something else) 10 | ``` 11 | 12 | ## Step 1: Get JetStream-PyTorch github repository 13 | 14 | ```bash 15 | git clone https://github.com/google/jetstream-pytorch.git 16 | cd jetstream-pytorch/ 17 | git checkout jetstream-v0.2.3 18 | ``` 19 | 20 | ## Step 2: Setup JetStream and JetStream-PyTorch 21 | ```bash 22 | source install_everything.sh 23 | ``` 24 | 25 | Do not install jetstream separately, the above command will install everything. 26 | 27 | ## Step 3: Get the checkpoint and run conversion 28 | 29 | ```bash 30 | export input_ckpt_dir=$WORKDIR/ckpt/llama2-7b/original 31 | mkdir -p $input_ckpt_dir 32 | 33 | # NOTE: get your own weights from meta! 34 | gcloud storage cp hanq-random/llama-2-7b-chat/* $input_ckpt_dir 35 | ``` 36 | 37 | Run conversion 38 | ```bash 39 | export model_name=llama-2 40 | export tokenizer_path=$input_ckpt_dir/tokenizer.llama2 41 | 42 | ## Step 1: Convert model 43 | export output_ckpt_dir=$WORKDIR/ckpt/llama2-7b/converted 44 | mkdir -p ${output_ckpt_dir} 45 | python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize_weights=True 46 | ``` 47 | 48 | # Benchmark 49 | 50 | In terminal tab 1, start the server: 51 | ```bash 52 | export tokenizer_path=$input_ckpt_dir/tokenizer.model 53 | python run_server.py --model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" --quantize_weights=1 --quantize_kv_cache=1 54 | 55 | ``` 56 | 57 | In terminal tab 2, run the benchmark: 58 | One time setup 59 | ```bash 60 | source venv/bin/activate 61 | 62 | export model_name=llama-2 63 | export WORKDIR=$(pwd) # set current dir as workdir (can set to something else) 64 | export input_ckpt_dir=$WORKDIR/ckpt/llama2-7b/original 65 | export tokenizer_path=$input_ckpt_dir/tokenizer.model 66 | 67 | cd jetstream-pytorch/deps/JetStream/benchmarks 68 | pip install -r requirements.in 69 | ``` 70 | 71 | Run the benchmark 72 | ```bash 73 | python benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 1000 --dataset openorca --save-request-outputs --warmup-mode=sampled --model=$model_name 74 | ``` 75 | 76 | 77 | # NOTE: for release 0.2.4 (coming soon). The commandline interface will change. 78 | See more at https://github.com/google/JetStream-pytorch -------------------------------------------------------------------------------- /inference/v5e/MaxDiffusion/SDXL/README.md: -------------------------------------------------------------------------------- 1 | # Setup 2 | 3 | 4 | ## Step 1: Installing the dependencies: 5 | ``` 6 | mkdir -p ~/miniconda3 7 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh 8 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 9 | rm -rf ~/miniconda3/miniconda.sh 10 | 11 | export PATH="$HOME/miniconda3/bin:$PATH" 12 | source ~/.bashrc 13 | 14 | conda create -n tpu python=3.10 15 | source activate tpu 16 | 17 | https://github.com/google/maxdiffusion.git && cd maxdiffusion 18 | git checkout mlperf4.1 19 | 20 | pip install -e . 21 | pip install -r requirements.txt 22 | pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 23 | ``` 24 | 25 | ## Step 2: Running the inference benchmark: 26 | 27 | ``` 28 | LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536" python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" 29 | ``` 30 | 31 | 32 | -------------------------------------------------------------------------------- /microbenchmarks/README.md: -------------------------------------------------------------------------------- 1 | # Microbenchmarks 2 | 3 | ## Setup 4 | 5 | Set up a v6e TPU VM for single-chip microbenchmarks: 6 | ``` 7 | export TPU_NAME=your-tpu-vm-name 8 | export PROJECT_ID=your-gcloud-project-name 9 | export ZONE=us-east5-b 10 | 11 | gcloud compute tpus tpu-vm create ${TPU_NAME} \ 12 | --project ${PROJECT_ID} \ 13 | --zone=${ZONE} \ 14 | --accelerator-type=v6e-1 \ 15 | --version=v2-alpha-tpuv6e 16 | ``` 17 | Replace the example values for `TPU_NAME`, `PROJECT_ID`, `ZONE` with your own. 18 | If needed, see the full list of [available zones](https://cloud.google.com/tpu/docs/regions-zones). 19 | 20 | SSH into the VM: 21 | ``` 22 | gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} --zone ${ZONE} 23 | ``` 24 | 25 | More info on the previous commands can be found in the [Google Cloud documentation](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm). 26 | 27 | Clone the repo and install the dependencies: 28 | ```bash 29 | git clone https://github.com/AI-Hypercomputer/tpu-recipes.git 30 | cd tpu-recipes/microbenchmarks 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | ## Run Matmul Benchmark 35 | 36 | Usage example: 37 | ``` 38 | python benchmark_matmul.py \ 39 | --dim 8192 8192 8192 \ 40 | --libtpu_args=--xla_tpu_scoped_vmem_limit_kib=65536 \ 41 | --trace_matcher="jit_matmul.*" 42 | ``` 43 | 44 | Example output: 45 | ``` 46 | dtype: bfloat16, matrix dimensions: (8192, 8192, 8192), time taken (median, ms): 1.328756094, TFLOPS: 827.474382048629 47 | ``` 48 | 49 | The figure below shows the trace of the example above. Setting 50 | `--trace_matcher="jit_matmul.*"` means that the completion time is measured by 51 | the duration of the compiled [`matmul`](benchmark_matmul.py#L19) function on 52 | TPUs, which excludes the communication overheads between the host (CPU) and 53 | TPUs. 54 | 55 | 56 | ![Trace Image](https://services.google.com/fh/files/misc/trace.png) 57 | 58 | 59 | If `--trace_matcher` is not set, the completion time will be measured by timing 60 | the function on the host, which includes the compilation and communication 61 | overheads, including kernel launch, data transfer, synchronization, etc.. 62 | 63 | Example: 64 | ``` 65 | python benchmark_matmul.py \ 66 | --dim 8192 8192 8192 \ 67 | --libtpu_args=--xla_tpu_scoped_vmem_limit_kib=65536 68 | ``` 69 | 70 | Output: 71 | 72 | ``` 73 | dtype: bfloat16, matrix dimensions: (8192, 8192, 8192), time taken (median, ms): 1.457810401916504, TFLOPS: 754.2212803054033 74 | ``` 75 | 76 | Run `python benchmark_matmul.py -h` to view the how to set the other arguments. 77 | 78 | ## HBM Bandwidth Benchmark 79 | 80 | Usage example: 81 | ``` 82 | python benchmark_hbm.py \ 83 | --num_elements=16777216 \ 84 | --trace_matcher="jit_my_copy.*" 85 | ``` 86 | 87 | Example output: 88 | ``` 89 | Tensor size (bytes): 33554432, time taken (ms, median): 0.049359414, bandwidth (GBps, median): 1359.5960438266143 90 | ``` 91 | 92 | Run `python benchmark_hbm.py -h` to view the how to set the arguments. 93 | -------------------------------------------------------------------------------- /microbenchmarks/requirements.txt: -------------------------------------------------------------------------------- 1 | jax[tpu]==0.5.2 2 | -------------------------------------------------------------------------------- /microbenchmarks/trillium/collectives/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for running Collectives Benchmark on TPU trillium (v6e-256) 2 | 3 | ## XPK setup 4 | Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK 5 | 6 | ## Run Collectives on v6e-256 7 | 8 | ### Starting workload 9 | 10 | Launch the XPK workload, example to run on 1 slice of v6e-256: 11 | ``` 12 | python3 ~/xpk/xpk.py workload create \ 13 | --cluster=${CLUSTER_NAME} \ 14 | --project=${PROJECT} \ 15 | --zone=${ZONE} \ 16 | --device-type=v6e-256 \ 17 | --command="git clone https://github.com/AI-Hypercomputer/accelerator-microbenchmarks.git && cd accelerator-microbenchmarks && git checkout trillium-collectives && pip install -r requirements.txt && echo '4096 41943040 314572800' > /proc/sys/net/ipv4/tcp_rmem && export LIBTPU_INIT_ARGS='--megascale_grpc_premap_memory_bytes=17179869184 --xla_tpu_enable_sunk_dcn_allreduce_done_with_host_reduction=true' && python src/run_benchmark.py --config=configs/1x_v6e_256.yaml" \ 18 | --num-slices=1 \ 19 | --docker-image=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 \ 20 | --workload=${WORKLOAD_NAME} 21 | ``` 22 | 23 | To run on more than 1 slice, modify the `--num_slices` and `--config` flags to use the target number of slices and the corresponding yaml config file e.g 24 | ``` 25 | --num_slices=2 --config=configs/2x_v6e_256.yaml 26 | ``` 27 | 28 | From your workload logs, you should start seeing benchmark logs: 29 | ``` 30 | psum_dcn: Matrix size: 17408x17408, dtype=, matrix_size_gbyte=0.606076928,achieved_bandwidth_gbyte_s=4.1130934137328214 31 | psum_ici: Matrix size: 17408x17408, dtype=, matrix_size_gbyte=0.606076928,achieved_bandwidth_gbyte_s=235.7595345022845 32 | ``` 33 | 34 | Results will be printed out and also stored at `/tmp/microbenchmarks/collectives`. You can save the stored results to GCS by adding the following to `--command` in the XPK command: 35 | ``` 36 | gsutil cp -r /tmp/microbenchmarks/collectives gs:// 37 | ``` 38 | 39 | ### Run with a custom yaml config 40 | If you would like to run with a custom defined yaml with modified configurations (e.g. warmup_tries, tries, matrix_dim_range) you may do so by uploading it to a GCS bucket, pulling the yaml file from GCS in the workload, and then referencing the yaml file in the benchmark command. 41 | 42 | Start by creating a yaml file `your_config.yaml`. Take a look at [1x_v6e_256.yaml](https://github.com/AI-Hypercomputer/accelerator-microbenchmarks/blob/35c10a42e8cfab7593157327dd3ad3150e4c001d/configs/1x_v6e_256.yaml) for an example yaml config. Then upload it to your GCS bucket: 43 | ``` 44 | gsutil cp your_config.yaml gs:// 45 | ``` 46 | 47 | Then use a modified launch command that pulls the yaml file from GCS and references it in the benchmark command: 48 | ``` 49 | python3 ~/xpk/xpk.py workload create \ 50 | --cluster=${CLUSTER_NAME} \ 51 | --project=${PROJECT} \ 52 | --zone=${ZONE} \ 53 | --device-type=v6e-256 \ 54 | --command="git clone https://github.com/AI-Hypercomputer/accelerator-microbenchmarks.git && cd accelerator-microbenchmarks && git checkout trillium-collectives && pip install -r requirements.txt && echo '4096 41943040 314572800' > /proc/sys/net/ipv4/tcp_rmem && export LIBTPU_INIT_ARGS='--megascale_grpc_premap_memory_bytes=17179869184 --xla_tpu_enable_sunk_dcn_allreduce_done_with_host_reduction=true' && gsutil cp gs:///your_config.yaml configs/ && python src/run_benchmark.py --config=configs/your_config.yaml" \ 55 | --num-slices=1 \ 56 | --docker-image=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 \ 57 | --workload=${WORKLOAD_NAME} 58 | ``` -------------------------------------------------------------------------------- /microbenchmarks/trillium/collectives/collectives-1xv6e-256.sh: -------------------------------------------------------------------------------- 1 | python3 ~/xpk/xpk.py workload create \ 2 | --cluster=${CLUSTER_NAME} \ 3 | --project=${PROJECT} \ 4 | --zone=${ZONE} \ 5 | --device-type=v6e-256 \ 6 | --command="git clone https://github.com/AI-Hypercomputer/accelerator-microbenchmarks.git && cd accelerator-microbenchmarks && git checkout trillium-collectives && pip install -r requirements.txt && echo '4096 41943040 314572800' > /proc/sys/net/ipv4/tcp_rmem && export LIBTPU_INIT_ARGS='--megascale_grpc_premap_memory_bytes=17179869184 --xla_tpu_enable_sunk_dcn_allreduce_done_with_host_reduction=true' && python src/run_benchmark.py --config=configs/1x_v6e_256.yaml" \ 7 | --num-slices=1 \ 8 | --docker-image=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 \ 9 | --workload=${WORKLOAD_NAME} 10 | -------------------------------------------------------------------------------- /microbenchmarks/trillium/collectives/collectives-2xv6e-256.sh: -------------------------------------------------------------------------------- 1 | python3 ~/xpk/xpk.py workload create \ 2 | --cluster=${CLUSTER_NAME} \ 3 | --project=${PROJECT} \ 4 | --zone=${ZONE} \ 5 | --device-type=v6e-256 \ 6 | --command="git clone https://github.com/AI-Hypercomputer/accelerator-microbenchmarks.git && cd accelerator-microbenchmarks && git checkout trillium-collectives && pip install -r requirements.txt && echo '4096 41943040 314572800' > /proc/sys/net/ipv4/tcp_rmem && export LIBTPU_INIT_ARGS='--megascale_grpc_premap_memory_bytes=17179869184 --xla_tpu_enable_sunk_dcn_allreduce_done_with_host_reduction=true' && python src/run_benchmark.py --config=configs/2x_v6e_256.yaml" \ 7 | --num-slices=2 \ 8 | --docker-image=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 \ 9 | --workload=${WORKLOAD_NAME} 10 | -------------------------------------------------------------------------------- /microbenchmarks/trillium/collectives/collectives-4xv6e-256.sh: -------------------------------------------------------------------------------- 1 | python3 ~/dev/xpk/xpk.py workload create \ 2 | --cluster=${CLUSTER_NAME} \ 3 | --project=${PROJECT} \ 4 | --zone=${ZONE} \ 5 | --device-type=v6e-256 \ 6 | --command="git clone https://github.com/AI-Hypercomputer/accelerator-microbenchmarks.git && cd accelerator-microbenchmarks && git checkout trillium-collectives && pip install -r requirements.txt && echo '4096 41943040 314572800' > /proc/sys/net/ipv4/tcp_rmem && export LIBTPU_INIT_ARGS='--megascale_grpc_premap_memory_bytes=17179869184 --xla_tpu_enable_sunk_dcn_allreduce_done_with_host_reduction=true' && python src/run_benchmark.py --config=configs/4x_v6e_256.yaml" \ 7 | --num-slices=4 \ 8 | --docker-image=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 \ 9 | --workload=${WORKLOAD_NAME} -------------------------------------------------------------------------------- /training/trillium/Diffusion-2-PyTorch/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Stable Diffusion 2 on TPU Trillium 2 | 3 | 4 | This user guide provides a concise overview of the essential steps required to run StableDiffusion 2.0 base training on Cloud TPUs. 5 | 6 | 7 | ## Environment Setup 8 | 9 | The following setup assumes to run the training job with StableDiffusion 2.0 base on GCE TPUs using the docker image from this registery (us-central1-docker.pkg.dev/tpu-pytorch/docker/development/pytorch-tpu-diffusers:v2), the docker image uses the pytorch and torch_xla nightly build from 09/05 and has all the package dependency installed. It cloned the git repo from [https://github.com/pytorch-tpu/diffusers (commit f08dc9)](https://github.com/pytorch-tpu/diffusers/tree/f08dc92db9d7fd7d8d8ad4efcdfee675e2cd26f2) in order to run hugging face stable diffusion on TPU. Please follow corresponding TPU generation's user guide to setup the GCE TPUs first. All the command below should run from your own machine (not the TPU host you created). 10 | 11 | ### Setup Environment of Your TPUs 12 | Please replace all your-* with your TPUs' information. 13 | ``` 14 | export TPU_NAME=your-tpu-name 15 | export ZONE=your-tpu-zone 16 | export PROJECT=your-tpu-project 17 | ``` 18 | 19 | ### Simple Run Command 20 | git clone and navigate to this README repo and run training script: 21 | ```bash 22 | git clone --depth 1 https://github.com/AI-Hypercomputer/tpu-recipes.git 23 | cd training/trillium/Diffusion-2-PyTorch 24 | bash benchmark.sh 25 | ``` 26 | `benchmark.sh` script will upload 1) environment parameters in `env.sh`, 2) docker launch script in `host.sh` and 3) python training command in `train.sh` into all TPU workers. 27 | 28 | Note that the docker image is specified in `host.sh`. Make sure the docker image is accessible in your GCP project. If not, please download the image first, upload it to your GCP project and change env `$DOCKER_IMAGE` to the registry URL you own. 29 | 30 | When all training steps complete, the benchmark script will print out the average step time. You shall see the performance metric in the terminal like: 31 | ``` 32 | [worker :x] Average step time: ... 33 | ``` 34 | This tells the average step time for each batch run of each worker. In addition, it will copy the profile back to current folder under *profile/* and the trained model in safetensor format under *output/*. Use TensorBoard to open the profile and measure the step time from the "Trace View.". 35 | 36 | 37 | ### Environment Envs Explained 38 | 39 | To make it simple, we suggest only change the following to env variables in env.sh: 40 | * `PER_HOST_BATCH_SIZE`:Batch size for each host/worker. High number can cause out of memory issue. 41 | * `TRAIN_STEPS`: How many training steps to run. (choose more than 10 for this example) 42 | * `PROFILE_DURATION`: Length of the profiling time (unit ms). 43 | * `RESOLUTION`: Image resolution. 44 | -------------------------------------------------------------------------------- /training/trillium/Diffusion-2-PyTorch/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # SCP the environment setup to all instances. Used in `--env-file` in `docker run` on the host script. 3 | gcloud compute tpus tpu-vm scp env.sh train.sh $TPU_NAME:~ --worker=all --project $PROJECT --zone=$ZONE 4 | 5 | # Actually runs the benchmark. 6 | gcloud compute tpus tpu-vm ssh $TPU_NAME --project $PROJECT --zone=$ZONE --worker=all --command="$(cat host.sh)" 7 | 8 | # Copy the profile and output back 9 | gcloud compute tpus tpu-vm scp --recurse $TPU_NAME:~/{profile,output} ./ --project=$PROJECT --zone=$ZONE 10 | -------------------------------------------------------------------------------- /training/trillium/Diffusion-2-PyTorch/env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | XLA_DISABLE_FUNCTIONALIZATION=0 3 | PROFILE_DIR=/tmp/home/profile/ 4 | CACHE_DIR=/tmp/home/xla_cache 5 | DATASET_NAME=lambdalabs/naruto-blip-captions 6 | OUTPUT_DIR=/tmp/home/output/ 7 | PROFILE_DURATION=80000 8 | PER_HOST_BATCH_SIZE=128 9 | TRAIN_STEPS=50 10 | RESOLUTION=512 11 | -------------------------------------------------------------------------------- /training/trillium/Diffusion-2-PyTorch/host.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DOCKER_IMAGE="us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-diffusers:v4" 4 | worker_id=$(curl -s "http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number" -H 'Metadata-Flavor: Google') 5 | cat >> /dev/null <&1 | sed "s/^/[worker:$worker_id] /g" | tee runlog 9 | set -o xtrace 10 | # Configure docker 11 | sudo groupadd docker 12 | sudo usermod -aG docker $USER 13 | # newgrp applies updated group permissions 14 | newgrp - docker 15 | gcloud auth configure-docker us-central1-docker.pkg.dev --quiet 16 | # Kill any running benchmarks 17 | docker kill $USER-test 18 | docker pull $DOCKER_IMAGE 19 | docker run --rm \ 20 | --name $USER-test \ 21 | --privileged \ 22 | --env-file env.sh \ 23 | -v /home/$USER:/tmp/home \ 24 | --shm-size=16G \ 25 | --net host \ 26 | -u root \ 27 | --entrypoint /bin/bash $DOCKER_IMAGE \ 28 | /tmp/home/train.sh 29 | 30 | PIPE_EOF 31 | -------------------------------------------------------------------------------- /training/trillium/Diffusion-2-PyTorch/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python /workspace/diffusers/examples/text_to_image/train_text_to_image_xla.py \ 4 | --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base \ 5 | --dataset_name=$DATASET_NAME --resolution=$RESOLUTION --center_crop --random_flip \ 6 | --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS \ 7 | --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=$PROFILE_DURATION \ 8 | --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 \ 9 | --loader_prefetch_size=4 --device_prefetch_size=4 --loader_prefetch_factor=4 10 | -------------------------------------------------------------------------------- /training/trillium/GPT3-175B-MaxText/bf16/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training GPT3-175B-Maxtext on TPU trillium 2 | 3 | ## XPK setup 4 | Please follow the [XPK_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK 5 | 6 | ## Prep for Maxtext 7 | 8 | ### Install MaxText and Build Docker Image 9 | Please follow the [MAXTEXT_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build the docker image. The following variables should be set: 10 | 11 | In step 1, use the MaxText [tpu-recipes-v0.1.2](https://github.com/AI-Hypercomputer/maxtext/releases/tag/tpu-recipes-v0.1.2) tag to run this recipe: 12 | ``` 13 | git checkout tpu-recipes-v0.1.2 14 | ``` 15 | 16 | In step 3, use the jax-stable-stack image containing JAX 0.5.2: 17 | ``` 18 | BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 19 | bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE} 20 | ``` 21 | 22 | ## Run Maxtext GPT3-175B workloads on GKE 23 | 24 | ### Starting workload 25 | 26 | From the MaxText root directory, start your GPT3-175B workload 27 | ``` 28 | python3 -m benchmarks.benchmark_runner xpk \ 29 | --project=$PROJECT \ 30 | --zone=$ZONE \ 31 | --device_type=v6e-256 \ 32 | --num_slices=1 \ 33 | --cluster_name=${CLUSTER_NAME} \ 34 | --base_output_directory=${OUTPUT_DIR} \ 35 | --model_name="gpt_3_175b_bf16" \ 36 | --base_docker_image=maxtext_base_image 37 | ``` 38 | 39 | From your workload logs, you should start seeing step time logs like the following: 40 | ``` 41 | completed step: 15, seconds: 17.182, TFLOP/s/device: 384.891, Tokens/s/device: 357.580, total_weights: 1572864, loss: 388.622 42 | ``` 43 | 44 | ### Workload Details 45 | 46 | For reference, here are the `gpt_3_175b_bf16` workload details as found in `MaxText@tpu-recipes-v0.1.2`: 47 | 48 | ``` 49 | MaxTextModel( 50 | model_name="gpt-3-175b-bf16", 51 | model_type="gpt3-175b", 52 | tuning_params={ 53 | "per_device_batch_size": 3, 54 | "ici_fsdp_parallelism": -1, 55 | "remat_policy": "full", 56 | "attention": "flash", 57 | "gcs_metrics": True, 58 | "dataset_type": "synthetic", 59 | "reuse_example_batch": 1, 60 | "enable_checkpointing": False, 61 | "profiler": "xplane", 62 | "sa_block_q": 1024, 63 | "sa_block_q_dkv": 2048, 64 | "sa_block_q_dq": 2048, 65 | }, 66 | xla_flags=( 67 | xla_flags_library.DENSE_VMEM_LIMIT_FLAG 68 | + xla_flags_library.CF_FOR_ALL_GATHER 69 | + xla_flags_library.DATA_PARALLEL_OVERLAP 70 | + xla_flags_library.DISABLE_BUNDLE_AWARE_COST_MODEL 71 | ), 72 | ) 73 | ``` 74 | 75 | This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/tpu-recipes-v0.1.2/benchmarks/maxtext_trillium_model_configs.py) file within the MaxText repository. -------------------------------------------------------------------------------- /training/trillium/GPT3-175B-MaxText/bf16/gpt3-175b-v6e-256.sh: -------------------------------------------------------------------------------- 1 | # Run this command from the MaxText root directory using the setup described in the README. 2 | python3 -m benchmarks.benchmark_runner xpk \ 3 | --project=$PROJECT \ 4 | --zone=$ZONE \ 5 | --device_type=v6e-256 \ 6 | --num_slices=1 \ 7 | --cluster_name=${CLUSTER_NAME} \ 8 | --base_output_directory=${OUTPUT_DIR} \ 9 | --model_name="gpt_3_175b_bf16" \ 10 | --base_docker_image=maxtext_base_image 11 | -------------------------------------------------------------------------------- /training/trillium/GPT3-175B-MaxText/fp8/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training GPT3-175B-Maxtext on TPU trillium 2 | 3 | ## XPK setup 4 | Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK 5 | 6 | ## Prep for Maxtext 7 | Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build docker image 8 | 9 | ## Run Maxtext GPT3-175B workloads on GKE 10 | 11 | ### Test Env 12 | jaxlib=0.4.35 13 | 14 | libtpu-nightly=20241028 15 | 16 | [maxtext](https://github.com/AI-Hypercomputer/maxtext.git)@e7292a3a572792a0d797fc8977b21d0f255729f1 17 | 18 | ### Starting workload 19 | 20 | From the MaxText root directory, start your GPT3-175B workload 21 | 22 | ``` 23 | python3 benchmarks/benchmark_runner.py --project=${PROJECT} --zone={zone} --device_type=v6e-256 --num_slices=1 --cluster_name=${CLUSTER_NAME} --base_output_directory=${OUTPUT_DIR} \ 24 | --model_name="gpt_3_175b" --libtpu_version=20241028 --base_docker_image=maxtext_base_image 25 | ``` 26 | 27 | From your workload logs, you should start seeing step time logs like the following: 28 | ``` 29 | step: 100, seconds: 14.245, TFLOP/s/device: 464.261, Tokens/s/device: 431.318, total_weights: 1572864, loss: 0.000 30 | ``` -------------------------------------------------------------------------------- /training/trillium/GPT3-175B-MaxText/fp8/gpt3-175b-v6e-256.sh: -------------------------------------------------------------------------------- 1 | python3 benchmarks/benchmark_runner.py --project=$PROJECT --zone=$ZONE --device_type=v6e-256 --num_slices=1 --cluster_name=${CLUSTER_NAME} --base_output_directory=${OUTPUT_DIR} \ 2 | --model_name="gpt_3_175b" --libtpu_version=20241009 --base_docker_image maxtext_base_image -------------------------------------------------------------------------------- /training/trillium/Llama2-70B-MaxText/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Llama2-70B-Maxtext on TPU trillium 2 | 3 | ## XPK setup 4 | Please follow the [XPK_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK 5 | 6 | ## Prep for Maxtext 7 | 8 | ### Install MaxText and Build Docker Image 9 | Please follow the [MAXTEXT_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build the docker image. The following variables should be set: 10 | 11 | In step 1, use the MaxText [tpu-recipes-v0.1.2(https://github.com/AI-Hypercomputer/maxtext/releases/tag/tpu-recipes-v0.1.2) tag to run this recipe: 12 | ``` 13 | git checkout tpu-recipes-v0.1.2 14 | ``` 15 | 16 | In step 3, use the jax-stable-stack image containing JAX 0.5.2: 17 | ``` 18 | BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 19 | bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE} 20 | ``` 21 | 22 | ## Run Maxtext Llama2-70B workloads on GKE 23 | 24 | ### Starting workload 25 | 26 | From the MaxText root directory, start your Llama2-70B workload 27 | ``` 28 | python3 -m benchmarks.benchmark_runner xpk \ 29 | --project=$PROJECT \ 30 | --zone=$ZONE \ 31 | --device_type=v6e-256 \ 32 | --num_slices=1 \ 33 | --cluster_name=${CLUSTER_NAME} \ 34 | --base_output_directory=${OUTPUT_DIR} \ 35 | --model_name="llama2_70b_4096_sc" \ 36 | --base_docker_image=maxtext_base_image 37 | ``` 38 | 39 | From your workload logs, you should start seeing step time logs like the following: 40 | ``` 41 | completed step: 16, seconds: 9.052, TFLOP/s/device: 402.274, Tokens/s/device: 905.021, total_weights: 2097152, loss: 1.104" 42 | ``` 43 | 44 | ### Workload Details 45 | 46 | For reference, here are the `llama2_70b_4096_sc` workload details as found in `MaxText@tpu-recipes-v0.1.2`: 47 | 48 | ``` 49 | MaxTextModel( 50 | model_name="llama2-70b-4096-sc", 51 | model_type="llama2-70b", 52 | tuning_params={ 53 | "per_device_batch_size": 3, 54 | "ici_fsdp_parallelism": 1, 55 | "ici_fsdp_transpose_parallelism": -1, 56 | "ici_tensor_parallelism": 1, 57 | "remat_policy": "qkv_proj_offloaded", 58 | "max_target_length": 4096, 59 | "attention": "flash", 60 | "gcs_metrics": True, 61 | "use_iota_embed": True, 62 | "dataset_path": "gs://max-datasets-rogue", 63 | "dataset_type": "synthetic", 64 | "enable_checkpointing": False, 65 | "profiler": "xplane", 66 | "sa_block_q": 1024, 67 | "sa_block_q_dkv": 2048, 68 | "sa_block_q_dq": 2048, 69 | }, 70 | xla_flags=( 71 | xla_flags_library.DENSE_VMEM_LIMIT_FLAG 72 | + xla_flags_library.CF_FOR_ALL_GATHER 73 | + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE 74 | ), 75 | ... 76 | ) 77 | ``` 78 | 79 | This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/tpu-recipes-v0.1.2/benchmarks/maxtext_trillium_model_configs.py) file within the MaxText repository. 80 | -------------------------------------------------------------------------------- /training/trillium/Llama2-70B-MaxText/llama2-70b-v6e-256.sh: -------------------------------------------------------------------------------- 1 | # Run this command from the MaxText root directory using the setup described in the README. 2 | python3 -m benchmarks.benchmark_runner xpk \ 3 | --project=$PROJECT \ 4 | --zone=$ZONE \ 5 | --device_type=v6e-256 \ 6 | --num_slices=1 \ 7 | --cluster_name=${CLUSTER_NAME} \ 8 | --base_output_directory=${OUTPUT_DIR} \ 9 | --model_name="llama2_70b_4096_sc" \ 10 | --base_docker_image=maxtext_base_image 11 | -------------------------------------------------------------------------------- /training/trillium/Llama3-8B-MaxText/v6e-256/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Llama3.1-8B-MaxText on TPU trillium (v6e-256) 2 | 3 | ## XPK setup 4 | Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK 5 | 6 | ## Prep for Maxtext 7 | Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build docker image. 8 | Be sure to use the jax-stable-stack image containing jax0.4.37. 9 | 10 | ## Run Maxtext Llama3.1-8B workloads on GKE 11 | 12 | ### Test Env 13 | jaxlib=0.4.37 14 | 15 | libtpu-nightly=20241209 16 | 17 | [maxtext](https://github.com/AI-Hypercomputer/maxtext.git)@3ad02ba70b122cec488aa5d017925aa00f5ef15f 18 | 19 | ### Starting workload 20 | 21 | From the MaxText root directory, start your Llama3.1-8B workload. Note: this benchmark uses a different model name than the equivalent v6e-8 recipe. 22 | ``` 23 | python3 benchmarks/benchmark_runner.py --project=$PROJECT --zone=$ZONE --device_type=v6e-256 --num_slices=1 --cluster_name=${CLUSTER_NAME} --base_output_directory=${OUTPUT_DIR} \ 24 | --model_name="llama3_1_8b_8192" --libtpu_version=20241209 --base_docker_image maxtext_base_image 25 | ``` 26 | 27 | From your workload logs, you should start seeing step time logs like the following: 28 | ``` 29 | completed step: 7, seconds: 4.225, TFLOP/s/device: 449.171, Tokens/s/device: 7755.989, total_weights: 8388608, loss: 0.000 30 | ``` 31 | If you would like to run on multiple slices of v6e-256, you may modify the `--num_slices` flag. -------------------------------------------------------------------------------- /training/trillium/Llama3-8B-MaxText/v6e-256/llama3-8B-1xv6e-256.sh: -------------------------------------------------------------------------------- 1 | python3 benchmarks/benchmark_runner.py --project=$PROJECT --zone=$ZONE --device_type=v6e-256 --num_slices=1 --cluster_name=${CLUSTER_NAME} --base_output_directory=${OUTPUT_DIR} \ 2 | --model_name="llama3_1_8b_8192" --libtpu_version=20241209 --base_docker_image maxtext_base_image -------------------------------------------------------------------------------- /training/trillium/Llama3-8B-MaxText/v6e-8/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Llama3.1-8B-MaxText on TPU trillium (v6e-8) 2 | 3 | ## XPK setup 4 | Please follow the [XPK_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK 5 | 6 | ## Prep for Maxtext 7 | 8 | ### Install MaxText and Build Docker Image 9 | Please follow the [MAXTEXT_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build the docker image. The following variables should be set: 10 | 11 | In step 1, use the MaxText [tpu-recipes-v0.1.2](https://github.com/AI-Hypercomputer/maxtext/releases/tag/tpu-recipes-v0.1.2) tag to run this recipe: 12 | ``` 13 | git checkout tpu-recipes-v0.1.2 14 | ``` 15 | 16 | In step 3, use the jax-stable-stack image containing JAX 0.5.2: 17 | ``` 18 | BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 19 | bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE} 20 | ``` 21 | 22 | ## Run Maxtext Llama3.1-8B workloads on GKE 23 | 24 | ### Starting workload 25 | 26 | From the MaxText root directory, start your Llama3.1-8B workload. Note: this benchmark uses a different model name than the equivalent v6e-256 recipe. 27 | ``` 28 | python3 -m benchmarks.benchmark_runner xpk \ 29 | --project=$PROJECT \ 30 | --zone=$ZONE \ 31 | --device_type=v6e-8 \ 32 | --num_slices=1 \ 33 | --cluster_name=${CLUSTER_NAME} \ 34 | --base_output_directory=${OUTPUT_DIR} \ 35 | --model_name="llama3_1_8b_8192_no_collective_matmul" \ 36 | --base_docker_image=maxtext_base_image 37 | ``` 38 | 39 | From your workload logs, you should start seeing step time logs like the following: 40 | ``` 41 | completed step: 6, seconds: 3.517, TFLOP/s/device: 404.636, Tokens/s/device: 6986.975, total_weights: 196608, loss: 7.507 42 | ``` 43 | If you would like to run on multiple slices of v6e-8, you may modify the `--num_slices` flag. 44 | 45 | ### Workload Details 46 | 47 | For reference, here are the `llama3_1_8b_8192_no_collective_matmul` workload details as found in `MaxText@tpu-recipes-v0.1.2`: 48 | 49 | ``` 50 | MaxTextModel( 51 | model_name="llama3_1-8b-8192-no-collective-matmul", 52 | model_type="llama3.1-8b", 53 | tuning_params={ 54 | "per_device_batch_size": 3, 55 | "ici_fsdp_parallelism": -1, 56 | "remat_policy": "custom", 57 | "decoder_layer_input": "offload", 58 | "out_proj": "offload", 59 | "query_proj": "offload", 60 | "key_proj": "offload", 61 | "value_proj": "offload", 62 | "max_target_length": 8192, 63 | "attention": "flash", 64 | "use_iota_embed": True, 65 | "dataset_path": "gs://max-datasets-rogue", 66 | "dataset_type": "synthetic", 67 | "enable_checkpointing": False, 68 | "sa_block_q": 2048, 69 | "sa_block_kv": 2048, 70 | "sa_block_kv_compute": 2048, 71 | "sa_block_q_dkv": 2048, 72 | "sa_block_kv_dkv": 2048, 73 | "sa_block_kv_dkv_compute": 2048, 74 | "sa_block_q_dq": 2048, 75 | "sa_block_kv_dq": 2048, 76 | "sa_use_fused_bwd_kernel": True, 77 | "profiler": "xplane", 78 | "skip_first_n_steps_for_profiler": 10, 79 | "profiler_steps": 5, 80 | }, 81 | xla_flags=( 82 | xla_flags_library.DENSE_VMEM_LIMIT_FLAG 83 | + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER 84 | + xla_flags_library.DATA_PARALLEL_OVERLAP 85 | + xla_flags_library.CF_FOR_ALL_GATHER 86 | + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE 87 | + xla_flags_library.HOST_OFFLOAD_FLAGS 88 | + xla_flags_library.DISABLE_COLLECTIVE_MATMUL 89 | ), 90 | ) 91 | ``` 92 | 93 | This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/tpu-recipes-v0.1.2/benchmarks/maxtext_trillium_model_configs.py) file within the MaxText repository. -------------------------------------------------------------------------------- /training/trillium/Llama3-8B-MaxText/v6e-8/llama3-8B-1xv6e-8.sh: -------------------------------------------------------------------------------- 1 | # Run this command from the MaxText root directory using the setup described in the README. 2 | python3 -m benchmarks.benchmark_runner xpk \ 3 | --project=$PROJECT \ 4 | --zone=$ZONE \ 5 | --device_type=v6e-8 \ 6 | --num_slices=1 \ 7 | --cluster_name=${CLUSTER_NAME} \ 8 | --base_output_directory=${OUTPUT_DIR} \ 9 | --model_name="llama3_1_8b_8192_no_collective_matmul" \ 10 | --base_docker_image=maxtext_base_image 11 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-70B-PyTorch/GCE/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Llama 3.0 70B on Trillium TPU 2 | 3 | This user guide provides a concise overview of the essential steps required to 4 | run Hugging Face (HF) Llama 3.0 70B training on Trillium TPUs. 5 | 6 | ## Environment Setup 7 | 8 | Please follow the corresponding TPU generation's user guide to setup the GCE TPUs 9 | first. 10 | 11 | Please replace all your-* with your TPUs' information. 12 | 13 | ``` 14 | export TPU_NAME=your-tpu-name 15 | export ZONE=your-tpu-zone 16 | export PROJECT=your-tpu-project 17 | ``` 18 | 19 | You may use this command to create a 256 chip Trillium pod: 20 | 21 | ```bash 22 | gcloud alpha compute tpus tpu-vm create $TPU_NAME \ 23 | --type v6e --topology 16x16 \ 24 | --project $PROJECT --zone $ZONE --version v2-alpha-tpuv6e 25 | ``` 26 | 27 | ## Steps to Run HF Llama 3.0 70B 28 | 29 | The following setup runs the training job with Llama 3.0 70B on GCE TPUs using 30 | the docker image from this registry 31 | (`us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-xla/llama3-70b:jan15built`). 32 | The docker image uses torch and torch_xla nightly build from 09/28/2024 33 | and comes with all the package dependency needed to run the model training. 34 | All the command below should run from your own machine (not the TPU host you 35 | created). 36 | 37 | 1. git clone and navigate to this README repo and run training script: 38 | 39 | ```bash 40 | git clone --depth 1 https://github.com/AI-Hypercomputer/tpu-recipes.git 41 | cd training/trillium/GCE/Llama3.0-70B-PyTorch 42 | ``` 43 | 44 | 2. Edit `env.sh` to add the hugging face token and/or setup the training parameters. 45 | 46 | ```bash 47 | # add your hugging face token 48 | HF_TOKEN=hf_*** 49 | ``` 50 | 51 | 3. Run the training script: 52 | 53 | ```bash 54 | ./benchmark.sh 55 | ``` 56 | 57 | `benchmark.sh` script will: upload 1) environment parameters in `env.sh`, 2) 58 | model related config in `config.json`, `fsdp_config.json`, 3) docker launch 59 | script in `host.sh` and 4) python training command in `train.sh` into all TPU 60 | workers, and starts the training afterwards. When all training steps complete, 61 | it will print out training metrics of each worker as below in terminal: 62 | 63 | ``` 64 | [worker :0] ***** train metrics ***** 65 | [worker :0] epoch = 0.3125 66 | [worker :0] total_flos = 10915247040GF 67 | [worker :0] train_loss = 9.278 68 | [worker :0] train_runtime = 0:46:45.60 69 | [worker :0] train_samples = 32816 70 | [worker :0] train_samples_per_second = 3.65 71 | [worker :0] train_steps_per_second = 0.007 72 | ``` 73 | 74 | In addition, it will copy back the trained model under `output/*`. 75 | 76 | ### Appendix: 77 | - historical docker image releases: `us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-xla/llama3-70b:nightly-sep28` 78 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-70B-PyTorch/GCE/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | # SCP the environment setup to all instances. 5 | gcloud compute tpus tpu-vm scp config.json fsdp_config.json train.sh host.sh env.sh "$TPU_NAME:~" --worker=all --project $PROJECT --zone=$ZONE 6 | 7 | # Actually runs the benchmark. 8 | gcloud compute tpus tpu-vm ssh $TPU_NAME --project $PROJECT --zone=$ZONE --worker=all --command="$(cat host.sh)" 9 | 10 | # Copy the profile and output back 11 | gcloud compute tpus tpu-vm scp --recurse $TPU_NAME:~/output ./ --project=$PROJECT --zone=$ZONE 12 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-70B-PyTorch/GCE/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 128000, 8 | "eos_token_id": 128001, 9 | "hidden_act": "silu", 10 | "hidden_size": 8192, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 28672, 13 | "max_position_embeddings": 8192, 14 | "model_type": "llama", 15 | "num_attention_heads": 64, 16 | "num_hidden_layers": 80, 17 | "num_key_value_heads": 8, 18 | "pretraining_tp": 1, 19 | "rms_norm_eps": 1e-05, 20 | "rope_scaling": null, 21 | "rope_theta": 500000.0, 22 | "tie_word_embeddings": false, 23 | "torch_dtype": "bfloat16", 24 | "transformers_version": "4.40.0.dev0", 25 | "use_cache": false, 26 | "vocab_size": 128256 27 | } 28 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-70B-PyTorch/GCE/env.sh: -------------------------------------------------------------------------------- 1 | # Uncomment below to set the Huggingface token 2 | HF_TOKEN=hf_*** 3 | PJRT_DEVICE=TPU 4 | XLA_IR_DEBUG=1 5 | XLA_HLO_DEBUG=1 6 | PROFILE_EPOCH=0 7 | PROFILE_STEP=3 8 | PROFILE_DURATION_MS=120000 9 | PROFILE_LOGDIR=/tmp/home/profile 10 | XLA_USE_SPMD=1 11 | MAX_STEPS=50 12 | SEQ_LENGTH=4096 13 | 14 | # Per-host batch size is the number of training examples used by a TPU VM 15 | # in each training step. For Trillium, it will be 4 times the per-device batch size, 16 | # since each TPU VM is connected to 4 Trillium TPU chips. The following will lead 17 | # to a per-device batch size of 1. Customize accordingly. 18 | PER_HOST_BATCH_SIZE=2 19 | 20 | # XLA flags 21 | # Quoting is not needed, c.f. https://github.com/moby/moby/issues/46773 22 | LIBTPU_INIT_ARGS=--xla_tpu_enable_flash_attention=false --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_scoped_vmem_limit_kib=98304 23 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-70B-PyTorch/GCE/fsdp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fsdp_transformer_layer_cls_to_wrap": [ 3 | "LlamaDecoderLayer" 4 | ], 5 | "xla": true, 6 | "xla_fsdp_v2": true, 7 | "xla_fsdp_grad_ckpt": true 8 | } 9 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-70B-PyTorch/GCE/host.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-xla/llama3-70b:jan15built 4 | 5 | worker_id=$(curl -s "http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number" -H 'Metadata-Flavor: Google') 6 | 7 | cat >> /dev/null <&1 | sed "s/^/[worker $worker_id] /g" | tee runlog 11 | set -o xtrace 12 | # Configure docker 13 | sudo groupadd docker 14 | sudo usermod -aG docker $USER 15 | # newgrp applies updated group permissions 16 | newgrp - docker 17 | gcloud auth configure-docker us-central1-docker.pkg.dev --quiet 18 | # Kill any running benchmarks 19 | docker kill $USER-test 20 | docker pull $DOCKER_IMAGE 21 | docker run --rm \ 22 | --name $USER-test \ 23 | --privileged \ 24 | --env-file env.sh \ 25 | -v /home/$USER:/tmp/home \ 26 | --shm-size=16G \ 27 | --net host \ 28 | -u root \ 29 | --entrypoint /bin/bash $DOCKER_IMAGE \ 30 | /tmp/home/train.sh 31 | 32 | PIPE_EOF 33 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-70B-PyTorch/GCE/tpu.Dockerfile: -------------------------------------------------------------------------------- 1 | # Base package containing nightly PyTorch/XLA 2 | ARG BASE_IMAGE=us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm 3 | FROM ${BASE_IMAGE} 4 | 5 | # Install transformers library 6 | ARG TRANSFORMERS_REPO=https://github.com/pytorch-tpu/transformers.git 7 | ARG TRANSFORMERS_REF=flash_attention_minibatch_v6e 8 | WORKDIR /workspace 9 | RUN git clone "${TRANSFORMERS_REPO}" transformers && cd transformers && git checkout "${TRANSFORMERS_REF}" 10 | 11 | # Install transformers dependencies 12 | WORKDIR /workspace/transformers 13 | RUN pip3 install git+file://$PWD accelerate datasets evaluate "huggingface_hub[cli]" \ 14 | "torch_xla[pallas]" \ 15 | -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ 16 | -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html 17 | 18 | WORKDIR /workspace 19 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-70B-PyTorch/GCE/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Remove existing repo and old data. 3 | LOCAL_DIR=/tmp/home/ 4 | rm -rf "${LOCAL_DIR}/output" 5 | rm -rf "${LOCAL_DIR}/plugins" 6 | rm -rf "${LOCAL_DIR}/cache" 7 | mkdir -p "${LOCAL_DIR}/output" 8 | mkdir -p "${LOCAL_DIR}/plugins" 9 | mkdir -p "${LOCAL_DIR}/cache" 10 | 11 | unset LD_PRELOAD 12 | 13 | 14 | cd transformers/ 15 | 16 | 17 | python3 examples/pytorch/language-modeling/run_clm.py \ 18 | --dataset_name wikitext \ 19 | --dataset_config_name wikitext-103-raw-v1 \ 20 | --per_device_train_batch_size "${PER_HOST_BATCH_SIZE}" \ 21 | --do_train \ 22 | --output_dir "${LOCAL_DIR}/output/test-clm" \ 23 | --overwrite_output_dir \ 24 | --config_name "${LOCAL_DIR}/config.json" \ 25 | --cache_dir "${LOCAL_DIR}/cache" \ 26 | --tokenizer_name meta-llama/Meta-Llama-3-70B \ 27 | --block_size "$SEQ_LENGTH" \ 28 | --optim adafactor \ 29 | --save_strategy no \ 30 | --logging_strategy no \ 31 | --fsdp "full_shard" \ 32 | --fsdp_config "${LOCAL_DIR}/fsdp_config.json" \ 33 | --torch_dtype bfloat16 \ 34 | --dataloader_drop_last yes \ 35 | --flash_attention \ 36 | --max_steps "$MAX_STEPS" \ 37 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-70B-PyTorch/XPK/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source env.sh 4 | 5 | python3 xpk/xpk.py workload create \ 6 | --cluster ${CLUSTER_NAME} \ 7 | --base-docker-image=${BASE_DOCKER_IMAGE} \ 8 | --workload=${WORKLOAD_NAME} \ 9 | --tpu-type=${TPU_TYPE} \ 10 | --num-slices=${NUM_SLICE} \ 11 | --on-demand \ 12 | --zone=$ZONE \ 13 | --project=$PROJECT \ 14 | --enable-debug-logs \ 15 | --command="bash /app/train.sh" 16 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-70B-PyTorch/XPK/config_70b.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 128000, 8 | "eos_token_id": 128001, 9 | "hidden_act": "silu", 10 | "hidden_size": 8192, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 28672, 13 | "max_position_embeddings": 8192, 14 | "model_type": "llama", 15 | "num_attention_heads": 64, 16 | "num_hidden_layers": 80, 17 | "num_key_value_heads": 8, 18 | "pretraining_tp": 1, 19 | "rms_norm_eps": 1e-05, 20 | "rope_scaling": null, 21 | "rope_theta": 500000.0, 22 | "tie_word_embeddings": false, 23 | "torch_dtype": "bfloat16", 24 | "transformers_version": "4.40.0.dev0", 25 | "use_cache": false, 26 | "vocab_size": 128256 27 | } 28 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-70B-PyTorch/XPK/env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Environment variables associated with XPK on GCP. 4 | export ZONE=... 5 | export PROJECT=... 6 | export TPU_TYPE=v6e-256 7 | export NUM_SLICE=1 8 | export CLUSTER_NAME=xpk-$USER-... # use existing CLUSTER if you have 9 | 10 | # Environment variables associated with training config. 11 | export BATCH_PER_DEVICE=2 12 | export SEQUENCE_LENGTH=4096 13 | export MAX_STEP=50 14 | export WORKLOAD_NAME=${USER}-xpk-${TPU_TYPE}-... # Your workload name. Need to update for different run. 15 | export BASE_DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-llama@sha256:310c661423206337ef27ed06597830c52ae03c3383af411a89b3be9e4bc10aca 16 | export PROFILE_LOG_DIR=... # GSC bucket to store profile in form of gs://... 17 | export HF_TOKEN=... # Add your onw Hugging face token to download model 18 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-70B-PyTorch/XPK/fsdp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fsdp_transformer_layer_cls_to_wrap": [ 3 | "LlamaDecoderLayer" 4 | ], 5 | "xla": true, 6 | "xla_fsdp_v2": true, 7 | "xla_fsdp_grad_ckpt": true 8 | } 9 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-70B-PyTorch/XPK/train.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | # XPK will create a new docker and copy env.sh file under /app/ 4 | source /app/env.sh 5 | 6 | # Calculate the global batch size 7 | # Extract the number after '-' in TPU_TYPE 8 | TPU_NUM=$(echo "$TPU_TYPE" | grep -oP '(?<=-)\d+') 9 | # Calculate GLOBAL_BATCH_SIZE 10 | GLOBAL_BATCH_SIZE=$(( TPU_NUM * BATCH_PER_DEVICE * NUM_SLICE )) 11 | export GLOBAL_BATCH_SIZE 12 | echo "GLOBAL_BATCH_SIZE=$GLOBAL_BATCH_SIZE" 13 | 14 | # Note --per_device_train_batch_size is the global batch size since we overwrite the dataloader in the HF trainer. 15 | cd /root/ && \ 16 | export PJRT_DEVICE=TPU && \ 17 | export XLA_USE_SPMD=1 && \ 18 | export ENABLE_PJRT_COMPATIBILITY=true && \ 19 | export XLA_IR_DEBUG=1 && \ 20 | export XLA_HLO_DEBUG=1 && \ 21 | export PROFILE_EPOCH=0 && \ 22 | export PROFILE_STEP=3 && \ 23 | export PROFILE_DURATION_MS=100000 && \ 24 | export PROFILE_LOGDIR=${PROFILE_LOG_DIR} && \ 25 | export XLA_PERSISTENT_CACHE_PATH=/app/xla_cache/ && \ 26 | export TPU_LIBRARY_PATH=/root/_libtpu.so && \ 27 | export NUM_SLICE=${NUM_SLICE} && \ 28 | 29 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_flash_attention=false --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_scoped_vmem_limit_kib=81920" 30 | 31 | huggingface-cli login --token=${HF_TOKEN} && \ 32 | python3 transformers/examples/pytorch/language-modeling/run_clm.py \ 33 | --dataset_name=wikitext \ 34 | --dataset_config_name=wikitext-103-raw-v1 \ 35 | --per_device_train_batch_size=${GLOBAL_BATCH_SIZE} \ 36 | --do_train \ 37 | --output_dir=test-clm \ 38 | --overwrite_output_dir \ 39 | --config_name=config_70b.json \ 40 | --cache_dir=cache \ 41 | --tokenizer_name=meta-llama/Meta-Llama-3-70B \ 42 | --block_size=${SEQUENCE_LENGTH} \ 43 | --optim=adafactor \ 44 | --save_strategy=no \ 45 | --logging_strategy=no \ 46 | --fsdp="full_shard" \ 47 | --fsdp_config=fsdp_config.json \ 48 | --torch_dtype=bfloat16 \ 49 | --dataloader_drop_last=yes \ 50 | --flash_attention \ 51 | --max_steps=${MAX_STEP} 52 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-8B-PyTorch/GCE/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Llama 3.0 8B on Trillium TPU 2 | 3 | This user guide provides a concise overview of the essential steps required to 4 | run Hugging Face (HF) Llama 3.0 8B training on Trillium TPUs. 5 | 6 | ## Environment Setup 7 | 8 | Please follow the corresponding TPU generation's user guide to setup the GCE TPUs 9 | first. 10 | 11 | Please replace all your-* with your TPUs' information. 12 | 13 | ``` 14 | export TPU_NAME=your-tpu-name 15 | export ZONE=your-tpu-zone 16 | export PROJECT=your-tpu-project 17 | ``` 18 | 19 | You may use this command to create a 256 chip Trillium pod: 20 | 21 | ```bash 22 | gcloud alpha compute tpus tpu-vm create $TPU_NAME \ 23 | --type v6e --topology 16x16 \ 24 | --project $PROJECT --zone $ZONE --version v2-alpha-tpuv6e 25 | ``` 26 | 27 | ## Steps to Run HF Llama 3.0 8B 28 | 29 | The following setup runs the training job with Llama 3.0 8B on GCE TPUs using 30 | the docker image from this registry 31 | (`us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-llama:v1`). 32 | The docker image uses torch and torch_xla nightly build from 02/11/2025 33 | and comes with all the package dependency needed to run the model training. 34 | All the command below should run from your own machine (not the TPU host you 35 | created). The Dockerfile is at https://github.com/pytorch-tpu/transformers/blob/flash_attention/Dockerfile 36 | 37 | 1. git clone and navigate to this README repo and run training script: 38 | 39 | ```bash 40 | git clone --depth 1 https://github.com/AI-Hypercomputer/tpu-recipes.git 41 | cd training/trillium/Llama3.0-8B-PyTorch/GCE 42 | ``` 43 | 44 | 2. Edit `env.sh` to add the hugging face token and/or setup the training parameters. 45 | 46 | ```bash 47 | # add your hugging face token 48 | HF_TOKEN=hf_*** 49 | ``` 50 | 51 | 3. Run the training script: 52 | 53 | ```bash 54 | ./benchmark.sh 55 | ``` 56 | 57 | `benchmark.sh` script will: upload 1) environment parameters in `env.sh`, 2) 58 | model related config in `config.json`, `fsdp_config.json`, 3) docker launch 59 | script in `host.sh` and 4) python training command in `train.sh` into all TPU 60 | workers, and starts the training afterwards. When all training steps complete, 61 | it will print out training metrics of each worker as below in terminal: 62 | 63 | ``` 64 | [worker :0] ***** train metrics ***** 65 | [worker :0] epoch = 0.3125 66 | [worker :0] total_flos = 10915247040GF 67 | [worker :0] train_loss = 9.278 68 | [worker :0] train_runtime = 0:46:45.60 69 | [worker :0] train_samples = 32816 70 | [worker :0] train_samples_per_second = 3.65 71 | [worker :0] train_steps_per_second = 0.007 72 | ``` 73 | 74 | In addition, it will copy back the trained model under `output/*`. 75 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-8B-PyTorch/GCE/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | # SCP the environment setup to all instances. 5 | gcloud compute tpus tpu-vm scp config.json fsdp_config.json train.sh host.sh env.sh "$TPU_NAME:~" --worker=all --project $PROJECT --zone=$ZONE 6 | 7 | # Actually runs the benchmark. 8 | gcloud compute tpus tpu-vm ssh $TPU_NAME --project $PROJECT --zone=$ZONE --worker=all --command="$(cat host.sh)" 9 | 10 | # Copy the profile and output back 11 | gcloud compute tpus tpu-vm scp --recurse $TPU_NAME:~/output ./ --project=$PROJECT --zone=$ZONE 12 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-8B-PyTorch/GCE/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 128000, 8 | "eos_token_id": 128001, 9 | "hidden_act": "silu", 10 | "hidden_size": 4096, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 14336, 13 | "max_position_embeddings": 8192, 14 | "model_type": "llama", 15 | "num_attention_heads": 32, 16 | "num_hidden_layers": 32, 17 | "num_key_value_heads": 8, 18 | "pretraining_tp": 1, 19 | "rms_norm_eps": 1e-05, 20 | "rope_scaling": null, 21 | "rope_theta": 500000.0, 22 | "tie_word_embeddings": false, 23 | "torch_dtype": "bfloat16", 24 | "transformers_version": "4.40.0.dev0", 25 | "use_cache": false, 26 | "vocab_size": 128256 27 | } 28 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-8B-PyTorch/GCE/env.sh: -------------------------------------------------------------------------------- 1 | # Uncomment below to set the Huggingface token 2 | HF_TOKEN=hf_*** 3 | PJRT_DEVICE=TPU 4 | XLA_IR_DEBUG=1 5 | XLA_HLO_DEBUG=1 6 | PROFILE_EPOCH=0 7 | PROFILE_STEP=3 8 | PROFILE_DURATION_MS=120000 9 | PROFILE_LOGDIR=/tmp/home/profile 10 | XLA_USE_SPMD=1 11 | MAX_STEPS=20 12 | SEQ_LENGTH=8192 13 | 14 | # Adjust as per needed. The following batch size is known to work for a full 15 | # Trillium pod of 256 chips 16 | GLOBAL_BATCH_SIZE=1024 17 | 18 | # XLA flags 19 | # Quoting is not needed, c.f. https://github.com/moby/moby/issues/46773 20 | LIBTPU_INIT_ARGS=--xla_tpu_enable_flash_attention=false --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_scoped_vmem_limit_kib=98304 21 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-8B-PyTorch/GCE/fsdp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fsdp_transformer_layer_cls_to_wrap": [ 3 | "LlamaDecoderLayer" 4 | ], 5 | "xla": true, 6 | "xla_fsdp_v2": true, 7 | "xla_fsdp_grad_ckpt": true 8 | } 9 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-8B-PyTorch/GCE/host.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-llama:v2 4 | worker_id=$(curl -s "http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number" -H 'Metadata-Flavor: Google') 5 | 6 | cat >> /dev/null <&1 | sed "s/^/[worker $worker_id] /g" | tee runlog 10 | set -o xtrace 11 | # Configure docker 12 | sudo groupadd docker 13 | sudo usermod -aG docker $USER 14 | # newgrp applies updated group permissions 15 | newgrp - docker 16 | gcloud auth configure-docker us-central1-docker.pkg.dev --quiet 17 | # Kill any running benchmarks 18 | docker kill $USER-test 19 | docker pull $DOCKER_IMAGE 20 | docker run --rm \ 21 | --name $USER-test \ 22 | --privileged \ 23 | --env-file env.sh \ 24 | -v /home/$USER:/tmp/home \ 25 | --shm-size=16G \ 26 | --net host \ 27 | -u root \ 28 | --entrypoint /bin/bash $DOCKER_IMAGE \ 29 | /tmp/home/train.sh 30 | 31 | PIPE_EOF 32 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-8B-PyTorch/GCE/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Remove existing repo and old data. 3 | LOCAL_DIR=/tmp/home/ 4 | rm -rf "${LOCAL_DIR}/output" 5 | rm -rf "${LOCAL_DIR}/plugins" 6 | rm -rf "${LOCAL_DIR}/cache" 7 | mkdir -p "${LOCAL_DIR}/output" 8 | mkdir -p "${LOCAL_DIR}/plugins" 9 | mkdir -p "${LOCAL_DIR}/cache" 10 | 11 | unset LD_PRELOAD 12 | 13 | 14 | cd /workspace/transformers/ 15 | 16 | 17 | python3 examples/pytorch/language-modeling/run_clm.py \ 18 | --dataset_name wikitext \ 19 | --dataset_config_name wikitext-103-raw-v1 \ 20 | --per_device_train_batch_size "${GLOBAL_BATCH_SIZE}" \ 21 | --do_train \ 22 | --output_dir "${LOCAL_DIR}/output/test-clm" \ 23 | --overwrite_output_dir \ 24 | --config_name "${LOCAL_DIR}/config.json" \ 25 | --cache_dir "${LOCAL_DIR}/cache" \ 26 | --tokenizer_name meta-llama/Meta-Llama-3-8B \ 27 | --block_size "$SEQ_LENGTH" \ 28 | --optim adafactor \ 29 | --save_strategy no \ 30 | --logging_strategy no \ 31 | --fsdp "full_shard" \ 32 | --fsdp_config "${LOCAL_DIR}/fsdp_config.json" \ 33 | --torch_dtype bfloat16 \ 34 | --dataloader_drop_last yes \ 35 | --flash_attention \ 36 | --max_steps "$MAX_STEPS" 37 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-8B-PyTorch/XPK/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source env.sh 4 | 5 | python3 xpk/xpk.py workload create \ 6 | --cluster ${CLUSTER_NAME} \ 7 | --base-docker-image=${BASE_DOCKER_IMAGE} \ 8 | --workload=${WORKLOAD_NAME} \ 9 | --tpu-type=${TPU_TYPE} \ 10 | --num-slices=${NUM_SLICE} \ 11 | --on-demand \ 12 | --zone=$ZONE \ 13 | --project=$PROJECT \ 14 | --enable-debug-logs \ 15 | --command="bash /app/train.sh" 16 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-8B-PyTorch/XPK/config_8b.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 128000, 8 | "eos_token_id": 128001, 9 | "hidden_act": "silu", 10 | "hidden_size": 4096, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 14336, 13 | "max_position_embeddings": 8192, 14 | "model_type": "llama", 15 | "num_attention_heads": 32, 16 | "num_hidden_layers": 32, 17 | "num_key_value_heads": 8, 18 | "pretraining_tp": 1, 19 | "rms_norm_eps": 1e-05, 20 | "rope_scaling": null, 21 | "rope_theta": 500000.0, 22 | "tie_word_embeddings": false, 23 | "torch_dtype": "bfloat16", 24 | "transformers_version": "4.40.0.dev0", 25 | "use_cache": false, 26 | "vocab_size": 128256 27 | } 28 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-8B-PyTorch/XPK/env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Environment variables associated with XPK on GCP. 4 | export ZONE=... 5 | export PROJECT=... 6 | export TPU_TYPE=v6e-256 7 | export NUM_SLICE=1 8 | export CLUSTER_NAME=xpk-$USER-... # use existing CLUSTER if you have 9 | 10 | # Environment variables associated with training config. 11 | export BATCH_PER_DEVICE=4 12 | export SEQUENCE_LENGTH=8192 13 | export MAX_STEP=50 14 | export WORKLOAD_NAME=${USER}-xpk-${TPU_TYPE}-... # Your workload name. Need to update for different run. 15 | export BASE_DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-llama:v1 16 | export PROFILE_LOG_DIR=... # GSC bucket to store profile in form of gs://... 17 | export HF_TOKEN=... # Add your onw Hugging face token to download model 18 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-8B-PyTorch/XPK/fsdp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fsdp_transformer_layer_cls_to_wrap": [ 3 | "LlamaDecoderLayer" 4 | ], 5 | "xla": true, 6 | "xla_fsdp_v2": true, 7 | "xla_fsdp_grad_ckpt": true 8 | } 9 | -------------------------------------------------------------------------------- /training/trillium/Llama3.0-8B-PyTorch/XPK/train.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | # XPK will create a new docker and copy env.sh file under /app/ 4 | source /app/env.sh 5 | 6 | # Calculate the global batch size 7 | # Extract the number after '-' in TPU_TYPE 8 | TPU_NUM=$(echo "$TPU_TYPE" | grep -oP '(?<=-)\d+') 9 | # Calculate GLOBAL_BATCH_SIZE 10 | GLOBAL_BATCH_SIZE=$(( TPU_NUM * BATCH_PER_DEVICE * NUM_SLICE )) 11 | export GLOBAL_BATCH_SIZE 12 | echo "GLOBAL_BATCH_SIZE=$GLOBAL_BATCH_SIZE" 13 | 14 | # Note --per_device_train_batch_size is the global batch size since we overwrite the dataloader in the HF trainer. 15 | cd /workspace/ && \ 16 | export PJRT_DEVICE=TPU && \ 17 | export XLA_USE_SPMD=1 && \ 18 | export ENABLE_PJRT_COMPATIBILITY=true && \ 19 | export XLA_IR_DEBUG=1 && \ 20 | export XLA_HLO_DEBUG=1 && \ 21 | export PROFILE_EPOCH=0 && \ 22 | export PROFILE_STEP=3 && \ 23 | export PROFILE_DURATION_MS=100000 && \ 24 | export PROFILE_LOGDIR=${PROFILE_LOG_DIR} && \ 25 | export XLA_PERSISTENT_CACHE_PATH=/app/xla_cache/ && \ 26 | export TPU_LIBRARY_PATH=/workspace/_libtpu.so && \ 27 | export NUM_TPU_SLICE=${NUM_SLICE} && \ 28 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_flash_attention=false --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_scoped_vmem_limit_kib=98304" 29 | 30 | huggingface-cli login --token=${HF_TOKEN} && \ 31 | python3 transformers/examples/pytorch/language-modeling/run_clm.py \ 32 | --dataset_name=wikitext \ 33 | --dataset_config_name=wikitext-103-raw-v1 \ 34 | --per_device_train_batch_size=${GLOBAL_BATCH_SIZE} \ 35 | --do_train \ 36 | --output_dir=test-clm \ 37 | --overwrite_output_dir \ 38 | --config_name=/app/config_8b.json \ 39 | --cache_dir=cache \ 40 | --tokenizer_name=meta-llama/Meta-Llama-3-8B \ 41 | --block_size=${SEQUENCE_LENGTH} \ 42 | --optim=adafactor \ 43 | --save_strategy=no \ 44 | --logging_strategy=no \ 45 | --fsdp="full_shard" \ 46 | --fsdp_config=/app/fsdp_config.json \ 47 | --torch_dtype=bfloat16 \ 48 | --dataloader_drop_last=yes \ 49 | --flash_attention \ 50 | --max_steps=${MAX_STEP} 51 | -------------------------------------------------------------------------------- /training/trillium/Llama3.1-405B-MaxText/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Llama3.1-405B-MaxText on TPU trillium 2 | 3 | ## XPK setup 4 | Please follow the [XPK_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK 5 | 6 | ## Prep for Maxtext 7 | 8 | ### Install MaxText and Build Docker Image 9 | Please follow the [MAXTEXT_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build the docker image. The following variables should be set: 10 | 11 | In step 1, use the MaxText [tpu-recipes-v0.1.2](https://github.com/AI-Hypercomputer/maxtext/releases/tag/tpu-recipes-v0.1.2) tag to run this recipe: 12 | ``` 13 | git checkout tpu-recipes-v0.1.2 14 | ``` 15 | 16 | In step 3, use the jax-stable-stack image containing JAX 0.5.2: 17 | ``` 18 | BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 19 | bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE} 20 | ``` 21 | 22 | ## Run Maxtext Llama3.1-405B workloads on GKE 23 | 24 | ### Starting workload 25 | 26 | From the MaxText root directory, start your Llama3.1-405B workload. 27 | ``` 28 | python3 -m benchmarks.benchmark_runner xpk \ 29 | --project=$PROJECT \ 30 | --zone=$ZONE \ 31 | --device_type=v6e-256 \ 32 | --num_slices=2 \ 33 | --cluster_name=${CLUSTER_NAME} \ 34 | --base_output_directory=${OUTPUT_DIR} \ 35 | --model_name="llama3_1_405b_8192_pure_fsdp_ici" \ 36 | --base_docker_image=maxtext_base_image 37 | ``` 38 | 39 | From your workload logs, you should start seeing step time logs like the following: 40 | ``` 41 | completed step: 14, seconds: 54.803, TFLOP/s/device: 392.454, Tokens/s/device: 149.482, total_weights: 4194304, loss: 0.297 42 | ``` 43 | 44 | ### Workload Details 45 | 46 | For reference, here are the `llama3_1_405b_8192_pure_fsdp_ici` workload details as found in `MaxText@tpu-recipes-v0.1.2`: 47 | 48 | ``` 49 | MaxTextModel( 50 | model_name="llama3-1-405b-8192-pure-fsdp-ici", 51 | model_type="llama3.1-405b", 52 | tuning_params={ 53 | "per_device_batch_size": 1, 54 | "ici_fsdp_parallelism": 256, 55 | "dcn_fsdp_parallelism": 2, 56 | "remat_policy": "custom", 57 | "decoder_layer_input": "offload", 58 | "max_target_length": 8192, 59 | "attention": "flash", 60 | "gcs_metrics": True, 61 | "use_iota_embed": True, 62 | "dataset_path": "gs://max-datasets-rogue", 63 | "dataset_type": "synthetic", 64 | "reuse_example_batch": 1, 65 | "enable_checkpointing": False, 66 | "profiler": "xplane", 67 | "sa_block_q": 1024, 68 | "sa_block_q_dkv": 2048, 69 | "sa_block_q_dq": 2048, 70 | }, 71 | xla_flags=( 72 | xla_flags_library.DENSE_VMEM_LIMIT_FLAG 73 | + xla_flags_library.CF_FOR_ALL_GATHER 74 | + xla_flags_library.HOST_OFFLOAD_FLAGS 75 | ), 76 | ) 77 | ``` 78 | 79 | This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/tpu-recipes-v0.1.2/benchmarks/maxtext_trillium_model_configs.py) file within the MaxText repository. -------------------------------------------------------------------------------- /training/trillium/Llama3.1-405B-MaxText/llama3-1-405b-2xv6e-256.sh: -------------------------------------------------------------------------------- 1 | # Run this command from the MaxText root directory using the setup described in the README. 2 | python3 -m benchmarks.benchmark_runner xpk \ 3 | --project=$PROJECT \ 4 | --zone=$ZONE \ 5 | --device_type=v6e-256 \ 6 | --num_slices=2 \ 7 | --cluster_name=${CLUSTER_NAME} \ 8 | --base_output_directory=${OUTPUT_DIR} \ 9 | --model_name="llama3_1_405b_8192_pure_fsdp_ici" \ 10 | --base_docker_image=maxtext_base_image 11 | -------------------------------------------------------------------------------- /training/trillium/Llama3.1-405B-PyTorch/GCE/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Llama 3.1 405B on Trillium TPU (1 pod) 2 | 3 | This user guide provides a concise overview of the essential steps required to 4 | run Hugging Face (HF) Llama 3.1 405B training on Trillium TPUs. Specifically, 5 | the instructions and docker image referenced here is optimized for a single 6 | Trillium pod. 7 | 8 | ## Environment Setup 9 | 10 | Please follow the corresponding TPU generation's user guide to setup the GCE TPUs 11 | first. 12 | 13 | Please replace all your-* with your TPUs' information. 14 | 15 | ``` 16 | export TPU_NAME=your-tpu-name 17 | export ZONE=your-tpu-zone 18 | export PROJECT=your-tpu-project 19 | ``` 20 | 21 | You may use this command to create a 256 chip Trillium pod: 22 | 23 | ```bash 24 | gcloud alpha compute tpus tpu-vm create $TPU_NAME \ 25 | --type v6e --topology 16x16 \ 26 | --project $PROJECT --zone $ZONE --version v2-alpha-tpuv6e 27 | ``` 28 | 29 | ## Steps to Run HF Llama 3.1 405B 30 | 31 | The following setup runs the training job with Llama 3.1 405B on GCE TPUs using 32 | the docker image from this registry 33 | (`us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-xla/llama3-405b:nightly-sep28`). 34 | The docker image uses torch and torch_xla nightly build from 09/28/2024 35 | and comes with all the package dependency needed to run the model training. 36 | All the command below should run from your own machine (not the TPU host you 37 | created). 38 | 39 | 1. git clone and navigate to this README repo and run training script: 40 | 41 | ```bash 42 | git clone --depth 1 https://github.com/AI-Hypercomputer/tpu-recipes.git 43 | cd training/trillium/Llama3.1-405B-PyTorch 44 | ``` 45 | 46 | 2. Edit `env.sh` to add the hugging face token and/or setup the training parameters. 47 | 48 | ```bash 49 | # add your hugging face token into `env.sh`, replacing the placeholder there. 50 | HF_TOKEN=hf_*** 51 | ``` 52 | 53 | 3. Run the training script: 54 | 55 | ```bash 56 | ./benchmark.sh 57 | ``` 58 | 59 | `benchmark.sh` script will: upload 1) environment parameters in `env.sh`, 2) 60 | model related config in `config.json`, 3) docker launch 61 | script in `host.sh` and 4) python training command in `train.sh` into all TPU 62 | workers, and starts the training afterwards. When all training steps complete, 63 | it will print out training metrics of each worker as below in terminal: 64 | 65 | ``` 66 | [worker :0] ***** train metrics ***** 67 | [worker :0] epoch = 0.3125 68 | [worker :0] total_flos = 10915247040GF 69 | [worker :0] train_loss = 9.278 70 | [worker :0] train_runtime = 0:46:45.60 71 | [worker :0] train_samples = 32816 72 | [worker :0] train_samples_per_second = 3.65 73 | [worker :0] train_steps_per_second = 0.007 74 | ``` 75 | 76 | ## Profiles 77 | 78 | Profiles will be saved under `/home/$USER/profile` in the host VM. 79 | Use `env.sh` to customize the profiling start step and duration. 80 | -------------------------------------------------------------------------------- /training/trillium/Llama3.1-405B-PyTorch/GCE/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | # SCP the environment setup to all instances. 5 | gcloud compute tpus tpu-vm scp config.json train.sh host.sh env.sh "$TPU_NAME:~" --worker=all --project $PROJECT --zone=$ZONE 6 | 7 | # Actually runs the benchmark. 8 | gcloud compute tpus tpu-vm ssh $TPU_NAME --project $PROJECT --zone=$ZONE --worker=all --command="$(cat host.sh)" 9 | 10 | # Copy the profile and output back 11 | gcloud compute tpus tpu-vm scp --recurse $TPU_NAME:~/output ./ --project=$PROJECT --zone=$ZONE 12 | -------------------------------------------------------------------------------- /training/trillium/Llama3.1-405B-PyTorch/GCE/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 128000, 8 | "eos_token_id": 128001, 9 | "hidden_act": "silu", 10 | "hidden_size": 16384, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 53248, 13 | "max_position_embeddings": 131072, 14 | "mlp_bias": false, 15 | "model_type": "llama", 16 | "num_attention_heads": 128, 17 | "num_hidden_layers": 126, 18 | "num_key_value_heads": 8, 19 | "pretraining_tp": 1, 20 | "rms_norm_eps": 1e-05, 21 | "rope_scaling": { 22 | "factor": 8.0, 23 | "low_freq_factor": 1.0, 24 | "high_freq_factor": 4.0, 25 | "original_max_position_embeddings": 8192, 26 | "rope_type": "llama3" 27 | }, 28 | "rope_theta": 500000.0, 29 | "tie_word_embeddings": false, 30 | "torch_dtype": "bfloat16", 31 | "transformers_version": "4.42.3", 32 | "use_cache": false, 33 | "vocab_size": 128256 34 | } 35 | -------------------------------------------------------------------------------- /training/trillium/Llama3.1-405B-PyTorch/GCE/env.sh: -------------------------------------------------------------------------------- 1 | # Uncomment below to set the Huggingface token 2 | HF_TOKEN=hf_*** 3 | PJRT_DEVICE=TPU 4 | XLA_IR_DEBUG=1 5 | XLA_HLO_DEBUG=1 6 | PROFILE_EPOCH=0 7 | PROFILE_STEP=10 8 | PROFILE_DURATION_MS=240000 9 | PROFILE_LOGDIR=/tmp/home/profile 10 | XLA_USE_SPMD=1 11 | MAX_STEPS=40 12 | SEQ_LENGTH=8192 13 | 14 | # Global batch size in each training step. 15 | GLOBAL_BATCH_SIZE=64 16 | 17 | # XLA flags 18 | # Quoting is not needed, c.f. https://github.com/moby/moby/issues/46773 19 | LIBTPU_INIT_ARGS=--xla_tpu_enable_flash_attention=false --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_scoped_vmem_limit_kib=98304 20 | -------------------------------------------------------------------------------- /training/trillium/Llama3.1-405B-PyTorch/GCE/host.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-xla/llama3-405b:nightly-sep28 4 | 5 | worker_id=$(curl -s "http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number" -H 'Metadata-Flavor: Google') 6 | 7 | cat >> /dev/null <&1 | sed "s/^/[worker $worker_id] /g" | tee runlog 11 | set -o xtrace 12 | # Configure docker 13 | sudo groupadd docker 14 | sudo usermod -aG docker $USER 15 | # newgrp applies updated group permissions 16 | newgrp - docker 17 | gcloud auth configure-docker us-central1-docker.pkg.dev --quiet 18 | # Kill any running benchmarks 19 | docker kill $USER-test 20 | docker pull $DOCKER_IMAGE 21 | docker run --rm \ 22 | --name $USER-test \ 23 | --privileged \ 24 | --env-file env.sh \ 25 | -v /home/$USER:/tmp/home \ 26 | --shm-size=16G \ 27 | --net host \ 28 | -u root \ 29 | --entrypoint /bin/bash $DOCKER_IMAGE \ 30 | /tmp/home/train.sh 31 | 32 | PIPE_EOF 33 | -------------------------------------------------------------------------------- /training/trillium/Llama3.1-405B-PyTorch/GCE/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Remove existing repo and old data. 3 | LOCAL_DIR=/tmp/home/ 4 | rm -rf "${LOCAL_DIR}/output" 5 | rm -rf "${LOCAL_DIR}/plugins" 6 | rm -rf "${LOCAL_DIR}/cache" 7 | mkdir -p "${LOCAL_DIR}/output" 8 | mkdir -p "${LOCAL_DIR}/plugins" 9 | mkdir -p "${LOCAL_DIR}/cache" 10 | 11 | unset LD_PRELOAD 12 | 13 | 14 | cd transformers/ 15 | 16 | # In this command, each host still loads the entire 17 | # minibatch of training data, but will only transfer 18 | # the subset of data required by the TPUs connected 19 | # to the host. 20 | python3 examples/pytorch/language-modeling/run_clm.py \ 21 | --dataset_name wikitext \ 22 | --dataset_config_name wikitext-103-raw-v1 \ 23 | --per_device_train_batch_size "${GLOBAL_BATCH_SIZE}" \ 24 | --do_train \ 25 | --output_dir "${LOCAL_DIR}/output/test-clm" \ 26 | --overwrite_output_dir \ 27 | --config_name "${LOCAL_DIR}/config.json" \ 28 | --cache_dir "${LOCAL_DIR}/cache" \ 29 | --tokenizer_name meta-llama/Meta-Llama-3.1-405B \ 30 | --block_size "$SEQ_LENGTH" \ 31 | --optim adafactor \ 32 | --save_strategy no \ 33 | --logging_strategy no \ 34 | --torch_dtype bfloat16 \ 35 | --dataloader_drop_last yes \ 36 | --flash_attention \ 37 | --spmd_2d_sharding 4 \ 38 | --max_steps "$MAX_STEPS" 39 | -------------------------------------------------------------------------------- /training/trillium/Llama3.1-405B-PyTorch/XPK/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source env.sh 4 | 5 | python3 xpk/xpk.py workload create \ 6 | --cluster ${CLUSTER_NAME} \ 7 | --base-docker-image=${BASE_DOCKER_IMAGE} \ 8 | --workload=${WORKLOAD_NAME} \ 9 | --tpu-type=${TPU_TYPE} \ 10 | --num-slices=${NUM_SLICE} \ 11 | --on-demand \ 12 | --zone=$ZONE \ 13 | --project=$PROJECT \ 14 | --enable-debug-logs \ 15 | --command="bash /app/train.sh" 16 | -------------------------------------------------------------------------------- /training/trillium/Llama3.1-405B-PyTorch/XPK/config_405b.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 128000, 8 | "eos_token_id": 128001, 9 | "hidden_act": "silu", 10 | "hidden_size": 16384, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 53248, 13 | "max_position_embeddings": 131072, 14 | "mlp_bias": false, 15 | "model_type": "llama", 16 | "num_attention_heads": 128, 17 | "num_hidden_layers": 126, 18 | "num_key_value_heads": 8, 19 | "pretraining_tp": 1, 20 | "rms_norm_eps": 1e-05, 21 | "rope_scaling": { 22 | "factor": 8.0, 23 | "low_freq_factor": 1.0, 24 | "high_freq_factor": 4.0, 25 | "original_max_position_embeddings": 8192, 26 | "rope_type": "llama3" 27 | }, 28 | "rope_theta": 500000.0, 29 | "tie_word_embeddings": false, 30 | "torch_dtype": "bfloat16", 31 | "transformers_version": "4.42.3", 32 | "use_cache": false, 33 | "vocab_size": 128256 34 | } 35 | -------------------------------------------------------------------------------- /training/trillium/Llama3.1-405B-PyTorch/XPK/env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Environment variables associated with XPK on GCP. 4 | export ZONE=... 5 | export PROJECT=... 6 | export TPU_TYPE=v6e-256 7 | export NUM_SLICE=2 8 | export CLUSTER_NAME=xpk-$USER-... # use existing CLUSTER if you have 9 | 10 | # Environment variables associated with training config. 11 | export BATCH_PER_DEVICE=1 12 | export SEQUENCE_LENGTH=8192 13 | export MAX_STEP=50 14 | export WORKLOAD_NAME=${USER}-xpk-${TPU_TYPE}-... # Your workload name. Need to update for different run. 15 | export BASE_DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-llama@sha256:d3a4c09cd13dab2af8129e8438b0acf3f8b5a2370b94b69e2e3aac16530e3664 16 | export PROFILE_LOG_DIR=... # GSC bucket to store profile in form of gs://... 17 | export HF_TOKEN=... # Add your own Hugging face token to download model 18 | -------------------------------------------------------------------------------- /training/trillium/Llama3.1-405B-PyTorch/XPK/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # XPK will create a new docker and copy env.sh file under /app/ 4 | source /app/env.sh 5 | 6 | # Calculate the global batch size 7 | # Extract the number after '-' in TPU_TYPE 8 | TPU_NUM=$(echo "$TPU_TYPE" | grep -oP '(?<=-)\d+') 9 | # Calculate GLOBAL_BATCH_SIZE 10 | GLOBAL_BATCH_SIZE=$(( TPU_NUM * BATCH_PER_DEVICE * NUM_SLICE )) 11 | export GLOBAL_BATCH_SIZE 12 | echo "GLOBAL_BATCH_SIZE=$GLOBAL_BATCH_SIZE" 13 | 14 | # Note --per_device_train_batch_size is the global batch size since we overwrite the dataloader in the HF trainer. 15 | cd /workspace/ && \ 16 | export PJRT_DEVICE=TPU && \ 17 | export XLA_USE_SPMD=1 && \ 18 | export ENABLE_PJRT_COMPATIBILITY=true && \ 19 | export XLA_IR_DEBUG=1 && \ 20 | export XLA_HLO_DEBUG=1 && \ 21 | export PROFILE_EPOCH=0 && \ 22 | export PROFILE_STEP=3 && \ 23 | export PROFILE_DURATION_MS=450000 && \ 24 | export PROFILE_LOGDIR=${PROFILE_LOG_DIR} && \ 25 | export XLA_PERSISTENT_CACHE_PATH=/app/xla_cache/ && \ 26 | export TPU_LIBRARY_PATH=/workspace/_libtpu.so && \ 27 | export NUM_SLICE=${NUM_SLICE} && \ 28 | 29 | export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_spmd_rng_bit_generator_unsafe=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_use_enhanced_launch_barrier=true --xla_tpu_enable_all_experimental_scheduler_features=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=2 --xla_tpu_aggressive_opt_barrier_removal=ENABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=100 --xla_latency_hiding_scheduler_rerun=2 --megascale_graph_hang_threshold=30m --megascale_graph_within_launch_hang_threshold=30m --megascale_grpc_enable_xor_tracer=false --megascale_grpc_premap_memory_bytes=68719476736 --megascale_grpc_use_chaotic_good=true --megascale_grpc_use_event_engine_allocator=true --grpc_enable_tcp_recv_zerocopy=false --grpc_enable_rpc_receive_coalescing=true" 30 | 31 | huggingface-cli login --token=${HF_TOKEN} && \ 32 | python3 transformers/examples/pytorch/language-modeling/run_clm.py \ 33 | --dataset_name=wikitext \ 34 | --dataset_config_name=wikitext-103-raw-v1 \ 35 | --per_device_train_batch_size=${GLOBAL_BATCH_SIZE} \ 36 | --do_train \ 37 | --output_dir=test-clm \ 38 | --overwrite_output_dir \ 39 | --config_name=/app/config_405b.json \ 40 | --cache_dir=cache \ 41 | --tokenizer_name=meta-llama/Meta-Llama-3.1-405B \ 42 | --block_size=${SEQUENCE_LENGTH} \ 43 | --optim=adafactor \ 44 | --save_strategy=no \ 45 | --logging_strategy=no \ 46 | --torch_dtype=bfloat16 \ 47 | --dataloader_drop_last=yes \ 48 | --flash_attention \ 49 | --spmd_2d_sharding=4 \ 50 | --max_steps=${MAX_STEP} 51 | -------------------------------------------------------------------------------- /training/trillium/Llama3.1-70B-MaxText/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Llama3.1-70B-MaxText on TPU trillium (v6e-256) 2 | 3 | ## XPK setup 4 | Please follow the [XPK_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK 5 | 6 | ## Prep for Maxtext 7 | 8 | ### Install MaxText and Build Docker Image 9 | Please follow the [MAXTEXT_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build the docker image. The following variables should be set: 10 | 11 | In step 1, use the MaxText [tpu-recipes-v0.1.2](https://github.com/AI-Hypercomputer/maxtext/releases/tag/tpu-recipes-v0.1.2) tag to run this recipe: 12 | ``` 13 | git checkout tpu-recipes-v0.1.2 14 | ``` 15 | 16 | In step 3, use the jax-stable-stack image containing JAX 0.5.2: 17 | ``` 18 | BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 19 | bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE} 20 | ``` 21 | 22 | ## Run MaxText Llama3.1-70B workloads on GKE 23 | 24 | ### Starting workload 25 | 26 | From the MaxText root directory, start your Llama3.1-70B workload 27 | ``` 28 | python3 -m benchmarks.benchmark_runner xpk \ 29 | --project=$PROJECT \ 30 | --zone=$ZONE \ 31 | --device_type=v6e-256 \ 32 | --num_slices=1 \ 33 | --cluster_name=${CLUSTER_NAME} \ 34 | --base_output_directory=${OUTPUT_DIR} \ 35 | --model_name="llama3_1_70b_8192" \ 36 | --base_docker_image=maxtext_base_image 37 | ``` 38 | 39 | From your workload logs, you should start seeing step time logs like the following: 40 | ``` 41 | completed step: 7, seconds: 34.562, TFLOP/s/device: 456.442, Tokens/s/device: 948.086, total_weights: 8388608, loss: 8.946 42 | ``` 43 | If you would like to run on multiple slices of v6e-256, you may modify the `--num_slices` flag. 44 | 45 | ### Workload Details 46 | 47 | For reference, here are the `llama3_1_70b_8192` workload details as found in `MaxText@tpu-recipes-v0.1.2`: 48 | 49 | ``` 50 | MaxTextModel( 51 | model_name="llama3_1-70b-8192", 52 | model_type="llama3.1-70b", 53 | tuning_params={ 54 | "per_device_batch_size": 5, 55 | "ici_fsdp_parallelism": -1, 56 | "remat_policy": "custom", 57 | "decoder_layer_input": "offload", 58 | "query_proj": "offload", 59 | "key_proj": "offload", 60 | "value_proj": "offload", 61 | "max_target_length": 8192, 62 | "attention": "flash", 63 | "use_iota_embed": True, 64 | "dataset_path": "gs://max-datasets-rogue", 65 | "dataset_type": "synthetic", 66 | "enable_checkpointing": False, 67 | "sa_block_q": 2048, 68 | "sa_block_kv": 2048, 69 | "sa_block_kv_compute": 2048, 70 | "sa_block_q_dkv": 2048, 71 | "sa_block_kv_dkv": 2048, 72 | "sa_block_kv_dkv_compute": 2048, 73 | "sa_block_q_dq": 2048, 74 | "sa_block_kv_dq": 2048, 75 | "sa_use_fused_bwd_kernel": True, 76 | "profiler": "xplane", 77 | "skip_first_n_steps_for_profiler": 10, 78 | "profiler_steps": 5, 79 | }, 80 | xla_flags=( 81 | xla_flags_library.DENSE_VMEM_LIMIT_FLAG 82 | + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER 83 | + xla_flags_library.DATA_PARALLEL_OVERLAP 84 | + xla_flags_library.CF_FOR_ALL_GATHER 85 | + xla_flags_library.HOST_OFFLOAD_FLAGS 86 | ), 87 | ) 88 | ``` 89 | 90 | This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/tpu-recipes-v0.1.2/benchmarks/maxtext_trillium_model_configs.py) file within the MaxText repository. -------------------------------------------------------------------------------- /training/trillium/Llama3.1-70B-MaxText/llama3-1-70B-1xv6e-256.sh: -------------------------------------------------------------------------------- 1 | # Run this command from the MaxText root directory using the setup described in the README. 2 | python3 -m benchmarks.benchmark_runner xpk \ 3 | --project=$PROJECT \ 4 | --zone=$ZONE \ 5 | --device_type=v6e-256 \ 6 | --num_slices=1 \ 7 | --cluster_name=${CLUSTER_NAME} \ 8 | --base_output_directory=${OUTPUT_DIR} \ 9 | --model_name="llama3_1_70b_8192" \ 10 | --base_docker_image=maxtext_base_image 11 | -------------------------------------------------------------------------------- /training/trillium/MAXTEXT_README.md: -------------------------------------------------------------------------------- 1 | # Prep for MaxText workloads on GKE 2 | 3 | > **_NOTE:_** We recommend running these instructions and kicking off your recipe 4 | workloads from a VM in GCP using Python 3.10. 5 | 6 | 1. Clone [MaxText](https://github.com/google/maxtext) repo and move to its directory 7 | ```shell 8 | git clone https://github.com/google/maxtext.git 9 | cd maxtext 10 | # Checkout either the commit id or MaxText tag. 11 | # Example: `git checkout tpu-recipes-v0.1.2` 12 | git checkout ${MAXTEXT_COMMIT_ID_OR_TAG} 13 | ``` 14 | 15 | 2. Install MaxText dependencies 16 | ```shell 17 | bash setup.sh 18 | ``` 19 | 20 | Optional: Use a virtual environment to setup and run your workloads. This can help with errors 21 | like `This environment is externally managed`: 22 | ```shell 23 | ## One time step of creating the venv 24 | VENV_DIR=~/venvp3 25 | python3 -m venv $VENV_DIR 26 | ## Enter your venv. 27 | source $VENV_DIR/bin/activate 28 | ## Install dependencies 29 | bash setup.sh 30 | ``` 31 | 32 | > **_NOTE:_** If you use a virtual environment, you must use the same one when running the 33 | [XPK Installation](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#installation) 34 | steps linked in the [XPK_README](XPK_README.md) as well as your relevant tpu-recipe workloads. 35 | 36 | 3. Run the following commands to build the docker image 37 | ```shell 38 | # Example BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 39 | BASE_IMAGE= 40 | bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE} 41 | ``` 42 | 43 | 4. Upload your docker image to Container Registry 44 | ```shell 45 | bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner 46 | ``` 47 | 48 | 5. Create your GCS bucket 49 | ```shell 50 | OUTPUT_DIR=gs://v6e-demo-run # 51 | gcloud storage buckets create ${OUTPUT_DIR} --project ${PROJECT} 52 | ``` 53 | 54 | 6. Specify your workload configs 55 | ```shell 56 | export PROJECT=# 57 | export ZONE=# 58 | export CLUSTER_NAME=v6e-demo # 59 | export OUTPUT_DIR=gs://v6e-demo/ # 60 | ``` 61 | 62 | # FAQ 63 | 64 | 1. If you see the following error when creating your virtual environment in step 2, install the 65 | required dependency using the output's provided command. You may need to run the command with `sudo`. This 66 | example is for Python3.10. 67 | ``` 68 | The virtual environment was not created successfully because ensurepip is not 69 | available. On Debian/Ubuntu systems, you need to install the python3-venv 70 | package using the following command. 71 | 72 | apt install python3.10-venv 73 | 74 | You may need to use sudo with that command. After installing the python3-venv 75 | package, recreate your virtual environment. 76 | 77 | Failing command: /home/bvandermoon/venvp3/bin/python3 78 | 79 | -bash: /home/bvandermoon/venvp3/bin/activate: No such file or directory 80 | ``` 81 | 82 | 2. If you see an error like the following while building your Docker image, there could be a pip versioning 83 | conflict in your cache. 84 | 85 | ``` 86 | ERROR: THESE PACKAGES DO NOT MATCH THE HASHES FROM THE REQUIREMENTS FILE. If you have updated the 87 | package versions, please update the hashes. Otherwise, examine the package contents carefully; 88 | someone may have tampered with them. 89 | unknown package: 90 | Expected sha256 b3e54983cd51875855da7c68ec05c05cf8bb08df361b1d5b69e05e40b0c9bd62 91 | Got f3b7ea1da59dc4f182437cebc7ef37b847d55c7ebfbc3ba286302f1c89ff5929 92 | ``` 93 | 94 | Try deleting your pip cache file: `rm ~/.cache/pip -rf`. Then retry the Docker build 95 | -------------------------------------------------------------------------------- /training/trillium/Mistral-7B-MaxText/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Mistral-7B-MaxText on TPU trillium (v6e-8) 2 | 3 | ## XPK setup 4 | Please follow the [XPK_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK 5 | 6 | ## Prep for Maxtext 7 | 8 | ### Install MaxText and Build Docker Image 9 | Please follow the [MAXTEXT_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build the docker image. The following variables should be set: 10 | 11 | In step 1, use the MaxText [tpu-recipes-v0.1.2](https://github.com/AI-Hypercomputer/maxtext/releases/tag/tpu-recipes-v0.1.2) tag to run this recipe: 12 | ``` 13 | git checkout tpu-recipes-v0.1.2 14 | ``` 15 | 16 | In step 3, use the jax-stable-stack image containing JAX 0.5.2: 17 | ``` 18 | BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 19 | bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE} 20 | ``` 21 | 22 | ## Run Maxtext Mistral-7B workloads on GKE 23 | 24 | ### Starting workload 25 | 26 | From the MaxText root directory, start your Mistral-7B workload. 27 | ``` 28 | python3 -m benchmarks.benchmark_runner xpk \ 29 | --project=$PROJECT \ 30 | --zone=$ZONE \ 31 | --device_type=v6e-8 \ 32 | --num_slices=1 \ 33 | --cluster_name=${CLUSTER_NAME} \ 34 | --base_output_directory=${OUTPUT_DIR} \ 35 | --model_name="mistral_7b" \ 36 | --base_docker_image=maxtext_base_image 37 | ``` 38 | 39 | From your workload logs, you should start seeing step time logs like the following: 40 | ``` 41 | completed step: 6, seconds: 6.320, TFLOP/s/device: 431.981, Tokens/s/device: 7776.813, total_weights: 393216, loss: 7.378 42 | ``` 43 | If you would like to run on multiple slices of v6e-8, you may modify the `--num_slices` flag. 44 | 45 | ### Workload Details 46 | 47 | For reference, here are the `mistral_7b` workload details as found in `MaxText@tpu-recipes-v0.1.2`: 48 | 49 | ``` 50 | MaxTextModel( 51 | model_name="mistral-7b", 52 | model_type="mistral-7b", 53 | tuning_params={ 54 | "per_device_batch_size": 6, 55 | "ici_fsdp_parallelism": -1, 56 | "remat_policy": "custom", 57 | "decoder_layer_input": "offload", 58 | "out_proj": "offload", 59 | "query_proj": "offload", 60 | "key_proj": "offload", 61 | "value_proj": "offload", 62 | "max_target_length": 8192, 63 | "attention": "flash", 64 | "use_iota_embed": True, 65 | "dataset_path": "gs://max-datasets-rogue", 66 | "dataset_type": "synthetic", 67 | "enable_checkpointing": False, 68 | "sa_block_q": 2048, 69 | "sa_block_kv": 2048, 70 | "sa_block_kv_compute": 2048, 71 | "sa_block_q_dkv": 2048, 72 | "sa_block_kv_dkv": 2048, 73 | "sa_block_kv_dkv_compute": 2048, 74 | "sa_block_q_dq": 2048, 75 | "sa_block_kv_dq": 2048, 76 | "sa_use_fused_bwd_kernel": True, 77 | "profiler": "xplane", 78 | "skip_first_n_steps_for_profiler": 10, 79 | "profiler_steps": 5, 80 | }, 81 | xla_flags=( 82 | xla_flags_library.DENSE_VMEM_LIMIT_FLAG 83 | + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER 84 | + xla_flags_library.DATA_PARALLEL_OVERLAP 85 | + xla_flags_library.CF_FOR_ALL_GATHER 86 | + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE 87 | + xla_flags_library.HOST_OFFLOAD_FLAGS 88 | + xla_flags_library.DISABLE_COLLECTIVE_MATMUL 89 | ), 90 | ) 91 | ``` 92 | 93 | This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/tpu-recipes-v0.1.2/benchmarks/maxtext_trillium_model_configs.py) file within the MaxText repository. 94 | -------------------------------------------------------------------------------- /training/trillium/Mistral-7B-MaxText/mistral-7B-1xv6e-8.sh: -------------------------------------------------------------------------------- 1 | # Run this command from the MaxText root directory using the setup described in the README. 2 | python3 -m benchmarks.benchmark_runner xpk \ 3 | --project=$PROJECT \ 4 | --zone=$ZONE \ 5 | --device_type=v6e-8 \ 6 | --num_slices=1 \ 7 | --cluster_name=${CLUSTER_NAME} \ 8 | --base_output_directory=${OUTPUT_DIR} \ 9 | --model_name="mistral_7b" \ 10 | --base_docker_image=maxtext_base_image 11 | -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x22B-MaxText/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Mixtral-8x22B-MaxText on TPU trillium 2 | 3 | ## XPK setup 4 | Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK 5 | 6 | ## Prep for Maxtext 7 | Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build docker image 8 | 9 | ## Run Maxtext Mixtral-8x7B workloads on GKE 10 | 11 | ### Test Env 12 | jaxlib=0.4.35 13 | 14 | libtpu-nightly=20241119 15 | 16 | [maxtext](https://github.com/AI-Hypercomputer/maxtext.git)@261a8be0fc5e909ef9da0521df62549e650ebb79 17 | 18 | ### Starting workload 19 | 20 | From the MaxText root directory, start your Mixtral workload. 21 | 22 | Bf16 run: 23 | ``` 24 | python3 benchmarks/benchmark_runner.py --project=${PROJECT} --zone={zone} --device_type=v6e-256 --num_slices=1 --cluster_name=${CLUSTER_NAME} --base_output_directory=${OUTPUT_DIR} \ 25 | --model_name="mixtral_8x22b_dropped" --libtpu_version=20241119 --base_docker_image=maxtext_base_image 26 | ``` 27 | 28 | Note: After commit `f64c51a2d8c115e98b6c4d24d90b546e5f0f826e`, use the xpk flag when running the benchmark script. For example: `python3 benchmarks/benchmark_runner.py xpk --project=${PROJECT} ...`. 29 | 30 | From your workload logs, you should start seeing step time logs like the following: 31 | ``` 32 | completed step: 9, seconds: 24.706, TFLOP/s/device: 332.463, Tokens/s/device: 1326.307, total_weights: 8388608, loss: 0.045 33 | ``` -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x22B-MaxText/mixtral-8x22b-10xv6e-256.sh: -------------------------------------------------------------------------------- 1 | python3 benchmarks/benchmark_runner.py --project=$PROJECT --zone=$ZONE --device_type=v6e-256 --num_slices=10 --cluster_name=${CLUSTER_NAME} --base_output_directory=${OUTPUT_DIR} \ 2 | --model_name="mixtral_8x22b_dropped" --libtpu_version=20241119 --base_docker_image maxtext_base_image 3 | -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x22B-MaxText/mixtral-8x22b-1xv6e-256.sh: -------------------------------------------------------------------------------- 1 | python3 benchmarks/benchmark_runner.py --project=$PROJECT --zone=$ZONE --device_type=v6e-256 --num_slices=1 --cluster_name=${CLUSTER_NAME} --base_output_directory=${OUTPUT_DIR} \ 2 | --model_name="mixtral_8x22b_dropped" --libtpu_version=20241119 --base_docker_image maxtext_base_image 3 | -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x22B-MaxText/mixtral-8x22b-20xv6e-256.sh: -------------------------------------------------------------------------------- 1 | python3 benchmarks/benchmark_runner.py --project=$PROJECT --zone=$ZONE --device_type=v6e-256 --num_slices=20 --cluster_name=${CLUSTER_NAME} --base_output_directory=${OUTPUT_DIR} \ 2 | --model_name="mixtral_8x22b_dropped" --libtpu_version=20241119 --base_docker_image maxtext_base_image 3 | -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x22B-MaxText/mixtral-8x22b-30xv6e-256.sh: -------------------------------------------------------------------------------- 1 | python3 benchmarks/benchmark_runner.py --project=$PROJECT --zone=$ZONE --device_type=v6e-256 --num_slices=30 --cluster_name=${CLUSTER_NAME} --base_output_directory=${OUTPUT_DIR} \ 2 | --model_name="mixtral_8x22b_dropped" --libtpu_version=20241119 --base_docker_image maxtext_base_image 3 | -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x22B-MaxText/mixtral-8x22b-40xv6e-256.sh: -------------------------------------------------------------------------------- 1 | python3 benchmarks/benchmark_runner.py --project=$PROJECT --zone=$ZONE --device_type=v6e-256 --num_slices=40 --cluster_name=${CLUSTER_NAME} --base_output_directory=${OUTPUT_DIR} \ 2 | --model_name="mixtral_8x22b_dropped" --libtpu_version=20241119 --base_docker_image maxtext_base_image 3 | -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-MaxText/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Mixtral-8x7B-MaxText on TPU trillium 2 | 3 | ## XPK setup 4 | Please follow the [XPK_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK 5 | 6 | ## Prep for Maxtext 7 | 8 | ### Install MaxText and Build Docker Image 9 | Please follow the [MAXTEXT_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build the docker image. The following variables should be set: 10 | 11 | In step 1, use the MaxText [tpu-recipes-v0.1.2](https://github.com/AI-Hypercomputer/maxtext/releases/tag/tpu-recipes-v0.1.2) tag to run this recipe: 12 | ``` 13 | git checkout tpu-recipes-v0.1.2 14 | ``` 15 | 16 | In step 3, use the jax-stable-stack image containing JAX 0.5.2: 17 | ``` 18 | BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 19 | bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE} 20 | ``` 21 | 22 | ## Run Maxtext Mixtral-8x7B workloads on GKE 23 | 24 | ### Starting workload 25 | 26 | From the MaxText root directory, start your Mixtral workload. 27 | 28 | ``` 29 | python3 -m benchmarks.benchmark_runner xpk \ 30 | --project=${PROJECT} \ 31 | --zone=${ZONE} \ 32 | --device_type=v6e-256 \ 33 | --num_slices=1 \ 34 | --cluster_name=${CLUSTER_NAME} \ 35 | --base_output_directory=${OUTPUT_DIR} \ 36 | --model_name="mixtral_8x7b_dropped" \ 37 | --base_docker_image=maxtext_base_image 38 | ``` 39 | 40 | From your workload logs, you should start seeing step time logs like the following: 41 | ``` 42 | completed step: 11, seconds: 13.484, TFLOP/s/device: 302.311, Tokens/s/device: 3645.203, total_weights: 12582912, loss: 10.546 43 | ``` 44 | 45 | ### Workload Details 46 | 47 | For reference, here are the `mixtral_8x7b_dropped` workload details as found in `MaxText@tpu-recipes-v0.1.2`: 48 | 49 | ``` 50 | MaxTextModel( 51 | model_name="mixtral_8x7b_dropped", 52 | model_type="mixtral-8x7b", 53 | tuning_params={ 54 | "per_device_batch_size": 12, 55 | "ici_fsdp_parallelism": -1, 56 | "max_target_length": 4096, 57 | "remat_policy": "custom", 58 | "decoder_layer_input": "offload", 59 | "out_proj": "offload", 60 | "query_proj": "offload", 61 | "key_proj": "offload", 62 | "value_proj": "offload", 63 | "attention": "flash", 64 | "gcs_metrics": True, 65 | "use_iota_embed": True, 66 | "dataset_path": "gs://max-datasets-rogue", 67 | "dataset_type": "synthetic", 68 | "reuse_example_batch": 1, 69 | "enable_checkpointing": False, 70 | "profiler": "xplane", 71 | "sa_block_q": 2048, 72 | "sa_block_q_dkv": 2048, 73 | "sa_block_q_dq": 2048, 74 | "megablox": False, 75 | "sparse_matmul": False, 76 | "capacity_factor": 1.25, 77 | "tokenizer_path": "assets/tokenizer.mistral-v1", 78 | }, 79 | xla_flags=( 80 | xla_flags_library.MOE_VMEM_LIMIT_FLAG 81 | + xla_flags_library.CF_FOR_ALL_GATHER 82 | + xla_flags_library.DATA_PARALLEL_OVERLAP 83 | ), 84 | ) 85 | ``` 86 | 87 | This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/tpu-recipes-v0.1.2/benchmarks/maxtext_trillium_model_configs.py) file within the MaxText repository. 88 | -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-MaxText/mixtral-8x7b-1xv6e-256.sh: -------------------------------------------------------------------------------- 1 | # Run this command from the MaxText root directory using the setup described in the README. 2 | python3 -m benchmarks.benchmark_runner xpk \ 3 | --project=${PROJECT} \ 4 | --zone=${ZONE} \ 5 | --device_type=v6e-256 \ 6 | --num_slices=1 \ 7 | --cluster_name=${CLUSTER_NAME} \ 8 | --base_output_directory=${OUTPUT_DIR} \ 9 | --model_name="mixtral_8x7b_dropped" \ 10 | --base_docker_image=maxtext_base_image 11 | -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-MaxText/mixtral-8x7b-2xv6e-256.sh: -------------------------------------------------------------------------------- 1 | # Run this command from the MaxText root directory using the setup described in the README. 2 | python3 -m benchmarks.benchmark_runner xpk \ 3 | --project=${PROJECT} \ 4 | --zone=${ZONE} \ 5 | --device_type=v6e-256 \ 6 | --num_slices=2 \ 7 | --cluster_name=${CLUSTER_NAME} \ 8 | --base_output_directory=${OUTPUT_DIR} \ 9 | --model_name="mixtral_8x7b_dropped" \ 10 | --base_docker_image=maxtext_base_image 11 | -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-MaxText/mixtral-8x7b-4xv6e-256.sh: -------------------------------------------------------------------------------- 1 | # Run this command from the MaxText root directory using the setup described in the README. 2 | python3 -m benchmarks.benchmark_runner xpk \ 3 | --project=${PROJECT} \ 4 | --zone=${ZONE} \ 5 | --device_type=v6e-256 \ 6 | --num_slices=4 \ 7 | --cluster_name=${CLUSTER_NAME} \ 8 | --base_output_directory=${OUTPUT_DIR} \ 9 | --model_name="mixtral_8x7b_dropped" \ 10 | --base_docker_image=maxtext_base_image 11 | -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-Pytorch/GCE/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Mixtral-8X7B on Trillium(v6e) TPU 2 | 3 | 4 | This user guide provides a concise overview of the essential steps required to run HuggingFace (HF) Mixtral training on Cloud TPUs. 5 | 6 | 7 | ## Environment Setup 8 | 9 | Please follow the corresponding TPU generation's user guide to setup the GCE TPUs 10 | first. 11 | 12 | Please replace all your-* with your TPUs' information. 13 | 14 | ``` 15 | export TPU_NAME=your-tpu-name 16 | export ZONE=your-tpu-zone 17 | export PROJECT=your-tpu-project 18 | ``` 19 | 20 | You may use this command to create a 256 chip v6e slice: 21 | 22 | ``` 23 | gcloud alpha compute tpus tpu-vm create $TPU_NAME \ 24 | --accelerator-type v6e-256 --project $PROJECT --zone $ZONE \ 25 | --version v2-alpha-tpuv6e 26 | ``` 27 | 28 | ## Steps to Run HF Mixtral 8x7B 29 | 30 | The following setup runs the training job with Mixtral 8x7B on GCE TPUs using the docker image from this registry (``), the docker image uses the pytorch and torch_xla nightly build from 10/28/2024 and installed with all the package dependency needed to run the model training. All the command below should run from your own machine (not the TPU host you created). 31 | 32 | 1. git clone and navigate to this README repo and run training script: 33 | ```bash 34 | git clone https://github.com/AI-Hypercomputer/tpu-recipes.git 35 | cd training/trillium/Mixtral-8x7B-PyTorch 36 | ``` 37 | 2. Edit `env.sh` to add the hugging face token and/or setup the training parameters. 38 | ```bash 39 | # add your hugging face token 40 | HF_TOKEN=hf_*** 41 | ``` 42 | 3. Edit `host.sh` to add the docker image URL if default docker image is not accessible to you. 43 | ```bash 44 | # docker image URL to use for the training 45 | DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-mixtral:dropping 46 | ``` 47 | 4. Run the training script: 48 | ```bash 49 | ./benchmark.sh 50 | ``` 51 | `benchmark.sh` script will upload 1) environment parameters in `env.sh`, 2) model related config in `config.json`, `fsdp_config.json`, 3) docker launch script in `host.sh` and 4) python training command in `train.sh` into all TPU workers, and starts the training afterwards. When all training steps complete, it will print out training metrics of each worker as below in terminal: 52 | ``` 53 | ***** train metrics ***** 54 | [worker :3] ***** train metrics ***** 55 | [worker :3] epoch = 56 | [worker :3] total_flos = 57 | [worker :3] train_loss = 58 | [worker :3] train_runtime = 59 | [worker :3] train_samples = 60 | [worker :3] train_samples_per_second = 61 | ``` 62 | In addition, it will copy back the trained model under `output/*`. -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-Pytorch/GCE/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # SCP the environment setup to all instances. 4 | gcloud compute tpus tpu-vm scp config.json fsdp_config.json train.sh host.sh env.sh "$TPU_NAME:~" --worker=all --project $PROJECT --zone=$ZONE 5 | 6 | # Actually runs the benchmark. 7 | gcloud compute tpus tpu-vm ssh $TPU_NAME --project $PROJECT --zone=$ZONE --worker=all --command="$(cat host.sh)" 8 | 9 | # Copy the profile and output back 10 | gcloud compute tpus tpu-vm scp --recurse $TPU_NAME:~/output ./ --project=$PROJECT --zone=$ZONE -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-Pytorch/GCE/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "MixtralForCausalLM" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 1, 7 | "eos_token_id": 2, 8 | "hidden_act": "silu", 9 | "hidden_size": 4096, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 14336, 12 | "max_position_embeddings": 32768, 13 | "model_type": "mixtral", 14 | "num_attention_heads": 32, 15 | "num_experts_per_tok": 2, 16 | "num_hidden_layers": 32, 17 | "num_key_value_heads": 8, 18 | "capacity_factor": 1.25, 19 | "num_local_experts": 8, 20 | "output_router_logits": false, 21 | "rms_norm_eps": 1e-05, 22 | "rope_theta": 1000000.0, 23 | "router_aux_loss_coef": 0.02, 24 | "sliding_window": null, 25 | "tie_word_embeddings": false, 26 | "torch_dtype": "bfloat16", 27 | "transformers_version": "4.36.0.dev0", 28 | "use_cache": false, 29 | "vocab_size": 32000 30 | } -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-Pytorch/GCE/env.sh: -------------------------------------------------------------------------------- 1 | # Uncomment below to set the Huggingface token 2 | # HF_TOKEN=hf_*** 3 | PJRT_DEVICE=TPU 4 | XLA_IR_DEBUG=1 5 | XLA_HLO_DEBUG=1 6 | PROFILE_EPOCH=0 7 | PROFILE_STEP=3 8 | PROFILE_DURATION_MS=120000 9 | XLA_USE_SPMD=1 10 | MAX_STEPS=20 11 | SEQ_LENGTH=4096 12 | 13 | GLOBAL_BATCH_SIZE=2048 14 | 15 | # XLA flags 16 | LIBTPU_INIT_ARGS=--xla_tpu_enable_flash_attention=false --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_scoped_vmem_limit_kib=81920 -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-Pytorch/GCE/fsdp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fsdp_transformer_layer_cls_to_wrap": [ 3 | "MixtralDecoderLayer" 4 | ], 5 | "xla": true, 6 | "xla_fsdp_v2": true, 7 | "xla_fsdp_grad_ckpt": true 8 | } -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-Pytorch/GCE/host.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-mixtral:dropping 4 | 5 | worker_id=$(curl -s "http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number" -H 'Metadata-Flavor: Google') 6 | 7 | cat >> /dev/null <&1 | sed "s/^/[worker $worker_id] /g" | tee runlog 11 | set -o xtrace 12 | # Configure docker 13 | sudo groupadd docker 14 | sudo usermod -aG docker $USER 15 | # newgrp applies updated group permissions 16 | newgrp - docker 17 | gcloud auth configure-docker us-central1-docker.pkg.dev --quiet 18 | # Kill any running benchmarks 19 | docker kill $USER-test 20 | docker pull $DOCKER_IMAGE 21 | docker run --rm \ 22 | --name $USER-test \ 23 | --privileged \ 24 | --env-file env.sh \ 25 | -v /home/$USER:/tmp/home \ 26 | --shm-size=16G \ 27 | --net host \ 28 | -u root \ 29 | --entrypoint /bin/bash $DOCKER_IMAGE \ 30 | /tmp/home/train.sh 31 | 32 | PIPE_EOF -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-Pytorch/GCE/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Remove existing repo and old data. 3 | LOCAL_DIR=/tmp/home/ 4 | rm -rf "${LOCAL_DIR}/output" 5 | rm -rf "${LOCAL_DIR}/plugins" 6 | rm -rf "${LOCAL_DIR}/cache" 7 | mkdir -p "${LOCAL_DIR}/output" 8 | mkdir -p "${LOCAL_DIR}/plugins" 9 | mkdir -p "${LOCAL_DIR}/cache" 10 | 11 | unset LD_PRELOAD 12 | 13 | 14 | cd transformers/ 15 | 16 | # The flag --static uses the dropping implementation for MoE 17 | python3 examples/pytorch/language-modeling/run_clm.py \ 18 | --dataset_name wikitext \ 19 | --dataset_config_name wikitext-103-raw-v1 \ 20 | --per_device_train_batch_size "${GLOBAL_BATCH_SIZE}" \ 21 | --do_train \ 22 | --output_dir "${LOCAL_DIR}/output/test-clm" \ 23 | --overwrite_output_dir \ 24 | --config_name "${LOCAL_DIR}/config.json" \ 25 | --cache_dir "${LOCAL_DIR}/cache" \ 26 | --tokenizer_name mistralai/Mixtral-8x7B-v0.1 \ 27 | --block_size "$SEQ_LENGTH" \ 28 | --optim adafactor \ 29 | --save_strategy no \ 30 | --logging_strategy no \ 31 | --fsdp "full_shard" \ 32 | --fsdp_config "${LOCAL_DIR}/fsdp_config.json" \ 33 | --torch_dtype bfloat16 \ 34 | --dataloader_drop_last yes \ 35 | --flash_attention \ 36 | --num_train_epochs 1 \ 37 | --max_steps "$MAX_STEPS" \ 38 | --static 39 | -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-Pytorch/XPK/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source env.sh 4 | 5 | python3 xpk.py workload create \ 6 | --cluster ${CLUSTER_NAME} \ 7 | --base-docker-image=${BASE_DOCKER_IMAGE} \ 8 | --workload=${WORKLOAD_NAME} \ 9 | --tpu-type=${TPU_TYPE} \ 10 | --num-slices=${NUM_SLICE} \ 11 | --on-demand \ 12 | --zone=$ZONE \ 13 | --project=$PROJECT \ 14 | --enable-debug-logs \ 15 | --command="bash /app/train.sh" -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-Pytorch/XPK/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "MixtralForCausalLM" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 1, 7 | "eos_token_id": 2, 8 | "hidden_act": "silu", 9 | "hidden_size": 4096, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 14336, 12 | "max_position_embeddings": 32768, 13 | "model_type": "mixtral", 14 | "num_attention_heads": 32, 15 | "num_experts_per_tok": 2, 16 | "num_hidden_layers": 32, 17 | "num_key_value_heads": 8, 18 | "capacity_factor": 1.25, 19 | "num_local_experts": 8, 20 | "output_router_logits": false, 21 | "rms_norm_eps": 1e-05, 22 | "rope_theta": 1000000.0, 23 | "router_aux_loss_coef": 0.02, 24 | "sliding_window": null, 25 | "tie_word_embeddings": false, 26 | "torch_dtype": "bfloat16", 27 | "transformers_version": "4.36.0.dev0", 28 | "use_cache": false, 29 | "vocab_size": 32000 30 | } -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-Pytorch/XPK/env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Environment variables associated with XPK on GCP. 4 | export ZONE=... 5 | export PROJECT=... 6 | export TPU_TYPE=v6e-256 7 | export NUM_SLICE=1 8 | export CLUSTER_NAME=xpk-$USER-... # use existing CLUSTER if you have 9 | 10 | # Environment variables associated with training config. 11 | export BATCH_PER_DEVICE=8 12 | export SEQUENCE_LENGTH=4096 13 | export MAX_STEP=50 14 | export WORKLOAD_NAME=${USER}-xpk-${TPU_TYPE}-... # Your workload name. Need to update for different run. 15 | export BASE_DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-mixtral:dropping 16 | export PROFILE_LOG_DIR=... # GCS bucket to store profile in form of gs://... 17 | export HF_TOKEN=... # Add your own Hugging face token to download model 18 | -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-Pytorch/XPK/fsdp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fsdp_transformer_layer_cls_to_wrap": [ 3 | "MixtralDecoderLayer" 4 | ], 5 | "xla": true, 6 | "xla_fsdp_v2": true, 7 | "xla_fsdp_grad_ckpt": true 8 | } 9 | -------------------------------------------------------------------------------- /training/trillium/Mixtral-8x7B-Pytorch/XPK/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /app/env.sh 4 | 5 | # Extract the number after '-' in TPU_TYPE 6 | TPU_NUM=$(echo "$TPU_TYPE" | grep -oP '(?<=-)\d+') 7 | 8 | # Calculate GLOBAL_BATCH_SIZE 9 | GLOBAL_BATCH_SIZE=$(( TPU_NUM * BATCH_PER_DEVICE * NUM_SLICE )) 10 | 11 | export GLOBAL_BATCH_SIZE 12 | 13 | echo "GLOBAL_BATCH_SIZE=$GLOBAL_BATCH_SIZE" 14 | 15 | export PJRT_DEVICE=TPU 16 | export XLA_USE_SPMD=1 17 | export ENABLE_PJRT_COMPATIBILITY=true 18 | export XLA_IR_DEBUG=1 19 | export XLA_HLO_DEBUG=1 20 | export PROFILE_EPOCH=0 21 | export PROFILE_STEP=3 22 | export PROFILE_DURATION_MS=100000 23 | export PROFILE_LOGDIR=${PROFILE_LOG_DIR} 24 | export XLA_PERSISTENT_CACHE_PATH=/app/xla_cache/ 25 | export NUM_TPU_SLICE=${NUM_SLICE} 26 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_flash_attention=false --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_scoped_vmem_limit_kib=81920" 27 | 28 | huggingface-cli login --token=${HF_TOKEN} 29 | 30 | # Note --per_device_train_batch_size is the global batch size since we overwrite the dataloader in the HF trainer. 31 | # --static uses the dropping implementation for MoE. 32 | python3 /workspaces/transformers/examples/pytorch/language-modeling/run_clm.py \ 33 | --dataset_name=wikitext --dataset_config_name=wikitext-103-raw-v1 \ 34 | --per_device_train_batch_size=${GLOBAL_BATCH_SIZE} --do_train --output_dir=test-clm \ 35 | --overwrite_output_dir --config_name=/app/config.json \ 36 | --cache_dir=cache --tokenizer_name=mistralai/Mixtral-8x7B-v0.1 \ 37 | --block_size=${SEQUENCE_LENGTH} --optim=adafactor --save_strategy=no \ 38 | --logging_strategy=no --fsdp="full_shard" \ 39 | --fsdp_config=/app/fsdp_config.json --torch_dtype=bfloat16 \ 40 | --dataloader_drop_last=yes --max_steps=${MAX_STEP} --flash_attention --static -------------------------------------------------------------------------------- /training/trillium/XPK_README.md: -------------------------------------------------------------------------------- 1 | ## Initialization 2 | 3 | > **_NOTE:_** We recommend running these instructions and kicking off your recipe 4 | workloads from a VM in GCP using Python 3.10. 5 | 6 | 1. Run the following commands to initialize the project and zone. 7 | ```shell 8 | export PROJECT=# 9 | export ZONE=# 10 | gcloud config set project $PROJECT 11 | gcloud config set compute/zone $ZONE 12 | ``` 13 | 14 | 2. Install XPK by following the [prerequisites](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#prerequisites) and [installation](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#installation) 15 | instructions. Also ensure you have the proper [GCP permissions](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#installation). 16 | 17 | * In order to run the tpu-recipes as-is, run the `git clone` command from your home (~/) directory: 18 | ```shell 19 | # tpu-recipes requiring XPK will look for it in the home directory 20 | cd ~/ 21 | git clone https://github.com/google/xpk.git 22 | ``` 23 | 24 | 3. Run the rest of these commands from the cloned XPK directory: 25 | 26 | ```shell 27 | cd xpk # Should be equivalent to cd ~/xpk 28 | ``` 29 | 30 | > **_NOTE:_** If you use a virtual environment in the 31 | [XPK Installation](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#installation) 32 | steps, you must use the same one to run the steps in the [MAXTEXT_README](MAXTEXT_README.md) 33 | as well as your relevant tpu-recipe workloads. 34 | 35 | ## GKE Cluster Creation 36 | Trillium GKE clusters can be [created](https://cloud.google.com/tpu/docs/v6e-intro#create_an_xpk_cluster_with_multi-nic_support) and 37 | [deleted](https://cloud.google.com/tpu/docs/v6e-intro#delete_xpk_cluster) by following the public GCP documentation. 38 | 39 | > **_NOTE:_** in order to run the training and microbenchmarks tpu-recipes, you should not need to run sections outside of 40 | `Create an XPK cluster with multi-NIC support` when creating your cluster. You can skip the following sections like `Framework setup`. -------------------------------------------------------------------------------- /training/v5p/Diffusion-2-MaxDiffusion/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Stable Diffusion 2 on TPU v5p 2 | 3 | This documents present steps to run StableDiffusion [MaxDiffusion](https://github.com/google/maxdiffusion/tree/main/src/maxdiffusion) workload through [XPK](https://github.com/google/xpk/blob/main/README.md) tool. 4 | 5 | Setup XPK and create cluster [XPK Userguide](Training/TPU-v5p/XPK_README.md) 6 | 7 | Build a local docker image. 8 | 9 | ``` 10 | LOCAL_IMAGE_NAME=maxdiffusion_base_image 11 | docker build --no-cache --network host -f ./maxdiffusion.Dockerfile -t ${LOCAL_IMAGE_NAME} . 12 | ``` 13 | 14 | Run workload using xpk. 15 | 16 | ``` 17 | export BASE_OUTPUT_DIR=gs://output_bucket/ 18 | DATA_DIR=gs://jfacevedo-maxdiffusion/laion400m/raw_data/tf_records_512_encoder_state_fp32 19 | COMMITS=eac9132ef8b1a977372e29720fabc478529cd364 20 | NUM_SLICES=1 21 | 22 | xpk workload create \ 23 | --cluster \ 24 | --base-docker-image maxdiffusion_base_image \ 25 | --workload ${USER}-sd21-v5p \ 26 | --tpu-type= \ 27 | --num-slices=${NUM_SLICES} \ 28 | --command "bash run_v5p-ddp-pbs-16.sh DATA_DIR=${DATA_DIR} BASE_OUTPUT_DIR=${BASE_OUTPUT_DIR} COMMITS=${COMMITS} NUM_SLICES=${NUM_SLICES} " 29 | ``` 30 | 31 | MFU Calculation. 32 | 33 | Above only Unet is trainable modeule, from FLOPS count, Per Step FLOPS = 2.41G FLOPS @BS=1, we get the MFU 34 | ``` 35 | MFU = Per Step FLOPS * BatchSize Per Device / Step Time / Per Device Peak FLOPS 36 | ``` -------------------------------------------------------------------------------- /training/v5p/Diffusion-2-MaxDiffusion/docker/maxdiffusion.Dockerfile: -------------------------------------------------------------------------------- 1 | # Install ip. 2 | FROM python:3.10-slim-bullseye 3 | RUN apt-get update 4 | RUN apt-get install -y curl procps gnupg git 5 | RUN apt-get install -y net-tools ethtool iproute2 6 | 7 | # Add the Google Cloud SDK package repository 8 | RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list 9 | RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - 10 | 11 | # Install the Google Cloud SDK 12 | RUN apt-get update && apt-get install -y google-cloud-sdk 13 | 14 | # Set the default Python version to 3.10 15 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1 16 | 17 | # Set environment variables for Google Cloud SDK and Python 3.10 18 | ENV PATH="/usr/local/google-cloud-sdk/bin:/usr/local/bin/python3.10:${PATH}" 19 | 20 | RUN pip install --no-cache-dir jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 21 | RUN pip install --no-cache-dir -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html 22 | RUN pip install git+https://github.com/google/jax 23 | RUN pip uninstall jaxlib -y 24 | RUN pip install -U --pre jax[tpu] --no-cache-dir -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 25 | RUN pip install git+https://github.com/mlperf/logging.git 26 | 27 | RUN git clone -b mlperf_4 https://github.com/google/maxdiffusion.git 28 | 29 | WORKDIR maxdiffusion 30 | RUN pip install -r requirements.txt 31 | 32 | 33 | RUN pip install . -------------------------------------------------------------------------------- /training/v5p/Diffusion-2-MaxDiffusion/scripts/run_v5p-ddp-pbs-16.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=$1 2 | BASE_OUTPUT_DIR=$2 3 | COMMITS=$3 4 | 5 | # Set environment variables 6 | for ARGUMENT in "$@"; do 7 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 8 | export "$KEY"="$VALUE" 9 | done 10 | 11 | export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_megacore_fusion=false --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true' 12 | 13 | LIBTPU_INIT_ARGS+=' --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_enable_async_all_reduce=true' 14 | LIBTPU_INIT_ARGS+=' --xla_tpu_enable_async_collective_fusion_with_mosaic_custom_call=true --xla_tpu_mosaic_fusion=true' 15 | LIBTPU_INIT_ARGS+=' --xla_enable_async_reduce_scatter_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_reduce_scatter=true' 16 | LIBTPU_INIT_ARGS+=' --xla_tpu_spmd_threshold_for_allgather_cse=1000000 --xla_jf_spmd_threshold_for_windowed_einsum_mib=1000000' 17 | 18 | #reload code to specific commits 19 | cd maxdiffusion 20 | git checkout ${COMMITS} 21 | pip install . 22 | 23 | python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name=sd_base2 base_output_directory=${BASE_OUTPUT_DIR} \ 24 | train_data_dir=${DATA_DIR} per_device_batch_size=16 split_head_dim=True attention=flash train_new_unet=true norm_num_groups=16 \ 25 | dcn_data_parallelism=${NUM_SLICES} \ 26 | start_step_to_checkpoint=5120000 enable_profiler=true skip_first_n_steps_for_profiler=5 reuse_example_batch=false max_train_steps=100 -------------------------------------------------------------------------------- /training/v5p/Diffusion-2-PyTorch/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Stable Diffusion 2 on TPU v5p 2 | 3 | 4 | This user guide provides a concise overview of the essential steps required to run StableDiffusion 2.0 base training on Cloud TPUs. 5 | 6 | 7 | ## Environment Setup 8 | 9 | The following setup assumes to run the training job with StableDiffusion 2.0 base on GCE TPUs using the docker image from this registery (us-central1-docker.pkg.dev/tpu-pytorch/docker/development/pytorch-tpu-diffusers:v2), the docker image uses the pytorch and torch_xla nightly build from 09/05 and has all the package dependency installed. It cloned the git repo from [https://github.com/pytorch-tpu/diffusers (commit f08dc9)](https://github.com/pytorch-tpu/diffusers/tree/f08dc92db9d7fd7d8d8ad4efcdfee675e2cd26f2) in order to run hugging face stable diffusion on TPU. Please follow corresponding TPU generation's user guide to setup the GCE TPUs first. All the command below should run from your own machine (not the TPU host you created). 10 | 11 | ### Setup Environment of Your TPUs 12 | Please replace all your-* with your TPUs' information. 13 | ``` 14 | export TPU_NAME=your-tpu-name 15 | export ZONE=your-tpu-zone 16 | export PROJECT=your-tpu-project 17 | ``` 18 | 19 | ### Simple Run Command 20 | git clone and navigate to this README repo and run training script: 21 | ```bash 22 | git clone --depth 1 https://github.com/AI-Hypercomputer/tpu-recipes.git 23 | cd training/v5p/Diffusion-2-PyTorch 24 | bash benchmark.sh 25 | ``` 26 | `benchmark.sh` script will upload 1) environment parameters in `env.sh`, 2) docker launch script in `host.sh` and 3) python training command in `train.sh` into all TPU workers. 27 | 28 | Note that the docker image is specified in `host.sh`. Make sure the docker image is accessible in your GCP project. If not, please download the image first, upload it to your GCP project and change env `$DOCKER_IMAGE` to the registry URL you own. 29 | 30 | When all training steps complete, the benchmark script will print out the average step time. You shall see the performance metric in the terminal like: 31 | ``` 32 | [worker :x] Average step time: ... 33 | ``` 34 | This tells the average step time for each batch run of each worker. In addition, it will copy the profile back to current folder under *profile/* and the trained model in safetensor format under *output/*. Use TensorBoard to open the profile and measure the step time from the "Trace View.". 35 | 36 | 37 | ### Environment Envs Explained 38 | 39 | To make it simple, we suggest only change the following to env variables in env.sh: 40 | * `PER_HOST_BATCH_SIZE`:Batch size for each host/worker. High number can cause out of memory issue. 41 | * `TRAIN_STEPS`: How many training steps to run. (choose more than 10 for this example) 42 | * `PROFILE_DURATION`: Length of the profiling time (unit ms). 43 | * `RESOLUTION`: Image resolution. 44 | -------------------------------------------------------------------------------- /training/v5p/Diffusion-2-PyTorch/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # SCP the environment setup to all instances. Used in `--env-file` in `docker run` on the host script. 3 | gcloud compute tpus tpu-vm scp env.sh train.sh $TPU_NAME:~ --worker=all --project $PROJECT --zone=$ZONE 4 | 5 | # Actually runs the benchmark. 6 | gcloud compute tpus tpu-vm ssh $TPU_NAME --project $PROJECT --zone=$ZONE --worker=all --command="$(cat host.sh)" 7 | 8 | # Copy the profile and output back 9 | gcloud compute tpus tpu-vm scp --recurse $TPU_NAME:~/{profile,output} ./ --project=$PROJECT --zone=$ZONE 10 | -------------------------------------------------------------------------------- /training/v5p/Diffusion-2-PyTorch/env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | XLA_DISABLE_FUNCTIONALIZATION=0 3 | PROFILE_DIR=/tmp/home/profile/ 4 | CACHE_DIR=/tmp/home/xla_cache 5 | DATASET_NAME=lambdalabs/naruto-blip-captions 6 | OUTPUT_DIR=/tmp/home/output/ 7 | PROFILE_DURATION=80000 8 | PER_HOST_BATCH_SIZE=256 9 | TRAIN_STEPS=50 10 | RESOLUTION=512 11 | -------------------------------------------------------------------------------- /training/v5p/Diffusion-2-PyTorch/host.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DOCKER_IMAGE="us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-diffusers:v4" 4 | worker_id=$(curl -s "http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number" -H 'Metadata-Flavor: Google') 5 | cat >> /dev/null <&1 | sed "s/^/[worker:$worker_id] /g" | tee runlog 9 | set -o xtrace 10 | # Configure docker 11 | sudo groupadd docker 12 | sudo usermod -aG docker $USER 13 | # newgrp applies updated group permissions 14 | newgrp - docker 15 | gcloud auth configure-docker us-central1-docker.pkg.dev --quiet 16 | # Kill any running benchmarks 17 | docker kill $USER-test 18 | docker pull $DOCKER_IMAGE 19 | docker run --rm \ 20 | --name $USER-test \ 21 | --privileged \ 22 | --env-file env.sh \ 23 | -v /home/$USER:/tmp/home \ 24 | --shm-size=16G \ 25 | --net host \ 26 | -u root \ 27 | --entrypoint /bin/bash $DOCKER_IMAGE \ 28 | /tmp/home/train.sh 29 | 30 | PIPE_EOF 31 | -------------------------------------------------------------------------------- /training/v5p/Diffusion-2-PyTorch/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python /workspace/diffusers/examples/text_to_image/train_text_to_image_xla.py \ 4 | --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base \ 5 | --dataset_name=$DATASET_NAME --resolution=$RESOLUTION --center_crop --random_flip \ 6 | --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS \ 7 | --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=$PROFILE_DURATION \ 8 | --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 \ 9 | --loader_prefetch_size=4 --device_prefetch_size=4 --loader_prefetch_factor=4 10 | -------------------------------------------------------------------------------- /training/v5p/GPT3-175B-MaxText/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training GPT3-175B-Maxtext on TPU v5p 2 | 3 | ## XPK setup 4 | Please follow this [link](https://github.com/gclouduniverse/reproducibility/tree/main/Training/TPU-v5p/XPK_README.md) to create your GKE cluster with XPK 5 | 6 | ## Prep for Maxtext GPT3-175B workloads on GKE 7 | 1. Clone [Maxtext](https://github.com/AI-Hypercomputer/maxtext) repo and move to its directory 8 | ``` 9 | git clone https://github.com/AI-Hypercomputer/maxtext.git 10 | cd maxtext 11 | ``` 12 | 13 | 2. Run the following commands to build the docker image 14 | ``` 15 | bash docker_build_dependency_image.sh MODE=stable DEVICE=tpu 16 | ``` 17 | 18 | 3. Upload your docker image to Container Registry 19 | ``` 20 | bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner 21 | ``` 22 | 23 | 4. Create your GCS bucket 24 | ``` 25 | GCS_PATH=gs://v5p-demo # 26 | gcloud storage buckets create ${GCS_PATH} --project ${PROJECT} 27 | ``` 28 | 29 | 5. Specify your workload configs 30 | ``` 31 | export CLUSTER_NAME=v5p-demo # 32 | export WORKLOAD_NAME=gpt3-175b-test # 33 | export RUN_NAME=gpt3-175b-run # 34 | export NUM_SLICES=1 # 35 | export LOCAL_IMAGE_NAME=gcr.io/${PROJECT}/${USER}_runner 36 | export OUTPUT_PATH=gs://v5p-demo/ # 37 | ``` 38 | 39 | ## Run Maxtext GPT3-175B workloads on GKE 40 | 41 | ### Configs based on TPU type 42 | 43 | #### v5p-1024 44 | 45 | ``` 46 | export TPU_TYPE=v5p-1024 47 | export SCRIPT=MaxText/configs/v5p/gpt3_175b/v5p_1024.sh 48 | ``` 49 | 50 | #### v5p-2048 51 | 52 | ``` 53 | export TPU_TYPE=v5p-2048 54 | export SCRIPT=MaxText/configs/v5p/gpt3_175b/v5p_2048.sh 55 | ``` 56 | 57 | #### v5p-3072 58 | 59 | ``` 60 | export TPU_TYPE=v5p-3072 61 | export SCRIPT=MaxText/configs/v5p/gpt3_175b/v5p_3072.sh 62 | ``` 63 | 64 | #### v5p-4096 65 | 66 | This will require a custom slice topology of 4x8x64 67 | ``` 68 | export TPU_TYPE=v5p-4096 69 | export SCRIPT=MaxText/configs/v5p/gpt3_175b/v5p_4096.sh 70 | ``` 71 | 72 | #### v5p-8192 73 | 74 | This will require a custom slice topology of 8x16x32 75 | ``` 76 | export TPU_TYPE=v5p-8192 77 | export SCRIPT=MaxText/configs/v5p/gpt3_175b/v5p_8192.sh 78 | ``` 79 | 80 | #### v5p-12288 81 | 82 | This will require a custom slice topology of 8x16x48 83 | ``` 84 | export TPU_TYPE=v5p-12288 85 | export SCRIPT=MaxText/configs/v5p/gpt3_175b/v5p_12288.sh 86 | ``` 87 | 88 | ### Starting workload 89 | 90 | From the MaxText root directory, start your GPT3-175B workload 91 | 92 | ``` 93 | python3 ../xpk.py workload create \ 94 | --project ${PROJECT} \ 95 | --cluster ${CLUSTER_NAME} \ 96 | --workload ${WORKLOAD_NAME} \ 97 | --tpu-type=${TPU_TYPE} \ 98 | --num-slices=1 \ 99 | --base-docker-image=${LOCAL_IMAGE_NAME} \ 100 | --command "bash $SCRIPT $RUN_NAME $OUTPUT_PATH" 101 | ``` 102 | 103 | From your workload logs, you should start seeing step time logs like the following: 104 | ``` 105 | completed step: 2, seconds: 22.197, TFLOP/s/device: 397.246, Tokens/s/device: 369.059, total_weights: 4194304, loss: 0.000 106 | ``` 107 | 108 | [Optional] If you need to delete your workload, you can run the following command: 109 | ``` 110 | cd .. # Switch back to the xpk directory 111 | export WORKLOAD_NAME_TO_DELETE=gpt3-175b-test 112 | 113 | python3 xpk.py workload delete \ 114 | --workload ${WORKLOAD_NAME_TO_DELETE} \ 115 | --cluster ${CLUSTER_NAME} 116 | ``` 117 | -------------------------------------------------------------------------------- /training/v5p/Llama2-7B-Maxtext/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Llama2-7B-Maxtext on TPU v5p 2 | 3 | ## XPK setup 4 | Please follow this [link](https://github.com/gclouduniverse/reproducibility/tree/main/Training/TPU-v5p/XPK_README.md) to create your GKE cluster with XPK 5 | 6 | ## Run Maxtext Llama2-7B workloads on GKE 7 | 1. Clone [Maxtext](https://github.com/AI-Hypercomputer/maxtext) repo 8 | ``` 9 | git clone https://github.com/AI-Hypercomputer/maxtext.git 10 | cd maxtext 11 | ``` 12 | 13 | 2. Run the following commands to build the docker image 14 | ``` 15 | bash docker_build_dependency_image.sh MODE=stable DEVICE=tpu 16 | ``` 17 | 18 | 3. Upload your docker image to Container Registry 19 | ``` 20 | bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner 21 | ``` 22 | 23 | 4. Create your GCS bucket 24 | ``` 25 | GCS_PATH=gs://v5p-demo # 26 | gcloud storage buckets create ${GCS_PATH} --project ${PROJECT} 27 | ``` 28 | 29 | 5. Specify your workload configs 30 | ``` 31 | export CLUSTER_NAME=v5p-demo # 32 | export WORKLOAD_NAME=llam2-7b-test # 33 | export RUN_NAME=llama2-7b-run # 34 | export TPU_TYPE=v5p-512 # 35 | export NUM_SLICES=1 # 36 | export LOCAL_IMAGE_NAME=gcr.io/${PROJECT}/${USER}_runner 37 | export OUTPUT_PATH=gs://v5p-demo/ # 38 | ``` 39 | 40 | 6. Switch back to your XPK folder and run Llama2-7B workload 41 | ``` 42 | cd ../ #Make sure you are in the XPK folder 43 | 44 | python3 xpk.py workload create \ 45 | --cluster ${CLUSTER_NAME} \ 46 | --workload ${WORKLOAD_NAME} \ 47 | --tpu-type=${TPU_TYPE} \ 48 | --num-slices=${NUM_SLICES} \ 49 | --docker-image=${LOCAL_IMAGE_NAME} \ 50 | --command "\ 51 | bash MaxText/configs/v5p/llama2_7b.sh RUN_NAME=$RUN_NAME OUTPUT_PATH=$OUTPUT_PATH" 52 | ``` 53 | 54 | 7. [Optional] If you need to delete any of your workload, you can run the following command: 55 | ``` 56 | export WORKLOAD_NAME_TO_DELETE=llam2-7b-test 57 | 58 | python3 xpk.py workload delete \ 59 | --workload ${WORKLOAD_NAME_TO_DELETE} \ 60 | --cluster ${CLUSTER_NAME} 61 | ``` 62 | -------------------------------------------------------------------------------- /training/v5p/Llama2-7B-PyTorch/benchmark.sh: -------------------------------------------------------------------------------- 1 | # SCP the environment setup to all instances. 2 | gcloud compute tpus tpu-vm scp config.json fsdp_config.json train.sh host.sh env.sh $TPU_NAME:~ --worker=all --project $PROJECT --zone=$ZONE 3 | 4 | # Actually runs the benchmark. 5 | gcloud compute tpus tpu-vm ssh $TPU_NAME --project $PROJECT --zone=$ZONE --worker=all --command="$(cat host.sh)" 6 | 7 | # Copy the profile and output back 8 | gcloud compute tpus tpu-vm scp --recurse $TPU_NAME:~/output ./ --project=$PROJECT --zone=$ZONE 9 | -------------------------------------------------------------------------------- /training/v5p/Llama2-7B-PyTorch/config.sh: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "meta-llama/Llama-2-7b-hf", 3 | "architectures": [ 4 | "LlamaForCausalLM" 5 | ], 6 | "bos_token_id": 1, 7 | "eos_token_id": 2, 8 | "hidden_act": "silu", 9 | "hidden_size": 4096, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 11008, 12 | "max_position_embeddings": 4096, 13 | "model_type": "llama", 14 | "num_attention_heads": 32, 15 | "num_hidden_layers": 32, 16 | "num_key_value_heads": 32, 17 | "pretraining_tp": 1, 18 | "rms_norm_eps": 1e-05, 19 | "rope_scaling": null, 20 | "tie_word_embeddings": false, 21 | "torch_dtype": "bfloat16", 22 | "transformers_version": "4.40.0.dev0", 23 | "use_cache": false, 24 | "vocab_size": 32000 25 | } 26 | -------------------------------------------------------------------------------- /training/v5p/Llama2-7B-PyTorch/env.sh: -------------------------------------------------------------------------------- 1 | # Uncomment below to set the Huggingface token 2 | # HF_TOKEN=hf_*** 3 | 4 | PJRT_DEVICE=TPU 5 | XLA_IR_DEBUG=1 6 | XLA_HLO_DEBUG=1 7 | PROFILE_EPOCH=0 8 | PROFILE_STEP=3 9 | PROFILE_DURATION_MS=120000 10 | XLA_USE_SPMD=1 11 | MAX_STEPS=20 12 | SEQ_LENGTH=4096 13 | BATCH_SIZE=512 14 | -------------------------------------------------------------------------------- /training/v5p/Llama2-7B-PyTorch/fsdp_config.sh: -------------------------------------------------------------------------------- 1 | { 2 | "fsdp_transformer_layer_cls_to_wrap": [ 3 | "LlamaDecoderLayer" 4 | ], 5 | "xla": true, 6 | "xla_fsdp_v2": true, 7 | "xla_fsdp_grad_ckpt": true 8 | } 9 | -------------------------------------------------------------------------------- /training/v5p/Llama2-7B-PyTorch/host.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DOCKER_IMAGE="us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-xla/llama2:7b" 4 | 5 | worker_id=$(curl -s "http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number" -H 'Metadata-Flavor: Google') 6 | 7 | cat >> /dev/null <&1 | sed "s/^/[worker $slice_id:$worker_id] /g" | tee runlog 11 | set -o xtrace 12 | # Configure docker 13 | sudo groupadd docker 14 | sudo usermod -aG docker $USER 15 | # newgrp applies updated group permissions 16 | newgrp - docker 17 | gcloud auth configure-docker us-central1-docker.pkg.dev --quiet 18 | # Kill any running benchmarks 19 | docker kill $USER-test 20 | docker pull $DOCKER_IMAGE 21 | docker run --rm \ 22 | --name $USER-test \ 23 | --privileged \ 24 | --env-file env.sh \ 25 | -v /home/$USER:/tmp/home \ 26 | --shm-size=16G \ 27 | --net host \ 28 | -u root \ 29 | --entrypoint /bin/bash $DOCKER_IMAGE \ 30 | /tmp/home/train.sh 31 | 32 | PIPE_EOF 33 | -------------------------------------------------------------------------------- /training/v5p/Llama2-7B-PyTorch/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Remove existing repo and old data. 3 | LOCAL_DIR=/tmp/home/ 4 | rm -rf "${LOCAL_DIR}/output" 5 | rm -rf "${LOCAL_DIR}/plugins" 6 | rm -rf "${LOCAL_DIR}/cache" 7 | mkdir -p "${LOCAL_DIR}/output" 8 | mkdir -p "${LOCAL_DIR}/plugins" 9 | mkdir -p "${LOCAL_DIR}/cache" 10 | 11 | unset LD_PRELOAD 12 | 13 | 14 | cd transformers/ 15 | python3 examples/pytorch/language-modeling/run_clm.py \ 16 | --dataset_name wikitext \ 17 | --dataset_config_name wikitext-103-raw-v1 \ 18 | --per_device_train_batch_size "$BATCH_SIZE" \ 19 | --do_train \ 20 | --output_dir "${LOCAL_DIR}/output/test-clm" \ 21 | --overwrite_output_dir \ 22 | --config_name "${LOCAL_DIR}/config.json" \ 23 | --cache_dir "${LOCAL_DIR}/cache" \ 24 | --tokenizer_name meta-llama/Llama-2-7b-hf \ 25 | --block_size "$SEQ_LENGTH" \ 26 | --optim adafactor \ 27 | --save_strategy no \ 28 | --logging_strategy no \ 29 | --fsdp "full_shard" \ 30 | --fsdp_config "${LOCAL_DIR}/fsdp_config.json" \ 31 | --torch_dtype bfloat16 \ 32 | --dataloader_drop_last yes \ 33 | --flash_attention \ 34 | --max_steps "$MAX_STEPS" 35 | -------------------------------------------------------------------------------- /training/v5p/Llama4-Maverick-17B-128E-Maxtext/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Llama4-Maverick-17B-128E Maxtext on TPU v5p-256 2 | 3 | This documents present steps to run Llama4-Maverick-17B-128E [MaxText](https://github.com/google/maxtext) workload through [XPK](https://github.com/google/xpk/blob/main/README.md) tool. 4 | 5 | ## XPK setup 6 | 7 | Please follow this [link](https://github.com/gclouduniverse/reproducibility/tree/main/Training/TPU-v5p/XPK_README.md) to create your GKE cluster with XPK. 8 | 9 | ## Prep for Maxtext 10 | 11 | Please follow the [MAXTEXT_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build the docker image. The following variables should be set: 12 | 13 | In step 1, Use the MaxText [tpu-recipes-v0.1.3](https://github.com/AI-Hypercomputer/maxtext/releases/tag/tpu-recipes-v0.1.3) tag to run this recipe: 14 | ``` 15 | git checkout tpu-recipes-v0.1.3 16 | ``` 17 | 18 | In step 3, use the jax-stable-stack image containing JAX 0.5.2: 19 | ``` 20 | BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 21 | bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE} 22 | ``` 23 | 24 | ## Run workloads 25 | 26 | From the MaxText root directory, start your workload 27 | 28 | ``` 29 | python3 -m benchmarks.benchmark_runner xpk \ 30 | --project=$PROJECT \ 31 | --zone=$ZONE \ 32 | --device_type=v5p-256 \ 33 | --num_slices=1 \ 34 | --cluster_name=${CLUSTER_NAME} \ 35 | --base_output_directory=${OUTPUT_DIR} \ 36 | --model_name="llama4_maverick_dropless_v5p_256" \ 37 | --base_docker_image=maxtext_base_image 38 | ``` 39 | 40 | From your workload logs, you should start seeing step time logs like the following: 41 | 42 | ``` 43 | completed step: 12, seconds: 24.792, TFLOP/s/device: 160.005, Tokens/s/device: 1321.725, total_weights: 4194304, loss: 10.034 44 | ``` 45 | 46 | Workload details can be found in `MaxText@tpu-recipes-v0.1.3` [here](https://github.com/AI-Hypercomputer/maxtext/blob/9ca35d7e60b71303b9f6fa885447d32e8a612c47/benchmarks/maxtext_v5p_model_configs.py#L151-L196): 47 | 48 | ``` 49 | MaxTextModel( 50 | model_name="llama4_maverick_dropless_v5p_256", 51 | model_type="llama4-17b-128e", 52 | tuning_params={ 53 | "per_device_batch_size": 4, 54 | "max_target_length": 8192, 55 | "ici_fsdp_parallelism": 32, 56 | "ici_tensor_parallelism": 4, 57 | "enable_checkpointing": False, 58 | "dtype": "bfloat16", 59 | "weight_dtype": "float32", 60 | "megablox": True, 61 | "sparse_matmul": True, 62 | "dataset_type": "synthetic", 63 | "opt_type": "adamw", 64 | "skip_first_n_steps_for_profiler": 5, 65 | "profiler_steps": 3, 66 | "profiler": "xplane", 67 | "remat_policy": "custom", 68 | "decoder_layer_input": "offload", 69 | "out_proj": "offload", 70 | "query_proj": "offload", 71 | "key_proj": "offload", 72 | "value_proj": "offload", 73 | "reuse_example_batch": 1, 74 | "sa_block_q": 2048, 75 | "sa_block_kv": 2048, 76 | "sa_block_kv_compute": 2048, 77 | "sa_block_q_dkv": 2048, 78 | "sa_block_kv_dkv": 2048, 79 | "sa_block_kv_dkv_compute": 2048, 80 | "sa_block_q_dq": 2048, 81 | "sa_block_kv_dq": 2048, 82 | "tokenizer_path": "meta-llama/Llama-4-Maverick-17B-128E", 83 | }, 84 | xla_flags=( 85 | xla_flags_library.MOE_VMEM_LIMIT_FLAG 86 | + xla_flags_library.CF_FOR_ALL_GATHER 87 | + xla_flags_library.DATA_PARALLEL_OVERLAP 88 | + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER 89 | + xla_flags_library.HOST_OFFLOAD_FLAGS 90 | ), 91 | ) 92 | ``` 93 | -------------------------------------------------------------------------------- /training/v5p/Llama4-Scout-17B-16E-Maxtext/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Llama4-Scout-17B-16E Maxtext on TPU v5p-256 2 | 3 | This documents present steps to run Llama4-Scout-17B-16E [MaxText](https://github.com/google/maxtext) workload through [XPK](https://github.com/google/xpk/blob/main/README.md) tool. 4 | 5 | ## XPK setup 6 | 7 | Please follow this [link](https://github.com/gclouduniverse/reproducibility/tree/main/Training/TPU-v5p/XPK_README.md) to create your GKE cluster with XPK. 8 | 9 | ## Prep for Maxtext 10 | 11 | Please follow the [MAXTEXT_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build the docker image. The following variables should be set: 12 | 13 | In step 1, Use the MaxText [tpu-recipes-v0.1.3](https://github.com/AI-Hypercomputer/maxtext/releases/tag/tpu-recipes-v0.1.3) tag to run this recipe: 14 | ``` 15 | git checkout tpu-recipes-v0.1.3 16 | ``` 17 | 18 | In step 3, use the jax-stable-stack image containing JAX 0.5.2: 19 | ``` 20 | BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 21 | bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE} 22 | ``` 23 | 24 | ## Run workloads 25 | 26 | From the MaxText root directory, start your workload 27 | 28 | ``` 29 | python3 -m benchmarks.benchmark_runner xpk \ 30 | --project=$PROJECT \ 31 | --zone=$ZONE \ 32 | --device_type=v5p-256 \ 33 | --num_slices=1 \ 34 | --cluster_name=${CLUSTER_NAME} \ 35 | --base_output_directory=${OUTPUT_DIR} \ 36 | --model_name="llama4_scout_dropless_v5p_256" \ 37 | --base_docker_image=maxtext_base_image 38 | ``` 39 | 40 | From your workload logs, you should start seeing step time logs like the following: 41 | 42 | ``` 43 | completed step: 12, seconds: 31.494, TFLOP/s/device: 251.760, Tokens/s/device: 2080.892, total_weights: 8388608, loss: 10.929 44 | ``` 45 | 46 | Workload details can be found in `MaxText@tpu-recipes-v0.1.3` [here](https://github.com/AI-Hypercomputer/maxtext/blob/9ca35d7e60b71303b9f6fa885447d32e8a612c47/benchmarks/maxtext_v5p_model_configs.py#L109-L149): 47 | 48 | ``` 49 | MaxTextModel( 50 | model_name="llama4_scout_dropless_v5p_256", 51 | model_type="llama4-17b-16e", 52 | tuning_params={ 53 | "per_device_batch_size": 8, 54 | "max_target_length": 8192, 55 | "ici_fsdp_parallelism": -1, 56 | "enable_checkpointing": False, 57 | "dtype": "bfloat16", 58 | "weight_dtype": "float32", 59 | "megablox": True, 60 | "sparse_matmul": True, 61 | "dataset_type": "synthetic", 62 | "opt_type": "adamw", 63 | "skip_first_n_steps_for_profiler": 5, 64 | "profiler_steps": 3, 65 | "profiler": "xplane", 66 | "remat_policy": "custom", 67 | "decoder_layer_input": "offload", 68 | "reuse_example_batch": 1, 69 | "sa_block_q": 2048, 70 | "sa_block_kv": 2048, 71 | "sa_block_kv_compute": 2048, 72 | "sa_block_q_dkv": 2048, 73 | "sa_block_kv_dkv": 2048, 74 | "sa_block_kv_dkv_compute": 2048, 75 | "sa_block_q_dq": 2048, 76 | "sa_block_kv_dq": 2048, 77 | "tokenizer_path": "meta-llama/Llama-4-Scout-17B-16E", 78 | }, 79 | xla_flags=( 80 | xla_flags_library.MOE_VMEM_LIMIT_FLAG 81 | + xla_flags_library.CF_FOR_ALL_GATHER 82 | + xla_flags_library.DATA_PARALLEL_OVERLAP 83 | + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER 84 | + xla_flags_library.HOST_OFFLOAD_FLAGS 85 | ), 86 | ) 87 | ``` 88 | -------------------------------------------------------------------------------- /training/v5p/Mixtral-8X7B-Maxtext/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Mixtral-8X7B Maxtext on TPU v5p 2 | 3 | This documents present steps to run Mixtral-8x7B [MaxText](https://github.com/google/maxtext) workload through [XPK](https://github.com/google/xpk/blob/main/README.md) tool. 4 | 5 | ## XPK setup 6 | 7 | Please follow this [link](https://github.com/gclouduniverse/reproducibility/tree/main/Training/TPU-v5p/XPK_README.md) to create your GKE cluster with XPK. 8 | 9 | 10 | ## Run script 11 | 12 | 1. Clone [Maxtext](https://github.com/AI-Hypercomputer/maxtext) repo. 13 | ``` 14 | git clone https://github.com/AI-Hypercomputer/maxtext.git 15 | ``` 16 | 17 | 2. Build a local docker image with default name `maxtext_base_image`. 18 | 19 | ``` 20 | cd maxtext 21 | bash docker_build_dependency_image.sh MODE=stable DEVICE=tpu 22 | ``` 23 | 24 | 3. (Optional) Install XPK if you haven't set it up. 25 | 26 | ``` 27 | pip install xpk 28 | ``` 29 | 30 | 4. Specify workload configs. 31 | 32 | ``` 33 | export CLUSTER_NAME=v5p-demo # 34 | export WORKLOAD_NAME=Mixtral-8x7b-test # 35 | export RUN_NAME=Mixtral-8x7b-run # 36 | export TPU_TYPE=v5p-128 # 37 | export NUM_SLICES=1 # 38 | export OUTPUT_PATH=gs://v5p-demo/ # 39 | ``` 40 | 41 | 5. Copy `scripts/run_mixtral-8x7b.sh` script, paste it to `MaxText/configs` folder, and run workload in the maxtext github root directory. 42 | 43 | ``` 44 | xpk workload create \ 45 | --cluster ${CLUSTER_NAME} \ 46 | --workload ${WORKLOAD_NAME} \ 47 | --tpu-type=${TPU_TYPE} \ 48 | --num-slices=${NUM_SLICES} \ 49 | --base-docker-image maxtext_base_image \ 50 | --command "bash MaxText/configs/run_mixtral-8x7b.sh RUN_NAME=${RUN_NAME} OUTPUT_PATH=${OUTPUT_PATH}" 51 | ``` 52 | 53 | 6. (Optional) Clean up the GKE cluster. 54 | 55 | ``` 56 | xpk cluster delete --cluster ${CLUSTER_NAME} 57 | ``` 58 | -------------------------------------------------------------------------------- /training/v5p/Mixtral-8X7B-Maxtext/scripts/run_mixtral-8x7b.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | echo "Running Mixtral 8x7b script" 16 | 17 | # Stop execution if any command exits with error 18 | set -e 19 | 20 | export EXECUTABLE="train.py" 21 | 22 | # Set environment variables 23 | for ARGUMENT in "$@"; do 24 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 25 | export "$KEY"="$VALUE" 26 | done 27 | 28 | # Set up RUN_NAME 29 | if [ -n "$RUN_NAME" ]; 30 | then 31 | export M_RUN_NAME=$RUN_NAME 32 | fi 33 | 34 | # Train 35 | export LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_scoped_vmem_limit_kib=81920" 36 | python3 MaxText/$EXECUTABLE MaxText/configs/base.yml\ 37 | model_name=mixtral-8x7b steps=10 per_device_batch_size=36 enable_checkpointing=false\ 38 | max_target_length=4096 base_output_directory=$OUTPUT_PATH dtype=bfloat16 weight_dtype=bfloat16\ 39 | dataset_type=synthetic attention=flash tokenizer_path=assets/tokenizer.mistral-v1 40 | -------------------------------------------------------------------------------- /training/v5p/Mixtral-8x7B-PyTorch/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Mixtral-8X7B on TPU v5p 2 | 3 | 4 | This user guide provides a concise overview of the essential steps required to run HuggingFace (HF) Mixtral training on Cloud TPUs. 5 | 6 | 7 | ## Environment Setup 8 | 9 | Please follow the corresponding TPU generation's user guide to setup the GCE TPUs 10 | first. 11 | 12 | NOTE: For best performance on Mixtral, use "untwisted" variants of TPU topologies. 13 | For example, if you're creating a 128 chip v5p slice, select `4x4x8_untwisted` in 14 | the `gcloud` CLI. The default is "twisted" which has been observed to reduce 15 | performance in Mixtral. See [1] about details on twisted tori. 16 | 17 | Please replace all your-* with your TPUs' information. 18 | 19 | ``` 20 | export TPU_NAME=your-tpu-name 21 | export ZONE=your-tpu-zone 22 | export PROJECT=your-tpu-project 23 | ``` 24 | 25 | You may use this command to create an untwisted 128 chip v5p slice: 26 | 27 | ``` 28 | gcloud alpha compute tpus tpu-vm create $TPU_NAME \ 29 | --type v5p --topology 4x4x8_untwisted \ 30 | --project $PROJECT --zone $ZONE --version v2-alpha-tpuv5 31 | ``` 32 | 33 | ## Steps to Run HF Mixtral 8x7B 34 | 35 | The following setup runs the training job with Mixtral 8x7B on GCE TPUs using the docker image from this registry (`us-central1-docker.pkg.dev/tpu-pytorch/docker/reproducibility/mixtral@sha256:c8f4a66e02a26548c9d71296cd345d3a302f6960db3da7fd3addf34c00332b5b`), the docker image uses the pytorch and torch_xla nightly build from 09/28/2024 and installed with all the package dependency needed to run the model training. All the command below should run from your own machine (not the TPU host you created). 36 | 37 | 1. git clone and navigate to this README repo and run training script: 38 | ```bash 39 | git clone --depth 1 https://github.com/AI-Hypercomputer/tpu-recipes.git 40 | cd training/v5p/Mixtral-8x7B-PyTorch 41 | ``` 42 | 2. Edit `env.sh` to add the hugging face token and/or setup the training parameters. 43 | ```bash 44 | # add your hugging face token 45 | HF_TOKEN=hf_*** 46 | ``` 47 | 3. Edit `host.sh` to add the docker image URL if default docker image is not accessible to you. 48 | ```bash 49 | # docker image URL to use for the training 50 | DOCKER_IMAGE=us-central1-docker.pkg.dev/tpu-pytorch/docker/reproducibility/mixtral@sha256:c8f4a66e02a26548c9d71296cd345d3a302f6960db3da7fd3addf34c00332b5b 51 | ``` 52 | 4. Run the training script: 53 | ```bash 54 | ./benchmark.sh 55 | ``` 56 | `benchmark.sh` script will upload 1) environment parameters in `env.sh`, 2) model related config in `config.json`, `fsdp_config.json`, 3) docker launch script in `host.sh` and 4) python training command in `train.sh` into all TPU workers, and starts the training afterwards. When all training steps complete, it will print out training metrics of each worker as below in terminal: 57 | ``` 58 | [worker :0] ***** train metrics ***** 59 | [worker :0] epoch = 0.3125 60 | [worker :0] total_flos = 10915247040GF 61 | [worker :0] train_loss = 9.278 62 | [worker :0] train_runtime = 0:46:45.60 63 | [worker :0] train_samples = 32816 64 | [worker :0] train_samples_per_second = 3.65 65 | [worker :0] train_steps_per_second = 0.007 66 | ``` 67 | In addition, it will copy back the trained model under `output/*`. 68 | 69 | 70 | 71 | [1]: https://cloud.google.com/tpu/docs/v4#twisted-tori 72 | 73 | -------------------------------------------------------------------------------- /training/v5p/Mixtral-8x7B-PyTorch/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # SCP the environment setup to all instances. 4 | gcloud compute tpus tpu-vm scp config.json fsdp_config.json train.sh host.sh env.sh "$TPU_NAME:~" --worker=all --project $PROJECT --zone=$ZONE 5 | 6 | # Actually runs the benchmark. 7 | gcloud compute tpus tpu-vm ssh $TPU_NAME --project $PROJECT --zone=$ZONE --worker=all --command="$(cat host.sh)" 8 | 9 | # Copy the profile and output back 10 | gcloud compute tpus tpu-vm scp --recurse $TPU_NAME:~/output ./ --project=$PROJECT --zone=$ZONE 11 | -------------------------------------------------------------------------------- /training/v5p/Mixtral-8x7B-PyTorch/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "MixtralForCausalLM" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 1, 7 | "eos_token_id": 2, 8 | "hidden_act": "silu", 9 | "hidden_size": 4096, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 14336, 12 | "max_position_embeddings": 32768, 13 | "model_type": "mixtral", 14 | "num_attention_heads": 32, 15 | "num_experts_per_tok": 2, 16 | "num_hidden_layers": 32, 17 | "num_key_value_heads": 8, 18 | "num_local_experts": 8, 19 | "output_router_logits": false, 20 | "rms_norm_eps": 1e-05, 21 | "rope_theta": 1000000.0, 22 | "router_aux_loss_coef": 0.02, 23 | "sliding_window": null, 24 | "tie_word_embeddings": false, 25 | "torch_dtype": "bfloat16", 26 | "transformers_version": "4.36.0.dev0", 27 | "use_cache": false, 28 | "vocab_size": 32000 29 | } 30 | 31 | -------------------------------------------------------------------------------- /training/v5p/Mixtral-8x7B-PyTorch/env.sh: -------------------------------------------------------------------------------- 1 | # Uncomment below to set the Huggingface token 2 | # HF_TOKEN=hf_*** 3 | PJRT_DEVICE=TPU 4 | XLA_IR_DEBUG=1 5 | XLA_HLO_DEBUG=1 6 | PROFILE_EPOCH=0 7 | PROFILE_STEP=3 8 | PROFILE_DURATION_MS=120000 9 | XLA_USE_SPMD=1 10 | MAX_STEPS=20 11 | SEQ_LENGTH=4096 12 | 13 | # Per-host batch size is the number of training examples used by a TPU VM 14 | # in each training step. For v5p, it will be 4 times the per-device batch size, 15 | # since each TPU VM is connected to 4 v5p TPU chips. The following will lead 16 | # to a per-device batch size of 8. Customize accordingly. 17 | PER_HOST_BATCH_SIZE=32 18 | 19 | # XLA flags 20 | LIBTPU_INIT_ARGS='--xla_enable_async_collective_permute=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_rwb_fusion=false --xla_jf_rematerialization_percent_shared_memory_limit=10000 --xla_tpu_enable_net_router_in_all_gather=false --xla_tpu_prefer_async_allgather_to_allreduce=true --xla_enable_all_gather_3d_emitter=true' 21 | -------------------------------------------------------------------------------- /training/v5p/Mixtral-8x7B-PyTorch/fsdp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fsdp_transformer_layer_cls_to_wrap": [ 3 | "MixtralDecoderLayer" 4 | ], 5 | "xla": true, 6 | "xla_fsdp_v2": true, 7 | "xla_fsdp_grad_ckpt": true 8 | } 9 | -------------------------------------------------------------------------------- /training/v5p/Mixtral-8x7B-PyTorch/host.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DOCKER_IMAGE=us-central1-docker.pkg.dev/tpu-pytorch/docker/reproducibility/mixtral@sha256:c8f4a66e02a26548c9d71296cd345d3a302f6960db3da7fd3addf34c00332b5b 4 | 5 | worker_id=$(curl -s "http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number" -H 'Metadata-Flavor: Google') 6 | 7 | cat >> /dev/null <&1 | sed "s/^/[worker $slice_id:$worker_id] /g" | tee runlog 11 | set -o xtrace 12 | # Configure docker 13 | sudo groupadd docker 14 | sudo usermod -aG docker $USER 15 | # newgrp applies updated group permissions 16 | newgrp - docker 17 | gcloud auth configure-docker us-central1-docker.pkg.dev --quiet 18 | # Kill any running benchmarks 19 | docker kill $USER-test 20 | docker pull $DOCKER_IMAGE 21 | docker run --rm \ 22 | --name $USER-test \ 23 | --privileged \ 24 | --env-file env.sh \ 25 | -v /home/$USER:/tmp/home \ 26 | --shm-size=16G \ 27 | --net host \ 28 | -u root \ 29 | --entrypoint /bin/bash $DOCKER_IMAGE \ 30 | /tmp/home/train.sh 31 | 32 | PIPE_EOF 33 | -------------------------------------------------------------------------------- /training/v5p/Mixtral-8x7B-PyTorch/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Remove existing repo and old data. 3 | LOCAL_DIR=/tmp/home/ 4 | rm -rf "${LOCAL_DIR}/output" 5 | rm -rf "${LOCAL_DIR}/plugins" 6 | rm -rf "${LOCAL_DIR}/cache" 7 | mkdir -p "${LOCAL_DIR}/output" 8 | mkdir -p "${LOCAL_DIR}/plugins" 9 | mkdir -p "${LOCAL_DIR}/cache" 10 | 11 | unset LD_PRELOAD 12 | 13 | 14 | cd transformers/ 15 | 16 | 17 | python3 examples/pytorch/language-modeling/run_clm.py \ 18 | --dataset_name wikitext \ 19 | --dataset_config_name wikitext-103-raw-v1 \ 20 | --per_device_train_batch_size "${PER_HOST_BATCH_SIZE}" \ 21 | --do_train \ 22 | --output_dir "${LOCAL_DIR}/output/test-clm" \ 23 | --overwrite_output_dir \ 24 | --config_name "${LOCAL_DIR}/config.json" \ 25 | --cache_dir "${LOCAL_DIR}/cache" \ 26 | --tokenizer_name mistralai/Mixtral-8x7B-v0.1 \ 27 | --block_size "$SEQ_LENGTH" \ 28 | --optim adafactor \ 29 | --save_strategy no \ 30 | --logging_strategy no \ 31 | --fsdp "full_shard" \ 32 | --fsdp_config "${LOCAL_DIR}/fsdp_config.json" \ 33 | --torch_dtype bfloat16 \ 34 | --dataloader_drop_last yes \ 35 | --flash_attention \ 36 | --num_train_epochs 1 \ 37 | --max_steps "$MAX_STEPS" \ 38 | --gmm 39 | 40 | -------------------------------------------------------------------------------- /training/v5p/SDXL-MaxDiffusion/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for training Stable Diffusion XL on TPU v5p 2 | 3 | This documents present steps to run StableDiffusion [MaxDiffusion](https://github.com/google/maxdiffusion/tree/main/src/maxdiffusion) workload through [XPK](https://github.com/google/xpk/blob/main/README.md) tool. 4 | 5 | Setup XPK and create cluster [XPK Userguide](../../../Training/TPU-v5p/XPK_README.md) 6 | 7 | Build a local docker image. 8 | 9 | ``` 10 | LOCAL_IMAGE_NAME=maxdiffusion_base_image 11 | docker build --no-cache --network host -f ./docker/maxdiffusion.Dockerfile -t ${LOCAL_IMAGE_NAME} . 12 | ``` 13 | 14 | Run workload using xpk. 15 | 16 | ``` 17 | export BASE_OUTPUT_DIR=gs://output_bucket/ 18 | export NUM_SLICES=1 19 | 20 | xpk workload create \ 21 | --cluster \ 22 | --base-docker-image maxdiffusion_base_image \ 23 | --workload ${USER}-sdxl-v5p \ 24 | --tpu-type= \ 25 | --num-slices=${NUM_SLICES} \ 26 | --zone $ZONE \ 27 | --project $PROJECT \ 28 | --command "bash scripts/run_v5p-ddp-pbs-1.sh BASE_OUTPUT_DIR=${BASE_OUTPUT_DIR} COMMITS=00150750841e9155669fd1ac4c6f2fcd0e0654e0" 29 | ``` 30 | 31 | MFU Calculation. 32 | 33 | Above only UNET is trainable model, FLOPS count = 162.27 TFLOPS @BS=8, we get the MFU 34 | ``` 35 | MFU = UNET FLOPS / Step Time / Per Device Peak FLOPS 36 | ``` -------------------------------------------------------------------------------- /training/v5p/SDXL-MaxDiffusion/docker/maxdiffusion.Dockerfile: -------------------------------------------------------------------------------- 1 | # Install ip. 2 | FROM python:3.10-slim-bullseye 3 | RUN apt-get update 4 | RUN apt-get install -y curl procps gnupg git 5 | RUN apt-get install -y net-tools ethtool iproute2 6 | 7 | # Add the Google Cloud SDK package repository 8 | RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list 9 | RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - 10 | 11 | # Install the Google Cloud SDK 12 | RUN apt-get update && apt-get install -y google-cloud-sdk 13 | 14 | # Set the default Python version to 3.10 15 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1 16 | 17 | # Set environment variables for Google Cloud SDK and Python 3.10 18 | ENV PATH="/usr/local/google-cloud-sdk/bin:/usr/local/bin/python3.10:${PATH}" 19 | 20 | RUN pip install --no-cache-dir jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 21 | RUN pip install --no-cache-dir -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html 22 | RUN pip install git+https://github.com/google/jax 23 | RUN pip uninstall jaxlib -y 24 | RUN pip install -U --pre jax[tpu] --no-cache-dir -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 25 | RUN pip install git+https://github.com/mlperf/logging.git 26 | 27 | RUN git clone https://github.com/google/maxdiffusion.git 28 | 29 | WORKDIR maxdiffusion 30 | 31 | RUN git checkout 00150750841e9155669fd1ac4c6f2fcd0e0654e0 32 | RUN pip install -r requirements.txt 33 | 34 | RUN pip install . -------------------------------------------------------------------------------- /training/v5p/SDXL-MaxDiffusion/scripts/run_v5p-ddp-pbs-1.sh: -------------------------------------------------------------------------------- 1 | BASE_OUTPUT_DIR=$1 2 | COMMITS=$2 3 | 4 | # Set environment variables 5 | for ARGUMENT in "$@"; do 6 | IFS='=' read -r KEY VALUE <<< "$ARGUMENT" 7 | export "$KEY"="$VALUE" 8 | done 9 | 10 | export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_megacore_fusion=false --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true' 11 | 12 | LIBTPU_INIT_ARGS+=' --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_enable_async_all_reduce=true' 13 | LIBTPU_INIT_ARGS+=' --xla_tpu_enable_async_collective_fusion_with_mosaic_custom_call=true --xla_tpu_mosaic_fusion=true' 14 | LIBTPU_INIT_ARGS+=' --xla_enable_async_reduce_scatter_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_reduce_scatter=true' 15 | LIBTPU_INIT_ARGS+=' --xla_tpu_spmd_threshold_for_allgather_cse=1000000 --xla_jf_spmd_threshold_for_windowed_einsum_mib=1000000' 16 | 17 | #reload code to specific commits 18 | rm -rf maxdiffusion 19 | 20 | git clone https://github.com/google/maxdiffusion.git 21 | cd maxdiffusion 22 | git checkout ${COMMITS} 23 | 24 | python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 resolution=1024 per_device_batch_size=8 output_dir=$BASE_OUTPUT_DIR jax_cache_dir=${BASE_OUTPUT_DIR}/cache_dir/ max_train_steps=5000 attention=flash run_name=sdxl-fsdp-v5p-ddp -------------------------------------------------------------------------------- /training/v5p/XPK_README.md: -------------------------------------------------------------------------------- 1 | ## Initialization 2 | 1. Run the following commands to initialize the project and zone. 3 | ``` 4 | export PROJECT=tpu-prod-env-multipod # 5 | export ZONE=us-central2-b # 6 | gcloud config set project $PROJECT 7 | gcloud config set compute/zone $ZONE 8 | ``` 9 | 2. Install XPK by following the [prerequisites](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#prerequisites) and [installation](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#installation) 10 | instructions. Also ensure you have the proper [GCP permissions](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#installation). 11 | 12 | * In order to run the tpu-recipes as-is, run the `git clone` command from your home directory: 13 | ``` 14 | git clone https://github.com/google/xpk.git 15 | ``` 16 | 17 | 3. Run the rest of these commands from the cloned XPK directory: 18 | 19 | ``` 20 | cd xpk # Should be equivalent to cd ~/xpk 21 | ``` 22 | 23 | 24 | ## GKE Cluster Creation 25 | 1. Specify your TPU GKE cluster configs. 26 | ``` 27 | export CLUSTER_NAME=v5p-demo # 28 | export NETWORK_NAME=${CLUSTER_NAME}-only-mtu9k 29 | export NETWORK_FW_NAME=${NETWORK_NAME}-only-fw 30 | export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}" 31 | export TPU_TYPE=v5p-512 # 32 | export NUM_SLICES=1 # 33 | ``` 34 | 35 | 2. Create the network and firewall for this cluster if it doesn’t exist yet. 36 | ``` 37 | gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional 38 | gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT} 39 | ``` 40 | 41 | 3. Create GKE cluster with TPU node-pools 42 | ``` 43 | python3 xpk.py cluster create \ 44 | --default-pool-cpu-machine-type=n1-standard-32 \ 45 | --cluster ${CLUSTER_NAME} \ 46 | --tpu-type=${TPU_TYPE} \ 47 | --num-slices=${NUM_SLICES} \ 48 | --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ 49 | --on-demand 50 | ``` 51 | 52 | * Noted: TPU has `reserved`, `on-demand`, `spot` quota. This example used the `on-demand` quota. If you have the reserved or spot quota, please refer to this [link](https://github.com/google/xpk?tab=readme-ov-file#cluster-create). 53 | * If you want to check what quota you have, please refer to this [link](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus#ensure-quota). 54 | * You should be able to see your GKE cluster similar to this once it is created successfully:![image](https://github.com/user-attachments/assets/60743411-5ee5-4391-bb0e-7ffba4d91c1d) 55 | 56 | 57 | 4. Test your GKE cluster to make sure it is usable 58 | ``` 59 | python3 xpk.py workload create \ 60 | --cluster ${CLUSTER_NAME} \ 61 | --workload hello-world-test \ 62 | --tpu-type=${TPU_TYPE} \ 63 | --num-slices=${NUM_SLICES} \ 64 | --command "echo Hello World" 65 | ``` 66 | * You should be able to to see results like this: ![image](https://github.com/user-attachments/assets/c33010a6-e109-411e-8fb5-afb4edb3fa72) 67 | 68 | 5. You can also check your workload status with the following command: 69 | ``` 70 | python3 xpk.py workload list \ 71 | --cluster ${CLUSTER_NAME} 72 | ``` 73 | 6. For more information about XPK, please refer to this [link](https://github.com/google/xpk). 74 | 75 | ## GKE Cluster Deletion 76 | You can use the following command to delete GKE cluster: 77 | ``` 78 | export CLUSTER_NAME=v5p-demo # 79 | 80 | python3 xpk.py cluster delete \ 81 | --cluster $CLUSTER_NAME 82 | ``` -------------------------------------------------------------------------------- /utils/profile_convert.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import statistics 3 | import xplane_pb2 4 | 5 | 6 | 7 | def analyze_step_duration(file_path: str) -> float: 8 | xspace = xplane_pb2.XSpace() # type: ignore 9 | 10 | # Read and parse the xplane proto 11 | with open(file_path, "rb") as f: 12 | print(f"Parsing {file_path}", file=sys.stderr) 13 | xspace.ParseFromString(f.read()) 14 | 15 | durations = [] 16 | event_count = 0 17 | 18 | for plane in xspace.planes: 19 | if plane.name != "/device:TPU:0": 20 | continue 21 | print(f"Plane ID: {plane.id}, Name: {plane.name}", file=sys.stderr) 22 | for line in plane.lines: 23 | if line.name != "XLA Modules": 24 | continue 25 | print(f" Line ID: {line.id}, Name: {line.name}", file=sys.stderr) 26 | for event in line.events: 27 | name: str = plane.event_metadata[event.metadata_id].name 28 | secs: float = event.duration_ps / 1e12 29 | if name.startswith("SyncTensorsGraph."): 30 | durations.append(secs) 31 | event_count += 1 32 | print( 33 | f" Event Metadata Name: {name}, ID: {event.metadata_id}, Duration: {secs} s", 34 | file=sys.stderr) 35 | 36 | print(f"Got {event_count} iterations", file=sys.stderr) 37 | 38 | if event_count == 0: 39 | raise ValueError("No SyncTensorsGraph events found.") 40 | 41 | if len(durations) < 3: 42 | print( 43 | "[Warning] Not enough SyncTensorsGraph events found to drop outliers.", 44 | file=sys.stderr) 45 | # Compute a simple average. 46 | return sum(durations) / len(durations) 47 | 48 | return statistics.median(durations) 49 | 50 | if __name__ == "__main__": 51 | if len(sys.argv) != 2: 52 | print(f"Usage: {sys.argv[0]} ") 53 | sys.exit(1) 54 | proto_file_path = sys.argv[1] 55 | try: 56 | # Average SyncTensorsGraph duration. 57 | average_duration = analyze_step_duration(proto_file_path) 58 | print(f"{average_duration:.4f}") 59 | except Exception as e: 60 | print(f"Error: {e}") 61 | sys.exit(1) 62 | --------------------------------------------------------------------------------