├── t5x_retrieval
├── testdata
│ └── test_de_t5_tiny.checkpoint_1000
│ │ ├── checkpoint
│ │ ├── target.encoder.layers_0.mlp.wi.kernel
│ │ ├── .zarray
│ │ └── 0.0
│ │ ├── target.encoder.layers_0.mlp.wo.kernel
│ │ ├── .zarray
│ │ └── 0.0
│ │ ├── target.encoder.layers_1.mlp.wi.kernel
│ │ ├── .zarray
│ │ └── 0.0
│ │ ├── target.encoder.layers_1.mlp.wo.kernel
│ │ ├── .zarray
│ │ └── 0.0
│ │ ├── target.token_embedder.embedding
│ │ ├── .zarray
│ │ ├── 0.0
│ │ ├── 1.0
│ │ ├── 2.0
│ │ ├── 3.0
│ │ ├── 4.0
│ │ ├── 5.0
│ │ ├── 6.0
│ │ └── 7.0
│ │ ├── target.encoder.layers_0.attention.key.kernel
│ │ ├── .zarray
│ │ └── 0.0
│ │ ├── target.encoder.layers_0.attention.out.kernel
│ │ ├── .zarray
│ │ └── 0.0
│ │ ├── target.encoder.layers_0.attention.query.kernel
│ │ ├── .zarray
│ │ └── 0.0
│ │ ├── target.encoder.layers_0.attention.value.kernel
│ │ ├── .zarray
│ │ └── 0.0
│ │ ├── target.encoder.layers_1.attention.key.kernel
│ │ ├── .zarray
│ │ └── 0.0
│ │ ├── target.encoder.layers_1.attention.out.kernel
│ │ ├── .zarray
│ │ └── 0.0
│ │ ├── target.encoder.layers_1.attention.query.kernel
│ │ ├── .zarray
│ │ └── 0.0
│ │ └── target.encoder.layers_1.attention.value.kernel
│ │ ├── .zarray
│ │ └── 0.0
├── configs
│ ├── runs
│ │ ├── trainer.gin
│ │ ├── infer_eval.gin
│ │ ├── pretrain.gin
│ │ ├── eval.gin
│ │ ├── infer.gin
│ │ └── finetune.gin
│ ├── models
│ │ ├── de_t5_3B.gin
│ │ ├── de_t5_11B.gin
│ │ ├── de_t5_large.gin
│ │ ├── de_mt5_xl.gin
│ │ ├── de_mt5_large.gin
│ │ ├── de_mt5_xxl.gin
│ │ ├── de_t5_1_1_xl.gin
│ │ ├── de_t5_1_1_xxl.gin
│ │ ├── de_t5_1_1_large.gin
│ │ ├── de_longt5_1_1_transient_global_large.gin
│ │ ├── de_t5_tiny.gin
│ │ ├── de_mt5_small.gin
│ │ ├── de_mt5_tiny.gin
│ │ ├── de_t5_1_1_tiny.gin
│ │ ├── de_t5_base.gin
│ │ ├── de_t5_1_1_base.gin
│ │ ├── de_mt5_base.gin
│ │ └── de_longt5_1_1_transient_global_base.gin
│ └── architectures
│ │ ├── de_t5_1_1_flaxformer.gin
│ │ ├── de_t5_flaxformer.gin
│ │ └── de_longt5_1_1_transient_global_flaxformer.gin
├── version.py
├── partitioning.py
├── __init__.py
├── adafactor_utils.py
├── metrics.py
├── postprocessors.py
├── preprocessors.py
├── tasks.py
├── losses.py
├── feature_converters.py
├── utils.py
└── models.py
├── CONTRIBUTING.md
├── LICENSE
└── README.md
/t5x_retrieval/testdata/test_de_t5_tiny.checkpoint_1000/checkpoint:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/t5x_retrieval/HEAD/t5x_retrieval/testdata/test_de_t5_tiny.checkpoint_1000/checkpoint
--------------------------------------------------------------------------------
/t5x_retrieval/configs/runs/trainer.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | from t5x import trainer
4 |
5 | # Parameters for trainer.Trainer:
6 | trainer.Trainer.weight_metrics_computer = @trainer.WeightMetricsComputer()
7 |
--------------------------------------------------------------------------------
/t5x_retrieval/testdata/test_de_t5_tiny.checkpoint_1000/target.encoder.layers_0.mlp.wi.kernel/.zarray:
--------------------------------------------------------------------------------
1 | {"chunks":[4,8],"compressor":{"id":"gzip","level":1},"dtype":" to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/t5x_retrieval/configs/architectures/de_t5_1_1_flaxformer.gin:
--------------------------------------------------------------------------------
1 | # Flaxformer implementation of Dual Encoder, based on T5.1.1 architecture.
2 | #
3 | # Required to be overridden:
4 | #
5 | # - NUM_HEADS
6 | # - NUM_LAYERS
7 | # - HEAD_DIM
8 | # - EMBED_DIM
9 | # - MLP_DIM
10 | # - PROJECTION_DIM
11 | from __gin__ import dynamic_registration
12 |
13 | from flax import linen
14 |
15 | from flaxformer.components import dense
16 | from flaxformer.components import layer_norm
17 |
18 | include 't5x_retrieval/configs/architectures/de_t5_flaxformer.gin'
19 |
20 | # Additional constants (may be overridden)
21 | SCALE = 1.0
22 |
23 | # Projection layer
24 | projection_layer/linen.initializers.variance_scaling:
25 | scale = %SCALE
26 |
27 | # Attention (encoder, decoder, self-attention)
28 | attention_kernel_init/linen.initializers.variance_scaling:
29 | scale = %SCALE
30 |
31 | # Relative position biases (encoder, decoder)
32 | relative_position_bias_init/linen.initializers.variance_scaling:
33 | scale = %SCALE
34 |
35 | # MLP (encoder)
36 | dense.MlpBlock:
37 | activations = ('gelu', 'linear')
38 | mlp_kernel_init/linen.initializers.variance_scaling:
39 | scale = %SCALE
40 |
41 | layer_norm.T5LayerNorm.dtype = %ACTIVATION_DTYPE
42 |
--------------------------------------------------------------------------------
/t5x_retrieval/configs/runs/infer_eval.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | import __main__ as train_script
4 | import seqio
5 | from t5x import utils
6 |
7 | # Convenience overrides.
8 | EVALUATOR_USE_MEMORY_CACHE = False
9 | EVALUATOR_NUM_EXAMPLES = None # None means use all examples in the infer_eval dataset.
10 | JSON_WRITE_N_RESULTS = None # None means write all inferences.
11 |
12 | train_script.train:
13 | infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
14 | inference_evaluator_cls = @seqio.Evaluator
15 |
16 | EVAL_MIXTURE_OR_TASK_NAME = %MIXTURE_OR_TASK_NAME
17 | EVAL_TASK_FEATURE_LENGTHS = %TASK_FEATURE_LENGTHS
18 | EVAL_BATCH_SIZE = %BATCH_SIZE
19 |
20 | infer_eval/utils.DatasetConfig:
21 | mixture_or_task_name = %EVAL_MIXTURE_OR_TASK_NAME
22 | task_feature_lengths = %EVAL_TASK_FEATURE_LENGTHS
23 | split = 'validation'
24 | batch_size = %EVAL_BATCH_SIZE
25 | shuffle = False
26 | seed = 42
27 | use_cached = %USE_CACHED_TASKS
28 | pack = False
29 | module = %MIXTURE_OR_TASK_MODULE
30 |
31 | seqio.Evaluator:
32 | logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
33 | num_examples = %EVALUATOR_NUM_EXAMPLES
34 | use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE
35 |
36 | seqio.JSONLogger:
37 | write_n_results = %JSON_WRITE_N_RESULTS
38 |
--------------------------------------------------------------------------------
/t5x_retrieval/configs/models/de_t5_base.gin:
--------------------------------------------------------------------------------
1 | # Dual Encoder based on original T5 (1.0) Base model.
2 | # Provides MODEL
3 | from __gin__ import dynamic_registration
4 |
5 | import seqio
6 | from t5x import adafactor
7 | from t5x_retrieval import feature_converters
8 | from t5x_retrieval import models
9 | from t5x_retrieval import losses
10 |
11 | ARCHITECTURE = %gin.REQUIRED
12 | LOSS_MODULE = @losses.InBatchCrossEntropyLoss
13 |
14 | include 't5x_retrieval/configs/architectures/de_t5_flaxformer.gin'
15 |
16 | # Architecture overrides
17 | NUM_HEADS = 12
18 | NUM_LAYERS = 12
19 | HEAD_DIM = 64
20 | EMBED_DIM = 768
21 | MLP_DIM = 3072
22 | PROJECTION_DIM = 768
23 |
24 | # Vocabulary (shared by encoder and decoder)
25 | VOCABULARY = @seqio.SentencePieceVocabulary()
26 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model"
27 |
28 | # Optimizer
29 | # `learning_rate` is set by `Trainer.learning_rate_fn`.
30 | OPTIMIZER = @adafactor.Adafactor()
31 | adafactor.Adafactor:
32 | decay_rate = 0.8
33 | step_offset = 0
34 |
35 | # Model
36 | MODEL = @models.DualEncoderModel()
37 | models.DualEncoderModel:
38 | use_negatives = False
39 | use_align_uniform = False
40 | feature_converter_cls = @feature_converters.DualEncoderFeatureConverterFactory()
41 | module = %ARCHITECTURE # provided by t5_flaxformer
42 | loss_module_factory = %LOSS_MODULE
43 | input_vocabulary = %VOCABULARY
44 | output_vocabulary = %VOCABULARY
45 | optimizer_def = %OPTIMIZER
46 |
--------------------------------------------------------------------------------
/t5x_retrieval/version.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Retrieval Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Copyright 2022 Google LLC
16 | #
17 | # Licensed under the Apache License, Version 2.0 (the "License");
18 | # you may not use this file except in compliance with the License.
19 | # You may obtain a copy of the License at
20 | #
21 | # http://www.apache.org/licenses/LICENSE-2.0
22 | #
23 | # Unless required by applicable law or agreed to in writing, software
24 | # distributed under the License is distributed on an "AS IS" BASIS,
25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 | # See the License for the specific language governing permissions and
27 | # limitations under the License.
28 | r"""Separate file for storing the current version of T5X Retrieval.
29 |
30 | Stored in a separate file so that setup.py can reference the version without
31 | pulling in all the dependencies in __init__.py.
32 | """
33 | __version__ = '0.0.0'
34 |
--------------------------------------------------------------------------------
/t5x_retrieval/configs/models/de_t5_1_1_base.gin:
--------------------------------------------------------------------------------
1 | # Dual Encoder based on T5 (1.1) Base model.
2 | # Provides MODEL
3 | from __gin__ import dynamic_registration
4 |
5 | import seqio
6 | from t5x import adafactor
7 | from t5x_retrieval import feature_converters
8 | from t5x_retrieval import models
9 | from t5x_retrieval import losses
10 |
11 | ARCHITECTURE = %gin.REQUIRED
12 | T5XR_INFERENCE_MODE = 'encode'
13 | LOSS_MODULE = @losses.InBatchCrossEntropyLoss
14 |
15 | include 't5x_retrieval/configs/architectures/de_t5_1_1_flaxformer.gin'
16 |
17 | # Architecture overrides
18 | NUM_HEADS = 12
19 | NUM_LAYERS = 12
20 | HEAD_DIM = 64
21 | EMBED_DIM = 768
22 | MLP_DIM = 2048
23 | PROJECTION_DIM = 768
24 |
25 | # Vocabulary (shared by encoder and decoder)
26 | VOCABULARY = @seqio.SentencePieceVocabulary()
27 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model"
28 |
29 | # Optimizer
30 | # `learning_rate` is set by `Trainer.learning_rate_fn`.
31 | OPTIMIZER = @adafactor.Adafactor()
32 | adafactor.Adafactor:
33 | decay_rate = 0.8
34 | step_offset = 0
35 |
36 | # Model
37 | MODEL = @models.DualEncoderModel()
38 | models.DualEncoderModel:
39 | use_negatives = False
40 | use_align_uniform = False
41 | inference_mode = %T5XR_INFERENCE_MODE
42 | feature_converter_cls = @feature_converters.DualEncoderFeatureConverterFactory()
43 | module = %ARCHITECTURE # provided by t5_flaxformer
44 | loss_module_factory = %LOSS_MODULE
45 | input_vocabulary = %VOCABULARY
46 | output_vocabulary = %VOCABULARY
47 | optimizer_def = %OPTIMIZER
48 |
--------------------------------------------------------------------------------
/t5x_retrieval/partitioning.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Retrieval Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Copyright 2022 Google LLC
16 | #
17 | # Licensed under the Apache License, Version 2.0 (the "License");
18 | # you may not use this file except in compliance with the License.
19 | # You may obtain a copy of the License at
20 | #
21 | # http://www.apache.org/licenses/LICENSE-2.0
22 | #
23 | # Unless required by applicable law or agreed to in writing, software
24 | # distributed under the License is distributed on an "AS IS" BASIS,
25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 | # See the License for the specific language governing permissions and
27 | # limitations under the License.
28 |
29 | """Custom, default partitioning rules for t5x retrieval models."""
30 |
31 | from t5x import partitioning
32 |
33 |
34 | def standard_logical_axis_rules() -> partitioning.LogicalAxisRules:
35 | """Returns t5x retrieval specific partitioning rules."""
36 | return (
37 | ('affinity', None),
38 | )
39 |
--------------------------------------------------------------------------------
/t5x_retrieval/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Retrieval Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Copyright 2022 Google LLC
16 | #
17 | # Licensed under the Apache License, Version 2.0 (the "License");
18 | # you may not use this file except in compliance with the License.
19 | # You may obtain a copy of the License at
20 | #
21 | # http://www.apache.org/licenses/LICENSE-2.0
22 | #
23 | # Unless required by applicable law or agreed to in writing, software
24 | # distributed under the License is distributed on an "AS IS" BASIS,
25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 | # See the License for the specific language governing permissions and
27 | # limitations under the License.
28 | """Import API modules."""
29 |
30 | import t5x_retrieval.adafactor_utils
31 | import t5x_retrieval.feature_converters
32 | import t5x_retrieval.models
33 | import t5x_retrieval.partitioning
34 | import t5x_retrieval.tasks
35 | import t5x_retrieval.utils
36 |
37 | # Version number.
38 | from t5x_retrieval.version import __version__
39 |
--------------------------------------------------------------------------------
/t5x_retrieval/configs/runs/pretrain.gin:
--------------------------------------------------------------------------------
1 | # Defaults for pretraining with train.py.
2 | #
3 | # You must also include a binding for MODEL.
4 | #
5 | # Required to be set:
6 | #
7 | # -MIXTURE_OR_TASK_NAME
8 | # -MIXTURE_OR_TASK_MODULE
9 | # -TASK_FEATURE_LENGTHS
10 | # -train.model_dir
11 | #
12 | # Commonly overridden options:
13 | #
14 | # - DatasetConfig.batch_size
15 | # - PjitPartitioner.num_partitions
16 | # - Trainer.num_microbatches
17 | #
18 | # Currently we don't support inference eval.
19 | from __gin__ import dynamic_registration
20 |
21 | import __main__ as train_script
22 |
23 | from t5x import adafactor
24 | from t5x import partitioning
25 | from t5x import utils
26 | from t5x_retrieval import adafactor_utils
27 | from t5x_retrieval import partitioning as t5xr_partitioning
28 |
29 | include 't5x/configs/runs/pretrain.gin'
30 |
31 | train_script.train:
32 | infer_eval_dataset_cfg = None
33 |
34 | train/utils.DatasetConfig:
35 | use_cached = False
36 | pack = False
37 |
38 | train_eval/utils.DatasetConfig:
39 | use_cached = False
40 | pack = False
41 |
42 | utils.create_learning_rate_scheduler:
43 | factors = 'linear_decay'
44 | base_learning_rate = 0.001
45 | warmup_steps = 1000
46 | decay_factor = 0.0001 # 1 / %TRAIN_STEPS
47 |
48 | adafactor.Adafactor:
49 | decay_rate = 0.8
50 | step_offset = 0
51 | logical_factor_rules = @adafactor_utils.logical_factor_rules()
52 |
53 | partitioning.PjitPartitioner:
54 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
55 |
56 | partitioning.standard_logical_axis_rules:
57 | additional_rules = @t5xr_partitioning.standard_logical_axis_rules()
58 |
--------------------------------------------------------------------------------
/t5x_retrieval/configs/runs/eval.gin:
--------------------------------------------------------------------------------
1 | # Defaults for eval.py.
2 | #
3 | # You must also include a binding for MODEL.
4 | #
5 | # Required to be set:
6 | #
7 | # - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to use for inference
8 | # - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features
9 | # to.
10 | # - CHECKPOINT_PATH: The model checkpoint to use for inference
11 | # - INFER_OUTPUT_DIR: The dir to write results to.
12 | #
13 | # Commonly overridden options:
14 | #
15 | # - infer.mode
16 | # - infer.checkpoint_period
17 | # - infer.shard_id
18 | # - infer.num_shards
19 | # - DatasetConfig.split
20 | # - DatasetConfig.batch_size
21 | # - DatasetConfig.use_cached
22 | # - RestoreCheckpointConfig.is_tensorflow
23 | # - RestoreCheckpointConfig.mode
24 | # - PjitPartitioner.num_partitions
25 | from __gin__ import dynamic_registration
26 |
27 | import __main__ as infer_script
28 |
29 | from t5x import adafactor
30 | from t5x import partitioning
31 | from t5x import utils
32 | from t5x_retrieval import adafactor_utils
33 | from t5x_retrieval import partitioning as t5xr_partitioning
34 |
35 | include 't5x/configs/runs/eval.gin'
36 |
37 | adafactor.Adafactor:
38 | decay_rate = 0.8
39 | step_offset = 0
40 | logical_factor_rules = @adafactor_utils.logical_factor_rules()
41 |
42 | partitioning.PjitPartitioner:
43 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
44 |
45 | partitioning.standard_logical_axis_rules:
46 | additional_rules = @t5xr_partitioning.standard_logical_axis_rules()
47 |
48 | utils.RestoreCheckpointConfig:
49 | path = %CHECKPOINT_PATH
50 | mode = 'specific'
51 | dtype = 'bfloat16'
52 |
--------------------------------------------------------------------------------
/t5x_retrieval/configs/runs/infer.gin:
--------------------------------------------------------------------------------
1 | # Defaults for infer.py.
2 | #
3 | # You must also include a binding for MODEL.
4 | #
5 | # Required to be set:
6 | #
7 | # - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to use for inference
8 | # - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features
9 | # to.
10 | # - CHECKPOINT_PATH: The model checkpoint to use for inference
11 | # - INFER_OUTPUT_DIR: The dir to write results to.
12 | #
13 | # Commonly overridden options:
14 | #
15 | # - infer.mode
16 | # - infer.checkpoint_period
17 | # - infer.shard_id
18 | # - infer.num_shards
19 | # - DatasetConfig.split
20 | # - DatasetConfig.batch_size
21 | # - DatasetConfig.use_cached
22 | # - RestoreCheckpointConfig.is_tensorflow
23 | # - RestoreCheckpointConfig.mode
24 | # - PjitPartitioner.num_partitions
25 | from __gin__ import dynamic_registration
26 |
27 | import __main__ as infer_script
28 |
29 | from t5x import adafactor
30 | from t5x import partitioning
31 | from t5x import utils
32 | from t5x_retrieval import adafactor_utils
33 | from t5x_retrieval import partitioning as t5xr_partitioning
34 |
35 | include 't5x/configs/runs/infer.gin'
36 |
37 | adafactor.Adafactor:
38 | decay_rate = 0.8
39 | step_offset = 0
40 | logical_factor_rules = @adafactor_utils.logical_factor_rules()
41 |
42 | partitioning.PjitPartitioner:
43 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
44 |
45 | partitioning.standard_logical_axis_rules:
46 | additional_rules = @t5xr_partitioning.standard_logical_axis_rules()
47 |
48 | utils.RestoreCheckpointConfig:
49 | path = %CHECKPOINT_PATH
50 | mode = 'specific'
51 | dtype = 'bfloat16'
52 |
--------------------------------------------------------------------------------
/t5x_retrieval/configs/models/de_mt5_base.gin:
--------------------------------------------------------------------------------
1 | # Dual Encoder based on mT5 Base model.
2 | # Provides MODEL
3 | from __gin__ import dynamic_registration
4 |
5 | import seqio
6 | from t5x import adafactor
7 | from t5x_retrieval import feature_converters
8 | from t5x_retrieval import models as models
9 | from t5x_retrieval import losses
10 |
11 | ARCHITECTURE = %gin.REQUIRED
12 | T5XR_INFERENCE_MODE = 'encode'
13 | LOSS_MODULE = @losses.InBatchCrossEntropyLoss
14 |
15 | # MT5 is identical to t5.1.1 architecture except for the vocabulary.
16 | include 't5x_retrieval/configs/architectures/de_t5_1_1_flaxformer.gin'
17 | NUM_EMBEDDINGS = 250112 # vocab size rounded to a multiple of 128 for TPU efficiency
18 |
19 | # Architecture overrides
20 | NUM_HEADS = 12
21 | NUM_LAYERS = 12
22 | HEAD_DIM = 64
23 | EMBED_DIM = 768
24 | MLP_DIM = 2048
25 | PROJECTION_DIM = 768
26 |
27 | # Vocabulary (shared by encoder and decoder)
28 | VOCABULARY = @seqio.SentencePieceVocabulary()
29 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model"
30 |
31 | # Optimizer
32 | # `learning_rate` is set by `Trainer.learning_rate_fn`.
33 | OPTIMIZER = @adafactor.Adafactor()
34 | adafactor.Adafactor:
35 | decay_rate = 0.8
36 | step_offset = 0
37 |
38 | # Model
39 | MODEL = @models.DualEncoderModel()
40 | models.DualEncoderModel:
41 | use_negatives = False
42 | use_align_uniform = False
43 | inference_mode = %T5XR_INFERENCE_MODE
44 | feature_converter_cls = @feature_converters.DualEncoderFeatureConverterFactory()
45 | module = %ARCHITECTURE # provided by t5_flaxformer
46 | loss_module_factory = %LOSS_MODULE
47 | input_vocabulary = %VOCABULARY
48 | output_vocabulary = %VOCABULARY
49 | optimizer_def = %OPTIMIZER
50 |
--------------------------------------------------------------------------------
/t5x_retrieval/configs/models/de_longt5_1_1_transient_global_base.gin:
--------------------------------------------------------------------------------
1 | # Dual Encoder LongT5 Base model. Config based on Dual Encoder T5.1.1 Base model.
2 | # Provides MODEL
3 | from __gin__ import dynamic_registration
4 |
5 | import seqio
6 | from t5x import adafactor
7 | from t5x_retrieval import feature_converters
8 | from t5x_retrieval import models
9 | from t5x_retrieval import losses
10 |
11 | ARCHITECTURE = %gin.REQUIRED
12 | T5XR_INFERENCE_MODE = 'encode'
13 | LOSS_MODULE = @losses.InBatchCrossEntropyLoss
14 |
15 | include 't5x_retrieval/configs/architectures/de_longt5_1_1_transient_global_flaxformer.gin'
16 |
17 | # Architecture overrides
18 | NUM_HEADS = 12
19 | NUM_LAYERS = 12
20 | HEAD_DIM = 64
21 | EMBED_DIM = 768
22 | MLP_DIM = 2048
23 | PROJECTION_DIM = 768
24 |
25 | # Vocabulary (shared by encoder and decoder)
26 | VOCABULARY = @seqio.SentencePieceVocabulary()
27 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model"
28 | NUM_EMBEDDINGS = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency
29 |
30 | # Optimizer
31 | # `learning_rate` is set by `Trainer.learning_rate_fn`.
32 | OPTIMIZER = @adafactor.Adafactor()
33 | adafactor.Adafactor:
34 | decay_rate = 0.8
35 | step_offset = 0
36 |
37 | # Model
38 | MODEL = @models.DualEncoderModel()
39 | models.DualEncoderModel:
40 | use_negatives = False
41 | use_align_uniform = False
42 | inference_mode = %T5XR_INFERENCE_MODE
43 | feature_converter_cls = @feature_converters.DualEncoderFeatureConverterFactory()
44 | module = %ARCHITECTURE # provided by t5_flaxformer
45 | loss_module_factory = %LOSS_MODULE
46 | input_vocabulary = %VOCABULARY
47 | output_vocabulary = %VOCABULARY
48 | optimizer_def = %OPTIMIZER
49 |
--------------------------------------------------------------------------------
/t5x_retrieval/adafactor_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Retrieval Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Copyright 2022 Google LLC
16 | #
17 | # Licensed under the Apache License, Version 2.0 (the "License");
18 | # you may not use this file except in compliance with the License.
19 | # You may obtain a copy of the License at
20 | #
21 | # http://www.apache.org/licenses/LICENSE-2.0
22 | #
23 | # Unless required by applicable law or agreed to in writing, software
24 | # distributed under the License is distributed on an "AS IS" BASIS,
25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 | # See the License for the specific language governing permissions and
27 | # limitations under the License.
28 | """Adafactor logical rules for T5X Retrieval."""
29 |
30 | from flax.core import freeze
31 | from flax.core import unfreeze
32 | from t5x import adafactor
33 |
34 |
35 | def logical_factor_rules():
36 | """Logical factor rules for T5X Retrieval (i.e. two tower models)."""
37 | rules = unfreeze(adafactor.standard_logical_factor_rules())
38 | rules.update({
39 | 'affinity': adafactor.FactorDim.COLUMN,
40 | 'idf_buckets': adafactor.FactorDim.NONE,
41 | })
42 | return freeze(rules)
43 |
--------------------------------------------------------------------------------
/t5x_retrieval/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Retrieval Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Copyright 2022 Google LLC
16 | #
17 | # Licensed under the Apache License, Version 2.0 (the "License");
18 | # you may not use this file except in compliance with the License.
19 | # You may obtain a copy of the License at
20 | #
21 | # http://www.apache.org/licenses/LICENSE-2.0
22 | #
23 | # Unless required by applicable law or agreed to in writing, software
24 | # distributed under the License is distributed on an "AS IS" BASIS,
25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 | # See the License for the specific language governing permissions and
27 | # limitations under the License.
28 | """Customized metric functions for T5X Retrieval."""
29 |
30 | from typing import Dict, Sequence
31 |
32 | import scipy.stats
33 |
34 |
35 | def spearman_corrcoef(targets: Sequence[float],
36 | scores: Sequence[float]) -> Dict[str, float]:
37 | """Spearman correlation coefficient.
38 |
39 | Args:
40 | targets: list of float.
41 | scores: list of float.
42 |
43 | Returns:
44 | Spearman correlation across all targets and scores.
45 | """
46 | return {'spearman_corrcoef': 100 * scipy.stats.spearmanr(targets, scores)[0]}
47 |
--------------------------------------------------------------------------------
/t5x_retrieval/postprocessors.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Retrieval Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Copyright 2022 Google LLC
16 | #
17 | # Licensed under the Apache License, Version 2.0 (the "License");
18 | # you may not use this file except in compliance with the License.
19 | # You may obtain a copy of the License at
20 | #
21 | # http://www.apache.org/licenses/LICENSE-2.0
22 | #
23 | # Unless required by applicable law or agreed to in writing, software
24 | # distributed under the License is distributed on an "AS IS" BASIS,
25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 | # See the License for the specific language governing permissions and
27 | # limitations under the License.
28 | """Postprocessors for T5X Retrieval."""
29 |
30 | from typing import Mapping, Optional, Union
31 |
32 | import tensorflow as tf
33 |
34 |
35 | def extract_label_postprocessor(
36 | output: Mapping[str, tf.Tensor],
37 | example: Optional[Mapping[str, tf.Tensor]] = None,
38 | is_target: Optional[bool] = False
39 | ) -> Union[tf.Tensor, Mapping[str, tf.Tensor]]:
40 | """Extracts the label to feed into the SeqIO evaluator.
41 |
42 | Args:
43 | output: A mapping of strings and tensors.
44 | example: An optional mapping of strings and tensors.
45 | is_target: An optional variable to indicate whether the postprocessor is
46 | applied on the output or the target (i.e. the "labels" field.).
47 |
48 | Returns:
49 | The target tensor or the output mapping.
50 | """
51 | if is_target:
52 | return example["labels"]
53 | return output
54 |
--------------------------------------------------------------------------------
/t5x_retrieval/preprocessors.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Retrieval Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Copyright 2022 Google LLC
16 | #
17 | # Licensed under the Apache License, Version 2.0 (the "License");
18 | # you may not use this file except in compliance with the License.
19 | # You may obtain a copy of the License at
20 | #
21 | # http://www.apache.org/licenses/LICENSE-2.0
22 | #
23 | # Unless required by applicable law or agreed to in writing, software
24 | # distributed under the License is distributed on an "AS IS" BASIS,
25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 | # See the License for the specific language governing permissions and
27 | # limitations under the License.
28 | """Preprocessors for T5X Retrieval."""
29 |
30 | import tensorflow as tf
31 |
32 |
33 | def to_stsb_label(dataset: tf.data.Dataset, label_field_name: str,
34 | label_type: str) -> tf.data.Dataset:
35 | """Converts the labels to scores within [0, 1] or multi class labels.
36 |
37 | Args:
38 | dataset: A TensorFlow dataset.
39 | label_field_name: A string of the label field name.
40 | label_type: A string indicating the label type.
41 |
42 | Returns:
43 | A TensorFlow dataset after the transformation.
44 | """
45 |
46 | def map_fn(example):
47 | if label_type == "score":
48 | label = example[label_field_name] / 5
49 | elif label_type == "multi_class":
50 | label = example[label_field_name]
51 | label = tf.round(label * 2)
52 | else:
53 | raise ValueError(f"Unsupported label type: {label_type}")
54 | label = tf.expand_dims(label, axis=-1)
55 | example[label_field_name] = label
56 | return example
57 |
58 | return dataset.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE)
59 |
--------------------------------------------------------------------------------
/t5x_retrieval/configs/runs/finetune.gin:
--------------------------------------------------------------------------------
1 | # Defaults for finetuning with train.py.
2 | #
3 | # You must also include a binding for MODEL.
4 | #
5 | # Required to be set:
6 | #
7 | # - MIXTURE_OR_TASK_NAME
8 | # - MIXTURE_OR_TASK_MODULE
9 | # - TASK_FEATURE_LENGTHS
10 | # - TRAIN_STEPS # includes pretrain steps
11 | # - MODEL_DIR # automatically set when using xm_launch
12 | # - INITIAL_CHECKPOINT_PATH
13 | #
14 | # `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt
15 | # has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps.
16 | #
17 | # Commonly overridden options:
18 | # - DROPOUT_RATE
19 | # - train/DatasetConfig.batch_size
20 | # - train_eval/DatasetConfig.batch_size
21 | # - infer_eval/DatasetConfig.batch_size
22 | # - PjitPartitioner.num_partitions
23 | # - Trainer.num_microbatches
24 | # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess
25 | # on the fly. Most common tasks are cached, hence this is set to True by
26 | # default.
27 | from __gin__ import dynamic_registration
28 |
29 | import __main__ as train_script
30 |
31 | from t5x import adafactor
32 | from t5x import partitioning
33 | from t5x import utils
34 | from t5x_retrieval import adafactor_utils
35 | from t5x_retrieval import partitioning as t5xr_partitioning
36 |
37 | include 't5x/configs/runs/finetune.gin'
38 |
39 | BATCH_SIZE = 128
40 |
41 | train_script.train:
42 | infer_eval_dataset_cfg = None
43 | eval_steps = 20
44 | eval_period = 1000 # eval frequency
45 | random_seed = None
46 |
47 | train/utils.DatasetConfig:
48 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME
49 | task_feature_lengths = %TASK_FEATURE_LENGTHS
50 | split = 'train'
51 | shuffle = True
52 | seed = None # use a new seed each run/restart
53 | use_cached = %USE_CACHED_TASKS
54 | pack = False
55 | module = %MIXTURE_OR_TASK_MODULE
56 |
57 | train_eval/utils.DatasetConfig:
58 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME
59 | task_feature_lengths = %TASK_FEATURE_LENGTHS
60 | split = 'validation'
61 | shuffle = False
62 | seed = 42
63 | use_cached = %USE_CACHED_TASKS
64 | pack = False
65 | module = %MIXTURE_OR_TASK_MODULE
66 |
67 | utils.RestoreCheckpointConfig:
68 | assignment_map = ((r'state/param_states.*', None),) # Skip optimizer states
69 | fallback_to_scratch = True
70 |
71 | utils.SaveCheckpointConfig:
72 | period = 1000
73 | dtype = 'float32'
74 | keep = None # keep all checkpoints
75 |
76 | utils.create_learning_rate_scheduler:
77 | factors = 'linear_decay'
78 | base_learning_rate = 0.001
79 | warmup_steps = 1000
80 |
81 | adafactor.Adafactor:
82 | decay_rate = 0.8
83 | step_offset = 0
84 | logical_factor_rules = @adafactor_utils.logical_factor_rules()
85 |
86 | partitioning.PjitPartitioner:
87 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
88 |
89 | partitioning.standard_logical_axis_rules:
90 | additional_rules = @t5xr_partitioning.standard_logical_axis_rules()
91 |
--------------------------------------------------------------------------------
/t5x_retrieval/tasks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Retrieval Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Copyright 2022 Google LLC
16 | #
17 | # Licensed under the Apache License, Version 2.0 (the "License");
18 | # you may not use this file except in compliance with the License.
19 | # You may obtain a copy of the License at
20 | #
21 | # http://www.apache.org/licenses/LICENSE-2.0
22 | #
23 | # Unless required by applicable law or agreed to in writing, software
24 | # distributed under the License is distributed on an "AS IS" BASIS,
25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 | # See the License for the specific language governing permissions and
27 | # limitations under the License.
28 | """Add Tasks to registry."""
29 | import functools
30 |
31 | import seqio
32 | import t5.data
33 | from t5x_retrieval import metrics
34 | from t5x_retrieval import postprocessors as t5xr_postprocessors
35 | from t5x_retrieval import preprocessors as t5xr_preprocessors
36 | import tensorflow as tf
37 |
38 |
39 | DEFAULT_VOCAB = t5.data.get_default_vocabulary()
40 | DEFAULT_OUTPUT_FEATURES = {
41 | "inputs":
42 | seqio.Feature(vocabulary=DEFAULT_VOCAB, add_eos=True, required=False),
43 | "targets":
44 | seqio.Feature(vocabulary=DEFAULT_VOCAB, add_eos=True)
45 | }
46 |
47 | RELEVANCE_OUTPUT_FEATURES = {
48 | "inputs":
49 | seqio.Feature(
50 | vocabulary=t5.data.get_default_vocabulary(),
51 | add_eos=True,
52 | required=True),
53 | "targets":
54 | seqio.Feature(
55 | vocabulary=t5.data.get_default_vocabulary(), add_eos=True),
56 | "labels":
57 | seqio.Feature(
58 | vocabulary=seqio.PassThroughVocabulary(size=1),
59 | add_eos=False,
60 | required=False,
61 | dtype=tf.float32)
62 | }
63 |
64 |
65 | # =========================== Fine-tuning Tasks/Mixtures =======================
66 | # ----- Beir MS Marco-----
67 | seqio.TaskRegistry.add(
68 | "beir_msmarco_retrieval",
69 | source=seqio.TfdsDataSource(
70 | tfds_name="beir/msmarco:1.0.0",
71 | splits={
72 | "train": "train",
73 | "validation": "validation",
74 | },
75 | ),
76 | preprocessors=[
77 | functools.partial(
78 | t5.data.preprocessors.rekey,
79 | key_map={
80 | "inputs": "query",
81 | "targets": "passage",
82 | }),
83 | seqio.preprocessors.tokenize,
84 | seqio.CacheDatasetPlaceholder(),
85 | seqio.preprocessors.append_eos_after_trim,
86 | ],
87 | metric_fns=[],
88 | output_features=DEFAULT_OUTPUT_FEATURES)
89 |
90 |
91 | # ========================== STS Benchmark ====================================
92 | seqio.TaskRegistry.add(
93 | "glue_stsb_v002_score",
94 | source=seqio.TfdsDataSource(tfds_name="glue/stsb:2.0.0", splits=None),
95 | preprocessors=[
96 | functools.partial(
97 | seqio.preprocessors.rekey,
98 | key_map={
99 | "inputs": "sentence1",
100 | "targets": "sentence2",
101 | "labels": "label"
102 | }),
103 | functools.partial(
104 | t5xr_preprocessors.to_stsb_label,
105 | label_field_name="labels",
106 | label_type="score",
107 | ),
108 | seqio.preprocessors.tokenize,
109 | seqio.CacheDatasetPlaceholder(),
110 | seqio.preprocessors.append_eos_after_trim,
111 | ],
112 | output_features=RELEVANCE_OUTPUT_FEATURES,
113 | postprocess_fn=t5xr_postprocessors.extract_label_postprocessor,
114 | metric_fns=[metrics.spearman_corrcoef])
115 |
116 |
117 | # ============================ Inference Tasks/Mixtures =======================
118 | # ----- Beir MS Marco-----
119 | for split in ["query", "passage"]:
120 | seqio.TaskRegistry.add(
121 | f"beir_msmarco_retrieval_{split}",
122 | source=seqio.TfdsDataSource(
123 | tfds_name="beir/msmarco:1.0.0",
124 | splits={split: split},
125 | ),
126 | preprocessors=[
127 | functools.partial(
128 | t5.data.preprocessors.rekey,
129 | key_map={
130 | "inputs": split,
131 | "targets": f"{split}_id",
132 | }),
133 | seqio.preprocessors.tokenize,
134 | seqio.CacheDatasetPlaceholder(),
135 | seqio.preprocessors.append_eos_after_trim,
136 | ],
137 | metric_fns=[],
138 | output_features=DEFAULT_OUTPUT_FEATURES)
139 |
--------------------------------------------------------------------------------
/t5x_retrieval/configs/architectures/de_t5_flaxformer.gin:
--------------------------------------------------------------------------------
1 | # Flaxformer implementation of Dual Encoder, based on original T5 (1.0) architecture.
2 | #
3 | # Required to be overridden:
4 | #
5 | # - NUM_HEADS
6 | # - NUM_LAYERS
7 | # - HEAD_DIM
8 | # - EMBED_DIM
9 | # - MLP_DIM
10 | # - PROJECTION_DIM
11 | from __gin__ import dynamic_registration
12 |
13 | from flax import linen
14 |
15 | from flaxformer.architectures.dual_encoder import dual_encoder_architecture
16 | from flaxformer.architectures.dual_encoder import l2_norm
17 | from flaxformer.architectures.dual_encoder import poolings
18 | from flaxformer.architectures.dual_encoder import similarity_functions
19 | from flaxformer.architectures.t5 import t5_architecture
20 | from flaxformer.components.attention import dense_attention
21 | from flaxformer.components import dense
22 | from flaxformer.components import embedding
23 | from flaxformer.components import layer_norm
24 | from flaxformer.components import relative_position_biases
25 |
26 | from t5x import models
27 | from t5x import utils
28 | from t5x_retrieval import feature_converters
29 |
30 | # Must be overridden.
31 | NUM_HEADS = %gin.REQUIRED
32 | NUM_LAYERS = %gin.REQUIRED
33 | HEAD_DIM = %gin.REQUIRED
34 | EMBED_DIM = %gin.REQUIRED
35 | MLP_DIM = %gin.REQUIRED
36 | PROJECTION_DIM = %gin.REQUIRED
37 |
38 | # Constants (may be overridden)
39 | ACTIVATION_DTYPE = 'bfloat16'
40 | ACTIVATION_PARTITIONING_DIMS = 1
41 | NUM_EMBEDDINGS = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency
42 | DROPOUT_RATE = 0.0
43 |
44 | # Macros
45 | BIAS_INIT = @bias_init/linen.initializers.normal()
46 | bias_init/linen.initializers.normal.stddev = 1e-6
47 | DROPOUT_FACTORY = @dropout_factory/linen.Dropout
48 | dropout_factory/linen.Dropout:
49 | rate = %DROPOUT_RATE
50 | broadcast_dims = (-2,)
51 |
52 | # Architecture (Flax Module)
53 | ARCHITECTURE = @dual_encoder_architecture.DualEncoder()
54 | dual_encoder_architecture.DualEncoder:
55 | encoder_factory = @t5_architecture.Encoder
56 | pooler_factory = @poolings.MeanPooling
57 | shared_token_embedder_factory = @embedding.Embed
58 | l2_norm_factory = @l2_norm.L2Norm
59 | projection_layer_factory = @projection_layer/dense.DenseGeneral
60 | similarity_layer_factory = @similarity_functions.BatchDotProduct
61 | dtype = %ACTIVATION_DTYPE
62 |
63 | # Encoder
64 | t5_architecture.Encoder:
65 | num_layers = %NUM_LAYERS
66 | layer_factory = @t5_architecture.EncoderLayer
67 | input_dropout_factory = %DROPOUT_FACTORY
68 | output_dropout_factory = %DROPOUT_FACTORY
69 | layer_norm_factory = @layer_norm.T5LayerNorm
70 | position_embedder_factory = None
71 | shared_relative_position_bias_factory = @relative_position_biases.RelativePositionBiases
72 | dtype = %ACTIVATION_DTYPE
73 |
74 | # TODO(b/262657686): Move this to model gin files.
75 | # Infer the input features from the TASK_FEATURE_LENGTHS passed by the user.
76 | feature_converters.DualEncoderFeatureConverterFactory:
77 | feature_specs = (
78 | ("inputs", "int32", 1, 0),
79 | ("targets", "int32", 1, 0),
80 | )
81 |
82 | # Similarity layer
83 | similarity_functions.BatchDotProduct:
84 | name = 'batch_dot_product'
85 |
86 | # Projection layer
87 | projection_layer/dense.DenseGeneral:
88 | features = %PROJECTION_DIM
89 | use_bias = False
90 | dtype = 'float32'
91 | kernel_init = @projection_layer/linen.initializers.variance_scaling()
92 | kernel_axis_names = ('embed', 'affinity')
93 | bias_init = %BIAS_INIT
94 | projection_layer/linen.initializers.variance_scaling:
95 | scale = 1
96 | mode = 'fan_in'
97 | distribution = 'truncated_normal'
98 |
99 | # Encoder Layer
100 | t5_architecture.EncoderLayer:
101 | attention = @dense_attention.MultiHeadDotProductAttention()
102 | mlp = @dense.MlpBlock()
103 | dropout_factory = %DROPOUT_FACTORY
104 | layer_norm_factory = @layer_norm.T5LayerNorm
105 | activation_partitioning_dims = %ACTIVATION_PARTITIONING_DIMS
106 |
107 | # Token Embedder (shared)
108 | embedding.Embed:
109 | num_embeddings= %NUM_EMBEDDINGS
110 | features = %EMBED_DIM
111 | cast_input_dtype = 'int32'
112 | dtype = %ACTIVATION_DTYPE
113 | attend_dtype = 'float32' # for logit training stability
114 | one_hot = True
115 | embedding_init = @token_embedder_init/linen.initializers.normal()
116 | name = 'token_embedder'
117 | token_embedder_init/linen.initializers.normal.stddev = 1.0
118 |
119 | # Attention (encoder, decoder, self-attention)
120 | dense_attention.MultiHeadDotProductAttention:
121 | num_heads = %NUM_HEADS
122 | head_dim = %HEAD_DIM
123 | dtype = %ACTIVATION_DTYPE
124 | kernel_init = @attention_kernel_init/linen.initializers.variance_scaling()
125 | bias_init = %BIAS_INIT
126 | use_bias = False
127 | broadcast_dropout = True
128 | dropout_rate = %DROPOUT_RATE
129 | attention_kernel_init/linen.initializers.variance_scaling:
130 | scale = 1.0
131 | mode = 'fan_in'
132 | distribution = 'normal'
133 |
134 | # Relative position biases (encoder, decoder)
135 | relative_position_biases.RelativePositionBiases:
136 | num_heads = %NUM_HEADS
137 | num_buckets = 32
138 | max_distance = 128
139 | dtype = %ACTIVATION_DTYPE
140 | embedding_init = @relative_position_bias_init/linen.initializers.variance_scaling()
141 | relative_position_bias_init/linen.initializers.variance_scaling:
142 | scale = 1.0
143 | mode = 'fan_avg'
144 | distribution = 'uniform'
145 |
146 | # MLP (encoder)
147 | dense.MlpBlock:
148 | use_bias = False
149 | intermediate_dim = %MLP_DIM
150 | activations = ('relu',)
151 | kernel_init = @mlp_kernel_init/linen.initializers.variance_scaling()
152 | bias_init = %BIAS_INIT
153 | intermediate_dropout_rate = %DROPOUT_RATE
154 | final_dropout_rate = 0
155 | dtype = %ACTIVATION_DTYPE
156 | mlp_kernel_init/linen.initializers.variance_scaling:
157 | scale = 1.0
158 | mode = 'fan_in'
159 | distribution = 'truncated_normal'
160 |
--------------------------------------------------------------------------------
/t5x_retrieval/configs/architectures/de_longt5_1_1_transient_global_flaxformer.gin:
--------------------------------------------------------------------------------
1 | # Flaxformer implementation of Long Dual Encoder, based on LongT5 architecture.
2 | #
3 | # Required to be overridden:
4 | #
5 | # - NUM_HEADS
6 | # - NUM_LAYERS
7 | # - HEAD_DIM
8 | # - EMBED_DIM
9 | # - MLP_DIM
10 | from __gin__ import dynamic_registration
11 |
12 | from flax import linen
13 | from flaxformer.architectures.dual_encoder import dual_encoder_architecture
14 | from flaxformer.architectures.dual_encoder import l2_norm
15 | from flaxformer.architectures.dual_encoder import poolings
16 | from flaxformer.architectures.dual_encoder import similarity_functions
17 | from flaxformer.architectures.longt5 import long_attention
18 | from flaxformer.architectures.longt5 import longt5_architecture
19 | from flaxformer.architectures.longt5 import relative_position_biases_general
20 | from flaxformer.components import dense
21 | from flaxformer.components import embedding
22 | from flaxformer.components import layer_norm
23 | from flaxformer.components import relative_position_biases
24 |
25 | from t5x_retrieval import feature_converters
26 |
27 |
28 | NUM_LAYERS = %gin.REQUIRED
29 | NUM_HEADS = %gin.REQUIRED
30 | HEAD_DIM = %gin.REQUIRED
31 | EMBED_DIM = %gin.REQUIRED
32 | MLP_DIM = %gin.REQUIRED
33 |
34 |
35 | ACTIVATION_DTYPE = 'bfloat16'
36 | ACTIVATION_PARTITIONING_DIMS = 1
37 | SCALE = 1.0
38 | DROPOUT_RATE = 0.0
39 | LOCAL_RADIUS = 127
40 | TOKENS_PER_BLOCK = 16
41 |
42 | # Macros
43 | BIAS_INIT = @bias_init/linen.initializers.normal()
44 | bias_init/linen.initializers.normal.stddev = 1e-6
45 | DROPOUT_FACTORY = @dropout_factory/linen.Dropout
46 | dropout_factory/linen.Dropout:
47 | rate = %DROPOUT_RATE
48 | broadcast_dims = (-2,)
49 |
50 | # Architecture (Flax Module)
51 | ARCHITECTURE = @dual_encoder_architecture.LongDualEncoder()
52 | dual_encoder_architecture.LongDualEncoder:
53 | encoder_factory = @longt5_architecture.LongEncoder
54 | pooler_factory = @poolings.MeanPooling
55 | shared_token_embedder_factory = @embedding.Embed
56 | l2_norm_factory = @l2_norm.L2Norm
57 | projection_layer_factory = @projection_layer/dense.DenseGeneral
58 | similarity_layer_factory = @similarity_functions.BatchDotProduct
59 | dtype = %ACTIVATION_DTYPE
60 |
61 | # Encoder
62 | longt5_architecture.LongEncoder:
63 | num_layers = %NUM_LAYERS
64 | layer_factory = @longt5_architecture.LongEncoderLayer
65 | input_dropout_factory = %DROPOUT_FACTORY
66 | output_dropout_factory = %DROPOUT_FACTORY
67 | layer_norm_factory = @layer_norm.T5LayerNorm
68 | position_embedder_factory = None
69 | shared_relpos_bias_factory = @relative_position_biases_general.RelativePositionBiasesGeneral
70 | shared_side_relpos_bias_factory = @relative_position_biases_general.RelativePositionBiasesGeneral
71 | dtype = %ACTIVATION_DTYPE
72 |
73 | # Infer the input features from the TASK_FEATURE_LENGTHS passed by the user.
74 | feature_converters.DualEncoderFeatureConverterFactory:
75 | feature_specs = (
76 | ("inputs", "int32", 1, 0),
77 | ("targets", "int32", 1, 0),
78 | )
79 |
80 | # Similarity layer
81 | similarity_functions.BatchDotProduct:
82 | name = 'batch_dot_product'
83 |
84 | # Projection layer
85 | projection_layer/dense.DenseGeneral:
86 | features = %PROJECTION_DIM
87 | use_bias = False
88 | dtype = 'float32'
89 | kernel_init = @projection_layer/linen.initializers.variance_scaling()
90 | kernel_axis_names = ('embed', 'affinity')
91 | bias_init = %BIAS_INIT
92 | projection_layer/linen.initializers.variance_scaling:
93 | scale = 1
94 | mode = 'fan_in'
95 | distribution = 'truncated_normal'
96 |
97 | # Encoder Layer
98 | longt5_architecture.LongEncoderLayer:
99 | attention_factory = @long_attention.EtcTransientGlobalSelfAttention
100 | mlp = @dense.MlpBlock()
101 | dropout_factory = %DROPOUT_FACTORY
102 | layer_norm_factory = @layer_norm.T5LayerNorm
103 | activation_partitioning_dims = %ACTIVATION_PARTITIONING_DIMS
104 |
105 | # Long Attention (encoder, self-attention)
106 | long_attention.EtcTransientGlobalSelfAttention:
107 | num_heads = %NUM_HEADS
108 | tokens_per_block = %TOKENS_PER_BLOCK
109 | local_radius = %LOCAL_RADIUS
110 | dtype = %ACTIVATION_DTYPE
111 | head_dim = %HEAD_DIM
112 | kernel_init = @attention_kernel_init/linen.initializers.variance_scaling()
113 | bias_init = %BIAS_INIT
114 | use_bias = False
115 | broadcast_dropout = True
116 | dropout_rate = %DROPOUT_RATE
117 | attention_kernel_init/linen.initializers.variance_scaling:
118 | scale = %SCALE
119 | mode = 'fan_in'
120 | distribution = 'normal'
121 |
122 | # Relative position biases (encoder)
123 | relative_position_biases_general.RelativePositionBiasesGeneral:
124 | num_heads = %NUM_HEADS
125 | dtype = %ACTIVATION_DTYPE
126 | num_buckets = 32
127 | max_distance = 128
128 | embedding_init = @relative_position_bias_init/linen.initializers.variance_scaling()
129 | relative_position_bias_init/linen.initializers.variance_scaling:
130 | scale = %SCALE
131 | mode = 'fan_avg'
132 | distribution = 'uniform'
133 |
134 | # Token Embedder (shared)
135 | embedding.Embed:
136 | num_embeddings= %NUM_EMBEDDINGS
137 | features = %EMBED_DIM
138 | cast_input_dtype = 'int32'
139 | dtype = %ACTIVATION_DTYPE
140 | attend_dtype = 'float32' # for logit training stability
141 | one_hot = True
142 | embedding_init = @token_embedder_init/linen.initializers.normal()
143 | name = 'token_embedder'
144 | token_embedder_init/linen.initializers.normal.stddev = 1.0
145 |
146 | dense.MlpBlock:
147 | use_bias = False
148 | intermediate_dim = %MLP_DIM
149 | activations = ('gelu', 'linear')
150 | kernel_init = @mlp_kernel_init/linen.initializers.variance_scaling()
151 | bias_init = %BIAS_INIT
152 | intermediate_dropout_rate = %DROPOUT_RATE
153 | final_dropout_rate = 0
154 | dtype = %ACTIVATION_DTYPE
155 | mlp_kernel_init/linen.initializers.variance_scaling:
156 | scale = 1.0
157 | mode = 'fan_in'
158 | distribution = 'truncated_normal'
159 |
--------------------------------------------------------------------------------
/t5x_retrieval/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Retrieval Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Copyright 2022 Google LLC
16 | #
17 | # Licensed under the Apache License, Version 2.0 (the "License");
18 | # you may not use this file except in compliance with the License.
19 | # You may obtain a copy of the License at
20 | #
21 | # http://www.apache.org/licenses/LICENSE-2.0
22 | #
23 | # Unless required by applicable law or agreed to in writing, software
24 | # distributed under the License is distributed on an "AS IS" BASIS,
25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 | # See the License for the specific language governing permissions and
27 | # limitations under the License.
28 | """Loss layer implementations for dual encoders."""
29 | import abc
30 | from typing import Mapping, Optional, Tuple
31 |
32 | import clu.metrics as clu_metrics
33 | from flax import linen as nn
34 | from flax.core import scope as flax_scope
35 | import gin
36 | from jax import numpy as jnp
37 | import seqio
38 | from t5x import metrics as t5x_metrics
39 | from t5x import models as t5x_models
40 | from t5x_retrieval import feature_converters
41 | from t5x_retrieval import utils
42 |
43 | FeatureSpec = feature_converters.FeatureSpec
44 | SeqIOFeatureSpec = seqio.feature_converters.FeatureConverter.FeatureSpec
45 |
46 |
47 | @gin.configurable
48 | class DualEncoderLoss(nn.Module, abc.ABC):
49 | """Base class for loss layers accepted by dual encoders.
50 |
51 | """
52 |
53 | @property
54 | @abc.abstractmethod
55 | def LOSS_MODEL_FEATURES(self) -> Mapping[str, SeqIOFeatureSpec]:
56 | """Model features required by loss layer for computing loss/metrcs.
57 |
58 | This property specifies the features that are expected to be present in
59 | `batch` argument passed to the __call__ method by a dual encoder model.
60 | These must be described as a map of feature keys to SeqIO `FeatureSpec`s,
61 | following the format of the `MODEL_FEATURES` attribute of SeqIO's
62 | `FeatureConverter`s. The features specified here in this manner should be
63 | validated by `validate_model_features` method (the default base class
64 | implementation does this).
65 | """
66 | pass
67 |
68 | @nn.compact
69 | @abc.abstractmethod
70 | def __call__(self, batch: Mapping[str, jnp.ndarray], logits: jnp.ndarray,
71 | **kwargs) -> Tuple[jnp.float64, t5x_metrics.MetricsMap]:
72 | """Computes loss and loss-related metrics on the inputs.
73 |
74 | Args:
75 | batch: Features passed by dual-encoder model. Any external supervision
76 | signals/labels should be passed through this argument.
77 | logits: Output of dual-encoder model's similarity layer.
78 | **kwargs: Additional arguments that may be needed by the call method
79 |
80 | Returns:
81 | Tuple of (loss, metrics)
82 | """
83 | pass
84 |
85 | @abc.abstractmethod
86 | def get_initial_variables(self, rng, input_shapes,
87 | input_types) -> flax_scope.FrozenVariableDict:
88 | """Gets the initial variables for a loss layer."""
89 | pass
90 |
91 | def validate_model_features(self, feature_converter: seqio.FeatureConverter):
92 | """Ensures loss-layer's required features are provided by feature converter.
93 |
94 | This method checks that the feature-spec for loss model features match. Any
95 | additional validation should be performed by child classes.
96 |
97 | Args:
98 | feature_converter: Feature converter used by the dual encoder model that
99 | invokes this loss layer.
100 |
101 | Raises:
102 | ValueError, if feature is missing or has a spec mismatch
103 | """
104 | model_features = feature_converter.MODEL_FEATURES
105 | for name, feature in self.LOSS_MODEL_FEATURES.items():
106 | model_feature = model_features.get(name, None)
107 | if not model_feature:
108 | raise ValueError(f"Missing required loss-layer feature {name} "
109 | f"in model features {model_features}")
110 | if not isinstance(feature, type(model_feature)):
111 | raise ValueError(
112 | "Found incorrect type for feature spec of loss layer feature ",
113 | f"{name}, expected {type(model_feature)}, found {type(feature)}.")
114 | if feature != model_feature:
115 | raise ValueError("Found incorrect feature spec for loss layer feature "
116 | f"{name}, expected {feature}, found {model_feature}")
117 |
118 |
119 | class InBatchCrossEntropyLoss(DualEncoderLoss):
120 | """Dual encoder in-batch cross-entropy loss implementation.
121 |
122 | Attributes:
123 | bidirectional: Whether to use bi-directional in-batch softmax loss. If set
124 | to True, consider both left-to-right and right-to-left losses.
125 | label_smoothing: Label smoothing constant, used to determine the on and off
126 | values.
127 | """
128 |
129 | bidirectional: bool = True
130 | label_smoothing: float = 0.0
131 |
132 | @property
133 | def LOSS_MODEL_FEATURES(self):
134 | """Model features required by loss layer for computing loss/metrcs."""
135 | # This loss relies only on in-batch logits, and therefore doesn't
136 | # need any additional model features
137 | return {}
138 |
139 | def get_initial_variables(self, rng, input_shapes,
140 | input_types) -> flax_scope.FrozenVariableDict:
141 | """Gets the initial variables for a loss layer."""
142 | # `logits` is of shape [B, B*(1+num_negatives)] that considers the
143 | # negatives while `right_logits` is in shape [B, B] that doesn't considers
144 | # negatives. `num_negatives` could be greater than 1 in the future.
145 | left_encoder_shape = input_shapes["left_encoder_input_tokens"]
146 | batch_size = left_encoder_shape[0]
147 |
148 | num_negatives = 0
149 | if "right_negative_encoder_input_tokens" in input_shapes:
150 | # right_negative_encoder_input_tokens: batch_size x num_negatives
151 | num_negatives = input_shapes["right_negative_encoder_input_tokens"][1]
152 |
153 | return self.init(
154 | rng,
155 | params={},
156 | batch={},
157 | logits=jnp.ones([batch_size, (batch_size * (1 + num_negatives))]),
158 | right_logits=jnp.ones([batch_size, batch_size]))
159 |
160 | @nn.compact
161 | def __call__(self,
162 | batch: Mapping[str, jnp.ndarray],
163 | logits: jnp.ndarray,
164 | right_logits: Optional[jnp.ndarray] = None,
165 | **kwargs) -> Tuple[jnp.float64, t5x_metrics.MetricsMap]:
166 | """Computes loss and loss-related metrics on inputs.
167 |
168 | `logits` is of shape [B, B*(1+num_negatives)] that considers the
169 | negatives while `right_logits` is in shape [B, B] that doesn't considers
170 | negatives. `num_negatives` could be greater than 1 in the future.
171 |
172 | Args:
173 | batch: Features passed by dual-encoder model. Unused.
174 | logits: Output of similarity layer that considers negatives. Has shape [B,
175 | B*(1+num_negatives)].
176 | right_logits: Output of similarity layer that doesn't consider negatives.
177 | Has shape [B, B]. If None, the right loss is skipped.
178 | **kwargs: Unused.
179 |
180 | Returns:
181 | loss: a float scalar for contrastive loss.
182 | metrics: metrics defined in `t5x_models.compute_base_metrics` and `MRR`.
183 | """
184 | del kwargs
185 | del batch # we don't require any external inputs for this loss
186 |
187 | # z_loss is already added to loss, which is a workaround for the numerical
188 | # instability issue.
189 | z_loss = 0.
190 | loss = utils.in_batch_cross_entropy(
191 | logits, label_smoothing=self.label_smoothing)
192 | if right_logits is not None and self.bidirectional:
193 | right_loss = utils.in_batch_cross_entropy(
194 | right_logits, label_smoothing=self.label_smoothing)
195 | loss = jnp.mean(loss + right_loss)
196 |
197 | metrics = t5x_models.compute_base_metrics(
198 | logits=logits,
199 | targets=utils.sparse_labels_for_in_batch_cross_entropy(logits),
200 | mask=None,
201 | loss=loss,
202 | z_loss=z_loss)
203 |
204 | return loss, metrics
205 |
206 |
207 |
--------------------------------------------------------------------------------
/t5x_retrieval/feature_converters.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Retrieval Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Copyright 2022 Google LLC
16 | #
17 | # Licensed under the Apache License, Version 2.0 (the "License");
18 | # you may not use this file except in compliance with the License.
19 | # You may obtain a copy of the License at
20 | #
21 | # http://www.apache.org/licenses/LICENSE-2.0
22 | #
23 | # Unless required by applicable law or agreed to in writing, software
24 | # distributed under the License is distributed on an "AS IS" BASIS,
25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 | # See the License for the specific language governing permissions and
27 | # limitations under the License.
28 |
29 | """Feature converter for dual encoders."""
30 |
31 | import dataclasses
32 | from typing import Any, Iterable, Mapping, Sequence, Tuple
33 | import ml_collections
34 | import seqio
35 | import tensorflow as tf
36 |
37 | FeatureSpecConfig = Tuple[str, str, int, int]
38 |
39 | utils = seqio.utils
40 | FeatureConverter = seqio.FeatureConverter
41 |
42 |
43 | # TODO(b/262657107): Add features for universal feature converter.
44 | _MODEL_FEATURES_MAPPING = {
45 | "inputs": "left_encoder_input_tokens",
46 | "targets": "right_encoder_input_tokens",
47 | "negative_targets": "right_negative_encoder_input_tokens",
48 | "soft_label_pos": "soft_label_pos",
49 | "soft_label_neg": "soft_label_neg",
50 | "labels": "labels",
51 | "loss_weights": "loss_weights",
52 | }
53 |
54 |
55 | @dataclasses.dataclass
56 | class FeatureSpec:
57 | """Container class for a feature's name, dtype, and rank."""
58 | name: str
59 | dtype: tf.DType
60 | rank: int
61 | sequence_dim: int
62 |
63 | @classmethod
64 | def to_map(
65 | cls, feature_specs: Iterable[FeatureSpecConfig]) -> Mapping[str, Any]:
66 | feature_spec_map = {}
67 | for name, dtype_str, rank, sequence_dim in feature_specs:
68 | feature_spec_map[name] = cls(
69 | name, getattr(tf, dtype_str), rank, sequence_dim)
70 | return feature_spec_map
71 |
72 |
73 | class DualEncoderFeatureConverterFactory(object):
74 | """Factory for dual encoder feature converters."""
75 |
76 | def __init__(self, feature_specs: Iterable[FeatureSpecConfig],
77 | is_multimodal: bool = False):
78 | self.feature_specs = feature_specs
79 | self.is_multimodal = is_multimodal
80 |
81 | def __call__(self,
82 | pack: bool = False,
83 | use_custom_packing_ops: bool = False):
84 | feature_spec_map = FeatureSpec.to_map(self.feature_specs)
85 | return DualEncoderFeatureConverter(
86 | input_features=feature_spec_map.values(),
87 | is_multimodal=self.is_multimodal,
88 | pack=pack,
89 | use_custom_packing_ops=use_custom_packing_ops)
90 |
91 |
92 | class DualEncoderFeatureConverter(FeatureConverter):
93 | """Feature converter for dual-encoder achitecture.
94 |
95 | The inputs and targets to the dual-encoder are sent to the left and right
96 | encoders separately.
97 |
98 | Attributes:
99 | input_features: a list of feature specs that are used to define the
100 | feature's name, dtype and rank.
101 | is_multimodal: a boolean variable to indicate whether it is a feature
102 | converter to multimodal inputs. For multimodal inputs, we don't use the
103 | default MODEL_FEATURES_MAPPING.
104 | """
105 | input_features: Sequence[FeatureSpec] = ()
106 | is_multimodal: bool = False
107 |
108 | def __init__(self,
109 | input_features: Sequence[FeatureSpec],
110 | is_multimodal: bool = False,
111 | pack: bool = False,
112 | use_custom_packing_ops: bool = True):
113 | self.input_features = input_features
114 | self.is_multimodal = is_multimodal
115 | # NOTE: for multimodal inputs, make sure the inputs either (1) include both
116 | # "left_" and "right_" features for training or (2) don't include any of
117 | # them for inference time.
118 | if self.is_multimodal:
119 | has_left, has_right = False, False
120 | for f in self.input_features:
121 | if "left_" in f.name:
122 | has_left = True
123 | elif "right_" in f.name:
124 | has_right = True
125 | if (has_left and not has_right) or (not has_left and has_right):
126 | raise ValueError(
127 | "Multimodal inputs features should have both left and right tower"
128 | "features for training."
129 | )
130 | super().__init__(pack=pack, use_custom_packing_ops=use_custom_packing_ops)
131 |
132 | @property
133 | def TASK_FEATURES(self):
134 | feature_specs_map = {
135 | f.name: seqio.FeatureConverter.FeatureSpec(
136 | dtype=f.dtype, rank=f.rank, sequence_dim=f.sequence_dim)
137 | for f in self.input_features
138 | }
139 | return feature_specs_map
140 |
141 | # NOTE: only use the _MODEL_FEATURES_MAPPING for non-multimodal inputs.
142 | @property
143 | def MODEL_FEATURES(self):
144 | feature_specs_map = {}
145 | for f in self.input_features:
146 | name = f.name if self.is_multimodal else _MODEL_FEATURES_MAPPING[f.name]
147 | feature_specs_map[name] = seqio.FeatureConverter.FeatureSpec(
148 | dtype=f.dtype, rank=f.rank, sequence_dim=f.sequence_dim)
149 | return feature_specs_map
150 |
151 | @property
152 | def PACKING_FEATURE_DTYPES(self):
153 | return None
154 |
155 | def _convert_features(self, ds: tf.data.Dataset,
156 | input_lengths: Mapping[str, int]) -> tf.data.Dataset:
157 | """Convert the input dataset to an output dataset to be fed to the model.
158 |
159 | The conversion process involves three steps
160 |
161 | 1. Each feature in the `input_lengths` is padded.
162 | 2. "inputs" fields are mapped to the left encoder input and "targets" are
163 | mapped to right encoder input.
164 |
165 | Assume the input dataset has two examples each with "inputs" and "targets".
166 |
167 | ds = [{"inputs": [7, 8, 5, 1], "targets": [3, 9, 1]},
168 | {"inputs": [8, 4, 9, 3, 1], "targets": [4, 1]}]
169 |
170 | task_feature_lengths = {"inputs": 8, "targets": 4}
171 |
172 | First, the `inputs` are padded to length 8 and assigned to
173 | "left_encoder_input_tokens" field. The `targets` are processed similarly.
174 |
175 | converted_ds = [
176 | {
177 | "left_encoder_input_tokens": [7, 8, 5, 1, 0, 0, 0, 0],
178 | "right_encoder_input_tokens": [3, 9, 1, 0],
179 | },
180 | {
181 | "left_encoder_input_tokens": [8, 4, 9, 3, 1, 0, 0, 0],
182 | "right_encoder_input_tokens": [4, 1, 0, 0],
183 | },
184 | ]
185 |
186 | Args:
187 | ds: an input tf.data.Dataset to be converted.
188 | input_lengths: a mapping from a feature to its length
189 |
190 | Returns:
191 | ds: the converted dataset.
192 | """
193 |
194 | def convert_example(
195 | features: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
196 | d = {}
197 | for f in self.input_features:
198 | name = f.name if self.is_multimodal else _MODEL_FEATURES_MAPPING[f.name]
199 | d[name] = features[f.name]
200 | return d
201 |
202 | if self.pack:
203 | raise ValueError(
204 | "Dual encoder only takes non packed examples at this moment."
205 | )
206 |
207 | # Stop padding features with rank > 1, since _pack_or_pad adds padding to
208 | # the first dimension instead of the last dimension.
209 | for f_name in input_lengths:
210 | if f_name in self.TASK_FEATURES and self.TASK_FEATURES[f_name].rank > 1:
211 | # input should already be padded and dense.
212 | input_lengths = dict(input_lengths)
213 | if isinstance(input_lengths, ml_collections.ConfigDict):
214 | input_lengths.unlock()
215 | del input_lengths[f_name]
216 |
217 | ds = self._pack_or_pad(ds, input_lengths)
218 | return ds.map(
219 | convert_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
220 |
221 | def get_model_feature_lengths(
222 | self, task_feature_lengths: Mapping[str, int]) -> Mapping[str, int]:
223 | """Define the length relationship between input and output features."""
224 | model_feature_lengths = {}
225 | for k in self.TASK_FEATURES:
226 | model_feature = k if self.is_multimodal else _MODEL_FEATURES_MAPPING[k]
227 | model_feature_lengths[model_feature] = task_feature_lengths[k]
228 | if self.pack:
229 | raise ValueError("Packing not supported")
230 |
231 | return model_feature_lengths
232 |
233 |
234 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # T5X Retrieval
2 |
3 | T5X Retrieval is a JAX implementation of T5 (Text-to-Text Transfer Transformer) optimized for retrieval applications.
4 | It is built on top of T5 on JAX, aka [T5X](https://github.com/google-research/t5x).
5 | This is targeted at Natural Language Understanding researchers as well as application developers who are aiming to use the latest T5-based Transformer models for search, retrieval and ranking applications, but in the JAX framework as opposed to TensorFlow.
6 |
7 | T5X Retrieval is an efficient training and evaluation framework that supports transformer-based neural retrieval and ranking models such as sentence encoders and dense retrieval models. It supports multi-pod large model training, large cross-batch negatives and the capability to initialize from any pre-trained model trained using T5X.
8 |
9 | This launch open sources the training and inference code, including references to TFDS for training data, actual model training code (Python JAX & Flaxformer), pre-trained models and basic inference example code. This end-to-end example model code is meant to accompany the SentenceT5 and Generalizable T5 Retrieval models that includes the implementation and performance on relevant benchmarks.
10 |
11 |
12 | # What's here
13 |
14 | - configs/\*.gin - Model configurations
15 | - tasks.py - Task definitions that generate the dataset.
16 | - feature_converters.py - Converters that transform the task features from the dataset to model features
17 | - models.py - High-level models, such as DualEncoderDecoderModel, that take the feature converters ouputs as inputs.
18 |
19 | For more details about the training pipeline and task definitions, you can check out [T5X](https://github.com/google-research/t5x) and [Seqio](https://github.com/google/seqio).
20 |
21 | # Quickstart (Recommended)
22 |
23 | T5X Retrieval supports the training and evaluation options provided by [T5X](https://github.com/google-research/t5x). It can be run with [XManager](https://github.com/deepmind/xmanager) on
24 | [Vertex AI](https://cloud.google.com/vertex-ai), which is a platform for
25 | training that creates TPU instances and runs code on the TPUs.
26 |
27 | We briefly summarized steps to quickly start the training and inference jobs. You can find more details at the [T5X Quickstart](https://github.com/google-research/t5x#quickstart-recommended).
28 |
29 | 0. [Create a GCP project](https://github.com/deepmind/xmanager#create-a-gcp-project-optional). Create the bucket to store data and models.
30 |
31 | 1. Follow the pre-requisites and directions to install [XManager](https://github.com/deepmind/xmanager).
32 |
33 | 2. [Optional] GCP projects come with 8 cores by default, which is enough to run one training experiment on a single TPU host. Request TPU quota as required if you want to run multi-host training or multiple runs in parallel.
34 |
35 | 3. Install all dependencies such as [T5X](https://github.com/google-research/t5x), [Flaxformer](https://github.com/google/flaxformer), [TFDS](https://github.com/tensorflow/datasets).
36 |
37 | 4. Launch the xmanager script located at `t5x/scripts/xm_launch.py`.
38 |
39 | As a running example, we use the [BEIR MS Marco dataset](https://github.com/beir-cellar/beir#beers-available-datasets).
40 |
41 | ```sh
42 | # Export GOOGLE_CLOUD_BUCKET_NAME to a proper value.
43 | export GOOGLE_CLOUD_BUCKET_NAME=...
44 | export TFDS_DATA_DIR=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x_retrieval/data
45 | export MODEL_DIR=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x_retrieval/$(date +%Y%m%d)
46 |
47 | # Install dependencies.
48 | git clone https://github.com/google-research/t5x /tmp/t5x
49 | git clone https://github.com/google-research/t5x_retrieval /tmp/t5x_retrieval
50 | git clone https://github.com/google/flaxformer /tmp/flaxformer
51 | git clone https://github.com/google/aqt.git /tmp/aqt
52 |
53 | cd /tmp/t5x/
54 |
55 | python3 t5x/scripts/xm_launch.py \
56 | --pip_install="apache_beam[gcp]" \
57 | --model_dir=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/msmarco_ft_$(date +%Y%m%d) \
58 | --tfds_data_dir=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/data \
59 | --project_dirs=/tmp/t5x_retrieval/t5x_retrieval,/tmp/flaxformer/flaxformer,/tmp/aqt/aqt \
60 | --gin_file=t5x_retrieval/configs/models/de_t5_base.gin \
61 | --gin.INITIAL_CHECKPOINT_PATH=\"gs://t5-data/pretrained_models/t5x/t5_base/checkpoint_999900\" \
62 | --gin_file=t5x_retrieval/configs/runs/finetune.gin \
63 | --gin.TRAIN_STEPS=1009900 \
64 | --gin.utils.create_learning_rate_scheduler.step_offset=999900 \
65 | --gin.utils.create_learning_rate_scheduler.warmup_steps=1000 \
66 | --gin.utils.create_learning_rate_scheduler.decay_factor=0.00000125 \
67 | --gin.USE_CACHED_TASKS=False \
68 | --gin.models.DualEncoderModel.use_negatives=False \
69 | --gin.train.eval_period=500 \
70 | --gin.utils.SaveCheckpointConfig.keep=10 \
71 | --gin.utils.SaveCheckpointConfig.period=500 \
72 | --gin.train/DatasetConfig.batch_size=512 \
73 | --gin.MIXTURE_OR_TASK_NAME="'beir_msmarco_retrieval'" \
74 | --gin.MIXTURE_OR_TASK_MODULE="'t5x_retrieval.tasks'" \
75 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 64, 'targets': 256}"
76 | ```
77 |
78 | Notes:
79 |
80 | - Check `gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/` for the output artifacts, which can be read by TensorBoard.
81 | - Add `--pip_install="apache_beam[gcp]"` to the script if you have not downloaded the dataset before hand.
82 | - The `TRAIN_STEPS = step_offset + real_train_steps`, where `step_offset` is the step of the loaded checkpoint while the `real_train_step` is the steps that the model will be trained for.
83 |
84 | # Models
85 |
86 | ## Sentence encoders
87 | **SentenceT5** is a family of high performing sentence encoders trained using T5X Retrieval. The sentenceT5 models encode text into high-dimensional vectors that can be used for text classification, semantic similarity, clustering and other natural language processing tasks.
88 |
89 | SentenceT5 models are built on top of the Text-To-Text Transfer Transformer (T5). It is trained on a variety of data sources and initialized from pre-trained T5 models with different model sizes as described in [1]. The input is variable-length English text and the output is a 768-dimensional vector. Note that there's no hard length limit for T5 (i.e., no 512 tokens limit as in BERT), but that it's been trained to produce good embeddings for approximately sentence length text.
90 |
91 | ### Metrics
92 |
93 | * We evaluate this model on the
94 | [SentEval](https://github.com/facebookresearch/SentEval) sentence
95 | representation benchmark.
96 |
97 | Transfer tasks | MR | CR | SUBJ | MPQA | SST | TREC | MRPC | Average
98 | :------------------------------------------------------------ | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ------:
99 | **ST5-Base** | 85.8 | 92.1 | 94.6 | 90.9 | 91.8 | 96.4 | 75.2 | 89.5
100 | [ST5-Large](https://tfhub.dev/google/sentence-t5/st5-large/1) | 88.9 | 93.5 | 95.4 | 91.5 | 94.2 | 96.2 | 77.1 | 91.0
101 | [ST5-3B](https://tfhub.dev/google/sentence-t5/st5-3b/1) | 89.9 | 94.1 | 95.9 | 91.6 | 94.8 | 96.2 | 77.9 | 91.5
102 | [ST5-11B](https://tfhub.dev/google/sentence-t5/st5-11b/1) | 90.8 | 94.4 | 96.3 | 91.7 | 94.8 | 95.4 | 77.9 | 91.6
103 |
104 |
105 |
106 | STS tasks | STS12 | STS13 | STS14 | STS15 | STS16 | STSb | SICK-R | Average
107 | :------------------------------------------------------------ | ----: | ----: | ----: | ----: | ----: | ---: | -----: | ------:
108 | **ST5-Base** | 78.1. | 85.8 | 82.2 | 87.5 | 84.0 | 86.0 | 79.8 | 83.3
109 | [ST5-Large](https://tfhub.dev/google/sentence-t5/st5-large/1) | 79.1 | 87.3 | 83.2 | 88.3 | 84.4 | 86.7 | 79.8 | 84.1
110 | [ST5-3B](https://tfhub.dev/google/sentence-t5/st5-3b/1) | 79.0 | 88.8 | 84.3 | 88.9 | 85.3 | 86.3 | 79.5 | 84.6
111 | [ST5-11B](https://tfhub.dev/google/sentence-t5/st5-11b/1) | 80.1 | 88.8 | 84.7 | 88.9 | 85.2 | 86.8 | 80.4 | 85.0
112 |
113 | More details about the evaluations can be found in the paper [1].
114 |
115 | ## Dense retrieval models
116 | The **Generalizable T5 Retrieval** models are dual encoders that encode two pieces of text into two dense
117 | vectors respectively [2]. This is typically used to encode a query and a document to
118 | compute their similarity for dense retrieval.
119 |
120 | GTR models are built on top of [T5](https://arxiv.org/pdf/1910.10683.pdf) (i.e.
121 | the Text-To-Text Transfer Transformer). The GTR-Base model employs a 12-layer
122 | transformer architecture, which is the same as the T5 base model. The model is
123 | first initialized from the pre-trained T5 checkpoint. It is then further
124 | pre-trained with a set of community question-answer pairs we collected. Finally,
125 | the model is fine-tuned on the [MS Marco](https://microsoft.github.io/msmarco/)
126 | dataset.
127 |
128 | The two encoders are [shared](https://arxiv.org/pdf/2204.07120.pdf) so the GTR model functions as a single text encoder.
129 | The input is variable-length English text and the output is a 768-dimensional
130 | vector.
131 |
132 | ### Metrics
133 |
134 | We evaluate on the [BEIR](https://github.com/UKPLab/beir) benchmark and report the Recall@100.
135 |
136 | Dataset \ Model | **GTR-Base** | [GTR-Large](https://tfhub.dev/google/gtr/gtr-large/1) | [GTR-XL](https://tfhub.dev/google/gtr/gtr-xl/1) | [GTR-XXL](https://tfhub.dev/google/gtr/gtr-xxl/1)
137 | ---------------- | ------------ | ----------------------------------------------------- | ----------------------------------------------- | -------------------------------------------------
138 | MS MARCO | 0.898 | 0.908 | 0.911 | 0.916
139 | Trec-Covid | 0.411 | 0.434 | 0.457 | 0.407
140 | BioASQ | 0.441 | 0.490 | 0.483 | 0.483
141 | NFCorpus | 0.275 | 0.298 | 0.318 | 0.300
142 | NQ | 0.893 | 0.930 | 0.936 | 0.946
143 | HotpotQA | 0.676 | 0.725 | 0.739 | 0.752
144 | FiQA-2018 | 0.670 | 0.742 | 0.755 | 0.780
145 | Signal-1M | 0.263 | 0.261 | 0.268 | 0.268
146 | Trec-News | 0.475 | 0.525 | 0.512 | 0.544
147 | Robust04 | 0.324 | 0.365 | 0.364 | 0.372
148 | ArguAna | 0.974 | 0.978 | 0.980 | 0.983
149 | Touché-2020 | 0.281 | 0.282 | 0.297 | 0.301
150 | Quora | 0.996 | 0.996 | 0.997 | 0.997
151 | DBPedia-entity | 0.418 | 0.480 | 0.480 | 0.494
152 | SCIDOCS | 0.340 | 0.358 | 0.358 | 0.366
153 | Fever | 0.923 | 0.941 | 0.944 | 0.947
154 | Climate-Fever | 0.522 | 0.552 | 0.569 | 0.556
155 | SciFact | 0.872 | 0.899 | 0.911 | 0.900
156 | CQADupStack | 0.681 | 0.714 | 0.729 | 0.740
157 | Avg | 0.596 | 0.625 | 0.632 | 0.634
158 | Avg w/o MS MARCO | 0.580 | 0.609 | 0.616 | 0.619
159 |
160 |
161 | # Released Model Checkpoints
162 |
163 | We have released the following checkpoints for SentenceT5 and GTR pre-trained models:
164 |
165 | * **SentenceT5-Base** ([config](t5x_retrieval/configs/models/de_t5_base.gin), 110M parameters): [gs://t5-data/pretrained_models/t5x/retrieval/st5_base](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/st5_base)
166 | * **SentenceT5-Large** ([config](t5x_retrieval/configs/models/de_t5_large.gin), 335M parameters): [gs://t5-data/pretrained_models/t5x/retrieval/st5_large](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/st5_large/)
167 | * **SentenceT5-XL** ([config](t5x_retrieval/configs/models/de_t5_3B.gin), 1.24B parameters): [gs://t5-data/pretrained_models/t5x/retrieval/st5_xl](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/st5_xl/)
168 | * **SentenceT5-XXL** ([config](t5x_retrieval/configs/models/de_t5_11B.gin), 4.8B parameters): [gs://t5-data/pretrained_models/t5x/retrieval/st5_xxl](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/st5_xxl/)
169 | * **GTR-Base** ([config](t5x_retrieval/configs/models/de_t5_base.gin), 110M parameters): [gs://t5-data/pretrained_models/t5x/retrieval/gtr_base](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/gtr_base/)
170 | * **GTR-Large** ([config](t5x_retrieval/configs/models/de_t5_large.gin), 335M parameters): [gs://t5-data/pretrained_models/t5x/retrieval/gtr_large](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/gtr_large/)
171 | * **GTR-XL** ([config](t5x_retrieval/configs/models/de_t5_3B.gin), 1.24B parameters): [gs://t5-data/pretrained_models/t5x/retrieval/gtr_xl](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/gtr_xl/)
172 | * **GTR-XXL** ([config](t5x_retrieval/configs/models/de_t5_11B.gin), 4.8B parameters): [gs://t5-data/pretrained_models/t5x/retrieval/gtr_xxl](https://console.cloud.google.com/storage/browser/t5-data/pretrained_models/t5x/retrieval/gtr_xxl/)
173 |
174 |
175 | # References
176 |
177 | [1] Jianmo, Ni, Gustavo Hernández Ábrego, Noah Constant, Ji Ma, Keith B. Hall,
178 | Daniel Cer, Yinfei Yang.
179 | [Sentence-t5: Scalable sentence encoders from pre-trained text-to-text models.](https://arxiv.org/abs/2108.08877)
180 | ACL 2022.
181 |
182 | [2] Jianmo Ni, Chen Qu, Jing Lu, Zhuyun Dai, Gustavo Hernández Ábrego,
183 | Ji Ma, Vincent Zhao, Yi Luan, Keith B. Hall, Ming-wei Chang, Yinfei Yang.
184 | [Large Dual Encoders Are Generalizable Retrievers.](https://arxiv.org/abs/2112.07899)
185 | December 2021.
186 |
187 | This is not an officially supported Google product.
188 |
--------------------------------------------------------------------------------
/t5x_retrieval/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Retrieval Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Copyright 2022 Google LLC
16 | #
17 | # Licensed under the Apache License, Version 2.0 (the "License");
18 | # you may not use this file except in compliance with the License.
19 | # You may obtain a copy of the License at
20 | #
21 | # http://www.apache.org/licenses/LICENSE-2.0
22 | #
23 | # Unless required by applicable law or agreed to in writing, software
24 | # distributed under the License is distributed on an "AS IS" BASIS,
25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 | # See the License for the specific language governing permissions and
27 | # limitations under the License.
28 | """Utility functions for training and evaluation.
29 |
30 | """
31 |
32 | import time
33 | from typing import Callable, Optional, Sequence, Type
34 |
35 | import clu.metrics
36 | import flax
37 | from flax.training import common_utils
38 | import jax
39 | from jax import lax
40 | from jax.experimental import multihost_utils
41 | import jax.numpy as jnp
42 | import numpy as np
43 | import seqio
44 | import sklearn.metrics
45 | from t5x import checkpoints as t5x_checkpoints
46 | from t5x import losses as t5x_losses
47 | from t5x import state_utils as t5x_state_utils
48 | from t5x import utils as t5x_utils
49 | import tensorflow as tf
50 |
51 | DatasetConfig = t5x_utils.DatasetConfig
52 | PyTreeDef = type(jax.tree_util.tree_structure(None))
53 |
54 |
55 | # ===== Datasets ===== #
56 | def get_batch_unmixed_dataset(
57 | cfg: DatasetConfig,
58 | shard_id: int,
59 | num_shards: int,
60 | feature_converter_cls: Type[seqio.FeatureConverter],
61 | num_epochs: Optional[int] = None,
62 | continue_from_last_checkpoint: bool = False) -> tf.data.Dataset:
63 | """Returns a dataset by sampling each batch from each single task."""
64 | if continue_from_last_checkpoint:
65 | raise ValueError(
66 | '`continue_from_last_checkpoint` must be set to False as this is not '
67 | 'supported by this dataset fn.')
68 | del continue_from_last_checkpoint
69 |
70 | if cfg.batch_size % num_shards:
71 | raise ValueError(
72 | f'Batch size ({cfg.batch_size}) must be divisible by number of '
73 | f'shards ({num_shards}).')
74 |
75 | shard_info = seqio.ShardInfo(index=shard_id, num_shards=num_shards)
76 |
77 | if cfg.seed is None:
78 | # Use a shared timestamp across devices as the seed.
79 | seed = multihost_utils.broadcast_one_to_all(np.int32(time.time()))
80 | else:
81 | seed = cfg.seed
82 |
83 | num_epochs = None # repeat indefinitely.
84 |
85 | mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name)
86 | if not isinstance(mixture_or_task, seqio.Mixture):
87 | raise ValueError('Only SeqIO Mixture supports batch unmixed data accesss')
88 |
89 | datasets = []
90 | rates = []
91 | for task in mixture_or_task.tasks:
92 | cfg.mixture_or_task_name = task.name
93 | # Returns a batched dataset.
94 | datasets.append(
95 | t5x_utils.get_dataset_inner(cfg, shard_info, feature_converter_cls,
96 | seed, num_epochs))
97 | rates.append(mixture_or_task.get_rate(task))
98 | return tf.data.experimental.sample_from_datasets(datasets, rates)
99 |
100 |
101 | # ===== Losses ===== #
102 | # More details about alignment and uniformity loss can be found at
103 | # https://arxiv.org/pdf/2005.10242.pdf. They measure the quality of embeddings.
104 | def compute_align_loss(x: jnp.array, y: jnp.array, alpha: int = 2):
105 | loss = jnp.linalg.norm(x - y, ord=2, axis=1)
106 | return lax.pow(loss, 1.0 * alpha).mean()
107 |
108 |
109 | def compute_uniform_loss(xs: jnp.array, t: int = 2):
110 | """Computes the euclidean distance between every pair of row vectors in the input."""
111 | distance_kernel = lambda x, y: jnp.sqrt(jnp.sum((x - y)**2))
112 | loss = jax.vmap(lambda x: jax.vmap(lambda y: distance_kernel(x, y))(xs))(xs)
113 | return jnp.log(jnp.exp(-t * lax.pow(loss, 2.0)).mean())
114 |
115 |
116 | def sigmoid_cross_entropy_with_logits(logits: jnp.array, labels: jnp.array):
117 | """Compute binary cross entropy loss with logits input.
118 |
119 | Args:
120 | logits: similarity scores
121 | labels: ground-truth labels
122 |
123 | Returns:
124 | cross-entropy loss
125 | """
126 |
127 | x, z = logits, labels
128 | # Follow the implementation in Tensorflow:
129 | # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn_impl.py#L186
130 | loss = jnp.maximum(x, 0) - x * z + jnp.log(1 + jnp.exp(-jnp.absolute(x)))
131 | return loss
132 |
133 |
134 | def apply_temperature(probs: jnp.array, temperature: float):
135 | """Apply a temperature to probabilities."""
136 | if temperature <= 0:
137 | raise ValueError('Temperature must be positive.')
138 | inv_temp = 1.0 / temperature
139 | x = lax.pow(probs, inv_temp)
140 | y = lax.pow(1 - probs, inv_temp)
141 | return x / (x + y)
142 |
143 |
144 | def binary_cross_entropy_with_logits(logits: jnp.array,
145 | labels: jnp.array,
146 | weights=None,
147 | temperature=1.0):
148 | """Binary cross-entropy loss function.
149 |
150 | This loss is used for point-wise distillation where the goal is to match the
151 | point-wise logits for each example to match the gold labels. The temperature
152 | is commonly added on both the logits and the labels as described in previous
153 | distillation methods: https://arxiv.org/abs/1503.02531.
154 |
155 | Args:
156 | logits: similarity scores corresponding to all pairs in the batch, with
157 | shape [batch_size].
158 | labels: labels for each pair.
159 | weights: weights for each pair.
160 | temperature: temperature for the distillation.
161 |
162 | Returns:
163 | binary cross-entropy loss for distillation.
164 | """
165 | logits = logits / temperature
166 | labels = apply_temperature(labels, temperature)
167 | loss = sigmoid_cross_entropy_with_logits(logits, labels)
168 |
169 | if weights:
170 | loss = loss * weights
171 | # Multiply the loss by temperature^2, to keep it on the same scale as the
172 | # batch softmax loss.
173 | return (temperature**2) * loss
174 |
175 |
176 | def sparse_labels_for_in_batch_cross_entropy(logits: jnp.array) -> jnp.array:
177 | """Generates labels assuming the diagnoal |logits| are ground truth."""
178 | return jnp.arange(logits.shape[0])
179 |
180 |
181 | def in_batch_cross_entropy(
182 | logits: jnp.array,
183 | labels: Optional[jnp.array] = None,
184 | weights: Optional[jnp.array] = None,
185 | off_diagonal_positive_mask: Optional[jnp.array] = None,
186 | num_negatives: jnp.int32 = 0,
187 | label_smoothing: float = 0.0,
188 | reduce_fn: Optional[Callable[[jnp.array],
189 | jnp.array]] = jnp.mean) -> jnp.array:
190 | """In batch cross-entropy loss function.
191 |
192 | This corresponds to computing a softmax for each row (and optionally each
193 | column) of the similarities matrix, where the diagonal element corresponds
194 | to the only positive pair, and off-diagonal elements are random negatives.
195 |
196 | Args:
197 | logits: [batch_size, batch_size + sample_size]. Tensor of similarities
198 | between all pairs of a left and a right element in a batch, with shape
199 | [batch_size, batch_size + sample_size] where sample_size is the number of
200 | extra negative right examples, if any.
201 | labels: [batch_size, batch_size + sample_size]. If None, then this function
202 | generates one hot labels which assumes the diagonal elements correspond to
203 | the ground truth.
204 | weights: [batch_size]. Weights for each pair (or row).
205 | off_diagonal_positive_mask: a Boolean Tensor of [batch_size, batch_size].
206 | A Tensor masks the off-diagonal positives. With off-diagonal positives
207 | masked by 1, and others with 0, including the diagonal positives.
208 | num_negatives: number of negative examples. Default to 0.
209 | label_smoothing: label smoothing constant, used to determine the on and off
210 | values.
211 | reduce_fn: Callable on how to reduce losses to a scalar. If none, returns
212 | row-wise loss.
213 |
214 | Returns:
215 | [batch_size] array of cross entropy losses if reduce_fn is None. Otherwise,
216 | return a scalar of reduced row loss.
217 | """
218 | if (num_negatives > 0) and (off_diagonal_positive_mask is None):
219 | raise ValueError(
220 | 'num_negatives (positive integer) is used only in create '
221 | 'off_diagonal_positive_mask.')
222 |
223 | normalizing_constant = 0.0
224 | if labels is None:
225 | num_classes = logits.shape[-1]
226 | sparse_labels = sparse_labels_for_in_batch_cross_entropy(logits)
227 | confidence, low_confidence, normalizing_constant = 1.0, 0.0, 0.0
228 | if num_classes > 1:
229 | confidence = 1.0 - label_smoothing
230 | low_confidence = (1.0 - confidence) / (num_classes - 1)
231 | normalizing_constant = -(
232 | confidence * jnp.log(confidence) +
233 | (num_classes - 1) * low_confidence * jnp.log(low_confidence + 1e-20))
234 | labels = common_utils.onehot(
235 | sparse_labels,
236 | num_classes,
237 | on_value=confidence,
238 | off_value=low_confidence)
239 |
240 | if off_diagonal_positive_mask is None:
241 | row_loss, _ = t5x_losses.cross_entropy_with_logits(
242 | logits, labels, z_loss=0.0)
243 | else:
244 | off_diagonal_positive_mask = off_diagonal_positive_mask * (
245 | 1 - jnp.identity(logits.shape[0]))
246 | if num_negatives > 0:
247 | off_diagonal_positive_mask = jnp.pad(
248 | off_diagonal_positive_mask, [(0, 0), (0, num_negatives)])
249 | masked_logits = jnp.where(off_diagonal_positive_mask, -1e+9, logits)
250 | row_loss, _ = t5x_losses.cross_entropy_with_logits(
251 | masked_logits, labels, z_loss=0.0)
252 |
253 | row_loss = row_loss - normalizing_constant
254 | if weights:
255 | row_loss = row_loss * weights
256 |
257 | return reduce_fn(row_loss) if reduce_fn is not None else row_loss
258 |
259 |
260 | def get_off_diagonal_positive_mask(reference_labels: jnp.ndarray,
261 | return_dtype: jnp.dtype = bool):
262 | """Construct mask for off diagonal positives from reference labels.
263 |
264 | For a given batch, the off diagonal mask is constructed by comparing the
265 | reference labels across examples. E.g. if the batch size is 3, and
266 | the reference label is [11, 22, 11], the off diagonal positive mask is
267 | [[0, 0, 1],
268 | [0, 0, 0],
269 | [1, 0, 0]]
270 | indicating the 1st and the 3nd are mutually positive, as their labels are
271 | same.
272 |
273 | Args:
274 | reference_labels: a [batch] Tensor, which stores the reference information
275 | to construct the mask.
276 | return_dtype: dtype for returned mask.
277 |
278 | Returns:
279 | mask: a [batch, batch] Boolean Tensor for off-diagonal positive mask.
280 | """
281 | reference_labels = jnp.squeeze(reference_labels)
282 | if reference_labels.ndim > 1:
283 | raise ValueError('Provide only 1 reference label per example in batch! '
284 | 'The reference_labels should be of the shape [batch].')
285 | mask = (reference_labels[None, :] == reference_labels[:, None]
286 | - jnp.identity(reference_labels.shape[0], dtype=bool))
287 | return mask.astype(return_dtype)
288 |
289 |
290 | # ===== Metrics ===== #
291 | @flax.struct.dataclass
292 | class AUC(clu.metrics.CollectingMetric.from_outputs(('labels', 'logits'))):
293 |
294 | def compute(self):
295 | labels_sum = jnp.sum(self.values['labels'])
296 | # Do not compute AUC if positives only have one class.
297 | if labels_sum == 0 or labels_sum == len(self.values['labels']):
298 | return 0.0
299 | return sklearn.metrics.roc_auc_score(self.values['labels'],
300 | self.values['logits'])
301 |
302 |
303 | def compute_auc(targets: jnp.array,
304 | predictions: jnp.array,
305 | targets_threshold=None):
306 | """Compute Area Under the ROC and PR curves.
307 |
308 | ROC - Receiver Operating Characteristic
309 | PR - Precision and Recall
310 |
311 | Args:
312 | targets: np.ndarray of targets, either 0 or 1, or continuous values.
313 | predictions: np.ndarray of predictions, any value.
314 | targets_threshold: float, if target values are continuous values, this
315 | threshold binarizes them.
316 |
317 | Returns:
318 | A dictionary with AUC-ROC and AUC-PR scores.
319 | """
320 |
321 | if targets_threshold is not None:
322 | targets = jnp.array(targets)
323 | targets = jnp.where(targets < targets_threshold,
324 | jnp.zeros_like(targets, dtype=jnp.int32),
325 | jnp.ones_like(targets, dtype=jnp.int32))
326 |
327 | a = jnp.min(predictions)
328 | b = jnp.max(predictions)
329 | scale = 3.0 / (b - a + 1e-6)
330 | scaled_predictions = scale * (2 * predictions - b - a)
331 | transformed_predictions = jax.nn.sigmoid(scaled_predictions)
332 | binarized_targets = jnp.round(targets)
333 | return {
334 | 'auc-roc':
335 | AUC.from_model_output(
336 | logits=transformed_predictions, labels=binarized_targets),
337 | }
338 |
339 |
340 | def compute_rr(logits: jnp.array, labels: jnp.array):
341 | """Compute Reciprocal Rank for in-batch examples.
342 |
343 | Args:
344 | logits: jnp.array of logits of shape [batch_size, batch_size]
345 | labels: jnp.array of indices indicating the positive example.
346 |
347 | Returns:
348 | An jnp.array of reciprocal rank of the positive example in-batch.
349 | """
350 | labels = jnp.expand_dims(labels, axis=-1)
351 | logits_desc = np.argsort(-logits, axis=-1)
352 | rank = (
353 | jnp.argwhere(logits_desc == labels, size=logits_desc.shape[0])[:, -1] + 1)
354 | return jnp.reciprocal(rank)
355 |
356 |
357 | # ===== Checkpoint ===== #
358 | def partially_load_checkpoint(
359 | excluded_patterns: Sequence[str],
360 | require_all_rules_match: bool = True
361 | ) -> t5x_checkpoints.RestoreStateTransformationFn:
362 | """Load a checkpoint partially, used in exports to trim the output SavedModel graph.
363 |
364 | Args:
365 | excluded_patterns: Checkpoint Optimizer param patterns to exclude from the
366 | export.
367 | require_all_rules_match: Whether to verify that all the patterns match
368 | correctly to a path in the checkpoint.
369 |
370 | Returns:
371 | A RestoreStateTransformationFn that excludes the pattern specified in the
372 | """
373 | assignment_map = [(pattern, None) for pattern in excluded_patterns]
374 |
375 | def _wrapped_assignment_map(
376 | ckpt_optimizer_state: PyTreeDef,
377 | _: PyTreeDef, # pylint: disable=unused-argument
378 | *,
379 | is_resuming: bool = False):
380 | """Remap the optimizer state to checkpoint optimizer state.
381 |
382 | Setting assignment maps in RestoreCheckpointConfig in load_t5x_checkpoint
383 | sets optimizer state to an empty dict, failing the assignments.
384 | Args:
385 | ckpt_optimizer_state: Checkpoint param state.
386 | is_resuming: `True` iff this restore call is due to a job resuming after
387 | being temporarily stopped due to, for example, a preemption. This is
388 | useful when there is restore logic that should run when restoring from
389 | some pre-existing checkpoint, but that should not run again when
390 | resuming from a newly-written checkpoint.
391 |
392 | Returns:
393 | The result of transforming the checkpoint state dict.
394 | """
395 | return t5x_state_utils.apply_assignment_map(
396 | ckpt_optimizer_state,
397 | ckpt_optimizer_state, # Optimizer State
398 | assignment_map,
399 | require_all_rules_match,
400 | is_resuming=is_resuming)
401 |
402 | return _wrapped_assignment_map
403 |
404 |
405 | def load_tower(side: str) -> t5x_checkpoints.RestoreStateTransformationFn:
406 | """Load a single `side` tower of an Asymmetric Dual Encoder.
407 |
408 | Args:
409 | side: The side of the tower to load. Available values are [left, right]
410 |
411 | Returns:
412 | A restore state transformation function that filters out the weights of the
413 | other tower. Only set if inference mode is set to `encode_{side}`.
414 | """
415 | assert side in ('left', 'right'), (
416 | f'Expected side to be one of [left, right], but is {side}')
417 | if side == 'left':
418 | return partially_load_checkpoint(excluded_patterns=[r'.*right_.*'])
419 | else:
420 | return partially_load_checkpoint(excluded_patterns=[r'.*left_.*'])
421 |
--------------------------------------------------------------------------------
/t5x_retrieval/models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Retrieval Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Copyright 2022 Google LLC
16 | #
17 | # Licensed under the Apache License, Version 2.0 (the "License");
18 | # you may not use this file except in compliance with the License.
19 | # You may obtain a copy of the License at
20 | #
21 | # http://www.apache.org/licenses/LICENSE-2.0
22 | #
23 | # Unless required by applicable law or agreed to in writing, software
24 | # distributed under the License is distributed on an "AS IS" BASIS,
25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 | # See the License for the specific language governing permissions and
27 | # limitations under the License.
28 | """T5X Retrieval Models.
29 |
30 | This module uses Flaxformer modules to build a higher-level model structure and
31 | define methods for the loss computation as well as a train, prediction, and
32 | evaluation steps.
33 | """
34 | # pylint: disable=attribute-defined-outside-init,g-bare-generic,g-multiple-import
35 | from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union
36 |
37 | from flax import linen as nn
38 | from flax import optim
39 | from flax.core import scope as flax_scope
40 | from flax.training import common_utils
41 | import jax
42 | from jax import lax
43 | import jax.numpy as jnp
44 | import ml_collections
45 | import numpy as np
46 |
47 | import seqio
48 | from t5x import losses as t5x_losses
49 | from t5x import metrics as metrics_lib
50 | from t5x import models as t5x_models
51 | from t5x import utils as t5x_utils
52 | from t5x_retrieval import losses
53 | from t5x_retrieval import utils
54 | import tensorflow as tf
55 |
56 | Array = Union[np.ndarray, jnp.ndarray, jax.pxla.ShardedDeviceArray, tf.Tensor]
57 | DType = jnp.dtype
58 | ConfigDict = ml_collections.ConfigDict
59 | PyTreeDef = type(jax.tree_util.tree_structure(None))
60 | Optimizer = optim.Optimizer
61 |
62 |
63 | LEFT_ENCODINGS = 'left_encodings'
64 | RIGHT_ENCODINGS = 'right_encodings'
65 |
66 |
67 | class DualEncoderBase(t5x_models.BaseTransformerModel):
68 | """Base class for dual encoder models.
69 |
70 | Subclasses must implement `score_batch` and `_compute_logits`.
71 | """
72 |
73 | FEATURE_CONVERTER_CLS: Callable[..., seqio.FeatureConverter]
74 |
75 | ALLOWED_INFERENCE_MODE = frozenset({'encode', 'similarity'})
76 |
77 | # TODO(b/262639556): Change loss_module from Optional to required once
78 | # loss-layers have been implemented
79 | def __init__(
80 | self,
81 | module: nn.Module,
82 | feature_converter_cls: Callable[[bool], seqio.FeatureConverter],
83 | input_vocabulary: seqio.Vocabulary,
84 | output_vocabulary: seqio.Vocabulary,
85 | optimizer_def: optim.OptimizerDef,
86 | inference_mode: str = 'encode',
87 | loss_module_factory: Optional[nn.Module] = None,
88 | ):
89 | self.FEATURE_CONVERTER_CLS = feature_converter_cls # pylint: disable=invalid-name
90 | self._inference_mode = inference_mode
91 |
92 | self.loss_module = None
93 | # TODO(b/262639556): Remove check once loss-layer is not Optional
94 | if loss_module_factory:
95 | self.loss_module = loss_module_factory()
96 | self.loss_module.validate_model_features(feature_converter_cls(False))
97 |
98 | super(DualEncoderBase, self).__init__(
99 | module=module,
100 | input_vocabulary=input_vocabulary,
101 | output_vocabulary=output_vocabulary,
102 | optimizer_def=optimizer_def)
103 |
104 | def get_initial_variables(
105 | self,
106 | rng: jnp.ndarray,
107 | input_shapes: Mapping[str, Array],
108 | input_types: Optional[Mapping[str, DType]] = None
109 | ) -> flax_scope.FrozenVariableDict:
110 | """Get the initial variables for an dual-encoder model."""
111 | input_types = {} if input_types is None else input_types
112 | encoder_type = input_types.get('left_encoder_input_tokens', jnp.float32)
113 | left_encoder_shape = input_shapes['left_encoder_input_tokens']
114 | right_encoder_shape = input_shapes['right_encoder_input_tokens']
115 | initial_variables = self.module.init(
116 | rng,
117 | jnp.ones(left_encoder_shape, encoder_type),
118 | jnp.ones(right_encoder_shape, encoder_type),
119 | enable_dropout=False)
120 |
121 | # TODO(b/262639556): Remove check once loss-layer is not Optional
122 | if self.loss_module:
123 | loss_variables = self.loss_module.get_initial_variables(
124 | rng, input_shapes, input_types)
125 | initial_variables = initial_variables.copy(loss_variables)
126 |
127 | return initial_variables
128 |
129 | def loss_weights(self, batch: Mapping[str,
130 | jnp.ndarray]) -> Optional[jnp.ndarray]:
131 | raise NotImplementedError('Not implemented for dual encoder.')
132 |
133 | def predict_batch_with_aux(
134 | self,
135 | params: Mapping[str, Array],
136 | batch: Mapping[str, jnp.ndarray],
137 | rng: Optional[jax.random.KeyArray] = None,
138 | ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
139 | raise NotImplementedError(
140 | 'Autoregressive prediction is not implemented for dual encoder.')
141 |
142 | def _encode_batch(self, params: Mapping[str, Array],
143 | batch: Mapping[str, jnp.ndarray]) -> Array:
144 | """Encode the embeddings for the inputs."""
145 | return self.module.apply(
146 | {'params': params},
147 | batch['left_encoder_input_tokens'],
148 | # Disable the dropout during inference.
149 | enable_dropout=False,
150 | method=self.module.encode)
151 |
152 | def _similarity_batch(self,
153 | params: Mapping[str, Array],
154 | batch: Mapping[str, jnp.ndarray],
155 | return_intermediates: bool = False) -> Array:
156 | """Score the similarity of the left and right inputs."""
157 | left_encodings, right_encodings, logits = self.module.apply(
158 | {'params': params},
159 | batch['left_encoder_input_tokens'],
160 | batch['right_encoder_input_tokens'],
161 | enable_dropout=False)
162 | if return_intermediates:
163 | return logits, {
164 | LEFT_ENCODINGS: (left_encodings,),
165 | RIGHT_ENCODINGS: (right_encodings,),
166 | }
167 | else:
168 | return logits
169 |
170 | def score_batch(self,
171 | params: Mapping[str, Array],
172 | batch: Mapping[str, jnp.ndarray],
173 | return_intermediates: bool = False) -> jnp.ndarray:
174 | """Model prediction for batch.
175 |
176 | Args:
177 | params: Model parameters.
178 | batch: A batch of inputs.
179 | return_intermediates: Whether to return intermediates.
180 |
181 | Returns:
182 | an array of encodings or similarity scores (with optional intermediates).
183 | """
184 | if self._inference_mode not in self.ALLOWED_INFERENCE_MODE:
185 | raise ValueError(
186 | 'Invalid `inference_mode`: %s. Supported inference mode: %s.' %
187 | (self._inference_mode, self.ALLOWED_INFERENCE_MODE))
188 | if self._inference_mode == 'encode':
189 | return self._encode_batch(params, batch)
190 | elif self._inference_mode == 'similarity':
191 | return self._similarity_batch(params, batch, return_intermediates)
192 |
193 |
194 | class DualEncoderModel(DualEncoderBase):
195 | """Model class for Dual Encoder."""
196 |
197 | ALLOWED_INFERENCE_MODE = frozenset({
198 | 'encode', 'encode_left', 'encode_right', 'similarity',
199 | 'pointwise_similarity'
200 | })
201 |
202 | def __init__(
203 | self,
204 | module: nn.Module,
205 | feature_converter_cls: Callable[[bool], seqio.FeatureConverter],
206 | input_vocabulary: seqio.Vocabulary,
207 | output_vocabulary: seqio.Vocabulary,
208 | optimizer_def: optim.OptimizerDef,
209 | loss_module_factory: Optional[losses.DualEncoderLoss] = None,
210 | inference_mode: str = 'encode',
211 | use_negatives: bool = False,
212 | use_align_uniform: bool = False,
213 | logit_scale: float = 100,
214 | logit_margin: float = 0.0,
215 | ):
216 | """Initialization function.
217 |
218 | Args:
219 | module: Flax module.
220 | feature_converter_cls: SeqIO feature converters to apply to the dataset.
221 | input_vocabulary: Vocabulary for the input features.
222 | output_vocabulary: Vocabulary for the output features.
223 | optimizer_def: Optimizer.
224 | loss_module_factory: Factory to produce loss module.
225 | inference_mode: Inference mode (e.g. encode or similarity).
226 | use_negatives: Whether to use hard negatives. If True, the model encodes
227 | the additional feature for hard negatives on the right tower.
228 | use_align_uniform: Whether to compute alignment and uniformity metrics. If
229 | True, compute alignment and uniformity metrics.
230 | logit_scale: A factor for logits scaling.
231 | logit_margin: A constant for logits margin.
232 | """
233 | self._use_negatives = use_negatives
234 | self._use_align_uniform = use_align_uniform
235 | self._logit_scale = logit_scale
236 | self._logit_margin = logit_margin
237 | super(DualEncoderModel, self).__init__(
238 | module=module,
239 | feature_converter_cls=feature_converter_cls,
240 | input_vocabulary=input_vocabulary,
241 | output_vocabulary=output_vocabulary,
242 | optimizer_def=optimizer_def,
243 | inference_mode=inference_mode,
244 | loss_module_factory=loss_module_factory)
245 |
246 | def _compute_logits(
247 | self,
248 | params: Mapping[str, Any],
249 | batch: Mapping[str, jnp.ndarray],
250 | dropout_rng: Optional[jnp.ndarray] = None
251 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
252 | """Computes logits via a forward pass of `self.module_cls`."""
253 | # Dropout is provided only for the training mode.
254 | rngs = {'dropout': dropout_rng} if dropout_rng is not None else None
255 |
256 | if not self._use_negatives and 'right_negative_encoder_input_tokens' in batch:
257 | ValueError(
258 | 'Invalid module. Please select `DualEncoderWithNegativesModel` for negative inputs.'
259 | )
260 |
261 | if self._use_negatives and 'right_negative_encoder_input_tokens' not in batch:
262 | ValueError(
263 | 'Invalid inputs. Please prepare negative inputs for DualEncoderWithNegativesModel.'
264 | )
265 |
266 | if self._use_negatives:
267 | left_tokens = batch['left_encoder_input_tokens']
268 | right_positive_tokens = batch['right_encoder_input_tokens']
269 | right_negative_tokens = batch['right_negative_encoder_input_tokens']
270 |
271 | # left/right_encoder_input_tokens should be 2d tensor.
272 | assert left_tokens.ndim == 2
273 | assert right_positive_tokens.ndim == 2
274 | # right_negative_encoder_input_tokens can be a 2d tensor (when feature
275 | # spec set up for single negative) or a 3d tensor (when feature spec
276 | # set up for multiple negatives).
277 | assert right_negative_tokens.ndim == 2 or right_negative_tokens.ndim == 3
278 |
279 | # All tensors should have the same batch size.
280 | batch_size = right_positive_tokens.shape[0]
281 | assert left_tokens.shape[0] == batch_size
282 | assert right_negative_tokens.shape[0] == batch_size
283 |
284 | if right_negative_tokens.ndim == 3:
285 | # We have multiple negatives, so need to reshape the
286 | # right_negative_encoder_input_tokens.
287 |
288 | # Right positive and negative should have the same sequence length.
289 | right_seq_length = right_positive_tokens.shape[1]
290 | assert right_seq_length == right_negative_tokens.shape[2]
291 |
292 | num_negatives = right_negative_tokens.shape[1]
293 | right_negative_tokens = jnp.reshape(
294 | right_negative_tokens,
295 | (batch_size * num_negatives, right_seq_length))
296 |
297 | (left_encodings, right_encodings,
298 | logits), _ = self.module.apply({'params': params},
299 | left_tokens,
300 | right_positive_tokens,
301 | right_negative_tokens,
302 | enable_dropout=rngs is not None,
303 | rngs=rngs,
304 | mutable='dropout')
305 |
306 | # `left_logits` is of shape [B, B*(1+num_negatives)] that considers the
307 | # negatives while `right_logits` is in shape [B, B] that doesn't considers
308 | # negatives. `num_negatives` could be greater than 1 in the future.
309 | left_logits, right_logits = logits, jnp.dot(right_encodings,
310 | left_encodings.transpose())
311 | else:
312 | (left_encodings, right_encodings, logits), _ = self.module.apply(
313 | {'params': params},
314 | batch['left_encoder_input_tokens'],
315 | batch['right_encoder_input_tokens'],
316 | enable_dropout=rngs is not None,
317 | rngs=rngs,
318 | mutable='dropout')
319 | left_logits, right_logits = logits, logits.transpose()
320 |
321 | left_logits *= self._logit_scale
322 | right_logits *= self._logit_scale
323 |
324 | # Only additive margin to the logits for training mode.
325 | # For details please check https://arxiv.org/abs/1902.08564. The tensor
326 | # shapes are not changed after scaling.
327 | if dropout_rng is not None and self._logit_margin != 0:
328 | left_logits = (
329 | left_logits - self._logit_margin *
330 | jnp.eye(N=left_logits.shape[0], M=left_logits.shape[1]))
331 | right_logits = (
332 | right_logits - self._logit_margin * jnp.eye(right_logits.shape[0]))
333 |
334 | return left_encodings, right_encodings, left_logits, right_logits
335 |
336 | def loss_fn(
337 | self,
338 | params: Mapping[str, Any],
339 | batch: Mapping[str, jnp.ndarray],
340 | dropout_rng: Optional[jnp.ndarray],
341 | ) -> Tuple[jnp.ndarray, jnp.ndarray]:
342 | """Loss function used for training with a cross-entropy loss."""
343 |
344 | left_encodings, right_encodings, left_logits, right_logits = self._compute_logits(
345 | params, batch, dropout_rng)
346 | loss, metrics = self.loss_module.apply({'params': params},
347 | batch=batch,
348 | logits=left_logits,
349 | right_logits=right_logits)
350 | if self._use_align_uniform:
351 | align_loss = utils.compute_align_loss(left_encodings, right_encodings)
352 | uniform_loss = utils.compute_uniform_loss(
353 | left_encodings) + utils.compute_uniform_loss(right_encodings)
354 | metrics.update({
355 | 'align_loss':
356 | metrics_lib.AveragePerStep.from_model_output(align_loss),
357 | 'uniform_loss':
358 | metrics_lib.AveragePerStep.from_model_output(uniform_loss),
359 | })
360 |
361 | return loss, metrics
362 |
363 | def _encode_batch(self, params: Mapping[str, Array],
364 | batch: Mapping[str, jnp.ndarray]) -> Array:
365 | """Encode the embeddings for the inputs."""
366 | if self._inference_mode == 'encode_right':
367 | encoder_input_tokens = batch['right_encoder_input_tokens']
368 | else:
369 | encoder_input_tokens = batch['left_encoder_input_tokens']
370 | return self.module.apply(
371 | {'params': params},
372 | encoder_input_tokens,
373 | # Disable the dropout during inference.
374 | enable_dropout=False,
375 | method=self.module.encode)
376 |
377 | def score_batch(self,
378 | params: Mapping[str, Array],
379 | batch: Mapping[str, jnp.ndarray],
380 | return_intermediates: bool = False) -> jnp.ndarray:
381 | """Model prediction for batch.
382 |
383 | Args:
384 | params: Model parameters.
385 | batch: A batch of inputs.
386 | return_intermediates: Whether to return intermediates.
387 |
388 | Returns:
389 | an array of encodings or similarity scores (with optional intermediates).
390 | """
391 | if self._inference_mode not in self.ALLOWED_INFERENCE_MODE:
392 | raise ValueError(
393 | 'Invalid `inference_mode`: %s. Supported inference mode: %s.' %
394 | (self._inference_mode, self.ALLOWED_INFERENCE_MODE))
395 | if self._inference_mode.startswith('encode'):
396 | return self._encode_batch(params, batch)
397 | elif self._inference_mode == 'similarity':
398 | return self._similarity_batch(params, batch, return_intermediates)
399 | elif self._inference_mode == 'pointwise_similarity':
400 | if return_intermediates:
401 | logits, intermediates = (
402 | self._similarity_batch(params, batch, return_intermediates))
403 | return jnp.diagonal(logits), intermediates
404 | else:
405 | logits = self._similarity_batch(params, batch, return_intermediates)
406 | return jnp.diagonal(logits)
407 |
408 |
409 |
--------------------------------------------------------------------------------