├── .github ├── ISSUE_TEMPLATE.md └── workflows │ ├── build.yml │ └── close-pull-request.yml ├── CONTRIBUTING ├── LICENSE ├── MANIFEST.in ├── README.md ├── implementation.md ├── prompt_tuning ├── __init__.py ├── configs │ ├── __init__.py │ ├── architectures │ │ ├── __init__.py │ │ ├── prompt_encoder_t5_1_1_flaxformer.gin │ │ └── t5_1_1_flaxformer.gin │ ├── extended │ │ ├── __init__.py │ │ ├── architectures │ │ │ ├── __init__.py │ │ │ ├── ia3_t5_1_1_flaxformer.gin │ │ │ ├── multi_task_prompt_encoder_t5_1_1_flaxformer.gin │ │ │ └── per_layer_prompt_encoder_t5_1_1_flaxformer.gin │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── ia3_t5_1_1.gin │ │ │ ├── mt5_per_layer_prompt.gin │ │ │ ├── multi_task_t5_1_1_base_prompt.gin │ │ │ ├── multi_task_t5_1_1_large_prompt.gin │ │ │ ├── multi_task_t5_1_1_prompt.gin │ │ │ ├── multi_task_t5_1_1_small_prompt.gin │ │ │ ├── multi_task_t5_1_1_xl_prompt.gin │ │ │ ├── multi_task_t5_1_1_xxl_prompt.gin │ │ │ ├── t5_1_1_per_layer_prompt.gin │ │ │ ├── wayward_t5_1_1_base_prompt.gin │ │ │ ├── wayward_t5_1_1_large_prompt.gin │ │ │ └── wayward_t5_1_1_prompt.gin │ │ ├── runs │ │ │ ├── __init__.py │ │ │ ├── ia3_finetune.gin │ │ │ └── multitask_prompt_finetune.gin │ │ └── test │ │ │ ├── __init__.py │ │ │ ├── load_multi_task_t5_1_1_tiny_prompt.gin │ │ │ ├── multi_task_t5_1_1_tiny_prompt.gin │ │ │ └── train_multi_task_t5_1_1_tiny_prompt.gin │ ├── models │ │ ├── __init__.py │ │ ├── decoding │ │ │ ├── __init__.py │ │ │ ├── beam_search.gin │ │ │ └── nucleus_sampling.gin │ │ ├── mt5_150b_prompt.gin │ │ ├── mt5_80b_prompt.gin │ │ ├── mt5_base_prompt.gin │ │ ├── mt5_large_prompt.gin │ │ ├── mt5_prompt.gin │ │ ├── mt5_small_prompt.gin │ │ ├── mt5_xl_prompt.gin │ │ ├── mt5_xxl_prompt.gin │ │ ├── sizes │ │ │ ├── 150b.gin │ │ │ ├── 80b.gin │ │ │ ├── __init__.py │ │ │ ├── base.gin │ │ │ ├── large.gin │ │ │ ├── small.gin │ │ │ ├── xl.gin │ │ │ └── xxl.gin │ │ ├── t5_1_1_base_prompt.gin │ │ ├── t5_1_1_large_prompt.gin │ │ ├── t5_1_1_prompt.gin │ │ ├── t5_1_1_small_prompt.gin │ │ ├── t5_1_1_xl_prompt.gin │ │ └── t5_1_1_xxl_prompt.gin │ ├── prompts │ │ ├── __init__.py │ │ ├── from_class_labels.gin │ │ ├── from_class_labels_numpy.gin │ │ ├── from_file.gin │ │ ├── from_sampled_vocab.gin │ │ └── from_sampled_vocab_numpy.gin │ ├── runs │ │ ├── __init__.py │ │ ├── prompt_eval.gin │ │ ├── prompt_finetune.gin │ │ └── prompt_infer.gin │ └── test │ │ ├── __init__.py │ │ ├── load_t5_1_1_tiny_prompt.gin │ │ ├── t5_1_1_tiny.gin │ │ ├── t5_1_1_tiny_prompt.gin │ │ ├── train_t5_1_1_tiny.gin │ │ └── train_t5_1_1_tiny_prompt.gin ├── data │ ├── __init__.py │ ├── c4.py │ ├── constants.py │ ├── features.py │ ├── glue.py │ ├── glue_transfer.py │ ├── metrics.py │ ├── metrics_test.py │ ├── postprocessors.py │ ├── postprocessors_test.py │ ├── preprocessors.py │ ├── preprocessors_test.py │ ├── qa.py │ ├── show_tasks.py │ ├── summarization.py │ ├── super_glue.py │ ├── tasks.py │ ├── tasks_test.py │ ├── utils.py │ └── utils_test.py ├── extended │ ├── README.md │ ├── __init__.py │ ├── masks.py │ ├── masks_test.py │ ├── multitask_prompts.py │ ├── multitask_prompts_test.py │ ├── perceptron │ │ ├── README.md │ │ ├── __init__.py │ │ ├── configs │ │ │ ├── __init__.py │ │ │ └── models │ │ │ │ ├── __init__.py │ │ │ │ ├── cross_entropy_t5_1_1_prompt.gin │ │ │ │ └── perceptron_t5_1_1_prompt.gin │ │ ├── data │ │ │ ├── __init__.py │ │ │ └── tasks.py │ │ └── train │ │ │ ├── __init__.py │ │ │ ├── feature_converters.py │ │ │ └── models.py │ └── train │ │ ├── __init__.py │ │ ├── ia3.py │ │ ├── ia3_test.py │ │ ├── multitask_partitioning.py │ │ ├── multitask_prompts.py │ │ ├── multitask_prompts_test.py │ │ ├── per_layer.py │ │ ├── per_layer_test.py │ │ └── wayward.py ├── masks.py ├── masks_test.py ├── pretrained_prompts │ └── t5_1_1_lm100k_base │ │ ├── README.md │ │ ├── mrpc.npy │ │ ├── rte.npy │ │ └── sst2.npy ├── prompts.py ├── prompts_test.py ├── recycling │ ├── README.md │ ├── __init__.py │ ├── collect_recycling_results.py │ ├── configs │ │ ├── __init__.py │ │ ├── imdb.json │ │ └── sst2.json │ ├── data │ │ ├── __init__.py │ │ ├── c4.py │ │ ├── filtered-vocab-english-only.json │ │ ├── imdb.py │ │ ├── metrics.py │ │ ├── preprocessors.py │ │ ├── preprocessors_test.py │ │ ├── qqp.py │ │ ├── rank_classification.py │ │ ├── record.py │ │ └── sst2.py │ ├── recycle.py │ ├── run_recycle.py │ └── utils.py ├── scripts │ ├── __init__.py │ ├── diff_checkpoints.py │ ├── extract_variable.py │ ├── find_module.py │ ├── mrqa_to_tsv.py │ ├── recreate_checkpoint.py │ ├── sst2-demo-eval.sh │ ├── sst2-demo-xxl.sh │ ├── sst2-demo.sh │ └── subsample_vocab.py ├── spot │ ├── README.md │ ├── __init__.py │ └── data │ │ ├── __init__.py │ │ ├── glue.py │ │ ├── mrqa.py │ │ ├── nli.py │ │ ├── preprocessors.py │ │ ├── preprocessors_test.py │ │ ├── summarization.py │ │ └── tasks.py ├── test_data │ ├── prompt_5x256.npy │ ├── t5_vocab │ ├── test_t5_1_1_tiny │ │ └── checkpoint_3 │ │ │ ├── checkpoint │ │ │ ├── target.decoder.layers_0.encoder_decoder_attention.key.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_0.encoder_decoder_attention.out.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_0.encoder_decoder_attention.query.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_0.encoder_decoder_attention.value.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_0.mlp.wi_0.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_0.mlp.wi_1.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_0.mlp.wo.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_0.self_attention.key.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_0.self_attention.out.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_0.self_attention.query.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_0.self_attention.value.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_1.encoder_decoder_attention.key.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_1.encoder_decoder_attention.out.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_1.encoder_decoder_attention.query.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_1.encoder_decoder_attention.value.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_1.mlp.wi_0.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_1.mlp.wi_1.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_1.mlp.wo.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_1.self_attention.key.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_1.self_attention.out.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_1.self_attention.query.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.layers_1.self_attention.value.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.decoder.logits_dense.kernel │ │ │ ├── .zarray │ │ │ ├── 0.0 │ │ │ ├── 0.1 │ │ │ ├── 0.2 │ │ │ ├── 0.3 │ │ │ ├── 0.4 │ │ │ ├── 0.5 │ │ │ ├── 0.6 │ │ │ └── 0.7 │ │ │ ├── target.encoder.layers_0.attention.key.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.encoder.layers_0.attention.out.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.encoder.layers_0.attention.query.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.encoder.layers_0.attention.value.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.encoder.layers_0.mlp.wi_0.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.encoder.layers_0.mlp.wi_1.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.encoder.layers_0.mlp.wo.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.encoder.layers_1.attention.key.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.encoder.layers_1.attention.out.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.encoder.layers_1.attention.query.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.encoder.layers_1.attention.value.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.encoder.layers_1.mlp.wi_0.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.encoder.layers_1.mlp.wi_1.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ ├── target.encoder.layers_1.mlp.wo.kernel │ │ │ ├── .zarray │ │ │ └── 0.0 │ │ │ └── target.token_embedder.embedding │ │ │ ├── .zarray │ │ │ ├── 0.0 │ │ │ ├── 1.0 │ │ │ ├── 2.0 │ │ │ ├── 3.0 │ │ │ ├── 4.0 │ │ │ ├── 5.0 │ │ │ ├── 6.0 │ │ │ └── 7.0 │ └── tiny_embeddings_32128x4.npy ├── test_utils.py ├── test_utils_test.py ├── train │ ├── __init__.py │ ├── layers.py │ ├── layers_fixtures.py │ ├── layers_test.py │ ├── models.py │ ├── models_test.py │ ├── optim.py │ ├── partitioning.py │ ├── prompts.py │ ├── prompts_test.py │ ├── train_test.py │ ├── utils.py │ └── utils_test.py └── x_gen │ └── README.md ├── pyproject.toml └── setup.py /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Expected Behavior 2 | 3 | 4 | ## Actual Behavior 5 | 6 | 7 | ## Steps to Reproduce the Problem 8 | 9 | 1. 10 | 1. 11 | 1. 12 | 13 | ## Specifications 14 | 15 | - Version: 16 | - Platform: 17 | 18 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | workflow_dispatch: 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.8.x' 16 | - name: Install Dependencies and Package 17 | run: | 18 | python -m pip install --upgrade pip 19 | python -m pip install .[test] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 20 | - name: Test with Pytest 21 | run: | 22 | python -m pytest 23 | # The below step just reports the success or failure of tests as a "commit status". 24 | # This is needed for copybara integration. 25 | - name: Report success or failure as github status 26 | if: always() 27 | shell: bash 28 | run: | 29 | status="${{ job.status }}" 30 | lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') 31 | curl -sS --request POST \ 32 | --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ 33 | --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ 34 | --header 'content-type: application/json' \ 35 | --data '{ 36 | "state": "'$lowercase_status'", 37 | "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", 38 | "description": "'$status'", 39 | "context": "github-actions/build" 40 | }' 41 | -------------------------------------------------------------------------------- /.github/workflows/close-pull-request.yml: -------------------------------------------------------------------------------- 1 | name: Close Pull Request 2 | 3 | on: 4 | pull_request_target: 5 | types: [opened] 6 | 7 | jobs: 8 | run: 9 | runs-on: ubuntu-latest 10 | if: ${{ github.event.pusher.name != 'copybara-service' }} 11 | steps: 12 | - uses: superbrothers/close-pull-request@v3 13 | with: 14 | comment: "Unfortunately, we cannot accept contributions to the Prompt Tuning repo at this time. Please file issues as needed though!" 15 | -------------------------------------------------------------------------------- /CONTRIBUTING: -------------------------------------------------------------------------------- 1 | External contributions are not accepted, sorry! 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include prompt_tuning/configs *.gin 2 | recursive-include prompt_tuning/pretrained_prompts *.npy 3 | recursive-include prompt_tuning/pretrained_prompts *.md 4 | recursive-include prompt_tuning/recycling/configs *.json 5 | -------------------------------------------------------------------------------- /prompt_tuning/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Prompt Tuning.""" 16 | __version__ = "0.1.0" 17 | -------------------------------------------------------------------------------- /prompt_tuning/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gin requires gin files to be in a python package, this file makes it one.""" 16 | -------------------------------------------------------------------------------- /prompt_tuning/configs/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gin requires gin files to be in a python package, this file makes it one.""" 16 | -------------------------------------------------------------------------------- /prompt_tuning/configs/architectures/prompt_encoder_t5_1_1_flaxformer.gin: -------------------------------------------------------------------------------- 1 | # Flaxformer implementation of T5.1.1 architecture with prompting. 2 | # 3 | # Required to be overridden: 4 | # 5 | # - PROMPT 6 | # - PROMPT_LENGTH 7 | 8 | # We disable input order checks because the gin linter seems to want inputs in 9 | # alphabetical order without considering std, third_party, and project packages 10 | # like the python ordering of imports do. So disable that for now. 11 | 12 | # ginlint: disable=bad-import-order 13 | from __gin__ import dynamic_registration 14 | 15 | from flax import linen 16 | from flaxformer.components import embedding 17 | from flaxformer.components import layer_norm 18 | from flaxformer.components import relative_position_biases 19 | from flaxformer.architectures.t5 import t5_architecture 20 | 21 | from prompt_tuning.train import layers as prompt_layers 22 | from prompt_tuning import masks as prompt_masks 23 | 24 | PROMPT = %gin.REQUIRED 25 | PROMPT_LENGTH = %gin.REQUIRED 26 | 27 | include 'prompt_tuning/configs/architectures/t5_1_1_flaxformer.gin' 28 | 29 | # Architecture (Flax Module) 30 | # Use our subclass that has a prompted encoder and normal decoder 31 | ARCHITECTURE = @prompt_layers.PromptEncoderDecoder() 32 | prompt_layers.PromptEncoderDecoder: 33 | # Set the encoder to be my encoder subclass that adds a prompt 34 | encoder_factory = @prompt_layers.PromptEncoder 35 | decoder_factory = @t5_architecture.Decoder 36 | shared_token_embedder_factory = @embedding.Embed 37 | dtype = %ACTIVATION_DTYPE 38 | # Setting so our mask turns our correct. 39 | encoder_mask_factory = @prompt_masks.create_prompt_encoder_mask 40 | add_fake_prompt_factory = @prompt_masks.add_fake_prompt 41 | 42 | # How to create an attention mask for the encoder the considers the prompt 43 | prompt_masks.create_prompt_encoder_mask: 44 | prompt_length = %PROMPT_LENGTH 45 | 46 | # Decoder masking needs to know how long the prompt is because it creates 47 | # full visible attention over the prompts and inputs. 48 | prompt_masks.add_fake_prompt: 49 | prompt_length = %PROMPT_LENGTH 50 | multitask = False 51 | 52 | # Encoder 53 | prompt_layers.PromptEncoder: 54 | # How to create the prompt module. 55 | prompt_factory = %PROMPT 56 | add_fake_prompt_factory = @prompt_masks.add_fake_prompt 57 | num_layers = %NUM_ENCODER_LAYERS 58 | layer_factory = @t5_architecture.EncoderLayer 59 | input_dropout_factory = %DROPOUT_FACTORY 60 | output_dropout_factory = %DROPOUT_FACTORY 61 | layer_norm_factory = @layer_norm.T5LayerNorm 62 | position_embedder_factory = None 63 | shared_relative_position_bias_factory = @relative_position_biases.RelativePositionBiases 64 | dtype = %ACTIVATION_DTYPE 65 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gin requires gin files to be in a python package, this file makes it one.""" 16 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gin requires gin files to be in a python package, this file makes it one.""" 16 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/architectures/ia3_t5_1_1_flaxformer.gin: -------------------------------------------------------------------------------- 1 | # Flaxformer implementation of adding IA^3 to t5 1.1 2 | 3 | from __gin__ import dynamic_registration 4 | 5 | from flax import linen 6 | from prompt_tuning.extended.train import ia3 7 | from flaxformer.components import dense 8 | from flaxformer.components.attention import dense_attention 9 | from flaxformer.architectures.t5 import t5_architecture 10 | 11 | include 'prompt_tuning/configs/architectures/t5_1_1_flaxformer.gin' 12 | 13 | # Add ia3 to all attention implementations 14 | dense_attention.MultiHeadDotProductAttention: 15 | k_conv = @ia3.IA3Attention() 16 | v_conv = @ia3.IA3Attention() 17 | 18 | ia3.IA3Attention: 19 | dtype = %ACTIVATION_DTYPE 20 | 21 | dense.MlpBlock: 22 | intermediate_conv = @ia3.IA3() 23 | 24 | ia3.IA3: 25 | axis_name = ('mlp',) 26 | dtype = %ACTIVATION_DTYPE 27 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/architectures/multi_task_prompt_encoder_t5_1_1_flaxformer.gin: -------------------------------------------------------------------------------- 1 | # Flaxformer implementation of T5.1.1 architecture. 2 | # 3 | # Required to be overridden: 4 | # 5 | # - PROMPT 6 | # - PROMPT_LENGTH 7 | 8 | # We disable input order checks because the gin linter seems to want inputs in 9 | # alphabetical order without considering std, third_party, and project packages 10 | # like the python ordering of imports do. So disable that for now. 11 | 12 | # ginlint: disable=bad-import-order 13 | from __gin__ import dynamic_registration 14 | 15 | from prompt_tuning import masks 16 | from prompt_tuning.extended import masks as prompt_masks 17 | from prompt_tuning.train import layers as prompt_layers 18 | 19 | include 'prompt_tuning/configs/architectures/prompt_encoder_t5_1_1_flaxformer.gin' 20 | 21 | 22 | prompt_layers.PromptEncoderDecoder: 23 | encoder_mask_factory = @prompt_masks.prompt_encoder_attention_mask 24 | 25 | prompt_masks.prompt_encoder_attention_mask: 26 | prompt_length = %PROMPT_LENGTH 27 | multitask = True 28 | masks.add_fake_prompt.multitask = True 29 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/architectures/per_layer_prompt_encoder_t5_1_1_flaxformer.gin: -------------------------------------------------------------------------------- 1 | # Flaxformer implementation of T5.1.1 architecture with prompting. 2 | # 3 | # Required to be overridden: 4 | # 5 | # - PROMPT 6 | # - PROMPT_LENGTH 7 | # - PER_LAYER_PROMPT: When this is on, the combining function should be set to 8 | # either `replace_prompt` or `add_prompt` 9 | 10 | # We disable input order checks because the gin linter seems to want inputs in 11 | # alphabetical order without considering std, third_party, and project packages 12 | # like the python ordering of imports do. So disable that for now. 13 | 14 | # ginlint: disable=bad-import-order 15 | from __gin__ import dynamic_registration 16 | 17 | from flax import linen 18 | from flaxformer.components.attention import dense_attention 19 | from flaxformer.components import dense 20 | from flaxformer.components import embedding 21 | from flaxformer.components import layer_norm 22 | from flaxformer.components import relative_position_biases 23 | from flaxformer.architectures.t5 import t5_architecture 24 | 25 | from prompt_tuning import masks as prompt_masks 26 | from prompt_tuning.extended.train import per_layer 27 | from prompt_tuning.train import layers as prompt_layers 28 | 29 | PROMPT = %gin.REQUIRED 30 | PROMPT_LENGTH = %gin.REQUIRED 31 | PER_LAYER_PROMPT = %gin.REQUIRED 32 | 33 | include 'prompt_tuning/configs/architectures/t5_1_1_flaxformer.gin' 34 | 35 | # Architecture (Flax Module) 36 | # Use our subclass that has a prompted encoder and normal decoder 37 | ARCHITECTURE = @prompt_layers.PromptEncoderDecoder() 38 | prompt_layers.PromptEncoderDecoder: 39 | # Set the encoder to be the encoder subclass that adds a prompt 40 | encoder_factory = @prompt_layers.PromptEncoder 41 | decoder_factory = @t5_architecture.Decoder 42 | shared_token_embedder_factory = @embedding.Embed 43 | dtype = %ACTIVATION_DTYPE 44 | # Setting a prompt aware mask creation function that will extend the attention 45 | # mask so that inputs can attend to the newly added prompt variables. 46 | encoder_mask_factory = @prompt_masks.create_prompt_encoder_mask 47 | add_fake_prompt_factory = @prompt_masks.add_fake_prompt 48 | 49 | # How to create an attention mask for the encoder that considers the prompt 50 | prompt_masks.create_prompt_encoder_mask: 51 | prompt_length = %PROMPT_LENGTH 52 | 53 | # Decoder masking needs to know how long the prompt is because it creates 54 | # full visible attention over the prompts and inputs. 55 | prompt_masks.add_fake_prompt: 56 | prompt_length = %PROMPT_LENGTH 57 | multitask = False 58 | 59 | # Encoder 60 | prompt_layers.PromptEncoder: 61 | # How to create the prompt module. 62 | prompt_factory = %PROMPT 63 | add_fake_prompt_factory = @prompt_masks.add_fake_prompt 64 | num_layers = %NUM_ENCODER_LAYERS 65 | # Add a Prompt at each layer. 66 | layer_factory = @per_layer.PromptEncoderLayer 67 | input_dropout_factory = %DROPOUT_FACTORY 68 | output_dropout_factory = %DROPOUT_FACTORY 69 | layer_norm_factory = @layer_norm.T5LayerNorm 70 | position_embedder_factory = None 71 | shared_relative_position_bias_factory = @relative_position_biases.RelativePositionBiases 72 | dtype = %ACTIVATION_DTYPE 73 | 74 | # Add a Prompt to each layer 75 | per_layer.PromptEncoderLayer: 76 | attention = @dense_attention.MultiHeadDotProductAttention() 77 | mlp = @dense.MlpBlock() 78 | dropout_factory = %DROPOUT_FACTORY 79 | layer_norm_factory = @layer_norm.T5LayerNorm 80 | activation_partitioning_dims = %ACTIVATION_PARTITIONING_DIMS 81 | prompt_factory = %PER_LAYER_PROMPT 82 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gin requires gin files to be in a python package, this file makes it one.""" 16 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/models/mt5_per_layer_prompt.gin: -------------------------------------------------------------------------------- 1 | # mT5 1.1 Base Prompt model. 2 | # Provides MODEL, PROMPT, and PROMPT_LENGTH 3 | 4 | include 'prompt_tuning/configs/extended/models/t5_1_1_per_layer_prompt.gin' 5 | 6 | # Update the embeddings and vocab to point to the MT5 versions. 7 | NUM_EMBEDDINGS = 250112 8 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model" 9 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/models/multi_task_t5_1_1_base_prompt.gin: -------------------------------------------------------------------------------- 1 | # Multitask Prompt T5.1.1 Base Prompt model. 2 | 3 | include 'prompt_tuning/configs/extended/models/multi_task_t5_1_1_prompt.gin' 4 | include 'prompt_tuning/configs/models/sizes/base.gin' 5 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/models/multi_task_t5_1_1_large_prompt.gin: -------------------------------------------------------------------------------- 1 | # Multitask Prompt T5.1.1 Large Prompt model. 2 | 3 | include 'prompt_tuning/configs/extended/models/multi_task_t5_1_1_prompt.gin' 4 | include 'prompt_tuning/configs/models/sizes/large.gin' 5 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/models/multi_task_t5_1_1_small_prompt.gin: -------------------------------------------------------------------------------- 1 | # Multitask Prompt T5.1.1 Small Prompt model. 2 | 3 | include 'prompt_tuning/configs/extended/models/multi_task_t5_1_1_prompt.gin' 4 | include 'prompt_tuning/configs/models/sizes/small.gin' 5 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/models/multi_task_t5_1_1_xl_prompt.gin: -------------------------------------------------------------------------------- 1 | # Multitask Prompt T5.1.1 XL Prompt model. 2 | 3 | include 'prompt_tuning/configs/extended/models/multi_task_t5_1_1_prompt.gin' 4 | include 'prompt_tuning/configs/models/sizes/xl.gin' 5 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/models/multi_task_t5_1_1_xxl_prompt.gin: -------------------------------------------------------------------------------- 1 | # Multitask Prompt T5.1.1 XXL Prompt model. 2 | 3 | include 'prompt_tuning/configs/extended/models/multi_task_t5_1_1_prompt.gin' 4 | include 'prompt_tuning/configs/models/sizes/XXL.gin' 5 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/models/wayward_t5_1_1_base_prompt.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Base Prompt model with a wayward prompt. 2 | # 3 | # This is a standard prompt tuning training run except we have an extra term in 4 | # our loss function that regularizes the learned prompt towards the embedded 5 | # representation of a discrete prompt from Khashabi, et al. (2021) 6 | # https://arxiv.org/abs/2112.08348 7 | # 8 | # Provides MODEL, PROMPT, and PROMPT_LENGTH. 9 | # 10 | # You can set TASK_STRING to update the task description we are trying to match. 11 | 12 | include 'prompt_tuning/configs/extended/models/wayward_t5_1_1_prompt.gin' 13 | include 'prompt_tuning/configs/models/sizes/base.gin' 14 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/models/wayward_t5_1_1_large_prompt.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Large Prompt model with a wayward prompt. 2 | # 3 | # This is a standard prompt tuning training run except we have an extra term in 4 | # our loss function that regularizes the learned prompt towards the embedded 5 | # representation of a discrete prompt from Khashabi, et al. (2021) 6 | # https://arxiv.org/abs/2112.08348 7 | # 8 | # Provides MODEL, PROMPT, and PROMPT_LENGTH. 9 | # 10 | # You can set TASK_STRING to update the task description we are trying to match. 11 | 12 | include 'prompt_tuning/configs/extended/models/wayward_t5_1_1_prompt.gin' 13 | include 'prompt_tuning/configs/models/sizes/large.gin' 14 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/models/wayward_t5_1_1_prompt.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Base Prompt model with a wayward prompt. 2 | # 3 | # This is a standard prompt tuning training run except we have an extra term in 4 | # our loss function that regularizes the learned prompt towards the embedded 5 | # representation of a discrete prompt from Khashabi, et al. (2021) 6 | # https://arxiv.org/abs/2112.08348 7 | # 8 | # Provides MODEL, PROMPT, and PROMPT_LENGTH. 9 | # 10 | # You can set TASK_STRING to update the task description we are trying to match. 11 | # 12 | # The default TASK_STRING has a prompt length of 90 and is setup for SST2. 13 | 14 | from __gin__ import dynamic_registration 15 | 16 | from flax import linen 17 | from prompt_tuning import prompts 18 | from prompt_tuning.extended.train import wayward 19 | from prompt_tuning.train import prompts as train_prompts 20 | 21 | include 'prompt_tuning/configs/models/t5_1_1_prompt.gin' 22 | 23 | 24 | # SST2 preprocessors in t5 result in spacing around punctuation which we 25 | # include here. 26 | TASK_STRING = ( 27 | "Classify this movie review based on its sentiment . Use 2 " 28 | "classes . One positive ( for reviews that paint the movie in " 29 | "a favorable light ) and one negative ( for reviews that make " 30 | "you not want to see the movie or think it will be bad ) . Use " 31 | "the string ` positive ` for the positive class , the good / " 32 | "great movies , and use the string ` negative ` for the negative " 33 | "class , the bad movies ." 34 | ) 35 | 36 | VERBALIZERS = None 37 | 38 | TASK_PIECES = @wayward.encode_string() 39 | wayward.encode_string: 40 | s = %TASK_STRING 41 | vocab = %VOCABULARY 42 | format_values = %VERBALIZERS 43 | 44 | PROMPT_LENGTH = @wayward.length() 45 | wayward.length.x = %TASK_PIECES 46 | 47 | PROMPT = @train_prompts.Prompt 48 | train_prompts.Prompt.prompt = @prompts.Prompt() 49 | 50 | prompts.Prompt: 51 | length = %PROMPT_LENGTH 52 | prompt_init = @prompt_init/prompts.from_embedded_string() 53 | 54 | prompt_init/prompts.from_embedded_string: 55 | embeddings = @prompt_init/prompts.t5x_load() 56 | vocab = %VOCABULARY 57 | text = %TASK_STRING 58 | initializer = @linen.initializers.zeros 59 | 60 | prompt_init/prompts.t5x_load: 61 | checkpoint_path = %INITIAL_CHECKPOINT_PATH 62 | variable_path = "token_embedder/embedding" 63 | 64 | 65 | MODEL = @wayward.WaywardPromptEncoderDecoderModel() 66 | wayward.WaywardPromptEncoderDecoderModel: 67 | module = %ARCHITECTURE 68 | input_vocabulary = %VOCABULARY 69 | output_vocabulary = %VOCABULARY 70 | optimizer_def = %OPTIMIZER 71 | z_loss = %Z_LOSS 72 | label_smoothing = %LABEL_SMOOTHING 73 | loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR 74 | distance = @wayward.squared_l2_distance 75 | discrete_prompt = @prompt_init/wayward.execute_initializer() 76 | prompt_path = "encoder/prompt/prompt/prompt" 77 | gamma = 0.01 78 | 79 | prompt_init/wayward.execute_initializer: 80 | init = @prompt_init/prompts.from_embedded_string() 81 | rng = None 82 | shape = (%PROMPT_LENGTH, %EMBED_DIM) 83 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/runs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gin requires gin files to be in a python package, this file makes it one.""" 16 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/runs/ia3_finetune.gin: -------------------------------------------------------------------------------- 1 | # Defaults for finetuning with IA3. 2 | # 3 | # See go/t5x-finetune for instructions. 4 | # 5 | # You must also include a binding for MODEL, PROMPT, and PROMPT_LENGTH. 6 | # 7 | # Required to be set: 8 | # 9 | # - MIXTURE_OR_TASK_NAME 10 | # - TASK_FEATURE_LENGTHS 11 | # - TRAIN_STEPS # includes pretrain steps 12 | # - MODEL_DIR # automatically set when using xm_launch 13 | # - INITIAL_CHECKPOINT_PATH 14 | # 15 | # When launching on XManager, `MODEL_DIR` (the directory to write fine-tuned 16 | # checkpoints to) is configured automatically by the XManager launch script. 17 | # When running locally, it needs to be passed in the `gin.MODEL_DIR` flag. 18 | # 19 | # `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt 20 | # has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. 21 | # 22 | # Commonly overridden options: 23 | # - DROPOUT_RATE 24 | # - BATCH_SIZE 25 | # - PjitPartitioner.num_partitions 26 | # - Trainer.num_microbatches 27 | # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess 28 | # on the fly. Most common tasks are cached, hence this is set to True by 29 | # default. 30 | from __gin__ import dynamic_registration 31 | from t5x import utils 32 | 33 | include "prompt_tuning/configs/runs/prompt_finetune.gin" 34 | 35 | # ========== These are IA3 HPs you might want to override ========== 36 | utils.create_learning_rate_scheduler: 37 | factors = "constant" 38 | # Learning rate from the paper. 39 | base_learning_rate = 3e-3 40 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/runs/multitask_prompt_finetune.gin: -------------------------------------------------------------------------------- 1 | # Defaults for finetuning with train.py. 2 | # 3 | # See go/t5x-finetune for instructions. 4 | # 5 | # You must also include a binding for MODEL, PROMPT, and PROMPT_LENGTHS. 6 | # 7 | # Required to be set: 8 | # 9 | # - MIXTURE_OR_TASK_NAME 10 | # - TASK_FEATURE_LENGTHS 11 | # - TRAIN_STEPS # includes pretrain steps 12 | # - MODEL_DIR # automatically set when using xm_launch 13 | # - INITIAL_CHECKPOINT_PATH 14 | # 15 | # When launching on XManager, `MODEL_DIR` (the directory to write fine-tuned 16 | # checkpoints to) is configured automatically by the XManager launch script. 17 | # When running locally, it needs to be passed in the `gin.MODEL_DIR` flag. 18 | # 19 | # `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt 20 | # has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. 21 | # 22 | # Commonly overridden options: 23 | # - DROPOUT_RATE 24 | # - BATCH_SIZE 25 | # - PjitPartitioner.num_partitions 26 | # - Trainer.num_microbatches 27 | # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess 28 | # on the fly. Most common tasks are cached, hence this is set to True by 29 | # default. 30 | from __gin__ import dynamic_registration 31 | from t5x import utils 32 | from t5x import partitioning 33 | from prompt_tuning.extended.train import multitask_partitioning as multitask_prompt_partitioning 34 | 35 | include "t5x/configs/runs/finetune.gin" 36 | 37 | # Prompt Tuning does not support packing. 38 | train/utils.DatasetConfig.pack = False 39 | train_eval/utils.DatasetConfig.pack = False 40 | 41 | # ========== These are Prompt Tuning HPs you might want to override ========== 42 | utils.create_learning_rate_scheduler: 43 | factors = "constant" 44 | base_learning_rate = 0.3 45 | 46 | utils.SaveCheckpointConfig: 47 | period = 1000 48 | # Keep a single checkpoint. Even though the majority of these checkpoint 49 | # weights are unchanged from our initial checkpoint we keep the copy so 50 | # recovery from preemption works. We save our prompt values ourselves so we 51 | # don't have to worry about losing them. 52 | keep = 1 53 | 54 | partitioning.PjitPartitioner: 55 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 56 | 57 | partitioning.standard_logical_axis_rules: 58 | additional_rules = @multitask_prompt_partitioning.standard_logical_axis_rules() 59 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gin requires gin files to be in a python package, this file makes it one.""" 16 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/test/load_multi_task_t5_1_1_tiny_prompt.gin: -------------------------------------------------------------------------------- 1 | # Test config to exercise train.py with "T5X core" partitioning. 2 | 3 | # We disable input order checks because the gin linter seems to want inputs in 4 | # alphabetical order without considering std, third_party, and project packages 5 | # like the python ordering of imports do. So disable that for now. 6 | 7 | # ginlint: disable=bad-import-order 8 | from __gin__ import dynamic_registration 9 | from t5x import utils 10 | from t5x import partitioning 11 | 12 | include 'prompt_tuning/configs/extended/runs/multitask_prompt_finetune.gin' 13 | include 'prompt_tuning/configs/extended/test/multi_task_t5_1_1_tiny_prompt.gin' 14 | 15 | MODEL_DIR = "/tmp" # Will be overridden in test. 16 | 17 | TRAIN_STEPS = 3 18 | MIXTURE_OR_TASK_MODULE = "prompt_tuning.data.tasks" 19 | MIXTURE_OR_TASK_NAME = "taskless_super_glue_boolq_v102_examples" 20 | TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512} 21 | 22 | partitioning.PjitPartitioner.num_partitions = 1 23 | utils.SaveCheckpointConfig.period = 2 24 | train/utils.DatasetConfig.batch_size = 8 25 | train_eval/utils.DatasetConfig.batch_size = 8 26 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/test/multi_task_t5_1_1_tiny_prompt.gin: -------------------------------------------------------------------------------- 1 | # Common pieces for the tests with the small model. 2 | 3 | include 'prompt_tuning/configs/extended/models/multi_task_t5_1_1_base_prompt.gin' 4 | 5 | NUM_HEADS = 2 6 | HEAD_DIM = 2 7 | EMBED_DIM = 4 8 | MLP_DIM = 8 9 | NUM_ENCODER_LAYERS = 2 10 | NUM_DECODER_LAYERS = 2 11 | 12 | relative_position_biases.RelativePositionBiases: 13 | num_buckets = 4 14 | max_distance = 8 15 | -------------------------------------------------------------------------------- /prompt_tuning/configs/extended/test/train_multi_task_t5_1_1_tiny_prompt.gin: -------------------------------------------------------------------------------- 1 | # Test config to exercise train.py with "T5X core" partitioning. 2 | 3 | # We disable input order checks because the gin linter seems to want inputs in 4 | # alphabetical order without considering std, third_party, and project packages 5 | # like the python ordering of imports do. So disable that for now. 6 | 7 | # ginlint: disable=bad-import-order 8 | from __gin__ import dynamic_registration 9 | from t5x import utils 10 | from t5x import partitioning 11 | 12 | include 't5x/configs/runs/pretrain.gin' 13 | include 'prompt_tuning/configs/extended/test/multi_task_t5_1_1_tiny_prompt.gin' 14 | 15 | MODEL_DIR = "/tmp" # Will be overridden in test. 16 | 17 | TRAIN_STEPS = 3 18 | MIXTURE_OR_TASK_MODULE = "prompt_tuning.data.tasks" 19 | MIXTURE_OR_TASK_NAME = "taskless_super_glue_boolq_v102_examples" 20 | TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512} 21 | 22 | partitioning.PjitPartitioner.num_partitions = 1 23 | utils.SaveCheckpointConfig.period = 2 24 | train/utils.DatasetConfig.batch_size = 8 25 | train/utils.DatasetConfig.pack = False 26 | train_eval/utils.DatasetConfig.batch_size = 8 27 | train_eval/utils.DatasetConfig.pack = False 28 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gin requires gin files to be in a python package, this file makes it one.""" 16 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/decoding/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gin requires gin files to be in a python package, this file makes it one.""" 16 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/decoding/beam_search.gin: -------------------------------------------------------------------------------- 1 | # Beam search decoding. 2 | # Include as a --gin_file argument after the `models/*.gin` and `prompts/*.gin` 3 | # arguments. 4 | from __gin__ import dynamic_registration 5 | 6 | from t5x import decoding 7 | from t5x import models 8 | 9 | models.EncoderDecoderModel.predict_batch_with_aux: 10 | num_decodes = 4 11 | models.EncoderDecoderModel: 12 | decode_fn = @decoding.beam_search 13 | decoding.beam_search: 14 | alpha = 0.6 15 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/decoding/nucleus_sampling.gin: -------------------------------------------------------------------------------- 1 | # Nucleus sampling decoding. 2 | # Include as a --gin_file argument after the `models/*.gin` and `prompts/*.gin` 3 | # arguments. 4 | from __gin__ import dynamic_registration 5 | 6 | from t5x import decoding 7 | from t5x import models 8 | 9 | models.EncoderDecoderModel.predict_batch_with_aux: 10 | num_decodes = 1 11 | models.EncoderDecoderModel: 12 | decode_fn = @decoding.temperature_sample 13 | decoding.temperature_sample: 14 | temperature = 1.0 15 | topp = 0.9 16 | topk = 0 17 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/mt5_150b_prompt.gin: -------------------------------------------------------------------------------- 1 | # MUM 150b model. 2 | # Provides MODEL, PROMPT, and PROMPT_LENGTH 3 | 4 | from __gin__ import dynamic_registration 5 | 6 | from flaxformer.architectures.t5 import t5_architecture 7 | from prompt_tuning.train import layers as prompt_layers 8 | import seqio 9 | from t5x import adafactor 10 | 11 | include 'prompt_tuning/configs/models/t5_1_1_prompt.gin' 12 | include 'prompt_tuning/configs/models/sizes/150b.gin' 13 | 14 | VOCABULARY = @seqio.SentencePieceVocabulary() 15 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model" 16 | 17 | LAYER_REMAT = 'full' 18 | 19 | # Custom Adafactor Rules needed below. 20 | # ------------------------------------------------------------------------------ 21 | adafactor.Adafactor.factor_map = @adafactor.HParamMap() 22 | adafactor.HParamMap.rules = @adafactor.standard_factor_rules() 23 | adafactor.standard_factor_rules.scan_layers = True 24 | 25 | # Scanned Layers 26 | #------------------------------------------------------------------------------- 27 | t5_architecture.DecoderLayer: 28 | scanned = True 29 | 30 | t5_architecture.Decoder: 31 | shared_relative_position_bias_factory = None 32 | scan_layers = True 33 | layer_remat = %LAYER_REMAT 34 | 35 | t5_architecture.EncoderLayer: 36 | scanned = True 37 | 38 | t5_architecture.Encoder: 39 | shared_relative_position_bias_factory = None 40 | scan_layers = True 41 | layer_remat = %LAYER_REMAT 42 | 43 | t5_architecture.EncoderDecoder: 44 | scan_layers = True 45 | 46 | t5_architecture.DecoderOnly: 47 | scan_layers = True 48 | 49 | prompt_layers.PromptEncoder: 50 | shared_relative_position_bias_factory = None 51 | scan_layers = True 52 | layer_remat = %LAYER_REMAT 53 | 54 | prompt_layers.PromptDecoder: 55 | shared_relative_position_bias_factory = None 56 | scan_layers = True 57 | layer_remat = %LAYER_REMAT 58 | 59 | prompt_layers.PromptEncoderDecoder: 60 | scan_layers = True 61 | 62 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/mt5_80b_prompt.gin: -------------------------------------------------------------------------------- 1 | # MT5 80b model. 2 | # Provides MODEL, PROMPT, and PROMPT_LENGTH 3 | 4 | include 'prompt_tuning/configs/models/mt5_prompt.gin' 5 | include 'prompt_tuning/configs/models/sizes/80b.gin' 6 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/mt5_base_prompt.gin: -------------------------------------------------------------------------------- 1 | # mT5 1.1 Base Prompt model. 2 | # Provides MODEL, PROMPT, and PROMPT_LENGTH 3 | 4 | include 'prompt_tuning/configs/models/mt5_prompt.gin' 5 | include 'prompt_tuning/configs/models/sizes/base.gin' 6 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/mt5_large_prompt.gin: -------------------------------------------------------------------------------- 1 | # mT5 1.1 Large Prompt model. 2 | # Provides MODEL, PROMPT, and PROMPT_LENGTH 3 | 4 | include 'prompt_tuning/configs/models/mt5_prompt.gin' 5 | include 'prompt_tuning/configs/models/sizes/large.gin' 6 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/mt5_prompt.gin: -------------------------------------------------------------------------------- 1 | # mT5 1.1 Base Prompt model. 2 | # Provides MODEL, PROMPT, and PROMPT_LENGTH 3 | 4 | include 'prompt_tuning/configs/models/t5_1_1_prompt.gin' 5 | 6 | # Update the embeddings and vocab to point to the MT5 versions. 7 | NUM_EMBEDDINGS = 250112 8 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model" 9 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/mt5_small_prompt.gin: -------------------------------------------------------------------------------- 1 | # mt5 1.1 small Prompt model. 2 | # Provides Model, Prompt, and PROMPT_LENGTH 3 | 4 | include 'prompt_tuning/configs/models/mt5_prompt.gin' 5 | include 'prompt_tuning/configs/models/sizes/small.gin' 6 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/mt5_xl_prompt.gin: -------------------------------------------------------------------------------- 1 | # mT5 1.1 XL Prompt Model. 2 | # Provides MODEL, PROMPT, and PROMPT_LENGTH 3 | 4 | include 'prompt_tuning/configs/models/mt5_prompt.gin' 5 | include 'prompt_tuning/configs/models/sizes/xl.gin' 6 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/mt5_xxl_prompt.gin: -------------------------------------------------------------------------------- 1 | # mT5 1.1 XXL Prompt model. 2 | # Provides MODEL, PROMPT, and PROMPT_LENGTH 3 | 4 | include 'prompt_tuning/configs/models/mt5_prompt.gin' 5 | include 'prompt_tuning/configs/models/sizes/xxl.gin' 6 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/sizes/150b.gin: -------------------------------------------------------------------------------- 1 | # Architecture overrides 2 | NUM_ENCODER_LAYERS = 32 3 | NUM_DECODER_LAYERS = 32 4 | NUM_HEADS = 96 5 | HEAD_DIM = 128 6 | EMBED_DIM = 12288 7 | MLP_DIM = 36864 8 | 9 | # 150B uses same vocab as smaller models, extra tokens were added for sharding. 10 | NUM_EMBEDDINGS = 250368 11 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/sizes/80b.gin: -------------------------------------------------------------------------------- 1 | # Settings to create an 80b parameter T5 1.1/MT5 model. 2 | 3 | NUM_HEADS = 256 4 | NUM_ENCODER_LAYERS = 16 5 | NUM_DECODER_LAYERS = 16 6 | HEAD_DIM = 64 7 | EMBED_DIM = 8192 8 | MLP_DIM = 65536 9 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/sizes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Make sure this gin dir is picked up as a python package when installed.""" 16 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/sizes/base.gin: -------------------------------------------------------------------------------- 1 | # Settings to create a Base T5 1.1/MT5 model. 2 | 3 | NUM_HEADS = 12 4 | NUM_ENCODER_LAYERS = 12 5 | NUM_DECODER_LAYERS = 12 6 | HEAD_DIM = 64 7 | EMBED_DIM = 768 8 | MLP_DIM = 2048 9 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/sizes/large.gin: -------------------------------------------------------------------------------- 1 | # Settings to create a Large T5 1.1/MT5 model. 2 | 3 | NUM_HEADS = 16 4 | NUM_ENCODER_LAYERS = 24 5 | NUM_DECODER_LAYERS = 24 6 | HEAD_DIM = 64 7 | EMBED_DIM = 1024 8 | MLP_DIM = 2816 9 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/sizes/small.gin: -------------------------------------------------------------------------------- 1 | # Settings to create a Small T5 1.1/MT5 model. 2 | 3 | NUM_HEADS = 6 4 | NUM_ENCODER_LAYERS = 8 5 | NUM_DECODER_LAYERS = 8 6 | HEAD_DIM = 64 7 | EMBED_DIM = 512 8 | MLP_DIM = 1024 9 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/sizes/xl.gin: -------------------------------------------------------------------------------- 1 | # Settings to create an XL T5 1.1/MT5 model. 2 | 3 | NUM_HEADS = 32 4 | NUM_ENCODER_LAYERS = 24 5 | NUM_DECODER_LAYERS = 24 6 | HEAD_DIM = 64 7 | EMBED_DIM = 2048 8 | MLP_DIM = 5120 9 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/sizes/xxl.gin: -------------------------------------------------------------------------------- 1 | # Settings to create an XXL T5 1.1/MT5 model. 2 | 3 | NUM_HEADS = 64 4 | NUM_ENCODER_LAYERS = 24 5 | NUM_DECODER_LAYERS = 24 6 | HEAD_DIM = 64 7 | EMBED_DIM = 4096 8 | MLP_DIM = 10240 9 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/t5_1_1_base_prompt.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Small Prompt model. 2 | 3 | include 'prompt_tuning/configs/models/t5_1_1_prompt.gin' 4 | include 'prompt_tuning/configs/models/sizes/base.gin' 5 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/t5_1_1_large_prompt.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Large Prompt model. 2 | 3 | include 'prompt_tuning/configs/models/t5_1_1_prompt.gin' 4 | include 'prompt_tuning/configs/models/sizes/large.gin' 5 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/t5_1_1_small_prompt.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 Small Prompt model. 2 | 3 | include 'prompt_tuning/configs/models/t5_1_1_prompt.gin' 4 | include 'prompt_tuning/configs/models/sizes/small.gin' 5 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/t5_1_1_xl_prompt.gin: -------------------------------------------------------------------------------- 1 | # T5 1.1 XL Prompt model. 2 | 3 | include 'prompt_tuning/configs/models/t5_1_1_prompt.gin' 4 | include 'prompt_tuning/configs/models/sizes/xl.gin' 5 | -------------------------------------------------------------------------------- /prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin: -------------------------------------------------------------------------------- 1 | # T5.1.1 XXL Prompt model. 2 | 3 | include 'prompt_tuning/configs/models/t5_1_1_prompt.gin' 4 | include 'prompt_tuning/configs/models/sizes/xxl.gin' 5 | -------------------------------------------------------------------------------- /prompt_tuning/configs/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gin requires gin files to be in a python package, this file makes it one.""" 16 | -------------------------------------------------------------------------------- /prompt_tuning/configs/prompts/from_class_labels.gin: -------------------------------------------------------------------------------- 1 | # Initialize the Prompt based on the class labels where the embeddings are 2 | # loaded from the initial model checkpoint at `INITIAL_CHECKPOINT_PATH`. 3 | # Provides PROMPT 4 | # 5 | # Requires PROMPT_LENGTH, VOCABULARY, and CLASS_LABELS to be 6 | # set. 7 | # 8 | # Include as a --gin_file argument after the `models/*.gin` arguments. 9 | from __gin__ import dynamic_registration 10 | 11 | from prompt_tuning import prompts 12 | from prompt_tuning.train import prompts as train_prompts 13 | 14 | CLASS_LABELS = %gin.REQUIRED 15 | 16 | PROMPT = @train_prompts.Prompt 17 | train_prompts.Prompt.prompt = @prompts.Prompt() 18 | 19 | prompts.Prompt: 20 | length = %PROMPT_LENGTH 21 | prompt_init = @prompt_init/prompts.from_embedded_list() 22 | 23 | prompt_init/prompts.from_embedded_list: 24 | embeddings = @prompt_init/prompts.t5x_load() 25 | vocab = %VOCABULARY 26 | texts = %CLASS_LABELS 27 | initializer = @prompt_init/prompts.from_sample_of_embeddings() 28 | 29 | prompt_init/prompts.from_sample_of_embeddings: 30 | embeddings = @prompt_init/prompts.t5x_load() 31 | population_size = 5000 32 | 33 | prompt_init/prompts.t5x_load: 34 | checkpoint_path = %INITIAL_CHECKPOINT_PATH 35 | variable_path = "token_embedder/embedding" 36 | 37 | # Then set overrides in the launch script: 38 | # --gin.LABELS="['entailment', 'contradiction', 'neutral']" \ 39 | -------------------------------------------------------------------------------- /prompt_tuning/configs/prompts/from_class_labels_numpy.gin: -------------------------------------------------------------------------------- 1 | # Initialize the Prompt based on the class labels where the embeddings are 2 | # loaded from a saved numpy file. 3 | # Provides PROMPT 4 | # 5 | # Requires PROMPT_LENGTH, VOCABULARY, EMBEDDING_FILE, and CLASS_LABELS to be 6 | # set. 7 | # 8 | # Include as a --gin_file argument after the `models/*.gin` arguments. 9 | from __gin__ import dynamic_registration 10 | 11 | from prompt_tuning import prompts 12 | from prompt_tuning.train import prompts as train_prompts 13 | 14 | EMBEDDING_FILE = %gin.REQUIRED 15 | CLASS_LABELS = %gin.REQUIRED 16 | 17 | PROMPT = @train_prompts.Prompt 18 | train_prompts.Prompt.prompt = @prompts.Prompt() 19 | 20 | prompts.Prompt: 21 | length = %PROMPT_LENGTH 22 | prompt_init = @prompt_init/prompts.from_embedded_list() 23 | 24 | prompt_init/prompts.from_embedded_list: 25 | embeddings = @prompt_init/prompts.np_load() 26 | vocab = %VOCABULARY 27 | texts = %CLASS_LABELS 28 | initializer = @prompt_init/prompts.from_sample_of_embeddings() 29 | 30 | prompt_init/prompts.from_sample_of_embeddings: 31 | embeddings = @prompt_init/prompts.np_load() 32 | population_size = 5000 33 | 34 | prompt_init/prompts.np_load.path = %EMBEDDING_FILE 35 | 36 | # Then set overrides in the launch script: 37 | # --gin.EMBEDDING_FILE="'/path/to/numpy/embeddings/t5_1_1_lm_adaptation/xxl.npy'" \ 38 | # --gin.LABELS="['entailment', 'contradiction', 'neutral']" \ 39 | -------------------------------------------------------------------------------- /prompt_tuning/configs/prompts/from_file.gin: -------------------------------------------------------------------------------- 1 | # Initialize the Prompt from a saved numpy file. 2 | # Provides PROMPT 3 | # 4 | # Requires PROMPT_LENGTH, and PROMPT_FILE to be set. 5 | # 6 | # Include as a --gin_file argument after the `models/*.gin` arguments. 7 | from __gin__ import dynamic_registration 8 | 9 | from prompt_tuning import prompts 10 | from prompt_tuning.train import prompts as train_prompts 11 | 12 | PROMPT_FILE = %gin.REQUIRED 13 | 14 | PROMPT = @train_prompts.Prompt 15 | train_prompts.Prompt.prompt = @prompts.Prompt() 16 | prompts.Prompt: 17 | length = %PROMPT_LENGTH 18 | prompt_init = @prompt_init/prompts.from_array() 19 | 20 | prompt_init/prompts.from_array.prompt = @prompt_init/prompts.np_load() 21 | prompt_init/prompts.np_load.path = %PROMPT_FILE 22 | 23 | # Then set overrides in the launch script: 24 | # --gin.PROMPT_FILE="'/path/to/numpy/prompt'" 25 | -------------------------------------------------------------------------------- /prompt_tuning/configs/prompts/from_sampled_vocab.gin: -------------------------------------------------------------------------------- 1 | # Initialize the Prompt based on sampling from the vocabulary where the 2 | # embeddings are loaded from the initial model checkpoint at 3 | # `INITIAL_CHECKPOINT_PATH`. 4 | # Provides PROMPT 5 | # 6 | # Requires PROMPT_LENGTH to be set. 7 | # 8 | # Include as a --gin_file argument after the `models/*.gin` arguments. 9 | from __gin__ import dynamic_registration 10 | 11 | from prompt_tuning import prompts 12 | from prompt_tuning.train import prompts as train_prompts 13 | 14 | PROMPT = @train_prompts.Prompt 15 | train_prompts.Prompt.prompt = @prompts.Prompt() 16 | 17 | prompts.Prompt: 18 | length = %PROMPT_LENGTH 19 | prompt_init = @prompt_init/prompts.from_sample_of_embeddings() 20 | 21 | prompt_init/prompts.from_sample_of_embeddings: 22 | embeddings = @prompt_init/prompts.t5x_load() 23 | population_size = 5000 24 | 25 | prompt_init/prompts.t5x_load: 26 | checkpoint_path = %INITIAL_CHECKPOINT_PATH 27 | variable_path = "token_embedder/embedding" 28 | -------------------------------------------------------------------------------- /prompt_tuning/configs/prompts/from_sampled_vocab_numpy.gin: -------------------------------------------------------------------------------- 1 | # Initialize the Prompt based on sampling from the vocabulary where the 2 | # embeddings are loaded from a saved numpy file. 3 | # Provides PROMPT 4 | # 5 | # Requires PROMPT_LENGTH and EMBEDDING_FILE to be set. 6 | # 7 | # Include as a --gin_file argument after the `models/*.gin` arguments. 8 | from __gin__ import dynamic_registration 9 | 10 | from prompt_tuning import prompts 11 | from prompt_tuning.train import prompts as train_prompts 12 | 13 | EMBEDDING_FILE = %gin.REQUIRED 14 | 15 | PROMPT = @train_prompts.Prompt 16 | train_prompts.Prompt.prompt = @prompts.Prompt() 17 | 18 | prompts.Prompt: 19 | length = %PROMPT_LENGTH 20 | prompt_init = @prompt_init/prompts.from_sample_of_embeddings() 21 | 22 | prompt_init/prompts.from_sample_of_embeddings: 23 | embeddings = @prompt_init/prompts.np_load() 24 | population_size = 5000 25 | 26 | prompt_init/prompts.np_load.path = %EMBEDDING_FILE 27 | 28 | # Then set overrides in the launch script: 29 | # --gin.EMBEDDING_FILE="'/path/to/numpy/embeddings/t5_1_1_lm_adaptation/xxl.npy'" \ 30 | -------------------------------------------------------------------------------- /prompt_tuning/configs/runs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gin requires gin files to be in a python package, this file makes it one.""" 16 | -------------------------------------------------------------------------------- /prompt_tuning/configs/runs/prompt_eval.gin: -------------------------------------------------------------------------------- 1 | # Defaults for eval.py. 2 | # 3 | # 4 | # You must also include a binding for MODEL. 5 | # 6 | # Required to be set: 7 | # 8 | # - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to evaluate on 9 | # - CHECKPOINT_PATH: The model checkpoint to evaluate 10 | # - EVAL_OUTPUT_DIR: The dir to write results to. 11 | # - PROMPT_FILE: The file to load the prompt from. 12 | # 13 | # 14 | # Commonly overridden options: 15 | # 16 | # - DatasetConfig.split 17 | # - DatasetConfig.batch_size 18 | # - DatasetConfig.use_cached 19 | # - RestoreCheckpointConfig.mode 20 | # - PjitPartitioner.num_partitions 21 | from __gin__ import dynamic_registration 22 | import __main__ as eval_script 23 | from t5x import partitioning 24 | from prompt_tuning.train import partitioning as prompt_partitioning 25 | 26 | include "t5x/configs/runs/eval.gin" 27 | # Force loading a prompt from a file. If you want to evaluate a prompted model 28 | # from the checkpoint produced by the prompt tuning training run (instead of 29 | # loading from the checkpoint used to initialize prompt tuning + the learned 30 | # prompt from a file that this config enables) you can use 31 | # "t5x/configs/runs/infer.py" directly (you will probably need 32 | # to update the RestoreCheckpointConfig to disable fallback_to_scratch used 33 | # during training). 34 | include "prompt_tuning/configs/prompts/from_file.gin" 35 | 36 | # Enable "reinitialization" of parameters so the prompt will be initialized from 37 | # file. 38 | eval_script.evaluate.fallback_init_rng = 0 39 | 40 | # Add partitioning rules for our new axis named `prompt` 41 | partitioning.PjitPartitioner: 42 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 43 | 44 | partitioning.standard_logical_axis_rules: 45 | additional_rules = @prompt_partitioning.standard_logical_axis_rules() 46 | -------------------------------------------------------------------------------- /prompt_tuning/configs/runs/prompt_finetune.gin: -------------------------------------------------------------------------------- 1 | # Defaults for finetuning with train.py. 2 | # 3 | # See go/t5x-finetune for instructions. 4 | # 5 | # You must also include a binding for MODEL, PROMPT, and PROMPT_LENGTH. 6 | # 7 | # Required to be set: 8 | # 9 | # - MIXTURE_OR_TASK_NAME 10 | # - TASK_FEATURE_LENGTHS 11 | # - TRAIN_STEPS # includes pretrain steps 12 | # - MODEL_DIR # automatically set when using xm_launch 13 | # - INITIAL_CHECKPOINT_PATH 14 | # 15 | # When launching on XManager, `MODEL_DIR` (the directory to write fine-tuned 16 | # checkpoints to) is configured automatically by the XManager launch script. 17 | # When running locally, it needs to be passed in the `gin.MODEL_DIR` flag. 18 | # 19 | # `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt 20 | # has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. 21 | # 22 | # Commonly overridden options: 23 | # - DROPOUT_RATE 24 | # - BATCH_SIZE 25 | # - PjitPartitioner.num_partitions 26 | # - Trainer.num_microbatches 27 | # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess 28 | # on the fly. Most common tasks are cached, hence this is set to True by 29 | # default. 30 | from __gin__ import dynamic_registration 31 | from t5x import utils 32 | from t5x import partitioning 33 | from prompt_tuning.train import partitioning as prompt_partitioning 34 | 35 | include "t5x/configs/runs/finetune.gin" 36 | 37 | # Prompt Tuning does not support packing. 38 | train/utils.DatasetConfig.pack = False 39 | train_eval/utils.DatasetConfig.pack = False 40 | 41 | # ========== These are Prompt Tuning HPs you might want to override ========== 42 | utils.create_learning_rate_scheduler: 43 | factors = "constant" 44 | base_learning_rate = 0.3 45 | 46 | utils.SaveCheckpointConfig: 47 | period = 1000 48 | # Keep a single checkpoint. Even though the majority of these checkpoint 49 | # weights are unchanged from our initial checkpoint we keep the copy so 50 | # recovery from preemption works. We save our prompt values ourselves so we 51 | # don't have to worry about losing them. 52 | keep = 1 53 | 54 | partitioning.PjitPartitioner: 55 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 56 | 57 | partitioning.standard_logical_axis_rules: 58 | additional_rules = @prompt_partitioning.standard_logical_axis_rules() 59 | -------------------------------------------------------------------------------- /prompt_tuning/configs/runs/prompt_infer.gin: -------------------------------------------------------------------------------- 1 | # Defaults for infer.py. 2 | # 3 | # 4 | # You must also include a binding for MODEL. 5 | # 6 | # Required to be set: 7 | # 8 | # - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to use for inference 9 | # - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features 10 | # to. 11 | # - CHECKPOINT_PATH: The model checkpoint to use for inference 12 | # - INFER_OUTPUT_DIR: The dir to write results to. 13 | # - PROMPT_FILE: The file to load the prompt from. 14 | # 15 | # 16 | # Commonly overridden options: 17 | # 18 | # - infer.mode 19 | # - infer.checkpoint_period 20 | # - infer.shard_id 21 | # - infer.num_shards 22 | # - DatasetConfig.split 23 | # - DatasetConfig.batch_size 24 | # - DatasetConfig.use_cached 25 | # - RestoreCheckpointConfig.is_tensorflow 26 | # - RestoreCheckpointConfig.mode 27 | # - PjitPartitioner.num_partitions 28 | from __gin__ import dynamic_registration 29 | import __main__ as infer_script 30 | from t5x import partitioning 31 | from prompt_tuning.train import partitioning as prompt_partitioning 32 | 33 | include "t5x/configs/runs/infer.gin" 34 | # Force loading a prompt from a file. If you want to evaluate a prompted model 35 | # from the checkpoint produced by the prompt tuning training run (instead of 36 | # loading from the checkpoint used to initialize prompt tuning + the learned 37 | # prompt from a file that this config enables) you can use 38 | # "t5x/configs/runs/infer.py" directly (you will probably need 39 | # to update the RestoreCheckpointConfig to disable fallback_to_scratch used 40 | # during training). 41 | include "prompt_tuning/configs/prompts/from_file.gin" 42 | 43 | # Enable "reinitialization" of parameters so the prompt will be initialized from 44 | # file. 45 | infer_script.infer.fallback_init_rng = 0 46 | 47 | # Add partitioning rules for our new axis named `prompt` 48 | partitioning.PjitPartitioner: 49 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 50 | 51 | partitioning.standard_logical_axis_rules: 52 | additional_rules = @prompt_partitioning.standard_logical_axis_rules() 53 | -------------------------------------------------------------------------------- /prompt_tuning/configs/test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gin requires gin files to be in a python package, this file makes it one.""" 16 | -------------------------------------------------------------------------------- /prompt_tuning/configs/test/load_t5_1_1_tiny_prompt.gin: -------------------------------------------------------------------------------- 1 | # Test config to exercise train.py with "T5X core" partitioning. 2 | 3 | # We disable input order checks because the gin linter seems to want inputs in 4 | # alphabetical order without considering std, third_party, and project packages 5 | # like the python ordering of imports do. So disable that for now. 6 | 7 | # ginlint: disable=bad-import-order 8 | from __gin__ import dynamic_registration 9 | from t5x import utils 10 | from t5x import partitioning 11 | 12 | include 'prompt_tuning/configs/runs/prompt_finetune.gin' 13 | include 'prompt_tuning/configs/test/t5_1_1_tiny_prompt.gin' 14 | 15 | MODEL_DIR = "/tmp" # Will be overridden in test. 16 | 17 | TRAIN_STEPS = 3 18 | MIXTURE_OR_TASK_MODULE = "prompt_tuning.data.tasks" 19 | MIXTURE_OR_TASK_NAME = "taskless_super_glue_boolq_v102_examples" 20 | TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512} 21 | 22 | partitioning.PjitPartitioner.num_partitions = 1 23 | utils.SaveCheckpointConfig.period = 2 24 | train/utils.DatasetConfig.batch_size = 8 25 | train_eval/utils.DatasetConfig.batch_size = 8 26 | -------------------------------------------------------------------------------- /prompt_tuning/configs/test/t5_1_1_tiny.gin: -------------------------------------------------------------------------------- 1 | # Common pieces for the tests with the small model. 2 | 3 | include 'nlp/nlx/infrastructure/flaxformer/t5x/configs/t5/models/t5_1_1_base.gin' 4 | 5 | NUM_HEADS = 2 6 | HEAD_DIM = 2 7 | EMBED_DIM = 4 8 | MLP_DIM = 8 9 | NUM_ENCODER_LAYERS = 2 10 | NUM_DECODER_LAYERS = 2 11 | 12 | relative_position_biases.RelativePositionBiases: 13 | num_buckets = 4 14 | max_distance = 8 15 | -------------------------------------------------------------------------------- /prompt_tuning/configs/test/t5_1_1_tiny_prompt.gin: -------------------------------------------------------------------------------- 1 | # Common pieces for the tests with the small model. 2 | 3 | include 'prompt_tuning/configs/models/t5_1_1_base_prompt.gin' 4 | 5 | NUM_HEADS = 2 6 | HEAD_DIM = 2 7 | EMBED_DIM = 4 8 | MLP_DIM = 8 9 | NUM_ENCODER_LAYERS = 2 10 | NUM_DECODER_LAYERS = 2 11 | 12 | relative_position_biases.RelativePositionBiases: 13 | num_buckets = 4 14 | max_distance = 8 15 | -------------------------------------------------------------------------------- /prompt_tuning/configs/test/train_t5_1_1_tiny.gin: -------------------------------------------------------------------------------- 1 | # Test config to exercise train.py with "T5X core" partitioning. 2 | # This was used to create the `test_t5_1_1_tiny/checkpoint_3` checkpoint. 3 | 4 | # We disable input order checks because the gin linter seems to want inputs in 5 | # alphabetical order without considering std, third_party, and project packages 6 | # like the python ordering of imports do. So disable that for now. 7 | 8 | # ginlint: disable=bad-import-order 9 | from __gin__ import dynamic_registration 10 | from t5x import utils 11 | from t5x import partitioning 12 | 13 | include 't5x/configs/runs/pretrain.gin' 14 | include 'prompt_tuning/configs/test/t5_1_1_tiny.gin' 15 | 16 | MODEL_DIR = "/tmp" # Will be overridden in test. 17 | 18 | TRAIN_STEPS = 3 19 | MIXTURE_OR_TASK_MODULE = "prompt_tuning.data.tasks" 20 | MIXTURE_OR_TASK_NAME = "taskless_super_glue_boolq_v102_examples" 21 | TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512} 22 | 23 | partitioning.PjitPartitioner.num_partitions = 1 24 | utils.SaveCheckpointConfig.period = 2 25 | train/utils.DatasetConfig.batch_size = 8 26 | train/utils.DatasetConfig.pack = False 27 | train/utils.DatasetConfig.use_cached = False 28 | train_eval/utils.DatasetConfig.batch_size = 8 29 | train_eval/utils.DatasetConfig.pack = False 30 | train_eval/utils.DatasetConfig.use_cached = False 31 | -------------------------------------------------------------------------------- /prompt_tuning/configs/test/train_t5_1_1_tiny_prompt.gin: -------------------------------------------------------------------------------- 1 | # Test config to exercise train.py with "T5X core" partitioning. 2 | 3 | # We disable input order checks because the gin linter seems to want inputs in 4 | # alphabetical order without considering std, third_party, and project packages 5 | # like the python ordering of imports do. So disable that for now. 6 | 7 | # ginlint: disable=bad-import-order 8 | from __gin__ import dynamic_registration 9 | from t5x import utils 10 | from t5x import partitioning 11 | 12 | include 't5x/configs/runs/pretrain.gin' 13 | include 'prompt_tuning/configs/test/t5_1_1_tiny_prompt.gin' 14 | 15 | MODEL_DIR = "/tmp" # Will be overridden in test. 16 | 17 | TRAIN_STEPS = 3 18 | MIXTURE_OR_TASK_MODULE = "prompt_tuning.data.tasks" 19 | MIXTURE_OR_TASK_NAME = "taskless_super_glue_boolq_v102_examples" 20 | TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512} 21 | 22 | partitioning.PjitPartitioner.num_partitions = 1 23 | utils.SaveCheckpointConfig.period = 2 24 | train/utils.DatasetConfig.batch_size = 8 25 | train/utils.DatasetConfig.pack = False 26 | train_eval/utils.DatasetConfig.batch_size = 8 27 | train_eval/utils.DatasetConfig.pack = False 28 | -------------------------------------------------------------------------------- /prompt_tuning/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /prompt_tuning/data/c4.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tasks that use the c4 dataset.""" 16 | 17 | import functools 18 | from prompt_tuning.data import features 19 | from prompt_tuning.data import preprocessors 20 | import seqio 21 | from t5.data import preprocessors as t5_preprocessors 22 | 23 | 24 | 25 | # ===== BART (Lewis et. al., 2019) Pre-training-like objectives ===== 26 | seqio.TaskRegistry.add( 27 | "c4_v220_bart_text_infilling", 28 | source=seqio.TfdsDataSource(tfds_name="c4/en:2.2.0"), 29 | preprocessors=[ 30 | functools.partial( 31 | t5_preprocessors.rekey, 32 | key_map={ 33 | "inputs": None, 34 | "targets": "text", 35 | }), 36 | seqio.preprocessors.tokenize, 37 | seqio.CacheDatasetPlaceholder(), 38 | preprocessors.text_infilling, 39 | seqio.preprocessors.append_eos_after_trim, 40 | ], 41 | output_features=features.T5_FEATURES, 42 | metric_fns=[], 43 | postprocess_fn=None 44 | ) 45 | 46 | 47 | seqio.TaskRegistry.add( 48 | "c4_v220_bart_token_deletion", 49 | source=seqio.TfdsDataSource(tfds_name="c4/en:2.2.0"), 50 | preprocessors=[ 51 | functools.partial( 52 | t5_preprocessors.rekey, 53 | key_map={ 54 | "inputs": None, 55 | "targets": "text", 56 | }), 57 | seqio.preprocessors.tokenize, 58 | seqio.CacheDatasetPlaceholder(), 59 | preprocessors.token_deletion, 60 | seqio.preprocessors.append_eos_after_trim, 61 | ], 62 | output_features=features.T5_FEATURES, 63 | metric_fns=[], 64 | postprocess_fn=None 65 | ) 66 | -------------------------------------------------------------------------------- /prompt_tuning/data/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Constants that ensure dict key matches between prompt tuning processors.""" 16 | 17 | PREDICTION = "prediction" 18 | PREDICTION_TEXT = "prediction_pretokenized" 19 | INPUT_FIELD = "inputs" 20 | INPUT_TEXT = "inputs_pretokenized" 21 | TARGET_FIELD = "targets" 22 | TARGET_TEXT = "targets_pretokenized" 23 | CONTEXT_TEXT = "context" 24 | QUESTION_TEXT = "question" 25 | ANSWERS_TEXT = "answers" 26 | EXTRA_ID_0 = "" 27 | EQUIVALENT = "equivalent" 28 | DUPLICATE = "duplicate" 29 | -------------------------------------------------------------------------------- /prompt_tuning/data/features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A collection of the default features used in various t5 tasks.""" 16 | 17 | try: 18 | # This makes sure we have access to the byt5 mixtures 19 | import byt5.tasks as byt5_tasks 20 | except ImportError: 21 | byt5_tasks = None 22 | try: 23 | # This makes sure we have access to all mt5 tasks and mixtures 24 | import multilingual_t5.tasks as mt5_tasks 25 | except ImportError: 26 | mt5_tasks = None 27 | # This makes sure we have access to all the t5 mixtures 28 | import t5.data.mixtures # pylint: disable=unused-import 29 | # This makes sure we have access to all the t5 tasks 30 | import t5.data.tasks as t5_tasks 31 | 32 | # Aliases for ease of use when you want features for a specific model. 33 | T5_FEATURES = t5_tasks.DEFAULT_OUTPUT_FEATURES 34 | MODEL_TO_FEATURES = { 35 | # We use the keys to prefix the tasks t5 is the default so we don't want a 36 | # prefix, thus the strange empty string as the key. 37 | "": T5_FEATURES, 38 | } 39 | 40 | if mt5_tasks is not None: 41 | MT5_FEATURES = mt5_tasks.DEFAULT_OUTPUT_FEATURES 42 | MODEL_TO_FEATURES["mt5_"] = MT5_FEATURES 43 | if byt5_tasks is not None: 44 | BYT5_FEATURES = byt5_tasks.DEFAULT_BYTE_OUTPUT_FEATURES 45 | MODEL_TO_FEATURES["byt5_"] = BYT5_FEATURES 46 | -------------------------------------------------------------------------------- /prompt_tuning/data/show_tasks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""A quick tool to examine tasks. 16 | 17 | Example usage: 18 | See a list of all tasks this module defines: 19 | python -m prompt_tuning.data.show_tasks 20 | 21 | See the first example of the test split of the taskless boolq task: 22 | python -m prompt_tuning.data.show_tasks \ 23 | --task=taskless_super_glue_boolq_v102_examples \ 24 | --split=test 25 | 26 | """ 27 | import importlib 28 | from absl import app 29 | from absl import flags 30 | import byt5.tasks as byt5_tasks # pylint: disable=unused-import 31 | import multilingual_t5.tasks as mt5_tasks # pylint: disable=unused-import 32 | import seqio 33 | import t5.data.mixtures as t5_mixtures # pylint: disable=unused-import 34 | import t5.data.tasks as t5_tasks # pylint: disable=unused-import 35 | 36 | OG_TASKS = frozenset(list(seqio.TaskRegistry._REGISTRY.keys())) # pylint: disable=protected-access 37 | OG_MIXTURES = frozenset(list(seqio.MixtureRegistry._REGISTRY.keys())) # pylint: disable=protected-access 38 | 39 | from prompt_tuning.data import tasks # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order 40 | 41 | FLAGS = flags.FLAGS 42 | flags.DEFINE_string("task", None, "The task you want to look at.") 43 | flags.DEFINE_string("split", "validation", 44 | "The split you want to look at in the task.") 45 | flags.DEFINE_string("module", 46 | None, 47 | "An extra module containing tasks to import.") 48 | 49 | 50 | def main(_): 51 | """Print all new tasks from the registry or a specific task.""" 52 | if FLAGS.module is not None: 53 | importlib.import_module(FLAGS.module) 54 | if FLAGS.task is None: 55 | print("New tasks from `prompt_tuning.data.tasks`") 56 | for task in sorted(seqio.TaskRegistry._REGISTRY.keys()): # pylint: disable=protected-access 57 | if task not in OG_TASKS: 58 | print(task) 59 | print("New mixtures from `prompt_tuning.data.tasks`") 60 | for mix in sorted(seqio.MixtureRegistry._REGISTRY.keys()): # pylint: disable=protected-access 61 | if mix not in OG_MIXTURES: 62 | print(mix) 63 | else: 64 | task = seqio.get_mixture_or_task(FLAGS.task) 65 | dataset = task.get_dataset(None, split=FLAGS.split, shuffle=False) 66 | print(f"The first example from the {FLAGS.split} split of {FLAGS.task}:") 67 | batch = next(iter(dataset)) 68 | for key, value in batch.items(): 69 | if key.endswith("_pretokenized"): 70 | print(f"\t{key}:\n\t\t{value.numpy().decode('utf-8')}") 71 | else: 72 | print(f"\t{key}:\n\t\t{value.numpy()}") 73 | 74 | 75 | if __name__ == "__main__": 76 | app.run(main) 77 | -------------------------------------------------------------------------------- /prompt_tuning/data/summarization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Summarization tasks for T5. 16 | 17 | This includes some summarization tasks, xsum and cnn/dailymail, as t5 tasks 18 | expect they are configured to write examples to Tensorboard when they run 19 | evaluation. 20 | """ 21 | import functools 22 | 23 | from prompt_tuning.data import features 24 | from prompt_tuning.data import metrics as rom_metrics 25 | from prompt_tuning.data import postprocessors as rom_postprocessors 26 | from prompt_tuning.data import utils 27 | import seqio 28 | from t5.data import preprocessors 29 | from t5.evaluation import metrics 30 | 31 | 32 | # ========== Summarization ========== 33 | for model_prefix, feats in features.MODEL_TO_FEATURES.items(): 34 | # ===== XSUM ===== 35 | seqio.TaskRegistry.add( 36 | f"{model_prefix}xsum_v110_examples", 37 | source=seqio.TfdsDataSource(tfds_name="xsum:1.1.0"), 38 | preprocessors=[ 39 | functools.partial( 40 | preprocessors.summarize, 41 | article_key="document", 42 | summary_key="summary"), 43 | seqio.preprocessors.tokenize, 44 | seqio.CacheDatasetPlaceholder(), 45 | seqio.preprocessors.append_eos_after_trim, 46 | ], 47 | postprocess_fn=functools.partial( 48 | rom_postprocessors.postprocess_with_examples, utils.identity), 49 | metric_fns=[ 50 | functools.partial(rom_metrics.metric_with_examples, metrics.rouge), 51 | functools.partial( 52 | rom_metrics.text_examples, task_name="xsum", find_negative=False), 53 | ], 54 | output_features=feats) 55 | 56 | # ===== CNN/DailyMail ===== 57 | seqio.TaskRegistry.add( 58 | f"{model_prefix}cnn_dailymail_v310_examples", 59 | source=seqio.TfdsDataSource(tfds_name="cnn_dailymail:3.1.0"), 60 | preprocessors=[ 61 | functools.partial( 62 | preprocessors.summarize, 63 | article_key="article", 64 | summary_key="highlights"), 65 | seqio.preprocessors.tokenize, 66 | seqio.CacheDatasetPlaceholder(), 67 | seqio.preprocessors.append_eos_after_trim, 68 | ], 69 | postprocess_fn=functools.partial( 70 | rom_postprocessors.postprocess_with_examples, utils.identity), 71 | metric_fns=[ 72 | functools.partial(rom_metrics.metric_with_examples, metrics.rouge), 73 | functools.partial( 74 | rom_metrics.text_examples, 75 | task_name="cnn_dailymail", 76 | find_negative=False) 77 | ], 78 | output_features=feats) 79 | -------------------------------------------------------------------------------- /prompt_tuning/data/tasks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Add T5 and MT5 tasks that include text examples postprocessors/metrics. 16 | 17 | This includes several definitions of tasks (mostly from glue and super glue) 18 | using both the T5 and the MT5 sentence piece vocabs. They also include special 19 | postprocessing functions that copy the `inputs_pretokenized` and the 20 | `targets_pretokenized` fields from the input into a dictionary. This dict also 21 | has the model prediction text as well as the model prediction after being 22 | postprocessed. There are also special metric functions, one knows how to 23 | extract the targets and predictions from the dict before running the default 24 | metric functions. There is also a metric function that returns a 25 | `seqio.metrics.Text` object. This is then written to the text tab of 26 | tensorboard. 27 | """ 28 | 29 | # pylint: disable=unused-import,g-import-not-at-top 30 | # pytype: disable=import-error 31 | from prompt_tuning.data import c4 32 | from prompt_tuning.data import glue 33 | from prompt_tuning.data import glue_transfer 34 | from prompt_tuning.data import qa 35 | from prompt_tuning.data import summarization 36 | from prompt_tuning.data import super_glue 37 | -------------------------------------------------------------------------------- /prompt_tuning/data/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for defining tasks.""" 16 | 17 | from typing import Sequence, Mapping, Optional, Any, Dict 18 | 19 | 20 | def identity(x: Any, *args, **kwargs) -> Any: # pylint: disable=unused-argument 21 | """Identity function used when a task doesn't have a postprocess function.""" 22 | return x 23 | 24 | 25 | def task_mapping(tasks: Sequence[str], 26 | aliases: Optional[Mapping[str, str]]) -> Dict[str, int]: 27 | """Create a mapping from task name to index, sorted by task name. 28 | 29 | Args: 30 | tasks: The tasks that we are creating an index for. If any alias appears in 31 | these tasks it is removed before the sorting happens. 32 | aliases: Optional alternative names for a task, generally used for things 33 | like the SuperGLUE auxiliary tasks where things like the AX-b task is cast 34 | as the RTE task. 35 | 36 | Raises: 37 | ValueError if a alias is supposed to map to a task that was not provided. 38 | 39 | Returns: 40 | A mapping from task names to task indexes, where the tasks have been 41 | assigned based on the sorted task names. If any aliases are provided, they 42 | map to the same index as the task they alias. 43 | """ 44 | # From the python 3.9.6 documentation for a set: 45 | # Return a new set or frozenset object whose elements are taken from iterable. 46 | # tasks now refers to a different set from the one possibly passed in so we 47 | # can do in-place operations like `-=` to it later. 48 | tasks = set(tasks) 49 | if aliases is None: 50 | aliases = {} 51 | # Remove any aliases from the list of tasks. 52 | tasks -= set(aliases.keys()) 53 | task_index = {task: i for i, task in enumerate(sorted(tasks))} 54 | for alias, target in aliases.items(): 55 | if target not in task_index: 56 | raise ValueError("You are trying to create a task alias from " 57 | f"{alias}->{target} but {target} is not a provided " 58 | "task.") 59 | task_index[alias] = task_index[target] 60 | return task_index 61 | 62 | 63 | def remove_prefix(s: str, prefix: str) -> str: 64 | """Remove prefix from the beginning of the string if present.""" 65 | if s.startswith(prefix): 66 | return s[len(prefix):] 67 | return s[:] 68 | 69 | 70 | def remove_suffix(string: str, suffix: str) -> str: 71 | """Remove suffix from the end of the string if present.""" 72 | if suffix and string.endswith(suffix): 73 | return string[:-len(suffix)] 74 | return string[:] 75 | -------------------------------------------------------------------------------- /prompt_tuning/data/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for utils.""" 16 | 17 | from absl.testing import absltest 18 | from prompt_tuning.data import utils 19 | 20 | 21 | class UtilsTest(absltest.TestCase): 22 | 23 | def test_task_mapping(self): 24 | tasks = ['a', 'alias_a', 'alias_b', 'b', 'c', 'alias_b2'] 25 | aliases = { 26 | 'alias_a': 'a', 27 | 'alias_b': 'b', 28 | 'alias_b2': 'b' 29 | } 30 | gold_tasks = { 31 | 'a': 0, 32 | 'b': 1, 33 | 'c': 2, 34 | 'alias_a': 0, 35 | 'alias_b': 1, 36 | 'alias_b2': 1, 37 | } 38 | task_index = utils.task_mapping(tasks, aliases) 39 | self.assertEqual(task_index, gold_tasks) 40 | 41 | def test_task_mapping_raises_error(self): 42 | with self.assertRaises(ValueError): 43 | utils.task_mapping([], {'alias': 'missing_task'}) 44 | 45 | def test_remove_prefix(self): 46 | prefix = 'prefix' 47 | gold = 'some other words' 48 | inputs = f'{prefix}{gold}' 49 | self.assertEqual(utils.remove_prefix(inputs, prefix), gold) 50 | 51 | def test_remove_prefix_missing(self): 52 | prefix = 'prefix' 53 | gold = 'some other words' 54 | inputs = gold 55 | self.assertEqual(utils.remove_prefix(inputs, prefix), gold) 56 | 57 | def test_remove_prefix_infix(self): 58 | prefix = 'prefix' 59 | post = 'some other words' 60 | pre = 'EXTRA EXTRA, PREFIX ALL ABOUT IT' 61 | inputs = f'{pre}{prefix}{post}' 62 | self.assertEqual(utils.remove_prefix(inputs, prefix), inputs) 63 | 64 | def test_remove_suffix(self): 65 | suffix = 'suffix' 66 | gold = 'some other words' 67 | inputs = f'{gold}{suffix}' 68 | self.assertEqual(utils.remove_suffix(inputs, suffix), gold) 69 | 70 | def test_remove_suffix_missing(self): 71 | suffix = 'suffix' 72 | gold = 'some other words' 73 | inputs = gold 74 | self.assertEqual(utils.remove_suffix(inputs, suffix), gold) 75 | 76 | def test_remove_suffix_infix(self): 77 | suffix = 'suffix' 78 | post = 'things the follow after' 79 | pre = 'i\'m def in front of the suffix' 80 | inputs = f'{pre}{suffix}{post}' 81 | self.assertEqual(utils.remove_suffix(inputs, suffix), inputs) 82 | 83 | 84 | if __name__ == '__main__': 85 | absltest.main() 86 | -------------------------------------------------------------------------------- /prompt_tuning/extended/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /prompt_tuning/extended/perceptron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /prompt_tuning/extended/perceptron/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /prompt_tuning/extended/perceptron/configs/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /prompt_tuning/extended/perceptron/configs/models/cross_entropy_t5_1_1_prompt.gin: -------------------------------------------------------------------------------- 1 | # T5 1.1 Base Prompt model, trained with a Perceptron Loss. 2 | # Provides MODEL, PROMPT, and PROMPT_LENGTH 3 | 4 | from __gin__ import dynamic_registration 5 | from prompt_tuning.extended.perceptron.train import models 6 | 7 | include 'prompt_tuning/configs/models/t5_1_1_prompt.gin' 8 | 9 | MODEL = @models.CrossEntropyEncoderDecoderModel() 10 | models.CrossEntropyEncoderDecoderModel: 11 | module = %ARCHITECTURE 12 | input_vocabulary = %VOCABULARY 13 | output_vocabulary = %VOCABULARY 14 | optimizer_def = %OPTIMIZER 15 | z_loss = %Z_LOSS 16 | label_smoothing = %LABEL_SMOOTHING 17 | loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR 18 | length_normalize = True 19 | -------------------------------------------------------------------------------- /prompt_tuning/extended/perceptron/configs/models/perceptron_t5_1_1_prompt.gin: -------------------------------------------------------------------------------- 1 | # T5 1.1 Base Prompt model, trained with a Perceptron Loss. 2 | # Provides MODEL, PROMPT, and PROMPT_LENGTH 3 | 4 | from __gin__ import dynamic_registration 5 | from prompt_tuning.extended.perceptron.train import models 6 | 7 | include 'prompt_tuning/configs/models/t5_1_1_prompt.gin' 8 | 9 | MODEL = @models.PerceptronLossEncoderDecoderModel() 10 | models.PerceptronLossEncoderDecoderModel: 11 | module = %ARCHITECTURE 12 | input_vocabulary = %VOCABULARY 13 | output_vocabulary = %VOCABULARY 14 | optimizer_def = %OPTIMIZER 15 | z_loss = %Z_LOSS 16 | label_smoothing = %LABEL_SMOOTHING 17 | loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR 18 | length_normalize = True 19 | hinge = 0.2 20 | -------------------------------------------------------------------------------- /prompt_tuning/extended/perceptron/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /prompt_tuning/extended/perceptron/train/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /prompt_tuning/extended/train/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /prompt_tuning/extended/train/ia3.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """IA^3 (https://arxiv.org/abs/2205.05638) implementation.""" 16 | 17 | from typing import Tuple 18 | 19 | import flax.linen as nn 20 | from flax.linen import partitioning 21 | import jax.numpy as jnp 22 | from prompt_tuning import prompts 23 | from flaxformer.types import DType 24 | 25 | 26 | class IA3(nn.Module): 27 | """IA3 scaling to use with a linear layer. 28 | 29 | Note: 30 | This module is used as the intermediate_conv module in the flaxformer 31 | MlpBlock. The MlpBlock only applies this intermediate conv to one of the 32 | parallel activation functions it uses, but because these parallel 33 | activations are combined with multiplication, IA3 applies an multiplicative 34 | scaling, and multiplication is associative we can apply the scaling to just 35 | that activation and get the same result as if we applied it afterwards. 36 | 37 | Attributes: 38 | init: How to initialize the scaling variable. 39 | axis_name: The logical names of the variable axes, used for partitioning. 40 | dtype: The dtype of the activations for this module. 41 | """ 42 | ia3_init: prompts.Initializer = nn.initializers.ones 43 | axis_name: Tuple[str] = ('embed',) 44 | dtype: DType = jnp.float32 45 | 46 | @nn.compact 47 | def __call__(self, x, *args, **kwargs): 48 | del args 49 | del kwargs 50 | *rest, hidden = x.shape 51 | scaling = partitioning.param_with_axes( 52 | 'ia3_scaling', 53 | self.ia3_init, 54 | (hidden,), 55 | axes=self.axis_name 56 | ) 57 | scaling = scaling.astype(self.dtype) 58 | # Reshape to broadcast over batch, seq, etc. 59 | scaling = jnp.reshape(scaling, tuple((1 for _ in rest)) + scaling.shape) 60 | return x * scaling 61 | 62 | 63 | class IA3Attention(nn.Module): 64 | """A version of IA3 scaling to use with the Attention class. 65 | 66 | Note: 67 | Because of where we can hook into the flaxformer attention class (the 68 | `(k|v)_conv` module) the input to this function is already reshaped into 69 | [..., length, heads, kv] so we shape our scaling to match those last two 70 | dimensions. This will result in the same value as if we were to reshape 71 | the variable and do a single d_model scale. 72 | TODO: Rewrite as a single class that infers the number of dims 73 | to extract from the input to use to shape the param from the number of dims 74 | in the axis names. 75 | 76 | Attributes: 77 | init: How to initialize the scaling variable. 78 | axis_name: The logical names of the variable axes, used for partitioning. 79 | dtype: The dtype of the activations for this module. 80 | """ 81 | ia3_init: prompts.Initializer = nn.initializers.ones 82 | axis_names: Tuple[str, str] = ('heads', 'kv') 83 | dtype: DType = jnp.float32 84 | 85 | @nn.compact 86 | def __call__(self, x, *args, **kwargs): 87 | del args 88 | del kwargs 89 | *rest, heads, kv = x.shape 90 | scaling = partitioning.param_with_axes( 91 | 'ia3_scaling', 92 | self.ia3_init, 93 | (heads, kv), 94 | axes=self.axis_names 95 | ) 96 | scaling = scaling.astype(self.dtype) 97 | # Reshape to broadcast over batch, seq, etc. 98 | scaling = jnp.reshape(scaling, tuple((1 for _ in rest)) + scaling.shape) 99 | return x * scaling 100 | -------------------------------------------------------------------------------- /prompt_tuning/extended/train/ia3_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ia3.py.""" 16 | 17 | from absl.testing import absltest 18 | import flax.linen as nn 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | from prompt_tuning.extended.train import ia3 23 | 24 | 25 | class IA3Test(absltest.TestCase): 26 | 27 | def test_reshpaed_attention_equal_default(self): 28 | batch_size = 12 29 | seq = 10 30 | heads = 3 31 | kv = 2 32 | hidden = heads * kv 33 | 34 | default_input = jax.random.uniform(jax.random.PRNGKey(0), 35 | (batch_size, seq, hidden)) 36 | attention_input = jnp.reshape(default_input, (batch_size, seq, heads, kv)) 37 | 38 | params = ia3.IA3(ia3_init=nn.initializers.uniform()).init( 39 | jax.random.PRNGKey(0), default_input) 40 | attention_params = jax.tree.map( 41 | lambda x: jnp.reshape(x, (heads, kv)), params) 42 | 43 | default_result = jax.jit(ia3.IA3().apply)(params, default_input) 44 | attention_result = jax.jit(ia3.IA3Attention().apply)(attention_params, 45 | attention_input) 46 | np.testing.assert_allclose(jnp.reshape(default_result, (-1,)), 47 | jnp.reshape(attention_result, (-1,))) 48 | 49 | 50 | if __name__ == "__main__": 51 | absltest.main() 52 | -------------------------------------------------------------------------------- /prompt_tuning/extended/train/multitask_partitioning.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Custom, default partitioning rules for multitask prompt based models.""" 16 | 17 | import itertools 18 | from prompt_tuning.train import partitioning as pt_partitioning 19 | from t5x import partitioning 20 | 21 | 22 | def standard_logical_axis_rules() -> partitioning.LogicalAxisRules: 23 | """Add multitask prompt partitioning rules.""" 24 | return tuple(itertools.chain(pt_partitioning.standard_logical_axis_rules(), 25 | (("tasks", None), ("prompt+embed", None)))) 26 | -------------------------------------------------------------------------------- /prompt_tuning/extended/train/multitask_prompts.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Training subclasses of our prompt that actually do the concatenation. 16 | 17 | This also creates a unified API between single task prompts, where the prompt is 18 | unbatched, and multi-task prompts, where the input needs to have the first token 19 | removed. 20 | """ 21 | 22 | from typing import Callable 23 | import flax.linen as nn 24 | from prompt_tuning.train import prompts 25 | from flaxformer.types import Array 26 | 27 | 28 | class MultiTaskPrompt(nn.Module): 29 | """Generate a MultiTaskPrompt and concatenate it with the input. 30 | 31 | This is the training time version of prompting a model. Calling the injected 32 | `prompt` module will generate your prompt. This prompt should be batched. This 33 | module then slices off the task index from the input and concatenates the 34 | prompt. This can be used in conjunction with the `multitask=True` arguments 35 | for attention mask creation to do multi-task prompting without need multitask 36 | subclasses of various flaxformer modules. 37 | 38 | Attributes: 39 | prompt: The model that actually generates the batched prompt. 40 | combine: A function that combines the prompt and the embedded input. 41 | """ 42 | prompt: nn.Module 43 | combine: Callable[[Array, Array, Array], Array] = prompts.prefix_prompt 44 | 45 | def __call__(self, x, x_embed): 46 | prompt = self.prompt(x, x_embed) 47 | # Remove the task index token 48 | x_embed = x_embed[:, 1:] 49 | # Pytype is throwing a false positive here, it probably thinks 50 | # `self.combine` is a method call that is giving a `self` parameter but it 51 | # is actually just a function so there are only 2 arguments, like the type 52 | # annotation says. 53 | return self.combine(prompt, x_embed, x) # pylint: disable=too-many-function-args 54 | -------------------------------------------------------------------------------- /prompt_tuning/extended/train/multitask_prompts_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for multitask training prompts.""" 16 | 17 | from unittest import mock 18 | from absl.testing import absltest 19 | import jax.numpy as jnp 20 | from prompt_tuning import test_utils 21 | from prompt_tuning.extended.train import multitask_prompts as train_multitask_prompts 22 | from prompt_tuning.train import prompts as train_prompts 23 | 24 | 25 | class PromptsTest(absltest.TestCase): 26 | 27 | def test_multitask_prompt_does_concatenation(self): 28 | embed_size = 20 29 | prompt_length = 5 30 | batch_size = 2 31 | seq_len = 20 32 | mock_prompt = mock.MagicMock() 33 | prompt = jnp.zeros((batch_size, prompt_length, embed_size)) 34 | mock_prompt.return_value = prompt 35 | mock_combine = mock.create_autospec( 36 | train_prompts.prefix_prompt, spec_set=True) 37 | prompt_module = train_multitask_prompts.MultiTaskPrompt( 38 | prompt=mock_prompt, combine=mock_combine) 39 | input_tokens = jnp.ones((batch_size, seq_len)) 40 | embed = jnp.ones((batch_size, seq_len, embed_size)) 41 | prompt_module.apply({"params": {}}, input_tokens, embed) 42 | 43 | mock_prompt.assert_called_once_with( 44 | test_utils.ArrayEqualMatcher(input_tokens), 45 | test_utils.ArrayAllCloseMatcher(embed)) 46 | 47 | mock_combine.assert_called_once_with( 48 | test_utils.ArrayAllCloseMatcher(prompt), 49 | test_utils.ArrayAllCloseMatcher(embed[:, 1:]), 50 | test_utils.ArrayEqualMatcher(input_tokens)) 51 | 52 | 53 | if __name__ == "__main__": 54 | absltest.main() 55 | -------------------------------------------------------------------------------- /prompt_tuning/extended/train/per_layer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for per layer prompting.""" 16 | 17 | from absl.testing import absltest 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | from prompt_tuning.extended.train import per_layer 22 | 23 | 24 | class PerLayerPromptsTest(absltest.TestCase): 25 | 26 | def test_replace_prompt(self): 27 | embed_size = 20 28 | prompt_length = 5 29 | batch_size = 2 30 | seq_len = 14 31 | prompt = jnp.zeros((batch_size, prompt_length, embed_size)) 32 | embed = jnp.ones((batch_size, seq_len, embed_size)) 33 | with_prompt = jax.jit(per_layer.replace_prompt)(prompt, embed, None) 34 | self.assertEqual(with_prompt.shape, embed.shape) 35 | np.testing.assert_array_equal(with_prompt[:, :prompt_length], prompt) 36 | np.testing.assert_array_equal(with_prompt[:, prompt_length:], 37 | embed[:, prompt_length:]) 38 | 39 | def test_add_prompt(self): 40 | embed_size = 20 41 | prompt_length = 5 42 | batch_size = 2 43 | seq_len = 14 44 | prompt = 2 * jnp.ones((batch_size, prompt_length, embed_size)) 45 | embed = jnp.ones((batch_size, seq_len, embed_size)) 46 | with_prompt = jax.jit(per_layer.add_prompt)(prompt, embed, None) 47 | self.assertEqual(with_prompt.shape, embed.shape) 48 | prompt_from_output = with_prompt[:, :prompt_length] 49 | np.testing.assert_array_equal(prompt_from_output, 50 | 3 * jnp.ones_like(prompt_from_output)) 51 | np.testing.assert_array_equal(with_prompt[:, prompt_length:], 52 | embed[:, prompt_length:]) 53 | 54 | 55 | if __name__ == "__main__": 56 | absltest.main() 57 | -------------------------------------------------------------------------------- /prompt_tuning/pretrained_prompts/t5_1_1_lm100k_base/README.md: -------------------------------------------------------------------------------- 1 | # Pretrained Prompts 2 | 3 | Prompts trained using T5 1.1 lm100k base as the frozen model. 4 | 5 | Path | Prompt Length | Batch Size | Topology | Partition | Dataset | Split | Metric | Score | Note 6 | ---------------------------------------------- | ------------: | ---------: | -------- | -----------: | ------- | ---------- | -------- | --------: |----------------- 7 | pretrained_prompts/t5_1_1_lm100k_base/sst2.npy | 100 | 128 | v3-8 | (1, 1, 1, 1) | SST2 | validation | Accuracy | 95.07 | Class Label Init 8 | pretrained_prompts/t5_1_1_lm100k_base/mrpc.npy | 100 | 32 | v3-8 | (1, 1, 1, 1) | MRPC | validation | F1/Acc | 89.7/85.3 | Class Label Init 9 | pretrained_prompts/t5_1_1_lm100k_base/rte.npy | 100 | 32 | v3-8 | (1, 1, 1, 1) | RTE | validation | Accuracy | 68.6 | Class Label Init 10 | -------------------------------------------------------------------------------- /prompt_tuning/pretrained_prompts/t5_1_1_lm100k_base/mrpc.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/prompt-tuning/72285eea06100f954bcbb16a447ec6cddfc6716c/prompt_tuning/pretrained_prompts/t5_1_1_lm100k_base/mrpc.npy -------------------------------------------------------------------------------- /prompt_tuning/pretrained_prompts/t5_1_1_lm100k_base/rte.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/prompt-tuning/72285eea06100f954bcbb16a447ec6cddfc6716c/prompt_tuning/pretrained_prompts/t5_1_1_lm100k_base/rte.npy -------------------------------------------------------------------------------- /prompt_tuning/pretrained_prompts/t5_1_1_lm100k_base/sst2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/prompt-tuning/72285eea06100f954bcbb16a447ec6cddfc6716c/prompt_tuning/pretrained_prompts/t5_1_1_lm100k_base/sst2.npy -------------------------------------------------------------------------------- /prompt_tuning/recycling/README.md: -------------------------------------------------------------------------------- 1 | # Prompt Recycling 2 | 3 | Note: This is a work in progress. 4 | 5 | Data and code for our paper 6 | [Reducing Retraining by Recycling Parameter-Efficient Prompts](https://arxiv.org/abs/2208.05577). 7 | 8 | # Usage 9 | 10 | 1. First we need a source prompt that will be recycled. See the 11 | [Training A Prompt Section](https://github.com/google-research/prompt-tuning/tree/main/prompt_tuning/README.md#training-a-prompt) 12 | of the main README on how to train a prompt using the Prompt Tuning code 13 | base. This prompt will be used as the input to the recycler. 14 | 2. Second, we need the path to the prompt we just trained. Wherever T5X was 15 | configured to save models checkpoints (controlled by the `--model_dir` 16 | flag), there will be a directory called `numpy_checkpoints`. In it there are 17 | directories for each saved step (`checkpoint_${step}`) and within that is 18 | the prompt, saved as a numpy file. This file will have a name like 19 | `encoder.prompt.prompt.prompt` (for Encoder-Decoder models) which is the 20 | path to the parameter through the model PyTree, using `.` for scoping. We 21 | will need this file. So to recap the trained prompt will live at: 22 | 23 | ```shell 24 | ${MODEL_DIR}/numpy_checkpoints/checkpoint_${step}/encoder.prompt.prompt.prompt 25 | ``` 26 | 27 | 3. We need to train a recycler using the source and target models and then 28 | apply it to the source prompt. The 29 | [run\_recycle.py](https://github.com/google-research/prompt-tuning/tree/main/prompt_tuning/recycling/run_recycle.py) 30 | script is able to do this. It takes the commandline arguments 31 | `--source_model` and `--target_model` which should point to the T5X 32 | checkpoints of the source model (which you trained the prompt with) and the 33 | target model (which you want to use with the recycled prompt) respectivly. 34 | It also requires the path to the source prompt as the `--prompt_path` 35 | parameter. Set this to the value above. You can select which recycler to use 36 | with the `--recycler` parameter. Finally the `--output_path` paramaeter is 37 | needed to specify where to save the recycled prompt. 38 | 4. Finally, it is time to run eval. Follow the instructions from the 39 | [Inference with a Prompt](https://github.com/google-research/prompt-tuning/tree/main/prompt_tuning/README.md#inference-with-a-prompt) 40 | section, but set the `--gin.PROMPT_FILE` override to the `--output_path` 41 | used above. 42 | 43 | ## Large Scale Automatic Experiments 44 | 45 | The 46 | [recycle.py](https://github.com/google-research/prompt-tuning/tree/main/prompt_tuning/recycling/recycle.py) 47 | script can be used with one of the 48 | [config files](https://github.com/google-research/prompt-tuning/tree/main/prompt_tuning/spot/recycling/configs/) 49 | to train recyclers and generate recycled prompts. It will produce a `.txt` file 50 | of CLI arguments that will be helpful in running all the recycling experiments. 51 | 52 | # Recycler Implementations 53 | 54 | Our recycler implementations include: 55 | 56 | * `v2v-nn` :: A Jax based neural network trained to map a subsample of source 57 | token embeddings to their corresponding target token embeddings. Implemented 58 | in `JaxRecycler`. 59 | * `v2v-lin` :: A linear projection, learned via least squares, that maps a 60 | subsample of source token embeddings to their corresponding target token 61 | embeddings. Implemented in `TFLstSqRecycler`. 62 | * `lin-comb` :: A linear combination of source embeddings that approximates 63 | the source prompt is learned and applied to the target embeddings. 64 | Implemented in `LinearCombinationRecycler`. 65 | 66 | # Vocabulary Filtering 67 | 68 | The final list of our filtered vocabulary items can be found in 69 | [filtered-vocab-english-only.json](https://github.com/google-research/prompt-tuning/tree/main/prompt_tuning/recycling/data/filtered-vocab-english-only.json) 70 | 71 | # How to Cite 72 | 73 | If you build on this code or these ideas please cite: 74 | 75 | ```bibtex 76 | @article{lester-etal-2022-recycling 77 | title={{R}educing {R}etraining by {R}ecycling {P}arameter-{E}fficient {P}rompts}, 78 | author={Lester, Brian and Yurtsever, Joshua and Shakeri, Siamak and Constant Noah}, 79 | year={2022}, 80 | journal={arXiv preprint arXiv:2208.05577}, 81 | url={https://arxiv.org/abs/2208.05577}, 82 | } 83 | ``` 84 | -------------------------------------------------------------------------------- /prompt_tuning/recycling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /prompt_tuning/recycling/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /prompt_tuning/recycling/configs/imdb.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_dir": "...", 3 | "dataset": "imdb", 4 | "task_name": "pr:imdb_v100_negative_positive_verbalizers_test_rc", 5 | "clobber": false, 6 | "steps": [1102000, 1105000, 1110000, 1120000], 7 | "load_embeddings": { 8 | "default": { 9 | "word_list_path": "prompt_tuning/recycling/data/filtered-vocab-english-only.json", 10 | "num_words": 4000, 11 | "word_offset": 1000 12 | } 13 | }, 14 | "recycling_methods": { 15 | "jax-nn": { 16 | "__init__": { 17 | "hidden_scale": 4 18 | }, 19 | "fit": { 20 | "batch_size": 50, 21 | "learning_rate": 0.0003, 22 | "epochs": 25 23 | } 24 | }, 25 | "tf-lstsq": {}, 26 | "linear-combination": {} 27 | }, 28 | "pretrained": { 29 | "Seed 0": "...", 30 | "Seed 2000": "...", 31 | "Default": "gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_base/checkpoint_1100000", 32 | "Large": "gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_large/checkpoint_1100000" 33 | }, 34 | "prompts": { 35 | "Seed 0": { 36 | "class-init": { 37 | "run 1": "...", 38 | "run 2": "...", 39 | "run 3": "..." 40 | }, 41 | "spot 10k": { 42 | "run 1": "...", 43 | "run 2": "...", 44 | "run 3": "..." 45 | }, 46 | "spot 50k": { 47 | "run 1": "...", 48 | "run 2": "...", 49 | "run 3": "..." 50 | } 51 | }, 52 | "Seed 2000": { 53 | "class-init": { 54 | "run 1": "...", 55 | "run 2": "...", 56 | "run 3": "..." 57 | }, 58 | "spot 10k": { 59 | "run 1": "...", 60 | "run 2": "...", 61 | "run 3": "..." 62 | }, 63 | "spot 50k": { 64 | "run 1": "...", 65 | "run 2": "...", 66 | "run 3": "..." 67 | } 68 | }, 69 | "Default": { 70 | "class-init": { 71 | "run 1": "...", 72 | "run 2": "...", 73 | "run 3": "..." 74 | }, 75 | "spot 10k": { 76 | "run 1": "...", 77 | "run 2": "...", 78 | "run 3": "..." 79 | }, 80 | "spot 50k": { 81 | "run 1": "...", 82 | "run 2": "...", 83 | "run 3": "..." 84 | } 85 | }, 86 | "Large": { 87 | "class-init": { 88 | "run 1": "...", 89 | "run 2": "...", 90 | "run 3": "..." 91 | }, 92 | "spot 10k": { 93 | "run 1": "...", 94 | "run 2": "...", 95 | "run 3": "..." 96 | }, 97 | "spot 50k": { 98 | "run 1": "...", 99 | "run 2": "...", 100 | "run 3": "..." 101 | } 102 | } 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /prompt_tuning/recycling/configs/sst2.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_dir": "...", 3 | "dataset": "sst2", 4 | "task_name": "pr:glue_sst2_v200_test_rc", 5 | "clobber": false, 6 | "steps": [1102000, 1105000, 1110000, 1120000], 7 | "load_embeddings": { 8 | "default": { 9 | "word_list_path": "prompt_tuning/recycling/data/filtered-vocab-english-only.json", 10 | "num_words": 4000, 11 | "word_offset": 1000 12 | }, 13 | "from_string": { 14 | "vocab_file": "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model", 15 | "variable_path": "token_embedder/embedding", 16 | "string": "Classify this movie review based on its sentiment . Use 2 classes . One positive ( for reviews that paint the movie in a favorable light ) and one negative ( for reviews that make you not want to see the movie or think it will be bad ) . Use the string ` positive ` for the positive class , the good / great movies , and use the string ` negative ` for the negative class , the bad movies ." 17 | } 18 | }, 19 | "recycling_methods": { 20 | "jax-nn": { 21 | "__init__": { 22 | "hidden_scale": 4 23 | }, 24 | "fit": { 25 | "batch_size": 50, 26 | "learning_rate": 0.0003, 27 | "epochs": 25 28 | } 29 | }, 30 | "tf-lstsq": {}, 31 | "linear-combination": {} 32 | }, 33 | "pretrained": { 34 | "Seed 0": "...", 35 | "Seed 2000": "...", 36 | "Default": "gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_base/checkpoint_1100000", 37 | "Large": "gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_large/checkpoint_1100000" 38 | }, 39 | "prompts": { 40 | "Seed 0": { 41 | "rand-init": { 42 | "run 1": "...", 43 | "run 2": "...", 44 | "run 3": "..." 45 | }, 46 | "class-init": { 47 | "run 1": "...", 48 | "run 2": "...", 49 | "run 3": "..." 50 | }, 51 | "spot 10k": { 52 | "run 1": "...", 53 | "run 2": "...", 54 | "run 3": "..." 55 | }, 56 | "spot 50k": { 57 | "run 1": "...", 58 | "run 2": "...", 59 | "run 3": "..." 60 | }, 61 | "wayward": { 62 | "run 1": "...", 63 | "run 2": "...", 64 | "run 3": "..." 65 | } 66 | }, 67 | "Seed 2000": { 68 | "rand-init": { 69 | "run 1": "...", 70 | "run 2": "...", 71 | "run 3": "..." 72 | }, 73 | "class-init": { 74 | "run 1": "...", 75 | "run 2": "...", 76 | "run 3": "..." 77 | }, 78 | "spot 10k": { 79 | "run 1": "...", 80 | "run 2": "...", 81 | "run 3": "..." 82 | }, 83 | "spot 50k": { 84 | "run 1": "...", 85 | "run 2": "...", 86 | "run 3": "..." 87 | }, 88 | "wayward": { 89 | "run 1": "...", 90 | "run 2": "...", 91 | "run 3": "..." 92 | } 93 | }, 94 | "Default": { 95 | "rand-init": { 96 | "run 1": "...", 97 | "run 2": "...", 98 | "run 3": "..." 99 | }, 100 | "class-init": { 101 | "run 1": "...", 102 | "run 2": "...", 103 | "run 3": "..." 104 | }, 105 | "spot 10k": { 106 | "run 1": "...", 107 | "run 2": "...", 108 | "run 3": "..." 109 | }, 110 | "spot 50k": { 111 | "run 1": "...", 112 | "run 2": "...", 113 | "run 3": "..." 114 | }, 115 | "wayward": { 116 | "run 1": "...", 117 | "run 2": "...", 118 | "run 3": "..." 119 | } 120 | }, 121 | "Large": { 122 | "rand-init": { 123 | "run 1": "...", 124 | "run 2": "...", 125 | "run 3": "..." 126 | }, 127 | "class-init": { 128 | "run 1": "...", 129 | "run 2": "...", 130 | "run 3": "..." 131 | }, 132 | "spot 10k": { 133 | "run 1": "...", 134 | "run 2": "...", 135 | "run 3": "..." 136 | }, 137 | "spot 50k": { 138 | "run 1": "...", 139 | "run 2": "...", 140 | "run 3": "..." 141 | }, 142 | "wayward": { 143 | "run 1": "...", 144 | "run 2": "...", 145 | "run 3": "..." 146 | } 147 | } 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /prompt_tuning/recycling/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /prompt_tuning/recycling/data/c4.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Task that us c4 but have random targets.""" 16 | 17 | import functools 18 | from prompt_tuning.data import features 19 | from prompt_tuning.recycling.data import preprocessors as rec_preprocessors 20 | import seqio 21 | from t5.data import preprocessors as t5_preprocessors 22 | 23 | 24 | seqio.TaskRegistry.add( 25 | "c4_v220_random_positive_negative_targets", 26 | source=seqio.TfdsDataSource(tfds_name="c4/en:2.2.0"), 27 | preprocessors=[ 28 | functools.partial( 29 | t5_preprocessors.rekey, 30 | key_map={ 31 | "inputs": None, 32 | "targets": "text", 33 | }), 34 | seqio.preprocessors.tokenize, 35 | seqio.CacheDatasetPlaceholder(), 36 | t5_preprocessors.prefix_lm, 37 | functools.partial( 38 | rec_preprocessors.random_targets, 39 | targets=["positive", "negative"], 40 | ), 41 | seqio.preprocessors.append_eos_after_trim, 42 | ], 43 | output_features=features.T5_FEATURES, 44 | metric_fns=[], 45 | postprocess_fn=None 46 | ) 47 | -------------------------------------------------------------------------------- /prompt_tuning/recycling/data/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A metric wrapper that updates the metric names. 16 | 17 | This wrapper allows us to have two 2 rank classification metrics for a single 18 | task with different settings (one with length normalization on and one off). 19 | Otherwise we would need to create two copies of the task which would cause a 20 | large slow-down in evaluation time as inference would be performed on each task, 21 | even through they have the same inputs, only the metric calculation changes. The 22 | wrapper enables us re-use the model's inference for both version of the metric. 23 | """ 24 | 25 | from typing import Union, Dict, Sequence, Tuple, Callable 26 | 27 | RankClassificationTargets = Sequence[Tuple[Sequence[int], bool, float, int]] 28 | RankClassificationScores = Sequence[float] 29 | MetricMap = Dict[str, Union[float, int]] 30 | RankClassificationMetric = Callable[ 31 | [RankClassificationTargets, RankClassificationScores], MetricMap] 32 | 33 | 34 | # Note: The parameter names need to be `target` and `scores` as T5X uses 35 | # introspection to determine if it should use the autoregressive or the scoring 36 | # based inference function. 37 | def prefix_metric_names( 38 | targets: RankClassificationTargets, 39 | scores: RankClassificationScores, 40 | metric_prefix: str, 41 | metric_fn: RankClassificationMetric) -> MetricMap: 42 | """Run `metric_fn` and prepend `metric_prefix` to all metric names.""" 43 | metrics = metric_fn(targets, scores) 44 | return {f"{metric_prefix}_{k}": v for k, v in metrics.items()} 45 | -------------------------------------------------------------------------------- /prompt_tuning/recycling/data/preprocessors_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for prompt recycling preprocessors.""" 16 | 17 | import collections 18 | import itertools 19 | import os 20 | from absl.testing import absltest 21 | import numpy as np 22 | from prompt_tuning.recycling.data import preprocessors 23 | import seqio 24 | import t5.data 25 | import tensorflow as tf 26 | import tensorflow_datasets as tfds 27 | 28 | TRIALS = 400_000 29 | 30 | # We need to go up 3 dirs to get to the test data path. 31 | TEST_DATA = os.path.join( 32 | os.path.dirname( 33 | os.path.dirname( 34 | os.path.dirname( 35 | os.path.abspath(__file__)))), 36 | "test_data") 37 | 38 | 39 | class RandomLabelsTest(absltest.TestCase): 40 | 41 | def _collect_stats_from_dataset(self, dataset): 42 | counts = collections.Counter( 43 | [tuple(ex["targets"].tolist()) for ex in tfds.as_numpy(dataset)]) 44 | norm = sum(counts.values()) 45 | return {k: v / norm for k, v in counts.most_common()} 46 | 47 | def test_uniform(self): 48 | targets = ["positive", "negative", "neither"] 49 | vocab = seqio.vocabularies.SentencePieceVocabulary( 50 | os.path.join(TEST_DATA, "t5_vocab")) 51 | 52 | ds = tf.data.Dataset.range(TRIALS) 53 | ds = ds.map(lambda ex: {"inputs": ex, "targets": None}) 54 | ds = preprocessors.random_targets( 55 | ds, 56 | {"targets": t5.data.Feature(vocabulary=vocab)}, 57 | None, 58 | targets, 59 | seed=42) 60 | 61 | stats = self._collect_stats_from_dataset(ds) 62 | for dist1, dist2 in itertools.product(stats.values(), stats.values()): 63 | np.testing.assert_allclose(dist1, dist2, rtol=1e-2) 64 | 65 | def test_skewed(self): 66 | targets = ["positive", "negative", "neither"] 67 | probs = [0.1, 0.3, 0.6] 68 | 69 | vocab = seqio.vocabularies.SentencePieceVocabulary( 70 | os.path.join(TEST_DATA, "t5_vocab")) 71 | 72 | ds = tf.data.Dataset.range(TRIALS) 73 | ds = ds.map(lambda ex: {"inputs": ex, "targets": None}) 74 | ds = preprocessors.random_targets( 75 | ds, 76 | {"targets": t5.data.Feature(vocabulary=vocab)}, 77 | None, 78 | targets, 79 | probs=probs, 80 | seed=42) 81 | 82 | stats = self._collect_stats_from_dataset(ds) 83 | for target, prob in zip(targets, probs): 84 | targ = tuple(vocab.encode(target)) 85 | np.testing.assert_allclose(stats[targ], prob, rtol=1e-2) 86 | 87 | 88 | if __name__ == "__main__": 89 | absltest.main() 90 | -------------------------------------------------------------------------------- /prompt_tuning/recycling/data/rank_classification.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for creating rank classification versions of out datasets.""" 16 | 17 | import tensorflow as tf 18 | 19 | 20 | def get_inputs(ex, labels): 21 | """Get inputs for each target during rank classification.""" 22 | return [ex["inputs"]] * len(labels) 23 | 24 | 25 | def get_targets(ex, labels): 26 | """Get possible targets for rank classification.""" 27 | del ex 28 | return list(labels) 29 | 30 | 31 | def get_correct(ex, labels): 32 | """Get a boolean mask denoting which target is correct.""" 33 | return tf.equal(list(labels), ex["targets"]) 34 | -------------------------------------------------------------------------------- /prompt_tuning/recycling/data/sst2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Register SST2 Generative and Rank Classification Eval mixture.""" 16 | 17 | import functools 18 | from prompt_tuning.data import features 19 | from prompt_tuning.data import preprocessors as pt_preprocessors 20 | from prompt_tuning.recycling.data import rank_classification as pr_rc 21 | import seqio 22 | from t5.data import glue_utils 23 | from t5.data import postprocessors as t5_postprocessors 24 | from t5.data import preprocessors as t5_preprocessors 25 | from t5.evaluation import metrics as t5_metrics 26 | import tensorflow_datasets as tfds 27 | 28 | 29 | builder = tfds.text.glue.Glue.builder_configs["sst2"] 30 | 31 | # Training and Dev set split. 32 | seqio.TaskRegistry.add( 33 | "pr:glue_sst2_v200_train", 34 | source=seqio.TfdsDataSource( 35 | tfds_name="glue/sst2:2.0.0", 36 | splits={ 37 | "train": "train[:-1000]", 38 | "validation": "train[-1000:]" 39 | } 40 | ), 41 | preprocessors=[ 42 | glue_utils.get_glue_text_preprocessor(builder), 43 | pt_preprocessors.remove_first_text_token, 44 | seqio.preprocessors.tokenize, 45 | seqio.preprocessors.append_eos_after_trim, 46 | ], 47 | postprocess_fn=glue_utils.get_glue_postprocess_fn(builder), 48 | metric_fns=glue_utils.get_glue_metric(builder.name), 49 | output_features=features.T5_FEATURES 50 | ) 51 | 52 | # Generative Test split 53 | seqio.TaskRegistry.add( 54 | "pr:glue_sst2_v200_test", 55 | source=seqio.TfdsDataSource( 56 | tfds_name="glue/sst2:2.0.0", 57 | splits=("validation",) 58 | ), 59 | preprocessors=[ 60 | glue_utils.get_glue_text_preprocessor(builder), 61 | pt_preprocessors.remove_first_text_token, 62 | seqio.preprocessors.tokenize, 63 | seqio.preprocessors.append_eos_after_trim, 64 | ], 65 | postprocess_fn=glue_utils.get_glue_postprocess_fn(builder), 66 | metric_fns=glue_utils.get_glue_metric(builder.name), 67 | output_features=features.T5_FEATURES 68 | ) 69 | 70 | # Rank classification Test split 71 | seqio.TaskRegistry.add( 72 | "pr:glue_sst2_v200_test_rc", 73 | source=seqio.TfdsDataSource( 74 | tfds_name="glue/sst2:2.0.0", 75 | splits=("validation",) 76 | ), 77 | preprocessors=[ 78 | glue_utils.get_glue_text_preprocessor(builder), 79 | pt_preprocessors.remove_first_text_token, 80 | functools.partial( 81 | t5_preprocessors.rank_classification, 82 | inputs_fn=functools.partial(pr_rc.get_inputs, 83 | labels=builder.label_classes), 84 | targets_fn=functools.partial(pr_rc.get_targets, 85 | labels=builder.label_classes), 86 | is_correct_fn=functools.partial(pr_rc.get_correct, 87 | labels=builder.label_classes) 88 | ), 89 | seqio.preprocessors.tokenize, 90 | seqio.preprocessors.append_eos_after_trim, 91 | ], 92 | postprocess_fn=t5_postprocessors.rank_classification, 93 | metric_fns=[functools.partial( 94 | t5_metrics.rank_classification, 95 | num_classes=2)], 96 | output_features=features.T5_FEATURES 97 | ) 98 | 99 | # Mix training and eval settings together for training and early stopping. 100 | seqio.MixtureRegistry.add( 101 | "pr:glue_sst2_v200", 102 | ["pr:glue_sst2_v200_train", 103 | "pr:glue_sst2_v200_test", 104 | "pr:glue_sst2_v200_test_rc"], 105 | default_rate=1.0) 106 | 107 | # Mix together to run generative and RC evals at onces from eval.py 108 | seqio.MixtureRegistry.add( 109 | "pr:glue_sst2_v200_eval", 110 | ["pr:glue_sst2_v200_test_rc", 111 | "pr:glue_sst2_v200_test"], 112 | default_rate=1.0) 113 | -------------------------------------------------------------------------------- /prompt_tuning/recycling/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for reading english only embeddings.""" 16 | 17 | import json 18 | from t5x import checkpoints 19 | from tensorflow.io import gfile 20 | 21 | from flaxformer.types import Array 22 | 23 | 24 | def load_english_only_embedding( 25 | checkpoint_path: str, 26 | word_list_path: str, 27 | ) -> Array: 28 | """Load embedding based on our filtered, english-only list. 29 | 30 | Args: 31 | checkpoint_path: The t5x checkpoint to load from. 32 | word_list_path: The word list json file, format is Mapping[int, str] where 33 | the key is the vocab index in the original vocab and the value is the 34 | actual text of the word. 35 | 36 | Returns: 37 | The embedding table for the words in `word_list_path`. 38 | """ 39 | with gfile.GFile(word_list_path) as f: 40 | filtered_words = json.load(f) 41 | 42 | # JSON keys are strings, convert to ints for actual indexing. 43 | filtered_indices = list(map(int, filtered_words.keys())) 44 | 45 | full_embeddings = checkpoints.load_t5x_checkpoint( 46 | checkpoint_path, lazy_parameters=True) 47 | full_embeddings = ( 48 | full_embeddings["target"]["token_embedder"]["embedding"].get()) 49 | return full_embeddings[filtered_indices] 50 | -------------------------------------------------------------------------------- /prompt_tuning/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /prompt_tuning/scripts/extract_variable.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Extract a variable from a t5x checkpoint and save it as a numpy file. 16 | 17 | Example usage: 18 | python -m prompt_tuning.scripts.extract_variable \ 19 | --checkpoint_dir=/path/to/t5x/checkpoint_step \ 20 | --variable_path=target/encoder/prompt/prompt/prompt \ 21 | --restore_dtype=float32 \ 22 | --output_path=/path/to/save/prompt.npy 23 | 24 | """ 25 | 26 | import os 27 | import re 28 | from typing import Mapping, Any, Sequence 29 | from absl import app 30 | from absl import flags 31 | from absl import logging 32 | import jax.numpy as jnp 33 | import numpy as np 34 | from t5x import checkpoints 35 | from tensorflow.io import gfile 36 | 37 | FLAGS = flags.FLAGS 38 | 39 | flags.DEFINE_string( 40 | "checkpoint_dir", None, "The path to the t5x checkpoint directory") 41 | flags.DEFINE_string( 42 | "variable_path", 43 | None, 44 | "The path to the variable in the checkpoint tree, using `/` for scoping. " 45 | "Leading `/` or `/target` is optional.") 46 | flags.DEFINE_enum( 47 | "restore_dtype", 48 | "float32", 49 | ["float32", "bfloat16"], 50 | "The data type to use when restoring the variable.") 51 | flags.DEFINE_string( 52 | "output_path", 53 | None, 54 | "The path to where the numpy variable should be saved.") 55 | flags.mark_flag_as_required("checkpoint_dir") 56 | flags.mark_flag_as_required("variable_path") 57 | flags.mark_flag_as_required("output_path") 58 | 59 | 60 | def normalize_variable_path(path: str, sep: str = "/") -> str: 61 | """Make sure path starts with `target/`.""" 62 | # TODO: enable saving all variables within a scope if the path 63 | # ends in the separator. 64 | path = path.strip(sep) 65 | path = re.sub(r"^target/", "", path) 66 | return f"target/{path}" 67 | 68 | 69 | def extract_nested_key( 70 | nested_key: str, blob: Mapping[str, Any], sep: str = "/") -> Any: 71 | """Extract a key nested dicts using a scoping separator.""" 72 | # TODO: Add nicer error handling that shows where in the nested 73 | # dicts your key lookup fails. 74 | for key in nested_key.split(sep): 75 | blob = blob[key] 76 | return blob 77 | 78 | 79 | def save_variable(output_path: str, variable: np.ndarray): 80 | """Save variable at output path using numpy.""" 81 | dir_name = os.path.dirname(output_path) 82 | if not gfile.exists(dir_name): 83 | gfile.makedirs(dir_name) 84 | 85 | with gfile.GFile(output_path, "wb") as wf: 86 | np.save(wf, variable) 87 | 88 | 89 | def main(argv: Sequence[str]): 90 | """Extract a numpy value from a t5x checkpoint.""" 91 | if len(argv) > 1: 92 | raise app.UsageError("Too many command-line-arguments.") 93 | 94 | restore_dtype = jnp.dtype(FLAGS.restore_dtype) 95 | 96 | checkpoint = checkpoints.load_t5x_checkpoint( 97 | FLAGS.checkpoint_dir, 98 | restore_dtype=restore_dtype, 99 | lazy_parameters=True) 100 | 101 | logging.info("Reading variables from %s as dtype=%s", 102 | FLAGS.checkpoint_dir, 103 | restore_dtype) 104 | 105 | variable_path = normalize_variable_path(FLAGS.variable_path) 106 | logging.info("Extracting variable found at %s", variable_path) 107 | 108 | variable = extract_nested_key(variable_path, checkpoint) 109 | variable = variable.get() 110 | logging.info("Read variable with shape %s", variable.shape) 111 | 112 | logging.info("Saving variable to %s", FLAGS.output_path) 113 | save_variable(FLAGS.output_path, variable) 114 | 115 | 116 | if __name__ == "__main__": 117 | app.run(main) 118 | -------------------------------------------------------------------------------- /prompt_tuning/scripts/find_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Find where a module is installed. 16 | 17 | This tool is useful for finding where a package like T5X is installed so that 18 | we can easily use the gin configs that are bundled with it. 19 | 20 | Example usage: 21 | 22 | python -m t5x.train \ 23 | --gin_search_paths=`python -m prompt_tuning.scripts.find_module t5x` \ 24 | --gin_file=t5x/configs/... \ 25 | ... 26 | """ 27 | 28 | import importlib 29 | import os 30 | from typing import Sequence 31 | from absl import app 32 | 33 | 34 | def main(argv: Sequence[str]): 35 | if len(argv) != 2: 36 | raise app.UsageError("Missing module argument.") 37 | 38 | module = importlib.import_module(argv[1]) 39 | print(os.path.dirname(os.path.abspath(module.__file__))) 40 | 41 | 42 | if __name__ == "__main__": 43 | app.run(main) 44 | -------------------------------------------------------------------------------- /prompt_tuning/scripts/mrqa_to_tsv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Convert the MRQA data format into a tsv format that t5 can handle. 16 | 17 | This script converts the jsonl format from MRQA into a tsv file for t5 18 | consumption. It produces a tsv with the following columns. `id`, `question`, 19 | `context`, `answer` and `answers`. `answer` is the answer that will be used as 20 | the target when training while `answers` will be the multiple human answers 21 | that are used for evaluation. The answers column will be joined together with 22 | the `--delim` argument (defaults to `"|||"`). 23 | 24 | Eventually it should be nice to have a tfds version of these datasets 25 | (at least the out of domain validation set ones) that produce a 26 | `squad-like` example format. There are some hurdles for that though. 27 | Are all the answers actually in the text? Do all the answers have start 28 | offsets associated with them? 29 | 30 | Example usage: 31 | python -m prompt_tuning.scripts.mrqa_to_tsv \ 32 | --json /path/to/MRQA_dataset.jsonl \ 33 | --output /path/to/MRQA_dataset.tsv 34 | 35 | """ 36 | 37 | import json 38 | import os 39 | import re 40 | from typing import Dict 41 | 42 | from absl import app 43 | from absl import flags 44 | import pandas as pd 45 | from tensorflow.io import gfile 46 | 47 | FLAGS = flags.FLAGS 48 | flags.DEFINE_string("jsonl", None, "The MRQA 2019 shared task jsonl file.") 49 | flags.DEFINE_integer("header", 1, "The number of header lines to skip.") 50 | flags.DEFINE_string( 51 | "delim", "|||", 52 | "The default separator of answers when we serialize the list") 53 | flags.DEFINE_string( 54 | "output", None, 55 | "The name of the output file. Defaults to `--jsonl` base name + \".tsv\"") 56 | flags.mark_flag_as_required("jsonl") 57 | 58 | 59 | def normalize_whitespace(s: str) -> str: 60 | """Convert all whitespace (tabs, newlines, etc) into spaces.""" 61 | return re.sub(r"\s+", " ", s, flags=re.MULTILINE) 62 | 63 | 64 | def parse_line(line: str, delim: str = "|||") -> Dict[str, str]: 65 | """Turn a jsonl line into a row that is ready to write to csv.""" 66 | example = json.loads(line) 67 | context = normalize_whitespace(example["context"]) 68 | # Some of the questions have newlines in them and the tensorflow csv utils 69 | # can't handle newlines, so remove them. 70 | question = normalize_whitespace(example["qas"][0]["question"]) 71 | qid = example["qas"][0]["qid"] 72 | answer = normalize_whitespace( 73 | example["qas"][0]["detected_answers"][0]["text"]) 74 | answers = [normalize_whitespace(ans) for ans in example["qas"][0]["answers"]] 75 | return { 76 | "id": qid, 77 | "context": context, 78 | "question": question, 79 | "answer": answer, 80 | "answers": delim.join(answers) 81 | } 82 | 83 | 84 | def main(_): 85 | output = FLAGS.output 86 | if FLAGS.output is None: 87 | output = os.path.splitext(FLAGS.jsonl)[0] + ".tsv" 88 | 89 | examples = [] 90 | with gfile.GFile(FLAGS.jsonl) as f: 91 | for _ in range(FLAGS.header): 92 | f.readline() 93 | for line in f: 94 | examples.append(parse_line(line)) 95 | 96 | df = pd.DataFrame(examples, dtype=str) 97 | with gfile.GFile(output, "w") as wf: 98 | df.to_csv(wf, index=False, sep="\t") 99 | 100 | 101 | if __name__ == "__main__": 102 | app.run(main) 103 | -------------------------------------------------------------------------------- /prompt_tuning/scripts/sst2-demo-eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | EVAL_DIR=${1:-${EVAL_DIR}} 4 | TFDS_DATA_DIR=${2:-${TFDS_DATA_DIR}} 5 | 6 | if [ -z ${EVAL_DIR} ] || [ -z ${TFDS_DATA_DIR} ]; then 7 | echo "usage: ./sst2-demo-eval.sh gs://your-bucket/path/to/eval_dir gs://your-bucket/path/to/tfds/cache" 8 | exit 1 9 | fi 10 | 11 | T5X_DIR="`python3 -m prompt_tuning.scripts.find_module t5x`/.." 12 | FLAXFORMER_DIR="`python3 -m prompt_tuning.scripts.find_module flaxformer`/.." 13 | PROMPT_DIR="`python3 -m prompt_tuning.scripts.find_module prompt_tuning`/.." 14 | echo "Searching for gin configs in:" 15 | echo "- ${T5X_DIR}" 16 | echo "- ${FLAXFORMER_DIR}" 17 | echo "- ${PROMPT_DIR}" 18 | echo "=============================" 19 | PRETRAINED_MODEL="gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_base/checkpoint_1100000" 20 | PROMPT_FILE="`python3 -m prompt_tuning.scripts.find_module prompt_tuning`/pretrained_prompts/t5_1_1_lm100k_base/sst2.npy" 21 | 22 | python3 -m t5x.eval \ 23 | --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" \ 24 | --gin_file="prompt_tuning/configs/models/t5_1_1_base_prompt.gin" \ 25 | --gin_file="prompt_tuning/configs/runs/prompt_eval.gin" \ 26 | --gin.MIXTURE_OR_TASK_NAME="'taskless_glue_sst2_v200_examples'" \ 27 | --gin.MIXTURE_OR_TASK_MODULE="'prompt_tuning.data.glue'" \ 28 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" \ 29 | --gin.CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" \ 30 | --gin.EVAL_OUTPUT_DIR="'${EVAL_DIR}'" \ 31 | --gin.utils.DatasetConfig.split="'validation'" \ 32 | --gin.utils.DatasetConfig.batch_size="128" \ 33 | --gin.USE_CACHED_TASKS="False" \ 34 | --gin.PROMPT_FILE="'${PROMPT_FILE}'" \ 35 | --tfds_data_dir=${TFDS_DATA_DIR} 36 | -------------------------------------------------------------------------------- /prompt_tuning/scripts/sst2-demo-xxl.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_DIR=${1:-${MODEL_DIR}} 4 | TFDS_DATA_DIR=${2:-${TFDS_DATA_DIR}} 5 | 6 | if [ -z ${MODEL_DIR} ] || [ -z ${TFDS_DATA_DIR} ]; then 7 | echo "usage: ./sst2-demo-xxl.sh gs://your-bucket/path/to/model_dir gs://your-bucket/path/to/tfds/cache" 8 | exit 1 9 | fi 10 | 11 | T5X_DIR="`python3 -m prompt_tuning.scripts.find_module t5x`/.." 12 | FLAXFORMER_DIR="`python3 -m prompt_tuning.scripts.find_module flaxformer`/.." 13 | PROMPT_DIR="`python3 -m prompt_tuning.scripts.find_module prompt_tuning`/.." 14 | echo "Searching for gin configs in:" 15 | echo "- ${T5X_DIR}" 16 | echo "- ${FLAXFORMER_DIR}" 17 | echo "- ${PROMPT_DIR}" 18 | echo "=============================" 19 | PRETRAINED_MODEL="gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xxl/checkpoint_1100000" 20 | 21 | python3 -m t5x.train \ 22 | --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" \ 23 | --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" \ 24 | --gin_file="prompt_tuning/configs/prompts/from_class_labels.gin" \ 25 | --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" \ 26 | --gin.CLASS_LABELS="['positive', 'negative']" \ 27 | --gin.MODEL_DIR="'${MODEL_DIR}'" \ 28 | --gin.MIXTURE_OR_TASK_NAME="'taskless_glue_sst2_v200_examples'" \ 29 | --gin.MIXTURE_OR_TASK_MODULE="'prompt_tuning.data.glue'" \ 30 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" \ 31 | --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" \ 32 | --gin.TRAIN_STEPS="1_150_000" \ 33 | --gin.USE_CACHED_TASKS="False" \ 34 | --gin.BATCH_SIZE="32" \ 35 | --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" \ 36 | --tfds_data_dir=${TFDS_DATA_DIR} 37 | -------------------------------------------------------------------------------- /prompt_tuning/scripts/sst2-demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_DIR=${1:-${MODEL_DIR}} 4 | TFDS_DATA_DIR=${2:-${TFDS_DATA_DIR}} 5 | 6 | if [ -z ${MODEL_DIR} ] || [ -z ${TFDS_DATA_DIR} ]; then 7 | echo "usage: ./sst2-demo.sh gs://your-bucket/path/to/model_dir gs://your-bucket/path/to/tfds/cache" 8 | exit 1 9 | fi 10 | 11 | T5X_DIR="`python3 -m prompt_tuning.scripts.find_module t5x`/.." 12 | FLAXFORMER_DIR="`python3 -m prompt_tuning.scripts.find_module flaxformer`/.." 13 | PROMPT_DIR="`python3 -m prompt_tuning.scripts.find_module prompt_tuning`/.." 14 | echo "Searching for gin configs in:" 15 | echo "- ${T5X_DIR}" 16 | echo "- ${FLAXFORMER_DIR}" 17 | echo "- ${PROMPT_DIR}" 18 | echo "=============================" 19 | PRETRAINED_MODEL="gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_base/checkpoint_1100000" 20 | 21 | python3 -m t5x.train \ 22 | --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" \ 23 | --gin_file="prompt_tuning/configs/models/t5_1_1_base_prompt.gin" \ 24 | --gin_file="prompt_tuning/configs/prompts/from_class_labels.gin" \ 25 | --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" \ 26 | --gin.CLASS_LABELS="['positive', 'negative']" \ 27 | --gin.MODEL_DIR="'${MODEL_DIR}'" \ 28 | --gin.MIXTURE_OR_TASK_NAME="'taskless_glue_sst2_v200_examples'" \ 29 | --gin.MIXTURE_OR_TASK_MODULE="'prompt_tuning.data.glue'" \ 30 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" \ 31 | --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" \ 32 | --gin.TRAIN_STEPS="1_150_000" \ 33 | --gin.USE_CACHED_TASKS="False" \ 34 | --tfds_data_dir=${TFDS_DATA_DIR} 35 | -------------------------------------------------------------------------------- /prompt_tuning/spot/README.md: -------------------------------------------------------------------------------- 1 | # SPoT: Soft Prompt Transfer 2 | 3 | Note: This is a work in progress. 4 | 5 | Data and code for our paper 6 | [SPoT: Better Frozen Model Adaptation through Soft Prompt Transfer](https://aclanthology.org/2022.acl-long.346) 7 | , published at [ACL 2022](https://www.2022.aclweb.org/). 8 | 9 | # Using SPoT 10 | 11 | At its core, SPoT is essentially transfer learning for prompts. A prompt learned 12 | on one or more source tasks is used as the initialization point for training on 13 | a new target task. 14 | 15 | 1. First we need a pre-trained prompt that will be used for initialization. See 16 | the 17 | [Training A Prompt Section](https://github.com/google-research/prompt-tuning/tree/main/prompt_tuning/README.md#training-a-prompt) 18 | of the main README on how to train a prompt using the Prompt Tuning code 19 | base. This prompt will be used for initialization and should generally be 20 | trained on some large pre-training mixture. 21 | 2. Second, we need the path to that prompt we just trained. Wherever T5X was 22 | configured to save models checkpoints (controlled by the `--model_dir` 23 | flag), there will be a directory called `numpy_checkpoints`. In it there are 24 | directories for each saved step (`checkpoint_${step}`) and within that is 25 | the prompt, saved as a numpy file. This file will have a name like 26 | `encoder.prompt.prompt.prompt` (for Encoder-Decoder models) which is the 27 | path to the parameter through the model PyTree, using `.` for scoping. We 28 | will need this file. So to recap the trained prompt will live at: 29 | 30 | ```shell 31 | ${MODEL_DIR}/numpy_checkpoints/checkpoint_${step}/encoder.prompt.prompt.prompt 32 | ``` 33 | 34 | 1. Finally, we train a new prompt on our target task, initializing the prompt 35 | from the file above. To do this, we train a prompt as above, this time on 36 | the actual downsteam task you care about, and use the gin config 37 | [prompts/from_file.gin](https://github.com/google-research/prompt-tuning/tree/main/prompt_tuning/configs/prompts/from_file.gin) 38 | to automatically set this up. Replace any other `--gin_file=prompts/*.gin` 39 | argument with `--gin_file=prompts/from_file.gin` (or place it after the 40 | `--gin_file=models/*.gin` argument if no prompt config was used). The 41 | inclusion of this gin file will require an additional gin cli override 42 | `--gin.PROMPT_FILE`. This override points to a numpy file that will be read 43 | and used as the initial prompt value, i.e. 44 | `--gin.PROMPT_FILE=${MODEL_DIR}/numpy_checkpoints/checkpoint_${step}/encoder.prompt.prompt.prompt`. 45 | 46 | # How to Cite 47 | 48 | If you make use of this code or idea please cite: 49 | 50 | ```bibtex 51 | @inproceedings{vu-etal-2022-spot, 52 | title = "{SP}o{T}: Better Frozen Model Adaptation through Soft Prompt Transfer", 53 | author = "Vu, Tu and 54 | Lester, Brian and 55 | Constant, Noah and 56 | Al-Rfou{'}, Rami and 57 | Cer, Daniel", 58 | booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 59 | month = may, 60 | year = "2022", 61 | address = "Dublin, Ireland", 62 | publisher = "Association for Computational Linguistics", 63 | url = "https://aclanthology.org/2022.acl-long.346", 64 | doi = "10.18653/v1/2022.acl-long.346", 65 | pages = "5039--5059", 66 | } 67 | ``` 68 | 69 | ## Note 70 | 71 | This is not an officially supported Google product. 72 | -------------------------------------------------------------------------------- /prompt_tuning/spot/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Soft Prompt Transfer.""" 16 | __version__ = "0.1.0" 17 | -------------------------------------------------------------------------------- /prompt_tuning/spot/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Soft Prompt Transfer.""" 16 | __version__ = "0.1.0" 17 | -------------------------------------------------------------------------------- /prompt_tuning/spot/data/glue.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """GLUE datasets. 16 | 17 | Register GLUE datasets using their latest TFDS versions. 18 | """ 19 | 20 | import seqio 21 | from t5.data import tasks as t5_tasks 22 | from t5.data.glue_utils import get_glue_metric 23 | from t5.data.glue_utils import get_glue_postprocess_fn 24 | from t5.data.glue_utils import get_glue_text_preprocessor 25 | import tensorflow_datasets as tfds 26 | 27 | TaskRegistry = seqio.TaskRegistry 28 | MixtureRegistry = seqio.MixtureRegistry 29 | 30 | for b in tfds.text.glue.Glue.builder_configs.values(): 31 | if b.name != 'ax' and not b.name.startswith('mnli'): 32 | TaskRegistry.add( 33 | f'spot_glue_{b.name}_v200', 34 | source=seqio.TfdsDataSource(tfds_name=f'glue/{b.name}:2.0.0'), 35 | preprocessors=[ 36 | get_glue_text_preprocessor(b), 37 | seqio.preprocessors.tokenize, 38 | seqio.CacheDatasetPlaceholder(), 39 | seqio.preprocessors.append_eos_after_trim, 40 | ], 41 | metric_fns=get_glue_metric(b.name), 42 | output_features=t5_tasks.DEFAULT_OUTPUT_FEATURES, 43 | postprocess_fn=get_glue_postprocess_fn(b)) 44 | elif b.name == 'mnli': 45 | # We will register MNLI's validation sets later 46 | TaskRegistry.add( 47 | f'spot_glue_{b.name}_train_v200', 48 | source=seqio.TfdsDataSource( 49 | tfds_name=f'glue/{b.name}:2.0.0', splits=['train']), 50 | preprocessors=[ 51 | get_glue_text_preprocessor(b), 52 | seqio.preprocessors.tokenize, 53 | seqio.CacheDatasetPlaceholder(), 54 | seqio.preprocessors.append_eos_after_trim, 55 | ], 56 | metric_fns=get_glue_metric(b.name), 57 | output_features=t5_tasks.DEFAULT_OUTPUT_FEATURES, 58 | postprocess_fn=get_glue_postprocess_fn(b)) 59 | 60 | # Register MNLI's validation sets so we can see the results on Tensorboard 61 | b = tfds.text.glue.Glue.builder_configs['mnli'] 62 | for split in ['validation_matched', 'validation_mismatched']: 63 | TaskRegistry.add( 64 | f'spot_glue_mnli_{split}_v200', 65 | source=seqio.TfdsDataSource( 66 | tfds_name='glue/mnli:2.0.0', splits={'validation': split}), 67 | preprocessors=[ 68 | get_glue_text_preprocessor(b), 69 | seqio.preprocessors.tokenize, 70 | seqio.CacheDatasetPlaceholder(), 71 | seqio.preprocessors.append_eos_after_trim, 72 | ], 73 | metric_fns=get_glue_metric(b.name), 74 | output_features=t5_tasks.DEFAULT_OUTPUT_FEATURES, 75 | postprocess_fn=get_glue_postprocess_fn(b)) 76 | 77 | # Create MNLI mixture 78 | MixtureRegistry.add( 79 | 'spot_glue_mnli_and_dev_v200', ([ 80 | 'spot_glue_mnli_train_v200', 81 | 'spot_glue_mnli_validation_matched_v200', 82 | 'spot_glue_mnli_validation_mismatched_v200', 83 | ]), 84 | default_rate=1.0) 85 | -------------------------------------------------------------------------------- /prompt_tuning/spot/data/mrqa.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """MRQA datasets. 16 | 17 | Note that we do not include held-out out-of-domain datasets here. 18 | """ 19 | 20 | import functools 21 | 22 | from prompt_tuning.data import preprocessors as pt_preprocessors 23 | import seqio 24 | from t5.data import postprocessors as t5_postprocessors 25 | from t5.data import tasks as t5_tasks 26 | from t5.evaluation import metrics as t5_metrics 27 | 28 | TaskRegistry = seqio.TaskRegistry 29 | MixtureRegistry = seqio.MixtureRegistry 30 | 31 | DATASETS = { 32 | 'squad': { 33 | 'tfds_name': 'mrqa/squad:1.0.0', 34 | }, 35 | 'news_qa': { 36 | 'tfds_name': 'mrqa/news_qa:1.0.0', 37 | }, 38 | 'trivia_qa': { 39 | 'tfds_name': 'mrqa/trivia_qa:1.0.0', 40 | }, 41 | 'search_qa': { 42 | 'tfds_name': 'mrqa/search_qa:1.0.0', 43 | }, 44 | 'hotpot_qa': { 45 | 'tfds_name': 'mrqa/hotpot_qa:1.0.0', 46 | }, 47 | 'natural_questions': { 48 | 'tfds_name': 'mrqa/natural_questions:1.0.0', 49 | }, 50 | } 51 | 52 | # Register datasets 53 | for dataset in DATASETS: 54 | version = f"v{DATASETS[dataset]['tfds_name'].split(':')[-1].replace('.', '')}" 55 | TaskRegistry.add( 56 | f'spot_mrqa_{dataset.lower()}_{version}', 57 | source=seqio.TfdsDataSource(tfds_name=DATASETS[dataset]['tfds_name']), 58 | preprocessors=[ 59 | functools.partial( 60 | pt_preprocessors.mrqa, 61 | task_name=dataset.lower(), 62 | ), 63 | seqio.preprocessors.tokenize, 64 | seqio.CacheDatasetPlaceholder(), 65 | seqio.preprocessors.append_eos_after_trim, 66 | ], 67 | postprocess_fn=t5_postprocessors.qa, 68 | metric_fns=[t5_metrics.squad], 69 | output_features=t5_tasks.DEFAULT_OUTPUT_FEATURES) 70 | -------------------------------------------------------------------------------- /prompt_tuning/spot/data/nli.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Summarization datasets.""" 16 | 17 | import functools 18 | 19 | from prompt_tuning.spot.data import preprocessors as spot_preprocessors 20 | import seqio 21 | from t5.data import tasks as t5_tasks 22 | from t5.evaluation import metrics as t5_metrics 23 | 24 | TaskRegistry = seqio.TaskRegistry 25 | MixtureRegistry = seqio.MixtureRegistry 26 | 27 | DATASETS = { 28 | 'anli_r1': { 29 | 'tfds_name': 'anli/r1:0.1.0', 30 | 'text_a_key': 'hypothesis', 31 | 'text_b_key': 'context', 32 | 'label_names': ['entailment', 'neutral', 'contradiction'], 33 | }, 34 | 'anli_r2': { 35 | 'tfds_name': 'anli/r2:0.1.0', 36 | 'text_a_key': 'hypothesis', 37 | 'text_b_key': 'context', 38 | 'label_names': ['entailment', 'neutral', 'contradiction'], 39 | }, 40 | 'anli_r3': { 41 | 'tfds_name': 'anli/r3:0.1.0', 42 | 'text_a_key': 'hypothesis', 43 | 'text_b_key': 'context', 44 | 'label_names': ['entailment', 'neutral', 'contradiction'], 45 | }, 46 | 'doc_nli': { 47 | 'tfds_name': 'doc_nli:1.0.0', 48 | 'text_a_key': 'hypothesis', 49 | 'text_b_key': 'premise', 50 | 'label_names': ['not_entailment', 'entailment'], 51 | }, 52 | 'snli': { 53 | 'tfds_name': 'snli:1.1.0', 54 | 'text_a_key': 'hypothesis', 55 | 'text_b_key': 'premise', 56 | 'label_names': ['entailment', 'neutral', 'contradiction'], 57 | }, 58 | } 59 | 60 | # Register datasets 61 | for dataset in DATASETS: 62 | version = f"v{DATASETS[dataset]['tfds_name'].split(':')[-1].replace('.', '')}" 63 | TaskRegistry.add( 64 | f'spot_{dataset.lower()}_{version}', 65 | source=seqio.TfdsDataSource(tfds_name=DATASETS[dataset]['tfds_name']), 66 | preprocessors=[ 67 | functools.partial( 68 | spot_preprocessors.preprocess_text_classification, 69 | text_a_key=DATASETS[dataset]['text_a_key'], 70 | text_b_key=DATASETS[dataset]['text_b_key'], 71 | task_name=dataset, 72 | label_names=DATASETS[dataset]['label_names']), 73 | seqio.preprocessors.tokenize, 74 | seqio.CacheDatasetPlaceholder(), 75 | seqio.preprocessors.append_eos_after_trim, 76 | ], 77 | metric_fns=[t5_metrics.accuracy], 78 | output_features=t5_tasks.DEFAULT_OUTPUT_FEATURES) 79 | -------------------------------------------------------------------------------- /prompt_tuning/spot/data/preprocessors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Preprocessors for SPoT tasks.""" 16 | 17 | import seqio 18 | import t5.data.preprocessors 19 | import tensorflow.compat.v2 as tf 20 | 21 | # pylint:disable=no-value-for-parameter, protected-access 22 | AUTOTUNE = tf.data.experimental.AUTOTUNE 23 | 24 | _string_join = t5.data.preprocessors._string_join 25 | _pad_punctuation = t5.data.preprocessors._pad_punctuation 26 | 27 | 28 | @seqio.map_over_dataset 29 | def preprocess_text_classification(example, 30 | text_a_key, 31 | text_b_key=None, 32 | task_name=None, 33 | label_names=None): 34 | """Convert a text classification dataset to a text-to-text format. 35 | 36 | Each {, } example will have the format: 37 | {'inputs': : [: ], 38 | 'targets': } 39 | 40 | Args: 41 | example: An example to process. 42 | text_a_key: The key for the (first) text. 43 | text_b_key: The key for the second text (if any). 44 | task_name: The name of the task. 45 | label_names: A list of label names corresponding to class index. 46 | 47 | Returns: 48 | A preprocessed example with the format listed above. 49 | """ 50 | 51 | text_a = example[text_a_key] 52 | text_b = example[text_b_key] if text_b_key is not None else None 53 | 54 | strs_to_join = [f'{text_a_key}:', text_a] 55 | if task_name is not None: 56 | strs_to_join = [task_name] + strs_to_join 57 | 58 | if text_b_key is not None and text_b is not None: 59 | strs_to_join.extend([f'{text_b_key}:', text_b]) 60 | 61 | if label_names is not None: 62 | label = example['label'] 63 | if label.dtype == tf.string: 64 | label = tf.strings.to_number(label, tf.int64) 65 | label_name = tf.cond( 66 | # When no label is provided (label == -1), use "" 67 | tf.equal(label, -1), 68 | lambda: tf.constant(''), 69 | # Otherwise grab the label text from label_names 70 | lambda: tf.gather(label_names, label), 71 | ) 72 | else: 73 | label_name = tf.as_string(example['label']) 74 | 75 | return { 76 | 'inputs': _string_join(strs_to_join), 77 | 'targets': label_name 78 | } 79 | -------------------------------------------------------------------------------- /prompt_tuning/spot/data/summarization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Summarization datasets.""" 16 | 17 | import functools 18 | 19 | from prompt_tuning.data import preprocessors as pt_preprocessors 20 | import seqio 21 | from t5.data import tasks as t5_tasks 22 | from t5.evaluation import metrics as t5_metrics 23 | 24 | TaskRegistry = seqio.TaskRegistry 25 | MixtureRegistry = seqio.MixtureRegistry 26 | 27 | DATASETS = { 28 | 'aeslc': { 29 | 'tfds_name': 'aeslc:1.0.0', 30 | 'source_key': 'email_body', 31 | 'target_key': 'subject_line', 32 | }, 33 | 'billsum': { 34 | 'tfds_name': 'billsum:3.0.0', 35 | 'source_key': 'text', 36 | 'target_key': 'summary', 37 | }, 38 | 'gigaword': { 39 | 'tfds_name': 'gigaword:1.2.0', 40 | 'source_key': 'document', 41 | 'target_key': 'summary', 42 | }, 43 | 'cnn_dailymail': { 44 | 'tfds_name': 'cnn_dailymail:3.2.0', 45 | 'source_key': 'article', 46 | 'target_key': 'highlights', 47 | }, 48 | 'multi_news': { 49 | 'tfds_name': 'multi_news:1.0.0', 50 | 'source_key': 'document', 51 | 'target_key': 'summary', 52 | }, 53 | 'samsum': { 54 | 'tfds_name': 'samsum:1.0.0', 55 | 'source_key': 'dialogue', 56 | 'target_key': 'summary', 57 | }, 58 | 'newsroom': { 59 | 'tfds_name': 'newsroom:1.0.0', 60 | 'source_key': 'text', 61 | 'target_key': 'summary', 62 | }, 63 | } 64 | 65 | # Register datasets 66 | for dataset in DATASETS: 67 | version = f"v{DATASETS[dataset]['tfds_name'].split(':')[-1].replace('.', '')}" 68 | TaskRegistry.add( 69 | f'spot_{dataset.lower()}_{version}', 70 | source=seqio.TfdsDataSource(tfds_name=DATASETS[dataset]['tfds_name']), 71 | preprocessors=[ 72 | functools.partial( 73 | pt_preprocessors.preprocess_text_generation, 74 | source_key=DATASETS[dataset]['source_key'], 75 | target_key=DATASETS[dataset]['target_key'], 76 | task_name=dataset.lower(), 77 | prefix='summarize:', 78 | ), 79 | seqio.preprocessors.tokenize, 80 | seqio.CacheDatasetPlaceholder(), 81 | seqio.preprocessors.append_eos_after_trim, 82 | ], 83 | metric_fns=[t5_metrics.rouge], 84 | output_features=t5_tasks.DEFAULT_OUTPUT_FEATURES) 85 | -------------------------------------------------------------------------------- /prompt_tuning/spot/data/tasks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Import tasks.""" 16 | 17 | # pylint: disable=unused-import,g-import-not-at-top 18 | from prompt_tuning.spot.data import glue 19 | from prompt_tuning.spot.data import mrqa 20 | from prompt_tuning.spot.data import nli 21 | from prompt_tuning.spot.data import summarization 22 | from t5.data import tasks as t5_tasks 23 | -------------------------------------------------------------------------------- /prompt_tuning/test_data/prompt_5x256.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/prompt-tuning/72285eea06100f954bcbb16a447ec6cddfc6716c/prompt_tuning/test_data/prompt_5x256.npy -------------------------------------------------------------------------------- /prompt_tuning/test_data/t5_vocab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/prompt-tuning/72285eea06100f954bcbb16a447ec6cddfc6716c/prompt_tuning/test_data/t5_vocab -------------------------------------------------------------------------------- /prompt_tuning/test_data/test_t5_1_1_tiny/checkpoint_3/checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/prompt-tuning/72285eea06100f954bcbb16a447ec6cddfc6716c/prompt_tuning/test_data/test_t5_1_1_tiny/checkpoint_3/checkpoint -------------------------------------------------------------------------------- /prompt_tuning/test_data/test_t5_1_1_tiny/checkpoint_3/target.decoder.layers_0.encoder_decoder_attention.key.kernel/.zarray: -------------------------------------------------------------------------------- 1 | {"chunks":[4,4],"compressor":{"id":"gzip","level":1},"dtype":" partitioning.LogicalAxisRules: 21 | """Add prompt specific partitioning rules.""" 22 | return (("prompt", None), ("prompt_embed", None)) 23 | -------------------------------------------------------------------------------- /prompt_tuning/x_gen/README.md: -------------------------------------------------------------------------------- 1 | # X-Gen: Zero-Shot Cross-Lingual Generation 2 | 3 | Note: This is a work in progress. 4 | 5 | Data and code for our paper [Overcoming Catastrophic Forgetting in Zero-Shot Cross-Lingual Generation](https://arxiv.org/abs/2205.12647). 6 | 7 | # How to Cite 8 | 9 | If you leverage the ideas or code from our work please cite: 10 | 11 | ```bibtex 12 | @article{Vu2022OvercomingCF, 13 | title={Overcoming Catastrophic Forgetting in Zero-Shot Cross-Lingual Generation}, 14 | author={Tu Vu and Aditya Barua and Brian Lester and Daniel Matthew Cer and Mohit Iyyer and Noah Constant}, 15 | journal={ArXiv}, 16 | year={2022}, 17 | volume={abs/2205.12647} 18 | } 19 | ``` 20 | 21 | ## Note 22 | 23 | This is not an officially supported Google product. 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | 4 | [tool.pytest.ini_options] 5 | minversion = "6.0" 6 | python_files = "*_test.py" 7 | log_level = "INFO" 8 | # Skip train_test as it requires TPUs to execute. 9 | # Skip tasks_test as they require TFDS pre-setup. 10 | # Skip utils_test as it requires temporary files that doesn't work well on Cloud. 11 | addopts = "--ignore=prompt_tuning/train/train_test.py --ignore=prompt_tuning/data/tasks_test.py --ignore=prompt_tuning/train/utils_test.py" 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Install prompt-tuning.""" 16 | 17 | import ast 18 | import setuptools 19 | 20 | 21 | def get_version(file_name: str, version_name: str = "__version__") -> str: 22 | """Find version by AST parsing to avoid needing to import this package.""" 23 | with open(file_name) as f: 24 | tree = ast.parse(f.read()) 25 | # Look for all assignment nodes in the AST, if the variable name is what 26 | # we assigned the version number too, grab the value (the version). 27 | for node in ast.walk(tree): 28 | if isinstance(node, ast.Assign): 29 | if node.targets[0].id == version_name: 30 | return node.value.s 31 | raise ValueError(f"Couldn't find assignment to variable {version_name} " 32 | f"in file {file_name}") 33 | 34 | with open("README.md") as fp: 35 | LONG_DESCRIPTION = fp.read() 36 | 37 | _jax_version = "0.3.1" 38 | 39 | setuptools.setup( 40 | name="prompt-tuning", 41 | version=get_version("prompt_tuning/__init__.py"), 42 | description="Prompt Tuning from Lester et al., 2021", 43 | long_description=LONG_DESCRIPTION, 44 | long_description_content_type="text/markdown", 45 | author="Google Inc.", 46 | author_email="no-reply@google.com", 47 | url="http://github.com/google-research/prompt-tuning", 48 | license="Apache 2.0", 49 | packages=setuptools.find_packages(), 50 | include_package_data=True, 51 | package_data={ 52 | "": ["**/*.gin", "**/*.json"], 53 | }, 54 | scripts=[], 55 | install_requires=[ 56 | "absl-py", 57 | "flax @ git+https://github.com/google/flax#egg=flax", 58 | "gin-config", 59 | f"jax>={_jax_version}", 60 | "numpy", 61 | "seqio-nightly", 62 | "t5", 63 | "tensorflow", 64 | "tensorflow_datasets", 65 | # Install from git as they have setup.pys but are not on PyPI. 66 | "t5x @ git+https://github.com/google-research/t5x@main#egg=t5x", 67 | "flaxformer @ git+https://github.com/google/flaxformer@main#egg=flaxformer", 68 | ], 69 | extras_require={ 70 | "test": ["pytest>=6.0"], 71 | # TODO: mt5 and byt5 are not setup as python packages. 72 | # Figure out best way to bring them in as dependencies. 73 | "mt5": [], 74 | "byt5": [], 75 | "mrqa": ["pandas"], 76 | "tpu": [f"jax[tpu]>={_jax_version}"] 77 | }, 78 | classifiers=[ 79 | "Development Status :: 4 - Beta", 80 | "Intended Audience :: Developers", 81 | "Intended Audience :: Science/Research", 82 | "License :: OSI Approved :: Apache Software License", 83 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 84 | ], 85 | keywords=[ 86 | "prompt tuning", 87 | "machine learning", 88 | "transformers", 89 | "neural networks", 90 | "pre-trained language models", 91 | "nlp", 92 | "jax", 93 | "flax", 94 | "t5", 95 | "t5x", 96 | ] 97 | ) 98 | --------------------------------------------------------------------------------