├── .flake8 ├── .github ├── dependabot.yml └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE.md ├── README.md ├── dev-requirements.txt ├── figures └── hint.png ├── gins ├── finetune_from_scratch.gin ├── full_restore.gin ├── hyper_base.gin ├── hyper_large.gin ├── hyper_small.gin ├── hyper_tiny.gin ├── hyper_xl.gin ├── hyper_xxl.gin ├── hypertune.gin ├── hypertune_decoder_us.gin ├── hypertune_embed.gin ├── hypertune_full_train.gin ├── hypter.gin ├── instruction_embed.gin ├── lora │ ├── hyper_lora_base.gin │ ├── hyper_lora_small.gin │ ├── hyper_lora_xl.gin │ ├── lora.gin │ ├── lora_base.gin │ ├── lora_small.gin │ ├── lora_xl.gin │ └── plain │ │ ├── lora_base.gin │ │ ├── lora_small.gin │ │ └── lora_xl.gin ├── ni_eval.gin ├── ni_train.gin ├── ni_train_mixed.gin ├── partial_train_adafactor.gin ├── partial_train_adafactor_dual.gin ├── partial_train_adafactor_dual_frozen_under.gin ├── partial_train_adafactor_no_roberta.gin ├── partial_train_adam.gin ├── pretrain.gin ├── pretrain_4part.gin ├── pretrain_6part.gin ├── restore_frozen_under.gin ├── restore_pretrained.gin ├── separate_henc.gin ├── t0.gin ├── t0_eval.gin ├── t0_train.gin ├── t0_train_local.gin ├── t0_train_local_copy.gin └── train_only_hnet.gin ├── hyper_task_descriptions ├── __init__.py ├── c4 │ ├── __init__.py │ ├── c4_registry.py │ ├── c4_registry_4part.py │ └── c4_registry_6part.py ├── common │ ├── __init__.py │ └── testing.py ├── hf_vocab.py ├── learning_rate_adafactor.py ├── modeling │ ├── hyper_interactive_model.py │ ├── hyper_network.py │ ├── hyper_transformer.py │ ├── layers.py │ ├── lora.py │ ├── lora_partitioning.py │ ├── losses.py │ ├── roberta_partitioning.py │ └── t5_partitioning.py ├── ni_tasks │ ├── __init__.py │ ├── evaluation.py │ ├── ni_collator.py │ ├── ni_dataset.py │ └── ni_registry.py ├── numeric_task │ ├── numeric_registry.py │ └── words.txt ├── python_scripts │ ├── interactive.py │ ├── lora_params.py │ ├── make_p3_boxplot.py │ ├── p3_perf_over_time.py │ ├── p3_results.py │ ├── poking_the_bear.py │ ├── readme.md │ └── test_all_tf_records.py ├── seqio_tasks │ ├── __init__.py │ ├── all_edited_prompts.txt │ ├── all_t0_task_prefixes.txt │ ├── all_t0_tasks.py │ ├── all_t0_tasks.txt │ ├── beam_requirements.txt │ ├── catwalk_to_seqio.py │ ├── check_t0_tasks.py │ ├── datasets.csv │ ├── few_shot.py │ ├── my_t0_tasks.py │ ├── readme.md │ ├── small_t0_tasks.py │ ├── t0_datasets_mapping.py │ ├── t0_tasks.py │ └── utils.py ├── utils.py └── version.py ├── mypy.ini ├── pyproject.toml ├── pytest.ini ├── requirements.txt ├── scripts ├── eval │ ├── t0_eval.sh │ ├── t0_eval_adapter.sh │ ├── t0_eval_hypter.sh │ └── t0_reg_eval.sh ├── local │ ├── debug_from_t5.sh │ ├── local.sh │ ├── local_debug.sh │ ├── local_eval.sh │ ├── lora_finetune.sh │ └── lora_local.sh ├── lora │ ├── debug_lora.sh │ ├── t0_lora_eval.sh │ ├── train_lora_from_t5.sh │ └── train_plain_lora_from_t5.sh ├── nat_int │ ├── ni_eval.sh │ ├── ni_eval_base.sh │ ├── ni_eval_reg.sh │ ├── ni_train.sh │ ├── ni_train_debug.sh │ ├── ni_train_htune.sh │ ├── ni_train_hypter.sh │ ├── ni_train_mixed.sh │ ├── ni_train_mixed_base.sh │ ├── ni_train_no_fid.sh │ ├── ni_train_no_hnet.sh │ ├── ni_train_only_adapter_no_fid.sh │ ├── ni_train_only_lora_no_fid.sh │ ├── ni_train_only_lora_no_fid_smaller.sh │ ├── ni_train_only_prefix_no_fid.sh │ ├── ni_train_pretrained.sh │ ├── ni_train_pretrained_2pos.sh │ ├── ni_train_pretrained_base.sh │ ├── ni_train_pretrained_base_2pos.sh │ ├── ni_train_pretrained_decoder.sh │ ├── ni_train_pretrained_froz.sh │ ├── ni_train_pretrained_just_prefix.sh │ ├── ni_train_pretrained_just_prefix_alt.sh │ ├── ni_train_pretrained_just_prefix_tanh.sh │ ├── ni_train_pretrained_mimic.sh │ ├── ni_train_pretrained_no_fid.sh │ ├── ni_train_pretrained_our_decoder.sh │ ├── ni_train_pretrained_xxl.sh │ ├── ni_train_reg.sh │ ├── ni_train_reg_base.sh │ ├── ni_train_reg_xxl.sh │ └── ni_train_sep_encoder_no_fid.sh ├── pretraining │ ├── pretrain.sh │ ├── pretrain_base.sh │ ├── pretrain_decoder.sh │ ├── pretrain_hnet_only.sh │ ├── pretrain_htune.sh │ ├── pretrain_lora.sh │ ├── pretrain_non_layer.sh │ ├── pretrain_only_adapter_no_fid.sh │ ├── pretrain_only_lora_no_fid.sh │ ├── pretrain_only_lora_no_fid_smaller.sh │ ├── pretrain_only_prefix_no_fid.sh │ ├── pretrain_only_prefix_no_fid_4_way.sh │ ├── pretrain_only_prefix_no_fid_alt.sh │ ├── pretrain_our_decoder.sh │ ├── pretrain_prefix_just_prefix.sh │ ├── pretrain_prefix_no_fid.sh │ ├── pretrain_sep_enc.sh │ ├── pretrain_six.sh │ └── pretrain_xxl.sh ├── t0_few_shot │ ├── t0_xshot_eval_hint.sh │ ├── t0_xshot_eval_reg.sh │ ├── t0_xshot_train_hint.sh │ └── t0_xshot_train_reg.sh ├── t0_reg_train.sh ├── t0_train_hypter.sh ├── tpu_setup.sh ├── train_from_pretrained.sh └── train_from_t5.sh ├── setup.py └── tests ├── __init__.py ├── hello_test.py └── modeling ├── __init__.py ├── hyper_network_test.py ├── hyper_transformer_test.py ├── layers_test.py ├── lora_network_test.py ├── lora_test.py └── losses_test.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 115 3 | 4 | ignore = 5 | # these rules don't play well with black 6 | # whitespace before : 7 | E203 8 | # line break before binary operator 9 | W503 10 | 11 | exclude = 12 | .venv 13 | .git 14 | __pycache__ 15 | .mypy_cache 16 | 17 | per-file-ignores = 18 | # __init__.py files are allowed to have unused imports and lines-too-long 19 | */__init__.py:F401 20 | */**/**/__init__.py:F401,E501 21 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | ignore: 8 | - dependency-name: "tensorflow" 9 | - dependency-name: "tensorflow-text" 10 | open-pull-requests-limit: 10 11 | - package-ecosystem: "github-actions" 12 | directory: "/" 13 | schedule: 14 | interval: "daily" 15 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.ref }} 5 | cancel-in-progress: true 6 | 7 | on: 8 | pull_request: 9 | branches: 10 | - main 11 | push: 12 | branches: 13 | - main 14 | 15 | env: 16 | # Change this to invalidate existing cache. 17 | CACHE_PREFIX: v2 18 | PYTHON_PATH: ./ 19 | 20 | jobs: 21 | checks: 22 | name: python ${{ matrix.python }} - ${{ matrix.task.name }} 23 | runs-on: [ubuntu-latest] 24 | timeout-minutes: 30 25 | strategy: 26 | fail-fast: false 27 | matrix: 28 | python: [3.8] 29 | task: 30 | - name: Style 31 | run: | 32 | isort --check . 33 | black --check . 34 | 35 | - name: Lint 36 | run: | 37 | flake8 . 38 | 39 | - name: Type check 40 | run: | 41 | mypy . 42 | 43 | - name: Test 44 | run: | 45 | pytest -v --color=yes tests/ 46 | 47 | steps: 48 | - uses: actions/checkout@v3 49 | 50 | - name: Setup Python 51 | uses: actions/setup-python@v4 52 | with: 53 | python-version: ${{ matrix.python }} 54 | 55 | - name: Install prerequisites 56 | run: | 57 | pip install --upgrade pip setuptools wheel virtualenv 58 | 59 | - name: Set build variables 60 | shell: bash 61 | run: | 62 | # Get the exact Python version to use in the cache key. 63 | echo "PYTHON_VERSION=$(python --version)" >> $GITHUB_ENV 64 | echo "RUNNER_ARCH=$(uname -m)" >> $GITHUB_ENV 65 | # Use week number in cache key so we can refresh the cache weekly. 66 | echo "WEEK_NUMBER=$(date +%V)" >> $GITHUB_ENV 67 | 68 | - uses: actions/cache@v3 69 | id: virtualenv-cache 70 | with: 71 | path: .venv 72 | key: ${{ env.CACHE_PREFIX }}-${{ env.WEEK_NUMBER }}-${{ runner.os }}-${{ env.RUNNER_ARCH }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }} 73 | restore-keys: | 74 | ${{ env.CACHE_PREFIX }}-${{ env.WEEK_NUMBER }}-${{ runner.os }}-${{ env.RUNNER_ARCH }}-${{ env.PYTHON_VERSION }} 75 | 76 | - name: Setup virtual environment (no cache hit) 77 | if: steps.virtualenv-cache.outputs.cache-hit != 'true' 78 | run: | 79 | test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv 80 | . .venv/bin/activate 81 | pip install --no-deps -e .[dev] 82 | pip install -r requirements.txt --no-deps 83 | pip install -r dev-requirements.txt --no-deps 84 | 85 | - name: Setup virtual environment (cache hit) 86 | if: steps.virtualenv-cache.outputs.cache-hit == 'true' 87 | run: | 88 | test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv 89 | . .venv/bin/activate 90 | pip install --no-deps -e .[dev] 91 | pip install -r requirements.txt --no-deps 92 | pip install -r dev-requirements.txt --no-deps 93 | 94 | - name: Show environment info 95 | run: | 96 | . .venv/bin/activate 97 | which python 98 | python --version 99 | pip freeze 100 | 101 | - name: ${{ matrix.task.name }} 102 | run: | 103 | . .venv/bin/activate 104 | ${{ matrix.task.run }} 105 | 106 | - name: Clean up 107 | if: always() 108 | run: | 109 | . .venv/bin/activate 110 | pip uninstall -y hyper_task_descriptions 111 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # build artifacts 2 | 3 | .eggs/ 4 | .mypy_cache 5 | hyper_task_descriptions.egg-info/ 6 | build/ 7 | dist/ 8 | pip-wheel-metadata/ 9 | 10 | 11 | # dev tools 12 | 13 | .envrc 14 | .python-version 15 | .idea 16 | .venv/ 17 | .vscode/ 18 | /*.iml 19 | 20 | 21 | # jupyter notebooks 22 | 23 | .ipynb_checkpoints 24 | 25 | 26 | # miscellaneous 27 | 28 | .cache/ 29 | doc/_build/ 30 | *.swp 31 | .DS_Store 32 | 33 | 34 | # python 35 | 36 | *.pyc 37 | *.pyo 38 | __pycache__ 39 | 40 | 41 | # testing and continuous integration 42 | 43 | .coverage 44 | .pytest_cache/ 45 | .benchmarks 46 | 47 | # documentation build artifacts 48 | 49 | docs/build 50 | site/ 51 | *.pyc 52 | venv/ 53 | t5x/ 54 | nvenv/ 55 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | #################################### 2 | ###### Main dev dependencies ####### 3 | #################################### 4 | 5 | # Checks style, syntax, and other useful errors. 6 | flake8==6.0.0 7 | pyflakes==3.0.1 8 | pycodestyle==2.10.0 9 | mccabe==0.7.0 10 | 11 | # Static type checking 12 | mypy==0.971 13 | 14 | # Automatic code formatting 15 | # promptsource specifies these, so we have to match 16 | black==21.12b0 17 | isort==5.8.0 18 | 19 | # Running tests 20 | pytest 21 | 22 | # Flaky tests 23 | flaky>=3.7.0 24 | -------------------------------------------------------------------------------- /figures/hint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/hyper-task-descriptions/fd86a43ac2131582548130d86d57ea977c804ab6/figures/hint.png -------------------------------------------------------------------------------- /gins/full_restore.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Base model. 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import utils 5 | 6 | 7 | # These setting allow us to partially reload a checkpoint, that is, we can load 8 | # most of the model weights from the checkpoint, without it complaining that we 9 | # don't have a weight for our prompt in the checkpoint. 10 | utils.RestoreCheckpointConfig: 11 | # Activate the codepath that allow of the merging of the optimizer state as 12 | # specified in the config (with our new parameter) and the optimizer state as 13 | # defined in the checkpoint. 14 | fallback_to_scratch = False 15 | # Use the T5X assignment map to grab values from the checkpoint. Each entry in 16 | # the map is a regular expression that matches some flatten variable in the 17 | # optimizer state as defined in the model created by the config. The second 18 | # value is the corresponding name in optimizer state as defined by the 19 | # checkpoint. It supports interpolating capture groups from the initial regex. 20 | # If the second pattern it `None` we skip trying to load this variable from 21 | # the checkpoint. 22 | 23 | # reset just the optimizer for now. 24 | assignment_map = None 25 | -------------------------------------------------------------------------------- /gins/hyper_base.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Base model. 2 | from __gin__ import dynamic_registration 3 | 4 | import seqio 5 | from t5x import adafactor 6 | from t5x import models 7 | from t5x.examples.t5 import network 8 | from hyper_task_descriptions.modeling import hyper_network 9 | from hyper_task_descriptions import hf_vocab 10 | from hyper_task_descriptions.modeling import hyper_transformer 11 | 12 | # ------------------- Loss HParam ---------------------------------------------- 13 | Z_LOSS = 0.0001 14 | LABEL_SMOOTHING = 0.0 15 | # NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) 16 | # the loss normalizing factor should be set to pretraining batch_size * 17 | # target_token_length. 18 | LOSS_NORMALIZING_FACTOR = None 19 | # Dropout should be specified in the "run" files 20 | DROPOUT_RATE = %gin.REQUIRED 21 | 22 | # Vocabulary (shared by encoder and decoder) 23 | # VOCABULARY = @hf_vocab.HuggingfaceVocabulary() 24 | # hf_vocab.HuggingfaceVocabulary.model_name = "t5-base" 25 | VOCABULARY = @seqio.SentencePieceVocabulary() 26 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" 27 | 28 | 29 | # ------------------- Optimizer ------------------------------------------------ 30 | # `learning_rate` is set by `Trainer.learning_rate_fn`. 31 | OPTIMIZER = @adafactor.Adafactor() 32 | adafactor.Adafactor: 33 | decay_rate = 0.8 34 | step_offset = 0 35 | logical_factor_rules = @adafactor.standard_logical_factor_rules() 36 | # clipping_threshold = None # the hypernet updates get *consistently* clipped 37 | 38 | # ------------------- Model ---------------------------------------------------- 39 | MODEL = @hyper_transformer.HyperEncoderDecoderContrastiveModel() 40 | hyper_transformer.HyperEncoderDecoderContrastiveModel: 41 | module = @hyper_network.HyperTransformer() 42 | input_vocabulary = %VOCABULARY 43 | output_vocabulary = %VOCABULARY 44 | optimizer_def = %OPTIMIZER 45 | z_loss = %Z_LOSS 46 | label_smoothing = %LABEL_SMOOTHING 47 | loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR 48 | 49 | # ------------------- Network specification ------------------------------------ 50 | hyper_network.HyperTransformer.config = @hyper_network.HyperT5Config() 51 | hyper_network.HyperT5Config: 52 | vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency 53 | dtype = 'bfloat16' 54 | emb_dim = 768 55 | num_heads = 12 56 | num_encoder_layers = 12 57 | num_decoder_layers = 12 58 | head_dim = 64 59 | mlp_dim = 2048 60 | mlp_activations = ('gelu', 'linear') 61 | dropout_rate = %DROPOUT_RATE 62 | logits_via_embedding = False 63 | adapter_size = 64 64 | num_prefix_tokens = 30 65 | -------------------------------------------------------------------------------- /gins/hyper_large.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 small model. 2 | 3 | include 'gins/hyper_base.gin' # imports vocab, optimizer and model. 4 | 5 | # ------------------- Network specification overrides -------------------------- 6 | hyper_network.HyperTransformer.config = @hyper_network.HyperT5Config() 7 | hyper_network.HyperT5Config: 8 | emb_dim = 1024 9 | num_heads = 16 10 | num_encoder_layers = 24 11 | num_decoder_layers = 24 12 | head_dim = 64 13 | mlp_dim = 2816 14 | adapter_size = 8 15 | num_prefix_tokens = 7 16 | -------------------------------------------------------------------------------- /gins/hyper_small.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 small model. 2 | 3 | include 'gins/hyper_base.gin' # imports vocab, optimizer and model. 4 | 5 | # ------------------- Network specification overrides -------------------------- 6 | hyper_network.HyperTransformer.config = @hyper_network.HyperT5Config() 7 | hyper_network.HyperT5Config: 8 | emb_dim = 512 9 | num_heads = 6 10 | num_encoder_layers = 8 11 | num_decoder_layers = 8 12 | head_dim = 64 13 | mlp_dim = 1024 14 | adapter_size = 8 15 | num_prefix_tokens = 7 16 | -------------------------------------------------------------------------------- /gins/hyper_tiny.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 tiny model. 2 | 3 | include 'gins/hyper_base.gin' # imports vocab, optimizer and model. 4 | 5 | # ------------------- Network specification overrides -------------------------- 6 | hyper_network.HyperTransformer.config = @hyper_network.HyperT5Config() 7 | hyper_network.HyperT5Config: 8 | emb_dim = 8 9 | num_heads = 4 10 | num_encoder_layers = 2 11 | num_decoder_layers = 2 12 | head_dim = 3 13 | mlp_dim = 16 14 | adapter_size = 8 15 | num_prefix_tokens = 2 16 | -------------------------------------------------------------------------------- /gins/hyper_xl.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 XL model. 2 | 3 | include 'gins/hyper_base.gin' # imports vocab, optimizer and model. 4 | 5 | # ------------------- Network specification overrides -------------------------- 6 | hyper_network.HyperTransformer.config = @hyper_network.HyperT5Config() 7 | hyper_network.HyperT5Config: 8 | emb_dim = 2048 9 | num_heads = 32 10 | num_encoder_layers = 24 11 | num_decoder_layers = 24 12 | head_dim = 64 13 | mlp_dim = 5120 14 | adapter_size = 512 15 | num_prefix_tokens = 30 16 | -------------------------------------------------------------------------------- /gins/hyper_xxl.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 XL model. 2 | 3 | include 'gins/hyper_base.gin' # imports vocab, optimizer and model. 4 | 5 | # ------------------- Network specification overrides -------------------------- 6 | hyper_network.HyperTransformer.config = @hyper_network.HyperT5Config() 7 | hyper_network.HyperT5Config: 8 | emb_dim = 4096 9 | num_heads = 64 10 | num_encoder_layers = 24 11 | num_decoder_layers = 24 12 | head_dim = 64 13 | mlp_dim = 10240 14 | adapter_size = 512 15 | num_prefix_tokens = 30 16 | -------------------------------------------------------------------------------- /gins/hypertune.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | from t5x import utils 4 | from hyper_task_descriptions.modeling import hyper_network 5 | 6 | # hypertune: decoder in hypernet 7 | hyper_network.HyperTransformer.config = @hyper_network.HyperT5Config() 8 | hyper_network.HyperT5Config: 9 | use_adapter = False 10 | use_prefix = True 11 | use_fusion_in_decoder = False 12 | layer_embedding_method = "decoder" 13 | 14 | # we restore hypernetwork weights from pretrained model 15 | utils.RestoreCheckpointConfig: 16 | fallback_to_scratch = True 17 | assignment_map = ( 18 | # for some reason, we get non-partitioned dict entries. this hacks this by mapping 19 | # them to none. I suspect this means the hypernet has its optimizer states reset... 20 | ('(.*)/param_states/(encoder|hyper_encoder)/(.*)/(scale|kernel|rel_embedding)$', None), 21 | ('(.*)/param_states/(decoder|hyper_decoder)/(.*)/(scale|kernel|rel_embedding)$', None), 22 | # regular restore, using groups 23 | ('(.*)/(hyper/hyper_encoder|encoder)/(.*)', r'\1/encoder/\3'), 24 | ('(.*)/(hyper/hyper_decoder|decoder)/(.*)', r'\1/decoder/\3'), 25 | # the non-t5 bits of hypernet need to be initialised from scratch 26 | ('.*hyper/[^h].*', None), 27 | ) -------------------------------------------------------------------------------- /gins/hypertune_decoder_us.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | from t5x import utils 4 | from hyper_task_descriptions.modeling import hyper_network 5 | 6 | # hypertune: decoder in hypernet 7 | hyper_network.HyperTransformer.config = @hyper_network.HyperT5Config() 8 | hyper_network.HyperT5Config: 9 | use_adapter = True 10 | use_prefix = True 11 | use_fusion_in_decoder = True 12 | layer_embedding_method = "decoder_test" 13 | 14 | # we restore hypernetwork weights from pretrained model 15 | utils.RestoreCheckpointConfig: 16 | fallback_to_scratch = True 17 | assignment_map = ( 18 | # for some reason, we get non-partitioned dict entries. this hacks this by mapping 19 | # them to none. I suspect this means the hypernet has its optimizer states reset... 20 | #('(.*)/param_states/(encoder|hyper_encoder)/(.*)/(scale|kernel|rel_embedding)$', None), 21 | #('(.*)/param_states/(decoder|hyper_decoder)/(.*)/(scale|kernel|rel_embedding)$', None), 22 | # regular restore, using groups 23 | ('(.*)/(hyper/hyper_decoder|decoder)/(.*)', r'\1/decoder/\3'), 24 | # the non-t5 bits of hypernet need to be initialised from scratch 25 | ('.*hyper/[^h].*', None), 26 | ) -------------------------------------------------------------------------------- /gins/hypertune_embed.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | from t5x import utils 4 | from hyper_task_descriptions.modeling import hyper_network 5 | 6 | # hypertune: decoder in hypernet 7 | hyper_network.HyperTransformer.config = @hyper_network.HyperT5Config() 8 | hyper_network.HyperT5Config: 9 | use_adapter = False 10 | use_prefix = True 11 | use_fusion_in_decoder = True 12 | layer_embedding_method = "decoder" 13 | per_layer_hnet = True 14 | -------------------------------------------------------------------------------- /gins/hypertune_full_train.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | from t5x import utils 4 | from hyper_task_descriptions.modeling import hyper_network 5 | 6 | # hypertune: decoder in hypernet 7 | hyper_network.HyperTransformer.config = @hyper_network.HyperT5Config() 8 | hyper_network.HyperT5Config: 9 | use_adapter = False 10 | use_prefix = True 11 | use_fusion_in_decoder = True 12 | layer_embedding_method = "decoder" 13 | per_layer_hnet = True 14 | 15 | # we restore hypernetwork weights from pretrained model 16 | utils.RestoreCheckpointConfig: 17 | fallback_to_scratch = True 18 | assignment_map = ( 19 | # for some reason, we get non-partitioned dict entries. this hacks this by mapping 20 | # them to none. I suspect this means the hypernet has its optimizer states reset... 21 | #('(.*)/param_states/(encoder|hyper_encoder)/(.*)/(scale|kernel|rel_embedding)$', None), 22 | #('(.*)/param_states/(decoder|hyper_decoder)/(.*)/(scale|kernel|rel_embedding)$', None), 23 | # regular restore, using groups 24 | ('(.*)/(hyper/hyper_encoder|encoder)/(.*)', r'\1/encoder/\3'), 25 | ('(.*)/(hyper/hyper_decoder|decoder)/(.*)', r'\1/decoder/\3'), 26 | # the non-t5 bits of hypernet need to be initialised from scratch 27 | ('.*hyper/[^h].*', None), 28 | ) -------------------------------------------------------------------------------- /gins/hypter.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | from t5x import utils 4 | from hyper_task_descriptions.modeling import hyper_network 5 | 6 | # hypter: per-layer hnet, separate (frozen) encoder 7 | # we fully finetune, which is a bit different, but I think justified. 8 | # adapter-only 9 | hyper_network.HyperTransformer.config = @hyper_network.HyperT5Config() 10 | hyper_network.HyperT5Config: 11 | use_adapter = True 12 | adapter_size = 8 # required due to size. 13 | use_prefix = False 14 | use_fusion_in_decoder = False 15 | layer_embedding_method = "none" 16 | per_layer_hnet = True 17 | share_hnet_encoder = False 18 | -------------------------------------------------------------------------------- /gins/instruction_embed.gin: -------------------------------------------------------------------------------- 1 | 2 | # ------------------- Network specification overrides -------------------------- 3 | hyper_network.HyperTransformer.config = @hyper_network.HyperT5Config() 4 | hyper_network.HyperT5Config: 5 | use_adapter = True 6 | use_prefix = True 7 | use_fusion_in_decoder = True 8 | share_hnet_encoder = True 9 | use_linear = False 10 | -------------------------------------------------------------------------------- /gins/lora/hyper_lora_base.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 small model. 2 | 3 | include 'gins/lora/lora_base.gin' # imports vocab, optimizer and model. 4 | 5 | # ------------------- Network specification overrides -------------------------- 6 | lora_network.LoraTransformer.config = @hyper_network.HyperT5Config() 7 | hyper_network.HyperT5Config: 8 | use_lora = True 9 | 10 | -------------------------------------------------------------------------------- /gins/lora/hyper_lora_small.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 small model. 2 | 3 | include 'gins/lora/lora_small.gin' # imports vocab, optimizer and model. 4 | 5 | # ------------------- Network specification overrides -------------------------- 6 | lora_network.LoraTransformer.config = @hyper_network.HyperT5Config() 7 | hyper_network.HyperT5Config: 8 | use_lora = True 9 | 10 | -------------------------------------------------------------------------------- /gins/lora/hyper_lora_xl.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 small model. 2 | 3 | include 'gins/lora/lora_xl.gin' # imports vocab, optimizer and model. 4 | 5 | # ------------------- Network specification overrides -------------------------- 6 | lora_network.LoraTransformer.config = @hyper_network.HyperT5Config() 7 | hyper_network.HyperT5Config: 8 | use_lora = True 9 | lora_ranks = (8, None, 8, None) 10 | -------------------------------------------------------------------------------- /gins/lora/lora.gin: -------------------------------------------------------------------------------- 1 | from t5x import utils 2 | from hyper_task_descriptions import utils as hyper_utils 3 | 4 | hyper_utils.match_any_optax.regexes = [".*lora*.*"] 5 | 6 | utils.RestoreCheckpointConfig: 7 | assignment_map = ( 8 | #(r"^.*hyper.*$", None), 9 | #(r"^.*lora_a.*$", None), 10 | #(r"^.*lora_b.*$", None), 11 | ) 12 | 13 | # , ".*hyper.*"] 14 | -------------------------------------------------------------------------------- /gins/lora/lora_base.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Base model. 2 | from __gin__ import dynamic_registration 3 | 4 | import seqio 5 | from t5x import adafactor 6 | from t5x import models 7 | from t5x.examples.t5 import network 8 | from hyper_task_descriptions.modeling import hyper_network 9 | from hyper_task_descriptions.modeling import lora_network 10 | from hyper_task_descriptions import hf_vocab 11 | from hyper_task_descriptions.modeling import hyper_transformer 12 | from hyper_task_descriptions import utils as hyper_utils 13 | 14 | # ------------------- Loss HParam ---------------------------------------------- 15 | Z_LOSS = 0.0001 16 | LABEL_SMOOTHING = 0.0 17 | # NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) 18 | # the loss normalizing factor should be set to pretraining batch_size * 19 | # target_token_length. 20 | LOSS_NORMALIZING_FACTOR = None 21 | # Dropout should be specified in the "run" files 22 | DROPOUT_RATE = %gin.REQUIRED 23 | 24 | # Vocabulary (shared by encoder and decoder) 25 | VOCABULARY = @hf_vocab.HuggingfaceVocabulary() 26 | hf_vocab.HuggingfaceVocabulary.model_name = "t5-base" 27 | #VOCABULARY = @seqio.SentencePieceVocabulary() 28 | #seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" 29 | 30 | 31 | # ------------------- Optimizer ------------------------------------------------ 32 | # `learning_rate` is set by `Trainer.learning_rate_fn`. 33 | OPTIMIZER = @adafactor.Adafactor() 34 | adafactor.Adafactor: 35 | decay_rate = 0.8 36 | step_offset = 0 37 | logical_factor_rules = @adafactor.standard_logical_factor_rules() 38 | # clipping_threshold = None # the hypernet updates get *consistently* clipped 39 | 40 | # ------------------- Model ---------------------------------------------------- 41 | MODEL = @hyper_transformer.HyperEncoderDecoderContrastiveModel() 42 | hyper_transformer.HyperEncoderDecoderContrastiveModel: 43 | module = @lora_network.LoraTransformer() 44 | input_vocabulary = %VOCABULARY 45 | output_vocabulary = %VOCABULARY 46 | optimizer_def = %OPTIMIZER 47 | z_loss = %Z_LOSS 48 | label_smoothing = %LABEL_SMOOTHING 49 | loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR 50 | 51 | # ------------------- Network specification ------------------------------------ 52 | lora_network.LoraTransformer.config = @hyper_network.HyperT5Config() 53 | hyper_network.HyperT5Config: 54 | vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency 55 | dtype = 'bfloat16' 56 | emb_dim = 768 57 | num_heads = 12 58 | num_encoder_layers = 12 59 | num_decoder_layers = 12 60 | head_dim = 64 61 | mlp_dim = 2048 62 | mlp_activations = ('gelu', 'linear') 63 | dropout_rate = %DROPOUT_RATE 64 | logits_via_embedding = False 65 | adapter_size = 64 66 | num_prefix_tokens = 30 67 | use_lora = False 68 | lora_ranks = (2, None, 2, None) 69 | 70 | hyper_utils.match_any_optax.regexes = [".*lora*.*"] 71 | -------------------------------------------------------------------------------- /gins/lora/lora_small.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 small model. 2 | 3 | include 'gins/lora/lora_base.gin' # imports vocab, optimizer and model. 4 | 5 | # ------------------- Network specification overrides -------------------------- 6 | lora_network.LoraTransformer.config = @hyper_network.HyperT5Config() 7 | hyper_network.HyperT5Config: 8 | emb_dim = 512 9 | num_heads = 6 10 | num_encoder_layers = 8 11 | num_decoder_layers = 8 12 | head_dim = 64 13 | mlp_dim = 1024 14 | adapter_size = 8 15 | num_prefix_tokens = 7 16 | lora_ranks = (8, None, 8, None) 17 | 18 | 19 | -------------------------------------------------------------------------------- /gins/lora/lora_xl.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 XL model. 2 | 3 | include 'gins/lora/lora_base.gin' # imports vocab, optimizer and model. 4 | 5 | # ------------------- Network specification overrides -------------------------- 6 | lora_network.LoraTransformer.config = @hyper_network.HyperT5Config() 7 | hyper_network.HyperT5Config: 8 | emb_dim = 2048 9 | num_heads = 32 10 | num_encoder_layers = 24 11 | num_decoder_layers = 24 12 | head_dim = 64 13 | mlp_dim = 5120 14 | adapter_size = 64 15 | num_prefix_tokens = 15 16 | lora_ranks = (8, None, 8, None) 17 | -------------------------------------------------------------------------------- /gins/lora/plain/lora_base.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Base model. 2 | from __gin__ import dynamic_registration 3 | 4 | import seqio 5 | from t5x import adafactor 6 | from t5x import models 7 | from t5x import utils 8 | from t5x.examples.t5 import network 9 | from hyper_task_descriptions.modeling import hyper_network 10 | from hyper_task_descriptions.modeling import lora_network 11 | from hyper_task_descriptions import hf_vocab 12 | from hyper_task_descriptions.modeling import hyper_transformer 13 | from hyper_task_descriptions import utils as hyper_utils 14 | 15 | # ------------------- Loss HParam ---------------------------------------------- 16 | Z_LOSS = 0.0001 17 | LABEL_SMOOTHING = 0.0 18 | # NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) 19 | # the loss normalizing factor should be set to pretraining batch_size * 20 | # target_token_length. 21 | LOSS_NORMALIZING_FACTOR = None 22 | # Dropout should be specified in the "run" files 23 | DROPOUT_RATE = %gin.REQUIRED 24 | 25 | # Vocabulary (shared by encoder and decoder) 26 | VOCABULARY = @hf_vocab.HuggingfaceVocabulary() 27 | hf_vocab.HuggingfaceVocabulary.model_name = "t5-base" 28 | #VOCABULARY = @seqio.SentencePieceVocabulary() 29 | #seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" 30 | 31 | 32 | # ------------------- Optimizer ------------------------------------------------ 33 | # `learning_rate` is set by `Trainer.learning_rate_fn`. 34 | OPTIMIZER = @adafactor.Adafactor() 35 | adafactor.Adafactor: 36 | decay_rate = 0.8 37 | step_offset = 0 38 | logical_factor_rules = @adafactor.standard_logical_factor_rules() 39 | # clipping_threshold = None # the hypernet updates get *consistently* clipped 40 | 41 | # ------------------- Model ---------------------------------------------------- 42 | MODEL = @hyper_transformer.HyperEncoderDecoderModel() 43 | hyper_transformer.HyperEncoderDecoderModel: 44 | module = @lora_network.LoraTransformer() 45 | input_vocabulary = %VOCABULARY 46 | output_vocabulary = %VOCABULARY 47 | optimizer_def = %OPTIMIZER 48 | z_loss = %Z_LOSS 49 | label_smoothing = %LABEL_SMOOTHING 50 | loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR 51 | 52 | # ------------------- Network specification ------------------------------------ 53 | lora_network.LoraTransformer.config = @hyper_network.HyperT5Config() 54 | hyper_network.HyperT5Config: 55 | vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency 56 | dtype = 'bfloat16' 57 | emb_dim = 768 58 | num_heads = 12 59 | num_encoder_layers = 12 60 | num_decoder_layers = 12 61 | head_dim = 64 62 | mlp_dim = 2048 63 | mlp_activations = ('gelu', 'linear') 64 | dropout_rate = %DROPOUT_RATE 65 | logits_via_embedding = False 66 | use_lora = False 67 | lora_ranks = (2, None, 2, None) 68 | use_prefix = False 69 | 70 | hyper_utils.match_any_optax.regexes = [".*lora*.*"] 71 | 72 | utils.RestoreCheckpointConfig: 73 | assignment_map = ( 74 | #(r"^.*hyper.*$", None), 75 | #(r"^.*lora_a.*$", None), 76 | #(r"^.*lora_b.*$", None), 77 | ) 78 | -------------------------------------------------------------------------------- /gins/lora/plain/lora_small.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 small model. 2 | 3 | include 'gins/lora/plain/lora_base.gin' # imports vocab, optimizer and model. 4 | 5 | # ------------------- Network specification overrides -------------------------- 6 | lora_network.LoraTransformer.config = @hyper_network.HyperT5Config() 7 | hyper_network.HyperT5Config: 8 | emb_dim = 512 9 | num_heads = 6 10 | num_encoder_layers = 8 11 | num_decoder_layers = 8 12 | head_dim = 64 13 | mlp_dim = 1024 14 | lora_ranks = (4, None, 4, None) 15 | 16 | 17 | -------------------------------------------------------------------------------- /gins/lora/plain/lora_xl.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 XL model. 2 | 3 | include 'gins/lora/plain/lora_base.gin' # imports vocab, optimizer and model. 4 | 5 | # ------------------- Network specification overrides -------------------------- 6 | lora_network.LoraTransformer.config = @hyper_network.LoraT5Config() 7 | hyper_network.HyperT5Config: 8 | emb_dim = 2048 9 | num_heads = 32 10 | num_encoder_layers = 24 11 | num_decoder_layers = 24 12 | head_dim = 64 13 | mlp_dim = 5120 14 | lora_ranks = (8, None, 8, None) 15 | -------------------------------------------------------------------------------- /gins/ni_eval.gin: -------------------------------------------------------------------------------- 1 | # Defaults for eval.py. 2 | # 3 | # 4 | # You must also include a binding for MODEL. 5 | # 6 | # Required to be set: 7 | # 8 | # - CHECKPOINT_PATH: The model checkpoint to evaluate 9 | # - EVAL_OUTPUT_DIR: The dir to write results to. 10 | # 11 | # 12 | # Commonly overridden options: 13 | # 14 | # - DatasetConfig.split 15 | # - DatasetConfig.batch_size 16 | from __gin__ import dynamic_registration 17 | 18 | import __main__ as eval_script 19 | from t5x import partitioning 20 | from t5x import utils 21 | 22 | import seqio 23 | from seqio import loggers 24 | from hyper_task_descriptions.ni_tasks import ni_registry # Needed to define the t0 eval mixtures 25 | 26 | # Must be overridden 27 | MIXTURE_OR_TASK_NAME = "natural_instructions" 28 | CHECKPOINT_PATH = %gin.REQUIRED 29 | EVAL_OUTPUT_DIR = %gin.REQUIRED 30 | TASK_FEATURE_LENGTHS = {"inputs": 1024, "hyper_inputs": 1024, "task_names": 1, "targets": 256} 31 | DROPOUT_RATE = 0.0 32 | 33 | # DEPRECATED: Import the this module in your gin file. 34 | MIXTURE_OR_TASK_MODULE = None 35 | 36 | eval_script.evaluate: 37 | model = %MODEL # imported from separate gin file 38 | dataset_cfg = @utils.DatasetConfig() 39 | partitioner = @partitioning.PjitPartitioner() 40 | restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() 41 | output_dir = %EVAL_OUTPUT_DIR 42 | inference_evaluator_cls = @seqio.Evaluator 43 | 44 | seqio.Evaluator.logger_cls = [@loggers.JSONLogger, @seqio.TensorBoardLogger] 45 | 46 | partitioning.PjitPartitioner.num_partitions = 2 47 | 48 | utils.DatasetConfig: 49 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME 50 | task_feature_lengths = %TASK_FEATURE_LENGTHS 51 | split = 'test' 52 | batch_size = 256 53 | shuffle = False 54 | seed = 42 55 | use_cached = %USE_CACHED_TASKS 56 | pack = False 57 | use_custom_packing_ops = False 58 | module = %MIXTURE_OR_TASK_MODULE 59 | 60 | utils.RestoreCheckpointConfig: 61 | path = %CHECKPOINT_PATH 62 | mode = 'specific' 63 | dtype = 'float32' 64 | strict = True # make sure we actually load everything! 65 | -------------------------------------------------------------------------------- /gins/ni_train.gin: -------------------------------------------------------------------------------- 1 | # For training T0 (xxl = 11b, xl = 3b). Make sure you have cached p3 first! 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import models 5 | from t5x import trainer 6 | from t5x import utils 7 | import seqio 8 | from hyper_task_descriptions.ni_tasks import ni_registry 9 | 10 | import __main__ as train_script 11 | 12 | include "t5x/configs/runs/finetune.gin" 13 | include "gins/t0.gin" # This overrides some default config in `t5x/configs/runs/finetune.gin` 14 | include "gins/restore_pretrained.gin" # for loading from checkpoints 15 | 16 | TASK_FEATURE_LENGTHS = {"inputs": 1024, "hyper_inputs": 1024, "task_names": 1, "targets": 256} 17 | MIXTURE_OR_TASK_NAME = "natural_instructions" 18 | 19 | trainer.Trainer.num_microbatches = 16 # 2048 // 16 20 | trainer.Trainer.weight_metrics_computer = @trainer.WeightMetricsComputer() 21 | 22 | # the batch sizes are really big so I increase stats logging 23 | # for more visibility. 24 | train_script.train: 25 | stats_period = 50 # default is eval_period, which is 1000 26 | 27 | 28 | utils.SaveCheckpointConfig: 29 | period = 100 # checkpoint frequency 30 | -------------------------------------------------------------------------------- /gins/ni_train_mixed.gin: -------------------------------------------------------------------------------- 1 | # For training T0 (xxl = 11b, xl = 3b). Make sure you have cached p3 first! 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import models 5 | from t5x import trainer 6 | from t5x import utils 7 | import seqio 8 | from hyper_task_descriptions.c4 import c4_ni_registry 9 | 10 | import __main__ as train_script 11 | 12 | include "t5x/configs/runs/finetune.gin" 13 | include "gins/t0.gin" # This overrides some default config in `t5x/configs/runs/finetune.gin` 14 | include "gins/restore_pretrained.gin" # for loading from checkpoints 15 | 16 | TASK_FEATURE_LENGTHS = {"inputs": 1024, "hyper_inputs": 1024, "task_names": 1, "targets": 256} 17 | MIXTURE_OR_TASK_NAME = "c4_ni" 18 | 19 | trainer.Trainer.num_microbatches = 16 # 2048 // 16 20 | trainer.Trainer.weight_metrics_computer = @trainer.WeightMetricsComputer() 21 | 22 | # the batch sizes are really big so I increase stats logging 23 | # for more visibility. 24 | train_script.train: 25 | stats_period = 50 # default is eval_period, which is 1000 26 | 27 | 28 | utils.SaveCheckpointConfig: 29 | period = 100 # checkpoint frequency 30 | -------------------------------------------------------------------------------- /gins/partial_train_adafactor.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Base model. 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import adafactor 5 | from t5x import optimizers 6 | from hyper_task_descriptions import utils as hyper_utils 7 | from flax import traverse_util 8 | 9 | # gin that allows partial training based on regex matching. 10 | 11 | # ------------------- Partial loading ------------------------------------------------ 12 | OPTIMIZER = @optimizers.MultiOptimizer() 13 | # note you can add more traversals if you want different optimizer settings 14 | # for dfferent parts of the model. 15 | # See https://github.com/google-research/t5x/blob/main/docs/usage/gin.md#scoping 16 | # for how to create multiple specialised instances of the same class. 17 | optimizers.MultiOptimizer: 18 | traversals_and_optimizers = ((@hyper/traverse_util.ModelParamTraversal(), 19 | @hyper/adafactor.Adafactor()), 20 | (@roberta/traverse_util.ModelParamTraversal(), 21 | @roberta/adafactor.Adafactor()), 22 | (@t5/traverse_util.ModelParamTraversal(), 23 | @t5/adafactor.Adafactor()),) 24 | 25 | # MultiOptimizer will match any parameter with a flattened name that 26 | # matches *any* of the regular expressions in the list. 27 | # hyper - we turn off param scaling, offset the step 28 | hyper/adafactor.Adafactor: 29 | multiply_by_parameter_scale = False 30 | step_offset = 1100000 31 | hyper/traverse_util.ModelParamTraversal: 32 | filter_fn = @hyper/hyper_utils.match_any() 33 | hyper/hyper_utils.match_any.regexes = [".*/hyper/[^e].*"] 34 | # hyper encoder - state is fresh so reset step 35 | roberta/adafactor.Adafactor: 36 | step_offset = 1100000 37 | roberta/traverse_util.ModelParamTraversal: 38 | filter_fn = @roberta/hyper_utils.match_any() 39 | roberta/hyper_utils.match_any.regexes = [".*/hyper/e.*"] 40 | # nothing special for this one 41 | t5/traverse_util.ModelParamTraversal: 42 | filter_fn = @t5/hyper_utils.inverse_match_any() 43 | t5/hyper_utils.inverse_match_any.regexes = [".*hyper.*"] 44 | -------------------------------------------------------------------------------- /gins/partial_train_adafactor_dual.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Base model. 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import adafactor 5 | from t5x import optimizers 6 | from hyper_task_descriptions import utils as hyper_utils 7 | from hyper_task_descriptions import learning_rate_adafactor 8 | from flax import traverse_util 9 | 10 | # gin that allows partial training based on regex matching. 11 | 12 | # general defaults. 13 | learning_rate_adafactor.Adafactor: 14 | decay_rate = 0.8 15 | step_offset = 0 16 | logical_factor_rules = @adafactor.standard_logical_factor_rules() 17 | 18 | # ------------------- Partial loading ------------------------------------------------ 19 | OPTIMIZER = @optimizers.MultiOptimizer() 20 | # note you can add more traversals if you want different optimizer settings 21 | # for dfferent parts of the model. 22 | # See https://github.com/google-research/t5x/blob/main/docs/usage/gin.md#scoping 23 | # for how to create multiple specialised instances of the same class. 24 | optimizers.MultiOptimizer: 25 | traversals_and_optimizers = ((@hyper/traverse_util.ModelParamTraversal(), 26 | @hyper/learning_rate_adafactor.Adafactor()), 27 | (@t5/traverse_util.ModelParamTraversal(), 28 | @t5/learning_rate_adafactor.Adafactor()),) 29 | 30 | # MultiOptimizer will match any parameter with a flattened name that 31 | # matches *any* of the regular expressions in the list. 32 | # hyper - we offset the step 33 | hyper/learning_rate_adafactor.Adafactor: 34 | multiply_by_parameter_scale = True 35 | step_offset = 1100000 36 | learning_rate = 1e-3 37 | hyper/traverse_util.ModelParamTraversal: 38 | filter_fn = @hyper/hyper_utils.match_any() 39 | hyper/hyper_utils.match_any.regexes = [".*hyper.*"] 40 | # hyper encoder - state is fresh so reset step 41 | # nuked 42 | # nothing special for this one 43 | t5/learning_rate_adafactor.Adafactor: 44 | learning_rate = 1e-3 45 | t5/traverse_util.ModelParamTraversal: 46 | filter_fn = @t5/hyper_utils.inverse_match_any() 47 | t5/hyper_utils.inverse_match_any.regexes = [".*/hyper.*"] 48 | -------------------------------------------------------------------------------- /gins/partial_train_adafactor_dual_frozen_under.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Base model. 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import adafactor 5 | from t5x import optimizers 6 | from hyper_task_descriptions import utils as hyper_utils 7 | from hyper_task_descriptions import learning_rate_adafactor 8 | from flax import traverse_util 9 | 10 | # gin that allows partial training based on regex matching. 11 | 12 | # general defaults. 13 | learning_rate_adafactor.Adafactor: 14 | decay_rate = 0.8 15 | step_offset = 0 16 | logical_factor_rules = @adafactor.standard_logical_factor_rules() 17 | 18 | # ------------------- Partial loading ------------------------------------------------ 19 | OPTIMIZER = @optimizers.MultiOptimizer() 20 | # note you can add more traversals if you want different optimizer settings 21 | # for dfferent parts of the model. 22 | # See https://github.com/google-research/t5x/blob/main/docs/usage/gin.md#scoping 23 | # for how to create multiple specialised instances of the same class. 24 | optimizers.MultiOptimizer: 25 | traversals_and_optimizers = ((@hyper/traverse_util.ModelParamTraversal(), 26 | @hyper/learning_rate_adafactor.Adafactor()), 27 | (@t5/traverse_util.ModelParamTraversal(), 28 | @t5/learning_rate_adafactor.Adafactor()),) 29 | 30 | # MultiOptimizer will match any parameter with a flattened name that 31 | # matches *any* of the regular expressions in the list. 32 | # hyper - we turn off param scaling, offset the step 33 | hyper/learning_rate_adafactor.Adafactor: 34 | multiply_by_parameter_scale = True 35 | step_offset = 1100000 36 | learning_rate = 1e-3 37 | hyper/traverse_util.ModelParamTraversal: 38 | filter_fn = @hyper/hyper_utils.match_any() 39 | hyper/hyper_utils.match_any.regexes = [".*hyper.*"] 40 | # hyper encoder - state is fresh so reset step 41 | # nuked 42 | # nothing special for this one 43 | t5/learning_rate_adafactor.Adafactor: 44 | learning_rate = 0 45 | t5/traverse_util.ModelParamTraversal: 46 | filter_fn = @t5/hyper_utils.inverse_match_any() 47 | t5/hyper_utils.inverse_match_any.regexes = [".*/hyper.*"] 48 | -------------------------------------------------------------------------------- /gins/partial_train_adafactor_no_roberta.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Base model. 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import adafactor 5 | from t5x import optimizers 6 | from hyper_task_descriptions import utils as hyper_utils 7 | from flax import traverse_util 8 | 9 | # gin that allows partial training based on regex matching. 10 | 11 | # ------------------- Partial loading ------------------------------------------------ 12 | OPTIMIZER = @optimizers.MultiOptimizer() 13 | # note you can add more traversals if you want different optimizer settings 14 | # for dfferent parts of the model. 15 | # See https://github.com/google-research/t5x/blob/main/docs/usage/gin.md#scoping 16 | # for how to create multiple specialised instances of the same class. 17 | optimizers.MultiOptimizer: 18 | traversals_and_optimizers = ((@hyper/traverse_util.ModelParamTraversal(), 19 | @hyper/adafactor.Adafactor()), 20 | (@t5/traverse_util.ModelParamTraversal(), 21 | @t5/adafactor.Adafactor()),) 22 | 23 | # MultiOptimizer will match any parameter with a flattened name that 24 | # matches *any* of the regular expressions in the list. 25 | # hyper - we turn off param scaling, offset the step 26 | hyper/adafactor.Adafactor: 27 | multiply_by_parameter_scale = True 28 | step_offset = 1100000 29 | hyper/traverse_util.ModelParamTraversal: 30 | filter_fn = @hyper/hyper_utils.match_any() 31 | hyper/hyper_utils.match_any.regexes = [".*/hyper/[^e].*"] 32 | # nothing special for this one 33 | t5/traverse_util.ModelParamTraversal: 34 | filter_fn = @t5/hyper_utils.inverse_match_any() 35 | t5/hyper_utils.inverse_match_any.regexes = [".*hyper.*"] 36 | -------------------------------------------------------------------------------- /gins/partial_train_adam.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import optax 4 | from t5x import utils 5 | 6 | from hyper_task_descriptions import utils as hyper_utils 7 | 8 | # WARNING: t5x will log starting from the pretrained model step, 9 | # but optax calls this starting from 0. So ignore the tensorboard 10 | # learning rate logging. 11 | utils.create_learning_rate_scheduler: 12 | factors = 'linear_decay' 13 | base_learning_rate = 1e-4 14 | decay_factor = 0 15 | warmup_steps = 1000 16 | step_offset = 0 # our steps start at 0 no matter what with optax. 17 | 18 | 19 | # multi optimizer - try to match hyper, then roberta, all else freeze 20 | # hyper = parameter generators 21 | # roberta = hyperencoder 22 | # all else = underlying model 23 | OPTIMIZER = @hyper_utils.multi_transform() 24 | hyper_utils.multi_transform: 25 | transforms = {"hyper": @hyper/optax.adam(), "freeze": @under/optax.adam(), "roberta": @roberta/optax.adam()} 26 | param_labels = @hyper_utils.match_any_optax_trip() # match_any_optax 27 | 28 | hyper/optax.adam: 29 | learning_rate = @hyper/utils.create_learning_rate_scheduler() 30 | weight_decay = 0 31 | 32 | hyper/utils.create_learning_rate_scheduler: 33 | base_learning_rate = 1e-4 34 | decay_factor = 0 35 | 36 | roberta/optax.adam: 37 | learning_rate = @roberta/utils.create_learning_rate_scheduler() 38 | 39 | roberta/utils.create_learning_rate_scheduler: 40 | base_learning_rate = 1e-4 41 | decay_factor = 0 42 | 43 | under/optax.adam: 44 | learning_rate = @under/utils.create_learning_rate_scheduler() 45 | 46 | under/utils.create_learning_rate_scheduler: 47 | base_learning_rate = 1e-4 48 | decay_factor = 0 49 | 50 | hyper_utils.match_any_optax_trip.regexes = [".*hyper/[eap].*"] # select encoder + adapter generators 51 | hyper_utils.match_any_optax_trip.hyper_regexes = [".*hyper.*"] # all other hyper 52 | -------------------------------------------------------------------------------- /gins/pretrain.gin: -------------------------------------------------------------------------------- 1 | # For training T0 (xxl = 11b, xl = 3b). Make sure you have cached p3 first! 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import models 5 | from t5x import trainer 6 | from t5x import utils 7 | import seqio 8 | from hyper_task_descriptions.c4 import c4_registry 9 | 10 | import __main__ as train_script 11 | 12 | include "t5x/configs/runs/finetune.gin" 13 | include "gins/t0.gin" # This overrides some default config in `t5x/configs/runs/finetune.gin` 14 | include "gins/restore_pretrained.gin" # for loading from checkpoints 15 | 16 | TASK_FEATURE_LENGTHS = {"inputs": 512, "hyper_inputs": 512, "task_names": 1, "targets": 512} 17 | MIXTURE_OR_TASK_NAME = "c4_pretrain" 18 | DROPOUT_RATE = 0.0 19 | 20 | trainer.Trainer.num_microbatches = 16 # 2048 // 16 21 | trainer.Trainer.weight_metrics_computer = @trainer.WeightMetricsComputer() 22 | 23 | # the batch sizes are really big so I increase stats logging 24 | # for more visibility. 25 | train_script.train: 26 | stats_period = 100 # default is eval_period, which is 1000 27 | 28 | 29 | utils.SaveCheckpointConfig: 30 | period = 500 # checkpoint frequency -------------------------------------------------------------------------------- /gins/pretrain_4part.gin: -------------------------------------------------------------------------------- 1 | # For training T0 (xxl = 11b, xl = 3b). Make sure you have cached p3 first! 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import models 5 | from t5x import trainer 6 | from t5x import utils 7 | import seqio 8 | from hyper_task_descriptions.c4 import c4_registry_4part 9 | 10 | import __main__ as train_script 11 | 12 | include "t5x/configs/runs/finetune.gin" 13 | include "gins/t0.gin" # This overrides some default config in `t5x/configs/runs/finetune.gin` 14 | include "gins/restore_pretrained.gin" # for loading from checkpoints 15 | 16 | TASK_FEATURE_LENGTHS = {"inputs": 512, "hyper_inputs": 512, "task_names": 1, "targets": 512} 17 | MIXTURE_OR_TASK_NAME = "c4_pretrain" 18 | DROPOUT_RATE = 0.0 19 | 20 | trainer.Trainer.num_microbatches = 16 # 2048 // 16 21 | trainer.Trainer.weight_metrics_computer = @trainer.WeightMetricsComputer() 22 | 23 | # the batch sizes are really big so I increase stats logging 24 | # for more visibility. 25 | train_script.train: 26 | stats_period = 100 # default is eval_period, which is 1000 27 | 28 | 29 | utils.SaveCheckpointConfig: 30 | period = 500 # checkpoint frequency -------------------------------------------------------------------------------- /gins/pretrain_6part.gin: -------------------------------------------------------------------------------- 1 | # For training T0 (xxl = 11b, xl = 3b). Make sure you have cached p3 first! 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import models 5 | from t5x import trainer 6 | from t5x import utils 7 | import seqio 8 | from hyper_task_descriptions.c4 import c4_registry_6part 9 | 10 | import __main__ as train_script 11 | 12 | include "t5x/configs/runs/finetune.gin" 13 | include "gins/t0.gin" # This overrides some default config in `t5x/configs/runs/finetune.gin` 14 | include "gins/restore_pretrained.gin" # for loading from checkpoints 15 | 16 | TASK_FEATURE_LENGTHS = {"inputs": 512, "hyper_inputs": 512, "task_names": 1, "targets": 512} 17 | MIXTURE_OR_TASK_NAME = "c4_pretrain" 18 | DROPOUT_RATE = 0.0 19 | 20 | trainer.Trainer.num_microbatches = 16 # 2048 // 16 21 | trainer.Trainer.weight_metrics_computer = @trainer.WeightMetricsComputer() 22 | 23 | # the batch sizes are really big so I increase stats logging 24 | # for more visibility. 25 | train_script.train: 26 | stats_period = 100 # default is eval_period, which is 1000 27 | 28 | 29 | utils.SaveCheckpointConfig: 30 | period = 500 # checkpoint frequency -------------------------------------------------------------------------------- /gins/restore_frozen_under.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Base model. 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import utils 5 | 6 | 7 | # These setting allow us to partially reload a checkpoint, that is, we can load 8 | # most of the model weights from the checkpoint, without it complaining that we 9 | # don't have a weight for our prompt in the checkpoint. 10 | utils.RestoreCheckpointConfig: 11 | # Activate the codepath that allow of the merging of the optimizer state as 12 | # specified in the config (with our new parameter) and the optimizer state as 13 | # defined in the checkpoint. 14 | fallback_to_scratch = True 15 | # Use the T5X assignment map to grab values from the checkpoint. Each entry in 16 | # the map is a regular expression that matches some flatten variable in the 17 | # optimizer state as defined in the model created by the config. The second 18 | # value is the corresponding name in optimizer state as defined by the 19 | # checkpoint. It supports interpolating capture groups from the initial regex. 20 | # If the second pattern it `None` we skip trying to load this variable from 21 | # the checkpoint. 22 | 23 | # We skip hypernetwork parameters 24 | # any matching regex will not be restored from the checkpoint. 25 | # anything not matching not in the checkpoint will cause an error. 26 | assignment_map = ( 27 | (r"^.*param_states/[ed].*$", None), 28 | ) 29 | -------------------------------------------------------------------------------- /gins/restore_pretrained.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Base model. 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import utils 5 | 6 | 7 | # These setting allow us to partially reload a checkpoint, that is, we can load 8 | # most of the model weights from the checkpoint, without it complaining that we 9 | # don't have a weight for our prompt in the checkpoint. 10 | utils.RestoreCheckpointConfig: 11 | # Activate the codepath that allow of the merging of the optimizer state as 12 | # specified in the config (with our new parameter) and the optimizer state as 13 | # defined in the checkpoint. 14 | fallback_to_scratch = True 15 | # Use the T5X assignment map to grab values from the checkpoint. Each entry in 16 | # the map is a regular expression that matches some flatten variable in the 17 | # optimizer state as defined in the model created by the config. The second 18 | # value is the corresponding name in optimizer state as defined by the 19 | # checkpoint. It supports interpolating capture groups from the initial regex. 20 | # If the second pattern it `None` we skip trying to load this variable from 21 | # the checkpoint. 22 | 23 | # We skip hypernetwork parameters 24 | # any matching regex will not be restored from the checkpoint. 25 | # anything not matching not in the checkpoint will cause an error. 26 | assignment_map = ( 27 | (r"^.*hyper.*$", None), 28 | #(r"^.*lora_a.*$", None), 29 | #(r"^.*lora_b.*$", None), 30 | ) 31 | -------------------------------------------------------------------------------- /gins/separate_henc.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | from t5x import utils 4 | from hyper_task_descriptions.modeling import hyper_network 5 | 6 | # hypertune: decoder in hypernet 7 | hyper_network.HyperTransformer.config = @hyper_network.HyperT5Config() 8 | hyper_network.HyperT5Config: 9 | share_hnet_encoder = False 10 | 11 | # we restore hypernetwork weights from pretrained model 12 | utils.RestoreCheckpointConfig: 13 | fallback_to_scratch = True 14 | assignment_map = ( 15 | # we map encoder values in the checkpoint to hyperencoder and regular encoder weights 16 | ('(.*)/(hyper/hyper_encoder|encoder)/(.*)', r'\1/encoder/\3'), 17 | # the non-t5 bits of hypernet need to be initialised from scratch 18 | ('.*hyper/[^h].*', None), 19 | ) -------------------------------------------------------------------------------- /gins/t0.gin: -------------------------------------------------------------------------------- 1 | ## T0 overrides to match their training setup 2 | 3 | from __gin__ import dynamic_registration 4 | 5 | import __main__ as train_script 6 | from t5x import models 7 | from t5x import partitioning 8 | from t5x import trainer 9 | from t5x import utils 10 | import seqio 11 | 12 | MIXTURE_OR_TASK_MODULE = "t5.data.mixtures" 13 | TRAIN_STEPS = 1212200 # 1112200 (pretrain) + 100000 (finetune) 14 | DROPOUT_RATE = 0.1 15 | BATCH_SIZE = 2048 16 | 17 | train/utils.DatasetConfig: 18 | batch_size = %BATCH_SIZE 19 | use_cached = %USE_CACHED_TASKS 20 | pack = False # for the hypernet, we need this off. 21 | use_custom_packing_ops = False 22 | 23 | train_eval/utils.DatasetConfig: 24 | batch_size = %BATCH_SIZE 25 | use_cached = %USE_CACHED_TASKS 26 | pack = False 27 | use_custom_packing_ops = False 28 | 29 | train_script.train: 30 | eval_period = 500 31 | eval_steps = 50 32 | random_seed = None 33 | infer_eval_dataset_cfg = None # Prevent to run inference evaluation 34 | train_eval_dataset_cfg = None # Prevent to run evaluation as it seems to OOM me. 35 | 36 | utils.create_learning_rate_scheduler: 37 | base_learning_rate = 0.001 38 | warmup_steps = 10000 # irrelevant 39 | 40 | utils.SaveCheckpointConfig: 41 | period = 1000 # checkpoint frequency 42 | keep = None # only keep one checkpoint 43 | 44 | utils.RestoreCheckpointConfig: 45 | strict = True 46 | dtype = 'bfloat16' 47 | 48 | partitioning.PjitPartitioner.num_partitions = 2 49 | 50 | trainer.Trainer.num_microbatches = 32 # 2048 // 32 51 | 52 | seqio.Evaluator.use_memory_cache = False 53 | -------------------------------------------------------------------------------- /gins/t0_eval.gin: -------------------------------------------------------------------------------- 1 | # Defaults for eval.py. 2 | # 3 | # 4 | # You must also include a binding for MODEL. 5 | # 6 | # Required to be set: 7 | # 8 | # - CHECKPOINT_PATH: The model checkpoint to evaluate 9 | # - EVAL_OUTPUT_DIR: The dir to write results to. 10 | # 11 | # 12 | # Commonly overridden options: 13 | # 14 | # - DatasetConfig.split 15 | # - DatasetConfig.batch_size 16 | from __gin__ import dynamic_registration 17 | 18 | import __main__ as eval_script 19 | from t5x import partitioning 20 | from t5x import utils 21 | 22 | import seqio 23 | from seqio import loggers 24 | from hyper_task_descriptions.seqio_tasks import all_t0_tasks # Needed to define the t0 eval mixtures 25 | # from hyper_task_descriptions.seqio_tasks import my_t0_tasks 26 | 27 | 28 | # Must be overridden 29 | MIXTURE_OR_TASK_NAME = "t0_eval_score_eval" 30 | CHECKPOINT_PATH = %gin.REQUIRED 31 | EVAL_OUTPUT_DIR = %gin.REQUIRED 32 | TASK_FEATURE_LENGTHS = {"inputs": 1024, "hyper_inputs": 512, "task_names": 1, "targets": 256} 33 | DROPOUT_RATE = 0.0 34 | 35 | # DEPRECATED: Import the this module in your gin file. 36 | MIXTURE_OR_TASK_MODULE = None 37 | 38 | eval_script.evaluate: 39 | model = %MODEL # imported from separate gin file 40 | dataset_cfg = @utils.DatasetConfig() 41 | partitioner = @partitioning.PjitPartitioner() 42 | restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() 43 | output_dir = %EVAL_OUTPUT_DIR 44 | inference_evaluator_cls = @seqio.Evaluator 45 | 46 | seqio.Evaluator.logger_cls = [@loggers.JSONLogger, @seqio.TensorBoardLogger] 47 | 48 | partitioning.PjitPartitioner.num_partitions = 2 49 | 50 | utils.DatasetConfig: 51 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME 52 | task_feature_lengths = %TASK_FEATURE_LENGTHS 53 | split = 'validation' 54 | batch_size = 256 55 | shuffle = False 56 | seed = 42 57 | use_cached = %USE_CACHED_TASKS 58 | pack = False 59 | use_custom_packing_ops = False 60 | module = %MIXTURE_OR_TASK_MODULE 61 | 62 | utils.RestoreCheckpointConfig: 63 | path = %CHECKPOINT_PATH 64 | mode = 'specific' 65 | dtype = 'float32' 66 | strict = True # make sure we actually load everything! 67 | -------------------------------------------------------------------------------- /gins/t0_train.gin: -------------------------------------------------------------------------------- 1 | # For training T0 (xxl = 11b, xl = 3b). Make sure you have cached p3 first! 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import models 5 | from t5x import trainer 6 | from t5x import utils 7 | import seqio 8 | from hyper_task_descriptions.seqio_tasks import all_t0_tasks 9 | # from hyper_task_descriptions.seqio_tasks import my_t0_tasks 10 | 11 | import __main__ as train_script 12 | 13 | include "t5x/configs/runs/finetune.gin" 14 | include "gins/t0.gin" # This overrides some default config in `t5x/configs/runs/finetune.gin` 15 | include "gins/restore_pretrained.gin" # for loading from checkpoints 16 | 17 | TASK_FEATURE_LENGTHS = {"inputs": 1024, "hyper_inputs": 512, "task_names": 1, "targets": 256} 18 | MIXTURE_OR_TASK_NAME = "t0_train" 19 | 20 | trainer.Trainer.num_microbatches = 16 # 2048 // 16 21 | trainer.Trainer.weight_metrics_computer = @trainer.WeightMetricsComputer() 22 | 23 | # the batch sizes are really big so I increase stats logging 24 | # for more visibility. 25 | train_script.train: 26 | stats_period = 100 # default is eval_period, which is 1000 27 | -------------------------------------------------------------------------------- /gins/t0_train_local.gin: -------------------------------------------------------------------------------- 1 | # For training T0 (xxl = 11b, xl = 3b). Make sure you have cached p3 first! 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import models 5 | from t5x import trainer 6 | from t5x import utils 7 | import seqio 8 | from hyper_task_descriptions.seqio_tasks import small_t0_tasks 9 | 10 | import __main__ as train_script 11 | 12 | #include "t5x/configs/runs/finetune.gin" 13 | include "gins/finetune_from_scratch.gin" 14 | include "gins/t0.gin" # This overrides some default config in `t5x/configs/runs/finetune.gin` 15 | include "gins/restore_pretrained.gin" # for loading from checkpoints 16 | 17 | TASK_FEATURE_LENGTHS = {"inputs": 64, "hyper_inputs": 128, "task_names": 1, "targets": 32} 18 | MIXTURE_OR_TASK_NAME = "t0_small_train" 19 | 20 | trainer.Trainer.num_microbatches = 16 # 2048 // 16 21 | trainer.Trainer.weight_metrics_computer = @trainer.WeightMetricsComputer() 22 | 23 | # the batch sizes are really big so I increase stats logging 24 | # for more visibility. 25 | train_script.train: 26 | stats_period = 100 # default is eval_period, which is 1000 27 | -------------------------------------------------------------------------------- /gins/t0_train_local_copy.gin: -------------------------------------------------------------------------------- 1 | # For training T0 (xxl = 11b, xl = 3b). Make sure you have cached p3 first! 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import models 5 | from t5x import trainer 6 | from t5x import utils 7 | import seqio 8 | from hyper_task_descriptions.numeric_task import numeric_registry 9 | 10 | import __main__ as train_script 11 | 12 | #include "t5x/configs/runs/finetune.gin" 13 | include "gins/finetune_from_scratch.gin" 14 | include "gins/t0.gin" # This overrides some default config in `t5x/configs/runs/finetune.gin` 15 | include "gins/restore_pretrained.gin" # for loading from checkpoints 16 | 17 | TASK_FEATURE_LENGTHS = {"inputs": 2, "hyper_inputs": 4, "task_names": 1, "targets": 3} 18 | MIXTURE_OR_TASK_NAME = "copy_task" 19 | 20 | trainer.Trainer.num_microbatches = 16 # 2048 // 16 21 | trainer.Trainer.weight_metrics_computer = @trainer.WeightMetricsComputer() 22 | 23 | # the batch sizes are really big so I increase stats logging 24 | # for more visibility. 25 | train_script.train: 26 | stats_period = 1 # default is eval_period, which is 1000 27 | -------------------------------------------------------------------------------- /gins/train_only_hnet.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Base model. 2 | from __gin__ import dynamic_registration 3 | 4 | from t5x import adafactor 5 | from t5x import optimizers 6 | from hyper_task_descriptions import utils as hyper_utils 7 | from flax import traverse_util 8 | 9 | # gin that allows partial training based on regex matching. 10 | 11 | # ------------------- Partial loading ------------------------------------------------ 12 | OPTIMIZER = @optimizers.MultiOptimizer() 13 | # note you can add more traversals if you want different optimizer settings 14 | # for dfferent parts of the model. 15 | # See https://github.com/google-research/t5x/blob/main/docs/usage/gin.md#scoping 16 | # for how to create multiple specialised instances of the same class. 17 | optimizers.MultiOptimizer: 18 | traversals_and_optimizers = ((@hyper/traverse_util.ModelParamTraversal(), 19 | @hyper/adafactor.Adafactor()),) 20 | 21 | # MultiOptimizer will match any parameter with a flattened name that 22 | # matches *any* of the regular expressions in the list. 23 | # hyper - we turn off param scaling, offset the step 24 | hyper/adafactor.Adafactor: 25 | step_offset = 1100000 26 | hyper/traverse_util.ModelParamTraversal: 27 | filter_fn = @hyper/hyper_utils.match_any() 28 | hyper/hyper_utils.match_any.regexes = [".*hyper.*"] 29 | -------------------------------------------------------------------------------- /hyper_task_descriptions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/hyper-task-descriptions/fd86a43ac2131582548130d86d57ea977c804ab6/hyper_task_descriptions/__init__.py -------------------------------------------------------------------------------- /hyper_task_descriptions/c4/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/hyper-task-descriptions/fd86a43ac2131582548130d86d57ea977c804ab6/hyper_task_descriptions/c4/__init__.py -------------------------------------------------------------------------------- /hyper_task_descriptions/c4/c4_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5 Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # modified prefix lm pretraining, using 3-way split. 16 | 17 | import functools 18 | import os 19 | 20 | import seqio 21 | import tensorflow as tf 22 | from t5.data import preprocessors 23 | 24 | from hyper_task_descriptions.hf_vocab import HuggingfaceVocabulary 25 | from hyper_task_descriptions.utils import GOOGLE_BUCKET_PATH 26 | 27 | TaskRegistry = seqio.TaskRegistry 28 | 29 | seqio.add_global_cache_dirs([f"{GOOGLE_BUCKET_PATH}/hyper-task-descriptions/data/c4_pretrain_data"]) 30 | 31 | t5_vocab = HuggingfaceVocabulary("t5-base") 32 | 33 | words_path = os.path.join( 34 | os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "numeric_task", "words.txt" 35 | ) 36 | 37 | words = [line.strip() for line in open(words_path, "r").readlines()] 38 | 39 | 40 | def pack_prefix_lm_encoder_decoder_random_inputs(ds, sequence_length, pad_id=0): 41 | """Setup example for prefix lm. no packing becuz im lazy""" 42 | 43 | @seqio.utils.map_over_dataset(num_seeds=2) 44 | def create_example(example, seeds): 45 | split_point_1 = tf.random.stateless_uniform( 46 | (), minval=1, maxval=example["targets"].shape[0] - 2, seed=seeds[0], dtype=tf.int32 47 | ) 48 | split_point_2 = tf.random.stateless_uniform( 49 | (), 50 | minval=split_point_1, 51 | maxval=example["targets"].shape[0] - 2, 52 | seed=seeds[0], 53 | dtype=tf.int32, 54 | ) 55 | hyper_inputs = example["targets"][:split_point_1] 56 | inputs = example["targets"][split_point_1:split_point_2] 57 | targets = example["targets"][split_point_2:] 58 | 59 | # inputs = t5_vocab._encode_tf(random.choice(words)) 60 | # We want the length _after_ tokenization to be sequence_length['inputs'] 61 | # inputs = t5_vocab._encode_tf(' '.join(random.choices(words, k=sequence_length['inputs'] // 4))) 62 | return { 63 | "inputs": inputs, 64 | "hyper_inputs": hyper_inputs, 65 | "targets": targets, 66 | "task_names": [0], 67 | } 68 | 69 | return create_example(ds) 70 | 71 | 72 | # only compatible when we use the T5 encoder as our hypernetwork 73 | seqio.TaskRegistry.add( 74 | "c4_pretrain", 75 | source=seqio.TfdsDataSource(tfds_name="c4/en:3.1.0", splits=["train", "validation"]), 76 | preprocessors=[ 77 | functools.partial( 78 | preprocessors.rekey, 79 | key_map={ 80 | "inputs": None, 81 | "targets": "text", 82 | # "roberta_targets": "text" # for roberta vocab 83 | }, 84 | ), 85 | seqio.preprocessors.tokenize, 86 | seqio.CacheDatasetPlaceholder(), 87 | preprocessors.targets_for_prefix_lm_objective, 88 | pack_prefix_lm_encoder_decoder_random_inputs, 89 | seqio.preprocessors.append_eos, 90 | ], 91 | output_features={ 92 | "inputs": seqio.Feature(vocabulary=t5_vocab, add_eos=True), 93 | "targets": seqio.Feature(vocabulary=t5_vocab, add_eos=True), 94 | "hyper_inputs": seqio.Feature(vocabulary=t5_vocab, add_eos=True), 95 | "task_names": seqio.Feature(seqio.PassThroughVocabulary(1), add_eos=False, dtype=tf.int32), 96 | }, 97 | metric_fns=[], 98 | ) 99 | -------------------------------------------------------------------------------- /hyper_task_descriptions/c4/c4_registry_4part.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5 Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # modified prefix lm pretraining, using 3-way split. 16 | 17 | import functools 18 | import os 19 | 20 | import seqio 21 | import tensorflow as tf 22 | from t5.data import preprocessors 23 | 24 | from hyper_task_descriptions.hf_vocab import HuggingfaceVocabulary 25 | from hyper_task_descriptions.utils import GOOGLE_BUCKET_PATH 26 | 27 | TaskRegistry = seqio.TaskRegistry 28 | 29 | seqio.add_global_cache_dirs([f"{GOOGLE_BUCKET_PATH}/hyper-task-descriptions/data/c4_pretrain_data"]) 30 | 31 | t5_vocab = HuggingfaceVocabulary("t5-base") 32 | 33 | words_path = os.path.join( 34 | os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "numeric_task", "words.txt" 35 | ) 36 | 37 | words = [line.strip() for line in open(words_path, "r").readlines()] 38 | 39 | 40 | def pack_hypertune(ds, sequence_length, pad_id=0): 41 | """Setup example for prefix lm. no packing becuz im lazy""" 42 | 43 | @seqio.utils.map_over_dataset(num_seeds=3) 44 | def create_example(example, seeds): 45 | # we split the target into 4 parts: 46 | # a -> hypernet 47 | # b (short) -> encoder 48 | # c (short) -> decoder 49 | # d -> hypernet again. 50 | # currently using random lengths 51 | split_point_1 = tf.random.stateless_uniform( 52 | (), minval=1, maxval=example["targets"].shape[0] - 4, seed=seeds[0], dtype=tf.int32 53 | ) 54 | split_point_2 = tf.random.stateless_uniform( 55 | (), 56 | minval=split_point_1, 57 | maxval=example["targets"].shape[0] - 3, 58 | seed=seeds[1], 59 | dtype=tf.int32, 60 | ) 61 | split_point_3 = tf.random.stateless_uniform( 62 | (), 63 | minval=split_point_2, 64 | maxval=example["targets"].shape[0] - 2, 65 | seed=seeds[2], 66 | dtype=tf.int32, 67 | ) 68 | # '1' as eos to mark end of first part of input 69 | hyper_inputs = tf.concat( 70 | [example["targets"][:split_point_1], [1], example["targets"][split_point_3:]], axis=0 71 | ) 72 | inputs = example["targets"][split_point_1:split_point_2] 73 | targets = example["targets"][split_point_2:split_point_3] 74 | 75 | return { 76 | "inputs": inputs, 77 | "hyper_inputs": hyper_inputs, 78 | "targets": targets, 79 | "task_names": [0], 80 | } 81 | 82 | return create_example(ds) 83 | 84 | 85 | seqio.TaskRegistry.add( 86 | "c4_pretrain", 87 | source=seqio.TfdsDataSource(tfds_name="c4/en:3.1.0", splits=["train", "validation"]), 88 | preprocessors=[ 89 | functools.partial( 90 | preprocessors.rekey, 91 | key_map={ 92 | "inputs": None, 93 | "targets": "text", 94 | # "roberta_targets": "text" # for roberta vocab 95 | }, 96 | ), 97 | seqio.preprocessors.tokenize, 98 | seqio.CacheDatasetPlaceholder(), 99 | preprocessors.targets_for_prefix_lm_objective, 100 | pack_hypertune, 101 | seqio.preprocessors.append_eos, 102 | ], 103 | output_features={ 104 | "inputs": seqio.Feature(vocabulary=t5_vocab, add_eos=True), 105 | "targets": seqio.Feature(vocabulary=t5_vocab, add_eos=True), 106 | "hyper_inputs": seqio.Feature(vocabulary=t5_vocab, add_eos=True), 107 | "task_names": seqio.Feature(seqio.PassThroughVocabulary(1), add_eos=False, dtype=tf.int32), 108 | }, 109 | metric_fns=[], 110 | ) 111 | -------------------------------------------------------------------------------- /hyper_task_descriptions/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/hyper-task-descriptions/fd86a43ac2131582548130d86d57ea977c804ab6/hyper_task_descriptions/common/__init__.py -------------------------------------------------------------------------------- /hyper_task_descriptions/hf_vocab.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Optional, Sequence 2 | 3 | import tensorflow.compat.v2 as tf 4 | from seqio import Vocabulary 5 | from transformers import AutoTokenizer 6 | 7 | 8 | class HuggingfaceVocabulary(Vocabulary): 9 | """Really simple wrapper around huggingface tokenizer.""" 10 | 11 | def __init__(self, model_name: str, extra_ids: int = 0, add_special_tokens: bool = False): 12 | """Vocabulary constructor. 13 | Args: 14 | extra_ids: The number of extra IDs to reserve. 15 | """ 16 | self._tokenizer = None # lazy load tokenizer 17 | self.model_name = model_name 18 | self._extra_ids = extra_ids or 0 19 | assert self._extra_ids == 0 20 | self._add_special_tokens = add_special_tokens 21 | super().__init__(extra_ids=extra_ids) 22 | 23 | def _load_model(self): 24 | self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) 25 | 26 | @property 27 | def tokenizer(self): 28 | if self._tokenizer is None: 29 | self._load_model() 30 | return self._tokenizer 31 | 32 | @property 33 | def eos_id(self) -> Optional[int]: 34 | return self.tokenizer.eos_token_id 35 | 36 | @property 37 | def pad_id(self) -> int: 38 | return self.tokenizer.pad_token_id 39 | 40 | @property 41 | def unk_id(self) -> Optional[int]: 42 | return self.tokenizer.unk_token_id 43 | 44 | @property 45 | def _base_vocab_size(self) -> int: 46 | """Vocabulary size, excluding extra ids but including PAD/EOS/UNK.""" 47 | return self.tokenizer.vocab_size 48 | 49 | def _encode(self, s: str) -> Sequence[int]: 50 | return self.tokenizer(s, add_special_tokens=self._add_special_tokens)["input_ids"] 51 | 52 | def _decode(self, ids): 53 | return self.tokenizer.decode(ids, skip_special_tokens=True) 54 | 55 | def decode(self, ids: Iterable[int]): 56 | """Detokenizes int32 iterable to a string, up through first EOS.""" 57 | clean_ids = list(ids) 58 | 59 | if self.unk_id is not None: 60 | vocab_size = self._base_vocab_size 61 | clean_ids = [self.unk_id if i >= vocab_size else i for i in clean_ids] 62 | 63 | if self.eos_id is not None and self.eos_id in clean_ids: 64 | clean_ids = clean_ids[: clean_ids.index(self.eos_id) + 1] 65 | 66 | return self._decode(clean_ids) 67 | 68 | def _encode_tf(self, s: tf.Tensor) -> tf.Tensor: 69 | def enc(s): 70 | r = self.tokenizer( 71 | s.numpy().decode("utf-8"), 72 | return_tensors="tf", 73 | add_special_tokens=self._add_special_tokens, 74 | )["input_ids"] 75 | return tf.cast(r, tf.int32) 76 | 77 | # we reshape to ensure that we get a 1-dimensional tensor. 78 | return tf.reshape(tf.py_function(enc, [s], Tout=tf.int32), [-1]) 79 | 80 | def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor: 81 | return tf.constant(self.tokenizer.decode(ids, skip_special_tokens=True)) 82 | 83 | def __eq__(self, other): 84 | # this is an overly simple implementation of __eq__, but should be okay. 85 | # if not isinstance(other, HuggingfaceVocabulary): 86 | # return False 87 | # try: 88 | # their_model_name = other.model_name 89 | # except AttributeError: 90 | # return False 91 | # return self.model_name == their_model_name 92 | # hack! 93 | return True 94 | -------------------------------------------------------------------------------- /hyper_task_descriptions/learning_rate_adafactor.py: -------------------------------------------------------------------------------- 1 | """ 2 | T5X default setup does not respect the learning rate set in its arguments. 3 | This optimizer version does. 4 | """ 5 | from t5x import adafactor 6 | 7 | 8 | class Adafactor(adafactor.Adafactor): 9 | def apply_param_gradient(self, step, hyper_params, param, state, grad, path): 10 | # must use replace function as hyper_params is a struct.dataclass and frozen 11 | hyper_params = hyper_params.replace(learning_rate=self.hyper_params.learning_rate) 12 | return super().apply_param_gradient(step, hyper_params, param, state, grad, path) 13 | 14 | def apply_gradient(self, hyper_params, params, state, grads): 15 | # must use replace function as hyper_params is a struct.dataclass and frozen 16 | hyper_params = hyper_params.replace(learning_rate=self.hyper_params.learning_rate) 17 | return super().apply_gradient(hyper_params, params, state, grads) 18 | -------------------------------------------------------------------------------- /hyper_task_descriptions/modeling/lora_partitioning.py: -------------------------------------------------------------------------------- 1 | """ 2 | Partition map for lora weights 3 | See https://github.com/google-research/t5x/blob/main/docs/usage/partitioning.md 4 | for details on how this works. 5 | """ 6 | 7 | lora_axes_names_override = [ 8 | ( 9 | r"(encoder|decoder)/layers_\d+/(self_attention|attention|encoder_decoder_attention)" 10 | "/(query|key|value)/lora_a", 11 | ("embed", "joined_kv"), 12 | ), 13 | ( 14 | r"(encoder|decoder)/layers_\d+/(self_attention|attention|encoder_decoder_attention)" 15 | "/(query|key|value)/lora_b", 16 | ("embed", "joined_kv"), 17 | ), 18 | ( 19 | r"(encoder|decoder)/layers_\d+/(self_attention|attention|encoder_decoder_attention)" 20 | "/out/lora_a", 21 | ("joined_kv", "embed"), 22 | ), 23 | ( 24 | r"(encoder|decoder)/layers_\d+/(self_attention|attention|encoder_decoder_attention)" 25 | "/out/lora_b", 26 | ("joined_kv", "embed"), 27 | ), 28 | ] 29 | -------------------------------------------------------------------------------- /hyper_task_descriptions/modeling/losses.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | 7 | def cosine_similarity_loss( 8 | pred_vectors: jnp.ndarray, 9 | target_vectors: jnp.ndarray, 10 | ground_truth_similarity: jnp.ndarray, 11 | ) -> jnp.ndarray: 12 | cosine_sim = jax.vmap(cosine_similarity_one_to_many, in_axes=[0, None])( 13 | pred_vectors, target_vectors 14 | ) 15 | # cosine_sim = mask * cosine_sim 16 | loss = jnp.square(cosine_sim - ground_truth_similarity).mean() 17 | return loss 18 | 19 | 20 | def cosine_similarity_one_to_many( 21 | pred_vector: jnp.ndarray, 22 | target_vectors: jnp.ndarray, 23 | ) -> jnp.ndarray: 24 | cosine_sim = cosine_similarity( 25 | pred_vector[ 26 | None, 27 | ], 28 | target_vectors, 29 | ) 30 | return cosine_sim 31 | 32 | 33 | def cosine_similarity( 34 | predictions: jnp.ndarray, 35 | targets: jnp.ndarray, 36 | epsilon: float = 1e-8, 37 | ) -> jnp.ndarray: 38 | """ 39 | Computes the cosine similarity between targets and predictions. 40 | Adapted from optax: https://github.com/deepmind/optax/blob/master/optax/_src/loss.py 41 | eps default adjusted to 1e-8 to match pytorch. 42 | """ 43 | # vectorize norm fn, to treat all dimensions except the last as batch dims. 44 | batched_norm_fn = jnp.vectorize(safe_norm, signature="(k)->()", excluded={1}) 45 | # normalise the last dimension of targets and predictions. 46 | unit_targets = targets / jnp.expand_dims(batched_norm_fn(targets, epsilon), axis=-1) 47 | unit_predictions = predictions / jnp.expand_dims(batched_norm_fn(predictions, epsilon), axis=-1) 48 | # return cosine similarity. 49 | return jnp.sum(unit_targets * unit_predictions, axis=-1) 50 | 51 | 52 | # taken whole-cloth from optax. 53 | # https://github.com/deepmind/optax/blob/master/optax/_src/numerics.py#L48 54 | def safe_norm( 55 | x: jnp.ndarray, 56 | min_norm: float, 57 | ord: Optional[Union[int, float, str]] = None, # pylint: disable=redefined-builtin 58 | axis: Union[None, Tuple[int, ...], int] = None, 59 | keepdims: bool = False, 60 | ) -> jnp.ndarray: 61 | """Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients. 62 | The gradients of `jnp.maximum(jnp.linalg.norm(x), min_norm)` at 0.0 is `NaN`, 63 | because jax will evaluate both branches of the `jnp.maximum`. This function 64 | will instead return the correct gradient of 0.0 also in such setting. 65 | Args: 66 | x: jax array. 67 | min_norm: lower bound for the returned norm. 68 | ord: {non-zero int, inf, -inf, optional. Order of the norm. 69 | inf means numpys inf object. The default is None. 70 | axis: {None, int, 2-tuple of ints}, optional. If axis is an integer, it 71 | specifies the axis of x along which to compute the vector norms. If axis 72 | is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix 73 | norms of these matrices are computed. If axis is None then either a vector 74 | norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned. The 75 | default is None. 76 | keepdims: bool, optional. If this is set to True, the axes which are normed 77 | over are left in the result as dimensions with size one. With this option 78 | the result will broadcast correctly against the original x. 79 | Returns: 80 | The safe norm of the input vector, accounting for correct gradient. 81 | """ 82 | norm = jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=True) 83 | x = jnp.where(norm <= min_norm, jnp.ones_like(x), x) 84 | norm = jnp.squeeze(norm, axis=axis) if not keepdims else norm 85 | masked_norm = jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) 86 | return jnp.where(norm <= min_norm, min_norm, masked_norm) 87 | -------------------------------------------------------------------------------- /hyper_task_descriptions/modeling/roberta_partitioning.py: -------------------------------------------------------------------------------- 1 | """ 2 | Partition map for roberta-base 3 | See https://github.com/google-research/t5x/blob/main/docs/usage/partitioning.md 4 | for details on how this works. 5 | """ 6 | 7 | roberta_axes_names_override = [ 8 | (r"hyper/encoder/embeddings/[\w_]+/embedding", ("vocab", "embed")), 9 | # attention 10 | ( 11 | r"hyper/encoder/encoder/layer/\d+/attention/self/(query|key|value)/kernel", 12 | ("embed", "joined_kv"), 13 | ), 14 | (r"hyper/encoder/encoder/layer/\d+/attention/self/(query|key|value)/bias", ("joined_kv",)), 15 | (r"hyper/encoder/encoder/layer/\d+/attention/output/dense/kernel", ("joined_kv", "embed")), 16 | (r"hyper/encoder/encoder/layer/\d+/attention/output/dense/bias", ("embed",)), 17 | # intermediate 18 | (r"hyper/encoder/encoder/layer/\d+/intermediate/dense/kernel", ("embed", "mlp")), 19 | (r"hyper/encoder/encoder/layer/\d+/intermediate/dense/bias", ("mlp",)), 20 | # output 21 | (r"hyper/encoder/encoder/layer/\d+/output/dense/kernel", ("mlp", "embed")), 22 | (r"hyper/encoder/encoder/layer/\d+/output/dense/bias", ("embed",)), 23 | # layer norms 24 | (r"hyper/encoder/encoder/layer/\d+/[\w_\/]+/LayerNorm/bias", ("embed",)), 25 | (r"hyper/encoder/encoder/layer/\d+/[\w_\/]+/LayerNorm/scale", ("embed",)), 26 | (r"hyper/encoder/embeddings/LayerNorm/bias", ("embed",)), 27 | (r"hyper/encoder/embeddings/LayerNorm/scale", ("embed",)), 28 | # pooler 29 | (r"hyper/encoder/pooler/dense/kernel", ("embed", "mlp")), 30 | (r"hyper/encoder/pooler/dense/bias", ("embed",)), 31 | ] 32 | -------------------------------------------------------------------------------- /hyper_task_descriptions/modeling/t5_partitioning.py: -------------------------------------------------------------------------------- 1 | """ 2 | Partition map for t5 v1.1 3 | See https://github.com/google-research/t5x/blob/main/docs/usage/partitioning.md 4 | for details on how this works. 5 | """ 6 | 7 | t5_axes_names_override = [ 8 | (r"hyper/encoder/shared/embedding", ("vocab", "embed")), 9 | ( 10 | r"hyper/encoder/encoder/block/\d+/layer/0/SelfAttention/(q|k|v)/kernel", 11 | ("embed", "joined_kv"), 12 | ), 13 | (r"hyper/encoder/encoder/block/\d+/layer/0/SelfAttention/o/kernel", ("joined_kv", "embed")), 14 | ( 15 | r"hyper/encoder/encoder/block/\d+/layer/0/SelfAttention/relative_attention_bias/embedding", 16 | ("heads", "relpos_buckets"), 17 | ), 18 | (r"hyper/encoder/encoder/block/\d+/layer/0/layer_norm/weight", ("embed",)), 19 | (r"hyper/encoder/encoder/block/\d+/layer/1/DenseReluDense/wi_0/kernel", ("embed", "mlp")), 20 | (r"hyper/encoder/encoder/block/\d+/layer/1/DenseReluDense/wi_1/kernel", ("embed", "mlp")), 21 | (r"hyper/encoder/encoder/block/\d+/layer/1/DenseReluDense/wo/kernel", ("mlp", "embed")), 22 | (r"hyper/encoder/encoder/block/\d+/layer/1/layer_norm/weight", ("embed",)), 23 | (r"hyper/encoder/encoder/final_layer_norm/weight", ("embed",)), 24 | ] 25 | -------------------------------------------------------------------------------- /hyper_task_descriptions/ni_tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/hyper-task-descriptions/fd86a43ac2131582548130d86d57ea977c804ab6/hyper_task_descriptions/ni_tasks/__init__.py -------------------------------------------------------------------------------- /hyper_task_descriptions/python_scripts/lora_params.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | 4 | def compute_num_lora_params( 5 | num_encoder_layers: int, 6 | num_decoder_layers: int, 7 | emb_dim: int, 8 | lora_ranks: Tuple, 9 | num_heads: int, 10 | head_dim: int, 11 | ) -> int: 12 | total_layers = num_encoder_layers + 2 * num_decoder_layers 13 | 14 | q_rank = lora_ranks[0] or 0 15 | q_A = total_layers * emb_dim * q_rank 16 | q_B = total_layers * q_rank * num_heads * head_dim 17 | 18 | k_rank = lora_ranks[1] or 0 19 | k_A = total_layers * emb_dim * k_rank 20 | k_B = total_layers * k_rank * num_heads * head_dim 21 | 22 | v_rank = lora_ranks[2] or 0 23 | v_A = total_layers * emb_dim * v_rank 24 | v_B = total_layers * v_rank * num_heads * head_dim 25 | 26 | o_rank = lora_ranks[3] or 0 27 | o_A = total_layers * num_heads * head_dim * o_rank 28 | o_B = total_layers * o_rank * emb_dim 29 | 30 | lora_A_params = q_A + k_A + v_A + o_A 31 | lora_B_params = q_B + k_B + v_B + o_B 32 | 33 | print(f"Total lora A parameters: {lora_A_params}") 34 | print(f"Total lora B parameters: {lora_B_params}") 35 | print(f"Total lora parameters: {lora_A_params + lora_B_params}") 36 | 37 | return lora_A_params + lora_B_params 38 | 39 | 40 | if __name__ == "__main__": 41 | pass 42 | -------------------------------------------------------------------------------- /hyper_task_descriptions/python_scripts/make_p3_boxplot.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import matplotlib.pyplot as plt 4 | 5 | with open(sys.argv[1], "r") as f: 6 | data = [x.strip().split("\t") for x in f.readlines()][1:] 7 | 8 | tasks = sorted(list(set([x[0] for x in data]))) 9 | 10 | 11 | def get_model_data(model_idx, data): 12 | task_values = {} 13 | for line in data: 14 | task = line[0] 15 | if task not in task_values: 16 | task_values[task] = [] 17 | if model_idx < len(line): 18 | task_values[task].append(float(line[model_idx])) 19 | else: 20 | task_values[task].append(0) 21 | return task_values 22 | 23 | 24 | my_t0_3b = get_model_data(2, data) 25 | original_t0_3b = get_model_data(3, data) 26 | scratch_t0_3b = get_model_data(4, data) 27 | 28 | original_t0_11b = get_model_data(6, data) 29 | scratch_t0_11b = get_model_data(7, data) 30 | 31 | 32 | fig, axs = plt.subplots(2, 6, figsize=(20, 10)) 33 | 34 | models_to_eval = [original_t0_3b, scratch_t0_3b, original_t0_11b, scratch_t0_11b] 35 | model_names = ["T0 3B", "T03Bp", "T011B", "T011Bp"] 36 | print(tasks) 37 | for i, task in enumerate(tasks): 38 | vert_idx = i // 6 39 | horiz_idx = i % 6 40 | axs[vert_idx, horiz_idx].set_title(task if task else "avg") 41 | axs[vert_idx, horiz_idx].set_xticklabels(model_names) 42 | axs[vert_idx, horiz_idx].boxplot([x[task] for x in models_to_eval if task in x]) 43 | for j, model in enumerate(models_to_eval): 44 | axs[vert_idx, horiz_idx].scatter([j + 1] * len(model[task]), model[task]) 45 | plt.tight_layout() 46 | plt.savefig("boxplot.png") 47 | # plt.show() 48 | -------------------------------------------------------------------------------- /hyper_task_descriptions/python_scripts/p3_results.py: -------------------------------------------------------------------------------- 1 | """ 2 | Point this at a bucket folder with results as formatted by t5x to get min/max/avg scores per prompt. 3 | TODO: also output per-prompt scores in easy-to-paste format. 4 | """ 5 | import argparse 6 | import json 7 | from collections import defaultdict 8 | from pathlib import Path 9 | from typing import Dict, Tuple 10 | 11 | from google.cloud import storage 12 | from tqdm import tqdm 13 | 14 | client = storage.Client(project="ai2-tpu") 15 | bucket = client.bucket("hamishi-us-bucket") 16 | 17 | 18 | # function from: 19 | # https://github.com/bigscience-workshop/architecture-objective/blob/main/bigscience/eval-spreadsheet/parse_promptsource.py 20 | def process_task_prompt(task_prompt: str) -> Tuple[str, str]: 21 | task_prompt = task_prompt[:-11] # Remove 'score_eval' string at the end 22 | 23 | task, prompt = None, None 24 | if "anli" in task_prompt: 25 | task = "anli" + task_prompt[-3:] 26 | prompt = task_prompt[5:-3] 27 | elif "hellaswag" in task_prompt: 28 | task = "hellaswag" 29 | prompt = task_prompt[10:] 30 | elif "story_cloze" in task_prompt: 31 | task = "story_cloze" 32 | prompt = task_prompt[17:] 33 | elif "super_glue" in task_prompt: 34 | if "cb" in task_prompt: 35 | task = "cb" 36 | prompt = task_prompt[14:] 37 | elif "copa" in task_prompt: 38 | task = "copa" 39 | prompt = task_prompt[16:] 40 | elif "rte" in task_prompt: 41 | task = "rte" 42 | prompt = task_prompt[15:] 43 | elif "wic" in task_prompt: 44 | task = "wic" 45 | prompt = task_prompt[15:] 46 | elif "wsc" in task_prompt: 47 | task = "wsc" 48 | prompt = task_prompt[15:] 49 | elif "winogrande" in task_prompt: 50 | task = "winogrande" 51 | prompt = task_prompt[25:] 52 | 53 | if task is None or prompt is None: 54 | raise ValueError(f"Failed to parse task/prompt: {task_prompt}") 55 | 56 | return task, prompt 57 | 58 | 59 | def process_ps_results(folder: str) -> Dict[str, Dict[str, float]]: 60 | accuracies: Dict[str, Dict[str, float]] = defaultdict(dict) 61 | for blob in tqdm(bucket.list_blobs(prefix=f"{folder}")): 62 | if blob.name.endswith("-metrics.jsonl"): 63 | s = blob.download_as_string().decode("utf-8") 64 | filename = Path(blob.name).stem.replace("-metrics", "") 65 | task, prompt = process_task_prompt(filename) 66 | # last step 67 | accuracies[task][prompt] = json.loads(s.split("\n")[-2])["accuracy"] 68 | return accuracies 69 | 70 | 71 | # min, max, average per prompt 72 | def summarise_ps_results(accuracies: Dict[str, Dict[str, float]]) -> None: 73 | print("TASK: MIN MAX AVG MED") 74 | from statistics import mean, median 75 | 76 | for task in accuracies: 77 | scores = [x for x in accuracies[task].values()] 78 | print( 79 | f"{task}: {min(scores):.2f} {max(scores):.2f} {mean(scores):.2f} {median(scores):.2f}" 80 | ) 81 | print("-------------" * 2) 82 | 83 | 84 | def print_all_results(accuracies: Dict[str, Dict[str, float]]) -> None: 85 | all_accuracies = [] 86 | for task in accuracies: 87 | for prompt, score in accuracies[task].items(): 88 | print(f"{task} {prompt} {score:.2f}") 89 | all_accuracies.append(score) 90 | print() 91 | print(f"Average overall accuracy: {sum(all_accuracies) / len(all_accuracies)}") 92 | 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser(description="Get formatted results from p3/t0 eval mixture") 96 | parser.add_argument( 97 | "-f", 98 | "--res-folder", 99 | type=str, 100 | help="Path to a promptsource .json result file", 101 | ) 102 | args = parser.parse_args() 103 | 104 | accuracies = process_ps_results(args.res_folder) 105 | summarise_ps_results(accuracies) 106 | print_all_results(accuracies) 107 | print("You should be able to copy/paste the above! :)") 108 | -------------------------------------------------------------------------------- /hyper_task_descriptions/python_scripts/poking_the_bear.py: -------------------------------------------------------------------------------- 1 | """ 2 | Export the roberta model from the hypernetwork into a roberta huggingface model. 3 | """ 4 | # flake8: noqa 5 | import argparse 6 | 7 | import jax 8 | from jax import numpy as jnp 9 | from t5x import checkpoints 10 | from transformers import AutoTokenizer, FlaxT5EncoderModel, T5EncoderModel 11 | 12 | 13 | def extract_roberta_model(t5x_checkpoint_path, flax_dump_folder_path): 14 | t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) 15 | hf_model = FlaxT5EncoderModel.from_pretrained("google/t5-large-lm-adapt") 16 | hf_model.params = t5x_model["target"]["hyper"]["encoder"] 17 | hf_model.save_pretrained(flax_dump_folder_path) 18 | model = T5EncoderModel.from_pretrained(flax_dump_folder_path, from_flax=True) 19 | model.save_pretrained(flax_dump_folder_path + "_pytorch") 20 | 21 | 22 | def get_model_output(model, tok, text): 23 | return model(tok(text, return_tensors="np")["input_ids"])[0] 24 | 25 | 26 | def get_attention_values(model, t5x_model, tok, text): 27 | output = model(tok(text, return_tensors="np")["input_ids"])[0][0] 28 | layer_embeds = t5x_model["target"]["hyper"]["component_embedding"] 29 | # run attn 30 | attn_weights = jnp.einsum("qd,kd->qk", layer_embeds, output) 31 | attn_weights = jax.nn.softmax(attn_weights) 32 | return attn_weights 33 | 34 | 35 | def play_with_model(t5x_checkpoint_path): 36 | tok = AutoTokenizer.from_pretrained("t5-base") 37 | t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) 38 | # hf_model = FlaxT5EncoderModel.from_pretrained("google/t5-large-lm-adapt", from_pt=True) 39 | # hf_model.params = t5x_model["target"]["hyper"]["encoder"] 40 | # get_out = lambda x: get_model_output(hf_model, tok, x) 41 | # get_attn = lambda x: get_attention_values(hf_model, t5x_model, tok, x) 42 | import pdb 43 | 44 | pdb.set_trace() 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | # Required parameters 50 | parser.add_argument( 51 | "--t5x_checkpoint_path", 52 | "-t", 53 | default=None, 54 | type=str, 55 | required=True, 56 | help="Path the TX5 checkpoint.", 57 | ) 58 | parser.add_argument( 59 | "--flax_dump_folder_path", 60 | "-d", 61 | default=None, 62 | type=str, 63 | required=False, 64 | help="Path the TX5 checkpoint.", 65 | ) 66 | args = parser.parse_args() 67 | # extract_roberta_model(args.t5x_checkpoint_path, args.flax_dump_folder_path) 68 | play_with_model(args.t5x_checkpoint_path) 69 | -------------------------------------------------------------------------------- /hyper_task_descriptions/python_scripts/readme.md: -------------------------------------------------------------------------------- 1 | # Python Scripts 2 | 3 | `p3_results.py` - script for displaying results from evaluation easy. 4 | `fixed_roberta.py` - make and upload the slightly altered roberta model. 5 | 6 | ## poking_the_bear.py 7 | 8 | This is a fun little script for testing pretrained models. First, download a checkpoint locally (this may take a while). Make sure you are in the right env (install requrements and add the repo to your python path): 9 | ``` 10 | pip install -r requirements.txt 11 | export PYTHONPATH=$(pwd) 12 | ``` 13 | 14 | You can then run the script as follows: 15 | ``` 16 | python hyper_task_descriptions/python_scripts/poking_the_bear.py -t 17 | ``` 18 | 19 | This will then put you into a `pdb` shell after loading, and I have provided two functions to show how to play: `get_out` and `get_attn`, which get the hyperencoder output and the cross-attention probabilities respectively. Feel free to mess around and add your own functions etc. I was lazy here so if you want to emulate different model output processing, you'll have to code that up yourself. 20 | -------------------------------------------------------------------------------- /hyper_task_descriptions/python_scripts/test_all_tf_records.py: -------------------------------------------------------------------------------- 1 | """ 2 | merge tfrecords. 3 | testing this out. 4 | """ 5 | import glob 6 | import os 7 | from multiprocessing import Pool 8 | 9 | import tensorflow as tf 10 | from tqdm import tqdm 11 | 12 | 13 | def test_tf_record(dir): 14 | for split in ["train", "test", "validation"]: 15 | tf_record_list = glob.glob(f"{dir}/{split}.tfrecord*") 16 | assert len(tf_record_list) < 2, f"{split} has more than one tfrecord file" 17 | if len(tf_record_list) == 0: 18 | continue 19 | dataset = tf.data.TFRecordDataset(tf_record_list) 20 | try: 21 | for _ in tqdm(dataset): 22 | pass 23 | except tf.python.framework.errors_impl.DataLossError as e: 24 | print(f"\n============== {dir} {split} failed ==============\n") 25 | raise e 26 | 27 | 28 | def main(): 29 | dirs = os.listdir("t0_data_new") 30 | dirs = [f"t0_data_new/{dir}" for dir in dirs] 31 | pool = Pool() 32 | for _ in tqdm(pool.imap_unordered(test_tf_record, dirs), total=len(dirs)): 33 | pass 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /hyper_task_descriptions/seqio_tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/hyper-task-descriptions/fd86a43ac2131582548130d86d57ea977c804ab6/hyper_task_descriptions/seqio_tasks/__init__.py -------------------------------------------------------------------------------- /hyper_task_descriptions/seqio_tasks/all_t0_task_prefixes.txt: -------------------------------------------------------------------------------- 1 | adversarial_qa_dbert 2 | adversarial_qa_dbidaf 3 | adversarial_qa_droberta 4 | ag_news 5 | amazon_polarity 6 | app_reviews 7 | cnn_dailymail_3.0.0 8 | common_gen 9 | cos_e_v1.11 10 | cosmos_qa 11 | dbpedia_14 12 | dream 13 | duorc_ParaphraseRC 14 | duorc_SelfRC 15 | gigaword 16 | glue_mrpc 17 | glue_qqp 18 | imdb 19 | kilt 20 | multi_news 21 | paws_labeled_final 22 | qasc 23 | quail 24 | quarel 25 | quartz 26 | quoref 27 | ropes 28 | rotten_tomatoes 29 | samsum 30 | sciq 31 | social_i_qa 32 | trec 33 | wiki_bio 34 | wiki_hop 35 | wiki_qa 36 | wiqa 37 | xsum 38 | yelp 39 | -------------------------------------------------------------------------------- /hyper_task_descriptions/seqio_tasks/beam_requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | numpy 3 | sentencepiece 4 | tensorflow-text 5 | tfds-nightly 6 | t5 7 | tqdm 8 | promptsource 9 | tensorflow 10 | transformers -------------------------------------------------------------------------------- /hyper_task_descriptions/seqio_tasks/check_t0_tasks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Quick script for checking if T0 caching is complete. 3 | """ 4 | from google.cloud import storage 5 | 6 | storage_client = storage.Client(project="ai2-allennlp") 7 | bucket = storage_client.bucket("hamishi-tpu-bucket") 8 | 9 | name = "COMPLETED" 10 | 11 | tasks = [t.strip().decode("utf-8") for t in open("all_t0_tasks.txt", "rb").readlines()] 12 | 13 | for task in tasks: 14 | if not storage.Blob(bucket=bucket, name=f"t0_data_split_descr/{task}/COMPLETED").exists( 15 | storage_client 16 | ): 17 | print(task) 18 | -------------------------------------------------------------------------------- /hyper_task_descriptions/seqio_tasks/datasets.csv: -------------------------------------------------------------------------------- 1 | HF_name,subset,task_by_convention,do_train,do_eval,train_size 2 | crows_pairs,,bias_and_fairness,,BIAS_FAIRNESS, 3 | jigsaw_toxicity_pred,,bias_and_fairness,,BIAS_FAIRNESS, 4 | super_glue,axg,bias_and_fairness,,BIAS_FAIRNESS, 5 | wino_bias,type1_anti,bias_and_fairness,,BIAS_FAIRNESS, 6 | wino_bias,type2_anti,bias_and_fairness,,BIAS_FAIRNESS, 7 | wino_bias,type1_pro,bias_and_fairness,,BIAS_FAIRNESS, 8 | wino_bias,type2_pro,bias_and_fairness,,BIAS_FAIRNESS, 9 | super_glue,wsc.fixed,coreference,SGLUE,BASE,554 10 | winogrande,winogrande_xl,coreference,,BASE,40398 11 | super_glue,cb,NLI,,BASE,250 12 | super_glue,rte,NLI,,BASE,2490 13 | anli,,NLI,,BASE,162865 14 | glue,mrpc,paraphrase,BASE,,3668 15 | glue,qqp,paraphrase,BASE,,363846 16 | paws,labeled_final,paraphrase,BASE,,49401 17 | ai2_arc,ARC-Challenge,QA_closed_book,GPT_EVAL,,1119 18 | ai2_arc,ARC-Easy,QA_closed_book,GPT_EVAL,,2251 19 | kilt_tasks,hotpotqa,QA_closed_book,BASE,,88869 20 | trivia_qa,unfiltered,QA_closed_book,GPT_EVAL,,87622 21 | web_questions,,QA_closed_book,GPT_EVAL,,3778 22 | wiki_qa,,QA_closed_book,BASE,,20360 23 | adversarial_qa,dbidaf,QA_extractive,BASE,,10000 24 | adversarial_qa,dbert,QA_extractive,BASE,,10000 25 | adversarial_qa,droberta,QA_extractive,BASE,,10000 26 | duorc,SelfRC,QA_extractive,BASE,,60721 27 | duorc,ParaphraseRC,QA_extractive,BASE,,69524 28 | ropes,,QA_extractive,BASE,,10924 29 | squad_v2,,QA_extractive,GPT_EVAL,,130319 30 | super_glue,record,QA_extractive,SGLUE,,100730 31 | quoref,,QA_extractive,BASE,,19399 32 | cos_e,v1.11,QA_multiple_choice,BASE,,9741 33 | cosmos_qa,,QA_multiple_choice,BASE,,25262 34 | dream,,QA_multiple_choice,BASE,,6116 35 | openbookqa,main,QA_multiple_choice,GPT_EVAL,,4957 36 | qasc,,QA_multiple_choice,BASE,,8134 37 | quail,,QA_multiple_choice,BASE,,10246 38 | quarel,,QA_multiple_choice,BASE,,1941 39 | quartz,,QA_multiple_choice,BASE,,2696 40 | race,high,QA_multiple_choice,GPT_EVAL,,62445 41 | race,middle,QA_multiple_choice,GPT_EVAL,,25421 42 | sciq,,QA_multiple_choice,BASE,,11679 43 | social_i_qa,,QA_multiple_choice,BASE,,33410 44 | super_glue,boolq,QA_multiple_choice,SGLUE,,9427 45 | super_glue,copa,QA_multiple_choice,SGLUE,BASE,400 46 | super_glue,multirc,QA_multiple_choice,SGLUE,,27243 47 | wiki_hop,original,QA_multiple_choice,BASE,,43738 48 | wiqa,,QA_multiple_choice,BASE,,29808 49 | piqa,,QA_multiple_choice,GPT_EVAL,,16113 50 | amazon_polarity,,sentiment,BASE,,3600000 51 | app_reviews,,sentiment,BASE,,288065 52 | imdb,,sentiment,BASE,,25000 53 | rotten_tomatoes,,sentiment,BASE,,8530 54 | yelp_review_full,,sentiment,BASE,,650000 55 | story_cloze,2016,story_completion,,BASE, 56 | hellaswag,,story_completion,GPT_EVAL,BASE,39905 57 | common_gen,,structure_to_text,BASE,,67389 58 | wiki_bio,,structure_to_text,BASE,,582659 59 | cnn_dailymail,3.0.0,summarization,BASE,,287113 60 | gigaword,,summarization,BASE,,3803957 61 | multi_news,,summarization,BASE,,44972 62 | samsum,,summarization,BASE,,14732 63 | xsum,,summarization,BASE,,204045 64 | ag_news,,topic_classification,BASE,,120000 65 | dbpedia_14,,topic_classification,BASE,,560000 66 | trec,,topic_classification,BASE,,5452 67 | super_glue,wic,word_sense_disambiguation,SGLUE,BASE,5428 -------------------------------------------------------------------------------- /hyper_task_descriptions/seqio_tasks/readme.md: -------------------------------------------------------------------------------- 1 | # SEQIO TASKS 2 | 3 | This folder defines task registries used for T0, and an additional registry that can be loaded independently to allow for development without loading all of P3. 4 | 5 | t5x expects that tasks are preprocessed and cached prior to running. While I should have a gcloud bucket setup with everything needed, you can also cache tasks yourself with the following command. The task regex will cache all matching tasks defined in the imported module. I recommend caching to a shared drive or bucket so we can reuse data as much as possible :) 6 | 7 | ```bash 8 | seqio_cache_tasks \ 9 | --tasks= \ 10 | --output_cache_dir= \ 11 | --module_import=hyper_task_descriptions.seqio_tasks.all_t0_tasks 12 | ``` 13 | 14 | n.b. you'll probably have to add the repo location to your `PYTHONPATH` to be able to import the t0 task module. 15 | 16 | I'm working on a script for running all the preprocessing and caching it. 17 | -------------------------------------------------------------------------------- /hyper_task_descriptions/seqio_tasks/small_t0_tasks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines 't0_small_train', a minimal set of tasks to allow local dev without loading all of P3. 3 | """ 4 | import seqio 5 | 6 | from hyper_task_descriptions.seqio_tasks.t0_tasks import ( 7 | TASK_BLACKLIST, 8 | create_mixture_lists, 9 | load_t0_csv, 10 | ) 11 | 12 | t0_train, t0_eval, gsheet = load_t0_csv() 13 | 14 | # only want one instance. 15 | t0_train["GPT_EVAL"] = [] 16 | t0_train["SGLUE"] = [] 17 | t0_eval["BIAS_FAIRNESS"] = [] 18 | t0_train["BASE"] = t0_train["BASE"][:1] # [("super_glue", "rte"), ("super_glue", "cb")] 19 | t0_eval["BASE"] = t0_eval["BASE"][:1] # [("super_glue", "rte"), ("super_glue", "cb")] 20 | 21 | # download the dataset infos 22 | mixtures = create_mixture_lists(t0_train, t0_eval, gsheet) 23 | 24 | # create our singular mixture 25 | t0_train_mixture = mixtures[0] 26 | mixture_cap = mixtures[2] 27 | seqio.MixtureRegistry.add( 28 | "t0_small_train", 29 | [task for task in t0_train_mixture["BASE"] if task not in TASK_BLACKLIST], 30 | default_rate=lambda t: mixture_cap[t.name], 31 | ) 32 | -------------------------------------------------------------------------------- /hyper_task_descriptions/seqio_tasks/t0_datasets_mapping.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mapping for dataset names to ints. 3 | Because i had issues passing strings into seqio + t5x. 4 | """ 5 | 6 | T0_DS_MAPPING = { 7 | "super_glue_record": 0, 8 | "super_glue_multirc": 1, 9 | "super_glue_copa": 2, 10 | "super_glue_axg": 3, 11 | "super_glue_boolq": 4, 12 | "super_glue_rte": 5, 13 | "super_glue_wic": 6, 14 | "super_glue_wsc.fixed": 7, 15 | "super_glue_cb": 8, 16 | "rotten_tomatoes": 9, 17 | "glue_qqp": 10, 18 | "glue_mrpc": 11, 19 | "cosmos_qa": 12, 20 | "sciq": 13, 21 | "trec": 14, 22 | "wiki_qa": 15, 23 | "crows_pairs": 16, 24 | "kilt_tasks_hotpotqa": 17, 25 | "samsum": 18, 26 | "common_gen": 19, 27 | "paws_labeled_final": 20, 28 | "dbpedia_14": 21, 29 | "piqa": 22, 30 | "ropes": 23, 31 | "cos_e_v1.11": 24, 32 | "amazon_polarity": 25, 33 | "trivia_qa_unfiltered": 26, 34 | "cnn_dailymail_3.0.0": 27, 35 | "ai2_arc_ARC-Challenge": 28, 36 | "ai2_arc_ARC-Easy": 29, 37 | "quarel": 30, 38 | "gigaword": 31, 39 | "xsum": 32, 40 | "yelp_review_full": 33, 41 | "winogrande_winogrande_xl": 34, 42 | "wiki_bio": 35, 43 | "openbookqa_main": 36, 44 | "social_i_qa": 37, 45 | "web_questions": 38, 46 | "ag_news": 39, 47 | "qasc": 40, 48 | "wiki_hop_original": 41, 49 | "quoref": 42, 50 | "squad_v2": 43, 51 | "wiqa": 44, 52 | "imdb": 45, 53 | "race_middle": 46, 54 | "race_high": 47, 55 | "app_reviews": 48, 56 | "quartz": 49, 57 | "hellaswag": 50, 58 | "duorc_SelfRC": 51, 59 | "duorc_ParaphraseRC": 52, 60 | "multi_news": 53, 61 | "dream": 54, 62 | "wino_bias_type2_pro": 55, 63 | "wino_bias_type2_anti": 56, 64 | "wino_bias_type1_pro": 57, 65 | "wino_bias_type1_anti": 58, 66 | "quail": 59, 67 | "story_cloze_2016": 60, 68 | "adversarial_qa_dbert": 61, 69 | "adversarial_qa_dbidaf": 62, 70 | "adversarial_qa_droberta": 63, 71 | "anli": 64, 72 | } 73 | -------------------------------------------------------------------------------- /hyper_task_descriptions/version.py: -------------------------------------------------------------------------------- 1 | _MAJOR = "0" 2 | _MINOR = "1" 3 | # On main and in a nightly release the patch should be one ahead of the last 4 | # released build. 5 | _PATCH = "0" 6 | # This is mainly for nightly builds which have the suffix ".dev$DATE". See 7 | # https://semver.org/#is-v123-a-semantic-version for the semantics. 8 | _SUFFIX = "" 9 | 10 | VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) 11 | VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX) 12 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = true 3 | no_site_packages = true 4 | allow_redefinition = true 5 | 6 | [mypy-tests.*] 7 | strict_optional = false 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | 4 | include = '\.pyi?$' 5 | 6 | exclude = ''' 7 | ( 8 | __pycache__ 9 | | \.git 10 | | \.mypy_cache 11 | | \.pytest_cache 12 | | \.vscode 13 | | \.venv 14 | | \bdist\b 15 | | \bdoc\b 16 | ) 17 | ''' 18 | 19 | [tool.isort] 20 | profile = "black" 21 | multi_line_output = 3 22 | 23 | [build-system] 24 | requires = ["setuptools", "wheel"] 25 | build-backend = "setuptools.build_meta" 26 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = tests/ 3 | python_classes = Test* *Test 4 | log_format = %(asctime)s - %(levelname)s - %(name)s - %(message)s 5 | log_level = DEBUG 6 | markers = 7 | filterwarnings = 8 | ignore::DeprecationWarning 9 | -------------------------------------------------------------------------------- /scripts/eval/t0_eval.sh: -------------------------------------------------------------------------------- 1 | # Model dir to save logs, ckpts, etc. in "gs://model_dir" format. 2 | EXPERIMENT_NAME=$1 3 | CHECKPOINT_NAME=$2 4 | BUCKET_NAME="hamishi-tpu" 5 | 6 | # model checkpoint location 7 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model/${CHECKPOINT_NAME}" 8 | # where to put eval results 9 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval" 10 | 11 | # we go offline to avoid constant calls to get basic info (happens even when cached) 12 | # for your first run, you will probably need to run all these calls :( 13 | # note you pass in a model file and the eval file. 14 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 15 | --gin_search_paths="gins" \ 16 | --gin_file="hyper_xl.gin" \ 17 | --gin_file="t0_eval.gin" \ 18 | --gin.USE_CACHED_TASKS=True \ 19 | --gin.utils.DatasetConfig.batch_size=128 \ 20 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 21 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 22 | -------------------------------------------------------------------------------- /scripts/eval/t0_eval_adapter.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | BUCKET_NAME="hamishi-tpu" 4 | 5 | # where model will be saved 6 | MODEL_DIR="gs://${BUCKET_NAME}${EXPERIMENT_NAME}/model" 7 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval" 8 | 9 | 10 | # we go offline to avoid constant calls to get basic info (happens even when cached) 11 | # for your first run, you will probably need to run all these calls :( 12 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 13 | --gin_search_paths=gins \ 14 | --gin_file="hyper_xl.gin" \ 15 | --gin_file="instruction_embed.gin" \ 16 | --gin_file="t0_eval.gin" \ 17 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 18 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 19 | --gin.hyper_network.HyperT5Config.use_instructions=True \ 20 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 21 | --gin.utils.DatasetConfig.batch_size=128 \ 22 | --gin.USE_CACHED_TASKS=True \ 23 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 24 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" \ 25 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" 26 | -------------------------------------------------------------------------------- /scripts/eval/t0_eval_hypter.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | BUCKET_NAME="hamishi-tpu" 4 | 5 | # where model will be saved 6 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 7 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval" 8 | 9 | 10 | # we go offline to avoid constant calls to get basic info (happens even when cached) 11 | # for your first run, you will probably need to run all these calls :( 12 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 13 | --gin_search_paths=gins \ 14 | --gin_file="hyper_xl.gin" \ 15 | --gin_file="instruction_embed.gin" \ 16 | --gin_file="t0_eval.gin" \ 17 | --gin_file="hypter.gin" \ 18 | --gin.hyper_network.HyperT5Config.adapter_size=4 \ 19 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 20 | --gin.utils.DatasetConfig.batch_size=128 \ 21 | --gin.USE_CACHED_TASKS=True \ 22 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 23 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" \ 24 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" 25 | -------------------------------------------------------------------------------- /scripts/eval/t0_reg_eval.sh: -------------------------------------------------------------------------------- 1 | # Checkpoint to eval on 2 | EXPERIMENT_NAME=$1 3 | BUCKET_NAME="hamishi-tpu" 4 | 5 | # model checkpoint location 6 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 7 | # where to put eval results 8 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval" 9 | 10 | # we go offline to avoid constant calls to get basic info (happens even when cached) 11 | # for your first run, you will probably need to run all these calls :( 12 | # note you pass in a model file and the eval file. 13 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 14 | --gin_search_paths="gins" \ 15 | --gin_file="hyper_xl.gin" \ 16 | --gin_file="t0_eval.gin" \ 17 | --gin.USE_CACHED_TASKS=True \ 18 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 19 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 20 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 21 | --gin.utils.DatasetConfig.batch_size=128 \ 22 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 23 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" \ 24 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" 25 | -------------------------------------------------------------------------------- /scripts/local/debug_from_t5.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | BUCKET_NAME="hamishi-tpu" 4 | 5 | # where model will be saved 6 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 7 | 8 | # we go offline to avoid constant calls to get basic info (happens even when cached) 9 | # for your first run, you will probably need to run all these calls :( 10 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \ 11 | --gin_search_paths=gins \ 12 | --gin_file="hyper_xl.gin" \ 13 | --gin_file="t0_train.gin" \ 14 | --gin_file="partial_train_adam.gin" \ 15 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 16 | --gin.TRAIN_STEPS=1100100 \ 17 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 18 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000\" 19 | -------------------------------------------------------------------------------- /scripts/local/local.sh: -------------------------------------------------------------------------------- 1 | # l'il script for running locally. Make sure to login to gcloud for cached data, or change tfds_data_dir 2 | 3 | TRANSFORMERS_OFFLINE=1 HF_DATASETS_OFFLINE=1 JAX_DISABLE_JIT=1 python -m t5x.train \ 4 | --gin_search_paths=./gins \ 5 | --gin_file="hyper_small.gin" \ 6 | --gin_file="t0_train_local_copy.gin" \ 7 | --gin.MODEL_DIR=\"test\" \ 8 | --gin.hyper_network.HyperT5Config.hyperencoder_model=\"google/t5-small-lm-adapt\" \ 9 | --gin.USE_CACHED_TASKS=False \ 10 | --gin.TRAIN_STEPS=1100100 \ 11 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_small/checkpoint_1100000\" 12 | -------------------------------------------------------------------------------- /scripts/local/local_debug.sh: -------------------------------------------------------------------------------- 1 | # script for loading existing models locally 2 | # disable jit so I can debug. 3 | 4 | EXPERIMENT_NAME=$1 5 | BUCKET_NAME="hamishi-tpu" 6 | 7 | # where model will be saved 8 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 9 | 10 | TRANSFORMERS_OFFLINE=1 HF_DATASETS_OFFLINE=1 JAX_DISABLE_JIT=1 python -m t5x.train \ 11 | --gin_search_paths=./gins \ 12 | --gin_file="t0_train_local.gin" \ 13 | --gin.TRAIN_STEPS=1117000 \ 14 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 15 | --gin.DROPOUT_RATE=0.1 \ 16 | --gin.USE_CACHED_TASKS=False 17 | -------------------------------------------------------------------------------- /scripts/local/local_eval.sh: -------------------------------------------------------------------------------- 1 | # li'l script for testing eval stuff locally. 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | python3 -m t5x.eval \ 5 | --gin_search_paths="gins" \ 6 | --gin_file="hyper_xl.gin" \ 7 | --gin_file="catwalk_eval.gin" \ 8 | --gin.MIXTURE_OR_TASK_NAME=\"eleuther::cola\" \ 9 | --gin.EVAL_OUTPUT_DIR=\"model_eval_test\" \ 10 | --gin.CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/roberta_contrastive_dup_t0_3b/model/checkpoint_1123000\" \ 11 | --gin.USE_CACHED_TASKS=False 12 | -------------------------------------------------------------------------------- /scripts/local/lora_finetune.sh: -------------------------------------------------------------------------------- 1 | # l'il script for running locally. Make sure to login to gcloud for cached data, or change tfds_data_dir 2 | 3 | TRANSFORMERS_OFFLINE=1 HF_DATASETS_OFFLINE=1 python -m t5x.train \ 4 | --gin_search_paths=./gins \ 5 | --gin_file="t0_train_local.gin" \ 6 | --gin_file="lora/lora_small.gin" \ 7 | --gin.DROPOUT_RATE=0.1 \ 8 | --gin.MODEL_DIR=\"lora_test\" \ 9 | --gin.USE_CACHED_TASKS=False \ 10 | --gin.TRAIN_STEPS=1100010 \ 11 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_small/checkpoint_1100000\" 12 | -------------------------------------------------------------------------------- /scripts/local/lora_local.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | 4 | # where model will be saved 5 | MODEL_DIR="${EXPERIMENT_NAME}" 6 | 7 | # we go offline to avoid constant calls to get basic info (happens even when cached) 8 | # for your first run, you will probably need to run all these calls :( 9 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \ 10 | --gin_search_paths=gins \ 11 | --gin_file="lora/lora_small.gin" \ 12 | --gin_file="t0_train_local.gin" \ 13 | --gin_file="partial_train_adam.gin" \ 14 | --gin_file="lora/lora.gin" \ 15 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 16 | --gin.TRAIN_STEPS=1100010 \ 17 | --gin.partitioning.PjitPartitioner.num_partitions=1 \ 18 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_small/checkpoint_1100000\" \ 19 | --gin.USE_CACHED_TASKS=False 20 | -------------------------------------------------------------------------------- /scripts/lora/debug_lora.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | 4 | # where model will be saved 5 | MODEL_DIR="${EXPERIMENT_NAME}" 6 | 7 | # we go offline to avoid constant calls to get basic info (happens even when cached) 8 | # for your first run, you will probably need to run all these calls :( 9 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \ 10 | --gin_search_paths=gins \ 11 | --gin_file="lora/plain/lora_small.gin" \ 12 | --gin_file="t0_train_local.gin" \ 13 | --gin_file="partial_train_adam.gin" \ 14 | --gin_file="lora/lora.gin" \ 15 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 16 | --gin.TRAIN_STEPS=1101000 \ 17 | --gin.partitioning.PjitPartitioner.num_partitions=1 \ 18 | --gin.lora_network.LoraT5Config.lora_ranks="(4,None,4,None)" \ 19 | --gin.utils.create_learning_rate_scheduler.base_learning_rate=1e-4 \ 20 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_small/checkpoint_1100000\" 21 | -------------------------------------------------------------------------------- /scripts/lora/t0_lora_eval.sh: -------------------------------------------------------------------------------- 1 | # Model dir to save logs, ckpts, etc. in "gs://model_dir" format. 2 | EXPERIMENT_NAME=$1 3 | CHECKPOINT_NAME=$2 4 | BUCKET_NAME="hamishi-tpu" 5 | 6 | # model checkpoint location 7 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model/${CHECKPOINT_NAME}" 8 | # where to put eval results 9 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval" 10 | 11 | #MODEL_DIR="plain-lora-small-4n4n/model/checkpoint_1107000" 12 | #EVAL_OUTPUT_DIR="plain-lora-small-4n4n-eval" 13 | 14 | # we go offline to avoid constant calls to get basic info (happens even when cached) 15 | # for your first run, you will probably need to run all these calls :( 16 | # note you pass in a model file and the eval file. 17 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 18 | --gin_search_paths="gins" \ 19 | --gin_file="lora/plain/lora_xl.gin" \ 20 | --gin_file="t0_eval.gin" \ 21 | --gin.USE_CACHED_TASKS=True \ 22 | --gin.utils.DatasetConfig.batch_size=64 \ 23 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 24 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 25 | -------------------------------------------------------------------------------- /scripts/lora/train_lora_from_t5.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | BUCKET_NAME="hamishi-tpu" 4 | 5 | # where model will be saved 6 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 7 | # MODEL_DIR="${EXPERIMENT_NAME}" 8 | 9 | # we go offline to avoid constant calls to get basic info (happens even when cached) 10 | # for your first run, you will probably need to run all these calls :( 11 | HF_DATASETS_OFFLINE=0 TRANSFORMERS_OFFLINE=0 python3 -m t5x.train \ 12 | --gin_search_paths=gins \ 13 | --gin_file="lora/lora_small.gin" \ 14 | --gin_file="t0_train_local.gin" \ 15 | --gin_file="partial_train_adam.gin" \ 16 | --gin_file="lora/lora.gin" \ 17 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 18 | --gin.TRAIN_STEPS=1107000 \ 19 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 20 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_small/checkpoint_1100000\" 21 | -------------------------------------------------------------------------------- /scripts/lora/train_plain_lora_from_t5.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | BUCKET_NAME="hamishi-tpu" 4 | 5 | # where model will be saved 6 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 7 | # MODEL_DIR="${EXPERIMENT_NAME}" 8 | 9 | # we go offline to avoid constant calls to get basic info (happens even when cached) 10 | # for your first run, you will probably need to run all these calls :( 11 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \ 12 | --gin_search_paths=gins \ 13 | --gin_file="lora/plain/lora_xl.gin" \ 14 | --gin_file="t0_train.gin" \ 15 | --gin_file="partial_train_adam.gin" \ 16 | --gin_file="lora/lora.gin" \ 17 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 18 | --gin.TRAIN_STEPS=1107000 \ 19 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 20 | --gin.lora_network.LoraT5Config.lora_ranks="(32,None,32,None)" \ 21 | --gin.utils.create_learning_rate_scheduler.base_learning_rate=1e-5 \ 22 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000\" 23 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_eval.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # for standalone eval 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | CHECKPOINT=$2 7 | BUCKET_NAME="hamishi-tpu" 8 | 9 | # where model will be saved 10 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model/checkpoint_${CHECKPOINT}" 11 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 12 | 13 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 14 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 15 | --gin_search_paths="gins" \ 16 | --gin_file="hyper_xl.gin" \ 17 | --gin_file="instruction_embed.gin" \ 18 | --gin_file="ni_eval.gin" \ 19 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 20 | --gin.USE_CACHED_TASKS=True \ 21 | --gin.utils.DatasetConfig.batch_size=128 \ 22 | --gin.utils.DatasetConfig.split=\"test\" \ 23 | --gin.partitioning.PjitPartitioner.num_partitions=16 \ 24 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 25 | --gin.utils.RestoreCheckpointConfig.mode=\"specific\" \ 26 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 27 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_eval_base.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # for standalone eval 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | BUCKET_NAME="hamishi-tpu" 7 | 8 | # where model will be saved 9 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 10 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 11 | 12 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 13 | --gin_search_paths="gins" \ 14 | --gin_file="hyper_base.gin" \ 15 | --gin_file="instruction_embed.gin" \ 16 | --gin_file="ni_eval.gin" \ 17 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 18 | --gin.USE_CACHED_TASKS=True \ 19 | --gin.utils.DatasetConfig.batch_size=128 \ 20 | --gin.utils.DatasetConfig.split=\"test\" \ 21 | --gin.partitioning.PjitPartitioner.num_partitions=2 \ 22 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 23 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 24 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 25 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_eval_reg.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # for standalone eval 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | BUCKET_NAME="hamishi-tpu" 7 | 8 | # where model will be saved 9 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 10 | 11 | 12 | EVAL_OUTPUT_DIR="${EXPERIMENT_NAME}/eval/" 13 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 14 | --gin_search_paths="gins" \ 15 | --gin_file="hyper_base.gin" \ 16 | --gin_file="ni_eval.gin" \ 17 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 18 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 19 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 20 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions_def\" \ 21 | --gin.USE_CACHED_TASKS=True \ 22 | --gin.utils.DatasetConfig.batch_size=512 \ 23 | --gin.utils.DatasetConfig.split=\"test\" \ 24 | --gin.partitioning.PjitPartitioner.num_partitions=2 \ 25 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 26 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 27 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 28 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | 3 | # name of experiment folder 4 | EXPERIMENT_NAME=$1 5 | BUCKET_NAME="hamishi-tpu" 6 | 7 | # where model will be saved 8 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 9 | 10 | python3 -m t5x.train \ 11 | --gin_search_paths=gins \ 12 | --gin_file="hyper_xl.gin" \ 13 | --gin_file="instruction_embed.gin" \ 14 | --gin_file="ni_train.gin" \ 15 | --gin_file="partial_train_adafactor_dual.gin" \ 16 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 17 | --gin.USE_CACHED_TASKS=True \ 18 | --gin.trainer.Trainer.num_microbatches=16 \ 19 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 20 | --gin.BATCH_SIZE=1024 \ 21 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 22 | --gin.TRAIN_STEPS=1101000 \ 23 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 24 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 25 | 26 | echo "Training done. Now evaluating all checkpoints..." 27 | 28 | 29 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 30 | 31 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 32 | --gin_search_paths="gins" \ 33 | --gin_file="hyper_xl.gin" \ 34 | --gin_file="instruction_embed.gin" \ 35 | --gin_file="ni_eval.gin" \ 36 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 37 | --gin.USE_CACHED_TASKS=True \ 38 | --gin.utils.DatasetConfig.batch_size=512 \ 39 | --gin.utils.DatasetConfig.split=\"test\" \ 40 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 41 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 42 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 43 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 44 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_debug.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | 3 | # name of experiment folder 4 | checkpoint=$1 5 | 6 | # where model will be saved 7 | MODEL_DIR="test/model" 8 | 9 | # JAX_DISABLE_JIT=1 python3 -m t5x.train \ 10 | # --gin_search_paths=gins \ 11 | # --gin_file="hyper_xl.gin" \ 12 | # --gin_file="instruction_embed.gin" \ 13 | # --gin_file="ni_train.gin" \ 14 | # --gin_file="partial_train_adafactor.gin" \ 15 | # --gin_file="full_restore.gin" \ 16 | # --gin.USE_CACHED_TASKS=True \ 17 | # --gin.trainer.Trainer.num_microbatches=1 \ 18 | # --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 19 | # --gin.BATCH_SIZE=2 \ 20 | # --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 21 | # --gin.TRAIN_STEPS=1170000 \ 22 | # --gin.partitioning.PjitPartitioner.num_partitions=8 \ 23 | # --gin.partitioning.PjitPartitioner.use_cpu_pjit=True \ 24 | # --gin.train.use_gda=False \ 25 | # --gin.INITIAL_CHECKPOINT_PATH=\"checkpoint_1111000\" 26 | 27 | 28 | 29 | JAX_DISABLE_JIT=1 python3 -m t5x.train \ 30 | --gin_search_paths=gins \ 31 | --gin_file="hyper_xl.gin" \ 32 | --gin_file="ni_train.gin" \ 33 | --gin_file="partial_train_adafactor.gin" \ 34 | --gin_file="full_restore.gin" \ 35 | --gin.USE_CACHED_TASKS=True \ 36 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 37 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 38 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 39 | --gin.trainer.Trainer.num_microbatches=1 \ 40 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 41 | --gin.BATCH_SIZE=2 \ 42 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 43 | --gin.TRAIN_STEPS=1170000 \ 44 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 45 | --gin.partitioning.PjitPartitioner.use_cpu_pjit=True \ 46 | --gin.train.use_gda=False \ 47 | --gin.INITIAL_CHECKPOINT_PATH=\"checkpoint_1101000\" 48 | 49 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_htune.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="train_only_hnet.gin" \ 24 | --gin_file="hypertune.gin" \ 25 | --gin_file="full_restore.gin" \ 26 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 27 | --gin.USE_CACHED_TASKS=True \ 28 | --gin.trainer.Trainer.num_microbatches=16 \ 29 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 30 | --gin.BATCH_SIZE=1024 \ 31 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 32 | --gin.TRAIN_STEPS=$4 \ 33 | --gin.partitioning.PjitPartitioner.num_partitions=16 \ 34 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/checkpoint_${3}\" 35 | 36 | 37 | echo "Training done. Now evaluating all checkpoints..." 38 | 39 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 40 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 41 | --gin_search_paths="gins" \ 42 | --gin_file="hyper_xl.gin" \ 43 | --gin_file="instruction_embed.gin" \ 44 | --gin_file="ni_eval.gin" \ 45 | --gin_file="hypertune.gin" \ 46 | --gin_file="full_restore.gin" \ 47 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 48 | --gin.USE_CACHED_TASKS=True \ 49 | --gin.utils.DatasetConfig.batch_size=512 \ 50 | --gin.utils.DatasetConfig.split=\"test\" \ 51 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 52 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 53 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 54 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 55 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_hypter.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | 3 | # name of experiment folder 4 | EXPERIMENT_NAME=$1 5 | BUCKET_NAME="hamishi-tpu" 6 | 7 | # where model will be saved 8 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 9 | 10 | python3 -m t5x.train \ 11 | --gin_search_paths=gins \ 12 | --gin_file="hyper_base.gin" \ 13 | --gin_file="instruction_embed.gin" \ 14 | --gin_file="ni_train.gin" \ 15 | --gin_file="partial_train_adafactor_no_roberta.gin" \ 16 | --gin_file="hypter.gin" \ 17 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instruction_positive_example_hyper\" \ 18 | --gin.USE_CACHED_TASKS=True \ 19 | --gin.trainer.Trainer.num_microbatches=128 \ 20 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 21 | --gin.BATCH_SIZE=1024 \ 22 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 23 | --gin.TRAIN_STEPS=1101000 \ 24 | --gin.partitioning.PjitPartitioner.num_partitions=2 \ 25 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_base/checkpoint_1100000/\" 26 | 27 | echo "Training done. Now evaluating all checkpoints..." 28 | 29 | 30 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 31 | 32 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 33 | --gin_search_paths="gins" \ 34 | --gin_file="hyper_base.gin" \ 35 | --gin_file="instruction_embed.gin" \ 36 | --gin_file="ni_eval.gin" \ 37 | --gin_file="hypter.gin" \ 38 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instruction_positive_example_hyper\" \ 39 | --gin.USE_CACHED_TASKS=True \ 40 | --gin.utils.DatasetConfig.batch_size=128 \ 41 | --gin.utils.DatasetConfig.split=\"test\" \ 42 | --gin.partitioning.PjitPartitioner.num_partitions=2 \ 43 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 44 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 45 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 46 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_mixed.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | 3 | # name of experiment folder 4 | EXPERIMENT_NAME=$1 5 | BUCKET_NAME="hamishi-tpu" 6 | 7 | # where model will be saved 8 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 9 | 10 | python3 -m t5x.train \ 11 | --gin_search_paths=gins \ 12 | --gin_file="hyper_xl.gin" \ 13 | --gin_file="instruction_embed.gin" \ 14 | --gin_file="ni_train_mixed.gin" \ 15 | --gin_file="partial_train_adafactor_dual.gin" \ 16 | --gin.MIXTURE_OR_TASK_NAME=\"c4_ni\" \ 17 | --gin.USE_CACHED_TASKS=True \ 18 | --gin.trainer.Trainer.num_microbatches=16 \ 19 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 20 | --gin.BATCH_SIZE=1024 \ 21 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 22 | --gin.TRAIN_STEPS=1101000 \ 23 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 24 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 25 | 26 | echo "Training done. Now evaluating all checkpoints..." 27 | 28 | 29 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 30 | 31 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 32 | --gin_search_paths="gins" \ 33 | --gin_file="hyper_xl.gin" \ 34 | --gin_file="instruction_embed.gin" \ 35 | --gin_file="ni_eval.gin" \ 36 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 37 | --gin.USE_CACHED_TASKS=True \ 38 | --gin.utils.DatasetConfig.batch_size=512 \ 39 | --gin.utils.DatasetConfig.split=\"test\" \ 40 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 41 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 42 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 43 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 44 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_mixed_base.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | 3 | # name of experiment folder 4 | EXPERIMENT_NAME=$1 5 | BUCKET_NAME="hamishi-tpu" 6 | 7 | # where model will be saved 8 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 9 | 10 | python3 -m t5x.train \ 11 | --gin_search_paths=gins \ 12 | --gin_file="hyper_base.gin" \ 13 | --gin_file="instruction_embed.gin" \ 14 | --gin_file="ni_train_mixed.gin" \ 15 | --gin_file="partial_train_adafactor_dual.gin" \ 16 | --gin.MIXTURE_OR_TASK_NAME=\"c4_ni\" \ 17 | --gin.USE_CACHED_TASKS=True \ 18 | --gin.trainer.Trainer.num_microbatches=16 \ 19 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 20 | --gin.BATCH_SIZE=1024 \ 21 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 22 | --gin.TRAIN_STEPS=1105000 \ 23 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 24 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_base/checkpoint_1100000/\" 25 | 26 | echo "Training done. Now evaluating all checkpoints..." 27 | 28 | 29 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 30 | 31 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 32 | --gin_search_paths="gins" \ 33 | --gin_file="hyper_base.gin" \ 34 | --gin_file="instruction_embed.gin" \ 35 | --gin_file="ni_eval.gin" \ 36 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 37 | --gin.USE_CACHED_TASKS=True \ 38 | --gin.utils.DatasetConfig.batch_size=512 \ 39 | --gin.utils.DatasetConfig.split=\"test\" \ 40 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 41 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 42 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 43 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 44 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_no_fid.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | 3 | # name of experiment folder 4 | EXPERIMENT_NAME=$1 5 | BUCKET_NAME="hamishi-tpu" 6 | 7 | # where model will be saved 8 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 9 | 10 | python3 -m t5x.train \ 11 | --gin_search_paths=gins \ 12 | --gin_file="hyper_xl.gin" \ 13 | --gin_file="instruction_embed.gin" \ 14 | --gin_file="ni_train.gin" \ 15 | --gin_file="partial_train_adafactor.gin" \ 16 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 17 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 18 | --gin.USE_CACHED_TASKS=True \ 19 | --gin.trainer.Trainer.num_microbatches=16 \ 20 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 21 | --gin.BATCH_SIZE=1024 \ 22 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 23 | --gin.TRAIN_STEPS=1101000 \ 24 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 25 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 26 | 27 | echo "Training done. Now evaluating all checkpoints..." 28 | 29 | 30 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 31 | 32 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 33 | --gin_search_paths="gins" \ 34 | --gin_file="hyper_xl.gin" \ 35 | --gin_file="instruction_embed.gin" \ 36 | --gin_file="ni_eval.gin" \ 37 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 38 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 39 | --gin.USE_CACHED_TASKS=True \ 40 | --gin.utils.DatasetConfig.batch_size=512 \ 41 | --gin.utils.DatasetConfig.split=\"test\" \ 42 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 43 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 44 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 45 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 46 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_no_hnet.sh: -------------------------------------------------------------------------------- 1 | # NI training - no hnet 2 | 3 | # name of experiment folder 4 | EXPERIMENT_NAME=$1 5 | BUCKET_NAME="hamishi-tpu" 6 | 7 | # where model will be saved 8 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 9 | 10 | python3 -m t5x.train \ 11 | --gin_search_paths=gins \ 12 | --gin_file="hyper_base.gin" \ 13 | --gin_file="instruction_embed.gin" \ 14 | --gin_file="ni_train.gin" \ 15 | --gin_file="partial_train_adafactor_dual.gin" \ 16 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 17 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 18 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions_def_pos_2\" \ 19 | --gin.USE_CACHED_TASKS=True \ 20 | --gin.trainer.Trainer.num_microbatches=32 \ 21 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 22 | --gin.BATCH_SIZE=1024 \ 23 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 24 | --gin.TRAIN_STEPS=1101000 \ 25 | --gin.partitioning.PjitPartitioner.num_partitions=2 \ 26 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_base/checkpoint_1100000/\" 27 | 28 | echo "Training done. Now evaluating all checkpoints..." 29 | 30 | 31 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 32 | 33 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 34 | --gin_search_paths="gins" \ 35 | --gin_file="hyper_base.gin" \ 36 | --gin_file="instruction_embed.gin" \ 37 | --gin_file="ni_eval.gin" \ 38 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 39 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 40 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions_def_pos_2\" \ 41 | --gin.USE_CACHED_TASKS=True \ 42 | --gin.utils.DatasetConfig.batch_size=128 \ 43 | --gin.utils.DatasetConfig.split=\"test\" \ 44 | --gin.partitioning.PjitPartitioner.num_partitions=2 \ 45 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 46 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 47 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 48 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_only_adapter_no_fid.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="full_restore.gin" \ 24 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 25 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 26 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 27 | --gin.USE_CACHED_TASKS=True \ 28 | --gin.trainer.Trainer.num_microbatches=16 \ 29 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 30 | --gin.BATCH_SIZE=1024 \ 31 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 32 | --gin.TRAIN_STEPS=$4 \ 33 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 34 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 35 | 36 | 37 | echo "Training done. Now evaluating all checkpoints..." 38 | 39 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 40 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 41 | --gin_search_paths="gins" \ 42 | --gin_file="hyper_xl.gin" \ 43 | --gin_file="instruction_embed.gin" \ 44 | --gin_file="ni_eval.gin" \ 45 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 46 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 47 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 48 | --gin.USE_CACHED_TASKS=True \ 49 | --gin.utils.DatasetConfig.batch_size=256 \ 50 | --gin.utils.DatasetConfig.split=\"test\" \ 51 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 52 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 53 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 54 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 55 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_only_lora_no_fid.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="full_restore.gin" \ 24 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 25 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 26 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 27 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 28 | --gin.hyper_network.HyperT5Config.use_lora=True \ 29 | --gin.hyper_network.HyperT5Config.lora_ranks="(512,None,512,None)" \ 30 | --gin.USE_CACHED_TASKS=True \ 31 | --gin.trainer.Trainer.num_microbatches=16 \ 32 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 33 | --gin.BATCH_SIZE=1024 \ 34 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 35 | --gin.TRAIN_STEPS=$4 \ 36 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 37 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 38 | 39 | 40 | echo "Training done. Now evaluating all checkpoints..." 41 | 42 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 43 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 44 | --gin_search_paths="gins" \ 45 | --gin_file="hyper_xl.gin" \ 46 | --gin_file="instruction_embed.gin" \ 47 | --gin_file="ni_eval.gin" \ 48 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 49 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 50 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 51 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 52 | --gin.hyper_network.HyperT5Config.use_lora=True \ 53 | --gin.hyper_network.HyperT5Config.lora_ranks="(512,None,512,None)" \ 54 | --gin.USE_CACHED_TASKS=True \ 55 | --gin.utils.DatasetConfig.batch_size=256 \ 56 | --gin.utils.DatasetConfig.split=\"test\" \ 57 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 58 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 59 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 60 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 61 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_only_lora_no_fid_smaller.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="full_restore.gin" \ 24 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 25 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 26 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 27 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 28 | --gin.hyper_network.HyperT5Config.use_lora=True \ 29 | --gin.hyper_network.HyperT5Config.lora_ranks="(128,None,128,None)" \ 30 | --gin.USE_CACHED_TASKS=True \ 31 | --gin.trainer.Trainer.num_microbatches=16 \ 32 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 33 | --gin.BATCH_SIZE=1024 \ 34 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 35 | --gin.TRAIN_STEPS=$4 \ 36 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 37 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 38 | 39 | 40 | echo "Training done. Now evaluating all checkpoints..." 41 | 42 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 43 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 44 | --gin_search_paths="gins" \ 45 | --gin_file="hyper_xl.gin" \ 46 | --gin_file="instruction_embed.gin" \ 47 | --gin_file="ni_eval.gin" \ 48 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 49 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 50 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 51 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 52 | --gin.hyper_network.HyperT5Config.use_lora=True \ 53 | --gin.hyper_network.HyperT5Config.lora_ranks="(128,None,128,None)" \ 54 | --gin.USE_CACHED_TASKS=True \ 55 | --gin.utils.DatasetConfig.batch_size=256 \ 56 | --gin.utils.DatasetConfig.split=\"test\" \ 57 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 58 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 59 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 60 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 61 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_only_prefix_no_fid.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="full_restore.gin" \ 24 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 25 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 26 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 27 | --gin.USE_CACHED_TASKS=True \ 28 | --gin.trainer.Trainer.num_microbatches=16 \ 29 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 30 | --gin.BATCH_SIZE=1024 \ 31 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 32 | --gin.TRAIN_STEPS=$4 \ 33 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 34 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 35 | 36 | 37 | echo "Training done. Now evaluating all checkpoints..." 38 | 39 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 40 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 41 | --gin_search_paths="gins" \ 42 | --gin_file="hyper_xl.gin" \ 43 | --gin_file="instruction_embed.gin" \ 44 | --gin_file="ni_eval.gin" \ 45 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 46 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 47 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 48 | --gin.USE_CACHED_TASKS=True \ 49 | --gin.utils.DatasetConfig.batch_size=256 \ 50 | --gin.utils.DatasetConfig.split=\"test\" \ 51 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 52 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 53 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 54 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 55 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_pretrained.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="full_restore.gin" \ 24 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 25 | --gin.USE_CACHED_TASKS=True \ 26 | --gin.trainer.Trainer.num_microbatches=16 \ 27 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 28 | --gin.BATCH_SIZE=1024 \ 29 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 30 | --gin.TRAIN_STEPS=$4 \ 31 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 32 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 33 | 34 | 35 | echo "Training done. Now evaluating all checkpoints..." 36 | 37 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 38 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 39 | --gin_search_paths="gins" \ 40 | --gin_file="hyper_xl.gin" \ 41 | --gin_file="instruction_embed.gin" \ 42 | --gin_file="ni_eval.gin" \ 43 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 44 | --gin.USE_CACHED_TASKS=True \ 45 | --gin.utils.DatasetConfig.batch_size=256 \ 46 | --gin.utils.DatasetConfig.split=\"test\" \ 47 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 48 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 49 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 50 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 51 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_pretrained_2pos.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="full_restore.gin" \ 24 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instruction_positive_example_hyper_1\" \ 25 | --gin.USE_CACHED_TASKS=True \ 26 | --gin.trainer.Trainer.num_microbatches=16 \ 27 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 28 | --gin.BATCH_SIZE=1024 \ 29 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 30 | --gin.TRAIN_STEPS=$4 \ 31 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 32 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 33 | 34 | 35 | echo "Training done. Now evaluating all checkpoints..." 36 | 37 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 38 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 39 | --gin_search_paths="gins" \ 40 | --gin_file="hyper_xl.gin" \ 41 | --gin_file="instruction_embed.gin" \ 42 | --gin_file="ni_eval.gin" \ 43 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instruction_positive_example_hyper_1\" \ 44 | --gin.USE_CACHED_TASKS=True \ 45 | --gin.utils.DatasetConfig.batch_size=128 \ 46 | --gin.utils.DatasetConfig.split=\"test\" \ 47 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 48 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 49 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 50 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 51 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_pretrained_base.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_base.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="full_restore.gin" \ 24 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 25 | --gin.USE_CACHED_TASKS=True \ 26 | --gin.trainer.Trainer.num_microbatches=16 \ 27 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 28 | --gin.BATCH_SIZE=1024 \ 29 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 30 | --gin.TRAIN_STEPS=$4 \ 31 | --gin.partitioning.PjitPartitioner.num_partitions=2 \ 32 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 33 | 34 | 35 | echo "Training done. Now evaluating all checkpoints..." 36 | 37 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 38 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 39 | --gin_search_paths="gins" \ 40 | --gin_file="hyper_xl.gin" \ 41 | --gin_file="instruction_embed.gin" \ 42 | --gin_file="ni_eval.gin" \ 43 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 44 | --gin.USE_CACHED_TASKS=True \ 45 | --gin.utils.DatasetConfig.batch_size=256 \ 46 | --gin.utils.DatasetConfig.split=\"test\" \ 47 | --gin.partitioning.PjitPartitioner.num_partitions=2 \ 48 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 49 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 50 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 51 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_pretrained_base_2pos.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_base.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="full_restore.gin" \ 24 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instruction_positive_example_hyper\" \ 25 | --gin.USE_CACHED_TASKS=True \ 26 | --gin.trainer.Trainer.num_microbatches=16 \ 27 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 28 | --gin.BATCH_SIZE=1024 \ 29 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 30 | --gin.TRAIN_STEPS=$4 \ 31 | --gin.partitioning.PjitPartitioner.num_partitions=2 \ 32 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 33 | 34 | 35 | echo "Training done. Now evaluating all checkpoints..." 36 | 37 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 38 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 39 | --gin_search_paths="gins" \ 40 | --gin_file="hyper_base.gin" \ 41 | --gin_file="instruction_embed.gin" \ 42 | --gin_file="ni_eval.gin" \ 43 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instruction_positive_example_hyper\" \ 44 | --gin.USE_CACHED_TASKS=True \ 45 | --gin.utils.DatasetConfig.batch_size=256 \ 46 | --gin.utils.DatasetConfig.split=\"test\" \ 47 | --gin.partitioning.PjitPartitioner.num_partitions=2 \ 48 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 49 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 50 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 51 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_pretrained_decoder.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="hypertune_embed.gin" \ 24 | --gin_file="full_restore.gin" \ 25 | --gin_file="train_only_hnet.gin" \ 26 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 27 | --gin.USE_CACHED_TASKS=True \ 28 | --gin.trainer.Trainer.num_microbatches=128 \ 29 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 30 | --gin.BATCH_SIZE=1024 \ 31 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 32 | --gin.TRAIN_STEPS=$4 \ 33 | --gin.partitioning.PjitPartitioner.num_partitions=16 \ 34 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 35 | 36 | echo "Training done. Now evaluating all checkpoints..." 37 | 38 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 39 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 40 | --gin_search_paths="gins" \ 41 | --gin_file="hyper_xl.gin" \ 42 | --gin_file="instruction_embed.gin" \ 43 | --gin_file="hypertune_embed.gin" \ 44 | --gin_file="ni_eval.gin" \ 45 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 46 | --gin.USE_CACHED_TASKS=True \ 47 | --gin.utils.DatasetConfig.batch_size=256 \ 48 | --gin.utils.DatasetConfig.split=\"test\" \ 49 | --gin.partitioning.PjitPartitioner.num_partitions=16 \ 50 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 51 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 52 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 53 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_pretrained_froz.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="partial_train_adafactor_dual.gin" \ 24 | --gin_file="restore_frozen_under.gin" \ 25 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 26 | --gin.USE_CACHED_TASKS=True \ 27 | --gin.trainer.Trainer.num_microbatches=16 \ 28 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 29 | --gin.BATCH_SIZE=1024 \ 30 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 31 | --gin.TRAIN_STEPS=$4 \ 32 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 33 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 34 | 35 | echo "Training done. Now evaluating all checkpoints..." 36 | 37 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 38 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 39 | --gin_search_paths="gins" \ 40 | --gin_file="hyper_xl.gin" \ 41 | --gin_file="instruction_embed.gin" \ 42 | --gin_file="ni_eval.gin" \ 43 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 44 | --gin.USE_CACHED_TASKS=True \ 45 | --gin.utils.DatasetConfig.batch_size=512 \ 46 | --gin.utils.DatasetConfig.split=\"test\" \ 47 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 48 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 49 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 50 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 51 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_pretrained_just_prefix.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="partial_train_adafactor_dual.gin" \ 24 | --gin_file="full_restore.gin" \ 25 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 26 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 27 | --gin.hyper_network.HyperT5Config.num_prefix_tokens=512 \ 28 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 29 | --gin.USE_CACHED_TASKS=True \ 30 | --gin.trainer.Trainer.num_microbatches=16 \ 31 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 32 | --gin.BATCH_SIZE=1024 \ 33 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 34 | --gin.TRAIN_STEPS=$4 \ 35 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 36 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 37 | 38 | echo "Training done. Now evaluating all checkpoints..." 39 | 40 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 41 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 42 | --gin_search_paths="gins" \ 43 | --gin_file="hyper_xl.gin" \ 44 | --gin_file="instruction_embed.gin" \ 45 | --gin_file="ni_eval.gin" \ 46 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 47 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 48 | --gin.hyper_network.HyperT5Config.num_prefix_tokens=512 \ 49 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 50 | --gin.USE_CACHED_TASKS=True \ 51 | --gin.utils.DatasetConfig.batch_size=512 \ 52 | --gin.utils.DatasetConfig.split=\"test\" \ 53 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 54 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 55 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 56 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 57 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_pretrained_just_prefix_alt.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="full_restore.gin" \ 24 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 25 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 26 | --gin.hyper_network.HyperT5Config.num_prefix_tokens=512 \ 27 | --gin.hyper_network.HyperT5Config.share_hnet_encoder=False \ 28 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 29 | --gin.USE_CACHED_TASKS=True \ 30 | --gin.trainer.Trainer.num_microbatches=16 \ 31 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 32 | --gin.BATCH_SIZE=1024 \ 33 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 34 | --gin.TRAIN_STEPS=$4 \ 35 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 36 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 37 | 38 | echo "Training done. Now evaluating all checkpoints..." 39 | 40 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 41 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 42 | --gin_search_paths="gins" \ 43 | --gin_file="hyper_xl.gin" \ 44 | --gin_file="instruction_embed.gin" \ 45 | --gin_file="ni_eval.gin" \ 46 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 47 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 48 | --gin.hyper_network.HyperT5Config.num_prefix_tokens=512 \ 49 | --gin.hyper_network.HyperT5Config.share_hnet_encoder=False \ 50 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 51 | --gin.USE_CACHED_TASKS=True \ 52 | --gin.utils.DatasetConfig.batch_size=512 \ 53 | --gin.utils.DatasetConfig.split=\"test\" \ 54 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 55 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 56 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 57 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 58 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_pretrained_just_prefix_tanh.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="partial_train_adafactor_dual.gin" \ 24 | --gin_file="full_restore.gin" \ 25 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 26 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 27 | --gin.hyper_network.HyperT5Config.use_tanh_prefix=True \ 28 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 29 | --gin.USE_CACHED_TASKS=True \ 30 | --gin.trainer.Trainer.num_microbatches=16 \ 31 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 32 | --gin.BATCH_SIZE=1024 \ 33 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 34 | --gin.TRAIN_STEPS=$4 \ 35 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 36 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 37 | 38 | echo "Training done. Now evaluating all checkpoints..." 39 | 40 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 41 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 42 | --gin_search_paths="gins" \ 43 | --gin_file="hyper_xl.gin" \ 44 | --gin_file="instruction_embed.gin" \ 45 | --gin_file="ni_eval.gin" \ 46 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 47 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 48 | --gin.hyper_network.HyperT5Config.use_tanh_prefix=True \ 49 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 50 | --gin.USE_CACHED_TASKS=True \ 51 | --gin.utils.DatasetConfig.batch_size=512 \ 52 | --gin.utils.DatasetConfig.split=\"test\" \ 53 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 54 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 55 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 56 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 57 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_pretrained_mimic.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="full_restore.gin" \ 24 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions_split_mimic_def\" \ 25 | --gin.USE_CACHED_TASKS=True \ 26 | --gin.trainer.Trainer.num_microbatches=16 \ 27 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 28 | --gin.BATCH_SIZE=1024 \ 29 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 30 | --gin.TRAIN_STEPS=$4 \ 31 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 32 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 33 | 34 | 35 | echo "Training done. Now evaluating all checkpoints..." 36 | 37 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 38 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 39 | --gin_search_paths="gins" \ 40 | --gin_file="hyper_xl.gin" \ 41 | --gin_file="instruction_embed.gin" \ 42 | --gin_file="ni_eval.gin" \ 43 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions_split_mimic_def\" \ 44 | --gin.USE_CACHED_TASKS=True \ 45 | --gin.utils.DatasetConfig.batch_size=256 \ 46 | --gin.utils.DatasetConfig.split=\"test\" \ 47 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 48 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 49 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 50 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 51 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_pretrained_no_fid.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="full_restore.gin" \ 24 | --gin_file="partial_train_adafactor.gin" \ 25 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 26 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 27 | --gin.USE_CACHED_TASKS=True \ 28 | --gin.trainer.Trainer.num_microbatches=16 \ 29 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 30 | --gin.BATCH_SIZE=1024 \ 31 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 32 | --gin.TRAIN_STEPS=$4 \ 33 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 34 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 35 | 36 | 37 | echo "Training done. Now evaluating all checkpoints..." 38 | 39 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 40 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 41 | --gin_search_paths="gins" \ 42 | --gin_file="hyper_xl.gin" \ 43 | --gin_file="instruction_embed.gin" \ 44 | --gin_file="ni_eval.gin" \ 45 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 46 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 47 | --gin.USE_CACHED_TASKS=True \ 48 | --gin.utils.DatasetConfig.batch_size=512 \ 49 | --gin.utils.DatasetConfig.split=\"test\" \ 50 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 51 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 52 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 53 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 54 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_pretrained_our_decoder.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="hypertune_embed.gin" \ 24 | --gin_file="full_restore.gin" \ 25 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 26 | --gin.USE_CACHED_TASKS=True \ 27 | --gin.trainer.Trainer.num_microbatches=128 \ 28 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 29 | --gin.BATCH_SIZE=1024 \ 30 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 31 | --gin.TRAIN_STEPS=$4 \ 32 | --gin.partitioning.PjitPartitioner.num_partitions=16 \ 33 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 34 | 35 | echo "Training done. Now evaluating all checkpoints..." 36 | 37 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 38 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 39 | --gin_search_paths="gins" \ 40 | --gin_file="hyper_xl.gin" \ 41 | --gin_file="instruction_embed.gin" \ 42 | --gin_file="hypertune_embed.gin" \ 43 | --gin_file="ni_eval.gin" \ 44 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 45 | --gin.USE_CACHED_TASKS=True \ 46 | --gin.utils.DatasetConfig.batch_size=128 \ 47 | --gin.utils.DatasetConfig.split=\"test\" \ 48 | --gin.partitioning.PjitPartitioner.num_partitions=16 \ 49 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 50 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 51 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 52 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_pretrained_xxl.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xxl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="partial_train_adafactor_dual.gin" \ 24 | --gin_file="full_restore.gin" \ 25 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 26 | --gin.USE_CACHED_TASKS=True \ 27 | --gin.trainer.Trainer.num_microbatches=32 \ 28 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 29 | --gin.BATCH_SIZE=1024 \ 30 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 31 | --gin.TRAIN_STEPS=$4 \ 32 | --gin.partitioning.PjitPartitioner.num_partitions=16 \ 33 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 34 | 35 | echo "Training done. Now evaluating all checkpoints..." 36 | 37 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 38 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 39 | --gin_search_paths="gins" \ 40 | --gin_file="hyper_xxl.gin" \ 41 | --gin_file="instruction_embed.gin" \ 42 | --gin_file="ni_eval.gin" \ 43 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 44 | --gin.USE_CACHED_TASKS=True \ 45 | --gin.utils.DatasetConfig.batch_size=32 \ 46 | --gin.utils.DatasetConfig.split=\"test\" \ 47 | --gin.partitioning.PjitPartitioner.num_partitions=16 \ 48 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 49 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 50 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 51 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_reg.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | BUCKET_NAME="hamishi-tpu" 7 | 8 | # where model will be saved 9 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 10 | 11 | # we go offline to avoid constant calls to get basic info (happens even when cached) 12 | # for your first run, you will probably need to run all these calls :( 13 | python3 -m t5x.train \ 14 | --gin_search_paths=gins \ 15 | --gin_file="hyper_large.gin" \ 16 | --gin_file="ni_train.gin" \ 17 | --gin_file="partial_train_adafactor.gin" \ 18 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 19 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 20 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 21 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions_def_pos_2\" \ 22 | --gin.USE_CACHED_TASKS=True \ 23 | --gin.trainer.Trainer.num_microbatches=8 \ 24 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 25 | --gin.BATCH_SIZE=1024 \ 26 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 27 | --gin.TRAIN_STEPS=1101000 \ 28 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 29 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_large/checkpoint_1100000/\" \ 30 | 31 | echo "Training done. Now evaluating all checkpoints..." 32 | # gsutil -m cp -r ${MODEL_DIR} gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model 33 | 34 | 35 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 36 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 37 | --gin_search_paths="gins" \ 38 | --gin_file="hyper_large.gin" \ 39 | --gin_file="ni_eval.gin" \ 40 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 41 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 42 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 43 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions_def_pos_2\" \ 44 | --gin.USE_CACHED_TASKS=True \ 45 | --gin.utils.DatasetConfig.batch_size=512 \ 46 | --gin.utils.DatasetConfig.split=\"test\" \ 47 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 48 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 49 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 50 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 51 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_reg_base.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | BUCKET_NAME="hamishi-tpu" 7 | 8 | # where model will be saved 9 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 10 | 11 | # we go offline to avoid constant calls to get basic info (happens even when cached) 12 | # for your first run, you will probably need to run all these calls :( 13 | python3 -m t5x.train \ 14 | --gin_search_paths=gins \ 15 | --gin_file="hyper_base.gin" \ 16 | --gin_file="ni_train.gin" \ 17 | --gin_file="partial_train_adafactor.gin" \ 18 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 19 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 20 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 21 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions_def_pos_2\" \ 22 | --gin.USE_CACHED_TASKS=True \ 23 | --gin.trainer.Trainer.num_microbatches=8 \ 24 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 25 | --gin.BATCH_SIZE=1024 \ 26 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 27 | --gin.TRAIN_STEPS=1101000 \ 28 | --gin.partitioning.PjitPartitioner.num_partitions=2 \ 29 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_base/checkpoint_1100000/\" \ 30 | 31 | echo "Training done. Now evaluating all checkpoints..." 32 | # gsutil -m cp -r ${MODEL_DIR} gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model 33 | 34 | 35 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 36 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 37 | --gin_search_paths="gins" \ 38 | --gin_file="hyper_base.gin" \ 39 | --gin_file="ni_eval.gin" \ 40 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 41 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 42 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 43 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions_def_pos_2\" \ 44 | --gin.USE_CACHED_TASKS=True \ 45 | --gin.utils.DatasetConfig.batch_size=512 \ 46 | --gin.utils.DatasetConfig.split=\"test\" \ 47 | --gin.partitioning.PjitPartitioner.num_partitions=2 \ 48 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 49 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 50 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 51 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_reg_xxl.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | BUCKET_NAME="hamishi-tpu" 7 | 8 | # where model will be saved 9 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 10 | 11 | # we go offline to avoid constant calls to get basic info (happens even when cached) 12 | # for your first run, you will probably need to run all these calls :( 13 | python3 -m t5x.train \ 14 | --gin_search_paths=gins \ 15 | --gin_file="hyper_xxl.gin" \ 16 | --gin_file="ni_train.gin" \ 17 | --gin.hyper_network.HyperT5Config.hyperencoder_model=\"google/t5-base-lm-adapt\" \ 18 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 19 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 20 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 21 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions_def\" \ 22 | --gin.USE_CACHED_TASKS=True \ 23 | --gin.trainer.Trainer.num_microbatches=32 \ 24 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 25 | --gin.BATCH_SIZE=1024 \ 26 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 27 | --gin.TRAIN_STEPS=1101000 \ 28 | --gin.partitioning.PjitPartitioner.num_partitions=16 \ 29 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xxl/checkpoint_1100000/\" \ 30 | 31 | echo "Training done. Now evaluating all checkpoints..." 32 | # gsutil -m cp -r ${MODEL_DIR} gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model 33 | 34 | 35 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 36 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 37 | --gin_search_paths="gins" \ 38 | --gin_file="hyper_xxl.gin" \ 39 | --gin_file="ni_eval.gin" \ 40 | --gin.hyper_network.HyperT5Config.hyperencoder_model=\"google/t5-base-lm-adapt\" \ 41 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 42 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 43 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 44 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions_def\" \ 45 | --gin.USE_CACHED_TASKS=True \ 46 | --gin.utils.DatasetConfig.batch_size=512 \ 47 | --gin.utils.DatasetConfig.split=\"test\" \ 48 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 49 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 50 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 51 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 52 | -------------------------------------------------------------------------------- /scripts/nat_int/ni_train_sep_encoder_no_fid.sh: -------------------------------------------------------------------------------- 1 | # NI training 2 | # eval after since its a short experiment. 3 | 4 | # name of experiment folder 5 | EXPERIMENT_NAME=$1 6 | LOAD_MODEL=$2 7 | CHECKPOINT=$3 8 | TRAIN_STEPS=$4 9 | BUCKET_NAME="hamishi-tpu" 10 | 11 | echo "Make sure train steps is set to > the checkpoint!" 12 | 13 | # where model will be saved 14 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 15 | 16 | # we go offline to avoid constant calls to get basic info (happens even when cached) 17 | # for your first run, you will probably need to run all these calls :( 18 | python3 -m t5x.train \ 19 | --gin_search_paths=gins \ 20 | --gin_file="hyper_xl.gin" \ 21 | --gin_file="instruction_embed.gin" \ 22 | --gin_file="ni_train.gin" \ 23 | --gin_file="full_restore.gin" \ 24 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 25 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 26 | --gin.hyper_network.HyperT5Config.share_hnet_encoder=False \ 27 | --gin.USE_CACHED_TASKS=True \ 28 | --gin.trainer.Trainer.num_microbatches=16 \ 29 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 30 | --gin.BATCH_SIZE=1024 \ 31 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 32 | --gin.TRAIN_STEPS=$4 \ 33 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 34 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 35 | 36 | 37 | echo "Training done. Now evaluating all checkpoints..." 38 | 39 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval/" 40 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 41 | --gin_search_paths="gins" \ 42 | --gin_file="hyper_xl.gin" \ 43 | --gin_file="instruction_embed.gin" \ 44 | --gin_file="ni_eval.gin" \ 45 | --gin.MIXTURE_OR_TASK_NAME=\"natural_instructions\" \ 46 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 47 | --gin.hyper_network.HyperT5Config.share_hnet_encoder=False \ 48 | --gin.USE_CACHED_TASKS=True \ 49 | --gin.utils.DatasetConfig.batch_size=256 \ 50 | --gin.utils.DatasetConfig.split=\"test\" \ 51 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 52 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 53 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" \ 54 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" 55 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain.gin" \ 12 | --gin_file="partial_train_adafactor_dual.gin" \ 13 | --gin.USE_CACHED_TASKS=True \ 14 | --gin.trainer.Trainer.num_microbatches=8 \ 15 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 16 | --gin.BATCH_SIZE=1024 \ 17 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 18 | --gin.TRAIN_STEPS=1110000 \ 19 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 20 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 21 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_base.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_base.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain.gin" \ 12 | --gin_file="partial_train_adafactor_dual.gin" \ 13 | --gin.USE_CACHED_TASKS=True \ 14 | --gin.trainer.Trainer.num_microbatches=8 \ 15 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 16 | --gin.BATCH_SIZE=1024 \ 17 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 18 | --gin.TRAIN_STEPS=2000000 \ 19 | --gin.partitioning.PjitPartitioner.num_partitions=2 \ 20 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_base/checkpoint_1100000/\" 21 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_decoder.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain.gin" \ 12 | --gin_file="train_only_hnet.gin" \ 13 | --gin_file="hypertune.gin" \ 14 | --gin.USE_CACHED_TASKS=True \ 15 | --gin.trainer.Trainer.num_microbatches=8 \ 16 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 17 | --gin.BATCH_SIZE=1024 \ 18 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 19 | --gin.TRAIN_STEPS=1120000 \ 20 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 21 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 22 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_hnet_only.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain.gin" \ 12 | --gin_file="train_only_hnet.gin" \ 13 | --gin.USE_CACHED_TASKS=True \ 14 | --gin.trainer.Trainer.num_microbatches=8 \ 15 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 16 | --gin.BATCH_SIZE=1024 \ 17 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 18 | --gin.TRAIN_STEPS=2000000 \ 19 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 20 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 21 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_htune.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain_4part.gin" \ 12 | --gin_file="train_only_hnet.gin" \ 13 | --gin_file="hypertune.gin" \ 14 | --gin.USE_CACHED_TASKS=True \ 15 | --gin.trainer.Trainer.num_microbatches=8 \ 16 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 17 | --gin.BATCH_SIZE=1024 \ 18 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 19 | --gin.TRAIN_STEPS=1120000 \ 20 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 21 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 22 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_lora.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain.gin" \ 12 | --gin_file="partial_train_adafactor.gin" \ 13 | --gin.USE_CACHED_TASKS=True \ 14 | --gin.trainer.Trainer.num_microbatches=8 \ 15 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 16 | --gin.BATCH_SIZE=1024 \ 17 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 18 | --gin.TRAIN_STEPS=1120000 \ 19 | --gin.hyper_network.HyperT5Config.use_lora=True \ 20 | --gin.hyper_network.HyperT5Config.lora_ranks="(64,None,64,None)" \ 21 | --gin.hyper/hyper_utils.match_any.regexes="[\".*/hyper/[^e].*\", \".*/hyper/lora*.*\"]" \ 22 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 23 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 24 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_non_layer.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain_6part.gin" \ 12 | --gin_file="partial_train_adafactor_dual.gin" \ 13 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 14 | --gin.USE_CACHED_TASKS=True \ 15 | --gin.trainer.Trainer.num_microbatches=8 \ 16 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 17 | --gin.BATCH_SIZE=1024 \ 18 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 19 | --gin.TRAIN_STEPS=1120000 \ 20 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 21 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 22 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_only_adapter_no_fid.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain.gin" \ 12 | --gin_file="partial_train_adafactor_dual.gin" \ 13 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 14 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 15 | --gin.USE_CACHED_TASKS=True \ 16 | --gin.trainer.Trainer.num_microbatches=8 \ 17 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 18 | --gin.BATCH_SIZE=1024 \ 19 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 20 | --gin.TRAIN_STEPS=1120000 \ 21 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 22 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 23 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_only_lora_no_fid.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain.gin" \ 12 | --gin_file="partial_train_adafactor_dual.gin" \ 13 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 14 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 15 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 16 | --gin.hyper_network.HyperT5Config.use_lora=True \ 17 | --gin.hyper_network.HyperT5Config.lora_ranks="(512,None,512,None)" \ 18 | --gin.USE_CACHED_TASKS=True \ 19 | --gin.trainer.Trainer.num_microbatches=8 \ 20 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 21 | --gin.BATCH_SIZE=1024 \ 22 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 23 | --gin.TRAIN_STEPS=1120000 \ 24 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 25 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 26 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_only_lora_no_fid_smaller.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain.gin" \ 12 | --gin_file="partial_train_adafactor_dual.gin" \ 13 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 14 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 15 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 16 | --gin.hyper_network.HyperT5Config.use_lora=True \ 17 | --gin.hyper_network.HyperT5Config.lora_ranks="(128,None,128,None)" \ 18 | --gin.USE_CACHED_TASKS=True \ 19 | --gin.trainer.Trainer.num_microbatches=8 \ 20 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 21 | --gin.BATCH_SIZE=1024 \ 22 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 23 | --gin.TRAIN_STEPS=1120000 \ 24 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 25 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 26 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_only_prefix_no_fid.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain.gin" \ 12 | --gin_file="partial_train_adafactor_dual_frozen_under.gin" \ 13 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 14 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 15 | --gin.hyper_network.HyperT5Config.num_prefix_tokens=512 \ 16 | --gin.hyper_network.HyperT5Config.share_hnet_encoder=False \ 17 | --gin.USE_CACHED_TASKS=True \ 18 | --gin.trainer.Trainer.num_microbatches=8 \ 19 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 20 | --gin.BATCH_SIZE=1024 \ 21 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 22 | --gin.TRAIN_STEPS=1120000 \ 23 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 24 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 25 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_only_prefix_no_fid_4_way.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain_4part.gin" \ 12 | --gin_file="partial_train_adafactor_dual.gin" \ 13 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 14 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 15 | --gin.USE_CACHED_TASKS=True \ 16 | --gin.trainer.Trainer.num_microbatches=8 \ 17 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 18 | --gin.BATCH_SIZE=1024 \ 19 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 20 | --gin.TRAIN_STEPS=1120000 \ 21 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 22 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 23 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_only_prefix_no_fid_alt.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain.gin" \ 12 | --gin_file="partial_train_adafactor_dual.gin" \ 13 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 14 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 15 | --gin.hyper_network.HyperT5Config.use_tanh_prefix=True \ 16 | --gin.USE_CACHED_TASKS=True \ 17 | --gin.trainer.Trainer.num_microbatches=8 \ 18 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 19 | --gin.BATCH_SIZE=1024 \ 20 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 21 | --gin.TRAIN_STEPS=1120000 \ 22 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 23 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 24 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_our_decoder.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain.gin" \ 12 | --gin_file="hypertune_full_train.gin" \ 13 | --gin.USE_CACHED_TASKS=True \ 14 | --gin.trainer.Trainer.num_microbatches=32 \ 15 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 16 | --gin.BATCH_SIZE=1024 \ 17 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 18 | --gin.TRAIN_STEPS=1120000 \ 19 | --gin.partitioning.PjitPartitioner.num_partitions=16 \ 20 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 21 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_prefix_just_prefix.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain_4part.gin" \ 12 | --gin_file="partial_train_adafactor_dual.gin" \ 13 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 14 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 15 | --gin.USE_CACHED_TASKS=True \ 16 | --gin.trainer.Trainer.num_microbatches=8 \ 17 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 18 | --gin.BATCH_SIZE=1024 \ 19 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 20 | --gin.TRAIN_STEPS=1120000 \ 21 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 22 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 23 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_prefix_no_fid.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain.gin" \ 12 | --gin_file="partial_train_adafactor_dual.gin" \ 13 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 14 | --gin.USE_CACHED_TASKS=True \ 15 | --gin.trainer.Trainer.num_microbatches=8 \ 16 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 17 | --gin.BATCH_SIZE=1024 \ 18 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 19 | --gin.TRAIN_STEPS=1120000 \ 20 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 21 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 22 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_sep_enc.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain.gin" \ 12 | --gin_file="separate_henc.gin" \ 13 | --gin_file="partial_train_adafactor.gin" \ 14 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 15 | --gin.USE_CACHED_TASKS=True \ 16 | --gin.trainer.Trainer.num_microbatches=8 \ 17 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 18 | --gin.BATCH_SIZE=1024 \ 19 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 20 | --gin.TRAIN_STEPS=2000000 \ 21 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 22 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 23 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_six.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain_6part.gin" \ 12 | --gin_file="partial_train_adafactor_dual.gin" \ 13 | --gin.hyper_network.HyperT5Config.use_fusion_in_decoder=False \ 14 | --gin.USE_CACHED_TASKS=True \ 15 | --gin.trainer.Trainer.num_microbatches=8 \ 16 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 17 | --gin.BATCH_SIZE=1024 \ 18 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 19 | --gin.TRAIN_STEPS=1120000 \ 20 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 21 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000/\" 22 | -------------------------------------------------------------------------------- /scripts/pretraining/pretrain_xxl.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME=$1 2 | BUCKET_NAME="hamishi-tpu" 3 | 4 | # where model will be saved 5 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 6 | 7 | python3 -m t5x.train \ 8 | --gin_search_paths=gins \ 9 | --gin_file="hyper_xxl.gin" \ 10 | --gin_file="instruction_embed.gin" \ 11 | --gin_file="pretrain.gin" \ 12 | --gin_file="partial_train_adafactor_dual.gin" \ 13 | --gin.USE_CACHED_TASKS=True \ 14 | --gin.trainer.Trainer.num_microbatches=32 \ 15 | --gin.utils.create_learning_rate_scheduler.warmup_steps=100 \ 16 | --gin.BATCH_SIZE=1024 \ 17 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 18 | --gin.TRAIN_STEPS=1120000 \ 19 | --gin.partitioning.PjitPartitioner.num_partitions=16 \ 20 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xxl/checkpoint_1100000/\" 21 | -------------------------------------------------------------------------------- /scripts/t0_few_shot/t0_xshot_eval_hint.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | SHOT=$2 # must be 1, 2, 4, 5 4 | BUCKET_NAME="hamishi-tpu" 5 | 6 | # where model will be saved 7 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 8 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval" 9 | 10 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 11 | --gin_search_paths="gins" \ 12 | --gin_file="hyper_xl.gin" \ 13 | --gin_file="t0_eval.gin" \ 14 | --gin.USE_CACHED_TASKS=True \ 15 | --gin.utils.DatasetConfig.batch_size=128 \ 16 | --gin.MIXTURE_OR_TASK_NAME=\"t0_eval_score_eval_${SHOT}_shot\" \ 17 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 18 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" \ 19 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" 20 | -------------------------------------------------------------------------------- /scripts/t0_few_shot/t0_xshot_eval_reg.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | SHOT=$2 # must be 1, 2, 4, 5 4 | BUCKET_NAME="hamishi-tpu" 5 | 6 | # where model will be saved 7 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 8 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval" 9 | 10 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 11 | --gin_search_paths="gins" \ 12 | --gin_file="hyper_xl.gin" \ 13 | --gin_file="t0_eval.gin" \ 14 | --gin.USE_CACHED_TASKS=True \ 15 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 16 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 17 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 18 | --gin.utils.DatasetConfig.batch_size=128 \ 19 | --gin.MIXTURE_OR_TASK_NAME=\"t0_eval_score_eval_${SHOT}_shot\" \ 20 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 21 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" \ 22 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" 23 | -------------------------------------------------------------------------------- /scripts/t0_few_shot/t0_xshot_train_hint.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | SHOT=$2 # must be 1, 2, 4, 5 4 | BUCKET_NAME="hamishi-tpu" 5 | 6 | # where model will be saved 7 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 8 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval" 9 | 10 | # we go offline to avoid constant calls to get basic info (happens even when cached) 11 | # for your first run, you will probably need to run all these calls :( 12 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \ 13 | --gin_search_paths=gins \ 14 | --gin_file="hyper_xl.gin" \ 15 | --gin_file="t0_train.gin" \ 16 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 17 | --gin.MIXTURE_OR_TASK_NAME=\"t0_train_${SHOT}_shot\" \ 18 | --gin.TRAIN_STEPS=1110000 \ 19 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 20 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000\" 21 | 22 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 23 | --gin_search_paths="gins" \ 24 | --gin_file="hyper_xl.gin" \ 25 | --gin_file="t0_eval.gin" \ 26 | --gin.USE_CACHED_TASKS=True \ 27 | --gin.utils.DatasetConfig.batch_size=128 \ 28 | --gin.MIXTURE_OR_TASK_NAME=\"t0_eval_score_eval_${SHOT}_shot\" \ 29 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 30 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" \ 31 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" 32 | -------------------------------------------------------------------------------- /scripts/t0_few_shot/t0_xshot_train_reg.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | SHOT=$2 # must be 1, 2, 4, 5 4 | BUCKET_NAME="hamishi-tpu" 5 | 6 | # where model will be saved 7 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 8 | EVAL_OUTPUT_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/eval" 9 | 10 | # we go offline to avoid constant calls to get basic info (happens even when cached) 11 | # for your first run, you will probably need to run all these calls :( 12 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \ 13 | --gin_search_paths=gins \ 14 | --gin_file="hyper_xl.gin" \ 15 | --gin_file="t0_train.gin" \ 16 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 17 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 18 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 19 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 20 | --gin.MIXTURE_OR_TASK_NAME=\"t0_train_${SHOT}_shot\" \ 21 | --gin.TRAIN_STEPS=1110000 \ 22 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 23 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000\" 24 | 25 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.eval \ 26 | --gin_search_paths="gins" \ 27 | --gin_file="hyper_xl.gin" \ 28 | --gin_file="t0_eval.gin" \ 29 | --gin.USE_CACHED_TASKS=True \ 30 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 31 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 32 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 33 | --gin.utils.DatasetConfig.batch_size=128 \ 34 | --gin.MIXTURE_OR_TASK_NAME=\"t0_eval_score_eval_${SHOT}_shot\" \ 35 | --gin.CHECKPOINT_PATH=\"$MODEL_DIR\" \ 36 | --gin.EVAL_OUTPUT_DIR=\"$EVAL_OUTPUT_DIR\" \ 37 | --gin.utils.RestoreCheckpointConfig.mode=\"all\" 38 | -------------------------------------------------------------------------------- /scripts/t0_reg_train.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | BUCKET_NAME="hamishi-tpu" 4 | 5 | # where model will be saved 6 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 7 | 8 | # we go offline to avoid constant calls to get basic info (happens even when cached) 9 | # for your first run, you will probably need to run all these calls :( 10 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \ 11 | --gin_search_paths=gins \ 12 | --gin_file="hyper_xl.gin" \ 13 | --gin_file="t0_train.gin" \ 14 | --gin.hyper_network.HyperT5Config.use_adapter=False \ 15 | --gin.hyper_network.HyperT5Config.use_prefix=False \ 16 | --gin.hyper_network.HyperT5Config.use_instructions=False \ 17 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 18 | --gin.TRAIN_STEPS=1212200 \ 19 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 20 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000\" 21 | -------------------------------------------------------------------------------- /scripts/t0_train_hypter.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | BUCKET_NAME="hamishi-tpu" 4 | 5 | # where model will be saved 6 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 7 | 8 | # we go offline to avoid constant calls to get basic info (happens even when cached) 9 | # for your first run, you will probably need to run all these calls :( 10 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \ 11 | --gin_search_paths=gins \ 12 | --gin_file="hyper_xl.gin" \ 13 | --gin_file="instruction_embed.gin" \ 14 | --gin_file="t0_train.gin" \ 15 | --gin_file="partial_train_adafactor_no_roberta.gin" \ 16 | --gin_file="hypter.gin" \ 17 | --gin.hyper_network.HyperT5Config.adapter_size=4 \ 18 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 19 | --gin.TRAIN_STEPS=1120000 \ 20 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 21 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000\" 22 | -------------------------------------------------------------------------------- /scripts/tpu_setup.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=${PYTHONPATH}:${PWD} 2 | # setup t5x (important) 3 | git clone https://github.com/google-research/t5x.git 4 | cd t5x 5 | git checkout 3282da46b4a7e46bc17b96cdb6673a4dd812a1b6 6 | # no deps as t5x used un-pinned library versions from github 7 | # we have the pinned versions in requirements.txt 8 | python3 -m pip install -e '.[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --no-deps 9 | cd .. 10 | # install deps 11 | python3 -m pip install -r requirements.txt --upgrade --no-deps 12 | # install tpu-specific jax 13 | python3 -m pip install "jax[tpu]==0.3.23" --upgrade -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 14 | echo "----- ALL DEPENDENCIES INSTALLED -----" 15 | # optional: caching. You should do this for the T0 tasks if you haven't, 16 | # but its not needed for the SNI tasks. 17 | # we cache the tokenizers / HF splits used so we don't have to load them later. 18 | # This can take ~15 minutes. 19 | # python3 -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('t5-base')" 20 | # python3 -c "from transformers import AutoModel; AutoModel.from_pretrained('google/t5-large-lm-adapt')" 21 | # python3 -c "from transformers import AutoModel; AutoModel.from_pretrained('google/t5-small-lm-adapt')" 22 | # TRANSFORMERS_OFFLINE=1 python3 -c "import hyper_task_descriptions.seqio_tasks.all_t0_tasks" 23 | # echo "----- CACHED TOKENIZERS AND SPLITS -----" 24 | # and we are done! 25 | echo "----- TPU SETUP COMPLETE -----" 26 | -------------------------------------------------------------------------------- /scripts/train_from_pretrained.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | LOAD_MODEL=$2 4 | CHECKPOINT=$3 5 | TRAIN_STEPS=$4 6 | BUCKET_NAME="hamishi-tpu" 7 | 8 | # where model will be saved 9 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 10 | 11 | # we go offline to avoid constant calls to get basic info (happens even when cached) 12 | # for your first run, you will probably need to run all these calls :( 13 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \ 14 | --gin_search_paths=gins \ 15 | --gin_file="hyper_xl.gin" \ 16 | --gin_file="t0_train.gin" \ 17 | --gin_file="partial_train_adafactor_dual.gin" \ 18 | --gin_file="full_restore.gin" \ 19 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 20 | --gin.TRAIN_STEPS=$4 \ 21 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 22 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://${BUCKET_NAME}/$2/model/$3\" 23 | -------------------------------------------------------------------------------- /scripts/train_from_t5.sh: -------------------------------------------------------------------------------- 1 | # name of experiment folder 2 | EXPERIMENT_NAME=$1 3 | BUCKET_NAME="hamishi-tpu" 4 | 5 | # where model will be saved 6 | MODEL_DIR="gs://${BUCKET_NAME}/${EXPERIMENT_NAME}/model" 7 | 8 | # we go offline to avoid constant calls to get basic info (happens even when cached) 9 | # for your first run, you will probably need to run all these calls :( 10 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python3 -m t5x.train \ 11 | --gin_search_paths=gins \ 12 | --gin_file="hyper_xl.gin" \ 13 | --gin_file="t0_train.gin" \ 14 | --gin_file="partial_train_adafactor_dual.gin" \ 15 | --gin.MODEL_DIR=\"${MODEL_DIR}\" \ 16 | --gin.TRAIN_STEPS=1212200 \ 17 | --gin.partitioning.PjitPartitioner.num_partitions=8 \ 18 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000\" 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | def read_requirements(filename: str): 5 | with open(filename) as requirements_file: 6 | import re 7 | 8 | def fix_url_dependencies(req: str) -> str: 9 | """Pip and setuptools disagree about how URL dependencies should be handled.""" 10 | m = re.match( 11 | r"^(git\+)?(https|ssh)://(git@)?github\.com/([\w-]+)/(?P[\w-]+)\.git", req 12 | ) 13 | if m is None: 14 | return req 15 | else: 16 | return f"{m.group('name')} @ {req}" 17 | 18 | requirements = [] 19 | for line in requirements_file: 20 | line = line.strip() 21 | if line.startswith("#") or len(line) <= 0: 22 | continue 23 | requirements.append(fix_url_dependencies(line)) 24 | return requirements 25 | 26 | 27 | # version.py defines the VERSION and VERSION_SHORT variables. 28 | # We use exec here so we don't import cached_path whilst setting up. 29 | VERSION = {} # type: ignore 30 | with open("hyper_task_descriptions/version.py", "r") as version_file: 31 | exec(version_file.read(), VERSION) 32 | 33 | setup( 34 | name="hyper_task_descriptions", 35 | version=VERSION["VERSION"], 36 | description="", 37 | long_description=open("README.md").read(), 38 | long_description_content_type="text/markdown", 39 | classifiers=[ 40 | "Intended Audience :: Science/Research", 41 | "Development Status :: 3 - Alpha", 42 | "License :: OSI Approved :: Apache Software License", 43 | "Programming Language :: Python :: 3", 44 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 45 | ], 46 | keywords="", 47 | url="https://github.com/allenai/hyper_task_descriptions", 48 | author="Allen Institute for Artificial Intelligence", 49 | author_email="contact@allenai.org", 50 | license="Apache", 51 | packages=find_packages( 52 | exclude=["*.tests", "*.tests.*", "tests.*", "tests"], 53 | ), 54 | package_data={ 55 | "hyper_task_descriptions": [ 56 | "seqio_tasks/datasets.csv", 57 | "seqio_tasks/all_t0_task_prefixes.txt", 58 | "seqio_tasks/all_edited_prompts.txt", 59 | ], 60 | "": ["requirements.txt", "dev-requirements.txt"], 61 | }, 62 | install_requires=read_requirements("requirements.txt"), 63 | extras_require={ 64 | "dev": read_requirements("dev-requirements.txt"), 65 | "catwalk": ["catwalk @ git+https://github.com/allenai/catwalk.git"], 66 | }, 67 | python_requires=">=3.7", 68 | ) 69 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/hyper-task-descriptions/fd86a43ac2131582548130d86d57ea977c804ab6/tests/__init__.py -------------------------------------------------------------------------------- /tests/hello_test.py: -------------------------------------------------------------------------------- 1 | def test_hello(): 2 | print("Hello, World!") 3 | -------------------------------------------------------------------------------- /tests/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/hyper-task-descriptions/fd86a43ac2131582548130d86d57ea977c804ab6/tests/modeling/__init__.py -------------------------------------------------------------------------------- /tests/modeling/hyper_network_test.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | from absl.testing import parameterized 4 | 5 | from hyper_task_descriptions.common.testing import get_test_model 6 | 7 | 8 | class NetworkTest(parameterized.TestCase): 9 | def setUp(self): 10 | super().setUp() 11 | batch_size, max_decode_len, input_len, hyper_input_len = 2, 3, 4, 5 12 | self.input_shapes = { 13 | "encoder_input_tokens": (batch_size, input_len), 14 | "hyper_encoder_input_tokens": (batch_size, hyper_input_len), 15 | "decoder_input_tokens": (batch_size, max_decode_len), 16 | } 17 | np.random.seed(42) 18 | self.batch = { 19 | "encoder_input_tokens": np.random.randint(3, 10, size=(batch_size, input_len)), 20 | "hyper_encoder_input_tokens": np.random.randint( 21 | 3, 10, size=(batch_size, hyper_input_len) 22 | ), 23 | "decoder_input_tokens": np.random.randint(3, 10, size=(batch_size, max_decode_len)), 24 | "decoder_target_tokens": np.random.randint(3, 10, size=(batch_size, max_decode_len)), 25 | } 26 | 27 | def test_t5_1_1_regression(self): 28 | np.random.seed(0) 29 | batch_size, max_decode_len, input_len, hyper_input_len = 2, 3, 4, 5 30 | batch = { 31 | "encoder_input_tokens": np.random.randint(3, 10, size=(batch_size, input_len)), 32 | "hyper_encoder_input_tokens": np.random.randint( 33 | 3, 10, size=(batch_size, hyper_input_len) 34 | ), 35 | "decoder_input_tokens": np.random.randint(3, 10, size=(batch_size, max_decode_len)), 36 | "decoder_target_tokens": np.random.randint(3, 10, size=(batch_size, max_decode_len)), 37 | } 38 | model = get_test_model( 39 | emb_dim=13, 40 | head_dim=16, 41 | num_heads=8, 42 | mlp_dim=32, 43 | vocab_size=10, 44 | num_encoder_layers=1, 45 | num_decoder_layers=1, 46 | ) 47 | params = model.get_initial_variables(jax.random.PRNGKey(42), self.input_shapes)["params"] 48 | loss, _ = jax.jit(model.loss_fn)(params, batch, jax.random.PRNGKey(1)) 49 | self.assertAlmostEqual(loss, 15.268721, delta=0.05) 50 | 51 | predicted, scores = model.predict_batch_with_aux(params, batch) 52 | # predicted.shape = 2 x 3 (batch_size x max_decode_len) (best option) 53 | np.testing.assert_array_equal(predicted, [[2, 6, 1], [2, 6, 5]]) 54 | # scores.shape = 2 (batch_size) (best option) 55 | np.testing.assert_allclose(scores["scores"], [-3.501333, -2.825637], rtol=1e-3) 56 | -------------------------------------------------------------------------------- /tests/modeling/losses_test.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from hyper_task_descriptions.modeling.losses import ( # safe_norm, 4 | cosine_similarity, 5 | cosine_similarity_loss, 6 | cosine_similarity_one_to_many, 7 | ) 8 | 9 | 10 | def test_cosine_similarity(): 11 | 12 | preds = jnp.array([1, 2, 3]) 13 | targets = jnp.array([1, 2, 3]) 14 | assert jnp.allclose(cosine_similarity(preds, targets), jnp.array(1)) 15 | 16 | targets = jnp.array([-1, 0, 1]) 17 | # [1, 2, 3] dot [-1, 0, 1] = -1 + 0 + 3 = 2 18 | # ||[1, 2, 3]|| = sqrt(14) 19 | # ||[-1, 0, 1]|| = sqrt(2) 20 | # cosine_sim = 2/(sqrt(14)*sqrt(2)) = 1/sqrt(7) ~= 0.37796444 21 | assert jnp.allclose(cosine_similarity(preds, targets), jnp.array(0.37796444)) 22 | 23 | 24 | def test_cosine_similarity_one_to_many(): 25 | 26 | preds = jnp.array([1, 2, 3]) 27 | targets = jnp.array([[1, 2, 3], [-1, 0, 1]]) 28 | assert jnp.allclose(cosine_similarity_one_to_many(preds, targets), jnp.array([1, 0.37796444])) 29 | 30 | 31 | def test_cosine_similarity_loss(): 32 | preds = jnp.array([[1, 2, 3], [1, 2, 3]]) 33 | targets = jnp.array([[1, 2, 3], [-1, 0, 1]]) 34 | gt_sim = jnp.array([1, 0.37796444]) 35 | 36 | jnp.allclose(cosine_similarity_loss(preds, targets, gt_sim), jnp.array(0)) 37 | 38 | gt_sim = jnp.array([0.5, 0.37796444]) 39 | expected_loss = ((1 - 0.5) ** 2 + 0) / 2 40 | jnp.allclose(cosine_similarity_loss(preds, targets, gt_sim), jnp.array(expected_loss)) 41 | 42 | 43 | def test_safe_norm(): 44 | pass 45 | --------------------------------------------------------------------------------