├── .gitignore
├── ACKNOWLEDGMENTS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE.txt
├── README.md
├── assets
├── dd-animation.gif
├── sequence_length.png
└── throughput.png
├── open_lm.patch
└── scripts
├── dclm_download.py
├── get_dd_params.py
├── get_stats.py
├── make_dd_buckets.py
├── train_dd.sh
└── wiki_download.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea*
2 |
--------------------------------------------------------------------------------
/ACKNOWLEDGMENTS:
--------------------------------------------------------------------------------
1 | Acknowledgements
2 | Portions of this Software may utilize the following copyrighted
3 | material, the use of which is hereby acknowledged.
4 |
5 | ------------------------------------------------
6 | OpenLM (open_lm)
7 |
8 | MIT License
9 |
10 | Copyright (c) 2023 mlfoundations
11 |
12 | Permission is hereby granted, free of charge, to any person obtaining a copy
13 | of this software and associated documentation files (the "Software"), to deal
14 | in the Software without restriction, including without limitation the rights
15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16 | copies of the Software, and to permit persons to whom the Software is
17 | furnished to do so, subject to the following conditions:
18 |
19 | The above copyright notice and this permission notice shall be included in all
20 | copies or substantial portions of the Software.
21 |
22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 | SOFTWARE.
29 |
30 | ------------------------------------------------
31 | HuggingFace Datasets (datasets)
32 |
33 |
34 |
35 | Apache License
36 | Version 2.0, January 2004
37 | http://www.apache.org/licenses/
38 |
39 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
40 |
41 | 1. Definitions.
42 |
43 | "License" shall mean the terms and conditions for use, reproduction,
44 | and distribution as defined by Sections 1 through 9 of this document.
45 |
46 | "Licensor" shall mean the copyright owner or entity authorized by
47 | the copyright owner that is granting the License.
48 |
49 | "Legal Entity" shall mean the union of the acting entity and all
50 | other entities that control, are controlled by, or are under common
51 | control with that entity. For the purposes of this definition,
52 | "control" means (i) the power, direct or indirect, to cause the
53 | direction or management of such entity, whether by contract or
54 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
55 | outstanding shares, or (iii) beneficial ownership of such entity.
56 |
57 | "You" (or "Your") shall mean an individual or Legal Entity
58 | exercising permissions granted by this License.
59 |
60 | "Source" form shall mean the preferred form for making modifications,
61 | including but not limited to software source code, documentation
62 | source, and configuration files.
63 |
64 | "Object" form shall mean any form resulting from mechanical
65 | transformation or translation of a Source form, including but
66 | not limited to compiled object code, generated documentation,
67 | and conversions to other media types.
68 |
69 | "Work" shall mean the work of authorship, whether in Source or
70 | Object form, made available under the License, as indicated by a
71 | copyright notice that is included in or attached to the work
72 | (an example is provided in the Appendix below).
73 |
74 | "Derivative Works" shall mean any work, whether in Source or Object
75 | form, that is based on (or derived from) the Work and for which the
76 | editorial revisions, annotations, elaborations, or other modifications
77 | represent, as a whole, an original work of authorship. For the purposes
78 | of this License, Derivative Works shall not include works that remain
79 | separable from, or merely link (or bind by name) to the interfaces of,
80 | the Work and Derivative Works thereof.
81 |
82 | "Contribution" shall mean any work of authorship, including
83 | the original version of the Work and any modifications or additions
84 | to that Work or Derivative Works thereof, that is intentionally
85 | submitted to Licensor for inclusion in the Work by the copyright owner
86 | or by an individual or Legal Entity authorized to submit on behalf of
87 | the copyright owner. For the purposes of this definition, "submitted"
88 | means any form of electronic, verbal, or written communication sent
89 | to the Licensor or its representatives, including but not limited to
90 | communication on electronic mailing lists, source code control systems,
91 | and issue tracking systems that are managed by, or on behalf of, the
92 | Licensor for the purpose of discussing and improving the Work, but
93 | excluding communication that is conspicuously marked or otherwise
94 | designated in writing by the copyright owner as "Not a Contribution."
95 |
96 | "Contributor" shall mean Licensor and any individual or Legal Entity
97 | on behalf of whom a Contribution has been received by Licensor and
98 | subsequently incorporated within the Work.
99 |
100 | 2. Grant of Copyright License. Subject to the terms and conditions of
101 | this License, each Contributor hereby grants to You a perpetual,
102 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
103 | copyright license to reproduce, prepare Derivative Works of,
104 | publicly display, publicly perform, sublicense, and distribute the
105 | Work and such Derivative Works in Source or Object form.
106 |
107 | 3. Grant of Patent License. Subject to the terms and conditions of
108 | this License, each Contributor hereby grants to You a perpetual,
109 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
110 | (except as stated in this section) patent license to make, have made,
111 | use, offer to sell, sell, import, and otherwise transfer the Work,
112 | where such license applies only to those patent claims licensable
113 | by such Contributor that are necessarily infringed by their
114 | Contribution(s) alone or by combination of their Contribution(s)
115 | with the Work to which such Contribution(s) was submitted. If You
116 | institute patent litigation against any entity (including a
117 | cross-claim or counterclaim in a lawsuit) alleging that the Work
118 | or a Contribution incorporated within the Work constitutes direct
119 | or contributory patent infringement, then any patent licenses
120 | granted to You under this License for that Work shall terminate
121 | as of the date such litigation is filed.
122 |
123 | 4. Redistribution. You may reproduce and distribute copies of the
124 | Work or Derivative Works thereof in any medium, with or without
125 | modifications, and in Source or Object form, provided that You
126 | meet the following conditions:
127 |
128 | (a) You must give any other recipients of the Work or
129 | Derivative Works a copy of this License; and
130 |
131 | (b) You must cause any modified files to carry prominent notices
132 | stating that You changed the files; and
133 |
134 | (c) You must retain, in the Source form of any Derivative Works
135 | that You distribute, all copyright, patent, trademark, and
136 | attribution notices from the Source form of the Work,
137 | excluding those notices that do not pertain to any part of
138 | the Derivative Works; and
139 |
140 | (d) If the Work includes a "NOTICE" text file as part of its
141 | distribution, then any Derivative Works that You distribute must
142 | include a readable copy of the attribution notices contained
143 | within such NOTICE file, excluding those notices that do not
144 | pertain to any part of the Derivative Works, in at least one
145 | of the following places: within a NOTICE text file distributed
146 | as part of the Derivative Works; within the Source form or
147 | documentation, if provided along with the Derivative Works; or,
148 | within a display generated by the Derivative Works, if and
149 | wherever such third-party notices normally appear. The contents
150 | of the NOTICE file are for informational purposes only and
151 | do not modify the License. You may add Your own attribution
152 | notices within Derivative Works that You distribute, alongside
153 | or as an addendum to the NOTICE text from the Work, provided
154 | that such additional attribution notices cannot be construed
155 | as modifying the License.
156 |
157 | You may add Your own copyright statement to Your modifications and
158 | may provide additional or different license terms and conditions
159 | for use, reproduction, or distribution of Your modifications, or
160 | for any such Derivative Works as a whole, provided Your use,
161 | reproduction, and distribution of the Work otherwise complies with
162 | the conditions stated in this License.
163 |
164 | 5. Submission of Contributions. Unless You explicitly state otherwise,
165 | any Contribution intentionally submitted for inclusion in the Work
166 | by You to the Licensor shall be under the terms and conditions of
167 | this License, without any additional terms or conditions.
168 | Notwithstanding the above, nothing herein shall supersede or modify
169 | the terms of any separate license agreement you may have executed
170 | with Licensor regarding such Contributions.
171 |
172 | 6. Trademarks. This License does not grant permission to use the trade
173 | names, trademarks, service marks, or product names of the Licensor,
174 | except as required for reasonable and customary use in describing the
175 | origin of the Work and reproducing the content of the NOTICE file.
176 |
177 | 7. Disclaimer of Warranty. Unless required by applicable law or
178 | agreed to in writing, Licensor provides the Work (and each
179 | Contributor provides its Contributions) on an "AS IS" BASIS,
180 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
181 | implied, including, without limitation, any warranties or conditions
182 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
183 | PARTICULAR PURPOSE. You are solely responsible for determining the
184 | appropriateness of using or redistributing the Work and assume any
185 | risks associated with Your exercise of permissions under this License.
186 |
187 | 8. Limitation of Liability. In no event and under no legal theory,
188 | whether in tort (including negligence), contract, or otherwise,
189 | unless required by applicable law (such as deliberate and grossly
190 | negligent acts) or agreed to in writing, shall any Contributor be
191 | liable to You for damages, including any direct, indirect, special,
192 | incidental, or consequential damages of any character arising as a
193 | result of this License or out of the use or inability to use the
194 | Work (including but not limited to damages for loss of goodwill,
195 | work stoppage, computer failure or malfunction, or any and all
196 | other commercial damages or losses), even if such Contributor
197 | has been advised of the possibility of such damages.
198 |
199 | 9. Accepting Warranty or Additional Liability. While redistributing
200 | the Work or Derivative Works thereof, You may choose to offer,
201 | and charge a fee for, acceptance of support, warranty, indemnity,
202 | or other liability obligations and/or rights consistent with this
203 | License. However, in accepting such obligations, You may act only
204 | on Your own behalf and on Your sole responsibility, not on behalf
205 | of any other Contributor, and only if You agree to indemnify,
206 | defend, and hold each Contributor harmless for any liability
207 | incurred by, or claims asserted against, such Contributor by reason
208 | of your accepting any such warranty or additional liability.
209 |
210 | END OF TERMS AND CONDITIONS
211 |
212 | APPENDIX: How to apply the Apache License to your work.
213 |
214 | To apply the Apache License to your work, attach the following
215 | boilerplate notice, with the fields enclosed by brackets "[]"
216 | replaced with your own identifying information. (Don't include
217 | the brackets!) The text should be enclosed in the appropriate
218 | comment syntax for the file format. We also recommend that a
219 | file or class name and description of purpose be included on the
220 | same "printed page" as the copyright notice for easier
221 | identification within third-party archives.
222 |
223 | Copyright [yyyy] [name of copyright owner]
224 |
225 | Licensed under the Apache License, Version 2.0 (the "License");
226 | you may not use this file except in compliance with the License.
227 | You may obtain a copy of the License at
228 |
229 | http://www.apache.org/licenses/LICENSE-2.0
230 |
231 | Unless required by applicable law or agreed to in writing, software
232 | distributed under the License is distributed on an "AS IS" BASIS,
233 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
234 | See the License for the specific language governing permissions and
235 | limitations under the License.
236 |
237 | ------------------------------------------------
238 | HuggingFace Transformers (transformers)
239 |
240 | Copyright 2018- The Hugging Face team. All rights reserved.
241 |
242 | Apache License
243 | Version 2.0, January 2004
244 | http://www.apache.org/licenses/
245 |
246 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
247 |
248 | 1. Definitions.
249 |
250 | "License" shall mean the terms and conditions for use, reproduction,
251 | and distribution as defined by Sections 1 through 9 of this document.
252 |
253 | "Licensor" shall mean the copyright owner or entity authorized by
254 | the copyright owner that is granting the License.
255 |
256 | "Legal Entity" shall mean the union of the acting entity and all
257 | other entities that control, are controlled by, or are under common
258 | control with that entity. For the purposes of this definition,
259 | "control" means (i) the power, direct or indirect, to cause the
260 | direction or management of such entity, whether by contract or
261 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
262 | outstanding shares, or (iii) beneficial ownership of such entity.
263 |
264 | "You" (or "Your") shall mean an individual or Legal Entity
265 | exercising permissions granted by this License.
266 |
267 | "Source" form shall mean the preferred form for making modifications,
268 | including but not limited to software source code, documentation
269 | source, and configuration files.
270 |
271 | "Object" form shall mean any form resulting from mechanical
272 | transformation or translation of a Source form, including but
273 | not limited to compiled object code, generated documentation,
274 | and conversions to other media types.
275 |
276 | "Work" shall mean the work of authorship, whether in Source or
277 | Object form, made available under the License, as indicated by a
278 | copyright notice that is included in or attached to the work
279 | (an example is provided in the Appendix below).
280 |
281 | "Derivative Works" shall mean any work, whether in Source or Object
282 | form, that is based on (or derived from) the Work and for which the
283 | editorial revisions, annotations, elaborations, or other modifications
284 | represent, as a whole, an original work of authorship. For the purposes
285 | of this License, Derivative Works shall not include works that remain
286 | separable from, or merely link (or bind by name) to the interfaces of,
287 | the Work and Derivative Works thereof.
288 |
289 | "Contribution" shall mean any work of authorship, including
290 | the original version of the Work and any modifications or additions
291 | to that Work or Derivative Works thereof, that is intentionally
292 | submitted to Licensor for inclusion in the Work by the copyright owner
293 | or by an individual or Legal Entity authorized to submit on behalf of
294 | the copyright owner. For the purposes of this definition, "submitted"
295 | means any form of electronic, verbal, or written communication sent
296 | to the Licensor or its representatives, including but not limited to
297 | communication on electronic mailing lists, source code control systems,
298 | and issue tracking systems that are managed by, or on behalf of, the
299 | Licensor for the purpose of discussing and improving the Work, but
300 | excluding communication that is conspicuously marked or otherwise
301 | designated in writing by the copyright owner as "Not a Contribution."
302 |
303 | "Contributor" shall mean Licensor and any individual or Legal Entity
304 | on behalf of whom a Contribution has been received by Licensor and
305 | subsequently incorporated within the Work.
306 |
307 | 2. Grant of Copyright License. Subject to the terms and conditions of
308 | this License, each Contributor hereby grants to You a perpetual,
309 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
310 | copyright license to reproduce, prepare Derivative Works of,
311 | publicly display, publicly perform, sublicense, and distribute the
312 | Work and such Derivative Works in Source or Object form.
313 |
314 | 3. Grant of Patent License. Subject to the terms and conditions of
315 | this License, each Contributor hereby grants to You a perpetual,
316 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
317 | (except as stated in this section) patent license to make, have made,
318 | use, offer to sell, sell, import, and otherwise transfer the Work,
319 | where such license applies only to those patent claims licensable
320 | by such Contributor that are necessarily infringed by their
321 | Contribution(s) alone or by combination of their Contribution(s)
322 | with the Work to which such Contribution(s) was submitted. If You
323 | institute patent litigation against any entity (including a
324 | cross-claim or counterclaim in a lawsuit) alleging that the Work
325 | or a Contribution incorporated within the Work constitutes direct
326 | or contributory patent infringement, then any patent licenses
327 | granted to You under this License for that Work shall terminate
328 | as of the date such litigation is filed.
329 |
330 | 4. Redistribution. You may reproduce and distribute copies of the
331 | Work or Derivative Works thereof in any medium, with or without
332 | modifications, and in Source or Object form, provided that You
333 | meet the following conditions:
334 |
335 | (a) You must give any other recipients of the Work or
336 | Derivative Works a copy of this License; and
337 |
338 | (b) You must cause any modified files to carry prominent notices
339 | stating that You changed the files; and
340 |
341 | (c) You must retain, in the Source form of any Derivative Works
342 | that You distribute, all copyright, patent, trademark, and
343 | attribution notices from the Source form of the Work,
344 | excluding those notices that do not pertain to any part of
345 | the Derivative Works; and
346 |
347 | (d) If the Work includes a "NOTICE" text file as part of its
348 | distribution, then any Derivative Works that You distribute must
349 | include a readable copy of the attribution notices contained
350 | within such NOTICE file, excluding those notices that do not
351 | pertain to any part of the Derivative Works, in at least one
352 | of the following places: within a NOTICE text file distributed
353 | as part of the Derivative Works; within the Source form or
354 | documentation, if provided along with the Derivative Works; or,
355 | within a display generated by the Derivative Works, if and
356 | wherever such third-party notices normally appear. The contents
357 | of the NOTICE file are for informational purposes only and
358 | do not modify the License. You may add Your own attribution
359 | notices within Derivative Works that You distribute, alongside
360 | or as an addendum to the NOTICE text from the Work, provided
361 | that such additional attribution notices cannot be construed
362 | as modifying the License.
363 |
364 | You may add Your own copyright statement to Your modifications and
365 | may provide additional or different license terms and conditions
366 | for use, reproduction, or distribution of Your modifications, or
367 | for any such Derivative Works as a whole, provided Your use,
368 | reproduction, and distribution of the Work otherwise complies with
369 | the conditions stated in this License.
370 |
371 | 5. Submission of Contributions. Unless You explicitly state otherwise,
372 | any Contribution intentionally submitted for inclusion in the Work
373 | by You to the Licensor shall be under the terms and conditions of
374 | this License, without any additional terms or conditions.
375 | Notwithstanding the above, nothing herein shall supersede or modify
376 | the terms of any separate license agreement you may have executed
377 | with Licensor regarding such Contributions.
378 |
379 | 6. Trademarks. This License does not grant permission to use the trade
380 | names, trademarks, service marks, or product names of the Licensor,
381 | except as required for reasonable and customary use in describing the
382 | origin of the Work and reproducing the content of the NOTICE file.
383 |
384 | 7. Disclaimer of Warranty. Unless required by applicable law or
385 | agreed to in writing, Licensor provides the Work (and each
386 | Contributor provides its Contributions) on an "AS IS" BASIS,
387 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
388 | implied, including, without limitation, any warranties or conditions
389 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
390 | PARTICULAR PURPOSE. You are solely responsible for determining the
391 | appropriateness of using or redistributing the Work and assume any
392 | risks associated with Your exercise of permissions under this License.
393 |
394 | 8. Limitation of Liability. In no event and under no legal theory,
395 | whether in tort (including negligence), contract, or otherwise,
396 | unless required by applicable law (such as deliberate and grossly
397 | negligent acts) or agreed to in writing, shall any Contributor be
398 | liable to You for damages, including any direct, indirect, special,
399 | incidental, or consequential damages of any character arising as a
400 | result of this License or out of the use or inability to use the
401 | Work (including but not limited to damages for loss of goodwill,
402 | work stoppage, computer failure or malfunction, or any and all
403 | other commercial damages or losses), even if such Contributor
404 | has been advised of the possibility of such damages.
405 |
406 | 9. Accepting Warranty or Additional Liability. While redistributing
407 | the Work or Derivative Works thereof, You may choose to offer,
408 | and charge a fee for, acceptance of support, warranty, indemnity,
409 | or other liability obligations and/or rights consistent with this
410 | License. However, in accepting such obligations, You may act only
411 | on Your own behalf and on Your sole responsibility, not on behalf
412 | of any other Contributor, and only if You agree to indemnify,
413 | defend, and hold each Contributor harmless for any liability
414 | incurred by, or claims asserted against, such Contributor by reason
415 | of your accepting any such warranty or additional liability.
416 |
417 | END OF TERMS AND CONDITIONS
418 |
419 | APPENDIX: How to apply the Apache License to your work.
420 |
421 | To apply the Apache License to your work, attach the following
422 | boilerplate notice, with the fields enclosed by brackets "[]"
423 | replaced with your own identifying information. (Don't include
424 | the brackets!) The text should be enclosed in the appropriate
425 | comment syntax for the file format. We also recommend that a
426 | file or class name and description of purpose be included on the
427 | same "printed page" as the copyright notice for easier
428 | identification within third-party archives.
429 |
430 | Copyright [yyyy] [name of copyright owner]
431 |
432 | Licensed under the Apache License, Version 2.0 (the "License");
433 | you may not use this file except in compliance with the License.
434 | You may obtain a copy of the License at
435 |
436 | http://www.apache.org/licenses/LICENSE-2.0
437 |
438 | Unless required by applicable law or agreed to in writing, software
439 | distributed under the License is distributed on an "AS IS" BASIS,
440 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
441 | See the License for the specific language governing permissions and
442 | limitations under the License.
443 |
444 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6 |
7 | ## Before you get started
8 |
9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10 |
11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
12 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (C) 2024 Apple Inc. All Rights Reserved.
2 |
3 | IMPORTANT: This Apple software is supplied to you by Apple
4 | Inc. ("Apple") in consideration of your agreement to the following
5 | terms, and your use, installation, modification or redistribution of
6 | this Apple software constitutes acceptance of these terms. If you do
7 | not agree with these terms, please do not use, install, modify or
8 | redistribute this Apple software.
9 |
10 | In consideration of your agreement to abide by the following terms, and
11 | subject to these terms, Apple grants you a personal, non-exclusive
12 | license, under Apple's copyrights in this original Apple software (the
13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
14 | Software, with or without modifications, in source and/or binary forms;
15 | provided that if you redistribute the Apple Software in its entirety and
16 | without modifications, you must retain this notice and the following
17 | text and disclaimers in all such redistributions of the Apple Software.
18 | Neither the name, trademarks, service marks or logos of Apple Inc. may
19 | be used to endorse or promote products derived from the Apple Software
20 | without specific prior written permission from Apple. Except as
21 | expressly stated in this notice, no other rights or licenses, express or
22 | implied, are granted by Apple herein, including but not limited to any
23 | patent rights that may be infringed by your derivative works or by other
24 | works in which the Apple Software may be incorporated.
25 |
26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31 |
32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39 | POSSIBILITY OF SUCH DAMAGE.
40 |
41 | -------------------------------------------------------------------------------
42 | SOFTWARE DISTRIBUTED WITH ML-Dataset-Decomposition:
43 |
44 | The ML-Dataset-Decomposition software includes a number of subcomponents with separate
45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
46 | -------------------------------------------------------------------------------
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dataset-Decomposition
2 |
3 | This repository contains the implementation of the [Dataset Decomposition NeurIPS 2024 paper](https://arxiv.org/abs/2405.13226).
4 | The training code is based on the [OpenLM repository](https://github.com/mlfoundations/open_lm), as explained below.
5 |
6 | Dataset decomposition enables **fast pre-training for long-context** inputs by organizing documents into buckets based on their length.
7 | Additionally, a **length-based curriculum** can be applied (starting with short sequences and gradually progressing to longer ones)
8 | to achieve **improved performance** on both regular and long-context benchmarks.
9 |
10 |
11 |
12 | ## Table of Contents
13 |
14 | - [Setup](#setup)
15 | - [Create Decomposed Datasets](#create-decomposed-datasets)
16 | - [Launch Variable Sequence Length Training](#launch-variable-sequence-length-training)
17 | - [Results](#results)
18 | - [Citation](#citation)
19 |
20 | ## Setup
21 |
22 | Clone [OpenLM](https://github.com/mlfoundations/open_lm) and apply our patch to enable variable sequence length training.
23 | Install the requirements as instructed in the [OpenLM repository](https://github.com/mlfoundations/open_lm).
24 | Then, from the root of this repo perform the following steps:
25 | ```shell
26 | git clone https://github.com/mlfoundations/open_lm.git
27 | cd open_lm
28 | git checkout 9bb92ef1689333534b7057942a20d18a46d1fa52
29 | git apply ../open_lm.patch
30 | # Install dependencies as required by OpenLM
31 | cd ..
32 | ```
33 |
34 | ## Create decomposed datasets
35 |
36 | Dataset decomposition is a per-document method and is applicable to any dataset.
37 | Here, we show an example for small datasets in the form of JSONL files.
38 |
39 | ### Step 1:
40 | Get some data. Make sure to upgrade the [datasets library](https://pypi.org/project/datasets) (we use version 3.1).
41 | ```shell
42 | mkdir -p /mnt/raw_datasets/wiki
43 | python scripts/wiki_download.py --output-dir /mnt/raw_datasets/wiki
44 | ```
45 | Once the download is complete, you will have 32 JSONL files.
46 | Alternatively, you can run [scripts/dclm_download.py](scripts/dclm_download.py) to
47 | download a small potion of the [DCLM dataset](https://www.datacomp.ai/dclm/index.html#home).
48 |
49 | ### Step 2:
50 | Run tokenize+bucketize+shuffle:
51 | ```shell
52 | mkdir -p /mnt/processed_datasets/wiki
53 | python scripts/make_dd_buckets.py --input-files /mnt/raw_datasets/wiki/*.jsonl \
54 | --output-dir /mnt/processed_datasets/wiki --min-bucket 8 --max-bucket 13 --num-workers 32
55 | ```
56 | We use `32` workers here. You can increase this number for faster processing if you have more JSONL files.
57 | The `--min-bucket` and `--max-bucket` parameters determine the range of buckets for dataset decomposition.
58 | For the example above, buckets will be created for sequences with lengths `2^8=256`, `2^9=512`, ..., `2^13=8192`.
59 |
60 | When this step is completed, buckets will be created with multiple shards per bucket.
61 |
62 | ### Step 3:
63 | Create the WebDataset manifest files, one for each bucket. This ensures that each bucket has its corresponding manifest for proper dataset handling and processing.
64 | ```shell
65 | for i in $(seq 8 13);
66 | do
67 | python open_lm/open_lm/utils/make_wds_manifest.py --data-dir /mnt/processed_datasets/wiki/D_$i --num-workers 16
68 | done
69 | ```
70 | ### Step 4 (Optional)
71 | Get stats of the (decomposed) dataset you just created:
72 | ```shell
73 | python scripts/get_stats.py --dd-dir /mnt/processed_datasets/wiki
74 | ```
75 | With the above bucket sizes, we obtain the following statistics for the Wikipedia and DCLM datasets:
76 | #### Wikipedia
77 | ```text
78 | D_8 : # shards: 553 seq-length: 256 # sequences: 2,265,088 # tokens: 579,862,528
79 | D_9 : # shards: 779 seq-length: 512 # sequences: 1,595,392 # tokens: 816,840,704
80 | D_10: # shards: 831 seq-length: 1,024 # sequences: 850,944 # tokens: 871,366,656
81 | D_11: # shards: 690 seq-length: 2,048 # sequences: 353,280 # tokens: 723,517,440
82 | D_12: # shards: 475 seq-length: 4,096 # sequences: 121,600 # tokens: 498,073,600
83 | D_13: # shards: 291 seq-length: 8,192 # sequences: 37,248 # tokens: 305,135,616
84 | ********************
85 | Total number of tokens = 3,794,796,544
86 | ```
87 | #### DCLM-Baseline subset
88 | ```text
89 | D_8 : # shards: 3,560 seq-length: 256 # sequences: 14,581,760 # tokens: 3,732,930,560
90 | D_9 : # shards: 5,996 seq-length: 512 # sequences: 12,279,808 # tokens: 6,287,261,696
91 | D_10: # shards: 7,410 seq-length: 1,024 # sequences: 7,587,840 # tokens: 7,769,948,160
92 | D_11: # shards: 6,309 seq-length: 2,048 # sequences: 3,230,208 # tokens: 6,615,465,984
93 | D_12: # shards: 5,157 seq-length: 4,096 # sequences: 1,320,192 # tokens: 5,407,506,432
94 | D_13: # shards: 4,513 seq-length: 8,192 # sequences: 577,664 # tokens: 4,732,223,488
95 | ********************
96 | Total number of tokens = 34,545,336,320
97 | ```
98 |
99 | ## Launch Variable Sequence Length training
100 |
101 | ### Step 1
102 | Modify a run script with your desired hyperparameters and the path to the dataset.
103 | Refer to [the paper]((https://arxiv.org/abs/2405.13226))'s Appendix for the full list of hyperparameters.
104 | The dataset path can be either local or on S3.
105 |
106 | ### Step 2
107 | For dataset-decomposition parameters, you can use the following [helper code](scripts/get_dd_params.py) (or set them manually).
108 | For example, the parameters below configure training with a total of `29` billion tokens, `8` epochs/cycles, `8` GPUs, and a global batch size
109 | of `64*8192` tokens, with buckets sized from `256` to `8192`.
110 | (One global batch would include `64` sequences for the last bucket of sequences with a length of `8192`):
111 | ```shell
112 | python scripts/get_dd_params.py \
113 | --tokens 28795904000 \
114 | --epochs 8 \
115 | --gpus 8 \
116 | --global-batch-size 64 \
117 | --number-of-shards 3560 5996 7410 6309 5157 4513 \
118 | --sequence-per-shard 4096 2048 1024 512 256 128 \
119 | --sequence_sizes 256 512 1024 2048 4096 8192 \
120 | --batch-mult 32 16 8 4 2 1 \
121 | --train-data-mix-weights 32 16 8 4 2 1
122 | ```
123 |
124 | Here is a short description of each input argument:
125 |
126 | - `--tokens`: Total number of tokens to be processed.
127 | - `--epochs`: Number of cycles (also determines the number of checkpoints to save).
128 | - `--gpus`: Total number of GPUs.
129 | - `--global-batch-size`: Global batch size (assuming all sequences are of the longest length, e.g., 8192 here).
130 | - `--number-of-shards`: Number of available shards per bucket.
131 | - `--sequence-per-shard`: Number of sequences per shard per bucket.
132 | - `--sequence_sizes`: Length of sequences in each bucket.
133 | - `--batch-mult`: Batch multipliers to maintain a fixed number of tokens regardless of sequence length.
134 | - `--train-data-mix-weights`: Power-of-2 length-based curriculum (prioritizing shorter sequences first).
135 |
136 | It would output the following:
137 | ```text
138 | **** Use the following arguments:
139 | --epochs 8
140 | --train-num-samples 3607101440
141 | --dataset-batch-mult 32 16 8 4 2 1
142 | --source-num-seq-per-epoch 1507328 1277952 794624 335872 137216 61440
143 | --train-data-mix-weights 1472 1248 776 328 134 60
144 | ```
145 |
146 | ### Step 3:
147 | Update the [run script](scripts/train_dd.sh) with the above parameters, and launch the training.
148 | Ensure you log in to WandB (or disable WandB reporting) before running the script.
149 |
150 | ```shell
151 | bash scripts/train_dd.sh
152 | ```
153 | The above set of hyperparameters corresponds to [DCLM-Baseline 1B-1x](https://github.com/mlfoundations/dclm)
154 | with a maximum sequence length of `8192`.
155 | To extend beyond `8192`, make sure to update the model configuration files (located in `open_lm/model_configs`).
156 |
157 | On an H100 node with 8x GPUs, the above training should take less than 28 hours.
158 | For this example, the model performance is as follows:
159 |
160 | | ArcE | ArcC | Hellaswag | LamOAI | Winogrande | Winograd | WikiQA | OBQA | SQuAD | PIQA | COPA | CoQA | BoolQ |
161 | |--------|--------|-----------|--------|------------|----------|--------|-------|-------|-------|-------|-------|-------|
162 | | 65.0 | 35.5 | 57.8 | 61.0 | 58.9 | 75.5 | 52.5 | 39.8 | 35.3 | 73.4 | 72.0 | 31.8 | 61.7 |
163 |
164 | The number of tokens processed per second per GPU (H100) and the sampled sequence length over the course of training
165 | would be as shown below for this example.
166 |
167 |
168 |
169 |
170 |
171 | ## Results
172 | Please see the full list of ablations in [the paper](https://arxiv.org/abs/2405.13226).
173 | The following table summarizes some billion-scale results:
174 |
175 | - **RW** refers to the [RefinedWeb dataset](https://huggingface.co/datasets/tiiuae/falcon-refinedweb).
176 | - **DCLM** refers to the [DCLM-Baseline dataset](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0).
177 | - For models with **SFT**, we follow the same setup as DCLM-Baseline.
178 | - **DD** refers to Dataset Decomposition, and **C&C** refers to concat-and-chunk.
179 |
180 | All models are trained with a context length of `8192` and a total of `2^40` seen tokens (~1.1 trillion tokens).
181 |
182 | | **Model** | Dataset | Method | SFT | MMLU | ArcE | ArcC | Hellaswag | LambadaOAI | Winogrande | Winograd | WikiQA | OBQA | SQuAD | PIQA | COPA | CoQA | BoolQ |
183 | |-----------------|---------|--------|-----|------|------|------|-----------|------------|------------|----------|--------|------|-------|------|------|------|-------|
184 | | *Shots* | N/A | N/A | N/A | *5* | *3* | *3* | *0* | *0* | *5* | *3* | *3* | *10* | *3* | *0* | *0* | *0* | *0* |
185 | | **Random** | 0 | N/A | N/A | 25 | 25 | 25 | 25 | 0 | 50 | 50 | 0 | 25 | 0 | 50 | 50 | 0 | 50 |
186 | | **OpenLM 160m** | DCLM | DD | No | 24.5 | 51.9 | 26 | 40 | 44.2 | 52.5 | 65.6 | 42.4 | 33 | 17.8 | 68.4 | 61 | 19.9 | 58.8 |
187 | | **OpenLM 160m** | DCLM | DD | Yes | 25.6 | 52.9 | 27.6 | 39 | 39.9 | 50 | 65.6 | 36.2 | 31.4 | 36.2 | 66.1 | 62 | 29.3 | 49.1 |
188 | | **OpenLM 410m** | RW | C&C | No | 24.8 | 53.6 | 26.6 | 52.7 | 50.5 | 56.7 | 70.7 | 52.6 | 35.6 | 25.5 | 71.3 | 69 | 26.9 | 54.1 |
189 | | **OpenLM 410m** | RW | DD | No | 27 | 55.3 | 27.9 | 55.1 | 53.9 | 59 | 74.4 | 56.3 | 35 | 30.1 | 72.6 | 63 | 28.1 | 62.7 |
190 | | **OpenLM 410m** | DCLM | DD | No | 24.9 | 62.4 | 33.9 | 55.9 | 57.2 | 59.9 | 77.7 | 55.3 | 38.8 | 32 | 73.4 | 68 | 31.3 | 56.2 |
191 | | **OpenLM 410m** | DCLM | DD | Yes | 34.8 | 63.3 | 35.4 | 53.5 | 52.9 | 58.7 | 74.4 | 50.1 | 38.4 | 49.4 | 73.2 | 67 | 39.8 | 72.2 |
192 | | **OpenLM 1B** | DCLM | DD | No | 28.6 | 70.6 | 43.2 | 68.9 | 67.6 | 67.6 | 85.7 | 62.9 | 44.2 | 47.6 | 77.1 | 77 | 39.9 | 58.7 |
193 | | **OpenLM 1B** | DCLM | DD | Yes | 49.1 | 70.7 | 43.1 | 68.6 | 61 | 66.3 | 78.4 | 56.8 | 45 | 57.1 | 77 | 80 | 46.5 | 80.7 |
194 |
195 | ## Citation
196 | If you like our work, please consider citing [our NeurIPS 2024 paper](https://arxiv.org/abs/2405.13226):
197 | ```bibtex
198 | @article{pouransari2024dataset,
199 | title={Dataset Decomposition: Faster LLM Training with Variable Sequence Length Curriculum},
200 | author={Pouransari, Hadi and Li, Chun-Liang and Chang, Jen-Hao Rick and Vasu, Pavan Kumar Anasosalu and Koc, Cem and Shankar, Vaishaal and Tuzel, Oncel},
201 | journal={arXiv preprint arXiv:2405.13226},
202 | year={2024},
203 | url={https://arxiv.org/abs/2405.13226}
204 | }
205 | ```
206 |
--------------------------------------------------------------------------------
/assets/dd-animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-dataset-decomposition/9a23939205bd371184b9e627a7abc9e9e9e07a4b/assets/dd-animation.gif
--------------------------------------------------------------------------------
/assets/sequence_length.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-dataset-decomposition/9a23939205bd371184b9e627a7abc9e9e9e07a4b/assets/sequence_length.png
--------------------------------------------------------------------------------
/assets/throughput.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-dataset-decomposition/9a23939205bd371184b9e627a7abc9e9e9e07a4b/assets/throughput.png
--------------------------------------------------------------------------------
/open_lm.patch:
--------------------------------------------------------------------------------
1 | diff --git a/open_lm/data.py b/open_lm/data.py
2 | index 107ff6e..05f6d08 100644
3 | --- a/open_lm/data.py
4 | +++ b/open_lm/data.py
5 | @@ -38,6 +38,34 @@ from webdataset.tariterators import (
6 | )
7 | from webdataset.mix import RandomMix
8 |
9 | +class MyRandomMix(RandomMix):
10 | + def __init__(self, datasets, probs=None, longest=False, seed=42):
11 | + super().__init__(datasets, probs=probs, longest=longest)
12 | + self.rng = random.Random()
13 | + self.rng.seed(seed)
14 | +
15 | + def __iter__(self):
16 | + """Return an iterator over the sources."""
17 | + sources = [iter(d) for d in self.datasets]
18 | + return self.random_samples(sources, self.probs)
19 | +
20 | + def random_samples(self, sources, probs=None):
21 | + if probs is None:
22 | + probs = [1] * len(sources)
23 | + else:
24 | + probs = list(probs)
25 | + while len(sources) > 0:
26 | + cum = (np.array(probs) / np.sum(probs)).cumsum()
27 | + r = self.rng.random()
28 | + i = np.searchsorted(cum, r)
29 | + try:
30 | + yield next(sources[i])
31 | + except StopIteration:
32 | + if self.longest:
33 | + del sources[i]
34 | + del probs[i]
35 | + else:
36 | + break
37 |
38 | def seed_worker(worker_id):
39 | worker_seed = torch.initial_seed() % 2**32
40 | @@ -344,7 +372,7 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
41 | all_num_samples = []
42 |
43 | shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc
44 | - for ii, input_shards in enumerate(input_shards_):
45 | + for ii, (input_shards, batch_mult) in enumerate(zip(input_shards_, args.dataset_batch_mult)):
46 | resampled = getattr(args, "dataset_resampled", False) and is_train
47 | num_shards = None
48 | if is_train:
49 | @@ -421,7 +449,7 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
50 | )
51 |
52 | map_handler = {"handler": log_and_continue} if args.ignore_parse_errors else {}
53 | - batch_size = args.per_gpu_batch_size if is_train else args.per_gpu_val_batch_size
54 | + batch_size = int(batch_mult * args.per_gpu_batch_size) if is_train else args.per_gpu_val_batch_size
55 |
56 | if data_key == "json" or data_key == "json.gz":
57 | pipeline.extend(
58 | @@ -430,7 +458,6 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
59 | wds.rename(json=data_key),
60 | wds.map_dict(json=partial(preprocess_json, vocab_size=args.vocab_size), **map_handler),
61 | wds.to_tuple("json", **map_handler),
62 | - wds.select(partial(filter_lt_seqlen, args.seq_len)),
63 | wds.batched(batch_size, partial=not is_train),
64 | ]
65 | )
66 | @@ -439,7 +466,6 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
67 | [
68 | wds.map_dict(txt=partial(preprocess_txt, vocab_size=args.vocab_size), **map_handler),
69 | wds.to_tuple("txt", **map_handler),
70 | - wds.select(partial(filter_lt_seqlen, args.seq_len)),
71 | wds.batched(batch_size, partial=not is_train),
72 | ]
73 | )
74 | @@ -451,8 +477,8 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
75 | all_num_samples.append(num_samples)
76 |
77 | if is_train:
78 | - # TODO: why did we previoulsy wrap with RandomMix_
79 | - dataset = RandomMix(datasets, probs=args.train_data_mix_weights, longest=True)
80 | + # Use our RandomMix with determined random seed to make sure all nodes choose the same bucket.
81 | + dataset = MyRandomMix(datasets, probs=args.train_data_mix_weights, longest=True, seed=args.seed)
82 | if len(datasets) > 1:
83 | logging.warning("Source mixing is happening during training. It is preferred to mix during tokenization.")
84 | else:
85 | @@ -461,17 +487,18 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
86 | # dataset = datasets[0]
87 | if is_train:
88 | if not resampled:
89 | - num_shards = num_shards or len(expand_urls(input_shards)[0])
90 | - if num_shards < args.workers * args.world_size:
91 | + shards_per_source_avail = [len(expand_urls(shard_string)[0]) for shard_string in input_shards_]
92 | + print(f"Number of shards available from each source = {shards_per_source_avail}")
93 | + min_num_shards = min(shards_per_source_avail)
94 | + if min_num_shards < args.workers * args.world_size:
95 | print("Please increase --train-num-samples or decrease workers or world size")
96 | - print(f"num_shards: {num_shards}, workers: {args.workers}, world_size: {args.world_size}")
97 | - assert num_shards >= args.workers * args.world_size, "number of shards must be >= total workers"
98 | - # roll over and repeat a few samples to get same number of full batches on each node
99 | + print(f"min num_shards: {min_num_shards}, workers: {args.workers}, world_size: {args.world_size}")
100 | + assert min_num_shards >= args.workers * args.world_size, "number of shards must be >= total workers"
101 | round_fn = math.floor if floor else math.ceil
102 | - global_batch_size = batch_size * args.world_size
103 | total_num_batches = 0
104 | total_num_samples = 0
105 | - for ii in range(len(datasets)):
106 | + for ii, batch_mult in enumerate(args.dataset_batch_mult):
107 | + global_batch_size = int(batch_mult * args.global_batch_size)
108 | # Calculate batches per worker, round as little as possible.
109 | num_workers_per_gpu = max(1, args.workers)
110 | num_worker_batches = round_fn(all_num_samples[ii] / (global_batch_size * num_workers_per_gpu))
111 | @@ -484,7 +511,7 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
112 | )
113 |
114 | num_batches = num_worker_batches * num_workers_per_gpu
115 | - num_samples = num_batches * global_batch_size
116 | + num_samples = num_batches * args.global_batch_size # Number of sequences as if all were the longest (8k)
117 |
118 | # This forces the dataloader to take num_worker_batches steps per worker, so num_batches total.
119 | datasets[ii] = datasets[ii].repeat(nepochs=1, nbatches=num_worker_batches)
120 | @@ -704,18 +731,6 @@ def mask_sequence(chunk, start_idx, args, ignore_tok=-100):
121 |
122 |
123 | def sample_chunk(chunk, args):
124 | - if chunk.shape[1] == args.seq_len + 1:
125 | - start_idx = 0
126 | - elif chunk.shape[1] > args.seq_len + 1:
127 | - start_idx = torch.randint(0, chunk.shape[1] - args.seq_len, (1,)).item()
128 | - else:
129 | - raise Exception(f"Invalid sequence length: Sequence length {args.seq_len} > {chunk.shape[1]} Chunk size")
130 | -
131 | - inputs = chunk[:, start_idx : start_idx + args.seq_len]
132 | - targets = chunk[:, start_idx + 1 : start_idx + args.seq_len + 1]
133 | -
134 | - # replace elements to be masked with with -100 (pytorch default xent ignore value)
135 | - if args.target_mask_left is not None or args.target_mask_individual is not None:
136 | - inputs, targets = mask_sequence(chunk, start_idx, args)
137 | -
138 | + inputs = chunk[:, :-1]
139 | + targets = chunk[:, 1:]
140 | return inputs, targets
141 | diff --git a/open_lm/file_utils.py b/open_lm/file_utils.py
142 | index f91919b..fe729fa 100644
143 | --- a/open_lm/file_utils.py
144 | +++ b/open_lm/file_utils.py
145 | @@ -134,14 +134,22 @@ def check_exists(file_path):
146 | return True
147 |
148 |
149 | -def get_metadata_file(path, shard_shuffle_seed=None):
150 | +def get_metadata_file(path, shard_shuffle_seed=None, append_a_copy=4):
151 | of = fsspec.open(path, "rb")
152 | with of as f:
153 | out = f.read()
154 | out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]]
155 | + if append_a_copy > 0:
156 | + out_copy = [copy.deepcopy(out) for _ in range(append_a_copy)]
157 | if shard_shuffle_seed is not None:
158 | rng_gen = np.random.default_rng(shard_shuffle_seed)
159 | rng_gen.shuffle(out)
160 | + if append_a_copy > 0:
161 | + for a_copy in out_copy:
162 | + rng_gen.shuffle(a_copy)
163 | + if append_a_copy > 0:
164 | + for a_copy in out_copy:
165 | + out = out + a_copy
166 | return out
167 |
168 |
169 | @@ -218,7 +226,7 @@ def count_small_shards(path, ratio=0.9):
170 |
171 | shard_sizes = np.array(shard_sizes)
172 |
173 | - return np.sum(shard_sizes < ratio * max(shard_sizes))
174 | + return np.sum(shard_sizes < ratio * max(shard_sizes)), max(shard_sizes)
175 |
176 |
177 | def are_sources_imbalanced_with_each_other(paths, ratio=2):
178 | @@ -262,9 +270,11 @@ def log_num_checkpoints(total_steps, args):
179 | args.world_size,
180 | multi_epoch=args.multiple_data_passes,
181 | shard_shuffle_seed=args.shard_shuffle_seed,
182 | + source_num_seq_per_epoch=args.source_num_seq_per_epoch,
183 | )
184 | steps_epoch = sum(
185 | - [(n // (args.workers * args.global_batch_size)) * args.workers for n in num_samples_per_source]
186 | + [(n // (args.workers * args.global_batch_size * batch_mult)) * args.workers for n, batch_mult in
187 | + zip(num_samples_per_source, args.dataset_batch_mult)]
188 | )
189 | steps_done += steps_epoch
190 | if steps_done > total_steps:
191 | @@ -300,15 +310,18 @@ def get_string_for_epoch(
192 | world_size: int,
193 | multi_epoch=False,
194 | shard_shuffle_seed=None,
195 | + source_num_seq_per_epoch=None,
196 | ):
197 | """See _single_epoch_string for full docstring."""
198 | if multi_epoch:
199 | return _multi_epoch_string(
200 | - num_samples, starting_points, paths, weights, num_workers_per_gpu, world_size, shard_shuffle_seed
201 | + num_samples, starting_points, paths, weights, num_workers_per_gpu, world_size, shard_shuffle_seed,
202 | + source_num_seq_per_epoch
203 | )
204 | else:
205 | return _single_epoch_string(
206 | - num_samples, starting_points, paths, weights, num_workers_per_gpu, world_size, shard_shuffle_seed
207 | + num_samples, starting_points, paths, weights, num_workers_per_gpu, world_size, shard_shuffle_seed,
208 | + source_num_seq_per_epoch
209 | )
210 |
211 |
212 | @@ -370,6 +383,7 @@ def _single_epoch_string(
213 | num_workers_per_gpu: int,
214 | world_size: int,
215 | shard_shuffle_seed: Optional[int],
216 | + source_num_seq_per_epoch: Optional[List[int]] = None,
217 | ):
218 | """Retrieve shards to train on for a particular checkpoint.
219 |
220 | @@ -383,38 +397,25 @@ def _single_epoch_string(
221 | num_workers_per_gpu: Number of workers per gpu process.
222 | world_size: Total number of gpus used for training.
223 | shard_shuffle_seed: Seed to shuffle shards before checkpoint assignment
224 | + source_num_seq_per_epoch: List of number of sequences per bucket per epoch.
225 | """
226 |
227 | num_sources = len(paths)
228 | -
229 | - if num_sources > 1:
230 | - logging.warning(
231 | - "Multiple sources are not supported fully as of now. It is advised to combine the data into a single "
232 | - "source, by using datapreprocess/ray/tokenize_shuffle.py. Best effort will be done to mix data at the "
233 | - "desired ratio."
234 | - )
235 | - if are_sources_imbalanced_with_each_other(paths):
236 | - logging.warning(
237 | - "Sources contain highly imbalanced shards (largest median shard size of a source is >2x the smallest "
238 | - "median size of a source). This will lead to deteriorated performance (less frequent checkpoints, "
239 | - "data being skipped, and inaccurate mixing). It is STRONGLY advised to combine into one source."
240 | - )
241 | + expected_num_sequence_per_shard = []
242 |
243 | for path in paths:
244 | - num_small_shards = count_small_shards(path)
245 | - if num_small_shards > 0:
246 | - logging.warning(
247 | - f"Source defined by {path} contains {num_small_shards} shards that are smaller than 90% the size of "
248 | - f"the largest shard. These shards might cause deterioration in performance, with more samples being "
249 | - f"skipped than necessary. It is advised to make the shards more uniform."
250 | - )
251 | + num_small_shards, expected_num_seq = count_small_shards(path)
252 | + expected_num_sequence_per_shard.append(expected_num_seq)
253 |
254 | if weights is None:
255 | weights = [1.0 / num_sources for _ in range(num_sources)]
256 |
257 | assert len(weights) == num_sources, "One weight is needed per source."
258 |
259 | - needed_samples_per_source = [int(np.ceil(weights[i] * num_samples / sum(weights))) for i in range(num_sources)]
260 | + if source_num_seq_per_epoch is None:
261 | + needed_samples_per_source = [int(np.ceil(weights[i] * num_samples / sum(weights))) for i in range(num_sources)]
262 | + else:
263 | + needed_samples_per_source = source_num_seq_per_epoch
264 |
265 | manifests = [get_metadata_file(path, shard_shuffle_seed=shard_shuffle_seed) for path in paths]
266 |
267 | @@ -424,32 +425,38 @@ def _single_epoch_string(
268 | num_samples_per_source = [[] for _ in range(num_sources)]
269 |
270 | total_num_workers = num_workers_per_gpu * world_size
271 | - while not enough_shards(shard_list_per_source, total_num_workers) or not enough_samples(
272 | - num_samples_per_source, needed_samples_per_source
273 | - ):
274 | - try:
275 | - for i in range(num_sources):
276 | + try:
277 | + for i in range(num_sources):
278 | + while len(shard_list_per_source[i]) < total_num_workers or sum(num_samples_per_source[i]) < \
279 | + needed_samples_per_source[i]:
280 | # Add shards incrementally
281 | shard_name = manifests[i][next_shard_per_source[i]]["shard"]
282 | try:
283 | num_samples_shard = manifests[i][next_shard_per_source[i]]["num_sequences"]
284 | except KeyError:
285 | num_samples_shard = manifests[i][next_shard_per_source[i]]["num_chunks"]
286 | -
287 | - shard_list_per_source[i].append(shard_name)
288 | - num_samples_per_source[i].append(num_samples_shard)
289 | + if num_samples_shard == expected_num_sequence_per_shard[i]:
290 | + shard_list_per_source[i].append(shard_name)
291 | + num_samples_per_source[i].append(num_samples_shard)
292 | + else:
293 | + print(
294 | + f"Dropping shard = {shard_name} with {num_samples_shard} samples != {expected_num_sequence_per_shard[i]}")
295 |
296 | next_shard_per_source[i] += 1
297 | -
298 | - except IndexError as e:
299 | - logging.error(
300 | - "Number of shards requested for a single epoch is more than the number of shards available. This means "
301 | - "that the amount of data requested to train on is more than the dataloader can serve. This can either "
302 | - "happen because there are not enough data to begin with, or data being skipped due to rounding errors. "
303 | - "To alleviate the latter, consider making more uniform shards, and using less workers/GPUs. This will "
304 | - "allow for better use of the dataset."
305 | - )
306 | - raise e
307 | + except IndexError as e:
308 | + print(f"For Source = {i}")
309 | + print(f"Need samples = {needed_samples_per_source[i]}, collected {sum(num_samples_per_source[i])}")
310 | + print(f"Total shards so far = {next_shard_per_source[i]}")
311 | + print(f"len(shard_list_per_source[i]) = {len(shard_list_per_source[i])}")
312 | + print(f"total_num_workers = {total_num_workers}")
313 | + logging.error(
314 | + "Number of shards requested for a single epoch is more than the number of shards available. This means "
315 | + "that the amount of data requested to train on is more than the dataloader can serve. This can either "
316 | + "happen because there are not enough data to begin with, or data being skipped due to rounding errors. "
317 | + "To alleviate the latter, consider making more uniform shards, and using less workers/GPUs. This will "
318 | + "allow for better use of the dataset."
319 | + )
320 | + raise e
321 |
322 | for i in range(num_sources):
323 | # Ensure the number of shards is a multiple of number of workers, so each worker has the same
324 | @@ -458,6 +465,9 @@ def _single_epoch_string(
325 | # This is a heuristic to minimize how much data we discard when trying to ensure each worker has
326 | # the same number of samples. Shards tend to have similar number of samples, so an extra shard
327 | # in a worker will likely get discarded.
328 | + if not len(shard_list_per_source[i]) % total_num_workers == 0:
329 | + print(
330 | + f"For source {i} number of shards = {len(shard_list_per_source[i])} is not multiple of total workers = {total_num_workers}")
331 | num_multiples = len(shard_list_per_source[i]) // total_num_workers
332 |
333 | shard_list_per_source[i] = shard_list_per_source[i][: num_multiples * total_num_workers]
334 | diff --git a/open_lm/main.py b/open_lm/main.py
335 | index 7c80f55..0da7edc 100644
336 | --- a/open_lm/main.py
337 | +++ b/open_lm/main.py
338 | @@ -793,6 +793,7 @@ def main(args):
339 | args.world_size,
340 | multi_epoch=args.multiple_data_passes,
341 | shard_shuffle_seed=args.shard_shuffle_seed,
342 | + source_num_seq_per_epoch=args.source_num_seq_per_epoch,
343 | )
344 |
345 | # In the distributed case, make sure that all nodes receive the same string
346 | diff --git a/open_lm/model_configs/open_lm_160m.json b/open_lm/model_configs/open_lm_160m.json
347 | index ea4fe6e..944faf0 100644
348 | --- a/open_lm/model_configs/open_lm_160m.json
349 | +++ b/open_lm/model_configs/open_lm_160m.json
350 | @@ -2,7 +2,7 @@
351 | "hidden_dim": 768,
352 | "n_layers": 12,
353 | "n_heads": 12,
354 | - "seq_len": 2048,
355 | + "seq_len": 8192,
356 | "vocab_size": 50432,
357 | "post_embed_norm": false,
358 | "weight_tying": false
359 | diff --git a/open_lm/model_configs/open_lm_1b.json b/open_lm/model_configs/open_lm_1b.json
360 | index fc1878e..774fc9b 100644
361 | --- a/open_lm/model_configs/open_lm_1b.json
362 | +++ b/open_lm/model_configs/open_lm_1b.json
363 | @@ -2,7 +2,7 @@
364 | "hidden_dim": 2048,
365 | "n_layers": 24,
366 | "n_heads": 16,
367 | - "seq_len": 2048,
368 | + "seq_len": 8192,
369 | "vocab_size": 50432,
370 | "post_embed_norm": false,
371 | "weight_tying": false
372 | diff --git a/open_lm/model_configs/open_lm_3b.json b/open_lm/model_configs/open_lm_3b.json
373 | index 64ec0a4..57cc24a 100644
374 | --- a/open_lm/model_configs/open_lm_3b.json
375 | +++ b/open_lm/model_configs/open_lm_3b.json
376 | @@ -2,7 +2,7 @@
377 | "hidden_dim": 2560,
378 | "n_layers": 32,
379 | "n_heads": 32,
380 | - "seq_len": 2048,
381 | + "seq_len": 8192,
382 | "vocab_size": 50432,
383 | "post_embed_norm": false,
384 | "weight_tying": false
385 | diff --git a/open_lm/model_configs/open_lm_410m.json b/open_lm/model_configs/open_lm_410m.json
386 | index 8532173..1010cf7 100644
387 | --- a/open_lm/model_configs/open_lm_410m.json
388 | +++ b/open_lm/model_configs/open_lm_410m.json
389 | @@ -2,7 +2,7 @@
390 | "hidden_dim": 1024,
391 | "n_layers": 24,
392 | "n_heads": 16,
393 | - "seq_len": 2048,
394 | + "seq_len": 8192,
395 | "vocab_size": 50432,
396 | "post_embed_norm": false,
397 | "weight_tying": false
398 | diff --git a/open_lm/model_configs/open_lm_7b.json b/open_lm/model_configs/open_lm_7b.json
399 | index e662dab..b9178d0 100644
400 | --- a/open_lm/model_configs/open_lm_7b.json
401 | +++ b/open_lm/model_configs/open_lm_7b.json
402 | @@ -2,7 +2,7 @@
403 | "hidden_dim": 4096,
404 | "n_layers": 32,
405 | "n_heads": 32,
406 | - "seq_len": 2048,
407 | + "seq_len": 8192,
408 | "vocab_size": 50432,
409 | "post_embed_norm": false,
410 | "weight_tying": false
411 | diff --git a/open_lm/params.py b/open_lm/params.py
412 | index 0a7a3f6..389b805 100644
413 | --- a/open_lm/params.py
414 | +++ b/open_lm/params.py
415 | @@ -787,6 +787,20 @@ def parse_args(args):
416 | default=0,
417 | help="This is the maximum number of failed checkpoints (due to not having seen enough tokens) that are allowed",
418 | )
419 | + parser.add_argument(
420 | + "--dataset-batch-mult",
421 | + type=float,
422 | + nargs="+",
423 | + default=None,
424 | + help="Multiplier of batchsize to be used for each dataset (with respect to base batchsize).",
425 | + )
426 | + parser.add_argument(
427 | + "--source-num-seq-per-epoch",
428 | + type=int,
429 | + nargs="+",
430 | + default=None,
431 | + help="Number of sequences to be used per epoch from each source.",
432 | + )
433 |
434 | add_model_args(parser)
435 |
436 | diff --git a/open_lm/positional_embedding/rotary.py b/open_lm/positional_embedding/rotary.py
437 | index b48ed89..d5c1af0 100644
438 | --- a/open_lm/positional_embedding/rotary.py
439 | +++ b/open_lm/positional_embedding/rotary.py
440 | @@ -57,7 +57,7 @@ class RotaryEmbedding(torch.nn.Module):
441 | self.reset_parameters()
442 |
443 | def reset_parameters(self):
444 | - self.inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model))
445 | + self.inv_freq = 1.0 / (100000 ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model))
446 | self._update_cos_sin_tables(self.seq_len)
447 |
448 | def _update_cos_sin_tables(self, seq_len: int = None, device: torch.device = None, dtype: torch.dtype = None):
449 | diff --git a/open_lm/train.py b/open_lm/train.py
450 | index 0d54bf7..eccf708 100644
451 | --- a/open_lm/train.py
452 | +++ b/open_lm/train.py
453 | @@ -110,13 +110,17 @@ def train_one_epoch(
454 |
455 | try:
456 | batch = next(data_iterator)
457 | - has_data = torch.tensor(1, dtype=torch.long, device=device)
458 | + has_data = torch.tensor([1, len(batch[0])], dtype=torch.long, device=device)
459 | except StopIteration:
460 | - has_data = torch.tensor(0, dtype=torch.long, device=device)
461 | + logging.warning("Could not get a batch!!!")
462 | + has_data = torch.tensor([0, 0], dtype=torch.long, device=device)
463 |
464 | if args.world_size > 1:
465 | dist.all_reduce(has_data, op=ReduceOp.SUM)
466 | - if has_data < args.world_size:
467 | + if has_data[1] != len(batch[0]) * args.world_size:
468 | + logging.warning("Same global sequence length consistency broke! This can reduce performance.")
469 | + if has_data[0] != args.world_size:
470 | + logging.warning("At least one gpu could not get a batch.")
471 | break
472 |
473 | (texts,) = batch
474 | @@ -153,12 +157,12 @@ def train_one_epoch(
475 | # save the loss for the average model for logging
476 | total_loss_avg[key] = loss(out_avg.reshape(-1, args.vocab_size), targets.reshape(-1))
477 | else:
478 | + inputs, targets = sample_chunk(texts, args)
479 | +
480 | # split up batch into accum_freq chunks -- if you have --batch-size 8 and --accum-freq 4
481 | # then you only process 2 items at a time. batch-size must be divisible by accume-freq.
482 | - assert args.per_gpu_batch_size % args.accum_freq == 0, "Per-GPU batch size must be divisible by accum_freq"
483 | - per_batch = args.per_gpu_batch_size // args.accum_freq
484 | -
485 | - inputs, targets = sample_chunk(texts, args)
486 | + assert inputs.shape[0] % args.accum_freq == 0, "Per-GPU batch size must be divisible by accum_freq"
487 | + per_batch = inputs.shape[0] // args.accum_freq
488 |
489 | forward_total_time = 0
490 | backward_total_time = 0
491 | @@ -291,7 +295,7 @@ def train_one_epoch(
492 | for key, value in total_loss_avg.items():
493 | losses_avg_m[key].update(value.item(), batch_size)
494 | if i % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch or step == total_steps - 1:
495 | - num_samples = batch_count * batch_size * args.world_size
496 | + num_samples = batch_count * args.global_batch_size # Number of sequences seen as if all were the longest
497 | samples_per_epoch = dataloader.num_samples
498 | percent_complete = 100.0 * batch_count / num_batches_per_epoch
499 |
500 | @@ -332,6 +336,7 @@ def train_one_epoch(
501 | "tokens": (step + 1) * args.global_batch_size * args.seq_len,
502 | "expected_steps_epoch": data["train"].dataloader.num_batches,
503 | "seen_steps_epoch": batch_count,
504 | + "seq_len": inputs.shape[1],
505 | }
506 |
507 | if averagers is not None and args.log_avg_model_training_loss:
508 |
--------------------------------------------------------------------------------
/scripts/dclm_download.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import argparse
7 | import os
8 |
9 | from datasets import load_dataset
10 | from tqdm import tqdm
11 |
12 |
13 | def main(output_dir: str, num_shards: int = 128) -> None:
14 | """Downloads a small subset of the DCLM-Baseline dataset.
15 |
16 | :param output_dir: The path to the output directory where the downloaded files will be saved.
17 | :param num_shards: The number of JSONL files to divide the downloaded data into.
18 | :return: None
19 | """
20 | os.makedirs(output_dir, exist_ok=True)
21 | # Download a small fraction of DCLM-Baseline dataset
22 | data = load_dataset("mlfoundations/dclm-baseline-1.0",
23 | data_dir="global-shard_01_of_10/local-shard_0_of_10")
24 |
25 | for split, dataset in data.items():
26 | for i in tqdm(range(num_shards)):
27 | dataset_shard = dataset.shard(
28 | num_shards=num_shards, index=i, contiguous=True)
29 | output_file = os.path.join(output_dir, f"dclm_{split}_{i}.jsonl")
30 | dataset_shard.to_json(output_file)
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | "--output-dir",
37 | type=str,
38 | required=True,
39 | help="Where to store the DCLM subset .jsonl files.",
40 | )
41 |
42 | args = parser.parse_args()
43 | main(args.output_dir)
44 |
--------------------------------------------------------------------------------
/scripts/get_dd_params.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import argparse
7 |
8 | import numpy as np
9 | from numpy.typing import NDArray
10 |
11 |
12 | def main(total_tokens: int,
13 | epochs: int,
14 | num_gpus: int,
15 | global_batch: int,
16 | num_workers: int,
17 | curriculum: NDArray,
18 | number_of_shards: NDArray,
19 | sequence_per_shard: NDArray,
20 | sequence_sizes: NDArray,
21 | batch_mult: NDArray) -> None:
22 | """Computes variable batch size training parameters based on the desired training hyperparameters.
23 |
24 | :param total_tokens: The total number of tokens to be processed during training.
25 | :param epochs: The number of epochs/checkpoints to save during training.
26 | :param num_gpus: The total number of GPUs to be used for training.
27 | :param global_batch: The global batch size for training.
28 | :param num_workers: The number of dataloader workers per GPU.
29 | :param curriculum: A numpy array of integers representing the probabilities of selecting a batch from each bucket.
30 | :param number_of_shards: A numpy array indicating the number of shards per source as defined in each bucket's manifest file.
31 | :param sequence_per_shard: A numpy array representing the number of sequences per shard for each source.
32 | :param sequence_sizes: A numpy array specifying the sizes of sequences per shard.
33 | :param batch_mult: A numpy array defining the ratio of each source's batch size compared to the source with the longest sequences.
34 | :return: None
35 | """
36 | # Number of tokens available per source/bucket:
37 | tokens_per_bucket = number_of_shards * sequence_per_shard * sequence_sizes
38 |
39 | # Ratio of number of tokens needed vs number of tokens available
40 | job_scale_factor = total_tokens / tokens_per_bucket.sum()
41 |
42 | # Number of tokens needed per source/bucket
43 | needed_tokens_per_bucket = job_scale_factor * tokens_per_bucket
44 |
45 | # Number of tokens needed per source/bucket per epoch
46 | needed_tokens_per_bucket_per_epoch = needed_tokens_per_bucket / epochs
47 |
48 | # Number of sequences needed per source/bucket per epoch
49 | needed_sequence_per_bucket_per_epoch = needed_tokens_per_bucket_per_epoch / \
50 | (sequence_sizes)
51 |
52 | # Number of sequences per source/bucket per epoch should be divisible with
53 | # the following numbers
54 | denom_condition1 = batch_mult * global_batch * num_workers
55 | denom_condition2 = sequence_per_shard * num_gpus * num_workers
56 | # Satisfying the second condition is sufficient
57 | assert np.all(denom_condition2 > denom_condition1)
58 |
59 | factors = needed_sequence_per_bucket_per_epoch / denom_condition2
60 | factors_int = np.int32(np.round(factors))
61 |
62 | def get_token_diff(proposed_factors):
63 | return total_tokens / epochs - \
64 | np.sum(proposed_factors * denom_condition2 * sequence_sizes)
65 |
66 | proposed_factors = factors_int
67 |
68 | index = len(factors_int) - 1
69 |
70 | tried_proposed_factors = set()
71 | while get_token_diff(proposed_factors) != 0:
72 | if index < 0:
73 | total_tokens_real = epochs * \
74 | (proposed_factors * denom_condition2 * sequence_sizes).sum()
75 | print(
76 | f"Cannot match requested number of tokens of {total_tokens:,}. Go with {total_tokens_real:,} instead.")
77 | break
78 | if tuple(proposed_factors) in tried_proposed_factors:
79 | index -= 1
80 | diff = get_token_diff(proposed_factors)
81 | tried_proposed_factors.add(tuple(proposed_factors))
82 | if diff < 0:
83 | proposed_factors[index] -= 1
84 | else:
85 | proposed_factors[index] += 1
86 |
87 | proposed_needed_sequence_per_bucket_per_epoch = proposed_factors * denom_condition2
88 | proposed_num_tokens_per_bucket = proposed_needed_sequence_per_bucket_per_epoch * sequence_sizes
89 | source_num_seq_per_epoch = " ".join(
90 | [str(int(x)) for x in proposed_needed_sequence_per_bucket_per_epoch])
91 | sampling_weights = proposed_num_tokens_per_bucket / \
92 | proposed_num_tokens_per_bucket[0] * proposed_factors[0]
93 | sampling_weights = sampling_weights * curriculum
94 | train_data_mix_weights = " ".join([str(int(x)) for x in sampling_weights])
95 | actual_tokens_per_epoch = (
96 | proposed_needed_sequence_per_bucket_per_epoch *
97 | sequence_sizes).sum()
98 |
99 | print("**** Use the following arguments:")
100 |
101 | print(f"--epochs {epochs}")
102 | print(f"--train-num-samples {actual_tokens_per_epoch}")
103 | print("--dataset-batch-mult " +
104 | " ".join([str(int(x)) for x in batch_mult]))
105 | print("--source-num-seq-per-epoch " + source_num_seq_per_epoch)
106 | print("--train-data-mix-weights " + train_data_mix_weights)
107 |
108 |
109 | if __name__ == "__main__":
110 | parser = argparse.ArgumentParser()
111 | parser.add_argument(
112 | "--tokens",
113 | type=int,
114 | help=(
115 | "Total number of tokens to be seen during training."
116 | "Can be larger than number of available tokens as we allow repeated tokens."))
117 | parser.add_argument(
118 | "--epochs",
119 | type=int,
120 | help="Number of epochs/checkpoints to save.")
121 | parser.add_argument(
122 | "--gpus",
123 | type=int,
124 | help="Total number of GPUs to be used for training.")
125 | parser.add_argument(
126 | "--global-batch-size",
127 | type=int,
128 | help="Global batch size.")
129 | parser.add_argument(
130 | "--workers",
131 | type=int,
132 | default=1,
133 | help="Number of dataloader workers per GPU.")
134 | parser.add_argument(
135 | "--number-of-shards",
136 | type=int,
137 | nargs="+",
138 | help="Number of shards per source (can read from manifest files).",
139 | default=[
140 | 553,
141 | 779,
142 | 831,
143 | 690,
144 | 475,
145 | 291],
146 | ) # Default values for the wiki example
147 | parser.add_argument(
148 | "--sequence-per-shard",
149 | type=int,
150 | nargs="+",
151 | help="Number of sequences per shard for each source (determined in each bucket's manifest file).",
152 | default=[
153 | 4096,
154 | 2048,
155 | 1024,
156 | 512,
157 | 256,
158 | 128],
159 | ) # Default values for the wiki example
160 | parser.add_argument(
161 | "--sequence_sizes",
162 | type=int,
163 | nargs="+",
164 | help="Size of sequences per shard.",
165 | default=[
166 | 256,
167 | 512,
168 | 1024,
169 | 2048,
170 | 4096,
171 | 8192],
172 | ) # Default values for the wiki example
173 | parser.add_argument(
174 | "--batch-mult",
175 | type=float,
176 | nargs="+",
177 | help="Ratio of each source batch vs the source with the longest sequences.",
178 | default=[
179 | 32,
180 | 16,
181 | 8,
182 | 4,
183 | 2,
184 | 1],
185 | ) # Default values for the wiki example
186 | parser.add_argument("--train-data-mix-weights", type=float, nargs="+",
187 | help="List of odds to pick a batch from each bucket.",
188 | default=[32, 16, 8, 4, 2, 1], ) # Pow-2 Curriculum
189 | args = parser.parse_args()
190 | print(args)
191 | main(
192 | args.tokens, args.epochs, args.gpus, args.global_batch_size, args.workers, np.array(
193 | args.train_data_mix_weights), np.array(
194 | args.number_of_shards), np.array(
195 | args.sequence_per_shard), np.array(
196 | args.sequence_sizes), np.array(
197 | args.batch_mult))
198 |
--------------------------------------------------------------------------------
/scripts/get_stats.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import argparse
7 | import os
8 | from pathlib import Path
9 |
10 | import jsonlines
11 |
12 |
13 | def main(dd_dir: str) -> None:
14 | """Generates statistics for a directory containing a webdataset dataset.
15 |
16 | :param dd_dir: The path to the directory containing the webdataset dataset.
17 | :return: None
18 | """
19 | total_num_tokens = 0
20 | bucket_dir_list = sorted(
21 | os.listdir(dd_dir),
22 | key=lambda x: int(
23 | x.split("_")[1]))
24 | for bucket_dir in bucket_dir_list:
25 | manifest_file = os.path.join(dd_dir, bucket_dir, "manifest.jsonl")
26 | num_shards = 0
27 | num_sequences = 0
28 | sequence_length = 2 ** int(bucket_dir.split("_")[1])
29 | if os.path.exists(manifest_file):
30 | with jsonlines.open(manifest_file) as reader:
31 | for item in reader:
32 | num_shards += 1
33 | num_sequences += item['num_sequences']
34 | num_tokens = num_sequences * sequence_length
35 | print(
36 | f"{bucket_dir:<4}: # shards: {num_shards:<6,} seq-length: {sequence_length:<8,} # sequences: {num_sequences:<12,} # tokens: {num_tokens:<12,}")
37 | total_num_tokens += num_tokens
38 | print(20 * "*")
39 | print(f"Total number of tokens = {total_num_tokens:,}")
40 |
41 |
42 | if __name__ == "__main__":
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument(
45 | "--dd-dir",
46 | type=Path,
47 | help="Path to a dataset decomposition directory.")
48 | args = parser.parse_args()
49 |
50 | main(args.dd_dir)
51 |
--------------------------------------------------------------------------------
/scripts/make_dd_buckets.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import argparse
7 | import glob
8 | import gzip
9 | import io
10 | import math
11 | import multiprocessing
12 | import os
13 | import random
14 | import shutil
15 | import time
16 | from collections import defaultdict
17 | from contextlib import contextmanager
18 | from pathlib import Path
19 | from typing import Dict, Generator, List, Union
20 |
21 | import jsonlines
22 | import numpy as np
23 | import zstandard as zstd
24 | from tqdm.auto import tqdm
25 | from transformers import (
26 | GPTNeoXTokenizerFast,
27 | PreTrainedTokenizer,
28 | PreTrainedTokenizerFast,
29 | )
30 | from webdataset import ShardWriter
31 |
32 | # Increasing pool size results in more global shuffling, but use more memory.
33 | SHARD_POOL_FACTOR = 20
34 |
35 | # Number of documents per shard for each bucket. Note that bucket i has documents with length 2**i+1.
36 | # We use smaller number of documents per shard for larger i's.
37 | SHARD_SIZE = {i: min(2 ** (20 - i), 65536) for i in range(20)}
38 |
39 | EOT_TOKEN = "<|endoftext|>"
40 |
41 |
42 | def write_to_shard(chunks: List[Union[int, str]],
43 | shard_writer: ShardWriter) -> None:
44 | """Writes a list of tokens to a shard file.
45 |
46 | :param chunks: A list of tokens to be written.
47 | :param shard_writer: A Webdataset writer object for managing the shard file.
48 | :return: None
49 | """
50 | for idx, chunk in enumerate(chunks):
51 | shard_writer.write({"__key__": f"{idx:012d}", "txt": str(chunk)})
52 |
53 |
54 | @contextmanager
55 | def get_item_reader(file_name: str) -> Generator[jsonlines.Reader, None, None]:
56 | """Creates an iterator for reading .jsonl files or Zstd-compressed .jsonl files.
57 |
58 | :param file_name: The path to the input data file.
59 | :return: A generator that yields items from the .jsonl file or zstd-compressed .jsonl file.
60 | """
61 | if file_name.endswith(".jsonl"):
62 | with jsonlines.open(file_name) as reader:
63 | yield reader
64 | elif file_name.endswith(".jsonl.gz"):
65 | with gzip.open(file_name, "rb") as f_in:
66 | with jsonlines.Reader(f_in) as jsonl_reader:
67 | yield jsonl_reader
68 | else:
69 | dctx = zstd.ZstdDecompressor()
70 | with open(file_name, "rb") as compressed_file:
71 | with dctx.stream_reader(compressed_file) as reader:
72 | with io.TextIOWrapper(reader, encoding="utf-8") as text_reader:
73 | with jsonlines.Reader(text_reader) as jsonl_reader:
74 | yield jsonl_reader
75 |
76 |
77 | def get_binary(seq: List[Union[int, str]], min_log2: int = 8, max_log2: int = 13,
78 | randomize: bool = True) -> Dict[int, List[Union[int, str]]]:
79 | """Applies binary dataset decomposition to a document.
80 |
81 | :param seq: A list of tokenized documents.
82 | :param min_log2: The log2 of the minimum subsequence length to keep. Smaller subsequences will be ignored.
83 | :param max_log2: The log2 of the maximum subsequence length to keep. Larger subsequences will be further divided.
84 | :param randomize: If True, subsequences larger than 2**max_log2 will be divided randomly into subsequences
85 | with lengths within the range determined by min_log2 and max_log2. If False, the division
86 | prioritizes keeping the longest acceptable subsequences.
87 | :return: A dictionary `d` where `d[i]` contains the subsequences of `seq` with lengths of 2**i+1.
88 | """
89 | out_map = defaultdict(list)
90 | ps = 2 ** np.arange(max_log2 + 1 - min_log2)
91 | ps = ps / ps.sum()
92 | while len(seq) > 1:
93 | k = int(math.log2(len(seq) - 1))
94 | if k < min_log2:
95 | return out_map
96 |
97 | if k > max_log2:
98 | if randomize:
99 | k = np.random.choice(
100 | np.arange(
101 | max_log2, min_log2 - 1, -1), p=ps)
102 | else:
103 | k = min(k, max_log2)
104 |
105 | out_map[k].append(seq[:(1 << k) + 1])
106 | seq = seq[(1 << k) + 1:]
107 | return out_map
108 |
109 |
110 | def tokenize_and_shard(
111 | file_names: List[str],
112 | my_id: int,
113 | output_dir: str,
114 | enc: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
115 | min_bucket: int,
116 | max_bucket: int) -> None:
117 | """Performs dataset-decomposition tokenize-and-shuffle using a single process.
118 |
119 | :param file_names: A list of input data files.
120 | :param my_id: The process ID for the current worker.
121 | :param output_dir: The path to the output directory where the sharded webdataset files will be stored.
122 | :param enc: A tokenizer object for tokenizing the input data.
123 | :param min_bucket: The index of the bucket containing the shortest sequences.
124 | :param max_bucket: The index of the bucket containing the longest sequences.
125 | :return: None
126 | """
127 | start_time = time.time()
128 |
129 | shard_writer = {}
130 | for k in range(min_bucket, max_bucket + 1):
131 | output_dir_k = os.path.join(output_dir, f'{k}')
132 | os.makedirs(output_dir_k, exist_ok=True)
133 | shard_writer[k] = ShardWriter(
134 | os.path.join(
135 | output_dir_k,
136 | "shard-%07d.tar"),
137 | maxcount=SHARD_SIZE[k])
138 |
139 | # dictionary where keys are log2 length, and values are list of sequences.
140 | chunks = defaultdict(list)
141 |
142 | num_entries = 0
143 |
144 | for file_name in file_names:
145 | with get_item_reader(file_name) as item_reader:
146 | for item in item_reader:
147 | string = item["text"]
148 | try:
149 | tokens = enc(string).input_ids + [EOT_TOKEN]
150 | token_map = get_binary(
151 | tokens, min_log2=min_bucket, max_log2=max_bucket)
152 | num_entries += 1
153 | except BaseException:
154 | print("Failed to encode string.")
155 | continue
156 |
157 | for k, v_list in token_map.items():
158 | for v in v_list:
159 | chunks[k].append(v)
160 | if len(chunks[k]) == SHARD_POOL_FACTOR * SHARD_SIZE[k]:
161 | random.shuffle(chunks[k])
162 | write_to_shard(
163 | chunks[k][:SHARD_SIZE[k]], shard_writer[k])
164 | chunks[k] = chunks[k][SHARD_SIZE[k]:]
165 |
166 | total_time = time.time() - start_time
167 | print(
168 | f"Process {my_id} found {num_entries} entries in {total_time} seconds",
169 | flush=True,
170 | )
171 |
172 | # Write remaining shards
173 | for k in chunks.keys():
174 | random.shuffle(chunks[k])
175 | for i in range(0, len(chunks[k]), SHARD_SIZE[k]):
176 | if i + SHARD_SIZE[k] <= len(chunks[k]
177 | ): # Do not allow partial shards
178 | write_to_shard(
179 | chunks[k][i: i + SHARD_SIZE[k]], shard_writer[k])
180 |
181 | print(f"Process {my_id} Done.", flush=True)
182 |
183 |
184 | def merge_process_dirs(
185 | output_dir: str,
186 | min_bucket: int,
187 | max_bucket: int,
188 | num_workers: int) -> None:
189 | """Merges multiple webdatasets into one for each bucket.
190 |
191 | :param output_dir: Path to a directory containing [num_workers] subdirectories. Each subdirectory is the output
192 | of a single process of tokenize-and-shuffle and contains subdirectories for different buckets.
193 | :param min_bucket: The index of the bucket with the shortest sequences.
194 | :param max_bucket: The index of the bucket with the longest sequences.
195 | :param num_workers: The number of processes used for parallel computation.
196 | :return: None
197 | """
198 | process_dirs = os.listdir(output_dir)
199 | for k in tqdm(
200 | range(
201 | min_bucket,
202 | max_bucket +
203 | 1),
204 | total=max_bucket -
205 | min_bucket +
206 | 1):
207 | wds_dirs = [os.path.join(output_dir, p, f"{k}") for p in process_dirs]
208 |
209 | transfer_map = {}
210 | global_index = 0
211 | for i, dir in enumerate(wds_dirs):
212 | tarfiles = [os.path.join(dir, file) for file in os.listdir(
213 | dir) if file.endswith(".tar")]
214 | for a_tar in tarfiles:
215 | dir_path = os.path.join(output_dir, f"D_{k}")
216 | if not os.path.exists(dir_path):
217 | os.makedirs(dir_path, exist_ok=True)
218 | target_file = os.path.join(
219 | dir_path, "shard-{:07d}.tar".format(global_index))
220 | global_index += 1
221 | transfer_map[a_tar] = target_file
222 |
223 | with multiprocessing.Pool(processes=num_workers) as pool:
224 | pool.starmap(
225 | shutil.move,
226 | [(src, transfer_map[src]) for src in transfer_map],
227 | )
228 | # Remove original subdirs
229 | for p in process_dirs:
230 | dir_path = os.path.join(output_dir, p)
231 | shutil.rmtree(dir_path)
232 |
233 |
234 | def tokenize_shuffle(
235 | input_files: List[str],
236 | output_dir: str,
237 | num_workers: int,
238 | min_bucket: int,
239 | max_bucket: int,
240 | ) -> None:
241 | """Performs dataset-decomposition tokenize-and-shuffle using multiple processes.
242 |
243 | :param input_files: A list of input data files.
244 | :param output_dir: The path to the output directory.
245 | :param num_workers: The number of processes to use for parallel computation.
246 | :param min_bucket: The index of the bucket containing the shortest sequences.
247 | :param max_bucket: The index of the bucket containing the longest sequences.
248 | :return: None
249 | """
250 | input_files = [glob.glob(input_file) for input_file in input_files]
251 | input_files = [x for y in input_files for x in y]
252 |
253 | # Shuffle the input files
254 | random.shuffle(input_files)
255 | print("Number of input files = {}".format(len(input_files)))
256 |
257 | enc = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")
258 |
259 | # assert len(input_files) % num_workers == 0
260 | files_per_worker = len(input_files) // num_workers
261 | file_groups = [input_files[x: x + files_per_worker]
262 | for x in range(0, len(input_files), files_per_worker)]
263 |
264 | with multiprocessing.Pool(processes=num_workers) as pool:
265 | pool.starmap(
266 | tokenize_and_shard,
267 | [
268 | (
269 | fg,
270 | my_id,
271 | os.path.join(output_dir, str(my_id)),
272 | enc,
273 | min_bucket,
274 | max_bucket,
275 | )
276 | for my_id, fg in enumerate(file_groups)
277 | ],
278 | )
279 |
280 |
281 | if __name__ == "__main__":
282 | parser = argparse.ArgumentParser()
283 | parser.add_argument(
284 | "--input-files",
285 | type=str,
286 | nargs="+",
287 | help="Set of input data files.")
288 | parser.add_argument(
289 | "--output-dir",
290 | type=Path,
291 | help="Path to output directory.")
292 | parser.add_argument(
293 | "--num-workers",
294 | type=int,
295 | default=32,
296 | help="Number of workers to use.")
297 | parser.add_argument(
298 | "--min-bucket",
299 | type=int,
300 | default=8,
301 | help="log2 of the shortest sequences.")
302 | parser.add_argument(
303 | "--max-bucket",
304 | type=int,
305 | default=13,
306 | help="log2 of the longest sequences.")
307 |
308 | args = parser.parse_args()
309 |
310 | tokenize_shuffle(
311 | args.input_files,
312 | args.output_dir,
313 | args.num_workers,
314 | args.min_bucket,
315 | args.max_bucket,
316 | )
317 | merge_process_dirs(
318 | args.output_dir,
319 | args.min_bucket,
320 | args.max_bucket,
321 | args.num_workers)
322 |
--------------------------------------------------------------------------------
/scripts/train_dd.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export PYTHONPATH=$PYTHONPATH:./open_lm
4 | torchrun --nproc-per-node 8 \
5 | --nnodes 1 \
6 | --node_rank 0 \
7 | --max_restarts=0 \
8 | --rdzv_backend c10d \
9 | --rdzv_conf "timeout=3000,read_timeout=10000" \
10 | -m open_lm.main \
11 | --accum-freq 4 \
12 | --global-batch-size 64 \
13 | --beta1 0.9 \
14 | --beta2 0.95 \
15 | --data-key txt \
16 | --ffn-type swiglu \
17 | --fsdp \
18 | --fsdp-limit-all-gathers \
19 | --log-every-n-steps 32 \
20 | --lr 0.003 \
21 | --lr-cooldown-end 3e-5 \
22 | --model open_lm_1b \
23 | --name dd_open_lm_1b_$RANDOM \
24 | --precision amp_bfloat16 \
25 | --qk-norm \
26 | --seed 42 \
27 | --warmup 5000 \
28 | --wd 0.033 \
29 | --workers 1 \
30 | --z-loss-coefficient 0.0001 \
31 | --ignore-parse-errors \
32 | --logs /mnt/open_lm_logs/ \
33 | --dataset-manifest "/mnt/processed_datasets/dclm/D_8/manifest.jsonl" \
34 | "/mnt/processed_datasets/dclm/D_9/manifest.jsonl" \
35 | "/mnt/processed_datasets/dclm/D_10/manifest.jsonl" \
36 | "/mnt/processed_datasets/dclm/D_11/manifest.jsonl" \
37 | "/mnt/processed_datasets/dclm/D_12/manifest.jsonl" \
38 | "/mnt/processed_datasets/dclm/D_13/manifest.jsonl" \
39 | --epochs 8 \
40 | --train-num-samples 3607101440 \
41 | --dataset-batch-mult 32 16 8 4 2 1 \
42 | --source-num-seq-per-epoch 1507328 1277952 794624 335872 137216 61440 \
43 | --train-data-mix-weights 1472 1248 776 328 134 60 \
44 | --fsdp-amp \
45 | --grad-clip-norm 1 \
46 | --attn-name xformers_attn \
47 | --model-norm gain_only_lp_layer_norm \
48 | --wandb-project-name dd_code \
49 | --report-to wandb \
50 |
--------------------------------------------------------------------------------
/scripts/wiki_download.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import argparse
7 | import os
8 |
9 | from datasets import load_dataset
10 | from tqdm import tqdm
11 |
12 |
13 | def main(output_dir: str, num_shards: int = 32):
14 | """Downloads the Wikipedia dataset.
15 |
16 | :param output_dir: The path to the output directory where the downloaded files will be saved.
17 | :param num_shards: The number of JSONL files to divide the downloaded data into.
18 | :return: None
19 | """
20 | os.makedirs(output_dir, exist_ok=True)
21 | data = load_dataset("wikipedia", "20220301.en", trust_remote_code=True)
22 |
23 | for split, dataset in data.items():
24 | for i in tqdm(range(num_shards)):
25 | dataset_shard = dataset.shard(
26 | num_shards=num_shards, index=i, contiguous=True)
27 | output_file = os.path.join(
28 | output_dir, f"wiki_en_20220301_{split}_{i}.jsonl")
29 | dataset_shard.to_json(output_file)
30 |
31 |
32 | if __name__ == "__main__":
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument(
35 | "--output-dir",
36 | type=str,
37 | required=True,
38 | help="Where to store the wikipedia .jsonl files",
39 | )
40 |
41 | args = parser.parse_args()
42 | main(args.output_dir)
43 |
--------------------------------------------------------------------------------