├── data
├── multiprocessing
│ ├── main.py
│ └── fineweb.py
├── ray_distributed
│ ├── test_cluster.py
│ └── main.py
├── scripts
│ ├── batch_test.sh
│ └── setup_tpu.sh
├── make_folder.py
├── requirements.txt
└── README.md
├── .env.local
├── public
├── banner.png
├── experts.png
├── loss-load.png
├── loss-val.png
└── banner-light.png
├── .gitignore
├── launcher.sh
├── setupTpu.sh
├── run.sh
├── debug_tpu.sh
├── README.md
├── dataset.py
├── utils.py
├── main.py
└── model.py
/data/multiprocessing/main.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.env.local:
--------------------------------------------------------------------------------
1 | WANDB_KEY=your_wandb_key
--------------------------------------------------------------------------------
/public/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/public/banner.png
--------------------------------------------------------------------------------
/public/experts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/public/experts.png
--------------------------------------------------------------------------------
/public/loss-load.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/public/loss-load.png
--------------------------------------------------------------------------------
/public/loss-val.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/public/loss-val.png
--------------------------------------------------------------------------------
/public/banner-light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/divyamakkar0/JAXformer/HEAD/public/banner-light.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.txt
2 | *.npy
3 | *.pyc
4 | trainSet/*
5 | valSet/*
6 | *.tgz
7 | results/*
8 | *.env
9 | wandb/*
10 | .vscode/*
11 | .DS_Store
12 | .env
13 | data_dir/
14 | edu_fineweb10B
15 | checkpoints/
16 | final_data/
17 |
--------------------------------------------------------------------------------
/launcher.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source .env
3 |
4 | IPS=(
5 | "35.186.25.28"
6 | "35.186.39.76"
7 | "107.167.173.215"
8 | "35.186.132.44"
9 | "35.186.24.134"
10 | "35.186.58.69"
11 | "35.186.134.160"
12 | "35.186.107.62"
13 | )
14 |
15 |
16 | printf "%s\n" "${IPS[@]}" | xargs -n 1 -P 0 -I {} bash run.sh {}
--------------------------------------------------------------------------------
/setupTpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # script to run to setup dependencies on TPU
4 |
5 | pip install -U "jax[tpu]"
6 | pip install flax jaxtyping wandb tpu-info einops tiktoken
7 | pip install google-cloud google-cloud-storage gcloud gcsfs
8 |
9 | if [[ -z "~/.config/gcloud/application_default_credentials.json" ]]; then
10 | echo "gcloud storage credentials found "
11 | else
12 | echo "gcloud storage credentials not found "
13 | gcloud auth application-default login
14 | fi
15 |
16 | echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.bashrc
17 | echo -e "\n\ndone run tpu-info to check"
18 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | IP=$1
4 | SESSION="trainingRun"
5 | command="python main.py --checkpoint_steps=75 --n_device_axis 8 2 2 --name moe1B --train_batch_size 32 --use_cache --wandb --eval_steps 10"
6 |
7 | echo "Running on $IP"
8 |
9 | ssh adityamakkar@$IP "
10 |
11 | tmux kill-session -t $SESSION
12 | tmux new-session -d -s $SESSION
13 |
14 | tmux send-keys -t $SESSION:0 'cd ~/Jaxformer && rm -rf samples && mkdir samples' C-m
15 | tmux send-keys -t $SESSION:0 'git fetch origin && git reset --hard origin/main' C-m
16 | tmux send-keys -t $SESSION:0 'bash setupTpu.sh' C-m
17 | tmux send-keys -t $SESSION:0 '$command' C-m
18 | "
19 | echo "done commands"
20 |
--------------------------------------------------------------------------------
/data/ray_distributed/test_cluster.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | import socket
3 | import time
4 |
5 | import ray
6 |
7 | ray.init()
8 |
9 | print('''This cluster consists of
10 | {} nodes in total
11 | {} CPU resources in total
12 | '''.format(len(ray.nodes()), ray.cluster_resources()['CPU']))
13 |
14 | @ray.remote
15 | def f():
16 | time.sleep(0.001)
17 | # Return IP address.
18 | return socket.gethostname()
19 |
20 | object_ids = [f.remote() for _ in range(10000)]
21 | ip_addresses = ray.get(object_ids)
22 | # ip_addresses = [f() for _ in range(10000)]
23 |
24 | print('Tasks executed')
25 | for ip_address, num_tasks in Counter(ip_addresses).items():
26 | print(' {} tasks on {}'.format(num_tasks, ip_address))
--------------------------------------------------------------------------------
/data/scripts/batch_test.sh:
--------------------------------------------------------------------------------
1 | set -euo pipefail
2 |
3 | # TODO: fill in these variables
4 | PY_SCRIPT=""
5 | RUNTIME_S="5"
6 | RESULTS="bs_benchmark_$(date +%Y%m%d_%H%M%S).csv"
7 | EXTRA_ARGS=""
8 |
9 | echo "batch_size,tokens_per_s" > "$RESULTS"
10 |
11 | for bs in $(seq 25 50 750) 750; do
12 | if grep -q "^$bs," "$RESULTS"; then
13 | continue
14 | fi
15 |
16 | echo "Testing BATCH_SIZE=$bs ..."
17 |
18 | LOG="$(BENCHMARK=1 BATCH_SIZE="$bs" timeout --signal=INT "${RUNTIME_S}s" \
19 | python "$PY_SCRIPT" $EXTRA_ARGS 2>&1 || true)"
20 |
21 | RATE="$(printf "%s\n" "$LOG" \
22 | | sed -nE 's/.*([0-9]+(\.[0-9]+)?) ?tokens\/s.*/\1/p' \
23 | | tail -n1)"
24 | [ -z "${RATE:-}" ] && RATE="NA"
25 |
26 |
27 | echo "$LOG"
28 |
29 | if [ -z "${RATE:-}" ]; then
30 | RATE="NA"
31 | fi
32 |
33 | echo "$bs,$RATE" | tee -a "$RESULTS"
34 | done
35 |
36 | echo "Done. Results saved to $RESULTS"
37 |
--------------------------------------------------------------------------------
/data/make_folder.py:
--------------------------------------------------------------------------------
1 | # Before running:
2 | # 1) pip install -r requirements.txt in data dir
3 | # 2) authenticate using gcloud CLI
4 |
5 | from google.cloud import storage_control_v2
6 | from dotenv import load_dotenv
7 | import os
8 |
9 | load_dotenv()
10 |
11 | def create_folder(bucket_name: str, folder_name: str) -> None:
12 | storage_control_client = storage_control_v2.StorageControlClient()
13 | project_path = storage_control_client.common_project_path("_")
14 | bucket_path = f"{project_path}/buckets/{bucket_name}"
15 |
16 | request = storage_control_v2.CreateFolderRequest(
17 | parent=bucket_path,
18 | folder_id=folder_name,
19 | )
20 | response = storage_control_client.create_folder(request=request)
21 |
22 | print(f"Created folder: {response.name}")
23 |
24 | if __name__ == '__main__':
25 | bucket_name = os.getenv("GCP_BUCKET", default="")
26 |
27 | for folder_name in ['train', 'test', 'checkpoints']:
28 | if bucket_name == "":
29 | print("GCP_BUCKET is not set")
30 | raise ValueError("GCP_BUCKET is not set")
31 | else:
32 | create_folder(bucket_name, folder_name)
--------------------------------------------------------------------------------
/data/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohappyeyeballs==2.6.1
2 | aiohttp==3.12.14
3 | aiosignal==1.4.0
4 | async-timeout==5.0.1
5 | attrs==25.3.0
6 | beautifulsoup4==4.13.4
7 | cachetools==5.5.2
8 | certifi==2025.7.14
9 | charset-normalizer==3.4.2
10 | datasets==4.0.0
11 | dill==0.3.8
12 | filelock==3.18.0
13 | frozenlist==1.7.0
14 | fsspec==2025.3.0
15 | gcloud==0.18.3
16 | google==3.0.0
17 | google-api-core==2.25.1
18 | google-auth==2.40.3
19 | google-cloud-core==2.4.3
20 | google-cloud-storage==3.2.0
21 | google-crc32c==1.7.1
22 | google-resumable-media==2.7.2
23 | googleapis-common-protos==1.70.0
24 | hf-xet==1.1.5
25 | httplib2==0.22.0
26 | huggingface-hub==0.33.4
27 | idna==3.10
28 | multidict==6.6.3
29 | multiprocess==0.70.16
30 | numpy==2.2.6
31 | oauth2client==4.1.3
32 | packaging==25.0
33 | pandas==2.3.1
34 | propcache==0.3.2
35 | proto-plus==1.26.1
36 | protobuf==6.31.1
37 | pyarrow==20.0.0
38 | pyasn1==0.6.1
39 | pyasn1_modules==0.4.2
40 | pyparsing==3.2.3
41 | python-dateutil==2.9.0.post0
42 | python-dotenv==1.1.1
43 | pytz==2025.2
44 | PyYAML==6.0.2
45 | regex==2024.11.6
46 | requests==2.32.4
47 | rsa==4.9.1
48 | six==1.17.0
49 | soupsieve==2.7
50 | tiktoken==0.9.0
51 | tqdm==4.67.1
52 | typing_extensions==4.14.1
53 | tzdata==2025.2
54 | urllib3==2.5.0
55 | xxhash==3.5.0
56 | yarl==1.20.1
57 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | ## Tokenization
2 |
3 | Collection of scripts to efficiently upload tokenized shards to GCP buckets.
4 |
5 | ### Authentication with GCP
6 |
7 | To get started, follow these steps to prepare your Google Cloud environment and configure the scripts for execution.
8 |
9 | 1. **Initialize `gcloud`**: First, you need to set up the Google Cloud SDK and link it to your project. Run the `gcloud init` command in your terminal and follow the prompts to select your project.
10 |
11 | ```bash
12 | gcloud init
13 | ```
14 |
15 | 2. **Authenticate Credentials**: Authenticate your application-default credentials. This step is crucial as it allows the Python scripts to securely access and interact with your Google Cloud Storage bucket.
16 |
17 | ```bash
18 | gcloud auth application-default login
19 | ```
20 |
21 | 3. **Create a GCP Bucket**: Create a new Google Cloud Storage bucket to store your tokenized data. Although Google Cloud Storage has a flat namespace, you can simulate a hierarchical structure (like folders) by using prefixes in your object names. The scripts will handle this automatically.
22 |
23 | ```bash
24 | gcloud storage buckets create gs://[YOUR_BUCKET_NAME]
25 | ```
26 |
27 | ### Running the Scripts
28 |
29 | 4. **Configure Scripts**: Open the Python scripts and change any placeholder names to match your specific setup, such as the `BUCKET_NAME` and `DATA_CACHE_DIR`.
30 |
31 | 5. **Run `make_folder.py`**: Execute the `make_folder.py` script to create the necessary local directories for temporary data storage.
32 |
33 | ```bash
34 | python make_folder.py
35 | ```
36 |
37 | 6. **Run `main.py`**: Finally, run the main `main.py` script. This will start the data streaming, tokenization, and upload process. Run the script from the folder you wish to use (single VM runs the `multiprocessing` main.py, while distributed runs from `ray_distributed`).
38 |
39 | ```bash
40 | python main.py
41 | ```
--------------------------------------------------------------------------------
/debug_tpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source .env # wandb key from here
4 | SESSION="mysession"
5 | command="python main.py --checkpoint_steps=1 --n_device_axis 8 2 2 --name moe1B --train_batch_size 16 --use_cache --wandb --eval_steps 1"
6 |
7 | IPS=(
8 | "35.186.25.28"
9 | "35.186.39.76"
10 | "107.167.173.215"
11 | "35.186.132.44"
12 | "35.186.24.134"
13 | "35.186.58.69"
14 | "35.186.134.160"
15 | "35.186.107.62"
16 | )
17 |
18 | tmux new-session -d -s "$SESSION" -n "main" \; \
19 | split-window -v \; split-window -v \; split-window -v \; \
20 | select-pane -t "$SESSION":0.0 \; split-window -h \; \
21 | select-pane -t "$SESSION":0.1 \; split-window -h \; \
22 | select-pane -t "$SESSION":0.2 \; split-window -h \; \
23 | select-pane -t "$SESSION":0.3 \; split-window -h \; \
24 | select-layout tiled
25 |
26 | tmux new-window -t "$SESSION" -n "monitor" \; \
27 | split-window -v \; split-window -v \; split-window -v \; \
28 | select-pane -t "$SESSION":1.0\; split-window -h \; \
29 | select-pane -t "$SESSION":1.1 \; split-window -h \; \
30 | select-pane -t "$SESSION":1.2 \; split-window -h \; \
31 | select-pane -t "$SESSION":1.3 \; split-window -h \; \
32 | select-layout tiled
33 |
34 | for i in $(seq 0 7); do
35 | tmux send-keys -t "$SESSION":0.$i "ssh adityamakkar@${IPS[$i]}" C-m
36 | tmux send-keys -t "$SESSION":0.$i "cd ~/Jaxformer && rm -rf samples && mkdir samples" C-m
37 | tmux send-keys -t "$SESSION":0.$i "git fetch origin && git reset --hard origin/main" C-m
38 | tmux send-keys -t "$SESSION":0.$i "bash setupTpu.sh" C-m
39 | tmux send-keys -t "$SESSION":0.$i "wandb login $WANDB_KEY" C-m
40 | tmux send-keys -t "$SESSION":0.$i "$command" C-m
41 | done
42 |
43 | for i in $(seq 0 7); do
44 | tmux send-keys -t "$SESSION":1.$i "ssh adityamakkar@${IPS[$i]}" C-m
45 | tmux send-keys -t "$SESSION":1.$i "watch -n 1 tpu-info" C-m
46 | done
47 |
48 | tmux attach -t "$SESSION"
49 |
--------------------------------------------------------------------------------
/data/scripts/setup_tpu.sh:
--------------------------------------------------------------------------------
1 | set -euo pipefail
2 | source .env 2>/dev/null || true
3 |
4 | IPS=(
5 | "LIST GOES HERE"
6 | )
7 |
8 | HEAD_IP="${HEAD_IP:-${IPS[0]}}"
9 |
10 | SSH_USER="${SSH_USER:-$USER}"
11 | SSH_KEY="${SSH_KEY:-}"
12 |
13 | RAY_PORT="${RAY_PORT:-6379}" # ray head port
14 | WORKDIR="${WORKDIR:-~/jaxformer}" # remote project dir on each node
15 | PYTHON="${PYTHON:-python3}" # python on the remote machines
16 | MAIN_SCRIPT="${MAIN_SCRIPT:-main_distributed.py}"
17 | MAIN_ARGS="${MAIN_ARGS:-}" # optional args for your script
18 |
19 |
20 | export HEAD_IP RAY_PORT WORKDIR PYTHON SSH_USER SSH_KEY
21 |
22 | echo "[1/3] Starting Ray on all nodes (head: $HEAD_IP, port: $RAY_PORT)..."
23 | printf "%s\n" "${IPS[@]}" | xargs -n 1 -P 0 -I {} bash run.sh {}
24 |
25 | echo "[2/3] Waiting for Ray cluster to become ready..."
26 |
27 | TIMEOUT_SEC=120
28 | EXPECTED_NODES=${#IPS[@]}
29 | DEADLINE=$(( $(date +%s) + TIMEOUT_SEC ))
30 |
31 | while :; do
32 | set +e
33 | NODES=$(
34 | ssh -o StrictHostKeyChecking=no ${SSH_KEY:+-i "$SSH_KEY"} "$SSH_USER@$HEAD_IP" \
35 | "bash -lc '$PYTHON - <<\"PY\"
36 | import time, ray
37 | ray.init(address=\"auto\", ignore_reinit_error=True, namespace=\"_probe_\")
38 | print(len(ray.nodes()))
39 | PY
40 | '"
41 | )
42 | STATUS=$?
43 | set -e
44 | if [[ $STATUS -eq 0 && \"$NODES\" =~ ^[0-9]+$ && $NODES -ge $EXPECTED_NODES ]]; then
45 | echo "Ray is up with $NODES/$EXPECTED_NODES nodes."
46 | break
47 | fi
48 | if [[ $(date +%s) -ge $DEADLINE ]]; then
49 | echo "Timed out waiting for Ray cluster. Got $NODES/$EXPECTED_NODES nodes." >&2
50 | exit 1
51 | fi
52 | sleep 3
53 | done
54 |
55 | echo "[3/3] Launching training on the head node..."
56 | ssh -o StrictHostKeyChecking=no ${SSH_KEY:+-i "$SSH_KEY"} "$SSH_USER@$HEAD_IP" "bash -lc '
57 | set -e
58 | cd \"$WORKDIR\"
59 | # If your script needs the Ray address, pass it explicitly:
60 | # $PYTHON -u $MAIN_SCRIPT --address $HEAD_IP:$RAY_PORT $MAIN_ARGS
61 | # Otherwise, if it uses ray.init(\"auto\"), the below is fine:
62 | $PYTHON -u $MAIN_SCRIPT $MAIN_ARGS
63 | '"
64 |
65 | echo "Done."
66 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | # JAXformer
10 |
11 | This is a zero-to-one guide on scaling modern transformers with n-dimensional parallelism in JAX. Our blog for [JAXformer](https://jaxformer.com) covers a from-scratch guide to distributed data processing, FSDP, pipeline parallelism, tensor parallelism, weight-sharding, activation-sharding, MoE scaling and much more. Our guide aims to bridge the gap between theory and end-to-end implementation by demonstrating how to scale a modern language model.
12 |
13 | ## Structure
14 |
15 | The model built throughout the blog is defined in `model.py`. The main training script is in `main.py`. `utils.py` and `dataset.py` contain the dataclasses and dataset processing implementations. `debug_tpu.sh` launches a TMUX with 8 panes to SSH into 8 nodes at once running the command in the `command` variable. `launcher.sh` ssh's headlessly into each node and executves `run.sh` creating TMUX terminals inside the ssh to allow for runs to continue even if the ssh connection is broken. `setup_tpu.sh` setups all the dependencies on the TPU. The `data` directory contains all the relevant code for tokenization.
16 |
17 | ## Results
18 |
19 | Results for a 1B model (300M active) trained to 3.28 val loss using 3-D sharding on a cluster of 32 TPU-v4(8 FSDP, 2 Pipeline, 2 Tensor).
20 |
21 | ### Val-Loss
22 |
23 |
24 |
25 |
26 | ### Load-Loss
27 |
28 |
29 |
30 |
31 | ### Expert-per-Head
32 |
33 |
34 |
35 |
36 | ## Contributing and Contact
37 |
38 | If you see any issues or have questions, open up an issue or send in a PR. You can also leave a comment on the website itself (powered by Giscus) or in the GitHub discussion.
39 |
40 | ## Acknowledgements
41 |
42 | This guide was written by [Aditya Makkar](https://x.com/AdityaMakkar000), [Divya Makkar](https://x.com/_DivyaMakkar), and [Chinmay Jindal](https://x.com/chinmayjindal_). We are all undergraduate students studying Computer Science at the University of Waterloo.
43 |
44 | The website uses a Distill-style Jekyll theme called [Al-Folio](https://github.com/alshedivat/al-folio). The idea of the blog and front-end structure is inspired by Google DeepMind's [How to Scale Your Model](https://jax-ml.github.io/scaling-book/) guide. [Google's TRC](https://sites.research.google/trc/about/) was used to provide the compute needed. Thanks!
45 |
--------------------------------------------------------------------------------
/data/multiprocessing/fineweb.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from Andrej's Karpathy tokenizer script in GPT-2 reproduction
3 | Changes: GPT-4 embeddings with uint32, multiprocessing main guards (macOS), checkpoint marking
4 | """
5 |
6 | import os
7 | import multiprocessing as mp
8 | import numpy as np
9 | import tiktoken
10 | from datasets import load_dataset # pip install datasets
11 | from tqdm import tqdm # pip install tqdm
12 |
13 | BLUE = "\033[34m"
14 | RESET = "\033[0m"
15 |
16 | # ------------------------------------------
17 | local_dir = "edu_fineweb10B"
18 | remote_name = "sample-10BT"
19 | shard_size = int(1e8) # 100M tokens per shard, total of 100 shards
20 |
21 | # create the cache the local directory if it doesn't exist yet
22 | DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir)
23 | os.makedirs(DATA_CACHE_DIR, exist_ok=True)
24 |
25 | # download the dataset
26 | fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train")
27 |
28 | # init the tokenizer
29 | enc = tiktoken.encoding_for_model("gpt-4") # 'cl100k_base'
30 |
31 | eot = enc._special_tokens["<|endoftext|>"] # end of text token
32 |
33 |
34 | def tokenize(doc):
35 | # tokenizes a single document and returns a numpy array of uint32 tokens
36 | tokens = [eot] # the special <|endoftext|> token delimits all documents
37 | tokens.extend(enc.encode_ordinary(doc["text"]))
38 | tokens_np = np.array(tokens)
39 | assert (0 <= tokens_np).all() and (tokens_np < 2**32).all(), (
40 | "token dictionary too large for uint32"
41 | )
42 | tokens_np_uint32 = tokens_np.astype(np.uint32)
43 | return tokens_np_uint32
44 |
45 |
46 | def write_datafile(filename, tokens_np):
47 | np.save(filename, tokens_np)
48 |
49 |
50 | print(f"{BLUE}hf dataset has been downloaded{RESET}")
51 |
52 |
53 | # tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder)
54 | cpu_count = os.cpu_count()
55 | nprocs = max(1, cpu_count // 2)
56 |
57 |
58 | def main():
59 | with mp.Pool(nprocs) as pool:
60 | shard_index = 0
61 | # preallocate buffer to hold current shard
62 | all_tokens_np = np.empty((shard_size,), dtype=np.uint32)
63 | token_count = 0
64 | progress_bar = None
65 | for tokens in pool.imap(tokenize, fw, chunksize=16):
66 | # print(tokens)
67 | # break
68 |
69 | # is there enough space in the current shard for the new tokens?
70 | if token_count + len(tokens) < shard_size:
71 | # simply append tokens to current shard
72 | all_tokens_np[token_count : token_count + len(tokens)] = tokens
73 | token_count += len(tokens)
74 | # update progress bar
75 | if progress_bar is None:
76 | progress_bar = tqdm(
77 | total=shard_size, unit="tokens", desc=f"Shard {shard_index}"
78 | )
79 | progress_bar.update(len(tokens))
80 | else:
81 | # write the current shard and start a new one
82 | split = "val" if shard_index == 0 else "train"
83 | filename = os.path.join(
84 | DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}"
85 | )
86 | # split the document into whatever fits in this shard; the remainder goes to next one
87 | remainder = shard_size - token_count
88 | progress_bar.update(remainder)
89 | all_tokens_np[token_count : token_count + remainder] = tokens[
90 | :remainder
91 | ]
92 | write_datafile(filename, all_tokens_np)
93 | shard_index += 1
94 | progress_bar = None
95 | # populate the next shard with the leftovers of the current doc
96 | all_tokens_np[0 : len(tokens) - remainder] = tokens[remainder:]
97 | token_count = len(tokens) - remainder
98 |
99 | # write any remaining tokens as the last shard
100 | if token_count != 0:
101 | split = "val" if shard_index == 0 else "train"
102 | filename = os.path.join(
103 | DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}"
104 | )
105 | write_datafile(filename, all_tokens_np[:token_count])
106 |
107 |
108 | if __name__ == "__main__":
109 | main()
110 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import jax
3 | from jax.sharding import NamedSharding
4 | import numpy as np
5 | from typing import Optional, Tuple
6 | from utils import dataConfig
7 | from google.cloud import storage
8 | import time
9 |
10 |
11 | def log(out: str):
12 | if jax.process_index() == 0:
13 | print(out)
14 |
15 |
16 | class Dataset:
17 | def __init__(
18 | self,
19 | process_path: str,
20 | T: int,
21 | batch_size: int,
22 | microbatch: int,
23 | dp: int,
24 | pp: int,
25 | bucket_name: str,
26 | id: str,
27 | partition: Optional[NamedSharding] = None,
28 | ):
29 | assert (batch_size % microbatch) == 0, (
30 | "microbatch should divide batch size per data axis"
31 | )
32 | assert (microbatch % pp) == 0, "pp should divide microbatch size"
33 |
34 | self.T = T
35 | self.batch_size = batch_size
36 | self.dp = dp
37 | self.microbatch = microbatch
38 |
39 | self.step_idx = 0
40 | self.shard_idx = 0
41 | self.partition = partition
42 |
43 | self.bucket_name = bucket_name
44 | self.base_process_path = process_path
45 | self.process_path = process_path
46 | self.id = id
47 | self.data = self.return_blobs(bucket_name, self.id)
48 | self.dir_name = "bucket_downloads"
49 | try:
50 | os.mkdir(self.dir_name)
51 | except OSError as e:
52 | log(f"{self.dir_name} already exists")
53 |
54 | self.load_next_shard()
55 |
56 | def return_blobs(self, bucket_name, prefix, delimiter=None):
57 | res = []
58 | storage_client = storage.Client()
59 | blobs = storage_client.list_blobs(
60 | bucket_name, prefix=prefix, delimiter=delimiter
61 | )
62 | for blob in blobs:
63 | res.append(blob.name)
64 |
65 | return res[1:]
66 |
67 | def download_blob_to_stream(self, bucket_name, source_blob_name, file_obj):
68 | storage_client = storage.Client()
69 | bucket = storage_client.bucket(bucket_name)
70 |
71 | blob = bucket.blob(source_blob_name)
72 | blob.download_to_file(file_obj)
73 | log(f"Downloaded blob {source_blob_name} to file-like object.")
74 |
75 | return file_obj
76 |
77 | def download_bucket(self, bucket_name, source_name, f):
78 | while True:
79 | try:
80 | result = self.download_blob_to_stream(bucket_name, source_name, f)
81 | return result
82 | except Exception as e:
83 | log("Failed to download due to exception")
84 | time.sleep(5)
85 |
86 | def download_next(self):
87 | log("Started downloading")
88 | source_name = self.data[self.shard_idx % len(self.data)]
89 | self.shard_idx += 1
90 | log(f" Downloading: {source_name} | Shard_idx: {self.shard_idx}")
91 |
92 | self.process_path = f"{self.base_process_path}_{self.id}_{self.shard_idx}"
93 | with open(self.process_path, "wb") as f:
94 | result = self.download_bucket(self.bucket_name, source_name, f)
95 | log(f"Done downloading {result}")
96 |
97 | def load_next_shard(self):
98 | self.download_next()
99 |
100 | def process_prev():
101 | log(f"Processing shard at {self.process_path}\n\n")
102 | try:
103 | data = np.load(self.process_path)
104 | except:
105 | log(f"couldn't load data\n\n")
106 | self.dataset = data[:-1]
107 | self.labels = data[1:]
108 |
109 | len_dataset = self.dataset.shape[0]
110 | max_batches = len_dataset // (self.batch_size * self.T * self.dp)
111 |
112 | self.dataset = self.dataset[
113 | : max_batches * self.batch_size * self.T * self.dp
114 | ].reshape(
115 | max_batches,
116 | self.microbatch,
117 | (self.dp * self.batch_size) // self.microbatch,
118 | self.T,
119 | )
120 | self.labels = self.labels[
121 | : max_batches * self.batch_size * self.T * self.dp
122 | ].reshape(
123 | max_batches,
124 | self.microbatch,
125 | (self.dp * self.batch_size) // self.microbatch,
126 | self.T,
127 | )
128 |
129 | self.dataset = jax.device_put(self.dataset, self.partition)
130 | self.labels = jax.device_put(self.labels, self.partition)
131 |
132 | process_prev()
133 |
134 | os.remove(self.process_path)
135 |
136 | def __len__(self):
137 | return self.dataset.shape[0]
138 |
139 | def __call__(self, step=1):
140 | if self.step_idx + step > self.dataset.shape[0]:
141 | self.step_idx = 0
142 | self.load_next_shard()
143 |
144 | x = self.dataset[self.step_idx : self.step_idx + step]
145 | y = self.labels[self.step_idx : self.step_idx + step]
146 | self.step_idx += step
147 |
148 | return x, y
149 |
150 | @classmethod
151 | def getDataset(
152 | cls,
153 | cfg: dataConfig,
154 | partition: Optional[NamedSharding] = None,
155 | dp: int = 1,
156 | pp: int = 1,
157 | tp: int = 1,
158 | ) -> Tuple["Dataset", "Dataset"]:
159 | assert (cfg.T % tp) == 0, "T should be divisible by tensor parallelism"
160 | train_dataset = cls(
161 | cfg.process_path,
162 | cfg.T,
163 | cfg.train_batch_size,
164 | cfg.micro_batch_size,
165 | partition=partition,
166 | dp=dp,
167 | pp=pp,
168 | bucket_name=cfg.bucket_name,
169 | id=cfg.train_folder_name,
170 | )
171 | val_dataset = cls(
172 | cfg.process_path,
173 | cfg.T,
174 | cfg.val_batch_size,
175 | cfg.micro_batch_size,
176 | partition=partition,
177 | dp=dp,
178 | pp=pp,
179 | bucket_name=cfg.bucket_name,
180 | id=cfg.val_folder_name,
181 | )
182 |
183 | return train_dataset, val_dataset
184 |
185 | @property
186 | def tokens_per_step(self):
187 | return self.dp * self.batch_size * self.T
188 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from dataclasses import dataclass, field
3 | from typing import List, Optional
4 | import jax.numpy as jnp
5 | from jax.numpy import dtype
6 | import json
7 |
8 |
9 | @dataclass
10 | class modelConfig:
11 | """model config class"""
12 |
13 | model_dimension: int
14 | vocab_size: int
15 | n_head: int
16 | blocks: int
17 | layers_per_block: int
18 | T: int
19 | latent_dim: int
20 | dhR: int
21 | dropout_rate: float = 0.1
22 | model_dtype: str = "bfloat16"
23 | k: int = 0
24 | n_experts: int = 0
25 | n_shared: int = 0
26 | capacity_factor: float = 1.0
27 |
28 |
29 | @dataclass
30 | class dataConfig:
31 | bucket_name: str
32 | process_path: str = "./bucket_downloads/processShard"
33 | train_folder_name: str = "train"
34 | val_folder_name: str = "val"
35 | T: int = 6
36 | train_batch_size: int = 3
37 | val_batch_size: int = 3
38 | micro_batch_size: int = 1
39 |
40 |
41 | @dataclass
42 | class LRConfig:
43 | max_lr: float
44 | min_lr: float
45 | end_lr: float
46 | warmup_steps: int
47 | end_steps: int
48 |
49 |
50 | @dataclass
51 | class deviceConfig:
52 | n_device_axis: List[int]
53 |
54 |
55 | @dataclass
56 | class inferenceConfig:
57 | prompt: Optional[str] = None
58 | batch_size: int = 1
59 | top_k: int = 10000
60 | temperature: float = 1.0
61 | n_devices: int = 1
62 | max_tokens: int = 256
63 | use_cache: bool = True
64 |
65 |
66 | @dataclass
67 | class config:
68 | model_config: modelConfig
69 | data_config: dataConfig
70 | lr: LRConfig
71 | device_config: deviceConfig
72 | inference_config: inferenceConfig
73 | output_dir: str
74 | training_steps: int
75 | name: str
76 | grad_step: int = 1
77 | alpha: float = 0.001
78 | checkpoint_steps: int = 10
79 | eval_steps: int = 25
80 | seed: int = 0
81 | wandb: bool = True
82 | grad_clip_norm: float = 1.0
83 |
84 |
85 | def parse_args():
86 | parser = argparse.ArgumentParser(description="model training")
87 |
88 | parser.add_argument("--model_dimension", type=int, default=768)
89 | parser.add_argument("--vocab_size", type=int, default=100277)
90 | parser.add_argument("--n_head", type=int, default=12)
91 | parser.add_argument("--blocks", type=int, default=8)
92 | parser.add_argument("--layers_per_block", type=int, default=6)
93 | parser.add_argument("--T", type=int, default=1024)
94 | parser.add_argument("--latent_dim", type=int, default=128)
95 | parser.add_argument("--dhR", type=int, default=128)
96 | parser.add_argument("--dropout_rate", type=float, default=0.2)
97 | parser.add_argument("--model_dtype", type=str, default="bfloat16")
98 | parser.add_argument("--k", type=int, default=2)
99 | parser.add_argument("--n_experts", type=int, default=16)
100 | parser.add_argument("--n_shared", type=int, default=2)
101 | parser.add_argument("--capacity_factor", type=float, default=1.5)
102 |
103 | parser.add_argument(
104 | "--bucket_name",
105 | type=str,
106 | default="350bt_gpt4",
107 | )
108 | parser.add_argument(
109 | "--process_path", type=str, default="./bucket_downloads/processShard"
110 | )
111 | parser.add_argument("--train_folder_name", type=str, default="train")
112 | parser.add_argument("--val_folder_name", type=str, default="val")
113 | parser.add_argument("--train_batch_size", type=int, default=16)
114 | parser.add_argument("--val_batch_size", type=int, default=16)
115 | parser.add_argument("--micro_batch_size", type=int, default=4)
116 |
117 | parser.add_argument("--max_lr", type=float, default=6e-4)
118 | parser.add_argument("--min_lr", type=float, default=0)
119 | parser.add_argument("--end_lr", type=float, default=6e-5)
120 | parser.add_argument("--warmup_steps", type=int, default=5000)
121 | parser.add_argument("--end_steps", type=int, default=75000)
122 |
123 | parser.add_argument("--alpha", type=float, default=0.0001)
124 | parser.add_argument("--name", type=str, default=None, required=True)
125 | parser.add_argument("--output_dir", type=str, default="gs://results_jaxformer/")
126 | parser.add_argument("--checkpoint_steps", type=int, default=100)
127 | parser.add_argument("--seed", type=int, default=0)
128 | parser.add_argument("--wandb", action="store_true")
129 |
130 | parser.add_argument("--training_steps", type=int, default=100000)
131 | parser.add_argument("--grad_step", type=int, default=1)
132 | parser.add_argument("--eval_steps", type=int, default=25)
133 | parser.add_argument("--grad_clip_norm", type=float, default=1.0)
134 |
135 | parser.add_argument(
136 | "--n_device_axis",
137 | type=int,
138 | nargs="*",
139 | default=[1],
140 | )
141 |
142 | parser.add_argument("--inference_batch", type=int, default=1)
143 | parser.add_argument("--top_k", type=int, default=10000)
144 | parser.add_argument("--temperature", type=float, default=1.0)
145 | parser.add_argument("--use_cache", action="store_true")
146 | parser.add_argument("--max_tokens", type=int, default=10)
147 | parser.add_argument("--prompt", type=str, default="hello world")
148 | parser.add_argument("--n_devices", type=int, default=1)
149 |
150 | args = parser.parse_args()
151 |
152 | model_cfg = modelConfig(
153 | model_dimension=args.model_dimension,
154 | vocab_size=args.vocab_size,
155 | n_head=args.n_head,
156 | blocks=args.blocks,
157 | layers_per_block=args.layers_per_block,
158 | T=args.T,
159 | latent_dim=args.latent_dim,
160 | dhR=args.dhR,
161 | dropout_rate=args.dropout_rate,
162 | model_dtype=args.model_dtype,
163 | k=args.k,
164 | n_experts=args.n_experts,
165 | n_shared=args.n_shared,
166 | capacity_factor=args.capacity_factor,
167 | )
168 |
169 | data_cfg = dataConfig(
170 | bucket_name=args.bucket_name,
171 | process_path=args.process_path,
172 | train_folder_name=args.train_folder_name,
173 | val_folder_name=args.val_folder_name,
174 | T=args.T,
175 | train_batch_size=args.train_batch_size,
176 | val_batch_size=args.val_batch_size,
177 | micro_batch_size=args.micro_batch_size,
178 | )
179 |
180 | lr_cfg = LRConfig(
181 | max_lr=args.max_lr,
182 | min_lr=args.min_lr,
183 | end_lr=args.end_lr,
184 | warmup_steps=args.warmup_steps,
185 | end_steps=args.end_steps,
186 | )
187 |
188 | device_cfg = deviceConfig(
189 | n_device_axis=args.n_device_axis,
190 | )
191 |
192 | inference_cfg = inferenceConfig(
193 | prompt=args.prompt,
194 | batch_size=args.inference_batch,
195 | top_k=args.top_k,
196 | temperature=args.temperature,
197 | n_devices=args.n_devices,
198 | max_tokens=args.max_tokens,
199 | use_cache=args.use_cache,
200 | )
201 |
202 | cfg = config(
203 | model_config=model_cfg,
204 | data_config=data_cfg,
205 | lr=lr_cfg,
206 | name=args.name,
207 | output_dir=args.output_dir,
208 | device_config=device_cfg,
209 | checkpoint_steps=args.checkpoint_steps,
210 | inference_config=inference_cfg,
211 | seed=args.seed,
212 | training_steps=args.training_steps,
213 | grad_step=args.grad_step,
214 | eval_steps=args.eval_steps,
215 | alpha=args.alpha,
216 | wandb=args.wandb,
217 | grad_clip_norm=args.grad_clip_norm,
218 | )
219 |
220 | return cfg
221 |
--------------------------------------------------------------------------------
/data/ray_distributed/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import multiprocessing as mp
4 | import numpy as np
5 | import tiktoken
6 | from datasets import load_dataset
7 | from tqdm import tqdm
8 | from google.cloud.storage import Client, transfer_manager
9 | import argparse
10 | import ray
11 |
12 | # init ray in the cluster mode
13 | ray.init(address="auto")
14 |
15 | # constants for splits and multiprocessing
16 | TEST_SPLIT = 350
17 | BUCKET_NAME = "ray_jaxformer"
18 | BATCH_SIZE = 512
19 | WORKERS = int(os.cpu_count())
20 | nprocs = max(1, int(os.cpu_count() / 1.5))
21 |
22 | # other constants for dataset processing
23 | local_dir = "data_dir"
24 | remote_name = "sample-350BT"
25 | shard_size = int(1e8)
26 |
27 | # gcp storage client and bucket
28 | storage_client = Client()
29 | bucket = storage_client.bucket(BUCKET_NAME)
30 |
31 | # create the cache the local directory if it doesn't exist yet
32 | DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir)
33 | checkpoint_dir = os.path.join(os.path.dirname(__file__), 'checkpoints')
34 | os.makedirs(DATA_CACHE_DIR, exist_ok=True)
35 | os.makedirs(checkpoint_dir, exist_ok=True)
36 |
37 | # set up argument parser to check if --continue flag is given
38 | def setup_argument_parser():
39 | parser = argparse.ArgumentParser(description='Process the 350BT dataset')
40 | parser.add_argument('--continue', dest='continue_processing', action='store_true',
41 | help='Continue processing from a checkpoint')
42 | parser.set_defaults(continue_processing=False)
43 | return parser
44 |
45 | parser = setup_argument_parser()
46 | args = parser.parse_args()
47 | continue_processing = args.continue_processing
48 | checkpoint_to_resume = None
49 | shard_to_resume = 0
50 | skip_number = 0
51 |
52 | # if a --continue flag is given, pull latest checkpoint name from gcp bucket called checkpoints
53 | if continue_processing:
54 | # pull latest checkpoint name from gcp bucket called checkpoints
55 | blobs = bucket.list_blobs(prefix="checkpoints/")
56 | checkpoint_blobs = [b for b in blobs if str(b.name).endswith(".txt")]
57 | if not checkpoint_blobs:
58 | continue_processing = False
59 | else:
60 | latest_checkpoint = max(checkpoint_blobs, key=lambda b: b.updated)
61 | checkpoint_to_resume = latest_checkpoint.name[len("checkpoints/"):-4] # remove 'checkpoints/' prefix and '.txt' suffix
62 | shard_to_resume, skip_number = map(int, (latest_checkpoint.download_as_bytes().decode('utf-8')).split(':'))
63 |
64 | # ------------------------------------------
65 |
66 | fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train", streaming=True)
67 |
68 | # init the tokenizer
69 | enc = tiktoken.encoding_for_model("gpt-4") # 'cl100k_base'
70 | eot = enc._special_tokens['<|endoftext|>'] # end of text token
71 |
72 | # tokenize function with ray remote decorator
73 | @ray.remote
74 | def tokenize(doc):
75 | doc_id_return = doc['id']
76 | tokens = [eot]
77 | tokens.extend(enc.encode_ordinary(doc["text"]))
78 | tokens_np = np.array(tokens)
79 | assert (0 <= tokens_np).all() and (tokens_np < 2**32).all(), "token dictionary too large for uint32"
80 | tokens_np_uint32 = tokens_np.astype(np.uint32)
81 | return tokens_np_uint32, doc_id_return
82 |
83 | def write_datafile(filename, tokens_np):
84 | np.save(filename, tokens_np)
85 |
86 | # function to upload files to gcp bucket using transfer manager
87 | def upload_file(split):
88 | def upload_many_blobs_with_transfer_manager(split, filenames, source_directory="", workers=8):
89 |
90 | blob_names = [split + name for name in filenames]
91 |
92 | blob_file_pairs = [(os.path.join(source_directory, f), bucket.blob(b)) for f, b in zip(filenames, blob_names)]
93 |
94 | results = transfer_manager.upload_many(
95 | blob_file_pairs, skip_if_exists=True, max_workers=workers, worker_type=transfer_manager.THREAD
96 | )
97 |
98 | FILE_NAMES = os.listdir(DATA_CACHE_DIR)
99 | upload_many_blobs_with_transfer_manager(split, FILE_NAMES, DATA_CACHE_DIR, WORKERS)
100 | for file in FILE_NAMES:
101 | full_path = DATA_CACHE_DIR + '/' + file
102 | os.remove(full_path)
103 |
104 | # function to upload checkpoints to gcp bucket and remove local copies
105 | def upload_checkpoint():
106 | checkpoint_files = os.listdir(checkpoint_dir)
107 | for filename in checkpoint_files:
108 | blob = bucket.blob(f"checkpoints/{filename}")
109 | blob.upload_from_filename(os.path.join(checkpoint_dir, filename))
110 | for filename in checkpoint_files:
111 | os.remove(os.path.join(checkpoint_dir, filename))
112 |
113 | # skip to the previous checkpoint (zero by default)
114 | fw.skip(skip_number)
115 | shard_index = shard_to_resume + 1 if continue_processing else 0
116 |
117 | # variables to keep track of tokens in the current shard
118 | all_tokens_np = np.empty((shard_size,), dtype=np.uint32)
119 | token_count = 0
120 | progress_bar = None
121 | doc_iter = iter(fw)
122 |
123 | while True:
124 | batch = []
125 | try:
126 | for _ in range(BATCH_SIZE):
127 | batch.append(next(doc_iter))
128 | except StopIteration:
129 | pass
130 |
131 | if not batch:
132 | break
133 |
134 | # get the tokenized results from ray
135 | futures = [tokenize.remote(doc) for doc in batch]
136 | results = ray.get(futures)
137 |
138 | for tokens, doc_id in results:
139 | skip_number += 1
140 |
141 | # if the current document fits in the current shard
142 | if token_count + len(tokens) < shard_size:
143 |
144 | # simply append tokens to current shard
145 | all_tokens_np[token_count:token_count+len(tokens)] = tokens
146 | token_count += len(tokens)
147 |
148 | # update progress bar
149 | if progress_bar is None:
150 | progress_bar = tqdm(total=shard_size, unit="tokens", desc=f"Shard {shard_index}", dynamic_ncols=True)
151 | progress_bar.update(len(tokens))
152 |
153 | else:
154 |
155 | # save a checkpoint for resuming later
156 | checkpoint_filename = os.path.join(checkpoint_dir, f"{doc_id}.txt")
157 | with open(checkpoint_filename, "w") as f:
158 | f.write(str(shard_index) + ':' + str(skip_number))
159 |
160 | # write the current shard and start a new one
161 | if shard_index >= 0 and shard_index < TEST_SPLIT:
162 | split = 'test/'
163 | shard_index_number = shard_index
164 | else:
165 | split = 'train/'
166 | shard_index_number = shard_index - TEST_SPLIT
167 | split_name = split[:-1]
168 |
169 | filename = os.path.join(DATA_CACHE_DIR, f"{split_name}_{shard_index_number:04d}")
170 |
171 | # split the document into whatever fits in this shard; the remainder goes to next one
172 | remainder = shard_size - token_count
173 | progress_bar.update(remainder)
174 | all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
175 |
176 | write_datafile(filename, all_tokens_np)
177 | upload_file(split)
178 | upload_checkpoint()
179 |
180 | # update shard index and reset progress bar
181 | shard_index += 1
182 | progress_bar = None
183 |
184 | # populate the next shard with the leftovers of the current doc
185 | all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:]
186 | token_count = len(tokens)-remainder
187 |
188 | # write any remaining tokens as the last shard
189 | if token_count != 0:
190 | if shard_index >= 0 and shard_index < TEST_SPLIT:
191 | split = 'test/'
192 | shard_index_number = shard_index
193 | else:
194 | split = 'train/'
195 | shard_index_number = shard_index - TEST_SPLIT
196 | split_name = split[:-1]
197 |
198 | filename = os.path.join(DATA_CACHE_DIR, f"{split_name}_{shard_index_number:04d}")
199 |
200 | write_datafile(filename, all_tokens_np[:token_count])
201 | upload_file(split)
202 | upload_checkpoint()
203 |
204 |
205 | # clean up directory after function terminates
206 | if os.path.exists(DATA_CACHE_DIR):
207 | shutil.rmtree(DATA_CACHE_DIR)
208 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["XLA_FLAGS"] = (
4 | "--xla_gpu_triton_gemm_any=True --xla_gpu_enable_latency_hiding_scheduler=true "
5 | )
6 |
7 | import jax
8 | import jax.numpy as jnp
9 |
10 | jax.config.update("jax_compilation_cache_dir", "gs://jaxformer-cache/")
11 | jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
12 | jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
13 | jax.config.update(
14 | "jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir"
15 | )
16 |
17 | import optax
18 | import numpy as np
19 | import time
20 | import json
21 | import wandb
22 | import orbax.checkpoint as ocp
23 |
24 | from dataclasses import asdict
25 | from jax.sharding import PartitionSpec as P
26 | from functools import partial
27 | from typing import Tuple
28 | from model import shardedModel
29 | from dataset import Dataset
30 | from utils import parse_args, config
31 |
32 |
33 | def log(msg: str):
34 | if jax.process_index() == 0:
35 | print(msg)
36 |
37 |
38 | def init_devices(
39 | axes: Tuple[int, ...], axes_name: Tuple[str, ...]
40 | ) -> jax.sharding.Mesh:
41 | devices = np.array(jax.devices())
42 | for idx in np.ndindex(devices.shape):
43 | d = devices[idx]
44 | log(
45 | f" {idx} ID: {d.id}, Process: {d.process_index}, "
46 | f"Coords: {d.coords}, Core: {d.core_on_chip}"
47 | )
48 |
49 | assert devices.size == np.prod(axes), (
50 | f"Expected {np.prod(axes)} devices, got {devices.shape[0]}"
51 | )
52 | try:
53 | mesh = jax.make_mesh(axes, axes_name)
54 | except:
55 | log("Failed to create mesh with make_mesh, falling back to sharding.Mesh")
56 | mesh = jax.sharding.Mesh(devices.reshape(axes), axes_name)
57 | return mesh
58 |
59 |
60 | def main(cfg: config):
61 | key = jax.random.PRNGKey(cfg.seed)
62 | DATA_PARALLEL, LAYER_PARALLEL, TENSOR_PARALLEL = cfg.device_config.n_device_axis
63 |
64 | axes = (*cfg.device_config.n_device_axis,)
65 | axes_name = ("dp", "pp", "tp")
66 |
67 | mesh = init_devices(axes, axes_name)
68 | log(mesh)
69 |
70 | checkpoint_dir = cfg.output_dir + cfg.name
71 | options = ocp.CheckpointManagerOptions(max_to_keep=1)
72 | checkpoint_manager = ocp.CheckpointManager(checkpoint_dir, options=options)
73 | load = checkpoint_manager.latest_step() is not None
74 |
75 | data_spec = P(None, "pp", "dp", "tp")
76 | data_partition = jax.sharding.NamedSharding(mesh, data_spec)
77 |
78 | train_dataset, val_dataset = Dataset.getDataset(
79 | cfg.data_config,
80 | partition=data_partition,
81 | dp=DATA_PARALLEL,
82 | pp=LAYER_PARALLEL,
83 | tp=TENSOR_PARALLEL,
84 | )
85 |
86 | model = shardedModel(cfg.model_config)
87 |
88 | log("creating sharded model ...")
89 | key, init_key = jax.random.split(key, 2)
90 | params = model.init_weights(init_key, mesh)
91 |
92 | lr_scheduler = optax.warmup_cosine_decay_schedule(
93 | init_value=cfg.lr.min_lr,
94 | peak_value=cfg.lr.max_lr,
95 | warmup_steps=cfg.lr.warmup_steps,
96 | decay_steps=cfg.lr.end_steps,
97 | end_value=cfg.lr.end_lr,
98 | )
99 |
100 | tx = optax.chain(
101 | optax.clip_by_global_norm(config.grad_clip_norm),
102 | optax.inject_hyperparams(optax.adamw)(learning_rate=lr_scheduler),
103 | )
104 |
105 | default_sharding = jax.sharding.NamedSharding(mesh, P())
106 | opt_state = jax.tree.map(
107 | lambda x: x if jnp.ndim(x) != 0 else jax.device_put(x, default_sharding),
108 | tx.init(params),
109 | )
110 |
111 | init_step = 0
112 | use_wandb = cfg.wandb is True and jax.process_index() == 0
113 | wandb_id = None
114 |
115 | def make_save_tree(step):
116 | model_state = {
117 | "params": params,
118 | "opt_state": opt_state,
119 | }
120 | save_tree = {
121 | "state": model_state,
122 | "key": jax.device_get(key),
123 | "train_step_idx": train_dataset.step_idx,
124 | "train_shard_idx": (train_dataset.shard_idx - 1) % len(train_dataset.data),
125 | "val_step_idx": val_dataset.step_idx,
126 | "val_shard_idx": (val_dataset.shard_idx - 1) % len(val_dataset.data),
127 | "step": step,
128 | }
129 | metadata = {"wandb_id": wandb_id}
130 | return save_tree, metadata
131 |
132 | def save_checkpoint(
133 | step,
134 | ):
135 | save_tree, metadata = make_save_tree(step)
136 | checkpoint_manager.save(
137 | step,
138 | args=ocp.args.Composite(
139 | state=ocp.args.StandardSave(save_tree),
140 | metadata=ocp.args.JsonSave(metadata),
141 | ),
142 | )
143 |
144 | if load:
145 | abstract_tree_state = jax.tree.map(
146 | ocp.utils.to_shape_dtype_struct, make_save_tree(init_step)[0]
147 | )
148 | tree = checkpoint_manager.restore(
149 | checkpoint_manager.latest_step(),
150 | args=ocp.args.Composite(
151 | state=ocp.args.StandardRestore(abstract_tree_state),
152 | metadata=ocp.args.JsonRestore(),
153 | ),
154 | )
155 |
156 | tree_state, tree_metadata = tree.state, tree.metadata
157 |
158 | init_step = tree_state["step"]
159 | log(f"loading checkpoint @ step {init_step}")
160 |
161 | key.key = tree_state["key"]
162 | params = tree_state["state"]["params"]
163 | opt_state = tree_state["state"]["opt_state"]
164 |
165 | train_dataset.step_idx = tree_state["train_step_idx"]
166 | train_dataset.shard_idx = tree_state["train_shard_idx"]
167 | train_dataset.load_next_shard()
168 |
169 | val_dataset.step_idx = tree_state["val_step_idx"]
170 | val_dataset.shard_idx = tree_state["val_shard_idx"]
171 | val_dataset.load_next_shard()
172 |
173 | wandb_id = tree_metadata["wandb_id"]
174 | if use_wandb:
175 | assert wandb_id is not None, "wandb_id is None"
176 | wandb.init(
177 | entity="waterloo2",
178 | project="jaxformer",
179 | name=cfg.name,
180 | resume="must",
181 | id=wandb_id,
182 | config=asdict(cfg),
183 | )
184 |
185 | else:
186 | log("no checkpoint found, saving init copy")
187 | if use_wandb:
188 | wandb.init(
189 | entity="waterloo2",
190 | project="jaxformer",
191 | name=cfg.name,
192 | resume="allow",
193 | config=asdict(cfg),
194 | )
195 | wandb_id = wandb.run.id
196 | save_checkpoint(init_step)
197 |
198 | if use_wandb:
199 | table = wandb.Table(
200 | columns=["step"]
201 | + [
202 | f"tokens_{i}"
203 | for i in range(
204 | cfg.inference_config.batch_size
205 | * cfg.inference_config.n_devices
206 | * jax.process_count()
207 | )
208 | ],
209 | log_mode="INCREMENTAL",
210 | )
211 |
212 | param_count, active_param_count = model.param_count(params)
213 | log(f"Total parameters: {param_count:,} with {active_param_count:,} active")
214 |
215 | def step(params, x, y, key, train):
216 | def loss_fn(params, x, y, key):
217 | logits, (_, moe_stat) = model.pipe_step(
218 | params,
219 | x,
220 | key=key,
221 | train=train,
222 | )
223 | log_probs = jax.nn.log_softmax(logits, axis=-1)
224 |
225 | M, B, T, V = logits.shape
226 | y = y.reshape(-1)
227 | log_probs = log_probs.reshape(M * B * T, V)
228 |
229 | loss_idx = lambda x, idx: jax.lax.dynamic_slice(x, (idx,), (1,))
230 | loss_cross = -(jax.vmap(loss_idx, in_axes=(0, 0))(log_probs, y)).mean()
231 |
232 | loss_cross = jax.lax.pmean(loss_cross, axis_name="dp")
233 | loss_cross = jax.lax.pmean(loss_cross, axis_name="tp")
234 | loss_cross = jax.lax.pmean(loss_cross, axis_name="pp")
235 |
236 | loss_balance = 0.0
237 |
238 | moe_stat = jax.tree.map(lambda x: jax.lax.psum(x, axis_name="dp"), moe_stat)
239 | moe_stat = jax.tree.map(lambda x: jax.lax.psum(x, axis_name="tp"), moe_stat)
240 | moe_stat = jax.tree.map(lambda x: jax.lax.psum(x, axis_name="pp"), moe_stat)
241 |
242 | loss_balance = (cfg.model_config.n_experts / cfg.model_config.k) * moe_stat[
243 | "aux_loss"
244 | ].sum()
245 |
246 | loss = loss_cross + cfg.alpha * loss_balance
247 |
248 | metrics = {
249 | "loss": loss,
250 | "loss_cross": loss_cross,
251 | "loss_balance": loss_balance,
252 | "load_expert": moe_stat["tokens_per_expert"],
253 | }
254 |
255 | return loss, metrics
256 |
257 | return loss_fn(params, x, y, key)
258 |
259 | param_spec = shardedModel.get_p_spec(
260 | [model.embedding, model.block], mesh, cfg.model_config
261 | )
262 | opt_spec = jax.tree.map(lambda x: x.sharding.spec, opt_state)
263 | key_spec = P("dp", "pp", "tp")
264 |
265 | @jax.jit
266 | @partial(
267 | jax.shard_map,
268 | mesh=mesh,
269 | in_specs=(param_spec, opt_spec, data_spec, data_spec, key_spec),
270 | out_specs=(param_spec, opt_spec, P()),
271 | check_vma=False,
272 | )
273 | def train_step(params, opt_state, x, y, key):
274 | step_fn = jax.value_and_grad(step, has_aux=True)
275 |
276 | def single_step(grads, batch):
277 | (_, metrics), grads_current = step_fn(params, *batch, train=True)
278 | grads = jax.tree.map(lambda x, y: x + y, grads, grads_current)
279 | return grads, metrics
280 |
281 | grads = jax.tree.map(lambda x: jnp.zeros_like(x), params)
282 | key = key.reshape(cfg.grad_step, 2)
283 |
284 | grads, metrics = jax.lax.scan(
285 | single_step,
286 | grads,
287 | (x, y, key),
288 | )
289 |
290 | grads = jax.tree.map(lambda x: x / cfg.grad_step, grads)
291 |
292 | metrics = jax.tree.map(lambda x: x.mean(axis=0), metrics)
293 |
294 | updates, opt_state = tx.update(grads, opt_state, params)
295 | params = optax.apply_updates(params, updates)
296 |
297 | return params, opt_state, metrics
298 |
299 | @jax.jit
300 | @partial(
301 | jax.shard_map,
302 | mesh=mesh,
303 | in_specs=(param_spec, data_spec, data_spec),
304 | out_specs=P(),
305 | check_vma=False,
306 | )
307 | def eval_step(params, x, y):
308 | def single_step(_, batch):
309 | loss, metrics = step(
310 | params, *batch, key=jax.random.PRNGKey(0), train=False
311 | ) # Key does not matter
312 | return loss, metrics
313 |
314 | _, metrics = jax.lax.scan(single_step, 0, (x, y))
315 | metrics = jax.tree.map(lambda x: x.mean(axis=0), metrics)
316 | return metrics
317 |
318 | total_steps = cfg.training_steps
319 | total_tokens = train_dataset.tokens_per_step * cfg.grad_step
320 |
321 | jax.experimental.multihost_utils.sync_global_devices("sync")
322 | log(f"Total steps: {total_steps}")
323 | log(f"Total tokens per step: {total_tokens:,}")
324 |
325 | key, sample_key = jax.random.split(key, 2)
326 | start = time.time()
327 | train_loss = []
328 | wandb_log_array = []
329 |
330 | @partial(jax.jit, static_argnames=["steps"])
331 | def make_sharded_key(key, steps=1):
332 | key = jax.random.split(
333 | key, DATA_PARALLEL * LAYER_PARALLEL * TENSOR_PARALLEL * steps
334 | )
335 | key = jnp.asarray(key).reshape(
336 | (DATA_PARALLEL, LAYER_PARALLEL, TENSOR_PARALLEL, steps, 2)
337 | )
338 | return key
339 |
340 | for current_step in range(init_step, total_steps):
341 | key, train_key = jax.random.split(key)
342 | train_key = make_sharded_key(train_key, steps=cfg.grad_step)
343 |
344 | x, y = train_dataset(step=cfg.grad_step)
345 |
346 | params, opt_state, metrics = train_step(params, opt_state, x, y, train_key)
347 | train_loss.append(metrics["loss"])
348 |
349 | if use_wandb:
350 | wandb_log = {
351 | "step": current_step,
352 | "loss/train_loss": metrics["loss"],
353 | "loss/train_cross_entropy_loss": metrics["loss_cross"],
354 | "lr": opt_state[1].hyperparams["learning_rate"],
355 | }
356 | wandb_log["loss/load_loss"] = metrics["loss_balance"]
357 | for h in range(cfg.model_config.n_experts):
358 | wandb_log[f"load/head_{h}"] = jax.device_get(metrics[f"load_expert"])[h]
359 |
360 | if current_step % cfg.checkpoint_steps == 0:
361 | time_per_batch = time.time() - start
362 | eval_x, eval_y = val_dataset(step=cfg.eval_steps)
363 | val_metrics = eval_step(params, eval_x, eval_y)
364 |
365 | if use_wandb:
366 | wandb_log["loss/val_loss"] = val_metrics["loss"]
367 | wandb_log["loss/val_cross_entropy_loss"] = val_metrics["loss_cross"]
368 | wandb_log["loss/val_load_loss"] = val_metrics["loss_balance"]
369 | for h in range(cfg.model_config.n_experts):
370 | wandb_log[f"load/head_{h}"] = jax.device_get(
371 | val_metrics[f"load_expert"]
372 | )[h]
373 |
374 | jax.experimental.multihost_utils.sync_global_devices("sync")
375 |
376 | tokens_per_second = cfg.checkpoint_steps * total_tokens / time_per_batch
377 | train_loss = jnp.array(train_loss).mean().item()
378 | eval_loss = val_metrics["loss"].item()
379 | log_string = f"Step {current_step + 1}, Loss: {train_loss:.4f}, Eval Loss: {eval_loss:.4f}, tk/s: {tokens_per_second:,.2f}"
380 | log(log_string)
381 |
382 | start = time.time()
383 | train_loss = []
384 |
385 | if current_step % (10 * cfg.checkpoint_steps) == 0:
386 | outputs = model.generate(
387 | params,
388 | cfg.model_config,
389 | key=sample_key,
390 | x=cfg.inference_config.prompt,
391 | B=cfg.inference_config.batch_size,
392 | k=cfg.inference_config.top_k,
393 | temperature=cfg.inference_config.temperature,
394 | n_devices=cfg.inference_config.n_devices,
395 | use_cache=cfg.inference_config.use_cache,
396 | )
397 |
398 | log("Generated outputs:")
399 | for output in outputs:
400 | log(f"\t{output}")
401 |
402 | if jax.process_index() == 0:
403 | save_path = os.path.join(os.path.abspath("./samples"), cfg.name)
404 | if not os.path.exists(save_path):
405 | os.makedirs(save_path)
406 | with open(
407 | os.path.join(save_path, "tokens.txt"),
408 | "a",
409 | ) as f:
410 | f.write(f"{current_step} | {outputs}\n")
411 |
412 | if use_wandb:
413 | table.add_data(current_step, *outputs)
414 | wandb_log["inference_tokens"] = table
415 |
416 | save_checkpoint(current_step)
417 | gen_time = time.time() - start
418 | log(f"Generation time: {gen_time:.4f} seconds")
419 | start = time.time()
420 |
421 | if use_wandb:
422 | wandb_log_array.append({"data": wandb_log, "step": current_step})
423 | if current_step % (10 * cfg.checkpoint_steps) == 0:
424 | for log_entry in wandb_log_array:
425 | wandb.log(data=log_entry["data"], step=log_entry["step"])
426 | wandb_log_array = []
427 |
428 | if use_wandb:
429 | wandb.finish()
430 |
431 |
432 | if __name__ == "__main__":
433 | jax.distributed.initialize()
434 | cfg = parse_args()
435 | print(json.dumps(cfg.__dict__, indent=4, default=lambda o: o.__dict__))
436 | main(cfg)
437 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax.sharding import PartitionSpec as P
4 | from flax import linen as nn
5 | from einops import rearrange
6 |
7 | import tiktoken
8 | from utils import modelConfig
9 | import numpy as np
10 | from typing import Optional, Tuple, List
11 | from jaxtyping import Array, PyTree
12 | from functools import partial
13 |
14 | import time
15 |
16 | cache_type = Tuple[Optional[Array], Optional[Array]]
17 | dtype_map = {
18 | "bfloat16": jnp.bfloat16,
19 | "float32": jnp.float32,
20 | "float16": jnp.float16,
21 | "int32": jnp.int32,
22 | "int64": jnp.int64,
23 | }
24 |
25 |
26 | def convert_dtype(dtype_str):
27 | if dtype_str in dtype_map:
28 | return dtype_map[dtype_str]
29 | else:
30 | raise ValueError(f"Unsupported dtype: {dtype_str}")
31 |
32 |
33 | class Dense(nn.Module):
34 | features: int
35 | dtype: jnp.dtype = jnp.float32
36 |
37 | @nn.compact
38 | def __call__(self, x: Array) -> Array:
39 | if self.is_mutable_collection("params"):
40 | kernel = self.param(
41 | "kernel",
42 | nn.initializers.lecun_normal(),
43 | (x.shape[-1], self.features),
44 | jnp.float32,
45 | )
46 | else:
47 | kernel = self.scope.get_variable("params", "kernel")
48 | kernel = jax.lax.all_gather(kernel, "dp", axis=-1, tiled=True)
49 |
50 | bias = self.param("bias", nn.initializers.zeros, (self.features,), jnp.float32)
51 | x, kernel, bias = jax.tree.map(
52 | lambda x: x.astype(self.dtype), (x, kernel, bias)
53 | )
54 |
55 | x = jnp.einsum("...d,df->...f", x, kernel)
56 | tensor_size = jax.lax.psum(1, axis_name="tp")
57 | x = x + (1 / tensor_size) * bias
58 | x = jax.lax.psum_scatter(x, "tp", scatter_dimension=x.ndim - 1, tiled=True)
59 |
60 | return x
61 |
62 |
63 | class FeedForward(nn.Module):
64 | model_dimension: int
65 | dropout_rate: float = 0.1
66 | model_dtype: jnp.dtype = jnp.bfloat16
67 |
68 | @nn.compact
69 | def __call__(self, x: Array, train=True) -> Array:
70 | x = Dense(features=self.model_dimension * 4, dtype=self.model_dtype)(x)
71 | x = nn.gelu(x)
72 | x = Dense(features=self.model_dimension, dtype=self.model_dtype)(x)
73 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
74 | return x
75 |
76 |
77 | class RMSNorm(nn.Module):
78 | model_dtype: jnp.dtype = jnp.float32
79 |
80 | @nn.compact
81 | def __call__(self, x: Array) -> Array:
82 | x_type = x.dtype
83 | x = x.astype(jnp.float32)
84 |
85 | rms = jnp.sum(jnp.square(x), axis=-1, keepdims=True)
86 | rms = jax.lax.psum(rms, axis_name="tp")
87 | rms = rms / jax.lax.psum(x.shape[-1], axis_name="tp")
88 |
89 | x = x / jnp.sqrt(rms + 1e-6)
90 | x = x.astype(x_type)
91 |
92 | gamma = self.param(
93 | "gamma", nn.initializers.ones, (1, 1, x.shape[-1]), jnp.float32
94 | )
95 | beta = self.param(
96 | "beta", nn.initializers.zeros, (1, 1, x.shape[-1]), jnp.float32
97 | )
98 |
99 | x, gamma, beta = jax.tree.map(
100 | lambda x: x.astype(self.model_dtype), (x, gamma, beta)
101 | )
102 |
103 | x = x * gamma + beta
104 |
105 | return x
106 |
107 |
108 | class NoisyKGate(nn.Module):
109 | n_experts: int
110 | k: int
111 | model_dtype: jnp.dtype
112 |
113 | def setup(self):
114 | self.centroids = Dense(features=self.n_experts, dtype=self.model_dtype)
115 |
116 | def top(self, x: Array) -> Tuple[Array, Array]:
117 | assert x.shape[0] == self.n_experts, "x must be of shape (n_experts, )"
118 | g_i, i = jax.lax.top_k(x, self.k)
119 | g = g_i / jnp.sum(g_i, axis=-1)
120 |
121 | return g, i
122 |
123 | def __call__(self, x: Array) -> Tuple[Array, Array, Array]:
124 | local_scores = nn.sigmoid(self.centroids(x))
125 |
126 | scores = jax.lax.all_gather(
127 | local_scores,
128 | "tp",
129 | axis=x.ndim - 1,
130 | tiled=True,
131 | ) # ( B, T, C) fully collected
132 | g_scores, indices = jnp.apply_along_axis(func1d=self.top, axis=-1, arr=scores)
133 |
134 | return g_scores, indices, scores
135 |
136 |
137 | class MoE(nn.Module):
138 | model_dimension: int
139 | n_shared: int
140 | n_experts: int
141 | k: int
142 | dropout_rate: float
143 | capacity_factor: float = 1.0
144 | model_dtype: jnp.dtype = jnp.float32
145 |
146 | @nn.compact
147 | def __call__(self, x, train=True):
148 | B, T, C = x.shape
149 |
150 | shared = Dense(
151 | features=self.model_dimension * self.n_shared,
152 | dtype=self.model_dtype,
153 | )
154 |
155 | res_shared = shared(x)
156 | res_shared = rearrange(res_shared, "B T (n d) -> B T n d", n=self.n_shared)
157 | res_shared = jnp.sum(res_shared, axis=2) # (B, T, n, d) -> (B, T, d)
158 |
159 | router = NoisyKGate(
160 | n_experts=self.n_experts,
161 | k=self.k,
162 | model_dtype=self.model_dtype,
163 | )
164 | g_scores, indices, scores = router(x) # (B, T, k), (B, T, k), (B, T, n_experts)
165 |
166 | capacity = B * T
167 | if train:
168 | capacity = int(capacity * self.capacity_factor / self.n_experts)
169 |
170 | expert_inputs, score_mask, tokens_per_expert = self.scatter(
171 | x, g_scores, indices, capacity
172 | ) # (e, c, d) , (B * T, e, c), (e,)
173 |
174 | expert = FeedForward(
175 | model_dimension=self.model_dimension,
176 | dropout_rate=self.dropout_rate,
177 | model_dtype=self.model_dtype,
178 | )
179 |
180 | expert_outputs = nn.vmap(
181 | lambda expert, inp: expert(inp, train=train),
182 | in_axes=(0),
183 | out_axes=(0),
184 | variable_axes={"params": 0},
185 | split_rngs={"params": True, "dropout": True},
186 | )(expert, expert_inputs)
187 |
188 | # sum the out by the weighted dim
189 | expert_outputs = jnp.einsum("ecd,tec->td", expert_outputs, score_mask)
190 | expert_outputs = expert_outputs.reshape(B, T, C)
191 |
192 | f, p = self.auxiliary_loss(scores, indices)
193 |
194 | aux = {"tokens_per_expert": tokens_per_expert, "f": f, "p": p}
195 |
196 | x = res_shared + expert_outputs
197 |
198 | return x, aux
199 |
200 | def scatter(
201 | self, x: Array, scores: Array, indices: Array, capacity: int
202 | ) -> Tuple[Array, Array]:
203 | B, T, C = x.shape
204 | x = x.reshape(B * T, C)
205 | scores = scores.reshape(B * T, self.k)
206 | indices = indices.reshape(B * T, self.k)
207 |
208 | # sort to arrange in order of expert scores for each batch by
209 | # the highest scored expert
210 | sorted_token_idx = jnp.argsort(-scores[:, 0], axis=0)
211 | sorted_indices = jnp.take_along_axis(indices, sorted_token_idx[:, None], axis=0)
212 | sorted_scores = jnp.take_along_axis(scores, sorted_token_idx[:, None], axis=0)
213 |
214 | # swapping gives you the highest highest score across the batch
215 | # expert_1: [b_1, b_2, .. b_{B * T }], expert_2: [b_1, b_2, .. b_{B * T }], ...
216 | # flatten then to get expert indices in order
217 | flat_indices = jnp.swapaxes(sorted_indices, 0, 1).reshape(-1)
218 | flat_scores = jnp.swapaxes(sorted_scores, 0, 1).reshape(-1)
219 |
220 | # convert to one hot encoding
221 | # then multiply to get the score for each instead of 1
222 | expert_onehot = jax.nn.one_hot(
223 | flat_indices, self.n_experts, dtype=jnp.int32
224 | ) # (B*T*k, n_experts)
225 | expert_scores = flat_scores[:, None] * expert_onehot # (B*T*k, n_experts)
226 |
227 | position_in_expert = (
228 | jnp.cumsum(expert_onehot, axis=0) * expert_onehot
229 | ) # get which position it is in the expert
230 | # find max position across all batches since that is the total sum from cumsum
231 | tokens_per_expert = jnp.max(position_in_expert, axis=0) / (
232 | B * T
233 | ) # take average across batch
234 |
235 | # reshape it back to get for
236 | # expert_i: [b_1, b_2, .. b_{B * T }] where b_i is the one hot for which position it is in
237 | # same for expert scores
238 | position_in_expert = position_in_expert.reshape(self.k, B * T, self.n_experts)
239 | expert_scores = expert_scores.reshape(self.k, B * T, self.n_experts)
240 |
241 | # go back to orginal shape
242 | position_in_expert = jnp.swapaxes(
243 | position_in_expert, 0, 1
244 | ) # (B*T, k, n_experts)
245 | expert_scores = jnp.swapaxes(expert_scores, 0, 1) # (B*T, k, n_experts)
246 |
247 | # for every batch in each expert find the non-zero expert position
248 | # as for every expert we only have one non-zero value
249 | final_pos = jnp.max(position_in_expert, axis=1) - 1 # make it 0 indexed
250 | final_scores = jnp.max(expert_scores, axis=1) # do the same for the score
251 |
252 | # unsort the indices
253 | unsorted_indices = jnp.argsort(sorted_token_idx)
254 | final_pos = jnp.take_along_axis(final_pos, unsorted_indices[:, None], axis=0)
255 | final_scores = jnp.take_along_axis(
256 | final_scores, unsorted_indices[:, None], axis=0
257 | )
258 | # final pos is now the orginal order where each index is the position in the expert
259 | # if it is greater than or less than the capcity / 0 (hence -1) the row will be 0 in the capcity
260 | # hence we have for each positoin and expert the one hot tells us which position it is in
261 | # if it is in
262 | dispatch_mask = jax.nn.one_hot(
263 | final_pos, capacity, dtype=jnp.int32
264 | ) # (B*T, n_experts, capacity)
265 | # multiply out all the values in the capcity by final score
266 | # we can replicate since at most 1 value will be non zero
267 | scores_mask = (
268 | dispatch_mask * final_scores[..., None]
269 | ) # (B*T, n_experts, capacity)
270 |
271 | # since only one expert at every position in capactiy at most
272 | # we can sum to get rid of batch dim and get the exepect capacity dimension indicies
273 | expert_inputs = jnp.einsum("bd,bec->ecd", x, dispatch_mask)
274 |
275 | return expert_inputs, scores_mask, tokens_per_expert
276 |
277 | def auxiliary_loss(self, scores: Array, indices: Array) -> Array:
278 | B, T, n_experts = scores.shape
279 |
280 | scores = scores / jnp.sum(scores, axis=-1, keepdims=True)
281 | scores = scores.reshape(B * T, n_experts)
282 | p = jnp.sum(scores, axis=0) / (B * T)
283 |
284 | total_batch = B * T * self.k
285 | indices = indices.reshape(total_batch)
286 | f = jax.nn.one_hot(indices, n_experts, dtype=jnp.float32)
287 | f = jnp.sum(f, axis=0) / (B * T)
288 |
289 | return f, p
290 |
291 |
292 | class Embedding(nn.Module):
293 | model_dimension: int
294 | vocab_size: int
295 | model_dtype: jnp.dtype
296 |
297 | def setup(self):
298 | self.embedding = nn.Embed(
299 | num_embeddings=self.vocab_size,
300 | features=self.model_dimension,
301 | dtype=self.model_dtype,
302 | )
303 | self.norm = RMSNorm(model_dtype=self.model_dtype)
304 |
305 | def __call__(self, x: Array, out: bool = False) -> Array:
306 | if not out:
307 | *_, T = x.shape
308 | x = self.embedding(x)
309 | x = jax.lax.all_to_all(
310 | x, "tp", split_axis=x.ndim - 1, concat_axis=x.ndim - 2, tiled=True
311 | )
312 | if self.is_mutable_collection("params"):
313 | _ = self.norm(x)
314 | else:
315 | x = self.norm(x)
316 | x = jax.lax.all_to_all(
317 | x, "tp", split_axis=x.ndim - 2, concat_axis=x.ndim - 1, tiled=True
318 | )
319 | x = self.embedding.attend(x)
320 |
321 | return x
322 |
323 |
324 | class RoPE(nn.Module):
325 | T: int
326 | model_dim: int
327 |
328 | def setup(self):
329 | assert self.model_dim % 2 == 0, "model_dim must be even"
330 |
331 | freq = jnp.arange(self.T, dtype=jnp.float32)[:, None] + 1
332 |
333 | pos = jnp.arange(self.model_dim // 2, dtype=jnp.float32)[:, None]
334 | pos = pos.repeat(2, axis=-1).reshape(1, -1)
335 | log_theta_base = jnp.log(10000.0)
336 | theta = jnp.exp(-2 * pos / self.model_dim * log_theta_base)
337 |
338 | idx = jax.lax.axis_index("tp")
339 | tensor_size = jax.lax.psum(1, axis_name="tp")
340 | slice_factor = self.model_dim // tensor_size
341 |
342 | cos = jnp.cos(freq * theta)
343 | sin = jnp.sin(freq * theta)
344 |
345 | self.cos = jax.lax.dynamic_slice_in_dim(
346 | cos, slice_factor * idx, slice_factor, axis=-1
347 | )
348 | self.sin = jax.lax.dynamic_slice_in_dim(
349 | sin, slice_factor * idx, slice_factor, axis=-1
350 | )
351 |
352 | def __call__(
353 | self,
354 | x: Array,
355 | t_start: int,
356 | ) -> Array:
357 | B, T, C = x.shape
358 | x_dtype = x.dtype
359 | x = x.astype(jnp.float32)
360 |
361 | cos_rope = x * self.cos[t_start : t_start + T, :]
362 |
363 | x_inter = x.reshape((B, T, C // 2, 2))
364 | x_inter_one = x_inter[..., 0]
365 | x_inter_two = -1 * x_inter[..., 1]
366 | x_inter = jnp.stack([x_inter_two, x_inter_one], axis=-1).reshape((B, T, C))
367 |
368 | x_inter = x_inter.reshape((B, T, C))
369 | sin_rope = x_inter * self.sin[t_start : t_start + T, :]
370 |
371 | x = cos_rope + sin_rope
372 | x = x.astype(x_dtype)
373 |
374 | return x
375 |
376 |
377 | class MLA(nn.Module):
378 | model_dimension: int
379 | n_heads: int
380 | T: int
381 | latent_dim: int
382 | dhR: int
383 | model_dtype: jnp.dtype
384 | dropout: float = 0.0
385 |
386 | @nn.compact
387 | def __call__(
388 | self,
389 | x: Array,
390 | *,
391 | KV_cache: Optional[Array] = None,
392 | KR_cache: Optional[Array] = None,
393 | train=True,
394 | ) -> Tuple[Array, Tuple[Optional[Array], Optional[Array]]]:
395 | use_rope = self.dhR > 0
396 |
397 | B, T, C = x.shape
398 |
399 | x = Dense(features=2 * self.latent_dim, dtype=self.model_dtype)(x)
400 | kv_latent, q_latent = jnp.split(x, 2, axis=-1)
401 |
402 | if use_rope:
403 | t_start = KV_cache.shape[1] if KV_cache is not None else 0
404 | x_k_r = Dense(features=self.dhR, dtype=self.model_dtype)(x)
405 | x_q_r = Dense(features=self.dhR * self.n_heads, dtype=self.model_dtype)(x)
406 |
407 | rope_k = RoPE(model_dim=self.dhR, T=self.T)
408 | rope_q = RoPE(
409 | model_dim=self.dhR * self.n_heads,
410 | T=self.T,
411 | )
412 |
413 | kRt = rope_k(x_k_r, t_start)
414 |
415 | qRt = rope_q(x_q_r, t_start)
416 | qRt = rearrange(qRt, "B T (nh d) -> B nh T d", nh=self.n_heads)
417 |
418 | if not train:
419 | if KV_cache is not None:
420 | kv_latent = jnp.concatenate([KV_cache, kv_latent], axis=1)
421 | KV_cache = kv_latent
422 |
423 | if use_rope:
424 | if KR_cache is not None:
425 | kRt = jnp.concatenate([KR_cache, kRt], axis=1)
426 | KR_cache = kRt
427 |
428 | k, v = jnp.split(
429 | Dense(features=2 * self.model_dimension, dtype=self.model_dtype)(kv_latent),
430 | 2,
431 | axis=-1,
432 | )
433 | q = Dense(features=self.model_dimension, dtype=self.model_dtype)(q_latent)
434 |
435 | q, k, v = jax.tree.map(
436 | lambda x: rearrange(x, "B T (nh d) -> B nh T d", nh=self.n_heads), (q, k, v)
437 | )
438 |
439 | q, k, v = jax.tree.map(
440 | lambda x: jax.lax.all_to_all(
441 | x, "tp", split_axis=1, concat_axis=3, tiled=True
442 | ),
443 | (q, k, v),
444 | )
445 |
446 | if use_rope:
447 | qRt = jax.lax.all_to_all(qRt, "tp", split_axis=1, concat_axis=3, tiled=True)
448 | q = jnp.concatenate([q, qRt], axis=-1)
449 |
450 | kRt = jnp.repeat(kRt[:, None, :, :], self.n_heads, axis=1)
451 | kRt = jax.lax.all_to_all(kRt, "tp", split_axis=1, concat_axis=3, tiled=True)
452 | k = jnp.concatenate([k, kRt], axis=-1)
453 |
454 | def scaledDotProd(q, k, v, mask):
455 | input_dtype = q.dtype
456 |
457 | q, k, v = jax.tree.map(lambda x: x.astype(jnp.float32), (q, k, v))
458 | dk = q.shape[-1]
459 |
460 | w = jnp.einsum("B n T d, B n t d -> B n T t", q, k) * (dk**-0.5)
461 | w = jnp.where(mask == 0, -jnp.inf, w)
462 | w = jax.nn.softmax(w, axis=-1)
463 | output = jnp.einsum("B n T t, B n t d -> B n T d", w, v)
464 |
465 | output = output.astype(input_dtype)
466 | return output
467 |
468 | local_n_heads = q.shape[1]
469 | if T == 1:
470 | mask = jnp.ones((B, local_n_heads, 1, k.shape[2]))
471 | else:
472 | mask = jnp.tril(
473 | jnp.ones((B, local_n_heads, q.shape[2], k.shape[2])),
474 | )
475 |
476 | output = scaledDotProd(q, k, v, mask)
477 |
478 | output = jax.lax.all_to_all(
479 | output, "tp", split_axis=3, concat_axis=1, tiled=True
480 | )
481 | output = rearrange(output, "B nh T dk -> B T (nh dk)")
482 |
483 | output = Dense(features=self.model_dimension, dtype=self.model_dtype)(output)
484 | output = nn.Dropout(rate=self.dropout)(output, deterministic=not train)
485 |
486 | return output, (KV_cache, KR_cache)
487 |
488 |
489 | class Layer(nn.Module):
490 | model_dimension: int
491 | n_heads: int
492 | T: int
493 | latent_dim: int
494 | dhR: int
495 | n_experts: int
496 | k: int
497 | n_shared: int
498 | capacity_factor: float
499 | use_moe: bool = False
500 | dropout_rate: float = 0.1
501 | model_dtype: jnp.dtype = jnp.bfloat16
502 |
503 | @nn.compact
504 | def __call__(
505 | self, x, cache: Optional[cache_type] = None, train=True
506 | ) -> Tuple[Array, cache_type]:
507 | x_res = x
508 |
509 | x = RMSNorm(model_dtype=self.model_dtype)(x)
510 | x, cache = MLA(
511 | model_dimension=self.model_dimension,
512 | n_heads=self.n_heads,
513 | T=self.T,
514 | latent_dim=self.latent_dim,
515 | dhR=self.dhR,
516 | model_dtype=self.model_dtype,
517 | dropout=self.dropout_rate,
518 | )(x, KV_cache=cache[0], KR_cache=cache[1], train=train)
519 | x = x + x_res
520 | x_res = x
521 |
522 | x = RMSNorm(model_dtype=self.model_dtype)(x)
523 | if self.use_moe:
524 | x, aux = MoE(
525 | model_dimension=self.model_dimension,
526 | n_experts=self.n_experts,
527 | k=self.k,
528 | n_shared=self.n_shared,
529 | capacity_factor=self.capacity_factor,
530 | dropout_rate=self.dropout_rate,
531 | model_dtype=self.model_dtype,
532 | )(x, train=train)
533 | else:
534 | x, aux = (
535 | FeedForward(
536 | model_dimension=self.model_dimension,
537 | dropout_rate=self.dropout_rate,
538 | model_dtype=self.model_dtype,
539 | )(x, train=train),
540 | None,
541 | )
542 | x = x + x_res
543 |
544 | return x, (cache, aux)
545 |
546 |
547 | class Block(nn.Module):
548 | layers: int
549 | model_dimension: int
550 | n_heads: int
551 | T: int
552 | latent_dim: int
553 | dhR: int
554 | n_experts: int
555 | k: int
556 | n_shared: int
557 | capacity_factor: float
558 | dropout_rate: float = 0.1
559 | model_dtype: jnp.dtype = jnp.bfloat16
560 |
561 | @nn.compact
562 | def __call__(
563 | self, x, cache: Optional[cache_type] = None, train=True
564 | ) -> Tuple[Array, cache_type]:
565 | KV_cache = []
566 | KR_cache = []
567 | moe_stat = None
568 |
569 | for i in range(self.layers):
570 | current_cache = [None, None]
571 | if cache is not None:
572 | current_cache[0] = cache[0][i]
573 | if i < self.layers - 1:
574 | current_cache[1] = cache[1][i]
575 |
576 | x, (cache_out, aux) = Layer(
577 | model_dimension=self.model_dimension,
578 | n_heads=self.n_heads,
579 | T=self.T,
580 | latent_dim=self.latent_dim,
581 | dhR=self.dhR if i < self.layers - 1 else 0,
582 | n_experts=self.n_experts,
583 | k=self.k,
584 | n_shared=self.n_shared,
585 | capacity_factor=self.capacity_factor,
586 | use_moe=(i == self.layers - 1),
587 | dropout_rate=self.dropout_rate,
588 | model_dtype=self.model_dtype,
589 | )(x, current_cache, train=train)
590 |
591 | if aux is not None:
592 | moe_stat = aux
593 |
594 | ckV, kRT = cache_out
595 | if ckV is not None:
596 | KV_cache.append(ckV)
597 | if kRT is not None:
598 | KR_cache.append(kRT)
599 |
600 | KV_cache = jnp.stack(KV_cache, axis=0) if len(KV_cache) > 0 else None
601 | KR_cache = jnp.stack(KR_cache, axis=0) if len(KR_cache) > 0 else None
602 |
603 | out_cache = (KV_cache, KR_cache)
604 |
605 | return x, (out_cache, moe_stat)
606 |
607 |
608 | class Transformer(nn.Module):
609 | model_dimension: int
610 | vocab_size: int
611 | n_head: int
612 | blocks: int
613 | layers_per_block: int
614 | T: int
615 | latent_dim: int
616 | dhR: int
617 | n_experts: int
618 | k: int
619 | n_shared: int
620 | capacity_factor: float
621 | dropout_rate: float = 0.1
622 | model_dtype: jnp.dtype = jnp.bfloat16
623 |
624 | @nn.compact
625 | def __call__(
626 | self, x, cache: Optional[cache_type] = None, train=True
627 | ) -> Tuple[Array, cache_type]:
628 | if cache is not None:
629 | x = x[..., -1:]
630 |
631 | *B, T = x.shape
632 | x = x.reshape(-1, T)
633 |
634 | embedding = Embedding(
635 | vocab_size=self.vocab_size,
636 | model_dimension=self.model_dimension,
637 | model_dtype=self.model_dtype,
638 | )
639 |
640 | x = embedding(x)
641 |
642 | KV_cache = []
643 | ckRT_cache = []
644 | moe_stat = []
645 |
646 | for i in range(self.blocks):
647 | if cache is None:
648 | layer_cache = None
649 | else:
650 | cKV = cache[0][i]
651 | kRT = cache[1][i] if cache[1] is not None else None
652 | layer_cache = (cKV, kRT)
653 |
654 | x, (cache_out, moe_stat_out) = Block(
655 | layers=self.layers_per_block,
656 | model_dimension=self.model_dimension,
657 | n_heads=self.n_head,
658 | T=self.T,
659 | latent_dim=self.latent_dim,
660 | dhR=self.dhR,
661 | n_experts=self.n_experts,
662 | k=self.k,
663 | n_shared=self.n_shared,
664 | capacity_factor=self.capacity_factor,
665 | dropout_rate=self.dropout_rate,
666 | model_dtype=self.model_dtype,
667 | )(x, layer_cache, train=train)
668 |
669 | if cache_out[0] is not None:
670 | KV_cache.append(cache_out[0])
671 | if cache_out[1] is not None:
672 | ckRT_cache.append(cache_out[1])
673 |
674 | moe_stat.append(moe_stat_out)
675 |
676 | if len(KV_cache) > 0:
677 | KV_cache = jnp.stack(KV_cache, axis=0)
678 | else:
679 | KV_cache = None
680 | if len(ckRT_cache) > 0:
681 | ckRT_cache = jnp.stack(ckRT_cache, axis=0)
682 | else:
683 | ckRT_cache = None
684 | out_cache = (KV_cache, ckRT_cache)
685 |
686 | moe_stat = jax.tree.map(lambda *x: jnp.stack(x, axis=0), *moe_stat)
687 |
688 | x_out = embedding(x, out=True)
689 | x_out = x_out.reshape(*B, T, self.vocab_size)
690 |
691 | return x_out, (out_cache, moe_stat)
692 |
693 | def init_weights(self, key: jax.random.key, mesh: jax.sharding.Mesh) -> PyTree:
694 | params = self.init(key, jnp.ones((1, self.T), dtype=jnp.int32), train=False)[
695 | "params"
696 | ]
697 | p_spec = Transformer.get_p_spec(params)
698 | params = jax.tree.map(
699 | lambda x, y: jax.device_put(x, jax.sharding.NamedSharding(mesh, y)),
700 | params,
701 | p_spec,
702 | )
703 | return params
704 |
705 | @classmethod
706 | def get_model(cls, cfg: modelConfig) -> "Transformer":
707 | return cls(
708 | model_dimension=cfg.model_dimension,
709 | vocab_size=cfg.vocab_size,
710 | n_head=cfg.n_head,
711 | blocks=cfg.blocks,
712 | layers_per_block=cfg.layers_per_block,
713 | T=cfg.T,
714 | latent_dim=cfg.latent_dim,
715 | dhR=cfg.dhR,
716 | n_experts=cfg.n_experts,
717 | k=cfg.k,
718 | n_shared=cfg.n_shared,
719 | capacity_factor=cfg.capacity_factor,
720 | dropout_rate=cfg.dropout_rate,
721 | model_dtype=convert_dtype(cfg.model_dtype),
722 | )
723 |
724 | @staticmethod
725 | def get_p_spec(params: PyTree):
726 | return jax.tree.map(
727 | lambda _: P(),
728 | params,
729 | )
730 |
731 | def generate(
732 | self,
733 | params: PyTree,
734 | key: jax.random.key,
735 | x: str = "",
736 | *,
737 | B: int = 1,
738 | k: int = 10000,
739 | temperature: int = 1,
740 | max_tokens: int = 100,
741 | use_cache=True,
742 | ) -> List[str]:
743 | enc = tiktoken.encoding_for_model("gpt-4")
744 | out = jnp.array(
745 | [enc._special_tokens["<|endoftext|>"]] if x == "" else enc.encode(x),
746 | dtype=jnp.int32,
747 | )
748 |
749 | prompt_length = out.shape[0]
750 | generation_length = min(max_tokens, self.T - prompt_length)
751 | out = jnp.repeat(out[None, :], B, axis=0)
752 | cache = None
753 |
754 | @jax.jit
755 | def sample(key, params, inp, cache):
756 | logits, cache = self.apply(
757 | {"params": params}, inp, cache=cache, train=False
758 | )
759 | logits = logits[:, -1, :]
760 | logits, idx = jax.lax.top_k(logits, k=k)
761 | logits /= temperature
762 |
763 | out_next_idx = jax.random.categorical(key, logits, axis=-1, shape=(B,))
764 | out_next = idx[jnp.arange(B, dtype=jnp.int32), out_next_idx][:, None]
765 |
766 | return out_next, (cache, logits)
767 |
768 | for _ in range(generation_length):
769 | start_time = time.time()
770 | if not use_cache:
771 | cache = None
772 | key, sample_key = jax.random.split(key)
773 | out_next, (cache, _logits) = sample(sample_key, params, out, cache)
774 | out = jnp.concatenate([out, out_next], axis=-1)
775 | end_time = time.time()
776 | token_time = end_time - start_time
777 | print(f"Token {_ + 1} generated \t {1 / token_time:.4f} tk/s")
778 |
779 | tokens = jax.device_get(out)
780 | outputs = list(map(lambda x: enc.decode(x), tokens))
781 |
782 | return outputs
783 |
784 |
785 | class shardedModel:
786 | def __init__(self, cfg: modelConfig):
787 | self.dtype = convert_dtype(cfg.model_dtype)
788 | self.embedding = Embedding(
789 | vocab_size=cfg.vocab_size,
790 | model_dimension=cfg.model_dimension,
791 | model_dtype=self.dtype,
792 | )
793 |
794 | self.block = Block(
795 | layers=cfg.layers_per_block,
796 | model_dimension=cfg.model_dimension,
797 | n_heads=cfg.n_head,
798 | T=cfg.T,
799 | latent_dim=cfg.latent_dim,
800 | dhR=cfg.dhR,
801 | n_experts=cfg.n_experts,
802 | k=cfg.k,
803 | n_shared=cfg.n_shared,
804 | capacity_factor=cfg.capacity_factor,
805 | dropout_rate=cfg.dropout_rate,
806 | model_dtype=self.dtype,
807 | )
808 |
809 | self.cfg = cfg
810 |
811 | def init_weights(self, key, mesh):
812 | out_spec = shardedModel.get_p_spec([self.embedding, self.block], mesh, self.cfg)
813 |
814 | def replace_fsdp(p: jax.sharding.PartitionSpec):
815 | if p[-1] == "dp":
816 | p = P(*p[:-1], None)
817 | return p
818 |
819 | out_spec_no_fsdp = jax.tree.map(lambda x: replace_fsdp(x), out_spec)
820 |
821 | x_embed = jnp.ones((1, self.cfg.T), dtype=jnp.int32)
822 | x_layer = jnp.ones((1, self.cfg.T, self.cfg.model_dimension), dtype=self.dtype)
823 |
824 | layer_devices = mesh.devices.shape[1]
825 | tensor_devices = mesh.devices.shape[2]
826 |
827 | assert self.cfg.blocks // layer_devices, (
828 | "Number of blocks must be divisible by number of devices"
829 | )
830 | layers_per_device = self.cfg.blocks // layer_devices
831 |
832 | key, embed_key = jax.random.split(key, 2)
833 | key, *layer_keys = jax.random.split(key, layer_devices * tensor_devices + 1)
834 | layer_keys = jnp.array(layer_keys).reshape(layer_devices, tensor_devices, 2)
835 |
836 | @jax.jit
837 | @partial(
838 | jax.shard_map,
839 | mesh=mesh,
840 | in_specs=(P(None, "tp"), P(None, None, "tp"), P("pp", "tp")),
841 | out_specs=out_spec_no_fsdp,
842 | )
843 | def init_params(x_embed, x_layer, layer_key):
844 | layer_key = layer_key.reshape(
845 | 2,
846 | )
847 | embedding_params = self.embedding.init(embed_key, x_embed, out=False)[
848 | "params"
849 | ]
850 | layer_params = []
851 |
852 | for _ in range(layers_per_device):
853 | layer_key, init_key = jax.random.split(layer_key)
854 | current_params = self.block.init(init_key, x_layer, train=False)[
855 | "params"
856 | ]
857 | layer_params.append(current_params)
858 | layer_params = jax.tree.map(lambda *x: jnp.stack(x, axis=0), *layer_params)
859 |
860 | return embedding_params, layer_params
861 |
862 | params = init_params(x_embed, x_layer, layer_keys)
863 | params = jax.tree.map(
864 | lambda x, y: jax.device_put(x, jax.sharding.NamedSharding(mesh, y)),
865 | params,
866 | out_spec,
867 | )
868 |
869 | return params
870 |
871 | def pipe_step(self, params, x, key, train, cache=None):
872 | embedding_params, layer_params = params
873 |
874 | if cache is not None:
875 | x = x[..., -1:]
876 |
877 | embeddings = self.embedding.apply({"params": embedding_params}, x, out=False)
878 |
879 | layer_fn = lambda x, params, cache, key: self.block.apply(
880 | {"params": params},
881 | x,
882 | cache=cache,
883 | train=train,
884 | rngs={"dropout": key} if train else None,
885 | )
886 |
887 | @partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
888 | def fwd_fn(state_idx, x, params, cache, key):
889 | def grad_fn(stop_grad):
890 | return (
891 | lambda *args: jax.lax.stop_gradient(layer_fn(*args))
892 | if stop_grad
893 | else layer_fn(*args)
894 | )
895 |
896 | fns = [
897 | grad_fn(stop_grad=True),
898 | grad_fn(stop_grad=False),
899 | ]
900 |
901 | out = jax.lax.switch(
902 | state_idx,
903 | fns,
904 | x,
905 | params,
906 | cache,
907 | key,
908 | )
909 |
910 | return out
911 |
912 | layer_out, (out_cache, moe_stat) = self.pipeline(
913 | fwd_fn, layer_params, embeddings, cache, key
914 | )
915 |
916 | logits = self.embedding.apply({"params": embedding_params}, layer_out, out=True)
917 | return logits, (out_cache, moe_stat)
918 |
919 | def pipeline(
920 | self,
921 | fn,
922 | stage_params: PyTree,
923 | inputs: Array,
924 | cache: Optional[Tuple[Array, Optional[Array]]],
925 | key: jax.random.PRNGKey,
926 | ):
927 | device_idx = jax.lax.axis_index("pp")
928 | n_devices = jax.lax.axis_size("pp")
929 | layers_per_device = stage_params["Layer_0"]["MLA_0"]["Dense_0"]["kernel"].shape[
930 | 0
931 | ]
932 | microbatch_per_device = inputs.shape[0]
933 | microbatches = n_devices * microbatch_per_device
934 | layers = layers_per_device * n_devices
935 | outputs = jnp.zeros_like(inputs) * jnp.nan
936 | state = (
937 | jnp.zeros(
938 | (
939 | layers_per_device,
940 | *inputs.shape[1:],
941 | )
942 | )
943 | * jnp.nan
944 | )
945 |
946 | state_idx = jnp.zeros((layers_per_device,), dtype=jnp.int32)
947 | perm = [(i, (i + 1) % n_devices) for i in range(n_devices)]
948 |
949 | KV_cache = []
950 | KR_cache = []
951 |
952 | moe_stat = []
953 |
954 | for i in range(microbatches + layers - 1):
955 | batch_idx = i % microbatch_per_device
956 | layer_idx = (i - layers + 1) % microbatch_per_device
957 |
958 | state = state.at[0].set(
959 | jnp.where(device_idx == 0, inputs[batch_idx], state[0])
960 | )
961 | state_idx = state_idx.at[0].set(jnp.where(device_idx == 0, 1, state_idx[0]))
962 |
963 | key, *layer_keys = jax.random.split(key, layers_per_device + 1)
964 | layer_keys = jnp.array(layer_keys)
965 |
966 | current_cache = None
967 | if cache is not None:
968 | current_cache = [cache[0][i], None]
969 | if cache[1] is not None:
970 | current_cache[1] = cache[1][i]
971 |
972 | state, (out_cache, out_moe_stat) = jax.vmap(fn)(
973 | state_idx, state, stage_params, current_cache, layer_keys
974 | )
975 |
976 | if out_cache[0] is not None:
977 | KV_cache.append(out_cache[0])
978 | if out_cache[1] is not None:
979 | KR_cache.append(out_cache[1])
980 | moe_stat.append(out_moe_stat)
981 |
982 | outputs = outputs.at[layer_idx].set(
983 | jnp.where(device_idx == n_devices - 1, state[-1], outputs[layer_idx])
984 | )
985 |
986 | state = jnp.concat(
987 | [jax.lax.ppermute(state[-1], "pp", perm)[None, ...], state[:-1]], axis=0
988 | )
989 | state_idx = jnp.concat(
990 | [
991 | jax.lax.ppermute(state_idx[-1], "pp", perm)[None, ...],
992 | state_idx[:-1],
993 | ],
994 | axis=0,
995 | )
996 |
997 | if batch_idx == microbatch_per_device - 1 and i < microbatches:
998 | inputs = jax.lax.ppermute(inputs, axis_name="pp", perm=perm)
999 |
1000 | if layer_idx == microbatch_per_device - 1 and i >= layers - 1:
1001 | outputs = jax.lax.ppermute(outputs, axis_name="pp", perm=perm)
1002 |
1003 | outputs = jax.lax.ppermute(outputs, "pp", perm)
1004 |
1005 | if len(KV_cache) > 0:
1006 | KV_cache = jnp.stack(KV_cache, axis=0)
1007 | else:
1008 | KV_cache = None
1009 |
1010 | if len(KR_cache) > 0:
1011 | KR_cache = jnp.stack(KR_cache, axis=0)
1012 | else:
1013 | KR_cache = None
1014 | out_cache = (KV_cache, KR_cache)
1015 |
1016 | moe_stat = jax.tree.map(lambda *x: jnp.stack(x, axis=0), *moe_stat)
1017 |
1018 | def slice_moe(x: Array) -> Array:
1019 | def each_layer(layer_idx, x):
1020 | return jax.lax.dynamic_slice_in_dim(
1021 | x, layers_per_device * device_idx + layer_idx, microbatches, axis=0
1022 | )
1023 |
1024 | sliced_x = jax.vmap(each_layer, in_axes=(0, -2), out_axes=(-2))(
1025 | jnp.arange(layers_per_device), x
1026 | )
1027 | return sliced_x
1028 |
1029 | moe_stat = jax.tree.map(
1030 | lambda x: slice_moe(x).mean(axis=0), # mean across microbatches
1031 | moe_stat,
1032 | )
1033 |
1034 | moe_stat = {
1035 | "tokens_per_expert": moe_stat["tokens_per_expert"].sum(
1036 | axis=0
1037 | ), # (experts,)
1038 | "aux_loss": moe_stat["f"] * moe_stat["p"], # (layers_per_device, experts)
1039 | }
1040 |
1041 | return outputs, (out_cache, moe_stat)
1042 |
1043 | def generate(
1044 | self,
1045 | params: PyTree,
1046 | cfg: modelConfig,
1047 | key: jax.random.key,
1048 | x: str = "",
1049 | *,
1050 | B: int = 1,
1051 | k: int = 10000,
1052 | temperature: int = 1,
1053 | max_tokens: int = 10,
1054 | n_devices: int = 1,
1055 | use_cache=True,
1056 | ) -> list[str]:
1057 | assert B % n_devices == 0, "Batch size must be divisible by number of devices"
1058 | assert n_devices <= jax.local_device_count(), (
1059 | "Number of devices exceeds available devices"
1060 | )
1061 |
1062 | mesh = jax.make_mesh(
1063 | (1, n_devices, 1),
1064 | axis_names=("dp", "pp", "tp"),
1065 | devices=np.array(jax.local_devices())[:n_devices],
1066 | )
1067 |
1068 | model = shardedModel(cfg)
1069 | out_spec = shardedModel.get_p_spec([model.embedding, model.block], mesh, cfg)
1070 | params = jax.tree.map(
1071 | lambda x, y: jax.device_put(
1072 | jax.experimental.multihost_utils.process_allgather(x, tiled=True),
1073 | jax.sharding.NamedSharding(mesh, y),
1074 | ),
1075 | params,
1076 | out_spec,
1077 | )
1078 | enc = tiktoken.encoding_for_model("gpt-4")
1079 | out = jnp.array(
1080 | [enc._special_tokens["<|endoftext|>"]] if x == "" else enc.encode(x),
1081 | dtype=jnp.int32,
1082 | )
1083 | out = jnp.repeat(out[None, :], B, axis=0).reshape(n_devices, B // n_devices, -1)
1084 |
1085 | prompt_length = out.shape[-1]
1086 | generation_length = min(max_tokens, cfg.T - prompt_length)
1087 |
1088 | generation = jnp.zeros(
1089 | (n_devices, B // n_devices, generation_length + prompt_length),
1090 | dtype=jnp.int32,
1091 | )
1092 | generation = generation.at[:, :, :prompt_length].set(out)
1093 |
1094 | def sample(params, out, cache, sample_key):
1095 | sample_key, pipe_key = jax.random.split(sample_key, 2)
1096 | logits, (cache, _) = shardedModel.pipe_step(
1097 | model, params, out, pipe_key, train=False, cache=cache
1098 | )
1099 |
1100 | logits = logits[:, :, -1, :]
1101 | M, B_sample, _ = logits.shape
1102 | logits = logits.reshape(M * B_sample, -1)
1103 | logits, idx = jax.lax.top_k(logits, k=k)
1104 | logits /= temperature
1105 |
1106 | sample_prob = lambda key, logits, idx: idx[
1107 | jax.random.categorical(key, logits)
1108 | ]
1109 | sample_key = jnp.array(jax.random.split(sample_key, logits.shape[0]))
1110 | out_next = jax.vmap(sample_prob)(sample_key, logits, idx)
1111 | out_next = out_next.reshape(M, B_sample, 1)
1112 |
1113 | return out_next, (cache, logits)
1114 |
1115 | @jax.jit
1116 | @partial(
1117 | jax.shard_map,
1118 | mesh=mesh,
1119 | in_specs=(
1120 | out_spec,
1121 | P(),
1122 | P("dp", "pp", "tp"),
1123 | ),
1124 | out_specs=P("pp", "dp", "tp"),
1125 | )
1126 | def generate_shard(params, generation_buffer, key):
1127 | cache = None
1128 | key = key.reshape(
1129 | 2,
1130 | )
1131 | for idx in range(generation_length):
1132 | if not use_cache:
1133 | cache = None
1134 | key, sample_key = jax.random.split(key)
1135 | current_idx = prompt_length + idx
1136 | out = jax.lax.dynamic_slice_in_dim(
1137 | generation_buffer, 0, current_idx + 1, axis=-1
1138 | )
1139 | out_next, (cache, _logits) = sample(params, out, cache, sample_key)
1140 |
1141 | generation_buffer = generation_buffer.at[
1142 | :, :, current_idx : current_idx + 1
1143 | ].set(out_next)
1144 |
1145 | return generation_buffer[None, None, ...]
1146 |
1147 | key = jax.random.fold_in(key, jax.process_index())
1148 | sample_key = jnp.array(jax.random.split(key, B)).reshape(
1149 | n_devices, B // n_devices, 2
1150 | )
1151 | out = generate_shard(params, generation, sample_key)
1152 |
1153 | tokens = jax.device_get(out)
1154 | tokens = tokens.reshape(-1, tokens.shape[-1])
1155 | tokens = jax.experimental.multihost_utils.process_allgather(tokens, tiled=True)
1156 |
1157 | outputs = [enc.decode(x) for x in tokens]
1158 |
1159 | return outputs
1160 |
1161 | @staticmethod
1162 | def get_p_spec(
1163 | model: Tuple[Embedding, Block], mesh: jax.sharding.Mesh, config: modelConfig
1164 | ) -> Tuple[jax.sharding.NamedSharding, jax.sharding.NamedSharding]:
1165 | T = config.T
1166 | n_devices = mesh.devices.shape[1]
1167 | n_layers = config.blocks
1168 | assert n_layers % n_devices == 0, (
1169 | "Number of layers must be divisible by number of devices"
1170 | )
1171 |
1172 | embed, layer = model
1173 |
1174 | x_embed = jnp.ones((1, T), dtype=jnp.int32)
1175 | x_layer = jnp.ones((1, T, embed.model_dimension), dtype=jnp.float32)
1176 | key = jax.random.PRNGKey(0)
1177 |
1178 | @partial(
1179 | jax.shard_map,
1180 | mesh=mesh,
1181 | in_specs=(P(None, "tp"), P(None, None, "tp")),
1182 | out_specs=(P("pp")),
1183 | )
1184 | def get_var_spec_shard(x_embed, x_layer):
1185 | embed_shape = embed.init(key, x_embed)["params"]
1186 | layer_shape = []
1187 | for _ in range(n_layers // n_devices):
1188 | layer_shape.append(layer.init(key, x_layer, train=False)["params"])
1189 | layer_shape = jax.tree.map(lambda *x: jnp.stack(x, axis=0), *layer_shape)
1190 |
1191 | return embed_shape, layer_shape
1192 |
1193 | eval_shape = jax.eval_shape(
1194 | get_var_spec_shard,
1195 | x_embed,
1196 | x_layer,
1197 | )
1198 |
1199 | join_fn = lambda path: " ".join(i.key for i in path).lower()
1200 |
1201 | def layer_partition(key: Tuple[str, ...], x: Array) -> P:
1202 | path = join_fn(key)
1203 | if "moe" in path and "feedforward" in path:
1204 | if x.ndim == 4:
1205 | return P("pp", None, "tp", "dp")
1206 | if x.ndim == 3:
1207 | return P("pp", None, None)
1208 |
1209 | if "gamma" in path or "beta" in path:
1210 | return P("pp", None, None, "tp")
1211 |
1212 | if x.ndim == 3:
1213 | return P("pp", "tp", "dp")
1214 |
1215 | return P("pp", None)
1216 |
1217 | def embedding_partition(key: Tuple[str, ...], x: Array) -> P:
1218 | path = join_fn(key)
1219 | if "gamma" in path or "beta" in path:
1220 | return P(None, None, "tp")
1221 | return P(*(None for _ in range(x.ndim)))
1222 |
1223 | embed_p_spec = jax.tree.map_with_path(
1224 | embedding_partition,
1225 | eval_shape[0],
1226 | )
1227 |
1228 | layer_p_spec = jax.tree.map_with_path(
1229 | layer_partition,
1230 | eval_shape[1],
1231 | )
1232 |
1233 | return embed_p_spec, layer_p_spec
1234 |
1235 | def param_count(self, params):
1236 | total_params = jax.tree.reduce(
1237 | lambda x, y: x + y.size,
1238 | params,
1239 | 0,
1240 | )
1241 |
1242 | join_fn = lambda path: " ".join(i.key for i in path).lower()
1243 |
1244 | def count_active_params(key, x):
1245 | path = join_fn(key)
1246 | n_elements = x.size
1247 |
1248 | is_expert = "moe" in path and "feedforward" in path
1249 | if is_expert:
1250 | n_elements = n_elements // self.cfg.n_experts * self.cfg.k
1251 |
1252 | return n_elements
1253 |
1254 | active_params_map = jax.tree.map_with_path(count_active_params, params[1])
1255 | active_params = jax.tree.reduce(lambda x, y: x + y, active_params_map, 0)
1256 |
1257 | return total_params, active_params
1258 |
--------------------------------------------------------------------------------