├── .gitignore
├── README.md
├── requirements
├── gpu_requirements.txt
└── tpu_requirements.txt
├── scripts
├── transformer
│ ├── 1.3b.sh
│ ├── 125m.sh
│ ├── 350m.sh
│ └── 760m.sh
├── ttt_linear
│ ├── 1.3b.sh
│ ├── 125m.sh
│ ├── 350m.sh
│ └── 760m.sh
└── ttt_mlp
│ ├── 1.3b.sh
│ ├── 125m.sh
│ ├── 350m.sh
│ └── 760m.sh
└── ttt
├── README.md
├── __init__.py
├── dataloader
├── README.md
├── __init__.py
├── language_modeling_hf.py
├── lm_dataset.py
└── tokenization.py
├── infra
├── __init__.py
├── checkpoint.py
├── jax_utils.py
└── optimizers.py
├── models
├── __init__.py
├── bpt.py
├── model.py
└── ttt_layer.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | launcher/
131 |
132 | # TTT Edits
133 | exp
134 | wandb
135 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning to (Learn at Test Time): RNNs with Expressive Hidden States
2 | [**Paper**](https://arxiv.org/abs/2407.04620)
3 | | [**PyTorch Codebase**](https://github.com/test-time-training/ttt-lm-pytorch)
4 | | [**Setup**](#setup)
5 | | [**Replicating Experiments**](#replicating-experiments)
6 | | [**Model Docs**](ttt/README.md)
7 | | [**Dataset Preparation**](ttt/dataloader/README.md)
8 | | [**Inference Benchmark**](https://github.com/test-time-training/ttt-lm-kernels)
9 |
10 | ## Abstract
11 |
12 | Self-attention performs well in long context but has quadratic complexity. Existing RNN layers
13 | have linear complexity, but their performance in long context is limited by the expressive power
14 | of their hidden state. We propose a new class of sequence modeling layers with linear complexity
15 | and an expressive hidden state. The key idea is to make the hidden state a machine learning
16 | model itself, and the update rule a step of self-supervised learning.
17 |
18 | Since the hidden state is updated by training even on test sequences, our layers are called **Test-Time Training (TTT) layers**.
19 | We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model
20 | and a two-layer MLP respectively.
21 |
22 | ## Setup
23 | This codebase is implemented in [JAX](https://jax.readthedocs.io/en/latest/index.html) and has been tested on both GPUs and Cloud TPU VMs with Python 3.11.
24 |
25 | For a PyTorch model definition, please refer to [this link](https://github.com/test-time-training/ttt-lm-pytorch). For inference kernels, or to replicate speed benchmarks from our paper, please view our [kernel implementations](https://github.com/test-time-training/ttt-lm-kernels).
26 |
27 | ### Environment Installation
28 | To setup and run our code on a (local) GPU machine, we highly recommend using [Anaconda](https://anaconda.com/download) when installing python dependencies. Install GPU requirements using:
29 | ```
30 | cd requirements
31 | pip install -r gpu_requirements.txt
32 | ```
33 |
34 | For TPU, please refer to [this link](https://cloud.google.com/tpu/docs/quick-starts) for guidance on creating cloud TPU VMs. Then, run:
35 | ```
36 | cd requirements
37 | pip install -r tpu_requirements.txt
38 | ```
39 |
40 | ### WandB Login
41 | We use WandB for logging training metrics and TTT statistics. After installing the requirements, login to WandB using:
42 | ```
43 | wandb login
44 | ```
45 |
46 |
47 | ### Dataset Download
48 | Our Llama-2 tokenized datasets are available for download from Google Cloud Buckets:
49 |
50 | ```
51 | gsutil -m cp -r gs://llama-2-pile/* llama-2-pile/
52 | gsutil -m cp -r gs://llama-2-books3/* llama-2-books3/
53 | ```
54 |
55 | Once downloaded, set the `dataset_path` flag in `train.py` to the directory containing the `tokenizer_name-meta-llama` folder. This will allow the dataloader to find the correct path.
56 |
57 | Alternatively, to tokenize datasets yourself, refer to [dataset preparation](ttt/dataloader/README.md).
58 |
59 | ## Replicating Experiments
60 | We provide scripts corresponding to each experiment in our paper in the `scripts` folder. After specifying the experiment name and directory, select the desired context length and divide by 0.5 million to calculate the appropriate batch size.
61 |
62 | Depending on the model size, you may need to modify the `mesh_dim` to introduce model sharding. See the [model docs](ttt/README.md) for additional information on the training configuration.
63 |
64 | ## Credits
65 | * This codebase is based on [EasyLM](https://github.com/young-geng/EasyLM).
66 | * Our dataloader is based on [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main/training).
67 |
--------------------------------------------------------------------------------
/requirements/gpu_requirements.txt:
--------------------------------------------------------------------------------
1 | -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
2 | numpy==1.26.4
3 | matplotlib==3.9.0
4 | tqdm
5 | --extra-index-url https://download.pytorch.org/whl/cpu
6 | torch==2.3.0
7 | jax==0.4.14
8 | jaxlib==0.4.14+cuda12.cudnn89
9 | flax==0.7.0
10 | optax==0.1.7
11 | transformers==4.41.0
12 | datasets
13 | mlxu>=0.1.13
14 | einops
15 | ml_collections
16 | scipy==1.12.0
--------------------------------------------------------------------------------
/requirements/tpu_requirements.txt:
--------------------------------------------------------------------------------
1 | -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
2 | jax[tpu]==0.4.14
3 | numpy==1.26.4
4 | matplotlib
5 | tqdm
6 | --extra-index-url https://download.pytorch.org/whl/cpu
7 | torch==2.3.0
8 | flax==0.7.0
9 | optax==0.1.7
10 | transformers==4.41.0
11 | datasets
12 | mlxu>=0.1.13
13 | einops
14 | ml_collections
15 | scipy==1.12.0
--------------------------------------------------------------------------------
/scripts/transformer/1.3b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=TODO
4 | DATA_NAME="the_pile" # "books3"
5 |
6 | # Product should equal 0.5 million
7 | SEQ_LEN=2048
8 | BS=256
9 |
10 | # Experiment details
11 | EXP_NAME=TODO
12 | EXP_DIR=TODO
13 |
14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME};
15 | cd ../..
16 |
17 | python3 -m ttt.train \
18 | --mesh_dim='!-1,1,1' \
19 | --dtype='fp32' \
20 | --total_steps=50000 \
21 | --save_checkpoint_freq=1000 \
22 | --save_milestone_freq=2000 \
23 | --load_model_config='1b' \
24 | --dataset_path=${DATA_PATH} \
25 | --dataset_name=${DATA_NAME} \
26 | --seq_length=${SEQ_LEN} \
27 | --global_batch_size=${BS} \
28 | --optimizer.type='adamw' \
29 | --optimizer.adamw_optimizer.weight_decay=0.1 \
30 | --optimizer.adamw_optimizer.lr=1e-3 \
31 | --optimizer.adamw_optimizer.end_lr=1e-5 \
32 | --optimizer.adamw_optimizer.lr_warmup_steps=5000 \
33 | --optimizer.adamw_optimizer.lr_decay_steps=50000 \
34 | --exp_dir=${EXP_DIR} \
35 | --exp_name=${EXP_NAME}
--------------------------------------------------------------------------------
/scripts/transformer/125m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=TODO
4 | DATA_NAME="the_pile" # "books3"
5 |
6 | # Product should equal 0.5 million
7 | SEQ_LEN=2048
8 | BS=256
9 |
10 | # Experiment details
11 | EXP_NAME=TODO
12 | EXP_DIR=TODO
13 |
14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME};
15 | cd ../..
16 |
17 | python3 -m ttt.train \
18 | --mesh_dim='!-1,1,1' \
19 | --dtype='fp32' \
20 | --total_steps=4800 \
21 | --save_checkpoint_freq=1000 \
22 | --save_milestone_freq=2000 \
23 | --load_model_config='125m' \
24 | --dataset_path=${DATA_PATH} \
25 | --dataset_name=${DATA_NAME} \
26 | --seq_length=${SEQ_LEN} \
27 | --global_batch_size=${BS} \
28 | --optimizer.type='adamw' \
29 | --optimizer.adamw_optimizer.weight_decay=0.1 \
30 | --optimizer.adamw_optimizer.lr=3e-3 \
31 | --optimizer.adamw_optimizer.end_lr=1e-5 \
32 | --optimizer.adamw_optimizer.lr_warmup_steps=480 \
33 | --optimizer.adamw_optimizer.lr_decay_steps=4800 \
34 | --exp_dir=${EXP_DIR} \
35 | --exp_name=${EXP_NAME}
--------------------------------------------------------------------------------
/scripts/transformer/350m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=TODO
4 | DATA_NAME="the_pile" # "books3"
5 |
6 | # Product should equal 0.5 million
7 | SEQ_LEN=2048
8 | BS=256
9 |
10 | # Experiment details
11 | EXP_NAME=TODO
12 | EXP_DIR=TODO
13 |
14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME};
15 | cd ../..
16 |
17 | python3 -m ttt.train \
18 | --mesh_dim='!-1,1,1' \
19 | --dtype='fp32' \
20 | --total_steps=13500 \
21 | --save_checkpoint_freq=1000 \
22 | --save_milestone_freq=2000 \
23 | --load_model_config='350m' \
24 | --dataset_path=${DATA_PATH} \
25 | --dataset_name=${DATA_NAME} \
26 | --seq_length=${SEQ_LEN} \
27 | --global_batch_size=${BS} \
28 | --optimizer.type='adamw' \
29 | --optimizer.adamw_optimizer.weight_decay=0.1 \
30 | --optimizer.adamw_optimizer.lr=1.5e-3 \
31 | --optimizer.adamw_optimizer.end_lr=1e-5 \
32 | --optimizer.adamw_optimizer.lr_warmup_steps=1350 \
33 | --optimizer.adamw_optimizer.lr_decay_steps=13500 \
34 | --exp_dir=${EXP_DIR} \
35 | --exp_name=${EXP_NAME}
--------------------------------------------------------------------------------
/scripts/transformer/760m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=TODO
4 | DATA_NAME="the_pile" # "books3"
5 |
6 | # Product should equal 0.5 million
7 | SEQ_LEN=2048
8 | BS=256
9 |
10 | # Experiment details
11 | EXP_NAME=TODO
12 | EXP_DIR=TODO
13 |
14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME};
15 | cd ../..
16 |
17 | python3 -m ttt.train \
18 | --mesh_dim='!-1,1,1' \
19 | --dtype='fp32' \
20 | --total_steps=29000 \
21 | --save_checkpoint_freq=1000 \
22 | --save_milestone_freq=2000 \
23 | --load_model_config='760m' \
24 | --dataset_path=${DATA_PATH} \
25 | --dataset_name=${DATA_NAME} \
26 | --seq_length=${SEQ_LEN} \
27 | --global_batch_size=${BS} \
28 | --optimizer.type='adamw' \
29 | --optimizer.adamw_optimizer.weight_decay=0.1 \
30 | --optimizer.adamw_optimizer.lr=1.25e-3 \
31 | --optimizer.adamw_optimizer.end_lr=1e-5 \
32 | --optimizer.adamw_optimizer.lr_warmup_steps=2900 \
33 | --optimizer.adamw_optimizer.lr_decay_steps=29000 \
34 | --exp_dir=${EXP_DIR} \
35 | --exp_name=${EXP_NAME}
--------------------------------------------------------------------------------
/scripts/ttt_linear/1.3b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=TODO
4 | DATA_NAME="the_pile" # "books3"
5 |
6 | # Product should equal 0.5 million
7 | SEQ_LEN=2048
8 | BS=256
9 |
10 | # Experiment details
11 | EXP_NAME=TODO
12 | EXP_DIR=TODO
13 |
14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME};
15 | cd ../..
16 |
17 | python3 -m ttt.train \
18 | --mesh_dim='!-1,1,1' \
19 | --dtype='fp32' \
20 | --total_steps=50000 \
21 | --save_checkpoint_freq=1000 \
22 | --save_milestone_freq=2000 \
23 | --load_model_config='1b-TTT' \
24 | --update_model_config="dict(seq_modeling_block='ttt_linear', ttt_base_lr=1.0)" \
25 | --dataset_path=${DATA_PATH} \
26 | --dataset_name=${DATA_NAME} \
27 | --seq_length=${SEQ_LEN} \
28 | --global_batch_size=${BS} \
29 | --optimizer.type='adamw' \
30 | --optimizer.adamw_optimizer.weight_decay=0.1 \
31 | --optimizer.adamw_optimizer.lr=1e-3 \
32 | --optimizer.adamw_optimizer.end_lr=1e-5 \
33 | --optimizer.adamw_optimizer.lr_warmup_steps=5000 \
34 | --optimizer.adamw_optimizer.lr_decay_steps=50000 \
35 | --exp_dir=${EXP_DIR} \
36 | --exp_name=${EXP_NAME}
--------------------------------------------------------------------------------
/scripts/ttt_linear/125m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=TODO
4 | DATA_NAME="the_pile" # "books3"
5 |
6 | # Product should equal 0.5 million
7 | SEQ_LEN=2048
8 | BS=256
9 |
10 | # Experiment details
11 | EXP_NAME=TODO
12 | EXP_DIR=TODO
13 |
14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME};
15 | cd ../..
16 |
17 | python3 -m ttt.train \
18 | --mesh_dim='!-1,1,1' \
19 | --dtype='fp32' \
20 | --total_steps=4800 \
21 | --save_checkpoint_freq=1000 \
22 | --save_milestone_freq=2000 \
23 | --load_model_config='125m-TTT' \
24 | --update_model_config="dict(seq_modeling_block='ttt_linear', ttt_base_lr=1.0)" \
25 | --dataset_path=${DATA_PATH} \
26 | --dataset_name=${DATA_NAME} \
27 | --seq_length=${SEQ_LEN} \
28 | --global_batch_size=${BS} \
29 | --optimizer.type='adamw' \
30 | --optimizer.adamw_optimizer.weight_decay=0.1 \
31 | --optimizer.adamw_optimizer.lr=3e-3 \
32 | --optimizer.adamw_optimizer.end_lr=1e-5 \
33 | --optimizer.adamw_optimizer.lr_warmup_steps=480 \
34 | --optimizer.adamw_optimizer.lr_decay_steps=4800 \
35 | --exp_dir=${EXP_DIR} \
36 | --exp_name=${EXP_NAME}
--------------------------------------------------------------------------------
/scripts/ttt_linear/350m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=TODO
4 | DATA_NAME="the_pile" # "books3"
5 |
6 | # Product should equal 0.5 million
7 | SEQ_LEN=2048
8 | BS=256
9 |
10 | # Experiment details
11 | EXP_NAME=TODO
12 | EXP_DIR=TODO
13 |
14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME};
15 | cd ../..
16 |
17 | python3 -m ttt.train \
18 | --mesh_dim='!-1,1,1' \
19 | --dtype='fp32' \
20 | --total_steps=13500 \
21 | --save_checkpoint_freq=1000 \
22 | --save_milestone_freq=2000 \
23 | --load_model_config='350m-TTT' \
24 | --update_model_config="dict(seq_modeling_block='ttt_linear', ttt_base_lr=1.0)" \
25 | --dataset_path=${DATA_PATH} \
26 | --dataset_name=${DATA_NAME} \
27 | --seq_length=${SEQ_LEN} \
28 | --global_batch_size=${BS} \
29 | --optimizer.type='adamw' \
30 | --optimizer.adamw_optimizer.weight_decay=0.1 \
31 | --optimizer.adamw_optimizer.lr=1.5e-3 \
32 | --optimizer.adamw_optimizer.end_lr=1e-5 \
33 | --optimizer.adamw_optimizer.lr_warmup_steps=1350 \
34 | --optimizer.adamw_optimizer.lr_decay_steps=13500 \
35 | --exp_dir=${EXP_DIR} \
36 | --exp_name=${EXP_NAME}
--------------------------------------------------------------------------------
/scripts/ttt_linear/760m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=TODO
4 | DATA_NAME="the_pile" # "books3"
5 |
6 | # Product should equal 0.5 million
7 | SEQ_LEN=2048
8 | BS=256
9 |
10 | # Experiment details
11 | EXP_NAME=TODO
12 | EXP_DIR=TODO
13 |
14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME};
15 | cd ../..
16 |
17 | python3 -m ttt.train \
18 | --mesh_dim='!-1,1,1' \
19 | --dtype='fp32' \
20 | --total_steps=29000 \
21 | --save_checkpoint_freq=1000 \
22 | --save_milestone_freq=2000 \
23 | --load_model_config='760m-TTT' \
24 | --update_model_config="dict(seq_modeling_block='ttt_linear', ttt_base_lr=1.0)" \
25 | --dataset_path=${DATA_PATH} \
26 | --dataset_name=${DATA_NAME} \
27 | --seq_length=${SEQ_LEN} \
28 | --global_batch_size=${BS} \
29 | --optimizer.type='adamw' \
30 | --optimizer.adamw_optimizer.weight_decay=0.1 \
31 | --optimizer.adamw_optimizer.lr=1.25e-3 \
32 | --optimizer.adamw_optimizer.end_lr=1e-5 \
33 | --optimizer.adamw_optimizer.lr_warmup_steps=2900 \
34 | --optimizer.adamw_optimizer.lr_decay_steps=29000 \
35 | --exp_dir=${EXP_DIR} \
36 | --exp_name=${EXP_NAME}
--------------------------------------------------------------------------------
/scripts/ttt_mlp/1.3b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=TODO
4 | DATA_NAME="the_pile" # "books3"
5 |
6 | # Product should equal 0.5 million
7 | SEQ_LEN=2048
8 | BS=256
9 |
10 | # Experiment details
11 | EXP_NAME=TODO
12 | EXP_DIR=TODO
13 |
14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME};
15 | cd ../..
16 |
17 | python3 -m ttt.train \
18 | --mesh_dim='!-1,1,1' \
19 | --dtype='fp32' \
20 | --total_steps=50000 \
21 | --save_checkpoint_freq=1000 \
22 | --save_milestone_freq=2000 \
23 | --load_model_config='1b-TTT' \
24 | --update_model_config="dict(seq_modeling_block='ttt_mlp', ttt_base_lr=0.1, ttt_base_lr_init=0.01, ttt_base_lr_warmup=5000)" \
25 | --dataset_path=${DATA_PATH} \
26 | --dataset_name=${DATA_NAME} \
27 | --seq_length=${SEQ_LEN} \
28 | --global_batch_size=${BS} \
29 | --optimizer.type='adamw' \
30 | --optimizer.adamw_optimizer.weight_decay=0.1 \
31 | --optimizer.adamw_optimizer.lr=1e-3 \
32 | --optimizer.adamw_optimizer.end_lr=1e-5 \
33 | --optimizer.adamw_optimizer.lr_warmup_steps=5000 \
34 | --optimizer.adamw_optimizer.lr_decay_steps=50000 \
35 | --exp_dir=${EXP_DIR} \
36 | --exp_name=${EXP_NAME}
--------------------------------------------------------------------------------
/scripts/ttt_mlp/125m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=TODO
4 | DATA_NAME="the_pile" # "books3"
5 |
6 | # Product should equal 0.5 million
7 | SEQ_LEN=2048
8 | BS=256
9 |
10 | # Experiment details
11 | EXP_NAME=TODO
12 | EXP_DIR=TODO
13 |
14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME};
15 | cd ../..
16 |
17 | python3 -m ttt.train \
18 | --mesh_dim='!-1,1,1' \
19 | --dtype='fp32' \
20 | --total_steps=4800 \
21 | --save_checkpoint_freq=1000 \
22 | --save_milestone_freq=2000 \
23 | --load_model_config='125m-TTT' \
24 | --update_model_config="dict(seq_modeling_block='ttt_mlp', ttt_base_lr=0.1, ttt_base_lr_init=0.01, ttt_base_lr_warmup=480)" \
25 | --dataset_path=${DATA_PATH} \
26 | --dataset_name=${DATA_NAME} \
27 | --seq_length=${SEQ_LEN} \
28 | --global_batch_size=${BS} \
29 | --optimizer.type='adamw' \
30 | --optimizer.adamw_optimizer.weight_decay=0.1 \
31 | --optimizer.adamw_optimizer.lr=3e-3 \
32 | --optimizer.adamw_optimizer.end_lr=1e-5 \
33 | --optimizer.adamw_optimizer.lr_warmup_steps=480 \
34 | --optimizer.adamw_optimizer.lr_decay_steps=4800 \
35 | --exp_dir=${EXP_DIR} \
36 | --exp_name=${EXP_NAME}
--------------------------------------------------------------------------------
/scripts/ttt_mlp/350m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=TODO
4 | DATA_NAME="the_pile" # "books3"
5 |
6 | # Product should equal 0.5 million
7 | SEQ_LEN=2048
8 | BS=256
9 |
10 | # Experiment details
11 | EXP_NAME=TODO
12 | EXP_DIR=TODO
13 |
14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME};
15 | cd ../..
16 |
17 | python3 -m ttt.train \
18 | --mesh_dim='!-1,1,1' \
19 | --dtype='fp32' \
20 | --total_steps=13500 \
21 | --save_checkpoint_freq=1000 \
22 | --save_milestone_freq=2000 \
23 | --load_model_config='350m-TTT' \
24 | --update_model_config="dict(seq_modeling_block='ttt_mlp', ttt_base_lr=0.1, ttt_base_lr_init=0.01, ttt_base_lr_warmup=1350)" \
25 | --dataset_path=${DATA_PATH} \
26 | --dataset_name=${DATA_NAME} \
27 | --seq_length=${SEQ_LEN} \
28 | --global_batch_size=${BS} \
29 | --optimizer.type='adamw' \
30 | --optimizer.adamw_optimizer.weight_decay=0.1 \
31 | --optimizer.adamw_optimizer.lr=1.5e-3 \
32 | --optimizer.adamw_optimizer.end_lr=1e-5 \
33 | --optimizer.adamw_optimizer.lr_warmup_steps=1350 \
34 | --optimizer.adamw_optimizer.lr_decay_steps=13500 \
35 | --exp_dir=${EXP_DIR} \
36 | --exp_name=${EXP_NAME}
--------------------------------------------------------------------------------
/scripts/ttt_mlp/760m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=TODO
4 | DATA_NAME="the_pile" # "books3"
5 |
6 | # Product should equal 0.5 million
7 | SEQ_LEN=2048
8 | BS=256
9 |
10 | # Experiment details
11 | EXP_NAME=TODO
12 | EXP_DIR=TODO
13 |
14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME};
15 | cd ../..
16 |
17 | python3 -m ttt.train \
18 | --mesh_dim='!-1,1,1' \
19 | --dtype='fp32' \
20 | --total_steps=29000 \
21 | --save_checkpoint_freq=1000 \
22 | --save_milestone_freq=2000 \
23 | --load_model_config='760m-TTT' \
24 | --update_model_config="dict(seq_modeling_block='ttt_mlp', ttt_base_lr=0.1, ttt_base_lr_init=0.01, ttt_base_lr_warmup=2900)" \
25 | --dataset_path=${DATA_PATH} \
26 | --dataset_name=${DATA_NAME} \
27 | --seq_length=${SEQ_LEN} \
28 | --global_batch_size=${BS} \
29 | --optimizer.type='adamw' \
30 | --optimizer.adamw_optimizer.weight_decay=0.1 \
31 | --optimizer.adamw_optimizer.lr=1.25e-3 \
32 | --optimizer.adamw_optimizer.end_lr=1e-5 \
33 | --optimizer.adamw_optimizer.lr_warmup_steps=2900 \
34 | --optimizer.adamw_optimizer.lr_decay_steps=29000 \
35 | --exp_dir=${EXP_DIR} \
36 | --exp_name=${EXP_NAME}
--------------------------------------------------------------------------------
/ttt/README.md:
--------------------------------------------------------------------------------
1 | # Model Documentation
2 |
3 | This codebase is implemented in [JAX](https://jax.readthedocs.io/en/latest/index.html) and is based on [EasyLM](https://github.com/young-geng/EasyLM/tree/main).
4 |
5 | ## Training Flags
6 | - `mesh_dim` refers to the the mesh used by JAX to parallelize computation across multiple accelerators and hosts. Please refer to the [EasyLM paralellization documentation](https://github.com/young-geng/EasyLM/blob/main/docs/parallelism.md) for configuration.
7 | - `seq_length` and `global_batch_size` determine the total number of tokens per batch (fixed to 0.5 million in our paper).
8 | - `load_model_config` is used to load a default configs from `model.py`
9 | - `update_model_config` is used to update a default config. To update specific keys, pass a dictionary to the flag:
10 |
11 | ```
12 | --update_model_config="dict(seq_modeling_block='ttt_linear', ttt_base_lr=1.0)"
13 | ```
14 |
15 | All additional hyperparameters are specified Appendix C of our paper.
16 |
17 | ## Model Flags
18 | All model configuration flags can be found in `model.py`. Here are a few important details to note:
19 |
20 | We implement four TTT choices for the `seq_modeling_block`:
21 | - `ttt_linear` and `ttt_mlp`, which specify TTT layers within the **Mamba backbone**.
22 | - `ttt_linear_base` and `ttt_mlp_base`, which specify TTT layers within the **Transformer backbone**.
23 |
24 | ### TTT LR
25 | - For all `ttt_linear` experiments, `ttt_base_lr` is set to 1.0.
26 | - For all `ttt_mlp` experiments:
27 | - `ttt_base_lr` is set to 0.1
28 | - `ttt_base_lr_init` is set to 0.01
29 | - `ttt_base_lr_warmup` is set to the total number of outer loop warmup steps.
30 |
--------------------------------------------------------------------------------
/ttt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/test-time-training/ttt-lm-jax/ac8cdc8a43b811afe23c27d7e82eed34a747b19c/ttt/__init__.py
--------------------------------------------------------------------------------
/ttt/dataloader/README.md:
--------------------------------------------------------------------------------
1 |
Dataset Preparation
2 |
3 | Option 1: Download Pre-Tokenized Datasets (Recommended)
4 |
5 | Our Llama-2 tokenized datasets are available for download from Google Cloud Buckets:
6 |
7 | ```
8 | gsutil -m cp -r gs://llama-2-pile/* llama-2-pile/
9 | gsutil -m cp -r gs://llama-2-books3/* llama-2-books3/
10 | ```
11 |
12 | Once downloaded, set the `dataset_path` flag in `train.py` to the directory containing the `tokenizer_name-meta-llama` folder. This will allow the dataloader to find the correct path.
13 |
14 | Option 2: Tokenize Datasets Yourself
15 |
16 | Since the raw Pile and Books3 datasets are no longer publically available on Huggingface, we recommend acquiring them via correspondence to their authors or from the community.
17 |
18 | Before tokenization, set `raw_json_path` and `cache_dir` in `tokenization.py` to the path where the raw dataset (in json format) is stored and where you want to store the tokenized dataset, respectively.
19 |
20 | Our tokenization script is based on [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main/training#dataset-preparation). Tokenize the raw datasets using the commands below.
21 |
22 | **Pile:**
23 | ```
24 | export PYTHONPATH=$PWD:$PYTHONPATH
25 | pytest -q -s ttt/dataloader/tokenization.py -k "pile"
26 | ```
27 | This takes around 20h on a 64-core CPU. The processed dataset is 716G.
28 |
29 | **Books3:**
30 | ```
31 | export PYTHONPATH=$PWD:$PYTHONPATH
32 | pytest -q -s ttt/dataloader/tokenization.py -k "books"
33 | ```
34 | This takes around 3h on a 64-core CPU. The processed dataset is 61G.
35 |
--------------------------------------------------------------------------------
/ttt/dataloader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/test-time-training/ttt-lm-jax/ac8cdc8a43b811afe23c27d7e82eed34a747b19c/ttt/dataloader/__init__.py
--------------------------------------------------------------------------------
/ttt/dataloader/language_modeling_hf.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py
2 | import sys
3 | import os.path as osp
4 | from itertools import chain
5 | from pathlib import Path
6 | import pickle
7 | from typing import Any, List, Union
8 | import subprocess
9 | import mmap
10 |
11 | import numpy as np
12 | from torch.utils.data.dataloader import DataLoader, Dataset
13 | from transformers import AutoTokenizer
14 | from datasets import load_dataset
15 |
16 | from ttt.dataloader.lm_dataset import RandomFaultTolerantSampler, LMDataset
17 | from ttt.infra.jax_utils import master_print
18 |
19 |
20 | class LMDataModule:
21 | def __init__(
22 | self,
23 | dataset_name,
24 | tokenizer_name,
25 | dataset_config_name=None,
26 | max_length=1024,
27 | cache_dir=None,
28 | raw_json_path=None,
29 | val_ratio=0.0005,
30 | val_split_seed=2357,
31 | add_eos=True,
32 | batch_size=32,
33 | batch_size_eval=None,
34 | num_workers=1,
35 | loader_workers=1,
36 | shuffle=False,
37 | pin_memory=False,
38 | drop_last=False,
39 | fault_tolerant=False,
40 | ):
41 | super().__init__()
42 | self.dataset_name = dataset_name
43 | self.dataset_config_name = dataset_config_name
44 | self.tokenizer_name = tokenizer_name
45 | self.cache_dir = None if cache_dir is None else Path(cache_dir).expanduser()
46 | self.raw_json_path = raw_json_path
47 | self.max_length = max_length
48 | self.val_ratio = val_ratio
49 | self.val_split_seed = val_split_seed
50 | self.add_eos = add_eos
51 | self.batch_size = batch_size
52 | self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size
53 | self.num_workers = num_workers
54 | self.loader_workers = loader_workers
55 | self.shuffle = shuffle
56 | self.pin_memory = pin_memory
57 | self.drop_last = drop_last
58 | if fault_tolerant:
59 | assert self.shuffle
60 | self.fault_tolerant = fault_tolerant
61 |
62 | def prepare_data(self):
63 | if self.cache_dir is None:
64 | # Just download the dataset
65 | load_dataset(self.dataset_name, self.dataset_config_name)
66 | else:
67 | # Process the dataset and save it
68 | self.process_dataset()
69 |
70 | def setup(self, stage=None):
71 | if stage == "test" and hasattr(self, "dataset_test"):
72 | return
73 | concat_ids, self.tokenizer = self.process_dataset()
74 | self.vocab_size = len(self.tokenizer)
75 | self.dataset_train, self.dataset_val, self.dataset_test = [
76 | LMDataset(
77 | concat_ids[split], seq_len=self.max_length, llama2=(self.tokenizer_name == "meta-llama/Llama-2-7b-hf")
78 | )
79 | for split in ["train", "validation", "test"]
80 | ]
81 |
82 | def process_dataset(self):
83 | cache_dir = None if self.cache_dir is None else self.cache_dir / self._cache_dir_name
84 |
85 | if cache_dir is not None:
86 | if cache_dir.is_dir():
87 | return self._load_from_cache(cache_dir)
88 |
89 | if self.raw_json_path is not None:
90 | raw_datasets = load_dataset("json", data_files=self.raw_json_path)
91 | else:
92 | raw_datasets = load_dataset(self.dataset_name, self.dataset_config_name)
93 |
94 | # https://github.com/stanford-crfm/mistral/blob/main/src/corpora/auto.py
95 | if "validation" not in raw_datasets:
96 | assert "train" in raw_datasets, "You must have train in raw_datasets to make a validation raw_datasets"
97 | raw_datasets = raw_datasets["train"].train_test_split(
98 | test_size=self.val_ratio,
99 | seed=self.val_split_seed,
100 | shuffle=True, # Otherwise test will be at the end of the dataset
101 | )
102 | raw_datasets["validation"] = raw_datasets["test"]
103 |
104 | tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True)
105 | # Preprocessing the datasets.
106 | # First we tokenize all the texts.
107 | column_names = raw_datasets["train"].column_names
108 | text_column_name = "text" if "text" in column_names else column_names[0]
109 | if self.add_eos:
110 | add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq
111 | add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs]
112 | tokenize = lambda example: tokenizer(add_eos_batched(example[text_column_name]))
113 | else:
114 | tokenize = lambda example: tokenizer(example[text_column_name])
115 |
116 | dtype = np.uint16 if tokenizer.vocab_size < 64 * 1024 else np.int32
117 |
118 | def tokenize_concat(examples):
119 | # We just need 'input_ids', not 'attention_mask' (since it's all 1)
120 | input_ids = np.fromiter(chain(*tokenize(examples)["input_ids"]), dtype=dtype)
121 | # Need to return a list since we're doing batched processing
122 | return {"input_ids": [input_ids], "len": [len(input_ids)]}
123 |
124 | tokenized_datasets = raw_datasets.map(
125 | tokenize_concat,
126 | batched=True,
127 | num_proc=max(self.num_workers, 1),
128 | remove_columns=column_names,
129 | desc="Running tokenizer on dataset",
130 | )
131 |
132 | # Use disk
133 | concat_ids = {}
134 | assert cache_dir is not None
135 | cache_dir.mkdir(parents=True, exist_ok=True)
136 |
137 | def write_ids_to_disk(example, filename):
138 | with open(filename, "r+b") as f:
139 | mm = mmap.mmap(f.fileno(), 0)
140 | start_idx = example["len_offset"] - len(example["input_ids"])
141 | array_len = len(example["input_ids"])
142 | arr = np.ndarray((array_len,), dtype=dtype, buffer=mm, offset=np.dtype(dtype).itemsize * start_idx)
143 | arr[:] = example["input_ids"]
144 | mm.flush()
145 |
146 | for name, ds in tokenized_datasets.items():
147 | tokenized_datasets[name] = ds.add_column("len_offset", np.cumsum(ds["len"]))
148 | array_len = tokenized_datasets[name][-1]["len_offset"]
149 |
150 | filename = cache_dir / f"{name}.bin"
151 |
152 | # Need to create the file with this specific size first
153 | # https://ostechnix.com/create-files-certain-size-linux/
154 | subprocess.run(["truncate", "-s", str(array_len * np.dtype(dtype).itemsize), str(filename)], check=True)
155 |
156 | tokenized_datasets[name].map(
157 | write_ids_to_disk,
158 | fn_kwargs={"filename": filename}, # .bin
159 | batched=False,
160 | num_proc=max(self.num_workers, 1),
161 | desc="Concatenating examples",
162 | )
163 | concat_ids[name] = np.memmap(filename, dtype=dtype, mode="r", shape=(array_len,))
164 |
165 | if cache_dir is not None:
166 | self._save_to_cache(concat_ids, tokenizer, cache_dir)
167 |
168 | for name in concat_ids:
169 | Path(cache_dir / f"{name}.bin").unlink()
170 |
171 | return concat_ids, tokenizer
172 |
173 | def _save_to_cache(self, concat_ids, tokenizer, cache_dir):
174 | cache_dir.mkdir(parents=True, exist_ok=True)
175 | master_print(f"Saving to cache at {str(cache_dir)}")
176 | for k, v in concat_ids.items():
177 | np.save(cache_dir / f"{k}.npy", v)
178 | with open(cache_dir / "tokenizer.pkl", "wb") as f:
179 | pickle.dump(tokenizer, f)
180 |
181 | def _load_from_cache(self, cache_dir):
182 | assert cache_dir.is_dir()
183 | master_print(f"Load from cache at {str(cache_dir)}")
184 | concat_ids = {
185 | split: np.load(cache_dir / f"{split}.npy", mmap_mode="r") for split in ["train", "validation", "test"]
186 | }
187 | with open(cache_dir / "tokenizer.pkl", "rb") as f:
188 | tokenizer = pickle.load(f)
189 | return concat_ids, tokenizer
190 |
191 | @property
192 | def _cache_dir_name(self):
193 | return (
194 | f"tokenizer_name-{self.tokenizer_name}-val_ratio-{self.val_ratio}-"
195 | f"val_split_seed-{self.val_split_seed}-add_eos-{self.add_eos}-detokenize-False"
196 | )
197 |
198 | def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
199 | """The train dataloader"""
200 | if self.shuffle and self.fault_tolerant:
201 | shuffle = False
202 | sampler = RandomFaultTolerantSampler(self.dataset_train)
203 | else:
204 | shuffle = self.shuffle
205 | sampler = None
206 |
207 | return self._data_loader(self.dataset_train, batch_size=self.batch_size, shuffle=shuffle, sampler=sampler)
208 |
209 | def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
210 | """The val dataloader"""
211 | return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval)
212 |
213 | def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
214 | """The test dataloader"""
215 | return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval)
216 |
217 | def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False, sampler=None) -> DataLoader:
218 | return DataLoader(
219 | dataset,
220 | batch_size=batch_size,
221 | num_workers=self.loader_workers,
222 | shuffle=shuffle,
223 | sampler=sampler,
224 | drop_last=self.drop_last,
225 | pin_memory=self.pin_memory,
226 | )
227 |
--------------------------------------------------------------------------------
/ttt/dataloader/lm_dataset.py:
--------------------------------------------------------------------------------
1 | # Inspired by https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt/datasets.py
2 | # Except we don't pad the last block and don't use overlapping eval
3 | # And we return both the input and the target
4 | import math
5 | from typing import Iterator
6 |
7 | import numpy as np
8 | import torch
9 | from torch.utils.data import RandomSampler
10 |
11 |
12 | class RandomFaultTolerantSampler(RandomSampler):
13 | def __init__(self, *args, generator=None, **kwargs):
14 | if generator is None:
15 | seed = int(torch.empty((), dtype=torch.int64).random_().item())
16 | generator = torch.Generator().manual_seed(seed)
17 | super().__init__(*args, generator=generator, **kwargs)
18 | self.counter = 0 # Absolute position of data reading
19 | self.is_rollback = False
20 | self.state = self.generator.get_state() # Record the initial state of generator determined by seed
21 | # Should not be changed before an entire loop over dataset is done
22 | # Give same seed, generator state change deterministically after each torch.randperm
23 | self.shuffle_log = [{"shuffle_after": self.counter}]
24 |
25 | def state_dict(self):
26 | return {"random_state": self.state, "counter": self.counter, "shuffle_log": self.shuffle_log}
27 |
28 | def load_state_dict(self, state_dict):
29 | self.state = state_dict["random_state"]
30 | self.counter = state_dict["counter"]
31 | if "shuffle_log" in state_dict:
32 | self.shuffle_log = state_dict["shuffle_log"] # A list of shuffle records, each record is a dict
33 |
34 | def update_shuffle_history(self):
35 | self.shuffle_log.append({"shuffle_after": self.counter})
36 |
37 | def go_through_shuffle_history(self):
38 | N = len(self.data_source)
39 | initial_shuffle_after = self.shuffle_log[0]["shuffle_after"]
40 | self.generator.set_state(self.state)
41 | indices = torch.randperm(N, generator=self.generator)
42 |
43 | for shuffle_record in self.shuffle_log[1:]:
44 | shuffle_after = shuffle_record["shuffle_after"]
45 | new_order = torch.randperm(N - shuffle_after, generator=self.generator) #
46 | indices = torch.concatenate([indices[:shuffle_after], indices[shuffle_after:][new_order]])
47 |
48 | return indices
49 |
50 | def __iter__(self) -> Iterator[int]:
51 |
52 | if self.is_rollback:
53 | # Before entering __iter__() due to rollback, set loader.sampler.is_rollback = True manually outside
54 | self.update_shuffle_history() # Add a shuffle action at self.counter, which is where we resume from but need a different coming data order
55 | self.is_rollback = False
56 |
57 | indices = self.go_through_shuffle_history()
58 | indices = indices[self.counter :].tolist()
59 |
60 | for index in indices:
61 | self.counter += 1
62 | yield index
63 |
64 | # End of one loop over the entire dataset
65 | self.counter = 0
66 | self.state = self.generator.get_state() # If have the next epoch, state will definitely be different
67 | self.shuffle_log = [{"shuffle_after": self.counter}]
68 |
69 |
70 | class LMDataset(torch.utils.data.Dataset):
71 | def __init__(self, tokens, seq_len, drop_last=True, llama2=False):
72 | """tokens should be a numpy array"""
73 | self.seq_len = seq_len
74 | ntokens = len(tokens)
75 | if drop_last:
76 | ntokens = ((ntokens - 1) // seq_len) * seq_len + 1
77 | self.ntokens = ntokens
78 | # We're careful not to slice tokens, since it could be a memmap'ed array or H5 dataset,
79 | # and slicing would load it to memory.
80 | self.tokens = tokens
81 | self.total_sequences = math.ceil((self.ntokens - 1) / self.seq_len)
82 | self.llama2 = llama2
83 |
84 | def __len__(self):
85 | return self.total_sequences
86 |
87 | def __getitem__(self, idx):
88 | idx = idx % self.ntokens
89 | start_idx = idx * self.seq_len
90 | seq_len = min(self.seq_len, self.ntokens - 1 - start_idx)
91 | data = torch.as_tensor(self.tokens[start_idx : (start_idx + seq_len + 1)].astype(np.int32))
92 | if self.llama2:
93 | return {
94 | "input_tokens": data[:-1],
95 | "target_tokens": data[1:].clone(),
96 | "loss_masks": (data[1:] != 1).to(torch.float32),
97 | }
98 | else:
99 | return {
100 | "input_tokens": data[:-1],
101 | "target_tokens": data[1:].clone(),
102 | "loss_masks": torch.ones_like(data[:-1], dtype=torch.float32),
103 | }
104 |
--------------------------------------------------------------------------------
/ttt/dataloader/tokenization.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | current_dir = Path(__file__).parent.absolute()
5 |
6 |
7 | import pytest
8 |
9 | import torch
10 |
11 | import dotenv
12 |
13 | from ttt.dataloader.language_modeling_hf import LMDataModule
14 |
15 | # load environment variables from `.env` file if it exists
16 | # recursively searches for `.env` in all folders starting from work dir
17 | dotenv.load_dotenv(override=True)
18 |
19 |
20 | def div_up(x: int, y: int) -> int:
21 | return (x + y - 1) // y
22 |
23 |
24 | # https://stackoverflow.com/questions/1006289/how-to-find-out-the-number-of-cpus-using-python/55423170#55423170
25 | def num_cpu_cores():
26 | try:
27 | import psutil
28 |
29 | return psutil.cpu_count(logical=False)
30 | except ImportError:
31 | return len(os.sched_getaffinity(0))
32 |
33 |
34 | class TestLMDataModule:
35 | def test_the_pile(self):
36 | batch_size = 8
37 | dataset_name = "the_pile"
38 | dataset_config_name = None
39 | cache_dir = Path(
40 | "/mnt/disks/persistent/the_pile_release"
41 | ) # TODO: Fill in your path to save the tokenized dataset
42 | raw_json_path = (
43 | "/mnt/disks/persistent/PILE"
44 | ) # TODO: Fill in your path that already stores the raw dataset in json format
45 | max_length = 2048
46 | num_workers = num_cpu_cores() // 2
47 | datamodule = LMDataModule(
48 | dataset_name,
49 | tokenizer_name="meta-llama/Llama-2-7b-hf",
50 | dataset_config_name=dataset_config_name,
51 | max_length=max_length,
52 | cache_dir=cache_dir,
53 | raw_json_path=raw_json_path,
54 | add_eos=True, # bos is added by default in llama2 tokenizer
55 | batch_size=batch_size,
56 | num_workers=num_workers,
57 | )
58 | datamodule.prepare_data()
59 | datamodule.setup(stage="fit")
60 | train_loader = datamodule.train_dataloader()
61 | val_loader = datamodule.val_dataloader()
62 | datamodule.setup(stage="test")
63 | test_loader = datamodule.test_dataloader()
64 | # Token number of The Pile when tokenized by the llama2 tokenizer
65 | train_len = 383509963636
66 | val_len = 393983786
67 | test_len = 383707892
68 | assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
69 | assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
70 | assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
71 | for loader in [train_loader, val_loader, test_loader]:
72 | x, y = next(iter(loader))
73 | assert x.dim() == 2
74 | assert x.shape == (batch_size, max_length)
75 | assert x.dtype == torch.long
76 | assert torch.allclose(x[:, 1:], y[:, :-1])
77 |
78 | def test_books(self):
79 | batch_size = 8
80 | dataset_name = "books3"
81 | dataset_config_name = None
82 | cache_dir = Path(
83 | "/mnt/disks/persistent/books3_release"
84 | ) # TODO: fill in your path to save the tokenized dataset
85 | raw_json_path = (
86 | "/mnt/disks/persistent/lwm_raw/lwm_text_data/combined_books.jsonl"
87 | ) # TODO: fill in your path that already stores the raw dataset in json format
88 | max_length = 2048
89 | num_workers = 1
90 | datamodule = LMDataModule(
91 | dataset_name,
92 | tokenizer_name="meta-llama/Llama-2-7b-hf",
93 | dataset_config_name=dataset_config_name,
94 | max_length=max_length,
95 | cache_dir=cache_dir,
96 | raw_json_path=raw_json_path,
97 | add_eos=True, # bos is added by default in llama2 tokenizer
98 | batch_size=batch_size,
99 | num_workers=num_workers,
100 | )
101 | datamodule.prepare_data()
102 | datamodule.setup(stage="fit")
103 | train_loader = datamodule.train_dataloader()
104 | val_loader = datamodule.val_dataloader()
105 | datamodule.setup(stage="test")
106 | test_loader = datamodule.test_dataloader()
107 | # Token number of Books3 when tokenized by the llama2 tokenizer
108 | train_len = 32585931901
109 | val_len = 14007763
110 | test_len = 14007763
111 | assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
112 | assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
113 | assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
114 | for loader in [train_loader, val_loader, test_loader]:
115 | x, y = next(iter(loader))
116 | assert x.dim() == 2
117 | assert x.shape == (batch_size, max_length)
118 | assert x.dtype == torch.long
119 | assert torch.allclose(x[:, 1:], y[:, :-1])
120 |
--------------------------------------------------------------------------------
/ttt/infra/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/test-time-training/ttt-lm-jax/ac8cdc8a43b811afe23c27d7e82eed34a747b19c/ttt/infra/__init__.py
--------------------------------------------------------------------------------
/ttt/infra/checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from ml_collections import ConfigDict
4 | import mlxu
5 | import jax
6 | import jax.numpy as jnp
7 | import flax
8 | from flax.serialization import from_bytes, to_bytes, to_state_dict, from_state_dict
9 | from flax.traverse_util import flatten_dict, unflatten_dict, empty_node
10 | import msgpack
11 |
12 | from ttt.infra.jax_utils import tree_apply, float_tensor_to_dtype
13 |
14 |
15 | class StreamingCheckpointer(object):
16 | """Custom msgpack checkpointer that saves large train states by serializing
17 | and saving tensors one by one in a streaming fashion. Avoids running
18 | out of memory or local TPU disk with default flax checkpointer.
19 | """
20 |
21 | @staticmethod
22 | def get_default_config(updates=None):
23 | config = ConfigDict()
24 | config.float_dtype = "bf16"
25 | config.save_optimizer_state = True
26 |
27 | if updates is not None:
28 | config.update(ConfigDict(updates).copy_and_resolve_references())
29 | return config
30 |
31 | def __init__(self, config, checkpoint_dir, enable=True):
32 | self.config = self.get_default_config(config)
33 | self.checkpoint_dir = checkpoint_dir
34 | self.enable = enable
35 |
36 | def save_checkpoint(self, train_state, filename, gather_fns=None):
37 | if self.enable:
38 | path = os.path.join(self.checkpoint_dir, filename)
39 | os.makedirs(os.path.dirname(path), exist_ok=True)
40 | else:
41 | path = "/dev/null"
42 | self.save_train_state_to_file(train_state, path, gather_fns, self.config.float_dtype)
43 |
44 | @staticmethod
45 | def save_train_state_to_file(train_state, path, gather_fns=None, float_dtype=None):
46 | train_state = to_state_dict(train_state)
47 | packer = msgpack.Packer()
48 | flattend_train_state = flatten_dict(train_state)
49 | if gather_fns is not None:
50 | gather_fns = flatten_dict(to_state_dict(gather_fns))
51 |
52 | with mlxu.open_file(path, "wb") as fout:
53 | for key, value in flattend_train_state.items():
54 | if gather_fns is not None:
55 | value = gather_fns[key](value)
56 | value = float_tensor_to_dtype(value, float_dtype)
57 | fout.write(packer.pack((key, to_bytes(value))))
58 |
59 | def save_pickle(self, obj, filename):
60 | if self.enable:
61 | path = os.path.join(self.checkpoint_dir, filename)
62 | os.makedirs(os.path.dirname(path), exist_ok=True)
63 | else:
64 | path = "/dev/null"
65 | mlxu.save_pickle(obj, path)
66 |
67 | def save_all(self, train_state, gather_fns, metadata=None, dataset=None, milestone=False):
68 | step = int(jax.device_get(train_state.step))
69 | if self.config.save_optimizer_state:
70 | checkpoint_state = train_state
71 | checkpoint_name = "streaming_train_state"
72 | checkpoint_gather_fns = gather_fns
73 | else:
74 | checkpoint_state = train_state.params["params"]
75 | checkpoint_name = "streaming_params"
76 | checkpoint_gather_fns = gather_fns.params["params"]
77 |
78 | if milestone:
79 | # Save a milestone checkpoint that will not be overwritten
80 | self.save_pickle(metadata, f"step_{step}/metadata_{step}.pkl")
81 | self.save_pickle(dataset, f"step_{step}/dataset_{step}.pkl")
82 | self.save_checkpoint(checkpoint_state, f"step_{step}/{checkpoint_name}_{step}", checkpoint_gather_fns)
83 | # Additionally save a checkpoint that can be overwritten for automatic resuming
84 | self.save_pickle(metadata, "metadata.pkl")
85 | self.save_pickle(dataset, "dataset.pkl")
86 | self.save_checkpoint(checkpoint_state, f"{checkpoint_name}", checkpoint_gather_fns)
87 | else:
88 | # Save a normal checkpoint that can be overwritten
89 | self.save_pickle(metadata, "metadata.pkl")
90 | self.save_pickle(dataset, "dataset.pkl")
91 | self.save_checkpoint(checkpoint_state, f"{checkpoint_name}", checkpoint_gather_fns)
92 |
93 | @staticmethod
94 | def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None):
95 | if shard_fns is not None:
96 | shard_fns = flatten_dict(to_state_dict(shard_fns))
97 | if remove_dict_prefix is not None:
98 | remove_dict_prefix = tuple(remove_dict_prefix)
99 | flattend_train_state = {}
100 | with mlxu.open_file(path) as fin:
101 | # 83886080 bytes = 80 MB, which is 16 blocks on GCS
102 | unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0)
103 | for key, value in unpacker:
104 | key = tuple(key)
105 | if remove_dict_prefix is not None:
106 | if key[: len(remove_dict_prefix)] == remove_dict_prefix:
107 | key = key[len(remove_dict_prefix) :]
108 | else:
109 | continue
110 |
111 | tensor = from_bytes(None, value)
112 | if shard_fns is not None:
113 | tensor = shard_fns[key](tensor)
114 | flattend_train_state[key] = tensor
115 |
116 | if target is not None:
117 | flattened_target = flatten_dict(to_state_dict(target), keep_empty_nodes=True)
118 | for key, value in flattened_target.items():
119 | if key not in flattend_train_state and value == empty_node:
120 | flattend_train_state[key] = value
121 |
122 | train_state = unflatten_dict(flattend_train_state)
123 | if target is None:
124 | return train_state
125 |
126 | return from_state_dict(target, train_state)
127 |
128 | @staticmethod
129 | def load_flax_checkpoint(path, target=None, shard_fns=None):
130 | """Load a standard flax checkpoint that's not saved with the
131 | msgpack streaming format.
132 | """
133 | with mlxu.open_file(path, "rb") as fin:
134 | encoded_bytes = fin.read()
135 |
136 | state_dict = flax.serialization.msgpack_restore(encoded_bytes)
137 | if shard_fns is not None:
138 | shard_fns = to_state_dict(shard_fns)
139 | state_dict = tree_apply(shard_fns, state_dict)
140 |
141 | if target is None:
142 | return state_dict
143 | return from_state_dict(target, state_dict)
144 |
145 | @classmethod
146 | def load_trainstate_checkpoint(
147 | cls, load_from, trainstate_target=None, trainstate_shard_fns=None, disallow_trainstate=False
148 | ):
149 | if trainstate_target is not None:
150 | params_target = trainstate_target.params["params"]
151 | else:
152 | params_target = None
153 |
154 | if trainstate_shard_fns is not None:
155 | params_shard_fns = trainstate_shard_fns.params["params"]
156 | else:
157 | params_shard_fns = None
158 |
159 | load_type, load_path = load_from.split("::", 1)
160 | if disallow_trainstate:
161 | assert load_type != "trainstate", "Loading full trainstate is not allowed!"
162 |
163 | train_state = None
164 | restored_params = None
165 |
166 | if load_type == "trainstate":
167 | # Load the entire train state in the streaming format
168 | train_state = cls.load_checkpoint(path=load_path, target=trainstate_target, shard_fns=trainstate_shard_fns)
169 | elif load_type == "trainstate_params":
170 | # Load the params part of the train state in the streaming format
171 | restored_params = cls.load_checkpoint(
172 | path=load_path,
173 | target=params_target,
174 | shard_fns=params_shard_fns,
175 | remove_dict_prefix=("params", "params"),
176 | )
177 | restored_params = flax.core.frozen_dict.freeze({"params": restored_params})
178 | elif load_type == "params":
179 | # Load the params in the streaming format
180 | restored_params = cls.load_checkpoint(path=load_path, target=params_target, shard_fns=params_shard_fns)
181 | restored_params = flax.core.frozen_dict.freeze({"params": restored_params})
182 | elif load_type == "flax_params":
183 | # Load the params in the standard flax format (non-streaming)
184 | # This requires the entire params to fit in memory
185 | restored_params = cls.load_flax_checkpoint(path=load_path, target=params_target, shard_fns=params_shard_fns)
186 | restored_params = flax.core.frozen_dict.freeze({"params": restored_params})
187 | else:
188 | raise ValueError(f"Invalid load_from type: {load_type}")
189 |
190 | return train_state, restored_params
191 |
--------------------------------------------------------------------------------
/ttt/infra/jax_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
4 | from functools import partial
5 | import re
6 | import dataclasses
7 | import random
8 | from ml_collections import ConfigDict
9 | from ml_collections.config_dict.config_dict import placeholder
10 |
11 | import flax
12 | import jax
13 | import jax.numpy as jnp
14 | from jax.sharding import PartitionSpec as PS
15 | from jax.sharding import Mesh
16 | from jax.experimental import mesh_utils
17 | from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint
18 | from jax.experimental.pjit import pjit
19 | from jax.interpreters import pxla
20 | import numpy as np
21 | import torch
22 | from transformers import FlaxLogitsWarper
23 |
24 | from io import BytesIO
25 | from PIL import Image
26 | import wandb
27 | import matplotlib.pyplot as plt
28 | import logging
29 |
30 | matplotlib_logger = logging.getLogger("matplotlib")
31 | matplotlib_logger.setLevel(logging.WARNING)
32 |
33 |
34 | class JaxRNG(object):
35 | """A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside
36 | pure function.
37 | """
38 |
39 | @classmethod
40 | def from_seed(cls, seed):
41 | return cls(jax.random.PRNGKey(seed))
42 |
43 | def __init__(self, rng):
44 | self.rng = rng
45 |
46 | def __call__(self, keys=None):
47 | if keys is None:
48 | self.rng, split_rng = jax.random.split(self.rng)
49 | return split_rng
50 | elif isinstance(keys, int):
51 | split_rngs = jax.random.split(self.rng, num=keys + 1)
52 | self.rng = split_rngs[0]
53 | return tuple(split_rngs[1:])
54 | else:
55 | split_rngs = jax.random.split(self.rng, num=len(keys) + 1)
56 | self.rng = split_rngs[0]
57 | return {key: val for key, val in zip(keys, split_rngs[1:])}
58 |
59 |
60 | class JaxDistributedConfig(object):
61 | """Utility class for initializing JAX distributed."""
62 |
63 | @staticmethod
64 | def get_default_config(updates=None):
65 | config = ConfigDict()
66 | config.initialize_jax_distributed = False
67 | config.coordinator_address = placeholder(str)
68 | config.num_processes = placeholder(int)
69 | config.process_id = placeholder(int)
70 | config.local_device_ids = placeholder(str)
71 |
72 | if updates is not None:
73 | config.update(ConfigDict(updates).copy_and_resolve_references())
74 | return config
75 |
76 | @classmethod
77 | def initialize(cls, config):
78 | config = cls.get_default_config(config)
79 | if config.initialize_jax_distributed:
80 | if config.local_device_ids is not None:
81 | local_device_ids = [int(x) for x in config.local_device_ids.split(",")]
82 | else:
83 | local_device_ids = None
84 |
85 | jax.distributed.initialize(
86 | coordinator_address=config.coordinator_address,
87 | num_processes=config.num_processes,
88 | process_id=config.process_id,
89 | local_device_ids=local_device_ids,
90 | )
91 |
92 |
93 | class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
94 | """JIT traceable version of FlaxLogitsWarper that performs temperature scaling."""
95 |
96 | def __init__(self, temperature):
97 | self.temperature = temperature
98 |
99 | def __call__(self, input_ids, scores, cur_len):
100 | return scores / jnp.clip(self.temperature, a_min=1e-8)
101 |
102 |
103 | def make_shard_and_gather_fns(partition_specs, dtype_specs=None):
104 | """Create pytree of sharding and gathering functions from pytree of
105 | partition specs.
106 | """
107 | float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
108 |
109 | def make_to_dtype_fn(dtype_spec):
110 | def to_dtype(tensor):
111 | if dtype_specs in float_dtypes and getattr(tensor, "dtype", None) in float_dtypes:
112 | # Convert all float tensors to the same dtype
113 | return tensor.astype(dtype_specs)
114 | elif hasattr(dtype_spec, "dtype") and hasattr(tensor, "dtype"):
115 | return tensor.astype(dtype_spec.dtype)
116 | return tensor
117 |
118 | return to_dtype
119 |
120 | def make_shard_fn(partition_spec, dtype_spec=None):
121 | jax_shard_function = pjit(make_to_dtype_fn(dtype_spec), in_shardings=None, out_shardings=partition_spec)
122 |
123 | def shard_fn(tensor):
124 | return jax_shard_function(tensor).block_until_ready()
125 |
126 | return shard_fn
127 |
128 | def make_gather_fn(partition_spec, dtype_spec=None):
129 | jax_gather_fn = pjit(make_to_dtype_fn(dtype_spec), in_shardings=partition_spec, out_shardings=None)
130 |
131 | def gather_fn(tensor):
132 | return jax.device_get(jax_gather_fn(tensor))
133 |
134 | return gather_fn
135 |
136 | if dtype_specs is None or dtype_specs in float_dtypes:
137 | shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
138 | gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs)
139 | else:
140 | shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs, dtype_specs)
141 | gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs, dtype_specs)
142 | return shard_fns, gather_fns
143 |
144 |
145 | def set_random_seed(seed):
146 | np.random.seed(seed)
147 | random.seed(seed)
148 | torch.manual_seed(seed)
149 | init_rng(seed)
150 |
151 |
152 | def get_jax_mesh(axis_dims, names):
153 | if axis_dims.startswith("!"):
154 | # Allow splitting a physical mesh axis if needed
155 | mesh_axis_splitting = True
156 | axis_dims = axis_dims[1:]
157 | else:
158 | mesh_axis_splitting = False
159 |
160 | if ":" in axis_dims:
161 | dims = []
162 | dim_names = []
163 | for axis in axis_dims.split(","):
164 | name, dim = axis.split(":")
165 | assert name in names
166 | dims.append(int(dim))
167 | dim_names.append(name)
168 | assert set(dim_names) == set(names)
169 | else:
170 | dims = [int(x) for x in axis_dims.split(",")]
171 | dim_names = names
172 | assert len(dims) == len(names)
173 | mesh_shape = np.arange(jax.device_count()).reshape(dims).shape
174 | if mesh_axis_splitting:
175 | physical_mesh = np.array(jax.devices()).reshape(mesh_shape)
176 | else:
177 | physical_mesh = mesh_utils.create_device_mesh(mesh_shape)
178 | return Mesh(physical_mesh, dim_names)
179 |
180 |
181 | def names_in_current_mesh(*names):
182 | """Check if current mesh axes contain these names."""
183 | mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names
184 | return set(names) <= set(mesh_axis_names)
185 |
186 |
187 | def get_names_from_parition_spec(partition_specs):
188 | """Return axis names from partition specs."""
189 | names = set()
190 | if isinstance(partition_specs, dict):
191 | partition_specs = partition_specs.values()
192 | for item in partition_specs:
193 | if item is None:
194 | continue
195 | elif isinstance(item, str):
196 | names.add(item)
197 | else:
198 | names.update(get_names_from_parition_spec(item))
199 |
200 | return list(names)
201 |
202 |
203 | def with_sharding_constraint(x, partition_specs):
204 | """A smarter version of with_sharding_constraint that only applies the
205 | constraint if the current mesh contains the axes in the partition specs.
206 | """
207 | axis_names = get_names_from_parition_spec(partition_specs)
208 | if names_in_current_mesh(*axis_names):
209 | x = _with_sharding_constraint(x, partition_specs)
210 | return x
211 |
212 |
213 | def wrap_function_with_rng(rng):
214 | """To be used as decorator, automatically bookkeep a RNG for the wrapped function."""
215 |
216 | def wrap_function(function):
217 | def wrapped(*args, **kwargs):
218 | nonlocal rng
219 | rng, split_rng = jax.random.split(rng)
220 | return function(split_rng, *args, **kwargs)
221 |
222 | return wrapped
223 |
224 | return wrap_function
225 |
226 |
227 | def init_rng(seed):
228 | global jax_utils_rng
229 | jax_utils_rng = JaxRNG.from_seed(seed)
230 |
231 |
232 | def next_rng(*args, **kwargs):
233 | global jax_utils_rng
234 | return jax_utils_rng(*args, **kwargs)
235 |
236 |
237 | def get_metrics(metrics, unreplicate=False, stack=False):
238 | if unreplicate:
239 | metrics = flax.jax_utils.unreplicate(metrics)
240 | metrics = jax.device_get(metrics)
241 | if stack:
242 | return jax.tree_map(lambda *args: np.stack(args), *metrics)
243 | else:
244 | return {key: float(val) for key, val in metrics.items()}
245 |
246 |
247 | def mse_loss(val, target, valid=None):
248 | if valid is None:
249 | valid = jnp.ones((*target.shape[:2], 1))
250 | valid = valid.astype(jnp.float32)
251 | loss = jnp.mean(jnp.where(valid > 0.0, jnp.square(val - target), 0.0))
252 | return loss
253 |
254 |
255 | def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
256 | if valid is None:
257 | valid = jnp.ones(tokens.shape[:2])
258 | valid = valid.astype(jnp.float32)
259 | valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
260 | logits = logits.astype(jnp.float32) # for numerical stability
261 | token_log_prob = jnp.squeeze(
262 | jnp.take_along_axis(jax.nn.log_softmax(logits, axis=-1), jnp.expand_dims(tokens, -1), axis=-1), -1
263 | )
264 | token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
265 | loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
266 | correct = jnp.where(valid > 0.0, jnp.argmax(logits, axis=-1) == tokens, jnp.array(False))
267 | accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
268 | return loss, accuracy
269 |
270 |
271 | def global_norm(tree):
272 | """Return the global L2 norm of a pytree."""
273 | squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree)
274 | flattened, _ = jax.flatten_util.ravel_pytree(squared)
275 | return jnp.sqrt(jnp.sum(flattened))
276 |
277 |
278 | def average_metrics(metrics):
279 | return jax.tree_map(lambda *args: jnp.mean(jnp.stack(args)), *metrics)
280 |
281 |
282 | def get_float_dtype_by_name(dtype):
283 | return {
284 | "bf16": jnp.bfloat16,
285 | "bfloat16": jnp.bfloat16,
286 | "fp16": jnp.float16,
287 | "float16": jnp.float16,
288 | "fp32": jnp.float32,
289 | "float32": jnp.float32,
290 | "fp64": jnp.float64,
291 | "float64": jnp.float64,
292 | }[dtype]
293 |
294 |
295 | def float_tensor_to_dtype(tensor, dtype):
296 | if dtype is None or dtype == "":
297 | return tensor
298 | if isinstance(dtype, str):
299 | dtype = get_float_dtype_by_name(dtype)
300 | float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
301 | if getattr(tensor, "dtype", None) in float_dtypes:
302 | tensor = tensor.astype(dtype)
303 | return tensor
304 |
305 |
306 | def float_to_dtype(tree, dtype):
307 | return jax.tree_util.tree_map(partial(float_tensor_to_dtype, dtype=dtype), tree)
308 |
309 |
310 | def get_gradient_checkpoint_policy(name):
311 | return {
312 | "everything_saveable": jax.checkpoint_policies.everything_saveable,
313 | "nothing_saveable": jax.checkpoint_policies.nothing_saveable,
314 | "checkpoint_dots": jax.checkpoint_policies.checkpoint_dots,
315 | "checkpoint_dots_with_no_batch_dims": jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
316 | }[name]
317 |
318 |
319 | def tree_path_to_string(path, sep=None):
320 | keys = []
321 | for key in path:
322 | if isinstance(key, jax.tree_util.SequenceKey):
323 | keys.append(str(key.idx))
324 | elif isinstance(key, jax.tree_util.DictKey):
325 | keys.append(str(key.key))
326 | elif isinstance(key, jax.tree_util.GetAttrKey):
327 | keys.append(str(key.name))
328 | elif isinstance(key, jax.tree_util.FlattenedIndexKey):
329 | keys.append(str(key.key))
330 | else:
331 | keys.append(str(key))
332 | if sep is None:
333 | return tuple(keys)
334 | return sep.join(keys)
335 |
336 |
337 | def flatten_tree(xs, is_leaf=None, sep=None):
338 | flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf)
339 | output = {}
340 | for key, val in flattened:
341 | output[tree_path_to_string(key, sep=sep)] = val
342 | return output
343 |
344 |
345 | def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
346 | """An extended version of jax.tree_util.tree_map, where the mapped function
347 | f takes both the name (path) and the tree leaf as input.
348 | """
349 | return jax.tree_util.tree_map_with_path(
350 | lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r), tree, *rest, is_leaf=is_leaf
351 | )
352 |
353 |
354 | def match_partition_rules(rules, params):
355 | """Returns a pytree of PartitionSpec according to rules. Supports handling
356 | Flax TrainState and Optax optimizer state.
357 | """
358 |
359 | def get_partition_spec(name, leaf):
360 | if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1:
361 | """Don't partition scalar values."""
362 | return PS()
363 | for rule, ps in rules:
364 | if re.search(rule, name) is not None:
365 | return ps
366 | raise ValueError(f"Partition rule not found for param: {name}")
367 |
368 | return named_tree_map(get_partition_spec, params, sep="/")
369 |
370 |
371 | def get_weight_decay_mask(exclusions):
372 | """Return a weight decay mask function that computes the pytree masks
373 | according to the given exclusion rules.
374 | """
375 |
376 | def decay(name, _):
377 | for rule in exclusions:
378 | if re.search(rule, name) is not None:
379 | return False
380 | return True
381 |
382 | def weight_decay_mask(params):
383 | return named_tree_map(decay, params, sep="/")
384 |
385 | return weight_decay_mask
386 |
387 |
388 | def tree_apply(fns, tree):
389 | """Apply a pytree of functions to the pytree."""
390 | return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
391 |
392 |
393 | def master_print(msg, logger=None, end="\n"):
394 | if jax.process_index() == 0:
395 | print(msg, flush=True, end=end)
396 | if logger is not None:
397 | logger.writelines(msg)
398 | if end == "\n":
399 | logger.writelines("\n")
400 | logger.flush()
401 |
402 |
403 | def log_plot(fig, name, step):
404 | buf = BytesIO()
405 | fig.savefig(buf, format="png")
406 | buf.seek(0)
407 | image = Image.open(buf)
408 | wandb.log({name: wandb.Image(image)}, step=step)
409 | buf.close()
410 |
411 |
412 | def log_ttt_stats(layer, ttt_stats_layer, x_axis, step):
413 | ssl_tgt_last_in_mini_batch_from_mean_mse = ttt_stats_layer[0]
414 | ttt_loss_mse_init = ttt_stats_layer[1]
415 | ttt_loss_mse_step_0 = ttt_stats_layer[2]
416 | ttt_loss_mse_step_1 = ttt_stats_layer[3]
417 |
418 | fig, ax = plt.subplots()
419 | ax.plot(x_axis, ssl_tgt_last_in_mini_batch_from_mean_mse, label="$\\|E[Y_{ssl}]-Y_{ssl}\\|^2$", color="green")
420 | ax.plot(x_axis, ttt_loss_mse_init, label="$\mathcal{L}(x_t; W_0)$", color="orange")
421 | ax.plot(x_axis, ttt_loss_mse_step_0, label="$\mathcal{L}(x_t; W_{t-b})$", color="blue")
422 | ax.plot(x_axis, ttt_loss_mse_step_1, label="$\mathcal{L}(x_t; W_{t})$", color="red")
423 | ax.set_ylabel("TTT Loss")
424 | ax.set_xlabel("Position in Sequence")
425 | ax.set_title(f"Layer {layer + 1}")
426 | ax.legend()
427 | log_plot(fig, f"Layer {layer + 1} TTT Loss", step)
428 | plt.close()
429 |
--------------------------------------------------------------------------------
/ttt/infra/optimizers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
4 | from functools import partial
5 | import re
6 | import dataclasses
7 | import random
8 |
9 | from ml_collections.config_dict import config_dict
10 | from ml_collections import ConfigDict
11 | import jax
12 | import jax.numpy as jnp
13 | import numpy as np
14 | from absl import logging
15 | import optax
16 |
17 | from ttt.infra.jax_utils import float_to_dtype
18 |
19 |
20 | class OptimizerFactory(object):
21 | """Configurable optax optimizer factory."""
22 |
23 | def __init__(self):
24 | raise NotImplementedError
25 |
26 | @staticmethod
27 | def get_default_config(updates=None):
28 | config = ConfigDict()
29 | config.accumulate_gradient_steps = 1
30 | config.type = "adamw"
31 | config.palm_optimizer = PalmOptimizerFactory.get_default_config()
32 | config.adamw_optimizer = AdamWOptimizerFactory.get_default_config()
33 |
34 | if updates is not None:
35 | config.update(ConfigDict(updates).copy_and_resolve_references())
36 | return config
37 |
38 | @classmethod
39 | def get_optimizer(cls, config, weight_decay_mask=None):
40 | config = cls.get_default_config(config)
41 | if config.type == "palm":
42 | optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer(config.palm_optimizer, weight_decay_mask)
43 | elif config.type == "adamw":
44 | optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer(config.adamw_optimizer, weight_decay_mask)
45 | else:
46 | raise ValueError(f"Unknown optimizer type: {config.type}")
47 |
48 | if config.accumulate_gradient_steps > 1:
49 | optimizer = optax.MultiSteps(optimizer, config.accumulate_gradient_steps)
50 |
51 | return optimizer, optimizer_info
52 |
53 |
54 | class PalmOptimizerFactory(object):
55 | """PaLM optimizer factory. This optimizer implements the optimizer
56 | described in the PaLM paper: https://arxiv.org/abs/2204.02311
57 | """
58 |
59 | def __init__(self):
60 | raise NotImplementedError
61 |
62 | @staticmethod
63 | def get_default_config(updates=None):
64 | config = ConfigDict()
65 | config.lr = 0.01
66 | config.lr_warmup_steps = 10000
67 | config.b1 = 0.9
68 | config.b2 = 0.99
69 | config.clip_gradient = 1.0
70 | config.weight_decay = 1e-4
71 | config.bf16_momentum = False
72 |
73 | if updates is not None:
74 | config.update(ConfigDict(updates).copy_and_resolve_references())
75 | return config
76 |
77 | @classmethod
78 | def get_optimizer(cls, config, weight_decay_mask=None):
79 | config = cls.get_default_config(config)
80 |
81 | def learning_rate_schedule(step):
82 | multiplier = config.lr / 0.01
83 | return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps))
84 |
85 | def weight_decay_schedule(step):
86 | multiplier = config.weight_decay / 1e-4
87 | return -multiplier * jnp.square(learning_rate_schedule(step))
88 |
89 | optimizer_info = dict(
90 | learning_rate_schedule=learning_rate_schedule, weight_decay_schedule=weight_decay_schedule
91 | )
92 |
93 | optimizer = optax.chain(
94 | optax.clip_by_global_norm(config.clip_gradient),
95 | optax.adafactor(
96 | learning_rate=learning_rate_schedule,
97 | multiply_by_parameter_scale=True,
98 | momentum=config.b1,
99 | decay_rate=config.b2,
100 | factored=False,
101 | clipping_threshold=None,
102 | dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
103 | ),
104 | optax_add_scheduled_weight_decay(weight_decay_schedule, weight_decay_mask),
105 | )
106 | return optimizer, optimizer_info
107 |
108 |
109 | class AdamWOptimizerFactory(object):
110 | """AdamW optimizer with cosine schedule."""
111 |
112 | def __init__(self):
113 | raise NotImplementedError
114 |
115 | @staticmethod
116 | def get_default_config(updates=None):
117 | config = ConfigDict()
118 | config.init_lr = 0.0
119 | config.end_lr = 1e-5 # EasyLM's previous default: 0.001
120 | config.lr = 0.01
121 | config.lr_warmup_steps = 2000
122 | config.lr_decay_steps = 500000
123 | config.b1 = 0.9
124 | config.b2 = 0.95
125 | config.clip_gradient = 1.0
126 | config.weight_decay = 0.1 # EasyLM previous default: 1e-4
127 | config.bf16_momentum = False
128 | config.multiply_by_parameter_scale = False
129 |
130 | if updates is not None:
131 | config.update(ConfigDict(updates).copy_and_resolve_references())
132 | return config
133 |
134 | @classmethod
135 | def get_optimizer(cls, config, weight_decay_mask=None):
136 | config = cls.get_default_config(config)
137 |
138 | learning_rate_schedule = optax.warmup_cosine_decay_schedule(
139 | init_value=config.init_lr,
140 | peak_value=config.lr,
141 | warmup_steps=config.lr_warmup_steps,
142 | decay_steps=config.lr_decay_steps,
143 | end_value=config.end_lr,
144 | )
145 |
146 | optimizer_info = dict(learning_rate_schedule=learning_rate_schedule)
147 |
148 | if config.multiply_by_parameter_scale:
149 | optimizer = optax.chain(
150 | optax.clip_by_global_norm(config.clip_gradient),
151 | optax.adafactor(
152 | learning_rate=learning_rate_schedule,
153 | multiply_by_parameter_scale=True,
154 | momentum=config.b1,
155 | decay_rate=config.b2,
156 | factored=False,
157 | clipping_threshold=None,
158 | dtype_momentum=(jnp.bfloat16 if config.bf16_momentum else jnp.float32),
159 | ),
160 | optax_add_scheduled_weight_decay(
161 | lambda step: -learning_rate_schedule(step) * config.weight_decay, weight_decay_mask
162 | ),
163 | )
164 | else:
165 | optimizer = optax.chain(
166 | optax.clip_by_global_norm(config.clip_gradient),
167 | optax.adamw(
168 | learning_rate=learning_rate_schedule,
169 | weight_decay=config.weight_decay,
170 | b1=config.b1,
171 | b2=config.b2,
172 | mask=weight_decay_mask,
173 | mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
174 | ),
175 | )
176 |
177 | return optimizer, optimizer_info
178 |
179 |
180 | class OptaxScheduledWeightDecayState(NamedTuple):
181 | count: jax.Array
182 |
183 |
184 | def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
185 | """Apply weight decay with schedule."""
186 |
187 | def init_fn(params):
188 | del params
189 | return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
190 |
191 | def update_fn(updates, state, params):
192 | if params is None:
193 | raise ValueError("Params cannot be None for weight decay!")
194 |
195 | weight_decay = schedule_fn(state.count)
196 | updates = jax.tree_util.tree_map(lambda g, p: g + weight_decay * p, updates, params)
197 | return updates, OptaxScheduledWeightDecayState(count=optax.safe_int32_increment(state.count))
198 |
199 | if mask is not None:
200 | return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
201 | return optax.GradientTransformation(init_fn, update_fn)
202 |
--------------------------------------------------------------------------------
/ttt/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/test-time-training/ttt-lm-jax/ac8cdc8a43b811afe23c27d7e82eed34a747b19c/ttt/models/__init__.py
--------------------------------------------------------------------------------
/ttt/models/bpt.py:
--------------------------------------------------------------------------------
1 | """
2 | An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370
3 | Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682
4 | """
5 |
6 | import functools
7 | from typing import NamedTuple
8 |
9 | import flax.linen as nn
10 | import jax
11 | import jax.lax as lax
12 | import jax.numpy as jnp
13 | from einops import rearrange
14 |
15 | """
16 | Computing ffn blockwise without materializing the large hidden tensor, training
17 | 4x longer sequences than the memory-efficient transformer.
18 | Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023
19 | """
20 | def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True):
21 | # remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable()
22 | # inputs: (batch, seq_len, dim)
23 | # chunk_size: the chunk size to split the sequence
24 | inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
25 | def scan_ffn(remat_ffn, carry, hidden_states):
26 | # outputs = remat_ffn(hidden_states, deterministic=deterministic)
27 | outputs = remat_ffn(hidden_states, deterministic) # @xinhao: when mlp is rematted, should directly pass `deterministic` instead of using keyword. Otherwise, `deterministic` will be ignored.
28 | return carry, outputs
29 | scan_axis = inputs.ndim - 2
30 | _, res = nn.scan(
31 | scan_ffn,
32 | variable_broadcast="params",
33 | split_rngs={"params": False, "dropout": True},
34 | in_axes=scan_axis,
35 | out_axes=scan_axis,
36 | )(remat_ffn, None, inputs)
37 | res = rearrange(res, 'b c n d -> b (c n) d')
38 | return res
39 |
40 |
41 | """
42 | Compute attention blockwise without materializing the full attention matrix,
43 | initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021;
44 | flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA
45 | efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370
46 | Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x
47 | longer sequences than memory-efficient/flash-attention and fusion of attention and FFN.
48 | """
49 | def blockwise_attn(
50 | query, key, value,
51 | bias=None,
52 | deterministic=True,
53 | dropout_rng=None,
54 | attn_pdrop=0.0,
55 | causal=True,
56 | query_chunk_size=2048,
57 | key_chunk_size=2048,
58 | dtype=jnp.float32,
59 | policy=jax.checkpoint_policies.nothing_saveable(),
60 | precision=None,
61 | float32_logits=True,
62 | prevent_cse=True,
63 | ):
64 | # query, key, value: (batch, seq_len, num_heads, dim_per_head)
65 | # bias: (batch, seq_len) can be used to mask out attention (e.g. padding)
66 | # causal: whether to use causal mask
67 | # policy: one of jax.checkpoint_policies
68 | query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
69 | if float32_logits:
70 | query = query.astype(jnp.float32)
71 | key = key.astype(jnp.float32)
72 |
73 | batch, q_len, num_heads, dim_per_head = query.shape
74 | batch, kv_len, num_heads, dim_per_head = key.shape
75 | batch, kv_len, num_heads, dim_per_head = value.shape
76 |
77 | num_q = q_len // query_chunk_size
78 | num_kv = kv_len // key_chunk_size
79 | query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
80 | key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
81 | value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
82 |
83 | query = jnp.moveaxis(query, 1, 0)
84 | key = jnp.moveaxis(key, 1, 0)
85 | value = jnp.moveaxis(value, 1, 0)
86 |
87 | if bias is not None:
88 | for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
89 | assert bias_dim == 1 or bias_dim == broadcast_dim
90 | if not deterministic and attn_pdrop > 0.0:
91 | attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
92 | attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
93 | else:
94 | attn_dropout = None
95 |
96 | _chunk_bias_fn = functools.partial(
97 | _chunk_attention_bias,
98 | query_chunk_size, key_chunk_size, bias, deterministic,
99 | attn_dropout, attn_pdrop, causal, dtype)
100 |
101 | def scan_attention(args):
102 | query_chunk, query_chunk_idx = args
103 |
104 | @functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
105 | def scan_kv_block(carry, args):
106 | key_chunk, value_chunk, key_chunk_idx = args
107 | (numerator, denominator, prev_max_score) = carry
108 | attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
109 | bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
110 | bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
111 | attn_weights = attn_weights + bias_chunk
112 |
113 | max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
114 | max_score = jnp.maximum(prev_max_score, max_score)
115 | max_score = jax.lax.stop_gradient(max_score)
116 | exp_weights = jnp.exp(attn_weights - max_score)
117 | exp_values = jnp.einsum(
118 | 'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
119 | )
120 | correction = jnp.exp(prev_max_score - max_score)
121 | numerator = numerator * correction + exp_values
122 | denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True)
123 | return Carry(numerator, denominator, max_score), None
124 |
125 | def skip_upper_half(carry, args):
126 | key_chunk, value_chunk, key_chunk_idx = args
127 | skip_block = jnp.array(False)
128 | if causal:
129 | skip_block = query_chunk_idx < key_chunk_idx
130 | return jax.lax.cond(
131 | skip_block,
132 | lambda carry, args: (carry, None),
133 | scan_kv_block,
134 | carry,
135 | args,
136 | )
137 |
138 | init_carry = Carry(
139 | jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
140 | jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
141 | (-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
142 | )
143 | (numerator, denominator, max_score), _ = lax.scan(
144 | skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
145 | )
146 | outputs = (numerator / denominator).astype(dtype)
147 | return outputs
148 |
149 | _, res = lax.scan(
150 | lambda _, x: ((), scan_attention(x)),
151 | (), xs=(query, jnp.arange(0, num_q))
152 | )
153 | res = rearrange(res, 'n b c h d -> b (n c) h d')
154 | return res
155 |
156 |
157 | class Carry(NamedTuple):
158 | numerator: jax.Array
159 | denominator: jax.Array
160 | max_so_far: jax.Array
161 |
162 |
163 | def _chunk_attention_bias(query_chunk_size, key_chunk_size,
164 | bias, deterministic, attn_dropout, attn_pdrop, causal,
165 | dtype, query_chunk_idx, key_chunk_idx):
166 | query_offset = query_chunk_idx * query_chunk_size
167 | key_offset = key_chunk_idx * key_chunk_size
168 | chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
169 | if bias is not None:
170 | chunk_bias = lax.dynamic_slice(
171 | bias,
172 | start_indices=(0, 0, query_offset, key_offset),
173 | slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
174 | )
175 |
176 | if causal:
177 | query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
178 | key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
179 | offset = query_offset - key_offset
180 | query_idx += offset
181 | causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
182 | chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
183 |
184 | if not deterministic and attn_pdrop > 0.0:
185 | attn_dropout_slice = lax.dynamic_slice(
186 | attn_dropout,
187 | start_indices=(0, 0, query_offset, key_offset),
188 | slice_sizes=(
189 | *attn_dropout.shape[:2],
190 | min(attn_dropout.shape[-2], query_chunk_size),
191 | min(attn_dropout.shape[-1], key_chunk_size),
192 | ),
193 | )
194 | chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
195 | return chunk_bias.astype(dtype)
196 |
197 |
198 | if __name__ == '__main__':
199 | # test
200 | def reference_attn(query, key, value, causal, dtype):
201 | query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
202 | logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
203 | if causal:
204 | mask_value = jnp.finfo(logits.dtype).min
205 | _, q_seq_len, _, _ = query.shape
206 | _, kv_seq_len, _, _ = key.shape
207 | mask_shape = (q_seq_len, kv_seq_len)
208 | row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
209 | col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
210 | causal_mask = (row_ids < col_ids)[None, None, :, :]
211 | logits = logits + jnp.where(causal_mask, mask_value, 0.0)
212 | weights = jax.nn.softmax(logits, axis=-1)
213 | out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
214 | return out
215 |
216 | # random inputs
217 | shape = (1, 32, 8, 64)
218 | query = jax.random.normal(jax.random.PRNGKey(0), shape)
219 | key = jax.random.normal(jax.random.PRNGKey(1), shape)
220 | value = jax.random.normal(jax.random.PRNGKey(2), shape)
221 |
222 | causal = True
223 | chunk_size = 4
224 | policy = jax.checkpoint_policies.nothing_saveable()
225 |
226 | blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False)
227 | reference = reference_attn(query, key, value, causal, 'float32')
228 |
229 | assert jnp.allclose(reference, blockwise, atol=1e-6)
230 |
--------------------------------------------------------------------------------
/ttt/models/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Dict, List, Optional, Tuple, Union
3 | import json
4 |
5 | import numpy as np
6 | import jax
7 | import jax.numpy as jnp
8 | from jax import lax
9 | from jax.sharding import PartitionSpec as PS
10 | import flax
11 | import flax.linen as nn
12 | from flax.linen import combine_masks, make_causal_mask
13 | from flax.linen.attention import dot_product_attention_weights
14 | from flax.linen import partitioning as nn_partitioning
15 |
16 | from transformers.configuration_utils import PretrainedConfig
17 | from transformers.modeling_flax_outputs import ModelOutput
18 |
19 | from ml_collections import ConfigDict
20 | from mlxu import function_args_to_config, load_pickle, open_file
21 |
22 | from ttt.models.bpt import blockwise_ffn, blockwise_attn
23 | from ttt.infra.jax_utils import with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy
24 | from ttt.models.ttt_layer import TTTLinear, TTTMLP, TTTLinearBase, TTTMLPBase, precompute_freqs_cis, apply_rotary_emb
25 |
26 |
27 | @flax.struct.dataclass
28 | class BaseModelOutput(ModelOutput):
29 | last_hidden_state: jnp.ndarray = None
30 | hidden_states: Optional[Tuple[jnp.ndarray]] = None
31 | attentions: Optional[Tuple[jnp.ndarray]] = None
32 | ttt_stats: Optional[Tuple[jnp.ndarray]] = None
33 | logits: jnp.ndarray = None
34 |
35 |
36 | CausalLMOutput = BaseModelOutput
37 |
38 | remat = nn_partitioning.remat
39 |
40 | CONFIGS = {
41 | "125m": {
42 | "vocab_size": 32000,
43 | "num_hidden_layers": 12,
44 | "hidden_size": 768,
45 | "num_attention_heads": 12,
46 | "intermediate_size": 2048,
47 | "max_sequence_length": 2048,
48 | "initializer_range": 0.02,
49 | "rms_norm_eps": 1e-6,
50 | "use_cache": True,
51 | "tie_word_embeddings": True,
52 | "seq_modeling_block": "self_attention",
53 | "use_rotary_emb": "sequence",
54 | "rope_theta": 10000.0,
55 | "pre_conv": False,
56 | },
57 | "125m-TTT": {
58 | "vocab_size": 32000,
59 | "num_hidden_layers": 12,
60 | "hidden_size": 768,
61 | "num_attention_heads": 12,
62 | "intermediate_size": 2048,
63 | "max_sequence_length": 2048,
64 | "initializer_range": 0.02,
65 | "rms_norm_eps": 1e-6,
66 | "use_cache": True,
67 | "tie_word_embeddings": True,
68 | "seq_modeling_block": "ttt_linear",
69 | "ttt_base_lr": 1.0,
70 | "ttt_base_lr_init": -1.0,
71 | "ttt_base_lr_warmup": -1,
72 | "mini_batch_size": 16,
73 | "remat_mini_batch_group_size": 4,
74 | "rope_theta": 10000.0,
75 | "pre_conv": True,
76 | "conv_width": 4,
77 | },
78 | "350m": {
79 | "vocab_size": 32000,
80 | "num_hidden_layers": 24,
81 | "hidden_size": 1024,
82 | "num_attention_heads": 16,
83 | "intermediate_size": 2736,
84 | "max_sequence_length": 2048,
85 | "initializer_range": 0.02,
86 | "rms_norm_eps": 1e-6,
87 | "use_cache": True,
88 | "tie_word_embeddings": True,
89 | "seq_modeling_block": "self_attention",
90 | "use_rotary_emb": "sequence",
91 | "rope_theta": 10000.0,
92 | "pre_conv": False,
93 | },
94 | "350m-TTT": {
95 | "vocab_size": 32000,
96 | "num_hidden_layers": 24,
97 | "hidden_size": 1024,
98 | "num_attention_heads": 16,
99 | "intermediate_size": 2736,
100 | "max_sequence_length": 2048,
101 | "initializer_range": 0.02,
102 | "rms_norm_eps": 1e-6,
103 | "use_cache": True,
104 | "tie_word_embeddings": True,
105 | "seq_modeling_block": "ttt_linear",
106 | "ttt_base_lr": 1.0,
107 | "ttt_base_lr_init": -1.0,
108 | "ttt_base_lr_warmup": -1,
109 | "mini_batch_size": 16,
110 | "remat_mini_batch_group_size": 4,
111 | "rope_theta": 10000.0,
112 | "pre_conv": True,
113 | "conv_width": 4,
114 | },
115 | "760m": {
116 | "vocab_size": 32000,
117 | "num_hidden_layers": 24,
118 | "hidden_size": 1536,
119 | "num_attention_heads": 16,
120 | "intermediate_size": 4096,
121 | "max_sequence_length": 2048,
122 | "initializer_range": 0.02,
123 | "rms_norm_eps": 1e-6,
124 | "use_cache": True,
125 | "tie_word_embeddings": True,
126 | "seq_modeling_block": "self_attention",
127 | "use_rotary_emb": "sequence",
128 | "rope_theta": 10000.0,
129 | "pre_conv": False,
130 | },
131 | "760m-TTT": {
132 | "vocab_size": 32000,
133 | "num_hidden_layers": 24,
134 | "hidden_size": 1536,
135 | "num_attention_heads": 16,
136 | "intermediate_size": 4096,
137 | "max_sequence_length": 2048,
138 | "initializer_range": 0.02,
139 | "rms_norm_eps": 1e-6,
140 | "use_cache": True,
141 | "tie_word_embeddings": True,
142 | "seq_modeling_block": "ttt_linear",
143 | "ttt_base_lr": 1.0,
144 | "ttt_base_lr_init": -1.0,
145 | "ttt_base_lr_warmup": -1,
146 | "mini_batch_size": 16,
147 | "remat_mini_batch_group_size": 4,
148 | "rope_theta": 10000.0,
149 | "pre_conv": True,
150 | "conv_width": 4,
151 | },
152 | "1b": {
153 | "vocab_size": 32000,
154 | "num_hidden_layers": 24,
155 | "hidden_size": 2048,
156 | "num_attention_heads": 32,
157 | "intermediate_size": 5504,
158 | "max_sequence_length": 2048,
159 | "initializer_range": 0.02,
160 | "rms_norm_eps": 1e-6,
161 | "use_cache": True,
162 | "tie_word_embeddings": True,
163 | "seq_modeling_block": "self_attention",
164 | "use_rotary_emb": "sequence",
165 | "rope_theta": 10000.0,
166 | "pre_conv": False,
167 | },
168 | "1b-TTT": {
169 | "vocab_size": 32000,
170 | "num_hidden_layers": 24,
171 | "hidden_size": 2048,
172 | "num_attention_heads": 32,
173 | "intermediate_size": 5504,
174 | "max_sequence_length": 2048,
175 | "initializer_range": 0.02,
176 | "rms_norm_eps": 1e-6,
177 | "use_cache": True,
178 | "tie_word_embeddings": True,
179 | "seq_modeling_block": "ttt_linear",
180 | "ttt_base_lr": 1.0,
181 | "ttt_base_lr_init": -1.0,
182 | "ttt_base_lr_warmup": -1,
183 | "mini_batch_size": 16,
184 | "remat_mini_batch_group_size": 4,
185 | "rope_theta": 10000.0,
186 | "pre_conv": True,
187 | "conv_width": 4,
188 | },
189 | }
190 |
191 |
192 | class ModelConfig(PretrainedConfig):
193 | def __init__(
194 | self,
195 | vocab_size=32000,
196 | hidden_size=4096,
197 | intermediate_size=11008,
198 | num_hidden_layers=32,
199 | num_attention_heads=32,
200 | max_sequence_length=2048,
201 | rms_norm_eps=1e-6,
202 | initializer_range=0.02,
203 | use_cache=True,
204 | bos_token_id=1,
205 | eos_token_id=2,
206 | resid_pdrop=0.0,
207 | embd_pdrop=0.0,
208 | attn_pdrop=0.0,
209 | tie_word_embeddings=False,
210 | remat_block="",
211 | remat_attention="",
212 | remat_mlp="",
213 | remat_conv="",
214 | scan_attention=False,
215 | scan_mlp=False,
216 | scan_query_chunk_size=1024,
217 | scan_key_chunk_size=1024,
218 | scan_mlp_chunk_size=1024,
219 | fcm_min_ratio=0.0,
220 | fcm_max_ratio=0.0,
221 | **kwargs,
222 | ):
223 | self.vocab_size = vocab_size
224 | self.hidden_size = hidden_size
225 | self.initializer_range = initializer_range
226 | self.intermediate_size = intermediate_size
227 | self.num_hidden_layers = num_hidden_layers
228 | self.num_attention_heads = num_attention_heads
229 | self.max_sequence_length = max_sequence_length
230 | self.rms_norm_eps = rms_norm_eps
231 | self.use_cache = use_cache
232 | self.resid_pdrop = resid_pdrop
233 | self.embd_pdrop = embd_pdrop
234 | self.attn_pdrop = attn_pdrop
235 | self.remat_block = remat_block
236 | self.remat_attention = remat_attention
237 | self.remat_mlp = remat_mlp
238 | self.remat_conv = remat_conv
239 | self.scan_attention = scan_attention
240 | self.scan_mlp = scan_mlp
241 | self.scan_query_chunk_size = scan_query_chunk_size
242 | self.scan_key_chunk_size = scan_key_chunk_size
243 | self.scan_mlp_chunk_size = scan_mlp_chunk_size
244 | self.fcm_min_ratio = fcm_min_ratio
245 | self.fcm_max_ratio = fcm_max_ratio
246 | super().__init__(
247 | bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
248 | )
249 |
250 | @classmethod
251 | def get_default_config(cls, updates=None):
252 | config = function_args_to_config(cls.__init__)
253 |
254 | if updates is not None:
255 | config.update(ConfigDict(updates).copy_and_resolve_references())
256 |
257 | return config
258 |
259 | @staticmethod
260 | def get_jax_mesh(axis_dims):
261 | return get_jax_mesh(axis_dims, ("dp", "fsdp", "mp"))
262 |
263 | @staticmethod
264 | def get_partition_rules():
265 | """Partition rules. Note that these rules are orderd, so that
266 | the beginning rules match first. It is important to use
267 | PartitionSpec() instead of None here because JAX does not treat
268 | None as a pytree leaf.
269 | """
270 | return (
271 | # Embeddings
272 | ("model/wte/embedding", PS("mp", "fsdp")),
273 | # Attention/TTT
274 | ("seq_modeling_block/(wq|wk|wv)/kernel", PS("fsdp", "mp")),
275 | ("seq_modeling_block/wo/kernel", PS("mp", "fsdp")),
276 | # TTT
277 | ("seq_modeling_block/ttt_norm/scale", PS(None)),
278 | ("seq_modeling_block/ttt_norm/bias", PS(None)),
279 | ("seq_modeling_block/post_norm/scale", PS(None)),
280 | ("seq_modeling_block/post_norm/bias", PS(None)),
281 | ("seq_modeling_block/learnable_ttt_lr/kernel", PS(None)),
282 | ("seq_modeling_block/learnable_ttt_lr/bias", PS(None)),
283 | ("seq_modeling_block/ttt_dense_0", PS(None)),
284 | ("seq_modeling_block/ttt_dense_1", PS(None)),
285 | ("seq_modeling_block/ttt_bias_0", PS(None)),
286 | ("seq_modeling_block/ttt_bias_1", PS(None)),
287 | # SwiGLU MLP
288 | ("feed_forward/w1/kernel", PS("fsdp", "mp")),
289 | ("feed_forward/w2/kernel", PS("mp", "fsdp")),
290 | ("feed_forward/w3/kernel", PS("fsdp", "mp")),
291 | # RMS Norm
292 | ("seq_norm/kernel", PS(None)),
293 | ("ffn_norm/kernel", PS(None)),
294 | # Output Head
295 | ("model/ln_f/kernel", PS(None)),
296 | ("lm_head/kernel", PS("fsdp", "mp")),
297 | (".*", PS(None)),
298 | )
299 |
300 | @staticmethod
301 | def get_weight_decay_exclusions():
302 | return tuple()
303 |
304 | @staticmethod
305 | def rng_keys():
306 | return ("params", "dropout", "fcm")
307 |
308 | @classmethod
309 | def load_config(cls, path):
310 | if path in CONFIGS:
311 | return cls.from_dict(CONFIGS[path])
312 | load_type, load_path = path.split("::", 1)
313 | if load_type == "pickle":
314 | return cls.from_dict(load_pickle(load_path)["config"])
315 | elif load_type == "json":
316 | with open_file(load_path, "r") as fin:
317 | raw_config = fin.read()
318 | return cls.from_dict(json.loads(raw_config))
319 | else:
320 | raise ValueError(f"Unsupported load config type: {load_type}")
321 |
322 |
323 | class RMSNorm(nn.Module):
324 | dim: int
325 | eps: float = 1e-6
326 | dtype: jnp.dtype = jnp.float32
327 | param_dtype: jnp.dtype = jnp.float32
328 |
329 | def setup(self) -> None:
330 | self.weight = self.param("kernel", nn.initializers.ones, (self.dim,), self.param_dtype)
331 |
332 | def _norm(self, x: jnp.ndarray) -> jnp.ndarray:
333 | return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)
334 |
335 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
336 | x = x.astype(jnp.promote_types(self.dtype, jnp.float32))
337 | output = self._norm(x).astype(self.dtype)
338 | weight = jnp.asarray(self.weight, self.dtype)
339 | return output * weight
340 |
341 |
342 | class SwiGLUMLP(nn.Module):
343 | config: ModelConfig
344 | dtype: jnp.dtype = jnp.float32
345 | param_dtype: jnp.dtype = jnp.float32
346 | precision: Optional[Union[jax.lax.Precision, str]] = None
347 |
348 | def setup(self) -> None:
349 | config = self.config
350 | self.w1 = nn.Dense(
351 | config.intermediate_size,
352 | dtype=self.dtype,
353 | param_dtype=self.param_dtype,
354 | use_bias=False,
355 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
356 | precision=self.precision,
357 | )
358 | self.w2 = nn.Dense(
359 | config.hidden_size,
360 | dtype=self.dtype,
361 | param_dtype=self.param_dtype,
362 | use_bias=False,
363 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
364 | precision=self.precision,
365 | )
366 | self.w3 = nn.Dense(
367 | config.intermediate_size,
368 | dtype=self.dtype,
369 | param_dtype=self.param_dtype,
370 | use_bias=False,
371 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
372 | precision=self.precision,
373 | )
374 | self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
375 |
376 | def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
377 | x = self.w2(nn.silu(self.w1(x)) * self.w3(x))
378 | x = self.dropout(x, deterministic=deterministic)
379 | return x
380 |
381 |
382 | class ConvModule(nn.Module):
383 | config: ModelConfig
384 | dtype: jnp.dtype = jnp.float32
385 | param_dtype: jnp.dtype = jnp.float32
386 | precision: Optional[Union[jax.lax.Precision, str]] = None
387 |
388 | def setup(self):
389 | config = self.config
390 |
391 | if self.config.remat_conv != "":
392 | conv_module = nn_partitioning.remat(
393 | nn.Conv, policy=get_gradient_checkpoint_policy(self.config.remat_conv), prevent_cse=True
394 | )
395 | else:
396 | conv_module = nn.Conv
397 |
398 | self.conv1 = conv_module(
399 | config.hidden_size,
400 | (config.conv_width,),
401 | padding="CAUSAL",
402 | feature_group_count=config.hidden_size,
403 | dtype=self.dtype,
404 | param_dtype=self.param_dtype,
405 | precision=self.precision,
406 | )
407 | self.conv_norm = RMSNorm(
408 | self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype
409 | )
410 |
411 | def __call__(self, hidden_states):
412 | x = hidden_states
413 | x = self.conv_norm(x)
414 | x = self.conv1(x)
415 | return x
416 |
417 |
418 | class Attention(nn.Module):
419 | config: ModelConfig
420 | dtype: jnp.dtype = jnp.float32
421 | param_dtype: jnp.dtype = jnp.float32
422 | precision: Optional[Union[jax.lax.Precision, str]] = None
423 |
424 | def setup(self):
425 | config = self.config
426 | self.embed_dim = config.hidden_size
427 | self.num_heads = config.num_attention_heads
428 | self.head_dim = self.embed_dim // self.num_heads
429 |
430 | self.wq = nn.Dense(
431 | config.num_attention_heads * self.head_dim,
432 | dtype=self.dtype,
433 | param_dtype=self.param_dtype,
434 | use_bias=False,
435 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
436 | precision=self.precision,
437 | )
438 | self.wk = nn.Dense(
439 | config.num_attention_heads * self.head_dim,
440 | dtype=self.dtype,
441 | param_dtype=self.param_dtype,
442 | use_bias=False,
443 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
444 | precision=self.precision,
445 | )
446 | self.wv = nn.Dense(
447 | config.num_attention_heads * self.head_dim,
448 | dtype=self.dtype,
449 | param_dtype=self.param_dtype,
450 | use_bias=False,
451 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
452 | precision=self.precision,
453 | )
454 | self.wo = nn.Dense(
455 | config.hidden_size,
456 | dtype=self.dtype,
457 | param_dtype=self.param_dtype,
458 | use_bias=False,
459 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
460 | precision=self.precision,
461 | )
462 |
463 | self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
464 |
465 | self.causal_mask = make_causal_mask(jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool")
466 |
467 | self.freqs_cis = precompute_freqs_cis(
468 | self.head_dim, config.max_sequence_length * 2, theta=config.rope_theta, dtype=self.dtype
469 | )
470 |
471 | def _split_heads(self, hidden_states):
472 | return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
473 |
474 | def _merge_heads(self, hidden_states):
475 | return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
476 |
477 | @nn.compact
478 | def _concatenate_to_cache(self, key, value, query, attention_mask):
479 | """
480 | This function takes projected key, value states from a single input token and concatenates the states to cached
481 | states from previous steps. This function is slighly adapted from the official Flax repository:
482 | https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
483 | """
484 | # detect if we're initializing by absence of existing cache data.
485 | is_initialized = self.has_variable("cache", "cached_key")
486 | cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
487 | cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
488 | cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
489 |
490 | if is_initialized:
491 | *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
492 | # update key, value caches with our new 1d spatial slices
493 | cur_index = cache_index.value
494 | indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
495 | key = lax.dynamic_update_slice(cached_key.value, key, indices)
496 | value = lax.dynamic_update_slice(cached_value.value, value, indices)
497 | cached_key.value = key
498 | cached_value.value = value
499 | num_updated_cache_vectors = query.shape[1]
500 | cache_index.value = cache_index.value + num_updated_cache_vectors
501 | # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
502 | pad_mask = jnp.broadcast_to(
503 | jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
504 | tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
505 | )
506 | attention_mask = combine_masks(pad_mask, attention_mask)
507 | return key, value, attention_mask
508 |
509 | def __call__(
510 | self,
511 | hidden_states,
512 | attention_mask,
513 | position_ids,
514 | deterministic: bool = True,
515 | init_cache: bool = False,
516 | output_attentions: bool = False,
517 | fcm_mask=None,
518 | ):
519 | xq, xk, xv = (self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states))
520 |
521 | xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "mp"))
522 | xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), None, "mp"))
523 | xv = with_sharding_constraint(xv, PS(("dp", "fsdp"), None, "mp"))
524 |
525 | xq = self._split_heads(xq)
526 | xk = self._split_heads(xk)
527 | xv = self._split_heads(xv)
528 |
529 | freqs_cis = jnp.take(self.freqs_cis, position_ids, axis=0)
530 |
531 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype)
532 |
533 | dropout_rng = None
534 | if not deterministic and self.config.attn_pdrop > 0.0:
535 | dropout_rng = self.make_rng("dropout")
536 |
537 | if self.config.scan_attention and not (self.has_variable("cache", "cached_key") or init_cache):
538 | # doesn't need blockwise attention if we are doing autoregressive decoding since no quadratic memory
539 |
540 | # attention mask without nxn materlization, blockwise_attn will handle the rest
541 | attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
542 | # transform boolean mask into float mask
543 | attention_bias = lax.select(
544 | attention_mask > 0,
545 | jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
546 | jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
547 | )
548 | attn_weights = None
549 | attn_output = blockwise_attn(
550 | xq,
551 | xk,
552 | xv,
553 | bias=attention_bias,
554 | deterministic=deterministic,
555 | dropout_rng=dropout_rng,
556 | attn_pdrop=self.config.attn_pdrop,
557 | causal=True,
558 | query_chunk_size=self.config.scan_query_chunk_size,
559 | key_chunk_size=self.config.scan_key_chunk_size,
560 | dtype=self.dtype,
561 | policy=get_gradient_checkpoint_policy("nothing_saveable"),
562 | precision=self.precision,
563 | float32_logits=True,
564 | prevent_cse=True,
565 | )
566 | attn_output = with_sharding_constraint(attn_output, PS(("dp", "fsdp"), None, "mp", None))
567 | else:
568 | query_length, key_length = xq.shape[1], xk.shape[1]
569 |
570 | if self.has_variable("cache", "cached_key"):
571 | mask_shift = self.variables["cache"]["cache_index"]
572 | max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
573 | causal_mask = lax.dynamic_slice(
574 | self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
575 | )
576 | else:
577 | causal_mask = self.causal_mask[:, :, :query_length, :key_length]
578 |
579 | batch_size = hidden_states.shape[0]
580 | causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
581 |
582 | attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
583 | attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
584 |
585 | # During fast autoregressive decoding, we feed one position at a time,
586 | # and cache the keys and values step by step.
587 | if self.has_variable("cache", "cached_key") or init_cache:
588 | xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)
589 |
590 | # transform boolean mask into float mask
591 | attention_bias = lax.select(
592 | attention_mask > 0,
593 | jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
594 | jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
595 | )
596 | attn_weights = dot_product_attention_weights(
597 | xq,
598 | xk,
599 | bias=attention_bias,
600 | dropout_rng=dropout_rng,
601 | dropout_rate=self.config.attn_pdrop,
602 | deterministic=deterministic,
603 | dtype=jnp.promote_types(self.dtype, jnp.float32),
604 | precision=self.precision,
605 | )
606 | attn_weights = with_sharding_constraint(attn_weights, PS(("dp", "fsdp"), "mp", None, None))
607 | attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision)
608 |
609 | attn_output = self._merge_heads(attn_output)
610 | attn_output = self.wo(attn_output)
611 | attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
612 | outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
613 | return outputs
614 |
615 |
616 | class Block(nn.Module):
617 | config: ModelConfig
618 | dtype: jnp.dtype = jnp.float32
619 | param_dtype: jnp.dtype = jnp.float32
620 | precision: Optional[Union[jax.lax.Precision, str]] = None
621 |
622 | def setup(self) -> None:
623 | if self.config.seq_modeling_block == "self_attention":
624 | seq_modeling_block = Attention
625 |
626 | elif self.config.seq_modeling_block == "ttt_linear":
627 | seq_modeling_block = TTTLinear
628 |
629 | elif self.config.seq_modeling_block == "ttt_mlp":
630 | seq_modeling_block = TTTMLP
631 |
632 | elif self.config.seq_modeling_block == "ttt_linear_base":
633 | seq_modeling_block = TTTLinearBase
634 |
635 | elif self.config.seq_modeling_block == "ttt_mlp_base":
636 | seq_modeling_block = TTTMLPBase
637 |
638 | else:
639 | raise NotImplementedError("Sequence Modeling Layer %s Not Implemented." % (self.config.seq_modeling_block))
640 |
641 | mlp_module = SwiGLUMLP
642 |
643 | if self.config.remat_attention != "":
644 | if self.config.seq_modeling_block == "self_attention":
645 | static_argnums_tuple = (3, 4, 5)
646 | else:
647 | static_argnums_tuple = (3, 4)
648 | seq_modeling_block = remat(
649 | seq_modeling_block,
650 | static_argnums=static_argnums_tuple,
651 | policy=get_gradient_checkpoint_policy(self.config.remat_attention),
652 | prevent_cse=True,
653 | )
654 | if self.config.remat_mlp != "":
655 | mlp_module = remat(
656 | SwiGLUMLP,
657 | static_argnums=(1,),
658 | policy=get_gradient_checkpoint_policy(self.config.remat_mlp),
659 | prevent_cse=True,
660 | )
661 |
662 | self.seq_modeling_block = seq_modeling_block(
663 | self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision
664 | )
665 | self.feed_forward = mlp_module(
666 | self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision
667 | )
668 | self.seq_norm = RMSNorm(
669 | self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype
670 | )
671 | self.ffn_norm = RMSNorm(
672 | self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype
673 | )
674 | if self.config.pre_conv:
675 | self.conv = ConvModule(
676 | self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision
677 | )
678 |
679 | def __call__(
680 | self,
681 | hidden_states,
682 | input_ids=None,
683 | attention_mask=None,
684 | position_ids=None,
685 | ttt_lr_mult=1.0,
686 | deterministic: bool = True,
687 | init_cache: bool = False,
688 | output_attentions: bool = False,
689 | output_ttt_stats: bool = False,
690 | fcm_mask: Optional[jnp.ndarray] = None,
691 | ):
692 | if self.config.pre_conv:
693 | conv_outputs = self.conv(hidden_states)
694 | hidden_states = hidden_states + conv_outputs
695 |
696 | hidden_states_pre_normed = self.seq_norm(hidden_states)
697 |
698 | if self.config.seq_modeling_block == "self_attention":
699 | seq_modeling_outputs = self.seq_modeling_block(
700 | hidden_states_pre_normed,
701 | attention_mask,
702 | position_ids,
703 | deterministic,
704 | init_cache,
705 | output_attentions,
706 | fcm_mask,
707 | )
708 | else:
709 | seq_modeling_outputs = self.seq_modeling_block(
710 | hidden_states_pre_normed, input_ids, position_ids, deterministic, output_ttt_stats, ttt_lr_mult
711 | )
712 |
713 | seq_modeling_output = seq_modeling_outputs[0]
714 | hidden_states = hidden_states + seq_modeling_output
715 |
716 | feed_forward_input = self.ffn_norm(hidden_states)
717 | if self.config.scan_mlp:
718 | feed_forward_hidden_states = blockwise_ffn(
719 | self.feed_forward, feed_forward_input, self.config.scan_mlp_chunk_size, deterministic
720 | )
721 | else:
722 | feed_forward_hidden_states = self.feed_forward(feed_forward_input, deterministic)
723 | feed_forward_hidden_states = with_sharding_constraint(
724 | feed_forward_hidden_states, PS(("dp", "fsdp"), None, "mp")
725 | )
726 | hidden_states = hidden_states + feed_forward_hidden_states
727 |
728 | if len(seq_modeling_outputs) > 1:
729 | if isinstance(seq_modeling_outputs[1], tuple):
730 | return (hidden_states,) + (seq_modeling_outputs[1],)
731 | else:
732 | return (hidden_states,) + ((seq_modeling_outputs[1],),)
733 | else:
734 | return (hidden_states,)
735 |
736 |
737 | class BlockCollection(nn.Module):
738 | config: ModelConfig
739 | dtype: jnp.dtype = jnp.float32
740 | param_dtype: jnp.dtype = jnp.float32
741 | precision: Optional[Union[jax.lax.Precision, str]] = None
742 |
743 | def setup(self):
744 | block = Block
745 | if self.config.remat_block != "":
746 | block = remat(
747 | Block, static_argnums=(5, 6, 7, 8), policy=get_gradient_checkpoint_policy(self.config.remat_block)
748 | )
749 | self.blocks = [
750 | block(self.config, name=str(i), dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)
751 | for i in range(self.config.num_hidden_layers)
752 | ]
753 |
754 | def __call__(
755 | self,
756 | hidden_states,
757 | input_ids=None,
758 | attention_mask=None,
759 | position_ids=None,
760 | ttt_lr_mult=1.0,
761 | deterministic: bool = True,
762 | init_cache: bool = False,
763 | output_attentions: bool = False,
764 | output_hidden_states: bool = False,
765 | output_ttt_stats: bool = False,
766 | return_dict: bool = True,
767 | ):
768 | all_attentions = () if output_attentions else None
769 | all_hidden_states = () if output_hidden_states else None
770 | all_ttt_stats = () if output_ttt_stats else None
771 |
772 | if not deterministic and self.config.fcm_max_ratio > 0:
773 | batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1]
774 | fcm_ratio = jax.random.uniform(
775 | self.make_rng("fcm"),
776 | shape=(batch_size, 1, 1, 1),
777 | minval=self.config.fcm_min_ratio,
778 | maxval=self.config.fcm_max_ratio,
779 | )
780 | fcm_mask = jax.random.uniform(self.make_rng("fcm"), shape=(batch_size, 1, 1, seq_length)) > fcm_ratio
781 | fcm_mask = fcm_mask.at[:, :, :, 0].set(True)
782 | fcm_mask = fcm_mask.astype("bool")
783 | else:
784 | fcm_mask = None
785 |
786 | for block in self.blocks:
787 | if output_hidden_states:
788 | all_hidden_states += (hidden_states,)
789 |
790 | layer_outputs = block(
791 | hidden_states,
792 | input_ids,
793 | attention_mask,
794 | position_ids,
795 | ttt_lr_mult,
796 | deterministic,
797 | init_cache,
798 | output_attentions,
799 | output_ttt_stats,
800 | fcm_mask,
801 | )
802 | hidden_states = layer_outputs[0]
803 |
804 | if output_attentions:
805 | all_attentions += (layer_outputs[1],)
806 |
807 | if output_ttt_stats:
808 | all_ttt_stats += (layer_outputs[1],)
809 |
810 | outputs = (hidden_states, all_hidden_states, all_attentions, all_ttt_stats)
811 | return outputs
812 |
813 |
814 | class Model(nn.Module):
815 | config: ModelConfig
816 | dtype: jnp.dtype = jnp.float32
817 | param_dtype: jnp.dtype = jnp.float32
818 | precision: Optional[Union[jax.lax.Precision, str]] = None
819 |
820 | def setup(self):
821 | self.embed_dim = self.config.hidden_size
822 | self.wte = nn.Embed(
823 | self.config.vocab_size,
824 | self.config.hidden_size,
825 | embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
826 | dtype=self.dtype,
827 | param_dtype=self.param_dtype,
828 | )
829 | self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
830 | self.h = BlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)
831 | self.ln_f = RMSNorm(
832 | self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype
833 | )
834 |
835 | def __call__(
836 | self,
837 | input_ids,
838 | attention_mask,
839 | position_ids,
840 | ttt_lr_mult=1.0,
841 | deterministic=True,
842 | init_cache: bool = False,
843 | output_attentions: bool = False,
844 | output_hidden_states: bool = False,
845 | output_ttt_stats: bool = False,
846 | return_dict: bool = True,
847 | ):
848 | input_embeds = self.wte(input_ids.astype("i4"))
849 | hidden_states = self.dropout(input_embeds, deterministic=deterministic)
850 | outputs = self.h(
851 | hidden_states,
852 | input_ids=input_ids,
853 | attention_mask=attention_mask,
854 | position_ids=position_ids,
855 | ttt_lr_mult=ttt_lr_mult,
856 | deterministic=deterministic,
857 | init_cache=init_cache,
858 | output_attentions=output_attentions,
859 | output_hidden_states=output_hidden_states,
860 | output_ttt_stats=output_ttt_stats,
861 | return_dict=return_dict,
862 | )
863 | hidden_states = outputs[0]
864 | hidden_states = self.ln_f(hidden_states)
865 |
866 | if output_hidden_states:
867 | all_hidden_states = outputs[1] + (hidden_states,)
868 | outputs = (hidden_states, all_hidden_states) + outputs[2:]
869 | else:
870 | outputs = (hidden_states,) + outputs[1:]
871 |
872 | if not return_dict:
873 | return tuple(v for v in outputs if v is not None)
874 |
875 | return BaseModelOutput(
876 | last_hidden_state=hidden_states, hidden_states=outputs[1], attentions=outputs[2], ttt_stats=outputs[3]
877 | )
878 |
879 |
880 | class CausalLM(nn.Module):
881 | config: ModelConfig
882 | dtype: jnp.dtype = jnp.float32
883 | param_dtype: jnp.dtype = jnp.float32
884 | precision: Optional[Union[jax.lax.Precision, str]] = None
885 |
886 | def setup(self):
887 | self.model = Model(self.config, dtype=self.dtype, param_dtype=self.param_dtype)
888 | self.lm_head = nn.Dense(
889 | self.config.vocab_size,
890 | dtype=self.dtype,
891 | param_dtype=self.param_dtype,
892 | use_bias=False,
893 | kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
894 | precision=self.precision,
895 | )
896 |
897 | def __call__(
898 | self,
899 | input_ids,
900 | attention_mask=None,
901 | position_ids=None,
902 | ttt_lr_mult=1.0,
903 | deterministic: bool = True,
904 | init_cache: bool = False,
905 | output_attentions: bool = False,
906 | output_hidden_states: bool = False,
907 | output_ttt_stats: bool = False,
908 | return_dict: bool = True,
909 | ):
910 | batch_size, seq_length = input_ids.shape
911 | if attention_mask is None:
912 | attention_mask = jnp.ones_like(input_ids)
913 | if position_ids is None:
914 | position_ids = jnp.broadcast_to(
915 | jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), (batch_size, seq_length)
916 | )
917 | outputs = self.model(
918 | input_ids,
919 | attention_mask,
920 | position_ids,
921 | ttt_lr_mult,
922 | deterministic=deterministic,
923 | init_cache=init_cache,
924 | output_attentions=output_attentions,
925 | output_hidden_states=output_hidden_states,
926 | output_ttt_stats=output_ttt_stats,
927 | return_dict=return_dict,
928 | )
929 |
930 | hidden_states = outputs[0]
931 |
932 | if self.config.tie_word_embeddings:
933 | shared_kernel = self.model.variables["params"]["wte"]["embedding"].T
934 | lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
935 | else:
936 | lm_logits = self.lm_head(hidden_states)
937 |
938 | if not return_dict:
939 | return (lm_logits,) + outputs[1:]
940 |
941 | return CausalLMOutput(
942 | logits=lm_logits,
943 | hidden_states=outputs.hidden_states,
944 | attentions=outputs.attentions,
945 | ttt_stats=outputs.ttt_stats,
946 | )
947 |
--------------------------------------------------------------------------------
/ttt/models/ttt_layer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 | from functools import partial
5 | from typing import Any, Union, Sequence, Optional, Tuple
6 |
7 | import jax
8 | import jax.numpy as jnp
9 | import flax
10 | from jax import vmap
11 | from jax.tree_util import tree_map
12 | from jax.sharding import PartitionSpec as PS
13 | from flax import linen as nn
14 | from flax.linen import partitioning as nn_partitioning
15 |
16 | from ttt.infra.jax_utils import with_sharding_constraint, get_gradient_checkpoint_policy
17 |
18 | Axes = Union[int, Sequence[int]]
19 |
20 |
21 | def scan_remat_every_n_iterations_scan(f, n, carry, x):
22 | """
23 | Remat every n mini batches.
24 | """
25 | x_grouped = tree_map(lambda x: x.reshape((-1, n, *x.shape[1:])), x)
26 | carry, y_grouped = jax.lax.scan(jax.remat(partial(jax.lax.scan, f), prevent_cse=False), carry, x_grouped)
27 | y = tree_map(lambda x: x.reshape((-1, *x.shape[2:])), y_grouped)
28 | return carry, y
29 |
30 |
31 | def get_multi_head_params(self, params, param_dtype, kernel_init="normal", std=0.02):
32 | flat_params = flax.traverse_util.flatten_dict(params, sep="/")
33 | for k in flat_params.keys():
34 | new_shape = (self.num_heads, *flat_params[k].shape)
35 | if "scale" in k:
36 | p = self.param(k, jax.nn.initializers.ones, new_shape, param_dtype)
37 | elif "kernel" in k:
38 | if kernel_init == "normal":
39 | initializer = nn.initializers.normal(std)
40 | elif kernel_init == "zeros":
41 | initializer = nn.initializers.zeros
42 | elif kernel_init == "ones":
43 | initializer = nn.initializers.ones
44 | else:
45 | raise NotImplementedError("Initializer %s Not Implemented." % (kernel_init))
46 | p = self.param(k, initializer, new_shape, param_dtype)
47 | else:
48 | p = self.param(k, jax.nn.initializers.zeros, new_shape, param_dtype)
49 | flat_params[k] = p
50 | params_init = flax.traverse_util.unflatten_dict(flat_params, sep="/")
51 | return params_init
52 |
53 |
54 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype: jnp.dtype = jnp.float32) -> jnp.ndarray:
55 | freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
56 | t = np.arange(end)
57 | freqs = np.outer(t, freqs).astype(dtype)
58 | sin, cos = np.sin(freqs), np.cos(freqs)
59 | freqs_cis = np.complex64(cos + 1j * sin)
60 | return jnp.asarray(freqs_cis)
61 |
62 |
63 | def apply_rotary_emb(
64 | xq: jnp.ndarray, xk: jnp.ndarray, freqs_cis: jnp.ndarray, dtype: jnp.dtype = jnp.float32
65 | ) -> Tuple[jnp.ndarray, jnp.ndarray]:
66 |
67 | reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
68 | reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
69 |
70 | xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
71 | xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
72 |
73 | freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]))
74 |
75 | xq_out = xq_ * freqs_cis
76 | xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)
77 |
78 | xk_out = xk_ * freqs_cis
79 | xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)
80 |
81 | return xq_out.astype(dtype), xk_out.astype(dtype)
82 |
83 |
84 | def diff_gelu(x):
85 | tanh_out = jnp.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
86 | ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
87 | return ff
88 |
89 |
90 | class LinearLayerTemplate(nn.Module):
91 | width: int
92 | use_bias: bool
93 | name: str
94 | dtype: jnp.dtype = jnp.float32
95 | param_dtype: jnp.dtype = jnp.float32
96 |
97 | @nn.compact
98 | def __call__(self, x):
99 | x = nn.Dense(
100 | self.width, use_bias=self.use_bias, name=self.name, dtype=self.dtype, param_dtype=self.param_dtype
101 | )(x)
102 | return x
103 |
104 |
105 | class LayerNormTemplate(nn.Module):
106 | name: str
107 | dtype: jnp.dtype = jnp.float32
108 | param_dtype: jnp.dtype = jnp.float32
109 |
110 | @nn.compact
111 | def __call__(self, x):
112 | x = nn.LayerNorm(name=self.name, dtype=self.dtype, param_dtype=self.param_dtype)(x)
113 | return x
114 |
115 |
116 | class TTTBase(nn.Module):
117 | config: Any = None
118 | dtype: jnp.dtype = jnp.float32
119 | param_dtype: jnp.dtype = jnp.float32
120 | precision: Optional[Union[jax.lax.Precision, str]] = None
121 |
122 | def setup(self):
123 | self.width = self.config.hidden_size
124 | self.num_heads = self.config.num_attention_heads
125 | self.head_dim = self.width // self.num_heads
126 | self.mini_batch_size = self.config.mini_batch_size
127 | self.n_mini_batch = self.config.max_sequence_length // self.mini_batch_size
128 | self.seq_shape = (self.n_mini_batch, self.mini_batch_size)
129 | self.freqs_cis = precompute_freqs_cis(
130 | self.head_dim, self.mini_batch_size * 2, theta=self.config.rope_theta, dtype=self.dtype
131 | )
132 |
133 | self.setup_qkvo()
134 | self.setup_token_idx()
135 | self.setup_ttt_lr_gate()
136 |
137 | self.ttt_norm = LayerNormTemplate(dtype=self.dtype, param_dtype=self.param_dtype)
138 | ttt_norm_params = self.ttt_norm.init(jax.random.PRNGKey(0), jnp.ones([1, self.head_dim]))["params"]
139 | self.ttt_norm_params = get_multi_head_params(
140 | self, ttt_norm_params, param_dtype=self.param_dtype, kernel_init="layer_norm"
141 | )
142 | self.post_norm = nn.LayerNorm(dtype=self.dtype, param_dtype=self.param_dtype)
143 |
144 | self.ttt_params = ()
145 |
146 | def setup_qkvo(self):
147 | self.wq = nn.Dense(
148 | self.num_heads * self.head_dim,
149 | dtype=self.dtype,
150 | param_dtype=self.param_dtype,
151 | use_bias=False,
152 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
153 | precision=self.precision,
154 | )
155 | self.wk = nn.Dense(
156 | self.num_heads * self.head_dim,
157 | dtype=self.dtype,
158 | param_dtype=self.param_dtype,
159 | use_bias=False,
160 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
161 | precision=self.precision,
162 | )
163 | self.wv = nn.Dense(
164 | self.num_heads * self.head_dim,
165 | dtype=self.dtype,
166 | param_dtype=self.param_dtype,
167 | use_bias=False,
168 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
169 | precision=self.precision,
170 | )
171 | self.wo = nn.Dense(
172 | self.width,
173 | dtype=self.dtype,
174 | param_dtype=self.param_dtype,
175 | use_bias=False,
176 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
177 | precision=self.precision,
178 | )
179 |
180 | def setup_token_idx(self):
181 | self.token_idx = 1.0 / jnp.arange(1, self.mini_batch_size + 1, dtype=jnp.float32)
182 | self.learnable_token_idx = self.param(
183 | "learnable_token_idx", nn.initializers.zeros, (self.mini_batch_size,), jnp.float32
184 | )
185 |
186 | def setup_ttt_lr_gate(self):
187 | self.learnable_ttt_lr = LinearLayerTemplate(
188 | width=1, use_bias=True, name="learnable_ttt_lr", dtype=self.dtype, param_dtype=self.param_dtype
189 | )
190 | learnable_ttt_lr_params = self.learnable_ttt_lr.init(jax.random.PRNGKey(0), jnp.ones([1, self.width]))["params"]
191 | self.learnable_ttt_lr_params = get_multi_head_params(
192 | self,
193 | learnable_ttt_lr_params,
194 | param_dtype=self.param_dtype,
195 | kernel_init="normal",
196 | std=self.config.initializer_range,
197 | )
198 |
199 | def _split_heads(self, hidden_states):
200 | return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
201 |
202 | def _split_mini_batches(self, hidden_states):
203 | B, N, num_head, head_dim = hidden_states.shape
204 | hidden_states = hidden_states.reshape(B, *self.seq_shape, self.num_heads, self.head_dim).transpose(
205 | 0, 3, 1, 2, 4
206 | )
207 | return hidden_states
208 |
209 | def get_qkv_projections(self, batch):
210 | XQ, XK, XV = self.wq(batch), self.wk(batch), self.wv(batch)
211 | return XQ, XK, XV
212 |
213 | def get_eta(self, X):
214 | learnable_ttt_lr = vmap(
215 | lambda x, p: self.learnable_ttt_lr.apply({"params": p}, x), axis_name="head", in_axes=[None, 0], out_axes=1
216 | )(X, self.learnable_ttt_lr_params)
217 | learnable_ttt_lr = nn.sigmoid(learnable_ttt_lr)
218 | learnable_ttt_lr = learnable_ttt_lr.transpose(0, 1, 2, 4, 3)
219 |
220 | token_idx = self.learnable_token_idx + self.token_idx
221 | token_idx = jnp.clip(token_idx, a_min=0.0)
222 |
223 | eta = (
224 | (self.config.ttt_base_lr * token_idx).reshape(1, 1, 1, token_idx.shape[0], -1)
225 | * learnable_ttt_lr
226 | / self.head_dim
227 | )
228 | return eta
229 |
230 | def get_ttt_inputs(self, batch, position_ids):
231 | B, N, F = batch.shape
232 | n_mini_batch = N // self.mini_batch_size
233 | X = batch.reshape(B, *self.seq_shape, self.width)
234 |
235 | XQ, XK, XV = self.get_qkv_projections(batch)
236 |
237 | if self.config.output_ttt_stats:
238 | XV_last_in_mini_batch = XV[:, :: self.mini_batch_size, ...].reshape(
239 | B, n_mini_batch, self.num_heads, self.head_dim
240 | )
241 | XK_last_in_mini_batch = XK[:, :: self.mini_batch_size, ...].reshape(
242 | B, n_mini_batch, self.num_heads, self.head_dim
243 | )
244 | ssl_tgt_last_in_mini_batch = XV_last_in_mini_batch - XK_last_in_mini_batch
245 | ssl_tgt_mean = (XV - XK).mean(axis=1, keepdims=True).reshape(B, 1, self.num_heads, self.head_dim)
246 | ssl_tgt_last_in_mini_batch_from_mean_mse = ((ssl_tgt_last_in_mini_batch - ssl_tgt_mean) ** 2).mean(
247 | axis=(0, 2, 3)
248 | )
249 | else:
250 | ssl_tgt_last_in_mini_batch_from_mean_mse = None
251 |
252 | XQ = with_sharding_constraint(XQ, PS(("dp", "fsdp"), None, "mp"))
253 | XK = with_sharding_constraint(XK, PS(("dp", "fsdp"), None, "mp"))
254 | XV = with_sharding_constraint(XV, PS(("dp", "fsdp"), None, "mp"))
255 |
256 | XQ = self._split_heads(XQ)
257 | XK = self._split_heads(XK)
258 | XV = self._split_heads(XV)
259 |
260 | freqs_cis = jnp.take(self.freqs_cis, position_ids % self.mini_batch_size, axis=0)
261 | XQ, XK = apply_rotary_emb(XQ, XK, freqs_cis=freqs_cis, dtype=self.dtype)
262 |
263 | XQ = self._split_mini_batches(XQ)
264 | XK = self._split_mini_batches(XK)
265 | XV = self._split_mini_batches(XV)
266 |
267 | eta = self.get_eta(X)
268 |
269 | return (XQ, XK, XV, eta, (ssl_tgt_last_in_mini_batch_from_mean_mse,))
270 |
271 | def apply_gate(self, hidden_states, ttt_output):
272 | return ttt_output
273 |
274 | def project_ttt_outputs(self, XQW_batch):
275 | z_batch = self.wo(XQW_batch)
276 | return z_batch
277 |
278 | def process_mini_batch(
279 | self,
280 | XQ_mini_batch,
281 | XK_mini_batch,
282 | XV_mini_batch,
283 | eta_mini_batch,
284 | ttt_params_init,
285 | ttt_params_mini_batch_init,
286 | ttt_norm_params,
287 | ):
288 | raise NotImplementedError
289 |
290 | def ttt(self, XQ, XK, XV, eta, input_ids):
291 | B, N = XV.shape[0], XV.shape[2] * XV.shape[3]
292 |
293 | @partial(vmap, axis_name="batch")
294 | def update_embed(XQ, XK, XV, eta):
295 | @partial(vmap, axis_name="head")
296 | def parallelize_over_heads(XQ, XK, XV, eta, ttt_params_init, ttt_norm_params):
297 | def compute_mini_batch(ttt_params_mini_batch_init, inputs):
298 | XQ_mini_batch = inputs["XQ"]
299 | XK_mini_batch = inputs["XK"]
300 | XV_mini_batch = inputs["XV"]
301 | eta_mini_batch = inputs["eta"]
302 |
303 | ttt_params_last_in_mini_batch, outputs = self.process_mini_batch(
304 | XQ_mini_batch,
305 | XK_mini_batch,
306 | XV_mini_batch,
307 | eta_mini_batch,
308 | ttt_params_init,
309 | ttt_params_mini_batch_init,
310 | ttt_norm_params,
311 | )
312 | return ttt_params_last_in_mini_batch, outputs
313 |
314 | inputs = {"XQ": XQ, "XK": XK, "XV": XV, "eta": eta}
315 |
316 | _, outputs = scan_remat_every_n_iterations_scan(
317 | compute_mini_batch, self.config.remat_mini_batch_group_size, ttt_params_init, inputs
318 | )
319 | Z, ttt_loss_mse_init, ttt_loss_mse_step_0, ttt_loss_mse_step_1 = outputs
320 | return (Z.reshape(-1, self.head_dim), ttt_loss_mse_init, ttt_loss_mse_step_0, ttt_loss_mse_step_1)
321 |
322 | outputs = parallelize_over_heads(XQ, XK, XV, eta, self.ttt_params, self.ttt_norm_params)
323 | return outputs
324 |
325 | outputs = update_embed(XQ, XK, XV, eta)
326 | Z, ttt_loss_mse_init, ttt_loss_mse_step_0, ttt_loss_mse_step_1 = outputs
327 | Z = Z.transpose(0, 2, 1, 3).reshape(B, N, -1)
328 |
329 | if self.config.output_ttt_stats:
330 | ttt_loss_mse_init = ttt_loss_mse_init.mean(axis=(0, 1))
331 | ttt_loss_mse_step_0 = ttt_loss_mse_step_0.mean(axis=(0, 1))
332 | ttt_loss_mse_step_1 = ttt_loss_mse_step_1.mean(axis=(0, 1))
333 |
334 | return Z, (ttt_loss_mse_init, ttt_loss_mse_step_0, ttt_loss_mse_step_1)
335 |
336 | def __call__(
337 | self,
338 | hidden_states,
339 | input_ids=None,
340 | position_ids=None,
341 | deterministic: bool = True,
342 | output_ttt_stats: bool = False,
343 | ttt_lr_mult=1.0,
344 | ):
345 | self.config.output_ttt_stats = output_ttt_stats
346 | del deterministic
347 | XQ, XK, XV, eta, precompute_stats = self.get_ttt_inputs(hidden_states, position_ids=position_ids)
348 | eta *= ttt_lr_mult
349 | Z, ttt_stats = self.ttt(XQ, XK, XV, eta, input_ids)
350 | Z = self.post_norm(Z)
351 | Z = self.apply_gate(hidden_states, Z)
352 | ttt_output = self.project_ttt_outputs(Z)
353 | return ttt_output, (*precompute_stats, *ttt_stats)
354 |
355 |
356 | class TTTLinearBase(TTTBase):
357 | def setup(self):
358 | super().setup()
359 | self.W1 = self.param(
360 | "ttt_dense_0",
361 | nn.initializers.normal(self.config.initializer_range),
362 | (self.num_heads, self.head_dim, self.head_dim),
363 | self.param_dtype,
364 | )
365 | self.b1 = self.param("ttt_bias_0", nn.initializers.zeros, (self.num_heads, 1, self.head_dim), self.param_dtype)
366 | self.ttt_params = (self.W1, self.b1)
367 |
368 | def process_mini_batch(
369 | self,
370 | XQ_mini_batch,
371 | XK_mini_batch,
372 | XV_mini_batch,
373 | eta_mini_batch,
374 | ttt_params_init,
375 | ttt_params_mini_batch_init,
376 | ttt_norm_params,
377 | ):
378 |
379 | W1_init, b1_init = ttt_params_mini_batch_init
380 | square_eta_mini_batch = eta_mini_batch[: self.mini_batch_size]
381 | last_eta_in_mini_batch = eta_mini_batch[-1][:, None]
382 |
383 | X1 = XK_mini_batch
384 | Z1 = X1 @ W1_init + b1_init
385 | ttt_norm_out, ttt_norm_vjp = jax.vjp(lambda z: self.ttt_norm.apply({"params": ttt_norm_params}, z), Z1)
386 | ssl_target = XV_mini_batch - XK_mini_batch
387 | grad_l_wrt_ttt_norm_out = ttt_norm_out - ssl_target
388 | grad_l_wrt_Z1 = ttt_norm_vjp(grad_l_wrt_ttt_norm_out)[0]
389 |
390 | # Calculate TTT loss using W_init of the current mini-batch
391 | if self.config.output_ttt_stats:
392 | ttt_loss_mse_step_0 = (grad_l_wrt_ttt_norm_out[-1] ** 2).mean()
393 | else:
394 | ttt_loss_mse_step_0 = None
395 |
396 | # Calculate TTT loss using W_init of the entire sequence
397 | if self.config.output_ttt_stats:
398 | W1_0, b1_0 = ttt_params_init
399 | Z1_0 = X1 @ W1_0 + b1_0
400 | ttt_norm_out_0 = self.ttt_norm.apply({"params": ttt_norm_params}, Z1_0)
401 | ttt_loss_mse_init = ((ttt_norm_out_0 - ssl_target)[-1] ** 2).mean()
402 | else:
403 | ttt_loss_mse_init = None
404 |
405 | X1_bar = XQ_mini_batch
406 | Attn1 = jnp.tril(X1_bar @ X1.transpose(1, 0))
407 | b1_bar = b1_init - (square_eta_mini_batch * jnp.tril(jnp.ones_like(Attn1))) @ grad_l_wrt_Z1
408 | Z1_bar = X1_bar @ W1_init - (square_eta_mini_batch * Attn1) @ grad_l_wrt_Z1 + b1_bar
409 | ttt_norm_out_bar = self.ttt_norm.apply({"params": ttt_norm_params}, Z1_bar)
410 |
411 | output_mini_batch = X1_bar + ttt_norm_out_bar
412 |
413 | W1_bar_last = W1_init - (last_eta_in_mini_batch * X1).transpose(1, 0) @ grad_l_wrt_Z1
414 | b1_bar_last = b1_init - jnp.sum(last_eta_in_mini_batch * grad_l_wrt_Z1, axis=0, keepdims=True)
415 |
416 | # Calculate ttt loss using the updated W_init by the current mini-batch
417 | if self.config.output_ttt_stats:
418 | X1_last_fwd_new = X1[-1:] @ W1_bar_last + b1_bar_last
419 | X1_last_fwd_new = self.ttt_norm.apply({"params": ttt_norm_params}, X1_last_fwd_new)
420 | ttt_loss_mse_step_1 = ((X1_last_fwd_new - ssl_target[-1:]) ** 2).mean()
421 | else:
422 | ttt_loss_mse_step_1 = None
423 |
424 | ttt_params_mini_batch_new = (W1_bar_last, b1_bar_last)
425 |
426 | return (
427 | ttt_params_mini_batch_new,
428 | (output_mini_batch, ttt_loss_mse_init, ttt_loss_mse_step_0, ttt_loss_mse_step_1),
429 | )
430 |
431 |
432 | class TTTLinear(TTTLinearBase):
433 | def setup(self):
434 | super().setup()
435 | self.wg = nn.Dense(
436 | self.width,
437 | dtype=self.dtype,
438 | param_dtype=self.param_dtype,
439 | use_bias=False,
440 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
441 | precision=self.precision,
442 | )
443 |
444 | def setup_qkvo(self):
445 | self.wq = nn.Dense(
446 | self.num_heads * self.head_dim,
447 | dtype=self.dtype,
448 | param_dtype=self.param_dtype,
449 | use_bias=False,
450 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
451 | precision=self.precision,
452 | )
453 | if self.config.remat_conv != "":
454 | conv_module = nn_partitioning.remat(
455 | nn.Conv, policy=get_gradient_checkpoint_policy(self.config.remat_conv), prevent_cse=True
456 | )
457 | else:
458 | conv_module = nn.Conv
459 | self.conv_q = conv_module(
460 | self.config.hidden_size,
461 | (self.config.conv_width,),
462 | padding="CAUSAL",
463 | feature_group_count=self.config.hidden_size,
464 | dtype=self.dtype,
465 | param_dtype=self.param_dtype,
466 | precision=self.precision,
467 | )
468 | self.conv_k = conv_module(
469 | self.config.hidden_size,
470 | (self.config.conv_width,),
471 | padding="CAUSAL",
472 | feature_group_count=self.config.hidden_size,
473 | dtype=self.dtype,
474 | param_dtype=self.param_dtype,
475 | precision=self.precision,
476 | )
477 | self.wv = nn.Dense(
478 | self.num_heads * self.head_dim,
479 | dtype=self.dtype,
480 | param_dtype=self.param_dtype,
481 | use_bias=False,
482 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
483 | precision=self.precision,
484 | )
485 | self.wo = nn.Dense(
486 | self.width,
487 | dtype=self.dtype,
488 | param_dtype=self.param_dtype,
489 | use_bias=False,
490 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
491 | precision=self.precision,
492 | )
493 |
494 | def get_qkv_projections(self, batch):
495 | xqk, XV = self.wq(batch), self.wv(batch)
496 | XQ = self.conv_q(xqk)
497 | XK = self.conv_k(xqk)
498 | return XQ, XK, XV
499 |
500 | def apply_gate(self, hidden_states, ttt_output):
501 | y = self.wg(hidden_states)
502 | y = nn.gelu(y)
503 | output = y * ttt_output
504 | return output
505 |
506 |
507 | class TTTMLPBase(TTTBase):
508 | def setup(self):
509 | super().setup()
510 | self.W1 = self.param(
511 | "ttt_dense_0",
512 | nn.initializers.normal(self.config.initializer_range),
513 | (self.num_heads, self.head_dim, 4 * self.head_dim),
514 | self.param_dtype,
515 | )
516 | self.b1 = self.param(
517 | "ttt_bias_0", nn.initializers.zeros, (self.num_heads, 1, 4 * self.head_dim), self.param_dtype
518 | )
519 | self.W2 = self.param(
520 | "ttt_dense_1",
521 | nn.initializers.normal(self.config.initializer_range),
522 | (self.num_heads, 4 * self.head_dim, self.head_dim),
523 | self.param_dtype,
524 | )
525 | self.b2 = self.param("ttt_bias_1", nn.initializers.zeros, (self.num_heads, 1, self.head_dim), self.param_dtype)
526 | self.ttt_params = (self.W1, self.W2, self.b1, self.b2)
527 |
528 | def process_mini_batch(
529 | self,
530 | XQ_mini_batch,
531 | XK_mini_batch,
532 | XV_mini_batch,
533 | eta_mini_batch,
534 | ttt_params_init,
535 | ttt_params_mini_batch_init,
536 | ttt_norm_params,
537 | ):
538 |
539 | W1_init, W2_init, b1_init, b2_init = ttt_params_mini_batch_init
540 | square_eta_mini_batch = eta_mini_batch[: self.mini_batch_size]
541 | last_eta_in_mini_batch = eta_mini_batch[-1][:, None]
542 |
543 | X1 = XK_mini_batch
544 | Z1 = X1 @ W1_init + b1_init
545 | X2 = nn.gelu(Z1)
546 | Z2 = X2 @ W2_init + b2_init
547 | ttt_norm_out, ttt_norm_vjp = jax.vjp(lambda z: self.ttt_norm.apply({"params": ttt_norm_params}, z), Z2)
548 |
549 | ssl_target = XV_mini_batch - X1
550 | grad_l_wrt_ttt_norm_out = ttt_norm_out - ssl_target
551 | grad_l_wrt_Z2 = ttt_norm_vjp(grad_l_wrt_ttt_norm_out)[0]
552 | grad_l_wrt_Z1 = grad_l_wrt_Z2 @ W2_init.transpose(1, 0) * diff_gelu(Z1)
553 |
554 | if self.config.output_ttt_stats:
555 | ttt_loss_mse_step_0 = (grad_l_wrt_ttt_norm_out[-1] ** 2).mean()
556 | else:
557 | ttt_loss_mse_step_0 = None
558 |
559 | # Calculate ttt loss using W_init of the entire sequence
560 | if self.config.output_ttt_stats:
561 | W1_0, W2_0, b1_0, b2_0 = ttt_params_init
562 | Z1_0 = X1 @ W1_0 + b1_0
563 | X2_0 = nn.gelu(Z1_0)
564 | Z2_0 = X2_0 @ W2_0 + b2_0
565 | ttt_norm_out_0 = self.ttt_norm.apply({"params": ttt_norm_params}, Z2_0)
566 | ttt_loss_mse_init = ((ttt_norm_out_0 - ssl_target)[-1] ** 2).mean()
567 | else:
568 | ttt_loss_mse_init = None
569 |
570 | X1_bar = XQ_mini_batch
571 | Attn1 = jnp.tril(X1_bar @ X1.transpose(1, 0))
572 | b1_bar = b1_init - (square_eta_mini_batch * jnp.tril(jnp.ones_like(Attn1))) @ grad_l_wrt_Z1
573 | Z1_bar = X1_bar @ W1_init - (square_eta_mini_batch * Attn1) @ grad_l_wrt_Z1 + b1_bar
574 |
575 | X2_bar = nn.gelu(Z1_bar)
576 | Attn2 = jnp.tril(X2_bar @ X2.transpose(1, 0))
577 | b2_bar = b2_init - (square_eta_mini_batch * jnp.tril(jnp.ones_like(Attn2))) @ grad_l_wrt_Z2
578 | Z2_bar = X2_bar @ W2_init - (square_eta_mini_batch * Attn2) @ grad_l_wrt_Z2 + b2_bar
579 | ttt_norm_out_bar = self.ttt_norm.apply({"params": ttt_norm_params}, Z2_bar)
580 |
581 | output_mini_batch = X1_bar + ttt_norm_out_bar
582 |
583 | W1_bar_last = W1_init - (last_eta_in_mini_batch * X1).transpose(1, 0) @ grad_l_wrt_Z1
584 | W2_bar_last = W2_init - (last_eta_in_mini_batch * X2).transpose(1, 0) @ grad_l_wrt_Z2
585 | b1_bar_last = b1_init - jnp.sum(last_eta_in_mini_batch * grad_l_wrt_Z1, axis=0, keepdims=True)
586 | b2_bar_last = b2_init - jnp.sum(last_eta_in_mini_batch * grad_l_wrt_Z2, axis=0, keepdims=True)
587 |
588 | if self.config.output_ttt_stats:
589 | X1_last_fwd_new = nn.gelu((X1[-1:] @ W1_bar_last) + b1_bar_last) @ W2_bar_last + b2_bar_last
590 | X1_last_fwd_new = self.ttt_norm.apply({"params": ttt_norm_params}, X1_last_fwd_new)
591 | ttt_loss_mse_step_1 = ((X1_last_fwd_new - ssl_target[-1:]) ** 2).mean()
592 | else:
593 | ttt_loss_mse_step_1 = None
594 |
595 | ttt_params_mini_batch_new = (W1_bar_last, W2_bar_last, b1_bar_last, b2_bar_last)
596 |
597 | return (
598 | ttt_params_mini_batch_new,
599 | (output_mini_batch, ttt_loss_mse_init, ttt_loss_mse_step_0, ttt_loss_mse_step_1),
600 | )
601 |
602 |
603 | class TTTMLP(TTTMLPBase):
604 | def setup(self):
605 | super().setup()
606 | self.wg = nn.Dense(
607 | self.width,
608 | dtype=self.dtype,
609 | param_dtype=self.param_dtype,
610 | use_bias=False,
611 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
612 | precision=self.precision,
613 | )
614 |
615 | def setup_qkvo(self):
616 | self.wq = nn.Dense(
617 | self.num_heads * self.head_dim,
618 | dtype=self.dtype,
619 | param_dtype=self.param_dtype,
620 | use_bias=False,
621 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
622 | precision=self.precision,
623 | )
624 | if self.config.remat_conv != "":
625 | conv_module = nn_partitioning.remat(
626 | nn.Conv, policy=get_gradient_checkpoint_policy(self.config.remat_conv), prevent_cse=True
627 | )
628 | else:
629 | conv_module = nn.Conv
630 | self.conv_q = conv_module(
631 | self.config.hidden_size,
632 | (self.config.conv_width,),
633 | padding="CAUSAL",
634 | feature_group_count=self.config.hidden_size,
635 | dtype=self.dtype,
636 | param_dtype=self.param_dtype,
637 | precision=self.precision,
638 | )
639 | self.conv_k = conv_module(
640 | self.config.hidden_size,
641 | (self.config.conv_width,),
642 | padding="CAUSAL",
643 | feature_group_count=self.config.hidden_size,
644 | dtype=self.dtype,
645 | param_dtype=self.param_dtype,
646 | precision=self.precision,
647 | )
648 | self.wv = nn.Dense(
649 | self.num_heads * self.head_dim,
650 | dtype=self.dtype,
651 | param_dtype=self.param_dtype,
652 | use_bias=False,
653 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
654 | precision=self.precision,
655 | )
656 | self.wo = nn.Dense(
657 | self.width,
658 | dtype=self.dtype,
659 | param_dtype=self.param_dtype,
660 | use_bias=False,
661 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
662 | precision=self.precision,
663 | )
664 |
665 | def get_qkv_projections(self, batch):
666 | xqk, XV = self.wq(batch), self.wv(batch)
667 | XQ = self.conv_q(xqk)
668 | XK = self.conv_k(xqk)
669 | return XQ, XK, XV
670 |
671 | def apply_gate(self, hidden_states, ttt_output):
672 | y = self.wg(hidden_states)
673 | y = nn.gelu(y)
674 | output = y * ttt_output
675 | return output
676 |
--------------------------------------------------------------------------------
/ttt/train.py:
--------------------------------------------------------------------------------
1 | import mlxu
2 | import wandb
3 | import os.path as osp
4 |
5 | from tqdm import tqdm
6 | from copy import deepcopy
7 |
8 | import jax
9 | import jax.numpy as jnp
10 | from jax.tree_util import tree_map
11 | from jax.experimental.pjit import pjit
12 | from jax.sharding import PartitionSpec as PS
13 | from jax.experimental.multihost_utils import process_allgather
14 | from flax.training.train_state import TrainState
15 | from flax.traverse_util import flatten_dict
16 |
17 | from ttt.infra.optimizers import OptimizerFactory
18 | from ttt.dataloader.language_modeling_hf import LMDataModule
19 | from ttt.infra.checkpoint import StreamingCheckpointer
20 | from ttt.models.model import ModelConfig, CausalLM
21 | from ttt.infra.jax_utils import (
22 | JaxRNG,
23 | JaxDistributedConfig,
24 | next_rng,
25 | match_partition_rules,
26 | cross_entropy_loss_and_accuracy,
27 | global_norm,
28 | get_float_dtype_by_name,
29 | set_random_seed,
30 | average_metrics,
31 | get_weight_decay_mask,
32 | make_shard_and_gather_fns,
33 | with_sharding_constraint,
34 | master_print,
35 | log_ttt_stats,
36 | )
37 |
38 |
39 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
40 | seed=0,
41 | mesh_dim="-1,64,1",
42 | dtype="fp32",
43 | eval_mode=False,
44 | load_part="trainstate",
45 | total_steps=100,
46 | load_model_config="",
47 | update_model_config="",
48 | save_checkpoint_freq=100,
49 | save_milestone_freq=0,
50 | dataset_path="",
51 | dataset_name="the_pile",
52 | tokenizer_name="meta-llama/Llama-2-7b-hf",
53 | seq_length=2048,
54 | global_batch_size=1,
55 | accum_steps=1,
56 | loader_workers=48,
57 | optimizer=OptimizerFactory.get_default_config(),
58 | checkpointer=StreamingCheckpointer.get_default_config(),
59 | exp_dir="",
60 | exp_name="",
61 | resume_exp_name="",
62 | resume_step="",
63 | jax_distributed=JaxDistributedConfig.get_default_config(),
64 | is_rollback_reshuffle=False,
65 | )
66 |
67 |
68 | def make_train_step_fn(model, optimizer_info, model_config, accum_steps=1):
69 |
70 | if accum_steps == 1:
71 |
72 | def train_step(train_state, rng, batch, ttt_lr_mult, output_ttt_stats=False):
73 | rng_generator = JaxRNG(rng)
74 | batch = with_sharding_constraint(batch, PS(("dp", "fsdp")))
75 |
76 | def loss_and_accuracy(params):
77 | outputs = model.apply(
78 | params,
79 | batch["input_tokens"],
80 | ttt_lr_mult=ttt_lr_mult,
81 | deterministic=False,
82 | output_ttt_stats=output_ttt_stats,
83 | rngs=rng_generator(model_config.rng_keys()),
84 | )
85 | logits = outputs.logits
86 | ttt_stats = outputs.ttt_stats
87 | loss, _ = cross_entropy_loss_and_accuracy(logits, batch["target_tokens"], batch["loss_masks"])
88 | return loss, ttt_stats
89 |
90 | grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
91 | (loss, ttt_stats), grads = grad_fn(train_state.params)
92 |
93 | train_state = train_state.apply_gradients(grads=grads)
94 | learning_rate = optimizer_info["learning_rate_schedule"](train_state.step)
95 | grads_norm = global_norm(grads)
96 |
97 | return (train_state, loss, ttt_stats, grads_norm, learning_rate, rng_generator())
98 |
99 | elif accum_steps > 1:
100 |
101 | def train_step(train_state, rng, batch, ttt_lr_mult, output_ttt_stats=False):
102 | rng_generator = JaxRNG(rng)
103 | rngs = rng_generator(model_config.rng_keys())
104 |
105 | def computation(carry, micro_batch):
106 | sum_grads = carry["sum_grads"]
107 | micro_batch = with_sharding_constraint(micro_batch, PS(("dp", "fsdp")))
108 |
109 | def loss_and_accuracy(params):
110 | outputs = model.apply(
111 | params,
112 | micro_batch["input_tokens"],
113 | ttt_lr_mult=ttt_lr_mult,
114 | deterministic=False,
115 | output_ttt_stats=output_ttt_stats,
116 | rngs=rngs,
117 | )
118 | logits = outputs.logits
119 | ttt_stats = outputs.ttt_stats
120 | loss, _ = cross_entropy_loss_and_accuracy(
121 | logits, micro_batch["target_tokens"], micro_batch["loss_masks"]
122 | )
123 | return loss, ttt_stats
124 |
125 | grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
126 | (loss, ttt_stats), grads = grad_fn(train_state.params)
127 | sum_grads = tree_map(lambda x, y: x + y, sum_grads, grads)
128 | carry_new = {"sum_grads": sum_grads}
129 | return carry_new, (loss, ttt_stats)
130 |
131 | sum_grads = jax.tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), train_state.params)
132 | carry_init = {"sum_grads": sum_grads}
133 | batch = tree_map(lambda x: x.reshape(FLAGS.accum_steps, -1, *x.shape[1:]), batch)
134 | carry_new, outputs = jax.lax.scan(computation, carry_init, batch)
135 | loss, ttt_stats = outputs
136 | loss = jnp.mean(loss)
137 | if output_ttt_stats:
138 | ttt_stats = tree_map(lambda x: jnp.mean(x, axis=0), ttt_stats)
139 | else:
140 | ttt_stats = None
141 | grads = jax.tree_util.tree_map(lambda x: x / FLAGS.accum_steps, carry_new["sum_grads"])
142 |
143 | train_state = train_state.apply_gradients(grads=grads)
144 | learning_rate = optimizer_info["learning_rate_schedule"](train_state.step)
145 | grads_norm = global_norm(grads)
146 |
147 | return (train_state, loss, ttt_stats, grads_norm, learning_rate, rng_generator())
148 |
149 | else:
150 | raise ValueError(f"Accum steps must >= 1, got {accum_steps}")
151 |
152 | return train_step
153 |
154 |
155 | def make_eval_step_fn(model, model_config):
156 | def eval_step(train_state, rng, batch):
157 | rng_generator = JaxRNG(rng)
158 | batch = with_sharding_constraint(batch, PS(("dp", "fsdp")))
159 | logits = model.apply(
160 | train_state.params, batch["input_tokens"], deterministic=True, rngs=rng_generator(model_config.rng_keys())
161 | ).logits
162 | loss, accuracy = cross_entropy_loss_and_accuracy(logits, batch["target_tokens"], batch["loss_masks"])
163 | metrics = dict(eval_loss=loss, eval_accuracy=accuracy)
164 | return rng_generator(), metrics
165 |
166 | return eval_step
167 |
168 |
169 | def make_sharded_functions(model, optimizer, optimizer_info, model_config):
170 | def create_trainstate_from_params(params):
171 | return TrainState.create(params=params, tx=optimizer, apply_fn=None)
172 |
173 | def init_fn(rng):
174 | rng_generator = JaxRNG(rng)
175 | params = model.init(
176 | input_ids=jnp.zeros((4, FLAGS.seq_length), dtype=jnp.int32),
177 | position_ids=jnp.zeros((4, FLAGS.seq_length), dtype=jnp.int32),
178 | attention_mask=jnp.ones((4, FLAGS.seq_length), dtype=jnp.int32),
179 | rngs=rng_generator(model_config.rng_keys()),
180 | )
181 | return TrainState.create(params=params, tx=optimizer, apply_fn=None)
182 |
183 | train_step = make_train_step_fn(model, optimizer_info, model_config, FLAGS.accum_steps)
184 |
185 | train_state_shapes = jax.eval_shape(init_fn, next_rng())
186 |
187 | train_state_partition = match_partition_rules(model_config.get_partition_rules(), train_state_shapes)
188 |
189 | shard_fns, gather_fns = make_shard_and_gather_fns(train_state_partition, train_state_shapes)
190 |
191 | sharded_init_fn = pjit(init_fn, in_shardings=PS(), out_shardings=train_state_partition)
192 |
193 | sharded_create_trainstate_from_params = pjit(
194 | create_trainstate_from_params,
195 | in_shardings=(train_state_partition.params,),
196 | out_shardings=train_state_partition,
197 | donate_argnums=(0,),
198 | )
199 |
200 | sharded_train_step = pjit(
201 | train_step,
202 | in_shardings=(train_state_partition, PS(), PS(), PS()),
203 | out_shardings=(train_state_partition, PS(), PS(), PS(), PS(), PS()),
204 | static_argnums=4,
205 | donate_argnums=(0,),
206 | )
207 |
208 | return (
209 | sharded_init_fn,
210 | sharded_create_trainstate_from_params,
211 | sharded_train_step,
212 | shard_fns,
213 | gather_fns,
214 | train_state_shapes,
215 | train_state_partition,
216 | )
217 |
218 |
219 | def make_save_checkpoint(checkpointer, gather_fns, variant, flags_config_dict, model_config, global_batch_size):
220 | def save_checkpoint(train_state, train_loader, milestone=False):
221 | step = int(jax.device_get(train_state.step))
222 | metadata = dict(step=step, variant=variant, flags=flags_config_dict, model_config=model_config.to_dict())
223 | sampler_state_dict = {
224 | "random_state": train_loader.sampler.state_dict()["random_state"],
225 | "shuffle_log": train_loader.sampler.state_dict()["shuffle_log"],
226 | "counter": step * global_batch_size,
227 | }
228 | checkpointer.save_all(
229 | train_state=train_state,
230 | gather_fns=gather_fns,
231 | metadata=metadata,
232 | dataset=deepcopy(sampler_state_dict),
233 | milestone=milestone,
234 | )
235 |
236 | return save_checkpoint
237 |
238 |
239 | def make_get_ttt_lr_mult(model_config):
240 |
241 | if (
242 | hasattr(model_config, "ttt_base_lr_init")
243 | and model_config.ttt_base_lr_init > 0
244 | and model_config.ttt_base_lr_warmup > 0
245 | ):
246 | ttt_lr_mult_warmup_steps = model_config.ttt_base_lr_warmup
247 | ttt_lr_mult_init = model_config.ttt_base_lr_init
248 | ttt_lr_mult_peak = model_config.ttt_base_lr
249 |
250 | def get_ttt_lr_mult(step):
251 | ttt_lr_mult = ttt_lr_mult_init + min(1.0, (step - 1) / ttt_lr_mult_warmup_steps) * (
252 | ttt_lr_mult_peak - ttt_lr_mult_init
253 | )
254 | ttt_lr_mult = ttt_lr_mult / ttt_lr_mult_peak * jnp.ones((1,), dtype=jnp.bfloat16)
255 | return ttt_lr_mult
256 |
257 | else:
258 |
259 | def get_ttt_lr_mult(step):
260 | ttt_lr_mult = jnp.ones((1,), dtype=jnp.bfloat16)
261 | return ttt_lr_mult
262 |
263 | return get_ttt_lr_mult
264 |
265 |
266 | def initialize_or_resume(
267 | checkpointer,
268 | train_loader,
269 | train_state_shapes,
270 | sharded_init_fn,
271 | shard_fns,
272 | sharded_create_trainstate_from_params,
273 | FLAGS,
274 | ):
275 | start_step = 1
276 | train_state, restored_params = None, None
277 | if FLAGS.resume_exp_name != "":
278 | assert FLAGS.load_part in ["trainstate", "trainstate_params"]
279 | ckpt_resume_dir = (
280 | FLAGS.load_part
281 | + "::"
282 | + osp.join(
283 | FLAGS.exp_dir,
284 | FLAGS.resume_exp_name,
285 | (
286 | f"step_{int(FLAGS.resume_step)}/streaming_train_state_{int(FLAGS.resume_step)}"
287 | if FLAGS.resume_step
288 | else "streaming_train_state"
289 | ),
290 | )
291 | )
292 | train_state, restored_params = checkpointer.load_trainstate_checkpoint(
293 | ckpt_resume_dir, train_state_shapes, shard_fns
294 | )
295 |
296 | if FLAGS.load_part == "trainstate":
297 | start_step = int(jax.device_get(train_state.step)) + 1
298 | master_print(f"Resuming training from checkpoint at step {start_step - 1}...")
299 | dataset_pkl_filename = (
300 | f"step_{int(FLAGS.resume_step)}/dataset_{int(FLAGS.resume_step)}.pkl"
301 | if FLAGS.resume_step
302 | else "dataset.pkl"
303 | )
304 | dataset_resume_dir = osp.join(FLAGS.exp_dir, FLAGS.resume_exp_name, dataset_pkl_filename)
305 | train_loader.sampler.load_state_dict(deepcopy(mlxu.load_pickle(dataset_resume_dir)))
306 |
307 | if FLAGS.is_rollback_reshuffle:
308 | train_loader.sampler.is_rollback = True
309 |
310 | if train_state is None and restored_params is None:
311 | train_state = sharded_init_fn(next_rng())
312 | elif train_state is None and restored_params is not None:
313 | train_state = sharded_create_trainstate_from_params(restored_params)
314 | del restored_params
315 |
316 | return start_step, train_state, train_loader
317 |
318 |
319 | def main(argv):
320 | JaxDistributedConfig.initialize(FLAGS.jax_distributed)
321 | variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
322 | flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
323 |
324 | set_random_seed(FLAGS.seed)
325 | process_num = jax.process_count()
326 | global_dev_num = jax.device_count()
327 | local_dev_num = jax.local_device_count()
328 | master_process = jax.process_index() == 0
329 |
330 | dev_info = f"Process # {process_num}\tLocal dev # {local_dev_num}\tTotal dev # {global_dev_num}"
331 | master_print(dev_info)
332 |
333 | seq_length = FLAGS.seq_length
334 | global_batch_size = FLAGS.global_batch_size
335 | is_rollback_reshuffle = FLAGS.is_rollback_reshuffle
336 |
337 | # Create dataloader
338 | data_module = LMDataModule(
339 | dataset_name=FLAGS.dataset_name,
340 | dataset_config_name=None,
341 | tokenizer_name=FLAGS.tokenizer_name,
342 | cache_dir=FLAGS.dataset_path,
343 | max_length=seq_length,
344 | add_eos=True,
345 | batch_size=global_batch_size,
346 | batch_size_eval=global_batch_size,
347 | loader_workers=FLAGS.loader_workers,
348 | shuffle=True,
349 | fault_tolerant=True,
350 | drop_last=True,
351 | )
352 | data_module.prepare_data()
353 | data_module.setup()
354 | train_loader = data_module.train_dataloader()
355 |
356 | # Update model model_config
357 | if FLAGS.load_model_config != "":
358 | model_config = ModelConfig.load_config(FLAGS.load_model_config)
359 | else:
360 | raise RuntimeError(f"model_config must be specified")
361 | if FLAGS.update_model_config:
362 | update_dic = eval(FLAGS.update_model_config)
363 | for key, value in update_dic.items():
364 | if hasattr(model_config, key):
365 | setattr(model_config, key, value)
366 | else:
367 | raise KeyError(f"Update key {key} not in model_config")
368 | model_config.vocab_size = data_module.vocab_size
369 | model_config.max_sequence_length = seq_length
370 | flags_config_dict.model_config = model_config
371 |
372 | # Create WandB run and checkpointer
373 | if master_process:
374 | wandb.init(project="TTT-LM", config=flags_config_dict, name=FLAGS.exp_name)
375 | ckpt_dir = osp.join(FLAGS.exp_dir, FLAGS.exp_name)
376 | checkpointer = StreamingCheckpointer(FLAGS.checkpointer, ckpt_dir, enable=master_process)
377 |
378 | # Create model and optimizer
379 | model = CausalLM(model_config, dtype=get_float_dtype_by_name(FLAGS.dtype))
380 | optimizer, optimizer_info = OptimizerFactory.get_optimizer(
381 | FLAGS.optimizer, get_weight_decay_mask(model_config.get_weight_decay_exclusions())
382 | )
383 |
384 | # Helper function for dynamic TTT learning rate
385 | get_ttt_lr_mult = make_get_ttt_lr_mult(model_config)
386 |
387 | # Create sharded train functions
388 | (
389 | sharded_init_fn,
390 | sharded_create_trainstate_from_params,
391 | sharded_train_step,
392 | shard_fns,
393 | gather_fns,
394 | train_state_shapes,
395 | train_state_partition,
396 | ) = make_sharded_functions(model, optimizer, optimizer_info, model_config)
397 |
398 | save_checkpoint = make_save_checkpoint(
399 | checkpointer, gather_fns, variant, flags_config_dict, model_config, global_batch_size
400 | )
401 |
402 | mesh = model_config.get_jax_mesh(FLAGS.mesh_dim)
403 | with mesh:
404 | sharded_rng = next_rng()
405 |
406 | start_step, train_state, train_loader = initialize_or_resume(
407 | checkpointer,
408 | train_loader,
409 | train_state_shapes,
410 | sharded_init_fn,
411 | shard_fns,
412 | sharded_create_trainstate_from_params,
413 | FLAGS,
414 | )
415 |
416 | if FLAGS.eval_mode:
417 | eval_step = make_eval_step_fn(model, model_config)
418 | sharded_eval_step = pjit(
419 | eval_step,
420 | in_shardings=(train_state_partition, PS(), PS()),
421 | out_shardings=(PS(), PS()),
422 | donate_argnums=(1,),
423 | )
424 |
425 | val_loader = data_module.val_dataloader()
426 | eval_metric_list = []
427 |
428 | for eval_batch in tqdm(val_loader, disable=not master_process):
429 | for k in eval_batch.keys():
430 | eval_batch[k] = eval_batch[k].numpy()
431 | sharded_rng, eval_metrics = sharded_eval_step(train_state, sharded_rng, eval_batch)
432 | eval_metric_list.append(eval_metrics)
433 |
434 | val_loss_avg = average_metrics(process_allgather(eval_metric_list))["eval_loss"].item()
435 | master_print(f"Eval Loss: {val_loss_avg:.4f}")
436 | exit(0)
437 |
438 | train_loader_iterator = iter(train_loader)
439 |
440 | for step in tqdm(
441 | range(start_step, FLAGS.total_steps + 1),
442 | initial=start_step,
443 | total=FLAGS.total_steps,
444 | disable=not master_process,
445 | desc=f"Training {FLAGS.exp_name}",
446 | ):
447 | try:
448 | batch = next(train_loader_iterator)
449 | except StopIteration:
450 | train_loader.sampler.counter = 0
451 | train_loader_iterator = iter(train_loader)
452 | batch = next(train_loader_iterator)
453 |
454 | if is_rollback_reshuffle:
455 | sampler_state_dict = {
456 | "random_state": train_loader.sampler.state_dict()["random_state"],
457 | "shuffle_log": train_loader.sampler.state_dict()["shuffle_log"],
458 | "counter": (step - 1) * global_batch_size,
459 | }
460 | if master_process and FLAGS.resume_exp_name != "":
461 | master_print("Updating sampler state after rollback...")
462 | dataset_pkl_filename = (
463 | f"step_{int(FLAGS.resume_step)}/dataset_{int(FLAGS.resume_step)}.pkl"
464 | if FLAGS.resume_step
465 | else "dataset_state.pkl"
466 | )
467 | dataset_resume_dir = osp.join(FLAGS.exp_dir, FLAGS.resume_exp_name, dataset_pkl_filename)
468 | mlxu.save_pickle(deepcopy(sampler_state_dict), dataset_resume_dir)
469 | is_rollback_reshuffle = False
470 | master_print("Finished updating sampler state.")
471 |
472 | for k in batch.keys():
473 | batch[k] = batch[k].numpy()
474 |
475 | ttt_lr_mult = get_ttt_lr_mult(step)
476 | output_ttt_stats = (
477 | FLAGS.save_milestone_freq > 0
478 | and step % FLAGS.save_milestone_freq == 0
479 | and model_config.seq_modeling_block != "self_attention"
480 | )
481 |
482 | train_state, loss, ttt_stats, grads_norm, learning_rate, sharded_rng = sharded_train_step(
483 | train_state, sharded_rng, batch, ttt_lr_mult, output_ttt_stats
484 | )
485 |
486 | if master_process:
487 | wandb.log(
488 | {
489 | "Train Loss": loss.item(),
490 | "Gradient Norm": grads_norm.item(),
491 | "Learning Rate": learning_rate.item(),
492 | },
493 | step=step,
494 | )
495 |
496 | if output_ttt_stats:
497 | for layer in range(len(ttt_stats)):
498 | ttt_stats_layer = process_allgather(ttt_stats[layer])
499 | n_mini_batch = len(ttt_stats_layer[0])
500 | x_axis = [model_config.mini_batch_size * i for i in range(1, n_mini_batch + 1)]
501 | log_ttt_stats(layer, ttt_stats_layer, x_axis, step)
502 |
503 | if (FLAGS.save_checkpoint_freq > 0 and step % FLAGS.save_checkpoint_freq == 0) or (
504 | step == FLAGS.total_steps
505 | ):
506 | master_print(f"Saving checkpoint at step {step}, do not kill...")
507 | save_checkpoint(train_state, train_loader, step % FLAGS.save_milestone_freq == 0)
508 |
509 | if step == FLAGS.total_steps:
510 | master_print("Training has completed!")
511 |
512 |
513 | if __name__ == "__main__":
514 | mlxu.run(main)
515 |
--------------------------------------------------------------------------------