├── .github └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── LICENSE ├── README.md ├── RunEvals.md ├── __init__.py ├── _colbert.py ├── benchmark.py ├── download_artifacts_from_wandb.py ├── efficiency ├── README.md ├── multiprocess_bench.py ├── run_multiprocess_inference_bench_base.sh └── run_multiprocess_inference_bench_large.sh ├── environment.yaml ├── eval.py ├── examples ├── README.md ├── evaluate_pylate.py ├── evaluate_st.py ├── finetune_modernbert_on_glue.ipynb ├── train_pylate.py ├── train_st.py └── train_st_gooaq.py ├── generate_eval_config.py ├── glue.py ├── main.py ├── requirements-colbert.txt ├── requirements-cpu.txt ├── requirements-data.txt ├── requirements.txt ├── ruff.toml ├── run_evals.py ├── sequence_classification.py ├── src ├── __init__.py ├── algorithms │ └── rope_schedule.py ├── bert_layers │ ├── __init__.py │ ├── activation.py │ ├── attention.py │ ├── configuration_bert.py │ ├── embeddings.py │ ├── initialization.py │ ├── layers.py │ ├── loss.py │ ├── mlp.py │ ├── model.py │ ├── normalization.py │ ├── options.py │ ├── padding.py │ └── rotary.py ├── bert_padding.py ├── callbacks │ ├── __init__.py │ ├── dataloader_speed.py │ ├── log_grad_norm.py │ ├── packing_efficiency.py │ └── scheduled_gc.py ├── colbert_beir │ ├── __init__.py │ ├── index_and_score.py │ └── train.py ├── convert_dataset.py ├── data │ ├── README.md │ ├── __init__.py │ ├── data_utils.py │ ├── get_counts_from_hf.py │ ├── hf_to_mds.py │ ├── mds_conversion.py │ ├── relative_prop_to_instance_prop.py │ ├── sample_dataset_from_config.py │ └── source_stats.py ├── evals │ ├── README.md │ ├── __init__.py │ ├── data.py │ ├── finetuning_jobs.py │ ├── glue_jobs.py │ ├── misc_jobs.py │ └── superglue_jobs.py ├── flex_bert.py ├── hf_bert.py ├── mosaic_bert.py ├── optimizer.py ├── scheduler.py ├── sequence_packer.py ├── text_data.py └── utils.py ├── tests ├── __init__.py ├── smoketest_config_ablation_eval.yaml ├── smoketest_config_classification.yaml ├── smoketest_config_glue.yaml ├── smoketest_config_main.yaml ├── smoketest_config_sdpa_fa2.yaml ├── smoketest_config_superglue.yaml ├── test_classification.py ├── test_eval.py ├── test_glue.py ├── test_main.py ├── test_mlm_masking.py ├── test_padding.py ├── test_rotary.py ├── test_sdpa_fa2.py ├── test_sequence_packer.py ├── test_superglue.py ├── test_tiling.py └── test_utils.py ├── wandb_log_live_eval.py └── yamls ├── ablations └── example-config.yaml ├── baselines ├── bert-base-uncased-superglue.yaml ├── bert-base-uncased.yaml ├── colbert │ └── bert-base-uncased.yaml ├── deberta-v3-base.yaml └── deberta-v3-long-context.yaml ├── defaults.yaml ├── finetuning ├── glue │ └── mosaic-bert-base-uncased.yaml ├── hf-bert-base-uncased.yaml └── mosaic-bert-base-uncased.yaml ├── main ├── flex-bert-base-parallel.yaml ├── flex-bert-base.yaml ├── flex-bert-rope-base.yaml ├── flex-bert-rope-parallel-firstprenorm.yaml ├── hf-bert-base-uncased.yaml └── mosaic-bert-base-uncased.yaml ├── models ├── flex_bert.yaml ├── hf_bert.yaml └── mosaic_bert.yaml └── test ├── glue.yaml ├── main.yaml └── sequence_classification.yaml /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | **Changes** 2 | 3 | Description of the changes made by this PR and the reasoning behind them. 4 | 5 | **Discussions** 6 | 7 | If any, please include references to the relevant issues/previous PR/discord discussions around these changes. 8 | 9 | **Tests** 10 | 11 | 12 | - [ ] Is the new feature tested? (Not always necessary for all changes -- just adding to the checklist to keep track) 13 | - [ ] Have you ran all the tests? 14 | - [ ] Do the tests all pass? 15 | - [ ] If not, have you included an explanation of which tests this PR breaks and/or why (below this checklisT) 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # vscode 163 | .vscode 164 | 165 | # weights and biases 166 | wandb/ 167 | 168 | # OS X 169 | .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Welcome! 2 | 3 | This is the repository where you can find ModernBERT, our experiments to bring BERT into modernity via both architecture changes and scaling. 4 | 5 | This repository noticeably introduces FlexBERT, our modular approach to encoder building blocks, and heavily relies on .yaml configuration files to build models. The codebase builds upon [MosaicBERT](https://github.com/mosaicml/examples/tree/main/examples/benchmarks/bert), and specifically the [unmerged fork bringing Flash Attention 2](https://github.com/Skylion007/mosaicml-examples/tree/skylion007/add-fa2-to-bert) to it, under the terms of its Apache 2.0 license. We extend our thanks to MosaicML for starting the work on modernising encoders! 6 | 7 | This README is very barebones and is still under construction. It will improve with more reproducibility and documentation in the new year, as we gear up for more encoder niceties after the pre-holidays release of ModernBERT. For now, we're mostly looking forward to seeing what people build with the [🤗 model checkpoints](https://huggingface.co/collections/answerdotai/modernbert-67627ad707a4acbf33c41deb)). 8 | 9 | For more details on what this repository brings, we recommend reading our [release blog post](https://huggingface.co/blog/modernbert) for a high-level overview, and our [arXiv preprint](https://arxiv.org/abs/2412.13663) for more technical details. 10 | 11 | All code used in this repository is the code used as part of our experiments for both pre-training and GLUE evaluations, there's no uncommitted secret training sauce. 12 | 13 | **This is the research repository for ModernBERT, focused on pre-training and evaluations. If you're seeking the HuggingFace version, designed to integrate with any common pipeline, please head to the [ModernBERT Collection on HuggingFace](https://huggingface.co/collections/answerdotai/modernbert-67627ad707a4acbf33c41deb)** 14 | 15 | *ModernBERT is a collaboration between [Answer.AI](https://answer.ai), [LightOn](https://lighton.ai), and friends.* 16 | 17 | ## Setup 18 | 19 | We have fully documented the environment used to train ModernBERT, which can be installed on a GPU-equipped machine with the following commands: 20 | 21 | ```bash 22 | conda env create -f environment.yaml 23 | # if the conda environment errors out set channel priority to flexible: 24 | # conda config --set channel_priority flexible 25 | conda activate bert24 26 | # if using H100s clone and build flash attention 3 27 | # git clone https://github.com/Dao-AILab/flash-attention.git 28 | # cd flash-attention/hopper 29 | # python setup.py install 30 | # install flash attention 2 (model uses FA3+FA2 or just FA2 if FA3 isn't supported) 31 | pip install "flash_attn==2.6.3" --no-build-isolation 32 | # or download a precompiled wheel from https://github.com/Dao-AILab/flash-attention/releases/tag/v2.6.3 33 | # or limit the number of parallel compilation jobs 34 | # MAX_JOBS=8 pip install "flash_attn==2.6.3" --no-build-isolation 35 | ``` 36 | 37 | ## Training 38 | 39 | Training heavily leverages the [composer](https://github.com/mosaicml/composer) framework. All training are configured via YAML files, of which you can find examples in the `yamls` folder. We highly encourage you to check out one of the example yamls, such as `yamls/main/flex-bert-rope-base.yaml`, to explore the configuration options. 40 | 41 | ### Launch command example 42 | To run a training job using `yamls/main/modernbert-base.yaml` on all available GPUs, use the following command. 43 | ``` 44 | composer main.py yamls/main/modernbert-base.yaml 45 | ``` 46 | 47 | ### Data 48 | 49 | There are two dataset classes to choose between: 50 | 51 | `StreamingTextDataset` 52 | * inherits from [StreamingDataset](https://docs.mosaicml.com/projects/streaming/en/latest/preparing_datasets/dataset_format.html) 53 | * uses MDS, CSV/TSV or JSONL format 54 | * Supports both text and tokenized data 55 | * can be used with local data as well 56 | * WARNING: we found distribution of memory over accelerators to be uneven 57 | 58 | `NoStreamingDataset` 59 | * requires decompressed MDS-format, compressed MDS-data can be decompressed using [src/data/mds_conversion.py](src/data/mds_conversion.py) with the `--decompress` flag. 60 | * Supports both text and tokenized data 61 | 62 | When data is being accessed from local, we recommend using `NoStreamingDataset` as it enabled higher training throughput in our setting. Both classes are located in [src/text_data.py](src/text_data.py), and the class to be used for a dataset can be set for each data_loader and dataset by setting streaming: true (StreamingTextDataset) or false (NoStreamingDataset). 63 | 64 | ``` 65 | train_loader: 66 | name: text 67 | dataset: 68 | streaming: false 69 | ``` 70 | 71 | To get started, you can experiment with c4 data using the [following instructions](https://github.com/mosaicml/examples/tree/main/examples/benchmarks/bert#prepare-your-data). 72 | 73 | 74 | ## Evaluations 75 | 76 | ### GLUE 77 | 78 | GLUE evaluations for a ModernBERT model trained with this repository can be ran with via `run_evals.py`, by providing it with a checkpoint and a training config. To evaluate non-ModernBERT models, you should use `glue.py` in conjunction with a slightly different training YAML, of which you can find examples in the `yamls/finetuning` folder. 79 | 80 | ### Retrieval 81 | 82 | The `examples` subfolder contains scripts for training retrieval models, both dense models based on [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) and ColBERT models via the [PyLate](https://github.com/lightonai/pylate) library: 83 | - `examples/train_pylate.py`: The boilerplate code to train a ModernBERT-based ColBERT model with PyLate. 84 | - `examples/train_st.py`: The boilerplate code to train a ModernBERT-based dense retrieval model with Sentence Transformers. 85 | - `examples/evaluate_pylate.py`: The boilerplate code to evaluate a ModernBERT-based ColBERT model with PyLate. 86 | - `examples/evaluate_st.py`: The boilerplate code to evaluate a ModernBERT-based dense retrieval model with Sentence Transformers. 87 | 88 | 89 | ## Reference 90 | 91 | If you use ModernBERT in your work, be it the released models, the intermediate checkpoints (release pending) or this training repository, please cite: 92 | 93 | ```bibtex 94 | @misc{modernbert, 95 | title={Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference}, 96 | author={Benjamin Warner and Antoine Chaffin and Benjamin Clavié and Orion Weller and Oskar Hallström and Said Taghadouini and Alexis Gallagher and Raja Biswas and Faisal Ladhak and Tom Aarsen and Nathan Cooper and Griffin Adams and Jeremy Howard and Iacopo Poli}, 97 | year={2024}, 98 | eprint={2412.13663}, 99 | archivePrefix={arXiv}, 100 | primaryClass={cs.CL}, 101 | url={https://arxiv.org/abs/2412.13663}, 102 | } 103 | ``` -------------------------------------------------------------------------------- /RunEvals.md: -------------------------------------------------------------------------------- 1 | # How to Run Evaluations 2 | 3 | This document explains how to run fine-tuning evaluations for pre-trained models using the scripts `run_evals.py` and `generate_eval_config.py`. These scripts assume you have a pre-trained or finetuned Composer `FlexBert`checkpoint to evaluate. 4 | 5 | ## 1. Optionally Login to Hugging Face 6 | 7 | First, make sure you are logged into Hugging Face, or provide a Hugging Face token to the `hub_token` argument: 8 | 9 | ```bash 10 | huggingface-cli login 11 | ``` 12 | 13 | Follow the prompts to enter your authentication token. 14 | 15 | ## 2. Run Evaluations 16 | 17 | ### Option 1: Run Evaluations for All Checkpoints using `run_evals.py` 18 | 19 | You can use the `run_evals.py` script to run evaluations for all checkpoints in a directory. 20 | 21 | First, view the available arguments: 22 | ```bash 23 | python run_evals.py --help 24 | ``` 25 | 26 | To simplify the process, create a YAML configuration file (e.g., `run_evals_args.yaml`) with the required argument values. For example: 27 | 28 | ```yaml 29 | # Checkpoint & Config Paths 30 | checkpoints: checkpoints 31 | train_config: path/to/training_config.yaml # optional, uses default config if not provided and a wandb run isn't specified 32 | 33 | # Model Options 34 | model_size: base # default FlexBert model config to use 35 | 36 | # Hugging Face Download 37 | hub_repo: {org}/{repo} 38 | hub_token: {your_hf_token} # needed if downloading from a private/gated repo and `huggingface-cli login` wasn't used 39 | hub_files: {checkpoint_files} # optional limit to only download specific repo files or directories 40 | 41 | # Evaluation Tasks 42 | tasks: 43 | - mnli 44 | - sst2 45 | - cola 46 | - mrpc 47 | 48 | # Task Settings 49 | parallel: false 50 | seeds: 51 | - 42 52 | - 314 53 | - 1234 54 | 55 | # Weights & Biases (logging & config downloading) 56 | wandb_run: ${your_pretraining_run_name} # these two options are only needed to download a non-default pretraining config 57 | wandb_project: ${your_pretraining_wandb_project} 58 | 59 | track_run: true # set these options to track the evaluation run in W&B 60 | wandb_entity: ${your_wandb_entity} 61 | track_run_project: ${your_evaluation_wandb_project} 62 | 63 | # GPU Options (which GPUs to use) 64 | gpu_ids: 65 | - 0 66 | - 1 67 | ``` 68 | 69 | Replace the placeholders with your specific values: 70 | 71 | - `{parallel}`: Set to `true` to run evaluations on one checkpoint in parallel. Note that this can randomly error out. 72 | - `{training_config.yaml}`: Path to your optional training configuration file if not using the default config. 73 | - `{org}/{repo}`: Hugging Face Hub repository ID (e.g., `your_org/your_repo` where the Composer checkpoints are stored). 74 | - `{your_hf_token}`: Your Hugging Face authentication token. 75 | - `your_wandb_entity`: Your Weights & Biases entity (username or team name). 76 | - `your_wandb_run_name`: The name of the Weights & Biases run containing the training configuration. 77 | - `your_evaluation_wandb_project`: The name of your Weights & Biases evaluation project to log the eval runs to. 78 | 79 | To run the script, use: 80 | 81 | ```bash 82 | python run_evals.py --config run_evals_args.yaml 83 | ``` 84 | 85 | This will: 86 | 87 | - **Download checkpoints** from the specified Hugging Face repository (if `hub_repo` is provided). 88 | - **Generate evaluation configurations** for the specified tasks. 89 | - **Run evaluations** in parallel on the specified GPUs. 90 | 91 | ### Option 2: Run Evaluation for a Specific Checkpoint using `generate_eval_config.py` and `eval.py` 92 | 93 | If you want to run evaluation for a specific checkpoint, you can use `generate_eval_config.py` to generate the evaluation configuration, and then run `eval.py`. 94 | 95 | #### Step 1: Generate the Evaluation Configuration 96 | 97 | ```bash 98 | python generate_eval_config.py \ 99 | --checkpoint path/to/checkpoint \ 100 | --output-dir configs \ 101 | --model-size base \ 102 | --rope-theta 10000.0 \ 103 | --tasks mnli sst2 \ 104 | --wandb-entity ${your_wandb_entity} \ 105 | --wandb-project ${your_wandb_project} \ 106 | --wandb-run ${your_wandb_run_name} \ 107 | --track-run \ 108 | --track-run-project ${your_wandb_project} 109 | ``` 110 | 111 | Replace the placeholders accordingly: 112 | 113 | - `path/to/checkpoint`: Path to your specific checkpoint file or directory. 114 | - `configs`: Directory where the generated configuration file will be saved. 115 | - `mnli sst2`: List of tasks you want to evaluate. 116 | - `your_wandb_entity`, `your_wandb_project`, `your_wandb_run_name`: Your Weights & Biases details. 117 | 118 | This command will generate a configuration YAML file in the `configs` directory. 119 | 120 | #### Step 2: Run the Evaluation 121 | 122 | ```bash 123 | python eval.py configs/generated_config.yaml 124 | ``` 125 | 126 | Replace `configs/generated_config.yaml` with the actual path to the generated configuration file. 127 | 128 | ## Tips & Tricks 129 | 130 | 1. **Building Evaluation Configurations for Single Tasks** 131 | 132 | If you want to build a fine-tuning evaluation configuration YAML for a single task, you can use `generate_eval_config.py` with the `--tasks` option to specify the task(s). 133 | 134 | For example: 135 | 136 | python generate_eval_config.py \ 137 | --checkpoint path/to/checkpoint \ 138 | --output-dir configs \ 139 | --tasks mnli 140 | 141 | Then, run the evaluation with `eval.py`: 142 | 143 | python eval.py configs/generated_config.yaml 144 | 145 | 2. **Monitoring GPU Usage** 146 | 147 | Install `nvitop` to monitor GPU usage more effectively: 148 | 149 | pip install nvitop 150 | nvitop 151 | 152 | This provides a more useful and user-friendly interface than `nvidia-smi`. 153 | 154 | ## Additional Notes 155 | 156 | - **Parallel Evaluations**: When running evaluations in parallel, you can specify the GPU IDs to use with the `--gpu-ids` option or in the YAML configuration file. 157 | 158 | - **Configurable Options**: Both `run_evals.py` and `generate_eval_config.py` support various options to fine-tune the evaluation process. Use `--help` with these scripts to see all available options. 159 | 160 | python run_evals.py --help 161 | python generate_eval_config.py --help 162 | 163 | - **Using Configuration Files**: You can use YAML configuration files to specify arguments for the scripts, which can simplify command-line usage. Command-line options will override options specified in the configuration file. 164 | 165 | - **Hugging Face Hub Integration**: If you have your checkpoints stored in a private repository on Hugging Face Hub, ensure you have access by logging in via `huggingface-cli login` and providing your token. 166 | 167 | - **Loading Training Configurations**: If you have a training configuration file or a Weights & Biases run containing the training configuration, you can provide it using the `--train-config` or `--wandb-run` options to ensure consistency between training and evaluation. 168 | 169 | ## Optional: Manual Checkpoint Download 170 | 171 | If you prefer to manually download checkpoints instead of using the automatic download feature in `run_evals.py`, you can use `huggingface-cli`: 172 | 173 | Replace `{org}`, `{repo}`, and `{checkpoint_folder}` with the appropriate organization, repository, and checkpoint folder names. 174 | 175 | ### Example Command: 176 | 177 | ```bash 178 | huggingface-cli download {org}/{repo} --include "{checkpoint_folder}/*" --local-dir checkpoints 179 | huggingface-cli download {org}/{repo} --include "{checkpoint_folder}" --local-dir checkpoints 180 | ``` 181 | 182 | ### Notes: 183 | - If there are multiple Composer checkpoints, use the latest one (usually starts with "ep-1"). 184 | - This manual download is optional since `run_evals.py` can automatically download checkpoints when you specify the `hub_repo`, `hub_folder`, and `hub_token` arguments. 185 | 186 | --- 187 | 188 | This README reflects the latest updates in the scripts `run_evals.py` and `generate_eval_config.py`. Be sure to review the scripts and their help messages for the most current information. -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MosaicML Examples authors 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import sys 6 | 7 | # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from 8 | sys.path.append(os.path.dirname(os.path.realpath(__file__))) 9 | 10 | try: 11 | import torch 12 | 13 | # yapf: disable 14 | from src.bert_layers import (BertAlibiEmbeddings, BertAlibiEncoder, BertForMaskedLM, 15 | BertForSequenceClassification, 16 | BertResidualGLU, BertAlibiLayer, 17 | BertLMPredictionHead, BertModel, 18 | BertOnlyMLMHead, BertOnlyNSPHead, BertPooler, 19 | BertPredictionHeadTransform, BertSelfOutput, 20 | BertAlibiUnpadAttention, BertAlibiUnpadSelfAttention) 21 | # yapf: enable 22 | from src.bert_padding import ( 23 | IndexFirstAxis, 24 | IndexPutFirstAxis, 25 | index_first_axis, 26 | index_put_first_axis, 27 | pad_input, 28 | unpad_input, 29 | unpad_input_only, 30 | ) 31 | from src.hf_bert import create_hf_bert_classification, create_hf_bert_mlm 32 | from src.mosaic_bert import create_mosaic_bert_classification, create_mosaic_bert_mlm 33 | except ImportError as e: 34 | try: 35 | is_cuda_available = torch.cuda.is_available() # type: ignore 36 | except Exception: 37 | is_cuda_available = False 38 | 39 | reqs_file = "requirements.txt" if is_cuda_available else "requirements-cpu.txt" 40 | raise ImportError( 41 | f"Please make sure to pip install -r {reqs_file} to get the requirements for the BERT benchmark." 42 | ) from e 43 | 44 | __all__ = [ 45 | "BertAlibiEmbeddings", 46 | "BertAlibiEncoder", 47 | "BertForMaskedLM", 48 | "BertForSequenceClassification", 49 | "BertResidualGLU", 50 | "BertAlibiLayer", 51 | "BertLMPredictionHead", 52 | "BertModel", 53 | "BertOnlyMLMHead", 54 | "BertOnlyNSPHead", 55 | "BertPooler", 56 | "BertPredictionHeadTransform", 57 | "BertSelfOutput", 58 | "BertAlibiUnpadAttention", 59 | "BertAlibiUnpadSelfAttention", 60 | "IndexFirstAxis", 61 | "IndexPutFirstAxis", 62 | "index_first_axis", 63 | "index_put_first_axis", 64 | "pad_input", 65 | "unpad_input", 66 | "unpad_input_only", 67 | "create_hf_bert_classification", 68 | "create_hf_bert_mlm", 69 | "create_mosaic_bert_classification", 70 | "create_mosaic_bert_mlm", 71 | ] 72 | -------------------------------------------------------------------------------- /_colbert.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | import sys 5 | import shutil 6 | 7 | import huggingface_hub 8 | import omegaconf as om 9 | import ir_datasets 10 | 11 | from src.colbert_beir import build_colbert_index, colbert_score, colbert_train 12 | 13 | 14 | def _make_passage(doc): 15 | if hasattr(doc, "title"): 16 | return f"{doc.title}\n{doc.text}" 17 | else: 18 | return doc.text 19 | 20 | 21 | if __name__ == "__main__": 22 | try: 23 | yaml_path, args_list = sys.argv[1], sys.argv[2:] 24 | 25 | with open(yaml_path) as f: 26 | yaml_cfg = om.OmegaConf.load(f) 27 | 28 | cli_cfg = om.OmegaConf.from_cli(args_list) 29 | cfg = om.OmegaConf.merge(yaml_cfg, cli_cfg) 30 | 31 | assert isinstance(cfg, om.DictConfig) 32 | 33 | data_dir = f"{cfg.tmp_dir}/data" 34 | huggingface_hub.snapshot_download(repo_id=cfg.train_dataset_id, repo_type="dataset", local_dir=data_dir) 35 | 36 | if cfg.debug: 37 | import srsly 38 | 39 | triplets_path = f"{data_dir}/triples.train.colbert.jsonl" 40 | triplets = srsly.read_jsonl(triplets_path) 41 | downsampled_triplets = [triplet for i, triplet in enumerate(triplets) if i < 2000] 42 | srsly.write_jsonl(triplets_path, downsampled_triplets) 43 | 44 | model_name = cfg.model_name_or_path.split("/")[-1] if "/" in cfg.model_name_or_path else cfg.model_name_or_path 45 | model_name += "_colbert" 46 | 47 | train_params = cfg.train_params 48 | train_params["root"] = cfg.tmp_dir 49 | train_params["name"] = model_name 50 | 51 | checkpoint = colbert_train( 52 | model_name_or_path=cfg.model_name_or_path, 53 | train_params=train_params, 54 | n_gpu=cfg.n_gpu, 55 | data_path=data_dir, 56 | ) 57 | 58 | for dataset_name in cfg.eval_datasets: 59 | int2docid = {} 60 | docs = [] 61 | ds_split = "" 62 | dataset = ir_datasets.load(dataset_name) 63 | 64 | for i, doc in enumerate(dataset.docs_iter()): 65 | int2docid[i] = doc.doc_id 66 | docs.append(_make_passage(doc)) 67 | 68 | build_colbert_index( 69 | dataset_name=dataset_name, 70 | model_name_or_path=cfg.model_name_or_path, 71 | checkpoint_path=checkpoint, 72 | collection=docs, 73 | tmp_path=cfg.tmp_dir, 74 | ) 75 | score = colbert_score( 76 | model_name_or_path=cfg.model_name_or_path, 77 | dataset_name=dataset_name, 78 | dataset=dataset, 79 | int2docid=int2docid, 80 | tmp_path=cfg.tmp_dir, 81 | ) 82 | print(f"NDCG@10 for {dataset_name}: {score}") 83 | except Exception as e: 84 | print(f"Error: {e}") 85 | finally: 86 | # Clean up ColBERT artifacts 87 | shutil.rmtree("./experiments/default", ignore_errors=True) 88 | shutil.rmtree(cfg.tmp_dir, ignore_errors=True) 89 | -------------------------------------------------------------------------------- /download_artifacts_from_wandb.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | import argparse 5 | import itertools 6 | import os 7 | 8 | import wandb 9 | 10 | 11 | def get_model_artifacts(api, entity, project): 12 | runs = api.runs(f"{entity}/{project}") 13 | return list(itertools.chain(*[[(run, a) for a in run.logged_artifacts() if a.type == "model"] for run in runs])) 14 | 15 | 16 | def get_base_folder(artifact_name): 17 | name = artifact_name.replace("checkpoint-", "") 18 | name = "-".join((name.split("-")[:-1])) 19 | return name 20 | 21 | 22 | def get_ba(artifact_name): 23 | name = artifact_name.replace("checkpoint-", "") 24 | ba = name.split("-")[-1].split(":")[0] 25 | ba = int(ba.replace("ba", "").strip()) 26 | return ba 27 | 28 | 29 | def main(api, args): 30 | print("Fetching all model artifacts...") 31 | artifacts = get_model_artifacts(api, args.wandb_entity, args.wandb_project) 32 | print(f"Found {len(artifacts)} model artifacts.") 33 | 34 | for run, artifact in artifacts: 35 | print(f"Run: {run.name}") 36 | print(f"Artifact: {artifact.name}") 37 | 38 | base_dir = os.path.join(args.local_download_dir, get_base_folder(artifact.name)) 39 | 40 | os.makedirs(base_dir, exist_ok=True) 41 | out_dir = os.path.join(base_dir, artifact.name) 42 | if os.path.exists(out_dir): 43 | print(f"Artifact already exists locally: {out_dir}") 44 | artifact.delete(delete_aliases=True) 45 | continue 46 | 47 | os.makedirs(out_dir, exist_ok=True) 48 | artifact.download(root=out_dir) 49 | 50 | meta_fn = os.path.join(out_dir, "metadata.json") 51 | meta = { 52 | "artifact_id": artifact.id, 53 | "artifact_name": artifact.name, 54 | "artifact_created_at": artifact.created_at, 55 | "artifact_updated_at": artifact.updated_at, 56 | "run_id": run.id, 57 | "run_name": run.name, 58 | "project": args.wandb_project, 59 | "entity": args.wandb_entity, 60 | } 61 | 62 | with open(meta_fn, "w") as fd: 63 | fd.write(wandb.util.json_dumps_safer(meta)) 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = argparse.ArgumentParser(description="Download WandB artifacts") 68 | parser.add_argument("--wandb_entity", help="WandB entity name") 69 | parser.add_argument("--wandb_project", help="WandB project name") 70 | parser.add_argument( 71 | "--local_download_dir", 72 | help="Local directory to download artifacts", 73 | ) 74 | args = parser.parse_args() 75 | 76 | # Create download directory if it doesn't exist 77 | os.makedirs(args.local_download_dir, exist_ok=True) 78 | 79 | # Usage 80 | # crontab -e 81 | # 0 * * * * WANDB_API_KEY=<> python download_artifacts_from_wandb.py >> <> 2>&1 82 | api = wandb.Api(api_key=os.environ.get("WANDB_API_KEY")) 83 | main(api, args) 84 | -------------------------------------------------------------------------------- /efficiency/README.md: -------------------------------------------------------------------------------- 1 | This folder contains the code for batch size checks and basic efficiency tests. Every result was spot checked individually, for each model, to ensure proper batch size and runtime estimates. You may check these results on the [efficiency-spotchecks](https://github.com/AnswerDotAI/ModernBERT/tree/efficiency-spotchecks/efficiency) branch. -------------------------------------------------------------------------------- /efficiency/run_multiprocess_inference_bench_base.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 && python multiprocess_bench.py --model Alibaba-NLP/gte-base-en-v1.5 > gte_inference_times.log 2>&1 & 2 | export CUDA_VISIBLE_DEVICES=1 && python multiprocess_bench.py --model ModernBERT/bert24-base-v2-learning-rate-decay-v2-50B-3-best-and-last-avg > bert24_inference_times.log 2>&1 & 3 | export CUDA_VISIBLE_DEVICES=2 && python multiprocess_bench.py --model bert-base-uncased > bert_inference_times.log 2>&1 & 4 | export CUDA_VISIBLE_DEVICES=3 && python multiprocess_bench.py --model roberta-base > roberta_inference_times.log 2>&1 & 5 | export CUDA_VISIBLE_DEVICES=4 && python multiprocess_bench.py --model microsoft/deberta-v3-base > debertav3_inference_times.log 2>&1 & 6 | export CUDA_VISIBLE_DEVICES=5 && python multiprocess_bench.py --model nomic-ai/nomic-bert-2048 > nomicbert_inference_times.log 2>&1 & 7 | export CUDA_VISIBLE_DEVICES=6 && python multiprocess_bench.py --model Alibaba-NLP/gte-base-en-v1.5 --xformers > gte_xformers_inference_times.log 2>&1 & 8 | -------------------------------------------------------------------------------- /efficiency/run_multiprocess_inference_bench_large.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 && python multiprocess_bench.py --model Alibaba-NLP/gte-en-mlm-large > large_gte_inference_times.log 2>&1 & 2 | export CUDA_VISIBLE_DEVICES=2 && python multiprocess_bench.py --model ModernBERT/bert24-large-v2-learning-rate-decay-v3-50B-ep0-ba9000-rank0 > large_bert24_inference_times.log 2>&1 & 3 | export CUDA_VISIBLE_DEVICES=3 && python multiprocess_bench.py --model bert-large-uncased > large_bert_inference_times.log 2>&1 & 4 | export CUDA_VISIBLE_DEVICES=4 && python multiprocess_bench.py --model roberta-base > roberta_inference_times.log 2>&1 & 5 | export CUDA_VISIBLE_DEVICES=5 && python multiprocess_bench.py --model microsoft/deberta-v3-large > large_debertav3_inference_times.log 2>&1 & 6 | export CUDA_VISIBLE_DEVICES=6 && python multiprocess_bench.py --model nomic-ai/nomic-bert-2048 > nomicbert_inference_times_both.log 2>&1 & 7 | export CUDA_VISIBLE_DEVICES=7 && nohup python multiprocess_bench.py --model Alibaba-NLP/gte-base-en-v1.5 --xformers > gte_xformers_inference_times_both.log 2>&1 & 8 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # ModernBERT retrieval boilerplates 2 | 3 | In this folder, you can find different boilerplates to train and evaluate retrieval models using ModernBERT as the backbone, with [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) for single vector retrieval (DPR) and [PyLate](https://github.com/lightonai/pylate) for multi vector retrieval (ColBERT). 4 | 5 | You can use ```train_st.py``` and ```train_pylate.py``` to train a single vector model using contrastive learning on [MS-MARCO with mined hard negatives](https://huggingface.co/datasets/sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1) and a multi vector model using knowledge distillation on [MS-MARCO with teacher weights from bge-reranker-v2-m3](https://huggingface.co/datasets/lightonai/ms-marco-en-bge) respectively. Alternatively, ```train_st_gooaq.py``` provides a training script for training a single vector model on the [GooAQ](https://huggingface.co/datasets/sentence-transformers/gooaq) question-answer dataset. 6 | 7 | You can launch training on multiple GPUs by using ```accelerate launch --num_processes num_gpu train_st.py``` 8 | 9 | You can then run ```python evaluate_st.py``` or ```python evaluate_pylate.py``` to evaluate the trained models on BEIR datasets. -------------------------------------------------------------------------------- /examples/evaluate_pylate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | from collections import defaultdict 5 | 6 | import srsly 7 | 8 | from pylate import evaluation, indexes, models, retrieve 9 | 10 | eval_datasets = ["scifact", "nfcorpus", "fiqa", "trec-covid"] 11 | model_name = "answerdotai/ModernBERT-base" 12 | model_shortname = model_name.split("/")[-1] 13 | lr = 8e-5 14 | model_results = defaultdict(dict) 15 | run_name = f"{model_shortname}-colbert-KD-{lr}" 16 | output_dir = f"output/{model_shortname}/{run_name}" 17 | model = models.ColBERT( 18 | model_name_or_path=f"{output_dir}/final", 19 | document_length=510, 20 | ) 21 | 22 | for eval_dataset in eval_datasets: 23 | index = indexes.Voyager(index_name=eval_dataset, override=True, M=200, ef_construction=500, ef_search=500) 24 | 25 | retriever = retrieve.ColBERT(index=index) 26 | 27 | documents, queries, qrels = evaluation.load_beir( 28 | dataset_name=eval_dataset, 29 | split="test", 30 | ) 31 | 32 | batch_size = 500 33 | 34 | documents_embeddings = model.encode( 35 | sentences=[document["text"] for document in documents], 36 | batch_size=batch_size, 37 | is_query=False, 38 | show_progress_bar=True, 39 | ) 40 | 41 | index.add_documents( 42 | documents_ids=[document["id"] for document in documents], 43 | documents_embeddings=documents_embeddings, 44 | ) 45 | 46 | queries_embeddings = model.encode( 47 | sentences=queries, 48 | is_query=True, 49 | show_progress_bar=True, 50 | batch_size=16, 51 | ) 52 | 53 | scores = retriever.retrieve(queries_embeddings=queries_embeddings, k=100) 54 | 55 | evaluation_scores = evaluation.evaluate( 56 | scores=scores, 57 | qrels=qrels, 58 | queries=queries, 59 | metrics=["ndcg@10"], 60 | ) 61 | print(f"{model_name} - {lr} - {eval_dataset}") 62 | print(evaluation_scores) 63 | print("-----------") 64 | model_results[eval_dataset] = evaluation_scores 65 | srsly.write_json(f"output/{model_shortname}/{model_shortname}_results.json", model_results) 66 | -------------------------------------------------------------------------------- /examples/evaluate_st.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | import mteb 5 | from sentence_transformers import SentenceTransformer 6 | 7 | model_name = "answerdotai/ModernBERT-base" 8 | lr = 8e-5 9 | model_shortname = model_name.split("/")[-1] 10 | run_name = f"{model_shortname}-DPR-{lr}" 11 | output_dir = f"output/{model_shortname}/{run_name}" 12 | model = SentenceTransformer(f"{output_dir}/final") 13 | 14 | task_names = ["SciFact", "NFCorpus", "FiQA2018", "TRECCOVID"] 15 | tasks = mteb.get_tasks(tasks=task_names) 16 | evaluation = mteb.MTEB(tasks=tasks) 17 | results = evaluation.run(model, output_folder=f"results/{run_name}") 18 | -------------------------------------------------------------------------------- /examples/train_pylate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | from datasets import load_dataset 5 | from pylate import losses, models, utils 6 | from sentence_transformers import ( 7 | SentenceTransformerTrainer, 8 | SentenceTransformerTrainingArguments, 9 | ) 10 | 11 | def main(): 12 | # Load the datasets required for knowledge distillation (train, queries, documents) 13 | train = load_dataset( 14 | path="lightonai/ms-marco-en-bge", 15 | name="train", 16 | ) 17 | 18 | queries = load_dataset( 19 | path="lightonai/ms-marco-en-bge", 20 | name="queries", 21 | ) 22 | 23 | documents = load_dataset( 24 | path="lightonai/ms-marco-en-bge", 25 | name="documents", 26 | ) 27 | 28 | # Set the transformation to load the documents/queries texts using the corresponding ids on the fly 29 | train.set_transform( 30 | utils.KDProcessing(queries=queries, documents=documents).transform, 31 | ) 32 | 33 | # Define the base model, training parameters, and output directory 34 | num_train_epochs = 1 35 | lr = 8e-5 36 | batch_size = 16 37 | accum_steps = 1 38 | model_name = "answerdotai/ModernBERT-base" 39 | model_shortname = model_name.split("/")[-1] 40 | 41 | # Set the run name for logging and output directory 42 | run_name = f"{model_shortname}-colbert-KD-{lr}" 43 | output_dir = f"output/{model_shortname}/{run_name}" 44 | 45 | # Initialize the ColBERT model from the base model 46 | model = models.ColBERT(model_name_or_path=model_name) 47 | 48 | # Configure the training arguments (e.g., epochs, batch size, learning rate) 49 | args = SentenceTransformerTrainingArguments( 50 | output_dir=output_dir, 51 | num_train_epochs=num_train_epochs, 52 | per_device_train_batch_size=batch_size, 53 | fp16=False, # Set to False if you get an error that your GPU can't run on FP16 54 | bf16=True, # Set to True if you have a GPU that supports BF16 55 | run_name=run_name, 56 | logging_steps=10, 57 | learning_rate=lr, 58 | gradient_accumulation_steps=accum_steps, 59 | warmup_ratio=0.05, 60 | ) 61 | 62 | # Use the Distillation loss function for training 63 | train_loss = losses.Distillation(model=model) 64 | 65 | # Initialize the trainer 66 | trainer = SentenceTransformerTrainer( 67 | model=model, 68 | args=args, 69 | train_dataset=train, 70 | loss=train_loss, 71 | data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize), 72 | ) 73 | 74 | # Start the training process 75 | trainer.train() 76 | 77 | model.save_pretrained(f"{output_dir}/final") 78 | 79 | if __name__ == "__main__": 80 | main() -------------------------------------------------------------------------------- /examples/train_st.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | import argparse 5 | 6 | from datasets import load_dataset 7 | from sentence_transformers import ( 8 | SentenceTransformer, 9 | SentenceTransformerTrainer, 10 | SentenceTransformerTrainingArguments, 11 | ) 12 | from sentence_transformers.evaluation import TripletEvaluator 13 | from sentence_transformers.losses import CachedMultipleNegativesRankingLoss 14 | from sentence_transformers.training_args import BatchSamplers 15 | 16 | def main(): 17 | # parse the lr & model name 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--lr", type=float, default=8e-5) 20 | parser.add_argument("--model_name", type=str, default="answerdotai/ModernBERT-base") 21 | args = parser.parse_args() 22 | lr = args.lr 23 | model_name = args.model_name 24 | model_shortname = model_name.split("/")[-1] 25 | 26 | # 1. Load a model to finetune 27 | model = SentenceTransformer(model_name) 28 | 29 | # 2. Load a dataset to finetune on 30 | dataset = load_dataset( 31 | "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", 32 | "triplet-hard", 33 | split="train", 34 | ) 35 | dataset_dict = dataset.train_test_split(test_size=1_000, seed=12) 36 | train_dataset = dataset_dict["train"].select(range(1_250_000)) 37 | eval_dataset = dataset_dict["test"] 38 | 39 | # 3. Define a loss function 40 | loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=16) # Increase mini_batch_size if you have enough VRAM 41 | 42 | run_name = f"{model_shortname}-DPR-{lr}" 43 | # 4. (Optional) Specify training arguments 44 | args = SentenceTransformerTrainingArguments( 45 | # Required parameter: 46 | output_dir=f"output/{model_shortname}/{run_name}", 47 | # Optional training parameters: 48 | num_train_epochs=1, 49 | per_device_train_batch_size=512, 50 | per_device_eval_batch_size=512, 51 | warmup_ratio=0.05, 52 | fp16=False, # Set to False if GPU can't handle FP16 53 | bf16=True, # Set to True if GPU supports BF16 54 | batch_sampler=BatchSamplers.NO_DUPLICATES, # (Cached)MultipleNegativesRankingLoss benefits from no duplicates 55 | learning_rate=lr, 56 | # Optional tracking/debugging parameters: 57 | save_strategy="steps", 58 | save_steps=500, 59 | save_total_limit=2, 60 | logging_steps=500, 61 | run_name=run_name, # Used in `wandb`, `tensorboard`, `neptune`, etc. if installed 62 | ) 63 | 64 | # 5. (Optional) Create an evaluator & evaluate the base model 65 | dev_evaluator = TripletEvaluator( 66 | anchors=eval_dataset["query"], 67 | positives=eval_dataset["positive"], 68 | negatives=eval_dataset["negative"], 69 | name="msmarco-co-condenser-dev", 70 | ) 71 | dev_evaluator(model) 72 | 73 | # 6. Create a trainer & train 74 | trainer = SentenceTransformerTrainer( 75 | model=model, 76 | args=args, 77 | train_dataset=train_dataset, 78 | eval_dataset=eval_dataset, 79 | loss=loss, 80 | evaluator=dev_evaluator, 81 | ) 82 | trainer.train() 83 | 84 | # 7. (Optional) Evaluate the trained model on the evaluator after training 85 | dev_evaluator(model) 86 | 87 | # 8. Save the model 88 | model.save_pretrained(f"output/{model_shortname}/{run_name}/final") 89 | 90 | # 9. (Optional) Push it to the Hugging Face Hub 91 | model.push_to_hub(run_name, private=False) 92 | 93 | if __name__ == "__main__": 94 | main() -------------------------------------------------------------------------------- /examples/train_st_gooaq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | import argparse 5 | 6 | from datasets import load_dataset 7 | from sentence_transformers import ( 8 | SentenceTransformer, 9 | SentenceTransformerTrainer, 10 | SentenceTransformerTrainingArguments, 11 | ) 12 | from sentence_transformers.evaluation import NanoBEIREvaluator 13 | from sentence_transformers.losses import CachedMultipleNegativesRankingLoss 14 | from sentence_transformers.training_args import BatchSamplers 15 | 16 | def main(): 17 | # parse the lr & model name 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--lr", type=float, default=8e-5) 20 | parser.add_argument("--model_name", type=str, default="answerdotai/ModernBERT-base") 21 | args = parser.parse_args() 22 | lr = args.lr 23 | model_name = args.model_name 24 | model_shortname = model_name.split("/")[-1] 25 | 26 | # 1. Load a model to finetune 27 | model = SentenceTransformer(model_name) 28 | 29 | # 2. Load a dataset to finetune on 30 | dataset = load_dataset("sentence-transformers/gooaq", split="train") 31 | dataset_dict = dataset.train_test_split(test_size=1_000, seed=12) 32 | train_dataset = dataset_dict["train"] 33 | eval_dataset = dataset_dict["test"] 34 | 35 | # 3. Define a loss function 36 | loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=128) # Increase mini_batch_size if you have enough VRAM 37 | 38 | run_name = f"{model_shortname}-gooaq-{lr}" 39 | # 4. (Optional) Specify training arguments 40 | args = SentenceTransformerTrainingArguments( 41 | # Required parameter: 42 | output_dir=f"output/{model_shortname}/{run_name}", 43 | # Optional training parameters: 44 | num_train_epochs=1, 45 | per_device_train_batch_size=2048, 46 | per_device_eval_batch_size=2048, 47 | learning_rate=lr, 48 | warmup_ratio=0.05, 49 | fp16=False, # Set to False if GPU can't handle FP16 50 | bf16=True, # Set to True if GPU supports BF16 51 | batch_sampler=BatchSamplers.NO_DUPLICATES, # (Cached)MultipleNegativesRankingLoss benefits from no duplicates 52 | # Optional tracking/debugging parameters: 53 | eval_strategy="steps", 54 | eval_steps=50, 55 | save_strategy="steps", 56 | save_steps=50, 57 | save_total_limit=2, 58 | logging_steps=10, 59 | run_name=run_name, # Used in `wandb`, `tensorboard`, `neptune`, etc. if installed 60 | ) 61 | 62 | # 5. (Optional) Create an evaluator & evaluate the base model 63 | dev_evaluator = NanoBEIREvaluator(dataset_names=["NQ", "MSMARCO"]) 64 | dev_evaluator(model) 65 | 66 | # 6. Create a trainer & train 67 | trainer = SentenceTransformerTrainer( 68 | model=model, 69 | args=args, 70 | train_dataset=train_dataset, 71 | eval_dataset=eval_dataset, 72 | loss=loss, 73 | evaluator=dev_evaluator, 74 | ) 75 | trainer.train() 76 | 77 | # 7. (Optional) Evaluate the trained model on the evaluator after training 78 | dev_evaluator(model) 79 | 80 | # 8. Save the model 81 | model.save_pretrained(f"output/{model_shortname}/{run_name}/final") 82 | 83 | # 9. (Optional) Push it to the Hugging Face Hub 84 | model.push_to_hub(run_name, private=False) 85 | 86 | if __name__ == "__main__": 87 | main() -------------------------------------------------------------------------------- /requirements-colbert.txt: -------------------------------------------------------------------------------- 1 | colbert-ai=0.2.19 2 | faiss-gpu 3 | ranx 4 | ir_datasets -------------------------------------------------------------------------------- /requirements-cpu.txt: -------------------------------------------------------------------------------- 1 | einops==0.8.0 2 | torch==2.3.0 3 | mosaicml[nlp,wandb]>=0.22.0,<0.23 4 | mosaicml-streaming==0.7.6 5 | omegaconf==2.3.0 6 | transformers==4.40.2 7 | -------------------------------------------------------------------------------- /requirements-data.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | datatrove==0.2.0 3 | huggingface_hub==0.23.1 4 | pyyaml 5 | ruff 6 | tqdm 7 | streaming -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.8.0 2 | torch==2.3.0 3 | mosaicml[nlp,wandb]>=0.22.0,<0.23 4 | mosaicml-streaming==0.7.6 5 | omegaconf==2.3.0 6 | transformers==4.40.2 7 | triton==2.3.0 8 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | line-length = 120 2 | target-version = "py311" 3 | ignore = ["E402"] -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MosaicML Examples authors 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import sys 6 | 7 | # Add src folder root to path to allow us to use relative imports regardless of what directory the script is run from 8 | sys.path.append(os.path.dirname(os.path.realpath(__file__))) 9 | 10 | # yapf: disable 11 | from bert_layers import (BertAlibiEmbeddings, BertAlibiEncoder, BertForMaskedLM, 12 | BertForSequenceClassification, BertResidualGLU, 13 | BertAlibiLayer, BertLMPredictionHead, BertModel, 14 | BertOnlyMLMHead, BertOnlyNSPHead, BertPooler, 15 | BertPredictionHeadTransform, BertSelfOutput, 16 | BertAlibiUnpadAttention, BertAlibiUnpadSelfAttention) 17 | # yapf: enable 18 | from bert_padding import ( 19 | IndexFirstAxis, 20 | IndexPutFirstAxis, 21 | index_first_axis, 22 | index_put_first_axis, 23 | pad_input, 24 | unpad_input, 25 | unpad_input_only, 26 | ) 27 | from src.bert_layers.configuration_bert import BertConfig 28 | 29 | from hf_bert import create_hf_bert_classification, create_hf_bert_mlm 30 | from mosaic_bert import create_mosaic_bert_classification, create_mosaic_bert_mlm 31 | 32 | __all__ = [ 33 | "BertAlibiEmbeddings", 34 | "BertAlibiEncoder", 35 | "BertForMaskedLM", 36 | "BertForSequenceClassification", 37 | "BertResidualGLU", 38 | "BertAlibiLayer", 39 | "BertLMPredictionHead", 40 | "BertModel", 41 | "BertOnlyMLMHead", 42 | "BertOnlyNSPHead", 43 | "BertPooler", 44 | "BertPredictionHeadTransform", 45 | "BertSelfOutput", 46 | "BertAlibiUnpadAttention", 47 | "BertAlibiUnpadSelfAttention", 48 | "BertConfig", 49 | "IndexFirstAxis", 50 | "IndexPutFirstAxis", 51 | "index_first_axis", 52 | "index_put_first_axis", 53 | "pad_input", 54 | "unpad_input", 55 | "unpad_input_only", 56 | "create_hf_bert_classification", 57 | "create_hf_bert_mlm", 58 | "create_mosaic_bert_classification", 59 | "create_mosaic_bert_mlm", 60 | # These are commented out because they only exist if CUDA is available 61 | # 'flash_attn_func_bert', 62 | # 'flash_attn_qkvpacked_func_bert' 63 | ] 64 | -------------------------------------------------------------------------------- /src/algorithms/rope_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | from __future__ import annotations 5 | 6 | import logging 7 | from typing import Union 8 | 9 | from composer import Time 10 | import torch.nn as nn 11 | 12 | from composer.core import Algorithm, Event, State 13 | from composer.loggers import Logger 14 | 15 | from src.bert_layers.attention import FlexBertAttentionBase 16 | from src.bert_layers.model import FlexBertPreTrainedModel 17 | 18 | try: 19 | from flash_attn.layers.rotary import RotaryEmbedding # type: ignore 20 | from src.bert_layers.rotary import UnpaddedRotaryEmbedding # type: ignore 21 | 22 | except ImportError: 23 | RotaryEmbedding = None 24 | UnpaddedRotaryEmbedding = None 25 | 26 | 27 | log = logging.getLogger(__name__) 28 | 29 | __all__ = ["FlexBertRopeSchedule"] 30 | 31 | 32 | class FlexBertRopeSchedule(Algorithm): 33 | def __init__( 34 | self, 35 | min_rope_theta: int, 36 | max_rope_theta: int, 37 | warmup_tokens: Union[str, Time, int], 38 | rope_theta_increment: int = 10_000, 39 | target_layer: nn.Module = UnpaddedRotaryEmbedding, 40 | ignore_sliding_window: bool = True, 41 | batch_log_interval: int = 10, 42 | increment_theta_immediately: bool = False, 43 | ): 44 | if isinstance(warmup_tokens, str): 45 | warmup_tokens = Time.from_timestring(warmup_tokens).value 46 | elif isinstance(warmup_tokens, Time): 47 | warmup_tokens = warmup_tokens.value 48 | self.min_rope_theta = min_rope_theta 49 | self.max_rope_theta = max_rope_theta 50 | self.rope_theta_increment = rope_theta_increment 51 | self.target_layer = target_layer 52 | self.ignore_sliding_window = ignore_sliding_window 53 | self.batch_log_interval = batch_log_interval 54 | self.increment_theta_immediately = increment_theta_immediately 55 | self._rotary_layers = [] 56 | self.warmup_tokens = warmup_tokens # Store warmup_tokens for recalculations 57 | self._min_theta = self.min_rope_theta 58 | self._calculate_increase_every_tokens() 59 | 60 | def _calculate_increase_every_tokens(self): 61 | self._increase_every_tokens = self.warmup_tokens // ( 62 | (self.max_rope_theta - self._min_theta) / self.rope_theta_increment 63 | ) 64 | 65 | def match(self, event: Event, state: State) -> bool: 66 | return event in [Event.INIT, Event.FIT_START, Event.BATCH_START, Event.BATCH_END] 67 | 68 | def apply(self, event: Event, state: State, logger: Logger) -> None: 69 | if event == Event.FIT_START: 70 | flexbert = False 71 | self._current_theta = self._min_theta 72 | for layer in state.model.modules(): 73 | if isinstance(layer, FlexBertPreTrainedModel): 74 | flexbert = True 75 | if isinstance(layer, FlexBertAttentionBase): 76 | if hasattr(layer, "rotary_emb") and isinstance(layer.rotary_emb, self.target_layer): 77 | if ( 78 | not self.ignore_sliding_window 79 | or (hasattr(layer, "sliding_window") and layer.sliding_window == (-1, -1)) 80 | or not hasattr(layer, "sliding_window") 81 | ): 82 | self._rotary_layers.append(layer.rotary_emb) 83 | if layer.rotary_emb.base != self.min_rope_theta: 84 | raise ValueError(f"{self.min_rope_theta=} does not match the Rotary Embedding's RoPE theta {layer.rotary_emb.base}") # fmt: skip 85 | if self.increment_theta_immediately: 86 | # Increase the RoPE theta by rope_theta_increment 87 | layer.rotary_emb.base += self.rope_theta_increment 88 | if self.increment_theta_immediately: 89 | self._min_theta += self.rope_theta_increment 90 | self._current_theta = self._min_theta 91 | self._calculate_increase_every_tokens() 92 | if not flexbert: 93 | raise ValueError("Rope Schedule only works with a FlexBertPreTrainedModel") 94 | assert len(self._rotary_layers) > 0, "No layers found to apply Rope Schedule to." 95 | 96 | if event == Event.BATCH_START and state.timestamp.batch.value % self.batch_log_interval == 0: 97 | logger.log_metrics({"trainer/rope_theta": self._current_theta}) 98 | 99 | if event == Event.BATCH_END and self._current_theta != self.max_rope_theta: 100 | tokens = state.timestamp.token.value 101 | 102 | # Calculate the expected number of increments 103 | increments = int(tokens // self._increase_every_tokens) 104 | desired_theta = min(self.max_rope_theta, self._min_theta + increments * self.rope_theta_increment) 105 | 106 | # Check if we need to update the RoPE theta value 107 | if desired_theta > self._current_theta: 108 | self._current_theta = desired_theta 109 | for rotary_emb in self._rotary_layers: 110 | rotary_emb.base = self._current_theta 111 | -------------------------------------------------------------------------------- /src/bert_layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import ( 2 | BertAlibiUnpadAttention, 3 | BertAlibiUnpadSelfAttention, 4 | BertSelfOutput, 5 | FlexBertPaddedAttention, 6 | FlexBertUnpadAttention, 7 | ) 8 | from .embeddings import ( 9 | BertAlibiEmbeddings, 10 | FlexBertAbsoluteEmbeddings, 11 | FlexBertSansPositionEmbeddings, 12 | ) 13 | from .layers import ( 14 | BertAlibiEncoder, 15 | BertAlibiLayer, 16 | BertResidualGLU, 17 | FlexBertPaddedPreNormLayer, 18 | FlexBertPaddedPostNormLayer, 19 | FlexBertUnpadPostNormLayer, 20 | FlexBertUnpadPreNormLayer, 21 | ) 22 | from .model import ( 23 | BertLMPredictionHead, 24 | BertModel, 25 | BertForMaskedLM, 26 | BertForSequenceClassification, 27 | BertForMultipleChoice, 28 | BertOnlyMLMHead, 29 | BertOnlyNSPHead, 30 | BertPooler, 31 | BertPredictionHeadTransform, 32 | FlexBertModel, 33 | FlexBertForMaskedLM, 34 | FlexBertForSequenceClassification, 35 | FlexBertForMultipleChoice, 36 | ) 37 | 38 | 39 | __all__ = [ 40 | "BertAlibiEmbeddings", 41 | "BertAlibiEncoder", 42 | "BertForMaskedLM", 43 | "BertForSequenceClassification", 44 | "BertForMultipleChoice", 45 | "BertResidualGLU", 46 | "BertAlibiLayer", 47 | "BertLMPredictionHead", 48 | "BertModel", 49 | "BertOnlyMLMHead", 50 | "BertOnlyNSPHead", 51 | "BertPooler", 52 | "BertPredictionHeadTransform", 53 | "BertSelfOutput", 54 | "BertAlibiUnpadAttention", 55 | "BertAlibiUnpadSelfAttention", 56 | "FlexBertPaddedAttention", 57 | "FlexBertUnpadAttention", 58 | "FlexBertAbsoluteEmbeddings", 59 | "FlexBertSansPositionEmbeddings", 60 | "FlexBertPaddedPreNormLayer", 61 | "FlexBertPaddedPostNormLayer", 62 | "FlexBertUnpadPostNormLayer", 63 | "FlexBertUnpadPreNormLayer", 64 | "FlexBertModel", 65 | "FlexBertForMaskedLM", 66 | "FlexBertForSequenceClassification", 67 | "FlexBertForMultipleChoice", 68 | ] 69 | -------------------------------------------------------------------------------- /src/bert_layers/activation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | # Copyright 2020 The HuggingFace Team. 5 | # License: Apache-2.0 6 | 7 | from collections import OrderedDict 8 | from typing import Union 9 | import torch.nn as nn 10 | from .configuration_bert import FlexBertConfig 11 | 12 | 13 | class ClassInstantier(OrderedDict): 14 | def __getitem__(self, key): 15 | content = super().__getitem__(key) 16 | cls, kwargs = content if isinstance(content, tuple) else (content, {}) 17 | return cls(**kwargs) 18 | 19 | 20 | ACT2CLS = { 21 | "celu": nn.CELU, 22 | "elu": nn.ELU, 23 | "gelu": nn.GELU, 24 | "gelu_tanh": (nn.GELU, {"approximate": "tanh"}), 25 | "hardtanh": nn.Hardtanh, 26 | "hardsigmoid": nn.Hardsigmoid, 27 | "hardshrink": nn.Hardshrink, 28 | "hardswish": nn.Hardswish, 29 | "leaky_relu": nn.LeakyReLU, 30 | "logsigmoid": nn.LogSigmoid, 31 | "mish": nn.Mish, 32 | "prelu": nn.PReLU, 33 | "relu": nn.ReLU, 34 | "relu6": nn.ReLU6, 35 | "rrelu": nn.RReLU, 36 | "selu": nn.SELU, 37 | "sigmoid": nn.Sigmoid, 38 | "silu": nn.SiLU, 39 | "softmin": nn.Softmin, 40 | "softplus": nn.Softplus, 41 | "softshrink": nn.Softshrink, 42 | "softsign": nn.Softsign, 43 | "swish": nn.SiLU, 44 | "tanh": nn.Tanh, 45 | "tanhshrink": nn.Tanhshrink, 46 | "threshold": nn.Threshold, 47 | } 48 | ACT2FN = ClassInstantier(ACT2CLS) 49 | 50 | 51 | def get_act_fn(config: Union[FlexBertConfig, str]) -> nn.Module: 52 | try: 53 | if isinstance(config, str): 54 | return ACT2FN[config] 55 | return ACT2FN[config.hidden_act] 56 | except KeyError: 57 | if isinstance(config, str): 58 | raise ValueError(f"Invalid activation function type: {config}, must be one of {ACT2FN.keys()}.") 59 | else: 60 | raise ValueError(f"Invalid activation function type: {config.hidden_act=}, must be one of {ACT2FN.keys()}.") 61 | -------------------------------------------------------------------------------- /src/bert_layers/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | import inspect 5 | import torch.nn as nn 6 | from .configuration_bert import FlexBertConfig 7 | 8 | try: 9 | from flash_attn.losses.cross_entropy import CrossEntropyLoss 10 | except ImportError: 11 | CrossEntropyLoss = None 12 | 13 | LOSS2CLS = { 14 | "cross_entropy": nn.CrossEntropyLoss, 15 | "binary_cross_entropy": nn.BCEWithLogitsLoss, 16 | "mean_squared_error": nn.MSELoss, 17 | } 18 | 19 | if CrossEntropyLoss is not None: 20 | LOSS2CLS["fa_cross_entropy"] = CrossEntropyLoss 21 | 22 | 23 | def get_loss_fn(config: FlexBertConfig) -> nn.Module: 24 | try: 25 | loss_class = LOSS2CLS[config.loss_function] 26 | signature = inspect.signature(loss_class) 27 | loss_kwargs = {k: v for k, v in config.loss_kwargs.items() if k in signature.parameters} 28 | return loss_class(**loss_kwargs) 29 | except KeyError: 30 | raise ValueError(f"Invalid loss function type: {config.loss_function}, must be one of {LOSS2CLS.keys()}.") 31 | -------------------------------------------------------------------------------- /src/bert_layers/normalization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | # RMSNorm Implementation: Copyright Meta (from their Llama RMSNorm implementation) 5 | # License: LLAMA 2 COMMUNITY LICENSE AGREEMENT 6 | 7 | 8 | import inspect 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import init 12 | 13 | from .configuration_bert import FlexBertConfig 14 | 15 | try: 16 | from flash_attn.ops.triton.layer_norm import RMSNorm as TritonRMSNorm 17 | from flash_attn.ops.triton.layer_norm import layer_norm_fn 18 | 19 | except ImportError: 20 | TritonRMSNorm = None 21 | layer_norm_fn = None 22 | 23 | 24 | class RMSNorm(nn.Module): 25 | """Llama2 RMSNorm implementation""" 26 | 27 | def __init__(self, dim: int, eps: float = 1e-5): 28 | """ 29 | Initialize the RMSNorm normalization layer. 30 | 31 | Args: 32 | dim (int): The dimension of the input tensor. 33 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 34 | 35 | Attributes: 36 | eps (float): A small value added to the denominator for numerical stability. 37 | weight (nn.Parameter): Learnable scaling parameter. 38 | 39 | """ 40 | super().__init__() 41 | self.eps = eps 42 | self.weight = nn.Parameter(torch.ones(dim)) 43 | 44 | def _norm(self, x): 45 | """ 46 | Apply the RMSNorm normalization to the input tensor. 47 | 48 | Args: 49 | x (torch.Tensor): The input tensor. 50 | 51 | Returns: 52 | torch.Tensor: The normalized tensor. 53 | 54 | """ 55 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 56 | 57 | def forward(self, x): 58 | """ 59 | Forward pass through the RMSNorm layer. 60 | 61 | Args: 62 | x (torch.Tensor): The input tensor. 63 | 64 | Returns: 65 | torch.Tensor: The output tensor after applying RMSNorm. 66 | 67 | """ 68 | output = self._norm(x.float()).type_as(x) 69 | return output * self.weight 70 | 71 | def reset_parameters(self): 72 | init.ones_(self.weight) 73 | 74 | 75 | if layer_norm_fn is not None: 76 | 77 | class TritonLayerNorm(nn.LayerNorm): 78 | def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): 79 | return layer_norm_fn( 80 | x, 81 | self.weight, 82 | self.bias, 83 | residual=residual, 84 | eps=self.eps, 85 | prenorm=prenorm, 86 | residual_in_fp32=residual_in_fp32, 87 | ) 88 | else: 89 | TritonLayerNorm = None 90 | 91 | NORM2CLS = { 92 | "layernorm": nn.LayerNorm, 93 | "triton_layernorm": TritonLayerNorm if TritonLayerNorm is not None else nn.LayerNorm, 94 | "rmsnorm": RMSNorm, 95 | "triton_rmsnorm": TritonRMSNorm if TritonRMSNorm is not None else RMSNorm, 96 | } 97 | 98 | 99 | def get_norm_layer(config: FlexBertConfig, compiled_norm: bool = False) -> nn.Module: 100 | try: 101 | if compiled_norm: 102 | # Use non-Triton norms when compiling 103 | if config.normalization.startswith("triton_"): 104 | norm = config.normalization.replace("triton_", "") 105 | else: 106 | norm = config.normalization 107 | else: 108 | norm = config.normalization 109 | signature = inspect.signature(NORM2CLS[norm]) 110 | if hasattr(config, "norm_kwargs"): 111 | norm_kwargs = {k: v for k, v in config.norm_kwargs.items() if k in signature.parameters} 112 | else: 113 | norm_kwargs = {} 114 | return NORM2CLS[norm](config.hidden_size, **norm_kwargs) 115 | except KeyError: 116 | raise ValueError(f"Invalid normalization layer type: {config.normalization}, must be one of {NORM2CLS.keys()}.") 117 | -------------------------------------------------------------------------------- /src/bert_layers/options.py: -------------------------------------------------------------------------------- 1 | from .normalization import NORM2CLS 2 | from .embeddings import EBB2CLS 3 | from .activation import ACT2CLS 4 | from .attention import ATTN2CLS 5 | from .mlp import MLP2CLS 6 | from .layers import LAYER2CLS 7 | 8 | 9 | def print_layer_options(): 10 | print("Activation options:") 11 | for option in ACT2CLS: 12 | print(f" {option}") 13 | 14 | print("\nAttention Layer options:") 15 | for option in ATTN2CLS: 16 | print(f" {option}") 17 | 18 | print("\nEmbedding Layer options:") 19 | for option in EBB2CLS: 20 | print(f" {option}") 21 | 22 | print("\nBert Layer options:") 23 | for option in LAYER2CLS: 24 | print(f" {option}") 25 | 26 | print("\nMLP Layer options:") 27 | for option in MLP2CLS: 28 | print(f" {option}") 29 | 30 | print("\nNormalization options:") 31 | for option in NORM2CLS: 32 | print(f" {option}") 33 | -------------------------------------------------------------------------------- /src/bert_layers/padding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import Optional, Tuple 4 | import torch.nn.functional as F 5 | 6 | 7 | def unpad_input( 8 | inputs: Tensor, 9 | attention_mask: Tensor, 10 | position_ids: Optional[Tensor] = None, 11 | labels: Optional[Tensor] = None, 12 | ) -> Tuple[Tensor, Tensor, Tensor, int, Optional[Tensor], Optional[Tensor]]: 13 | """ 14 | Remove padding from input sequences. 15 | 16 | Args: 17 | inputs: (batch, seqlen, ...) or (batch, seqlen) 18 | attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. 19 | position_ids: (batch, seqlen), int, position ids 20 | labels: (batch, seqlen), int, labels 21 | 22 | Returns: 23 | unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. 24 | indices: (total_nnz) 25 | cu_seqlens: (batch + 1), the cumulative sequence lengths 26 | max_seqlen_in_batch: int 27 | unpadded_position_ids: (total_nnz) or None 28 | unpadded_labels: (total_nnz) or None 29 | """ 30 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 31 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 32 | max_seqlen_in_batch = int(seqlens_in_batch.max().item()) 33 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 34 | 35 | if inputs.dim() == 2: 36 | unpadded_inputs = inputs.flatten()[indices] 37 | else: 38 | batch, seqlen, *rest = inputs.shape 39 | shape = batch * seqlen 40 | unpadded_inputs = inputs.view(shape, *rest)[indices] 41 | 42 | unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None 43 | unpadded_labels = labels.flatten()[indices] if labels is not None else None 44 | 45 | return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels 46 | 47 | 48 | def pad_input( 49 | inputs: Tensor, 50 | indices: Tensor, 51 | batch: int, 52 | seqlen: int, 53 | labels: Optional[Tensor] = None, 54 | ignore_index: int = -100, 55 | ) -> Tuple[Tensor, Optional[Tensor]]: 56 | """ 57 | Add padding to sequences. 58 | 59 | Args: 60 | inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. 61 | indices: (total_nnz) 62 | batch: int, batch size 63 | seqlen: int, max sequence length 64 | position_ids: (total_nnz) or None 65 | labels: (total_nnz) or None 66 | 67 | Returns: 68 | padded_inputs: (batch, seqlen, ...) or (batch, seqlen) 69 | padded_labels: (batch, seqlen) or None 70 | """ 71 | if inputs.dim() == 1: 72 | output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) 73 | output[indices] = inputs 74 | padded_inputs = output.view(batch, seqlen) 75 | else: 76 | _, *rest = inputs.shape 77 | output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) 78 | output[indices] = inputs 79 | padded_inputs = output.view(batch, seqlen, *rest) 80 | 81 | padded_labels = None 82 | if labels is not None: 83 | padded_labels = torch.full((batch * seqlen,), fill_value=ignore_index, dtype=labels.dtype, device=labels.device) 84 | padded_labels[indices] = labels 85 | padded_labels = padded_labels.view(batch, seqlen) 86 | 87 | return padded_inputs, padded_labels 88 | -------------------------------------------------------------------------------- /src/bert_padding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MosaicML Examples authors 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py 5 | # Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py 6 | 7 | """Helper functions for padding and unpadding batches. 8 | 9 | These functions are used extensively throughout the Mosaic BERT implementation 10 | in `bert_layers.py`. 11 | """ 12 | 13 | from typing import Tuple, cast 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | from einops import rearrange, repeat 18 | 19 | 20 | class IndexFirstAxis(torch.autograd.Function): 21 | @staticmethod 22 | def forward(ctx, input: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: 23 | """Get just the values of `input` which are at `indices`. 24 | 25 | Arguments: 26 | ctx: the autograd context object 27 | input: (b, ...) 2+ dimensional tensor 28 | indices: (num_idx) 1D tensor 29 | """ 30 | ctx.save_for_backward(indices) 31 | assert input.ndim >= 2 32 | ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] # type: ignore 33 | second_dim = other_shape.numel() # product of sizes of all but first dimension 34 | # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. 35 | return torch.gather( 36 | rearrange(input, "b ... -> b (...)"), # (b, ...) -> (b, second_dim) 37 | 0, 38 | repeat(indices, "z -> z d", d=second_dim), # (indices,) -> (indices, second_dim) 39 | ).reshape(-1, *other_shape) # (num_idx, ...) 40 | 41 | @staticmethod 42 | def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: 43 | (indices,) = ctx.saved_tensors 44 | assert grad_output.ndim >= 2 45 | other_shape = grad_output.shape[1:] 46 | grad_output = rearrange(grad_output, "b ... -> b (...)") 47 | grad_input = torch.zeros( 48 | [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype 49 | ) 50 | # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. 51 | # grad_input[indices] = grad_output 52 | grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) 53 | return grad_input.reshape(ctx.first_axis_dim, *other_shape), None 54 | 55 | 56 | index_first_axis = IndexFirstAxis.apply 57 | 58 | 59 | class IndexPutFirstAxis(torch.autograd.Function): 60 | @staticmethod 61 | def forward(ctx, values: torch.Tensor, indices: torch.Tensor, first_axis_dim) -> torch.Tensor: 62 | ctx.save_for_backward(indices) 63 | assert indices.ndim == 1 64 | assert values.ndim >= 2 65 | output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) 66 | output[indices] = values 67 | return output 68 | 69 | @staticmethod 70 | def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]: 71 | (indices,) = ctx.saved_tensors 72 | grad_values = grad_output[indices] 73 | return grad_values, None, None 74 | 75 | 76 | index_put_first_axis = IndexPutFirstAxis.apply 77 | 78 | 79 | def unpad_input( 80 | hidden_states: torch.Tensor, 81 | attention_mask: torch.Tensor, 82 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: 83 | """Remove padding from input sequences. 84 | 85 | Arguments: 86 | hidden_states: (batch, seqlen, ...) 87 | attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. 88 | 89 | Returns: 90 | hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. 91 | indices: (total_nnz) 92 | cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. 93 | max_seqlen_in_batch: int () 94 | """ 95 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 96 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 97 | max_seqlen_in_batch = int(seqlens_in_batch.max().item()) 98 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 99 | # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the 100 | # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim 101 | # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to 102 | # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, 103 | # so we write custom forward and backward to make it a bit faster. 104 | hidden_states = cast(torch.Tensor, index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices)) 105 | return hidden_states, indices, cu_seqlens, max_seqlen_in_batch 106 | 107 | 108 | def unpad_input_only( 109 | hidden_states: torch.Tensor, 110 | attention_mask: torch.Tensor, 111 | ) -> torch.Tensor: 112 | """Like unpad_input, but only return the unpadded first tensor. 113 | 114 | Save a small amount of overhead. 115 | 116 | Arguments: 117 | hidden_states: (batch, seqlen, ...) 118 | attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. 119 | 120 | Returns: 121 | hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. 122 | """ 123 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 124 | rearranged = rearrange(hidden_states, "b s ... -> (b s) ...") 125 | return index_first_axis(rearranged, indices) # type: ignore 126 | 127 | 128 | def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: 129 | """Add padding to sequences. 130 | 131 | Arguments: 132 | hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. 133 | indices: (total_nnz) 134 | batch: int batch_size 135 | seqlen: int max sequence length 136 | 137 | Returns: 138 | hidden_states: (batch, seqlen, ...) 139 | """ 140 | output = index_put_first_axis(hidden_states, indices, batch * seqlen) 141 | return rearrange(output, "(b s) ... -> b s ...", b=batch) # type: ignore 142 | -------------------------------------------------------------------------------- /src/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/ModernBERT/8c57a0f01c12c4953ead53d398a36f81a4ba9e38/src/callbacks/__init__.py -------------------------------------------------------------------------------- /src/callbacks/dataloader_speed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | import time 5 | from composer.core import Callback, State 6 | from composer.loggers import Logger 7 | 8 | __all__ = ["DataloaderSpeedMonitor"] 9 | 10 | 11 | class DataloaderSpeedMonitor(Callback): 12 | """Measure how long it takes to return a batch from the dataloader.""" 13 | 14 | def before_dataloader(self, state: State, logger: Logger) -> None: 15 | del logger # unused 16 | self.batch_start_time = time.time_ns() 17 | 18 | def after_dataloader(self, state: State, logger: Logger) -> None: 19 | self.batch_serve_time = time.time_ns() - self.batch_start_time 20 | logger.log_metrics( 21 | { 22 | "throughput/batch_serve_time_ns": self.batch_serve_time, 23 | "throughput/batch_serve_time_ms": self.batch_serve_time / 1e6, 24 | } 25 | ) 26 | -------------------------------------------------------------------------------- /src/callbacks/log_grad_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MosaicML Composer authors 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | """Monitor gradients during training.""" 5 | 6 | import logging 7 | import torch 8 | 9 | from composer.core import Callback, State 10 | from composer.loggers import Logger 11 | from composer.utils import dist 12 | 13 | try: 14 | from src.optimizer import StableAdamW 15 | except ImportError: 16 | StableAdamW = None 17 | 18 | __all__ = ["LogGradNorm"] 19 | 20 | 21 | log = logging.getLogger(__name__) 22 | 23 | 24 | class LogGradNorm(Callback): 25 | """Logs the precomputed L1 and L2 gradient norms from StableAdamW""" 26 | 27 | def __init__(self, log_optimizer_metrics: bool = True, batch_log_interval: int = 10): 28 | self.log_optimizer_metrics = log_optimizer_metrics 29 | self.batch_log_interval = batch_log_interval 30 | if StableAdamW is None: 31 | raise ImportError("Install `pip install torch-optimi` to use the StableAdamW optimizer.") 32 | 33 | def epoch_start(self, state: State, logger: Logger): 34 | if state.fsdp_enabled and dist.get_world_size() > 0 and self.log_optimizer_metrics: 35 | raise ValueError("Logging grad_norms is currently incompatible with FSDP.") 36 | if not isinstance(state.optimizers[0], StableAdamW): 37 | self.log_optimizer_metrics = False 38 | log.warn("Disabling `LogGradNorm` as it requires the internal `StableAdamW` optimizer") 39 | 40 | def batch_end(self, state: State, logger: Logger): 41 | if state.timestamp.batch.value % self.batch_log_interval != 0 or not self.log_optimizer_metrics: 42 | return 43 | 44 | optimizer_metrics = getattr(state.optimizers[0], "grad_norms", None) 45 | if optimizer_metrics is not None: 46 | logged_metrics = {} 47 | for metric, value in optimizer_metrics.items(): 48 | if isinstance(value, torch.Tensor): 49 | value = value.item() 50 | logged_metrics[f"gradient_norms/{metric}"] = value 51 | logger.log_metrics(logged_metrics) 52 | -------------------------------------------------------------------------------- /src/callbacks/packing_efficiency.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | from composer.core import Callback, State 5 | from composer.loggers import Logger 6 | 7 | __all__ = ["PackingEfficency"] 8 | 9 | 10 | class PackingEfficency(Callback): 11 | """Records the packing efficiency for each batch.""" 12 | 13 | def __init__(self, log_interval: int = 100): 14 | self.log_interval = log_interval 15 | 16 | def after_dataloader(self, state: State, logger: Logger) -> None: 17 | if state.timestamp.batch.value % self.log_interval != 0: 18 | return 19 | logger.log_metrics( 20 | { 21 | "trainer/packing_efficiency": self._packing_efficiency(state), 22 | } 23 | ) 24 | 25 | def _packing_efficiency(self, state: State) -> float: 26 | return state.batch["attention_mask"].sum().item() / state.batch["attention_mask"].numel() 27 | -------------------------------------------------------------------------------- /src/callbacks/scheduled_gc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MosaicML LLM Foundry authors 2 | # SPDX-License-Identifier: Apache-2.0 3 | # from: https://github.com/mosaicml/llm-foundry/blob/main/llmfoundry/callbacks/scheduled_gc_callback.py 4 | 5 | import gc 6 | from typing import Optional 7 | 8 | import torch 9 | from composer.core import Callback, State 10 | from composer.loggers import Logger 11 | 12 | __all__ = ["ScheduledGarbageCollector"] 13 | 14 | 15 | def gc_cuda(): 16 | """Garbage collect Torch (CUDA) memory.""" 17 | gc.collect() 18 | if torch.cuda.is_available(): 19 | torch.cuda.empty_cache() 20 | 21 | 22 | class ScheduledGarbageCollector(Callback): 23 | """Disable automatic garbage collection and collect garbage at interval. 24 | 25 | Args: 26 | batch_interval (int): Number of batches between calls to gc.collect() 27 | gen_1_batch_interval(int, optional): Number of batches between calls to gc.collect(1) 28 | eval_keep_disabled (bool): keep gc disabled during eval (default: False) 29 | """ 30 | 31 | def __init__( 32 | self, 33 | batch_interval: int, 34 | gen_1_batch_interval: Optional[int] = None, 35 | eval_keep_disabled: bool = False, 36 | ): 37 | self.batch_interval = batch_interval 38 | self.gen_1_batch_interval = gen_1_batch_interval 39 | self.eval_keep_disabled = eval_keep_disabled 40 | self.gc_init_state = None 41 | 42 | def fit_start(self, state: State, logger: Logger) -> None: 43 | del state, logger # unused 44 | 45 | # cache if automatic garbage collection is enabled; reset at fit_end 46 | self.gc_init_state = gc.isenabled() 47 | 48 | # disable automatic garbage collection 49 | gc.disable() 50 | gc_cuda() 51 | 52 | def fit_end(self, state: State, logger: Logger) -> None: 53 | del state, logger # unused 54 | 55 | gc_cuda() 56 | 57 | # reset automatic garbage collection at fit_end 58 | if self.gc_init_state: 59 | gc.enable() 60 | else: 61 | gc.disable() 62 | 63 | def before_dataloader(self, state: State, logger: Logger) -> None: 64 | del logger # unused 65 | 66 | if self.gen_1_batch_interval is not None and state.timestamp.batch.value % self.gen_1_batch_interval == 0: 67 | gc.collect(1) 68 | 69 | if state.timestamp.batch.value % self.batch_interval == 0: 70 | gc_cuda() 71 | 72 | def eval_start(self, state: State, logger: Logger) -> None: 73 | del state, logger # unused 74 | 75 | gc_cuda() 76 | if not self.eval_keep_disabled: 77 | gc.enable() 78 | 79 | def eval_end(self, state: State, logger: Logger) -> None: 80 | del state, logger # unused 81 | 82 | if not self.eval_keep_disabled: 83 | gc.disable() 84 | 85 | gc_cuda() 86 | -------------------------------------------------------------------------------- /src/colbert_beir/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import colbert_train 2 | from .index_and_score import build_colbert_index, colbert_score 3 | 4 | 5 | __all__ = ["colbert_train", "build_colbert_index", "colbert_score"] 6 | -------------------------------------------------------------------------------- /src/colbert_beir/index_and_score.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from pathlib import Path 3 | from colbert import Indexer, Searcher 4 | from colbert.infra import ColBERTConfig 5 | import ir_datasets 6 | from tqdm import tqdm 7 | from ranx import Run as ranx_run 8 | from ranx import Qrels, evaluate 9 | 10 | 11 | def build_colbert_index( 12 | dataset_name: str, 13 | model_name_or_path: str, 14 | checkpoint_path: str, 15 | collection: list[str], 16 | tmp_path: str, 17 | ): 18 | config = ColBERTConfig( 19 | nbits=8, 20 | root=str(Path(tmp_path) / f"benchmark_{model_name_or_path}"), 21 | overwrite=True, 22 | kmeans_niters=10, 23 | doc_maxlen=300, 24 | ) 25 | indexer = Indexer(checkpoint=checkpoint_path, config=config) 26 | indexer.index( 27 | name=dataset_name, 28 | collection=collection, 29 | overwrite=True, 30 | ) 31 | return True 32 | 33 | 34 | def colbert_score( 35 | model_name_or_path: str, 36 | dataset_name: str, 37 | dataset: ir_datasets.Dataset, 38 | int2docid: dict[int, str], 39 | tmp_path: str, 40 | metric: str = "ndcg@10", 41 | ): 42 | qrels_dict = defaultdict(dict) 43 | for qrel in dataset.qrels_iter(): 44 | qrels_dict[qrel.query_id][qrel.doc_id] = qrel.relevance 45 | 46 | qrels = Qrels(qrels_dict) 47 | 48 | qid_to_query = {} 49 | 50 | for query in dataset.queries_iter(): 51 | qid_to_query[query.query_id] = query.text 52 | 53 | config = ColBERTConfig( 54 | nbits=8, 55 | ncells=8, 56 | ndocs=8192, 57 | root=str(Path(tmp_path) / f"benchmark_{model_name_or_path}"), 58 | centroid_score_threshold=0.3, 59 | doc_maxlen=300, 60 | ) 61 | searcher = Searcher(index=dataset_name, config=config) 62 | run_dict = defaultdict(dict) 63 | for qid, query in tqdm(qid_to_query.items(), desc="Querying " + dataset_name): 64 | result = searcher.search(query, k=10) 65 | for i, r in enumerate(result[0]): 66 | run_dict[qid][int2docid[r]] = result[2][i] 67 | run = ranx_run(run_dict) 68 | return evaluate(qrels, run, metric) 69 | -------------------------------------------------------------------------------- /src/colbert_beir/train.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from colbert.infra import Run, RunConfig, ColBERTConfig 3 | from colbert import Trainer 4 | 5 | 6 | def colbert_train(model_name_or_path: str, train_params: dict, n_gpu: int, data_path: str): 7 | with Run().context( 8 | RunConfig( 9 | nranks=n_gpu, 10 | experiment=model_name_or_path, 11 | name=train_params["name"], 12 | root=train_params["root"], 13 | ) 14 | ): 15 | config = ColBERTConfig(doc_maxlen=300, **train_params) 16 | print(config) 17 | data_path = Path(data_path) 18 | 19 | trainer = Trainer( 20 | triples=str(data_path / "triples.train.colbert.jsonl"), 21 | queries=str(data_path / "queries.train.colbert.tsv"), 22 | collection=str(data_path / "corpus.train.colbert.tsv"), 23 | config=config, 24 | ) 25 | 26 | trainer.train(checkpoint=model_name_or_path) 27 | return f"{train_params['root']}/{model_name_or_path}/none/{train_params['name']}/checkpoints/colbert" 28 | -------------------------------------------------------------------------------- /src/data/README.md: -------------------------------------------------------------------------------- 1 | # Training Data 2 | This readme describes the training data and process used in BERT24. 3 | 4 | 5 | 6 | ## Re-Generating the Data 7 | 0. Install dependencies from the `requirements.txt` and `requirements-data.txt` 8 | 3. Turn a HF dataset into MDS via `hf_to_mds.py`. 9 | 6. Sample each dataset using the `sample_dataset_from_config.py TODO` 10 | 11 | 12 | ## Utilities 13 | 14 | #### Gathering the size of HF MDS datasets 15 | Thanks to the MDS format, it is simple and quick to get the number of instances and the size of the data. To do this, run `get_counts_from_hf.py` to get all counts, or use `get_counts_from_hf.py --repos "REPO_1 REPO_2 ... REPO_N"`. 16 | 17 | #### Calculating the tokens in a dataset 18 | We can gather the number of total tokens and the tokens per instance using `python source_stats.py`. This is still in progress. 19 | 20 | ## Config format 21 | TODO: copy from old README 22 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/ModernBERT/8c57a0f01c12c4953ead53d398a36f81a4ba9e38/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | MDS_COLS_TOKENIZED = { 4 | 'input_ids': 'ndarray', 5 | 'attention_mask': 'ndarray', 6 | 'id': 'str' 7 | } 8 | 9 | MDS_COLS_TEXT = { 10 | 'text': 'str', 11 | 'id': 'str' 12 | } 13 | 14 | 15 | ALL_REPOS = [ 16 | # TODO: make these 17 | ] -------------------------------------------------------------------------------- /src/data/get_counts_from_hf.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import argparse 3 | import random 4 | import json 5 | import datasets 6 | import requests 7 | import math 8 | import os 9 | import gzip 10 | import numpy as np 11 | import multiprocessing 12 | import huggingface_hub 13 | import glob 14 | import tempfile 15 | 16 | from datasets import load_dataset, Dataset, DatasetDict, interleave_datasets 17 | from streaming.base.util import _merge_index_from_root, merge_index 18 | from transformers import set_seed, AutoTokenizer 19 | from streaming import MDSWriter, StreamingDataset 20 | 21 | from huggingface_hub import HfFileSystem 22 | from data_utils import ALL_REPOS 23 | 24 | 25 | def get_counts_for_repo(repo, args): 26 | # download the root index.json only 27 | files_in_repo = [item.path for item in huggingface_hub.list_repo_tree(repo, repo_type="dataset")] 28 | if "index.json" not in files_in_repo: 29 | # it must be in the main folder 30 | main_folder = None 31 | repo_name_folder = repo.split("/")[-1] 32 | for file in files_in_repo: 33 | if file not in [".gitattributes"] and file.count(".") == 0: 34 | main_folder = file 35 | break 36 | main_json = f"{main_folder}/index.json" 37 | print(f"Did not find a root index.json, using {main_json}") 38 | else: 39 | main_json = "index.json" 40 | 41 | with tempfile.TemporaryDirectory() as tmp_cache_dir: 42 | root_folder = huggingface_hub.snapshot_download(repo_id=repo, allow_patterns=main_json, repo_type="dataset", cache_dir=tmp_cache_dir) 43 | dataset = StreamingDataset(local=os.path.join(root_folder, main_json.replace("index.json", "")), shuffle=False, split=None, batch_size=1) 44 | dataset_size = len(dataset) 45 | 46 | base_dir = f"datasets/{repo}" 47 | fs = HfFileSystem() 48 | try: 49 | size_of_folder = fs.du(base_dir, total=True, maxdepth=None, withdirs=True) 50 | except Exception as e: 51 | print(f"Error: {e}. Sleeping for 60 seconds and trying again") 52 | import time 53 | time.sleep(60) 54 | size_of_folder = fs.du(base_dir, total=True, maxdepth=None, withdirs=True) 55 | 56 | return {"dataset": repo, "size": size_of_folder / 1e9, "instances": dataset_size} 57 | 58 | 59 | def get_counts(args): 60 | # read in all that have been already processed 61 | if os.path.exists("dataset_info.jsonl"): 62 | with open("dataset_info.jsonl", "r") as f: 63 | processed_datasets = set([json.loads(line)["dataset"] for line in f]) 64 | else: 65 | processed_datasets = set() 66 | 67 | output_f = open("dataset_info.jsonl", "a") 68 | for repo in tqdm.tqdm(args.repos): 69 | if repo in processed_datasets: 70 | print(f"Skipping {repo} since it's already processed") 71 | continue 72 | print(f"Getting counts for {repo}") 73 | output_dict = get_counts_for_repo(repo, args) 74 | output_f.write(json.dumps(output_dict) + "\n") 75 | # flush it 76 | output_f.flush() 77 | 78 | output_f.close() 79 | 80 | # read in the info and sum and print 81 | total_size = 0 82 | total_instances = 0 83 | with open("dataset_info.jsonl", "r") as f: 84 | for line in f: 85 | info = json.loads(line) 86 | total_size += info["size"] 87 | total_instances += info["instances"] 88 | 89 | print(f"Total size: {total_size} GB") 90 | print(f"Total instances: {total_instances}") 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--repos", type=str, nargs="+", help="List of repos to get counts for", default=None) 96 | args = parser.parse_args() 97 | 98 | # if repos is None use the default ALL_REPOS 99 | if args.repos is None: 100 | args.repos = ALL_REPOS 101 | 102 | get_counts(args) 103 | 104 | # example usage: 105 | # python get_counts_from_hf.py -------------------------------------------------------------------------------- /src/data/hf_to_mds.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import argparse 3 | import random 4 | import json 5 | import datasets 6 | import requests 7 | import math 8 | import os 9 | import gzip 10 | import numpy as np 11 | import multiprocessing 12 | import huggingface_hub 13 | import glob 14 | import tempfile 15 | 16 | from datasets import load_dataset, Dataset, DatasetDict, interleave_datasets 17 | from streaming.base.util import _merge_index_from_root, merge_index 18 | from transformers import set_seed, AutoTokenizer 19 | from streaming import MDSWriter, StreamingDataset 20 | from huggingface_hub import HfFileSystem 21 | 22 | from data_utils import MDS_COLS_TEXT 23 | 24 | 25 | set_seed(11111111) 26 | 27 | FILES_INFO = None 28 | 29 | 30 | 31 | def push_to_hub_incrementally(repo_name, local_path): 32 | api = huggingface_hub.HfApi() 33 | # Upload all the content from the local folder to your remote Space. 34 | # By default, files are uploaded at the root of the repo 35 | print(f"Uploading {local_path} to {repo_name}/{local_path}") 36 | try: 37 | api.upload_folder( 38 | folder_path=local_path, 39 | repo_id=repo_name, 40 | path_in_repo=local_path, 41 | repo_type="dataset", 42 | multi_commits=True, 43 | multi_commits_verbose=True, 44 | ) 45 | except Exception as e: 46 | print(e) 47 | import time 48 | time.sleep(30) 49 | print(f"Error uploading {local_path} to {repo_name}, trying again") 50 | api.upload_folder( 51 | folder_path=local_path, 52 | repo_id=repo_name, 53 | path_in_repo=local_path, 54 | repo_type="dataset", 55 | multi_commits=True, 56 | multi_commits_verbose=True, 57 | ) 58 | 59 | os.system(f"rm -rf {local_path}") 60 | print(f"Pushed {local_path} to {repo_name}") 61 | 62 | 63 | def sample_hf(upload_repo, repo_name, split_name, config_name): 64 | print(f"Sampling the data with repo {repo_name} and {split_name} and {config_name} and pushing to {upload_repo}...") 65 | 66 | if config_name is not None and split_name: 67 | dataset = load_dataset(repo_name, config_name, streaming=True)[split_name] 68 | elif config_name is not None: 69 | dataset = load_dataset(repo_name, config_name, streaming=True) 70 | elif split_name is not None: 71 | dataset = load_dataset(repo_name, streaming=True)[split_name] 72 | else: 73 | dataset = load_dataset(repo_name, streaming=True) 74 | 75 | 76 | try: 77 | files = list(huggingface_hub.list_repo_tree(upload_repo, repo_type="dataset")) 78 | files = [file.path for file in files] 79 | except huggingface_hub.utils._errors.RepositoryNotFoundError: 80 | # make the dataset if it doesn't exist 81 | api = huggingface_hub.HfApi() 82 | repo_url = api.create_repo( 83 | upload_repo, 84 | repo_type="dataset", 85 | exist_ok=False, 86 | ) 87 | files = [] 88 | 89 | 90 | if "data/index.json" not in files: 91 | config_name_dirsafe = config_name.replace("/", "-") if config_name is not None else "default" 92 | split_name_dirsafe = split_name.replace("/", "-") if split_name is not None else "default" 93 | tmp_cache_dir = f"{repo_name.replace('/', '-')}---{split_name_dirsafe}---{config_name_dirsafe}" 94 | if not os.path.isfile(os.path.join(tmp_cache_dir, "index.json")): 95 | print(f"Writing to MDS...") 96 | with MDSWriter(out=tmp_cache_dir, columns=MDS_COLS_TEXT, compression='zstd') as train_writer: 97 | for item in tqdm.tqdm(dataset): 98 | train_writer.write(item) 99 | 100 | print(f"Pushing to HF...") 101 | dataset = StreamingDataset(local=tmp_cache_dir, shuffle=False, split=None, batch_size=1) 102 | num_instances = len(dataset) 103 | push_to_hub_incrementally( 104 | upload_repo, 105 | tmp_cache_dir 106 | ) 107 | else: 108 | print(f"Using existing MDS written out") 109 | 110 | fs = HfFileSystem() 111 | size_of_folder = fs.du(f"datasets/{upload_repo}") 112 | with open("dataset_info.jsonl", "a") as f: 113 | f.write(json.dumps({"dataset": upload_repo, "split_name": split_name, "config_name": config_name, "size": size_of_folder / 1e9, "instances": num_instances}) + "\n") 114 | 115 | 116 | 117 | if __name__ == "__main__": 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument("-u", "--upload_repo", type=str, required=True) 120 | parser.add_argument("-r", "--repo_name", type=str, required=True) 121 | parser.add_argument("-s", "--repo_split", type=str, required=False) 122 | parser.add_argument("-c", "--repo_config", type=str, required=False) 123 | args = parser.parse_args() 124 | 125 | sample_hf(args.upload_repo, args.repo_name, args.repo_split, args.repo_config) 126 | 127 | # example usage: 128 | # python hf_to_mds.py -r HF_DATASET -c CONFIG -s SPLIT -u HF_SAVE_PATH 129 | -------------------------------------------------------------------------------- /src/data/mds_conversion.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script allows conversion of mds-data, such as 3 | * Compressing or decompressing a dataset 4 | * Removing unnecessary fields 5 | * Adapting `input_ids` to a more appropriate dtype 6 | """ 7 | 8 | import argparse 9 | import os 10 | import json 11 | import numpy as np 12 | from streaming.base.format.mds.writer import MDSWriter 13 | from streaming.base.format import reader_from_json 14 | from streaming.base.compression import decompress 15 | from tqdm import tqdm 16 | 17 | def maybe_decompress_shard(shard, delete_zip: bool = False): 18 | """ 19 | If shard does not have decompressed data, 20 | this function decompresses the shard 21 | """ 22 | raw_filename = os.path.join(shard.dirname, shard.split, shard.raw_data.basename) 23 | if not os.path.isfile(raw_filename): 24 | zip_filename = os.path.join(shard.dirname, shard.split, shard.zip_data.basename) 25 | data = open(zip_filename, 'rb').read() 26 | data = decompress(shard.compression, data) 27 | tmp_filename = raw_filename + '.tmp' 28 | with open(tmp_filename, 'wb') as out: 29 | out.write(data) 30 | os.rename(tmp_filename, raw_filename) 31 | 32 | # Maybe remove compressed to save space. 33 | if shard.zip_data is not None and delete_zip: 34 | zip_filename = os.path.join(shard.dirname, shard.split, shard.zip_data.basename) 35 | if os.path.exists(zip_filename): 36 | os.remove(zip_filename) 37 | 38 | def main(): 39 | # Initialize the argument parser 40 | parser = argparse.ArgumentParser() 41 | 42 | # Define the arguments 43 | parser.add_argument('--data_path', type=str, required=True, help='Path to the data file') 44 | parser.add_argument('--read_split', type=str, required=True, help='Data split to read data from') 45 | parser.add_argument('--write_split', type=str, default=None, help='Data split to write data to') 46 | parser.add_argument('--dtype', type=str, default=None, help='Data type to convert the values of input_ids to') 47 | parser.add_argument('--columns_to_keep', type=str, nargs='+', default=None, help='List of columns to keep, if None, all columns will be kept') 48 | parser.add_argument('--decompress', action='store_true', help='If data in read_split should be be decompressed. Necessary if there is only compressed data in read_split') 49 | parser.add_argument('--delete_zip', action='store_true', help='Whether the compressed files should be kept after decompression or not') 50 | parser.add_argument('--compression', type=str, default=None, help='Compression type to use for the data to write') 51 | 52 | # Parse the arguments 53 | args = parser.parse_args() 54 | 55 | # Verify that the data path exists 56 | if not os.path.exists(args.data_path): 57 | raise FileNotFoundError(f"Data path {args.data_path} does not exist.") 58 | 59 | if not args.write_split: 60 | assert args.decompress and not args.dtype and not args.columns_to_keep, "Only decompression is allowed if no write_split has been specified" 61 | 62 | # Convert args.dtype string into actual np.dtype if given 63 | dtype = np.dtype(args.dtype) if args.dtype else None 64 | 65 | # Load index file 66 | split_path = os.path.join(args.data_path, args.read_split) 67 | index_file_path = os.path.join(split_path, "index.json") 68 | obj = json.load(open(index_file_path)) 69 | 70 | # Load columns from first shard to know what columns to write, and adapt if columns_to_keep is specified 71 | columns_to_write = {col_name: col_enc for col_name, col_enc in zip(obj["shards"][0]["column_names"], obj["shards"][0]["column_encodings"])} 72 | assert "input_ids" in columns_to_write, f"The data in the read path must have `input_ids` in its columns. Its columns: {columns_to_write.keys()}" 73 | if args.columns_to_keep: 74 | # Verify that each column in columns_to_keep is valid 75 | for column in args.columns_to_keep: 76 | assert column in columns_to_write, f"The given column to keep {column} must exist in the data in {args.read_split}" 77 | columns_to_write = {col_name: col_encoding for col_name, col_encoding in columns_to_write.items() if col_name in args.columns_to_keep} 78 | 79 | # read all shards 80 | shards = [] 81 | for info in tqdm(obj['shards'], desc="Reading shards"): 82 | shard = reader_from_json(args.data_path, args.read_split, info) 83 | maybe_decompress_shard(shard, args.delete_zip) 84 | shards.append(shard) 85 | 86 | # potentially filter/alter shards and write the new ones 87 | if args.write_split: 88 | with MDSWriter( 89 | columns=columns_to_write, out=os.path.join(args.data_path, args.write_split), compression=args.compression 90 | ) as out: 91 | for shard in tqdm(shards, desc="Writing shards"): 92 | for sample in shard: 93 | if dtype: 94 | assert np.all(sample["input_ids"]<=np.iinfo(dtype).max), f"value in sample[input_ids] must not exceed {dtype} max" 95 | sample["input_ids"] = sample["input_ids"].astype(dtype) 96 | out.write({k: sample[k] for k in columns_to_write.keys()}) 97 | 98 | if __name__ == "__main__": 99 | main() -------------------------------------------------------------------------------- /src/data/relative_prop_to_instance_prop.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import huggingface_hub 3 | from datasets import concatenate_datasets, load_dataset 4 | from tqdm import tqdm 5 | import tempfile 6 | import yaml 7 | import os 8 | import math 9 | import numpy as np 10 | import multiprocessing 11 | from pathlib import Path 12 | import pandas as pd 13 | import time 14 | from transformers import AutoTokenizer 15 | 16 | 17 | from datasets.utils.logging import disable_progress_bar 18 | disable_progress_bar() 19 | 20 | TOTAL_TOKENS = 1000000000 21 | 22 | 23 | def relative_to_instance(args): 24 | assert os.path.isfile(args.config), f"Config file {in_fn} does not exist." 25 | with open(args.config, 'r') as file: 26 | config = yaml.safe_load(file) 27 | 28 | target_tokens = config["target_tokens"] 29 | token_adjustment_ratio = target_tokens / TOTAL_TOKENS 30 | 31 | # contains `tokens_per_instance`, `instance_proportions`, and `num_instances` for each source 32 | existing_cts = pd.read_csv(args.ground_truth) 33 | source2instances = dict(zip(existing_cts["sources"], existing_cts["num_instances"])) 34 | 35 | # get the scaled amount of instances we need, based on the token adjustment ratio 36 | # NOTE: appx, since the number of tokens per instance is appx 37 | source2instances_scaled = {k: v * token_adjustment_ratio for k, v in source2instances.items()} 38 | 39 | # from the config file, get what relative weights we want 40 | sample_coefficients = {x["name"]: x["source_coefficient"] for x in config["sources"]} 41 | 42 | # multiply the relative weights by the number of instances 43 | final_proportions = {k: v * source2instances_scaled[k] for k, v in sample_coefficients.items()} 44 | print(f"Targeting {target_tokens} tokens with the following sampling fractions by source:") 45 | print("\n".join(f"\t- {k} -> {round(v, 4)}" for k, v in final_proportions.items())) 46 | 47 | # make a new config file with the instance numbers instead 48 | # write this out to the instance config folder that will be used for sampling (one dir up) 49 | out_fn = args.config.replace("relative", "instances") 50 | # replace sources with the `final_proportions` dict 51 | config["sources"] = [{"name": k, "num_instances": round(v)} for k, v in final_proportions.items()] 52 | with open(out_fn, 'w') as file: 53 | yaml.dump(config, file) 54 | print(f"New config file written to {out_fn}") 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("--config", type=str, required=True) 60 | parser.add_argument("--ground_truth", type=str, default="statistics/ground_truth.csv") 61 | args = parser.parse_args() 62 | 63 | relative_to_instance(args) 64 | 65 | # example usage: 66 | # python relative_prop_to_instance_prop.py --config configs/relative/stratified_20bn.yaml -------------------------------------------------------------------------------- /src/data/sample_dataset_from_config.py: -------------------------------------------------------------------------------- 1 | from streaming import MDSWriter 2 | import huggingface_hub 3 | from datasets import load_dataset, interleave_datasets 4 | import argparse 5 | import os 6 | import tqdm 7 | import yaml 8 | import random 9 | from transformers import AutoTokenizer, set_seed 10 | import datasets 11 | from streaming import StreamingDataset 12 | 13 | from hf_to_mds import push_to_hub_incrementally 14 | from data_utils import SOURCE_MAP, MDS_COLS_TOKENIZED, MDS_COLS_TEXT 15 | 16 | set_seed(123456789) 17 | 18 | TEST = False 19 | 20 | 21 | def tokenize_and_write(writer, pool, tokenizer): 22 | global TEST 23 | total_tokens = 0 24 | pool_texts = [instance["text"] for instance in pool] 25 | if tokenizer is not None: 26 | texts_tokenized = tokenizer(pool_texts, truncation=False, return_tensors="np") 27 | 28 | for i, instance in enumerate(pool): 29 | instance_dict = { 30 | "text": instance["text"], 31 | "id": instance["id"], 32 | } 33 | if tokenizer is not None: 34 | instance_dict["input_ids"] = texts_tokenized["input_ids"][i].squeeze() 35 | instance_dict["attention_mask"] = texts_tokenized["attention_mask"][i].squeeze() 36 | del instance_dict["text"] 37 | total_tokens += len(instance_dict["input_ids"]) 38 | 39 | if not TEST: 40 | print(instance_dict.keys()) 41 | TEST = True 42 | 43 | writer.write(instance_dict) 44 | 45 | return total_tokens 46 | 47 | 48 | def sample_dataset_from_config(args): 49 | assert os.path.isfile(args.config), f"Config file {in_fn} does not exist." 50 | with open(args.config, 'r') as file: 51 | config = yaml.safe_load(file) 52 | 53 | target_tokens = config["target_tokens"] 54 | sample_nums = {x["name"]: x["num_instances"] for x in config["sources"]} 55 | config_file_name = os.path.basename(args.config).split(".")[0] 56 | 57 | train_path = os.path.join(config_file_name, "train") 58 | validation_path = os.path.join(config_file_name, "validation") 59 | os.makedirs(train_path, exist_ok=True) 60 | os.makedirs(validation_path, exist_ok=True) 61 | 62 | if args.tokenizer is not None: 63 | print(f"Using tokenizer model {args.tokenizer}...") 64 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 65 | 66 | COLS = MDS_COLS_TOKENIZED if args.tokenizer is not None else MDS_COLS_TEXT 67 | with MDSWriter(out=os.path.join(config_file_name, "train"), columns=COLS, compression='zstd') as train_writer: 68 | with MDSWriter(out=os.path.join(config_file_name, "validation"), columns=COLS, compression='zstd') as validation_writer: 69 | for source in tqdm.tqdm(config["sources"], desc="Sources"): 70 | # pools are used to tokenize more than once, using args.tokenization_batch_size 71 | train_pool = [] 72 | validation_pool = [] 73 | 74 | source_name = source["name"] 75 | num_train = sample_nums[source_name] 76 | num_validation = max(1, round(num_train * args.validation_fraction)) 77 | 78 | source_hf_repo = SOURCE_MAP[source_name] 79 | remote = f'hf://datasets/{source_hf_repo}/' 80 | dataset = StreamingDataset(remote=remote, shuffle=True, split=None, batch_size=1, cache_limit="50GB") 81 | 82 | for idx, instance in tqdm.tqdm(dataset): 83 | if args.debug and idx > 100: 84 | break 85 | 86 | if idx < num_train: 87 | train_pool.append(instance_dict) 88 | else: 89 | validation_pool.append(instance_dict) 90 | 91 | if len(train_pool) > args.tokenization_batch_size: 92 | num_tokens += tokenize_and_write(train_writer, train_pool, tokenizer) 93 | train_pool = [] 94 | 95 | if len(validation_pool) > args.tokenization_batch_size: 96 | tokenize_and_write(validation_writer, validation_pool, tokenizer) 97 | validation_pool = [] 98 | 99 | # any that didn't fit in the batch size 100 | if len(train_pool) > 0: 101 | num_tokens += tokenize_and_write(train_writer, train_pool, tokenizer) 102 | 103 | if len(validation_pool) > 0: 104 | tokenize_and_write(validation_writer, validation_pool, tokenizer) 105 | 106 | print(f"Finished writing with a total of {num_tokens} tokens.") 107 | # add the config file to the output directory with the total number of tokens 108 | with open(os.path.join(config_file_name, "config.yaml"), 'w') as file: 109 | config["total_tokens"] = num_tokens 110 | yaml.dump(config, file) 111 | 112 | # now push it to HF 113 | print(f"Pushing to HF...") 114 | upload_repo_path = f"orionweller/{config_file_name}" 115 | api = huggingface_hub.HfApi() 116 | repo_url = api.create_repo( 117 | upload_repo_path, 118 | repo_type="dataset", 119 | exist_ok=False, 120 | ) 121 | push_to_hub_incrementally( 122 | upload_repo_path, 123 | config_file_name 124 | ) 125 | 126 | 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument("-c", "--config", type=str, required=True) 131 | parser.add_argument("-t", "--tokenizer", type=str, required=False, default=None) 132 | parser.add_argument("-v", "--validation_fraction", type=float, default=0.01) 133 | parser.add_argument("-b", "--tokenization_batch_size", type=int, default=100000) 134 | parser.add_argument("-d", "--debug", action="store_true") 135 | args = parser.parse_args() 136 | 137 | sample_dataset_from_config(args) 138 | 139 | # example usage: 140 | # python sample_dataset_from_config.py -c configs/instances/stratified_20bn.yaml -d -------------------------------------------------------------------------------- /src/data/source_stats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import huggingface_hub 3 | from collections import Counter 4 | from datasets import load_dataset, Dataset 5 | import tqdm 6 | import tempfile 7 | from pathlib import Path 8 | import multiprocessing 9 | import re 10 | import numpy as np 11 | import pandas as pd 12 | import math 13 | import json 14 | import os 15 | from transformers import AutoTokenizer 16 | from streaming import StreamingDataset 17 | from streaming.base.util import clean_stale_shared_memory 18 | 19 | from data_utils import ALL_REPOS, MDS_COLS_TEXT 20 | 21 | 22 | NUM_PROC = int(math.ceil(0.35 * multiprocessing.cpu_count())) 23 | 24 | model_name = "gpt2" 25 | 26 | def main(out_fn, dataset_max_size): 27 | tokenizer = AutoTokenizer.from_pretrained(model_name, fast=True) 28 | 29 | # load all lines in out_fn 30 | percentiles_out_path = str(out_fn).replace(".csv", ".jsonl") 31 | cached_sources = set() 32 | if os.path.exists(percentiles_out_path): 33 | with open(str(out_fn).replace(".csv", ".jsonl"), "r") as f: 34 | for line in f: 35 | cached_sources.add(json.loads(line)["source"]) 36 | 37 | 38 | percentiles = [1, 99] + list(range(0, 101, 5)) 39 | print(f"Saving source-level token count percentiles to {str(out_fn).replace('.csv', '.jsonl')}") 40 | stats = [] 41 | tokens_for_source = [] 42 | current_repos_to_do = [item for item in ALL_REPOS if item.split("/")[-1] not in cached_sources] 43 | prev_src = current_repos_to_do[0].split("/")[-1] 44 | percentiles_out = open(percentiles_out_path, "a") 45 | 46 | for data_dir in tqdm.tqdm(current_repos_to_do): 47 | source = data_dir.split("/")[-1] 48 | print(f"Processing {source}... with data_dir {data_dir}") 49 | if source != prev_src: 50 | # add percentiles and reset 51 | tokens_np = np.array(tokens_for_source) 52 | percentile_stats_all = np.percentile(tokens_np, percentiles) 53 | percentile_stats = { 54 | "mean": np.mean(tokens_np), 55 | "std": np.std(tokens_np), 56 | "percentiles": {p: v for p, v in zip(percentiles, percentile_stats_all)} 57 | } 58 | tokens_for_source = [] 59 | percentiles_out.write(json.dumps({prev_src: percentile_stats, "source": prev_src}) + "\n") 60 | percentiles_out.flush() 61 | 62 | prev_src = source 63 | 64 | with tempfile.TemporaryDirectory() as tmp_cache_dir: 65 | remote = f'hf://datasets/orionweller/{source}/' 66 | token_lens = [] 67 | pool = [] 68 | clean_stale_shared_memory() 69 | for idx, instance in tqdm.tqdm(enumerate(StreamingDataset(remote=remote, shuffle=False, split=None, batch_size=1, predownload=dataset_max_size))): 70 | pool.append(instance) 71 | if idx > dataset_max_size: 72 | break 73 | if len(pool) > 1000: 74 | hf_dataset = Dataset.from_list(pool) 75 | try: 76 | tokens = hf_dataset.map( 77 | lambda row: {"num_tokens": tokenizer(row["text"]), "batched": True}, 78 | num_proc=NUM_PROC, remove_columns=MDS_COLS_TEXT.keys() 79 | )["num_tokens"] 80 | except Exception as e: 81 | print(f"Error processing {source} at idx {idx}") 82 | print(e) 83 | tokens = hf_dataset.map( 84 | lambda row: {"num_tokens": tokenizer(row["text"]), "batched": True}, 85 | num_proc=NUM_PROC, remove_columns=MDS_COLS_TEXT.keys() 86 | )["num_tokens"] 87 | token_lens.extend([len(item["input_ids"]) for item in tokens]) 88 | hf_dataset.cleanup_cache_files() 89 | pool = [] 90 | 91 | hf_dataset = Dataset.from_list(pool) 92 | tokens = hf_dataset.map( 93 | lambda row: {"num_tokens": tokenizer(row["text"]), "batched": True}, 94 | num_proc=NUM_PROC, remove_columns=MDS_COLS_TEXT.keys() 95 | )["num_tokens"] 96 | token_lens.extend([len(item["input_ids"]) for item in tokens]) 97 | 98 | tokens_for_source.extend(token_lens) 99 | 100 | # This is overkill, but just in case 101 | hf_dataset.cleanup_cache_files() 102 | 103 | stats.append({ 104 | "source": source, 105 | "num_tokens": sum(token_lens) 106 | }) 107 | 108 | # do the percentile calculation for the last source also 109 | tokens_np = np.array(tokens_for_source) 110 | percentile_stats_all = np.percentile(tokens_np, percentiles) 111 | percentile_stats = { 112 | "mean": np.mean(tokens_np), 113 | "std": np.std(tokens_np), 114 | "percentiles": {p: v for p, v in zip(percentiles, percentile_stats_all)} 115 | } 116 | percentiles_out.write(json.dumps({prev_src: percentile_stats}) + "\n") 117 | 118 | # now get the total stats 119 | stats = pd.DataFrame(stats) 120 | 121 | # Group by source and sum num_tokens 122 | stats = stats.groupby("source").sum().reset_index() 123 | 124 | # Add a column which shows the fractional contribution of each source 125 | stats["fraction"] = stats["num_tokens"] / stats["num_tokens"].sum() 126 | 127 | # Sort by fraction decreasing 128 | stats = stats.sort_values("fraction", ascending=False) 129 | 130 | print(f"Saving source-level token count statistics to {out_fn}") 131 | stats.to_csv(out_fn, index=False) 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser(description="Measure token count per data source.") 136 | parser.add_argument("--out_fn", type=Path, default=Path(__file__).resolve().parent / "statistics" / "source_stats.csv", help="Output file for source stats.") 137 | parser.add_argument("--dataset_max_size", type=int, default=100000, help="Maximum number of instances to load at once.") 138 | args = parser.parse_args() 139 | 140 | main(args.out_fn, args.dataset_max_size) 141 | 142 | # example usage: 143 | # python source_stats.py --dataset_max_size 100000 144 | -------------------------------------------------------------------------------- /src/evals/README.md: -------------------------------------------------------------------------------- 1 | # Ablation Evals 2 | 3 | ## Generate config 4 | 5 | Run `python generate_eval_config_from_checkpoint.py --help` for all options. 6 | 7 | ### Create config by specifying checkpoint & config path 8 | ``` 9 | python generate_eval_config_from_checkpoint.py \ 10 | --checkpoint /path/to/checkpoint/folder \ 11 | --train_config /path/to/config.yaml 12 | ``` 13 | 14 | ### Create config from the matching wandb run & add wandb tracking 15 | ``` 16 | python generate_eval_config_from_checkpoint.py \ 17 | --checkpoint /path/to/checkpoint/folder \ 18 | --wandb_entity entity_name \ 19 | --wandb_project project_name \ 20 | --track_run 21 | ``` 22 | 23 | ### Create a config and skip the MNLI eval 24 | 25 | You can skip any number of evals by adding `--skip_` for each eval you want to skip. 26 | 27 | ``` 28 | python generate_eval_config_from_checkpoint.py \ 29 | --checkpoint /path/to/checkpoint/folder \ 30 | --wandb_entity entity_name \ 31 | --wandb_project project_name \ 32 | --track_run \ 33 | --skip_mnli 34 | ``` 35 | 36 | ## Launch a single ablation job 37 | ```bash 38 | python eval.py yamls/ablations/checkpoint_name.yaml 39 | ``` 40 | 41 | ## Automatically generate eval configs for multiple checkpoints and run evals on multiple GPUs 42 | 43 | `run_evals_from_checkpoints.py` can be used to automatically generate configs from the latest checkpoints in a given directory, and run all evals on all avalible GPUs. 44 | 45 | Run `python run_evals_from_checkpoints.py --help` for all options. All options from `generate_eval_config_from_checkpoint.py` are also available. 46 | 47 | The logic for this script is: 48 | - Each subdir in `--checkpoints` is scanned for model checkpoints. 49 | - If a checkpoint/symlink named "latest-rank0.pt" does not exist, a symlink to the latest checkpoint will be created. 50 | - If checkpoint/symlink exists, the script will use that checkpoint. 51 | - If you pass `--overwrite_existing_symlinks`, the script will create a new symlink to the latest checkpoint and use it. 52 | - Config files should be stored together with the checkpoint (_evaluation.yaml) 53 | - If not, the script will try to find a matching wandb run in `wandb_entity`/`wandb_project` project and autogen a config. 54 | - If the above fails, then the job will be skipped. 55 | 56 | ``` 57 | python run_evals_from_checkpoints.py \ 58 | --checkpoints /home/shared/data-ablations/checkpoints \ 59 | --wandb_entity entity_name \ 60 | --wandb_project project_name \ 61 | --track_run 62 | ``` 63 | 64 | -------------------------------------------------------------------------------- /src/evals/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MosaicML Examples authors 2 | # SPDX-License-Identifier: Apache-2.0 3 | -------------------------------------------------------------------------------- /src/evals/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MosaicML Examples authors 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | 5 | """from https://arxiv.org/pdf/1905.00537 6 | For classification tasks with sentence-pair inputs (BoolQ, CB, RTE, WiC), we concatenate the 7 | sentences with a [SEP] token, feed the fused input to BERT, and use a logistic regression classifier 8 | that sees the representation corresponding to [CLS]. For WiC, we also concatenate the representation 9 | of the marked word. For COPA, MultiRC, and ReCoRD, for each answer choice, we similarly 10 | concatenate the context with that answer choice and feed the resulting sequence into BERT to produce 11 | an answer representation. For COPA, we project these representations into a scalar, and take as the 12 | answer the choice with the highest associated scalar. For MultiRC, because each question can have 13 | more than one correct answer, we feed each answer representation into a logistic regression classifier. 14 | For ReCoRD, we also evaluate the probability of each candidate independent of other candidates, 15 | and take the most likely candidate as the model’s prediction. For WSC, which is a span-based task, 16 | we use a model inspired by Tenney et al. (2019). Given the BERT representation for each word in the 17 | original sentence, we get span representations of the pronoun and noun phrase via a self-attention 18 | span-pooling operator (Lee et al., 2017), before feeding it into a logistic regression classifier. 19 | """ 20 | 21 | import logging 22 | 23 | from composer.utils import MissingConditionalImportError, dist 24 | 25 | _glue_task_column_names = { 26 | "cola": ("sentence", None), 27 | "mnli": ("premise", "hypothesis"), 28 | "mrpc": ("sentence1", "sentence2"), 29 | "qnli": ("question", "sentence"), 30 | "qqp": ("question1", "question2"), 31 | "rte": ("sentence1", "sentence2"), 32 | "sst2": ("sentence", None), 33 | "stsb": ("sentence1", "sentence2"), 34 | } 35 | 36 | _superglue_task_column_names = { 37 | "boolq": ("question", "passage"), 38 | "cb": ("premise", "hypothesis"), 39 | "copa": ("premise", "choice1", "choice2", "question"), 40 | "multirc": ("paragraph", "question", "answer"), 41 | # "record": ("question1", "question2"), ['passage', 'query', 'entities', 'entity_spans', 'answers', 'idx' 42 | "rte": ("premise", "hypothesis"), 43 | "wic": ( 44 | "sentence1", 45 | "sentence2", 46 | ), #'word','sentence1' 'sentence2', 'start1', 'start2', 'end1', 'end2', 47 | # "wsc": ("sentence1", "sentence2"), #'text','span1_index', 'span2_index', 'span1_text', 'span2_text', 48 | # "wsc.fixed": ("sentence1", "sentence2"), #'text','span1_index', 'span2_index', 'span1_text', 'span2_text', 49 | # "axb": ("sentence1", "sentence2"), 50 | # "axg": ("premise", "hypothesis"), 51 | } 52 | 53 | log = logging.getLogger(__name__) 54 | 55 | 56 | def create_eval_dataset( 57 | task: str, 58 | tokenizer_name: str, 59 | split: str, 60 | dataset_name: str, 61 | max_seq_length: int = 256, 62 | max_retries: int = 10, 63 | num_workers: int = 0, 64 | dataset_subset: str = None, 65 | task_column_names: dict = _glue_task_column_names, 66 | tokenize_fn_factory: callable = None, 67 | ): 68 | try: 69 | import datasets 70 | import transformers 71 | except ImportError as e: 72 | raise MissingConditionalImportError( 73 | extra_deps_group="nlp", conda_package="transformers" 74 | ) from e 75 | 76 | if task not in task_column_names: 77 | raise ValueError(f"task ({task}) must be one of {task_column_names.keys()}") 78 | 79 | if (max_seq_length % 8) != 0: 80 | log.warning( 81 | "For performance, a max_seq_length as a multiple of 8 is recommended." 82 | ) 83 | 84 | tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name) # type: ignore (thirdparty) 85 | 86 | log.info(f"Loading {task.upper()} on rank {dist.get_global_rank()}") 87 | download_config = datasets.DownloadConfig(max_retries=max_retries) 88 | dataset = datasets.load_dataset( 89 | dataset_name, 90 | dataset_subset if dataset_subset is not None else task, 91 | split=split, 92 | download_config=download_config, 93 | ) 94 | 95 | log.info(f"Starting tokenization by preprocessing over {num_workers} threads!") 96 | text_column_names = task_column_names[task] 97 | 98 | if tokenize_fn_factory is None: 99 | # Calling the BERT tokenizer in this way will insert [SEP] between the 100 | # inputs, e.g. "[CLS] text [SEP] text_pair [SEP]". Without NSP, BERT is 101 | # not exposed to sequences with two [SEP] tokens during pretraining, 102 | # but finetuning on MNLI before finetuning on smaller datasets can help 103 | # the model get used to this. 104 | tokenize_fn_factory = lambda tokenizer, max_seq_length: lambda inp: tokenizer( 105 | text=inp[text_column_names[0]], 106 | text_pair=( 107 | inp[text_column_names[1]] if text_column_names[1] in inp else None 108 | ), 109 | padding="max_length", 110 | max_length=max_seq_length, 111 | truncation=True, 112 | ) 113 | 114 | columns_to_remove = [i for i in text_column_names if i is not None] 115 | 116 | assert isinstance(dataset, datasets.Dataset) 117 | dataset = dataset.map( 118 | tokenize_fn_factory(tokenizer, max_seq_length), 119 | batched=True, 120 | num_proc=None if num_workers == 0 else num_workers, 121 | batch_size=1000, 122 | remove_columns=columns_to_remove, 123 | load_from_cache_file=True, 124 | ) 125 | return dataset 126 | 127 | 128 | def create_glue_dataset(**kwargs): 129 | return create_eval_dataset( 130 | **kwargs, dataset_name="glue", task_column_names=_glue_task_column_names 131 | ) 132 | 133 | 134 | def create_superglue_dataset(**kwargs): 135 | return create_eval_dataset( 136 | **kwargs, 137 | dataset_name="aps/super_glue", 138 | task_column_names=_superglue_task_column_names, 139 | ) 140 | 141 | 142 | def create_swag_dataset(**kwargs): 143 | return create_eval_dataset( 144 | **kwargs, 145 | dataset_name="swag", 146 | dataset_subset="regular", 147 | task_column_names={ 148 | "swag": ("sent1", "sent2", "ending0", "ending1", "ending2", "ending3") 149 | }, 150 | ) 151 | 152 | def create_eurlex_dataset(**kwargs): 153 | return create_eval_dataset( 154 | **kwargs, 155 | dataset_name="coastalcph/lex_glue", 156 | dataset_subset="eurlex", 157 | task_column_names={"coastalcph/lex_glue": ("text",)}, 158 | ) 159 | 160 | def create_ultrafeedback_dataset(**kwargs): 161 | return create_eval_dataset( 162 | **kwargs, 163 | dataset_name="rbiswasfc/ultrafeedback-binary-classification", 164 | dataset_subset="", 165 | task_column_names={"rbiswasfc/ultrafeedback-binary-classification": ("prompt", "response_a", "response_b")}, 166 | ) 167 | 168 | def create_mlmmlu_dataset(**kwargs): 169 | dataset_subset = kwargs.pop("dataset_subset") 170 | 171 | if dataset_subset in ['Amateur', 'Semipro']: 172 | task_column_names= ("question", "options", "answer", "category", "cot_content", "src", "question_id", "llama_pred", "llama_correct") 173 | elif dataset_subset in ['Reserve', 'Rookie']: 174 | task_column_names= ("question", "choices", "category", "question_id", "llama_correct", "id_in_subset") 175 | else: 176 | raise NotImplementedError 177 | 178 | return create_eval_dataset( 179 | dataset_name="answerdotai/MLMMLU", 180 | dataset_subset=dataset_subset, 181 | task_column_names={"answerdotai/MLMMLU": task_column_names}, 182 | **kwargs, 183 | ) 184 | 185 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Optuna, Hugging Face 2 | # License: Apache-2.0 3 | 4 | # Copyright 2023 OLMo Authors 5 | # License: Apache-2.0 6 | 7 | import functools 8 | import logging 9 | from enum import Enum 10 | 11 | 12 | @functools.lru_cache(None) 13 | def warning_once(self, *args, **kwargs): 14 | """ 15 | This method is identical to `logger.warning()`, but will emit the warning with the same message only once 16 | 17 | Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache. 18 | The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to 19 | another type of cache that includes the caller frame information in the hashing function. 20 | """ 21 | self.warning(*args, **kwargs) 22 | 23 | 24 | logging.Logger.warning_once = warning_once 25 | logging.Logger.warn_once = warning_once 26 | 27 | 28 | class StrEnum(str, Enum): 29 | """ 30 | This is equivalent to Python's :class:`enum.StrEnum` since version 3.11. 31 | We include this here for compatibility with older version of Python. 32 | """ 33 | 34 | def __str__(self) -> str: 35 | return self.value 36 | 37 | def __repr__(self) -> str: 38 | return f"'{str(self)}'" 39 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/ModernBERT/8c57a0f01c12c4953ead53d398a36f81a4ba9e38/tests/__init__.py -------------------------------------------------------------------------------- /tests/smoketest_config_ablation_eval.yaml: -------------------------------------------------------------------------------- 1 | # Whether to run the various GLUE jobs serially or in parallel (must be `false` on CPU) 2 | parallel: false 3 | 4 | # Basic run configuration, additional details will be added to this name for each GLUE task, and each random seed 5 | base_run_name: ablation-eval-test 6 | default_seed: 1111 7 | precision: amp_bf16 8 | 9 | # Tokenizer for dataset creation 10 | tokenizer_name: bert-base-uncased 11 | 12 | # Base model config 13 | model: 14 | pretrained_model_name: prajjwal1/bert-tiny 15 | tokenizer_name: ${tokenizer_name} 16 | model_config: 17 | deterministic_fa2: true 18 | 19 | # Loading 20 | starting_checkpoint_load_path: # Start from scratch for the sake of testing 21 | local_pretrain_checkpoint_folder: ./local-bert-checkpoints/ 22 | 23 | # Saving 24 | save_finetune_checkpoint_prefix: ./local-finetune-checkpoints/ # (local) 25 | save_finetune_checkpoint_folder: ${save_finetune_checkpoint_prefix}/${base_run_name} 26 | 27 | # Callbacks 28 | callbacks: 29 | lr_monitor: {} 30 | speed_monitor: {} 31 | 32 | # Scheduler 33 | scheduler: 34 | name: linear_decay_with_warmup 35 | t_warmup: 0.5dur 36 | alpha_f: 0.0 37 | 38 | # Optimizer 39 | optimizer: 40 | name: decoupled_adamw 41 | lr: 3.0e-05 42 | betas: 43 | - 0.9 44 | - 0.98 45 | eps: 1.0e-06 46 | weight_decay: 3.0e-06 47 | 48 | # Task configuration 49 | tasks: 50 | mnli: 51 | trainer_kwargs: 52 | save_num_checkpoints_to_keep: 1 53 | max_duration: 2ba 54 | eval_subset_num_batches: 2 55 | boolq: 56 | trainer_kwargs: 57 | save_num_checkpoints_to_keep: 0 58 | max_duration: 2ba 59 | eval_subset_num_batches: 2 60 | wic: 61 | trainer_kwargs: 62 | save_num_checkpoints_to_keep: 0 63 | max_duration: 2ba 64 | eval_subset_num_batches: 2 65 | eurlex: 66 | trainer_kwargs: 67 | save_num_checkpoints_to_keep: 0 68 | max_duration: 2ba 69 | eval_subset_num_batches: 2 70 | model_config: 71 | problem_type: multi_label_classification -------------------------------------------------------------------------------- /tests/smoketest_config_classification.yaml: -------------------------------------------------------------------------------- 1 | tokenizer_name: prajjwal1/bert-tiny 2 | max_seq_len: 32 3 | 4 | # Run Name 5 | run_name: test 6 | 7 | # Model 8 | model: 9 | num_labels: 2 10 | pretrained_model_name: ${tokenizer_name} 11 | tokenizer_name: ${tokenizer_name} 12 | 13 | # Dataloaders 14 | train_loader: 15 | split: train 16 | tokenizer_name: ${tokenizer_name} 17 | max_seq_len: ${max_seq_len} 18 | shuffle: true 19 | drop_last: true 20 | num_workers: 4 21 | 22 | eval_loader: 23 | split: validation 24 | tokenizer_name: ${tokenizer_name} 25 | max_seq_len: ${max_seq_len} 26 | shuffle: false 27 | drop_last: false 28 | num_workers: 4 29 | 30 | # Optimization 31 | scheduler: 32 | name: linear_decay_with_warmup 33 | t_warmup: 0.5dur 34 | alpha_f: 0.02 35 | 36 | optimizer: 37 | name: decoupled_adamw 38 | lr: 2.0e-4 39 | betas: 40 | - 0.9 41 | - 0.95 42 | eps: 1.0e-08 43 | weight_decay: 0.0 44 | filter_bias_norm_wd: false 45 | 46 | # Training duration and evaluation frequency 47 | max_duration: 8ba 48 | eval_interval: 8ba 49 | eval_subset_num_batches: 2 50 | global_train_batch_size: 4 51 | 52 | # System 53 | seed: 17 54 | device_eval_microbatch_size: 4 55 | device_train_microbatch_size: 2 56 | precision: fp32 57 | 58 | # Logging 59 | progress_bar: false 60 | log_to_console: false 61 | console_log_interval: 1ba 62 | 63 | callbacks: 64 | speed_monitor: 65 | window_size: 4 66 | lr_monitor: {} 67 | -------------------------------------------------------------------------------- /tests/smoketest_config_glue.yaml: -------------------------------------------------------------------------------- 1 | # Whether to run the various GLUE jobs serially or in parallel (must be `false` on CPU) 2 | parallel: false 3 | 4 | # Basic run configuration, additional details will be added to this name for each GLUE task, and each random seed 5 | base_run_name: glue-finetuning-benchmark-test 6 | default_seed: 1111 7 | precision: fp32 8 | 9 | # Tokenizer for dataset creation 10 | tokenizer_name: bert-base-uncased 11 | 12 | # Base model config 13 | model: 14 | pretrained_model_name: prajjwal1/bert-tiny 15 | tokenizer_name: ${tokenizer_name} 16 | model_config: 17 | deterministic_fa2: true 18 | 19 | # Loading 20 | starting_checkpoint_load_path: # Start from scratch for the sake of testing 21 | local_pretrain_checkpoint_folder: ./local-bert-checkpoints/ 22 | 23 | # Saving 24 | save_finetune_checkpoint_prefix: ./local-finetune-checkpoints/ # (local) 25 | save_finetune_checkpoint_folder: ${save_finetune_checkpoint_prefix}/${base_run_name} 26 | 27 | # Callbacks 28 | callbacks: 29 | lr_monitor: {} 30 | speed_monitor: {} 31 | 32 | # Scheduler 33 | scheduler: 34 | name: linear_decay_with_warmup 35 | t_warmup: 0.5dur 36 | alpha_f: 0.0 37 | 38 | 39 | # Task configuration 40 | tasks: # Only run MNLI and RTE for the sake of testing 41 | mnli: 42 | # Specify any extra task-specific arguments for the trainer here 43 | trainer_kwargs: 44 | # We keep one MNLI checkpoint locally so that we can start finetuning of 45 | # RTE, MRPC and STS-B from the MNLI checkpoint 46 | save_num_checkpoints_to_keep: 1 47 | max_duration: 2ba 48 | eval_subset_num_batches: 2 49 | rte: 50 | trainer_kwargs: 51 | save_num_checkpoints_to_keep: 0 52 | max_duration: 2ba 53 | eval_subset_num_batches: 2 54 | -------------------------------------------------------------------------------- /tests/smoketest_config_main.yaml: -------------------------------------------------------------------------------- 1 | tokenizer_name: prajjwal1/bert-tiny 2 | max_seq_len: 32 3 | mlm_probability: 0.15 4 | 5 | # Run Name 6 | run_name: test 7 | 8 | # Model 9 | model: 10 | use_pretrained: false 11 | pretrained_model_name: ${tokenizer_name} 12 | tokenizer_name: ${tokenizer_name} 13 | 14 | # Dataloaders 15 | train_loader: 16 | name: text 17 | dataset: 18 | remote: 19 | local: 20 | split: train 21 | tokenizer_name: ${tokenizer_name} 22 | max_seq_len: ${max_seq_len} 23 | predownload: 1000 24 | shuffle: true 25 | mlm_probability: ${mlm_probability} 26 | num_canonical_nodes: 8 27 | drop_last: true 28 | num_workers: 4 29 | 30 | eval_loader: 31 | name: text 32 | dataset: 33 | remote: 34 | local: 35 | split: val 36 | tokenizer_name: ${tokenizer_name} 37 | max_seq_len: ${max_seq_len} 38 | predownload: 1000 39 | shuffle: false 40 | mlm_probability: ${mlm_probability} 41 | num_canonical_nodes: 8 42 | drop_last: false 43 | num_workers: 4 44 | 45 | # Optimization 46 | scheduler: 47 | name: linear_decay_with_warmup 48 | t_warmup: 0.5dur 49 | alpha_f: 0.02 50 | 51 | optimizer: 52 | name: decoupled_adamw 53 | lr: 2.0e-4 54 | betas: 55 | - 0.9 56 | - 0.95 57 | eps: 1.0e-08 58 | weight_decay: 0.0 59 | filter_bias_norm_wd: false 60 | 61 | # Training duration and evaluation frequency 62 | max_duration: 8ba 63 | eval_interval: 8ba 64 | global_train_batch_size: 4 65 | 66 | # System 67 | seed: 17 68 | device_eval_microbatch_size: 4 69 | device_train_microbatch_size: 2 70 | precision: fp32 71 | 72 | # Logging 73 | progress_bar: false 74 | log_to_console: false 75 | console_log_interval: 1ba 76 | 77 | callbacks: 78 | speed_monitor: 79 | window_size: 4 80 | lr_monitor: {} 81 | 82 | algorithms: 83 | gradient_clipping: 84 | clipping_type: norm 85 | clipping_threshold: 1.0 -------------------------------------------------------------------------------- /tests/smoketest_config_sdpa_fa2.yaml: -------------------------------------------------------------------------------- 1 | tokenizer_name: prajjwal1/bert-tiny 2 | max_seq_len: 32 3 | mlm_probability: 0.15 4 | 5 | # Run Name 6 | run_name: test 7 | 8 | # Model 9 | model: 10 | use_pretrained: false 11 | pretrained_model_name: ${tokenizer_name} 12 | tokenizer_name: ${tokenizer_name} 13 | 14 | # Dataloaders 15 | train_loader: 16 | name: text 17 | dataset: 18 | remote: 19 | local: 20 | split: train 21 | tokenizer_name: ${tokenizer_name} 22 | max_seq_len: ${max_seq_len} 23 | predownload: 1000 24 | shuffle: true 25 | mlm_probability: ${mlm_probability} 26 | num_canonical_nodes: 8 27 | drop_last: true 28 | num_workers: 4 29 | 30 | eval_loader: 31 | name: text 32 | dataset: 33 | remote: 34 | local: 35 | split: val 36 | tokenizer_name: ${tokenizer_name} 37 | max_seq_len: ${max_seq_len} 38 | predownload: 1000 39 | shuffle: false 40 | mlm_probability: ${mlm_probability} 41 | num_canonical_nodes: 8 42 | drop_last: false 43 | num_workers: 4 44 | 45 | # Optimization 46 | scheduler: 47 | name: linear_decay_with_warmup 48 | t_warmup: 0.5dur 49 | alpha_f: 0.02 50 | 51 | optimizer: 52 | name: decoupled_adamw 53 | lr: 2.0e-4 54 | betas: 55 | - 0.9 56 | - 0.95 57 | eps: 1.0e-08 58 | weight_decay: 0.01 59 | filter_bias_norm_wd: true 60 | 61 | # Training duration and evaluation frequency 62 | max_duration: 8ba 63 | eval_interval: 8ba 64 | global_train_batch_size: 4 65 | 66 | # System 67 | seed: 17 68 | device_eval_microbatch_size: 4 69 | device_train_microbatch_size: 2 70 | precision: amp_bf16 71 | 72 | # Logging 73 | progress_bar: false 74 | log_to_console: false 75 | console_log_interval: 1ba 76 | 77 | callbacks: 78 | speed_monitor: 79 | window_size: 4 80 | lr_monitor: {} -------------------------------------------------------------------------------- /tests/smoketest_config_superglue.yaml: -------------------------------------------------------------------------------- 1 | # Whether to run the various GLUE jobs serially or in parallel (must be `false` on CPU) 2 | parallel: false 3 | 4 | # Basic run configuration, additional details will be added to this name for each GLUE task, and each random seed 5 | base_run_name: superglue-finetuning-benchmark-test 6 | default_seed: 1111 7 | precision: fp32 8 | 9 | # Tokenizer for dataset creation 10 | tokenizer_name: bert-base-uncased 11 | 12 | # Base model config 13 | model: 14 | pretrained_model_name: prajjwal1/bert-tiny 15 | tokenizer_name: ${tokenizer_name} 16 | model_config: 17 | deterministic_fa2: true 18 | 19 | # Loading 20 | starting_checkpoint_load_path: # Start from scratch for the sake of testing 21 | local_pretrain_checkpoint_folder: ./local-bert-checkpoints/ 22 | 23 | # Saving 24 | save_finetune_checkpoint_prefix: ./local-finetune-checkpoints/ # (local) 25 | save_finetune_checkpoint_folder: ${save_finetune_checkpoint_prefix}/${base_run_name} 26 | 27 | # Callbacks 28 | callbacks: 29 | lr_monitor: {} 30 | speed_monitor: {} 31 | 32 | # Scheduler 33 | scheduler: 34 | name: linear_decay_with_warmup 35 | t_warmup: 0.5dur 36 | alpha_f: 0.0 37 | 38 | 39 | # Task configuration 40 | tasks: # Only run SWAG and COPA for the sake of testing 41 | swag: 42 | # Specify any extra task-specific arguments for the trainer here 43 | trainer_kwargs: 44 | # We keep one SWAG checkpoint locally so that we can start finetuning of 45 | # COPA from the SWAG checkpoint 46 | save_num_checkpoints_to_keep: 1 47 | max_duration: 2ba 48 | eval_subset_num_batches: 2 49 | copa: 50 | trainer_kwargs: 51 | save_num_checkpoints_to_keep: 0 52 | max_duration: 2ba 53 | eval_subset_num_batches: 2 -------------------------------------------------------------------------------- /tests/test_classification.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MosaicML Examples authors 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pytest 5 | from omegaconf import DictConfig, OmegaConf 6 | from ..sequence_classification import train 7 | 8 | 9 | @pytest.mark.parametrize("model_name", ["mosaic_bert", "hf_bert", "flex_bert"]) 10 | def test_classification_script(model_name): 11 | with open("yamls/defaults.yaml") as f: 12 | default_cfg = OmegaConf.load(f) 13 | with open(f"yamls/models/{model_name}.yaml") as f: 14 | model_cfg = OmegaConf.load(f) 15 | with open("tests/smoketest_config_classification.yaml") as f: 16 | test_config = OmegaConf.load(f) 17 | config = OmegaConf.merge(default_cfg, model_cfg, test_config) 18 | assert isinstance(config, DictConfig) 19 | 20 | # The test is that `main` runs successfully 21 | train(config) 22 | -------------------------------------------------------------------------------- /tests/test_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MosaicML Examples authors 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import sys 5 | import os 6 | import shutil 7 | import tempfile 8 | from typing import Any 9 | 10 | import pytest 11 | 12 | # Add tests folder root to path to allow us to use relative imports regardless of what directory the script is run from 13 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 14 | # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from 15 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 16 | from eval import train 17 | from omegaconf import DictConfig, OmegaConf 18 | 19 | 20 | class AblationDirContext(object): 21 | def __init__(self): 22 | self.path = None 23 | 24 | def __enter__(self): 25 | self.path = tempfile.mkdtemp() 26 | return self.path 27 | 28 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): 29 | del exc_type, exc_value, traceback # unused 30 | if self.path is not None: 31 | shutil.rmtree(self.path) 32 | 33 | 34 | @pytest.mark.parametrize("model_name", ["mosaic_bert", "hf_bert", "flex_bert"]) 35 | def test_eval_script(model_name: str): 36 | with open("yamls/defaults.yaml") as f: 37 | default_cfg = OmegaConf.load(f) 38 | with open(f"yamls/models/{model_name}.yaml") as f: 39 | model_cfg = OmegaConf.load(f) 40 | with open("tests/smoketest_config_glue.yaml") as f: 41 | test_config = OmegaConf.load(f) 42 | config = OmegaConf.merge(default_cfg, model_cfg, test_config) 43 | assert isinstance(config, DictConfig) 44 | config.model.name = model_name 45 | 46 | if ( 47 | model_name == "flex_bert" 48 | and not config.model.model_config.use_fa2 49 | and config.model.model_config.padding == "unpadded" 50 | ): 51 | pytest.skip("SDPA call currently errors with Glue test on unpadded inputs") 52 | 53 | # The test is that `train` runs successfully 54 | with AblationDirContext() as local_save_dir: 55 | config.save_finetune_checkpoint_prefix = local_save_dir 56 | train(config) 57 | -------------------------------------------------------------------------------- /tests/test_glue.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MosaicML Examples authors 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import sys 5 | import os 6 | import shutil 7 | import tempfile 8 | from typing import Any 9 | 10 | import pytest 11 | 12 | # Add tests folder root to path to allow us to use relative imports regardless of what directory the script is run from 13 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 14 | # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from 15 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 16 | from glue import train 17 | from omegaconf import DictConfig, OmegaConf 18 | 19 | 20 | class GlueDirContext(object): 21 | def __init__(self): 22 | self.path = None 23 | 24 | def __enter__(self): 25 | self.path = tempfile.mkdtemp() 26 | return self.path 27 | 28 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): 29 | del exc_type, exc_value, traceback # unused 30 | if self.path is not None: 31 | shutil.rmtree(self.path) 32 | 33 | 34 | @pytest.mark.parametrize("model_name", ["mosaic_bert", "hf_bert", "flex_bert"]) 35 | def test_glue_script(model_name: str): 36 | with open("yamls/defaults.yaml") as f: 37 | default_cfg = OmegaConf.load(f) 38 | with open(f"yamls/models/{model_name}.yaml") as f: 39 | model_cfg = OmegaConf.load(f) 40 | with open("tests/smoketest_config_glue.yaml") as f: 41 | test_config = OmegaConf.load(f) 42 | config = OmegaConf.merge(default_cfg, model_cfg, test_config) 43 | assert isinstance(config, DictConfig) 44 | config.model.name = model_name 45 | 46 | if ( 47 | model_name == "flex_bert" 48 | and not config.model.model_config.use_fa2 49 | and config.model.model_config.padding == "unpadded" 50 | ): 51 | pytest.skip("SDPA call currently errors with Glue test on unpadded inputs") 52 | 53 | # The test is that `train` runs successfully 54 | with GlueDirContext() as local_save_dir: 55 | config.save_finetune_checkpoint_prefix = local_save_dir 56 | train(config) 57 | -------------------------------------------------------------------------------- /tests/test_main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MosaicML Examples authors 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import sys 6 | 7 | import pytest 8 | import torch 9 | from omegaconf import DictConfig, OmegaConf 10 | 11 | # Add tests folder root to path to allow us to use relative imports regardless of what directory the script is run from 12 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 13 | # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from 14 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 15 | from main import main 16 | from test_utils import SynthTextDirectory 17 | 18 | 19 | @pytest.mark.parametrize("model_name,seed", [("mosaic_bert", 17), ("hf_bert", 18), ("flex_bert", 42)]) 20 | def test_trainer(model_name: str, seed: int): 21 | with open("yamls/defaults.yaml") as f: 22 | default_cfg = OmegaConf.load(f) 23 | with open(f"yamls/models/{model_name}.yaml") as f: 24 | model_cfg = OmegaConf.load(f) 25 | with open("tests/smoketest_config_main.yaml") as f: 26 | test_config = OmegaConf.load(f) 27 | config = OmegaConf.merge(default_cfg, model_cfg, test_config) 28 | assert isinstance(config, DictConfig) 29 | config.model.name = model_name 30 | config.seed = seed 31 | 32 | with SynthTextDirectory() as tmp_datadir: 33 | config.train_loader.dataset.remote = tmp_datadir 34 | config.train_loader.dataset.local = os.path.join(tmp_datadir, "tr-local1") 35 | config.eval_loader.dataset.remote = tmp_datadir 36 | config.eval_loader.dataset.local = os.path.join(tmp_datadir, "ev-local1") 37 | # Also save checkpoints in the temporary directory 38 | config.save_folder = tmp_datadir 39 | 40 | # Train 41 | trainer1 = main(config, return_trainer=True) 42 | assert trainer1 is not None 43 | model1 = trainer1.state.model.model 44 | 45 | # Check that the checkpoint was saved 46 | chkpt_path = os.path.join(tmp_datadir, "latest-rank0.pt") 47 | assert os.path.isfile(chkpt_path), f"{os.listdir(tmp_datadir)}" 48 | 49 | # Check that the checkpoint was loaded by comparing model weights (with no weight changes) 50 | config.load_path = chkpt_path 51 | config.seed += 10 # change seed 52 | config.train_loader.dataset.local = os.path.join(tmp_datadir, "tr-local2") 53 | config.eval_loader.dataset.local = os.path.join(tmp_datadir, "ev-local2") 54 | trainer2 = main(config, return_trainer=True, do_train=False) 55 | assert trainer2 is not None 56 | model2 = trainer2.state.model.model 57 | 58 | for param1, param2 in zip(model1.parameters(), model2.parameters()): 59 | torch.testing.assert_close(param1, param2) 60 | -------------------------------------------------------------------------------- /tests/test_mlm_masking.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | import sys 5 | import os 6 | import pytest 7 | import numpy as np 8 | 9 | # Add tests folder root to path to allow us to use relative imports regardless of what directory the script is run from 10 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 11 | # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | 14 | from src.sequence_packer import SequencePacker 15 | 16 | 17 | import pytest 18 | import numpy as np 19 | 20 | 21 | @pytest.mark.parametrize("mask_prob", [0.1, 0.15, 0.3, 0.5]) 22 | def test_mlm_masking(mask_prob): 23 | # Test setup 24 | seq = np.arange(100_000) # A sequence of 100,000 tokens 25 | mask_token = -1 26 | ignore_index = -100 27 | 28 | # Run the function 29 | masked_seq, labels = SequencePacker.mlm_masking( 30 | seq.copy(), mask_prob=mask_prob, mask_token=mask_token, ignore_index=ignore_index 31 | ) 32 | 33 | # Test 1 and 2: Check if the output types and shapes are correct 34 | if not (isinstance(masked_seq, np.ndarray) and isinstance(labels, np.ndarray)): 35 | raise ValueError("Output types are not correct. Expected NumPy arrays.") 36 | if not (masked_seq.shape == labels.shape == seq.shape): 37 | raise ValueError("Output shapes are not correct.") 38 | 39 | # Test 3: Check 80-10-10 rule 40 | masked_indices = labels != ignore_index 41 | total_masked = np.sum(masked_indices) 42 | 43 | if total_masked > 0: 44 | replaced_by_mask = np.sum((masked_seq == mask_token) & masked_indices) 45 | replaced_by_random = np.sum((masked_seq != mask_token) & (masked_seq != seq) & masked_indices) 46 | kept_unchanged = np.sum((masked_seq == seq) & masked_indices) 47 | 48 | mask_ratio = replaced_by_mask / total_masked 49 | random_ratio = replaced_by_random / total_masked 50 | unchanged_ratio = kept_unchanged / total_masked 51 | 52 | if not 0.79 < mask_ratio < 0.81: 53 | raise ValueError(f"Mask token ratio ({mask_ratio:.4f}) is out of expected range [0.79, 0.81]") 54 | if not 0.09 < random_ratio < 0.11: 55 | raise ValueError(f"Random token ratio ({random_ratio:.4f}) is out of expected range [0.09, 0.11]") 56 | if not 0.09 < unchanged_ratio < 0.11: 57 | raise ValueError(f"Unchanged token ratio ({unchanged_ratio:.4f}) is out of expected range [0.09, 0.11]") 58 | 59 | # Test 4: Check overall masking probability 60 | actual_mask_prob = np.mean(masked_indices) 61 | if not mask_prob - 0.01 < actual_mask_prob < mask_prob + 0.01: 62 | raise ValueError( 63 | f"Actual masking probability ({actual_mask_prob:.4f}) is too far from requested probability ({mask_prob})" 64 | ) -------------------------------------------------------------------------------- /tests/test_padding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import sys 4 | import os 5 | 6 | # Add tests folder root to path to allow us to use relative imports regardless of what directory the script is run from 7 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 8 | # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from 9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 10 | from src.bert_layers.padding import unpad_input, pad_input 11 | 12 | 13 | @pytest.fixture 14 | def sample_data(): 15 | batch, seqlen, hidden_dim = 2, 4, 3 16 | inputs = torch.randn(batch, seqlen, hidden_dim) 17 | attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]], dtype=torch.int32) 18 | position_ids = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.long) 19 | labels = torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]], dtype=torch.long) 20 | return inputs, attention_mask, position_ids, labels 21 | 22 | 23 | def test_unpad_input(sample_data): 24 | inputs, attention_mask, position_ids, labels = sample_data 25 | unpadded_inputs, indices, cu_seqlens, max_seqlen, unpadded_position_ids, unpadded_labels = unpad_input( 26 | inputs, attention_mask, position_ids, labels 27 | ) 28 | 29 | assert unpadded_inputs.shape == (5, 3) # 5 valid tokens, hidden_dim = 3 30 | assert indices.tolist() == [0, 1, 2, 4, 5] 31 | assert cu_seqlens.tolist() == [0, 3, 5] 32 | assert max_seqlen == 3 33 | assert unpadded_position_ids.tolist() == [0, 1, 2, 0, 1] 34 | assert unpadded_labels.tolist() == [1, 2, 3, 4, 5] 35 | 36 | 37 | def test_pad_input(sample_data): 38 | inputs, attention_mask, _, labels = sample_data 39 | unpadded_inputs, indices, _, _, _, unpadded_labels = unpad_input(inputs, attention_mask, labels=labels) 40 | 41 | padded_inputs, padded_labels = pad_input(unpadded_inputs, indices, batch=2, seqlen=4, labels=unpadded_labels) 42 | 43 | assert padded_inputs.shape == (2, 4, 3) 44 | assert torch.allclose(padded_inputs[attention_mask.bool()], unpadded_inputs) 45 | assert torch.all(padded_inputs[~attention_mask.bool()] == 0) 46 | assert torch.all(padded_labels[attention_mask.bool()] == unpadded_labels) 47 | assert torch.all(padded_labels[~attention_mask.bool()] == -100) 48 | 49 | 50 | def test_roundtrip(sample_data): 51 | inputs, attention_mask, _, labels = sample_data 52 | unpadded_inputs, indices, _, _, _, unpadded_labels = unpad_input(inputs, attention_mask, labels=labels) 53 | padded_inputs, padded_labels = pad_input(unpadded_inputs, indices, batch=2, seqlen=4, labels=unpadded_labels) 54 | 55 | assert torch.allclose(inputs[attention_mask.bool()], padded_inputs[attention_mask.bool()]) 56 | assert torch.all(labels == padded_labels) 57 | 58 | 59 | def test_token_input(): 60 | batch, seqlen, vocab_size = 2, 4, 1000 61 | token_ids = torch.randint(0, vocab_size, (batch, seqlen)) 62 | attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]], dtype=torch.int32) 63 | 64 | unpadded_inputs, indices, _, _, _, _ = unpad_input(token_ids, attention_mask) 65 | 66 | assert unpadded_inputs.shape == (5,) # 5 valid tokens 67 | assert unpadded_inputs.dtype == torch.long 68 | 69 | padded_inputs, _ = pad_input(unpadded_inputs, indices, batch=2, seqlen=4) 70 | 71 | assert padded_inputs.shape == (2, 4) 72 | assert padded_inputs.dtype == torch.long 73 | assert torch.all(padded_inputs[attention_mask.bool()] == unpadded_inputs) 74 | assert torch.all(padded_inputs[~attention_mask.bool()] == 0) 75 | 76 | 77 | def test_2d_input(): 78 | batch, seqlen = 2, 4 79 | inputs = torch.randn(batch, seqlen) 80 | attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]], dtype=torch.int32) 81 | 82 | unpadded_inputs, indices, cu_seqlens, max_seqlen, _, _ = unpad_input(inputs, attention_mask) 83 | 84 | assert unpadded_inputs.shape == (5,) # 5 valid tokens 85 | assert indices.tolist() == [0, 1, 2, 4, 5] 86 | assert cu_seqlens.tolist() == [0, 3, 5] 87 | assert max_seqlen == 3 88 | 89 | padded_inputs, _ = pad_input(unpadded_inputs, indices, batch=2, seqlen=4) 90 | 91 | assert padded_inputs.shape == (2, 4) 92 | assert torch.allclose(padded_inputs[attention_mask.bool()], unpadded_inputs) 93 | assert torch.all(padded_inputs[~attention_mask.bool()] == 0) 94 | -------------------------------------------------------------------------------- /tests/test_superglue.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MosaicML Examples authors 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import sys 5 | import os 6 | import shutil 7 | import tempfile 8 | from typing import Any 9 | 10 | import pytest 11 | 12 | # Add tests folder root to path to allow us to use relative imports regardless of what directory the script is run from 13 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 14 | # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from 15 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 16 | from glue import train 17 | from omegaconf import DictConfig, OmegaConf 18 | 19 | 20 | class SuperGlueDirContext(object): 21 | def __init__(self): 22 | self.path = None 23 | 24 | def __enter__(self): 25 | self.path = tempfile.mkdtemp() 26 | return self.path 27 | 28 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): 29 | del exc_type, exc_value, traceback # unused 30 | if self.path is not None: 31 | shutil.rmtree(self.path) 32 | 33 | 34 | @pytest.mark.parametrize("model_name", ["mosaic_bert", "hf_bert", "flex_bert"]) 35 | def test_superglue_script(model_name: str): 36 | with open("yamls/defaults.yaml") as f: 37 | default_cfg = OmegaConf.load(f) 38 | with open(f"yamls/models/{model_name}.yaml") as f: 39 | model_cfg = OmegaConf.load(f) 40 | with open("tests/smoketest_config_superglue.yaml") as f: 41 | test_config = OmegaConf.load(f) 42 | config = OmegaConf.merge(default_cfg, model_cfg, test_config) 43 | assert isinstance(config, DictConfig) 44 | config.model.name = model_name 45 | 46 | if ( 47 | model_name == "flex_bert" 48 | and not config.model.model_config.use_fa2 49 | and config.model.model_config.padding == "unpadded" 50 | ): 51 | pytest.skip("SDPA call currently errors with SuperGlue test on unpadded inputs") 52 | 53 | # The test is that `train` runs successfully 54 | with SuperGlueDirContext() as local_save_dir: 55 | config.save_finetune_checkpoint_prefix = local_save_dir 56 | train(config) 57 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MosaicML Examples authors 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import random 6 | import shutil 7 | import tempfile 8 | from typing import Any 9 | 10 | import numpy as np 11 | import streaming 12 | 13 | 14 | class SynthTextDirectory(object): 15 | def __enter__(self): 16 | path = create_synthetic_text_dataset() 17 | self.path = path # type: ignore (reportUninitializedInstanceVariable) 18 | return self.path 19 | 20 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): 21 | del exc_type, exc_value, traceback # Unused 22 | shutil.rmtree(self.path) 23 | 24 | 25 | def create_synthetic_text_dataset(n_samples: int = 16): 26 | tmp_dirname = tempfile.mkdtemp() 27 | 28 | for split in ["train", "val"]: 29 | dirname = os.path.join(tmp_dirname, split) 30 | hashes = ["sha1", "xxh64"] 31 | size_limit = 1 << 25 32 | with streaming.MDSWriter(columns={"text": "str"}, out=dirname, hashes=hashes, size_limit=size_limit) as out: 33 | for _ in range(n_samples): 34 | n_letters = np.random.randint(low=5, high=256) 35 | letter_str = " ".join([random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(n_letters)]) 36 | out.write({"text": letter_str}) 37 | 38 | return tmp_dirname 39 | -------------------------------------------------------------------------------- /wandb_log_live_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 onwards Answer.AI, LightOn, and contributors 2 | # License: Apache-2.0 3 | 4 | import argparse 5 | import re 6 | import time 7 | from datetime import datetime 8 | 9 | import pandas as pd 10 | import schedule 11 | import wandb 12 | 13 | 14 | def parse_model_string(s): 15 | pattern = r"(bert24-(base|large)(?:-v\d+)?(?:-\w+)?)-ba(\d+)_task=(\w+)(?:_\w+)?_seed=(\d+)" 16 | match = re.match(pattern, s) 17 | if match: 18 | full_model, size, batch, task, seed = match.groups() 19 | return {"model": full_model, "size": size, "batch": int(batch), "task": task, "seed": int(seed)} 20 | else: 21 | raise ValueError(f"Could not parse model string: {s}") 22 | 23 | 24 | def init_run(args): 25 | # Initialize meta W&B run 26 | wandb.init(project=args.meta_project, name=f"{args.meta_run_name}") 27 | meta_run_id = wandb.run.id 28 | wandb.finish() 29 | print(f"Initialized meta run with ID: {meta_run_id}") 30 | return meta_run_id 31 | 32 | 33 | def process_data(args): 34 | print(f"Starting data processing at {datetime.now()}") 35 | 36 | # Get runs from source eval project 37 | api = wandb.Api() 38 | runs = api.runs(f"{args.entity}/{args.source_project}") 39 | 40 | # Process data 41 | stats = [] 42 | for run in runs: 43 | if run.state != "finished" or "task=" not in run.name: 44 | continue 45 | try: 46 | meta = parse_model_string(run.name) 47 | except ValueError: 48 | print(f"Skipping run with unparseable name: {run.name}") 49 | continue 50 | task = meta["task"] 51 | summary = run.summary 52 | 53 | for m in args.task2metric_dict[task]: 54 | val = summary.get(m) 55 | if val: 56 | stats.append({**meta, "metric": m, "score": val}) 57 | 58 | # Aggregate stats 59 | stats_df = pd.DataFrame(stats) 60 | print(f"available models: {stats_df.model.unique().tolist()}") 61 | stats_df = stats_df[stats_df["model"] == args.model_name] 62 | 63 | grouped_df = stats_df.groupby(["model", "size", "batch", "task", "metric"])["score"].mean().reset_index() 64 | count_df = stats_df.groupby(["model", "size", "batch", "task", "metric"])["score"].count().reset_index() 65 | count_df.rename(columns={"score": "count"}, inplace=True) 66 | grouped_df = pd.merge(grouped_df, count_df, on=["model", "size", "batch", "task", "metric"]) 67 | 68 | # Log metrics to W&B 69 | batch_ticks = sorted(grouped_df["batch"].unique().tolist()) 70 | all_metrics = args.all_metrics # sorted(grouped_df["metric"].unique().tolist()) 71 | grouped_df = grouped_df[grouped_df["metric"].isin(all_metrics)] 72 | print(batch_ticks) 73 | 74 | with wandb.init(project=args.meta_project, job_type="eval", id=args.meta_run_id, resume="must") as run: 75 | for step in batch_ticks: 76 | # check if all metrics are computed for the current batch 77 | for metric in all_metrics: 78 | ex = grouped_df[(grouped_df["batch"] == step) & (grouped_df["metric"] == metric)] 79 | if len(ex) == 0 or ex["count"].values[0] < args.metric2num_seeds[metric]: 80 | print(f"insufficient data for step={step} and metric={metric}") 81 | print(f"Logged up to step < {step}") 82 | return 83 | 84 | for metric in all_metrics: 85 | ex = grouped_df[(grouped_df["batch"] == step) & (grouped_df["metric"] == metric)] 86 | if len(ex) == 1: 87 | if ex["count"].values[0] >= args.metric2num_seeds[metric]: 88 | score = ex["score"].values[0] 89 | run.log({metric: score}, step=step) 90 | 91 | print(f"Finished data processing at {datetime.now()}") 92 | 93 | 94 | def main(): 95 | parser = argparse.ArgumentParser(description="W&B Logging Script") 96 | parser.add_argument("--entity", type=str, default="bert24", help="W&B entity name") 97 | parser.add_argument("--meta-project", type=str, default="bert24-evals-meta", help="meta project name") 98 | parser.add_argument("--model-name", type=str, default="bert24-large-v2", help="Model name") 99 | parser.add_argument("--meta-run-id", type=str, help="ID of the meta run to update") 100 | parser.add_argument("--meta-run-name", type=str, default="bert24-large-v2-evals", help="Meta run name") 101 | 102 | parser.add_argument("--source-project", type=str, default="bert24-large-v2-evals", help="project for eval runs") 103 | parser.add_argument("--interval", type=int, default=60, help="Interval in minutes between data refresh") 104 | parser.add_argument("--init-meta", action="store_true", help="Initialize a new meta run") 105 | 106 | args = parser.parse_args() 107 | 108 | # metadata information --- 109 | args.task2metric_dict = { 110 | "mnli": ["metrics/glue_mnli/MulticlassAccuracy", "metrics/glue_mnli_mismatched/MulticlassAccuracy"], 111 | "ultrafeedback": ["metrics/long_context_ultrafeedback/UltrafeedbackAUROC"], 112 | "mlmmlu_rookie_reserve": [ 113 | "metrics/mlmmlu_rookie/MulticlassAccuracy", 114 | "metrics/mlmmlu_reserve/MulticlassAccuracy", 115 | ], 116 | "wic": ["metrics/superglue_wic/MulticlassAccuracy"], 117 | "boolq": ["metrics/superglue_boolq/MulticlassAccuracy"], 118 | } 119 | 120 | args.metric2num_seeds = { 121 | "metrics/glue_mnli/MulticlassAccuracy": 3, 122 | "metrics/glue_mnli_mismatched/MulticlassAccuracy": 3, 123 | "metrics/mlmmlu_rookie/MulticlassAccuracy": 3, 124 | "metrics/mlmmlu_reserve/MulticlassAccuracy": 3, 125 | "metrics/superglue_wic/MulticlassAccuracy": 3, 126 | "metrics/superglue_boolq/MulticlassAccuracy": 3, 127 | "metrics/long_context_ultrafeedback/UltrafeedbackAUROC": 2, 128 | } 129 | 130 | args.all_metrics = [ 131 | "metrics/glue_mnli/MulticlassAccuracy", 132 | "metrics/glue_mnli_mismatched/MulticlassAccuracy", 133 | # "metrics/mlmmlu_rookie/MulticlassAccuracy", 134 | # "metrics/mlmmlu_reserve/MulticlassAccuracy", 135 | "metrics/superglue_wic/MulticlassAccuracy", 136 | "metrics/superglue_boolq/MulticlassAccuracy", 137 | ] 138 | 139 | if args.init_meta: 140 | meta_run_id = init_run(args) 141 | print(f"Use this meta_run_id for future runs: {meta_run_id}") 142 | return 143 | 144 | if not args.meta_run_id: 145 | parser.error("--meta-run-id is required when not initializing a new meta run") 146 | 147 | schedule.every(args.interval).minutes.do(process_data, args) 148 | process_data(args) # first run 149 | 150 | while True: 151 | try: 152 | schedule.run_pending() 153 | time.sleep(30) 154 | except KeyboardInterrupt: 155 | print("Scheduler stopped by user. Exiting...") 156 | break 157 | 158 | 159 | if __name__ == "__main__": 160 | main() 161 | 162 | ## Usage 163 | # python wandb_log_live_eval.py --init-meta --model-name <> --meta-project <> --meta-run-name "<>-evals" 164 | -------------------------------------------------------------------------------- /yamls/ablations/example-config.yaml: -------------------------------------------------------------------------------- 1 | parallel: true 2 | base_run_name: mosaic-bert 3 | default_seed: 19 4 | precision: amp_bf16 5 | tokenizer_name: bclavie/bert24_32k_tok_llama2 6 | model: 7 | name: mosaic_bert 8 | use_pretrained: true 9 | pretrained_model_name: bert-base-uncased 10 | tokenizer_name: ${tokenizer_name} 11 | model_config: 12 | num_attention_heads: 12 13 | num_hidden_layers: 12 14 | head_pred_act: gelu 15 | hidden_act: gelu 16 | normalization: layernorm 17 | allow_embedding_resizing: true 18 | attention_probs_dropout_prob: 0.0 19 | use_fa2: true 20 | head_class_norm: null 21 | head_class_act: tanh 22 | starting_checkpoint_load_path: latest-rank0.pt 23 | local_pretrain_checkpoint_folder: /home/shared/data-ablations/checkpoints/mosaic-bert-1024 24 | save_finetune_checkpoint_prefix: ./finetuned-checkpoints 25 | save_finetune_checkpoint_folder: ${save_finetune_checkpoint_prefix}/${base_run_name} 26 | loggers: 27 | wandb: 28 | project: bert24-data-ablations-evals 29 | entity: bert24 30 | callbacks: 31 | lr_monitor: {} 32 | speed_monitor: {} 33 | scheduler: 34 | name: linear_decay_with_warmup 35 | t_warmup: 0.06dur 36 | alpha_f: 0.0 37 | tasks: 38 | mlmmlu_amateur_semipro: 39 | seeds: 40 | - 233 41 | - 331 42 | - 461 43 | - 567 44 | trainer_kwargs: 45 | save_num_checkpoints_to_keep: 0 46 | mlmmlu_rookie_reserve: 47 | seeds: 48 | - 233 49 | - 331 50 | - 461 51 | - 567 52 | trainer_kwargs: 53 | save_num_checkpoints_to_keep: 0 54 | -------------------------------------------------------------------------------- /yamls/baselines/bert-base-uncased-superglue.yaml: -------------------------------------------------------------------------------- 1 | # Whether to run the various GLUE jobs serially or in parallel (use parallel=True to take advantage of multiple GPUs) 2 | parallel: true 3 | 4 | # Basic run configuration, additional details will be added to this name for each GLUE task, and each random seed 5 | base_run_name: bert-base-uncased-superglue-test 6 | default_seed: 19 7 | precision: amp_bf16 8 | 9 | # Tokenizer for dataset creation 10 | tokenizer_name: bert-base-uncased 11 | 12 | # Base model config 13 | model: 14 | name: hf_bert 15 | use_pretrained: true 16 | pretrained_model_name: ${tokenizer_name} 17 | tokenizer_name: ${tokenizer_name} 18 | 19 | # Saving 20 | save_finetune_checkpoint_prefix: ./bert-finetune-checkpoints 21 | save_finetune_checkpoint_folder: ${save_finetune_checkpoint_prefix}/${base_run_name} 22 | 23 | # (Optional) W&B logging 24 | # loggers: 25 | # wandb: 26 | # project: # Fill this in if using W&B 27 | # entity: # Fill this in if using W&B 28 | 29 | # Callbacks 30 | callbacks: 31 | lr_monitor: {} 32 | speed_monitor: {} 33 | 34 | # Scheduler 35 | scheduler: 36 | name: linear_decay_with_warmup 37 | t_warmup: 0.06dur 38 | alpha_f: 0.0 39 | 40 | # Task configuration 41 | tasks: 42 | mnli: 43 | trainer_kwargs: 44 | # MNLI is not part of SuperGLUE, but we include it here because best 45 | # practice for evaluating RTE involves starting from an MNLI checkpoint, 46 | # which is why we keep one MNLI checkpoint locally. 47 | save_num_checkpoints_to_keep: 1 48 | swag: 49 | trainer_kwargs: 50 | # SWAG is not part of SuperGLUE, but it is commonly used as a first step 51 | # in the process of fine-tuning COPA, which is why we keep one SWAG 52 | # checkpoint locally as well. 53 | save_num_checkpoints_to_keep: 1 54 | boolq: 55 | seeds: [23, 42, 6033] 56 | trainer_kwargs: 57 | save_num_checkpoints_to_keep: 0 58 | cb: 59 | seeds: [23, 42, 6033] 60 | trainer_kwargs: 61 | save_num_checkpoints_to_keep: 0 62 | rte: 63 | seeds: [19, 8364, 717, 10536, 90166] 64 | trainer_kwargs: 65 | save_num_checkpoints_to_keep: 0 66 | wic: 67 | seeds: [23, 42, 6033] 68 | trainer_kwargs: 69 | save_num_checkpoints_to_keep: 0 70 | copa: 71 | seeds: [23, 42, 6033, 1337, 24] 72 | trainer_kwargs: 73 | save_num_checkpoints_to_keep: 0 74 | multirc: 75 | seeds: [23, 42, 6033, 1337] 76 | trainer_kwargs: 77 | save_num_checkpoints_to_keep: 0 -------------------------------------------------------------------------------- /yamls/baselines/bert-base-uncased.yaml: -------------------------------------------------------------------------------- 1 | # Whether to run the various GLUE jobs serially or in parallel (use parallel=True to take advantage of multiple GPUs) 2 | parallel: true 3 | 4 | # Basic run configuration, additional details will be added to this name for each GLUE task, and each random seed 5 | base_run_name: bert-base-uncased-glue-finetuning 6 | default_seed: 19 7 | precision: amp_bf16 8 | 9 | # Tokenizer for dataset creation 10 | tokenizer_name: bert-base-uncased 11 | 12 | # Base model config 13 | model: 14 | name: hf_bert 15 | use_pretrained: true 16 | pretrained_model_name: ${tokenizer_name} 17 | tokenizer_name: ${tokenizer_name} 18 | 19 | # Saving 20 | save_finetune_checkpoint_prefix: ./bert-finetune-checkpoints 21 | save_finetune_checkpoint_folder: ${save_finetune_checkpoint_prefix}/${base_run_name} 22 | 23 | # (Optional) W&B logging 24 | # loggers: 25 | # wandb: 26 | # project: # Fill this in if using W&B 27 | # entity: # Fill this in if using W&B 28 | 29 | # Callbacks 30 | callbacks: 31 | lr_monitor: {} 32 | speed_monitor: {} 33 | 34 | # Scheduler 35 | scheduler: 36 | name: linear_decay_with_warmup 37 | t_warmup: 0.06dur 38 | alpha_f: 0.0 39 | 40 | # Task configuration 41 | tasks: 42 | mnli: 43 | # Specify any extra task-specific arguments for the trainer here 44 | trainer_kwargs: 45 | # We keep one MNLI checkpoint locally so that we can start finetuning of 46 | # RTE, MRPC and STS-B from the MNLI checkpoint 47 | save_num_checkpoints_to_keep: 1 48 | rte: 49 | seeds: [19, 8364, 717, 10536, 90166] 50 | trainer_kwargs: 51 | save_num_checkpoints_to_keep: 0 52 | qqp: 53 | trainer_kwargs: 54 | save_num_checkpoints_to_keep: 0 55 | qnli: 56 | trainer_kwargs: 57 | save_num_checkpoints_to_keep: 0 58 | sst2: 59 | seeds: [19, 8364, 717] 60 | trainer_kwargs: 61 | save_num_checkpoints_to_keep: 0 62 | stsb: 63 | seeds: [19, 8364, 717, 10536, 90166] 64 | trainer_kwargs: 65 | save_num_checkpoints_to_keep: 0 66 | mrpc: 67 | seeds: [19, 8364, 717, 10536, 90166] 68 | trainer_kwargs: 69 | save_num_checkpoints_to_keep: 0 70 | cola: 71 | seeds: [19, 8364, 717, 10536] 72 | trainer_kwargs: 73 | save_num_checkpoints_to_keep: 0 74 | -------------------------------------------------------------------------------- /yamls/baselines/colbert/bert-base-uncased.yaml: -------------------------------------------------------------------------------- 1 | # This YAML is built to work with the `sequence_classification.py` starter script! 2 | # 3 | # Follow the instructions in that script to modify the `build_my_dataloader` function 4 | # and fine-tune a BERT model on your own dataset! 5 | # 6 | # 7 | # Note that some of the fields in this template haven't been filled in yet. 8 | # Please resolve any empty fields before launching! 9 | 10 | # Run Name 11 | run_name: finetune-hf-bert 12 | 13 | seed: 42 14 | model_name_or_path: bert-base-uncased 15 | train_dataset_id: bclavie/msmarco-500k-triplets-colbert-format 16 | n_gpu: 1 17 | tmp_dir: ./tmp_colbert 18 | debug: false 19 | 20 | train_params: 21 | lr: 3e-5 22 | use_ib_negatives: true 23 | bsize: 32 24 | 25 | eval_datasets: 26 | - "beir/scifact/test" 27 | - "beir/nfcorpus/test" 28 | - "beir/fiqa/test" 29 | - "beir/scidocs" 30 | - "beir/trec-covid" 31 | - "beir/webis-touche2020/v2" 32 | -------------------------------------------------------------------------------- /yamls/baselines/deberta-v3-base.yaml: -------------------------------------------------------------------------------- 1 | # Whether to run the various GLUE jobs serially or in parallel (use parallel=True to take advantage of multiple GPUs) 2 | parallel: true 3 | 4 | # Basic run configuration, additional details will be added to this name for each GLUE task, and each random seed 5 | base_run_name: deberta-v3-base-glue-finetuning 6 | default_seed: 19 7 | precision: amp_bf16 8 | 9 | # Tokenizer for dataset creation 10 | tokenizer_name: microsoft/deberta-v3-base 11 | 12 | # Base model config 13 | model: 14 | name: hf_bert 15 | use_pretrained: true 16 | pretrained_model_name: ${tokenizer_name} 17 | tokenizer_name: ${tokenizer_name} 18 | 19 | # Saving 20 | save_finetune_checkpoint_prefix: ./deberta-finetune-checkpoints 21 | save_finetune_checkpoint_folder: ${save_finetune_checkpoint_prefix}/${base_run_name} 22 | 23 | # (Optional) W&B logging 24 | # loggers: 25 | # wandb: 26 | # project: # Fill this in if using W&B 27 | # entity: # Fill this in if using W&B 28 | 29 | # Callbacks 30 | callbacks: 31 | lr_monitor: {} 32 | speed_monitor: {} 33 | 34 | # Scheduler 35 | scheduler: 36 | name: linear_decay_with_warmup 37 | t_warmup: 0.06dur 38 | alpha_f: 0.0 39 | 40 | # Task configuration 41 | tasks: 42 | mnli: 43 | # Specify any extra task-specific arguments for the trainer here 44 | trainer_kwargs: 45 | # We keep one MNLI checkpoint locally so that we can start finetuning of 46 | # RTE, MRPC and STS-B from the MNLI checkpoint 47 | save_num_checkpoints_to_keep: 1 48 | rte: 49 | seeds: [19, 8364, 717, 10536, 90166] 50 | trainer_kwargs: 51 | save_num_checkpoints_to_keep: 0 52 | qqp: 53 | trainer_kwargs: 54 | save_num_checkpoints_to_keep: 0 55 | qnli: 56 | trainer_kwargs: 57 | save_num_checkpoints_to_keep: 0 58 | sst2: 59 | seeds: [19, 8364, 717] 60 | trainer_kwargs: 61 | save_num_checkpoints_to_keep: 0 62 | stsb: 63 | seeds: [19, 8364, 717, 10536, 90166] 64 | trainer_kwargs: 65 | save_num_checkpoints_to_keep: 0 66 | mrpc: 67 | seeds: [19, 8364, 717, 10536, 90166] 68 | trainer_kwargs: 69 | save_num_checkpoints_to_keep: 0 70 | cola: 71 | seeds: [19, 8364, 717, 10536] 72 | trainer_kwargs: 73 | save_num_checkpoints_to_keep: 0 74 | -------------------------------------------------------------------------------- /yamls/baselines/deberta-v3-long-context.yaml: -------------------------------------------------------------------------------- 1 | # Whether to run the various GLUE jobs serially or in parallel (use parallel=True to take advantage of multiple GPUs) 2 | parallel: true 3 | 4 | # Basic run configuration, additional details will be added to this name for each GLUE task, and each random seed 5 | base_run_name: deberta-v3-long-context-finetuning 6 | default_seed: 19 7 | precision: amp_bf16 8 | 9 | # Tokenizer for dataset creation 10 | tokenizer_name: microsoft/deberta-v3-large # microsoft/deberta-v3-base 11 | 12 | # Base model config 13 | model: 14 | name: hf_bert 15 | use_pretrained: true 16 | pretrained_model_name: ${tokenizer_name} 17 | tokenizer_name: ${tokenizer_name} 18 | gradient_checkpointing: false 19 | model_config: {} 20 | 21 | # Saving 22 | save_finetune_checkpoint_prefix: ./deberta-finetune-checkpoints-long-context 23 | save_finetune_checkpoint_folder: ${save_finetune_checkpoint_prefix}/${base_run_name} 24 | 25 | # (Optional) W&B logging 26 | # loggers: 27 | # wandb: 28 | # project: # Fill this in if using W&B 29 | # entity: # Fill this in if using W&B 30 | 31 | # Callbacks 32 | callbacks: 33 | lr_monitor: {} 34 | speed_monitor: {} 35 | 36 | # Scheduler` 37 | scheduler: 38 | name: linear_decay_with_warmup 39 | t_warmup: 0.06dur 40 | alpha_f: 0.0 41 | 42 | # Task configuration 43 | tasks: 44 | eurlex: 45 | seeds: [461, 475, 480] 46 | model_config: 47 | problem_type: multi_label_classification 48 | trainer_kwargs: 49 | save_num_checkpoints_to_keep: 0 -------------------------------------------------------------------------------- /yamls/defaults.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: mosaic_bert 3 | model_config: 4 | normalization: layernorm 5 | hidden_act: gelu -------------------------------------------------------------------------------- /yamls/finetuning/glue/mosaic-bert-base-uncased.yaml: -------------------------------------------------------------------------------- 1 | # Note that some of the fields in this template haven't been filled in yet. 2 | # Please resolve any `null` fields before launching! 3 | 4 | # Whether to run the various GLUE jobs serially or in parallel (use parallel=True to take advantage of multiple GPUs) 5 | parallel: true 6 | 7 | # Basic run configuration, additional details will be added to this name for each GLUE task, and each random seed 8 | base_run_name: mosaic-bert-base-uncased-glue-finetuning # Determines how runs are saved and logged in W&B 9 | default_seed: 19 10 | precision: amp_bf16 11 | 12 | # Tokenizer for dataset creation 13 | tokenizer_name: bert-base-uncased 14 | 15 | # Base model config 16 | model: 17 | name: mosaic_bert 18 | pretrained_model_name: ${tokenizer_name} 19 | tokenizer_name: ${tokenizer_name} 20 | model_config: 21 | deterministic_fa2: true 22 | 23 | # Loading 24 | # (fill this in with the composer checkpoint from the end of pre-training a Mosaic BERT) 25 | starting_checkpoint_load_path: 26 | local_pretrain_checkpoint_folder: ./local-bert-checkpoints/ 27 | 28 | # Saving 29 | save_finetune_checkpoint_prefix: ./local-finetune-checkpoints/ # (local) 30 | # save_finetune_checkpoint_prefix: s3:///remote-finetune-checkpoints # (remote) 31 | save_finetune_checkpoint_folder: ${save_finetune_checkpoint_prefix}/${base_run_name} 32 | 33 | # # (Optional) W&B logging 34 | # loggers: 35 | # wandb: 36 | # project: # Fill this in if using W&B 37 | # entity: # Fill this in if using W&B 38 | 39 | # Callbacks 40 | callbacks: 41 | lr_monitor: {} 42 | speed_monitor: {} 43 | 44 | # Scheduler 45 | scheduler: 46 | name: linear_decay_with_warmup 47 | t_warmup: 0.06dur 48 | alpha_f: 0.0 49 | 50 | # Algorithms 51 | # algorithms: 52 | 53 | # Task configuration 54 | tasks: 55 | mnli: 56 | # Specify any extra task-specific arguments for the trainer here 57 | trainer_kwargs: 58 | # We keep one MNLI checkpoint locally so that we can start finetuning of 59 | # RTE, MRPC and STS-B from the MNLI checkpoint 60 | save_num_checkpoints_to_keep: 1 61 | rte: 62 | seeds: [19, 8364, 717, 10536, 90166] 63 | trainer_kwargs: 64 | save_num_checkpoints_to_keep: 0 65 | qqp: 66 | trainer_kwargs: 67 | save_num_checkpoints_to_keep: 0 68 | qnli: 69 | trainer_kwargs: 70 | save_num_checkpoints_to_keep: 0 71 | sst2: 72 | seeds: [19, 8364, 717] 73 | trainer_kwargs: 74 | save_num_checkpoints_to_keep: 0 75 | stsb: 76 | seeds: [19, 8364, 717, 10536, 90166] 77 | trainer_kwargs: 78 | save_num_checkpoints_to_keep: 0 79 | mrpc: 80 | seeds: [19, 8364, 717, 10536, 90166] 81 | trainer_kwargs: 82 | save_num_checkpoints_to_keep: 0 83 | cola: 84 | seeds: [19, 8364, 717, 10536] 85 | trainer_kwargs: 86 | save_num_checkpoints_to_keep: 0 87 | -------------------------------------------------------------------------------- /yamls/finetuning/hf-bert-base-uncased.yaml: -------------------------------------------------------------------------------- 1 | # This YAML is built to work with the `sequence_classification.py` starter script! 2 | # 3 | # Follow the instructions in that script to modify the `build_my_dataloader` function 4 | # and fine-tune a BERT model on your own dataset! 5 | # 6 | # 7 | # Note that some of the fields in this template haven't been filled in yet. 8 | # Please resolve any empty fields before launching! 9 | 10 | # Run Name 11 | run_name: finetune-hf-bert 12 | 13 | tokenizer_name: bert-base-uncased 14 | max_seq_len: 128 15 | 16 | load_path: # (Optionally) provide a composer checkpoint to use for the starting weights 17 | 18 | # Model 19 | model: 20 | name: hf_bert 21 | num_labels: 2 # <-- Make sure to update these after you modify the starter script! 22 | use_pretrained: true 23 | pretrained_model_name: ${tokenizer_name} 24 | tokenizer_name: ${tokenizer_name} 25 | 26 | # Dataloaders (make sure to update these after you modify the starter script!) 27 | train_loader: 28 | split: train 29 | tokenizer_name: ${tokenizer_name} 30 | max_seq_len: ${max_seq_len} 31 | shuffle: true 32 | drop_last: true 33 | num_workers: 8 34 | 35 | eval_loader: 36 | split: validation 37 | tokenizer_name: ${tokenizer_name} 38 | max_seq_len: ${max_seq_len} 39 | shuffle: true 40 | drop_last: true 41 | num_workers: 8 42 | 43 | # Optimization 44 | scheduler: 45 | name: linear_decay_with_warmup 46 | t_warmup: 0.06dur # Warmup to the full LR for 6% of the training duration 47 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 48 | 49 | optimizer: 50 | name: decoupled_adamw 51 | lr: 1.0e-5 52 | betas: 53 | - 0.9 54 | - 0.98 55 | eps: 1.0e-06 56 | weight_decay: 1.0e-6 57 | 58 | # Training duration and evaluation frequency 59 | max_duration: 10ep 60 | eval_interval: 1ep 61 | global_train_batch_size: 16 62 | 63 | # System 64 | seed: 17 65 | device_eval_microbatch_size: 16 66 | device_train_microbatch_size: 16 67 | precision: amp_bf16 68 | 69 | # Logging 70 | progress_bar: false 71 | log_to_console: true 72 | console_log_interval: 10ba 73 | 74 | # Optionally log to W&B 75 | # loggers: 76 | # wandb: {} 77 | 78 | callbacks: 79 | speed_monitor: 80 | window_size: 50 81 | lr_monitor: {} 82 | -------------------------------------------------------------------------------- /yamls/finetuning/mosaic-bert-base-uncased.yaml: -------------------------------------------------------------------------------- 1 | # This YAML is built to work with the `sequence_classification.py` starter script! 2 | # 3 | # Follow the instructions in that script to modify the `build_my_dataloader` function 4 | # and fine-tune a BERT model on your own dataset! 5 | # 6 | # 7 | # Note that some of the fields in this template haven't been filled in yet. 8 | # Please resolve any empty fields before launching! 9 | 10 | # Run Name 11 | run_name: finetune-mosaic-bert 12 | 13 | tokenizer_name: bert-base-uncased 14 | max_seq_len: 128 15 | 16 | load_path: # (Optionally) provide a composer checkpoint to use for the starting weights 17 | 18 | # Model 19 | model: 20 | name: mosaic_bert 21 | num_labels: 2 # <-- Make sure to update these after you modify the starter script! 22 | pretrained_model_name: ${tokenizer_name} 23 | tokenizer_name: ${tokenizer_name} 24 | model_config: 25 | deterministic_fa2: true 26 | 27 | # Dataloaders (make sure to update these after you modify the starter script!) 28 | train_loader: 29 | split: train 30 | tokenizer_name: ${tokenizer_name} 31 | max_seq_len: ${max_seq_len} 32 | shuffle: true 33 | drop_last: true 34 | num_workers: 8 35 | 36 | eval_loader: 37 | split: validation 38 | tokenizer_name: ${tokenizer_name} 39 | max_seq_len: ${max_seq_len} 40 | shuffle: true 41 | drop_last: true 42 | num_workers: 8 43 | 44 | # Optimization 45 | scheduler: 46 | name: linear_decay_with_warmup 47 | t_warmup: 0.06dur # Warmup to the full LR for 6% of the training duration 48 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 49 | 50 | optimizer: 51 | name: decoupled_adamw 52 | lr: 1.0e-5 53 | betas: 54 | - 0.9 55 | - 0.98 56 | eps: 1.0e-06 57 | weight_decay: 1.0e-6 58 | 59 | # Training duration and evaluation frequency 60 | max_duration: 10ep 61 | eval_interval: 1ep 62 | global_train_batch_size: 16 63 | 64 | # System 65 | seed: 17 66 | device_eval_microbatch_size: 16 67 | device_train_microbatch_size: 16 68 | precision: amp_bf16 69 | 70 | # Logging 71 | progress_bar: false 72 | log_to_console: true 73 | console_log_interval: 10ba 74 | 75 | # Optionally log to W&B 76 | # loggers: 77 | # wandb: {} 78 | 79 | callbacks: 80 | speed_monitor: 81 | window_size: 50 82 | lr_monitor: {} 83 | -------------------------------------------------------------------------------- /yamls/main/flex-bert-base-parallel.yaml: -------------------------------------------------------------------------------- 1 | # Note that some of the fields in this template haven't been filled in yet. 2 | # Please resolve any `null` fields before launching! 3 | 4 | # Follow the instructions in the README to set up ./my-copy-c4 5 | # Or point data paths to your remote C4 dataset 6 | data_local: ./my-copy-c4 7 | data_remote: # If blank, files must be present in data_local 8 | 9 | max_seq_len: 128 10 | tokenizer_name: bert-base-uncased # switch to bert tokenizer until we add [MASK] token to the llama tokenizer meta-llama/Llama-2-7b-hf 11 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 12 | 13 | # Run Name 14 | run_name: flex-bert-base-parallel 15 | 16 | # Model 17 | model: 18 | name: flex_bert 19 | recompute_metric_loss: false # recompute metric loss, use if passing label_smoothing to record non-label-smoothed loss as a metric 20 | pretrained_model_name: ${tokenizer_name} 21 | tokenizer_name: ${tokenizer_name} 22 | # FlexBERT 'base' generally uses the default architecture values for from the Hugging Face BertConfig object 23 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 24 | # the model_config settings match the architecture of the existing model 25 | model_config: 26 | num_attention_heads: 12 # bert-base default 27 | num_hidden_layers: 12 # bert-base default 28 | attention_layer: parallel 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: False 31 | attn_out_dropout_prob: 0.0 32 | attn_qkv_bias: False 33 | bert_layer: parallel_prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: False 36 | final_norm: True 37 | embedding_layer: absolute_pos 38 | loss_function: fa_cross_entropy 39 | loss_kwargs: 40 | reduction: mean 41 | mlp_dropout_prob: 0.0 42 | mlp_in_bias: False 43 | mlp_layer: parallel_glu 44 | mlp_out_bias: False 45 | norm_kwargs: 46 | eps: 1e-6 47 | normalization: rmsnorm 48 | padding: unpadded 49 | sparse_prediction: False 50 | hidden_act: gelu 51 | init_method: full_megatron 52 | init_std: 0.02 53 | init_cutoff_factor: 2.0 54 | init_small_embedding: False 55 | deterministic_fa2: false 56 | initial_attention_layer: null 57 | initial_bert_layer: null 58 | initial_mlp_layer: null 59 | num_initial_layers: 0 60 | skip_first_prenorm: true 61 | sliding_window: 128 62 | global_attn_every_n_layers: 3 63 | unpad_embeddings: true 64 | pad_logits: false 65 | 66 | 67 | # Dataloaders 68 | train_loader: 69 | name: text 70 | dataset: 71 | local: ${data_local} 72 | remote: ${data_remote} 73 | split: train_small 74 | tokenizer_name: ${tokenizer_name} 75 | max_seq_len: ${max_seq_len} 76 | shuffle: true 77 | mlm_probability: ${mlm_probability} 78 | drop_last: true 79 | num_workers: 8 80 | 81 | eval_loader: 82 | name: text 83 | dataset: 84 | local: ${data_local} 85 | remote: ${data_remote} 86 | split: val 87 | tokenizer_name: ${tokenizer_name} 88 | max_seq_len: ${max_seq_len} 89 | shuffle: false 90 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 91 | drop_last: false 92 | num_workers: 8 93 | 94 | # Optimization 95 | scheduler: 96 | name: linear_decay_with_warmup 97 | t_warmup: 0.06dur # Warmup to the full LR for 6% of the training duration 98 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 99 | 100 | optimizer: 101 | name: decoupled_adamw 102 | lr: 5.0e-4 # Peak learning rate 103 | betas: 104 | - 0.9 105 | - 0.98 106 | eps: 1.0e-06 107 | weight_decay: 1.0e-5 # Amount of weight decay regularization 108 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 109 | 110 | # algorithms: 111 | 112 | max_duration: 286720000sp # Subsample the training data for ~275M samples 113 | eval_interval: 2000ba 114 | global_train_batch_size: 4096 115 | 116 | # System 117 | seed: 17 118 | device_train_microbatch_size: 128 119 | # device_train_microbatch_size: auto 120 | precision: amp_bf16 121 | 122 | global_eval_batch_size: 256 123 | device_eval_microbatch_size: 64 124 | 125 | # Logging 126 | progress_bar: false 127 | log_to_console: true 128 | console_log_interval: 1ba 129 | 130 | callbacks: 131 | speed_monitor: 132 | window_size: 500 133 | lr_monitor: {} 134 | 135 | algorithms: 136 | gradient_clipping: 137 | clipping_type: norm 138 | clipping_threshold: 1.0 139 | 140 | # (Optional) W&B logging 141 | loggers: 142 | wandb: 143 | project: bert24 # Fill this in 144 | # entity: # Fill this in 145 | 146 | # (Optional) Checkpoint to local filesystem or remote object store 147 | # save_interval: 3500ba 148 | # save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK 149 | # save_folder: # e.g. './{run_name}/ckpt' (local) or 's3://mybucket/mydir/{run_name}/ckpt' (remote) 150 | 151 | # (Optional) Load from local filesystem or remote object store to 152 | # start from an existing model checkpoint; 153 | # e.g. './ckpt/latest-rank{rank}.pt' (local), or 154 | # 's3://mybucket/mydir/ckpt/latest-rank{rank}.pt' (remote) 155 | # load_path: null 156 | -------------------------------------------------------------------------------- /yamls/main/flex-bert-base.yaml: -------------------------------------------------------------------------------- 1 | # Note that some of the fields in this template haven't been filled in yet. 2 | # Please resolve any `null` fields before launching! 3 | 4 | # Follow the instructions in the README to set up ./my-copy-c4 5 | # Or point data paths to your remote C4 dataset 6 | data_local: ./my-copy-c4 7 | data_remote: # If blank, files must be present in data_local 8 | 9 | max_seq_len: 128 10 | tokenizer_name: bert-base-uncased # switch to bert tokenizer until we add [MASK] token to the llama tokenizer meta-llama/Llama-2-7b-hf 11 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 12 | 13 | # Run Name 14 | run_name: flex-bert-base 15 | 16 | # Model 17 | model: 18 | name: flex_bert 19 | recompute_metric_loss: false # recompute metric loss, use if passing label_smoothing to record non-label-smoothed loss as a metric 20 | pretrained_model_name: ${tokenizer_name} 21 | tokenizer_name: ${tokenizer_name} 22 | # FlexBERT 'base' generally uses the default architecture values for from the Hugging Face BertConfig object 23 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 24 | # the model_config settings match the architecture of the existing model 25 | model_config: 26 | num_attention_heads: 12 # bert-base default 27 | num_hidden_layers: 12 # bert-base default 28 | attention_layer: base 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.0 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: false 36 | final_norm: true 37 | embedding_layer: absolute_pos 38 | loss_function: fa_cross_entropy 39 | loss_kwargs: 40 | reduction: mean 41 | mlp_dropout_prob: 0.0 42 | mlp_in_bias: false 43 | mlp_layer: mlp 44 | mlp_out_bias: false 45 | norm_kwargs: 46 | eps: 1e-6 47 | normalization: rmsnorm 48 | padding: unpadded 49 | sparse_prediction: false 50 | hidden_act: gelu 51 | init_method: full_megatron 52 | init_std: 0.02 53 | init_cutoff_factor: 2.0 54 | init_small_embedding: False 55 | deterministic_fa2: false 56 | initial_attention_layer: null 57 | initial_bert_layer: null 58 | initial_mlp_layer: null 59 | num_initial_layers: 0 60 | skip_first_prenorm: true 61 | sliding_window: 128 62 | global_attn_every_n_layers: 3 63 | unpad_embeddings: true 64 | pad_logits: false 65 | 66 | # Dataloaders 67 | train_loader: 68 | name: text 69 | dataset: 70 | local: ${data_local} 71 | remote: ${data_remote} 72 | split: train 73 | tokenizer_name: ${tokenizer_name} 74 | max_seq_len: ${max_seq_len} 75 | shuffle: true 76 | mlm_probability: ${mlm_probability} 77 | drop_last: true 78 | num_workers: 8 79 | 80 | eval_loader: 81 | name: text 82 | dataset: 83 | local: ${data_local} 84 | remote: ${data_remote} 85 | split: val 86 | tokenizer_name: ${tokenizer_name} 87 | max_seq_len: ${max_seq_len} 88 | shuffle: false 89 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 90 | drop_last: false 91 | num_workers: 8 92 | 93 | # Optimization 94 | scheduler: 95 | name: linear_decay_with_warmup 96 | t_warmup: 0.06dur # Warmup to the full LR for 6% of the training duration 97 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 98 | 99 | optimizer: 100 | name: decoupled_adamw 101 | lr: 5.0e-4 # Peak learning rate 102 | betas: 103 | - 0.9 104 | - 0.98 105 | eps: 1.0e-06 106 | weight_decay: 1.0e-5 # Amount of weight decay regularization 107 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 108 | 109 | # algorithms: 110 | 111 | max_duration: 286720000sp # Subsample the training data for ~275M samples 112 | eval_interval: 2000ba 113 | global_train_batch_size: 4096 114 | 115 | # System 116 | seed: 17 117 | device_train_microbatch_size: 128 118 | # device_train_microbatch_size: auto 119 | precision: amp_bf16 120 | 121 | global_eval_batch_size: 256 122 | device_eval_microbatch_size: 64 123 | 124 | # Logging 125 | progress_bar: false 126 | log_to_console: true 127 | console_log_interval: 1ba 128 | 129 | callbacks: 130 | speed_monitor: 131 | window_size: 500 132 | lr_monitor: {} 133 | 134 | algorithms: 135 | gradient_clipping: 136 | clipping_type: norm 137 | clipping_threshold: 1.0 138 | 139 | # (Optional) W&B logging 140 | # loggers: 141 | # wandb: 142 | # project: # Fill this in 143 | # entity: # Fill this in 144 | 145 | # (Optional) Checkpoint to local filesystem or remote object store 146 | # save_interval: 3500ba 147 | # save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK 148 | # save_folder: # e.g. './{run_name}/ckpt' (local) or 's3://mybucket/mydir/{run_name}/ckpt' (remote) 149 | 150 | # (Optional) Load from local filesystem or remote object store to 151 | # start from an existing model checkpoint; 152 | # e.g. './ckpt/latest-rank{rank}.pt' (local), or 153 | # 's3://mybucket/mydir/ckpt/latest-rank{rank}.pt' (remote) 154 | # load_path: null 155 | -------------------------------------------------------------------------------- /yamls/main/flex-bert-rope-base.yaml: -------------------------------------------------------------------------------- 1 | # Note that some of the fields in this template haven't been filled in yet. 2 | # Please resolve any `null` fields before launching! 3 | 4 | # Follow the instructions in the README to set up ./my-copy-c4 5 | # Or point data paths to your remote C4 dataset 6 | data_local: ./my-copy-c4 7 | data_remote: # If blank, files must be present in data_local 8 | 9 | max_seq_len: 512 10 | tokenizer_name: bert-base-uncased # switch to bert tokenizer until we add [MASK] token to the llama tokenizer meta-llama/Llama-2-7b-hf 11 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 12 | 13 | # Run Name 14 | run_name: flex-bert-rope-base 15 | 16 | # Model 17 | model: 18 | name: flex_bert 19 | recompute_metric_loss: false # recompute metric loss, use if passing label_smoothing to record non-label-smoothed loss as a metric 20 | pretrained_model_name: ${tokenizer_name} 21 | tokenizer_name: ${tokenizer_name} 22 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 23 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 24 | # the model_config settings match the architecture of the existing model 25 | model_config: 26 | num_attention_heads: 12 # bert-base default 27 | num_hidden_layers: 12 # bert-base default 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.0 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: false 36 | final_norm: true 37 | embedding_layer: sans_pos 38 | loss_function: fa_cross_entropy 39 | loss_kwargs: 40 | reduction: mean 41 | mlp_dropout_prob: 0.0 42 | mlp_in_bias: false 43 | mlp_layer: mlp 44 | mlp_out_bias: false 45 | normalization: rmsnorm 46 | norm_kwargs: 47 | eps: 1e-6 48 | padding: unpadded 49 | sparse_prediction: false 50 | rotary_emb_dim: null # will be set to headdim by default 51 | rotary_emb_base: 10000.0 52 | rotary_emb_scale_base: null 53 | rotary_emb_interleaved: false 54 | hidden_act: gelu 55 | init_method: full_megatron 56 | init_std: 0.02 57 | init_cutoff_factor: 2.0 58 | init_small_embedding: False 59 | deterministic_fa2: false 60 | initial_attention_layer: null 61 | initial_bert_layer: null 62 | initial_mlp_layer: null 63 | num_initial_layers: 0 64 | skip_first_prenorm: true 65 | sliding_window: 128 66 | global_attn_every_n_layers: 3 67 | unpad_embeddings: true 68 | pad_logits: false 69 | 70 | # Dataloaders 71 | train_loader: 72 | name: text 73 | dataset: 74 | local: ${data_local} 75 | remote: ${data_remote} 76 | split: train 77 | tokenizer_name: ${tokenizer_name} 78 | max_seq_len: ${max_seq_len} 79 | shuffle: true 80 | mlm_probability: ${mlm_probability} 81 | drop_last: true 82 | num_workers: 8 83 | 84 | eval_loader: 85 | name: text 86 | dataset: 87 | local: ${data_local} 88 | remote: ${data_remote} 89 | split: val 90 | tokenizer_name: ${tokenizer_name} 91 | max_seq_len: ${max_seq_len} 92 | shuffle: false 93 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 94 | drop_last: false 95 | num_workers: 8 96 | 97 | # Optimization 98 | scheduler: 99 | name: linear_decay_with_warmup 100 | t_warmup: 0.06dur # Warmup to the full LR for 6% of the training duration 101 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 102 | 103 | optimizer: 104 | name: decoupled_adamw 105 | lr: 5.0e-4 # Peak learning rate 106 | betas: 107 | - 0.9 108 | - 0.98 109 | eps: 1.0e-06 110 | weight_decay: 1.0e-5 # Amount of weight decay regularization 111 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 112 | 113 | # algorithms: 114 | 115 | max_duration: 286720000sp # Subsample the training data for ~275M samples 116 | eval_interval: 2000ba 117 | global_train_batch_size: 4096 118 | 119 | # System 120 | seed: 17 121 | device_train_microbatch_size: 128 122 | # device_train_microbatch_size: auto 123 | precision: amp_bf16 124 | 125 | global_eval_batch_size: 256 126 | device_eval_microbatch_size: 64 127 | 128 | # Logging 129 | progress_bar: false 130 | log_to_console: true 131 | console_log_interval: 1ba 132 | 133 | algorithms: 134 | gradient_clipping: 135 | clipping_type: norm 136 | clipping_threshold: 1.0 137 | 138 | callbacks: 139 | speed_monitor: 140 | window_size: 10 141 | lr_monitor: {} 142 | 143 | # (Optional) W&B logging 144 | # loggers: 145 | # wandb: 146 | # project: # Fill this in 147 | # entity: # Fill this in 148 | 149 | # (Optional) Checkpoint to local filesystem or remote object store 150 | # save_interval: 3500ba 151 | # save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK 152 | # save_folder: # e.g. './{run_name}/ckpt' (local) or 's3://mybucket/mydir/{run_name}/ckpt' (remote) 153 | 154 | # (Optional) Load from local filesystem or remote object store to 155 | # start from an existing model checkpoint; 156 | # e.g. './ckpt/latest-rank{rank}.pt' (local), or 157 | # 's3://mybucket/mydir/ckpt/latest-rank{rank}.pt' (remote) 158 | # load_path: null 159 | -------------------------------------------------------------------------------- /yamls/main/flex-bert-rope-parallel-firstprenorm.yaml: -------------------------------------------------------------------------------- 1 | # Note that some of the fields in this template haven't been filled in yet. 2 | # Please resolve any `null` fields before launching! 3 | 4 | # Follow the instructions in the README to set up ./my-copy-c4 5 | # Or point data paths to your remote C4 dataset 6 | data_local: ./my-copy-c4 7 | data_remote: # If blank, files must be present in data_local 8 | 9 | max_seq_len: 512 10 | tokenizer_name: bert-base-uncased # switch to bert tokenizer until we add [MASK] token to the llama tokenizer meta-llama/Llama-2-7b-hf 11 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 12 | 13 | # Run Name 14 | run_name: flex-bert-rope-parallel-firstprenorm 15 | 16 | # Model 17 | model: 18 | name: flex_bert 19 | recompute_metric_loss: false # recompute metric loss, use if passing label_smoothing to record non-label-smoothed loss as a metric 20 | pretrained_model_name: ${tokenizer_name} 21 | tokenizer_name: ${tokenizer_name} 22 | # FlexBERT 'base' generally uses the default architecture values for from the Hugging Face BertConfig object 23 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 24 | # the model_config settings match the architecture of the existing model 25 | model_config: 26 | num_attention_heads: 12 # bert-base default 27 | num_hidden_layers: 12 # bert-base default 28 | attention_layer: rope_parallel 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.0 32 | attn_qkv_bias: false 33 | bert_layer: parallel_prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | embedding_layer: sans_pos 38 | loss_function: fa_cross_entropy 39 | loss_kwargs: 40 | reduction: mean 41 | mlp_dropout_prob: 0.0 42 | mlp_in_bias: false 43 | mlp_layer: parallel_glu 44 | mlp_out_bias: false 45 | normalization: rmsnorm 46 | norm_kwargs: 47 | eps: 1e-6 48 | padding: unpadded 49 | sparse_prediction: false 50 | rotary_emb_dim: null # will be set to headdim by default 51 | rotary_emb_base: 10000.0 52 | rotary_emb_scale_base: null 53 | rotary_emb_interleaved: false 54 | hidden_act: gelu 55 | init_method: full_megatron 56 | init_std: 0.02 57 | init_cutoff_factor: 2.0 58 | init_small_embedding: False 59 | deterministic_fa2: false 60 | initial_attention_layer: null 61 | initial_bert_layer: null 62 | initial_mlp_layer: null 63 | num_initial_layers: 0 64 | skip_first_prenorm: true 65 | sliding_window: 128 66 | global_attn_every_n_layers: 3 67 | unpad_embeddings: true 68 | pad_logits: false 69 | 70 | # Dataloaders 71 | train_loader: 72 | name: text 73 | dataset: 74 | local: ${data_local} 75 | remote: ${data_remote} 76 | split: train_small 77 | tokenizer_name: ${tokenizer_name} 78 | max_seq_len: ${max_seq_len} 79 | shuffle: true 80 | mlm_probability: ${mlm_probability} 81 | drop_last: true 82 | num_workers: 8 83 | 84 | eval_loader: 85 | name: text 86 | dataset: 87 | local: ${data_local} 88 | remote: ${data_remote} 89 | split: val 90 | tokenizer_name: ${tokenizer_name} 91 | max_seq_len: ${max_seq_len} 92 | shuffle: false 93 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 94 | drop_last: false 95 | num_workers: 8 96 | 97 | # Optimization 98 | scheduler: 99 | name: linear_decay_with_warmup 100 | t_warmup: 0.06dur # Warmup to the full LR for 6% of the training duration 101 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 102 | 103 | optimizer: 104 | name: decoupled_adamw 105 | lr: 5.0e-4 # Peak learning rate 106 | betas: 107 | - 0.9 108 | - 0.98 109 | eps: 1.0e-06 110 | weight_decay: 1.0e-5 # Amount of weight decay regularization 111 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 112 | 113 | # algorithms: 114 | 115 | max_duration: 286720000sp # Subsample the training data for ~275M samples 116 | eval_interval: 2000ba 117 | global_train_batch_size: 4096 118 | 119 | # System 120 | seed: 17 121 | device_train_microbatch_size: 128 122 | # device_train_microbatch_size: auto 123 | precision: amp_bf16 124 | 125 | global_eval_batch_size: 256 126 | device_eval_microbatch_size: 64 127 | 128 | # Logging 129 | progress_bar: false 130 | log_to_console: true 131 | console_log_interval: 1ba 132 | 133 | callbacks: 134 | speed_monitor: 135 | window_size: 10 136 | lr_monitor: {} 137 | 138 | algorithms: 139 | gradient_clipping: 140 | clipping_type: norm 141 | clipping_threshold: 1.0 142 | 143 | # (Optional) W&B logging 144 | loggers: 145 | wandb: 146 | project: bert24 # Fill this in 147 | # entity: # Fill this in 148 | 149 | # (Optional) Checkpoint to local filesystem or remote object store 150 | # save_interval: 3500ba 151 | # save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK 152 | # save_folder: # e.g. './{run_name}/ckpt' (local) or 's3://mybucket/mydir/{run_name}/ckpt' (remote) 153 | 154 | # (Optional) Load from local filesystem or remote object store to 155 | # start from an existing model checkpoint; 156 | # e.g. './ckpt/latest-rank{rank}.pt' (local), or 157 | # 's3://mybucket/mydir/ckpt/latest-rank{rank}.pt' (remote) 158 | # load_path: null 159 | -------------------------------------------------------------------------------- /yamls/main/hf-bert-base-uncased.yaml: -------------------------------------------------------------------------------- 1 | # Note that some of the fields in this template haven't been filled in yet. 2 | # Please resolve any `null` fields before launching! 3 | 4 | # Follow the instructions in the README to set up ./my-copy-c4 5 | # Or point data paths to your remote C4 dataset 6 | data_local: ./my-copy-c4 7 | data_remote: # If blank, files must be present in data_local 8 | 9 | max_seq_len: 128 10 | tokenizer_name: bert-base-uncased 11 | mlm_probability: 0.15 12 | 13 | # Run Name 14 | run_name: hf-bert-base-uncased 15 | 16 | # Model 17 | model: 18 | name: hf_bert 19 | use_pretrained: false # Train the model from scratch. Set to true to start from the HF off-the-shelf weights. 20 | pretrained_model_name: ${tokenizer_name} 21 | tokenizer_name: ${tokenizer_name} 22 | # This implementation generally uses the default architecture values for from the Hugging Face BertConfig object 23 | # These values can be changed here when pretraining from scratch. Note that these should only be used 24 | # if used_pretained: false, otherwise the model will not be loaded properly 25 | model_config: 26 | num_attention_heads: 12 # bert-base default 27 | num_hidden_layers: 12 # bert-base default 28 | max_position_embedding: 512 29 | attention_probs_dropout_prob: 0.1 # bert-base default 30 | 31 | # Dataloaders 32 | train_loader: 33 | name: text 34 | dataset: 35 | local: ${data_local} 36 | remote: ${data_remote} 37 | split: train 38 | tokenizer_name: ${tokenizer_name} 39 | max_seq_len: ${max_seq_len} 40 | shuffle: true 41 | mlm_probability: ${mlm_probability} 42 | drop_last: true 43 | num_workers: 8 44 | 45 | eval_loader: 46 | name: text 47 | dataset: 48 | local: ${data_local} 49 | remote: ${data_remote} 50 | split: val 51 | tokenizer_name: ${tokenizer_name} 52 | max_seq_len: ${max_seq_len} 53 | shuffle: false 54 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 55 | drop_last: false 56 | num_workers: 8 57 | 58 | # Optimization 59 | scheduler: 60 | name: linear_decay_with_warmup 61 | t_warmup: 0.06dur # Warmup to the full LR for 6% of the training duration 62 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 63 | 64 | optimizer: 65 | name: decoupled_adamw 66 | lr: 5.0e-4 # Peak learning rate 67 | betas: 68 | - 0.9 69 | - 0.98 70 | eps: 1.0e-06 71 | weight_decay: 1.0e-5 # Amount of weight decay regularization 72 | filter_bias_norm_wd: false # If True, doesn't apply weight decay to norm layers and biases 73 | 74 | max_duration: 286720000sp # Subsample the training data for ~275M samples 75 | eval_interval: 2000ba 76 | global_train_batch_size: 4096 77 | 78 | # System 79 | seed: 17 80 | device_train_microbatch_size: 128 81 | # device_train_microbatch_size: auto 82 | precision: amp_bf16 83 | 84 | global_eval_batch_size: 256 85 | device_eval_microbatch_size: 64 86 | 87 | # Logging 88 | progress_bar: false 89 | log_to_console: true 90 | console_log_interval: 1ba 91 | 92 | algorithms: 93 | gradient_clipping: 94 | clipping_type: norm 95 | clipping_threshold: 1.0 96 | 97 | callbacks: 98 | speed_monitor: 99 | window_size: 500 100 | lr_monitor: {} 101 | 102 | # (Optional) W&B logging 103 | # loggers: 104 | # wandb: 105 | # project: # Fill this in 106 | # entity: # Fill this in 107 | 108 | # (Optional) Checkpoint to local filesystem or remote object store 109 | # save_interval: 3500ba 110 | # save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK 111 | # save_folder: # e.g. './{run_name}/ckpt' (local) or 's3://mybucket/mydir/{run_name}/ckpt' (remote) 112 | 113 | # (Optional) Load from local filesystem or remote object store to 114 | # start from an existing model checkpoint; 115 | # e.g. './ckpt/latest-rank{rank}.pt' (local), or 116 | # 's3://mybucket/mydir/ckpt/latest-rank{rank}.pt' (remote) 117 | # load_path: null 118 | -------------------------------------------------------------------------------- /yamls/main/mosaic-bert-base-uncased.yaml: -------------------------------------------------------------------------------- 1 | # Note that some of the fields in this template haven't been filled in yet. 2 | # Please resolve any `null` fields before launching! 3 | 4 | # Follow the instructions in the README to set up ./my-copy-c4 5 | # Or point data paths to your remote C4 dataset 6 | data_local: ./my-copy-c4 7 | data_remote: # If blank, files must be present in data_local 8 | 9 | max_seq_len: 128 10 | tokenizer_name: bert-base-uncased 11 | mlm_probability: 0.3 # MosaicBERT should use 30% masking for optimal performance 12 | 13 | # Run Name 14 | run_name: mosaic-bert-base-uncased 15 | 16 | # Model 17 | model: 18 | name: mosaic_bert 19 | pretrained_model_name: ${tokenizer_name} 20 | tokenizer_name: ${tokenizer_name} 21 | # MosaicBERT 'base' generally uses the default architecture values for from the Hugging Face BertConfig object 22 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 23 | # the model_config settings match the architecture of the existing model 24 | model_config: 25 | num_attention_heads: 12 # bert-base default 26 | num_hidden_layers: 12 # bert-base default 27 | attention_probs_dropout_prob: 0.0 # This can be non zero with Flash Attention 2 28 | deterministic_fa2: false 29 | 30 | # Dataloaders 31 | train_loader: 32 | name: text 33 | dataset: 34 | local: ${data_local} 35 | remote: ${data_remote} 36 | split: train 37 | tokenizer_name: ${tokenizer_name} 38 | max_seq_len: ${max_seq_len} 39 | shuffle: true 40 | mlm_probability: ${mlm_probability} 41 | drop_last: true 42 | num_workers: 8 43 | 44 | eval_loader: 45 | name: text 46 | dataset: 47 | local: ${data_local} 48 | remote: ${data_remote} 49 | split: val 50 | tokenizer_name: ${tokenizer_name} 51 | max_seq_len: ${max_seq_len} 52 | shuffle: false 53 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 54 | drop_last: false 55 | num_workers: 8 56 | 57 | # Optimization 58 | scheduler: 59 | name: linear_decay_with_warmup 60 | t_warmup: 0.06dur # Warmup to the full LR for 6% of the training duration 61 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 62 | 63 | optimizer: 64 | name: decoupled_adamw 65 | lr: 5.0e-4 # Peak learning rate 66 | betas: 67 | - 0.9 68 | - 0.98 69 | eps: 1.0e-06 70 | weight_decay: 1.0e-5 # Amount of weight decay regularization 71 | filter_bias_norm_wd: false # If True, doesn't apply weight decay to norm layers and biases 72 | 73 | # algorithms: 74 | 75 | max_duration: 286720000sp # Subsample the training data for ~275M samples 76 | eval_interval: 2000ba 77 | global_train_batch_size: 4096 78 | 79 | # System 80 | seed: 17 81 | device_train_microbatch_size: 128 82 | # device_train_microbatch_size: auto 83 | precision: amp_bf16 84 | 85 | global_eval_batch_size: 256 86 | device_eval_microbatch_size: 64 87 | 88 | # Logging 89 | progress_bar: false 90 | log_to_console: true 91 | console_log_interval: 1ba 92 | 93 | callbacks: 94 | speed_monitor: 95 | window_size: 500 96 | lr_monitor: {} 97 | 98 | algorithms: 99 | gradient_clipping: 100 | clipping_type: norm 101 | clipping_threshold: 1.0 102 | 103 | # (Optional) W&B logging 104 | # loggers: 105 | # wandb: 106 | # project: # Fill this in 107 | # entity: # Fill this in 108 | 109 | # (Optional) Checkpoint to local filesystem or remote object store 110 | # save_interval: 3500ba 111 | # save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK 112 | # save_folder: # e.g. './{run_name}/ckpt' (local) or 's3://mybucket/mydir/{run_name}/ckpt' (remote) 113 | 114 | # (Optional) Load from local filesystem or remote object store to 115 | # start from an existing model checkpoint; 116 | # e.g. './ckpt/latest-rank{rank}.pt' (local), or 117 | # 's3://mybucket/mydir/ckpt/latest-rank{rank}.pt' (remote) 118 | # load_path: null 119 | -------------------------------------------------------------------------------- /yamls/models/flex_bert.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: flex_bert 3 | model_config: 4 | activation_function: silu 5 | attention_layer: base 6 | attention_probs_dropout_prob: 0.0 7 | attn_out_bias: False 8 | attn_out_dropout_prob: 0.0 9 | attn_qkv_bias: False 10 | bert_layer: prenorm 11 | decoder_bias: True 12 | embed_dropout_prob: 0.0 13 | embed_norm: True 14 | final_norm: True 15 | embedding_layer: absolute_pos 16 | encoder_layer: base 17 | loss_function: cross_entropy 18 | loss_kwargs: 19 | reduction: mean 20 | mlp_dropout_prob: 0.0 21 | mlp_in_bias: False 22 | mlp_layer: mlp 23 | mlp_out_bias: False 24 | norm_kwargs: 25 | eps: 1e-6 26 | normalization: rmsnorm 27 | padding: unpadded 28 | head_class_act: silu 29 | head_class_bias: False 30 | head_class_dropout: 0.0 31 | head_class_norm: False 32 | head_pred_act: silu 33 | head_pred_bias: False 34 | head_pred_dropout: 0.0 35 | head_pred_norm: True 36 | hidden_act: silu 37 | pooling_type: mean 38 | use_fa2: True 39 | use_sdpa_attn_mask: False 40 | init_method: default 41 | init_std: 0.02 42 | init_cutoff_factor: 2.0 43 | init_small_embedding: False 44 | initial_attention_layer: null 45 | initial_bert_layer: null 46 | initial_mlp_layer: null 47 | num_initial_layers: 1 48 | skip_first_prenorm: False 49 | deterministic_fa2: False 50 | sliding_window: -1 51 | global_attn_every_n_layers: -1 52 | local_attn_rotary_emb_base: -1 53 | unpad_embeddings: False 54 | pad_logits: False 55 | 56 | -------------------------------------------------------------------------------- /yamls/models/hf_bert.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: hf_bert 3 | model_config: 4 | normalization: layernorm 5 | hidden_act: gelu -------------------------------------------------------------------------------- /yamls/models/mosaic_bert.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: mosaic_bert 3 | model_config: 4 | normalization: layernorm 5 | hidden_act: gelu 6 | position_embed: alibi # TODO: use in model building 7 | mlp: residual_glu # TODO: use in model building -------------------------------------------------------------------------------- /yamls/test/glue.yaml: -------------------------------------------------------------------------------- 1 | # Use this YAML to verify that GLUE fine-tuning works. Runs on CPU or GPUs (if available). 2 | # From `examples/bert`, run: 3 | # `python glue.py yamls/test/glue.yaml` to run using the HuggingFace BERT 4 | # `python glue.py yamls/test/glue.yaml model.name=mosaic_bert` to run using the Mosaic BERT 5 | 6 | # Whether to run the various GLUE jobs serially or in parallel (must be `false` on CPU) 7 | parallel: false 8 | 9 | # Basic run configuration, additional details will be added to this name for each GLUE task, and each random seed 10 | base_run_name: glue-finetuning-benchmark-test 11 | default_seed: 1111 12 | precision: fp32 13 | 14 | # Tokenizer for dataset creation 15 | tokenizer_name: bert-base-uncased 16 | 17 | # Base model config 18 | model: 19 | name: hf_bert 20 | pretrained_model_name: prajjwal1/bert-tiny 21 | tokenizer_name: ${tokenizer_name} 22 | 23 | 24 | # Loading 25 | starting_checkpoint_load_path: # Start from scratch for the sake of testing 26 | local_pretrain_checkpoint_folder: ./local-bert-checkpoints/ 27 | 28 | # Saving 29 | save_finetune_checkpoint_prefix: ./local-finetune-checkpoints/ # (local) 30 | save_finetune_checkpoint_folder: ${save_finetune_checkpoint_prefix}/${base_run_name} 31 | 32 | # Callbacks 33 | callbacks: 34 | lr_monitor: {} 35 | speed_monitor: {} 36 | 37 | # Scheduler 38 | scheduler: 39 | name: linear_decay_with_warmup 40 | t_warmup: 0.5dur 41 | alpha_f: 0.0 42 | 43 | # Task configuration 44 | tasks: # Only run MNLI and RTE for the sake of testing 45 | mnli: 46 | # Specify any extra task-specific arguments for the trainer here 47 | trainer_kwargs: 48 | # We keep one MNLI checkpoint locally so that we can start finetuning of 49 | # RTE, MRPC and STS-B from the MNLI checkpoint 50 | save_num_checkpoints_to_keep: 1 51 | max_duration: 10ba 52 | eval_subset_num_batches: 10 53 | rte: 54 | trainer_kwargs: 55 | save_num_checkpoints_to_keep: 0 56 | max_duration: 10ba 57 | eval_subset_num_batches: 10 58 | -------------------------------------------------------------------------------- /yamls/test/main.yaml: -------------------------------------------------------------------------------- 1 | # Use this YAML to verify that MLM pre-training works. Runs on CPU or GPUs (if available). 2 | # From `examples/bert`, run: 3 | # `composer main.py yamls/test/main.yaml` to run using the HuggingFace BERT 4 | # `composer main.py yamls/test/main.yaml model.name=mosaic_bert` to run using the Mosaic BERT 5 | 6 | data_remote: # If blank, files must be present in data_local 7 | data_local: ./my-copy-c4 8 | tokenizer_name: prajjwal1/bert-tiny 9 | max_seq_len: 128 10 | mlm_probability: 0.15 11 | 12 | # Run Name 13 | run_name: test 14 | 15 | # Model 16 | model: 17 | name: hf_bert 18 | use_pretrained: false # Train the model from scratch. Set to true to start from the HF off-the-shelf weights. 19 | pretrained_model_name: ${tokenizer_name} 20 | tokenizer_name: ${tokenizer_name} 21 | 22 | # Dataloaders 23 | train_loader: 24 | name: text 25 | dataset: 26 | remote: ${data_remote} 27 | local: ${data_local} 28 | split: train_small 29 | tokenizer_name: ${tokenizer_name} 30 | max_seq_len: ${max_seq_len} 31 | predownload: 1000 32 | shuffle: true 33 | mlm_probability: ${mlm_probability} 34 | drop_last: true 35 | num_workers: 8 36 | 37 | eval_loader: 38 | name: text 39 | dataset: 40 | remote: ${data_remote} 41 | local: ${data_local} 42 | split: val 43 | tokenizer_name: ${tokenizer_name} 44 | max_seq_len: ${max_seq_len} 45 | predownload: 1000 46 | shuffle: false 47 | mlm_probability: ${mlm_probability} 48 | drop_last: false 49 | num_workers: 8 50 | 51 | # Optimization 52 | scheduler: 53 | name: linear_decay_with_warmup 54 | t_warmup: 0.5dur # Warmup to the full LR for 6% of the training duration 55 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 56 | 57 | optimizer: 58 | name: decoupled_adamw 59 | lr: 2.0e-4 60 | betas: 61 | - 0.9 62 | - 0.95 63 | eps: 1.0e-08 64 | weight_decay: 0.0 65 | 66 | # Training duration and evaluation frequency 67 | max_duration: 10ba 68 | eval_interval: 10ba 69 | eval_subset_num_batches: 20 # For code testing, evaluate on a subset of 20 batches 70 | global_train_batch_size: 16 71 | global_eval_batch_size: 16 72 | 73 | # System 74 | seed: 17 75 | device_eval_microbatch_size: 16 76 | device_train_microbatch_size: 16 77 | precision: fp32 78 | 79 | # Logging 80 | progress_bar: false 81 | log_to_console: true 82 | console_log_interval: 1ba 83 | 84 | callbacks: 85 | speed_monitor: 86 | window_size: 5 87 | lr_monitor: {} 88 | -------------------------------------------------------------------------------- /yamls/test/sequence_classification.yaml: -------------------------------------------------------------------------------- 1 | # Use this YAML to verify that fine-tuning starter script works. Runs on CPU or GPUs (if available). 2 | # From `examples/bert`, run: 3 | # `composer sequence_classification.py yamls/test/sequence_classification.yaml` (HuggingFace BERT) 4 | # `composer sequence_classification.py yamls/test/sequence_classification.yaml model.name=mosaic_bert` (Mosaic BERT) 5 | 6 | tokenizer_name: prajjwal1/bert-tiny 7 | max_seq_len: 128 8 | 9 | # Run Name 10 | run_name: test 11 | 12 | load_path: # (Optionally) provide a composer checkpoint to use for the starting weights 13 | 14 | # Model 15 | model: 16 | name: mosaic_bert 17 | num_labels: 2 18 | pretrained_model_name: ${tokenizer_name} 19 | tokenizer_name: ${tokenizer_name} 20 | 21 | # Dataloaders 22 | train_loader: 23 | split: train 24 | tokenizer_name: ${tokenizer_name} 25 | max_seq_len: ${max_seq_len} 26 | shuffle: true 27 | drop_last: true 28 | num_workers: 8 29 | 30 | eval_loader: 31 | split: validation 32 | tokenizer_name: ${tokenizer_name} 33 | max_seq_len: ${max_seq_len} 34 | shuffle: true 35 | drop_last: true 36 | num_workers: 8 37 | 38 | # Optimization 39 | scheduler: 40 | name: linear_decay_with_warmup 41 | t_warmup: 0.06dur # Warmup to the full LR for 6% of the training duration 42 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 43 | 44 | optimizer: 45 | name: decoupled_adamw 46 | lr: 1.0e-5 47 | betas: 48 | - 0.9 49 | - 0.98 50 | eps: 1.0e-06 51 | weight_decay: 1.0e-6 52 | 53 | # Training duration and evaluation frequency 54 | max_duration: 10ba 55 | eval_interval: 10ba 56 | eval_subset_num_batches: 4 # For code testing, evaluate on a subset of 4 batches 57 | global_train_batch_size: 16 58 | 59 | # System 60 | seed: 17 61 | device_eval_microbatch_size: 16 62 | device_train_microbatch_size: 16 63 | precision: fp32 64 | 65 | # Logging 66 | progress_bar: false 67 | log_to_console: true 68 | console_log_interval: 1ba 69 | 70 | # Optionally log to W&B 71 | # loggers: 72 | # wandb: {} 73 | 74 | callbacks: 75 | speed_monitor: 76 | window_size: 5 77 | lr_monitor: {} 78 | --------------------------------------------------------------------------------