├── .gitignore ├── README.md ├── requirements ├── gpu_requirements.txt └── tpu_requirements.txt ├── scripts ├── transformer │ ├── 1.3b.sh │ ├── 125m.sh │ ├── 350m.sh │ └── 760m.sh ├── ttt_linear │ ├── 1.3b.sh │ ├── 125m.sh │ ├── 350m.sh │ └── 760m.sh └── ttt_mlp │ ├── 1.3b.sh │ ├── 125m.sh │ ├── 350m.sh │ └── 760m.sh └── ttt ├── README.md ├── __init__.py ├── dataloader ├── README.md ├── __init__.py ├── language_modeling_hf.py ├── lm_dataset.py └── tokenization.py ├── infra ├── __init__.py ├── checkpoint.py ├── jax_utils.py └── optimizers.py ├── models ├── __init__.py ├── bpt.py ├── model.py └── ttt_layer.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | launcher/ 131 | 132 | # TTT Edits 133 | exp 134 | wandb 135 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to (Learn at Test Time): RNNs with Expressive Hidden States 2 | [**Paper**](https://arxiv.org/abs/2407.04620) 3 | | [**PyTorch Codebase**](https://github.com/test-time-training/ttt-lm-pytorch) 4 | | [**Setup**](#setup) 5 | | [**Replicating Experiments**](#replicating-experiments) 6 | | [**Model Docs**](ttt/README.md) 7 | | [**Dataset Preparation**](ttt/dataloader/README.md) 8 | | [**Inference Benchmark**](https://github.com/test-time-training/ttt-lm-kernels) 9 | 10 | ## Abstract 11 | 12 | Self-attention performs well in long context but has quadratic complexity. Existing RNN layers 13 | have linear complexity, but their performance in long context is limited by the expressive power 14 | of their hidden state. We propose a new class of sequence modeling layers with linear complexity 15 | and an expressive hidden state. The key idea is to make the hidden state a machine learning 16 | model itself, and the update rule a step of self-supervised learning. 17 | 18 | Since the hidden state is updated by training even on test sequences, our layers are called **Test-Time Training (TTT) layers**. 19 | We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model 20 | and a two-layer MLP respectively. 21 | 22 | ## Setup 23 | This codebase is implemented in [JAX](https://jax.readthedocs.io/en/latest/index.html) and has been tested on both GPUs and Cloud TPU VMs with Python 3.11. 24 | 25 | For a PyTorch model definition, please refer to [this link](https://github.com/test-time-training/ttt-lm-pytorch). For inference kernels, or to replicate speed benchmarks from our paper, please view our [kernel implementations](https://github.com/test-time-training/ttt-lm-kernels). 26 | 27 | ### Environment Installation 28 | To setup and run our code on a (local) GPU machine, we highly recommend using [Anaconda](https://anaconda.com/download) when installing python dependencies. Install GPU requirements using: 29 | ``` 30 | cd requirements 31 | pip install -r gpu_requirements.txt 32 | ``` 33 | 34 | For TPU, please refer to [this link](https://cloud.google.com/tpu/docs/quick-starts) for guidance on creating cloud TPU VMs. Then, run: 35 | ``` 36 | cd requirements 37 | pip install -r tpu_requirements.txt 38 | ``` 39 | 40 | ### WandB Login 41 | We use WandB for logging training metrics and TTT statistics. After installing the requirements, login to WandB using: 42 | ``` 43 | wandb login 44 | ``` 45 | 46 | 47 | ### Dataset Download 48 | Our Llama-2 tokenized datasets are available for download from Google Cloud Buckets: 49 | 50 | ``` 51 | gsutil -m cp -r gs://llama-2-pile/* llama-2-pile/ 52 | gsutil -m cp -r gs://llama-2-books3/* llama-2-books3/ 53 | ``` 54 | 55 | Once downloaded, set the `dataset_path` flag in `train.py` to the directory containing the `tokenizer_name-meta-llama` folder. This will allow the dataloader to find the correct path. 56 | 57 | Alternatively, to tokenize datasets yourself, refer to [dataset preparation](ttt/dataloader/README.md). 58 | 59 | ## Replicating Experiments 60 | We provide scripts corresponding to each experiment in our paper in the `scripts` folder. After specifying the experiment name and directory, select the desired context length and divide by 0.5 million to calculate the appropriate batch size. 61 | 62 | Depending on the model size, you may need to modify the `mesh_dim` to introduce model sharding. See the [model docs](ttt/README.md) for additional information on the training configuration. 63 | 64 | ## Credits 65 | * This codebase is based on [EasyLM](https://github.com/young-geng/EasyLM). 66 | * Our dataloader is based on [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main/training). 67 | -------------------------------------------------------------------------------- /requirements/gpu_requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 2 | numpy==1.26.4 3 | matplotlib==3.9.0 4 | tqdm 5 | --extra-index-url https://download.pytorch.org/whl/cpu 6 | torch==2.3.0 7 | jax==0.4.14 8 | jaxlib==0.4.14+cuda12.cudnn89 9 | flax==0.7.0 10 | optax==0.1.7 11 | transformers==4.41.0 12 | datasets 13 | mlxu>=0.1.13 14 | einops 15 | ml_collections 16 | scipy==1.12.0 -------------------------------------------------------------------------------- /requirements/tpu_requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 2 | jax[tpu]==0.4.14 3 | numpy==1.26.4 4 | matplotlib 5 | tqdm 6 | --extra-index-url https://download.pytorch.org/whl/cpu 7 | torch==2.3.0 8 | flax==0.7.0 9 | optax==0.1.7 10 | transformers==4.41.0 11 | datasets 12 | mlxu>=0.1.13 13 | einops 14 | ml_collections 15 | scipy==1.12.0 -------------------------------------------------------------------------------- /scripts/transformer/1.3b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=TODO 4 | DATA_NAME="the_pile" # "books3" 5 | 6 | # Product should equal 0.5 million 7 | SEQ_LEN=2048 8 | BS=256 9 | 10 | # Experiment details 11 | EXP_NAME=TODO 12 | EXP_DIR=TODO 13 | 14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME}; 15 | cd ../.. 16 | 17 | python3 -m ttt.train \ 18 | --mesh_dim='!-1,1,1' \ 19 | --dtype='fp32' \ 20 | --total_steps=50000 \ 21 | --save_checkpoint_freq=1000 \ 22 | --save_milestone_freq=2000 \ 23 | --load_model_config='1b' \ 24 | --dataset_path=${DATA_PATH} \ 25 | --dataset_name=${DATA_NAME} \ 26 | --seq_length=${SEQ_LEN} \ 27 | --global_batch_size=${BS} \ 28 | --optimizer.type='adamw' \ 29 | --optimizer.adamw_optimizer.weight_decay=0.1 \ 30 | --optimizer.adamw_optimizer.lr=1e-3 \ 31 | --optimizer.adamw_optimizer.end_lr=1e-5 \ 32 | --optimizer.adamw_optimizer.lr_warmup_steps=5000 \ 33 | --optimizer.adamw_optimizer.lr_decay_steps=50000 \ 34 | --exp_dir=${EXP_DIR} \ 35 | --exp_name=${EXP_NAME} -------------------------------------------------------------------------------- /scripts/transformer/125m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=TODO 4 | DATA_NAME="the_pile" # "books3" 5 | 6 | # Product should equal 0.5 million 7 | SEQ_LEN=2048 8 | BS=256 9 | 10 | # Experiment details 11 | EXP_NAME=TODO 12 | EXP_DIR=TODO 13 | 14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME}; 15 | cd ../.. 16 | 17 | python3 -m ttt.train \ 18 | --mesh_dim='!-1,1,1' \ 19 | --dtype='fp32' \ 20 | --total_steps=4800 \ 21 | --save_checkpoint_freq=1000 \ 22 | --save_milestone_freq=2000 \ 23 | --load_model_config='125m' \ 24 | --dataset_path=${DATA_PATH} \ 25 | --dataset_name=${DATA_NAME} \ 26 | --seq_length=${SEQ_LEN} \ 27 | --global_batch_size=${BS} \ 28 | --optimizer.type='adamw' \ 29 | --optimizer.adamw_optimizer.weight_decay=0.1 \ 30 | --optimizer.adamw_optimizer.lr=3e-3 \ 31 | --optimizer.adamw_optimizer.end_lr=1e-5 \ 32 | --optimizer.adamw_optimizer.lr_warmup_steps=480 \ 33 | --optimizer.adamw_optimizer.lr_decay_steps=4800 \ 34 | --exp_dir=${EXP_DIR} \ 35 | --exp_name=${EXP_NAME} -------------------------------------------------------------------------------- /scripts/transformer/350m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=TODO 4 | DATA_NAME="the_pile" # "books3" 5 | 6 | # Product should equal 0.5 million 7 | SEQ_LEN=2048 8 | BS=256 9 | 10 | # Experiment details 11 | EXP_NAME=TODO 12 | EXP_DIR=TODO 13 | 14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME}; 15 | cd ../.. 16 | 17 | python3 -m ttt.train \ 18 | --mesh_dim='!-1,1,1' \ 19 | --dtype='fp32' \ 20 | --total_steps=13500 \ 21 | --save_checkpoint_freq=1000 \ 22 | --save_milestone_freq=2000 \ 23 | --load_model_config='350m' \ 24 | --dataset_path=${DATA_PATH} \ 25 | --dataset_name=${DATA_NAME} \ 26 | --seq_length=${SEQ_LEN} \ 27 | --global_batch_size=${BS} \ 28 | --optimizer.type='adamw' \ 29 | --optimizer.adamw_optimizer.weight_decay=0.1 \ 30 | --optimizer.adamw_optimizer.lr=1.5e-3 \ 31 | --optimizer.adamw_optimizer.end_lr=1e-5 \ 32 | --optimizer.adamw_optimizer.lr_warmup_steps=1350 \ 33 | --optimizer.adamw_optimizer.lr_decay_steps=13500 \ 34 | --exp_dir=${EXP_DIR} \ 35 | --exp_name=${EXP_NAME} -------------------------------------------------------------------------------- /scripts/transformer/760m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=TODO 4 | DATA_NAME="the_pile" # "books3" 5 | 6 | # Product should equal 0.5 million 7 | SEQ_LEN=2048 8 | BS=256 9 | 10 | # Experiment details 11 | EXP_NAME=TODO 12 | EXP_DIR=TODO 13 | 14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME}; 15 | cd ../.. 16 | 17 | python3 -m ttt.train \ 18 | --mesh_dim='!-1,1,1' \ 19 | --dtype='fp32' \ 20 | --total_steps=29000 \ 21 | --save_checkpoint_freq=1000 \ 22 | --save_milestone_freq=2000 \ 23 | --load_model_config='760m' \ 24 | --dataset_path=${DATA_PATH} \ 25 | --dataset_name=${DATA_NAME} \ 26 | --seq_length=${SEQ_LEN} \ 27 | --global_batch_size=${BS} \ 28 | --optimizer.type='adamw' \ 29 | --optimizer.adamw_optimizer.weight_decay=0.1 \ 30 | --optimizer.adamw_optimizer.lr=1.25e-3 \ 31 | --optimizer.adamw_optimizer.end_lr=1e-5 \ 32 | --optimizer.adamw_optimizer.lr_warmup_steps=2900 \ 33 | --optimizer.adamw_optimizer.lr_decay_steps=29000 \ 34 | --exp_dir=${EXP_DIR} \ 35 | --exp_name=${EXP_NAME} -------------------------------------------------------------------------------- /scripts/ttt_linear/1.3b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=TODO 4 | DATA_NAME="the_pile" # "books3" 5 | 6 | # Product should equal 0.5 million 7 | SEQ_LEN=2048 8 | BS=256 9 | 10 | # Experiment details 11 | EXP_NAME=TODO 12 | EXP_DIR=TODO 13 | 14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME}; 15 | cd ../.. 16 | 17 | python3 -m ttt.train \ 18 | --mesh_dim='!-1,1,1' \ 19 | --dtype='fp32' \ 20 | --total_steps=50000 \ 21 | --save_checkpoint_freq=1000 \ 22 | --save_milestone_freq=2000 \ 23 | --load_model_config='1b-TTT' \ 24 | --update_model_config="dict(seq_modeling_block='ttt_linear', ttt_base_lr=1.0)" \ 25 | --dataset_path=${DATA_PATH} \ 26 | --dataset_name=${DATA_NAME} \ 27 | --seq_length=${SEQ_LEN} \ 28 | --global_batch_size=${BS} \ 29 | --optimizer.type='adamw' \ 30 | --optimizer.adamw_optimizer.weight_decay=0.1 \ 31 | --optimizer.adamw_optimizer.lr=1e-3 \ 32 | --optimizer.adamw_optimizer.end_lr=1e-5 \ 33 | --optimizer.adamw_optimizer.lr_warmup_steps=5000 \ 34 | --optimizer.adamw_optimizer.lr_decay_steps=50000 \ 35 | --exp_dir=${EXP_DIR} \ 36 | --exp_name=${EXP_NAME} -------------------------------------------------------------------------------- /scripts/ttt_linear/125m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=TODO 4 | DATA_NAME="the_pile" # "books3" 5 | 6 | # Product should equal 0.5 million 7 | SEQ_LEN=2048 8 | BS=256 9 | 10 | # Experiment details 11 | EXP_NAME=TODO 12 | EXP_DIR=TODO 13 | 14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME}; 15 | cd ../.. 16 | 17 | python3 -m ttt.train \ 18 | --mesh_dim='!-1,1,1' \ 19 | --dtype='fp32' \ 20 | --total_steps=4800 \ 21 | --save_checkpoint_freq=1000 \ 22 | --save_milestone_freq=2000 \ 23 | --load_model_config='125m-TTT' \ 24 | --update_model_config="dict(seq_modeling_block='ttt_linear', ttt_base_lr=1.0)" \ 25 | --dataset_path=${DATA_PATH} \ 26 | --dataset_name=${DATA_NAME} \ 27 | --seq_length=${SEQ_LEN} \ 28 | --global_batch_size=${BS} \ 29 | --optimizer.type='adamw' \ 30 | --optimizer.adamw_optimizer.weight_decay=0.1 \ 31 | --optimizer.adamw_optimizer.lr=3e-3 \ 32 | --optimizer.adamw_optimizer.end_lr=1e-5 \ 33 | --optimizer.adamw_optimizer.lr_warmup_steps=480 \ 34 | --optimizer.adamw_optimizer.lr_decay_steps=4800 \ 35 | --exp_dir=${EXP_DIR} \ 36 | --exp_name=${EXP_NAME} -------------------------------------------------------------------------------- /scripts/ttt_linear/350m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=TODO 4 | DATA_NAME="the_pile" # "books3" 5 | 6 | # Product should equal 0.5 million 7 | SEQ_LEN=2048 8 | BS=256 9 | 10 | # Experiment details 11 | EXP_NAME=TODO 12 | EXP_DIR=TODO 13 | 14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME}; 15 | cd ../.. 16 | 17 | python3 -m ttt.train \ 18 | --mesh_dim='!-1,1,1' \ 19 | --dtype='fp32' \ 20 | --total_steps=13500 \ 21 | --save_checkpoint_freq=1000 \ 22 | --save_milestone_freq=2000 \ 23 | --load_model_config='350m-TTT' \ 24 | --update_model_config="dict(seq_modeling_block='ttt_linear', ttt_base_lr=1.0)" \ 25 | --dataset_path=${DATA_PATH} \ 26 | --dataset_name=${DATA_NAME} \ 27 | --seq_length=${SEQ_LEN} \ 28 | --global_batch_size=${BS} \ 29 | --optimizer.type='adamw' \ 30 | --optimizer.adamw_optimizer.weight_decay=0.1 \ 31 | --optimizer.adamw_optimizer.lr=1.5e-3 \ 32 | --optimizer.adamw_optimizer.end_lr=1e-5 \ 33 | --optimizer.adamw_optimizer.lr_warmup_steps=1350 \ 34 | --optimizer.adamw_optimizer.lr_decay_steps=13500 \ 35 | --exp_dir=${EXP_DIR} \ 36 | --exp_name=${EXP_NAME} -------------------------------------------------------------------------------- /scripts/ttt_linear/760m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=TODO 4 | DATA_NAME="the_pile" # "books3" 5 | 6 | # Product should equal 0.5 million 7 | SEQ_LEN=2048 8 | BS=256 9 | 10 | # Experiment details 11 | EXP_NAME=TODO 12 | EXP_DIR=TODO 13 | 14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME}; 15 | cd ../.. 16 | 17 | python3 -m ttt.train \ 18 | --mesh_dim='!-1,1,1' \ 19 | --dtype='fp32' \ 20 | --total_steps=29000 \ 21 | --save_checkpoint_freq=1000 \ 22 | --save_milestone_freq=2000 \ 23 | --load_model_config='760m-TTT' \ 24 | --update_model_config="dict(seq_modeling_block='ttt_linear', ttt_base_lr=1.0)" \ 25 | --dataset_path=${DATA_PATH} \ 26 | --dataset_name=${DATA_NAME} \ 27 | --seq_length=${SEQ_LEN} \ 28 | --global_batch_size=${BS} \ 29 | --optimizer.type='adamw' \ 30 | --optimizer.adamw_optimizer.weight_decay=0.1 \ 31 | --optimizer.adamw_optimizer.lr=1.25e-3 \ 32 | --optimizer.adamw_optimizer.end_lr=1e-5 \ 33 | --optimizer.adamw_optimizer.lr_warmup_steps=2900 \ 34 | --optimizer.adamw_optimizer.lr_decay_steps=29000 \ 35 | --exp_dir=${EXP_DIR} \ 36 | --exp_name=${EXP_NAME} -------------------------------------------------------------------------------- /scripts/ttt_mlp/1.3b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=TODO 4 | DATA_NAME="the_pile" # "books3" 5 | 6 | # Product should equal 0.5 million 7 | SEQ_LEN=2048 8 | BS=256 9 | 10 | # Experiment details 11 | EXP_NAME=TODO 12 | EXP_DIR=TODO 13 | 14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME}; 15 | cd ../.. 16 | 17 | python3 -m ttt.train \ 18 | --mesh_dim='!-1,1,1' \ 19 | --dtype='fp32' \ 20 | --total_steps=50000 \ 21 | --save_checkpoint_freq=1000 \ 22 | --save_milestone_freq=2000 \ 23 | --load_model_config='1b-TTT' \ 24 | --update_model_config="dict(seq_modeling_block='ttt_mlp', ttt_base_lr=0.1, ttt_base_lr_init=0.01, ttt_base_lr_warmup=5000)" \ 25 | --dataset_path=${DATA_PATH} \ 26 | --dataset_name=${DATA_NAME} \ 27 | --seq_length=${SEQ_LEN} \ 28 | --global_batch_size=${BS} \ 29 | --optimizer.type='adamw' \ 30 | --optimizer.adamw_optimizer.weight_decay=0.1 \ 31 | --optimizer.adamw_optimizer.lr=1e-3 \ 32 | --optimizer.adamw_optimizer.end_lr=1e-5 \ 33 | --optimizer.adamw_optimizer.lr_warmup_steps=5000 \ 34 | --optimizer.adamw_optimizer.lr_decay_steps=50000 \ 35 | --exp_dir=${EXP_DIR} \ 36 | --exp_name=${EXP_NAME} -------------------------------------------------------------------------------- /scripts/ttt_mlp/125m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=TODO 4 | DATA_NAME="the_pile" # "books3" 5 | 6 | # Product should equal 0.5 million 7 | SEQ_LEN=2048 8 | BS=256 9 | 10 | # Experiment details 11 | EXP_NAME=TODO 12 | EXP_DIR=TODO 13 | 14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME}; 15 | cd ../.. 16 | 17 | python3 -m ttt.train \ 18 | --mesh_dim='!-1,1,1' \ 19 | --dtype='fp32' \ 20 | --total_steps=4800 \ 21 | --save_checkpoint_freq=1000 \ 22 | --save_milestone_freq=2000 \ 23 | --load_model_config='125m-TTT' \ 24 | --update_model_config="dict(seq_modeling_block='ttt_mlp', ttt_base_lr=0.1, ttt_base_lr_init=0.01, ttt_base_lr_warmup=480)" \ 25 | --dataset_path=${DATA_PATH} \ 26 | --dataset_name=${DATA_NAME} \ 27 | --seq_length=${SEQ_LEN} \ 28 | --global_batch_size=${BS} \ 29 | --optimizer.type='adamw' \ 30 | --optimizer.adamw_optimizer.weight_decay=0.1 \ 31 | --optimizer.adamw_optimizer.lr=3e-3 \ 32 | --optimizer.adamw_optimizer.end_lr=1e-5 \ 33 | --optimizer.adamw_optimizer.lr_warmup_steps=480 \ 34 | --optimizer.adamw_optimizer.lr_decay_steps=4800 \ 35 | --exp_dir=${EXP_DIR} \ 36 | --exp_name=${EXP_NAME} -------------------------------------------------------------------------------- /scripts/ttt_mlp/350m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=TODO 4 | DATA_NAME="the_pile" # "books3" 5 | 6 | # Product should equal 0.5 million 7 | SEQ_LEN=2048 8 | BS=256 9 | 10 | # Experiment details 11 | EXP_NAME=TODO 12 | EXP_DIR=TODO 13 | 14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME}; 15 | cd ../.. 16 | 17 | python3 -m ttt.train \ 18 | --mesh_dim='!-1,1,1' \ 19 | --dtype='fp32' \ 20 | --total_steps=13500 \ 21 | --save_checkpoint_freq=1000 \ 22 | --save_milestone_freq=2000 \ 23 | --load_model_config='350m-TTT' \ 24 | --update_model_config="dict(seq_modeling_block='ttt_mlp', ttt_base_lr=0.1, ttt_base_lr_init=0.01, ttt_base_lr_warmup=1350)" \ 25 | --dataset_path=${DATA_PATH} \ 26 | --dataset_name=${DATA_NAME} \ 27 | --seq_length=${SEQ_LEN} \ 28 | --global_batch_size=${BS} \ 29 | --optimizer.type='adamw' \ 30 | --optimizer.adamw_optimizer.weight_decay=0.1 \ 31 | --optimizer.adamw_optimizer.lr=1.5e-3 \ 32 | --optimizer.adamw_optimizer.end_lr=1e-5 \ 33 | --optimizer.adamw_optimizer.lr_warmup_steps=1350 \ 34 | --optimizer.adamw_optimizer.lr_decay_steps=13500 \ 35 | --exp_dir=${EXP_DIR} \ 36 | --exp_name=${EXP_NAME} -------------------------------------------------------------------------------- /scripts/ttt_mlp/760m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=TODO 4 | DATA_NAME="the_pile" # "books3" 5 | 6 | # Product should equal 0.5 million 7 | SEQ_LEN=2048 8 | BS=256 9 | 10 | # Experiment details 11 | EXP_NAME=TODO 12 | EXP_DIR=TODO 13 | 14 | sudo mkdir -p /${EXP_DIR}/${EXP_NAME} && sudo chmod -R 777 ${EXP_DIR}/${EXP_NAME}; 15 | cd ../.. 16 | 17 | python3 -m ttt.train \ 18 | --mesh_dim='!-1,1,1' \ 19 | --dtype='fp32' \ 20 | --total_steps=29000 \ 21 | --save_checkpoint_freq=1000 \ 22 | --save_milestone_freq=2000 \ 23 | --load_model_config='760m-TTT' \ 24 | --update_model_config="dict(seq_modeling_block='ttt_mlp', ttt_base_lr=0.1, ttt_base_lr_init=0.01, ttt_base_lr_warmup=2900)" \ 25 | --dataset_path=${DATA_PATH} \ 26 | --dataset_name=${DATA_NAME} \ 27 | --seq_length=${SEQ_LEN} \ 28 | --global_batch_size=${BS} \ 29 | --optimizer.type='adamw' \ 30 | --optimizer.adamw_optimizer.weight_decay=0.1 \ 31 | --optimizer.adamw_optimizer.lr=1.25e-3 \ 32 | --optimizer.adamw_optimizer.end_lr=1e-5 \ 33 | --optimizer.adamw_optimizer.lr_warmup_steps=2900 \ 34 | --optimizer.adamw_optimizer.lr_decay_steps=29000 \ 35 | --exp_dir=${EXP_DIR} \ 36 | --exp_name=${EXP_NAME} -------------------------------------------------------------------------------- /ttt/README.md: -------------------------------------------------------------------------------- 1 | # Model Documentation 2 | 3 | This codebase is implemented in [JAX](https://jax.readthedocs.io/en/latest/index.html) and is based on [EasyLM](https://github.com/young-geng/EasyLM/tree/main). 4 | 5 | ## Training Flags 6 | - `mesh_dim` refers to the the mesh used by JAX to parallelize computation across multiple accelerators and hosts. Please refer to the [EasyLM paralellization documentation](https://github.com/young-geng/EasyLM/blob/main/docs/parallelism.md) for configuration. 7 | - `seq_length` and `global_batch_size` determine the total number of tokens per batch (fixed to 0.5 million in our paper). 8 | - `load_model_config` is used to load a default configs from `model.py` 9 | - `update_model_config` is used to update a default config. To update specific keys, pass a dictionary to the flag: 10 | 11 | ``` 12 | --update_model_config="dict(seq_modeling_block='ttt_linear', ttt_base_lr=1.0)" 13 | ``` 14 | 15 | All additional hyperparameters are specified Appendix C of our paper. 16 | 17 | ## Model Flags 18 | All model configuration flags can be found in `model.py`. Here are a few important details to note: 19 | 20 | We implement four TTT choices for the `seq_modeling_block`: 21 | - `ttt_linear` and `ttt_mlp`, which specify TTT layers within the **Mamba backbone**. 22 | - `ttt_linear_base` and `ttt_mlp_base`, which specify TTT layers within the **Transformer backbone**. 23 | 24 | ### TTT LR 25 | - For all `ttt_linear` experiments, `ttt_base_lr` is set to 1.0. 26 | - For all `ttt_mlp` experiments: 27 | - `ttt_base_lr` is set to 0.1 28 | - `ttt_base_lr_init` is set to 0.01 29 | - `ttt_base_lr_warmup` is set to the total number of outer loop warmup steps. 30 | -------------------------------------------------------------------------------- /ttt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/test-time-training/ttt-lm-jax/ac8cdc8a43b811afe23c27d7e82eed34a747b19c/ttt/__init__.py -------------------------------------------------------------------------------- /ttt/dataloader/README.md: -------------------------------------------------------------------------------- 1 |

Dataset Preparation

2 | 3 |

Option 1: Download Pre-Tokenized Datasets (Recommended)

4 | 5 | Our Llama-2 tokenized datasets are available for download from Google Cloud Buckets: 6 | 7 | ``` 8 | gsutil -m cp -r gs://llama-2-pile/* llama-2-pile/ 9 | gsutil -m cp -r gs://llama-2-books3/* llama-2-books3/ 10 | ``` 11 | 12 | Once downloaded, set the `dataset_path` flag in `train.py` to the directory containing the `tokenizer_name-meta-llama` folder. This will allow the dataloader to find the correct path. 13 | 14 |

Option 2: Tokenize Datasets Yourself

15 | 16 | Since the raw Pile and Books3 datasets are no longer publically available on Huggingface, we recommend acquiring them via correspondence to their authors or from the community. 17 | 18 | Before tokenization, set `raw_json_path` and `cache_dir` in `tokenization.py` to the path where the raw dataset (in json format) is stored and where you want to store the tokenized dataset, respectively. 19 | 20 | Our tokenization script is based on [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main/training#dataset-preparation). Tokenize the raw datasets using the commands below. 21 | 22 | **Pile:** 23 | ``` 24 | export PYTHONPATH=$PWD:$PYTHONPATH 25 | pytest -q -s ttt/dataloader/tokenization.py -k "pile" 26 | ``` 27 | This takes around 20h on a 64-core CPU. The processed dataset is 716G. 28 | 29 | **Books3:** 30 | ``` 31 | export PYTHONPATH=$PWD:$PYTHONPATH 32 | pytest -q -s ttt/dataloader/tokenization.py -k "books" 33 | ``` 34 | This takes around 3h on a 64-core CPU. The processed dataset is 61G. 35 | -------------------------------------------------------------------------------- /ttt/dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/test-time-training/ttt-lm-jax/ac8cdc8a43b811afe23c27d7e82eed34a747b19c/ttt/dataloader/__init__.py -------------------------------------------------------------------------------- /ttt/dataloader/language_modeling_hf.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py 2 | import sys 3 | import os.path as osp 4 | from itertools import chain 5 | from pathlib import Path 6 | import pickle 7 | from typing import Any, List, Union 8 | import subprocess 9 | import mmap 10 | 11 | import numpy as np 12 | from torch.utils.data.dataloader import DataLoader, Dataset 13 | from transformers import AutoTokenizer 14 | from datasets import load_dataset 15 | 16 | from ttt.dataloader.lm_dataset import RandomFaultTolerantSampler, LMDataset 17 | from ttt.infra.jax_utils import master_print 18 | 19 | 20 | class LMDataModule: 21 | def __init__( 22 | self, 23 | dataset_name, 24 | tokenizer_name, 25 | dataset_config_name=None, 26 | max_length=1024, 27 | cache_dir=None, 28 | raw_json_path=None, 29 | val_ratio=0.0005, 30 | val_split_seed=2357, 31 | add_eos=True, 32 | batch_size=32, 33 | batch_size_eval=None, 34 | num_workers=1, 35 | loader_workers=1, 36 | shuffle=False, 37 | pin_memory=False, 38 | drop_last=False, 39 | fault_tolerant=False, 40 | ): 41 | super().__init__() 42 | self.dataset_name = dataset_name 43 | self.dataset_config_name = dataset_config_name 44 | self.tokenizer_name = tokenizer_name 45 | self.cache_dir = None if cache_dir is None else Path(cache_dir).expanduser() 46 | self.raw_json_path = raw_json_path 47 | self.max_length = max_length 48 | self.val_ratio = val_ratio 49 | self.val_split_seed = val_split_seed 50 | self.add_eos = add_eos 51 | self.batch_size = batch_size 52 | self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size 53 | self.num_workers = num_workers 54 | self.loader_workers = loader_workers 55 | self.shuffle = shuffle 56 | self.pin_memory = pin_memory 57 | self.drop_last = drop_last 58 | if fault_tolerant: 59 | assert self.shuffle 60 | self.fault_tolerant = fault_tolerant 61 | 62 | def prepare_data(self): 63 | if self.cache_dir is None: 64 | # Just download the dataset 65 | load_dataset(self.dataset_name, self.dataset_config_name) 66 | else: 67 | # Process the dataset and save it 68 | self.process_dataset() 69 | 70 | def setup(self, stage=None): 71 | if stage == "test" and hasattr(self, "dataset_test"): 72 | return 73 | concat_ids, self.tokenizer = self.process_dataset() 74 | self.vocab_size = len(self.tokenizer) 75 | self.dataset_train, self.dataset_val, self.dataset_test = [ 76 | LMDataset( 77 | concat_ids[split], seq_len=self.max_length, llama2=(self.tokenizer_name == "meta-llama/Llama-2-7b-hf") 78 | ) 79 | for split in ["train", "validation", "test"] 80 | ] 81 | 82 | def process_dataset(self): 83 | cache_dir = None if self.cache_dir is None else self.cache_dir / self._cache_dir_name 84 | 85 | if cache_dir is not None: 86 | if cache_dir.is_dir(): 87 | return self._load_from_cache(cache_dir) 88 | 89 | if self.raw_json_path is not None: 90 | raw_datasets = load_dataset("json", data_files=self.raw_json_path) 91 | else: 92 | raw_datasets = load_dataset(self.dataset_name, self.dataset_config_name) 93 | 94 | # https://github.com/stanford-crfm/mistral/blob/main/src/corpora/auto.py 95 | if "validation" not in raw_datasets: 96 | assert "train" in raw_datasets, "You must have train in raw_datasets to make a validation raw_datasets" 97 | raw_datasets = raw_datasets["train"].train_test_split( 98 | test_size=self.val_ratio, 99 | seed=self.val_split_seed, 100 | shuffle=True, # Otherwise test will be at the end of the dataset 101 | ) 102 | raw_datasets["validation"] = raw_datasets["test"] 103 | 104 | tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True) 105 | # Preprocessing the datasets. 106 | # First we tokenize all the texts. 107 | column_names = raw_datasets["train"].column_names 108 | text_column_name = "text" if "text" in column_names else column_names[0] 109 | if self.add_eos: 110 | add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq 111 | add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] 112 | tokenize = lambda example: tokenizer(add_eos_batched(example[text_column_name])) 113 | else: 114 | tokenize = lambda example: tokenizer(example[text_column_name]) 115 | 116 | dtype = np.uint16 if tokenizer.vocab_size < 64 * 1024 else np.int32 117 | 118 | def tokenize_concat(examples): 119 | # We just need 'input_ids', not 'attention_mask' (since it's all 1) 120 | input_ids = np.fromiter(chain(*tokenize(examples)["input_ids"]), dtype=dtype) 121 | # Need to return a list since we're doing batched processing 122 | return {"input_ids": [input_ids], "len": [len(input_ids)]} 123 | 124 | tokenized_datasets = raw_datasets.map( 125 | tokenize_concat, 126 | batched=True, 127 | num_proc=max(self.num_workers, 1), 128 | remove_columns=column_names, 129 | desc="Running tokenizer on dataset", 130 | ) 131 | 132 | # Use disk 133 | concat_ids = {} 134 | assert cache_dir is not None 135 | cache_dir.mkdir(parents=True, exist_ok=True) 136 | 137 | def write_ids_to_disk(example, filename): 138 | with open(filename, "r+b") as f: 139 | mm = mmap.mmap(f.fileno(), 0) 140 | start_idx = example["len_offset"] - len(example["input_ids"]) 141 | array_len = len(example["input_ids"]) 142 | arr = np.ndarray((array_len,), dtype=dtype, buffer=mm, offset=np.dtype(dtype).itemsize * start_idx) 143 | arr[:] = example["input_ids"] 144 | mm.flush() 145 | 146 | for name, ds in tokenized_datasets.items(): 147 | tokenized_datasets[name] = ds.add_column("len_offset", np.cumsum(ds["len"])) 148 | array_len = tokenized_datasets[name][-1]["len_offset"] 149 | 150 | filename = cache_dir / f"{name}.bin" 151 | 152 | # Need to create the file with this specific size first 153 | # https://ostechnix.com/create-files-certain-size-linux/ 154 | subprocess.run(["truncate", "-s", str(array_len * np.dtype(dtype).itemsize), str(filename)], check=True) 155 | 156 | tokenized_datasets[name].map( 157 | write_ids_to_disk, 158 | fn_kwargs={"filename": filename}, # .bin 159 | batched=False, 160 | num_proc=max(self.num_workers, 1), 161 | desc="Concatenating examples", 162 | ) 163 | concat_ids[name] = np.memmap(filename, dtype=dtype, mode="r", shape=(array_len,)) 164 | 165 | if cache_dir is not None: 166 | self._save_to_cache(concat_ids, tokenizer, cache_dir) 167 | 168 | for name in concat_ids: 169 | Path(cache_dir / f"{name}.bin").unlink() 170 | 171 | return concat_ids, tokenizer 172 | 173 | def _save_to_cache(self, concat_ids, tokenizer, cache_dir): 174 | cache_dir.mkdir(parents=True, exist_ok=True) 175 | master_print(f"Saving to cache at {str(cache_dir)}") 176 | for k, v in concat_ids.items(): 177 | np.save(cache_dir / f"{k}.npy", v) 178 | with open(cache_dir / "tokenizer.pkl", "wb") as f: 179 | pickle.dump(tokenizer, f) 180 | 181 | def _load_from_cache(self, cache_dir): 182 | assert cache_dir.is_dir() 183 | master_print(f"Load from cache at {str(cache_dir)}") 184 | concat_ids = { 185 | split: np.load(cache_dir / f"{split}.npy", mmap_mode="r") for split in ["train", "validation", "test"] 186 | } 187 | with open(cache_dir / "tokenizer.pkl", "rb") as f: 188 | tokenizer = pickle.load(f) 189 | return concat_ids, tokenizer 190 | 191 | @property 192 | def _cache_dir_name(self): 193 | return ( 194 | f"tokenizer_name-{self.tokenizer_name}-val_ratio-{self.val_ratio}-" 195 | f"val_split_seed-{self.val_split_seed}-add_eos-{self.add_eos}-detokenize-False" 196 | ) 197 | 198 | def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: 199 | """The train dataloader""" 200 | if self.shuffle and self.fault_tolerant: 201 | shuffle = False 202 | sampler = RandomFaultTolerantSampler(self.dataset_train) 203 | else: 204 | shuffle = self.shuffle 205 | sampler = None 206 | 207 | return self._data_loader(self.dataset_train, batch_size=self.batch_size, shuffle=shuffle, sampler=sampler) 208 | 209 | def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: 210 | """The val dataloader""" 211 | return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval) 212 | 213 | def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: 214 | """The test dataloader""" 215 | return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval) 216 | 217 | def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False, sampler=None) -> DataLoader: 218 | return DataLoader( 219 | dataset, 220 | batch_size=batch_size, 221 | num_workers=self.loader_workers, 222 | shuffle=shuffle, 223 | sampler=sampler, 224 | drop_last=self.drop_last, 225 | pin_memory=self.pin_memory, 226 | ) 227 | -------------------------------------------------------------------------------- /ttt/dataloader/lm_dataset.py: -------------------------------------------------------------------------------- 1 | # Inspired by https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt/datasets.py 2 | # Except we don't pad the last block and don't use overlapping eval 3 | # And we return both the input and the target 4 | import math 5 | from typing import Iterator 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import RandomSampler 10 | 11 | 12 | class RandomFaultTolerantSampler(RandomSampler): 13 | def __init__(self, *args, generator=None, **kwargs): 14 | if generator is None: 15 | seed = int(torch.empty((), dtype=torch.int64).random_().item()) 16 | generator = torch.Generator().manual_seed(seed) 17 | super().__init__(*args, generator=generator, **kwargs) 18 | self.counter = 0 # Absolute position of data reading 19 | self.is_rollback = False 20 | self.state = self.generator.get_state() # Record the initial state of generator determined by seed 21 | # Should not be changed before an entire loop over dataset is done 22 | # Give same seed, generator state change deterministically after each torch.randperm 23 | self.shuffle_log = [{"shuffle_after": self.counter}] 24 | 25 | def state_dict(self): 26 | return {"random_state": self.state, "counter": self.counter, "shuffle_log": self.shuffle_log} 27 | 28 | def load_state_dict(self, state_dict): 29 | self.state = state_dict["random_state"] 30 | self.counter = state_dict["counter"] 31 | if "shuffle_log" in state_dict: 32 | self.shuffle_log = state_dict["shuffle_log"] # A list of shuffle records, each record is a dict 33 | 34 | def update_shuffle_history(self): 35 | self.shuffle_log.append({"shuffle_after": self.counter}) 36 | 37 | def go_through_shuffle_history(self): 38 | N = len(self.data_source) 39 | initial_shuffle_after = self.shuffle_log[0]["shuffle_after"] 40 | self.generator.set_state(self.state) 41 | indices = torch.randperm(N, generator=self.generator) 42 | 43 | for shuffle_record in self.shuffle_log[1:]: 44 | shuffle_after = shuffle_record["shuffle_after"] 45 | new_order = torch.randperm(N - shuffle_after, generator=self.generator) # 46 | indices = torch.concatenate([indices[:shuffle_after], indices[shuffle_after:][new_order]]) 47 | 48 | return indices 49 | 50 | def __iter__(self) -> Iterator[int]: 51 | 52 | if self.is_rollback: 53 | # Before entering __iter__() due to rollback, set loader.sampler.is_rollback = True manually outside 54 | self.update_shuffle_history() # Add a shuffle action at self.counter, which is where we resume from but need a different coming data order 55 | self.is_rollback = False 56 | 57 | indices = self.go_through_shuffle_history() 58 | indices = indices[self.counter :].tolist() 59 | 60 | for index in indices: 61 | self.counter += 1 62 | yield index 63 | 64 | # End of one loop over the entire dataset 65 | self.counter = 0 66 | self.state = self.generator.get_state() # If have the next epoch, state will definitely be different 67 | self.shuffle_log = [{"shuffle_after": self.counter}] 68 | 69 | 70 | class LMDataset(torch.utils.data.Dataset): 71 | def __init__(self, tokens, seq_len, drop_last=True, llama2=False): 72 | """tokens should be a numpy array""" 73 | self.seq_len = seq_len 74 | ntokens = len(tokens) 75 | if drop_last: 76 | ntokens = ((ntokens - 1) // seq_len) * seq_len + 1 77 | self.ntokens = ntokens 78 | # We're careful not to slice tokens, since it could be a memmap'ed array or H5 dataset, 79 | # and slicing would load it to memory. 80 | self.tokens = tokens 81 | self.total_sequences = math.ceil((self.ntokens - 1) / self.seq_len) 82 | self.llama2 = llama2 83 | 84 | def __len__(self): 85 | return self.total_sequences 86 | 87 | def __getitem__(self, idx): 88 | idx = idx % self.ntokens 89 | start_idx = idx * self.seq_len 90 | seq_len = min(self.seq_len, self.ntokens - 1 - start_idx) 91 | data = torch.as_tensor(self.tokens[start_idx : (start_idx + seq_len + 1)].astype(np.int32)) 92 | if self.llama2: 93 | return { 94 | "input_tokens": data[:-1], 95 | "target_tokens": data[1:].clone(), 96 | "loss_masks": (data[1:] != 1).to(torch.float32), 97 | } 98 | else: 99 | return { 100 | "input_tokens": data[:-1], 101 | "target_tokens": data[1:].clone(), 102 | "loss_masks": torch.ones_like(data[:-1], dtype=torch.float32), 103 | } 104 | -------------------------------------------------------------------------------- /ttt/dataloader/tokenization.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | current_dir = Path(__file__).parent.absolute() 5 | 6 | 7 | import pytest 8 | 9 | import torch 10 | 11 | import dotenv 12 | 13 | from ttt.dataloader.language_modeling_hf import LMDataModule 14 | 15 | # load environment variables from `.env` file if it exists 16 | # recursively searches for `.env` in all folders starting from work dir 17 | dotenv.load_dotenv(override=True) 18 | 19 | 20 | def div_up(x: int, y: int) -> int: 21 | return (x + y - 1) // y 22 | 23 | 24 | # https://stackoverflow.com/questions/1006289/how-to-find-out-the-number-of-cpus-using-python/55423170#55423170 25 | def num_cpu_cores(): 26 | try: 27 | import psutil 28 | 29 | return psutil.cpu_count(logical=False) 30 | except ImportError: 31 | return len(os.sched_getaffinity(0)) 32 | 33 | 34 | class TestLMDataModule: 35 | def test_the_pile(self): 36 | batch_size = 8 37 | dataset_name = "the_pile" 38 | dataset_config_name = None 39 | cache_dir = Path( 40 | "/mnt/disks/persistent/the_pile_release" 41 | ) # TODO: Fill in your path to save the tokenized dataset 42 | raw_json_path = ( 43 | "/mnt/disks/persistent/PILE" 44 | ) # TODO: Fill in your path that already stores the raw dataset in json format 45 | max_length = 2048 46 | num_workers = num_cpu_cores() // 2 47 | datamodule = LMDataModule( 48 | dataset_name, 49 | tokenizer_name="meta-llama/Llama-2-7b-hf", 50 | dataset_config_name=dataset_config_name, 51 | max_length=max_length, 52 | cache_dir=cache_dir, 53 | raw_json_path=raw_json_path, 54 | add_eos=True, # bos is added by default in llama2 tokenizer 55 | batch_size=batch_size, 56 | num_workers=num_workers, 57 | ) 58 | datamodule.prepare_data() 59 | datamodule.setup(stage="fit") 60 | train_loader = datamodule.train_dataloader() 61 | val_loader = datamodule.val_dataloader() 62 | datamodule.setup(stage="test") 63 | test_loader = datamodule.test_dataloader() 64 | # Token number of The Pile when tokenized by the llama2 tokenizer 65 | train_len = 383509963636 66 | val_len = 393983786 67 | test_len = 383707892 68 | assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size) 69 | assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size) 70 | assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size) 71 | for loader in [train_loader, val_loader, test_loader]: 72 | x, y = next(iter(loader)) 73 | assert x.dim() == 2 74 | assert x.shape == (batch_size, max_length) 75 | assert x.dtype == torch.long 76 | assert torch.allclose(x[:, 1:], y[:, :-1]) 77 | 78 | def test_books(self): 79 | batch_size = 8 80 | dataset_name = "books3" 81 | dataset_config_name = None 82 | cache_dir = Path( 83 | "/mnt/disks/persistent/books3_release" 84 | ) # TODO: fill in your path to save the tokenized dataset 85 | raw_json_path = ( 86 | "/mnt/disks/persistent/lwm_raw/lwm_text_data/combined_books.jsonl" 87 | ) # TODO: fill in your path that already stores the raw dataset in json format 88 | max_length = 2048 89 | num_workers = 1 90 | datamodule = LMDataModule( 91 | dataset_name, 92 | tokenizer_name="meta-llama/Llama-2-7b-hf", 93 | dataset_config_name=dataset_config_name, 94 | max_length=max_length, 95 | cache_dir=cache_dir, 96 | raw_json_path=raw_json_path, 97 | add_eos=True, # bos is added by default in llama2 tokenizer 98 | batch_size=batch_size, 99 | num_workers=num_workers, 100 | ) 101 | datamodule.prepare_data() 102 | datamodule.setup(stage="fit") 103 | train_loader = datamodule.train_dataloader() 104 | val_loader = datamodule.val_dataloader() 105 | datamodule.setup(stage="test") 106 | test_loader = datamodule.test_dataloader() 107 | # Token number of Books3 when tokenized by the llama2 tokenizer 108 | train_len = 32585931901 109 | val_len = 14007763 110 | test_len = 14007763 111 | assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size) 112 | assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size) 113 | assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size) 114 | for loader in [train_loader, val_loader, test_loader]: 115 | x, y = next(iter(loader)) 116 | assert x.dim() == 2 117 | assert x.shape == (batch_size, max_length) 118 | assert x.dtype == torch.long 119 | assert torch.allclose(x[:, 1:], y[:, :-1]) 120 | -------------------------------------------------------------------------------- /ttt/infra/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/test-time-training/ttt-lm-jax/ac8cdc8a43b811afe23c27d7e82eed34a747b19c/ttt/infra/__init__.py -------------------------------------------------------------------------------- /ttt/infra/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from ml_collections import ConfigDict 4 | import mlxu 5 | import jax 6 | import jax.numpy as jnp 7 | import flax 8 | from flax.serialization import from_bytes, to_bytes, to_state_dict, from_state_dict 9 | from flax.traverse_util import flatten_dict, unflatten_dict, empty_node 10 | import msgpack 11 | 12 | from ttt.infra.jax_utils import tree_apply, float_tensor_to_dtype 13 | 14 | 15 | class StreamingCheckpointer(object): 16 | """Custom msgpack checkpointer that saves large train states by serializing 17 | and saving tensors one by one in a streaming fashion. Avoids running 18 | out of memory or local TPU disk with default flax checkpointer. 19 | """ 20 | 21 | @staticmethod 22 | def get_default_config(updates=None): 23 | config = ConfigDict() 24 | config.float_dtype = "bf16" 25 | config.save_optimizer_state = True 26 | 27 | if updates is not None: 28 | config.update(ConfigDict(updates).copy_and_resolve_references()) 29 | return config 30 | 31 | def __init__(self, config, checkpoint_dir, enable=True): 32 | self.config = self.get_default_config(config) 33 | self.checkpoint_dir = checkpoint_dir 34 | self.enable = enable 35 | 36 | def save_checkpoint(self, train_state, filename, gather_fns=None): 37 | if self.enable: 38 | path = os.path.join(self.checkpoint_dir, filename) 39 | os.makedirs(os.path.dirname(path), exist_ok=True) 40 | else: 41 | path = "/dev/null" 42 | self.save_train_state_to_file(train_state, path, gather_fns, self.config.float_dtype) 43 | 44 | @staticmethod 45 | def save_train_state_to_file(train_state, path, gather_fns=None, float_dtype=None): 46 | train_state = to_state_dict(train_state) 47 | packer = msgpack.Packer() 48 | flattend_train_state = flatten_dict(train_state) 49 | if gather_fns is not None: 50 | gather_fns = flatten_dict(to_state_dict(gather_fns)) 51 | 52 | with mlxu.open_file(path, "wb") as fout: 53 | for key, value in flattend_train_state.items(): 54 | if gather_fns is not None: 55 | value = gather_fns[key](value) 56 | value = float_tensor_to_dtype(value, float_dtype) 57 | fout.write(packer.pack((key, to_bytes(value)))) 58 | 59 | def save_pickle(self, obj, filename): 60 | if self.enable: 61 | path = os.path.join(self.checkpoint_dir, filename) 62 | os.makedirs(os.path.dirname(path), exist_ok=True) 63 | else: 64 | path = "/dev/null" 65 | mlxu.save_pickle(obj, path) 66 | 67 | def save_all(self, train_state, gather_fns, metadata=None, dataset=None, milestone=False): 68 | step = int(jax.device_get(train_state.step)) 69 | if self.config.save_optimizer_state: 70 | checkpoint_state = train_state 71 | checkpoint_name = "streaming_train_state" 72 | checkpoint_gather_fns = gather_fns 73 | else: 74 | checkpoint_state = train_state.params["params"] 75 | checkpoint_name = "streaming_params" 76 | checkpoint_gather_fns = gather_fns.params["params"] 77 | 78 | if milestone: 79 | # Save a milestone checkpoint that will not be overwritten 80 | self.save_pickle(metadata, f"step_{step}/metadata_{step}.pkl") 81 | self.save_pickle(dataset, f"step_{step}/dataset_{step}.pkl") 82 | self.save_checkpoint(checkpoint_state, f"step_{step}/{checkpoint_name}_{step}", checkpoint_gather_fns) 83 | # Additionally save a checkpoint that can be overwritten for automatic resuming 84 | self.save_pickle(metadata, "metadata.pkl") 85 | self.save_pickle(dataset, "dataset.pkl") 86 | self.save_checkpoint(checkpoint_state, f"{checkpoint_name}", checkpoint_gather_fns) 87 | else: 88 | # Save a normal checkpoint that can be overwritten 89 | self.save_pickle(metadata, "metadata.pkl") 90 | self.save_pickle(dataset, "dataset.pkl") 91 | self.save_checkpoint(checkpoint_state, f"{checkpoint_name}", checkpoint_gather_fns) 92 | 93 | @staticmethod 94 | def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None): 95 | if shard_fns is not None: 96 | shard_fns = flatten_dict(to_state_dict(shard_fns)) 97 | if remove_dict_prefix is not None: 98 | remove_dict_prefix = tuple(remove_dict_prefix) 99 | flattend_train_state = {} 100 | with mlxu.open_file(path) as fin: 101 | # 83886080 bytes = 80 MB, which is 16 blocks on GCS 102 | unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0) 103 | for key, value in unpacker: 104 | key = tuple(key) 105 | if remove_dict_prefix is not None: 106 | if key[: len(remove_dict_prefix)] == remove_dict_prefix: 107 | key = key[len(remove_dict_prefix) :] 108 | else: 109 | continue 110 | 111 | tensor = from_bytes(None, value) 112 | if shard_fns is not None: 113 | tensor = shard_fns[key](tensor) 114 | flattend_train_state[key] = tensor 115 | 116 | if target is not None: 117 | flattened_target = flatten_dict(to_state_dict(target), keep_empty_nodes=True) 118 | for key, value in flattened_target.items(): 119 | if key not in flattend_train_state and value == empty_node: 120 | flattend_train_state[key] = value 121 | 122 | train_state = unflatten_dict(flattend_train_state) 123 | if target is None: 124 | return train_state 125 | 126 | return from_state_dict(target, train_state) 127 | 128 | @staticmethod 129 | def load_flax_checkpoint(path, target=None, shard_fns=None): 130 | """Load a standard flax checkpoint that's not saved with the 131 | msgpack streaming format. 132 | """ 133 | with mlxu.open_file(path, "rb") as fin: 134 | encoded_bytes = fin.read() 135 | 136 | state_dict = flax.serialization.msgpack_restore(encoded_bytes) 137 | if shard_fns is not None: 138 | shard_fns = to_state_dict(shard_fns) 139 | state_dict = tree_apply(shard_fns, state_dict) 140 | 141 | if target is None: 142 | return state_dict 143 | return from_state_dict(target, state_dict) 144 | 145 | @classmethod 146 | def load_trainstate_checkpoint( 147 | cls, load_from, trainstate_target=None, trainstate_shard_fns=None, disallow_trainstate=False 148 | ): 149 | if trainstate_target is not None: 150 | params_target = trainstate_target.params["params"] 151 | else: 152 | params_target = None 153 | 154 | if trainstate_shard_fns is not None: 155 | params_shard_fns = trainstate_shard_fns.params["params"] 156 | else: 157 | params_shard_fns = None 158 | 159 | load_type, load_path = load_from.split("::", 1) 160 | if disallow_trainstate: 161 | assert load_type != "trainstate", "Loading full trainstate is not allowed!" 162 | 163 | train_state = None 164 | restored_params = None 165 | 166 | if load_type == "trainstate": 167 | # Load the entire train state in the streaming format 168 | train_state = cls.load_checkpoint(path=load_path, target=trainstate_target, shard_fns=trainstate_shard_fns) 169 | elif load_type == "trainstate_params": 170 | # Load the params part of the train state in the streaming format 171 | restored_params = cls.load_checkpoint( 172 | path=load_path, 173 | target=params_target, 174 | shard_fns=params_shard_fns, 175 | remove_dict_prefix=("params", "params"), 176 | ) 177 | restored_params = flax.core.frozen_dict.freeze({"params": restored_params}) 178 | elif load_type == "params": 179 | # Load the params in the streaming format 180 | restored_params = cls.load_checkpoint(path=load_path, target=params_target, shard_fns=params_shard_fns) 181 | restored_params = flax.core.frozen_dict.freeze({"params": restored_params}) 182 | elif load_type == "flax_params": 183 | # Load the params in the standard flax format (non-streaming) 184 | # This requires the entire params to fit in memory 185 | restored_params = cls.load_flax_checkpoint(path=load_path, target=params_target, shard_fns=params_shard_fns) 186 | restored_params = flax.core.frozen_dict.freeze({"params": restored_params}) 187 | else: 188 | raise ValueError(f"Invalid load_from type: {load_type}") 189 | 190 | return train_state, restored_params 191 | -------------------------------------------------------------------------------- /ttt/infra/jax_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from typing import Any, Mapping, Text, Tuple, Union, NamedTuple 4 | from functools import partial 5 | import re 6 | import dataclasses 7 | import random 8 | from ml_collections import ConfigDict 9 | from ml_collections.config_dict.config_dict import placeholder 10 | 11 | import flax 12 | import jax 13 | import jax.numpy as jnp 14 | from jax.sharding import PartitionSpec as PS 15 | from jax.sharding import Mesh 16 | from jax.experimental import mesh_utils 17 | from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint 18 | from jax.experimental.pjit import pjit 19 | from jax.interpreters import pxla 20 | import numpy as np 21 | import torch 22 | from transformers import FlaxLogitsWarper 23 | 24 | from io import BytesIO 25 | from PIL import Image 26 | import wandb 27 | import matplotlib.pyplot as plt 28 | import logging 29 | 30 | matplotlib_logger = logging.getLogger("matplotlib") 31 | matplotlib_logger.setLevel(logging.WARNING) 32 | 33 | 34 | class JaxRNG(object): 35 | """A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside 36 | pure function. 37 | """ 38 | 39 | @classmethod 40 | def from_seed(cls, seed): 41 | return cls(jax.random.PRNGKey(seed)) 42 | 43 | def __init__(self, rng): 44 | self.rng = rng 45 | 46 | def __call__(self, keys=None): 47 | if keys is None: 48 | self.rng, split_rng = jax.random.split(self.rng) 49 | return split_rng 50 | elif isinstance(keys, int): 51 | split_rngs = jax.random.split(self.rng, num=keys + 1) 52 | self.rng = split_rngs[0] 53 | return tuple(split_rngs[1:]) 54 | else: 55 | split_rngs = jax.random.split(self.rng, num=len(keys) + 1) 56 | self.rng = split_rngs[0] 57 | return {key: val for key, val in zip(keys, split_rngs[1:])} 58 | 59 | 60 | class JaxDistributedConfig(object): 61 | """Utility class for initializing JAX distributed.""" 62 | 63 | @staticmethod 64 | def get_default_config(updates=None): 65 | config = ConfigDict() 66 | config.initialize_jax_distributed = False 67 | config.coordinator_address = placeholder(str) 68 | config.num_processes = placeholder(int) 69 | config.process_id = placeholder(int) 70 | config.local_device_ids = placeholder(str) 71 | 72 | if updates is not None: 73 | config.update(ConfigDict(updates).copy_and_resolve_references()) 74 | return config 75 | 76 | @classmethod 77 | def initialize(cls, config): 78 | config = cls.get_default_config(config) 79 | if config.initialize_jax_distributed: 80 | if config.local_device_ids is not None: 81 | local_device_ids = [int(x) for x in config.local_device_ids.split(",")] 82 | else: 83 | local_device_ids = None 84 | 85 | jax.distributed.initialize( 86 | coordinator_address=config.coordinator_address, 87 | num_processes=config.num_processes, 88 | process_id=config.process_id, 89 | local_device_ids=local_device_ids, 90 | ) 91 | 92 | 93 | class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): 94 | """JIT traceable version of FlaxLogitsWarper that performs temperature scaling.""" 95 | 96 | def __init__(self, temperature): 97 | self.temperature = temperature 98 | 99 | def __call__(self, input_ids, scores, cur_len): 100 | return scores / jnp.clip(self.temperature, a_min=1e-8) 101 | 102 | 103 | def make_shard_and_gather_fns(partition_specs, dtype_specs=None): 104 | """Create pytree of sharding and gathering functions from pytree of 105 | partition specs. 106 | """ 107 | float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) 108 | 109 | def make_to_dtype_fn(dtype_spec): 110 | def to_dtype(tensor): 111 | if dtype_specs in float_dtypes and getattr(tensor, "dtype", None) in float_dtypes: 112 | # Convert all float tensors to the same dtype 113 | return tensor.astype(dtype_specs) 114 | elif hasattr(dtype_spec, "dtype") and hasattr(tensor, "dtype"): 115 | return tensor.astype(dtype_spec.dtype) 116 | return tensor 117 | 118 | return to_dtype 119 | 120 | def make_shard_fn(partition_spec, dtype_spec=None): 121 | jax_shard_function = pjit(make_to_dtype_fn(dtype_spec), in_shardings=None, out_shardings=partition_spec) 122 | 123 | def shard_fn(tensor): 124 | return jax_shard_function(tensor).block_until_ready() 125 | 126 | return shard_fn 127 | 128 | def make_gather_fn(partition_spec, dtype_spec=None): 129 | jax_gather_fn = pjit(make_to_dtype_fn(dtype_spec), in_shardings=partition_spec, out_shardings=None) 130 | 131 | def gather_fn(tensor): 132 | return jax.device_get(jax_gather_fn(tensor)) 133 | 134 | return gather_fn 135 | 136 | if dtype_specs is None or dtype_specs in float_dtypes: 137 | shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs) 138 | gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs) 139 | else: 140 | shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs, dtype_specs) 141 | gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs, dtype_specs) 142 | return shard_fns, gather_fns 143 | 144 | 145 | def set_random_seed(seed): 146 | np.random.seed(seed) 147 | random.seed(seed) 148 | torch.manual_seed(seed) 149 | init_rng(seed) 150 | 151 | 152 | def get_jax_mesh(axis_dims, names): 153 | if axis_dims.startswith("!"): 154 | # Allow splitting a physical mesh axis if needed 155 | mesh_axis_splitting = True 156 | axis_dims = axis_dims[1:] 157 | else: 158 | mesh_axis_splitting = False 159 | 160 | if ":" in axis_dims: 161 | dims = [] 162 | dim_names = [] 163 | for axis in axis_dims.split(","): 164 | name, dim = axis.split(":") 165 | assert name in names 166 | dims.append(int(dim)) 167 | dim_names.append(name) 168 | assert set(dim_names) == set(names) 169 | else: 170 | dims = [int(x) for x in axis_dims.split(",")] 171 | dim_names = names 172 | assert len(dims) == len(names) 173 | mesh_shape = np.arange(jax.device_count()).reshape(dims).shape 174 | if mesh_axis_splitting: 175 | physical_mesh = np.array(jax.devices()).reshape(mesh_shape) 176 | else: 177 | physical_mesh = mesh_utils.create_device_mesh(mesh_shape) 178 | return Mesh(physical_mesh, dim_names) 179 | 180 | 181 | def names_in_current_mesh(*names): 182 | """Check if current mesh axes contain these names.""" 183 | mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names 184 | return set(names) <= set(mesh_axis_names) 185 | 186 | 187 | def get_names_from_parition_spec(partition_specs): 188 | """Return axis names from partition specs.""" 189 | names = set() 190 | if isinstance(partition_specs, dict): 191 | partition_specs = partition_specs.values() 192 | for item in partition_specs: 193 | if item is None: 194 | continue 195 | elif isinstance(item, str): 196 | names.add(item) 197 | else: 198 | names.update(get_names_from_parition_spec(item)) 199 | 200 | return list(names) 201 | 202 | 203 | def with_sharding_constraint(x, partition_specs): 204 | """A smarter version of with_sharding_constraint that only applies the 205 | constraint if the current mesh contains the axes in the partition specs. 206 | """ 207 | axis_names = get_names_from_parition_spec(partition_specs) 208 | if names_in_current_mesh(*axis_names): 209 | x = _with_sharding_constraint(x, partition_specs) 210 | return x 211 | 212 | 213 | def wrap_function_with_rng(rng): 214 | """To be used as decorator, automatically bookkeep a RNG for the wrapped function.""" 215 | 216 | def wrap_function(function): 217 | def wrapped(*args, **kwargs): 218 | nonlocal rng 219 | rng, split_rng = jax.random.split(rng) 220 | return function(split_rng, *args, **kwargs) 221 | 222 | return wrapped 223 | 224 | return wrap_function 225 | 226 | 227 | def init_rng(seed): 228 | global jax_utils_rng 229 | jax_utils_rng = JaxRNG.from_seed(seed) 230 | 231 | 232 | def next_rng(*args, **kwargs): 233 | global jax_utils_rng 234 | return jax_utils_rng(*args, **kwargs) 235 | 236 | 237 | def get_metrics(metrics, unreplicate=False, stack=False): 238 | if unreplicate: 239 | metrics = flax.jax_utils.unreplicate(metrics) 240 | metrics = jax.device_get(metrics) 241 | if stack: 242 | return jax.tree_map(lambda *args: np.stack(args), *metrics) 243 | else: 244 | return {key: float(val) for key, val in metrics.items()} 245 | 246 | 247 | def mse_loss(val, target, valid=None): 248 | if valid is None: 249 | valid = jnp.ones((*target.shape[:2], 1)) 250 | valid = valid.astype(jnp.float32) 251 | loss = jnp.mean(jnp.where(valid > 0.0, jnp.square(val - target), 0.0)) 252 | return loss 253 | 254 | 255 | def cross_entropy_loss_and_accuracy(logits, tokens, valid=None): 256 | if valid is None: 257 | valid = jnp.ones(tokens.shape[:2]) 258 | valid = valid.astype(jnp.float32) 259 | valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10) 260 | logits = logits.astype(jnp.float32) # for numerical stability 261 | token_log_prob = jnp.squeeze( 262 | jnp.take_along_axis(jax.nn.log_softmax(logits, axis=-1), jnp.expand_dims(tokens, -1), axis=-1), -1 263 | ) 264 | token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0)) 265 | loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length) 266 | correct = jnp.where(valid > 0.0, jnp.argmax(logits, axis=-1) == tokens, jnp.array(False)) 267 | accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length) 268 | return loss, accuracy 269 | 270 | 271 | def global_norm(tree): 272 | """Return the global L2 norm of a pytree.""" 273 | squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree) 274 | flattened, _ = jax.flatten_util.ravel_pytree(squared) 275 | return jnp.sqrt(jnp.sum(flattened)) 276 | 277 | 278 | def average_metrics(metrics): 279 | return jax.tree_map(lambda *args: jnp.mean(jnp.stack(args)), *metrics) 280 | 281 | 282 | def get_float_dtype_by_name(dtype): 283 | return { 284 | "bf16": jnp.bfloat16, 285 | "bfloat16": jnp.bfloat16, 286 | "fp16": jnp.float16, 287 | "float16": jnp.float16, 288 | "fp32": jnp.float32, 289 | "float32": jnp.float32, 290 | "fp64": jnp.float64, 291 | "float64": jnp.float64, 292 | }[dtype] 293 | 294 | 295 | def float_tensor_to_dtype(tensor, dtype): 296 | if dtype is None or dtype == "": 297 | return tensor 298 | if isinstance(dtype, str): 299 | dtype = get_float_dtype_by_name(dtype) 300 | float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) 301 | if getattr(tensor, "dtype", None) in float_dtypes: 302 | tensor = tensor.astype(dtype) 303 | return tensor 304 | 305 | 306 | def float_to_dtype(tree, dtype): 307 | return jax.tree_util.tree_map(partial(float_tensor_to_dtype, dtype=dtype), tree) 308 | 309 | 310 | def get_gradient_checkpoint_policy(name): 311 | return { 312 | "everything_saveable": jax.checkpoint_policies.everything_saveable, 313 | "nothing_saveable": jax.checkpoint_policies.nothing_saveable, 314 | "checkpoint_dots": jax.checkpoint_policies.checkpoint_dots, 315 | "checkpoint_dots_with_no_batch_dims": jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, 316 | }[name] 317 | 318 | 319 | def tree_path_to_string(path, sep=None): 320 | keys = [] 321 | for key in path: 322 | if isinstance(key, jax.tree_util.SequenceKey): 323 | keys.append(str(key.idx)) 324 | elif isinstance(key, jax.tree_util.DictKey): 325 | keys.append(str(key.key)) 326 | elif isinstance(key, jax.tree_util.GetAttrKey): 327 | keys.append(str(key.name)) 328 | elif isinstance(key, jax.tree_util.FlattenedIndexKey): 329 | keys.append(str(key.key)) 330 | else: 331 | keys.append(str(key)) 332 | if sep is None: 333 | return tuple(keys) 334 | return sep.join(keys) 335 | 336 | 337 | def flatten_tree(xs, is_leaf=None, sep=None): 338 | flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf) 339 | output = {} 340 | for key, val in flattened: 341 | output[tree_path_to_string(key, sep=sep)] = val 342 | return output 343 | 344 | 345 | def named_tree_map(f, tree, *rest, is_leaf=None, sep=None): 346 | """An extended version of jax.tree_util.tree_map, where the mapped function 347 | f takes both the name (path) and the tree leaf as input. 348 | """ 349 | return jax.tree_util.tree_map_with_path( 350 | lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r), tree, *rest, is_leaf=is_leaf 351 | ) 352 | 353 | 354 | def match_partition_rules(rules, params): 355 | """Returns a pytree of PartitionSpec according to rules. Supports handling 356 | Flax TrainState and Optax optimizer state. 357 | """ 358 | 359 | def get_partition_spec(name, leaf): 360 | if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1: 361 | """Don't partition scalar values.""" 362 | return PS() 363 | for rule, ps in rules: 364 | if re.search(rule, name) is not None: 365 | return ps 366 | raise ValueError(f"Partition rule not found for param: {name}") 367 | 368 | return named_tree_map(get_partition_spec, params, sep="/") 369 | 370 | 371 | def get_weight_decay_mask(exclusions): 372 | """Return a weight decay mask function that computes the pytree masks 373 | according to the given exclusion rules. 374 | """ 375 | 376 | def decay(name, _): 377 | for rule in exclusions: 378 | if re.search(rule, name) is not None: 379 | return False 380 | return True 381 | 382 | def weight_decay_mask(params): 383 | return named_tree_map(decay, params, sep="/") 384 | 385 | return weight_decay_mask 386 | 387 | 388 | def tree_apply(fns, tree): 389 | """Apply a pytree of functions to the pytree.""" 390 | return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree) 391 | 392 | 393 | def master_print(msg, logger=None, end="\n"): 394 | if jax.process_index() == 0: 395 | print(msg, flush=True, end=end) 396 | if logger is not None: 397 | logger.writelines(msg) 398 | if end == "\n": 399 | logger.writelines("\n") 400 | logger.flush() 401 | 402 | 403 | def log_plot(fig, name, step): 404 | buf = BytesIO() 405 | fig.savefig(buf, format="png") 406 | buf.seek(0) 407 | image = Image.open(buf) 408 | wandb.log({name: wandb.Image(image)}, step=step) 409 | buf.close() 410 | 411 | 412 | def log_ttt_stats(layer, ttt_stats_layer, x_axis, step): 413 | ssl_tgt_last_in_mini_batch_from_mean_mse = ttt_stats_layer[0] 414 | ttt_loss_mse_init = ttt_stats_layer[1] 415 | ttt_loss_mse_step_0 = ttt_stats_layer[2] 416 | ttt_loss_mse_step_1 = ttt_stats_layer[3] 417 | 418 | fig, ax = plt.subplots() 419 | ax.plot(x_axis, ssl_tgt_last_in_mini_batch_from_mean_mse, label="$\\|E[Y_{ssl}]-Y_{ssl}\\|^2$", color="green") 420 | ax.plot(x_axis, ttt_loss_mse_init, label="$\mathcal{L}(x_t; W_0)$", color="orange") 421 | ax.plot(x_axis, ttt_loss_mse_step_0, label="$\mathcal{L}(x_t; W_{t-b})$", color="blue") 422 | ax.plot(x_axis, ttt_loss_mse_step_1, label="$\mathcal{L}(x_t; W_{t})$", color="red") 423 | ax.set_ylabel("TTT Loss") 424 | ax.set_xlabel("Position in Sequence") 425 | ax.set_title(f"Layer {layer + 1}") 426 | ax.legend() 427 | log_plot(fig, f"Layer {layer + 1} TTT Loss", step) 428 | plt.close() 429 | -------------------------------------------------------------------------------- /ttt/infra/optimizers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import Any, Mapping, Text, Tuple, Union, NamedTuple 4 | from functools import partial 5 | import re 6 | import dataclasses 7 | import random 8 | 9 | from ml_collections.config_dict import config_dict 10 | from ml_collections import ConfigDict 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | from absl import logging 15 | import optax 16 | 17 | from ttt.infra.jax_utils import float_to_dtype 18 | 19 | 20 | class OptimizerFactory(object): 21 | """Configurable optax optimizer factory.""" 22 | 23 | def __init__(self): 24 | raise NotImplementedError 25 | 26 | @staticmethod 27 | def get_default_config(updates=None): 28 | config = ConfigDict() 29 | config.accumulate_gradient_steps = 1 30 | config.type = "adamw" 31 | config.palm_optimizer = PalmOptimizerFactory.get_default_config() 32 | config.adamw_optimizer = AdamWOptimizerFactory.get_default_config() 33 | 34 | if updates is not None: 35 | config.update(ConfigDict(updates).copy_and_resolve_references()) 36 | return config 37 | 38 | @classmethod 39 | def get_optimizer(cls, config, weight_decay_mask=None): 40 | config = cls.get_default_config(config) 41 | if config.type == "palm": 42 | optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer(config.palm_optimizer, weight_decay_mask) 43 | elif config.type == "adamw": 44 | optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer(config.adamw_optimizer, weight_decay_mask) 45 | else: 46 | raise ValueError(f"Unknown optimizer type: {config.type}") 47 | 48 | if config.accumulate_gradient_steps > 1: 49 | optimizer = optax.MultiSteps(optimizer, config.accumulate_gradient_steps) 50 | 51 | return optimizer, optimizer_info 52 | 53 | 54 | class PalmOptimizerFactory(object): 55 | """PaLM optimizer factory. This optimizer implements the optimizer 56 | described in the PaLM paper: https://arxiv.org/abs/2204.02311 57 | """ 58 | 59 | def __init__(self): 60 | raise NotImplementedError 61 | 62 | @staticmethod 63 | def get_default_config(updates=None): 64 | config = ConfigDict() 65 | config.lr = 0.01 66 | config.lr_warmup_steps = 10000 67 | config.b1 = 0.9 68 | config.b2 = 0.99 69 | config.clip_gradient = 1.0 70 | config.weight_decay = 1e-4 71 | config.bf16_momentum = False 72 | 73 | if updates is not None: 74 | config.update(ConfigDict(updates).copy_and_resolve_references()) 75 | return config 76 | 77 | @classmethod 78 | def get_optimizer(cls, config, weight_decay_mask=None): 79 | config = cls.get_default_config(config) 80 | 81 | def learning_rate_schedule(step): 82 | multiplier = config.lr / 0.01 83 | return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps)) 84 | 85 | def weight_decay_schedule(step): 86 | multiplier = config.weight_decay / 1e-4 87 | return -multiplier * jnp.square(learning_rate_schedule(step)) 88 | 89 | optimizer_info = dict( 90 | learning_rate_schedule=learning_rate_schedule, weight_decay_schedule=weight_decay_schedule 91 | ) 92 | 93 | optimizer = optax.chain( 94 | optax.clip_by_global_norm(config.clip_gradient), 95 | optax.adafactor( 96 | learning_rate=learning_rate_schedule, 97 | multiply_by_parameter_scale=True, 98 | momentum=config.b1, 99 | decay_rate=config.b2, 100 | factored=False, 101 | clipping_threshold=None, 102 | dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32, 103 | ), 104 | optax_add_scheduled_weight_decay(weight_decay_schedule, weight_decay_mask), 105 | ) 106 | return optimizer, optimizer_info 107 | 108 | 109 | class AdamWOptimizerFactory(object): 110 | """AdamW optimizer with cosine schedule.""" 111 | 112 | def __init__(self): 113 | raise NotImplementedError 114 | 115 | @staticmethod 116 | def get_default_config(updates=None): 117 | config = ConfigDict() 118 | config.init_lr = 0.0 119 | config.end_lr = 1e-5 # EasyLM's previous default: 0.001 120 | config.lr = 0.01 121 | config.lr_warmup_steps = 2000 122 | config.lr_decay_steps = 500000 123 | config.b1 = 0.9 124 | config.b2 = 0.95 125 | config.clip_gradient = 1.0 126 | config.weight_decay = 0.1 # EasyLM previous default: 1e-4 127 | config.bf16_momentum = False 128 | config.multiply_by_parameter_scale = False 129 | 130 | if updates is not None: 131 | config.update(ConfigDict(updates).copy_and_resolve_references()) 132 | return config 133 | 134 | @classmethod 135 | def get_optimizer(cls, config, weight_decay_mask=None): 136 | config = cls.get_default_config(config) 137 | 138 | learning_rate_schedule = optax.warmup_cosine_decay_schedule( 139 | init_value=config.init_lr, 140 | peak_value=config.lr, 141 | warmup_steps=config.lr_warmup_steps, 142 | decay_steps=config.lr_decay_steps, 143 | end_value=config.end_lr, 144 | ) 145 | 146 | optimizer_info = dict(learning_rate_schedule=learning_rate_schedule) 147 | 148 | if config.multiply_by_parameter_scale: 149 | optimizer = optax.chain( 150 | optax.clip_by_global_norm(config.clip_gradient), 151 | optax.adafactor( 152 | learning_rate=learning_rate_schedule, 153 | multiply_by_parameter_scale=True, 154 | momentum=config.b1, 155 | decay_rate=config.b2, 156 | factored=False, 157 | clipping_threshold=None, 158 | dtype_momentum=(jnp.bfloat16 if config.bf16_momentum else jnp.float32), 159 | ), 160 | optax_add_scheduled_weight_decay( 161 | lambda step: -learning_rate_schedule(step) * config.weight_decay, weight_decay_mask 162 | ), 163 | ) 164 | else: 165 | optimizer = optax.chain( 166 | optax.clip_by_global_norm(config.clip_gradient), 167 | optax.adamw( 168 | learning_rate=learning_rate_schedule, 169 | weight_decay=config.weight_decay, 170 | b1=config.b1, 171 | b2=config.b2, 172 | mask=weight_decay_mask, 173 | mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32, 174 | ), 175 | ) 176 | 177 | return optimizer, optimizer_info 178 | 179 | 180 | class OptaxScheduledWeightDecayState(NamedTuple): 181 | count: jax.Array 182 | 183 | 184 | def optax_add_scheduled_weight_decay(schedule_fn, mask=None): 185 | """Apply weight decay with schedule.""" 186 | 187 | def init_fn(params): 188 | del params 189 | return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32)) 190 | 191 | def update_fn(updates, state, params): 192 | if params is None: 193 | raise ValueError("Params cannot be None for weight decay!") 194 | 195 | weight_decay = schedule_fn(state.count) 196 | updates = jax.tree_util.tree_map(lambda g, p: g + weight_decay * p, updates, params) 197 | return updates, OptaxScheduledWeightDecayState(count=optax.safe_int32_increment(state.count)) 198 | 199 | if mask is not None: 200 | return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask) 201 | return optax.GradientTransformation(init_fn, update_fn) 202 | -------------------------------------------------------------------------------- /ttt/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/test-time-training/ttt-lm-jax/ac8cdc8a43b811afe23c27d7e82eed34a747b19c/ttt/models/__init__.py -------------------------------------------------------------------------------- /ttt/models/bpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370 3 | Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682 4 | """ 5 | 6 | import functools 7 | from typing import NamedTuple 8 | 9 | import flax.linen as nn 10 | import jax 11 | import jax.lax as lax 12 | import jax.numpy as jnp 13 | from einops import rearrange 14 | 15 | """ 16 | Computing ffn blockwise without materializing the large hidden tensor, training 17 | 4x longer sequences than the memory-efficient transformer. 18 | Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023 19 | """ 20 | def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True): 21 | # remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable() 22 | # inputs: (batch, seq_len, dim) 23 | # chunk_size: the chunk size to split the sequence 24 | inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size) 25 | def scan_ffn(remat_ffn, carry, hidden_states): 26 | # outputs = remat_ffn(hidden_states, deterministic=deterministic) 27 | outputs = remat_ffn(hidden_states, deterministic) # @xinhao: when mlp is rematted, should directly pass `deterministic` instead of using keyword. Otherwise, `deterministic` will be ignored. 28 | return carry, outputs 29 | scan_axis = inputs.ndim - 2 30 | _, res = nn.scan( 31 | scan_ffn, 32 | variable_broadcast="params", 33 | split_rngs={"params": False, "dropout": True}, 34 | in_axes=scan_axis, 35 | out_axes=scan_axis, 36 | )(remat_ffn, None, inputs) 37 | res = rearrange(res, 'b c n d -> b (c n) d') 38 | return res 39 | 40 | 41 | """ 42 | Compute attention blockwise without materializing the full attention matrix, 43 | initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021; 44 | flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA 45 | efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370 46 | Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x 47 | longer sequences than memory-efficient/flash-attention and fusion of attention and FFN. 48 | """ 49 | def blockwise_attn( 50 | query, key, value, 51 | bias=None, 52 | deterministic=True, 53 | dropout_rng=None, 54 | attn_pdrop=0.0, 55 | causal=True, 56 | query_chunk_size=2048, 57 | key_chunk_size=2048, 58 | dtype=jnp.float32, 59 | policy=jax.checkpoint_policies.nothing_saveable(), 60 | precision=None, 61 | float32_logits=True, 62 | prevent_cse=True, 63 | ): 64 | # query, key, value: (batch, seq_len, num_heads, dim_per_head) 65 | # bias: (batch, seq_len) can be used to mask out attention (e.g. padding) 66 | # causal: whether to use causal mask 67 | # policy: one of jax.checkpoint_policies 68 | query = query / jnp.sqrt(query.shape[-1]).astype(dtype) 69 | if float32_logits: 70 | query = query.astype(jnp.float32) 71 | key = key.astype(jnp.float32) 72 | 73 | batch, q_len, num_heads, dim_per_head = query.shape 74 | batch, kv_len, num_heads, dim_per_head = key.shape 75 | batch, kv_len, num_heads, dim_per_head = value.shape 76 | 77 | num_q = q_len // query_chunk_size 78 | num_kv = kv_len // key_chunk_size 79 | query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head)) 80 | key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) 81 | value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) 82 | 83 | query = jnp.moveaxis(query, 1, 0) 84 | key = jnp.moveaxis(key, 1, 0) 85 | value = jnp.moveaxis(value, 1, 0) 86 | 87 | if bias is not None: 88 | for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)): 89 | assert bias_dim == 1 or bias_dim == broadcast_dim 90 | if not deterministic and attn_pdrop > 0.0: 91 | attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng) 92 | attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len)) 93 | else: 94 | attn_dropout = None 95 | 96 | _chunk_bias_fn = functools.partial( 97 | _chunk_attention_bias, 98 | query_chunk_size, key_chunk_size, bias, deterministic, 99 | attn_dropout, attn_pdrop, causal, dtype) 100 | 101 | def scan_attention(args): 102 | query_chunk, query_chunk_idx = args 103 | 104 | @functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy) 105 | def scan_kv_block(carry, args): 106 | key_chunk, value_chunk, key_chunk_idx = args 107 | (numerator, denominator, prev_max_score) = carry 108 | attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision) 109 | bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx) 110 | bias_chunk = jnp.moveaxis(bias_chunk, 1, 2) 111 | attn_weights = attn_weights + bias_chunk 112 | 113 | max_score = jnp.max(attn_weights, axis=-1, keepdims=True) 114 | max_score = jnp.maximum(prev_max_score, max_score) 115 | max_score = jax.lax.stop_gradient(max_score) 116 | exp_weights = jnp.exp(attn_weights - max_score) 117 | exp_values = jnp.einsum( 118 | 'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision 119 | ) 120 | correction = jnp.exp(prev_max_score - max_score) 121 | numerator = numerator * correction + exp_values 122 | denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True) 123 | return Carry(numerator, denominator, max_score), None 124 | 125 | def skip_upper_half(carry, args): 126 | key_chunk, value_chunk, key_chunk_idx = args 127 | skip_block = jnp.array(False) 128 | if causal: 129 | skip_block = query_chunk_idx < key_chunk_idx 130 | return jax.lax.cond( 131 | skip_block, 132 | lambda carry, args: (carry, None), 133 | scan_kv_block, 134 | carry, 135 | args, 136 | ) 137 | 138 | init_carry = Carry( 139 | jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), 140 | jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), 141 | (-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype), 142 | ) 143 | (numerator, denominator, max_score), _ = lax.scan( 144 | skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv)) 145 | ) 146 | outputs = (numerator / denominator).astype(dtype) 147 | return outputs 148 | 149 | _, res = lax.scan( 150 | lambda _, x: ((), scan_attention(x)), 151 | (), xs=(query, jnp.arange(0, num_q)) 152 | ) 153 | res = rearrange(res, 'n b c h d -> b (n c) h d') 154 | return res 155 | 156 | 157 | class Carry(NamedTuple): 158 | numerator: jax.Array 159 | denominator: jax.Array 160 | max_so_far: jax.Array 161 | 162 | 163 | def _chunk_attention_bias(query_chunk_size, key_chunk_size, 164 | bias, deterministic, attn_dropout, attn_pdrop, causal, 165 | dtype, query_chunk_idx, key_chunk_idx): 166 | query_offset = query_chunk_idx * query_chunk_size 167 | key_offset = key_chunk_idx * key_chunk_size 168 | chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype) 169 | if bias is not None: 170 | chunk_bias = lax.dynamic_slice( 171 | bias, 172 | start_indices=(0, 0, query_offset, key_offset), 173 | slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)), 174 | ) 175 | 176 | if causal: 177 | query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0) 178 | key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1) 179 | offset = query_offset - key_offset 180 | query_idx += offset 181 | causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min 182 | chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape) 183 | 184 | if not deterministic and attn_pdrop > 0.0: 185 | attn_dropout_slice = lax.dynamic_slice( 186 | attn_dropout, 187 | start_indices=(0, 0, query_offset, key_offset), 188 | slice_sizes=( 189 | *attn_dropout.shape[:2], 190 | min(attn_dropout.shape[-2], query_chunk_size), 191 | min(attn_dropout.shape[-1], key_chunk_size), 192 | ), 193 | ) 194 | chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min 195 | return chunk_bias.astype(dtype) 196 | 197 | 198 | if __name__ == '__main__': 199 | # test 200 | def reference_attn(query, key, value, causal, dtype): 201 | query = query / jnp.sqrt(query.shape[-1]).astype(dtype) 202 | logits = jnp.einsum("bqhc,bkhc->bhqk", query, key) 203 | if causal: 204 | mask_value = jnp.finfo(logits.dtype).min 205 | _, q_seq_len, _, _ = query.shape 206 | _, kv_seq_len, _, _ = key.shape 207 | mask_shape = (q_seq_len, kv_seq_len) 208 | row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) 209 | col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) 210 | causal_mask = (row_ids < col_ids)[None, None, :, :] 211 | logits = logits + jnp.where(causal_mask, mask_value, 0.0) 212 | weights = jax.nn.softmax(logits, axis=-1) 213 | out = jnp.einsum("bhqk,bkhc->bqhc", weights, value) 214 | return out 215 | 216 | # random inputs 217 | shape = (1, 32, 8, 64) 218 | query = jax.random.normal(jax.random.PRNGKey(0), shape) 219 | key = jax.random.normal(jax.random.PRNGKey(1), shape) 220 | value = jax.random.normal(jax.random.PRNGKey(2), shape) 221 | 222 | causal = True 223 | chunk_size = 4 224 | policy = jax.checkpoint_policies.nothing_saveable() 225 | 226 | blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False) 227 | reference = reference_attn(query, key, value, causal, 'float32') 228 | 229 | assert jnp.allclose(reference, blockwise, atol=1e-6) 230 | -------------------------------------------------------------------------------- /ttt/models/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | import json 4 | 5 | import numpy as np 6 | import jax 7 | import jax.numpy as jnp 8 | from jax import lax 9 | from jax.sharding import PartitionSpec as PS 10 | import flax 11 | import flax.linen as nn 12 | from flax.linen import combine_masks, make_causal_mask 13 | from flax.linen.attention import dot_product_attention_weights 14 | from flax.linen import partitioning as nn_partitioning 15 | 16 | from transformers.configuration_utils import PretrainedConfig 17 | from transformers.modeling_flax_outputs import ModelOutput 18 | 19 | from ml_collections import ConfigDict 20 | from mlxu import function_args_to_config, load_pickle, open_file 21 | 22 | from ttt.models.bpt import blockwise_ffn, blockwise_attn 23 | from ttt.infra.jax_utils import with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy 24 | from ttt.models.ttt_layer import TTTLinear, TTTMLP, TTTLinearBase, TTTMLPBase, precompute_freqs_cis, apply_rotary_emb 25 | 26 | 27 | @flax.struct.dataclass 28 | class BaseModelOutput(ModelOutput): 29 | last_hidden_state: jnp.ndarray = None 30 | hidden_states: Optional[Tuple[jnp.ndarray]] = None 31 | attentions: Optional[Tuple[jnp.ndarray]] = None 32 | ttt_stats: Optional[Tuple[jnp.ndarray]] = None 33 | logits: jnp.ndarray = None 34 | 35 | 36 | CausalLMOutput = BaseModelOutput 37 | 38 | remat = nn_partitioning.remat 39 | 40 | CONFIGS = { 41 | "125m": { 42 | "vocab_size": 32000, 43 | "num_hidden_layers": 12, 44 | "hidden_size": 768, 45 | "num_attention_heads": 12, 46 | "intermediate_size": 2048, 47 | "max_sequence_length": 2048, 48 | "initializer_range": 0.02, 49 | "rms_norm_eps": 1e-6, 50 | "use_cache": True, 51 | "tie_word_embeddings": True, 52 | "seq_modeling_block": "self_attention", 53 | "use_rotary_emb": "sequence", 54 | "rope_theta": 10000.0, 55 | "pre_conv": False, 56 | }, 57 | "125m-TTT": { 58 | "vocab_size": 32000, 59 | "num_hidden_layers": 12, 60 | "hidden_size": 768, 61 | "num_attention_heads": 12, 62 | "intermediate_size": 2048, 63 | "max_sequence_length": 2048, 64 | "initializer_range": 0.02, 65 | "rms_norm_eps": 1e-6, 66 | "use_cache": True, 67 | "tie_word_embeddings": True, 68 | "seq_modeling_block": "ttt_linear", 69 | "ttt_base_lr": 1.0, 70 | "ttt_base_lr_init": -1.0, 71 | "ttt_base_lr_warmup": -1, 72 | "mini_batch_size": 16, 73 | "remat_mini_batch_group_size": 4, 74 | "rope_theta": 10000.0, 75 | "pre_conv": True, 76 | "conv_width": 4, 77 | }, 78 | "350m": { 79 | "vocab_size": 32000, 80 | "num_hidden_layers": 24, 81 | "hidden_size": 1024, 82 | "num_attention_heads": 16, 83 | "intermediate_size": 2736, 84 | "max_sequence_length": 2048, 85 | "initializer_range": 0.02, 86 | "rms_norm_eps": 1e-6, 87 | "use_cache": True, 88 | "tie_word_embeddings": True, 89 | "seq_modeling_block": "self_attention", 90 | "use_rotary_emb": "sequence", 91 | "rope_theta": 10000.0, 92 | "pre_conv": False, 93 | }, 94 | "350m-TTT": { 95 | "vocab_size": 32000, 96 | "num_hidden_layers": 24, 97 | "hidden_size": 1024, 98 | "num_attention_heads": 16, 99 | "intermediate_size": 2736, 100 | "max_sequence_length": 2048, 101 | "initializer_range": 0.02, 102 | "rms_norm_eps": 1e-6, 103 | "use_cache": True, 104 | "tie_word_embeddings": True, 105 | "seq_modeling_block": "ttt_linear", 106 | "ttt_base_lr": 1.0, 107 | "ttt_base_lr_init": -1.0, 108 | "ttt_base_lr_warmup": -1, 109 | "mini_batch_size": 16, 110 | "remat_mini_batch_group_size": 4, 111 | "rope_theta": 10000.0, 112 | "pre_conv": True, 113 | "conv_width": 4, 114 | }, 115 | "760m": { 116 | "vocab_size": 32000, 117 | "num_hidden_layers": 24, 118 | "hidden_size": 1536, 119 | "num_attention_heads": 16, 120 | "intermediate_size": 4096, 121 | "max_sequence_length": 2048, 122 | "initializer_range": 0.02, 123 | "rms_norm_eps": 1e-6, 124 | "use_cache": True, 125 | "tie_word_embeddings": True, 126 | "seq_modeling_block": "self_attention", 127 | "use_rotary_emb": "sequence", 128 | "rope_theta": 10000.0, 129 | "pre_conv": False, 130 | }, 131 | "760m-TTT": { 132 | "vocab_size": 32000, 133 | "num_hidden_layers": 24, 134 | "hidden_size": 1536, 135 | "num_attention_heads": 16, 136 | "intermediate_size": 4096, 137 | "max_sequence_length": 2048, 138 | "initializer_range": 0.02, 139 | "rms_norm_eps": 1e-6, 140 | "use_cache": True, 141 | "tie_word_embeddings": True, 142 | "seq_modeling_block": "ttt_linear", 143 | "ttt_base_lr": 1.0, 144 | "ttt_base_lr_init": -1.0, 145 | "ttt_base_lr_warmup": -1, 146 | "mini_batch_size": 16, 147 | "remat_mini_batch_group_size": 4, 148 | "rope_theta": 10000.0, 149 | "pre_conv": True, 150 | "conv_width": 4, 151 | }, 152 | "1b": { 153 | "vocab_size": 32000, 154 | "num_hidden_layers": 24, 155 | "hidden_size": 2048, 156 | "num_attention_heads": 32, 157 | "intermediate_size": 5504, 158 | "max_sequence_length": 2048, 159 | "initializer_range": 0.02, 160 | "rms_norm_eps": 1e-6, 161 | "use_cache": True, 162 | "tie_word_embeddings": True, 163 | "seq_modeling_block": "self_attention", 164 | "use_rotary_emb": "sequence", 165 | "rope_theta": 10000.0, 166 | "pre_conv": False, 167 | }, 168 | "1b-TTT": { 169 | "vocab_size": 32000, 170 | "num_hidden_layers": 24, 171 | "hidden_size": 2048, 172 | "num_attention_heads": 32, 173 | "intermediate_size": 5504, 174 | "max_sequence_length": 2048, 175 | "initializer_range": 0.02, 176 | "rms_norm_eps": 1e-6, 177 | "use_cache": True, 178 | "tie_word_embeddings": True, 179 | "seq_modeling_block": "ttt_linear", 180 | "ttt_base_lr": 1.0, 181 | "ttt_base_lr_init": -1.0, 182 | "ttt_base_lr_warmup": -1, 183 | "mini_batch_size": 16, 184 | "remat_mini_batch_group_size": 4, 185 | "rope_theta": 10000.0, 186 | "pre_conv": True, 187 | "conv_width": 4, 188 | }, 189 | } 190 | 191 | 192 | class ModelConfig(PretrainedConfig): 193 | def __init__( 194 | self, 195 | vocab_size=32000, 196 | hidden_size=4096, 197 | intermediate_size=11008, 198 | num_hidden_layers=32, 199 | num_attention_heads=32, 200 | max_sequence_length=2048, 201 | rms_norm_eps=1e-6, 202 | initializer_range=0.02, 203 | use_cache=True, 204 | bos_token_id=1, 205 | eos_token_id=2, 206 | resid_pdrop=0.0, 207 | embd_pdrop=0.0, 208 | attn_pdrop=0.0, 209 | tie_word_embeddings=False, 210 | remat_block="", 211 | remat_attention="", 212 | remat_mlp="", 213 | remat_conv="", 214 | scan_attention=False, 215 | scan_mlp=False, 216 | scan_query_chunk_size=1024, 217 | scan_key_chunk_size=1024, 218 | scan_mlp_chunk_size=1024, 219 | fcm_min_ratio=0.0, 220 | fcm_max_ratio=0.0, 221 | **kwargs, 222 | ): 223 | self.vocab_size = vocab_size 224 | self.hidden_size = hidden_size 225 | self.initializer_range = initializer_range 226 | self.intermediate_size = intermediate_size 227 | self.num_hidden_layers = num_hidden_layers 228 | self.num_attention_heads = num_attention_heads 229 | self.max_sequence_length = max_sequence_length 230 | self.rms_norm_eps = rms_norm_eps 231 | self.use_cache = use_cache 232 | self.resid_pdrop = resid_pdrop 233 | self.embd_pdrop = embd_pdrop 234 | self.attn_pdrop = attn_pdrop 235 | self.remat_block = remat_block 236 | self.remat_attention = remat_attention 237 | self.remat_mlp = remat_mlp 238 | self.remat_conv = remat_conv 239 | self.scan_attention = scan_attention 240 | self.scan_mlp = scan_mlp 241 | self.scan_query_chunk_size = scan_query_chunk_size 242 | self.scan_key_chunk_size = scan_key_chunk_size 243 | self.scan_mlp_chunk_size = scan_mlp_chunk_size 244 | self.fcm_min_ratio = fcm_min_ratio 245 | self.fcm_max_ratio = fcm_max_ratio 246 | super().__init__( 247 | bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs 248 | ) 249 | 250 | @classmethod 251 | def get_default_config(cls, updates=None): 252 | config = function_args_to_config(cls.__init__) 253 | 254 | if updates is not None: 255 | config.update(ConfigDict(updates).copy_and_resolve_references()) 256 | 257 | return config 258 | 259 | @staticmethod 260 | def get_jax_mesh(axis_dims): 261 | return get_jax_mesh(axis_dims, ("dp", "fsdp", "mp")) 262 | 263 | @staticmethod 264 | def get_partition_rules(): 265 | """Partition rules. Note that these rules are orderd, so that 266 | the beginning rules match first. It is important to use 267 | PartitionSpec() instead of None here because JAX does not treat 268 | None as a pytree leaf. 269 | """ 270 | return ( 271 | # Embeddings 272 | ("model/wte/embedding", PS("mp", "fsdp")), 273 | # Attention/TTT 274 | ("seq_modeling_block/(wq|wk|wv)/kernel", PS("fsdp", "mp")), 275 | ("seq_modeling_block/wo/kernel", PS("mp", "fsdp")), 276 | # TTT 277 | ("seq_modeling_block/ttt_norm/scale", PS(None)), 278 | ("seq_modeling_block/ttt_norm/bias", PS(None)), 279 | ("seq_modeling_block/post_norm/scale", PS(None)), 280 | ("seq_modeling_block/post_norm/bias", PS(None)), 281 | ("seq_modeling_block/learnable_ttt_lr/kernel", PS(None)), 282 | ("seq_modeling_block/learnable_ttt_lr/bias", PS(None)), 283 | ("seq_modeling_block/ttt_dense_0", PS(None)), 284 | ("seq_modeling_block/ttt_dense_1", PS(None)), 285 | ("seq_modeling_block/ttt_bias_0", PS(None)), 286 | ("seq_modeling_block/ttt_bias_1", PS(None)), 287 | # SwiGLU MLP 288 | ("feed_forward/w1/kernel", PS("fsdp", "mp")), 289 | ("feed_forward/w2/kernel", PS("mp", "fsdp")), 290 | ("feed_forward/w3/kernel", PS("fsdp", "mp")), 291 | # RMS Norm 292 | ("seq_norm/kernel", PS(None)), 293 | ("ffn_norm/kernel", PS(None)), 294 | # Output Head 295 | ("model/ln_f/kernel", PS(None)), 296 | ("lm_head/kernel", PS("fsdp", "mp")), 297 | (".*", PS(None)), 298 | ) 299 | 300 | @staticmethod 301 | def get_weight_decay_exclusions(): 302 | return tuple() 303 | 304 | @staticmethod 305 | def rng_keys(): 306 | return ("params", "dropout", "fcm") 307 | 308 | @classmethod 309 | def load_config(cls, path): 310 | if path in CONFIGS: 311 | return cls.from_dict(CONFIGS[path]) 312 | load_type, load_path = path.split("::", 1) 313 | if load_type == "pickle": 314 | return cls.from_dict(load_pickle(load_path)["config"]) 315 | elif load_type == "json": 316 | with open_file(load_path, "r") as fin: 317 | raw_config = fin.read() 318 | return cls.from_dict(json.loads(raw_config)) 319 | else: 320 | raise ValueError(f"Unsupported load config type: {load_type}") 321 | 322 | 323 | class RMSNorm(nn.Module): 324 | dim: int 325 | eps: float = 1e-6 326 | dtype: jnp.dtype = jnp.float32 327 | param_dtype: jnp.dtype = jnp.float32 328 | 329 | def setup(self) -> None: 330 | self.weight = self.param("kernel", nn.initializers.ones, (self.dim,), self.param_dtype) 331 | 332 | def _norm(self, x: jnp.ndarray) -> jnp.ndarray: 333 | return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) 334 | 335 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 336 | x = x.astype(jnp.promote_types(self.dtype, jnp.float32)) 337 | output = self._norm(x).astype(self.dtype) 338 | weight = jnp.asarray(self.weight, self.dtype) 339 | return output * weight 340 | 341 | 342 | class SwiGLUMLP(nn.Module): 343 | config: ModelConfig 344 | dtype: jnp.dtype = jnp.float32 345 | param_dtype: jnp.dtype = jnp.float32 346 | precision: Optional[Union[jax.lax.Precision, str]] = None 347 | 348 | def setup(self) -> None: 349 | config = self.config 350 | self.w1 = nn.Dense( 351 | config.intermediate_size, 352 | dtype=self.dtype, 353 | param_dtype=self.param_dtype, 354 | use_bias=False, 355 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 356 | precision=self.precision, 357 | ) 358 | self.w2 = nn.Dense( 359 | config.hidden_size, 360 | dtype=self.dtype, 361 | param_dtype=self.param_dtype, 362 | use_bias=False, 363 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 364 | precision=self.precision, 365 | ) 366 | self.w3 = nn.Dense( 367 | config.intermediate_size, 368 | dtype=self.dtype, 369 | param_dtype=self.param_dtype, 370 | use_bias=False, 371 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 372 | precision=self.precision, 373 | ) 374 | self.dropout = nn.Dropout(rate=self.config.resid_pdrop) 375 | 376 | def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: 377 | x = self.w2(nn.silu(self.w1(x)) * self.w3(x)) 378 | x = self.dropout(x, deterministic=deterministic) 379 | return x 380 | 381 | 382 | class ConvModule(nn.Module): 383 | config: ModelConfig 384 | dtype: jnp.dtype = jnp.float32 385 | param_dtype: jnp.dtype = jnp.float32 386 | precision: Optional[Union[jax.lax.Precision, str]] = None 387 | 388 | def setup(self): 389 | config = self.config 390 | 391 | if self.config.remat_conv != "": 392 | conv_module = nn_partitioning.remat( 393 | nn.Conv, policy=get_gradient_checkpoint_policy(self.config.remat_conv), prevent_cse=True 394 | ) 395 | else: 396 | conv_module = nn.Conv 397 | 398 | self.conv1 = conv_module( 399 | config.hidden_size, 400 | (config.conv_width,), 401 | padding="CAUSAL", 402 | feature_group_count=config.hidden_size, 403 | dtype=self.dtype, 404 | param_dtype=self.param_dtype, 405 | precision=self.precision, 406 | ) 407 | self.conv_norm = RMSNorm( 408 | self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype 409 | ) 410 | 411 | def __call__(self, hidden_states): 412 | x = hidden_states 413 | x = self.conv_norm(x) 414 | x = self.conv1(x) 415 | return x 416 | 417 | 418 | class Attention(nn.Module): 419 | config: ModelConfig 420 | dtype: jnp.dtype = jnp.float32 421 | param_dtype: jnp.dtype = jnp.float32 422 | precision: Optional[Union[jax.lax.Precision, str]] = None 423 | 424 | def setup(self): 425 | config = self.config 426 | self.embed_dim = config.hidden_size 427 | self.num_heads = config.num_attention_heads 428 | self.head_dim = self.embed_dim // self.num_heads 429 | 430 | self.wq = nn.Dense( 431 | config.num_attention_heads * self.head_dim, 432 | dtype=self.dtype, 433 | param_dtype=self.param_dtype, 434 | use_bias=False, 435 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 436 | precision=self.precision, 437 | ) 438 | self.wk = nn.Dense( 439 | config.num_attention_heads * self.head_dim, 440 | dtype=self.dtype, 441 | param_dtype=self.param_dtype, 442 | use_bias=False, 443 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 444 | precision=self.precision, 445 | ) 446 | self.wv = nn.Dense( 447 | config.num_attention_heads * self.head_dim, 448 | dtype=self.dtype, 449 | param_dtype=self.param_dtype, 450 | use_bias=False, 451 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 452 | precision=self.precision, 453 | ) 454 | self.wo = nn.Dense( 455 | config.hidden_size, 456 | dtype=self.dtype, 457 | param_dtype=self.param_dtype, 458 | use_bias=False, 459 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 460 | precision=self.precision, 461 | ) 462 | 463 | self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) 464 | 465 | self.causal_mask = make_causal_mask(jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool") 466 | 467 | self.freqs_cis = precompute_freqs_cis( 468 | self.head_dim, config.max_sequence_length * 2, theta=config.rope_theta, dtype=self.dtype 469 | ) 470 | 471 | def _split_heads(self, hidden_states): 472 | return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) 473 | 474 | def _merge_heads(self, hidden_states): 475 | return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) 476 | 477 | @nn.compact 478 | def _concatenate_to_cache(self, key, value, query, attention_mask): 479 | """ 480 | This function takes projected key, value states from a single input token and concatenates the states to cached 481 | states from previous steps. This function is slighly adapted from the official Flax repository: 482 | https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 483 | """ 484 | # detect if we're initializing by absence of existing cache data. 485 | is_initialized = self.has_variable("cache", "cached_key") 486 | cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) 487 | cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) 488 | cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) 489 | 490 | if is_initialized: 491 | *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape 492 | # update key, value caches with our new 1d spatial slices 493 | cur_index = cache_index.value 494 | indices = (0,) * len(batch_dims) + (cur_index, 0, 0) 495 | key = lax.dynamic_update_slice(cached_key.value, key, indices) 496 | value = lax.dynamic_update_slice(cached_value.value, value, indices) 497 | cached_key.value = key 498 | cached_value.value = value 499 | num_updated_cache_vectors = query.shape[1] 500 | cache_index.value = cache_index.value + num_updated_cache_vectors 501 | # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. 502 | pad_mask = jnp.broadcast_to( 503 | jnp.arange(max_length) < cur_index + num_updated_cache_vectors, 504 | tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), 505 | ) 506 | attention_mask = combine_masks(pad_mask, attention_mask) 507 | return key, value, attention_mask 508 | 509 | def __call__( 510 | self, 511 | hidden_states, 512 | attention_mask, 513 | position_ids, 514 | deterministic: bool = True, 515 | init_cache: bool = False, 516 | output_attentions: bool = False, 517 | fcm_mask=None, 518 | ): 519 | xq, xk, xv = (self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)) 520 | 521 | xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "mp")) 522 | xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), None, "mp")) 523 | xv = with_sharding_constraint(xv, PS(("dp", "fsdp"), None, "mp")) 524 | 525 | xq = self._split_heads(xq) 526 | xk = self._split_heads(xk) 527 | xv = self._split_heads(xv) 528 | 529 | freqs_cis = jnp.take(self.freqs_cis, position_ids, axis=0) 530 | 531 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype) 532 | 533 | dropout_rng = None 534 | if not deterministic and self.config.attn_pdrop > 0.0: 535 | dropout_rng = self.make_rng("dropout") 536 | 537 | if self.config.scan_attention and not (self.has_variable("cache", "cached_key") or init_cache): 538 | # doesn't need blockwise attention if we are doing autoregressive decoding since no quadratic memory 539 | 540 | # attention mask without nxn materlization, blockwise_attn will handle the rest 541 | attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) 542 | # transform boolean mask into float mask 543 | attention_bias = lax.select( 544 | attention_mask > 0, 545 | jnp.full(attention_mask.shape, 0.0).astype(self.dtype), 546 | jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), 547 | ) 548 | attn_weights = None 549 | attn_output = blockwise_attn( 550 | xq, 551 | xk, 552 | xv, 553 | bias=attention_bias, 554 | deterministic=deterministic, 555 | dropout_rng=dropout_rng, 556 | attn_pdrop=self.config.attn_pdrop, 557 | causal=True, 558 | query_chunk_size=self.config.scan_query_chunk_size, 559 | key_chunk_size=self.config.scan_key_chunk_size, 560 | dtype=self.dtype, 561 | policy=get_gradient_checkpoint_policy("nothing_saveable"), 562 | precision=self.precision, 563 | float32_logits=True, 564 | prevent_cse=True, 565 | ) 566 | attn_output = with_sharding_constraint(attn_output, PS(("dp", "fsdp"), None, "mp", None)) 567 | else: 568 | query_length, key_length = xq.shape[1], xk.shape[1] 569 | 570 | if self.has_variable("cache", "cached_key"): 571 | mask_shift = self.variables["cache"]["cache_index"] 572 | max_decoder_length = self.variables["cache"]["cached_key"].shape[1] 573 | causal_mask = lax.dynamic_slice( 574 | self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) 575 | ) 576 | else: 577 | causal_mask = self.causal_mask[:, :, :query_length, :key_length] 578 | 579 | batch_size = hidden_states.shape[0] 580 | causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) 581 | 582 | attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) 583 | attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask) 584 | 585 | # During fast autoregressive decoding, we feed one position at a time, 586 | # and cache the keys and values step by step. 587 | if self.has_variable("cache", "cached_key") or init_cache: 588 | xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask) 589 | 590 | # transform boolean mask into float mask 591 | attention_bias = lax.select( 592 | attention_mask > 0, 593 | jnp.full(attention_mask.shape, 0.0).astype(self.dtype), 594 | jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), 595 | ) 596 | attn_weights = dot_product_attention_weights( 597 | xq, 598 | xk, 599 | bias=attention_bias, 600 | dropout_rng=dropout_rng, 601 | dropout_rate=self.config.attn_pdrop, 602 | deterministic=deterministic, 603 | dtype=jnp.promote_types(self.dtype, jnp.float32), 604 | precision=self.precision, 605 | ) 606 | attn_weights = with_sharding_constraint(attn_weights, PS(("dp", "fsdp"), "mp", None, None)) 607 | attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision) 608 | 609 | attn_output = self._merge_heads(attn_output) 610 | attn_output = self.wo(attn_output) 611 | attn_output = self.resid_dropout(attn_output, deterministic=deterministic) 612 | outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) 613 | return outputs 614 | 615 | 616 | class Block(nn.Module): 617 | config: ModelConfig 618 | dtype: jnp.dtype = jnp.float32 619 | param_dtype: jnp.dtype = jnp.float32 620 | precision: Optional[Union[jax.lax.Precision, str]] = None 621 | 622 | def setup(self) -> None: 623 | if self.config.seq_modeling_block == "self_attention": 624 | seq_modeling_block = Attention 625 | 626 | elif self.config.seq_modeling_block == "ttt_linear": 627 | seq_modeling_block = TTTLinear 628 | 629 | elif self.config.seq_modeling_block == "ttt_mlp": 630 | seq_modeling_block = TTTMLP 631 | 632 | elif self.config.seq_modeling_block == "ttt_linear_base": 633 | seq_modeling_block = TTTLinearBase 634 | 635 | elif self.config.seq_modeling_block == "ttt_mlp_base": 636 | seq_modeling_block = TTTMLPBase 637 | 638 | else: 639 | raise NotImplementedError("Sequence Modeling Layer %s Not Implemented." % (self.config.seq_modeling_block)) 640 | 641 | mlp_module = SwiGLUMLP 642 | 643 | if self.config.remat_attention != "": 644 | if self.config.seq_modeling_block == "self_attention": 645 | static_argnums_tuple = (3, 4, 5) 646 | else: 647 | static_argnums_tuple = (3, 4) 648 | seq_modeling_block = remat( 649 | seq_modeling_block, 650 | static_argnums=static_argnums_tuple, 651 | policy=get_gradient_checkpoint_policy(self.config.remat_attention), 652 | prevent_cse=True, 653 | ) 654 | if self.config.remat_mlp != "": 655 | mlp_module = remat( 656 | SwiGLUMLP, 657 | static_argnums=(1,), 658 | policy=get_gradient_checkpoint_policy(self.config.remat_mlp), 659 | prevent_cse=True, 660 | ) 661 | 662 | self.seq_modeling_block = seq_modeling_block( 663 | self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision 664 | ) 665 | self.feed_forward = mlp_module( 666 | self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision 667 | ) 668 | self.seq_norm = RMSNorm( 669 | self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype 670 | ) 671 | self.ffn_norm = RMSNorm( 672 | self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype 673 | ) 674 | if self.config.pre_conv: 675 | self.conv = ConvModule( 676 | self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision 677 | ) 678 | 679 | def __call__( 680 | self, 681 | hidden_states, 682 | input_ids=None, 683 | attention_mask=None, 684 | position_ids=None, 685 | ttt_lr_mult=1.0, 686 | deterministic: bool = True, 687 | init_cache: bool = False, 688 | output_attentions: bool = False, 689 | output_ttt_stats: bool = False, 690 | fcm_mask: Optional[jnp.ndarray] = None, 691 | ): 692 | if self.config.pre_conv: 693 | conv_outputs = self.conv(hidden_states) 694 | hidden_states = hidden_states + conv_outputs 695 | 696 | hidden_states_pre_normed = self.seq_norm(hidden_states) 697 | 698 | if self.config.seq_modeling_block == "self_attention": 699 | seq_modeling_outputs = self.seq_modeling_block( 700 | hidden_states_pre_normed, 701 | attention_mask, 702 | position_ids, 703 | deterministic, 704 | init_cache, 705 | output_attentions, 706 | fcm_mask, 707 | ) 708 | else: 709 | seq_modeling_outputs = self.seq_modeling_block( 710 | hidden_states_pre_normed, input_ids, position_ids, deterministic, output_ttt_stats, ttt_lr_mult 711 | ) 712 | 713 | seq_modeling_output = seq_modeling_outputs[0] 714 | hidden_states = hidden_states + seq_modeling_output 715 | 716 | feed_forward_input = self.ffn_norm(hidden_states) 717 | if self.config.scan_mlp: 718 | feed_forward_hidden_states = blockwise_ffn( 719 | self.feed_forward, feed_forward_input, self.config.scan_mlp_chunk_size, deterministic 720 | ) 721 | else: 722 | feed_forward_hidden_states = self.feed_forward(feed_forward_input, deterministic) 723 | feed_forward_hidden_states = with_sharding_constraint( 724 | feed_forward_hidden_states, PS(("dp", "fsdp"), None, "mp") 725 | ) 726 | hidden_states = hidden_states + feed_forward_hidden_states 727 | 728 | if len(seq_modeling_outputs) > 1: 729 | if isinstance(seq_modeling_outputs[1], tuple): 730 | return (hidden_states,) + (seq_modeling_outputs[1],) 731 | else: 732 | return (hidden_states,) + ((seq_modeling_outputs[1],),) 733 | else: 734 | return (hidden_states,) 735 | 736 | 737 | class BlockCollection(nn.Module): 738 | config: ModelConfig 739 | dtype: jnp.dtype = jnp.float32 740 | param_dtype: jnp.dtype = jnp.float32 741 | precision: Optional[Union[jax.lax.Precision, str]] = None 742 | 743 | def setup(self): 744 | block = Block 745 | if self.config.remat_block != "": 746 | block = remat( 747 | Block, static_argnums=(5, 6, 7, 8), policy=get_gradient_checkpoint_policy(self.config.remat_block) 748 | ) 749 | self.blocks = [ 750 | block(self.config, name=str(i), dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision) 751 | for i in range(self.config.num_hidden_layers) 752 | ] 753 | 754 | def __call__( 755 | self, 756 | hidden_states, 757 | input_ids=None, 758 | attention_mask=None, 759 | position_ids=None, 760 | ttt_lr_mult=1.0, 761 | deterministic: bool = True, 762 | init_cache: bool = False, 763 | output_attentions: bool = False, 764 | output_hidden_states: bool = False, 765 | output_ttt_stats: bool = False, 766 | return_dict: bool = True, 767 | ): 768 | all_attentions = () if output_attentions else None 769 | all_hidden_states = () if output_hidden_states else None 770 | all_ttt_stats = () if output_ttt_stats else None 771 | 772 | if not deterministic and self.config.fcm_max_ratio > 0: 773 | batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1] 774 | fcm_ratio = jax.random.uniform( 775 | self.make_rng("fcm"), 776 | shape=(batch_size, 1, 1, 1), 777 | minval=self.config.fcm_min_ratio, 778 | maxval=self.config.fcm_max_ratio, 779 | ) 780 | fcm_mask = jax.random.uniform(self.make_rng("fcm"), shape=(batch_size, 1, 1, seq_length)) > fcm_ratio 781 | fcm_mask = fcm_mask.at[:, :, :, 0].set(True) 782 | fcm_mask = fcm_mask.astype("bool") 783 | else: 784 | fcm_mask = None 785 | 786 | for block in self.blocks: 787 | if output_hidden_states: 788 | all_hidden_states += (hidden_states,) 789 | 790 | layer_outputs = block( 791 | hidden_states, 792 | input_ids, 793 | attention_mask, 794 | position_ids, 795 | ttt_lr_mult, 796 | deterministic, 797 | init_cache, 798 | output_attentions, 799 | output_ttt_stats, 800 | fcm_mask, 801 | ) 802 | hidden_states = layer_outputs[0] 803 | 804 | if output_attentions: 805 | all_attentions += (layer_outputs[1],) 806 | 807 | if output_ttt_stats: 808 | all_ttt_stats += (layer_outputs[1],) 809 | 810 | outputs = (hidden_states, all_hidden_states, all_attentions, all_ttt_stats) 811 | return outputs 812 | 813 | 814 | class Model(nn.Module): 815 | config: ModelConfig 816 | dtype: jnp.dtype = jnp.float32 817 | param_dtype: jnp.dtype = jnp.float32 818 | precision: Optional[Union[jax.lax.Precision, str]] = None 819 | 820 | def setup(self): 821 | self.embed_dim = self.config.hidden_size 822 | self.wte = nn.Embed( 823 | self.config.vocab_size, 824 | self.config.hidden_size, 825 | embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), 826 | dtype=self.dtype, 827 | param_dtype=self.param_dtype, 828 | ) 829 | self.dropout = nn.Dropout(rate=self.config.embd_pdrop) 830 | self.h = BlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision) 831 | self.ln_f = RMSNorm( 832 | self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype 833 | ) 834 | 835 | def __call__( 836 | self, 837 | input_ids, 838 | attention_mask, 839 | position_ids, 840 | ttt_lr_mult=1.0, 841 | deterministic=True, 842 | init_cache: bool = False, 843 | output_attentions: bool = False, 844 | output_hidden_states: bool = False, 845 | output_ttt_stats: bool = False, 846 | return_dict: bool = True, 847 | ): 848 | input_embeds = self.wte(input_ids.astype("i4")) 849 | hidden_states = self.dropout(input_embeds, deterministic=deterministic) 850 | outputs = self.h( 851 | hidden_states, 852 | input_ids=input_ids, 853 | attention_mask=attention_mask, 854 | position_ids=position_ids, 855 | ttt_lr_mult=ttt_lr_mult, 856 | deterministic=deterministic, 857 | init_cache=init_cache, 858 | output_attentions=output_attentions, 859 | output_hidden_states=output_hidden_states, 860 | output_ttt_stats=output_ttt_stats, 861 | return_dict=return_dict, 862 | ) 863 | hidden_states = outputs[0] 864 | hidden_states = self.ln_f(hidden_states) 865 | 866 | if output_hidden_states: 867 | all_hidden_states = outputs[1] + (hidden_states,) 868 | outputs = (hidden_states, all_hidden_states) + outputs[2:] 869 | else: 870 | outputs = (hidden_states,) + outputs[1:] 871 | 872 | if not return_dict: 873 | return tuple(v for v in outputs if v is not None) 874 | 875 | return BaseModelOutput( 876 | last_hidden_state=hidden_states, hidden_states=outputs[1], attentions=outputs[2], ttt_stats=outputs[3] 877 | ) 878 | 879 | 880 | class CausalLM(nn.Module): 881 | config: ModelConfig 882 | dtype: jnp.dtype = jnp.float32 883 | param_dtype: jnp.dtype = jnp.float32 884 | precision: Optional[Union[jax.lax.Precision, str]] = None 885 | 886 | def setup(self): 887 | self.model = Model(self.config, dtype=self.dtype, param_dtype=self.param_dtype) 888 | self.lm_head = nn.Dense( 889 | self.config.vocab_size, 890 | dtype=self.dtype, 891 | param_dtype=self.param_dtype, 892 | use_bias=False, 893 | kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), 894 | precision=self.precision, 895 | ) 896 | 897 | def __call__( 898 | self, 899 | input_ids, 900 | attention_mask=None, 901 | position_ids=None, 902 | ttt_lr_mult=1.0, 903 | deterministic: bool = True, 904 | init_cache: bool = False, 905 | output_attentions: bool = False, 906 | output_hidden_states: bool = False, 907 | output_ttt_stats: bool = False, 908 | return_dict: bool = True, 909 | ): 910 | batch_size, seq_length = input_ids.shape 911 | if attention_mask is None: 912 | attention_mask = jnp.ones_like(input_ids) 913 | if position_ids is None: 914 | position_ids = jnp.broadcast_to( 915 | jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), (batch_size, seq_length) 916 | ) 917 | outputs = self.model( 918 | input_ids, 919 | attention_mask, 920 | position_ids, 921 | ttt_lr_mult, 922 | deterministic=deterministic, 923 | init_cache=init_cache, 924 | output_attentions=output_attentions, 925 | output_hidden_states=output_hidden_states, 926 | output_ttt_stats=output_ttt_stats, 927 | return_dict=return_dict, 928 | ) 929 | 930 | hidden_states = outputs[0] 931 | 932 | if self.config.tie_word_embeddings: 933 | shared_kernel = self.model.variables["params"]["wte"]["embedding"].T 934 | lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) 935 | else: 936 | lm_logits = self.lm_head(hidden_states) 937 | 938 | if not return_dict: 939 | return (lm_logits,) + outputs[1:] 940 | 941 | return CausalLMOutput( 942 | logits=lm_logits, 943 | hidden_states=outputs.hidden_states, 944 | attentions=outputs.attentions, 945 | ttt_stats=outputs.ttt_stats, 946 | ) 947 | -------------------------------------------------------------------------------- /ttt/models/ttt_layer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | from functools import partial 5 | from typing import Any, Union, Sequence, Optional, Tuple 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import flax 10 | from jax import vmap 11 | from jax.tree_util import tree_map 12 | from jax.sharding import PartitionSpec as PS 13 | from flax import linen as nn 14 | from flax.linen import partitioning as nn_partitioning 15 | 16 | from ttt.infra.jax_utils import with_sharding_constraint, get_gradient_checkpoint_policy 17 | 18 | Axes = Union[int, Sequence[int]] 19 | 20 | 21 | def scan_remat_every_n_iterations_scan(f, n, carry, x): 22 | """ 23 | Remat every n mini batches. 24 | """ 25 | x_grouped = tree_map(lambda x: x.reshape((-1, n, *x.shape[1:])), x) 26 | carry, y_grouped = jax.lax.scan(jax.remat(partial(jax.lax.scan, f), prevent_cse=False), carry, x_grouped) 27 | y = tree_map(lambda x: x.reshape((-1, *x.shape[2:])), y_grouped) 28 | return carry, y 29 | 30 | 31 | def get_multi_head_params(self, params, param_dtype, kernel_init="normal", std=0.02): 32 | flat_params = flax.traverse_util.flatten_dict(params, sep="/") 33 | for k in flat_params.keys(): 34 | new_shape = (self.num_heads, *flat_params[k].shape) 35 | if "scale" in k: 36 | p = self.param(k, jax.nn.initializers.ones, new_shape, param_dtype) 37 | elif "kernel" in k: 38 | if kernel_init == "normal": 39 | initializer = nn.initializers.normal(std) 40 | elif kernel_init == "zeros": 41 | initializer = nn.initializers.zeros 42 | elif kernel_init == "ones": 43 | initializer = nn.initializers.ones 44 | else: 45 | raise NotImplementedError("Initializer %s Not Implemented." % (kernel_init)) 46 | p = self.param(k, initializer, new_shape, param_dtype) 47 | else: 48 | p = self.param(k, jax.nn.initializers.zeros, new_shape, param_dtype) 49 | flat_params[k] = p 50 | params_init = flax.traverse_util.unflatten_dict(flat_params, sep="/") 51 | return params_init 52 | 53 | 54 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype: jnp.dtype = jnp.float32) -> jnp.ndarray: 55 | freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim)) 56 | t = np.arange(end) 57 | freqs = np.outer(t, freqs).astype(dtype) 58 | sin, cos = np.sin(freqs), np.cos(freqs) 59 | freqs_cis = np.complex64(cos + 1j * sin) 60 | return jnp.asarray(freqs_cis) 61 | 62 | 63 | def apply_rotary_emb( 64 | xq: jnp.ndarray, xk: jnp.ndarray, freqs_cis: jnp.ndarray, dtype: jnp.dtype = jnp.float32 65 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 66 | 67 | reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) 68 | reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) 69 | 70 | xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1]) 71 | xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) 72 | 73 | freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:])) 74 | 75 | xq_out = xq_ * freqs_cis 76 | xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1) 77 | 78 | xk_out = xk_ * freqs_cis 79 | xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1) 80 | 81 | return xq_out.astype(dtype), xk_out.astype(dtype) 82 | 83 | 84 | def diff_gelu(x): 85 | tanh_out = jnp.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) 86 | ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) 87 | return ff 88 | 89 | 90 | class LinearLayerTemplate(nn.Module): 91 | width: int 92 | use_bias: bool 93 | name: str 94 | dtype: jnp.dtype = jnp.float32 95 | param_dtype: jnp.dtype = jnp.float32 96 | 97 | @nn.compact 98 | def __call__(self, x): 99 | x = nn.Dense( 100 | self.width, use_bias=self.use_bias, name=self.name, dtype=self.dtype, param_dtype=self.param_dtype 101 | )(x) 102 | return x 103 | 104 | 105 | class LayerNormTemplate(nn.Module): 106 | name: str 107 | dtype: jnp.dtype = jnp.float32 108 | param_dtype: jnp.dtype = jnp.float32 109 | 110 | @nn.compact 111 | def __call__(self, x): 112 | x = nn.LayerNorm(name=self.name, dtype=self.dtype, param_dtype=self.param_dtype)(x) 113 | return x 114 | 115 | 116 | class TTTBase(nn.Module): 117 | config: Any = None 118 | dtype: jnp.dtype = jnp.float32 119 | param_dtype: jnp.dtype = jnp.float32 120 | precision: Optional[Union[jax.lax.Precision, str]] = None 121 | 122 | def setup(self): 123 | self.width = self.config.hidden_size 124 | self.num_heads = self.config.num_attention_heads 125 | self.head_dim = self.width // self.num_heads 126 | self.mini_batch_size = self.config.mini_batch_size 127 | self.n_mini_batch = self.config.max_sequence_length // self.mini_batch_size 128 | self.seq_shape = (self.n_mini_batch, self.mini_batch_size) 129 | self.freqs_cis = precompute_freqs_cis( 130 | self.head_dim, self.mini_batch_size * 2, theta=self.config.rope_theta, dtype=self.dtype 131 | ) 132 | 133 | self.setup_qkvo() 134 | self.setup_token_idx() 135 | self.setup_ttt_lr_gate() 136 | 137 | self.ttt_norm = LayerNormTemplate(dtype=self.dtype, param_dtype=self.param_dtype) 138 | ttt_norm_params = self.ttt_norm.init(jax.random.PRNGKey(0), jnp.ones([1, self.head_dim]))["params"] 139 | self.ttt_norm_params = get_multi_head_params( 140 | self, ttt_norm_params, param_dtype=self.param_dtype, kernel_init="layer_norm" 141 | ) 142 | self.post_norm = nn.LayerNorm(dtype=self.dtype, param_dtype=self.param_dtype) 143 | 144 | self.ttt_params = () 145 | 146 | def setup_qkvo(self): 147 | self.wq = nn.Dense( 148 | self.num_heads * self.head_dim, 149 | dtype=self.dtype, 150 | param_dtype=self.param_dtype, 151 | use_bias=False, 152 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 153 | precision=self.precision, 154 | ) 155 | self.wk = nn.Dense( 156 | self.num_heads * self.head_dim, 157 | dtype=self.dtype, 158 | param_dtype=self.param_dtype, 159 | use_bias=False, 160 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 161 | precision=self.precision, 162 | ) 163 | self.wv = nn.Dense( 164 | self.num_heads * self.head_dim, 165 | dtype=self.dtype, 166 | param_dtype=self.param_dtype, 167 | use_bias=False, 168 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 169 | precision=self.precision, 170 | ) 171 | self.wo = nn.Dense( 172 | self.width, 173 | dtype=self.dtype, 174 | param_dtype=self.param_dtype, 175 | use_bias=False, 176 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 177 | precision=self.precision, 178 | ) 179 | 180 | def setup_token_idx(self): 181 | self.token_idx = 1.0 / jnp.arange(1, self.mini_batch_size + 1, dtype=jnp.float32) 182 | self.learnable_token_idx = self.param( 183 | "learnable_token_idx", nn.initializers.zeros, (self.mini_batch_size,), jnp.float32 184 | ) 185 | 186 | def setup_ttt_lr_gate(self): 187 | self.learnable_ttt_lr = LinearLayerTemplate( 188 | width=1, use_bias=True, name="learnable_ttt_lr", dtype=self.dtype, param_dtype=self.param_dtype 189 | ) 190 | learnable_ttt_lr_params = self.learnable_ttt_lr.init(jax.random.PRNGKey(0), jnp.ones([1, self.width]))["params"] 191 | self.learnable_ttt_lr_params = get_multi_head_params( 192 | self, 193 | learnable_ttt_lr_params, 194 | param_dtype=self.param_dtype, 195 | kernel_init="normal", 196 | std=self.config.initializer_range, 197 | ) 198 | 199 | def _split_heads(self, hidden_states): 200 | return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) 201 | 202 | def _split_mini_batches(self, hidden_states): 203 | B, N, num_head, head_dim = hidden_states.shape 204 | hidden_states = hidden_states.reshape(B, *self.seq_shape, self.num_heads, self.head_dim).transpose( 205 | 0, 3, 1, 2, 4 206 | ) 207 | return hidden_states 208 | 209 | def get_qkv_projections(self, batch): 210 | XQ, XK, XV = self.wq(batch), self.wk(batch), self.wv(batch) 211 | return XQ, XK, XV 212 | 213 | def get_eta(self, X): 214 | learnable_ttt_lr = vmap( 215 | lambda x, p: self.learnable_ttt_lr.apply({"params": p}, x), axis_name="head", in_axes=[None, 0], out_axes=1 216 | )(X, self.learnable_ttt_lr_params) 217 | learnable_ttt_lr = nn.sigmoid(learnable_ttt_lr) 218 | learnable_ttt_lr = learnable_ttt_lr.transpose(0, 1, 2, 4, 3) 219 | 220 | token_idx = self.learnable_token_idx + self.token_idx 221 | token_idx = jnp.clip(token_idx, a_min=0.0) 222 | 223 | eta = ( 224 | (self.config.ttt_base_lr * token_idx).reshape(1, 1, 1, token_idx.shape[0], -1) 225 | * learnable_ttt_lr 226 | / self.head_dim 227 | ) 228 | return eta 229 | 230 | def get_ttt_inputs(self, batch, position_ids): 231 | B, N, F = batch.shape 232 | n_mini_batch = N // self.mini_batch_size 233 | X = batch.reshape(B, *self.seq_shape, self.width) 234 | 235 | XQ, XK, XV = self.get_qkv_projections(batch) 236 | 237 | if self.config.output_ttt_stats: 238 | XV_last_in_mini_batch = XV[:, :: self.mini_batch_size, ...].reshape( 239 | B, n_mini_batch, self.num_heads, self.head_dim 240 | ) 241 | XK_last_in_mini_batch = XK[:, :: self.mini_batch_size, ...].reshape( 242 | B, n_mini_batch, self.num_heads, self.head_dim 243 | ) 244 | ssl_tgt_last_in_mini_batch = XV_last_in_mini_batch - XK_last_in_mini_batch 245 | ssl_tgt_mean = (XV - XK).mean(axis=1, keepdims=True).reshape(B, 1, self.num_heads, self.head_dim) 246 | ssl_tgt_last_in_mini_batch_from_mean_mse = ((ssl_tgt_last_in_mini_batch - ssl_tgt_mean) ** 2).mean( 247 | axis=(0, 2, 3) 248 | ) 249 | else: 250 | ssl_tgt_last_in_mini_batch_from_mean_mse = None 251 | 252 | XQ = with_sharding_constraint(XQ, PS(("dp", "fsdp"), None, "mp")) 253 | XK = with_sharding_constraint(XK, PS(("dp", "fsdp"), None, "mp")) 254 | XV = with_sharding_constraint(XV, PS(("dp", "fsdp"), None, "mp")) 255 | 256 | XQ = self._split_heads(XQ) 257 | XK = self._split_heads(XK) 258 | XV = self._split_heads(XV) 259 | 260 | freqs_cis = jnp.take(self.freqs_cis, position_ids % self.mini_batch_size, axis=0) 261 | XQ, XK = apply_rotary_emb(XQ, XK, freqs_cis=freqs_cis, dtype=self.dtype) 262 | 263 | XQ = self._split_mini_batches(XQ) 264 | XK = self._split_mini_batches(XK) 265 | XV = self._split_mini_batches(XV) 266 | 267 | eta = self.get_eta(X) 268 | 269 | return (XQ, XK, XV, eta, (ssl_tgt_last_in_mini_batch_from_mean_mse,)) 270 | 271 | def apply_gate(self, hidden_states, ttt_output): 272 | return ttt_output 273 | 274 | def project_ttt_outputs(self, XQW_batch): 275 | z_batch = self.wo(XQW_batch) 276 | return z_batch 277 | 278 | def process_mini_batch( 279 | self, 280 | XQ_mini_batch, 281 | XK_mini_batch, 282 | XV_mini_batch, 283 | eta_mini_batch, 284 | ttt_params_init, 285 | ttt_params_mini_batch_init, 286 | ttt_norm_params, 287 | ): 288 | raise NotImplementedError 289 | 290 | def ttt(self, XQ, XK, XV, eta, input_ids): 291 | B, N = XV.shape[0], XV.shape[2] * XV.shape[3] 292 | 293 | @partial(vmap, axis_name="batch") 294 | def update_embed(XQ, XK, XV, eta): 295 | @partial(vmap, axis_name="head") 296 | def parallelize_over_heads(XQ, XK, XV, eta, ttt_params_init, ttt_norm_params): 297 | def compute_mini_batch(ttt_params_mini_batch_init, inputs): 298 | XQ_mini_batch = inputs["XQ"] 299 | XK_mini_batch = inputs["XK"] 300 | XV_mini_batch = inputs["XV"] 301 | eta_mini_batch = inputs["eta"] 302 | 303 | ttt_params_last_in_mini_batch, outputs = self.process_mini_batch( 304 | XQ_mini_batch, 305 | XK_mini_batch, 306 | XV_mini_batch, 307 | eta_mini_batch, 308 | ttt_params_init, 309 | ttt_params_mini_batch_init, 310 | ttt_norm_params, 311 | ) 312 | return ttt_params_last_in_mini_batch, outputs 313 | 314 | inputs = {"XQ": XQ, "XK": XK, "XV": XV, "eta": eta} 315 | 316 | _, outputs = scan_remat_every_n_iterations_scan( 317 | compute_mini_batch, self.config.remat_mini_batch_group_size, ttt_params_init, inputs 318 | ) 319 | Z, ttt_loss_mse_init, ttt_loss_mse_step_0, ttt_loss_mse_step_1 = outputs 320 | return (Z.reshape(-1, self.head_dim), ttt_loss_mse_init, ttt_loss_mse_step_0, ttt_loss_mse_step_1) 321 | 322 | outputs = parallelize_over_heads(XQ, XK, XV, eta, self.ttt_params, self.ttt_norm_params) 323 | return outputs 324 | 325 | outputs = update_embed(XQ, XK, XV, eta) 326 | Z, ttt_loss_mse_init, ttt_loss_mse_step_0, ttt_loss_mse_step_1 = outputs 327 | Z = Z.transpose(0, 2, 1, 3).reshape(B, N, -1) 328 | 329 | if self.config.output_ttt_stats: 330 | ttt_loss_mse_init = ttt_loss_mse_init.mean(axis=(0, 1)) 331 | ttt_loss_mse_step_0 = ttt_loss_mse_step_0.mean(axis=(0, 1)) 332 | ttt_loss_mse_step_1 = ttt_loss_mse_step_1.mean(axis=(0, 1)) 333 | 334 | return Z, (ttt_loss_mse_init, ttt_loss_mse_step_0, ttt_loss_mse_step_1) 335 | 336 | def __call__( 337 | self, 338 | hidden_states, 339 | input_ids=None, 340 | position_ids=None, 341 | deterministic: bool = True, 342 | output_ttt_stats: bool = False, 343 | ttt_lr_mult=1.0, 344 | ): 345 | self.config.output_ttt_stats = output_ttt_stats 346 | del deterministic 347 | XQ, XK, XV, eta, precompute_stats = self.get_ttt_inputs(hidden_states, position_ids=position_ids) 348 | eta *= ttt_lr_mult 349 | Z, ttt_stats = self.ttt(XQ, XK, XV, eta, input_ids) 350 | Z = self.post_norm(Z) 351 | Z = self.apply_gate(hidden_states, Z) 352 | ttt_output = self.project_ttt_outputs(Z) 353 | return ttt_output, (*precompute_stats, *ttt_stats) 354 | 355 | 356 | class TTTLinearBase(TTTBase): 357 | def setup(self): 358 | super().setup() 359 | self.W1 = self.param( 360 | "ttt_dense_0", 361 | nn.initializers.normal(self.config.initializer_range), 362 | (self.num_heads, self.head_dim, self.head_dim), 363 | self.param_dtype, 364 | ) 365 | self.b1 = self.param("ttt_bias_0", nn.initializers.zeros, (self.num_heads, 1, self.head_dim), self.param_dtype) 366 | self.ttt_params = (self.W1, self.b1) 367 | 368 | def process_mini_batch( 369 | self, 370 | XQ_mini_batch, 371 | XK_mini_batch, 372 | XV_mini_batch, 373 | eta_mini_batch, 374 | ttt_params_init, 375 | ttt_params_mini_batch_init, 376 | ttt_norm_params, 377 | ): 378 | 379 | W1_init, b1_init = ttt_params_mini_batch_init 380 | square_eta_mini_batch = eta_mini_batch[: self.mini_batch_size] 381 | last_eta_in_mini_batch = eta_mini_batch[-1][:, None] 382 | 383 | X1 = XK_mini_batch 384 | Z1 = X1 @ W1_init + b1_init 385 | ttt_norm_out, ttt_norm_vjp = jax.vjp(lambda z: self.ttt_norm.apply({"params": ttt_norm_params}, z), Z1) 386 | ssl_target = XV_mini_batch - XK_mini_batch 387 | grad_l_wrt_ttt_norm_out = ttt_norm_out - ssl_target 388 | grad_l_wrt_Z1 = ttt_norm_vjp(grad_l_wrt_ttt_norm_out)[0] 389 | 390 | # Calculate TTT loss using W_init of the current mini-batch 391 | if self.config.output_ttt_stats: 392 | ttt_loss_mse_step_0 = (grad_l_wrt_ttt_norm_out[-1] ** 2).mean() 393 | else: 394 | ttt_loss_mse_step_0 = None 395 | 396 | # Calculate TTT loss using W_init of the entire sequence 397 | if self.config.output_ttt_stats: 398 | W1_0, b1_0 = ttt_params_init 399 | Z1_0 = X1 @ W1_0 + b1_0 400 | ttt_norm_out_0 = self.ttt_norm.apply({"params": ttt_norm_params}, Z1_0) 401 | ttt_loss_mse_init = ((ttt_norm_out_0 - ssl_target)[-1] ** 2).mean() 402 | else: 403 | ttt_loss_mse_init = None 404 | 405 | X1_bar = XQ_mini_batch 406 | Attn1 = jnp.tril(X1_bar @ X1.transpose(1, 0)) 407 | b1_bar = b1_init - (square_eta_mini_batch * jnp.tril(jnp.ones_like(Attn1))) @ grad_l_wrt_Z1 408 | Z1_bar = X1_bar @ W1_init - (square_eta_mini_batch * Attn1) @ grad_l_wrt_Z1 + b1_bar 409 | ttt_norm_out_bar = self.ttt_norm.apply({"params": ttt_norm_params}, Z1_bar) 410 | 411 | output_mini_batch = X1_bar + ttt_norm_out_bar 412 | 413 | W1_bar_last = W1_init - (last_eta_in_mini_batch * X1).transpose(1, 0) @ grad_l_wrt_Z1 414 | b1_bar_last = b1_init - jnp.sum(last_eta_in_mini_batch * grad_l_wrt_Z1, axis=0, keepdims=True) 415 | 416 | # Calculate ttt loss using the updated W_init by the current mini-batch 417 | if self.config.output_ttt_stats: 418 | X1_last_fwd_new = X1[-1:] @ W1_bar_last + b1_bar_last 419 | X1_last_fwd_new = self.ttt_norm.apply({"params": ttt_norm_params}, X1_last_fwd_new) 420 | ttt_loss_mse_step_1 = ((X1_last_fwd_new - ssl_target[-1:]) ** 2).mean() 421 | else: 422 | ttt_loss_mse_step_1 = None 423 | 424 | ttt_params_mini_batch_new = (W1_bar_last, b1_bar_last) 425 | 426 | return ( 427 | ttt_params_mini_batch_new, 428 | (output_mini_batch, ttt_loss_mse_init, ttt_loss_mse_step_0, ttt_loss_mse_step_1), 429 | ) 430 | 431 | 432 | class TTTLinear(TTTLinearBase): 433 | def setup(self): 434 | super().setup() 435 | self.wg = nn.Dense( 436 | self.width, 437 | dtype=self.dtype, 438 | param_dtype=self.param_dtype, 439 | use_bias=False, 440 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 441 | precision=self.precision, 442 | ) 443 | 444 | def setup_qkvo(self): 445 | self.wq = nn.Dense( 446 | self.num_heads * self.head_dim, 447 | dtype=self.dtype, 448 | param_dtype=self.param_dtype, 449 | use_bias=False, 450 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 451 | precision=self.precision, 452 | ) 453 | if self.config.remat_conv != "": 454 | conv_module = nn_partitioning.remat( 455 | nn.Conv, policy=get_gradient_checkpoint_policy(self.config.remat_conv), prevent_cse=True 456 | ) 457 | else: 458 | conv_module = nn.Conv 459 | self.conv_q = conv_module( 460 | self.config.hidden_size, 461 | (self.config.conv_width,), 462 | padding="CAUSAL", 463 | feature_group_count=self.config.hidden_size, 464 | dtype=self.dtype, 465 | param_dtype=self.param_dtype, 466 | precision=self.precision, 467 | ) 468 | self.conv_k = conv_module( 469 | self.config.hidden_size, 470 | (self.config.conv_width,), 471 | padding="CAUSAL", 472 | feature_group_count=self.config.hidden_size, 473 | dtype=self.dtype, 474 | param_dtype=self.param_dtype, 475 | precision=self.precision, 476 | ) 477 | self.wv = nn.Dense( 478 | self.num_heads * self.head_dim, 479 | dtype=self.dtype, 480 | param_dtype=self.param_dtype, 481 | use_bias=False, 482 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 483 | precision=self.precision, 484 | ) 485 | self.wo = nn.Dense( 486 | self.width, 487 | dtype=self.dtype, 488 | param_dtype=self.param_dtype, 489 | use_bias=False, 490 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 491 | precision=self.precision, 492 | ) 493 | 494 | def get_qkv_projections(self, batch): 495 | xqk, XV = self.wq(batch), self.wv(batch) 496 | XQ = self.conv_q(xqk) 497 | XK = self.conv_k(xqk) 498 | return XQ, XK, XV 499 | 500 | def apply_gate(self, hidden_states, ttt_output): 501 | y = self.wg(hidden_states) 502 | y = nn.gelu(y) 503 | output = y * ttt_output 504 | return output 505 | 506 | 507 | class TTTMLPBase(TTTBase): 508 | def setup(self): 509 | super().setup() 510 | self.W1 = self.param( 511 | "ttt_dense_0", 512 | nn.initializers.normal(self.config.initializer_range), 513 | (self.num_heads, self.head_dim, 4 * self.head_dim), 514 | self.param_dtype, 515 | ) 516 | self.b1 = self.param( 517 | "ttt_bias_0", nn.initializers.zeros, (self.num_heads, 1, 4 * self.head_dim), self.param_dtype 518 | ) 519 | self.W2 = self.param( 520 | "ttt_dense_1", 521 | nn.initializers.normal(self.config.initializer_range), 522 | (self.num_heads, 4 * self.head_dim, self.head_dim), 523 | self.param_dtype, 524 | ) 525 | self.b2 = self.param("ttt_bias_1", nn.initializers.zeros, (self.num_heads, 1, self.head_dim), self.param_dtype) 526 | self.ttt_params = (self.W1, self.W2, self.b1, self.b2) 527 | 528 | def process_mini_batch( 529 | self, 530 | XQ_mini_batch, 531 | XK_mini_batch, 532 | XV_mini_batch, 533 | eta_mini_batch, 534 | ttt_params_init, 535 | ttt_params_mini_batch_init, 536 | ttt_norm_params, 537 | ): 538 | 539 | W1_init, W2_init, b1_init, b2_init = ttt_params_mini_batch_init 540 | square_eta_mini_batch = eta_mini_batch[: self.mini_batch_size] 541 | last_eta_in_mini_batch = eta_mini_batch[-1][:, None] 542 | 543 | X1 = XK_mini_batch 544 | Z1 = X1 @ W1_init + b1_init 545 | X2 = nn.gelu(Z1) 546 | Z2 = X2 @ W2_init + b2_init 547 | ttt_norm_out, ttt_norm_vjp = jax.vjp(lambda z: self.ttt_norm.apply({"params": ttt_norm_params}, z), Z2) 548 | 549 | ssl_target = XV_mini_batch - X1 550 | grad_l_wrt_ttt_norm_out = ttt_norm_out - ssl_target 551 | grad_l_wrt_Z2 = ttt_norm_vjp(grad_l_wrt_ttt_norm_out)[0] 552 | grad_l_wrt_Z1 = grad_l_wrt_Z2 @ W2_init.transpose(1, 0) * diff_gelu(Z1) 553 | 554 | if self.config.output_ttt_stats: 555 | ttt_loss_mse_step_0 = (grad_l_wrt_ttt_norm_out[-1] ** 2).mean() 556 | else: 557 | ttt_loss_mse_step_0 = None 558 | 559 | # Calculate ttt loss using W_init of the entire sequence 560 | if self.config.output_ttt_stats: 561 | W1_0, W2_0, b1_0, b2_0 = ttt_params_init 562 | Z1_0 = X1 @ W1_0 + b1_0 563 | X2_0 = nn.gelu(Z1_0) 564 | Z2_0 = X2_0 @ W2_0 + b2_0 565 | ttt_norm_out_0 = self.ttt_norm.apply({"params": ttt_norm_params}, Z2_0) 566 | ttt_loss_mse_init = ((ttt_norm_out_0 - ssl_target)[-1] ** 2).mean() 567 | else: 568 | ttt_loss_mse_init = None 569 | 570 | X1_bar = XQ_mini_batch 571 | Attn1 = jnp.tril(X1_bar @ X1.transpose(1, 0)) 572 | b1_bar = b1_init - (square_eta_mini_batch * jnp.tril(jnp.ones_like(Attn1))) @ grad_l_wrt_Z1 573 | Z1_bar = X1_bar @ W1_init - (square_eta_mini_batch * Attn1) @ grad_l_wrt_Z1 + b1_bar 574 | 575 | X2_bar = nn.gelu(Z1_bar) 576 | Attn2 = jnp.tril(X2_bar @ X2.transpose(1, 0)) 577 | b2_bar = b2_init - (square_eta_mini_batch * jnp.tril(jnp.ones_like(Attn2))) @ grad_l_wrt_Z2 578 | Z2_bar = X2_bar @ W2_init - (square_eta_mini_batch * Attn2) @ grad_l_wrt_Z2 + b2_bar 579 | ttt_norm_out_bar = self.ttt_norm.apply({"params": ttt_norm_params}, Z2_bar) 580 | 581 | output_mini_batch = X1_bar + ttt_norm_out_bar 582 | 583 | W1_bar_last = W1_init - (last_eta_in_mini_batch * X1).transpose(1, 0) @ grad_l_wrt_Z1 584 | W2_bar_last = W2_init - (last_eta_in_mini_batch * X2).transpose(1, 0) @ grad_l_wrt_Z2 585 | b1_bar_last = b1_init - jnp.sum(last_eta_in_mini_batch * grad_l_wrt_Z1, axis=0, keepdims=True) 586 | b2_bar_last = b2_init - jnp.sum(last_eta_in_mini_batch * grad_l_wrt_Z2, axis=0, keepdims=True) 587 | 588 | if self.config.output_ttt_stats: 589 | X1_last_fwd_new = nn.gelu((X1[-1:] @ W1_bar_last) + b1_bar_last) @ W2_bar_last + b2_bar_last 590 | X1_last_fwd_new = self.ttt_norm.apply({"params": ttt_norm_params}, X1_last_fwd_new) 591 | ttt_loss_mse_step_1 = ((X1_last_fwd_new - ssl_target[-1:]) ** 2).mean() 592 | else: 593 | ttt_loss_mse_step_1 = None 594 | 595 | ttt_params_mini_batch_new = (W1_bar_last, W2_bar_last, b1_bar_last, b2_bar_last) 596 | 597 | return ( 598 | ttt_params_mini_batch_new, 599 | (output_mini_batch, ttt_loss_mse_init, ttt_loss_mse_step_0, ttt_loss_mse_step_1), 600 | ) 601 | 602 | 603 | class TTTMLP(TTTMLPBase): 604 | def setup(self): 605 | super().setup() 606 | self.wg = nn.Dense( 607 | self.width, 608 | dtype=self.dtype, 609 | param_dtype=self.param_dtype, 610 | use_bias=False, 611 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 612 | precision=self.precision, 613 | ) 614 | 615 | def setup_qkvo(self): 616 | self.wq = nn.Dense( 617 | self.num_heads * self.head_dim, 618 | dtype=self.dtype, 619 | param_dtype=self.param_dtype, 620 | use_bias=False, 621 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 622 | precision=self.precision, 623 | ) 624 | if self.config.remat_conv != "": 625 | conv_module = nn_partitioning.remat( 626 | nn.Conv, policy=get_gradient_checkpoint_policy(self.config.remat_conv), prevent_cse=True 627 | ) 628 | else: 629 | conv_module = nn.Conv 630 | self.conv_q = conv_module( 631 | self.config.hidden_size, 632 | (self.config.conv_width,), 633 | padding="CAUSAL", 634 | feature_group_count=self.config.hidden_size, 635 | dtype=self.dtype, 636 | param_dtype=self.param_dtype, 637 | precision=self.precision, 638 | ) 639 | self.conv_k = conv_module( 640 | self.config.hidden_size, 641 | (self.config.conv_width,), 642 | padding="CAUSAL", 643 | feature_group_count=self.config.hidden_size, 644 | dtype=self.dtype, 645 | param_dtype=self.param_dtype, 646 | precision=self.precision, 647 | ) 648 | self.wv = nn.Dense( 649 | self.num_heads * self.head_dim, 650 | dtype=self.dtype, 651 | param_dtype=self.param_dtype, 652 | use_bias=False, 653 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 654 | precision=self.precision, 655 | ) 656 | self.wo = nn.Dense( 657 | self.width, 658 | dtype=self.dtype, 659 | param_dtype=self.param_dtype, 660 | use_bias=False, 661 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 662 | precision=self.precision, 663 | ) 664 | 665 | def get_qkv_projections(self, batch): 666 | xqk, XV = self.wq(batch), self.wv(batch) 667 | XQ = self.conv_q(xqk) 668 | XK = self.conv_k(xqk) 669 | return XQ, XK, XV 670 | 671 | def apply_gate(self, hidden_states, ttt_output): 672 | y = self.wg(hidden_states) 673 | y = nn.gelu(y) 674 | output = y * ttt_output 675 | return output 676 | -------------------------------------------------------------------------------- /ttt/train.py: -------------------------------------------------------------------------------- 1 | import mlxu 2 | import wandb 3 | import os.path as osp 4 | 5 | from tqdm import tqdm 6 | from copy import deepcopy 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | from jax.tree_util import tree_map 11 | from jax.experimental.pjit import pjit 12 | from jax.sharding import PartitionSpec as PS 13 | from jax.experimental.multihost_utils import process_allgather 14 | from flax.training.train_state import TrainState 15 | from flax.traverse_util import flatten_dict 16 | 17 | from ttt.infra.optimizers import OptimizerFactory 18 | from ttt.dataloader.language_modeling_hf import LMDataModule 19 | from ttt.infra.checkpoint import StreamingCheckpointer 20 | from ttt.models.model import ModelConfig, CausalLM 21 | from ttt.infra.jax_utils import ( 22 | JaxRNG, 23 | JaxDistributedConfig, 24 | next_rng, 25 | match_partition_rules, 26 | cross_entropy_loss_and_accuracy, 27 | global_norm, 28 | get_float_dtype_by_name, 29 | set_random_seed, 30 | average_metrics, 31 | get_weight_decay_mask, 32 | make_shard_and_gather_fns, 33 | with_sharding_constraint, 34 | master_print, 35 | log_ttt_stats, 36 | ) 37 | 38 | 39 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 40 | seed=0, 41 | mesh_dim="-1,64,1", 42 | dtype="fp32", 43 | eval_mode=False, 44 | load_part="trainstate", 45 | total_steps=100, 46 | load_model_config="", 47 | update_model_config="", 48 | save_checkpoint_freq=100, 49 | save_milestone_freq=0, 50 | dataset_path="", 51 | dataset_name="the_pile", 52 | tokenizer_name="meta-llama/Llama-2-7b-hf", 53 | seq_length=2048, 54 | global_batch_size=1, 55 | accum_steps=1, 56 | loader_workers=48, 57 | optimizer=OptimizerFactory.get_default_config(), 58 | checkpointer=StreamingCheckpointer.get_default_config(), 59 | exp_dir="", 60 | exp_name="", 61 | resume_exp_name="", 62 | resume_step="", 63 | jax_distributed=JaxDistributedConfig.get_default_config(), 64 | is_rollback_reshuffle=False, 65 | ) 66 | 67 | 68 | def make_train_step_fn(model, optimizer_info, model_config, accum_steps=1): 69 | 70 | if accum_steps == 1: 71 | 72 | def train_step(train_state, rng, batch, ttt_lr_mult, output_ttt_stats=False): 73 | rng_generator = JaxRNG(rng) 74 | batch = with_sharding_constraint(batch, PS(("dp", "fsdp"))) 75 | 76 | def loss_and_accuracy(params): 77 | outputs = model.apply( 78 | params, 79 | batch["input_tokens"], 80 | ttt_lr_mult=ttt_lr_mult, 81 | deterministic=False, 82 | output_ttt_stats=output_ttt_stats, 83 | rngs=rng_generator(model_config.rng_keys()), 84 | ) 85 | logits = outputs.logits 86 | ttt_stats = outputs.ttt_stats 87 | loss, _ = cross_entropy_loss_and_accuracy(logits, batch["target_tokens"], batch["loss_masks"]) 88 | return loss, ttt_stats 89 | 90 | grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True) 91 | (loss, ttt_stats), grads = grad_fn(train_state.params) 92 | 93 | train_state = train_state.apply_gradients(grads=grads) 94 | learning_rate = optimizer_info["learning_rate_schedule"](train_state.step) 95 | grads_norm = global_norm(grads) 96 | 97 | return (train_state, loss, ttt_stats, grads_norm, learning_rate, rng_generator()) 98 | 99 | elif accum_steps > 1: 100 | 101 | def train_step(train_state, rng, batch, ttt_lr_mult, output_ttt_stats=False): 102 | rng_generator = JaxRNG(rng) 103 | rngs = rng_generator(model_config.rng_keys()) 104 | 105 | def computation(carry, micro_batch): 106 | sum_grads = carry["sum_grads"] 107 | micro_batch = with_sharding_constraint(micro_batch, PS(("dp", "fsdp"))) 108 | 109 | def loss_and_accuracy(params): 110 | outputs = model.apply( 111 | params, 112 | micro_batch["input_tokens"], 113 | ttt_lr_mult=ttt_lr_mult, 114 | deterministic=False, 115 | output_ttt_stats=output_ttt_stats, 116 | rngs=rngs, 117 | ) 118 | logits = outputs.logits 119 | ttt_stats = outputs.ttt_stats 120 | loss, _ = cross_entropy_loss_and_accuracy( 121 | logits, micro_batch["target_tokens"], micro_batch["loss_masks"] 122 | ) 123 | return loss, ttt_stats 124 | 125 | grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True) 126 | (loss, ttt_stats), grads = grad_fn(train_state.params) 127 | sum_grads = tree_map(lambda x, y: x + y, sum_grads, grads) 128 | carry_new = {"sum_grads": sum_grads} 129 | return carry_new, (loss, ttt_stats) 130 | 131 | sum_grads = jax.tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), train_state.params) 132 | carry_init = {"sum_grads": sum_grads} 133 | batch = tree_map(lambda x: x.reshape(FLAGS.accum_steps, -1, *x.shape[1:]), batch) 134 | carry_new, outputs = jax.lax.scan(computation, carry_init, batch) 135 | loss, ttt_stats = outputs 136 | loss = jnp.mean(loss) 137 | if output_ttt_stats: 138 | ttt_stats = tree_map(lambda x: jnp.mean(x, axis=0), ttt_stats) 139 | else: 140 | ttt_stats = None 141 | grads = jax.tree_util.tree_map(lambda x: x / FLAGS.accum_steps, carry_new["sum_grads"]) 142 | 143 | train_state = train_state.apply_gradients(grads=grads) 144 | learning_rate = optimizer_info["learning_rate_schedule"](train_state.step) 145 | grads_norm = global_norm(grads) 146 | 147 | return (train_state, loss, ttt_stats, grads_norm, learning_rate, rng_generator()) 148 | 149 | else: 150 | raise ValueError(f"Accum steps must >= 1, got {accum_steps}") 151 | 152 | return train_step 153 | 154 | 155 | def make_eval_step_fn(model, model_config): 156 | def eval_step(train_state, rng, batch): 157 | rng_generator = JaxRNG(rng) 158 | batch = with_sharding_constraint(batch, PS(("dp", "fsdp"))) 159 | logits = model.apply( 160 | train_state.params, batch["input_tokens"], deterministic=True, rngs=rng_generator(model_config.rng_keys()) 161 | ).logits 162 | loss, accuracy = cross_entropy_loss_and_accuracy(logits, batch["target_tokens"], batch["loss_masks"]) 163 | metrics = dict(eval_loss=loss, eval_accuracy=accuracy) 164 | return rng_generator(), metrics 165 | 166 | return eval_step 167 | 168 | 169 | def make_sharded_functions(model, optimizer, optimizer_info, model_config): 170 | def create_trainstate_from_params(params): 171 | return TrainState.create(params=params, tx=optimizer, apply_fn=None) 172 | 173 | def init_fn(rng): 174 | rng_generator = JaxRNG(rng) 175 | params = model.init( 176 | input_ids=jnp.zeros((4, FLAGS.seq_length), dtype=jnp.int32), 177 | position_ids=jnp.zeros((4, FLAGS.seq_length), dtype=jnp.int32), 178 | attention_mask=jnp.ones((4, FLAGS.seq_length), dtype=jnp.int32), 179 | rngs=rng_generator(model_config.rng_keys()), 180 | ) 181 | return TrainState.create(params=params, tx=optimizer, apply_fn=None) 182 | 183 | train_step = make_train_step_fn(model, optimizer_info, model_config, FLAGS.accum_steps) 184 | 185 | train_state_shapes = jax.eval_shape(init_fn, next_rng()) 186 | 187 | train_state_partition = match_partition_rules(model_config.get_partition_rules(), train_state_shapes) 188 | 189 | shard_fns, gather_fns = make_shard_and_gather_fns(train_state_partition, train_state_shapes) 190 | 191 | sharded_init_fn = pjit(init_fn, in_shardings=PS(), out_shardings=train_state_partition) 192 | 193 | sharded_create_trainstate_from_params = pjit( 194 | create_trainstate_from_params, 195 | in_shardings=(train_state_partition.params,), 196 | out_shardings=train_state_partition, 197 | donate_argnums=(0,), 198 | ) 199 | 200 | sharded_train_step = pjit( 201 | train_step, 202 | in_shardings=(train_state_partition, PS(), PS(), PS()), 203 | out_shardings=(train_state_partition, PS(), PS(), PS(), PS(), PS()), 204 | static_argnums=4, 205 | donate_argnums=(0,), 206 | ) 207 | 208 | return ( 209 | sharded_init_fn, 210 | sharded_create_trainstate_from_params, 211 | sharded_train_step, 212 | shard_fns, 213 | gather_fns, 214 | train_state_shapes, 215 | train_state_partition, 216 | ) 217 | 218 | 219 | def make_save_checkpoint(checkpointer, gather_fns, variant, flags_config_dict, model_config, global_batch_size): 220 | def save_checkpoint(train_state, train_loader, milestone=False): 221 | step = int(jax.device_get(train_state.step)) 222 | metadata = dict(step=step, variant=variant, flags=flags_config_dict, model_config=model_config.to_dict()) 223 | sampler_state_dict = { 224 | "random_state": train_loader.sampler.state_dict()["random_state"], 225 | "shuffle_log": train_loader.sampler.state_dict()["shuffle_log"], 226 | "counter": step * global_batch_size, 227 | } 228 | checkpointer.save_all( 229 | train_state=train_state, 230 | gather_fns=gather_fns, 231 | metadata=metadata, 232 | dataset=deepcopy(sampler_state_dict), 233 | milestone=milestone, 234 | ) 235 | 236 | return save_checkpoint 237 | 238 | 239 | def make_get_ttt_lr_mult(model_config): 240 | 241 | if ( 242 | hasattr(model_config, "ttt_base_lr_init") 243 | and model_config.ttt_base_lr_init > 0 244 | and model_config.ttt_base_lr_warmup > 0 245 | ): 246 | ttt_lr_mult_warmup_steps = model_config.ttt_base_lr_warmup 247 | ttt_lr_mult_init = model_config.ttt_base_lr_init 248 | ttt_lr_mult_peak = model_config.ttt_base_lr 249 | 250 | def get_ttt_lr_mult(step): 251 | ttt_lr_mult = ttt_lr_mult_init + min(1.0, (step - 1) / ttt_lr_mult_warmup_steps) * ( 252 | ttt_lr_mult_peak - ttt_lr_mult_init 253 | ) 254 | ttt_lr_mult = ttt_lr_mult / ttt_lr_mult_peak * jnp.ones((1,), dtype=jnp.bfloat16) 255 | return ttt_lr_mult 256 | 257 | else: 258 | 259 | def get_ttt_lr_mult(step): 260 | ttt_lr_mult = jnp.ones((1,), dtype=jnp.bfloat16) 261 | return ttt_lr_mult 262 | 263 | return get_ttt_lr_mult 264 | 265 | 266 | def initialize_or_resume( 267 | checkpointer, 268 | train_loader, 269 | train_state_shapes, 270 | sharded_init_fn, 271 | shard_fns, 272 | sharded_create_trainstate_from_params, 273 | FLAGS, 274 | ): 275 | start_step = 1 276 | train_state, restored_params = None, None 277 | if FLAGS.resume_exp_name != "": 278 | assert FLAGS.load_part in ["trainstate", "trainstate_params"] 279 | ckpt_resume_dir = ( 280 | FLAGS.load_part 281 | + "::" 282 | + osp.join( 283 | FLAGS.exp_dir, 284 | FLAGS.resume_exp_name, 285 | ( 286 | f"step_{int(FLAGS.resume_step)}/streaming_train_state_{int(FLAGS.resume_step)}" 287 | if FLAGS.resume_step 288 | else "streaming_train_state" 289 | ), 290 | ) 291 | ) 292 | train_state, restored_params = checkpointer.load_trainstate_checkpoint( 293 | ckpt_resume_dir, train_state_shapes, shard_fns 294 | ) 295 | 296 | if FLAGS.load_part == "trainstate": 297 | start_step = int(jax.device_get(train_state.step)) + 1 298 | master_print(f"Resuming training from checkpoint at step {start_step - 1}...") 299 | dataset_pkl_filename = ( 300 | f"step_{int(FLAGS.resume_step)}/dataset_{int(FLAGS.resume_step)}.pkl" 301 | if FLAGS.resume_step 302 | else "dataset.pkl" 303 | ) 304 | dataset_resume_dir = osp.join(FLAGS.exp_dir, FLAGS.resume_exp_name, dataset_pkl_filename) 305 | train_loader.sampler.load_state_dict(deepcopy(mlxu.load_pickle(dataset_resume_dir))) 306 | 307 | if FLAGS.is_rollback_reshuffle: 308 | train_loader.sampler.is_rollback = True 309 | 310 | if train_state is None and restored_params is None: 311 | train_state = sharded_init_fn(next_rng()) 312 | elif train_state is None and restored_params is not None: 313 | train_state = sharded_create_trainstate_from_params(restored_params) 314 | del restored_params 315 | 316 | return start_step, train_state, train_loader 317 | 318 | 319 | def main(argv): 320 | JaxDistributedConfig.initialize(FLAGS.jax_distributed) 321 | variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF) 322 | flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF) 323 | 324 | set_random_seed(FLAGS.seed) 325 | process_num = jax.process_count() 326 | global_dev_num = jax.device_count() 327 | local_dev_num = jax.local_device_count() 328 | master_process = jax.process_index() == 0 329 | 330 | dev_info = f"Process # {process_num}\tLocal dev # {local_dev_num}\tTotal dev # {global_dev_num}" 331 | master_print(dev_info) 332 | 333 | seq_length = FLAGS.seq_length 334 | global_batch_size = FLAGS.global_batch_size 335 | is_rollback_reshuffle = FLAGS.is_rollback_reshuffle 336 | 337 | # Create dataloader 338 | data_module = LMDataModule( 339 | dataset_name=FLAGS.dataset_name, 340 | dataset_config_name=None, 341 | tokenizer_name=FLAGS.tokenizer_name, 342 | cache_dir=FLAGS.dataset_path, 343 | max_length=seq_length, 344 | add_eos=True, 345 | batch_size=global_batch_size, 346 | batch_size_eval=global_batch_size, 347 | loader_workers=FLAGS.loader_workers, 348 | shuffle=True, 349 | fault_tolerant=True, 350 | drop_last=True, 351 | ) 352 | data_module.prepare_data() 353 | data_module.setup() 354 | train_loader = data_module.train_dataloader() 355 | 356 | # Update model model_config 357 | if FLAGS.load_model_config != "": 358 | model_config = ModelConfig.load_config(FLAGS.load_model_config) 359 | else: 360 | raise RuntimeError(f"model_config must be specified") 361 | if FLAGS.update_model_config: 362 | update_dic = eval(FLAGS.update_model_config) 363 | for key, value in update_dic.items(): 364 | if hasattr(model_config, key): 365 | setattr(model_config, key, value) 366 | else: 367 | raise KeyError(f"Update key {key} not in model_config") 368 | model_config.vocab_size = data_module.vocab_size 369 | model_config.max_sequence_length = seq_length 370 | flags_config_dict.model_config = model_config 371 | 372 | # Create WandB run and checkpointer 373 | if master_process: 374 | wandb.init(project="TTT-LM", config=flags_config_dict, name=FLAGS.exp_name) 375 | ckpt_dir = osp.join(FLAGS.exp_dir, FLAGS.exp_name) 376 | checkpointer = StreamingCheckpointer(FLAGS.checkpointer, ckpt_dir, enable=master_process) 377 | 378 | # Create model and optimizer 379 | model = CausalLM(model_config, dtype=get_float_dtype_by_name(FLAGS.dtype)) 380 | optimizer, optimizer_info = OptimizerFactory.get_optimizer( 381 | FLAGS.optimizer, get_weight_decay_mask(model_config.get_weight_decay_exclusions()) 382 | ) 383 | 384 | # Helper function for dynamic TTT learning rate 385 | get_ttt_lr_mult = make_get_ttt_lr_mult(model_config) 386 | 387 | # Create sharded train functions 388 | ( 389 | sharded_init_fn, 390 | sharded_create_trainstate_from_params, 391 | sharded_train_step, 392 | shard_fns, 393 | gather_fns, 394 | train_state_shapes, 395 | train_state_partition, 396 | ) = make_sharded_functions(model, optimizer, optimizer_info, model_config) 397 | 398 | save_checkpoint = make_save_checkpoint( 399 | checkpointer, gather_fns, variant, flags_config_dict, model_config, global_batch_size 400 | ) 401 | 402 | mesh = model_config.get_jax_mesh(FLAGS.mesh_dim) 403 | with mesh: 404 | sharded_rng = next_rng() 405 | 406 | start_step, train_state, train_loader = initialize_or_resume( 407 | checkpointer, 408 | train_loader, 409 | train_state_shapes, 410 | sharded_init_fn, 411 | shard_fns, 412 | sharded_create_trainstate_from_params, 413 | FLAGS, 414 | ) 415 | 416 | if FLAGS.eval_mode: 417 | eval_step = make_eval_step_fn(model, model_config) 418 | sharded_eval_step = pjit( 419 | eval_step, 420 | in_shardings=(train_state_partition, PS(), PS()), 421 | out_shardings=(PS(), PS()), 422 | donate_argnums=(1,), 423 | ) 424 | 425 | val_loader = data_module.val_dataloader() 426 | eval_metric_list = [] 427 | 428 | for eval_batch in tqdm(val_loader, disable=not master_process): 429 | for k in eval_batch.keys(): 430 | eval_batch[k] = eval_batch[k].numpy() 431 | sharded_rng, eval_metrics = sharded_eval_step(train_state, sharded_rng, eval_batch) 432 | eval_metric_list.append(eval_metrics) 433 | 434 | val_loss_avg = average_metrics(process_allgather(eval_metric_list))["eval_loss"].item() 435 | master_print(f"Eval Loss: {val_loss_avg:.4f}") 436 | exit(0) 437 | 438 | train_loader_iterator = iter(train_loader) 439 | 440 | for step in tqdm( 441 | range(start_step, FLAGS.total_steps + 1), 442 | initial=start_step, 443 | total=FLAGS.total_steps, 444 | disable=not master_process, 445 | desc=f"Training {FLAGS.exp_name}", 446 | ): 447 | try: 448 | batch = next(train_loader_iterator) 449 | except StopIteration: 450 | train_loader.sampler.counter = 0 451 | train_loader_iterator = iter(train_loader) 452 | batch = next(train_loader_iterator) 453 | 454 | if is_rollback_reshuffle: 455 | sampler_state_dict = { 456 | "random_state": train_loader.sampler.state_dict()["random_state"], 457 | "shuffle_log": train_loader.sampler.state_dict()["shuffle_log"], 458 | "counter": (step - 1) * global_batch_size, 459 | } 460 | if master_process and FLAGS.resume_exp_name != "": 461 | master_print("Updating sampler state after rollback...") 462 | dataset_pkl_filename = ( 463 | f"step_{int(FLAGS.resume_step)}/dataset_{int(FLAGS.resume_step)}.pkl" 464 | if FLAGS.resume_step 465 | else "dataset_state.pkl" 466 | ) 467 | dataset_resume_dir = osp.join(FLAGS.exp_dir, FLAGS.resume_exp_name, dataset_pkl_filename) 468 | mlxu.save_pickle(deepcopy(sampler_state_dict), dataset_resume_dir) 469 | is_rollback_reshuffle = False 470 | master_print("Finished updating sampler state.") 471 | 472 | for k in batch.keys(): 473 | batch[k] = batch[k].numpy() 474 | 475 | ttt_lr_mult = get_ttt_lr_mult(step) 476 | output_ttt_stats = ( 477 | FLAGS.save_milestone_freq > 0 478 | and step % FLAGS.save_milestone_freq == 0 479 | and model_config.seq_modeling_block != "self_attention" 480 | ) 481 | 482 | train_state, loss, ttt_stats, grads_norm, learning_rate, sharded_rng = sharded_train_step( 483 | train_state, sharded_rng, batch, ttt_lr_mult, output_ttt_stats 484 | ) 485 | 486 | if master_process: 487 | wandb.log( 488 | { 489 | "Train Loss": loss.item(), 490 | "Gradient Norm": grads_norm.item(), 491 | "Learning Rate": learning_rate.item(), 492 | }, 493 | step=step, 494 | ) 495 | 496 | if output_ttt_stats: 497 | for layer in range(len(ttt_stats)): 498 | ttt_stats_layer = process_allgather(ttt_stats[layer]) 499 | n_mini_batch = len(ttt_stats_layer[0]) 500 | x_axis = [model_config.mini_batch_size * i for i in range(1, n_mini_batch + 1)] 501 | log_ttt_stats(layer, ttt_stats_layer, x_axis, step) 502 | 503 | if (FLAGS.save_checkpoint_freq > 0 and step % FLAGS.save_checkpoint_freq == 0) or ( 504 | step == FLAGS.total_steps 505 | ): 506 | master_print(f"Saving checkpoint at step {step}, do not kill...") 507 | save_checkpoint(train_state, train_loader, step % FLAGS.save_milestone_freq == 0) 508 | 509 | if step == FLAGS.total_steps: 510 | master_print("Training has completed!") 511 | 512 | 513 | if __name__ == "__main__": 514 | mlxu.run(main) 515 | --------------------------------------------------------------------------------