├── LICENSE ├── README.md ├── assets ├── attention_masks.png ├── decoder_results.jpg ├── enc_vs_dec.jpg ├── encoder_results.jpg ├── sizes.jpg ├── sizes_small.png └── training_data.jpg ├── bias_eval ├── README.md ├── batch_eval.py ├── create_plot.py └── eval.py ├── docs ├── decoder-eval.md └── encoder-generative-eval.md ├── glue_evaluation └── README.md ├── pretraining ├── README.md ├── configs │ ├── cross-train │ │ ├── base │ │ │ ├── decoder_base.yaml │ │ │ └── encoder_base.yaml │ │ ├── huge │ │ │ ├── decoder_huge.yaml │ │ │ └── encoder_huge.yaml │ │ ├── large │ │ │ ├── decoder_large.yaml │ │ │ └── encoder_large.yaml │ │ ├── mini │ │ │ ├── decoder_mini.yaml │ │ │ └── encoder_mini.yaml │ │ ├── tiny │ │ │ ├── decoder_tiny.yaml │ │ │ └── encoder_tiny.yaml │ │ └── very_tiny │ │ │ ├── decoder_very_tiny.yaml │ │ │ └── encoder_very_tiny.yaml │ └── standard │ │ ├── base │ │ ├── de_decay │ │ │ ├── decoder_base.yaml │ │ │ └── encoder_base.yaml │ │ ├── decoder_base.yaml │ │ ├── encoder_base.yaml │ │ └── prolong_decay │ │ │ ├── decoder_base.yaml │ │ │ └── encoder_base.yaml │ │ ├── huge │ │ ├── de_decay │ │ │ ├── decoder_huge.yaml │ │ │ └── encoder_huge.yaml │ │ ├── decoder_huge.yaml │ │ ├── encoder_huge.yaml │ │ └── prolong_decay │ │ │ ├── decoder_huge.yaml │ │ │ └── encoder_huge.yaml │ │ ├── large │ │ ├── de_decay │ │ │ ├── decoder_large.yaml │ │ │ └── encoder_large.yaml │ │ ├── decoder_large.yaml │ │ ├── encoder_large.yaml │ │ └── prolong_decay │ │ │ ├── decoder_large.yaml │ │ │ └── encoder_large.yaml │ │ ├── mini │ │ ├── de_decay │ │ │ ├── decoder_mini.yaml │ │ │ └── encoder_mini.yaml │ │ ├── decoder_mini.yaml │ │ ├── encoder_mini.yaml │ │ └── prolong_decay │ │ │ ├── decoder_mini.yaml │ │ │ └── encoder_mini.yaml │ │ ├── tiny │ │ ├── de_decay │ │ │ ├── decoder_tiny.yaml │ │ │ └── encoder_tiny.yaml │ │ ├── decoder_tiny.yaml │ │ ├── encoder_tiny.yaml │ │ └── prolong_decay │ │ │ ├── decoder_tiny.yaml │ │ │ └── encoder_tiny.yaml │ │ └── very_tiny │ │ ├── de_decay │ │ ├── decoder_very_tiny.yaml │ │ └── encoder_very_tiny.yaml │ │ ├── decoder_very_tiny.yaml │ │ ├── encoder_very_tiny.yaml │ │ └── prolong_decay │ │ ├── decoder_very_tiny.yaml │ │ └── encoder_very_tiny.yaml └── data_processing │ ├── README.md │ ├── bin │ ├── chunk_all.sh │ ├── count_instances.py │ ├── datasets.txt │ ├── decompress.py │ ├── dolma_mds_all.sh │ ├── download_all_data.sh │ ├── download_folder.py │ ├── download_from_hub.py │ ├── jsonl_to_mds.py │ ├── make_recursive_root.py │ ├── random_sample_from_all_data.py │ ├── sample_down_or_up_to_meet.py │ ├── sample_for_context_extension_random.py │ ├── sample_too_large_down.py │ ├── tokenize_all.sh │ └── upload_large_folder.py │ ├── requirements.txt │ └── src │ ├── __init__.py │ ├── initial_dataset_creation │ ├── __init__.py │ ├── dolma_to_mds.py │ ├── dolma_urls.py │ ├── hf_to_mds.py │ ├── mds_to_jsonl.py │ └── merge_mds_to_one_index.py │ ├── sampling │ ├── __init__.py │ ├── move_chunks.py │ ├── move_out_final_sampled_chunks.py │ ├── sample_from_chunks.py │ ├── sample_from_chunks_extra_large.py │ ├── sample_from_folders.py │ ├── split_dataset_into_chunks.py │ └── split_dataset_into_chunks_individual.py │ ├── tokenization │ ├── __init__.py │ ├── decode_all.py │ ├── decode_mds.py │ ├── move_tokenized.py │ ├── tokenize_mds.py │ └── tokenize_mds_subfolders.py │ └── utils │ ├── cleanup.py │ ├── cleanup_all.py │ ├── compare_subfolders.py │ ├── compare_train_and_chunked.py │ ├── create_final_dataset_index.py │ ├── data_utils.py │ ├── get_counts_from_hf.py │ └── upload_dataset_by_subfolders.py └── retrieval_eval ├── README.md └── train_st.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Johns Hopkins University Center for Language and Speech Processing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /assets/attention_masks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JHU-CLSP/ettin-encoder-vs-decoder/6cbf954a2ecaa7c753be6468a1bccdef401d5f93/assets/attention_masks.png -------------------------------------------------------------------------------- /assets/decoder_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JHU-CLSP/ettin-encoder-vs-decoder/6cbf954a2ecaa7c753be6468a1bccdef401d5f93/assets/decoder_results.jpg -------------------------------------------------------------------------------- /assets/enc_vs_dec.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JHU-CLSP/ettin-encoder-vs-decoder/6cbf954a2ecaa7c753be6468a1bccdef401d5f93/assets/enc_vs_dec.jpg -------------------------------------------------------------------------------- /assets/encoder_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JHU-CLSP/ettin-encoder-vs-decoder/6cbf954a2ecaa7c753be6468a1bccdef401d5f93/assets/encoder_results.jpg -------------------------------------------------------------------------------- /assets/sizes.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JHU-CLSP/ettin-encoder-vs-decoder/6cbf954a2ecaa7c753be6468a1bccdef401d5f93/assets/sizes.jpg -------------------------------------------------------------------------------- /assets/sizes_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JHU-CLSP/ettin-encoder-vs-decoder/6cbf954a2ecaa7c753be6468a1bccdef401d5f93/assets/sizes_small.png -------------------------------------------------------------------------------- /assets/training_data.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JHU-CLSP/ettin-encoder-vs-decoder/6cbf954a2ecaa7c753be6468a1bccdef401d5f93/assets/training_data.jpg -------------------------------------------------------------------------------- /bias_eval/README.md: -------------------------------------------------------------------------------- 1 | # Bias Evaluation Toolkit 2 | 3 | This directory contains tools for evaluating gender bias in language models using the Winogender dataset, specifically focusing on counter-stereotypical examples. 4 | 5 | ## Overview 6 | 7 | The evaluation toolkit tests how well models handle counter-stereotypical gendered pronouns in occupational contexts. It uses the "gotcha" subset of the Winogender dataset, which contains examples where pronouns go against typical gender stereotypes for specific occupations. 8 | 9 | ## Files 10 | 11 | - `eval.py` - Main evaluation script with `WinogenderEvaluator` class 12 | - `batch_eval.py` - Batch evaluation script for multiple models 13 | 14 | ## Key Features 15 | 16 | ### Model Support 17 | - **Encoder Models** (MLM): Masks pronouns and evaluates prediction probabilities 18 | - **Decoder Models** (Causal LM): Compares perplexity across gender variants 19 | 20 | ### Evaluation Metrics 21 | - **Gotcha Preference Rate**: How often the model prefers counter-stereotypical pronouns 22 | - **Gender Distribution**: Breakdown of model predictions by gender 23 | - **Baseline Comparisons**: Against uniform and dataset-based baselines 24 | 25 | ## Usage 26 | 27 | ### Single Model Evaluation 28 | 29 | ```bash 30 | # Evaluate an encoder model 31 | python eval.py --model_name "jhu-clsp/ettin-encoder-17m" --model_type encoder --output_path results.json 32 | 33 | # Evaluate a decoder model 34 | python eval.py --model_name "jhu-clsp/ettin-decoder-17m" --model_type decoder --output_path results.json 35 | ``` 36 | 37 | ### Batch Evaluation 38 | 39 | ```bash 40 | # Run evaluation on all configured models 41 | python batch_eval.py 42 | ``` 43 | 44 | The batch script will: 45 | - Load existing results to avoid re-evaluation 46 | - Evaluate missing models incrementally 47 | - Save results after each model 48 | - Provide progress updates and summary statistics 49 | 50 | ## Evaluation Method 51 | 52 | ### For Encoder Models (MLM) 53 | 1. Takes counter-stereotypical sentences from Winogender "gotcha" dataset 54 | 2. Masks pronouns in each sentence 55 | 3. Evaluates model's pronoun predictions 56 | 4. Classifies predictions by gender (male/female/neutral) 57 | 5. Calculates preference rates for counter-stereotypical pronouns 58 | 59 | ### For Decoder Models (Causal LM) 60 | 1. Takes counter-stereotypical sentences from Winogender "gotcha" dataset 61 | 2. Generates alternative versions with different gender pronouns 62 | 3. Compares perplexity across all gender variants 63 | 4. Determines which gender version the model "prefers" (lowest perplexity) 64 | 5. Measures how often the counter-stereotypical version is preferred 65 | 66 | ## Example Output 67 | 68 | ``` 69 | GOTCHA EVALUATION SUMMARY 70 | ======================================== 71 | Model: your-model-name (encoder) 72 | Testing counter-stereotypical vs stereotypical preferences 73 | 74 | 🎯 GOTCHA PREFERENCE RATE: 0.342 75 | (How often model prefers counter-stereotypical pronouns) 76 | Expected baseline: 0.289 (dataset-based) 77 | Uniform baseline: 0.333 (1/3 random) 78 | 79 | Encoder Details: 80 | Total gotcha examples tested: 120 81 | 82 | Dataset gender distribution: 83 | male: 0.400 (40.0%) 84 | female: 0.350 (35.0%) 85 | neutral: 0.250 (25.0%) 86 | 87 | Model prediction distribution: 88 | male: 0.425 (42.5%) 89 | female: 0.308 (30.8%) 90 | neutral: 0.267 (26.7%) 91 | 92 | Gender-specific gotcha preference rates: 93 | male: 0.354 94 | female: 0.310 95 | neutral: 0.367 96 | ``` 97 | 98 | ## Interpretation 99 | 100 | - **Gotcha Preference Rate > 0.33**: Model shows some ability to handle counter-stereotypical cases 101 | - **Rate ≈ baseline**: Model performs similarly to random chance 102 | - **Rate < baseline**: Model shows strong stereotypical bias 103 | - **Gender-specific rates**: Reveals which gender presentations the model handles better 104 | 105 | ## Model Configuration 106 | 107 | The batch evaluation script is configured to test these model families: 108 | - `ettin-encoder` models (17M to 1B parameters) 109 | - `ettin-decoder` models (17M to 1B parameters) 110 | - Cross-trained variants (`enc-from-dec`, `dec-from-enc`) 111 | 112 | To add new models, edit the `models` list in `batch_eval.py`. 113 | 114 | ## Dataset 115 | 116 | Uses the Winogender dataset "gotcha" split from HuggingFace: 117 | - Dataset: `oskarvanderwal/winogender` 118 | - Split: `gotcha` (counter-stereotypical examples) 119 | - Contains sentences with pronouns that go against occupational stereotypes -------------------------------------------------------------------------------- /docs/decoder-eval.md: -------------------------------------------------------------------------------- 1 | # Decoder Evaluation on Generative Tasks 2 | 3 | This guide covers evaluating Ettin decoder models on generative language tasks using the EleutherAI evaluation harness (commit `867413f8677f00f6a817262727cbb041bf36192a`). 4 | 5 | ## Overview 6 | 7 | Ettin decoder models excel at generative tasks and should be evaluated using the standard EleutherAI lm-evaluation-harness. This provides comprehensive evaluation across a wide range of language understanding and generation benchmarks. 8 | 9 | ## Quick Start 10 | 11 | ### Installation 12 | 13 | ```bash 14 | # Clone the specific commit of lm-evaluation-harness 15 | git clone https://github.com/EleutherAI/lm-evaluation-harness.git 16 | cd lm-evaluation-harness 17 | git checkout 867413f8677f00f6a817262727cbb041bf36192a 18 | pip install -e . 19 | ``` 20 | 21 | ### Basic Evaluation 22 | 23 | ```bash 24 | # Evaluate Ettin decoder on core tasks 25 | lm_eval --model hf \ 26 | --model_args "pretrained=jhu-clsp/ettin-decoder-150m,add_bos_token=True" \ 27 | --tasks hellaswag,arc_easy,arc_challenge,winogrande \ 28 | --device cuda:0 \ 29 | --batch_size 8 \ 30 | --output_path results/ettin-decoder-150m 31 | ``` 32 | 33 | 34 | ## Multi-Model Evaluation Script 35 | 36 | ```bash 37 | #!/bin/bash 38 | # evaluate_all_decoders.sh 39 | 40 | MODELS=( 41 | "jhu-clsp/ettin-decoder-17m" 42 | "jhu-clsp/ettin-decoder-32m" 43 | "jhu-clsp/ettin-decoder-68m" 44 | "jhu-clsp/ettin-decoder-150m" 45 | "jhu-clsp/ettin-decoder-400m" 46 | "jhu-clsp/ettin-decoder-1b" 47 | ) 48 | 49 | TASKS="hellaswag,arc_easy,arc_challenge,winogrande,piqa,boolq" 50 | 51 | for model in "${MODELS[@]}"; do 52 | echo "Evaluating $model..." 53 | output_dir="results/$(basename $model)" 54 | 55 | lm_eval --model hf \ 56 | --model_args "pretrained=$model,add_bos_token=True" \ 57 | --tasks $TASKS \ 58 | --device cuda:0 \ 59 | --batch_size 8 \ 60 | --output_path $output_dir \ 61 | --log_samples 62 | done 63 | ``` 64 | 65 | ## Checkpoint Evaluation 66 | 67 | ```bash 68 | # Evaluate specific training checkpoints 69 | lm_eval --model hf \ 70 | --model_args "pretrained=jhu-clsp/ettin-decoder-400m,revision=step590532,add_bos_token=True" \ 71 | --tasks hellaswag,arc_easy \ 72 | --device cuda:0 \ 73 | --batch_size 8 \ 74 | --output_path results/ettin-decoder-400m-step590532 75 | ``` 76 | 77 | ## Cross-Objective Model Evaluation 78 | Is done in the same way as the above, since they are decoders now. 79 | 80 | 81 | ## Links and Resources 82 | 83 | - **Evaluation Harness**: [EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) 84 | - **Specific Commit**: [867413f8677f00f6a817262727cbb041bf36192a](https://github.com/EleutherAI/lm-evaluation-harness/commit/867413f8677f00f6a817262727cbb041bf36192a) 85 | - **Model Collection**: [jhu-clsp on HuggingFace](https://huggingface.co/jhu-clsp) 86 | - **Documentation**: [lm-eval docs](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs) 87 | 88 | --- 89 | 90 | For issues with decoder evaluation, please refer to the [EleutherAI evaluation harness documentation](https://github.com/EleutherAI/lm-evaluation-harness) or open an issue in the [Ettin repository](https://github.com/jhu-clsp/ettin-encoder-vs-decoder). -------------------------------------------------------------------------------- /glue_evaluation/README.md: -------------------------------------------------------------------------------- 1 | # GLUE Evaluation 2 | 3 | This directory contains scripts and documentation for evaluating Ettin encoder and decoder models on the GLUE (General Language Understanding Evaluation) benchmark tasks. 4 | 5 | ## Overview 6 | 7 | GLUE is a collection of nine English sentence understanding tasks, designed to evaluate and analyze the general language understanding capabilities of language models. 8 | 9 | ## Quick Start 10 | 11 | ### Installation 12 | 13 | ```bash 14 | # Install required dependencies 15 | pip install transformers datasets evaluate scikit-learn 16 | 17 | # For hyperparameter sweeps (optional) 18 | pip install wandb optuna 19 | 20 | # Clone training repository for advanced configs 21 | git clone https://github.com/orionw/bert24.git 22 | cd bert24 23 | pip install -e . 24 | ``` 25 | 26 | ### Quick Evaluation Example 27 | 28 | TODO 29 | 30 | ## Links and Resources 31 | 32 | - **📊 GLUE Benchmark**: [https://gluebenchmark.com/](https://gluebenchmark.com/) 33 | - **📖 GLUE Paper**: [https://arxiv.org/abs/1804.07461](https://arxiv.org/abs/1804.07461) 34 | - **🤗 HuggingFace GLUE**: [https://huggingface.co/datasets/glue](https://huggingface.co/datasets/glue) 35 | - **📈 Papers With Code**: [https://paperswithcode.com/dataset/glue](https://paperswithcode.com/dataset/glue) 36 | - **📖 Training Repository**: [https://github.com/orionw/bert24](https://github.com/orionw/bert24) 37 | 38 | --- 39 | 40 | For questions about GLUE evaluation, please open an issue in the main [Ettin repository](https://github.com/jhu-clsp/ettin-encoder-vs-decoder/issues). -------------------------------------------------------------------------------- /pretraining/README.md: -------------------------------------------------------------------------------- 1 | # Pre-training Guide 2 | 3 | This guide covers the complete pre-training process for Ettin models, including data preparation, training setup, and the adapted ModernBERT recipe used for both encoder and decoder models. 4 | 5 | ## Overview 6 | 7 | Ettin models are trained using a three-phase approach adapted from the ModernBERT training recipe, with identical data and procedures for both encoder and decoder models to enable fair architectural comparisons. 8 | 9 | ## Training Repository 10 | 11 | **📖 training code**: [https://github.com/orionw/bert24](https://github.com/orionw/bert24) 12 | 13 | That repository is a fork of the ModernBERT training codebase, extended with decoder model support and training objectives. You can just clone it and run the command when training. 14 | 15 | ## Training Phases 16 | 17 | ### Phase 1: Pre-training (1.7T tokens) 18 | - **Duration**: ~600k steps 19 | - **Data**: Diverse mixture including web text, books, code, and scientific papers 20 | - **Context Length**: 1024 tokens initially, gradually increased 21 | - **Learning Rate**: Peak after warmup 22 | 23 | ### Phase 2: Mid-training/Extension (250B tokens) 24 | - **Duration**: ~100k steps 25 | - **Data**: Higher-quality filtered subset with domain balancing 26 | - **Context Length**: Extended to 8k tokens 27 | - **Learning Rate**: decay to half LR from Phase 1 28 | 29 | ### Phase 3: Decay Phase (50B tokens) 30 | - **Duration**: ~20k steps 31 | - **Data**: Premium sources (books, academic papers, curated web content) 32 | - **Context Length**: Maintained at 8k tokens 33 | - **Learning Rate**: another decay to 0.02 of the LR 34 | 35 | 36 | ## Data Preprocessing 37 | You can use the existing data available in Huggingface or create your own. The data should be in MosiacML `streaming` format. For (messy) scripts to do data preprocessing see the README in [Data Processing Guide](data_processing/README.md). 38 | 39 | ## Model Configurations 40 | 41 | ### Architecture Scaling 42 | 43 | The repository includes configurations for all Ettin model sizes: 44 | 45 | | Model Size | Config File | Layers | Hidden Size | Intermediate Size | Attention Heads | 46 | |:-----------|:------------|:-------|:------------|:------------------|:----------------| 47 | | 17M | `configs/ettin_17m.yaml` | 7 | 256 | 384 | 4 | 48 | | 32M | `configs/ettin_32m.yaml` | 10 | 384 | 576 | 6 | 49 | | 68M | `configs/ettin_68m.yaml` | 19 | 512 | 768 | 8 | 50 | | 150M | `configs/ettin_150m.yaml` | 22 | 768 | 1152 | 12 | 51 | | 400M | `configs/ettin_400m.yaml` | 28 | 1024 | 2624 | 16 | 52 | | 1B | `configs/ettin_1b.yaml` | 28 | 1792 | 3840 | 28 | 53 | 54 | 55 | ## Quick Start 56 | 57 | ### Setup Environment 58 | 59 | ```bash 60 | # Clone the training repository 61 | git clone https://github.com/orionw/bert24.git 62 | cd bert24 63 | 64 | # Install dependencies 65 | pip install -r requirements.txt 66 | pip install -e . 67 | ``` 68 | 69 | ### Data Preparation 70 | 71 | See [Data Processing Guide](data_processing/README.md) for detailed preprocessing instructions or download the data from huggingface, e.g. the [pretraining data here](https://huggingface.co/datasets/jhu-clsp/ettin-pretraining-data). 72 | 73 | ### Training Commands 74 | The training command will infer the number of GPUs. 75 | 76 | `composer main.py $yaml_config_file` 77 | 78 | You can run the cross-objective version by using those configs, which are the same but load from the opposite checkpoint. See configs/cross-train for examples 79 | 80 | ### Decoder → Encoder Conversion 81 | There are a few changes to do: (1) change the tokenizer to work like an encoder (2) change the model class to be modernbert and (3) re-combine the qkv layer. We have some messy scripts to do this and will be uploading them soon, if you need them sooner please open an issue or message us! 82 | 83 | 84 | ## Hardware Requirements 85 | All models are trained on 4x H100s. Training time is approximately: 86 | 87 | - 1B: 2170 hours to do 2T (~90 days, we did only ~40) 88 | - 400M: 950 hours (~40 days) 89 | - 150M: 470 hours (~20 days) 90 | - 68M: 300 hours (~13 days) 91 | - 32M: 212 hours (~9 days) 92 | - 17M: 141 hours (~6 days) 93 | 94 | ## Links and Resources 95 | 96 | - **📖 Training Repository**: [https://github.com/orionw/bert24](https://github.com/orionw/bert24) 97 | - **📊 Training Data**: [HuggingFace Datasets](https://huggingface.co/datasets/jhu-clsp) 98 | - **🔧 Model Configs**: [Configuration Files](./pretraining/configs) 99 | 100 | -- -------------------------------------------------------------------------------- /pretraining/configs/cross-train/base/decoder_base.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.15 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: base_decoder_TO_encoder 11 | 12 | # Model 13 | model: 14 | name: flex_gpt_to_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 22 25 | hidden_size: 768 26 | intermediate_size: 1152 27 | num_attention_heads: 12 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: true 65 | 66 | # Dataloaders 67 | train_loader: 68 | name: text 69 | dataset: 70 | local: ${data_local} 71 | remote: ${data_remote} 72 | split: train 73 | tokenizer_name: ${tokenizer_name} 74 | max_seq_len: ${max_seq_len} 75 | shuffle: true 76 | mlm_probability: ${mlm_probability} 77 | streaming: false 78 | shuffle_seed: 21 79 | drop_last: true 80 | num_workers: 6 81 | sequence_packing: false 82 | batch_size_warmup_min_size: ${device_train_microbatch_size} 83 | batch_size_warmup_tokens: 10_000_000_000tok 84 | 85 | 86 | # Optimization 87 | scheduler: 88 | name: warmup_stable_decay 89 | t_warmup: 3_000_000_000tok # Warmup to the full LR for 6% of the training duration 90 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 91 | t_decay: 10_000_000_000tok 92 | 93 | 94 | optimizer: 95 | name: decoupled_stableadamw 96 | lr: 4e-4 # Peak learning rate 97 | betas: 98 | - 0.9 99 | - 0.98 100 | eps: 1.0e-06 101 | weight_decay: 1.0e-5 # Amount of weight decay regularization 102 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 103 | log_grad_norm: true 104 | 105 | max_duration: 50_000_000_000tok 106 | eval_interval: 99999999999999999999999999999999ba # don't do this 107 | global_train_batch_size: 4608 108 | global_eval_batch_size: 1024 109 | 110 | # System 111 | seed: 210 112 | device_eval_batch_size: 12 113 | device_train_microbatch_size: 12 114 | 115 | precision: amp_bf16 116 | 117 | # Logging 118 | progress_bar: true 119 | log_to_console: true 120 | console_log_interval: 100ba 121 | 122 | callbacks: 123 | speed_monitor: 124 | window_size: 100 125 | lr_monitor: {} 126 | scheduled_gc: {} 127 | log_grad_norm: 128 | batch_log_interval: 100 129 | packing_efficiency: 130 | log_interval: 100 131 | 132 | # W&B logging 133 | loggers: 134 | wandb: 135 | project: encoder_vs_decoder-cross 136 | entity: your_wandb_entity 137 | 138 | # Checkpoint to local filesystem or remote object store 139 | save_interval: 10_000_000_000tok 140 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 141 | save_folder: your_save_folder 142 | 143 | 144 | # Load from local filesystem or remote object store to 145 | load_path: your_load_path 146 | reset_time: true 147 | load_weights_only: true 148 | 149 | -------------------------------------------------------------------------------- /pretraining/configs/cross-train/base/encoder_base.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.0 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: base_encoder_TO_decoder 11 | 12 | # Model 13 | model: 14 | name: flex_bert_to_gpt 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 22 25 | hidden_size: 768 26 | intermediate_size: 1152 27 | num_attention_heads: 12 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: false 65 | causal_mask: true 66 | pad_logits: true 67 | 68 | # Dataloaders 69 | train_loader: 70 | name: text 71 | dataset: 72 | local: ${data_local} 73 | remote: ${data_remote} 74 | split: train 75 | tokenizer_name: ${tokenizer_name} 76 | max_seq_len: ${max_seq_len} 77 | shuffle: true 78 | mlm_probability: ${mlm_probability} 79 | streaming: false 80 | use_decoder_attn_mask: true 81 | shuffle_seed: 21 82 | suppress_masking: true 83 | drop_last: true 84 | num_workers: 6 85 | sequence_packing: false 86 | batch_size_warmup_min_size: ${device_train_microbatch_size} 87 | batch_size_warmup_tokens: 10_000_000_000tok 88 | 89 | 90 | 91 | # Optimization 92 | scheduler: 93 | name: warmup_stable_decay 94 | t_warmup: 3_000_000_000tok # Warmup to the full LR for 6% of the training duration 95 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 96 | t_decay: 10_000_000_000tok 97 | 98 | 99 | optimizer: 100 | name: decoupled_stableadamw 101 | lr: 4e-4 # Peak learning rate 102 | betas: 103 | - 0.9 104 | - 0.98 105 | eps: 1.0e-06 106 | weight_decay: 1.0e-5 # Amount of weight decay regularization 107 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 108 | log_grad_norm: true 109 | 110 | max_duration: 50_000_000_000tok 111 | eval_interval: 99999999999999999999999999999999ba # don't do this 112 | global_train_batch_size: 4608 113 | global_eval_batch_size: 1024 114 | 115 | # System 116 | seed: 210 117 | device_eval_batch_size: 12 118 | device_train_microbatch_size: 12 119 | 120 | precision: amp_bf16 121 | 122 | # Logging 123 | progress_bar: true 124 | log_to_console: true 125 | console_log_interval: 100ba 126 | 127 | callbacks: 128 | speed_monitor: 129 | window_size: 100 130 | lr_monitor: {} 131 | scheduled_gc: {} 132 | log_grad_norm: 133 | batch_log_interval: 100 134 | packing_efficiency: 135 | log_interval: 100 136 | 137 | # W&B logging 138 | loggers: 139 | wandb: 140 | project: encoder_vs_decoder-cross 141 | entity: your_wandb_entity 142 | 143 | # Checkpoint to local filesystem or remote object store 144 | save_interval: 10_000_000_000tok 145 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 146 | save_folder: your_save_folder 147 | 148 | 149 | # Load from local filesystem or remote object store to 150 | load_path: your_load_path 151 | reset_time: true 152 | load_weights_only: true 153 | -------------------------------------------------------------------------------- /pretraining/configs/cross-train/huge/decoder_huge.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.15 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: huge_decoder_TO_encoder 11 | 12 | # Model 13 | model: 14 | name: flex_gpt_to_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 28 25 | hidden_size: 1792 26 | intermediate_size: 3840 27 | num_attention_heads: 28 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: true 65 | 66 | # Dataloaders 67 | train_loader: 68 | name: text 69 | dataset: 70 | local: ${data_local} 71 | remote: ${data_remote} 72 | split: train 73 | tokenizer_name: ${tokenizer_name} 74 | max_seq_len: ${max_seq_len} 75 | shuffle: true 76 | mlm_probability: ${mlm_probability} 77 | streaming: false 78 | shuffle_seed: 21 79 | drop_last: true 80 | num_workers: 6 81 | sequence_packing: false 82 | batch_size_warmup_min_size: ${device_train_microbatch_size} 83 | batch_size_warmup_tokens: 10_000_000_000tok 84 | 85 | 86 | # Optimization 87 | scheduler: 88 | name: warmup_stable_decay 89 | t_warmup: 1_000_000_000tok # Warmup to the full LR for 6% of the training duration 90 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 91 | t_decay: 3_333_000_000tok 92 | 93 | 94 | optimizer: 95 | name: decoupled_stableadamw 96 | lr: 2.5e-4 # Peak learning rate 97 | betas: 98 | - 0.9 99 | - 0.98 100 | eps: 1.0e-06 101 | weight_decay: 5.0e-5 # Amount of weight decay regularization 102 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 103 | log_grad_norm: true 104 | 105 | max_duration: 16_666_666_667tok 106 | eval_interval: 99999999999999999999999999999999ba # don't do this 107 | global_train_batch_size: 4608 108 | global_eval_batch_size: 1024 109 | 110 | # System 111 | seed: 210 112 | device_eval_batch_size: 2 113 | device_train_microbatch_size: 2 114 | 115 | precision: amp_bf16 116 | 117 | # Logging 118 | progress_bar: true 119 | log_to_console: true 120 | console_log_interval: 100ba 121 | 122 | callbacks: 123 | speed_monitor: 124 | window_size: 100 125 | lr_monitor: {} 126 | scheduled_gc: {} 127 | log_grad_norm: 128 | batch_log_interval: 100 129 | packing_efficiency: 130 | log_interval: 100 131 | 132 | # W&B logging 133 | loggers: 134 | wandb: 135 | project: encoder_vs_decoder-cross 136 | entity: your_wandb_entity 137 | 138 | # Checkpoint to local filesystem or remote object store 139 | save_interval: 3_333_333_333tok 140 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 141 | save_folder: your_save_folder 142 | 143 | 144 | # Load from local filesystem or remote object store to 145 | load_path: your_load_path 146 | reset_time: true 147 | load_weights_only: true 148 | 149 | -------------------------------------------------------------------------------- /pretraining/configs/cross-train/huge/encoder_huge.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.0 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: huge_encoder_TO_decoder 11 | 12 | # Model 13 | model: 14 | name: flex_bert_to_gpt 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 28 25 | hidden_size: 1792 26 | intermediate_size: 3840 27 | num_attention_heads: 28 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: false 65 | causal_mask: true 66 | pad_logits: true 67 | 68 | # Dataloaders 69 | train_loader: 70 | name: text 71 | dataset: 72 | local: ${data_local} 73 | remote: ${data_remote} 74 | split: train 75 | tokenizer_name: ${tokenizer_name} 76 | max_seq_len: ${max_seq_len} 77 | shuffle: true 78 | mlm_probability: ${mlm_probability} 79 | streaming: false 80 | use_decoder_attn_mask: true 81 | shuffle_seed: 21 82 | suppress_masking: true 83 | drop_last: true 84 | num_workers: 6 85 | sequence_packing: false 86 | batch_size_warmup_min_size: ${device_train_microbatch_size} 87 | batch_size_warmup_tokens: 10_000_000_000tok 88 | 89 | 90 | 91 | # Optimization 92 | scheduler: 93 | name: warmup_stable_decay 94 | t_warmup: 1_000_000_000tok # Warmup to the full LR for 6% of the training duration 95 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 96 | t_decay: 3_333_000_000tok 97 | 98 | 99 | optimizer: 100 | name: decoupled_stableadamw 101 | lr: 2.5e-4 # Peak learning rate 102 | betas: 103 | - 0.9 104 | - 0.98 105 | eps: 1.0e-06 106 | weight_decay: 5.0e-5 # Amount of weight decay regularization 107 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 108 | log_grad_norm: true 109 | 110 | max_duration: 16_666_666_667tok 111 | eval_interval: 99999999999999999999999999999999ba # don't do this 112 | global_train_batch_size: 4608 113 | global_eval_batch_size: 1024 114 | 115 | # System 116 | seed: 210 117 | device_eval_batch_size: 2 118 | device_train_microbatch_size: 2 119 | 120 | precision: amp_bf16 121 | 122 | # Logging 123 | progress_bar: true 124 | log_to_console: true 125 | console_log_interval: 100ba 126 | 127 | callbacks: 128 | speed_monitor: 129 | window_size: 100 130 | lr_monitor: {} 131 | scheduled_gc: {} 132 | log_grad_norm: 133 | batch_log_interval: 100 134 | packing_efficiency: 135 | log_interval: 100 136 | 137 | # W&B logging 138 | loggers: 139 | wandb: 140 | project: encoder_vs_decoder-cross 141 | entity: your_wandb_entity 142 | 143 | # Checkpoint to local filesystem or remote object store 144 | save_interval: 3_333_333_333tok 145 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 146 | save_folder: your_save_folder 147 | 148 | 149 | # Load from local filesystem or remote object store to 150 | load_path: your_load_path 151 | reset_time: true 152 | load_weights_only: true 153 | -------------------------------------------------------------------------------- /pretraining/configs/cross-train/large/decoder_large.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.15 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: large_decoder_TO_encoder 11 | 12 | # Model 13 | model: 14 | name: flex_gpt_to_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 28 25 | hidden_size: 1024 26 | intermediate_size: 2624 27 | num_attention_heads: 16 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: true 65 | 66 | # Dataloaders 67 | train_loader: 68 | name: text 69 | dataset: 70 | local: ${data_local} 71 | remote: ${data_remote} 72 | split: train 73 | tokenizer_name: ${tokenizer_name} 74 | max_seq_len: ${max_seq_len} 75 | shuffle: true 76 | mlm_probability: ${mlm_probability} 77 | streaming: false 78 | shuffle_seed: 21 79 | drop_last: true 80 | num_workers: 6 81 | sequence_packing: false 82 | batch_size_warmup_min_size: ${device_train_microbatch_size} 83 | batch_size_warmup_tokens: 10_000_000_000tok 84 | 85 | 86 | # Optimization 87 | scheduler: 88 | name: warmup_stable_decay 89 | t_warmup: 3_000_000_000tok # Warmup to the full LR for 6% of the training duration 90 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 91 | t_decay: 10_000_000_000tok 92 | 93 | 94 | optimizer: 95 | name: decoupled_stableadamw 96 | lr: 2.5e-4 # Peak learning rate 97 | betas: 98 | - 0.9 99 | - 0.98 100 | eps: 1.0e-06 101 | weight_decay: 1.0e-5 # Amount of weight decay regularization 102 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 103 | log_grad_norm: true 104 | 105 | max_duration: 50_000_000_000tok 106 | eval_interval: 99999999999999999999999999999999ba # don't do this 107 | global_train_batch_size: 4608 108 | global_eval_batch_size: 1024 109 | 110 | # System 111 | seed: 210 112 | device_eval_batch_size: 4 113 | device_train_microbatch_size: 4 114 | 115 | precision: amp_bf16 116 | 117 | # Logging 118 | progress_bar: true 119 | log_to_console: true 120 | console_log_interval: 100ba 121 | 122 | callbacks: 123 | speed_monitor: 124 | window_size: 100 125 | lr_monitor: {} 126 | scheduled_gc: {} 127 | log_grad_norm: 128 | batch_log_interval: 100 129 | packing_efficiency: 130 | log_interval: 100 131 | 132 | # W&B logging 133 | loggers: 134 | wandb: 135 | project: encoder_vs_decoder-cross 136 | entity: your_wandb_entity 137 | 138 | # Checkpoint to local filesystem or remote object store 139 | save_interval: 10_000_000_000tok 140 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 141 | save_folder: your_save_folder 142 | 143 | 144 | # Load from local filesystem or remote object store to 145 | load_path: your_load_path 146 | reset_time: true 147 | load_weights_only: true 148 | 149 | -------------------------------------------------------------------------------- /pretraining/configs/cross-train/large/encoder_large.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.0 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: large_encoder_TO_decoder 11 | 12 | # Model 13 | model: 14 | name: flex_bert_to_gpt 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 28 25 | hidden_size: 1024 26 | intermediate_size: 2624 27 | num_attention_heads: 16 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: false 65 | causal_mask: true 66 | pad_logits: true 67 | 68 | # Dataloaders 69 | train_loader: 70 | name: text 71 | dataset: 72 | local: ${data_local} 73 | remote: ${data_remote} 74 | split: train 75 | tokenizer_name: ${tokenizer_name} 76 | max_seq_len: ${max_seq_len} 77 | shuffle: true 78 | mlm_probability: ${mlm_probability} 79 | streaming: false 80 | use_decoder_attn_mask: true 81 | shuffle_seed: 21 82 | suppress_masking: true 83 | drop_last: true 84 | num_workers: 6 85 | sequence_packing: false 86 | batch_size_warmup_min_size: ${device_train_microbatch_size} 87 | batch_size_warmup_tokens: 10_000_000_000tok 88 | 89 | 90 | 91 | # Optimization 92 | scheduler: 93 | name: warmup_stable_decay 94 | t_warmup: 3_000_000_000tok # Warmup to the full LR for 6% of the training duration 95 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 96 | t_decay: 10_000_000_000tok 97 | 98 | 99 | optimizer: 100 | name: decoupled_stableadamw 101 | lr: 2.5e-4 # Peak learning rate 102 | betas: 103 | - 0.9 104 | - 0.98 105 | eps: 1.0e-06 106 | weight_decay: 1.0e-5 # Amount of weight decay regularization 107 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 108 | log_grad_norm: true 109 | 110 | max_duration: 50_000_000_000tok 111 | eval_interval: 99999999999999999999999999999999ba # don't do this 112 | global_train_batch_size: 4608 113 | global_eval_batch_size: 1024 114 | 115 | # System 116 | seed: 210 117 | device_eval_batch_size: 4 118 | device_train_microbatch_size: 4 119 | 120 | precision: amp_bf16 121 | 122 | # Logging 123 | progress_bar: true 124 | log_to_console: true 125 | console_log_interval: 100ba 126 | 127 | callbacks: 128 | speed_monitor: 129 | window_size: 100 130 | lr_monitor: {} 131 | scheduled_gc: {} 132 | log_grad_norm: 133 | batch_log_interval: 100 134 | packing_efficiency: 135 | log_interval: 100 136 | 137 | # W&B logging 138 | loggers: 139 | wandb: 140 | project: encoder_vs_decoder-cross 141 | entity: your_wandb_entity 142 | 143 | # Checkpoint to local filesystem or remote object store 144 | save_interval: 10_000_000_000tok 145 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 146 | save_folder: your_save_folder 147 | 148 | 149 | # Load from local filesystem or remote object store to 150 | load_path: your_load_path 151 | reset_time: true 152 | load_weights_only: true 153 | -------------------------------------------------------------------------------- /pretraining/configs/cross-train/mini/decoder_mini.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.15 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: mini_decoder_TO_encoder 11 | 12 | # Model 13 | model: 14 | name: flex_gpt_to_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 19 25 | hidden_size: 512 26 | intermediate_size: 768 27 | num_attention_heads: 8 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: true 65 | 66 | # Dataloaders 67 | train_loader: 68 | name: text 69 | dataset: 70 | local: ${data_local} 71 | remote: ${data_remote} 72 | split: train 73 | tokenizer_name: ${tokenizer_name} 74 | max_seq_len: ${max_seq_len} 75 | shuffle: true 76 | mlm_probability: ${mlm_probability} 77 | streaming: false 78 | shuffle_seed: 21 79 | drop_last: true 80 | num_workers: 6 81 | sequence_packing: false 82 | batch_size_warmup_min_size: ${device_train_microbatch_size} 83 | batch_size_warmup_tokens: 10_000_000_000tok 84 | 85 | 86 | # Optimization 87 | scheduler: 88 | name: warmup_stable_decay 89 | t_warmup: 3_000_000_000tok # Warmup to the full LR for 6% of the training duration 90 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 91 | t_decay: 10_000_000_000tok 92 | 93 | 94 | optimizer: 95 | name: decoupled_stableadamw 96 | lr: 1.5e-3 # Peak learning rate 97 | betas: 98 | - 0.9 99 | - 0.98 100 | eps: 1.0e-06 101 | weight_decay: 3.0e-4 # Amount of weight decay regularization 102 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 103 | log_grad_norm: true 104 | 105 | max_duration: 50_000_000_000tok 106 | eval_interval: 99999999999999999999999999999999ba # don't do this 107 | global_train_batch_size: 4608 108 | global_eval_batch_size: 1024 109 | 110 | # System 111 | seed: 210 112 | device_eval_batch_size: 8 113 | device_train_microbatch_size: 8 114 | 115 | precision: amp_bf16 116 | 117 | # Logging 118 | progress_bar: true 119 | log_to_console: true 120 | console_log_interval: 100ba 121 | 122 | callbacks: 123 | speed_monitor: 124 | window_size: 100 125 | lr_monitor: {} 126 | scheduled_gc: {} 127 | log_grad_norm: 128 | batch_log_interval: 100 129 | packing_efficiency: 130 | log_interval: 100 131 | 132 | # W&B logging 133 | loggers: 134 | wandb: 135 | project: encoder_vs_decoder-cross 136 | entity: your_wandb_entity 137 | 138 | # Checkpoint to local filesystem or remote object store 139 | save_interval: 10_000_000_000tok 140 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 141 | save_folder: your_save_folder 142 | 143 | 144 | # Load from local filesystem or remote object store to 145 | load_path: your_load_path 146 | reset_time: true 147 | load_weights_only: true 148 | 149 | -------------------------------------------------------------------------------- /pretraining/configs/cross-train/mini/encoder_mini.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.0 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: mini_encoder_TO_decoder 11 | 12 | # Model 13 | model: 14 | name: flex_bert_to_gpt 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 19 25 | hidden_size: 512 26 | intermediate_size: 768 27 | num_attention_heads: 8 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: false 65 | causal_mask: true 66 | pad_logits: true 67 | 68 | # Dataloaders 69 | train_loader: 70 | name: text 71 | dataset: 72 | local: ${data_local} 73 | remote: ${data_remote} 74 | split: train 75 | tokenizer_name: ${tokenizer_name} 76 | max_seq_len: ${max_seq_len} 77 | shuffle: true 78 | mlm_probability: ${mlm_probability} 79 | streaming: false 80 | use_decoder_attn_mask: true 81 | shuffle_seed: 21 82 | suppress_masking: true 83 | drop_last: true 84 | num_workers: 6 85 | sequence_packing: false 86 | batch_size_warmup_min_size: ${device_train_microbatch_size} 87 | batch_size_warmup_tokens: 10_000_000_000tok 88 | 89 | 90 | 91 | # Optimization 92 | scheduler: 93 | name: warmup_stable_decay 94 | t_warmup: 3_000_000_000tok # Warmup to the full LR for 6% of the training duration 95 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 96 | t_decay: 10_000_000_000tok 97 | 98 | 99 | optimizer: 100 | name: decoupled_stableadamw 101 | lr: 1.5e-3 # Peak learning rate 102 | betas: 103 | - 0.9 104 | - 0.98 105 | eps: 1.0e-06 106 | weight_decay: 3.0e-4 # Amount of weight decay regularization 107 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 108 | log_grad_norm: true 109 | 110 | max_duration: 50_000_000_000tok 111 | eval_interval: 99999999999999999999999999999999ba # don't do this 112 | global_train_batch_size: 4608 113 | global_eval_batch_size: 1024 114 | 115 | # System 116 | seed: 210 117 | device_eval_batch_size: 8 118 | device_train_microbatch_size: 8 119 | 120 | precision: amp_bf16 121 | 122 | # Logging 123 | progress_bar: true 124 | log_to_console: true 125 | console_log_interval: 100ba 126 | 127 | callbacks: 128 | speed_monitor: 129 | window_size: 100 130 | lr_monitor: {} 131 | scheduled_gc: {} 132 | log_grad_norm: 133 | batch_log_interval: 100 134 | packing_efficiency: 135 | log_interval: 100 136 | 137 | # W&B logging 138 | loggers: 139 | wandb: 140 | project: encoder_vs_decoder-cross 141 | entity: your_wandb_entity 142 | 143 | # Checkpoint to local filesystem or remote object store 144 | save_interval: 10_000_000_000tok 145 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 146 | save_folder: your_save_folder 147 | 148 | 149 | # Load from local filesystem or remote object store to 150 | load_path: your_load_path 151 | reset_time: true 152 | load_weights_only: true 153 | -------------------------------------------------------------------------------- /pretraining/configs/cross-train/tiny/decoder_tiny.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.15 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: tiny_decoder_TO_encoder 11 | 12 | # Model 13 | model: 14 | name: flex_gpt_to_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 10 25 | hidden_size: 384 26 | intermediate_size: 576 27 | num_attention_heads: 6 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: true 65 | 66 | # Dataloaders 67 | train_loader: 68 | name: text 69 | dataset: 70 | local: ${data_local} 71 | remote: ${data_remote} 72 | split: train 73 | tokenizer_name: ${tokenizer_name} 74 | max_seq_len: ${max_seq_len} 75 | shuffle: true 76 | mlm_probability: ${mlm_probability} 77 | streaming: false 78 | shuffle_seed: 21 79 | drop_last: true 80 | num_workers: 6 81 | sequence_packing: false 82 | batch_size_warmup_min_size: ${device_train_microbatch_size} 83 | batch_size_warmup_tokens: 10_000_000_000tok 84 | 85 | # Optimization 86 | scheduler: 87 | name: warmup_stable_decay 88 | t_warmup: 3_000_000_000tok # Warmup to the full LR for 6% of the training duration 89 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 90 | t_decay: 10_000_000_000tok 91 | 92 | optimizer: 93 | name: decoupled_stableadamw 94 | lr: 1.5e-3 # Peak learning rate 95 | betas: 96 | - 0.9 97 | - 0.98 98 | eps: 1.0e-06 99 | weight_decay: 3.0e-4 # Amount of weight decay regularization 100 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 101 | log_grad_norm: true 102 | 103 | max_duration: 50_000_000_000tok 104 | eval_interval: 99999999999999999999999999999999ba # don't do this 105 | global_train_batch_size: 4608 106 | global_eval_batch_size: 1024 107 | 108 | # System 109 | seed: 210 110 | device_eval_batch_size: 12 111 | device_train_microbatch_size: 12 112 | 113 | precision: amp_bf16 114 | 115 | # Logging 116 | progress_bar: true 117 | log_to_console: true 118 | console_log_interval: 100ba 119 | 120 | callbacks: 121 | speed_monitor: 122 | window_size: 100 123 | lr_monitor: {} 124 | scheduled_gc: {} 125 | log_grad_norm: 126 | batch_log_interval: 100 127 | packing_efficiency: 128 | log_interval: 100 129 | 130 | # W&B logging 131 | loggers: 132 | wandb: 133 | project: encoder_vs_decoder-cross 134 | entity: your_wandb_entity 135 | 136 | # Checkpoint to local filesystem or remote object store 137 | save_interval: 10_000_000_000tok 138 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 139 | save_folder: your_save_folder 140 | 141 | 142 | # Load from local filesystem or remote object store to 143 | load_path: your_load_path 144 | reset_time: true 145 | load_weights_only: true 146 | 147 | -------------------------------------------------------------------------------- /pretraining/configs/cross-train/tiny/encoder_tiny.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.0 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: tiny_encoder_TO_decoder 11 | 12 | # Model 13 | model: 14 | name: flex_bert_to_gpt 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 10 25 | hidden_size: 384 26 | intermediate_size: 576 27 | num_attention_heads: 6 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: false 65 | causal_mask: true 66 | pad_logits: true 67 | 68 | # Dataloaders 69 | train_loader: 70 | name: text 71 | dataset: 72 | local: ${data_local} 73 | remote: ${data_remote} 74 | split: train 75 | tokenizer_name: ${tokenizer_name} 76 | max_seq_len: ${max_seq_len} 77 | shuffle: true 78 | mlm_probability: ${mlm_probability} 79 | streaming: false 80 | use_decoder_attn_mask: true 81 | shuffle_seed: 21 82 | suppress_masking: true 83 | drop_last: true 84 | num_workers: 6 85 | sequence_packing: false 86 | batch_size_warmup_min_size: ${device_train_microbatch_size} 87 | batch_size_warmup_tokens: 10_000_000_000tok 88 | 89 | 90 | 91 | # Optimization 92 | scheduler: 93 | name: warmup_stable_decay 94 | t_warmup: 3_000_000_000tok # Warmup to the full LR for 6% of the training duration 95 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 96 | t_decay: 10_000_000_000tok 97 | 98 | 99 | optimizer: 100 | name: decoupled_stableadamw 101 | lr: 1.5e-3 # Peak learning rate 102 | betas: 103 | - 0.9 104 | - 0.98 105 | eps: 1.0e-06 106 | weight_decay: 3.0e-4 # Amount of weight decay regularization 107 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 108 | log_grad_norm: true 109 | 110 | max_duration: 50_000_000_000tok 111 | eval_interval: 99999999999999999999999999999999ba # don't do this 112 | global_train_batch_size: 4608 113 | global_eval_batch_size: 1024 114 | 115 | # System 116 | seed: 210 117 | device_eval_batch_size: 12 118 | device_train_microbatch_size: 12 119 | 120 | precision: amp_bf16 121 | 122 | # Logging 123 | progress_bar: true 124 | log_to_console: true 125 | console_log_interval: 100ba 126 | 127 | callbacks: 128 | speed_monitor: 129 | window_size: 100 130 | lr_monitor: {} 131 | scheduled_gc: {} 132 | log_grad_norm: 133 | batch_log_interval: 100 134 | packing_efficiency: 135 | log_interval: 100 136 | 137 | # W&B logging 138 | loggers: 139 | wandb: 140 | project: encoder_vs_decoder-cross 141 | entity: your_wandb_entity 142 | 143 | # Checkpoint to local filesystem or remote object store 144 | save_interval: 10_000_000_000tok 145 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 146 | save_folder: your_save_folder 147 | 148 | 149 | # Load from local filesystem or remote object store to 150 | load_path: your_load_path 151 | reset_time: true 152 | load_weights_only: true 153 | -------------------------------------------------------------------------------- /pretraining/configs/cross-train/very_tiny/decoder_very_tiny.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.15 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: very_tiny_decoder_TO_encoder 11 | 12 | # Model 13 | model: 14 | name: flex_gpt_to_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 7 25 | hidden_size: 256 26 | intermediate_size: 384 27 | num_attention_heads: 4 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: true 65 | 66 | # Dataloaders 67 | train_loader: 68 | name: text 69 | dataset: 70 | local: ${data_local} 71 | remote: ${data_remote} 72 | split: train 73 | tokenizer_name: ${tokenizer_name} 74 | max_seq_len: ${max_seq_len} 75 | shuffle: true 76 | mlm_probability: ${mlm_probability} 77 | streaming: false 78 | shuffle_seed: 21 79 | drop_last: true 80 | num_workers: 6 81 | sequence_packing: false 82 | batch_size_warmup_min_size: ${device_train_microbatch_size} 83 | batch_size_warmup_tokens: 10_000_000_000tok 84 | 85 | 86 | # Optimization 87 | scheduler: 88 | name: warmup_stable_decay 89 | t_warmup: 3_000_000_000tok # Warmup to the full LR for 6% of the training duration 90 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 91 | t_decay: 10_000_000_000tok 92 | 93 | 94 | optimizer: 95 | name: decoupled_stableadamw 96 | lr: 1.5e-3 # Peak learning rate 97 | betas: 98 | - 0.9 99 | - 0.98 100 | eps: 1.0e-06 101 | weight_decay: 3.0e-4 # Amount of weight decay regularization 102 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 103 | log_grad_norm: true 104 | 105 | max_duration: 50_000_000_000tok 106 | eval_interval: 99999999999999999999999999999999ba # don't do this 107 | global_train_batch_size: 4608 108 | global_eval_batch_size: 1024 109 | 110 | # System 111 | seed: 210 112 | device_eval_batch_size: 12 113 | device_train_microbatch_size: 12 114 | 115 | precision: amp_bf16 116 | 117 | # Logging 118 | progress_bar: true 119 | log_to_console: true 120 | console_log_interval: 100ba 121 | 122 | callbacks: 123 | speed_monitor: 124 | window_size: 100 125 | lr_monitor: {} 126 | scheduled_gc: {} 127 | log_grad_norm: 128 | batch_log_interval: 100 129 | packing_efficiency: 130 | log_interval: 100 131 | 132 | # W&B logging 133 | loggers: 134 | wandb: 135 | project: encoder_vs_decoder-cross 136 | entity: your_wandb_entity 137 | 138 | # Checkpoint to local filesystem or remote object store 139 | save_interval: 10_000_000_000tok 140 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 141 | save_folder: your_save_folder 142 | 143 | 144 | # Load from local filesystem or remote object store to 145 | load_path: your_load_path 146 | reset_time: true 147 | load_weights_only: true 148 | 149 | -------------------------------------------------------------------------------- /pretraining/configs/cross-train/very_tiny/encoder_very_tiny.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.0 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: very_tiny_encoder_TO_decoder 11 | 12 | # Model 13 | model: 14 | name: flex_bert_to_gpt 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 7 25 | hidden_size: 256 26 | intermediate_size: 384 27 | num_attention_heads: 4 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: false 65 | causal_mask: true 66 | pad_logits: true 67 | 68 | # Dataloaders 69 | train_loader: 70 | name: text 71 | dataset: 72 | local: ${data_local} 73 | remote: ${data_remote} 74 | split: train 75 | tokenizer_name: ${tokenizer_name} 76 | max_seq_len: ${max_seq_len} 77 | shuffle: true 78 | mlm_probability: ${mlm_probability} 79 | streaming: false 80 | use_decoder_attn_mask: true 81 | shuffle_seed: 21 82 | suppress_masking: true 83 | drop_last: true 84 | num_workers: 6 85 | sequence_packing: false 86 | batch_size_warmup_min_size: ${device_train_microbatch_size} 87 | batch_size_warmup_tokens: 10_000_000_000tok 88 | 89 | 90 | 91 | # Optimization 92 | scheduler: 93 | name: warmup_stable_decay 94 | t_warmup: 3_000_000_000tok # Warmup to the full LR for 6% of the training duration 95 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 96 | t_decay: 10_000_000_000tok 97 | 98 | 99 | optimizer: 100 | name: decoupled_stableadamw 101 | lr: 1.5e-3 # Peak learning rate 102 | betas: 103 | - 0.9 104 | - 0.98 105 | eps: 1.0e-06 106 | weight_decay: 3.0e-4 # Amount of weight decay regularization 107 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 108 | log_grad_norm: true 109 | 110 | max_duration: 50_000_000_000tok 111 | eval_interval: 99999999999999999999999999999999ba # don't do this 112 | global_train_batch_size: 4608 113 | global_eval_batch_size: 1024 114 | 115 | # System 116 | seed: 210 117 | device_eval_batch_size: 12 118 | device_train_microbatch_size: 12 119 | 120 | precision: amp_bf16 121 | 122 | # Logging 123 | progress_bar: true 124 | log_to_console: true 125 | console_log_interval: 100ba 126 | 127 | callbacks: 128 | speed_monitor: 129 | window_size: 100 130 | lr_monitor: {} 131 | scheduled_gc: {} 132 | log_grad_norm: 133 | batch_log_interval: 100 134 | packing_efficiency: 135 | log_interval: 100 136 | 137 | # W&B logging 138 | loggers: 139 | wandb: 140 | project: encoder_vs_decoder-cross 141 | entity: your_wandb_entity 142 | 143 | # Checkpoint to local filesystem or remote object store 144 | save_interval: 10_000_000_000tok 145 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 146 | save_folder: your_save_folder 147 | 148 | 149 | # Load from local filesystem or remote object store to 150 | load_path: your_load_path 151 | reset_time: true 152 | load_weights_only: true 153 | -------------------------------------------------------------------------------- /pretraining/configs/standard/base/de_decay/encoder_base.yaml: -------------------------------------------------------------------------------- 1 | data_local: sample_250B 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: encoder_base_no_packing_de_decay 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 22 25 | hidden_size: 768 26 | intermediate_size: 1152 27 | num_attention_heads: 12 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | rotary_emb_scale_base: null 57 | rotary_emb_interleaved: false 58 | allow_embedding_resizing: true 59 | sliding_window: 128 60 | global_attn_every_n_layers: 3 61 | unpad_embeddings: true 62 | compile_model: true 63 | masked_prediction: true 64 | 65 | # Dataloaders 66 | train_loader: 67 | name: text 68 | dataset: 69 | local: ${data_local} 70 | remote: ${data_remote} 71 | split: train 72 | tokenizer_name: ${tokenizer_name} 73 | max_seq_len: ${max_seq_len} 74 | shuffle: true 75 | mlm_probability: ${mlm_probability} 76 | streaming: false 77 | shuffle_seed: 21 78 | drop_last: true 79 | num_workers: 6 80 | sequence_packing: false 81 | batch_size_warmup_min_size: ${device_train_microbatch_size} 82 | batch_size_warmup_tokens: 0tok 83 | 84 | 85 | eval_loader: 86 | name: text 87 | dataset: 88 | local: ${data_local} 89 | remote: 90 | split: validation 91 | tokenizer_name: ${tokenizer_name} 92 | max_seq_len: ${max_seq_len} 93 | shuffle: false 94 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 95 | streaming: false 96 | drop_last: false 97 | num_workers: 3 98 | sequence_packing: false 99 | 100 | 101 | # Optimization 102 | scheduler: 103 | name: one_minus_sqrt 104 | alpha_f: 0.5 105 | t_decay: ${max_duration} 106 | t_max: ${max_duration} 107 | 108 | optimizer: 109 | name: decoupled_stableadamw 110 | lr: 8e-4 # Peak learning rate 111 | betas: 112 | - 0.9 113 | - 0.98 114 | eps: 1.0e-06 115 | weight_decay: 1.0e-5 # Amount of weight decay regularization 116 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 117 | log_grad_norm: true 118 | 119 | max_duration: 250_000_000_000tok 120 | eval_interval: 99999999999999999999999999999999ba # don't do this 121 | global_train_batch_size: 4608 122 | global_eval_batch_size: 1024 123 | 124 | # System 125 | seed: 21 126 | device_eval_batch_size: 16 127 | device_train_microbatch_size: 16 128 | 129 | precision: amp_bf16 130 | 131 | # Logging 132 | progress_bar: true 133 | log_to_console: true 134 | console_log_interval: 100ba 135 | 136 | callbacks: 137 | speed_monitor: 138 | window_size: 100 139 | lr_monitor: {} 140 | scheduled_gc: {} 141 | log_grad_norm: 142 | batch_log_interval: 100 143 | packing_efficiency: 144 | log_interval: 100 145 | 146 | # W&B logging 147 | loggers: 148 | wandb: 149 | project: encoder_vs_decoder 150 | entity: your_wandb_entity 151 | 152 | # Checkpoint to local filesystem or remote object store 153 | save_interval: 8_500_000_000tok 154 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 155 | save_folder: your_save_folder 156 | 157 | # Load from local filesystem or remote object store to 158 | autoresume: true 159 | load_path: your_load_path 160 | reset_time: true -------------------------------------------------------------------------------- /pretraining/configs/standard/base/encoder_base.yaml: -------------------------------------------------------------------------------- 1 | data_local: ettin_tokenized_data 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 1024 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: encoder_no_packing 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 22 25 | hidden_size: 768 26 | intermediate_size: 1152 27 | num_attention_heads: 12 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 10000.0 56 | rotary_emb_scale_base: null 57 | rotary_emb_interleaved: false 58 | allow_embedding_resizing: true 59 | sliding_window: 128 60 | global_attn_every_n_layers: 3 61 | unpad_embeddings: true 62 | compile_model: true 63 | masked_prediction: true 64 | 65 | # Dataloaders 66 | train_loader: 67 | name: text 68 | dataset: 69 | local: ${data_local} 70 | remote: ${data_remote} 71 | split: train 72 | tokenizer_name: ${tokenizer_name} 73 | max_seq_len: ${max_seq_len} 74 | shuffle: true 75 | mlm_probability: ${mlm_probability} 76 | streaming: false 77 | shuffle_seed: 21 78 | drop_last: true 79 | num_workers: 6 80 | sequence_packing: false 81 | batch_size_warmup_min_size: ${device_train_microbatch_size} 82 | batch_size_warmup_tokens: 50_000_000_000tok 83 | 84 | 85 | eval_loader: 86 | name: text 87 | dataset: 88 | local: ${data_local} 89 | remote: 90 | split: validation 91 | tokenizer_name: ${tokenizer_name} 92 | max_seq_len: ${max_seq_len} 93 | shuffle: false 94 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 95 | streaming: false 96 | drop_last: false 97 | num_workers: 3 98 | sequence_packing: false 99 | 100 | 101 | # Optimization 102 | scheduler: 103 | name: warmup_stable_decay 104 | t_warmup: 3_000_000_000tok # Warmup to the full LR for 6% of the training duration 105 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 106 | t_decay: 0tok 107 | 108 | optimizer: 109 | name: decoupled_stableadamw 110 | lr: 8e-4 # Peak learning rate 111 | betas: 112 | - 0.9 113 | - 0.98 114 | eps: 1.0e-06 115 | weight_decay: 1.0e-5 # Amount of weight decay regularization 116 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 117 | log_grad_norm: true 118 | 119 | max_duration: 1_700_000_000_000tok 120 | eval_interval: 99999999999999999999999999999999ba # don't do this 121 | global_train_batch_size: 4608 122 | global_eval_batch_size: 1024 123 | 124 | # System 125 | seed: 21 126 | device_eval_batch_size: 48 127 | device_train_microbatch_size: 72 128 | 129 | precision: amp_bf16 130 | 131 | # Logging 132 | progress_bar: true 133 | log_to_console: true 134 | console_log_interval: 100ba 135 | 136 | callbacks: 137 | speed_monitor: 138 | window_size: 100 139 | lr_monitor: {} 140 | scheduled_gc: {} 141 | log_grad_norm: 142 | batch_log_interval: 100 143 | packing_efficiency: 144 | log_interval: 100 145 | 146 | # W&B logging 147 | loggers: 148 | wandb: 149 | project: encoder_vs_decoder 150 | entity: your_wandb_entity 151 | 152 | # Checkpoint to local filesystem or remote object store 153 | save_interval: 8_500_000_000tok 154 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 155 | save_folder: your_save_folder 156 | 157 | # Load from local filesystem or remote object store to 158 | # load_path: null 159 | 160 | autoresume: true -------------------------------------------------------------------------------- /pretraining/configs/standard/base/prolong_decay/encoder_base.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.15 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: encoder_base_no_packing_prolong_decay_lower_mask 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 22 25 | hidden_size: 768 26 | intermediate_size: 1152 27 | num_attention_heads: 12 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: true 65 | 66 | # Dataloaders 67 | train_loader: 68 | name: text 69 | dataset: 70 | local: ${data_local} 71 | remote: ${data_remote} 72 | split: train 73 | tokenizer_name: ${tokenizer_name} 74 | max_seq_len: ${max_seq_len} 75 | shuffle: true 76 | mlm_probability: ${mlm_probability} 77 | streaming: false 78 | shuffle_seed: 21 79 | drop_last: true 80 | num_workers: 6 81 | sequence_packing: false 82 | batch_size_warmup_min_size: ${device_train_microbatch_size} 83 | batch_size_warmup_tokens: 0tok 84 | 85 | 86 | eval_loader: 87 | name: text 88 | dataset: 89 | local: ${data_local} 90 | remote: 91 | split: validation 92 | tokenizer_name: ${tokenizer_name} 93 | max_seq_len: ${max_seq_len} 94 | shuffle: false 95 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 96 | streaming: false 97 | drop_last: false 98 | num_workers: 3 99 | sequence_packing: false 100 | 101 | 102 | # Optimization 103 | scheduler: 104 | name: one_minus_sqrt 105 | alpha_f: 0.04 106 | t_decay: ${max_duration} 107 | t_max: ${max_duration} 108 | 109 | optimizer: 110 | name: decoupled_stableadamw 111 | lr: 4e-4 # Peak learning rate 112 | betas: 113 | - 0.9 114 | - 0.98 115 | eps: 1.0e-06 116 | weight_decay: 1.0e-5 # Amount of weight decay regularization 117 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 118 | log_grad_norm: true 119 | 120 | max_duration: 50_000_000_000tok 121 | eval_interval: 99999999999999999999999999999999ba # don't do this 122 | global_train_batch_size: 4608 123 | global_eval_batch_size: 1024 124 | 125 | # System 126 | seed: 21 127 | device_eval_batch_size: 12 128 | device_train_microbatch_size: 12 129 | 130 | precision: amp_bf16 131 | 132 | # Logging 133 | progress_bar: true 134 | log_to_console: true 135 | console_log_interval: 100ba 136 | 137 | callbacks: 138 | speed_monitor: 139 | window_size: 100 140 | lr_monitor: {} 141 | scheduled_gc: {} 142 | log_grad_norm: 143 | batch_log_interval: 100 144 | packing_efficiency: 145 | log_interval: 100 146 | 147 | # W&B logging 148 | loggers: 149 | wandb: 150 | project: encoder_vs_decoder 151 | entity: your_wandb_entity 152 | 153 | # Checkpoint to local filesystem or remote object store 154 | save_interval: 8_500_000_000tok 155 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 156 | save_folder: your_save_folder 157 | 158 | 159 | # Load from local filesystem or remote object store to 160 | autoresume: true 161 | load_path: your_load_path 162 | reset_time: true 163 | restart_override: true 164 | -------------------------------------------------------------------------------- /pretraining/configs/standard/huge/de_decay/decoder_huge.yaml: -------------------------------------------------------------------------------- 1 | data_local: sample_250B 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.0 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: huge_decoder_no_packing_de_decay 11 | 12 | # Model 13 | model: 14 | name: flex_gpt 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 28 25 | hidden_size: 1792 26 | intermediate_size: 3840 27 | num_attention_heads: 28 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | rotary_emb_scale_base: null 57 | rotary_emb_interleaved: false 58 | allow_embedding_resizing: true 59 | sliding_window: 128 60 | global_attn_every_n_layers: 3 61 | unpad_embeddings: true 62 | compile_model: true 63 | masked_prediction: false 64 | causal_mask: true 65 | pad_logits: true 66 | 67 | 68 | # Dataloaders 69 | train_loader: 70 | name: text 71 | dataset: 72 | local: ${data_local} 73 | remote: ${data_remote} 74 | split: train 75 | tokenizer_name: ${tokenizer_name} 76 | max_seq_len: ${max_seq_len} 77 | shuffle: true 78 | mlm_probability: ${mlm_probability} 79 | streaming: false 80 | use_decoder_attn_mask: true 81 | shuffle_seed: 21 82 | suppress_masking: true 83 | drop_last: true 84 | num_workers: 6 85 | sequence_packing: false 86 | batch_size_warmup_min_size: ${device_train_microbatch_size} 87 | batch_size_warmup_tokens: 0tok 88 | 89 | 90 | eval_loader: 91 | name: text 92 | dataset: 93 | local: ${data_local} 94 | remote: 95 | split: validation 96 | tokenizer_name: ${tokenizer_name} 97 | max_seq_len: ${max_seq_len} 98 | shuffle: false 99 | mlm_probability: 0.0 # We always evaluate at 15% masking for consistent comparison 100 | streaming: false 101 | use_decoder_attn_mask: true 102 | suppress_masking: true 103 | drop_last: false 104 | num_workers: 3 105 | sequence_packing: false 106 | 107 | 108 | # Optimization 109 | scheduler: 110 | name: one_minus_sqrt 111 | alpha_f: 0.5 112 | t_decay: ${max_duration} 113 | t_max: ${max_duration} 114 | 115 | optimizer: 116 | name: decoupled_stableadamw 117 | lr: 5e-4 # Peak learning rate 118 | betas: 119 | - 0.9 120 | - 0.98 121 | eps: 1.0e-06 122 | weight_decay: 5.0e-5 # Amount of weight decay regularization 123 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 124 | log_grad_norm: true 125 | 126 | max_duration: 83_333_000_000tok 127 | eval_interval: 99999999999999999999999999999999ba # don't do this 128 | global_train_batch_size: 4608 129 | global_eval_batch_size: 1024 130 | 131 | # System 132 | seed: 21 133 | device_eval_batch_size: 2 134 | device_train_microbatch_size: 2 135 | 136 | precision: amp_bf16 137 | 138 | # Logging 139 | progress_bar: true 140 | log_to_console: true 141 | console_log_interval: 100ba 142 | 143 | callbacks: 144 | speed_monitor: 145 | window_size: 100 146 | lr_monitor: {} 147 | scheduled_gc: {} 148 | log_grad_norm: 149 | batch_log_interval: 100 150 | packing_efficiency: 151 | log_interval: 100 152 | 153 | # W&B logging 154 | loggers: 155 | wandb: 156 | project: encoder_vs_decoder 157 | entity: your_wandb_entity 158 | 159 | # Checkpoint to local filesystem or remote object store 160 | save_interval: 8_500_000_000tok 161 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 162 | save_folder: your_save_folder 163 | 164 | # Load from local filesystem or remote object store to 165 | autoresume: true 166 | load_path: your_load_path 167 | -------------------------------------------------------------------------------- /pretraining/configs/standard/huge/de_decay/encoder_huge.yaml: -------------------------------------------------------------------------------- 1 | data_local: sample_250B 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: huge_encoder_no_packing_de_decay 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 28 25 | hidden_size: 1792 26 | intermediate_size: 3840 27 | num_attention_heads: 28 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | rotary_emb_scale_base: null 57 | rotary_emb_interleaved: false 58 | allow_embedding_resizing: true 59 | sliding_window: 128 60 | global_attn_every_n_layers: 3 61 | unpad_embeddings: true 62 | compile_model: true 63 | masked_prediction: true 64 | 65 | # Dataloaders 66 | train_loader: 67 | name: text 68 | dataset: 69 | local: ${data_local} 70 | remote: ${data_remote} 71 | split: train 72 | tokenizer_name: ${tokenizer_name} 73 | max_seq_len: ${max_seq_len} 74 | shuffle: true 75 | mlm_probability: ${mlm_probability} 76 | streaming: false 77 | shuffle_seed: 21 78 | drop_last: true 79 | num_workers: 6 80 | sequence_packing: false 81 | batch_size_warmup_min_size: ${device_train_microbatch_size} 82 | batch_size_warmup_tokens: 0tok 83 | 84 | 85 | eval_loader: 86 | name: text 87 | dataset: 88 | local: ${data_local} 89 | remote: 90 | split: validation 91 | tokenizer_name: ${tokenizer_name} 92 | max_seq_len: ${max_seq_len} 93 | shuffle: false 94 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 95 | streaming: false 96 | drop_last: false 97 | num_workers: 3 98 | sequence_packing: false 99 | 100 | 101 | # Optimization 102 | scheduler: 103 | name: one_minus_sqrt 104 | alpha_f: 0.5 105 | t_decay: ${max_duration} 106 | t_max: ${max_duration} 107 | 108 | optimizer: 109 | name: decoupled_stableadamw 110 | lr: 5e-4 # Peak learning rate 111 | betas: 112 | - 0.9 113 | - 0.98 114 | eps: 1.0e-06 115 | weight_decay: 5.0e-5 # Amount of weight decay regularization 116 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 117 | log_grad_norm: true 118 | 119 | max_duration: 83_333_000_000tok 120 | eval_interval: 99999999999999999999999999999999ba # don't do this 121 | global_train_batch_size: 4608 122 | global_eval_batch_size: 1024 123 | 124 | # System 125 | seed: 21 126 | device_eval_batch_size: 2 127 | device_train_microbatch_size: 2 128 | 129 | precision: amp_bf16 130 | 131 | # Logging 132 | progress_bar: true 133 | log_to_console: true 134 | console_log_interval: 100ba 135 | 136 | callbacks: 137 | speed_monitor: 138 | window_size: 100 139 | lr_monitor: {} 140 | scheduled_gc: {} 141 | log_grad_norm: 142 | batch_log_interval: 100 143 | packing_efficiency: 144 | log_interval: 100 145 | 146 | # W&B logging 147 | loggers: 148 | wandb: 149 | project: encoder_vs_decoder 150 | entity: your_wandb_entity 151 | 152 | # Checkpoint to local filesystem or remote object store 153 | save_interval: 8_500_000_000tok 154 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 155 | save_folder: your_save_folder 156 | 157 | # Load from local filesystem or remote object store to 158 | autoresume: true 159 | load_path: your_load_path 160 | -------------------------------------------------------------------------------- /pretraining/configs/standard/huge/encoder_huge.yaml: -------------------------------------------------------------------------------- 1 | data_local: ettin_tokenized_data 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 1024 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: huge_encoder_no_packing 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 28 25 | hidden_size: 1792 26 | intermediate_size: 3840 27 | num_attention_heads: 28 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 10000.0 56 | rotary_emb_scale_base: null 57 | rotary_emb_interleaved: false 58 | allow_embedding_resizing: true 59 | sliding_window: 128 60 | global_attn_every_n_layers: 3 61 | unpad_embeddings: true 62 | compile_model: true 63 | masked_prediction: true 64 | 65 | # Dataloaders 66 | train_loader: 67 | name: text 68 | dataset: 69 | local: ${data_local} 70 | remote: ${data_remote} 71 | split: train 72 | tokenizer_name: ${tokenizer_name} 73 | max_seq_len: ${max_seq_len} 74 | shuffle: true 75 | mlm_probability: ${mlm_probability} 76 | streaming: false 77 | shuffle_seed: 21 78 | drop_last: true 79 | num_workers: 6 80 | sequence_packing: false 81 | batch_size_warmup_min_size: ${device_train_microbatch_size} 82 | batch_size_warmup_tokens: 3_000_000_000tok 83 | 84 | 85 | eval_loader: 86 | name: text 87 | dataset: 88 | local: ${data_local} 89 | remote: 90 | split: validation 91 | tokenizer_name: ${tokenizer_name} 92 | max_seq_len: ${max_seq_len} 93 | shuffle: false 94 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 95 | streaming: false 96 | drop_last: false 97 | num_workers: 3 98 | sequence_packing: false 99 | 100 | 101 | # Optimization 102 | scheduler: 103 | name: warmup_stable_decay 104 | t_warmup: 2_000_000_000tok # Warmup to the full LR for 6% of the training duration 105 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 106 | t_decay: 0tok 107 | 108 | optimizer: 109 | name: decoupled_stableadamw 110 | lr: 5e-4 # Peak learning rate 111 | betas: 112 | - 0.9 113 | - 0.98 114 | eps: 1.0e-06 115 | weight_decay: 5.0e-5 # Amount of weight decay regularization 116 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 117 | log_grad_norm: true 118 | 119 | max_duration: 1_700_000_000_000tok 120 | eval_interval: 99999999999999999999999999999999ba # don't do this 121 | global_train_batch_size: 4608 122 | global_eval_batch_size: 1024 123 | 124 | # System 125 | seed: 21 126 | device_eval_batch_size: 24 127 | device_train_microbatch_size: 24 128 | 129 | precision: amp_bf16 130 | 131 | # Logging 132 | progress_bar: true 133 | log_to_console: true 134 | console_log_interval: 100ba 135 | 136 | callbacks: 137 | speed_monitor: 138 | window_size: 100 139 | lr_monitor: {} 140 | scheduled_gc: {} 141 | log_grad_norm: 142 | batch_log_interval: 100 143 | packing_efficiency: 144 | log_interval: 100 145 | 146 | # W&B logging 147 | loggers: 148 | wandb: 149 | project: encoder_vs_decoder 150 | entity: your_wandb_entity 151 | 152 | # Checkpoint to local filesystem or remote object store 153 | save_interval: 8_500_000_000tok 154 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 155 | save_folder: your_save_folder 156 | 157 | # Load from local filesystem or remote object store to 158 | # load_path: null 159 | 160 | autoresume: true -------------------------------------------------------------------------------- /pretraining/configs/standard/huge/prolong_decay/encoder_huge.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.15 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: huge_encoder_no_packing_prolong_decay_lower_mask 11 | 12 | 13 | # Model 14 | model: 15 | name: flex_bert 16 | pretrained_model_name: bert-base-uncased 17 | tokenizer_name: ${tokenizer_name} 18 | disable_train_metrics: true 19 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 20 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 21 | # the model_config settings match the architecture of the existing model 22 | model_config: 23 | vocab_size: 50368 24 | init_method: full_megatron 25 | num_hidden_layers: 28 26 | hidden_size: 1792 27 | intermediate_size: 3840 28 | num_attention_heads: 28 # to have head size of 64 29 | attention_layer: rope 30 | attention_probs_dropout_prob: 0.0 31 | attn_out_bias: false 32 | attn_out_dropout_prob: 0.1 33 | attn_qkv_bias: false 34 | bert_layer: prenorm 35 | embed_dropout_prob: 0.0 36 | embed_norm: true 37 | final_norm: true 38 | skip_first_prenorm: true 39 | embedding_layer: sans_pos 40 | loss_function: fa_cross_entropy 41 | loss_kwargs: 42 | reduction: mean 43 | mlp_dropout_prob: 0.0 44 | mlp_in_bias: false 45 | mlp_layer: glu 46 | mlp_out_bias: false 47 | normalization: layernorm 48 | norm_kwargs: 49 | eps: 1e-5 50 | bias: false 51 | hidden_act: gelu 52 | head_pred_act: gelu 53 | activation_function: gelu # better safe than sorry 54 | padding: unpadded 55 | rotary_emb_dim: null 56 | rotary_emb_base: 160000.0 57 | local_attn_rotary_emb_base: 160000.0 58 | rotary_emb_scale_base: null 59 | rotary_emb_interleaved: false 60 | allow_embedding_resizing: true 61 | sliding_window: 128 62 | global_attn_every_n_layers: 3 63 | unpad_embeddings: true 64 | compile_model: true 65 | masked_prediction: true 66 | 67 | # Dataloaders 68 | train_loader: 69 | name: text 70 | dataset: 71 | local: ${data_local} 72 | remote: ${data_remote} 73 | split: train 74 | tokenizer_name: ${tokenizer_name} 75 | max_seq_len: ${max_seq_len} 76 | shuffle: true 77 | mlm_probability: ${mlm_probability} 78 | streaming: false 79 | shuffle_seed: 21 80 | drop_last: true 81 | num_workers: 6 82 | sequence_packing: false 83 | batch_size_warmup_min_size: ${device_train_microbatch_size} 84 | batch_size_warmup_tokens: 0tok 85 | 86 | 87 | eval_loader: 88 | name: text 89 | dataset: 90 | local: ${data_local} 91 | remote: 92 | split: validation 93 | tokenizer_name: ${tokenizer_name} 94 | max_seq_len: ${max_seq_len} 95 | shuffle: false 96 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 97 | streaming: false 98 | drop_last: false 99 | num_workers: 3 100 | sequence_packing: false 101 | 102 | 103 | # Optimization 104 | scheduler: 105 | name: one_minus_sqrt 106 | alpha_f: 0.04 107 | t_decay: ${max_duration} 108 | t_max: ${max_duration} 109 | 110 | optimizer: 111 | name: decoupled_stableadamw 112 | lr: 2.5e-4 # Peak learning rate 113 | betas: 114 | - 0.9 115 | - 0.98 116 | eps: 1.0e-06 117 | weight_decay: 5.0e-5 # Amount of weight decay regularization 118 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 119 | log_grad_norm: true 120 | 121 | max_duration: 16_666_666_667tok 122 | eval_interval: 99999999999999999999999999999999ba # don't do this 123 | global_train_batch_size: 4608 124 | global_eval_batch_size: 1024 125 | 126 | # System 127 | seed: 21 128 | device_eval_batch_size: 2 129 | device_train_microbatch_size: 2 130 | 131 | precision: amp_bf16 132 | 133 | # Logging 134 | progress_bar: true 135 | log_to_console: true 136 | console_log_interval: 100ba 137 | 138 | callbacks: 139 | speed_monitor: 140 | window_size: 100 141 | lr_monitor: {} 142 | scheduled_gc: {} 143 | log_grad_norm: 144 | batch_log_interval: 100 145 | packing_efficiency: 146 | log_interval: 100 147 | 148 | # W&B logging 149 | loggers: 150 | wandb: 151 | project: encoder_vs_decoder 152 | entity: your_wandb_entity 153 | 154 | # Checkpoint to local filesystem or remote object store 155 | save_interval: 8_500_000_000tok 156 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 157 | save_folder: your_save_folder 158 | 159 | 160 | # Load from local filesystem or remote object store to 161 | autoresume: true 162 | load_path: your_load_path 163 | reset_time: true 164 | restart_override: true -------------------------------------------------------------------------------- /pretraining/configs/standard/large/de_decay/encoder_large.yaml: -------------------------------------------------------------------------------- 1 | data_local: sample_250B 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: large_encoder_no_packing_de_decay 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 28 25 | hidden_size: 1024 26 | intermediate_size: 2624 27 | num_attention_heads: 16 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | rotary_emb_scale_base: null 57 | rotary_emb_interleaved: false 58 | allow_embedding_resizing: true 59 | sliding_window: 128 60 | global_attn_every_n_layers: 3 61 | unpad_embeddings: true 62 | compile_model: true 63 | masked_prediction: true 64 | 65 | # Dataloaders 66 | train_loader: 67 | name: text 68 | dataset: 69 | local: ${data_local} 70 | remote: ${data_remote} 71 | split: train 72 | tokenizer_name: ${tokenizer_name} 73 | max_seq_len: ${max_seq_len} 74 | shuffle: true 75 | mlm_probability: ${mlm_probability} 76 | streaming: false 77 | shuffle_seed: 21 78 | drop_last: true 79 | num_workers: 6 80 | sequence_packing: false 81 | batch_size_warmup_min_size: ${device_train_microbatch_size} 82 | batch_size_warmup_tokens: 0tok 83 | 84 | 85 | eval_loader: 86 | name: text 87 | dataset: 88 | local: ${data_local} 89 | remote: 90 | split: validation 91 | tokenizer_name: ${tokenizer_name} 92 | max_seq_len: ${max_seq_len} 93 | shuffle: false 94 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 95 | streaming: false 96 | drop_last: false 97 | num_workers: 3 98 | sequence_packing: false 99 | 100 | 101 | # Optimization 102 | scheduler: 103 | name: one_minus_sqrt 104 | alpha_f: 0.5 105 | t_decay: ${max_duration} 106 | t_max: ${max_duration} 107 | 108 | optimizer: 109 | name: decoupled_stableadamw 110 | lr: 5e-4 # Peak learning rate 111 | betas: 112 | - 0.9 113 | - 0.98 114 | eps: 1.0e-06 115 | weight_decay: 1.0e-5 # Amount of weight decay regularization 116 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 117 | log_grad_norm: true 118 | 119 | max_duration: 250_000_000_000tok 120 | eval_interval: 99999999999999999999999999999999ba # don't do this 121 | global_train_batch_size: 4608 122 | global_eval_batch_size: 1024 123 | 124 | # System 125 | seed: 21 126 | device_eval_batch_size: 6 127 | device_train_microbatch_size: 6 128 | 129 | precision: amp_bf16 130 | 131 | # Logging 132 | progress_bar: true 133 | log_to_console: true 134 | console_log_interval: 100ba 135 | 136 | callbacks: 137 | speed_monitor: 138 | window_size: 100 139 | lr_monitor: {} 140 | scheduled_gc: {} 141 | log_grad_norm: 142 | batch_log_interval: 100 143 | packing_efficiency: 144 | log_interval: 100 145 | 146 | # W&B logging 147 | loggers: 148 | wandb: 149 | project: encoder_vs_decoder 150 | entity: your_wandb_entity 151 | 152 | # Checkpoint to local filesystem or remote object store 153 | save_interval: 8_500_000_000tok 154 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 155 | save_folder: your_save_folder 156 | 157 | # Load from local filesystem or remote object store to 158 | autoresume: true 159 | load_path: your_load_path 160 | reset_time: true -------------------------------------------------------------------------------- /pretraining/configs/standard/large/encoder_large.yaml: -------------------------------------------------------------------------------- 1 | data_local: ettin_tokenized_data 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 1024 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: large_encoder_no_packing 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 28 25 | hidden_size: 1024 26 | intermediate_size: 2624 27 | num_attention_heads: 16 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 10000.0 56 | rotary_emb_scale_base: null 57 | rotary_emb_interleaved: false 58 | allow_embedding_resizing: true 59 | sliding_window: 128 60 | global_attn_every_n_layers: 3 61 | unpad_embeddings: true 62 | compile_model: true 63 | masked_prediction: true 64 | 65 | # Dataloaders 66 | train_loader: 67 | name: text 68 | dataset: 69 | local: ${data_local} 70 | remote: ${data_remote} 71 | split: train 72 | tokenizer_name: ${tokenizer_name} 73 | max_seq_len: ${max_seq_len} 74 | shuffle: true 75 | mlm_probability: ${mlm_probability} 76 | streaming: false 77 | shuffle_seed: 21 78 | drop_last: true 79 | num_workers: 6 80 | sequence_packing: false 81 | batch_size_warmup_min_size: ${device_train_microbatch_size} 82 | batch_size_warmup_tokens: 10_000_000_000tok 83 | 84 | 85 | eval_loader: 86 | name: text 87 | dataset: 88 | local: ${data_local} 89 | remote: 90 | split: validation 91 | tokenizer_name: ${tokenizer_name} 92 | max_seq_len: ${max_seq_len} 93 | shuffle: false 94 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 95 | streaming: false 96 | drop_last: false 97 | num_workers: 3 98 | sequence_packing: false 99 | 100 | 101 | # Optimization 102 | scheduler: 103 | name: warmup_stable_decay 104 | t_warmup: 2_000_000_000tok # Warmup to the full LR for 6% of the training duration 105 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 106 | t_decay: 0tok 107 | 108 | optimizer: 109 | name: decoupled_stableadamw 110 | lr: 5e-4 # Peak learning rate 111 | betas: 112 | - 0.9 113 | - 0.98 114 | eps: 1.0e-06 115 | weight_decay: 1.0e-5 # Amount of weight decay regularization 116 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 117 | log_grad_norm: true 118 | 119 | max_duration: 1_700_000_000_000tok 120 | eval_interval: 99999999999999999999999999999999ba # don't do this 121 | global_train_batch_size: 4608 122 | global_eval_batch_size: 1024 123 | 124 | # System 125 | seed: 21 126 | device_eval_batch_size: 48 127 | device_train_microbatch_size: 48 128 | 129 | precision: amp_bf16 130 | 131 | # Logging 132 | progress_bar: true 133 | log_to_console: true 134 | console_log_interval: 100ba 135 | 136 | callbacks: 137 | speed_monitor: 138 | window_size: 100 139 | lr_monitor: {} 140 | scheduled_gc: {} 141 | log_grad_norm: 142 | batch_log_interval: 100 143 | packing_efficiency: 144 | log_interval: 100 145 | 146 | # W&B logging 147 | loggers: 148 | wandb: 149 | project: encoder_vs_decoder 150 | entity: your_wandb_entity 151 | 152 | # Checkpoint to local filesystem or remote object store 153 | save_interval: 8_500_000_000tok 154 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 155 | save_folder: your_save_folder 156 | 157 | # Load from local filesystem or remote object store to 158 | # load_path: null 159 | 160 | autoresume: true -------------------------------------------------------------------------------- /pretraining/configs/standard/large/prolong_decay/encoder_large.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.15 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: large_encoder_no_packing_prolong_decay_lower_mask 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 28 25 | hidden_size: 1024 26 | intermediate_size: 2624 27 | num_attention_heads: 16 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: true 65 | 66 | # Dataloaders 67 | train_loader: 68 | name: text 69 | dataset: 70 | local: ${data_local} 71 | remote: ${data_remote} 72 | split: train 73 | tokenizer_name: ${tokenizer_name} 74 | max_seq_len: ${max_seq_len} 75 | shuffle: true 76 | mlm_probability: ${mlm_probability} 77 | streaming: false 78 | shuffle_seed: 21 79 | drop_last: true 80 | num_workers: 6 81 | sequence_packing: false 82 | batch_size_warmup_min_size: ${device_train_microbatch_size} 83 | batch_size_warmup_tokens: 0tok 84 | 85 | 86 | eval_loader: 87 | name: text 88 | dataset: 89 | local: ${data_local} 90 | remote: 91 | split: validation 92 | tokenizer_name: ${tokenizer_name} 93 | max_seq_len: ${max_seq_len} 94 | shuffle: false 95 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 96 | streaming: false 97 | drop_last: false 98 | num_workers: 3 99 | sequence_packing: false 100 | 101 | 102 | # Optimization 103 | scheduler: 104 | name: one_minus_sqrt 105 | alpha_f: 0.04 106 | t_decay: ${max_duration} 107 | t_max: ${max_duration} 108 | 109 | optimizer: 110 | name: decoupled_stableadamw 111 | lr: 2.5e-4 # Peak learning rate 112 | betas: 113 | - 0.9 114 | - 0.98 115 | eps: 1.0e-06 116 | weight_decay: 1.0e-5 # Amount of weight decay regularization 117 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 118 | log_grad_norm: true 119 | 120 | max_duration: 50_000_000_000tok 121 | eval_interval: 99999999999999999999999999999999ba # don't do this 122 | global_train_batch_size: 4608 123 | global_eval_batch_size: 1024 124 | 125 | # System 126 | seed: 21 127 | device_eval_batch_size: 2 128 | device_train_microbatch_size: 2 129 | 130 | precision: amp_bf16 131 | 132 | # Logging 133 | progress_bar: true 134 | log_to_console: true 135 | console_log_interval: 100ba 136 | 137 | callbacks: 138 | speed_monitor: 139 | window_size: 100 140 | lr_monitor: {} 141 | scheduled_gc: {} 142 | log_grad_norm: 143 | batch_log_interval: 100 144 | packing_efficiency: 145 | log_interval: 100 146 | 147 | # W&B logging 148 | loggers: 149 | wandb: 150 | project: encoder_vs_decoder 151 | entity: your_wandb_entity 152 | 153 | # Checkpoint to local filesystem or remote object store 154 | save_interval: 8_500_000_000tok 155 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 156 | save_folder: your_save_folder 157 | 158 | # Load from local filesystem or remote object store to 159 | autoresume: true 160 | load_path: your_load_path 161 | -------------------------------------------------------------------------------- /pretraining/configs/standard/mini/de_decay/decoder_mini.yaml: -------------------------------------------------------------------------------- 1 | data_local: sample_250B 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.0 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: decoder_mini_no_packing_de_decay 11 | 12 | # Model 13 | model: 14 | name: flex_gpt 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 19 25 | hidden_size: 512 26 | intermediate_size: 768 27 | num_attention_heads: 8 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | rotary_emb_scale_base: null 57 | rotary_emb_interleaved: false 58 | allow_embedding_resizing: true 59 | sliding_window: 128 60 | global_attn_every_n_layers: 3 61 | unpad_embeddings: true 62 | compile_model: true 63 | masked_prediction: false 64 | causal_mask: true 65 | pad_logits: true 66 | 67 | 68 | # Dataloaders 69 | train_loader: 70 | name: text 71 | dataset: 72 | local: ${data_local} 73 | remote: ${data_remote} 74 | split: train 75 | tokenizer_name: ${tokenizer_name} 76 | max_seq_len: ${max_seq_len} 77 | shuffle: true 78 | mlm_probability: ${mlm_probability} 79 | streaming: false 80 | use_decoder_attn_mask: true 81 | shuffle_seed: 21 82 | suppress_masking: true 83 | drop_last: true 84 | num_workers: 6 85 | sequence_packing: false 86 | batch_size_warmup_min_size: ${device_train_microbatch_size} 87 | batch_size_warmup_tokens: 0tok 88 | 89 | 90 | eval_loader: 91 | name: text 92 | dataset: 93 | local: ${data_local} 94 | remote: 95 | split: validation 96 | tokenizer_name: ${tokenizer_name} 97 | max_seq_len: ${max_seq_len} 98 | shuffle: false 99 | mlm_probability: 0.0 # We always evaluate at 15% masking for consistent comparison 100 | streaming: false 101 | use_decoder_attn_mask: true 102 | suppress_masking: true 103 | drop_last: false 104 | num_workers: 3 105 | sequence_packing: false 106 | 107 | 108 | # Optimization 109 | scheduler: 110 | name: one_minus_sqrt 111 | alpha_f: 0.5 112 | t_decay: ${max_duration} 113 | t_max: ${max_duration} 114 | 115 | optimizer: 116 | name: decoupled_stableadamw 117 | lr: 3e-3 # Peak learning rate 118 | betas: 119 | - 0.9 120 | - 0.98 121 | eps: 1.0e-06 122 | weight_decay: 3.0e-4 # Amount of weight decay regularization 123 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 124 | log_grad_norm: true 125 | 126 | max_duration: 250_000_000_000tok 127 | eval_interval: 99999999999999999999999999999999ba # don't do this 128 | global_train_batch_size: 4608 129 | global_eval_batch_size: 1024 130 | 131 | # System 132 | seed: 21 133 | device_eval_batch_size: 18 134 | device_train_microbatch_size: 18 135 | 136 | precision: amp_bf16 137 | 138 | # Logging 139 | progress_bar: true 140 | log_to_console: true 141 | console_log_interval: 100ba 142 | 143 | callbacks: 144 | speed_monitor: 145 | window_size: 100 146 | lr_monitor: {} 147 | scheduled_gc: {} 148 | log_grad_norm: 149 | batch_log_interval: 100 150 | packing_efficiency: 151 | log_interval: 100 152 | 153 | # W&B logging 154 | loggers: 155 | wandb: 156 | project: encoder_vs_decoder 157 | entity: your_wandb_entity 158 | 159 | # Checkpoint to local filesystem or remote object store 160 | save_interval: 8_500_000_000tok 161 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 162 | save_folder: your_save_folder 163 | 164 | # Load from local filesystem or remote object store to 165 | autoresume: true 166 | load_path: your_load_path 167 | reset_time: true -------------------------------------------------------------------------------- /pretraining/configs/standard/mini/de_decay/encoder_mini.yaml: -------------------------------------------------------------------------------- 1 | data_local: sample_250B 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: encoder_mini_no_packing_de_decay 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 19 25 | hidden_size: 512 26 | intermediate_size: 768 27 | num_attention_heads: 8 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | rotary_emb_scale_base: null 57 | rotary_emb_interleaved: false 58 | allow_embedding_resizing: true 59 | sliding_window: 128 60 | global_attn_every_n_layers: 3 61 | unpad_embeddings: true 62 | compile_model: true 63 | masked_prediction: true 64 | 65 | # Dataloaders 66 | train_loader: 67 | name: text 68 | dataset: 69 | local: ${data_local} 70 | remote: ${data_remote} 71 | split: train 72 | tokenizer_name: ${tokenizer_name} 73 | max_seq_len: ${max_seq_len} 74 | shuffle: true 75 | mlm_probability: ${mlm_probability} 76 | streaming: false 77 | shuffle_seed: 21 78 | drop_last: true 79 | num_workers: 6 80 | sequence_packing: false 81 | batch_size_warmup_min_size: ${device_train_microbatch_size} 82 | batch_size_warmup_tokens: 0tok 83 | 84 | 85 | eval_loader: 86 | name: text 87 | dataset: 88 | local: ${data_local} 89 | remote: 90 | split: validation 91 | tokenizer_name: ${tokenizer_name} 92 | max_seq_len: ${max_seq_len} 93 | shuffle: false 94 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 95 | streaming: false 96 | drop_last: false 97 | num_workers: 3 98 | sequence_packing: false 99 | 100 | 101 | # Optimization 102 | scheduler: 103 | name: one_minus_sqrt 104 | alpha_f: 0.5 105 | t_decay: ${max_duration} 106 | t_max: ${max_duration} 107 | 108 | optimizer: 109 | name: decoupled_stableadamw 110 | lr: 3e-3 # Peak learning rate 111 | betas: 112 | - 0.9 113 | - 0.98 114 | eps: 1.0e-06 115 | weight_decay: 3.0e-4 # Amount of weight decay regularization 116 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 117 | log_grad_norm: true 118 | 119 | max_duration: 250_000_000_000tok 120 | eval_interval: 99999999999999999999999999999999ba # don't do this 121 | global_train_batch_size: 4608 122 | global_eval_batch_size: 1024 123 | 124 | # System 125 | seed: 21 126 | device_eval_batch_size: 18 127 | device_train_microbatch_size: 18 128 | 129 | precision: amp_bf16 130 | 131 | # Logging 132 | progress_bar: true 133 | log_to_console: true 134 | console_log_interval: 100ba 135 | 136 | callbacks: 137 | speed_monitor: 138 | window_size: 100 139 | lr_monitor: {} 140 | scheduled_gc: {} 141 | log_grad_norm: 142 | batch_log_interval: 100 143 | packing_efficiency: 144 | log_interval: 100 145 | 146 | # W&B logging 147 | loggers: 148 | wandb: 149 | project: encoder_vs_decoder 150 | entity: your_wandb_entity 151 | 152 | # Checkpoint to local filesystem or remote object store 153 | save_interval: 8_500_000_000tok 154 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 155 | save_folder: your_save_folder 156 | 157 | 158 | # Load from local filesystem or remote object store to 159 | autoresume: true 160 | load_path: your_load_path 161 | reset_time: true -------------------------------------------------------------------------------- /pretraining/configs/standard/mini/encoder_mini.yaml: -------------------------------------------------------------------------------- 1 | data_local: ettin_tokenized_data 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 1024 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: encoder_mini_no_packing 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 19 25 | hidden_size: 512 26 | intermediate_size: 768 27 | num_attention_heads: 8 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 10000.0 56 | rotary_emb_scale_base: null 57 | rotary_emb_interleaved: false 58 | allow_embedding_resizing: true 59 | sliding_window: 128 60 | global_attn_every_n_layers: 3 61 | unpad_embeddings: true 62 | compile_model: true 63 | masked_prediction: true 64 | 65 | # Dataloaders 66 | train_loader: 67 | name: text 68 | dataset: 69 | local: ${data_local} 70 | remote: ${data_remote} 71 | split: train 72 | tokenizer_name: ${tokenizer_name} 73 | max_seq_len: ${max_seq_len} 74 | shuffle: true 75 | mlm_probability: ${mlm_probability} 76 | streaming: false 77 | shuffle_seed: 21 78 | drop_last: true 79 | num_workers: 6 80 | sequence_packing: false 81 | batch_size_warmup_min_size: ${device_train_microbatch_size} 82 | batch_size_warmup_tokens: 75_000_000_000tok 83 | 84 | 85 | eval_loader: 86 | name: text 87 | dataset: 88 | local: ${data_local} 89 | remote: 90 | split: validation 91 | tokenizer_name: ${tokenizer_name} 92 | max_seq_len: ${max_seq_len} 93 | shuffle: false 94 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 95 | streaming: false 96 | drop_last: false 97 | num_workers: 3 98 | sequence_packing: false 99 | 100 | 101 | # Optimization 102 | scheduler: 103 | name: warmup_stable_decay 104 | t_warmup: 3_000_000_000tok # Warmup to the full LR for 6% of the training duration 105 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 106 | t_decay: 0tok 107 | 108 | optimizer: 109 | name: decoupled_stableadamw 110 | lr: 3e-3 # Peak learning rate 111 | betas: 112 | - 0.9 113 | - 0.98 114 | eps: 1.0e-06 115 | weight_decay: 3.0e-4 # Amount of weight decay regularization 116 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 117 | log_grad_norm: true 118 | 119 | max_duration: 1_700_000_000_000tok 120 | eval_interval: 99999999999999999999999999999999ba # don't do this 121 | global_train_batch_size: 4608 122 | global_eval_batch_size: 1024 123 | 124 | # System 125 | seed: 21 126 | device_eval_batch_size: 96 127 | device_train_microbatch_size: 96 128 | 129 | precision: amp_bf16 130 | 131 | # Logging 132 | progress_bar: true 133 | log_to_console: true 134 | console_log_interval: 100ba 135 | 136 | callbacks: 137 | speed_monitor: 138 | window_size: 100 139 | lr_monitor: {} 140 | scheduled_gc: {} 141 | log_grad_norm: 142 | batch_log_interval: 100 143 | packing_efficiency: 144 | log_interval: 100 145 | 146 | # W&B logging 147 | loggers: 148 | wandb: 149 | project: encoder_vs_decoder 150 | entity: your_wandb_entity 151 | 152 | # Checkpoint to local filesystem or remote object store 153 | save_interval: 8_500_000_000tok 154 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 155 | save_folder: your_save_folder 156 | 157 | 158 | # Load from local filesystem or remote object store to 159 | # load_path: null 160 | 161 | autoresume: true -------------------------------------------------------------------------------- /pretraining/configs/standard/mini/prolong_decay/encoder_mini.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.15 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: encoder_mini_no_packing_prolong_decay_lower_mask 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 19 25 | hidden_size: 512 26 | intermediate_size: 768 27 | num_attention_heads: 8 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: true 65 | 66 | # Dataloaders 67 | train_loader: 68 | name: text 69 | dataset: 70 | local: ${data_local} 71 | remote: ${data_remote} 72 | split: train 73 | tokenizer_name: ${tokenizer_name} 74 | max_seq_len: ${max_seq_len} 75 | shuffle: true 76 | mlm_probability: ${mlm_probability} 77 | streaming: false 78 | shuffle_seed: 21 79 | drop_last: true 80 | num_workers: 6 81 | sequence_packing: false 82 | batch_size_warmup_min_size: ${device_train_microbatch_size} 83 | batch_size_warmup_tokens: 0tok 84 | 85 | 86 | eval_loader: 87 | name: text 88 | dataset: 89 | local: ${data_local} 90 | remote: 91 | split: validation 92 | tokenizer_name: ${tokenizer_name} 93 | max_seq_len: ${max_seq_len} 94 | shuffle: false 95 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 96 | streaming: false 97 | drop_last: false 98 | num_workers: 3 99 | sequence_packing: false 100 | 101 | 102 | # Optimization 103 | scheduler: 104 | name: one_minus_sqrt 105 | alpha_f: 0.04 106 | t_decay: ${max_duration} 107 | t_max: ${max_duration} 108 | 109 | optimizer: 110 | name: decoupled_stableadamw 111 | lr: 1.5e-3 # Peak learning rate 112 | betas: 113 | - 0.9 114 | - 0.98 115 | eps: 1.0e-06 116 | weight_decay: 3.0e-4 # Amount of weight decay regularization 117 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 118 | log_grad_norm: true 119 | 120 | max_duration: 50_000_000_000tok 121 | eval_interval: 99999999999999999999999999999999ba # don't do this 122 | global_train_batch_size: 4608 123 | global_eval_batch_size: 1024 124 | 125 | # System 126 | seed: 21 127 | device_eval_batch_size: 8 128 | device_train_microbatch_size: 8 129 | 130 | precision: amp_bf16 131 | 132 | # Logging 133 | progress_bar: true 134 | log_to_console: true 135 | console_log_interval: 100ba 136 | 137 | callbacks: 138 | speed_monitor: 139 | window_size: 100 140 | lr_monitor: {} 141 | scheduled_gc: {} 142 | log_grad_norm: 143 | batch_log_interval: 100 144 | packing_efficiency: 145 | log_interval: 100 146 | 147 | # W&B logging 148 | loggers: 149 | wandb: 150 | project: encoder_vs_decoder 151 | entity: your_wandb_entity 152 | 153 | # Checkpoint to local filesystem or remote object store 154 | save_interval: 8_500_000_000tok 155 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 156 | save_folder: your_save_folder 157 | 158 | 159 | # Load from local filesystem or remote object store to 160 | autoresume: true 161 | load_path: your_load_path 162 | reset_time: true 163 | restart_override: true -------------------------------------------------------------------------------- /pretraining/configs/standard/tiny/de_decay/encoder_tiny.yaml: -------------------------------------------------------------------------------- 1 | data_local: sample_250B 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: encoder_tiny_no_packing_de_decay 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 10 25 | hidden_size: 384 26 | intermediate_size: 576 27 | num_attention_heads: 6 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | rotary_emb_scale_base: null 57 | rotary_emb_interleaved: false 58 | allow_embedding_resizing: true 59 | sliding_window: 128 60 | global_attn_every_n_layers: 3 61 | unpad_embeddings: true 62 | compile_model: true 63 | masked_prediction: true 64 | 65 | # Dataloaders 66 | train_loader: 67 | name: text 68 | dataset: 69 | local: ${data_local} 70 | remote: ${data_remote} 71 | split: train 72 | tokenizer_name: ${tokenizer_name} 73 | max_seq_len: ${max_seq_len} 74 | shuffle: true 75 | mlm_probability: ${mlm_probability} 76 | streaming: false 77 | shuffle_seed: 21 78 | drop_last: true 79 | num_workers: 6 80 | sequence_packing: false 81 | batch_size_warmup_min_size: ${device_train_microbatch_size} 82 | batch_size_warmup_tokens: 0tok 83 | 84 | 85 | eval_loader: 86 | name: text 87 | dataset: 88 | local: ${data_local} 89 | remote: 90 | split: validation 91 | tokenizer_name: ${tokenizer_name} 92 | max_seq_len: ${max_seq_len} 93 | shuffle: false 94 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 95 | streaming: false 96 | drop_last: false 97 | num_workers: 3 98 | sequence_packing: false 99 | 100 | 101 | # Optimization 102 | scheduler: 103 | name: one_minus_sqrt 104 | alpha_f: 0.5 105 | t_decay: ${max_duration} 106 | t_max: ${max_duration} 107 | 108 | optimizer: 109 | name: decoupled_stableadamw 110 | lr: 3e-3 # Peak learning rate 111 | betas: 112 | - 0.9 113 | - 0.98 114 | eps: 1.0e-06 115 | weight_decay: 3.0e-4 # Amount of weight decay regularization 116 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 117 | log_grad_norm: true 118 | 119 | max_duration: 250_000_000_000tok 120 | eval_interval: 99999999999999999999999999999999ba # don't do this 121 | global_train_batch_size: 4608 122 | global_eval_batch_size: 1024 123 | 124 | # System 125 | seed: 21 126 | device_eval_batch_size: 24 127 | device_train_microbatch_size: 24 128 | 129 | precision: amp_bf16 130 | 131 | # Logging 132 | progress_bar: true 133 | log_to_console: true 134 | console_log_interval: 100ba 135 | 136 | callbacks: 137 | speed_monitor: 138 | window_size: 100 139 | lr_monitor: {} 140 | scheduled_gc: {} 141 | log_grad_norm: 142 | batch_log_interval: 100 143 | packing_efficiency: 144 | log_interval: 100 145 | 146 | # W&B logging 147 | loggers: 148 | wandb: 149 | project: encoder_vs_decoder 150 | entity: your_wandb_entity 151 | 152 | # Checkpoint to local filesystem or remote object store 153 | save_interval: 8_500_000_000tok 154 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 155 | save_folder: your_save_folder 156 | 157 | # Load from local filesystem or remote object store to 158 | autoresume: true 159 | load_path: your_load_path 160 | reset_time: true -------------------------------------------------------------------------------- /pretraining/configs/standard/tiny/encoder_tiny.yaml: -------------------------------------------------------------------------------- 1 | data_local: ettin_tokenized_data 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 1024 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: encoder_tiny_no_packing 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 10 25 | hidden_size: 384 26 | intermediate_size: 576 27 | num_attention_heads: 6 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 10000.0 56 | rotary_emb_scale_base: null 57 | rotary_emb_interleaved: false 58 | allow_embedding_resizing: true 59 | sliding_window: 128 60 | global_attn_every_n_layers: 3 61 | unpad_embeddings: true 62 | compile_model: true 63 | masked_prediction: true 64 | 65 | # Dataloaders 66 | train_loader: 67 | name: text 68 | dataset: 69 | local: ${data_local} 70 | remote: ${data_remote} 71 | split: train 72 | tokenizer_name: ${tokenizer_name} 73 | max_seq_len: ${max_seq_len} 74 | shuffle: true 75 | mlm_probability: ${mlm_probability} 76 | streaming: false 77 | shuffle_seed: 21 78 | drop_last: true 79 | num_workers: 6 80 | sequence_packing: false 81 | batch_size_warmup_min_size: ${device_train_microbatch_size} 82 | batch_size_warmup_tokens: 100_000_000_000tok 83 | 84 | 85 | eval_loader: 86 | name: text 87 | dataset: 88 | local: ${data_local} 89 | remote: 90 | split: validation 91 | tokenizer_name: ${tokenizer_name} 92 | max_seq_len: ${max_seq_len} 93 | shuffle: false 94 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 95 | streaming: false 96 | drop_last: false 97 | num_workers: 3 98 | sequence_packing: false 99 | 100 | 101 | # Optimization 102 | scheduler: 103 | name: warmup_stable_decay 104 | t_warmup: 4_000_000_000tok # Warmup to the full LR for 6% of the training duration 105 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 106 | t_decay: 0tok 107 | 108 | optimizer: 109 | name: decoupled_stableadamw 110 | lr: 3e-3 # Peak learning rate 111 | betas: 112 | - 0.9 113 | - 0.98 114 | eps: 1.0e-06 115 | weight_decay: 3.0e-4 # Amount of weight decay regularization 116 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 117 | log_grad_norm: true 118 | 119 | max_duration: 1_700_000_000_000tok 120 | eval_interval: 99999999999999999999999999999999ba # don't do this 121 | global_train_batch_size: 4608 122 | global_eval_batch_size: 1024 123 | 124 | # System 125 | seed: 21 126 | device_eval_batch_size: 128 127 | device_train_microbatch_size: 128 128 | 129 | precision: amp_bf16 130 | 131 | # Logging 132 | progress_bar: true 133 | log_to_console: true 134 | console_log_interval: 100ba 135 | 136 | callbacks: 137 | speed_monitor: 138 | window_size: 100 139 | lr_monitor: {} 140 | scheduled_gc: {} 141 | log_grad_norm: 142 | batch_log_interval: 100 143 | packing_efficiency: 144 | log_interval: 100 145 | 146 | # W&B logging 147 | loggers: 148 | wandb: 149 | project: encoder_vs_decoder 150 | entity: your_wandb_entity 151 | 152 | # Checkpoint to local filesystem or remote object store 153 | save_interval: 8_500_000_000tok 154 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 155 | save_folder: your_save_folder 156 | 157 | # Load from local filesystem or remote object store to 158 | # load_path: null 159 | 160 | autoresume: true -------------------------------------------------------------------------------- /pretraining/configs/standard/tiny/prolong_decay/encoder_tiny.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.15 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: encoder_tiny_no_packing_prolong_decay_lower_mask 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 10 25 | hidden_size: 384 26 | intermediate_size: 576 27 | num_attention_heads: 6 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: true 65 | 66 | # Dataloaders 67 | train_loader: 68 | name: text 69 | dataset: 70 | local: ${data_local} 71 | remote: ${data_remote} 72 | split: train 73 | tokenizer_name: ${tokenizer_name} 74 | max_seq_len: ${max_seq_len} 75 | shuffle: true 76 | mlm_probability: ${mlm_probability} 77 | streaming: false 78 | shuffle_seed: 21 79 | drop_last: true 80 | num_workers: 6 81 | sequence_packing: false 82 | batch_size_warmup_min_size: ${device_train_microbatch_size} 83 | batch_size_warmup_tokens: 0tok 84 | 85 | 86 | eval_loader: 87 | name: text 88 | dataset: 89 | local: ${data_local} 90 | remote: 91 | split: validation 92 | tokenizer_name: ${tokenizer_name} 93 | max_seq_len: ${max_seq_len} 94 | shuffle: false 95 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 96 | streaming: false 97 | drop_last: false 98 | num_workers: 3 99 | sequence_packing: false 100 | 101 | 102 | # Optimization 103 | scheduler: 104 | name: one_minus_sqrt 105 | alpha_f: 0.04 106 | t_decay: ${max_duration} 107 | t_max: ${max_duration} 108 | 109 | optimizer: 110 | name: decoupled_stableadamw 111 | lr: 1.5e-3 # Peak learning rate 112 | betas: 113 | - 0.9 114 | - 0.98 115 | eps: 1.0e-06 116 | weight_decay: 3.0e-4 # Amount of weight decay regularization 117 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 118 | log_grad_norm: true 119 | 120 | max_duration: 50_000_000_000tok 121 | eval_interval: 99999999999999999999999999999999ba # don't do this 122 | global_train_batch_size: 4608 123 | global_eval_batch_size: 1024 124 | 125 | # System 126 | seed: 21 127 | device_eval_batch_size: 12 128 | device_train_microbatch_size: 12 129 | 130 | precision: amp_bf16 131 | 132 | # Logging 133 | progress_bar: true 134 | log_to_console: true 135 | console_log_interval: 100ba 136 | 137 | callbacks: 138 | speed_monitor: 139 | window_size: 100 140 | lr_monitor: {} 141 | scheduled_gc: {} 142 | log_grad_norm: 143 | batch_log_interval: 100 144 | packing_efficiency: 145 | log_interval: 100 146 | 147 | # W&B logging 148 | loggers: 149 | wandb: 150 | project: encoder_vs_decoder 151 | entity: your_wandb_entity 152 | 153 | # Checkpoint to local filesystem or remote object store 154 | save_interval: 8_500_000_000tok 155 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 156 | save_folder: your_save_folder 157 | 158 | # Load from local filesystem or remote object store to 159 | autoresume: true 160 | load_path: your_load_path 161 | reset_time: true 162 | restart_override: true -------------------------------------------------------------------------------- /pretraining/configs/standard/very_tiny/de_decay/encoder_very_tiny.yaml: -------------------------------------------------------------------------------- 1 | data_local: sample_250B 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: encoder_very_tiny_no_packing_de_decay_rope 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 7 25 | hidden_size: 256 26 | intermediate_size: 384 27 | num_attention_heads: 4 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | rotary_emb_scale_base: null 57 | local_attn_rotary_emb_base: 10000.0 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: true 65 | 66 | # Dataloaders 67 | train_loader: 68 | name: text 69 | dataset: 70 | local: ${data_local} 71 | remote: ${data_remote} 72 | split: train 73 | tokenizer_name: ${tokenizer_name} 74 | max_seq_len: ${max_seq_len} 75 | shuffle: true 76 | mlm_probability: ${mlm_probability} 77 | streaming: false 78 | shuffle_seed: 21 79 | drop_last: true 80 | num_workers: 6 81 | sequence_packing: false 82 | batch_size_warmup_min_size: ${device_train_microbatch_size} 83 | batch_size_warmup_tokens: 0tok 84 | 85 | 86 | eval_loader: 87 | name: text 88 | dataset: 89 | local: ${data_local} 90 | remote: 91 | split: validation 92 | tokenizer_name: ${tokenizer_name} 93 | max_seq_len: ${max_seq_len} 94 | shuffle: false 95 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 96 | streaming: false 97 | drop_last: false 98 | num_workers: 3 99 | sequence_packing: false 100 | 101 | 102 | # Optimization 103 | scheduler: 104 | name: one_minus_sqrt 105 | alpha_f: 0.5 106 | t_decay: ${max_duration} 107 | t_max: ${max_duration} 108 | 109 | optimizer: 110 | name: decoupled_stableadamw 111 | lr: 3e-3 # Peak learning rate 112 | betas: 113 | - 0.9 114 | - 0.98 115 | eps: 1.0e-06 116 | weight_decay: 3.0e-4 # Amount of weight decay regularization 117 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 118 | log_grad_norm: true 119 | 120 | max_duration: 250_000_000_000tok 121 | eval_interval: 99999999999999999999999999999999ba # don't do this 122 | global_train_batch_size: 4608 123 | global_eval_batch_size: 1024 124 | 125 | # System 126 | seed: 21 127 | device_eval_batch_size: 24 128 | device_train_microbatch_size: 24 129 | 130 | precision: amp_bf16 131 | 132 | # Logging 133 | progress_bar: true 134 | log_to_console: true 135 | console_log_interval: 100ba 136 | 137 | callbacks: 138 | speed_monitor: 139 | window_size: 100 140 | lr_monitor: {} 141 | scheduled_gc: {} 142 | log_grad_norm: 143 | batch_log_interval: 100 144 | packing_efficiency: 145 | log_interval: 100 146 | 147 | # W&B logging 148 | loggers: 149 | wandb: 150 | project: encoder_vs_decoder 151 | entity: your_wandb_entity 152 | 153 | # Checkpoint to local filesystem or remote object store 154 | save_interval: 8_500_000_000tok 155 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 156 | save_folder: your_save_folder 157 | 158 | # Load from local filesystem or remote object store to 159 | autoresume: true 160 | load_path: your_load_path 161 | reset_time: true 162 | -------------------------------------------------------------------------------- /pretraining/configs/standard/very_tiny/encoder_very_tiny.yaml: -------------------------------------------------------------------------------- 1 | data_local: ettin_tokenized_data 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 1024 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: encoder_very_tiny_no_packing 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 7 25 | hidden_size: 256 26 | intermediate_size: 384 27 | num_attention_heads: 4 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 10000.0 56 | rotary_emb_scale_base: null 57 | rotary_emb_interleaved: false 58 | allow_embedding_resizing: true 59 | sliding_window: 128 60 | global_attn_every_n_layers: 3 61 | unpad_embeddings: true 62 | compile_model: true 63 | masked_prediction: true 64 | 65 | # Dataloaders 66 | train_loader: 67 | name: text 68 | dataset: 69 | local: ${data_local} 70 | remote: ${data_remote} 71 | split: train 72 | tokenizer_name: ${tokenizer_name} 73 | max_seq_len: ${max_seq_len} 74 | shuffle: true 75 | mlm_probability: ${mlm_probability} 76 | streaming: false 77 | shuffle_seed: 21 78 | drop_last: true 79 | num_workers: 6 80 | sequence_packing: false 81 | batch_size_warmup_min_size: ${device_train_microbatch_size} 82 | batch_size_warmup_tokens: 125_000_000_000tok 83 | 84 | 85 | eval_loader: 86 | name: text 87 | dataset: 88 | local: ${data_local} 89 | remote: 90 | split: validation 91 | tokenizer_name: ${tokenizer_name} 92 | max_seq_len: ${max_seq_len} 93 | shuffle: false 94 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 95 | streaming: false 96 | drop_last: false 97 | num_workers: 3 98 | sequence_packing: false 99 | 100 | 101 | # Optimization 102 | scheduler: 103 | name: warmup_stable_decay 104 | t_warmup: 4_000_000_000tok # Warmup to the full LR for 6% of the training duration 105 | alpha_f: 0.02 # Linearly decay to 0.02x the full LR by the end of the training duration 106 | t_decay: 0tok 107 | 108 | optimizer: 109 | name: decoupled_stableadamw 110 | lr: 3e-3 # Peak learning rate 111 | betas: 112 | - 0.9 113 | - 0.98 114 | eps: 1.0e-06 115 | weight_decay: 3.0e-4 # Amount of weight decay regularization 116 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 117 | log_grad_norm: true 118 | 119 | max_duration: 1_700_000_000_000tok 120 | eval_interval: 99999999999999999999999999999999ba # don't do this 121 | global_train_batch_size: 4608 122 | global_eval_batch_size: 1024 123 | 124 | # System 125 | seed: 21 126 | device_eval_batch_size: 144 127 | device_train_microbatch_size: 144 128 | 129 | precision: amp_bf16 130 | 131 | # Logging 132 | progress_bar: true 133 | log_to_console: true 134 | console_log_interval: 100ba 135 | 136 | callbacks: 137 | speed_monitor: 138 | window_size: 100 139 | lr_monitor: {} 140 | scheduled_gc: {} 141 | log_grad_norm: 142 | batch_log_interval: 100 143 | packing_efficiency: 144 | log_interval: 100 145 | 146 | # W&B logging 147 | loggers: 148 | wandb: 149 | project: encoder_vs_decoder 150 | entity: your_wandb_entity 151 | 152 | # Checkpoint to local filesystem or remote object store 153 | save_interval: 8_500_000_000tok 154 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 155 | save_folder: your_save_folder 156 | 157 | # Load from local filesystem or remote object store to 158 | # load_path: null 159 | 160 | autoresume: true -------------------------------------------------------------------------------- /pretraining/configs/standard/very_tiny/prolong_decay/encoder_very_tiny.yaml: -------------------------------------------------------------------------------- 1 | data_local: ProLong 2 | data_remote: # If blank, files must be present in data_local 3 | 4 | max_seq_len: 7999 5 | tokenizer_name: bclavie/olmo_bert_template 6 | mlm_probability: 0.15 # FlexBERT should use 30% masking for optimal performance 7 | count_padding_tokens: false 8 | 9 | # Run Name 10 | run_name: encoder_very_tiny_no_packing_prolong_decay_lower_mask 11 | 12 | # Model 13 | model: 14 | name: flex_bert 15 | pretrained_model_name: bert-base-uncased 16 | tokenizer_name: ${tokenizer_name} 17 | disable_train_metrics: true 18 | # FlexBERT 'base' generally uses the default architecture values from the Hugging Face BertConfig object 19 | # Note: if using the pretrained_checkpoint argument to create a model from an existing checkpoint, make sure 20 | # the model_config settings match the architecture of the existing model 21 | model_config: 22 | vocab_size: 50368 23 | init_method: full_megatron 24 | num_hidden_layers: 7 25 | hidden_size: 256 26 | intermediate_size: 384 27 | num_attention_heads: 4 # to have head size of 64 28 | attention_layer: rope 29 | attention_probs_dropout_prob: 0.0 30 | attn_out_bias: false 31 | attn_out_dropout_prob: 0.1 32 | attn_qkv_bias: false 33 | bert_layer: prenorm 34 | embed_dropout_prob: 0.0 35 | embed_norm: true 36 | final_norm: true 37 | skip_first_prenorm: true 38 | embedding_layer: sans_pos 39 | loss_function: fa_cross_entropy 40 | loss_kwargs: 41 | reduction: mean 42 | mlp_dropout_prob: 0.0 43 | mlp_in_bias: false 44 | mlp_layer: glu 45 | mlp_out_bias: false 46 | normalization: layernorm 47 | norm_kwargs: 48 | eps: 1e-5 49 | bias: false 50 | hidden_act: gelu 51 | head_pred_act: gelu 52 | activation_function: gelu # better safe than sorry 53 | padding: unpadded 54 | rotary_emb_dim: null 55 | rotary_emb_base: 160000.0 56 | local_attn_rotary_emb_base: 160000.0 57 | rotary_emb_scale_base: null 58 | rotary_emb_interleaved: false 59 | allow_embedding_resizing: true 60 | sliding_window: 128 61 | global_attn_every_n_layers: 3 62 | unpad_embeddings: true 63 | compile_model: true 64 | masked_prediction: true 65 | 66 | # Dataloaders 67 | train_loader: 68 | name: text 69 | dataset: 70 | local: ${data_local} 71 | remote: ${data_remote} 72 | split: train 73 | tokenizer_name: ${tokenizer_name} 74 | max_seq_len: ${max_seq_len} 75 | shuffle: true 76 | mlm_probability: ${mlm_probability} 77 | streaming: false 78 | shuffle_seed: 21 79 | drop_last: true 80 | num_workers: 6 81 | sequence_packing: false 82 | batch_size_warmup_min_size: ${device_train_microbatch_size} 83 | batch_size_warmup_tokens: 0tok 84 | 85 | 86 | eval_loader: 87 | name: text 88 | dataset: 89 | local: ${data_local} 90 | remote: 91 | split: validation 92 | tokenizer_name: ${tokenizer_name} 93 | max_seq_len: ${max_seq_len} 94 | shuffle: false 95 | mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison 96 | streaming: false 97 | drop_last: false 98 | num_workers: 3 99 | sequence_packing: false 100 | 101 | 102 | # Optimization 103 | scheduler: 104 | name: one_minus_sqrt 105 | alpha_f: 0.04 106 | t_decay: ${max_duration} 107 | t_max: ${max_duration} 108 | 109 | optimizer: 110 | name: decoupled_stableadamw 111 | lr: 1.5e-3 # Peak learning rate 112 | betas: 113 | - 0.9 114 | - 0.98 115 | eps: 1.0e-06 116 | weight_decay: 3.0e-4 # Amount of weight decay regularization 117 | filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases 118 | log_grad_norm: true 119 | 120 | max_duration: 50_000_000_000tok 121 | eval_interval: 99999999999999999999999999999999ba # don't do this 122 | global_train_batch_size: 4608 123 | global_eval_batch_size: 1024 124 | 125 | # System 126 | seed: 21 127 | device_eval_batch_size: 12 128 | device_train_microbatch_size: 12 129 | 130 | precision: amp_bf16 131 | 132 | # Logging 133 | progress_bar: true 134 | log_to_console: true 135 | console_log_interval: 100ba 136 | 137 | callbacks: 138 | speed_monitor: 139 | window_size: 100 140 | lr_monitor: {} 141 | scheduled_gc: {} 142 | log_grad_norm: 143 | batch_log_interval: 100 144 | packing_efficiency: 145 | log_interval: 100 146 | 147 | # W&B logging 148 | loggers: 149 | wandb: 150 | project: encoder_vs_decoder 151 | entity: your_wandb_entity 152 | 153 | # Checkpoint to local filesystem or remote object store 154 | save_interval: 8_500_000_000tok 155 | save_num_checkpoints_to_keep: 100000 # Important, this cleans up checkpoints saved to DISK 156 | save_folder: your_save_folder 157 | 158 | # Load from local filesystem or remote object store to 159 | autoresume: true 160 | load_path: your_load_path 161 | reset_time: true 162 | restart_override: true 163 | -------------------------------------------------------------------------------- /pretraining/data_processing/bin/chunk_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | datasets=( 4 | # list datasets here 5 | ) 6 | 7 | 8 | for dataset in "${datasets[@]}" 9 | do 10 | # change /text/ to /tokenized_olmo/ 11 | dataset=$(echo $dataset | sed 's/text\//tokenized_olmo\//') 12 | # add -tokenized to the end 13 | dataset="$dataset-tokenized" 14 | echo "Chunking $dataset" 15 | python src/sampling/split_dataset_into_chunks.py -s $dataset -c 8192 -m 512 -a 32 --batch_size 1000 --backfill --backfill_no_duplicates --num_processes 40 --add_eos_token --reverse 16 | done 17 | 18 | ## after run: 19 | # python src/ettin_data/sampling/move_chunks.py data/tokenized_olmo/ 8192-512-32-backfill-nodups 20 | # python ./bin/compile_final_dataset_stats.py data/chunked-olmo-8192-512-32-backfill-nodups/ 21 | 22 | -------------------------------------------------------------------------------- /pretraining/data_processing/bin/count_instances.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | def count_samples(file_path): 5 | with open(file_path, 'r') as file: 6 | data = json.load(file) 7 | 8 | total_samples = sum(shard['samples'] for shard in data['shards']) 9 | return total_samples 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser(description="Count total samples from a JSON file containing shard information.") 13 | parser.add_argument("file_path", help="Path to the JSON file") 14 | args = parser.parse_args() 15 | 16 | try: 17 | total_samples = count_samples(args.file_path) 18 | print(f"Total number of samples: {total_samples}") 19 | except FileNotFoundError: 20 | print(f"Error: File '{args.file_path}' not found.") 21 | except json.JSONDecodeError: 22 | print(f"Error: '{args.file_path}' is not a valid JSON file.") 23 | except KeyError: 24 | print("Error: The JSON file does not have the expected structure.") 25 | 26 | if __name__ == "__main__": 27 | main() -------------------------------------------------------------------------------- /pretraining/data_processing/bin/datasets.txt: -------------------------------------------------------------------------------- 1 | orionweller/books_mds_incremental 2 | orionweller/pes2o_mds_incremental 3 | orionweller/tulu_flan_mds_incremental 4 | orionweller/starcoder_mds_incremental 5 | orionweller/stackexchange_mds_incremental 6 | orionweller/arxiv_mds_incremental 7 | # these are examples, add yours here! -------------------------------------------------------------------------------- /pretraining/data_processing/bin/dolma_mds_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # convert all dolma datasets to mds - I have already run these, but if you want to do similar you can! 3 | python src/initial_dataset_creation/dolma_to_mds.py -s "falcon" -r "orionweller/refinedweb_mds_incremental" 4 | python src/initial_dataset_creation/dolma_to_mds.py -s "c4" -r "orionweller/c4_mds_incremental" 5 | python src/initial_dataset_creation/dolma_to_mds.py -s "books" -r "orionweller/books_mds_incremental" 6 | python src/initial_dataset_creation/dolma_to_mds.py -s "cc_en_head" -r "orionweller/cc_en_head_mds_incremental" 7 | python src/initial_dataset_creation/dolma_to_mds.py -s "cc_en_tail" -r "orionweller/cc_en_tail_mds_incremental" 8 | python src/initial_dataset_creation/dolma_to_mds.py -s "cc_en_middle" -r "orionweller/cc_en_middle_mds_incremental" 9 | python src/initial_dataset_creation/dolma_to_mds.py -s "megawika" -r "orionweller/megawika_mds_incremental" 10 | python src/initial_dataset_creation/dolma_to_mds.py -s "cc_news" -r "orionweller/cc_news_mds_incremental" 11 | python src/initial_dataset_creation/dolma_to_mds.py -s "pes2o" -r "orionweller/pes2o_mds_incremental" 12 | python src/initial_dataset_creation/dolma_to_mds.py -s "tulu_flan" -r "orionweller/tulu_flan_mds_incremental" 13 | python src/initial_dataset_creation/dolma_to_mds.py -s "starcoder" -r "orionweller/starcoder_mds_incremental" 14 | python src/initial_dataset_creation/dolma_to_mds.py -s "stackexchange" -r "orionweller/stackexchange_mds_incremental" 15 | python src/initial_dataset_creation/dolma_to_mds.py -s "arxiv" -r "orionweller/arxiv_mds_incremental" 16 | python src/initial_dataset_creation/dolma_to_mds.py -s "open-web-math" -r "orionweller/open-web-math_mds_incremental" 17 | python src/initial_dataset_creation/dolma_to_mds.py -s "reddit" -r "orionweller/reddit_mds_incremental" 18 | python src/initial_dataset_creation/dolma_to_mds.py -s "algebraic-stack" -r "orionweller/algebraic-stack_mds_incremental" 19 | python src/initial_dataset_creation/dolma_to_mds.py -s "wiki-" -r "orionweller/wikipedia_mds_incremental" 20 | -------------------------------------------------------------------------------- /pretraining/data_processing/bin/download_all_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BASE_DIR=downloaded_data 3 | mkdir -p $BASE_DIR 4 | 5 | # this is just useful for bulk downloading 6 | # for name in ./bin/dataset.txt, read in each line and call ./bin/download_folder.py -r $name 7 | 8 | while IFS= read -r line 9 | do 10 | echo "Downloading $line" 11 | nice -n 10 python ./bin/download_folder.py -r $line 12 | done < bin/datasets.txt -------------------------------------------------------------------------------- /pretraining/data_processing/bin/download_folder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import glob 4 | import shutil 5 | import huggingface_hub 6 | 7 | def download_dataset(args): 8 | print(f"Downloading {args.repo} to {args.cache_dir}") 9 | root_folder = huggingface_hub.snapshot_download( 10 | repo_id=args.repo, 11 | repo_type="dataset", 12 | cache_dir=args.cache_dir, 13 | ) 14 | print(f"Downloaded to {root_folder}") 15 | 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("-r", "--repo", type=str, required=True, help="The HF repo to download") 20 | parser.add_argument("-c", "--cache_dir", type=str, default="data/text/", help="The cache directory") 21 | args = parser.parse_args() 22 | download_dataset(args) 23 | -------------------------------------------------------------------------------- /pretraining/data_processing/bin/download_from_hub.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from huggingface_hub import snapshot_download 3 | import glob 4 | import os 5 | 6 | def download_files(repo_id, pattern, local_dir, token=None): 7 | """ 8 | Download files from a Hugging Face repository that match a specific pattern. 9 | 10 | Args: 11 | repo_id (str): The repository ID (e.g., 'username/repo-name') 12 | pattern (str): File pattern to match (e.g., '*.txt', 'model/*.safetensors') 13 | local_dir (str): Local directory to save the files 14 | token (str, optional): Hugging Face authentication token for private repos 15 | """ 16 | try: 17 | # Create the local directory if it doesn't exist 18 | os.makedirs(local_dir, exist_ok=True) 19 | if token is None: 20 | token = os.environ.get("HF_TOKEN") 21 | 22 | assert token is not None, "HF_TOKEN environment variable is not set" 23 | 24 | # Use the allow_patterns parameter to filter files 25 | print(f"Downloading files from {repo_id} with pattern `{pattern}` to {local_dir}") 26 | snapshot_download( 27 | repo_id=repo_id, 28 | # allow_patterns=pattern, 29 | local_dir=local_dir, 30 | token=token, 31 | repo_type="dataset", 32 | ) 33 | 34 | print(f"Successfully downloaded files matching '{pattern}' to {local_dir}") 35 | 36 | except Exception as e: 37 | print(f"Error downloading files: {str(e)}") 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser(description='Download files from Hugging Face Hub') 41 | parser.add_argument('--repo', required=True, help='Repository ID (e.g., username/repo-name)') 42 | parser.add_argument('--pattern', required=True, help='File pattern to match (e.g., *.txt)') 43 | parser.add_argument('--output', required=True, help='Local directory to save files') 44 | parser.add_argument('--token', help='Hugging Face authentication token', default=None) 45 | 46 | args = parser.parse_args() 47 | 48 | download_files(args.repo, args.pattern, args.output, args.token) 49 | 50 | if __name__ == "__main__": 51 | main() 52 | # this file is most useful for downloading a single dataset from huggingface -------------------------------------------------------------------------------- /pretraining/data_processing/bin/make_recursive_root.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from streaming.base.util import merge_index 4 | import typer 5 | 6 | def recursive_merge(current_folder): 7 | subdirs = [d for d in glob.glob(os.path.join(current_folder, '*')) if os.path.isdir(d)] 8 | 9 | # Check if all direct subdirectories have index.json 10 | all_subdirs_have_index = all(os.path.exists(os.path.join(d, 'index.json')) for d in subdirs) 11 | 12 | if all_subdirs_have_index: 13 | # If all subdirectories have index.json, merge them 14 | index_files = [os.path.join(d, 'index.json') for d in subdirs] 15 | print(f"Merging {len(index_files)} index files in {current_folder}") 16 | print(f"Example index file: {index_files[0]}") 17 | merge_index(index_files, current_folder) 18 | else: 19 | # If any subdirectory doesn't have index.json, process it recursively 20 | for subdir in subdirs: 21 | if not os.path.exists(os.path.join(subdir, 'index.json')): 22 | recursive_merge(subdir) 23 | 24 | # After processing subdirectories, check again if we can merge 25 | if all(os.path.exists(os.path.join(d, 'index.json')) for d in subdirs): 26 | index_files = [os.path.join(d, 'index.json') for d in subdirs] 27 | print(f"Merging {len(index_files)} index files in {current_folder}") 28 | print(f"Example index file: {index_files[0]}") 29 | merge_index(index_files, current_folder) 30 | 31 | app = typer.Typer() 32 | 33 | @app.command() 34 | def recursively_make_root( 35 | folder_path: str = typer.Argument(..., help="Path to the folder to process"), 36 | ): 37 | # Start the recursive merging process 38 | recursive_merge(folder_path) 39 | 40 | if __name__ == "__main__": 41 | app() 42 | 43 | # NOTE: you can easily combine mds datasets with this simple script 44 | # python make_recursive_root.py /path/to/folder/of/train -------------------------------------------------------------------------------- /pretraining/data_processing/bin/tokenize_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | datasets=( 5 | # add dataset with local paths here like 6 | "data/text/datasets--orionweller--wikipedia_mds_incremental/snapshots/aff2afa7d7274979206600f1b53d7869eebc3dc9" 7 | ) 8 | 9 | for dataset in "${datasets[@]}" 10 | do 11 | echo "Tokenizing $dataset" 12 | python src/tokenization/tokenize_mds_subfolders.py -t answerdotai/ModernBERT-base -r $dataset -n 40 13 | # sometimes you have to run the above multiple times and then run the below 14 | python src/utils/compare_subfolders.py -l $dataset 15 | python src/tokenization/move_tokenized.py $dataset --tokenizer_name olmo_space 16 | python ./bin/make_root.py $dataset --tokenizer_name olmo_space 17 | python bin/count_tokenized_tokens_from_file.py --dataset_path $dataset-tokenized --tokenizer_name olmo_space 18 | done -------------------------------------------------------------------------------- /pretraining/data_processing/bin/upload_large_folder.py: -------------------------------------------------------------------------------- 1 | import huggingface_hub 2 | import os 3 | import argparse 4 | 5 | 6 | def upload_folder(args): 7 | print(f"Creating a new repo {args.repo}") 8 | api = huggingface_hub.HfApi() 9 | repo_url = api.create_repo( 10 | args.repo, 11 | repo_type="dataset", 12 | exist_ok=True, 13 | ) 14 | # Upload all the content from the local folder to your remote Space. 15 | # By default, files are uploaded at the root of the repo 16 | print(f"Uploading {args.folder} to {args.repo}") 17 | api.upload_large_folder( 18 | folder_path=args.folder, 19 | repo_id=args.repo, 20 | repo_type="dataset", 21 | ) 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser(description="Upload a folder to Hugging Face Hub") 25 | parser.add_argument("-f", "--folder", type=str, help="The folder to upload", required=True) 26 | parser.add_argument("-r", "--repo", type=str, help="The repo to upload to", required=True) 27 | args = parser.parse_args() 28 | upload_folder(args) 29 | 30 | 31 | # example usage: 32 | # python push_folder_to_hub.py -f downloaded_data/fineweb-edu-350B -r orionweller/fineweb-edu-350B -------------------------------------------------------------------------------- /pretraining/data_processing/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | numpy 3 | tokenizers 4 | huggingface_hub 5 | pyyaml 6 | ruff 7 | tqdm 8 | mosaicml-streaming==0.8.1 9 | -------------------------------------------------------------------------------- /pretraining/data_processing/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JHU-CLSP/ettin-encoder-vs-decoder/6cbf954a2ecaa7c753be6468a1bccdef401d5f93/pretraining/data_processing/src/__init__.py -------------------------------------------------------------------------------- /pretraining/data_processing/src/initial_dataset_creation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JHU-CLSP/ettin-encoder-vs-decoder/6cbf954a2ecaa7c753be6468a1bccdef401d5f93/pretraining/data_processing/src/initial_dataset_creation/__init__.py -------------------------------------------------------------------------------- /pretraining/data_processing/src/initial_dataset_creation/mds_to_jsonl.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import huggingface_hub 4 | import os 5 | import tqdm 6 | import random 7 | from transformers import AutoTokenizer, set_seed 8 | import datasets 9 | import argparse 10 | from streaming import StreamingDataset 11 | 12 | 13 | from src.utils.data_utils import SOURCE_MAP 14 | 15 | 16 | def mds_to_jsonl(args): 17 | source_repo = SOURCE_MAP[args.source] 18 | assert os.path.isdir(source_repo), f"Source {source_repo} does not exist." 19 | print(f"Using local dataset {source_repo}...") 20 | dataset = StreamingDataset(local=source_repo, shuffle=False, split=None, batch_size=1, shuffle_seed=9176) 21 | out_f = open(args.out_file, "w") 22 | for instance in tqdm.tqdm(dataset): 23 | out_f.write(json.dumps(instance) + "\n") 24 | 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("-s", "--source", type=str, required=True) 30 | parser.add_argument("-o", "--out_file", type=str, required=True) 31 | args = parser.parse_args() 32 | 33 | mds_to_jsonl(args) 34 | -------------------------------------------------------------------------------- /pretraining/data_processing/src/initial_dataset_creation/merge_mds_to_one_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gzip 3 | import numpy as np 4 | import multiprocessing 5 | import huggingface_hub 6 | import glob 7 | import tempfile 8 | 9 | from streaming.base.util import merge_index 10 | 11 | 12 | def merge(root_folder): 13 | # merge them all together by gathering all index.json files 14 | string_files = list(glob.glob(root_folder + "/**/index.json", recursive=True)) 15 | print(f"Merging {len(string_files)} files") 16 | merge_index(string_files, root_folder) 17 | 18 | print(f"Merged to {root_folder}/index.json") 19 | 20 | 21 | if __name__ == "__main__": 22 | import argparse 23 | parser = argparse.ArgumentParser(description="Merge all index.json files in a directory") 24 | parser.add_argument("root_folder", help="Path to the root folder to process") 25 | args = parser.parse_args() 26 | 27 | merge(args.root_folder) 28 | 29 | # python merge_mds_to_one_index.py data/text/mlfoundations-dclm-baseline-1.0-parquet---train---small/ -------------------------------------------------------------------------------- /pretraining/data_processing/src/sampling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JHU-CLSP/ettin-encoder-vs-decoder/6cbf954a2ecaa7c753be6468a1bccdef401d5f93/pretraining/data_processing/src/sampling/__init__.py -------------------------------------------------------------------------------- /pretraining/data_processing/src/sampling/move_out_final_sampled_chunks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import glob 4 | import json 5 | import logging 6 | from tqdm import tqdm 7 | from streaming import StreamingDataset 8 | from streaming.base.util import merge_index 9 | 10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 11 | 12 | def reorganize_and_index_datasets(main_dir): 13 | # Create train and validation directories in the main folder 14 | train_dir = os.path.join(main_dir, 'train') 15 | validation_dir = os.path.join(main_dir, 'validation') 16 | os.makedirs(train_dir, exist_ok=True) 17 | os.makedirs(validation_dir, exist_ok=True) 18 | 19 | # Get all dataset names 20 | dataset_names = [name for name in os.listdir(main_dir) 21 | if os.path.isdir(os.path.join(main_dir, name)) 22 | and name not in ['train', 'validation'] and "-FULL" not in name] 23 | 24 | for dataset_name in tqdm(dataset_names, desc="Processing datasets"): 25 | dataset_path = os.path.join(main_dir, dataset_name) 26 | 27 | # Move train data 28 | src_train = os.path.join(dataset_path, 'train') 29 | dst_train = os.path.join(train_dir, dataset_name) 30 | if os.path.exists(src_train): 31 | logging.info(f"Moved {src_train} to {dst_train}") 32 | shutil.move(src_train, dst_train) 33 | 34 | # Move validation data 35 | src_validation = os.path.join(dataset_path, 'validation') 36 | dst_validation = os.path.join(validation_dir, dataset_name) 37 | if os.path.exists(src_validation): 38 | logging.info(f"Moved {src_validation} to {dst_validation}") 39 | shutil.move(src_validation, dst_validation) 40 | 41 | # Combine index.json files for the dataset 42 | for split in ['train']: 43 | split_dir = os.path.join(main_dir, split, dataset_name) 44 | # rename the current index.json file to index.json.old 45 | if os.path.exists(os.path.join(split_dir, 'index.json')): 46 | os.rename(os.path.join(split_dir, 'index.json'), os.path.join(split_dir, 'index.json.old')) 47 | if os.path.exists(split_dir): 48 | index_files = glob.glob(f"{split_dir}/**/index.json", recursive=True) 49 | if index_files: 50 | logging.info(f"Merging {len(index_files)} index files for {dataset_name} {split}") 51 | merge_index(index_files, split_dir) 52 | 53 | # Combine index.json files for the entire dataset 54 | for split in ['train', 'validation']: 55 | split_dir = os.path.join(main_dir, split) 56 | index_files = glob.glob(f"{split_dir}/*/index.json") 57 | if index_files: 58 | logging.info(f"Merging {len(index_files)} index files for entire {split} set") 59 | merge_index(index_files, split_dir) 60 | 61 | # Load the overall dataset and print the number of instances 62 | for split in ['train', 'validation']: 63 | split_dir = os.path.join(main_dir, split) 64 | if os.path.exists(split_dir): 65 | dataset = StreamingDataset(local=split_dir, predownload=1, batch_size=1) 66 | num_instances = len(dataset) 67 | logging.info(f"Number of instances in {split} set: {num_instances}") 68 | 69 | 70 | if __name__ == "__main__": 71 | import argparse 72 | 73 | parser = argparse.ArgumentParser(description="Reorganize datasets and create combined index") 74 | parser.add_argument("main_dir", help="Main directory containing the datasets") 75 | 76 | args = parser.parse_args() 77 | 78 | reorganize_and_index_datasets(args.main_dir) 79 | 80 | # Example usage: 81 | # python move_out_final_sampled_chunks.py data/chunked-olmo-1024-512-128-backfill-nodups -------------------------------------------------------------------------------- /pretraining/data_processing/src/sampling/sample_from_folders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import shutil 5 | import argparse 6 | import time 7 | import tqdm 8 | from transformers import set_seed 9 | 10 | set_seed(42) 11 | 12 | def sample_folders(source_dir, target_tokens, output_dir): 13 | total_tokens = 0 14 | selected_folders = [] 15 | 16 | # Get all subdirectories 17 | all_folders = [f for f in os.listdir(source_dir) if os.path.isdir(os.path.join(source_dir, f))] 18 | 19 | while total_tokens < target_tokens and all_folders: 20 | # Randomly select a folder 21 | folder = random.choice(all_folders) 22 | all_folders.remove(folder) 23 | 24 | folder_path = os.path.join(source_dir, folder) 25 | stats_file = os.path.join(folder_path, 'stats.json') 26 | 27 | if os.path.exists(stats_file): 28 | with open(stats_file, 'r') as f: 29 | stats = json.load(f) 30 | 31 | folder_tokens = stats['total_tokens_written'] 32 | 33 | total_tokens += folder_tokens 34 | selected_folders.append(folder) 35 | print(f"Selected folder: {folder}, Tokens: {folder_tokens:,}, Total: {total_tokens:,}") 36 | 37 | # Create output directory 38 | os.makedirs(output_dir, exist_ok=True) 39 | print(f"Created output directory: {output_dir}") 40 | 41 | # Copy selected folders to output directory 42 | for folder in tqdm.tqdm(selected_folders): 43 | src = os.path.join(source_dir, folder) 44 | dst = os.path.join(output_dir, folder) 45 | print(f"copying {src} to {dst}") 46 | shutil.copytree(src, dst) 47 | 48 | print(f"Sampling complete. Total tokens: {total_tokens}") 49 | return total_tokens, selected_folders 50 | 51 | def main(): 52 | parser = argparse.ArgumentParser(description="Sample folders based on token count.") 53 | parser.add_argument("source_dir", help="Source directory to sample from") 54 | parser.add_argument("tokens_to_sample", type=int, help="Number of tokens to sample") 55 | args = parser.parse_args() 56 | 57 | source_dir = args.source_dir 58 | tokens_to_sample = args.tokens_to_sample 59 | output_dir = f"{source_dir}-sampled" 60 | print(f"Writing output to {output_dir}") 61 | time.sleep(5) 62 | 63 | total_tokens, selected_folders = sample_folders(source_dir, tokens_to_sample, output_dir) 64 | 65 | print(f"Sampled {len(selected_folders)} folders") 66 | print(f"Total tokens sampled: {total_tokens}") 67 | print(f"Output directory: {output_dir}") 68 | 69 | if __name__ == "__main__": 70 | main() 71 | 72 | # example usage: 73 | # python sample_from_folders.py "data/chunked-space-1024-512-128-backfill-nodups/mlfoundations-dclm-baseline-1.0-parquet-FULL" 837179337679 # (837,179,337,679) -------------------------------------------------------------------------------- /pretraining/data_processing/src/tokenization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JHU-CLSP/ettin-encoder-vs-decoder/6cbf954a2ecaa7c753be6468a1bccdef401d5f93/pretraining/data_processing/src/tokenization/__init__.py -------------------------------------------------------------------------------- /pretraining/data_processing/src/tokenization/tokenize_mds.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import tqdm 5 | from transformers import AutoTokenizer, set_seed 6 | from streaming import StreamingDataset, MDSWriter 7 | from src.utils.data_utils import MDS_COLS_PRE_TOKENIZED 8 | import multiprocessing as mp 9 | from functools import partial 10 | import numpy as np 11 | import uuid 12 | import gc 13 | import streaming 14 | 15 | def get_uuid(): 16 | return str(uuid.uuid4()) 17 | 18 | def tokenize_batch(tokenizer, batch): 19 | tokenized = tokenizer(batch, truncation=False, padding=False, return_tensors="np") 20 | # Convert each sequence individually to uint32 21 | input_ids = [np.array(seq, dtype=np.uint32) for seq in tokenized['input_ids']] 22 | return input_ids 23 | 24 | def process_chunk(chunk, tokenizer): 25 | assert isinstance(chunk, list) and all(isinstance(item, dict) for item in chunk), "Chunk should be a list of dictionaries" 26 | texts = [item["text"] for item in chunk] 27 | if "id" not in chunk[0]: 28 | # create ids from uuids 29 | ids = [str(uuid.uuid4()) for _ in chunk] 30 | else: 31 | ids = [item["id"] for item in chunk] 32 | input_ids, attention_mask = tokenize_batch(tokenizer, texts) 33 | assert len(ids) == len(input_ids) == len(attention_mask) == len(texts), f"Length mismatch in chunk. {len(ids)} {len(input_ids)} {len(attention_mask)} {len(texts)}" 34 | return [{'id': id, 'input_ids': input_id, 'attention_mask': mask, "len": len(input_id)} 35 | for id, input_id, mask in zip(ids, input_ids, attention_mask)] 36 | 37 | def sample_dataset_from_config(args): 38 | assert os.path.isdir(args.dataset), f"Dataset {args.dataset} does not exist." 39 | print(f"Using local dataset {args.dataset}...") 40 | # clean up the shared memory 41 | streaming.base.util.clean_stale_shared_memory() 42 | dataset = StreamingDataset(local=args.dataset, shuffle=False, split=None, batch_size=1, keep_zip=False) 43 | 44 | print(f"Using tokenizer model {args.tokenizer}...") 45 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True, add_prefix_space=True) 46 | 47 | output_dir = args.dataset + "-tokenized" 48 | 49 | num_tokens = 0 50 | num_truncated_tokens = 0 51 | batch_size = 5000 52 | if args.has_domains: 53 | MDS_COLS_PRE_TOKENIZED["domain"] = "str" 54 | 55 | with MDSWriter(out=output_dir, columns=MDS_COLS_PRE_TOKENIZED, compression='zstd') as mds_writer: 56 | pbar = tqdm.tqdm(total=len(dataset), desc="Processing samples") 57 | 58 | for i in range(0, len(dataset), batch_size): 59 | # Get a single batch of data 60 | end = min(i + batch_size, len(dataset)) 61 | chunk = [dataset[k] for k in range(i, end)] 62 | 63 | if not chunk: 64 | break 65 | 66 | # Process the chunk directly without multiprocessing 67 | batch_results = process_chunk(chunk, tokenizer) 68 | 69 | # Write results immediately 70 | for item in batch_results: 71 | mds_writer.write(item) 72 | num_tokens += item['len'] 73 | num_truncated_tokens += min(1024, item['len']) 74 | 75 | # Clear item from memory 76 | del item 77 | 78 | # Clear results from memory 79 | del batch_results 80 | gc.collect() 81 | 82 | pbar.update(len(chunk)) 83 | 84 | if args.debug and pbar.n >= 100 * batch_size: 85 | break 86 | 87 | print(f"Finished writing with a total of {num_tokens} train tokens.") 88 | # save a json file in the directory with the number of tokens and the number of truncated tokens 89 | with open(os.path.join(output_dir, "num_tokens.json"), "w") as f: 90 | f.write(json.dumps({"num_tokens": num_tokens, "num_truncated_tokens": num_truncated_tokens}, indent=2)) 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("-d", "--dataset", type=str, required=True) 95 | parser.add_argument("-t", "--tokenizer", type=str, required=True) 96 | parser.add_argument("--has_domains", action="store_true") 97 | parser.add_argument("--debug", action="store_true") 98 | args = parser.parse_args() 99 | 100 | set_seed(123456789) 101 | sample_dataset_from_config(args) 102 | 103 | # python tokenize_mds.py --dataset data/arxiv/ --tokenizer bclavie/olmo_bert_template -------------------------------------------------------------------------------- /pretraining/data_processing/src/utils/cleanup.py: -------------------------------------------------------------------------------- 1 | from tokenize_mds_subfolders import cleanup_folder 2 | import argparse 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("-r", "--root_path", type=str, required=True) 7 | args = parser.parse_args() 8 | 9 | cleanup_folder(args.root_path) -------------------------------------------------------------------------------- /pretraining/data_processing/src/utils/cleanup_all.py: -------------------------------------------------------------------------------- 1 | from tokenize_mds_subfolders import cleanup_folder 2 | from data_utils import SOURCE_MAP 3 | 4 | 5 | if __name__ == "__main__": 6 | for source_dir in SOURCE_MAP.values(): 7 | print(f"Cleaning up {source_dir}-tokenized ...") 8 | cleanup_folder(source_dir + "-tokenized") -------------------------------------------------------------------------------- /pretraining/data_processing/src/utils/compare_subfolders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from torch.utils.data import DataLoader 4 | from streaming import StreamingDataset 5 | import os 6 | import tqdm 7 | import streaming 8 | 9 | 10 | def compare(args): 11 | 12 | # first check that all folders have a index.json file 13 | for folder in os.listdir(args.local): 14 | if not os.path.isdir(os.path.join(args.local, folder)): 15 | continue 16 | 17 | if folder.endswith("-tokenized"): 18 | continue 19 | 20 | local = os.path.join(args.local, folder) 21 | if not os.path.exists(os.path.join(local, "index.json")): 22 | print(f"Folder: {folder} does not have index.json file") 23 | 24 | local_tokenized = os.path.join(args.local, folder + "-tokenized") 25 | if not os.path.exists(os.path.join(local_tokenized, "index.json")): 26 | print(f"Tokenized folder: {local_tokenized} does not have index.json file") 27 | 28 | # get all folders and sort them 29 | folders = sorted(os.listdir(args.local)) 30 | if args.skip: 31 | folders = folders[args.skip:] 32 | for folder in tqdm.tqdm(folders): 33 | if not os.path.isdir(os.path.join(args.local, folder)): 34 | continue 35 | 36 | if folder.endswith("-tokenized"): 37 | continue 38 | 39 | if folder == "data": 40 | continue 41 | 42 | streaming.base.util.clean_stale_shared_memory() 43 | 44 | local = os.path.join(args.local, folder) 45 | dataset = StreamingDataset(local=local, shuffle=False, split=None, batch_size=1, predownload=1) 46 | len_og = len(dataset) 47 | 48 | local_tokenized = os.path.join(args.local, folder + "-tokenized") 49 | try: 50 | dataset_tokenized = StreamingDataset(local=local_tokenized, shuffle=False, split=None, batch_size=1, predownload=1) 51 | len_tok = len(dataset_tokenized) 52 | except Exception as e: 53 | print(f"Error loading tokenized dataset: {e}") 54 | len_tok = 0 55 | 56 | if len_og != len_tok: 57 | print(f"Folder: {folder} has different lengths: {len_og} vs {len_tok}") 58 | 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument("-l", "--local", type=str, required=True) 63 | parser.add_argument("-s", "--skip", type=int, default=0) 64 | args = parser.parse_args() 65 | compare(args) 66 | 67 | # NOTE: the skip is there since it only seems to do about 10k before it opens too many files 68 | # so it allows it to try again from the failure point 69 | # python compare_subfolders.py -l data/text/datasets--orionweller--dclm-1T-sample/snapshots/e01c4d93f79aacd04361454cc360da67eefab9a3 -s 42000 -------------------------------------------------------------------------------- /pretraining/data_processing/src/utils/compare_train_and_chunked.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import logging 5 | import pandas as pd 6 | from tqdm import tqdm 7 | import glob 8 | from streaming.base.util import merge_index 9 | 10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 11 | 12 | def create_index_file(directory): 13 | index_files = glob.glob(f"{directory}/**/index.json", recursive=True) 14 | if index_files: 15 | logging.info(f"Merging {len(index_files)} index files for {directory}") 16 | merge_index(index_files, directory) 17 | else: 18 | logging.warning(f"No index files found in {directory}") 19 | 20 | def count_instances_from_index(index_path): 21 | if not os.path.exists(index_path): 22 | directory = os.path.dirname(index_path) 23 | create_index_file(directory) 24 | 25 | try: 26 | with open(index_path, 'r') as f: 27 | index_data = json.load(f) 28 | return sum(shard['samples'] for shard in index_data['shards']) 29 | except Exception as e: 30 | logging.error(f"Error loading index from {index_path}: {str(e)}") 31 | return 0 32 | 33 | def compare_datasets(chunking_dir, output_csv): 34 | train_dir = os.path.join(chunking_dir, 'train') 35 | validation_dir = os.path.join(chunking_dir, 'validation') 36 | 37 | subfolders = [f for f in os.listdir(chunking_dir) if os.path.isdir(os.path.join(chunking_dir, f)) 38 | and f not in ['train', 'validation', '.locks'] and "-FULL" not in f] 39 | 40 | results = [] 41 | 42 | for subfolder in tqdm(subfolders, desc="Processing subfolders"): 43 | main_path = os.path.join(chunking_dir, subfolder, 'index.json') 44 | train_path = os.path.join(train_dir, subfolder, 'index.json') 45 | validation_path = os.path.join(validation_dir, subfolder, 'index.json') 46 | 47 | main_count = count_instances_from_index(main_path) 48 | train_count = count_instances_from_index(train_path) 49 | validation_count = count_instances_from_index(validation_path) 50 | 51 | train_ratio = train_count / main_count if main_count > 0 else 0 52 | is_train_ratio_correct = abs(train_ratio - 0.999) <= 0.001 if main_count > 0 else False 53 | is_total_correct = (train_count + validation_count == main_count) 54 | 55 | result = { 56 | 'Dataset': subfolder, 57 | 'Main instances': main_count, 58 | 'Train instances': train_count, 59 | 'Validation instances': validation_count, 60 | 'Train ratio': train_ratio, 61 | 'Is train ratio correct': is_train_ratio_correct, 62 | 'Is total correct': is_total_correct 63 | } 64 | results.append(result) 65 | 66 | logging.info(f"\nDataset: {subfolder}") 67 | logging.info(f"Main instances: {main_count}") 68 | logging.info(f"Train instances: {train_count}") 69 | logging.info(f"Validation instances: {validation_count}") 70 | logging.info(f"Train ratio: {train_ratio:.4f}") 71 | 72 | if not is_train_ratio_correct: 73 | logging.warning(f"Train ratio is not approximately 0.999 (99.9%)") 74 | if not is_total_correct: 75 | logging.warning(f"Sum of train and validation instances does not equal main instances: {train_count} + {validation_count} != {main_count}") 76 | 77 | # Create DataFrame and save to CSV 78 | df = pd.DataFrame(results) 79 | df.to_csv(output_csv, index=False) 80 | logging.info(f"Results saved to {output_csv}") 81 | 82 | # Print summary statistics 83 | logging.info("\nSummary Statistics:") 84 | logging.info(df.describe()) 85 | 86 | # Print datasets with incorrect ratios or totals 87 | incorrect_ratios = df[~df['Is train ratio correct']] 88 | incorrect_totals = df[~df['Is total correct']] 89 | 90 | if not incorrect_ratios.empty: 91 | logging.warning("\nDatasets with incorrect train ratios:") 92 | logging.warning(incorrect_ratios[['Dataset', 'Train ratio']]) 93 | 94 | if not incorrect_totals.empty: 95 | logging.warning("\nDatasets with mismatched totals:") 96 | logging.warning(incorrect_totals[['Dataset', 'Main instances', 'Train instances', 'Validation instances']]) 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser(description="Compare dataset instances in chunking directory and save results to CSV") 100 | parser.add_argument("chunking_dir", help="Path to the chunking directory") 101 | parser.add_argument("--output", default="dataset_comparison_results.csv", help="Output CSV file name") 102 | args = parser.parse_args() 103 | 104 | compare_datasets(args.chunking_dir, args.output) 105 | # python compare_train_and_chunked.py data/chunked-olmo-1024-512-128-backfill-nodups 106 | -------------------------------------------------------------------------------- /pretraining/data_processing/src/utils/create_final_dataset_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import logging 5 | import glob 6 | from streaming.base.util import merge_index 7 | from tqdm import tqdm 8 | 9 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 10 | 11 | def create_main_index(main_dir, split): 12 | split_dir = os.path.join(main_dir, split) 13 | index_files = glob.glob(f"{split_dir}/*/index.json") 14 | if index_files: 15 | logging.info(f"Merging {len(index_files)} index files for entire {split} set") 16 | merge_index(index_files, split_dir) 17 | logging.info(f"Created main index.json in {split_dir}") 18 | else: 19 | logging.warning(f"No index files found in {split_dir}") 20 | 21 | def count_instances_from_index(index_path): 22 | try: 23 | with open(index_path, 'r') as f: 24 | index_data = json.load(f) 25 | return sum(shard['samples'] for shard in index_data['shards']) 26 | except Exception as e: 27 | logging.error(f"Error loading index from {index_path}: {str(e)}") 28 | return 0 29 | 30 | def create_and_verify_index(main_dir, split): 31 | # Create the main index.json 32 | create_main_index(main_dir, split) 33 | 34 | # Verify the count 35 | main_index_path = os.path.join(main_dir, split, 'index.json') 36 | main_count = count_instances_from_index(main_index_path) 37 | logging.info(f"Main {split} count from new index.json: {main_count}") 38 | 39 | # Count instances in individual subdirectories for verification 40 | split_dir = os.path.join(main_dir, split) 41 | subfolders = [f for f in os.listdir(split_dir) 42 | if os.path.isdir(os.path.join(split_dir, f)) and f != '.locks'] 43 | 44 | total_count = 0 45 | for subfolder in tqdm(subfolders, desc=f"Verifying {split} subfolders"): 46 | index_path = os.path.join(split_dir, subfolder, 'index.json') 47 | count = count_instances_from_index(index_path) 48 | total_count += count 49 | logging.info(f"Subfolder {subfolder}: {count} instances") 50 | 51 | logging.info(f"Total {split} count from subfolders: {total_count}") 52 | 53 | if main_count == total_count: 54 | logging.info("Verification successful: Main count matches total subfolder count.") 55 | else: 56 | logging.warning(f"Verification failed: Main count ({main_count}) does not match total subfolder count ({total_count}).") 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser(description="Create main index.json and verify count") 60 | parser.add_argument("main_dir", help="Path to the main directory containing the split") 61 | parser.add_argument("--split", default="train", help="Split to process (default: train)") 62 | args = parser.parse_args() 63 | 64 | create_and_verify_index(args.main_dir, args.split) 65 | # python create_final_dataset_index.py data/chunked-olmo-1024-512-128-backfill-nodups/ -------------------------------------------------------------------------------- /pretraining/data_processing/src/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | MDS_COLS_TOKENIZED = { 4 | 'input_ids': 'ndarray:uint32', 5 | 'id': 'str' 6 | } 7 | 8 | MDS_COLS_TEXT = { 9 | 'text': 'str', 10 | 'id': 'str' 11 | } 12 | 13 | MDS_COLS_PRE_TOKENIZED = { 14 | 'input_ids': 'ndarray', 15 | 'id': 'str', 16 | 'len': 'int' 17 | } 18 | 19 | MDS_COLS_OUTPUT_ONLY = { 20 | 'input_ids': 'ndarray:uint32', 21 | 'id': 'str', 22 | } 23 | 24 | 25 | # path to the uploaded datasets on huggingface, not all of these were used 26 | SOURCE_MAP_REMOTE = { 27 | "books": "orionweller/books_mds_incremental", 28 | "wiki": "orionweller/wikipedia_mds_incremental", 29 | "falcon": "orionweller/refinedweb_mds_incremental", 30 | "c4": "orionweller/c4_mds_incremental", 31 | "cc_en_head": "orionweller/cc_en_head_mds_incremental", 32 | "cc_en_tail": "orionweller/cc_en_tail_mds_incremental", 33 | "cc_en_middle": "orionweller/cc_en_middle_mds_incremental", 34 | "megawika": "orionweller/megawika_mds_incremental", 35 | "cc_news": "orionweller/cc_news_mds_incremental", 36 | "pes2o": "orionweller/pes2o_mds_incremental", 37 | "tulu_flan": "orionweller/tulu_flan_mds_incremental", 38 | "starcoder": "orionweller/starcoder_mds_incremental", 39 | "stackexchange": "orionweller/stackexchange_mds_incremental", 40 | "arxiv": "orionweller/arxiv_mds_incremental", 41 | "open_web_math_train": "orionweller/open-web-math_mds_incremental", 42 | "reddit": "orionweller/reddit_mds_incremental", 43 | "algebraic_stack_train": "orionweller/algebraic-stack_mds_incremental", 44 | "caselaw-access-project": "orionweller/caselaw-access-project", 45 | "fineweb-edu-10B": "orionweller/fineweb-edu-10B", 46 | "fineweb-edu-350B": "orionweller/fineweb-edu-350B", 47 | "fineweb-edu-score-2": "orionweller/fineweb-edu-score-2", 48 | } 49 | 50 | 51 | 52 | SOURCE_MAP = { 53 | # Path to the local downloaded versions of the dataset, not all of these were used 54 | "case_access_law": "ettin-data/data/text/TeraflopAI-Caselaw_Access_Project---train---default", 55 | "fineweb-edu": "ettin-data/data/text/HuggingFaceTB-smollm-corpus---train---fineweb-edu-dedup", 56 | "algebraic_stack_train": "ettin-data/data/text/datasets--orionweller--algebraic-stack_mds_incremental/snapshots/5af697376cc89b191fef8b7873280e2c393e8361", 57 | "arxiv": "ettin-data/data/text/datasets--orionweller--arxiv_mds_incremental/snapshots/640f80fa7d7ff93226a1f7115f70145fd1f4ead7", 58 | "books": "ettin-data/data/text/datasets--orionweller--books_mds_incremental/snapshots/502df43dc5445788353f1cf7befdc1a3cbedd6cb", 59 | "c4": "ettin-data/data/text/datasets--orionweller--c4_mds_incremental/snapshots/fdb71eeccbe17fc95d0e0330dea5f9f0e79c7aaa", 60 | "cc_en_head": "ettin-data/data/text/datasets--orionweller--cc_en_head_mds_incremental/snapshots/3f13a7e03eef6df4ee62b486ea912eb926e7be91", 61 | "cc_en_middle": "ettin-data/data/text/datasets--orionweller--cc_en_middle_mds_incremental/snapshots/4e4577a77611d8c6ebcefafa89d64ec7329d8d1b", 62 | "cc_en_tail": "ettin-data/data/text/datasets--orionweller--cc_en_tail_mds_incremental/snapshots/58bf1c63c23598548a1da42cc3dffe42d2672f80", 63 | "cc_news": "ettin-data/data/text/datasets--orionweller--cc_news_mds_incremental/snapshots/846a17dd910daf76ffd96fa735a3dd9240736116", 64 | "megawika": "ettin-data/data/text/datasets--orionweller--megawika_mds_incremental/snapshots/477460d68212afbf7937bbfe0143bf482651b684", 65 | "open_web_math_train": "ettin-data/data/text/datasets--orionweller--open-web-math_mds_incremental/snapshots/732910d828ea4f7e1ab62a7a787d5e3bd59210b0", 66 | "pes2o": "ettin-data/data/text/datasets--orionweller--pes2o_mds_incremental/snapshots/71ed50bcfd714e2360c4fcd59d601d4eecc9d1a2", 67 | "reddit": "ettin-data/data/text/datasets--orionweller--reddit_mds_incremental/snapshots/53d1edb1053ffa4b519ba45d9daed14fd82cfd68", 68 | "falcon": "ettin-data/data/text/datasets--orionweller--refinedweb_mds_incremental/snapshots/31ce550bcb0c117c0ce058f166795c338a8f6fa1", 69 | "stackexchange": "ettin-data/data/text/datasets--orionweller--stackexchange_mds_incremental/snapshots/dd95be271cac1709b0dda6776a6e83df93b4d1f0", 70 | "starcoder": "ettin-data/data/text/datasets--orionweller--starcoder_mds_incremental/snapshots/5cef62e15c251baa36779ebab15a9e5ba3a5d7a6", 71 | "tulu_flan": "ettin-data/data/text/datasets--orionweller--tulu_flan_mds_incremental/snapshots/7f00a393e1b26ee2d48b65ca26ab61fd6e82786e", 72 | "wiki": "ettin-data/data/text/datasets--orionweller--wikipedia_mds_incremental/snapshots/aff2afa7d7274979206600f1b53d7869eebc3dc9", 73 | "fineweb-edu-score-2": "ettin-data/data/text/datasets--orionweller--fineweb-edu-score-2/snapshots/755c506cae00da40a7cbe5d8b7dbcf7f6e171de9/HuggingFaceFW-fineweb-edu-score-2---train---default", 74 | "dclm": "ettin-data/data/text/mlfoundations-dclm-baseline-1.0-parquet---train---default", 75 | "cosmopediav2": "ettin-data/data/text/HuggingFaceTB-smollm-corpus---train---cosmopedia-v2" 76 | } 77 | 78 | ALL_REPOS_REMOTE = list(SOURCE_MAP_REMOTE.values()) 79 | ALL_REPOS = list(SOURCE_MAP.values()) -------------------------------------------------------------------------------- /pretraining/data_processing/src/utils/get_counts_from_hf.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import argparse 3 | import random 4 | import json 5 | import datasets 6 | import requests 7 | import math 8 | import os 9 | import gzip 10 | import numpy as np 11 | import multiprocessing 12 | import huggingface_hub 13 | import glob 14 | import tempfile 15 | 16 | from datasets import load_dataset, Dataset, DatasetDict, interleave_datasets 17 | from streaming.base.util import _merge_index_from_root, merge_index 18 | from transformers import set_seed, AutoTokenizer 19 | from streaming import MDSWriter, StreamingDataset 20 | 21 | from huggingface_hub import HfFileSystem 22 | from data_utils import ALL_REPOS 23 | 24 | 25 | def get_counts_for_repo(repo, args): 26 | # download the root index.json only 27 | files_in_repo = [item.path for item in huggingface_hub.list_repo_tree(repo, repo_type="dataset")] 28 | if "index.json" not in files_in_repo: 29 | # it must be in the main folder 30 | main_folder = None 31 | repo_name_folder = repo.split("/")[-1] 32 | for file in files_in_repo: 33 | if file not in [".gitattributes"] and file.count(".") == 0: 34 | main_folder = file 35 | break 36 | main_json = f"{main_folder}/index.json" 37 | print(f"Did not find a root index.json, using {main_json}") 38 | else: 39 | main_json = "index.json" 40 | 41 | with tempfile.TemporaryDirectory() as tmp_cache_dir: 42 | root_folder = huggingface_hub.snapshot_download(repo_id=repo, allow_patterns=main_json, repo_type="dataset", cache_dir=tmp_cache_dir) 43 | dataset = StreamingDataset(local=os.path.join(root_folder, main_json.replace("index.json", "")), shuffle=False, split=None, batch_size=1) 44 | dataset_size = len(dataset) 45 | 46 | base_dir = f"datasets/{repo}" 47 | fs = HfFileSystem() 48 | try: 49 | size_of_folder = fs.du(base_dir, total=True, maxdepth=None, withdirs=True) 50 | except Exception as e: 51 | print(f"Error: {e}. Sleeping for 60 seconds and trying again") 52 | import time 53 | time.sleep(60) 54 | size_of_folder = fs.du(base_dir, total=True, maxdepth=None, withdirs=True) 55 | 56 | return {"dataset": repo, "size": size_of_folder / 1e9, "instances": dataset_size} 57 | 58 | 59 | def get_counts(args): 60 | # read in all that have been already processed 61 | if os.path.exists("dataset_info.jsonl"): 62 | with open("dataset_info.jsonl", "r") as f: 63 | processed_datasets = set([json.loads(line)["dataset"] for line in f]) 64 | else: 65 | processed_datasets = set() 66 | 67 | output_f = open("dataset_info.jsonl", "a") 68 | for repo in tqdm.tqdm(args.repos): 69 | if repo in processed_datasets: 70 | print(f"Skipping {repo} since it's already processed") 71 | continue 72 | print(f"Getting counts for {repo}") 73 | output_dict = get_counts_for_repo(repo, args) 74 | output_f.write(json.dumps(output_dict) + "\n") 75 | # flush it 76 | output_f.flush() 77 | 78 | output_f.close() 79 | 80 | # read in the info and sum and print 81 | total_size = 0 82 | total_instances = 0 83 | with open("dataset_info.jsonl", "r") as f: 84 | for line in f: 85 | info = json.loads(line) 86 | total_size += info["size"] 87 | total_instances += info["instances"] 88 | 89 | print(f"Total size: {total_size} GB") 90 | print(f"Total instances: {total_instances}") 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--repos", type=str, nargs="+", help="List of repos to get counts for", default=None) 96 | args = parser.parse_args() 97 | 98 | # if repos is None use the default ALL_REPOS 99 | if args.repos is None: 100 | args.repos = ALL_REPOS 101 | 102 | get_counts(args) 103 | 104 | # example usage: 105 | # python get_counts_from_hf.py -------------------------------------------------------------------------------- /pretraining/data_processing/src/utils/upload_dataset_by_subfolders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from huggingface_hub import HfApi, HfFolder 4 | from huggingface_hub.utils import RepositoryNotFoundError 5 | 6 | MAX_FOLDERS = 100000 # HF doesn't allow more than this number of folders in a repo 7 | 8 | def get_existing_folders(api, repo_id): 9 | try: 10 | return set(item.rfilename.split('/')[0] for item in api.list_repo_files(repo_id, repo_type="dataset") if '/' in item.rfilename) 11 | except RepositoryNotFoundError: 12 | return set() 13 | 14 | def upload_subfolder(api, repo_id, subfolder_path, subfolder_name): 15 | print(f"Uploading subfolder: {subfolder_name}") 16 | try: 17 | api.upload_folder( 18 | folder_path=subfolder_path, 19 | repo_id=repo_id, 20 | repo_type="dataset", 21 | path_in_repo=subfolder_name, 22 | commit_message=f"Upload subfolder: {subfolder_name}", 23 | create_pr=False, 24 | multi_commits=True, # Explicitly use multi_commits 25 | multi_commits_verbose=True # Add verbosity for better tracking 26 | ) 27 | print(f"Successfully uploaded {subfolder_name}") 28 | return True 29 | except Exception as e: 30 | print(f"Error uploading {subfolder_name}: {str(e)}") 31 | return False 32 | 33 | def upload_subfolders(args): 34 | api = HfApi() 35 | 36 | # Create repo if it doesn't exist 37 | if not args.skip_create: 38 | try: 39 | api.create_repo(args.repo, repo_type="dataset", exist_ok=True) 40 | print(f"Repository {args.repo} created or already exists.") 41 | except Exception as e: 42 | print(f"Error creating repository: {str(e)}") 43 | return 44 | 45 | # Get list of existing folders 46 | existing_folders = get_existing_folders(api, args.repo) 47 | print(f"Existing folders: {existing_folders}") 48 | 49 | # Get list of local subfolders 50 | subfolders = [f for f in os.listdir(args.folder) if os.path.isdir(os.path.join(args.folder, f))] 51 | 52 | # Check if number of subfolders exceeds the limit 53 | if len(subfolders) > MAX_FOLDERS: 54 | print(f"Error: The number of subfolders ({len(subfolders)}) exceeds the maximum allowed ({MAX_FOLDERS}).") 55 | return 56 | 57 | # Upload each subfolder 58 | for subfolder in subfolders: 59 | if subfolder in existing_folders: 60 | print(f"Skipping {subfolder} as it already exists in the repository.") 61 | continue 62 | 63 | subfolder_path = os.path.join(args.folder, subfolder) 64 | success = upload_subfolder(api, args.repo, subfolder_path, subfolder) 65 | 66 | if success: 67 | existing_folders.add(subfolder) 68 | else: 69 | print(f"Failed to upload {subfolder}. Skipping to next subfolder.") 70 | 71 | print("Upload process completed.") 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser(description="Upload subfolders to Hugging Face Hub") 75 | parser.add_argument("-f", "--folder", type=str, help="The parent folder containing subfolders to upload", required=True) 76 | parser.add_argument("-r", "--repo", type=str, help="The repo to upload to", required=True) 77 | parser.add_argument("--skip_create", action="store_true", help="Skip creating the repository") 78 | args = parser.parse_args() 79 | upload_subfolders(args) 80 | 81 | # python upload_dataset_by_subfolders.py -f data/tokenized_olmo/datasets--orionweller--algebraic-stack_mds_incremental/snapshots/5af697376cc89b191fef8b7873280e2c393e8361-tokenized -------------------------------------------------------------------------------- /retrieval_eval/README.md: -------------------------------------------------------------------------------- 1 | # Retrieval Evaluation 2 | 3 | This directory contains scripts and documentation for evaluating Ettin encoder and decoder models on retrieval tasks, including fine-tuning on MS MARCO and evaluation on MTEB v2 English benchmarks. 4 | 5 | ## Quick Start 6 | 7 | ### Installation 8 | 9 | ```bash 10 | # Install retrieval dependencies 11 | pip install sentence-transformers mteb 12 | ``` 13 | 14 | 15 | ## Training 16 | 17 | The `train_st.py` script allows you to fine-tune Ettin models (both encoder and decoder variants) on the MS MARCO dataset for retrieval tasks. The script supports both encoder-only models and decoder models with configurable pooling strategies. 18 | 19 | ### Usage 20 | 21 | ```bash 22 | python train_st.py --lr --model_name --model_out_dir --model_suffix --accum_steps --bsize [additional_options] 23 | ``` 24 | 25 | ### Required Arguments 26 | 27 | - `--lr`: Learning rate (float) 28 | - `--model_name`: Path or name of the base model to fine-tune 29 | - `--model_out_dir`: Directory where trained models will be saved 30 | - `--model_suffix`: Suffix to append to the run name for identification 31 | - `--accum_steps`: Number of gradient accumulation steps (int) 32 | - `--bsize`: Per-device training batch size (int) 33 | 34 | ### Optional Arguments 35 | 36 | - `--gc_bsize`: Gradient cache batch size for CachedMultipleNegativesRankingLoss (default: 64) 37 | - `--warmup_ratio`: Warmup ratio for learning rate scheduling (default: 0.05) 38 | - `--scale`: Temperature scaling parameter for the loss function (default: 20) 39 | - `--pooling`: Pooling strategy - choices: `lasttoken`, `mean`, `weightedmean` (default: `lasttoken`) 40 | - `--fp16`: Enable FP16 mixed precision training 41 | - `--bf16`: Enable BF16 mixed precision training 42 | - `--resume_training`: Resume training from checkpoint 43 | - `--decoder`: Use decoder model architecture instead of encoder 44 | 45 | ### Pooling Strategies 46 | 47 | - **`lasttoken`**: Use the last token's representation (suitable for decoder models) 48 | - **`mean`**: Average all token representations 49 | - **`weightedmean`**: Weighted average of token representations 50 | 51 | ### Training Examples 52 | 53 | #### Encoder Training 54 | ```bash 55 | python train_st.py \ 56 | --lr 3e-4 \ 57 | --model_name "jhu-clsp/ettin-encoder-17m" \ 58 | --model_out_dir "./models" \ 59 | --model_suffix "encoder-v1" \ 60 | --bf16 61 | ``` 62 | 63 | #### Decoder Model Training 64 | ```bash 65 | python train_st.py \ 66 | --lr 3e-4 \ 67 | --model_name "jhu-clsp/ettin-decoder-17m" \ 68 | --model_out_dir "./models" \ 69 | --model_suffix "decoder-v1" \ 70 | --decoder \ 71 | --pooling lasttoken \ 72 | --bf16 73 | ``` 74 | 75 | ### Evaluation Examples 76 | Evaluation was performed with [MTEB](https://github.com/embeddings-benchmark/mteb/tree/main). Please see their documentation for more. To reproduce on MTEB v2 Eng you can use the following: 77 | 78 | ```bash 79 | import mteb 80 | from sentence_transformers import SentenceTransformer 81 | 82 | # Define the sentence-transformers model name 83 | model_name = "path_to_your_model" 84 | benchmark = mteb.get_benchmark("MTEB(eng, v2)") 85 | evaluation = mteb.MTEB(tasks=benchmark) 86 | results = evaluation.run(model, output_folder=f"results/{model_name}") 87 | ``` --------------------------------------------------------------------------------