├── .gitignore
├── LICENSE
├── README.md
├── configs
├── OLMoE-1B-7B-0924.yml
└── ablations
│ ├── olmo-1b-newhp-newds-cx5-datafix.yml
│ ├── olmo-1b-newhp-newds-cx5-flan.yml
│ ├── olmo-1b-newhp-newds-cx5-reddit.yml
│ ├── olmo-1b-newhp-newds-cx5.yml
│ ├── olmo-1b-newhp-newds-s3.yml
│ ├── olmo-1b-newhp-newds.yml
│ ├── olmo-1b-newhp-oldds-cx5.yml
│ ├── olmo-1b-newhp-oldds-s3.yml
│ ├── olmo-1b-newhp-oldds.yml
│ ├── olmoe-8x1b-newhp-newds-cx5-fine-shared-s3.yml
│ ├── olmoe-8x1b-newhp-newds-cx5-fine-shared.yml
│ ├── olmoe-8x1b-newhp-newds-cx5-fine.yml
│ ├── olmoe-8x1b-newhp-newds-cx5-fine05.yml
│ ├── olmoe-8x1b-newhp-newds-cx5-fine1-datafix.yml
│ ├── olmoe-8x1b-newhp-newds-cx5-fine1-docmask-8k.yml
│ ├── olmoe-8x1b-newhp-newds-cx5-fine1-docmask.yml
│ ├── olmoe-8x1b-newhp-newds-cx5-fine1-newtok.yml
│ ├── olmoe-8x1b-newhp-newds-cx5-fine1-normreorder.yml
│ ├── olmoe-8x1b-newhp-newds-cx5-fine1.yml
│ ├── olmoe-8x1b-newhp-newds-cx5-k2-fine-s3.yml
│ ├── olmoe-8x1b-newhp-newds-cx5-k2-fine.yml
│ ├── olmoe-8x1b-newhp-newds-cx5-k2.yml
│ ├── olmoe-8x1b-newhp-newds-cx5.yml
│ ├── olmoe-8x1b-newhp-newds-final-anneal.yml
│ ├── olmoe-8x1b-newhp-newds-final-densecomp.yml
│ ├── olmoe-8x1b-newhp-newds-final-double-alt.yml
│ ├── olmoe-8x1b-newhp-newds-final-double.yml
│ ├── olmoe-8x1b-newhp-newds-final-s3.yml
│ ├── olmoe-8x1b-newhp-newds-final-v2.yml
│ ├── olmoe-8x1b-newhp-newds-final.yml
│ ├── olmoe-8x1b-newhp-newds-k2qk.yml
│ ├── olmoe-8x1b-newhp-newds-s3-cx5.yml
│ ├── olmoe-8x1b-newhp-newds-s3.yml
│ ├── olmoe-8x1b-newhp-newds.yml
│ ├── olmoe-8x1b-newhp-oldds.yml
│ ├── olmoe-8x2b-newhp-newds-final.yml
│ ├── olmoe-8x7b-A7B.yml
│ ├── olmoe-8x7b.yml
│ ├── olmoe17-16x1b-fullshard-swiglu-wrapb-s1k1.yml
│ ├── olmoe17-8x1b-final-decemb.yml
│ ├── olmoe17-8x1b-final-decln.yml
│ ├── olmoe17-8x1b-final-eps-fine.yml
│ ├── olmoe17-8x1b-final-eps-noqk.yml
│ ├── olmoe17-8x1b-final-eps.yml
│ ├── olmoe17-8x1b-final-fine.yml
│ ├── olmoe17-8x1b-final-nodecln.yml
│ ├── olmoe17-8x1b-final-normdc.yml
│ ├── olmoe17-8x1b-final-weka.yaml
│ ├── olmoe17-8x1b-final.yml
│ ├── olmoe17-8x1b-fullshard-swiglu-wrapb-k2-qknorm-zloss.yml
│ └── olmoe17-8x7b-final.yml
├── logs
├── olmoe-dpo-logs.txt
└── olmoe-sft-logs.txt
├── scripts
├── adapteval.sh
├── batchjob.sh
├── eval_openlm_ckpt.py
├── humaneval.yaml
├── llm1b.sh
├── make_table.py
├── megatron.sh
├── megatron_dense_46m_8gpu.sh
├── megatron_dmoe_46m_8gpu.sh
├── olmoe-gantry.sh
├── olmoe_visuals.ipynb
├── plot_routing_analysis.ipynb
├── plot_routing_analysis_v2.ipynb
├── plot_routing_analysis_v2_cross_layer.ipynb
├── plot_routing_analysis_v2_top1.ipynb
├── routing_mixtral_v2.jpg
├── routing_olmoe_v2.jpg
├── routing_output.zip
├── routing_output
│ ├── mistral
│ │ ├── eid2token
│ │ │ ├── arxiv.pkl
│ │ │ ├── book.pkl
│ │ │ ├── c4.pkl
│ │ │ ├── github.pkl
│ │ │ └── wikipedia.pkl
│ │ ├── expert_counts
│ │ │ ├── arxiv.pkl
│ │ │ ├── book.pkl
│ │ │ ├── c4.pkl
│ │ │ ├── github.pkl
│ │ │ └── wikipedia.pkl
│ │ ├── expert_counts_crosslayer
│ │ │ ├── arxiv.pkl
│ │ │ ├── book.pkl
│ │ │ ├── c4.pkl
│ │ │ ├── github.pkl
│ │ │ └── wikipedia.pkl
│ │ ├── expert_counts_crosslayer_top1
│ │ │ ├── arxiv.pkl
│ │ │ ├── book.pkl
│ │ │ ├── c4.pkl
│ │ │ ├── github.pkl
│ │ │ └── wikipedia.pkl
│ │ └── expert_counts_top1
│ │ │ ├── arxiv.pkl
│ │ │ ├── book.pkl
│ │ │ ├── c4.pkl
│ │ │ ├── github.pkl
│ │ │ └── wikipedia.pkl
│ ├── olmoe-dpo
│ │ └── expert_counts
│ │ │ └── tulu.pkl
│ ├── olmoe-sft
│ │ └── expert_counts
│ │ │ └── tulu.pkl
│ ├── olmoe
│ │ ├── eid2token
│ │ │ ├── arxiv.pkl
│ │ │ ├── book.pkl
│ │ │ ├── c4.pkl
│ │ │ ├── github.pkl
│ │ │ └── wikipedia.pkl
│ │ ├── expert_counts
│ │ │ ├── arxiv.pkl
│ │ │ ├── book.pkl
│ │ │ ├── c4.pkl
│ │ │ ├── github.pkl
│ │ │ ├── tulu.pkl
│ │ │ └── wikipedia.pkl
│ │ ├── expert_counts_crosslayer
│ │ │ ├── arxiv.pkl
│ │ │ ├── book.pkl
│ │ │ ├── c4.pkl
│ │ │ ├── github.pkl
│ │ │ └── wikipedia.pkl
│ │ ├── expert_counts_crosslayer_top1
│ │ │ ├── arxiv.pkl
│ │ │ ├── book.pkl
│ │ │ ├── c4.pkl
│ │ │ ├── github.pkl
│ │ │ └── wikipedia.pkl
│ │ └── expert_counts_top1
│ │ │ ├── arxiv.pkl
│ │ │ ├── book.pkl
│ │ │ ├── c4.pkl
│ │ │ ├── github.pkl
│ │ │ └── wikipedia.pkl
│ ├── routing.jpg
│ ├── routing_prob_distribution.png
│ └── text
│ │ ├── arxiv_texts.txt
│ │ ├── b3g_texts.txt
│ │ ├── c4_texts.txt
│ │ ├── github_oss_with_stack_texts.txt
│ │ └── wikipedia_texts.txt
├── run_dclm_evals_heavy.sh
├── run_dclm_evals_heavy_olmo.sh
├── run_dclm_evals_humaneval.sh
├── run_moe_analysis.py
├── run_routing_analysis.py
├── sparsify_ckpt_unsharded.py
└── wekatransfer
│ ├── s3weka.sh
│ ├── s3weka.yml
│ └── wekas3.yaml
└── visuals
├── emojis
├── olmoe_checkmark.png
├── olmoe_checkmark_yellow.png
├── olmoe_cross.png
└── olmoe_warning.png
├── figures
├── adamweps.pdf
├── dataset.pdf
├── datasetredditflan.pdf
├── embdecay.pdf
├── expertchoice.pdf
├── granularity.pdf
├── init.pdf
├── layer_0_heatmap.pdf
├── layer_15_heatmap.pdf
├── layer_7_heatmap.pdf
├── layersharing.pdf
├── lbl.pdf
├── lblprecision.pdf
├── lbltoks.pdf
├── ln.pdf
├── lndecay.pdf
├── lngradnorm.pdf
├── loss.pdf
├── moevsdense.pdf
├── noise.pdf
├── olmoe.pdf
├── overview.jpg
├── overview.pdf
├── qknorm.pdf
├── routing_mixtral.pdf
├── routing_olmoe.pdf
├── routing_prob_distribution_mixtral.pdf
├── routing_prob_distribution_olmoe.pdf
├── shared.pdf
├── token_specialization_top1_olmoe.pdf
├── token_specialization_top2_mixtral.pdf
├── token_specialization_top8_olmoe.pdf
├── top18_changes_over_checkpoints.pdf
├── trainingevalflops.pdf
├── trainingevaltokens.pdf
├── upcycle.pdf
└── zloss.pdf
├── logos
├── OLMoE_logo.png
├── OLMoE_logo.svg
├── OLMoE_logo_alt1.png
├── OLMoE_logo_alt1.svg
├── OLMoE_logo_alt2.png
├── OLMoE_logo_alt2.svg
├── OLMoE_logo_alt3.png
└── OLMoE_logo_alt3.svg
├── poster_iclr2025.pdf
├── poster_iclr2025.pptx
├── poster_neurips2024.pdf
└── twitterblog_images
├── domainspec.png
├── experiments.png
├── logo_transparent.png
├── logo_twitter.png
├── overview_base.png
├── overview_left.png
├── overview_long.png
├── overview_right.png
├── perf_adapt.png
├── perf_during.png
├── perf_pretr.png
├── perf_pretr_adapt.png
└── tokenidspec.png
/.gitignore:
--------------------------------------------------------------------------------
1 | **/results/*
2 | .DS_Store
3 | */.DS_Store
4 | */*/.DS_Store
5 | */*/*/.DS_Store
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
OLMoE: Open Mixture-of-Experts Language Models
4 |
Fully open, state-of-the-art Mixture of Expert model with 1.3 billion active and 6.9 billion total parameters. All data, code, and logs released.
5 |
6 |
7 |
8 | 
9 |
10 | This repository provides an overview of all resources for the paper ["OLMoE: Open Mixture-of-Experts Language Models"](https://arxiv.org/abs/2409.02060).
11 |
12 | - [Artifacts](#artifacts)
13 | - [Inference](#inference)
14 | - [Pretraining](#pretraining)
15 | - [Adaptation](#adaptation)
16 | - [Evaluation](#evaluation)
17 | - [During pretraining](#during-pretraining)
18 | - [After pretraining](#after-pretraining)
19 | - [After adaptation](#after-adaptation)
20 | - [Visuals](#visuals)
21 | - [Citation](#citation)
22 |
23 | ### Artifacts
24 |
25 | - **Paper**: https://arxiv.org/abs/2409.02060
26 | - **Pretraining** [Checkpoints](https://hf.co/allenai/OLMoE-1B-7B-0924), [Final Checkpoint GGUF](https://hf.co/allenai/OLMoE-1B-7B-0924-GGUF), [Code](https://github.com/allenai/OLMo/tree/Muennighoff/MoE), [Data](https://huggingface.co/datasets/allenai/OLMoE-mix-0924) and [Logs](https://wandb.ai/ai2-llm/olmoe/reports/OLMoE-1B-7B-0924--Vmlldzo4OTcyMjU3).
27 | - **SFT (Supervised Fine-Tuning)** [Checkpoints](https://huggingface.co/allenai/OLMoE-1B-7B-0924-SFT), [Code](https://github.com/allenai/open-instruct/), [Data](https://hf.co/datasets/allenai/tulu-v3.1-mix-preview-4096-OLMoE) and [Logs](https://github.com/allenai/OLMoE/blob/main/logs/olmoe-sft-logs.txt).
28 | - **DPO/KTO (Direct Preference Optimization/Kahneman-Tversky Optimization)**, [Checkpoints](https://huggingface.co/allenai/OLMoE-1B-7B-0924-Instruct), [Final Checkpoint GGUF](https://hf.co/allenai/OLMoE-1B-7B-0924-Instruct-GGUF), [Preference Data](https://hf.co/datasets/allenai/ultrafeedback_binarized_cleaned), [DPO code](https://github.com/allenai/open-instruct/), [KTO code](https://github.com/Muennighoff/kto/blob/master/kto.py) and [Logs](https://github.com/allenai/OLMoE/blob/main/logs/olmoe-dpo-logs.txt).
29 |
30 | ### Inference
31 |
32 | OLMoE has been integrated into [vLLM](https://github.com/vllm-project/vllm), [SGLang](https://github.com/sgl-project/sglang), [llama.cpp](https://github.com/ggerganov/llama.cpp), and [transformers](https://github.com/huggingface/transformers). The transformers implementation is slow, thus we recommend using the others, e.g. vLLM, where possible. Below are examples for using it with vLLM and transformers.
33 |
34 | #### vLLM
35 |
36 | Install the `vllm` library and run:
37 |
38 | ```python
39 | from vllm import LLM, SamplingParams
40 | model = LLM("allenai/OLMoE-1B-7B-0924")
41 | out = model.generate("Bitcoin is", SamplingParams(temperature=0.0))
42 | print("Bitcoin is" + out[0].outputs[0].text)
43 | # Bitcoin is a digital currency that is not controlled by any central authority. It is a peer
44 | ```
45 |
46 | #### llama.cpp
47 |
48 | Install `llama.cpp`, download a quantized GGUF of the final checkpoint (e.g. [`olmoe-1b-7b-0924-q4_0.gguf`](https://hf.co/allenai/OLMoE-1B-7B-0924-GGUF/resolve/main/olmoe-1b-7b-0924-q4_0.gguf)) and run in a shell:
49 |
50 | ```bash
51 | llama-cli -m olmoe-1b-7b-0924-q4_0.gguf -p "Bitcoin is" -n 128
52 | ```
53 |
54 | #### transformers
55 |
56 | Install the `transformers` & `torch` libraries and run:
57 |
58 | ```python
59 | from transformers import OlmoeForCausalLM, AutoTokenizer
60 | import torch
61 |
62 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
63 |
64 | # Load different ckpts via passing e.g. `revision=step10000-tokens41B`
65 | # also check allenai/OLMoE-1B-7B-0924-SFT & allenai/OLMoE-1B-7B-0924-Instruct
66 | model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924").to(DEVICE)
67 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
68 | inputs = tokenizer("Bitcoin is", return_tensors="pt")
69 | inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
70 | out = model.generate(**inputs, max_length=64)
71 | print(tokenizer.decode(out[0]))
72 | # Bitcoin is a digital currency that is created and held electronically. No one controls it. Bitcoins aren’t printed, like dollars or euros – they’re produced by people and businesses running computers all around the world, using software that solves mathematical
73 | ```
74 |
75 | You can list all revisions/branches by installing `huggingface-hub` & running:
76 | ```python
77 | from huggingface_hub import list_repo_refs
78 | out = list_repo_refs("allenai/OLMoE-1B-7B-0924")
79 | branches = [b.name for b in out.branches]
80 | ```
81 |
82 | ### Pretraining
83 |
84 | 1. Clone this [OLMo branch](https://github.com/allenai/OLMo/tree/Muennighoff/MoE) & create an environment with its dependencies via `cd OLMo; pip install -e .`. If you want to use new features in OLMo clone from the `main` branch instead.
85 | 2. Run `pip install git+https://github.com/Muennighoff/megablocks.git@olmoe`
86 | 3. Setup a config file. `configs/OLMoE-1B-7B-0924.yml` was used for the pretraining of `OLMoE-1B-7B-0924`. You can find configs from various ablations in `configs/ablations`.
87 | 4. Download the data from https://hf.co/datasets/allenai/OLMoE-mix-0924, tokenize it via the command below and adapt the `paths` in your training config to point to it.
88 | ```bash
89 | dolma tokens \
90 | --documents ${PATH_TO_DOWNLOADED_DATA} \
91 | --destination ${PATH_WHERE_TO_SAVE_TOKENIZED_DATA} \
92 | --tokenizer.name_or_path 'allenai/gpt-neox-olmo-dolma-v1_5' \
93 | --max_size '2_147_483_648' \
94 | --seed 0 \
95 | --tokenizer.eos_token_id 50279 \
96 | --tokenizer.pad_token_id 1 \
97 | --processes ${NUMBER_OF_CPU_CORES_TO_USE}
98 | ```
99 | 6. Submit your job. We used `bash scripts/olmoe-gantry.sh` which invokes https://github.com/allenai/OLMo/blob/Muennighoff/MoE/scripts/train.py and uses [beaker gantry](https://github.com/allenai/beaker-gantry) but you will likely need to change the script to work with your setup.
100 | 7. To run annealing after the main pretraining we use [this config](https://github.com/allenai/OLMoE/blob/main/configs/ablations/olmoe-8x1b-newhp-newds-final-anneal.yml) - the only changes from the pretraining config are the `optimizer` and `scheduler` fields as well as `max_duration` and `stop_at`.
101 | 8. To convert you pretraining checkpoint to Hugging Face transformers after training, you can use the script & instructions [here](https://github.com/huggingface/transformers/blob/8f8af0fb38baa851f3fd69f564fbf91b5af78332/src/transformers/models/olmoe/convert_olmoe_weights_to_hf.py#L14).
102 |
103 | #### Other design choices
104 |
105 | For most of our experiments on other design choices, you can simply set them in the config file (e.g. change the respective hyperparam), except for:
106 | 1. **Sparse upcycling:** To sparse upcycle your model, train it dense first using e.g. [this config](https://github.com/allenai/OLMoE/blob/main/configs/ablations/olmo-1b-newhp-newds-cx5.yml), then convert any of its checkpoints into an MoE using [this script](https://github.com/allenai/OLMoE/blob/main/scripts/sparsify_ckpt_unsharded.py) & its instructions at the top while making sure to modify the hardcoded values (num experts etc) as you'd like your model to be, then place the newly created model (`model_sparse.safetensors`) into a new folder with a name that ends in `-unsharded` and place the model file inside of it with the name `model.safetensors`, then launch a job that loads this model similar to [our sparse upcycling job](https://wandb.ai/ai2-llm/olmoe/runs/1w3srbb3/overview) (note the settings `--load_path=path_to_upcycled_ckpt --reset_optimizer_state=True --reset_trainer_state=True` and `--fast_forward_batches=XXX` if you also want to continue on the same dataset with the same order). Also make sure to have the changes from this PR in your code: https://github.com/allenai/OLMo/pull/573. Finally, if you want to reproduce upcycling from OLMo-1B (0724) as in the paper, the OLMo 1B checkpoint turned into an MoE with 8 experts to start from is here: https://huggingface.co/allenai/OLMo-1B-0724-954000steps-unsharded; download the files inside of it (e.g. `wget https://huggingface.co/allenai/OLMo-1B-0724-954000steps-unsharded/resolve/main/model.safetensors`), then use a config similar to [this one](https://wandb.ai/ai2-llm/olmoe/runs/1w3srbb3/overview) to train the upcycled MoE from it.
107 | 2. **Expert choice:** To run experiments with our expert choice implementation, you need to instead use the olmo branch `Muennighoff/OLMoSE` or simply copy over the small config changes that enable expert choice (i.e. [here](https://github.com/allenai/OLMo/blob/b7b312aa2d9ee0ec0816a042955f34e27f9f4628/olmo/config.py#L524)) to the `Muennighoff/MoE` branch. You can then run expert choice by activating it in your config (it will use this code: https://github.com/Muennighoff/megablocks/blob/4a25bc7b5665bcb9da93d72d5ad0c14d41e1a351/megablocks/layers/moe.py#L462 or https://github.com/Muennighoff/megablocks/blob/4a25bc7b5665bcb9da93d72d5ad0c14d41e1a351/megablocks/layers/moe.py#L477 depending on your selection; both should be ~equivalent implementations of expert choice; neither was better than dropless token choice in our experiments)
108 | 3. **Shared MoE layers (Appendix):** For these experiments you need to use the olmo branch `Muennighoff/OLMoSE` and create your own config e.g. like the one used in [this run](https://wandb.ai/ai2-llm/olmoe/reports/Plot-Shared-vs-Dense--Vmlldzo4NDI0MTc5).
109 |
110 | ### Adaptation
111 |
112 | 1. Clone Open Instruct [here](https://github.com/allenai/open-instruct/) & follow its setup instructions. If you run into any problems, try upgrading your transformers version with `pip install --upgrade transformers` first.
113 | 2. SFT: After adapting as needed, run:
114 | ```
115 | accelerate launch \
116 | --mixed_precision bf16 \
117 | --num_machines 1 \
118 | --num_processes 8 \
119 | --use_deepspeed \
120 | --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \
121 | open_instruct/finetune.py \
122 | --model_name_or_path allenai/OLMoE-1B-7B-0924 \
123 | --tokenizer_name allenai/OLMoE-1B-7B-0924 \
124 | --use_flash_attn \
125 | --max_seq_length 4096 \
126 | --preprocessing_num_workers 128 \
127 | --per_device_train_batch_size 2 \
128 | --gradient_accumulation_steps 8 \
129 | --learning_rate 2e-05 \
130 | --lr_scheduler_type linear \
131 | --warmup_ratio 0.03 \
132 | --weight_decay 0.0 \
133 | --num_train_epochs 2 \
134 | --output_dir output/ \
135 | --with_tracking \
136 | --report_to wandb \
137 | --logging_steps 1 \
138 | --reduce_loss sum \
139 | --model_revision main \
140 | --dataset_mixer_list allenai/tulu-v3-mix-preview-4096-OLMoE 1.0 ai2-adapt-dev/daring-anteater-specialized 1.0 \
141 | --checkpointing_steps epoch \
142 | --add_bos
143 | ```
144 | 4. DPO: Run
145 | ```
146 | accelerate launch \
147 | --mixed_precision bf16 \
148 | --num_machines 1 \
149 | --num_processes 8 \
150 | --use_deepspeed \
151 | --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \
152 | open_instruct/dpo_tune.py \
153 | --model_name_or_path allenai/OLMoE-1B-7B-0924-SFT \
154 | --tokenizer_name allenai/OLMoE-1B-7B-0924-SFT \
155 | --use_flash_attn \
156 | --gradient_checkpointing \
157 | --dataset_name argilla/ultrafeedback-binarized-preferences-cleaned \
158 | --max_seq_length 4096 \
159 | --preprocessing_num_workers 16 \
160 | --per_device_train_batch_size 1 \
161 | --gradient_accumulation_steps 4 \
162 | --learning_rate 5e-7 \
163 | --lr_scheduler_type linear \
164 | --warmup_ratio 0.1 \
165 | --weight_decay 0. \
166 | --num_train_epochs 3 \
167 | --output_dir output/ \
168 | --report_to tensorboard \
169 | --logging_steps 1 \
170 | --reduce_loss sum \
171 | --add_bos \
172 | --checkpointing_steps epoch \
173 | --dpo_beta 0.1
174 | ```
175 | 6. KTO: Install `trl` and run https://github.com/Muennighoff/kto/blob/master/kto.py via `WANDB_PROJECT=olmoe accelerate launch --config_file=config_8gpusdsz2_m7.yml kto.py --model_name_or_path allenai/OLMoE-1B-7B-0924-SFT --output_dir OLMoE-1B-7B-0924-SFT-KTO-3EP --report_to "wandb" --per_device_train_batch_size 4 --gradient_accumulation_steps 1 --optim rmsprop --learning_rate 5e-07 --beta 0.1 --logging_steps 1 --bf16 --sanity_check False --num_train_epochs 3` (if you want to run the Adam optimizer change to `--optim adamw_torch`). We used `trl==0.9.6`.
176 |
177 | ### Evaluation
178 |
179 | #### During pretraining
180 |
181 | Evaluation during pretraining is done automatically and configured in the config file. It uses the code here: https://github.com/allenai/OLMo/tree/Muennighoff/MoE/olmo/eval.
182 |
183 | #### After pretraining
184 |
185 | OLMES Evals: Follow the instructions at https://github.com/allenai/OLMo-Eval/blob/51c5ba579e75ef4ce7e9b29936eaa72c1a0e99eb/olmo_eval/tasks/olmes_v0_1/README.md
186 |
187 | DCLM Evals: Run `scripts/run_dclm_evals*` and refer to instructions from https://github.com/mlfoundations/dclm
188 |
189 | #### After adaptation
190 |
191 | - Setup https://github.com/allenai/open-instruct/
192 | - Run `sbatch scripts/adapteval.sh` after changing it as necessary / extract the commands from the script and run them one by one.
193 |
194 | ### Visuals
195 |
196 | - Figure 1, `visuals/figures/overview.pdf`: Run "Main plot" in `scripts/olmoe_visuals.ipynb` equivalent to [this colab](https://colab.research.google.com/drive/15PTwmoxcbrwWKG6ErY44hlJlLLKAj7Hx?usp=sharing) and add the result into this drawing to edit it further: https://docs.google.com/drawings/d/1Of9-IgvKH54zhKI_M4x5HOYEF4XUp6qaXluT3Zmv1vk/edit?usp=sharing (the drawing used for [this tweet](https://x.com/Muennighoff/status/1831159130230587486) is [here](https://docs.google.com/drawings/d/133skqIfE8f7iOO9hMidV5tBZ7MAle28VxEm9U8Q8ioU/edit?usp=sharing))
197 | - Figure 2, `visuals/figures/olmoe.pdf`: https://www.figma.com/design/Es8UpNHKgugMAncPWnSDuK/olmoe?node-id=0-1&t=SeuQKPlaoB12TXqe-1 (also contains some other figures used on Twitter)
198 | - Figure 3 & 25, `visuals/figures/trainingeval*pdf`: Run "During training" in `scripts/olmoe_visuals.ipynb` equivalent to [this colab](https://colab.research.google.com/drive/15PTwmoxcbrwWKG6ErY44hlJlLLKAj7Hx?usp=sharing)
199 | - Figure 4 - 19, 24, 26-29, `visuals/figures/...pdf`: Run respective parts in `scripts/olmoe_visuals.ipynb` equivalent to [this colab](https://colab.research.google.com/drive/15PTwmoxcbrwWKG6ErY44hlJlLLKAj7Hx?usp=sharing)
200 | - Figure 20, 21, 23, 30, 31, Table 8, `visuals/figures/...pdf`: `scripts/run_moe_analysis.py` (If you do not want to rerun inference on the model generate the routing statistics you can download them from https://huggingface.co/datasets/allenai/analysis_olmoe & https://huggingface.co/datasets/allenai/analysis_mixtral)
201 | - Figure 22, 33-36 `visuals/figures/...pdf`: Run `scripts/run_routing_analysis.py` & then `scripts/plot_routing_analysis_v2.ipynb` / `scripts/plot_routing_analysis_v2_top1.ipynb` / `scripts/plot_routing_analysis_v2_cross_layer.ipynb`
202 | - Figure 32, `visuals/figures/...pdf`: Run `scripts/run_routing_analysis.py` & then `scripts/plot_routing_analysis.ipynb`
203 | - Table 13: `scripts/make_table.py`
204 | - All other tables are manually created.
205 |
206 | ### Citation
207 |
208 | ```bibtex
209 | @misc{muennighoff2024olmoeopenmixtureofexpertslanguage,
210 | title={OLMoE: Open Mixture-of-Experts Language Models},
211 | author={Niklas Muennighoff and Luca Soldaini and Dirk Groeneveld and Kyle Lo and Jacob Morrison and Sewon Min and Weijia Shi and Pete Walsh and Oyvind Tafjord and Nathan Lambert and Yuling Gu and Shane Arora and Akshita Bhagia and Dustin Schwenk and David Wadden and Alexander Wettig and Binyuan Hui and Tim Dettmers and Douwe Kiela and Ali Farhadi and Noah A. Smith and Pang Wei Koh and Amanpreet Singh and Hannaneh Hajishirzi},
212 | year={2024},
213 | eprint={2409.02060},
214 | archivePrefix={arXiv},
215 | primaryClass={cs.CL},
216 | url={https://arxiv.org/abs/2409.02060},
217 | }
218 | ```
219 |
--------------------------------------------------------------------------------
/scripts/adapteval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=eval
3 | #SBATCH --nodes=1
4 | #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
5 | #SBATCH --partition=a3
6 | #SBATCH --gres=gpu:8 # number of gpus
7 | #SBATCH --time 24:00:00 # maximum execution time (HH:MM:SS)
8 | #SBATCH --output=jobs/%x-%j.out # output file name
9 | #SBATCH --exclusive
10 | #SBATCH --array=0-6%7 # adjusted array size and concurrency
11 |
12 | MODEL_PATH=allenai/OLMoE-1B-7B-0924-Instruct
13 | TOKENIZER_PATH=$MODEL_PATH
14 |
15 | cd ~/open-instruct
16 | conda activate YOURENV
17 |
18 | # Commands array
19 | case $SLURM_ARRAY_TASK_ID in
20 | 0)
21 | python -m eval.mmlu.run_eval \
22 | --ntrain 0 \
23 | --data_dir /data/niklas/data/eval/mmlu/ \
24 | --save_dir ${MODEL_PATH}/eval/mmlu \
25 | --model_name_or_path ${MODEL_PATH} \
26 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
27 | --use_chat_format \
28 | --chat_formatting_function eval.templates.create_prompt_with_huggingface_tokenizer_template \
29 | --eval_batch_size 64
30 | ;;
31 | 1)
32 | python -m eval.gsm.run_eval \
33 | --data_dir /data/niklas/data/eval/gsm/ \
34 | --max_num_examples 200 \
35 | --save_dir ${MODEL_PATH}/eval/gsm8k \
36 | --model_name_or_path ${MODEL_PATH} \
37 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
38 | --n_shot 8 \
39 | --use_chat_format \
40 | --chat_formatting_function eval.templates.create_prompt_with_huggingface_tokenizer_template \
41 | --eval_batch_size 64
42 | ;;
43 | 2)
44 | OPENAI_API_KEY=YOUR_KEY IS_ALPACA_EVAL_2=False python -m eval.alpaca_farm.run_eval \
45 | --model_name_or_path ${MODEL_PATH} \
46 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
47 | --save_dir ${MODEL_PATH}/eval/alpaca \
48 | --use_chat_format \
49 | --chat_formatting_function eval.templates.create_prompt_with_huggingface_tokenizer_template \
50 | --eval_batch_size 128
51 | ;;
52 | 3)
53 | python -m eval.bbh.run_eval \
54 | --data_dir /data/niklas/data/eval/bbh/ \
55 | --save_dir ${MODEL_PATH}/eval/bbh \
56 | --model_name_or_path ${MODEL_PATH} \
57 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
58 | --max_num_examples_per_task 40 \
59 | --use_chat_format \
60 | --chat_formatting_function eval.templates.create_prompt_with_huggingface_tokenizer_template \
61 | --eval_batch_size 64
62 | ;;
63 | 4)
64 | OPENAI_API_KEY=YOUR_KEY python -m eval.xstest.run_eval \
65 | --data_dir /data/niklas/data/eval/xstest/ \
66 | --save_dir ${MODEL_PATH}/eval/xstest \
67 | --model_name_or_path ${MODEL_PATH} \
68 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
69 | --use_chat_format \
70 | --chat_formatting_function eval.templates.create_prompt_with_huggingface_tokenizer_template \
71 | --eval_batch_size 64
72 | ;;
73 | 5)
74 | python -m eval.codex_humaneval.run_eval \
75 | --data_file /data/niklas/data/eval/codex_humaneval/HumanEval.jsonl.gz \
76 | --eval_pass_at_ks 1 5 10 20 \
77 | --unbiased_sampling_size_n 20 \
78 | --temperature 0.8 \
79 | --save_dir ${MODEL_PATH}/eval/humaneval \
80 | --model_name_or_path ${MODEL_PATH} \
81 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
82 | --eval_batch_size 64
83 | ;;
84 | 6)
85 | python -m eval.ifeval.run_eval \
86 | --data_dir /data/niklas/data/eval/ifeval/ \
87 | --save_dir ${MODEL_PATH}/eval/ifeval \
88 | --model_name_or_path ${MODEL_PATH} \
89 | --tokenizer_name_or_path ${TOKENIZER_PATH} \
90 | --use_chat_format \
91 | --chat_formatting_function eval.templates.create_prompt_with_huggingface_tokenizer_template \
92 | --eval_batch_size 64
93 | ;;
94 | *)
95 | echo "Invalid SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID"
96 | exit 1
97 | ;;
98 | esac
99 |
--------------------------------------------------------------------------------
/scripts/batchjob.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=llm
3 | #SBATCH --nodes=1
4 | #SBATCH --ntasks-per-node=1
5 | #SBATCH --hint=nomultithread # we get physical cores not logical
6 | #SBATCH --partition=a3mixed
7 | #SBATCH --gres=gpu:1 # number of gpus
8 | #SBATCH --output=/data/niklas/jobs/%x-%j.out # output file name
9 | #SBATCH --exclusive
10 |
11 | ######################
12 | ### Set enviroment ###
13 | ######################
14 | cd /home/niklas/OLMoE
15 | source /env/bin/start-ctx-user
16 | conda activate llmoe
17 | export WANDB_PROJECT="olmoe"
18 | # Training setup
19 | set -euo pipefail
20 | export GPUS_PER_NODE=1
21 | export NODENAME=$(hostname -s)
22 | export MASTER_ADDR=$(scontrol show hostnames | head -n 1)
23 | export MASTER_PORT=39594
24 | export WORLD_SIZE=$SLURM_NTASKS
25 | export RANK=$SLURM_PROCID
26 | export FS_LOCAL_RANK=$SLURM_PROCID
27 | export LOCAL_WORLD_SIZE=$SLURM_NTASKS_PER_NODE
28 | export LOCAL_RANK=$SLURM_LOCALID
29 | export NODE_RANK=$((($RANK - $LOCAL_RANK) / $LOCAL_WORLD_SIZE))
30 | export R2_PROFILE=r2
31 | export R2_ENDPOINT_URL=YOUR_URL
32 | export AWS_ACCESS_KEY_ID=YOUR_KEY
33 | export AWS_SECRET_ACCESS_KEY=YOUR_KEY
34 | export HF_DATASETS_OFFLINE=1
35 | #export CUDA_LAUNCH_BLOCKING=1
36 | export TRITON_CACHE_DIR=/data/niklas/tcache
37 |
38 | echo "World size: $WORLD_SIZE"
39 | ######################
40 |
41 | ######################
42 | #### Set network #####
43 | ######################
44 | head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
45 | ######################
46 |
47 | python -u scripts/train.py configs/olmoe17/olmoe17-8x1b-fullshard-swiglu-wrapb-ecg-k2.yml \
48 | "--save_folder=/data/niklas/llm/checkpoints/${SLURM_JOB_ID}/" \
49 | --eval_interval=1000 \
50 | --save_interval=5000 \
51 | --global_train_batch_size=16 \
52 | --device_train_microbatch_size=2 \
53 | --save_overwrite=True \
54 | --evaluators=[]
55 |
56 | # srun \
57 | # --distribution=block:block \
58 | # --kill-on-bad-exit \
59 | # scripts/run_with_environment.sh \
60 | # python -u scripts/train.py configs/olmoe17/olmoe-8x1b-newhp-newds-s3.yml \
61 | # "--save_folder=/data/niklas/llm/checkpoints/${SLURM_JOB_ID}/" \
62 | # --save_overwrite \
63 | # --fsdp.sharding_strategy=HYBRID_SHARD \
64 | # --device_train_microbatch_size=2
65 |
--------------------------------------------------------------------------------
/scripts/eval_openlm_ckpt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import builtins as __builtin__
3 | import json
4 | import os
5 | import shutil
6 | import subprocess
7 | import sys
8 | import time
9 | import uuid
10 | from collections import defaultdict
11 | from datetime import datetime
12 | from pathlib import Path
13 | from typing import List
14 |
15 |
16 | sys.path.insert(0, str(Path(__file__).parent.parent))
17 |
18 | import numpy as np
19 | import pandas as pd
20 | import pytz
21 | import random
22 | import torch
23 | from aggregated_metrics import get_aggregated_results
24 | from utils import update_args_from_openlm_config
25 | from composer.loggers import InMemoryLogger, LoggerDestination
26 | from composer.trainer import Trainer
27 | from composer.utils import dist, get_device, reproducibility
28 | from llmfoundry.utils.builders import build_icl_evaluators, build_logger
29 | from omegaconf import OmegaConf as om
30 |
31 | from open_lm.hf import *
32 |
33 | from open_lm.attention import ATTN_ACTIVATIONS, ATTN_SEQ_SCALARS
34 | from open_lm.data import get_data
35 | from open_lm.distributed import init_distributed_device, is_master, world_info_from_env
36 | from open_lm.model import create_params
37 | from open_lm.main import load_model
38 | from open_lm.evaluate import evaluate_loop
39 | from open_lm.file_utils import pt_load
40 | from open_lm.utils.llm_foundry_wrapper import SimpleComposerOpenLMCausalLM
41 | from open_lm.utils.transformers.hf_config import OpenLMConfig
42 | from open_lm.utils.transformers.hf_model import OpenLMforCausalLM
43 | from pytz import timezone
44 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GPTNeoXTokenizerFast, LlamaTokenizerFast
45 |
46 | from training.file_utils import download_val_data, load_ppl_yaml
47 |
48 | builtin_print = __builtin__.print
49 |
50 | hf_access_token = open("hf_access_token.txt").read().strip()
51 |
52 | def setup_for_distributed(is_master):
53 | def print(*args, **kwargs):
54 | force = kwargs.pop("force", False)
55 | if is_master or force:
56 | builtin_print(*args, **kwargs)
57 |
58 | __builtin__.print = print
59 |
60 |
61 | def convert_gpqa(gpqa_dir, outdir):
62 | os.makedirs(outdir, exist_ok=True)
63 | for filename in ["gpqa_main.csv", "gpqa_diamond.csv", "gpqa_extended.csv"]:
64 | inpath = os.path.join(gpqa_dir, filename)
65 | outpath = os.path.join(outdir, Path(inpath).with_suffix(".jsonl").name)
66 | with open(outpath, "w") as f:
67 | df = pd.read_csv(inpath)
68 | rng = random.Random(42)
69 | for i in range(df.shape[0]):
70 | question = df["Question"][i]
71 | choices = np.array(
72 | [
73 | df["Correct Answer"][i],
74 | df["Incorrect Answer 1"][i],
75 | df["Incorrect Answer 2"][i],
76 | df["Incorrect Answer 3"][i],
77 | ]
78 | )
79 | idx = list(range(4))
80 | rng.shuffle(idx)
81 | choices = choices[idx]
82 | gold = idx.index(0)
83 | data = {"query": question, "choices": choices.tolist(), "gold": gold}
84 | f.write(json.dumps(data))
85 | f.write("\n")
86 |
87 | return
88 |
89 |
90 | def check_and_download_data():
91 | if not os.path.exists("local_data"):
92 | current_dir = os.path.dirname(os.path.realpath(__file__))
93 |
94 | if os.path.exists(f"{current_dir}/local_data"):
95 | shutil.copytree(f"{current_dir}/local_data", "local_data")
96 | else:
97 | if dist.get_global_rank() == 0:
98 | print("local_data folder does not exist. Running bash script...")
99 | script_path = os.path.join(current_dir, "download_eval_data.sh")
100 |
101 | subprocess.call([script_path])
102 |
103 | else:
104 | # Let other workers sleep a bit before barrier.
105 | time.sleep(10)
106 | dist.barrier()
107 |
108 | if not os.path.exists("gpqa_data"):
109 | repo_dir = os.path.dirname(os.path.realpath(__file__))
110 | if dist.get_global_rank() == 0:
111 | subprocess.run(
112 | [
113 | "unzip",
114 | "-P",
115 | "deserted-untie-orchid",
116 | os.path.join(repo_dir, "gpqa/dataset.zip"),
117 | "-d",
118 | "gpqa_data_orig/",
119 | ],
120 | check=True,
121 | stdout=subprocess.PIPE,
122 | stderr=subprocess.PIPE,
123 | )
124 | convert_gpqa("gpqa_data_orig/dataset", "gpqa_data")
125 | shutil.rmtree("gpqa_data_orig")
126 | else:
127 | time.sleep(10)
128 | dist.barrier()
129 |
130 | print("Done downloading data.")
131 | return
132 |
133 |
134 | @torch.no_grad()
135 | def evaluate(model, tokenizer, cfg):
136 | cfg.dist_timeout = cfg.get("dist_timeout", 600.0)
137 |
138 | reproducibility.seed_all(cfg.seed)
139 | dist.initialize_dist(get_device(None), timeout=cfg.dist_timeout)
140 | setup_for_distributed(dist.get_global_rank() == 0)
141 |
142 | # Check if the data is downloaded, if not, download it.
143 | check_and_download_data()
144 |
145 | composer_model = SimpleComposerOpenLMCausalLM(model, tokenizer)
146 |
147 | icl_tasks_w_categories = list(
148 | filter(lambda x: 0 if "has_categories" not in x else x["has_categories"], cfg.icl_tasks)
149 | )
150 | icl_tasks_w_categories = list(map(lambda x: x["label"], icl_tasks_w_categories))
151 |
152 | evaluators, logger_keys = build_icl_evaluators(
153 | cfg.icl_tasks, tokenizer, cfg.max_seq_len, cfg.device_eval_batch_size,
154 | # icl_subset_num_batches=cfg.eval_subset_num_batches
155 | )
156 |
157 | in_memory_logger = InMemoryLogger() # track metrics in the in_memory_logger
158 | loggers: List[LoggerDestination] = [
159 | build_logger(name, logger_cfg) for name, logger_cfg in (cfg.get("loggers") or {}).items()
160 | ]
161 | loggers.append(in_memory_logger)
162 |
163 | fsdp_config = None
164 | fsdp_config = om.to_container(fsdp_config, resolve=True) if fsdp_config is not None else None
165 |
166 | load_path = cfg.get("load_path", None)
167 |
168 | trainer = Trainer(
169 | model=composer_model,
170 | loggers=loggers,
171 | precision=cfg.precision,
172 | fsdp_config=fsdp_config, # type: ignore
173 | load_path=load_path,
174 | load_weights_only=True,
175 | progress_bar=False,
176 | log_to_console=True,
177 | dist_timeout=cfg.dist_timeout,
178 | # eval_subset_num_batches=cfg.eval_subset_num_batches
179 | )
180 |
181 | if torch.cuda.is_available():
182 | torch.cuda.synchronize()
183 | a = time.time()
184 | trainer.eval(
185 | eval_dataloader=evaluators,
186 | # subset_num_batches=cfg.eval_subset_num_batches
187 | )
188 |
189 | if torch.cuda.is_available():
190 | torch.cuda.synchronize()
191 | b = time.time()
192 |
193 | print(f"Ran eval in: {b-a} seconds")
194 |
195 | performance_on_tasks = defaultdict(list)
196 | for key in logger_keys:
197 | if key in in_memory_logger.data:
198 | result = in_memory_logger.data[key][0][1].item()
199 | flag = True
200 | if len(icl_tasks_w_categories) > 0:
201 | for task in icl_tasks_w_categories:
202 | if task in key:
203 | performance_on_tasks[task].append(result)
204 | flag = False
205 | if flag:
206 | performance_on_tasks[key].append(result)
207 |
208 | report_results = {}
209 | for task in performance_on_tasks:
210 | result = sum(performance_on_tasks[task]) / len(performance_on_tasks[task])
211 | if len(task.split("/")) > 1:
212 | label = task.split("/")[1]
213 | report_results[label] = result
214 | else:
215 | report_results[task] = result
216 | print(report_results)
217 | return report_results
218 |
219 |
220 | def set_args_for_val(args, data, key):
221 | setattr(args, "val_data", data)
222 | setattr(args, "val_data_key", key)
223 | setattr(args, "squash_mask_left", True)
224 | setattr(args, "target_mask_individual", 50400)
225 | setattr(args, "target_mask_left", 50300)
226 | setattr(args, "val_seq_ci", True)
227 | setattr(args, "val_tok_ci", True)
228 | return args
229 |
230 |
231 | def main():
232 | """
233 | Usage:
234 | python eval_openlm_ckpt.py --checkpoint --model --eval-yaml --tokenizer
235 | example:
236 | cd eval
237 | python eval_openlm_ckpt.py --checkpoint ../checkpoints/llama2_7b.pt --model llama2_7b.json --eval-yaml in_memory_hf_eval.yaml --tokenizer
238 | multi-gpu example:
239 | cd eval
240 | torchrun --nproc_per_node 3 eval_openlm_ckpt.py --checkpoint ../checkpoints/llama2_7b.pt --model llama2_7b.json --eval-yaml in_memory_hf_eval.yaml --tokenizer
241 |
242 | torchrun --nproc_per_node 3 eval_openlm_ckpt.py --checkpoint checkpoint.pt --config params.txt
243 | """
244 | parser = argparse.ArgumentParser()
245 | # Arguments that openlm requires when we call load_model
246 | parser.add_argument(
247 | "--seed",
248 | type=int,
249 | default=None,
250 | help="Seed for reproducibility, when None, will use the seed from the eval config file.",
251 | )
252 | parser.add_argument("--fsdp", default=False, action="store_true")
253 | parser.add_argument("--distributed", default=True, action="store_true")
254 | parser.add_argument("--resume", default=None, type=str)
255 |
256 | # Argument for uploading results
257 | parser.add_argument("--remote-sync", type=str, default=None)
258 | parser.add_argument("--remote-sync-protocol", type=str, default="s3", choices=["s3", "fsspec"])
259 |
260 | parser.add_argument("--checkpoint", default=None, type=str, help="Path to checkpoint to evaluate.")
261 | parser.add_argument("--eval-yaml", type=str, default="light.yaml")
262 | parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b")
263 | parser.add_argument("--config", default=None, type=str)
264 | parser.add_argument("--model", default=None, type=str)
265 | parser.add_argument("--hf-model", default=None)
266 |
267 | parser.add_argument("--hf-cache-dir", default=None)
268 | parser.add_argument("--output-file", type=str, default=None)
269 | parser.add_argument(
270 | "--use-temp-working-dir",
271 | action="store_true",
272 | help="Use a temporary working directory for the evaluation. removing it when done. "
273 | "This is required if you wish to run multiple evaluations with the same datasets"
274 | " in parallel on the same node.",
275 | )
276 | parser.add_argument(
277 | "--eval_meta_data", default=f"{os.path.dirname(__file__)}/eval_meta_data.csv", help="Eval meta data file"
278 | )
279 | parser.add_argument(
280 | "--preset-world-size",
281 | type=int,
282 | default=None,
283 | help="Explicitly set the world size. Useful in cases where a different number of gpus per node need to be used.",
284 | )
285 | parser.add_argument(
286 | "--additional_aggregation",
287 | default=f"{os.path.dirname(__file__)}/additional_aggregation.json",
288 | help="Eval aggregation file",
289 | )
290 | parser.add_argument(
291 | "--val-data",
292 | type=str,
293 | default=None,
294 | help="val data for perplexity calc",
295 | )
296 | parser.add_argument(
297 | "--moe-freq",
298 | type=int,
299 | default=0,
300 | help="if set > 0, we will add MoE layer to every moe_freq layer.",
301 | )
302 | parser.add_argument(
303 | "--moe-num-experts",
304 | type=int,
305 | default=None,
306 | help="Number of experts for MoE",
307 | )
308 |
309 | parser.add_argument(
310 | "--moe-weight-parallelism",
311 | action="store_true",
312 | help="Add weight parallelism to MoE",
313 | )
314 |
315 | parser.add_argument(
316 | "--moe-expert-model-parallelism",
317 | action="store_true",
318 | help="Add expert model parallelism to MoE",
319 | )
320 |
321 | parser.add_argument(
322 | "--moe-capacity-factor",
323 | type=float,
324 | default=1.25,
325 | help="MoE capacity factor",
326 | )
327 |
328 | parser.add_argument(
329 | "--moe-loss-weight",
330 | type=float,
331 | default=0.1,
332 | help="MoE loss weight",
333 | )
334 | parser.add_argument(
335 | "--moe-top-k",
336 | type=int,
337 | default=2,
338 | help="MoE top k experts",
339 | )
340 | parser.add_argument(
341 | "--attn-name",
342 | type=str,
343 | default="xformers_attn",
344 | choices=["xformers_attn", "torch_attn", "custom_attn"],
345 | help="type of attention to use",
346 | )
347 | parser.add_argument(
348 | "--attn-activation",
349 | type=str,
350 | default=None,
351 | choices=list(ATTN_ACTIVATIONS.keys()),
352 | help="activation to use with custom_attn",
353 | )
354 | parser.add_argument(
355 | "--attn-seq-scalar",
356 | type=str,
357 | default=None,
358 | choices=list(ATTN_SEQ_SCALARS.keys()),
359 | help="different ways to set L, where L^alpha divides attention logits post activation",
360 | )
361 | parser.add_argument(
362 | "--attn-seq-scalar-alpha",
363 | type=float,
364 | default=None,
365 | help="power alpha to raise L to, where L^alpha divides attention logits post activation",
366 | )
367 | parser.add_argument(
368 | "--val-max-pop-ci",
369 | default=None,
370 | action="store",
371 | type=int,
372 | help="when running CIs what is the maximum population size for the inner loop",
373 | )
374 | parser.add_argument(
375 | "--val-iter-ci",
376 | default=10_000,
377 | action="store",
378 | type=int,
379 | help="how many times to sample to construct the CI for the outer loop",
380 | )
381 | parser.add_argument("--averager-name", help="If specified, load this averager from checkpoint.")
382 |
383 | parser.add_argument("--donot-compute-perplexity", action="store_true")
384 | parser.add_argument("--compute-downstream-perplexity", action="store_true")
385 | parser.add_argument("--compute-paloma-perplexity", action="store_true")
386 | parser.add_argument("--force-xformers", action="store_true")
387 |
388 | parser.add_argument("--num_experts_per_tok", type=int, default=None)
389 |
390 | args = parser.parse_args()
391 |
392 | orig_seed = args.seed # may be overridden by config file if it exists
393 |
394 | if os.path.exists(args.output_file):
395 | print (f"{args.output_file} exists!")
396 | return
397 |
398 | if args.config is not None:
399 | assert args.hf_model is None, (
400 | "If you are using a config file, "
401 | "you are trying to evaluate open_lm model. Please remove hf-model argument."
402 | )
403 |
404 | update_args_from_openlm_config(args)
405 | # disable wandb for eval
406 | args.wandb = None
407 | else:
408 | # Most probably evaling a hf-model.
409 |
410 | assert args.hf_model, (
411 | "If you are not using a config file, you might want to evaluate a Hugginface model, "
412 | "so please provide hf-model argument."
413 | )
414 | # Computing perplexity for HF model doesn't make sense.
415 | args.donot_compute_perplexity = True
416 |
417 | # Setting those params as they are needed to distributed evals
418 | # and they are supposed to come from config file.
419 | args.dist_backend = "nccl"
420 | args.dist_url = "env://"
421 | args.no_set_device_rank = False
422 | args.model = args.hf_model
423 | args.force_distributed = False
424 | with open(args.eval_yaml) as f:
425 | eval_cfg = om.load(f)
426 | if orig_seed is not None:
427 | print(f"Overriding eval config seed ({eval_cfg.seed}) to {orig_seed}")
428 | eval_cfg.seed = orig_seed
429 |
430 | # now need to set the 'fewshot_random_seed' in each config in the icl task configs
431 | for icl_cfg in eval_cfg.icl_tasks:
432 | icl_cfg.fewshot_random_seed = orig_seed
433 |
434 | args.resume = args.checkpoint
435 | args.remote_sync = args.output_file
436 | directory = os.path.dirname(args.output_file)
437 | if directory != "" and not os.path.exists(directory):
438 | os.makedirs(directory)
439 |
440 | CWD = os.getcwd()
441 | if args.use_temp_working_dir:
442 | temp_dir = os.path.join(CWD, "eval_openlm_ckpt_temp_dirs", f"{uuid.uuid4()}")
443 | os.makedirs(temp_dir, exist_ok=True) # in case rank > 0
444 | os.chdir(temp_dir)
445 | print(f"Using temporary working directory: {temp_dir}")
446 |
447 | print("Loading model into the right classes")
448 | if args.hf_model is not None:
449 |
450 | if args.num_experts_per_tok:
451 | config = AutoConfig.from_pretrained(
452 | args.hf_model, token=hf_access_token, trust_remote_code=True, cache_dir=args.hf_cache_dir)
453 | config.num_experts_per_tok = args.num_experts_per_tok
454 | eval_model = AutoModelForCausalLM.from_pretrained(
455 | args.hf_model, config=config, token=hf_access_token, trust_remote_code=True, cache_dir=args.hf_cache_dir
456 | )
457 | assert eval_model.num_experts_per_tok == args.num_experts_per_tok
458 |
459 | else:
460 | eval_model = AutoModelForCausalLM.from_pretrained(
461 | args.hf_model, token=hf_access_token, trust_remote_code=True, cache_dir=args.hf_cache_dir
462 | )
463 | else:
464 | params = create_params(args)
465 | eval_model = OpenLMforCausalLM(OpenLMConfig(params))
466 |
467 | if "gpt-neox-20b" in args.tokenizer:
468 | tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")
469 | elif "llama" in args.tokenizer:
470 | tokenizer = LlamaTokenizerFast.from_pretrained(args.tokenizer)
471 | if len(tokenizer) > eval_model.config.vocab_size: # happens in llama-3-8b
472 | print(f"Resizing vocab from {eval_model.config.vocab_size} to {len(tokenizer)}")
473 | eval_model.resize_token_embeddings(len(tokenizer))
474 | else:
475 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, token=hf_access_token, trust_remote_code=True, cache_dir=args.hf_cache_dir)
476 |
477 | if args.checkpoint is not None:
478 | if not args.averager_name:
479 | print(f"Loading checkpoint {args.checkpoint}")
480 | args.distributed = False
481 | load_model(args, eval_model.model, different_seed=True)
482 | args.distributed = True
483 | else:
484 | print(f"Loading checkpoint {args.checkpoint}")
485 | checkpoint = pt_load(args.resume, map_location="cpu")
486 | if "epoch" in checkpoint:
487 | # resuming a train checkpoint w/ epoch and optimizer state
488 | start_epoch = checkpoint["epoch"]
489 | avg_sd = torch.load(args.checkpoint, map_location="cpu")
490 | if next(iter(avg_sd.items()))[0].startswith("module"):
491 | avg_sd = {k[len("module.") :]: v for k, v in avg_sd.items()}
492 | eval_model.model.load_state_dict(avg_sd)
493 |
494 | # HF model loaded with from_pretrained is by default in eval mode.
495 | # https://github.com/huggingface/transformers/blob/ebfdb9ca62205279d5019ef1403877461b3b2da4/src/transformers/modeling_utils.py#L2500
496 | eval_model.model.eval()
497 |
498 | # Set requires grad = False to reduce memory consumption - o/w composer makes a copy of the model.
499 | for p in eval_model.parameters():
500 | p.requires_grad = False
501 |
502 | device = init_distributed_device(args)
503 | eval_model = eval_model.to(device)
504 | eval_metrics = {}
505 |
506 | local_rank, _, _ = world_info_from_env()
507 |
508 | if not args.donot_compute_perplexity:
509 | args.per_gpu_val_batch_size = args.per_gpu_batch_size // args.accum_freq
510 | openlm_val_data = download_val_data("open_lm_val", skip_download=local_rank != 0)
511 | args = set_args_for_val(args, [openlm_val_data], ["json"])
512 | data = get_data(args, epoch=0, tokenizer=None, skip_train=True)
513 | results = evaluate_loop(eval_model.model, data["val_list"], 0, args, None)
514 | perplexity_val = results[0]["loss"]
515 | eval_metrics["perplexity"] = perplexity_val
516 |
517 | if args.compute_paloma_perplexity:
518 | args.per_gpu_val_batch_size = args.per_gpu_batch_size // args.accum_freq
519 | paloma_val_data = download_val_data("paloma_val", skip_download=local_rank != 0)
520 | args = set_args_for_val(args, [paloma_val_data], ["json.gz"])
521 | data = get_data(args, epoch=0, tokenizer=None, skip_train=True)
522 | results = evaluate_loop(eval_model.model, data["val_list"], 0, args, None)
523 | perplexity_val = results[0]["loss"]
524 | eval_metrics["paloma_perplexity"] = perplexity_val
525 |
526 | if args.compute_downstream_perplexity:
527 | args.per_gpu_val_batch_size = args.per_gpu_batch_size // args.accum_freq
528 | size = args.eval_yaml[:-5]
529 | tasks = load_ppl_yaml(size)
530 | downstream_datas = [
531 | download_val_data(task_name, skip_download=local_rank != 0)
532 | for task_name in tasks
533 | if "gpqa" not in task_name
534 | ]
535 | args = set_args_for_val(args, downstream_datas, ["txt"] * len(downstream_datas))
536 | data = get_data(args, epoch=0, tokenizer=None, skip_train=True)
537 |
538 | results = evaluate_loop(eval_model.model, data["val_list"], 0, args, None)
539 | eval_metrics["downstream_perpexity"] = {}
540 | for result in results:
541 | data_name = result["val_data"][0].split("/")[-2]
542 | eval_metrics["downstream_perpexity"][data_name] = result["loss"]
543 |
544 | icl_results = evaluate(eval_model, tokenizer, eval_cfg)
545 | eval_metrics["icl"] = icl_results
546 |
547 | date_format = "%Y_%m_%d-%H_%M_%S"
548 | date = datetime.now(tz=pytz.utc)
549 | date = date.astimezone(timezone("US/Pacific"))
550 | date = date.strftime(date_format)
551 |
552 | output = {
553 | "name": str(args.eval_yaml)[:-5],
554 | "uuid": str(uuid.uuid4()),
555 | "model": args.model,
556 | "creation_date": date,
557 | "eval_metrics": eval_metrics,
558 | }
559 |
560 | with open(args.additional_aggregation, "r") as f:
561 | aggregation_json = json.load(f)
562 |
563 | eval_metadata = pd.read_csv(args.eval_meta_data)
564 |
565 | output = get_aggregated_results(output, eval_metadata, aggregation_json)
566 | print("Eval output: ")
567 | print(json.dumps(output, indent=4, sort_keys=True))
568 | if local_rank == 0:
569 | if args.use_temp_working_dir:
570 | print(f"Removing temporary working directory: {temp_dir} amd changing back to {CWD}")
571 | shutil.rmtree(temp_dir)
572 | os.chdir(CWD) # need to change back BEFORE we save the output file
573 |
574 | with open(args.output_file, "w") as f:
575 | json.dump(output, f, indent=4)
576 |
577 | return output
578 |
579 |
580 | if __name__ == "__main__":
581 | main()
582 |
--------------------------------------------------------------------------------
/scripts/humaneval.yaml:
--------------------------------------------------------------------------------
1 | epoch: 1.25T
2 | dataset: bigdata
3 | num_params: 1B
4 | max_seq_len: 2048
5 | seed: 1
6 | precision: fp32
7 |
8 | # Tokenizer
9 | tokenizer:
10 | # name: [Add name from memory]
11 | pretrained_model_name_or_path:
12 | kwargs:
13 | model_max_length: 2048
14 |
15 | model:
16 | name: open_lm
17 | # pretrained_model_name_or_path: [add name from memory]
18 | init_device: cpu
19 | pretrained: true
20 |
21 | load_path: # Add your (optional) Composer checkpoint path here!
22 |
23 | device_eval_batch_size: 2
24 |
25 | # FSDP config for model sharding
26 | fsdp_config:
27 | sharding_strategy: FULL_SHARD
28 | mixed_precision: FULL
29 |
30 |
31 | icl_tasks:
32 | -
33 | label: human_eval
34 | dataset_uri: local_data/programming/human_eval.jsonl # ADD YOUR OWN DATASET URI
35 | num_fewshot: [0]
36 | pass_at_k: 1
37 | batch_size: 1
38 | icl_task_type: code_evaluation
39 | generation_kwargs:
40 | num_beams: 5
41 | -
42 | label: human_eval_cpp
43 | dataset_uri: local_data/programming/processed_human_eval_cpp.jsonl # ADD YOUR OWN DATASET URI
44 | num_fewshot: [0]
45 | pass_at_k: 1
46 | batch_size: 1
47 | icl_task_type: code_evaluation
48 | generation_kwargs:
49 | num_beams: 5
50 | -
51 | label: human_eval_js
52 | dataset_uri: local_data/programming/processed_human_eval_js.jsonl # ADD YOUR OWN DATASET URI
53 | num_fewshot: [0]
54 | pass_at_k: 1
55 | batch_size: 1
56 | icl_task_type: code_evaluation
57 | generation_kwargs:
58 | num_beams: 5
59 | -
60 | label: human_eval_return_simple
61 | dataset_uri: local_data/programming/human_eval_return_simple.jsonl # ADD YOUR OWN DATASET URI
62 | num_fewshot: [0]
63 | pass_at_k: 1
64 | batch_size: 1
65 | icl_task_type: code_evaluation
66 | generation_kwargs:
67 | num_beams: 5
68 | -
69 | label: human_eval_return_complex
70 | dataset_uri: local_data/programming/human_eval_return_complex.jsonl # ADD YOUR OWN DATASET URI
71 | num_fewshot: [0]
72 | pass_at_k: 1
73 | batch_size: 1
74 | icl_task_type: code_evaluation
75 | generation_kwargs:
76 | num_beams: 5
77 | -
78 | label: human_eval_25
79 | dataset_uri: local_data/programming/human_eval-0.25.jsonl # ADD YOUR OWN DATASET URI
80 | num_fewshot: [0]
81 | pass_at_k: 1
82 | batch_size: 1
83 | icl_task_type: code_evaluation
84 | generation_kwargs:
85 | num_beams: 5
86 | -
87 | label: human_eval_50
88 | dataset_uri: local_data/programming/human_eval-0.5.jsonl # ADD YOUR OWN DATASET URI
89 | num_fewshot: [0]
90 | pass_at_k: 1
91 | batch_size: 1
92 | icl_task_type: code_evaluation
93 | generation_kwargs:
94 | num_beams: 5
95 | -
96 | label: human_eval_75
97 | dataset_uri: local_data/programming/human_eval-0.75.jsonl # ADD YOUR OWN DATASET URI
98 | num_fewshot: [0]
99 | pass_at_k: 1
100 | batch_size: 1
101 | icl_task_type: code_evaluation
102 | generation_kwargs:
103 | num_beams: 5
104 |
--------------------------------------------------------------------------------
/scripts/llm1b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=llm
3 | #SBATCH --nodes=8
4 | #SBATCH --ntasks-per-node=8
5 | #SBATCH --hint=nomultithread # we get physical cores not logical
6 | #SBATCH --partition=a3
7 | #SBATCH --gres=gpu:8 # number of gpus
8 | #SBATCH --output=/data/niklas/jobs/%x-%j.out # output file name
9 | #SBATCH --exclusive
10 |
11 | ######################
12 | ### Set enviroment ###
13 | ######################
14 | cd /home/niklas/OLMo
15 | source /env/bin/start-ctx-user
16 | conda activate llm
17 | export WANDB_PROJECT="olmo-small"
18 | # Training setup
19 | set -euo pipefail
20 | export GPUS_PER_NODE=8
21 | export NODENAME=$(hostname -s)
22 | export MASTER_ADDR=$(scontrol show hostnames | head -n 1)
23 | export MASTER_PORT=39594
24 | export WORLD_SIZE=$SLURM_NTASKS
25 | export RANK=$SLURM_PROCID
26 | export FS_LOCAL_RANK=$SLURM_PROCID
27 | export LOCAL_WORLD_SIZE=$SLURM_NTASKS_PER_NODE
28 | export LOCAL_RANK=$SLURM_LOCALID
29 | export NODE_RANK=$((($RANK - $LOCAL_RANK) / $LOCAL_WORLD_SIZE))
30 | export R2_PROFILE=r2
31 | export R2_ENDPOINT_URL=XXX
32 | export AWS_ACCESS_KEY_ID=XXX
33 | export AWS_SECRET_ACCESS_KEY=XXX
34 | export HF_DATASETS_OFFLINE=1
35 | echo "World size: $WORLD_SIZE"
36 | ######################
37 |
38 | ######################
39 | #### Set network #####
40 | ######################
41 | head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
42 | ######################
43 |
44 | srun \
45 | --distribution=block:block \
46 | --kill-on-bad-exit \
47 | scripts/run_with_environment.sh \
48 | python -u scripts/train.py configs/mitchish1-s3.yaml \
49 | --activation_checkpointing=fine_grained \
50 | --canceled_check_interval=50 \
51 | --gen1_gc_interval=1 \
52 | --device_train_microbatch_size=8 \
53 | --global_train_batch_size=512 \
54 | --run_name=mitchish1 \
55 | --wandb.group=mitchish1 \
56 | --model.flash_attention=true \
57 | --fsdp.wrapping_strategy=null \
58 | --fsdp.sharding_strategy=SHARD_GRAD_OP \
59 | --fused_loss=true \
60 | '--load_path=${path.last_checkpoint:${remote_save_folder}}' \
61 | "--save_folder=/data/niklas/llm/checkpoints/${SLURM_JOB_ID}/"
62 |
--------------------------------------------------------------------------------
/scripts/make_table.py:
--------------------------------------------------------------------------------
1 | """
2 | Make DCLM results table.
3 | """
4 |
5 | from pathlib import Path
6 | import json
7 | import pandas as pd
8 | import copy
9 |
10 |
11 | result_dir = Path("results/dclm")
12 |
13 |
14 | model_names = [
15 | "OLMoE-7B-A1B-main",
16 | "OLMoE-7B-A1B-step1220000-tokens5117B",
17 | "OLMoE-7B-A1B-step1223842-tokens5100B",
18 | "OLMo-1B-0724-hf"
19 | # "OLMo-7B-0724-hf.json", # Uncomment this once evals are done.
20 | ]
21 |
22 | eval_settings = ["heavy", "humaneval"]
23 |
24 | models_lookup = {
25 | "OLMoE-7B-A1B-main": "OLMoE-1B-7B",
26 | "OLMoE-7B-A1B-step1220000-tokens5117B": "OLMoE-1B-7B step 1,220,000",
27 | "OLMoE-7B-A1B-step1223842-tokens5100B": "OLMoE-1B-7B step 1,223,842",
28 | "OLMo-1B-0724-hf": "OLMo-1B",
29 | # "OLMo-7B-0724-hf": "OLMo-7B", # Uncomment this once evals are done.
30 | }
31 |
32 | metrics_lookup = {
33 | "agi_eval_lsat_ar": "AGI Eval LSAT-AR$^*$",
34 | "agi_eval_lsat_lr": "AGI Eval LSAT-LR",
35 | "agi_eval_lsat_rc": "AGI Eval LSAT-RC",
36 | "agi_eval_sat_en": "AGI Eval SAT-En",
37 | "agi_eval_sat_math_cot": "AGI Eval SAT-Math CoT",
38 | "aqua_cot": "AQuA CoT",
39 | "arc_challenge": "ARC Challenge$^*$",
40 | "arc_easy": "ARC Easy$^*$",
41 | "bbq": "BBQ",
42 | "bigbench_conceptual_combinations": "BigBench Conceptual Combinations",
43 | "bigbench_conlang_translation": "BigBench Conlang Translation",
44 | "bigbench_cs_algorithms": "BigBench CS Algorithms$^*$",
45 | "bigbench_dyck_languages": "BigBench Dyck Languages$^*$",
46 | "bigbench_elementary_math_qa": "BigBench Elementary Math QA",
47 | "bigbench_language_identification": "BigBench Language Identification$^*$",
48 | "bigbench_logical_deduction": "BigBench Logical Deduction",
49 | "bigbench_misconceptions": "BigBench Misconceptions",
50 | "bigbench_novel_concepts": "BigBench Novel Concepts",
51 | "bigbench_operators": "BigBench Operators$^*$",
52 | "bigbench_qa_wikidata": "BigBench QA Wikidata$^*$",
53 | "bigbench_repeat_copy_logic": "BigBench Repeat Copy Logic$^*$",
54 | "bigbench_strange_stories": "BigBench Strange Stories",
55 | "bigbench_strategy_qa": "BigBench Strategy QA",
56 | "bigbench_understanding_fables": "BigBench Understanding Fables",
57 | "boolq": "BoolQ$^*$",
58 | "commonsense_qa": "CommonsenseQA$^*$",
59 | "copa": "COPA$^*$",
60 | "coqa": "CoQA$^*$",
61 | "enterprise_pii_classification": "Enterprise PII Classification",
62 | "gpqa_diamond": "GPQA Diamond",
63 | "gpqa_main": "GPQA Main",
64 | "gsm8k_cot": "GSM8K CoT",
65 | "hellaswag": "HellaSwag 10-shot$^*$",
66 | "hellaswag_zeroshot": "HellaSwag 0-shot$^*$",
67 | "jeopardy": "Jeopardy$^*$",
68 | "lambada_openai": "LAMBADA$^*$",
69 | "logi_qa": "LogiQA",
70 | "math_qa": "Math QA",
71 | "mmlu_fewshot": "MMLU Few-shot",
72 | "mmlu_zeroshot": "MMLU Zero-shot",
73 | "openbook_qa": "OpenBookQA$^*$",
74 | "piqa": "PIQA$^*$",
75 | "pubmed_qa_labeled": "PubMedQA",
76 | "simple_arithmetic_nospaces": "Simple Arithmetic, no spaces",
77 | "simple_arithmetic_withspaces": "Simple Arithmetic, with spaces",
78 | "siqa": "Social IQA",
79 | "squad": "SQuAD$^*$",
80 | "svamp_cot": "SVAMP CoT",
81 | "triviaqa_sm_sub": "Trivia QA",
82 | "winogender_mc_female": "Winogender Female",
83 | "winogender_mc_male": "Winogender Male",
84 | "winograd": "Winograd$^*$",
85 | "winogrande": "Winogrande$^*$",
86 | "core": "Core",
87 | "extended": "Extended",
88 | }
89 |
90 | res = {}
91 |
92 | for model_name in model_names:
93 | data = json.load(open(result_dir / f"heavy-{model_name}.json"))
94 | to_add = copy.deepcopy(data["eval_metrics"]["icl"])
95 | to_update = {
96 | "core": data["low_variance_datasets_centered"],
97 | "extended": data["aggregated_centered_results"],
98 | }
99 | to_add.update(to_update)
100 | res[model_name] = to_add
101 |
102 | res = pd.DataFrame(res) * 100
103 |
104 |
105 | # Replace the columns in res according to the `models_lookup` dict.
106 | res.columns = [models_lookup.get(col, col) for col in res.columns]
107 |
108 | # Replace the rows in res according to the `metrics_lookup` dict.
109 | res.index = [metrics_lookup.get(row, row) for row in res.index]
110 |
111 | res = res.reindex(sorted(res.index.drop(["Core", "Extended"])) + ["Core", "Extended"])
112 |
113 | res.to_csv("results/dclm-table.tsv", sep="\t", float_format="%.1f")
114 | res.to_latex("results/dclm-table.tex", float_format="%.1f")
115 |
--------------------------------------------------------------------------------
/scripts/megatron.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=llm
3 | #SBATCH --nodes=1
4 | #SBATCH --ntasks-per-node=1
5 | #SBATCH --hint=nomultithread # we get physical cores not logical
6 | #SBATCH --partition=a3
7 | #SBATCH --gres=gpu:8 # number of gpus
8 | #SBATCH --time 48:00:00 # maximum execution time (HH:MM:SS)
9 | #SBATCH --output=/data/niklas/jobs/%x-%j.out # output file name
10 | #SBATCH --exclusive
11 |
12 | ######################
13 | ### Set enviroment ###
14 | ######################
15 | cd /home/niklas/megablocks
16 | source /env/bin/start-ctx-user
17 | conda activate megatronmoe
18 | export WANDB_PROJECT="olmoe"
19 | # Training setup
20 | set -euo pipefail
21 | export GPUS_PER_NODE=8
22 | export NODENAME=$(hostname -s)
23 | export MASTER_ADDR=$(scontrol show hostnames | head -n 1)
24 | export MASTER_PORT=39594
25 | #export WORLD_SIZE=$SLURM_NTASKS
26 | #export RANK=$SLURM_PROCID
27 | #export FS_LOCAL_RANK=$SLURM_PROCID
28 | #export LOCAL_WORLD_SIZE=$SLURM_NTASKS_PER_NODE
29 | #export LOCAL_RANK=$SLURM_LOCALID
30 | #export NODE_RANK=$((($RANK - $LOCAL_RANK) / $LOCAL_WORLD_SIZE))
31 | export R2_PROFILE=r2
32 | export R2_ENDPOINT_URL=YOUR_URL
33 | export AWS_ACCESS_KEY_ID=YOUR_KEY
34 | export AWS_SECRET_ACCESS_KEY=YOUR_KEY
35 | export HF_DATASETS_OFFLINE=1
36 | export CUDA_LAUNCH_BLOCKING=1
37 | export TRITON_CACHE_DIR=/data/niklas/tritoncache
38 |
39 | #echo "World size: $WORLD_SIZE"
40 | ######################
41 |
42 | ######################
43 | #### Set network #####
44 | ######################
45 | head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
46 | ######################
47 |
48 | srun /home/niklas/dense_1b_8gpu.sh /data/niklas/llm/checkpoints/${SLURM_JOB_ID}/ 20000
49 |
--------------------------------------------------------------------------------
/scripts/megatron_dense_46m_8gpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | EXP_DIR=$1
4 |
5 | # scaling law: 1B tokesn @ 125m = 2k steps.
6 | #
7 | # 512 * 1k * 400k = 200b tokens.
8 | # 512 * 1k * 200k = 100b tokens.
9 | # 512 * 1k * 100k = 50b tokens (default).
10 | # 512 * 1k * 20k = 10b tokens.
11 | TRAINING_STEPS=100000
12 | if [ -n "${2}" ]; then
13 | TRAINING_STEPS=$2;
14 | fi
15 |
16 | ##
17 | ### Pre-training for GPT2 46M parameter.
18 | ##
19 |
20 | # Distributed hyperparameters.
21 | DISTRIBUTED_ARGUMENTS="\
22 | --nproc_per_node 8 \
23 | --nnodes 1 \
24 | --node_rank 0 \
25 | --master_addr localhost \
26 | --master_port 6000"
27 |
28 | # Model hyperparameters.
29 | MODEL_ARGUMENTS="\
30 | --num-layers 6 \
31 | --hidden-size 512 \
32 | --num-attention-heads 8 \
33 | --seq-length 1024 \
34 | --max-position-embeddings 1024"
35 |
36 | # Training hyperparameters.
37 | TRAINING_ARGUMENTS="\
38 | --micro-batch-size 64 \
39 | --global-batch-size 512 \
40 | --train-iters ${TRAINING_STEPS} \
41 | --lr-decay-iters ${TRAINING_STEPS} \
42 | --lr 0.0006 \
43 | --min-lr 0.00006 \
44 | --lr-decay-style cosine \
45 | --lr-warmup-fraction 0.01 \
46 | --clip-grad 1.0 \
47 | --init-method-std 0.01"
48 |
49 | C4_DATASET="/data/niklas/c4-subsets/55b/gpt2tok_c4_en_55B_text_document"
50 |
51 | # NOTE: We don't train for enough tokens for the
52 | # split to matter.
53 | DATA_ARGUMENTS="\
54 | --data-path ${C4_DATASET} \
55 | --vocab-file /data/niklas/vocab.json \
56 | --merge-file /data/niklas/merges.txt \
57 | --make-vocab-size-divisible-by 1024 \
58 | --split 969,30,1"
59 |
60 | COMPUTE_ARGUMENTS="\
61 | --bf16 \
62 | --DDP-impl local \
63 | --no-async-tensor-model-parallel-allreduce \
64 | --use-flash-attn"
65 |
66 | CHECKPOINT_ARGUMENTS="\
67 | --save-interval 2000 \
68 | --save ./${EXP_DIR}"
69 |
70 | EVALUATION_ARGUMENTS="\
71 | --eval-iters 100 \
72 | --log-interval 100 \
73 | --eval-interval 1000"
74 |
75 | torchrun ${DISTRIBUTED_ARGUMENTS} \
76 | third_party/Megatron-LM/pretrain_gpt.py \
77 | ${MODEL_ARGUMENTS} \
78 | ${TRAINING_ARGUMENTS} \
79 | ${DATA_ARGUMENTS} \
80 | ${COMPUTE_ARGUMENTS} \
81 | ${CHECKPOINT_ARGUMENTS} \
82 | ${EVALUATION_ARGUMENTS} |& tee ./${EXP_DIR}/train.log
83 |
--------------------------------------------------------------------------------
/scripts/megatron_dmoe_46m_8gpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | EXP_DIR=$1
4 |
5 | # 512 * 1k * 400k = 200b tokens.
6 | # 512 * 1k * 200k = 100b tokens.
7 | # 512 * 1k * 100k = 50b tokens (default).
8 | # 512 * 1k * 20k = 10b tokens.
9 | TRAINING_STEPS=20000
10 | if [ -n "${2}" ]; then
11 | TRAINING_STEPS=$2;
12 | fi
13 |
14 | NUM_EXPERTS=64
15 | if [ -n "${3}" ]; then
16 | NUM_EXPERTS=$3;
17 | fi
18 |
19 | TOP_K=1
20 | if [ -n "${4}" ]; then
21 | TOP_K=$4;
22 | fi
23 |
24 | LOSS_WEIGHT=0.1
25 | if [ -n "${5}" ]; then
26 | LOSS_WEIGHT=$5;
27 | fi
28 |
29 | BATCH_SIZE=64
30 | if [ -n "${6}" ]; then
31 | BATCH_SIZE=$6;
32 | fi
33 |
34 | ##
35 | ### Pre-training for dMoE 46M parameter.
36 | ##
37 |
38 | # MoE hyperparameters.
39 | MOE_ARGUMENTS="\
40 | --moe-num-experts=${NUM_EXPERTS} \
41 | --moe-loss-weight=${LOSS_WEIGHT} \
42 | --moe-top-k=${TOP_K}"
43 |
44 | # Distributed hyperparameters.
45 | DISTRIBUTED_ARGUMENTS="\
46 | --nproc_per_node ${GPUS_PER_NODE} \
47 | --nnodes 1 \
48 | --node_rank 0 \
49 | --master_addr ${MASTER_ADDR} \
50 | --master_port ${MASTER_PORT}"
51 |
52 | # Model hyperparameters.
53 | MODEL_ARGUMENTS="\
54 | --num-layers 6 \
55 | --hidden-size 512 \
56 | --num-attention-heads 8 \
57 | --seq-length 1024 \
58 | --max-position-embeddings 1024"
59 |
60 | # Training hyperparameters.
61 | TRAINING_ARGUMENTS="\
62 | --micro-batch-size ${BATCH_SIZE} \
63 | --global-batch-size 512 \
64 | --train-iters ${TRAINING_STEPS} \
65 | --lr-decay-iters ${TRAINING_STEPS} \
66 | --lr 0.0006 \
67 | --min-lr 0.00006 \
68 | --lr-decay-style cosine \
69 | --lr-warmup-fraction 0.01 \
70 | --clip-grad 1.0 \
71 | --init-method-std 0.01 \
72 | --optimizer adam"
73 |
74 | PILE_DATASET="\
75 | 1.0 \
76 | /mount/pile_gpt2/01_text_document \
77 | 1.0 \
78 | /mount/pile_gpt2/02_text_document \
79 | 1.0 \
80 | /mount/pile_gpt2/03_text_document \
81 | 1.0 \
82 | /mount/pile_gpt2/04_text_document \
83 | 1.0 \
84 | /mount/pile_gpt2/05_text_document \
85 | 1.0 \
86 | /mount/pile_gpt2/06_text_document \
87 | 1.0 \
88 | /mount/pile_gpt2/07_text_document \
89 | 1.0 \
90 | /mount/pile_gpt2/08_text_document \
91 | 1.0 \
92 | /mount/pile_gpt2/09_text_document \
93 | 1.0 \
94 | /mount/pile_gpt2/10_text_document \
95 | 1.0 \
96 | /mount/pile_gpt2/11_text_document \
97 | 1.0 \
98 | /mount/pile_gpt2/12_text_document \
99 | 1.0 \
100 | /mount/pile_gpt2/13_text_document \
101 | 1.0 \
102 | /mount/pile_gpt2/14_text_document \
103 | 1.0 \
104 | /mount/pile_gpt2/15_text_document \
105 | 1.0 \
106 | /mount/pile_gpt2/16_text_document \
107 | 1.0 \
108 | /mount/pile_gpt2/17_text_document \
109 | 1.0 \
110 | /mount/pile_gpt2/18_text_document \
111 | 1.0 \
112 | /mount/pile_gpt2/19_text_document \
113 | 1.0 \
114 | /mount/pile_gpt2/20_text_document \
115 | 1.0 \
116 | /mount/pile_gpt2/21_text_document \
117 | 1.0 \
118 | /mount/pile_gpt2/22_text_document \
119 | 1.0 \
120 | /mount/pile_gpt2/23_text_document \
121 | 1.0 \
122 | /mount/pile_gpt2/24_text_document \
123 | 1.0 \
124 | /mount/pile_gpt2/25_text_document \
125 | 1.0 \
126 | /mount/pile_gpt2/26_text_document \
127 | 1.0 \
128 | /mount/pile_gpt2/27_text_document \
129 | 1.0 \
130 | /mount/pile_gpt2/28_text_document \
131 | 1.0 \
132 | /mount/pile_gpt2/29_text_document"
133 |
134 | C4_DATASET="/data/niklas/c4-subsets/55b/gpt2tok_c4_en_55B_text_document"
135 |
136 | # NOTE: We don't train for enough tokens for the
137 | # split to matter.
138 | DATA_ARGUMENTS="\
139 | --data-path ${C4_DATASET} \
140 | --vocab-file /data/niklas/vocab.json \
141 | --merge-file /data/niklas/merges.txt \
142 | --make-vocab-size-divisible-by 1024 \
143 | --split 969,30,1"
144 |
145 | COMPUTE_ARGUMENTS="\
146 | --bf16 \
147 | --DDP-impl local \
148 | --moe-expert-model-parallelism \
149 | --no-async-tensor-model-parallel-allreduce \
150 | --use-flash-attn"
151 |
152 | CHECKPOINT_ARGUMENTS="\
153 | --save-interval 2000 \
154 | --save ./${EXP_DIR}"
155 |
156 | EVALUATION_ARGUMENTS="\
157 | --eval-iters 100 \
158 | --log-interval 100 \
159 | --eval-interval 1000"
160 |
161 | torchrun ${DISTRIBUTED_ARGUMENTS} \
162 | third_party/Megatron-LM/pretrain_gpt.py \
163 | ${MOE_ARGUMENTS} \
164 | ${MODEL_ARGUMENTS} \
165 | ${TRAINING_ARGUMENTS} \
166 | ${DATA_ARGUMENTS} \
167 | ${COMPUTE_ARGUMENTS} \
168 | ${CHECKPOINT_ARGUMENTS} \
169 | ${EVALUATION_ARGUMENTS} |& tee ./${EXP_DIR}/train.log
170 |
--------------------------------------------------------------------------------
/scripts/olmoe-gantry.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -ex
3 |
4 | CONFIG_PATH=configs/olmoe17/olmoe-8x1b-newhp-newds-final.yml
5 | ARGS='--run_name=olmoe-8x1b-newhp-newds-final --save-overwrite --fsdp.sharding_strategy=FULL_SHARD --device_train_microbatch_size=4 --canceled_check_interval=9999999'
6 |
7 | #NUM_NODES=1
8 | #NUM_NODES=8
9 | #NUM_NODES=16
10 | NUM_NODES=32
11 | BEAKER_REPLICA_RANK=0
12 |
13 | gantry run \
14 | --weka oe-training-default:/weka/oe-training-default \
15 | --allow-dirty \
16 | --preemptible \
17 | --priority urgent \
18 | --workspace ai2/olmoe \
19 | --task-name olmoe \
20 | --description olmoe \
21 | --beaker-image shanea/olmo-torch2.2-gantry \
22 | --budget ai2/oe-training \
23 | --cluster ai2/jupiter-cirrascale-2 \
24 | --gpus 8 \
25 | --replicas "${NUM_NODES}" \
26 | --env-secret WANDB_API_KEY=WANDB_API_KEY \
27 | --env-secret AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID \
28 | --env-secret AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY \
29 | --env-secret R2_ENDPOINT_URL=R2_ENDPOINT_URL \
30 | --leader-selection \
31 | --host-networking \
32 | --env LOG_FILTER_TYPE=local_rank0_only \
33 | --env OMP_NUM_THREADS=8 \
34 | --env OLMO_TASK=model \
35 | --shared-memory 10GiB \
36 | --venv base \
37 | --yes \
38 | --synchronized-start-timeout 60m \
39 | -- /bin/bash -c "pip install --upgrade torch==2.3.0; pip install --upgrade flash-attn --no-build-isolation; pip install git+https://github.com/Muennighoff/megablocks.git@zloss; mkdir -p /root/.cache; pushd /root/.cache; curl "https://storage.googleapis.com/dirkgr-public/huggingface_cache_v3.tar.gz" | tar --keep-newer-files -xzf -; popd; export HF_DATASETS_OFFLINE=1; export NCCL_IB_HCA=^=mlx5_bond_0; SLURM_JOB_ID=${BEAKER_JOB_ID} torchrun --nnodes ${NUM_NODES}:${NUM_NODES} --node_rank ${BEAKER_REPLICA_RANK} --nproc-per-node 8 --rdzv_id=12347 --rdzv_backend=c10d --rdzv_conf='read_timeout=420' --rdzv_endpoint=\$BEAKER_LEADER_REPLICA_HOSTNAME:29400 scripts/train.py ${CONFIG_PATH} ${ARGS}"
40 |
41 | # Single node:
42 | #--rdzv_endpoint=\$BEAKER_NODE_HOSTNAME:29400
43 | # Multinode:
44 | #--rdzv_endpoint=\$BEAKER_LEADER_REPLICA_HOSTNAME:29400
45 | # --mount /net/nfs.cirrascale/allennlp/petew/cache:/root/.cache \
46 | #--node_rank=$BEAKER_REPLICA_RANK
47 | # --nfs \
--------------------------------------------------------------------------------
/scripts/routing_mixtral_v2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_mixtral_v2.jpg
--------------------------------------------------------------------------------
/scripts/routing_olmoe_v2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_olmoe_v2.jpg
--------------------------------------------------------------------------------
/scripts/routing_output.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output.zip
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/eid2token/arxiv.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/eid2token/arxiv.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/eid2token/book.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/eid2token/book.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/eid2token/c4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/eid2token/c4.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/eid2token/github.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/eid2token/github.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/eid2token/wikipedia.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/eid2token/wikipedia.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts/arxiv.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts/arxiv.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts/book.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts/book.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts/c4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts/c4.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts/github.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts/github.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts/wikipedia.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts/wikipedia.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_crosslayer/arxiv.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer/arxiv.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_crosslayer/book.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer/book.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_crosslayer/c4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer/c4.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_crosslayer/github.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer/github.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_crosslayer/wikipedia.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer/wikipedia.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_crosslayer_top1/arxiv.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer_top1/arxiv.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_crosslayer_top1/book.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer_top1/book.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_crosslayer_top1/c4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer_top1/c4.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_crosslayer_top1/github.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer_top1/github.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_crosslayer_top1/wikipedia.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer_top1/wikipedia.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_top1/arxiv.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_top1/arxiv.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_top1/book.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_top1/book.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_top1/c4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_top1/c4.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_top1/github.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_top1/github.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/mistral/expert_counts_top1/wikipedia.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_top1/wikipedia.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe-dpo/expert_counts/tulu.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe-dpo/expert_counts/tulu.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe-sft/expert_counts/tulu.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe-sft/expert_counts/tulu.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/eid2token/arxiv.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/eid2token/arxiv.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/eid2token/book.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/eid2token/book.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/eid2token/c4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/eid2token/c4.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/eid2token/github.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/eid2token/github.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/eid2token/wikipedia.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/eid2token/wikipedia.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts/arxiv.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts/arxiv.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts/book.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts/book.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts/c4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts/c4.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts/github.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts/github.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts/tulu.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts/tulu.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts/wikipedia.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts/wikipedia.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_crosslayer/arxiv.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer/arxiv.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_crosslayer/book.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer/book.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_crosslayer/c4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer/c4.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_crosslayer/github.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer/github.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_crosslayer/wikipedia.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer/wikipedia.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/arxiv.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/arxiv.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/book.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/book.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/c4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/c4.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/github.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/github.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/wikipedia.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/wikipedia.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_top1/arxiv.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_top1/arxiv.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_top1/book.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_top1/book.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_top1/c4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_top1/c4.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_top1/github.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_top1/github.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/olmoe/expert_counts_top1/wikipedia.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_top1/wikipedia.pkl
--------------------------------------------------------------------------------
/scripts/routing_output/routing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/routing.jpg
--------------------------------------------------------------------------------
/scripts/routing_output/routing_prob_distribution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/routing_prob_distribution.png
--------------------------------------------------------------------------------
/scripts/run_dclm_evals_heavy.sh:
--------------------------------------------------------------------------------
1 | # Evaluate model checkpoints using DCLM eval suite.
2 | # Usage: bash script/run_dclm_evals.sh
3 | # Evaluates on all the `heavy` tasks from DCLM, plus the humaneval code tasks.
4 |
5 | # Run using this conda env: `/net/nfs.cirrascale/allennlp/davidw/miniconda3/envs/dclm`
6 |
7 | DCLM_DIR=/net/nfs.cirrascale/allennlp/davidw/proj/dclm/eval
8 | MODEL_DIR=/net/nfs.cirrascale/allennlp/davidw/checkpoints/moe-release
9 | WKDIR=$(pwd)
10 | METRICS_DIR=$WKDIR/results/dclm
11 |
12 | mkdir -p $METRICS_DIR
13 |
14 |
15 | declare -a models=(
16 | OLMoE-7B-A1B/main
17 | OLMoE-7B-A1B/step1220000-tokens5117B
18 | OLMoE-7B-A1B/step1223842-tokens5100B
19 | jetmoe-8b/main
20 | )
21 |
22 |
23 | cd $DCLM_DIR
24 | export TQDM_DISABLE=1
25 |
26 | for model in "${models[@]}"
27 | do
28 | out_name=${model//\//-}
29 | out_file=$METRICS_DIR/heavy-${out_name}.json
30 |
31 | mason \
32 | --cluster ai2/pluto-cirrascale \
33 | --budget ai2/oe-training \
34 | --gpus 1 \
35 | --workspace ai2/olmoe \
36 | --description "Run DCLM evals for MoE model $name" \
37 | --task_name "eval-${out_name}" \
38 | --priority high \
39 | --preemptible \
40 | -- \
41 | python eval_openlm_ckpt.py \
42 | --hf-model $MODEL_DIR/$model \
43 | --tokenizer $MODEL_DIR/$model \
44 | --eval-yaml heavy.yaml \
45 | --output-file $out_file \
46 | --use-temp-working-dir
47 | done
48 |
--------------------------------------------------------------------------------
/scripts/run_dclm_evals_heavy_olmo.sh:
--------------------------------------------------------------------------------
1 | # Evaluate olmo 1B and 7B.
2 | # Usage: bash script/run_dclm_evals_heavy_olmo.sh
3 | # Evaluates on all the `heavy` tasks from DCLM.
4 |
5 | # Run using this conda env: `/net/nfs.cirrascale/allennlp/davidw/miniconda3/envs/dclm`
6 |
7 | DCLM_DIR=/net/nfs.cirrascale/allennlp/davidw/proj/dclm/eval
8 | WKDIR=$(pwd)
9 | METRICS_DIR=$WKDIR/results/dclm
10 |
11 | mkdir -p $METRICS_DIR
12 |
13 |
14 | declare -a models=(
15 | allenai/OLMo-7B-0724-hf
16 | allenai/OLMo-1B-0724-hf
17 | )
18 |
19 |
20 | cd $DCLM_DIR
21 | export TQDM_DISABLE=1
22 |
23 | for model in "${models[@]}"
24 | do
25 | out_name=$(echo "$model" | awk -F '/' '{ print $NF }')
26 | out_file=$METRICS_DIR/heavy-${out_name}.json
27 |
28 | mason \
29 | --cluster ai2/s2-cirrascale \
30 | --budget ai2/oe-training \
31 | --gpus 4 \
32 | --workspace ai2/olmoe \
33 | --description "Run DCLM evals for OLMo model $name" \
34 | --task_name "eval-${out_name}" \
35 | --priority high \
36 | -- \
37 | python eval_openlm_ckpt.py \
38 | --hf-model $model \
39 | --tokenizer $model \
40 | --eval-yaml heavy.yaml \
41 | --output-file $out_file \
42 | --use-temp-working-dir
43 | done
44 |
--------------------------------------------------------------------------------
/scripts/run_dclm_evals_humaneval.sh:
--------------------------------------------------------------------------------
1 | # Evaluate model checkpoints using DCLM eval suite.
2 | # Usage: bash script/run_dclm_evals.sh
3 | # Evaluates on all the `humaneval` tasks from DCLM, plus the humaneval code tasks.
4 |
5 | # Run using this conda env: `/net/nfs.cirrascale/allennlp/davidw/miniconda3/envs/dclm`
6 |
7 | DCLM_DIR=/net/nfs.cirrascale/allennlp/davidw/proj/dclm/eval
8 | MODEL_DIR=/net/nfs.cirrascale/allennlp/davidw/checkpoints/moe-release
9 | WKDIR=$(pwd)
10 | METRICS_DIR=$WKDIR/results/dclm
11 |
12 | mkdir -p $METRICS_DIR
13 |
14 |
15 | declare -a models=(
16 | OLMoE-7B-A1B/main
17 | OLMoE-7B-A1B/step1220000-tokens5117B
18 | OLMoE-7B-A1B/step1223842-tokens5100B
19 | jetmoe-8b/main
20 | )
21 |
22 |
23 | cd $DCLM_DIR
24 | export TQDM_DISABLE=1
25 |
26 | for model in "${models[@]}"
27 | do
28 | out_name=${model//\//-}
29 | out_file=$METRICS_DIR/humaneval-${out_name}.json
30 |
31 | mason \
32 | --cluster ai2/pluto-cirrascale \
33 | --budget ai2/oe-training \
34 | --gpus 1 \
35 | --workspace ai2/olmoe \
36 | --description "Run HumanEval DCLM evals for MoE model $name" \
37 | --task_name "eval-${out_name}" \
38 | --priority high \
39 | --preemptible \
40 | -- \
41 | python eval_openlm_ckpt.py \
42 | --hf-model $MODEL_DIR/$model \
43 | --tokenizer $MODEL_DIR/$model \
44 | --eval-yaml $WKDIR/script/humaneval.yaml \
45 | --output-file $out_file \
46 | --use-temp-working-dir
47 | done
48 |
--------------------------------------------------------------------------------
/scripts/run_moe_analysis.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | import time
5 | import pickle as pkl
6 | import json
7 | import numpy as np
8 |
9 | from collections import defaultdict, Counter
10 |
11 | from datasets import load_dataset
12 | from transformers import OlmoeForCausalLM, AutoTokenizer, AutoModelForCausalLM
13 | from huggingface_hub import list_repo_refs
14 |
15 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16 | if os.path.exists("hf_access_token.txt"):
17 | with open("hf_access_token.txt") as f:
18 | token = f.read().strip()
19 | else:
20 | token = None
21 |
22 | start_time = time.time()
23 |
24 | def tokenize_c4(tokenized_path, sample_ratio=0.005, model="olmoe"):
25 | np.random.seed(2024)
26 | with open(tokenized_path, "w") as f:
27 | if model == "olmoe":
28 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token)
29 | elif model == "mixtral":
30 | tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1", token=token)
31 | else: raise NotImplementedError(f"model={model}")
32 | cnt = 0
33 | for row in load_dataset("allenai/c4", "en", split="validation", streaming=True):
34 | if np.random.random() < sample_ratio:
35 | input_ids = tokenizer(row["text"])["input_ids"]
36 | f.write(json.dumps({"input_ids": input_ids})+"\n")
37 | cnt += 1
38 | print(f"Loaded {cnt} lines!")
39 |
40 | def load_c4(tokenized_path, bs):
41 | tokens = []
42 | with open(tokenized_path, "r") as f:
43 | for line in f:
44 | tokens += json.loads(line)["input_ids"]
45 | while len(tokens) >= bs:
46 | yield tokens[:bs]
47 | tokens = tokens[bs:]
48 |
49 | def load_model(revision="main", model="olmoe"):
50 | if model == "olmoe":
51 | model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token, revision=revision).to(DEVICE)
52 | elif model == "mixtral":
53 | model = AutoModelForCausalLM.from_pretrained(
54 | "mistralai/Mixtral-8x7B-v0.1", token=token, revision=revision, device_map="auto"
55 | )
56 | else: raise NotImplementedError(f"model={model}")
57 | return model
58 |
59 | def do_inference(args, run_all_checkpoints=False):
60 | run(args, "main")
61 |
62 | if run_all_checkpoints:
63 | out = list_repo_refs("allenai/OLMoE-1B-7B-0924", token=token)
64 | cand_branches = [b.name for b in out.branches]
65 | all_branches = []
66 | # Old checkpoints previously used: ["15000", "130000", "250000", "490000"]:
67 | # Percentages of pretraining: 0.00408549469 ; 0.0980518727 ; 0.20018924011 ; 0.40037848022
68 | for step in ["5000", "120000", "245000", "490000"]:
69 | branches = [name for name in cand_branches if name.startswith(f"step{step}-tokens") and name.endswith("B")]
70 | assert len(branches) == 1, f"{step} ; {branches}"
71 | all_branches.append(branches[0])
72 |
73 | for b in all_branches:
74 | run(args, b, postfix="_"+b.split("-")[0][4:], save_exp_id_only=True)
75 |
76 | def run(args, revision, postfix="", length=2048, save_exp_id_only=False):
77 | save_path = os.path.join(args.out_dir, f"c4_results{postfix}.jsonl")
78 | assert not os.path.exists(save_path), f"{save_path} already exists"
79 |
80 | model = load_model(revision, model=args.model)
81 |
82 | results = []
83 | start_time = time.time()
84 |
85 | print("Start inference")
86 | for input_ids in load_c4(args.tokenized_path, length):
87 | input_ids = torch.LongTensor(input_ids).reshape(1, -1).to(DEVICE)
88 | out = model(input_ids=input_ids, output_router_logits=True)
89 |
90 | input_ids = input_ids[0].detach().cpu().numpy().tolist()
91 | logits = out["logits"][0].detach().cpu().numpy()
92 | predicted_token_ids = np.argmax(logits, -1).tolist()
93 | router_logits = [l.detach().cpu().numpy() for l in out["router_logits"]]
94 | if args.model == "olmoe":
95 | assert len(router_logits) == 16
96 | exp_ids = np.stack([np.argsort(-logits, -1)[:, :8].tolist() for logits in router_logits], -1).tolist()
97 | assert np.array(exp_ids).shape == (2048, 8, 16)
98 | elif args.model == "mixtral":
99 | assert len(router_logits) == 32
100 | exp_ids = np.stack([np.argsort(-logits, -1)[:, :2].tolist() for logits in router_logits], -1).tolist()
101 | assert np.array(exp_ids).shape == (2048, 2, 32)
102 | else: raise NotImplementedError(f"model={args.model}")
103 | results.append({"input_ids": input_ids, "predicted_token_ids": predicted_token_ids, "exp_ids": exp_ids})
104 |
105 | if len(results) % 50 == 0:
106 | print("Finish %d batches (%dmin)" % (len(results), (time.time()-start_time)/60))
107 |
108 | with open(save_path, "w") as f:
109 | for r in results:
110 | f.write(json.dumps(r)+"\n")
111 |
112 | print("Saved %d batches to %s" % (len(results), save_path))
113 |
114 | def do_ckpt_analysis(args):
115 | FONTSIZE = 36
116 | with open(os.path.join(args.out_dir, "c4_results.jsonl"), "r") as f:
117 | results = []
118 | for line in f:
119 | results.append(json.loads(line))
120 |
121 | assert args.topk in [1, 8, 18]
122 |
123 | import matplotlib.pylab as plt
124 | if args.topk == 18:
125 | fig, axes = plt.subplots(figsize=(24, 8), ncols=2, nrows=1, sharey=True, layout='constrained')
126 | titles = ["Top-k=1", "Top-k=8"]
127 | else:
128 | fig, ax = plt.subplots(figsize=(16, 8))
129 | axes = [ax]
130 | titles = ["Top-k=1"] if args.topk == 1 else ["Top-k=8"]
131 |
132 | def compare_with_early_results(postfix, topk):
133 | with open(os.path.join(args.out_dir, "c4_results_{}.jsonl".format(postfix)), "rb") as f:
134 | early_results = []
135 | for line in f:
136 | early_results.append(json.loads(line))
137 |
138 | equals = Counter()
139 | total = 0
140 | ratio_per_layer = defaultdict(list)
141 |
142 | for r, er in zip(results, early_results):
143 | assert r["input_ids"] == er["input_ids"]
144 |
145 | if topk == 1:
146 | exp_ids = np.array(r["exp_ids"])[:, 0, :]
147 | e_exp_ids = np.array(er["exp_ids"])[:, 0, :]
148 | assert exp_ids.shape==e_exp_ids.shape, (exp_ids.shape, e_exp_ids.shape)
149 | curr_equals = exp_ids==e_exp_ids # [token_cnt, n_layers]
150 | for i, _curr_equals in enumerate(curr_equals.transpose(1, 0).tolist()):
151 | equals[i] += np.sum(_curr_equals)
152 | total += len(curr_equals)
153 | elif topk == 8:
154 | exp_ids = np.array(r["exp_ids"])
155 | e_exp_ids = np.array(er["exp_ids"])
156 | for layer_idx in range(16):
157 | indices = exp_ids[:, :, layer_idx]
158 | e_indices = e_exp_ids[:, :, layer_idx]
159 | assert indices.shape==e_indices.shape==(2048, 8)
160 | for _indices, _e_indices in zip(indices.tolist(), e_indices.tolist()):
161 | ratio = len(set(_indices) & set(_e_indices)) / 8
162 | ratio_per_layer[layer_idx].append(ratio)
163 | else:
164 | raise NotImplementedError(f"topk={topk}")
165 |
166 | row = [postfix]
167 | result = []
168 | for i in range(16):
169 | if topk == 1:
170 | row.append("%.1f" % (100 * equals[i] / total))
171 | result.append(equals[i] / total)
172 | else:
173 | row.append("%.1f" % (100 * np.mean(ratio_per_layer[i])))
174 | result.append(np.mean(ratio_per_layer[i]))
175 |
176 | pt.add_row(row)
177 | return result
178 |
179 | x = ["1", "10", "20", "40"]
180 | palette = ["#9e0142", "#c12a49", "#d53e4f", "#e65949", "#f46d43", "#f78a52", "#fdae61", "#ffd700", "#ffffbf", "#d8ef94", "#66c2a5", "#429db4", "#3288bd", "#5a71c1", "#7c4ab3", "#5e4fa2"]
181 | palette = [
182 | "#F0539B", "#43C5E0", "#2E3168", "#FDBE15",
183 | "#F0539B", "#43C5E0", "#2E3168", "#FDBE15",
184 | "#F0539B", "#43C5E0", "#2E3168", "#FDBE15",
185 | "#F0539B", "#43C5E0", "#2E3168", "#FDBE15",
186 | ]
187 | alpha = [
188 | 0.8, 0.8, 0.8, 0.8,
189 | 0.6, 0.6, 0.6, 0.6,
190 | 0.4, 0.4, 0.4, 0.4,
191 | 0.2, 0.2, 0.2, 0.2,
192 | ]
193 | linestyle = [
194 | "-", "-", "-", "-",
195 | "--", "--", "--", "--",
196 | ":", ":", ":", ":",
197 | "-.", "-.", "-.", "-.",
198 | ]
199 | for i, ax in enumerate(axes):
200 | from prettytable import PrettyTable
201 | pt = PrettyTable()
202 | pt.field_names = [""] + [str(i) for i in range(16)]
203 | topk = int(titles[i][-1])
204 | #r1 = compare_with_early_results(15000, topk) # 1%
205 | #r2 = compare_with_early_results(130000, topk) # 10%
206 | #r3 = compare_with_early_results(250000, topk) # 20%
207 | #r4 = compare_with_early_results(490000, topk) # 40%
208 | r1 = compare_with_early_results(5000, topk) # 1%
209 | r2 = compare_with_early_results(120000, topk) # 10%
210 | r3 = compare_with_early_results(245000, topk) # 20%
211 | r4 = compare_with_early_results(490000, topk) # 40%
212 | merged_results = np.array([r1, r2, r3, r4]) # [4, 16]
213 | # Define the original 4 colors
214 | colors = ["#F0539B", "#43C5E0", "#2E3168", "#FDBE15"]
215 |
216 | # Create a custom colormap using the defined colors
217 | from matplotlib.colors import LinearSegmentedColormap
218 | cmap = LinearSegmentedColormap.from_list("custom_theme", colors, N=16)
219 | # Generate 16 colors from the colormap
220 | additional_colors = [cmap(i / 15) for i in range(16)]
221 |
222 | #import seaborn as sns
223 | # Define the original 4 colors
224 | #colors = ["#F0539B", "#43C5E0", "#2E3168", "#FDBE15"]
225 | # Create a seaborn color palette with 16 colors based on the 4 theme colors
226 | #palette = sns.color_palette(colors, n_colors=16)
227 |
228 | for j in range(16):
229 | #ax.plot(x, merged_results[:, j] * 100, label=j, color=palette[j], marker='o', markersize=12, linewidth=6)
230 | #ax.plot(x, merged_results[:, j] * 100, label=j, color=palette[j], marker='o', markersize=12, linewidth=6, linestyle=linestyle[j])
231 | ax.plot(x, merged_results[:, j] * 100, label=j, color=additional_colors[j], marker='o', markersize=12, linewidth=6)
232 | ax.tick_params(axis='both', which='major', labelsize=FONTSIZE)
233 | ax.set_title(titles[i], fontsize=FONTSIZE, fontweight='bold')
234 | ax.spines['top'].set_visible(False)
235 | ax.spines['right'].set_visible(False)
236 | # ax.set_ylim(0, 0.9)
237 | if args.topk in [8, 18]:
238 | plt.legend(frameon=True, title="Layer ID", title_fontsize=FONTSIZE, fontsize=FONTSIZE, columnspacing=0.4, labelspacing=0.4, ncol=4)
239 | fig.supxlabel('Pretraining stage (%)', fontsize=FONTSIZE, fontweight='bold')
240 | # fig.supylabel('% of active experts matching\nfinal checkpoint for the same input data', fontsize=FONTSIZE, fontweight='bold')
241 | # fig.supylabel('Active experts matching\n final checkpoint (%)', fontsize=FONTSIZE, fontweight='bold')
242 | fig.supylabel('Router saturation (%)', fontsize=FONTSIZE, fontweight='bold')
243 | plt.savefig(os.path.join(args.fig_dir, f"top{args.topk}_changes_over_checkpoints.png"))
244 | plt.savefig(os.path.join(args.fig_dir, f"top{args.topk}_changes_over_checkpoints.pdf"))
245 |
246 | def do_coactivation_analysis(args):
247 | FONTSIZE = 18
248 | with open(os.path.join(args.out_dir, "c4_results.jsonl"), "rb") as f:
249 | results = []
250 | for line in f:
251 | try:
252 | results.append(json.loads(line))
253 | except Exception as e:
254 | print("Failed to load line", e)
255 | break
256 |
257 | pairwise_counter = Counter()
258 | single_counter = Counter()
259 | from more_itertools import pairwise
260 |
261 | layer_num = args.layer_num
262 | for result in results:
263 | for indices in np.array(result["exp_ids"])[:, :, layer_num].tolist():
264 | assert len(indices) == 8
265 | for i in indices:
266 | single_counter[i] += 1
267 | for (a, b) in pairwise(indices):
268 | pairwise_counter[(a, b)] += 1
269 | pairwise_counter[(b, a)] += 1
270 | pairwise_probs = {
271 | (a, b): pairwise_counter[(a, b)] / single_counter[a]
272 | for (a, b) in pairwise_counter
273 | }
274 |
275 | new_idx_to_orig_idx = []
276 | N = 16
277 |
278 | for (a, b), p in sorted(pairwise_probs.items(), key=lambda x: -x[1]):
279 | if a not in new_idx_to_orig_idx:
280 | new_idx_to_orig_idx.append(a)
281 | if b not in new_idx_to_orig_idx:
282 | new_idx_to_orig_idx.append(b)
283 | if len(new_idx_to_orig_idx) == N: break
284 |
285 | scores = np.zeros((N, N))
286 | labels_x, labels_y = [], []
287 | for i in range(N):
288 | labels_x.append(new_idx_to_orig_idx[i])
289 | for j in range(N):
290 | scores[i, j] = pairwise_probs.get(
291 | (new_idx_to_orig_idx[i], new_idx_to_orig_idx[j]), 0) * 100
292 | if i == 0:
293 | labels_y.append(new_idx_to_orig_idx[j])
294 |
295 | import matplotlib.pylab as plt
296 | import seaborn as sns
297 | # ax = sns.heatmap(scores, cmap="Reds", linewidth=.5, center=30, vmax=60, xticklabels=labels_x, yticklabels=labels_y)
298 | from matplotlib.colors import LinearSegmentedColormap
299 | # Define a custom colormap using shades of #F0539B
300 | cmap = LinearSegmentedColormap.from_list("custom_cmap", ["white", "#F0539B", "#4A0033"], N=256)
301 |
302 | # Generate the heatmap with the custom colormap
303 | ax = sns.heatmap(
304 | scores,
305 | cmap=cmap,
306 | linewidth=0.5,
307 | center=30,
308 | vmax=60,
309 | xticklabels=labels_x,
310 | yticklabels=labels_y,
311 | cbar_kws={'ticks': [0, 15, 30, 45, 60]} # Optional: Customize color bar ticks
312 | )
313 | # Increase tick size
314 | ax.tick_params(axis='both', which='major', labelsize=FONTSIZE)
315 | # Increase colorbar ticks
316 | cbar = ax.collections[0].colorbar
317 | cbar.ax.tick_params(labelsize=FONTSIZE)
318 | # This sets the yticks "upright" with 0, as opposed to sideways with 90.
319 | plt.yticks(rotation=0)
320 | plt.xticks(rotation=-90)
321 | # Set title
322 | plt.title(f"Layer {layer_num}", fontsize=FONTSIZE, fontweight='bold')
323 | plt.savefig(os.path.join(args.fig_dir, f"layer_{layer_num}_heatmap.png"))
324 | plt.savefig(os.path.join(args.fig_dir, f"layer_{layer_num}_heatmap.pdf"))
325 |
326 | def do_token_analysis(args, tex_format=True):
327 | with open(os.path.join(args.out_dir, "c4_results.jsonl"), "rb") as f:
328 | results = []
329 | for line in f:
330 | try:
331 | results.append(json.loads(line))
332 | except Exception as e:
333 | print("Failed to load line", e)
334 | break
335 | assert args.model == "olmoe"
336 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token)
337 |
338 | input_id_to_exp = defaultdict(list)
339 | gt_next_token_to_exp = defaultdict(list)
340 | predicted_next_token_to_exp = defaultdict(list)
341 |
342 | for result in results:
343 | input_ids = result["input_ids"]
344 | gt_next_token_ids = input_ids[1:]
345 | predicted_next_token_ids = result["predicted_token_ids"]
346 | exp_ids = np.array(result["exp_ids"])[:, 0, args.layer_num].tolist()
347 |
348 | for _id, exp_id in zip(input_ids, exp_ids):
349 | input_id_to_exp[_id].append(exp_id)
350 | for _id, exp_id in zip(gt_next_token_ids, exp_ids):
351 | gt_next_token_to_exp[_id].append(exp_id)
352 | for _id, exp_id in zip(predicted_next_token_ids, exp_ids):
353 | predicted_next_token_to_exp[_id].append(exp_id)
354 |
355 | def print_avg(id_to_exp):
356 | probs = []
357 | exp_id_to_vocabs = defaultdict(list)
358 |
359 | for _id, exp_ids in id_to_exp.items():
360 | most_freq_id, val = sorted(Counter(exp_ids).items(), key=lambda x: -x[1])[0]
361 | probs.append(val / len(exp_ids))
362 | if len(exp_ids) >= 10:
363 | exp_id_to_vocabs[most_freq_id].append((_id, val/len(exp_ids)))
364 |
365 | # if you want to allow token IDs to appear for multiple experts:
366 | """
367 | c = Counter(exp_ids)
368 | # import pdb; pdb.set_trace()
369 | for exp_id in c:
370 | if len(exp_ids) >= 10:
371 | exp_id_to_vocabs[exp_id].append((_id, c[exp_id] / len(exp_ids)))
372 | """
373 |
374 | print("Average probability:", np.mean(probs))
375 | return exp_id_to_vocabs
376 |
377 | exp_id_to_vocabs = print_avg(input_id_to_exp)
378 | print_avg(gt_next_token_to_exp)
379 | exp_id_to_predicted_vocabs = print_avg(predicted_next_token_to_exp)
380 |
381 | with open(f"exp_id_to_vocabs_layer{args.layer_num}.txt", "w") as f:
382 | #for exp_id, vocabs in sorted(exp_id_to_vocabs.items(), key=lambda x: -np.mean([p for _, p in x[1]])):
383 | for exp_id, vocabs in sorted(exp_id_to_vocabs.items(), key=lambda x: -np.mean([p for _, p in x[1]]) + -np.mean([p for _, p in exp_id_to_predicted_vocabs[x[0]]])):
384 | #for exp_id, vocabs in sorted(exp_id_to_vocabs.items()):
385 | text = "exp_id: %d" % exp_id
386 | for vocab, p in sorted(vocabs, key=lambda x: -x[1])[:15]:
387 | if tex_format:
388 | text += " \colorbox{lightOlmoeYellow}{%s} (%d\\%%)" % (tokenizer._decode(vocab), 100*p) # + str(vocab) # for unknowns
389 | else:
390 | text += " %s (%d%%)" % (tokenizer._decode(vocab), 100*p)
391 | f.write(text + "\n\n")
392 |
393 | with open(f"exp_id_to_predicted_vocabs_layer{args.layer_num}.txt", "w") as f:
394 | #for exp_id, vocabs in sorted(exp_id_to_predicted_vocabs.items(), key=lambda x: -np.mean([p for _, p in x[1]])):
395 | for exp_id, vocabs in sorted(exp_id_to_predicted_vocabs.items(), key=lambda x: -np.mean([p for _, p in x[1]]) + -np.mean([p for _, p in exp_id_to_vocabs[x[0]]])):
396 | #for exp_id, vocabs in sorted(exp_id_to_predicted_vocabs.items()):
397 | text = "exp_id: %d" % exp_id
398 | for vocab, p in sorted(vocabs, key=lambda x: -x[1])[:15]:
399 | if tex_format:
400 | text += " \colorbox{lightOlmoeYellow}{%s} (%d\\%%)" % (tokenizer._decode(vocab), 100*p) # + str(vocab) # for unknowns
401 | else:
402 | text += " %s (%d%%)" % (tokenizer._decode(vocab), 100*p)
403 | f.write(text + "\n\n")
404 |
405 | def do_token_analysis_layers(args, tex_format=True):
406 | FONTSIZE = 28
407 | with open(os.path.join(args.out_dir, "c4_results.jsonl"), "rb") as f:
408 | results = []
409 | for line in f:
410 | results.append(json.loads(line))
411 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token)
412 |
413 | def print_avg(id_to_exp):
414 | probs = []
415 | exp_id_to_vocabs = defaultdict(list)
416 |
417 | for _id, exp_ids in id_to_exp.items():
418 | most_freq_id, val = sorted(Counter(exp_ids).items(), key=lambda x: -x[1])[0]
419 | probs.append(val / len(exp_ids))
420 | if len(exp_ids) >= 10:
421 | # if len(exp_ids) >= 8:
422 | exp_id_to_vocabs[most_freq_id].append((_id, val/len(exp_ids)))
423 |
424 | print("Average probability:", np.mean(probs))
425 | return np.mean(probs)
426 |
427 | layer_num_to_probs = {}
428 | for layer_num in list(range(16)):
429 | print(f"Layer {layer_num}")
430 | input_id_to_exp = defaultdict(list)
431 | gt_next_token_to_exp = defaultdict(list)
432 | predicted_next_token_to_exp = defaultdict(list)
433 | for result in results:
434 | input_ids = result["input_ids"]
435 | gt_next_token_ids = input_ids[1:]
436 | predicted_next_token_ids = result["predicted_token_ids"]
437 | exp_ids = np.array(result["exp_ids"])[:, 0, layer_num].tolist()
438 |
439 | for _id, exp_id in zip(input_ids, exp_ids):
440 | input_id_to_exp[_id].append(exp_id)
441 | for _id, exp_id in zip(gt_next_token_ids, exp_ids):
442 | gt_next_token_to_exp[_id].append(exp_id)
443 | for _id, exp_id in zip(predicted_next_token_ids, exp_ids):
444 | predicted_next_token_to_exp[_id].append(exp_id)
445 |
446 | input_prob = print_avg(input_id_to_exp)
447 | gt_prob = print_avg(gt_next_token_to_exp)
448 | output_prob = print_avg(predicted_next_token_to_exp)
449 |
450 | layer_num_to_probs[layer_num] = (input_prob, gt_prob, output_prob)
451 |
452 | import matplotlib.pylab as plt
453 | fig, ax = plt.subplots(figsize=(12, 8))
454 | x = list(range(16))
455 | y = [layer_num_to_probs[i][0] * 100 for i in x]
456 | ax.plot(x, y, label="Input tokens", color="#F0539B", marker='o', markersize=12, linewidth=6)
457 | y = [layer_num_to_probs[i][2] * 100 for i in x]
458 | ax.plot(x, y, label="Predicted output tokens", color="#2E3168", marker='o', markersize=12, linewidth=6)
459 | y = [layer_num_to_probs[i][1] * 100 for i in x]
460 | ax.plot(x, y, label="Ground-truth output tokens", color="#43C5E0", marker='o', markersize=12, linewidth=6)
461 | ax.tick_params(axis='both', which='major', labelsize=FONTSIZE)
462 | ax.set_xticks(x)
463 | ax.set_xlabel("Layer ID", fontsize=FONTSIZE, fontweight='bold')
464 | ax.set_ylabel("Vocabulary specialization (%)", fontsize=FONTSIZE, fontweight='bold')
465 | ax.spines['top'].set_visible(False)
466 | ax.spines['right'].set_visible(False)
467 | plt.legend(frameon=True, fontsize=FONTSIZE, columnspacing=0.4, labelspacing=0.4, loc="lower right")
468 | plt.savefig(os.path.join(args.fig_dir, f"layerwise_token_analysis.pdf"))
469 | plt.savefig(os.path.join(args.fig_dir, f"layerwise_token_analysis.png"))
470 |
471 | def do_token_analysis_experts(args, tex_format=True, do_sort=False):
472 | FONTSIZE = 28
473 | with open(os.path.join(args.out_dir, "c4_results.jsonl"), "rb") as f:
474 | results = []
475 | for line in f:
476 | results.append(json.loads(line))
477 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token)
478 |
479 | def print_avg(id_to_exp):
480 | probs = []
481 | exp_id_to_probs = defaultdict(list)
482 | for _id, exp_ids in id_to_exp.items():
483 | most_freq_id, val = sorted(Counter(exp_ids).items(), key=lambda x: -x[1])[0]
484 | if len(exp_ids) >= 10:
485 | probs.append(val / len(exp_ids))
486 | exp_id_to_probs[most_freq_id].append(val / len(exp_ids))
487 | # c = Counter(exp_ids)
488 | # import pdb; pdb.set_trace()
489 | # for exp_id in c:
490 | # exp_id_to_probs[exp_id].append(c[exp_id] / len(exp_ids))
491 |
492 | # avg
493 | exp_id_to_probs = {k: (np.mean(v), len(v)) for k, v in exp_id_to_probs.items()}
494 | print(exp_id_to_probs)
495 | print("Average probability:", np.mean(probs))
496 | return exp_id_to_probs, np.mean(probs)
497 |
498 | input_id_to_exp = defaultdict(list)
499 | gt_next_token_to_exp = defaultdict(list)
500 | predicted_next_token_to_exp = defaultdict(list)
501 |
502 | for result in results:
503 | input_ids = result["input_ids"]
504 | gt_next_token_ids = input_ids[1:]
505 | predicted_next_token_ids = result["predicted_token_ids"]
506 | exp_ids = np.array(result["exp_ids"])[:, 0, args.layer_num].tolist()
507 |
508 | for _id, exp_id in zip(input_ids, exp_ids):
509 | input_id_to_exp[_id].append(exp_id)
510 | for _id, exp_id in zip(gt_next_token_ids, exp_ids):
511 | gt_next_token_to_exp[_id].append(exp_id)
512 | for _id, exp_id in zip(predicted_next_token_ids, exp_ids):
513 | predicted_next_token_to_exp[_id].append(exp_id)
514 |
515 | input_exp_id_to_probs, input_prob = print_avg(input_id_to_exp)
516 | gt_exp_id_to_probs, gt_prob = print_avg(gt_next_token_to_exp)
517 | predicted_exp_id_to_probs, pred_prob = print_avg(predicted_next_token_to_exp)
518 |
519 | import matplotlib.pylab as plt
520 | fig, ax = plt.subplots(figsize=(12, 8))
521 | x = list(range(64))
522 | # Sort x by input_exp_id_to_probs
523 | x_sorted = sorted(x, key=lambda i: -input_exp_id_to_probs[i][0]) if do_sort else x
524 | y = [input_exp_id_to_probs[i][0] * 100 for i in x_sorted]
525 | ax.plot(x, y, label="Input tokens", color="#F0539B", marker='o', markersize=12, linewidth=6)
526 | y = [gt_exp_id_to_probs[i][0] * 100 for i in x_sorted]
527 | ax.plot(x, y, label="Predicted output tokens", color="#2E3168", marker='o', markersize=12, linewidth=6)
528 | y = [predicted_exp_id_to_probs[i][0] * 100 for i in x_sorted]
529 | ax.plot(x, y, label="Ground-truth output tokens", color="#43C5E0", marker='o', markersize=12, linewidth=6)
530 | # Draw horizontal prob lines for input_prob etc
531 | ax.axhline(input_prob * 100, color="#F0539B", linestyle='--', linewidth=3)
532 | ax.axhline(gt_prob * 100, color="#2E3168", linestyle='--', linewidth=3)
533 | ax.axhline(pred_prob * 100, color="#43C5E0", linestyle='--', linewidth=3)
534 | ax.tick_params(axis='both', which='major', labelsize=FONTSIZE)
535 | ax.set_xticks(x, x_sorted)
536 | ax.set_xlabel("Expert ID", fontsize=FONTSIZE, fontweight='bold')
537 | ax.set_ylabel("Vocabulary specialization (%)", fontsize=FONTSIZE, fontweight='bold')
538 | ax.spines['top'].set_visible(False)
539 | ax.spines['right'].set_visible(False)
540 | plt.title(f"Layer {args.layer_num}", fontsize=FONTSIZE, fontweight='bold')
541 | plt.legend(frameon=True, fontsize=FONTSIZE, columnspacing=0.4, labelspacing=0.4, loc="lower right")
542 | plt.savefig(os.path.join(args.fig_dir, f"vocabulary_specialization_experts.pdf"))
543 | plt.savefig(os.path.join(args.fig_dir, f"vocabulary_specialization_experts.png"))
544 |
545 |
546 | def do_token_analysis_layers_experts(args, do_sort=False, normalize=False):
547 | """normalize does not make a big difference, hence not used"""
548 | from matplotlib import rcParams
549 | rcParams.update({'figure.autolayout': True})
550 |
551 | with open(os.path.join(args.out_dir, "c4_results.jsonl"), "r") as f:
552 | results = []
553 | for line in f:
554 | results.append(json.loads(line))
555 |
556 | assert args.topk in [1, 2, 8]
557 |
558 | with open(os.path.join(args.out_dir, "c4_results.jsonl"), "rb") as f:
559 | results = []
560 | for line in f:
561 | results.append(json.loads(line))
562 |
563 | import matplotlib.pylab as plt
564 | if args.model == "olmoe":
565 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token)
566 | random_prob = args.topk/64
567 | num_layers = 16
568 | # fig, axes = plt.subplots(figsize=(32, 8), ncols=2, nrows=1, sharey=True, layout='constrained', width_ratios=[1, 2])
569 | fig, axes = plt.subplots(figsize=(32, 8), ncols=2, nrows=1, sharey=False, layout='constrained', width_ratios=[1, 2])
570 | else:
571 | tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1", token=token)
572 | random_prob = args.topk/8
573 | num_layers = 32
574 | # fig, axes = plt.subplots(figsize=(32, 8), ncols=2, nrows=1, sharey=True, layout='constrained', width_ratios=[2, 1])
575 | fig, axes = plt.subplots(figsize=(32, 8), ncols=2, nrows=1, sharey=False, layout='constrained', width_ratios=[2, 1])
576 |
577 | def print_avg(id_to_exp):
578 | probs = []
579 | exp_id_to_probs = defaultdict(list)
580 | for _id, exp_ids in id_to_exp.items():
581 | most_freq_id, val = sorted(Counter(exp_ids).items(), key=lambda x: -x[1])[0]
582 | max_possible = len(exp_ids) // args.topk
583 | if normalize:
584 | # use random prob to normalize such that 0 is random and 1 is perfect
585 | prob = (val / max_possible - random_prob) / (1 - random_prob)
586 | else:
587 | prob = val / max_possible
588 | probs.append(prob)
589 | exp_id_to_probs[most_freq_id].append(prob)
590 | exp_id_to_probs = {k: (np.mean(v), len(v)) for k, v in exp_id_to_probs.items()}
591 | print("Average probability:", np.mean(probs))
592 | return exp_id_to_probs, np.mean(probs)
593 |
594 | layer_num_to_probs = {}
595 | for layer_num in list(range(num_layers)):
596 | print(f"Layer {layer_num}")
597 | input_id_to_exp = defaultdict(list)
598 | gt_next_token_to_exp = defaultdict(list)
599 | predicted_next_token_to_exp = defaultdict(list)
600 | for result in results:
601 | input_ids = result["input_ids"]
602 | gt_next_token_ids = input_ids[1:]
603 | predicted_next_token_ids = result["predicted_token_ids"]
604 | exp_ids = np.array(result["exp_ids"])[:, :args.topk, layer_num].tolist()
605 |
606 | for _id, exp_id in zip(input_ids, exp_ids):
607 | # input_id_to_exp[_id].append(exp_id)
608 | input_id_to_exp[_id].extend(exp_id)
609 | for _id, exp_id in zip(gt_next_token_ids, exp_ids):
610 | # gt_next_token_to_exp[_id].append(exp_id)
611 | gt_next_token_to_exp[_id].extend(exp_id)
612 | for _id, exp_id in zip(predicted_next_token_ids, exp_ids):
613 | # predicted_next_token_to_exp[_id].append(exp_id)
614 | predicted_next_token_to_exp[_id].extend(exp_id)
615 |
616 | input_prob = print_avg(input_id_to_exp)[1]
617 | gt_prob = print_avg(gt_next_token_to_exp)[1]
618 | output_prob = print_avg(predicted_next_token_to_exp)[1]
619 | layer_num_to_probs[layer_num] = (input_prob, gt_prob, output_prob)
620 |
621 | # Specialization for one specific layer
622 | if layer_num == args.layer_num:
623 | input_exp_id_to_probs, input_prob_chosen = print_avg(input_id_to_exp)
624 | gt_exp_id_to_probs, gt_prob_chosen = print_avg(gt_next_token_to_exp)
625 | predicted_exp_id_to_probs, pred_prob_chosen = print_avg(predicted_next_token_to_exp)
626 |
627 | FONTSIZE = 34
628 | ### Layer spec ###
629 | ax = axes[0]
630 | x = list(range(num_layers))
631 | y = [layer_num_to_probs[i][0] * 100 for i in x]
632 | ax.plot(x, y, label="Input tokens", color="#F0539B", marker='o', markersize=12, linewidth=6)
633 | y = [layer_num_to_probs[i][2] * 100 for i in x]
634 | ax.plot(x, y, label="Predicted output tokens", color="#43C5E0", marker='o', markersize=12, linewidth=6)
635 | y = [layer_num_to_probs[i][1] * 100 for i in x]
636 | ax.plot(x, y, label="Ground-truth output tokens", color="#2E3168", marker='o', markersize=12, linewidth=6)
637 | ax.tick_params(axis='both', which='major', labelsize=FONTSIZE)
638 | #ax.set_xlim(0.5, num_layers-1)
639 | ax.margins(x=0.01)
640 | ax.set_xticks(x)
641 | ax.set_xlabel("Layer ID", fontsize=FONTSIZE, fontweight='bold')
642 | ax.set_ylabel("Vocabulary specialization (%) ", fontsize=FONTSIZE, fontweight='bold')
643 | # ax.set_title("Layer-wise token analysis", fontsize=FONTSIZE, fontweight='bold')
644 | ax.spines['top'].set_visible(False)
645 | ax.spines['right'].set_visible(False)
646 | ax.set_title("Per layer", fontsize=FONTSIZE, fontweight='bold')
647 | # ax.legend(frameon=True, fontsize=FONTSIZE, columnspacing=0.4, labelspacing=0.4)
648 | ### Expert spec ###
649 | """Line Plot
650 | ax = axes[1]
651 | if args.model == "olmoe":
652 | x = list(range(64))[:32] # limit to 32 experts
653 | else:
654 | x = list(range(8))
655 | # Sort x by input_exp_id_to_probs
656 | x_sorted = sorted(x, key=lambda i: -input_exp_id_to_probs[i][0]) if do_sort else x
657 | y = [input_exp_id_to_probs[i][0] * 100 for i in x_sorted]
658 | ax.plot(x, y, label="Input tokens", color="#F0539B", marker='o', markersize=16, linewidth=6, linestyle='dotted')
659 | # Scatter instead
660 | # ax.scatter(x, y, color="#F0539B", s=200, label="Input tokens")
661 | y = [gt_exp_id_to_probs[i][0] * 100 for i in x_sorted]
662 | ax.plot(x, y, label="Predicted output tokens", color="#43C5E0", marker='o', markersize=16, linewidth=6, linestyle='dotted')
663 | # ax.scatter(x, y, color="#43C5E0", s=200, label="Predicted output tokens")
664 | y = [predicted_exp_id_to_probs[i][0] * 100 for i in x_sorted]
665 | ax.plot(x, y, label="Ground-truth output tokens", color="#2E3168", marker='o', markersize=16, linewidth=6, linestyle='dotted')
666 | # ax.scatter(x, y, color="#2E3168", s=200, label="Ground-truth output tokens")
667 | # Draw horizontal prob lines for input_prob etc
668 | ax.axhline(input_prob_chosen * 100, color="#F0539B", linestyle='--', linewidth=6, alpha=0.8)
669 | ax.axhline(pred_prob_chosen * 100, color="#43C5E0", linestyle='--', linewidth=6, alpha=0.8)
670 | ax.axhline(gt_prob_chosen * 100, color="#2E3168", linestyle='--', linewidth=6, alpha=0.8)
671 | ax.tick_params(axis='both', which='major', labelsize=FONTSIZE)
672 | #ax.set_xlim(0, x[-1])
673 | ax.margins(x=0.005)
674 | ax.set_xticks(x, x_sorted)
675 | ax.set_xlabel("Expert ID", fontsize=FONTSIZE, fontweight='bold')
676 | # ax.set_ylabel("Vocabulary specialization (%)", fontsize=FONTSIZE, fontweight='bold')
677 | # ax.set_title("Layer-wise token analysis", fontsize=FONTSIZE, fontweight='bold')
678 | ax.spines['top'].set_visible(False)
679 | ax.spines['right'].set_visible(False)
680 | ax.set_title(f"Per expert in layer {args.layer_num}", fontsize=FONTSIZE, fontweight='bold')
681 | plt.legend(frameon=True, fontsize=FONTSIZE, columnspacing=0.4, labelspacing=0.4, loc="lower right")
682 | """
683 | #"""Bar Plot
684 | ax = axes[1]
685 | if args.model == "olmoe":
686 | x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 27, 37, 58]
687 | else:
688 | x = list(range(8))
689 | # Sort x by input_exp_id_to_probs
690 | x_sorted = list(range(len(x)))
691 | y = [input_exp_id_to_probs[i][0] * 100 for i in x]
692 | # ax.bar(x, y, color="#F0539B", label="Input tokens")
693 | # Space them out
694 | x_bar = np.array(x_sorted) - 0.275
695 | ax.bar(x_bar, y, color="#F0539B", label="Input token ID", width=0.25, alpha=0.8)
696 | y = [predicted_exp_id_to_probs[i][0] * 100 for i in x_sorted]
697 | # ax.bar(x, y, color="#43C5E0", label="Predicted output tokens")
698 | x_bar = np.array(x_sorted)
699 | ax.bar(x_bar, y, color="#43C5E0", label="Predicted output token ID", width=0.25, alpha=0.8)
700 | y = [gt_exp_id_to_probs[i][0] * 100 for i in x_sorted]
701 | # ax.bar(x, y, color="#2E3168", label="Ground-truth output tokens")
702 | x_bar = np.array(x_sorted) + 0.275
703 | ax.bar(x_bar, y, color="#2E3168", label="Ground-truth output token ID", width=0.25, alpha=0.8)
704 | # Draw horizontal prob lines for input_prob etc
705 | ax.axhline(input_prob_chosen * 100, color="#F0539B", linestyle='--', linewidth=6)
706 | ax.axhline(pred_prob_chosen * 100, color="#43C5E0", linestyle='--', linewidth=6)
707 | ax.axhline(gt_prob_chosen * 100, color="#2E3168", linestyle='--', linewidth=6)
708 | ax.tick_params(axis='both', which='major', labelsize=FONTSIZE)
709 | ax.margins(x=0.005)
710 | ax.set_xticks(x_sorted, x)
711 | ax.set_xlabel("Expert ID", fontsize=FONTSIZE, fontweight='bold')
712 | ax.spines['top'].set_visible(False)
713 | ax.spines['right'].set_visible(False)
714 | ax.set_title(f"Per expert in layer {args.layer_num}", fontsize=FONTSIZE, fontweight='bold')
715 | plt.legend(frameon=True, fontsize=FONTSIZE, columnspacing=0.4, labelspacing=0.4, loc="lower right")
716 | #"""
717 | plt.savefig(os.path.join(args.fig_dir, f"vocabulary_specialization_top{args.topk}_{args.model}.pdf"), bbox_inches='tight')
718 | plt.savefig(os.path.join(args.fig_dir, f"vocabulary_specialization_top{args.topk}_{args.model}.png"), bbox_inches='tight')
719 |
720 |
721 | if __name__ == '__main__':
722 | """
723 | First, run the following to save model outputs:
724 | `python moe.py --do_inference --do_inference_all_ckpts`
725 | (skip `--do_inference_all_ckpts` if you'll not run ckpt analysis)
726 |
727 | Then, to do ckpt analysis (Router Saturation):
728 | `python moe.py --do_ckpt_analysis --topk 1
729 | python moe.py --do_ckpt_analysis --topk 8`
730 |
731 | To do coactivation analysis:
732 | `python moe.py --do_coactivation_analysis --layer_num 0
733 | python moe.py --do_coactivation_analysis --layer_num 7
734 | python moe.py --do_coactivation_analysis --layer_num 15`
735 |
736 | To do token analysis (Vocabulary specialization):
737 | `python moe.py --do_token_analysis`
738 | `python moe.py --do_token_analysis_layers_experts --topk 1`
739 |
740 | For all comments: use `--tokenized_path`, `--out_dir`, `--fig_dir` to specify where to save stuff
741 | To use Mixtral, use `--model mixtral` (also requires rerunning do_inference)
742 | """
743 |
744 | parser = argparse.ArgumentParser(prog="moe.py", description="Run analyses on OLMoE")
745 | parser.add_argument("--do_inference", action="store_true")
746 | parser.add_argument("--do_inference_all_ckpts", action="store_true")
747 |
748 | parser.add_argument("--do_ckpt_analysis", action="store_true")
749 | parser.add_argument("--do_coactivation_analysis", action="store_true")
750 | parser.add_argument("--do_token_analysis", action="store_true")
751 | parser.add_argument("--do_token_analysis_layers", action="store_true")
752 | parser.add_argument("--do_token_analysis_experts", action="store_true")
753 | parser.add_argument("--do_token_analysis_layers_experts", action="store_true")
754 |
755 | parser.add_argument("--tokenized_path", default="c4_validation.jsonl", type=str, help="directory to save tokenized c4 data.")
756 | parser.add_argument("--out_dir", default="out", type=str, help="directory to save outputs from the model")
757 | parser.add_argument("--fig_dir", default="figs", type=str, help="directory to save figures")
758 |
759 | parser.add_argument("--topk", choices=[1, 2, 8, 18], default=1, type=int)
760 | parser.add_argument("--layer_num", choices=list(range(16)), default=7, type=int)
761 | parser.add_argument("--model", default="olmoe", type=str, help="Which model; if not olmoe, then mixtral")
762 |
763 | args = parser.parse_args()
764 |
765 | if args.do_inference:
766 | if not os.path.exists(args.tokenized_path):
767 | tokenize_c4(args.tokenized_path, model=args.model)
768 | if not os.path.exists(args.out_dir):
769 | os.mkdir(args.out_dir)
770 | do_inference(args, run_all_checkpoints=args.do_inference_all_ckpts)
771 |
772 | if args.do_ckpt_analysis or args.do_coactivation_analysis or args.do_token_analysis:
773 | if not os.path.exists(args.fig_dir):
774 | os.mkdir(args.fig_dir)
775 |
776 | if args.do_ckpt_analysis:
777 | do_ckpt_analysis(args)
778 |
779 | if args.do_coactivation_analysis:
780 | do_coactivation_analysis(args)
781 |
782 | if args.do_token_analysis:
783 | do_token_analysis(args)
784 |
785 | if args.do_token_analysis_layers:
786 | do_token_analysis_layers(args)
787 |
788 | if args.do_token_analysis_experts:
789 | do_token_analysis_experts(args)
790 |
791 | if args.do_token_analysis_layers_experts:
792 | do_token_analysis_layers_experts(args)
--------------------------------------------------------------------------------
/scripts/run_routing_analysis.py:
--------------------------------------------------------------------------------
1 | from transformers import OlmoeForCausalLM, AutoTokenizer, AutoModelForCausalLM
2 | import torch
3 | import time
4 | import pickle as pkl
5 | import json
6 | import numpy as np
7 | from collections import defaultdict, Counter
8 | from tqdm import tqdm
9 | from pathlib import Path
10 |
11 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12 | token = None
13 |
14 | start_time = time.time()
15 |
16 |
17 | # Adapted from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
18 | # Fixed aggregating over all layers
19 | def load_balancing_loss_func(
20 | gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2
21 | ) -> float:
22 | if gate_logits is None or not isinstance(gate_logits, tuple):
23 | return 0
24 |
25 | if isinstance(gate_logits, tuple):
26 | compute_device = gate_logits[0].device
27 | concatenated_gate_logits = torch.stack([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=1)
28 |
29 | routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
30 |
31 | _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
32 |
33 | expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
34 |
35 | # Compute the percentage of tokens routed to each experts
36 | tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
37 |
38 | # Compute the average probability of routing to these experts
39 | router_prob_per_expert = torch.mean(routing_weights, dim=0)
40 |
41 | overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(-2)) / len(gate_logits)
42 | return overall_loss * num_experts
43 |
44 |
45 | def load_analysis_data(tokenizer, domain, bs):
46 | np.random.seed(2024)
47 | tokens = []
48 |
49 | if domain == "tulu":
50 | data_path = f"routing_output/text/tulu-v3.1.jsonl"
51 | with open(data_path) as f:
52 | for line in f:
53 | text = json.loads(line)["text"]
54 | tokens = tokenizer(text, truncation=False)["input_ids"]
55 | while len(tokens) >= bs:
56 | yield tokens[:bs]
57 | tokens = tokens[bs:]
58 | yield tokens
59 | else:
60 | data_path = f"routing_output/text/{domain}_texts.txt"
61 | with open(data_path) as f:
62 | text = f.read()
63 | tokens = tokenizer(text, truncation=False)["input_ids"]
64 | while len(tokens) >= bs:
65 | yield tokens[:bs]
66 | tokens = tokens[bs:]
67 |
68 | def load_sft_model():
69 | DEVICE = "cuda"
70 | model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924-SFT", token=token).to(DEVICE)
71 | model.eval()
72 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924-SFT", token=token)
73 | return model, tokenizer
74 |
75 |
76 | def load_dpo_model():
77 | DEVICE = "cuda"
78 | model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924-Instruct", token=token).to(DEVICE)
79 | model.eval()
80 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924-Instruct", token=token)
81 | return model, tokenizer
82 |
83 |
84 | def load_model():
85 | DEVICE = "cuda"
86 | model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token).to(DEVICE)
87 | model.eval()
88 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token)
89 | return model, tokenizer
90 |
91 | def load_model_mistral():
92 | model_id = "mistralai/Mixtral-8x7B-v0.1"
93 | model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto', token=token)
94 | model.eval()
95 | tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
96 | return model, tokenizer
97 |
98 |
99 | def print_expert_percentage(exp_counts):
100 | total = sum(exp_counts.values())
101 | for eid, ecount in exp_counts.most_common():
102 | print(f"Expert {eid}: {ecount/total*100:.2f}")
103 |
104 |
105 | def run_analysis(domain, model_name=None):
106 | layer_counters = defaultdict(Counter)
107 | crosslayer_counters = defaultdict(Counter)
108 |
109 | eid2token_layer0 = defaultdict(Counter)
110 | eid2token_layer7 = defaultdict(Counter)
111 | eid2token_layer15 = defaultdict(Counter)
112 |
113 | total_token_count = 0
114 |
115 | aux_losses = []
116 |
117 | # for each expert, what are some commen tokens that are assigned to it?
118 | # write a code to count that and print it out: {expert_id: Counter({token1: 32, token2: 23})...}
119 | for i, input_ids in tqdm(enumerate(load_analysis_data(tokenizer, domain=domain, bs=length))):
120 | input_ids = torch.LongTensor(input_ids).reshape(1, -1).to(DEVICE)
121 | out = model(input_ids=input_ids, output_router_logits=True)
122 |
123 | aux_loss = load_balancing_loss_func(
124 | out["router_logits"],
125 | model.num_experts,
126 | model.num_experts_per_tok,
127 | )
128 | aux_losses.append(aux_loss.cpu().item())
129 |
130 | # input id shapes: 2048 seqlen
131 | input_ids = input_ids[0].detach().cpu().numpy().tolist()
132 | total_token_count += len(input_ids)
133 |
134 | # 16 layer, (2048 tokens, 64 experts)
135 | router_logits = [l.detach().cpu().numpy() for l in out["router_logits"]]
136 | # 2048 tokens, 8 experts, 16 layers
137 | if model_name == "mistral":
138 | exp_ids = np.stack([np.argsort(-logits, -1)[:, :2].tolist() for logits in router_logits], -1)
139 | elif model_name.startswith("olmoe"):
140 | exp_ids = np.stack([np.argsort(-logits, -1)[:, :8].tolist() for logits in router_logits], -1)
141 | # 2048 tokens, 8 experts
142 | exp_ids_layer0 = exp_ids[:, :, 0]
143 | exp_ids_layer7 = exp_ids[:, :, 7]
144 | exp_ids_layer15 = exp_ids[:, :, 15]
145 |
146 | for id, token in enumerate(input_ids):
147 | experts = exp_ids_layer0[id, :]
148 | for e in experts:
149 | eid2token_layer0[e][token] += 1
150 | for e in exp_ids_layer7[id, :]:
151 | eid2token_layer7[e][token] += 1
152 | for e in exp_ids_layer15[id, :]:
153 | eid2token_layer15[e][token] += 1
154 |
155 | for layer in range(exp_ids.shape[2]):
156 | exp_counts = Counter(exp_ids[:, :, layer].flatten())
157 | layer_counters[layer].update(exp_counts)
158 |
159 | for layer_i in range(exp_ids.shape[2] - 1):
160 | for layer_j in range(exp_ids.shape[2]):
161 | exps_counts = Counter(zip(exp_ids[:, :, layer_i].flatten(), exp_ids[:, :, layer_j].flatten()))
162 | crosslayer_counters[(layer_i, layer_j)].update(exps_counts)
163 |
164 | if total_token_count > 204800:
165 | break
166 |
167 | print(f"Average aux loss: {np.mean(aux_losses)}")
168 |
169 | return layer_counters, crosslayer_counters, eid2token_layer0, eid2token_layer7, eid2token_layer15
170 |
171 | name2finaldata = {"github_oss_with_stack": "github", "arxiv": "arxiv", "c4": "c4", "b3g": "book", "wikipedia": "wikipedia", "tulu": "tulu"}
172 |
173 | if __name__=='__main__':
174 | model_name = "olmoe"
175 | print(model_name)
176 | if model_name == "mistral":
177 | model, tokenizer = load_model_mistral()
178 | elif model_name == "olmoe-sft":
179 | model, tokenizer = load_sft_model()
180 | elif model_name == "olmoe-dpo":
181 | model, tokenizer = load_dpo_model()
182 | elif model_name == "olmoe":
183 | model, tokenizer = load_model()
184 |
185 | length = 2048
186 | for domain in tqdm(["tulu", "github_oss_with_stack", "arxiv", "c4", "b3g", "wikipedia"]):
187 | print(f"Domain: {domain}")
188 | layer_counters, crosslayer_counters, eid2token_layer0, eid2token_layer7, eid2token_layer15 = run_analysis(domain, model_name)
189 | Path(f"routing_output/{model_name}/expert_counts").mkdir(parents=True, exist_ok=True)
190 | Path(f"routing_output/{model_name}/expert_counts_crosslayer").mkdir(parents=True, exist_ok=True)
191 | Path(f"routing_output/{model_name}/eid2token").mkdir(parents=True, exist_ok=True)
192 | with open(f"routing_output/{model_name}/expert_counts/{name2finaldata[domain]}.pkl", "wb") as f:
193 | pkl.dump([layer_counters[0], layer_counters[7], layer_counters[15]], f)
194 | with open(f"routing_output/{model_name}/expert_counts_crosslayer/{name2finaldata[domain]}.pkl", "wb") as f:
195 | pkl.dump([crosslayer_counters[(0, 7)], crosslayer_counters[(7, 15)]], f)
196 | with open(f"routing_output/{model_name}/eid2token/{name2finaldata[domain]}.pkl", "wb") as f:
197 | pkl.dump([eid2token_layer0, eid2token_layer7, eid2token_layer15], f)
--------------------------------------------------------------------------------
/scripts/sparsify_ckpt_unsharded.py:
--------------------------------------------------------------------------------
1 | """
2 | 1. Unshard ckpt using `python /home/niklas/OLMoE/scripts/unshard.py /data/niklas/llm/checkpoints/23485/step954000 /data/niklas/llm/checkpoints/1b-954000-unsharded --safe-tensors --model-only`
3 | 2. Run this script via `python /home/niklas/OLMoE/scripts/sparsify_ckpt_unsharded.py /data/niklas/llm/checkpoints/1b-954000-unsharded/model.safetensors`
4 | """
5 | import copy
6 | import sys
7 | import torch
8 | from olmo.safetensors_util import safetensors_file_to_state_dict, state_dict_to_safetensors_file
9 |
10 | path = sys.argv[1]
11 | sd = safetensors_file_to_state_dict(path)
12 | tensors = {}
13 | swiglu = True
14 | noise = False
15 | share = False
16 | interleave = False
17 | n_experts = 8
18 | D = 2048
19 |
20 | def noise_injection(weight, noise_ratio=0.5, init_std=0.02):
21 | mask = torch.FloatTensor(weight.size()).uniform_() < noise_ratio
22 | mask = mask.to(weight.device)
23 | rand_weight = torch.nn.init.normal_(copy.deepcopy(weight), mean=0.0, std=init_std)
24 | weight[mask] = rand_weight[mask]
25 | return weight
26 |
27 | for key in list(sd.keys()):
28 | if "ff_proj.weight" in key:
29 | block_num = int(key.split(".")[2])
30 | if interleave and block_num % 2 == 0:
31 | tensors[key] = sd.pop(key)
32 | continue
33 | new_key = key.replace("ff_proj.weight", "ffn.experts.mlp.w1")
34 | if swiglu:
35 | new_key_v1 = new_key.replace("w1", "v1")
36 | # OLMo takes the F.silu on the second part of the tensor which corresponds to v1
37 | v1, w1 = sd.pop(key).chunk(2, dim=0) # e.g. [16384, 2048]
38 | tensors[new_key] = torch.cat([w1] * n_experts, dim=0)
39 | tensors[new_key_v1] = torch.cat([v1] * n_experts, dim=0)
40 | if noise:
41 | tensors[new_key] = noise_injection(tensors[new_key])
42 | tensors[new_key_v1] = noise_injection(tensors[new_key_v1])
43 | if share:
44 | share_key = new_key.replace("experts.mlp.w1", "shared_expert.up_proj.weight")
45 | share_key_v1 = new_key_v1.replace("experts.mlp.v1", "shared_expert.gate_proj.weight")
46 | tensors[share_key] = w1
47 | tensors[share_key_v1] = v1
48 | else:
49 | tensors[new_key] = torch.cat([sd.pop(key)] * n_experts, dim=0)
50 | elif ("ff_out.weight" in key) and (key != 'transformer.ff_out.weight'):
51 | block_num = int(key.split(".")[2])
52 | if interleave and block_num % 2 == 0:
53 | tensors[key] = sd.pop(key)
54 | continue
55 | new_key = key.replace("ff_out.weight", "ffn.experts.mlp.w2")
56 | w = sd.pop(key)
57 | tensors[new_key] = torch.cat([w.t()] * n_experts, dim=0)
58 | if noise:
59 | tensors[new_key] = noise_injection(tensors[new_key])
60 | if share:
61 | share_key = new_key.replace("experts.mlp.w2", "shared_expert.down_proj.weight")
62 | tensors[share_key] = w
63 | # Add router
64 | router_key = key.replace("ff_out.weight", "ffn.router.layer.weight")
65 | # tensors[router_key] = torch.ones((n_experts, D)).squeeze() # Worse perf
66 | tensors[router_key] = torch.nn.init.normal_(torch.ones((n_experts, D)).squeeze(), std=0.02)
67 | else:
68 | tensors[key] = sd.pop(key)
69 |
70 | state_dict_to_safetensors_file(tensors, path.replace("model.safetensors", "model_sparse.safetensors"))
--------------------------------------------------------------------------------
/scripts/wekatransfer/s3weka.sh:
--------------------------------------------------------------------------------
1 | export BUDGET=ai2/oe-training
2 | export S3_BUCKET=ai2-llm
3 | #export S3_PREFIX=preprocessed/olmo-mix/danyh-compiled-v1_7
4 | #export S3_PREFIX=preprocessed/fastdclm
5 | export S3_PREFIX=preprocessed/starcoder/v1-decon-100_to_20k-2star-top_token_030
6 | export WEKA_BUCKET=oe-training-default
7 | #export WEKA_PREFIX=ai2-llm/preprocessed/danyh-compiled-v1_7
8 | #WEKA_PREFIX=ai2-llm/preprocessed/fastdclm
9 | WEKA_PREFIX=ai2-llm/preprocessed/starcoder/v1-decon-100_to_20k-2star-top_token_030
--------------------------------------------------------------------------------
/scripts/wekatransfer/s3weka.yml:
--------------------------------------------------------------------------------
1 | version: v2
2 | budget: {{.Env.BUDGET}}
3 | description: Sync contents of S3 bucket "{{.Env.S3_BUCKET}}" to WEKA bucket "{{.Env.WEKA_BUCKET}}"
4 | tasks:
5 | - name: sync_s3_to_weka
6 | image:
7 | beaker: ai2/cuda11.8-ubuntu20.04
8 | command: ['aws', 's3', 'sync', 's3://{{.Env.S3_BUCKET}}/{{.Env.S3_PREFIX}}', '/{{.Env.WEKA_BUCKET}}/{{.Env.WEKA_PREFIX}}']
9 | datasets:
10 | - mountPath: /{{.Env.WEKA_BUCKET}}
11 | source:
12 | weka: {{.Env.WEKA_BUCKET}}
13 | envVars:
14 | - name: AWS_ACCESS_KEY_ID
15 | secret: AWS_ACCESS_KEY_ID
16 | - name: AWS_SECRET_ACCESS_KEY
17 | secret: AWS_SECRET_ACCESS_KEY
18 | context:
19 | preemptible: true
20 | constraints:
21 | cluster:
22 | - ai2/jupiter-cirrascale-2
--------------------------------------------------------------------------------
/scripts/wekatransfer/wekas3.yaml:
--------------------------------------------------------------------------------
1 | version: v2
2 | budget: {{.Env.BUDGET}}
3 | description: Sync contents of WEKA bucket "{{.Env.WEKA_BUCKET}}" to S3 bucket "{{.Env.S3_BUCKET}}"
4 | tasks:
5 | - name: sync_weka_to_s3
6 | image:
7 | beaker: ai2/cuda11.8-ubuntu20.04
8 | command: ['aws', 's3', 'sync', '/{{.Env.WEKA_BUCKET}}/{{.Env.WEKA_PREFIX}}', 's3://{{.Env.S3_BUCKET}}/{{.Env.S3_PREFIX}}']
9 | datasets:
10 | - mountPath: /{{.Env.WEKA_BUCKET}}
11 | source:
12 | weka: {{.Env.WEKA_BUCKET}}
13 | envVars:
14 | - name: AWS_ACCESS_KEY_ID
15 | secret: AWS_ACCESS_KEY_ID
16 | - name: AWS_SECRET_ACCESS_KEY
17 | secret: AWS_SECRET_ACCESS_KEY
18 | context:
19 | preemptible: true
20 | constraints:
21 | cluster:
22 | - ai2/jupiter-cirrascale-2
--------------------------------------------------------------------------------
/visuals/emojis/olmoe_checkmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/emojis/olmoe_checkmark.png
--------------------------------------------------------------------------------
/visuals/emojis/olmoe_checkmark_yellow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/emojis/olmoe_checkmark_yellow.png
--------------------------------------------------------------------------------
/visuals/emojis/olmoe_cross.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/emojis/olmoe_cross.png
--------------------------------------------------------------------------------
/visuals/emojis/olmoe_warning.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/emojis/olmoe_warning.png
--------------------------------------------------------------------------------
/visuals/figures/adamweps.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/adamweps.pdf
--------------------------------------------------------------------------------
/visuals/figures/dataset.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/dataset.pdf
--------------------------------------------------------------------------------
/visuals/figures/datasetredditflan.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/datasetredditflan.pdf
--------------------------------------------------------------------------------
/visuals/figures/embdecay.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/embdecay.pdf
--------------------------------------------------------------------------------
/visuals/figures/expertchoice.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/expertchoice.pdf
--------------------------------------------------------------------------------
/visuals/figures/granularity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/granularity.pdf
--------------------------------------------------------------------------------
/visuals/figures/init.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/init.pdf
--------------------------------------------------------------------------------
/visuals/figures/layer_0_heatmap.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/layer_0_heatmap.pdf
--------------------------------------------------------------------------------
/visuals/figures/layer_15_heatmap.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/layer_15_heatmap.pdf
--------------------------------------------------------------------------------
/visuals/figures/layer_7_heatmap.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/layer_7_heatmap.pdf
--------------------------------------------------------------------------------
/visuals/figures/layersharing.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/layersharing.pdf
--------------------------------------------------------------------------------
/visuals/figures/lbl.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/lbl.pdf
--------------------------------------------------------------------------------
/visuals/figures/lblprecision.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/lblprecision.pdf
--------------------------------------------------------------------------------
/visuals/figures/lbltoks.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/lbltoks.pdf
--------------------------------------------------------------------------------
/visuals/figures/ln.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/ln.pdf
--------------------------------------------------------------------------------
/visuals/figures/lndecay.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/lndecay.pdf
--------------------------------------------------------------------------------
/visuals/figures/lngradnorm.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/lngradnorm.pdf
--------------------------------------------------------------------------------
/visuals/figures/loss.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/loss.pdf
--------------------------------------------------------------------------------
/visuals/figures/moevsdense.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/moevsdense.pdf
--------------------------------------------------------------------------------
/visuals/figures/noise.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/noise.pdf
--------------------------------------------------------------------------------
/visuals/figures/olmoe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/olmoe.pdf
--------------------------------------------------------------------------------
/visuals/figures/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/overview.jpg
--------------------------------------------------------------------------------
/visuals/figures/overview.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/overview.pdf
--------------------------------------------------------------------------------
/visuals/figures/qknorm.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/qknorm.pdf
--------------------------------------------------------------------------------
/visuals/figures/routing_mixtral.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/routing_mixtral.pdf
--------------------------------------------------------------------------------
/visuals/figures/routing_olmoe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/routing_olmoe.pdf
--------------------------------------------------------------------------------
/visuals/figures/routing_prob_distribution_mixtral.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/routing_prob_distribution_mixtral.pdf
--------------------------------------------------------------------------------
/visuals/figures/routing_prob_distribution_olmoe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/routing_prob_distribution_olmoe.pdf
--------------------------------------------------------------------------------
/visuals/figures/shared.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/shared.pdf
--------------------------------------------------------------------------------
/visuals/figures/token_specialization_top1_olmoe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/token_specialization_top1_olmoe.pdf
--------------------------------------------------------------------------------
/visuals/figures/token_specialization_top2_mixtral.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/token_specialization_top2_mixtral.pdf
--------------------------------------------------------------------------------
/visuals/figures/token_specialization_top8_olmoe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/token_specialization_top8_olmoe.pdf
--------------------------------------------------------------------------------
/visuals/figures/top18_changes_over_checkpoints.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/top18_changes_over_checkpoints.pdf
--------------------------------------------------------------------------------
/visuals/figures/trainingevalflops.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/trainingevalflops.pdf
--------------------------------------------------------------------------------
/visuals/figures/trainingevaltokens.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/trainingevaltokens.pdf
--------------------------------------------------------------------------------
/visuals/figures/upcycle.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/upcycle.pdf
--------------------------------------------------------------------------------
/visuals/figures/zloss.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/zloss.pdf
--------------------------------------------------------------------------------
/visuals/logos/OLMoE_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/logos/OLMoE_logo.png
--------------------------------------------------------------------------------
/visuals/logos/OLMoE_logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/visuals/logos/OLMoE_logo_alt1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/logos/OLMoE_logo_alt1.png
--------------------------------------------------------------------------------
/visuals/logos/OLMoE_logo_alt1.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/visuals/logos/OLMoE_logo_alt2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/logos/OLMoE_logo_alt2.png
--------------------------------------------------------------------------------
/visuals/logos/OLMoE_logo_alt2.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/visuals/logos/OLMoE_logo_alt3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/logos/OLMoE_logo_alt3.png
--------------------------------------------------------------------------------
/visuals/logos/OLMoE_logo_alt3.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/visuals/poster_iclr2025.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/poster_iclr2025.pdf
--------------------------------------------------------------------------------
/visuals/poster_iclr2025.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/poster_iclr2025.pptx
--------------------------------------------------------------------------------
/visuals/poster_neurips2024.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/poster_neurips2024.pdf
--------------------------------------------------------------------------------
/visuals/twitterblog_images/domainspec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/domainspec.png
--------------------------------------------------------------------------------
/visuals/twitterblog_images/experiments.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/experiments.png
--------------------------------------------------------------------------------
/visuals/twitterblog_images/logo_transparent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/logo_transparent.png
--------------------------------------------------------------------------------
/visuals/twitterblog_images/logo_twitter.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/logo_twitter.png
--------------------------------------------------------------------------------
/visuals/twitterblog_images/overview_base.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/overview_base.png
--------------------------------------------------------------------------------
/visuals/twitterblog_images/overview_left.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/overview_left.png
--------------------------------------------------------------------------------
/visuals/twitterblog_images/overview_long.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/overview_long.png
--------------------------------------------------------------------------------
/visuals/twitterblog_images/overview_right.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/overview_right.png
--------------------------------------------------------------------------------
/visuals/twitterblog_images/perf_adapt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/perf_adapt.png
--------------------------------------------------------------------------------
/visuals/twitterblog_images/perf_during.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/perf_during.png
--------------------------------------------------------------------------------
/visuals/twitterblog_images/perf_pretr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/perf_pretr.png
--------------------------------------------------------------------------------
/visuals/twitterblog_images/perf_pretr_adapt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/perf_pretr_adapt.png
--------------------------------------------------------------------------------
/visuals/twitterblog_images/tokenidspec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/tokenidspec.png
--------------------------------------------------------------------------------