├── .github └── workflows │ └── checks.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode ├── extensions.json └── settings.json ├── LICENSE ├── README.md ├── configs └── stories │ ├── llama2 │ ├── 100k.json │ ├── 10m.json │ ├── 1m.json │ ├── 2.5m.json │ ├── 250k.json │ ├── 25m.json │ ├── 500k.json │ ├── 50k.json │ ├── 50m.json │ ├── 5m.json │ ├── README.md │ └── base.json │ └── mamba │ ├── 100k.json │ ├── 10m.json │ ├── 1m.json │ ├── 2.5m.json │ ├── 250k.json │ ├── 25m.json │ ├── 500k.json │ ├── 50k.json │ ├── 50m.json │ ├── 5m.json │ ├── README.md │ └── base.json ├── delphi ├── __init__.py ├── eval.py ├── test_configs │ ├── __init__.py │ └── debug.json ├── tokenization.py ├── train │ ├── __init__.py │ ├── checkpoint_step.py │ ├── config │ │ ├── __init__.py │ │ ├── adam_config.py │ │ ├── dataset_config.py │ │ ├── debug_config.py │ │ ├── training_config.py │ │ └── utils.py │ ├── run_context.py │ ├── shuffle.py │ ├── train_step.py │ ├── training.py │ ├── utils.py │ └── wandb_utils.py └── utils.py ├── notebooks ├── .gitkeep └── eval_notebook.ipynb ├── pyproject.toml ├── scripts ├── .gitkeep ├── get_next_logprobs.py ├── tokenize_dataset.py ├── train_model.py ├── train_tokenizer.py └── validate_configs.py ├── setup.py └── tests ├── __init__.py ├── test_eval.py ├── test_tokeniation.py ├── test_utils.py └── train ├── config └── test_config_utils.py ├── test_shuffle.py ├── test_train_step.py ├── test_utils.py └── test_wandb_utils.py /.github/workflows/checks.yml: -------------------------------------------------------------------------------- 1 | name: checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - "*" 10 | 11 | permissions: 12 | actions: write 13 | 14 | jobs: 15 | checks: 16 | name: checks 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v4 20 | with: 21 | submodules: recursive 22 | - name: setup python 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: "3.10" 26 | cache: "pip" 27 | - name: cache models and datasets 28 | uses: actions/cache@v3 29 | with: 30 | path: | 31 | ~/.cache/huggingface 32 | key: ${{ runner.os }}-hf-cache-v0.2 # increment this key to invalidate the cache when new models/datasets are added 33 | - name: dependencies 34 | run: | 35 | python -m pip install --upgrade pip 36 | pip install -e .[dev,notebooks] 37 | - name: black 38 | run: black --check . 39 | - name: isort 40 | run: isort --check . 41 | - name: pytest 42 | run: pytest 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | bin 10 | include 11 | lib64 12 | pyvenv.cfg 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | 167 | # ignore wandb files 168 | **/wandb/* 169 | **/*.wandb 170 | **/wandb-summary.json 171 | **/wandb-metadata.json 172 | 173 | # scratch notebook 174 | notebooks/scratch.ipynb 175 | 176 | # dsstore 177 | .DS_Store 178 | 179 | # vscode debug configs 180 | **/launch.json -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black-pre-commit-mirror 3 | rev: 23.12.1 4 | hooks: 5 | - id: black 6 | language_version: python3.10 7 | - repo: https://github.com/pycqa/isort 8 | rev: 5.13.2 9 | hooks: 10 | - id: isort 11 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.python", 4 | "ms-python.vscode-pylance", 5 | "ms-python.black-formatter", 6 | "ms-python.isort", 7 | "github.copilot", 8 | ] 9 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[python]": { 3 | "editor.defaultFormatter": "ms-python.black-formatter", 4 | }, 5 | "editor.formatOnSave": true, 6 | "editor.codeActionsOnSave": { 7 | "source.organizeImports": "explicit" 8 | }, 9 | "python.analysis.typeCheckingMode": "basic", 10 | "black-formatter.importStrategy": "fromEnvironment", 11 | "python.testing.pytestArgs": [ 12 | "tests" 13 | ], 14 | "python.testing.unittestEnabled": false, 15 | "python.testing.pytestEnabled": true, 16 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # delphi 2 | 3 | delphi is a set of tools for standardized and (mostly) reproducible training of small language models. You can use delphi to train a custom tokenizer, tokenize your dataset, and train your model. We build on top of HuggingFace, supporting every `CausalLM` architecture. Datasets, tokenizers and models (including checkpoints!) can be downloaded from and uploaded to HuggingFace automatically, with no need to manage local files. 4 | 5 | 6 | # Setup 7 | 8 | 1. Clone the repo 9 | ```shell 10 | git clone https://github.com/delphi-suite/delphi.git 11 | cd delphi 12 | ``` 13 | 2. Make & activate python >= 3.10 virtual env 14 | ```shell 15 | python3.10 -m venv .venv 16 | source .venv/bin/activate 17 | ``` 18 | 3. Install the project in editable state 19 | `pip install -e .` 20 | See `[project.optional-dependencies]` section in `pyproject.toml` for additional dependencies, e.g. you may want to `pip install -e ."[dev,mamba_cuda]"` 21 | 4. get your HuggingFace and W&B tokens and put them in the environment variables 22 | ```shell 23 | export HF_TOKEN=... 24 | export WANDB_API_KEY=... 25 | ``` 26 | 27 | 28 | # Training a tokenizer 29 | 30 | If you want to train a small and efficient model on a narrow dataset, then we recommend using a custom tokenizer with a small vocabulary. To train a reversible, GPT2-style, BPE tokenizer you can use `scripts/train_tokenizer.py`. 31 | 32 | Script usage: 33 | 34 | ``` 35 | > scripts/train_tokenizer.py --help 36 | usage: train_tokenizer.py [-h] --in-dataset IN_DATASET --feature FEATURE --split SPLIT 37 | --vocab-size VOCAB_SIZE 38 | [--out-dir OUT_DIR] [--out-repo OUT_REPO] 39 | 40 | Train a custom, reversible, BPE tokenizer (GPT2-like). You need to provide --out-repo or --out-dir. 41 | 42 | options: 43 | -h, --help show this help message and exit 44 | --in-dataset IN_DATASET, -i IN_DATASET 45 | Dataset you want to train the tokenizer on. Local path or HF repo id 46 | --feature FEATURE, -f FEATURE 47 | Name of the feature (column) containing text documents in the input dataset 48 | --split SPLIT, -s SPLIT 49 | Split of the dataset to be used for tokenizer training, supports slicing like 'train[:10%]' 50 | --vocab-size VOCAB_SIZE, -v VOCAB_SIZE 51 | Vocabulary size of the tokenizer 52 | --out-dir OUT_DIR Local directory to save the resulting tokenizer 53 | --out-repo OUT_REPO HF repo id to upload the resulting tokenizer 54 | ``` 55 | 56 | Here's how we trained the tokenizer for our `stories-*` suite of models. Please note that you can use single letter abbreviations for most arguments. 57 | 58 | ``` 59 | > scripts/train_tokenizer.py \ 60 | --in-dataset delphi-suite/stories \ 61 | --feature story \ 62 | --split train \ 63 | --vocab-size 4096 \ 64 | --out-repo delphi-suite/stories-tokenizer 65 | ``` 66 | 67 | We use the only feature named `story` in the `train` split of [delphi-suite/stories](https://huggingface.co/datasets/delphi-suite/stories). We train a tokenizer with a vocabulary of 4096 tokens, and upload it to HF model repo [delphi-suite/stories-tokenizer](https://huggingface.co/delphi-suite/stories-tokenizer). 68 | 69 | 70 | # Tokenizing a dataset 71 | 72 | To turn a collection of text documents into sequences of tokens required for model training, you can use `scripts/tokenize_dataset.py`. All documents are tokenized and concatenated, with the `` token as a separator, e.g. 73 | ``` 74 | doc1_tok1, doc1_tok2, ..., doc1_tokX, , doc2_tok1, doc2_tok2, ..., doc2_tokX, , doc3_tok1, ... 75 | ``` 76 | Then this is divided into chunks, and the `` token is inserted at the begining of each chunk, e.g. 77 | ``` 78 | doc1_tok1, doc1_tok2, ..., doc1_tokX, , doc2_tok1 79 | doc2_tok2, ..., doc2_tok511 80 | doc2_tok512, doc2_tok513, ..., doc2_tokX , doc3_tok1, ... 81 | ... 82 | ``` 83 | It will produce sequences of specified size, by discarding the last chunk if it's too short. We don't use padding. 84 | 85 | Script usage: 86 | 87 | ``` 88 | > scripts/tokenize_dataset.py --help 89 | usage: tokenize_dataset.py [-h] --in-dataset IN_DATASET --feature FEATURE --split SPLIT 90 | --tokenizer TOKENIZER --seq-len SEQ_LEN 91 | [--batch-size BATCH_SIZE] [--chunk-size CHUNK_SIZE] 92 | [--out-dir OUT_DIR] [--out-repo OUT_REPO] 93 | 94 | Tokenize a text dataset using a specific tokenizer 95 | 96 | options: 97 | -h, --help show this help message and exit 98 | --in-dataset IN_DATASET, -i IN_DATASET 99 | Dataset you want to tokenize. Local path or HF repo id 100 | --feature FEATURE, -f FEATURE 101 | Name of the feature (column) containing text documents in the input dataset 102 | --split SPLIT, -s SPLIT 103 | Split of the dataset to be tokenized, supports slicing like 'train[:10%]' 104 | --tokenizer TOKENIZER, -t TOKENIZER 105 | HF repo id or local directory containing the tokenizer 106 | --seq-len SEQ_LEN, -l SEQ_LEN 107 | Length of the tokenized sequences 108 | --batch-size BATCH_SIZE, -b BATCH_SIZE 109 | How many text documents to tokenize at once (default: 50) 110 | --chunk-size CHUNK_SIZE, -c CHUNK_SIZE 111 | Maximum number of tokenized sequences in a single parquet file (default: 200_000) 112 | --out-dir OUT_DIR Local directory to save the resulting dataset 113 | --out-repo OUT_REPO HF repo id to upload the resulting dataset 114 | ``` 115 | 116 | Here's how we tokenized the dataset for our `stories-*` suite of models. Please note that you can use single letter abbreviations for most arguments. 117 | 118 | For `train` split: 119 | ``` 120 | > scripts/tokenize_dataset.py \ 121 | --in-dataset delphi-suite/stories \ 122 | --feature story \ 123 | --split train \ 124 | --tokenizer delphi-suite/stories-tokenizer \ 125 | --seq-len 512 \ 126 | --out-repo delphi-suite/stories-tokenized 127 | ``` 128 | For `validation` split, repeated arguments omitted: 129 | ``` 130 | > scripts/tokenize_dataset.py \ 131 | ... 132 | --split validation \ 133 | ... 134 | ``` 135 | 136 | The input dataset is the same as in tokenizer training example above. We tokenize it with our custom [delphi-suite/stories-tokenizer](https://huggingface.co/delphi-suite/stories-tokenizer) into sequences of length 512. We upload it to HF dataset repo [delphi-suite/stories-tokenized](https://huggingface.co/datasets/delphi-suite/stories-tokenized). 137 | 138 | Please note that you can use any HuggingFace tokenizer, you don't need to train a custom one. 139 | 140 | # Training a model 141 | 142 | To train a model, you'll need to create a config file. For examples see `configs/`, and for field descriptions see `delphi/train/config/training_config.py`. The training script is located in `scripts/train_model.py`. 143 | 144 | Script usage: 145 | 146 | ``` 147 | > scripts/train_model.py --help 148 | usage: train_model.py [-h] [--overrides [OVERRIDES ...]] [-v | -s] [config_files ...] 149 | 150 | Train a delphi model 151 | 152 | positional arguments: 153 | config_files Path to json file(s) containing config values, e.g. 'primary_config.json secondary_config.json'. 154 | 155 | options: 156 | -h, --help show this help message and exit 157 | --overrides [OVERRIDES ...] 158 | Override config values with space-separated declarations. e.g. `--overrides model_config.hidden_size=42 run_name=foo` 159 | -v, --verbose Increase verbosity level, repeatable (e.g. -vvv). Mutually exclusive with --silent, --loglevel 160 | -s, --silent Silence all logging. Mutually exclusive with --verbose, --loglevel 161 | ``` 162 | 163 | You can specify primary config and secondary config, which is useful if you're training a suite of models that only differ in a few parameters. Additionally, you can override specific fields using the `--overrides` flag. If you don't want to push the model and its checkpoints to HF, you need to explicitly set `out_repo=""`. If you don't want to log to W&B, you need to set `wandb=""`. Please note that by default we save the optimizer state (2x model size) with every checkpoint. 164 | 165 | Here is how we trained our `stories-mamba-100k` model 166 | ``` 167 | > scripts/train_model.py \ 168 | configs/stories/mamba/base.json \ 169 | configs/stories/mamba/100k.json \ 170 | --overrides \ 171 | out_repo="delphi-suite/stories-mamba-100k" \ 172 | wandb="delphi-suite/delphi" 173 | ``` 174 | 175 | # Development 176 | 177 | 1. Install the `dev` and `notebooks` dependencies `pip install -e ."[dev,notebooks]"`. 178 | 2. Run the tests `pytest`. 179 | 3. Install pre-commit `pre-commit install`. 180 | 4. Install the recommended vscode extensions. 181 | 182 | When you save a file vscode should automatically format it. Otherwise, pre-commit will do that, but you will need to add the changes and commit again. 183 | 184 | # Citation 185 | 186 | If you use delphi in your research, please cite using the following 187 | 188 | ```bibtex 189 | @misc{delphi, 190 | title = {delphi: small language models training made easy}, 191 | author = {Janiak, J. and Dhyani, J. and Brinkmann, J. and Paulo, G. and Wendland, J. and Alonso, V. A. and Li, S. and Duong, P. A. and Rigg, A.}, 192 | year = 2024, 193 | url = {https://github.com/delphi-suite/delphi}, 194 | } 195 | ``` 196 | -------------------------------------------------------------------------------- /configs/stories/llama2/100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 12, 4 | "intermediate_size": 32, 5 | "num_attention_heads": 2, 6 | "num_hidden_layers": 1, 7 | "num_key_value_heads": 1 8 | } 9 | } -------------------------------------------------------------------------------- /configs/stories/llama2/10m.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 340, 4 | "intermediate_size": 907, 5 | "num_attention_heads": 10, 6 | "num_hidden_layers": 6, 7 | "num_key_value_heads": 5 8 | } 9 | } -------------------------------------------------------------------------------- /configs/stories/llama2/1m.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 84, 4 | "intermediate_size": 244, 5 | "num_attention_heads": 6, 6 | "num_hidden_layers": 4, 7 | "num_key_value_heads": 3 8 | } 9 | } -------------------------------------------------------------------------------- /configs/stories/llama2/2.5m.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 160, 4 | "intermediate_size": 426, 5 | "num_attention_heads": 8, 6 | "num_hidden_layers": 4, 7 | "num_key_value_heads": 4 8 | } 9 | } -------------------------------------------------------------------------------- /configs/stories/llama2/250k.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 28, 4 | "intermediate_size": 75, 5 | "num_attention_heads": 2, 6 | "num_hidden_layers": 2, 7 | "num_key_value_heads": 1 8 | } 9 | } -------------------------------------------------------------------------------- /configs/stories/llama2/25m.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 480, 4 | "intermediate_size": 1280, 5 | "num_attention_heads": 16, 6 | "num_hidden_layers": 8, 7 | "num_key_value_heads": 8 8 | } 9 | } -------------------------------------------------------------------------------- /configs/stories/llama2/500k.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 56, 4 | "intermediate_size": 149, 5 | "num_attention_heads": 4, 6 | "num_hidden_layers": 2, 7 | "num_key_value_heads": 2 8 | } 9 | } -------------------------------------------------------------------------------- /configs/stories/llama2/50k.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 6, 4 | "intermediate_size": 16, 5 | "num_attention_heads": 3, 6 | "num_hidden_layers": 1, 7 | "num_key_value_heads": 1 8 | } 9 | } -------------------------------------------------------------------------------- /configs/stories/llama2/50m.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 704, 4 | "intermediate_size": 1877, 5 | "num_attention_heads": 16, 6 | "num_hidden_layers": 8, 7 | "num_key_value_heads": 8 8 | } 9 | } -------------------------------------------------------------------------------- /configs/stories/llama2/5m.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 240, 4 | "intermediate_size": 640, 5 | "num_attention_heads": 10, 6 | "num_hidden_layers": 5, 7 | "num_key_value_heads": 5 8 | } 9 | } -------------------------------------------------------------------------------- /configs/stories/llama2/README.md: -------------------------------------------------------------------------------- 1 | - use_cache - using default 2 | - pretraining_tp - experimental parallelization we're not using, which is the default 3 | - tie_word_embeddings - llama2 used False and this is better for interpretability, note that llama2.c is using True by default, which is probably more efficient use of parameters for very small models 4 | - rope settings are widely used defaults 5 | - attention_bias - no biases on QKV and output projection is the default and that's what we're using 6 | - attention_dropout - this is the only dropout llama2 can use, it's set to prob=0 by default and that's what we're using -------------------------------------------------------------------------------- /configs/stories/llama2/base.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "model_class": "LlamaForCausalLM", 4 | "vocab_size": 4096, 5 | "hidden_act": "silu", 6 | "max_position_embeddings": 512, 7 | "initializer_range": 0.02, 8 | "rms_norm_eps": 1e-06, 9 | "bos_token_id": 0, 10 | "eos_token_id": 1, 11 | "pad_token_id": 2, 12 | "tie_word_embeddings": false, 13 | "rope_theta": 10000.0, 14 | "rope_scaling": null, 15 | "attention_bias": false, 16 | "attention_dropout": 0.0 17 | }, 18 | "max_seq_len": 512, 19 | "device": "auto", 20 | "checkpoint_interval": 400, 21 | "extra_checkpoint_iters": [ 22 | 1, 23 | 2, 24 | 4, 25 | 8, 26 | 16, 27 | 32, 28 | 64, 29 | 128, 30 | 256, 31 | 512 32 | ], 33 | "log_interval": 40, 34 | "eval_iters": 10, 35 | "batch_size": 256, 36 | "max_epochs": 10, 37 | "grad_clip": 1.0, 38 | "gradient_accumulation_steps": 1, 39 | "adam": { 40 | "learning_rate": 0.0005, 41 | "weight_decay": 0.1, 42 | "beta1": 0.9, 43 | "beta2": 0.95, 44 | "decay_lr": true, 45 | "warmup_iters": 1000, 46 | "min_lr": 0.0 47 | }, 48 | "batch_ordering_seed": 1337, 49 | "torch_seed": 42, 50 | "dataset": { 51 | "path": "delphi-suite/stories-tokenized" 52 | }, 53 | "tokenizer": "delphi-suite/stories-tokenizer" 54 | } -------------------------------------------------------------------------------- /configs/stories/mamba/100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 24, 4 | "num_hidden_layers": 2 5 | } 6 | } -------------------------------------------------------------------------------- /configs/stories/mamba/10m.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 400, 4 | "num_hidden_layers": 8 5 | } 6 | } -------------------------------------------------------------------------------- /configs/stories/mamba/1m.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 112, 4 | "num_hidden_layers": 6 5 | } 6 | } -------------------------------------------------------------------------------- /configs/stories/mamba/2.5m.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 204, 4 | "num_hidden_layers": 6 5 | } 6 | } -------------------------------------------------------------------------------- /configs/stories/mamba/250k.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 36, 4 | "num_hidden_layers": 4 5 | } 6 | } -------------------------------------------------------------------------------- /configs/stories/mamba/25m.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 664, 4 | "num_hidden_layers": 8 5 | } 6 | } -------------------------------------------------------------------------------- /configs/stories/mamba/500k.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 76, 4 | "num_hidden_layers": 4 5 | } 6 | } -------------------------------------------------------------------------------- /configs/stories/mamba/50k.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 12, 4 | "num_hidden_layers": 2 5 | } 6 | } -------------------------------------------------------------------------------- /configs/stories/mamba/50m.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 952, 4 | "num_hidden_layers": 8 5 | } 6 | } -------------------------------------------------------------------------------- /configs/stories/mamba/5m.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "hidden_size": 308, 4 | "num_hidden_layers": 6 5 | } 6 | } -------------------------------------------------------------------------------- /configs/stories/mamba/README.md: -------------------------------------------------------------------------------- 1 | - layer_norm_eps - different than rms norm eps in llama 2 | - initializer_range - different in mamba & llama 3 | - residual_in_fp32 - mamba specific parameter 4 | - time_step_* - mamba specific, sane defaults 5 | - there is no way to untie embeddings and unembeddings in mamba, they're tied by default https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/mamba/modeling_mamba.py#L602-L610 6 | - rescale_prenorm_residual was True in original paper, so we set it to True, despite HF default being false 7 | - using default for use_cache 8 | - state_size is default -------------------------------------------------------------------------------- /configs/stories/mamba/base.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "model_class": "MambaForCausalLM", 4 | "vocab_size": 4096, 5 | "state_size": 16, 6 | "layer_norm_epsilon": 1e-5, 7 | "bos_token_id": 0, 8 | "eos_token_id": 1, 9 | "pad_token_id": 2, 10 | "expand": 2, 11 | "conv_kernel": 4, 12 | "use_bias": false, 13 | "use_conv_bias": true, 14 | "hidden_act": "silu", 15 | "initializer_range": 0.1, 16 | "residual_in_fp32": true, 17 | "rescale_prenorm_residual": true 18 | }, 19 | "max_seq_len": 512, 20 | "device": "auto", 21 | "checkpoint_interval": 400, 22 | "extra_checkpoint_iters": [ 23 | 1, 24 | 2, 25 | 4, 26 | 8, 27 | 16, 28 | 32, 29 | 64, 30 | 128, 31 | 256, 32 | 512 33 | ], 34 | "log_interval": 40, 35 | "eval_iters": 10, 36 | "batch_size": 256, 37 | "max_epochs": 10, 38 | "grad_clip": 1.0, 39 | "gradient_accumulation_steps": 1, 40 | "adam": { 41 | "learning_rate": 0.0005, 42 | "weight_decay": 0.1, 43 | "beta1": 0.9, 44 | "beta2": 0.95, 45 | "decay_lr": true, 46 | "warmup_iters": 1000, 47 | "min_lr": 0.0 48 | }, 49 | "batch_ordering_seed": 1337, 50 | "torch_seed": 42, 51 | "dataset": { 52 | "path": "delphi-suite/stories-tokenized" 53 | }, 54 | "tokenizer": "delphi-suite/stories-tokenizer" 55 | } -------------------------------------------------------------------------------- /delphi/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.resources import files 2 | from pathlib import Path 3 | from typing import cast 4 | 5 | from beartype.claw import beartype_this_package # <-- hype comes 6 | 7 | beartype_this_package() # <-- hype goes 8 | 9 | __version__ = "0.2" 10 | TEST_CONFIGS_DIR = cast(Path, files("delphi.test_configs")) 11 | -------------------------------------------------------------------------------- /delphi/eval.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import uuid 4 | from typing import Any, Optional, cast 5 | 6 | import numpy as np 7 | import panel as pn 8 | import plotly.graph_objects as go 9 | import torch 10 | from datasets import Dataset 11 | from IPython.core.display import HTML 12 | from IPython.core.display_functions import display 13 | from jaxtyping import Float, Int 14 | from transformers import PreTrainedTokenizerBase 15 | 16 | 17 | def single_loss_diff_to_color(loss_diff: float) -> str: 18 | # if loss_diff is negative, we want the color to be red 19 | # if loss_diff is positive, we want the color to be green 20 | # if loss_diff is 0, we want the color to be white 21 | # the color should be more intense the larger the absolute value of loss_diff 22 | 23 | def sigmoid(x: float) -> float: 24 | return 1 / (1 + math.exp(-x)) 25 | 26 | scaled_loss_diff = sigmoid(loss_diff) # scale to 0-1 27 | 28 | if scaled_loss_diff < 0.5: # red 29 | red_val = 255 30 | green_blue_val = min(int(255 * 2 * scaled_loss_diff), 255) 31 | return f"rgb({red_val}, {green_blue_val}, {green_blue_val})" 32 | else: # green 33 | green_val = 255 34 | red_blue_val = min(int(255 * 2 * (1 - scaled_loss_diff)), 255) 35 | return f"rgb({red_blue_val}, {green_val}, {red_blue_val})" 36 | 37 | 38 | def token_to_html( 39 | token: int, 40 | tokenizer: PreTrainedTokenizerBase, 41 | bg_color: str, 42 | data: dict, 43 | class_name: str = "token", 44 | ) -> str: 45 | data = data or {} # equivalent to if not data: data = {} 46 | # non-breakable space, w/o it leading spaces wouldn't be displayed 47 | str_token = tokenizer.decode(token).replace(" ", " ") 48 | 49 | # background or user-select (for \n) goes here 50 | specific_styles = {} 51 | # for now just adds line break or doesn't 52 | br = "" 53 | 54 | if bg_color: 55 | specific_styles["background-color"] = bg_color 56 | if str_token == "\n": 57 | # replace new line character with two characters: \ and n 58 | str_token = r"\n" 59 | # add line break in html 60 | br += "
" 61 | # this is so we can copy the prompt without "\n"s 62 | specific_styles["user-select"] = "none" 63 | str_token = str_token.replace("<", "<").replace(">", ">") 64 | 65 | style_str = data_str = "" 66 | # converting style dict into the style attribute 67 | if specific_styles: 68 | inside_style_str = "; ".join(f"{k}: {v}" for k, v in specific_styles.items()) 69 | style_str = f" style='{inside_style_str}'" 70 | if data: 71 | data_str = "".join( 72 | f" data-{k}='{v.replace(' ', ' ')}'" for k, v in data.items() 73 | ) 74 | return f"
{str_token}
{br}" 75 | 76 | 77 | _token_style = { 78 | "border": "1px solid #888", 79 | "display": "inline-block", 80 | # each character of the same width, so we can easily spot a space 81 | "font-family": "monospace", 82 | "font-size": "14px", 83 | "color": "black", 84 | "background-color": "white", 85 | "margin": "1px 0px 1px 1px", 86 | "padding": "0px 1px 1px 1px", 87 | } 88 | _token_emphasized_style = { 89 | "border": "3px solid #888", 90 | "display": "inline-block", 91 | "font-family": "monospace", 92 | "font-size": "14px", 93 | "color": "black", 94 | "background-color": "white", 95 | "margin": "1px 0px 1px 1px", 96 | "padding": "0px 1px 1px 1px", 97 | } 98 | _token_style_str = " ".join([f"{k}: {v};" for k, v in _token_style.items()]) 99 | _token_emphasized_style_str = " ".join( 100 | [f"{k}: {v};" for k, v in _token_emphasized_style.items()] 101 | ) 102 | 103 | 104 | def vis_pos_map( 105 | pos_list: list[tuple[int, int]], 106 | selected_tokens: list[int], 107 | metrics: Float[torch.Tensor, "prompt pos"], 108 | token_ids: Int[torch.Tensor, "prompt pos"], 109 | tokenizer: PreTrainedTokenizerBase, 110 | ): 111 | """ 112 | Randomly sample from pos_map and visualize the loss diff at the corresponding position. 113 | """ 114 | 115 | token_htmls = [] 116 | unique_id = str(uuid.uuid4()) 117 | token_class = f"pretoken_{unique_id}" 118 | selected_token_class = f"token_{unique_id}" 119 | hover_div_id = f"hover_info_{unique_id}" 120 | 121 | # choose a random keys from pos_map 122 | key = random.choice(pos_list) 123 | 124 | prompt, pos = key 125 | all_toks = token_ids[prompt][: pos + 1] 126 | 127 | for i in range(all_toks.shape[0]): 128 | token_id = cast(int, all_toks[i].item()) 129 | value = metrics[prompt][i].item() 130 | token_htmls.append( 131 | token_to_html( 132 | token_id, 133 | tokenizer, 134 | bg_color="white" 135 | if np.isnan(value) 136 | else single_loss_diff_to_color(value), 137 | data={"loss-diff": f"{value:.2f}"}, 138 | class_name=token_class 139 | if token_id not in selected_tokens 140 | else selected_token_class, 141 | ) 142 | ) 143 | 144 | # add break line 145 | token_htmls.append("

") 146 | 147 | html_str = f""" 148 | 149 | {"".join(token_htmls)}
150 | 172 | """ 173 | display(HTML(html_str)) 174 | 175 | 176 | def token_selector( 177 | vocab_map: dict[str, int] 178 | ) -> tuple[pn.widgets.MultiChoice, list[int]]: 179 | tokens = list(vocab_map.keys()) 180 | token_selector_ = pn.widgets.MultiChoice(name="Tokens", options=tokens) 181 | token_ids = [vocab_map[token] for token in cast(list[str], token_selector_.value)] 182 | 183 | def update_tokens(event): 184 | token_ids.clear() 185 | token_ids.extend([vocab_map[token] for token in event.new]) 186 | 187 | token_selector_.param.watch(update_tokens, "value") 188 | return token_selector_, token_ids 189 | 190 | 191 | def calc_model_group_stats( 192 | tokenized_corpus_dataset: Dataset, 193 | logprobs_by_dataset: dict[str, torch.Tensor], 194 | selected_tokens: list[int], 195 | ) -> dict[str, dict[str, float]]: 196 | """ 197 | For each (model, token group) pair, calculate useful stats (for visualization) 198 | 199 | args: 200 | - tokenized_corpus_dataset: a list of the tokenized corpus datasets, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"] 201 | - logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]} 202 | - selected_tokens: a list of selected token IDs, e.g. [46, 402, ...] 203 | 204 | returns: a dict of model names as keys and stats dict as values 205 | e.g. {"100k": {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...} 206 | 207 | Stats calculated: mean, median, min, max, 25th percentile, 75th percentile 208 | """ 209 | model_group_stats = {} 210 | for model in logprobs_by_dataset: 211 | model_logprobs = [] 212 | print(f"Processing model {model}") 213 | dataset = logprobs_by_dataset[model] 214 | for ix_doc_lp, document_lps in enumerate(dataset): 215 | tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"] 216 | for ix_token, token in enumerate(tokens): 217 | if ix_token == 0: # skip the first token, which isn't predicted 218 | continue 219 | logprob = document_lps[ix_token].item() 220 | if token in selected_tokens: 221 | model_logprobs.append(logprob) 222 | 223 | if model_logprobs: 224 | model_group_stats[model] = { 225 | "mean": np.mean(model_logprobs), 226 | "median": np.median(model_logprobs), 227 | "min": np.min(model_logprobs), 228 | "max": np.max(model_logprobs), 229 | "25th": np.percentile(model_logprobs, 25), 230 | "75th": np.percentile(model_logprobs, 75), 231 | } 232 | return model_group_stats 233 | 234 | 235 | def dict_filter_quantile( 236 | d: dict[Any, float], q_start: float, q_end: float 237 | ) -> dict[Any, float]: 238 | if not (0 <= q_start < q_end <= 1): 239 | raise ValueError("Invalid quantile range") 240 | q_start_val = np.nanquantile(list(d.values()), q_start) 241 | q_end_val = np.nanquantile(list(d.values()), q_end) 242 | return { 243 | k: v for k, v in d.items() if q_start_val <= v <= q_end_val and not np.isnan(v) 244 | } 245 | 246 | 247 | def get_all_tok_metrics_in_label( 248 | token_ids: Int[torch.Tensor, "prompt pos"], 249 | selected_tokens: list[int], 250 | metrics: torch.Tensor, 251 | q_start: Optional[float] = None, 252 | q_end: Optional[float] = None, 253 | ) -> dict[tuple[int, int], float]: 254 | """ 255 | From the token_map, get all the positions of the tokens that have a certain label. 256 | We don't use the token_map because for sampling purposes, iterating through token_ids is more efficient. 257 | Optionally, filter the tokens based on the quantile range of the metrics. 258 | 259 | Args: 260 | - token_ids (Dataset): token_ids dataset e.g. token_ids[0] = {"tokens": [[1, 2, ...], [2, 5, ...], ...]} 261 | - selected_tokens (list[int]): list of token IDs to search for e.g. [46, 402, ...] 262 | - metrics (torch.Tensor): tensor of metrics to search through e.g. torch.tensor([[0.1, 0.2, ...], [0.3, 0.4, ...], ...]) 263 | - q_start (float): the start of the quantile range to filter the metrics e.g. 0.1 264 | - q_end (float): the end of the quantile range to filter the metrics e.g. 0.9 265 | 266 | Returns: 267 | - tok_positions (dict[tuple[int, int], Number]): dictionary of token positions and their corresponding metrics 268 | """ 269 | 270 | # check if metrics have the same dimensions as token_ids 271 | if metrics.shape != token_ids.shape: 272 | raise ValueError( 273 | f"Expected metrics to have the same shape as token_ids, but got {metrics.shape} and {token_ids.shape} instead." 274 | ) 275 | 276 | tok_positions = {} 277 | for prompt_pos, prompt in enumerate(token_ids.numpy()): 278 | for tok_pos, tok in enumerate(prompt): 279 | if tok in selected_tokens: 280 | tok_positions[(prompt_pos, tok_pos)] = metrics[ 281 | prompt_pos, tok_pos 282 | ].item() 283 | 284 | if q_start is not None and q_end is not None: 285 | tok_positions = dict_filter_quantile(tok_positions, q_start, q_end) 286 | 287 | return tok_positions 288 | 289 | 290 | def visualize_selected_tokens( 291 | input: dict[str | int, tuple[float, float, float]], 292 | log_scale=False, 293 | line_metric="Means", 294 | checkpoint_mode=True, 295 | shade_color="rgba(68, 68, 68, 0.3)", 296 | line_color="rgb(31, 119, 180)", 297 | bar_color="purple", 298 | marker_color="SkyBlue", 299 | background_color="AliceBlue", 300 | ) -> go.FigureWidget: 301 | input_x = list(input.keys()) 302 | 303 | def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]: 304 | return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)] 305 | 306 | def get_plot_values() -> tuple[np.ndarray, np.ndarray, np.ndarray]: 307 | x = np.array([input[x] for x in input_x]).T 308 | means, err_lo, err_hi = x[0], x[1], x[2] 309 | return means, err_lo, err_hi 310 | 311 | means, err_lo, err_hi = get_plot_values() 312 | 313 | if checkpoint_mode: 314 | scatter_plot = go.Figure( 315 | [ 316 | go.Scatter( 317 | name="Upper Bound", 318 | x=input_x, 319 | y=means + err_hi, 320 | mode="lines", 321 | marker=dict(color=shade_color), 322 | line=dict(width=0), 323 | showlegend=False, 324 | ), 325 | go.Scatter( 326 | name="Lower Bound", 327 | x=input_x, 328 | y=means - err_lo, 329 | marker=dict(color=shade_color), 330 | line=dict(width=0), 331 | mode="lines", 332 | fillcolor=shade_color, 333 | fill="tonexty", 334 | showlegend=False, 335 | ), 336 | go.Scatter( 337 | name=line_metric, 338 | x=input_x, 339 | y=means, 340 | mode="lines", 341 | marker=dict( 342 | color=line_color, 343 | size=0, 344 | line=dict(color=line_color, width=1), 345 | ), 346 | ), 347 | ] 348 | ) 349 | else: 350 | scatter_plot = go.Scatter( 351 | x=input_x, 352 | y=means, 353 | error_y=dict( 354 | type="data", 355 | symmetric=False, 356 | array=err_hi, 357 | arrayminus=err_lo, 358 | color=bar_color, 359 | ), 360 | marker=dict( 361 | color=marker_color, 362 | size=15, 363 | line=dict(color=line_color, width=2), 364 | ), 365 | hovertext=get_hovertexts(means, err_lo, err_hi), 366 | hoverinfo="text+x", 367 | ) 368 | g = go.FigureWidget( 369 | data=scatter_plot, 370 | layout=go.Layout( 371 | yaxis=dict( 372 | title="Loss", 373 | type="log" if log_scale else "linear", 374 | ), 375 | plot_bgcolor=background_color, 376 | ), 377 | ) 378 | 379 | return g 380 | -------------------------------------------------------------------------------- /delphi/test_configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delphi-suite/delphi/1efcabec4bceffbf4e9a383d05958f8c0704d2ce/delphi/test_configs/__init__.py -------------------------------------------------------------------------------- /delphi/test_configs/debug.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_seq_len": 512, 3 | "max_epochs": 2, 4 | "eval_iters": 1, 5 | "batch_ordering_seed": 42, 6 | "torch_seed": 1337, 7 | "batch_size": 64, 8 | "model_config": { 9 | "model_class": "LlamaForCausalLM", 10 | "hidden_size": 48, 11 | "intermediate_size": 48, 12 | "num_attention_heads": 2, 13 | "num_hidden_layers": 2, 14 | "num_key_value_heads": 2, 15 | "vocab_size": 4096 16 | }, 17 | "dataset": { 18 | "path": "delphi-suite/stories-tokenized" 19 | }, 20 | "out_repo": "", 21 | "wandb": "" 22 | } -------------------------------------------------------------------------------- /delphi/tokenization.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections import deque 3 | from collections.abc import Iterator 4 | 5 | from datasets import Dataset 6 | from transformers import PreTrainedTokenizerBase 7 | 8 | 9 | def extend_deque( 10 | deq: deque[int], 11 | context_size: int, 12 | dataset: Dataset, 13 | doc_idx: int, 14 | tokenizer: PreTrainedTokenizerBase, 15 | batch_size: int, 16 | ) -> int: 17 | """ 18 | Extends the deque with tokenized text documents until the deque grows large 19 | enough to reach the context size, or until all text documents are processed. 20 | 21 | The usage of a deque here aims to save the memory as opposed to 22 | load all the documents and tokenize them at once. 23 | 24 | Args: 25 | dq: Deque to extend with tokenized tokens. 26 | context_size: Size of the context(input sequences). 27 | text_documents: List of (untokenized) text documents to be tokenized. 28 | doc_idx: Index of the current text story. 29 | tokenizer: Tokenizer to encode the text strings. 30 | batch_size: The size of input into batched tokenization. 31 | Returns: 32 | int: Updated index in the text documents dataset. 33 | """ 34 | feature = dataset.column_names[0] 35 | while len(deq) < context_size and doc_idx < len(dataset): 36 | documents = dataset[doc_idx : doc_idx + batch_size][feature] 37 | batch_input_ids = tokenizer( 38 | documents, return_attention_mask=False, add_special_tokens=False 39 | )["input_ids"] 40 | for input_ids in batch_input_ids: # type: ignore 41 | deq.extend(input_ids + [tokenizer.eos_token_id]) 42 | doc_idx += batch_size 43 | return doc_idx 44 | 45 | 46 | def make_new_sample(deq: deque[int], seq_len: int, bos_token_id: int) -> list[int]: 47 | """ 48 | Generates new sample for training by creating sequence of tokens 49 | from the deque until the deque. 50 | 51 | Note: the model is unable to use the last token in an input sequence, 52 | so we repeat this token in the next input sequence. 53 | 54 | Args: 55 | deq: Deque containing tokenized tokens. 56 | context_size: Size of the context (input sequences). 57 | bos_token_id: bos_token_id of the tokenizer used. 58 | 59 | Returns: 60 | list[int]: token sequence. 61 | """ 62 | sample = [bos_token_id] 63 | # For the first n-2 elements, pop from the left of the deque 64 | # and add to the new sample, the (n-1)-th element will be retained 65 | # in the deque for making the next sample. 66 | for _ in range(seq_len - 2): 67 | sample.append(deq.popleft()) 68 | sample.append(deq[0]) 69 | return sample 70 | 71 | 72 | def tokenize_dataset( 73 | dataset: Dataset, 74 | tokenizer: PreTrainedTokenizerBase, 75 | seq_len: int, 76 | batch_size: int, 77 | ) -> Iterator[list[int]]: 78 | """ 79 | Tokenizes the input text documents using the provided tokenizer and 80 | generates token sequences of the specified length. 81 | 82 | Args: 83 | text_documents: List[str], 84 | tokenizer, 85 | context_size, 86 | batch_size: The size of input into batched tokenization. 87 | 88 | Returns: 89 | oken sequences of length equal to context_size. 90 | """ 91 | assert tokenizer.bos_token_id is not None 92 | deq = deque() 93 | doc_idx = 0 94 | # iterate through the text documents and tokenize them 95 | while doc_idx < len(dataset): 96 | doc_idx = extend_deque(deq, seq_len, dataset, doc_idx, tokenizer, batch_size) 97 | yield make_new_sample(deq, seq_len, tokenizer.bos_token_id) 98 | # We discard the last chunk, so no processing on the remainder of the deque here 99 | 100 | 101 | def get_tokenized_chunks( 102 | dataset_split: Dataset, 103 | tokenizer: PreTrainedTokenizerBase, 104 | seq_len: int, 105 | batch_size: int, 106 | chunk_size: int, 107 | ) -> Iterator[Dataset]: 108 | seq_it = tokenize_dataset( 109 | dataset_split, 110 | tokenizer, 111 | seq_len=seq_len, 112 | batch_size=batch_size, 113 | ) 114 | while tokens_chunk := tuple(itertools.islice(seq_it, chunk_size)): 115 | yield Dataset.from_dict({"tokens": tokens_chunk}) 116 | -------------------------------------------------------------------------------- /delphi/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delphi-suite/delphi/1efcabec4bceffbf4e9a383d05958f8c0704d2ce/delphi/train/__init__.py -------------------------------------------------------------------------------- /delphi/train/checkpoint_step.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from datasets import Dataset 5 | 6 | from .config import TrainingConfig 7 | from .run_context import RunContext 8 | from .utils import ModelTrainingState, count_tokens_so_far, estimate_loss, save_results 9 | from .wandb_utils import log_to_wandb 10 | 11 | 12 | def should_save_checkpoint(config: TrainingConfig, mts: ModelTrainingState): 13 | return ( 14 | mts.iter_num % config.checkpoint_interval == 0 15 | and mts.iter_num > 0 16 | or mts.iter_num in config.extra_checkpoint_iters 17 | ) 18 | 19 | 20 | def log_and_save_checkpoint( 21 | config: TrainingConfig, 22 | mts: ModelTrainingState, 23 | train_ds: Dataset, 24 | validation_ds: Dataset, 25 | run_context: RunContext, 26 | ): 27 | """ 28 | Save a checkpoint of the current model + training state, evaluate, and optionally upload to huggingface and log to wandb (if configured) 29 | """ 30 | model = mts.model 31 | if config.debug_config.no_eval: 32 | logging.debug("no_eval=True, skipping evaluation and using dummy losses") 33 | losses = {"train": 42.0, "val": 43.0} 34 | else: 35 | losses = estimate_loss( 36 | model=model, 37 | eval_iters=config.eval_iters, 38 | batch_size=config.batch_size, 39 | split_to_ds={"train": train_ds, "val": validation_ds}, 40 | device=run_context.device, 41 | feature_name=config.dataset.feature, 42 | ) 43 | logging.info( 44 | f"step {mts.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" 45 | ) 46 | save_results(config=config, train_results=mts, run_context=run_context) 47 | if config.wandb: 48 | log_to_wandb( 49 | mts=mts, 50 | losses=losses, 51 | tokens_so_far=count_tokens_so_far(config, mts), 52 | ) 53 | -------------------------------------------------------------------------------- /delphi/train/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .adam_config import AdamConfig 2 | from .training_config import TrainingConfig 3 | from .utils import ( 4 | build_config_dict_from_files, 5 | build_config_from_files_and_overrides, 6 | dot_notation_to_dict, 7 | get_user_config_path, 8 | ) 9 | -------------------------------------------------------------------------------- /delphi/train/config/adam_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from beartype import beartype 4 | 5 | 6 | @beartype 7 | @dataclass 8 | class AdamConfig: 9 | # adamw optimizer 10 | learning_rate: float = 5e-4 # max learning rate 11 | weight_decay: float = 1e-1 12 | beta1: float = 0.9 13 | beta2: float = 0.95 14 | # learning rate decay settings 15 | decay_lr: bool = True # whether to decay the learning rate 16 | warmup_iters: int = 1000 # how many steps to warm up for 17 | min_lr: float = 0.0 # should be ~learning_rate/10 per Chinchill 18 | -------------------------------------------------------------------------------- /delphi/train/config/dataset_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from beartype import beartype 4 | from datasets import Dataset 5 | 6 | from delphi import utils 7 | 8 | 9 | @beartype 10 | @dataclass(frozen=True) 11 | class DatasetConfig: 12 | # tokenized dataset; HF repo id or local directory 13 | path: str 14 | 15 | # feature in the dataset; should be a list of <= max_seq_len token ints 16 | feature: str = "tokens" 17 | 18 | # split of the dataset to use for training 19 | train_split: str = "train" 20 | 21 | # split of the dataset to use for validation 22 | validation_split: str = "validation" 23 | 24 | def _load(self, split) -> Dataset: 25 | ds = utils.load_dataset_split_sequence_int32_feature( 26 | self.path, split, self.feature 27 | ) 28 | ds.set_format("torch") 29 | return ds 30 | 31 | def load_train(self) -> Dataset: 32 | return self._load(self.train_split) 33 | 34 | def load_validation(self) -> Dataset: 35 | return self._load(self.validation_split) 36 | -------------------------------------------------------------------------------- /delphi/train/config/debug_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from beartype import beartype 4 | 5 | 6 | @beartype 7 | @dataclass(frozen=True) 8 | class DebugConfig: 9 | no_training: bool = field( 10 | default=False, metadata={"help": "skip all actual training, do everything else"} 11 | ) 12 | no_eval: bool = field(default=False, metadata={"help": "skip actual evaluation"}) 13 | -------------------------------------------------------------------------------- /delphi/train/config/training_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from datetime import datetime 4 | from typing import Any, Optional 5 | 6 | import platformdirs 7 | from beartype import beartype 8 | 9 | from .adam_config import AdamConfig 10 | from .dataset_config import DatasetConfig 11 | from .debug_config import DebugConfig 12 | 13 | 14 | @beartype 15 | @dataclass(frozen=True, kw_only=True) 16 | class TrainingConfig: 17 | # model config; class_name=name of model class in transformers, everything else is kwargs for the corresponding model config 18 | model_config: dict[str, Any] 19 | 20 | max_seq_len: int 21 | run_name: str = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 22 | out_dir: str = os.path.join(platformdirs.user_data_dir(appname="delphi"), run_name) 23 | 24 | # device to use (cuda, mps, cpu) 25 | device: str = "auto" 26 | 27 | # checkpoint every N iters 28 | checkpoint_interval: int = 2000 29 | 30 | # manually list iterations to save checkpoints on 31 | extra_checkpoint_iters: list[int] = field(default_factory=list) 32 | 33 | # log to the console every N iters; this doesn't control wandb logging which is done only on checkpoints 34 | log_interval: int = 1 35 | 36 | # FIXME: there is a bug in the current implementation, and eval loss is computed on the 37 | # entire dataset. In this implementation, eval_iters controls the number of minibatches 38 | # the dataset is split into for evaluation. 39 | eval_iters: int = 100 40 | 41 | # path to a checkpoint to resume from 42 | resume_from_path: Optional[str] = None 43 | 44 | # number of samples used to compute the gradient for a single optimizer step 45 | batch_size: int = 64 46 | 47 | # total number of training epochs 48 | max_epochs: int = 10 49 | 50 | # clip gradients at this value, or disable if == 0.0 51 | grad_clip: float = 1.0 52 | 53 | # if > 1 reduces memory usage by computing gradient in microbatches 54 | gradient_accumulation_steps: int = 1 55 | 56 | # AdamW optimizer 57 | adam: AdamConfig = field(default_factory=AdamConfig) 58 | 59 | # seed used for pseudorandomly sampling data during training 60 | batch_ordering_seed: int 61 | 62 | # seed used for torch 63 | torch_seed: int 64 | 65 | # whether to save the optimizer state with each checkpoint 66 | # this is twice as large as the model, but allows to resume training in a reproducible way 67 | save_optimizer: bool = True 68 | 69 | # specify training and validation data 70 | dataset: DatasetConfig 71 | 72 | # HF repo id or local directory containing the tokenizer. Used only to upload it to HF with the model, not for training 73 | tokenizer: str = "" 74 | 75 | # wandb config in 'entity/project' form. Set to empty string to not use wandb. 76 | wandb: str 77 | 78 | # HF repo id. Set to empty string to not push to repo. 79 | out_repo: str 80 | 81 | # debug config 82 | debug_config: DebugConfig = field(default_factory=DebugConfig) 83 | -------------------------------------------------------------------------------- /delphi/train/config/utils.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import logging 4 | import os 5 | from dataclasses import fields, is_dataclass 6 | from datetime import datetime 7 | from pathlib import Path 8 | from typing import _GenericAlias # type: ignore 9 | from typing import Any, Type, TypeVar, Union 10 | 11 | import platformdirs 12 | from dacite import Config as dacite_config 13 | from dacite import from_dict 14 | 15 | from .training_config import TrainingConfig 16 | 17 | T = TypeVar("T") 18 | 19 | 20 | def merge_two_dicts(merge_into: dict[str, Any], merge_from: dict[str, Any]): 21 | """recursively merge two dicts, with values in merge_from taking precedence""" 22 | for key, val in merge_from.items(): 23 | if ( 24 | key in merge_into 25 | and isinstance(merge_into[key], dict) 26 | and isinstance(val, dict) 27 | ): 28 | merge_two_dicts(merge_into[key], val) 29 | else: 30 | merge_into[key] = val 31 | 32 | 33 | def merge_dicts(*dicts: dict[str, Any]) -> dict[str, Any]: 34 | """ 35 | Recursively merge multiple dictionaries, with later dictionaries taking precedence. 36 | """ 37 | merged = {} 38 | for d in dicts: 39 | merge_two_dicts(merged, d) 40 | return merged 41 | 42 | 43 | def get_user_config_path() -> Path: 44 | """ 45 | This enables a user-specific config to always be included in the training config. 46 | 47 | This is useful for things like wandb config, where you'll generally want to use your own account. 48 | """ 49 | _user_config_dir = Path(platformdirs.user_config_dir(appname="delphi")) 50 | _user_config_dir.mkdir(parents=True, exist_ok=True) 51 | user_config_path = _user_config_dir / "config.json" 52 | return user_config_path 53 | 54 | 55 | def build_config_dict_from_files(config_files: list[Path]) -> dict[str, Any]: 56 | """ 57 | Given a list of config json paths, merge them into a combined config dict (with later files taking precedence). 58 | """ 59 | config_dicts = [] 60 | for config_file in config_files: 61 | logging.debug(f"Loading {config_file}") 62 | with open(config_file, "r") as f: 63 | config_dicts.append(json.load(f)) 64 | combined_config = merge_dicts(*config_dicts) 65 | return combined_config 66 | 67 | 68 | def cast_types(config: dict[str, Any], target_dataclass: Type): 69 | """ 70 | user overrides are passed in as strings, so we need to cast them to the correct type 71 | """ 72 | dc_fields = {f.name: f for f in fields(target_dataclass)} 73 | for k, v in config.items(): 74 | if k in dc_fields: 75 | field = dc_fields[k] 76 | field_type = _unoptionalize(field.type) 77 | if is_dataclass(field_type): 78 | cast_types(v, field_type) 79 | elif isinstance(field_type, dict): 80 | # for dictionaries, make best effort to cast values to the correct type 81 | for _k, _v in v.items(): 82 | v[_k] = ast.literal_eval(_v) 83 | else: 84 | config[k] = field_type(v) 85 | 86 | 87 | def build_config_from_files_and_overrides( 88 | config_files: list[Path], 89 | overrides: dict[str, Any], 90 | ) -> TrainingConfig: 91 | """ 92 | This is the main entrypoint for building a TrainingConfig object from a list of config files and overrides. 93 | 94 | 1. Load config_files in order, merging them into one dict, with later taking precedence. 95 | 2. Cast the strings from overrides to the correct types 96 | (we expect this to be passed as strings w/o type hints from a script argument: 97 | e.g. `--overrides model_config.hidden_size=42 run_name=foo`) 98 | 3. Merge in overrides to config_dict, taking precedence over all config_files values. 99 | 4. Build the TrainingConfig object from the final config dict and return it. 100 | """ 101 | combined_config = build_config_dict_from_files(config_files) 102 | cast_types(overrides, TrainingConfig) 103 | merge_two_dicts(merge_into=combined_config, merge_from=overrides) 104 | return from_dict(TrainingConfig, combined_config, config=dacite_config(strict=True)) 105 | 106 | 107 | def dot_notation_to_dict(vars: dict[str, Any]) -> dict[str, Any]: 108 | """ 109 | Convert {"a.b.c": 4, "foo": false} to {"a": {"b": {"c": 4}}, "foo": False} 110 | """ 111 | nested_dict = dict() 112 | for k, v in vars.items(): 113 | if v is None: 114 | continue 115 | cur = nested_dict 116 | subkeys = k.split(".") 117 | for subkey in subkeys[:-1]: 118 | if subkey not in cur: 119 | cur[subkey] = {} 120 | cur = cur[subkey] 121 | cur[subkeys[-1]] = v 122 | return nested_dict 123 | 124 | 125 | def _unoptionalize(t: Type | _GenericAlias) -> Type: 126 | """unwrap `Optional[T]` to T. 127 | 128 | We need this to correctly interpret user-passed overrides, which are always strings 129 | without any type information attached. We need to look up what type they should be 130 | and cast accordingly. As part of this lookup we need to pierce Optional values - 131 | if the user is setting a value, it's clearly not Optional, and we need to get the underlying 132 | type to cast correctly. 133 | """ 134 | # Under the hood, `Optional` is really `Union[T, None]`. So we 135 | # just check if this is a Union over two types including None, and 136 | # return the other 137 | if hasattr(t, "__origin__") and t.__origin__ is Union: 138 | args = t.__args__ 139 | # Check if one of the Union arguments is type None 140 | if len(args) == 2 and type(None) in args: 141 | return args[0] if args[1] is type(None) else args[1] 142 | return t 143 | -------------------------------------------------------------------------------- /delphi/train/run_context.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import torch 5 | import transformers 6 | 7 | import delphi 8 | 9 | 10 | def get_auto_device_str() -> str: 11 | if torch.cuda.is_available(): 12 | return "cuda" 13 | if torch.backends.mps.is_available(): 14 | return "mps" 15 | return "cpu" 16 | 17 | 18 | def check_set_env_cublas_workspace_config(): 19 | expected_val = ":4096:8" 20 | actual_val = os.getenv("CUBLAS_WORKSPACE_CONFIG") 21 | if actual_val is None: 22 | logging.info( 23 | f"Environment variable CUBLAS_WORKSPACE_CONFIG not set. Setting to '{expected_val}' to ensure reproducibility." 24 | ) 25 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = expected_val 26 | else: 27 | correct_values = [expected_val, ":16:8"] 28 | assert actual_val in correct_values, ( 29 | f"Environment variable CUBLAS_WORKSPACE_CONFIG is set to {actual_val}, which is incompatibe with reproducible training. " 30 | f"Please set it to one of the following values: {correct_values}. " 31 | f"See https://docs.nvidia.com/cuda/archive/12.4.0/cublas/index.html#results-reproducibility for more information." 32 | ) 33 | 34 | 35 | class RunContext: 36 | def __init__(self, device_str: str): 37 | if device_str == "auto": 38 | device_str = get_auto_device_str() 39 | self.device = torch.device(device_str) 40 | if self.device.type == "cuda": 41 | assert torch.cuda.is_available() 42 | check_set_env_cublas_workspace_config() 43 | self.gpu_name = torch.cuda.get_device_name(self.device) 44 | elif self.device.type == "mps": 45 | assert torch.backends.mps.is_available() 46 | self.torch_version = torch.__version__ 47 | self.delphi_version = delphi.__version__ 48 | self.transformers_version = transformers.__version__ 49 | self.os = os.uname().version 50 | 51 | def asdict(self) -> dict: 52 | asdict = self.__dict__.copy() 53 | asdict["device"] = str(self.device) 54 | return asdict 55 | -------------------------------------------------------------------------------- /delphi/train/shuffle.py: -------------------------------------------------------------------------------- 1 | class RNG: 2 | """Random Number Generator 3 | 4 | Linear Congruential Generator equivalent to minstd_rand in C++11 5 | https://en.cppreference.com/w/cpp/numeric/random 6 | """ 7 | 8 | a = 48271 9 | m = 2147483647 # 2^31 - 1 10 | 11 | def __init__(self, seed: int): 12 | assert 0 <= seed < self.m 13 | self.state = seed 14 | 15 | def __call__(self) -> int: 16 | self.state = (self.state * self.a) % self.m 17 | return self.state 18 | 19 | 20 | def shuffle_list(in_out: list, seed: int): 21 | """Deterministically shuffle a list in-place 22 | 23 | Implements Fisher-Yates shuffle with LCG as randomness source 24 | https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm 25 | """ 26 | rng = RNG(seed) 27 | n = len(in_out) 28 | for i in range(n - 1, 0, -1): 29 | j = rng() % (i + 1) 30 | in_out[i], in_out[j] = in_out[j], in_out[i] 31 | 32 | 33 | def shuffle_epoch(samples: list, seed: int, epoch_nr: int): 34 | """Shuffle the samples in-place for a given training epoch""" 35 | rng = RNG((10_000 + seed) % RNG.m) 36 | for _ in range(epoch_nr): 37 | rng() 38 | shuffle_seed = rng() 39 | shuffle_list(samples, shuffle_seed) 40 | -------------------------------------------------------------------------------- /delphi/train/train_step.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections.abc import Iterable 3 | 4 | import torch 5 | from datasets import Dataset 6 | from transformers import PreTrainedModel 7 | 8 | from .config import TrainingConfig 9 | from .utils import ModelTrainingState, gen_minibatches 10 | 11 | 12 | def train_step( 13 | model_training_state: ModelTrainingState, 14 | train_ds: Dataset, 15 | config: TrainingConfig, 16 | device: torch.device, 17 | ds_indices: list[int], 18 | ): 19 | """ 20 | Runs a training step, updating (mutating in place) model_training_state: 21 | - generate gradient_accumulation_steps batches (each batch is batch_size/gradient_accumulation_steps items) 22 | - forward pass, accumulating gradient/gradient_accumulation_steps over gradient_accumulation_steps batches 23 | - clip gradient where gradient exceeds grad_clip (if configured) 24 | - backward pass, updating model weights 25 | - reset grad 26 | """ 27 | model = model_training_state.model 28 | optimizer = model_training_state.optimizer 29 | 30 | if config.debug_config.no_training: 31 | total_loss = 0.0 32 | logging.debug("no_training set, skipping forward backward pass") 33 | else: 34 | minibatches = gen_minibatches( 35 | dataset=train_ds, 36 | indices=ds_indices, 37 | batch_size=config.batch_size, 38 | num_minibatches=config.gradient_accumulation_steps, 39 | step=model_training_state.step, 40 | device=device, 41 | feature_name=config.dataset.feature, 42 | ) 43 | total_loss = accumulate_gradients( 44 | model=model, 45 | batches=minibatches, 46 | num_batches=config.gradient_accumulation_steps, 47 | ) 48 | # clip the gradient 49 | if config.grad_clip != 0.0: 50 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) # type: ignore 51 | optimizer.step() 52 | # flush the gradients as soon as we can, no need for this memory anymore 53 | optimizer.zero_grad(set_to_none=True) 54 | model_training_state.train_loss = total_loss 55 | 56 | 57 | def accumulate_gradients( 58 | model: PreTrainedModel, 59 | batches: Iterable[torch.Tensor], 60 | num_batches: int, 61 | ) -> float: 62 | """ 63 | Accumulate gradients over multiple batches as if they were a single batch 64 | """ 65 | total_loss = 0.0 66 | for X in batches: 67 | loss = model(X, labels=X, return_dict=True).loss / num_batches 68 | total_loss += loss.item() 69 | loss.backward() 70 | return total_loss 71 | -------------------------------------------------------------------------------- /delphi/train/training.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from dataclasses import fields 5 | from pathlib import Path 6 | 7 | import torch 8 | from huggingface_hub import HfApi 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer 11 | 12 | from delphi.train.shuffle import shuffle_epoch 13 | 14 | from .checkpoint_step import log_and_save_checkpoint, should_save_checkpoint 15 | from .config import TrainingConfig 16 | from .run_context import RunContext 17 | from .train_step import train_step 18 | from .utils import ( 19 | ModelTrainingState, 20 | initialize_model_training_state, 21 | set_lr, 22 | setup_determinism, 23 | ) 24 | from .wandb_utils import init_wandb 25 | 26 | 27 | def setup_training(config: TrainingConfig): 28 | logging.info("Setting up training...") 29 | os.makedirs(config.out_dir, exist_ok=True) 30 | 31 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 32 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 33 | 34 | setup_determinism(config.torch_seed) 35 | 36 | if config.out_repo: 37 | api = HfApi() 38 | api.create_repo(config.out_repo, exist_ok=True) 39 | 40 | if config.wandb: 41 | init_wandb(config) 42 | 43 | if config.tokenizer: 44 | tokenizer = AutoTokenizer.from_pretrained(config.tokenizer) 45 | tokenizer.save_pretrained(Path(config.out_dir) / "tokenizer") 46 | 47 | 48 | def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext]: 49 | setup_training(config) 50 | logging.info("Starting training...") 51 | logging.info("Config:") 52 | for field in fields(config): 53 | logging.info(f" {field.name}: {getattr(config, field.name)}") 54 | run_context = RunContext(config.device) 55 | logging.debug(f"Run context: {run_context.asdict()}") 56 | 57 | # load data 58 | logging.info("Loading data...") 59 | train_ds = config.dataset.load_train() 60 | validation_ds = config.dataset.load_validation() 61 | logging.info(f"Train dataset: {len(train_ds)} samples") 62 | logging.info(f"Validation dataset: {len(validation_ds)} samples") 63 | 64 | # derive iteration params 65 | steps_per_epoch = len(train_ds) // config.batch_size 66 | lr_decay_iters = ( 67 | config.max_epochs * steps_per_epoch 68 | ) # should be ~=max_iters per Chinchilla 69 | 70 | # model init 71 | model_training_state = initialize_model_training_state(config, run_context.device) 72 | 73 | # training loop 74 | logging.info("Starting training...") 75 | for epoch in range(config.max_epochs): 76 | logging.info(f"Epoch: {epoch+1} / {config.max_epochs}") 77 | train_data_indices = list(range(len(train_ds))) 78 | shuffle_epoch( 79 | train_data_indices, seed=config.batch_ordering_seed, epoch_nr=epoch 80 | ) 81 | model_training_state.epoch = epoch 82 | for step in tqdm(range(steps_per_epoch)): 83 | model_training_state.step = step 84 | if should_save_checkpoint(config, model_training_state): 85 | log_and_save_checkpoint( 86 | config=config, 87 | mts=model_training_state, 88 | train_ds=train_ds, 89 | validation_ds=validation_ds, 90 | run_context=run_context, 91 | ) 92 | model_training_state.lr = set_lr( 93 | lr_decay_iters=lr_decay_iters, 94 | config=config, 95 | optimizer=model_training_state.optimizer, 96 | iter_num=model_training_state.iter_num, 97 | ) 98 | train_step( 99 | model_training_state=model_training_state, 100 | train_ds=train_ds, 101 | config=config, 102 | device=run_context.device, 103 | ds_indices=train_data_indices, 104 | ) 105 | t1 = time.time() 106 | dt = t1 - model_training_state.last_training_step_time 107 | model_training_state.last_training_step_time = t1 108 | if model_training_state.iter_num % config.log_interval == 0: 109 | logging.debug( 110 | ( 111 | f"{model_training_state.iter_num} | loss {model_training_state.train_loss:.4f} | lr {model_training_state.lr:e} | " 112 | f"{dt*1000:.2f}ms" 113 | ) 114 | ) 115 | model_training_state.iter_num += 1 116 | return model_training_state, run_context 117 | -------------------------------------------------------------------------------- /delphi/train/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import math 4 | import os 5 | import shutil 6 | import time 7 | from collections.abc import Iterator 8 | from dataclasses import asdict, dataclass, field 9 | from pathlib import Path 10 | from typing import Any, Type, cast 11 | 12 | import safetensors.torch as st 13 | import torch 14 | import transformers 15 | from datasets import Dataset 16 | from huggingface_hub import HfApi 17 | from torch.optim import AdamW 18 | from transformers import PreTrainedModel 19 | 20 | from delphi.train.config import dot_notation_to_dict 21 | 22 | from .config import TrainingConfig 23 | from .run_context import RunContext 24 | 25 | 26 | @dataclass 27 | class ModelTrainingState: 28 | """mutable training state - stuff that changes over the course of training""" 29 | 30 | model: PreTrainedModel 31 | optimizer: torch.optim.Optimizer 32 | iter_num: int = field( 33 | metadata={"help": "total iterations so far across all epochs"} 34 | ) 35 | last_training_step_time: float = field( 36 | metadata={"help": "time last iteration ended"} 37 | ) 38 | epoch: int = field(metadata={"help": "current epoch"}) 39 | step: int = field(metadata={"help": "step within current epoch"}) 40 | lr: float = field(default=1.0e-5, metadata={"help": "learning rate"}) 41 | train_loss: float = field( 42 | default=0.0, metadata={"help": "loss on most recent train step"} 43 | ) 44 | 45 | 46 | def setup_determinism(seed: int): 47 | logging.debug(f"Setting up torch determinism (seed={seed})...") 48 | torch.use_deterministic_algorithms(True) 49 | torch.backends.cudnn.benchmark = False 50 | torch.manual_seed(seed) 51 | 52 | 53 | def get_lr( 54 | iter_num: int, 55 | warmup_iters: int, 56 | learning_rate: float, 57 | lr_decay_iters: int, 58 | min_lr: float, 59 | ): 60 | # 1) linear warmup for warmup_iters steps 61 | if iter_num < warmup_iters: 62 | return learning_rate * iter_num / warmup_iters 63 | # 2) if it > lr_decay_iters, return min learning rate 64 | if iter_num > lr_decay_iters: 65 | return min_lr 66 | # 3) in between, use cosine decay down to min learning rate 67 | decay_ratio = (iter_num - warmup_iters) / (lr_decay_iters - warmup_iters) 68 | assert 0 <= decay_ratio <= 1 69 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 70 | return min_lr + coeff * (learning_rate - min_lr) 71 | 72 | 73 | def set_lr( 74 | lr_decay_iters: int, 75 | config: TrainingConfig, 76 | optimizer: torch.optim.Optimizer, 77 | iter_num: int, 78 | ): 79 | """ 80 | Set the learning rate (calculated by get_lr) on the optimizer 81 | """ 82 | lr = ( 83 | get_lr( 84 | iter_num=iter_num, 85 | warmup_iters=config.adam.warmup_iters, 86 | learning_rate=config.adam.learning_rate, 87 | lr_decay_iters=lr_decay_iters, 88 | min_lr=config.adam.min_lr, 89 | ) 90 | if config.adam.decay_lr 91 | else config.adam.learning_rate 92 | ) 93 | for param_group in optimizer.param_groups: 94 | param_group["lr"] = lr 95 | return lr 96 | 97 | 98 | def initialize_model_training_state( 99 | config: TrainingConfig, device: torch.device 100 | ) -> ModelTrainingState: 101 | t0 = time.time() 102 | model = init_model(config.model_config, seed=config.torch_seed) 103 | model.to(device) # type: ignore 104 | optimizer = AdamW( 105 | lr=config.adam.learning_rate, 106 | params=model.parameters(), 107 | weight_decay=config.adam.weight_decay, 108 | betas=(config.adam.beta1, config.adam.beta2), 109 | ) 110 | training_state_vals = dict() 111 | if config.resume_from_path is not None: 112 | logging.info(f"Resuming training from {config.resume_from_path}") 113 | st.load_model( 114 | model, os.path.join(config.resume_from_path, "model", "model.safetensors") 115 | ) 116 | with open( 117 | os.path.join(config.resume_from_path, "training_state.json"), "r" 118 | ) as f: 119 | training_state_vals = json.load(f) 120 | opt_state_dict_path = Path(os.path.join(config.resume_from_path, "opt.pt")) 121 | if opt_state_dict_path.exists(): 122 | with open(opt_state_dict_path, "rb") as f: 123 | logging.info(" Loading optimizer state from {state_dict_path}") 124 | optimizer.load_state_dict(torch.load(f)) 125 | return ModelTrainingState( 126 | model=model, 127 | optimizer=optimizer, 128 | last_training_step_time=t0, 129 | iter_num=training_state_vals.get("iter_num", 0), 130 | epoch=training_state_vals.get("epoch", 0), 131 | step=training_state_vals.get("step", 0), 132 | ) 133 | 134 | 135 | def gen_minibatches( 136 | dataset: Dataset, 137 | batch_size: int, 138 | num_minibatches: int, 139 | step: int, 140 | indices: list[int], 141 | device: torch.device, 142 | feature_name: str, 143 | ) -> Iterator[torch.Tensor]: 144 | """ 145 | Generate minibatches from a dataset given a step and indices 146 | """ 147 | minibatch_size = batch_size // num_minibatches 148 | first_minibatch_num = num_minibatches * step 149 | for batch_num in range(first_minibatch_num, first_minibatch_num + num_minibatches): 150 | start = batch_num * minibatch_size 151 | end = (batch_num + 1) * minibatch_size 152 | batch_indices = indices[start:end] 153 | yield dataset[batch_indices][feature_name].to(device) 154 | 155 | 156 | @torch.no_grad() 157 | def estimate_loss( 158 | model: torch.nn.Module, 159 | eval_iters: int, 160 | batch_size: int, 161 | split_to_ds: dict[str, Dataset], 162 | device: torch.device, 163 | feature_name: str, 164 | ) -> dict[str, float]: 165 | """helps estimate an arbitrarily accurate loss over either split using many batches""" 166 | out = {} 167 | model.eval() 168 | for split, ds in split_to_ds.items(): 169 | indices = list(range(len(ds))) 170 | eval_iters = min(eval_iters, len(ds) // batch_size) 171 | losses = torch.zeros(eval_iters) # keep on CPU 172 | minibatches = gen_minibatches( 173 | dataset=ds, 174 | batch_size=batch_size, 175 | num_minibatches=eval_iters, 176 | step=0, 177 | indices=indices, 178 | device=device, 179 | feature_name=feature_name, 180 | ) 181 | for k, X in enumerate(minibatches): 182 | loss = model(X, labels=X, return_dict=True).loss 183 | losses[k] = loss.item() 184 | out[split] = losses.mean().item() 185 | model.train() 186 | return out 187 | 188 | 189 | def save_results( 190 | config: TrainingConfig, 191 | train_results: ModelTrainingState, 192 | run_context: RunContext, 193 | final: bool = False, 194 | ): 195 | """ 196 | save results to disk, and to huggingface if configured to do so. 197 | 198 | Saves everything required to replicate the current state of training, including optimizer state, 199 | config, context (e.g. hardware), training step, etc 200 | """ 201 | iter_name = "main" if final else f"iter{train_results.iter_num}" 202 | out_dir = Path(config.out_dir) 203 | results_path = out_dir / iter_name 204 | logging.info(f"saving checkpoint to {results_path}") 205 | results_path.mkdir(parents=True, exist_ok=True) 206 | with open(results_path / "training_config.json", "w") as file: 207 | json.dump(asdict(config), file, indent=2) 208 | train_results.model.save_pretrained( 209 | save_directory=results_path, 210 | ) 211 | if config.save_optimizer: 212 | with open(results_path / "optimizer.pt", "wb") as f: 213 | torch.save(train_results.optimizer.state_dict(), f) 214 | with open(results_path / "training_state.json", "w") as file: 215 | training_state_dict = { 216 | "iter_num": train_results.iter_num, 217 | "lr": train_results.lr, 218 | "epoch": train_results.epoch, 219 | "step": train_results.step, 220 | } 221 | json.dump(training_state_dict, file, indent=2) 222 | with open(results_path / "run_context.json", "w") as file: 223 | json.dump(run_context.asdict(), file, indent=2) 224 | if (tokenizer_dir := out_dir / "tokenizer").exists(): 225 | for src_file in tokenizer_dir.iterdir(): 226 | if src_file.is_file(): 227 | dest_file = results_path / src_file.name 228 | shutil.copy2(src_file, dest_file) 229 | if config.out_repo: 230 | try: 231 | api = HfApi() 232 | api.create_branch(config.out_repo, branch=iter_name, exist_ok=True) 233 | api.upload_folder( 234 | folder_path=results_path, 235 | repo_id=config.out_repo, 236 | revision=iter_name, 237 | ) 238 | except Exception as e: 239 | logging.error(f"Failed to upload to huggingface: {e}") 240 | 241 | 242 | def count_tokens_so_far(config: TrainingConfig, mts: ModelTrainingState) -> int: 243 | tokens_per_iter = config.batch_size * config.max_seq_len 244 | return mts.iter_num * tokens_per_iter 245 | 246 | 247 | def init_model(model_config_dict: dict[str, Any], seed: int) -> PreTrainedModel: 248 | """ 249 | Get a model from a model config dictionary 250 | """ 251 | # reseed torch to ensure reproducible results in case other torch calls are different up to this point 252 | torch.random.manual_seed(seed) 253 | model_class = getattr(transformers, model_config_dict["model_class"]) 254 | config_class = cast(Type[transformers.PretrainedConfig], model_class.config_class) 255 | model_params_dict = model_config_dict.copy() 256 | model_params_dict.pop("model_class") 257 | return model_class(config_class(**(model_params_dict))) 258 | 259 | 260 | def overrides_to_dict(overrides: list[str]) -> dict[str, Any]: 261 | # ["a.b.c=4", "foo=false"] to {"a": {"b": {"c": 4}}, "foo": False} 262 | config_vars = {k: v for k, v in [x.split("=") for x in overrides if "=" in x]} 263 | return dot_notation_to_dict(config_vars) 264 | -------------------------------------------------------------------------------- /delphi/train/wandb_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import asdict 3 | 4 | import wandb 5 | 6 | from .config import TrainingConfig 7 | from .utils import ModelTrainingState 8 | 9 | 10 | def init_wandb(config: TrainingConfig): 11 | assert "/" in config.wandb, "wandb should be in the 'entity/project' form" 12 | wandb_entity, wandb_project = config.wandb.split("/") 13 | wandb.init( 14 | entity=wandb_entity, 15 | project=wandb_project, 16 | name=config.run_name, 17 | config=asdict(config), 18 | ) 19 | 20 | 21 | def log_to_wandb(mts: ModelTrainingState, losses: dict[str, float], tokens_so_far: int): 22 | try: 23 | wandb.log( 24 | { 25 | "epoch": mts.epoch, 26 | "epoch_iter": mts.step, 27 | "global_iter": mts.iter_num, 28 | "tokens": tokens_so_far, 29 | "loss/train": losses["train"], 30 | "loss/val": losses["val"], 31 | "lr": mts.lr, 32 | }, 33 | step=mts.iter_num, 34 | ) 35 | except Exception as e: 36 | logging.error(f"logging to wandb failed: {e}") 37 | -------------------------------------------------------------------------------- /delphi/utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from typing import cast 3 | 4 | import torch 5 | from datasets import Dataset, Features, Sequence, Value, load_dataset 6 | from jaxtyping import Float, Int 7 | 8 | 9 | def hf_split_to_split_name(split: str) -> str: 10 | return split.split("[")[0] 11 | 12 | 13 | def load_dataset_split_features( 14 | path: str, 15 | split: str, 16 | features: Features, 17 | ) -> Dataset: 18 | dataset = load_dataset( 19 | path, 20 | split=split, 21 | features=features, 22 | ) 23 | dataset = cast(Dataset, dataset) 24 | return dataset 25 | 26 | 27 | def load_dataset_split_string_feature( 28 | path: str, 29 | split: str, 30 | feature_name: str, 31 | ) -> Dataset: 32 | print("Loading string dataset") 33 | print(f"{path=}, {split=}, {feature_name=}") 34 | return load_dataset_split_features( 35 | path, 36 | split, 37 | Features({feature_name: Value("string")}), 38 | ) 39 | 40 | 41 | def load_dataset_split_sequence_int32_feature( 42 | path: str, 43 | split: str, 44 | feature_name: str, 45 | ) -> Dataset: 46 | print("Loading sequence int32 dataset") 47 | print(f"{path=}, {split=}, {feature_name=}") 48 | return load_dataset_split_features( 49 | path, 50 | split, 51 | Features({feature_name: Sequence(Value("int32"))}), 52 | ) 53 | 54 | 55 | def get_all_hf_branch_names(repo_id: str) -> list[str]: 56 | from huggingface_hub import HfApi 57 | 58 | api = HfApi() 59 | refs = api.list_repo_refs(repo_id) 60 | return [branch.name for branch in refs.branches] 61 | 62 | 63 | def gather_logprobs( 64 | logprobs: Float[torch.Tensor, "batch seq vocab"], 65 | tokens: Int[torch.Tensor, "batch seq"], 66 | ) -> Float[torch.Tensor, "batch seq"]: 67 | return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1) 68 | 69 | 70 | def get_all_logprobs( 71 | model: Callable, input_ids: Int[torch.Tensor, "batch seq"] 72 | ) -> Float[torch.Tensor, "batch seq vocab"]: 73 | # batch, seq, vocab 74 | logits = model(input_ids).logits 75 | return torch.log_softmax(logits, dim=-1) 76 | 77 | 78 | def get_all_and_next_logprobs( 79 | model: Callable, 80 | input_ids: Int[torch.Tensor, "batch seq"], 81 | ) -> tuple[ 82 | Float[torch.Tensor, "batch shorter_seq vocab"], 83 | Float[torch.Tensor, "batch shorter_seq"], 84 | ]: 85 | logprobs = get_all_logprobs(model, input_ids[:, :-1]) 86 | next_tokens = input_ids[:, 1:] 87 | return logprobs, gather_logprobs(logprobs, next_tokens) 88 | -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delphi-suite/delphi/1efcabec4bceffbf4e9a383d05958f8c0704d2ce/notebooks/.gitkeep -------------------------------------------------------------------------------- /notebooks/eval_notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# colab cells (only run if on colab)\n", 10 | "# TODO: experiment on colab to see how to set up the environment" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "# Important\n", 18 | "\n", 19 | "Run this cell by cell. The token selecter cell needs to be ran first so the later cells work." 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "data": { 29 | "application/javascript": "(function(root) {\n function now() {\n return new Date();\n }\n\n var force = true;\n var py_version = '3.4.1'.replace('rc', '-rc.').replace('.dev', '-dev.');\n var reloading = false;\n var Bokeh = root.Bokeh;\n\n if (typeof (root._bokeh_timeout) === \"undefined\" || force) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks;\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, js_modules, js_exports, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n if (js_modules == null) js_modules = [];\n if (js_exports == null) js_exports = {};\n\n root._bokeh_onload_callbacks.push(callback);\n\n if (root._bokeh_is_loading > 0) {\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n }\n if (js_urls.length === 0 && js_modules.length === 0 && Object.keys(js_exports).length === 0) {\n run_callbacks();\n return null;\n }\n if (!reloading) {\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n }\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n window._bokeh_on_load = on_load\n\n function on_error() {\n console.error(\"failed to load \" + url);\n }\n\n var skip = [];\n if (window.requirejs) {\n window.requirejs.config({'packages': {}, 'paths': {}, 'shim': {}});\n root._bokeh_is_loading = css_urls.length + 0;\n } else {\n root._bokeh_is_loading = css_urls.length + js_urls.length + js_modules.length + Object.keys(js_exports).length;\n }\n\n var existing_stylesheets = []\n var links = document.getElementsByTagName('link')\n for (var i = 0; i < links.length; i++) {\n var link = links[i]\n if (link.href != null) {\n\texisting_stylesheets.push(link.href)\n }\n }\n for (var i = 0; i < css_urls.length; i++) {\n var url = css_urls[i];\n if (existing_stylesheets.indexOf(url) !== -1) {\n\ton_load()\n\tcontinue;\n }\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error;\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n } var existing_scripts = []\n var scripts = document.getElementsByTagName('script')\n for (var i = 0; i < scripts.length; i++) {\n var script = scripts[i]\n if (script.src != null) {\n\texisting_scripts.push(script.src)\n }\n }\n for (var i = 0; i < js_urls.length; i++) {\n var url = js_urls[i];\n if (skip.indexOf(url) !== -1 || existing_scripts.indexOf(url) !== -1) {\n\tif (!window.requirejs) {\n\t on_load();\n\t}\n\tcontinue;\n }\n var element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (var i = 0; i < js_modules.length; i++) {\n var url = js_modules[i];\n if (skip.indexOf(url) !== -1 || existing_scripts.indexOf(url) !== -1) {\n\tif (!window.requirejs) {\n\t on_load();\n\t}\n\tcontinue;\n }\n var element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (const name in js_exports) {\n var url = js_exports[name];\n if (skip.indexOf(url) >= 0 || root[name] != null) {\n\tif (!window.requirejs) {\n\t on_load();\n\t}\n\tcontinue;\n }\n var element = document.createElement('script');\n element.onerror = on_error;\n element.async = false;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n element.textContent = `\n import ${name} from \"${url}\"\n window.${name} = ${name}\n window._bokeh_on_load()\n `\n document.head.appendChild(element);\n }\n if (!js_urls.length && !js_modules.length) {\n on_load()\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n var js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-3.4.1.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-3.4.1.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-3.4.1.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-3.4.1.min.js\", \"https://cdn.holoviz.org/panel/1.4.0/dist/panel.min.js\"];\n var js_modules = [];\n var js_exports = {};\n var css_urls = [\"https://cdn.holoviz.org/panel/1.4.0/dist/bundled/font-awesome/css/all.min.css\"];\n var inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {} // ensure no trailing comma for IE\n ];\n\n function run_inline_js() {\n if ((root.Bokeh !== undefined) || (force === true)) {\n for (var i = 0; i < inline_js.length; i++) {\n\ttry {\n inline_js[i].call(root, root.Bokeh);\n\t} catch(e) {\n\t if (!reloading) {\n\t throw e;\n\t }\n\t}\n }\n // Cache old bokeh versions\n if (Bokeh != undefined && !reloading) {\n\tvar NewBokeh = root.Bokeh;\n\tif (Bokeh.versions === undefined) {\n\t Bokeh.versions = new Map();\n\t}\n\tif (NewBokeh.version !== Bokeh.version) {\n\t Bokeh.versions.set(NewBokeh.version, NewBokeh)\n\t}\n\troot.Bokeh = Bokeh;\n }} else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n }\n root._bokeh_is_initializing = false\n }\n\n function load_or_wait() {\n // Implement a backoff loop that tries to ensure we do not load multiple\n // versions of Bokeh and its dependencies at the same time.\n // In recent versions we use the root._bokeh_is_initializing flag\n // to determine whether there is an ongoing attempt to initialize\n // bokeh, however for backward compatibility we also try to ensure\n // that we do not start loading a newer (Panel>=1.0 and Bokeh>3) version\n // before older versions are fully initialized.\n if (root._bokeh_is_initializing && Date.now() > root._bokeh_timeout) {\n root._bokeh_is_initializing = false;\n root._bokeh_onload_callbacks = undefined;\n console.log(\"Bokeh: BokehJS was loaded multiple times but one version failed to initialize.\");\n load_or_wait();\n } else if (root._bokeh_is_initializing || (typeof root._bokeh_is_initializing === \"undefined\" && root._bokeh_onload_callbacks !== undefined)) {\n setTimeout(load_or_wait, 100);\n } else {\n root._bokeh_is_initializing = true\n root._bokeh_onload_callbacks = []\n var bokeh_loaded = Bokeh != null && (Bokeh.version === py_version || (Bokeh.versions !== undefined && Bokeh.versions.has(py_version)));\n if (!reloading && !bokeh_loaded) {\n\troot.Bokeh = undefined;\n }\n load_libs(css_urls, js_urls, js_modules, js_exports, function() {\n\tconsole.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n\trun_inline_js();\n });\n }\n }\n // Give older versions of the autoload script a head-start to ensure\n // they initialize before we start loading newer version.\n setTimeout(load_or_wait, 100)\n}(window));", 30 | "application/vnd.holoviews_load.v0+json": "" 31 | }, 32 | "metadata": {}, 33 | "output_type": "display_data" 34 | }, 35 | { 36 | "data": { 37 | "application/javascript": "\nif ((window.PyViz === undefined) || (window.PyViz instanceof HTMLElement)) {\n window.PyViz = {comms: {}, comm_status:{}, kernels:{}, receivers: {}, plot_index: []}\n}\n\n\n function JupyterCommManager() {\n }\n\n JupyterCommManager.prototype.register_target = function(plot_id, comm_id, msg_handler) {\n if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n comm_manager.register_target(comm_id, function(comm) {\n comm.on_msg(msg_handler);\n });\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n window.PyViz.kernels[plot_id].registerCommTarget(comm_id, function(comm) {\n comm.onMsg = msg_handler;\n });\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n google.colab.kernel.comms.registerTarget(comm_id, (comm) => {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n console.log(message)\n var content = {data: message.data, comm_id};\n var buffers = []\n for (var buffer of message.buffers || []) {\n buffers.push(new DataView(buffer))\n }\n var metadata = message.metadata || {};\n var msg = {content, buffers, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n })\n }\n }\n\n JupyterCommManager.prototype.get_client_comm = function(plot_id, comm_id, msg_handler) {\n if (comm_id in window.PyViz.comms) {\n return window.PyViz.comms[comm_id];\n } else if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n var comm = comm_manager.new_comm(comm_id, {}, {}, {}, comm_id);\n if (msg_handler) {\n comm.on_msg(msg_handler);\n }\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n var comm = window.PyViz.kernels[plot_id].connectToComm(comm_id);\n comm.open();\n if (msg_handler) {\n comm.onMsg = msg_handler;\n }\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n var comm_promise = google.colab.kernel.comms.open(comm_id)\n comm_promise.then((comm) => {\n window.PyViz.comms[comm_id] = comm;\n if (msg_handler) {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n var content = {data: message.data};\n var metadata = message.metadata || {comm_id};\n var msg = {content, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n }\n }) \n var sendClosure = (data, metadata, buffers, disposeOnDone) => {\n return comm_promise.then((comm) => {\n comm.send(data, metadata, buffers, disposeOnDone);\n });\n };\n var comm = {\n send: sendClosure\n };\n }\n window.PyViz.comms[comm_id] = comm;\n return comm;\n }\n window.PyViz.comm_manager = new JupyterCommManager();\n \n\n\nvar JS_MIME_TYPE = 'application/javascript';\nvar HTML_MIME_TYPE = 'text/html';\nvar EXEC_MIME_TYPE = 'application/vnd.holoviews_exec.v0+json';\nvar CLASS_NAME = 'output';\n\n/**\n * Render data to the DOM node\n */\nfunction render(props, node) {\n var div = document.createElement(\"div\");\n var script = document.createElement(\"script\");\n node.appendChild(div);\n node.appendChild(script);\n}\n\n/**\n * Handle when a new output is added\n */\nfunction handle_add_output(event, handle) {\n var output_area = handle.output_area;\n var output = handle.output;\n if ((output.data == undefined) || (!output.data.hasOwnProperty(EXEC_MIME_TYPE))) {\n return\n }\n var id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n var toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n if (id !== undefined) {\n var nchildren = toinsert.length;\n var html_node = toinsert[nchildren-1].children[0];\n html_node.innerHTML = output.data[HTML_MIME_TYPE];\n var scripts = [];\n var nodelist = html_node.querySelectorAll(\"script\");\n for (var i in nodelist) {\n if (nodelist.hasOwnProperty(i)) {\n scripts.push(nodelist[i])\n }\n }\n\n scripts.forEach( function (oldScript) {\n var newScript = document.createElement(\"script\");\n var attrs = [];\n var nodemap = oldScript.attributes;\n for (var j in nodemap) {\n if (nodemap.hasOwnProperty(j)) {\n attrs.push(nodemap[j])\n }\n }\n attrs.forEach(function(attr) { newScript.setAttribute(attr.name, attr.value) });\n newScript.appendChild(document.createTextNode(oldScript.innerHTML));\n oldScript.parentNode.replaceChild(newScript, oldScript);\n });\n if (JS_MIME_TYPE in output.data) {\n toinsert[nchildren-1].children[1].textContent = output.data[JS_MIME_TYPE];\n }\n output_area._hv_plot_id = id;\n if ((window.Bokeh !== undefined) && (id in Bokeh.index)) {\n window.PyViz.plot_index[id] = Bokeh.index[id];\n } else {\n window.PyViz.plot_index[id] = null;\n }\n } else if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n var bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n var script_attrs = bk_div.children[0].attributes;\n for (var i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].childNodes[1].setAttribute(script_attrs[i].name, script_attrs[i].value);\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n}\n\n/**\n * Handle when an output is cleared or removed\n */\nfunction handle_clear_output(event, handle) {\n var id = handle.cell.output_area._hv_plot_id;\n var server_id = handle.cell.output_area._bokeh_server_id;\n if (((id === undefined) || !(id in PyViz.plot_index)) && (server_id !== undefined)) { return; }\n var comm = window.PyViz.comm_manager.get_client_comm(\"hv-extension-comm\", \"hv-extension-comm\", function () {});\n if (server_id !== null) {\n comm.send({event_type: 'server_delete', 'id': server_id});\n return;\n } else if (comm !== null) {\n comm.send({event_type: 'delete', 'id': id});\n }\n delete PyViz.plot_index[id];\n if ((window.Bokeh !== undefined) & (id in window.Bokeh.index)) {\n var doc = window.Bokeh.index[id].model.document\n doc.clear();\n const i = window.Bokeh.documents.indexOf(doc);\n if (i > -1) {\n window.Bokeh.documents.splice(i, 1);\n }\n }\n}\n\n/**\n * Handle kernel restart event\n */\nfunction handle_kernel_cleanup(event, handle) {\n delete PyViz.comms[\"hv-extension-comm\"];\n window.PyViz.plot_index = {}\n}\n\n/**\n * Handle update_display_data messages\n */\nfunction handle_update_output(event, handle) {\n handle_clear_output(event, {cell: {output_area: handle.output_area}})\n handle_add_output(event, handle)\n}\n\nfunction register_renderer(events, OutputArea) {\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n var toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n var props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[0]);\n element.append(toinsert);\n return toinsert\n }\n\n events.on('output_added.OutputArea', handle_add_output);\n events.on('output_updated.OutputArea', handle_update_output);\n events.on('clear_output.CodeCell', handle_clear_output);\n events.on('delete.Cell', handle_clear_output);\n events.on('kernel_ready.Kernel', handle_kernel_cleanup);\n\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n safe: true,\n index: 0\n });\n}\n\nif (window.Jupyter !== undefined) {\n try {\n var events = require('base/js/events');\n var OutputArea = require('notebook/js/outputarea').OutputArea;\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n } catch(err) {\n }\n}\n", 38 | "application/vnd.holoviews_load.v0+json": "" 39 | }, 40 | "metadata": {}, 41 | "output_type": "display_data" 42 | }, 43 | { 44 | "data": { 45 | "text/html": [ 46 | "" 62 | ] 63 | }, 64 | "metadata": {}, 65 | "output_type": "display_data" 66 | }, 67 | { 68 | "data": { 69 | "application/vnd.holoviews_exec.v0+json": "", 70 | "text/html": [ 71 | "
\n", 72 | "
\n", 73 | "
\n", 74 | "" 136 | ] 137 | }, 138 | "metadata": { 139 | "application/vnd.holoviews_exec.v0+json": { 140 | "id": "2ac0543f-58a1-4b93-9570-a2c0e6f09501" 141 | } 142 | }, 143 | "output_type": "display_data" 144 | } 145 | ], 146 | "source": [ 147 | "# imports\n", 148 | "import torch\n", 149 | "import panel as pn\n", 150 | "from delphi.eval import token_selector, vis_pos_map, calc_model_group_stats, visualize_selected_tokens, get_all_tok_metrics_in_label\n", 151 | "from datasets import load_dataset, Dataset\n", 152 | "from transformers import AutoTokenizer\n", 153 | "from typing import cast\n", 154 | "import ipywidgets as widgets\n", 155 | "\n", 156 | "# refer to https://panel.holoviz.org/reference/panes/IPyWidget.html to integrate ipywidgets with panel\n", 157 | "pn.extension('ipywidgets')\n", 158 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 3, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "# specify model names (or checkpoints)\n", 168 | "prefix = \"delphi-suite/v0-next-logprobs-llama2-\"\n", 169 | "suffixes = [\n", 170 | " \"100k\",\n", 171 | " \"200k\",\n", 172 | " \"400k\",\n", 173 | "] # , \"800k\", \"1.6m\", \"3.2m\", \"6.4m\", \"12.8m\", \"25.6m\"]" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 4, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "# load next logprobs data for all models\n", 183 | "split = \"validation[:100]\"\n", 184 | "next_logprobs = {\n", 185 | " suffix: cast(\n", 186 | " Dataset,\n", 187 | " load_dataset(f\"{prefix}{suffix}\", split=split),\n", 188 | " )\n", 189 | " .with_format(\"torch\")\n", 190 | " .map(lambda x: {\"logprobs\": x[\"logprobs\"].to(device)})\n", 191 | " for suffix in suffixes\n", 192 | "}\n", 193 | "next_logprobs_plot = {k: d[\"logprobs\"] for k, d in next_logprobs.items()}\n" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 5, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "\n", 203 | "# load the tokenized dataset\n", 204 | "tokenized_corpus_dataset = (\n", 205 | " cast(\n", 206 | " Dataset,\n", 207 | " load_dataset(\"delphi-suite/stories-tokenized\", split=split),\n", 208 | " )\n", 209 | " .with_format(\"torch\")\n", 210 | " .map(lambda x: {\"tokens\": x[\"tokens\"].to(device)})\n", 211 | ")" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": {}, 217 | "source": [ 218 | "Run this notebook until the following cell, then the rest should work." 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 6, 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "name": "stderr", 228 | "output_type": "stream", 229 | "text": [ 230 | "/Users/jett/Documents/jett/delphi/.venv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", 231 | " warnings.warn(\n" 232 | ] 233 | }, 234 | { 235 | "data": { 236 | "application/vnd.jupyter.widget-view+json": { 237 | "model_id": "29c01ed2f022418ebc6a7c4f2d8210b4", 238 | "version_major": 2, 239 | "version_minor": 0 240 | }, 241 | "text/plain": [ 242 | "BokehModel(combine_events=True, render_bundle={'docs_json': {'f8a47e67-7cc8-4e1f-bc8e-c10325c540a2': {'version…" 243 | ] 244 | }, 245 | "execution_count": 6, 246 | "metadata": {}, 247 | "output_type": "execute_result" 248 | } 249 | ], 250 | "source": [ 251 | "# specific token specification\n", 252 | "tokenizer = AutoTokenizer.from_pretrained(\"delphi-suite/stories-tokenizer\")\n", 253 | "\n", 254 | "# Count the frequency of each token using torch.bincount\n", 255 | "token_counts = torch.bincount(tokenized_corpus_dataset[\"tokens\"].view(-1))\n", 256 | "\n", 257 | "# Get the indices that would sort the token counts in descending order\n", 258 | "sorted_indices = torch.argsort(token_counts, descending=True)\n", 259 | "\n", 260 | "# Get the token IDs in descending order of frequency\n", 261 | "valid_tok_ids = sorted_indices.tolist()\n", 262 | "def format_fix(s):\n", 263 | " if s.startswith(\" \"):\n", 264 | " return \"_\" + s[1:]\n", 265 | " return s\n", 266 | "vocab = {format_fix(tokenizer.decode(t, clean_up_tokenization_spaces=True)): t for t in sorted_indices.tolist() if token_counts[t] > 0}\n", 267 | "\n", 268 | "\n", 269 | "selector, selected_ids = token_selector(vocab) # use selected_ids as a dynamic variable\n", 270 | "pn.Row(selector, height=500).servable()" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 7, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "name": "stdout", 280 | "output_type": "stream", 281 | "text": [ 282 | "Selected IDs: [40, 2, 14]\n" 283 | ] 284 | } 285 | ], 286 | "source": [ 287 | "if not selected_ids:\n", 288 | " selected_ids = [40, 2, 14]\n", 289 | "print(\"Selected IDs:\", selected_ids)" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 8, 295 | "metadata": {}, 296 | "outputs": [ 297 | { 298 | "data": { 299 | "text/plain": [ 300 | "torch.Size([100, 512])" 301 | ] 302 | }, 303 | "execution_count": 8, 304 | "metadata": {}, 305 | "output_type": "execute_result" 306 | } 307 | ], 308 | "source": [ 309 | "list(next_logprobs_plot.values())[0].shape" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 9, 315 | "metadata": {}, 316 | "outputs": [ 317 | { 318 | "name": "stdout", 319 | "output_type": "stream", 320 | "text": [ 321 | "Processing model 100k\n" 322 | ] 323 | }, 324 | { 325 | "name": "stdout", 326 | "output_type": "stream", 327 | "text": [ 328 | "Processing model 200k\n", 329 | "Processing model 400k\n" 330 | ] 331 | }, 332 | { 333 | "data": { 334 | "application/vnd.jupyter.widget-view+json": { 335 | "model_id": "2084f5f7ca5e4aeca85b0f5c821eaced", 336 | "version_major": 2, 337 | "version_minor": 0 338 | }, 339 | "text/plain": [ 340 | "FigureWidget({\n", 341 | " 'data': [{'line': {'width': 0},\n", 342 | " 'marker': {'color': 'rgba(68, 68, 68, 0.3)'},\n", 343 | " 'mode': 'lines',\n", 344 | " 'name': 'Upper Bound',\n", 345 | " 'showlegend': False,\n", 346 | " 'type': 'scatter',\n", 347 | " 'uid': '3348f11d-9719-4274-9954-43a9dc8f2ce1',\n", 348 | " 'x': [100k, 200k, 400k],\n", 349 | " 'y': array([4.6017912 , 4.03893679, 3.46496367])},\n", 350 | " {'fill': 'tonexty',\n", 351 | " 'fillcolor': 'rgba(68, 68, 68, 0.3)',\n", 352 | " 'line': {'width': 0},\n", 353 | " 'marker': {'color': 'rgba(68, 68, 68, 0.3)'},\n", 354 | " 'mode': 'lines',\n", 355 | " 'name': 'Lower Bound',\n", 356 | " 'showlegend': False,\n", 357 | " 'type': 'scatter',\n", 358 | " 'uid': '2a90892b-69a9-49d6-bad0-086ee4837fcc',\n", 359 | " 'x': [100k, 200k, 400k],\n", 360 | " 'y': array([1.00667199, 0.88813308, 0.735852 ])},\n", 361 | " {'marker': {'color': 'rgb(31, 119, 180)', 'line': {'color': 'rgb(31, 119, 180)', 'width': 1}, 'size': 0},\n", 362 | " 'mode': 'lines',\n", 363 | " 'name': 'Means',\n", 364 | " 'type': 'scatter',\n", 365 | " 'uid': '24ce8dc7-90cf-48f7-9722-de6cc02ba5a1',\n", 366 | " 'x': [100k, 200k, 400k],\n", 367 | " 'y': array([1.39094847, 1.15670866, 0.93012363])}],\n", 368 | " 'layout': {'template': '...'}\n", 369 | "})" 370 | ] 371 | }, 372 | "execution_count": 9, 373 | "metadata": {}, 374 | "output_type": "execute_result" 375 | } 376 | ], 377 | "source": [ 378 | "model_group_stats = calc_model_group_stats( # i'm not sure if tokenized_corpus_dataset.tolist() is the right input, it was list(tokenized_corpus_dataset) before\n", 379 | " tokenized_corpus_dataset, next_logprobs_plot, selected_ids\n", 380 | ")\n", 381 | "performance_data = {}\n", 382 | "for suffix in suffixes:\n", 383 | " stats = model_group_stats[suffix]\n", 384 | " performance_data[suffix] = (\n", 385 | " -stats[\"median\"],\n", 386 | " -stats[\"75th\"],\n", 387 | " -stats[\"25th\"],\n", 388 | " )\n", 389 | "\n", 390 | "visualize_selected_tokens(performance_data, log_scale=True)" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 10, 396 | "metadata": {}, 397 | "outputs": [ 398 | { 399 | "data": { 400 | "application/vnd.jupyter.widget-view+json": { 401 | "model_id": "3dee1d95570945f3a119c830cdd6d9b6", 402 | "version_major": 2, 403 | "version_minor": 0 404 | }, 405 | "text/plain": [ 406 | "interactive(children=(FloatRangeSlider(value=(0.25, 0.75), description='Quantiles', max=1.0, step=0.05), Dropd…" 407 | ] 408 | }, 409 | "metadata": {}, 410 | "output_type": "display_data" 411 | }, 412 | { 413 | "data": { 414 | "text/plain": [ 415 | "" 416 | ] 417 | }, 418 | "execution_count": 10, 419 | "metadata": {}, 420 | "output_type": "execute_result" 421 | } 422 | ], 423 | "source": [ 424 | "def show_pos_map(\n", 425 | " quantile: tuple[float, float],\n", 426 | " model_name_1: str,\n", 427 | " model_name_2: str,\n", 428 | "):\n", 429 | " logprobs_diff = next_logprobs[model_name_2][\"logprobs\"] - next_logprobs[model_name_1][\"logprobs\"] # type: ignore\n", 430 | " pos_to_diff = get_all_tok_metrics_in_label(tokenized_corpus_dataset[\"tokens\"], selected_tokens=selected_ids, metrics=logprobs_diff, q_start=quantile[0], q_end=quantile[1]) # type: ignore\n", 431 | " try:\n", 432 | " _ = vis_pos_map(list(pos_to_diff.keys()), selected_ids, logprobs_diff, tokenized_corpus_dataset[\"tokens\"], tokenizer) # type: ignore\n", 433 | " except ValueError:\n", 434 | " if pos_to_diff == {}:\n", 435 | " print(\"No tokens found in this label\")\n", 436 | " return\n", 437 | "\n", 438 | "\n", 439 | "widgets.interact_manual(\n", 440 | " show_pos_map,\n", 441 | " quantile=widgets.FloatRangeSlider(\n", 442 | " min=0.0, max=1.0, step=0.05, description=\"Quantiles\"\n", 443 | " ),\n", 444 | " samples=widgets.IntSlider(min=1, max=5, description=\"Samples\", value=2),\n", 445 | " model_name_1=widgets.Dropdown(\n", 446 | " options=suffixes,\n", 447 | " description=\"Model 1\",\n", 448 | " value=\"100k\",\n", 449 | " ),\n", 450 | " model_name_2=widgets.Dropdown(\n", 451 | " options=suffixes,\n", 452 | " description=\"Model 2\",\n", 453 | " value=\"200k\",\n", 454 | " ),\n", 455 | ")" 456 | ] 457 | } 458 | ], 459 | "metadata": { 460 | "kernelspec": { 461 | "display_name": ".venv", 462 | "language": "python", 463 | "name": "python3" 464 | }, 465 | "language_info": { 466 | "codemirror_mode": { 467 | "name": "ipython", 468 | "version": 3 469 | }, 470 | "file_extension": ".py", 471 | "mimetype": "text/x-python", 472 | "name": "python", 473 | "nbconvert_exporter": "python", 474 | "pygments_lexer": "ipython3", 475 | "version": "3.10.13" 476 | } 477 | }, 478 | "nbformat": 4, 479 | "nbformat_minor": 2 480 | } 481 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "delphi" 3 | dynamic = ["version"] 4 | dependencies = [ 5 | "torch==2.1.2", 6 | "datasets==2.16.1", 7 | "tqdm==4.66.1", 8 | "jaxtyping==0.2.25", 9 | "beartype==0.18.2", 10 | "chardet==5.2.0", 11 | "plotly==5.18.0", 12 | "wandb==0.16.3", 13 | "dacite==1.8.1", 14 | "transformers==4.40.0", 15 | "platformdirs==4.2.2" 16 | ] 17 | 18 | [project.optional-dependencies] 19 | mamba_cuda = [ 20 | "mamba_ssm==1.2.0.post1", 21 | "causal-conv1d==1.2.0.post2", 22 | ] 23 | notebooks = [ 24 | "ipykernel==6.29.4", 25 | "panel==1.4.0", 26 | "jupyter_bokeh==4.0.1", 27 | "ipywidgets==8.1.1", 28 | "nbformat==5.9.2", 29 | ] 30 | dev = [ 31 | "pytest==7.4.4", 32 | "black==23.12.1", 33 | "isort==5.13.2", 34 | "pre-commit==3.6.0", 35 | ] 36 | 37 | [build-system] 38 | requires = ["setuptools", "wheel"] 39 | 40 | 41 | 42 | [tool.setuptools.dynamic] 43 | version = {attr = "delphi.__version__"} 44 | 45 | [tool.isort] 46 | profile = 'black' 47 | known_third_party = ['wandb'] 48 | 49 | [tool.pytest.ini_options] 50 | testpaths = ["tests"] -------------------------------------------------------------------------------- /scripts/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delphi-suite/delphi/1efcabec4bceffbf4e9a383d05958f8c0704d2ce/scripts/.gitkeep -------------------------------------------------------------------------------- /scripts/get_next_logprobs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | from collections.abc import Iterable 4 | 5 | import numpy as np 6 | import torch 7 | from datasets import Dataset 8 | from tqdm.auto import trange 9 | from transformers import AutoModelForCausalLM 10 | 11 | from delphi import utils 12 | 13 | torch.set_grad_enabled(False) 14 | 15 | 16 | def main( 17 | in_model_repo_id: str, 18 | branches: Iterable[str], 19 | in_dataset_repo_id: str, 20 | split: str, 21 | feature: str, 22 | batch_size: int, 23 | out_repo_id: str, 24 | ): 25 | """ 26 | Outputs the log probabilities of the next token for each token in the dataset. 27 | And uploads the resulting dataset to huggingface. 28 | """ 29 | in_dataset_split = utils.load_dataset_split_sequence_int32_feature( 30 | in_dataset_repo_id, split, feature 31 | ) 32 | in_dataset_split.set_format("torch") 33 | for branch in branches: 34 | print(f"Loading model='{in_model_repo_id}', {branch=}") 35 | model = AutoModelForCausalLM.from_pretrained(in_model_repo_id, revision=branch) 36 | logprobs_dataset = get_logprobs_single_model( 37 | model=model, 38 | dataset=in_dataset_split, 39 | feature=feature, 40 | batch_size=batch_size, 41 | ) 42 | logprobs_dataset.push_to_hub( 43 | repo_id=out_repo_id, 44 | split=utils.hf_split_to_split_name(split), 45 | revision=branch, 46 | ) 47 | 48 | 49 | def get_logprobs_single_model( 50 | model: AutoModelForCausalLM, 51 | dataset: Dataset, 52 | feature: str, 53 | batch_size: int, 54 | ) -> Dataset: 55 | n_seq = len(dataset) 56 | seq_len = len(dataset[0][feature]) 57 | logprobs = np.empty((n_seq, seq_len)) 58 | logprobs[:, 0] = float("nan") 59 | print("Running inference...") 60 | for i in trange(0, n_seq, batch_size): 61 | batch_tokens = dataset[i : i + batch_size][feature] 62 | logprobs[i : i + batch_size, 1:] = ( 63 | utils.get_all_and_next_logprobs(model, batch_tokens)[1].cpu().numpy() # type: ignore 64 | ) 65 | return Dataset.from_dict({"logprobs": [row for row in logprobs]}) 66 | 67 | 68 | if __name__ == "__main__": 69 | parser = argparse.ArgumentParser( 70 | description="Run inference and generate log probabilities." 71 | ) 72 | parser.add_argument( 73 | "--in-model-repo-id", 74 | "--im", 75 | type=str, 76 | required=True, 77 | help="The model", 78 | ) 79 | parser.add_argument( 80 | "--branches", 81 | help="comma separated branches of the model to use or 'ALL' to use all branches", 82 | type=str, 83 | default="main", 84 | required=False, 85 | ) 86 | 87 | parser.add_argument( 88 | "--in-dataset-repo-id", 89 | "--id", 90 | type=str, 91 | required=True, 92 | help="The tokenized dataset", 93 | ) 94 | parser.add_argument( 95 | "--feature", 96 | "-f", 97 | type=str, 98 | required=True, 99 | help="Name of the column containing token sequences in the input dataset", 100 | ) 101 | parser.add_argument( 102 | "--split", 103 | "-s", 104 | type=str, 105 | required=True, 106 | help="Split of the tokenized dataset, supports slicing like 'train[:10%%]'", 107 | ) 108 | parser.add_argument( 109 | "--out-repo-id", 110 | "-o", 111 | type=str, 112 | required=True, 113 | help="Where to upload the next logprobs", 114 | ) 115 | parser.add_argument( 116 | "--batch-size", 117 | "-b", 118 | type=int, 119 | default=80, 120 | help="How many sequences to evaluate at once", 121 | ) 122 | # TODO 123 | # parser.add_argument( 124 | # "--chunk-size", 125 | # "-c", 126 | # type=int, 127 | # default=200_000, 128 | # help="Size of the parquet chunks uploaded to HuggingFace", 129 | # ) 130 | args = parser.parse_args() 131 | 132 | branches = ( 133 | args.branches.split(",") 134 | if args.branches != "ALL" 135 | else utils.get_all_hf_branch_names(args.in_model_repo_id) 136 | ) 137 | 138 | main( 139 | in_model_repo_id=args.in_model_repo_id, 140 | branches=branches, 141 | in_dataset_repo_id=args.in_dataset_repo_id, 142 | split=args.split, 143 | feature=args.feature, 144 | batch_size=args.batch_size, 145 | out_repo_id=args.out_repo_id, 146 | ) 147 | -------------------------------------------------------------------------------- /scripts/tokenize_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import io 4 | import os 5 | from pathlib import Path 6 | 7 | from datasets import Dataset 8 | from huggingface_hub import HfApi 9 | from transformers import AutoTokenizer 10 | 11 | from delphi import utils 12 | from delphi.tokenization import get_tokenized_chunks 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser( 16 | description="Tokenize a text dataset using a specified tokenizer", 17 | allow_abbrev=False, 18 | ) 19 | 20 | parser.add_argument( 21 | "--in-dataset", 22 | "-i", 23 | type=str, 24 | required=True, 25 | help="Dataset you want to tokenize. Local path or HF repo id", 26 | ) 27 | parser.add_argument( 28 | "--feature", 29 | "-f", 30 | type=str, 31 | required=True, 32 | help="Name of the feature (column) containing text documents in the input dataset", 33 | ) 34 | parser.add_argument( 35 | "--split", 36 | "-s", 37 | type=str, 38 | required=True, 39 | help="Split of the dataset to be tokenized, supports slicing like 'train[:10%%]'", 40 | ) 41 | parser.add_argument( 42 | "--tokenizer", 43 | "-t", 44 | type=str, 45 | required=True, 46 | help="HF repo id or local directory containing the tokenizer", 47 | ) 48 | parser.add_argument( 49 | "--seq-len", 50 | "-l", 51 | type=int, 52 | required=True, 53 | help="Length of the tokenized sequences", 54 | ) 55 | parser.add_argument( 56 | "--batch-size", 57 | "-b", 58 | type=int, 59 | default=50, 60 | help="How many text documents to tokenize at once (default: 50)", 61 | ) 62 | parser.add_argument( 63 | "--chunk-size", 64 | "-c", 65 | type=int, 66 | default=200_000, 67 | help="Maximum number of tokenized sequences in a single parquet file (default: 200_000)", 68 | ) 69 | parser.add_argument( 70 | "--out-dir", 71 | type=str, 72 | required=False, 73 | help="Local directory to save the resulting dataset", 74 | ) 75 | parser.add_argument( 76 | "--out-repo", 77 | type=str, 78 | required=False, 79 | help="HF repo id to upload the resulting dataset", 80 | ) 81 | args = parser.parse_args() 82 | assert args.out_repo or args.out_dir, "You need to provide --out-repo or --out-dir" 83 | 84 | in_dataset_split = utils.load_dataset_split_string_feature( 85 | args.in_dataset, args.split, args.feature 86 | ) 87 | assert isinstance(in_dataset_split, Dataset) 88 | print(f"Loading tokenizer from '{args.tokenizer}'...") 89 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 90 | assert tokenizer.bos_token_id is not None, "Tokenizer must have a bos_token_id" 91 | assert tokenizer.eos_token_id is not None, "Tokenizer must have a eos_token_id" 92 | 93 | api = None 94 | if args.out_repo: 95 | api = HfApi() 96 | api.create_repo(repo_id=args.out_repo, repo_type="dataset", exist_ok=True) 97 | if args.out_dir: 98 | os.makedirs(args.out_dir, exist_ok=True) 99 | 100 | ds_chunks_it = get_tokenized_chunks( 101 | dataset_split=in_dataset_split, 102 | tokenizer=tokenizer, 103 | seq_len=args.seq_len, 104 | batch_size=args.batch_size, 105 | chunk_size=args.chunk_size, 106 | ) 107 | 108 | print(f"Tokenizing split='{args.split}'...") 109 | split_name = utils.hf_split_to_split_name(args.split) 110 | for chunk_idx, ds_chunk in enumerate(ds_chunks_it): 111 | chunk_name = f"{split_name}-{chunk_idx:05}.parquet" 112 | if args.out_dir: 113 | ds_parquet_chunk = Path(args.out_dir) / chunk_name 114 | print(f"Saving '{ds_parquet_chunk}'...") 115 | else: 116 | ds_parquet_chunk = io.BytesIO() 117 | ds_chunk.to_parquet(ds_parquet_chunk) 118 | if api: 119 | print(f"Uploading '{chunk_name}' to '{args.out_repo}'...") 120 | api.upload_file( 121 | path_or_fileobj=ds_parquet_chunk, 122 | path_in_repo=f"data/{chunk_name}", 123 | repo_id=args.out_repo, 124 | repo_type="dataset", 125 | ) 126 | print(f"Done saving/uploading '{chunk_name}'") 127 | -------------------------------------------------------------------------------- /scripts/train_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import logging 4 | import sys 5 | from pathlib import Path 6 | 7 | from delphi.train.config import build_config_from_files_and_overrides 8 | from delphi.train.training import run_training 9 | from delphi.train.utils import overrides_to_dict, save_results 10 | 11 | 12 | def add_logging_args(parser: argparse.ArgumentParser): 13 | logging_group = parser.add_mutually_exclusive_group() 14 | logging_group.add_argument( 15 | "-v", 16 | "--verbose", 17 | action="count", 18 | default=None, 19 | help="Increase verbosity level, repeatable (e.g. -vvv). Mutually exclusive with --silent, --loglevel", 20 | ) 21 | logging_group.add_argument( 22 | "-s", 23 | "--silent", 24 | action="store_true", 25 | help="Silence all logging. Mutually exclusive with --verbose, --loglevel", 26 | default=False, 27 | ) 28 | 29 | 30 | def set_logging(args: argparse.Namespace): 31 | logging.basicConfig(format="%(message)s") 32 | logging.getLogger().setLevel(logging.INFO) 33 | if args.verbose is not None: 34 | if args.verbose == 1: 35 | loglevel = logging.DEBUG 36 | elif args.verbose >= 2: 37 | loglevel = 0 38 | logging.getLogger().setLevel(loglevel) 39 | if args.silent: 40 | logging.getLogger().setLevel(logging.CRITICAL) 41 | else: 42 | logging_level_str = logging.getLevelName( 43 | logging.getLogger().getEffectiveLevel() 44 | ) 45 | print(f"set logging level to {logging_level_str}") 46 | 47 | 48 | def setup_parser() -> argparse.ArgumentParser: 49 | # Setup argparse 50 | parser = argparse.ArgumentParser( 51 | description="Train a delphi model", allow_abbrev=False 52 | ) 53 | parser.add_argument( 54 | "config_files", 55 | help=( 56 | "Path to json file(s) containing config values, e.g. 'primary_config.json secondary_config.json'." 57 | ), 58 | type=str, 59 | nargs="*", 60 | ) 61 | parser.add_argument( 62 | "--overrides", 63 | help=( 64 | "Override config values with space-separated declarations. " 65 | "e.g. `--overrides model_config.hidden_size=42 run_name=foo`" 66 | ), 67 | type=str, 68 | required=False, 69 | nargs="*", 70 | default=[], 71 | ) 72 | add_logging_args(parser) 73 | return parser 74 | 75 | 76 | def main(): 77 | parser = setup_parser() 78 | args = parser.parse_args() 79 | if len(sys.argv) == 1: 80 | parser.print_help() 81 | exit(0) 82 | set_logging(args) 83 | 84 | args_dict = overrides_to_dict(args.overrides) 85 | config_files = [Path(f) for f in args.config_files] 86 | config = build_config_from_files_and_overrides(config_files, args_dict) 87 | # run training 88 | results, run_context = run_training(config) 89 | # to save & upload to iterX folder/branch 90 | save_results(config, results, run_context, final=False) 91 | # to save & upload to main folder/branch 92 | save_results(config, results, run_context, final=True) 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /scripts/train_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | 4 | from datasets import Dataset, Features, Value 5 | from tokenizers import ByteLevelBPETokenizer # type: ignore 6 | from transformers import PreTrainedTokenizerFast 7 | 8 | from delphi import utils 9 | 10 | 11 | def train_byte_level_bpe( 12 | dataset: Dataset, feature: str, vocab_size: int 13 | ) -> PreTrainedTokenizerFast: 14 | tokenizer = ByteLevelBPETokenizer() 15 | text_generator = (example[feature] for example in dataset) # type: ignore 16 | tokenizer.train_from_iterator( 17 | text_generator, 18 | vocab_size=vocab_size, 19 | special_tokens=["", "", ""], 20 | show_progress=True, 21 | length=len(dataset), 22 | ) 23 | return PreTrainedTokenizerFast( 24 | tokenizer_object=tokenizer, 25 | bos_token="", 26 | eos_token="", 27 | pad_token="", 28 | ) 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser( 33 | description="Train a custom, reversible, BPE tokenizer (GPT2-like). You need to provide --out-repo or --out-dir.", 34 | allow_abbrev=False, 35 | ) 36 | 37 | parser.add_argument( 38 | "--in-dataset", 39 | "-i", 40 | type=str, 41 | required=True, 42 | help="Dataset you want to train the tokenizer on. Local path or HF repo id", 43 | ) 44 | parser.add_argument( 45 | "--feature", 46 | "-f", 47 | type=str, 48 | required=True, 49 | help="Name of the feature (column) containing text documents in the input dataset", 50 | ) 51 | parser.add_argument( 52 | "--split", 53 | "-s", 54 | type=str, 55 | required=True, 56 | help="Split of the dataset to be used for tokenizer training, supports slicing like 'train[:10%%]'", 57 | ) 58 | parser.add_argument( 59 | "--vocab-size", 60 | "-v", 61 | type=int, 62 | required=True, 63 | help="Vocabulary size of the tokenizer", 64 | ) 65 | parser.add_argument( 66 | "--out-dir", 67 | type=str, 68 | required=False, 69 | help="Local directory to save the resulting tokenizer", 70 | ) 71 | parser.add_argument( 72 | "--out-repo", 73 | type=str, 74 | required=False, 75 | help="HF repo id to upload the resulting tokenizer", 76 | ) 77 | args = parser.parse_args() 78 | assert args.out_repo or args.out_dir, "You need to provide --out-repo or --out-dir" 79 | 80 | in_dataset_split = utils.load_dataset_split_string_feature( 81 | args.in_dataset, args.split, args.feature 82 | ) 83 | assert isinstance(in_dataset_split, Dataset) 84 | tokenizer = train_byte_level_bpe( 85 | dataset=in_dataset_split, 86 | feature=args.feature, 87 | vocab_size=args.vocab_size, 88 | ) 89 | if args.out_dir: 90 | print(f"Saving tokenizer to '{args.out_dir}' directory...") 91 | tokenizer.save_pretrained(args.out_dir) 92 | print("Done.") 93 | if args.out_repo: 94 | print(f"Pushing tokenizer to HF repo '{args.out_repo}'...") 95 | tokenizer.push_to_hub( 96 | repo_id=args.out_repo, 97 | ) 98 | print("Done.") 99 | -------------------------------------------------------------------------------- /scripts/validate_configs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import pathlib 4 | 5 | from delphi.train.config import build_config_from_files_and_overrides 6 | from delphi.train.utils import init_model, overrides_to_dict 7 | 8 | 9 | def get_config_path_with_base(config_path: pathlib.Path) -> list[pathlib.Path]: 10 | """If config path is in directory which includes base.json, include that as the first config.""" 11 | if (config_path.parent / "base.json").exists(): 12 | return [config_path.parent / "base.json", config_path] 13 | return [config_path] 14 | 15 | 16 | def get_config_paths(config_path: str) -> list[list[pathlib.Path]]: 17 | """If config path is a directory, recursively glob all json files in it. Otherwise, just use the path and create a list of 1.""" 18 | paths = ( 19 | list(pathlib.Path(config_path).rglob("*.json")) 20 | if pathlib.Path(config_path).is_dir() 21 | else [pathlib.Path(config_path)] 22 | ) 23 | # exclude base.json files 24 | paths = [path for path in paths if not path.name.startswith("base")] 25 | # supplement non-base configs with base.json if it exists in same dir 26 | return [get_config_path_with_base(path) for path in paths] 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser() 31 | # we take one positional argument, a path to a directory or config 32 | parser.add_argument( 33 | "config_path", 34 | type=str, 35 | help="path to a training config json or directory of training config jsons", 36 | ) 37 | parser.add_argument( 38 | "--overrides", 39 | help=( 40 | "Override config values with space-separated declarations. " 41 | "e.g. `--overrides model_config.hidden_size=42 run_name=foo`" 42 | ), 43 | type=str, 44 | required=False, 45 | nargs="*", 46 | default=[], 47 | ) 48 | parser.add_argument("--init", help="initialize the model", action="store_true") 49 | args = parser.parse_args() 50 | config_paths = get_config_paths(args.config_path) 51 | print( 52 | f"validating configs: {' | '.join(str(config_path[-1]) for config_path in config_paths)}" 53 | ) 54 | overrides = overrides_to_dict(args.overrides) 55 | errors = [] 56 | sizes = [] 57 | for config_path in config_paths: 58 | try: 59 | config = build_config_from_files_and_overrides(config_path, overrides) 60 | if args.init: 61 | model = init_model(config.model_config, seed=config.torch_seed) 62 | sizes.append((config_path, model.num_parameters())) 63 | except Exception as e: 64 | errors.append((config_path, e)) 65 | continue 66 | if errors: 67 | print("errors:") 68 | for config_path, e in errors: 69 | print(f" {config_path[-1]}: {e}") 70 | else: 71 | print("all configs loaded successfully") 72 | if sizes: 73 | print("model sizes:") 74 | for config_path, size in sizes: 75 | print(f" {config_path[-1]}: {size}") 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="delphi", 5 | packages=find_packages(where="."), 6 | package_dir={"": "."}, 7 | package_data={ 8 | "delphi": ["test_configs/**/*"], 9 | }, 10 | include_package_data=True, 11 | ) 12 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delphi-suite/delphi/1efcabec4bceffbf4e9a383d05958f8c0704d2ce/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_eval.py: -------------------------------------------------------------------------------- 1 | from math import isclose 2 | from typing import cast 3 | 4 | import pytest 5 | import torch 6 | from datasets import Dataset 7 | 8 | from delphi.eval import dict_filter_quantile, get_all_tok_metrics_in_label 9 | 10 | 11 | @pytest.mark.filterwarnings( 12 | "ignore::RuntimeWarning" 13 | ) # ignore warnings from numpy empty slice 14 | def test_dict_filter_quantile(): 15 | d = {1: 0.1, 2: 0.2, 3: 0.3, 4: 0.4, 5: 0.5} 16 | result = dict_filter_quantile(d, 0.2, 0.6) 17 | expected = {2: 0.2, 3: 0.3} 18 | 19 | # compare keys 20 | assert result.keys() == expected.keys() 21 | # compare values 22 | for k in result: 23 | assert isclose(result[k], expected[k], rel_tol=1e-6) 24 | 25 | # test with negative values 26 | d = {1: -0.1, 2: -0.2, 3: -0.3, 4: -0.4, 5: -0.5} 27 | result = dict_filter_quantile(d, 0.2, 0.6) 28 | expected = {3: -0.3, 4: -0.4} 29 | 30 | # compare keys 31 | assert result.keys() == expected.keys() 32 | # compare values 33 | for k in result: 34 | assert isclose(result[k], expected[k], rel_tol=1e-6) 35 | 36 | # test invalid quantile range 37 | with pytest.raises(ValueError): 38 | dict_filter_quantile(d, 0.6, 0.2) 39 | with pytest.raises(ValueError): 40 | dict_filter_quantile(d, 0.1, 1.1) 41 | with pytest.raises(ValueError): 42 | dict_filter_quantile(d, -0.1, 0.6) 43 | 44 | # test empty dict, will raise a warning 45 | result = dict_filter_quantile({}, 0.2, 0.6) 46 | assert result == {} 47 | 48 | 49 | def test_get_all_tok_metrics_in_label(): 50 | token_ids = Dataset.from_dict( 51 | {"tokens": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]} 52 | ).with_format("torch") 53 | selected_tokens = [2, 4, 6, 8] 54 | metrics = torch.tensor([[-1, 0.45, -0.33], [-1.31, 2.3, 0.6], [0.2, 0.8, 0.1]]) 55 | result = get_all_tok_metrics_in_label( 56 | token_ids["tokens"], # type: ignore 57 | selected_tokens, 58 | metrics, 59 | ) 60 | # key: (prompt_pos, tok_pos), value: logprob 61 | expected = { 62 | (0, 1): 0.45, 63 | (1, 0): -1.31, 64 | (1, 2): 0.6, 65 | (2, 1): 0.8, 66 | } 67 | 68 | # compare keys 69 | assert result.keys() == expected.keys() 70 | # compare values 71 | for k in result: 72 | assert isclose(cast(float, result[k]), expected[k], rel_tol=1e-6) # type: ignore 73 | 74 | # test with quantile filtering 75 | result_q = get_all_tok_metrics_in_label( 76 | token_ids["tokens"], # type: ignore 77 | selected_tokens, 78 | metrics, 79 | q_start=0.6, 80 | q_end=1.0, 81 | ) 82 | expected_q = { 83 | (1, 2): 0.6, 84 | (2, 1): 0.8, 85 | } 86 | 87 | # compare keys 88 | assert result_q.keys() == expected_q.keys() 89 | # compare values 90 | for k in result_q: 91 | assert isclose(cast(float, result_q[k]), expected_q[k], rel_tol=1e-6) # type: ignore 92 | -------------------------------------------------------------------------------- /tests/test_tokeniation.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import random 3 | 4 | import pytest 5 | from datasets import Dataset 6 | from transformers import AutoTokenizer 7 | 8 | from delphi.tokenization import extend_deque, make_new_sample, tokenize_dataset 9 | 10 | 11 | @pytest.fixture 12 | def tokenizer(): 13 | return AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer") 14 | 15 | 16 | def make_random_document(tokenizer): 17 | all_token_ids = range(2, tokenizer.vocab_size) 18 | n_tokens = random.randint(100, 800) 19 | random_tokens = random.choices(all_token_ids, k=n_tokens) 20 | return tokenizer.decode(random_tokens) 21 | 22 | 23 | def get_random_feature_name(): 24 | return "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=10)) 25 | 26 | 27 | def test_extend_deque(tokenizer): 28 | CTX_SIZE = 10 29 | BATCH_SIZE = 2 30 | # generate 100 random stories 31 | documents = [make_random_document(tokenizer) for _ in range(100)] 32 | feature_name = get_random_feature_name() 33 | dataset = Dataset.from_dict({feature_name: documents}) 34 | 35 | prompt_idx = 0 36 | deq = collections.deque() 37 | 38 | while prompt_idx < len(dataset): 39 | prompt_idx = extend_deque( 40 | deq, CTX_SIZE, dataset, prompt_idx, tokenizer, BATCH_SIZE 41 | ) 42 | if prompt_idx < len(dataset) - 1: 43 | # assert that the deque has grown large enough in each round 44 | assert len(deq) >= CTX_SIZE 45 | while len(deq) >= CTX_SIZE: 46 | for _ in range(CTX_SIZE - 1): 47 | deq.popleft() 48 | 49 | 50 | def test_make_new_sample(tokenizer): 51 | for _ in range(100): 52 | total_tokens = random.randint(100, 1000) 53 | context_size = random.randint(5, total_tokens // 2) 54 | dq = collections.deque(random.choices(range(3, 1000), k=total_tokens)) 55 | samples = [] 56 | while len(dq) >= context_size: 57 | samples.append(make_new_sample(dq, context_size, tokenizer.bos_token_id)) 58 | tokens_cnt = 0 59 | for i, sample in enumerate(samples): 60 | assert sample[0] == tokenizer.bos_token_id 61 | if i > 0: 62 | # assert that there is an overlap of the last token in the previous sample 63 | # and the first token in its following sample 64 | assert sample[1] == samples[i - 1][-1] 65 | tokens_cnt += len(sample) 66 | 67 | # We discard the last chunk so the following lines are only for testing 68 | tokens_cnt += 1 + len(dq) # the last batch with BOS in the beginning 69 | assert tokens_cnt == total_tokens + ( 70 | 2 * len(samples) + 1 71 | ) # BOS for each batch + overlapping of the last tokens in the batches 72 | assert len(dq) > 0 # always leaving at least one element in the deque 73 | 74 | 75 | def test_tokenize_dataset(tokenizer): 76 | SEQ_LEN = 11 77 | BATCH_SIZE = 2 78 | 79 | documents = [ 80 | "Once upon a", 81 | "Mother woke up alert. She put on her coat", 82 | "Once upon a time, in a small town, there was a weird", 83 | "Once upon a time, there was a", 84 | "Sara and Tom are friends. They like to play in the park.", 85 | ] 86 | feature_name = get_random_feature_name() 87 | dataset = Dataset.from_dict({feature_name: documents}) 88 | expected = [ 89 | [0, 432, 441, 261, 1, 47, 500, 1946, 369, 3444, 16], 90 | [0, 16, 341, 577, 356, 338, 1888, 1, 432, 441, 261], 91 | [0, 261, 400, 14, 315, 261, 561, 1006, 14, 403, 285], 92 | [0, 285, 261, 2607, 1, 432, 441, 261, 400, 14, 403], 93 | [0, 403, 285, 261, 1, 1371, 269, 416, 485, 413, 16], 94 | ] 95 | actual = [x for x in tokenize_dataset(dataset, tokenizer, SEQ_LEN, BATCH_SIZE)] 96 | assert actual == expected 97 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | 4 | import torch 5 | 6 | from delphi.utils import gather_logprobs, hf_split_to_split_name 7 | 8 | 9 | def random_string(length: int) -> str: 10 | return "".join(random.choices(string.ascii_lowercase, k=length)) 11 | 12 | 13 | def test_hf_split_to_split_name(): 14 | random_split_name = random_string(5) 15 | assert hf_split_to_split_name(random_split_name) == random_split_name 16 | assert hf_split_to_split_name(f"{random_split_name}[:10%]") == random_split_name 17 | assert hf_split_to_split_name(f"{random_split_name}[10%:]") == random_split_name 18 | assert hf_split_to_split_name(f"{random_split_name}[10%:20%]") == random_split_name 19 | assert hf_split_to_split_name(f"{random_split_name}[:200]") == random_split_name 20 | assert hf_split_to_split_name(f"{random_split_name}[200:]") == random_split_name 21 | assert hf_split_to_split_name(f"{random_split_name}[200:400]") == random_split_name 22 | 23 | 24 | def test_gather_logprobs(): 25 | # vocab size = 3 26 | logprobs = torch.tensor( 27 | [ 28 | # batch 0 29 | [ 30 | # seq 0 31 | [0.00, 0.01, 0.02], 32 | # seq 1 33 | [0.10, 0.11, 0.12], 34 | ], 35 | # batch 1 36 | [ 37 | # seq 0 38 | [1.00, 1.01, 1.02], 39 | # seq 1 40 | [1.10, 1.11, 1.12], 41 | ], 42 | ] 43 | ) 44 | tokens = torch.tensor( 45 | [ 46 | # batch 0 47 | [0, 2], 48 | # batch 1 49 | [1, 2], 50 | ] 51 | ) 52 | expected_output = torch.tensor( 53 | [ 54 | # batch 0 55 | [0.00, 0.12], 56 | # batch 1 57 | [1.01, 1.12], 58 | ] 59 | ) 60 | result = gather_logprobs(logprobs, tokens) 61 | assert torch.allclose(result, expected_output) 62 | -------------------------------------------------------------------------------- /tests/train/config/test_config_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from delphi import TEST_CONFIGS_DIR 4 | from delphi.train.config.utils import ( 5 | _unoptionalize, 6 | build_config_from_files_and_overrides, 7 | dot_notation_to_dict, 8 | merge_dicts, 9 | merge_two_dicts, 10 | ) 11 | 12 | 13 | def test_configs(): 14 | test_configs = list(TEST_CONFIGS_DIR.glob("*.json")) 15 | for config in test_configs: 16 | build_config_from_files_and_overrides([config], {}) 17 | 18 | 19 | def test_merge_two_dicts(): 20 | dict1 = {"a": 1, "b": 2, "c": {"d": 3, "e": 4}} 21 | dict2 = {"a": 5, "c": {"d": 6}} 22 | merge_two_dicts(dict1, dict2) 23 | assert dict1 == {"a": 5, "b": 2, "c": {"d": 6, "e": 4}} 24 | 25 | 26 | def test_merge_dicts(): 27 | dict1 = {"a": 1, "b": 2, "c": {"d": 3, "e": 4}} 28 | dict2 = {"a": 5, "c": {"d": 6}} 29 | dict3 = {"a": 7, "b": 8, "c": {"d": 9, "e": 10}} 30 | merged = merge_dicts(dict1, dict2, dict3) 31 | assert merged == {"a": 7, "b": 8, "c": {"d": 9, "e": 10}} 32 | 33 | 34 | def test_dot_notation_to_dict(): 35 | vars = {"a.b.c": 4, "foo": False} 36 | result = dot_notation_to_dict(vars) 37 | assert result == {"a": {"b": {"c": 4}}, "foo": False} 38 | 39 | 40 | def test_build_config_from_files_and_overrides(): 41 | config_files = [TEST_CONFIGS_DIR / "debug.json"] 42 | overrides = {"model_config": {"hidden_size": 128}, "eval_iters": 5} 43 | config = build_config_from_files_and_overrides(config_files, overrides) 44 | # check overrides 45 | assert config.model_config["hidden_size"] == 128 46 | assert config.eval_iters == 5 47 | # check base values 48 | assert config.max_epochs == 2 49 | assert config.dataset.path == "delphi-suite/stories-tokenized" 50 | 51 | 52 | def test_unoptionalize(): 53 | assert _unoptionalize(int) == int 54 | assert _unoptionalize(Optional[str]) == str 55 | -------------------------------------------------------------------------------- /tests/train/test_shuffle.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import pytest 4 | 5 | from delphi.train.shuffle import RNG, shuffle_epoch, shuffle_list 6 | 7 | 8 | def test_rng(): 9 | """ 10 | Compare to the following C++ code: 11 | 12 | #include 13 | #include 14 | 15 | int main() { 16 | unsigned int seed = 12345; 17 | std::minstd_rand generator(seed); 18 | 19 | for (int i = 0; i < 5; i++) 20 | std::cout << generator() << ", "; 21 | } 22 | """ 23 | rng = RNG(12345) 24 | expected = [595905495, 1558181227, 1498755989, 2021244883, 887213142] 25 | for val in expected: 26 | assert rng() == val 27 | 28 | 29 | @pytest.mark.parametrize( 30 | "input_list, seed", 31 | [(random.sample(range(100), 10), random.randint(1, 1000)) for _ in range(5)], 32 | ) 33 | def test_shuffle_list(input_list, seed): 34 | original_list = input_list.copy() 35 | shuffle_list(input_list, seed) 36 | assert sorted(input_list) == sorted(original_list) 37 | 38 | 39 | @pytest.mark.parametrize( 40 | "seed, epoch_nr, expected", 41 | [ 42 | (1, 1, [2, 5, 1, 3, 4]), 43 | (2, 5, [2, 1, 4, 5, 3]), 44 | (3, 10, [1, 4, 3, 5, 2]), 45 | (4, 100, [3, 4, 5, 1, 2]), 46 | ], 47 | ) 48 | def test_shuffle_epoch(seed, epoch_nr, expected): 49 | samples = [1, 2, 3, 4, 5] 50 | shuffle_epoch(samples, seed, epoch_nr) 51 | assert samples == expected 52 | -------------------------------------------------------------------------------- /tests/train/test_train_step.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | 3 | import dacite 4 | import pytest 5 | import torch 6 | from datasets import Dataset 7 | from jaxtyping import Float 8 | from transformers import PreTrainedModel 9 | 10 | from delphi import TEST_CONFIGS_DIR 11 | from delphi.train.config import TrainingConfig 12 | from delphi.train.config.utils import build_config_from_files_and_overrides 13 | from delphi.train.train_step import accumulate_gradients, train_step 14 | from delphi.train.utils import ( 15 | ModelTrainingState, 16 | gen_minibatches, 17 | init_model, 18 | setup_determinism, 19 | ) 20 | from delphi.utils import get_all_and_next_logprobs 21 | 22 | 23 | def load_test_config(preset_name: str) -> TrainingConfig: 24 | """Load a test config by name, e.g. `load_preset("debug")`.""" 25 | preset_path = TEST_CONFIGS_DIR / f"{preset_name}.json" 26 | return build_config_from_files_and_overrides([preset_path], {}) 27 | 28 | 29 | @pytest.fixture 30 | def dataset(): 31 | ds = Dataset.from_dict( 32 | { 33 | "tokens": [list(range(i, i + 512)) for i in range(64)], 34 | }, 35 | ) 36 | ds.set_format(type="torch") 37 | return ds 38 | 39 | 40 | @pytest.fixture 41 | def model(): 42 | setup_determinism(42) 43 | return init_model( 44 | { 45 | "model_class": "LlamaForCausalLM", 46 | "hidden_size": 48, 47 | "intermediate_size": 48, 48 | "num_attention_heads": 2, 49 | "num_hidden_layers": 2, 50 | "num_key_value_heads": 2, 51 | "vocab_size": 4096, 52 | }, 53 | seed=42, 54 | ) 55 | 56 | 57 | def get_params(model: torch.nn.Module) -> Float[torch.Tensor, "params"]: 58 | params = [ 59 | (name, param) for name, param in model.named_parameters() if param.requires_grad 60 | ] 61 | params.sort(key=lambda x: x[0]) 62 | return torch.cat([p.flatten() for _, p in params]) 63 | 64 | 65 | def test_basic_reproducibility(dataset, model): 66 | """ 67 | check that the same batch produces the same gradient 68 | """ 69 | # setup 70 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 71 | model_training_state = ModelTrainingState( 72 | model=model, 73 | optimizer=optimizer, 74 | iter_num=0, 75 | epoch=0, 76 | step=0, 77 | train_loss=0.0, 78 | lr=0.01, 79 | last_training_step_time=0.0, 80 | ) 81 | device = torch.device("cpu") 82 | indices = list(range(len(dataset))) 83 | train_step( 84 | model_training_state, dataset, load_test_config("debug"), device, indices 85 | ) 86 | 87 | params = get_params(model) 88 | 89 | assert torch.isclose( 90 | params[[1000, 2000, 3000]], 91 | torch.tensor([-0.01780166, -0.00762226, 0.03532362]), 92 | ).all() 93 | 94 | 95 | def test_performance(dataset, model): 96 | """check that predictions improve with training""" 97 | # setup 98 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) 99 | model_training_state = ModelTrainingState( 100 | model=model, 101 | optimizer=optimizer, 102 | iter_num=0, 103 | epoch=0, 104 | step=0, 105 | train_loss=0.0, 106 | lr=1e-3, 107 | last_training_step_time=0.0, 108 | ) 109 | device = torch.device("cpu") 110 | indices = list(range(len(dataset))) 111 | 112 | next_logprobs_before = get_all_and_next_logprobs(model, dataset["tokens"])[1] 113 | 114 | train_step( 115 | model_training_state, dataset, load_test_config("debug"), device, indices 116 | ) 117 | 118 | next_logprobs_after = get_all_and_next_logprobs(model, dataset["tokens"])[1] 119 | # should generally increse with training 120 | frac_increased = (next_logprobs_after > next_logprobs_before).float().mean().item() 121 | assert frac_increased > 0.95 122 | 123 | 124 | def get_grads(model: PreTrainedModel) -> Float[torch.Tensor, "grads"]: 125 | grads = [ 126 | param.grad.flatten() for param in model.parameters() if param.grad is not None 127 | ] 128 | return torch.cat(grads) 129 | 130 | 131 | def test_accumulate_gradients_accumulates(dataset, model): 132 | """ 133 | check that gradient accumulation works as expected and doesn't reset on each microstep 134 | """ 135 | batch_size = 3 136 | num_batches = 3 137 | # first 2 mini-batches different, last mini-batch the same 138 | indices_set_a = [1, 2, 3] + [4, 5, 6] + [7, 8, 9] 139 | indices_set_b = [7, 8, 9] * 3 140 | 141 | kwargs = dict( 142 | dataset=dataset, 143 | batch_size=batch_size, 144 | num_minibatches=num_batches, 145 | step=0, 146 | device=torch.device("cpu"), 147 | feature_name="tokens", 148 | ) 149 | batches_a = gen_minibatches(indices=indices_set_a, **kwargs) # type: ignore 150 | batches_b = gen_minibatches(indices=indices_set_b, **kwargs) # type: ignore 151 | 152 | # accumulate 153 | _total_loss = accumulate_gradients(model, batches_a, num_batches) 154 | 155 | grads_a = get_grads(model) 156 | 157 | # reset grad on model 158 | model.zero_grad() 159 | 160 | _total_loss = accumulate_gradients(model, batches_b, num_batches) 161 | grads_b = get_grads(model) 162 | 163 | # test 164 | assert not torch.isclose(grads_a, grads_b).all() 165 | 166 | 167 | def test_accumulate_gradients_consistent(dataset, model): 168 | """ 169 | Validate that the gradients are consistent when the same batch is passed to accumulate_gradients 170 | """ 171 | # setup 172 | num_batches = 3 173 | batch_size = 3 174 | kwargs = dict( 175 | indices=list(range(1, 10)), 176 | dataset=dataset, 177 | batch_size=batch_size, 178 | num_minibatches=num_batches, 179 | step=0, 180 | device=torch.device("cpu"), 181 | feature_name="tokens", 182 | ) 183 | batches_a = gen_minibatches(**kwargs) # type: ignore 184 | batches_aa = gen_minibatches(**kwargs) # type: ignore 185 | 186 | # accumulate 187 | _total_loss = accumulate_gradients(model, batches_a, num_batches) 188 | 189 | grads_a = get_grads(model) 190 | 191 | # reset grad on model 192 | model.zero_grad() 193 | 194 | _total_loss = accumulate_gradients(model, batches_aa, num_batches) 195 | grads_aa = get_grads(model) 196 | 197 | # test 198 | assert torch.isclose(grads_a, grads_aa).all() 199 | 200 | 201 | def get_model_training_state(model, optimizer, step): 202 | return ModelTrainingState( 203 | model=model, 204 | optimizer=optimizer, 205 | iter_num=0, 206 | epoch=0, 207 | step=step, 208 | train_loss=0.0, 209 | lr=0.01, 210 | last_training_step_time=0.0, 211 | ) 212 | 213 | 214 | def test_train_step_no_training(dataset, model): 215 | """ 216 | Test train_step when no_training is set to True 217 | """ 218 | # setup 219 | config_dict = asdict(load_test_config("debug")) 220 | config_dict["debug_config"] = {"no_training": True} 221 | config = dacite.from_dict(TrainingConfig, config_dict) 222 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 223 | model_training_state = get_model_training_state( 224 | model=model, optimizer=optimizer, step=0 225 | ) 226 | device = torch.device("cpu") 227 | indices = [0, 1, 2, 3] 228 | 229 | # (don't) train 230 | train_step(model_training_state, dataset, config, device, indices) 231 | 232 | # test 233 | assert model_training_state.train_loss == 0.0 234 | 235 | 236 | def test_train_step_with_training(dataset, model): 237 | """ 238 | Test train_step when training is performed 239 | """ 240 | # setup 241 | config_dict = asdict(load_test_config("debug")) 242 | config_dict["debug_config"] = {"no_training": False} 243 | config_dict["batch_size"] = 16 244 | config_dict["optimizer"] = {"gradient_accumulation_steps": 4} 245 | config_dict["grad_clip"] = 1.0 246 | config = dacite.from_dict(TrainingConfig, config_dict) 247 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 248 | model_training_state = get_model_training_state( 249 | model=model, optimizer=optimizer, step=0 250 | ) 251 | device = torch.device("cpu") 252 | indices = list(range(len(dataset))) 253 | 254 | # train 255 | train_step(model_training_state, dataset, config, device, indices) 256 | 257 | # test 258 | assert model_training_state.train_loss > 0.0 259 | -------------------------------------------------------------------------------- /tests/train/test_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import Dataset 3 | 4 | from delphi.train.utils import gen_minibatches 5 | 6 | 7 | def test_gen_minibatches(): 8 | DS_SIZE = 6 9 | SEQ_LEN = 5 10 | NUM_MINIBATCHES = 3 11 | MINIBATCH_SIZE = DS_SIZE // NUM_MINIBATCHES 12 | FEATURE_NAME = "tokens" 13 | dataset = Dataset.from_dict( 14 | { 15 | FEATURE_NAME: [list(range(i, i + SEQ_LEN)) for i in range(DS_SIZE)], 16 | }, 17 | ) 18 | dataset.set_format(type="torch") 19 | indices = list(range(DS_SIZE - 1, -1, -1)) 20 | minibatches = gen_minibatches( 21 | dataset=dataset, 22 | batch_size=DS_SIZE, 23 | num_minibatches=NUM_MINIBATCHES, 24 | step=0, 25 | indices=indices, 26 | device=torch.device("cpu"), 27 | feature_name=FEATURE_NAME, 28 | ) 29 | minibatches = list(minibatches) 30 | assert len(minibatches) == NUM_MINIBATCHES 31 | shuffled_ds = dataset[FEATURE_NAME][indices] # type: ignore 32 | for i, minibatch in enumerate(minibatches): 33 | assert minibatch.shape == (MINIBATCH_SIZE, SEQ_LEN) 34 | expected_mb = shuffled_ds[i * MINIBATCH_SIZE : (i + 1) * MINIBATCH_SIZE] 35 | assert torch.all(minibatch == expected_mb) 36 | -------------------------------------------------------------------------------- /tests/train/test_wandb_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import asdict 3 | from unittest.mock import MagicMock, patch 4 | 5 | import pytest 6 | import torch 7 | import transformers 8 | from dacite import from_dict 9 | 10 | from delphi import TEST_CONFIGS_DIR 11 | from delphi.train.config import TrainingConfig 12 | from delphi.train.config.utils import build_config_from_files_and_overrides 13 | from delphi.train.utils import ModelTrainingState, initialize_model_training_state 14 | from delphi.train.wandb_utils import init_wandb, log_to_wandb 15 | 16 | 17 | @pytest.fixture 18 | def mock_training_config() -> TrainingConfig: 19 | preset_path = TEST_CONFIGS_DIR / "debug.json" 20 | overrides = { 21 | "run_name": "test_run", 22 | "wandb": "test_entity/test_project", 23 | } 24 | return build_config_from_files_and_overrides([preset_path], overrides) 25 | 26 | 27 | @pytest.fixture 28 | def mock_model_training_state(mock_training_config): 29 | device = torch.device(mock_training_config.device) 30 | # this is gross and horrible, sorry, I'm rushing 31 | mts = initialize_model_training_state(config=mock_training_config, device=device) 32 | mts.step = 1 33 | mts.epoch = 1 34 | mts.iter_num = 1 35 | mts.lr = 0.001 36 | return mts 37 | 38 | 39 | @patch("wandb.init") 40 | def test_init_wandb(mock_wandb_init: MagicMock, mock_training_config): 41 | init_wandb(mock_training_config) 42 | mock_wandb_init.assert_called_once_with( 43 | entity="test_entity", 44 | project="test_project", 45 | name="test_run", 46 | config=asdict(mock_training_config), 47 | ) 48 | 49 | 50 | @patch("wandb.log") 51 | def test_log_to_wandb(mock_wandb_log: MagicMock): 52 | model = MagicMock(spec=transformers.LlamaForCausalLM) 53 | optimizer = MagicMock(spec=torch.optim.AdamW) 54 | log_to_wandb( 55 | mts=ModelTrainingState( 56 | model=model, 57 | optimizer=optimizer, 58 | step=5, 59 | epoch=1, 60 | iter_num=55, 61 | lr=0.007, 62 | last_training_step_time=0.0, 63 | ), 64 | losses={"train": 0.5, "val": 0.4}, 65 | tokens_so_far=4242, 66 | ) 67 | assert mock_wandb_log.call_count == 1 68 | mock_wandb_log.assert_called_with( 69 | { 70 | "epoch": 1, 71 | "epoch_iter": 5, 72 | "global_iter": 55, 73 | "tokens": 4242, 74 | "loss/train": 0.5, 75 | "loss/val": 0.4, 76 | "lr": 0.007, 77 | }, 78 | step=55, 79 | ) 80 | --------------------------------------------------------------------------------