├── apps ├── __init__.py └── main │ ├── __init__.py │ └── configs │ ├── eval.yaml │ ├── llama_1B.yaml │ └── llama_7B.yaml ├── fixtures ├── test-cfgs │ ├── list.yaml │ ├── override.yaml │ ├── middle.yaml │ ├── top.yaml │ └── root.yaml ├── test_docs.jsonl ├── tokenizer_data.json ├── tokenizer_data_bpe_delim.json └── tokens_with_entropies.json ├── blt-figure.jpg ├── blt-figure.pdf ├── bytelatent ├── data │ ├── __init__.py │ ├── iterators │ │ ├── __init__.py │ │ ├── abstract_iterator.py │ │ ├── looping_iterator.py │ │ ├── test_limit_iterator.py │ │ ├── limit_iterator.py │ │ ├── sampling_iterator.py │ │ ├── dev_iterators.py │ │ ├── test_iters.py │ │ ├── test_arrow_iterator.py │ │ ├── preprocess_iterator.py │ │ └── sequence_iterator.py │ ├── test_data.py │ ├── data_types.py │ ├── ngram_processor.py │ └── file_util.py ├── model │ ├── __init__.py │ ├── latent_transformer.py │ └── utils.py ├── plotting │ ├── __init__.py │ ├── config_scaling_figures.yaml │ ├── config_entropy_figure.yaml │ ├── entropy_figure.py │ └── scaling_figures.py ├── preprocess │ ├── __init__.py │ ├── fsspec_target.py │ ├── data_pipeline.py │ ├── parallel_entropies.py │ └── preprocess_entropies.py ├── tokenizers │ ├── __init__.py │ ├── constants.py │ ├── abstract_tokenizer.py │ ├── test_blt_tokenizer.py │ ├── build_tokenizer.py │ ├── sentence_piece_tokenizer.py │ ├── tiktoken_tokenizer.py │ └── blt_tokenizer.py ├── .DS_Store ├── __init__.py ├── constants.py ├── print_config.py ├── templates │ └── stool_template.sh.jinja ├── iterate_data.py ├── entropy_model.py ├── configs │ ├── entropy_model.yaml │ └── debug.yaml ├── test_entropy_model.py ├── config_parser.py ├── norms.py ├── profiling.py ├── logger.py ├── optim.py ├── float8.py ├── test_config_parser.py ├── hf.py ├── stool.py ├── generate_blt.py └── metrics.py ├── dev └── lint.sh ├── .prettierrc ├── .github └── workflows │ ├── isort.yml │ └── black.yml ├── requirements.txt ├── download_blt_weights.py ├── setup.py ├── CONTRIBUTING.md ├── setup ├── create_env.sh ├── download_tokenizer.py └── download_prepare_hf_data.py ├── plot_data ├── scores.json └── entropy_figure.json ├── pyproject.toml ├── demo.py ├── CODE_OF_CONDUCT.md └── .gitignore /apps/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apps/main/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fixtures/test-cfgs/list.yaml: -------------------------------------------------------------------------------- 1 | [1, 2, 3] 2 | -------------------------------------------------------------------------------- /fixtures/test-cfgs/override.yaml: -------------------------------------------------------------------------------- 1 | a: 100 2 | -------------------------------------------------------------------------------- /blt-figure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/blt/HEAD/blt-figure.jpg -------------------------------------------------------------------------------- /blt-figure.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/blt/HEAD/blt-figure.pdf -------------------------------------------------------------------------------- /bytelatent/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | -------------------------------------------------------------------------------- /bytelatent/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | -------------------------------------------------------------------------------- /bytelatent/plotting/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | -------------------------------------------------------------------------------- /bytelatent/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | -------------------------------------------------------------------------------- /bytelatent/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | -------------------------------------------------------------------------------- /dev/lint.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | isort . 3 | black . 4 | -------------------------------------------------------------------------------- /fixtures/test-cfgs/middle.yaml: -------------------------------------------------------------------------------- 1 | config: fixtures/test-cfgs/root.yaml 2 | b: 3 | y: 10 4 | -------------------------------------------------------------------------------- /fixtures/test-cfgs/top.yaml: -------------------------------------------------------------------------------- 1 | config: fixtures/test-cfgs/middle.yaml 2 | 3 | hello: world 4 | -------------------------------------------------------------------------------- /bytelatent/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/blt/HEAD/bytelatent/.DS_Store -------------------------------------------------------------------------------- /bytelatent/data/iterators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | -------------------------------------------------------------------------------- /fixtures/test-cfgs/root.yaml: -------------------------------------------------------------------------------- 1 | seed: -1 2 | a: 1 3 | b: 4 | x: 0 5 | y: ??? 6 | z: ??? 7 | -------------------------------------------------------------------------------- /bytelatent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | class ByteLatentError(Exception): 3 | pass 4 | -------------------------------------------------------------------------------- /.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "overrides": [ 3 | { 4 | "files": "*.yaml", 5 | "options": { "tabWidth": 2 } 6 | } 7 | ] 8 | } 9 | -------------------------------------------------------------------------------- /fixtures/test_docs.jsonl: -------------------------------------------------------------------------------- 1 | {"sample_id": "0", "text": "test_0"} 2 | {"sample_id": "1", "text": "test_1"} 3 | {"sample_id": "2", "text": "test_2"} 4 | -------------------------------------------------------------------------------- /bytelatent/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import os 3 | from pathlib import Path 4 | 5 | BLT_DATA = Path(os.environ.get("BLT_DATA", "data")) 6 | -------------------------------------------------------------------------------- /bytelatent/plotting/config_scaling_figures.yaml: -------------------------------------------------------------------------------- 1 | df_dir: /home/par/blt_df 2 | output_chart_dir: figures/ 3 | frame_files: 4 | ["4b_df.json", "500m_df.json", "scaling_arch_df.json", "scaling_df.json"] 5 | -------------------------------------------------------------------------------- /bytelatent/plotting/config_entropy_figure.yaml: -------------------------------------------------------------------------------- 1 | data_path: plot_data/entropy_figure.json 2 | chart_path: figures/entropy_figure.pdf 3 | threshold_override: 1.7171002626419067 4 | score_override_path: plot_data/scores.json 5 | -------------------------------------------------------------------------------- /.github/workflows/isort.yml: -------------------------------------------------------------------------------- 1 | name: Lint with isort 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - uses: isort/isort-action@master 11 | -------------------------------------------------------------------------------- /bytelatent/tokenizers/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | 4 | SEP = " " 5 | BOS_ID: int = 1 6 | EOS_ID: int = 2 7 | PAD_ID: int = -1 8 | BOE_ID: int = 0 9 | BPE_ID: int = 3 10 | OFFSET: int = 4 11 | 12 | BYTE_UNITS: int = 256 13 | -------------------------------------------------------------------------------- /fixtures/tokenizer_data.json: -------------------------------------------------------------------------------- 1 | {"texts": ["Let's check if these tokenizers match!"], "tokens": [[1, 80, 105, 120, 43, 119, 36, 103, 108, 105, 103, 111, 36, 109, 106, 36, 120, 108, 105, 119, 105, 36, 120, 115, 111, 105, 114, 109, 126, 105, 118, 119, 36, 113, 101, 120, 103, 108, 37, 2]]} -------------------------------------------------------------------------------- /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Lint with Black 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - uses: psf/black@stable 11 | with: 12 | version: "24.8.0" 13 | -------------------------------------------------------------------------------- /fixtures/tokenizer_data_bpe_delim.json: -------------------------------------------------------------------------------- 1 | {"texts": ["Let's check if these tokenizers match!"], "tokens": [[1, 3, 80, 105, 120, 3, 43, 3, 119, 3, 36, 103, 108, 105, 103, 111, 3, 36, 109, 106, 3, 36, 120, 108, 105, 119, 105, 3, 36, 120, 115, 111, 105, 114, 3, 109, 126, 105, 118, 119, 3, 36, 113, 101, 120, 103, 108, 3, 37, 2]]} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | omegaconf 3 | msgspec 4 | rouge-score 5 | sacrebleu 6 | sentencepiece 7 | tiktoken 8 | blobfile 9 | wandb 10 | viztracer 11 | lm-eval 12 | scipy 13 | pynvml 14 | datatrove 15 | orjson 16 | luigi 17 | pydantic 18 | altair 19 | submitit 20 | typer 21 | rich 22 | fsspec[full] 23 | huggingface-hub==0.30.* 24 | -------------------------------------------------------------------------------- /bytelatent/print_config.py: -------------------------------------------------------------------------------- 1 | from bytelatent.args import TrainArgs 2 | from bytelatent.config_parser import parse_args_to_pydantic_model 3 | 4 | 5 | def main(): 6 | train_args = parse_args_to_pydantic_model(TrainArgs) 7 | print(train_args.model_dump_json(indent=4)) 8 | 9 | 10 | if __name__ == "__main__": 11 | main() 12 | -------------------------------------------------------------------------------- /download_blt_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import typer 4 | from huggingface_hub import snapshot_download 5 | 6 | 7 | def main(): 8 | if not os.path.exists("hf-weights"): 9 | os.makedirs("hf-weights") 10 | snapshot_download(f"facebook/blt", local_dir=f"hf-weights") 11 | 12 | 13 | if __name__ == "__main__": 14 | typer.run(main) 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="bytelatent", 5 | version="0.1.0", 6 | description="Byte Latent Transformer: Patches Scale Better Than Tokens", 7 | author="Meta Platforms, Inc. and affiliates.", 8 | url="https://github.com/facebookresearch/blt", 9 | packages=find_packages(), 10 | install_requires=["sentencepiece", "tiktoken", "xformers"], 11 | ) 12 | -------------------------------------------------------------------------------- /bytelatent/tokenizers/abstract_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import abc 3 | 4 | 5 | class Tokenizer(abc.ABC): 6 | @abc.abstractmethod 7 | def encode(self, text: str, add_bos: bool, add_eos: bool): 8 | pass 9 | 10 | @abc.abstractmethod 11 | def decode(self, tokens: list[int]): 12 | pass 13 | 14 | @abc.abstractmethod 15 | def get_token_offsets( 16 | self, text: str, tokens: list[int] | None = None 17 | ) -> tuple[list[str], list[int]]: 18 | """Return the offsets of the tokens in the original text. Only used for evaluation.""" 19 | pass 20 | 21 | @abc.abstractmethod 22 | def get_vocab_size(self) -> int: 23 | pass 24 | -------------------------------------------------------------------------------- /apps/main/configs/eval.yaml: -------------------------------------------------------------------------------- 1 | name: "debug_evals" 2 | # ckpt_dir: !!CHANGETHIS!! 3 | # dump_dir: !!CHANGETHIS!! 4 | generator: 5 | max_tokens: 8192 6 | dtype: bf16 7 | temperature: 1.0 8 | top_p: 0.95 9 | harness: 10 | tasks: 11 | - hellaswag 12 | - task: boolq 13 | dataset_kwargs: 14 | trust_remote_code: true 15 | - task: nq_open 16 | num_fewshot: 5 17 | - piqa 18 | - task: social_iqa 19 | dataset_kwargs: 20 | trust_remote_code: true 21 | - triviaqa 22 | - winogrande 23 | - openbookqa 24 | - arc_easy 25 | - arc_challenge 26 | - race 27 | - commonsense_qa 28 | # - coqa 29 | - copa 30 | - gsm8k 31 | - bbh 32 | - mmlu 33 | - mmlu_pro 34 | validation: 35 | max_steps: 1000 36 | -------------------------------------------------------------------------------- /bytelatent/templates/stool_template.sh.jinja: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | {{ exclude }} 4 | {{ qos }} 5 | {{ account }} 6 | {{ constraint }} 7 | #SBATCH --job-name={{ name }} 8 | #SBATCH --nodes={{ nodes }} 9 | #SBATCH --gres=gpu:{{ ngpus }} 10 | #SBATCH --cpus-per-gpu={{ ncpu }} 11 | #SBATCH --time={{ time }} 12 | #SBATCH --partition={{ partition }} 13 | #SBATCH --mem={{ mem }} 14 | 15 | #SBATCH --output={{ dump_dir }}/logs/%j/%j.stdout 16 | #SBATCH --error={{ dump_dir }}/logs/%j/%j.stderr 17 | 18 | #SBATCH --open-mode=append 19 | #SBATCH --signal=USR2@120 20 | #SBATCH --distribution=block 21 | 22 | {% if use_conda %} 23 | # Mimic the effect of "conda init", which doesn't work for scripts 24 | eval "$({{ conda_exe }} shell.bash hook)" 25 | source activate {{ conda_env_path }} 26 | {% endif %} 27 | 28 | {{ go_to_code_dir }} 29 | 30 | export OMP_NUM_THREADS=1 31 | export LAUNCH_WITH="SBATCH" 32 | export DUMP_DIR={{ dump_dir }} 33 | srun {{ log_output }} -n {{ tasks }} -N {{ nodes_per_run }} {{ python_command }} -u -m {{ script }} config=$DUMP_DIR/base_config.yaml dump_dir=$DUMP_DIR name={{ name }} 34 | -------------------------------------------------------------------------------- /bytelatent/preprocess/fsspec_target.py: -------------------------------------------------------------------------------- 1 | import fsspec 2 | from luigi.target import FileSystem, FileSystemTarget 3 | 4 | 5 | class FSSpecFileSystem(FileSystem): 6 | def __init__(self, fs: fsspec.AbstractFileSystem): 7 | self.fs = fs 8 | 9 | def exists(self, path): 10 | return self.fs.exists() 11 | 12 | def remove(self, path, recursive=True, skip_trash=True): 13 | raise NotImplementedError() 14 | 15 | def isdir(self, path): 16 | return self.fs.isdir(path) 17 | 18 | def listdir(self, path): 19 | return self.fs.ls(path) 20 | 21 | 22 | class FSSpecTarget(FileSystemTarget): 23 | def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None): 24 | self.path = path 25 | if fs is None: 26 | self.fsspec_fs = fsspec.filesystem("file") 27 | else: 28 | self.fsspec_fs = fs 29 | self._fs = None 30 | 31 | @property 32 | def fs(self): 33 | if self._fs is None: 34 | self._fs = FSSpecFileSystem(self.fsspec_fs) 35 | return self._fs 36 | 37 | def open(self, mode): 38 | return self.fs.open(self.path, mode=mode) 39 | -------------------------------------------------------------------------------- /bytelatent/data/iterators/abstract_iterator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import abc 3 | from typing import Any, Generator, Generic, TypeVar 4 | 5 | import pydantic 6 | 7 | T = TypeVar("T") 8 | C = TypeVar("C") 9 | 10 | 11 | class StatefulIterator(Generic[T, C], abc.ABC): 12 | 13 | @abc.abstractmethod 14 | def get_state(self) -> C: 15 | pass 16 | 17 | @abc.abstractmethod 18 | def create_iter(self) -> Generator[T, Any, None]: 19 | pass 20 | 21 | 22 | class IteratorState(Generic[C]): 23 | @abc.abstractmethod 24 | def build(self) -> StatefulIterator[T, C]: 25 | pass 26 | 27 | 28 | class PydanticIteratorState(pydantic.BaseModel, IteratorState): 29 | model_config = pydantic.ConfigDict(extra="forbid") 30 | 31 | 32 | def get_state_and_refresh(iterator: StatefulIterator): 33 | # Re-init dataloader and iterator is necessary since get_state() 34 | # on mp iterator shuts down MP to correctly persist state and it needs 35 | # to be restarted. 36 | state = iterator.get_state() 37 | data_loader = state.build() 38 | py_iterator = data_loader.create_iter() 39 | return state, data_loader, py_iterator 40 | -------------------------------------------------------------------------------- /bytelatent/data/iterators/looping_iterator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from bytelatent.data.iterators.abstract_iterator import ( 4 | PydanticIteratorState, 5 | StatefulIterator, 6 | ) 7 | from bytelatent.data.iterators.arrow_iterator import ( 8 | ArrowFileIterator, 9 | ArrowFileIteratorState, 10 | ) 11 | 12 | 13 | class LoopingIteratorState(PydanticIteratorState): 14 | file_iterator_state: ArrowFileIteratorState 15 | epoch: int 16 | 17 | def build(self) -> "LoopingIterator": 18 | return LoopingIterator( 19 | file_iterator=self.file_iterator_state.build(), 20 | epoch=self.epoch, 21 | ) 22 | 23 | 24 | class LoopingIterator(StatefulIterator): 25 | def __init__(self, file_iterator: ArrowFileIterator, epoch: int = -1): 26 | self.file_iterator = file_iterator 27 | self.epoch = epoch 28 | 29 | def get_state(self): 30 | return LoopingIteratorState( 31 | file_iterator_state=self.file_iterator.get_state(), epoch=self.epoch 32 | ) 33 | 34 | def create_iter(self): 35 | while True: 36 | self.epoch += 1 37 | iterator = self.file_iterator.create_iter() 38 | yield from iterator 39 | -------------------------------------------------------------------------------- /bytelatent/iterate_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pyarrow 4 | import typer 5 | from rich.progress import track 6 | 7 | from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIteratorState 8 | from bytelatent.logger import init_logger 9 | 10 | 11 | def main( 12 | state_file: str, 13 | steps: int = 3_000, 14 | io_thread_count: int = 2, 15 | cpu_count: int = 2, 16 | log_freq: int = 100, 17 | ): 18 | init_logger() 19 | pyarrow.set_io_thread_count(io_thread_count) 20 | pyarrow.set_cpu_count(cpu_count) 21 | with open(state_file) as f: 22 | train_state = json.load(f) 23 | dl_state = MultiprocessIteratorState(**train_state["data_loader_state"]) 24 | packing_iterator_state = dl_state.base_iterator_state 25 | print("building") 26 | packing_iterator = packing_iterator_state.build() 27 | print("iter") 28 | batch_iter = packing_iterator.create_iter() 29 | print("looping") 30 | for i in track(range(steps)): 31 | _ = next(batch_iter) 32 | if i % log_freq == 0: 33 | print(pyarrow.default_memory_pool()) 34 | print(i) 35 | print(pyarrow.default_memory_pool()) 36 | 37 | 38 | if __name__ == "__main__": 39 | typer.run(main) 40 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | 8 | We actively welcome your pull requests. 9 | 10 | 1. Fork the repo and create your branch from `main`. 11 | 2. If you've added code that should be tested, add tests. 12 | 3. If you've changed APIs, update the documentation. 13 | 4. Ensure the test suite passes. 14 | 5. Make sure your code lints. 15 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 16 | 17 | ## Contributor License Agreement ("CLA") 18 | 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Meta's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | 26 | We use GitHub issues to track public bugs. Please ensure your description is 27 | clear and has sufficient instructions to be able to reproduce the issue. 28 | 29 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 30 | disclosure of security bugs. In those cases, please go through the process 31 | outlined on that page and do not file a public issue. 32 | 33 | ## License 34 | 35 | By contributing to BLT, you agree that your contributions will be licensed 36 | under the LICENSE file in the root directory of this source tree. 37 | -------------------------------------------------------------------------------- /bytelatent/tokenizers/test_blt_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import json 3 | 4 | from bytelatent.constants import BLT_DATA 5 | from bytelatent.tokenizers.blt_tokenizer import BltTokenizer 6 | from bytelatent.tokenizers.build_tokenizer import TokenizerArgs 7 | 8 | 9 | def test_tokenizer_bytes(): 10 | with open("fixtures/tokenizer_data.json") as f: 11 | data = json.load(f) 12 | 13 | examples: list[str] = data["texts"] 14 | examples_tokens: list[list[int]] = data["tokens"] 15 | 16 | tokenizer = BltTokenizer(bpe_delim=False) 17 | for i in range(len(examples)): 18 | assert tokenizer.encode(examples[i]) == examples_tokens[i] 19 | 20 | 21 | def test_tokenizer_bpe(): 22 | with open("fixtures/tokenizer_data_bpe_delim.json") as f: 23 | data = json.load(f) 24 | 25 | examples: list[str] = data["texts"] 26 | examples_tokens: list[list[int]] = data["tokens"] 27 | 28 | tokenizer = BltTokenizer(bpe_delim=True) 29 | for i in range(len(examples)): 30 | assert tokenizer.encode(examples[i]) == examples_tokens[i] 31 | 32 | 33 | def test_build_tokenizer_from_args(): 34 | tokenizer_args = TokenizerArgs( 35 | name="blt", 36 | init_kwargs={ 37 | "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" 38 | }, 39 | ) 40 | tokenizer = tokenizer_args.build() 41 | assert tokenizer.encode("test text") is not None 42 | -------------------------------------------------------------------------------- /setup/create_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | #SBATCH --job-name=env_creation 5 | #SBATCH --nodes=1 6 | #SBATCH --ntasks=1 7 | #SBATCH --gres=gpu:8 8 | #SBATCH --exclusive 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --cpus-per-task=128 11 | #SBATCH --mem=0 12 | #SBATCH --time=01:00:00 13 | 14 | # Exit immediately if a command exits with a non-zero status 15 | set -e 16 | 17 | # Start timer 18 | start_time=$(date +%s) 19 | 20 | # Get the current date 21 | current_date=$(date +%y%m%d) 22 | 23 | # Create environment name with the current date 24 | env_prefix=blt_$current_date 25 | 26 | # Create the conda environment 27 | 28 | source $CONDA_ROOT/etc/profile.d/conda.sh 29 | conda create -n $env_prefix python=3.12 -y 30 | conda activate $env_prefix 31 | 32 | echo "Currently in env $(which python)" 33 | 34 | # Install packages 35 | pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 36 | pip install ninja 37 | pip install -v -U git+https://github.com/facebookresearch/xformers.git@de742ec3d64bd83b1184cc043e541f15d270c148 38 | pip install -r requirements.txt 39 | 40 | # End timer 41 | end_time=$(date +%s) 42 | 43 | # Calculate elapsed time in seconds 44 | elapsed_time=$((end_time - start_time)) 45 | 46 | # Convert elapsed time to minutes 47 | elapsed_minutes=$((elapsed_time / 60)) 48 | 49 | echo "Environment $env_prefix created and all packages installed successfully in $elapsed_minutes minutes!" 50 | -------------------------------------------------------------------------------- /bytelatent/data/iterators/test_limit_iterator.py: -------------------------------------------------------------------------------- 1 | from bytelatent.data.iterators.dev_iterators import BltTestIterator 2 | from bytelatent.data.iterators.limit_iterator import LimitIterator 3 | 4 | 5 | def test_limit_iterator(): 6 | total = 10 7 | limit = 5 8 | base_iterator = BltTestIterator(total=total) 9 | limit_iterator = LimitIterator(base_iterator, limit=limit) 10 | iterator = limit_iterator.create_iter() 11 | n = 0 12 | for example in iterator: 13 | assert example.sample_id == f"test_{n}" 14 | n += 1 15 | assert n == limit 16 | 17 | limit = 10 18 | base_iterator = BltTestIterator(total=total) 19 | limit_iterator = LimitIterator(base_iterator, limit=limit) 20 | iterator = limit_iterator.create_iter() 21 | n = 0 22 | for example in iterator: 23 | assert example.sample_id == f"test_{n}" 24 | n += 1 25 | assert n == limit == total 26 | 27 | limit = 20 28 | base_iterator = BltTestIterator(total=total) 29 | limit_iterator = LimitIterator(base_iterator, limit=limit) 30 | iterator = limit_iterator.create_iter() 31 | n = 0 32 | for example in iterator: 33 | assert example.sample_id == f"test_{n}" 34 | n += 1 35 | assert n == total 36 | 37 | limit = -1 38 | base_iterator = BltTestIterator(total=total) 39 | limit_iterator = LimitIterator(base_iterator, limit=limit) 40 | iterator = limit_iterator.create_iter() 41 | n = 0 42 | for example in iterator: 43 | assert example.sample_id == f"test_{n}" 44 | n += 1 45 | assert n == total 46 | -------------------------------------------------------------------------------- /bytelatent/entropy_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import json 3 | import logging 4 | import os 5 | 6 | import torch 7 | 8 | from bytelatent.transformer import LMTransformer, LMTransformerArgs 9 | 10 | logger = logging.getLogger() 11 | 12 | 13 | def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"): 14 | with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: 15 | reloaded = json.loads(fr.read()) 16 | 17 | torch.set_default_dtype(torch.bfloat16) 18 | model_params = reloaded["entropy_model"] 19 | logger.warning( 20 | "Update checkpoint to load attn and sliding window args from checkpoint" 21 | ) 22 | entropy_model_args = LMTransformerArgs( 23 | dim=model_params["dim"], 24 | n_layers=model_params["n_layers"], 25 | n_heads=model_params["n_heads"], 26 | max_seqlen=model_params["max_seqlen"], 27 | ffn_dim_multiplier=model_params["ffn_dim_multiplier"], 28 | vocab_size=model_params["vocab_size"], 29 | attn_bias_type="local_block_causal", 30 | attn_impl="xformers", 31 | sliding_window=512, 32 | ) 33 | entropy_model = LMTransformer(entropy_model_args) 34 | 35 | entropy_model.load_state_dict( 36 | torch.load(state_dict_path, map_location=device)["model"], strict=False 37 | ) 38 | entropy_model.to(device) 39 | entropy_model = entropy_model.eval() 40 | # no grads for the model: 41 | for param in entropy_model.parameters(): 42 | param.requires_grad = False 43 | return entropy_model, entropy_model_args 44 | -------------------------------------------------------------------------------- /plot_data/scores.json: -------------------------------------------------------------------------------- 1 | {"score": [3.3949153423309326, 2.1647746562957764, 2.3216569423675537, 2.8114914894104004, 1.505232334136963, 0.04055612534284592, 0.09150367230176926, 0.06008715182542801, 0.3453567624092102, 1.0483067035675049, 0.1967127025127411, 0.12737397849559784, 0.05923430994153023, 0.001597292022779584, 0.004362526815384626, 0.005547997076064348, 0.0011689786333590746, 0.0010273229563608766, 1.0228447914123535, 3.6863417625427246, 0.46605175733566284, 0.048645928502082825, 2.2544963359832764, 0.37329360842704773, 1.001160979270935, 2.9116122722625732, 1.8948925733566284, 1.4017235040664673, 0.3879640996456146, 0.2652309536933899, 1.780383825302124, 0.013964788988232613, 0.005456871818751097, 0.5426468253135681, 0.20666983723640442, 0.0051853349432349205, 0.0005802579107694328, 0.0007443525246344507, 0.0004390323010738939, 0.005452247802168131, 1.1932975053787231, 0.023798620328307152, 3.1230878829956055, 1.3915895223617554, 3.0489213466644287, 1.7018193006515503, 1.873910903930664, 1.4662408828735352, 0.004920408595353365, 0.02599342167377472, 0.6620859503746033, 0.31743818521499634, 2.8409600257873535, 1.1354060173034668, 0.0520976223051548, 0.3519965708255768, 0.40707266330718994, 2.5438783168792725, 1.3343133926391602, 0.023993035778403282, 3.445943832397461, 1.8542104959487915, 0.7849258780479431, 0.6848396062850952, 0.06938046962022781, 0.20923230051994324, 0.10084306448698044, 0.18334199488162994, 0.4126923978328705, 0.5505472421646118, 0.1042013093829155, 0.019447727128863335, 0.0014866517158225179, 0.0009848219342529774, 0.00021391961490735412, 0.007746236398816109, 0.00038792978739365935, 0.0007933690212666988, 1.2369810342788696, 0.4436197578907013, 4.6366687456611544e-05]} -------------------------------------------------------------------------------- /bytelatent/data/iterators/limit_iterator.py: -------------------------------------------------------------------------------- 1 | from pydantic import ConfigDict 2 | 3 | from bytelatent.data.iterators.abstract_iterator import ( 4 | PydanticIteratorState, 5 | StatefulIterator, 6 | ) 7 | from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState 8 | from bytelatent.data.iterators.dev_iterators import BltTestIteratorState 9 | 10 | 11 | class LimitIteratorState(PydanticIteratorState): 12 | model_config = ConfigDict(extra="forbid") 13 | base_iterator_state: ( 14 | BltTestIteratorState | ArrowFileIteratorState | PydanticIteratorState 15 | ) 16 | n_yielded: int 17 | limit: int 18 | 19 | def build(self) -> "LimitIterator": 20 | return LimitIterator( 21 | base_iterator=self.base_iterator_state.build(), 22 | n_yielded=self.n_yielded, 23 | limit=self.limit, 24 | ) 25 | 26 | 27 | class LimitIterator(StatefulIterator): 28 | def __init__(self, base_iterator: StatefulIterator, limit: int, n_yielded: int = 0): 29 | self.base_iterator = base_iterator 30 | self.n_yielded = n_yielded 31 | self.limit = limit 32 | 33 | def get_state(self): 34 | return LimitIteratorState( 35 | base_iterator_state=self.base_iterator.get_state(), 36 | n_yielded=self.n_yielded, 37 | limit=self.limit, 38 | ) 39 | 40 | def create_iter(self): 41 | iterator = self.base_iterator.create_iter() 42 | try: 43 | while self.n_yielded < self.limit or self.limit < 0: 44 | yield next(iterator) 45 | self.n_yielded += 1 46 | except StopIteration: 47 | pass 48 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "blt" 3 | version = "0.1.0" 4 | description = "BLT" 5 | readme = "README.md" 6 | requires-python = "==3.12.*" 7 | dependencies = [ 8 | "altair>=5.5.0", 9 | "datatrove>=0.5.0", 10 | "fsspec>=2024.6.1", 11 | "huggingface-hub==0.30.*", 12 | "jinja2>=3.1.6", 13 | "lm-eval>=0.4.8", 14 | "luigi>=3.6.0", 15 | "numpy>=2.1.2", 16 | "omegaconf>=2.3.0", 17 | "orjson>=3.10.18", 18 | "pydantic>=2.11.4", 19 | "pynvml>=12.0.0", 20 | "rich>=14.0.0", 21 | "s3fs>=2024.6.1", 22 | "scipy>=1.15.2", 23 | "sentencepiece>=0.2.0", 24 | "submitit>=1.5.2", 25 | "tiktoken>=0.8.0", 26 | "typer>=0.15.3", 27 | "viztracer>=1.0.3", 28 | "wandb>=0.19.10", 29 | "xformers", 30 | ] 31 | 32 | [[tool.uv.index]] 33 | name = "torch-nightly-cu121" 34 | url = "https://download.pytorch.org/whl/nightly/cu121" 35 | 36 | 37 | [tool.uv.sources] 38 | torch = {index = "torch-nightly-cu121"} 39 | xformers = { git = "https://github.com/facebookresearch/xformers.git", rev = "de742ec3d64bd83b1184cc043e541f15d270c148" } 40 | 41 | [dependency-groups] 42 | pre_build = [ 43 | "setuptools", 44 | "ninja", 45 | "torch==2.6.0.dev20241112", 46 | ] 47 | compile_xformers = ['xformers'] 48 | dev = [ 49 | "black==24.8.0", 50 | "ipython>=9.2.0", 51 | "isort>=6.0.1", 52 | "pudb>=2025.1", 53 | ] 54 | 55 | 56 | [tool.uv] 57 | no-build-isolation-package = ["xformers"] 58 | index-strategy = "unsafe-best-match" 59 | override-dependencies = ["torch==2.6.0.dev20241112"] 60 | 61 | [tool.isort] 62 | profile = "black" 63 | known_bytelatent = "bytelatent" 64 | known_apps = "apps" 65 | known_third_party = "wandb" 66 | sections = "FUTURE,STDLIB,THIRDPARTY,BYTELATENT,APPS,FIRSTPARTY,LOCALFOLDER" 67 | -------------------------------------------------------------------------------- /apps/main/configs/llama_1B.yaml: -------------------------------------------------------------------------------- 1 | # dump_dir: !!!CHANGE_THIS!!! 2 | name: large_lm 3 | steps: 60_000 4 | probe_freq: null 5 | seed: 777 6 | 7 | optim: 8 | lr: 3e-3 9 | weight_decay: 0.033 10 | warmup: 5000 11 | lr_min_ratio: 0.000001 12 | clip: 1.0 13 | 14 | distributed: 15 | fsdp_type: full_shard 16 | compile: true 17 | model_dtype: bf16 18 | matmul_allow_tf32: false 19 | selective_activation_checkpointing: false 20 | tp_size: 1 21 | 22 | model: 23 | dim: 2048 24 | n_layers: 25 25 | n_heads: 16 26 | 27 | data: 28 | root_dir: data/shuffled 29 | sources: 30 | dclm_baseline_1.0: 100.0 31 | batch_size: 4 32 | prefetch_size: 1024 33 | seq_len: 4096 34 | n_views: 2 35 | load_async: true 36 | add_bos: true 37 | add_eos: true 38 | tokenizer: 39 | name: tiktoken 40 | path: tokenizers/cl_toplang_128k.tiktoken 41 | 42 | profiling: 43 | run: true 44 | mem_warmup: 0 45 | mem_steps: 4 46 | profile_warmup: 100 47 | profile_steps: 4 48 | 49 | checkpoint: 50 | dump: 51 | every: 2500 52 | keep: 3 53 | eval: 54 | every: 5000 55 | keep: -1 56 | 57 | logging: 58 | freq: 1 59 | 60 | async_eval_gpus: 8 61 | eval: 62 | harness: 63 | tasks: 64 | - hellaswag 65 | - task: boolq 66 | dataset_kwargs: 67 | trust_remote_code: true 68 | - piqa 69 | - task: social_iqa 70 | dataset_kwargs: 71 | trust_remote_code: true 72 | - winogrande 73 | - openbookqa 74 | - arc_easy 75 | - arc_challenge 76 | - race 77 | - commonsense_qa 78 | - copa 79 | # - coqa 80 | # - task: nq_open 81 | # num_fewshot: 5 82 | # - triviaqa 83 | validation: 84 | max_steps: 1000 85 | generator: 86 | max_tokens: 16384 87 | dtype: bf16 88 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import typer 5 | 6 | from bytelatent.distributed import DistributedArgs, setup_torch_distributed 7 | from bytelatent.generate import load_consolidated_model_and_tokenizer 8 | from bytelatent.generate_blt import generate_nocache 9 | from bytelatent.model.blt import ByteLatentTransformer 10 | from bytelatent.tokenizers.blt_tokenizer import BltTokenizer 11 | 12 | 13 | def main(prompt: str, model_name: str = "blt-1b"): 14 | assert model_name in ["blt-1b", "blt-7b"] 15 | model_name = model_name.replace("-", "_") 16 | distributed_args = DistributedArgs() 17 | distributed_args.configure_world() 18 | if not torch.distributed.is_initialized(): 19 | setup_torch_distributed(distributed_args) 20 | checkpoint_path = os.path.join("hf-weights", model_name) 21 | print(f"Loading BLT model: {model_name}") 22 | model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( 23 | checkpoint_path, 24 | ) 25 | assert isinstance(model, ByteLatentTransformer) 26 | assert isinstance(tokenizer, BltTokenizer) 27 | patcher_args = train_cfg.data.patcher_args.model_copy(deep=True) 28 | patcher_args.realtime_patching = True 29 | print("Loading entropy model and patcher") 30 | patcher_args.entropy_model_checkpoint_dir = os.path.join( 31 | "hf-weights", "entropy_model" 32 | ) 33 | patcher = patcher_args.build() 34 | prompts = [prompt] 35 | outputs = generate_nocache( 36 | prompts, model=model, tokenizer=tokenizer, patcher=patcher 37 | ) 38 | text_outputs = [tokenizer.decode(t) for t in outputs] 39 | for p, t in zip(prompts, text_outputs): 40 | print(f'Prompt: "{p}" Completion: "{t}"') 41 | print() 42 | 43 | 44 | if __name__ == "__main__": 45 | typer.run(main) 46 | -------------------------------------------------------------------------------- /bytelatent/configs/entropy_model.yaml: -------------------------------------------------------------------------------- 1 | # Template config, need to change dump_dir, data.root_dir and tokenizer.path 2 | # Evals can be activated by uncommenting its config 3 | # python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest 4 | 5 | dump_dir: /tmp/blt-entropy 6 | name: "debug" 7 | steps: 100_000 8 | max_steps: null 9 | probe_freq: null 10 | seed: 777 11 | optim: 12 | lr: 4e-04 13 | warmup: 500 14 | lr_min_ratio: 0.1 15 | clip: 10.0 16 | 17 | distributed: 18 | fsdp_type: full_shard 19 | model_dtype: bf16 20 | matmul_allow_tf32: false 21 | selective_activation_checkpointing: false 22 | tp_size: 1 23 | 24 | train_entropy_model: true 25 | model: null 26 | entropy_model: 27 | dim: 768 28 | n_layers: 14 29 | n_heads: 12 30 | max_seqlen: 8192 31 | # vocab_size: -1 32 | vocab_size: 260 33 | ffn_dim_multiplier: 1.0 34 | sliding_window: 512 35 | attn_bias_type: "local_block_causal" 36 | attn_impl: "xformers" 37 | 38 | data: 39 | root_dir: ??? 40 | sources: 41 | dclm_baseline_1.0: 1.0 42 | batch_size: 2 43 | prefetch_size: 64 44 | # seqlen is in terms of patches and 45 | # max_encoder_seq_length is in terms of bytes. 46 | # For entropy model, these are the same since 1 patch=1 byte 47 | seq_len: 8192 48 | max_encoder_seq_length: 8192 49 | load_async: true 50 | preprocess_dir: ??? 51 | # We don't need patches for this model 52 | add_patches: false 53 | patcher_args: 54 | # This doesn't matter since byte entropy model doesn't use patching, 55 | # so pick the most efficient, so static 56 | patching_mode: byte 57 | tokenizer_args: 58 | name: blt 59 | 60 | profiling: 61 | run: false 62 | 63 | checkpoint: 64 | dump: 65 | every: 500 66 | keep: 3 67 | eval: 68 | every: 1000 69 | keep: -1 70 | 71 | logging: 72 | freq: 10 73 | 74 | eval_on_gpus: 8 75 | eval: null 76 | -------------------------------------------------------------------------------- /bytelatent/data/test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import pytest 5 | from omegaconf import OmegaConf 6 | 7 | from bytelatent.args import TrainArgs 8 | from bytelatent.constants import BLT_DATA 9 | 10 | 11 | def get_test_config(): 12 | if "BLT_INTERNAL" in os.environ: 13 | internal_dir = os.environ["BLT_INTERNAL"] 14 | else: 15 | internal_dir = "../internal-blt/configs" 16 | test_config = os.path.join(internal_dir, "tests.yaml") 17 | return test_config 18 | 19 | 20 | @pytest.mark.skipif( 21 | not os.path.exists(get_test_config()), 22 | reason="Skipping since internal config is missing", 23 | ) 24 | def test_first_batch_matches(): 25 | test_config_path = get_test_config() 26 | default_cfg = OmegaConf.create(TrainArgs().model_dump()) 27 | file_cfg = OmegaConf.load(test_config_path) 28 | merged_cfg = OmegaConf.merge(default_cfg, file_cfg) 29 | merged_cfg = OmegaConf.to_container(merged_cfg, resolve=True, throw_on_missing=True) 30 | train_args = TrainArgs.model_validate(merged_cfg) 31 | # MP doesn't work with async very well, but it doesn't change logic 32 | train_args.data.load_async = False 33 | 34 | # Test data created by pickling first batch in train loop then exiting 35 | with open(os.path.join(BLT_DATA, "fixtures", "first_batch_0.pickle"), "rb") as f: 36 | first_batch = pickle.load(f) 37 | 38 | # Emulate 1 node, 8 gpu training 39 | data_loader = train_args.data.build_from_rank(0, 8) 40 | batch_iterator = data_loader.create_iter() 41 | print("Getting first batch") 42 | batch = next(batch_iterator) 43 | assert (batch.x == first_batch.x).all() 44 | assert (batch.y == first_batch.y).all() 45 | assert (batch.mask == first_batch.mask).all() 46 | assert (batch.patch_lengths == first_batch.patch_lengths).all() 47 | assert batch.ngram_ids is None and first_batch.ngram_ids is None 48 | assert batch.is_final == False and batch.is_final == False 49 | -------------------------------------------------------------------------------- /apps/main/configs/llama_7B.yaml: -------------------------------------------------------------------------------- 1 | #python -m lingua.stool config=apps/main/configs/llama2_7B.yaml nodes=32 account=fair_amaia_cw_codegen qos=lowest 2 | # dump_dir: !!!CHANGE_THIS!!! 3 | name: "7b_baseline" 4 | steps: 100_000 5 | grad_acc_steps: 1 6 | probe_freq: 100 7 | 8 | seed: 777 9 | optim: 10 | lr: 1.0e-3 11 | weight_decay: 0.1 12 | warmup: 2000 13 | lr_min_ratio: 0.000001 14 | clip: 1.0 15 | 16 | distributed: 17 | fsdp_type: full_shard 18 | compile: true 19 | model_dtype: bf16 20 | matmul_allow_tf32: false 21 | selective_activation_checkpointing: false 22 | tp_size: 1 23 | 24 | model: 25 | dim: 4096 26 | n_layers: 32 27 | n_heads: 32 28 | rope_theta: 100_000 29 | ffn_dim_multiplier: 1.0 30 | multiple_of: 256 31 | 32 | data: 33 | root_dir: data/shuffled 34 | sources: 35 | dclm_baseline_1.0: 1.0 36 | batch_size: 2 37 | prefetch_size: 1024 38 | seq_len: 4096 39 | n_views: 2 40 | load_async: true 41 | tokenizer: 42 | name: tiktoken 43 | path: tokenizers/cl_toplang_128k.tiktoken 44 | 45 | profiling: 46 | run: true 47 | mem_warmup: 0 48 | mem_steps: 4 49 | profile_warmup: 100 50 | profile_steps: 4 51 | 52 | checkpoint: 53 | dump: 54 | every: 10000 55 | keep: -1 56 | eval: 57 | every: 1000 58 | keep: 3 59 | 60 | logging: 61 | freq: 1 62 | 63 | async_eval_gpus: 8 64 | eval: 65 | dataset_dir: datasets/eval 66 | harness: 67 | tasks: 68 | - hellaswag 69 | - task: boolq 70 | dataset_kwargs: 71 | trust_remote_code: true 72 | - piqa 73 | - task: social_iqa 74 | dataset_kwargs: 75 | trust_remote_code: true 76 | - winogrande 77 | - openbookqa 78 | - arc_easy 79 | - arc_challenge 80 | - race 81 | - commonsense_qa 82 | # - coqa 83 | - copa 84 | - mmlu 85 | - mmlu_pro 86 | # - task: nq_open 87 | # num_fewshot: 5 88 | # - triviaqa 89 | # - gsm8k 90 | # - bbh 91 | validation: 92 | max_steps: 1000 93 | generator: 94 | max_tokens: 8192 95 | dtype: bf16 96 | -------------------------------------------------------------------------------- /setup/download_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import argparse 3 | import os 4 | from typing import Optional 5 | 6 | from requests.exceptions import HTTPError 7 | 8 | TOKENIZER = { 9 | "llama2": ("meta-llama/Llama-2-7b", "tokenizer.model"), 10 | "llama3": ("meta-llama/Meta-Llama-3-8B", "original/tokenizer.model"), 11 | "gemma": ("google/gemma-2-9b", "tokenizer.model"), 12 | } 13 | 14 | 15 | def main(tokenizer_name: str, path_to_save: str, api_key: Optional[str] = None): 16 | if tokenizer_name in TOKENIZER: 17 | repo_id, filename = TOKENIZER[tokenizer_name] 18 | 19 | from huggingface_hub import hf_hub_download 20 | 21 | try: 22 | hf_hub_download( 23 | repo_id=repo_id, 24 | filename=filename, 25 | local_dir=path_to_save, 26 | local_dir_use_symlinks=False, 27 | token=api_key if api_key else None, 28 | ) 29 | except HTTPError as e: 30 | if e.response.status_code == 401: 31 | print( 32 | "You need to pass a valid `--hf_token=...` to download private checkpoints." 33 | ) 34 | else: 35 | raise e 36 | else: 37 | from tiktoken import get_encoding 38 | 39 | if "TIKTOKEN_CACHE_DIR" not in os.environ: 40 | os.environ["TIKTOKEN_CACHE_DIR"] = path_to_save 41 | try: 42 | get_encoding(tokenizer_name) 43 | except ValueError: 44 | print( 45 | f"Tokenizer {tokenizer_name} not found. Please check the name and try again." 46 | ) 47 | 48 | 49 | if __name__ == "__main__": 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument("tokenizer_name", type=str) 52 | parser.add_argument("tokenizer_dir", type=str, default=8) 53 | parser.add_argument("--api_key", type=str, default="") 54 | args = parser.parse_args() 55 | 56 | main( 57 | tokenizer_name=args.tokenizer_name, 58 | path_to_save=args.tokenizer_dir, 59 | api_key=args.api_key, 60 | ) 61 | -------------------------------------------------------------------------------- /bytelatent/tokenizers/build_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import logging 3 | from typing import Any 4 | 5 | from pydantic import BaseModel 6 | 7 | from bytelatent.tokenizers.blt_tokenizer import BltTokenizer 8 | from bytelatent.tokenizers.tiktoken_tokenizer import TikTokenTokenizer 9 | 10 | try: 11 | from sentencepiece import SentencePieceProcessor 12 | 13 | has_sp = True 14 | except ImportError: 15 | has_sp = False 16 | 17 | try: 18 | import tiktoken 19 | from tiktoken.load import load_tiktoken_bpe 20 | 21 | has_tiktoken = True 22 | except ImportError: 23 | has_tiktoken = False 24 | 25 | from bytelatent.tokenizers.abstract_tokenizer import Tokenizer 26 | from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class MockTokenizer(Tokenizer): 32 | n_words: int = 256 33 | 34 | def encode(self, text: str, add_bos: bool, add_eos: bool): 35 | return text 36 | 37 | def decode(self, tokens): 38 | raise NotImplementedError() 39 | 40 | def get_token_offsets( 41 | self, text: str, tokens: list[int] | None = None 42 | ) -> tuple[list[str]]: 43 | raise NotImplementedError() 44 | 45 | 46 | class TokenizerArgs(BaseModel): 47 | name: str = "bytes" 48 | init_kwargs: dict[str, Any] | None = None 49 | 50 | def build(self) -> Tokenizer: 51 | if self.init_kwargs is None: 52 | init_kwargs = {} 53 | else: 54 | init_kwargs = self.init_kwargs 55 | if self.name == "blt": 56 | return BltTokenizer(**init_kwargs) 57 | elif self.name == "mock": 58 | return MockTokenizer(**init_kwargs) 59 | elif self.name == "sp": 60 | assert has_sp, "sentencepiece not installed" 61 | return SentencePieceTokenizer(**init_kwargs) 62 | elif self.name == "tiktoken": 63 | assert has_tiktoken, "tiktoken not installed" 64 | return TikTokenTokenizer(**init_kwargs) 65 | else: 66 | raise NotImplementedError(f"{self.name} tokenizer type is not implemented") 67 | -------------------------------------------------------------------------------- /bytelatent/tokenizers/sentence_piece_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import logging 3 | import os 4 | 5 | try: 6 | from sentencepiece import SentencePieceProcessor 7 | 8 | has_sp = True 9 | except ImportError: 10 | has_sp = False 11 | 12 | from bytelatent.tokenizers.abstract_tokenizer import Tokenizer 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class SentencePieceTokenizer(Tokenizer): 18 | def __init__( 19 | self, model_path: str, add_bos: bool = True, add_eos: bool = True 20 | ) -> None: 21 | assert os.path.isfile(model_path), model_path 22 | self.sp_model = SentencePieceProcessor(model_file=model_path) 23 | 24 | logger.info(f"Reloaded SentencePiece model from {model_path}") 25 | 26 | # BOS / EOS token IDs 27 | self.n_words: int = self.sp_model.vocab_size() 28 | self.bos_id: int = self.sp_model.bos_id() 29 | self.eos_id: int = self.sp_model.eos_id() 30 | self.pad_id: int = self.sp_model.pad_id() 31 | self.add_bos = add_bos 32 | self.add_eos = add_eos 33 | logger.info( 34 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" 35 | ) 36 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 37 | 38 | def get_vocab_size(self) -> int: 39 | return self.n_words 40 | 41 | def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None): 42 | if add_bos is None: 43 | add_bos = self.add_bos 44 | 45 | if add_eos is None: 46 | add_eos = self.add_eos 47 | assert type(s) is str 48 | tokens = ( 49 | [self.bos_id] * add_bos + self.sp_model.encode(s) + [self.eos_id] * add_eos 50 | ) 51 | return tokens 52 | 53 | def decode(self, tokens: list[int]): 54 | return self.sp_model.decode(tokens) 55 | 56 | def get_token_offsets( 57 | self, text: str, tokens: list[int] | None = None 58 | ) -> tuple[list[str], list[int]]: 59 | pieces = self.sp_model.encode_as_immutable_proto(text).pieces 60 | substrs = [p.surface for p in pieces] 61 | offsets = [p.begin for p in pieces] 62 | return substrs, offsets 63 | -------------------------------------------------------------------------------- /bytelatent/test_entropy_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import os 3 | 4 | import torch 5 | 6 | from bytelatent.constants import BLT_DATA 7 | from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState 8 | from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator 9 | from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum, entropy 10 | from bytelatent.entropy_model import load_entropy_model 11 | from bytelatent.tokenizers.build_tokenizer import TokenizerArgs 12 | 13 | ENTROPY_MODEL = "transformer_100m" 14 | ARROW_TEST_DATA = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow") 15 | 16 | 17 | def test_entropy_model(): 18 | initial_state = ArrowFileIteratorState( 19 | file_path=None, 20 | num_workers=1, 21 | worker_id=0, 22 | preprocess_dir=None, 23 | entropy_model_name=ENTROPY_MODEL, 24 | dataset_files=[ARROW_TEST_DATA], 25 | row_num=0, 26 | arrow_batch_size=100, 27 | s3_profile=None, 28 | file_format="arrow", 29 | ) 30 | arrow_file = initial_state.build() 31 | tokenizer_args = TokenizerArgs( 32 | name="blt", 33 | init_kwargs={ 34 | "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" 35 | }, 36 | ) 37 | entropy_model, _ = load_entropy_model( 38 | BLT_DATA / "checkpoint_0100000_consolidated", 39 | os.path.join( 40 | BLT_DATA, 41 | "entropy_model.pth", 42 | ), 43 | ).cuda() 44 | preprocess_iter = PreprocessIterator( 45 | arrow_file, 46 | tokenizer_args=tokenizer_args, 47 | patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.entropy), 48 | add_patches=False, 49 | ) 50 | for example in preprocess_iter.create_iter(): 51 | tokens = torch.tensor(example.tokens).unsqueeze(0) 52 | expected_entropies = torch.tensor(example.entropies).unsqueeze(0) 53 | preds = entropy_model(tokens.cuda()) 54 | pred_entropies = entropy(preds) 55 | assert pred_entropies.shape == expected_entropies.shape 56 | assert torch.allclose( 57 | pred_entropies.cpu(), expected_entropies, rtol=1.0, atol=3.5 58 | ) 59 | break 60 | -------------------------------------------------------------------------------- /bytelatent/preprocess/data_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import subprocess 3 | from pathlib import Path 4 | 5 | import luigi 6 | 7 | # CHANGEME: Change this to point to your data 8 | BASE_DIR = Path("datasets") 9 | DATASETS = ["dclm"] 10 | TARGET_DIR = Path("entropy_preprocess") 11 | 12 | SHARD_SCRIPT = """split -C 2500m -d {source} {destination}.shard_""" 13 | 14 | 15 | def list_dataset_shards(dataset: str): 16 | dataset_dir = BASE_DIR / dataset 17 | return list(dataset_dir.glob("*.chunk.*.jsonl")) 18 | 19 | 20 | class ChunkFile(luigi.ExternalTask): 21 | file = luigi.Parameter() 22 | 23 | def output(self): 24 | return luigi.LocalTarget(self.file) 25 | 26 | 27 | class ShardDatasetChunk(luigi.Task): 28 | dataset_name = luigi.Parameter() 29 | chunk_file = luigi.Parameter() 30 | 31 | def _chunk_filename(self): 32 | return Path(self.chunk_file).name 33 | 34 | def requires(self): 35 | return ChunkFile(self.chunk_file) 36 | 37 | def run(self): 38 | destination_dir = TARGET_DIR / str(self.dataset_name) 39 | destination_dir.mkdir(parents=True, exist_ok=True) 40 | destination = destination_dir / self._chunk_filename() 41 | subprocess.check_output( 42 | SHARD_SCRIPT.format(source=str(self.chunk_file), destination=destination), 43 | shell=True, 44 | ) 45 | ( 46 | Path(TARGET_DIR) 47 | / str(self.dataset_name) 48 | / f"{self._chunk_filename()}.shard.COMPLETE" 49 | ).touch() 50 | 51 | def output(self): 52 | return luigi.LocalTarget( 53 | TARGET_DIR 54 | / str(self.dataset_name) 55 | / f"{self._chunk_filename()}.shard.COMPLETE" 56 | ) 57 | 58 | 59 | class ShardDataset(luigi.WrapperTask): 60 | dataset_name = luigi.Parameter() 61 | 62 | def requires(self): 63 | for f in list_dataset_shards(self.dataset_name): 64 | yield ShardDatasetChunk(dataset_name=self.dataset_name, chunk_file=str(f)) 65 | 66 | 67 | class ShardAllDatasets(luigi.WrapperTask): 68 | def requires(self): 69 | for d in DATASETS: 70 | yield ShardDataset(dataset_name=d) 71 | 72 | 73 | if __name__ == "__main__": 74 | luigi.build([ShardAllDatasets()], local_scheduler=True, workers=128) 75 | -------------------------------------------------------------------------------- /bytelatent/data/iterators/sampling_iterator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | from typing import Any 3 | 4 | import numpy as np 5 | from pydantic import ConfigDict 6 | 7 | from bytelatent.data.iterators.abstract_iterator import ( 8 | PydanticIteratorState, 9 | StatefulIterator, 10 | ) 11 | from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState 12 | 13 | 14 | class SamplingIteratorState(PydanticIteratorState): 15 | model_config = ConfigDict(extra="forbid") 16 | rng_state: dict[str, Any] 17 | source_to_weight: dict[str, float] 18 | source_to_iterator_state: dict[str, SequenceIteratorState] 19 | 20 | def build(self) -> "SamplingIterator": 21 | return SamplingIterator( 22 | rng_state=self.rng_state, 23 | source_to_weight=self.source_to_weight, 24 | source_to_iterator={ 25 | source: state.build() 26 | for source, state in self.source_to_iterator_state.items() 27 | }, 28 | ) 29 | 30 | 31 | class SamplingIterator(StatefulIterator): 32 | def __init__( 33 | self, 34 | *, 35 | rng_state: dict[str, Any], 36 | source_to_weight: dict[str, float], 37 | source_to_iterator: dict[str, StatefulIterator], 38 | ): 39 | self.rng = np.random.default_rng() 40 | self.rng.bit_generator.state = rng_state 41 | self.source_to_weight = source_to_weight 42 | self.source_to_iterator = source_to_iterator 43 | 44 | def get_state(self) -> SamplingIteratorState: 45 | return SamplingIteratorState( 46 | rng_state=self.rng.bit_generator.state, 47 | source_to_weight=self.source_to_weight, 48 | source_to_iterator_state={ 49 | source: iterator.get_state() 50 | for source, iterator in self.source_to_iterator.items() 51 | }, 52 | ) 53 | 54 | def create_iter(self): 55 | n_sources = len(self.source_to_weight) 56 | possible_sources = [] 57 | weights = [] 58 | for source, w in self.source_to_weight.items(): 59 | possible_sources.append(source) 60 | weights.append(w) 61 | 62 | source_to_python_iter = { 63 | source: self.source_to_iterator[source].create_iter() 64 | for source in possible_sources 65 | } 66 | while True: 67 | norm_weights = np.array(weights) / np.array(weights).sum() 68 | source_choice = possible_sources[self.rng.choice(n_sources, p=norm_weights)] 69 | yield next(source_to_python_iter[source_choice]) 70 | -------------------------------------------------------------------------------- /bytelatent/data/iterators/dev_iterators.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pydantic import ConfigDict 3 | 4 | from bytelatent.data.data_types import BltExample 5 | from bytelatent.data.iterators.abstract_iterator import ( 6 | PydanticIteratorState, 7 | StatefulIterator, 8 | ) 9 | 10 | 11 | class BltTestIteratorState(PydanticIteratorState): 12 | model_config = ConfigDict(extra="forbid") 13 | position: int 14 | total: int 15 | 16 | def build(self): 17 | blt_iter = BltTestIteratorState(total=self.total) 18 | blt_iter.position = self.position 19 | return blt_iter 20 | 21 | 22 | class BltTestIterator(StatefulIterator): 23 | def __init__(self, total: int): 24 | self.position = 0 25 | self.total = total 26 | 27 | def get_state(self): 28 | return BltTestIteratorState(position=self.position, total=self.total) 29 | 30 | def create_iter(self): 31 | for i in range(self.total): 32 | self.position += 1 33 | yield BltExample( 34 | sample_id=f"test_{i}", 35 | text=f"This is some test {i} text.", 36 | tokens=None, 37 | mask=None, 38 | entropies=None, 39 | patch_lengths=None, 40 | ) 41 | 42 | 43 | class BltTestWithEntropiesIteratorState(PydanticIteratorState): 44 | model_config = ConfigDict(extra="forbid") 45 | position: int 46 | total: int 47 | 48 | def build(self): 49 | blt_iter = BltTestWithEntropiesIteratorState(total=self.total) 50 | blt_iter.position = self.position 51 | return blt_iter 52 | 53 | 54 | class BltTestWithEntropiesIterator(StatefulIterator): 55 | def __init__(self, total: int): 56 | self.position = 0 57 | self.total = total 58 | 59 | def get_state(self): 60 | return BltTestIteratorState(position=self.position, total=self.total) 61 | 62 | def create_iter(self): 63 | text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin." 64 | df = pd.read_json("fixtures/tokens_with_entropies.json") 65 | tokens = df["token_ids"].tolist() 66 | entropies = df["entropies"].tolist() 67 | # BOS and EOS 68 | assert len(tokens) == len(text) + 2 69 | for i in range(self.total): 70 | self.position += 1 71 | yield BltExample( 72 | sample_id=f"test_{i}", 73 | text=text, 74 | tokens=tokens, 75 | mask=[True] * len(tokens), 76 | entropies=entropies, 77 | patch_lengths=None, 78 | ) 79 | -------------------------------------------------------------------------------- /bytelatent/configs/debug.yaml: -------------------------------------------------------------------------------- 1 | # Template config, need to change dump_dir, data.root_dir and tokenizer.path 2 | # Evals can be activated by uncommenting its config 3 | # python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest 4 | 5 | dump_dir: /tmp/ 6 | name: "debug" 7 | steps: 100_000 8 | probe_freq: null 9 | seed: 777 10 | optim: 11 | lr: 4e-04 12 | warmup: 500 13 | lr_min_ratio: 0.1 14 | clip: 10.0 15 | 16 | distributed: 17 | fsdp_type: full_shard 18 | model_dtype: bf16 19 | matmul_allow_tf32: false 20 | selective_activation_checkpointing: false 21 | tp_size: 1 22 | 23 | model: 24 | n_heads: 8 25 | dim: 512 26 | vocab_size: 260 27 | dim_token: 256 28 | patch_size: 6 29 | patching_mode: "space" 30 | tie_local_encoder_decoder_logits: false 31 | patch_in_forward: false 32 | max_encoder_seq_length: 12288 33 | pad_to_max_length: true 34 | patching_threshold: 3.1439168453216553 35 | encoder_hash_byte_group_size: [4] 36 | encoder_hash_byte_group_vocab: 50002 37 | encoder_hash_byte_group_nb_functions: 3 38 | encoder_enable_byte_ngrams: false 39 | cross_attn_encoder: true # assuming cross_attention is true 40 | cross_attn_decoder: true # assuming cross_attention is true 41 | cross_attn_window_encoder: 512 42 | cross_attn_window_decoder: 512 43 | dim_local_encoder: 256 44 | dim_local_decoder: 256 45 | cross_attn_k: 8 46 | cross_attn_nheads: 4 47 | cross_attn_all_layers_decoder: true 48 | cross_attn_all_layers_encoder: true 49 | cross_attn_use_flex_attention: true 50 | cross_attn_init_by_pooling: true 51 | log_patch_lengths: true 52 | non_linearity: "swiglu" 53 | use_rope: true 54 | recompute_fc1_out: false 55 | recompute_fc3_out: false 56 | recompute_attn: false 57 | custom_bwd: false 58 | layer_ckpt: "none" 59 | use_local_encoder_transformer: true 60 | init_use_gaussian: true 61 | init_use_depth: "current" 62 | attn_impl: "xformers" 63 | attn_bias_type: "block_causal" 64 | alpha_depth: "disabled" 65 | max_length: 256 66 | local_attention_window_len: 512 67 | max_seqlen: 12288 68 | downsampling_by_pooling: "max" 69 | 70 | data: 71 | root_dir: ??? 72 | sources: 73 | dclm_baseline_1.0: 1.0 74 | batch_size: 2 75 | prefetch_size: 64 76 | seq_len: 4096 77 | load_async: true 78 | preprocess_dir: ??? 79 | tokenizer_args: 80 | name: blt 81 | init_kwargs: 82 | bpe_tokenizer_path: ??? 83 | 84 | profiling: 85 | run: false 86 | 87 | checkpoint: 88 | dump: 89 | every: 500 90 | keep: 3 91 | eval: 92 | every: 1000 93 | keep: -1 94 | 95 | logging: 96 | freq: 10 97 | 98 | eval_on_gpus: 8 99 | eval: null 100 | -------------------------------------------------------------------------------- /bytelatent/data/data_types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import json 3 | from dataclasses import dataclass 4 | from typing import Any, Iterator 5 | 6 | import numpy as np 7 | from pydantic import BaseModel, ConfigDict 8 | 9 | 10 | class BltExample(BaseModel): 11 | model_config = ConfigDict(extra="forbid") 12 | sample_id: str 13 | text: str 14 | tokens: list[int] | None 15 | entropies: list[float] | None 16 | patch_lengths: list[int] | None 17 | mask: list[bool] | None 18 | 19 | 20 | class MultiChoiceState(BaseModel): 21 | model_config = ConfigDict(extra="forbid") 22 | root_dir: str 23 | sources: dict[str, float] 24 | source_to_state: dict[str, Any] 25 | rng_state: dict[str, Any] 26 | 27 | 28 | class PrefetchState(BaseModel): 29 | model_config = ConfigDict(extra="forbid") 30 | seq_idx: int 31 | rng_state: dict[str, Any] 32 | prefetch_size: int 33 | batch_size: int 34 | 35 | 36 | class BltPackTokensState(BaseModel): 37 | model_config = ConfigDict(extra="forbid") 38 | start_token: int 39 | output_seq_len: int 40 | n_views: int = 2 41 | 42 | 43 | class BltSequence(BaseModel): 44 | tokens: list[int] 45 | mask: list[bool] 46 | patch_lengths: list[int] | None 47 | 48 | 49 | @dataclass 50 | class Batch: 51 | x: np.ndarray 52 | y: np.ndarray 53 | mask: np.ndarray | None = None 54 | patch_lengths: np.ndarray | None = None 55 | ngram_ids: np.ndarray | None = None 56 | is_final: bool = False 57 | 58 | def to_python_dict(self) -> dict: 59 | x = self.x.tolist() 60 | y = self.y.tolist() 61 | if self.mask is None: 62 | mask = None 63 | else: 64 | mask = self.mask.tolist() 65 | if self.patch_lengths is None: 66 | patch_lengths = None 67 | else: 68 | patch_lengths = self.patch_lengths.tolist() 69 | if self.ngram_ids is None: 70 | ngram_ids = None 71 | else: 72 | ngram_ids = self.ngram_ids.tolist() 73 | return { 74 | "x": x, 75 | "y": y, 76 | "mask": mask, 77 | "patch_lengths": patch_lengths, 78 | "ngram_ids": ngram_ids, 79 | "is_final": self.is_final, 80 | } 81 | 82 | @classmethod 83 | def from_python_dict(cls, data: dict) -> "Batch": 84 | x = np.array(data["x"]) 85 | y = np.array(data["y"]) 86 | if data["mask"] is None: 87 | mask = None 88 | else: 89 | mask = np.array(data["mask"]) 90 | if data["patch_lengths"] is None: 91 | patch_lengths = None 92 | else: 93 | patch_lengths = np.array(data["patch_lengths"]) 94 | if data["ngram_ids"] is None: 95 | ngram_ids = None 96 | else: 97 | ngram_ids = np.array(data["ngram_ids"]) 98 | return Batch( 99 | x=x, 100 | y=y, 101 | mask=mask, 102 | patch_lengths=patch_lengths, 103 | ngram_ids=ngram_ids, 104 | is_final=data["is_final"], 105 | ) 106 | -------------------------------------------------------------------------------- /bytelatent/data/iterators/test_iters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from bytelatent.constants import BLT_DATA 4 | from bytelatent.data.iterators.dev_iterators import ( 5 | BltTestIterator, 6 | BltTestWithEntropiesIterator, 7 | ) 8 | from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator 9 | from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum 10 | from bytelatent.tokenizers.build_tokenizer import TokenizerArgs 11 | 12 | 13 | def test_preprocess_iter(): 14 | total = 3 15 | tokenizer_args = TokenizerArgs( 16 | name="blt", 17 | init_kwargs={ 18 | "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" 19 | }, 20 | ) 21 | for mode in [ 22 | PatchingModeEnum.bpe, 23 | PatchingModeEnum.space, 24 | ]: 25 | data_it = BltTestIterator(total) 26 | patcher_args = PatcherArgs(patching_mode=mode) 27 | example_it = PreprocessIterator( 28 | data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args 29 | ) 30 | count = 0 31 | for example in example_it.create_iter(): 32 | assert isinstance(example.tokens, list) 33 | assert isinstance(example.tokens[0], int) 34 | # BOS and EOS 35 | assert len(example.tokens) == len(example.text) + 2 36 | assert example.mask is not None 37 | assert len(example.tokens) == len(example.mask) 38 | count += 1 39 | 40 | assert count == total 41 | 42 | 43 | def test_non_entropy_patch_iter(): 44 | total = 3 45 | tokenizer_args = TokenizerArgs( 46 | name="blt", 47 | init_kwargs={ 48 | "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" 49 | }, 50 | ) 51 | for mode in [ 52 | PatchingModeEnum.bpe, 53 | PatchingModeEnum.space, 54 | ]: 55 | patcher_args = PatcherArgs(patching_mode=mode) 56 | data_it = BltTestIterator(total) 57 | example_it = PreprocessIterator( 58 | data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args 59 | ) 60 | 61 | count = 0 62 | for example in example_it.create_iter(): 63 | assert isinstance(example.patch_lengths, list) 64 | assert isinstance(example.patch_lengths[0], int) 65 | assert len(example.tokens) == sum(example.patch_lengths) 66 | count += 1 67 | 68 | assert count == total 69 | 70 | 71 | def test_entropy_patch_iter(): 72 | total = 2 73 | patcher_args = PatcherArgs( 74 | patching_mode=PatchingModeEnum.entropy, threshold=1.335442066192627 75 | ) 76 | tokenizer_args = TokenizerArgs( 77 | name="blt", 78 | init_kwargs={ 79 | "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" 80 | }, 81 | ) 82 | data_it = BltTestWithEntropiesIterator(total) 83 | example_it = PreprocessIterator( 84 | data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args 85 | ) 86 | 87 | count = 0 88 | for example in example_it.create_iter(): 89 | assert isinstance(example.patch_lengths, list) 90 | assert isinstance(example.patch_lengths[0], int) 91 | assert len(example.tokens) == sum(example.patch_lengths) 92 | count += 1 93 | 94 | assert count == total 95 | -------------------------------------------------------------------------------- /bytelatent/plotting/entropy_figure.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import json 3 | import os 4 | import sys 5 | from pathlib import Path 6 | 7 | import altair as alt 8 | import pandas as pd 9 | from omegaconf import OmegaConf 10 | from pydantic import BaseModel 11 | 12 | 13 | class PlotEntropiesConfig(BaseModel): 14 | data_path: str | None 15 | chart_path: str 16 | score_override_path: str | None = None 17 | threshold_override: float | None = None 18 | 19 | class Config: 20 | extra = "forbid" 21 | 22 | 23 | class PlotEntropiesData(BaseModel): 24 | text: str 25 | threshold: float = 1.335442066192627 26 | dataframe_json: str | None 27 | 28 | class Config: 29 | extra = "forbid" 30 | 31 | 32 | def main(): 33 | config_path = sys.argv[1] 34 | file_config = OmegaConf.load(config_path) 35 | # Omit program name and config file name 36 | cli_conf = OmegaConf.from_cli(sys.argv[2:]) 37 | conf_dict = OmegaConf.to_container( 38 | OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True 39 | ) 40 | plot_config = PlotEntropiesConfig(**conf_dict) 41 | with open(plot_config.data_path) as f: 42 | json_data = f.read() 43 | 44 | plot_data = PlotEntropiesData.model_validate_json(json_data) 45 | df = pd.read_json(plot_data.dataframe_json) 46 | print("LEN", len(df)) 47 | if plot_config.threshold_override is None: 48 | threshold = plot_data.threshold 49 | else: 50 | threshold = plot_config.threshold_override 51 | if plot_config.score_override_path is not None: 52 | with open(plot_config.score_override_path) as f: 53 | scores = json.load(f)["score"] 54 | assert len(scores) == len(df) 55 | df["entropies"] = scores 56 | df["start"] = [1] + (df["entropies"] > threshold).values.tolist()[:-1] 57 | 58 | x_ticks = [] 59 | for row in df.itertuples(): 60 | position = row.position 61 | token = row.tokens 62 | x_ticks.append(f"{str(position).zfill(3)}|{token}") 63 | df["position_with_token"] = x_ticks 64 | print(df) 65 | 66 | x_axis = alt.Axis( 67 | labelExpr="split(datum.label, '|')[1]", 68 | grid=False, 69 | labelOverlap=False, 70 | labelAngle=0, 71 | ) 72 | width = 1200 73 | height = 150 74 | base = alt.Chart(df).properties(width=width, height=height) 75 | points = base.mark_line(point=True).encode( 76 | x=alt.X("position_with_token:O", title=None, axis=x_axis), 77 | y=alt.Y( 78 | "entropies", 79 | title="Entropy of Next Byte", 80 | ), 81 | ) 82 | rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode( 83 | y=alt.datum(threshold), 84 | ) 85 | patch_rules = ( 86 | alt.Chart(df[df["start"] > 0]) 87 | .properties(width=width, height=height) 88 | .mark_rule(color="#474747", strokeDash=[4, 2]) 89 | .encode(x=alt.X("position_with_token:O", axis=x_axis)) 90 | ) 91 | 92 | chart = patch_rules + rule + points 93 | chart = chart.configure_axis(labelFontSize=15, titleFontSize=15) 94 | path = Path(plot_config.chart_path) 95 | path.parent.mkdir(exist_ok=True) 96 | chart.save(path) 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /bytelatent/tokenizers/tiktoken_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import logging 3 | from copy import copy 4 | from pathlib import Path 5 | 6 | from bytelatent.tokenizers.abstract_tokenizer import Tokenizer 7 | 8 | try: 9 | import tiktoken 10 | from tiktoken.load import load_tiktoken_bpe 11 | 12 | has_tiktoken = True 13 | except ImportError: 14 | has_tiktoken = False 15 | DEFAULT_TIKTOKEN_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" 16 | DEFAULT_TIKTOKEN_SPECIAL_TOKENS = { 17 | "<|begin_of_text|>": 0, 18 | "<|end_of_text|>": 1, 19 | "<|fim_prefix|>": 2, 20 | "<|fim_middle|>": 3, 21 | "<|fim_end_fill|>": 253, 22 | "<|fim_pad|>": 254, 23 | "<|fim_suffix|>": 255, 24 | } 25 | TIKTOKEN_MAX_ENCODE_CHARS = 400_000 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | class TikTokenTokenizer(Tokenizer): 31 | def __init__(self, model_path: str) -> None: 32 | mergeable_ranks = load_tiktoken_bpe(model_path) 33 | all_special_tokens_with_ids = copy(DEFAULT_TIKTOKEN_SPECIAL_TOKENS) 34 | missing_ids = set(range(256)) - set(all_special_tokens_with_ids.values()) 35 | for id in missing_ids: 36 | all_special_tokens_with_ids[f"<|reserved_special_token_{id}|>"] = id 37 | for name in all_special_tokens_with_ids: 38 | all_special_tokens_with_ids[name] += len(mergeable_ranks) 39 | 40 | self.tkt_model = tiktoken.core.Encoding( 41 | name=Path(model_path).stem, 42 | pat_str=DEFAULT_TIKTOKEN_PATTERN, 43 | mergeable_ranks=mergeable_ranks, 44 | special_tokens=all_special_tokens_with_ids, 45 | ) 46 | 47 | self.bos_id: int = self.tkt_model.encode_single_token("<|begin_of_text|>") 48 | self.eos_id: int = self.tkt_model.encode_single_token("<|end_of_text|>") 49 | 50 | self.n_words: int = self.tkt_model.n_vocab 51 | 52 | logger.info( 53 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" 54 | ) 55 | 56 | def get_vocab_size(self) -> int: 57 | return self.n_words 58 | 59 | def encode(self, s: str, add_bos: bool, add_eos: bool): 60 | assert isinstance(s, str) 61 | 62 | subs = [] 63 | for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS): 64 | subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS]) 65 | return ( 66 | [self.bos_id] * add_bos 67 | + sum(self.tkt_model.encode_ordinary_batch(subs), start=[]) 68 | + [self.eos_id] * add_eos 69 | ) 70 | 71 | def decode(self, tokens: list[int]): 72 | return self.tkt_model.decode(tokens) 73 | 74 | def get_token_offsets( 75 | self, text: str, tokens: list[int] | None = None 76 | ) -> tuple[list[str], list[int]]: 77 | if tokens is not None: 78 | token_bytes = self.tkt_model.decode_tokens_bytes(tokens) 79 | else: 80 | token_bytes = self.tkt_model.decode_tokens_bytes( 81 | self.tkt_model.encode(text, allowed_special="all") 82 | ) 83 | 84 | text_len, offsets = 0, [] 85 | for token in token_bytes: 86 | offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0))) 87 | text_len += sum(1 for c in token if not 0x80 <= c < 0xC0) 88 | substrs = [text[s:e] for s, e in zip(offsets, offsets[1:] + [None])] 89 | return substrs, offsets 90 | -------------------------------------------------------------------------------- /bytelatent/config_parser.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Any, Type, TypeVar 3 | 4 | import omegaconf 5 | from omegaconf import DictConfig, OmegaConf 6 | from pydantic import BaseModel 7 | from pydantic_core import PydanticUndefined 8 | 9 | 10 | def parse_file_config(path: str) -> DictConfig: 11 | file_cfg = OmegaConf.load(path) 12 | if not isinstance(file_cfg, DictConfig): 13 | raise ValueError( 14 | f"File paths must parse to DictConfig, but it was: {type(file_cfg)}" 15 | ) 16 | return file_cfg 17 | 18 | 19 | def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]: 20 | if "config" not in cfg: 21 | return [cfg] 22 | 23 | ordered_cfgs = [] 24 | cfg = copy.deepcopy(cfg) 25 | config_arg = cfg["config"] 26 | del cfg["config"] 27 | ordered_cfgs.append(cfg) 28 | 29 | if isinstance(config_arg, str): 30 | file_cfg = parse_file_config(config_arg) 31 | sub_configs = recursively_parse_config(file_cfg) 32 | ordered_cfgs = sub_configs + ordered_cfgs 33 | elif isinstance(config_arg, omegaconf.listconfig.ListConfig): 34 | sub_configs = [] 35 | for c in config_arg: 36 | if not isinstance(c, str): 37 | raise ValueError( 38 | f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}' 39 | ) 40 | config_to_parse = parse_file_config(c) 41 | sub_configs.extend(recursively_parse_config(config_to_parse)) 42 | ordered_cfgs = sub_configs + ordered_cfgs 43 | else: 44 | raise ValueError( 45 | f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}' 46 | ) 47 | return ordered_cfgs 48 | 49 | 50 | def parse_args_with_default( 51 | *, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None 52 | ): 53 | if cli_args is None: 54 | cli_args = OmegaConf.from_cli() 55 | assert isinstance( 56 | cli_args, DictConfig 57 | ), f"CLI Args must be a DictConfig, not {type(cli_args)}" 58 | ordered_cfgs = recursively_parse_config(cli_args) 59 | if default_cfg is not None: 60 | ordered_cfgs.insert(0, default_cfg) 61 | cfg = OmegaConf.merge(*ordered_cfgs) 62 | # TODO: Change sources to list[tuple,str, float]] so that this special case isn't needed 63 | for c in reversed(ordered_cfgs): 64 | if "data" in c and "sources" in c["data"]: 65 | cfg["data"]["sources"] = c["data"]["sources"] 66 | break 67 | return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) 68 | 69 | 70 | T = TypeVar("T", bound=BaseModel) 71 | 72 | 73 | def get_pydantic_default_args(args_cls: Type[T]) -> dict[str, Any]: 74 | defaults = {} 75 | for field, info in args_cls.model_fields.items(): 76 | if info.default != PydanticUndefined: 77 | defaults[field] = info.default 78 | return defaults 79 | 80 | 81 | def parse_args_to_pydantic_model( 82 | args_cls: Type[T], 83 | cli_args: DictConfig | None = None, 84 | instantiate_default_cls: bool = True, 85 | ) -> T: 86 | if instantiate_default_cls: 87 | default_cfg = OmegaConf.create(args_cls().model_dump()) 88 | else: 89 | default_cfg = OmegaConf.create(get_pydantic_default_args(args_cls)) 90 | parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args) 91 | print(default_cfg) 92 | print() 93 | print(parsed_cfg) 94 | pydantic_args = args_cls.model_validate(parsed_cfg) 95 | return pydantic_args 96 | -------------------------------------------------------------------------------- /bytelatent/plotting/scaling_figures.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import sys 3 | from pathlib import Path 4 | 5 | import altair as alt 6 | import pandas as pd 7 | import pydantic 8 | from omegaconf import OmegaConf 9 | 10 | 11 | class ScalingPlotsConfig(pydantic.BaseModel): 12 | df_dir: str 13 | output_chart_dir: str 14 | frame_files: list[str] 15 | 16 | class Config: 17 | extra = "forbid" 18 | 19 | 20 | def determine_family(key: str): 21 | if key.startswith("Megabyte++"): 22 | return "Megabyte++" 23 | elif key.startswith("BLT"): 24 | return "BLT" 25 | elif key.startswith("LLaMA"): 26 | return "LLaMA" 27 | elif key.startswith("Space"): 28 | return "Space" 29 | 30 | 31 | file_to_vars = {} 32 | 33 | 34 | def create_chart(df: pd.DataFrame, output_file: str): 35 | df["metric"] = df["bpb/not_heldout.jsonl"] 36 | df["family"] = df["key"].map(determine_family) 37 | model_domain = [ 38 | "BLT Space ps=6", 39 | "BLT Space w/o cross-attn", 40 | "SpaceByte", 41 | "LLaMA 3 BPE", 42 | "Megabyte++ ps=4", 43 | "Megabyte++ ps=6", 44 | ] 45 | color_range = ["#1f77b4", "#1f77b4", "#1f77b4", "#ff7f0e", "#2ca02c", "#2ca02c"] 46 | shape_range = [ 47 | "circle", 48 | "square", 49 | "cross", 50 | "diamond", 51 | "triangle-up", 52 | "triangle-down", 53 | ] 54 | color_scale = alt.Scale(domain=model_domain, range=color_range) 55 | shape_scale = alt.Scale( 56 | domain=model_domain, 57 | range=shape_range, 58 | ) 59 | base_chart = alt.Chart(df).encode( 60 | x=alt.X("flops", title="Training FLOPS") 61 | .scale(type="log", domain=[2e20, 1.25e22]) 62 | .axis(values=[2e20, 4e20, 8e20, 1e21, 2e21, 4e21, 8e21, 1e22]), 63 | y=alt.Y("metric", title="Bits per Byte (BPB)").scale(zero=False), 64 | ) 65 | lines = base_chart.encode( 66 | color=alt.Color("key", title="Model Color", scale=color_scale, legend=None), 67 | strokeDash=alt.StrokeDash("family", title="Model Family", legend=None), 68 | ).mark_line() 69 | points = base_chart.encode( 70 | color=alt.Color("key", title="Model", scale=color_scale), 71 | shape=alt.Shape("key", title="", scale=shape_scale), 72 | ).mark_point(size=70) 73 | chart = ( 74 | (lines + points) 75 | .resolve_scale( 76 | color="independent", 77 | shape="independent", 78 | # strokeDash="independent", 79 | ) 80 | .configure_legend(orient="right") 81 | .properties(height=300, width=400) 82 | ) 83 | print("Saving", output_file) 84 | chart.save(output_file) 85 | 86 | 87 | def main(): 88 | config_path = sys.argv[1] 89 | file_config = OmegaConf.load(config_path) 90 | # Omit program name and config file name 91 | cli_conf = OmegaConf.from_cli(sys.argv[2:]) 92 | conf_dict = OmegaConf.to_container( 93 | OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True 94 | ) 95 | plot_config = ScalingPlotsConfig(**conf_dict) 96 | df_dir = Path(plot_config.df_dir) 97 | chart_dir = Path(plot_config.output_chart_dir) 98 | chart_dir.mkdir(exist_ok=True, parents=True) 99 | for ff in plot_config.frame_files: 100 | path = df_dir / ff 101 | df = pd.read_json(path) 102 | print(df) 103 | print(df.columns) 104 | create_chart(df, chart_dir / f"{path.name}.pdf") 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /bytelatent/preprocess/parallel_entropies.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import subprocess 3 | from pathlib import Path 4 | 5 | import submitit 6 | import typer 7 | 8 | 9 | class PreprocessEntropiesJob(submitit.helpers.Checkpointable): 10 | def __init__(self) -> None: 11 | pass 12 | 13 | def __call__(self, shard_file: str, output_filename: str): 14 | subprocess.run( 15 | [ 16 | "python", 17 | "-u", 18 | "-m", 19 | "bytelatent.preprocess.preprocess_entropies", 20 | str(shard_file), 21 | str(output_filename), 22 | ], 23 | check=True, 24 | ) 25 | return True 26 | 27 | 28 | def chunk(items, size): 29 | for i in range(0, len(items), size): 30 | yield items[i : i + size] 31 | 32 | 33 | def main( 34 | job_folder: str, 35 | input_dir: str, 36 | output_dir: str, 37 | qos: str = "explore", 38 | slurm_batch_size: int = 1000, 39 | check_only: bool = False, 40 | wait: bool = False, 41 | ): 42 | input_dir = Path(input_dir) 43 | output_dir = Path(output_dir) 44 | shard_files = [ 45 | p for p in input_dir.glob("*.jsonl.shard*") if "COMPLETE" not in p.name 46 | ] 47 | if check_only: 48 | exist = [] 49 | missing = [] 50 | for shard_file in shard_files: 51 | shard_file = Path(shard_file) 52 | complete_file = output_dir / f"{shard_file.name}.arrow.complete" 53 | if complete_file.exists(): 54 | exist.append(complete_file) 55 | else: 56 | missing.append(complete_file) 57 | print("Checked for output files for input_dir=", input_dir) 58 | print("Exist:", len(exist)) 59 | print("Missing:", len(missing)) 60 | print(missing) 61 | return 62 | print("Running parallel job over N files=", len(shard_files)) 63 | print("Input Directory:", input_dir) 64 | print("Output Directory:", output_dir) 65 | output_dir.mkdir(exist_ok=True, parents=True) 66 | 67 | executor = submitit.SlurmExecutor(job_folder) 68 | executor.update_parameters( 69 | # 12 hours in minutes 70 | time=60 * 12, 71 | qos=qos, 72 | exclusive="user", 73 | cpus_per_task=4, 74 | num_gpus=1, 75 | mem_per_gpu="80G", 76 | array_parallelism=slurm_batch_size, 77 | ) 78 | 79 | jobs = [] 80 | n_batches = 0 81 | n_skipped = 0 82 | n_launched = 0 83 | for file_batch in chunk(shard_files, slurm_batch_size): 84 | with executor.batch(): 85 | for shard_file in file_batch: 86 | output_filename = Path(output_dir) / f"{shard_file.name}.arrow" 87 | complete_output_filename = ( 88 | Path(output_dir) / f"{shard_file.name}.arrow.complete" 89 | ) 90 | if complete_output_filename.exists(): 91 | n_skipped += 1 92 | else: 93 | job = executor.submit( 94 | PreprocessEntropiesJob(), str(shard_file), str(output_filename) 95 | ) 96 | n_launched += 1 97 | jobs.append(job) 98 | n_batches += 1 99 | print("launched array jobs n=", n_launched) 100 | print("skipped (completed) array jobs n=", n_skipped) 101 | print("number of slurm batches=", n_batches) 102 | if wait: 103 | output = [job.result() for job in jobs] 104 | assert all(output) 105 | 106 | 107 | if __name__ == "__main__": 108 | typer.run(main) 109 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /bytelatent/data/iterators/test_arrow_iterator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import numpy as np 3 | import pyarrow as pa 4 | 5 | # pyarrow needs the initialization from this import 6 | import pyarrow.dataset # pyright: ignore 7 | 8 | from bytelatent.constants import BLT_DATA 9 | from bytelatent.data.iterators.arrow_iterator import ( 10 | ArrowFileIterator, 11 | ArrowFileIteratorState, 12 | ) 13 | 14 | ENTROPY_MODEL = "transformer_100m" 15 | ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow") 16 | ARROW_TEST_DATA_2 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_01.arrow") 17 | 18 | 19 | def test_basic_arrow_file(): 20 | dataset = pa.dataset.dataset(ARROW_TEST_DATA_1, format="arrow") 21 | n_head = 1000 22 | head_df = dataset.head(n_head).to_pandas() 23 | 24 | initial_state = ArrowFileIteratorState( 25 | file_path=None, 26 | num_workers=1, 27 | worker_id=0, 28 | preprocess_dir=None, 29 | entropy_model_name=ENTROPY_MODEL, 30 | dataset_files=[ARROW_TEST_DATA_1], 31 | row_num=0, 32 | arrow_batch_size=100, 33 | s3_profile=None, 34 | file_format="arrow", 35 | ) 36 | arrow_file = initial_state.build() 37 | start_state = arrow_file.get_state() 38 | assert start_state.row_num == initial_state.row_num 39 | 40 | sample_id = None 41 | for example in arrow_file.create_iter(): 42 | sample_id = example.sample_id 43 | assert head_df.iloc[0]["sample_id"] == sample_id 44 | break 45 | 46 | assert arrow_file.get_state().row_num == 1 47 | arrow_file = initial_state.build() 48 | for example in arrow_file.create_iter(): 49 | assert example.sample_id == sample_id 50 | assert head_df.iloc[0]["sample_id"] == sample_id 51 | break 52 | 53 | # Test resume far enough in to be past the batch size of 100 54 | resumed_state = ArrowFileIteratorState( 55 | file_path=None, 56 | num_workers=1, 57 | worker_id=0, 58 | preprocess_dir=None, 59 | entropy_model_name=ENTROPY_MODEL, 60 | dataset_files=[ARROW_TEST_DATA_1], 61 | row_num=251, 62 | arrow_batch_size=100, 63 | s3_profile=None, 64 | file_format="arrow", 65 | ) 66 | arrow_file = resumed_state.build() 67 | for example in arrow_file.create_iter(): 68 | assert example.sample_id == head_df.iloc[251]["sample_id"] 69 | assert arrow_file.get_state().row_num == 252 70 | break 71 | 72 | world_rank = 1 73 | world_size = 4 74 | # Test World Size and Rank 75 | rank_state = ArrowFileIteratorState( 76 | file_path=None, 77 | num_workers=world_size, 78 | worker_id=world_rank, 79 | preprocess_dir=None, 80 | entropy_model_name=ENTROPY_MODEL, 81 | dataset_files=[ARROW_TEST_DATA_1], 82 | row_num=0, 83 | arrow_batch_size=100, 84 | s3_profile=None, 85 | file_format="arrow", 86 | ) 87 | arrow_file = rank_state.build() 88 | expected_ids = [] 89 | for i in range(n_head): 90 | if i % world_size == world_rank: 91 | expected_ids.append(head_df.iloc[i]["sample_id"]) 92 | print(len(expected_ids)) 93 | i = 0 94 | for example in arrow_file.create_iter(): 95 | assert example.sample_id == expected_ids[i] 96 | i += 1 97 | if i >= len(expected_ids): 98 | break 99 | 100 | 101 | def test_read_jsonl_from_arrow(): 102 | arrow_iterator = ArrowFileIterator( 103 | file_path="fixtures/test_docs.jsonl", 104 | num_workers=1, 105 | worker_id=0, 106 | preprocess_dir=None, 107 | entropy_model_name=None, 108 | file_format="json", 109 | arrow_batch_size=100, 110 | ) 111 | iterator = arrow_iterator.create_iter() 112 | for i, example in enumerate(iterator): 113 | assert example.sample_id == str(i) 114 | assert example.text == f"test_{i}" 115 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | *.out 164 | 165 | figures/ 166 | .vscode/ 167 | .DS_Store 168 | internal/ 169 | jobs_parallel-copy/ 170 | wandb/ 171 | *.ipynb 172 | hf-weights/ 173 | -------------------------------------------------------------------------------- /fixtures/tokens_with_entropies.json: -------------------------------------------------------------------------------- 1 | {"position":{"0":0,"1":1,"2":2,"3":3,"4":4,"5":5,"6":6,"7":7,"8":8,"9":9,"10":10,"11":11,"12":12,"13":13,"14":14,"15":15,"16":16,"17":17,"18":18,"19":19,"20":20,"21":21,"22":22,"23":23,"24":24,"25":25,"26":26,"27":27,"28":28,"29":29,"30":30,"31":31,"32":32,"33":33,"34":34,"35":35,"36":36,"37":37,"38":38,"39":39,"40":40,"41":41,"42":42,"43":43,"44":44,"45":45,"46":46,"47":47,"48":48,"49":49,"50":50,"51":51,"52":52,"53":53,"54":54,"55":55,"56":56,"57":57,"58":58,"59":59,"60":60,"61":61,"62":62,"63":63,"64":64,"65":65,"66":66,"67":67,"68":68,"69":69,"70":70,"71":71,"72":72,"73":73,"74":74,"75":75,"76":76,"77":77,"78":78,"79":79,"80":80},"tokens":{"0":"<","1":"D","2":"a","3":"e","4":"n","5":"e","6":"r","7":"y","8":"s","9":"_","10":"T","11":"a","12":"r","13":"g","14":"a","15":"r","16":"y","17":"e","18":"n","19":"_","20":"i","21":"s","22":"_","23":"i","24":"n","25":"_","26":"G","27":"a","28":"m","29":"e","30":"_","31":"o","32":"f","33":"_","34":"T","35":"h","36":"r","37":"o","38":"n","39":"e","40":"s","41":",","42":"_","43":"a","44":"_","45":"f","46":"a","47":"n","48":"t","49":"a","50":"s","51":"y","52":"_","53":"e","54":"p","55":"i","56":"c","57":"_","58":"b","59":"y","60":"_","61":"G","62":"e","63":"o","64":"r","65":"g","66":"e","67":"_","68":"R","69":".","70":"R","71":".","72":"_","73":"M","74":"a","75":"r","76":"t","77":"i","78":"n","79":".","80":">"},"token_ids":{"0":1,"1":72,"2":101,"3":105,"4":114,"5":105,"6":118,"7":125,"8":119,"9":36,"10":88,"11":101,"12":118,"13":107,"14":101,"15":118,"16":125,"17":105,"18":114,"19":36,"20":109,"21":119,"22":36,"23":109,"24":114,"25":36,"26":75,"27":101,"28":113,"29":105,"30":36,"31":115,"32":106,"33":36,"34":88,"35":108,"36":118,"37":115,"38":114,"39":105,"40":119,"41":48,"42":36,"43":101,"44":36,"45":106,"46":101,"47":114,"48":120,"49":101,"50":119,"51":125,"52":36,"53":105,"54":116,"55":109,"56":103,"57":36,"58":102,"59":125,"60":36,"61":75,"62":105,"63":115,"64":118,"65":107,"66":105,"67":36,"68":86,"69":50,"70":86,"71":50,"72":36,"73":81,"74":101,"75":118,"76":120,"77":109,"78":114,"79":50,"80":2},"entropies":{"0":3.3949158192,"1":2.1656451225,"2":2.3216569424,"3":2.8214058876,"4":1.5249242783,"5":0.0401624143,"6":0.0981037766,"7":0.0544578359,"8":0.3430138826,"9":1.0546212196,"10":0.25252828,"11":0.1494535804,"12":0.0624754503,"13":0.001355894,"14":0.0050173439,"15":0.0052358187,"16":0.0011725067,"17":0.0010307421,"18":1.0241208076,"19":3.6867966652,"20":0.4502205253,"21":0.0484119244,"22":2.2572875023,"23":0.3789347112,"24":1.0042934418,"25":2.9090054035,"26":1.8933598995,"27":1.3859074116,"28":0.3827198744,"29":0.2646365762,"30":1.7742085457,"31":0.0136727821,"32":0.0053820172,"33":0.5485631227,"34":0.2064044327,"35":0.0049266233,"36":0.0005439016,"37":0.0007023578,"38":0.0004170335,"39":0.0054524317,"40":1.1938130856,"41":0.0238215197,"42":3.1279797554,"43":1.3883389235,"44":3.0503094196,"45":1.695879817,"46":1.8551058769,"47":1.4570231438,"48":0.0047810897,"49":0.026396824,"50":0.6633765101,"51":0.3141393065,"52":2.8411159515,"53":1.143143177,"54":0.0520330966,"55":0.3398066461,"56":0.4140175879,"57":2.5563707352,"58":1.3370712996,"59":0.0227173548,"60":3.4447185993,"61":1.8576486111,"62":0.8189754486,"63":0.6776530743,"64":0.0677763447,"65":0.212713033,"66":0.1003480032,"67":0.1746164262,"68":0.4123829603,"69":0.5507118702,"70":0.1047425047,"71":0.0194335245,"72":0.001482119,"73":0.0009310447,"74":0.0002176317,"75":0.0076908777,"76":0.0003866984,"77":0.0008008487,"78":1.2395234108,"79":0.4564163089,"80":0.0000461392},"patch":{"0":0,"1":1,"2":2,"3":3,"4":4,"5":5,"6":5,"7":5,"8":5,"9":5,"10":5,"11":5,"12":5,"13":5,"14":5,"15":5,"16":5,"17":5,"18":5,"19":5,"20":6,"21":6,"22":6,"23":7,"24":7,"25":7,"26":8,"27":9,"28":10,"29":10,"30":10,"31":11,"32":11,"33":11,"34":11,"35":11,"36":11,"37":11,"38":11,"39":11,"40":11,"41":11,"42":11,"43":12,"44":13,"45":14,"46":15,"47":16,"48":17,"49":17,"50":17,"51":17,"52":17,"53":18,"54":18,"55":18,"56":18,"57":18,"58":19,"59":20,"60":20,"61":21,"62":22,"63":22,"64":22,"65":22,"66":22,"67":22,"68":22,"69":22,"70":22,"71":22,"72":22,"73":22,"74":22,"75":22,"76":22,"77":22,"78":22,"79":22,"80":22},"start":{"0":1,"1":1,"2":1,"3":1,"4":1,"5":1,"6":0,"7":0,"8":0,"9":0,"10":0,"11":0,"12":0,"13":0,"14":0,"15":0,"16":0,"17":0,"18":0,"19":0,"20":1,"21":0,"22":0,"23":1,"24":0,"25":0,"26":1,"27":1,"28":1,"29":0,"30":0,"31":1,"32":0,"33":0,"34":0,"35":0,"36":0,"37":0,"38":0,"39":0,"40":0,"41":0,"42":0,"43":1,"44":1,"45":1,"46":1,"47":1,"48":1,"49":0,"50":0,"51":0,"52":0,"53":1,"54":0,"55":0,"56":0,"57":0,"58":1,"59":1,"60":0,"61":1,"62":1,"63":0,"64":0,"65":0,"66":0,"67":0,"68":0,"69":0,"70":0,"71":0,"72":0,"73":0,"74":0,"75":0,"76":0,"77":0,"78":0,"79":0,"80":0}} -------------------------------------------------------------------------------- /bytelatent/norms.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.utils._foreach_utils import ( 6 | _device_has_foreach_support, 7 | _group_tensors_by_device_and_dtype, 8 | _has_foreach_support, 9 | ) 10 | 11 | 12 | @torch.no_grad() 13 | def fixed_clip_grad_norm_( 14 | parameters: torch.Tensor | list[torch.Tensor], 15 | max_norm: float, 16 | norm_type: float = 2.0, 17 | error_if_nonfinite: bool = False, 18 | foreach: Optional[bool] = None, 19 | ) -> torch.Tensor: 20 | r"""Clip the gradient norm of an iterable of parameters. 21 | 22 | The norm is computed over the norms of the individual gradients of all parameters, 23 | as if the norms of the individual gradients were concatenated into a single vector. 24 | Gradients are modified in-place. 25 | 26 | Args: 27 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 28 | single Tensor that will have gradients normalized 29 | max_norm (float): max norm of the gradients 30 | norm_type (float): type of the used p-norm. Can be ``'inf'`` for 31 | infinity norm. 32 | error_if_nonfinite (bool): if True, an error is thrown if the total 33 | norm of the gradients from :attr:`parameters` is ``nan``, 34 | ``inf``, or ``-inf``. Default: False (will switch to True in the future) 35 | foreach (bool): use the faster foreach-based implementation. 36 | If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently 37 | fall back to the slow implementation for other device types. 38 | Default: ``None`` 39 | 40 | Returns: 41 | Total norm of the parameter gradients (viewed as a single vector). 42 | """ 43 | if isinstance(parameters, torch.Tensor): 44 | parameters = [parameters] 45 | grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None] 46 | max_norm = float(max_norm) 47 | norm_type = float(norm_type) 48 | if len(grads) == 0: 49 | return torch.tensor(0.0) 50 | first_device = grads[0].device 51 | grouped_grads: Dict[ 52 | Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] 53 | ] = _group_tensors_by_device_and_dtype( 54 | [grads] 55 | ) # type: ignore[assignment] 56 | 57 | norms: List[Tensor] = [] 58 | for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] 59 | if (foreach is None and _has_foreach_support(device_grads, device)) or ( 60 | foreach and _device_has_foreach_support(device) 61 | ): 62 | norms.extend(torch._foreach_norm(device_grads, norm_type)) 63 | elif foreach: 64 | raise RuntimeError( 65 | f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" 66 | ) 67 | else: 68 | norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) 69 | 70 | total_norm = torch.linalg.vector_norm( 71 | torch.stack([norm.to(first_device) for norm in norms]), norm_type 72 | ) 73 | 74 | if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): 75 | raise RuntimeError( 76 | f"The total norm of order {norm_type} for gradients from " 77 | "`parameters` is non-finite, so it cannot be clipped. To disable " 78 | "this error and scale the gradients by the non-finite norm anyway, " 79 | "set `error_if_nonfinite=False`" 80 | ) 81 | clip_coef = max_norm / (total_norm + 1e-6) 82 | # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so 83 | # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization 84 | # when the gradients do not reside in CPU memory. 85 | clip_coef_clamped = torch.clamp(clip_coef, max=1.0) 86 | for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] 87 | if (foreach is None and _has_foreach_support(device_grads, device)) or ( 88 | foreach and _device_has_foreach_support(device) 89 | ): 90 | torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) 91 | elif foreach: 92 | raise RuntimeError( 93 | f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" 94 | ) 95 | else: 96 | clip_coef_clamped_device = clip_coef_clamped.to(device) 97 | for g in device_grads: 98 | g.mul_(clip_coef_clamped_device) 99 | 100 | return total_norm 101 | -------------------------------------------------------------------------------- /bytelatent/data/iterators/preprocess_iterator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | from typing import Any, Generator 3 | 4 | import torch 5 | from pydantic import BaseModel, ConfigDict 6 | 7 | from bytelatent.data.data_types import BltExample 8 | from bytelatent.data.iterators.abstract_iterator import ( 9 | PydanticIteratorState, 10 | StatefulIterator, 11 | ) 12 | from bytelatent.data.iterators.arrow_iterator import ( 13 | ArrowFileIterator, 14 | ArrowFileIteratorState, 15 | ) 16 | from bytelatent.data.iterators.limit_iterator import LimitIterator, LimitIteratorState 17 | from bytelatent.data.iterators.looping_iterator import ( 18 | LoopingIterator, 19 | LoopingIteratorState, 20 | ) 21 | from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum 22 | from bytelatent.tokenizers.blt_tokenizer import BltTokenizer 23 | from bytelatent.tokenizers.build_tokenizer import TokenizerArgs 24 | 25 | 26 | class PreprocessIteratorState(PydanticIteratorState): 27 | model_config = ConfigDict(extra="forbid") 28 | arrow_file_iterator_state: ( 29 | ArrowFileIteratorState | LoopingIteratorState | LimitIteratorState 30 | ) 31 | add_tokens: bool 32 | add_patches: bool 33 | tokenizer_args: TokenizerArgs 34 | patcher_args: PatcherArgs 35 | 36 | def build(self): 37 | arrow_iterator = self.arrow_file_iterator_state.build() 38 | return PreprocessIterator( 39 | arrow_iterator, 40 | patcher_args=self.patcher_args, 41 | tokenizer_args=self.tokenizer_args, 42 | add_tokens=self.add_tokens, 43 | add_patches=self.add_patches, 44 | ) 45 | 46 | 47 | class PreprocessIterator(StatefulIterator): 48 | """ 49 | Take BltExamples with fields filled in only from ArrowFileIterator, and fill in fields that require 50 | preprocessing like tokenization and patching 51 | """ 52 | 53 | def __init__( 54 | self, 55 | arrow_iterator: ArrowFileIterator | LoopingIterator | LimitIterator, 56 | *, 57 | patcher_args: PatcherArgs, 58 | tokenizer_args: TokenizerArgs, 59 | add_tokens: bool = True, 60 | add_patches: bool = True, 61 | ): 62 | self.arrow_iterator = arrow_iterator 63 | self.tokenizer_args = tokenizer_args 64 | self.patcher_args = patcher_args 65 | self.add_tokens = add_tokens 66 | self.add_patches = add_patches 67 | self.tokenizer: BltTokenizer | None = None 68 | self.patcher: Patcher | None = None 69 | 70 | def get_state(self) -> PreprocessIteratorState: 71 | """ 72 | The only state to maintain here is from arrow, there 73 | isn't any internal state on this iterator. 74 | """ 75 | return PreprocessIteratorState( 76 | arrow_file_iterator_state=self.arrow_iterator.get_state(), 77 | tokenizer_args=self.tokenizer_args, 78 | patcher_args=self.patcher_args, 79 | add_tokens=self.add_tokens, 80 | add_patches=self.add_patches, 81 | ) 82 | 83 | def create_iter(self) -> Generator[BltExample, Any, None]: 84 | if self.tokenizer is None and self.add_tokens: 85 | self.tokenizer = self.tokenizer_args.build() 86 | if self.patcher is None and self.add_patches: 87 | self.patcher = self.patcher_args.build() 88 | 89 | example_iter = self.arrow_iterator.create_iter() 90 | for example in example_iter: 91 | if self.add_tokens: 92 | tokens = self.tokenizer.encode(example.text) 93 | else: 94 | tokens = example.tokens 95 | if ( 96 | self.patcher is not None 97 | and self.patcher.patching_mode == PatchingModeEnum.entropy 98 | ): 99 | assert ( 100 | example.entropies is not None 101 | ), "For patching, entropies cannot be None" 102 | entropies = torch.tensor(example.entropies).unsqueeze(0) 103 | else: 104 | entropies = None 105 | if self.patcher is None: 106 | patch_lengths = None 107 | else: 108 | patch_lengths = self.patcher.patch( 109 | torch.tensor(tokens).unsqueeze(0), 110 | include_next_token=False, 111 | entropies=entropies, 112 | )[0][0].tolist() 113 | yield BltExample( 114 | sample_id=example.sample_id, 115 | text=example.text, 116 | tokens=tokens, 117 | mask=[True] * len(tokens), 118 | patch_lengths=patch_lengths, 119 | entropies=example.entropies, 120 | ) 121 | -------------------------------------------------------------------------------- /bytelatent/profiling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import contextlib 5 | import logging 6 | import os 7 | from pathlib import Path 8 | 9 | import torch.distributed 10 | import wandb 11 | import xformers.profiler 12 | from pydantic import BaseModel 13 | from torch.profiler.profiler import profile 14 | from xformers.profiler import MemSnapshotsProfiler, PyTorchProfiler 15 | 16 | from bytelatent.distributed import get_is_master 17 | 18 | 19 | class ProfilerArgs(BaseModel): 20 | run: bool = False 21 | trace_folder: str = "profiling" 22 | mem_warmup: int = 100 23 | mem_steps: int = 2 24 | profile_warmup: int = 102 25 | profile_steps: int = 2 26 | 27 | 28 | logger = logging.getLogger() 29 | 30 | 31 | def perfetto_to_html(json_file, html_file): 32 | import gzip 33 | import string 34 | 35 | import viztracer 36 | 37 | root = os.path.dirname(viztracer.__file__) 38 | sub = {} 39 | json_file = gzip.open(json_file) if ".gz" in str(json_file) else open(json_file) 40 | with open( 41 | os.path.join(root, "html/trace_viewer_embedder.html"), encoding="utf-8" 42 | ) as f: 43 | tmpl = f.read() 44 | with open(os.path.join(root, "html/trace_viewer_full.html"), encoding="utf-8") as f: 45 | sub["trace_viewer_full"] = f.read() 46 | with json_file as j: 47 | content = j.read() 48 | if isinstance(content, bytes): 49 | content = content.decode("utf-8") 50 | sub["json_data"] = content.replace("", "<\\/script>") # type: ignore 51 | with open(html_file, "w+", encoding="utf-8") as output_file: 52 | output_file.write(string.Template(tmpl).substitute(sub)) 53 | 54 | 55 | class PyTorchProfilerWandb(PyTorchProfiler): 56 | def __init__(self, main_profiler) -> None: 57 | self.main_profiler = main_profiler 58 | self.num_steps = 0 59 | self.pytorch_profiler = torch.profiler.profile( 60 | on_trace_ready=self._on_trace, 61 | profile_memory=True, 62 | record_shapes=True, 63 | # With stack gives huge profile traces 64 | # and bugs out because of some non ascii 65 | # character somewhere in pytorch 66 | with_stack=False, 67 | with_flops=True, 68 | activities=self.ACTIVITIES, 69 | ) 70 | 71 | def _analyze_trace(self, prof: profile): 72 | logger.info("Begin analyze trace") 73 | super()._analyze_trace(prof) 74 | logger.info("End analyze trace") 75 | 76 | def _on_trace(self, prof: torch.profiler.profiler.profile) -> None: 77 | super()._on_trace(prof) 78 | if get_is_master() and wandb.run is not None: 79 | filename = list( 80 | Path(self.main_profiler.output_dir).glob( 81 | "profile_CPU_CUDA*/*.pt.trace.json*" 82 | ) 83 | )[0] 84 | html_path = str(filename).replace(".json", ".html") 85 | perfetto_to_html(filename, html_path) 86 | wandb.log({"profile_trace": wandb.Html(html_path)}) 87 | 88 | 89 | class MemSnapshotsProfilerWandb(MemSnapshotsProfiler): 90 | def __exit__(self, exc_type, exc_val, exc_tb): 91 | super().__exit__(exc_type, exc_val, exc_tb) 92 | if get_is_master() and wandb.run is not None: 93 | filename = list( 94 | Path(self.main_profiler.output_dir).glob("memory_trace_plot/*.html") 95 | )[0] 96 | wandb.log({"memory_trace": wandb.Html(open(filename), inject=False)}) 97 | 98 | 99 | @contextlib.contextmanager 100 | def maybe_run_profiler(dump_dir, module, config: ProfilerArgs): 101 | # get user defined profiler settings 102 | 103 | if config.run: 104 | trace_dir = os.path.join(dump_dir, config.trace_folder) 105 | 106 | logger.info(f"Profiling active. Traces will be saved at {trace_dir}") 107 | 108 | if get_is_master() and not os.path.exists(trace_dir): 109 | os.makedirs(trace_dir) 110 | if torch.distributed.is_initialized(): 111 | torch.distributed.barrier() 112 | 113 | with xformers.profiler.profile( 114 | output_dir=trace_dir, 115 | module=module, 116 | schedule=[ 117 | ( 118 | MemSnapshotsProfilerWandb, 119 | config.mem_warmup, 120 | config.mem_warmup + config.mem_steps, 121 | ), 122 | ( 123 | PyTorchProfilerWandb, 124 | config.profile_warmup, 125 | config.profile_warmup + config.profile_steps, 126 | ), 127 | ], 128 | ) as profiler: 129 | yield profiler 130 | 131 | else: 132 | torch_profiler = contextlib.nullcontext() 133 | yield None 134 | -------------------------------------------------------------------------------- /bytelatent/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | import math 5 | import sys 6 | import time 7 | from datetime import timedelta 8 | 9 | import fsspec 10 | 11 | from bytelatent.distributed import get_global_rank, get_is_slurm_job 12 | 13 | 14 | class LogFormatter(logging.Formatter): 15 | """ 16 | Custom logger for distributed jobs, displaying rank 17 | and preserving indent from the custom prefix format. 18 | """ 19 | 20 | def __init__(self): 21 | self.start_time = time.time() 22 | self.rank = get_global_rank() 23 | self.show_rank = not get_is_slurm_job() # srun has --label 24 | 25 | def formatTime(self, record): 26 | subsecond, seconds = math.modf(record.created) 27 | curr_date = ( 28 | time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds)) 29 | + f".{int(subsecond * 1_000_000):06d}" 30 | ) 31 | delta = timedelta(seconds=round(record.created - self.start_time)) 32 | return f"{curr_date} - {delta}" 33 | 34 | def formatPrefix(self, record): 35 | fmt_time = self.formatTime(record) 36 | if self.show_rank: 37 | return f"{self.rank}: {record.levelname:<7} {fmt_time} - " 38 | else: 39 | return f"{record.levelname:<7} {fmt_time} - " 40 | 41 | def formatMessage(self, record, indent: str): 42 | content = record.getMessage() 43 | content = content.replace("\n", "\n" + indent) 44 | # Exception handling as in the default formatter, albeit with indenting 45 | # according to our custom prefix 46 | if record.exc_info: 47 | # Cache the traceback text to avoid converting it multiple times 48 | # (it's constant anyway) 49 | if not record.exc_text: 50 | record.exc_text = self.formatException(record.exc_info) 51 | if record.exc_text: 52 | if content[-1:] != "\n": 53 | content = content + "\n" + indent 54 | content = content + indent.join( 55 | [l + "\n" for l in record.exc_text.splitlines()] 56 | ) 57 | if content[-1:] == "\n": 58 | content = content[:-1] 59 | if record.stack_info: 60 | if content[-1:] != "\n": 61 | content = content + "\n" + indent 62 | stack_text = self.formatStack(record.stack_info) 63 | content = content + indent.join([l + "\n" for l in stack_text.splitlines()]) 64 | if content[-1:] == "\n": 65 | content = content[:-1] 66 | 67 | return content 68 | 69 | def format(self, record): 70 | prefix = self.formatPrefix(record) 71 | indent = " " * len(prefix) 72 | content = self.formatMessage(record, indent) 73 | return prefix + content 74 | 75 | 76 | def set_root_log_level(log_level: str): 77 | logger = logging.getLogger() 78 | level: int | str = log_level.upper() 79 | try: 80 | level = int(log_level) 81 | except ValueError: 82 | pass 83 | try: 84 | logger.setLevel(level) # type: ignore 85 | except Exception: 86 | logger.warning( 87 | f"Failed to set logging level to {log_level}, using default 'NOTSET'" 88 | ) 89 | logger.setLevel(logging.NOTSET) 90 | 91 | 92 | def init_logger( 93 | log_file: str | None = None, 94 | *, 95 | name: str | None = None, 96 | level: str = "INFO", 97 | fs: fsspec.AbstractFileSystem | None = None, 98 | ): 99 | """ 100 | Setup logging. 101 | 102 | Args: 103 | log_file: A file name to save file logs to. 104 | name: The name of the logger to configure, by default the root logger. 105 | level: The logging level to use. 106 | """ 107 | set_root_log_level(level) 108 | logger = logging.getLogger(name) 109 | 110 | # stdout: everything 111 | stdout_handler = logging.StreamHandler(sys.stdout) 112 | stdout_handler.setLevel(logging.NOTSET) 113 | stdout_handler.setFormatter(LogFormatter()) 114 | 115 | # stderr: warnings / errors and above 116 | stderr_handler = logging.StreamHandler(sys.stderr) 117 | stderr_handler.setLevel(logging.WARNING) 118 | stderr_handler.setFormatter(LogFormatter()) 119 | 120 | # set stream handlers 121 | logger.handlers.clear() 122 | logger.handlers.append(stdout_handler) 123 | logger.handlers.append(stderr_handler) 124 | 125 | if log_file is not None and get_global_rank() == 0: 126 | # build file handler 127 | if fs is None: 128 | file_handler = logging.FileHandler(log_file, "a") 129 | else: 130 | file_stream = fs.open(log_file, mode="a") 131 | file_handler = logging.StreamHandler(file_stream) 132 | file_handler.setLevel(logging.NOTSET) 133 | file_handler.setFormatter(LogFormatter()) 134 | # update logger 135 | logger = logging.getLogger() 136 | logger.addHandler(file_handler) 137 | -------------------------------------------------------------------------------- /bytelatent/optim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | import math 5 | from functools import partial 6 | 7 | from pydantic import BaseModel, ConfigDict 8 | from torch import nn 9 | from torch.optim import AdamW, lr_scheduler 10 | 11 | logger = logging.getLogger() 12 | 13 | 14 | class OptimArgs(BaseModel): 15 | model_config = ConfigDict(extra="forbid") 16 | lr: float = 3e-4 17 | weight_decay: float = 0.1 18 | epsilon: float = 1e-8 19 | beta1: float = 0.9 20 | beta2: float = 0.95 21 | clip: float = 1.0 22 | 23 | scheduler: str = "cosine" 24 | warmup: int = 2000 25 | lr_min_ratio: float = 0.1 26 | cycle_length: float = 1.0 27 | cosine_theta: float = 1.0 28 | annealing_step: int = 1000 29 | decay_fraction: float = 0.1 30 | 31 | exp_factor: float = 0.5 32 | 33 | 34 | def lr_linear(step: int, warmup: int, n_steps: int, min_ratio: float) -> float: 35 | if step < warmup: 36 | lr = float(step) / warmup 37 | elif step <= n_steps: 38 | s = float(step - warmup) / (n_steps - warmup) 39 | lr = s * min_ratio + (1 - s) 40 | else: 41 | lr = min_ratio 42 | return lr 43 | 44 | 45 | def lr_inv_sqrt(step: int, warmup: int, exp_factor: float, min_ratio: float) -> float: 46 | if step < warmup: 47 | lr = float(step) / warmup 48 | else: 49 | lr = max((warmup**exp_factor) / (step**exp_factor), min_ratio) 50 | return lr 51 | 52 | 53 | def lr_cosine( 54 | step: int, 55 | warmup: int, 56 | n_steps: int, 57 | cycle_length: float, 58 | theta: float, 59 | min_ratio: float, 60 | ) -> float: 61 | sign = ((step // (n_steps * cycle_length)) % 2) * -2 + 1 62 | if step < warmup: 63 | lr = float(step) / warmup 64 | elif step <= n_steps: 65 | s = float(step - warmup) / (n_steps - warmup) 66 | lr = min_ratio + 0.5 * (1 - min_ratio) * ( 67 | sign * math.cos(math.pi * s**theta / cycle_length) + 1 68 | ) 69 | else: 70 | lr = min_ratio 71 | return lr 72 | 73 | 74 | def lr_wsd( 75 | step: int, 76 | warmup: int, 77 | n_steps: int, 78 | decay_fraction: float, 79 | cycle_length: float, 80 | min_ratio: float, 81 | ) -> float: 82 | """ 83 | UNDERSTANDING WARMUP-STABLE-DECAY LEARNING RATES: A RIVER VALLEY LOSS LANDSCAPE PERSPECTIVE 84 | https://arxiv.org/pdf/2410.05192 85 | """ 86 | cycle_num = step // int(n_steps * cycle_length) + 1 87 | curr_n_steps = int(n_steps * cycle_length) * cycle_num 88 | decay_length = int(curr_n_steps * decay_fraction) 89 | 90 | if step < warmup: 91 | lr = float(step) / warmup 92 | elif step <= curr_n_steps - decay_length: 93 | lr = 1.0 94 | elif step > curr_n_steps - decay_length and step <= curr_n_steps: 95 | # Linear interpolation gives similar results 96 | # slope = -(1.0 - min_ratio) / decay_length 97 | # intercept = min_ratio + ((1.0 - min_ratio) * curr_n_steps) / decay_length 98 | # lr = slope * step + intercept 99 | 100 | step = step - (curr_n_steps - decay_length) 101 | lr = 1 / ((step / curr_n_steps) * (1 / min_ratio) + (1 - step / curr_n_steps)) 102 | else: 103 | lr = min_ratio 104 | 105 | return lr 106 | 107 | 108 | def build_lr_fn(args: OptimArgs, n_steps: int): 109 | if args.scheduler == "constant": 110 | lr_fn = lambda x: 1.0 111 | elif args.scheduler == "linear": 112 | lr_fn = partial( 113 | lr_linear, warmup=args.warmup, n_steps=n_steps, min_ratio=args.lr_min_ratio 114 | ) 115 | elif args.scheduler == "inv_sqrt": 116 | lr_fn = partial( 117 | lr_inv_sqrt, 118 | warmup=args.warmup, 119 | exp_factor=args.exp_factor, 120 | min_ratio=args.lr_min_ratio, 121 | ) 122 | elif args.scheduler == "cosine": 123 | lr_fn = partial( 124 | lr_cosine, 125 | warmup=args.warmup, 126 | n_steps=n_steps, 127 | cycle_length=args.cycle_length, 128 | theta=args.cosine_theta, 129 | min_ratio=args.lr_min_ratio, 130 | ) 131 | elif args.scheduler == "wsd": 132 | assert args.decay_fraction < args.cycle_length 133 | lr_fn = partial( 134 | lr_wsd, 135 | warmup=args.warmup, 136 | n_steps=n_steps, 137 | decay_fraction=args.decay_fraction, 138 | cycle_length=args.cycle_length, 139 | min_ratio=args.lr_min_ratio, 140 | ) 141 | else: 142 | raise NotImplementedError(f"Unknown scheduler: {args.scheduler}") 143 | return lr_fn 144 | 145 | 146 | def build_optimizer(model: nn.Module, args: OptimArgs, n_steps: int): 147 | logger.info("Starting build of optimizer...") 148 | optimizer = AdamW( 149 | model.parameters(), 150 | lr=args.lr, 151 | betas=(args.beta1, args.beta2), 152 | weight_decay=args.weight_decay, 153 | eps=args.epsilon, 154 | fused=True, # Faster optim.step but can throw errors 155 | ) 156 | 157 | # scheduler 158 | lr_fn = build_lr_fn(args, n_steps) 159 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_fn) 160 | 161 | logger.info("Done with build of optimizer.") 162 | return optimizer, scheduler 163 | -------------------------------------------------------------------------------- /bytelatent/tokenizers/blt_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import re 3 | 4 | from bytelatent.tokenizers.abstract_tokenizer import Tokenizer 5 | from bytelatent.tokenizers.constants import ( 6 | BOE_ID, 7 | BOS_ID, 8 | BPE_ID, 9 | BYTE_UNITS, 10 | EOS_ID, 11 | OFFSET, 12 | PAD_ID, 13 | ) 14 | from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer 15 | 16 | 17 | def convert_to_bytes(s): 18 | # check if the output is a bytes like object of the format <0x00> 19 | if re.match(r"<0x[0-9a-fA-F]+>", s): 20 | return bytes.fromhex(s[3:-1]) 21 | else: 22 | return bytes(s, "utf-8", errors="ignore") 23 | 24 | 25 | def text2bytes_bpe_delims( 26 | text: str, 27 | *, 28 | bpe_tokenizer, 29 | bpe_id: int, 30 | offsetting_special_char: int, 31 | add_bos: bool, 32 | add_eos: bool, 33 | ): 34 | cur_bpe = bpe_tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos) 35 | # merge the leading space tokens 36 | leading_space_tokens = [] 37 | other_bpe_tokens = [] 38 | leading = True 39 | for token in cur_bpe: 40 | bpe_str = bpe_tokenizer.sp_model.id_to_piece(token) 41 | if leading and all(c == "▁" for c in bpe_str): 42 | leading_space_tokens.append(bpe_str) 43 | else: 44 | leading = False 45 | other_bpe_tokens.append(bpe_str) 46 | cur_bpe_strs = ["".join(leading_space_tokens)] + other_bpe_tokens 47 | 48 | # Remove the '▁' characters 49 | bpe_strs = [] 50 | for i, bpe_str in enumerate(cur_bpe_strs): 51 | if ( 52 | len(bpe_strs) <= 1 53 | and all([c == " " for s in bpe_strs for c in s]) 54 | and not all(c == "▁" for c in bpe_str) 55 | ): 56 | # Remove leading space for first non space token. 57 | bpe_str = bpe_str.replace("▁", "") 58 | elif i == 0 and all(c == "▁" for c in bpe_str): 59 | bpe_str = " " * (len(text) - len(text.lstrip(" "))) 60 | else: 61 | bpe_str = bpe_str.replace("▁", " ") 62 | if len(bpe_str) > 0: 63 | bpe_strs.append(bpe_str) 64 | ex_seq = [] 65 | # Convert bpe tokens to bytes 66 | for s in bpe_strs: 67 | byte_chunk = convert_to_bytes(s) 68 | proc_chunk = [int(unit) for unit in byte_chunk] 69 | ex_seq.extend([bpe_id - offsetting_special_char] + proc_chunk) 70 | 71 | return ex_seq 72 | 73 | 74 | class BltTokenizer(Tokenizer): 75 | def __init__( 76 | self, 77 | *, 78 | vocab_size_unit_1: int = BYTE_UNITS, 79 | bpe_delim: bool = False, 80 | bpe_tokenizer_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", 81 | add_bos: bool = True, 82 | add_eos: bool = True, 83 | ): 84 | self.add_bos = add_bos 85 | self.add_eos = add_eos 86 | self.vocab_size_unit_1 = vocab_size_unit_1 87 | self.boe_id = BOE_ID 88 | self.bos_id = BOS_ID 89 | self.eos_id = EOS_ID 90 | self.pad_id = PAD_ID 91 | self.bpe_id = BPE_ID 92 | self.bpe_tokenizer_path = bpe_tokenizer_path 93 | if bpe_delim: 94 | self.bpe_tokenizer = SentencePieceTokenizer( 95 | model_path=self.bpe_tokenizer_path 96 | ) 97 | else: 98 | self.bpe_tokenizer = None 99 | self.bpe_delim = bpe_delim 100 | self.offsetting_special_char = OFFSET 101 | self.vocab_size_unit_1 = vocab_size_unit_1 102 | self.n_words = vocab_size_unit_1 + self.offsetting_special_char 103 | 104 | def get_vocab_size(self) -> int: 105 | return self.n_words 106 | 107 | def encode( 108 | self, text: str, add_bos: bool | None = None, add_eos: bool | None = None 109 | ): 110 | if add_bos is None: 111 | add_bos = self.add_bos 112 | if add_eos is None: 113 | add_eos = self.add_eos 114 | 115 | if self.bpe_delim: 116 | tokens = text2bytes_bpe_delims( 117 | text, 118 | bpe_tokenizer=self.bpe_tokenizer, 119 | bpe_id=self.bpe_id, 120 | offsetting_special_char=self.offsetting_special_char, 121 | add_bos=False, 122 | add_eos=False, 123 | ) 124 | else: 125 | tokens = bytes(text, encoding="utf-8", errors="ignore") 126 | 127 | # Offsetting 128 | tokens = [int(unit) + self.offsetting_special_char for unit in tokens] 129 | 130 | if add_bos: 131 | tokens.insert(0, self.bos_id) 132 | if add_eos: 133 | tokens.append(self.eos_id) 134 | 135 | return tokens 136 | 137 | def decode(self, tokens: list[int], cut_at_eos: bool = False): 138 | if cut_at_eos: 139 | for k, t in enumerate(tokens): 140 | if t == self.eos_id: 141 | tokens = tokens[: k + 1] 142 | break 143 | return bytes( 144 | [ 145 | tok - self.offsetting_special_char 146 | for tok in tokens 147 | if tok - self.offsetting_special_char >= 0 148 | ] 149 | ).decode("utf-8", errors="ignore") 150 | 151 | def get_token_offsets(self, text: str, tokens: list[int] | None = None): 152 | # TODO: Figure out what this does 153 | raise NotImplementedError() 154 | -------------------------------------------------------------------------------- /bytelatent/float8.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import re 4 | import warnings 5 | from typing import Callable 6 | 7 | import torch 8 | 9 | # avoid division by zero when calculating scale 10 | EPS = 1e-12 11 | 12 | 13 | def scale(t, amax_t, dtype_t): 14 | min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max 15 | scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v 16 | t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t) 17 | return t_fp8, scale_t 18 | 19 | 20 | def matmul( 21 | first, amax_first, dtype_first, second_t, amax_second_t, dtype_second_t, bias 22 | ): 23 | first_fp8, scale_first = scale(first, amax_first, dtype_first) 24 | second_t_fp8, scale_second_t = scale(second_t, amax_second_t, dtype_second_t) 25 | output = torch._scaled_mm( 26 | first_fp8, 27 | second_t_fp8.t(), 28 | scale_a=scale_first, 29 | scale_b=scale_second_t.t(), 30 | bias=bias, 31 | out_dtype=torch.bfloat16, 32 | use_fast_accum=True, 33 | ) 34 | return output 35 | 36 | 37 | @torch._dynamo.allow_in_graph 38 | class Fp8LinearFn(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, a, b_t, bias): 41 | amax_a = a.abs().amax(dim=-1, keepdim=True) 42 | amax_b_t = b_t.abs().amax(dim=-1, keepdim=True) 43 | out = matmul( 44 | a, amax_a, torch.float8_e4m3fn, b_t, amax_b_t, torch.float8_e4m3fn, bias 45 | ) 46 | 47 | ctx.a_requires_grad = a.requires_grad 48 | ctx.b_requires_grad = b_t.requires_grad 49 | ctx.bias_requires_grad = bias.requires_grad if bias is not None else False 50 | 51 | ctx.save_for_backward(a, b_t, amax_b_t.max()) 52 | 53 | return out 54 | 55 | @staticmethod 56 | def backward(ctx, grad_out): 57 | a, b_t, amax_b = ctx.saved_tensors 58 | 59 | if ctx.a_requires_grad: 60 | b = b_t.t().contiguous() 61 | amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True) 62 | amax_b = amax_b.repeat(b.shape[0], 1) 63 | grad_a = matmul( 64 | grad_out, 65 | amax_grad_out, 66 | torch.float8_e4m3fn, 67 | b, 68 | amax_b, 69 | torch.float8_e4m3fn, 70 | None, 71 | ) 72 | else: 73 | grad_a = None 74 | if ctx.b_requires_grad: 75 | grad_b = grad_out.t() @ a 76 | else: 77 | grad_b = None 78 | if ctx.bias_requires_grad: 79 | grad_bias = grad_out.sum(dim=0) 80 | else: 81 | grad_bias = None 82 | 83 | return grad_a, grad_b, grad_bias 84 | 85 | 86 | class Fp8Linear(torch.nn.Linear): 87 | def forward(self, input: torch.Tensor) -> torch.Tensor: 88 | out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias) 89 | out = out.unflatten(0, input.shape[:-1]) 90 | return out 91 | 92 | 93 | def named_replace( 94 | fn: Callable[[torch.nn.Module, str], torch.nn.Module], 95 | module: torch.nn.Module, 96 | name="", 97 | ) -> torch.nn.Module: 98 | for child_name, child_module in list(module.named_children()): 99 | full_name = f"{name}.{child_name}" if name else child_name 100 | new_child_module = named_replace(fn, child_module, full_name) 101 | setattr(module, child_name, new_child_module) 102 | module = fn(module, name) 103 | return module 104 | 105 | 106 | def convert_linears_to_fp8( 107 | root_module: torch.nn.Module, recipe: str, filter: str 108 | ) -> torch.nn.Module: 109 | if recipe not in ["rowwise"]: 110 | raise RuntimeError(f"Unknown float8 recipe {recipe!r}") 111 | 112 | if recipe == "rowwise" and torch.__version__ < "2.5": 113 | # We need https://github.com/pytorch/pytorch/pull/134781. 114 | warnings.warn("Float8 row-wise scaling is slow in PyTorch prior to v2.5.0") 115 | 116 | # Multi-kernel makes Inductor auto-tune between a regular "streaming"-based 117 | # reduction kernel and a "persistent" reduction kernel. Since fp8 has some 118 | # multi-pass steps (e.g., first get amax, then scale), persistent kernels 119 | # should perform better. 120 | torch._inductor.config.triton.multi_kernel = 1 121 | 122 | filter_re = re.compile(filter) 123 | 124 | def replace(module: torch.nn.Module, name: str) -> torch.nn.Module: 125 | if not isinstance(module, torch.nn.Linear) or not filter_re.search(name): 126 | return module 127 | if type(module) == torch.nn.Linear: 128 | if recipe == "rowwise": 129 | new_module = Fp8Linear( 130 | in_features=module.in_features, 131 | out_features=module.out_features, 132 | bias=module.bias is not None, 133 | dtype=module.weight.dtype, 134 | device=module.weight.device, 135 | ) 136 | new_module.weight = module.weight 137 | new_module.bias = module.bias 138 | else: 139 | assert False, recipe 140 | else: 141 | assert False, str(type(module)) 142 | return new_module 143 | 144 | out = named_replace(replace, root_module) 145 | 146 | # Force re-compile everything 147 | torch._dynamo.reset_code_caches() 148 | from torch._inductor.cudagraph_trees import reset_cudagraph_trees 149 | 150 | reset_cudagraph_trees() 151 | 152 | return out 153 | -------------------------------------------------------------------------------- /plot_data/entropy_figure.json: -------------------------------------------------------------------------------- 1 | {"text":"Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.","threshold":1.335442066192627,"dataframe_json":"{\"position\":{\"0\":0,\"1\":1,\"2\":2,\"3\":3,\"4\":4,\"5\":5,\"6\":6,\"7\":7,\"8\":8,\"9\":9,\"10\":10,\"11\":11,\"12\":12,\"13\":13,\"14\":14,\"15\":15,\"16\":16,\"17\":17,\"18\":18,\"19\":19,\"20\":20,\"21\":21,\"22\":22,\"23\":23,\"24\":24,\"25\":25,\"26\":26,\"27\":27,\"28\":28,\"29\":29,\"30\":30,\"31\":31,\"32\":32,\"33\":33,\"34\":34,\"35\":35,\"36\":36,\"37\":37,\"38\":38,\"39\":39,\"40\":40,\"41\":41,\"42\":42,\"43\":43,\"44\":44,\"45\":45,\"46\":46,\"47\":47,\"48\":48,\"49\":49,\"50\":50,\"51\":51,\"52\":52,\"53\":53,\"54\":54,\"55\":55,\"56\":56,\"57\":57,\"58\":58,\"59\":59,\"60\":60,\"61\":61,\"62\":62,\"63\":63,\"64\":64,\"65\":65,\"66\":66,\"67\":67,\"68\":68,\"69\":69,\"70\":70,\"71\":71,\"72\":72,\"73\":73,\"74\":74,\"75\":75,\"76\":76,\"77\":77,\"78\":78,\"79\":79,\"80\":80},\"tokens\":{\"0\":\"<\",\"1\":\"D\",\"2\":\"a\",\"3\":\"e\",\"4\":\"n\",\"5\":\"e\",\"6\":\"r\",\"7\":\"y\",\"8\":\"s\",\"9\":\"_\",\"10\":\"T\",\"11\":\"a\",\"12\":\"r\",\"13\":\"g\",\"14\":\"a\",\"15\":\"r\",\"16\":\"y\",\"17\":\"e\",\"18\":\"n\",\"19\":\"_\",\"20\":\"i\",\"21\":\"s\",\"22\":\"_\",\"23\":\"i\",\"24\":\"n\",\"25\":\"_\",\"26\":\"G\",\"27\":\"a\",\"28\":\"m\",\"29\":\"e\",\"30\":\"_\",\"31\":\"o\",\"32\":\"f\",\"33\":\"_\",\"34\":\"T\",\"35\":\"h\",\"36\":\"r\",\"37\":\"o\",\"38\":\"n\",\"39\":\"e\",\"40\":\"s\",\"41\":\",\",\"42\":\"_\",\"43\":\"a\",\"44\":\"_\",\"45\":\"f\",\"46\":\"a\",\"47\":\"n\",\"48\":\"t\",\"49\":\"a\",\"50\":\"s\",\"51\":\"y\",\"52\":\"_\",\"53\":\"e\",\"54\":\"p\",\"55\":\"i\",\"56\":\"c\",\"57\":\"_\",\"58\":\"b\",\"59\":\"y\",\"60\":\"_\",\"61\":\"G\",\"62\":\"e\",\"63\":\"o\",\"64\":\"r\",\"65\":\"g\",\"66\":\"e\",\"67\":\"_\",\"68\":\"R\",\"69\":\".\",\"70\":\"R\",\"71\":\".\",\"72\":\"_\",\"73\":\"M\",\"74\":\"a\",\"75\":\"r\",\"76\":\"t\",\"77\":\"i\",\"78\":\"n\",\"79\":\".\",\"80\":\">\"},\"token_ids\":{\"0\":1,\"1\":72,\"2\":101,\"3\":105,\"4\":114,\"5\":105,\"6\":118,\"7\":125,\"8\":119,\"9\":36,\"10\":88,\"11\":101,\"12\":118,\"13\":107,\"14\":101,\"15\":118,\"16\":125,\"17\":105,\"18\":114,\"19\":36,\"20\":109,\"21\":119,\"22\":36,\"23\":109,\"24\":114,\"25\":36,\"26\":75,\"27\":101,\"28\":113,\"29\":105,\"30\":36,\"31\":115,\"32\":106,\"33\":36,\"34\":88,\"35\":108,\"36\":118,\"37\":115,\"38\":114,\"39\":105,\"40\":119,\"41\":48,\"42\":36,\"43\":101,\"44\":36,\"45\":106,\"46\":101,\"47\":114,\"48\":120,\"49\":101,\"50\":119,\"51\":125,\"52\":36,\"53\":105,\"54\":116,\"55\":109,\"56\":103,\"57\":36,\"58\":102,\"59\":125,\"60\":36,\"61\":75,\"62\":105,\"63\":115,\"64\":118,\"65\":107,\"66\":105,\"67\":36,\"68\":86,\"69\":50,\"70\":86,\"71\":50,\"72\":36,\"73\":81,\"74\":101,\"75\":118,\"76\":120,\"77\":109,\"78\":114,\"79\":50,\"80\":2},\"entropies\":{\"0\":3.3949158192,\"1\":2.1656451225,\"2\":2.3216569424,\"3\":2.8214058876,\"4\":1.5249242783,\"5\":0.0401624143,\"6\":0.0981037766,\"7\":0.0544578359,\"8\":0.3430138826,\"9\":1.0546212196,\"10\":0.25252828,\"11\":0.1494535804,\"12\":0.0624754503,\"13\":0.001355894,\"14\":0.0050173439,\"15\":0.0052358187,\"16\":0.0011725067,\"17\":0.0010307421,\"18\":1.0241208076,\"19\":3.6867966652,\"20\":0.4502205253,\"21\":0.0484119244,\"22\":2.2572875023,\"23\":0.3789347112,\"24\":1.0042934418,\"25\":2.9090054035,\"26\":1.8933598995,\"27\":1.3859074116,\"28\":0.3827198744,\"29\":0.2646365762,\"30\":1.7742085457,\"31\":0.0136727821,\"32\":0.0053820172,\"33\":0.5485631227,\"34\":0.2064044327,\"35\":0.0049266233,\"36\":0.0005439016,\"37\":0.0007023578,\"38\":0.0004170335,\"39\":0.0054524317,\"40\":1.1938130856,\"41\":0.0238215197,\"42\":3.1279797554,\"43\":1.3883389235,\"44\":3.0503094196,\"45\":1.695879817,\"46\":1.8551058769,\"47\":1.4570231438,\"48\":0.0047810897,\"49\":0.026396824,\"50\":0.6633765101,\"51\":0.3141393065,\"52\":2.8411159515,\"53\":1.143143177,\"54\":0.0520330966,\"55\":0.3398066461,\"56\":0.4140175879,\"57\":2.5563707352,\"58\":1.3370712996,\"59\":0.0227173548,\"60\":3.4447185993,\"61\":1.8576486111,\"62\":0.8189754486,\"63\":0.6776530743,\"64\":0.0677763447,\"65\":0.212713033,\"66\":0.1003480032,\"67\":0.1746164262,\"68\":0.4123829603,\"69\":0.5507118702,\"70\":0.1047425047,\"71\":0.0194335245,\"72\":0.001482119,\"73\":0.0009310447,\"74\":0.0002176317,\"75\":0.0076908777,\"76\":0.0003866984,\"77\":0.0008008487,\"78\":1.2395234108,\"79\":0.4564163089,\"80\":0.0000461392},\"patch\":{\"0\":0,\"1\":1,\"2\":2,\"3\":3,\"4\":4,\"5\":5,\"6\":5,\"7\":5,\"8\":5,\"9\":5,\"10\":5,\"11\":5,\"12\":5,\"13\":5,\"14\":5,\"15\":5,\"16\":5,\"17\":5,\"18\":5,\"19\":5,\"20\":6,\"21\":6,\"22\":6,\"23\":7,\"24\":7,\"25\":7,\"26\":8,\"27\":9,\"28\":10,\"29\":10,\"30\":10,\"31\":11,\"32\":11,\"33\":11,\"34\":11,\"35\":11,\"36\":11,\"37\":11,\"38\":11,\"39\":11,\"40\":11,\"41\":11,\"42\":11,\"43\":12,\"44\":13,\"45\":14,\"46\":15,\"47\":16,\"48\":17,\"49\":17,\"50\":17,\"51\":17,\"52\":17,\"53\":18,\"54\":18,\"55\":18,\"56\":18,\"57\":18,\"58\":19,\"59\":20,\"60\":20,\"61\":21,\"62\":22,\"63\":22,\"64\":22,\"65\":22,\"66\":22,\"67\":22,\"68\":22,\"69\":22,\"70\":22,\"71\":22,\"72\":22,\"73\":22,\"74\":22,\"75\":22,\"76\":22,\"77\":22,\"78\":22,\"79\":22,\"80\":22},\"start\":{\"0\":1,\"1\":1,\"2\":1,\"3\":1,\"4\":1,\"5\":1,\"6\":0,\"7\":0,\"8\":0,\"9\":0,\"10\":0,\"11\":0,\"12\":0,\"13\":0,\"14\":0,\"15\":0,\"16\":0,\"17\":0,\"18\":0,\"19\":0,\"20\":1,\"21\":0,\"22\":0,\"23\":1,\"24\":0,\"25\":0,\"26\":1,\"27\":1,\"28\":1,\"29\":0,\"30\":0,\"31\":1,\"32\":0,\"33\":0,\"34\":0,\"35\":0,\"36\":0,\"37\":0,\"38\":0,\"39\":0,\"40\":0,\"41\":0,\"42\":0,\"43\":1,\"44\":1,\"45\":1,\"46\":1,\"47\":1,\"48\":1,\"49\":0,\"50\":0,\"51\":0,\"52\":0,\"53\":1,\"54\":0,\"55\":0,\"56\":0,\"57\":0,\"58\":1,\"59\":1,\"60\":0,\"61\":1,\"62\":1,\"63\":0,\"64\":0,\"65\":0,\"66\":0,\"67\":0,\"68\":0,\"69\":0,\"70\":0,\"71\":0,\"72\":0,\"73\":0,\"74\":0,\"75\":0,\"76\":0,\"77\":0,\"78\":0,\"79\":0,\"80\":0}}"} -------------------------------------------------------------------------------- /bytelatent/test_config_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf 5 | from pydantic import BaseModel, ConfigDict 6 | 7 | from bytelatent.config_parser import ( 8 | parse_args_to_pydantic_model, 9 | parse_file_config, 10 | recursively_parse_config, 11 | ) 12 | 13 | FIXTURE_DIR = "fixtures/test-cfgs" 14 | 15 | 16 | def test_parse_file_config(): 17 | with pytest.raises(ValueError): 18 | cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml")) 19 | assert isinstance(cfg, DictConfig) 20 | 21 | 22 | def test_nop(): 23 | cfg = OmegaConf.create({"a": 1}) 24 | parsed_cfgs = recursively_parse_config(cfg) 25 | assert len(parsed_cfgs) == 1 26 | assert parsed_cfgs[0] == cfg 27 | 28 | 29 | def test_root(): 30 | cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")}) 31 | parsed_cfgs = recursively_parse_config(cli_cfg) 32 | assert len(parsed_cfgs) == 2 33 | assert len(parsed_cfgs[1]) == 0 34 | assert parsed_cfgs[0]["seed"] == -1 35 | with pytest.raises(MissingMandatoryValue): 36 | assert parsed_cfgs[0]["b"]["y"] is not None 37 | 38 | # Test basic cli override 39 | cli_cfg = OmegaConf.create( 40 | {"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42} 41 | ) 42 | parsed_cfgs = recursively_parse_config(cli_cfg) 43 | assert parsed_cfgs[1]["seed"] == 42 44 | cfg = OmegaConf.merge(*parsed_cfgs) 45 | assert cfg["seed"] == 42 46 | 47 | 48 | def test_one_level_include(): 49 | cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")}) 50 | parsed_cfgs = recursively_parse_config(cli_cfg) 51 | assert len(parsed_cfgs) == 3 52 | assert parsed_cfgs[0]["seed"] == -1 53 | assert parsed_cfgs[1]["b"]["y"] == 10 54 | assert len(parsed_cfgs[2]) == 0 55 | cfg = OmegaConf.merge(*parsed_cfgs) 56 | assert cfg["b"]["y"] == 10 57 | 58 | cli_cfg = OmegaConf.create( 59 | {"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}} 60 | ) 61 | parsed_cfgs = recursively_parse_config(cli_cfg) 62 | assert len(parsed_cfgs) == 3 63 | assert parsed_cfgs[0]["seed"] == -1 64 | assert parsed_cfgs[1]["b"]["y"] == 10 65 | assert parsed_cfgs[2]["b"]["y"] == 100 66 | cfg = OmegaConf.merge(*parsed_cfgs) 67 | assert cfg["b"]["y"] == 100 68 | 69 | 70 | def test_two_level_include(): 71 | cli_cfg = OmegaConf.create( 72 | {"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}} 73 | ) 74 | parsed_cfgs = recursively_parse_config(cli_cfg) 75 | assert len(parsed_cfgs) == 4 76 | assert parsed_cfgs[0]["seed"] == -1 77 | assert parsed_cfgs[1]["b"]["y"] == 10 78 | assert parsed_cfgs[2]["hello"] == "world" 79 | assert parsed_cfgs[3]["p"] == 500 80 | assert parsed_cfgs[3]["b"]["z"] == -2 81 | cfg = OmegaConf.merge(*parsed_cfgs) 82 | assert cfg["a"] == 1 83 | assert cfg["seed"] == -1 84 | assert cfg["b"]["x"] == 0 85 | assert cfg["b"]["y"] == 10 86 | assert cfg["b"]["z"] == -2 87 | assert cfg["hello"] == "world" 88 | 89 | 90 | def test_multiple_includes(): 91 | cli_cfg = OmegaConf.create( 92 | { 93 | "config": [ 94 | os.path.join(FIXTURE_DIR, "top.yaml"), 95 | os.path.join(FIXTURE_DIR, "override.yaml"), 96 | ], 97 | "p": 500, 98 | "b": {"z": -2}, 99 | } 100 | ) 101 | parsed_cfgs = recursively_parse_config(cli_cfg) 102 | assert len(parsed_cfgs) == 5 103 | assert parsed_cfgs[0]["seed"] == -1 104 | assert parsed_cfgs[1]["b"]["y"] == 10 105 | assert parsed_cfgs[2]["hello"] == "world" 106 | assert parsed_cfgs[3]["a"] == 100 107 | assert parsed_cfgs[4]["p"] == 500 108 | assert parsed_cfgs[4]["b"]["z"] == -2 109 | cfg = OmegaConf.merge(*parsed_cfgs) 110 | assert cfg["a"] == 100 111 | assert cfg["seed"] == -1 112 | assert cfg["b"]["x"] == 0 113 | assert cfg["b"]["y"] == 10 114 | assert cfg["b"]["z"] == -2 115 | assert cfg["hello"] == "world" 116 | 117 | cli_cfg = OmegaConf.create( 118 | { 119 | "config": [ 120 | os.path.join(FIXTURE_DIR, "top.yaml"), 121 | os.path.join(FIXTURE_DIR, "override.yaml"), 122 | ], 123 | "p": 500, 124 | "b": {"z": -2}, 125 | "a": 1000, 126 | } 127 | ) 128 | parsed_cfgs = recursively_parse_config(cli_cfg) 129 | assert len(parsed_cfgs) == 5 130 | assert parsed_cfgs[0]["seed"] == -1 131 | assert parsed_cfgs[1]["b"]["y"] == 10 132 | assert parsed_cfgs[2]["hello"] == "world" 133 | assert parsed_cfgs[3]["a"] == 100 134 | assert parsed_cfgs[4]["p"] == 500 135 | assert parsed_cfgs[4]["b"]["z"] == -2 136 | cfg = OmegaConf.merge(*parsed_cfgs) 137 | assert cfg["a"] == 1000 138 | assert cfg["seed"] == -1 139 | assert cfg["b"]["x"] == 0 140 | assert cfg["b"]["y"] == 10 141 | assert cfg["b"]["z"] == -2 142 | assert cfg["hello"] == "world" 143 | 144 | 145 | class SubConfig(BaseModel): 146 | model_config = ConfigDict(extra="forbid") 147 | x: int = -100 148 | y: int = -100 149 | z: int = -5 150 | 151 | 152 | class SampleConfig(BaseModel): 153 | model_config = ConfigDict(extra="forbid") 154 | a: int = -100 155 | seed: int = -100 156 | b: SubConfig = SubConfig() 157 | hello: str = "" 158 | p: int = -100 159 | 160 | 161 | def test_pydantic_parse(): 162 | cli_cfg = OmegaConf.create( 163 | { 164 | "config": [ 165 | os.path.join(FIXTURE_DIR, "top.yaml"), 166 | os.path.join(FIXTURE_DIR, "override.yaml"), 167 | ], 168 | "p": 500, 169 | "a": 1000, 170 | } 171 | ) 172 | cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg) 173 | assert isinstance(cfg, SampleConfig) 174 | assert cfg.a == 1000 175 | assert cfg.p == 500 176 | assert cfg.seed == -1 177 | assert cfg.b.x == 0 178 | assert cfg.b.y == 10 179 | assert cfg.b.z == -5 180 | assert cfg.hello == "world" 181 | -------------------------------------------------------------------------------- /bytelatent/data/ngram_processor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import pickle 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | 7 | from bytelatent import ByteLatentError 8 | 9 | LOOKUP_OFFSET = 4 10 | 11 | 12 | def apply_lookup_table_wrapper(ngram_to_idx: dict[tuple, int], lookup_offset=1): 13 | """ 14 | Wrapper function for applying the lookup table to each n-gram. 15 | 16 | :param ngram: Array of numbers representing an n-gram. 17 | :param lookup_table: Dictionary where keys are tuples (n-grams) and values are the desired outputs. 18 | :param lookup_offset: Offset to add to the lookup result. 19 | :return: The value associated with the n-gram tuple in the dictionary, or None if not found. 20 | """ 21 | 22 | def apply_lookup_table(ngram): 23 | """ 24 | Function to apply to each n-gram: converts it to a tuple and looks it up in a dictionary. 25 | 26 | :param ngram: Array of numbers representing an n-gram. 27 | :return: The value associated with the n-gram tuple in the dictionary, or None if not found. 28 | """ 29 | # Convert the n-gram to a tuple 30 | ngram_tuple = tuple(ngram) 31 | 32 | if ngram_tuple not in ngram_to_idx: 33 | return 0 34 | else: 35 | return ngram_to_idx[ngram_tuple] + lookup_offset 36 | 37 | return apply_lookup_table 38 | 39 | 40 | def get_byte_ngrams_ids( 41 | byte_array: np.ndarray, n: int, ngram_to_idx: dict[tuple, int], pad_value=0 42 | ): 43 | """ 44 | Generate n-grams from a 2D numpy array. 45 | 46 | :param n: The length of each n-gram. 47 | :param pad_value: The value used for padding of the byte values to maintain the same dimensions for the n-grams. 48 | :return: A 2D numpy array where each element is the ID of an n-gram offset by LOOKUP_OFFSET. 49 | """ 50 | num_rows, num_cols = byte_array.shape 51 | 52 | # Create an array to hold the padded version of the original array 53 | padded_array = np.pad( 54 | byte_array, ((0, 0), (n - 1, 0)), mode="constant", constant_values=pad_value 55 | ) 56 | 57 | # Use stride tricks to avoid explicit looping 58 | strided = np.lib.stride_tricks.as_strided 59 | shape = (num_rows, num_cols, n) 60 | strides = padded_array.strides[:2] + (padded_array.strides[1],) 61 | ngrams = strided(padded_array, shape=shape, strides=strides) 62 | 63 | ngram_ids = np.apply_along_axis( 64 | apply_lookup_table_wrapper(ngram_to_idx, lookup_offset=LOOKUP_OFFSET), 2, ngrams 65 | ) 66 | assert ngram_ids.shape == byte_array.shape 67 | return ngram_ids 68 | 69 | 70 | def reload_tables( 71 | ngram_table_dir: str, ngram_to_size: dict[int, int], offset: int = LOOKUP_OFFSET 72 | ) -> tuple[dict[int, list], dict[tuple, int], dict[int, int]]: 73 | """ 74 | Reload lookup tables from a directory. Reload only the ngrams in the dictionary and per ngram, 75 | only load up to the max specified size. Return the actual number of ngrams taken per ngram size. 76 | """ 77 | idx_to_ngram_tables = {} 78 | ngram_to_idx_tables = {} 79 | vocab_sizes = {} 80 | for ngram, size in ngram_to_size.items(): 81 | with open(Path(ngram_table_dir) / f"ngram-{ngram}.pickle", "rb") as f: 82 | # These are already sorted by count 83 | # Value: tuple of: count, ngram, dataset 84 | ngram_data: list[tuple[tuple, tuple[int, int, str]]] = pickle.load(f)[ 85 | "counts" 86 | ] 87 | table = [ngram for ngram, _ in ngram_data][:size] 88 | if len(table) != size: 89 | raise ValueError( 90 | f"Ngram table for {ngram}-gram is not large enough to get {size} ngrams, max size is {len(ngram_data)}" 91 | ) 92 | ngram_to_idx = {ngram: idx for idx, ngram in enumerate(table)} 93 | actual_size = len(table) 94 | idx_to_ngram_tables[ngram] = table 95 | ngram_to_idx_tables[ngram] = ngram_to_idx 96 | vocab_sizes[ngram] = actual_size + offset 97 | return ngram_to_idx_tables, ngram_to_idx_tables, vocab_sizes 98 | 99 | 100 | def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]: 101 | if ngram_to_size_str is None: 102 | return None 103 | ngram_to_size = {} 104 | for entry in ngram_to_size_str.split(","): 105 | ngram, size = entry.split(":") 106 | ngram = int(ngram) 107 | size = int(size) 108 | ngram_to_size[ngram] = size 109 | return ngram_to_size 110 | 111 | 112 | class NgramProcessor: 113 | def __init__( 114 | self, 115 | ngram_table_dir: str | None = None, 116 | ngram_to_size: dict[int, int] | None = None, 117 | ): 118 | if ngram_table_dir is None or ngram_to_size is None: 119 | raise ByteLatentError( 120 | "ngram_table_dir and ngram_to_size cannot be none if enable_byte_ngrams is True" 121 | ) 122 | ( 123 | self.ngram_to_idx_tables, 124 | self.idx_to_ngram_tables, 125 | self.ngram_vocab_sizes, 126 | ) = reload_tables(ngram_table_dir, ngram_to_size) 127 | # Lowest to highest ngram 128 | self.ngram_sizes = sorted(list(self.ngram_to_idx_tables.keys())) 129 | # Although the model might not use all the ngrams, we need the tokenizer 130 | # to produce ngram_ids such that index zero is the 2-gram, later on in 131 | # src.model.megabyte.Megabyte.forward 132 | assert self.ngram_sizes[0] == 2 133 | 134 | def encode_single_ngram_table(self, data: np.ndarray, n: int): 135 | """ 136 | Return the n-grams of the input data for a given n 137 | numpy array with ids of shape data.shape 138 | """ 139 | return get_byte_ngrams_ids(data, n, self.ngram_to_idx_tables[n], pad_value=0) 140 | 141 | def encode_token_ngrams(self, data: np.ndarray): 142 | """ 143 | Return the n-grams of the input data. 144 | output shape: [ids with data.shape for n in self.ngram_sizes] 145 | """ 146 | return [self.encode_single_ngram_table(data, n) for n in self.ngram_sizes] 147 | -------------------------------------------------------------------------------- /setup/download_prepare_hf_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import argparse 4 | import os 5 | import subprocess 6 | import time 7 | 8 | import fsspec 9 | import requests 10 | from huggingface_hub import snapshot_download 11 | 12 | 13 | def run_command(command): 14 | print(f"Running: {command}") 15 | subprocess.run(command, shell=True, check=True) 16 | 17 | 18 | def download_dataset(repo_id, local_dir, allow_patterns): 19 | print(f"Downloading dataset from {repo_id}...") 20 | max_retries = 5 21 | retry_delay = 10 # seconds 22 | for attempt in range(max_retries): 23 | try: 24 | snapshot_download( 25 | repo_id, 26 | repo_type="dataset", 27 | local_dir=local_dir, 28 | allow_patterns=allow_patterns, 29 | resume_download=True, 30 | max_workers=16, # Don't hesitate to increase this number to lower the download time 31 | ) 32 | break 33 | except requests.exceptions.ReadTimeout: 34 | if attempt < max_retries - 1: 35 | print(f"Timeout occurred. Retrying in {retry_delay} seconds...") 36 | time.sleep(retry_delay) 37 | else: 38 | raise 39 | print(f"Dataset downloaded to {local_dir}") 40 | 41 | 42 | def parquet_to_jsonl( 43 | dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None 44 | ): 45 | from datatrove.executor import LocalPipelineExecutor 46 | from datatrove.pipeline.readers import ParquetReader 47 | from datatrove.pipeline.writers import JsonlWriter 48 | 49 | if tgt_dir.startswith("s3//"): 50 | if s3_profile is None: 51 | out_spec = tgt_dir 52 | else: 53 | out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) 54 | else: 55 | out_spec = tgt_dir 56 | 57 | pipeline_exec = LocalPipelineExecutor( 58 | pipeline=[ 59 | ParquetReader( 60 | src_dir, 61 | file_progress=True, 62 | doc_progress=True, 63 | glob_pattern="**/*.parquet", 64 | ), 65 | JsonlWriter( 66 | out_spec, 67 | output_filename=dataset + ".chunk.${rank}.jsonl", 68 | compression=None, 69 | ), 70 | ], 71 | tasks=ntasks, 72 | logging_dir=os.path.join(work_dir, "datatrove"), 73 | ) 74 | pipeline_exec.run() 75 | 76 | 77 | def setup_terashuf(work_dir): 78 | terashuf_dir = os.path.join(work_dir, "terashuf") 79 | terashuf_executable = os.path.join(terashuf_dir, "terashuf") 80 | 81 | if os.path.exists(terashuf_executable): 82 | print("terashuf executable already exists. Skipping setup.") 83 | return terashuf_dir 84 | 85 | print("Setting up terashuf...") 86 | run_command(f"git clone https://github.com/alexandres/terashuf {terashuf_dir}") 87 | run_command(f"make -C {terashuf_dir}") 88 | return terashuf_dir 89 | 90 | 91 | def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): 92 | # Configuration 93 | repo_id = { 94 | "fineweb_edu": "HuggingFaceFW/fineweb-edu", 95 | "fineweb_edu_10bt": "HuggingFaceFW/fineweb-edu", 96 | "dclm_baseline_1.0": "mlfoundations/dclm-baseline-1.0", 97 | "dclm_baseline_1.0_10prct": "mlfoundations/dclm-baseline-1.0", 98 | }[dataset] 99 | src_dir = f"{data_dir}/{dataset}" 100 | out_dir = f"{src_dir}_shuffled" 101 | os.makedirs(out_dir, exist_ok=True) 102 | work_dir = src_dir # Directory of this Python file 103 | prefix = f"{dataset}.chunk." 104 | orig_extension = { 105 | "fineweb_edu": ".jsonl", 106 | "fineweb_edu_10bt": ".jsonl", 107 | "dclm_baseline_1.0": ".jsonl.zst", 108 | "dclm_baseline_1.0_10prct": ".jsonl.zst", 109 | }[dataset] 110 | cat_command = { 111 | "fineweb_edu": "cat", 112 | "fineweb_edu_10bt": "cat", 113 | "dclm_baseline_1.0": "zstdcat", 114 | "dclm_baseline_1.0_10prct": "zstdcat", 115 | }[dataset] 116 | allow_patterns = { 117 | "fineweb_edu": None, 118 | "fineweb_edu_10bt": "sample/10BT/*", 119 | "dclm_baseline_1.0": "*.jsonl.zst", 120 | "dclm_baseline_1.0_10prct": "global-shard_01_of_10/*.jsonl.zst", 121 | }[dataset] 122 | suffix = ".jsonl" 123 | k_validation = 10000 # Number of lines to take from each chunk for validation 124 | 125 | # Setup terashuf 126 | terashuf_dir = setup_terashuf(work_dir) 127 | 128 | # Download dataset 129 | download_dataset(repo_id, src_dir, allow_patterns) 130 | 131 | if "fineweb" in dataset: 132 | parquet_to_jsonl(dataset, work_dir, src_dir, src_dir) 133 | 134 | # Set up environment variables 135 | os.environ["MEMORY"] = f"{memory}" 136 | os.environ["SEED"] = f"{seed}" 137 | 138 | # Run the original shuffling and splitting command 139 | terashuf_executable = os.path.join(terashuf_dir, "terashuf") 140 | run_command( 141 | f"ulimit -n 100000 && " 142 | f"find {src_dir} -type f -name '*{orig_extension}' -print0 | xargs -0 {cat_command} | {terashuf_executable} | " 143 | f"split -n r/{nchunks} -d --suffix-length 2 --additional-suffix {suffix} - {out_dir}/{prefix}" 144 | "; trap 'echo \"Caught signal 13, exiting with code 1\"; exit 1' SIGPIPE;" 145 | ) 146 | 147 | # Create validation set and remove lines from chunks 148 | validation_file = f"{out_dir}/{dataset}.val{suffix}" 149 | for i in range(nchunks): 150 | chunk_file = f"{out_dir}/{prefix}{i:02d}{suffix}" 151 | run_command(f"head -n {k_validation} {chunk_file} >> {validation_file}") 152 | run_command(f"sed -i '1,{k_validation}d' {chunk_file}") 153 | 154 | print("All tasks completed successfully!") 155 | 156 | 157 | if __name__ == "__main__": 158 | parser = argparse.ArgumentParser() 159 | parser.add_argument("dataset", type=str) 160 | parser.add_argument("memory", type=float, default=8) 161 | parser.add_argument("--data_dir", type=str, default="data") 162 | parser.add_argument("--seed", type=int, default=42) 163 | parser.add_argument("--nchunks", type=int, default=32) 164 | 165 | args = parser.parse_args() 166 | 167 | main(args.dataset, args.memory, args.data_dir, args.seed, args.nchunks) 168 | -------------------------------------------------------------------------------- /bytelatent/data/file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import fsspec 4 | import pyarrow as pa 5 | 6 | # pyarrow needs the initialization from this import 7 | import pyarrow.dataset # pyright: ignore 8 | import typer 9 | from pyarrow.lib import ArrowInvalid 10 | from rich.progress import track 11 | 12 | 13 | def is_valid_arrow_file(path: str): 14 | try: 15 | dataset = pa.dataset.dataset(path, format="arrow") 16 | return True 17 | except ArrowInvalid: 18 | return False 19 | 20 | 21 | app = typer.Typer() 22 | 23 | S3_PREFIX = "s3://" 24 | 25 | 26 | def get_fs( 27 | path: str, s3_profile: str | None = None, use_listings_cache: None | bool = None 28 | ) -> fsspec.AbstractFileSystem: 29 | if path.startswith("s3://"): 30 | config_kwargs = {"retries": {"max_attempts": 10, "mode": "standard"}} 31 | if s3_profile is None: 32 | return fsspec.filesystem( 33 | "s3", use_listings_cache=use_listings_cache, config_kwargs=config_kwargs 34 | ) 35 | else: 36 | return fsspec.filesystem( 37 | "s3", 38 | profile=s3_profile, 39 | use_listings_cache=use_listings_cache, 40 | config_kwargs=config_kwargs, 41 | ) 42 | else: 43 | return fsspec.filesystem("file") 44 | 45 | 46 | @app.command() 47 | def print_local_to_delete( 48 | blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" 49 | ): 50 | for s in local_dirs: 51 | assert s.endswith("/"), "Dirs must end with /" 52 | assert blob_dir.endswith("/"), "Dirs must end with /" 53 | blob_fs = fsspec.filesystem("s3", profile=s3_profile) 54 | blob_files = blob_fs.find(blob_dir) 55 | for f in track(blob_files): 56 | size = blob_fs.info(f)["Size"] 57 | if not f.lower().endswith(".complete"): 58 | assert size != 0, f"Size was invalidly zero for {f}" 59 | 60 | blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} 61 | local_fs = fsspec.filesystem("file") 62 | 63 | files_to_delete = [] 64 | for local_dir in local_dirs: 65 | local_files = local_fs.find(local_dir) 66 | for f in local_files: 67 | relative_path = f[len(local_dir) :] 68 | if relative_path in blob_relative_paths and not os.path.islink(f): 69 | files_to_delete.append(f) 70 | print(len(files_to_delete)) 71 | with open("/tmp/files_to_delete.txt", "w") as f: 72 | for file in files_to_delete: 73 | f.write(f"{file}\n") 74 | 75 | 76 | @app.command() 77 | def compare_local_to_blob( 78 | source_dirs: list[str], 79 | dst_dir: str, 80 | s3_profile: str = "blt", 81 | print_sizes: bool = False, 82 | ): 83 | for s in source_dirs: 84 | assert s.endswith("/"), "Dirs must end with /" 85 | assert dst_dir.endswith("/"), "Dirs must end with /" 86 | assert len(source_dirs) != 0 87 | assert dst_dir.startswith("s3://") 88 | local_fs = fsspec.filesystem("file") 89 | dst_fs = fsspec.filesystem("s3", profile=s3_profile) 90 | source_to_files = {} 91 | source_file_to_size = {} 92 | all_local_files = set() 93 | for s in source_dirs: 94 | skipped = [] 95 | if s not in source_to_files: 96 | source_to_files[s] = [] 97 | for f in local_fs.find(s): 98 | if os.path.islink(f): 99 | continue 100 | if f.endswith(".COMPLETE") or f.endswith(".complete"): 101 | is_complete_file = True 102 | assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" 103 | else: 104 | is_complete_file = False 105 | 106 | if not is_complete_file and os.path.getsize(f) == 0: 107 | skipped.append(f) 108 | continue 109 | if f.endswith(".arrow"): 110 | if not is_valid_arrow_file(f): 111 | skipped.append(f) 112 | continue 113 | 114 | file_without_prefix = f[len(s) :] 115 | if file_without_prefix not in source_file_to_size: 116 | source_file_to_size[file_without_prefix] = os.path.getsize(f) 117 | else: 118 | source_file_to_size[file_without_prefix] = max( 119 | source_file_to_size[file_without_prefix], os.path.getsize(f) 120 | ) 121 | 122 | source_to_files[s].append(f) 123 | all_local_files.add(file_without_prefix) 124 | print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) 125 | 126 | dst_files = dst_fs.find(dst_dir) 127 | print(dst_dir, len(dst_files)) 128 | 129 | dst_file_to_size = {} 130 | dst_file_set = set() 131 | for f in dst_files: 132 | dst_file_without_prefix = f[len(dst_dir) - len(S3_PREFIX) :] 133 | dst_file_set.add(dst_file_without_prefix) 134 | dst_file_to_size[dst_file_without_prefix] = dst_fs.size(f) 135 | 136 | diff = all_local_files.symmetric_difference(dst_file_set) 137 | print("Local files", len(all_local_files)) 138 | print("DST Files", len(dst_file_set)) 139 | print("Symmetric difference", len(diff)) 140 | dst_only_files = dst_file_set - all_local_files 141 | print("DST only", len(dst_only_files), list(dst_only_files)[:10]) 142 | 143 | all_files = dst_file_set | all_local_files 144 | print("Check that files match") 145 | size_success = True 146 | for f in sorted(all_files): 147 | if f in source_file_to_size and f in dst_file_to_size: 148 | if source_file_to_size[f] != dst_file_to_size[f]: 149 | size_success = False 150 | print( 151 | f"Mismatch file size for {f}, Local: {source_file_to_size[f]} Blob: {dst_file_to_size[f]}" 152 | ) 153 | else: 154 | if print_sizes: 155 | print(f"Matching file size: {dst_file_to_size[f]} for {f}") 156 | elif f not in source_file_to_size: 157 | size_success = False 158 | print(f"Missing file in source: {f}") 159 | elif f not in dst_file_to_size: 160 | size_success = False 161 | print(f"missing file in dst: {f}") 162 | else: 163 | raise ValueError("Unexpected to be missing file in src and dst") 164 | 165 | if size_success: 166 | print("All files pass size check") 167 | else: 168 | raise ValueError("At least one file failed size comparison check") 169 | 170 | 171 | if __name__ == "__main__": 172 | app() 173 | -------------------------------------------------------------------------------- /bytelatent/model/latent_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import logging 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | import torch.nn 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | from torch.nn.attention.flex_attention import BlockMask 10 | from xformers.ops import AttentionBias 11 | 12 | from bytelatent.base_transformer import ( 13 | BaseTransformer, 14 | BaseTransformerArgs, 15 | flex_attention_comp, 16 | repeat_kv, 17 | ) 18 | from bytelatent.model.utils import create_causal_mask 19 | 20 | logger = logging.getLogger() 21 | try: 22 | from apex.normalization.fused_layer_norm import FusedRMSNorm 23 | 24 | RMSNorm = FusedRMSNorm 25 | except (ImportError, ModuleNotFoundError): 26 | logging.debug("Apex not found. Using nn.RMSNorm") 27 | RMSNorm = nn.RMSNorm 28 | 29 | 30 | class CrossAttention(nn.Module): 31 | """ 32 | CrossAttention block to attend to the encoder states from the decoder. 33 | Rope is not supported. 34 | """ 35 | 36 | def __init__( 37 | self, 38 | dim: int, 39 | head_dim: int, 40 | n_heads: int, 41 | n_kv_heads: int, 42 | norm_eps: float, 43 | ): 44 | super().__init__() 45 | 46 | self.dim = dim 47 | self.head_dim = head_dim 48 | 49 | self.n_heads = n_heads 50 | self.n_kv_heads = n_kv_heads 51 | self.heads_per_group = self.n_heads // self.n_kv_heads 52 | 53 | self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) 54 | self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) 55 | 56 | self.wq = nn.Linear( 57 | dim, 58 | n_heads * head_dim, 59 | bias=False, 60 | ) 61 | self.wk = nn.Linear( 62 | dim, 63 | n_kv_heads * head_dim, 64 | bias=False, 65 | ) 66 | self.wv = nn.Linear( 67 | dim, 68 | n_kv_heads * head_dim, 69 | bias=False, 70 | ) 71 | 72 | self.wo = nn.Linear( 73 | n_heads * head_dim, 74 | dim, 75 | bias=False, 76 | ) 77 | 78 | def forward( 79 | self, 80 | x: torch.Tensor, 81 | kv: torch.Tensor, 82 | mask: Optional[Union[BlockMask, AttentionBias, str]] = None, 83 | ) -> torch.Tensor: 84 | # B S D 85 | bsz, seq_len, _ = x.shape 86 | _, slen_kv, _ = kv.shape 87 | x_norm = self.cross_attn_norm_q(x) 88 | kv = self.cross_attn_norm_kv(kv) 89 | 90 | xq = self.wq(x_norm) 91 | xk = self.wk(kv) 92 | xv = self.wv(kv) 93 | 94 | output_shape = xq.shape 95 | # B S D -> B S H D 96 | xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) 97 | xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) 98 | xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) 99 | 100 | xk = repeat_kv(xk, self.heads_per_group, dim=2) 101 | xv = repeat_kv(xv, self.heads_per_group, dim=2) 102 | 103 | assert mask is None or isinstance(mask, BlockMask) 104 | xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) 105 | output = flex_attention_comp(xq, xk, xv, block_mask=mask) 106 | output = output.transpose(1, 2).contiguous() # B H S D -> B S H D 107 | 108 | output = self.wo(output.reshape(output_shape)) 109 | 110 | return x + output 111 | 112 | def init_weights(self, base_std: float, factor: float = 1.0): 113 | std = base_std or (self.dim ** (-0.5)) / factor 114 | 115 | nn.init.trunc_normal_( 116 | self.wq.weight, 117 | mean=0.0, 118 | std=std, 119 | a=-3 * std, 120 | b=3 * std, 121 | ) 122 | 123 | nn.init.trunc_normal_( 124 | self.wk.weight, 125 | mean=0.0, 126 | std=std, 127 | a=-3 * std, 128 | b=3 * std, 129 | ) 130 | 131 | nn.init.trunc_normal_( 132 | self.wv.weight, 133 | mean=0.0, 134 | std=std, 135 | a=-3 * std, 136 | b=3 * std, 137 | ) 138 | 139 | nn.init.trunc_normal_( 140 | self.wo.weight, 141 | mean=0.0, 142 | std=std, 143 | a=-3 * std, 144 | b=3 * std, 145 | ) 146 | self.cross_attn_norm_q.reset_parameters() 147 | self.cross_attn_norm_kv.reset_parameters() 148 | 149 | 150 | class GlobalTransformer(BaseTransformer): 151 | def __init__(self, args: BaseTransformerArgs): 152 | super().__init__(args) 153 | self.dropout = args.dropout 154 | self.eos_id = args.eos_id 155 | self.dim_token_emb = args.dim_token_emb 156 | 157 | self.token_embedding_projection = None 158 | if args.dim_token_emb is not None and args.dim_token_emb != self.dim: 159 | self.token_embedding_projection = nn.Linear( 160 | args.dim_token_emb, 161 | args.dim, 162 | bias=False, 163 | ) 164 | 165 | def forward( 166 | self, 167 | tokens: torch.Tensor, 168 | tok_idx: Optional[torch.Tensor] = None, 169 | embeds: Optional[torch.Tensor] = None, 170 | mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, 171 | cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, 172 | ): 173 | """ 174 | Similar to BaseTransformer.forward, but with an additional embeds argument 175 | and projection to the token space. 176 | """ 177 | bs, seqlen = tokens.shape 178 | 179 | h = embeds 180 | 181 | mask = ( 182 | mask 183 | if mask is not None 184 | else create_causal_mask( 185 | seqlen, 186 | self.attn_impl, 187 | self.attn_bias_type, 188 | tokens=tokens, 189 | eos_id=self.eos_id, 190 | ) 191 | ) 192 | 193 | if self.token_embedding_projection is not None and h.shape[-1] != self.dim: 194 | h = self.token_embedding_projection(h) 195 | 196 | h = F.dropout(h, p=self.dropout, training=self.training) 197 | 198 | h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) 199 | return h, cache 200 | 201 | def init_weights(self): 202 | super().init_weights() 203 | std = self.dim_token_emb ** (-0.5) 204 | if self.token_embedding_projection is not None: 205 | nn.init.trunc_normal_( 206 | self.token_embedding_projection.weight, 207 | mean=0.0, 208 | std=std, 209 | a=-3 * std, 210 | b=3 * std, 211 | ) 212 | -------------------------------------------------------------------------------- /bytelatent/model/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import logging 3 | import os 4 | 5 | import torch 6 | from torch.nn.attention.flex_attention import create_block_mask 7 | from xformers.ops import fmha 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | def patch_reduce(h, max_num_patches, reduction, patch_ids): 13 | """ 14 | Reduce variable length patches to single embedding per patch 15 | Note: this works with variable number of patches for different sequences in the batch 16 | It handles variable length patches by assuming that patch_lengths will be 0 for any 17 | extra patches on the *right*. Since there can be a variable number of patches 18 | this function also return the number of patches for each sequence in the batch. 19 | Any embeddings on the right that are not allocated to a patch 20 | (i.e. if the sum(patch_lengths[i]) < seq_len for any i) 21 | will be sent to a dummy patch, which is trimmed before returning. 22 | """ 23 | bs, seq_len, emb_dim = h.shape 24 | 25 | patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) 26 | 27 | reduced_embs = torch.zeros( 28 | (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device 29 | ) 30 | reduced_embs = reduced_embs.scatter_reduce( 31 | src=h, 32 | dim=1, 33 | index=patch_ids, 34 | reduce=reduction, 35 | include_self=False, 36 | ) 37 | reduced_embs = reduced_embs[:, :max_num_patches, :] 38 | 39 | return reduced_embs 40 | 41 | 42 | def concat_downsample(h, patch_lengths, patch_size): 43 | # The assumption in this function is that seq_len = patch_size * num_patches. 44 | bs, seq_len, emb_dim = h.shape 45 | patch_end_ids = torch.cumsum(patch_lengths, dim=1) 46 | patch_ids = patch_end_ids.unsqueeze(-1) - torch.arange(patch_size, 0, -1).to( 47 | patch_end_ids.device 48 | ) 49 | # Is clamp ok here? 50 | patch_ids = patch_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, -1, h.shape[-1]) 51 | patch_ids = patch_ids.view(bs, -1, emb_dim) 52 | # after gather h.shape = [batch_size, seq_len, dim] 53 | h = torch.gather(h, 1, patch_ids) 54 | h = h.reshape(bs, patch_lengths.shape[1], patch_size * h.size(-1)) 55 | return h 56 | 57 | 58 | def pooling_downsample(h, max_num_patches, pooling_mode, patch_ids): 59 | cat = [] 60 | if "avg" in pooling_mode or "mean" in pooling_mode: 61 | cat.append(patch_reduce(h, max_num_patches, "mean", patch_ids)) 62 | if "min" in pooling_mode: 63 | cat.append(patch_reduce(h, max_num_patches, "amin", patch_ids)) 64 | if "max" in pooling_mode: 65 | cat.append(patch_reduce(h, max_num_patches, "amax", patch_ids)) 66 | assert len(cat) > 0 67 | h = torch.cat(cat, dim=-1) 68 | return h 69 | 70 | 71 | def downsample( 72 | h, 73 | num_patches, 74 | patch_lengths=None, 75 | patch_ids=None, 76 | downsampling_by_pooling=None, 77 | patch_size=4, 78 | ): 79 | """ 80 | Downsampling: 81 | a. concatenating embeddings in the patch 82 | Note: with dynamic patching, patch the last patch_size tokens. 83 | b. pooling embeddings in the patch 84 | """ 85 | # input: h.shape = [batch_size, seq_len, dim] 86 | # input: pool h.shape = [batch_size, seq_len / patch_size, dim] 87 | # if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep 88 | if downsampling_by_pooling is not None and len(downsampling_by_pooling) > 0: 89 | # By pooling 90 | max_num_patches = num_patches 91 | assert patch_ids is not None 92 | h = pooling_downsample(h, max_num_patches, downsampling_by_pooling, patch_ids) 93 | else: 94 | # TODO: remove this condition 95 | # By concatenating (fixed lengths patching) 96 | assert patch_lengths is not None 97 | h = concat_downsample(h, patch_lengths, patch_size) 98 | return h 99 | 100 | 101 | def causal_mask(b, h, q_idx, kv_idx): 102 | return q_idx >= kv_idx 103 | 104 | 105 | def tokens_to_seqlen(batch: torch.Tensor, eos_id: int): 106 | """ 107 | 0 0 0 1 0 0 0 1 0 0 0 108 | 0 1 0 0 0 1 0 0 0 0 0 109 | -> 4 4 3 2 4 5 110 | """ 111 | mask = batch == eos_id 112 | mask[:, -1] = True # virtual eos at the end of each row 113 | 114 | # 0 0 0 1 0 0 0 1 0 0 X 115 | # 0 1 0 0 0 1 0 0 0 0 X 116 | row, col = torch.where(mask) 117 | 118 | # row = 0, 0, 0, 1, 1, 1 119 | # col = 3, 7, 10, 1, 5, 10 120 | seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1] 121 | # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5) 122 | return [int(col[0].item() + 1)] + seqlens.tolist() 123 | 124 | 125 | def create_causal_mask( 126 | seqlen, 127 | attn_impl: str, 128 | attn_bias_type: str | None, 129 | *, 130 | eos_id: int | None = None, 131 | tokens: torch.Tensor | None = None, 132 | sliding_window: int | None = None, 133 | ): 134 | if attn_impl == "xformers": 135 | if attn_bias_type is None: 136 | return fmha.attn_bias.LowerTriangularMask() 137 | elif attn_bias_type == "causal": 138 | assert sliding_window is None 139 | return fmha.attn_bias.LowerTriangularMask() 140 | elif attn_bias_type == "block_causal": 141 | assert sliding_window is None 142 | assert eos_id is not None 143 | assert tokens is not None 144 | return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( 145 | q_seqlen=tokens_to_seqlen(tokens, eos_id) 146 | ) 147 | elif attn_bias_type == "local_block_causal": 148 | assert sliding_window is not None 149 | assert eos_id is not None 150 | assert tokens is not None 151 | return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( 152 | q_seqlen=tokens_to_seqlen(tokens, eos_id) 153 | ).make_local_attention(sliding_window) 154 | else: 155 | return fmha.attn_bias.LocalAttentionFromBottomRightMask( 156 | window_left=sliding_window - 1, window_right=0 157 | ) 158 | elif attn_impl == "sdpa": 159 | BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) 160 | 161 | if attn_bias_type == "causal": 162 | return "causal" 163 | 164 | if BLT_SUPPRESS_ATTN_ERROR == 1: 165 | return "causal" 166 | else: 167 | raise ValueError( 168 | "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" 169 | ) 170 | elif attn_impl == "flex_attention": 171 | return create_block_mask(causal_mask, None, None, seqlen, seqlen) 172 | elif attn_impl == "fmha": 173 | return None 174 | else: 175 | raise NotImplementedError( 176 | f"Attention {attn_impl} with {sliding_window} sliding window not implemented" 177 | ) 178 | -------------------------------------------------------------------------------- /bytelatent/preprocess/preprocess_entropies.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | import time 3 | 4 | import fsspec 5 | import jsonlines 6 | import numpy as np 7 | import pyarrow as pa 8 | import torch 9 | import typer 10 | from rich.progress import Progress, TextColumn 11 | 12 | from bytelatent.data.file_util import get_fs 13 | from bytelatent.data.patcher import calculate_entropies 14 | from bytelatent.entropy_model import load_entropy_model 15 | from bytelatent.tokenizers.build_tokenizer import TokenizerArgs 16 | 17 | 18 | def get_id_key(doc: dict) -> int: 19 | """ 20 | We need a reliable way to ensure that samples from jsonl 21 | and arrow are the same, but there is no unique id field, 22 | so derive the best possible 23 | """ 24 | if "sample_id" in doc: 25 | return "sample_id" 26 | elif "title" in doc: 27 | return "title" 28 | elif "qid" in doc: 29 | return "qid" 30 | elif "paper_id" in doc: 31 | return "paper_id" 32 | elif "path" in doc: 33 | return "path" 34 | elif "url" in doc: 35 | return "url" 36 | elif "id" in doc: 37 | return "id" 38 | else: 39 | raise ValueError(f"Could not find a id key from: {doc.keys()}") 40 | 41 | 42 | def get_id_from_doc(doc: dict) -> int: 43 | """ 44 | We need a reliable way to ensure that samples from jsonl 45 | and arrow are the same, but there is no unique id field, 46 | so derive the best possible 47 | """ 48 | return str(doc[get_id_key(doc)]) 49 | 50 | 51 | def get_text(doc: dict): 52 | if "text" in doc: 53 | text = doc["text"] 54 | elif "content" in doc: 55 | text = doc["content"] 56 | else: 57 | raise ValueError(f"Could not find a text key from: {doc.keys()}") 58 | return text 59 | 60 | 61 | def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str): 62 | with fs.open(path) as f: 63 | reader = jsonlines.Reader(f) 64 | yield from reader 65 | 66 | 67 | def main( 68 | input_file: str, 69 | output_file: str, 70 | patching_device: str = "cuda", 71 | log_step: int = 10_000, 72 | entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint", 73 | entropy_model_state_dict_path: str = "public_data/entropy_model.pth", 74 | bpe_tokenizer_path: str = "public_data/tokenizer.model", 75 | dry_run: bool = False, 76 | s3_profile: str | None = None, 77 | ): 78 | print(f"Preprocessing entropies, input: {input_file}, output: {output_file}") 79 | print("Loading entropy model", entropy_model_checkpoint_dir) 80 | input_fs = get_fs(input_file, s3_profile=s3_profile) 81 | input_doc_iterator = jsonl_file_iterator(input_fs, input_file) 82 | 83 | if dry_run: 84 | return 85 | entropy_model, _ = load_entropy_model( 86 | entropy_model_checkpoint_dir, 87 | entropy_model_state_dict_path, 88 | device=patching_device, 89 | ) 90 | 91 | print("Creating patcher") 92 | patching_batch_size = 32 93 | print("Creating tokenizer") 94 | tokenizer_args = TokenizerArgs( 95 | name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path} 96 | ) 97 | tokenizer = tokenizer_args.build() 98 | step = 0 99 | print("starting") 100 | start_time = time.time() 101 | patch_time = 0 102 | entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False) 103 | sample_id_field = pa.field("sample_id", pa.string(), nullable=False) 104 | text_field = pa.field("text", pa.string(), nullable=False) 105 | schema = pa.schema([sample_id_field, text_field, entropy_field]) 106 | arrow_batch_size = 1_000 107 | 108 | output_fs = get_fs(output_file, s3_profile=s3_profile) 109 | 110 | try: 111 | with output_fs.open(output_file, "wb") as sink: 112 | with pa.ipc.new_file(sink, schema) as writer: 113 | id_buffer = [] 114 | entropies_buffer = [] 115 | text_buffer = [] 116 | with Progress( 117 | *Progress.get_default_columns(), 118 | TextColumn("Completed: {task.completed}"), 119 | ) as progress: 120 | task = progress.add_task( 121 | "[green]Calculating entropies...", total=None 122 | ) 123 | for doc in input_doc_iterator: 124 | sample_id = get_id_from_doc(doc) 125 | text = get_text(doc) 126 | tokens = torch.tensor(tokenizer.encode(text)) 127 | patch_start = time.time() 128 | scores, _ = calculate_entropies( 129 | tokens, 130 | entropy_model, 131 | patching_batch_size, 132 | patching_device, 133 | ) 134 | entropies_buffer.append( 135 | np.array(scores.tolist(), dtype=np.float16) 136 | ) 137 | id_buffer.append(sample_id) 138 | text_buffer.append(text) 139 | if len(entropies_buffer) == arrow_batch_size: 140 | batch = pa.record_batch( 141 | { 142 | "entropies": entropies_buffer, 143 | "sample_id": id_buffer, 144 | "text": text_buffer, 145 | }, 146 | schema, 147 | ) 148 | writer.write(batch) 149 | entropies_buffer = [] 150 | id_buffer = [] 151 | text_buffer = [] 152 | patch_time += time.time() - patch_start 153 | step += 1 154 | if step % log_step == 0: 155 | print("Completed steps:", step) 156 | progress.update(task, advance=1) 157 | if len(entropies_buffer) > 0: 158 | # Write last things 159 | batch = pa.record_batch( 160 | { 161 | "entropies": entropies_buffer, 162 | "sample_id": id_buffer, 163 | "text": text_buffer, 164 | }, 165 | schema, 166 | ) 167 | writer.write(batch) 168 | entropies_buffer = [] 169 | id_buffer = [] 170 | text_buffer = [] 171 | output_fs.touch(f"{output_file}.complete") 172 | except: 173 | if output_fs.exists(output_file): 174 | output_fs.rm(output_file) 175 | raise 176 | elapsed = time.time() - start_time 177 | print("steps", step) 178 | print("done in:", elapsed) 179 | 180 | 181 | if __name__ == "__main__": 182 | typer.run(main) 183 | -------------------------------------------------------------------------------- /bytelatent/data/iterators/sequence_iterator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | from logging import getLogger 3 | from typing import Any 4 | 5 | import numpy as np 6 | from pydantic import BaseModel, ConfigDict 7 | 8 | from bytelatent.data.data_types import BltSequence 9 | from bytelatent.data.iterators.abstract_iterator import ( 10 | PydanticIteratorState, 11 | StatefulIterator, 12 | ) 13 | from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator 14 | from bytelatent.data.iterators.limit_iterator import LimitIterator 15 | from bytelatent.data.iterators.looping_iterator import LoopingIterator 16 | from bytelatent.data.iterators.preprocess_iterator import ( 17 | PreprocessIterator, 18 | PreprocessIteratorState, 19 | ) 20 | 21 | logger = getLogger() 22 | 23 | 24 | class SequencePackingArgs(BaseModel): 25 | model_config = ConfigDict(extra="forbid") 26 | output_seq_len: int 27 | buffer_size: int 28 | 29 | 30 | class SequenceIteratorState(PydanticIteratorState): 31 | model_config = ConfigDict(extra="forbid") 32 | sequence_packing_args: SequencePackingArgs 33 | preprocess_iterator_state: PreprocessIteratorState 34 | # If None, rng is disabled. 35 | rng_state: dict[str, Any] | None 36 | 37 | def build(self): 38 | preprocess_iterator = self.preprocess_iterator_state.build() 39 | return SequenceIterator( 40 | preprocess_iterator, 41 | sequence_packing_args=self.sequence_packing_args, 42 | rng_state=self.rng_state, 43 | ) 44 | 45 | 46 | def get_datafile( 47 | iterator: PreprocessIterator | ArrowFileIterator | LoopingIterator | LimitIterator, 48 | ): 49 | if isinstance(iterator, ArrowFileIterator): 50 | return f"file={iterator.file_path} n_shards={len(iterator.dataset_files) if iterator.dataset_files is not None else None}" 51 | elif isinstance(iterator, PreprocessIterator): 52 | return get_datafile(iterator.arrow_iterator) 53 | elif isinstance(iterator, LoopingIterator): 54 | return get_datafile(iterator.file_iterator) 55 | elif isinstance(iterator, LimitIterator): 56 | return get_datafile(iterator.base_iterator) 57 | else: 58 | raise NotImplementedError() 59 | 60 | 61 | class SequenceIterator(StatefulIterator): 62 | def __init__( 63 | self, 64 | preprocess_iterator: PreprocessIterator, 65 | *, 66 | rng_state: dict[str, Any] | None, 67 | sequence_packing_args: SequencePackingArgs, 68 | ): 69 | self.preprocess_iterator = preprocess_iterator 70 | self.sequence_packing_args = sequence_packing_args 71 | self.output_seq_len = sequence_packing_args.output_seq_len 72 | self.buffer_size = sequence_packing_args.buffer_size 73 | if rng_state is None: 74 | self.rng = None 75 | else: 76 | self.rng = np.random.default_rng() 77 | self.rng.bit_generator.state = rng_state 78 | 79 | def get_state(self): 80 | # TODO: need to also perist the current shuffle buffer 81 | return SequenceIteratorState( 82 | sequence_packing_args=self.sequence_packing_args, 83 | preprocess_iterator_state=self.preprocess_iterator.get_state(), 84 | rng_state=None if self.rng is None else self.rng.bit_generator.state, 85 | ) 86 | 87 | def create_iter(self): 88 | example_iter = self.preprocess_iterator.create_iter() 89 | n_buffer_patches = self.buffer_size * self.output_seq_len 90 | 91 | patch_lengths: list[int] = [] 92 | tokens: list[int] = [] 93 | mask: list[bool] = [] 94 | first = True 95 | logger.info( 96 | "Starting first buffer for: %s", 97 | get_datafile(self.preprocess_iterator), 98 | ) 99 | for example in example_iter: 100 | assert example.tokens is not None 101 | assert example.mask is not None 102 | if self.preprocess_iterator.add_patches: 103 | assert example.patch_lengths is not None 104 | assert len(example.tokens) == sum(example.patch_lengths) 105 | else: 106 | assert example.patch_lengths is None 107 | assert len(example.tokens) != 0 108 | assert len(example.mask) != 0 109 | assert len(example.tokens) == len(example.mask) 110 | 111 | tokens.extend(example.tokens) 112 | mask.extend(example.mask) 113 | if self.preprocess_iterator.add_patches: 114 | patch_lengths.extend(example.patch_lengths) 115 | else: 116 | # This lets the rest of the code work as expected and just yield byte seqs 117 | patch_lengths.extend([1] * len(example.tokens)) 118 | 119 | while len(patch_lengths) >= n_buffer_patches: 120 | if first: 121 | first = False 122 | logger.info( 123 | "First buffer complete for: %s", 124 | get_datafile(self.preprocess_iterator), 125 | ) 126 | 127 | x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape( 128 | self.buffer_size, self.output_seq_len 129 | ) 130 | seq_tokens = [] 131 | seq_mask = [] 132 | start_id = 0 133 | # We fix the number of patches and therefore global steps per batch 134 | # so we have a variable number of tokens we need to account for 135 | for num_tokens in x_patches.sum(axis=-1): 136 | seq_tokens.append(tokens[start_id : start_id + num_tokens]) 137 | seq_mask.append(mask[start_id : start_id + num_tokens]) 138 | start_id += num_tokens 139 | 140 | assert start_id == x_patches.sum() 141 | 142 | # Remove what we just added from the buffer 143 | patch_lengths = patch_lengths[n_buffer_patches:] 144 | tokens = tokens[x_patches.sum() :] 145 | mask = mask[x_patches.sum() :] 146 | 147 | seq_patch_lengths: list[list[int]] = x_patches.tolist() 148 | assert len(seq_patch_lengths) == self.buffer_size 149 | if self.rng is None: 150 | permutations = list(range(len(seq_patch_lengths))) 151 | else: 152 | permutations = self.rng.permutation(len(seq_patch_lengths)) 153 | 154 | for idx in permutations: 155 | assert len(seq_patch_lengths[idx]) == self.output_seq_len 156 | assert ( 157 | sum(seq_patch_lengths[idx]) 158 | == len(seq_tokens[idx]) 159 | == len(seq_mask[idx]) 160 | ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}" 161 | assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}" 162 | if self.preprocess_iterator.add_patches: 163 | yield BltSequence( 164 | tokens=seq_tokens[idx], 165 | mask=seq_mask[idx], 166 | patch_lengths=seq_patch_lengths[idx], 167 | ) 168 | else: 169 | yield BltSequence( 170 | tokens=seq_tokens[idx], 171 | mask=seq_mask[idx], 172 | patch_lengths=None, 173 | ) 174 | -------------------------------------------------------------------------------- /bytelatent/hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | from pathlib import Path 5 | from typing import Dict, Optional, Union 6 | 7 | import torch 8 | import typer 9 | from huggingface_hub import hf_hub_download 10 | from huggingface_hub.hub_mixin import ModelHubMixin 11 | 12 | from bytelatent.args import TrainArgs 13 | from bytelatent.data.patcher import PatcherArgs, to_device 14 | from bytelatent.distributed import DistributedArgs, setup_torch_distributed 15 | from bytelatent.entropy_model import load_entropy_model 16 | from bytelatent.generate import load_consolidated_model_and_tokenizer 17 | from bytelatent.generate_blt import generate_nocache 18 | from bytelatent.model.blt import ByteLatentTransformer 19 | from bytelatent.tokenizers.blt_tokenizer import BltTokenizer 20 | from bytelatent.tokenizers.build_tokenizer import TokenizerArgs 21 | from bytelatent.transformer import LMTransformer 22 | 23 | app = typer.Typer() 24 | 25 | 26 | class BltTokenizerAndPatcher(ModelHubMixin): 27 | def __init__( 28 | self, 29 | *, 30 | patcher_args: PatcherArgs, 31 | tokenizer_args: TokenizerArgs, 32 | distributed_args: DistributedArgs, 33 | ): 34 | self.patcher_args = patcher_args 35 | self.tokenizer_args = tokenizer_args 36 | self.distributed_args = distributed_args 37 | 38 | def push_to_hub(self, *args, **kwargs): 39 | raise ValueError( 40 | "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." 41 | ) 42 | 43 | def save_pretrained(self, *args, **kwargs): 44 | raise ValueError( 45 | "Tokenizer and Patcher are saved by BLT, this class is just for loading" 46 | ) 47 | 48 | def _save_pretrained(self, *args, **kwargs): 49 | raise ValueError( 50 | "Tokenizer and Patcher are saved by BLT, this class is just for loading" 51 | ) 52 | 53 | @classmethod 54 | def _from_pretrained( 55 | cls, 56 | *, 57 | model_id: str, 58 | revision: Optional[str], 59 | cache_dir: Optional[Union[str, Path]], 60 | force_download: bool, 61 | proxies: Optional[Dict], 62 | resume_download: Optional[bool], 63 | local_files_only: bool, 64 | token: Optional[Union[str, bool]], 65 | **model_kwargs, 66 | ): 67 | if os.path.isdir(model_id): 68 | train_args_file = os.path.join(model_id, "train_args.json") 69 | else: 70 | train_args_file = hf_hub_download( 71 | repo_id=model_id, 72 | filename="train_args.json", 73 | revision=revision, 74 | cache_dir=cache_dir, 75 | force_download=force_download, 76 | proxies=proxies, 77 | resume_download=resume_download, 78 | local_files_only=local_files_only, 79 | token=token, 80 | ) 81 | 82 | with open(train_args_file) as f: 83 | train_args = TrainArgs(**json.load(f)) 84 | return cls( 85 | patcher_args=train_args.data.patcher_args, 86 | tokenizer_args=train_args.data.tokenizer_args, 87 | distributed_args=train_args.distributed, 88 | ) 89 | 90 | 91 | @app.command() 92 | def convert_to_transformers(blt_weights_dir: str, output_dir: str): 93 | if not os.path.exists(output_dir): 94 | os.makedirs(output_dir, exist_ok=True) 95 | model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(blt_weights_dir) 96 | blt_dir = os.path.join(output_dir, "blt") 97 | entropy_dir = os.path.join(output_dir, "entropy") 98 | model.save_pretrained(blt_dir, config={"args": train_cfg.model.model_dump()}) 99 | shutil.copyfile( 100 | os.path.join(blt_weights_dir, "params.json"), 101 | os.path.join(blt_dir, "train_args.json"), 102 | ) 103 | blt_readme_file = os.path.join(blt_dir, "README.md") 104 | if os.path.exists(blt_readme_file): 105 | os.remove(blt_readme_file) 106 | 107 | patcher_args = train_cfg.data.patcher_args.model_copy(deep=True) 108 | patcher_args.realtime_patching = False 109 | print("Loading entropy model and patcher") 110 | patcher_args.entropy_model_checkpoint_dir = os.path.join( 111 | blt_weights_dir, "entropy_model" 112 | ) 113 | state_path = os.path.join( 114 | patcher_args.entropy_model_checkpoint_dir, "consolidated.pth" 115 | ) 116 | entropy_model, entropy_model_args = load_entropy_model( 117 | patcher_args.entropy_model_checkpoint_dir, state_path 118 | ) 119 | entropy_model.save_pretrained( 120 | entropy_dir, config={"args": entropy_model_args.model_dump()} 121 | ) 122 | entropy_readme_file = os.path.join(entropy_dir, "README.md") 123 | if os.path.exists(entropy_readme_file): 124 | os.remove(entropy_readme_file) 125 | 126 | 127 | @app.command() 128 | def load_transformers( 129 | source: str, 130 | entropy_repo: str = "facebook/blt-entropy", 131 | blt_repo: str = "facebook/blt-1b", 132 | entropy_dir: str | None = None, 133 | blt_dir: str | None = None, 134 | prompt: str | None = None, 135 | ): 136 | if source == "local": 137 | assert entropy_dir is not None 138 | assert blt_dir is not None 139 | entropy_model = LMTransformer.from_pretrained( 140 | entropy_dir, local_files_only=True 141 | ) 142 | blt_model = ByteLatentTransformer.from_pretrained( 143 | blt_dir, local_files_only=True 144 | ) 145 | tok_and_patcher = BltTokenizerAndPatcher.from_pretrained( 146 | blt_dir, local_files_only=True 147 | ) 148 | tokenizer = tok_and_patcher.tokenizer_args.build() 149 | patcher = tok_and_patcher.patcher_args.build() 150 | print("Loaded all local") 151 | print(entropy_model) 152 | print(blt_model) 153 | print(tok_and_patcher) 154 | elif source == "hub": 155 | entropy_model = LMTransformer.from_pretrained(entropy_repo) 156 | blt_model = ByteLatentTransformer.from_pretrained(blt_repo) 157 | tok_and_patcher = BltTokenizerAndPatcher.from_pretrained(blt_repo) 158 | tokenizer = tok_and_patcher.tokenizer_args.build() 159 | patcher = tok_and_patcher.patcher_args.build() 160 | print("Loaded all remote") 161 | print(entropy_model) 162 | print(blt_model) 163 | print(tok_and_patcher) 164 | else: 165 | raise ValueError(f"Unknown source: {source}") 166 | 167 | if prompt is not None: 168 | assert isinstance(tokenizer, BltTokenizer) 169 | # Move args to correct GPU 170 | param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ 171 | tok_and_patcher.distributed_args.model_dtype 172 | ] 173 | blt_model = blt_model.cuda().eval() 174 | for param in blt_model.parameters(): 175 | param.data = param.data.to(dtype=param_dtype) 176 | 177 | # Enable realtime patching 178 | patcher.realtime_patching = True 179 | patcher.entropy_model, _ = to_device( 180 | entropy_model, tok_and_patcher.patcher_args.patching_device 181 | ) 182 | 183 | # Setup distributed 184 | distributed_args = DistributedArgs() 185 | distributed_args.configure_world() 186 | if not torch.distributed.is_initialized(): 187 | setup_torch_distributed(distributed_args) 188 | prompts = [prompt] 189 | outputs = generate_nocache( 190 | prompts, model=blt_model, tokenizer=tokenizer, patcher=patcher 191 | ) 192 | text_outputs = [tokenizer.decode(t) for t in outputs] 193 | for p, t in zip(prompts, text_outputs): 194 | print(f'Prompt: "{p}"\nCompletion: "{t}"') 195 | print() 196 | 197 | 198 | if __name__ == "__main__": 199 | app() 200 | -------------------------------------------------------------------------------- /bytelatent/stool.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import json 4 | import os 5 | import shutil 6 | import subprocess 7 | from typing import Any, Dict 8 | 9 | import jinja2 10 | from omegaconf import OmegaConf 11 | from pydantic import BaseModel 12 | 13 | from bytelatent.config_parser import parse_args_to_pydantic_model 14 | 15 | 16 | class StoolArgs(BaseModel): 17 | name: str 18 | dump_dir: str 19 | # model_config is a reserved name by pydantic, so use this instead 20 | model_conf: Any = None 21 | launcher: str = "sbatch" # Can be sbatch or bash if already in salloc 22 | python_command: str = "python" 23 | use_conda: bool = True 24 | script: str = "apps.main.train" # The script to run. 25 | copy_code: bool = True # Wether to copy code to dump dir 26 | dirs_exists_ok: bool = ( 27 | False # Wether to copy new code and config and run regardless that dir exists 28 | ) 29 | override: bool = ( 30 | False # Whether to delete dump dir and restart, requires confirmation 31 | ) 32 | force_override: bool = False # Does not require interaction 33 | nodes: int = -1 # The number of nodes to run the job on. 34 | ngpu: int = 8 # The number of GPUs required per node. 35 | ncpu: int = 16 # The number of CPUs allocated per GPU. 36 | mem: str = "" # The amount of memory to allocate. 37 | anaconda: str = "default" # The path to the anaconda environment. 38 | constraint: str = "" # The constraint on the nodes. 39 | exclude: str = "" # The nodes to exclude. 40 | time: int = -1 # The time limit of the job (in minutes). 41 | account: str = "" 42 | qos: str = "" 43 | partition: str = "learn" 44 | stdout: bool = False 45 | dry_run: bool = False 46 | 47 | 48 | def copy_dir(input_dir: str, output_dir: str) -> None: 49 | print(f"Copying : {input_dir}\n" f"to : {output_dir} ...") 50 | assert os.path.isdir(input_dir), f"{input_dir} is not a directory" 51 | assert os.path.isdir(output_dir), f"{output_dir} is not a directory" 52 | rsync_cmd = ( 53 | f"rsync -rmt --copy-links " 54 | f"--exclude .venv " 55 | f"--include '**/' " 56 | f"--include '*.py' " 57 | f"--exclude='*' " 58 | f"{input_dir}/ {output_dir}" 59 | ) 60 | print(f"Copying command: {rsync_cmd}") 61 | subprocess.call([rsync_cmd], shell=True) 62 | print("Copy done.") 63 | 64 | 65 | def retrieve_max_time_per_partition() -> Dict[str, int]: 66 | # retrieve partition max times (a bit slow) 67 | 68 | sinfo = json.loads(subprocess.check_output("sinfo --json", shell=True))["sinfo"] 69 | max_times: Dict[str, int] = {} 70 | 71 | for info in sinfo: 72 | if info["partition"]["maximums"]["time"]["infinite"]: 73 | max_times[info["partition"]["name"]] = 14 * 24 * 60 # 14 days 74 | else: 75 | max_times[info["partition"]["name"]] = info["partition"]["maximums"][ 76 | "time" 77 | ][ 78 | "number" 79 | ] # in minutes 80 | 81 | return max_times 82 | 83 | 84 | def validate_args(args) -> None: 85 | # Set maximum time limit if not specified 86 | if args.time == -1: 87 | max_times = retrieve_max_time_per_partition() 88 | args.time = max_times.get( 89 | args.partition, 3 * 24 * 60 90 | ) # Default to 3 days if not found 91 | print( 92 | f"No time limit specified, using max time for partitions: {args.time} minutes" 93 | ) 94 | 95 | if args.constraint: 96 | args.constraint = f"#SBATCH --constraint={args.constraint}" 97 | 98 | if args.account: 99 | args.account = f"#SBATCH --account={args.account}" 100 | 101 | if args.qos: 102 | args.qos = f"#SBATCH --qos={args.qos}" 103 | 104 | if getattr(args, "exclude", ""): 105 | args.exclude = f"#SBATCH --exclude={args.exclude}" 106 | 107 | if hasattr(args, "anaconda") and args.anaconda: 108 | if args.anaconda == "default": 109 | args.anaconda = ( 110 | subprocess.check_output("which python", shell=True) 111 | .decode("ascii") 112 | .strip() 113 | ) 114 | else: 115 | args.anaconda = f"{args.anaconda}/bin/python" 116 | assert os.path.isfile(args.anaconda) 117 | 118 | args.mem = args.mem or "0" 119 | 120 | assert args.partition 121 | assert args.ngpu > 0 122 | assert args.ncpu > 0 123 | assert args.nodes > 0 124 | assert args.time > 0 125 | assert args.partition 126 | 127 | 128 | def launch_job(args: StoolArgs): 129 | # Set up args default and validate them depending on the cluster or partition requested 130 | validate_args(args) 131 | job_name = args.name or args.model_conf["name"] 132 | dump_dir = os.path.join(args.dump_dir, job_name) or args.model_conf["dump_dir"] 133 | print("Creating directories...") 134 | os.makedirs( 135 | dump_dir, exist_ok=args.dirs_exists_ok or args.override or args.force_override 136 | ) 137 | if args.override or args.force_override: 138 | if args.force_override: 139 | shutil.rmtree(dump_dir) 140 | print(f"Directory '{dump_dir}' has been deleted.") 141 | else: 142 | confirm = input( 143 | f"Are you sure you want to delete the directory '{dump_dir}'? This action cannot be undone. (yes/no): " 144 | ) 145 | if confirm.lower() == "yes": 146 | shutil.rmtree(dump_dir) 147 | print(f"Directory '{dump_dir}' has been deleted.") 148 | else: 149 | print("Operation cancelled.") 150 | return 151 | if args.copy_code: 152 | os.makedirs(f"{dump_dir}/code", exist_ok=args.dirs_exists_ok) 153 | print("Copying code ...") 154 | copy_dir(os.getcwd(), f"{dump_dir}/code") 155 | 156 | print("Saving config file ...") 157 | shutil.copy(args.model_conf, f"{dump_dir}/base_config.yaml") 158 | 159 | conda_exe = os.environ.get("CONDA_EXE", "conda") 160 | conda_env_path = os.path.dirname(os.path.dirname(args.anaconda)) 161 | log_output = ( 162 | "-o $DUMP_DIR/logs/%j/%j_%t.out -e $DUMP_DIR/logs/%j/%j_%t.err" 163 | if not args.stdout 164 | else "" 165 | ) 166 | env = jinja2.Environment( 167 | loader=jinja2.PackageLoader("bytelatent"), 168 | autoescape=jinja2.select_autoescape(), 169 | ) 170 | template = env.get_template("stool_template.sh.jinja") 171 | sbatch_jinja = template.render( 172 | name=job_name, 173 | script=args.script, 174 | dump_dir=dump_dir, 175 | nodes=args.nodes, 176 | tasks=args.nodes * args.ngpu, 177 | nodes_per_run=args.nodes, 178 | ngpus=args.ngpu, 179 | ncpu=args.ncpu, 180 | mem=args.mem, 181 | qos=args.qos, 182 | account=args.account, 183 | constraint=args.constraint, 184 | exclude=args.exclude, 185 | time=args.time, 186 | partition=args.partition, 187 | python_command=args.python_command, 188 | conda_exe=conda_exe, 189 | conda_env_path=conda_env_path, 190 | use_conda=args.use_conda, 191 | log_output=log_output, 192 | go_to_code_dir=f"cd {dump_dir}/code/" if args.copy_code else "", 193 | ) 194 | 195 | print("Writing sbatch command ...") 196 | with open(f"{dump_dir}/submit.slurm", "w") as f: 197 | f.write(sbatch_jinja) 198 | 199 | if args.dry_run: 200 | print("Dry run mode enabled. Not submitting job.") 201 | else: 202 | print("Submitting job ...") 203 | os.system(f"{args.launcher} {dump_dir}/submit.slurm") 204 | 205 | print("Done.") 206 | 207 | 208 | if __name__ == "__main__": 209 | """ 210 | The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments 211 | """ 212 | args = parse_args_to_pydantic_model(StoolArgs, instantiate_default_cls=False) 213 | launch_job(args) 214 | -------------------------------------------------------------------------------- /bytelatent/generate_blt.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import torch 5 | 6 | from bytelatent.args import EvalArgs 7 | from bytelatent.config_parser import parse_args_to_pydantic_model 8 | from bytelatent.data.file_util import get_fs 9 | from bytelatent.data.patcher import Patcher 10 | from bytelatent.distributed import ( 11 | DistributedArgs, 12 | dist_max, 13 | dist_min, 14 | dist_sum, 15 | get_device_mesh, 16 | setup_torch_distributed, 17 | ) 18 | from bytelatent.generate import load_consolidated_model_and_tokenizer 19 | from bytelatent.model.blt import ByteLatentTransformer 20 | from bytelatent.tokenizers.blt_tokenizer import BltTokenizer 21 | 22 | logger = logging.getLogger() 23 | 24 | 25 | def get_max_length(input_tokens: list[list[int]] | None) -> int: 26 | # reduce max length prompt over all processes to have an equal number of call on each process with fsdp 27 | if input_tokens is None: 28 | max_length = 0 29 | else: 30 | max_length = max([len(t) for t in input_tokens]) 31 | if torch.distributed.is_initialized(): 32 | max_length = int(dist_max(max_length)) 33 | return max_length 34 | 35 | 36 | def get_min_length(input_tokens: list[list[int]] | None) -> int: 37 | # reduce min length prompt over all processes to have an equal number of call on each process with fsdp 38 | if input_tokens is None: 39 | # TODO: Double check this change from int(1e9) is correct 40 | min_length = 0 41 | else: 42 | min_length = min([len(t) for t in input_tokens]) 43 | if torch.distributed.is_initialized(): 44 | min_length = int(dist_min(min_length)) 45 | return min_length 46 | 47 | 48 | def get_generation_range( 49 | prompt_tokens: list[list[int]] | None, max_gen_len: int 50 | ) -> tuple[int, int]: 51 | batch_min_prompt_length = get_min_length(prompt_tokens) 52 | batch_max_prompt_length = get_max_length(prompt_tokens) 53 | return batch_min_prompt_length, batch_max_prompt_length + max_gen_len 54 | 55 | 56 | def sample_top_k(probs, k): 57 | topk_value, _ = torch.topk(probs, k) # batch_sz x topk 58 | min_value_top_k = topk_value[:, [-1]] 59 | probs[probs < min_value_top_k] = 0.0 60 | probs.div_(probs.sum(dim=-1, keepdim=True)) 61 | next_token = torch.multinomial(probs, num_samples=1) 62 | return next_token 63 | 64 | 65 | def sample_top_p(probs, p): 66 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 67 | probs_sum = torch.cumsum(probs_sort, dim=-1) 68 | mask = probs_sum - probs_sort > p 69 | probs_sort[mask] = 0.0 70 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 71 | next_token = torch.multinomial(probs_sort, num_samples=1) 72 | next_token = torch.gather(probs_idx, -1, next_token) 73 | return next_token 74 | 75 | 76 | @torch.inference_mode() 77 | def generate_nocache( 78 | prompts: list[str] | None, 79 | *, 80 | model: ByteLatentTransformer, 81 | tokenizer: BltTokenizer, 82 | patcher: Patcher, 83 | max_prompt_len: int = 256, 84 | max_gen_len: int = 256, 85 | use_sampling: bool = False, 86 | temp: float = 1.0, 87 | top_k: int = 0, 88 | top_p: float = 0.0, 89 | remove_prompts: bool = True, 90 | ) -> list[list[int]]: 91 | assert ( 92 | patcher.realtime_patching 93 | ), "generate_nocache requires patcher.realtime_patching=True" 94 | model.eval() 95 | if prompts is None: 96 | prompt_tokens = None 97 | n_truncated_prompts = 0 98 | total_truncated_prompts = 0 99 | else: 100 | prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts] 101 | n_truncated_prompts = sum([max_prompt_len < len(t) for t in prompt_tokens]) 102 | total_truncated_prompts = dist_sum(n_truncated_prompts) 103 | 104 | # Truncation 105 | prompt_tokens = [ 106 | t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :] 107 | for t in prompt_tokens 108 | ] 109 | 110 | if total_truncated_prompts > 0: 111 | logger.info( 112 | f"There are {total_truncated_prompts} prompts that are truncated on the left, " 113 | f"length greater than max_prompt_len = {max_prompt_len}, " 114 | f"maximum prompt length = {get_max_length(prompt_tokens)} across all gpus." 115 | ) 116 | 117 | if prompt_tokens is None: 118 | prompt_tokens = [[tokenizer.bos_id] for _ in range(end_pos)] 119 | 120 | start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len) 121 | batch_size = len(prompt_tokens) 122 | tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).cuda().long() 123 | 124 | # Copy inputs to tensor for generated tokens 125 | for i, row_tokens in enumerate(prompt_tokens): 126 | tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long() 127 | input_text_mask = tokens != tokenizer.pad_id 128 | 129 | for i, curr_pos in enumerate(range(start_pos, end_pos)): 130 | current_tokens = tokens[:, :curr_pos] 131 | patch_lengths, _ = patcher.patch(current_tokens, include_next_token=True) 132 | logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1] 133 | 134 | if use_sampling: 135 | probs = torch.softmax(logits / temp, dim=-1) 136 | if top_p > 0.0: 137 | next_token = sample_top_p(probs, top_p) 138 | elif top_k > 0: 139 | next_token = sample_top_k(probs, top_k) 140 | else: 141 | next_token = torch.multinomial(probs, num_samples=1) 142 | else: 143 | next_token = torch.argmax(logits, dim=-1) 144 | 145 | next_token = torch.where( 146 | input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token 147 | ) 148 | tokens[:, curr_pos] = next_token 149 | 150 | if remove_prompts: 151 | generated_tokens = [ 152 | t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist() 153 | for i, t in enumerate(tokens) 154 | ] 155 | else: 156 | generated_tokens = [ 157 | t[: len(prompt_tokens[i]) + max_gen_len].tolist() 158 | for i, t in enumerate(tokens) 159 | ] 160 | return generated_tokens 161 | 162 | 163 | def launch_generate(eval_args: EvalArgs): 164 | assert eval_args.dump_dir is not None 165 | assert eval_args.ckpt_dir is not None 166 | distributed_args = DistributedArgs() 167 | distributed_args.configure_world() 168 | if not torch.distributed.is_initialized(): 169 | setup_torch_distributed(distributed_args) 170 | 171 | world_mesh = get_device_mesh(distributed_args) 172 | dp_mesh = world_mesh["dp_replicate"] 173 | assert distributed_args.dp_shard == 1 174 | world_size = dp_mesh.size() 175 | world_rank = dp_mesh.get_local_rank() 176 | 177 | fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile) 178 | if ( 179 | fs.exists(eval_args.ckpt_dir) 180 | and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json")) 181 | and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0 182 | ): 183 | consolidate_path = eval_args.ckpt_dir 184 | else: 185 | raise ValueError("Did not find a consolidated checkpoint in the ckpt_dir") 186 | 187 | model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( 188 | consolidate_path, 189 | ) 190 | patcher_args = train_cfg.data.patcher_args.model_copy(deep=True) 191 | patcher_args.realtime_patching = True 192 | patcher_args.entropy_model_checkpoint_dir = eval_args.entropy_ckpt_dir 193 | patcher = patcher_args.build() 194 | outputs = generate_nocache( 195 | eval_args.prompts, model=model, tokenizer=tokenizer, patcher=patcher 196 | ) 197 | text_outputs = [tokenizer.decode(t) for t in outputs] 198 | for p, t in zip(eval_args.prompts, text_outputs): 199 | print(f'Prompt: "{p}" Completion: "{t}"') 200 | print() 201 | 202 | 203 | def main(): 204 | eval_args = parse_args_to_pydantic_model(EvalArgs) 205 | launch_generate(eval_args) 206 | 207 | 208 | if __name__ == "__main__": 209 | main() 210 | -------------------------------------------------------------------------------- /bytelatent/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import json 5 | import logging 6 | from collections import namedtuple 7 | from datetime import datetime, timezone 8 | from pathlib import Path 9 | from typing import Any, Union 10 | 11 | import fsspec 12 | import torch 13 | import torch.nn as nn 14 | import wandb 15 | from pydantic import BaseModel, ConfigDict 16 | 17 | from bytelatent.distributed import get_is_master 18 | 19 | logger = logging.getLogger() 20 | 21 | 22 | class WandbArgs(BaseModel): 23 | model_config = ConfigDict(extra="forbid") 24 | job_type: str | None = None 25 | dir: str | None = None 26 | project: str | None = None 27 | entity: str | None = None 28 | tags: list | None = None 29 | group: str | None = None 30 | name: str | None = None 31 | notes: str | None = None 32 | config_exclude_keys: list[str] | None = None 33 | config_include_keys: list[str] | None = None 34 | anonymous: str | None = None 35 | mode: str | None = None 36 | allow_val_change: bool | None = None 37 | resume: Union[bool, str] | None = None 38 | force: bool | None = None 39 | tensorboard: bool | None = None 40 | sync_tensorboard: bool | None = None 41 | monitor_gym: bool | None = None 42 | save_code: bool | None = None 43 | id: str | None = None 44 | fork_from: str | None = None 45 | resume_from: str | None = None 46 | 47 | 48 | class LoggingArgs(BaseModel): 49 | model_config = ConfigDict(extra="forbid") 50 | freq: int = 10 # Log every freq optimizer steps 51 | acc_freq: int | None = None # Log every acc_freq gradient accumulation steps 52 | wandb: WandbArgs | None = None 53 | 54 | 55 | class MetricLogger: 56 | def __init__( 57 | self, 58 | outdir: str, 59 | # args: TrainArgs 60 | args: Any | None = None, 61 | fs: fsspec.AbstractFileSystem | None = None, 62 | ): 63 | self.outdir = outdir 64 | self.jsonl_writer = None 65 | self.fs = fs 66 | self.args = args 67 | 68 | def open(self): 69 | if self.jsonl_writer is None: 70 | if self.fs is None: 71 | self.jsonl_writer = open(self.outdir, "a") 72 | else: 73 | self.jsonl_writer = self.fs.open(self.outdir, "a") 74 | if ( 75 | self.args is not None 76 | and self.args.logging.wandb is not None 77 | and get_is_master() 78 | ): 79 | run = wandb.init( 80 | config=self.args.model_dump(), 81 | **self.args.logging.wandb.model_dump(), 82 | ) 83 | 84 | def log(self, metrics: dict[str, Any]): 85 | if ( 86 | self.args is not None 87 | and self.args.logging.wandb is not None 88 | and (wandb.run is not None) 89 | ): 90 | wandb.log(metrics, step=metrics["global_step"]) 91 | 92 | metrics.update({"created_at": datetime.now(timezone.utc).isoformat()}) 93 | print(json.dumps(metrics), file=self.jsonl_writer, flush=True) 94 | 95 | def close(self): 96 | if self.jsonl_writer is not None: 97 | self.jsonl_writer.close() 98 | self.jsonl_writer = None 99 | 100 | def __enter__(self): 101 | self.open() 102 | return self 103 | 104 | def __exit__(self, exc_type, exc_value, traceback): 105 | self.close() 106 | 107 | def __del__(self): 108 | self.close() 109 | 110 | 111 | GPUMemStats = namedtuple( 112 | "GPUMemStats", 113 | [ 114 | "max_active_gib", 115 | "max_active_pct", 116 | "max_reserved_gib", 117 | "max_reserved_pct", 118 | "num_alloc_retries", 119 | "num_ooms", 120 | "power_draw", 121 | ], 122 | ) 123 | 124 | 125 | class GPUMemoryMonitor: 126 | """ 127 | Class to monitor GPU memory usage 128 | """ 129 | 130 | def __init__(self, device: str = "cuda:0"): 131 | self.device = torch.device(device) # device object 132 | self.device_name = torch.cuda.get_device_name(self.device) 133 | self.device_index = torch.cuda.current_device() 134 | self.device_capacity = torch.cuda.get_device_properties( 135 | self.device 136 | ).total_memory 137 | self.device_capacity_gib = self._to_gib(self.device_capacity) 138 | 139 | # reset stats, clear cache 140 | torch.cuda.reset_peak_memory_stats() 141 | torch.cuda.empty_cache() 142 | 143 | def _to_gib(self, memory_in_bytes): 144 | # NOTE: GiB (gibibyte) is 1024, vs GB is 1000 145 | _gib_in_bytes = 1024 * 1024 * 1024 146 | memory_in_gib = memory_in_bytes / _gib_in_bytes 147 | return memory_in_gib 148 | 149 | def _to_pct(self, memory): 150 | return 100 * memory / self.device_capacity 151 | 152 | def get_peak_stats(self): 153 | cuda_info = torch.cuda.memory_stats(self.device) 154 | 155 | max_active = cuda_info["active_bytes.all.peak"] 156 | max_active_gib = self._to_gib(max_active) 157 | max_active_pct = self._to_pct(max_active) 158 | 159 | max_reserved = cuda_info["reserved_bytes.all.peak"] 160 | max_reserved_gib = self._to_gib(max_reserved) 161 | max_reserved_pct = self._to_pct(max_reserved) 162 | 163 | num_retries = cuda_info["num_alloc_retries"] 164 | num_ooms = cuda_info["num_ooms"] 165 | power_draw = torch.cuda.power_draw() 166 | 167 | if num_retries > 0: 168 | logger.warning(f"{num_retries} CUDA memory allocation retries.") 169 | if num_ooms > 0: 170 | logger.warning(f"{num_ooms} CUDA OOM errors thrown.") 171 | 172 | return GPUMemStats( 173 | max_active_gib, 174 | max_active_pct, 175 | max_reserved_gib, 176 | max_reserved_pct, 177 | num_retries, 178 | num_ooms, 179 | power_draw, 180 | ) 181 | 182 | def reset_peak_stats(self): 183 | torch.cuda.reset_peak_memory_stats() 184 | torch.cuda.reset_accumulated_memory_stats() 185 | 186 | def __str__(self): 187 | mem_stats = self.get_peak_stats() 188 | display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gib} GiB capacity, " 189 | display_str += ( 190 | f"{mem_stats.max_reserved_gib} GiB peak, {mem_stats.max_reserved_pct}% peak" 191 | ) 192 | return f"{display_str}" 193 | 194 | 195 | def upload_train_to_wandb( 196 | ckpt_dir, project="lingua", entity="codegen-team", train=True, eval=True 197 | ): 198 | import json 199 | from pathlib import Path 200 | 201 | import wandb 202 | from omegaconf import OmegaConf 203 | 204 | cfg = OmegaConf.load(Path(ckpt_dir) / "config.yaml") 205 | cfg = OmegaConf.to_container(cfg) 206 | 207 | if train: 208 | wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity) 209 | 210 | with open(Path(ckpt_dir) / "metrics.jsonl") as f: 211 | for l in f: 212 | m = json.loads(l) 213 | wandb.log(m, step=m["global_step"]) 214 | 215 | wandb.finish() 216 | 217 | if eval: 218 | wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity) 219 | 220 | with open(Path(ckpt_dir) / "metrics.eval.jsonl") as f: 221 | for l in f: 222 | m = json.loads(l) 223 | wandb.log( 224 | { 225 | f"evals/{name.replace('/','.')}": value 226 | for name, value in m.items() 227 | if "/" in name 228 | }, 229 | step=m["global_step"], 230 | ) 231 | 232 | wandb.finish() 233 | 234 | 235 | def get_num_params(model: nn.Module) -> int: 236 | """ 237 | Get the total model params 238 | Args : only_trainable: whether to only count trainable params 239 | """ 240 | numel = {n: p.numel() for n, p in model.named_parameters()} 241 | return sum(numel.values()) 242 | --------------------------------------------------------------------------------