├── requirements.txt
├── LICENSE
├── training
├── distributed_attention.py
├── dataset.py
├── train_language_model.py
├── modeling_flash_llama.py
└── trainer.py
├── .gitignore
├── train_sft.sh
├── train_64K.sh
├── train_512K.sh
└── README.md
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.32.1
2 | datasets==2.20.0
3 | mosaicml-cli==0.5.34
4 | mosaicml-streaming==0.8.1
5 | ninja==1.11.1.1
6 | numpy==1.26.3
7 | packaging==24.1
8 | pandas==2.2.1
9 | protobuf==4.25.3
10 | python-dateutil==2.9.0
11 | regex==2023.12.25
12 | sentencepiece==0.1.99
13 | tiktoken==0.7.0
14 | torch==2.4.1
15 | tqdm==4.66.4
16 | transformers==4.44.2
17 | triton==3.0.0
18 | wandb==0.17.3
19 | zstandard==0.22.0
20 | zstd==1.5.5.1
21 |
22 | flash-attn==2.6.1, --config-settings=--global-option="--no-build-isolation"
23 | rotary-emb @ git+https://github.com/Dao-AILab/flash-attention.git@9356a1c0389660d7e231ff3163c1ac17d9e3824a#subdirectory=csrc/rotary
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Princeton Natural Language Processing
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/training/distributed_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | # DeepSpeed Team
5 |
6 | import torch
7 |
8 | from typing import Any, Tuple
9 | from torch import Tensor
10 | from torch.nn import Module
11 |
12 | import torch.distributed as dist
13 |
14 | class SeqAllToAll(torch.autograd.Function):
15 | @staticmethod
16 | def forward(ctx: Any, input: Tensor, scatter_idx: int, gather_idx: int, group: Any) -> Tensor:
17 | ctx.scatter_idx = scatter_idx
18 | ctx.gather_idx = gather_idx
19 | ctx.group = group
20 |
21 | world_size = dist.get_world_size(group)
22 |
23 | input_list = [t.contiguous() for t in torch.tensor_split(input, world_size, scatter_idx)]
24 | output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
25 |
26 | dist.all_to_all(output_list, input_list, group=group)
27 | return torch.cat(output_list, dim=gather_idx).contiguous()
28 |
29 | @staticmethod
30 | def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None, None, None]:
31 | return (SeqAllToAll.apply(*grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.group), None, None, None)
32 |
33 |
34 | class DistributedAttention(torch.nn.Module):
35 | """Initialization.
36 |
37 | Arguments:
38 | local_attention (Module): local attention with q,k,v
39 | scatter_idx (int): scatter_idx for all2all comm
40 | gather_idx (int): gather_idx for all2all comm
41 | """
42 |
43 | def __init__(
44 | self,
45 | local_attention: Module,
46 | scatter_idx: int = -2,
47 | gather_idx: int = 1,
48 | ) -> None:
49 |
50 | super(DistributedAttention, self).__init__()
51 | self.local_attn = local_attention
52 | self.scatter_idx = scatter_idx # head axis
53 | self.gather_idx = gather_idx # seq axis
54 |
55 | def forward(self, query: Tensor, key_values: Tensor, *args, group: Any = None, **kwargs) -> Tensor:
56 | """ forward
57 |
58 | Arguments:
59 | query (Tensor): query input to the layer
60 | key (Tensor): key input to the layer
61 | value (Tensor): value input to the layer
62 | args: other args
63 |
64 | Returns:
65 | * output (Tensor): context output
66 | """
67 | #in shape : e.g., [s/p:h:]
68 | query_heads = SeqAllToAll.apply(query, self.scatter_idx, self.gather_idx, group)
69 | key_values_heads = SeqAllToAll.apply(key_values, self.scatter_idx, self.gather_idx, group)
70 |
71 | #out shape : e.g., [s:h/p:]
72 | output_heads = self.local_attn(query_heads, key_values_heads, *args, **kwargs)
73 |
74 | #out e.g., [s/p::h]
75 | return SeqAllToAll.apply(output_heads, self.gather_idx, self.scatter_idx, group)
76 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
164 | debug*
165 | checkpoints
166 | datasets
167 |
--------------------------------------------------------------------------------
/train_sft.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -l
2 | #SBATCH -J train_64K
3 | #SBATCH -N 1
4 | #SBATCH --output=slurm/%x-%j.out
5 | #SBATCH --gres=gpu:8
6 | #SBATCH --mem=400G
7 | #SBATCH -c 32
8 |
9 | # !!!! Load your own environment here !!!! #
10 | # !!!! Load your own environment here !!!! #
11 |
12 | # Fine-tune from this model
13 | model=${MODEL:-princeton-nlp/Llama-3-8B-ProLong-512k-Base}
14 | # Point to the base dir of the ProLong 64K data
15 | dataset=${DATASET:-"datasets"}
16 |
17 | # Directories in the dataset root folder where @ is followed by the mixing proportion
18 | domains=(
19 | prolong-ultrachat-64K@1.0
20 | )
21 | domains_name=ultrachat
22 |
23 |
24 | bsz=${BSZ:-64} # * 64k (seq len) = 4M
25 | seq=${SEQ:-1} # per-device batch size
26 | lr=${LR:-2e-5}
27 | steps=${STEPS:-250}
28 | save_steps=${SAVE:-250}
29 | warmup=${WARMUP:-0.05}
30 | suffix=${SUFFIX:-""} # for model saving name
31 |
32 |
33 | run_name="sft_$(basename $model)_$(basename $dataset)_${domains_name}_bsz${bsz}_steps${steps}_lr${lr}_warmup${warmup}${suffix}"
34 | out_dir="checkpoints/$run_name"
35 |
36 | if [ -z "$CUDA_VISIBLE_DEVICES" ]; then
37 | num_gpus=$(nvidia-smi -L | wc -l)
38 | else
39 | num_gpus=$(jq -n "[$CUDA_VISIBLE_DEVICES] | length")
40 | fi
41 | num_gpus=${NUM_GPUS:-$num_gpus}
42 |
43 | num_nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
44 | if [ $num_nodes == 0 ]; then
45 | num_nodes=1
46 | fi
47 | num_nodes=${NUM_NODES:-$num_nodes}
48 |
49 | # Gradient accumulation
50 | accu=$(($bsz / $seq / $num_gpus / $num_nodes))
51 |
52 |
53 | # [0] Disable
54 | # [1] FULL_SHARD (shards optimizer states, gradients and parameters),
55 | # [2] SHARD_GRAD_OP (shards optimizer states and gradients),
56 | # [3] NO_SHARD (DDP),
57 | # [4] HYBRID_SHARD (shards optimizer states, gradients and parameters within each node while each node has full copy),
58 | # [5] HYBRID_SHARD_ZERO2 (shards optimizer states and gradients within each node while each node has full copy). For more information, please refer the official PyTorch docs.
59 | fsdp=${FSDP:-"1"}
60 | gc=${GC:-"1"}
61 |
62 | export LOGIT_BLOCK_SIZE=2048 # Compute Llama logits in blocks of 2048 tokens
63 |
64 | mkdir -p $out_dir
65 | nvidia-smi
66 |
67 | if [ $num_nodes -gt 1 ]; then
68 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
69 | master_addr=${MASTER_ADDR:-$master_addr}
70 |
71 | # Launch via srun
72 | header="srun torchrun \
73 | --rdzv-backend=c10d \
74 | --rdzv-endpoint=$master_addr:56321 \
75 | --nnodes=$num_nodes \
76 | --nproc-per-node=$num_gpus \
77 | -m training.train_language_model"
78 | else
79 | master_port=$(comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 1)
80 |
81 | # Launch without srun
82 | header="torchrun \
83 | --rdzv-backend=c10d \
84 | --rdzv-endpoint=localhost:$master_port \
85 | --nnodes=1 \
86 | --nproc-per-node=$num_gpus \
87 | -m training.train_language_model"
88 | fi
89 | echo "slurm_nodelist=${SLURM_NODELIST} num_nodes=${num_nodes} master_addr=${master_addr} master_port=${master_port} num_gpus=${num_gpus}"
90 |
91 | export OMP_NUM_THREADS=$num_gpus
92 | export WANDB_PROJECT="prolong"
93 | export WANDB_DIR=$out_dir
94 | export WANDB_MODE="offline" # We turn off wandb online sync by default
95 | export TOKENIZERS_PARALLELISM=true
96 |
97 |
98 | base_arguments=(
99 | --report_to wandb
100 | --do_train
101 |
102 | --model_name $model
103 | --tokenizer_name $model
104 |
105 | --run_name $run_name
106 | --output_dir $out_dir
107 | --config_overrides_json "$overrides"
108 | --gradient_accumulation_steps $accu
109 | --per_device_train_batch_size $seq
110 | --per_device_eval_batch_size $seq
111 |
112 | --bf16
113 | --learning_rate $lr
114 | --min_lr_ratio 0.1
115 | --lr_scheduler_type cosine
116 | --max_grad_norm 1.0
117 | --adam_beta1 0.9
118 | --adam_beta2 0.95
119 | --weight_decay 0.1
120 | --warmup_ratio $warmup
121 | --optim adamw_torch
122 |
123 | --logging_steps 1
124 | --log_level info
125 |
126 | --max_steps $steps
127 | --save_steps $save_steps
128 | --dataloader_num_workers 1
129 |
130 | --disable_tqdm true
131 | --use_fast_tokenizer false
132 | --remove_unused_columns false
133 | --ddp_find_unused_parameters false
134 |
135 | --per_device_max_tokens 65536
136 |
137 | --cuda_empty_cache
138 |
139 | --apply_instruct_masks # mask out the tokens from instructions (instead of responses) when calculating losses
140 | --token_scaled_loss # average losses over valid training tokens instead of devices
141 | )
142 |
143 |
144 |
145 | if [ $fsdp -ne 0 ]; then
146 | export FSDP_SHARDING_STRATEGY=$fsdp
147 | base_arguments+=( --fsdp "auto_wrap" )
148 | # [1] FULL_STATE_DICT, [2] LOCAL_STATE_DICT, [3] SHARDED_STATE_DICT
149 | export FSDP_STATE_DICT_TYPE="FULL_STATE_DICT"
150 | fi
151 |
152 | if [ $gc -ne 0 ]; then
153 | base_arguments+=( --gradient_checkpointing )
154 | fi
155 |
156 | base_arguments+=( --tokenized_mds_train )
157 | for domain in "${domains[@]}"; do
158 | base_arguments+=( $dataset/$domain )
159 | done
160 |
161 | base_arguments+=( $@ )
162 |
163 | echo command: "${header} ${base_arguments[@]}"
164 | ${header} "${base_arguments[@]}" 2>&1 | tee -a $out_dir/log.out
165 |
--------------------------------------------------------------------------------
/train_64K.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -l
2 | #SBATCH -J train_64K
3 | #SBATCH -N 1
4 | #SBATCH --output=slurm/%x-%j.out
5 | #SBATCH --gres=gpu:8
6 | #SBATCH --mem=400G
7 | #SBATCH -c 32
8 |
9 | # !!!! Load your own environment here !!!! #
10 | # !!!! Load your own environment here !!!! #
11 |
12 | # Fine-tune from this model
13 | model=${MODEL:-meta-llama/Meta-Llama-3-8B-Instruct}
14 | # Point to the base dir of the ProLong 64K data
15 | dataset=${DATASET:-"datasets/long-context-65536"}
16 |
17 | # Directories in the dataset root folder where @ is followed by the mixing proportion
18 | domains=(
19 | thestackv1_concat_by_repo-65536@0.3
20 | book-65536@0.3
21 | fineweb-edu@0.1
22 | fineweb-2023-50@0.1
23 | stackexchange@0.04
24 | dolmawiki@0.04
25 | tuluv2@0.03
26 | arxiv@0.03
27 | openwebmath@0.03
28 | textbooks@0.03
29 | )
30 | domains_name=ProLong64KMix
31 |
32 |
33 | bsz=${BSZ:-64} # * 64k (seq len) = 4M
34 | seq=${SEQ:-1} # per-device batch size
35 | lr=${LR:-1e-5}
36 | steps=${STEPS:-5000}
37 | save_steps=${SAVE:-125}
38 | warmup=${WARMUP:-0.1}
39 | suffix=${SUFFIX:-""} # for model saving name
40 |
41 |
42 | run_name="lcft_$(basename $model)_$(basename $dataset)_${domains_name}_bsz${bsz}_steps${steps}_lr${lr}_warmup${warmup}${suffix}"
43 | out_dir="checkpoints/$run_name"
44 |
45 | if [ -z "$CUDA_VISIBLE_DEVICES" ]; then
46 | num_gpus=$(nvidia-smi -L | wc -l)
47 | else
48 | num_gpus=$(jq -n "[$CUDA_VISIBLE_DEVICES] | length")
49 | fi
50 | num_gpus=${NUM_GPUS:-$num_gpus}
51 |
52 | num_nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
53 | if [ $num_nodes == 0 ]; then
54 | num_nodes=1
55 | fi
56 | num_nodes=${NUM_NODES:-$num_nodes}
57 |
58 | # Gradient accumulation
59 | accu=$(($bsz / $seq / $num_gpus / $num_nodes))
60 |
61 |
62 | # [0] Disable
63 | # [1] FULL_SHARD (shards optimizer states, gradients and parameters),
64 | # [2] SHARD_GRAD_OP (shards optimizer states and gradients),
65 | # [3] NO_SHARD (DDP),
66 | # [4] HYBRID_SHARD (shards optimizer states, gradients and parameters within each node while each node has full copy),
67 | # [5] HYBRID_SHARD_ZERO2 (shards optimizer states and gradients within each node while each node has full copy). For more information, please refer the official PyTorch docs.
68 | fsdp=${FSDP:-"1"}
69 | gc=${GC:-"1"}
70 |
71 | export LOGIT_BLOCK_SIZE=2048 # Compute Llama logits in blocks of 2048 tokens
72 |
73 | mkdir -p $out_dir
74 | nvidia-smi
75 |
76 | if [ $num_nodes -gt 1 ]; then
77 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
78 | master_addr=${MASTER_ADDR:-$master_addr}
79 |
80 | # Launch via srun
81 | header="srun torchrun \
82 | --rdzv-backend=c10d \
83 | --rdzv-endpoint=$master_addr:56321 \
84 | --nnodes=$num_nodes \
85 | --nproc-per-node=$num_gpus \
86 | -m training.train_language_model"
87 | else
88 | master_port=$(comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 1)
89 |
90 | # Launch without srun
91 | header="torchrun \
92 | --rdzv-backend=c10d \
93 | --rdzv-endpoint=localhost:$master_port \
94 | --nnodes=1 \
95 | --nproc-per-node=$num_gpus \
96 | -m training.train_language_model"
97 | fi
98 | echo "slurm_nodelist=${SLURM_NODELIST} num_nodes=${num_nodes} master_addr=${master_addr} master_port=${master_port} num_gpus=${num_gpus}"
99 |
100 | export OMP_NUM_THREADS=$num_gpus
101 | export WANDB_PROJECT="prolong"
102 | export WANDB_DIR=$out_dir
103 | export WANDB_MODE="offline" # We turn off wandb online sync by default
104 | export TOKENIZERS_PARALLELISM=true
105 |
106 |
107 | base_arguments=(
108 | --report_to wandb
109 | --do_train
110 |
111 | --model_name $model
112 | --tokenizer_name $model
113 |
114 | --run_name $run_name
115 | --output_dir $out_dir
116 | --config_overrides_json "$overrides"
117 | --gradient_accumulation_steps $accu
118 | --per_device_train_batch_size $seq
119 | --per_device_eval_batch_size $seq
120 |
121 | --bf16
122 | --learning_rate $lr
123 | --min_lr_ratio 0.1
124 | --lr_scheduler_type cosine
125 | --max_grad_norm 1.0
126 | --adam_beta1 0.9
127 | --adam_beta2 0.95
128 | --weight_decay 0.1
129 | --warmup_ratio $warmup
130 | --optim adamw_torch
131 |
132 | --logging_steps 1
133 | --log_level info
134 |
135 | --max_steps $steps
136 | --save_steps $save_steps
137 | --dataloader_num_workers 1
138 |
139 | --disable_tqdm true
140 | --use_fast_tokenizer false
141 | --remove_unused_columns false
142 | --ddp_find_unused_parameters false
143 |
144 | --per_device_max_tokens 65536
145 |
146 | # --torch_compile
147 | --cuda_empty_cache
148 | --config_overrides "rope_theta=8000000"
149 | )
150 |
151 |
152 |
153 | if [ $fsdp -ne 0 ]; then
154 | export FSDP_SHARDING_STRATEGY=$fsdp
155 | base_arguments+=( --fsdp "auto_wrap" )
156 | # [1] FULL_STATE_DICT, [2] LOCAL_STATE_DICT, [3] SHARDED_STATE_DICT
157 | export FSDP_STATE_DICT_TYPE="FULL_STATE_DICT"
158 | fi
159 |
160 | if [ $gc -ne 0 ]; then
161 | base_arguments+=( --gradient_checkpointing )
162 | fi
163 |
164 | base_arguments+=( --tokenized_mds_train )
165 | for domain in "${domains[@]}"; do
166 | base_arguments+=( $dataset/$domain )
167 | done
168 |
169 | base_arguments+=( $@ )
170 |
171 | echo command: "${header} ${base_arguments[@]}"
172 | ${header} "${base_arguments[@]}" 2>&1 | tee -a $out_dir/log.out
173 |
--------------------------------------------------------------------------------
/train_512K.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -l
2 | #SBATCH -J train_512K
3 | #SBATCH -N 1
4 | #SBATCH --output=slurm/%x-%j.out
5 | #SBATCH --gres=gpu:8
6 | #SBATCH --mem=400G
7 | #SBATCH -c 32
8 |
9 | # !!!! Load your own environment here !!!! #
10 | # !!!! Load your own environment here !!!! #
11 |
12 | # Fine-tune from this model
13 | model=${MODEL:-meta-llama/Meta-Llama-3-8B-Instruct}
14 | # Point to the base dir of the ProLong 512K data
15 | dataset=${DATASET:-"datasets/long-context-524288"}
16 |
17 | # Directories in the dataset root folder where @ is followed by the mixing proportion
18 | domains=(
19 | thestackv1_concat_by_repo-524288@0.15
20 | thestackv1_concat_by_repo-65536@0.15
21 | book-524288@0.05
22 | book-65536@0.25
23 | fineweb-edu@0.1
24 | fineweb-2023-50@0.1
25 | stackexchange@0.04
26 | dolmawiki@0.04
27 | tuluv2@0.03
28 | arxiv@0.03
29 | openwebmath@0.03
30 | textbooks@0.03
31 | )
32 | domains_name=ProLong512KMix
33 |
34 |
35 | bsz=${BSZ:-128} # * 512K (seq len) / 8 (seq parallel size) = 8M
36 | seq=${SEQ:-1} # per-device batch size
37 | lr=${LR:-5e-6}
38 | steps=${STEPS:-2500}
39 | save_steps=${SAVE:-125}
40 | warmup=${WARMUP:-0.1}
41 | suffix=${SUFFIX:-""} # for model saving name
42 |
43 |
44 | run_name="lcft_$(basename $model)_$(basename $dataset)_${domains_name}_bsz${bsz}_steps${steps}_lr${lr}_warmup${warmup}${suffix}"
45 | out_dir="checkpoints/$run_name"
46 |
47 | if [ -z "$CUDA_VISIBLE_DEVICES" ]; then
48 | num_gpus=$(nvidia-smi -L | wc -l)
49 | else
50 | num_gpus=$(jq -n "[$CUDA_VISIBLE_DEVICES] | length")
51 | fi
52 | num_gpus=${NUM_GPUS:-$num_gpus}
53 |
54 | num_nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
55 | if [ $num_nodes == 0 ]; then
56 | num_nodes=1
57 | fi
58 | num_nodes=${NUM_NODES:-$num_nodes}
59 |
60 | # Gradient accumulation
61 | accu=$(($bsz / $seq / $num_gpus / $num_nodes))
62 |
63 |
64 | # [0] Disable
65 | # [1] FULL_SHARD (shards optimizer states, gradients and parameters),
66 | # [2] SHARD_GRAD_OP (shards optimizer states and gradients),
67 | # [3] NO_SHARD (DDP),
68 | # [4] HYBRID_SHARD (shards optimizer states, gradients and parameters within each node while each node has full copy),
69 | # [5] HYBRID_SHARD_ZERO2 (shards optimizer states and gradients within each node while each node has full copy). For more information, please refer the official PyTorch docs.
70 | fsdp=${FSDP:-"1"}
71 | gc=${GC:-"1"}
72 |
73 | export LOGIT_BLOCK_SIZE=2048 # Compute Llama logits in blocks of 2048 tokens
74 |
75 | mkdir -p $out_dir
76 | nvidia-smi
77 |
78 | if [ $num_nodes -gt 1 ]; then
79 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
80 | master_addr=${MASTER_ADDR:-$master_addr}
81 |
82 | # Launch via srun
83 | header="srun torchrun \
84 | --rdzv-backend=c10d \
85 | --rdzv-endpoint=$master_addr:56321 \
86 | --nnodes=$num_nodes \
87 | --nproc-per-node=$num_gpus \
88 | -m training.train_language_model"
89 | else
90 | master_port=$(comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 1)
91 |
92 | # Launch without srun
93 | header="torchrun \
94 | --rdzv-backend=c10d \
95 | --rdzv-endpoint=localhost:$master_port \
96 | --nnodes=1 \
97 | --nproc-per-node=$num_gpus \
98 | -m training.train_language_model"
99 | fi
100 | echo "slurm_nodelist=${SLURM_NODELIST} num_nodes=${num_nodes} master_addr=${master_addr} master_port=${master_port} num_gpus=${num_gpus}"
101 |
102 | export OMP_NUM_THREADS=$num_gpus
103 | export WANDB_PROJECT="prolong"
104 | export WANDB_DIR=$out_dir
105 | export WANDB_MODE="offline" # We turn off wandb online sync by default
106 | export TOKENIZERS_PARALLELISM=true
107 |
108 |
109 | base_arguments=(
110 | --report_to wandb
111 | --do_train
112 |
113 | --model_name $model
114 | --tokenizer_name $model
115 |
116 | # Initialize model + optimizer state with ProLong64K (please follow the README for the correct setup)
117 | --resume_from_checkpoint path/to/the/root/64K/checkpoint/folder
118 |
119 | --run_name $run_name
120 | --output_dir $out_dir
121 | --config_overrides_json "$overrides"
122 | --gradient_accumulation_steps $accu
123 | --per_device_train_batch_size $seq
124 | --per_device_eval_batch_size $seq
125 |
126 | --bf16
127 | --learning_rate $lr
128 | --min_lr_ratio 0.1
129 | --lr_scheduler_type cosine
130 | --max_grad_norm 1.0
131 | --adam_beta1 0.9
132 | --adam_beta2 0.95
133 | --weight_decay 0.1
134 | --warmup_ratio $warmup
135 | --optim adamw_torch
136 |
137 | --logging_steps 1
138 | --log_level info
139 |
140 | --max_steps $steps
141 | --save_steps $save_steps
142 | --dataloader_num_workers 1
143 |
144 | --disable_tqdm true
145 | --use_fast_tokenizer false
146 | --remove_unused_columns false
147 | --ddp_find_unused_parameters false
148 |
149 | --per_device_max_tokens 524288
150 |
151 | # --torch_compile
152 | --cuda_empty_cache
153 | --config_overrides "rope_theta=128000000"
154 |
155 | --seq_parallel_size 8
156 | )
157 |
158 |
159 |
160 | if [ $fsdp -ne 0 ]; then
161 | export FSDP_SHARDING_STRATEGY=$fsdp
162 | base_arguments+=( --fsdp "auto_wrap" )
163 | # [1] FULL_STATE_DICT, [2] LOCAL_STATE_DICT, [3] SHARDED_STATE_DICT
164 | export FSDP_STATE_DICT_TYPE="FULL_STATE_DICT"
165 | fi
166 |
167 | if [ $gc -ne 0 ]; then
168 | base_arguments+=( --gradient_checkpointing )
169 | fi
170 |
171 | base_arguments+=( --tokenized_mds_train )
172 | for domain in "${domains[@]}"; do
173 | base_arguments+=( $dataset/$domain )
174 | done
175 |
176 | base_arguments+=( $@ )
177 |
178 | echo command: "${header} ${base_arguments[@]}"
179 | ${header} "${base_arguments[@]}" 2>&1 | tee -a $out_dir/log.out
180 |
--------------------------------------------------------------------------------
/training/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from streaming import StreamingDataset, Stream
5 | import logging
6 |
7 | from itertools import islice
8 |
9 | from typing import Dict, Any, List, Tuple
10 | from collections.abc import Iterator
11 |
12 | from training.trainer import TrainingArguments
13 |
14 | from dataclasses import dataclass, field
15 | from typing import Optional, List
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | @dataclass
21 | class DataArguments:
22 | single_seq: bool = field(default=False, metadata={"help": "Ignore the document boundaries and treat the whole packed sequence as a single sequence"})
23 | per_device_max_tokens: Optional[int] = field(default=4_294_967_296, metadata={"help": "Maximum number of tokens per device; this is to avoid some catastrophic cases where the indices or data sequences are not filtered/truncated properly in preprocessing"})
24 | apply_instruct_masks: bool = field(default=False, metadata={"help": "Whether to apply loss masks over the instructions (for instruction tuning). If enabled, will read the `mask` field in the data and set the corresponding labels to -100."})
25 |
26 |
27 | class SafeStream(Stream):
28 | """Safe if multiple processes try to decompress the same shard."""
29 |
30 | def _decompress_shard_part(self, zip_info, zip_filename, raw_filename, compression):
31 | unique_extension = "." + str(os.getenv("SLURM_JOB_ID", "local")) + "-" + str(os.getpid())
32 | super()._decompress_shard_part(zip_info, zip_filename, raw_filename + unique_extension, compression)
33 | os.rename(raw_filename + unique_extension, raw_filename)
34 |
35 |
36 | class DataCollator:
37 | def __init__(self, tokenizer, args: DataArguments):
38 | self.tokenizer = tokenizer
39 | self.args = args
40 |
41 | @torch.no_grad()
42 | def __call__(self, features):
43 | input_ids = []
44 | labels = []
45 | seq_lengths = []
46 |
47 | available_tokens = self.args.per_device_max_tokens
48 | for item in features:
49 | apply_instruct_masks = self.args.apply_instruct_masks and ("mask" in item)
50 | indices = item["indices"] if "indices" in item else [(0, len(item["input_ids"]))]
51 | if self.args.single_seq:
52 | indices = [(0, len(item["input_ids"]))]
53 |
54 | label_seq = torch.tensor(item["input_ids"], dtype=torch.long)
55 |
56 | for a, b in indices:
57 | b = a + min(b - a, available_tokens)
58 | if b - a > 1:
59 | input_seq = torch.tensor(item["input_ids"][a:b], dtype=torch.long)
60 | input_ids.append(input_seq)
61 |
62 | _label = label_seq[a:b]
63 | _label[0] = -100 # Don't predict the first token
64 | if apply_instruct_masks:
65 | # Read the `mask` field and set the corresponding labels to -100
66 | mask = torch.tensor(item["mask"][a:b], dtype=torch.long)
67 | _label[mask == 0] = -100
68 | labels.append(_label)
69 |
70 | seq_lengths.append(b - a)
71 | available_tokens -= b - a
72 | elif available_tokens <= 0:
73 | assert available_tokens == 0, "Available tokens should be non-negative"
74 | break
75 |
76 | input_ids = torch.concat(input_ids, dim=0)
77 | labels = torch.concat(labels, dim=0)
78 | seq_lengths = torch.tensor(seq_lengths, dtype=torch.long)
79 |
80 | return dict(input_ids=input_ids,
81 | attention_mask=None,
82 | labels=labels,
83 | seq_lengths=seq_lengths)
84 |
85 |
86 |
87 | class SortByLengthDataset(StreamingDataset):
88 | def __init__(self, *args, sort_by_length_size=1, data_args=None, **kwargs):
89 | super().__init__(*args, **kwargs)
90 | self.sort_by_length_size = sort_by_length_size
91 | self.data_args = data_args
92 |
93 | def _negative_item_cost(self, item):
94 | if "indices" in item:
95 | return -sum(
96 | (end - start)**2 for start, end in item["indices"]
97 | )
98 | elif "length" in item:
99 | return -item["length"]**2
100 | else:
101 | return -len(item["input_ids"])**2
102 |
103 | def __iter__(self) -> Iterator[Dict[str, Any]]:
104 | if self.sort_by_length_size <= 1:
105 | yield from super().__iter__()
106 | else:
107 | iterator = super().__iter__()
108 | while True:
109 | block = list(islice(iterator, self.sort_by_length_size))
110 | if not block:
111 | return
112 |
113 | yield from sorted(block, key=self._negative_item_cost)
114 |
115 |
116 | def build_dataset(paths, training_args: TrainingArguments, data_args: DataArguments, is_training: bool) -> StreamingDataset:
117 | logger.info(f"Loading datasets for {'training' if is_training else 'evaluation'}")
118 |
119 | streams = []
120 | for path in paths:
121 | if "@" in path:
122 | path, proportion = path.split("@", 1)
123 | logger.info(f"Loading dataset from {path} with proportion {proportion}")
124 | streams.append(SafeStream(remote=path, local=path, proportion=float(proportion)))
125 | elif "#" in path:
126 | path, proportion = path.split("#", 1)
127 | logger.info(f"Loading dataset from {path} with repeat {proportion}")
128 | streams.append(SafeStream(remote=path, local=path, repeat=float(proportion)))
129 | else:
130 | streams.append(SafeStream(remote=path, local=path))
131 |
132 | epoch_size = (
133 | training_args.max_steps * training_args.train_batch_size * training_args.gradient_accumulation_steps *
134 | training_args.world_size // training_args.seq_parallel_size
135 | )
136 |
137 | num_dataloaders = max(training_args.dataloader_num_workers, 1)
138 | per_device_step_size = training_args.gradient_accumulation_steps * training_args.train_batch_size
139 | per_worker_step_size = per_device_step_size // num_dataloaders
140 | assert per_device_step_size % num_dataloaders == 0, "dataloader workers should divide local batch size"
141 |
142 | return SortByLengthDataset(
143 | streams=streams,
144 | shuffle=is_training,
145 | shuffle_seed=training_args.seed,
146 | batch_size=(training_args.train_batch_size if is_training else training_args.eval_batch_size),
147 | epoch_size=(epoch_size if is_training else None),
148 | sort_by_length_size=(per_worker_step_size if is_training else 1),
149 | data_args=data_args,
150 | replication=training_args.seq_parallel_size,
151 | )
152 |
--------------------------------------------------------------------------------
/training/train_language_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import torch
5 | import datasets
6 | import transformers
7 | import functools
8 |
9 | from transformers import (
10 | AutoConfig,
11 | AutoTokenizer,
12 | HfArgumentParser,
13 | set_seed,
14 | )
15 |
16 | from training.modeling_flash_llama import LlamaForCausalLM
17 | from training.trainer import Trainer, TrainingArguments
18 | from training.dataset import build_dataset, DataCollator, DataArguments
19 | from training.dataset import logger as dataset_logger
20 |
21 |
22 | from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
23 |
24 | from transformers.trainer_utils import get_last_checkpoint
25 | import json
26 | from dataclasses import dataclass, field
27 | from typing import Optional, List
28 |
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 | @dataclass
33 | class ScriptArguments:
34 | model_name_or_path: Optional[str] = field(
35 | default=None,
36 | metadata={
37 | "help": (
38 | "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
39 | )
40 | },
41 | )
42 | config_overrides: Optional[str] = field(
43 | default=None,
44 | metadata={
45 | "help": (
46 | "Override some existing default config settings when a model is trained from scratch. Example: "
47 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
48 | )
49 | },
50 | )
51 | config_overrides_json: Optional[str] = field(
52 | default=None,
53 | metadata={
54 | "help": (
55 | "Override some existing default config settings when a model is trained from scratch. Example: "
56 | "'{\"resid_pdrop\": 0.2}'"
57 | )
58 | },
59 | )
60 | config_name: Optional[str] = field(
61 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
62 | )
63 | tokenizer_name: Optional[str] = field(
64 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
65 | )
66 | cache_dir: Optional[str] = field(
67 | default=None,
68 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
69 | )
70 | use_fast_tokenizer: bool = field(
71 | default=True,
72 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
73 | )
74 | model_revision: str = field(
75 | default="main",
76 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
77 | )
78 | use_auth_token: bool = field(
79 | default=False,
80 | metadata={
81 | "help": (
82 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
83 | "with private models)."
84 | )
85 | },
86 | )
87 |
88 | tokenized_mds_train: List[str] = field(default_factory=list, metadata={"help": "Paths to tokenized training datasets in MDS format"})
89 | tokenized_mds_validation: List[str] = field(default_factory=list, metadata={"help": "Paths to tokenized validation datasets in MDS format"})
90 | tokenized_mds_test: List[str] = field(default_factory=list, metadata={"help": "Paths to tokenized test datasets in MDS format"})
91 |
92 | token_scaled_loss: bool = field(default=False, metadata={"help": "Whether to re-scale the loss by the number of valid training tokens instead of averaging loss across sequences and devices. This should be turned on for instruction tuning, especially when using synthetic data, as the valid training tokens vary across devices."})
93 |
94 |
95 | def main():
96 | # See all possible arguments in src/transformers/training_args.py
97 | # or by passing the --help flag to this script.
98 | # We now keep distinct sets of script_args, for a cleaner separation of concerns.
99 | parser = HfArgumentParser((ScriptArguments, TrainingArguments, DataArguments))
100 | script_args, training_args, data_args = parser.parse_args_into_dataclasses()
101 | # Setup logging
102 | logging.basicConfig(
103 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
104 | datefmt="%m/%d/%Y %H:%M:%S",
105 | handlers=[logging.StreamHandler(sys.stdout)],
106 | )
107 | log_level = training_args.get_process_log_level()
108 | logger.setLevel(log_level)
109 | dataset_logger.setLevel(log_level)
110 | datasets.utils.logging.set_verbosity(log_level)
111 | transformers.utils.logging.set_verbosity(log_level)
112 | transformers.utils.logging.enable_default_handler()
113 | transformers.utils.logging.enable_explicit_format()
114 |
115 | # Log on each process the small summary:
116 | logger.warning(
117 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
118 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
119 | )
120 | logger.info(f"Training/evaluation parameters {training_args}")
121 | logger.info(f"Data arguments {data_args}")
122 | logger.info(f"Additional arguments {script_args}")
123 | # Detecting last checkpoint.
124 | last_checkpoint = None
125 | if os.path.isdir(training_args.output_dir):
126 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
127 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
128 | logger.info(
129 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
130 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
131 | )
132 |
133 | # Set seed before initializing model.
134 | set_seed(training_args.seed)
135 | tokenizer = AutoTokenizer.from_pretrained(
136 | script_args.tokenizer_name or script_args.model_name_or_path,
137 | cache_dir=script_args.cache_dir,
138 | use_fast=script_args.use_fast_tokenizer,
139 | revision=script_args.model_revision,
140 | use_auth_token=True if script_args.use_auth_token else None,
141 | )
142 | config = AutoConfig.from_pretrained(
143 | script_args.config_name or script_args.model_name_or_path,
144 | cache_dir=script_args.cache_dir,
145 | revision=script_args.model_revision,
146 | use_auth_token=True if script_args.use_auth_token else None
147 | )
148 | if script_args.config_overrides:
149 | logger.info(f"Overriding config: {script_args.config_overrides}")
150 | config.update_from_string(script_args.config_overrides)
151 | logger.info(f"New config: {config}")
152 |
153 | if script_args.config_overrides_json:
154 | logger.info(f"Overriding config: {script_args.config_overrides_json}")
155 | config.update(json.loads(script_args.config_overrides_json))
156 | logger.info(f"New config: {config}")
157 |
158 | config.pad_token_id = 0
159 |
160 | if script_args.model_name_or_path:
161 | model = LlamaForCausalLM.from_pretrained(
162 | script_args.model_name_or_path,
163 | from_tf=bool(".ckpt" in script_args.model_name_or_path),
164 | config=config,
165 | cache_dir=script_args.cache_dir,
166 | revision=script_args.model_revision,
167 | use_auth_token=True if script_args.use_auth_token else None,
168 | )
169 | else:
170 | logger.warning(f"Initializing new LlamaForCausalLM from scratch")
171 | model = LlamaForCausalLM(config)
172 |
173 | if script_args.tokenizer_name is not None and script_args.model_name_or_path != script_args.tokenizer_name:
174 | model.resize_token_embeddings(len(tokenizer))
175 |
176 | logger.info(f"Model: {model}")
177 |
178 | # This avoids weird issues when doing multiple runs from different codebases
179 | import streaming
180 | streaming.base.util.clean_stale_shared_memory()
181 |
182 | if script_args.token_scaled_loss:
183 | model.token_scaled_loss = True
184 | training_args.token_scaled_loss = True
185 |
186 | # load_datasets
187 | if training_args.do_train:
188 | train_dataset = build_dataset(script_args.tokenized_mds_train, training_args, data_args, is_training=True)
189 |
190 | if training_args.do_eval:
191 | eval_dataset = {
192 | x.split("/")[-1]: build_dataset(x, tokenizer, training_args, data_args, is_training=False)
193 | for x in script_args.tokenized_mds_validation
194 | }
195 |
196 | if training_args.do_predict:
197 | test_dataset = {
198 | x.split("/")[-1]: build_dataset(x, tokenizer, training_args, data_args, is_training=False)
199 | for x in script_args.tokenized_mds_test
200 | }
201 |
202 | data_collator = DataCollator(tokenizer, data_args)
203 |
204 | # Initialize our Trainer
205 | trainer = Trainer(
206 | model=model,
207 | args=training_args,
208 | train_dataset=train_dataset if training_args.do_train else None,
209 | eval_dataset=eval_dataset if training_args.do_eval else None,
210 | tokenizer=tokenizer,
211 | data_collator=data_collator,
212 | )
213 |
214 | if trainer.is_fsdp_enabled:
215 | # Identify which modules have "_fsdp_wrap" attribute set to True and wrap these
216 | def fsdp_policy_fn(module):
217 | return getattr(module, "_fsdp_wrap", False)
218 |
219 | auto_wrap_policy = functools.partial(lambda_auto_wrap_policy,
220 | lambda_fn=fsdp_policy_fn)
221 | trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = auto_wrap_policy
222 |
223 | # Training
224 | if training_args.do_train:
225 | checkpoint = None
226 | if training_args.resume_from_checkpoint is not None:
227 | checkpoint = training_args.resume_from_checkpoint
228 | elif last_checkpoint is not None:
229 | checkpoint = last_checkpoint
230 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
231 | trainer.save_model()
232 |
233 | metrics = train_result.metrics
234 | trainer.log_metrics("train", metrics)
235 | trainer.save_metrics("train", metrics)
236 | trainer.save_state()
237 |
238 | if torch.distributed.is_initialized():
239 | torch.distributed.barrier()
240 |
241 |
242 | # Evaluation
243 | if training_args.do_eval:
244 | logger.info("*** Evaluate ***")
245 | metrics = trainer.evaluate(eval_dataset)
246 | trainer.log_metrics("eval", metrics)
247 | trainer.save_metrics("eval", metrics)
248 |
249 | # Predict
250 | if training_args.do_predict:
251 | logger.info("*** Predict ***")
252 | predictions = trainer.predict(test_dataset=test_dataset)
253 | print(predictions)
254 | predictions = predictions.predictions
255 | predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
256 | with open('dump.json', 'w') as f:
257 | print(json.dumps(predictions), file=f, flush=True)
258 |
259 |
260 | if __name__ == "__main__":
261 | main()
262 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ProLong
2 |
3 | [[Paper](https://arxiv.org/pdf/2410.02660)] [[HF Page](https://huggingface.co/collections/princeton-nlp/prolong-66c72d55d2051a86ac7bd7e4)]
4 |
5 | This is the homepage for **ProLong** (Princeton long-context language models).
6 |
7 | ProLong is a family of long-context models that are continued trained and supervised fine-tuned from Llama-3-8B, with a maximum context window of 512K tokens. Our [main ProLong model](https://huggingface.co/princeton-nlp/Llama-3-8B-ProLong-512k-Instruct) is one of the best-performing long-context models at the 10B scale (evaluated by [HELMET](https://github.com/princeton-nlp/helmet)).
8 |
9 | To train this strong long-context model, we conduct thorough ablations on the long-context pre-training data, SFT data, and numerous other design choices. We demonstrate our findings in our paper, [How to Train Long-Context Language Models (Effectively)](https://arxiv.org/pdf/2410.02660).
10 |
11 | Authors: [Tianyu Gao](https://gaotianyu.xyz/about)\*, [Alexander Wettig](https://www.cs.princeton.edu/~awettig/)\*, [Howard Yen](https://howard-yen.github.io/), [Danqi Chen](https://www.cs.princeton.edu/~danqic/) (* equal contribution)
12 |
13 | ## Release Progress
14 |
15 |
16 | - [x] ProLong models
17 | - [x] ProLong data
18 | - [x] Pre-training and SFT code
19 | - [x] Sequence parallelism
20 |
21 | ## Model card
22 |
23 | Here are some quick facts about our main ProLong model: [princeton-nlp/Llama-3-8B-ProLong-512k-Instruct](https://huggingface.co/princeton-nlp/Llama-3-8B-ProLong-512k-Instruct).
24 | * Base model: [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
25 | * Long-context continued training: 20B tokens on 64K training data, and 20B tokens on 512K training data
26 | * Supervised fine-tuning (SFT): [UltraChat](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
27 | * Maximum context window: 512K tokens
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 | ProLong performance on HELMET averaged over 32K, 64K, and 128K lengths. All models are instruct models.
36 |
37 |
38 |
39 | ## Download the models and packed data
40 |
41 | All ProLong models are available on Hugging Face. All the models are based on Llama-3-8B, so any code that supports Llama-3-8B is also compatible with ProLong models.
42 |
43 | | Model | HF Link |
44 | |-------|---------|
45 | | ProLong-64k-Base | [princeton-nlp/Llama-3-8B-ProLong-64k-Base](https://huggingface.co/princeton-nlp/Llama-3-8B-ProLong-64k-Base) |
46 | | ProLong-64k-Instruct | [princeton-nlp/Llama-3-8B-ProLong-64k-Instruct](https://huggingface.co/princeton-nlp/Llama-3-8B-ProLong-64k-Instruct) |
47 | | ProLong-512k-Base | [princeton-nlp/Llama-3-8B-ProLong-512k-Base](https://huggingface.co/princeton-nlp/Llama-3-8B-ProLong-512k-Base) |
48 | | ⭐ ProLong-512k-Instruct | [princeton-nlp/Llama-3-8B-ProLong-512k-Instruct](https://huggingface.co/princeton-nlp/Llama-3-8B-ProLong-512k-Instruct) |
49 |
50 | Our training data (packed and sampled version) are also available on Hugging Face (in [mosaicml-streaming](https://docs.mosaicml.com/projects/streaming/en/stable/index.html) format).
51 |
52 | | Data | HF Link |
53 | |------|---------|
54 | | Stage 1: 64K training data (40B tokens) | [princeton-nlp/prolong-data-64K](https://huggingface.co/datasets/princeton-nlp/prolong-data-64K) |
55 | | Stage 2: 512K training data (40B tokens)| [princeton-nlp/prolong-data-512K](https://huggingface.co/datasets/princeton-nlp/prolong-data-512K) |
56 | | SFT: UltraChat (1B tokens) | [princeton-nlp/prolong-ultrachat-64K](https://huggingface.co/datasets/princeton-nlp/prolong-ultrachat-64K) |
57 |
58 |
59 |
60 |
61 | ## Download and prepare raw data
62 |
63 | If you want to experiment with different data lengths or data mixtures,
64 | We also provide the (unpacked, unfiltered, but tokenized) raw data from each domain below.
65 | Due to the large size of the raw data, we store it on AWS S3. To download the data, you need to have an AWS account (with an access key and a secret key). **Note that data downloading will incur a charge on your AWS account**. According to [this S3 document](https://aws.amazon.com/s3/pricing/), each GB of data downloaded incurs $0.09 and the first 100GB is free. You can download the data using the following commands:
66 |
67 | ```bash
68 | # Install AWS CLI if you haven't already
69 | pip install awscli
70 |
71 | # Configure AWS CLI with your credentials (you will need an access key and a secret key from your AWS account)
72 | aws configure
73 |
74 | # Download the raw code repo data (concatenated by repo names from the stack v1)
75 | aws s3 sync s3://princeton-prolong/data_before_packing/code_repos/ /target/path/ --request-payer requester
76 | ```
77 |
78 | Below is the available unpacked raw data (tokenized with the Llama-3 tokenizer). All data is in the [mosaicml-streaming](https://docs.mosaicml.com/projects/streaming/en/stable/index.html) format, with three fields: `domain` (`str`), `input_ids` (`int32 numpy array`, the Llama-3 tokenized document with no BOS/EOS), and `length` (`int32`, number of tokens).
79 |
80 | | Data | Size | S3 path |
81 | |------|------|---------|
82 | | Code repos | 689 GB | s3://princeton-prolong/data_before_packing/code_repos/ |
83 | | Books (SlimPajama)| 180 GB| s3://princeton-prolong/data_before_packing/books/ |
84 | | FineWeb (sampled) | 864 GB | s3://princeton-prolong/data_before_packing/fineweb-2023-50/ |
85 | | FineWeb-edu (sampled) | 365 GB | s3://princeton-prolong/data_before_packing/fineweb-edu-100B/ |
86 | | OpenWebMath | 48 GB| s3://princeton-prolong/data_before_packing/openwebmath/ |
87 | | Wikipedia (Dolma) | 14 GB | s3://princeton-prolong/data_before_packing/wikipedia/ |
88 | | Textbooks | 1 GB | s3://princeton-prolong/data_before_packing/textbooks/ |
89 | | Tulu-v2 | 1 GB | s3://princeton-prolong/data_before_packing/tuluv2/ |
90 | | StackExchange (SlimPajama) | 135 GB | s3://princeton-prolong/data_before_packing/stackexchange/ |
91 | | ArXiv (SlimPajama) | 210 GB | s3://princeton-prolong/data_before_packing/arxiv/ |
92 |
93 |
94 |
95 | A quick guide of mosaicml-streaming
96 |
97 | Full documentation and installation guide can be found [here](https://docs.mosaicml.com/projects/streaming/en/stable/index.html).
98 |
99 |
100 | >>> from streaming import LocalDataset
101 | >>> dataset = LocalDataset("path/to/dataset")
102 | >>> len(dataset) # number of samples
103 | >>> dataset[0] # allow random access, use like a dictionary/JSON
104 | {'domain': 'book', 'input_ids': array([ 1038, 19017, 2041, ..., 271, 12488, 220], dtype=uint32), 'length': 111200}
105 |
106 |
107 |
108 |
109 |
110 |
111 | ### How to filter and pack data
112 |
113 | We use our own [datatools](https://github.com/CodeCreator/datatools) (created by Alex and Tianyu) to filter (by lengths) and pack data. `datatools` is a versatile repo that supports tokenization/packing/filtering from various raw formats (json, jsonl, hugging face, mosaicml-streaming, etc) and outputs the data in the mosaicml-streaming format.
114 |
115 | Example usage:
116 | ```bash
117 | pack