├── t5x_retrieval ├── testdata │ └── test_de_t5_tiny.checkpoint_1000 │ │ ├── checkpoint │ │ ├── target.encoder.layers_0.mlp.wi.kernel │ │ ├── .zarray │ │ └── 0.0 │ │ ├── target.encoder.layers_0.mlp.wo.kernel │ │ ├── .zarray │ │ └── 0.0 │ │ ├── target.encoder.layers_1.mlp.wi.kernel │ │ ├── .zarray │ │ └── 0.0 │ │ ├── target.encoder.layers_1.mlp.wo.kernel │ │ ├── .zarray │ │ └── 0.0 │ │ ├── target.token_embedder.embedding │ │ ├── .zarray │ │ ├── 0.0 │ │ ├── 1.0 │ │ ├── 2.0 │ │ ├── 3.0 │ │ ├── 4.0 │ │ ├── 5.0 │ │ ├── 6.0 │ │ └── 7.0 │ │ ├── target.encoder.layers_0.attention.key.kernel │ │ ├── .zarray │ │ └── 0.0 │ │ ├── target.encoder.layers_0.attention.out.kernel │ │ ├── .zarray │ │ └── 0.0 │ │ ├── target.encoder.layers_0.attention.query.kernel │ │ ├── .zarray │ │ └── 0.0 │ │ ├── target.encoder.layers_0.attention.value.kernel │ │ ├── .zarray │ │ └── 0.0 │ │ ├── target.encoder.layers_1.attention.key.kernel │ │ ├── .zarray │ │ └── 0.0 │ │ ├── target.encoder.layers_1.attention.out.kernel │ │ ├── .zarray │ │ └── 0.0 │ │ ├── target.encoder.layers_1.attention.query.kernel │ │ ├── .zarray │ │ └── 0.0 │ │ └── target.encoder.layers_1.attention.value.kernel │ │ ├── .zarray │ │ └── 0.0 ├── configs │ ├── runs │ │ ├── trainer.gin │ │ ├── infer_eval.gin │ │ ├── pretrain.gin │ │ ├── eval.gin │ │ ├── infer.gin │ │ └── finetune.gin │ ├── models │ │ ├── de_t5_3B.gin │ │ ├── de_t5_11B.gin │ │ ├── de_t5_large.gin │ │ ├── de_mt5_xl.gin │ │ ├── de_mt5_large.gin │ │ ├── de_mt5_xxl.gin │ │ ├── de_t5_1_1_xl.gin │ │ ├── de_t5_1_1_xxl.gin │ │ ├── de_t5_1_1_large.gin │ │ ├── de_longt5_1_1_transient_global_large.gin │ │ ├── de_t5_tiny.gin │ │ ├── de_mt5_small.gin │ │ ├── de_mt5_tiny.gin │ │ ├── de_t5_1_1_tiny.gin │ │ ├── de_t5_base.gin │ │ ├── de_t5_1_1_base.gin │ │ ├── de_mt5_base.gin │ │ └── de_longt5_1_1_transient_global_base.gin │ └── architectures │ │ ├── de_t5_1_1_flaxformer.gin │ │ ├── de_t5_flaxformer.gin │ │ └── de_longt5_1_1_transient_global_flaxformer.gin ├── version.py ├── partitioning.py ├── __init__.py ├── adafactor_utils.py ├── metrics.py ├── postprocessors.py ├── preprocessors.py ├── tasks.py ├── losses.py ├── feature_converters.py ├── utils.py └── models.py ├── CONTRIBUTING.md ├── LICENSE └── README.md /t5x_retrieval/testdata/test_de_t5_tiny.checkpoint_1000/checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/t5x_retrieval/HEAD/t5x_retrieval/testdata/test_de_t5_tiny.checkpoint_1000/checkpoint -------------------------------------------------------------------------------- /t5x_retrieval/configs/runs/trainer.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | from t5x import trainer 4 | 5 | # Parameters for trainer.Trainer: 6 | trainer.Trainer.weight_metrics_computer = @trainer.WeightMetricsComputer() 7 | -------------------------------------------------------------------------------- /t5x_retrieval/testdata/test_de_t5_tiny.checkpoint_1000/target.encoder.layers_0.mlp.wi.kernel/.zarray: -------------------------------------------------------------------------------- 1 | {"chunks":[4,8],"compressor":{"id":"gzip","level":1},"dtype":" to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /t5x_retrieval/configs/architectures/de_t5_1_1_flaxformer.gin: -------------------------------------------------------------------------------- 1 | # Flaxformer implementation of Dual Encoder, based on T5.1.1 architecture. 2 | # 3 | # Required to be overridden: 4 | # 5 | # - NUM_HEADS 6 | # - NUM_LAYERS 7 | # - HEAD_DIM 8 | # - EMBED_DIM 9 | # - MLP_DIM 10 | # - PROJECTION_DIM 11 | from __gin__ import dynamic_registration 12 | 13 | from flax import linen 14 | 15 | from flaxformer.components import dense 16 | from flaxformer.components import layer_norm 17 | 18 | include 't5x_retrieval/configs/architectures/de_t5_flaxformer.gin' 19 | 20 | # Additional constants (may be overridden) 21 | SCALE = 1.0 22 | 23 | # Projection layer 24 | projection_layer/linen.initializers.variance_scaling: 25 | scale = %SCALE 26 | 27 | # Attention (encoder, decoder, self-attention) 28 | attention_kernel_init/linen.initializers.variance_scaling: 29 | scale = %SCALE 30 | 31 | # Relative position biases (encoder, decoder) 32 | relative_position_bias_init/linen.initializers.variance_scaling: 33 | scale = %SCALE 34 | 35 | # MLP (encoder) 36 | dense.MlpBlock: 37 | activations = ('gelu', 'linear') 38 | mlp_kernel_init/linen.initializers.variance_scaling: 39 | scale = %SCALE 40 | 41 | layer_norm.T5LayerNorm.dtype = %ACTIVATION_DTYPE 42 | -------------------------------------------------------------------------------- /t5x_retrieval/configs/runs/infer_eval.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import __main__ as train_script 4 | import seqio 5 | from t5x import utils 6 | 7 | # Convenience overrides. 8 | EVALUATOR_USE_MEMORY_CACHE = False 9 | EVALUATOR_NUM_EXAMPLES = None # None means use all examples in the infer_eval dataset. 10 | JSON_WRITE_N_RESULTS = None # None means write all inferences. 11 | 12 | train_script.train: 13 | infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() 14 | inference_evaluator_cls = @seqio.Evaluator 15 | 16 | EVAL_MIXTURE_OR_TASK_NAME = %MIXTURE_OR_TASK_NAME 17 | EVAL_TASK_FEATURE_LENGTHS = %TASK_FEATURE_LENGTHS 18 | EVAL_BATCH_SIZE = %BATCH_SIZE 19 | 20 | infer_eval/utils.DatasetConfig: 21 | mixture_or_task_name = %EVAL_MIXTURE_OR_TASK_NAME 22 | task_feature_lengths = %EVAL_TASK_FEATURE_LENGTHS 23 | split = 'validation' 24 | batch_size = %EVAL_BATCH_SIZE 25 | shuffle = False 26 | seed = 42 27 | use_cached = %USE_CACHED_TASKS 28 | pack = False 29 | module = %MIXTURE_OR_TASK_MODULE 30 | 31 | seqio.Evaluator: 32 | logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] 33 | num_examples = %EVALUATOR_NUM_EXAMPLES 34 | use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE 35 | 36 | seqio.JSONLogger: 37 | write_n_results = %JSON_WRITE_N_RESULTS 38 | -------------------------------------------------------------------------------- /t5x_retrieval/configs/models/de_t5_base.gin: -------------------------------------------------------------------------------- 1 | # Dual Encoder based on original T5 (1.0) Base model. 2 | # Provides MODEL 3 | from __gin__ import dynamic_registration 4 | 5 | import seqio 6 | from t5x import adafactor 7 | from t5x_retrieval import feature_converters 8 | from t5x_retrieval import models 9 | from t5x_retrieval import losses 10 | 11 | ARCHITECTURE = %gin.REQUIRED 12 | LOSS_MODULE = @losses.InBatchCrossEntropyLoss 13 | 14 | include 't5x_retrieval/configs/architectures/de_t5_flaxformer.gin' 15 | 16 | # Architecture overrides 17 | NUM_HEADS = 12 18 | NUM_LAYERS = 12 19 | HEAD_DIM = 64 20 | EMBED_DIM = 768 21 | MLP_DIM = 3072 22 | PROJECTION_DIM = 768 23 | 24 | # Vocabulary (shared by encoder and decoder) 25 | VOCABULARY = @seqio.SentencePieceVocabulary() 26 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" 27 | 28 | # Optimizer 29 | # `learning_rate` is set by `Trainer.learning_rate_fn`. 30 | OPTIMIZER = @adafactor.Adafactor() 31 | adafactor.Adafactor: 32 | decay_rate = 0.8 33 | step_offset = 0 34 | 35 | # Model 36 | MODEL = @models.DualEncoderModel() 37 | models.DualEncoderModel: 38 | use_negatives = False 39 | use_align_uniform = False 40 | feature_converter_cls = @feature_converters.DualEncoderFeatureConverterFactory() 41 | module = %ARCHITECTURE # provided by t5_flaxformer 42 | loss_module_factory = %LOSS_MODULE 43 | input_vocabulary = %VOCABULARY 44 | output_vocabulary = %VOCABULARY 45 | optimizer_def = %OPTIMIZER 46 | -------------------------------------------------------------------------------- /t5x_retrieval/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Retrieval Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2022 Google LLC 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | r"""Separate file for storing the current version of T5X Retrieval. 29 | 30 | Stored in a separate file so that setup.py can reference the version without 31 | pulling in all the dependencies in __init__.py. 32 | """ 33 | __version__ = '0.0.0' 34 | -------------------------------------------------------------------------------- /t5x_retrieval/configs/models/de_t5_1_1_base.gin: -------------------------------------------------------------------------------- 1 | # Dual Encoder based on T5 (1.1) Base model. 2 | # Provides MODEL 3 | from __gin__ import dynamic_registration 4 | 5 | import seqio 6 | from t5x import adafactor 7 | from t5x_retrieval import feature_converters 8 | from t5x_retrieval import models 9 | from t5x_retrieval import losses 10 | 11 | ARCHITECTURE = %gin.REQUIRED 12 | T5XR_INFERENCE_MODE = 'encode' 13 | LOSS_MODULE = @losses.InBatchCrossEntropyLoss 14 | 15 | include 't5x_retrieval/configs/architectures/de_t5_1_1_flaxformer.gin' 16 | 17 | # Architecture overrides 18 | NUM_HEADS = 12 19 | NUM_LAYERS = 12 20 | HEAD_DIM = 64 21 | EMBED_DIM = 768 22 | MLP_DIM = 2048 23 | PROJECTION_DIM = 768 24 | 25 | # Vocabulary (shared by encoder and decoder) 26 | VOCABULARY = @seqio.SentencePieceVocabulary() 27 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" 28 | 29 | # Optimizer 30 | # `learning_rate` is set by `Trainer.learning_rate_fn`. 31 | OPTIMIZER = @adafactor.Adafactor() 32 | adafactor.Adafactor: 33 | decay_rate = 0.8 34 | step_offset = 0 35 | 36 | # Model 37 | MODEL = @models.DualEncoderModel() 38 | models.DualEncoderModel: 39 | use_negatives = False 40 | use_align_uniform = False 41 | inference_mode = %T5XR_INFERENCE_MODE 42 | feature_converter_cls = @feature_converters.DualEncoderFeatureConverterFactory() 43 | module = %ARCHITECTURE # provided by t5_flaxformer 44 | loss_module_factory = %LOSS_MODULE 45 | input_vocabulary = %VOCABULARY 46 | output_vocabulary = %VOCABULARY 47 | optimizer_def = %OPTIMIZER 48 | -------------------------------------------------------------------------------- /t5x_retrieval/partitioning.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Retrieval Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2022 Google LLC 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | 29 | """Custom, default partitioning rules for t5x retrieval models.""" 30 | 31 | from t5x import partitioning 32 | 33 | 34 | def standard_logical_axis_rules() -> partitioning.LogicalAxisRules: 35 | """Returns t5x retrieval specific partitioning rules.""" 36 | return ( 37 | ('affinity', None), 38 | ) 39 | -------------------------------------------------------------------------------- /t5x_retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Retrieval Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2022 Google LLC 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Import API modules.""" 29 | 30 | import t5x_retrieval.adafactor_utils 31 | import t5x_retrieval.feature_converters 32 | import t5x_retrieval.models 33 | import t5x_retrieval.partitioning 34 | import t5x_retrieval.tasks 35 | import t5x_retrieval.utils 36 | 37 | # Version number. 38 | from t5x_retrieval.version import __version__ 39 | -------------------------------------------------------------------------------- /t5x_retrieval/configs/runs/pretrain.gin: -------------------------------------------------------------------------------- 1 | # Defaults for pretraining with train.py. 2 | # 3 | # You must also include a binding for MODEL. 4 | # 5 | # Required to be set: 6 | # 7 | # -MIXTURE_OR_TASK_NAME 8 | # -MIXTURE_OR_TASK_MODULE 9 | # -TASK_FEATURE_LENGTHS 10 | # -train.model_dir 11 | # 12 | # Commonly overridden options: 13 | # 14 | # - DatasetConfig.batch_size 15 | # - PjitPartitioner.num_partitions 16 | # - Trainer.num_microbatches 17 | # 18 | # Currently we don't support inference eval. 19 | from __gin__ import dynamic_registration 20 | 21 | import __main__ as train_script 22 | 23 | from t5x import adafactor 24 | from t5x import partitioning 25 | from t5x import utils 26 | from t5x_retrieval import adafactor_utils 27 | from t5x_retrieval import partitioning as t5xr_partitioning 28 | 29 | include 't5x/configs/runs/pretrain.gin' 30 | 31 | train_script.train: 32 | infer_eval_dataset_cfg = None 33 | 34 | train/utils.DatasetConfig: 35 | use_cached = False 36 | pack = False 37 | 38 | train_eval/utils.DatasetConfig: 39 | use_cached = False 40 | pack = False 41 | 42 | utils.create_learning_rate_scheduler: 43 | factors = 'linear_decay' 44 | base_learning_rate = 0.001 45 | warmup_steps = 1000 46 | decay_factor = 0.0001 # 1 / %TRAIN_STEPS 47 | 48 | adafactor.Adafactor: 49 | decay_rate = 0.8 50 | step_offset = 0 51 | logical_factor_rules = @adafactor_utils.logical_factor_rules() 52 | 53 | partitioning.PjitPartitioner: 54 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 55 | 56 | partitioning.standard_logical_axis_rules: 57 | additional_rules = @t5xr_partitioning.standard_logical_axis_rules() 58 | -------------------------------------------------------------------------------- /t5x_retrieval/configs/runs/eval.gin: -------------------------------------------------------------------------------- 1 | # Defaults for eval.py. 2 | # 3 | # You must also include a binding for MODEL. 4 | # 5 | # Required to be set: 6 | # 7 | # - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to use for inference 8 | # - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features 9 | # to. 10 | # - CHECKPOINT_PATH: The model checkpoint to use for inference 11 | # - INFER_OUTPUT_DIR: The dir to write results to. 12 | # 13 | # Commonly overridden options: 14 | # 15 | # - infer.mode 16 | # - infer.checkpoint_period 17 | # - infer.shard_id 18 | # - infer.num_shards 19 | # - DatasetConfig.split 20 | # - DatasetConfig.batch_size 21 | # - DatasetConfig.use_cached 22 | # - RestoreCheckpointConfig.is_tensorflow 23 | # - RestoreCheckpointConfig.mode 24 | # - PjitPartitioner.num_partitions 25 | from __gin__ import dynamic_registration 26 | 27 | import __main__ as infer_script 28 | 29 | from t5x import adafactor 30 | from t5x import partitioning 31 | from t5x import utils 32 | from t5x_retrieval import adafactor_utils 33 | from t5x_retrieval import partitioning as t5xr_partitioning 34 | 35 | include 't5x/configs/runs/eval.gin' 36 | 37 | adafactor.Adafactor: 38 | decay_rate = 0.8 39 | step_offset = 0 40 | logical_factor_rules = @adafactor_utils.logical_factor_rules() 41 | 42 | partitioning.PjitPartitioner: 43 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 44 | 45 | partitioning.standard_logical_axis_rules: 46 | additional_rules = @t5xr_partitioning.standard_logical_axis_rules() 47 | 48 | utils.RestoreCheckpointConfig: 49 | path = %CHECKPOINT_PATH 50 | mode = 'specific' 51 | dtype = 'bfloat16' 52 | -------------------------------------------------------------------------------- /t5x_retrieval/configs/runs/infer.gin: -------------------------------------------------------------------------------- 1 | # Defaults for infer.py. 2 | # 3 | # You must also include a binding for MODEL. 4 | # 5 | # Required to be set: 6 | # 7 | # - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to use for inference 8 | # - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features 9 | # to. 10 | # - CHECKPOINT_PATH: The model checkpoint to use for inference 11 | # - INFER_OUTPUT_DIR: The dir to write results to. 12 | # 13 | # Commonly overridden options: 14 | # 15 | # - infer.mode 16 | # - infer.checkpoint_period 17 | # - infer.shard_id 18 | # - infer.num_shards 19 | # - DatasetConfig.split 20 | # - DatasetConfig.batch_size 21 | # - DatasetConfig.use_cached 22 | # - RestoreCheckpointConfig.is_tensorflow 23 | # - RestoreCheckpointConfig.mode 24 | # - PjitPartitioner.num_partitions 25 | from __gin__ import dynamic_registration 26 | 27 | import __main__ as infer_script 28 | 29 | from t5x import adafactor 30 | from t5x import partitioning 31 | from t5x import utils 32 | from t5x_retrieval import adafactor_utils 33 | from t5x_retrieval import partitioning as t5xr_partitioning 34 | 35 | include 't5x/configs/runs/infer.gin' 36 | 37 | adafactor.Adafactor: 38 | decay_rate = 0.8 39 | step_offset = 0 40 | logical_factor_rules = @adafactor_utils.logical_factor_rules() 41 | 42 | partitioning.PjitPartitioner: 43 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 44 | 45 | partitioning.standard_logical_axis_rules: 46 | additional_rules = @t5xr_partitioning.standard_logical_axis_rules() 47 | 48 | utils.RestoreCheckpointConfig: 49 | path = %CHECKPOINT_PATH 50 | mode = 'specific' 51 | dtype = 'bfloat16' 52 | -------------------------------------------------------------------------------- /t5x_retrieval/configs/models/de_mt5_base.gin: -------------------------------------------------------------------------------- 1 | # Dual Encoder based on mT5 Base model. 2 | # Provides MODEL 3 | from __gin__ import dynamic_registration 4 | 5 | import seqio 6 | from t5x import adafactor 7 | from t5x_retrieval import feature_converters 8 | from t5x_retrieval import models as models 9 | from t5x_retrieval import losses 10 | 11 | ARCHITECTURE = %gin.REQUIRED 12 | T5XR_INFERENCE_MODE = 'encode' 13 | LOSS_MODULE = @losses.InBatchCrossEntropyLoss 14 | 15 | # MT5 is identical to t5.1.1 architecture except for the vocabulary. 16 | include 't5x_retrieval/configs/architectures/de_t5_1_1_flaxformer.gin' 17 | NUM_EMBEDDINGS = 250112 # vocab size rounded to a multiple of 128 for TPU efficiency 18 | 19 | # Architecture overrides 20 | NUM_HEADS = 12 21 | NUM_LAYERS = 12 22 | HEAD_DIM = 64 23 | EMBED_DIM = 768 24 | MLP_DIM = 2048 25 | PROJECTION_DIM = 768 26 | 27 | # Vocabulary (shared by encoder and decoder) 28 | VOCABULARY = @seqio.SentencePieceVocabulary() 29 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model" 30 | 31 | # Optimizer 32 | # `learning_rate` is set by `Trainer.learning_rate_fn`. 33 | OPTIMIZER = @adafactor.Adafactor() 34 | adafactor.Adafactor: 35 | decay_rate = 0.8 36 | step_offset = 0 37 | 38 | # Model 39 | MODEL = @models.DualEncoderModel() 40 | models.DualEncoderModel: 41 | use_negatives = False 42 | use_align_uniform = False 43 | inference_mode = %T5XR_INFERENCE_MODE 44 | feature_converter_cls = @feature_converters.DualEncoderFeatureConverterFactory() 45 | module = %ARCHITECTURE # provided by t5_flaxformer 46 | loss_module_factory = %LOSS_MODULE 47 | input_vocabulary = %VOCABULARY 48 | output_vocabulary = %VOCABULARY 49 | optimizer_def = %OPTIMIZER 50 | -------------------------------------------------------------------------------- /t5x_retrieval/configs/models/de_longt5_1_1_transient_global_base.gin: -------------------------------------------------------------------------------- 1 | # Dual Encoder LongT5 Base model. Config based on Dual Encoder T5.1.1 Base model. 2 | # Provides MODEL 3 | from __gin__ import dynamic_registration 4 | 5 | import seqio 6 | from t5x import adafactor 7 | from t5x_retrieval import feature_converters 8 | from t5x_retrieval import models 9 | from t5x_retrieval import losses 10 | 11 | ARCHITECTURE = %gin.REQUIRED 12 | T5XR_INFERENCE_MODE = 'encode' 13 | LOSS_MODULE = @losses.InBatchCrossEntropyLoss 14 | 15 | include 't5x_retrieval/configs/architectures/de_longt5_1_1_transient_global_flaxformer.gin' 16 | 17 | # Architecture overrides 18 | NUM_HEADS = 12 19 | NUM_LAYERS = 12 20 | HEAD_DIM = 64 21 | EMBED_DIM = 768 22 | MLP_DIM = 2048 23 | PROJECTION_DIM = 768 24 | 25 | # Vocabulary (shared by encoder and decoder) 26 | VOCABULARY = @seqio.SentencePieceVocabulary() 27 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" 28 | NUM_EMBEDDINGS = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency 29 | 30 | # Optimizer 31 | # `learning_rate` is set by `Trainer.learning_rate_fn`. 32 | OPTIMIZER = @adafactor.Adafactor() 33 | adafactor.Adafactor: 34 | decay_rate = 0.8 35 | step_offset = 0 36 | 37 | # Model 38 | MODEL = @models.DualEncoderModel() 39 | models.DualEncoderModel: 40 | use_negatives = False 41 | use_align_uniform = False 42 | inference_mode = %T5XR_INFERENCE_MODE 43 | feature_converter_cls = @feature_converters.DualEncoderFeatureConverterFactory() 44 | module = %ARCHITECTURE # provided by t5_flaxformer 45 | loss_module_factory = %LOSS_MODULE 46 | input_vocabulary = %VOCABULARY 47 | output_vocabulary = %VOCABULARY 48 | optimizer_def = %OPTIMIZER 49 | -------------------------------------------------------------------------------- /t5x_retrieval/adafactor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Retrieval Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2022 Google LLC 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Adafactor logical rules for T5X Retrieval.""" 29 | 30 | from flax.core import freeze 31 | from flax.core import unfreeze 32 | from t5x import adafactor 33 | 34 | 35 | def logical_factor_rules(): 36 | """Logical factor rules for T5X Retrieval (i.e. two tower models).""" 37 | rules = unfreeze(adafactor.standard_logical_factor_rules()) 38 | rules.update({ 39 | 'affinity': adafactor.FactorDim.COLUMN, 40 | 'idf_buckets': adafactor.FactorDim.NONE, 41 | }) 42 | return freeze(rules) 43 | -------------------------------------------------------------------------------- /t5x_retrieval/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Retrieval Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2022 Google LLC 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Customized metric functions for T5X Retrieval.""" 29 | 30 | from typing import Dict, Sequence 31 | 32 | import scipy.stats 33 | 34 | 35 | def spearman_corrcoef(targets: Sequence[float], 36 | scores: Sequence[float]) -> Dict[str, float]: 37 | """Spearman correlation coefficient. 38 | 39 | Args: 40 | targets: list of float. 41 | scores: list of float. 42 | 43 | Returns: 44 | Spearman correlation across all targets and scores. 45 | """ 46 | return {'spearman_corrcoef': 100 * scipy.stats.spearmanr(targets, scores)[0]} 47 | -------------------------------------------------------------------------------- /t5x_retrieval/postprocessors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Retrieval Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2022 Google LLC 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Postprocessors for T5X Retrieval.""" 29 | 30 | from typing import Mapping, Optional, Union 31 | 32 | import tensorflow as tf 33 | 34 | 35 | def extract_label_postprocessor( 36 | output: Mapping[str, tf.Tensor], 37 | example: Optional[Mapping[str, tf.Tensor]] = None, 38 | is_target: Optional[bool] = False 39 | ) -> Union[tf.Tensor, Mapping[str, tf.Tensor]]: 40 | """Extracts the label to feed into the SeqIO evaluator. 41 | 42 | Args: 43 | output: A mapping of strings and tensors. 44 | example: An optional mapping of strings and tensors. 45 | is_target: An optional variable to indicate whether the postprocessor is 46 | applied on the output or the target (i.e. the "labels" field.). 47 | 48 | Returns: 49 | The target tensor or the output mapping. 50 | """ 51 | if is_target: 52 | return example["labels"] 53 | return output 54 | -------------------------------------------------------------------------------- /t5x_retrieval/preprocessors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Retrieval Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2022 Google LLC 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Preprocessors for T5X Retrieval.""" 29 | 30 | import tensorflow as tf 31 | 32 | 33 | def to_stsb_label(dataset: tf.data.Dataset, label_field_name: str, 34 | label_type: str) -> tf.data.Dataset: 35 | """Converts the labels to scores within [0, 1] or multi class labels. 36 | 37 | Args: 38 | dataset: A TensorFlow dataset. 39 | label_field_name: A string of the label field name. 40 | label_type: A string indicating the label type. 41 | 42 | Returns: 43 | A TensorFlow dataset after the transformation. 44 | """ 45 | 46 | def map_fn(example): 47 | if label_type == "score": 48 | label = example[label_field_name] / 5 49 | elif label_type == "multi_class": 50 | label = example[label_field_name] 51 | label = tf.round(label * 2) 52 | else: 53 | raise ValueError(f"Unsupported label type: {label_type}") 54 | label = tf.expand_dims(label, axis=-1) 55 | example[label_field_name] = label 56 | return example 57 | 58 | return dataset.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE) 59 | -------------------------------------------------------------------------------- /t5x_retrieval/configs/runs/finetune.gin: -------------------------------------------------------------------------------- 1 | # Defaults for finetuning with train.py. 2 | # 3 | # You must also include a binding for MODEL. 4 | # 5 | # Required to be set: 6 | # 7 | # - MIXTURE_OR_TASK_NAME 8 | # - MIXTURE_OR_TASK_MODULE 9 | # - TASK_FEATURE_LENGTHS 10 | # - TRAIN_STEPS # includes pretrain steps 11 | # - MODEL_DIR # automatically set when using xm_launch 12 | # - INITIAL_CHECKPOINT_PATH 13 | # 14 | # `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt 15 | # has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. 16 | # 17 | # Commonly overridden options: 18 | # - DROPOUT_RATE 19 | # - train/DatasetConfig.batch_size 20 | # - train_eval/DatasetConfig.batch_size 21 | # - infer_eval/DatasetConfig.batch_size 22 | # - PjitPartitioner.num_partitions 23 | # - Trainer.num_microbatches 24 | # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess 25 | # on the fly. Most common tasks are cached, hence this is set to True by 26 | # default. 27 | from __gin__ import dynamic_registration 28 | 29 | import __main__ as train_script 30 | 31 | from t5x import adafactor 32 | from t5x import partitioning 33 | from t5x import utils 34 | from t5x_retrieval import adafactor_utils 35 | from t5x_retrieval import partitioning as t5xr_partitioning 36 | 37 | include 't5x/configs/runs/finetune.gin' 38 | 39 | BATCH_SIZE = 128 40 | 41 | train_script.train: 42 | infer_eval_dataset_cfg = None 43 | eval_steps = 20 44 | eval_period = 1000 # eval frequency 45 | random_seed = None 46 | 47 | train/utils.DatasetConfig: 48 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME 49 | task_feature_lengths = %TASK_FEATURE_LENGTHS 50 | split = 'train' 51 | shuffle = True 52 | seed = None # use a new seed each run/restart 53 | use_cached = %USE_CACHED_TASKS 54 | pack = False 55 | module = %MIXTURE_OR_TASK_MODULE 56 | 57 | train_eval/utils.DatasetConfig: 58 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME 59 | task_feature_lengths = %TASK_FEATURE_LENGTHS 60 | split = 'validation' 61 | shuffle = False 62 | seed = 42 63 | use_cached = %USE_CACHED_TASKS 64 | pack = False 65 | module = %MIXTURE_OR_TASK_MODULE 66 | 67 | utils.RestoreCheckpointConfig: 68 | assignment_map = ((r'state/param_states.*', None),) # Skip optimizer states 69 | fallback_to_scratch = True 70 | 71 | utils.SaveCheckpointConfig: 72 | period = 1000 73 | dtype = 'float32' 74 | keep = None # keep all checkpoints 75 | 76 | utils.create_learning_rate_scheduler: 77 | factors = 'linear_decay' 78 | base_learning_rate = 0.001 79 | warmup_steps = 1000 80 | 81 | adafactor.Adafactor: 82 | decay_rate = 0.8 83 | step_offset = 0 84 | logical_factor_rules = @adafactor_utils.logical_factor_rules() 85 | 86 | partitioning.PjitPartitioner: 87 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 88 | 89 | partitioning.standard_logical_axis_rules: 90 | additional_rules = @t5xr_partitioning.standard_logical_axis_rules() 91 | -------------------------------------------------------------------------------- /t5x_retrieval/tasks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Retrieval Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2022 Google LLC 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Add Tasks to registry.""" 29 | import functools 30 | 31 | import seqio 32 | import t5.data 33 | from t5x_retrieval import metrics 34 | from t5x_retrieval import postprocessors as t5xr_postprocessors 35 | from t5x_retrieval import preprocessors as t5xr_preprocessors 36 | import tensorflow as tf 37 | 38 | 39 | DEFAULT_VOCAB = t5.data.get_default_vocabulary() 40 | DEFAULT_OUTPUT_FEATURES = { 41 | "inputs": 42 | seqio.Feature(vocabulary=DEFAULT_VOCAB, add_eos=True, required=False), 43 | "targets": 44 | seqio.Feature(vocabulary=DEFAULT_VOCAB, add_eos=True) 45 | } 46 | 47 | RELEVANCE_OUTPUT_FEATURES = { 48 | "inputs": 49 | seqio.Feature( 50 | vocabulary=t5.data.get_default_vocabulary(), 51 | add_eos=True, 52 | required=True), 53 | "targets": 54 | seqio.Feature( 55 | vocabulary=t5.data.get_default_vocabulary(), add_eos=True), 56 | "labels": 57 | seqio.Feature( 58 | vocabulary=seqio.PassThroughVocabulary(size=1), 59 | add_eos=False, 60 | required=False, 61 | dtype=tf.float32) 62 | } 63 | 64 | 65 | # =========================== Fine-tuning Tasks/Mixtures ======================= 66 | # ----- Beir MS Marco----- 67 | seqio.TaskRegistry.add( 68 | "beir_msmarco_retrieval", 69 | source=seqio.TfdsDataSource( 70 | tfds_name="beir/msmarco:1.0.0", 71 | splits={ 72 | "train": "train", 73 | "validation": "validation", 74 | }, 75 | ), 76 | preprocessors=[ 77 | functools.partial( 78 | t5.data.preprocessors.rekey, 79 | key_map={ 80 | "inputs": "query", 81 | "targets": "passage", 82 | }), 83 | seqio.preprocessors.tokenize, 84 | seqio.CacheDatasetPlaceholder(), 85 | seqio.preprocessors.append_eos_after_trim, 86 | ], 87 | metric_fns=[], 88 | output_features=DEFAULT_OUTPUT_FEATURES) 89 | 90 | 91 | # ========================== STS Benchmark ==================================== 92 | seqio.TaskRegistry.add( 93 | "glue_stsb_v002_score", 94 | source=seqio.TfdsDataSource(tfds_name="glue/stsb:2.0.0", splits=None), 95 | preprocessors=[ 96 | functools.partial( 97 | seqio.preprocessors.rekey, 98 | key_map={ 99 | "inputs": "sentence1", 100 | "targets": "sentence2", 101 | "labels": "label" 102 | }), 103 | functools.partial( 104 | t5xr_preprocessors.to_stsb_label, 105 | label_field_name="labels", 106 | label_type="score", 107 | ), 108 | seqio.preprocessors.tokenize, 109 | seqio.CacheDatasetPlaceholder(), 110 | seqio.preprocessors.append_eos_after_trim, 111 | ], 112 | output_features=RELEVANCE_OUTPUT_FEATURES, 113 | postprocess_fn=t5xr_postprocessors.extract_label_postprocessor, 114 | metric_fns=[metrics.spearman_corrcoef]) 115 | 116 | 117 | # ============================ Inference Tasks/Mixtures ======================= 118 | # ----- Beir MS Marco----- 119 | for split in ["query", "passage"]: 120 | seqio.TaskRegistry.add( 121 | f"beir_msmarco_retrieval_{split}", 122 | source=seqio.TfdsDataSource( 123 | tfds_name="beir/msmarco:1.0.0", 124 | splits={split: split}, 125 | ), 126 | preprocessors=[ 127 | functools.partial( 128 | t5.data.preprocessors.rekey, 129 | key_map={ 130 | "inputs": split, 131 | "targets": f"{split}_id", 132 | }), 133 | seqio.preprocessors.tokenize, 134 | seqio.CacheDatasetPlaceholder(), 135 | seqio.preprocessors.append_eos_after_trim, 136 | ], 137 | metric_fns=[], 138 | output_features=DEFAULT_OUTPUT_FEATURES) 139 | -------------------------------------------------------------------------------- /t5x_retrieval/configs/architectures/de_t5_flaxformer.gin: -------------------------------------------------------------------------------- 1 | # Flaxformer implementation of Dual Encoder, based on original T5 (1.0) architecture. 2 | # 3 | # Required to be overridden: 4 | # 5 | # - NUM_HEADS 6 | # - NUM_LAYERS 7 | # - HEAD_DIM 8 | # - EMBED_DIM 9 | # - MLP_DIM 10 | # - PROJECTION_DIM 11 | from __gin__ import dynamic_registration 12 | 13 | from flax import linen 14 | 15 | from flaxformer.architectures.dual_encoder import dual_encoder_architecture 16 | from flaxformer.architectures.dual_encoder import l2_norm 17 | from flaxformer.architectures.dual_encoder import poolings 18 | from flaxformer.architectures.dual_encoder import similarity_functions 19 | from flaxformer.architectures.t5 import t5_architecture 20 | from flaxformer.components.attention import dense_attention 21 | from flaxformer.components import dense 22 | from flaxformer.components import embedding 23 | from flaxformer.components import layer_norm 24 | from flaxformer.components import relative_position_biases 25 | 26 | from t5x import models 27 | from t5x import utils 28 | from t5x_retrieval import feature_converters 29 | 30 | # Must be overridden. 31 | NUM_HEADS = %gin.REQUIRED 32 | NUM_LAYERS = %gin.REQUIRED 33 | HEAD_DIM = %gin.REQUIRED 34 | EMBED_DIM = %gin.REQUIRED 35 | MLP_DIM = %gin.REQUIRED 36 | PROJECTION_DIM = %gin.REQUIRED 37 | 38 | # Constants (may be overridden) 39 | ACTIVATION_DTYPE = 'bfloat16' 40 | ACTIVATION_PARTITIONING_DIMS = 1 41 | NUM_EMBEDDINGS = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency 42 | DROPOUT_RATE = 0.0 43 | 44 | # Macros 45 | BIAS_INIT = @bias_init/linen.initializers.normal() 46 | bias_init/linen.initializers.normal.stddev = 1e-6 47 | DROPOUT_FACTORY = @dropout_factory/linen.Dropout 48 | dropout_factory/linen.Dropout: 49 | rate = %DROPOUT_RATE 50 | broadcast_dims = (-2,) 51 | 52 | # Architecture (Flax Module) 53 | ARCHITECTURE = @dual_encoder_architecture.DualEncoder() 54 | dual_encoder_architecture.DualEncoder: 55 | encoder_factory = @t5_architecture.Encoder 56 | pooler_factory = @poolings.MeanPooling 57 | shared_token_embedder_factory = @embedding.Embed 58 | l2_norm_factory = @l2_norm.L2Norm 59 | projection_layer_factory = @projection_layer/dense.DenseGeneral 60 | similarity_layer_factory = @similarity_functions.BatchDotProduct 61 | dtype = %ACTIVATION_DTYPE 62 | 63 | # Encoder 64 | t5_architecture.Encoder: 65 | num_layers = %NUM_LAYERS 66 | layer_factory = @t5_architecture.EncoderLayer 67 | input_dropout_factory = %DROPOUT_FACTORY 68 | output_dropout_factory = %DROPOUT_FACTORY 69 | layer_norm_factory = @layer_norm.T5LayerNorm 70 | position_embedder_factory = None 71 | shared_relative_position_bias_factory = @relative_position_biases.RelativePositionBiases 72 | dtype = %ACTIVATION_DTYPE 73 | 74 | # TODO(b/262657686): Move this to model gin files. 75 | # Infer the input features from the TASK_FEATURE_LENGTHS passed by the user. 76 | feature_converters.DualEncoderFeatureConverterFactory: 77 | feature_specs = ( 78 | ("inputs", "int32", 1, 0), 79 | ("targets", "int32", 1, 0), 80 | ) 81 | 82 | # Similarity layer 83 | similarity_functions.BatchDotProduct: 84 | name = 'batch_dot_product' 85 | 86 | # Projection layer 87 | projection_layer/dense.DenseGeneral: 88 | features = %PROJECTION_DIM 89 | use_bias = False 90 | dtype = 'float32' 91 | kernel_init = @projection_layer/linen.initializers.variance_scaling() 92 | kernel_axis_names = ('embed', 'affinity') 93 | bias_init = %BIAS_INIT 94 | projection_layer/linen.initializers.variance_scaling: 95 | scale = 1 96 | mode = 'fan_in' 97 | distribution = 'truncated_normal' 98 | 99 | # Encoder Layer 100 | t5_architecture.EncoderLayer: 101 | attention = @dense_attention.MultiHeadDotProductAttention() 102 | mlp = @dense.MlpBlock() 103 | dropout_factory = %DROPOUT_FACTORY 104 | layer_norm_factory = @layer_norm.T5LayerNorm 105 | activation_partitioning_dims = %ACTIVATION_PARTITIONING_DIMS 106 | 107 | # Token Embedder (shared) 108 | embedding.Embed: 109 | num_embeddings= %NUM_EMBEDDINGS 110 | features = %EMBED_DIM 111 | cast_input_dtype = 'int32' 112 | dtype = %ACTIVATION_DTYPE 113 | attend_dtype = 'float32' # for logit training stability 114 | one_hot = True 115 | embedding_init = @token_embedder_init/linen.initializers.normal() 116 | name = 'token_embedder' 117 | token_embedder_init/linen.initializers.normal.stddev = 1.0 118 | 119 | # Attention (encoder, decoder, self-attention) 120 | dense_attention.MultiHeadDotProductAttention: 121 | num_heads = %NUM_HEADS 122 | head_dim = %HEAD_DIM 123 | dtype = %ACTIVATION_DTYPE 124 | kernel_init = @attention_kernel_init/linen.initializers.variance_scaling() 125 | bias_init = %BIAS_INIT 126 | use_bias = False 127 | broadcast_dropout = True 128 | dropout_rate = %DROPOUT_RATE 129 | attention_kernel_init/linen.initializers.variance_scaling: 130 | scale = 1.0 131 | mode = 'fan_in' 132 | distribution = 'normal' 133 | 134 | # Relative position biases (encoder, decoder) 135 | relative_position_biases.RelativePositionBiases: 136 | num_heads = %NUM_HEADS 137 | num_buckets = 32 138 | max_distance = 128 139 | dtype = %ACTIVATION_DTYPE 140 | embedding_init = @relative_position_bias_init/linen.initializers.variance_scaling() 141 | relative_position_bias_init/linen.initializers.variance_scaling: 142 | scale = 1.0 143 | mode = 'fan_avg' 144 | distribution = 'uniform' 145 | 146 | # MLP (encoder) 147 | dense.MlpBlock: 148 | use_bias = False 149 | intermediate_dim = %MLP_DIM 150 | activations = ('relu',) 151 | kernel_init = @mlp_kernel_init/linen.initializers.variance_scaling() 152 | bias_init = %BIAS_INIT 153 | intermediate_dropout_rate = %DROPOUT_RATE 154 | final_dropout_rate = 0 155 | dtype = %ACTIVATION_DTYPE 156 | mlp_kernel_init/linen.initializers.variance_scaling: 157 | scale = 1.0 158 | mode = 'fan_in' 159 | distribution = 'truncated_normal' 160 | -------------------------------------------------------------------------------- /t5x_retrieval/configs/architectures/de_longt5_1_1_transient_global_flaxformer.gin: -------------------------------------------------------------------------------- 1 | # Flaxformer implementation of Long Dual Encoder, based on LongT5 architecture. 2 | # 3 | # Required to be overridden: 4 | # 5 | # - NUM_HEADS 6 | # - NUM_LAYERS 7 | # - HEAD_DIM 8 | # - EMBED_DIM 9 | # - MLP_DIM 10 | from __gin__ import dynamic_registration 11 | 12 | from flax import linen 13 | from flaxformer.architectures.dual_encoder import dual_encoder_architecture 14 | from flaxformer.architectures.dual_encoder import l2_norm 15 | from flaxformer.architectures.dual_encoder import poolings 16 | from flaxformer.architectures.dual_encoder import similarity_functions 17 | from flaxformer.architectures.longt5 import long_attention 18 | from flaxformer.architectures.longt5 import longt5_architecture 19 | from flaxformer.architectures.longt5 import relative_position_biases_general 20 | from flaxformer.components import dense 21 | from flaxformer.components import embedding 22 | from flaxformer.components import layer_norm 23 | from flaxformer.components import relative_position_biases 24 | 25 | from t5x_retrieval import feature_converters 26 | 27 | 28 | NUM_LAYERS = %gin.REQUIRED 29 | NUM_HEADS = %gin.REQUIRED 30 | HEAD_DIM = %gin.REQUIRED 31 | EMBED_DIM = %gin.REQUIRED 32 | MLP_DIM = %gin.REQUIRED 33 | 34 | 35 | ACTIVATION_DTYPE = 'bfloat16' 36 | ACTIVATION_PARTITIONING_DIMS = 1 37 | SCALE = 1.0 38 | DROPOUT_RATE = 0.0 39 | LOCAL_RADIUS = 127 40 | TOKENS_PER_BLOCK = 16 41 | 42 | # Macros 43 | BIAS_INIT = @bias_init/linen.initializers.normal() 44 | bias_init/linen.initializers.normal.stddev = 1e-6 45 | DROPOUT_FACTORY = @dropout_factory/linen.Dropout 46 | dropout_factory/linen.Dropout: 47 | rate = %DROPOUT_RATE 48 | broadcast_dims = (-2,) 49 | 50 | # Architecture (Flax Module) 51 | ARCHITECTURE = @dual_encoder_architecture.LongDualEncoder() 52 | dual_encoder_architecture.LongDualEncoder: 53 | encoder_factory = @longt5_architecture.LongEncoder 54 | pooler_factory = @poolings.MeanPooling 55 | shared_token_embedder_factory = @embedding.Embed 56 | l2_norm_factory = @l2_norm.L2Norm 57 | projection_layer_factory = @projection_layer/dense.DenseGeneral 58 | similarity_layer_factory = @similarity_functions.BatchDotProduct 59 | dtype = %ACTIVATION_DTYPE 60 | 61 | # Encoder 62 | longt5_architecture.LongEncoder: 63 | num_layers = %NUM_LAYERS 64 | layer_factory = @longt5_architecture.LongEncoderLayer 65 | input_dropout_factory = %DROPOUT_FACTORY 66 | output_dropout_factory = %DROPOUT_FACTORY 67 | layer_norm_factory = @layer_norm.T5LayerNorm 68 | position_embedder_factory = None 69 | shared_relpos_bias_factory = @relative_position_biases_general.RelativePositionBiasesGeneral 70 | shared_side_relpos_bias_factory = @relative_position_biases_general.RelativePositionBiasesGeneral 71 | dtype = %ACTIVATION_DTYPE 72 | 73 | # Infer the input features from the TASK_FEATURE_LENGTHS passed by the user. 74 | feature_converters.DualEncoderFeatureConverterFactory: 75 | feature_specs = ( 76 | ("inputs", "int32", 1, 0), 77 | ("targets", "int32", 1, 0), 78 | ) 79 | 80 | # Similarity layer 81 | similarity_functions.BatchDotProduct: 82 | name = 'batch_dot_product' 83 | 84 | # Projection layer 85 | projection_layer/dense.DenseGeneral: 86 | features = %PROJECTION_DIM 87 | use_bias = False 88 | dtype = 'float32' 89 | kernel_init = @projection_layer/linen.initializers.variance_scaling() 90 | kernel_axis_names = ('embed', 'affinity') 91 | bias_init = %BIAS_INIT 92 | projection_layer/linen.initializers.variance_scaling: 93 | scale = 1 94 | mode = 'fan_in' 95 | distribution = 'truncated_normal' 96 | 97 | # Encoder Layer 98 | longt5_architecture.LongEncoderLayer: 99 | attention_factory = @long_attention.EtcTransientGlobalSelfAttention 100 | mlp = @dense.MlpBlock() 101 | dropout_factory = %DROPOUT_FACTORY 102 | layer_norm_factory = @layer_norm.T5LayerNorm 103 | activation_partitioning_dims = %ACTIVATION_PARTITIONING_DIMS 104 | 105 | # Long Attention (encoder, self-attention) 106 | long_attention.EtcTransientGlobalSelfAttention: 107 | num_heads = %NUM_HEADS 108 | tokens_per_block = %TOKENS_PER_BLOCK 109 | local_radius = %LOCAL_RADIUS 110 | dtype = %ACTIVATION_DTYPE 111 | head_dim = %HEAD_DIM 112 | kernel_init = @attention_kernel_init/linen.initializers.variance_scaling() 113 | bias_init = %BIAS_INIT 114 | use_bias = False 115 | broadcast_dropout = True 116 | dropout_rate = %DROPOUT_RATE 117 | attention_kernel_init/linen.initializers.variance_scaling: 118 | scale = %SCALE 119 | mode = 'fan_in' 120 | distribution = 'normal' 121 | 122 | # Relative position biases (encoder) 123 | relative_position_biases_general.RelativePositionBiasesGeneral: 124 | num_heads = %NUM_HEADS 125 | dtype = %ACTIVATION_DTYPE 126 | num_buckets = 32 127 | max_distance = 128 128 | embedding_init = @relative_position_bias_init/linen.initializers.variance_scaling() 129 | relative_position_bias_init/linen.initializers.variance_scaling: 130 | scale = %SCALE 131 | mode = 'fan_avg' 132 | distribution = 'uniform' 133 | 134 | # Token Embedder (shared) 135 | embedding.Embed: 136 | num_embeddings= %NUM_EMBEDDINGS 137 | features = %EMBED_DIM 138 | cast_input_dtype = 'int32' 139 | dtype = %ACTIVATION_DTYPE 140 | attend_dtype = 'float32' # for logit training stability 141 | one_hot = True 142 | embedding_init = @token_embedder_init/linen.initializers.normal() 143 | name = 'token_embedder' 144 | token_embedder_init/linen.initializers.normal.stddev = 1.0 145 | 146 | dense.MlpBlock: 147 | use_bias = False 148 | intermediate_dim = %MLP_DIM 149 | activations = ('gelu', 'linear') 150 | kernel_init = @mlp_kernel_init/linen.initializers.variance_scaling() 151 | bias_init = %BIAS_INIT 152 | intermediate_dropout_rate = %DROPOUT_RATE 153 | final_dropout_rate = 0 154 | dtype = %ACTIVATION_DTYPE 155 | mlp_kernel_init/linen.initializers.variance_scaling: 156 | scale = 1.0 157 | mode = 'fan_in' 158 | distribution = 'truncated_normal' 159 | -------------------------------------------------------------------------------- /t5x_retrieval/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Retrieval Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2022 Google LLC 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Loss layer implementations for dual encoders.""" 29 | import abc 30 | from typing import Mapping, Optional, Tuple 31 | 32 | import clu.metrics as clu_metrics 33 | from flax import linen as nn 34 | from flax.core import scope as flax_scope 35 | import gin 36 | from jax import numpy as jnp 37 | import seqio 38 | from t5x import metrics as t5x_metrics 39 | from t5x import models as t5x_models 40 | from t5x_retrieval import feature_converters 41 | from t5x_retrieval import utils 42 | 43 | FeatureSpec = feature_converters.FeatureSpec 44 | SeqIOFeatureSpec = seqio.feature_converters.FeatureConverter.FeatureSpec 45 | 46 | 47 | @gin.configurable 48 | class DualEncoderLoss(nn.Module, abc.ABC): 49 | """Base class for loss layers accepted by dual encoders. 50 | 51 | """ 52 | 53 | @property 54 | @abc.abstractmethod 55 | def LOSS_MODEL_FEATURES(self) -> Mapping[str, SeqIOFeatureSpec]: 56 | """Model features required by loss layer for computing loss/metrcs. 57 | 58 | This property specifies the features that are expected to be present in 59 | `batch` argument passed to the __call__ method by a dual encoder model. 60 | These must be described as a map of feature keys to SeqIO `FeatureSpec`s, 61 | following the format of the `MODEL_FEATURES` attribute of SeqIO's 62 | `FeatureConverter`s. The features specified here in this manner should be 63 | validated by `validate_model_features` method (the default base class 64 | implementation does this). 65 | """ 66 | pass 67 | 68 | @nn.compact 69 | @abc.abstractmethod 70 | def __call__(self, batch: Mapping[str, jnp.ndarray], logits: jnp.ndarray, 71 | **kwargs) -> Tuple[jnp.float64, t5x_metrics.MetricsMap]: 72 | """Computes loss and loss-related metrics on the inputs. 73 | 74 | Args: 75 | batch: Features passed by dual-encoder model. Any external supervision 76 | signals/labels should be passed through this argument. 77 | logits: Output of dual-encoder model's similarity layer. 78 | **kwargs: Additional arguments that may be needed by the call method 79 | 80 | Returns: 81 | Tuple of (loss, metrics) 82 | """ 83 | pass 84 | 85 | @abc.abstractmethod 86 | def get_initial_variables(self, rng, input_shapes, 87 | input_types) -> flax_scope.FrozenVariableDict: 88 | """Gets the initial variables for a loss layer.""" 89 | pass 90 | 91 | def validate_model_features(self, feature_converter: seqio.FeatureConverter): 92 | """Ensures loss-layer's required features are provided by feature converter. 93 | 94 | This method checks that the feature-spec for loss model features match. Any 95 | additional validation should be performed by child classes. 96 | 97 | Args: 98 | feature_converter: Feature converter used by the dual encoder model that 99 | invokes this loss layer. 100 | 101 | Raises: 102 | ValueError, if feature is missing or has a spec mismatch 103 | """ 104 | model_features = feature_converter.MODEL_FEATURES 105 | for name, feature in self.LOSS_MODEL_FEATURES.items(): 106 | model_feature = model_features.get(name, None) 107 | if not model_feature: 108 | raise ValueError(f"Missing required loss-layer feature {name} " 109 | f"in model features {model_features}") 110 | if not isinstance(feature, type(model_feature)): 111 | raise ValueError( 112 | "Found incorrect type for feature spec of loss layer feature ", 113 | f"{name}, expected {type(model_feature)}, found {type(feature)}.") 114 | if feature != model_feature: 115 | raise ValueError("Found incorrect feature spec for loss layer feature " 116 | f"{name}, expected {feature}, found {model_feature}") 117 | 118 | 119 | class InBatchCrossEntropyLoss(DualEncoderLoss): 120 | """Dual encoder in-batch cross-entropy loss implementation. 121 | 122 | Attributes: 123 | bidirectional: Whether to use bi-directional in-batch softmax loss. If set 124 | to True, consider both left-to-right and right-to-left losses. 125 | label_smoothing: Label smoothing constant, used to determine the on and off 126 | values. 127 | """ 128 | 129 | bidirectional: bool = True 130 | label_smoothing: float = 0.0 131 | 132 | @property 133 | def LOSS_MODEL_FEATURES(self): 134 | """Model features required by loss layer for computing loss/metrcs.""" 135 | # This loss relies only on in-batch logits, and therefore doesn't 136 | # need any additional model features 137 | return {} 138 | 139 | def get_initial_variables(self, rng, input_shapes, 140 | input_types) -> flax_scope.FrozenVariableDict: 141 | """Gets the initial variables for a loss layer.""" 142 | # `logits` is of shape [B, B*(1+num_negatives)] that considers the 143 | # negatives while `right_logits` is in shape [B, B] that doesn't considers 144 | # negatives. `num_negatives` could be greater than 1 in the future. 145 | left_encoder_shape = input_shapes["left_encoder_input_tokens"] 146 | batch_size = left_encoder_shape[0] 147 | 148 | num_negatives = 0 149 | if "right_negative_encoder_input_tokens" in input_shapes: 150 | # right_negative_encoder_input_tokens: batch_size x num_negatives 151 | num_negatives = input_shapes["right_negative_encoder_input_tokens"][1] 152 | 153 | return self.init( 154 | rng, 155 | params={}, 156 | batch={}, 157 | logits=jnp.ones([batch_size, (batch_size * (1 + num_negatives))]), 158 | right_logits=jnp.ones([batch_size, batch_size])) 159 | 160 | @nn.compact 161 | def __call__(self, 162 | batch: Mapping[str, jnp.ndarray], 163 | logits: jnp.ndarray, 164 | right_logits: Optional[jnp.ndarray] = None, 165 | **kwargs) -> Tuple[jnp.float64, t5x_metrics.MetricsMap]: 166 | """Computes loss and loss-related metrics on inputs. 167 | 168 | `logits` is of shape [B, B*(1+num_negatives)] that considers the 169 | negatives while `right_logits` is in shape [B, B] that doesn't considers 170 | negatives. `num_negatives` could be greater than 1 in the future. 171 | 172 | Args: 173 | batch: Features passed by dual-encoder model. Unused. 174 | logits: Output of similarity layer that considers negatives. Has shape [B, 175 | B*(1+num_negatives)]. 176 | right_logits: Output of similarity layer that doesn't consider negatives. 177 | Has shape [B, B]. If None, the right loss is skipped. 178 | **kwargs: Unused. 179 | 180 | Returns: 181 | loss: a float scalar for contrastive loss. 182 | metrics: metrics defined in `t5x_models.compute_base_metrics` and `MRR`. 183 | """ 184 | del kwargs 185 | del batch # we don't require any external inputs for this loss 186 | 187 | # z_loss is already added to loss, which is a workaround for the numerical 188 | # instability issue. 189 | z_loss = 0. 190 | loss = utils.in_batch_cross_entropy( 191 | logits, label_smoothing=self.label_smoothing) 192 | if right_logits is not None and self.bidirectional: 193 | right_loss = utils.in_batch_cross_entropy( 194 | right_logits, label_smoothing=self.label_smoothing) 195 | loss = jnp.mean(loss + right_loss) 196 | 197 | metrics = t5x_models.compute_base_metrics( 198 | logits=logits, 199 | targets=utils.sparse_labels_for_in_batch_cross_entropy(logits), 200 | mask=None, 201 | loss=loss, 202 | z_loss=z_loss) 203 | 204 | return loss, metrics 205 | 206 | 207 | -------------------------------------------------------------------------------- /t5x_retrieval/feature_converters.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Retrieval Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2022 Google LLC 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | 29 | """Feature converter for dual encoders.""" 30 | 31 | import dataclasses 32 | from typing import Any, Iterable, Mapping, Sequence, Tuple 33 | import ml_collections 34 | import seqio 35 | import tensorflow as tf 36 | 37 | FeatureSpecConfig = Tuple[str, str, int, int] 38 | 39 | utils = seqio.utils 40 | FeatureConverter = seqio.FeatureConverter 41 | 42 | 43 | # TODO(b/262657107): Add features for universal feature converter. 44 | _MODEL_FEATURES_MAPPING = { 45 | "inputs": "left_encoder_input_tokens", 46 | "targets": "right_encoder_input_tokens", 47 | "negative_targets": "right_negative_encoder_input_tokens", 48 | "soft_label_pos": "soft_label_pos", 49 | "soft_label_neg": "soft_label_neg", 50 | "labels": "labels", 51 | "loss_weights": "loss_weights", 52 | } 53 | 54 | 55 | @dataclasses.dataclass 56 | class FeatureSpec: 57 | """Container class for a feature's name, dtype, and rank.""" 58 | name: str 59 | dtype: tf.DType 60 | rank: int 61 | sequence_dim: int 62 | 63 | @classmethod 64 | def to_map( 65 | cls, feature_specs: Iterable[FeatureSpecConfig]) -> Mapping[str, Any]: 66 | feature_spec_map = {} 67 | for name, dtype_str, rank, sequence_dim in feature_specs: 68 | feature_spec_map[name] = cls( 69 | name, getattr(tf, dtype_str), rank, sequence_dim) 70 | return feature_spec_map 71 | 72 | 73 | class DualEncoderFeatureConverterFactory(object): 74 | """Factory for dual encoder feature converters.""" 75 | 76 | def __init__(self, feature_specs: Iterable[FeatureSpecConfig], 77 | is_multimodal: bool = False): 78 | self.feature_specs = feature_specs 79 | self.is_multimodal = is_multimodal 80 | 81 | def __call__(self, 82 | pack: bool = False, 83 | use_custom_packing_ops: bool = False): 84 | feature_spec_map = FeatureSpec.to_map(self.feature_specs) 85 | return DualEncoderFeatureConverter( 86 | input_features=feature_spec_map.values(), 87 | is_multimodal=self.is_multimodal, 88 | pack=pack, 89 | use_custom_packing_ops=use_custom_packing_ops) 90 | 91 | 92 | class DualEncoderFeatureConverter(FeatureConverter): 93 | """Feature converter for dual-encoder achitecture. 94 | 95 | The inputs and targets to the dual-encoder are sent to the left and right 96 | encoders separately. 97 | 98 | Attributes: 99 | input_features: a list of feature specs that are used to define the 100 | feature's name, dtype and rank. 101 | is_multimodal: a boolean variable to indicate whether it is a feature 102 | converter to multimodal inputs. For multimodal inputs, we don't use the 103 | default MODEL_FEATURES_MAPPING. 104 | """ 105 | input_features: Sequence[FeatureSpec] = () 106 | is_multimodal: bool = False 107 | 108 | def __init__(self, 109 | input_features: Sequence[FeatureSpec], 110 | is_multimodal: bool = False, 111 | pack: bool = False, 112 | use_custom_packing_ops: bool = True): 113 | self.input_features = input_features 114 | self.is_multimodal = is_multimodal 115 | # NOTE: for multimodal inputs, make sure the inputs either (1) include both 116 | # "left_" and "right_" features for training or (2) don't include any of 117 | # them for inference time. 118 | if self.is_multimodal: 119 | has_left, has_right = False, False 120 | for f in self.input_features: 121 | if "left_" in f.name: 122 | has_left = True 123 | elif "right_" in f.name: 124 | has_right = True 125 | if (has_left and not has_right) or (not has_left and has_right): 126 | raise ValueError( 127 | "Multimodal inputs features should have both left and right tower" 128 | "features for training." 129 | ) 130 | super().__init__(pack=pack, use_custom_packing_ops=use_custom_packing_ops) 131 | 132 | @property 133 | def TASK_FEATURES(self): 134 | feature_specs_map = { 135 | f.name: seqio.FeatureConverter.FeatureSpec( 136 | dtype=f.dtype, rank=f.rank, sequence_dim=f.sequence_dim) 137 | for f in self.input_features 138 | } 139 | return feature_specs_map 140 | 141 | # NOTE: only use the _MODEL_FEATURES_MAPPING for non-multimodal inputs. 142 | @property 143 | def MODEL_FEATURES(self): 144 | feature_specs_map = {} 145 | for f in self.input_features: 146 | name = f.name if self.is_multimodal else _MODEL_FEATURES_MAPPING[f.name] 147 | feature_specs_map[name] = seqio.FeatureConverter.FeatureSpec( 148 | dtype=f.dtype, rank=f.rank, sequence_dim=f.sequence_dim) 149 | return feature_specs_map 150 | 151 | @property 152 | def PACKING_FEATURE_DTYPES(self): 153 | return None 154 | 155 | def _convert_features(self, ds: tf.data.Dataset, 156 | input_lengths: Mapping[str, int]) -> tf.data.Dataset: 157 | """Convert the input dataset to an output dataset to be fed to the model. 158 | 159 | The conversion process involves three steps 160 | 161 | 1. Each feature in the `input_lengths` is padded. 162 | 2. "inputs" fields are mapped to the left encoder input and "targets" are 163 | mapped to right encoder input. 164 | 165 | Assume the input dataset has two examples each with "inputs" and "targets". 166 | 167 | ds = [{"inputs": [7, 8, 5, 1], "targets": [3, 9, 1]}, 168 | {"inputs": [8, 4, 9, 3, 1], "targets": [4, 1]}] 169 | 170 | task_feature_lengths = {"inputs": 8, "targets": 4} 171 | 172 | First, the `inputs` are padded to length 8 and assigned to 173 | "left_encoder_input_tokens" field. The `targets` are processed similarly. 174 | 175 | converted_ds = [ 176 | { 177 | "left_encoder_input_tokens": [7, 8, 5, 1, 0, 0, 0, 0], 178 | "right_encoder_input_tokens": [3, 9, 1, 0], 179 | }, 180 | { 181 | "left_encoder_input_tokens": [8, 4, 9, 3, 1, 0, 0, 0], 182 | "right_encoder_input_tokens": [4, 1, 0, 0], 183 | }, 184 | ] 185 | 186 | Args: 187 | ds: an input tf.data.Dataset to be converted. 188 | input_lengths: a mapping from a feature to its length 189 | 190 | Returns: 191 | ds: the converted dataset. 192 | """ 193 | 194 | def convert_example( 195 | features: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]: 196 | d = {} 197 | for f in self.input_features: 198 | name = f.name if self.is_multimodal else _MODEL_FEATURES_MAPPING[f.name] 199 | d[name] = features[f.name] 200 | return d 201 | 202 | if self.pack: 203 | raise ValueError( 204 | "Dual encoder only takes non packed examples at this moment." 205 | ) 206 | 207 | # Stop padding features with rank > 1, since _pack_or_pad adds padding to 208 | # the first dimension instead of the last dimension. 209 | for f_name in input_lengths: 210 | if f_name in self.TASK_FEATURES and self.TASK_FEATURES[f_name].rank > 1: 211 | # input should already be padded and dense. 212 | input_lengths = dict(input_lengths) 213 | if isinstance(input_lengths, ml_collections.ConfigDict): 214 | input_lengths.unlock() 215 | del input_lengths[f_name] 216 | 217 | ds = self._pack_or_pad(ds, input_lengths) 218 | return ds.map( 219 | convert_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) 220 | 221 | def get_model_feature_lengths( 222 | self, task_feature_lengths: Mapping[str, int]) -> Mapping[str, int]: 223 | """Define the length relationship between input and output features.""" 224 | model_feature_lengths = {} 225 | for k in self.TASK_FEATURES: 226 | model_feature = k if self.is_multimodal else _MODEL_FEATURES_MAPPING[k] 227 | model_feature_lengths[model_feature] = task_feature_lengths[k] 228 | if self.pack: 229 | raise ValueError("Packing not supported") 230 | 231 | return model_feature_lengths 232 | 233 | 234 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # T5X Retrieval 2 | 3 | T5X Retrieval is a JAX implementation of T5 (Text-to-Text Transfer Transformer) optimized for retrieval applications. 4 | It is built on top of T5 on JAX, aka [T5X](https://github.com/google-research/t5x). 5 | This is targeted at Natural Language Understanding researchers as well as application developers who are aiming to use the latest T5-based Transformer models for search, retrieval and ranking applications, but in the JAX framework as opposed to TensorFlow. 6 | 7 | T5X Retrieval is an efficient training and evaluation framework that supports transformer-based neural retrieval and ranking models such as sentence encoders and dense retrieval models. It supports multi-pod large model training, large cross-batch negatives and the capability to initialize from any pre-trained model trained using T5X. 8 | 9 | This launch open sources the training and inference code, including references to TFDS for training data, actual model training code (Python JAX & Flaxformer), pre-trained models and basic inference example code. This end-to-end example model code is meant to accompany the SentenceT5 and Generalizable T5 Retrieval models that includes the implementation and performance on relevant benchmarks. 10 | 11 | 12 | # What's here 13 | 14 | - configs/\*.gin - Model configurations 15 | - tasks.py - Task definitions that generate the dataset. 16 | - feature_converters.py - Converters that transform the task features from the dataset to model features 17 | - models.py - High-level models, such as DualEncoderDecoderModel, that take the feature converters ouputs as inputs. 18 | 19 | For more details about the training pipeline and task definitions, you can check out [T5X](https://github.com/google-research/t5x) and [Seqio](https://github.com/google/seqio). 20 | 21 | # Quickstart (Recommended) 22 | 23 | T5X Retrieval supports the training and evaluation options provided by [T5X](https://github.com/google-research/t5x). It can be run with [XManager](https://github.com/deepmind/xmanager) on 24 | [Vertex AI](https://cloud.google.com/vertex-ai), which is a platform for 25 | training that creates TPU instances and runs code on the TPUs. 26 | 27 | We briefly summarized steps to quickly start the training and inference jobs. You can find more details at the [T5X Quickstart](https://github.com/google-research/t5x#quickstart-recommended). 28 | 29 | 0. [Create a GCP project](https://github.com/deepmind/xmanager#create-a-gcp-project-optional). Create the bucket to store data and models. 30 | 31 | 1. Follow the pre-requisites and directions to install [XManager](https://github.com/deepmind/xmanager). 32 | 33 | 2. [Optional] GCP projects come with 8 cores by default, which is enough to run one training experiment on a single TPU host. Request TPU quota as required if you want to run multi-host training or multiple runs in parallel. 34 | 35 | 3. Install all dependencies such as [T5X](https://github.com/google-research/t5x), [Flaxformer](https://github.com/google/flaxformer), [TFDS](https://github.com/tensorflow/datasets). 36 | 37 | 4. Launch the xmanager script located at `t5x/scripts/xm_launch.py`. 38 | 39 | As a running example, we use the [BEIR MS Marco dataset](https://github.com/beir-cellar/beir#beers-available-datasets). 40 | 41 | ```sh 42 | # Export GOOGLE_CLOUD_BUCKET_NAME to a proper value. 43 | export GOOGLE_CLOUD_BUCKET_NAME=... 44 | export TFDS_DATA_DIR=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x_retrieval/data 45 | export MODEL_DIR=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x_retrieval/$(date +%Y%m%d) 46 | 47 | # Install dependencies. 48 | git clone https://github.com/google-research/t5x /tmp/t5x 49 | git clone https://github.com/google-research/t5x_retrieval /tmp/t5x_retrieval 50 | git clone https://github.com/google/flaxformer /tmp/flaxformer 51 | git clone https://github.com/google/aqt.git /tmp/aqt 52 | 53 | cd /tmp/t5x/ 54 | 55 | python3 t5x/scripts/xm_launch.py \ 56 | --pip_install="apache_beam[gcp]" \ 57 | --model_dir=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/msmarco_ft_$(date +%Y%m%d) \ 58 | --tfds_data_dir=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/data \ 59 | --project_dirs=/tmp/t5x_retrieval/t5x_retrieval,/tmp/flaxformer/flaxformer,/tmp/aqt/aqt \ 60 | --gin_file=t5x_retrieval/configs/models/de_t5_base.gin \ 61 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_base/checkpoint_999900\" \ 62 | --gin_file=t5x_retrieval/configs/runs/finetune.gin \ 63 | --gin.TRAIN_STEPS=1009900 \ 64 | --gin.utils.create_learning_rate_scheduler.step_offset=999900 \ 65 | --gin.utils.create_learning_rate_scheduler.warmup_steps=1000 \ 66 | --gin.utils.create_learning_rate_scheduler.decay_factor=0.00000125 \ 67 | --gin.USE_CACHED_TASKS=False \ 68 | --gin.models.DualEncoderModel.use_negatives=False \ 69 | --gin.train.eval_period=500 \ 70 | --gin.utils.SaveCheckpointConfig.keep=10 \ 71 | --gin.utils.SaveCheckpointConfig.period=500 \ 72 | --gin.train/DatasetConfig.batch_size=512 \ 73 | --gin.MIXTURE_OR_TASK_NAME="'beir_msmarco_retrieval'" \ 74 | --gin.MIXTURE_OR_TASK_MODULE="'t5x_retrieval.tasks'" \ 75 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 64, 'targets': 256}" 76 | ``` 77 | 78 | Notes: 79 | 80 | - Check `gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/` for the output artifacts, which can be read by TensorBoard. 81 | - Add `--pip_install="apache_beam[gcp]"` to the script if you have not downloaded the dataset before hand. 82 | - The `TRAIN_STEPS = step_offset + real_train_steps`, where `step_offset` is the step of the loaded checkpoint while the `real_train_step` is the steps that the model will be trained for. 83 | 84 | # Models 85 | 86 | ## Sentence encoders 87 | **SentenceT5** is a family of high performing sentence encoders trained using T5X Retrieval. The sentenceT5 models encode text into high-dimensional vectors that can be used for text classification, semantic similarity, clustering and other natural language processing tasks. 88 | 89 | SentenceT5 models are built on top of the Text-To-Text Transfer Transformer (T5). It is trained on a variety of data sources and initialized from pre-trained T5 models with different model sizes as described in [1]. The input is variable-length English text and the output is a 768-dimensional vector. Note that there's no hard length limit for T5 (i.e., no 512 tokens limit as in BERT), but that it's been trained to produce good embeddings for approximately sentence length text. 90 | 91 | ### Metrics 92 | 93 | * We evaluate this model on the 94 | [SentEval](https://github.com/facebookresearch/SentEval) sentence 95 | representation benchmark. 96 | 97 | Transfer tasks | MR | CR | SUBJ | MPQA | SST | TREC | MRPC | Average 98 | :------------------------------------------------------------ | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ------: 99 | **ST5-Base** | 85.8 | 92.1 | 94.6 | 90.9 | 91.8 | 96.4 | 75.2 | 89.5 100 | [ST5-Large](https://tfhub.dev/google/sentence-t5/st5-large/1) | 88.9 | 93.5 | 95.4 | 91.5 | 94.2 | 96.2 | 77.1 | 91.0 101 | [ST5-3B](https://tfhub.dev/google/sentence-t5/st5-3b/1) | 89.9 | 94.1 | 95.9 | 91.6 | 94.8 | 96.2 | 77.9 | 91.5 102 | [ST5-11B](https://tfhub.dev/google/sentence-t5/st5-11b/1) | 90.8 | 94.4 | 96.3 | 91.7 | 94.8 | 95.4 | 77.9 | 91.6 103 | 104 |
105 | 106 | STS tasks | STS12 | STS13 | STS14 | STS15 | STS16 | STSb | SICK-R | Average 107 | :------------------------------------------------------------ | ----: | ----: | ----: | ----: | ----: | ---: | -----: | ------: 108 | **ST5-Base** | 78.1. | 85.8 | 82.2 | 87.5 | 84.0 | 86.0 | 79.8 | 83.3 109 | [ST5-Large](https://tfhub.dev/google/sentence-t5/st5-large/1) | 79.1 | 87.3 | 83.2 | 88.3 | 84.4 | 86.7 | 79.8 | 84.1 110 | [ST5-3B](https://tfhub.dev/google/sentence-t5/st5-3b/1) | 79.0 | 88.8 | 84.3 | 88.9 | 85.3 | 86.3 | 79.5 | 84.6 111 | [ST5-11B](https://tfhub.dev/google/sentence-t5/st5-11b/1) | 80.1 | 88.8 | 84.7 | 88.9 | 85.2 | 86.8 | 80.4 | 85.0 112 | 113 | More details about the evaluations can be found in the paper [1]. 114 | 115 | ## Dense retrieval models 116 | The **Generalizable T5 Retrieval** models are dual encoders that encode two pieces of text into two dense 117 | vectors respectively [2]. This is typically used to encode a query and a document to 118 | compute their similarity for dense retrieval. 119 | 120 | GTR models are built on top of [T5](https://arxiv.org/pdf/1910.10683.pdf) (i.e. 121 | the Text-To-Text Transfer Transformer). The GTR-Base model employs a 12-layer 122 | transformer architecture, which is the same as the T5 base model. The model is 123 | first initialized from the pre-trained T5 checkpoint. It is then further 124 | pre-trained with a set of community question-answer pairs we collected. Finally, 125 | the model is fine-tuned on the [MS Marco](https://microsoft.github.io/msmarco/) 126 | dataset. 127 | 128 | The two encoders are [shared](https://arxiv.org/pdf/2204.07120.pdf) so the GTR model functions as a single text encoder. 129 | The input is variable-length English text and the output is a 768-dimensional 130 | vector. 131 | 132 | ### Metrics 133 | 134 | We evaluate on the [BEIR](https://github.com/UKPLab/beir) benchmark and report the Recall@100. 135 | 136 | Dataset \ Model | **GTR-Base** | [GTR-Large](https://tfhub.dev/google/gtr/gtr-large/1) | [GTR-XL](https://tfhub.dev/google/gtr/gtr-xl/1) | [GTR-XXL](https://tfhub.dev/google/gtr/gtr-xxl/1) 137 | ---------------- | ------------ | ----------------------------------------------------- | ----------------------------------------------- | ------------------------------------------------- 138 | MS MARCO | 0.898 | 0.908 | 0.911 | 0.916 139 | Trec-Covid | 0.411 | 0.434 | 0.457 | 0.407 140 | BioASQ | 0.441 | 0.490 | 0.483 | 0.483 141 | NFCorpus | 0.275 | 0.298 | 0.318 | 0.300 142 | NQ | 0.893 | 0.930 | 0.936 | 0.946 143 | HotpotQA | 0.676 | 0.725 | 0.739 | 0.752 144 | FiQA-2018 | 0.670 | 0.742 | 0.755 | 0.780 145 | Signal-1M | 0.263 | 0.261 | 0.268 | 0.268 146 | Trec-News | 0.475 | 0.525 | 0.512 | 0.544 147 | Robust04 | 0.324 | 0.365 | 0.364 | 0.372 148 | ArguAna | 0.974 | 0.978 | 0.980 | 0.983 149 | Touché-2020 | 0.281 | 0.282 | 0.297 | 0.301 150 | Quora | 0.996 | 0.996 | 0.997 | 0.997 151 | DBPedia-entity | 0.418 | 0.480 | 0.480 | 0.494 152 | SCIDOCS | 0.340 | 0.358 | 0.358 | 0.366 153 | Fever | 0.923 | 0.941 | 0.944 | 0.947 154 | Climate-Fever | 0.522 | 0.552 | 0.569 | 0.556 155 | SciFact | 0.872 | 0.899 | 0.911 | 0.900 156 | CQADupStack | 0.681 | 0.714 | 0.729 | 0.740 157 | Avg | 0.596 | 0.625 | 0.632 | 0.634 158 | Avg w/o MS MARCO | 0.580 | 0.609 | 0.616 | 0.619 159 | 160 | 161 | # Released Model Checkpoints 162 | 163 | We have released the following checkpoints for SentenceT5 and GTR pre-trained models: 164 | 165 | * **SentenceT5-Base** ([config](t5x_retrieval/configs/models/de_t5_base.gin), 110M parameters): [gs://t5-data/pretrained_models/t5x/retrieval/st5_base](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/st5_base) 166 | * **SentenceT5-Large** ([config](t5x_retrieval/configs/models/de_t5_large.gin), 335M parameters): [gs://t5-data/pretrained_models/t5x/retrieval/st5_large](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/st5_large/) 167 | * **SentenceT5-XL** ([config](t5x_retrieval/configs/models/de_t5_3B.gin), 1.24B parameters): [gs://t5-data/pretrained_models/t5x/retrieval/st5_xl](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/st5_xl/) 168 | * **SentenceT5-XXL** ([config](t5x_retrieval/configs/models/de_t5_11B.gin), 4.8B parameters): [gs://t5-data/pretrained_models/t5x/retrieval/st5_xxl](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/st5_xxl/) 169 | * **GTR-Base** ([config](t5x_retrieval/configs/models/de_t5_base.gin), 110M parameters): [gs://t5-data/pretrained_models/t5x/retrieval/gtr_base](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/gtr_base/) 170 | * **GTR-Large** ([config](t5x_retrieval/configs/models/de_t5_large.gin), 335M parameters): [gs://t5-data/pretrained_models/t5x/retrieval/gtr_large](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/gtr_large/) 171 | * **GTR-XL** ([config](t5x_retrieval/configs/models/de_t5_3B.gin), 1.24B parameters): [gs://t5-data/pretrained_models/t5x/retrieval/gtr_xl](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/gtr_xl/) 172 | * **GTR-XXL** ([config](t5x_retrieval/configs/models/de_t5_11B.gin), 4.8B parameters): [gs://t5-data/pretrained_models/t5x/retrieval/gtr_xxl](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/gtr_xxl/) 173 | 174 | 175 | # References 176 | 177 | [1] Jianmo, Ni, Gustavo Hernández Ábrego, Noah Constant, Ji Ma, Keith B. Hall, 178 | Daniel Cer, Yinfei Yang. 179 | [Sentence-t5: Scalable sentence encoders from pre-trained text-to-text models.](https://arxiv.org/abs/2108.08877) 180 | ACL 2022. 181 | 182 | [2] Jianmo Ni, Chen Qu, Jing Lu, Zhuyun Dai, Gustavo Hernández Ábrego, 183 | Ji Ma, Vincent Zhao, Yi Luan, Keith B. Hall, Ming-wei Chang, Yinfei Yang. 184 | [Large Dual Encoders Are Generalizable Retrievers.](https://arxiv.org/abs/2112.07899) 185 | December 2021. 186 | 187 | This is not an officially supported Google product. 188 | -------------------------------------------------------------------------------- /t5x_retrieval/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Retrieval Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2022 Google LLC 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Utility functions for training and evaluation. 29 | 30 | """ 31 | 32 | import time 33 | from typing import Callable, Optional, Sequence, Type 34 | 35 | import clu.metrics 36 | import flax 37 | from flax.training import common_utils 38 | import jax 39 | from jax import lax 40 | from jax.experimental import multihost_utils 41 | import jax.numpy as jnp 42 | import numpy as np 43 | import seqio 44 | import sklearn.metrics 45 | from t5x import checkpoints as t5x_checkpoints 46 | from t5x import losses as t5x_losses 47 | from t5x import state_utils as t5x_state_utils 48 | from t5x import utils as t5x_utils 49 | import tensorflow as tf 50 | 51 | DatasetConfig = t5x_utils.DatasetConfig 52 | PyTreeDef = type(jax.tree_util.tree_structure(None)) 53 | 54 | 55 | # ===== Datasets ===== # 56 | def get_batch_unmixed_dataset( 57 | cfg: DatasetConfig, 58 | shard_id: int, 59 | num_shards: int, 60 | feature_converter_cls: Type[seqio.FeatureConverter], 61 | num_epochs: Optional[int] = None, 62 | continue_from_last_checkpoint: bool = False) -> tf.data.Dataset: 63 | """Returns a dataset by sampling each batch from each single task.""" 64 | if continue_from_last_checkpoint: 65 | raise ValueError( 66 | '`continue_from_last_checkpoint` must be set to False as this is not ' 67 | 'supported by this dataset fn.') 68 | del continue_from_last_checkpoint 69 | 70 | if cfg.batch_size % num_shards: 71 | raise ValueError( 72 | f'Batch size ({cfg.batch_size}) must be divisible by number of ' 73 | f'shards ({num_shards}).') 74 | 75 | shard_info = seqio.ShardInfo(index=shard_id, num_shards=num_shards) 76 | 77 | if cfg.seed is None: 78 | # Use a shared timestamp across devices as the seed. 79 | seed = multihost_utils.broadcast_one_to_all(np.int32(time.time())) 80 | else: 81 | seed = cfg.seed 82 | 83 | num_epochs = None # repeat indefinitely. 84 | 85 | mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name) 86 | if not isinstance(mixture_or_task, seqio.Mixture): 87 | raise ValueError('Only SeqIO Mixture supports batch unmixed data accesss') 88 | 89 | datasets = [] 90 | rates = [] 91 | for task in mixture_or_task.tasks: 92 | cfg.mixture_or_task_name = task.name 93 | # Returns a batched dataset. 94 | datasets.append( 95 | t5x_utils.get_dataset_inner(cfg, shard_info, feature_converter_cls, 96 | seed, num_epochs)) 97 | rates.append(mixture_or_task.get_rate(task)) 98 | return tf.data.experimental.sample_from_datasets(datasets, rates) 99 | 100 | 101 | # ===== Losses ===== # 102 | # More details about alignment and uniformity loss can be found at 103 | # https://arxiv.org/pdf/2005.10242.pdf. They measure the quality of embeddings. 104 | def compute_align_loss(x: jnp.array, y: jnp.array, alpha: int = 2): 105 | loss = jnp.linalg.norm(x - y, ord=2, axis=1) 106 | return lax.pow(loss, 1.0 * alpha).mean() 107 | 108 | 109 | def compute_uniform_loss(xs: jnp.array, t: int = 2): 110 | """Computes the euclidean distance between every pair of row vectors in the input.""" 111 | distance_kernel = lambda x, y: jnp.sqrt(jnp.sum((x - y)**2)) 112 | loss = jax.vmap(lambda x: jax.vmap(lambda y: distance_kernel(x, y))(xs))(xs) 113 | return jnp.log(jnp.exp(-t * lax.pow(loss, 2.0)).mean()) 114 | 115 | 116 | def sigmoid_cross_entropy_with_logits(logits: jnp.array, labels: jnp.array): 117 | """Compute binary cross entropy loss with logits input. 118 | 119 | Args: 120 | logits: similarity scores 121 | labels: ground-truth labels 122 | 123 | Returns: 124 | cross-entropy loss 125 | """ 126 | 127 | x, z = logits, labels 128 | # Follow the implementation in Tensorflow: 129 | # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn_impl.py#L186 130 | loss = jnp.maximum(x, 0) - x * z + jnp.log(1 + jnp.exp(-jnp.absolute(x))) 131 | return loss 132 | 133 | 134 | def apply_temperature(probs: jnp.array, temperature: float): 135 | """Apply a temperature to probabilities.""" 136 | if temperature <= 0: 137 | raise ValueError('Temperature must be positive.') 138 | inv_temp = 1.0 / temperature 139 | x = lax.pow(probs, inv_temp) 140 | y = lax.pow(1 - probs, inv_temp) 141 | return x / (x + y) 142 | 143 | 144 | def binary_cross_entropy_with_logits(logits: jnp.array, 145 | labels: jnp.array, 146 | weights=None, 147 | temperature=1.0): 148 | """Binary cross-entropy loss function. 149 | 150 | This loss is used for point-wise distillation where the goal is to match the 151 | point-wise logits for each example to match the gold labels. The temperature 152 | is commonly added on both the logits and the labels as described in previous 153 | distillation methods: https://arxiv.org/abs/1503.02531. 154 | 155 | Args: 156 | logits: similarity scores corresponding to all pairs in the batch, with 157 | shape [batch_size]. 158 | labels: labels for each pair. 159 | weights: weights for each pair. 160 | temperature: temperature for the distillation. 161 | 162 | Returns: 163 | binary cross-entropy loss for distillation. 164 | """ 165 | logits = logits / temperature 166 | labels = apply_temperature(labels, temperature) 167 | loss = sigmoid_cross_entropy_with_logits(logits, labels) 168 | 169 | if weights: 170 | loss = loss * weights 171 | # Multiply the loss by temperature^2, to keep it on the same scale as the 172 | # batch softmax loss. 173 | return (temperature**2) * loss 174 | 175 | 176 | def sparse_labels_for_in_batch_cross_entropy(logits: jnp.array) -> jnp.array: 177 | """Generates labels assuming the diagnoal |logits| are ground truth.""" 178 | return jnp.arange(logits.shape[0]) 179 | 180 | 181 | def in_batch_cross_entropy( 182 | logits: jnp.array, 183 | labels: Optional[jnp.array] = None, 184 | weights: Optional[jnp.array] = None, 185 | off_diagonal_positive_mask: Optional[jnp.array] = None, 186 | num_negatives: jnp.int32 = 0, 187 | label_smoothing: float = 0.0, 188 | reduce_fn: Optional[Callable[[jnp.array], 189 | jnp.array]] = jnp.mean) -> jnp.array: 190 | """In batch cross-entropy loss function. 191 | 192 | This corresponds to computing a softmax for each row (and optionally each 193 | column) of the similarities matrix, where the diagonal element corresponds 194 | to the only positive pair, and off-diagonal elements are random negatives. 195 | 196 | Args: 197 | logits: [batch_size, batch_size + sample_size]. Tensor of similarities 198 | between all pairs of a left and a right element in a batch, with shape 199 | [batch_size, batch_size + sample_size] where sample_size is the number of 200 | extra negative right examples, if any. 201 | labels: [batch_size, batch_size + sample_size]. If None, then this function 202 | generates one hot labels which assumes the diagonal elements correspond to 203 | the ground truth. 204 | weights: [batch_size]. Weights for each pair (or row). 205 | off_diagonal_positive_mask: a Boolean Tensor of [batch_size, batch_size]. 206 | A Tensor masks the off-diagonal positives. With off-diagonal positives 207 | masked by 1, and others with 0, including the diagonal positives. 208 | num_negatives: number of negative examples. Default to 0. 209 | label_smoothing: label smoothing constant, used to determine the on and off 210 | values. 211 | reduce_fn: Callable on how to reduce losses to a scalar. If none, returns 212 | row-wise loss. 213 | 214 | Returns: 215 | [batch_size] array of cross entropy losses if reduce_fn is None. Otherwise, 216 | return a scalar of reduced row loss. 217 | """ 218 | if (num_negatives > 0) and (off_diagonal_positive_mask is None): 219 | raise ValueError( 220 | 'num_negatives (positive integer) is used only in create ' 221 | 'off_diagonal_positive_mask.') 222 | 223 | normalizing_constant = 0.0 224 | if labels is None: 225 | num_classes = logits.shape[-1] 226 | sparse_labels = sparse_labels_for_in_batch_cross_entropy(logits) 227 | confidence, low_confidence, normalizing_constant = 1.0, 0.0, 0.0 228 | if num_classes > 1: 229 | confidence = 1.0 - label_smoothing 230 | low_confidence = (1.0 - confidence) / (num_classes - 1) 231 | normalizing_constant = -( 232 | confidence * jnp.log(confidence) + 233 | (num_classes - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) 234 | labels = common_utils.onehot( 235 | sparse_labels, 236 | num_classes, 237 | on_value=confidence, 238 | off_value=low_confidence) 239 | 240 | if off_diagonal_positive_mask is None: 241 | row_loss, _ = t5x_losses.cross_entropy_with_logits( 242 | logits, labels, z_loss=0.0) 243 | else: 244 | off_diagonal_positive_mask = off_diagonal_positive_mask * ( 245 | 1 - jnp.identity(logits.shape[0])) 246 | if num_negatives > 0: 247 | off_diagonal_positive_mask = jnp.pad( 248 | off_diagonal_positive_mask, [(0, 0), (0, num_negatives)]) 249 | masked_logits = jnp.where(off_diagonal_positive_mask, -1e+9, logits) 250 | row_loss, _ = t5x_losses.cross_entropy_with_logits( 251 | masked_logits, labels, z_loss=0.0) 252 | 253 | row_loss = row_loss - normalizing_constant 254 | if weights: 255 | row_loss = row_loss * weights 256 | 257 | return reduce_fn(row_loss) if reduce_fn is not None else row_loss 258 | 259 | 260 | def get_off_diagonal_positive_mask(reference_labels: jnp.ndarray, 261 | return_dtype: jnp.dtype = bool): 262 | """Construct mask for off diagonal positives from reference labels. 263 | 264 | For a given batch, the off diagonal mask is constructed by comparing the 265 | reference labels across examples. E.g. if the batch size is 3, and 266 | the reference label is [11, 22, 11], the off diagonal positive mask is 267 | [[0, 0, 1], 268 | [0, 0, 0], 269 | [1, 0, 0]] 270 | indicating the 1st and the 3nd are mutually positive, as their labels are 271 | same. 272 | 273 | Args: 274 | reference_labels: a [batch] Tensor, which stores the reference information 275 | to construct the mask. 276 | return_dtype: dtype for returned mask. 277 | 278 | Returns: 279 | mask: a [batch, batch] Boolean Tensor for off-diagonal positive mask. 280 | """ 281 | reference_labels = jnp.squeeze(reference_labels) 282 | if reference_labels.ndim > 1: 283 | raise ValueError('Provide only 1 reference label per example in batch! ' 284 | 'The reference_labels should be of the shape [batch].') 285 | mask = (reference_labels[None, :] == reference_labels[:, None] 286 | - jnp.identity(reference_labels.shape[0], dtype=bool)) 287 | return mask.astype(return_dtype) 288 | 289 | 290 | # ===== Metrics ===== # 291 | @flax.struct.dataclass 292 | class AUC(clu.metrics.CollectingMetric.from_outputs(('labels', 'logits'))): 293 | 294 | def compute(self): 295 | labels_sum = jnp.sum(self.values['labels']) 296 | # Do not compute AUC if positives only have one class. 297 | if labels_sum == 0 or labels_sum == len(self.values['labels']): 298 | return 0.0 299 | return sklearn.metrics.roc_auc_score(self.values['labels'], 300 | self.values['logits']) 301 | 302 | 303 | def compute_auc(targets: jnp.array, 304 | predictions: jnp.array, 305 | targets_threshold=None): 306 | """Compute Area Under the ROC and PR curves. 307 | 308 | ROC - Receiver Operating Characteristic 309 | PR - Precision and Recall 310 | 311 | Args: 312 | targets: np.ndarray of targets, either 0 or 1, or continuous values. 313 | predictions: np.ndarray of predictions, any value. 314 | targets_threshold: float, if target values are continuous values, this 315 | threshold binarizes them. 316 | 317 | Returns: 318 | A dictionary with AUC-ROC and AUC-PR scores. 319 | """ 320 | 321 | if targets_threshold is not None: 322 | targets = jnp.array(targets) 323 | targets = jnp.where(targets < targets_threshold, 324 | jnp.zeros_like(targets, dtype=jnp.int32), 325 | jnp.ones_like(targets, dtype=jnp.int32)) 326 | 327 | a = jnp.min(predictions) 328 | b = jnp.max(predictions) 329 | scale = 3.0 / (b - a + 1e-6) 330 | scaled_predictions = scale * (2 * predictions - b - a) 331 | transformed_predictions = jax.nn.sigmoid(scaled_predictions) 332 | binarized_targets = jnp.round(targets) 333 | return { 334 | 'auc-roc': 335 | AUC.from_model_output( 336 | logits=transformed_predictions, labels=binarized_targets), 337 | } 338 | 339 | 340 | def compute_rr(logits: jnp.array, labels: jnp.array): 341 | """Compute Reciprocal Rank for in-batch examples. 342 | 343 | Args: 344 | logits: jnp.array of logits of shape [batch_size, batch_size] 345 | labels: jnp.array of indices indicating the positive example. 346 | 347 | Returns: 348 | An jnp.array of reciprocal rank of the positive example in-batch. 349 | """ 350 | labels = jnp.expand_dims(labels, axis=-1) 351 | logits_desc = np.argsort(-logits, axis=-1) 352 | rank = ( 353 | jnp.argwhere(logits_desc == labels, size=logits_desc.shape[0])[:, -1] + 1) 354 | return jnp.reciprocal(rank) 355 | 356 | 357 | # ===== Checkpoint ===== # 358 | def partially_load_checkpoint( 359 | excluded_patterns: Sequence[str], 360 | require_all_rules_match: bool = True 361 | ) -> t5x_checkpoints.RestoreStateTransformationFn: 362 | """Load a checkpoint partially, used in exports to trim the output SavedModel graph. 363 | 364 | Args: 365 | excluded_patterns: Checkpoint Optimizer param patterns to exclude from the 366 | export. 367 | require_all_rules_match: Whether to verify that all the patterns match 368 | correctly to a path in the checkpoint. 369 | 370 | Returns: 371 | A RestoreStateTransformationFn that excludes the pattern specified in the 372 | """ 373 | assignment_map = [(pattern, None) for pattern in excluded_patterns] 374 | 375 | def _wrapped_assignment_map( 376 | ckpt_optimizer_state: PyTreeDef, 377 | _: PyTreeDef, # pylint: disable=unused-argument 378 | *, 379 | is_resuming: bool = False): 380 | """Remap the optimizer state to checkpoint optimizer state. 381 | 382 | Setting assignment maps in RestoreCheckpointConfig in load_t5x_checkpoint 383 | sets optimizer state to an empty dict, failing the assignments. 384 | Args: 385 | ckpt_optimizer_state: Checkpoint param state. 386 | is_resuming: `True` iff this restore call is due to a job resuming after 387 | being temporarily stopped due to, for example, a preemption. This is 388 | useful when there is restore logic that should run when restoring from 389 | some pre-existing checkpoint, but that should not run again when 390 | resuming from a newly-written checkpoint. 391 | 392 | Returns: 393 | The result of transforming the checkpoint state dict. 394 | """ 395 | return t5x_state_utils.apply_assignment_map( 396 | ckpt_optimizer_state, 397 | ckpt_optimizer_state, # Optimizer State 398 | assignment_map, 399 | require_all_rules_match, 400 | is_resuming=is_resuming) 401 | 402 | return _wrapped_assignment_map 403 | 404 | 405 | def load_tower(side: str) -> t5x_checkpoints.RestoreStateTransformationFn: 406 | """Load a single `side` tower of an Asymmetric Dual Encoder. 407 | 408 | Args: 409 | side: The side of the tower to load. Available values are [left, right] 410 | 411 | Returns: 412 | A restore state transformation function that filters out the weights of the 413 | other tower. Only set if inference mode is set to `encode_{side}`. 414 | """ 415 | assert side in ('left', 'right'), ( 416 | f'Expected side to be one of [left, right], but is {side}') 417 | if side == 'left': 418 | return partially_load_checkpoint(excluded_patterns=[r'.*right_.*']) 419 | else: 420 | return partially_load_checkpoint(excluded_patterns=[r'.*left_.*']) 421 | -------------------------------------------------------------------------------- /t5x_retrieval/models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Retrieval Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2022 Google LLC 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """T5X Retrieval Models. 29 | 30 | This module uses Flaxformer modules to build a higher-level model structure and 31 | define methods for the loss computation as well as a train, prediction, and 32 | evaluation steps. 33 | """ 34 | # pylint: disable=attribute-defined-outside-init,g-bare-generic,g-multiple-import 35 | from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union 36 | 37 | from flax import linen as nn 38 | from flax import optim 39 | from flax.core import scope as flax_scope 40 | from flax.training import common_utils 41 | import jax 42 | from jax import lax 43 | import jax.numpy as jnp 44 | import ml_collections 45 | import numpy as np 46 | 47 | import seqio 48 | from t5x import losses as t5x_losses 49 | from t5x import metrics as metrics_lib 50 | from t5x import models as t5x_models 51 | from t5x import utils as t5x_utils 52 | from t5x_retrieval import losses 53 | from t5x_retrieval import utils 54 | import tensorflow as tf 55 | 56 | Array = Union[np.ndarray, jnp.ndarray, jax.pxla.ShardedDeviceArray, tf.Tensor] 57 | DType = jnp.dtype 58 | ConfigDict = ml_collections.ConfigDict 59 | PyTreeDef = type(jax.tree_util.tree_structure(None)) 60 | Optimizer = optim.Optimizer 61 | 62 | 63 | LEFT_ENCODINGS = 'left_encodings' 64 | RIGHT_ENCODINGS = 'right_encodings' 65 | 66 | 67 | class DualEncoderBase(t5x_models.BaseTransformerModel): 68 | """Base class for dual encoder models. 69 | 70 | Subclasses must implement `score_batch` and `_compute_logits`. 71 | """ 72 | 73 | FEATURE_CONVERTER_CLS: Callable[..., seqio.FeatureConverter] 74 | 75 | ALLOWED_INFERENCE_MODE = frozenset({'encode', 'similarity'}) 76 | 77 | # TODO(b/262639556): Change loss_module from Optional to required once 78 | # loss-layers have been implemented 79 | def __init__( 80 | self, 81 | module: nn.Module, 82 | feature_converter_cls: Callable[[bool], seqio.FeatureConverter], 83 | input_vocabulary: seqio.Vocabulary, 84 | output_vocabulary: seqio.Vocabulary, 85 | optimizer_def: optim.OptimizerDef, 86 | inference_mode: str = 'encode', 87 | loss_module_factory: Optional[nn.Module] = None, 88 | ): 89 | self.FEATURE_CONVERTER_CLS = feature_converter_cls # pylint: disable=invalid-name 90 | self._inference_mode = inference_mode 91 | 92 | self.loss_module = None 93 | # TODO(b/262639556): Remove check once loss-layer is not Optional 94 | if loss_module_factory: 95 | self.loss_module = loss_module_factory() 96 | self.loss_module.validate_model_features(feature_converter_cls(False)) 97 | 98 | super(DualEncoderBase, self).__init__( 99 | module=module, 100 | input_vocabulary=input_vocabulary, 101 | output_vocabulary=output_vocabulary, 102 | optimizer_def=optimizer_def) 103 | 104 | def get_initial_variables( 105 | self, 106 | rng: jnp.ndarray, 107 | input_shapes: Mapping[str, Array], 108 | input_types: Optional[Mapping[str, DType]] = None 109 | ) -> flax_scope.FrozenVariableDict: 110 | """Get the initial variables for an dual-encoder model.""" 111 | input_types = {} if input_types is None else input_types 112 | encoder_type = input_types.get('left_encoder_input_tokens', jnp.float32) 113 | left_encoder_shape = input_shapes['left_encoder_input_tokens'] 114 | right_encoder_shape = input_shapes['right_encoder_input_tokens'] 115 | initial_variables = self.module.init( 116 | rng, 117 | jnp.ones(left_encoder_shape, encoder_type), 118 | jnp.ones(right_encoder_shape, encoder_type), 119 | enable_dropout=False) 120 | 121 | # TODO(b/262639556): Remove check once loss-layer is not Optional 122 | if self.loss_module: 123 | loss_variables = self.loss_module.get_initial_variables( 124 | rng, input_shapes, input_types) 125 | initial_variables = initial_variables.copy(loss_variables) 126 | 127 | return initial_variables 128 | 129 | def loss_weights(self, batch: Mapping[str, 130 | jnp.ndarray]) -> Optional[jnp.ndarray]: 131 | raise NotImplementedError('Not implemented for dual encoder.') 132 | 133 | def predict_batch_with_aux( 134 | self, 135 | params: Mapping[str, Array], 136 | batch: Mapping[str, jnp.ndarray], 137 | rng: Optional[jax.random.KeyArray] = None, 138 | ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: 139 | raise NotImplementedError( 140 | 'Autoregressive prediction is not implemented for dual encoder.') 141 | 142 | def _encode_batch(self, params: Mapping[str, Array], 143 | batch: Mapping[str, jnp.ndarray]) -> Array: 144 | """Encode the embeddings for the inputs.""" 145 | return self.module.apply( 146 | {'params': params}, 147 | batch['left_encoder_input_tokens'], 148 | # Disable the dropout during inference. 149 | enable_dropout=False, 150 | method=self.module.encode) 151 | 152 | def _similarity_batch(self, 153 | params: Mapping[str, Array], 154 | batch: Mapping[str, jnp.ndarray], 155 | return_intermediates: bool = False) -> Array: 156 | """Score the similarity of the left and right inputs.""" 157 | left_encodings, right_encodings, logits = self.module.apply( 158 | {'params': params}, 159 | batch['left_encoder_input_tokens'], 160 | batch['right_encoder_input_tokens'], 161 | enable_dropout=False) 162 | if return_intermediates: 163 | return logits, { 164 | LEFT_ENCODINGS: (left_encodings,), 165 | RIGHT_ENCODINGS: (right_encodings,), 166 | } 167 | else: 168 | return logits 169 | 170 | def score_batch(self, 171 | params: Mapping[str, Array], 172 | batch: Mapping[str, jnp.ndarray], 173 | return_intermediates: bool = False) -> jnp.ndarray: 174 | """Model prediction for batch. 175 | 176 | Args: 177 | params: Model parameters. 178 | batch: A batch of inputs. 179 | return_intermediates: Whether to return intermediates. 180 | 181 | Returns: 182 | an array of encodings or similarity scores (with optional intermediates). 183 | """ 184 | if self._inference_mode not in self.ALLOWED_INFERENCE_MODE: 185 | raise ValueError( 186 | 'Invalid `inference_mode`: %s. Supported inference mode: %s.' % 187 | (self._inference_mode, self.ALLOWED_INFERENCE_MODE)) 188 | if self._inference_mode == 'encode': 189 | return self._encode_batch(params, batch) 190 | elif self._inference_mode == 'similarity': 191 | return self._similarity_batch(params, batch, return_intermediates) 192 | 193 | 194 | class DualEncoderModel(DualEncoderBase): 195 | """Model class for Dual Encoder.""" 196 | 197 | ALLOWED_INFERENCE_MODE = frozenset({ 198 | 'encode', 'encode_left', 'encode_right', 'similarity', 199 | 'pointwise_similarity' 200 | }) 201 | 202 | def __init__( 203 | self, 204 | module: nn.Module, 205 | feature_converter_cls: Callable[[bool], seqio.FeatureConverter], 206 | input_vocabulary: seqio.Vocabulary, 207 | output_vocabulary: seqio.Vocabulary, 208 | optimizer_def: optim.OptimizerDef, 209 | loss_module_factory: Optional[losses.DualEncoderLoss] = None, 210 | inference_mode: str = 'encode', 211 | use_negatives: bool = False, 212 | use_align_uniform: bool = False, 213 | logit_scale: float = 100, 214 | logit_margin: float = 0.0, 215 | ): 216 | """Initialization function. 217 | 218 | Args: 219 | module: Flax module. 220 | feature_converter_cls: SeqIO feature converters to apply to the dataset. 221 | input_vocabulary: Vocabulary for the input features. 222 | output_vocabulary: Vocabulary for the output features. 223 | optimizer_def: Optimizer. 224 | loss_module_factory: Factory to produce loss module. 225 | inference_mode: Inference mode (e.g. encode or similarity). 226 | use_negatives: Whether to use hard negatives. If True, the model encodes 227 | the additional feature for hard negatives on the right tower. 228 | use_align_uniform: Whether to compute alignment and uniformity metrics. If 229 | True, compute alignment and uniformity metrics. 230 | logit_scale: A factor for logits scaling. 231 | logit_margin: A constant for logits margin. 232 | """ 233 | self._use_negatives = use_negatives 234 | self._use_align_uniform = use_align_uniform 235 | self._logit_scale = logit_scale 236 | self._logit_margin = logit_margin 237 | super(DualEncoderModel, self).__init__( 238 | module=module, 239 | feature_converter_cls=feature_converter_cls, 240 | input_vocabulary=input_vocabulary, 241 | output_vocabulary=output_vocabulary, 242 | optimizer_def=optimizer_def, 243 | inference_mode=inference_mode, 244 | loss_module_factory=loss_module_factory) 245 | 246 | def _compute_logits( 247 | self, 248 | params: Mapping[str, Any], 249 | batch: Mapping[str, jnp.ndarray], 250 | dropout_rng: Optional[jnp.ndarray] = None 251 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: 252 | """Computes logits via a forward pass of `self.module_cls`.""" 253 | # Dropout is provided only for the training mode. 254 | rngs = {'dropout': dropout_rng} if dropout_rng is not None else None 255 | 256 | if not self._use_negatives and 'right_negative_encoder_input_tokens' in batch: 257 | ValueError( 258 | 'Invalid module. Please select `DualEncoderWithNegativesModel` for negative inputs.' 259 | ) 260 | 261 | if self._use_negatives and 'right_negative_encoder_input_tokens' not in batch: 262 | ValueError( 263 | 'Invalid inputs. Please prepare negative inputs for DualEncoderWithNegativesModel.' 264 | ) 265 | 266 | if self._use_negatives: 267 | left_tokens = batch['left_encoder_input_tokens'] 268 | right_positive_tokens = batch['right_encoder_input_tokens'] 269 | right_negative_tokens = batch['right_negative_encoder_input_tokens'] 270 | 271 | # left/right_encoder_input_tokens should be 2d tensor. 272 | assert left_tokens.ndim == 2 273 | assert right_positive_tokens.ndim == 2 274 | # right_negative_encoder_input_tokens can be a 2d tensor (when feature 275 | # spec set up for single negative) or a 3d tensor (when feature spec 276 | # set up for multiple negatives). 277 | assert right_negative_tokens.ndim == 2 or right_negative_tokens.ndim == 3 278 | 279 | # All tensors should have the same batch size. 280 | batch_size = right_positive_tokens.shape[0] 281 | assert left_tokens.shape[0] == batch_size 282 | assert right_negative_tokens.shape[0] == batch_size 283 | 284 | if right_negative_tokens.ndim == 3: 285 | # We have multiple negatives, so need to reshape the 286 | # right_negative_encoder_input_tokens. 287 | 288 | # Right positive and negative should have the same sequence length. 289 | right_seq_length = right_positive_tokens.shape[1] 290 | assert right_seq_length == right_negative_tokens.shape[2] 291 | 292 | num_negatives = right_negative_tokens.shape[1] 293 | right_negative_tokens = jnp.reshape( 294 | right_negative_tokens, 295 | (batch_size * num_negatives, right_seq_length)) 296 | 297 | (left_encodings, right_encodings, 298 | logits), _ = self.module.apply({'params': params}, 299 | left_tokens, 300 | right_positive_tokens, 301 | right_negative_tokens, 302 | enable_dropout=rngs is not None, 303 | rngs=rngs, 304 | mutable='dropout') 305 | 306 | # `left_logits` is of shape [B, B*(1+num_negatives)] that considers the 307 | # negatives while `right_logits` is in shape [B, B] that doesn't considers 308 | # negatives. `num_negatives` could be greater than 1 in the future. 309 | left_logits, right_logits = logits, jnp.dot(right_encodings, 310 | left_encodings.transpose()) 311 | else: 312 | (left_encodings, right_encodings, logits), _ = self.module.apply( 313 | {'params': params}, 314 | batch['left_encoder_input_tokens'], 315 | batch['right_encoder_input_tokens'], 316 | enable_dropout=rngs is not None, 317 | rngs=rngs, 318 | mutable='dropout') 319 | left_logits, right_logits = logits, logits.transpose() 320 | 321 | left_logits *= self._logit_scale 322 | right_logits *= self._logit_scale 323 | 324 | # Only additive margin to the logits for training mode. 325 | # For details please check https://arxiv.org/abs/1902.08564. The tensor 326 | # shapes are not changed after scaling. 327 | if dropout_rng is not None and self._logit_margin != 0: 328 | left_logits = ( 329 | left_logits - self._logit_margin * 330 | jnp.eye(N=left_logits.shape[0], M=left_logits.shape[1])) 331 | right_logits = ( 332 | right_logits - self._logit_margin * jnp.eye(right_logits.shape[0])) 333 | 334 | return left_encodings, right_encodings, left_logits, right_logits 335 | 336 | def loss_fn( 337 | self, 338 | params: Mapping[str, Any], 339 | batch: Mapping[str, jnp.ndarray], 340 | dropout_rng: Optional[jnp.ndarray], 341 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 342 | """Loss function used for training with a cross-entropy loss.""" 343 | 344 | left_encodings, right_encodings, left_logits, right_logits = self._compute_logits( 345 | params, batch, dropout_rng) 346 | loss, metrics = self.loss_module.apply({'params': params}, 347 | batch=batch, 348 | logits=left_logits, 349 | right_logits=right_logits) 350 | if self._use_align_uniform: 351 | align_loss = utils.compute_align_loss(left_encodings, right_encodings) 352 | uniform_loss = utils.compute_uniform_loss( 353 | left_encodings) + utils.compute_uniform_loss(right_encodings) 354 | metrics.update({ 355 | 'align_loss': 356 | metrics_lib.AveragePerStep.from_model_output(align_loss), 357 | 'uniform_loss': 358 | metrics_lib.AveragePerStep.from_model_output(uniform_loss), 359 | }) 360 | 361 | return loss, metrics 362 | 363 | def _encode_batch(self, params: Mapping[str, Array], 364 | batch: Mapping[str, jnp.ndarray]) -> Array: 365 | """Encode the embeddings for the inputs.""" 366 | if self._inference_mode == 'encode_right': 367 | encoder_input_tokens = batch['right_encoder_input_tokens'] 368 | else: 369 | encoder_input_tokens = batch['left_encoder_input_tokens'] 370 | return self.module.apply( 371 | {'params': params}, 372 | encoder_input_tokens, 373 | # Disable the dropout during inference. 374 | enable_dropout=False, 375 | method=self.module.encode) 376 | 377 | def score_batch(self, 378 | params: Mapping[str, Array], 379 | batch: Mapping[str, jnp.ndarray], 380 | return_intermediates: bool = False) -> jnp.ndarray: 381 | """Model prediction for batch. 382 | 383 | Args: 384 | params: Model parameters. 385 | batch: A batch of inputs. 386 | return_intermediates: Whether to return intermediates. 387 | 388 | Returns: 389 | an array of encodings or similarity scores (with optional intermediates). 390 | """ 391 | if self._inference_mode not in self.ALLOWED_INFERENCE_MODE: 392 | raise ValueError( 393 | 'Invalid `inference_mode`: %s. Supported inference mode: %s.' % 394 | (self._inference_mode, self.ALLOWED_INFERENCE_MODE)) 395 | if self._inference_mode.startswith('encode'): 396 | return self._encode_batch(params, batch) 397 | elif self._inference_mode == 'similarity': 398 | return self._similarity_batch(params, batch, return_intermediates) 399 | elif self._inference_mode == 'pointwise_similarity': 400 | if return_intermediates: 401 | logits, intermediates = ( 402 | self._similarity_batch(params, batch, return_intermediates)) 403 | return jnp.diagonal(logits), intermediates 404 | else: 405 | logits = self._similarity_batch(params, batch, return_intermediates) 406 | return jnp.diagonal(logits) 407 | 408 | 409 | --------------------------------------------------------------------------------