├── .gitignore ├── LICENSE ├── README.md ├── configs ├── OLMoE-1B-7B-0924.yml └── ablations │ ├── olmo-1b-newhp-newds-cx5-datafix.yml │ ├── olmo-1b-newhp-newds-cx5-flan.yml │ ├── olmo-1b-newhp-newds-cx5-reddit.yml │ ├── olmo-1b-newhp-newds-cx5.yml │ ├── olmo-1b-newhp-newds-s3.yml │ ├── olmo-1b-newhp-newds.yml │ ├── olmo-1b-newhp-oldds-cx5.yml │ ├── olmo-1b-newhp-oldds-s3.yml │ ├── olmo-1b-newhp-oldds.yml │ ├── olmoe-8x1b-newhp-newds-cx5-fine-shared-s3.yml │ ├── olmoe-8x1b-newhp-newds-cx5-fine-shared.yml │ ├── olmoe-8x1b-newhp-newds-cx5-fine.yml │ ├── olmoe-8x1b-newhp-newds-cx5-fine05.yml │ ├── olmoe-8x1b-newhp-newds-cx5-fine1-datafix.yml │ ├── olmoe-8x1b-newhp-newds-cx5-fine1-docmask-8k.yml │ ├── olmoe-8x1b-newhp-newds-cx5-fine1-docmask.yml │ ├── olmoe-8x1b-newhp-newds-cx5-fine1-newtok.yml │ ├── olmoe-8x1b-newhp-newds-cx5-fine1-normreorder.yml │ ├── olmoe-8x1b-newhp-newds-cx5-fine1.yml │ ├── olmoe-8x1b-newhp-newds-cx5-k2-fine-s3.yml │ ├── olmoe-8x1b-newhp-newds-cx5-k2-fine.yml │ ├── olmoe-8x1b-newhp-newds-cx5-k2.yml │ ├── olmoe-8x1b-newhp-newds-cx5.yml │ ├── olmoe-8x1b-newhp-newds-final-anneal.yml │ ├── olmoe-8x1b-newhp-newds-final-densecomp.yml │ ├── olmoe-8x1b-newhp-newds-final-double-alt.yml │ ├── olmoe-8x1b-newhp-newds-final-double.yml │ ├── olmoe-8x1b-newhp-newds-final-s3.yml │ ├── olmoe-8x1b-newhp-newds-final-v2.yml │ ├── olmoe-8x1b-newhp-newds-final.yml │ ├── olmoe-8x1b-newhp-newds-k2qk.yml │ ├── olmoe-8x1b-newhp-newds-s3-cx5.yml │ ├── olmoe-8x1b-newhp-newds-s3.yml │ ├── olmoe-8x1b-newhp-newds.yml │ ├── olmoe-8x1b-newhp-oldds.yml │ ├── olmoe-8x2b-newhp-newds-final.yml │ ├── olmoe-8x7b-A7B.yml │ ├── olmoe-8x7b.yml │ ├── olmoe17-16x1b-fullshard-swiglu-wrapb-s1k1.yml │ ├── olmoe17-8x1b-final-decemb.yml │ ├── olmoe17-8x1b-final-decln.yml │ ├── olmoe17-8x1b-final-eps-fine.yml │ ├── olmoe17-8x1b-final-eps-noqk.yml │ ├── olmoe17-8x1b-final-eps.yml │ ├── olmoe17-8x1b-final-fine.yml │ ├── olmoe17-8x1b-final-nodecln.yml │ ├── olmoe17-8x1b-final-normdc.yml │ ├── olmoe17-8x1b-final-weka.yaml │ ├── olmoe17-8x1b-final.yml │ ├── olmoe17-8x1b-fullshard-swiglu-wrapb-k2-qknorm-zloss.yml │ └── olmoe17-8x7b-final.yml ├── logs ├── olmoe-dpo-logs.txt └── olmoe-sft-logs.txt ├── scripts ├── adapteval.sh ├── batchjob.sh ├── eval_openlm_ckpt.py ├── humaneval.yaml ├── llm1b.sh ├── make_table.py ├── megatron.sh ├── megatron_dense_46m_8gpu.sh ├── megatron_dmoe_46m_8gpu.sh ├── olmoe-gantry.sh ├── olmoe_visuals.ipynb ├── plot_routing_analysis.ipynb ├── plot_routing_analysis_v2.ipynb ├── plot_routing_analysis_v2_cross_layer.ipynb ├── plot_routing_analysis_v2_top1.ipynb ├── routing_mixtral_v2.jpg ├── routing_olmoe_v2.jpg ├── routing_output.zip ├── routing_output │ ├── mistral │ │ ├── eid2token │ │ │ ├── arxiv.pkl │ │ │ ├── book.pkl │ │ │ ├── c4.pkl │ │ │ ├── github.pkl │ │ │ └── wikipedia.pkl │ │ ├── expert_counts │ │ │ ├── arxiv.pkl │ │ │ ├── book.pkl │ │ │ ├── c4.pkl │ │ │ ├── github.pkl │ │ │ └── wikipedia.pkl │ │ ├── expert_counts_crosslayer │ │ │ ├── arxiv.pkl │ │ │ ├── book.pkl │ │ │ ├── c4.pkl │ │ │ ├── github.pkl │ │ │ └── wikipedia.pkl │ │ ├── expert_counts_crosslayer_top1 │ │ │ ├── arxiv.pkl │ │ │ ├── book.pkl │ │ │ ├── c4.pkl │ │ │ ├── github.pkl │ │ │ └── wikipedia.pkl │ │ └── expert_counts_top1 │ │ │ ├── arxiv.pkl │ │ │ ├── book.pkl │ │ │ ├── c4.pkl │ │ │ ├── github.pkl │ │ │ └── wikipedia.pkl │ ├── olmoe-dpo │ │ └── expert_counts │ │ │ └── tulu.pkl │ ├── olmoe-sft │ │ └── expert_counts │ │ │ └── tulu.pkl │ ├── olmoe │ │ ├── eid2token │ │ │ ├── arxiv.pkl │ │ │ ├── book.pkl │ │ │ ├── c4.pkl │ │ │ ├── github.pkl │ │ │ └── wikipedia.pkl │ │ ├── expert_counts │ │ │ ├── arxiv.pkl │ │ │ ├── book.pkl │ │ │ ├── c4.pkl │ │ │ ├── github.pkl │ │ │ ├── tulu.pkl │ │ │ └── wikipedia.pkl │ │ ├── expert_counts_crosslayer │ │ │ ├── arxiv.pkl │ │ │ ├── book.pkl │ │ │ ├── c4.pkl │ │ │ ├── github.pkl │ │ │ └── wikipedia.pkl │ │ ├── expert_counts_crosslayer_top1 │ │ │ ├── arxiv.pkl │ │ │ ├── book.pkl │ │ │ ├── c4.pkl │ │ │ ├── github.pkl │ │ │ └── wikipedia.pkl │ │ └── expert_counts_top1 │ │ │ ├── arxiv.pkl │ │ │ ├── book.pkl │ │ │ ├── c4.pkl │ │ │ ├── github.pkl │ │ │ └── wikipedia.pkl │ ├── routing.jpg │ ├── routing_prob_distribution.png │ └── text │ │ ├── arxiv_texts.txt │ │ ├── b3g_texts.txt │ │ ├── c4_texts.txt │ │ ├── github_oss_with_stack_texts.txt │ │ └── wikipedia_texts.txt ├── run_dclm_evals_heavy.sh ├── run_dclm_evals_heavy_olmo.sh ├── run_dclm_evals_humaneval.sh ├── run_moe_analysis.py ├── run_routing_analysis.py ├── sparsify_ckpt_unsharded.py └── wekatransfer │ ├── s3weka.sh │ ├── s3weka.yml │ └── wekas3.yaml └── visuals ├── emojis ├── olmoe_checkmark.png ├── olmoe_checkmark_yellow.png ├── olmoe_cross.png └── olmoe_warning.png ├── figures ├── adamweps.pdf ├── dataset.pdf ├── datasetredditflan.pdf ├── embdecay.pdf ├── expertchoice.pdf ├── granularity.pdf ├── init.pdf ├── layer_0_heatmap.pdf ├── layer_15_heatmap.pdf ├── layer_7_heatmap.pdf ├── layersharing.pdf ├── lbl.pdf ├── lblprecision.pdf ├── lbltoks.pdf ├── ln.pdf ├── lndecay.pdf ├── lngradnorm.pdf ├── loss.pdf ├── moevsdense.pdf ├── noise.pdf ├── olmoe.pdf ├── overview.jpg ├── overview.pdf ├── qknorm.pdf ├── routing_mixtral.pdf ├── routing_olmoe.pdf ├── routing_prob_distribution_mixtral.pdf ├── routing_prob_distribution_olmoe.pdf ├── shared.pdf ├── token_specialization_top1_olmoe.pdf ├── token_specialization_top2_mixtral.pdf ├── token_specialization_top8_olmoe.pdf ├── top18_changes_over_checkpoints.pdf ├── trainingevalflops.pdf ├── trainingevaltokens.pdf ├── upcycle.pdf └── zloss.pdf ├── logos ├── OLMoE_logo.png ├── OLMoE_logo.svg ├── OLMoE_logo_alt1.png ├── OLMoE_logo_alt1.svg ├── OLMoE_logo_alt2.png ├── OLMoE_logo_alt2.svg ├── OLMoE_logo_alt3.png └── OLMoE_logo_alt3.svg ├── poster_iclr2025.pdf ├── poster_iclr2025.pptx ├── poster_neurips2024.pdf └── twitterblog_images ├── domainspec.png ├── experiments.png ├── logo_transparent.png ├── logo_twitter.png ├── overview_base.png ├── overview_left.png ├── overview_long.png ├── overview_right.png ├── perf_adapt.png ├── perf_during.png ├── perf_pretr.png ├── perf_pretr_adapt.png └── tokenidspec.png /.gitignore: -------------------------------------------------------------------------------- 1 | **/results/* 2 | .DS_Store 3 | */.DS_Store 4 | */*/.DS_Store 5 | */*/*/.DS_Store 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

OLMoE: Open Mixture-of-Experts Language Models

4 |

Fully open, state-of-the-art Mixture of Expert model with 1.3 billion active and 6.9 billion total parameters. All data, code, and logs released.

5 |
6 |
7 | 8 | ![](visuals/figures/overview.jpg) 9 | 10 | This repository provides an overview of all resources for the paper ["OLMoE: Open Mixture-of-Experts Language Models"](https://arxiv.org/abs/2409.02060). 11 | 12 | - [Artifacts](#artifacts) 13 | - [Inference](#inference) 14 | - [Pretraining](#pretraining) 15 | - [Adaptation](#adaptation) 16 | - [Evaluation](#evaluation) 17 | - [During pretraining](#during-pretraining) 18 | - [After pretraining](#after-pretraining) 19 | - [After adaptation](#after-adaptation) 20 | - [Visuals](#visuals) 21 | - [Citation](#citation) 22 | 23 | ### Artifacts 24 | 25 | - **Paper**: https://arxiv.org/abs/2409.02060 26 | - **Pretraining** [Checkpoints](https://hf.co/allenai/OLMoE-1B-7B-0924), [Final Checkpoint GGUF](https://hf.co/allenai/OLMoE-1B-7B-0924-GGUF), [Code](https://github.com/allenai/OLMo/tree/Muennighoff/MoE), [Data](https://huggingface.co/datasets/allenai/OLMoE-mix-0924) and [Logs](https://wandb.ai/ai2-llm/olmoe/reports/OLMoE-1B-7B-0924--Vmlldzo4OTcyMjU3). 27 | - **SFT (Supervised Fine-Tuning)** [Checkpoints](https://huggingface.co/allenai/OLMoE-1B-7B-0924-SFT), [Code](https://github.com/allenai/open-instruct/), [Data](https://hf.co/datasets/allenai/tulu-v3.1-mix-preview-4096-OLMoE) and [Logs](https://github.com/allenai/OLMoE/blob/main/logs/olmoe-sft-logs.txt). 28 | - **DPO/KTO (Direct Preference Optimization/Kahneman-Tversky Optimization)**, [Checkpoints](https://huggingface.co/allenai/OLMoE-1B-7B-0924-Instruct), [Final Checkpoint GGUF](https://hf.co/allenai/OLMoE-1B-7B-0924-Instruct-GGUF), [Preference Data](https://hf.co/datasets/allenai/ultrafeedback_binarized_cleaned), [DPO code](https://github.com/allenai/open-instruct/), [KTO code](https://github.com/Muennighoff/kto/blob/master/kto.py) and [Logs](https://github.com/allenai/OLMoE/blob/main/logs/olmoe-dpo-logs.txt). 29 | 30 | ### Inference 31 | 32 | OLMoE has been integrated into [vLLM](https://github.com/vllm-project/vllm), [SGLang](https://github.com/sgl-project/sglang), [llama.cpp](https://github.com/ggerganov/llama.cpp), and [transformers](https://github.com/huggingface/transformers). The transformers implementation is slow, thus we recommend using the others, e.g. vLLM, where possible. Below are examples for using it with vLLM and transformers. 33 | 34 | #### vLLM 35 | 36 | Install the `vllm` library and run: 37 | 38 | ```python 39 | from vllm import LLM, SamplingParams 40 | model = LLM("allenai/OLMoE-1B-7B-0924") 41 | out = model.generate("Bitcoin is", SamplingParams(temperature=0.0)) 42 | print("Bitcoin is" + out[0].outputs[0].text) 43 | # Bitcoin is a digital currency that is not controlled by any central authority. It is a peer 44 | ``` 45 | 46 | #### llama.cpp 47 | 48 | Install `llama.cpp`, download a quantized GGUF of the final checkpoint (e.g. [`olmoe-1b-7b-0924-q4_0.gguf`](https://hf.co/allenai/OLMoE-1B-7B-0924-GGUF/resolve/main/olmoe-1b-7b-0924-q4_0.gguf)) and run in a shell: 49 | 50 | ```bash 51 | llama-cli -m olmoe-1b-7b-0924-q4_0.gguf -p "Bitcoin is" -n 128 52 | ``` 53 | 54 | #### transformers 55 | 56 | Install the `transformers` & `torch` libraries and run: 57 | 58 | ```python 59 | from transformers import OlmoeForCausalLM, AutoTokenizer 60 | import torch 61 | 62 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 63 | 64 | # Load different ckpts via passing e.g. `revision=step10000-tokens41B` 65 | # also check allenai/OLMoE-1B-7B-0924-SFT & allenai/OLMoE-1B-7B-0924-Instruct 66 | model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924").to(DEVICE) 67 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924") 68 | inputs = tokenizer("Bitcoin is", return_tensors="pt") 69 | inputs = {k: v.to(DEVICE) for k, v in inputs.items()} 70 | out = model.generate(**inputs, max_length=64) 71 | print(tokenizer.decode(out[0])) 72 | # Bitcoin is a digital currency that is created and held electronically. No one controls it. Bitcoins aren’t printed, like dollars or euros – they’re produced by people and businesses running computers all around the world, using software that solves mathematical 73 | ``` 74 | 75 | You can list all revisions/branches by installing `huggingface-hub` & running: 76 | ```python 77 | from huggingface_hub import list_repo_refs 78 | out = list_repo_refs("allenai/OLMoE-1B-7B-0924") 79 | branches = [b.name for b in out.branches] 80 | ``` 81 | 82 | ### Pretraining 83 | 84 | 1. Clone this [OLMo branch](https://github.com/allenai/OLMo/tree/Muennighoff/MoE) & create an environment with its dependencies via `cd OLMo; pip install -e .`. If you want to use new features in OLMo clone from the `main` branch instead. 85 | 2. Run `pip install git+https://github.com/Muennighoff/megablocks.git@olmoe` 86 | 3. Setup a config file. `configs/OLMoE-1B-7B-0924.yml` was used for the pretraining of `OLMoE-1B-7B-0924`. You can find configs from various ablations in `configs/ablations`. 87 | 4. Download the data from https://hf.co/datasets/allenai/OLMoE-mix-0924, tokenize it via the command below and adapt the `paths` in your training config to point to it. 88 | ```bash 89 | dolma tokens \ 90 | --documents ${PATH_TO_DOWNLOADED_DATA} \ 91 | --destination ${PATH_WHERE_TO_SAVE_TOKENIZED_DATA} \ 92 | --tokenizer.name_or_path 'allenai/gpt-neox-olmo-dolma-v1_5' \ 93 | --max_size '2_147_483_648' \ 94 | --seed 0 \ 95 | --tokenizer.eos_token_id 50279 \ 96 | --tokenizer.pad_token_id 1 \ 97 | --processes ${NUMBER_OF_CPU_CORES_TO_USE} 98 | ``` 99 | 6. Submit your job. We used `bash scripts/olmoe-gantry.sh` which invokes https://github.com/allenai/OLMo/blob/Muennighoff/MoE/scripts/train.py and uses [beaker gantry](https://github.com/allenai/beaker-gantry) but you will likely need to change the script to work with your setup. 100 | 7. To run annealing after the main pretraining we use [this config](https://github.com/allenai/OLMoE/blob/main/configs/ablations/olmoe-8x1b-newhp-newds-final-anneal.yml) - the only changes from the pretraining config are the `optimizer` and `scheduler` fields as well as `max_duration` and `stop_at`. 101 | 8. To convert you pretraining checkpoint to Hugging Face transformers after training, you can use the script & instructions [here](https://github.com/huggingface/transformers/blob/8f8af0fb38baa851f3fd69f564fbf91b5af78332/src/transformers/models/olmoe/convert_olmoe_weights_to_hf.py#L14). 102 | 103 | #### Other design choices 104 | 105 | For most of our experiments on other design choices, you can simply set them in the config file (e.g. change the respective hyperparam), except for: 106 | 1. **Sparse upcycling:** To sparse upcycle your model, train it dense first using e.g. [this config](https://github.com/allenai/OLMoE/blob/main/configs/ablations/olmo-1b-newhp-newds-cx5.yml), then convert any of its checkpoints into an MoE using [this script](https://github.com/allenai/OLMoE/blob/main/scripts/sparsify_ckpt_unsharded.py) & its instructions at the top while making sure to modify the hardcoded values (num experts etc) as you'd like your model to be, then place the newly created model (`model_sparse.safetensors`) into a new folder with a name that ends in `-unsharded` and place the model file inside of it with the name `model.safetensors`, then launch a job that loads this model similar to [our sparse upcycling job](https://wandb.ai/ai2-llm/olmoe/runs/1w3srbb3/overview) (note the settings `--load_path=path_to_upcycled_ckpt --reset_optimizer_state=True --reset_trainer_state=True` and `--fast_forward_batches=XXX` if you also want to continue on the same dataset with the same order). Also make sure to have the changes from this PR in your code: https://github.com/allenai/OLMo/pull/573. Finally, if you want to reproduce upcycling from OLMo-1B (0724) as in the paper, the OLMo 1B checkpoint turned into an MoE with 8 experts to start from is here: https://huggingface.co/allenai/OLMo-1B-0724-954000steps-unsharded; download the files inside of it (e.g. `wget https://huggingface.co/allenai/OLMo-1B-0724-954000steps-unsharded/resolve/main/model.safetensors`), then use a config similar to [this one](https://wandb.ai/ai2-llm/olmoe/runs/1w3srbb3/overview) to train the upcycled MoE from it. 107 | 2. **Expert choice:** To run experiments with our expert choice implementation, you need to instead use the olmo branch `Muennighoff/OLMoSE` or simply copy over the small config changes that enable expert choice (i.e. [here](https://github.com/allenai/OLMo/blob/b7b312aa2d9ee0ec0816a042955f34e27f9f4628/olmo/config.py#L524)) to the `Muennighoff/MoE` branch. You can then run expert choice by activating it in your config (it will use this code: https://github.com/Muennighoff/megablocks/blob/4a25bc7b5665bcb9da93d72d5ad0c14d41e1a351/megablocks/layers/moe.py#L462 or https://github.com/Muennighoff/megablocks/blob/4a25bc7b5665bcb9da93d72d5ad0c14d41e1a351/megablocks/layers/moe.py#L477 depending on your selection; both should be ~equivalent implementations of expert choice; neither was better than dropless token choice in our experiments) 108 | 3. **Shared MoE layers (Appendix):** For these experiments you need to use the olmo branch `Muennighoff/OLMoSE` and create your own config e.g. like the one used in [this run](https://wandb.ai/ai2-llm/olmoe/reports/Plot-Shared-vs-Dense--Vmlldzo4NDI0MTc5). 109 | 110 | ### Adaptation 111 | 112 | 1. Clone Open Instruct [here](https://github.com/allenai/open-instruct/) & follow its setup instructions. If you run into any problems, try upgrading your transformers version with `pip install --upgrade transformers` first. 113 | 2. SFT: After adapting as needed, run: 114 | ``` 115 | accelerate launch \ 116 | --mixed_precision bf16 \ 117 | --num_machines 1 \ 118 | --num_processes 8 \ 119 | --use_deepspeed \ 120 | --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \ 121 | open_instruct/finetune.py \ 122 | --model_name_or_path allenai/OLMoE-1B-7B-0924 \ 123 | --tokenizer_name allenai/OLMoE-1B-7B-0924 \ 124 | --use_flash_attn \ 125 | --max_seq_length 4096 \ 126 | --preprocessing_num_workers 128 \ 127 | --per_device_train_batch_size 2 \ 128 | --gradient_accumulation_steps 8 \ 129 | --learning_rate 2e-05 \ 130 | --lr_scheduler_type linear \ 131 | --warmup_ratio 0.03 \ 132 | --weight_decay 0.0 \ 133 | --num_train_epochs 2 \ 134 | --output_dir output/ \ 135 | --with_tracking \ 136 | --report_to wandb \ 137 | --logging_steps 1 \ 138 | --reduce_loss sum \ 139 | --model_revision main \ 140 | --dataset_mixer_list allenai/tulu-v3-mix-preview-4096-OLMoE 1.0 ai2-adapt-dev/daring-anteater-specialized 1.0 \ 141 | --checkpointing_steps epoch \ 142 | --add_bos 143 | ``` 144 | 4. DPO: Run 145 | ``` 146 | accelerate launch \ 147 | --mixed_precision bf16 \ 148 | --num_machines 1 \ 149 | --num_processes 8 \ 150 | --use_deepspeed \ 151 | --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \ 152 | open_instruct/dpo_tune.py \ 153 | --model_name_or_path allenai/OLMoE-1B-7B-0924-SFT \ 154 | --tokenizer_name allenai/OLMoE-1B-7B-0924-SFT \ 155 | --use_flash_attn \ 156 | --gradient_checkpointing \ 157 | --dataset_name argilla/ultrafeedback-binarized-preferences-cleaned \ 158 | --max_seq_length 4096 \ 159 | --preprocessing_num_workers 16 \ 160 | --per_device_train_batch_size 1 \ 161 | --gradient_accumulation_steps 4 \ 162 | --learning_rate 5e-7 \ 163 | --lr_scheduler_type linear \ 164 | --warmup_ratio 0.1 \ 165 | --weight_decay 0. \ 166 | --num_train_epochs 3 \ 167 | --output_dir output/ \ 168 | --report_to tensorboard \ 169 | --logging_steps 1 \ 170 | --reduce_loss sum \ 171 | --add_bos \ 172 | --checkpointing_steps epoch \ 173 | --dpo_beta 0.1 174 | ``` 175 | 6. KTO: Install `trl` and run https://github.com/Muennighoff/kto/blob/master/kto.py via `WANDB_PROJECT=olmoe accelerate launch --config_file=config_8gpusdsz2_m7.yml kto.py --model_name_or_path allenai/OLMoE-1B-7B-0924-SFT --output_dir OLMoE-1B-7B-0924-SFT-KTO-3EP --report_to "wandb" --per_device_train_batch_size 4 --gradient_accumulation_steps 1 --optim rmsprop --learning_rate 5e-07 --beta 0.1 --logging_steps 1 --bf16 --sanity_check False --num_train_epochs 3` (if you want to run the Adam optimizer change to `--optim adamw_torch`). We used `trl==0.9.6`. 176 | 177 | ### Evaluation 178 | 179 | #### During pretraining 180 | 181 | Evaluation during pretraining is done automatically and configured in the config file. It uses the code here: https://github.com/allenai/OLMo/tree/Muennighoff/MoE/olmo/eval. 182 | 183 | #### After pretraining 184 | 185 | OLMES Evals: Follow the instructions at https://github.com/allenai/OLMo-Eval/blob/51c5ba579e75ef4ce7e9b29936eaa72c1a0e99eb/olmo_eval/tasks/olmes_v0_1/README.md 186 | 187 | DCLM Evals: Run `scripts/run_dclm_evals*` and refer to instructions from https://github.com/mlfoundations/dclm 188 | 189 | #### After adaptation 190 | 191 | - Setup https://github.com/allenai/open-instruct/ 192 | - Run `sbatch scripts/adapteval.sh` after changing it as necessary / extract the commands from the script and run them one by one. 193 | 194 | ### Visuals 195 | 196 | - Figure 1, `visuals/figures/overview.pdf`: Run "Main plot" in `scripts/olmoe_visuals.ipynb` equivalent to [this colab](https://colab.research.google.com/drive/15PTwmoxcbrwWKG6ErY44hlJlLLKAj7Hx?usp=sharing) and add the result into this drawing to edit it further: https://docs.google.com/drawings/d/1Of9-IgvKH54zhKI_M4x5HOYEF4XUp6qaXluT3Zmv1vk/edit?usp=sharing (the drawing used for [this tweet](https://x.com/Muennighoff/status/1831159130230587486) is [here](https://docs.google.com/drawings/d/133skqIfE8f7iOO9hMidV5tBZ7MAle28VxEm9U8Q8ioU/edit?usp=sharing)) 197 | - Figure 2, `visuals/figures/olmoe.pdf`: https://www.figma.com/design/Es8UpNHKgugMAncPWnSDuK/olmoe?node-id=0-1&t=SeuQKPlaoB12TXqe-1 (also contains some other figures used on Twitter) 198 | - Figure 3 & 25, `visuals/figures/trainingeval*pdf`: Run "During training" in `scripts/olmoe_visuals.ipynb` equivalent to [this colab](https://colab.research.google.com/drive/15PTwmoxcbrwWKG6ErY44hlJlLLKAj7Hx?usp=sharing) 199 | - Figure 4 - 19, 24, 26-29, `visuals/figures/...pdf`: Run respective parts in `scripts/olmoe_visuals.ipynb` equivalent to [this colab](https://colab.research.google.com/drive/15PTwmoxcbrwWKG6ErY44hlJlLLKAj7Hx?usp=sharing) 200 | - Figure 20, 21, 23, 30, 31, Table 8, `visuals/figures/...pdf`: `scripts/run_moe_analysis.py` (If you do not want to rerun inference on the model generate the routing statistics you can download them from https://huggingface.co/datasets/allenai/analysis_olmoe & https://huggingface.co/datasets/allenai/analysis_mixtral) 201 | - Figure 22, 33-36 `visuals/figures/...pdf`: Run `scripts/run_routing_analysis.py` & then `scripts/plot_routing_analysis_v2.ipynb` / `scripts/plot_routing_analysis_v2_top1.ipynb` / `scripts/plot_routing_analysis_v2_cross_layer.ipynb` 202 | - Figure 32, `visuals/figures/...pdf`: Run `scripts/run_routing_analysis.py` & then `scripts/plot_routing_analysis.ipynb` 203 | - Table 13: `scripts/make_table.py` 204 | - All other tables are manually created. 205 | 206 | ### Citation 207 | 208 | ```bibtex 209 | @misc{muennighoff2024olmoeopenmixtureofexpertslanguage, 210 | title={OLMoE: Open Mixture-of-Experts Language Models}, 211 | author={Niklas Muennighoff and Luca Soldaini and Dirk Groeneveld and Kyle Lo and Jacob Morrison and Sewon Min and Weijia Shi and Pete Walsh and Oyvind Tafjord and Nathan Lambert and Yuling Gu and Shane Arora and Akshita Bhagia and Dustin Schwenk and David Wadden and Alexander Wettig and Binyuan Hui and Tim Dettmers and Douwe Kiela and Ali Farhadi and Noah A. Smith and Pang Wei Koh and Amanpreet Singh and Hannaneh Hajishirzi}, 212 | year={2024}, 213 | eprint={2409.02060}, 214 | archivePrefix={arXiv}, 215 | primaryClass={cs.CL}, 216 | url={https://arxiv.org/abs/2409.02060}, 217 | } 218 | ``` 219 | -------------------------------------------------------------------------------- /scripts/adapteval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=eval 3 | #SBATCH --nodes=1 4 | #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! 5 | #SBATCH --partition=a3 6 | #SBATCH --gres=gpu:8 # number of gpus 7 | #SBATCH --time 24:00:00 # maximum execution time (HH:MM:SS) 8 | #SBATCH --output=jobs/%x-%j.out # output file name 9 | #SBATCH --exclusive 10 | #SBATCH --array=0-6%7 # adjusted array size and concurrency 11 | 12 | MODEL_PATH=allenai/OLMoE-1B-7B-0924-Instruct 13 | TOKENIZER_PATH=$MODEL_PATH 14 | 15 | cd ~/open-instruct 16 | conda activate YOURENV 17 | 18 | # Commands array 19 | case $SLURM_ARRAY_TASK_ID in 20 | 0) 21 | python -m eval.mmlu.run_eval \ 22 | --ntrain 0 \ 23 | --data_dir /data/niklas/data/eval/mmlu/ \ 24 | --save_dir ${MODEL_PATH}/eval/mmlu \ 25 | --model_name_or_path ${MODEL_PATH} \ 26 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 27 | --use_chat_format \ 28 | --chat_formatting_function eval.templates.create_prompt_with_huggingface_tokenizer_template \ 29 | --eval_batch_size 64 30 | ;; 31 | 1) 32 | python -m eval.gsm.run_eval \ 33 | --data_dir /data/niklas/data/eval/gsm/ \ 34 | --max_num_examples 200 \ 35 | --save_dir ${MODEL_PATH}/eval/gsm8k \ 36 | --model_name_or_path ${MODEL_PATH} \ 37 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 38 | --n_shot 8 \ 39 | --use_chat_format \ 40 | --chat_formatting_function eval.templates.create_prompt_with_huggingface_tokenizer_template \ 41 | --eval_batch_size 64 42 | ;; 43 | 2) 44 | OPENAI_API_KEY=YOUR_KEY IS_ALPACA_EVAL_2=False python -m eval.alpaca_farm.run_eval \ 45 | --model_name_or_path ${MODEL_PATH} \ 46 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 47 | --save_dir ${MODEL_PATH}/eval/alpaca \ 48 | --use_chat_format \ 49 | --chat_formatting_function eval.templates.create_prompt_with_huggingface_tokenizer_template \ 50 | --eval_batch_size 128 51 | ;; 52 | 3) 53 | python -m eval.bbh.run_eval \ 54 | --data_dir /data/niklas/data/eval/bbh/ \ 55 | --save_dir ${MODEL_PATH}/eval/bbh \ 56 | --model_name_or_path ${MODEL_PATH} \ 57 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 58 | --max_num_examples_per_task 40 \ 59 | --use_chat_format \ 60 | --chat_formatting_function eval.templates.create_prompt_with_huggingface_tokenizer_template \ 61 | --eval_batch_size 64 62 | ;; 63 | 4) 64 | OPENAI_API_KEY=YOUR_KEY python -m eval.xstest.run_eval \ 65 | --data_dir /data/niklas/data/eval/xstest/ \ 66 | --save_dir ${MODEL_PATH}/eval/xstest \ 67 | --model_name_or_path ${MODEL_PATH} \ 68 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 69 | --use_chat_format \ 70 | --chat_formatting_function eval.templates.create_prompt_with_huggingface_tokenizer_template \ 71 | --eval_batch_size 64 72 | ;; 73 | 5) 74 | python -m eval.codex_humaneval.run_eval \ 75 | --data_file /data/niklas/data/eval/codex_humaneval/HumanEval.jsonl.gz \ 76 | --eval_pass_at_ks 1 5 10 20 \ 77 | --unbiased_sampling_size_n 20 \ 78 | --temperature 0.8 \ 79 | --save_dir ${MODEL_PATH}/eval/humaneval \ 80 | --model_name_or_path ${MODEL_PATH} \ 81 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 82 | --eval_batch_size 64 83 | ;; 84 | 6) 85 | python -m eval.ifeval.run_eval \ 86 | --data_dir /data/niklas/data/eval/ifeval/ \ 87 | --save_dir ${MODEL_PATH}/eval/ifeval \ 88 | --model_name_or_path ${MODEL_PATH} \ 89 | --tokenizer_name_or_path ${TOKENIZER_PATH} \ 90 | --use_chat_format \ 91 | --chat_formatting_function eval.templates.create_prompt_with_huggingface_tokenizer_template \ 92 | --eval_batch_size 64 93 | ;; 94 | *) 95 | echo "Invalid SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID" 96 | exit 1 97 | ;; 98 | esac 99 | -------------------------------------------------------------------------------- /scripts/batchjob.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=llm 3 | #SBATCH --nodes=1 4 | #SBATCH --ntasks-per-node=1 5 | #SBATCH --hint=nomultithread # we get physical cores not logical 6 | #SBATCH --partition=a3mixed 7 | #SBATCH --gres=gpu:1 # number of gpus 8 | #SBATCH --output=/data/niklas/jobs/%x-%j.out # output file name 9 | #SBATCH --exclusive 10 | 11 | ###################### 12 | ### Set enviroment ### 13 | ###################### 14 | cd /home/niklas/OLMoE 15 | source /env/bin/start-ctx-user 16 | conda activate llmoe 17 | export WANDB_PROJECT="olmoe" 18 | # Training setup 19 | set -euo pipefail 20 | export GPUS_PER_NODE=1 21 | export NODENAME=$(hostname -s) 22 | export MASTER_ADDR=$(scontrol show hostnames | head -n 1) 23 | export MASTER_PORT=39594 24 | export WORLD_SIZE=$SLURM_NTASKS 25 | export RANK=$SLURM_PROCID 26 | export FS_LOCAL_RANK=$SLURM_PROCID 27 | export LOCAL_WORLD_SIZE=$SLURM_NTASKS_PER_NODE 28 | export LOCAL_RANK=$SLURM_LOCALID 29 | export NODE_RANK=$((($RANK - $LOCAL_RANK) / $LOCAL_WORLD_SIZE)) 30 | export R2_PROFILE=r2 31 | export R2_ENDPOINT_URL=YOUR_URL 32 | export AWS_ACCESS_KEY_ID=YOUR_KEY 33 | export AWS_SECRET_ACCESS_KEY=YOUR_KEY 34 | export HF_DATASETS_OFFLINE=1 35 | #export CUDA_LAUNCH_BLOCKING=1 36 | export TRITON_CACHE_DIR=/data/niklas/tcache 37 | 38 | echo "World size: $WORLD_SIZE" 39 | ###################### 40 | 41 | ###################### 42 | #### Set network ##### 43 | ###################### 44 | head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 45 | ###################### 46 | 47 | python -u scripts/train.py configs/olmoe17/olmoe17-8x1b-fullshard-swiglu-wrapb-ecg-k2.yml \ 48 | "--save_folder=/data/niklas/llm/checkpoints/${SLURM_JOB_ID}/" \ 49 | --eval_interval=1000 \ 50 | --save_interval=5000 \ 51 | --global_train_batch_size=16 \ 52 | --device_train_microbatch_size=2 \ 53 | --save_overwrite=True \ 54 | --evaluators=[] 55 | 56 | # srun \ 57 | # --distribution=block:block \ 58 | # --kill-on-bad-exit \ 59 | # scripts/run_with_environment.sh \ 60 | # python -u scripts/train.py configs/olmoe17/olmoe-8x1b-newhp-newds-s3.yml \ 61 | # "--save_folder=/data/niklas/llm/checkpoints/${SLURM_JOB_ID}/" \ 62 | # --save_overwrite \ 63 | # --fsdp.sharding_strategy=HYBRID_SHARD \ 64 | # --device_train_microbatch_size=2 65 | -------------------------------------------------------------------------------- /scripts/eval_openlm_ckpt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import builtins as __builtin__ 3 | import json 4 | import os 5 | import shutil 6 | import subprocess 7 | import sys 8 | import time 9 | import uuid 10 | from collections import defaultdict 11 | from datetime import datetime 12 | from pathlib import Path 13 | from typing import List 14 | 15 | 16 | sys.path.insert(0, str(Path(__file__).parent.parent)) 17 | 18 | import numpy as np 19 | import pandas as pd 20 | import pytz 21 | import random 22 | import torch 23 | from aggregated_metrics import get_aggregated_results 24 | from utils import update_args_from_openlm_config 25 | from composer.loggers import InMemoryLogger, LoggerDestination 26 | from composer.trainer import Trainer 27 | from composer.utils import dist, get_device, reproducibility 28 | from llmfoundry.utils.builders import build_icl_evaluators, build_logger 29 | from omegaconf import OmegaConf as om 30 | 31 | from open_lm.hf import * 32 | 33 | from open_lm.attention import ATTN_ACTIVATIONS, ATTN_SEQ_SCALARS 34 | from open_lm.data import get_data 35 | from open_lm.distributed import init_distributed_device, is_master, world_info_from_env 36 | from open_lm.model import create_params 37 | from open_lm.main import load_model 38 | from open_lm.evaluate import evaluate_loop 39 | from open_lm.file_utils import pt_load 40 | from open_lm.utils.llm_foundry_wrapper import SimpleComposerOpenLMCausalLM 41 | from open_lm.utils.transformers.hf_config import OpenLMConfig 42 | from open_lm.utils.transformers.hf_model import OpenLMforCausalLM 43 | from pytz import timezone 44 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GPTNeoXTokenizerFast, LlamaTokenizerFast 45 | 46 | from training.file_utils import download_val_data, load_ppl_yaml 47 | 48 | builtin_print = __builtin__.print 49 | 50 | hf_access_token = open("hf_access_token.txt").read().strip() 51 | 52 | def setup_for_distributed(is_master): 53 | def print(*args, **kwargs): 54 | force = kwargs.pop("force", False) 55 | if is_master or force: 56 | builtin_print(*args, **kwargs) 57 | 58 | __builtin__.print = print 59 | 60 | 61 | def convert_gpqa(gpqa_dir, outdir): 62 | os.makedirs(outdir, exist_ok=True) 63 | for filename in ["gpqa_main.csv", "gpqa_diamond.csv", "gpqa_extended.csv"]: 64 | inpath = os.path.join(gpqa_dir, filename) 65 | outpath = os.path.join(outdir, Path(inpath).with_suffix(".jsonl").name) 66 | with open(outpath, "w") as f: 67 | df = pd.read_csv(inpath) 68 | rng = random.Random(42) 69 | for i in range(df.shape[0]): 70 | question = df["Question"][i] 71 | choices = np.array( 72 | [ 73 | df["Correct Answer"][i], 74 | df["Incorrect Answer 1"][i], 75 | df["Incorrect Answer 2"][i], 76 | df["Incorrect Answer 3"][i], 77 | ] 78 | ) 79 | idx = list(range(4)) 80 | rng.shuffle(idx) 81 | choices = choices[idx] 82 | gold = idx.index(0) 83 | data = {"query": question, "choices": choices.tolist(), "gold": gold} 84 | f.write(json.dumps(data)) 85 | f.write("\n") 86 | 87 | return 88 | 89 | 90 | def check_and_download_data(): 91 | if not os.path.exists("local_data"): 92 | current_dir = os.path.dirname(os.path.realpath(__file__)) 93 | 94 | if os.path.exists(f"{current_dir}/local_data"): 95 | shutil.copytree(f"{current_dir}/local_data", "local_data") 96 | else: 97 | if dist.get_global_rank() == 0: 98 | print("local_data folder does not exist. Running bash script...") 99 | script_path = os.path.join(current_dir, "download_eval_data.sh") 100 | 101 | subprocess.call([script_path]) 102 | 103 | else: 104 | # Let other workers sleep a bit before barrier. 105 | time.sleep(10) 106 | dist.barrier() 107 | 108 | if not os.path.exists("gpqa_data"): 109 | repo_dir = os.path.dirname(os.path.realpath(__file__)) 110 | if dist.get_global_rank() == 0: 111 | subprocess.run( 112 | [ 113 | "unzip", 114 | "-P", 115 | "deserted-untie-orchid", 116 | os.path.join(repo_dir, "gpqa/dataset.zip"), 117 | "-d", 118 | "gpqa_data_orig/", 119 | ], 120 | check=True, 121 | stdout=subprocess.PIPE, 122 | stderr=subprocess.PIPE, 123 | ) 124 | convert_gpqa("gpqa_data_orig/dataset", "gpqa_data") 125 | shutil.rmtree("gpqa_data_orig") 126 | else: 127 | time.sleep(10) 128 | dist.barrier() 129 | 130 | print("Done downloading data.") 131 | return 132 | 133 | 134 | @torch.no_grad() 135 | def evaluate(model, tokenizer, cfg): 136 | cfg.dist_timeout = cfg.get("dist_timeout", 600.0) 137 | 138 | reproducibility.seed_all(cfg.seed) 139 | dist.initialize_dist(get_device(None), timeout=cfg.dist_timeout) 140 | setup_for_distributed(dist.get_global_rank() == 0) 141 | 142 | # Check if the data is downloaded, if not, download it. 143 | check_and_download_data() 144 | 145 | composer_model = SimpleComposerOpenLMCausalLM(model, tokenizer) 146 | 147 | icl_tasks_w_categories = list( 148 | filter(lambda x: 0 if "has_categories" not in x else x["has_categories"], cfg.icl_tasks) 149 | ) 150 | icl_tasks_w_categories = list(map(lambda x: x["label"], icl_tasks_w_categories)) 151 | 152 | evaluators, logger_keys = build_icl_evaluators( 153 | cfg.icl_tasks, tokenizer, cfg.max_seq_len, cfg.device_eval_batch_size, 154 | # icl_subset_num_batches=cfg.eval_subset_num_batches 155 | ) 156 | 157 | in_memory_logger = InMemoryLogger() # track metrics in the in_memory_logger 158 | loggers: List[LoggerDestination] = [ 159 | build_logger(name, logger_cfg) for name, logger_cfg in (cfg.get("loggers") or {}).items() 160 | ] 161 | loggers.append(in_memory_logger) 162 | 163 | fsdp_config = None 164 | fsdp_config = om.to_container(fsdp_config, resolve=True) if fsdp_config is not None else None 165 | 166 | load_path = cfg.get("load_path", None) 167 | 168 | trainer = Trainer( 169 | model=composer_model, 170 | loggers=loggers, 171 | precision=cfg.precision, 172 | fsdp_config=fsdp_config, # type: ignore 173 | load_path=load_path, 174 | load_weights_only=True, 175 | progress_bar=False, 176 | log_to_console=True, 177 | dist_timeout=cfg.dist_timeout, 178 | # eval_subset_num_batches=cfg.eval_subset_num_batches 179 | ) 180 | 181 | if torch.cuda.is_available(): 182 | torch.cuda.synchronize() 183 | a = time.time() 184 | trainer.eval( 185 | eval_dataloader=evaluators, 186 | # subset_num_batches=cfg.eval_subset_num_batches 187 | ) 188 | 189 | if torch.cuda.is_available(): 190 | torch.cuda.synchronize() 191 | b = time.time() 192 | 193 | print(f"Ran eval in: {b-a} seconds") 194 | 195 | performance_on_tasks = defaultdict(list) 196 | for key in logger_keys: 197 | if key in in_memory_logger.data: 198 | result = in_memory_logger.data[key][0][1].item() 199 | flag = True 200 | if len(icl_tasks_w_categories) > 0: 201 | for task in icl_tasks_w_categories: 202 | if task in key: 203 | performance_on_tasks[task].append(result) 204 | flag = False 205 | if flag: 206 | performance_on_tasks[key].append(result) 207 | 208 | report_results = {} 209 | for task in performance_on_tasks: 210 | result = sum(performance_on_tasks[task]) / len(performance_on_tasks[task]) 211 | if len(task.split("/")) > 1: 212 | label = task.split("/")[1] 213 | report_results[label] = result 214 | else: 215 | report_results[task] = result 216 | print(report_results) 217 | return report_results 218 | 219 | 220 | def set_args_for_val(args, data, key): 221 | setattr(args, "val_data", data) 222 | setattr(args, "val_data_key", key) 223 | setattr(args, "squash_mask_left", True) 224 | setattr(args, "target_mask_individual", 50400) 225 | setattr(args, "target_mask_left", 50300) 226 | setattr(args, "val_seq_ci", True) 227 | setattr(args, "val_tok_ci", True) 228 | return args 229 | 230 | 231 | def main(): 232 | """ 233 | Usage: 234 | python eval_openlm_ckpt.py --checkpoint --model --eval-yaml --tokenizer 235 | example: 236 | cd eval 237 | python eval_openlm_ckpt.py --checkpoint ../checkpoints/llama2_7b.pt --model llama2_7b.json --eval-yaml in_memory_hf_eval.yaml --tokenizer 238 | multi-gpu example: 239 | cd eval 240 | torchrun --nproc_per_node 3 eval_openlm_ckpt.py --checkpoint ../checkpoints/llama2_7b.pt --model llama2_7b.json --eval-yaml in_memory_hf_eval.yaml --tokenizer 241 | 242 | torchrun --nproc_per_node 3 eval_openlm_ckpt.py --checkpoint checkpoint.pt --config params.txt 243 | """ 244 | parser = argparse.ArgumentParser() 245 | # Arguments that openlm requires when we call load_model 246 | parser.add_argument( 247 | "--seed", 248 | type=int, 249 | default=None, 250 | help="Seed for reproducibility, when None, will use the seed from the eval config file.", 251 | ) 252 | parser.add_argument("--fsdp", default=False, action="store_true") 253 | parser.add_argument("--distributed", default=True, action="store_true") 254 | parser.add_argument("--resume", default=None, type=str) 255 | 256 | # Argument for uploading results 257 | parser.add_argument("--remote-sync", type=str, default=None) 258 | parser.add_argument("--remote-sync-protocol", type=str, default="s3", choices=["s3", "fsspec"]) 259 | 260 | parser.add_argument("--checkpoint", default=None, type=str, help="Path to checkpoint to evaluate.") 261 | parser.add_argument("--eval-yaml", type=str, default="light.yaml") 262 | parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b") 263 | parser.add_argument("--config", default=None, type=str) 264 | parser.add_argument("--model", default=None, type=str) 265 | parser.add_argument("--hf-model", default=None) 266 | 267 | parser.add_argument("--hf-cache-dir", default=None) 268 | parser.add_argument("--output-file", type=str, default=None) 269 | parser.add_argument( 270 | "--use-temp-working-dir", 271 | action="store_true", 272 | help="Use a temporary working directory for the evaluation. removing it when done. " 273 | "This is required if you wish to run multiple evaluations with the same datasets" 274 | " in parallel on the same node.", 275 | ) 276 | parser.add_argument( 277 | "--eval_meta_data", default=f"{os.path.dirname(__file__)}/eval_meta_data.csv", help="Eval meta data file" 278 | ) 279 | parser.add_argument( 280 | "--preset-world-size", 281 | type=int, 282 | default=None, 283 | help="Explicitly set the world size. Useful in cases where a different number of gpus per node need to be used.", 284 | ) 285 | parser.add_argument( 286 | "--additional_aggregation", 287 | default=f"{os.path.dirname(__file__)}/additional_aggregation.json", 288 | help="Eval aggregation file", 289 | ) 290 | parser.add_argument( 291 | "--val-data", 292 | type=str, 293 | default=None, 294 | help="val data for perplexity calc", 295 | ) 296 | parser.add_argument( 297 | "--moe-freq", 298 | type=int, 299 | default=0, 300 | help="if set > 0, we will add MoE layer to every moe_freq layer.", 301 | ) 302 | parser.add_argument( 303 | "--moe-num-experts", 304 | type=int, 305 | default=None, 306 | help="Number of experts for MoE", 307 | ) 308 | 309 | parser.add_argument( 310 | "--moe-weight-parallelism", 311 | action="store_true", 312 | help="Add weight parallelism to MoE", 313 | ) 314 | 315 | parser.add_argument( 316 | "--moe-expert-model-parallelism", 317 | action="store_true", 318 | help="Add expert model parallelism to MoE", 319 | ) 320 | 321 | parser.add_argument( 322 | "--moe-capacity-factor", 323 | type=float, 324 | default=1.25, 325 | help="MoE capacity factor", 326 | ) 327 | 328 | parser.add_argument( 329 | "--moe-loss-weight", 330 | type=float, 331 | default=0.1, 332 | help="MoE loss weight", 333 | ) 334 | parser.add_argument( 335 | "--moe-top-k", 336 | type=int, 337 | default=2, 338 | help="MoE top k experts", 339 | ) 340 | parser.add_argument( 341 | "--attn-name", 342 | type=str, 343 | default="xformers_attn", 344 | choices=["xformers_attn", "torch_attn", "custom_attn"], 345 | help="type of attention to use", 346 | ) 347 | parser.add_argument( 348 | "--attn-activation", 349 | type=str, 350 | default=None, 351 | choices=list(ATTN_ACTIVATIONS.keys()), 352 | help="activation to use with custom_attn", 353 | ) 354 | parser.add_argument( 355 | "--attn-seq-scalar", 356 | type=str, 357 | default=None, 358 | choices=list(ATTN_SEQ_SCALARS.keys()), 359 | help="different ways to set L, where L^alpha divides attention logits post activation", 360 | ) 361 | parser.add_argument( 362 | "--attn-seq-scalar-alpha", 363 | type=float, 364 | default=None, 365 | help="power alpha to raise L to, where L^alpha divides attention logits post activation", 366 | ) 367 | parser.add_argument( 368 | "--val-max-pop-ci", 369 | default=None, 370 | action="store", 371 | type=int, 372 | help="when running CIs what is the maximum population size for the inner loop", 373 | ) 374 | parser.add_argument( 375 | "--val-iter-ci", 376 | default=10_000, 377 | action="store", 378 | type=int, 379 | help="how many times to sample to construct the CI for the outer loop", 380 | ) 381 | parser.add_argument("--averager-name", help="If specified, load this averager from checkpoint.") 382 | 383 | parser.add_argument("--donot-compute-perplexity", action="store_true") 384 | parser.add_argument("--compute-downstream-perplexity", action="store_true") 385 | parser.add_argument("--compute-paloma-perplexity", action="store_true") 386 | parser.add_argument("--force-xformers", action="store_true") 387 | 388 | parser.add_argument("--num_experts_per_tok", type=int, default=None) 389 | 390 | args = parser.parse_args() 391 | 392 | orig_seed = args.seed # may be overridden by config file if it exists 393 | 394 | if os.path.exists(args.output_file): 395 | print (f"{args.output_file} exists!") 396 | return 397 | 398 | if args.config is not None: 399 | assert args.hf_model is None, ( 400 | "If you are using a config file, " 401 | "you are trying to evaluate open_lm model. Please remove hf-model argument." 402 | ) 403 | 404 | update_args_from_openlm_config(args) 405 | # disable wandb for eval 406 | args.wandb = None 407 | else: 408 | # Most probably evaling a hf-model. 409 | 410 | assert args.hf_model, ( 411 | "If you are not using a config file, you might want to evaluate a Hugginface model, " 412 | "so please provide hf-model argument." 413 | ) 414 | # Computing perplexity for HF model doesn't make sense. 415 | args.donot_compute_perplexity = True 416 | 417 | # Setting those params as they are needed to distributed evals 418 | # and they are supposed to come from config file. 419 | args.dist_backend = "nccl" 420 | args.dist_url = "env://" 421 | args.no_set_device_rank = False 422 | args.model = args.hf_model 423 | args.force_distributed = False 424 | with open(args.eval_yaml) as f: 425 | eval_cfg = om.load(f) 426 | if orig_seed is not None: 427 | print(f"Overriding eval config seed ({eval_cfg.seed}) to {orig_seed}") 428 | eval_cfg.seed = orig_seed 429 | 430 | # now need to set the 'fewshot_random_seed' in each config in the icl task configs 431 | for icl_cfg in eval_cfg.icl_tasks: 432 | icl_cfg.fewshot_random_seed = orig_seed 433 | 434 | args.resume = args.checkpoint 435 | args.remote_sync = args.output_file 436 | directory = os.path.dirname(args.output_file) 437 | if directory != "" and not os.path.exists(directory): 438 | os.makedirs(directory) 439 | 440 | CWD = os.getcwd() 441 | if args.use_temp_working_dir: 442 | temp_dir = os.path.join(CWD, "eval_openlm_ckpt_temp_dirs", f"{uuid.uuid4()}") 443 | os.makedirs(temp_dir, exist_ok=True) # in case rank > 0 444 | os.chdir(temp_dir) 445 | print(f"Using temporary working directory: {temp_dir}") 446 | 447 | print("Loading model into the right classes") 448 | if args.hf_model is not None: 449 | 450 | if args.num_experts_per_tok: 451 | config = AutoConfig.from_pretrained( 452 | args.hf_model, token=hf_access_token, trust_remote_code=True, cache_dir=args.hf_cache_dir) 453 | config.num_experts_per_tok = args.num_experts_per_tok 454 | eval_model = AutoModelForCausalLM.from_pretrained( 455 | args.hf_model, config=config, token=hf_access_token, trust_remote_code=True, cache_dir=args.hf_cache_dir 456 | ) 457 | assert eval_model.num_experts_per_tok == args.num_experts_per_tok 458 | 459 | else: 460 | eval_model = AutoModelForCausalLM.from_pretrained( 461 | args.hf_model, token=hf_access_token, trust_remote_code=True, cache_dir=args.hf_cache_dir 462 | ) 463 | else: 464 | params = create_params(args) 465 | eval_model = OpenLMforCausalLM(OpenLMConfig(params)) 466 | 467 | if "gpt-neox-20b" in args.tokenizer: 468 | tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") 469 | elif "llama" in args.tokenizer: 470 | tokenizer = LlamaTokenizerFast.from_pretrained(args.tokenizer) 471 | if len(tokenizer) > eval_model.config.vocab_size: # happens in llama-3-8b 472 | print(f"Resizing vocab from {eval_model.config.vocab_size} to {len(tokenizer)}") 473 | eval_model.resize_token_embeddings(len(tokenizer)) 474 | else: 475 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, token=hf_access_token, trust_remote_code=True, cache_dir=args.hf_cache_dir) 476 | 477 | if args.checkpoint is not None: 478 | if not args.averager_name: 479 | print(f"Loading checkpoint {args.checkpoint}") 480 | args.distributed = False 481 | load_model(args, eval_model.model, different_seed=True) 482 | args.distributed = True 483 | else: 484 | print(f"Loading checkpoint {args.checkpoint}") 485 | checkpoint = pt_load(args.resume, map_location="cpu") 486 | if "epoch" in checkpoint: 487 | # resuming a train checkpoint w/ epoch and optimizer state 488 | start_epoch = checkpoint["epoch"] 489 | avg_sd = torch.load(args.checkpoint, map_location="cpu") 490 | if next(iter(avg_sd.items()))[0].startswith("module"): 491 | avg_sd = {k[len("module.") :]: v for k, v in avg_sd.items()} 492 | eval_model.model.load_state_dict(avg_sd) 493 | 494 | # HF model loaded with from_pretrained is by default in eval mode. 495 | # https://github.com/huggingface/transformers/blob/ebfdb9ca62205279d5019ef1403877461b3b2da4/src/transformers/modeling_utils.py#L2500 496 | eval_model.model.eval() 497 | 498 | # Set requires grad = False to reduce memory consumption - o/w composer makes a copy of the model. 499 | for p in eval_model.parameters(): 500 | p.requires_grad = False 501 | 502 | device = init_distributed_device(args) 503 | eval_model = eval_model.to(device) 504 | eval_metrics = {} 505 | 506 | local_rank, _, _ = world_info_from_env() 507 | 508 | if not args.donot_compute_perplexity: 509 | args.per_gpu_val_batch_size = args.per_gpu_batch_size // args.accum_freq 510 | openlm_val_data = download_val_data("open_lm_val", skip_download=local_rank != 0) 511 | args = set_args_for_val(args, [openlm_val_data], ["json"]) 512 | data = get_data(args, epoch=0, tokenizer=None, skip_train=True) 513 | results = evaluate_loop(eval_model.model, data["val_list"], 0, args, None) 514 | perplexity_val = results[0]["loss"] 515 | eval_metrics["perplexity"] = perplexity_val 516 | 517 | if args.compute_paloma_perplexity: 518 | args.per_gpu_val_batch_size = args.per_gpu_batch_size // args.accum_freq 519 | paloma_val_data = download_val_data("paloma_val", skip_download=local_rank != 0) 520 | args = set_args_for_val(args, [paloma_val_data], ["json.gz"]) 521 | data = get_data(args, epoch=0, tokenizer=None, skip_train=True) 522 | results = evaluate_loop(eval_model.model, data["val_list"], 0, args, None) 523 | perplexity_val = results[0]["loss"] 524 | eval_metrics["paloma_perplexity"] = perplexity_val 525 | 526 | if args.compute_downstream_perplexity: 527 | args.per_gpu_val_batch_size = args.per_gpu_batch_size // args.accum_freq 528 | size = args.eval_yaml[:-5] 529 | tasks = load_ppl_yaml(size) 530 | downstream_datas = [ 531 | download_val_data(task_name, skip_download=local_rank != 0) 532 | for task_name in tasks 533 | if "gpqa" not in task_name 534 | ] 535 | args = set_args_for_val(args, downstream_datas, ["txt"] * len(downstream_datas)) 536 | data = get_data(args, epoch=0, tokenizer=None, skip_train=True) 537 | 538 | results = evaluate_loop(eval_model.model, data["val_list"], 0, args, None) 539 | eval_metrics["downstream_perpexity"] = {} 540 | for result in results: 541 | data_name = result["val_data"][0].split("/")[-2] 542 | eval_metrics["downstream_perpexity"][data_name] = result["loss"] 543 | 544 | icl_results = evaluate(eval_model, tokenizer, eval_cfg) 545 | eval_metrics["icl"] = icl_results 546 | 547 | date_format = "%Y_%m_%d-%H_%M_%S" 548 | date = datetime.now(tz=pytz.utc) 549 | date = date.astimezone(timezone("US/Pacific")) 550 | date = date.strftime(date_format) 551 | 552 | output = { 553 | "name": str(args.eval_yaml)[:-5], 554 | "uuid": str(uuid.uuid4()), 555 | "model": args.model, 556 | "creation_date": date, 557 | "eval_metrics": eval_metrics, 558 | } 559 | 560 | with open(args.additional_aggregation, "r") as f: 561 | aggregation_json = json.load(f) 562 | 563 | eval_metadata = pd.read_csv(args.eval_meta_data) 564 | 565 | output = get_aggregated_results(output, eval_metadata, aggregation_json) 566 | print("Eval output: ") 567 | print(json.dumps(output, indent=4, sort_keys=True)) 568 | if local_rank == 0: 569 | if args.use_temp_working_dir: 570 | print(f"Removing temporary working directory: {temp_dir} amd changing back to {CWD}") 571 | shutil.rmtree(temp_dir) 572 | os.chdir(CWD) # need to change back BEFORE we save the output file 573 | 574 | with open(args.output_file, "w") as f: 575 | json.dump(output, f, indent=4) 576 | 577 | return output 578 | 579 | 580 | if __name__ == "__main__": 581 | main() 582 | -------------------------------------------------------------------------------- /scripts/humaneval.yaml: -------------------------------------------------------------------------------- 1 | epoch: 1.25T 2 | dataset: bigdata 3 | num_params: 1B 4 | max_seq_len: 2048 5 | seed: 1 6 | precision: fp32 7 | 8 | # Tokenizer 9 | tokenizer: 10 | # name: [Add name from memory] 11 | pretrained_model_name_or_path: 12 | kwargs: 13 | model_max_length: 2048 14 | 15 | model: 16 | name: open_lm 17 | # pretrained_model_name_or_path: [add name from memory] 18 | init_device: cpu 19 | pretrained: true 20 | 21 | load_path: # Add your (optional) Composer checkpoint path here! 22 | 23 | device_eval_batch_size: 2 24 | 25 | # FSDP config for model sharding 26 | fsdp_config: 27 | sharding_strategy: FULL_SHARD 28 | mixed_precision: FULL 29 | 30 | 31 | icl_tasks: 32 | - 33 | label: human_eval 34 | dataset_uri: local_data/programming/human_eval.jsonl # ADD YOUR OWN DATASET URI 35 | num_fewshot: [0] 36 | pass_at_k: 1 37 | batch_size: 1 38 | icl_task_type: code_evaluation 39 | generation_kwargs: 40 | num_beams: 5 41 | - 42 | label: human_eval_cpp 43 | dataset_uri: local_data/programming/processed_human_eval_cpp.jsonl # ADD YOUR OWN DATASET URI 44 | num_fewshot: [0] 45 | pass_at_k: 1 46 | batch_size: 1 47 | icl_task_type: code_evaluation 48 | generation_kwargs: 49 | num_beams: 5 50 | - 51 | label: human_eval_js 52 | dataset_uri: local_data/programming/processed_human_eval_js.jsonl # ADD YOUR OWN DATASET URI 53 | num_fewshot: [0] 54 | pass_at_k: 1 55 | batch_size: 1 56 | icl_task_type: code_evaluation 57 | generation_kwargs: 58 | num_beams: 5 59 | - 60 | label: human_eval_return_simple 61 | dataset_uri: local_data/programming/human_eval_return_simple.jsonl # ADD YOUR OWN DATASET URI 62 | num_fewshot: [0] 63 | pass_at_k: 1 64 | batch_size: 1 65 | icl_task_type: code_evaluation 66 | generation_kwargs: 67 | num_beams: 5 68 | - 69 | label: human_eval_return_complex 70 | dataset_uri: local_data/programming/human_eval_return_complex.jsonl # ADD YOUR OWN DATASET URI 71 | num_fewshot: [0] 72 | pass_at_k: 1 73 | batch_size: 1 74 | icl_task_type: code_evaluation 75 | generation_kwargs: 76 | num_beams: 5 77 | - 78 | label: human_eval_25 79 | dataset_uri: local_data/programming/human_eval-0.25.jsonl # ADD YOUR OWN DATASET URI 80 | num_fewshot: [0] 81 | pass_at_k: 1 82 | batch_size: 1 83 | icl_task_type: code_evaluation 84 | generation_kwargs: 85 | num_beams: 5 86 | - 87 | label: human_eval_50 88 | dataset_uri: local_data/programming/human_eval-0.5.jsonl # ADD YOUR OWN DATASET URI 89 | num_fewshot: [0] 90 | pass_at_k: 1 91 | batch_size: 1 92 | icl_task_type: code_evaluation 93 | generation_kwargs: 94 | num_beams: 5 95 | - 96 | label: human_eval_75 97 | dataset_uri: local_data/programming/human_eval-0.75.jsonl # ADD YOUR OWN DATASET URI 98 | num_fewshot: [0] 99 | pass_at_k: 1 100 | batch_size: 1 101 | icl_task_type: code_evaluation 102 | generation_kwargs: 103 | num_beams: 5 104 | -------------------------------------------------------------------------------- /scripts/llm1b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=llm 3 | #SBATCH --nodes=8 4 | #SBATCH --ntasks-per-node=8 5 | #SBATCH --hint=nomultithread # we get physical cores not logical 6 | #SBATCH --partition=a3 7 | #SBATCH --gres=gpu:8 # number of gpus 8 | #SBATCH --output=/data/niklas/jobs/%x-%j.out # output file name 9 | #SBATCH --exclusive 10 | 11 | ###################### 12 | ### Set enviroment ### 13 | ###################### 14 | cd /home/niklas/OLMo 15 | source /env/bin/start-ctx-user 16 | conda activate llm 17 | export WANDB_PROJECT="olmo-small" 18 | # Training setup 19 | set -euo pipefail 20 | export GPUS_PER_NODE=8 21 | export NODENAME=$(hostname -s) 22 | export MASTER_ADDR=$(scontrol show hostnames | head -n 1) 23 | export MASTER_PORT=39594 24 | export WORLD_SIZE=$SLURM_NTASKS 25 | export RANK=$SLURM_PROCID 26 | export FS_LOCAL_RANK=$SLURM_PROCID 27 | export LOCAL_WORLD_SIZE=$SLURM_NTASKS_PER_NODE 28 | export LOCAL_RANK=$SLURM_LOCALID 29 | export NODE_RANK=$((($RANK - $LOCAL_RANK) / $LOCAL_WORLD_SIZE)) 30 | export R2_PROFILE=r2 31 | export R2_ENDPOINT_URL=XXX 32 | export AWS_ACCESS_KEY_ID=XXX 33 | export AWS_SECRET_ACCESS_KEY=XXX 34 | export HF_DATASETS_OFFLINE=1 35 | echo "World size: $WORLD_SIZE" 36 | ###################### 37 | 38 | ###################### 39 | #### Set network ##### 40 | ###################### 41 | head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 42 | ###################### 43 | 44 | srun \ 45 | --distribution=block:block \ 46 | --kill-on-bad-exit \ 47 | scripts/run_with_environment.sh \ 48 | python -u scripts/train.py configs/mitchish1-s3.yaml \ 49 | --activation_checkpointing=fine_grained \ 50 | --canceled_check_interval=50 \ 51 | --gen1_gc_interval=1 \ 52 | --device_train_microbatch_size=8 \ 53 | --global_train_batch_size=512 \ 54 | --run_name=mitchish1 \ 55 | --wandb.group=mitchish1 \ 56 | --model.flash_attention=true \ 57 | --fsdp.wrapping_strategy=null \ 58 | --fsdp.sharding_strategy=SHARD_GRAD_OP \ 59 | --fused_loss=true \ 60 | '--load_path=${path.last_checkpoint:${remote_save_folder}}' \ 61 | "--save_folder=/data/niklas/llm/checkpoints/${SLURM_JOB_ID}/" 62 | -------------------------------------------------------------------------------- /scripts/make_table.py: -------------------------------------------------------------------------------- 1 | """ 2 | Make DCLM results table. 3 | """ 4 | 5 | from pathlib import Path 6 | import json 7 | import pandas as pd 8 | import copy 9 | 10 | 11 | result_dir = Path("results/dclm") 12 | 13 | 14 | model_names = [ 15 | "OLMoE-7B-A1B-main", 16 | "OLMoE-7B-A1B-step1220000-tokens5117B", 17 | "OLMoE-7B-A1B-step1223842-tokens5100B", 18 | "OLMo-1B-0724-hf" 19 | # "OLMo-7B-0724-hf.json", # Uncomment this once evals are done. 20 | ] 21 | 22 | eval_settings = ["heavy", "humaneval"] 23 | 24 | models_lookup = { 25 | "OLMoE-7B-A1B-main": "OLMoE-1B-7B", 26 | "OLMoE-7B-A1B-step1220000-tokens5117B": "OLMoE-1B-7B step 1,220,000", 27 | "OLMoE-7B-A1B-step1223842-tokens5100B": "OLMoE-1B-7B step 1,223,842", 28 | "OLMo-1B-0724-hf": "OLMo-1B", 29 | # "OLMo-7B-0724-hf": "OLMo-7B", # Uncomment this once evals are done. 30 | } 31 | 32 | metrics_lookup = { 33 | "agi_eval_lsat_ar": "AGI Eval LSAT-AR$^*$", 34 | "agi_eval_lsat_lr": "AGI Eval LSAT-LR", 35 | "agi_eval_lsat_rc": "AGI Eval LSAT-RC", 36 | "agi_eval_sat_en": "AGI Eval SAT-En", 37 | "agi_eval_sat_math_cot": "AGI Eval SAT-Math CoT", 38 | "aqua_cot": "AQuA CoT", 39 | "arc_challenge": "ARC Challenge$^*$", 40 | "arc_easy": "ARC Easy$^*$", 41 | "bbq": "BBQ", 42 | "bigbench_conceptual_combinations": "BigBench Conceptual Combinations", 43 | "bigbench_conlang_translation": "BigBench Conlang Translation", 44 | "bigbench_cs_algorithms": "BigBench CS Algorithms$^*$", 45 | "bigbench_dyck_languages": "BigBench Dyck Languages$^*$", 46 | "bigbench_elementary_math_qa": "BigBench Elementary Math QA", 47 | "bigbench_language_identification": "BigBench Language Identification$^*$", 48 | "bigbench_logical_deduction": "BigBench Logical Deduction", 49 | "bigbench_misconceptions": "BigBench Misconceptions", 50 | "bigbench_novel_concepts": "BigBench Novel Concepts", 51 | "bigbench_operators": "BigBench Operators$^*$", 52 | "bigbench_qa_wikidata": "BigBench QA Wikidata$^*$", 53 | "bigbench_repeat_copy_logic": "BigBench Repeat Copy Logic$^*$", 54 | "bigbench_strange_stories": "BigBench Strange Stories", 55 | "bigbench_strategy_qa": "BigBench Strategy QA", 56 | "bigbench_understanding_fables": "BigBench Understanding Fables", 57 | "boolq": "BoolQ$^*$", 58 | "commonsense_qa": "CommonsenseQA$^*$", 59 | "copa": "COPA$^*$", 60 | "coqa": "CoQA$^*$", 61 | "enterprise_pii_classification": "Enterprise PII Classification", 62 | "gpqa_diamond": "GPQA Diamond", 63 | "gpqa_main": "GPQA Main", 64 | "gsm8k_cot": "GSM8K CoT", 65 | "hellaswag": "HellaSwag 10-shot$^*$", 66 | "hellaswag_zeroshot": "HellaSwag 0-shot$^*$", 67 | "jeopardy": "Jeopardy$^*$", 68 | "lambada_openai": "LAMBADA$^*$", 69 | "logi_qa": "LogiQA", 70 | "math_qa": "Math QA", 71 | "mmlu_fewshot": "MMLU Few-shot", 72 | "mmlu_zeroshot": "MMLU Zero-shot", 73 | "openbook_qa": "OpenBookQA$^*$", 74 | "piqa": "PIQA$^*$", 75 | "pubmed_qa_labeled": "PubMedQA", 76 | "simple_arithmetic_nospaces": "Simple Arithmetic, no spaces", 77 | "simple_arithmetic_withspaces": "Simple Arithmetic, with spaces", 78 | "siqa": "Social IQA", 79 | "squad": "SQuAD$^*$", 80 | "svamp_cot": "SVAMP CoT", 81 | "triviaqa_sm_sub": "Trivia QA", 82 | "winogender_mc_female": "Winogender Female", 83 | "winogender_mc_male": "Winogender Male", 84 | "winograd": "Winograd$^*$", 85 | "winogrande": "Winogrande$^*$", 86 | "core": "Core", 87 | "extended": "Extended", 88 | } 89 | 90 | res = {} 91 | 92 | for model_name in model_names: 93 | data = json.load(open(result_dir / f"heavy-{model_name}.json")) 94 | to_add = copy.deepcopy(data["eval_metrics"]["icl"]) 95 | to_update = { 96 | "core": data["low_variance_datasets_centered"], 97 | "extended": data["aggregated_centered_results"], 98 | } 99 | to_add.update(to_update) 100 | res[model_name] = to_add 101 | 102 | res = pd.DataFrame(res) * 100 103 | 104 | 105 | # Replace the columns in res according to the `models_lookup` dict. 106 | res.columns = [models_lookup.get(col, col) for col in res.columns] 107 | 108 | # Replace the rows in res according to the `metrics_lookup` dict. 109 | res.index = [metrics_lookup.get(row, row) for row in res.index] 110 | 111 | res = res.reindex(sorted(res.index.drop(["Core", "Extended"])) + ["Core", "Extended"]) 112 | 113 | res.to_csv("results/dclm-table.tsv", sep="\t", float_format="%.1f") 114 | res.to_latex("results/dclm-table.tex", float_format="%.1f") 115 | -------------------------------------------------------------------------------- /scripts/megatron.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=llm 3 | #SBATCH --nodes=1 4 | #SBATCH --ntasks-per-node=1 5 | #SBATCH --hint=nomultithread # we get physical cores not logical 6 | #SBATCH --partition=a3 7 | #SBATCH --gres=gpu:8 # number of gpus 8 | #SBATCH --time 48:00:00 # maximum execution time (HH:MM:SS) 9 | #SBATCH --output=/data/niklas/jobs/%x-%j.out # output file name 10 | #SBATCH --exclusive 11 | 12 | ###################### 13 | ### Set enviroment ### 14 | ###################### 15 | cd /home/niklas/megablocks 16 | source /env/bin/start-ctx-user 17 | conda activate megatronmoe 18 | export WANDB_PROJECT="olmoe" 19 | # Training setup 20 | set -euo pipefail 21 | export GPUS_PER_NODE=8 22 | export NODENAME=$(hostname -s) 23 | export MASTER_ADDR=$(scontrol show hostnames | head -n 1) 24 | export MASTER_PORT=39594 25 | #export WORLD_SIZE=$SLURM_NTASKS 26 | #export RANK=$SLURM_PROCID 27 | #export FS_LOCAL_RANK=$SLURM_PROCID 28 | #export LOCAL_WORLD_SIZE=$SLURM_NTASKS_PER_NODE 29 | #export LOCAL_RANK=$SLURM_LOCALID 30 | #export NODE_RANK=$((($RANK - $LOCAL_RANK) / $LOCAL_WORLD_SIZE)) 31 | export R2_PROFILE=r2 32 | export R2_ENDPOINT_URL=YOUR_URL 33 | export AWS_ACCESS_KEY_ID=YOUR_KEY 34 | export AWS_SECRET_ACCESS_KEY=YOUR_KEY 35 | export HF_DATASETS_OFFLINE=1 36 | export CUDA_LAUNCH_BLOCKING=1 37 | export TRITON_CACHE_DIR=/data/niklas/tritoncache 38 | 39 | #echo "World size: $WORLD_SIZE" 40 | ###################### 41 | 42 | ###################### 43 | #### Set network ##### 44 | ###################### 45 | head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 46 | ###################### 47 | 48 | srun /home/niklas/dense_1b_8gpu.sh /data/niklas/llm/checkpoints/${SLURM_JOB_ID}/ 20000 49 | -------------------------------------------------------------------------------- /scripts/megatron_dense_46m_8gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | EXP_DIR=$1 4 | 5 | # scaling law: 1B tokesn @ 125m = 2k steps. 6 | # 7 | # 512 * 1k * 400k = 200b tokens. 8 | # 512 * 1k * 200k = 100b tokens. 9 | # 512 * 1k * 100k = 50b tokens (default). 10 | # 512 * 1k * 20k = 10b tokens. 11 | TRAINING_STEPS=100000 12 | if [ -n "${2}" ]; then 13 | TRAINING_STEPS=$2; 14 | fi 15 | 16 | ## 17 | ### Pre-training for GPT2 46M parameter. 18 | ## 19 | 20 | # Distributed hyperparameters. 21 | DISTRIBUTED_ARGUMENTS="\ 22 | --nproc_per_node 8 \ 23 | --nnodes 1 \ 24 | --node_rank 0 \ 25 | --master_addr localhost \ 26 | --master_port 6000" 27 | 28 | # Model hyperparameters. 29 | MODEL_ARGUMENTS="\ 30 | --num-layers 6 \ 31 | --hidden-size 512 \ 32 | --num-attention-heads 8 \ 33 | --seq-length 1024 \ 34 | --max-position-embeddings 1024" 35 | 36 | # Training hyperparameters. 37 | TRAINING_ARGUMENTS="\ 38 | --micro-batch-size 64 \ 39 | --global-batch-size 512 \ 40 | --train-iters ${TRAINING_STEPS} \ 41 | --lr-decay-iters ${TRAINING_STEPS} \ 42 | --lr 0.0006 \ 43 | --min-lr 0.00006 \ 44 | --lr-decay-style cosine \ 45 | --lr-warmup-fraction 0.01 \ 46 | --clip-grad 1.0 \ 47 | --init-method-std 0.01" 48 | 49 | C4_DATASET="/data/niklas/c4-subsets/55b/gpt2tok_c4_en_55B_text_document" 50 | 51 | # NOTE: We don't train for enough tokens for the 52 | # split to matter. 53 | DATA_ARGUMENTS="\ 54 | --data-path ${C4_DATASET} \ 55 | --vocab-file /data/niklas/vocab.json \ 56 | --merge-file /data/niklas/merges.txt \ 57 | --make-vocab-size-divisible-by 1024 \ 58 | --split 969,30,1" 59 | 60 | COMPUTE_ARGUMENTS="\ 61 | --bf16 \ 62 | --DDP-impl local \ 63 | --no-async-tensor-model-parallel-allreduce \ 64 | --use-flash-attn" 65 | 66 | CHECKPOINT_ARGUMENTS="\ 67 | --save-interval 2000 \ 68 | --save ./${EXP_DIR}" 69 | 70 | EVALUATION_ARGUMENTS="\ 71 | --eval-iters 100 \ 72 | --log-interval 100 \ 73 | --eval-interval 1000" 74 | 75 | torchrun ${DISTRIBUTED_ARGUMENTS} \ 76 | third_party/Megatron-LM/pretrain_gpt.py \ 77 | ${MODEL_ARGUMENTS} \ 78 | ${TRAINING_ARGUMENTS} \ 79 | ${DATA_ARGUMENTS} \ 80 | ${COMPUTE_ARGUMENTS} \ 81 | ${CHECKPOINT_ARGUMENTS} \ 82 | ${EVALUATION_ARGUMENTS} |& tee ./${EXP_DIR}/train.log 83 | -------------------------------------------------------------------------------- /scripts/megatron_dmoe_46m_8gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | EXP_DIR=$1 4 | 5 | # 512 * 1k * 400k = 200b tokens. 6 | # 512 * 1k * 200k = 100b tokens. 7 | # 512 * 1k * 100k = 50b tokens (default). 8 | # 512 * 1k * 20k = 10b tokens. 9 | TRAINING_STEPS=20000 10 | if [ -n "${2}" ]; then 11 | TRAINING_STEPS=$2; 12 | fi 13 | 14 | NUM_EXPERTS=64 15 | if [ -n "${3}" ]; then 16 | NUM_EXPERTS=$3; 17 | fi 18 | 19 | TOP_K=1 20 | if [ -n "${4}" ]; then 21 | TOP_K=$4; 22 | fi 23 | 24 | LOSS_WEIGHT=0.1 25 | if [ -n "${5}" ]; then 26 | LOSS_WEIGHT=$5; 27 | fi 28 | 29 | BATCH_SIZE=64 30 | if [ -n "${6}" ]; then 31 | BATCH_SIZE=$6; 32 | fi 33 | 34 | ## 35 | ### Pre-training for dMoE 46M parameter. 36 | ## 37 | 38 | # MoE hyperparameters. 39 | MOE_ARGUMENTS="\ 40 | --moe-num-experts=${NUM_EXPERTS} \ 41 | --moe-loss-weight=${LOSS_WEIGHT} \ 42 | --moe-top-k=${TOP_K}" 43 | 44 | # Distributed hyperparameters. 45 | DISTRIBUTED_ARGUMENTS="\ 46 | --nproc_per_node ${GPUS_PER_NODE} \ 47 | --nnodes 1 \ 48 | --node_rank 0 \ 49 | --master_addr ${MASTER_ADDR} \ 50 | --master_port ${MASTER_PORT}" 51 | 52 | # Model hyperparameters. 53 | MODEL_ARGUMENTS="\ 54 | --num-layers 6 \ 55 | --hidden-size 512 \ 56 | --num-attention-heads 8 \ 57 | --seq-length 1024 \ 58 | --max-position-embeddings 1024" 59 | 60 | # Training hyperparameters. 61 | TRAINING_ARGUMENTS="\ 62 | --micro-batch-size ${BATCH_SIZE} \ 63 | --global-batch-size 512 \ 64 | --train-iters ${TRAINING_STEPS} \ 65 | --lr-decay-iters ${TRAINING_STEPS} \ 66 | --lr 0.0006 \ 67 | --min-lr 0.00006 \ 68 | --lr-decay-style cosine \ 69 | --lr-warmup-fraction 0.01 \ 70 | --clip-grad 1.0 \ 71 | --init-method-std 0.01 \ 72 | --optimizer adam" 73 | 74 | PILE_DATASET="\ 75 | 1.0 \ 76 | /mount/pile_gpt2/01_text_document \ 77 | 1.0 \ 78 | /mount/pile_gpt2/02_text_document \ 79 | 1.0 \ 80 | /mount/pile_gpt2/03_text_document \ 81 | 1.0 \ 82 | /mount/pile_gpt2/04_text_document \ 83 | 1.0 \ 84 | /mount/pile_gpt2/05_text_document \ 85 | 1.0 \ 86 | /mount/pile_gpt2/06_text_document \ 87 | 1.0 \ 88 | /mount/pile_gpt2/07_text_document \ 89 | 1.0 \ 90 | /mount/pile_gpt2/08_text_document \ 91 | 1.0 \ 92 | /mount/pile_gpt2/09_text_document \ 93 | 1.0 \ 94 | /mount/pile_gpt2/10_text_document \ 95 | 1.0 \ 96 | /mount/pile_gpt2/11_text_document \ 97 | 1.0 \ 98 | /mount/pile_gpt2/12_text_document \ 99 | 1.0 \ 100 | /mount/pile_gpt2/13_text_document \ 101 | 1.0 \ 102 | /mount/pile_gpt2/14_text_document \ 103 | 1.0 \ 104 | /mount/pile_gpt2/15_text_document \ 105 | 1.0 \ 106 | /mount/pile_gpt2/16_text_document \ 107 | 1.0 \ 108 | /mount/pile_gpt2/17_text_document \ 109 | 1.0 \ 110 | /mount/pile_gpt2/18_text_document \ 111 | 1.0 \ 112 | /mount/pile_gpt2/19_text_document \ 113 | 1.0 \ 114 | /mount/pile_gpt2/20_text_document \ 115 | 1.0 \ 116 | /mount/pile_gpt2/21_text_document \ 117 | 1.0 \ 118 | /mount/pile_gpt2/22_text_document \ 119 | 1.0 \ 120 | /mount/pile_gpt2/23_text_document \ 121 | 1.0 \ 122 | /mount/pile_gpt2/24_text_document \ 123 | 1.0 \ 124 | /mount/pile_gpt2/25_text_document \ 125 | 1.0 \ 126 | /mount/pile_gpt2/26_text_document \ 127 | 1.0 \ 128 | /mount/pile_gpt2/27_text_document \ 129 | 1.0 \ 130 | /mount/pile_gpt2/28_text_document \ 131 | 1.0 \ 132 | /mount/pile_gpt2/29_text_document" 133 | 134 | C4_DATASET="/data/niklas/c4-subsets/55b/gpt2tok_c4_en_55B_text_document" 135 | 136 | # NOTE: We don't train for enough tokens for the 137 | # split to matter. 138 | DATA_ARGUMENTS="\ 139 | --data-path ${C4_DATASET} \ 140 | --vocab-file /data/niklas/vocab.json \ 141 | --merge-file /data/niklas/merges.txt \ 142 | --make-vocab-size-divisible-by 1024 \ 143 | --split 969,30,1" 144 | 145 | COMPUTE_ARGUMENTS="\ 146 | --bf16 \ 147 | --DDP-impl local \ 148 | --moe-expert-model-parallelism \ 149 | --no-async-tensor-model-parallel-allreduce \ 150 | --use-flash-attn" 151 | 152 | CHECKPOINT_ARGUMENTS="\ 153 | --save-interval 2000 \ 154 | --save ./${EXP_DIR}" 155 | 156 | EVALUATION_ARGUMENTS="\ 157 | --eval-iters 100 \ 158 | --log-interval 100 \ 159 | --eval-interval 1000" 160 | 161 | torchrun ${DISTRIBUTED_ARGUMENTS} \ 162 | third_party/Megatron-LM/pretrain_gpt.py \ 163 | ${MOE_ARGUMENTS} \ 164 | ${MODEL_ARGUMENTS} \ 165 | ${TRAINING_ARGUMENTS} \ 166 | ${DATA_ARGUMENTS} \ 167 | ${COMPUTE_ARGUMENTS} \ 168 | ${CHECKPOINT_ARGUMENTS} \ 169 | ${EVALUATION_ARGUMENTS} |& tee ./${EXP_DIR}/train.log 170 | -------------------------------------------------------------------------------- /scripts/olmoe-gantry.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | 4 | CONFIG_PATH=configs/olmoe17/olmoe-8x1b-newhp-newds-final.yml 5 | ARGS='--run_name=olmoe-8x1b-newhp-newds-final --save-overwrite --fsdp.sharding_strategy=FULL_SHARD --device_train_microbatch_size=4 --canceled_check_interval=9999999' 6 | 7 | #NUM_NODES=1 8 | #NUM_NODES=8 9 | #NUM_NODES=16 10 | NUM_NODES=32 11 | BEAKER_REPLICA_RANK=0 12 | 13 | gantry run \ 14 | --weka oe-training-default:/weka/oe-training-default \ 15 | --allow-dirty \ 16 | --preemptible \ 17 | --priority urgent \ 18 | --workspace ai2/olmoe \ 19 | --task-name olmoe \ 20 | --description olmoe \ 21 | --beaker-image shanea/olmo-torch2.2-gantry \ 22 | --budget ai2/oe-training \ 23 | --cluster ai2/jupiter-cirrascale-2 \ 24 | --gpus 8 \ 25 | --replicas "${NUM_NODES}" \ 26 | --env-secret WANDB_API_KEY=WANDB_API_KEY \ 27 | --env-secret AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID \ 28 | --env-secret AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY \ 29 | --env-secret R2_ENDPOINT_URL=R2_ENDPOINT_URL \ 30 | --leader-selection \ 31 | --host-networking \ 32 | --env LOG_FILTER_TYPE=local_rank0_only \ 33 | --env OMP_NUM_THREADS=8 \ 34 | --env OLMO_TASK=model \ 35 | --shared-memory 10GiB \ 36 | --venv base \ 37 | --yes \ 38 | --synchronized-start-timeout 60m \ 39 | -- /bin/bash -c "pip install --upgrade torch==2.3.0; pip install --upgrade flash-attn --no-build-isolation; pip install git+https://github.com/Muennighoff/megablocks.git@zloss; mkdir -p /root/.cache; pushd /root/.cache; curl "https://storage.googleapis.com/dirkgr-public/huggingface_cache_v3.tar.gz" | tar --keep-newer-files -xzf -; popd; export HF_DATASETS_OFFLINE=1; export NCCL_IB_HCA=^=mlx5_bond_0; SLURM_JOB_ID=${BEAKER_JOB_ID} torchrun --nnodes ${NUM_NODES}:${NUM_NODES} --node_rank ${BEAKER_REPLICA_RANK} --nproc-per-node 8 --rdzv_id=12347 --rdzv_backend=c10d --rdzv_conf='read_timeout=420' --rdzv_endpoint=\$BEAKER_LEADER_REPLICA_HOSTNAME:29400 scripts/train.py ${CONFIG_PATH} ${ARGS}" 40 | 41 | # Single node: 42 | #--rdzv_endpoint=\$BEAKER_NODE_HOSTNAME:29400 43 | # Multinode: 44 | #--rdzv_endpoint=\$BEAKER_LEADER_REPLICA_HOSTNAME:29400 45 | # --mount /net/nfs.cirrascale/allennlp/petew/cache:/root/.cache \ 46 | #--node_rank=$BEAKER_REPLICA_RANK 47 | # --nfs \ -------------------------------------------------------------------------------- /scripts/routing_mixtral_v2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_mixtral_v2.jpg -------------------------------------------------------------------------------- /scripts/routing_olmoe_v2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_olmoe_v2.jpg -------------------------------------------------------------------------------- /scripts/routing_output.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output.zip -------------------------------------------------------------------------------- /scripts/routing_output/mistral/eid2token/arxiv.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/eid2token/arxiv.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/eid2token/book.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/eid2token/book.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/eid2token/c4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/eid2token/c4.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/eid2token/github.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/eid2token/github.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/eid2token/wikipedia.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/eid2token/wikipedia.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts/arxiv.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts/arxiv.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts/book.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts/book.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts/c4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts/c4.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts/github.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts/github.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts/wikipedia.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts/wikipedia.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_crosslayer/arxiv.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer/arxiv.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_crosslayer/book.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer/book.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_crosslayer/c4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer/c4.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_crosslayer/github.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer/github.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_crosslayer/wikipedia.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer/wikipedia.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_crosslayer_top1/arxiv.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer_top1/arxiv.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_crosslayer_top1/book.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer_top1/book.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_crosslayer_top1/c4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer_top1/c4.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_crosslayer_top1/github.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer_top1/github.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_crosslayer_top1/wikipedia.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_crosslayer_top1/wikipedia.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_top1/arxiv.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_top1/arxiv.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_top1/book.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_top1/book.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_top1/c4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_top1/c4.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_top1/github.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_top1/github.pkl -------------------------------------------------------------------------------- /scripts/routing_output/mistral/expert_counts_top1/wikipedia.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/mistral/expert_counts_top1/wikipedia.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe-dpo/expert_counts/tulu.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe-dpo/expert_counts/tulu.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe-sft/expert_counts/tulu.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe-sft/expert_counts/tulu.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/eid2token/arxiv.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/eid2token/arxiv.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/eid2token/book.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/eid2token/book.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/eid2token/c4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/eid2token/c4.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/eid2token/github.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/eid2token/github.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/eid2token/wikipedia.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/eid2token/wikipedia.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts/arxiv.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts/arxiv.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts/book.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts/book.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts/c4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts/c4.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts/github.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts/github.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts/tulu.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts/tulu.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts/wikipedia.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts/wikipedia.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_crosslayer/arxiv.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer/arxiv.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_crosslayer/book.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer/book.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_crosslayer/c4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer/c4.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_crosslayer/github.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer/github.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_crosslayer/wikipedia.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer/wikipedia.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_crosslayer_top1/arxiv.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/arxiv.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_crosslayer_top1/book.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/book.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_crosslayer_top1/c4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/c4.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_crosslayer_top1/github.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/github.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_crosslayer_top1/wikipedia.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_crosslayer_top1/wikipedia.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_top1/arxiv.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_top1/arxiv.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_top1/book.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_top1/book.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_top1/c4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_top1/c4.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_top1/github.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_top1/github.pkl -------------------------------------------------------------------------------- /scripts/routing_output/olmoe/expert_counts_top1/wikipedia.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/olmoe/expert_counts_top1/wikipedia.pkl -------------------------------------------------------------------------------- /scripts/routing_output/routing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/routing.jpg -------------------------------------------------------------------------------- /scripts/routing_output/routing_prob_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/scripts/routing_output/routing_prob_distribution.png -------------------------------------------------------------------------------- /scripts/run_dclm_evals_heavy.sh: -------------------------------------------------------------------------------- 1 | # Evaluate model checkpoints using DCLM eval suite. 2 | # Usage: bash script/run_dclm_evals.sh 3 | # Evaluates on all the `heavy` tasks from DCLM, plus the humaneval code tasks. 4 | 5 | # Run using this conda env: `/net/nfs.cirrascale/allennlp/davidw/miniconda3/envs/dclm` 6 | 7 | DCLM_DIR=/net/nfs.cirrascale/allennlp/davidw/proj/dclm/eval 8 | MODEL_DIR=/net/nfs.cirrascale/allennlp/davidw/checkpoints/moe-release 9 | WKDIR=$(pwd) 10 | METRICS_DIR=$WKDIR/results/dclm 11 | 12 | mkdir -p $METRICS_DIR 13 | 14 | 15 | declare -a models=( 16 | OLMoE-7B-A1B/main 17 | OLMoE-7B-A1B/step1220000-tokens5117B 18 | OLMoE-7B-A1B/step1223842-tokens5100B 19 | jetmoe-8b/main 20 | ) 21 | 22 | 23 | cd $DCLM_DIR 24 | export TQDM_DISABLE=1 25 | 26 | for model in "${models[@]}" 27 | do 28 | out_name=${model//\//-} 29 | out_file=$METRICS_DIR/heavy-${out_name}.json 30 | 31 | mason \ 32 | --cluster ai2/pluto-cirrascale \ 33 | --budget ai2/oe-training \ 34 | --gpus 1 \ 35 | --workspace ai2/olmoe \ 36 | --description "Run DCLM evals for MoE model $name" \ 37 | --task_name "eval-${out_name}" \ 38 | --priority high \ 39 | --preemptible \ 40 | -- \ 41 | python eval_openlm_ckpt.py \ 42 | --hf-model $MODEL_DIR/$model \ 43 | --tokenizer $MODEL_DIR/$model \ 44 | --eval-yaml heavy.yaml \ 45 | --output-file $out_file \ 46 | --use-temp-working-dir 47 | done 48 | -------------------------------------------------------------------------------- /scripts/run_dclm_evals_heavy_olmo.sh: -------------------------------------------------------------------------------- 1 | # Evaluate olmo 1B and 7B. 2 | # Usage: bash script/run_dclm_evals_heavy_olmo.sh 3 | # Evaluates on all the `heavy` tasks from DCLM. 4 | 5 | # Run using this conda env: `/net/nfs.cirrascale/allennlp/davidw/miniconda3/envs/dclm` 6 | 7 | DCLM_DIR=/net/nfs.cirrascale/allennlp/davidw/proj/dclm/eval 8 | WKDIR=$(pwd) 9 | METRICS_DIR=$WKDIR/results/dclm 10 | 11 | mkdir -p $METRICS_DIR 12 | 13 | 14 | declare -a models=( 15 | allenai/OLMo-7B-0724-hf 16 | allenai/OLMo-1B-0724-hf 17 | ) 18 | 19 | 20 | cd $DCLM_DIR 21 | export TQDM_DISABLE=1 22 | 23 | for model in "${models[@]}" 24 | do 25 | out_name=$(echo "$model" | awk -F '/' '{ print $NF }') 26 | out_file=$METRICS_DIR/heavy-${out_name}.json 27 | 28 | mason \ 29 | --cluster ai2/s2-cirrascale \ 30 | --budget ai2/oe-training \ 31 | --gpus 4 \ 32 | --workspace ai2/olmoe \ 33 | --description "Run DCLM evals for OLMo model $name" \ 34 | --task_name "eval-${out_name}" \ 35 | --priority high \ 36 | -- \ 37 | python eval_openlm_ckpt.py \ 38 | --hf-model $model \ 39 | --tokenizer $model \ 40 | --eval-yaml heavy.yaml \ 41 | --output-file $out_file \ 42 | --use-temp-working-dir 43 | done 44 | -------------------------------------------------------------------------------- /scripts/run_dclm_evals_humaneval.sh: -------------------------------------------------------------------------------- 1 | # Evaluate model checkpoints using DCLM eval suite. 2 | # Usage: bash script/run_dclm_evals.sh 3 | # Evaluates on all the `humaneval` tasks from DCLM, plus the humaneval code tasks. 4 | 5 | # Run using this conda env: `/net/nfs.cirrascale/allennlp/davidw/miniconda3/envs/dclm` 6 | 7 | DCLM_DIR=/net/nfs.cirrascale/allennlp/davidw/proj/dclm/eval 8 | MODEL_DIR=/net/nfs.cirrascale/allennlp/davidw/checkpoints/moe-release 9 | WKDIR=$(pwd) 10 | METRICS_DIR=$WKDIR/results/dclm 11 | 12 | mkdir -p $METRICS_DIR 13 | 14 | 15 | declare -a models=( 16 | OLMoE-7B-A1B/main 17 | OLMoE-7B-A1B/step1220000-tokens5117B 18 | OLMoE-7B-A1B/step1223842-tokens5100B 19 | jetmoe-8b/main 20 | ) 21 | 22 | 23 | cd $DCLM_DIR 24 | export TQDM_DISABLE=1 25 | 26 | for model in "${models[@]}" 27 | do 28 | out_name=${model//\//-} 29 | out_file=$METRICS_DIR/humaneval-${out_name}.json 30 | 31 | mason \ 32 | --cluster ai2/pluto-cirrascale \ 33 | --budget ai2/oe-training \ 34 | --gpus 1 \ 35 | --workspace ai2/olmoe \ 36 | --description "Run HumanEval DCLM evals for MoE model $name" \ 37 | --task_name "eval-${out_name}" \ 38 | --priority high \ 39 | --preemptible \ 40 | -- \ 41 | python eval_openlm_ckpt.py \ 42 | --hf-model $MODEL_DIR/$model \ 43 | --tokenizer $MODEL_DIR/$model \ 44 | --eval-yaml $WKDIR/script/humaneval.yaml \ 45 | --output-file $out_file \ 46 | --use-temp-working-dir 47 | done 48 | -------------------------------------------------------------------------------- /scripts/run_moe_analysis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import time 5 | import pickle as pkl 6 | import json 7 | import numpy as np 8 | 9 | from collections import defaultdict, Counter 10 | 11 | from datasets import load_dataset 12 | from transformers import OlmoeForCausalLM, AutoTokenizer, AutoModelForCausalLM 13 | from huggingface_hub import list_repo_refs 14 | 15 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 16 | if os.path.exists("hf_access_token.txt"): 17 | with open("hf_access_token.txt") as f: 18 | token = f.read().strip() 19 | else: 20 | token = None 21 | 22 | start_time = time.time() 23 | 24 | def tokenize_c4(tokenized_path, sample_ratio=0.005, model="olmoe"): 25 | np.random.seed(2024) 26 | with open(tokenized_path, "w") as f: 27 | if model == "olmoe": 28 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token) 29 | elif model == "mixtral": 30 | tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1", token=token) 31 | else: raise NotImplementedError(f"model={model}") 32 | cnt = 0 33 | for row in load_dataset("allenai/c4", "en", split="validation", streaming=True): 34 | if np.random.random() < sample_ratio: 35 | input_ids = tokenizer(row["text"])["input_ids"] 36 | f.write(json.dumps({"input_ids": input_ids})+"\n") 37 | cnt += 1 38 | print(f"Loaded {cnt} lines!") 39 | 40 | def load_c4(tokenized_path, bs): 41 | tokens = [] 42 | with open(tokenized_path, "r") as f: 43 | for line in f: 44 | tokens += json.loads(line)["input_ids"] 45 | while len(tokens) >= bs: 46 | yield tokens[:bs] 47 | tokens = tokens[bs:] 48 | 49 | def load_model(revision="main", model="olmoe"): 50 | if model == "olmoe": 51 | model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token, revision=revision).to(DEVICE) 52 | elif model == "mixtral": 53 | model = AutoModelForCausalLM.from_pretrained( 54 | "mistralai/Mixtral-8x7B-v0.1", token=token, revision=revision, device_map="auto" 55 | ) 56 | else: raise NotImplementedError(f"model={model}") 57 | return model 58 | 59 | def do_inference(args, run_all_checkpoints=False): 60 | run(args, "main") 61 | 62 | if run_all_checkpoints: 63 | out = list_repo_refs("allenai/OLMoE-1B-7B-0924", token=token) 64 | cand_branches = [b.name for b in out.branches] 65 | all_branches = [] 66 | # Old checkpoints previously used: ["15000", "130000", "250000", "490000"]: 67 | # Percentages of pretraining: 0.00408549469 ; 0.0980518727 ; 0.20018924011 ; 0.40037848022 68 | for step in ["5000", "120000", "245000", "490000"]: 69 | branches = [name for name in cand_branches if name.startswith(f"step{step}-tokens") and name.endswith("B")] 70 | assert len(branches) == 1, f"{step} ; {branches}" 71 | all_branches.append(branches[0]) 72 | 73 | for b in all_branches: 74 | run(args, b, postfix="_"+b.split("-")[0][4:], save_exp_id_only=True) 75 | 76 | def run(args, revision, postfix="", length=2048, save_exp_id_only=False): 77 | save_path = os.path.join(args.out_dir, f"c4_results{postfix}.jsonl") 78 | assert not os.path.exists(save_path), f"{save_path} already exists" 79 | 80 | model = load_model(revision, model=args.model) 81 | 82 | results = [] 83 | start_time = time.time() 84 | 85 | print("Start inference") 86 | for input_ids in load_c4(args.tokenized_path, length): 87 | input_ids = torch.LongTensor(input_ids).reshape(1, -1).to(DEVICE) 88 | out = model(input_ids=input_ids, output_router_logits=True) 89 | 90 | input_ids = input_ids[0].detach().cpu().numpy().tolist() 91 | logits = out["logits"][0].detach().cpu().numpy() 92 | predicted_token_ids = np.argmax(logits, -1).tolist() 93 | router_logits = [l.detach().cpu().numpy() for l in out["router_logits"]] 94 | if args.model == "olmoe": 95 | assert len(router_logits) == 16 96 | exp_ids = np.stack([np.argsort(-logits, -1)[:, :8].tolist() for logits in router_logits], -1).tolist() 97 | assert np.array(exp_ids).shape == (2048, 8, 16) 98 | elif args.model == "mixtral": 99 | assert len(router_logits) == 32 100 | exp_ids = np.stack([np.argsort(-logits, -1)[:, :2].tolist() for logits in router_logits], -1).tolist() 101 | assert np.array(exp_ids).shape == (2048, 2, 32) 102 | else: raise NotImplementedError(f"model={args.model}") 103 | results.append({"input_ids": input_ids, "predicted_token_ids": predicted_token_ids, "exp_ids": exp_ids}) 104 | 105 | if len(results) % 50 == 0: 106 | print("Finish %d batches (%dmin)" % (len(results), (time.time()-start_time)/60)) 107 | 108 | with open(save_path, "w") as f: 109 | for r in results: 110 | f.write(json.dumps(r)+"\n") 111 | 112 | print("Saved %d batches to %s" % (len(results), save_path)) 113 | 114 | def do_ckpt_analysis(args): 115 | FONTSIZE = 36 116 | with open(os.path.join(args.out_dir, "c4_results.jsonl"), "r") as f: 117 | results = [] 118 | for line in f: 119 | results.append(json.loads(line)) 120 | 121 | assert args.topk in [1, 8, 18] 122 | 123 | import matplotlib.pylab as plt 124 | if args.topk == 18: 125 | fig, axes = plt.subplots(figsize=(24, 8), ncols=2, nrows=1, sharey=True, layout='constrained') 126 | titles = ["Top-k=1", "Top-k=8"] 127 | else: 128 | fig, ax = plt.subplots(figsize=(16, 8)) 129 | axes = [ax] 130 | titles = ["Top-k=1"] if args.topk == 1 else ["Top-k=8"] 131 | 132 | def compare_with_early_results(postfix, topk): 133 | with open(os.path.join(args.out_dir, "c4_results_{}.jsonl".format(postfix)), "rb") as f: 134 | early_results = [] 135 | for line in f: 136 | early_results.append(json.loads(line)) 137 | 138 | equals = Counter() 139 | total = 0 140 | ratio_per_layer = defaultdict(list) 141 | 142 | for r, er in zip(results, early_results): 143 | assert r["input_ids"] == er["input_ids"] 144 | 145 | if topk == 1: 146 | exp_ids = np.array(r["exp_ids"])[:, 0, :] 147 | e_exp_ids = np.array(er["exp_ids"])[:, 0, :] 148 | assert exp_ids.shape==e_exp_ids.shape, (exp_ids.shape, e_exp_ids.shape) 149 | curr_equals = exp_ids==e_exp_ids # [token_cnt, n_layers] 150 | for i, _curr_equals in enumerate(curr_equals.transpose(1, 0).tolist()): 151 | equals[i] += np.sum(_curr_equals) 152 | total += len(curr_equals) 153 | elif topk == 8: 154 | exp_ids = np.array(r["exp_ids"]) 155 | e_exp_ids = np.array(er["exp_ids"]) 156 | for layer_idx in range(16): 157 | indices = exp_ids[:, :, layer_idx] 158 | e_indices = e_exp_ids[:, :, layer_idx] 159 | assert indices.shape==e_indices.shape==(2048, 8) 160 | for _indices, _e_indices in zip(indices.tolist(), e_indices.tolist()): 161 | ratio = len(set(_indices) & set(_e_indices)) / 8 162 | ratio_per_layer[layer_idx].append(ratio) 163 | else: 164 | raise NotImplementedError(f"topk={topk}") 165 | 166 | row = [postfix] 167 | result = [] 168 | for i in range(16): 169 | if topk == 1: 170 | row.append("%.1f" % (100 * equals[i] / total)) 171 | result.append(equals[i] / total) 172 | else: 173 | row.append("%.1f" % (100 * np.mean(ratio_per_layer[i]))) 174 | result.append(np.mean(ratio_per_layer[i])) 175 | 176 | pt.add_row(row) 177 | return result 178 | 179 | x = ["1", "10", "20", "40"] 180 | palette = ["#9e0142", "#c12a49", "#d53e4f", "#e65949", "#f46d43", "#f78a52", "#fdae61", "#ffd700", "#ffffbf", "#d8ef94", "#66c2a5", "#429db4", "#3288bd", "#5a71c1", "#7c4ab3", "#5e4fa2"] 181 | palette = [ 182 | "#F0539B", "#43C5E0", "#2E3168", "#FDBE15", 183 | "#F0539B", "#43C5E0", "#2E3168", "#FDBE15", 184 | "#F0539B", "#43C5E0", "#2E3168", "#FDBE15", 185 | "#F0539B", "#43C5E0", "#2E3168", "#FDBE15", 186 | ] 187 | alpha = [ 188 | 0.8, 0.8, 0.8, 0.8, 189 | 0.6, 0.6, 0.6, 0.6, 190 | 0.4, 0.4, 0.4, 0.4, 191 | 0.2, 0.2, 0.2, 0.2, 192 | ] 193 | linestyle = [ 194 | "-", "-", "-", "-", 195 | "--", "--", "--", "--", 196 | ":", ":", ":", ":", 197 | "-.", "-.", "-.", "-.", 198 | ] 199 | for i, ax in enumerate(axes): 200 | from prettytable import PrettyTable 201 | pt = PrettyTable() 202 | pt.field_names = [""] + [str(i) for i in range(16)] 203 | topk = int(titles[i][-1]) 204 | #r1 = compare_with_early_results(15000, topk) # 1% 205 | #r2 = compare_with_early_results(130000, topk) # 10% 206 | #r3 = compare_with_early_results(250000, topk) # 20% 207 | #r4 = compare_with_early_results(490000, topk) # 40% 208 | r1 = compare_with_early_results(5000, topk) # 1% 209 | r2 = compare_with_early_results(120000, topk) # 10% 210 | r3 = compare_with_early_results(245000, topk) # 20% 211 | r4 = compare_with_early_results(490000, topk) # 40% 212 | merged_results = np.array([r1, r2, r3, r4]) # [4, 16] 213 | # Define the original 4 colors 214 | colors = ["#F0539B", "#43C5E0", "#2E3168", "#FDBE15"] 215 | 216 | # Create a custom colormap using the defined colors 217 | from matplotlib.colors import LinearSegmentedColormap 218 | cmap = LinearSegmentedColormap.from_list("custom_theme", colors, N=16) 219 | # Generate 16 colors from the colormap 220 | additional_colors = [cmap(i / 15) for i in range(16)] 221 | 222 | #import seaborn as sns 223 | # Define the original 4 colors 224 | #colors = ["#F0539B", "#43C5E0", "#2E3168", "#FDBE15"] 225 | # Create a seaborn color palette with 16 colors based on the 4 theme colors 226 | #palette = sns.color_palette(colors, n_colors=16) 227 | 228 | for j in range(16): 229 | #ax.plot(x, merged_results[:, j] * 100, label=j, color=palette[j], marker='o', markersize=12, linewidth=6) 230 | #ax.plot(x, merged_results[:, j] * 100, label=j, color=palette[j], marker='o', markersize=12, linewidth=6, linestyle=linestyle[j]) 231 | ax.plot(x, merged_results[:, j] * 100, label=j, color=additional_colors[j], marker='o', markersize=12, linewidth=6) 232 | ax.tick_params(axis='both', which='major', labelsize=FONTSIZE) 233 | ax.set_title(titles[i], fontsize=FONTSIZE, fontweight='bold') 234 | ax.spines['top'].set_visible(False) 235 | ax.spines['right'].set_visible(False) 236 | # ax.set_ylim(0, 0.9) 237 | if args.topk in [8, 18]: 238 | plt.legend(frameon=True, title="Layer ID", title_fontsize=FONTSIZE, fontsize=FONTSIZE, columnspacing=0.4, labelspacing=0.4, ncol=4) 239 | fig.supxlabel('Pretraining stage (%)', fontsize=FONTSIZE, fontweight='bold') 240 | # fig.supylabel('% of active experts matching\nfinal checkpoint for the same input data', fontsize=FONTSIZE, fontweight='bold') 241 | # fig.supylabel('Active experts matching\n final checkpoint (%)', fontsize=FONTSIZE, fontweight='bold') 242 | fig.supylabel('Router saturation (%)', fontsize=FONTSIZE, fontweight='bold') 243 | plt.savefig(os.path.join(args.fig_dir, f"top{args.topk}_changes_over_checkpoints.png")) 244 | plt.savefig(os.path.join(args.fig_dir, f"top{args.topk}_changes_over_checkpoints.pdf")) 245 | 246 | def do_coactivation_analysis(args): 247 | FONTSIZE = 18 248 | with open(os.path.join(args.out_dir, "c4_results.jsonl"), "rb") as f: 249 | results = [] 250 | for line in f: 251 | try: 252 | results.append(json.loads(line)) 253 | except Exception as e: 254 | print("Failed to load line", e) 255 | break 256 | 257 | pairwise_counter = Counter() 258 | single_counter = Counter() 259 | from more_itertools import pairwise 260 | 261 | layer_num = args.layer_num 262 | for result in results: 263 | for indices in np.array(result["exp_ids"])[:, :, layer_num].tolist(): 264 | assert len(indices) == 8 265 | for i in indices: 266 | single_counter[i] += 1 267 | for (a, b) in pairwise(indices): 268 | pairwise_counter[(a, b)] += 1 269 | pairwise_counter[(b, a)] += 1 270 | pairwise_probs = { 271 | (a, b): pairwise_counter[(a, b)] / single_counter[a] 272 | for (a, b) in pairwise_counter 273 | } 274 | 275 | new_idx_to_orig_idx = [] 276 | N = 16 277 | 278 | for (a, b), p in sorted(pairwise_probs.items(), key=lambda x: -x[1]): 279 | if a not in new_idx_to_orig_idx: 280 | new_idx_to_orig_idx.append(a) 281 | if b not in new_idx_to_orig_idx: 282 | new_idx_to_orig_idx.append(b) 283 | if len(new_idx_to_orig_idx) == N: break 284 | 285 | scores = np.zeros((N, N)) 286 | labels_x, labels_y = [], [] 287 | for i in range(N): 288 | labels_x.append(new_idx_to_orig_idx[i]) 289 | for j in range(N): 290 | scores[i, j] = pairwise_probs.get( 291 | (new_idx_to_orig_idx[i], new_idx_to_orig_idx[j]), 0) * 100 292 | if i == 0: 293 | labels_y.append(new_idx_to_orig_idx[j]) 294 | 295 | import matplotlib.pylab as plt 296 | import seaborn as sns 297 | # ax = sns.heatmap(scores, cmap="Reds", linewidth=.5, center=30, vmax=60, xticklabels=labels_x, yticklabels=labels_y) 298 | from matplotlib.colors import LinearSegmentedColormap 299 | # Define a custom colormap using shades of #F0539B 300 | cmap = LinearSegmentedColormap.from_list("custom_cmap", ["white", "#F0539B", "#4A0033"], N=256) 301 | 302 | # Generate the heatmap with the custom colormap 303 | ax = sns.heatmap( 304 | scores, 305 | cmap=cmap, 306 | linewidth=0.5, 307 | center=30, 308 | vmax=60, 309 | xticklabels=labels_x, 310 | yticklabels=labels_y, 311 | cbar_kws={'ticks': [0, 15, 30, 45, 60]} # Optional: Customize color bar ticks 312 | ) 313 | # Increase tick size 314 | ax.tick_params(axis='both', which='major', labelsize=FONTSIZE) 315 | # Increase colorbar ticks 316 | cbar = ax.collections[0].colorbar 317 | cbar.ax.tick_params(labelsize=FONTSIZE) 318 | # This sets the yticks "upright" with 0, as opposed to sideways with 90. 319 | plt.yticks(rotation=0) 320 | plt.xticks(rotation=-90) 321 | # Set title 322 | plt.title(f"Layer {layer_num}", fontsize=FONTSIZE, fontweight='bold') 323 | plt.savefig(os.path.join(args.fig_dir, f"layer_{layer_num}_heatmap.png")) 324 | plt.savefig(os.path.join(args.fig_dir, f"layer_{layer_num}_heatmap.pdf")) 325 | 326 | def do_token_analysis(args, tex_format=True): 327 | with open(os.path.join(args.out_dir, "c4_results.jsonl"), "rb") as f: 328 | results = [] 329 | for line in f: 330 | try: 331 | results.append(json.loads(line)) 332 | except Exception as e: 333 | print("Failed to load line", e) 334 | break 335 | assert args.model == "olmoe" 336 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token) 337 | 338 | input_id_to_exp = defaultdict(list) 339 | gt_next_token_to_exp = defaultdict(list) 340 | predicted_next_token_to_exp = defaultdict(list) 341 | 342 | for result in results: 343 | input_ids = result["input_ids"] 344 | gt_next_token_ids = input_ids[1:] 345 | predicted_next_token_ids = result["predicted_token_ids"] 346 | exp_ids = np.array(result["exp_ids"])[:, 0, args.layer_num].tolist() 347 | 348 | for _id, exp_id in zip(input_ids, exp_ids): 349 | input_id_to_exp[_id].append(exp_id) 350 | for _id, exp_id in zip(gt_next_token_ids, exp_ids): 351 | gt_next_token_to_exp[_id].append(exp_id) 352 | for _id, exp_id in zip(predicted_next_token_ids, exp_ids): 353 | predicted_next_token_to_exp[_id].append(exp_id) 354 | 355 | def print_avg(id_to_exp): 356 | probs = [] 357 | exp_id_to_vocabs = defaultdict(list) 358 | 359 | for _id, exp_ids in id_to_exp.items(): 360 | most_freq_id, val = sorted(Counter(exp_ids).items(), key=lambda x: -x[1])[0] 361 | probs.append(val / len(exp_ids)) 362 | if len(exp_ids) >= 10: 363 | exp_id_to_vocabs[most_freq_id].append((_id, val/len(exp_ids))) 364 | 365 | # if you want to allow token IDs to appear for multiple experts: 366 | """ 367 | c = Counter(exp_ids) 368 | # import pdb; pdb.set_trace() 369 | for exp_id in c: 370 | if len(exp_ids) >= 10: 371 | exp_id_to_vocabs[exp_id].append((_id, c[exp_id] / len(exp_ids))) 372 | """ 373 | 374 | print("Average probability:", np.mean(probs)) 375 | return exp_id_to_vocabs 376 | 377 | exp_id_to_vocabs = print_avg(input_id_to_exp) 378 | print_avg(gt_next_token_to_exp) 379 | exp_id_to_predicted_vocabs = print_avg(predicted_next_token_to_exp) 380 | 381 | with open(f"exp_id_to_vocabs_layer{args.layer_num}.txt", "w") as f: 382 | #for exp_id, vocabs in sorted(exp_id_to_vocabs.items(), key=lambda x: -np.mean([p for _, p in x[1]])): 383 | for exp_id, vocabs in sorted(exp_id_to_vocabs.items(), key=lambda x: -np.mean([p for _, p in x[1]]) + -np.mean([p for _, p in exp_id_to_predicted_vocabs[x[0]]])): 384 | #for exp_id, vocabs in sorted(exp_id_to_vocabs.items()): 385 | text = "exp_id: %d" % exp_id 386 | for vocab, p in sorted(vocabs, key=lambda x: -x[1])[:15]: 387 | if tex_format: 388 | text += " \colorbox{lightOlmoeYellow}{%s} (%d\\%%)" % (tokenizer._decode(vocab), 100*p) # + str(vocab) # for unknowns 389 | else: 390 | text += " %s (%d%%)" % (tokenizer._decode(vocab), 100*p) 391 | f.write(text + "\n\n") 392 | 393 | with open(f"exp_id_to_predicted_vocabs_layer{args.layer_num}.txt", "w") as f: 394 | #for exp_id, vocabs in sorted(exp_id_to_predicted_vocabs.items(), key=lambda x: -np.mean([p for _, p in x[1]])): 395 | for exp_id, vocabs in sorted(exp_id_to_predicted_vocabs.items(), key=lambda x: -np.mean([p for _, p in x[1]]) + -np.mean([p for _, p in exp_id_to_vocabs[x[0]]])): 396 | #for exp_id, vocabs in sorted(exp_id_to_predicted_vocabs.items()): 397 | text = "exp_id: %d" % exp_id 398 | for vocab, p in sorted(vocabs, key=lambda x: -x[1])[:15]: 399 | if tex_format: 400 | text += " \colorbox{lightOlmoeYellow}{%s} (%d\\%%)" % (tokenizer._decode(vocab), 100*p) # + str(vocab) # for unknowns 401 | else: 402 | text += " %s (%d%%)" % (tokenizer._decode(vocab), 100*p) 403 | f.write(text + "\n\n") 404 | 405 | def do_token_analysis_layers(args, tex_format=True): 406 | FONTSIZE = 28 407 | with open(os.path.join(args.out_dir, "c4_results.jsonl"), "rb") as f: 408 | results = [] 409 | for line in f: 410 | results.append(json.loads(line)) 411 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token) 412 | 413 | def print_avg(id_to_exp): 414 | probs = [] 415 | exp_id_to_vocabs = defaultdict(list) 416 | 417 | for _id, exp_ids in id_to_exp.items(): 418 | most_freq_id, val = sorted(Counter(exp_ids).items(), key=lambda x: -x[1])[0] 419 | probs.append(val / len(exp_ids)) 420 | if len(exp_ids) >= 10: 421 | # if len(exp_ids) >= 8: 422 | exp_id_to_vocabs[most_freq_id].append((_id, val/len(exp_ids))) 423 | 424 | print("Average probability:", np.mean(probs)) 425 | return np.mean(probs) 426 | 427 | layer_num_to_probs = {} 428 | for layer_num in list(range(16)): 429 | print(f"Layer {layer_num}") 430 | input_id_to_exp = defaultdict(list) 431 | gt_next_token_to_exp = defaultdict(list) 432 | predicted_next_token_to_exp = defaultdict(list) 433 | for result in results: 434 | input_ids = result["input_ids"] 435 | gt_next_token_ids = input_ids[1:] 436 | predicted_next_token_ids = result["predicted_token_ids"] 437 | exp_ids = np.array(result["exp_ids"])[:, 0, layer_num].tolist() 438 | 439 | for _id, exp_id in zip(input_ids, exp_ids): 440 | input_id_to_exp[_id].append(exp_id) 441 | for _id, exp_id in zip(gt_next_token_ids, exp_ids): 442 | gt_next_token_to_exp[_id].append(exp_id) 443 | for _id, exp_id in zip(predicted_next_token_ids, exp_ids): 444 | predicted_next_token_to_exp[_id].append(exp_id) 445 | 446 | input_prob = print_avg(input_id_to_exp) 447 | gt_prob = print_avg(gt_next_token_to_exp) 448 | output_prob = print_avg(predicted_next_token_to_exp) 449 | 450 | layer_num_to_probs[layer_num] = (input_prob, gt_prob, output_prob) 451 | 452 | import matplotlib.pylab as plt 453 | fig, ax = plt.subplots(figsize=(12, 8)) 454 | x = list(range(16)) 455 | y = [layer_num_to_probs[i][0] * 100 for i in x] 456 | ax.plot(x, y, label="Input tokens", color="#F0539B", marker='o', markersize=12, linewidth=6) 457 | y = [layer_num_to_probs[i][2] * 100 for i in x] 458 | ax.plot(x, y, label="Predicted output tokens", color="#2E3168", marker='o', markersize=12, linewidth=6) 459 | y = [layer_num_to_probs[i][1] * 100 for i in x] 460 | ax.plot(x, y, label="Ground-truth output tokens", color="#43C5E0", marker='o', markersize=12, linewidth=6) 461 | ax.tick_params(axis='both', which='major', labelsize=FONTSIZE) 462 | ax.set_xticks(x) 463 | ax.set_xlabel("Layer ID", fontsize=FONTSIZE, fontweight='bold') 464 | ax.set_ylabel("Vocabulary specialization (%)", fontsize=FONTSIZE, fontweight='bold') 465 | ax.spines['top'].set_visible(False) 466 | ax.spines['right'].set_visible(False) 467 | plt.legend(frameon=True, fontsize=FONTSIZE, columnspacing=0.4, labelspacing=0.4, loc="lower right") 468 | plt.savefig(os.path.join(args.fig_dir, f"layerwise_token_analysis.pdf")) 469 | plt.savefig(os.path.join(args.fig_dir, f"layerwise_token_analysis.png")) 470 | 471 | def do_token_analysis_experts(args, tex_format=True, do_sort=False): 472 | FONTSIZE = 28 473 | with open(os.path.join(args.out_dir, "c4_results.jsonl"), "rb") as f: 474 | results = [] 475 | for line in f: 476 | results.append(json.loads(line)) 477 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token) 478 | 479 | def print_avg(id_to_exp): 480 | probs = [] 481 | exp_id_to_probs = defaultdict(list) 482 | for _id, exp_ids in id_to_exp.items(): 483 | most_freq_id, val = sorted(Counter(exp_ids).items(), key=lambda x: -x[1])[0] 484 | if len(exp_ids) >= 10: 485 | probs.append(val / len(exp_ids)) 486 | exp_id_to_probs[most_freq_id].append(val / len(exp_ids)) 487 | # c = Counter(exp_ids) 488 | # import pdb; pdb.set_trace() 489 | # for exp_id in c: 490 | # exp_id_to_probs[exp_id].append(c[exp_id] / len(exp_ids)) 491 | 492 | # avg 493 | exp_id_to_probs = {k: (np.mean(v), len(v)) for k, v in exp_id_to_probs.items()} 494 | print(exp_id_to_probs) 495 | print("Average probability:", np.mean(probs)) 496 | return exp_id_to_probs, np.mean(probs) 497 | 498 | input_id_to_exp = defaultdict(list) 499 | gt_next_token_to_exp = defaultdict(list) 500 | predicted_next_token_to_exp = defaultdict(list) 501 | 502 | for result in results: 503 | input_ids = result["input_ids"] 504 | gt_next_token_ids = input_ids[1:] 505 | predicted_next_token_ids = result["predicted_token_ids"] 506 | exp_ids = np.array(result["exp_ids"])[:, 0, args.layer_num].tolist() 507 | 508 | for _id, exp_id in zip(input_ids, exp_ids): 509 | input_id_to_exp[_id].append(exp_id) 510 | for _id, exp_id in zip(gt_next_token_ids, exp_ids): 511 | gt_next_token_to_exp[_id].append(exp_id) 512 | for _id, exp_id in zip(predicted_next_token_ids, exp_ids): 513 | predicted_next_token_to_exp[_id].append(exp_id) 514 | 515 | input_exp_id_to_probs, input_prob = print_avg(input_id_to_exp) 516 | gt_exp_id_to_probs, gt_prob = print_avg(gt_next_token_to_exp) 517 | predicted_exp_id_to_probs, pred_prob = print_avg(predicted_next_token_to_exp) 518 | 519 | import matplotlib.pylab as plt 520 | fig, ax = plt.subplots(figsize=(12, 8)) 521 | x = list(range(64)) 522 | # Sort x by input_exp_id_to_probs 523 | x_sorted = sorted(x, key=lambda i: -input_exp_id_to_probs[i][0]) if do_sort else x 524 | y = [input_exp_id_to_probs[i][0] * 100 for i in x_sorted] 525 | ax.plot(x, y, label="Input tokens", color="#F0539B", marker='o', markersize=12, linewidth=6) 526 | y = [gt_exp_id_to_probs[i][0] * 100 for i in x_sorted] 527 | ax.plot(x, y, label="Predicted output tokens", color="#2E3168", marker='o', markersize=12, linewidth=6) 528 | y = [predicted_exp_id_to_probs[i][0] * 100 for i in x_sorted] 529 | ax.plot(x, y, label="Ground-truth output tokens", color="#43C5E0", marker='o', markersize=12, linewidth=6) 530 | # Draw horizontal prob lines for input_prob etc 531 | ax.axhline(input_prob * 100, color="#F0539B", linestyle='--', linewidth=3) 532 | ax.axhline(gt_prob * 100, color="#2E3168", linestyle='--', linewidth=3) 533 | ax.axhline(pred_prob * 100, color="#43C5E0", linestyle='--', linewidth=3) 534 | ax.tick_params(axis='both', which='major', labelsize=FONTSIZE) 535 | ax.set_xticks(x, x_sorted) 536 | ax.set_xlabel("Expert ID", fontsize=FONTSIZE, fontweight='bold') 537 | ax.set_ylabel("Vocabulary specialization (%)", fontsize=FONTSIZE, fontweight='bold') 538 | ax.spines['top'].set_visible(False) 539 | ax.spines['right'].set_visible(False) 540 | plt.title(f"Layer {args.layer_num}", fontsize=FONTSIZE, fontweight='bold') 541 | plt.legend(frameon=True, fontsize=FONTSIZE, columnspacing=0.4, labelspacing=0.4, loc="lower right") 542 | plt.savefig(os.path.join(args.fig_dir, f"vocabulary_specialization_experts.pdf")) 543 | plt.savefig(os.path.join(args.fig_dir, f"vocabulary_specialization_experts.png")) 544 | 545 | 546 | def do_token_analysis_layers_experts(args, do_sort=False, normalize=False): 547 | """normalize does not make a big difference, hence not used""" 548 | from matplotlib import rcParams 549 | rcParams.update({'figure.autolayout': True}) 550 | 551 | with open(os.path.join(args.out_dir, "c4_results.jsonl"), "r") as f: 552 | results = [] 553 | for line in f: 554 | results.append(json.loads(line)) 555 | 556 | assert args.topk in [1, 2, 8] 557 | 558 | with open(os.path.join(args.out_dir, "c4_results.jsonl"), "rb") as f: 559 | results = [] 560 | for line in f: 561 | results.append(json.loads(line)) 562 | 563 | import matplotlib.pylab as plt 564 | if args.model == "olmoe": 565 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token) 566 | random_prob = args.topk/64 567 | num_layers = 16 568 | # fig, axes = plt.subplots(figsize=(32, 8), ncols=2, nrows=1, sharey=True, layout='constrained', width_ratios=[1, 2]) 569 | fig, axes = plt.subplots(figsize=(32, 8), ncols=2, nrows=1, sharey=False, layout='constrained', width_ratios=[1, 2]) 570 | else: 571 | tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1", token=token) 572 | random_prob = args.topk/8 573 | num_layers = 32 574 | # fig, axes = plt.subplots(figsize=(32, 8), ncols=2, nrows=1, sharey=True, layout='constrained', width_ratios=[2, 1]) 575 | fig, axes = plt.subplots(figsize=(32, 8), ncols=2, nrows=1, sharey=False, layout='constrained', width_ratios=[2, 1]) 576 | 577 | def print_avg(id_to_exp): 578 | probs = [] 579 | exp_id_to_probs = defaultdict(list) 580 | for _id, exp_ids in id_to_exp.items(): 581 | most_freq_id, val = sorted(Counter(exp_ids).items(), key=lambda x: -x[1])[0] 582 | max_possible = len(exp_ids) // args.topk 583 | if normalize: 584 | # use random prob to normalize such that 0 is random and 1 is perfect 585 | prob = (val / max_possible - random_prob) / (1 - random_prob) 586 | else: 587 | prob = val / max_possible 588 | probs.append(prob) 589 | exp_id_to_probs[most_freq_id].append(prob) 590 | exp_id_to_probs = {k: (np.mean(v), len(v)) for k, v in exp_id_to_probs.items()} 591 | print("Average probability:", np.mean(probs)) 592 | return exp_id_to_probs, np.mean(probs) 593 | 594 | layer_num_to_probs = {} 595 | for layer_num in list(range(num_layers)): 596 | print(f"Layer {layer_num}") 597 | input_id_to_exp = defaultdict(list) 598 | gt_next_token_to_exp = defaultdict(list) 599 | predicted_next_token_to_exp = defaultdict(list) 600 | for result in results: 601 | input_ids = result["input_ids"] 602 | gt_next_token_ids = input_ids[1:] 603 | predicted_next_token_ids = result["predicted_token_ids"] 604 | exp_ids = np.array(result["exp_ids"])[:, :args.topk, layer_num].tolist() 605 | 606 | for _id, exp_id in zip(input_ids, exp_ids): 607 | # input_id_to_exp[_id].append(exp_id) 608 | input_id_to_exp[_id].extend(exp_id) 609 | for _id, exp_id in zip(gt_next_token_ids, exp_ids): 610 | # gt_next_token_to_exp[_id].append(exp_id) 611 | gt_next_token_to_exp[_id].extend(exp_id) 612 | for _id, exp_id in zip(predicted_next_token_ids, exp_ids): 613 | # predicted_next_token_to_exp[_id].append(exp_id) 614 | predicted_next_token_to_exp[_id].extend(exp_id) 615 | 616 | input_prob = print_avg(input_id_to_exp)[1] 617 | gt_prob = print_avg(gt_next_token_to_exp)[1] 618 | output_prob = print_avg(predicted_next_token_to_exp)[1] 619 | layer_num_to_probs[layer_num] = (input_prob, gt_prob, output_prob) 620 | 621 | # Specialization for one specific layer 622 | if layer_num == args.layer_num: 623 | input_exp_id_to_probs, input_prob_chosen = print_avg(input_id_to_exp) 624 | gt_exp_id_to_probs, gt_prob_chosen = print_avg(gt_next_token_to_exp) 625 | predicted_exp_id_to_probs, pred_prob_chosen = print_avg(predicted_next_token_to_exp) 626 | 627 | FONTSIZE = 34 628 | ### Layer spec ### 629 | ax = axes[0] 630 | x = list(range(num_layers)) 631 | y = [layer_num_to_probs[i][0] * 100 for i in x] 632 | ax.plot(x, y, label="Input tokens", color="#F0539B", marker='o', markersize=12, linewidth=6) 633 | y = [layer_num_to_probs[i][2] * 100 for i in x] 634 | ax.plot(x, y, label="Predicted output tokens", color="#43C5E0", marker='o', markersize=12, linewidth=6) 635 | y = [layer_num_to_probs[i][1] * 100 for i in x] 636 | ax.plot(x, y, label="Ground-truth output tokens", color="#2E3168", marker='o', markersize=12, linewidth=6) 637 | ax.tick_params(axis='both', which='major', labelsize=FONTSIZE) 638 | #ax.set_xlim(0.5, num_layers-1) 639 | ax.margins(x=0.01) 640 | ax.set_xticks(x) 641 | ax.set_xlabel("Layer ID", fontsize=FONTSIZE, fontweight='bold') 642 | ax.set_ylabel("Vocabulary specialization (%) ", fontsize=FONTSIZE, fontweight='bold') 643 | # ax.set_title("Layer-wise token analysis", fontsize=FONTSIZE, fontweight='bold') 644 | ax.spines['top'].set_visible(False) 645 | ax.spines['right'].set_visible(False) 646 | ax.set_title("Per layer", fontsize=FONTSIZE, fontweight='bold') 647 | # ax.legend(frameon=True, fontsize=FONTSIZE, columnspacing=0.4, labelspacing=0.4) 648 | ### Expert spec ### 649 | """Line Plot 650 | ax = axes[1] 651 | if args.model == "olmoe": 652 | x = list(range(64))[:32] # limit to 32 experts 653 | else: 654 | x = list(range(8)) 655 | # Sort x by input_exp_id_to_probs 656 | x_sorted = sorted(x, key=lambda i: -input_exp_id_to_probs[i][0]) if do_sort else x 657 | y = [input_exp_id_to_probs[i][0] * 100 for i in x_sorted] 658 | ax.plot(x, y, label="Input tokens", color="#F0539B", marker='o', markersize=16, linewidth=6, linestyle='dotted') 659 | # Scatter instead 660 | # ax.scatter(x, y, color="#F0539B", s=200, label="Input tokens") 661 | y = [gt_exp_id_to_probs[i][0] * 100 for i in x_sorted] 662 | ax.plot(x, y, label="Predicted output tokens", color="#43C5E0", marker='o', markersize=16, linewidth=6, linestyle='dotted') 663 | # ax.scatter(x, y, color="#43C5E0", s=200, label="Predicted output tokens") 664 | y = [predicted_exp_id_to_probs[i][0] * 100 for i in x_sorted] 665 | ax.plot(x, y, label="Ground-truth output tokens", color="#2E3168", marker='o', markersize=16, linewidth=6, linestyle='dotted') 666 | # ax.scatter(x, y, color="#2E3168", s=200, label="Ground-truth output tokens") 667 | # Draw horizontal prob lines for input_prob etc 668 | ax.axhline(input_prob_chosen * 100, color="#F0539B", linestyle='--', linewidth=6, alpha=0.8) 669 | ax.axhline(pred_prob_chosen * 100, color="#43C5E0", linestyle='--', linewidth=6, alpha=0.8) 670 | ax.axhline(gt_prob_chosen * 100, color="#2E3168", linestyle='--', linewidth=6, alpha=0.8) 671 | ax.tick_params(axis='both', which='major', labelsize=FONTSIZE) 672 | #ax.set_xlim(0, x[-1]) 673 | ax.margins(x=0.005) 674 | ax.set_xticks(x, x_sorted) 675 | ax.set_xlabel("Expert ID", fontsize=FONTSIZE, fontweight='bold') 676 | # ax.set_ylabel("Vocabulary specialization (%)", fontsize=FONTSIZE, fontweight='bold') 677 | # ax.set_title("Layer-wise token analysis", fontsize=FONTSIZE, fontweight='bold') 678 | ax.spines['top'].set_visible(False) 679 | ax.spines['right'].set_visible(False) 680 | ax.set_title(f"Per expert in layer {args.layer_num}", fontsize=FONTSIZE, fontweight='bold') 681 | plt.legend(frameon=True, fontsize=FONTSIZE, columnspacing=0.4, labelspacing=0.4, loc="lower right") 682 | """ 683 | #"""Bar Plot 684 | ax = axes[1] 685 | if args.model == "olmoe": 686 | x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 27, 37, 58] 687 | else: 688 | x = list(range(8)) 689 | # Sort x by input_exp_id_to_probs 690 | x_sorted = list(range(len(x))) 691 | y = [input_exp_id_to_probs[i][0] * 100 for i in x] 692 | # ax.bar(x, y, color="#F0539B", label="Input tokens") 693 | # Space them out 694 | x_bar = np.array(x_sorted) - 0.275 695 | ax.bar(x_bar, y, color="#F0539B", label="Input token ID", width=0.25, alpha=0.8) 696 | y = [predicted_exp_id_to_probs[i][0] * 100 for i in x_sorted] 697 | # ax.bar(x, y, color="#43C5E0", label="Predicted output tokens") 698 | x_bar = np.array(x_sorted) 699 | ax.bar(x_bar, y, color="#43C5E0", label="Predicted output token ID", width=0.25, alpha=0.8) 700 | y = [gt_exp_id_to_probs[i][0] * 100 for i in x_sorted] 701 | # ax.bar(x, y, color="#2E3168", label="Ground-truth output tokens") 702 | x_bar = np.array(x_sorted) + 0.275 703 | ax.bar(x_bar, y, color="#2E3168", label="Ground-truth output token ID", width=0.25, alpha=0.8) 704 | # Draw horizontal prob lines for input_prob etc 705 | ax.axhline(input_prob_chosen * 100, color="#F0539B", linestyle='--', linewidth=6) 706 | ax.axhline(pred_prob_chosen * 100, color="#43C5E0", linestyle='--', linewidth=6) 707 | ax.axhline(gt_prob_chosen * 100, color="#2E3168", linestyle='--', linewidth=6) 708 | ax.tick_params(axis='both', which='major', labelsize=FONTSIZE) 709 | ax.margins(x=0.005) 710 | ax.set_xticks(x_sorted, x) 711 | ax.set_xlabel("Expert ID", fontsize=FONTSIZE, fontweight='bold') 712 | ax.spines['top'].set_visible(False) 713 | ax.spines['right'].set_visible(False) 714 | ax.set_title(f"Per expert in layer {args.layer_num}", fontsize=FONTSIZE, fontweight='bold') 715 | plt.legend(frameon=True, fontsize=FONTSIZE, columnspacing=0.4, labelspacing=0.4, loc="lower right") 716 | #""" 717 | plt.savefig(os.path.join(args.fig_dir, f"vocabulary_specialization_top{args.topk}_{args.model}.pdf"), bbox_inches='tight') 718 | plt.savefig(os.path.join(args.fig_dir, f"vocabulary_specialization_top{args.topk}_{args.model}.png"), bbox_inches='tight') 719 | 720 | 721 | if __name__ == '__main__': 722 | """ 723 | First, run the following to save model outputs: 724 | `python moe.py --do_inference --do_inference_all_ckpts` 725 | (skip `--do_inference_all_ckpts` if you'll not run ckpt analysis) 726 | 727 | Then, to do ckpt analysis (Router Saturation): 728 | `python moe.py --do_ckpt_analysis --topk 1 729 | python moe.py --do_ckpt_analysis --topk 8` 730 | 731 | To do coactivation analysis: 732 | `python moe.py --do_coactivation_analysis --layer_num 0 733 | python moe.py --do_coactivation_analysis --layer_num 7 734 | python moe.py --do_coactivation_analysis --layer_num 15` 735 | 736 | To do token analysis (Vocabulary specialization): 737 | `python moe.py --do_token_analysis` 738 | `python moe.py --do_token_analysis_layers_experts --topk 1` 739 | 740 | For all comments: use `--tokenized_path`, `--out_dir`, `--fig_dir` to specify where to save stuff 741 | To use Mixtral, use `--model mixtral` (also requires rerunning do_inference) 742 | """ 743 | 744 | parser = argparse.ArgumentParser(prog="moe.py", description="Run analyses on OLMoE") 745 | parser.add_argument("--do_inference", action="store_true") 746 | parser.add_argument("--do_inference_all_ckpts", action="store_true") 747 | 748 | parser.add_argument("--do_ckpt_analysis", action="store_true") 749 | parser.add_argument("--do_coactivation_analysis", action="store_true") 750 | parser.add_argument("--do_token_analysis", action="store_true") 751 | parser.add_argument("--do_token_analysis_layers", action="store_true") 752 | parser.add_argument("--do_token_analysis_experts", action="store_true") 753 | parser.add_argument("--do_token_analysis_layers_experts", action="store_true") 754 | 755 | parser.add_argument("--tokenized_path", default="c4_validation.jsonl", type=str, help="directory to save tokenized c4 data.") 756 | parser.add_argument("--out_dir", default="out", type=str, help="directory to save outputs from the model") 757 | parser.add_argument("--fig_dir", default="figs", type=str, help="directory to save figures") 758 | 759 | parser.add_argument("--topk", choices=[1, 2, 8, 18], default=1, type=int) 760 | parser.add_argument("--layer_num", choices=list(range(16)), default=7, type=int) 761 | parser.add_argument("--model", default="olmoe", type=str, help="Which model; if not olmoe, then mixtral") 762 | 763 | args = parser.parse_args() 764 | 765 | if args.do_inference: 766 | if not os.path.exists(args.tokenized_path): 767 | tokenize_c4(args.tokenized_path, model=args.model) 768 | if not os.path.exists(args.out_dir): 769 | os.mkdir(args.out_dir) 770 | do_inference(args, run_all_checkpoints=args.do_inference_all_ckpts) 771 | 772 | if args.do_ckpt_analysis or args.do_coactivation_analysis or args.do_token_analysis: 773 | if not os.path.exists(args.fig_dir): 774 | os.mkdir(args.fig_dir) 775 | 776 | if args.do_ckpt_analysis: 777 | do_ckpt_analysis(args) 778 | 779 | if args.do_coactivation_analysis: 780 | do_coactivation_analysis(args) 781 | 782 | if args.do_token_analysis: 783 | do_token_analysis(args) 784 | 785 | if args.do_token_analysis_layers: 786 | do_token_analysis_layers(args) 787 | 788 | if args.do_token_analysis_experts: 789 | do_token_analysis_experts(args) 790 | 791 | if args.do_token_analysis_layers_experts: 792 | do_token_analysis_layers_experts(args) -------------------------------------------------------------------------------- /scripts/run_routing_analysis.py: -------------------------------------------------------------------------------- 1 | from transformers import OlmoeForCausalLM, AutoTokenizer, AutoModelForCausalLM 2 | import torch 3 | import time 4 | import pickle as pkl 5 | import json 6 | import numpy as np 7 | from collections import defaultdict, Counter 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | 11 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 12 | token = None 13 | 14 | start_time = time.time() 15 | 16 | 17 | # Adapted from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func 18 | # Fixed aggregating over all layers 19 | def load_balancing_loss_func( 20 | gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2 21 | ) -> float: 22 | if gate_logits is None or not isinstance(gate_logits, tuple): 23 | return 0 24 | 25 | if isinstance(gate_logits, tuple): 26 | compute_device = gate_logits[0].device 27 | concatenated_gate_logits = torch.stack([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=1) 28 | 29 | routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) 30 | 31 | _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) 32 | 33 | expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) 34 | 35 | # Compute the percentage of tokens routed to each experts 36 | tokens_per_expert = torch.mean(expert_mask.float(), dim=0) 37 | 38 | # Compute the average probability of routing to these experts 39 | router_prob_per_expert = torch.mean(routing_weights, dim=0) 40 | 41 | overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(-2)) / len(gate_logits) 42 | return overall_loss * num_experts 43 | 44 | 45 | def load_analysis_data(tokenizer, domain, bs): 46 | np.random.seed(2024) 47 | tokens = [] 48 | 49 | if domain == "tulu": 50 | data_path = f"routing_output/text/tulu-v3.1.jsonl" 51 | with open(data_path) as f: 52 | for line in f: 53 | text = json.loads(line)["text"] 54 | tokens = tokenizer(text, truncation=False)["input_ids"] 55 | while len(tokens) >= bs: 56 | yield tokens[:bs] 57 | tokens = tokens[bs:] 58 | yield tokens 59 | else: 60 | data_path = f"routing_output/text/{domain}_texts.txt" 61 | with open(data_path) as f: 62 | text = f.read() 63 | tokens = tokenizer(text, truncation=False)["input_ids"] 64 | while len(tokens) >= bs: 65 | yield tokens[:bs] 66 | tokens = tokens[bs:] 67 | 68 | def load_sft_model(): 69 | DEVICE = "cuda" 70 | model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924-SFT", token=token).to(DEVICE) 71 | model.eval() 72 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924-SFT", token=token) 73 | return model, tokenizer 74 | 75 | 76 | def load_dpo_model(): 77 | DEVICE = "cuda" 78 | model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924-Instruct", token=token).to(DEVICE) 79 | model.eval() 80 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924-Instruct", token=token) 81 | return model, tokenizer 82 | 83 | 84 | def load_model(): 85 | DEVICE = "cuda" 86 | model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token).to(DEVICE) 87 | model.eval() 88 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924", token=token) 89 | return model, tokenizer 90 | 91 | def load_model_mistral(): 92 | model_id = "mistralai/Mixtral-8x7B-v0.1" 93 | model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto', token=token) 94 | model.eval() 95 | tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) 96 | return model, tokenizer 97 | 98 | 99 | def print_expert_percentage(exp_counts): 100 | total = sum(exp_counts.values()) 101 | for eid, ecount in exp_counts.most_common(): 102 | print(f"Expert {eid}: {ecount/total*100:.2f}") 103 | 104 | 105 | def run_analysis(domain, model_name=None): 106 | layer_counters = defaultdict(Counter) 107 | crosslayer_counters = defaultdict(Counter) 108 | 109 | eid2token_layer0 = defaultdict(Counter) 110 | eid2token_layer7 = defaultdict(Counter) 111 | eid2token_layer15 = defaultdict(Counter) 112 | 113 | total_token_count = 0 114 | 115 | aux_losses = [] 116 | 117 | # for each expert, what are some commen tokens that are assigned to it? 118 | # write a code to count that and print it out: {expert_id: Counter({token1: 32, token2: 23})...} 119 | for i, input_ids in tqdm(enumerate(load_analysis_data(tokenizer, domain=domain, bs=length))): 120 | input_ids = torch.LongTensor(input_ids).reshape(1, -1).to(DEVICE) 121 | out = model(input_ids=input_ids, output_router_logits=True) 122 | 123 | aux_loss = load_balancing_loss_func( 124 | out["router_logits"], 125 | model.num_experts, 126 | model.num_experts_per_tok, 127 | ) 128 | aux_losses.append(aux_loss.cpu().item()) 129 | 130 | # input id shapes: 2048 seqlen 131 | input_ids = input_ids[0].detach().cpu().numpy().tolist() 132 | total_token_count += len(input_ids) 133 | 134 | # 16 layer, (2048 tokens, 64 experts) 135 | router_logits = [l.detach().cpu().numpy() for l in out["router_logits"]] 136 | # 2048 tokens, 8 experts, 16 layers 137 | if model_name == "mistral": 138 | exp_ids = np.stack([np.argsort(-logits, -1)[:, :2].tolist() for logits in router_logits], -1) 139 | elif model_name.startswith("olmoe"): 140 | exp_ids = np.stack([np.argsort(-logits, -1)[:, :8].tolist() for logits in router_logits], -1) 141 | # 2048 tokens, 8 experts 142 | exp_ids_layer0 = exp_ids[:, :, 0] 143 | exp_ids_layer7 = exp_ids[:, :, 7] 144 | exp_ids_layer15 = exp_ids[:, :, 15] 145 | 146 | for id, token in enumerate(input_ids): 147 | experts = exp_ids_layer0[id, :] 148 | for e in experts: 149 | eid2token_layer0[e][token] += 1 150 | for e in exp_ids_layer7[id, :]: 151 | eid2token_layer7[e][token] += 1 152 | for e in exp_ids_layer15[id, :]: 153 | eid2token_layer15[e][token] += 1 154 | 155 | for layer in range(exp_ids.shape[2]): 156 | exp_counts = Counter(exp_ids[:, :, layer].flatten()) 157 | layer_counters[layer].update(exp_counts) 158 | 159 | for layer_i in range(exp_ids.shape[2] - 1): 160 | for layer_j in range(exp_ids.shape[2]): 161 | exps_counts = Counter(zip(exp_ids[:, :, layer_i].flatten(), exp_ids[:, :, layer_j].flatten())) 162 | crosslayer_counters[(layer_i, layer_j)].update(exps_counts) 163 | 164 | if total_token_count > 204800: 165 | break 166 | 167 | print(f"Average aux loss: {np.mean(aux_losses)}") 168 | 169 | return layer_counters, crosslayer_counters, eid2token_layer0, eid2token_layer7, eid2token_layer15 170 | 171 | name2finaldata = {"github_oss_with_stack": "github", "arxiv": "arxiv", "c4": "c4", "b3g": "book", "wikipedia": "wikipedia", "tulu": "tulu"} 172 | 173 | if __name__=='__main__': 174 | model_name = "olmoe" 175 | print(model_name) 176 | if model_name == "mistral": 177 | model, tokenizer = load_model_mistral() 178 | elif model_name == "olmoe-sft": 179 | model, tokenizer = load_sft_model() 180 | elif model_name == "olmoe-dpo": 181 | model, tokenizer = load_dpo_model() 182 | elif model_name == "olmoe": 183 | model, tokenizer = load_model() 184 | 185 | length = 2048 186 | for domain in tqdm(["tulu", "github_oss_with_stack", "arxiv", "c4", "b3g", "wikipedia"]): 187 | print(f"Domain: {domain}") 188 | layer_counters, crosslayer_counters, eid2token_layer0, eid2token_layer7, eid2token_layer15 = run_analysis(domain, model_name) 189 | Path(f"routing_output/{model_name}/expert_counts").mkdir(parents=True, exist_ok=True) 190 | Path(f"routing_output/{model_name}/expert_counts_crosslayer").mkdir(parents=True, exist_ok=True) 191 | Path(f"routing_output/{model_name}/eid2token").mkdir(parents=True, exist_ok=True) 192 | with open(f"routing_output/{model_name}/expert_counts/{name2finaldata[domain]}.pkl", "wb") as f: 193 | pkl.dump([layer_counters[0], layer_counters[7], layer_counters[15]], f) 194 | with open(f"routing_output/{model_name}/expert_counts_crosslayer/{name2finaldata[domain]}.pkl", "wb") as f: 195 | pkl.dump([crosslayer_counters[(0, 7)], crosslayer_counters[(7, 15)]], f) 196 | with open(f"routing_output/{model_name}/eid2token/{name2finaldata[domain]}.pkl", "wb") as f: 197 | pkl.dump([eid2token_layer0, eid2token_layer7, eid2token_layer15], f) -------------------------------------------------------------------------------- /scripts/sparsify_ckpt_unsharded.py: -------------------------------------------------------------------------------- 1 | """ 2 | 1. Unshard ckpt using `python /home/niklas/OLMoE/scripts/unshard.py /data/niklas/llm/checkpoints/23485/step954000 /data/niklas/llm/checkpoints/1b-954000-unsharded --safe-tensors --model-only` 3 | 2. Run this script via `python /home/niklas/OLMoE/scripts/sparsify_ckpt_unsharded.py /data/niklas/llm/checkpoints/1b-954000-unsharded/model.safetensors` 4 | """ 5 | import copy 6 | import sys 7 | import torch 8 | from olmo.safetensors_util import safetensors_file_to_state_dict, state_dict_to_safetensors_file 9 | 10 | path = sys.argv[1] 11 | sd = safetensors_file_to_state_dict(path) 12 | tensors = {} 13 | swiglu = True 14 | noise = False 15 | share = False 16 | interleave = False 17 | n_experts = 8 18 | D = 2048 19 | 20 | def noise_injection(weight, noise_ratio=0.5, init_std=0.02): 21 | mask = torch.FloatTensor(weight.size()).uniform_() < noise_ratio 22 | mask = mask.to(weight.device) 23 | rand_weight = torch.nn.init.normal_(copy.deepcopy(weight), mean=0.0, std=init_std) 24 | weight[mask] = rand_weight[mask] 25 | return weight 26 | 27 | for key in list(sd.keys()): 28 | if "ff_proj.weight" in key: 29 | block_num = int(key.split(".")[2]) 30 | if interleave and block_num % 2 == 0: 31 | tensors[key] = sd.pop(key) 32 | continue 33 | new_key = key.replace("ff_proj.weight", "ffn.experts.mlp.w1") 34 | if swiglu: 35 | new_key_v1 = new_key.replace("w1", "v1") 36 | # OLMo takes the F.silu on the second part of the tensor which corresponds to v1 37 | v1, w1 = sd.pop(key).chunk(2, dim=0) # e.g. [16384, 2048] 38 | tensors[new_key] = torch.cat([w1] * n_experts, dim=0) 39 | tensors[new_key_v1] = torch.cat([v1] * n_experts, dim=0) 40 | if noise: 41 | tensors[new_key] = noise_injection(tensors[new_key]) 42 | tensors[new_key_v1] = noise_injection(tensors[new_key_v1]) 43 | if share: 44 | share_key = new_key.replace("experts.mlp.w1", "shared_expert.up_proj.weight") 45 | share_key_v1 = new_key_v1.replace("experts.mlp.v1", "shared_expert.gate_proj.weight") 46 | tensors[share_key] = w1 47 | tensors[share_key_v1] = v1 48 | else: 49 | tensors[new_key] = torch.cat([sd.pop(key)] * n_experts, dim=0) 50 | elif ("ff_out.weight" in key) and (key != 'transformer.ff_out.weight'): 51 | block_num = int(key.split(".")[2]) 52 | if interleave and block_num % 2 == 0: 53 | tensors[key] = sd.pop(key) 54 | continue 55 | new_key = key.replace("ff_out.weight", "ffn.experts.mlp.w2") 56 | w = sd.pop(key) 57 | tensors[new_key] = torch.cat([w.t()] * n_experts, dim=0) 58 | if noise: 59 | tensors[new_key] = noise_injection(tensors[new_key]) 60 | if share: 61 | share_key = new_key.replace("experts.mlp.w2", "shared_expert.down_proj.weight") 62 | tensors[share_key] = w 63 | # Add router 64 | router_key = key.replace("ff_out.weight", "ffn.router.layer.weight") 65 | # tensors[router_key] = torch.ones((n_experts, D)).squeeze() # Worse perf 66 | tensors[router_key] = torch.nn.init.normal_(torch.ones((n_experts, D)).squeeze(), std=0.02) 67 | else: 68 | tensors[key] = sd.pop(key) 69 | 70 | state_dict_to_safetensors_file(tensors, path.replace("model.safetensors", "model_sparse.safetensors")) -------------------------------------------------------------------------------- /scripts/wekatransfer/s3weka.sh: -------------------------------------------------------------------------------- 1 | export BUDGET=ai2/oe-training 2 | export S3_BUCKET=ai2-llm 3 | #export S3_PREFIX=preprocessed/olmo-mix/danyh-compiled-v1_7 4 | #export S3_PREFIX=preprocessed/fastdclm 5 | export S3_PREFIX=preprocessed/starcoder/v1-decon-100_to_20k-2star-top_token_030 6 | export WEKA_BUCKET=oe-training-default 7 | #export WEKA_PREFIX=ai2-llm/preprocessed/danyh-compiled-v1_7 8 | #WEKA_PREFIX=ai2-llm/preprocessed/fastdclm 9 | WEKA_PREFIX=ai2-llm/preprocessed/starcoder/v1-decon-100_to_20k-2star-top_token_030 -------------------------------------------------------------------------------- /scripts/wekatransfer/s3weka.yml: -------------------------------------------------------------------------------- 1 | version: v2 2 | budget: {{.Env.BUDGET}} 3 | description: Sync contents of S3 bucket "{{.Env.S3_BUCKET}}" to WEKA bucket "{{.Env.WEKA_BUCKET}}" 4 | tasks: 5 | - name: sync_s3_to_weka 6 | image: 7 | beaker: ai2/cuda11.8-ubuntu20.04 8 | command: ['aws', 's3', 'sync', 's3://{{.Env.S3_BUCKET}}/{{.Env.S3_PREFIX}}', '/{{.Env.WEKA_BUCKET}}/{{.Env.WEKA_PREFIX}}'] 9 | datasets: 10 | - mountPath: /{{.Env.WEKA_BUCKET}} 11 | source: 12 | weka: {{.Env.WEKA_BUCKET}} 13 | envVars: 14 | - name: AWS_ACCESS_KEY_ID 15 | secret: AWS_ACCESS_KEY_ID 16 | - name: AWS_SECRET_ACCESS_KEY 17 | secret: AWS_SECRET_ACCESS_KEY 18 | context: 19 | preemptible: true 20 | constraints: 21 | cluster: 22 | - ai2/jupiter-cirrascale-2 -------------------------------------------------------------------------------- /scripts/wekatransfer/wekas3.yaml: -------------------------------------------------------------------------------- 1 | version: v2 2 | budget: {{.Env.BUDGET}} 3 | description: Sync contents of WEKA bucket "{{.Env.WEKA_BUCKET}}" to S3 bucket "{{.Env.S3_BUCKET}}" 4 | tasks: 5 | - name: sync_weka_to_s3 6 | image: 7 | beaker: ai2/cuda11.8-ubuntu20.04 8 | command: ['aws', 's3', 'sync', '/{{.Env.WEKA_BUCKET}}/{{.Env.WEKA_PREFIX}}', 's3://{{.Env.S3_BUCKET}}/{{.Env.S3_PREFIX}}'] 9 | datasets: 10 | - mountPath: /{{.Env.WEKA_BUCKET}} 11 | source: 12 | weka: {{.Env.WEKA_BUCKET}} 13 | envVars: 14 | - name: AWS_ACCESS_KEY_ID 15 | secret: AWS_ACCESS_KEY_ID 16 | - name: AWS_SECRET_ACCESS_KEY 17 | secret: AWS_SECRET_ACCESS_KEY 18 | context: 19 | preemptible: true 20 | constraints: 21 | cluster: 22 | - ai2/jupiter-cirrascale-2 -------------------------------------------------------------------------------- /visuals/emojis/olmoe_checkmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/emojis/olmoe_checkmark.png -------------------------------------------------------------------------------- /visuals/emojis/olmoe_checkmark_yellow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/emojis/olmoe_checkmark_yellow.png -------------------------------------------------------------------------------- /visuals/emojis/olmoe_cross.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/emojis/olmoe_cross.png -------------------------------------------------------------------------------- /visuals/emojis/olmoe_warning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/emojis/olmoe_warning.png -------------------------------------------------------------------------------- /visuals/figures/adamweps.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/adamweps.pdf -------------------------------------------------------------------------------- /visuals/figures/dataset.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/dataset.pdf -------------------------------------------------------------------------------- /visuals/figures/datasetredditflan.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/datasetredditflan.pdf -------------------------------------------------------------------------------- /visuals/figures/embdecay.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/embdecay.pdf -------------------------------------------------------------------------------- /visuals/figures/expertchoice.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/expertchoice.pdf -------------------------------------------------------------------------------- /visuals/figures/granularity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/granularity.pdf -------------------------------------------------------------------------------- /visuals/figures/init.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/init.pdf -------------------------------------------------------------------------------- /visuals/figures/layer_0_heatmap.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/layer_0_heatmap.pdf -------------------------------------------------------------------------------- /visuals/figures/layer_15_heatmap.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/layer_15_heatmap.pdf -------------------------------------------------------------------------------- /visuals/figures/layer_7_heatmap.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/layer_7_heatmap.pdf -------------------------------------------------------------------------------- /visuals/figures/layersharing.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/layersharing.pdf -------------------------------------------------------------------------------- /visuals/figures/lbl.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/lbl.pdf -------------------------------------------------------------------------------- /visuals/figures/lblprecision.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/lblprecision.pdf -------------------------------------------------------------------------------- /visuals/figures/lbltoks.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/lbltoks.pdf -------------------------------------------------------------------------------- /visuals/figures/ln.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/ln.pdf -------------------------------------------------------------------------------- /visuals/figures/lndecay.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/lndecay.pdf -------------------------------------------------------------------------------- /visuals/figures/lngradnorm.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/lngradnorm.pdf -------------------------------------------------------------------------------- /visuals/figures/loss.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/loss.pdf -------------------------------------------------------------------------------- /visuals/figures/moevsdense.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/moevsdense.pdf -------------------------------------------------------------------------------- /visuals/figures/noise.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/noise.pdf -------------------------------------------------------------------------------- /visuals/figures/olmoe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/olmoe.pdf -------------------------------------------------------------------------------- /visuals/figures/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/overview.jpg -------------------------------------------------------------------------------- /visuals/figures/overview.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/overview.pdf -------------------------------------------------------------------------------- /visuals/figures/qknorm.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/qknorm.pdf -------------------------------------------------------------------------------- /visuals/figures/routing_mixtral.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/routing_mixtral.pdf -------------------------------------------------------------------------------- /visuals/figures/routing_olmoe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/routing_olmoe.pdf -------------------------------------------------------------------------------- /visuals/figures/routing_prob_distribution_mixtral.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/routing_prob_distribution_mixtral.pdf -------------------------------------------------------------------------------- /visuals/figures/routing_prob_distribution_olmoe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/routing_prob_distribution_olmoe.pdf -------------------------------------------------------------------------------- /visuals/figures/shared.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/shared.pdf -------------------------------------------------------------------------------- /visuals/figures/token_specialization_top1_olmoe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/token_specialization_top1_olmoe.pdf -------------------------------------------------------------------------------- /visuals/figures/token_specialization_top2_mixtral.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/token_specialization_top2_mixtral.pdf -------------------------------------------------------------------------------- /visuals/figures/token_specialization_top8_olmoe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/token_specialization_top8_olmoe.pdf -------------------------------------------------------------------------------- /visuals/figures/top18_changes_over_checkpoints.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/top18_changes_over_checkpoints.pdf -------------------------------------------------------------------------------- /visuals/figures/trainingevalflops.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/trainingevalflops.pdf -------------------------------------------------------------------------------- /visuals/figures/trainingevaltokens.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/trainingevaltokens.pdf -------------------------------------------------------------------------------- /visuals/figures/upcycle.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/upcycle.pdf -------------------------------------------------------------------------------- /visuals/figures/zloss.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/figures/zloss.pdf -------------------------------------------------------------------------------- /visuals/logos/OLMoE_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/logos/OLMoE_logo.png -------------------------------------------------------------------------------- /visuals/logos/OLMoE_logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /visuals/logos/OLMoE_logo_alt1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/logos/OLMoE_logo_alt1.png -------------------------------------------------------------------------------- /visuals/logos/OLMoE_logo_alt1.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /visuals/logos/OLMoE_logo_alt2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/logos/OLMoE_logo_alt2.png -------------------------------------------------------------------------------- /visuals/logos/OLMoE_logo_alt2.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /visuals/logos/OLMoE_logo_alt3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/logos/OLMoE_logo_alt3.png -------------------------------------------------------------------------------- /visuals/logos/OLMoE_logo_alt3.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /visuals/poster_iclr2025.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/poster_iclr2025.pdf -------------------------------------------------------------------------------- /visuals/poster_iclr2025.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/poster_iclr2025.pptx -------------------------------------------------------------------------------- /visuals/poster_neurips2024.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/poster_neurips2024.pdf -------------------------------------------------------------------------------- /visuals/twitterblog_images/domainspec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/domainspec.png -------------------------------------------------------------------------------- /visuals/twitterblog_images/experiments.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/experiments.png -------------------------------------------------------------------------------- /visuals/twitterblog_images/logo_transparent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/logo_transparent.png -------------------------------------------------------------------------------- /visuals/twitterblog_images/logo_twitter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/logo_twitter.png -------------------------------------------------------------------------------- /visuals/twitterblog_images/overview_base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/overview_base.png -------------------------------------------------------------------------------- /visuals/twitterblog_images/overview_left.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/overview_left.png -------------------------------------------------------------------------------- /visuals/twitterblog_images/overview_long.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/overview_long.png -------------------------------------------------------------------------------- /visuals/twitterblog_images/overview_right.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/overview_right.png -------------------------------------------------------------------------------- /visuals/twitterblog_images/perf_adapt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/perf_adapt.png -------------------------------------------------------------------------------- /visuals/twitterblog_images/perf_during.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/perf_during.png -------------------------------------------------------------------------------- /visuals/twitterblog_images/perf_pretr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/perf_pretr.png -------------------------------------------------------------------------------- /visuals/twitterblog_images/perf_pretr_adapt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/perf_pretr_adapt.png -------------------------------------------------------------------------------- /visuals/twitterblog_images/tokenidspec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/OLMoE/357454f4f647385839c0ff6b99a688dc7cd9c13f/visuals/twitterblog_images/tokenidspec.png --------------------------------------------------------------------------------