├── .gitignore ├── README.md ├── configs ├── __init__.py ├── ia3.gin ├── ia3_eval.gin ├── lora.gin ├── lora_eval.gin ├── molora.gin ├── molora_eval.gin ├── mov.gin ├── mov_eval.gin ├── t0.gin ├── t0_eval.gin └── t5 │ ├── __init__.py │ ├── architectures │ ├── flash_attention.gin │ ├── t5_1_1_flaxformer.gin │ └── t5_flaxformer.gin │ ├── gin_configs_test.py │ └── models │ ├── t5_11B.gin │ ├── t5_1_1_base.gin │ ├── t5_1_1_large.gin │ ├── t5_1_1_small.gin │ ├── t5_1_1_tiny.gin │ ├── t5_1_1_xl.gin │ ├── t5_1_1_xxl.gin │ ├── t5_3B.gin │ ├── t5_base.gin │ ├── t5_large.gin │ └── t5_small.gin ├── demo.png ├── scripts ├── find_module.py ├── ia3_eval.sh ├── ia3_train.sh ├── lora_eval.sh ├── lora_train.sh ├── molora_eval.sh ├── molora_train.sh ├── mov_eval.sh ├── mov_train.sh ├── setup.sh ├── t0_eval.sh └── t0_train.sh ├── src ├── __init__.py ├── adafactor_custom.py ├── ia3.py ├── lora.py ├── molora.py ├── mov.py ├── partitioning_custom.py ├── routing.py └── utils.py └── t0_data ├── LICENSE ├── __init__.py ├── dataset_split.pickle ├── datasets.csv ├── datasets_original.csv ├── tasks.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MoV and MoLoRA 2 | This repository contains the official code for the paper: "[Pushing Mixture of Experts to the Limit: Extremely Parameter Efficient MoE for Instruction Tuning](https://arxiv.org/abs/2309.05444)." 3 | 4 | The codebase is built on [T5X](https://github.com/google-research/t5x), which 5 | defines the model and training loop; 6 | [Flaxformer](https://github.com/google/flaxformer), which defines the 7 | model computation; [Flax](https://github.com/google/flax), which defines the low 8 | level model layers; and [Jax](https://github.com/google/jax), which provides the execution 9 | 10 | ![My LaTeX Image](demo.png) 11 | 12 | #### Installation 13 | 14 | # CLONE repo 15 | git clone https://github.com/for-ai/parameter-efficient-moe 16 | 17 | # COPY to TPUs 18 | gcloud alpha compute tpus tpu-vm scp --recurse parameter-efficient-moe :parameter-efficient-moe --zone --worker=all 19 | 20 | # RUN on TPUs 21 | bash scripts/setup.sh 22 | 23 | 24 | ### Dataset 25 | The dataset that is used for training and evaluation should be cached using [SeqIO](https://github.com/google/seqio). We used [bigscience/P3](https://huggingface.co/datasets/bigscience/P3) dataset which is already prepared. For the dataset preparation, we refer [bigscience/t-zero](https://github.com/bigscience-workshop/t-zero/tree/master/training) repository. 26 | 27 | ### Code components 28 | 29 | Here is the code layout: 30 | 31 | * `configs/` :: contains configs for the architecture of the each models including T0, IA3, LoRA, MoV, MoLoRa using gin style configuration. 32 | * `scripts/` :: contains all the training and evaluation files for full fine-tuning, vanilla parameter-efficient fine-tuning, and their mixture counterpart fine-tuning. 33 | * `src/` :: contains IA3, LoRA, MoV and MoLoRa computations, including the router they use. 34 | 35 | 36 | #### Example script 37 | 38 | gcloud alpha compute tpus tpu-vm ssh --zone --worker=all --command "cd parameter-efficient-moe; bash scripts/mov_train.sh" 39 | 40 | 41 | #### Fine-tuning: 42 | 43 | ```sh 44 | # moe/scripts/mov_train.sh 45 | 46 | MODEL_DIR=${1:-${MODEL_DIR}} # Model dir to save logs, ckpts, etc. in "gs://model_dir" format. 47 | 48 | T5X_DIR="`python3 -m scripts.find_module t5x`/.." # directory where the T5X repo is cloned. 49 | FLAXFORMER_DIR="`python3 -m scripts.find_module flaxformer`/.." # directory where the Flaxformer repo is cloned. 50 | echo "Searching for gin configs in:" 51 | echo "- ${T5X_DIR}" 52 | echo "- ${FLAXFORMER_DIR}" 53 | echo "=============================" 54 | 55 | PRETRAINED_MODEL="gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_large/checkpoint_1100000" 56 | CACHE_DIR="raw_tfrecords/you_cache_dir" # Directory where P3 cached data is stored, etc. in "gs://model_dir" format. 57 | 58 | python3 -m t5x.train \ 59 | --gin_search_paths="${T5X_DIR}" \ 60 | --gin_file="configs/t5/models/t5_1_1_large.gin" \ #e.g. 770M(t5-large) model 61 | --gin_file="configs/mov.gin" \ # Use MoV as the architecture for PEFT 62 | --gin.MODEL_DIR="'${MODEL_DIR}'" \ 63 | --gin.LOSS_NORMALIZING_FACTOR="'AVERAGE_PER_SEQUENCE'" \ 64 | --gin.MIXTURE_OR_TASK_NAME="'t0_train'" \ # Training subset 65 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 1024, 'targets': 256}" \ 66 | --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" \ 67 | --gin.TRAIN_STEPS="1_600_000" \ # Pre-trained + number of steps 68 | --gin.USE_CACHED_TASKS="True" \ 69 | --gin.PACKING="True" \ 70 | --seqio_additional_cache_dirs=${CACHE_DIR} \ 71 | --gin.BATCH_SIZE="32" 72 | ``` 73 | 74 | #### Evaluation: 75 | 76 | ```sh 77 | # moe/scripts/mov_eval.sh 78 | 79 | CKPT_DIR=${1:-${CKPT_DIR}} # directory where the fine-tune model is stored 80 | EVAL_DIR=${2:-${EVAL_DIR}} # directory to write eval output 81 | 82 | T5X_DIR="`python3 -m scripts.find_module t5x`/.." #directory where the t5x is cloned 83 | FLAXFORMER_DIR="`python3 -m scripts.find_module flaxformer`/.." #directory where the flaxformer is cloned 84 | echo "Searching for gin configs in:" 85 | echo "- ${T5X_DIR}" 86 | echo "- ${FLAXFORMER_DIR}" 87 | echo "=============================" 88 | 89 | CACHE_DIR="raw_tfrecords/you_cache_dir" # directory where P3 cached data is stored, etc. in "gs://model_dir" format. 90 | 91 | python3 -m t5x.eval \ 92 | --gin_search_paths="${T5X_DIR}" \ 93 | --gin_file="configs/t5/models/t5_1_1_large.gin" \ 94 | --gin_file="configs/mov_eval.gin" \ # Use MoV as the architecture for PEFT 95 | --gin.EVAL_OUTPUT_DIR="'${EVAL_DIR}'" \ 96 | --gin.MIXTURE_OR_TASK_NAME="'t0_eval_score_eval'" \ # Evaluation subset 97 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 1024, 'targets': 256}" \ 98 | --gin.CHECKPOINT_PATH="'${CKPT_DIR}'" \ 99 | --seqio_additional_cache_dirs=${CACHE_DIR} \ 100 | --gin.utils.DatasetConfig.use_cached="True" \ 101 | --gin.utils.DatasetConfig.split="'validation'" \ 102 | --gin.BATCH_SIZE="32" 103 | ``` 104 | #### References 105 | Our IA3 module implementation is the based on [prompt-tuning](https://github.com/google-research/prompt-tuning), and we used [bigscience/t-zero](https://github.com/bigscience-workshop/t-zero/tree/master/training) for implementation of the dataset. 106 | 107 | #### Citation 108 | Please use the following bibtex entry to cite our work. 109 | 110 | ``` 111 | @article{zadouri2023pushing, 112 | url = {https://arxiv.org/abs/2309.05444} 113 | title={Pushing Mixture of Experts to the Limit: Extremely Parameter Efficient MoE for Instruction Tuning}, 114 | author={Ted Zadouri and Ahmet Üstün and Arash Ahmadian and Beyza Ermiş and Acyr Locatelli and Sara Hooker}, 115 | year={2023}, 116 | } 117 | ``` 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | """Gin requires gin files to be in a python package, this file makes it one.""" -------------------------------------------------------------------------------- /configs/ia3.gin: -------------------------------------------------------------------------------- 1 | # ginlint: disable=bad-import-order 2 | from __gin__ import dynamic_registration 3 | 4 | import seqio 5 | from t5x import models 6 | from t5x import utils 7 | from t5x import adafactor 8 | from t5x import optimizers as optim 9 | from flax import linen 10 | from flax import traverse_util 11 | from flaxformer.components import dense 12 | from flaxformer.components.attention import dense_attention 13 | 14 | from t5x import partitioning 15 | from src import adafactor_custom as c_optim 16 | from src import partitioning_custom as c_partitioning 17 | 18 | from src import utils as peft_utils 19 | from src import ia3 20 | 21 | include 't5x/configs/runs/finetune_no_eval.gin' 22 | 23 | # ========== Data Mixture ========== 24 | # SeqIO tasks for p3 (from original repo) 25 | import t0_data 26 | 27 | # ========== These are IA3 HPs you might want to override ========== 28 | # If you want to change the actual optimizer itself (to optim.Adam, etc), make 29 | # sure to update the optimizer that is passed to the MultiOptimizer. 30 | adafactor.Adafactor: 31 | decay_rate = 0.8 32 | step_offset = 0 33 | logical_factor_rules = @c_optim.standard_logical_factor_rules() 34 | 35 | # ========== These are IA3 HPs you might want to override ========== 36 | partitioning.PjitPartitioner: 37 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 38 | 39 | partitioning.standard_logical_axis_rules: 40 | additional_rules = @c_partitioning.standard_logical_axis_rules() 41 | 42 | # ========== Partial Loading ========== 43 | # The following is the configuration the allows to partially load a model (using 44 | # the values in a checkpoint) without it complaining that the shapes don't match 45 | # (because we have extra parameters, the ia3 scaling values) in our model. 46 | # You shouldn't need to update these outside of if you want to change the 47 | # optimizer itself. 48 | # 49 | # Optimizer 50 | # LR is set by `Trainer.learning_rate_fn`. 51 | # Use our MultiOptimizer wrapper to bind to the variadic 52 | # `*traversals_and_optimizers` 53 | OPTIMIZER = @optim.MultiOptimizer() 54 | optim.MultiOptimizer: 55 | traversals_and_optimizers = ((@traverse_util.ModelParamTraversal(), 56 | @adafactor.Adafactor()),) 57 | traverse_util.ModelParamTraversal: 58 | filter_fn = @peft_utils.match_any() 59 | # Our MultiOptimzier will match any parameter with a flattened name that 60 | # matches any of these regular expressions. 61 | TRAINABLE_REGEX = [".*/ia3_scaling.*"] 62 | peft_utils.match_any.regexes = %TRAINABLE_REGEX 63 | 64 | # These settings allow us to partially reload a checkpoint, that is, we can load 65 | # most of the model weights from the checkpoint, without it complaining that we 66 | # don't have a weight for our ia3 scaling values in the checkpoint. 67 | utils.RestoreCheckpointConfig: 68 | # Activate the codepath that allows the merging of the optimizer state as 69 | # specified in the config (with our new parameter) and the optimizer state as 70 | # defined in the checkpoint. 71 | fallback_to_scratch = True 72 | # Use the T5X assignment map to grab values from the checkpoint. Each entry in 73 | # the map is a regular expression that matches some flattened variable name in 74 | # the optimizer state as defined in the model created by the config. The 75 | # second value is the corresponding name in optimizer state as defined by the 76 | # checkpoint. It supports interpolating capture groups from the initial regex. 77 | # If the second pattern is `None` we skip trying to load this variable from 78 | # the checkpoint. 79 | 80 | # Skip trying to load all keys that have the word ia3_scaling in them, these 81 | # will be initialized from scratch. 82 | assignment_map = ((r"^.*ia3_scaling.*$", None),) 83 | 84 | utils.create_learning_rate_scheduler: 85 | factors = "constant" 86 | # Learning rate from the paper. 87 | base_learning_rate = 3e-4 88 | 89 | #INITIAL_CHECKPOINT_PATH = None 90 | #utils.CheckpointConfig: 91 | # restore = None 92 | 93 | utils.SaveCheckpointConfig: 94 | period = 50000 95 | keep = 60 96 | 97 | # ========== ARCHITECTURE ========== 98 | # Add ia3 to all attention implementations 99 | dense_attention.MultiHeadDotProductAttention: 100 | k_conv = @ia3.IA3Attention() 101 | v_conv = @ia3.IA3Attention() 102 | 103 | ia3.IA3Attention: 104 | dtype = %ACTIVATION_DTYPE 105 | 106 | dense.MlpBlock: 107 | intermediate_conv = @ia3.IA3() 108 | 109 | ia3.IA3: 110 | axis_name = ('mlp',) 111 | dtype = %ACTIVATION_DTYPE 112 | -------------------------------------------------------------------------------- /configs/ia3_eval.gin: -------------------------------------------------------------------------------- 1 | # ginlint: disable=bad-import-order 2 | from __gin__ import dynamic_registration 3 | 4 | import seqio 5 | from t5x import models 6 | from t5x import utils 7 | from t5x import adafactor 8 | from t5x import optimizers as optim 9 | from flax import linen 10 | from flax import traverse_util 11 | from flaxformer.components import dense 12 | from flaxformer.components.attention import dense_attention 13 | 14 | from t5x import partitioning 15 | from src import adafactor_custom as c_optim 16 | from src import partitioning_custom as c_partitioning 17 | 18 | from src import utils as peft_utils 19 | from src import ia3 20 | 21 | include 't5x/configs/runs/eval.gin' 22 | 23 | # ========== Data Mixture ========== 24 | # SeqIO tasks for p3 (from original repo) 25 | import t0_data 26 | 27 | # ========== These are IA3 HPs you might want to override ========== 28 | # If you want to change the actual optimizer itself (to optim.Adam, etc), make 29 | # sure to update the optimizer that is passed to the MultiOptimizer. 30 | adafactor.Adafactor: 31 | decay_rate = 0.8 32 | step_offset = 0 33 | logical_factor_rules = @c_optim.standard_logical_factor_rules() 34 | 35 | # ========== These are IA3 HPs you might want to override ========== 36 | partitioning.PjitPartitioner: 37 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 38 | 39 | partitioning.standard_logical_axis_rules: 40 | additional_rules = @c_partitioning.standard_logical_axis_rules() 41 | 42 | # ========== ARCHITECTURE ========== 43 | # Add ia3 to all attention implementations 44 | dense_attention.MultiHeadDotProductAttention: 45 | k_conv = @ia3.IA3Attention() 46 | v_conv = @ia3.IA3Attention() 47 | 48 | ia3.IA3Attention: 49 | dtype = %ACTIVATION_DTYPE 50 | 51 | dense.MlpBlock: 52 | intermediate_conv = @ia3.IA3() 53 | 54 | ia3.IA3: 55 | axis_name = ('mlp',) 56 | dtype = %ACTIVATION_DTYPE 57 | -------------------------------------------------------------------------------- /configs/lora.gin: -------------------------------------------------------------------------------- 1 | # ginlint: disable=bad-import-order 2 | from __gin__ import dynamic_registration 3 | 4 | import seqio 5 | from t5x import models 6 | from t5x import utils 7 | from t5x import adafactor 8 | from t5x import partitioning 9 | from t5x import optimizers as optim 10 | from flax import linen 11 | from flax import traverse_util 12 | from flaxformer.components import dense 13 | from flaxformer.components.attention import dense_attention 14 | 15 | from src import adafactor_custom as c_optim 16 | from src import partitioning_custom as c_partitioning 17 | 18 | from src import utils as peft_utils 19 | from src import routing 20 | from src import lora 21 | 22 | include 't5x/configs/runs/finetune_no_eval.gin' 23 | 24 | # ========== Data Mixture ========== 25 | # SeqIO tasks for p3 (from original repo) 26 | import t0_data 27 | 28 | # ========== These are IA3 HPs you might want to override ========== 29 | # If you want to change the actual optimizer itself (to optim.Adam, etc), make 30 | # sure to update the optimizer that is passed to the MultiOptimizer. 31 | adafactor.Adafactor: 32 | decay_rate = 0.8 33 | step_offset = 0 34 | logical_factor_rules = @c_optim.standard_logical_factor_rules() 35 | 36 | # ========== These are IA3 HPs you might want to override ========== 37 | partitioning.PjitPartitioner: 38 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 39 | 40 | partitioning.standard_logical_axis_rules: 41 | additional_rules = @c_partitioning.standard_logical_axis_rules() 42 | 43 | # ========== Partial Loading ========== 44 | # The following is the configuration the allows to partially load a model (using 45 | # the values in a checkpoint) without it complaining that the shapes don't match 46 | # (because we have extra parameters, the ia3 scaling values) in our model. 47 | # You shouldn't need to update these outside of if you want to change the 48 | # optimizer itself. 49 | # 50 | # Optimizer 51 | # LR is set by `Trainer.learning_rate_fn`. 52 | # Use our MultiOptimizer wrapper to bind to the variadic 53 | # `*traversals_and_optimizers` 54 | OPTIMIZER = @optim.MultiOptimizer() 55 | optim.MultiOptimizer: 56 | traversals_and_optimizers = ((@traverse_util.ModelParamTraversal(), 57 | @adafactor.Adafactor()),) 58 | traverse_util.ModelParamTraversal: 59 | filter_fn = @peft_utils.match_any() 60 | # Our MultiOptimzier will match any parameter with a flattened name that 61 | # matches any of these regular expressions. 62 | TRAINABLE_REGEX = [".*/lora_A.*", ".*/lora_B.*"] 63 | peft_utils.match_any.regexes = %TRAINABLE_REGEX 64 | 65 | # These settings allow us to partially reload a checkpoint, that is, we can load 66 | # most of the model weights from the checkpoint, without it complaining that we 67 | # don't have a weight for our ia3 scaling values in the checkpoint. 68 | utils.RestoreCheckpointConfig: 69 | # Activate the codepath that allows the merging of the optimizer state as 70 | # specified in the config (with our new parameter) and the optimizer state as 71 | # defined in the checkpoint. 72 | fallback_to_scratch = True 73 | # Use the T5X assignment map to grab values from the checkpoint. Each entry in 74 | # the map is a regular expression that matches some flattened variable name in 75 | # the optimizer state as defined in the model created by the config. The 76 | # second value is the corresponding name in optimizer state as defined by the 77 | # checkpoint. It supports interpolating capture groups from the initial regex. 78 | # If the second pattern is `None` we skip trying to load this variable from 79 | # the checkpoint. 80 | 81 | # Skip trying to load all keys that have the word ia3_scaling in them, these 82 | # will be initialized from scratch. 83 | assignment_map = ((r"^.*lora_A.*$", None), 84 | (r"^.*lora_B.*$", None),) 85 | 86 | utils.create_learning_rate_scheduler: 87 | factors = "constant" 88 | # Learning rate from the paper. 89 | base_learning_rate = 3e-4 90 | 91 | utils.SaveCheckpointConfig: 92 | period = 50000 93 | keep = 60 94 | 95 | # ========== ARCHITECTURE ========== 96 | dense_attention.MultiHeadDotProductAttention: 97 | lora_output_conv = @attn/lora.LoRA() 98 | lora_q_conv = @attn/lora.LoRAAttention() 99 | lora_k_conv = @attn/lora.LoRAAttention() 100 | lora_v_conv = @attn/lora.LoRAAttention() 101 | 102 | attn/lora.LoRAAttention: 103 | rank = 16 104 | num_heads = 16 105 | dtype = 'bfloat16' 106 | lora_axis_names_A = ('embed', 'rank') 107 | lora_axis_names_B = ('rank', 'embed') 108 | 109 | attn/lora.LoRA: 110 | rank = 16 111 | dtype = 'bfloat16' 112 | lora_axis_names_A = ('embed', 'rank') 113 | lora_axis_names_B = ('rank', 'embed') 114 | 115 | dense.MlpBlock: 116 | lora_intermediate_conv = @mlp1/lora.LoRA() 117 | lora_output_conv = @mlp2/lora.LoRA() 118 | 119 | mlp1/lora.LoRA: 120 | rank = 16 121 | output_dim = 2816 122 | dtype = 'bfloat16' 123 | lora_axis_names_A = ('embed', 'rank') 124 | lora_axis_names_B = ('rank', 'mlp') 125 | 126 | mlp2/lora.LoRA: 127 | rank = 16 128 | output_dim = 1024 129 | dtype = 'bfloat16' 130 | lora_axis_names_A = ('mlp', 'rank') 131 | lora_axis_names_B = ('rank', 'embed') -------------------------------------------------------------------------------- /configs/lora_eval.gin: -------------------------------------------------------------------------------- 1 | # ginlint: disable=bad-import-order 2 | from __gin__ import dynamic_registration 3 | 4 | import seqio 5 | from t5x import models 6 | from t5x import utils 7 | from t5x import adafactor 8 | from t5x import partitioning 9 | from t5x import optimizers as optim 10 | from flax import linen 11 | from flax import traverse_util 12 | from flaxformer.components import dense 13 | from flaxformer.components.attention import dense_attention 14 | 15 | from src import adafactor_custom as c_optim 16 | from src import partitioning_custom as c_partitioning 17 | 18 | from src import utils as peft_utils 19 | from src import routing 20 | from src import lora 21 | 22 | include 't5x/configs/runs/eval.gin' 23 | 24 | # ========== Data Mixture ========== 25 | # SeqIO tasks for p3 (from original repo) 26 | import t0_data 27 | 28 | # ========== These are IA3 HPs you might want to override ========== 29 | # If you want to change the actual optimizer itself (to optim.Adam, etc), make 30 | # sure to update the optimizer that is passed to the MultiOptimizer. 31 | adafactor.Adafactor: 32 | decay_rate = 0.8 33 | step_offset = 0 34 | logical_factor_rules = @c_optim.standard_logical_factor_rules() 35 | 36 | # ========== These are IA3 HPs you might want to override ========== 37 | partitioning.PjitPartitioner: 38 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 39 | 40 | partitioning.standard_logical_axis_rules: 41 | additional_rules = @c_partitioning.standard_logical_axis_rules() 42 | 43 | # ========== ARCHITECTURE ========== 44 | dense_attention.MultiHeadDotProductAttention: 45 | lora_output_conv = @attn/lora.LoRA() 46 | lora_q_conv = @attn/lora.LoRAAttention() 47 | lora_k_conv = @attn/lora.LoRAAttention() 48 | lora_v_conv = @attn/lora.LoRAAttention() 49 | 50 | attn/lora.LoRAAttention: 51 | rank = 16 52 | num_heads = 16 53 | dtype = 'bfloat16' 54 | lora_axis_names_A = ('embed', 'rank') 55 | lora_axis_names_B = ('rank', 'embed') 56 | 57 | attn/lora.LoRA: 58 | rank = 16 59 | dtype = 'bfloat16' 60 | lora_axis_names_A = ('embed', 'rank') 61 | lora_axis_names_B = ('rank', 'embed') 62 | 63 | dense.MlpBlock: 64 | lora_intermediate_conv = @mlp1/lora.LoRA() 65 | lora_output_conv = @mlp2/lora.LoRA() 66 | 67 | mlp1/lora.LoRA: 68 | rank = 16 69 | output_dim = 2816 70 | dtype = 'bfloat16' 71 | lora_axis_names_A = ('embed', 'rank') 72 | lora_axis_names_B = ('rank', 'mlp') 73 | 74 | mlp2/lora.LoRA: 75 | rank = 16 76 | output_dim = 1024 77 | dtype = 'bfloat16' 78 | lora_axis_names_A = ('mlp', 'rank') 79 | lora_axis_names_B = ('rank', 'embed') -------------------------------------------------------------------------------- /configs/molora.gin: -------------------------------------------------------------------------------- 1 | # ginlint: disable=bad-import-order 2 | from __gin__ import dynamic_registration 3 | 4 | import seqio 5 | from t5x import models 6 | from t5x import utils 7 | from t5x import adafactor 8 | from t5x import partitioning 9 | from t5x import optimizers as optim 10 | from flax import linen 11 | from flax import traverse_util 12 | from flaxformer.components import dense 13 | from flaxformer.components.attention import dense_attention 14 | 15 | from src import adafactor_custom as c_optim 16 | from src import partitioning_custom as c_partitioning 17 | 18 | from src import utils as peft_utils 19 | from src import routing 20 | from src import molora 21 | 22 | include 't5x/configs/runs/finetune_no_eval.gin' 23 | 24 | # ========== Data Mixture ========== 25 | # SeqIO tasks for p3 (from original repo) 26 | import t0_data 27 | 28 | # ========== These are IA3 HPs you might want to override ========== 29 | # If you want to change the actual optimizer itself (to optim.Adam, etc), make 30 | # sure to update the optimizer that is passed to the MultiOptimizer. 31 | adafactor.Adafactor: 32 | decay_rate = 0.8 33 | step_offset = 0 34 | logical_factor_rules = @c_optim.standard_logical_factor_rules() 35 | 36 | # ========== These are IA3 HPs you might want to override ========== 37 | partitioning.PjitPartitioner: 38 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 39 | 40 | partitioning.standard_logical_axis_rules: 41 | additional_rules = @c_partitioning.standard_logical_axis_rules() 42 | 43 | # ========== Partial Loading ========== 44 | # The following is the configuration the allows to partially load a model (using 45 | # the values in a checkpoint) without it complaining that the shapes don't match 46 | # (because we have extra parameters, the ia3 scaling values) in our model. 47 | # You shouldn't need to update these outside of if you want to change the 48 | # optimizer itself. 49 | # 50 | # Optimizer 51 | # LR is set by `Trainer.learning_rate_fn`. 52 | # Use our MultiOptimizer wrapper to bind to the variadic 53 | # `*traversals_and_optimizers` 54 | OPTIMIZER = @optim.MultiOptimizer() 55 | optim.MultiOptimizer: 56 | traversals_and_optimizers = ((@traverse_util.ModelParamTraversal(), 57 | @adafactor.Adafactor()),) 58 | traverse_util.ModelParamTraversal: 59 | filter_fn = @peft_utils.match_any() 60 | # Our MultiOptimzier will match any parameter with a flattened name that 61 | # matches any of these regular expressions. 62 | TRAINABLE_REGEX = [".*/lora_A.*", ".*/lora_B.*", ".*/router.*"] 63 | peft_utils.match_any.regexes = %TRAINABLE_REGEX 64 | 65 | # These settings allow us to partially reload a checkpoint, that is, we can load 66 | # most of the model weights from the checkpoint, without it complaining that we 67 | # don't have a weight for our ia3 scaling values in the checkpoint. 68 | utils.RestoreCheckpointConfig: 69 | # Activate the codepath that allows the merging of the optimizer state as 70 | # specified in the config (with our new parameter) and the optimizer state as 71 | # defined in the checkpoint. 72 | fallback_to_scratch = True 73 | # Use the T5X assignment map to grab values from the checkpoint. Each entry in 74 | # the map is a regular expression that matches some flattened variable name in 75 | # the optimizer state as defined in the model created by the config. The 76 | # second value is the corresponding name in optimizer state as defined by the 77 | # checkpoint. It supports interpolating capture groups from the initial regex. 78 | # If the second pattern is `None` we skip trying to load this variable from 79 | # the checkpoint. 80 | 81 | # Skip trying to load all keys that have the word ia3_scaling in them, these 82 | # will be initialized from scratch. 83 | assignment_map = ((r"^.*lora_A.*$", None), 84 | (r"^.*lora_B.*$", None), 85 | (r"^.*router.*$", None),) 86 | 87 | utils.create_learning_rate_scheduler: 88 | factors = "constant" 89 | # Learning rate from the paper. 90 | base_learning_rate = 3e-4 91 | 92 | utils.SaveCheckpointConfig: 93 | period = 10000 94 | keep = 60 95 | 96 | # ========== ARCHITECTURE ========== 97 | dense_attention.MultiHeadDotProductAttention: 98 | lora_output_conv = @attn/molora.MoLoRa() 99 | lora_q_conv = @attn/molora.MoLoRaAttention() 100 | lora_k_conv = @attn/molora.MoLoRaAttention() 101 | lora_v_conv = @attn/molora.MoLoRaAttention() 102 | 103 | attn/molora.MoLoRaAttention: 104 | rank = 4 105 | num_experts = 10 106 | num_heads = 16 107 | router = @attn/routing.Router() 108 | dtype = 'bfloat16' 109 | lora_axis_names_A = ('expert', 'embed', 'rank') 110 | lora_axis_names_B = ('expert', 'rank', 'embed') 111 | 112 | attn/molora.MoLoRa: 113 | rank = 4 114 | num_experts = 10 115 | router = @attn_out/routing.Router() 116 | dtype = 'bfloat16' 117 | lora_axis_names_A = ('expert', 'embed', 'rank') 118 | lora_axis_names_B = ('expert', 'rank', 'embed') 119 | 120 | dense.MlpBlock: 121 | lora_intermediate_conv = @mlp1/molora.MoLoRa() 122 | lora_output_conv = @mlp2/molora.MoLoRa() 123 | 124 | mlp1/molora.MoLoRa: 125 | rank = 4 126 | num_experts = 10 127 | output_dim = 2816 128 | router = @mlp1/routing.Router() 129 | dtype = 'bfloat16' 130 | lora_axis_names_A = ('expert', 'embed', 'rank') 131 | lora_axis_names_B = ('expert', 'rank', 'mlp') 132 | 133 | mlp2/molora.MoLoRa: 134 | rank = 4 135 | num_experts = 10 136 | output_dim = 1024 137 | router = @mlp2/routing.Router() 138 | dtype = 'bfloat16' 139 | lora_axis_names_A = ('expert', 'mlp', 'rank') 140 | lora_axis_names_B = ('expert', 'rank', 'embed') 141 | 142 | attn/routing.Router: 143 | router_weights = @attn/routing.RouterWeights() 144 | input_axis_names = ('batch', 'length', 'embed') 145 | jitter_noise = 0.0 146 | dtype = 'float32' 147 | ignore_padding_tokens = False 148 | 149 | attn/routing.RouterWeights: 150 | use_bias = False 151 | dtype = 'float32' 152 | kernel_axis_names = ('embed', 'expert') 153 | kernel_init = @router_init/linen.initializers.normal() 154 | bias_init = %BIAS_INIT 155 | # We obtain slightly better results adopting typical normally-distributed 156 | # scaling for the router, rather than the 0.1-scaled variance_scaling. May be 157 | # worth revisiting if stability becomes an issue during training. 158 | router_init/linen.initializers.normal: 159 | stddev = 2e-2 160 | 161 | attn_out/routing.Router: 162 | router_weights = @attn_out/routing.RouterWeights() 163 | input_axis_names = ('batch', 'length', 'embed') 164 | jitter_noise = 0.0 165 | dtype = 'float32' 166 | ignore_padding_tokens = False 167 | 168 | attn_out/routing.RouterWeights: 169 | use_bias = False 170 | dtype = 'float32' 171 | kernel_axis_names = ('embed', 'expert') 172 | kernel_init = @router_init/linen.initializers.normal() 173 | bias_init = %BIAS_INIT 174 | # We obtain slightly better results adopting typical normally-distributed 175 | # scaling for the router, rather than the 0.1-scaled variance_scaling. May be 176 | # worth revisiting if stability becomes an issue during training. 177 | router_init/linen.initializers.normal: 178 | stddev = 2e-2 179 | 180 | mlp1/routing.Router: 181 | router_weights = @mlp1/routing.RouterWeights() 182 | input_axis_names = ('batch', 'length', 'embed') 183 | jitter_noise = 0.0 184 | dtype = 'float32' 185 | ignore_padding_tokens = False 186 | 187 | mlp1/routing.RouterWeights: 188 | use_bias = False 189 | dtype = 'float32' 190 | kernel_axis_names = ('embed', 'expert') 191 | kernel_init = @router_init/linen.initializers.normal() 192 | bias_init = %BIAS_INIT 193 | # We obtain slightly better results adopting typical normally-distributed 194 | # scaling for the router, rather than the 0.1-scaled variance_scaling. May be 195 | # worth revisiting if stability becomes an issue during training. 196 | router_init/linen.initializers.normal: 197 | stddev = 2e-2 198 | 199 | mlp2/routing.Router: 200 | router_weights = @mlp2/routing.RouterWeights() 201 | input_axis_names = ('batch', 'length', 'mlp') 202 | jitter_noise = 0.0 203 | dtype = 'float32' 204 | ignore_padding_tokens = False 205 | 206 | mlp2/routing.RouterWeights: 207 | use_bias = False 208 | dtype = 'float32' 209 | kernel_axis_names = ('mlp', 'expert') 210 | kernel_init = @router_init/linen.initializers.normal() 211 | bias_init = %BIAS_INIT 212 | # We obtain slightly better results adopting typical normally-distributed 213 | # scaling for the router, rather than the 0.1-scaled variance_scaling. May be 214 | # worth revisiting if stability becomes an issue during training. 215 | router_init/linen.initializers.normal: 216 | stddev = 2e-2 -------------------------------------------------------------------------------- /configs/molora_eval.gin: -------------------------------------------------------------------------------- 1 | # ginlint: disable=bad-import-order 2 | from __gin__ import dynamic_registration 3 | 4 | import seqio 5 | from t5x import models 6 | from t5x import utils 7 | from t5x import adafactor 8 | from t5x import partitioning 9 | from t5x import optimizers as optim 10 | from flax import linen 11 | from flax import traverse_util 12 | from flaxformer.components import dense 13 | from flaxformer.components.attention import dense_attention 14 | 15 | from src import adafactor_custom as c_optim 16 | from src import partitioning_custom as c_partitioning 17 | 18 | from src import utils as peft_utils 19 | from src import routing 20 | from src import molora 21 | 22 | include 't5x/configs/runs/eval.gin' 23 | 24 | # ========== Data Mixture ========== 25 | # SeqIO tasks for p3 (from original repo) 26 | import t0_data 27 | 28 | # ========== These are IA3 HPs you might want to override ========== 29 | # If you want to change the actual optimizer itself (to optim.Adam, etc), make 30 | # sure to update the optimizer that is passed to the MultiOptimizer. 31 | adafactor.Adafactor: 32 | decay_rate = 0.8 33 | step_offset = 0 34 | logical_factor_rules = @c_optim.standard_logical_factor_rules() 35 | 36 | # ========== These are IA3 HPs you might want to override ========== 37 | partitioning.PjitPartitioner: 38 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 39 | 40 | partitioning.standard_logical_axis_rules: 41 | additional_rules = @c_partitioning.standard_logical_axis_rules() 42 | 43 | # ========== ARCHITECTURE ========== 44 | dense_attention.MultiHeadDotProductAttention: 45 | lora_output_conv = @attn/molora.MoLoRa() 46 | lora_q_conv = @attn/molora.MoLoRaAttention() 47 | lora_k_conv = @attn/molora.MoLoRaAttention() 48 | lora_v_conv = @attn/molora.MoLoRaAttention() 49 | 50 | attn/molora.MoLoRaAttention: 51 | rank = 4 52 | num_experts = 10 53 | num_heads = 16 54 | router = @attn/routing.Router() 55 | dtype = 'bfloat16' 56 | lora_axis_names_A = ('expert', 'embed', 'rank') 57 | lora_axis_names_B = ('expert', 'rank', 'embed') 58 | 59 | attn/molora.MoLoRa: 60 | rank = 4 61 | num_experts = 10 62 | router = @attn_out/routing.Router() 63 | dtype = 'bfloat16' 64 | lora_axis_names_A = ('expert', 'embed', 'rank') 65 | lora_axis_names_B = ('expert', 'rank', 'embed') 66 | 67 | dense.MlpBlock: 68 | lora_intermediate_conv = @mlp1/molora.MoLoRa() 69 | lora_output_conv = @mlp2/molora.MoLoRa() 70 | 71 | mlp1/molora.MoLoRa: 72 | rank = 4 73 | num_experts = 10 74 | output_dim = 2816 75 | router = @mlp1/routing.Router() 76 | dtype = 'bfloat16' 77 | lora_axis_names_A = ('expert', 'embed', 'rank') 78 | lora_axis_names_B = ('expert', 'rank', 'mlp') 79 | 80 | mlp2/molora.MoLoRa: 81 | rank = 4 82 | num_experts = 10 83 | output_dim = 1024 84 | router = @mlp2/routing.Router() 85 | dtype = 'bfloat16' 86 | lora_axis_names_A = ('expert', 'mlp', 'rank') 87 | lora_axis_names_B = ('expert', 'rank', 'embed') 88 | 89 | attn/routing.Router: 90 | router_weights = @attn/routing.RouterWeights() 91 | input_axis_names = ('batch', 'length', 'embed') 92 | jitter_noise = 0.0 93 | dtype = 'float32' 94 | ignore_padding_tokens = False 95 | 96 | attn/routing.RouterWeights: 97 | use_bias = False 98 | dtype = 'float32' 99 | kernel_axis_names = ('embed', 'expert') 100 | kernel_init = @router_init/linen.initializers.normal() 101 | bias_init = %BIAS_INIT 102 | # We obtain slightly better results adopting typical normally-distributed 103 | # scaling for the router, rather than the 0.1-scaled variance_scaling. May be 104 | # worth revisiting if stability becomes an issue during training. 105 | router_init/linen.initializers.normal: 106 | stddev = 2e-2 107 | 108 | attn_out/routing.Router: 109 | router_weights = @attn_out/routing.RouterWeights() 110 | input_axis_names = ('batch', 'length', 'embed') 111 | jitter_noise = 0.0 112 | dtype = 'float32' 113 | ignore_padding_tokens = False 114 | 115 | attn_out/routing.RouterWeights: 116 | use_bias = False 117 | dtype = 'float32' 118 | kernel_axis_names = ('embed', 'expert') 119 | kernel_init = @router_init/linen.initializers.normal() 120 | bias_init = %BIAS_INIT 121 | # We obtain slightly better results adopting typical normally-distributed 122 | # scaling for the router, rather than the 0.1-scaled variance_scaling. May be 123 | # worth revisiting if stability becomes an issue during training. 124 | router_init/linen.initializers.normal: 125 | stddev = 2e-2 126 | 127 | mlp1/routing.Router: 128 | router_weights = @mlp1/routing.RouterWeights() 129 | input_axis_names = ('batch', 'length', 'embed') 130 | jitter_noise = 0.0 131 | dtype = 'float32' 132 | ignore_padding_tokens = False 133 | 134 | mlp1/routing.RouterWeights: 135 | use_bias = False 136 | dtype = 'float32' 137 | kernel_axis_names = ('embed', 'expert') 138 | kernel_init = @router_init/linen.initializers.normal() 139 | bias_init = %BIAS_INIT 140 | # We obtain slightly better results adopting typical normally-distributed 141 | # scaling for the router, rather than the 0.1-scaled variance_scaling. May be 142 | # worth revisiting if stability becomes an issue during training. 143 | router_init/linen.initializers.normal: 144 | stddev = 2e-2 145 | 146 | mlp2/routing.Router: 147 | router_weights = @mlp2/routing.RouterWeights() 148 | input_axis_names = ('batch', 'length', 'mlp') 149 | jitter_noise = 0.0 150 | dtype = 'float32' 151 | ignore_padding_tokens = False 152 | 153 | mlp2/routing.RouterWeights: 154 | use_bias = False 155 | dtype = 'float32' 156 | kernel_axis_names = ('mlp', 'expert') 157 | kernel_init = @router_init/linen.initializers.normal() 158 | bias_init = %BIAS_INIT 159 | # We obtain slightly better results adopting typical normally-distributed 160 | # scaling for the router, rather than the 0.1-scaled variance_scaling. May be 161 | # worth revisiting if stability becomes an issue during training. 162 | router_init/linen.initializers.normal: 163 | stddev = 2e-2 -------------------------------------------------------------------------------- /configs/mov.gin: -------------------------------------------------------------------------------- 1 | # ginlint: disable=bad-import-order 2 | from __gin__ import dynamic_registration 3 | 4 | import seqio 5 | from t5x import models 6 | from t5x import utils 7 | from t5x import adafactor 8 | from t5x import partitioning 9 | from t5x import optimizers as optim 10 | from flax import linen 11 | from flax import traverse_util 12 | from flaxformer.components import dense 13 | from flaxformer.components.attention import dense_attention 14 | 15 | from src import adafactor_custom as c_optim 16 | from src import partitioning_custom as c_partitioning 17 | 18 | from src import utils as peft_utils 19 | from src import routing 20 | from src import mov 21 | 22 | include 't5x/configs/runs/finetune_no_eval.gin' 23 | 24 | # ========== Data Mixture ========== 25 | # SeqIO tasks for p3 (from original repo) 26 | import t0_data 27 | 28 | # ========== These are IA3 HPs you might want to override ========== 29 | # If you want to change the actual optimizer itself (to optim.Adam, etc), make 30 | # sure to update the optimizer that is passed to the MultiOptimizer. 31 | adafactor.Adafactor: 32 | decay_rate = 0.8 33 | step_offset = 0 34 | logical_factor_rules = @c_optim.standard_logical_factor_rules() 35 | 36 | # ========== These are IA3 HPs you might want to override ========== 37 | partitioning.PjitPartitioner: 38 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 39 | 40 | partitioning.standard_logical_axis_rules: 41 | additional_rules = @c_partitioning.standard_logical_axis_rules() 42 | 43 | # ========== Partial Loading ========== 44 | # The following is the configuration the allows to partially load a model (using 45 | # the values in a checkpoint) without it complaining that the shapes don't match 46 | # (because we have extra parameters, the ia3 scaling values) in our model. 47 | # You shouldn't need to update these outside of if you want to change the 48 | # optimizer itself. 49 | # 50 | # Optimizer 51 | # LR is set by `Trainer.learning_rate_fn`. 52 | # Use our MultiOptimizer wrapper to bind to the variadic 53 | # `*traversals_and_optimizers` 54 | OPTIMIZER = @optim.MultiOptimizer() 55 | optim.MultiOptimizer: 56 | traversals_and_optimizers = ((@traverse_util.ModelParamTraversal(), 57 | @adafactor.Adafactor()),) 58 | traverse_util.ModelParamTraversal: 59 | filter_fn = @peft_utils.match_any() 60 | # Our MultiOptimzier will match any parameter with a flattened name that 61 | # matches any of these regular expressions. 62 | TRAINABLE_REGEX = [".*/mov_scaling.*", ".*/router.*"] 63 | peft_utils.match_any.regexes = %TRAINABLE_REGEX 64 | 65 | # These settings allow us to partially reload a checkpoint, that is, we can load 66 | # most of the model weights from the checkpoint, without it complaining that we 67 | # don't have a weight for our ia3 scaling values in the checkpoint. 68 | utils.RestoreCheckpointConfig: 69 | # Activate the codepath that allows the merging of the optimizer state as 70 | # specified in the config (with our new parameter) and the optimizer state as 71 | # defined in the checkpoint. 72 | fallback_to_scratch = True 73 | # Use the T5X assignment map to grab values from the checkpoint. Each entry in 74 | # the map is a regular expression that matches some flattened variable name in 75 | # the optimizer state as defined in the model created by the config. The 76 | # second value is the corresponding name in optimizer state as defined by the 77 | # checkpoint. It supports interpolating capture groups from the initial regex. 78 | # If the second pattern is `None` we skip trying to load this variable from 79 | # the checkpoint. 80 | 81 | # Skip trying to load all keys that have the word ia3_scaling in them, these 82 | # will be initialized from scratch. 83 | assignment_map = ((r"^.*mov_scaling.*$", None), 84 | (r"^.*router.*$", None),) 85 | 86 | utils.create_learning_rate_scheduler: 87 | factors = "constant" 88 | # Learning rate from the paper. 89 | base_learning_rate = 3e-4 90 | 91 | #INITIAL_CHECKPOINT_PATH = None 92 | #utils.CheckpointConfig: 93 | # restore = None 94 | 95 | utils.SaveCheckpointConfig: 96 | period = 10000 97 | keep = 60 98 | 99 | # ========== ARCHITECTURE ========== 100 | # Add ia3 to all attention implementations 101 | dense_attention.MultiHeadDotProductAttention: 102 | k_conv = @mov.MoVAttention() 103 | v_conv = @mov.MoVAttention() 104 | 105 | mov.MoVAttention: 106 | num_experts = 30 107 | router = @attention/routing.Router() 108 | dtype = 'float32' 109 | 110 | dense.MlpBlock: 111 | intermediate_conv = @mov.MoV() 112 | 113 | mov.MoV: 114 | axis_name = ('unmodeled','mlp',) 115 | dtype = 'float32' 116 | num_experts = 30 117 | router = @mlp/routing.Router() 118 | 119 | mlp/routing.Router: 120 | router_weights = @mlp/routing.RouterWeights() 121 | jitter_noise = 0.0 122 | dtype = 'float32' 123 | ignore_padding_tokens = False 124 | 125 | mlp/routing.RouterWeights: 126 | use_bias = False 127 | dtype = 'float32' 128 | kernel_axis_names = ('mlp', 'unmodeled',) 129 | kernel_init = @router_init/linen.initializers.normal() 130 | bias_init = %BIAS_INIT 131 | # We obtain slightly better results adopting typical normally-distributed 132 | # scaling for the router, rather than the 0.1-scaled variance_scaling. May be 133 | # worth revisiting if stability becomes an issue during training. 134 | router_init/linen.initializers.normal: 135 | stddev = 2e-2 136 | 137 | attention/routing.Router: 138 | router_weights = @attention/routing.RouterWeights() 139 | input_axis_names = ('batch', 'length', 'heads', 'kv') 140 | jitter_noise = 0.0 141 | dtype = 'float32' 142 | ignore_padding_tokens = False 143 | 144 | attention/routing.RouterWeights: 145 | use_bias = False 146 | dtype = 'float32' 147 | kernel_axis_names = ('kv', 'unmodeled') 148 | kernel_init = @router_init/linen.initializers.normal() 149 | bias_init = %BIAS_INIT 150 | # We obtain slightly better results adopting typical normally-distributed 151 | # scaling for the router, rather than the 0.1-scaled variance_scaling. May be 152 | # worth revisiting if stability becomes an issue during training. 153 | router_init/linen.initializers.normal: 154 | stddev = 2e-2 -------------------------------------------------------------------------------- /configs/mov_eval.gin: -------------------------------------------------------------------------------- 1 | # ginlint: disable=bad-import-order 2 | from __gin__ import dynamic_registration 3 | 4 | import seqio 5 | from t5x import models 6 | from t5x import utils 7 | from t5x import adafactor 8 | from t5x import partitioning 9 | from t5x import optimizers as optim 10 | from flax import linen 11 | from flax import traverse_util 12 | from flaxformer.components import dense 13 | from flaxformer.components.attention import dense_attention 14 | 15 | from src import adafactor_custom as c_optim 16 | from src import partitioning_custom as c_partitioning 17 | 18 | from src import utils as peft_utils 19 | from src import routing 20 | from src import mov 21 | 22 | include 't5x/configs/runs/eval.gin' 23 | 24 | # ========== Data Mixture ========== 25 | # SeqIO tasks for p3 (from original repo) 26 | import t0_data 27 | 28 | # ========== These are IA3 HPs you might want to override ========== 29 | # If you want to change the actual optimizer itself (to optim.Adam, etc), make 30 | # sure to update the optimizer that is passed to the MultiOptimizer. 31 | adafactor.Adafactor: 32 | decay_rate = 0.8 33 | step_offset = 0 34 | logical_factor_rules = @c_optim.standard_logical_factor_rules() 35 | 36 | # ========== These are IA3 HPs you might want to override ========== 37 | partitioning.PjitPartitioner: 38 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 39 | 40 | partitioning.standard_logical_axis_rules: 41 | additional_rules = @c_partitioning.standard_logical_axis_rules() 42 | 43 | 44 | # ========== ARCHITECTURE ========== 45 | # Add ia3 to all attention implementations 46 | dense_attention.MultiHeadDotProductAttention: 47 | k_conv = @mov.MoVAttention() 48 | v_conv = @mov.MoVAttention() 49 | 50 | mov.MoVAttention: 51 | num_experts = 30 52 | router = @attention/routing.Router() 53 | dtype = 'float32' 54 | 55 | dense.MlpBlock: 56 | intermediate_conv = @mov.MoV() 57 | 58 | mov.MoV: 59 | axis_name = ('unmodeled','mlp',) 60 | dtype = 'float32' 61 | num_experts = 30 62 | router = @mlp/routing.Router() 63 | 64 | mlp/routing.Router: 65 | router_weights = @mlp/routing.RouterWeights() 66 | jitter_noise = 0.0 67 | dtype = 'float32' 68 | ignore_padding_tokens = False 69 | 70 | mlp/routing.RouterWeights: 71 | use_bias = False 72 | dtype = 'float32' 73 | kernel_axis_names = ('mlp', 'unmodeled',) 74 | kernel_init = @router_init/linen.initializers.normal() 75 | bias_init = %BIAS_INIT 76 | # We obtain slightly better results adopting typical normally-distributed 77 | # scaling for the router, rather than the 0.1-scaled variance_scaling. May be 78 | # worth revisiting if stability becomes an issue during training. 79 | router_init/linen.initializers.normal: 80 | stddev = 2e-2 81 | 82 | attention/routing.Router: 83 | router_weights = @attention/routing.RouterWeights() 84 | input_axis_names = ('batch', 'length', 'heads', 'kv') 85 | jitter_noise = 0.0 86 | dtype = 'float32' 87 | ignore_padding_tokens = False 88 | 89 | attention/routing.RouterWeights: 90 | use_bias = False 91 | dtype = 'float32' 92 | kernel_axis_names = ('kv', 'unmodeled') 93 | kernel_init = @router_init/linen.initializers.normal() 94 | bias_init = %BIAS_INIT 95 | # We obtain slightly better results adopting typical normally-distributed 96 | # scaling for the router, rather than the 0.1-scaled variance_scaling. May be 97 | # worth revisiting if stability becomes an issue during training. 98 | router_init/linen.initializers.normal: 99 | stddev = 2e-2 -------------------------------------------------------------------------------- /configs/t0.gin: -------------------------------------------------------------------------------- 1 | # ginlint: disable=bad-import-order 2 | from __gin__ import dynamic_registration 3 | 4 | import seqio 5 | from t5x import models 6 | from t5x import utils 7 | from t5x import adafactor 8 | from t5x import partitioning 9 | from t5x import optimizers as optim 10 | from flax import linen 11 | from flax import traverse_util 12 | 13 | include 't5x/configs/runs/finetune_no_eval.gin' 14 | 15 | # ========== Data Mixture ========== 16 | # SeqIO tasks for p3 (from original repo) 17 | import t0_data 18 | 19 | OPTIMIZER = @adafactor.Adafactor() 20 | adafactor.Adafactor: 21 | decay_rate = 0.8 22 | step_offset = 0 23 | logical_factor_rules = @adafactor.standard_logical_factor_rules() 24 | 25 | utils.create_learning_rate_scheduler: 26 | factors = "constant" 27 | # Learning rate from the paper. 28 | base_learning_rate = 1e-3 29 | 30 | utils.SaveCheckpointConfig: 31 | period = 1000 32 | keep = 60 -------------------------------------------------------------------------------- /configs/t0_eval.gin: -------------------------------------------------------------------------------- 1 | # ginlint: disable=bad-import-order 2 | from __gin__ import dynamic_registration 3 | 4 | include 't5x/configs/runs/eval.gin' 5 | 6 | # ========== Data Mixture ========== 7 | # SeqIO tasks for p3 (from original repo) 8 | import t0_data -------------------------------------------------------------------------------- /configs/t5/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cohere-Labs-Community/parameter-efficient-moe/e2683fead45eb480bcbf9c2e646fbbc2f7fc6efb/configs/t5/__init__.py -------------------------------------------------------------------------------- /configs/t5/architectures/flash_attention.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | from flaxformer.components.attention import dense_attention 4 | from flaxformer.components.attention import memory_efficient_attention 5 | 6 | 7 | dense_attention.MultiQueryDotProductAttention: 8 | attention_fn = @memory_efficient_attention.dot_product_attention_multiquery 9 | 10 | # Note that this attention function only works when the sequence length is 11 | # less than, or a multiple of dot_product_attention_multiquery.key_chunk_size. 12 | -------------------------------------------------------------------------------- /configs/t5/architectures/t5_1_1_flaxformer.gin: -------------------------------------------------------------------------------- 1 | # Flaxformer implementation of T5.1.1 architecture. 2 | # 3 | # Required to be overridden: 4 | # 5 | # - NUM_ENCODER_LAYERS 6 | # - NUM_DECODER_LAYERS 7 | # - NUM_HEADS 8 | # - HEAD_DIM 9 | # - EMBED_DIM 10 | # - MLP_DIM 11 | from __gin__ import dynamic_registration 12 | 13 | from flax import linen 14 | 15 | from flaxformer.architectures.t5 import t5_architecture 16 | from flaxformer.components.attention import dense_attention 17 | from flaxformer.components import dense 18 | from flaxformer.components import embedding 19 | from flaxformer.components import layer_norm 20 | from flaxformer.components import relative_position_biases 21 | 22 | # Must be overridden. 23 | NUM_ENCODER_LAYERS = %gin.REQUIRED 24 | NUM_DECODER_LAYERS = %gin.REQUIRED 25 | NUM_HEADS = %gin.REQUIRED 26 | HEAD_DIM = %gin.REQUIRED 27 | EMBED_DIM = %gin.REQUIRED 28 | MLP_DIM = %gin.REQUIRED 29 | NUM_EMBEDDINGS = %gin.REQUIRED 30 | 31 | # Constants (may be overridden) 32 | ACTIVATION_DTYPE = 'bfloat16' 33 | ACTIVATION_PARTITIONING_DIMS = 1 34 | SCALE = 1.0 35 | DROPOUT_RATE = 0.0 36 | 37 | # Macros 38 | BIAS_INIT = @bias_init/linen.initializers.normal() 39 | bias_init/linen.initializers.normal.stddev = 1e-6 40 | DROPOUT_FACTORY = @dropout_factory/linen.Dropout 41 | dropout_factory/linen.Dropout: 42 | rate = %DROPOUT_RATE 43 | broadcast_dims = (-2,) 44 | 45 | # Architecture (Flax Module) 46 | ARCHITECTURE = @t5_architecture.EncoderDecoder() 47 | t5_architecture.EncoderDecoder: 48 | encoder_factory = @t5_architecture.Encoder 49 | decoder_factory = @t5_architecture.Decoder 50 | shared_token_embedder_factory = @embedding.Embed 51 | dtype = %ACTIVATION_DTYPE 52 | 53 | # Encoder 54 | t5_architecture.Encoder: 55 | num_layers = %NUM_ENCODER_LAYERS 56 | layer_factory = @t5_architecture.EncoderLayer 57 | input_dropout_factory = %DROPOUT_FACTORY 58 | output_dropout_factory = %DROPOUT_FACTORY 59 | layer_norm_factory = @layer_norm.T5LayerNorm 60 | position_embedder_factory = None 61 | shared_relative_position_bias_factory = @relative_position_biases.RelativePositionBiases 62 | dtype = %ACTIVATION_DTYPE 63 | 64 | # Encoder Layer 65 | t5_architecture.EncoderLayer: 66 | attention = @dense_attention.MultiHeadDotProductAttention() 67 | mlp = @dense.MlpBlock() 68 | dropout_factory = %DROPOUT_FACTORY 69 | layer_norm_factory = @layer_norm.T5LayerNorm 70 | activation_partitioning_dims = %ACTIVATION_PARTITIONING_DIMS 71 | 72 | # Decoder 73 | t5_architecture.Decoder: 74 | num_layers = %NUM_DECODER_LAYERS 75 | layer_factory = @t5_architecture.DecoderLayer 76 | dropout_factory = %DROPOUT_FACTORY 77 | layer_norm_factory = @layer_norm.T5LayerNorm 78 | position_embedder_factory = None 79 | shared_relative_position_bias_factory = @relative_position_biases.RelativePositionBiases 80 | output_logits_factory = @output_logits/dense.DenseGeneral 81 | dtype = %ACTIVATION_DTYPE 82 | 83 | # Decoupled embedding 84 | output_logits/dense.DenseGeneral: 85 | features = %NUM_EMBEDDINGS 86 | use_bias = False 87 | dtype = 'float32' 88 | kernel_init = @output_logits_kernel_init/linen.initializers.variance_scaling() 89 | bias_init = %BIAS_INIT 90 | kernel_axis_names = ["embed", "vocab"] 91 | output_logits_kernel_init/linen.initializers.variance_scaling: 92 | scale = %SCALE 93 | mode = 'fan_in' 94 | distribution = 'truncated_normal' 95 | 96 | # Decoder Layer 97 | t5_architecture.DecoderLayer: 98 | self_attention = @dense_attention.MultiHeadDotProductAttention() 99 | encoder_decoder_attention = @dense_attention.MultiHeadDotProductAttention() 100 | mlp = @dense.MlpBlock() 101 | dropout_factory = %DROPOUT_FACTORY 102 | layer_norm_factory = @layer_norm.T5LayerNorm 103 | activation_partitioning_dims = %ACTIVATION_PARTITIONING_DIMS 104 | 105 | # Token Embedder (shared) 106 | embedding.Embed: 107 | num_embeddings= %NUM_EMBEDDINGS 108 | features = %EMBED_DIM 109 | cast_input_dtype = 'int32' 110 | dtype = %ACTIVATION_DTYPE 111 | attend_dtype = 'float32' # for logit training stability 112 | embedding_init = @token_embedder_init/linen.initializers.normal() 113 | one_hot = True 114 | name = 'token_embedder' 115 | token_embedder_init/linen.initializers.normal.stddev = 1.0 116 | 117 | # Attention (encoder, decoder, self-attention) 118 | dense_attention.MultiHeadDotProductAttention: 119 | num_heads = %NUM_HEADS 120 | dtype = %ACTIVATION_DTYPE 121 | head_dim = %HEAD_DIM 122 | kernel_init = @attention_kernel_init/linen.initializers.variance_scaling() 123 | bias_init = %BIAS_INIT 124 | use_bias = False 125 | broadcast_dropout = True 126 | dropout_rate = %DROPOUT_RATE 127 | attention_kernel_init/linen.initializers.variance_scaling: 128 | scale = %SCALE 129 | mode = 'fan_in' 130 | distribution = 'normal' 131 | 132 | # Relative position biases (encoder, decoder) 133 | relative_position_biases.RelativePositionBiases: 134 | num_heads = %NUM_HEADS 135 | dtype = %ACTIVATION_DTYPE 136 | num_buckets = 32 137 | max_distance = 128 138 | embedding_init = @relative_position_bias_init/linen.initializers.variance_scaling() 139 | relative_position_bias_init/linen.initializers.variance_scaling: 140 | scale = %SCALE 141 | mode = 'fan_avg' 142 | distribution = 'uniform' 143 | 144 | # MLP (encoder, decoder) 145 | dense.MlpBlock: 146 | use_bias = False 147 | intermediate_dim = %MLP_DIM 148 | activations = ('gelu', 'linear') 149 | kernel_init = @mlp_kernel_init/linen.initializers.variance_scaling() 150 | bias_init = %BIAS_INIT 151 | intermediate_dropout_rate = %DROPOUT_RATE 152 | final_dropout_rate = 0 153 | dtype = %ACTIVATION_DTYPE 154 | mlp_kernel_init/linen.initializers.variance_scaling: 155 | scale = %SCALE 156 | mode = 'fan_in' 157 | distribution = 'truncated_normal' 158 | 159 | layer_norm.T5LayerNorm.dtype = %ACTIVATION_DTYPE 160 | -------------------------------------------------------------------------------- /configs/t5/architectures/t5_flaxformer.gin: -------------------------------------------------------------------------------- 1 | # Flaxformer implementation of original T5 (1.0) architecture. 2 | # 3 | # Required to be overridden: 4 | # 5 | # - NUM_ENCODER_LAYERS 6 | # - NUM_DECODER_LAYERS 7 | # - NUM_HEADS 8 | # - HEAD_DIM 9 | # - EMBED_DIM 10 | # - MLP_DIM 11 | from __gin__ import dynamic_registration 12 | 13 | from flax import linen 14 | from flaxformer.architectures.t5 import t5_architecture 15 | from flaxformer.components.attention import dense_attention 16 | from flaxformer.components import dense 17 | from flaxformer.components import embedding 18 | from flaxformer.components import layer_norm 19 | from flaxformer.components import relative_position_biases 20 | 21 | # Must be overridden. 22 | NUM_ENCODER_LAYERS = %gin.REQUIRED 23 | NUM_DECODER_LAYERS = %gin.REQUIRED 24 | NUM_HEADS = %gin.REQUIRED 25 | HEAD_DIM = %gin.REQUIRED 26 | EMBED_DIM = %gin.REQUIRED 27 | MLP_DIM = %gin.REQUIRED 28 | NUM_EMBEDDINGS = %gin.REQUIRED 29 | 30 | # Constants (may be overridden) 31 | ACTIVATION_DTYPE = 'bfloat16' 32 | ACTIVATION_PARTITIONING_DIMS = 1 33 | DROPOUT_RATE = 0.0 34 | 35 | # Macros 36 | BIAS_INIT = @bias_init/linen.initializers.normal() 37 | bias_init/linen.initializers.normal.stddev = 1e-6 38 | DROPOUT_FACTORY = @dropout_factory/linen.Dropout 39 | dropout_factory/linen.Dropout: 40 | rate = %DROPOUT_RATE 41 | broadcast_dims = (-2,) 42 | 43 | # Architecture (Flax Module) 44 | ARCHITECTURE = @t5_architecture.EncoderDecoder() 45 | t5_architecture.EncoderDecoder: 46 | encoder_factory = @t5_architecture.Encoder 47 | decoder_factory = @t5_architecture.Decoder 48 | shared_token_embedder_factory = @embedding.Embed 49 | dtype = %ACTIVATION_DTYPE 50 | 51 | # Encoder 52 | t5_architecture.Encoder: 53 | num_layers = %NUM_ENCODER_LAYERS 54 | layer_factory = @t5_architecture.EncoderLayer 55 | input_dropout_factory = %DROPOUT_FACTORY 56 | output_dropout_factory = %DROPOUT_FACTORY 57 | layer_norm_factory = @layer_norm.T5LayerNorm 58 | position_embedder_factory = None 59 | shared_relative_position_bias_factory = @relative_position_biases.RelativePositionBiases 60 | dtype = %ACTIVATION_DTYPE 61 | 62 | # Encoder Layer 63 | t5_architecture.EncoderLayer: 64 | attention = @dense_attention.MultiHeadDotProductAttention() 65 | mlp = @dense.MlpBlock() 66 | dropout_factory = %DROPOUT_FACTORY 67 | layer_norm_factory = @layer_norm.T5LayerNorm 68 | activation_partitioning_dims = %ACTIVATION_PARTITIONING_DIMS 69 | 70 | # Decoder 71 | t5_architecture.Decoder: 72 | num_layers = %NUM_DECODER_LAYERS 73 | layer_factory = @t5_architecture.DecoderLayer 74 | dropout_factory = %DROPOUT_FACTORY 75 | layer_norm_factory = @layer_norm.T5LayerNorm 76 | output_logits_factory = None 77 | position_embedder_factory = None 78 | shared_relative_position_bias_factory = @relative_position_biases.RelativePositionBiases 79 | dtype = %ACTIVATION_DTYPE 80 | 81 | # Decoder Layer 82 | t5_architecture.DecoderLayer: 83 | self_attention = @dense_attention.MultiHeadDotProductAttention() 84 | encoder_decoder_attention = @dense_attention.MultiHeadDotProductAttention() 85 | mlp = @dense.MlpBlock() 86 | dropout_factory = %DROPOUT_FACTORY 87 | layer_norm_factory = @layer_norm.T5LayerNorm 88 | activation_partitioning_dims = %ACTIVATION_PARTITIONING_DIMS 89 | 90 | # Token Embedder (shared) 91 | embedding.Embed: 92 | num_embeddings= %NUM_EMBEDDINGS 93 | features = %EMBED_DIM 94 | cast_input_dtype = 'int32' 95 | dtype = %ACTIVATION_DTYPE 96 | attend_dtype = 'float32' # for logit training stability 97 | one_hot = True 98 | embedding_init = @token_embedder_init/linen.initializers.normal() 99 | name = 'token_embedder' 100 | token_embedder_init/linen.initializers.normal.stddev = 1.0 101 | 102 | # Attention (encoder, decoder, self-attention) 103 | dense_attention.MultiHeadDotProductAttention: 104 | num_heads = %NUM_HEADS 105 | head_dim = %HEAD_DIM 106 | dtype = %ACTIVATION_DTYPE 107 | kernel_init = @attention_kernel_init/linen.initializers.variance_scaling() 108 | bias_init = %BIAS_INIT 109 | use_bias = False 110 | broadcast_dropout = True 111 | dropout_rate = %DROPOUT_RATE 112 | attention_kernel_init/linen.initializers.variance_scaling: 113 | scale = 1.0 114 | mode = 'fan_in' 115 | distribution = 'normal' 116 | 117 | # Relative position biases (encoder, decoder) 118 | relative_position_biases.RelativePositionBiases: 119 | num_heads = %NUM_HEADS 120 | num_buckets = 32 121 | max_distance = 128 122 | dtype = %ACTIVATION_DTYPE 123 | embedding_init = @relative_position_bias_init/linen.initializers.variance_scaling() 124 | relative_position_bias_init/linen.initializers.variance_scaling: 125 | scale = 1.0 126 | mode = 'fan_avg' 127 | distribution = 'uniform' 128 | 129 | # MLP (encoder, decoder) 130 | dense.MlpBlock: 131 | use_bias = False 132 | intermediate_dim = %MLP_DIM 133 | activations = ('relu',) 134 | kernel_init = @mlp_kernel_init/linen.initializers.variance_scaling() 135 | bias_init = %BIAS_INIT 136 | intermediate_dropout_rate = %DROPOUT_RATE 137 | final_dropout_rate = 0 138 | dtype = %ACTIVATION_DTYPE 139 | mlp_kernel_init/linen.initializers.variance_scaling: 140 | scale = 1.0 141 | mode = 'fan_in' 142 | distribution = 'truncated_normal' 143 | 144 | layer_norm.T5LayerNorm.dtype = %ACTIVATION_DTYPE 145 | -------------------------------------------------------------------------------- /configs/t5/gin_configs_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for gin configs in this directory.""" 16 | 17 | # "Unused" imports below are needed by gin configs. 18 | # pylint: disable=unused-import 19 | 20 | import os 21 | 22 | from absl.testing import absltest 23 | from absl.testing import parameterized 24 | from flax import linen as nn 25 | import gin 26 | from jax import numpy as jnp 27 | from jax import random 28 | import numpy as np 29 | from t5x import models as t5x_models 30 | from t5x import utils 31 | 32 | 33 | class GinConfigsTest(parameterized.TestCase): 34 | 35 | @classmethod 36 | def setUpClass(cls): 37 | super(GinConfigsTest, cls).setUpClass() 38 | cls.root = os.path.join( 39 | absltest.get_default_test_srcdir(), 40 | 'flaxformer/t5x/configs/t5') 41 | gin.add_config_file_search_path(cls.root) 42 | 43 | def setUp(self): 44 | super().setUp() 45 | gin.clear_config() 46 | 47 | @parameterized.parameters( 48 | 'byt5_small.gin', 49 | 'mt5_small.gin', 50 | 't5_1_1_small.gin', 51 | 't5_small.gin', 52 | ) 53 | def test_model_gin_config(self, filename): 54 | path = os.path.join(self.root, 'models', filename) 55 | gin.parse_config_file(path) 56 | gin.finalize() # Check for required values, etc. 57 | 58 | model_config_ref: gin.ConfigurableReference = gin.query_parameter('%MODEL') 59 | 60 | # Instantiate T5X model (e.g. `t5x.models.EncoderDecoderModel`). 61 | model: t5x_models.BaseModel = model_config_ref.scoped_configurable_fn() 62 | 63 | encoder_input_tokens = jnp.ones((2, 3)) 64 | # For this test, decoder input and target tokens are fake values. 65 | decoder_input_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) 66 | decoder_target_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) 67 | decoder_loss_weights = jnp.array([[1, 1, 1, 0], [0, 1, 0, 1]]) 68 | 69 | if 'lamda' in filename: 70 | encoder_kwargs = {} 71 | else: 72 | encoder_kwargs = {'encoder_input_tokens': encoder_input_tokens} 73 | 74 | variables = model.module.init( 75 | random.PRNGKey(0), 76 | decoder_input_tokens=decoder_input_tokens, 77 | decoder_target_tokens=decoder_target_tokens, 78 | enable_dropout=False, 79 | **encoder_kwargs) 80 | 81 | output = model.module.apply({'params': variables['params']}, 82 | decoder_input_tokens=decoder_input_tokens, 83 | decoder_target_tokens=decoder_target_tokens, 84 | enable_dropout=False, 85 | **encoder_kwargs) 86 | del output # Unused. 87 | 88 | batch = { 89 | 'encoder_input_tokens': encoder_input_tokens, 90 | 'decoder_input_tokens': decoder_input_tokens, 91 | 'decoder_target_tokens': decoder_target_tokens, 92 | 'decoder_loss_weights': decoder_loss_weights 93 | } 94 | res = model.score_batch(variables['params'], batch) 95 | del res # Unused. 96 | 97 | @parameterized.parameters('t5_1_1_flaxformer.gin', 't5_flaxformer.gin') 98 | def test_architecture_gin_config(self, filename): 99 | path = os.path.join(self.root, 'architectures', filename) 100 | gin.parse_config_file(path) 101 | gin.parse_config(""" 102 | NUM_HEADS = 2 103 | NUM_ENCODER_LAYERS = 2 104 | NUM_DECODER_LAYERS = 2 105 | NUM_LAYERS = 2 106 | HEAD_DIM = 4 107 | EMBED_DIM = 8 108 | MLP_DIM = 8 109 | NUM_EMBEDDINGS = 128 110 | """) 111 | gin.finalize() # Check for required values, etc. 112 | 113 | arch_config_ref: gin.ConfigurableReference = gin.query_parameter( 114 | '%ARCHITECTURE') 115 | 116 | # Instantiate architecture. 117 | arch: nn.Module = arch_config_ref.scoped_configurable_fn() 118 | 119 | shape = [4, 8] 120 | encoder_input_tokens = np.ones(shape, dtype=np.int32) 121 | decoder_input_tokens = np.ones(shape, dtype=np.int32) 122 | decoder_target_tokens = np.ones(shape, dtype=np.int32) 123 | 124 | if 'lamda' in filename: 125 | encoder_kwargs = {} 126 | else: 127 | encoder_kwargs = {'encoder_input_tokens': encoder_input_tokens} 128 | 129 | output, variables = arch.init_with_output( 130 | random.PRNGKey(0), 131 | decoder_input_tokens=decoder_input_tokens, 132 | decoder_target_tokens=decoder_target_tokens, 133 | enable_dropout=False, 134 | decode=False, 135 | max_decode_length=None, 136 | **encoder_kwargs) 137 | del output # Unused. 138 | 139 | # Call with expected arrays (e.g. Call `__call__` with concrete sequences). 140 | _ = arch.apply( 141 | variables, 142 | decoder_input_tokens=decoder_input_tokens, 143 | decoder_target_tokens=decoder_target_tokens, 144 | **encoder_kwargs) 145 | 146 | 147 | if __name__ == '__main__': 148 | absltest.main() 149 | -------------------------------------------------------------------------------- /configs/t5/models/t5_11B.gin: -------------------------------------------------------------------------------- 1 | # Original T5 (1.0) 11B model. 2 | # Provides MODEL 3 | 4 | include 'configs/t5/models/t5_base.gin' # imports vocab, optimizer and model. 5 | 6 | # Architecture overrides 7 | NUM_ENCODER_LAYERS = 24 8 | NUM_DECODER_LAYERS = 24 9 | NUM_HEADS = 128 10 | HEAD_DIM = 128 11 | EMBED_DIM = 1024 12 | MLP_DIM = 65536 13 | -------------------------------------------------------------------------------- /configs/t5/models/t5_1_1_base.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Base model. 2 | # Provides MODEL 3 | from __gin__ import dynamic_registration 4 | 5 | import seqio 6 | from t5x import adafactor 7 | from t5x import models 8 | 9 | ARCHITECTURE = %gin.REQUIRED 10 | 11 | include 'configs/t5/architectures/t5_1_1_flaxformer.gin' 12 | 13 | # Architecture overrides 14 | NUM_ENCODER_LAYERS = 12 15 | NUM_DECODER_LAYERS = 12 16 | NUM_HEADS = 12 17 | HEAD_DIM = 64 18 | EMBED_DIM = 768 19 | MLP_DIM = 2048 20 | 21 | # Loss HParam defaults 22 | Z_LOSS = 0.0001 23 | LABEL_SMOOTHING = 0.0 24 | # NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) 25 | # the loss normalizing factor should be set to 2048 * 114 (pretraining 26 | # batch_size * target_token_length). 27 | LOSS_NORMALIZING_FACTOR = None 28 | 29 | # Vocabulary (shared by encoder and decoder) 30 | VOCABULARY = @seqio.SentencePieceVocabulary() 31 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" 32 | NUM_EMBEDDINGS = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency 33 | 34 | # Optimizer 35 | # `learning_rate` is set by `Trainer.learning_rate_fn`. 36 | OPTIMIZER = @adafactor.Adafactor() 37 | adafactor.Adafactor: 38 | decay_rate = 0.8 39 | step_offset = 0 40 | 41 | # Model 42 | MODEL = @models.EncoderDecoderModel() 43 | models.EncoderDecoderModel: 44 | module = %ARCHITECTURE # provided by t5_flaxformer 45 | input_vocabulary = %VOCABULARY 46 | output_vocabulary = %VOCABULARY 47 | optimizer_def = %OPTIMIZER 48 | z_loss = %Z_LOSS 49 | label_smoothing = %LABEL_SMOOTHING 50 | loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR 51 | -------------------------------------------------------------------------------- /configs/t5/models/t5_1_1_large.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Large model. 2 | # Provides MODEL 3 | 4 | include 'configs/t5/models/t5_1_1_base.gin' # imports vocab, optimizer and model. 5 | 6 | # Architecture overrides 7 | NUM_ENCODER_LAYERS = 24 8 | NUM_DECODER_LAYERS = 24 9 | NUM_HEADS = 16 10 | HEAD_DIM = 64 11 | EMBED_DIM = 1024 12 | MLP_DIM = 2816 13 | -------------------------------------------------------------------------------- /configs/t5/models/t5_1_1_small.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Small model. 2 | # Provides MODEL 3 | 4 | include 'configs/t5/models/t5_1_1_base.gin' # imports vocab, optimizer and model. 5 | 6 | # Architecture overrides 7 | NUM_ENCODER_LAYERS = 8 8 | NUM_DECODER_LAYERS = 8 9 | NUM_HEADS = 6 10 | HEAD_DIM = 64 11 | EMBED_DIM = 512 12 | MLP_DIM = 1024 13 | -------------------------------------------------------------------------------- /configs/t5/models/t5_1_1_tiny.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Small model. 2 | # Provides MODEL 3 | 4 | include 'configs/t5/models/t5_1_1_base.gin' # imports vocab, optimizer and model. 5 | 6 | # Architecture overrides 7 | NUM_HEADS = 2 8 | NUM_ENCODER_LAYERS = 1 9 | NUM_DECODER_LAYERS = 1 10 | HEAD_DIM = 16 11 | EMBED_DIM = 32 12 | MLP_DIM = 128 13 | -------------------------------------------------------------------------------- /configs/t5/models/t5_1_1_xl.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 XL model. 2 | # Provides MODEL 3 | 4 | include 'configs/t5/models/t5_1_1_base.gin' # imports vocab, optimizer and model. 5 | 6 | # Architecture overrides 7 | NUM_ENCODER_LAYERS = 24 8 | NUM_DECODER_LAYERS = 24 9 | NUM_HEADS = 32 10 | HEAD_DIM = 64 11 | EMBED_DIM = 2048 12 | MLP_DIM = 5120 13 | -------------------------------------------------------------------------------- /configs/t5/models/t5_1_1_xxl.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 XXL model. 2 | # Provides MODEL 3 | 4 | include 'configs/t5/models/t5_1_1_base.gin' # imports vocab, optimizer and model. 5 | 6 | # Architecture overrides 7 | NUM_ENCODER_LAYERS = 24 8 | NUM_DECODER_LAYERS = 24 9 | NUM_HEADS = 64 10 | HEAD_DIM = 64 11 | EMBED_DIM = 4096 12 | MLP_DIM = 10240 13 | -------------------------------------------------------------------------------- /configs/t5/models/t5_3B.gin: -------------------------------------------------------------------------------- 1 | # Original T5 (1.0) 3B model. 2 | # Provides MODEL 3 | 4 | include 'configs/t5/models/t5_base.gin' # imports vocab, optimizer and model. 5 | 6 | # Architecture overrides 7 | NUM_ENCODER_LAYERS = 24 8 | NUM_DECODER_LAYERS = 24 9 | NUM_HEADS = 32 10 | HEAD_DIM = 128 11 | EMBED_DIM = 1024 12 | MLP_DIM = 16384 13 | -------------------------------------------------------------------------------- /configs/t5/models/t5_base.gin: -------------------------------------------------------------------------------- 1 | # Original T5 (1.0) Base model. 2 | # Provides MODEL 3 | from __gin__ import dynamic_registration 4 | 5 | import seqio 6 | from t5x import adafactor 7 | from t5x import models 8 | 9 | ARCHITECTURE = %gin.REQUIRED 10 | 11 | include 'configs/t5/architectures/t5_flaxformer.gin' 12 | 13 | # Architecture overrides 14 | NUM_ENCODER_LAYERS = 12 15 | NUM_DECODER_LAYERS = 12 16 | NUM_HEADS = 12 17 | HEAD_DIM = 64 18 | EMBED_DIM = 768 19 | MLP_DIM = 3072 20 | 21 | # Loss HParam defaults 22 | Z_LOSS = 0.0001 23 | LABEL_SMOOTHING = 0.0 24 | # NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) 25 | # the loss normalizing factor should be set to 2048 * 114 (pretraining 26 | # batch_size * target_token_length). 27 | LOSS_NORMALIZING_FACTOR = None 28 | 29 | # Vocabulary (shared by encoder and decoder) 30 | VOCABULARY = @seqio.SentencePieceVocabulary() 31 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" 32 | NUM_EMBEDDINGS = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency 33 | 34 | # Optimizer 35 | # `learning_rate` is set by `Trainer.learning_rate_fn`. 36 | OPTIMIZER = @adafactor.Adafactor() 37 | adafactor.Adafactor: 38 | decay_rate = 0.8 39 | step_offset = 0 40 | 41 | # Model 42 | MODEL = @models.EncoderDecoderModel() 43 | models.EncoderDecoderModel: 44 | module = %ARCHITECTURE # provided by t5_flaxformer 45 | input_vocabulary = %VOCABULARY 46 | output_vocabulary = %VOCABULARY 47 | optimizer_def = %OPTIMIZER 48 | z_loss = %Z_LOSS 49 | label_smoothing = %LABEL_SMOOTHING 50 | loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR 51 | -------------------------------------------------------------------------------- /configs/t5/models/t5_large.gin: -------------------------------------------------------------------------------- 1 | # Original T5 (1.0) Large model. 2 | # Provides MODEL 3 | 4 | include 'configs/t5/models/t5_base.gin' # imports vocab, optimizer and model. 5 | 6 | # Architecture overrides 7 | NUM_ENCODER_LAYERS = 24 8 | NUM_DECODER_LAYERS = 24 9 | NUM_HEADS = 16 10 | HEAD_DIM = 64 11 | EMBED_DIM = 1024 12 | MLP_DIM = 4096 13 | -------------------------------------------------------------------------------- /configs/t5/models/t5_small.gin: -------------------------------------------------------------------------------- 1 | # Original T5 (1.0) Small model. 2 | # Provides MODEL 3 | 4 | include 'configs/t5/models/t5_base.gin' # imports vocab, optimizer and model. 5 | 6 | # Architecture overrides 7 | NUM_ENCODER_LAYERS = 6 8 | NUM_DECODER_LAYERS = 6 9 | NUM_HEADS = 8 10 | HEAD_DIM = 64 11 | EMBED_DIM = 512 12 | MLP_DIM = 2048 13 | -------------------------------------------------------------------------------- /demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cohere-Labs-Community/parameter-efficient-moe/e2683fead45eb480bcbf9c2e646fbbc2f7fc6efb/demo.png -------------------------------------------------------------------------------- /scripts/find_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Find where a module is installed. 16 | 17 | This tool is useful for finding where a package like T5X is installed so that 18 | we can easily use the gin configs that are bundled with it. 19 | 20 | Example usage: 21 | 22 | python -m t5x.train \ 23 | --gin_search_paths=`python -m prompt_tuning.scripts.find_module t5x` \ 24 | --gin_file=t5x/configs/... \ 25 | ... 26 | """ 27 | 28 | import importlib 29 | import os 30 | from typing import Sequence 31 | from absl import app 32 | 33 | 34 | def main(argv: Sequence[str]): 35 | if len(argv) != 2: 36 | raise app.UsageError("Missing module argument.") 37 | 38 | module = importlib.import_module(argv[1]) 39 | print(os.path.dirname(os.path.abspath(module.__file__))) 40 | 41 | 42 | if __name__ == "__main__": 43 | app.run(main) 44 | -------------------------------------------------------------------------------- /scripts/ia3_eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CKPT_DIR=${1:-${CKPT_DIR}} 4 | EVAL_DIR=${2:-${EVAL_DIR}} 5 | 6 | T5X_DIR="`python3 -m scripts.find_module t5x`/.." 7 | FLAXFORMER_DIR="`python3 -m scripts.find_module flaxformer`/.." 8 | echo "Searching for gin configs in:" 9 | echo "- ${T5X_DIR}" 10 | echo "- ${FLAXFORMER_DIR}" 11 | echo "=============================" 12 | CACHE_DIR="raw_tfrecords/you_cache_dir" 13 | 14 | python3 -m t5x.eval \ 15 | --gin_search_paths="${T5X_DIR}" \ 16 | --gin_file="configs/t5/models/t5_1_1_large.gin" \ 17 | --gin_file="configs/ia3_eval.gin" \ 18 | --gin.EVAL_OUTPUT_DIR="'${EVAL_DIR}'" \ 19 | --gin.MIXTURE_OR_TASK_NAME="'t0_eval_score_eval'" \ 20 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 1024, 'targets': 256}" \ 21 | --gin.CHECKPOINT_PATH="'${CKPT_DIR}'" \ 22 | --seqio_additional_cache_dirs=${CACHE_DIR} \ 23 | --gin.utils.DatasetConfig.use_cached="True" \ 24 | --gin.utils.DatasetConfig.split="'validation'" \ 25 | --gin.BATCH_SIZE="32" -------------------------------------------------------------------------------- /scripts/ia3_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_DIR=${1:-${MODEL_DIR}} 4 | 5 | T5X_DIR="`python3 -m scripts.find_module t5x`/.." 6 | FLAXFORMER_DIR="`python3 -m scripts.find_module flaxformer`/.." 7 | echo "Searching for gin configs in:" 8 | echo "- ${T5X_DIR}" 9 | echo "- ${FLAXFORMER_DIR}" 10 | echo "=============================" 11 | PRETRAINED_MODEL="gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_large/checkpoint_1100000" 12 | CACHE_DIR="raw_tfrecords/you_cache_dir" 13 | 14 | python3 -m t5x.train \ 15 | --gin_search_paths="${T5X_DIR}" \ 16 | --gin_file="configs/t5/models/t5_1_1_large.gin" \ 17 | --gin_file="configs/ia3.gin" \ 18 | --gin.MODEL_DIR="'${MODEL_DIR}'" \ 19 | --gin.LOSS_NORMALIZING_FACTOR="'AVERAGE_PER_SEQUENCE'" \ 20 | --gin.MIXTURE_OR_TASK_NAME="'t0_train'" \ 21 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 1024, 'targets': 256}" \ 22 | --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" \ 23 | --gin.TRAIN_STEPS="1_600_000" \ 24 | --gin.USE_CACHED_TASKS="True" \ 25 | --gin.PACKING="True" \ 26 | --seqio_additional_cache_dirs=${CACHE_DIR} \ 27 | --gin.BATCH_SIZE="32" -------------------------------------------------------------------------------- /scripts/lora_eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CKPT_DIR=${1:-${CKPT_DIR}} 4 | EVAL_DIR=${2:-${EVAL_DIR}} 5 | 6 | T5X_DIR="`python3 -m scripts.find_module t5x`/.." 7 | FLAXFORMER_DIR="`python3 -m scripts.find_module flaxformer`/.." 8 | echo "Searching for gin configs in:" 9 | echo "- ${T5X_DIR}" 10 | echo "- ${FLAXFORMER_DIR}" 11 | echo "=============================" 12 | CACHE_DIR="raw_tfrecords/you_cache_dir" 13 | 14 | python3 -m t5x.eval \ 15 | --gin_search_paths="${T5X_DIR}" \ 16 | --gin_file="configs/t5/models/t5_1_1_large.gin" \ 17 | --gin_file="configs/lora_eval.gin" \ 18 | --gin.EVAL_OUTPUT_DIR="'${EVAL_DIR}'" \ 19 | --gin.MIXTURE_OR_TASK_NAME="'t0_eval_score_eval'" \ 20 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 1024, 'targets': 256}" \ 21 | --gin.CHECKPOINT_PATH="'${CKPT_DIR}'" \ 22 | --seqio_additional_cache_dirs=${CACHE_DIR} \ 23 | --gin.utils.DatasetConfig.use_cached="True" \ 24 | --gin.utils.DatasetConfig.split="'validation'" \ 25 | --gin.BATCH_SIZE="32" -------------------------------------------------------------------------------- /scripts/lora_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_DIR=${1:-${MODEL_DIR}} 4 | 5 | T5X_DIR="`python3 -m scripts.find_module t5x`/.." 6 | FLAXFORMER_DIR="`python3 -m scripts.find_module flaxformer`/.." 7 | echo "Searching for gin configs in:" 8 | echo "- ${T5X_DIR}" 9 | echo "- ${FLAXFORMER_DIR}" 10 | echo "=============================" 11 | PRETRAINED_MODEL="gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_large/checkpoint_1100000" 12 | CACHE_DIR="raw_tfrecords/you_cache_dir" 13 | 14 | python3 -m t5x.train \ 15 | --gin_search_paths="${T5X_DIR}" \ 16 | --gin_file="configs/t5/models/t5_1_1_large.gin" \ 17 | --gin_file="configs/lora.gin" \ 18 | --gin.MODEL_DIR="'${MODEL_DIR}'" \ 19 | --gin.LOSS_NORMALIZING_FACTOR="'AVERAGE_PER_SEQUENCE'" \ 20 | --gin.MIXTURE_OR_TASK_NAME="'t0_train'" \ 21 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 1024, 'targets': 256}" \ 22 | --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" \ 23 | --gin.TRAIN_STEPS="1_600_000" \ 24 | --gin.USE_CACHED_TASKS="True" \ 25 | --gin.PACKING="True" \ 26 | --seqio_additional_cache_dirs=${CACHE_DIR} \ 27 | --gin.BATCH_SIZE="32" \ -------------------------------------------------------------------------------- /scripts/molora_eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CKPT_DIR=${1:-${CKPT_DIR}} 4 | EVAL_DIR=${2:-${EVAL_DIR}} 5 | 6 | T5X_DIR="`python3 -m scripts.find_module t5x`/.." 7 | FLAXFORMER_DIR="`python3 -m scripts.find_module flaxformer`/.." 8 | echo "Searching for gin configs in:" 9 | echo "- ${T5X_DIR}" 10 | echo "- ${FLAXFORMER_DIR}" 11 | echo "=============================" 12 | CACHE_DIR="raw_tfrecords/you_cache_dir" 13 | 14 | python3 -m t5x.eval \ 15 | --gin_search_paths="${T5X_DIR}" \ 16 | --gin_file="configs/t5/models/t5_1_1_large.gin" \ 17 | --gin_file="configs/molora_eval.gin" \ 18 | --gin.EVAL_OUTPUT_DIR="'${EVAL_DIR}'" \ 19 | --gin.MIXTURE_OR_TASK_NAME="'t0_eval_score_eval'" \ 20 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 1024, 'targets': 256}" \ 21 | --gin.CHECKPOINT_PATH="'${CKPT_DIR}'" \ 22 | --seqio_additional_cache_dirs=${CACHE_DIR} \ 23 | --gin.utils.DatasetConfig.use_cached="True" \ 24 | --gin.utils.DatasetConfig.split="'validation'" \ 25 | --gin.BATCH_SIZE="32" -------------------------------------------------------------------------------- /scripts/molora_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_DIR=${1:-${MODEL_DIR}} 4 | 5 | T5X_DIR="`python3 -m scripts.find_module t5x`/.." 6 | FLAXFORMER_DIR="`python3 -m scripts.find_module flaxformer`/.." 7 | echo "Searching for gin configs in:" 8 | echo "- ${T5X_DIR}" 9 | echo "- ${FLAXFORMER_DIR}" 10 | echo "=============================" 11 | PRETRAINED_MODEL="gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_large/checkpoint_1100000" 12 | CACHE_DIR="raw_tfrecords/you_cache_dir" 13 | 14 | python3 -m t5x.train \ 15 | --gin_search_paths="${T5X_DIR}" \ 16 | --gin_file="configs/t5/models/t5_1_1_large.gin" \ 17 | --gin_file="configs/molora.gin" \ 18 | --gin.MODEL_DIR="'${MODEL_DIR}'" \ 19 | --gin.LOSS_NORMALIZING_FACTOR="'AVERAGE_PER_SEQUENCE'" \ 20 | --gin.MIXTURE_OR_TASK_NAME="'t0_train'" \ 21 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 1024, 'targets': 256}" \ 22 | --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" \ 23 | --gin.TRAIN_STEPS="1_600_000" \ 24 | --gin.USE_CACHED_TASKS="True" \ 25 | --gin.PACKING="True" \ 26 | --seqio_additional_cache_dirs=${CACHE_DIR} \ 27 | --gin.BATCH_SIZE="32" -------------------------------------------------------------------------------- /scripts/mov_eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CKPT_DIR=${1:-${CKPT_DIR}} 4 | EVAL_DIR=${2:-${EVAL_DIR}} 5 | 6 | T5X_DIR="`python3 -m scripts.find_module t5x`/.." 7 | FLAXFORMER_DIR="`python3 -m scripts.find_module flaxformer`/.." 8 | echo "Searching for gin configs in:" 9 | echo "- ${T5X_DIR}" 10 | echo "- ${FLAXFORMER_DIR}" 11 | echo "=============================" 12 | CACHE_DIR="raw_tfrecords/you_cache_dir" 13 | 14 | python3 -m t5x.eval \ 15 | --gin_search_paths="${T5X_DIR}" \ 16 | --gin_file="configs/t5/models/t5_1_1_large.gin" \ 17 | --gin_file="configs/mov_eval.gin" \ 18 | --gin.EVAL_OUTPUT_DIR="'${EVAL_DIR}'" \ 19 | --gin.MIXTURE_OR_TASK_NAME="'t0_eval_score_eval'" \ 20 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 1024, 'targets': 256}" \ 21 | --gin.CHECKPOINT_PATH="'${CKPT_DIR}'" \ 22 | --seqio_additional_cache_dirs=${CACHE_DIR} \ 23 | --gin.utils.DatasetConfig.use_cached="True" \ 24 | --gin.utils.DatasetConfig.split="'validation'" \ 25 | --gin.BATCH_SIZE="32" -------------------------------------------------------------------------------- /scripts/mov_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_DIR=${1:-${MODEL_DIR}} 4 | 5 | T5X_DIR="`python3 -m scripts.find_module t5x`/.." 6 | FLAXFORMER_DIR="`python3 -m scripts.find_module flaxformer`/.." 7 | echo "Searching for gin configs in:" 8 | echo "- ${T5X_DIR}" 9 | echo "- ${FLAXFORMER_DIR}" 10 | echo "=============================" 11 | PRETRAINED_MODEL="gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_large/checkpoint_1100000" 12 | CACHE_DIR="raw_tfrecords/you_cache_dir" 13 | 14 | python3 -m t5x.train \ 15 | --gin_search_paths="${T5X_DIR}" \ 16 | --gin_file="configs/t5/models/t5_1_1_large.gin" \ 17 | --gin_file="configs/mov.gin" \ 18 | --gin.MODEL_DIR="'${MODEL_DIR}'" \ 19 | --gin.LOSS_NORMALIZING_FACTOR="'AVERAGE_PER_SEQUENCE'" \ 20 | --gin.MIXTURE_OR_TASK_NAME="'t0_train'" \ 21 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 1024, 'targets': 256}" \ 22 | --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" \ 23 | --gin.TRAIN_STEPS="1_600_000" \ 24 | --gin.USE_CACHED_TASKS="True" \ 25 | --gin.PACKING="True" \ 26 | --seqio_additional_cache_dirs=${CACHE_DIR} \ 27 | --gin.BATCH_SIZE="32" -------------------------------------------------------------------------------- /scripts/setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Run on TPUs HOME 4 | cd ~ 5 | cp ~/.bashrc ~/.bashrc.backup 6 | 7 | # Install miniconda 8 | mkdir -p ~/miniconda3 9 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh 10 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 11 | rm -rf ~/miniconda3/miniconda.sh 12 | ~/miniconda3/bin/conda init bash 13 | 14 | # Install conda enviroment with python3.8 15 | ~/miniconda3/bin/conda create --name conda-moe-py310 python=3.10 -y 16 | 17 | # install t5x 18 | git clone https://github.com/ahmetustun/t5x.git; cd t5x; ~/miniconda3/envs/conda-moe-py310/bin/python3 -m pip install -e '.[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; cd ~ 19 | 20 | # install flaxformer 21 | git clone https://github.com/ahmetustun/flaxformer.git; cd flaxformer; ~/miniconda3/envs/conda-moe-py310/bin/python3 -m pip install -e .; cd ~ 22 | 23 | # install other packages and fix version mismatches 24 | ~/miniconda3/envs/conda-moe-py310/bin/pip install t5 datasets promptsource markupsafe==2.0.1 ml_dtypes==0.2.0 orbax-checkpoint==0.2.3 --ignore-requires-python -------------------------------------------------------------------------------- /scripts/t0_eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CKPT_DIR=${1:-${CKPT_DIR}} 4 | EVAL_DIR=${2:-${EVAL_DIR}} 5 | 6 | T5X_DIR="`python3 -m scripts.find_module t5x`/.." 7 | FLAXFORMER_DIR="`python3 -m scripts.find_module flaxformer`/.." 8 | echo "Searching for gin configs in:" 9 | echo "- ${T5X_DIR}" 10 | echo "- ${FLAXFORMER_DIR}" 11 | echo "=============================" 12 | CACHE_DIR="raw_tfrecords/you_cache_dir" 13 | 14 | python3 -m t5x.eval \ 15 | --gin_search_paths="${T5X_DIR}" \ 16 | --gin_file="configs/t5/models/t5_1_1_large.gin" \ 17 | --gin_file="configs/t0_eval.gin" \ 18 | --gin.EVAL_OUTPUT_DIR="'${EVAL_DIR}'" \ 19 | --gin.MIXTURE_OR_TASK_NAME="'t0_eval_score_eval'" \ 20 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 1024, 'targets': 256}" \ 21 | --gin.CHECKPOINT_PATH="'${CKPT_DIR}'" \ 22 | --seqio_additional_cache_dirs=${CACHE_DIR} \ 23 | --gin.utils.DatasetConfig.use_cached="True" \ 24 | --gin.utils.DatasetConfig.split="'validation'" \ 25 | --gin.BATCH_SIZE="32" \ -------------------------------------------------------------------------------- /scripts/t0_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_DIR=${1:-${MODEL_DIR}} 4 | 5 | T5X_DIR="`python3 -m scripts.find_module t5x`/.." 6 | FLAXFORMER_DIR="`python3 -m scripts.find_module flaxformer`/.." 7 | echo "Searching for gin configs in:" 8 | echo "- ${T5X_DIR}" 9 | echo "- ${FLAXFORMER_DIR}" 10 | echo "=============================" 11 | PRETRAINED_MODEL="gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_large/checkpoint_1100000" 12 | CACHE_DIR="raw_tfrecords/you_cache_dir" 13 | 14 | python3 -m t5x.train \ 15 | --gin_search_paths="${T5X_DIR}" \ 16 | --gin_file="configs/t5/models/t5_1_1_large.gin" \ 17 | --gin_file="configs/t0.gin" \ 18 | --gin.MODEL_DIR="'${MODEL_DIR}'" \ 19 | --gin.LOSS_NORMALIZING_FACTOR="'AVERAGE_PER_SEQUENCE'" \ 20 | --gin.MIXTURE_OR_TASK_NAME="'t0_train'" \ 21 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 1024, 'targets': 256}" \ 22 | --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" \ 23 | --gin.TRAIN_STEPS="1_110_000" \ 24 | --gin.USE_CACHED_TASKS="True" \ 25 | --gin.PACKING="True" \ 26 | --seqio_additional_cache_dirs=${CACHE_DIR} \ 27 | --gin.BATCH_SIZE="256" -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cohere-Labs-Community/parameter-efficient-moe/e2683fead45eb480bcbf9c2e646fbbc2f7fc6efb/src/__init__.py -------------------------------------------------------------------------------- /src/adafactor_custom.py: -------------------------------------------------------------------------------- 1 | """Custom adafactor rules.""" 2 | 3 | from flax.core.frozen_dict import freeze 4 | from flax.core.frozen_dict import unfreeze 5 | from t5x import adafactor 6 | 7 | 8 | def standard_logical_factor_rules(rules=None): 9 | """Add prompt adafactor rules to your set of rules.""" 10 | if rules is None: 11 | rules = adafactor.standard_logical_factor_rules() 12 | rules = unfreeze(rules) 13 | rules['unmodeled'] = adafactor.FactorDim.NONE 14 | rules['rank'] = adafactor.FactorDim.NONE 15 | rules['expert'] = adafactor.FactorDim.NONE 16 | return freeze(rules) -------------------------------------------------------------------------------- /src/ia3.py: -------------------------------------------------------------------------------- 1 | # This code is taken from https://github.com/google-research/prompt-tuning 2 | # Copyright 2023 Google. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """IA3 implementation""" 17 | 18 | from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union 19 | 20 | import flax.linen as nn 21 | from flax.linen import partitioning 22 | import jax.numpy as jnp 23 | from flaxformer.types import Array, Initializer, DType 24 | 25 | 26 | class IA3(nn.Module): 27 | """IA3 scaling to use with a linear layer. 28 | 29 | Note: 30 | This module is used as the intermediate_conv module in the flaxformer 31 | MlpBlock. The MlpBlock only applies this intermediate conv to one of the 32 | parallel activation functions it uses, but because these parallel 33 | activations are combined with multiplication, IA3 applies an multiplicative 34 | scaling, and multiplication is associative we can apply the scaling to just 35 | that activation and get the same result as if we applied it afterwards. 36 | 37 | Attributes: 38 | init: How to initialize the scaling variable. 39 | axis_name: The logical names of the variable axes, used for partitioning. 40 | dtype: The dtype of the activations for this module. 41 | """ 42 | ia3_init: Callable[[Array, Sequence[int]], Array] = nn.initializers.ones 43 | axis_name: Tuple[str] = ('embed',) 44 | dtype: DType = jnp.float32 45 | 46 | @nn.compact 47 | def __call__(self, x, *args, **kwargs): 48 | del args 49 | del kwargs 50 | *rest, hidden = x.shape 51 | scaling = partitioning.param_with_axes( 52 | 'ia3_scaling', 53 | self.ia3_init, 54 | (hidden,), 55 | axes=self.axis_name 56 | ) 57 | scaling = scaling.astype(self.dtype) 58 | # Reshape to broadcast over batch, seq, etc. 59 | scaling = jnp.reshape(scaling, tuple((1 for _ in rest)) + scaling.shape) 60 | return x * scaling 61 | 62 | 63 | class IA3Attention(nn.Module): 64 | """A version of IA3 scaling to use with the Attention class. 65 | 66 | Note: 67 | Because of where we can hook into the flaxformer attention class (the 68 | `(k|v)_conv` module) the input to this function is already reshaped into 69 | [..., length, heads, kv] so we shape our scaling to match those last two 70 | dimensions. This will result in the same value as if we were to reshape 71 | the variable and do a single d_model scale. 72 | TODO: Rewrite as a single class that infers the number of dims 73 | to extract from the input to use to shape the param from the number of dims 74 | in the axis names. 75 | 76 | Attributes: 77 | init: How to initialize the scaling variable. 78 | axis_name: The logical names of the variable axes, used for partitioning. 79 | dtype: The dtype of the activations for this module. 80 | """ 81 | ia3_init: Callable[[Array, Sequence[int]], Array] = nn.initializers.ones 82 | axis_names: Tuple[str, str] = ('heads', 'kv') 83 | dtype: DType = jnp.float32 84 | 85 | @nn.compact 86 | def __call__(self, x, *args, **kwargs): 87 | del args 88 | del kwargs 89 | *rest, heads, kv = x.shape 90 | scaling = partitioning.param_with_axes( 91 | 'ia3_scaling', 92 | self.ia3_init, 93 | (heads, kv), 94 | axes=self.axis_names 95 | ) 96 | scaling = scaling.astype(self.dtype) 97 | # Reshape to broadcast over batch, seq, etc. 98 | scaling = jnp.reshape(scaling, tuple((1 for _ in rest)) + scaling.shape) 99 | return x * scaling -------------------------------------------------------------------------------- /src/lora.py: -------------------------------------------------------------------------------- 1 | """LoRA implementation""" 2 | 3 | from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import flax.linen as nn 8 | from flax.linen import partitioning 9 | from flaxformer.types import Array, Initializer, DType 10 | from flaxformer.components import dense 11 | from src import routing 12 | 13 | 14 | class LoRA(nn.Module): 15 | """LoRA implementation 16 | 17 | Attributes: 18 | router: Router class 19 | rank: LoRA rank 20 | alpha = LoRA aplha 21 | lora_init_A: LoRA A initializer 22 | lora_init_B: LoRA B initializer 23 | lora_axis_names_A: Sharding axis names for LoRA A 24 | lora_axis_names_B: Sharding axis names for LoRA B 25 | dtype: Activation dtype 26 | output_dim: LoRA output dimensions 27 | """ 28 | rank: int = 2 29 | lora_init_A: Initializer = nn.initializers.normal(stddev=2e-2) 30 | lora_init_B: Initializer = nn.initializers.zeros 31 | lora_axis_names_A: Sequence[str] = ('mlp', 'rank') 32 | lora_axis_names_B: Sequence[str] = ('rank', 'mlp') 33 | alpha = 16 34 | output_dim: Optional[int] = None 35 | dtype: DType = jnp.float32 36 | 37 | @nn.compact 38 | def __call__(self, x: Array, **kwargs) -> Array: 39 | 40 | *rest, hidden = x.shape 41 | 42 | x = jax.lax.convert_element_type(x, self.dtype) 43 | 44 | #[hidden, rank] 45 | lora_a = partitioning.param_with_axes( 46 | 'lora_A', 47 | self.lora_init_A, 48 | (hidden, self.rank), 49 | jnp.float32, 50 | axes=self.lora_axis_names_A) 51 | 52 | lora_a = jax.lax.convert_element_type(lora_a, self.dtype) 53 | 54 | #[batch, seq_len, rank] 55 | ax = jnp.einsum('...d,dr->...r', 56 | x, 57 | lora_a) 58 | 59 | # Add expert axis name to the partitioning axes 60 | ax = partitioning.with_sharding_constraint(ax, ('batch', 'length', 'unmodeled')) 61 | ax = jax.lax.convert_element_type(ax, self.dtype) 62 | 63 | #[rank, hidden] 64 | lora_b = partitioning.param_with_axes( 65 | 'lora_B', 66 | self.lora_init_B, 67 | (self.rank, (self.output_dim if self.output_dim else hidden)), 68 | jnp.float32, 69 | axes=self.lora_axis_names_B) 70 | 71 | lora_b = jax.lax.convert_element_type(lora_b, self.dtype) 72 | 73 | #[batch, seq_len, rank] 74 | bax = jnp.einsum('...r,rd->...d', 75 | ax, 76 | lora_b) 77 | 78 | return bax * (self.alpha / self.rank) 79 | 80 | 81 | class LoRAAttention(nn.Module): 82 | """LoRA implementation for Attention class 83 | 84 | Attributes: 85 | router: Router class 86 | rank: LoRA rank 87 | alpha = LoRA aplha 88 | lora_init_A: LoRA A initializer 89 | lora_init_B: LoRA B initializer 90 | lora_axis_names_A: Sharding axis names for LoRA A 91 | lora_axis_names_B: Sharding axis names for LoRA B 92 | dtype: Activation dtype 93 | output_dim: LoRA output dimensions 94 | num_heads: Number of heads 95 | """ 96 | rank: int = 2 97 | lora_init_A: Initializer = nn.initializers.normal(stddev=2e-2) 98 | lora_init_B: Initializer = nn.initializers.zeros 99 | lora_axis_names_A: Sequence[str] = ('mlp', 'rank') 100 | lora_axis_names_B: Sequence[str] = ('rank', 'mlp') 101 | alpha = 16 102 | num_heads: int = 1 103 | output_dim: Optional[int] = None 104 | dtype: DType = jnp.float32 105 | 106 | @nn.compact 107 | def __call__(self, x: Array, **kwargs) -> Array: 108 | 109 | *rest, hidden = x.shape 110 | 111 | x = jax.lax.convert_element_type(x, self.dtype) 112 | 113 | #[hidden, rank] 114 | lora_a = partitioning.param_with_axes( 115 | 'lora_A', 116 | self.lora_init_A, 117 | (hidden, self.rank), 118 | jnp.float32, 119 | axes=self.lora_axis_names_A) 120 | 121 | lora_a = jax.lax.convert_element_type(lora_a, self.dtype) 122 | 123 | #[batch, seq_len, rank] 124 | ax = jnp.einsum('...d,dr->...r', 125 | x, 126 | lora_a) 127 | 128 | # Add expert axis name to the partitioning axes 129 | ax = partitioning.with_sharding_constraint(ax, ('batch', 'length', 'unmodeled')) 130 | ax = jax.lax.convert_element_type(ax, self.dtype) 131 | 132 | #[rank, hidden] 133 | lora_b = partitioning.param_with_axes( 134 | 'lora_B', 135 | self.lora_init_B, 136 | (self.rank, (self.output_dim if self.output_dim else hidden)), 137 | jnp.float32, 138 | axes=self.lora_axis_names_B) 139 | 140 | lora_b = jax.lax.convert_element_type(lora_b, self.dtype) 141 | 142 | #[batch, seq_len, rank] 143 | bax = jnp.einsum('...r,rd->...d', 144 | ax, 145 | lora_b) 146 | 147 | bax = bax * (self.alpha / self.rank) 148 | 149 | # Reshape to [batch, seq_len, num_heads, head_dim] 150 | bax = jnp.reshape(bax, (*rest, self.num_heads, hidden // self.num_heads)) 151 | 152 | return bax -------------------------------------------------------------------------------- /src/molora.py: -------------------------------------------------------------------------------- 1 | """MoLoRa implementation""" 2 | 3 | from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import flax.linen as nn 8 | from flax.linen import partitioning 9 | from flaxformer.types import Array, Initializer, DType 10 | 11 | from src import routing 12 | 13 | 14 | class MoLoRa(nn.Module): 15 | """MoLoRa implementation 16 | 17 | Attributes: 18 | router: Router class 19 | rank: LoRA rank 20 | alpha = LoRA aplha 21 | lora_init_A: LoRA A initializer 22 | lora_init_B: LoRA B initializer 23 | lora_axis_names_A: Sharding axis names for LoRA A 24 | lora_axis_names_B: Sharding axis names for LoRA B 25 | num_experts: Number of expert 26 | dtype: Activation dtype 27 | output_dim: LoRA output dimensions 28 | """ 29 | router: routing.Router 30 | rank: int = 2 31 | lora_init_A: Initializer = nn.initializers.normal(stddev=2e-2) 32 | lora_init_B: Initializer = nn.initializers.zeros 33 | lora_axis_names_A: Sequence[str] = ('unmodeled', 'mlp', 'unmodeled') 34 | lora_axis_names_B: Sequence[str] = ('unmodeled', 'unmodeled', 'mlp') 35 | alpha = 16 36 | num_experts: int = 1 37 | dtype: DType = jnp.float32 38 | output_dim: Optional[int] = None 39 | 40 | @nn.compact 41 | def __call__(self, x: Array, **kwargs) -> Array: 42 | 43 | *rest, hidden = x.shape 44 | 45 | #x = jax.lax.convert_element_type(x, self.dtype) 46 | 47 | #[num_experts, hidden, rank] 48 | molora_a = partitioning.param_with_axes( 49 | 'lora_A', 50 | self.lora_init_A, 51 | (self.num_experts, hidden, self.rank), 52 | jnp.float32, 53 | axes=self.lora_axis_names_A) 54 | 55 | molora_a = jax.lax.convert_element_type(molora_a, self.dtype) 56 | 57 | #[batch, seq_len, num_experts, rank] 58 | ax = jnp.einsum('bsd,edr->bser', 59 | x, 60 | molora_a) 61 | 62 | # Add expert axis name to the partitioning axes 63 | ax = partitioning.with_sharding_constraint(ax, ('batch', 'length', 'expert', 'rank')) 64 | ax = jax.lax.convert_element_type(ax, self.dtype) 65 | 66 | #[num_experts, rank, output_dim] 67 | molora_b = partitioning.param_with_axes( 68 | 'lora_B', 69 | self.lora_init_B, 70 | (self.num_experts, self.rank, (self.output_dim if self.output_dim else hidden)), 71 | jnp.float32, 72 | axes=self.lora_axis_names_B) 73 | 74 | molora_b = jax.lax.convert_element_type(molora_b, self.dtype) 75 | 76 | #[batch, seq_len, num_experts, rank] 77 | bax = jnp.einsum('bser,erd->bsed', 78 | ax, 79 | molora_b) 80 | 81 | bax = partitioning.with_sharding_constraint(bax, ('batch', 'length', 'expert') + tuple([self.lora_axis_names_B[-1]])) 82 | bax = jax.lax.convert_element_type(bax, self.dtype) 83 | 84 | #[batch, seq_len, num_experts] 85 | router_probs = self.router(x, self.num_experts) 86 | router_probs = partitioning.with_sharding_constraint(router_probs, 87 | ('batch', 'length', 'expert')) 88 | 89 | #[batch, seq_len, hidden_dim] 90 | bax = jnp.einsum('...e,...ed->...d', 91 | router_probs, 92 | bax) 93 | 94 | return bax * (self.alpha / self.rank) 95 | 96 | 97 | class MoLoRaAttention(nn.Module): 98 | """MoLoRa implementation for Attention class 99 | 100 | Attributes: 101 | router: Router class 102 | rank: LoRA rank 103 | alpha = LoRA aplha 104 | lora_init_A: LoRA A initializer 105 | lora_init_B: LoRA B initializer 106 | lora_axis_names_A: Sharding axis names for LoRA A 107 | lora_axis_names_B: Sharding axis names for LoRA B 108 | num_experts: Number of expert 109 | dtype: Activation dtype 110 | output_dim: LoRA output dimensions 111 | num_heads: Number of heads 112 | """ 113 | router: routing.Router 114 | rank: int = 2 115 | lora_init_A: Initializer = nn.initializers.normal(stddev=2e-2) 116 | lora_init_B: Initializer = nn.initializers.zeros 117 | lora_axis_names_A: Sequence[str] = ('unmodeled', 'embed', 'unmodeled') 118 | lora_axis_names_B: Sequence[str] = ('unmodeled', 'unmodeled', 'joined_kv') 119 | alpha = 16 120 | num_experts: int = 1 121 | num_heads: int = 1 122 | dtype: DType = jnp.float32 123 | output_dim: Optional[int] = None 124 | 125 | @nn.compact 126 | def __call__(self, x: Array, **kwargs) -> Array: 127 | 128 | *rest, hidden = x.shape 129 | 130 | #x = jax.lax.convert_element_type(x, self.dtype) 131 | 132 | #[num_experts, hidden, rank] 133 | molora_a = partitioning.param_with_axes( 134 | 'lora_A', 135 | self.lora_init_A, 136 | (self.num_experts, hidden, self.rank), 137 | jnp.float32, 138 | axes=self.lora_axis_names_A) 139 | 140 | molora_a = jax.lax.convert_element_type(molora_a, self.dtype) 141 | 142 | #[batch, seq_len, num_experts, rank] 143 | ax = jnp.einsum('bsd,edr->bser', 144 | x, 145 | molora_a) 146 | 147 | # Add expert axis name to the partitioning axes 148 | ax = partitioning.with_sharding_constraint(ax, ('batch', 'length', 'expert', 'rank')) 149 | ax = jax.lax.convert_element_type(ax, self.dtype) 150 | 151 | #[num_experts, rank, output_dim] 152 | molora_b = partitioning.param_with_axes( 153 | 'lora_B', 154 | self.lora_init_B, 155 | (self.num_experts, self.rank, (self.output_dim if self.output_dim else hidden)), 156 | jnp.float32, 157 | axes=self.lora_axis_names_B) 158 | 159 | molora_b = jax.lax.convert_element_type(molora_b, self.dtype) 160 | 161 | #[batch, seq_len, num_experts, rank] 162 | bax = jnp.einsum('bser,erd->bsed', 163 | ax, 164 | molora_b) 165 | 166 | bax = partitioning.with_sharding_constraint(bax, ('batch', 'length', 'expert') + tuple([self.lora_axis_names_B[-1]])) 167 | bax = jax.lax.convert_element_type(bax, self.dtype) 168 | 169 | #[batch, seq_len, num_experts] 170 | router_probs = self.router(x, self.num_experts) 171 | router_probs = partitioning.with_sharding_constraint(router_probs, 172 | ('batch', 'length', 'expert')) 173 | 174 | #[batch, seq_len, hidden_dim] 175 | bax = jnp.einsum('...e,...ed->...d', 176 | router_probs, 177 | bax) 178 | 179 | # LoRA scaling 180 | bax = bax * (self.alpha / self.rank) 181 | 182 | # Reshape to [batch, seq_len, num_heads, head_dim] 183 | bax = jnp.reshape(bax, (*rest, self.num_heads, hidden // self.num_heads)) 184 | 185 | return bax -------------------------------------------------------------------------------- /src/mov.py: -------------------------------------------------------------------------------- 1 | """MoV implementation""" 2 | 3 | from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import flax.linen as nn 8 | from flax.linen import partitioning 9 | from flaxformer.types import Array, Initializer, DType 10 | from flaxformer.components import dense 11 | from src import routing 12 | 13 | 14 | class MoV(nn.Module): 15 | """MoV implementation 16 | 17 | Attributes: 18 | init: How to initialize the scaling variable. 19 | axis_name: The logical names of the variable axes, used for partitioning. 20 | dtype: The dtype of the activations for this module. 21 | router: Router class 22 | num_experts: Number of experts 23 | """ 24 | router: routing.Router 25 | ia3_init: Callable[[Array, Sequence[int]], Array] = nn.initializers.ones 26 | axis_name: Tuple[str] = ('unmodeled', 'mlp') 27 | num_experts: int = 1 28 | dtype: DType = jnp.float32 29 | 30 | @nn.compact 31 | def __call__(self, x, *args, **kwargs): 32 | del args 33 | del kwargs 34 | *rest, hidden = x.shape 35 | scaling = partitioning.param_with_axes( 36 | 'mov_scaling', 37 | self.ia3_init, 38 | (self.num_experts, hidden), 39 | jnp.float32, 40 | axes=self.axis_name 41 | ) 42 | #[batch, seq_len, num_experts] 43 | router_probs = self.router(x, self.num_experts) 44 | router_probs = partitioning.with_sharding_constraint(router_probs, 45 | ('batch', 'length', 'unmodeled')) 46 | 47 | #[num_experts, hidden_dim] 48 | scaling = jax.lax.convert_element_type(scaling, self.dtype) 49 | 50 | #[batch, seq_len, hidden_dim] 51 | scaling = jnp.einsum('...e,...ed->...d', 52 | router_probs, 53 | scaling) 54 | 55 | #[batch, seq_len, hidden_dim] 56 | #x = jax.lax.convert_element_type(x, self.dtype) 57 | return x * scaling 58 | 59 | 60 | class MoVAttention(nn.Module): 61 | """MoV implementation for the Attention class. 62 | 63 | Attributes: 64 | init: How to initialize the scaling variable. 65 | axis_name: The logical names of the variable axes, used for partitioning. 66 | dtype: The dtype of the activations for this module. 67 | router: Router class 68 | num_experts: Number of experts 69 | """ 70 | router: routing.Router 71 | ia3_init: Callable[[Array, Sequence[int]], Array] = nn.initializers.ones 72 | axis_names: Tuple[str, str] = ('heads', 'unmodeled', 'kv') 73 | num_experts: int = 1 74 | dtype: DType = jnp.float32 75 | 76 | @nn.compact 77 | def __call__(self, x, *args, **kwargs): 78 | del args 79 | del kwargs 80 | *rest, heads, kv = x.shape 81 | scaling = partitioning.param_with_axes( 82 | 'mov_scaling', 83 | self.ia3_init, 84 | (heads, self.num_experts, kv), 85 | jnp.float32, 86 | axes=self.axis_names 87 | ) 88 | 89 | #[batch, seq_len, heads, kv_hidden] 90 | router_probs = self.router(x, self.num_experts) 91 | router_probs = partitioning.with_sharding_constraint(router_probs, 92 | ('batch', 'length', 'heads', 'unmodeled')) 93 | 94 | #[heads, num_experts, kv_hidden] 95 | scaling = jax.lax.convert_element_type(scaling, self.dtype) 96 | 97 | #[batch, seq_len, heads, kv_hidden] 98 | scaling = jnp.einsum('...e,...ed->...d', 99 | router_probs, 100 | scaling) 101 | #x = jax.lax.convert_element_type(x, self.dtype) 102 | return x * scaling -------------------------------------------------------------------------------- /src/partitioning_custom.py: -------------------------------------------------------------------------------- 1 | """Custom partitioning rules.""" 2 | 3 | from t5x import partitioning 4 | 5 | 6 | def standard_logical_axis_rules() -> partitioning.LogicalAxisRules: 7 | """Add specific partitioning rules.""" 8 | return ( 9 | ("unmodeled", None), 10 | ("rank", None), 11 | ("expert", None), 12 | ) -------------------------------------------------------------------------------- /src/routing.py: -------------------------------------------------------------------------------- 1 | """Router implementation.""" 2 | 3 | from typing import Any, Iterable, Optional, Sequence, Tuple, Union 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import flax.linen as nn 8 | from flax.linen import partitioning as flax_partitioning 9 | from flaxformer.types import Array, Initializer, DType 10 | from flaxformer.components import dense 11 | 12 | RouterOutput = Any 13 | 14 | default_kernel_init = nn.initializers.normal(stddev=2e-2) 15 | default_bias_init = nn.initializers.zeros 16 | 17 | class RouterWeights(nn.Module): 18 | """Router module converting token inputs to router logits. 19 | 20 | Attributes: 21 | use_bias: Whether or not to use the bias term in computing the logits. 22 | dtype: Numerical float type for router logit computation. 23 | kernel_init: Initialization scheme for kernel. 24 | bias_init: Initialization scheme for bias. 25 | precision: XLA precision for array computations. 26 | axis: Axes along which to apply the dense router weights transformation. 27 | Defaults to final axis (typically the "hidden dimension"). 28 | kernel_axis_names: Logical axis names to use for kernel sharding. 29 | reshape_kernel: Whether to reshape the kernel parameter to 2D for Adafactor. 30 | """ 31 | use_bias: bool = True 32 | dtype: DType = jnp.bfloat16 33 | kernel_init: Initializer = default_kernel_init # pytype: disable=annotation-type-mismatch # jax-types 34 | bias_init: Initializer = default_bias_init 35 | precision: jax.lax.Precision = jax.lax.Precision.DEFAULT 36 | axis: Union[Iterable[int], int] = -1 37 | kernel_axis_names: Sequence[str] = ('mlp', 'unmodeled') 38 | reshape_kernel: bool = True 39 | 40 | @nn.compact 41 | def __call__(self, token_inputs: Array, num_experts: int) -> Array: 42 | """Applies RouterWeights module. 43 | 44 | Args: 45 | token_inputs: Flattened batch of tokens with shape [num_groups, 46 | group_size, hidden_dim]. 47 | num_experts: Number of experts. 48 | 49 | Returns: 50 | Router logits with shape [num_groups, group_size, num_experts]. 51 | """ 52 | return dense.DenseGeneral( 53 | features=num_experts, 54 | axis=self.axis, 55 | use_bias=self.use_bias, 56 | dtype=self.dtype, 57 | kernel_init=self.kernel_init, 58 | bias_init=self.bias_init, 59 | precision=self.precision, 60 | kernel_axis_names=self.kernel_axis_names, 61 | reshape_kernel=self.reshape_kernel, 62 | name='v')( 63 | token_inputs) 64 | 65 | 66 | class Router(nn.Module): 67 | """Abstract base router class, defining router API and inner workings. 68 | 69 | Attributes: 70 | router_weights: Configurable module used to compute router logits from token 71 | inputs. 72 | jitter_noise: Amplitude of jitter noise applied to router logits. 73 | dtype: Numeric float type for returned combine array. All actual 74 | computations are performed in float32 of the input for stability. 75 | ignore_padding_tokens: Whether to ignore padding tokens during routing. Note 76 | that some routers (e.g. TokensChooseMaskedRouter) will completely ignore 77 | padding tokens, while others (e.g. TokensChooseScatterRouter and 78 | ExpertsChooseMaskedRouter) will simply down-weight the probability of 79 | selecting padding tokens. 80 | """ 81 | router_weights: RouterWeights 82 | jitter_noise: float 83 | dtype: jnp.dtype 84 | ignore_padding_tokens: bool = True 85 | input_axis_names: Sequence[str] = ('batch', 'length', 'mlp') 86 | top_k: int = None 87 | load_balancing_loss: bool = False 88 | 89 | def __call__(self, 90 | token_inputs: Array, 91 | num_experts: int, 92 | apply_jitter: bool = True) -> RouterOutput: 93 | """Computes dispatch and combine arrays for routing to experts. 94 | 95 | Args: 96 | token_inputs: [batch, seq_len, hidden_dim] inputs to 97 | send to experts. 98 | num_experts: Number of experts. 99 | apply_jitter: If true, apply jitter noise during routing. 100 | 101 | Returns: 102 | Router indices or mask arrays (depending on router type). 103 | """ 104 | token_inputs = flax_partitioning.with_sharding_constraint(token_inputs, 105 | self.input_axis_names) 106 | 107 | router_probs, router_logits = self._compute_router_probabilities( 108 | token_inputs, num_experts, apply_jitter) 109 | 110 | if self.ignore_padding_tokens: 111 | # To identify non-padding tokens, we rely on the fact that padding tokens 112 | # in the inputs have already been masked in the default T5 architecture. 113 | # See 114 | # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 115 | # and 116 | # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. 117 | padding_mask = jnp.array((jnp.sum(jnp.abs(token_inputs), axis=-1) > 0), 118 | dtype=token_inputs.dtype) 119 | router_logits *= jnp.expand_dims(padding_mask, axis=-1) 120 | else: 121 | padding_mask = None 122 | 123 | return router_probs 124 | 125 | def _compute_router_probabilities(self, token_inputs: Array, num_experts: int, 126 | apply_jitter: bool) -> Tuple[Array, Array]: 127 | """Computes router probabilities from input tokens. 128 | 129 | Args: 130 | token_inputs: [batch, seq_len, hidden_dim] from which 131 | router probabilities are computed. 132 | num_experts: Number of experts. 133 | apply_jitter: If true, apply jitter noise. 134 | 135 | Returns: 136 | - [batch, seq_len, num_experts] probabilities for 137 | each token and expert. Used for routing tokens to experts. 138 | - [batch, seq_len, num_experts] raw router logits. 139 | Used for computing router z-loss. 140 | """ 141 | # For remainder of routing computation we use float32 to ensure stability. 142 | # See the discussion of "selective precision" in 143 | # https://arxiv.org/abs/2101.03961. 144 | token_inputs = jax.lax.convert_element_type(token_inputs, jnp.float32) 145 | 146 | if apply_jitter and self.jitter_noise > 0: 147 | token_inputs *= jax.random.uniform( 148 | self.make_rng('jitter'), 149 | token_inputs.shape, 150 | token_inputs.dtype, 151 | minval=1.0 - self.jitter_noise, 152 | maxval=1.0 + self.jitter_noise) 153 | 154 | # Shape: [batch, seq_len, num_experts] 155 | router_logits = self.router_weights(token_inputs, num_experts) 156 | 157 | router_probabilities = jax.nn.softmax(router_logits, axis=-1) 158 | 159 | if self.top_k is not None: 160 | topk_mask, top_k_indices = _top_k_mask(router_probabilities, self.top_k) 161 | router_axis_name = self.input_axis_names[:-1] + ('unmodeled',) 162 | topk_mask = flax_partitioning.with_sharding_constraint(topk_mask, router_axis_name) 163 | 164 | return router_probabilities * topk_mask if self.top_k is not None else router_probabilities, router_logits 165 | 166 | def _compute_routing_instructions(self, router_probs: Array, 167 | padding_mask: Optional[Array], 168 | expert_capacity: int) -> RouterOutput: 169 | """Computes instructions for routing inputs to experts.""" 170 | raise NotImplementedError( 171 | 'Router is an abstract class that should be subclassed.') 172 | 173 | 174 | def _load_balancing_loss(router_probs: Array, expert_mask: Array = None) -> float: 175 | """Compute load balancing loss.""" 176 | num_experts = router_probs.shape[-1] 177 | 178 | router_prob_per_expert = jnp.mean( 179 | router_probs, dtype=jnp.float32, axis=-2) 180 | 181 | if expert_mask is not None: 182 | tokens_per_expert = jnp.mean( 183 | expert_mask, dtype=jnp.float32, axis=-2) 184 | return jnp.mean( 185 | tokens_per_expert * router_prob_per_expert, 186 | dtype=jnp.float32) * num_experts**2 187 | else: 188 | return jnp.mean( 189 | router_prob_per_expert, 190 | dtype=jnp.float32) * num_experts**2 191 | 192 | 193 | def _router_z_loss(router_logits: Array) -> float: 194 | """Compute router z-loss. 195 | 196 | The router z-loss was introduced in Designing Effective Sparse Expert Models 197 | (https://arxiv.org/abs/2202.08906). It encourages router logits to remain 198 | small in an effort to improve stability. 199 | 200 | Args: 201 | router_logits: [num_groups, tokens_per_group, num_experts] router 202 | logits. 203 | 204 | Returns: 205 | Scalar router z-loss. 206 | """ 207 | num_groups, tokens_per_group, _ = router_logits.shape 208 | log_z = jax.nn.logsumexp(router_logits, axis=-1) 209 | z_loss = log_z**2 210 | return jnp.sum(z_loss, dtype=jnp.float32) / (num_groups * tokens_per_group) 211 | 212 | 213 | def _favor_one_hot_slices() -> bool: 214 | """Returns true iff running on TPUs.""" 215 | return jax.default_backend() == 'tpu' or jax.devices()[0].platform == 'tpu' 216 | 217 | 218 | def _take_along_axis(array: Array, indices: Array, axis: int) -> Array: 219 | """Takes values from the input array by matching 1D index and data slices. 220 | 221 | This function serves the same purpose as jax.numpy.take_along_axis, except 222 | that it uses one-hot matrix multiplications under the hood on TPUs: 223 | (1) On TPUs, we use one-hot matrix multiplications to select elements from the 224 | array; this is particularly helpful for avoiding erroneous all-gather ops 225 | when running under pjit. 226 | (2) Otherwise, we fall back to jax.numpy.take_along_axis. 227 | 228 | Notes: 229 | - To simplify matters in case (1), we only support slices along the second 230 | or last dimensions. 231 | - We may wish to revisit (1) for very large arrays. 232 | 233 | Args: 234 | array: Source array. 235 | indices: Indices to take along each 1D slice of array. 236 | axis: Axis along which to take 1D slices. 237 | 238 | Returns: 239 | The indexed result. 240 | """ 241 | if array.ndim != indices.ndim: 242 | raise ValueError( 243 | 'indices and array must have the same number of dimensions; ' 244 | f'{indices.ndim} vs. {array.ndim}.') 245 | 246 | if (axis != -1 and axis != array.ndim - 1 and # Not last dimension 247 | axis != 1 and axis != -array.ndim + 1): # Not second dimension 248 | raise ValueError( 249 | 'Only slices along the second or last dimension are supported; ' 250 | f'array.ndim = {array.ndim}, while axis = {axis}.') 251 | 252 | if _favor_one_hot_slices(): 253 | one_hot_length = array.shape[axis] 254 | one_hot_indices = jax.nn.one_hot(indices, one_hot_length, axis=axis) 255 | 256 | if axis == -1 or array.ndim == 1: 257 | # Take i elements from last dimension (s). 258 | # We must use HIGHEST precision to accurately reproduce indexing 259 | # operations with matrix multiplications. 260 | result = jnp.einsum( 261 | '...s,...is->...i', 262 | array, 263 | one_hot_indices, 264 | precision=jax.lax.Precision.HIGHEST) 265 | else: 266 | # Take i elements from second dimension (s). We assume here that we always 267 | # want to slice along the second dimension. 268 | # We must use HIGHEST precision to accurately reproduce indexing 269 | # operations with matrix multiplications. 270 | result = jnp.einsum( 271 | 'ns...,nis...->ni...', 272 | array, 273 | one_hot_indices, 274 | precision=jax.lax.Precision.HIGHEST) 275 | return jax.lax.convert_element_type(result, array.dtype) 276 | else: 277 | return jnp.take_along_axis(array, indices, axis=axis) 278 | 279 | 280 | def _top_k_mask(array: Array, k: int) -> Tuple[Array, Array]: 281 | top_k_indices = jax.lax.top_k(array, k)[-1] 282 | mask = jax.nn.one_hot(top_k_indices, array.shape[-1], dtype=jnp.float32) 283 | mask = jnp.sum(mask, axis=-2) 284 | return mask, top_k_indices -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Sequence, Callable, Any 3 | 4 | def match_any(regexes: Sequence[str]) -> Callable[[str, Any], bool]: 5 | """A traversal that checks if the parameter name matches any regex. 6 | 7 | This is returns a closure over the actual traversal function that takes the 8 | parameter name and value. The return value of this should be in input to the 9 | Traversal used in the MultiOptimizer. 10 | 11 | Args: 12 | regexes: A list of regular expressions that denote which parameter should be 13 | updated by this optimizer. 14 | 15 | Returns: 16 | A function that takes the name and value of a parameter and return True if 17 | that parameter should be updated by the optimizer. 18 | """ 19 | regexes = tuple(re.compile(regex) for regex in regexes) 20 | 21 | def _match_any(path, _): 22 | """True if path matches any regex in regexs, false otherwise.""" 23 | return any(regex.fullmatch(path) for regex in regexes) 24 | 25 | return _match_any -------------------------------------------------------------------------------- /t0_data/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /t0_data/__init__.py: -------------------------------------------------------------------------------- 1 | """Tools for loading prompted tasks in seqio.""" 2 | 3 | from t0_data import tasks, utils 4 | -------------------------------------------------------------------------------- /t0_data/dataset_split.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cohere-Labs-Community/parameter-efficient-moe/e2683fead45eb480bcbf9c2e646fbbc2f7fc6efb/t0_data/dataset_split.pickle -------------------------------------------------------------------------------- /t0_data/datasets.csv: -------------------------------------------------------------------------------- 1 | HF_name,subset,task_by_convention,do_train,do_eval,train_size 2 | crows_pairs,,bias_and_fairness,,BIAS_FAIRNESS, 3 | jigsaw_toxicity_pred,,bias_and_fairness,,BIAS_FAIRNESS, 4 | super_glue,axg,bias_and_fairness,,BIAS_FAIRNESS, 5 | wino_bias,type1_anti,bias_and_fairness,,BIAS_FAIRNESS, 6 | wino_bias,type2_anti,bias_and_fairness,,BIAS_FAIRNESS, 7 | wino_bias,type1_pro,bias_and_fairness,,BIAS_FAIRNESS, 8 | wino_bias,type2_pro,bias_and_fairness,,BIAS_FAIRNESS, 9 | super_glue,wsc.fixed,coreference,SGLUE,BASE,554 10 | winogrande,winogrande_xl,coreference,,BASE,40398 11 | super_glue,cb,NLI,,BASE,250 12 | super_glue,rte,NLI,,BASE,2490 13 | anli,,NLI,,BASE,162865 14 | glue,mrpc,paraphrase,BASE,,3668 15 | glue,qqp,paraphrase,BASE,,363846 16 | paws,labeled_final,paraphrase,BASE,,49401 17 | ai2_arc,ARC-Challenge,QA_closed_book,GPT_EVAL,,1119 18 | ai2_arc,ARC-Easy,QA_closed_book,GPT_EVAL,,2251 19 | kilt_tasks,hotpotqa,QA_closed_book,BASE,,88869 20 | trivia_qa,unfiltered,QA_closed_book,GPT_EVAL,,87622 21 | web_questions,,QA_closed_book,GPT_EVAL,,3778 22 | wiki_qa,,QA_closed_book,BASE,,20360 23 | adversarial_qa,dbidaf,QA_extractive,BASE,,10000 24 | adversarial_qa,dbert,QA_extractive,BASE,,10000 25 | adversarial_qa,droberta,QA_extractive,BASE,,10000 26 | duorc,SelfRC,QA_extractive,BASE,,60721 27 | duorc,ParaphraseRC,QA_extractive,BASE,,69524 28 | ropes,,QA_extractive,BASE,,10924 29 | squad_v2,,QA_extractive,GPT_EVAL,,130319 30 | super_glue,record,QA_extractive,SGLUE,,100730 31 | quoref,,QA_extractive,BASE,,19399 32 | cos_e,v1.11,QA_multiple_choice,BASE,,9741 33 | cosmos_qa,,QA_multiple_choice,BASE,,25262 34 | dream,,QA_multiple_choice,BASE,,6116 35 | openbookqa,main,QA_multiple_choice,GPT_EVAL,,4957 36 | qasc,,QA_multiple_choice,BASE,,8134 37 | quail,,QA_multiple_choice,BASE,,10246 38 | quarel,,QA_multiple_choice,BASE,,1941 39 | quartz,,QA_multiple_choice,BASE,,2696 40 | race,high,QA_multiple_choice,GPT_EVAL,,62445 41 | race,middle,QA_multiple_choice,GPT_EVAL,,25421 42 | sciq,,QA_multiple_choice,BASE,,11679 43 | social_i_qa,,QA_multiple_choice,BASE,,33410 44 | super_glue,boolq,QA_multiple_choice,SGLUE,,9427 45 | super_glue,copa,QA_multiple_choice,SGLUE,BASE,400 46 | super_glue,multirc,QA_multiple_choice,SGLUE,,27243 47 | wiki_hop,original,QA_multiple_choice,BASE,,43738 48 | wiqa,,QA_multiple_choice,BASE,,29808 49 | piqa,,QA_multiple_choice,GPT_EVAL,,16113 50 | amazon_polarity,,sentiment,BASE,,3600000 51 | app_reviews,,sentiment,BASE,,288065 52 | imdb,,sentiment,BASE,,25000 53 | rotten_tomatoes,,sentiment,BASE,,8530 54 | yelp_review_full,,sentiment,BASE,,650000 55 | hellaswag,,story_completion,GPT_EVAL,BASE,39905 56 | common_gen,,structure_to_text,BASE,,67389 57 | wiki_bio,,structure_to_text,BASE,,582659 58 | cnn_dailymail,3.0.0,summarization,BASE,,287113 59 | gigaword,,summarization,BASE,,3803957 60 | multi_news,,summarization,BASE,,44972 61 | samsum,,summarization,BASE,,14732 62 | xsum,,summarization,BASE,,204045 63 | ag_news,,topic_classification,BASE,,120000 64 | dbpedia_14,,topic_classification,BASE,,560000 65 | trec,,topic_classification,BASE,,5452 66 | super_glue,wic,word_sense_disambiguation,SGLUE,BASE,5428 67 | -------------------------------------------------------------------------------- /t0_data/datasets_original.csv: -------------------------------------------------------------------------------- 1 | HF_name,subset,task_by_convention,do_train,do_eval,train_size 2 | crows_pairs,,bias_and_fairness,,BIAS_FAIRNESS, 3 | jigsaw_toxicity_pred,,bias_and_fairness,,BIAS_FAIRNESS, 4 | super_glue,axg,bias_and_fairness,,BIAS_FAIRNESS, 5 | wino_bias,type1_anti,bias_and_fairness,,BIAS_FAIRNESS, 6 | wino_bias,type2_anti,bias_and_fairness,,BIAS_FAIRNESS, 7 | wino_bias,type1_pro,bias_and_fairness,,BIAS_FAIRNESS, 8 | wino_bias,type2_pro,bias_and_fairness,,BIAS_FAIRNESS, 9 | super_glue,wsc.fixed,coreference,SGLUE,BASE,554 10 | winogrande,winogrande_xl,coreference,,BASE,40398 11 | super_glue,cb,NLI,,BASE,250 12 | super_glue,rte,NLI,,BASE,2490 13 | anli,,NLI,,BASE,162865 14 | glue,mrpc,paraphrase,BASE,,3668 15 | glue,qqp,paraphrase,BASE,,363846 16 | paws,labeled_final,paraphrase,BASE,,49401 17 | ai2_arc,ARC-Challenge,QA_closed_book,GPT_EVAL,,1119 18 | ai2_arc,ARC-Easy,QA_closed_book,GPT_EVAL,,2251 19 | kilt_tasks,hotpotqa,QA_closed_book,BASE,,88869 20 | trivia_qa,unfiltered,QA_closed_book,GPT_EVAL,,87622 21 | web_questions,,QA_closed_book,GPT_EVAL,,3778 22 | wiki_qa,,QA_closed_book,BASE,,20360 23 | adversarial_qa,dbidaf,QA_extractive,BASE,,10000 24 | adversarial_qa,dbert,QA_extractive,BASE,,10000 25 | adversarial_qa,droberta,QA_extractive,BASE,,10000 26 | duorc,SelfRC,QA_extractive,BASE,,60721 27 | duorc,ParaphraseRC,QA_extractive,BASE,,69524 28 | ropes,,QA_extractive,BASE,,10924 29 | squad_v2,,QA_extractive,GPT_EVAL,,130319 30 | super_glue,record,QA_extractive,SGLUE,,100730 31 | quoref,,QA_extractive,BASE,,19399 32 | cos_e,v1.11,QA_multiple_choice,BASE,,9741 33 | cosmos_qa,,QA_multiple_choice,BASE,,25262 34 | dream,,QA_multiple_choice,BASE,,6116 35 | openbookqa,main,QA_multiple_choice,GPT_EVAL,,4957 36 | qasc,,QA_multiple_choice,BASE,,8134 37 | quail,,QA_multiple_choice,BASE,,10246 38 | quarel,,QA_multiple_choice,BASE,,1941 39 | quartz,,QA_multiple_choice,BASE,,2696 40 | race,high,QA_multiple_choice,GPT_EVAL,,62445 41 | race,middle,QA_multiple_choice,GPT_EVAL,,25421 42 | sciq,,QA_multiple_choice,BASE,,11679 43 | social_i_qa,,QA_multiple_choice,BASE,,33410 44 | super_glue,boolq,QA_multiple_choice,SGLUE,,9427 45 | super_glue,copa,QA_multiple_choice,SGLUE,BASE,400 46 | super_glue,multirc,QA_multiple_choice,SGLUE,,27243 47 | wiki_hop,original,QA_multiple_choice,BASE,,43738 48 | wiqa,,QA_multiple_choice,BASE,,29808 49 | piqa,,QA_multiple_choice,GPT_EVAL,,16113 50 | amazon_polarity,,sentiment,BASE,,3600000 51 | app_reviews,,sentiment,BASE,,288065 52 | imdb,,sentiment,BASE,,25000 53 | rotten_tomatoes,,sentiment,BASE,,8530 54 | yelp_review_full,,sentiment,BASE,,650000 55 | story_cloze,2016,story_completion,,BASE, 56 | hellaswag,,story_completion,GPT_EVAL,BASE,39905 57 | common_gen,,structure_to_text,BASE,,67389 58 | wiki_bio,,structure_to_text,BASE,,582659 59 | cnn_dailymail,3.0.0,summarization,BASE,,287113 60 | gigaword,,summarization,BASE,,3803957 61 | multi_news,,summarization,BASE,,44972 62 | samsum,,summarization,BASE,,14732 63 | xsum,,summarization,BASE,,204045 64 | ag_news,,topic_classification,BASE,,120000 65 | dbpedia_14,,topic_classification,BASE,,560000 66 | trec,,topic_classification,BASE,,5452 67 | super_glue,wic,word_sense_disambiguation,SGLUE,BASE,5428 68 | -------------------------------------------------------------------------------- /t0_data/tasks.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is taken from https://github.com/bigscience-workshop/t-zero/tree/master/t0 with small modifications. 3 | 4 | This file defines 8 mixtures that we used in the T-Zero paper: 5 | - t0_train: T0 training mixture 6 | - t0+_train: T0+ training mixture 7 | - t0++_train: T0++ training mixture 8 | - t0_eval_score_eval: T0 main evaluation mixture (Figure 4 for instance) 9 | - t0_train_score_eval: Evaluation mixture for checkpoint selection on T0 (validation splits of the training sets) 10 | - t0_train_one_og_prompt: T0 (p=1) training mixture for - one original-task prompt per dataset. Figure 6 11 | - t0_train_all_og_prompts: T0 (p=5.7) training mixture for - all original-task prompts for all datasets. Figure 6 12 | - bias_fairness_eval_score_eval: Bias & fairness evaluation mixture. Appendix B3 13 | """ 14 | 15 | import csv 16 | import functools 17 | import pdb 18 | import pickle 19 | from typing import Dict, List, Optional, Tuple 20 | 21 | import datasets 22 | import pkg_resources 23 | import seqio 24 | import t5 25 | import tensorflow as tf 26 | from promptsource import templates 27 | from t5.data.glue_utils import get_glue_metric, get_super_glue_metric 28 | from t5.evaluation import metrics as mt 29 | 30 | from t0_data import utils 31 | 32 | GET_METRICS = { 33 | "BLEU": mt.bleu, 34 | "ROUGE": mt.rouge, 35 | "Span Squad": mt.span_squad, 36 | "Squad": mt.squad, 37 | "Trivia QA": mt.trivia_qa, 38 | "Accuracy": mt.accuracy, 39 | "Sequence Accuracy": mt.sequence_accuracy, 40 | "Pearson Correlation": mt.pearson_corrcoef, 41 | "Spearman Correlation": mt.spearman_corrcoef, 42 | "MultiRC": mt.multirc_f1_over_all_answers, 43 | "AUC": mt.auc, 44 | "COQA F1": mt.coqa_f1, 45 | "Edit Distance": mt.edit_distance, 46 | # "Mean Reciprocal Rank": mt.accuracy, # NOTE not in T5? 47 | "Other": mt.accuracy, 48 | # Missing support for mean_multiclass_f1 etc. which need a num_classes parameter 49 | } 50 | 51 | MAX_EXAMPLES_PER_DATASET = 500_000 52 | 53 | 54 | def strip_whitespace(output_or_target, example=None, is_target=False): 55 | """Cached tasks from promptsource all have a leading space on the ground-truth targets.""" 56 | return output_or_target.strip() 57 | 58 | 59 | def maybe_get_class_id_postprocessor(template): 60 | if template.get_fixed_answer_choices_list(): 61 | 62 | def postprocess_fn(output_or_target, example=None, is_target=False): 63 | output_or_target = strip_whitespace(output_or_target) 64 | return t5.data.postprocessors.string_label_to_class_id( 65 | output_or_target, label_classes=template.get_fixed_answer_choices_list()) 66 | 67 | return postprocess_fn 68 | 69 | else: 70 | return strip_whitespace 71 | 72 | 73 | def get_tf_dataset(split, shuffle_files, seed, dataset_name, subset_name, template, split_mapping): 74 | # HF datasets does not support file-level shuffling 75 | del shuffle_files, seed 76 | dataset = datasets.load_dataset(dataset_name, subset_name) 77 | dataset = dataset[split_mapping[split]] 78 | dataset = utils.apply_template(dataset, template) 79 | return utils.hf_dataset_to_tf_dataset(dataset) 80 | 81 | 82 | def add_task(dataset_name, subset_name, template_name, task_name=None, split_mapping=None, special_split=None): 83 | template = all_templates.get_dataset(dataset_name, subset_name)[template_name] 84 | task_name = task_name or utils.get_task_name(dataset_name, subset_name, template_name) 85 | 86 | if dataset_name == "glue": 87 | metrics = get_glue_metric(subset_name) 88 | elif dataset_name == "super_glue": 89 | if subset_name in ("wsc.fixed", "multirc"): 90 | # TODO: WSC and MultiRC need special pre/postprocesing 91 | metrics = [mt.accuracy] 92 | else: 93 | metrics = get_super_glue_metric(subset_name) 94 | else: 95 | # TODO what if metric is null? 96 | metrics = [GET_METRICS[m] for m in template.metadata.metrics] 97 | 98 | 99 | tmp_name = dataset_name + "_" + str(subset_name) 100 | with open('t0_data/dataset_split.pickle', 'rb') as handle: 101 | DATASET_INFO = pickle.load(handle) 102 | 103 | if dataset_name == "anli": 104 | dataset_splits = DATASET_INFO[f"{dataset_name}_{special_split}"] 105 | else: 106 | dataset_splits = DATASET_INFO[tmp_name] 107 | 108 | split_mapping = split_mapping or {k: k for k in dataset_splits.keys()} 109 | 110 | dataset_fn = functools.partial( 111 | get_tf_dataset, 112 | seed=None, 113 | dataset_name=dataset_name, 114 | subset_name=subset_name, 115 | template=template, 116 | split_mapping=split_mapping, 117 | ) 118 | data_source = seqio.FunctionDataSource( 119 | dataset_fn, 120 | splits=list(split_mapping.keys()), 121 | num_input_examples={s: dataset_splits[split_mapping[s]].num_examples for s in split_mapping.keys()}, 122 | ) 123 | output_features = { 124 | "inputs": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=False, dtype=tf.int32), 125 | "targets": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=True, dtype=tf.int32), 126 | } 127 | 128 | preprocessors = [ 129 | seqio.preprocessors.tokenize, 130 | seqio.preprocessors.append_eos, 131 | seqio.CacheDatasetPlaceholder(required=True), 132 | ] 133 | 134 | # Add train and normal eval tasks 135 | seqio.TaskRegistry.add( 136 | task_name, 137 | data_source, 138 | preprocessors=preprocessors, 139 | output_features=output_features, 140 | metric_fns=metrics, 141 | postprocess_fn=maybe_get_class_id_postprocessor(template), 142 | #shuffle_buffer_size=50000, 143 | ) 144 | 145 | # Add rank classification eval task 146 | if template.answer_choices: 147 | rank_classification_preprocessor = functools.partial( 148 | t5.data.preprocessors.rank_classification, 149 | inputs_fn=lambda ex: tf.fill((len(ex["answer_choices"]),), ex["inputs"]), 150 | targets_fn=lambda ex: ex["answer_choices"], 151 | is_correct_fn=lambda ex: tf.equal(ex["answer_choices"], tf.strings.strip(ex["targets"])), 152 | weight_fn=lambda ex: 1.0, 153 | ) 154 | fixed_choices = template.get_fixed_answer_choices_list() 155 | num_classes = len(fixed_choices) if fixed_choices else None 156 | seqio.TaskRegistry.add( 157 | task_name + "_score_eval", 158 | data_source, 159 | preprocessors=[rank_classification_preprocessor] + preprocessors, 160 | output_features=output_features, 161 | metric_fns=[functools.partial(t5.evaluation.metrics.rank_classification, num_classes=num_classes)], 162 | postprocess_fn=t5.data.postprocessors.rank_classification, 163 | ) 164 | 165 | 166 | datatset_subset_tuple = Tuple[str, Optional[str]] 167 | t0_eval: Dict[str, List[datatset_subset_tuple]] = {"BASE": [], "BIAS_FAIRNESS": []} 168 | t0_train: Dict[str, List[datatset_subset_tuple]] = { 169 | "BASE": [], 170 | # GPT3 evaluation set 171 | "GPT_EVAL": [], 172 | # SuperGLUE (except RTE and CB) 173 | "SGLUE": [] 174 | } 175 | 176 | gsheet: Dict[datatset_subset_tuple, Dict] = {} 177 | experiment_path = pkg_resources.resource_filename(__name__, "datasets.csv") 178 | with open(experiment_path) as exp_file: 179 | reader = csv.DictReader(exp_file) 180 | for row in reader: 181 | if row["subset"] == "": 182 | row["subset"] = None # to match promptsource.Template object 183 | dataset_subset = (row["HF_name"], row["subset"]) 184 | if row["do_train"] != "": 185 | do_train_source = row["do_train"] 186 | # sanity checks 187 | if do_train_source == "SGLUE": 188 | assert dataset_subset[0] == "super_glue" 189 | t0_train[do_train_source].append(dataset_subset) 190 | if row["do_eval"] != "": 191 | do_eval_source = row["do_eval"] 192 | # sanity checks 193 | if do_eval_source == "BIAS_FAIRNESS": 194 | assert row["task_by_convention"] == "bias_and_fairness" 195 | t0_eval[do_eval_source].append(dataset_subset) 196 | gsheet[dataset_subset] = row 197 | 198 | all_datasets = sum(t0_train.values(), []) + sum(t0_eval.values(), []) 199 | 200 | all_templates = templates.TemplateCollection() 201 | all_templates.remove("anli") # Need to special-case ANLI due to weird split conventions 202 | 203 | # 3 stages of training/ablation: D4 -> GPT -> SuperGLUE 204 | t0_train_mixture: Dict[str, List[str]] = {key: [] for key in t0_train} 205 | t0_eval_mixture: Dict[str, List[str]] = {key: [] for key in t0_eval} 206 | mixture_cap: Dict[str, int] = {} 207 | single_original_task: Dict[Tuple[str, str], str] = {} 208 | all_original_tasks: List[str] = [] 209 | for dataset_name, subset_name in all_templates.keys: 210 | if (dataset_name, subset_name) not in all_datasets: 211 | all_templates.remove(dataset_name, subset_name) 212 | continue 213 | 214 | dataset = all_templates.get_dataset(dataset_name, subset_name) 215 | num_templates = len(dataset.all_template_names) 216 | train_size = gsheet[(dataset_name, subset_name)]["train_size"] 217 | if train_size == "": 218 | train_size = 0 219 | else: 220 | train_size = int(train_size) 221 | if train_size > MAX_EXAMPLES_PER_DATASET: 222 | cap = MAX_EXAMPLES_PER_DATASET // num_templates 223 | else: 224 | cap = train_size 225 | c = 0 226 | for template_name in dataset.all_template_names: 227 | add_task(dataset_name, subset_name, template_name) 228 | c += 1 229 | 230 | #import pdb; pdb.set_trace() 231 | template = dataset[template_name] 232 | 233 | task_name = utils.get_task_name(dataset_name, subset_name, template_name) 234 | 235 | if (dataset_name, subset_name) not in single_original_task and template.metadata.original_task: 236 | single_original_task[(dataset_name, subset_name)] = task_name 237 | 238 | if template.metadata.original_task: 239 | all_original_tasks.append(task_name) 240 | 241 | # Check that the dataset_subset_tuple is in t0_train 242 | for key, dataset_subset_tuples in t0_train.items(): 243 | if (dataset_name, subset_name) in dataset_subset_tuples: 244 | t0_train_mixture[key].append(task_name) 245 | mixture_cap[task_name] = cap 246 | 247 | # Check that the dataset_subset_tuple is in t0_eval 248 | if (dataset_name, subset_name) in t0_eval["BASE"]: 249 | if template.metadata.original_task: 250 | t0_eval_mixture["BASE"].append(task_name) 251 | # TODO use template.metadata.answer_choices here for rank eval 252 | if (dataset_name, subset_name) in t0_eval["BIAS_FAIRNESS"]: 253 | t0_eval_mixture["BIAS_FAIRNESS"].append(task_name) 254 | 255 | # Special case for ANLI, which has weirdly-named splits and rounds that should be subsets 256 | dataset_name, subset_name = ("anli", None) 257 | dataset = all_templates.get_dataset(dataset_name, subset_name) 258 | for anli_round in ("r1", "r2", "r3"): 259 | for template_name in all_templates.get_dataset(dataset_name, subset_name).all_template_names: 260 | task_name = utils.get_task_name(dataset_name, subset_name, template_name) + f"_{anli_round}" 261 | split_mapping = { 262 | "train": f"train_{anli_round}", 263 | "validation": f"dev_{anli_round}", 264 | "test": f"test_{anli_round}", 265 | } 266 | 267 | split_mapping = { 268 | "train": f"train", 269 | "validation": f"validation", 270 | "test": f"test", 271 | } 272 | 273 | add_task(dataset_name, subset_name, template_name, task_name, split_mapping, anli_round) 274 | template = dataset[template_name] 275 | if template.metadata.original_task: 276 | t0_eval_mixture["BASE"].append(task_name) # TODO or add to ANLI special mixture 277 | # TODO use template.metadata.answer_choices here for rank eval 278 | 279 | TASK_BLACKLIST = [ 280 | # Tasks which often tokenize to > 1024 tokens currently 281 | "hotpot_qa_distractor_Generate_Explanations", 282 | "hotpot_qa_fullwiki_Generate_Explanations", 283 | "hotpot_qa_distractor_Generate_Answer_and_Explanations", 284 | "hotpot_qa_fullwiki_Generate_Answer_and_Explanations", 285 | "hotpot_qa_fullwiki_Generate_Answer", 286 | "hotpot_qa_distractor_Generate_Answer", 287 | "hotpot_qa_distractor_Generate_Title_2", 288 | "hotpot_qa_fullwiki_Generate_Title_2", 289 | "hotpot_qa_fullwiki_Generate_Title_1", 290 | "hotpot_qa_distractor_Generate_Title_1", 291 | "hotpot_qa_distractor_Generate_Question", 292 | "hotpot_qa_fullwiki_Generate_Question", 293 | "tab_fact_tab_fact_tab_fact_3", 294 | "tab_fact_tab_fact_tab_fact_2", 295 | "tab_fact_tab_fact_tab_fact_1", 296 | "tab_fact_tab_fact_tab_fact_7", 297 | "tab_fact_tab_fact_tab_fact_4", 298 | "tab_fact_tab_fact_tab_fact_5", 299 | "tab_fact_tab_fact_tab_fact_6", 300 | "wiki_hop_masked_Choose_Best_Object_Candidate", 301 | "wiki_hop_masked_Indirect_Question_about_Birthplace_Citizenship_Place_of_Death", 302 | "narrativeqa_Template_05", 303 | "ecthr_cases_alleged_violation_prediction_silver_rationales", 304 | # Tasks with broken cached files 305 | "gigaword_summarize_", 306 | ] 307 | 308 | # Tasks that failed caching (won't try to fix them for now) - remove when we are done 309 | D4_TRAIN_SCORE_EVAL_TASK_BLACKLIST = [ 310 | "amazon_polarity_Is_this_product_review_positive_score_eval", 311 | "amazon_polarity_Is_this_review_negative_score_eval", 312 | "amazon_polarity_Is_this_review_score_eval", 313 | "amazon_polarity_User_recommend_this_product_score_eval", 314 | "amazon_polarity_convey_negative_or_positive_sentiment_score_eval", 315 | "amazon_polarity_flattering_or_not_score_eval", 316 | "amazon_polarity_negative_or_positive_tone_score_eval", 317 | "amazon_polarity_user_satisfied_score_eval", 318 | "amazon_polarity_would_you_buy_score_eval", 319 | "dbpedia_14_given_a_choice_of_categories__score_eval", 320 | "dbpedia_14_given_list_what_category_does_the_paragraph_belong_to_score_eval", 321 | "dbpedia_14_pick_one_category_for_the_following_text_score_eval", 322 | "wiki_hop_original_choose_best_object_affirmative_1_score_eval", 323 | "wiki_hop_original_choose_best_object_affirmative_2_score_eval", 324 | "wiki_hop_original_choose_best_object_affirmative_3_score_eval", 325 | "wiki_hop_original_choose_best_object_interrogative_1_score_eval", 326 | "wiki_hop_original_choose_best_object_interrogative_2_score_eval", 327 | ] 328 | 329 | #Per Dataset 330 | DATASETS = [ 331 | 'glue_mrpc', 'glue_qqp', 'paws_labeled_final', 'kilt_tasks', 'adversarial_qa', 'duorc', 'ropes', 'quoref', 'cos_e', 332 | 'cosmos_qa', 'dream', 'qasc', 'quail', 'quarel', 'quartz', 'sciq', 'social_i_qa', 'wiki_hop', 'wiqa', 333 | 'amazon_polarity', 'app_reviews', 'imdb', 'rotten_tomatoes', 'yelp_review_full', 'common_gen', 'wiki_bio', 334 | 'cnn_dailymail', 'gigaword', 'multi_news', 'samsum', 'xsum', 'ag_news', 'dbpedia_14', 'trec' 335 | ] 336 | 337 | for dataset in DATASETS: 338 | seqio.MixtureRegistry.add( 339 | dataset, 340 | [task for task in t0_train_mixture["BASE"] if task not in TASK_BLACKLIST and dataset in task], 341 | default_rate=lambda t: mixture_cap[t.name], 342 | ) 343 | 344 | ### All Tasks 345 | 346 | seqio.MixtureRegistry.add( 347 | "t0_train", 348 | [task for task in t0_train_mixture["BASE"] if task not in TASK_BLACKLIST], 349 | default_rate=lambda t: mixture_cap[t.name], 350 | ) 351 | 352 | seqio.MixtureRegistry.add( 353 | "t0+_train", 354 | [task for task in t0_train_mixture["BASE"] + t0_train_mixture["GPT_EVAL"] if task not in TASK_BLACKLIST], 355 | default_rate=lambda t: mixture_cap[t.name], 356 | ) 357 | seqio.MixtureRegistry.add( 358 | "t0++_train", 359 | [ 360 | task for task in t0_train_mixture["BASE"] + t0_train_mixture["GPT_EVAL"] + t0_train_mixture["SGLUE"] 361 | if task not in TASK_BLACKLIST 362 | ], 363 | default_rate=lambda t: mixture_cap[t.name], 364 | ) 365 | 366 | seqio.MixtureRegistry.add( 367 | "t0_eval_score_eval", 368 | [ 369 | task for task in seqio.TaskRegistry.names() if task.endswith("_score_eval") and 370 | task.split("_score_eval")[0] in t0_eval_mixture["BASE"] and task.split("_score_eval")[0] not in TASK_BLACKLIST 371 | ], 372 | default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), 373 | ) 374 | 375 | # Train tasks we don't care about evaluating on 376 | D4_TRAIN_SKIP_EVAL = [ 377 | "paws_labeled_final", 378 | "adversarial_qa_dbidaf", 379 | "adversarial_qa_dbert", 380 | "duorc_ParaphraseRC", 381 | "dream", 382 | "amazon_polarity", 383 | "app_reviews", 384 | "imdb", 385 | "wiki_bio", 386 | "gigaword", 387 | "multi_news", 388 | "samsum", 389 | "dbpedia_14", 390 | "trec", 391 | ] 392 | seqio.MixtureRegistry.add( 393 | "t0_train_score_eval", 394 | [ 395 | task for task in seqio.TaskRegistry.names() 396 | if task.endswith("_score_eval") and task.split("_score_eval")[0] in t0_train_mixture["BASE"] and 397 | task.split("_score_eval")[0] not in TASK_BLACKLIST and task not in D4_TRAIN_SCORE_EVAL_TASK_BLACKLIST and 398 | not any([skip in task for skip in D4_TRAIN_SKIP_EVAL]) and task.split("_score_eval")[0] in all_original_tasks 399 | ], 400 | default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), 401 | ) 402 | seqio.MixtureRegistry.add( 403 | "t0_train_one_og_prompt", 404 | [task for task in single_original_task.values() if task in t0_train_mixture["BASE"] and task not in TASK_BLACKLIST], 405 | default_rate=lambda t: mixture_cap[t.name], 406 | ) 407 | seqio.MixtureRegistry.add( 408 | "t0_train_all_og_prompts", 409 | [task for task in all_original_tasks if task in t0_train_mixture["BASE"] and task not in TASK_BLACKLIST], 410 | default_rate=lambda t: mixture_cap[t.name], 411 | ) 412 | seqio.MixtureRegistry.add( 413 | "bias_fairness_eval_score_eval", 414 | [ 415 | task for task in seqio.TaskRegistry.names() 416 | if task.endswith("_score_eval") and task.split("_score_eval")[0] in t0_eval_mixture["BIAS_FAIRNESS"] 417 | ], 418 | default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), 419 | ) 420 | 421 | -------------------------------------------------------------------------------- /t0_data/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import datasets 4 | import tensorflow as tf 5 | 6 | import promptsource.utils 7 | 8 | 9 | def feature_to_spec(feature, length=False): 10 | if isinstance(feature, datasets.ClassLabel): 11 | return tf.TensorSpec(shape=() if not length else (None if length == -1 else length,), dtype=tf.int64) 12 | elif isinstance(feature, datasets.Value): 13 | return tf.TensorSpec( 14 | shape=() if not length else (None if length == -1 else length,), dtype=getattr(tf.dtypes, feature.dtype) 15 | ) 16 | elif hasattr(feature, "dtype") and hasattr(feature, "shape"): 17 | return tf.TensorSpec(shape=feature.shape, dtype=feature.dtype) 18 | elif isinstance(feature, datasets.Sequence): 19 | return feature_to_spec(feature.feature, length=feature.length) 20 | elif isinstance(feature, list): 21 | return [feature_to_spec(f, length=length) for f in feature] 22 | elif isinstance(feature, dict): 23 | return {k: feature_to_spec(v, length=length) for k, v in feature.items()} 24 | else: 25 | raise ValueError(f"Unparseable feature type {type(feature)}") 26 | 27 | 28 | def hf_dataset_to_tf_dataset(dataset): 29 | return tf.data.Dataset.from_generator( 30 | dataset.__iter__, output_signature={k: feature_to_spec(v) for k, v in dataset.features.items()} 31 | ) 32 | 33 | 34 | def apply_template(dataset, template): 35 | def map_fn(ex): 36 | ex = promptsource.utils.removeHyphen(ex) 37 | inputs_and_targets = template.apply(ex) 38 | answer_choices = template.get_answer_choices_list(ex) 39 | if len(inputs_and_targets) == 2: 40 | inputs, targets = inputs_and_targets 41 | if targets == "": 42 | ex = {"inputs": inputs, "targets": ""} 43 | else: 44 | ex = {"inputs": inputs, "targets": targets} 45 | # When template results in an empty example, template.apply returns [""] 46 | # Also, if the template gets split wrong, len can be > 2 47 | # We will filter these out later 48 | else: 49 | ex = {"inputs": "", "targets": ""} 50 | 51 | if answer_choices: 52 | ex["answer_choices"] = answer_choices 53 | 54 | return ex 55 | 56 | def filter_fn(ex): 57 | return len(ex["inputs"]) > 0 and len(ex["targets"]) > 0 58 | 59 | original_columns = dataset.column_names 60 | dataset = dataset.map(map_fn).filter(filter_fn) 61 | # map keeps original columns, remove them 62 | return dataset.remove_columns(set(original_columns) - {"inputs", "targets", "answer_choices"}) 63 | 64 | 65 | def get_dataset_splits(dataset_name, subset_name=None): 66 | 67 | 68 | #Why is this taking so long 69 | info = datasets.get_dataset_infos(dataset_name) 70 | subset_name = subset_name or list(info.keys())[0] 71 | return info[subset_name].splits 72 | 73 | 74 | def task_clean(text): 75 | # Clean the text according to allowed characters for a task name 76 | return re.sub(r"[^\w\d\._]+", "_", text) 77 | 78 | 79 | def get_task_name(dataset_name, subset_name, template_name): 80 | return task_clean(dataset_name + (f"_{subset_name}_" if subset_name is not None else "_") + template_name) 81 | --------------------------------------------------------------------------------