The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .gitignore
├── README.md
├── analysis
    └── colossalai_replace
    │   └── layer.py
├── env
    ├── colab_env.txt
    ├── openmoe_infer_dockerfile
    └── prepare_env.sh
├── eval
    ├── plot_bigbench.py
    ├── result_retrieval_bigbenchlite.py
    └── triqa_plot.py
├── figure
    ├── bblite-3-shot.pdf
    ├── bblite-3-shot.png
    ├── mt_bench_turn_0.png
    ├── mt_bench_turn_1.png
    └── mt_bench_turn_2.png
├── logo.jpg
├── paper
    ├── README.md
    └── paper.pdf
├── results.md
└── script
    ├── inference_on_multi_devices.py
    ├── run_eval.sh
    └── run_pretrain.sh


/.gitignore:
--------------------------------------------------------------------------------
1 | BIG-bench
2 | results
3 | .idea


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | <p align="center">
  2 | <img width="200px" alt="OpenMoE" src="https://github.com/XueFuzhao/OpenMoE/blob/main/logo.jpg?raw=true">
  3 | </p>
  4 | <p align="center"><a href="https://github.com/XueFuzhao/OpenMoE/tree/main">[Homepage]</a> | <a href="https://arxiv.org/abs/2402.01739">[Paper]</a> |  <a href="https://colab.research.google.com/drive/1xIfIVafnlCP2XVICmRwkUFK3cwTJYjCY#scrollTo=62T-2mH_tsjG">[Colab Demo]</a> | <a href="https://huggingface.co/OrionZheng">[Huggingface]</a> | <a href="https://discord.gg/bjGnGfjegU">[Discord]</a>  |  <a href="https://twitter.com/xuefz/status/1693696988611739947?s=61&t=Xc2k2W7vU_hlpNizGDCmOw">[Twitter]</a> | <a href="https://xuefuzhao.notion.site/Aug-2023-OpenMoE-v0-2-Release-43808efc0f5845caa788f2db52021879">[Blog]</a></p>
  5 | </p>
  6 | <hr>
  7 | 
  8 | # OpenMoE
  9 | OpenMoE is a project aimed at igniting the open-source MoE community! We are releasing a family of open-sourced Mixture-of-Experts (MoE) Large Language Models.
 10 | 
 11 | Our project began in the summer of 2023. On August 22, 2023, we released the first batch of intermediate checkpoints (OpenMoE-base&8B), along with the data and code [[Twitter]](https://twitter.com/xuefz/status/1693696988611739947?s=61&t=Xc2k2W7vU_hlpNizGDCmOw). Subsequently, the OpenMoE-8B training was completed in November 2023. After that, we embarked on explorations on the 34B scale model, which is still ongoing.
 12 | 
 13 | As a small student team, instead of pursuing the best model with better data, computation, and human power, we devote to fully sharing our training data, strategies, model architecture, weights, and everything we have with the community. We hope this project will promote research on this promising field and invite more contributors to work on open-sourced MoE projects together!
 14 | 
 15 | ## News
 16 | 
 17 | [2024/01] 🔥 We release OpenMoE paper! We conducted an in-depth routing analysis and found many interesting stuff. Check it [here](https://github.com/XueFuzhao/OpenMoE/blob/main/paper/paper.pdf)!
 18 | 
 19 | [2024/01] 🔥 OpenMoE-8B-Chat is now available. We've provided a Colab inference [demo](https://colab.research.google.com/drive/1xIfIVafnlCP2XVICmRwkUFK3cwTJYjCY) for everyone to try, as well as a [tutorial](https://colab.research.google.com/drive/1eIT1rtG7pORRQAYtQoMOAekUg7aZLDdn) on converting JAX checkpoints to PyTorch checkpoints(Note: both require Colab Pro).
 20 | 
 21 | [2023/11] 🔥 Thanks to Colossal AI! They released one [PyTorch OpenMoE implementation](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/openmoe) including both training and inference with expert parallelism.
 22 | 
 23 | [2023/08] 🔥 We released an intermediate OpenMoE-8B checkpoint (OpenMoE-v0.2) along with two other models. Check out the blog [post](https://xuefuzhao.notion.site/Aug-2023-OpenMoE-v0-2-Release-43808efc0f5845caa788f2db52021879).
 24 | 
 25 | ## TODO List
 26 | 
 27 | - [x] PyTorch Implementation with Colossal AI
 28 | - [x] Continue Training to 1T tokens
 29 | - [ ] More Evaluation
 30 | - [ ] Paper
 31 | 
 32 | ## Contents
 33 | - [Model Weights](#model-weights)
 34 | - [Get Started](#get-started)
 35 | - [Approach](#approach)
 36 | - [License](#license)
 37 | - [Authors](#authors)
 38 | - [Citation](#citation)
 39 | 
 40 | 
 41 | ## Model Weights
 42 | Currently, three models are released in total: OpenMoE-base, OpenMoE-8B/8B-Chat, and OpenMoE-34B(at 200B tokens). 
 43 | 
 44 | The table below lists the 8B/8B-Chat model that has completed training on 1.1T tokens.
 45 | 
 46 | | Model Name     | Description                      | #Param   |Huggingface |
 47 | |----------------|-------------------------------------------------|----------|-------------|
 48 | | **OpenMoE-8B(1.1T)**   | 8B MoE with comparable FLOPs of a 2B LLaMA(No SFT)  |8B        |[Link](https://huggingface.co/OrionZheng/openmoe-8b) |
 49 | | **OpenMoE-8B-Chat (1.1T+SFT)**   | OpenMoE-8B-1.1T supervised finetuned on the [WildChat GPT-4 Subset](https://huggingface.co/datasets/allenai/WildChat-nontoxic)   |8B        |[Link](https://huggingface.co/OrionZheng/openmoe-8b-chat) |
 50 | 
 51 | 
 52 | Besides, we also provide all our intermediate checkpoints(base, 8B, 34B) for research purposes.
 53 | 
 54 | | Model Name     | Description                      | #Param   |Huggingface |
 55 | |----------------|-------------------------------------------------|----------|-------------|
 56 | | **OpenMoE-34B-200B**   |  34B MoE with comparable FLOPs of a 7B LLaMA(No SFT)  |34B        |[Link](https://huggingface.co/OrionZheng/openmoe-34b-200B) |
 57 | | OpenMoE-8B-200B   | 8B MoE with comparable FLOPs of a 2B LLaMA(No SFT) |8B        |[Link](https://huggingface.co/OrionZheng/openmoe-8b-200B) |
 58 | | OpenMoE-8B-400B   | 8B MoE with comparable FLOPs of a 2B LLaMA(No SFT)  |8B        |[Link](https://huggingface.co/OrionZheng/openmoe-8b-400B) | 
 59 | | OpenMoE-8B-600B   | 8B MoE with comparable FLOPs of a 2B LLaMA(No SFT) |8B        |[Link](https://huggingface.co/OrionZheng/openmoe-8b-600B) |
 60 | | OpenMoE-8B-800B   | 8B MoE with comparable FLOPs of a 2B LLaMA(No SFT)  |8B        |[Link](https://huggingface.co/OrionZheng/openmoe-8b-800B) | 
 61 | | OpenMoE-8B-1T   | 8B MoE with comparable FLOPs of a 2B LLaMA(No SFT)  |8B        |[Link](https://huggingface.co/OrionZheng/openmoe-8b-1T) | 
 62 | | OpenMoE-base(128B)   | A small MoE model for debugging only       |637M      |[Link](https://huggingface.co/OrionZheng/openmoe-base) |  
 63 | | OpenLLaMA-base(128B) | A dense counter-part of OpenMoE-base            |310M      |[Link](https://huggingface.co/fuzhao/OpenLLaMA_Base) |
 64 | 
 65 | 
 66 | The base model, which was trained using 128 billion tokens, served primarily for debugging purposes. After validating the effectiveness of our model architecture, we did not pursue further training. Consequently, their performance might not be very good, and the checkpoints are not suitable for practical applications. Better performance can be observed from our 8B or 34B versions.
 67 | 
 68 | The OpenMoE-8B with 4 MoE layers and 32 experts has been trained by 1.1T tokens. The SFT version has also been released after we finetuned the OpenMoE-8B-1.1T on the [wildchat]((https://huggingface.co/datasets/allenai/WildChat-nontoxic)) dataset's GPT-4 subset. The intermediate checkpoints at 200B, 400B, 600B, 800B, 1T tokens can be used to study the training dynamics of MoE architecture.
 69 | 
 70 | We are still training our OpenMoE-34B, which is a MoE model with 8 MoE layers and 32 experts. We released the intermediate checkpoint trained on 200B tokens on Huggingface. If you are interested in the latest checkpoint, please feel free to drop Fuzhao an email (f.xue@u.nus.edu).
 71 | 
 72 | ## Get Started
 73 | 
 74 | ### Inference with Pytorch
 75 | Our PyToch implementation is supported by [Colossal AI](https://github.com/hpcaitech/ColossalAI). You can install our forked version directly for easier setup:
 76 | ```
 77 | # Install ColossalAI
 78 | git clone --branch my_openmoe https://github.com/Orion-Zheng/ColossalAI.git
 79 | pip install ./ColossalAI
 80 | python -m pip install -r ./ColossalAI/examples/language/openmoe/requirements.txt
 81 | ```
 82 | 
 83 | Then, you can do inference by:
 84 | ```
 85 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
 86 | 
 87 | model_path = "ckpts/openmoe-8b-chat"
 88 | config = AutoConfig.from_pretrained(model_path)
 89 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 90 | model = AutoModelForCausalLM.from_pretrained(
 91 |     model_path,
 92 |     torch_dtype=torch.bfloat16,
 93 |     trust_remote_code=True, 
 94 |     device_map='auto'
 95 |     )
 96 | query = 'Question: How do I kill a process? Answer:'
 97 | prompt = f'''<<SYS>>
 98 | You are a helpful, respectful and honest assistant.
 99 | <</SYS>>
100 | 
101 | <s>[INST] {query} [/INST]'''
102 | 
103 | inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
104 | sample = model.generate(**inputs, max_new_tokens=32)
105 | print(tokenizer.decode(sample[0]))
106 | ```
107 | 
108 | We also provide a Colab [tutorial](https://colab.research.google.com/drive/1eIT1rtG7pORRQAYtQoMOAekUg7aZLDdn) demonstrating the jax checkpoint conversion and execution of PyTorch model inference. You can experiment with OpenMoE-8B-Chat on Colab directly by [this](https://colab.research.google.com/drive/1xIfIVafnlCP2XVICmRwkUFK3cwTJYjCY)(Note: both require Colab Pro).
109 | - Running OpenMoE-8B requires ~49GB of memory in float32 or ~23GB in bfloat16. It can be executed on a Colab `CPU High-RAM` runtime or an `A100-40GB` runtime, both of which require Colab Pro.The float16 precision is not recommended because sometimes it will lead to performance degradation.
110 | - Running the OpenMoE-34B requires ~89GB of memory in bfloat16 or ~180GB in float32. To perform inference on multiple devices/offloading model weights to RAM, please refer to the script [here](script/inference_on_multi_devices.py).
111 | - A more detailed env setup script can be found [here](env/prepare_env.sh). Note: you don't need t5x and Jax dependency if you are using our [huggingface ckpts](https://huggingface.co/OrionZheng/openmoe-8b-chat)
112 | 
113 | 
114 | 
115 | ### Training with TPU/GPU
116 | 1. **On TPUs:** Get a TPU-vm and run the following code on all TPUs. Researchers can apply [TPU Research Cloud](https://sites.research.google/trc/about/) to get the TPU resource.
117 | ```
118 | git clone https://github.com/XueFuzhao/OpenMoE.git
119 | bash OpenMoE/script/run_pretrain.sh
120 | ```
121 | 2. **On GPUs:** [ColossalAI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/openmoe) provides a PyTorch + GPU implementation for OpenMoE and has optimized expert parallel strategies. However, we have recently noticed some issues[#5163](https://github.com/hpcaitech/ColossalAI/issues/5163),[#5212](https://github.com/hpcaitech/ColossalAI/issues/5212) raised about convergence problems. We are actively following up on these concerns and will soon update our GPU training tutorials.
122 | 
123 | ### Evaluation with TPU/GPU
124 | 1. **On GPUs:** You can evaluate our model on MT-Bench by running the code below.
125 | ```
126 | git clone https://github.com/Orion-Zheng/FastChat.git
127 | cd FastChat && pip install -e ".[model_worker,llm_judge]"
128 | cd FastChat/fastchat/llm_judge
129 | python gen_model_answer.py --model-path LOCAL_PATH_TO_MODEL_CKPT/openmoe_8b_chat_ckpt\
130 |                            --model-id openmoe-chat\
131 |                            --dtype bfloat16
132 | ```
133 | 2. **On TPUs:** Get a TPU-vm and run the following code to evaluate the model on the BIG-bench-Lite.
134 | ```
135 | git clone https://github.com/XueFuzhao/OpenMoE.git
136 | bash OpenMoE/script/run_eval.sh
137 | ```
138 | 
139 | ## Approach
140 | ### Data
141 | #### Before 780B tokens:
142 | 50% The RedPajama + 50% The Stack Dedup.
143 | We use a high ratio of coding data to improve reasoning ability.
144 | 
145 | #### After 780B tokens:
146 | 
147 | | dataset                        | Ratio (%) |
148 | | -------------------------------| ----------- |
149 | | redpajama_c4                   | 15.0        |
150 | | redpajama_wikipedia            | 6.5         |
151 | | wikipedia                      | 6.5         |
152 | | redpajama_stackexchange        | 2.5         |
153 | | redpajama_arxiv                | 4.5         |
154 | | redpajama_book                 | 6.5         |
155 | | redpajama_github               | 5.0         |
156 | | redpajama_common_crawl         | 43.5        |
157 | | the_stack_dedup                | 10.0        |
158 | 
159 | We found model tends to learn code faster than language. So we decide to reduce the coding data at the later stage of training.
160 | 
161 | Below are scripts to generate TFDS for pre-training datasets:   
162 | The RedPajama: https://github.com/Orion-Zheng/redpajama_tfds  
163 | The-Stack-Dedup: https://github.com/Orion-Zheng/the_stack_tfds  
164 | 
165 | ### Tokenizer
166 | We use the [umt5 Tokenizer](https://arxiv.org/abs/2304.09151) to support multi-lingual continue learning in the future, which can be downloaded on [Huggingface](https://huggingface.co/google/umt5-small/tree/main) or [Google Cloud](https://github.com/google-research/t5x/blob/main/docs/models.md#umt5-checkpoints).
167 | 
168 | ### Model Architecture
169 | OpenMoE is based on [ST-MoE](https://arxiv.org/abs/2202.08906) but uses Decoder-only architecture. The detailed implementation can be found in Fuzhao's [T5x](https://github.com/XueFuzhao/t5x) and [Flaxformer](https://github.com/XueFuzhao/flaxformer) repo.
170 | 
171 | ### Training Objective
172 | 
173 | #### Before 780B tokens:
174 | We use a modified UL2 training objective but Casual Attention Mask (We use more prefix LM and high mask ratio because it saves computation.):
175 | - 50% prefix LM
176 | - 10% span len=3 mask ratio=0.15
177 | - 10% span len=8 mask ratio=0.15
178 | - 10% span len=3 mask ratio=0.5
179 | - 10% span len=8 mask ratio=0.5
180 | - 10% span len=64 mask ratio=0.5
181 | 
182 | #### After 780B tokens:
183 | Vanilla next token prediction, because we observed that UL2 objective tends to saturate at the later stage of training, although it enables model to learn things faster at start.
184 | 
185 | ### Other Designs
186 | RoPE, SwiGLU activation, 2K context length. We will release a more detailed report soon.
187 | 
188 | ## Evaluation Results
189 | ### BigBench-Lite
190 | We evaluate our model on BigBench-Lite as our first step. We plot the cost-effectiveness curve in the figure below. 
191 | 
192 | Relative Cost is approximated by multiplying activated parameters and training tokens. The size of dots denotes the number of activated parameters for each token. The lightgray dot denotes the total parameters of MoE models.
193 | <img src="figure/bblite-3-shot.png" width="60%" alt="Bigbench-Lite">
194 | 
195 | 
196 | For more detailed results, please see our [Blog](https://www.notion.so/Aug-2023-OpenMoE-v0-2-Release-43808efc0f5845caa788f2db52021879) 
197 | 
198 | ### MT-Bench
199 | We perform evaluation on MT-Bench and observe that OpenMoE-8B-Chat outperformed dense LLMs trained with around two times effective FLOPs on the first Turn results.
200 | 
201 | <img src="figure/mt_bench_turn_1.png" width="50%" alt="Bigbench-Lite">
202 | 
203 | ## License
204 | 
205 | Our code is under Apache 2.0 License.
206 | 
207 | Since the models are trained on The Redpajama and The Stack dataset, please check the license of these two datasets for your model usage.
208 | 
209 | 
210 | ## Authors
211 | 
212 | This project is currently contributed by the following authors:
213 | 
214 | - [Fuzhao Xue](https://xuefuzhao.github.io/)
215 | - [Zian Zheng](https://zheng-zian-andy.com)
216 | - [Yao Fu](https://franxyao.github.io/)
217 | - [Jinjie Ni](http://jinjie.one/)
218 | - [Zangwei Zheng](https://zhengzangw.github.io/)
219 | - [Wangchunshu Zhou](https://michaelzhouwang.github.io/)
220 | - [Yang You](https://www.comp.nus.edu.sg/~youy/)
221 | 
222 | ## Acknowledgement
223 | The computational resources for this project were generously provided by the [Google TPU Research Cloud(TRC)](https://sites.research.google/trc/about/). We extend our heartfelt thanks to TRC for their invaluable support, which has been fundamental to the success of our work. Besides, we are extremely grateful to the [ColossalAI Team](https://github.com/hpcaitech/ColossalAI) for their tremendous support with the PyTorch implementation, especially [Xuanlei Zhao](https://oahzxl.github.io/) and [Wenhao Chen](https://github.com/CWHer), making training and inference of OpenMoE on GPUs a reality.
224 | 
225 | ## Citation
226 | 
227 | Please cite the repo if you use the model and code in this repo.
228 | 
229 | ```bibtex
230 | @article{xue2024openmoe,
231 |   title={OpenMoE: An Early Effort on Open Mixture-of-Experts Language Models},
232 |   author={Xue, Fuzhao and Zheng, Zian and Fu, Yao and Ni, Jinjie and Zheng, Zangwei and Zhou, Wangchunshu and You, Yang},
233 |   journal={arXiv preprint arXiv:2402.01739},
234 |   year={2024}
235 | }
236 | ```
237 | 
238 | ## Star History
239 | 
240 | [![Star History Chart](https://api.star-history.com/svg?repos=XueFuzhao/OpenMoE&type=Date)](https://star-history.com/#XueFuzhao/OpenMoE&Date)
241 | 
242 | 
243 | 


--------------------------------------------------------------------------------
/analysis/colossalai_replace/layer.py:
--------------------------------------------------------------------------------
  1 | import dataclasses
  2 | import math
  3 | from typing import Any, Optional, Tuple
  4 | 
  5 | import torch
  6 | import torch.distributed as dist
  7 | import torch.nn as nn
  8 | import torch.nn.functional as F
  9 | from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
 10 | from colossalai.moe.experts import MLPExperts
 11 | from colossalai.moe.manager import MOE_MANAGER
 12 | from colossalai.moe.routers import MoeRouter, get_router_cls
 13 | from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
 14 | from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
 15 | 
 16 | 
 17 | import json
 18 | import numpy as np
 19 | 
 20 | class SparseMLP(nn.Module):
 21 |     """A class for users to create MoE modules in their models.
 22 | 
 23 |     Args:
 24 |         dim_model (int): Hidden dimension of training model
 25 |         num_experts (int): The number experts
 26 |         top_k (int, optional): The number of experts for dispatchment of each token
 27 |         capacity_factor_train (float, optional): Capacity factor in routing during training
 28 |         capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
 29 |         min_capacity (int, optional): The minimum number of the capacity of each expert
 30 |         noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
 31 |             'Jitter' can be found in `Switch Transformer paper`_.
 32 |             'Gaussian' can be found in `ViT-MoE paper`_.
 33 |         drop_tks (bool, optional): Whether drops tokens in evaluation
 34 |         use_residual (bool, optional): Makes this MoE layer a Residual MoE.
 35 |             More information can be found in `Microsoft paper`_.
 36 |         residual_instance (nn.Module, optional): The instance of residual module in Residual MoE
 37 |         expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
 38 |         expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
 39 |         expert_args (optional): The args of expert when no instance is given
 40 | 
 41 |     .. _Switch Transformer paper:
 42 |         https://arxiv.org/abs/2101.03961
 43 |     .. _ViT-MoE paper:
 44 |         https://arxiv.org/abs/2106.05974
 45 |     .. _Microsoft paper:
 46 |         https://arxiv.org/abs/2201.05596
 47 |     """
 48 | 
 49 |     def __init__(
 50 |         self,
 51 |         num_experts: int,
 52 |         hidden_size: int,
 53 |         intermediate_size: int,
 54 |         router_top_k: int = 1,
 55 |         router_capacity_factor_train: float = 1.25,
 56 |         router_capacity_factor_eval: float = 2.0,
 57 |         router_min_capacity: int = 4,
 58 |         router_noisy_policy: Optional[str] = None,
 59 |         router_drop_tks: bool = True,
 60 |         mlp_activation: Optional[str] = None,
 61 |         mlp_gated: bool = False,
 62 |         enable_load_balance: bool = False,
 63 |         load_balance_tolerance: float = 0.1,
 64 |         load_balance_beam_width: int = 8,
 65 |         load_balance_group_swap_factor: float = 0.4,
 66 |         enable_kernel: bool = False,
 67 |         enable_comm_overlap: bool = False,
 68 |         enable_hierarchical_comm: bool = False,
 69 |         model_output_dir: str = None,
 70 |     ):
 71 |         super().__init__()
 72 |         self.hidden_size = hidden_size
 73 |         self.intermediate_size = intermediate_size
 74 |         self.num_experts = num_experts
 75 |         self.gated = mlp_gated
 76 |         self.enable_kernel = enable_kernel
 77 |         self.enable_comm_overlap = enable_comm_overlap
 78 |         self.expert_parallel = MOE_MANAGER.get_parallel()
 79 |         self.model_output_dir = model_output_dir
 80 | 
 81 |         # For MoE Analysis
 82 |         if self.model_output_dir is not None:
 83 |             self.output_json_file = open(f"{self.model_output_dir}/output.json", "w")
 84 | 
 85 |         # moe router
 86 |         noisy_func = get_noise_generator(router_noisy_policy, num_experts)
 87 |         router_cls = get_router_cls(router_top_k)
 88 |         self.topk = router_top_k
 89 |         self.router: MoeRouter = router_cls(
 90 |             capacity_factor_train=router_capacity_factor_train,
 91 |             capacity_factor_eval=router_capacity_factor_eval,
 92 |             min_capacity=router_min_capacity,
 93 |             noisy_func=noisy_func,
 94 |             drop_tks=router_drop_tks,
 95 |         )
 96 | 
 97 |         # gate
 98 |         self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size))
 99 | 
100 |         # moe experts
101 |         self.experts = MLPExperts(
102 |             num_experts=self.num_experts,
103 |             expert_parallel=self.expert_parallel,
104 |             hidden_size=self.hidden_size,
105 |             intermediate_size=self.intermediate_size,
106 |             activation=mlp_activation,
107 |             gated=mlp_gated,
108 |             use_kernel=self.enable_kernel,
109 |         )
110 | 
111 |         # get parallel settings
112 |         if self.expert_parallel is not None:
113 |             self.ep_group = get_ep_group(self.experts)
114 |             self.ep_size = get_ep_size(self.experts)
115 |             self.ep_hierarchical_group = None
116 |             if enable_hierarchical_comm:
117 |                 self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
118 |                     get_ep_group_ranks(self.experts)
119 |                 )
120 |             self.dp_group = get_dp_group(self.experts)
121 |         else:
122 |             self.ep_group = None
123 |             self.dp_group = None
124 |         self.num_local_experts = self.experts.num_local_experts
125 | 
126 |         # load balance
127 |         self.enable_load_balance = enable_load_balance
128 |         if self.enable_load_balance == True:
129 |             from colossalai.moe.load_balance import LoadBalancer
130 |             self.load_balancer = LoadBalancer(
131 |                 experts=self.experts,
132 |                 gate=self.gate_weight,
133 |                 local_expert_num=self.num_local_experts,
134 |                 expert_num=self.num_experts,
135 |                 ep_group=self.ep_group,
136 |                 dp_group=self.dp_group,
137 |                 tolerance=load_balance_tolerance,
138 |                 beam_width=load_balance_beam_width,
139 |                 group_swap_factor=load_balance_group_swap_factor,
140 |             )
141 | 
142 |         # init param
143 |         self.reset_parameters()
144 | 
145 |     @torch.no_grad()
146 |     def reset_parameters(self):
147 |         torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size))
148 | 
149 |     def forward(self, inputs: torch.Tensor) -> torch.Tensor:
150 |         """
151 |         Args:
152 |             inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size)
153 | 
154 |         Returns:
155 |             torch.Tensor: The output tensor of shape (batch_size, seq_len, hidden_size)
156 |         """
157 |         # reshape the input tokens
158 |         tokens = inputs.reshape(-1, self.hidden_size)
159 | 
160 |         # the data type of the inputs in the gating should be fp32
161 |         fp32_input = tokens.to(torch.float)
162 |         fp32_weight = self.gate_weight.to(torch.float)
163 |         gate_output = F.linear(fp32_input, fp32_weight)
164 | 
165 |         # update expert load
166 |         if self.enable_load_balance == True:
167 |             with torch.no_grad():
168 |                 # TODO: optimize computation
169 |                 expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1]
170 |                 # TODO: bincount introduces synchronize, fix it
171 |                 expert_load = torch.bincount(expert_load.view(-1))
172 |                 self.load_balancer.update_load(expert_load)
173 | 
174 |         # the result from the router
175 |         used_capacity, *route_result_list = self.router(
176 |             inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
177 | 
178 |         
179 |         # Convert variables to NumPy arrays
180 |         gate_output_np = gate_output.detach().cpu().numpy()
181 |         used_capacity_np = used_capacity.detach().cpu().numpy()
182 |         dispatch_mask_np = route_result_list[1].detach().cpu().numpy()
183 |         combine_score_np = route_result_list[0].detach().cpu().numpy()
184 | 
185 |         # Create a dictionary to store the NumPy arrays
186 |         data = {
187 |             "gate_output": gate_output_np.tolist(),
188 |             "used_capacity": used_capacity_np.tolist(),
189 |             "dispatch_mask": dispatch_mask_np.tolist(),
190 |             "combine_score": combine_score_np.tolist()
191 |         }
192 | 
193 |         # Save the dictionary to the output JSON file
194 |         json.dump(data, self.output_json_file)
195 |         self.output_json_file.write('\n')
196 | 
197 |         # dispatch_data: (num_experts, capacity, hidden_size)
198 |         if self.enable_kernel:
199 |             dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
200 |             dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size)
201 |         else:
202 |             sec_mask_f = route_result_list[1].type_as(inputs)
203 |             dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
204 | 
205 |         # expert_output: (num_groups, num_experts, capacity, hidden_size)
206 |         if self.expert_parallel == "EP":
207 |             expert_output = self._ep_process(
208 |                 dispatch_data,
209 |                 used_capacity,
210 |                 overlap=self.enable_comm_overlap
211 |             )
212 |         elif self.expert_parallel == "TP":
213 |             expert_output = self._tp_process(
214 |                 dispatch_data,
215 |                 used_capacity,
216 |                 overlap=self.enable_comm_overlap
217 |             )
218 |         elif self.expert_parallel is None:
219 |             expert_output = self._local_process(dispatch_data)
220 |         else:
221 |             raise NotImplementedError("This kind of communication has not been implemented yet.\n"
222 |                                       "Please use Experts build function.")
223 | 
224 |         if self.enable_kernel:
225 |             expert_output = expert_output.reshape(-1, self.hidden_size)
226 |             ans = MoeCombine.apply(expert_output, *route_result_list)
227 |         else:
228 |             combine_weights = route_result_list[0].type_as(inputs)
229 |             combine_weights = combine_weights.view(combine_weights.shape[0], -1)
230 |             expert_output = expert_output.view(-1, expert_output.shape[-1])
231 |             ans = torch.matmul(combine_weights, expert_output)
232 | 
233 |         ans = ans.reshape(inputs.shape)
234 |         return ans
235 | 
236 |     def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
237 |         expert_in = expert_in.unsqueeze(0)
238 |         expert_out = self.experts(expert_in)
239 |         return expert_out
240 | 
241 |     def _ep_process(
242 |         self,
243 |         dispatch_data: torch.Tensor,
244 |         used_capacity: torch.Tensor,
245 |         overlap: bool = False
246 |     ) -> torch.Tensor:
247 |         """
248 |         Expert Parallel
249 | 
250 |         Args:
251 |             dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)
252 | 
253 |         Returns:
254 |             torch.Tensor: (num_experts, capacity, hidden_size)
255 |         """
256 |         if not overlap or dist.get_world_size(self.ep_group) == 1:
257 |             if self.ep_hierarchical_group is not None:
258 |                 expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
259 |                 expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
260 |                 expert_output = self.experts(expert_input)
261 |                 expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
262 |                 return expert_output
263 |             else:
264 |                 expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
265 |                 expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
266 |                 expert_output = self.experts(expert_input)
267 |                 expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0]
268 |                 return expert_output
269 |         else:
270 | 
271 |             @dataclasses.dataclass
272 |             class Capsule:
273 |                 data: torch.Tensor
274 |                 handle: Any = None
275 | 
276 |             NUM_CHUNK = 4
277 |             NUM_STAGES = 4
278 | 
279 |             assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet"
280 |             chunk_size = dispatch_data.shape[1] // NUM_CHUNK
281 |             input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
282 |             dispatch_data = dispatch_data.reshape(*input_shape)
283 |             chunk_data = torch.split(dispatch_data, chunk_size, dim=2)
284 |             output = torch.empty_like(dispatch_data)
285 | 
286 |             offset = 0
287 |             _expert_in, expert_in, _expert_out, expert_out = None, None, None, None
288 | 
289 |             for i in range(NUM_CHUNK + NUM_STAGES - 1):
290 |                 if expert_out is not None:
291 |                     expert_out.handle.wait()
292 |                     output[:, :, offset:offset + chunk_size, :] = expert_out.data
293 |                     offset += chunk_size
294 |                     expert_out = None
295 | 
296 |                 # all2all last output
297 |                 if _expert_out is not None:
298 |                     expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),)
299 |                     _expert_out = None
300 | 
301 |                 # all2all next input
302 |                 if 0 <= i < NUM_CHUNK:
303 |                     _expert_in = Capsule(*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True))
304 | 
305 |                 # compute
306 |                 if expert_in is not None:
307 |                     expert_in.handle.wait()
308 |                     _expert_out = Capsule(data=self.experts(expert_in.data), handle=None)
309 |                     expert_in = None
310 | 
311 |                 if _expert_in is not None:
312 |                     expert_in = _expert_in
313 |                     _expert_in = None
314 | 
315 |             return output
316 | 
317 |     def _tp_process(
318 |         self,
319 |         dispatch_data: torch.Tensor,
320 |         used_capacity: torch.Tensor,
321 |         overlap: bool = False
322 |     ) -> torch.Tensor:
323 |         """
324 |         without overlap:
325 |                    |    C    |
326 |         |     A    |         |    R    |
327 | 
328 |         with overlap:
329 |               |    C1   ||    C2   ||    C3   ||    C4   |
330 |         | A1 || A2 |     | R1 | A3 || R2 | A4 || R3 |     | R4 |
331 | 
332 |         where C is computation, A is all gather, R is reduce scatter.
333 | 
334 |         Args:
335 |             dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)
336 | 
337 |         Returns:
338 |             torch.Tensor: (num_experts, capacity, hidden_size)
339 |         """
340 |         if not overlap or dist.get_world_size(self.ep_group) == 1:
341 |             expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0]
342 |             expert_out = self.experts(expert_in)
343 |             expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0]
344 |             return expert_out
345 |         else:
346 | 
347 |             @dataclasses.dataclass
348 |             class Capsule:
349 |                 data: torch.Tensor
350 |                 handle: Any
351 |                 indices: Tuple
352 | 
353 |             NUM_CHUNK = 4
354 |             NUM_STAGES = 4
355 | 
356 |             assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
357 |                 "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
358 |             chunk_size = dispatch_data.shape[0] // NUM_CHUNK
359 |             chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
360 |             output = torch.empty_like(dispatch_data)
361 | 
362 |             def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]:
363 |                 return (slice(idx * chunk_size, (idx + 1) * chunk_size),)
364 | 
365 |             _expert_in, expert_in, _expert_out, expert_out = None, None, None, None
366 | 
367 |             for i in range(NUM_CHUNK + NUM_STAGES - 1):
368 |                 if expert_out is not None:
369 |                     expert_out.handle.wait()
370 |                     output[expert_out.indices] = expert_out.data
371 |                     expert_out = None
372 | 
373 |                 # reduce scatter last output
374 |                 if _expert_out is not None:
375 |                     expert_out = Capsule(
376 |                         *ReduceScatter.apply(_expert_out.data, self.ep_group, True),
377 |                         indices=_expert_out.indices,
378 |                     )
379 |                     _expert_out = None
380 | 
381 |                 # all gather next input
382 |                 if 0 <= i < NUM_CHUNK:
383 |                     _expert_in = Capsule(
384 |                         *AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True),
385 |                         indices=get_chunk_slice(i, chunk_size),
386 |                     )
387 | 
388 |                 # compute
389 |                 if expert_in is not None:
390 |                     expert_in.handle.wait()
391 |                     _expert_out = Capsule(
392 |                         self.experts(expert_in.data, expert_in.indices),
393 |                         handle=None,
394 |                         indices=expert_in.indices,
395 |                     )
396 |                     expert_in = None
397 | 
398 |                 if _expert_in is not None:
399 |                     expert_in = _expert_in
400 |                     _expert_in = None
401 | 
402 |             return output
403 | 
404 | 
405 | def apply_load_balance(model: nn.Module, optim: Any) -> None:
406 |     """
407 |     apply load balance to every experts in the model
408 |     """
409 | 
410 |     def _apply_recursive(module: nn.Module):
411 |         for _, sub_module in module.named_children():
412 |             if isinstance(sub_module, SparseMLP):
413 |                 if sub_module.enable_load_balance == True:
414 |                     sub_module.load_balancer.balance_load(optim)
415 |             _apply_recursive(sub_module)
416 | 
417 |     torch.cuda.empty_cache()
418 |     _apply_recursive(model)
419 |     torch.cuda.empty_cache()
420 | 


--------------------------------------------------------------------------------
/env/colab_env.txt:
--------------------------------------------------------------------------------
 1 | absl-py==1.4.0
 2 | clu @ git+https://github.com/google/CommonLoopUtils#egg=clu
 3 | flax @ git+https://github.com/google/flax#egg=flax
 4 | fiddle==0.2.11
 5 | gin-config==0.5.0
 6 | jax==0.4.20
 7 | jaxtyping==0.2.24
 8 | jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.20+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl#sha256=01be66238133f884bf5adf15cd7eaaf8445f9d4b056c5c64df28a997a6aff2fe
 9 | jestimator @ git+https://github.com/google-research/jestimator#egg=jestimator
10 | numpy==1.23.5
11 | optax @ git+https://github.com/deepmind/optax@1a7956d0e6f6f2881fb8f7d4c579a6435c8431eb
12 | orbax-checkpoint==0.4.4
13 | seqio @ git+https://github.com/google/seqio#egg=seqio
14 | tensorflow-cpu==2.15.0.post1
15 | tensorstore==0.1.45
16 | protobuf==3.20.3
17 | colossalai >= 0.3.3
18 | torch >= 2.1.0
19 | transformers==4.34.0
20 | sentencepiece==0.1.99
21 | datasets==2.14.7
22 | chex==0.1.7
23 | 


--------------------------------------------------------------------------------
/env/openmoe_infer_dockerfile:
--------------------------------------------------------------------------------
 1 | # Use NVIDIA PyTorch image as the base
 2 | FROM nvcr.io/nvidia/pytorch:23.12-py3
 3 | 
 4 | # Set the working directory
 5 | WORKDIR /workspace
 6 | 
 7 | # Download the requirements file 
 8 | RUN wget https://github.com/XueFuzhao/OpenMoE/raw/main/env/colab_env.txt -O colab_env.txt
 9 | 
10 | # Install Python packages
11 | RUN python --version && \
12 |     python -m pip install --upgrade pip && \
13 |     pip install -r colab_env.txt
14 | 
15 | # Clone and install t5x
16 | RUN git clone --branch=main https://github.com/Orion-Zheng/t5x && \
17 |     python -m pip install ./t5x
18 | 
19 | # Clone and install ColossalAI
20 | RUN git clone --branch my_openmoe https://github.com/Orion-Zheng/ColossalAI.git && \
21 |     pip install ./ColossalAI && \
22 |     python -m pip install -r ./ColossalAI/examples/language/openmoe/requirements.txt
23 | 
24 | # Set command to run on container start
25 | CMD ["bash"]
26 | 


--------------------------------------------------------------------------------
/env/prepare_env.sh:
--------------------------------------------------------------------------------
 1 | # Make sure that you're using the python 3.10 environment.
 2 | # For example, you can create a conda virtual environment by running commands below.
 3 | # conda create --name openmoe_env python=3.10
 4 | # conda activate openmoe_env
 5 | python --version  
 6 | python -m pip install --upgrade pip
 7 | 
 8 | # Prepare Colab Env
 9 | pip install -r colab_env.txt
10 | 
11 | # Install t5x
12 | git clone --branch=main https://github.com/Orion-Zheng/t5x
13 | python -m pip install ./t5x
14 | 
15 | # Install ColossalAI
16 | git clone --branch my_openmoe https://github.com/Orion-Zheng/ColossalAI.git
17 | pip install ./ColossalAI
18 | python -m pip install -r ./ColossalAI/examples/language/openmoe/requirements.txt


--------------------------------------------------------------------------------
/eval/plot_bigbench.py:
--------------------------------------------------------------------------------
  1 | import pandas as pd
  2 | import matplotlib.pyplot as plt
  3 | from matplotlib.lines import Line2D
  4 | 
  5 | 
  6 | csv_file_path = '../../final_results/BIG-bench-Lite/retrieved_results.csv'
  7 | output_csv_file_dir = '../../final_results/BIG-bench-Lite'
  8 | 
  9 | model_mapping = {
 10 |     'GPT-3-3B': 'GPT-3_3B',
 11 |     'GPT-3-6B': 'GPT-3_6B',
 12 |     'GPT-3-13B': 'GPT-3_13B',
 13 |     'GPT-3-200B': 'GPT-3_200B',
 14 |     'GPT-3-Small': 'GPT-3_125m',
 15 |     'GPT-3-Medium': 'GPT-3_350m',
 16 |     'GPT-3-Large': 'GPT-3_760m',
 17 |     'GPT-3-XL': 'GPT-3_1300m',
 18 |     "all_examples": "OpenMoE_8B",
 19 | }
 20 | 
 21 | # Dictionary to update model sizes
 22 | model_size_updates = {
 23 |     'BIG-G-sparse_2m': 60630144,
 24 |     'BIG-G-sparse_16m': 234507776,
 25 |     'BIG-G-sparse_53m': 534215808,
 26 |     'BIG-G-sparse_125m': 1777677312,
 27 |     'BIG-G-sparse_244m': 2819788160,
 28 |     'BIG-G-sparse_422m': 4126141440,
 29 |     'BIG-G-sparse_1b': 7581906944,
 30 |     'BIG-G-sparse_2b': 17278886400,
 31 |     'BIG-G-sparse_4b': 25465853952,
 32 |     'BIG-G-sparse_8b': 60261322752,
 33 | }
 34 | 
 35 | 
 36 | 
 37 | def read_csv_and_convert_to_dataframe(file_path):
 38 |     try:
 39 |         df = pd.read_csv(file_path)
 40 |         return df
 41 |     except FileNotFoundError:
 42 |         print("The CSV file could not be found.")
 43 |         return None
 44 |     except Exception as e:
 45 |         print("An error occurred:", str(e))
 46 |         return None
 47 | 
 48 | # Define a function to calculate num_tokens based on model name
 49 | def calculate_num_tokens(model_name):
 50 |     if model_name.startswith('BIG-G'):
 51 |         return 131 * 10**9
 52 |     elif model_name.startswith('PaLM'):
 53 |         return 780 * 10**9
 54 |     elif model_name.startswith('GPT-3'):
 55 |         return 300 * 10**9
 56 |     elif model_name.startswith('OpenMoE'):
 57 |         return 200 * 10**9
 58 |     else:
 59 |         return None
 60 | 
 61 | # Extract the model size information and convert it to desired format
 62 | def convert_size(size_str):
 63 |     print(size_str)
 64 |     size = int(size_str[:-1])
 65 |     unit = size_str[-1]
 66 |     if unit == 'm':
 67 |         return size * 10**6
 68 |     elif unit == 'b' or unit == 'B' :
 69 |         return size * 10**9
 70 |     else:
 71 |         return size
 72 | 
 73 | 
 74 | # Define a function to calculate activated_parameters based on model name
 75 | def calculate_activated_parameters(row):
 76 |     model_name = row['model']
 77 |     model_size = row['model_size']
 78 | 
 79 |     if model_name.startswith('OpenMoE'):
 80 |         model_activate_param_mapping = {
 81 |             'OpenMoE_8B': 2 * 10 ** 9
 82 |         }
 83 |         return model_activate_param_mapping.get(model_name, model_size)
 84 |     else:
 85 |         return model_size
 86 | 
 87 | # Read csv file
 88 | csv_dataframe = read_csv_and_convert_to_dataframe(csv_file_path)
 89 | 
 90 | # write one csv file for each shot
 91 | _number_of_shots = [3] # [0,1,2,3]
 92 | for shot in _number_of_shots:
 93 |     csv_dataframe_tmp = csv_dataframe.copy()
 94 |     for column in csv_dataframe_tmp.columns:
 95 |         csv_dataframe_tmp[column] = csv_dataframe_tmp[column].apply(
 96 |             lambda x: [float(value) for value in x.split()][shot] if " " in x else x)
 97 | 
 98 |     csv_dataframe_tmp['model'] = csv_dataframe_tmp['model'].replace(model_mapping)
 99 | 
100 |     # Filter out rows with 'T=0' in 'model' column
101 |     csv_dataframe_tmp = csv_dataframe_tmp[~csv_dataframe_tmp['model'].str.contains('T=0')]
102 | 
103 |     csv_dataframe_tmp['average'] = csv_dataframe_tmp.iloc[:, 1:].mean(axis=1)
104 | 
105 |     csv_dataframe_tmp['model_name'] = csv_dataframe_tmp['model'].apply(lambda x: x.split('_')[0])
106 | 
107 |     csv_dataframe_tmp['model_size'] = csv_dataframe_tmp['model'].str.extract(r'(\d+[mMbB])')[0].apply(convert_size)
108 | 
109 |     # Add 'activated_parameters' column using the function
110 |     csv_dataframe_tmp['activated_parameters'] = csv_dataframe_tmp.apply(calculate_activated_parameters, axis=1)
111 |     # Update 'model_size' column based on model names in the dictionary
112 |     csv_dataframe_tmp.loc[csv_dataframe_tmp['model'].isin(model_size_updates.keys()), 'model_size'] = csv_dataframe_tmp['model'].map(model_size_updates)
113 | 
114 |     # Add 'num_tokens' column using the function
115 |     csv_dataframe_tmp['num_tokens'] = csv_dataframe_tmp['model'].apply(calculate_num_tokens)
116 | 
117 | 
118 |     csv_dataframe_tmp['cost'] = (csv_dataframe_tmp['activated_parameters']/(10.0**9)) * (csv_dataframe_tmp['num_tokens']/(10.0**9))
119 |     csv_dataframe_tmp['cost'] = csv_dataframe_tmp['cost'] / 1000.0
120 |     csv_dataframe_col = csv_dataframe_tmp[
121 |         [
122 |             'model',
123 |             'model_name',
124 |             'model_size',
125 |             'activated_parameters',
126 |             'num_tokens',
127 |             'cost',
128 |             'average'
129 |          ]
130 |     ]
131 |     # print(csv_dataframe_col)
132 |     file_to_save = f'{output_csv_file_dir}/{shot}-shot.csv'
133 |     csv_dataframe_col.to_csv(
134 |         file_to_save, header=True, index=False)
135 | 
136 | 
137 | 
138 | # marker_shapes = {
139 | #     'BIG-G': 's',
140 | #     'GPT-3': 's',
141 | #     'BIG-G-sparse': 'o',
142 | #     'OpenMoE': 'o',
143 | # }
144 | 
145 | 
146 | # Plot and save the figures
147 | for shot in _number_of_shots:
148 |     # read csv file again
149 |     csv_file_path = f'{output_csv_file_dir}/{shot}-shot.csv'
150 |     df = read_csv_and_convert_to_dataframe(csv_file_path)
151 |     # print(df)
152 |     df['cost'] = df['cost'].astype(float)
153 |     df['average'] = df['average'].astype(float)
154 |     df['activated_parameters'] = df['activated_parameters'].astype(float) / 1000000000.0
155 |     df['model_size'] = df['model_size'].astype(float) / 1000000000.0
156 |     df = df[0.05 < df['cost']][df['cost'] < 1.5]
157 |     # df.sort_values(by=['model_name', 'activated_parameters'])
158 | 
159 |     model_data = {}
160 |     # Convert the DataFrame to a dictionary and store it in the data_dict
161 |     for model_name, group_tmp in df.groupby('model_name'):
162 |         if model_name not in model_data:
163 |             model_data[model_name] = {"cost": [], "result": [], "total_param": [], "act_param": []}
164 |         # Append each row of the group to the corresponding lists
165 |         # print(model_data[model_name])
166 |         for index, row in group_tmp.sort_values(by=['cost']).iterrows():
167 |             model_data[model_name]["cost"].append(row["cost"])
168 |             model_data[model_name]["result"].append(row['average'])
169 |             # print(float(row['model_size']), float(row['model_size'])/1000000000.0)
170 |             model_data[model_name]["total_param"].append(row['model_size'])
171 |             model_data[model_name]["act_param"].append(row['activated_parameters'])
172 | 
173 |     # Plotting
174 |     plt.figure(figsize=(12, 6))
175 |     colors = plt.cm.get_cmap("viridis", len(model_data))
176 | 
177 |     handles = []
178 |     for i, (model, values) in enumerate(model_data.items()):
179 |         x = values["cost"]
180 |         y = values["result"]
181 |         # print(values["act_param"])
182 |         sizes = [param * 20.0 for param in values["act_param"]]  # Adjust the scaling factor as needed
183 |         total_sizes = [param * 20.0 for param in values["total_param"]]
184 |         color = 'red' if model == "OpenMoE" else colors(i)
185 |         marker = 'o'  # if "MoE" in model else 'o'
186 |         print(model)
187 |         if ("MoE" in model ) or ('sparse' in model):
188 |             print(model)
189 |             plt.scatter(x, y, label=model, color='lightgray', s=total_sizes, marker=marker)
190 |         plt.scatter(x, y, label=model, color=color, s=sizes, marker=marker)
191 |         # handle = mpatches.Patch(color=colors(i), label=f'{model}', marker='o')
192 |         handle = Line2D([0], [0], marker=marker, color='w', label=f'{model}', markersize=12,
193 |                         markerfacecolor=color)
194 |         handles.append(handle)
195 |         # Adding dashed lines to connect dots of the same model
196 |         plt.plot(x, y, linestyle='dashed', color=color)
197 |         # Annotate each dot with the model name and total parameters
198 |         for j, (x_val, y_val) in enumerate(zip(x, y)):
199 |             if values["total_param"][j] >= 1.0:
200 |                 dot_name = f'{model}-{int(values["total_param"][j])}B'
201 |             else:
202 |                 dot_name = f'{model}-{int(values["total_param"][j]*100)}M'
203 |             plt.annotate(dot_name, (x_val, y_val), textcoords="offset points",
204 |                          xytext=(0, 10), ha='left', size=7)
205 | 
206 |     plt.xlabel("Relative Cost")
207 |     plt.ylabel(f"BigBench-Lite ({shot}-shot)")
208 |     plt.title(f"Relative Cost vs BigBench-Lite ({shot}-shot)")
209 |     # Move the legend to the outside right of the main figure
210 |     plt.legend(handles=handles, loc='upper left', bbox_to_anchor=(1.10, 1), title="Model")
211 |     plt.subplots_adjust(right=0.75)
212 |     plt.grid(True)
213 |     # plt.show()
214 |     plt.savefig(f"../figure/bblite-{shot}-shot.pdf", dpi=300, bbox_inches="tight")
215 |     plt.savefig(f"../figure/bblite-{shot}-shot.png", dpi=300, bbox_inches="tight")
216 | 
217 | 
218 | 
219 |     # df = df.sort_values(by=['model_name', 'activated_parameters'])
220 |     # # Calculate the size of the dots based on activated parameters
221 |     # df['dot_size'] = df['activated_parameters']/1000000000
222 |     # # Get the unique model names
223 |     # unique_model_names = df['model_name'].unique()
224 |     # # Create a color palette based on unique model names
225 |     # color_palette = sns.color_palette("Set1", n_colors=len(df.groupby('model_name')))
226 |     # print(color_palette)
227 |     # # Set style using Seaborn
228 |     # sns.set(style="whitegrid")
229 |     # # Create the plot
230 |     # plt.figure(figsize=(10, 6))
231 |     # # Scatter plot for individual data points
232 |     # sns.scatterplot(
233 |     #     x='cost',
234 |     #     y='average',
235 |     #     hue='model_name',
236 |     #     palette=color_palette,
237 |     #     size='dot_size',  # Use dot size based on activated parameters
238 |     #     # sizes=(20, 200),  # Define the range of dot sizes
239 |     #     style='model_name',
240 |     #     markers=marker_shapes,
241 |     #     data=df
242 |     # )
243 |     #
244 |     # # Line plot to connect models with the same model_name
245 |     # for model_name, group_tmp in df.groupby('model_name'):
246 |     #     group = group_tmp.sort_values(by=['cost'])
247 |     #     sns.lineplot(
248 |     #         x='cost',
249 |     #         y='average',
250 |     #         data=group,
251 |     #         color=color_palette[unique_model_names.tolist().index(model_name)],
252 |     #         dashes=True #  if model_name.startswith('BIG-G') else False  # Use dashed line for models with specific names
253 |     #     )
254 |     # for i, (model_name, group_tmp) in enumerate(df.groupby('model_name')):
255 |     #     group = group_tmp.sort_values(by=['cost'])
256 |     #     sns.scatterplot(
257 |     #         x='cost',
258 |     #         y='average',
259 |     #         hue='model_name',
260 |     #         color_palette=color_palette[i:i+1],
261 |     #         style='model_name',
262 |     #         markers=marker_shapes,
263 |     #         data=group
264 |     #     )
265 |         # plt.plot(
266 |         #     group['cost'],
267 |         #     group['average'],
268 |         #     markers=marker_shapes,
269 |         #     linestyle='--',
270 |         #     label=model_name
271 |         # )
272 | 
273 |     # plt.xlabel('Cost')
274 |     # plt.ylabel('Average')
275 |     # plt.legend()
276 |     # plt.title('Average vs. Cost by Model')
277 |     # plt.grid(True)
278 |     # plt.show()
279 |     # print(asd)
280 | 
281 | 


--------------------------------------------------------------------------------
/eval/result_retrieval_bigbenchlite.py:
--------------------------------------------------------------------------------
  1 | import json
  2 | import subprocess
  3 | import os
  4 | import logging
  5 | import sys
  6 | import warnings
  7 | import pandas as pd
  8 | 
  9 | logging.basicConfig()
 10 | logger = logging.getLogger('result_retrieval_bigbenchlite')
 11 | logger.setLevel(logging.DEBUG)
 12 | 
 13 | 
 14 | #******************************************************************************************************************************************************************************************************************
 15 | #                                   *** organizations of the evaluation metric result files ***
 16 | 
 17 | #* The metric result files categorized by tasks, models, shots, and subtasks should be put under the same directory without hierarchy. 
 18 | #* File name should be organized like "bigbench:{task_name}.{mul/gen}.{some hyperparams}.{shot_num}.{model_name}.{subtask_name}-metrics.jsonl".
 19 | #* e.g. bigbench:bbq_lite_json.mul.t5_default_vocab.3_shot.all_examples.bbq_lite_json_age_ambig-metrics.jsonl
 20 | #* Note: .{subtask_name} in the file name is optional. For some tasks, it doesn't have subtasks. For a certain task, the correspnding result files shouldn't contain both subtask files and non-subtask files.
 21 | #******************************************************************************************************************************************************************************************************************
 22 | 
 23 | 
 24 | # The tasks to be retrieved in BIG-bench. If a task is not found, the corresponding row with be Nones.
 25 | _tasks = ["auto_debugging", "bbq_lite_json", "code_line_description", 
 26 |          "conceptual_combinations", "conlang_translation", "emoji_movie", 
 27 |          "formal_fallacies_syllogisms_negation", "hindu_knowledge", 
 28 |          "known_unknowns", "language_identification", "linguistics_puzzles", 
 29 |          "logic_grid_puzzle", "logical_deduction", "misconceptions_russian", 
 30 |          "novel_concepts", "operators", "parsinlu_reading_comprehension", 
 31 |          "play_dialog_same_or_different", "repeat_copy_logic", "strange_stories", 
 32 |          "strategyqa", "symbol_interpretation", "vitaminc_fact_verification", "winowhy"]
 33 | 
 34 | # The models whose results to be retrieved for each task in BIG-bench. If a model is not found, the corresponding column with be Nones.
 35 | _models_repo = [
 36 | "BIG-G_2m_T=0",
 37 | "BIG-G_2m_T=1",
 38 | "BIG-G_16m_T=0",
 39 | "BIG-G_16m_T=1",
 40 | "BIG-G_53m_T=0",
 41 | "BIG-G_53m_T=1",
 42 | "BIG-G_125m_T=0",
 43 | "BIG-G_125m_T=1",
 44 | "BIG-G_244m_T=0",
 45 | "BIG-G_244m_T=1",
 46 | "BIG-G_422m_T=0",
 47 | "BIG-G_422m_T=1",
 48 | "BIG-G_1b_T=0",
 49 | "BIG-G_1b_T=1",
 50 | "BIG-G_2b_T=0",
 51 | "BIG-G_2b_T=1",
 52 | "BIG-G_4b_T=0",
 53 | "BIG-G_4b_T=1",
 54 | "BIG-G_8b_T=0",
 55 | "BIG-G_8b_T=1",
 56 | "BIG-G_27b_T=0",
 57 | "BIG-G_27b_T=1",
 58 | "BIG-G_128b_T=0",
 59 | "BIG-G_128b_T=1",
 60 | "BIG-G-sparse_2m",
 61 | "BIG-G-sparse_16m",
 62 | "BIG-G-sparse_53m",
 63 | "BIG-G-sparse_125m",
 64 | "BIG-G-sparse_244m",
 65 | "BIG-G-sparse_422m",
 66 | "BIG-G-sparse_1b",
 67 | "BIG-G-sparse_2b",
 68 | "BIG-G-sparse_4b",
 69 | "BIG-G-sparse_8b",
 70 | "GPT-3-3B", 
 71 | "GPT-3-6B", 
 72 | "GPT-3-13B", 
 73 | "GPT-3-200B",
 74 | "GPT-3-Small",
 75 | "GPT-3-Medium",
 76 | "GPT-3-Large",
 77 | "GPT-3-XL",
 78 | "PaLM_8b",
 79 | "PaLM_64b",
 80 | "PaLM_535b"
 81 | ]
 82 | 
 83 | # Our models whose results to be retrieved for each task.
 84 | _models_ours = [
 85 |     "all_examples"
 86 | ]
 87 | 
 88 | # n-shot results wanted. For example, if 0-shot and 3-shot is wanted, 
 89 | # each cell of the retrived results will be "score_0 score_3", which denotes the concatenated scores for 0-shot and 3-shot results
 90 | _number_of_shots = [0,1,2,3,4,5]
 91 | 
 92 | # The path to the retrieved results
 93 | _file_to_save = '../../final_results/BIG-bench-Lite/retrieved_results.csv'
 94 | 
 95 | # The path to the results directory
 96 | _results_dir = '../../inference_results/BIG-bench-Lite-T0'
 97 | 
 98 | # The path to the BIG-bench tasks directory
 99 | _repo_dir = '../../BIG-bench/bigbench/benchmark_tasks'
100 | 
101 | 
102 | def find_repo_task_dir(repo_dir: str, tasks: list):
103 |     task_dirs = []
104 |     existing_tasks = []
105 |     repo_tasks_dir = os.path.abspath(repo_dir)
106 |     for task in tasks:
107 |         for i, dir_name in enumerate(os.listdir(repo_tasks_dir)):
108 |             if task == dir_name:
109 |                 task_dirs.append(os.path.join(repo_tasks_dir, dir_name))
110 |                 existing_tasks.append(dir_name)
111 |                 break
112 |             elif i == len(os.listdir(repo_tasks_dir)) - 1:
113 |                 warnings.warn(f"Task '{task}' not found in the repo task dir {os.path.join(repo_tasks_dir, dir_name)}.")
114 |                 task_dirs.append(None)
115 |                 existing_tasks.append(None)
116 |     assert len(task_dirs) == len(existing_tasks)
117 |     return task_dirs, existing_tasks
118 | 
119 | 
120 | def update_task_metainfo(task_metric_meta, model_json):
121 |     with open(model_json, 'r') as json_file:
122 |         results_dict = json.load(json_file)
123 |     scores_dict = {f"{D['subtask_description']}---{D['number_of_shots']}_shot": D for D in results_dict['scores']}
124 |     for key in scores_dict.keys():
125 |         if key not in task_metric_meta.keys():
126 |             task_metric_meta[key] = scores_dict[key]
127 |     return task_metric_meta
128 | 
129 | 
130 | def find_repo_task_model_json(task_dir: str, models: list):
131 |     model_jsons = []
132 |     task_metric_meta = {}
133 |     repo_task_result_dir = os.path.abspath(f"{task_dir}/results")
134 |     for model in models:
135 |         if task_dir is not None:
136 |             for i, file_name in enumerate(os.listdir(repo_task_result_dir)):
137 |                 if file_name.startswith('scores') and file_name.endswith('.json') and model in file_name:
138 |                     model_jsons.append(os.path.join(repo_task_result_dir, file_name))
139 |                     task_metric_meta = update_task_metainfo(task_metric_meta, os.path.join(repo_task_result_dir,
140 |                                                                                            file_name))  # update the task meta info regarding that in the repo
141 |                     break
142 |                 elif i == len(os.listdir(repo_task_result_dir)) - 1:
143 |                     warnings.warn(
144 |                         f"Result of model '{model}' not found in the repo task results dir {os.path.join(repo_task_result_dir, 'results')}.")
145 |                     model_jsons.append(None)
146 |         else:
147 |             model_jsons.append(None)
148 |     assert len(models) == len(model_jsons)
149 | 
150 |     return model_jsons, task_metric_meta
151 | 
152 | 
153 | def min_max_normalize(score, min_score, max_score):
154 |     normalized_score = (score - min_score) / (max_score - min_score)
155 |     return normalized_score
156 | 
157 | 
158 | def compute_scores_repo(models: list, model_jsons: list, number_of_shots: list):
159 |     '''
160 |     Returns a list of normalized scores whose length is equal to the model number, each score is a string concatenating the scores of n-shot.
161 |     '''
162 |     normalized_scores_allmodel = {}
163 |     for model_json, model in zip(model_jsons, models):
164 |         if model_json is not None:
165 |             with open(model_json, 'r') as json_file:
166 |                 results_dict = json.load(json_file)
167 |             scores_dict = {f"{D['subtask_description']}---{D['number_of_shots']}_shot": D for D in
168 |                            results_dict['scores']}
169 |             normalized_scores = [{} for _ in
170 |                                  number_of_shots]  # the computed n-shot scores for this model in this task, n corresponds to the n-th digit in this list
171 |             normalized_scores_existing = [0] * len(
172 |                 number_of_shots)  # the computed n-shot scores for this model in this task, n corresponds to the n-th digit in this list
173 |             for si, shot_num in enumerate(number_of_shots):
174 |                 for subtask, score_meta in scores_dict.items():
175 |                     if score_meta['number_of_shots'] == shot_num:
176 |                         raw_score = score_meta['score_dict'][score_meta['preferred_score']]
177 |                         if len(subtask.split(':')) == 2:
178 |                             high_score = score_meta['high_score']
179 |                             low_score = score_meta['low_score']
180 |                             normalized_scores[si][subtask] = raw_score
181 |                             # normalized_scores[si][subtask] = raw_score
182 |                         else:
183 |                             high_score = score_meta['high_score']
184 |                             low_score = score_meta['low_score']
185 |                             normalized_scores_existing[si] += min_max_normalize(raw_score, low_score, high_score)
186 |                             normalized_scores[si]['high_score'] = score_meta['high_score']
187 |                             normalized_scores[si]['low_score'] = score_meta['low_score']
188 |                             if len({k: v for k, v in normalized_scores[si].items() if
189 |                                     k != 'high_score' and k != 'low_score'}.keys()) == 0:
190 |                                 normalized_scores[si][subtask] = raw_score
191 | 
192 |             normalized_scores = [min_max_normalize(float(
193 |                 f"{sum(list({k: v for k, v in subtask_normscores_shot.items() if k != 'high_score' and k != 'low_score'}.values())) / len(list({k: v for k, v in subtask_normscores_shot.items() if k != 'high_score' and k != 'low_score'}.values())):.{4}f}"),
194 |                                                    subtask_normscores_shot['low_score'],
195 |                                                    subtask_normscores_shot['high_score'])
196 |                                  if len(list({k: v for k, v in subtask_normscores_shot.items() if
197 |                                               k != 'high_score' and k != 'low_score'}.values())) > 0 else 0.0
198 |                                  for subtask_normscores_shot in normalized_scores]
199 |             normalized_scores_allmodel[model] = ' '.join(
200 |                 [str(f"{100 * float(num):.{2}f}") for num in normalized_scores_existing])
201 |         else:
202 |             normalized_scores_allmodel[model] = 'None'
203 |     return normalized_scores_allmodel
204 | 
205 | 
206 | def retrieve_and_write_repo(file_to_save: str, repo_dir: list,
207 |                             models: list, tasks: list,
208 |                             number_of_shots: list = [3]):
209 |     task_dirs, _ = find_repo_task_dir(repo_dir, tasks)
210 |     normalized_scores = {task: [] for task in tasks}
211 |     task_metric_metainfos = {task: [] for task in tasks}
212 |     for task_dir, task in zip(task_dirs, tasks):
213 |         model_jsons, task_metric_meta = find_repo_task_model_json(task_dir, models)
214 |         normalized_scores[task] = compute_scores_repo(models, model_jsons, number_of_shots)
215 |         task_metric_metainfos[task] = task_metric_meta
216 | 
217 |     def custom_sort_key(item):
218 |         return tasks.index(item[0])
219 | 
220 |     normalized_scores = dict(sorted(normalized_scores.items(), key=custom_sort_key))
221 |     Header = [['model'] + tasks]
222 |     score_mat = [models] + [inner_dict.values() for outer_key, inner_dict in normalized_scores.items()]
223 |     df1 = pd.DataFrame(Header)
224 |     df2 = pd.DataFrame(score_mat).transpose()
225 |     df = pd.concat([df1, df2], ignore_index=True)
226 |     df.to_csv(file_to_save, header=False, index=False)
227 |     return df, task_metric_metainfos
228 | 
229 | 
230 | def find_result_task_jsonsubsets(results_dir: str, tasks: list):
231 |     tasks_origin = tasks
232 |     tasks = [f'{task}.mul' for task in tasks] + [f'{task}.gen' for task in tasks]
233 |     task_dir = {task: [] for task in tasks}
234 |     task_dir_origin = {task: [] for task in tasks_origin}
235 |     for task in tasks:
236 |         for i, result_filename in enumerate(os.listdir(results_dir)):
237 |             if task in result_filename:
238 |                 task_dir[task].append(os.path.join(results_dir, result_filename))
239 |         if task_dir[task] == []:
240 |             task_dir.pop(task)
241 | 
242 |     for task in tasks_origin:
243 |         for i, result_filename in enumerate(os.listdir(results_dir)):
244 |             if task in result_filename:
245 |                 task_dir_origin[task].append('_')
246 |         if task_dir_origin[task] == []:
247 |             warnings.warn(f"Task '{task}' not found in the result dir {results_dir}.")
248 |     return task_dir
249 | 
250 | 
251 | def find_result_taskmodel_json(task, task_jsons: list, models: list, number_of_shots: list):
252 |     model_jsons = {model: {shot: [] for shot in number_of_shots} for model in models}
253 | 
254 |     if task_jsons != []:
255 |         for model in models:
256 |             model_found = False
257 |             for i, file_name in enumerate(task_jsons):
258 |                 for si, number_of_shot in enumerate(number_of_shots):
259 |                     if file_name.endswith(
260 |                             'metrics.jsonl') and model in file_name and f"{number_of_shot}_shot" in file_name:
261 |                         model_found = True
262 |                         model_jsons[model][number_of_shot].append(file_name)
263 |             if not model_found:
264 |                 warnings.warn(f"Result of model '{model}' not found in the {task} subset.")
265 | 
266 |     assert len(models) == len(model_jsons.keys())
267 | 
268 |     return model_jsons
269 | 
270 | 
271 | def compute_scores_result(task: str, models: list, model_jsons: list, number_of_shots: list,
272 |                           task_metric_metainfos: dict):
273 |     assert len(models) == len(model_jsons.keys())
274 | 
275 |     normalized_scores_allmodel = {}
276 |     for model in models:
277 |         model_found = False
278 |         normalized_scores = [{} for _ in number_of_shots]
279 |         for i, number_of_shot in enumerate(number_of_shots):
280 |             if model_jsons[model][number_of_shot] != []:
281 |                 model_found = True
282 |                 for subtask_file in model_jsons[model][number_of_shot]:
283 |                     subtask_name = subtask_file.split('.')[-2].replace('-metrics', '')
284 |                     if subtask_name != model:
285 |                         key_name = f"{task.replace('.mul', '').replace('.gen', '')}:{subtask_name}---{number_of_shot}_shot".replace(
286 |                             'atikamp__', 'atikamp?_')
287 |                     else:
288 |                         key_name = f"{task.replace('.mul', '').replace('.gen', '')}---{number_of_shot}_shot".replace(
289 |                             'atikamp__', 'atikamp?_')
290 |                     key_name_task = f"{task.replace('.mul', '').replace('.gen', '')}---{number_of_shot}_shot".replace(
291 |                         'atikamp__', 'atikamp?_')
292 |                     preferred_score = task_metric_metainfos[task.replace('.mul', '').replace('.gen', '')][key_name][
293 |                         'preferred_score']
294 |                     high_score = task_metric_metainfos[task.replace('.mul', '').replace('.gen', '')][key_name_task][
295 |                         'high_score']
296 |                     low_score = task_metric_metainfos[task.replace('.mul', '').replace('.gen', '')][key_name_task][
297 |                         'low_score']
298 |                     with open(subtask_file, 'r') as json_file:
299 |                         results_dict = json.load(json_file)
300 |                     if preferred_score in results_dict.keys():
301 |                         score = results_dict[preferred_score]
302 |                     else:
303 |                         warnings.warn(f"Task '{task}' do not have the {preferred_score} metric.")
304 |                         return None
305 |                     normalized_scores[i][key_name] = score
306 |                     normalized_scores[i]['low_score'] = low_score
307 |                     normalized_scores[i]['high_score'] = high_score
308 |                     # if 'logical_deduction' in subtask_file:
309 |                     #     print(key_name, score, low_score, high_score)
310 |         for shot_dir in normalized_scores:
311 |             for subtask in {k: v for k, v in shot_dir.items() if k != 'high_score' and k != 'low_score'}.keys():
312 |                 if ':' not in subtask:
313 |                     assert len({k: v for k, v in shot_dir.items() if
314 |                                 k != 'high_score' and k != 'low_score'}.keys()) == 1, f"For task {task} that doesn't has subtask, there shouldn't be any other subtask result files."
315 |         #  for the existing shots, average the scores corresponding to that shot and normalize. If a shot does not exist, set the corresponding score as 0.
316 |         normalized_scores = [min_max_normalize(float(
317 |             f"{sum(list({k: v for k, v in subtask_normscores_shot.items() if k != 'high_score' and k != 'low_score'}.values())) / len(list({k: v for k, v in subtask_normscores_shot.items() if k != 'high_score' and k != 'low_score'}.values())):.{4}f}"),
318 |                                                subtask_normscores_shot['low_score'],
319 |                                                subtask_normscores_shot['high_score'])
320 |                              if len(list({k: v for k, v in subtask_normscores_shot.items() if
321 |                                           k != 'high_score' and k != 'low_score'}.values())) > 0 else 0.0
322 |                              for subtask_normscores_shot in normalized_scores]
323 |         if model_found:
324 |             normalized_scores_allmodel[model] = ' '.join(
325 |                 [str(f"{100 * float(num):.{2}f}") for num in normalized_scores])
326 |         else:
327 |             normalized_scores_allmodel[model] = 'None'
328 |     return normalized_scores_allmodel
329 | 
330 | 
331 | def retrieve_and_write_results(file_to_save, results_dir, models, tasks, number_of_shots, task_metric_metainfos):
332 |     task_dir = find_result_task_jsonsubsets(results_dir, tasks)
333 |     normalized_scores = {task: {} for task in task_dir.keys()}
334 |     for task, task_jsons in task_dir.items():
335 |         model_jsons = find_result_taskmodel_json(task, task_jsons, models,
336 |                                                  number_of_shots)  # a dict consisting of hierarchically organized json files by the models, shots, and subtasks
337 |         _r = compute_scores_result(task, models, model_jsons, number_of_shots, task_metric_metainfos)
338 |         if _r is not None:
339 |             normalized_scores[task] = _r
340 |     normalized_scores = {k.replace('.mul', '').replace('.gen', ''): v for k, v in normalized_scores.items() if v != {}}
341 | 
342 |     def custom_sort_key(item):
343 |         return tasks.index(item[0])
344 | 
345 |     normalized_scores = dict(sorted(normalized_scores.items(), key=custom_sort_key))
346 | 
347 |     assert len(normalized_scores.keys()) == len(
348 |         tasks), f"len(normalized_scores.keys()), len(tasks): {len(normalized_scores.keys())}, {len(tasks)}"
349 |     score_mat = [models] + [inner_dict.values() for outer_key, inner_dict in normalized_scores.items()]
350 |     df = pd.DataFrame(score_mat).transpose()
351 |     df.to_csv(file_to_save, header=False, index=False, mode='a')
352 | 
353 |     return df
354 |     
355 | 
356 | def result_retrieval_bigbenchlite(file_to_save:str, results_dir: str, repo_dir:list, 
357 |                                   models_repo: list, models_ours: list, tasks: list, 
358 |                                   number_of_shots: list = [3]):
359 |     # retrieve and write from the repository results
360 |     # subprocess.run("git clone git@github.com:google/BIG-bench.git".split())
361 |     _, task_metric_metainfos = retrieve_and_write_repo(file_to_save, repo_dir, models_repo, tasks, number_of_shots)
362 |     retrieve_and_write_results(file_to_save, results_dir, models_ours, tasks, number_of_shots, task_metric_metainfos)
363 |     
364 |             
365 | 
366 | 
367 | 
368 | 
369 | if __name__ == "__main__":
370 |     result_retrieval_bigbenchlite(_file_to_save, 
371 |                                   _results_dir,
372 |                                   _repo_dir,
373 |                                   _models_repo, 
374 |                                   _models_ours, 
375 |                                   _tasks,
376 |                                   _number_of_shots)


--------------------------------------------------------------------------------
/eval/triqa_plot.py:
--------------------------------------------------------------------------------
 1 | import matplotlib.pyplot as plt
 2 | from matplotlib.lines import Line2D
 3 | 
 4 | data = [
 5 |     ("OPT", 7, 7, 300, 2.1, 22.7),
 6 |     ("OPT", 13, 13, 300, 3.9, 28.2),
 7 |     ("Pythia", 7, 7, 300, 2.1, 19.8),
 8 |     ("Pythia", 12, 12, 300, 3.6, 22.3),
 9 |     ("GPTJ", 6, 6, 400, 2.4, 23.4),
10 |     # ("GPT-NeoX", 20, 20, 475, 9.5, 34.7),
11 |     ("MPT", 7, 7, 1000, 7, 34.3),
12 |     ("LLaMA", 7, 7, 1000, 7, 44.3),
13 |     ("GaLM-Dense", 0.1, 0.1, 600, 0.6, 2.3),
14 |     ("GaLM-Dense", 2, 2, 600, 1.2, 27),
15 |     ("GaLM-Dense", 8, 8, 600, 4.8, 48.1),
16 |     ("GaLM-MoE", 0.1, 2, 600, 0.6, 9.4),
17 |     ("GaLM-MoE", 2, 27, 600, 1.2, 44),
18 |     ("GaLM-MoE", 8, 143, 600, 4.8,  55.1),
19 |     ("Gopher", 1, 1, 300, 0.3, 6.5),
20 |     ("Gopher", 7, 7, 300, 2.1, 19.9),
21 |     ("PaLM", 8, 8, 780, 6.2, 39.5),
22 |     ("GPT-3", 3, 3, 300, 0.9, 31.3),
23 |     ("GPT-3", 7, 7, 300, 2.1, 38.7),
24 |     ("GPT-3", 13, 13, 300, 3.9, 41.8),
25 |     ("OpenMoE", 0.2, 0.5, 200, 0.04, 12.8),
26 |     ("OpenMoE", 2, 8, 200, 0.4, 29.2),
27 | ]
28 | 
29 | # Group data by model name and calculate the average values
30 | model_data = {}
31 | for model, act_param, total_param, tokens, cost, result in data:
32 |     if model not in model_data:
33 |         model_data[model] = {"cost": [], "result": [], "total_param": [], "act_param": []}
34 |     model_data[model]["cost"].append(cost)
35 |     model_data[model]["result"].append(result)
36 |     model_data[model]["total_param"].append(total_param)
37 |     model_data[model]["act_param"].append(act_param)
38 | 
39 | # Plotting
40 | plt.figure(figsize=(12, 6))
41 | colors = plt.cm.get_cmap("tab20", len(model_data))
42 | 
43 | handles = []
44 | for i, (model, values) in enumerate(model_data.items()):
45 |     x = values["cost"]
46 |     y = values["result"]
47 |     sizes = [param * 20 for param in values["act_param"]]  # Adjust the scaling factor as needed
48 |     total_sizes = [param * 20 for param in values["total_param"]]
49 |     color = 'red' if model == "OpenMoE" else colors(i)
50 |     marker= 'o'
51 |     if "MoE" in model:
52 |         plt.scatter(x, y, label=model, color='lightgray', s=total_sizes, marker=marker)
53 |     plt.scatter(x, y, label=model, color=color, s=sizes, marker=marker)
54 |     # handle = mpatches.Patch(color=colors(i), label=f'{model}', marker='o')
55 |     handle = Line2D([0], [0], marker=marker, color='w', label=f'{model}', markersize=12,
56 |            markerfacecolor=color)
57 |     handles.append(handle)
58 |     # Adding dashed lines to connect dots of the same model
59 |     plt.plot(x, y, linestyle='dashed', color=color)
60 |     # Annotate each dot with the model name and total parameters
61 |     for j, (x_val, y_val) in enumerate(zip(x, y)):
62 |         plt.annotate(f'{model}-{values["total_param"][j]}B', (x_val, y_val), textcoords="offset points", xytext=(0,10), ha='left', size=7)
63 | 
64 | plt.xlabel("Relative Cost")
65 | plt.ylabel("TrivalQA (0-shot EM)")
66 | plt.title("Relative Cost vs TrivalQA (0-shot EM)")
67 | # Move the legend to the outside right of the main figure
68 | plt.legend(handles=handles, loc='upper left', bbox_to_anchor=(1.10, 1), title="Model")
69 | plt.subplots_adjust(right=0.75)
70 | plt.grid(True)
71 | # plt.show()
72 | plt.savefig("../figure/triqa.pdf", dpi=300, bbox_inches="tight")
73 | plt.savefig("../figure/triqa.png", dpi=300, bbox_inches="tight")
74 | 


--------------------------------------------------------------------------------
/figure/bblite-3-shot.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XueFuzhao/OpenMoE/ad4c65cc5828721835c4b064504e16e81444e5d2/figure/bblite-3-shot.pdf


--------------------------------------------------------------------------------
/figure/bblite-3-shot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XueFuzhao/OpenMoE/ad4c65cc5828721835c4b064504e16e81444e5d2/figure/bblite-3-shot.png


--------------------------------------------------------------------------------
/figure/mt_bench_turn_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XueFuzhao/OpenMoE/ad4c65cc5828721835c4b064504e16e81444e5d2/figure/mt_bench_turn_0.png


--------------------------------------------------------------------------------
/figure/mt_bench_turn_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XueFuzhao/OpenMoE/ad4c65cc5828721835c4b064504e16e81444e5d2/figure/mt_bench_turn_1.png


--------------------------------------------------------------------------------
/figure/mt_bench_turn_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XueFuzhao/OpenMoE/ad4c65cc5828721835c4b064504e16e81444e5d2/figure/mt_bench_turn_2.png


--------------------------------------------------------------------------------
/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XueFuzhao/OpenMoE/ad4c65cc5828721835c4b064504e16e81444e5d2/logo.jpg


--------------------------------------------------------------------------------
/paper/README.md:
--------------------------------------------------------------------------------
1 | 
2 | 


--------------------------------------------------------------------------------
/paper/paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XueFuzhao/OpenMoE/ad4c65cc5828721835c4b064504e16e81444e5d2/paper/paper.pdf


--------------------------------------------------------------------------------
/results.md:
--------------------------------------------------------------------------------
 1 | # Experimental Results
 2 | 
 3 | ## MT-Bench
 4 | 
 5 | ### Overall
 6 | Note: The baselines are used up to 2 times training FLOPs. OpenMoE-8B has similar FLOPs as a 1.6B dense model.
 7 | 
 8 | <img src="figure/mt_bench_turn_0.png" width="50%" alt="MT-Bench-Turn-0">
 9 | 
10 | | Model            | Training Cost | Inference Cost | MT-Bench 1st Turn | MT-Bench 2nd Turn | MT-Bench Avg |
11 | |------------------|---------------|-----------------|-------------------|-------------------|--------------|
12 | | GPT-J-6B (0.4T)         | 2.4           | 6               | 2.51              | 2.35              | 2.43         |
13 | | TinyLLaMA-1.1B (3T)   | 3.3           | 1.1             | 4.08              | 2.54              | 3.31         |
14 | | OpenLLaMA-3B (1T)    | 3             | 3               | 4.36              | **3.62**              | **3.99**         |
15 | | OpenMoE-8B/32E (1.1T)| 1.8           | 1.6             | **4.69**              | 3.26              | **3.98**         |
16 | 
17 | 
18 | ### First Turn
19 | 
20 | <img src="figure/mt_bench_turn_1.png" width="50%" alt="MT-Bench-Turn-1">
21 | 
22 | ### Second Turn
23 | 
24 | <img src="figure/mt_bench_turn_2.png" width="50%" alt="MT-Bench-Turn-2">
25 | 


--------------------------------------------------------------------------------
/script/inference_on_multi_devices.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | from accelerate import infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
  3 | from transformers import AutoTokenizer, T5Tokenizer, AutoConfig, AutoModelForCausalLM, LogitsProcessorList, LogitsProcessor
  4 | from typing import List, Optional
  5 | from huggingface_hub import snapshot_download
  6 | 
  7 | 
  8 | class StopAfterEosTextGenerated(LogitsProcessor):
  9 |         """Logits processor (to use with HuggingFace `generate()` method :
 10 |         https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/
 11 |         text_generation#transformers.generation_utils.GenerationMixin).
 12 | 
 13 |        Sometimes our model output '▁</', 's', '>' seperately as stopping signal(not '▁</s>' as a whole),
 14 |        which is unable to be captured by a single eos token and can cause a very long generation.
 15 |        This logitsprocessor will force generation stop after ' </', 's', '>'.
 16 | 
 17 |         Args:
 18 |             base_len (int): Size of the given context. Used to know if this is
 19 |                 the first character to generate.
 20 |             eos_token_id (int): ID of the EOS token.
 21 |         """
 22 |         def __init__(self, base_len: int, eos_token_id: int):
 23 |             super().__init__()
 24 |             self.base_len = base_len
 25 |             self.eos_token_id = eos_token_id
 26 | 
 27 |         def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
 28 |             if input_ids.size(1) > self.base_len:
 29 |                 forced_eos = torch.full((scores.size(1),), -float("inf")).to(scores.device)
 30 |                 forced_eos[self.eos_token_id] = 0
 31 |                 # If the last three tokens of input_ids are the stop_token_ids, a eos will be forced to generate afterwards
 32 |                 stop_token_ids = torch.Tensor([15501, 281, 926]).to(scores.device)  # ids for tokens '▁</', 's', '>'
 33 |                 stop_sample_ids = torch.eq(input_ids[:, -len(stop_token_ids): ], stop_token_ids).all(dim=1)
 34 |                 scores[stop_sample_ids] = forced_eos
 35 |             return scores
 36 | 
 37 | def inference(model, tokenizer, input_strs, gen_kwargs,
 38 |               add_special_tokens=True, split_special_tokens=False, output_only=True, verbose=False):
 39 | 
 40 |     model = model.eval()
 41 | 
 42 |     # Tokenization
 43 |     inputs = tokenizer.batch_encode_plus(input_strs,
 44 |                                          padding='longest',
 45 |                                          add_special_tokens=add_special_tokens,
 46 |                                          split_special_tokens=split_special_tokens,
 47 |                                          return_tensors="pt")
 48 |     input_ids = inputs.input_ids.to(model.device)
 49 |     attention_mask = inputs.attention_mask.to(model.device)
 50 |     base_len = inputs.input_ids.size(-1)
 51 |     if verbose:
 52 |         print("Input Tokens:\n", input_ids)
 53 |         print("Num of Input Tokens: ", base_len)
 54 |         print("Attention Mask:\n", attention_mask)
 55 |     logits_processor = LogitsProcessorList([StopAfterEosTextGenerated(base_len, tokenizer.eos_token_id)])
 56 | 
 57 |     output_ids = model.generate(input_ids=input_ids,
 58 |                                 attention_mask=attention_mask,
 59 |                                 bos_token_id=tokenizer.pad_token_id,
 60 |                                 eos_token_id=tokenizer.eos_token_id,
 61 |                                 pad_token_id=tokenizer.pad_token_id,
 62 |                                 logits_processor=logits_processor,
 63 |                                 **gen_kwargs)
 64 |     if output_only:  # Only preserve output tokens
 65 |         output_ids = output_ids[:, input_ids.size(1):]
 66 |     if verbose:
 67 |         print("Generated Tokens:\n", output_ids)
 68 |     output_txts = tokenizer.batch_decode(output_ids,
 69 |                                          clean_up_tokenization_spaces=True,
 70 |                                          skip_special_tokens=False)
 71 |     return output_ids, output_txts
 72 | 
 73 | def apply_llama_chat_template(tokenizer, input_strs, sys_prompt):
 74 |     # Use LLaMA's Chat Template(A bit diffrent from original one at the beginning part, we may correct it to the standard llama prompt template later)
 75 |     # input_strs = [('user_input', 'user'), ('AI_response', 'assistant'), ...]
 76 |     tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + message['content'] + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' '  + message['content'] + ' ' + eos_token }}{% endif %}{% endfor %}"
 77 |     system_prompt = {'content': sys_prompt, 'role': 'system'}
 78 |     chat = [system_prompt] + [{'content': input_str, 'role': role} for input_str, role in input_strs]
 79 |     input_str = tokenizer.apply_chat_template(chat,
 80 |                                               tokenize=False,
 81 |                                               add_generation_prompt=True)
 82 |     return input_str
 83 | 
 84 | if __name__ == "__main__":
 85 |         # @markdown 1. Path to the checkpoint repo
 86 |         pytorch_checkpoint_path = "OrionZheng/openmoe-8b-chat"#@param {type:"string"}
 87 |         #@markdown 2. (If any)Specify GPUs you want to use.
 88 |         #@markdown
 89 |         #@markdown - If single GPU memory is not enough, you can enter ids of multiple GPUs(seperated by comma). During inference, GPUs will be filed up sequentially.
 90 |         available_gpu_ids_str = "0" # @param ["", "0", "0,1", "0,1,2"] {allow-input: true}
 91 |         #@markdown - Specify available memory of each GPU
 92 |         #@markdown   - Leave some margin for data and activation.
 93 |         #@markdown For example, we used 38GB GPU memory for an A100(40GB)
 94 |         memory_per_gpu = "38GiB" # @param ["", "38GiB"] {allow-input: true}
 95 |         #@markdown 3. Specify available CPU RAM
 96 |         #@markdown
 97 |         #@markdown - The Colab CPU High-RAM Runtime has 51GiB RAM
 98 |         cpu_memory = '50GiB' #@param ["50GiB"] {allow-input: true}
 99 |         # @markdown 3. Specify the model parameter's precision
100 | 
101 |         # @markdown - The CPU runtime only supports inference in float32 precision
102 | 
103 |         # @markdown - The `bfloat16` is only available on A100 Colab runtime
104 | 
105 |         # @markdown - Please use float32/bfloat16 for inference. We observed issues with the model output when running in float16, which may be due to underflow caused by our large vocabulary size.
106 |         model_dtype = 'bfloat16' #@param ["float32", "bfloat16"]
107 |         #@markdown (Not recommended, very slow) Offload model weights to CPU memory if GPU's is insufficient, then offload to disk if CPU memory is insufficient.
108 |         offload = False #@param {type:"boolean"}
109 |         
110 |         input_str = "What is the title of the last Harry Potter novel, published in 2007?" # @param [] {allow-input: true}
111 |         input_strs = [input_str]
112 |         gen_strategy = "greedy" #@param ["greedy", "top_p"]
113 |         #@markdown Please select the prompt template if chat model is being used. For raw language model, please leave this field blank.
114 |         prompt_template = "openmoe" #@param ["openmoe", ""]
115 |         max_new_tokens = 32 #@param {type:"slider", min:1, max:512, step:1}
116 |         debug_verbose = True #@param {type:"boolean"}
117 |         cache_dir = "./"
118 |         gen_kwargs = {
119 |                 "greedy": {"do_sample": False, "num_beams": 1, "max_new_tokens": max_new_tokens},  # Greedy Search
120 |                 "top_p": {"do_sample": True, "temperature": 0.5, "top_p": 0.8, "max_new_tokens": max_new_tokens},  # Top-p Sampling
121 |             }
122 |         
123 |         if torch.cuda.is_available():
124 |             cuda_list = available_gpu_ids_str.split(',')
125 |         else:
126 |             available_gpu_ids_str, memory_per_gpu = "", ""
127 |             model_dtype = "float32"
128 |             cuda_list = []
129 | 
130 |         no_split_module_classes = "OpenMoeDecoderLayer"
131 | 
132 |         # 1. Allocate Devices for Inference
133 |         available_memory = {int(cuda): memory_per_gpu for cuda in cuda_list}
134 |         available_memory['cpu'] = cpu_memory
135 |         print('Available Devices and Memory: ', available_memory)
136 | 
137 |         # 2. Load the Model (init with empty weight to save memory)
138 |         config = AutoConfig.from_pretrained(pytorch_checkpoint_path)
139 |         weights_location = snapshot_download(repo_id=pytorch_checkpoint_path,
140 |                                              cache_dir=cache_dir)
141 |         with init_empty_weights():
142 |             model = AutoModelForCausalLM.from_config(config,
143 |                                                      torch_dtype=eval(f'torch.{model_dtype}'),
144 |                                                      trust_remote_code=True)
145 |         print('Model dtype: ', model.dtype)
146 |         device_map = infer_auto_device_map(model,
147 |                                            max_memory=available_memory,
148 |                                            no_split_module_classes=no_split_module_classes)
149 |         print('Inferred Device Map: \n', device_map)
150 |         if offload:
151 |             model = load_checkpoint_and_dispatch(model, weights_location,
152 |                                                  device_map=device_map,
153 |                                                  offload_folder="offload",
154 |                                                  offload_state_dict=True,
155 |                                                  dtype=eval(f'torch.{model_dtype}'),
156 |                                                  no_split_module_classes=[no_split_module_classes])
157 |         else:
158 |             model = load_checkpoint_and_dispatch(model, weights_location,
159 |                                                  device_map=device_map,
160 |                                                  dtype=eval(f'torch.{model_dtype}'),
161 |                                                  no_split_module_classes=[no_split_module_classes])
162 |         print('Fine-grained Device Map: \n', model.hf_device_map)
163 |         
164 | 
165 | 
166 |         # 3. Load the Tokenizer 
167 |         tokenizer = AutoTokenizer.from_pretrained(pytorch_checkpoint_path, trust_remote_code=True)
168 | 
169 |         # 4. Inference
170 |         final_input_strs = []
171 |         for input_str in input_strs:
172 |             if prompt_template == "openmoe":
173 |                 SYS_LLAMA = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature."
174 |                 input_str = apply_llama_chat_template(tokenizer,
175 |                                                       [(input_str, 'user')],
176 |                                                       sys_prompt=SYS_LLAMA)
177 |             final_input_strs.append(input_str)
178 |         print("=========== The Actual Input =============")
179 |         [print(i) for i in final_input_strs]
180 | 
181 |         output_ids, output_txts = inference(model, tokenizer, final_input_strs, gen_kwargs[gen_strategy],
182 |                                             verbose=debug_verbose)
183 | 
184 |         print("============== Output Text ===============")
185 |         for output_txt in output_txts:
186 |             print(output_txt.split('</s>')[0])


--------------------------------------------------------------------------------
/script/run_eval.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | 
 3 | sudo apt update
 4 | sudo apt install -y python3.9 python3.9-venv
 5 | python3.9 -m venv openmoe_venv
 6 | 
 7 | source openmoe_venv/bin/activate
 8 | python3 -m pip install -U pip setuptools wheel ipython
 9 | python3 -m pip install --upgrade pip
10 | python3 -m pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20230724-py3-none-any.whl
11 | 
12 | pip install git+https://github.com/google-research/jestimator
13 | pip install protobuf==3.20.3
14 | git clone --branch=main https://github.com/XueFuzhao/t5x
15 | cd t5x
16 | python3 setup.py install
17 | 
18 | pip install flax
19 | 
20 | echo y | python3 -m pip uninstall t5[gcp]
21 | echo y | python3 -m pip uninstall t5
22 | git clone --branch=main https://github.com/XueFuzhao/text-to-text-transfer-transformer.git
23 | cd text-to-text-transfer-transformer
24 | python3 setup.py install
25 | 
26 | echo y | python3 -m pip uninstall seqio
27 | echo y | python3 -m pip uninstall seqio-nightly
28 | git clone  --branch=main https://github.com/XueFuzhao/seqio.git
29 | cd seqio
30 | python3 setup.py install
31 | cd ../..
32 | git clone  --branch=main https://github.com/XueFuzhao/flaxformer.git
33 | cd flaxformer
34 | python3 setup.py install
35 | 
36 | python3 -m pip install gast
37 | python3 -m pip install astunparse
38 | python3 -m pip install flatbuffers
39 | python3 -m pip install tensorboard
40 | python3 -m pip install keras
41 | python3 -m pip install tensorflow_estimator
42 | python3 -m pip install libcst
43 | python3 -m pip install portalocker
44 | python3 -m pip install tabulate
45 | python3 -m pip install colorama
46 | python3 -m pip install lxml
47 | python3 -m pip install joblib
48 | python3 -m pip install threadpoolctl
49 | python3 -m pip install tfds-nightly==4.6.0.dev202210040045
50 | # python3 -m pip install tensorflow-datasets==4.1.0
51 | python3 -m pip install h5py
52 | 
53 | 
54 | cd ~
55 | git clone https://github.com/google/aqt.git
56 | cd aqt
57 | python3 setup.py install
58 | 
59 | cd ~
60 | git clone https://github.com/google/BIG-bench.git
61 | cd BIG-bench
62 | python3 setup.py sdist
63 | python3 -m pip install -e .
64 | 
65 | cd ~
66 | export GOOGLE_CLOUD_BUCKET_NAME=${YOUR_BUDGET_NAME} \
67 | export TFDS_DATA_DIR=gs://${YOUR_BUDGET_NAME} \
68 | export MODEL_DIR=gs://${YOUR_BUDGET_NAME}/openmoe_8b/checkpoint_100000 \
69 | export OUTPUT_DIR=gs://${YOUR_BUDGET_NAME}/openmoe_8b/checkpoint_100000_eval \
70 | export T5X_DIR="./t5x" \
71 | 
72 | 
73 | python3  ${T5X_DIR}/t5x/eval.py \
74 |   --gin_file="t5x/examples/t5/t5_1_1/examples/openmoe_large_eval_bblite.gin" \
75 |   --gin.CHECKPOINT_PATH=\"${MODEL_DIR}\" \
76 |   --gin.EVAL_OUTPUT_DIR=\"${OUTPUT_DIR}\" \
77 |   --tfds_data_dir=${TFDS_DATA_DIR}
78 | 


--------------------------------------------------------------------------------
/script/run_pretrain.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | 
 3 | sudo apt update
 4 | sudo apt install -y python3.9 python3.9-venv
 5 | python3.9 -m venv openmoe_venv
 6 | 
 7 | source openmoe_venv/bin/activate
 8 | python3 -m pip install -U pip setuptools wheel ipython
 9 | python3 -m pip install --upgrade pip
10 | python3 -m pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20230724-py3-none-any.whl
11 | 
12 | pip install git+https://github.com/google-research/jestimator
13 | pip install protobuf==3.20.3
14 | git clone --branch=main https://github.com/XueFuzhao/t5x
15 | cd t5x
16 | python3 setup.py install
17 | 
18 | pip install flax
19 | 
20 | echo y | python3 -m pip uninstall t5[gcp]
21 | echo y | python3 -m pip uninstall t5
22 | git clone --branch=main https://github.com/XueFuzhao/text-to-text-transfer-transformer.git
23 | cd text-to-text-transfer-transformer
24 | python3 setup.py install
25 | 
26 | echo y | python3 -m pip uninstall seqio
27 | echo y | python3 -m pip uninstall seqio-nightly
28 | git clone  --branch=main https://github.com/XueFuzhao/seqio.git
29 | cd seqio
30 | python3 setup.py install
31 | cd ../..
32 | git clone  --branch=main https://github.com/XueFuzhao/flaxformer.git
33 | cd flaxformer
34 | python3 setup.py install
35 | 
36 | python3 -m pip install gast
37 | python3 -m pip install astunparse
38 | python3 -m pip install flatbuffers
39 | python3 -m pip install tensorboard
40 | python3 -m pip install keras
41 | python3 -m pip install tensorflow_estimator
42 | python3 -m pip install libcst
43 | python3 -m pip install portalocker
44 | python3 -m pip install tabulate
45 | python3 -m pip install colorama
46 | python3 -m pip install lxml
47 | python3 -m pip install joblib
48 | python3 -m pip install threadpoolctl
49 | python3 -m pip install tfds-nightly==4.6.0.dev202210040045
50 | # python3 -m pip install tensorflow-datasets==4.3.0
51 | python3 -m pip install h5py
52 | 
53 | cd ~
54 | git clone https://github.com/google/aqt.git
55 | cd aqt
56 | python3 setup.py install
57 | 
58 | 
59 | cd ~
60 | export GOOGLE_CLOUD_BUCKET_NAME=${YOUR_BUDGET_NAME} \
61 | export TFDS_DATA_DIR=gs://${YOUR_BUDGET_NAME} \
62 | export MODEL_DIR=gs://${YOUR_BUDGET_NAME}/openmoe_8b/training \
63 | export T5X_DIR="./t5x" \
64 | 
65 | python3  ${T5X_DIR}/t5x/train.py \
66 | 	--gin_file="t5x/examples/t5/t5_1_1/examples/openmoe_large.gin" \
67 |   --gin.MODEL_DIR=\"${MODEL_DIR}\" \
68 |   --tfds_data_dir=${TFDS_DATA_DIR}
69 | 


--------------------------------------------------------------------------------