├── .gitignore ├── ACKNOWLEDGMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── assets ├── dd-animation.gif ├── sequence_length.png └── throughput.png ├── open_lm.patch └── scripts ├── dclm_download.py ├── get_dd_params.py ├── get_stats.py ├── make_dd_buckets.py ├── train_dd.sh └── wiki_download.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea* 2 | -------------------------------------------------------------------------------- /ACKNOWLEDGMENTS: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this Software may utilize the following copyrighted 3 | material, the use of which is hereby acknowledged. 4 | 5 | ------------------------------------------------ 6 | OpenLM (open_lm) 7 | 8 | MIT License 9 | 10 | Copyright (c) 2023 mlfoundations 11 | 12 | Permission is hereby granted, free of charge, to any person obtaining a copy 13 | of this software and associated documentation files (the "Software"), to deal 14 | in the Software without restriction, including without limitation the rights 15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 16 | copies of the Software, and to permit persons to whom the Software is 17 | furnished to do so, subject to the following conditions: 18 | 19 | The above copyright notice and this permission notice shall be included in all 20 | copies or substantial portions of the Software. 21 | 22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 28 | SOFTWARE. 29 | 30 | ------------------------------------------------ 31 | HuggingFace Datasets (datasets) 32 | 33 | 34 | 35 | Apache License 36 | Version 2.0, January 2004 37 | http://www.apache.org/licenses/ 38 | 39 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 40 | 41 | 1. Definitions. 42 | 43 | "License" shall mean the terms and conditions for use, reproduction, 44 | and distribution as defined by Sections 1 through 9 of this document. 45 | 46 | "Licensor" shall mean the copyright owner or entity authorized by 47 | the copyright owner that is granting the License. 48 | 49 | "Legal Entity" shall mean the union of the acting entity and all 50 | other entities that control, are controlled by, or are under common 51 | control with that entity. For the purposes of this definition, 52 | "control" means (i) the power, direct or indirect, to cause the 53 | direction or management of such entity, whether by contract or 54 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 55 | outstanding shares, or (iii) beneficial ownership of such entity. 56 | 57 | "You" (or "Your") shall mean an individual or Legal Entity 58 | exercising permissions granted by this License. 59 | 60 | "Source" form shall mean the preferred form for making modifications, 61 | including but not limited to software source code, documentation 62 | source, and configuration files. 63 | 64 | "Object" form shall mean any form resulting from mechanical 65 | transformation or translation of a Source form, including but 66 | not limited to compiled object code, generated documentation, 67 | and conversions to other media types. 68 | 69 | "Work" shall mean the work of authorship, whether in Source or 70 | Object form, made available under the License, as indicated by a 71 | copyright notice that is included in or attached to the work 72 | (an example is provided in the Appendix below). 73 | 74 | "Derivative Works" shall mean any work, whether in Source or Object 75 | form, that is based on (or derived from) the Work and for which the 76 | editorial revisions, annotations, elaborations, or other modifications 77 | represent, as a whole, an original work of authorship. For the purposes 78 | of this License, Derivative Works shall not include works that remain 79 | separable from, or merely link (or bind by name) to the interfaces of, 80 | the Work and Derivative Works thereof. 81 | 82 | "Contribution" shall mean any work of authorship, including 83 | the original version of the Work and any modifications or additions 84 | to that Work or Derivative Works thereof, that is intentionally 85 | submitted to Licensor for inclusion in the Work by the copyright owner 86 | or by an individual or Legal Entity authorized to submit on behalf of 87 | the copyright owner. For the purposes of this definition, "submitted" 88 | means any form of electronic, verbal, or written communication sent 89 | to the Licensor or its representatives, including but not limited to 90 | communication on electronic mailing lists, source code control systems, 91 | and issue tracking systems that are managed by, or on behalf of, the 92 | Licensor for the purpose of discussing and improving the Work, but 93 | excluding communication that is conspicuously marked or otherwise 94 | designated in writing by the copyright owner as "Not a Contribution." 95 | 96 | "Contributor" shall mean Licensor and any individual or Legal Entity 97 | on behalf of whom a Contribution has been received by Licensor and 98 | subsequently incorporated within the Work. 99 | 100 | 2. Grant of Copyright License. Subject to the terms and conditions of 101 | this License, each Contributor hereby grants to You a perpetual, 102 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 103 | copyright license to reproduce, prepare Derivative Works of, 104 | publicly display, publicly perform, sublicense, and distribute the 105 | Work and such Derivative Works in Source or Object form. 106 | 107 | 3. Grant of Patent License. Subject to the terms and conditions of 108 | this License, each Contributor hereby grants to You a perpetual, 109 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 110 | (except as stated in this section) patent license to make, have made, 111 | use, offer to sell, sell, import, and otherwise transfer the Work, 112 | where such license applies only to those patent claims licensable 113 | by such Contributor that are necessarily infringed by their 114 | Contribution(s) alone or by combination of their Contribution(s) 115 | with the Work to which such Contribution(s) was submitted. If You 116 | institute patent litigation against any entity (including a 117 | cross-claim or counterclaim in a lawsuit) alleging that the Work 118 | or a Contribution incorporated within the Work constitutes direct 119 | or contributory patent infringement, then any patent licenses 120 | granted to You under this License for that Work shall terminate 121 | as of the date such litigation is filed. 122 | 123 | 4. Redistribution. You may reproduce and distribute copies of the 124 | Work or Derivative Works thereof in any medium, with or without 125 | modifications, and in Source or Object form, provided that You 126 | meet the following conditions: 127 | 128 | (a) You must give any other recipients of the Work or 129 | Derivative Works a copy of this License; and 130 | 131 | (b) You must cause any modified files to carry prominent notices 132 | stating that You changed the files; and 133 | 134 | (c) You must retain, in the Source form of any Derivative Works 135 | that You distribute, all copyright, patent, trademark, and 136 | attribution notices from the Source form of the Work, 137 | excluding those notices that do not pertain to any part of 138 | the Derivative Works; and 139 | 140 | (d) If the Work includes a "NOTICE" text file as part of its 141 | distribution, then any Derivative Works that You distribute must 142 | include a readable copy of the attribution notices contained 143 | within such NOTICE file, excluding those notices that do not 144 | pertain to any part of the Derivative Works, in at least one 145 | of the following places: within a NOTICE text file distributed 146 | as part of the Derivative Works; within the Source form or 147 | documentation, if provided along with the Derivative Works; or, 148 | within a display generated by the Derivative Works, if and 149 | wherever such third-party notices normally appear. The contents 150 | of the NOTICE file are for informational purposes only and 151 | do not modify the License. You may add Your own attribution 152 | notices within Derivative Works that You distribute, alongside 153 | or as an addendum to the NOTICE text from the Work, provided 154 | that such additional attribution notices cannot be construed 155 | as modifying the License. 156 | 157 | You may add Your own copyright statement to Your modifications and 158 | may provide additional or different license terms and conditions 159 | for use, reproduction, or distribution of Your modifications, or 160 | for any such Derivative Works as a whole, provided Your use, 161 | reproduction, and distribution of the Work otherwise complies with 162 | the conditions stated in this License. 163 | 164 | 5. Submission of Contributions. Unless You explicitly state otherwise, 165 | any Contribution intentionally submitted for inclusion in the Work 166 | by You to the Licensor shall be under the terms and conditions of 167 | this License, without any additional terms or conditions. 168 | Notwithstanding the above, nothing herein shall supersede or modify 169 | the terms of any separate license agreement you may have executed 170 | with Licensor regarding such Contributions. 171 | 172 | 6. Trademarks. This License does not grant permission to use the trade 173 | names, trademarks, service marks, or product names of the Licensor, 174 | except as required for reasonable and customary use in describing the 175 | origin of the Work and reproducing the content of the NOTICE file. 176 | 177 | 7. Disclaimer of Warranty. Unless required by applicable law or 178 | agreed to in writing, Licensor provides the Work (and each 179 | Contributor provides its Contributions) on an "AS IS" BASIS, 180 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 181 | implied, including, without limitation, any warranties or conditions 182 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 183 | PARTICULAR PURPOSE. You are solely responsible for determining the 184 | appropriateness of using or redistributing the Work and assume any 185 | risks associated with Your exercise of permissions under this License. 186 | 187 | 8. Limitation of Liability. In no event and under no legal theory, 188 | whether in tort (including negligence), contract, or otherwise, 189 | unless required by applicable law (such as deliberate and grossly 190 | negligent acts) or agreed to in writing, shall any Contributor be 191 | liable to You for damages, including any direct, indirect, special, 192 | incidental, or consequential damages of any character arising as a 193 | result of this License or out of the use or inability to use the 194 | Work (including but not limited to damages for loss of goodwill, 195 | work stoppage, computer failure or malfunction, or any and all 196 | other commercial damages or losses), even if such Contributor 197 | has been advised of the possibility of such damages. 198 | 199 | 9. Accepting Warranty or Additional Liability. While redistributing 200 | the Work or Derivative Works thereof, You may choose to offer, 201 | and charge a fee for, acceptance of support, warranty, indemnity, 202 | or other liability obligations and/or rights consistent with this 203 | License. However, in accepting such obligations, You may act only 204 | on Your own behalf and on Your sole responsibility, not on behalf 205 | of any other Contributor, and only if You agree to indemnify, 206 | defend, and hold each Contributor harmless for any liability 207 | incurred by, or claims asserted against, such Contributor by reason 208 | of your accepting any such warranty or additional liability. 209 | 210 | END OF TERMS AND CONDITIONS 211 | 212 | APPENDIX: How to apply the Apache License to your work. 213 | 214 | To apply the Apache License to your work, attach the following 215 | boilerplate notice, with the fields enclosed by brackets "[]" 216 | replaced with your own identifying information. (Don't include 217 | the brackets!) The text should be enclosed in the appropriate 218 | comment syntax for the file format. We also recommend that a 219 | file or class name and description of purpose be included on the 220 | same "printed page" as the copyright notice for easier 221 | identification within third-party archives. 222 | 223 | Copyright [yyyy] [name of copyright owner] 224 | 225 | Licensed under the Apache License, Version 2.0 (the "License"); 226 | you may not use this file except in compliance with the License. 227 | You may obtain a copy of the License at 228 | 229 | http://www.apache.org/licenses/LICENSE-2.0 230 | 231 | Unless required by applicable law or agreed to in writing, software 232 | distributed under the License is distributed on an "AS IS" BASIS, 233 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 234 | See the License for the specific language governing permissions and 235 | limitations under the License. 236 | 237 | ------------------------------------------------ 238 | HuggingFace Transformers (transformers) 239 | 240 | Copyright 2018- The Hugging Face team. All rights reserved. 241 | 242 | Apache License 243 | Version 2.0, January 2004 244 | http://www.apache.org/licenses/ 245 | 246 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 247 | 248 | 1. Definitions. 249 | 250 | "License" shall mean the terms and conditions for use, reproduction, 251 | and distribution as defined by Sections 1 through 9 of this document. 252 | 253 | "Licensor" shall mean the copyright owner or entity authorized by 254 | the copyright owner that is granting the License. 255 | 256 | "Legal Entity" shall mean the union of the acting entity and all 257 | other entities that control, are controlled by, or are under common 258 | control with that entity. For the purposes of this definition, 259 | "control" means (i) the power, direct or indirect, to cause the 260 | direction or management of such entity, whether by contract or 261 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 262 | outstanding shares, or (iii) beneficial ownership of such entity. 263 | 264 | "You" (or "Your") shall mean an individual or Legal Entity 265 | exercising permissions granted by this License. 266 | 267 | "Source" form shall mean the preferred form for making modifications, 268 | including but not limited to software source code, documentation 269 | source, and configuration files. 270 | 271 | "Object" form shall mean any form resulting from mechanical 272 | transformation or translation of a Source form, including but 273 | not limited to compiled object code, generated documentation, 274 | and conversions to other media types. 275 | 276 | "Work" shall mean the work of authorship, whether in Source or 277 | Object form, made available under the License, as indicated by a 278 | copyright notice that is included in or attached to the work 279 | (an example is provided in the Appendix below). 280 | 281 | "Derivative Works" shall mean any work, whether in Source or Object 282 | form, that is based on (or derived from) the Work and for which the 283 | editorial revisions, annotations, elaborations, or other modifications 284 | represent, as a whole, an original work of authorship. For the purposes 285 | of this License, Derivative Works shall not include works that remain 286 | separable from, or merely link (or bind by name) to the interfaces of, 287 | the Work and Derivative Works thereof. 288 | 289 | "Contribution" shall mean any work of authorship, including 290 | the original version of the Work and any modifications or additions 291 | to that Work or Derivative Works thereof, that is intentionally 292 | submitted to Licensor for inclusion in the Work by the copyright owner 293 | or by an individual or Legal Entity authorized to submit on behalf of 294 | the copyright owner. For the purposes of this definition, "submitted" 295 | means any form of electronic, verbal, or written communication sent 296 | to the Licensor or its representatives, including but not limited to 297 | communication on electronic mailing lists, source code control systems, 298 | and issue tracking systems that are managed by, or on behalf of, the 299 | Licensor for the purpose of discussing and improving the Work, but 300 | excluding communication that is conspicuously marked or otherwise 301 | designated in writing by the copyright owner as "Not a Contribution." 302 | 303 | "Contributor" shall mean Licensor and any individual or Legal Entity 304 | on behalf of whom a Contribution has been received by Licensor and 305 | subsequently incorporated within the Work. 306 | 307 | 2. Grant of Copyright License. Subject to the terms and conditions of 308 | this License, each Contributor hereby grants to You a perpetual, 309 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 310 | copyright license to reproduce, prepare Derivative Works of, 311 | publicly display, publicly perform, sublicense, and distribute the 312 | Work and such Derivative Works in Source or Object form. 313 | 314 | 3. Grant of Patent License. Subject to the terms and conditions of 315 | this License, each Contributor hereby grants to You a perpetual, 316 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 317 | (except as stated in this section) patent license to make, have made, 318 | use, offer to sell, sell, import, and otherwise transfer the Work, 319 | where such license applies only to those patent claims licensable 320 | by such Contributor that are necessarily infringed by their 321 | Contribution(s) alone or by combination of their Contribution(s) 322 | with the Work to which such Contribution(s) was submitted. If You 323 | institute patent litigation against any entity (including a 324 | cross-claim or counterclaim in a lawsuit) alleging that the Work 325 | or a Contribution incorporated within the Work constitutes direct 326 | or contributory patent infringement, then any patent licenses 327 | granted to You under this License for that Work shall terminate 328 | as of the date such litigation is filed. 329 | 330 | 4. Redistribution. You may reproduce and distribute copies of the 331 | Work or Derivative Works thereof in any medium, with or without 332 | modifications, and in Source or Object form, provided that You 333 | meet the following conditions: 334 | 335 | (a) You must give any other recipients of the Work or 336 | Derivative Works a copy of this License; and 337 | 338 | (b) You must cause any modified files to carry prominent notices 339 | stating that You changed the files; and 340 | 341 | (c) You must retain, in the Source form of any Derivative Works 342 | that You distribute, all copyright, patent, trademark, and 343 | attribution notices from the Source form of the Work, 344 | excluding those notices that do not pertain to any part of 345 | the Derivative Works; and 346 | 347 | (d) If the Work includes a "NOTICE" text file as part of its 348 | distribution, then any Derivative Works that You distribute must 349 | include a readable copy of the attribution notices contained 350 | within such NOTICE file, excluding those notices that do not 351 | pertain to any part of the Derivative Works, in at least one 352 | of the following places: within a NOTICE text file distributed 353 | as part of the Derivative Works; within the Source form or 354 | documentation, if provided along with the Derivative Works; or, 355 | within a display generated by the Derivative Works, if and 356 | wherever such third-party notices normally appear. The contents 357 | of the NOTICE file are for informational purposes only and 358 | do not modify the License. You may add Your own attribution 359 | notices within Derivative Works that You distribute, alongside 360 | or as an addendum to the NOTICE text from the Work, provided 361 | that such additional attribution notices cannot be construed 362 | as modifying the License. 363 | 364 | You may add Your own copyright statement to Your modifications and 365 | may provide additional or different license terms and conditions 366 | for use, reproduction, or distribution of Your modifications, or 367 | for any such Derivative Works as a whole, provided Your use, 368 | reproduction, and distribution of the Work otherwise complies with 369 | the conditions stated in this License. 370 | 371 | 5. Submission of Contributions. Unless You explicitly state otherwise, 372 | any Contribution intentionally submitted for inclusion in the Work 373 | by You to the Licensor shall be under the terms and conditions of 374 | this License, without any additional terms or conditions. 375 | Notwithstanding the above, nothing herein shall supersede or modify 376 | the terms of any separate license agreement you may have executed 377 | with Licensor regarding such Contributions. 378 | 379 | 6. Trademarks. This License does not grant permission to use the trade 380 | names, trademarks, service marks, or product names of the Licensor, 381 | except as required for reasonable and customary use in describing the 382 | origin of the Work and reproducing the content of the NOTICE file. 383 | 384 | 7. Disclaimer of Warranty. Unless required by applicable law or 385 | agreed to in writing, Licensor provides the Work (and each 386 | Contributor provides its Contributions) on an "AS IS" BASIS, 387 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 388 | implied, including, without limitation, any warranties or conditions 389 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 390 | PARTICULAR PURPOSE. You are solely responsible for determining the 391 | appropriateness of using or redistributing the Work and assume any 392 | risks associated with Your exercise of permissions under this License. 393 | 394 | 8. Limitation of Liability. In no event and under no legal theory, 395 | whether in tort (including negligence), contract, or otherwise, 396 | unless required by applicable law (such as deliberate and grossly 397 | negligent acts) or agreed to in writing, shall any Contributor be 398 | liable to You for damages, including any direct, indirect, special, 399 | incidental, or consequential damages of any character arising as a 400 | result of this License or out of the use or inability to use the 401 | Work (including but not limited to damages for loss of goodwill, 402 | work stoppage, computer failure or malfunction, or any and all 403 | other commercial damages or losses), even if such Contributor 404 | has been advised of the possibility of such damages. 405 | 406 | 9. Accepting Warranty or Additional Liability. While redistributing 407 | the Work or Derivative Works thereof, You may choose to offer, 408 | and charge a fee for, acceptance of support, warranty, indemnity, 409 | or other liability obligations and/or rights consistent with this 410 | License. However, in accepting such obligations, You may act only 411 | on Your own behalf and on Your sole responsibility, not on behalf 412 | of any other Contributor, and only if You agree to indemnify, 413 | defend, and hold each Contributor harmless for any liability 414 | incurred by, or claims asserted against, such Contributor by reason 415 | of your accepting any such warranty or additional liability. 416 | 417 | END OF TERMS AND CONDITIONS 418 | 419 | APPENDIX: How to apply the Apache License to your work. 420 | 421 | To apply the Apache License to your work, attach the following 422 | boilerplate notice, with the fields enclosed by brackets "[]" 423 | replaced with your own identifying information. (Don't include 424 | the brackets!) The text should be enclosed in the appropriate 425 | comment syntax for the file format. We also recommend that a 426 | file or class name and description of purpose be included on the 427 | same "printed page" as the copyright notice for easier 428 | identification within third-party archives. 429 | 430 | Copyright [yyyy] [name of copyright owner] 431 | 432 | Licensed under the Apache License, Version 2.0 (the "License"); 433 | you may not use this file except in compliance with the License. 434 | You may obtain a copy of the License at 435 | 436 | http://www.apache.org/licenses/LICENSE-2.0 437 | 438 | Unless required by applicable law or agreed to in writing, software 439 | distributed under the License is distributed on an "AS IS" BASIS, 440 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 441 | See the License for the specific language governing permissions and 442 | limitations under the License. 443 | 444 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). 12 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ------------------------------------------------------------------------------- 42 | SOFTWARE DISTRIBUTED WITH ML-Dataset-Decomposition: 43 | 44 | The ML-Dataset-Decomposition software includes a number of subcomponents with separate 45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 46 | ------------------------------------------------------------------------------- -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dataset-Decomposition 2 | 3 | This repository contains the implementation of the [Dataset Decomposition NeurIPS 2024 paper](https://arxiv.org/abs/2405.13226). 4 | The training code is based on the [OpenLM repository](https://github.com/mlfoundations/open_lm), as explained below. 5 | 6 | Dataset decomposition enables **fast pre-training for long-context** inputs by organizing documents into buckets based on their length. 7 | Additionally, a **length-based curriculum** can be applied (starting with short sequences and gradually progressing to longer ones) 8 | to achieve **improved performance** on both regular and long-context benchmarks. 9 | 10 | 11 | 12 | ## Table of Contents 13 | 14 | - [Setup](#setup) 15 | - [Create Decomposed Datasets](#create-decomposed-datasets) 16 | - [Launch Variable Sequence Length Training](#launch-variable-sequence-length-training) 17 | - [Results](#results) 18 | - [Citation](#citation) 19 | 20 | ## Setup 21 | 22 | Clone [OpenLM](https://github.com/mlfoundations/open_lm) and apply our patch to enable variable sequence length training. 23 | Install the requirements as instructed in the [OpenLM repository](https://github.com/mlfoundations/open_lm). 24 | Then, from the root of this repo perform the following steps: 25 | ```shell 26 | git clone https://github.com/mlfoundations/open_lm.git 27 | cd open_lm 28 | git checkout 9bb92ef1689333534b7057942a20d18a46d1fa52 29 | git apply ../open_lm.patch 30 | # Install dependencies as required by OpenLM 31 | cd .. 32 | ``` 33 | 34 | ## Create decomposed datasets 35 | 36 | Dataset decomposition is a per-document method and is applicable to any dataset. 37 | Here, we show an example for small datasets in the form of JSONL files. 38 | 39 | ### Step 1: 40 | Get some data. Make sure to upgrade the [datasets library](https://pypi.org/project/datasets) (we use version 3.1). 41 | ```shell 42 | mkdir -p /mnt/raw_datasets/wiki 43 | python scripts/wiki_download.py --output-dir /mnt/raw_datasets/wiki 44 | ``` 45 | Once the download is complete, you will have 32 JSONL files. 46 | Alternatively, you can run [scripts/dclm_download.py](scripts/dclm_download.py) to 47 | download a small potion of the [DCLM dataset](https://www.datacomp.ai/dclm/index.html#home). 48 | 49 | ### Step 2: 50 | Run tokenize+bucketize+shuffle: 51 | ```shell 52 | mkdir -p /mnt/processed_datasets/wiki 53 | python scripts/make_dd_buckets.py --input-files /mnt/raw_datasets/wiki/*.jsonl \ 54 | --output-dir /mnt/processed_datasets/wiki --min-bucket 8 --max-bucket 13 --num-workers 32 55 | ``` 56 | We use `32` workers here. You can increase this number for faster processing if you have more JSONL files. 57 | The `--min-bucket` and `--max-bucket` parameters determine the range of buckets for dataset decomposition. 58 | For the example above, buckets will be created for sequences with lengths `2^8=256`, `2^9=512`, ..., `2^13=8192`. 59 | 60 | When this step is completed, buckets will be created with multiple shards per bucket. 61 | 62 | ### Step 3: 63 | Create the WebDataset manifest files, one for each bucket. This ensures that each bucket has its corresponding manifest for proper dataset handling and processing. 64 | ```shell 65 | for i in $(seq 8 13); 66 | do 67 | python open_lm/open_lm/utils/make_wds_manifest.py --data-dir /mnt/processed_datasets/wiki/D_$i --num-workers 16 68 | done 69 | ``` 70 | ### Step 4 (Optional) 71 | Get stats of the (decomposed) dataset you just created: 72 | ```shell 73 | python scripts/get_stats.py --dd-dir /mnt/processed_datasets/wiki 74 | ``` 75 | With the above bucket sizes, we obtain the following statistics for the Wikipedia and DCLM datasets: 76 | #### Wikipedia 77 | ```text 78 | D_8 : # shards: 553 seq-length: 256 # sequences: 2,265,088 # tokens: 579,862,528 79 | D_9 : # shards: 779 seq-length: 512 # sequences: 1,595,392 # tokens: 816,840,704 80 | D_10: # shards: 831 seq-length: 1,024 # sequences: 850,944 # tokens: 871,366,656 81 | D_11: # shards: 690 seq-length: 2,048 # sequences: 353,280 # tokens: 723,517,440 82 | D_12: # shards: 475 seq-length: 4,096 # sequences: 121,600 # tokens: 498,073,600 83 | D_13: # shards: 291 seq-length: 8,192 # sequences: 37,248 # tokens: 305,135,616 84 | ******************** 85 | Total number of tokens = 3,794,796,544 86 | ``` 87 | #### DCLM-Baseline subset 88 | ```text 89 | D_8 : # shards: 3,560 seq-length: 256 # sequences: 14,581,760 # tokens: 3,732,930,560 90 | D_9 : # shards: 5,996 seq-length: 512 # sequences: 12,279,808 # tokens: 6,287,261,696 91 | D_10: # shards: 7,410 seq-length: 1,024 # sequences: 7,587,840 # tokens: 7,769,948,160 92 | D_11: # shards: 6,309 seq-length: 2,048 # sequences: 3,230,208 # tokens: 6,615,465,984 93 | D_12: # shards: 5,157 seq-length: 4,096 # sequences: 1,320,192 # tokens: 5,407,506,432 94 | D_13: # shards: 4,513 seq-length: 8,192 # sequences: 577,664 # tokens: 4,732,223,488 95 | ******************** 96 | Total number of tokens = 34,545,336,320 97 | ``` 98 | 99 | ## Launch Variable Sequence Length training 100 | 101 | ### Step 1 102 | Modify a run script with your desired hyperparameters and the path to the dataset. 103 | Refer to [the paper]((https://arxiv.org/abs/2405.13226))'s Appendix for the full list of hyperparameters. 104 | The dataset path can be either local or on S3. 105 | 106 | ### Step 2 107 | For dataset-decomposition parameters, you can use the following [helper code](scripts/get_dd_params.py) (or set them manually). 108 | For example, the parameters below configure training with a total of `29` billion tokens, `8` epochs/cycles, `8` GPUs, and a global batch size 109 | of `64*8192` tokens, with buckets sized from `256` to `8192`. 110 | (One global batch would include `64` sequences for the last bucket of sequences with a length of `8192`): 111 | ```shell 112 | python scripts/get_dd_params.py \ 113 | --tokens 28795904000 \ 114 | --epochs 8 \ 115 | --gpus 8 \ 116 | --global-batch-size 64 \ 117 | --number-of-shards 3560 5996 7410 6309 5157 4513 \ 118 | --sequence-per-shard 4096 2048 1024 512 256 128 \ 119 | --sequence_sizes 256 512 1024 2048 4096 8192 \ 120 | --batch-mult 32 16 8 4 2 1 \ 121 | --train-data-mix-weights 32 16 8 4 2 1 122 | ``` 123 | 124 | Here is a short description of each input argument: 125 | 126 | - `--tokens`: Total number of tokens to be processed. 127 | - `--epochs`: Number of cycles (also determines the number of checkpoints to save). 128 | - `--gpus`: Total number of GPUs. 129 | - `--global-batch-size`: Global batch size (assuming all sequences are of the longest length, e.g., 8192 here). 130 | - `--number-of-shards`: Number of available shards per bucket. 131 | - `--sequence-per-shard`: Number of sequences per shard per bucket. 132 | - `--sequence_sizes`: Length of sequences in each bucket. 133 | - `--batch-mult`: Batch multipliers to maintain a fixed number of tokens regardless of sequence length. 134 | - `--train-data-mix-weights`: Power-of-2 length-based curriculum (prioritizing shorter sequences first). 135 | 136 | It would output the following: 137 | ```text 138 | **** Use the following arguments: 139 | --epochs 8 140 | --train-num-samples 3607101440 141 | --dataset-batch-mult 32 16 8 4 2 1 142 | --source-num-seq-per-epoch 1507328 1277952 794624 335872 137216 61440 143 | --train-data-mix-weights 1472 1248 776 328 134 60 144 | ``` 145 | 146 | ### Step 3: 147 | Update the [run script](scripts/train_dd.sh) with the above parameters, and launch the training. 148 | Ensure you log in to WandB (or disable WandB reporting) before running the script. 149 | 150 | ```shell 151 | bash scripts/train_dd.sh 152 | ``` 153 | The above set of hyperparameters corresponds to [DCLM-Baseline 1B-1x](https://github.com/mlfoundations/dclm) 154 | with a maximum sequence length of `8192`. 155 | To extend beyond `8192`, make sure to update the model configuration files (located in `open_lm/model_configs`). 156 | 157 | On an H100 node with 8x GPUs, the above training should take less than 28 hours. 158 | For this example, the model performance is as follows: 159 | 160 | | ArcE | ArcC | Hellaswag | LamOAI | Winogrande | Winograd | WikiQA | OBQA | SQuAD | PIQA | COPA | CoQA | BoolQ | 161 | |--------|--------|-----------|--------|------------|----------|--------|-------|-------|-------|-------|-------|-------| 162 | | 65.0 | 35.5 | 57.8 | 61.0 | 58.9 | 75.5 | 52.5 | 39.8 | 35.3 | 73.4 | 72.0 | 31.8 | 61.7 | 163 | 164 | The number of tokens processed per second per GPU (H100) and the sampled sequence length over the course of training 165 | would be as shown below for this example. 166 | 167 | 168 | 169 | 170 | 171 | ## Results 172 | Please see the full list of ablations in [the paper](https://arxiv.org/abs/2405.13226). 173 | The following table summarizes some billion-scale results: 174 | 175 | - **RW** refers to the [RefinedWeb dataset](https://huggingface.co/datasets/tiiuae/falcon-refinedweb). 176 | - **DCLM** refers to the [DCLM-Baseline dataset](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0). 177 | - For models with **SFT**, we follow the same setup as DCLM-Baseline. 178 | - **DD** refers to Dataset Decomposition, and **C&C** refers to concat-and-chunk. 179 | 180 | All models are trained with a context length of `8192` and a total of `2^40` seen tokens (~1.1 trillion tokens). 181 | 182 | | **Model** | Dataset | Method | SFT | MMLU | ArcE | ArcC | Hellaswag | LambadaOAI | Winogrande | Winograd | WikiQA | OBQA | SQuAD | PIQA | COPA | CoQA | BoolQ | 183 | |-----------------|---------|--------|-----|------|------|------|-----------|------------|------------|----------|--------|------|-------|------|------|------|-------| 184 | | *Shots* | N/A | N/A | N/A | *5* | *3* | *3* | *0* | *0* | *5* | *3* | *3* | *10* | *3* | *0* | *0* | *0* | *0* | 185 | | **Random** | 0 | N/A | N/A | 25 | 25 | 25 | 25 | 0 | 50 | 50 | 0 | 25 | 0 | 50 | 50 | 0 | 50 | 186 | | **OpenLM 160m** | DCLM | DD | No | 24.5 | 51.9 | 26 | 40 | 44.2 | 52.5 | 65.6 | 42.4 | 33 | 17.8 | 68.4 | 61 | 19.9 | 58.8 | 187 | | **OpenLM 160m** | DCLM | DD | Yes | 25.6 | 52.9 | 27.6 | 39 | 39.9 | 50 | 65.6 | 36.2 | 31.4 | 36.2 | 66.1 | 62 | 29.3 | 49.1 | 188 | | **OpenLM 410m** | RW | C&C | No | 24.8 | 53.6 | 26.6 | 52.7 | 50.5 | 56.7 | 70.7 | 52.6 | 35.6 | 25.5 | 71.3 | 69 | 26.9 | 54.1 | 189 | | **OpenLM 410m** | RW | DD | No | 27 | 55.3 | 27.9 | 55.1 | 53.9 | 59 | 74.4 | 56.3 | 35 | 30.1 | 72.6 | 63 | 28.1 | 62.7 | 190 | | **OpenLM 410m** | DCLM | DD | No | 24.9 | 62.4 | 33.9 | 55.9 | 57.2 | 59.9 | 77.7 | 55.3 | 38.8 | 32 | 73.4 | 68 | 31.3 | 56.2 | 191 | | **OpenLM 410m** | DCLM | DD | Yes | 34.8 | 63.3 | 35.4 | 53.5 | 52.9 | 58.7 | 74.4 | 50.1 | 38.4 | 49.4 | 73.2 | 67 | 39.8 | 72.2 | 192 | | **OpenLM 1B** | DCLM | DD | No | 28.6 | 70.6 | 43.2 | 68.9 | 67.6 | 67.6 | 85.7 | 62.9 | 44.2 | 47.6 | 77.1 | 77 | 39.9 | 58.7 | 193 | | **OpenLM 1B** | DCLM | DD | Yes | 49.1 | 70.7 | 43.1 | 68.6 | 61 | 66.3 | 78.4 | 56.8 | 45 | 57.1 | 77 | 80 | 46.5 | 80.7 | 194 | 195 | ## Citation 196 | If you like our work, please consider citing [our NeurIPS 2024 paper](https://arxiv.org/abs/2405.13226): 197 | ```bibtex 198 | @article{pouransari2024dataset, 199 | title={Dataset Decomposition: Faster LLM Training with Variable Sequence Length Curriculum}, 200 | author={Pouransari, Hadi and Li, Chun-Liang and Chang, Jen-Hao Rick and Vasu, Pavan Kumar Anasosalu and Koc, Cem and Shankar, Vaishaal and Tuzel, Oncel}, 201 | journal={arXiv preprint arXiv:2405.13226}, 202 | year={2024}, 203 | url={https://arxiv.org/abs/2405.13226} 204 | } 205 | ``` 206 | -------------------------------------------------------------------------------- /assets/dd-animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-dataset-decomposition/9a23939205bd371184b9e627a7abc9e9e9e07a4b/assets/dd-animation.gif -------------------------------------------------------------------------------- /assets/sequence_length.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-dataset-decomposition/9a23939205bd371184b9e627a7abc9e9e9e07a4b/assets/sequence_length.png -------------------------------------------------------------------------------- /assets/throughput.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-dataset-decomposition/9a23939205bd371184b9e627a7abc9e9e9e07a4b/assets/throughput.png -------------------------------------------------------------------------------- /open_lm.patch: -------------------------------------------------------------------------------- 1 | diff --git a/open_lm/data.py b/open_lm/data.py 2 | index 107ff6e..05f6d08 100644 3 | --- a/open_lm/data.py 4 | +++ b/open_lm/data.py 5 | @@ -38,6 +38,34 @@ from webdataset.tariterators import ( 6 | ) 7 | from webdataset.mix import RandomMix 8 | 9 | +class MyRandomMix(RandomMix): 10 | + def __init__(self, datasets, probs=None, longest=False, seed=42): 11 | + super().__init__(datasets, probs=probs, longest=longest) 12 | + self.rng = random.Random() 13 | + self.rng.seed(seed) 14 | + 15 | + def __iter__(self): 16 | + """Return an iterator over the sources.""" 17 | + sources = [iter(d) for d in self.datasets] 18 | + return self.random_samples(sources, self.probs) 19 | + 20 | + def random_samples(self, sources, probs=None): 21 | + if probs is None: 22 | + probs = [1] * len(sources) 23 | + else: 24 | + probs = list(probs) 25 | + while len(sources) > 0: 26 | + cum = (np.array(probs) / np.sum(probs)).cumsum() 27 | + r = self.rng.random() 28 | + i = np.searchsorted(cum, r) 29 | + try: 30 | + yield next(sources[i]) 31 | + except StopIteration: 32 | + if self.longest: 33 | + del sources[i] 34 | + del probs[i] 35 | + else: 36 | + break 37 | 38 | def seed_worker(worker_id): 39 | worker_seed = torch.initial_seed() % 2**32 40 | @@ -344,7 +372,7 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke 41 | all_num_samples = [] 42 | 43 | shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc 44 | - for ii, input_shards in enumerate(input_shards_): 45 | + for ii, (input_shards, batch_mult) in enumerate(zip(input_shards_, args.dataset_batch_mult)): 46 | resampled = getattr(args, "dataset_resampled", False) and is_train 47 | num_shards = None 48 | if is_train: 49 | @@ -421,7 +449,7 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke 50 | ) 51 | 52 | map_handler = {"handler": log_and_continue} if args.ignore_parse_errors else {} 53 | - batch_size = args.per_gpu_batch_size if is_train else args.per_gpu_val_batch_size 54 | + batch_size = int(batch_mult * args.per_gpu_batch_size) if is_train else args.per_gpu_val_batch_size 55 | 56 | if data_key == "json" or data_key == "json.gz": 57 | pipeline.extend( 58 | @@ -430,7 +458,6 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke 59 | wds.rename(json=data_key), 60 | wds.map_dict(json=partial(preprocess_json, vocab_size=args.vocab_size), **map_handler), 61 | wds.to_tuple("json", **map_handler), 62 | - wds.select(partial(filter_lt_seqlen, args.seq_len)), 63 | wds.batched(batch_size, partial=not is_train), 64 | ] 65 | ) 66 | @@ -439,7 +466,6 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke 67 | [ 68 | wds.map_dict(txt=partial(preprocess_txt, vocab_size=args.vocab_size), **map_handler), 69 | wds.to_tuple("txt", **map_handler), 70 | - wds.select(partial(filter_lt_seqlen, args.seq_len)), 71 | wds.batched(batch_size, partial=not is_train), 72 | ] 73 | ) 74 | @@ -451,8 +477,8 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke 75 | all_num_samples.append(num_samples) 76 | 77 | if is_train: 78 | - # TODO: why did we previoulsy wrap with RandomMix_ 79 | - dataset = RandomMix(datasets, probs=args.train_data_mix_weights, longest=True) 80 | + # Use our RandomMix with determined random seed to make sure all nodes choose the same bucket. 81 | + dataset = MyRandomMix(datasets, probs=args.train_data_mix_weights, longest=True, seed=args.seed) 82 | if len(datasets) > 1: 83 | logging.warning("Source mixing is happening during training. It is preferred to mix during tokenization.") 84 | else: 85 | @@ -461,17 +487,18 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke 86 | # dataset = datasets[0] 87 | if is_train: 88 | if not resampled: 89 | - num_shards = num_shards or len(expand_urls(input_shards)[0]) 90 | - if num_shards < args.workers * args.world_size: 91 | + shards_per_source_avail = [len(expand_urls(shard_string)[0]) for shard_string in input_shards_] 92 | + print(f"Number of shards available from each source = {shards_per_source_avail}") 93 | + min_num_shards = min(shards_per_source_avail) 94 | + if min_num_shards < args.workers * args.world_size: 95 | print("Please increase --train-num-samples or decrease workers or world size") 96 | - print(f"num_shards: {num_shards}, workers: {args.workers}, world_size: {args.world_size}") 97 | - assert num_shards >= args.workers * args.world_size, "number of shards must be >= total workers" 98 | - # roll over and repeat a few samples to get same number of full batches on each node 99 | + print(f"min num_shards: {min_num_shards}, workers: {args.workers}, world_size: {args.world_size}") 100 | + assert min_num_shards >= args.workers * args.world_size, "number of shards must be >= total workers" 101 | round_fn = math.floor if floor else math.ceil 102 | - global_batch_size = batch_size * args.world_size 103 | total_num_batches = 0 104 | total_num_samples = 0 105 | - for ii in range(len(datasets)): 106 | + for ii, batch_mult in enumerate(args.dataset_batch_mult): 107 | + global_batch_size = int(batch_mult * args.global_batch_size) 108 | # Calculate batches per worker, round as little as possible. 109 | num_workers_per_gpu = max(1, args.workers) 110 | num_worker_batches = round_fn(all_num_samples[ii] / (global_batch_size * num_workers_per_gpu)) 111 | @@ -484,7 +511,7 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke 112 | ) 113 | 114 | num_batches = num_worker_batches * num_workers_per_gpu 115 | - num_samples = num_batches * global_batch_size 116 | + num_samples = num_batches * args.global_batch_size # Number of sequences as if all were the longest (8k) 117 | 118 | # This forces the dataloader to take num_worker_batches steps per worker, so num_batches total. 119 | datasets[ii] = datasets[ii].repeat(nepochs=1, nbatches=num_worker_batches) 120 | @@ -704,18 +731,6 @@ def mask_sequence(chunk, start_idx, args, ignore_tok=-100): 121 | 122 | 123 | def sample_chunk(chunk, args): 124 | - if chunk.shape[1] == args.seq_len + 1: 125 | - start_idx = 0 126 | - elif chunk.shape[1] > args.seq_len + 1: 127 | - start_idx = torch.randint(0, chunk.shape[1] - args.seq_len, (1,)).item() 128 | - else: 129 | - raise Exception(f"Invalid sequence length: Sequence length {args.seq_len} > {chunk.shape[1]} Chunk size") 130 | - 131 | - inputs = chunk[:, start_idx : start_idx + args.seq_len] 132 | - targets = chunk[:, start_idx + 1 : start_idx + args.seq_len + 1] 133 | - 134 | - # replace elements to be masked with with -100 (pytorch default xent ignore value) 135 | - if args.target_mask_left is not None or args.target_mask_individual is not None: 136 | - inputs, targets = mask_sequence(chunk, start_idx, args) 137 | - 138 | + inputs = chunk[:, :-1] 139 | + targets = chunk[:, 1:] 140 | return inputs, targets 141 | diff --git a/open_lm/file_utils.py b/open_lm/file_utils.py 142 | index f91919b..fe729fa 100644 143 | --- a/open_lm/file_utils.py 144 | +++ b/open_lm/file_utils.py 145 | @@ -134,14 +134,22 @@ def check_exists(file_path): 146 | return True 147 | 148 | 149 | -def get_metadata_file(path, shard_shuffle_seed=None): 150 | +def get_metadata_file(path, shard_shuffle_seed=None, append_a_copy=4): 151 | of = fsspec.open(path, "rb") 152 | with of as f: 153 | out = f.read() 154 | out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]] 155 | + if append_a_copy > 0: 156 | + out_copy = [copy.deepcopy(out) for _ in range(append_a_copy)] 157 | if shard_shuffle_seed is not None: 158 | rng_gen = np.random.default_rng(shard_shuffle_seed) 159 | rng_gen.shuffle(out) 160 | + if append_a_copy > 0: 161 | + for a_copy in out_copy: 162 | + rng_gen.shuffle(a_copy) 163 | + if append_a_copy > 0: 164 | + for a_copy in out_copy: 165 | + out = out + a_copy 166 | return out 167 | 168 | 169 | @@ -218,7 +226,7 @@ def count_small_shards(path, ratio=0.9): 170 | 171 | shard_sizes = np.array(shard_sizes) 172 | 173 | - return np.sum(shard_sizes < ratio * max(shard_sizes)) 174 | + return np.sum(shard_sizes < ratio * max(shard_sizes)), max(shard_sizes) 175 | 176 | 177 | def are_sources_imbalanced_with_each_other(paths, ratio=2): 178 | @@ -262,9 +270,11 @@ def log_num_checkpoints(total_steps, args): 179 | args.world_size, 180 | multi_epoch=args.multiple_data_passes, 181 | shard_shuffle_seed=args.shard_shuffle_seed, 182 | + source_num_seq_per_epoch=args.source_num_seq_per_epoch, 183 | ) 184 | steps_epoch = sum( 185 | - [(n // (args.workers * args.global_batch_size)) * args.workers for n in num_samples_per_source] 186 | + [(n // (args.workers * args.global_batch_size * batch_mult)) * args.workers for n, batch_mult in 187 | + zip(num_samples_per_source, args.dataset_batch_mult)] 188 | ) 189 | steps_done += steps_epoch 190 | if steps_done > total_steps: 191 | @@ -300,15 +310,18 @@ def get_string_for_epoch( 192 | world_size: int, 193 | multi_epoch=False, 194 | shard_shuffle_seed=None, 195 | + source_num_seq_per_epoch=None, 196 | ): 197 | """See _single_epoch_string for full docstring.""" 198 | if multi_epoch: 199 | return _multi_epoch_string( 200 | - num_samples, starting_points, paths, weights, num_workers_per_gpu, world_size, shard_shuffle_seed 201 | + num_samples, starting_points, paths, weights, num_workers_per_gpu, world_size, shard_shuffle_seed, 202 | + source_num_seq_per_epoch 203 | ) 204 | else: 205 | return _single_epoch_string( 206 | - num_samples, starting_points, paths, weights, num_workers_per_gpu, world_size, shard_shuffle_seed 207 | + num_samples, starting_points, paths, weights, num_workers_per_gpu, world_size, shard_shuffle_seed, 208 | + source_num_seq_per_epoch 209 | ) 210 | 211 | 212 | @@ -370,6 +383,7 @@ def _single_epoch_string( 213 | num_workers_per_gpu: int, 214 | world_size: int, 215 | shard_shuffle_seed: Optional[int], 216 | + source_num_seq_per_epoch: Optional[List[int]] = None, 217 | ): 218 | """Retrieve shards to train on for a particular checkpoint. 219 | 220 | @@ -383,38 +397,25 @@ def _single_epoch_string( 221 | num_workers_per_gpu: Number of workers per gpu process. 222 | world_size: Total number of gpus used for training. 223 | shard_shuffle_seed: Seed to shuffle shards before checkpoint assignment 224 | + source_num_seq_per_epoch: List of number of sequences per bucket per epoch. 225 | """ 226 | 227 | num_sources = len(paths) 228 | - 229 | - if num_sources > 1: 230 | - logging.warning( 231 | - "Multiple sources are not supported fully as of now. It is advised to combine the data into a single " 232 | - "source, by using datapreprocess/ray/tokenize_shuffle.py. Best effort will be done to mix data at the " 233 | - "desired ratio." 234 | - ) 235 | - if are_sources_imbalanced_with_each_other(paths): 236 | - logging.warning( 237 | - "Sources contain highly imbalanced shards (largest median shard size of a source is >2x the smallest " 238 | - "median size of a source). This will lead to deteriorated performance (less frequent checkpoints, " 239 | - "data being skipped, and inaccurate mixing). It is STRONGLY advised to combine into one source." 240 | - ) 241 | + expected_num_sequence_per_shard = [] 242 | 243 | for path in paths: 244 | - num_small_shards = count_small_shards(path) 245 | - if num_small_shards > 0: 246 | - logging.warning( 247 | - f"Source defined by {path} contains {num_small_shards} shards that are smaller than 90% the size of " 248 | - f"the largest shard. These shards might cause deterioration in performance, with more samples being " 249 | - f"skipped than necessary. It is advised to make the shards more uniform." 250 | - ) 251 | + num_small_shards, expected_num_seq = count_small_shards(path) 252 | + expected_num_sequence_per_shard.append(expected_num_seq) 253 | 254 | if weights is None: 255 | weights = [1.0 / num_sources for _ in range(num_sources)] 256 | 257 | assert len(weights) == num_sources, "One weight is needed per source." 258 | 259 | - needed_samples_per_source = [int(np.ceil(weights[i] * num_samples / sum(weights))) for i in range(num_sources)] 260 | + if source_num_seq_per_epoch is None: 261 | + needed_samples_per_source = [int(np.ceil(weights[i] * num_samples / sum(weights))) for i in range(num_sources)] 262 | + else: 263 | + needed_samples_per_source = source_num_seq_per_epoch 264 | 265 | manifests = [get_metadata_file(path, shard_shuffle_seed=shard_shuffle_seed) for path in paths] 266 | 267 | @@ -424,32 +425,38 @@ def _single_epoch_string( 268 | num_samples_per_source = [[] for _ in range(num_sources)] 269 | 270 | total_num_workers = num_workers_per_gpu * world_size 271 | - while not enough_shards(shard_list_per_source, total_num_workers) or not enough_samples( 272 | - num_samples_per_source, needed_samples_per_source 273 | - ): 274 | - try: 275 | - for i in range(num_sources): 276 | + try: 277 | + for i in range(num_sources): 278 | + while len(shard_list_per_source[i]) < total_num_workers or sum(num_samples_per_source[i]) < \ 279 | + needed_samples_per_source[i]: 280 | # Add shards incrementally 281 | shard_name = manifests[i][next_shard_per_source[i]]["shard"] 282 | try: 283 | num_samples_shard = manifests[i][next_shard_per_source[i]]["num_sequences"] 284 | except KeyError: 285 | num_samples_shard = manifests[i][next_shard_per_source[i]]["num_chunks"] 286 | - 287 | - shard_list_per_source[i].append(shard_name) 288 | - num_samples_per_source[i].append(num_samples_shard) 289 | + if num_samples_shard == expected_num_sequence_per_shard[i]: 290 | + shard_list_per_source[i].append(shard_name) 291 | + num_samples_per_source[i].append(num_samples_shard) 292 | + else: 293 | + print( 294 | + f"Dropping shard = {shard_name} with {num_samples_shard} samples != {expected_num_sequence_per_shard[i]}") 295 | 296 | next_shard_per_source[i] += 1 297 | - 298 | - except IndexError as e: 299 | - logging.error( 300 | - "Number of shards requested for a single epoch is more than the number of shards available. This means " 301 | - "that the amount of data requested to train on is more than the dataloader can serve. This can either " 302 | - "happen because there are not enough data to begin with, or data being skipped due to rounding errors. " 303 | - "To alleviate the latter, consider making more uniform shards, and using less workers/GPUs. This will " 304 | - "allow for better use of the dataset." 305 | - ) 306 | - raise e 307 | + except IndexError as e: 308 | + print(f"For Source = {i}") 309 | + print(f"Need samples = {needed_samples_per_source[i]}, collected {sum(num_samples_per_source[i])}") 310 | + print(f"Total shards so far = {next_shard_per_source[i]}") 311 | + print(f"len(shard_list_per_source[i]) = {len(shard_list_per_source[i])}") 312 | + print(f"total_num_workers = {total_num_workers}") 313 | + logging.error( 314 | + "Number of shards requested for a single epoch is more than the number of shards available. This means " 315 | + "that the amount of data requested to train on is more than the dataloader can serve. This can either " 316 | + "happen because there are not enough data to begin with, or data being skipped due to rounding errors. " 317 | + "To alleviate the latter, consider making more uniform shards, and using less workers/GPUs. This will " 318 | + "allow for better use of the dataset." 319 | + ) 320 | + raise e 321 | 322 | for i in range(num_sources): 323 | # Ensure the number of shards is a multiple of number of workers, so each worker has the same 324 | @@ -458,6 +465,9 @@ def _single_epoch_string( 325 | # This is a heuristic to minimize how much data we discard when trying to ensure each worker has 326 | # the same number of samples. Shards tend to have similar number of samples, so an extra shard 327 | # in a worker will likely get discarded. 328 | + if not len(shard_list_per_source[i]) % total_num_workers == 0: 329 | + print( 330 | + f"For source {i} number of shards = {len(shard_list_per_source[i])} is not multiple of total workers = {total_num_workers}") 331 | num_multiples = len(shard_list_per_source[i]) // total_num_workers 332 | 333 | shard_list_per_source[i] = shard_list_per_source[i][: num_multiples * total_num_workers] 334 | diff --git a/open_lm/main.py b/open_lm/main.py 335 | index 7c80f55..0da7edc 100644 336 | --- a/open_lm/main.py 337 | +++ b/open_lm/main.py 338 | @@ -793,6 +793,7 @@ def main(args): 339 | args.world_size, 340 | multi_epoch=args.multiple_data_passes, 341 | shard_shuffle_seed=args.shard_shuffle_seed, 342 | + source_num_seq_per_epoch=args.source_num_seq_per_epoch, 343 | ) 344 | 345 | # In the distributed case, make sure that all nodes receive the same string 346 | diff --git a/open_lm/model_configs/open_lm_160m.json b/open_lm/model_configs/open_lm_160m.json 347 | index ea4fe6e..944faf0 100644 348 | --- a/open_lm/model_configs/open_lm_160m.json 349 | +++ b/open_lm/model_configs/open_lm_160m.json 350 | @@ -2,7 +2,7 @@ 351 | "hidden_dim": 768, 352 | "n_layers": 12, 353 | "n_heads": 12, 354 | - "seq_len": 2048, 355 | + "seq_len": 8192, 356 | "vocab_size": 50432, 357 | "post_embed_norm": false, 358 | "weight_tying": false 359 | diff --git a/open_lm/model_configs/open_lm_1b.json b/open_lm/model_configs/open_lm_1b.json 360 | index fc1878e..774fc9b 100644 361 | --- a/open_lm/model_configs/open_lm_1b.json 362 | +++ b/open_lm/model_configs/open_lm_1b.json 363 | @@ -2,7 +2,7 @@ 364 | "hidden_dim": 2048, 365 | "n_layers": 24, 366 | "n_heads": 16, 367 | - "seq_len": 2048, 368 | + "seq_len": 8192, 369 | "vocab_size": 50432, 370 | "post_embed_norm": false, 371 | "weight_tying": false 372 | diff --git a/open_lm/model_configs/open_lm_3b.json b/open_lm/model_configs/open_lm_3b.json 373 | index 64ec0a4..57cc24a 100644 374 | --- a/open_lm/model_configs/open_lm_3b.json 375 | +++ b/open_lm/model_configs/open_lm_3b.json 376 | @@ -2,7 +2,7 @@ 377 | "hidden_dim": 2560, 378 | "n_layers": 32, 379 | "n_heads": 32, 380 | - "seq_len": 2048, 381 | + "seq_len": 8192, 382 | "vocab_size": 50432, 383 | "post_embed_norm": false, 384 | "weight_tying": false 385 | diff --git a/open_lm/model_configs/open_lm_410m.json b/open_lm/model_configs/open_lm_410m.json 386 | index 8532173..1010cf7 100644 387 | --- a/open_lm/model_configs/open_lm_410m.json 388 | +++ b/open_lm/model_configs/open_lm_410m.json 389 | @@ -2,7 +2,7 @@ 390 | "hidden_dim": 1024, 391 | "n_layers": 24, 392 | "n_heads": 16, 393 | - "seq_len": 2048, 394 | + "seq_len": 8192, 395 | "vocab_size": 50432, 396 | "post_embed_norm": false, 397 | "weight_tying": false 398 | diff --git a/open_lm/model_configs/open_lm_7b.json b/open_lm/model_configs/open_lm_7b.json 399 | index e662dab..b9178d0 100644 400 | --- a/open_lm/model_configs/open_lm_7b.json 401 | +++ b/open_lm/model_configs/open_lm_7b.json 402 | @@ -2,7 +2,7 @@ 403 | "hidden_dim": 4096, 404 | "n_layers": 32, 405 | "n_heads": 32, 406 | - "seq_len": 2048, 407 | + "seq_len": 8192, 408 | "vocab_size": 50432, 409 | "post_embed_norm": false, 410 | "weight_tying": false 411 | diff --git a/open_lm/params.py b/open_lm/params.py 412 | index 0a7a3f6..389b805 100644 413 | --- a/open_lm/params.py 414 | +++ b/open_lm/params.py 415 | @@ -787,6 +787,20 @@ def parse_args(args): 416 | default=0, 417 | help="This is the maximum number of failed checkpoints (due to not having seen enough tokens) that are allowed", 418 | ) 419 | + parser.add_argument( 420 | + "--dataset-batch-mult", 421 | + type=float, 422 | + nargs="+", 423 | + default=None, 424 | + help="Multiplier of batchsize to be used for each dataset (with respect to base batchsize).", 425 | + ) 426 | + parser.add_argument( 427 | + "--source-num-seq-per-epoch", 428 | + type=int, 429 | + nargs="+", 430 | + default=None, 431 | + help="Number of sequences to be used per epoch from each source.", 432 | + ) 433 | 434 | add_model_args(parser) 435 | 436 | diff --git a/open_lm/positional_embedding/rotary.py b/open_lm/positional_embedding/rotary.py 437 | index b48ed89..d5c1af0 100644 438 | --- a/open_lm/positional_embedding/rotary.py 439 | +++ b/open_lm/positional_embedding/rotary.py 440 | @@ -57,7 +57,7 @@ class RotaryEmbedding(torch.nn.Module): 441 | self.reset_parameters() 442 | 443 | def reset_parameters(self): 444 | - self.inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model)) 445 | + self.inv_freq = 1.0 / (100000 ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model)) 446 | self._update_cos_sin_tables(self.seq_len) 447 | 448 | def _update_cos_sin_tables(self, seq_len: int = None, device: torch.device = None, dtype: torch.dtype = None): 449 | diff --git a/open_lm/train.py b/open_lm/train.py 450 | index 0d54bf7..eccf708 100644 451 | --- a/open_lm/train.py 452 | +++ b/open_lm/train.py 453 | @@ -110,13 +110,17 @@ def train_one_epoch( 454 | 455 | try: 456 | batch = next(data_iterator) 457 | - has_data = torch.tensor(1, dtype=torch.long, device=device) 458 | + has_data = torch.tensor([1, len(batch[0])], dtype=torch.long, device=device) 459 | except StopIteration: 460 | - has_data = torch.tensor(0, dtype=torch.long, device=device) 461 | + logging.warning("Could not get a batch!!!") 462 | + has_data = torch.tensor([0, 0], dtype=torch.long, device=device) 463 | 464 | if args.world_size > 1: 465 | dist.all_reduce(has_data, op=ReduceOp.SUM) 466 | - if has_data < args.world_size: 467 | + if has_data[1] != len(batch[0]) * args.world_size: 468 | + logging.warning("Same global sequence length consistency broke! This can reduce performance.") 469 | + if has_data[0] != args.world_size: 470 | + logging.warning("At least one gpu could not get a batch.") 471 | break 472 | 473 | (texts,) = batch 474 | @@ -153,12 +157,12 @@ def train_one_epoch( 475 | # save the loss for the average model for logging 476 | total_loss_avg[key] = loss(out_avg.reshape(-1, args.vocab_size), targets.reshape(-1)) 477 | else: 478 | + inputs, targets = sample_chunk(texts, args) 479 | + 480 | # split up batch into accum_freq chunks -- if you have --batch-size 8 and --accum-freq 4 481 | # then you only process 2 items at a time. batch-size must be divisible by accume-freq. 482 | - assert args.per_gpu_batch_size % args.accum_freq == 0, "Per-GPU batch size must be divisible by accum_freq" 483 | - per_batch = args.per_gpu_batch_size // args.accum_freq 484 | - 485 | - inputs, targets = sample_chunk(texts, args) 486 | + assert inputs.shape[0] % args.accum_freq == 0, "Per-GPU batch size must be divisible by accum_freq" 487 | + per_batch = inputs.shape[0] // args.accum_freq 488 | 489 | forward_total_time = 0 490 | backward_total_time = 0 491 | @@ -291,7 +295,7 @@ def train_one_epoch( 492 | for key, value in total_loss_avg.items(): 493 | losses_avg_m[key].update(value.item(), batch_size) 494 | if i % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch or step == total_steps - 1: 495 | - num_samples = batch_count * batch_size * args.world_size 496 | + num_samples = batch_count * args.global_batch_size # Number of sequences seen as if all were the longest 497 | samples_per_epoch = dataloader.num_samples 498 | percent_complete = 100.0 * batch_count / num_batches_per_epoch 499 | 500 | @@ -332,6 +336,7 @@ def train_one_epoch( 501 | "tokens": (step + 1) * args.global_batch_size * args.seq_len, 502 | "expected_steps_epoch": data["train"].dataloader.num_batches, 503 | "seen_steps_epoch": batch_count, 504 | + "seq_len": inputs.shape[1], 505 | } 506 | 507 | if averagers is not None and args.log_avg_model_training_loss: 508 | -------------------------------------------------------------------------------- /scripts/dclm_download.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | import os 8 | 9 | from datasets import load_dataset 10 | from tqdm import tqdm 11 | 12 | 13 | def main(output_dir: str, num_shards: int = 128) -> None: 14 | """Downloads a small subset of the DCLM-Baseline dataset. 15 | 16 | :param output_dir: The path to the output directory where the downloaded files will be saved. 17 | :param num_shards: The number of JSONL files to divide the downloaded data into. 18 | :return: None 19 | """ 20 | os.makedirs(output_dir, exist_ok=True) 21 | # Download a small fraction of DCLM-Baseline dataset 22 | data = load_dataset("mlfoundations/dclm-baseline-1.0", 23 | data_dir="global-shard_01_of_10/local-shard_0_of_10") 24 | 25 | for split, dataset in data.items(): 26 | for i in tqdm(range(num_shards)): 27 | dataset_shard = dataset.shard( 28 | num_shards=num_shards, index=i, contiguous=True) 29 | output_file = os.path.join(output_dir, f"dclm_{split}_{i}.jsonl") 30 | dataset_shard.to_json(output_file) 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--output-dir", 37 | type=str, 38 | required=True, 39 | help="Where to store the DCLM subset .jsonl files.", 40 | ) 41 | 42 | args = parser.parse_args() 43 | main(args.output_dir) 44 | -------------------------------------------------------------------------------- /scripts/get_dd_params.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | import numpy as np 9 | from numpy.typing import NDArray 10 | 11 | 12 | def main(total_tokens: int, 13 | epochs: int, 14 | num_gpus: int, 15 | global_batch: int, 16 | num_workers: int, 17 | curriculum: NDArray, 18 | number_of_shards: NDArray, 19 | sequence_per_shard: NDArray, 20 | sequence_sizes: NDArray, 21 | batch_mult: NDArray) -> None: 22 | """Computes variable batch size training parameters based on the desired training hyperparameters. 23 | 24 | :param total_tokens: The total number of tokens to be processed during training. 25 | :param epochs: The number of epochs/checkpoints to save during training. 26 | :param num_gpus: The total number of GPUs to be used for training. 27 | :param global_batch: The global batch size for training. 28 | :param num_workers: The number of dataloader workers per GPU. 29 | :param curriculum: A numpy array of integers representing the probabilities of selecting a batch from each bucket. 30 | :param number_of_shards: A numpy array indicating the number of shards per source as defined in each bucket's manifest file. 31 | :param sequence_per_shard: A numpy array representing the number of sequences per shard for each source. 32 | :param sequence_sizes: A numpy array specifying the sizes of sequences per shard. 33 | :param batch_mult: A numpy array defining the ratio of each source's batch size compared to the source with the longest sequences. 34 | :return: None 35 | """ 36 | # Number of tokens available per source/bucket: 37 | tokens_per_bucket = number_of_shards * sequence_per_shard * sequence_sizes 38 | 39 | # Ratio of number of tokens needed vs number of tokens available 40 | job_scale_factor = total_tokens / tokens_per_bucket.sum() 41 | 42 | # Number of tokens needed per source/bucket 43 | needed_tokens_per_bucket = job_scale_factor * tokens_per_bucket 44 | 45 | # Number of tokens needed per source/bucket per epoch 46 | needed_tokens_per_bucket_per_epoch = needed_tokens_per_bucket / epochs 47 | 48 | # Number of sequences needed per source/bucket per epoch 49 | needed_sequence_per_bucket_per_epoch = needed_tokens_per_bucket_per_epoch / \ 50 | (sequence_sizes) 51 | 52 | # Number of sequences per source/bucket per epoch should be divisible with 53 | # the following numbers 54 | denom_condition1 = batch_mult * global_batch * num_workers 55 | denom_condition2 = sequence_per_shard * num_gpus * num_workers 56 | # Satisfying the second condition is sufficient 57 | assert np.all(denom_condition2 > denom_condition1) 58 | 59 | factors = needed_sequence_per_bucket_per_epoch / denom_condition2 60 | factors_int = np.int32(np.round(factors)) 61 | 62 | def get_token_diff(proposed_factors): 63 | return total_tokens / epochs - \ 64 | np.sum(proposed_factors * denom_condition2 * sequence_sizes) 65 | 66 | proposed_factors = factors_int 67 | 68 | index = len(factors_int) - 1 69 | 70 | tried_proposed_factors = set() 71 | while get_token_diff(proposed_factors) != 0: 72 | if index < 0: 73 | total_tokens_real = epochs * \ 74 | (proposed_factors * denom_condition2 * sequence_sizes).sum() 75 | print( 76 | f"Cannot match requested number of tokens of {total_tokens:,}. Go with {total_tokens_real:,} instead.") 77 | break 78 | if tuple(proposed_factors) in tried_proposed_factors: 79 | index -= 1 80 | diff = get_token_diff(proposed_factors) 81 | tried_proposed_factors.add(tuple(proposed_factors)) 82 | if diff < 0: 83 | proposed_factors[index] -= 1 84 | else: 85 | proposed_factors[index] += 1 86 | 87 | proposed_needed_sequence_per_bucket_per_epoch = proposed_factors * denom_condition2 88 | proposed_num_tokens_per_bucket = proposed_needed_sequence_per_bucket_per_epoch * sequence_sizes 89 | source_num_seq_per_epoch = " ".join( 90 | [str(int(x)) for x in proposed_needed_sequence_per_bucket_per_epoch]) 91 | sampling_weights = proposed_num_tokens_per_bucket / \ 92 | proposed_num_tokens_per_bucket[0] * proposed_factors[0] 93 | sampling_weights = sampling_weights * curriculum 94 | train_data_mix_weights = " ".join([str(int(x)) for x in sampling_weights]) 95 | actual_tokens_per_epoch = ( 96 | proposed_needed_sequence_per_bucket_per_epoch * 97 | sequence_sizes).sum() 98 | 99 | print("**** Use the following arguments:") 100 | 101 | print(f"--epochs {epochs}") 102 | print(f"--train-num-samples {actual_tokens_per_epoch}") 103 | print("--dataset-batch-mult " + 104 | " ".join([str(int(x)) for x in batch_mult])) 105 | print("--source-num-seq-per-epoch " + source_num_seq_per_epoch) 106 | print("--train-data-mix-weights " + train_data_mix_weights) 107 | 108 | 109 | if __name__ == "__main__": 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument( 112 | "--tokens", 113 | type=int, 114 | help=( 115 | "Total number of tokens to be seen during training." 116 | "Can be larger than number of available tokens as we allow repeated tokens.")) 117 | parser.add_argument( 118 | "--epochs", 119 | type=int, 120 | help="Number of epochs/checkpoints to save.") 121 | parser.add_argument( 122 | "--gpus", 123 | type=int, 124 | help="Total number of GPUs to be used for training.") 125 | parser.add_argument( 126 | "--global-batch-size", 127 | type=int, 128 | help="Global batch size.") 129 | parser.add_argument( 130 | "--workers", 131 | type=int, 132 | default=1, 133 | help="Number of dataloader workers per GPU.") 134 | parser.add_argument( 135 | "--number-of-shards", 136 | type=int, 137 | nargs="+", 138 | help="Number of shards per source (can read from manifest files).", 139 | default=[ 140 | 553, 141 | 779, 142 | 831, 143 | 690, 144 | 475, 145 | 291], 146 | ) # Default values for the wiki example 147 | parser.add_argument( 148 | "--sequence-per-shard", 149 | type=int, 150 | nargs="+", 151 | help="Number of sequences per shard for each source (determined in each bucket's manifest file).", 152 | default=[ 153 | 4096, 154 | 2048, 155 | 1024, 156 | 512, 157 | 256, 158 | 128], 159 | ) # Default values for the wiki example 160 | parser.add_argument( 161 | "--sequence_sizes", 162 | type=int, 163 | nargs="+", 164 | help="Size of sequences per shard.", 165 | default=[ 166 | 256, 167 | 512, 168 | 1024, 169 | 2048, 170 | 4096, 171 | 8192], 172 | ) # Default values for the wiki example 173 | parser.add_argument( 174 | "--batch-mult", 175 | type=float, 176 | nargs="+", 177 | help="Ratio of each source batch vs the source with the longest sequences.", 178 | default=[ 179 | 32, 180 | 16, 181 | 8, 182 | 4, 183 | 2, 184 | 1], 185 | ) # Default values for the wiki example 186 | parser.add_argument("--train-data-mix-weights", type=float, nargs="+", 187 | help="List of odds to pick a batch from each bucket.", 188 | default=[32, 16, 8, 4, 2, 1], ) # Pow-2 Curriculum 189 | args = parser.parse_args() 190 | print(args) 191 | main( 192 | args.tokens, args.epochs, args.gpus, args.global_batch_size, args.workers, np.array( 193 | args.train_data_mix_weights), np.array( 194 | args.number_of_shards), np.array( 195 | args.sequence_per_shard), np.array( 196 | args.sequence_sizes), np.array( 197 | args.batch_mult)) 198 | -------------------------------------------------------------------------------- /scripts/get_stats.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | import os 8 | from pathlib import Path 9 | 10 | import jsonlines 11 | 12 | 13 | def main(dd_dir: str) -> None: 14 | """Generates statistics for a directory containing a webdataset dataset. 15 | 16 | :param dd_dir: The path to the directory containing the webdataset dataset. 17 | :return: None 18 | """ 19 | total_num_tokens = 0 20 | bucket_dir_list = sorted( 21 | os.listdir(dd_dir), 22 | key=lambda x: int( 23 | x.split("_")[1])) 24 | for bucket_dir in bucket_dir_list: 25 | manifest_file = os.path.join(dd_dir, bucket_dir, "manifest.jsonl") 26 | num_shards = 0 27 | num_sequences = 0 28 | sequence_length = 2 ** int(bucket_dir.split("_")[1]) 29 | if os.path.exists(manifest_file): 30 | with jsonlines.open(manifest_file) as reader: 31 | for item in reader: 32 | num_shards += 1 33 | num_sequences += item['num_sequences'] 34 | num_tokens = num_sequences * sequence_length 35 | print( 36 | f"{bucket_dir:<4}: # shards: {num_shards:<6,} seq-length: {sequence_length:<8,} # sequences: {num_sequences:<12,} # tokens: {num_tokens:<12,}") 37 | total_num_tokens += num_tokens 38 | print(20 * "*") 39 | print(f"Total number of tokens = {total_num_tokens:,}") 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument( 45 | "--dd-dir", 46 | type=Path, 47 | help="Path to a dataset decomposition directory.") 48 | args = parser.parse_args() 49 | 50 | main(args.dd_dir) 51 | -------------------------------------------------------------------------------- /scripts/make_dd_buckets.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | import glob 8 | import gzip 9 | import io 10 | import math 11 | import multiprocessing 12 | import os 13 | import random 14 | import shutil 15 | import time 16 | from collections import defaultdict 17 | from contextlib import contextmanager 18 | from pathlib import Path 19 | from typing import Dict, Generator, List, Union 20 | 21 | import jsonlines 22 | import numpy as np 23 | import zstandard as zstd 24 | from tqdm.auto import tqdm 25 | from transformers import ( 26 | GPTNeoXTokenizerFast, 27 | PreTrainedTokenizer, 28 | PreTrainedTokenizerFast, 29 | ) 30 | from webdataset import ShardWriter 31 | 32 | # Increasing pool size results in more global shuffling, but use more memory. 33 | SHARD_POOL_FACTOR = 20 34 | 35 | # Number of documents per shard for each bucket. Note that bucket i has documents with length 2**i+1. 36 | # We use smaller number of documents per shard for larger i's. 37 | SHARD_SIZE = {i: min(2 ** (20 - i), 65536) for i in range(20)} 38 | 39 | EOT_TOKEN = "<|endoftext|>" 40 | 41 | 42 | def write_to_shard(chunks: List[Union[int, str]], 43 | shard_writer: ShardWriter) -> None: 44 | """Writes a list of tokens to a shard file. 45 | 46 | :param chunks: A list of tokens to be written. 47 | :param shard_writer: A Webdataset writer object for managing the shard file. 48 | :return: None 49 | """ 50 | for idx, chunk in enumerate(chunks): 51 | shard_writer.write({"__key__": f"{idx:012d}", "txt": str(chunk)}) 52 | 53 | 54 | @contextmanager 55 | def get_item_reader(file_name: str) -> Generator[jsonlines.Reader, None, None]: 56 | """Creates an iterator for reading .jsonl files or Zstd-compressed .jsonl files. 57 | 58 | :param file_name: The path to the input data file. 59 | :return: A generator that yields items from the .jsonl file or zstd-compressed .jsonl file. 60 | """ 61 | if file_name.endswith(".jsonl"): 62 | with jsonlines.open(file_name) as reader: 63 | yield reader 64 | elif file_name.endswith(".jsonl.gz"): 65 | with gzip.open(file_name, "rb") as f_in: 66 | with jsonlines.Reader(f_in) as jsonl_reader: 67 | yield jsonl_reader 68 | else: 69 | dctx = zstd.ZstdDecompressor() 70 | with open(file_name, "rb") as compressed_file: 71 | with dctx.stream_reader(compressed_file) as reader: 72 | with io.TextIOWrapper(reader, encoding="utf-8") as text_reader: 73 | with jsonlines.Reader(text_reader) as jsonl_reader: 74 | yield jsonl_reader 75 | 76 | 77 | def get_binary(seq: List[Union[int, str]], min_log2: int = 8, max_log2: int = 13, 78 | randomize: bool = True) -> Dict[int, List[Union[int, str]]]: 79 | """Applies binary dataset decomposition to a document. 80 | 81 | :param seq: A list of tokenized documents. 82 | :param min_log2: The log2 of the minimum subsequence length to keep. Smaller subsequences will be ignored. 83 | :param max_log2: The log2 of the maximum subsequence length to keep. Larger subsequences will be further divided. 84 | :param randomize: If True, subsequences larger than 2**max_log2 will be divided randomly into subsequences 85 | with lengths within the range determined by min_log2 and max_log2. If False, the division 86 | prioritizes keeping the longest acceptable subsequences. 87 | :return: A dictionary `d` where `d[i]` contains the subsequences of `seq` with lengths of 2**i+1. 88 | """ 89 | out_map = defaultdict(list) 90 | ps = 2 ** np.arange(max_log2 + 1 - min_log2) 91 | ps = ps / ps.sum() 92 | while len(seq) > 1: 93 | k = int(math.log2(len(seq) - 1)) 94 | if k < min_log2: 95 | return out_map 96 | 97 | if k > max_log2: 98 | if randomize: 99 | k = np.random.choice( 100 | np.arange( 101 | max_log2, min_log2 - 1, -1), p=ps) 102 | else: 103 | k = min(k, max_log2) 104 | 105 | out_map[k].append(seq[:(1 << k) + 1]) 106 | seq = seq[(1 << k) + 1:] 107 | return out_map 108 | 109 | 110 | def tokenize_and_shard( 111 | file_names: List[str], 112 | my_id: int, 113 | output_dir: str, 114 | enc: Union[PreTrainedTokenizerFast, PreTrainedTokenizer], 115 | min_bucket: int, 116 | max_bucket: int) -> None: 117 | """Performs dataset-decomposition tokenize-and-shuffle using a single process. 118 | 119 | :param file_names: A list of input data files. 120 | :param my_id: The process ID for the current worker. 121 | :param output_dir: The path to the output directory where the sharded webdataset files will be stored. 122 | :param enc: A tokenizer object for tokenizing the input data. 123 | :param min_bucket: The index of the bucket containing the shortest sequences. 124 | :param max_bucket: The index of the bucket containing the longest sequences. 125 | :return: None 126 | """ 127 | start_time = time.time() 128 | 129 | shard_writer = {} 130 | for k in range(min_bucket, max_bucket + 1): 131 | output_dir_k = os.path.join(output_dir, f'{k}') 132 | os.makedirs(output_dir_k, exist_ok=True) 133 | shard_writer[k] = ShardWriter( 134 | os.path.join( 135 | output_dir_k, 136 | "shard-%07d.tar"), 137 | maxcount=SHARD_SIZE[k]) 138 | 139 | # dictionary where keys are log2 length, and values are list of sequences. 140 | chunks = defaultdict(list) 141 | 142 | num_entries = 0 143 | 144 | for file_name in file_names: 145 | with get_item_reader(file_name) as item_reader: 146 | for item in item_reader: 147 | string = item["text"] 148 | try: 149 | tokens = enc(string).input_ids + [EOT_TOKEN] 150 | token_map = get_binary( 151 | tokens, min_log2=min_bucket, max_log2=max_bucket) 152 | num_entries += 1 153 | except BaseException: 154 | print("Failed to encode string.") 155 | continue 156 | 157 | for k, v_list in token_map.items(): 158 | for v in v_list: 159 | chunks[k].append(v) 160 | if len(chunks[k]) == SHARD_POOL_FACTOR * SHARD_SIZE[k]: 161 | random.shuffle(chunks[k]) 162 | write_to_shard( 163 | chunks[k][:SHARD_SIZE[k]], shard_writer[k]) 164 | chunks[k] = chunks[k][SHARD_SIZE[k]:] 165 | 166 | total_time = time.time() - start_time 167 | print( 168 | f"Process {my_id} found {num_entries} entries in {total_time} seconds", 169 | flush=True, 170 | ) 171 | 172 | # Write remaining shards 173 | for k in chunks.keys(): 174 | random.shuffle(chunks[k]) 175 | for i in range(0, len(chunks[k]), SHARD_SIZE[k]): 176 | if i + SHARD_SIZE[k] <= len(chunks[k] 177 | ): # Do not allow partial shards 178 | write_to_shard( 179 | chunks[k][i: i + SHARD_SIZE[k]], shard_writer[k]) 180 | 181 | print(f"Process {my_id} Done.", flush=True) 182 | 183 | 184 | def merge_process_dirs( 185 | output_dir: str, 186 | min_bucket: int, 187 | max_bucket: int, 188 | num_workers: int) -> None: 189 | """Merges multiple webdatasets into one for each bucket. 190 | 191 | :param output_dir: Path to a directory containing [num_workers] subdirectories. Each subdirectory is the output 192 | of a single process of tokenize-and-shuffle and contains subdirectories for different buckets. 193 | :param min_bucket: The index of the bucket with the shortest sequences. 194 | :param max_bucket: The index of the bucket with the longest sequences. 195 | :param num_workers: The number of processes used for parallel computation. 196 | :return: None 197 | """ 198 | process_dirs = os.listdir(output_dir) 199 | for k in tqdm( 200 | range( 201 | min_bucket, 202 | max_bucket + 203 | 1), 204 | total=max_bucket - 205 | min_bucket + 206 | 1): 207 | wds_dirs = [os.path.join(output_dir, p, f"{k}") for p in process_dirs] 208 | 209 | transfer_map = {} 210 | global_index = 0 211 | for i, dir in enumerate(wds_dirs): 212 | tarfiles = [os.path.join(dir, file) for file in os.listdir( 213 | dir) if file.endswith(".tar")] 214 | for a_tar in tarfiles: 215 | dir_path = os.path.join(output_dir, f"D_{k}") 216 | if not os.path.exists(dir_path): 217 | os.makedirs(dir_path, exist_ok=True) 218 | target_file = os.path.join( 219 | dir_path, "shard-{:07d}.tar".format(global_index)) 220 | global_index += 1 221 | transfer_map[a_tar] = target_file 222 | 223 | with multiprocessing.Pool(processes=num_workers) as pool: 224 | pool.starmap( 225 | shutil.move, 226 | [(src, transfer_map[src]) for src in transfer_map], 227 | ) 228 | # Remove original subdirs 229 | for p in process_dirs: 230 | dir_path = os.path.join(output_dir, p) 231 | shutil.rmtree(dir_path) 232 | 233 | 234 | def tokenize_shuffle( 235 | input_files: List[str], 236 | output_dir: str, 237 | num_workers: int, 238 | min_bucket: int, 239 | max_bucket: int, 240 | ) -> None: 241 | """Performs dataset-decomposition tokenize-and-shuffle using multiple processes. 242 | 243 | :param input_files: A list of input data files. 244 | :param output_dir: The path to the output directory. 245 | :param num_workers: The number of processes to use for parallel computation. 246 | :param min_bucket: The index of the bucket containing the shortest sequences. 247 | :param max_bucket: The index of the bucket containing the longest sequences. 248 | :return: None 249 | """ 250 | input_files = [glob.glob(input_file) for input_file in input_files] 251 | input_files = [x for y in input_files for x in y] 252 | 253 | # Shuffle the input files 254 | random.shuffle(input_files) 255 | print("Number of input files = {}".format(len(input_files))) 256 | 257 | enc = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") 258 | 259 | # assert len(input_files) % num_workers == 0 260 | files_per_worker = len(input_files) // num_workers 261 | file_groups = [input_files[x: x + files_per_worker] 262 | for x in range(0, len(input_files), files_per_worker)] 263 | 264 | with multiprocessing.Pool(processes=num_workers) as pool: 265 | pool.starmap( 266 | tokenize_and_shard, 267 | [ 268 | ( 269 | fg, 270 | my_id, 271 | os.path.join(output_dir, str(my_id)), 272 | enc, 273 | min_bucket, 274 | max_bucket, 275 | ) 276 | for my_id, fg in enumerate(file_groups) 277 | ], 278 | ) 279 | 280 | 281 | if __name__ == "__main__": 282 | parser = argparse.ArgumentParser() 283 | parser.add_argument( 284 | "--input-files", 285 | type=str, 286 | nargs="+", 287 | help="Set of input data files.") 288 | parser.add_argument( 289 | "--output-dir", 290 | type=Path, 291 | help="Path to output directory.") 292 | parser.add_argument( 293 | "--num-workers", 294 | type=int, 295 | default=32, 296 | help="Number of workers to use.") 297 | parser.add_argument( 298 | "--min-bucket", 299 | type=int, 300 | default=8, 301 | help="log2 of the shortest sequences.") 302 | parser.add_argument( 303 | "--max-bucket", 304 | type=int, 305 | default=13, 306 | help="log2 of the longest sequences.") 307 | 308 | args = parser.parse_args() 309 | 310 | tokenize_shuffle( 311 | args.input_files, 312 | args.output_dir, 313 | args.num_workers, 314 | args.min_bucket, 315 | args.max_bucket, 316 | ) 317 | merge_process_dirs( 318 | args.output_dir, 319 | args.min_bucket, 320 | args.max_bucket, 321 | args.num_workers) 322 | -------------------------------------------------------------------------------- /scripts/train_dd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=$PYTHONPATH:./open_lm 4 | torchrun --nproc-per-node 8 \ 5 | --nnodes 1 \ 6 | --node_rank 0 \ 7 | --max_restarts=0 \ 8 | --rdzv_backend c10d \ 9 | --rdzv_conf "timeout=3000,read_timeout=10000" \ 10 | -m open_lm.main \ 11 | --accum-freq 4 \ 12 | --global-batch-size 64 \ 13 | --beta1 0.9 \ 14 | --beta2 0.95 \ 15 | --data-key txt \ 16 | --ffn-type swiglu \ 17 | --fsdp \ 18 | --fsdp-limit-all-gathers \ 19 | --log-every-n-steps 32 \ 20 | --lr 0.003 \ 21 | --lr-cooldown-end 3e-5 \ 22 | --model open_lm_1b \ 23 | --name dd_open_lm_1b_$RANDOM \ 24 | --precision amp_bfloat16 \ 25 | --qk-norm \ 26 | --seed 42 \ 27 | --warmup 5000 \ 28 | --wd 0.033 \ 29 | --workers 1 \ 30 | --z-loss-coefficient 0.0001 \ 31 | --ignore-parse-errors \ 32 | --logs /mnt/open_lm_logs/ \ 33 | --dataset-manifest "/mnt/processed_datasets/dclm/D_8/manifest.jsonl" \ 34 | "/mnt/processed_datasets/dclm/D_9/manifest.jsonl" \ 35 | "/mnt/processed_datasets/dclm/D_10/manifest.jsonl" \ 36 | "/mnt/processed_datasets/dclm/D_11/manifest.jsonl" \ 37 | "/mnt/processed_datasets/dclm/D_12/manifest.jsonl" \ 38 | "/mnt/processed_datasets/dclm/D_13/manifest.jsonl" \ 39 | --epochs 8 \ 40 | --train-num-samples 3607101440 \ 41 | --dataset-batch-mult 32 16 8 4 2 1 \ 42 | --source-num-seq-per-epoch 1507328 1277952 794624 335872 137216 61440 \ 43 | --train-data-mix-weights 1472 1248 776 328 134 60 \ 44 | --fsdp-amp \ 45 | --grad-clip-norm 1 \ 46 | --attn-name xformers_attn \ 47 | --model-norm gain_only_lp_layer_norm \ 48 | --wandb-project-name dd_code \ 49 | --report-to wandb \ 50 | -------------------------------------------------------------------------------- /scripts/wiki_download.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | import os 8 | 9 | from datasets import load_dataset 10 | from tqdm import tqdm 11 | 12 | 13 | def main(output_dir: str, num_shards: int = 32): 14 | """Downloads the Wikipedia dataset. 15 | 16 | :param output_dir: The path to the output directory where the downloaded files will be saved. 17 | :param num_shards: The number of JSONL files to divide the downloaded data into. 18 | :return: None 19 | """ 20 | os.makedirs(output_dir, exist_ok=True) 21 | data = load_dataset("wikipedia", "20220301.en", trust_remote_code=True) 22 | 23 | for split, dataset in data.items(): 24 | for i in tqdm(range(num_shards)): 25 | dataset_shard = dataset.shard( 26 | num_shards=num_shards, index=i, contiguous=True) 27 | output_file = os.path.join( 28 | output_dir, f"wiki_en_20220301_{split}_{i}.jsonl") 29 | dataset_shard.to_json(output_file) 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument( 35 | "--output-dir", 36 | type=str, 37 | required=True, 38 | help="Where to store the wikipedia .jsonl files", 39 | ) 40 | 41 | args = parser.parse_args() 42 | main(args.output_dir) 43 | --------------------------------------------------------------------------------