├── .gitignore
├── LICENSE
├── README.md
├── create_data
├── README.md
├── __init__.py
├── tfdatasets
│ ├── __init__.py
│ ├── alpaca
│ │ └── build.py
│ ├── coco_all
│ │ ├── __init__.py
│ │ └── build.py
│ └── flan
│ │ └── build.py
└── utils.py
├── demo.ipynb
├── demo
├── __init__.py
├── hifigan
│ ├── README.md
│ ├── checkpoints
│ │ ├── config.json
│ │ └── g_00930000
│ ├── config_v1.json
│ ├── config_v2.json
│ ├── config_v3.json
│ ├── env.py
│ ├── inference.py
│ ├── inference_e2e.py
│ ├── meldataset.py
│ ├── models.py
│ ├── requirements.txt
│ ├── resample.py
│ ├── test_mel.py
│ ├── test_mel2.py
│ ├── train copy.py
│ ├── train.py
│ └── utils.py
└── utils
│ ├── __init__.py
│ ├── audio_utils.py
│ └── video_utils.py
├── metadata
└── coco
│ └── coco_class_name_2017.json
├── setup.py
└── t5x
├── __init__.py
├── adafactor.py
├── adafactor_test.py
├── binary_search.py
├── binary_search_test.py
├── checkpoint_importer.py
├── checkpoint_importer_test.py
├── checkpoint_importer_vqgan.py
├── checkpoint_utils.py
├── checkpoint_utils_test.py
├── checkpoints.py
├── checkpoints_test.py
├── configs
├── __init__.py
└── runs
│ ├── __init__.py
│ ├── debug.gin
│ ├── eval.gin
│ ├── export.gin
│ ├── export_seqio.gin
│ ├── finetune.gin
│ ├── infer.gin
│ ├── infer_from_tfexample_file.gin
│ ├── loss.gin
│ ├── multitask.gin
│ ├── precompile.gin
│ ├── pretrain.gin
│ └── vit_vqgan.gin
├── decoding.py
├── eval.py
├── examples
├── __init__.py
└── unified_io
│ ├── __init__.py
│ ├── audio_encoder.py
│ ├── audio_vqgan.py
│ ├── aux_fns.py
│ ├── config.py
│ ├── data
│ ├── __init__.py
│ ├── data_utils.py
│ ├── mixtures.py
│ ├── nlp_instruction_following.py
│ ├── postprocessing.py
│ ├── preprocessing.py
│ ├── prompt_definition.py
│ ├── prompt_dict.py
│ ├── tasks.py
│ └── visualization_utils.py
│ ├── decoding.py
│ ├── evaluator.py
│ ├── image_encoder.py
│ ├── image_vqgan.py
│ ├── input_modalities.py
│ ├── layers.py
│ ├── metrics
│ ├── grit_keypoint.py
│ ├── grit_localization.py
│ ├── grit_normal.py
│ ├── grit_segmentation.py
│ ├── grit_vqa.py
│ ├── metrics.py
│ ├── rle.py
│ └── utils.py
│ ├── modality_processing.py
│ ├── model_test.py
│ ├── models.py
│ ├── network.py
│ ├── network_test.py
│ ├── packing.py
│ ├── packing_test.py
│ ├── perceiver.py
│ ├── scripts
│ ├── __init__.py
│ └── dataset_visualize.py
│ ├── seq_features.py
│ ├── seq_features_test.py
│ ├── t5_1_1
│ ├── base.gin
│ ├── eval
│ │ └── vision_language.gin
│ ├── finetune
│ │ └── refexp.gin
│ ├── large.gin
│ ├── tiny.gin
│ ├── xl.gin
│ └── xxl.gin
│ ├── target_modalities.py
│ ├── test_utils.py
│ ├── utils.py
│ └── vocabularies.py
├── export.py
├── export_lib.py
├── gin_utils.py
├── gin_utils_test.py
├── losses.py
├── losses_test.py
├── main.py
├── metrics.py
├── metrics_test.py
├── models.py
├── models_test.py
├── optimizers.py
├── optimizers_test.py
├── partitioning.py
├── partitioning_test.py
├── precompile.py
├── state_utils.py
├── state_utils_test.py
├── test_utils.py
├── testdata
├── mtf_tiny_t5
│ ├── checkpoint
│ ├── graph.pbtxt
│ ├── model-info.txt
│ ├── model.ckpt-0.data-00000-of-00002
│ ├── model.ckpt-0.data-00001-of-00002
│ ├── model.ckpt-0.index
│ ├── model.ckpt-0.meta
│ └── operative_config.gin
├── pinned_ckpt_dir
│ └── PINNED
└── test_t5_tiny.checkpoint_0
├── train.py
├── train_eval.py
├── train_state.py
├── train_state_test.py
├── trainer.py
├── trainer_test.py
├── utils.py
├── utils_test.py
└── version.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
162 | *.pyc
163 | .vscode/
164 | wandb/*
165 |
--------------------------------------------------------------------------------
/create_data/README.md:
--------------------------------------------------------------------------------
1 | # Unified-IO-2 Datasets
2 |
3 | This directory contains code to build datasets in a format
4 | that can be consumed by UnifiedIO.
5 |
6 | Some of this install scripts additional dependencies:
7 |
8 | ```
9 | python3 -m pip install -e '.[data]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -f https://storage.googleapis.com/jax-releases/jax_releases.html
10 | ```
--------------------------------------------------------------------------------
/create_data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/create_data/__init__.py
--------------------------------------------------------------------------------
/create_data/tfdatasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/create_data/tfdatasets/__init__.py
--------------------------------------------------------------------------------
/create_data/tfdatasets/alpaca/build.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import tensorflow as tf
4 | import tensorflow_datasets as tfds
5 |
6 |
7 | class Alpaca(tfds.core.GeneratorBasedBuilder):
8 | VERSION = tfds.core.Version('1.0.0')
9 |
10 | def _info(self) -> tfds.core.DatasetInfo:
11 | features = tfds.features.FeaturesDict(dict(
12 | example_num=tfds.features.Tensor(shape=(), dtype=tf.int32),
13 | instruction=tfds.features.Tensor(shape=(), dtype=tf.string),
14 | input=tfds.features.Tensor(shape=(), dtype=tf.string),
15 | output=tfds.features.Tensor(shape=(), dtype=tf.string),
16 | ))
17 | return tfds.core.DatasetInfo(builder=self, features=features)
18 |
19 | def _split_generators(self, dl_manager: tfds.download.DownloadManager):
20 | return {'train': self._generate_examples()}
21 |
22 | def _generate_examples(self):
23 | import datasets
24 | ds = datasets.load_dataset("yahma/alpaca-cleaned")["train"]
25 | for ix, ex in enumerate(ds):
26 | ex["example_num"] = ix
27 | yield ix, ex
28 |
29 |
30 | def main():
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument("data_dir")
33 | args = parser.parse_args()
34 |
35 | builder = Alpaca(data_dir=args.data_dir)
36 | builder.download_and_prepare()
37 |
38 |
39 | if __name__ == '__main__':
40 | main()
--------------------------------------------------------------------------------
/create_data/tfdatasets/coco_all/__init__.py:
--------------------------------------------------------------------------------
1 | """coco_all dataset."""
2 |
3 | from .build import CocoAll
4 |
--------------------------------------------------------------------------------
/create_data/tfdatasets/flan/build.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from collections import defaultdict
4 | from os.path import join
5 | import tensorflow as tf
6 | import tensorflow_datasets as tfds
7 | import numpy as np
8 | from PIL import Image
9 |
10 |
11 | MAPPING = {
12 | "Flan2021": "conceptofmind/flan2021_submix_original",
13 | "T0": "conceptofmind/t0_submix_original",
14 | "NIv2": "conceptofmind/niv2_submix_original",
15 | "CoT": "conceptofmind/cot_submix_original",
16 | "Dialog": "conceptofmind/dialog_submix_original",
17 | }
18 |
19 |
20 | class FLAN(tfds.core.GeneratorBasedBuilder):
21 | VERSION = tfds.core.Version('1.0.0')
22 |
23 | def __init__(self, src, **kwargs):
24 | self.src = src
25 | self.__class__.name = f"FLANv2-{src}"
26 | super().__init__(**kwargs)
27 |
28 | def _info(self) -> tfds.core.DatasetInfo:
29 | features = tfds.features.FeaturesDict(dict(
30 | example_num=tfds.features.Tensor(shape=(), dtype=tf.int32),
31 | inputs=tfds.features.Tensor(shape=(), dtype=tf.string),
32 | targets=tfds.features.Tensor(shape=(), dtype=tf.string),
33 | task_source=tfds.features.Tensor(shape=(), dtype=tf.string),
34 | task_name=tfds.features.Tensor(shape=(), dtype=tf.string),
35 | template_type=tfds.features.Tensor(shape=(), dtype=tf.string),
36 | ))
37 | return tfds.core.DatasetInfo(
38 | builder=self,
39 | features=features,
40 | )
41 |
42 | def _split_generators(self, dl_manager: tfds.download.DownloadManager):
43 | return {
44 | 'train': self._generate_examples(),
45 | }
46 |
47 | def _generate_examples(self):
48 | import datasets
49 | ds = datasets.load_dataset(MAPPING[self.src])["train"]
50 | for ix, ex in enumerate(ds):
51 | example = dict(example_num=ix)
52 | example.update({k: ex[k] for k in
53 | ["inputs", "targets", "task_source", "task_name", "template_type"]})
54 | yield ix, example
55 |
56 |
57 | def main():
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument("kind")
60 | parser.add_argument("data_dir")
61 | args = parser.parse_args()
62 |
63 | builder = FLAN(args.kind, data_dir=args.data_dir)
64 | builder.download_and_prepare()
65 |
66 |
67 | if __name__ == '__main__':
68 | main()
--------------------------------------------------------------------------------
/demo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/demo/__init__.py
--------------------------------------------------------------------------------
/demo/hifigan/README.md:
--------------------------------------------------------------------------------
1 | # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
2 |
3 | ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae
4 |
5 | In our [paper](https://arxiv.org/abs/2010.05646),
6 | we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
7 | We provide our implementation and pretrained models as open source in this repository.
8 |
9 | **Abstract :**
10 | Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms.
11 | Although such methods improve the sampling efficiency and memory usage,
12 | their sample quality has not yet reached that of autoregressive and flow-based generative models.
13 | In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis.
14 | As speech audio consists of sinusoidal signals with various periods,
15 | we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality.
16 | A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method
17 | demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than
18 | real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen
19 | speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times
20 | faster than real-time on CPU with comparable quality to an autoregressive counterpart.
21 |
22 | Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples.
23 |
24 |
25 | ## Pre-requisites
26 | 1. Python >= 3.6
27 | 2. Clone this repository.
28 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt)
29 | 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/).
30 | And move all wav files to `LJSpeech-1.1/wavs`
31 |
32 |
33 | ## Training
34 | ```
35 | python train.py --config config_v1.json
36 | ```
37 | To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
38 | Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
39 | You can change the path by adding `--checkpoint_path` option.
40 |
41 | Validation loss during training with V1 generator.
42 | 
43 |
44 | ## Pretrained Model
45 | You can also use pretrained models we provide.
46 | [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
47 | Details of each folder are as in follows:
48 |
49 | |Folder Name|Generator|Dataset|Fine-Tuned|
50 | |------|---|---|---|
51 | |LJ_V1|V1|LJSpeech|No|
52 | |LJ_V2|V2|LJSpeech|No|
53 | |LJ_V3|V3|LJSpeech|No|
54 | |LJ_FT_T2_V1|V1|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
55 | |LJ_FT_T2_V2|V2|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
56 | |LJ_FT_T2_V3|V3|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
57 | |VCTK_V1|V1|VCTK|No|
58 | |VCTK_V2|V2|VCTK|No|
59 | |VCTK_V3|V3|VCTK|No|
60 | |UNIVERSAL_V1|V1|Universal|No|
61 |
62 | We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets.
63 |
64 | ## Fine-Tuning
65 | 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
66 | The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
67 | Example:
68 | ```
69 | Audio File : LJ001-0001.wav
70 | Mel-Spectrogram File : LJ001-0001.npy
71 | ```
72 | 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
73 | 3. Run the following command.
74 | ```
75 | python train.py --fine_tuning True --config config_v1.json
76 | ```
77 | For other command line options, please refer to the training section.
78 |
79 |
80 | ## Inference from wav file
81 | 1. Make `test_files` directory and copy wav files into the directory.
82 | 2. Run the following command.
83 | ```
84 | python inference.py --checkpoint_file [generator checkpoint file path]
85 | ```
86 | Generated wav files are saved in `generated_files` by default.
87 | You can change the path by adding `--output_dir` option.
88 |
89 |
90 | ## Inference for end-to-end speech synthesis
91 | 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
92 | You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2),
93 | [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth.
94 | 2. Run the following command.
95 | ```
96 | python inference_e2e.py --checkpoint_file [generator checkpoint file path]
97 | ```
98 | Generated wav files are saved in `generated_files_from_mel` by default.
99 | You can change the path by adding `--output_dir` option.
100 |
101 |
102 | ## Acknowledgements
103 | We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips)
104 | and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this.
105 |
106 |
--------------------------------------------------------------------------------
/demo/hifigan/checkpoints/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "2",
3 | "num_gpus": 8,
4 | "batch_size": 512,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,4],
12 | "upsample_kernel_sizes": [16,16,8],
13 | "upsample_initial_channel": 256,
14 | "resblock_kernel_sizes": [3,5,7],
15 | "resblock_dilation_sizes": [[1,2], [2,6], [3,12]],
16 |
17 | "segment_size": 8192,
18 | "num_mels": 128,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 256,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 16000,
25 |
26 | "fmin": 0,
27 | "fmax": 8000,
28 | "fmax_for_loss": null,
29 |
30 | "num_workers": 8,
31 |
32 | "dist_config": {
33 | "dist_backend": "nccl",
34 | "dist_url": "tcp://127.0.0.1:52111",
35 | "world_size": 1
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/demo/hifigan/checkpoints/g_00930000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/demo/hifigan/checkpoints/g_00930000
--------------------------------------------------------------------------------
/demo/hifigan/config_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 0,
4 | "batch_size": 2,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,2,2],
12 | "upsample_kernel_sizes": [16,16,4,4],
13 | "upsample_initial_channel": 512,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "segment_size": 75264,
18 | "num_mels": 64,
19 | "n_fft": 1536,
20 | "hop_size": 588,
21 | "win_size": 1536,
22 |
23 | "sampling_rate": 22050,
24 |
25 | "fmin": 20.0,
26 | "fmax": 11025.0,
27 | "fmax_for_loss": 11025.0,
28 |
29 | "num_workers": 0,
30 |
31 | "dist_config": {
32 | "dist_backend": "nccl",
33 | "dist_url": "tcp://localhost:54321",
34 | "world_size": 1
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/demo/hifigan/config_v2.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 0,
4 | "batch_size": 16,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,2,2],
12 | "upsample_kernel_sizes": [16,16,4,4],
13 | "upsample_initial_channel": 128,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "segment_size": 8192,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 256,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 22050,
25 |
26 | "fmin": 0,
27 | "fmax": 8000,
28 | "fmax_for_loss": null,
29 |
30 | "num_workers": 4,
31 |
32 | "dist_config": {
33 | "dist_backend": "nccl",
34 | "dist_url": "tcp://localhost:54321",
35 | "world_size": 1
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/demo/hifigan/config_v3.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "2",
3 | "num_gpus": 8,
4 | "batch_size": 512,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,4],
12 | "upsample_kernel_sizes": [16,16,8],
13 | "upsample_initial_channel": 256,
14 | "resblock_kernel_sizes": [3,5,7],
15 | "resblock_dilation_sizes": [[1,2], [2,6], [3,12]],
16 |
17 | "segment_size": 8192,
18 | "num_mels": 128,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 256,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 16000,
25 |
26 | "fmin": 0,
27 | "fmax": 8000,
28 | "fmax_for_loss": null,
29 |
30 | "num_workers": 8,
31 |
32 | "dist_config": {
33 | "dist_backend": "nccl",
34 | "dist_url": "tcp://127.0.0.1:52111",
35 | "world_size": 1
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/demo/hifigan/env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | class AttrDict(dict):
6 | def __init__(self, *args, **kwargs):
7 | super(AttrDict, self).__init__(*args, **kwargs)
8 | self.__dict__ = self
9 |
10 |
11 | def build_env(config, config_name, path):
12 | t_path = os.path.join(path, config_name)
13 | if config != t_path:
14 | os.makedirs(path, exist_ok=True)
15 | shutil.copyfile(config, os.path.join(path, config_name))
16 |
--------------------------------------------------------------------------------
/demo/hifigan/inference.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 |
3 | import glob
4 | import os
5 | import argparse
6 | import json
7 | import torch
8 | from scipy.io.wavfile import write
9 | from env import AttrDict
10 | from meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav
11 | from models import Generator
12 |
13 | h = None
14 | device = None
15 |
16 |
17 | def load_checkpoint(filepath, device):
18 | assert os.path.isfile(filepath)
19 | print("Loading '{}'".format(filepath))
20 | checkpoint_dict = torch.load(filepath, map_location=device)
21 | print("Complete.")
22 | return checkpoint_dict
23 |
24 |
25 | def get_mel(x):
26 | return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
27 |
28 |
29 | def scan_checkpoint(cp_dir, prefix):
30 | pattern = os.path.join(cp_dir, prefix + '*')
31 | cp_list = glob.glob(pattern)
32 | if len(cp_list) == 0:
33 | return ''
34 | return sorted(cp_list)[-1]
35 |
36 |
37 | def inference(a):
38 | generator = Generator(h).to(device)
39 |
40 | state_dict_g = load_checkpoint(a.checkpoint_file, device)
41 | generator.load_state_dict(state_dict_g['generator'])
42 |
43 | filelist = os.listdir(a.input_wavs_dir)
44 |
45 | os.makedirs(a.output_dir, exist_ok=True)
46 |
47 | generator.eval()
48 | generator.remove_weight_norm()
49 | with torch.no_grad():
50 | for i, filname in enumerate(filelist):
51 | wav, sr = load_wav(os.path.join(a.input_wavs_dir, filname))
52 | wav = wav / MAX_WAV_VALUE
53 | wav = torch.FloatTensor(wav).to(device)
54 | x = get_mel(wav.unsqueeze(0))
55 | y_g_hat = generator(x)
56 | audio = y_g_hat.squeeze()
57 | audio = audio * MAX_WAV_VALUE
58 | audio = audio.cpu().numpy().astype('int16')
59 |
60 | output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated.wav')
61 | write(output_file, h.sampling_rate, audio)
62 | print(output_file)
63 |
64 |
65 | def main():
66 | print('Initializing Inference Process..')
67 |
68 | parser = argparse.ArgumentParser()
69 | parser.add_argument('--input_wavs_dir', default='test_files')
70 | parser.add_argument('--output_dir', default='generated_files')
71 | parser.add_argument('--checkpoint_file', required=True)
72 | a = parser.parse_args()
73 |
74 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json')
75 | with open(config_file) as f:
76 | data = f.read()
77 |
78 | global h
79 | json_config = json.loads(data)
80 | h = AttrDict(json_config)
81 |
82 | torch.manual_seed(h.seed)
83 | global device
84 | if torch.cuda.is_available():
85 | torch.cuda.manual_seed(h.seed)
86 | device = torch.device('cuda')
87 | else:
88 | device = torch.device('cpu')
89 |
90 | inference(a)
91 |
92 |
93 | if __name__ == '__main__':
94 | main()
95 |
96 |
--------------------------------------------------------------------------------
/demo/hifigan/inference_e2e.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 |
3 | import glob
4 | import os
5 | import numpy as np
6 | import argparse
7 | import json
8 | import torch
9 | from scipy.io.wavfile import write
10 | from env import AttrDict
11 | from meldataset import MAX_WAV_VALUE
12 | from models import Generator
13 |
14 | h = None
15 | device = None
16 |
17 |
18 | def load_checkpoint(filepath, device):
19 | assert os.path.isfile(filepath)
20 | print("Loading '{}'".format(filepath))
21 | checkpoint_dict = torch.load(filepath, map_location=device)
22 | print("Complete.")
23 | return checkpoint_dict
24 |
25 |
26 | def scan_checkpoint(cp_dir, prefix):
27 | pattern = os.path.join(cp_dir, prefix + '*')
28 | cp_list = glob.glob(pattern)
29 | if len(cp_list) == 0:
30 | return ''
31 | return sorted(cp_list)[-1]
32 |
33 |
34 | def inference(a):
35 | generator = Generator(h).to(device)
36 |
37 | state_dict_g = load_checkpoint(a.checkpoint_file, device)
38 | generator.load_state_dict(state_dict_g['generator'])
39 |
40 | filelist = os.listdir(a.input_mels_dir)
41 |
42 | os.makedirs(a.output_dir, exist_ok=True)
43 |
44 | generator.eval()
45 | generator.remove_weight_norm()
46 | with torch.no_grad():
47 | for i, filname in enumerate(filelist):
48 | x = np.load(os.path.join(a.input_mels_dir, filname))
49 | x = (x * 3.8312 - 5.0945)[:,:,0]
50 | x = torch.FloatTensor(x).to(device)
51 | y_g_hat = generator(x)
52 | audio = y_g_hat.squeeze()
53 | audio = audio * MAX_WAV_VALUE
54 | audio = audio.cpu().numpy().astype('int16')
55 |
56 | output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated_e2e.wav')
57 | write(output_file, h.sampling_rate, audio)
58 | print(output_file)
59 |
60 |
61 | def main():
62 | print('Initializing Inference Process..')
63 |
64 | parser = argparse.ArgumentParser()
65 | parser.add_argument('--input_mels_dir', default='test_mel_files')
66 | parser.add_argument('--output_dir', default='generated_files_from_mel')
67 | parser.add_argument('--checkpoint_file', required=True)
68 | a = parser.parse_args()
69 |
70 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json')
71 | with open(config_file) as f:
72 | data = f.read()
73 |
74 | global h
75 | json_config = json.loads(data)
76 | h = AttrDict(json_config)
77 |
78 | torch.manual_seed(h.seed)
79 | global device
80 | if torch.cuda.is_available():
81 | torch.cuda.manual_seed(h.seed)
82 | device = torch.device('cuda')
83 | else:
84 | device = torch.device('cpu')
85 |
86 | inference(a)
87 |
88 |
89 | if __name__ == '__main__':
90 | main()
91 |
92 |
--------------------------------------------------------------------------------
/demo/hifigan/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.4.0
2 | numpy==1.17.4
3 | librosa==0.7.2
4 | scipy==1.4.1
5 | tensorboard==2.0
6 | soundfile==0.10.3.post1
7 | matplotlib==3.1.3
--------------------------------------------------------------------------------
/demo/hifigan/resample.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from concurrent.futures import ProcessPoolExecutor
4 | from multiprocessing import cpu_count
5 |
6 | import torchaudio
7 | from torchaudio.functional import resample
8 |
9 | from tqdm import tqdm
10 | import subprocess
11 |
12 |
13 | def process_wav(in_path, out_path, sample_rate):
14 | # wav, sr = torchaudio.load(in_path)
15 | # wav = resample(wav, sr, sample_rate)
16 | out_path = Path(str(out_path).replace('mp4', 'wav'))
17 | ffmpeg_process = subprocess.Popen(
18 | ['ffmpeg', '-y', '-i', in_path, '-ac', '1', '-ar', str(sample_rate), out_path],
19 | stdout=-1, stderr=-1, text=True
20 | )
21 | stdout, stderr = ffmpeg_process.communicate(None, timeout=5.0)
22 | ffmpeg_process.kill()
23 | # torchaudio.save(out_path, wav, sample_rate)
24 | return out_path, 0# wav.size(-1) / sample_rate
25 |
26 |
27 | def preprocess_dataset(args):
28 | args.out_dir.mkdir(parents=True, exist_ok=True)
29 |
30 | futures = []
31 | executor = ProcessPoolExecutor(max_workers=cpu_count())
32 | print(f"Resampling audio in {args.in_dir}")
33 | for in_path in args.in_dir.rglob("*.mp4"):
34 | relative_path = in_path.relative_to(args.in_dir)
35 | out_path = args.out_dir / relative_path
36 | out_path.parent.mkdir(parents=True, exist_ok=True)
37 | # process_wav(in_path, out_path, args.sample_rate)
38 | futures.append(
39 | executor.submit(process_wav, in_path, out_path, args.sample_rate)
40 | )
41 | # import pdb; pdb.set_trace()
42 |
43 | results = [future.result() for future in tqdm(futures)]
44 |
45 | lengths = {path.stem: length for path, length in results}
46 | seconds = sum(lengths.values())
47 | hours = seconds / 3600
48 | print(f"Wrote {len(lengths)} utterances ({hours:.2f} hours)")
49 |
50 |
51 | if __name__ == "__main__":
52 | parser = argparse.ArgumentParser(description="Resample an audio dataset.")
53 | parser.add_argument(
54 | "in_dir", metavar="in-dir", help="path to the dataset directory.", type=Path
55 | )
56 | parser.add_argument(
57 | "out_dir", metavar="out-dir", help="path to the output directory.", type=Path
58 | )
59 | parser.add_argument(
60 | "--sample-rate",
61 | help="target sample rate (default 16kHz)",
62 | type=int,
63 | default=16000,
64 | )
65 | args = parser.parse_args()
66 | preprocess_dataset(args)
67 |
--------------------------------------------------------------------------------
/demo/hifigan/test_mel.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import os
3 | import sys
4 | import librosa
5 | import scipy.signal.windows
6 | import soundfile as sf
7 | import numpy as np
8 | from io import BytesIO
9 | from PIL import Image
10 | from scipy.io import wavfile
11 | import io
12 | import matplotlib.pyplot as plt
13 | from PIL import Image
14 | import requests
15 | import torch.nn.functional as F
16 | import torchaudio
17 | import torchaudio.transforms as transforms
18 | import torch
19 | from meldataset import mel_spectrogram
20 | from torchaudio.functional import resample
21 |
22 | window_size = 4.08
23 | sample_rate = 16000
24 | n_fft = 1024
25 | win_len = 1024
26 | hop_len=256
27 | n_mels = 128
28 | fmin = 0.0
29 | eps = 0.1
30 | max_wav_value=32768.0
31 | playback_speed = 1
32 | fmax = 8000
33 |
34 | with open("example.wav", "wb") as file:
35 | response = requests.get("https://drive.google.com/uc?export=preview&id=1Y3KuPAhB5VcsmIaokBVKu3LUEZOfhSu8")
36 | file.write(response.content)
37 |
38 | # logmel = LogMelSpectrogram()
39 |
40 | audio_fn = 'sample_1.wav'
41 | waveform1, sample_rate = librosa.load(audio_fn, sr=sample_rate)
42 |
43 | sr, waveform = wavfile.read(audio_fn, mmap=True)
44 | waveform = waveform.astype('float32')
45 | waveform /= max_wav_value
46 |
47 | st = float(60 * 0 + 0.0)
48 | start_idx = int(sr * st)
49 | # end_idx = start_idx + int(sr * window_size) * playback_speed
50 | end_idx = 8192
51 | waveform = waveform[start_idx:end_idx]
52 |
53 | waveform = torch.Tensor(waveform)
54 |
55 | torchaudio_melspec = transforms.MelSpectrogram(
56 | sample_rate=sample_rate,
57 | n_fft=n_fft,
58 | win_length=win_len,
59 | hop_length=hop_len,
60 | center=True,
61 | pad_mode="reflect",
62 | power=2.0,
63 | norm='slaney',
64 | onesided=True,
65 | n_mels=n_mels,
66 | mel_scale="slaney",
67 | f_min = fmin,
68 | f_max = sample_rate / 2.0,
69 | )(waveform)
70 |
71 | librosa_melspec = librosa.feature.melspectrogram(
72 | waveform.numpy(),
73 | sr=sample_rate,
74 | n_fft=n_fft,
75 | hop_length=hop_len,
76 | win_length=win_len,
77 | center=True,
78 | pad_mode="reflect",
79 | power=2.0,
80 | n_mels=n_mels,
81 | )
82 |
83 | torch_melspec = mel_spectrogram(waveform[None,:], n_fft, n_mels,
84 | sample_rate, hop_len, n_fft, fmin, fmax,
85 | center=True)
86 |
87 | torch_melspec = torch_melspec.squeeze(0)
88 |
89 | # mse = ((torch_melspec - librosa_melspec) ** 2).mean()
90 |
91 | import pdb; pdb.set_trace()
92 |
93 | # mel spectrogram extraction using librosa and torch audio.
94 | mel_lobrosa = librosa.feature.melspectrogram(y=y, sr=sr, **params)
95 |
96 | mel_torch = logmel(ty).numpy()
97 |
98 |
99 |
100 |
101 |
102 |
103 | diff = np.mean((mel_lobrosa - mel_torch) ** 2)
104 | import pdb; pdb.set_trace()
105 |
106 | import pdb; pdb.set_trace()
107 |
108 | import torch
109 | import numpy as np
110 |
111 | # Load checkpoint
112 | hifigan = torch.hub.load("bshall/hifigan:main", "hifigan_hubert_soft").cuda()
113 | # Load mel-spectrogram
114 |
115 | mel = torch.from_numpy(log_mel).unsqueeze(0).cuda()
116 | wav, sr = hifigan.generate(mel.cuda())
117 |
118 | wavfile.write("original.wav", sample_rate, y)
119 | wavfile.write("decoded.wav", sample_rate, wav.reshape(-1).cpu().numpy())
120 |
121 |
122 | def plot_spectrogram(log_mel, eps=0.1, ylabel='freq_bin', aspect='auto', xmax=None, to_db=True):
123 | fig, axs = plt.subplots(1, 1)
124 | spec = np.exp(log_mel + np.log(eps)) - eps
125 | if to_db:
126 | spec = librosa.power_to_db(spec, ref=np.max)
127 | axs.set_ylabel(ylabel)
128 | axs.set_xlabel('frame')
129 | im = axs.imshow(spec, origin='lower', aspect=aspect)
130 | if xmax:
131 | axs.set_xlim((0, xmax))
132 | fig.colorbar(im, ax=axs)
133 | fig.tight_layout(pad=0)
134 | fig.canvas.draw()
135 |
136 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
137 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
138 | plt.close(fig)
139 | return data
140 |
141 | img = plot_spectrogram(log_mel, eps)
142 |
143 | Image.fromarray(img).save('mel_spectrogram.png')
--------------------------------------------------------------------------------
/demo/hifigan/test_mel2.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import os
3 | import sys
4 | import librosa
5 | import scipy.signal.windows
6 | import soundfile as sf
7 | import numpy as np
8 | from io import BytesIO
9 | from PIL import Image
10 | from scipy.io import wavfile
11 | import io
12 | import matplotlib.pyplot as plt
13 | from PIL import Image
14 | import requests
15 | import torch.nn.functional as F
16 | import torchaudio
17 | import torchaudio.transforms as transforms
18 | import torch
19 | from meldataset import mel_spectrogram
20 | from torchaudio.functional import resample
21 |
22 | window_size = 4.08
23 | sample_rate = 16000
24 | n_fft = 1024
25 | win_len = 1024
26 | hop_len=256
27 | n_mels = 128
28 | fmin = 0.0
29 | eps = 0.1
30 | max_wav_value=32768.0
31 | playback_speed = 1
32 | fmax = 8000
33 |
34 | with open("example.wav", "wb") as file:
35 | response = requests.get("https://drive.google.com/uc?export=preview&id=1Y3KuPAhB5VcsmIaokBVKu3LUEZOfhSu8")
36 | file.write(response.content)
37 |
38 | # logmel = LogMelSpectrogram()
39 |
40 | audio_fn = 'sample_1.wav'
41 | waveform1, sample_rate = librosa.load(audio_fn, sr=sample_rate)
42 |
43 | sr, waveform = wavfile.read(audio_fn, mmap=True)
44 | waveform = waveform.astype('float32')
45 | waveform /= max_wav_value
46 |
47 | st = float(60 * 0 + 0.0)
48 | start_idx = int(sr * st)
49 | end_idx = start_idx + int(sr * window_size) * playback_speed
50 | waveform = waveform[start_idx:end_idx]
51 |
52 | waveform = torch.Tensor(waveform)
53 |
54 | librosa_melspec = librosa.feature.melspectrogram(
55 | waveform.numpy(),
56 | sr=sample_rate,
57 | n_fft=n_fft,
58 | hop_length=hop_len,
59 | win_length=win_len,
60 | center=True,
61 | pad_mode="reflect",
62 | power=2.0,
63 | n_mels=n_mels,
64 | )
65 |
66 | torch_melspec = mel_spectrogram(waveform[None,:], n_fft, n_mels,
67 | sample_rate, hop_len, n_fft, fmin, fmax,
68 | center=True)
69 |
70 | torch_melspec = torch_melspec.squeeze(0)
71 |
72 | mse = ((torch_melspec - librosa_melspec) ** 2).mean()
73 |
74 |
75 | import pdb; pdb.set_trace()
76 |
77 | # mel spectrogram extraction using librosa and torch audio.
78 | mel_lobrosa = librosa.feature.melspectrogram(y=y, sr=sr, **params)
79 | mel_torch = logmel(ty).numpy()
80 |
81 | diff = np.mean((mel_lobrosa - mel_torch) ** 2)
82 |
83 | import torch
84 | import numpy as np
85 |
86 | # Load checkpoint
87 | hifigan = torch.hub.load("bshall/hifigan:main", "hifigan_hubert_soft").cuda()
88 | # Load mel-spectrogram
89 |
90 | mel = torch.from_numpy(log_mel).unsqueeze(0).cuda()
91 | wav, sr = hifigan.generate(mel.cuda())
92 |
93 | wavfile.write("original.wav", sample_rate, y)
94 | wavfile.write("decoded.wav", sample_rate, wav.reshape(-1).cpu().numpy())
95 |
96 | def plot_spectrogram(log_mel, eps=0.1, ylabel='freq_bin', aspect='auto', xmax=None, to_db=True):
97 | fig, axs = plt.subplots(1, 1)
98 | spec = np.exp(log_mel + np.log(eps)) - eps
99 | if to_db:
100 | spec = librosa.power_to_db(spec, ref=np.max)
101 | axs.set_ylabel(ylabel)
102 | axs.set_xlabel('frame')
103 | im = axs.imshow(spec, origin='lower', aspect=aspect)
104 | if xmax:
105 | axs.set_xlim((0, xmax))
106 | fig.colorbar(im, ax=axs)
107 | fig.tight_layout(pad=0)
108 | fig.canvas.draw()
109 |
110 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
111 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
112 | plt.close(fig)
113 | return data
114 |
115 | img = plot_spectrogram(log_mel, eps)
116 |
117 | Image.fromarray(img).save('mel_spectrogram.png')
--------------------------------------------------------------------------------
/demo/hifigan/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import matplotlib
4 | import torch
5 | from torch.nn.utils import weight_norm
6 | matplotlib.use("Agg")
7 | import matplotlib.pylab as plt
8 |
9 |
10 | def plot_spectrogram(spectrogram):
11 | fig, ax = plt.subplots(figsize=(10, 2))
12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13 | interpolation='none')
14 | plt.colorbar(im, ax=ax)
15 |
16 | fig.canvas.draw()
17 | plt.close()
18 |
19 | return fig
20 |
21 |
22 | def init_weights(m, mean=0.0, std=0.01):
23 | classname = m.__class__.__name__
24 | if classname.find("Conv") != -1:
25 | m.weight.data.normal_(mean, std)
26 |
27 |
28 | def apply_weight_norm(m):
29 | classname = m.__class__.__name__
30 | if classname.find("Conv") != -1:
31 | weight_norm(m)
32 |
33 |
34 | def get_padding(kernel_size, dilation=1):
35 | return int((kernel_size*dilation - dilation)/2)
36 |
37 |
38 | def load_checkpoint(filepath, device):
39 | assert os.path.isfile(filepath)
40 | print("Loading '{}'".format(filepath))
41 | checkpoint_dict = torch.load(filepath, map_location=device)
42 | print("Complete.")
43 | return checkpoint_dict
44 |
45 |
46 | def save_checkpoint(filepath, obj):
47 | print("Saving checkpoint to {}".format(filepath))
48 | torch.save(obj, filepath)
49 | print("Complete.")
50 |
51 |
52 | def scan_checkpoint(cp_dir, prefix):
53 | pattern = os.path.join(cp_dir, prefix + '????????')
54 | cp_list = glob.glob(pattern)
55 | if len(cp_list) == 0:
56 | return None
57 | return sorted(cp_list)[-1]
58 |
59 |
--------------------------------------------------------------------------------
/demo/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .audio_utils import *
2 | from .video_utils import *
3 |
--------------------------------------------------------------------------------
/demo/utils/video_utils.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import random
3 | import string
4 | import subprocess
5 | import time
6 |
7 | import gradio as gr
8 |
9 | from create_data.utils import (
10 | get_video_length,
11 | create_audio_from_video,
12 | extract_frames_from_video,
13 | BUFFER_FROM_END,
14 | )
15 | from demo.utils.audio_utils import extract_spectrograms_from_audio
16 |
17 | __all__ = ["load_video"]
18 |
19 |
20 | def extract_frames_and_spectrograms_from_video(
21 | video_file,
22 | audio_dir,
23 | video_length=None,
24 | video_segment_length=None,
25 | audio_segment_length=None,
26 | times=None,
27 | clip_start_time=0,
28 | clip_end_time=None,
29 | num_frames=None,
30 | target_size=(256, 256),
31 | *,
32 | use_audio,
33 | ):
34 | if times is None:
35 | # get actual video length
36 | if video_length is None:
37 | video_length = get_video_length(video_file)
38 | if video_length is None:
39 | print(f"Couldn't get video length for {video_file}")
40 | return None, None
41 |
42 | if video_segment_length is None:
43 | video_segment_length = video_length / num_frames
44 | if video_length < (video_segment_length / 2.0) - BUFFER_FROM_END:
45 | print(
46 | f"Video is too short ({video_length}s is less than half the segment length of {video_segment_length}s segments"
47 | )
48 | return None, None
49 | else:
50 | # don't need this if times is given
51 | video_length = None
52 |
53 | # extract image frames
54 | # t0 = perf_counter()
55 | frames, boundaries = extract_frames_from_video(
56 | video_file,
57 | video_length,
58 | video_segment_length,
59 | times=times,
60 | clip_start_time=clip_start_time,
61 | clip_end_time=clip_end_time,
62 | num_frames=num_frames,
63 | multiprocess=False,
64 | resize=True,
65 | target_size=target_size,
66 | )
67 | # print(f"Load video in {perf_counter() - t0} seconds in total")
68 |
69 | spectrograms = None
70 | if use_audio:
71 | # expects the audio file to be created already (since it takes some time)
72 | audio_file = create_audio_from_video(video_file, audio_dir, force=True)
73 | if os.path.exists(audio_file): # in case video w/o audio
74 | # extract audio segments
75 | spectrograms = extract_spectrograms_from_audio(
76 | audio_file,
77 | audio_length=clip_end_time,
78 | audio_segment_length=audio_segment_length,
79 | spectrogram_length=audio_segment_length,
80 | )
81 |
82 | return frames, spectrograms
83 |
84 |
85 | def load_video(
86 | path: str,
87 | max_frames: int = 4,
88 | audio_segment_length: float = 4.08,
89 | target_size: tuple = (256, 256),
90 | *,
91 | use_audio: bool,
92 | ):
93 | """max frames could be max image history length + (1 if no image input else 0), similar for audio"""
94 | if path.startswith("http"):
95 | filetype = path.split(".")[-1]
96 | filename = "".join(random.choices(string.ascii_lowercase, k=8)) + "." + filetype
97 | cmd = f"wget -O {filename} {path}"
98 | print(cmd)
99 | _ = subprocess.Popen(
100 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True
101 | )
102 | gr.Info("Waiting for video download to finish!")
103 | time.sleep(10)
104 | path = filename
105 |
106 | assert os.path.exists(path), path
107 | video_length = get_video_length(path)
108 | clip_end_time, video_seg_length = None, None
109 | if video_length is not None:
110 | # accommodate the audio segment length
111 | max_length = max_frames * audio_segment_length
112 | clip_end_time = min(max_length, video_length)
113 | if clip_end_time < video_length:
114 | gr.Warning(
115 | f"Use the input video length of {clip_end_time} (original {video_length}) seconds."
116 | )
117 | # second per frame, not necessary corresponding with the audio in demo
118 | video_seg_length = clip_end_time / max_frames
119 |
120 | frames, spectrograms = extract_frames_and_spectrograms_from_video(
121 | path,
122 | audio_dir=None,
123 | video_length=video_length,
124 | video_segment_length=video_seg_length,
125 | audio_segment_length=audio_segment_length,
126 | clip_start_time=0,
127 | clip_end_time=clip_end_time,
128 | num_frames=None,
129 | use_audio=use_audio,
130 | target_size=target_size,
131 | )
132 | return frames, spectrograms
133 |
--------------------------------------------------------------------------------
/metadata/coco/coco_class_name_2017.json:
--------------------------------------------------------------------------------
1 | {"0": "person", "1": "bicycle", "2": "car", "3": "motorcycle", "4": "airplane", "5": "bus", "6": "train", "7": "truck", "8": "boat", "9": "traffic-light", "10": "fire-hydrant", "11": "stop-sign", "12": "parking-meter", "13": "bench", "14": "bird", "15": "cat", "16": "dog", "17": "horse", "18": "sheep", "19": "cow", "20": "elephant", "21": "bear", "22": "zebra", "23": "giraffe", "24": "backpack", "25": "umbrella", "26": "handbag", "27": "tie", "28": "suitcase", "29": "frisbee", "30": "skis", "31": "snowboard", "32": "sports ball", "33": "kite", "34": "baseball-bat", "35": "baseball-glove", "36": "skateboard", "37": "surfboard", "38": "tennis-racket", "39": "bottle", "40": "wine glass", "41": "cup", "42": "fork", "43": "knife", "44": "spoon", "45": "bowl", "46": "banana", "47": "apple", "48": "sandwich", "49": "orange", "50": "broccoli", "51": "carrot", "52": "hot dog", "53": "pizza", "54": "donut", "55": "cake", "56": "chair", "57": "couch", "58": "potted plant", "59": "bed", "60": "dining table", "61": "toilet", "62": "tv", "63": "laptop", "64": "mouse", "65": "remote", "66": "keyboard", "67": "cell phone", "68": "microwave", "69": "oven", "70": "toaster", "71": "sink", "72": "refrigerator", "73": "book", "74": "clock", "75": "vase", "76": "scissors", "77": "teddy-bear", "78": "hair-drier", "79": "toothbrush"}
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Install T5X."""
16 |
17 | import os
18 | import sys
19 | import setuptools
20 |
21 | # To enable importing version.py directly, we add its path to sys.path.
22 | version_path = os.path.join(os.path.dirname(__file__), 't5x')
23 | sys.path.append(version_path)
24 | from version import __version__ # pylint: disable=g-import-not-at-top
25 |
26 | # Get the long description from the README file.
27 | with open('README.md') as fp:
28 | _LONG_DESCRIPTION = fp.read()
29 |
30 | _jax_version = '0.2.27'
31 | _jaxlib_version = '0.1.76'
32 |
33 | setuptools.setup(
34 | name='t5x',
35 | version=__version__,
36 | description='UnifiedIO 2',
37 | long_description=_LONG_DESCRIPTION,
38 | long_description_content_type='text/markdown',
39 | license='Apache 2.0',
40 | packages=setuptools.find_packages(),
41 | package_data={
42 | '': ['**/*.gin'], # not all subdirectories may have __init__.py.
43 | },
44 | scripts=[],
45 | install_requires=[
46 | 'absl-py',
47 | 'cached_property',
48 | 'protobuf==3.19.4',
49 | 'google-api-core==2.8.2',
50 | # TODO(adarob): Replace with 'clu' once >0.0.6 is released.
51 | 'clu==0.0.8',
52 | 'flax==0.6.3',
53 | 'gin-config',
54 | f'jax==0.3.25',
55 | f'jaxlib==0.3.25',
56 | 'numpy',
57 | 'orbax==0.0.2',
58 | 't5==0.9.4',
59 | 'tensorflow==2.11.1',
60 | 'einops',
61 | 'tfds-nightly==4.8.3.dev202304050043',
62 | 'tensorflow_probability==0.19.0',
63 | 'tensorflow-addons==0.19.0',
64 | 'tensorflow-datasets==4.8.3',
65 | 'pycocoevalcap',
66 | 'tensorstore >= 0.1.20',
67 | 'librosa',
68 | 'scikit-image',
69 | 'wandb==0.14.0',
70 | "optax==0.1.4",
71 | "tqdm",
72 | "transforms3d==0.4.1",
73 | "pyglove==0.4.3",
74 | "seqio==0.0.8",
75 | ],
76 | extras_require={
77 | 'data': ['datasets', 'google-cloud-storage', "resampy"],
78 | "demo": ["resampy", 'google-cloud-storage', 'gradio==4.8.0', 'notebook', 'sk-video'],
79 | # Cloud TPU requirements.
80 | 'tpu': [f'jax[tpu]==0.3.25'],
81 | },
82 | classifiers=[
83 | 'Intended Audience :: Science/Research',
84 | 'License :: OSI Approved :: Apache Software License',
85 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
86 | ],
87 | )
--------------------------------------------------------------------------------
/t5x/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Import API modules."""
16 |
17 | import t5x.adafactor
18 | import t5x.checkpoints
19 | import t5x.decoding
20 | import t5x.gin_utils
21 | import t5x.losses
22 | import t5x.models
23 | import t5x.partitioning
24 | import t5x.state_utils
25 | import t5x.train_state
26 | import t5x.trainer
27 | import t5x.utils
28 |
29 | # Version number.
30 | from t5x.version import __version__
31 |
32 | # TODO(adarob): Move clients to t5x.checkpointing and rename
33 | # checkpoints.py to checkpointing.py
34 | checkpointing = t5x.checkpoints
35 |
--------------------------------------------------------------------------------
/t5x/binary_search_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for binary_search."""
16 |
17 | from absl.testing import absltest
18 | import jax
19 | import jax.numpy as jnp
20 | import numpy as np
21 | from t5x import binary_search
22 |
23 | _INT32_MIN = np.iinfo(np.int32).min
24 | _INT32_MAX = np.iinfo(np.int32).max
25 |
26 |
27 | class BinarySearchTest(absltest.TestCase):
28 |
29 | def test_int32_bsearch(self):
30 | a = jnp.asarray([
31 | 1,
32 | 43,
33 | 79,
34 | 2048,
35 | 0,
36 | 2047,
37 | _INT32_MIN,
38 | _INT32_MIN + 1,
39 | _INT32_MAX,
40 | _INT32_MAX - 1,
41 | ],
42 | dtype=jnp.int32)
43 |
44 | def predicate(x):
45 | return x > a
46 |
47 | r = binary_search.int32_bsearch(a.shape, predicate)
48 | np.testing.assert_array_equal(a, r)
49 |
50 | def test_int32_bsearch_extreme_predicates(self):
51 |
52 | def predicate_false(x):
53 | return jnp.full_like(x, False)
54 |
55 | np.testing.assert_array_equal(
56 | jnp.asarray([_INT32_MAX]),
57 | binary_search.int32_bsearch((1,), predicate_false))
58 |
59 | def predicate_true(x):
60 | return jnp.full_like(x, True)
61 |
62 | np.testing.assert_array_equal(
63 | jnp.asarray([_INT32_MIN]),
64 | binary_search.int32_bsearch((1,), predicate_true))
65 |
66 | def test_float32_bsearch(self):
67 | a = jnp.asarray([1.23, 0.0, -0.0, 105.4, -1024, 4.3], dtype=jnp.float32)
68 |
69 | def predicate(x):
70 | return x > a
71 |
72 | c = binary_search.float32_bsearch(a.shape, predicate)
73 | # Given that the predicate is based on floating point '>' as implemented by
74 | # JAX, we need our equality test to be based on floating point '==' as
75 | # implemented by JAX, rather than np.testing.assert_array_equal.
76 | #
77 | # Some corner cases on subnormal numbers may be different, depending on what
78 | # platform we run on.
79 | self.assertTrue(jnp.all(a == c), f'a={a}, c={c}')
80 |
81 | def test_topk_mask(self):
82 | mask = -1e10
83 | x = jnp.asarray([
84 | [1.4, 7.9, -4.3, 100, 71, 6, -1e4],
85 | [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, -100],
86 | ])
87 |
88 | # Using exact equality here, because topk_mask guarantees it: it is just
89 | # masking some things, not doing arithmetic on the array.
90 | np.testing.assert_array_equal(
91 | jnp.asarray([
92 | [mask, mask, mask, 100, mask, mask, mask],
93 | [mask, mask, mask, mask, mask, 9.7, mask],
94 | ]),
95 | binary_search.topk_mask(x, 1, mask),
96 | )
97 | np.testing.assert_array_equal(
98 | jnp.asarray([
99 | [mask, mask, mask, 100, 71, mask, mask],
100 | [8.3, mask, mask, mask, mask, 9.7, mask],
101 | ]),
102 | binary_search.topk_mask(x, 2, mask),
103 | )
104 | np.testing.assert_array_equal(
105 | jnp.asarray([
106 | [mask, 7.9, mask, 100, 71, mask, mask],
107 | [8.3, mask, 1.3, mask, mask, 9.7, mask],
108 | ]),
109 | binary_search.topk_mask(x, 3, mask),
110 | )
111 | np.testing.assert_array_equal(
112 | jnp.asarray([
113 | [mask, 7.9, mask, 100, 71, 6, mask],
114 | [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, mask],
115 | ]),
116 | binary_search.topk_mask(x, 4, mask),
117 | )
118 | np.testing.assert_array_equal(
119 | jnp.asarray([
120 | [1.4, 7.9, mask, 100, 71, 6, mask],
121 | [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, mask],
122 | ]),
123 | binary_search.topk_mask(x, 5, mask),
124 | )
125 | np.testing.assert_array_equal(
126 | jnp.asarray([
127 | [1.4, 7.9, -4.3, 100, 71, 6, mask],
128 | [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, mask],
129 | ]),
130 | binary_search.topk_mask(x, 6, mask),
131 | )
132 | np.testing.assert_array_equal(
133 | jnp.asarray([
134 | [1.4, 7.9, -4.3, 100, 71, 6, -1e4],
135 | [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, -100],
136 | ]),
137 | binary_search.topk_mask(x, 7, mask),
138 | )
139 |
140 | def test_topp_mask(self):
141 | probs = jnp.asarray([
142 | [0.0, 0.7, 0.04, 0.06, 0.2, 0.0],
143 | [0.0, 0.2, 0.2, 0.2, 0.3, 0.1],
144 | ])
145 | logits = jnp.log(probs)
146 | np.testing.assert_allclose(jax.nn.softmax(logits), probs)
147 | mask = -1e10
148 |
149 | # Using exact equality here, because topp_mask guarantees it: it is just
150 | # masking some things, not doing arithmetic on the array.
151 | np.testing.assert_array_equal(
152 | jnp.asarray([
153 | [mask, jnp.log(0.7), mask, mask, mask, mask],
154 | [mask, mask, mask, mask, jnp.log(0.3), mask],
155 | ]),
156 | binary_search.topp_mask(logits, 0.1, mask),
157 | )
158 | np.testing.assert_array_equal(
159 | jnp.asarray([
160 | [mask, jnp.log(0.7), mask, mask, mask, mask],
161 | [mask, mask, mask, mask, jnp.log(0.3), mask],
162 | ]),
163 | binary_search.topp_mask(logits, 0.3, mask),
164 | )
165 | np.testing.assert_array_equal(
166 | jnp.asarray([
167 | [mask, jnp.log(0.7), mask, mask, mask, mask],
168 | [
169 | mask,
170 | jnp.log(0.2),
171 | jnp.log(0.2),
172 | jnp.log(0.2),
173 | jnp.log(0.3), mask
174 | ],
175 | ]),
176 | binary_search.topp_mask(logits, 0.4, mask),
177 | )
178 | np.testing.assert_array_equal(
179 | jnp.asarray([
180 | [mask, jnp.log(0.7), mask, mask,
181 | jnp.log(0.2), mask],
182 | [
183 | mask,
184 | jnp.log(0.2),
185 | jnp.log(0.2),
186 | jnp.log(0.2),
187 | jnp.log(0.3), mask
188 | ],
189 | ]),
190 | binary_search.topp_mask(logits, 0.8, mask),
191 | )
192 | np.testing.assert_array_equal(
193 | jnp.asarray([
194 | [mask, jnp.log(0.7), mask,
195 | jnp.log(0.06),
196 | jnp.log(0.2), mask],
197 | [
198 | mask,
199 | jnp.log(0.2),
200 | jnp.log(0.2),
201 | jnp.log(0.2),
202 | jnp.log(0.3),
203 | jnp.log(0.1)
204 | ],
205 | ]),
206 | binary_search.topp_mask(logits, 0.95, mask),
207 | )
208 |
209 |
210 | if __name__ == '__main__':
211 | absltest.main()
212 |
--------------------------------------------------------------------------------
/t5x/checkpoint_importer_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for t5x.checkpoint_importer."""
16 |
17 | import json
18 | import os
19 |
20 | from absl import flags
21 | from absl.testing import absltest
22 | import jax
23 | import numpy as np
24 | from t5x import checkpoint_importer
25 | import tensorflow as tf
26 |
27 |
28 | class CheckpointImporterTest(absltest.TestCase):
29 |
30 | def test_rel_embeddings_shared_layers(self):
31 | # This represents a ckpt where the Mesh TensorFlow's
32 | # transformer_layers.SelfAttention.relative_attention_type = "bias_shared",
33 | # i.e., the same relative attention parameters are shared by all layers
34 | # within the (en|de)coder.
35 | ckpt_data = {
36 | 'encoder/block_000/layer_000/SelfAttention/relative_attention_bias':
37 | 1,
38 | 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias':
39 | 2,
40 | 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v':
41 | 3,
42 | }
43 | t5_data = checkpoint_importer.t5_importer.apply(ckpt_data)
44 | t5_data = checkpoint_importer._maybe_correct_relpos_bias(t5_data)
45 | expected = {
46 | 'target/encoder/relpos_bias/rel_embedding': 1,
47 | 'target/decoder/relpos_bias/rel_embedding': 2,
48 | 'state/param_states/decoder/relpos_bias/rel_embedding/v': 3,
49 | }
50 | self.assertEqual(t5_data, expected)
51 |
52 | def test_rel_embeddings_per_layer(self):
53 | # This represents a ckpt where the Mesh TensorFlow's
54 | # transformer_layers.SelfAttention.relative_attention_type = "bias", i.e.,
55 | # each layer has its own relative attention parameters.
56 | ckpt_data = {
57 | 'encoder/block_000/layer_000/SelfAttention/relative_attention_bias':
58 | 1,
59 | 'encoder/block_001/layer_000/SelfAttention/relative_attention_bias':
60 | 2,
61 | 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias':
62 | 3,
63 | 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v':
64 | 4,
65 | 'decoder/block_011/layer_000/SelfAttention/relative_attention_bias':
66 | 5
67 | }
68 | t5_data = checkpoint_importer.t5_importer.apply(ckpt_data)
69 | t5_data = checkpoint_importer._maybe_correct_relpos_bias(t5_data)
70 | expected = {
71 | 'target/encoder/layers_0/relpos_bias/rel_embedding': 1,
72 | 'target/encoder/layers_1/relpos_bias/rel_embedding': 2,
73 | 'target/decoder/layers_0/relpos_bias/rel_embedding': 3,
74 | 'state/param_states/decoder/layers_0/relpos_bias/rel_embedding/v': 4,
75 | 'target/decoder/layers_11/relpos_bias/rel_embedding': 5,
76 | }
77 | self.assertEqual(t5_data, expected)
78 |
79 |
80 | if __name__ == '__main__':
81 | absltest.main()
82 |
--------------------------------------------------------------------------------
/t5x/checkpoint_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Checkpoint helper functions for managing checkpoints.
16 |
17 | Supports marking checkpoints as pinned to exclude them from the checkpointer
18 | removal process.
19 | """
20 |
21 | import os
22 |
23 | from absl import logging
24 |
25 | from tensorflow.io import gfile
26 |
27 | # PINNED file in the checkpoint directory indicates that the checkpoint should
28 | # not be removed during the automatic pruning of old checkpoints.
29 | _PINNED_CHECKPOINT_FILENAME = 'PINNED'
30 |
31 |
32 | def pinned_checkpoint_filepath(ckpt_dir: str) -> str:
33 | """Full path of the pinned checkpoint file."""
34 | return os.path.join(ckpt_dir, _PINNED_CHECKPOINT_FILENAME)
35 |
36 |
37 | def is_pinned_checkpoint(ckpt_dir: str) -> bool:
38 | """Returns whether the checkpoint is pinned, and should NOT be removed."""
39 | pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir)
40 | if gfile.exists(pinned_ckpt_file):
41 | return True
42 | return False
43 |
44 |
45 | def pin_checkpoint(ckpt_dir: str, txt: str = '1') -> None:
46 | """Pin a checkpoint so it does not get deleted by the normal pruning process.
47 |
48 | Creates a PINNED file in the checkpoint directory to indicate the checkpoint
49 | should be excluded from the deletion of old checkpoints.
50 |
51 | Args:
52 | ckpt_dir: The checkpoint step dir that is to be always kept.
53 | txt: Text to be written into the checkpoints ALWAYS_KEEP me file.
54 | """
55 | pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir)
56 | with gfile.GFile(pinned_ckpt_file, 'w') as f:
57 | logging.debug('Write %s file : %s.', pinned_ckpt_file, txt)
58 | f.write(txt)
59 |
60 |
61 | def unpin_checkpoint(ckpt_dir: str) -> None:
62 | """Removes the pinned status of the checkpoint so it is open for deletion."""
63 | if not is_pinned_checkpoint(ckpt_dir):
64 | logging.debug('%s is not PINNED. Nothing to do here.', ckpt_dir)
65 | return
66 | try:
67 | pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir)
68 | logging.debug('Remove %s file.', pinned_ckpt_file)
69 | gfile.rmtree(pinned_ckpt_file)
70 | except IOError:
71 | logging.exception('Failed to unpin %s', ckpt_dir)
72 |
73 |
74 | def remove_checkpoint_dir(ckpt_dir: str) -> None:
75 | """Removes the checkpoint dir if it is not pinned."""
76 | if not is_pinned_checkpoint(ckpt_dir):
77 | logging.info('Deleting checkpoint: %s', ckpt_dir)
78 | gfile.rmtree(ckpt_dir)
79 | else:
80 | logging.info('Keeping pinned checkpoint: %s', ckpt_dir)
81 |
82 |
83 | def remove_dataset_checkpoint(ckpt_dir: str, train_ds_prefix: str) -> None:
84 | """Removes dataset checkpoints if the checkpoint is not pinned."""
85 | if not is_pinned_checkpoint(ckpt_dir):
86 | train_ds_pattern = os.path.join(ckpt_dir, train_ds_prefix + '*')
87 | logging.info('Deleting dataset checkpoint: %s', train_ds_pattern)
88 | for file in gfile.glob(train_ds_pattern):
89 | gfile.remove(file)
90 | else:
91 | logging.info('Keeping pinned checkpoint: %s', ckpt_dir)
92 |
--------------------------------------------------------------------------------
/t5x/checkpoint_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for t5x.checkpoint_utils."""
16 |
17 | import os
18 | import traceback
19 |
20 | from absl.testing import absltest
21 | from t5x import checkpoint_utils
22 | from tensorflow.io import gfile
23 |
24 | TESTDATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
25 |
26 |
27 | class CheckpointsUtilsTest(absltest.TestCase):
28 |
29 | def setUp(self):
30 | super().setUp()
31 | self.checkpoints_dir = self.create_tempdir()
32 | self.ckpt_dir_path = self.checkpoints_dir.full_path
33 | self.pinned_ckpt_file = os.path.join(self.ckpt_dir_path, "PINNED")
34 | self.checkpoints_dir.create_file("checkpoint")
35 | # Create a `train_ds` file representing the dataset checkpoint.
36 | train_ds_basename = "train_ds-00000-of-00001"
37 | self.train_ds_file = os.path.join(self.ckpt_dir_path, train_ds_basename)
38 | self.checkpoints_dir.create_file(train_ds_basename)
39 |
40 | def test_always_keep_checkpoint_file(self):
41 | self.assertEqual(
42 | "/path/to/ckpt/dir/PINNED",
43 | checkpoint_utils.pinned_checkpoint_filepath("/path/to/ckpt/dir"))
44 |
45 | def test_is_pinned_checkpoint_false_by_default(self):
46 | # Ensure regular checkpoint without PINNED file.
47 | self.assertFalse(gfile.exists(os.path.join(self.ckpt_dir_path, "PINNED")))
48 |
49 | # Validate checkpoints are not pinned by default.
50 | self.assertFalse(checkpoint_utils.is_pinned_checkpoint(self.ckpt_dir_path))
51 |
52 | def test_is_pinned_checkpoint(self):
53 | # Ensure the checkpoint directory as pinned.
54 | pinned_ckpt_testdata = os.path.join(TESTDATA, "pinned_ckpt_dir")
55 | pinned_file = os.path.join(pinned_ckpt_testdata, "PINNED")
56 | self.assertTrue(gfile.exists(pinned_file))
57 |
58 | # Test and validate.
59 | self.assertTrue(checkpoint_utils.is_pinned_checkpoint(pinned_ckpt_testdata))
60 |
61 | def test_is_pinned_missing_ckpt(self):
62 | self.assertFalse(
63 | checkpoint_utils.is_pinned_checkpoint(
64 | os.path.join(self.ckpt_dir_path, "ckpt_does_not_exist")))
65 |
66 | def test_pin_checkpoint(self):
67 | # Ensure directory isn't already pinned.
68 | self.assertFalse(gfile.exists(self.pinned_ckpt_file))
69 |
70 | # Test.
71 | checkpoint_utils.pin_checkpoint(self.ckpt_dir_path)
72 |
73 | # Validate.
74 | self.assertTrue(gfile.exists(self.pinned_ckpt_file))
75 | with open(self.pinned_ckpt_file) as f:
76 | self.assertEqual("1", f.read())
77 |
78 | def test_pin_checkpoint_txt(self):
79 | checkpoint_utils.pin_checkpoint(self.ckpt_dir_path, "TEXT_IN_PINNED")
80 | self.assertTrue(os.path.exists(os.path.join(self.ckpt_dir_path, "PINNED")))
81 | with open(self.pinned_ckpt_file) as f:
82 | self.assertEqual("TEXT_IN_PINNED", f.read())
83 |
84 | def test_unpin_checkpoint(self):
85 | # Mark the checkpoint directory as pinned.
86 | self.checkpoints_dir.create_file("PINNED")
87 | self.assertTrue(checkpoint_utils.is_pinned_checkpoint(self.ckpt_dir_path))
88 |
89 | # Test.
90 | checkpoint_utils.unpin_checkpoint(self.ckpt_dir_path)
91 |
92 | # Validate the "PINNED" checkpoint file got removed.
93 | self.assertFalse(gfile.exists(os.path.join(self.ckpt_dir_path, "PINNED")))
94 |
95 | def test_unpin_checkpoint_does_not_exist(self):
96 | missing_ckpt_path = os.path.join(self.ckpt_dir_path, "ckpt_does_not_exist")
97 | self.assertFalse(gfile.exists(missing_ckpt_path))
98 |
99 | # Test. Assert does not raise error.
100 | try:
101 | checkpoint_utils.unpin_checkpoint(missing_ckpt_path)
102 | except IOError:
103 | # TODO(b/172262005): Remove traceback.format_exc() from the error message.
104 | self.fail("Unpin checkpoint failed with: %s" % traceback.format_exc())
105 |
106 | def test_remove_checkpoint_dir(self):
107 | # Ensure the checkpoint directory is setup.
108 | assert gfile.exists(self.ckpt_dir_path)
109 |
110 | # Test.
111 | checkpoint_utils.remove_checkpoint_dir(self.ckpt_dir_path)
112 |
113 | # Validate the checkpoint directory got removed.
114 | self.assertFalse(gfile.exists(self.ckpt_dir_path))
115 |
116 | def test_remove_checkpoint_dir_pinned(self):
117 | # Mark the checkpoint directory as pinned so it does not get removed.
118 | self.checkpoints_dir.create_file("PINNED")
119 |
120 | # Test.
121 | checkpoint_utils.remove_checkpoint_dir(self.ckpt_dir_path)
122 |
123 | # Validate the checkpoint directory still exists.
124 | self.assertTrue(gfile.exists(self.ckpt_dir_path))
125 |
126 | def test_remove_dataset_checkpoint(self):
127 | # Ensure the checkpoint directory is setup.
128 | assert gfile.exists(self.ckpt_dir_path)
129 |
130 | # Test.
131 | checkpoint_utils.remove_dataset_checkpoint(self.ckpt_dir_path, "train_ds")
132 |
133 | # Validate the checkpoint directory got removed.
134 | self.assertFalse(gfile.exists(self.train_ds_file))
135 | self.assertTrue(gfile.exists(self.ckpt_dir_path))
136 |
137 | def test_remove_dataset_checkpoint_pinned(self):
138 | # Mark the checkpoint directory as pinned so it does not get removed.
139 | self.checkpoints_dir.create_file("PINNED")
140 |
141 | # Test.
142 | checkpoint_utils.remove_dataset_checkpoint(self.ckpt_dir_path, "train_ds")
143 |
144 | # Validate the checkpoint directory still exists.
145 | self.assertTrue(gfile.exists(self.train_ds_file))
146 | self.assertTrue(gfile.exists(self.ckpt_dir_path))
147 |
148 | if __name__ == "__main__":
149 | absltest.main()
150 |
--------------------------------------------------------------------------------
/t5x/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """This empty file is needed for loading the gin files in this directory."""
16 |
--------------------------------------------------------------------------------
/t5x/configs/runs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # This empty file is needed for loading the gin files in this directory.
16 |
--------------------------------------------------------------------------------
/t5x/configs/runs/debug.gin:
--------------------------------------------------------------------------------
1 | # Defaults for pretraining with train.py.
2 | #
3 | #
4 | # You must also include a binding for MODEL.
5 | #
6 | # Required to be set:
7 | #
8 | # - MIXTURE_OR_TASK_NAME
9 | # - TASK_FEATURE_LENGTHS
10 | # - TRAIN_STEPS
11 | # - MODEL_DIR: # automatically set when using xm_launch
12 | #
13 | # Commonly overridden options:
14 | #
15 | # - train/DatasetConfig.batch_size
16 | # - train_eval/DatasetConfig.batch_size
17 | # - PjitPartitioner.num_partitions
18 | # - Trainer.num_microbatches
19 | # - DROPOUT_RATE
20 | from __gin__ import dynamic_registration
21 |
22 | import __main__ as train_script
23 | from t5x import gin_utils
24 | from t5x import partitioning
25 | from t5x import utils
26 | from t5x import trainer
27 |
28 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED
29 | TASK_FEATURE_LENGTHS = %gin.REQUIRED
30 | TRAIN_STEPS = %gin.REQUIRED
31 | MODEL_DIR = %gin.REQUIRED
32 | BATCH_SIZE = 128
33 | USE_CACHED_TASKS = False
34 | BASE_LR = 0.01
35 | WARMUP_STEPS = 100
36 | TIMESCALE = 100
37 |
38 | # DEPRECATED: Import the this module in your gin file.
39 | MIXTURE_OR_TASK_MODULE = None
40 | SHUFFLE_TRAIN_EXAMPLES = False
41 |
42 | # HW RNG is faster than SW, but has limited determinism.
43 | # Most notably it is not deterministic across different
44 | # submeshes.
45 | USE_HARDWARE_RNG = False
46 | # None always uses faster, hardware RNG
47 | RANDOM_SEED = None
48 |
49 | # Can be overridden with `train.*`.`
50 | train_script.train:
51 | model = %MODEL # imported from separate gin file
52 | model_dir = %MODEL_DIR
53 | train_dataset_cfg = @train/utils.DatasetConfig()
54 | train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
55 | infer_eval_dataset_cfg = None
56 | checkpoint_cfg = @utils.CheckpointConfig()
57 | partitioner = @partitioning.PjitPartitioner()
58 | trainer_cls = @trainer.Trainer
59 | total_steps = %TRAIN_STEPS
60 | eval_steps = 20
61 | eval_period = 1000
62 | random_seed = %RANDOM_SEED
63 | use_hardware_rng = %USE_HARDWARE_RNG
64 | summarize_config_fn = @gin_utils.summarize_gin_config
65 |
66 | partitioning.PjitPartitioner:
67 | num_partitions = 1
68 | model_parallel_submesh = None
69 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
70 |
71 | train/utils.DatasetConfig:
72 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME
73 | task_feature_lengths = %TASK_FEATURE_LENGTHS
74 | split = 'train'
75 | batch_size = %BATCH_SIZE
76 | shuffle = %SHUFFLE_TRAIN_EXAMPLES
77 | seed = None # use a new seed each run/restart
78 | use_cached = %USE_CACHED_TASKS
79 | pack = False
80 | module = %MIXTURE_OR_TASK_MODULE
81 |
82 | train_eval/utils.DatasetConfig:
83 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL
84 | task_feature_lengths = %TASK_FEATURE_LENGTHS
85 | split = 'validation'
86 | batch_size = %BATCH_SIZE
87 | shuffle = False
88 | seed = 42
89 | use_cached = %USE_CACHED_TASKS
90 | pack = False
91 | module = %MIXTURE_OR_TASK_MODULE
92 |
93 | utils.CheckpointConfig:
94 | restore = @utils.RestoreCheckpointConfig()
95 | save = @utils.SaveCheckpointConfig()
96 |
97 | utils.RestoreCheckpointConfig:
98 | path = %INITIAL_CHECKPOINT_PATH
99 | mode = 'specific'
100 | dtype = 'float32'
101 | strict = False
102 | # fallback_to_scratch = %FALLBACK_TO_SCRATCH
103 | # state_transformation_fns = %STATE_TRANSFORMATION_FNS
104 |
105 | utils.SaveCheckpointConfig:
106 | period = 20000
107 | dtype = 'float32'
108 | keep = 5 # keep all checkpoints
109 | save_dataset = False # don't checkpoint dataset state
110 |
111 | trainer.Trainer:
112 | num_microbatches = None
113 | learning_rate_fn = @utils.create_learning_rate_scheduler()
114 |
115 | # utils.create_learning_rate_scheduler:
116 | # factors = 'constant * rsqrt_decay'
117 | # base_learning_rate = 1.0
118 | # warmup_steps = 100 # 10k to keep consistent with T5/MTF defaults.
119 |
120 | utils.create_learning_rate_scheduler:
121 | total_steps = %TRAIN_STEPS
122 | base = %BASE_LR
123 | decay_type = 'rsqrt'
124 | warmup_steps = %WARMUP_STEPS
125 | timescale = %TIMESCALE
126 |
127 | WANDB_GROUP = None
128 | WANDB_NAME = None
129 |
130 | from t5x.examples.unified_io import utils as unified_io_utils
131 |
132 | unified_io_utils.init_wandb:
133 | group = %WANDB_GROUP
134 | name = %WANDB_NAME
135 |
--------------------------------------------------------------------------------
/t5x/configs/runs/eval.gin:
--------------------------------------------------------------------------------
1 | # Defaults for eval.py.
2 | #
3 | #
4 | # You must also include a binding for MODEL.
5 | #
6 | # Required to be set:
7 | #
8 | # - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to evaluate on
9 | # - CHECKPOINT_PATH: The model checkpoint to evaluate
10 | # - EVAL_OUTPUT_DIR: The dir to write results to.
11 | #
12 | #
13 | # Commonly overridden options:
14 | #
15 | # - DatasetConfig.split
16 | # - DatasetConfig.batch_size
17 | # - DatasetConfig.use_cached
18 | # - RestoreCheckpointConfig.mode
19 | # - PjitPartitioner.num_partitions
20 | from __gin__ import dynamic_registration
21 |
22 | import __main__ as eval_script
23 | import seqio
24 | from t5x import partitioning
25 | from t5x import utils
26 | import t5x.examples.unified_io.evaluator as uio_evaluator
27 |
28 | # Must be overridden
29 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED
30 | CHECKPOINT_PATH = %gin.REQUIRED
31 | EVAL_OUTPUT_DIR = %gin.REQUIRED
32 | TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use.
33 |
34 | # DEPRECATED: Import the this module in your gin file.
35 | MIXTURE_OR_TASK_MODULE = None
36 |
37 | eval_script.evaluate:
38 | model = %MODEL # imported from separate gin file
39 | dataset_cfg = @utils.DatasetConfig()
40 | partitioner = @partitioning.PjitPartitioner()
41 | restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
42 | output_dir = %EVAL_OUTPUT_DIR
43 | inference_evaluator_cls = @uio_evaluator.UnifiedIOEvaluator
44 |
45 | partitioning.PjitPartitioner:
46 | num_partitions = 1
47 | model_parallel_submesh = None
48 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
49 |
50 | uio_evaluator.UnifiedIOEvaluator:
51 | logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
52 | num_examples = None # Use all examples in the dataset.
53 | use_memory_cache = False
54 |
55 | utils.DatasetConfig:
56 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME
57 | task_feature_lengths = %TASK_FEATURE_LENGTHS
58 | split = 'validation'
59 | batch_size = 32
60 | shuffle = False
61 | seed = 42
62 | use_cached = False
63 | use_memory_cache = False
64 | pack = False
65 | use_custom_packing_ops = False
66 | module = %MIXTURE_OR_TASK_MODULE
67 |
68 | utils.RestoreCheckpointConfig:
69 | path = %CHECKPOINT_PATH
70 | mode = 'specific'
71 |
--------------------------------------------------------------------------------
/t5x/configs/runs/export.gin:
--------------------------------------------------------------------------------
1 | # Defaults for single_core_export.py.
2 | #
3 | # You must also include a binding for MODEL.
4 | #
5 | # Required to be set:
6 | #
7 | # - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features
8 | # to.
9 | # - CHECKPOINT_PATH: The model checkpoint to use for inference
10 | # - INFER_OUTPUT_DIR: The dir to write results to. When launching using
11 | # XManager, this is set automatically.
12 | #
13 | # Commonly overridden options:
14 | #
15 | # warmup_examples: Optional[List[str]] = None
16 | # jit_compile: bool = False
17 |
18 | from __gin__ import dynamic_registration
19 |
20 | import seqio
21 |
22 | from t5x import checkpoints
23 | from t5x import models
24 | from t5x import partitioning
25 | from t5x import utils
26 | from t5x import export_lib
27 |
28 | # Must be overridden
29 | OUTPUT_FEATURES = %gin.REQUIRED
30 | TASK_FEATURE_LENGTHS = %gin.REQUIRED
31 | CHECKPOINT_PATH = %gin.REQUIRED
32 | MODEL_OUTPUT_DIR = %gin.REQUIRED
33 | MODEL_NAME = %gin.REQUIRED
34 | BATCH_SIZE = None
35 | BEAM_SIZE = 1
36 |
37 | OUTPUT_FEATURES = {'inputs': @inputs/seqio.Feature(), 'targets': @outputs/seqio.Feature()}
38 |
39 | # Plumbing to extract the vocabulary directly from MODEL. This is needed to
40 | # tokenize the features from the saved model inputs we aren't provided with
41 | # vocabularies via a Task.
42 | inputs/seqio.Feature.vocabulary = @models.get_input_vocabulary()
43 | models.get_input_vocabulary.model = %MODEL # imported from separate gin file
44 | outputs/seqio.Feature.vocabulary = @models.get_output_vocabulary()
45 | models.get_output_vocabulary.model = %MODEL # imported from separate gin file
46 |
47 |
48 | # Typical for inference settings:
49 | ACTIVATION_DTYPE = 'bfloat16'
50 |
51 | export_lib.save:
52 | model = %MODEL # imported from separate gin file
53 | inference_mode = 'predict'
54 | restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
55 | exportable_module_cls = @export_lib.ExportableModule
56 | create_preprocessor_fn = @export_lib.create_preprocessor
57 | create_postprocessor_fn = @export_lib.create_postprocessor
58 | write_warmup_example_fn = @export_lib.write_warmup_examples
59 | partitioner = @partitioning.PjitPartitioner()
60 | output_features = %OUTPUT_FEATURES
61 | task_feature_lengths = %TASK_FEATURE_LENGTHS
62 | output_dir = %MODEL_OUTPUT_DIR
63 | model_name = %MODEL_NAME
64 | batch_size = %BATCH_SIZE
65 | native_lowering = False
66 |
67 | utils.RestoreCheckpointConfig:
68 | path = %CHECKPOINT_PATH
69 | mode = 'specific'
70 | dtype = 'bfloat16'
71 | checkpointer_cls = @checkpoints.Checkpointer
72 | # TODO(b/234480674): GDA disabled due to incompatibility with export.
73 | use_gda = False
74 |
75 | export_lib.create_preprocessor:
76 | output_features = %OUTPUT_FEATURES
77 | task_feature_lengths = %TASK_FEATURE_LENGTHS
78 |
79 | export_lib.create_postprocessor:
80 | output_feature_names = None
81 |
82 | export_lib.ExportableModule:
83 | jit_compile = True
84 | use_batch_function = False
85 |
86 | partitioning.PjitPartitioner:
87 | num_partitions = 1
88 | params_on_devices = True
89 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
90 |
91 | models.EncoderDecoderModel.predict_batch_with_aux:
92 | num_decodes = %BEAM_SIZE
93 | return_all_decodes = True
94 |
--------------------------------------------------------------------------------
/t5x/configs/runs/export_seqio.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | from t5x import export_lib
4 | from t5x import partitioning
5 |
6 | include 't5x/configs/runs/export.gin'
7 |
8 |
9 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED
10 |
11 | export_lib.save:
12 | create_preprocessor_fn = @export_lib.create_preprocessor_from_task
13 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME
14 | output_features = None
15 |
16 | export_lib.create_preprocessor_from_task:
17 | model = %MODEL
18 | task_feature_lengths = %TASK_FEATURE_LENGTHS
19 | task_name = %MIXTURE_OR_TASK_NAME
20 | serialized_examples = True
21 | run_precache = False
22 |
--------------------------------------------------------------------------------
/t5x/configs/runs/finetune.gin:
--------------------------------------------------------------------------------
1 | # Defaults for finetuning with train.py.
2 | #
3 | #
4 | # You must also include a binding for MODEL.
5 | #
6 | # Required to be set:
7 | #
8 | # - MIXTURE_OR_TASK_NAME
9 | # - TASK_FEATURE_LENGTHS
10 | # - TRAIN_STEPS # includes pretrain steps
11 | # - MODEL_DIR # automatically set when using xm_launch
12 | # - INITIAL_CHECKPOINT_PATH
13 | #
14 | # When running locally, it needs to be passed in the `gin.MODEL_DIR` flag.
15 | #
16 | # `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt
17 | # has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps.
18 | #
19 | # Commonly overridden options:
20 | # - DROPOUT_RATE
21 | # - BATCH_SIZE
22 | # - PjitPartitioner.num_partitions
23 | # - Trainer.num_microbatches
24 | # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess
25 | # on the fly. Most common tasks are cached, hence this is set to True by
26 | # default.
27 |
28 | from __gin__ import dynamic_registration
29 |
30 | import __main__ as train_script
31 | import seqio
32 | from t5x import gin_utils
33 | from t5x import partitioning
34 | from t5x import utils
35 | from t5x import trainer
36 | import t5x.examples.unified_io.evaluator as uio_evaluator
37 |
38 | # Must be overridden
39 | MODEL_DIR = %gin.REQUIRED
40 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED
41 | MIXTURE_OR_TASK_NAME_EVAL = %gin.REQUIRED
42 | TASK_FEATURE_LENGTHS_TRAIN = %gin.REQUIRED
43 | TASK_FEATURE_LENGTHS_EVAL = %gin.REQUIRED
44 | MIXTURE_OR_TASK_MODULE = %gin.REQUIRED
45 | TRAIN_STEPS = %gin.REQUIRED
46 | INITIAL_CHECKPOINT_PATH = %gin.REQUIRED
47 |
48 | # Commonly overridden
49 | DROPOUT_RATE = 0.1
50 | USE_CACHED_TASKS = True
51 | BATCH_SIZE = 128
52 |
53 | # Sometimes overridden
54 | EVAL_STEPS = 20
55 | EVAL_PERIOD = 1000
56 |
57 | # Convenience overrides.
58 | EVALUATOR_USE_MEMORY_CACHE = True
59 | EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset.
60 | JSON_WRITE_N_RESULTS = None # Write all inferences.
61 | # HW RNG is faster than SW, but has limited determinism.
62 | # Most notably it is not deterministic across different
63 | # submeshes.
64 | USE_HARDWARE_RNG = False
65 | # None always uses faster, hardware RNG
66 | RANDOM_SEED = None
67 |
68 | # DEPRECATED: Import the this module in your gin file.
69 | MIXTURE_OR_TASK_MODULE = None
70 | TARGET_FIELD_NAME = 'text_targets'
71 |
72 | train_script.train:
73 | model = %MODEL # imported from separate gin file
74 | model_dir = %MODEL_DIR
75 | train_dataset_cfg = @train/utils.DatasetConfig()
76 | train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
77 | infer_eval_dataset_cfg = None
78 | checkpoint_cfg = @utils.CheckpointConfig()
79 | partitioner = @partitioning.PjitPartitioner()
80 | trainer_cls = @trainer.Trainer
81 | total_steps = %TRAIN_STEPS
82 | eval_steps = %EVAL_STEPS
83 | eval_period = %EVAL_PERIOD
84 | random_seed = %RANDOM_SEED
85 | use_hardware_rng = %USE_HARDWARE_RNG
86 | summarize_config_fn = @gin_utils.summarize_gin_config
87 | inference_evaluator_cls = @uio_evaluator.UnifiedIOEvaluator
88 |
89 | partitioning.PjitPartitioner:
90 | num_partitions = 1
91 | model_parallel_submesh = None
92 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
93 |
94 | uio_evaluator.UnifiedIOEvaluator:
95 | logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
96 | num_examples = %EVALUATOR_NUM_EXAMPLES
97 | use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE
98 |
99 | seqio.JSONLogger:
100 | write_n_results = %JSON_WRITE_N_RESULTS
101 |
102 | train/utils.DatasetConfig:
103 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME
104 | task_feature_lengths = %TASK_FEATURE_LENGTHS_TRAIN
105 | split = 'train'
106 | batch_size = %BATCH_SIZE
107 | shuffle = True
108 | seed = None # use a new seed each run/restart
109 | use_cached = %USE_CACHED_TASKS
110 | pack = False
111 | module = %MIXTURE_OR_TASK_MODULE
112 |
113 | train_eval/utils.DatasetConfig:
114 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL
115 | task_feature_lengths = %TASK_FEATURE_LENGTHS_EVAL
116 | split = 'validation'
117 | batch_size = %BATCH_SIZE
118 | shuffle = False
119 | seed = 42
120 | use_cached = %USE_CACHED_TASKS
121 | pack = False
122 | module = %MIXTURE_OR_TASK_MODULE
123 |
124 | infer_eval/utils.DatasetConfig:
125 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL
126 | task_feature_lengths = %TASK_FEATURE_LENGTHS_EVAL # compute max
127 | split = 'validation'
128 | batch_size = %BATCH_SIZE
129 | shuffle = False
130 | seed = 42
131 | use_cached = %USE_CACHED_TASKS
132 | pack = False
133 | module = %MIXTURE_OR_TASK_MODULE
134 | target_field_name = %TARGET_FIELD_NAME
135 |
136 | utils.CheckpointConfig:
137 | restore = @utils.RestoreCheckpointConfig()
138 | save = @utils.SaveCheckpointConfig()
139 | utils.RestoreCheckpointConfig:
140 | path = %INITIAL_CHECKPOINT_PATH
141 | mode = 'specific'
142 | dtype = 'float32'
143 | strict = False
144 |
145 | utils.SaveCheckpointConfig:
146 | period = 5000
147 | dtype = 'float32'
148 | keep = None # keep all checkpoints
149 | save_dataset = False # don't checkpoint dataset state
150 |
151 | trainer.Trainer:
152 | num_microbatches = None
153 | learning_rate_fn = @utils.create_learning_rate_scheduler()
154 |
155 | utils.create_learning_rate_scheduler:
156 | factors = 'constant'
157 | base_learning_rate = 0.0001
158 | warmup_steps = 1000
159 |
--------------------------------------------------------------------------------
/t5x/configs/runs/infer.gin:
--------------------------------------------------------------------------------
1 | # Defaults for infer.py.
2 | #
3 | #
4 | # You must also include a binding for MODEL.
5 | #
6 | # Required to be set:
7 | #
8 | # - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to use for inference
9 | # - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features
10 | # to.
11 | # - CHECKPOINT_PATH: The model checkpoint to use for inference
12 | # - INFER_OUTPUT_DIR: The dir to write results to.
13 | #
14 | #
15 | # Commonly overridden options:
16 | #
17 | # - infer.mode
18 | # - infer.checkpoint_period
19 | # - infer.shard_id
20 | # - infer.num_shards
21 | # - DatasetConfig.split
22 | # - DatasetConfig.batch_size
23 | # - DatasetConfig.use_cached
24 | # - RestoreCheckpointConfig.is_tensorflow
25 | # - RestoreCheckpointConfig.mode
26 | # - PjitPartitioner.num_partitions
27 | from __gin__ import dynamic_registration
28 |
29 | import __main__ as infer_script
30 | from t5x import partitioning
31 | from t5x import utils
32 |
33 | # Must be overridden
34 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED
35 | TASK_FEATURE_LENGTHS = %gin.REQUIRED
36 | CHECKPOINT_PATH = %gin.REQUIRED
37 | INFER_OUTPUT_DIR = %gin.REQUIRED
38 | CHECKPOINT_PERIOD = %gin.REQUIRED
39 | # DEPRECATED: Import the this module in your gin file.
40 | MIXTURE_OR_TASK_MODULE = None
41 |
42 | infer_script.infer:
43 | mode = 'predict_with_aux'
44 | model = %MODEL # imported from separate gin file
45 | output_dir = %INFER_OUTPUT_DIR
46 | dataset_cfg = @utils.DatasetConfig()
47 | partitioner = @partitioning.PjitPartitioner()
48 | restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
49 | checkpoint_period = %CHECKPOINT_PERIOD
50 | shard_id = 0
51 | num_shards = 1
52 | checkpoint_ds_iter = False
53 |
54 | partitioning.PjitPartitioner:
55 | num_partitions = 1
56 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
57 |
58 | utils.DatasetConfig:
59 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME
60 | module = %MIXTURE_OR_TASK_MODULE
61 | task_feature_lengths = %TASK_FEATURE_LENGTHS
62 | use_cached = False
63 | split = 'test'
64 | batch_size = 32
65 | shuffle = False
66 | seed = 0
67 | pack = False
68 |
69 | utils.RestoreCheckpointConfig:
70 | path = %CHECKPOINT_PATH
71 | mode = 'specific'
72 | dtype = 'bfloat16'
73 |
--------------------------------------------------------------------------------
/t5x/configs/runs/infer_from_tfexample_file.gin:
--------------------------------------------------------------------------------
1 | # Defaults for infer.py if using a TFExample file as input.
2 | #
3 | #
4 | # The features from each TFExample are tokenized using the model's vocabulary.
5 | # By default, the inputs feature is assumed to be keyed as 'inputs', but this
6 | # can be overridden with `create_task_from_tfexample_file.inputs_key`.
7 | #
8 | # You must also include a binding for MODEL.
9 | #
10 | # Required to be set:
11 | #
12 | # - TF_EXAMPLE_FILE_PATHS: The path to read TF Examples from.
13 | # - TF_EXAMPLE_FILE_TYPE: The type of file to read TF Examples from. Currently
14 | # supported: 'tfrecord', 'recordio', 'sstable'.
15 | # - FEATURE_LENGTHS: The maximum length per feature in the TF Examples.
16 | # - CHECKPOINT_PATH: The model checkpoint to use for inference
17 | # - INFER_OUTPUT_DIR: The dir to write results to.
18 | #
19 | #
20 | # Commonly overridden options:
21 | #
22 | # - infer.mode
23 | # - infer.checkpoint_period
24 | # - infer.shard_id
25 | # - infer.num_shards
26 | # - create_task_from_tfexample_file.inputs_key
27 | # - create_task_from_tfexample_file.targets_key
28 | # - DatasetConfig.split
29 | # - DatasetConfig.batch_size
30 | # - RestoreCheckpointConfig.mode
31 | # - PjitPartitioner.num_partitions
32 | from __gin__ import dynamic_registration
33 |
34 | import __main__ as infer_script
35 | import seqio
36 | from t5x import models
37 | from t5x import partitioning
38 | from t5x import utils
39 |
40 | # Must be overridden
41 | TF_EXAMPLE_FILE_PATHS = %gin.REQUIRED
42 | TF_EXAMPLE_FILE_TYPE = %gin.REQUIRED
43 | FEATURE_LENGTHS = %gin.REQUIRED
44 | CHECKPOINT_PATH = %gin.REQUIRED
45 | INFER_OUTPUT_DIR = %gin.REQUIRED
46 |
47 | infer_script.infer:
48 | mode = 'predict'
49 | model = %MODEL # imported from separate gin file
50 | output_dir = %INFER_OUTPUT_DIR
51 | dataset_cfg = @utils.DatasetConfig()
52 | partitioner = @partitioning.PjitPartitioner()
53 | restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
54 | checkpoint_period = 100
55 | shard_id = 0
56 | num_shards = 1
57 |
58 | partitioning.PjitPartitioner:
59 | num_partitions = 1
60 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
61 |
62 | utils.DatasetConfig:
63 | mixture_or_task_name = @infer_script.create_task_from_tfexample_file()
64 | task_feature_lengths = %FEATURE_LENGTHS
65 | split = 'infer'
66 | batch_size = 32
67 | shuffle = False
68 | seed = 0
69 | pack = False
70 |
71 | infer_script.create_task_from_tfexample_file:
72 | paths = %TF_EXAMPLE_FILE_PATHS
73 | file_type = %TF_EXAMPLE_FILE_TYPE
74 | inputs_key = 'inputs'
75 | targets_key = None
76 | features = {'inputs': @inputs/seqio.Feature(), 'targets': @outputs/seqio.Feature()}
77 |
78 | # Plumbing to extract the vocabulary directly from MODEL. This is needed to
79 | # tokenize the features from the TFExample we aren't provided with vocabularies
80 | # via a Task.
81 | inputs/seqio.Feature.vocabulary = @models.get_input_vocabulary()
82 | models.get_input_vocabulary.model = %MODEL
83 | outputs/seqio.Feature.vocabulary = @models.get_output_vocabulary()
84 | models.get_output_vocabulary.model = %MODEL
85 |
86 | utils.RestoreCheckpointConfig:
87 | mode = 'specific'
88 | path = %CHECKPOINT_PATH
89 | dtype = 'bfloat16'
90 |
91 |
--------------------------------------------------------------------------------
/t5x/configs/runs/loss.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | import __main__ as loss_script
4 | import seqio
5 | from t5x import partitioning
6 | from t5x import utils
7 | import t5x.examples.unified_io.evaluator as uio_evaluator
8 |
9 | # Must be overridden
10 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED
11 | CHECKPOINT_PATH = %gin.REQUIRED
12 | EVAL_OUTPUT_DIR = %gin.REQUIRED
13 | EVAL_STEPS = %gin.REQUIRED
14 | TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use.
15 |
16 | # DEPRECATED: Import the this module in your gin file.
17 | MIXTURE_OR_TASK_MODULE = None
18 |
19 | loss_script.compute_loss:
20 | model = %MODEL # imported from separate gin file
21 | dataset_cfg = @utils.DatasetConfig()
22 | partitioner = @partitioning.PjitPartitioner()
23 | restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
24 | output_dir = %EVAL_OUTPUT_DIR
25 | eval_steps = %EVAL_STEPS
26 |
27 | partitioning.PjitPartitioner:
28 | num_partitions = 1
29 | model_parallel_submesh = None
30 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
31 |
32 | utils.DatasetConfig:
33 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME
34 | task_feature_lengths = %TASK_FEATURE_LENGTHS
35 | split = 'validation'
36 | batch_size = 32
37 | shuffle = False
38 | seed = 42
39 | use_cached = False
40 | use_memory_cache = False
41 | pack = False
42 | use_custom_packing_ops = False
43 | module = %MIXTURE_OR_TASK_MODULE
44 |
45 | utils.RestoreCheckpointConfig:
46 | path = %CHECKPOINT_PATH
47 | mode = 'specific'
48 |
--------------------------------------------------------------------------------
/t5x/configs/runs/multitask.gin:
--------------------------------------------------------------------------------
1 | # Defaults for finetuning with train.py.
2 | #
3 | #
4 | # You must also include a binding for MODEL.
5 | #
6 | # Required to be set:
7 | #
8 | # - MIXTURE_OR_TASK_NAME
9 | # - TASK_FEATURE_LENGTHS
10 | # - TRAIN_STEPS # includes pretrain steps
11 | # - MODEL_DIR # automatically set when using xm_launch
12 | # - INITIAL_CHECKPOINT_PATH
13 | #
14 | # When running locally, it needs to be passed in the `gin.MODEL_DIR` flag.
15 | #
16 | # `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt
17 | # has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps.
18 | #
19 | # Commonly overridden options:
20 | # - DROPOUT_RATE
21 | # - BATCH_SIZE
22 | # - PjitPartitioner.num_partitions
23 | # - Trainer.num_microbatches
24 | # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess
25 | # on the fly. Most common tasks are cached, hence this is set to True by
26 | # default.
27 |
28 | from __gin__ import dynamic_registration
29 |
30 | import __main__ as train_script
31 | import seqio
32 | from t5x import gin_utils
33 | from t5x import partitioning
34 | from t5x import utils
35 | from t5x import trainer
36 | import t5x.examples.unified_io.evaluator as uio_evaluator
37 |
38 | # Must be overridden
39 | MODEL_DIR = %gin.REQUIRED
40 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED
41 | MIXTURE_OR_TASK_NAME_EVAL = %gin.REQUIRED
42 | TASK_FEATURE_LENGTHS_TRAIN = %gin.REQUIRED
43 | TASK_FEATURE_LENGTHS_EVAL = %gin.REQUIRED
44 | MIXTURE_OR_TASK_MODULE = %gin.REQUIRED
45 | TRAIN_STEPS = %gin.REQUIRED
46 | INITIAL_CHECKPOINT_PATH = %gin.REQUIRED
47 |
48 | # Commonly overridden
49 | DROPOUT_RATE = 0.1
50 | USE_CACHED_TASKS = False
51 | BATCH_SIZE = 128
52 |
53 | # Sometimes overridden
54 | EVAL_STEPS = 20
55 |
56 | # Convenience overrides.
57 | EVALUATOR_USE_MEMORY_CACHE = False
58 | EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset.
59 | JSON_WRITE_N_RESULTS = None # Write all inferences.
60 | # HW RNG is faster than SW, but has limited determinism.
61 | # Most notably it is not deterministic across different
62 | # submeshes.
63 | USE_HARDWARE_RNG = False
64 | # None always uses faster, hardware RNG
65 | RANDOM_SEED = None
66 |
67 | # DEPRECATED: Import the this module in your gin file.
68 | MIXTURE_OR_TASK_MODULE = None
69 | TARGET_FIELD_NAME = 'text_targets'
70 |
71 | train_script.train:
72 | model = %MODEL # imported from separate gin file
73 | model_dir = %MODEL_DIR
74 | train_dataset_cfg = @train/utils.DatasetConfig()
75 | train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
76 | infer_eval_dataset_cfg = None
77 | checkpoint_cfg = @utils.CheckpointConfig()
78 | partitioner = @partitioning.PjitPartitioner()
79 | trainer_cls = @trainer.Trainer
80 | total_steps = %TRAIN_STEPS
81 | eval_steps = %EVAL_STEPS
82 | eval_period = 1000
83 | random_seed = %RANDOM_SEED
84 | use_hardware_rng = %USE_HARDWARE_RNG
85 | summarize_config_fn = @gin_utils.summarize_gin_config
86 | inference_evaluator_cls = @uio_evaluator.UnifiedIOEvaluator
87 |
88 | partitioning.PjitPartitioner:
89 | num_partitions = 1
90 | model_parallel_submesh = None
91 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
92 |
93 | uio_evaluator.UnifiedIOEvaluator:
94 | logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
95 | num_examples = %EVALUATOR_NUM_EXAMPLES
96 | use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE
97 |
98 | seqio.JSONLogger:
99 | write_n_results = %JSON_WRITE_N_RESULTS
100 |
101 | train/utils.DatasetConfig:
102 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME
103 | task_feature_lengths = %TASK_FEATURE_LENGTHS_TRAIN
104 | split = 'train'
105 | batch_size = %BATCH_SIZE
106 | shuffle = True
107 | seed = None # use a new seed each run/restart
108 | use_cached = %USE_CACHED_TASKS
109 | pack = False
110 | module = %MIXTURE_OR_TASK_MODULE
111 |
112 | train_eval/utils.DatasetConfig:
113 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL
114 | task_feature_lengths = %TASK_FEATURE_LENGTHS_EVAL
115 | split = 'validation'
116 | batch_size = %BATCH_SIZE
117 | shuffle = False
118 | seed = 42
119 | use_cached = %USE_CACHED_TASKS
120 | pack = False
121 | module = %MIXTURE_OR_TASK_MODULE
122 |
123 | utils.CheckpointConfig:
124 | restore = @utils.RestoreCheckpointConfig()
125 | save = @utils.SaveCheckpointConfig()
126 |
127 | utils.RestoreCheckpointConfig:
128 | path = %INITIAL_CHECKPOINT_PATH
129 | mode = 'specific'
130 | dtype = 'float32'
131 | strict = False
132 |
133 | utils.SaveCheckpointConfig:
134 | period = 50000
135 | dtype = 'float32'
136 | keep = None # keep all checkpoints
137 | save_dataset = False # don't checkpoint dataset state
138 |
139 | trainer.Trainer:
140 | num_microbatches = None
141 | learning_rate_fn = @utils.create_learning_rate_scheduler()
142 |
143 | utils.create_learning_rate_scheduler:
144 | factors = 'constant * rsqrt_decay'
145 | base_learning_rate = 1.0
146 | warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults.
147 |
--------------------------------------------------------------------------------
/t5x/configs/runs/precompile.gin:
--------------------------------------------------------------------------------
1 | # Defaults for precompile mode in main.py.
2 | #
3 | # You must also include a binding for MODEL.
4 | #
5 | # Required to be set:
6 | #
7 | # - MIXTURE_OR_TASK_NAME
8 | # - TASK_FEATURE_LENGTHS
9 | # - TRAIN_STEPS
10 | # - MODEL_DIR: # automatically set when using xm_launch
11 | #
12 | # Commonly overridden options:
13 | #
14 | # - USE_CACHED_TASKS
15 | # - BATCH_SIZE
16 | # - PjitPartitioner.num_partitions
17 | from __gin__ import dynamic_registration
18 |
19 | import __main__ as train_script
20 | import seqio
21 | from t5x import gin_utils
22 | from t5x import partitioning
23 | from t5x import utils
24 | from t5x import trainer
25 |
26 | MODEL_DIR = %gin.REQUIRED
27 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED
28 | TASK_FEATURE_LENGTHS = %gin.REQUIRED
29 |
30 |
31 | # Commonly overridden
32 | USE_CACHED_TASKS = True
33 | BATCH_SIZE = 128
34 |
35 | # None always uses faster, hardware RNG
36 | RANDOM_SEED = None
37 |
38 | train_script.precompile:
39 | model = %MODEL # imported from separate gin file
40 | model_dir = %MODEL_DIR
41 | train_dataset_cfg = @train/utils.DatasetConfig()
42 | partitioner = @partitioning.PjitPartitioner()
43 | random_seed = %RANDOM_SEED
44 |
45 | partitioning.PjitPartitioner:
46 | num_partitions = 1
47 | model_parallel_submesh = None
48 | backend = "tpu"
49 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
50 |
51 | train/utils.DatasetConfig:
52 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME
53 | task_feature_lengths = %TASK_FEATURE_LENGTHS
54 | split = 'train'
55 | batch_size = %BATCH_SIZE
56 | shuffle = True
57 | seed = None # use a new seed each run/restart
58 | use_cached = %USE_CACHED_TASKS
59 | pack = True
60 |
--------------------------------------------------------------------------------
/t5x/configs/runs/pretrain.gin:
--------------------------------------------------------------------------------
1 | # Defaults for finetuning with train.py.
2 | #
3 | #
4 | # You must also include a binding for MODEL.
5 | #
6 | # Required to be set:
7 | #
8 | # - MIXTURE_OR_TASK_NAME
9 | # - TASK_FEATURE_LENGTHS
10 | # - TRAIN_STEPS # includes pretrain steps
11 | # - MODEL_DIR # automatically set when using xm_launch
12 | # - INITIAL_CHECKPOINT_PATH
13 | #
14 | # When running locally, it needs to be passed in the `gin.MODEL_DIR` flag.
15 | #
16 | # `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt
17 | # has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps.
18 | #
19 | # Commonly overridden options:
20 | # - DROPOUT_RATE
21 | # - BATCH_SIZE
22 | # - PjitPartitioner.num_partitions
23 | # - Trainer.num_microbatches
24 | # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess
25 | # on the fly. Most common tasks are cached, hence this is set to True by
26 | # default.
27 |
28 | from __gin__ import dynamic_registration
29 |
30 | import __main__ as train_script
31 | import seqio
32 | from t5x import gin_utils
33 | from t5x import partitioning
34 | from t5x import utils
35 | from t5x import trainer
36 | import t5x.examples.unified_io.evaluator as uio_evaluator
37 |
38 | # Must be overridden
39 | MODEL_DIR = %gin.REQUIRED
40 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED
41 | MIXTURE_OR_TASK_NAME_EVAL = %gin.REQUIRED
42 | TASK_FEATURE_LENGTHS_TRAIN = %gin.REQUIRED
43 | TASK_FEATURE_LENGTHS_EVAL = %gin.REQUIRED
44 | MIXTURE_OR_TASK_MODULE = %gin.REQUIRED
45 | TRAIN_STEPS = %gin.REQUIRED
46 | INITIAL_CHECKPOINT_PATH = %gin.REQUIRED
47 |
48 | # Commonly overridden
49 | DROPOUT_RATE = 0.1
50 | USE_CACHED_TASKS = False
51 | BATCH_SIZE = 128
52 |
53 | # Sometimes overridden
54 | EVAL_STEPS = 20
55 |
56 | # Convenience overrides.
57 | EVALUATOR_USE_MEMORY_CACHE = False
58 | EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset.
59 | JSON_WRITE_N_RESULTS = None # Write all inferences.
60 | # HW RNG is faster than SW, but has limited determinism.
61 | # Most notably it is not deterministic across different
62 | # submeshes.
63 | USE_HARDWARE_RNG = False
64 | # None always uses faster, hardware RNG
65 | RANDOM_SEED = None
66 |
67 | # DEPRECATED: Import the this module in your gin file.
68 | MIXTURE_OR_TASK_MODULE = None
69 | TARGET_FIELD_NAME = 'text_targets'
70 |
71 | train_script.train:
72 | model = %MODEL # imported from separate gin file
73 | model_dir = %MODEL_DIR
74 | train_dataset_cfg = @train/utils.DatasetConfig()
75 | train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
76 | infer_eval_dataset_cfg = None
77 | checkpoint_cfg = @utils.CheckpointConfig()
78 | partitioner = @partitioning.PjitPartitioner()
79 | trainer_cls = @trainer.Trainer
80 | total_steps = %TRAIN_STEPS
81 | eval_steps = %EVAL_STEPS
82 | eval_period = 1000
83 | random_seed = %RANDOM_SEED
84 | use_hardware_rng = %USE_HARDWARE_RNG
85 | summarize_config_fn = @gin_utils.summarize_gin_config
86 | inference_evaluator_cls = @uio_evaluator.UnifiedIOEvaluator
87 | use_gda = True
88 | partitioning.PjitPartitioner:
89 | num_partitions = 1
90 | model_parallel_submesh = None
91 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
92 |
93 | uio_evaluator.UnifiedIOEvaluator:
94 | logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
95 | num_examples = %EVALUATOR_NUM_EXAMPLES
96 | use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE
97 |
98 | seqio.JSONLogger:
99 | write_n_results = %JSON_WRITE_N_RESULTS
100 |
101 | train/utils.DatasetConfig:
102 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME
103 | task_feature_lengths = %TASK_FEATURE_LENGTHS_TRAIN
104 | split = 'train'
105 | batch_size = %BATCH_SIZE
106 | shuffle = True
107 | seed = None # use a new seed each run/restart
108 | use_cached = %USE_CACHED_TASKS
109 | pack = False
110 | module = %MIXTURE_OR_TASK_MODULE
111 |
112 | train_eval/utils.DatasetConfig:
113 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL
114 | task_feature_lengths = %TASK_FEATURE_LENGTHS_EVAL
115 | split = 'validation'
116 | batch_size = %BATCH_SIZE
117 | shuffle = False
118 | seed = 42
119 | use_cached = %USE_CACHED_TASKS
120 | pack = False
121 | module = %MIXTURE_OR_TASK_MODULE
122 |
123 | utils.CheckpointConfig:
124 | restore = @utils.RestoreCheckpointConfig()
125 | save = @utils.SaveCheckpointConfig()
126 |
127 | utils.RestoreCheckpointConfig:
128 | path = %INITIAL_CHECKPOINT_PATH
129 | mode = 'specific'
130 | dtype = 'float32'
131 | strict = False
132 |
133 | utils.SaveCheckpointConfig:
134 | period = 50000
135 | dtype = 'float32'
136 | keep = None # keep all checkpoints
137 | save_dataset = False # don't checkpoint dataset state
138 |
139 | # trainer.Trainer:
140 | # num_microbatches = None
141 | # learning_rate_fn = @utils.create_learning_rate_scheduler()
142 |
143 | # utils.create_learning_rate_scheduler:
144 | # factors = 'constant * rsqrt_decay'
145 | # base_learning_rate = 1.0
146 | # warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults.
147 |
--------------------------------------------------------------------------------
/t5x/configs/runs/vit_vqgan.gin:
--------------------------------------------------------------------------------
1 | # Defaults for finetuning with train.py.
2 | #
3 | #
4 | # You must also include a binding for MODEL.
5 | #
6 | # Required to be set:
7 | #
8 | # - MIXTURE_OR_TASK_NAME
9 | # - TASK_FEATURE_LENGTHS
10 | # - TRAIN_STEPS # includes pretrain steps
11 | # - MODEL_DIR # automatically set when using xm_launch
12 | # - INITIAL_CHECKPOINT_PATH
13 | #
14 | # When running locally, it needs to be passed in the `gin.MODEL_DIR` flag.
15 | #
16 | # `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt
17 | # has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps.
18 | #
19 | # Commonly overridden options:
20 | # - DROPOUT_RATE
21 | # - BATCH_SIZE
22 | # - PjitPartitioner.num_partitions
23 | # - Trainer.num_microbatches
24 | # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess
25 | # on the fly. Most common tasks are cached, hence this is set to True by
26 | # default.
27 |
28 | from __gin__ import dynamic_registration
29 |
30 | import __main__ as train_script
31 | import seqio
32 | from t5x import gin_utils
33 | from t5x import partitioning
34 | from t5x import utils
35 | from t5x import trainer
36 | import t5x.examples.unified_io.evaluator as uio_evaluator
37 | from t5x import train_state as train_state_lib
38 |
39 | # Must be overridden
40 | MODEL_DIR = %gin.REQUIRED
41 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED
42 | MIXTURE_OR_TASK_NAME_EVAL = %gin.REQUIRED
43 | TASK_FEATURE_LENGTHS_TRAIN = %gin.REQUIRED
44 | TASK_FEATURE_LENGTHS_EVAL = %gin.REQUIRED
45 | MIXTURE_OR_TASK_MODULE = %gin.REQUIRED
46 | TRAIN_STEPS = %gin.REQUIRED
47 | INITIAL_CHECKPOINT_PATH = %gin.REQUIRED
48 |
49 | # Commonly overridden
50 | DROPOUT_RATE = 0.1
51 | USE_CACHED_TASKS = True
52 | BATCH_SIZE = 128
53 |
54 | # Sometimes overridden
55 | EVAL_STEPS = 20
56 |
57 | # Convenience overrides.
58 | EVALUATOR_USE_MEMORY_CACHE = True
59 | EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset.
60 | JSON_WRITE_N_RESULTS = None # Write all inferences.
61 | # HW RNG is faster than SW, but has limited determinism.
62 | # Most notably it is not deterministic across different
63 | # submeshes.
64 | USE_HARDWARE_RNG = False
65 | # None always uses faster, hardware RNG
66 | RANDOM_SEED = None
67 |
68 | # DEPRECATED: Import the this module in your gin file.
69 | MIXTURE_OR_TASK_MODULE = None
70 | TARGET_FIELD_NAME = 'text_targets'
71 |
72 | train_script.train:
73 | model = %MODEL # imported from separate gin file
74 | model_dir = %MODEL_DIR
75 | train_dataset_cfg = @train/utils.DatasetConfig()
76 | train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
77 | infer_eval_dataset_cfg = None
78 | checkpoint_cfg = @utils.CheckpointConfig()
79 | partitioner = @partitioning.PjitPartitioner()
80 | trainer_cls = @trainer.Trainer
81 | total_steps = %TRAIN_STEPS
82 | eval_steps = %EVAL_STEPS
83 | eval_period = 1000
84 | random_seed = %RANDOM_SEED
85 | use_hardware_rng = %USE_HARDWARE_RNG
86 | summarize_config_fn = @gin_utils.summarize_gin_config
87 | inference_evaluator_cls = @uio_evaluator.UnifiedIOEvaluator
88 |
89 | partitioning.PjitPartitioner:
90 | num_partitions = 1
91 | model_parallel_submesh = None
92 | logical_axis_rules = @partitioning.standard_logical_axis_rules()
93 |
94 | uio_evaluator.UnifiedIOEvaluator:
95 | logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
96 | num_examples = %EVALUATOR_NUM_EXAMPLES
97 | use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE
98 |
99 | seqio.JSONLogger:
100 | write_n_results = %JSON_WRITE_N_RESULTS
101 |
102 | train/utils.DatasetConfig:
103 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME
104 | task_feature_lengths = %TASK_FEATURE_LENGTHS_TRAIN
105 | split = 'train'
106 | batch_size = %BATCH_SIZE
107 | shuffle = True
108 | seed = None # use a new seed each run/restart
109 | use_cached = %USE_CACHED_TASKS
110 | pack = False
111 | module = %MIXTURE_OR_TASK_MODULE
112 |
113 | train_eval/utils.DatasetConfig:
114 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL
115 | task_feature_lengths = %TASK_FEATURE_LENGTHS_EVAL
116 | split = 'validation'
117 | batch_size = %BATCH_SIZE
118 | shuffle = False
119 | seed = 42
120 | use_cached = %USE_CACHED_TASKS
121 | pack = False
122 | module = %MIXTURE_OR_TASK_MODULE
123 |
124 | infer_eval/utils.DatasetConfig:
125 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL
126 | task_feature_lengths = %TASK_FEATURE_LENGTHS_EVAL # compute max
127 | split = 'validation'
128 | batch_size = %BATCH_SIZE
129 | shuffle = False
130 | seed = 42
131 | use_cached = %USE_CACHED_TASKS
132 | pack = False
133 | module = %MIXTURE_OR_TASK_MODULE
134 | target_field_name = %TARGET_FIELD_NAME
135 |
136 | utils.CheckpointConfig:
137 | restore = @utils.RestoreCheckpointConfig()
138 | save = @utils.SaveCheckpointConfig()
139 |
140 | utils.RestoreCheckpointConfig:
141 | path = %INITIAL_CHECKPOINT_PATH
142 | mode = 'specific'
143 | dtype = 'float32'
144 | strict = False
145 |
146 | utils.SaveCheckpointConfig:
147 | period = 1000
148 | dtype = 'float32'
149 | keep = 3 # keep all checkpoints
150 | save_dataset = False # don't checkpoint dataset state
--------------------------------------------------------------------------------
/t5x/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/t5x/examples/__init__.py
--------------------------------------------------------------------------------
/t5x/examples/unified_io/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # This empty file is needed for loading the gin files in this directory.
16 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/aux_fns.py:
--------------------------------------------------------------------------------
1 | """Logit masking and checkpoint transformation functions"""
2 |
3 | import functools
4 | from typing import Dict
5 |
6 | import gin
7 | import jax
8 | import jax.numpy as jnp
9 | import numpy as np
10 |
11 | from t5x import state_utils
12 | from t5x.examples.unified_io import config
13 | from t5x.examples.unified_io.data.data_utils import get_default_vocabulary
14 | from flax import linen as nn
15 |
16 |
17 | def apply_mask(logits, cur_index, mask):
18 | if len(mask.shape) == 3:
19 | mask = mask[:, cur_index]
20 | elif len(mask.shape) == 2:
21 | mask = jnp.reshape(mask[cur_index], [1, -1])
22 | else:
23 | mask = jnp.reshape(mask, [1, -1])
24 | if mask.dtype == jnp.bool_:
25 | flat_logits = jnp.where(mask, -1e10, logits)
26 | else:
27 | flat_logits = mask + logits
28 | return flat_logits
29 |
30 |
31 | def clf_free_logit_mask_fn(logits, _, num_decodes, alpha=10.0):
32 | logits = nn.log_softmax(logits)
33 | logits = jnp.reshape(logits, [-1, num_decodes, logits.shape[-1]])
34 | bs = logits.shape[0]
35 | logits, clf_free = logits[:bs//2], logits[bs//2:]
36 | logits = (1 + alpha) * logits - alpha * clf_free
37 | return jnp.reshape(jnp.tile(logits, (2, 1, 1)), [bs*num_decodes, -1])
38 |
39 |
40 | def clf_free_next_token_callback(logits, next_token, num_decodes):
41 | # The classifier free examples need to follow the non-clf-free versions
42 | next_token = jnp.reshape(next_token, [-1, num_decodes])
43 | bs = next_token.shape[0]
44 | next_token = jnp.tile(next_token[:bs//2], (2, 1))
45 | return jnp.reshape(next_token, -1)
46 |
47 |
48 | @gin.configurable()
49 | def pose_estimation_mask_fn_part_names(_, lengths):
50 | """Mask logits so that the model only predicts pose part names and location points"""
51 | vocab = get_default_vocabulary()
52 | if config.TOKENIZER == "llama":
53 | vocab_size = 33280
54 | else:
55 | raise NotImplementedError()
56 | masks = []
57 | loc_mask = np.ones([vocab_size], np.bool_)
58 | loc_mask[32000:33000] = False
59 | for part in config.HUMAN_POSE_PART:
60 | masks += [loc_mask, loc_mask]
61 | for voc_id in vocab.encode(part):
62 | mask = np.ones([vocab_size], np.bool_)
63 | mask[voc_id] = False
64 | masks.append(mask)
65 | eos_mask = np.ones([vocab_size], np.bool_)
66 | eos_mask[1] = 0
67 | masks.append(eos_mask)
68 | mask = jnp.array(np.stack(masks, 0))
69 | return functools.partial(apply_mask, mask=mask)
70 |
71 |
72 | @gin.configurable
73 | def non_loc_select(_, lengths, thresh=0.5, require_one_box=False):
74 | """Mask logits so EOS is only selected if the total prob over location tokens is < `thresh`"""
75 | voc_size = 33280 if config.TOKENIZER == "llama" else 33152 + 16384
76 | loc_mask = np.zeros([voc_size], np.float32)
77 | loc_mask[:32000] = -10000
78 | loc_mask[33000:] = -10000
79 | loc_mask = jnp.array(loc_mask)
80 |
81 | def _fn(logits, cur_index):
82 | logits = jax.nn.log_softmax(logits)
83 | probs = jnp.exp(jax.scipy.special.logsumexp(logits[:, 32000:33000], axis=-1))
84 | use_loc = probs > thresh
85 | if require_one_box:
86 | use_loc = jnp.logical_or(use_loc, cur_index <= 3)
87 | return logits + loc_mask[None, :] * use_loc[:, None]
88 | return _fn
89 |
90 |
91 | def state_transformation_fns():
92 | fn = [
93 | functools.partial(
94 | state_utils.apply_assignment_map,
95 | assignment_map=[
96 | (r'state.*', None),
97 | ])
98 | ]
99 |
100 | return fn
101 |
102 |
103 | def remove_optimizer_state():
104 | fn = [
105 | functools.partial(
106 | state_utils.apply_assignment_map,
107 | assignment_map=[
108 | (r'state.*', None),
109 | ])
110 | ]
111 | return fn
112 |
113 |
114 | def load_vae(state_dict, target_state_dict: Dict, *, is_resuming: bool = False, modality="image"):
115 | if is_resuming:
116 | return target_state_dict
117 | return dict(target={f"target_encoders_{modality}": dict(discrete_vae=state_dict["target"])})
118 |
119 |
120 | def vit_vqgan_restore_fn(modality="image"):
121 | return [functools.partial(load_vae, modality=modality)]
122 |
123 |
124 | def load_vqgan(state_dict, target_state_dict: Dict, *, is_resuming: bool = False, modality="image"):
125 | if is_resuming:
126 | return target_state_dict
127 | if 'image_vitvqgan' in state_dict["target"]:
128 | return dict(target={f"target_encoders_{modality}": dict(discrete_vae=state_dict["target"]['image_vitvqgan'])})
129 | else:
130 | return dict(target={f"target_encoders_{modality}": dict(discrete_vae=state_dict["target"])})
131 |
132 |
133 | def vqgan_restore_fn(modality="image"):
134 | return [functools.partial(load_vqgan, modality=modality)]
135 |
136 |
137 | def load_vae_all(state_dict, target_state_dict: Dict, *, is_resuming: bool = False):
138 | if is_resuming:
139 | return target_state_dict
140 |
141 | return dict(target=dict(
142 | target_encoders_image=dict(discrete_vae=state_dict["target"]["image_vitvqgan"]),
143 | target_encoders_audio=dict(discrete_vae=state_dict["target"]["audio_vitvqgan"])))
144 |
145 |
146 | def vit_vqgan_all_restore_fn():
147 | """State transformation to map a VQGAN checkpoint parameters into a UIO2 model"""
148 | return [load_vae_all]
149 |
150 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/data/__init__.py:
--------------------------------------------------------------------------------
1 | from seqio.dataset_providers import *
2 | from seqio.utils import *
3 | from seqio.vocabularies import *
4 | from seqio.utils import *
--------------------------------------------------------------------------------
/t5x/examples/unified_io/data/mixtures.py:
--------------------------------------------------------------------------------
1 | from t5x.examples.unified_io.data import tasks
2 | from t5x.examples.unified_io.data import nlp_instruction_following
3 | from seqio import MixtureRegistry
4 |
5 |
6 | MixtureRegistry.add(
7 | "refexp",
8 | [
9 | ("refcoco_plus_unc", 1.0),
10 | ("refcocog_google", 1.0),
11 | ("refcoco_unc", 1.0),
12 | ],
13 | )
14 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/data/postprocessing.py:
--------------------------------------------------------------------------------
1 | def return_example(output_or_target, example=None, is_target=False):
2 | if is_target:
3 | return example
4 | else:
5 | return output_or_target
6 |
7 |
8 | def return_meta(output_or_target, example=None, is_target=False):
9 | if is_target:
10 | return {k[5:]: v for k, v in example.items() if k.startswith("meta/")}
11 | else:
12 | return output_or_target
13 |
14 |
15 | def return_field(output_or_target, field, example=None, is_target=False):
16 | if is_target:
17 | return example[field]
18 | else:
19 | return output_or_target
20 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/data/prompt_definition.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 |
4 | import gin
5 |
6 | from t5x.examples.unified_io.data.prompt_dict import PROMPT_DICT
7 |
8 |
9 | @gin.configurable
10 | class Prompt:
11 | """Configurable interface for getting prompts"""
12 |
13 | def __init__(self, original_flag=True, revised_original_flag=False, manual_flag=True,
14 | gpt3_flag=True, single_prompt=False, dbg=None):
15 | self.prompt_list = []
16 | self.original_flag = original_flag
17 | self.revised_original_flag = revised_original_flag
18 | self.manual_flag = manual_flag
19 | self.gpt3_flag = gpt3_flag
20 | self.single_prompt = single_prompt
21 | self.dbg = dbg
22 |
23 | def get_prompt_list(self, task_name, dataset_name):
24 | if self.dbg:
25 | logging.info(f"Using dbg prmopt {self.dbg}")
26 | return [self.dbg]
27 | prompt_list = []
28 | if self.original_flag:
29 | if self.revised_original_flag and 'revised_original' in PROMPT_DICT[task_name]:
30 | prompt_list += PROMPT_DICT[task_name]['revised_original']
31 | else:
32 | prompt_list += PROMPT_DICT[task_name]['original']
33 | if self.revised_original_flag and 'revised_original' in PROMPT_DICT[dataset_name]:
34 | prompt_list += PROMPT_DICT[dataset_name]['revised_original']
35 | else:
36 | prompt_list += PROMPT_DICT[dataset_name]['original']
37 | if self.manual_flag:
38 | if 'manual' in PROMPT_DICT[task_name]:
39 | prompt_list += PROMPT_DICT[task_name]['manual']
40 | if 'manual' in PROMPT_DICT[dataset_name]:
41 | prompt_list += PROMPT_DICT[dataset_name]['manual']
42 | if self.gpt3_flag:
43 | if 'gpt3' in PROMPT_DICT[task_name]:
44 | prompt_list += PROMPT_DICT[task_name]['gpt3']
45 |
46 | if 'gpt3' in PROMPT_DICT[dataset_name]:
47 | prompt_list += PROMPT_DICT[dataset_name]['gpt3']
48 | if not prompt_list:
49 | raise ValueError(f"No prompts for {task_name}/{dataset_name}")
50 | if self.single_prompt:
51 | logging.info(f"Using prompt \"{prompt_list[0]}\" for {task_name} {dataset_name}")
52 | return prompt_list[:1]
53 | return prompt_list
54 |
55 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/data/tasks.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import seqio
4 | from seqio import TaskRegistry
5 |
6 | from t5x.examples.unified_io.metrics import metrics
7 | from t5x.examples.unified_io.data.postprocessing import return_meta, return_field
8 | from t5x.examples.unified_io.metrics.metrics import exact_match
9 | from t5x.examples.unified_io.modality_processing import unified_io_preprocessor
10 |
11 | from t5x.examples.unified_io import config, modality_processing
12 | from t5x.examples.unified_io.data import preprocessing
13 | from t5x.examples.unified_io.data.preprocessing import rekey
14 |
15 |
16 | def add_refexp(name, src_name=None):
17 | if src_name is None:
18 | src_name = name
19 | TaskRegistry.add(
20 | name,
21 | # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
22 | source=seqio.TfdsDataSource(
23 | tfds_name=f"ref_coco/{src_name}:1.0.0",
24 | tfds_data_dir=config.MULTITASK_TFDS_DATA_DIR,
25 | ),
26 | preprocessors=[
27 | functools.partial(
28 | rekey, key_map={
29 | "image": ["image"],
30 | "bbox": ["objects", "bbox"],
31 | "label": ["objects", "refexp", "raw"],
32 | "refexp_id": ["objects", "refexp", "refexp_id"],
33 | }),
34 | functools.partial(
35 | preprocessing.refer_expression_preprocessor,
36 | dataset_name=name,
37 | ),
38 | unified_io_preprocessor,
39 | ],
40 | postprocess_fn=return_meta,
41 | metric_fns=[metrics.ref_exp_metric],
42 | output_features=modality_processing.OUTPUT_FEATURES,
43 | )
44 |
45 |
46 | add_refexp("refcoco_unc")
47 | add_refexp("refcocog_google")
48 | add_refexp("refcoco_plus_unc", "refcocoplus_unc")
49 |
50 |
51 | TaskRegistry.add(
52 | "image_generation_coco_2017",
53 | source=seqio.TfdsDataSource(
54 | tfds_name="coco_all:1.0.1",
55 | tfds_data_dir=config.MULTITASK_TFDS_DATA_DIR,
56 | ),
57 | preprocessors=[
58 | functools.partial(
59 | rekey, key_map={
60 | "image/filename": ["image/filename"],
61 | "image": ["image"],
62 | "captions": ["captions", "text"]
63 | }),
64 | functools.partial(
65 | preprocessing.image_generation_preprocessor,
66 | dataset_name="image_generation_coco_2017",
67 | ),
68 | unified_io_preprocessor,
69 | ],
70 | output_features=modality_processing.OUTPUT_FEATURES,
71 | )
72 |
73 |
74 | TaskRegistry.add(
75 | "image_caption_coco_2017",
76 | source=seqio.TfdsDataSource(
77 | tfds_name="coco_all:1.0.1",
78 | tfds_data_dir=config.MULTITASK_TFDS_DATA_DIR,
79 | ),
80 | preprocessors=[
81 | functools.partial(
82 | rekey, key_map={
83 | "image/filename": ["image/filename"],
84 | "image": ["image"],
85 | "captions": ["captions", "text"]
86 | }),
87 | functools.partial(
88 | preprocessing.image_caption_preprocessor,
89 | dataset_name="image_caption_coco_2017",
90 | ),
91 | unified_io_preprocessor
92 | ],
93 | output_features=modality_processing.OUTPUT_FEATURES,
94 | )
95 |
96 |
97 | TaskRegistry.add(
98 | "image_inpainting_coco",
99 | source=seqio.TfdsDataSource(
100 | tfds_name="coco_all:1.0.1",
101 | tfds_data_dir=config.MULTITASK_TFDS_DATA_DIR,
102 | ),
103 | preprocessors=[
104 | functools.partial(
105 | rekey, key_map={
106 | "image": ["image"],
107 | "bbox": ["objects", "bbox"],
108 | "label": ["objects", "label"],
109 | }),
110 | functools.partial(
111 | preprocessing.image_inpainting_preprocessor,
112 | dataset_name="image_inpainting_coco",
113 | class_names='metadata/coco/coco_class_name_2017.json',
114 | ),
115 | unified_io_preprocessor
116 | ],
117 | output_features=modality_processing.OUTPUT_FEATURES,
118 | )
119 |
120 |
121 | TaskRegistry.add(
122 | "vqa_coco_2017",
123 | # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
124 | source=seqio.TfdsDataSource(
125 | tfds_name="coco_all:1.0.1",
126 | tfds_data_dir=config.MULTITASK_TFDS_DATA_DIR,
127 | ),
128 | preprocessors=[
129 | functools.partial(
130 | rekey, key_map={
131 | "image": ["image"],
132 | "text_inputs": ["vqa", "questions"],
133 | "text_targets": ["vqa", "answers"],
134 | }),
135 | preprocessing.vqa_preprocessor,
136 | unified_io_preprocessor
137 | ],
138 | postprocess_fn=functools.partial(return_field, field="meta/all_references"),
139 | metric_fns=[metrics.vqa_metric],
140 | output_features=modality_processing.OUTPUT_FEATURES,
141 | )
142 |
143 |
144 | TaskRegistry.add(
145 | "box_classification_coco_2017",
146 | # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
147 | source=seqio.TfdsDataSource(
148 | tfds_name="coco_all:1.0.1",
149 | tfds_data_dir=config.MULTITASK_TFDS_DATA_DIR,
150 | ),
151 | preprocessors=[
152 | functools.partial(
153 | rekey, key_map={
154 | "image": ["image"],
155 | "bbox": ["objects", "bbox"],
156 | "label": ["objects", "label"],
157 | "image_id": ["image/filename"],
158 | }),
159 | functools.partial(
160 | preprocessing.box_classification_preprocessor,
161 | dataset_name='box_classification_coco_2017',
162 | class_names='metadata/coco/coco_class_name_2017.json',
163 | ),
164 | unified_io_preprocessor
165 | ],
166 | postprocess_fn=functools.partial(return_field, field="meta/label"),
167 | metric_fns=[exact_match],
168 | output_features=modality_processing.OUTPUT_FEATURES,
169 | )
--------------------------------------------------------------------------------
/t5x/examples/unified_io/evaluator.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from dataclasses import dataclass
3 | from typing import Optional, Mapping, Sequence, Any, List, Dict
4 |
5 | import gin
6 | import seqio
7 | from absl import logging
8 | from seqio import metrics as metrics_lib
9 |
10 | from t5x.examples.unified_io.data.data_utils import get_default_vocabulary
11 |
12 | AllOutputTokensType = Mapping[str, Sequence[Sequence[int]]]
13 | AllOutputScoresType = Mapping[str, Sequence[float]]
14 | AllOutputAuxValuesType = Mapping[str, Mapping[str, Sequence[Any]]]
15 | AllMetricsType = Mapping[str, Mapping[str, Any]]
16 |
17 |
18 | @dataclass
19 | class UnifiedIOOutput:
20 | """Wrapper to make it easier to work with the many different outputs of UIO2"""
21 | aux_values: Dict
22 |
23 | @property
24 | def text(self):
25 | return self.aux_values.get("text")
26 |
27 | @property
28 | def text_tokens(self):
29 | return self.aux_values.get("text-tokens")
30 |
31 | @property
32 | def image_tokens(self):
33 | return self.aux_values.get("image-tokens")
34 |
35 | @property
36 | def image(self):
37 | return self.aux_values.get("image")
38 |
39 | @property
40 | def audio(self):
41 | return self.aux_values.get("audio")
42 |
43 | @property
44 | def scores(self):
45 | if "scores" in self.aux_values:
46 | return self.aux_values["scores"]
47 | else:
48 | # Assume only one modality is presennt
49 | for x in ["text", "image", "audio"]:
50 | x = f"{x}-scores"
51 | if x in self.aux_values:
52 | return self.aux_values[x]
53 | raise ValueError(f"No scores found in self.aux_values, keys={self.aux_values.keys()}")
54 |
55 |
56 | def build_uio_outputs(aux_values, vocab) -> List[UnifiedIOOutput]:
57 | out = []
58 | n = len(next(iter(aux_values.values())))
59 | for ix in range(n):
60 | values = {k: v[ix] for k, v in aux_values.items()}
61 | txt_tokens = values.get("text-tokens")
62 | if txt_tokens is None:
63 | pass
64 | elif len(txt_tokens.shape) == 1:
65 | values["text"] = vocab.decode(txt_tokens)
66 | else:
67 | values["text"] = [vocab.decode(x) for x in txt_tokens]
68 | out.append(UnifiedIOOutput(values))
69 | return out
70 |
71 |
72 | @gin.configurable()
73 | class UnifiedIOEvaluator(seqio.Evaluator):
74 | """Evaluator for UnifiedIO 2"""
75 | # This class basically follows `seqio.Evaluator` but has a few UIO2 hacks
76 |
77 | def __init__(
78 | self,
79 | mixture_or_task_name: str,
80 | feature_converter,
81 | eval_split: str = "validation",
82 | use_cached: bool = False,
83 | seed: Optional[int] = 42,
84 | sequence_length: Optional[Mapping[str, int]] = None,
85 | num_examples: Optional[int] = None,
86 | shuffle: bool = False,
87 | logger_cls: Sequence = (),
88 | log_dir: Optional[str] = None,
89 | use_memory_cache: bool = True,
90 | target_field_name: str = "targets",
91 | ):
92 | # We use a simplified `sequence_length` that does not contain fields that exactly match
93 | # the Dataset fields. This can cause an issue because the evaluator will delete those
94 | # non-matching entries before the feature conversion stage, so we alias them to these
95 | # names that do match the dataset structure here so they will be preserved.
96 | if "text_inputs" in sequence_length:
97 | sequence_length["inputs/text/tokens"] = sequence_length["text_inputs"]
98 | if "text_targets" in sequence_length:
99 | sequence_length["targets/text/tokens"] = sequence_length["text_targets"]
100 | super().__init__(
101 | mixture_or_task_name, feature_converter, eval_split, use_cached, seed, sequence_length,
102 | num_examples, shuffle, logger_cls, log_dir, use_memory_cache, target_field_name
103 | )
104 |
105 | def _compute_metrics(self,
106 | predicted_tokens: AllOutputTokensType,
107 | scores: AllOutputScoresType,
108 | all_aux_values: AllOutputAuxValuesType,
109 | step: Optional[int] = None) -> AllMetricsType:
110 |
111 | vocab = get_default_vocabulary()
112 | all_metrics = {}
113 | for task in self.eval_tasks:
114 | logging.info("Computing metrics for %s", task.name)
115 | task_dataset = self.cached_task_datasets[task.name]
116 | targets = self.cached_targets[task.name]
117 | task_metrics = []
118 | inferences = {}
119 |
120 | if task.predict_metric_fns or task.predict_with_aux_metric_fns:
121 | (outputs,
122 | postprocessed_outputs) = self._decode_and_postprocess_predictions(
123 | task, predicted_tokens, task_dataset, targets)
124 | inferences["output"] = outputs
125 | inferences["prediction"] = postprocessed_outputs
126 |
127 | if task.predict_metric_fns:
128 | task_metrics.extend([
129 | metric_fn(targets, inferences["prediction"])
130 | for metric_fn in task.predict_metric_fns
131 | ])
132 |
133 | if task.predict_with_aux_metric_fns:
134 | aux_values = all_aux_values[task.name]
135 | uio_output = build_uio_outputs(aux_values, vocab)
136 | task_metrics.extend([
137 | metric_fn(targets, uio_output, aux_values)
138 | for metric_fn in task.predict_with_aux_metric_fns
139 | ])
140 | inferences["aux_value"] = aux_values
141 |
142 | if task.score_metric_fns:
143 | task_scores = scores[task.name]
144 | if len(targets) != len(task_scores):
145 | raise ValueError(f"len(targets)({len(targets)}) != "
146 | f"len(task_scores)({len(task_scores)})")
147 | task_metrics.extend([
148 | metric_fn(targets, task_scores)
149 | for metric_fn in task.score_metric_fns
150 | ])
151 | inferences["score"] = task_scores
152 |
153 | all_metrics[task.name] = {}
154 | for k, v in itertools.chain(*[m.items() for m in task_metrics]):
155 | if k in all_metrics[task.name]:
156 | raise ValueError(f"Duplicate metric key '{k}' in Task '{task.name}'.")
157 | all_metrics[task.name][k] = v
158 |
159 | metrics = {
160 | k: metrics_lib.Scalar(v)
161 | if not isinstance(v, metrics_lib.MetricValue) else v
162 | for k, v in all_metrics[task.name].items()
163 | }
164 | for logger in self.loggers:
165 | logger(task_name=task.name, step=step, metrics=metrics,
166 | dataset=task_dataset, inferences=inferences, targets=targets)
167 |
168 | return all_metrics
169 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/metrics/grit_keypoint.py:
--------------------------------------------------------------------------------
1 | # The object keypoint score (OKS) is computed and averaged within each image
2 | import numpy as np
3 | from scipy.optimize import linear_sum_assignment
4 |
5 | N_KEYPOINTS = 17
6 | N_DIM = 3
7 |
8 |
9 | def get_bbox_from_kp(kp):
10 | k_array3d = np.reshape(np.array(kp),(N_KEYPOINTS,N_DIM))
11 | kp = k_array3d[:,:2]
12 | k_vis = k_array3d[:,2]
13 | kp_only_labeled = kp[k_vis > 0]
14 | if len(kp_only_labeled) == 0:
15 | raise ValueError("All points are marked as not visible!")
16 | x_min = kp_only_labeled[:,0].min()
17 | y_min = kp_only_labeled[:,1].min()
18 | x_max = kp_only_labeled[:,0].max()
19 | y_max = kp_only_labeled[:,1].max()
20 | bbox = np.array([x_min, y_min, x_max, y_max])
21 | return bbox
22 |
23 |
24 | def computeOks(dts, gts):
25 | """
26 | analogous to computing IoUs for localization / segmentation
27 | """
28 | ious = np.zeros((len(dts), len(gts)))
29 | sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0
30 | variances = (sigmas * 2)**2
31 |
32 | # compute oks between each detection and ground truth object
33 | for j, gt in enumerate(gts):
34 | g = np.array(gt)
35 | xg = g[0::3]; yg = g[1::3]; vg = g[2::3]
36 | x_min, y_min, x_max, y_max = get_bbox_from_kp(gt)
37 | area = (y_max-y_min)*(x_max-x_min)
38 | for i, dt in enumerate(dts):
39 | d = np.array(dt)
40 | xd = d[0::3]; yd = d[1::3]
41 | dx = xd - xg
42 | dy = yd - yg
43 | e = (dx**2 + dy**2) / variances / (area+np.spacing(1)) / 2
44 | e = e[vg > 0]
45 | ious[i, j] = np.sum(np.exp(-e)) / e.shape[0]
46 | return ious
47 |
48 |
49 | def assign_instances(pred_points: list, gt_points: list):
50 | ious = computeOks(pred_points, gt_points)
51 | cost = -ious
52 |
53 | # solve assignment
54 | pred_ids, gt_ids = linear_sum_assignment(cost)
55 | pair_ids = list(zip(pred_ids, gt_ids))
56 |
57 | # select assignments with iou > 0
58 | pair_ids = [(i,j) for i,j in pair_ids if ious[i,j] > 0]
59 | pairs = [(pred_points[i],gt_points[j]) for i,j in pair_ids]
60 | pair_ious = [ious[i,j] for i,j in pair_ids]
61 |
62 | return pairs, pair_ious, pair_ids
63 |
64 |
65 | def kp_metric(pred_points: list, gt_points: list, return_pairs=False) -> float:
66 | num_pred = len(pred_points)
67 | num_gt = len(gt_points)
68 | if num_pred == 0 and num_gt == 0:
69 | return 1 if not return_pairs else (1, [], [])
70 | elif min(num_pred,num_gt) == 0 and max(num_pred,num_gt) > 0:
71 | return 0 if not return_pairs else (0, [], [])
72 |
73 | pairs, pair_ious, pair_ids = assign_instances(pred_points, gt_points)
74 | num_detected = len(pairs)
75 | num_missed = num_gt - num_detected
76 | score = np.sum(pair_ious) / (num_pred + num_missed)
77 |
78 | return score if not return_pairs else (score, pairs, pair_ids)
79 |
80 |
81 | def kp_metric_wrapper(prediction, ground_truth, return_pairs=False):
82 | if 'output' in prediction:
83 | prediction = prediction['output']
84 | assert ground_truth['output']['example_id'] == prediction['example_id']
85 | pred_points = prediction['points']
86 | gt_points = ground_truth['output']['points']
87 | return kp_metric(pred_points, gt_points, return_pairs)
--------------------------------------------------------------------------------
/t5x/examples/unified_io/metrics/grit_localization.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | from scipy.optimize import linear_sum_assignment
5 |
6 |
7 | def compute_iou(bbox1: list, bbox2: list, verbose: bool=False):
8 | x1, y1, x2, y2 = bbox1
9 | x1_, y1_, x2_, y2_ = bbox2
10 |
11 | x1_in = max(x1, x1_)
12 | y1_in = max(y1, y1_)
13 | x2_in = min(x2, x2_)
14 | y2_in = min(y2, y2_)
15 |
16 | intersection = compute_area(bbox=[x1_in, y1_in, x2_in, y2_in], invalid=0.0)
17 | area1 = compute_area(bbox1, invalid=0)
18 | area2 = compute_area(bbox2, invalid=0)
19 | union = area1 + area2 - intersection
20 | iou = intersection / (union + 1e-6)
21 |
22 | if verbose:
23 | return iou, intersection, union
24 |
25 | return iou
26 |
27 |
28 | def compute_area(bbox: list, invalid: float=None) -> float:
29 | x1, y1, x2, y2 = bbox
30 |
31 | if (x2 <= x1) or (y2 <= y1):
32 | area = invalid
33 | else:
34 | area = (x2 - x1) * (y2 - y1)
35 |
36 | return area
37 |
38 |
39 | def assign_boxes(pred_boxes: List[List], gt_boxes: List[List]):
40 | n1 = len(pred_boxes)
41 | n2 = len(gt_boxes)
42 | cost = np.zeros([n1,n2])
43 | ious = np.zeros([n1,n2])
44 | for i,bbox1 in enumerate(pred_boxes):
45 | for j,bbox2 in enumerate(gt_boxes):
46 | iou = compute_iou(bbox1,bbox2)
47 | ious[i,j] = iou
48 | cost[i,j] = 1-iou
49 |
50 | # solve assignment
51 | pred_box_ids, gt_box_ids = linear_sum_assignment(cost)
52 | pair_ids = list(zip(pred_box_ids, gt_box_ids))
53 |
54 | # select assignments with iou > 0
55 | pair_ids = [(i,j) for i,j in pair_ids if ious[i,j] > 0]
56 | pairs = [(pred_boxes[i],gt_boxes[j]) for i,j in pair_ids]
57 | pair_ious = [ious[i,j] for i,j in pair_ids]
58 |
59 | return pairs, pair_ious, pair_ids
60 |
61 |
62 | def loc_metric(pred_boxes: List[List], gt_boxes: List[List]) -> float:
63 | num_pred = len(pred_boxes)
64 | num_gt = len(gt_boxes)
65 | if num_pred == 0 and num_gt == 0:
66 | return 1
67 | elif min(num_pred,num_gt) == 0 and max(num_pred,num_gt) > 0:
68 | return 0
69 |
70 | pairs, pair_ious, pair_ids = assign_boxes(pred_boxes,gt_boxes)
71 | num_detected = len(pairs)
72 | num_missed = num_gt - num_detected
73 | return np.sum(pair_ious) / (num_pred + num_missed)
--------------------------------------------------------------------------------
/t5x/examples/unified_io/metrics/grit_segmentation.py:
--------------------------------------------------------------------------------
1 | from scipy.optimize import linear_sum_assignment
2 | import numpy as np
3 | from scipy import ndimage
4 |
5 |
6 | # https://github.com/bowenc0221/boundary-iou-api/blob/master/boundary_iou/utils/boundary_utils.py
7 | # General util function to get the boundary of a binary mask.
8 | def mask_to_boundary(mask, dilation_ratio=0.02):
9 | """
10 | Convert binary mask to boundary mask.
11 | :param mask (numpy array, uint8): binary mask
12 | :param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
13 | :return: boundary mask (numpy array)
14 | """
15 | h, w = mask.shape
16 | img_diag = np.sqrt(h ** 2 + w ** 2)
17 | dilation = int(round(dilation_ratio * img_diag))
18 | if dilation < 1:
19 | dilation = 1
20 | # Pad image so mask truncated by the image border is also considered as boundary.
21 | new_mask = np.pad(mask, [[1, 1], [1, 1]], constant_values=0)
22 | kernel = np.ones((3, 3), dtype=np.uint8)
23 | new_mask_erode = ndimage.binary_erosion(new_mask, kernel, iterations=dilation)
24 | mask_erode = new_mask_erode[1 : h + 1, 1 : w + 1]
25 | # G_d intersects G in the paper.
26 | return np.logical_xor(mask, mask_erode)
27 |
28 |
29 | def compute_iou(mask1, mask2, verbose=False):
30 | # resize predicted mask to be the same size as gt mask
31 | if mask1.shape != mask2.shape:
32 | raise NotImplementedError()
33 |
34 | intersection = np.sum(np.logical_and(mask1, mask2))
35 | union = np.sum(np.logical_or(mask1, mask2))
36 |
37 | iou = intersection / (union + 1e-6)
38 |
39 | if verbose:
40 | return iou, intersection, union
41 |
42 | return iou
43 |
44 |
45 | def assign_segmentations(pred_masks: list, gt_masks: list):
46 | n1 = len(pred_masks)
47 | n2 = len(gt_masks)
48 | cost = np.zeros([n1,n2])
49 | ious = np.zeros([n1,n2])
50 | for i,mask1 in enumerate(pred_masks):
51 | for j,mask2 in enumerate(gt_masks):
52 | iou = compute_iou(mask1,mask2)
53 | ious[i,j] = iou
54 | cost[i,j] = -iou
55 |
56 | # solve assignment
57 | pred_mask_ids, gt_mask_ids = linear_sum_assignment(cost)
58 | pair_ids = list(zip(pred_mask_ids, gt_mask_ids))
59 |
60 | # select assignments with iou > 0
61 | pair_ids = [(i,j) for i,j in pair_ids if ious[i,j] > 0]
62 | pairs = [(pred_masks[i],gt_masks[j]) for i,j in pair_ids]
63 | pair_ious = [ious[i,j] for i,j in pair_ids]
64 |
65 | return pairs, pair_ious, pair_ids
66 |
67 |
68 | # expects numpy arrays, could change to RLEs and use rle iou / merge fuction above
69 | def seg_metric(pred_masks: list, gt_masks: list, stuff: bool, return_pairs=False) -> float:
70 | """
71 | pred_masks: list of numpy arrays representing binary masks
72 | gt_masks: list of numpy arrays representing binary masks
73 | stuff: boolean for evaluation type (False for "Thing")
74 | pairs: return assignment pairs between prediction and gt instances
75 | """
76 | if stuff: # merge masks into to single mask
77 | pred_masks = [np.logical_or.reduce(pred_masks)] if pred_masks else []
78 | gt_masks = [np.logical_or.reduce(gt_masks)]
79 |
80 | num_pred = len(pred_masks)
81 | num_gt = len(gt_masks)
82 | if num_pred == 0 and num_gt == 0:
83 | return 1 if not return_pairs else (1, [], [])
84 | elif min(num_pred,num_gt) == 0 and max(num_pred,num_gt) > 0:
85 | return 0 if not return_pairs else (0, [], [])
86 |
87 | # uses BoundaryIoU
88 | # TODO double check mask_to_boundary is officially used, its not clear from the grit codebase
89 | pred_boundaries = [mask_to_boundary(m) for m in pred_masks]
90 | gt_boundaries = [mask_to_boundary(m) for m in gt_masks]
91 |
92 | pairs, pair_ious, pair_ids = assign_segmentations(pred_boundaries,gt_boundaries)
93 | num_detected = len(pairs)
94 | num_missed = num_gt - num_detected
95 | score = np.sum(pair_ious) / (num_pred + num_missed)
96 |
97 | return score if not return_pairs else (score, pairs, pair_ids)
98 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/metrics/grit_vqa.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | contractions = {
4 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
5 | "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \
6 | "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \
7 | "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \
8 | "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \
9 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \
10 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \
11 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \
12 | "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \
13 | "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \
14 | "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \
15 | "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \
16 | "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \
17 | "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \
18 | "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \
19 | "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \
20 | "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \
21 | "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \
22 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \
23 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \
24 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \
25 | "youll": "you'll", "youre": "you're", "youve": "you've"}
26 |
27 | manualMap = {
28 | 'none': '0',
29 | 'zero': '0',
30 | 'one': '1',
31 | 'two': '2',
32 | 'three': '3',
33 | 'four': '4',
34 | 'five': '5',
35 | 'six': '6',
36 | 'seven': '7',
37 | 'eight': '8',
38 | 'nine': '9',
39 | 'ten': '10'}
40 |
41 | articles = ['a','an','the']
42 |
43 | punct = [
44 | ';', r"/", '[', ']', '"', '{', '}',
45 | '(', ')', '=', '+', '\\', '_', '-',
46 | '>', '<', '@', '`', ',', '?', '!']
47 |
48 | periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
49 | commaStrip = re.compile("(\d)(\,)(\d)")
50 |
51 |
52 | def processPunctuation(inText):
53 | outText = inText
54 | for p in punct:
55 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(commaStrip, inText) != None):
56 | outText = outText.replace(p, '')
57 | else:
58 | outText = outText.replace(p, ' ')
59 | outText = periodStrip.sub("",outText,re.UNICODE)
60 | return outText
61 |
62 |
63 | def processDigitArticle(inText):
64 | outText = []
65 | tempText = inText.lower().split()
66 | for word in tempText:
67 | word = manualMap.setdefault(word, word)
68 | if word not in articles:
69 | outText.append(word)
70 | else:
71 | pass
72 | for wordId, word in enumerate(outText):
73 | if word in contractions:
74 | outText[wordId] = contractions[word]
75 | outText = ' '.join(outText)
76 | return outText
77 |
78 |
79 | def preprocess_answer(ans, cache={}):
80 | """GRIT VQA pre-processing"""
81 | if ans in cache:
82 | return cache[ans]
83 | ans = ans.replace('\n', ' ')
84 | ans = ans.replace('\t',' ')
85 | ans = ans.lower().strip()
86 | preprocessed = processDigitArticle(processPunctuation(ans))
87 | cache[ans] = preprocessed
88 | return preprocessed
89 |
90 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/metrics/metrics.py:
--------------------------------------------------------------------------------
1 | """Metrics for UiO2 tasks"""
2 | from collections import Counter
3 |
4 | from absl import logging
5 | from typing import List, Sequence
6 |
7 | import numpy as np
8 | from seqio.metrics import Scalar, Text
9 |
10 | from t5x.examples.unified_io.metrics.utils import extract_coordinates_from_text, \
11 | undo_box_preprocessing
12 |
13 | from t5x.examples.unified_io.evaluator import UnifiedIOOutput
14 |
15 | from t5x.examples.unified_io.metrics.grit_localization import compute_iou
16 | from t5x.examples.unified_io.metrics.grit_vqa import preprocess_answer as vqa_preprocessing
17 | from t5x.examples.unified_io import config
18 |
19 |
20 | def exact_match(targets, predictions: List[UnifiedIOOutput],
21 | aux_values, print_examples=True):
22 | if isinstance(targets[0], np.ndarray):
23 | # Multiple correct answers stored in a numpy object array
24 | matches = [pred.text.lower() in [x.decode("utf-8").lower() for x in target] for target, pred in zip(targets, predictions)]
25 | else:
26 | if isinstance(targets[0], bytes):
27 | targets = [x.decode("utf-8") for x in targets]
28 | matches = [pred.text.lower() == target.lower() for target, pred in zip(targets, predictions)]
29 | if print_examples:
30 | ixs = np.random.choice(len(targets), min(20, len(targets)), replace=False)
31 | examples = [f"pred={predictions[i].text} gt={targets[i]}" for i in ixs]
32 | for ex in examples:
33 | logging.info(ex)
34 | return {
35 | "score": np.mean(matches),
36 | }
37 |
38 |
39 | def vqa_score(target, pred):
40 | pred = vqa_preprocessing(pred)
41 | if isinstance(target, list):
42 | target = Counter(vqa_preprocessing(x) for x in target)
43 | return min(target[pred] / 3.0, 1)
44 | else:
45 | return float(vqa_preprocessing(pred) == vqa_preprocessing(target))
46 |
47 |
48 | def vqa_metric(targets: Sequence, predictions: Sequence[UnifiedIOOutput], aux_values):
49 | if isinstance(targets[0], np.ndarray):
50 | targets = [[ans.decode("utf-8") for ans in answer_set] for answer_set in targets]
51 | else:
52 | targets = [answer.decode("utf-8") for answer in targets]
53 | score = np.mean([vqa_score(t, p.text) for t, p in zip(targets, predictions)])
54 | n_targets = len(targets)
55 | ixs = np.random.choice(n_targets, min(n_targets, 20), replace=False)
56 | examples = [f"{predictions[i].text.lower()} (gt(s)={', '.join(targets[i])})" for i in ixs]
57 | return {
58 | "score": Scalar(score),
59 | "examples": Text(", ".join(examples))
60 | }
61 |
62 |
63 | def ref_exp_metric(targets, predictions: List[UnifiedIOOutput],
64 | aux_values, original_scale=True):
65 | total_acc = 0
66 | total_iou = 0
67 | for target, pred in zip(targets, predictions):
68 | gt_boxes, image_info, src_boxes = target["boxes"], target["image_info"], target["src_boxes"]
69 | if len(gt_boxes) != 1:
70 | raise ValueError("Should always be one ground truth box")
71 |
72 | p_boxes, classes = extract_coordinates_from_text(
73 | pred.text, image_size=config.IMAGE_INPUT_SIZE, n_coordinates=4, use_label=False)
74 | if original_scale:
75 | p_boxes = undo_box_preprocessing(p_boxes, image_info)
76 | h, w = image_info[3:5]
77 | gt_boxes = src_boxes * np.array([h, w, h, w]).reshape(1, 4)
78 | if len(p_boxes) == 0:
79 | iou = 0
80 | else:
81 | iou = compute_iou(p_boxes[0], gt_boxes[0])
82 | total_iou += iou
83 | total_acc += float((iou > 0.5))
84 |
85 | n = len(predictions)
86 | return dict(acc=total_acc/n, iou=total_iou/n)
87 |
88 |
89 | def null_metric(targets, predictions, aux_values):
90 | return {}
91 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/metrics/rle.py:
--------------------------------------------------------------------------------
1 | from pycocotools import mask
2 |
3 | import json
4 | import numpy as np
5 | from pycocotools import mask as maskUtils
6 |
7 |
8 | def uncompressed_encode(binary_mask):
9 | binary_mask = np.asfortranarray(binary_mask)
10 | uncompressed_rle = {'counts': [], 'size': list(binary_mask.shape)}
11 | counts = uncompressed_rle.get('counts')
12 |
13 | last_elem = 0
14 | running_length = 0
15 |
16 | for i, elem in enumerate(binary_mask.ravel(order='F')):
17 | if elem == last_elem:
18 | pass
19 | else:
20 | counts.append(running_length)
21 | running_length = 0
22 | last_elem = elem
23 | running_length += 1
24 |
25 | counts.append(running_length)
26 |
27 | return uncompressed_rle
28 |
29 |
30 | def compress(uncompressed_rle):
31 | compressed_rle = mask.frPyObjects(uncompressed_rle, uncompressed_rle.get('size')[0], uncompressed_rle.get('size')[1])
32 | return compressed_rle
33 |
34 |
35 | def to_utf(rle):
36 | rle = rle.copy()
37 | rle['counts'] = rle['counts'].decode("utf-8", "backslashreplace")
38 | return rle
39 |
40 |
41 | def from_utf(rle):
42 | rle = rle.copy()
43 | rle['counts'] = rle['counts'].encode("utf-8")
44 | return rle
45 |
46 |
47 | def encode(binary_mask, utf=True):
48 | encoded = maskUtils.encode(np.asfortranarray(binary_mask))
49 | return to_utf(encoded) if utf else encoded
50 |
51 |
52 | def decode(rle, utf=True):
53 | if type(rle) == list:
54 | return [decode(r, utf) for r in rle]
55 | else:
56 | rle = from_utf(rle) if utf else rle
57 | decoded = maskUtils.decode(rle)
58 | return decoded
59 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/model_test.py:
--------------------------------------------------------------------------------
1 | import jax.random
2 | from absl.testing.absltest import TestCase
3 | from flax import traverse_util
4 |
5 | from t5x.examples.unified_io import test_utils, modality_processing
6 | from t5x.examples.unified_io.data.data_utils import get_default_vocabulary
7 | from t5x.examples.unified_io.modality_processing import unified_io_preprocessor, \
8 | UnifiedIOFeatureConverter
9 | import numpy as np
10 |
11 | from t5x.examples.unified_io.models import EncoderDecoderModel
12 | from t5x.examples.unified_io.test_utils import DEBUG_CONFIG
13 | from t5x.examples.unified_io.utils import get_model
14 | import tensorflow as tf
15 |
16 |
17 | class ModelTest(TestCase):
18 |
19 | @classmethod
20 | def setUpClass(cls):
21 | cls.model = get_model("tiny", ["text"], ["text"])
22 | cls.cfg = cls.model.module.config
23 |
24 | def test_with_choices(self):
25 | """Test predictions with `choices` is consistent with computing the loss"""
26 | seq_len = dict(text_inputs=32, text_targets=16)
27 | ds = tf.data.Dataset.from_tensors(dict(
28 | text_inputs="Which answer is best?",
29 | text_targets="",
30 | choices=["a fat cat", "a quick dog"]
31 | ))
32 | ds = unified_io_preprocessor(ds, modality_processing.OUTPUT_FEATURES, seq_len)
33 | ds = UnifiedIOFeatureConverter()(ds, seq_len)
34 | batch = next(ds.repeat(2).batch(2).as_numpy_iterator())
35 |
36 | variables = self.model.get_initial_variables(
37 | jax.random.PRNGKey(5919),
38 | {k: v.shape for k, v in batch.items()},
39 | {k: v.dtype for k, v in batch.items()}
40 | )["params"]
41 |
42 | _, aux = self.model.predict_batch_with_aux(variables, batch)
43 |
44 | # Take the highest-ranked choices and compute loss as a regular batch
45 | # We have to manually build EOS and auto-regressive inputs
46 | tokens = aux["text-tokens"]
47 | features = traverse_util.unflatten_dict(batch, sep="/")
48 | features["targets"] = dict(text=dict(
49 | inputs=np.pad(tokens[:, :-1], [[0, 0], [1, 0]]),
50 | mask=(tokens>0).astype(np.int32),
51 | targets=tokens
52 | ))
53 | batch = traverse_util.flatten_dict(features, sep="/")
54 | loss = self.model.loss_fn(
55 | variables, batch, None, z_loss=0.0, loss_normalizing_by_weight_sum=False,
56 | loss_normalizing_factor=1)[0]
57 | self.assertAlmostEqual(float(loss), float(aux["scores"].sum()), places=4)
58 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/packing_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from absl.testing import absltest
4 |
5 | from t5x.examples.unified_io.packing import batch_with_constraints, pack_in_pairs
6 |
7 |
8 | def build_mask(lens):
9 | c1 = np.floor(1, lens * np.random.random(lens.shape))
10 | c2 = lens - c1
11 | l1 = np.maximum(c1.max() + np.random.randint(0, 3), 1)
12 | out1 = np.arange(l1)[None, :] < c1[:, None]
13 | l2 = np.maximum(c2.max() + np.random.randint(0, 3), 1)
14 | out2 = np.arange(l2)[None, :] < c2[:, None]
15 | return out1, out2
16 |
17 |
18 | def build_ds(lens):
19 | input_lens = np.array([x[0] for x in lens])
20 | target_lens = np.array([x[1] for x in lens])
21 | i1, i2 = build_mask(input_lens)
22 | t1, t2 = build_mask(target_lens)
23 | data = {
24 | "inputs/image/mask": i1,
25 | "inputs/text/mask": i2,
26 | "targets/image/mask": t1,
27 | "targets/text/mask": t2,
28 | "ixs": np.expand_dims(1 + np.arange(len(input_lens)), 1)
29 | }
30 | return tf.data.Dataset.from_tensor_slices(data)
31 |
32 |
33 | def tolist(ex):
34 | return {k: v.tolist() for k, v in ex.items()}
35 |
36 |
37 | class TestConstraintBatching(absltest.TestCase):
38 |
39 | def test_tiny(self):
40 | ds = tf.data.Dataset.from_tensor_slices(dict(
41 | c1=tf.convert_to_tensor([9, 12, 3, 4])
42 | ))
43 | batches = list(ex["c1"].tolist() for ex in batch_with_constraints(
44 | ds, 2, 5, [(lambda x: x["c1"], 20)]).as_numpy_iterator())
45 | self.assertEqual([[9, 3], [12, 4]], batches)
46 |
47 | batches = list(ex["c1"].tolist() for ex in batch_with_constraints(
48 | ds, 2, 5, [(lambda x: x["c1"], 100)]).as_numpy_iterator())
49 | self.assertEqual([[9, 12], [3, 4]], batches)
50 |
51 | def test_multiple(self):
52 | ds = tf.data.Dataset.from_tensor_slices(dict(
53 | c1=tf.convert_to_tensor([2, 2, 2, 2, 2, 14, 2, 2]),
54 | c2=tf.convert_to_tensor([11, 12, 13, 2, 2, 2, 2, 2]),
55 | example_ids=tf.range(8)
56 | ))
57 |
58 | def _fn1(x):
59 | return x["c1"]
60 |
61 | def _fn2(x):
62 | return x["c2"]
63 |
64 | batch_fns = [(_fn1, 15.1), (_fn2, 20)]
65 |
66 | batches = list(set(ex["example_ids"].tolist()) for ex in batch_with_constraints(
67 | ds, 3, 10, batch_fns).as_numpy_iterator())
68 | expected = [{0, 3, 4}, {1, 6, 7}]
69 | self.assertEqual(expected, batches)
70 |
71 | def test_random(self):
72 | seed = tf.convert_to_tensor([0, 83452])
73 | ds = tf.data.Dataset.from_tensor_slices(dict(
74 | c1=tf.random.stateless_uniform((100,), seed, 1, 10, dtype=tf.int64),
75 | c2=tf.random.stateless_uniform((100,), seed+1, 1, 10, dtype=tf.int64)
76 | ))
77 | const = [(lambda x: x["c1"], 18),
78 | (lambda x: x["c2"], 18)]
79 | for ex in batch_with_constraints(ds, 3, 10, const).as_numpy_iterator():
80 | self.assertLessEqual(ex["c1"].sum(), const[0][1])
81 | self.assertLessEqual(ex["c1"].sum(), const[1][1])
82 |
83 |
84 | def _pack_in_pairs(*args, **kwargs, ):
85 | return pack_in_pairs(
86 | *args, **kwargs,
87 | encoder_masks=[
88 | ("inputs/image/mask", None),
89 | ("inputs/text/mask", None)
90 | ],
91 | decoder_masks=[
92 | "targets/image/mask",
93 | "targets/text/mask"
94 | ]
95 | )
96 |
97 |
98 | class TestPackingBatching(absltest.TestCase):
99 |
100 | def test_tiny_pack(self):
101 | ds = build_ds([(3, 2), (2, 1)])
102 | ds = _pack_in_pairs(ds, 1, 5, 5, pool_size=2)
103 | exs = next(ds.as_numpy_iterator())
104 | self.assertEqual(set(exs["ixs"].ravel().tolist()), {1, 2})
105 |
106 | def test_tiny_pack2(self):
107 | ds = build_ds([(3, 2), (5, 2)])
108 | ds = _pack_in_pairs(ds, 1, 6, 6, pool_size=1)
109 | exs = next(ds.as_numpy_iterator())
110 | self.assertEqual(set(exs["ixs"].ravel().tolist()), {0, 2})
111 |
112 | def test_pool2(self):
113 | ds = build_ds([
114 | (3, 2), # add to pool
115 | (5, 5), # add to pool
116 | (2, 4), # add to pool, write (5, 5) as batch
117 | (2, 3), # pair with (3, 2)
118 | ])
119 | ds = _pack_in_pairs(ds, 1, 5, 5, pool_size=2).as_numpy_iterator()
120 | ex1 = next(ds)
121 | self.assertEqual(set(ex1["ixs"].ravel().tolist()), {0, 2})
122 | ex2 = next(ds)
123 | self.assertEqual(set(ex2["ixs"].ravel().tolist()), {1, 4})
124 |
125 |
126 | if __name__ == "__main__":
127 | absltest.main()
--------------------------------------------------------------------------------
/t5x/examples/unified_io/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/t5x/examples/unified_io/scripts/__init__.py
--------------------------------------------------------------------------------
/t5x/examples/unified_io/t5_1_1/base.gin:
--------------------------------------------------------------------------------
1 | # type: ignore
2 |
3 | # T5.1.1 Base model.
4 | from __gin__ import dynamic_registration
5 |
6 | import seqio
7 | from t5x import adafactor
8 | from t5x.examples.unified_io import models
9 | from t5x.examples.unified_io import network
10 | from t5x.examples.unified_io import config
11 | from t5x import optimizers
12 | from t5x import utils
13 | import optax
14 | from t5x import trainer
15 |
16 | # ------------------- Loss HParam ----------------------------------------------
17 | Z_LOSS = 0.0001
18 | LABEL_SMOOTHING = 0.0
19 | TEXT_DECODER_LENGTH = None
20 | IMAGE_DECODER_LENGTH = None
21 | # NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF)
22 | # the loss normalizing factor should be set to pretraining batch_size *
23 | # target_token_length.
24 | LOSS_NORMALIZING_FACTOR = None
25 | LOSS_NORMALIZING_BY_WEIGHT_SUM = True
26 | # Dropout should be specified in the "run" files
27 | DROPOUT_RATE = 0.0
28 | DROPOUT_BROADCAST_DIMS = (-2, )
29 | DROPPATH_RATE = 0.0
30 |
31 | # Vocabulary (shared by encoder and decoder)
32 | VOCABULARY = @seqio.SentencePieceVocabulary()
33 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model"
34 |
35 | # ------------------- Optimizer ------------------------------------------------
36 | # `learning_rate` is set by `Trainer.learning_rate_fn`.
37 | # In this case, we choose to switch to the AdamW optimizer with gradient clip.
38 | OPTIMIZER = None
39 |
40 | # ------------------- Model ----------------------------------------------------
41 | MODEL = @models.EncoderDecoderModel()
42 | models.EncoderDecoderModel:
43 | module = @network.Transformer()
44 | input_vocabulary = %VOCABULARY
45 | output_vocabulary = %VOCABULARY
46 | optimizer_def = %OPTIMIZER
47 | z_loss = %Z_LOSS
48 | label_smoothing = %LABEL_SMOOTHING
49 | loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
50 | loss_normalizing_by_weight_sum = %LOSS_NORMALIZING_BY_WEIGHT_SUM
51 |
52 | # ------------------- Network specification ------------------------------------
53 | network.Transformer.config = @config.T5Config()
54 |
55 | config.T5Config:
56 | vocab_size = 33280 # vocab size rounded to a multiple of 128 for TPU efficiency
57 | image_vocab_size = 16512 # vocab size rounded to a multiple of 128 for TPU efficiency
58 | image_patch_size = 16
59 | audio_vocab_size = 8320 # vocab size rounded to a multiple of 128 for TPU efficiency
60 | audio_patch_size = 16
61 |
62 | dtype = 'bfloat16'
63 | emb_dim = 768
64 | num_heads = 12
65 | num_encoder_layers = 12
66 | num_decoder_layers = 12
67 | head_dim = 64
68 | mlp_dim = 2048
69 | mlp_activations = ('silu', 'linear')
70 | dropout_rate = %DROPOUT_RATE
71 | dropout_broadcast_dims = %DROPOUT_BROADCAST_DIMS
72 | logits_via_embedding = True
73 | float32_attention_logits = True
74 | decoder_xattention_internval = 1
75 |
76 | image_tokenizer_type = 'vqgan'
77 |
78 | from t5x.examples.unified_io import modality_processing
79 | modality_processing.get_input_modalities:
80 | image_vit_cfg = @config.ImageVitFeatureConfig()
81 | audio_vit_cfg = @config.AudioVitFeatureConfig()
82 | image_history_cfg = @config.ImageResamplerConfig()
83 | audio_history_cfg = @config.AudioResamplerConfig()
84 | use_image_vit = True
85 | use_audio_vit = True
86 | use_image_history_vit = True
87 | use_audio_history_vit = True
88 |
89 | modality_processing.get_target_modalities:
90 | image_vae_config = @config.VAEConfig()
91 | audio_vae_config = @config.AudioViTVQGANConfig()
92 |
93 | config.VAEConfig:
94 | embed_dim = 256
95 | n_embed = 16384
96 | double_z = False
97 | z_channels = 4
98 | resolution = 256
99 | in_channels = 3
100 | out_ch = 3
101 | ch = 128
102 | ch_mult = (1,2,2,4)
103 | num_res_blocks = 2
104 | attn_resolutions = (32,)
105 | dropout = 0
106 | default_input_size = (256,256)
107 | patch_size = (8, 8)
108 |
109 | config.AudioViTVQGANConfig:
110 | vocab_size = 8192
111 | proj_dim = 32
112 | # Transformers
113 | encoder_hidden_size = 512
114 | encoder_num_layers = 8
115 | encoder_mlp_dim = 2048
116 | encoder_num_heads = 8
117 | encoder_head_dim = 64
118 |
119 | encoder_hidden_size = 512
120 | encoder_num_layers = 8
121 | encoder_mlp_dim = 2048
122 | encoder_num_heads = 8
123 | encoder_head_dim = 64
124 |
125 | dropout_rate = 0.0
126 | droppath_rate = 0.0
127 | attention_dropout_rate = 0.0
128 | use_bias = False
129 | act_fn = 'relu'
130 | # PE
131 | add_position_embedding = False
132 | # Misc.
133 | dtype = 'bfloat16'
134 | default_input_size = (128, 256) # we need to keep this to make it
135 | patch_size = (8, 8)
136 |
137 | output_channel = 1
138 | use_decoder = True
139 |
140 | config.ImageVitFeatureConfig:
141 | patch_size = 16
142 | pos_patch_size = 16
143 | emb_dim = 768
144 | num_heads = 12
145 | num_layers = 11 # -2 layer
146 | mlp_dim = 3072
147 | mlp_activations = ('gelu', )
148 | dropout_rate = 0.0
149 | dropout_broadcast_dims = ()
150 | default_input_size = (256, 256)
151 | num_pos = 197
152 | dtype = 'float32'
153 |
154 | config.AudioVitFeatureConfig:
155 | patch_size = 16
156 | emb_dim = 768
157 | num_heads = 12
158 | num_layers = 11 # -2 layer
159 | mlp_dim = 3072
160 | mlp_activations = ('gelu', )
161 | dropout_rate = 0.0
162 | dropout_broadcast_dims = ()
163 | default_input_size = (256, 128)
164 | transpose_input = True
165 | dtype = 'float32'
166 |
167 | config.ImageResamplerConfig:
168 | dtype = 'bfloat16'
169 | resampler_type = 'perceiver'
170 | max_frames = 8
171 | latents_size = 32
172 | emb_dim = 768
173 | num_heads = 12
174 | num_layers = 2
175 | xattention_index = (0, 1)
176 | head_dim = 64
177 | mlp_dim = 3072
178 | mlp_activations = ('gelu',)
179 | dropout_broadcast_dims = (-2,)
180 | droppath_rate = 0.0
181 | layer_drop = 0.0
182 | dropout_rate = 0.0
183 |
184 | config.AudioResamplerConfig:
185 | dtype = 'bfloat16'
186 | resampler_type = 'perceiver'
187 | max_frames = 8
188 | latents_size = 16
189 | emb_dim = 768
190 | num_heads = 12
191 | num_layers = 2
192 | xattention_index = (0, 1)
193 | head_dim = 64
194 | mlp_dim = 3072
195 | mlp_activations = ('gelu',)
196 | dropout_broadcast_dims = (-2,)
197 | droppath_rate = 0.0
198 | layer_drop = 0.0
199 | dropout_rate = 0.0
200 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/t5_1_1/eval/vision_language.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 | include 't5x/configs/runs/eval.gin'
3 |
4 | from t5x.examples.unified_io import aux_fns
5 |
6 | # IMPORTANT: This assumes pretty short prompts/output, change if needed
7 | TEXT_INPUTS = 128
8 | TEXT_DECODER_LENGTH = 32
9 | IMAGE_INPUT_SAMPLES = None
10 |
11 | DROPOUT_RATE = 0.0 # might be required by an imported model config
12 |
13 | TASK_FEATURE_LENGTHS = {
14 | "text_inputs": %TEXT_INPUTS,
15 | "text_targets": %TEXT_DECODER_LENGTH,
16 | "image_input_samples": %IMAGE_INPUT_SAMPLES,
17 | "is_training": False,
18 | }
19 |
20 | from t5x.examples.unified_io import modality_processing
21 | modality_processing.get_input_modalities.input_modality=["image", "text"]
22 | modality_processing.get_target_modalities.target_modality=["text"]
23 |
24 | from t5x.examples.unified_io import models
25 | models.EncoderDecoderModel.predict_batch_with_aux:
26 | length = %TEXT_DECODER_LENGTH
27 | modality = "text"
28 |
29 | # Import so any gin configurable method in these files will be registered with gin
30 | # and thus can be modified by command line
31 | from t5x.examples.unified_io import network
32 | from t5x.examples.unified_io.metrics import metrics
33 | from t5x.examples.unified_io import decoding
34 |
35 | # Import so the registration happens
36 | from t5x.examples.unified_io.data import tasks
37 | from t5x.examples.unified_io.data import mixtures
38 | from t5x.examples.unified_io import aux_fns
39 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/t5_1_1/finetune/refexp.gin:
--------------------------------------------------------------------------------
1 | # Fine tune on referring expression, required parameters:
2 | # INITIAL_CHECKPOINT_PATH
3 | # MODEL_DIR
4 |
5 | from __gin__ import dynamic_registration
6 | import __main__ as train_script
7 |
8 | # Register necessary SeqIO Tasks/Mixtures.
9 |
10 | from t5x.examples.unified_io.data import tasks
11 | from t5x.examples.unified_io import aux_fns
12 | from t5x.examples.unified_io import models
13 | from t5x import partitioning
14 | from t5x import trainer
15 | import seqio
16 | from t5x.examples.unified_io import packing
17 | from t5x import utils as t5x_utils
18 | from t5x.examples.unified_io import config
19 |
20 |
21 | include 't5x/configs/runs/multitask.gin'
22 |
23 | MIXTURE_OR_TASK_NAME = "refcoco_unc"
24 | MIXTURE_OR_TASK_NAME_EVAL = "refcoco_unc"
25 |
26 |
27 | TRAIN_STEPS = 3_100_000 # 100000 after 3million pre-training steps
28 | DROPOUT_RATE = 0.0
29 | BATCH_SIZE = 128
30 | EVAL_STEPS = 50
31 |
32 |
33 | train_script.train:
34 | eval_period = 2500
35 | stats_period = 500
36 | partitioner = @partitioning.PjitPartitioner()
37 | use_wandb = True
38 | concurrent_metrics = False
39 | infer_eval_dataset_cfg = @train_infer/t5x_utils.DatasetConfig()
40 |
41 |
42 | t5x_utils.SaveCheckpointConfig:
43 | period = 20000
44 |
45 |
46 | train_infer/t5x_utils.DatasetConfig:
47 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL
48 | task_feature_lengths = %TASK_FEATURE_LENGTHS_EVAL
49 | split = 'validation'
50 | batch_size = %BATCH_SIZE
51 | shuffle = False
52 | seed = 42
53 | use_cached = %USE_CACHED_TASKS
54 | pack = False
55 |
56 |
57 | partitioning.PjitPartitioner.num_partitions = 4
58 |
59 | from t5x.examples.unified_io import utils
60 | from t5x.examples.unified_io import modality_processing
61 |
62 | # Only load the needed modalities
63 | modality_processing.get_input_modalities.input_modality=["image", "text"]
64 | modality_processing.get_target_modalities.target_modality=["text"]
65 |
66 |
67 | TEXT_INPUT_LEN = 256
68 | TEXT_TARGET_LEN = 32
69 | IMAGE_SAMPLES = 1.0
70 |
71 | TASK_FEATURE_LENGTHS_TRAIN = {
72 | "text_inputs": %TEXT_INPUT_LEN,
73 | "text_targets": %TEXT_TARGET_LEN,
74 | "image_input_samples": %IMAGE_SAMPLES,
75 | "image_history_input_samples": 128,
76 | "audio_input_samples": 64,
77 | "audio_history_input_samples": 64,
78 | "num_frames": 4,
79 | "is_training": True,
80 | }
81 |
82 |
83 | TASK_FEATURE_LENGTHS_EVAL = {
84 | "text_inputs": %TEXT_INPUT_LEN,
85 | "text_targets": %TEXT_TARGET_LEN,
86 | "image_input_samples": None,
87 | "image_history_input_samples": 128,
88 | "audio_input_samples": 64,
89 | "audio_history_input_samples": 64,
90 | "num_frames": 4,
91 | "is_training": False,
92 | }
93 |
94 | models.EncoderDecoderModel.predict_batch_with_aux:
95 | length = %TEXT_TARGET_LEN
96 | modality = "text"
97 |
98 |
99 | t5x_utils.create_learning_rate_scheduler:
100 | factors = 'constant * linear_warmup * rsqrt_decay'
101 | # Generally for fine-tuning we half the learning rate
102 | base_learning_rate = 0.5
103 | warmup_steps = 2000 # 10k to keep consistent with T5/MTF defaults.
104 |
105 |
106 | from t5x import adafactor
107 | OPTIMIZER = @adafactor.Adafactor()
108 | adafactor.Adafactor:
109 | decay_rate = 0.8
110 | beta1 = 0.9
111 | step_offset = 0
112 | logical_factor_rules = @adafactor.standard_logical_factor_rules()
113 | global_norm_clip_threshold = 1.0
114 | skip_nan_updates = True
115 |
116 |
117 | import t5x.examples.unified_io.evaluator as uio_evaluator
118 | uio_evaluator.UnifiedIOEvaluator:
119 | num_examples = 10000
120 | logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @utils.WandbMetricsLogger]
121 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/t5_1_1/large.gin:
--------------------------------------------------------------------------------
1 | # type: ignore
2 |
3 | # T5.1.1 Large model.
4 |
5 | include 't5x/examples/unified_io/t5_1_1/base.gin' # imports vocab, optimizer and model.
6 |
7 | EMBED_DIM = 1024
8 | MLP_DIM = 2816
9 | NUM_HEADS = 16
10 | HEAD_DIM = 64
11 |
12 | from t5x.examples.unified_io import config
13 | network.Transformer.config = @config.T5Config()
14 | config.T5Config:
15 | emb_dim = %EMBED_DIM
16 | num_heads = %NUM_HEADS
17 | num_encoder_layers = 24
18 | num_decoder_layers = 24
19 | head_dim = %HEAD_DIM
20 | mlp_dim = %MLP_DIM
21 | decoder_xattention_internval = 1
22 |
23 | config.ImageResamplerConfig:
24 | dtype = 'bfloat16'
25 | resampler_type = 'perceiver'
26 | emb_dim = 768
27 | num_heads = 12
28 | head_dim = 64
29 | mlp_dim = 2048
30 | num_layers = 2
31 | xattention_index = (0, 1)
32 | dropout_broadcast_dims = (-2,)
33 | mlp_activations = ('gelu',)
34 | max_frames = 8
35 | latents_size = 32
36 |
37 | config.AudioResamplerConfig:
38 | resampler_type = 'perceiver'
39 | dtype = 'bfloat16'
40 | emb_dim = 768
41 | num_heads = 12
42 | head_dim = 64
43 | mlp_dim = 2048
44 | num_layers = 2
45 | xattention_index = (0, 1)
46 | dropout_broadcast_dims = (-2,)
47 | mlp_activations = ('gelu',)
48 | max_frames = 8
49 | latents_size = 16
--------------------------------------------------------------------------------
/t5x/examples/unified_io/t5_1_1/tiny.gin:
--------------------------------------------------------------------------------
1 | # type: ignore
2 |
3 | # T5.1.1 Large model.
4 |
5 | include 't5x/examples/unified_io/t5_1_1/base.gin' # imports vocab, optimizer and model.
6 |
7 | EMBED_DIM = 16
8 | MLP_DIM = 32
9 | NUM_HEADS = 4
10 | HEAD_DIM = 16
11 | NUM_RESAMPLER_LAYER = 1
12 |
13 | from t5x.examples.unified_io import config
14 | network.Transformer.config = @config.T5Config()
15 | config.T5Config:
16 | emb_dim = %EMBED_DIM
17 | num_heads = %NUM_HEADS
18 | num_encoder_layers = 1
19 | num_decoder_layers = 1
20 | head_dim = %HEAD_DIM
21 | mlp_dim = %MLP_DIM
22 |
23 | config.ImageResamplerConfig:
24 | emb_dim = %EMBED_DIM
25 | num_heads = %NUM_HEADS
26 | num_layers = %NUM_RESAMPLER_LAYER
27 | head_dim = %HEAD_DIM
28 | mlp_dim = %MLP_DIM
29 |
30 | config.AudioResamplerConfig:
31 | emb_dim = %EMBED_DIM
32 | num_heads = %NUM_HEADS
33 | num_layers = %NUM_RESAMPLER_LAYER
34 | head_dim = %HEAD_DIM
35 | mlp_dim = %MLP_DIM
36 |
37 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/t5_1_1/xl.gin:
--------------------------------------------------------------------------------
1 | # type: ignore
2 |
3 | # T5.1.1 XL model.
4 |
5 | include 't5x/examples/unified_io/t5_1_1/base.gin' # imports vocab, optimizer and model.
6 |
7 |
8 | IMAGE_LATENT_SIZE = 128
9 | AUDIO_LATENT_SIZE = 64
10 |
11 | from t5x.examples.unified_io import config
12 | network.Transformer.config = @config.T5Config()
13 |
14 | config.T5Config:
15 | vocab_size = 33280 # vocab size rounded to a multiple of 128 for TPU efficiency
16 | image_vocab_size = 16512 # vocab size rounded to a multiple of 128 for TPU efficiency
17 | image_patch_size = 16
18 | audio_vocab_size = 8320 # vocab size rounded to a multiple of 128 for TPU efficiency
19 | audio_patch_size = 16
20 | dtype = 'bfloat16'
21 | emb_dim = 2048
22 | num_heads = 16
23 | num_encoder_layers = 24
24 | num_decoder_layers = 24
25 | head_dim = 128
26 | mlp_dim = 5120
27 | mlp_activations = ('silu', 'linear')
28 | dropout_rate = %DROPOUT_RATE
29 | logits_via_embedding = True
30 | float32_attention_logits = True
31 | decoder_xattention_internval = 1
32 |
33 | config.ImageResamplerConfig:
34 | dtype = 'bfloat16'
35 | resampler_type = 'perceiver'
36 | emb_dim = 1024
37 | num_heads = 16
38 | head_dim = 64
39 | mlp_dim = 4096
40 | num_layers = 2
41 | xattention_index = (0, 1)
42 | dropout_broadcast_dims = (-2,)
43 | mlp_activations = ('gelu',)
44 | max_frames = 8
45 |
46 | config.AudioResamplerConfig:
47 | resampler_type = 'perceiver'
48 | dtype = 'bfloat16'
49 | emb_dim = 1024
50 | num_heads = 16
51 | head_dim = 64
52 | mlp_dim = 4096
53 | num_layers = 2
54 | xattention_index = (0, 1)
55 | dropout_broadcast_dims = (-2,)
56 | mlp_activations = ('gelu',)
57 | max_frames = 8
--------------------------------------------------------------------------------
/t5x/examples/unified_io/t5_1_1/xxl.gin:
--------------------------------------------------------------------------------
1 | # type: ignore
2 |
3 | # T5.1.1 XL model.
4 |
5 | include 't5x/examples/unified_io/t5_1_1/xl.gin' # imports vocab, optimizer and model.
6 |
7 | EMBED_DIM = 1024
8 | MLP_DIM = 2560
9 | NUM_HEADS = 16
10 | HEAD_DIM = 64
11 |
12 | from t5x.examples.unified_io import config
13 | # ------------------- Network specification overrides --------------------------
14 | network.Transformer.config = @config.T5Config()
15 | config.T5Config:
16 | vocab_size = 33280 # vocab size rounded to a multiple of 128 for TPU efficiency
17 | image_vocab_size = 16512 # vocab size rounded to a multiple of 128 for TPU efficiency
18 | image_patch_size = 16
19 | audio_vocab_size = 8320 # vocab size rounded to a multiple of 128 for TPU efficiency
20 | audio_patch_size = 16
21 | dtype = 'bfloat16'
22 | emb_dim = 3072
23 | num_heads = 24
24 | num_encoder_layers = 24
25 | num_decoder_layers = 24
26 | head_dim = 128
27 | mlp_dim = 8192
28 | mlp_activations = ('silu', 'linear')
29 | dropout_rate = %DROPOUT_RATE
30 | logits_via_embedding = True
31 | float32_attention_logits = True
32 | decoder_xattention_internval = 1
33 |
34 | config.ImageResamplerConfig:
35 | dtype = 'bfloat16'
36 | resampler_type = 'perceiver'
37 | emb_dim = 1024
38 | num_heads = 16
39 | head_dim = 64
40 | mlp_dim = 4096
41 | num_layers = 2
42 | xattention_index = (0, 1)
43 | dropout_broadcast_dims = (-2,)
44 | mlp_activations = ('gelu',)
45 | max_frames = 8
46 | xattn_qk_norm = False
47 | xattn_scaled_cosine = True
48 | attn_qk_norm = False
49 | attn_scaled_cosine = True
50 |
51 | config.AudioResamplerConfig:
52 | resampler_type = 'perceiver'
53 | dtype = 'bfloat16'
54 | emb_dim = 1024
55 | num_heads = 16
56 | head_dim = 64
57 | mlp_dim = 4096
58 | num_layers = 2
59 | xattention_index = (0, 1)
60 | dropout_broadcast_dims = (-2,)
61 | mlp_activations = ('gelu',)
62 | max_frames = 8
63 | xattn_qk_norm = False
64 | xattn_scaled_cosine = True
65 | attn_qk_norm = False
66 | attn_scaled_cosine = True
67 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/test_utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, List
3 |
4 | import jax
5 | import jax.numpy as jnp
6 | import numpy as np
7 | import flax.linen as nn
8 | from t5x.examples.unified_io import modality_processing
9 | from t5x.examples.unified_io.input_modalities import ModalityEncoder
10 | from t5x.examples.unified_io.network import Transformer
11 | from t5x.examples.unified_io.config import T5Config
12 | from t5x.examples.unified_io.seq_features import InputSequence, TargetSequence
13 | from t5x.examples.unified_io.target_modalities import BasicDecoder
14 |
15 | TARGET_MODALITIES = ["text", "image", "audio"]
16 |
17 |
18 | class NullEncoder(nn.Module):
19 | is_target: bool
20 | modality_id: int
21 |
22 | def __call__(self, embed, mask, init=False, targets=None, loss_mask=None, inputs=None,
23 | rel_atten=None, enable_dropout=None, use_constraints=True, segment_ids=None):
24 | if loss_mask is None:
25 | loss_mask = mask
26 |
27 | if self.is_target:
28 | return TargetSequence(
29 | embed, None,
30 | jnp.array(self.modality_id, dtype=jnp.int32), mask, rel_atten,
31 | subsegments=segment_ids, target_tokens=targets, loss_mask=loss_mask)
32 | else:
33 | return InputSequence(embed, mask, rel_atten)
34 |
35 |
36 | @dataclass
37 | class DebugModalityEncoder(ModalityEncoder):
38 | is_target: bool
39 | modality_id: int
40 |
41 | def get_encoder(self, *args, **kwargs) -> nn.Module:
42 | return NullEncoder(self.is_target, self.modality_id)
43 |
44 | def get_decoder(self, config, shared_embedding) -> nn.Module:
45 | return BasicDecoder(None, T5Config(None, logits_via_embedding=True), shared_embedding)
46 |
47 |
48 | class NullModalityEncoder(ModalityEncoder, nn.Module):
49 | def __init__(self, is_target, seq_len):
50 | super().__init__()
51 | self.seq_len = seq_len
52 | self.is_target = is_target
53 |
54 | def get_encoder(self, config, shared_embedding) -> nn.Module:
55 | self.t5_config = config.t5_config
56 | return self
57 |
58 | def __call__(self, null, init=False):
59 | bs = null.shape[0]
60 | if self.is_target:
61 | return TargetSequence.empty(
62 | bs, self.seq_len, self.t5_config.num_heads, self.t5_config.dtype, 2)
63 | else:
64 | return InputSequence.empty(bs, self.seq_len, self.t5_config)
65 |
66 |
67 | def build_random_batch(cfg, rng: np.random.RandomState, batch_size,
68 | input_modalities: List[int],
69 | target_modalities: List[int],
70 | target_segment_ids=False
71 | ):
72 | batch = build_inputs(cfg, rng, batch_size, input_modalities)
73 | batch.update(build_targets(
74 | cfg, rng, batch_size, target_modalities, target_segment_ids))
75 | return batch
76 |
77 |
78 | def build_inputs(cfg, np_rng: np.random.RandomState, batch_size, input_modalities: List[int]):
79 | out = {}
80 | for ix, seq_len in enumerate(input_modalities):
81 | out[f"inputs/{ix}/mask"] = np_rng.random((batch_size, seq_len)) > 0.5
82 | out[f"inputs/{ix}/embed"] = np_rng.uniform(-1, 1, (batch_size, seq_len, cfg.emb_dim))
83 | out[f"inputs/{ix}/rel_atten"] = np_rng.uniform(0, 0.5, (batch_size, cfg.num_heads, seq_len, seq_len))
84 |
85 | return out
86 |
87 |
88 | def build_targets(
89 | cfg, np_rng: np.random.RandomState, batch_size, target_modalities,
90 | segment_ids=False):
91 | out = {}
92 | assert len(target_modalities) <= 3
93 | for name, seq_len in zip(TARGET_MODALITIES, target_modalities):
94 | if seq_len is None:
95 | continue
96 | out.update({
97 | f"targets/{name}/mask": np_rng.random((batch_size, seq_len)) > 0.5,
98 | f"targets/{name}/embed": np_rng.uniform(-1, 1, (batch_size, seq_len, cfg.emb_dim)),
99 | f"targets/{name}/targets": np_rng.randint(0, cfg.vocab_size, (batch_size, seq_len), dtype=np.int32),
100 | })
101 | if segment_ids and name == "text":
102 | out[f"targets/{name}/segment_ids"] = np_rng.randint(0, 2, (batch_size, seq_len), dtype=np.int32)
103 | return out
104 |
105 |
106 | DEBUG_CONFIG = T5Config(
107 | num_encoder_layers=2,
108 | num_decoder_layers=2,
109 | vocab_size=100,
110 | dropout_rate=0.0,
111 | emb_dim=8,
112 | num_heads=2,
113 | head_dim=4,
114 | mlp_dim=12,
115 | dtype=jnp.float32,
116 | mlp_activations=('gelu',),
117 | logits_via_embedding=True,
118 | image_vocab_size=1000,
119 | )
120 |
121 |
122 | def build_test_transformer(
123 | cfg: T5Config, input_modalities: int,
124 | target_modalities: int,
125 | variable_seed=None,
126 | ):
127 | modality_processing.get_input_modalities = lambda: {
128 | str(ix): DebugModalityEncoder(False, ix) for ix in range(input_modalities)
129 | }
130 |
131 | # Use names `TARGET_MODALITIES` since those names are also hardcoded in some places
132 | assert target_modalities <= 3
133 | modality_processing.get_target_modalities = lambda: {
134 | name: DebugModalityEncoder(True, ix) for ix, name in enumerate(TARGET_MODALITIES[:target_modalities])
135 | }
136 |
137 | trans = Transformer(cfg)
138 | if variable_seed is not None:
139 | rng = jax.random.PRNGKey(variable_seed)
140 | np_rng = np.random.RandomState(variable_seed*9681)
141 | batch = build_random_batch(
142 | trans.config, np_rng, 1, [1]*input_modalities, [1]*target_modalities)
143 | variables = trans.init(rng, batch, init=True)
144 | return trans, variables
145 | else:
146 | return trans
147 |
--------------------------------------------------------------------------------
/t5x/examples/unified_io/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from typing import Optional, List, Dict, Mapping, Any, Sequence, Tuple
4 |
5 | import gin
6 | import jax.numpy as jnp
7 | import jax.random
8 | import numpy as np
9 | import seqio
10 | import tensorflow as tf
11 | import wandb
12 | from absl import logging
13 |
14 | from t5x import checkpoints
15 | from t5x import utils
16 | from t5x.examples.unified_io.config import SHUFFLE_BUFFER_SIZE, CYCLE_LENGTH, BLOCK_LENGTH
17 | from t5x.utils import TrainStateInitializer, RestoreCheckpointConfig, LegacyCheckpointManager
18 |
19 |
20 | def get_model(model_size, input_modalities=None, target_modalities=None,
21 | gin_bindings: Optional[List[str]]=None, dtype="bfloat16"):
22 | """Return a EncoderDecoder model and configure the code to support that model
23 |
24 | This will also configure the pre-processing functions to be consistent with returned model
25 | """
26 | if gin.config_is_locked():
27 | logging.warning("Using `get_model` might override existing gin flags")
28 |
29 | with gin.config.unlock_config():
30 | def _get(model):
31 | return model
32 |
33 | bindings = []
34 | if gin_bindings:
35 | bindings += gin_bindings
36 | if input_modalities:
37 | bindings.append(f"get_input_modalities.input_modality={input_modalities}")
38 | if target_modalities:
39 | bindings.append(f"get_target_modalities.target_modality={target_modalities}")
40 | if dtype != "bfloat16":
41 | bindings += [f"{x}.dtype=\"float32\"" for x in [
42 | "T5Config",
43 | "AudioViTVQGANConfig",
44 | "VAEConfig",
45 | "ImageVitFeatureConfig",
46 | "AudioVitFeatureConfig",
47 | "ImageResamplerConfig",
48 | "AudioResamplerConfig",
49 | ]]
50 | _get = gin.configurable(_get)
51 | gin.parse_config_files_and_bindings(
52 | config_files=[f"t5x/examples/unified_io/t5_1_1/{model_size}.gin"],
53 | bindings=bindings + [
54 | "_get.model=%MODEL"
55 | ]
56 | )
57 | return _get()
58 |
59 |
60 | def get_parameters(model, model_checkpoint: str = None,
61 | partitioner=None, rng: jax.random.PRNGKey=None) -> Tuple:
62 | """Get parameters for a model
63 |
64 | model: Model to get parameters for
65 | model_checkpoint: Checkpoint, if None initialized the model from scratch
66 | partitioner: If given, load parameters with this partitioner
67 | rng: If initializing from scratch, load parameters with this RNG
68 | """
69 | from t5x.examples.unified_io.modality_processing import get_input_spec
70 | t0 = time.perf_counter()
71 | if rng is None:
72 | seed = np.random.randint(np.iinfo(np.int32).min, np.iinfo(np.int32).max, (), np.int32)
73 | rng = jax.random.PRNGKey(seed)
74 |
75 | if model_checkpoint is None:
76 | logging.info("Init model from scratch")
77 | input_shapes, input_types = get_input_spec()
78 | if partitioner is None:
79 | params = model.get_initial_variables(rng, input_shapes, input_types)
80 | params, param_axes = params["params"], params["params_axes"]
81 | else:
82 | train_state_initializer = TrainStateInitializer(
83 | optimizer_def=None,
84 | init_fn=model.get_initial_variables,
85 | input_shapes=input_shapes,
86 | input_types=input_types,
87 | partitioner=partitioner
88 | )
89 | train_state = train_state_initializer.from_scratch(rng)
90 | param_axes = train_state_initializer.train_state_axes.params
91 | params = freeze(train_state.params)
92 | else:
93 | if not model_checkpoint.startswith("gs://"):
94 | model_checkpoint = os.path.abspath(os.path.expanduser(model_checkpoint))
95 | assert os.path.exists(model_checkpoint), f"{model_checkpoint=} does not exist!"
96 | logging.info(f"Loading model weights from {model_checkpoint}...")
97 | input_shapes, input_types = get_input_spec(1)
98 | if partitioner is not None:
99 | train_state_initializer = TrainStateInitializer(
100 | optimizer_def=None,
101 | init_fn=model.get_initial_variables,
102 | input_shapes=input_shapes,
103 | input_types=input_types,
104 | partitioner=partitioner
105 | )
106 | param_axes = train_state_initializer.train_state_axes.params
107 | params = LegacyCheckpointManager(
108 | restore_cfg=RestoreCheckpointConfig(model_checkpoint),
109 | train_state_shape=train_state_initializer.global_train_state_shape,
110 | partitioner=partitioner
111 | ).restore([model_checkpoint], RestoreCheckpointConfig(model_checkpoint)).params
112 | else:
113 | params = checkpoints.load_t5x_checkpoint(path=model_checkpoint)['target']
114 | params = jax.tree_util.tree_map(jnp.array, params)
115 | param_axes = None
116 | logging.info(f"Done in {time.perf_counter()-t0:0.1f}")
117 | return params, param_axes
118 |
119 |
120 | @gin.configurable()
121 | def init_wandb(name=None, group=None, entity=None, project=None):
122 | utils.create_learning_rate_scheduler() # Makes sure this is registered in `operative_config`
123 | config_str = gin.operative_config_str()
124 | logging.info(f"Init wandb with group={group} name={name}")
125 | wandb.init(
126 | group=group,
127 | name=name,
128 | entity=entity,
129 | project=project,
130 | force=True,
131 | notes=config_str
132 | )
133 |
134 |
135 | def transpose_lists(lsts):
136 | """Transpose a list of lists."""
137 | return [list(i) for i in zip(*lsts)]
138 |
139 |
140 | def list_of_dict_to_string(table: List[Dict[str, str]], filler="") -> str:
141 | keys = dict()
142 | for row in table:
143 | keys.update(row)
144 | raw_table = [list(keys)]
145 | raw_table += [[row.get(key, filler) for key in keys] for row in table]
146 | return table_string(raw_table)
147 |
148 |
149 | def table_string(table: List[List[str]]) -> str:
150 | """Table as listoflists to evenly spaces string"""
151 | # print while padding each column to the max column length
152 | if len(table) == 0:
153 | return ""
154 | col_lens = [0] * len(table[0])
155 | for row in table:
156 | for i, cell in enumerate(row):
157 | col_lens[i] = max(len(cell), col_lens[i])
158 |
159 | formats = ["{0:<%d}" % x for x in col_lens]
160 | out = []
161 | for row in table:
162 | out.append(" ".join(formats[i].format(row[i]) for i in range(len(row))))
163 | return "\n".join(out)
164 |
165 |
166 | class WandbMetricsLogger(seqio.Logger):
167 | """Log metrics to wandb"""
168 |
169 | def __call__(
170 | self,
171 | task_name: str,
172 | step: Optional[int],
173 | metrics: Mapping[str, Any],
174 | dataset: Optional[tf.data.Dataset],
175 | inferences: Optional[Mapping[str, Sequence[Any]]],
176 | targets: Optional[Sequence[Any]],
177 | ) -> None:
178 | if step is None:
179 | raise ValueError()
180 |
181 | wandb_metrics = {}
182 | for metric_name, metric_value in metrics.items():
183 | if isinstance(metric_value, seqio.metrics.Scalar):
184 | wandb_metrics[f"inference/{task_name}/{metric_name}"] = metric_value.value
185 | else:
186 | logging.warning(
187 | "Skipping WandbLogging of non-serializable metric '%s' of type %s.",
188 | metric_name,
189 | type(metric_value),
190 | )
191 | wandb.log(wandb_metrics, step=step)
--------------------------------------------------------------------------------
/t5x/export.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Exports a T5X model.
16 |
17 |
18 | """
19 | import os
20 | from typing import Sequence
21 | from absl import logging
22 |
23 | # Set Linen to add profiling information when constructing Modules.
24 | # Must be set before flax imports.
25 | # pylint:disable=g-import-not-at-top
26 | os.environ.setdefault('FLAX_PROFILE', 'true')
27 |
28 | import jax
29 | from t5x import export_lib
30 |
31 | if __name__ == '__main__':
32 | # pylint:disable=g-import-not-at-top
33 | from absl import app
34 | from absl import flags
35 | import gin
36 | from t5x import gin_utils
37 | # pylint:enable=g-import-not-at-top
38 |
39 | FLAGS = flags.FLAGS
40 |
41 | jax.config.parse_flags_with_absl()
42 |
43 |
44 | flags.DEFINE_multi_string(
45 | 'gin_file',
46 | default=None,
47 | help='Path to gin configuration file. Multiple paths may be passed and '
48 | 'will be imported in the given order, with later configurations '
49 | 'overriding earlier ones.')
50 |
51 | flags.DEFINE_multi_string(
52 | 'gin_bindings',
53 | default=[],
54 | help='Individual gin bindings. Also used to integrate gin and XManager.')
55 |
56 | flags.DEFINE_list(
57 | 'gin_search_paths',
58 | default=['t5x/configs'],
59 | help='Comma-separated list of gin config path prefixes to be prepended '
60 | 'to suffixes given via `--gin_file`. If a file appears in. Only the '
61 | 'first prefix that produces a valid path for each suffix will be '
62 | 'used.')
63 |
64 | def main(argv: Sequence[str]):
65 | """Wrapper for g3pdb post mortems."""
66 | _main(argv)
67 |
68 | def _main(argv: Sequence[str]):
69 | """True main function."""
70 | if len(argv) > 1:
71 | raise app.UsageError('Too many command-line arguments.')
72 |
73 | save_with_gin = gin.configurable(export_lib.save)
74 |
75 | gin_utils.parse_gin_flags(FLAGS.gin_search_paths, FLAGS.gin_file,
76 | FLAGS.gin_bindings)
77 | logging.info('Creating inference function...')
78 | save_with_gin()
79 |
80 | gin_utils.run(main)
81 |
--------------------------------------------------------------------------------
/t5x/gin_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for using gin configurations with T5X binaries."""
16 | import os
17 | from typing import Optional, Sequence, Union
18 |
19 | from absl import app
20 | from absl import logging
21 | from clu import metric_writers
22 | import gin
23 | import jax
24 | import tensorflow as tf
25 |
26 |
27 |
28 | def parse_gin_flags(gin_search_paths: Sequence[str],
29 | gin_files: Sequence[str],
30 | gin_bindings: Sequence[str],
31 | skip_unknown: Union[bool, Sequence[str]] = False,
32 | finalize_config: bool = True):
33 | """Parses provided gin files override params.
34 |
35 | Args:
36 | gin_search_paths: paths that will be searched for gin files.
37 | gin_files: paths to gin config files to be parsed. Files will be parsed in
38 | order with conflicting settings being overriden by later files. Paths may
39 | be relative to paths in `gin_search_paths`.
40 | gin_bindings: individual gin bindings to be applied after the gin files are
41 | parsed. Will be applied in order with conflicting settings being overriden
42 | by later oens.
43 | skip_unknown: whether to ignore unknown bindings or raise an error (default
44 | behavior). Alternatively, a list of configurable names to skip if unknown.
45 | finalize_config: whether to finalize the config so that it cannot be
46 | modified (default behavior).
47 | """
48 | # We import t5.data here since it includes gin configurable functions commonly
49 | # used by task modules.
50 | # TODO(adarob): Strip gin from t5.data and remove this import.
51 | import t5.data # pylint:disable=unused-import,g-import-not-at-top
52 | # Register .gin file search paths with gin
53 | for gin_file_path in gin_search_paths:
54 | gin.add_config_file_search_path(gin_file_path)
55 |
56 |
57 | # Parse config files and bindings passed via flag.
58 | gin.parse_config_files_and_bindings(
59 | gin_files,
60 | gin_bindings,
61 | skip_unknown=skip_unknown,
62 | finalize_config=finalize_config)
63 | logging.info('Gin Configuration:')
64 | for line in gin.config_str().splitlines():
65 | logging.info('%s', line)
66 |
67 |
68 | def rewrite_gin_args(args: Sequence[str]) -> Sequence[str]:
69 | """Rewrite `--gin.NAME=VALUE` flags to `--gin_bindings=NAME=VALUE`."""
70 |
71 | def _rewrite_gin_arg(arg):
72 | if not arg.startswith('--gin.'):
73 | return arg
74 | if '=' not in arg:
75 | raise ValueError(
76 | "Gin bindings must be of the form '--gin.=', got: " +
77 | arg)
78 | # Strip '--gin.'
79 | arg = arg[6:]
80 | name, value = arg.split('=', maxsplit=1)
81 | r_arg = f'--gin_bindings={name} = {value}'
82 | print(f'Rewritten gin arg: {r_arg}')
83 | return r_arg
84 |
85 | return [_rewrite_gin_arg(arg) for arg in args]
86 |
87 |
88 | @gin.register
89 | def summarize_gin_config(model_dir: str,
90 | summary_writer: Optional[metric_writers.MetricWriter],
91 | step: int):
92 | """Writes gin config to the model dir and TensorBoard summary."""
93 | if jax.process_index() == 0:
94 | config_str = gin.config_str()
95 | tf.io.gfile.makedirs(model_dir)
96 | # Write the config as JSON.
97 | with tf.io.gfile.GFile(os.path.join(model_dir, 'config.gin'), 'w') as f:
98 | f.write(config_str)
99 | # Include a raw dump of the json as a text summary.
100 | if summary_writer is not None:
101 | summary_writer.write_texts(step, {'config': gin.markdown(config_str)})
102 | summary_writer.flush()
103 |
104 |
105 | def run(main):
106 | """Wrapper for app.run that rewrites gin args before parsing."""
107 | app.run(
108 | main,
109 | flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a))) # pytype: disable=wrong-arg-types
110 |
111 |
112 | # ====================== Configurable Utility Functions ======================
113 |
114 |
115 | @gin.configurable
116 | def sum_fn(var1=gin.REQUIRED, var2=gin.REQUIRED):
117 | """sum function to use inside gin files."""
118 | return var1 + var2
119 |
120 |
121 | @gin.configurable
122 | def bool_fn(var1=gin.REQUIRED):
123 | """bool function to use inside gin files."""
124 | return bool(var1)
125 |
--------------------------------------------------------------------------------
/t5x/gin_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for gin_utils."""
16 |
17 | from absl.testing import absltest
18 | from t5x import gin_utils
19 |
20 |
21 | class GinUtilsTest(absltest.TestCase):
22 |
23 | def test_rewrite_gin_args(self):
24 | test_args = [
25 | '--gin_file=path/to/file',
26 | 'gin.value=3',
27 | '--gin.value=3',
28 | '--gin.value="3"',
29 | '--gin.value=\'3\'',
30 | '--gin.tricky="key = value"',
31 | '--gin.dict={"foo": 4, "bar": "four"}',
32 | '--gin.gin=bar',
33 | '--gin.scope/foo=bar',
34 | ]
35 | expected_args = [
36 | '--gin_file=path/to/file',
37 | 'gin.value=3',
38 | '--gin_bindings=value = 3',
39 | '--gin_bindings=value = "3"',
40 | '--gin_bindings=value = \'3\'',
41 | '--gin_bindings=tricky = "key = value"',
42 | '--gin_bindings=dict = {"foo": 4, "bar": "four"}',
43 | '--gin_bindings=gin = bar',
44 | '--gin_bindings=scope/foo = bar',
45 | ]
46 | self.assertSequenceEqual(
47 | gin_utils.rewrite_gin_args(test_args), expected_args)
48 |
49 | def test_rewrite_gin_args_malformed(self):
50 | test_args = ['--gin.value=3', '--gin.test']
51 | with self.assertRaisesWithLiteralMatch(
52 | ValueError,
53 | "Gin bindings must be of the form '--gin.=', got: "
54 | '--gin.test'):
55 | gin_utils.rewrite_gin_args(test_args)
56 |
57 |
58 | if __name__ == '__main__':
59 | absltest.main()
60 |
--------------------------------------------------------------------------------
/t5x/losses_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for t5x.losses."""
16 |
17 | from absl.testing import absltest
18 | import jax
19 | import jax.numpy as jnp
20 | import numpy as np
21 | from t5x import losses
22 |
23 |
24 | class LossTest(absltest.TestCase):
25 |
26 | def test_xent(self):
27 |
28 | def lossfn(logits, targets, weights):
29 | loss, z_loss, weight_sum = losses.compute_weighted_cross_entropy(
30 | logits,
31 | targets,
32 | weights,
33 | label_smoothing=0.1,
34 | z_loss=0.1,
35 | loss_normalizing_factor=0.1)
36 | return loss, (z_loss, weight_sum)
37 |
38 | batch_size = 2
39 | length = 4
40 | vocab_size = 8
41 | logits = np.random.normal(size=(batch_size, length,
42 | vocab_size)).astype(np.float32)
43 | targets = np.random.randint(0, vocab_size, size=(batch_size, length))
44 | weights = np.ones_like(targets)
45 | out = jax.jit(jax.value_and_grad(lossfn, has_aux=True))(logits, targets,
46 | weights)
47 | (loss, (z_loss, weight_sum)), dlogits = out
48 | # Just a smoke test for now
49 | # TODO(t5x): Expand test
50 | print(jax.device_get(((loss, (z_loss, weight_sum)), dlogits)))
51 |
52 |
53 | class SpecialLossNormalizingFactorTest(absltest.TestCase):
54 |
55 | def test_num_real_target_tokens(self):
56 | batch = {
57 | 'decoder_target_tokens':
58 | jnp.asarray([[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32)
59 | }
60 |
61 | (output_lnf,
62 | output_loss_weights) = losses.get_loss_normalizing_factor_and_weights(
63 | loss_normalizing_factor=losses.SpecialLossNormalizingFactor
64 | .NUM_REAL_TARGET_TOKENS,
65 | batch=batch)
66 |
67 | np.testing.assert_allclose(output_lnf, 6.0, rtol=1e-3)
68 | np.testing.assert_allclose(
69 | output_loss_weights,
70 | np.array([[1.0, 1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0]],
71 | dtype=np.float32),
72 | rtol=1e-3)
73 |
74 | def test_num_total_target_tokens(self):
75 | batch = {
76 | 'decoder_target_tokens':
77 | jnp.asarray([[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32)
78 | }
79 |
80 | (output_lnf,
81 | output_loss_weights) = losses.get_loss_normalizing_factor_and_weights(
82 | loss_normalizing_factor=losses.SpecialLossNormalizingFactor
83 | .NUM_TOTAL_TARGET_TOKENS,
84 | batch=batch)
85 |
86 | np.testing.assert_allclose(output_lnf, 10.0, rtol=1e-3)
87 | np.testing.assert_allclose(
88 | output_loss_weights,
89 | np.array([[1.0, 1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0]],
90 | dtype=np.float32),
91 | rtol=1e-3)
92 |
93 | def test_average_per_sequence(self):
94 | batch = {
95 | 'decoder_target_tokens':
96 | jnp.asarray([[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32)
97 | }
98 |
99 | (output_lnf,
100 | output_loss_weights) = losses.get_loss_normalizing_factor_and_weights(
101 | loss_normalizing_factor=losses.SpecialLossNormalizingFactor
102 | .AVERAGE_PER_SEQUENCE,
103 | batch=batch)
104 |
105 | np.testing.assert_allclose(output_lnf, 2.0, rtol=1e-3)
106 | np.testing.assert_allclose(
107 | output_loss_weights,
108 | jnp.asarray([[0.25, 0.25, 0.25, 0.25, 0.0], [0.5, 0.5, 0.0, 0.0, 0.0]],
109 | jnp.float32),
110 | rtol=1e-3)
111 |
112 | def test_average_per_sequence_with_weights(self):
113 | batch = {
114 | 'decoder_target_tokens':
115 | jnp.asarray([[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32),
116 | 'decoder_loss_weights':
117 | jnp.asarray([[0.5, 1.0, 0.25, 2.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0]],
118 | jnp.float32)
119 | }
120 |
121 | (output_lnf,
122 | output_loss_weights) = losses.get_loss_normalizing_factor_and_weights(
123 | loss_normalizing_factor=losses.SpecialLossNormalizingFactor
124 | .AVERAGE_PER_SEQUENCE,
125 | batch=batch)
126 |
127 | np.testing.assert_allclose(output_lnf, 2.0, rtol=1e-3)
128 | np.testing.assert_allclose(
129 | output_loss_weights,
130 | jnp.asarray([[0.5 / 3.75, 1.0 / 3.75, 0.25 / 3.75, 2.0 / 3.75, 0.0],
131 | [1.0 / 2.0, 1.0 / 2.0, 0.0, 0.0, 0.0]], jnp.float32),
132 | rtol=1e-3)
133 |
134 | def test_sum_weights_per_segment(self):
135 | weights = jnp.asarray(
136 | [[0.5, 1.0, 0.25, 2.0, 1.5], [1.0, 2.0, 3.0, 4.0, 5.0]], jnp.float32)
137 | positions = jnp.asarray([[0, 1, 2, 0, 0], [0, 0, 1, 0, 0]])
138 | segment_ids = jnp.asarray([[1, 1, 1, 2, 3], [1, 2, 2, 3, 0]])
139 |
140 | norm_vec = losses._sum_weights_per_segment(positions, segment_ids, weights)
141 |
142 | np.testing.assert_allclose(
143 | norm_vec,
144 | jnp.asarray([[1.75, 1.75, 1.75, 2.0, 1.5], [1.0, 5.0, 5.0, 4.0, 0.0]],
145 | jnp.float32),
146 | rtol=1e-3)
147 |
148 | def test_average_per_sequence_with_weights_with_packing(self):
149 | batch = {
150 | 'decoder_target_tokens':
151 | jnp.asarray([[1, 2, 3, 4, 5], [5, 6, 7, 8, 0]], jnp.int32),
152 | 'decoder_loss_weights':
153 | jnp.asarray([[0.5, 1.0, 0.25, 2.0, 1.5], [1.0, 2.0, 3.0, 4.0, 5.0]],
154 | jnp.float32),
155 | 'decoder_positions':
156 | jnp.asarray([[0, 1, 2, 0, 0], [0, 0, 1, 0, 0]]),
157 | 'decoder_segment_ids':
158 | jnp.asarray([[1, 1, 1, 2, 3], [1, 2, 2, 3, 0]])
159 | }
160 |
161 | (output_lnf,
162 | output_loss_weights) = losses.get_loss_normalizing_factor_and_weights(
163 | loss_normalizing_factor=losses.SpecialLossNormalizingFactor
164 | .AVERAGE_PER_SEQUENCE,
165 | batch=batch)
166 |
167 | np.testing.assert_allclose(output_lnf, 6.0, rtol=1e-3)
168 | np.testing.assert_allclose(
169 | output_loss_weights,
170 | jnp.asarray(
171 | [[0.5 / 1.75, 1.0 / 1.75, 0.25 / 1.75, 2.0 / 2.0, 1.5 / 1.5],
172 | [1.0 / 1.0, 2.0 / 5.0, 3.0 / 5.0, 4.0 / 4.0, 0.0]], jnp.float32),
173 | rtol=1e-3)
174 |
175 | if __name__ == '__main__':
176 | absltest.main()
177 |
--------------------------------------------------------------------------------
/t5x/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""The main entrance for running any of the T5X supported binaries.
16 |
17 | Currently this includes train/infer/eval/precompile.
18 |
19 | Example Local (CPU) Pretrain Gin usage
20 |
21 | python -m t5x.main \
22 | --gin_file=t5x/examples/t5/t5_1_1/tiny.gin \
23 | --gin_file=t5x/configs/runs/pretrain.gin \
24 | --gin.MODEL_DIR=\"/tmp/t5x_pretrain\" \
25 | --gin.TRAIN_STEPS=10 \
26 | --gin.MIXTURE_OR_TASK_NAME=\"c4_v220_span_corruption\" \
27 | --gin.MIXTURE_OR_TASK_MODULE=\"t5.data.mixtures\" \
28 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 128, 'targets': 30}" \
29 | --gin.DROPOUT_RATE=0.1 \
30 | --run_mode=train \
31 | --logtostderr
32 | """
33 | import concurrent.futures # pylint:disable=unused-import
34 | import enum
35 | import importlib
36 | import os
37 | import sys
38 | from typing import Optional, Sequence
39 |
40 | from absl import app
41 | from absl import flags
42 | from absl import logging
43 |
44 | import gin
45 | import jax
46 | import seqio
47 |
48 | from t5x import gin_utils
49 | from t5x import utils
50 |
51 |
52 | @enum.unique
53 | class RunMode(enum.Enum):
54 | """All the running mode possible in T5X."""
55 | TRAIN = 'train'
56 | EVAL = 'eval'
57 | INFER = 'infer'
58 | PRECOMPILE = 'precompile'
59 | EXPORT = 'export'
60 |
61 |
62 | _GIN_FILE = flags.DEFINE_multi_string(
63 | 'gin_file',
64 | default=None,
65 | help='Path to gin configuration file. Multiple paths may be passed and '
66 | 'will be imported in the given order, with later configurations '
67 | 'overriding earlier ones.')
68 |
69 | _GIN_BINDINGS = flags.DEFINE_multi_string(
70 | 'gin_bindings', default=[], help='Individual gin bindings.')
71 |
72 | _GIN_SEARCH_PATHS = flags.DEFINE_list(
73 | 'gin_search_paths',
74 | default=['.'],
75 | help='Comma-separated list of gin config path prefixes to be prepended '
76 | 'to suffixes given via `--gin_file`. If a file appears in. Only the '
77 | 'first prefix that produces a valid path for each suffix will be '
78 | 'used.')
79 |
80 | _RUN_MODE = flags.DEFINE_enum_class(
81 | 'run_mode',
82 | default=None,
83 | enum_class=RunMode,
84 | help='The mode to run T5X under')
85 |
86 | _TFDS_DATA_DIR = flags.DEFINE_string(
87 | 'tfds_data_dir', None,
88 | 'If set, this directory will be used to store datasets prepared by '
89 | 'TensorFlow Datasets that are not available in the public TFDS GCS '
90 | 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of '
91 | 'all `Task`s.')
92 |
93 | _DRY_RUN = flags.DEFINE_bool(
94 | 'dry_run', False,
95 | 'If set, does not start the function but stil loads and logs the config.')
96 |
97 |
98 | FLAGS = flags.FLAGS
99 |
100 | # Automatically search for gin files relative to the T5X package.
101 | _DEFAULT_GIN_SEARCH_PATHS = [
102 | os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
103 | ]
104 |
105 | # Mapping of run_mode to the attribute used in the imported module, e.g.
106 | # {EVAL : 'evaluate'} will load 'evaluate' in eval.py.
107 | _ATTR_BY_RUN_MODE = {
108 | RunMode.TRAIN: 'train',
109 | RunMode.EVAL: 'evaluate',
110 | RunMode.INFER: 'infer',
111 | RunMode.PRECOMPILE: 'precompile',
112 | RunMode.EXPORT: 'save',
113 | }
114 |
115 |
116 | main_module = sys.modules[__name__]
117 |
118 |
119 | def main(argv: Sequence[str]):
120 | if len(argv) > 1:
121 | raise app.UsageError('Too many command-line arguments.')
122 |
123 | if _RUN_MODE.value is None:
124 | raise ValueError("'run_mode' flag must be specified when using main.py.")
125 | # Dynamic import the modules based on run_mode, e.g.
126 | # If _RUN_MODE.value is 'train', below is equivalent of doing:
127 | # from t5x import train
128 | # train = train.train
129 |
130 | # _RUN_MODE can never be None after this point.
131 | # pytype: disable=attribute-error
132 | lib_name = _RUN_MODE.value.name.lower()
133 | import_attr = _ATTR_BY_RUN_MODE[_RUN_MODE.value]
134 | # pytype: enable=attribute-error
135 |
136 | parent_module = 't5x'
137 |
138 |
139 | module_to_import = f'{parent_module}.{lib_name}'
140 |
141 | logging.info('Dynamically importing : %s', module_to_import)
142 | imported_lib = importlib.import_module(module_to_import)
143 |
144 | entry_func = getattr(imported_lib, import_attr)
145 | setattr(main_module, import_attr, entry_func)
146 |
147 |
148 | if _TFDS_DATA_DIR.value is not None:
149 | seqio.set_tfds_data_dir_override(_TFDS_DATA_DIR.value)
150 |
151 |
152 | # Register function explicitly under __main__ module, to maintain backward
153 | # compatability of existing '__main__' module references.
154 | gin.register(entry_func, '__main__')
155 | if _GIN_SEARCH_PATHS.value != ['.']:
156 | logging.warning(
157 | 'Using absolute paths for the gin files is strongly recommended.')
158 |
159 | # User-provided gin paths take precedence if relative paths conflict.
160 | gin_utils.parse_gin_flags(_GIN_SEARCH_PATHS.value + _DEFAULT_GIN_SEARCH_PATHS,
161 | _GIN_FILE.value, _GIN_BINDINGS.value)
162 |
163 | if _DRY_RUN.value:
164 | return
165 |
166 | run_with_gin = gin.get_configurable(entry_func)
167 |
168 | run_with_gin()
169 |
170 |
171 |
172 | def _flags_parser(args: Sequence[str]) -> Sequence[str]:
173 | """Flag parser.
174 |
175 | See absl.app.parse_flags_with_usage and absl.app.main(..., flags_parser).
176 |
177 | Args:
178 | args: All command line arguments.
179 |
180 | Returns:
181 | [str], a non-empty list of remaining command line arguments after parsing
182 | flags, including program name.
183 | """
184 | return app.parse_flags_with_usage(list(gin_utils.rewrite_gin_args(args)))
185 |
186 |
187 | if __name__ == '__main__':
188 | jax.config.parse_flags_with_absl()
189 | app.run(main, flags_parser=_flags_parser)
190 |
--------------------------------------------------------------------------------
/t5x/metrics_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for clu.metrics."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | import jax.numpy as jnp
21 | import numpy as np
22 | from t5x import metrics
23 |
24 |
25 | class MetricsTest(parameterized.TestCase):
26 |
27 | @parameterized.named_parameters(
28 | ("0d_values", 2., 2.), ("1d_values", [1, 2, 3], 6.),
29 | ("2d_values", [[1, 2], [2, 3], [3, 4]], 15.),
30 | ("3d_values", [[[1, 2], [2, 3]], [[2, 1], [3, 4]], [[3, 1], [4, 1]]], 27.)
31 | )
32 | def test_sum(self, values, expected_result):
33 | self.assertAlmostEqual(
34 | metrics.Sum.from_model_output(values).compute(), expected_result)
35 |
36 | def test_time_rate(self):
37 | value = np.array([3.])
38 | duration = 2.
39 | metric = metrics.TimeRate.from_model_output(value).replace_duration(
40 | duration)
41 | self.assertAlmostEqual(metric.compute(), value / duration)
42 |
43 | def test_time_rate_unset_duration(self):
44 | value = jnp.array([3.])
45 | metric = metrics.TimeRate.from_model_output(value)
46 | with self.assertRaises(ValueError):
47 | metric.compute()
48 |
49 | def test_time_rate_sets_duration_inside_jitted_fn(self):
50 |
51 | @jax.jit
52 | def fn():
53 | value = jnp.array([3.])
54 | duration = 2.
55 | metric = metrics.TimeRate.from_model_output(value).replace_duration(
56 | duration)
57 | return metric
58 |
59 | with self.assertRaises(ValueError):
60 | fn()
61 |
62 | def test_time(self):
63 | duration = 2.
64 | metric = metrics.Time().replace_duration(duration)
65 | self.assertAlmostEqual(metric.compute(), duration)
66 |
67 | def test_time_unset_duration(self):
68 | metric = metrics.Time()
69 | with self.assertRaises(ValueError):
70 | metric.compute()
71 |
72 | @parameterized.named_parameters(
73 | ("0d_values", 2., 2.),
74 | ("1d_values", [1, 2, 3], 6.),
75 | )
76 | def test_average_per_step(self, values, expected_result):
77 | a = metrics.AveragePerStep.from_model_output(values)
78 | m = metrics.set_step_metrics_num_steps({"a": a}, 1)
79 | self.assertAlmostEqual(m["a"].compute(), expected_result)
80 |
81 | steps = 5
82 | b = metrics.AveragePerStep.from_model_output(values, steps=steps)
83 | m = metrics.set_step_metrics_num_steps({"b": b}, steps)
84 | self.assertAlmostEqual(m["b"].compute(), expected_result / steps)
85 |
86 | def test_steps_per_time(self):
87 | steps = 8.
88 | duration = 2.
89 | metric = metrics.StepsPerTime.from_model_output(
90 | steps=steps).replace_duration(duration)
91 | metrics_dict = metrics.set_step_metrics_num_steps({"metric": metric}, steps)
92 | self.assertAlmostEqual(metrics_dict["metric"].compute(), steps / duration)
93 |
94 |
95 | if __name__ == "__main__":
96 | absltest.main()
97 |
--------------------------------------------------------------------------------
/t5x/precompile.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Precompile and generates HLO from TPU metadata backend.
16 |
17 | TPU Metadata backend is a TPU backend without real TPU devices while supporting
18 | any TPU topologies, to allow work that doesn't require real TPUs to run as if
19 | it is, e.g., compiling/lowering a HLO graph with the backend.
20 |
21 | Ideally, the precompile defaults to cpu backend for default device array
22 | placement since metadata backend does not have memory allocation.
23 |
24 | The pjit function is pinned to use available TPU Metadata backend, for getting
25 | a proper lowering under TPU mesh.
26 |
27 | """
28 |
29 | import os
30 | from typing import Callable, Optional
31 |
32 | import clu.data
33 |
34 | import jax
35 | from jax import random
36 | import numpy as np
37 | import t5.data.mixtures # pylint:disable=unused-import
38 | from t5x import models
39 | from t5x import partitioning
40 | from t5x import trainer as trainer_lib
41 | from t5x import utils
42 | import tensorflow as tf
43 |
44 |
45 |
46 | def precompile(
47 | *,
48 | model: models.BaseTransformerModel,
49 | train_dataset_cfg: utils.DatasetConfig,
50 | partitioner: partitioning.BasePartitioner,
51 | model_dir: str,
52 | random_seed: Optional[int],
53 | get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset,
54 | verify_matching_vocabs_fn: Optional[
55 | Callable[[utils.DatasetConfig, models.BaseTransformerModel],
56 | None]] = utils.verify_matching_vocabs,
57 | ):
58 | """Compiles and dump the HLO to model dir, with HLO text dumps."""
59 | rng = random.PRNGKey(random_seed or 42)
60 | _, trainer_rng = random.split(rng, 2)
61 |
62 | # TODO(hthu): Find a better way of getting dataset shapes instead of actually
63 | # reading database and iterate on it.
64 | data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size)
65 | ds_shard_id = data_layout.shard_id
66 | num_ds_shards = data_layout.num_shards
67 |
68 | if verify_matching_vocabs_fn is not None:
69 | verify_matching_vocabs_fn(train_dataset_cfg, model)
70 |
71 | train_iter = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards,
72 | model.FEATURE_CONVERTER_CLS)
73 | if isinstance(train_iter, tf.data.Dataset):
74 | train_iter = clu.data.TfDatasetIterator(train_iter, checkpoint=True)
75 | elif not isinstance(train_iter, clu.data.dataset_iterator.DatasetIterator):
76 | raise ValueError(
77 | f'get_dataset_fn returned unsupported type {type(train_iter)}.')
78 |
79 | # Need to use full batch size.
80 | input_shapes = jax.tree_map(lambda x: (data_layout.batch_size, *x.shape[1:]),
81 | train_iter.element_spec)
82 | input_types = jax.tree_map(lambda x: x.dtype, train_iter.element_spec)
83 | dummy_batch = jax.tree_map(lambda x: np.ones(x.shape, x.dtype),
84 | train_iter.element_spec)
85 |
86 | # Compiling does not care about loading real weights.
87 | train_state_initializer = utils.TrainStateInitializer(
88 | optimizer_def=model.optimizer_def,
89 | init_fn=model.get_initial_variables,
90 | input_shapes=input_shapes,
91 | input_types=input_types,
92 | partitioner=partitioner)
93 | train_state_shape = train_state_initializer.global_train_state_shape
94 | train_state_axes = train_state_initializer.train_state_axes
95 |
96 | def train_step(train_state, batch):
97 | return trainer_lib.train_with_lr(
98 | train_state,
99 | batch,
100 | learning_rate=1e-3,
101 | dropout_rng=trainer_rng,
102 | model=model,
103 | num_microbatches=None,
104 | weight_metrics_computer=None)
105 |
106 | partitioned_step = partitioner.partition(
107 | train_step,
108 | in_axis_resources=(train_state_axes, partitioning.PartitionSpec('data',)),
109 | out_axis_resources=(train_state_axes, None),
110 | donate_argnums=(0,))
111 |
112 | # PartitionedTrainCallable has lower() defined but isn't exposed in pytype.
113 | # TODO(hthu): Explicitly expose the lower() interface.
114 | # pytype: disable=attribute-error
115 | lowered = partitioned_step.lower(train_state_shape, dummy_batch)
116 | # pytype: enable=attribute-error
117 |
118 |
119 | # TODO(hthu): Make this a proper library without writing files by default.
120 | tf.io.gfile.makedirs(model_dir)
121 | with tf.io.gfile.GFile(
122 | os.path.join(model_dir, 'lowered_hlo_pre_optimization'), 'w') as f:
123 | f.write(lowered.compiler_ir(dialect='hlo').as_serialized_hlo_module_proto())
124 | compiled = lowered.compile()
125 | output_path = os.path.join(model_dir, 'lowered_hlo_post_optimization')
126 | with tf.io.gfile.GFile(output_path, 'w') as f:
127 | f.write(compiled.compiler_ir()[0].as_serialized_hlo_module_proto())
128 | with tf.io.gfile.GFile(os.path.join(model_dir, 'assignment'), 'wb') as f:
129 | np.save(f, partitioner.mesh.device_ids)
130 |
--------------------------------------------------------------------------------
/t5x/testdata/mtf_tiny_t5/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "model.ckpt-0"
2 | all_model_checkpoint_paths: "model.ckpt-0"
3 |
--------------------------------------------------------------------------------
/t5x/testdata/mtf_tiny_t5/model.ckpt-0.data-00000-of-00002:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/t5x/testdata/mtf_tiny_t5/model.ckpt-0.data-00001-of-00002:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/t5x/testdata/mtf_tiny_t5/model.ckpt-0.data-00001-of-00002
--------------------------------------------------------------------------------
/t5x/testdata/mtf_tiny_t5/model.ckpt-0.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/t5x/testdata/mtf_tiny_t5/model.ckpt-0.index
--------------------------------------------------------------------------------
/t5x/testdata/pinned_ckpt_dir/PINNED:
--------------------------------------------------------------------------------
1 | 1
--------------------------------------------------------------------------------
/t5x/testdata/test_t5_tiny.checkpoint_0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/t5x/testdata/test_t5_tiny.checkpoint_0
--------------------------------------------------------------------------------
/t5x/version.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5X Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Separate file for storing the current version of T5X.
16 |
17 | Stored in a separate file so that setup.py can reference the version without
18 | pulling in all the dependencies in __init__.py.
19 | """
20 | __version__ = '0.0.0'
21 |
--------------------------------------------------------------------------------