├── .gitignore ├── LICENSE ├── README.md ├── create_data ├── README.md ├── __init__.py ├── tfdatasets │ ├── __init__.py │ ├── alpaca │ │ └── build.py │ ├── coco_all │ │ ├── __init__.py │ │ └── build.py │ └── flan │ │ └── build.py └── utils.py ├── demo.ipynb ├── demo ├── __init__.py ├── hifigan │ ├── README.md │ ├── checkpoints │ │ ├── config.json │ │ └── g_00930000 │ ├── config_v1.json │ ├── config_v2.json │ ├── config_v3.json │ ├── env.py │ ├── inference.py │ ├── inference_e2e.py │ ├── meldataset.py │ ├── models.py │ ├── requirements.txt │ ├── resample.py │ ├── test_mel.py │ ├── test_mel2.py │ ├── train copy.py │ ├── train.py │ └── utils.py └── utils │ ├── __init__.py │ ├── audio_utils.py │ └── video_utils.py ├── metadata └── coco │ └── coco_class_name_2017.json ├── setup.py └── t5x ├── __init__.py ├── adafactor.py ├── adafactor_test.py ├── binary_search.py ├── binary_search_test.py ├── checkpoint_importer.py ├── checkpoint_importer_test.py ├── checkpoint_importer_vqgan.py ├── checkpoint_utils.py ├── checkpoint_utils_test.py ├── checkpoints.py ├── checkpoints_test.py ├── configs ├── __init__.py └── runs │ ├── __init__.py │ ├── debug.gin │ ├── eval.gin │ ├── export.gin │ ├── export_seqio.gin │ ├── finetune.gin │ ├── infer.gin │ ├── infer_from_tfexample_file.gin │ ├── loss.gin │ ├── multitask.gin │ ├── precompile.gin │ ├── pretrain.gin │ └── vit_vqgan.gin ├── decoding.py ├── eval.py ├── examples ├── __init__.py └── unified_io │ ├── __init__.py │ ├── audio_encoder.py │ ├── audio_vqgan.py │ ├── aux_fns.py │ ├── config.py │ ├── data │ ├── __init__.py │ ├── data_utils.py │ ├── mixtures.py │ ├── nlp_instruction_following.py │ ├── postprocessing.py │ ├── preprocessing.py │ ├── prompt_definition.py │ ├── prompt_dict.py │ ├── tasks.py │ └── visualization_utils.py │ ├── decoding.py │ ├── evaluator.py │ ├── image_encoder.py │ ├── image_vqgan.py │ ├── input_modalities.py │ ├── layers.py │ ├── metrics │ ├── grit_keypoint.py │ ├── grit_localization.py │ ├── grit_normal.py │ ├── grit_segmentation.py │ ├── grit_vqa.py │ ├── metrics.py │ ├── rle.py │ └── utils.py │ ├── modality_processing.py │ ├── model_test.py │ ├── models.py │ ├── network.py │ ├── network_test.py │ ├── packing.py │ ├── packing_test.py │ ├── perceiver.py │ ├── scripts │ ├── __init__.py │ └── dataset_visualize.py │ ├── seq_features.py │ ├── seq_features_test.py │ ├── t5_1_1 │ ├── base.gin │ ├── eval │ │ └── vision_language.gin │ ├── finetune │ │ └── refexp.gin │ ├── large.gin │ ├── tiny.gin │ ├── xl.gin │ └── xxl.gin │ ├── target_modalities.py │ ├── test_utils.py │ ├── utils.py │ └── vocabularies.py ├── export.py ├── export_lib.py ├── gin_utils.py ├── gin_utils_test.py ├── losses.py ├── losses_test.py ├── main.py ├── metrics.py ├── metrics_test.py ├── models.py ├── models_test.py ├── optimizers.py ├── optimizers_test.py ├── partitioning.py ├── partitioning_test.py ├── precompile.py ├── state_utils.py ├── state_utils_test.py ├── test_utils.py ├── testdata ├── mtf_tiny_t5 │ ├── checkpoint │ ├── graph.pbtxt │ ├── model-info.txt │ ├── model.ckpt-0.data-00000-of-00002 │ ├── model.ckpt-0.data-00001-of-00002 │ ├── model.ckpt-0.index │ ├── model.ckpt-0.meta │ └── operative_config.gin ├── pinned_ckpt_dir │ └── PINNED └── test_t5_tiny.checkpoint_0 ├── train.py ├── train_eval.py ├── train_state.py ├── train_state_test.py ├── trainer.py ├── trainer_test.py ├── utils.py ├── utils_test.py └── version.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | *.pyc 163 | .vscode/ 164 | wandb/* 165 | -------------------------------------------------------------------------------- /create_data/README.md: -------------------------------------------------------------------------------- 1 | # Unified-IO-2 Datasets 2 | 3 | This directory contains code to build datasets in a format 4 | that can be consumed by UnifiedIO. 5 | 6 | Some of this install scripts additional dependencies: 7 | 8 | ``` 9 | python3 -m pip install -e '.[data]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -f https://storage.googleapis.com/jax-releases/jax_releases.html 10 | ``` -------------------------------------------------------------------------------- /create_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/create_data/__init__.py -------------------------------------------------------------------------------- /create_data/tfdatasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/create_data/tfdatasets/__init__.py -------------------------------------------------------------------------------- /create_data/tfdatasets/alpaca/build.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import tensorflow as tf 4 | import tensorflow_datasets as tfds 5 | 6 | 7 | class Alpaca(tfds.core.GeneratorBasedBuilder): 8 | VERSION = tfds.core.Version('1.0.0') 9 | 10 | def _info(self) -> tfds.core.DatasetInfo: 11 | features = tfds.features.FeaturesDict(dict( 12 | example_num=tfds.features.Tensor(shape=(), dtype=tf.int32), 13 | instruction=tfds.features.Tensor(shape=(), dtype=tf.string), 14 | input=tfds.features.Tensor(shape=(), dtype=tf.string), 15 | output=tfds.features.Tensor(shape=(), dtype=tf.string), 16 | )) 17 | return tfds.core.DatasetInfo(builder=self, features=features) 18 | 19 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 20 | return {'train': self._generate_examples()} 21 | 22 | def _generate_examples(self): 23 | import datasets 24 | ds = datasets.load_dataset("yahma/alpaca-cleaned")["train"] 25 | for ix, ex in enumerate(ds): 26 | ex["example_num"] = ix 27 | yield ix, ex 28 | 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("data_dir") 33 | args = parser.parse_args() 34 | 35 | builder = Alpaca(data_dir=args.data_dir) 36 | builder.download_and_prepare() 37 | 38 | 39 | if __name__ == '__main__': 40 | main() -------------------------------------------------------------------------------- /create_data/tfdatasets/coco_all/__init__.py: -------------------------------------------------------------------------------- 1 | """coco_all dataset.""" 2 | 3 | from .build import CocoAll 4 | -------------------------------------------------------------------------------- /create_data/tfdatasets/flan/build.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from collections import defaultdict 4 | from os.path import join 5 | import tensorflow as tf 6 | import tensorflow_datasets as tfds 7 | import numpy as np 8 | from PIL import Image 9 | 10 | 11 | MAPPING = { 12 | "Flan2021": "conceptofmind/flan2021_submix_original", 13 | "T0": "conceptofmind/t0_submix_original", 14 | "NIv2": "conceptofmind/niv2_submix_original", 15 | "CoT": "conceptofmind/cot_submix_original", 16 | "Dialog": "conceptofmind/dialog_submix_original", 17 | } 18 | 19 | 20 | class FLAN(tfds.core.GeneratorBasedBuilder): 21 | VERSION = tfds.core.Version('1.0.0') 22 | 23 | def __init__(self, src, **kwargs): 24 | self.src = src 25 | self.__class__.name = f"FLANv2-{src}" 26 | super().__init__(**kwargs) 27 | 28 | def _info(self) -> tfds.core.DatasetInfo: 29 | features = tfds.features.FeaturesDict(dict( 30 | example_num=tfds.features.Tensor(shape=(), dtype=tf.int32), 31 | inputs=tfds.features.Tensor(shape=(), dtype=tf.string), 32 | targets=tfds.features.Tensor(shape=(), dtype=tf.string), 33 | task_source=tfds.features.Tensor(shape=(), dtype=tf.string), 34 | task_name=tfds.features.Tensor(shape=(), dtype=tf.string), 35 | template_type=tfds.features.Tensor(shape=(), dtype=tf.string), 36 | )) 37 | return tfds.core.DatasetInfo( 38 | builder=self, 39 | features=features, 40 | ) 41 | 42 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 43 | return { 44 | 'train': self._generate_examples(), 45 | } 46 | 47 | def _generate_examples(self): 48 | import datasets 49 | ds = datasets.load_dataset(MAPPING[self.src])["train"] 50 | for ix, ex in enumerate(ds): 51 | example = dict(example_num=ix) 52 | example.update({k: ex[k] for k in 53 | ["inputs", "targets", "task_source", "task_name", "template_type"]}) 54 | yield ix, example 55 | 56 | 57 | def main(): 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("kind") 60 | parser.add_argument("data_dir") 61 | args = parser.parse_args() 62 | 63 | builder = FLAN(args.kind, data_dir=args.data_dir) 64 | builder.download_and_prepare() 65 | 66 | 67 | if __name__ == '__main__': 68 | main() -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/demo/__init__.py -------------------------------------------------------------------------------- /demo/hifigan/README.md: -------------------------------------------------------------------------------- 1 | # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis 2 | 3 | ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae 4 | 5 | In our [paper](https://arxiv.org/abs/2010.05646), 6 | we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
7 | We provide our implementation and pretrained models as open source in this repository. 8 | 9 | **Abstract :** 10 | Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. 11 | Although such methods improve the sampling efficiency and memory usage, 12 | their sample quality has not yet reached that of autoregressive and flow-based generative models. 13 | In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. 14 | As speech audio consists of sinusoidal signals with various periods, 15 | we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. 16 | A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method 17 | demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than 18 | real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen 19 | speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times 20 | faster than real-time on CPU with comparable quality to an autoregressive counterpart. 21 | 22 | Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. 23 | 24 | 25 | ## Pre-requisites 26 | 1. Python >= 3.6 27 | 2. Clone this repository. 28 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt) 29 | 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). 30 | And move all wav files to `LJSpeech-1.1/wavs` 31 | 32 | 33 | ## Training 34 | ``` 35 | python train.py --config config_v1.json 36 | ``` 37 | To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
38 | Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
39 | You can change the path by adding `--checkpoint_path` option. 40 | 41 | Validation loss during training with V1 generator.
42 | ![validation loss](./validation_loss.png) 43 | 44 | ## Pretrained Model 45 | You can also use pretrained models we provide.
46 | [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
47 | Details of each folder are as in follows: 48 | 49 | |Folder Name|Generator|Dataset|Fine-Tuned| 50 | |------|---|---|---| 51 | |LJ_V1|V1|LJSpeech|No| 52 | |LJ_V2|V2|LJSpeech|No| 53 | |LJ_V3|V3|LJSpeech|No| 54 | |LJ_FT_T2_V1|V1|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))| 55 | |LJ_FT_T2_V2|V2|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))| 56 | |LJ_FT_T2_V3|V3|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))| 57 | |VCTK_V1|V1|VCTK|No| 58 | |VCTK_V2|V2|VCTK|No| 59 | |VCTK_V3|V3|VCTK|No| 60 | |UNIVERSAL_V1|V1|Universal|No| 61 | 62 | We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. 63 | 64 | ## Fine-Tuning 65 | 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
66 | The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
67 | Example: 68 | ``` 69 | Audio File : LJ001-0001.wav 70 | Mel-Spectrogram File : LJ001-0001.npy 71 | ``` 72 | 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
73 | 3. Run the following command. 74 | ``` 75 | python train.py --fine_tuning True --config config_v1.json 76 | ``` 77 | For other command line options, please refer to the training section. 78 | 79 | 80 | ## Inference from wav file 81 | 1. Make `test_files` directory and copy wav files into the directory. 82 | 2. Run the following command. 83 | ``` 84 | python inference.py --checkpoint_file [generator checkpoint file path] 85 | ``` 86 | Generated wav files are saved in `generated_files` by default.
87 | You can change the path by adding `--output_dir` option. 88 | 89 | 90 | ## Inference for end-to-end speech synthesis 91 | 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
92 | You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), 93 | [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. 94 | 2. Run the following command. 95 | ``` 96 | python inference_e2e.py --checkpoint_file [generator checkpoint file path] 97 | ``` 98 | Generated wav files are saved in `generated_files_from_mel` by default.
99 | You can change the path by adding `--output_dir` option. 100 | 101 | 102 | ## Acknowledgements 103 | We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) 104 | and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. 105 | 106 | -------------------------------------------------------------------------------- /demo/hifigan/checkpoints/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "2", 3 | "num_gpus": 8, 4 | "batch_size": 512, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,4], 12 | "upsample_kernel_sizes": [16,16,8], 13 | "upsample_initial_channel": 256, 14 | "resblock_kernel_sizes": [3,5,7], 15 | "resblock_dilation_sizes": [[1,2], [2,6], [3,12]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 128, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 16000, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 8, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://127.0.0.1:52111", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /demo/hifigan/checkpoints/g_00930000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/demo/hifigan/checkpoints/g_00930000 -------------------------------------------------------------------------------- /demo/hifigan/config_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 2, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 75264, 18 | "num_mels": 64, 19 | "n_fft": 1536, 20 | "hop_size": 588, 21 | "win_size": 1536, 22 | 23 | "sampling_rate": 22050, 24 | 25 | "fmin": 20.0, 26 | "fmax": 11025.0, 27 | "fmax_for_loss": 11025.0, 28 | 29 | "num_workers": 0, 30 | 31 | "dist_config": { 32 | "dist_backend": "nccl", 33 | "dist_url": "tcp://localhost:54321", 34 | "world_size": 1 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /demo/hifigan/config_v2.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 128, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 4, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /demo/hifigan/config_v3.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "2", 3 | "num_gpus": 8, 4 | "batch_size": 512, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,4], 12 | "upsample_kernel_sizes": [16,16,8], 13 | "upsample_initial_channel": 256, 14 | "resblock_kernel_sizes": [3,5,7], 15 | "resblock_dilation_sizes": [[1,2], [2,6], [3,12]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 128, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 16000, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 8, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://127.0.0.1:52111", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /demo/hifigan/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | def build_env(config, config_name, path): 12 | t_path = os.path.join(path, config_name) 13 | if config != t_path: 14 | os.makedirs(path, exist_ok=True) 15 | shutil.copyfile(config, os.path.join(path, config_name)) 16 | -------------------------------------------------------------------------------- /demo/hifigan/inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import glob 4 | import os 5 | import argparse 6 | import json 7 | import torch 8 | from scipy.io.wavfile import write 9 | from env import AttrDict 10 | from meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav 11 | from models import Generator 12 | 13 | h = None 14 | device = None 15 | 16 | 17 | def load_checkpoint(filepath, device): 18 | assert os.path.isfile(filepath) 19 | print("Loading '{}'".format(filepath)) 20 | checkpoint_dict = torch.load(filepath, map_location=device) 21 | print("Complete.") 22 | return checkpoint_dict 23 | 24 | 25 | def get_mel(x): 26 | return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) 27 | 28 | 29 | def scan_checkpoint(cp_dir, prefix): 30 | pattern = os.path.join(cp_dir, prefix + '*') 31 | cp_list = glob.glob(pattern) 32 | if len(cp_list) == 0: 33 | return '' 34 | return sorted(cp_list)[-1] 35 | 36 | 37 | def inference(a): 38 | generator = Generator(h).to(device) 39 | 40 | state_dict_g = load_checkpoint(a.checkpoint_file, device) 41 | generator.load_state_dict(state_dict_g['generator']) 42 | 43 | filelist = os.listdir(a.input_wavs_dir) 44 | 45 | os.makedirs(a.output_dir, exist_ok=True) 46 | 47 | generator.eval() 48 | generator.remove_weight_norm() 49 | with torch.no_grad(): 50 | for i, filname in enumerate(filelist): 51 | wav, sr = load_wav(os.path.join(a.input_wavs_dir, filname)) 52 | wav = wav / MAX_WAV_VALUE 53 | wav = torch.FloatTensor(wav).to(device) 54 | x = get_mel(wav.unsqueeze(0)) 55 | y_g_hat = generator(x) 56 | audio = y_g_hat.squeeze() 57 | audio = audio * MAX_WAV_VALUE 58 | audio = audio.cpu().numpy().astype('int16') 59 | 60 | output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated.wav') 61 | write(output_file, h.sampling_rate, audio) 62 | print(output_file) 63 | 64 | 65 | def main(): 66 | print('Initializing Inference Process..') 67 | 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--input_wavs_dir', default='test_files') 70 | parser.add_argument('--output_dir', default='generated_files') 71 | parser.add_argument('--checkpoint_file', required=True) 72 | a = parser.parse_args() 73 | 74 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json') 75 | with open(config_file) as f: 76 | data = f.read() 77 | 78 | global h 79 | json_config = json.loads(data) 80 | h = AttrDict(json_config) 81 | 82 | torch.manual_seed(h.seed) 83 | global device 84 | if torch.cuda.is_available(): 85 | torch.cuda.manual_seed(h.seed) 86 | device = torch.device('cuda') 87 | else: 88 | device = torch.device('cpu') 89 | 90 | inference(a) 91 | 92 | 93 | if __name__ == '__main__': 94 | main() 95 | 96 | -------------------------------------------------------------------------------- /demo/hifigan/inference_e2e.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import glob 4 | import os 5 | import numpy as np 6 | import argparse 7 | import json 8 | import torch 9 | from scipy.io.wavfile import write 10 | from env import AttrDict 11 | from meldataset import MAX_WAV_VALUE 12 | from models import Generator 13 | 14 | h = None 15 | device = None 16 | 17 | 18 | def load_checkpoint(filepath, device): 19 | assert os.path.isfile(filepath) 20 | print("Loading '{}'".format(filepath)) 21 | checkpoint_dict = torch.load(filepath, map_location=device) 22 | print("Complete.") 23 | return checkpoint_dict 24 | 25 | 26 | def scan_checkpoint(cp_dir, prefix): 27 | pattern = os.path.join(cp_dir, prefix + '*') 28 | cp_list = glob.glob(pattern) 29 | if len(cp_list) == 0: 30 | return '' 31 | return sorted(cp_list)[-1] 32 | 33 | 34 | def inference(a): 35 | generator = Generator(h).to(device) 36 | 37 | state_dict_g = load_checkpoint(a.checkpoint_file, device) 38 | generator.load_state_dict(state_dict_g['generator']) 39 | 40 | filelist = os.listdir(a.input_mels_dir) 41 | 42 | os.makedirs(a.output_dir, exist_ok=True) 43 | 44 | generator.eval() 45 | generator.remove_weight_norm() 46 | with torch.no_grad(): 47 | for i, filname in enumerate(filelist): 48 | x = np.load(os.path.join(a.input_mels_dir, filname)) 49 | x = (x * 3.8312 - 5.0945)[:,:,0] 50 | x = torch.FloatTensor(x).to(device) 51 | y_g_hat = generator(x) 52 | audio = y_g_hat.squeeze() 53 | audio = audio * MAX_WAV_VALUE 54 | audio = audio.cpu().numpy().astype('int16') 55 | 56 | output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated_e2e.wav') 57 | write(output_file, h.sampling_rate, audio) 58 | print(output_file) 59 | 60 | 61 | def main(): 62 | print('Initializing Inference Process..') 63 | 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('--input_mels_dir', default='test_mel_files') 66 | parser.add_argument('--output_dir', default='generated_files_from_mel') 67 | parser.add_argument('--checkpoint_file', required=True) 68 | a = parser.parse_args() 69 | 70 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json') 71 | with open(config_file) as f: 72 | data = f.read() 73 | 74 | global h 75 | json_config = json.loads(data) 76 | h = AttrDict(json_config) 77 | 78 | torch.manual_seed(h.seed) 79 | global device 80 | if torch.cuda.is_available(): 81 | torch.cuda.manual_seed(h.seed) 82 | device = torch.device('cuda') 83 | else: 84 | device = torch.device('cpu') 85 | 86 | inference(a) 87 | 88 | 89 | if __name__ == '__main__': 90 | main() 91 | 92 | -------------------------------------------------------------------------------- /demo/hifigan/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | numpy==1.17.4 3 | librosa==0.7.2 4 | scipy==1.4.1 5 | tensorboard==2.0 6 | soundfile==0.10.3.post1 7 | matplotlib==3.1.3 -------------------------------------------------------------------------------- /demo/hifigan/resample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from concurrent.futures import ProcessPoolExecutor 4 | from multiprocessing import cpu_count 5 | 6 | import torchaudio 7 | from torchaudio.functional import resample 8 | 9 | from tqdm import tqdm 10 | import subprocess 11 | 12 | 13 | def process_wav(in_path, out_path, sample_rate): 14 | # wav, sr = torchaudio.load(in_path) 15 | # wav = resample(wav, sr, sample_rate) 16 | out_path = Path(str(out_path).replace('mp4', 'wav')) 17 | ffmpeg_process = subprocess.Popen( 18 | ['ffmpeg', '-y', '-i', in_path, '-ac', '1', '-ar', str(sample_rate), out_path], 19 | stdout=-1, stderr=-1, text=True 20 | ) 21 | stdout, stderr = ffmpeg_process.communicate(None, timeout=5.0) 22 | ffmpeg_process.kill() 23 | # torchaudio.save(out_path, wav, sample_rate) 24 | return out_path, 0# wav.size(-1) / sample_rate 25 | 26 | 27 | def preprocess_dataset(args): 28 | args.out_dir.mkdir(parents=True, exist_ok=True) 29 | 30 | futures = [] 31 | executor = ProcessPoolExecutor(max_workers=cpu_count()) 32 | print(f"Resampling audio in {args.in_dir}") 33 | for in_path in args.in_dir.rglob("*.mp4"): 34 | relative_path = in_path.relative_to(args.in_dir) 35 | out_path = args.out_dir / relative_path 36 | out_path.parent.mkdir(parents=True, exist_ok=True) 37 | # process_wav(in_path, out_path, args.sample_rate) 38 | futures.append( 39 | executor.submit(process_wav, in_path, out_path, args.sample_rate) 40 | ) 41 | # import pdb; pdb.set_trace() 42 | 43 | results = [future.result() for future in tqdm(futures)] 44 | 45 | lengths = {path.stem: length for path, length in results} 46 | seconds = sum(lengths.values()) 47 | hours = seconds / 3600 48 | print(f"Wrote {len(lengths)} utterances ({hours:.2f} hours)") 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser(description="Resample an audio dataset.") 53 | parser.add_argument( 54 | "in_dir", metavar="in-dir", help="path to the dataset directory.", type=Path 55 | ) 56 | parser.add_argument( 57 | "out_dir", metavar="out-dir", help="path to the output directory.", type=Path 58 | ) 59 | parser.add_argument( 60 | "--sample-rate", 61 | help="target sample rate (default 16kHz)", 62 | type=int, 63 | default=16000, 64 | ) 65 | args = parser.parse_args() 66 | preprocess_dataset(args) 67 | -------------------------------------------------------------------------------- /demo/hifigan/test_mel.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import sys 4 | import librosa 5 | import scipy.signal.windows 6 | import soundfile as sf 7 | import numpy as np 8 | from io import BytesIO 9 | from PIL import Image 10 | from scipy.io import wavfile 11 | import io 12 | import matplotlib.pyplot as plt 13 | from PIL import Image 14 | import requests 15 | import torch.nn.functional as F 16 | import torchaudio 17 | import torchaudio.transforms as transforms 18 | import torch 19 | from meldataset import mel_spectrogram 20 | from torchaudio.functional import resample 21 | 22 | window_size = 4.08 23 | sample_rate = 16000 24 | n_fft = 1024 25 | win_len = 1024 26 | hop_len=256 27 | n_mels = 128 28 | fmin = 0.0 29 | eps = 0.1 30 | max_wav_value=32768.0 31 | playback_speed = 1 32 | fmax = 8000 33 | 34 | with open("example.wav", "wb") as file: 35 | response = requests.get("https://drive.google.com/uc?export=preview&id=1Y3KuPAhB5VcsmIaokBVKu3LUEZOfhSu8") 36 | file.write(response.content) 37 | 38 | # logmel = LogMelSpectrogram() 39 | 40 | audio_fn = 'sample_1.wav' 41 | waveform1, sample_rate = librosa.load(audio_fn, sr=sample_rate) 42 | 43 | sr, waveform = wavfile.read(audio_fn, mmap=True) 44 | waveform = waveform.astype('float32') 45 | waveform /= max_wav_value 46 | 47 | st = float(60 * 0 + 0.0) 48 | start_idx = int(sr * st) 49 | # end_idx = start_idx + int(sr * window_size) * playback_speed 50 | end_idx = 8192 51 | waveform = waveform[start_idx:end_idx] 52 | 53 | waveform = torch.Tensor(waveform) 54 | 55 | torchaudio_melspec = transforms.MelSpectrogram( 56 | sample_rate=sample_rate, 57 | n_fft=n_fft, 58 | win_length=win_len, 59 | hop_length=hop_len, 60 | center=True, 61 | pad_mode="reflect", 62 | power=2.0, 63 | norm='slaney', 64 | onesided=True, 65 | n_mels=n_mels, 66 | mel_scale="slaney", 67 | f_min = fmin, 68 | f_max = sample_rate / 2.0, 69 | )(waveform) 70 | 71 | librosa_melspec = librosa.feature.melspectrogram( 72 | waveform.numpy(), 73 | sr=sample_rate, 74 | n_fft=n_fft, 75 | hop_length=hop_len, 76 | win_length=win_len, 77 | center=True, 78 | pad_mode="reflect", 79 | power=2.0, 80 | n_mels=n_mels, 81 | ) 82 | 83 | torch_melspec = mel_spectrogram(waveform[None,:], n_fft, n_mels, 84 | sample_rate, hop_len, n_fft, fmin, fmax, 85 | center=True) 86 | 87 | torch_melspec = torch_melspec.squeeze(0) 88 | 89 | # mse = ((torch_melspec - librosa_melspec) ** 2).mean() 90 | 91 | import pdb; pdb.set_trace() 92 | 93 | # mel spectrogram extraction using librosa and torch audio. 94 | mel_lobrosa = librosa.feature.melspectrogram(y=y, sr=sr, **params) 95 | 96 | mel_torch = logmel(ty).numpy() 97 | 98 | 99 | 100 | 101 | 102 | 103 | diff = np.mean((mel_lobrosa - mel_torch) ** 2) 104 | import pdb; pdb.set_trace() 105 | 106 | import pdb; pdb.set_trace() 107 | 108 | import torch 109 | import numpy as np 110 | 111 | # Load checkpoint 112 | hifigan = torch.hub.load("bshall/hifigan:main", "hifigan_hubert_soft").cuda() 113 | # Load mel-spectrogram 114 | 115 | mel = torch.from_numpy(log_mel).unsqueeze(0).cuda() 116 | wav, sr = hifigan.generate(mel.cuda()) 117 | 118 | wavfile.write("original.wav", sample_rate, y) 119 | wavfile.write("decoded.wav", sample_rate, wav.reshape(-1).cpu().numpy()) 120 | 121 | 122 | def plot_spectrogram(log_mel, eps=0.1, ylabel='freq_bin', aspect='auto', xmax=None, to_db=True): 123 | fig, axs = plt.subplots(1, 1) 124 | spec = np.exp(log_mel + np.log(eps)) - eps 125 | if to_db: 126 | spec = librosa.power_to_db(spec, ref=np.max) 127 | axs.set_ylabel(ylabel) 128 | axs.set_xlabel('frame') 129 | im = axs.imshow(spec, origin='lower', aspect=aspect) 130 | if xmax: 131 | axs.set_xlim((0, xmax)) 132 | fig.colorbar(im, ax=axs) 133 | fig.tight_layout(pad=0) 134 | fig.canvas.draw() 135 | 136 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 137 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 138 | plt.close(fig) 139 | return data 140 | 141 | img = plot_spectrogram(log_mel, eps) 142 | 143 | Image.fromarray(img).save('mel_spectrogram.png') -------------------------------------------------------------------------------- /demo/hifigan/test_mel2.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import sys 4 | import librosa 5 | import scipy.signal.windows 6 | import soundfile as sf 7 | import numpy as np 8 | from io import BytesIO 9 | from PIL import Image 10 | from scipy.io import wavfile 11 | import io 12 | import matplotlib.pyplot as plt 13 | from PIL import Image 14 | import requests 15 | import torch.nn.functional as F 16 | import torchaudio 17 | import torchaudio.transforms as transforms 18 | import torch 19 | from meldataset import mel_spectrogram 20 | from torchaudio.functional import resample 21 | 22 | window_size = 4.08 23 | sample_rate = 16000 24 | n_fft = 1024 25 | win_len = 1024 26 | hop_len=256 27 | n_mels = 128 28 | fmin = 0.0 29 | eps = 0.1 30 | max_wav_value=32768.0 31 | playback_speed = 1 32 | fmax = 8000 33 | 34 | with open("example.wav", "wb") as file: 35 | response = requests.get("https://drive.google.com/uc?export=preview&id=1Y3KuPAhB5VcsmIaokBVKu3LUEZOfhSu8") 36 | file.write(response.content) 37 | 38 | # logmel = LogMelSpectrogram() 39 | 40 | audio_fn = 'sample_1.wav' 41 | waveform1, sample_rate = librosa.load(audio_fn, sr=sample_rate) 42 | 43 | sr, waveform = wavfile.read(audio_fn, mmap=True) 44 | waveform = waveform.astype('float32') 45 | waveform /= max_wav_value 46 | 47 | st = float(60 * 0 + 0.0) 48 | start_idx = int(sr * st) 49 | end_idx = start_idx + int(sr * window_size) * playback_speed 50 | waveform = waveform[start_idx:end_idx] 51 | 52 | waveform = torch.Tensor(waveform) 53 | 54 | librosa_melspec = librosa.feature.melspectrogram( 55 | waveform.numpy(), 56 | sr=sample_rate, 57 | n_fft=n_fft, 58 | hop_length=hop_len, 59 | win_length=win_len, 60 | center=True, 61 | pad_mode="reflect", 62 | power=2.0, 63 | n_mels=n_mels, 64 | ) 65 | 66 | torch_melspec = mel_spectrogram(waveform[None,:], n_fft, n_mels, 67 | sample_rate, hop_len, n_fft, fmin, fmax, 68 | center=True) 69 | 70 | torch_melspec = torch_melspec.squeeze(0) 71 | 72 | mse = ((torch_melspec - librosa_melspec) ** 2).mean() 73 | 74 | 75 | import pdb; pdb.set_trace() 76 | 77 | # mel spectrogram extraction using librosa and torch audio. 78 | mel_lobrosa = librosa.feature.melspectrogram(y=y, sr=sr, **params) 79 | mel_torch = logmel(ty).numpy() 80 | 81 | diff = np.mean((mel_lobrosa - mel_torch) ** 2) 82 | 83 | import torch 84 | import numpy as np 85 | 86 | # Load checkpoint 87 | hifigan = torch.hub.load("bshall/hifigan:main", "hifigan_hubert_soft").cuda() 88 | # Load mel-spectrogram 89 | 90 | mel = torch.from_numpy(log_mel).unsqueeze(0).cuda() 91 | wav, sr = hifigan.generate(mel.cuda()) 92 | 93 | wavfile.write("original.wav", sample_rate, y) 94 | wavfile.write("decoded.wav", sample_rate, wav.reshape(-1).cpu().numpy()) 95 | 96 | def plot_spectrogram(log_mel, eps=0.1, ylabel='freq_bin', aspect='auto', xmax=None, to_db=True): 97 | fig, axs = plt.subplots(1, 1) 98 | spec = np.exp(log_mel + np.log(eps)) - eps 99 | if to_db: 100 | spec = librosa.power_to_db(spec, ref=np.max) 101 | axs.set_ylabel(ylabel) 102 | axs.set_xlabel('frame') 103 | im = axs.imshow(spec, origin='lower', aspect=aspect) 104 | if xmax: 105 | axs.set_xlim((0, xmax)) 106 | fig.colorbar(im, ax=axs) 107 | fig.tight_layout(pad=0) 108 | fig.canvas.draw() 109 | 110 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 111 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 112 | plt.close(fig) 113 | return data 114 | 115 | img = plot_spectrogram(log_mel, eps) 116 | 117 | Image.fromarray(img).save('mel_spectrogram.png') -------------------------------------------------------------------------------- /demo/hifigan/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import matplotlib 4 | import torch 5 | from torch.nn.utils import weight_norm 6 | matplotlib.use("Agg") 7 | import matplotlib.pylab as plt 8 | 9 | 10 | def plot_spectrogram(spectrogram): 11 | fig, ax = plt.subplots(figsize=(10, 2)) 12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 13 | interpolation='none') 14 | plt.colorbar(im, ax=ax) 15 | 16 | fig.canvas.draw() 17 | plt.close() 18 | 19 | return fig 20 | 21 | 22 | def init_weights(m, mean=0.0, std=0.01): 23 | classname = m.__class__.__name__ 24 | if classname.find("Conv") != -1: 25 | m.weight.data.normal_(mean, std) 26 | 27 | 28 | def apply_weight_norm(m): 29 | classname = m.__class__.__name__ 30 | if classname.find("Conv") != -1: 31 | weight_norm(m) 32 | 33 | 34 | def get_padding(kernel_size, dilation=1): 35 | return int((kernel_size*dilation - dilation)/2) 36 | 37 | 38 | def load_checkpoint(filepath, device): 39 | assert os.path.isfile(filepath) 40 | print("Loading '{}'".format(filepath)) 41 | checkpoint_dict = torch.load(filepath, map_location=device) 42 | print("Complete.") 43 | return checkpoint_dict 44 | 45 | 46 | def save_checkpoint(filepath, obj): 47 | print("Saving checkpoint to {}".format(filepath)) 48 | torch.save(obj, filepath) 49 | print("Complete.") 50 | 51 | 52 | def scan_checkpoint(cp_dir, prefix): 53 | pattern = os.path.join(cp_dir, prefix + '????????') 54 | cp_list = glob.glob(pattern) 55 | if len(cp_list) == 0: 56 | return None 57 | return sorted(cp_list)[-1] 58 | 59 | -------------------------------------------------------------------------------- /demo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .audio_utils import * 2 | from .video_utils import * 3 | -------------------------------------------------------------------------------- /demo/utils/video_utils.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import string 4 | import subprocess 5 | import time 6 | 7 | import gradio as gr 8 | 9 | from create_data.utils import ( 10 | get_video_length, 11 | create_audio_from_video, 12 | extract_frames_from_video, 13 | BUFFER_FROM_END, 14 | ) 15 | from demo.utils.audio_utils import extract_spectrograms_from_audio 16 | 17 | __all__ = ["load_video"] 18 | 19 | 20 | def extract_frames_and_spectrograms_from_video( 21 | video_file, 22 | audio_dir, 23 | video_length=None, 24 | video_segment_length=None, 25 | audio_segment_length=None, 26 | times=None, 27 | clip_start_time=0, 28 | clip_end_time=None, 29 | num_frames=None, 30 | target_size=(256, 256), 31 | *, 32 | use_audio, 33 | ): 34 | if times is None: 35 | # get actual video length 36 | if video_length is None: 37 | video_length = get_video_length(video_file) 38 | if video_length is None: 39 | print(f"Couldn't get video length for {video_file}") 40 | return None, None 41 | 42 | if video_segment_length is None: 43 | video_segment_length = video_length / num_frames 44 | if video_length < (video_segment_length / 2.0) - BUFFER_FROM_END: 45 | print( 46 | f"Video is too short ({video_length}s is less than half the segment length of {video_segment_length}s segments" 47 | ) 48 | return None, None 49 | else: 50 | # don't need this if times is given 51 | video_length = None 52 | 53 | # extract image frames 54 | # t0 = perf_counter() 55 | frames, boundaries = extract_frames_from_video( 56 | video_file, 57 | video_length, 58 | video_segment_length, 59 | times=times, 60 | clip_start_time=clip_start_time, 61 | clip_end_time=clip_end_time, 62 | num_frames=num_frames, 63 | multiprocess=False, 64 | resize=True, 65 | target_size=target_size, 66 | ) 67 | # print(f"Load video in {perf_counter() - t0} seconds in total") 68 | 69 | spectrograms = None 70 | if use_audio: 71 | # expects the audio file to be created already (since it takes some time) 72 | audio_file = create_audio_from_video(video_file, audio_dir, force=True) 73 | if os.path.exists(audio_file): # in case video w/o audio 74 | # extract audio segments 75 | spectrograms = extract_spectrograms_from_audio( 76 | audio_file, 77 | audio_length=clip_end_time, 78 | audio_segment_length=audio_segment_length, 79 | spectrogram_length=audio_segment_length, 80 | ) 81 | 82 | return frames, spectrograms 83 | 84 | 85 | def load_video( 86 | path: str, 87 | max_frames: int = 4, 88 | audio_segment_length: float = 4.08, 89 | target_size: tuple = (256, 256), 90 | *, 91 | use_audio: bool, 92 | ): 93 | """max frames could be max image history length + (1 if no image input else 0), similar for audio""" 94 | if path.startswith("http"): 95 | filetype = path.split(".")[-1] 96 | filename = "".join(random.choices(string.ascii_lowercase, k=8)) + "." + filetype 97 | cmd = f"wget -O {filename} {path}" 98 | print(cmd) 99 | _ = subprocess.Popen( 100 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True 101 | ) 102 | gr.Info("Waiting for video download to finish!") 103 | time.sleep(10) 104 | path = filename 105 | 106 | assert os.path.exists(path), path 107 | video_length = get_video_length(path) 108 | clip_end_time, video_seg_length = None, None 109 | if video_length is not None: 110 | # accommodate the audio segment length 111 | max_length = max_frames * audio_segment_length 112 | clip_end_time = min(max_length, video_length) 113 | if clip_end_time < video_length: 114 | gr.Warning( 115 | f"Use the input video length of {clip_end_time} (original {video_length}) seconds." 116 | ) 117 | # second per frame, not necessary corresponding with the audio in demo 118 | video_seg_length = clip_end_time / max_frames 119 | 120 | frames, spectrograms = extract_frames_and_spectrograms_from_video( 121 | path, 122 | audio_dir=None, 123 | video_length=video_length, 124 | video_segment_length=video_seg_length, 125 | audio_segment_length=audio_segment_length, 126 | clip_start_time=0, 127 | clip_end_time=clip_end_time, 128 | num_frames=None, 129 | use_audio=use_audio, 130 | target_size=target_size, 131 | ) 132 | return frames, spectrograms 133 | -------------------------------------------------------------------------------- /metadata/coco/coco_class_name_2017.json: -------------------------------------------------------------------------------- 1 | {"0": "person", "1": "bicycle", "2": "car", "3": "motorcycle", "4": "airplane", "5": "bus", "6": "train", "7": "truck", "8": "boat", "9": "traffic-light", "10": "fire-hydrant", "11": "stop-sign", "12": "parking-meter", "13": "bench", "14": "bird", "15": "cat", "16": "dog", "17": "horse", "18": "sheep", "19": "cow", "20": "elephant", "21": "bear", "22": "zebra", "23": "giraffe", "24": "backpack", "25": "umbrella", "26": "handbag", "27": "tie", "28": "suitcase", "29": "frisbee", "30": "skis", "31": "snowboard", "32": "sports ball", "33": "kite", "34": "baseball-bat", "35": "baseball-glove", "36": "skateboard", "37": "surfboard", "38": "tennis-racket", "39": "bottle", "40": "wine glass", "41": "cup", "42": "fork", "43": "knife", "44": "spoon", "45": "bowl", "46": "banana", "47": "apple", "48": "sandwich", "49": "orange", "50": "broccoli", "51": "carrot", "52": "hot dog", "53": "pizza", "54": "donut", "55": "cake", "56": "chair", "57": "couch", "58": "potted plant", "59": "bed", "60": "dining table", "61": "toilet", "62": "tv", "63": "laptop", "64": "mouse", "65": "remote", "66": "keyboard", "67": "cell phone", "68": "microwave", "69": "oven", "70": "toaster", "71": "sink", "72": "refrigerator", "73": "book", "74": "clock", "75": "vase", "76": "scissors", "77": "teddy-bear", "78": "hair-drier", "79": "toothbrush"} -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Install T5X.""" 16 | 17 | import os 18 | import sys 19 | import setuptools 20 | 21 | # To enable importing version.py directly, we add its path to sys.path. 22 | version_path = os.path.join(os.path.dirname(__file__), 't5x') 23 | sys.path.append(version_path) 24 | from version import __version__ # pylint: disable=g-import-not-at-top 25 | 26 | # Get the long description from the README file. 27 | with open('README.md') as fp: 28 | _LONG_DESCRIPTION = fp.read() 29 | 30 | _jax_version = '0.2.27' 31 | _jaxlib_version = '0.1.76' 32 | 33 | setuptools.setup( 34 | name='t5x', 35 | version=__version__, 36 | description='UnifiedIO 2', 37 | long_description=_LONG_DESCRIPTION, 38 | long_description_content_type='text/markdown', 39 | license='Apache 2.0', 40 | packages=setuptools.find_packages(), 41 | package_data={ 42 | '': ['**/*.gin'], # not all subdirectories may have __init__.py. 43 | }, 44 | scripts=[], 45 | install_requires=[ 46 | 'absl-py', 47 | 'cached_property', 48 | 'protobuf==3.19.4', 49 | 'google-api-core==2.8.2', 50 | # TODO(adarob): Replace with 'clu' once >0.0.6 is released. 51 | 'clu==0.0.8', 52 | 'flax==0.6.3', 53 | 'gin-config', 54 | f'jax==0.3.25', 55 | f'jaxlib==0.3.25', 56 | 'numpy', 57 | 'orbax==0.0.2', 58 | 't5==0.9.4', 59 | 'tensorflow==2.11.1', 60 | 'einops', 61 | 'tfds-nightly==4.8.3.dev202304050043', 62 | 'tensorflow_probability==0.19.0', 63 | 'tensorflow-addons==0.19.0', 64 | 'tensorflow-datasets==4.8.3', 65 | 'pycocoevalcap', 66 | 'tensorstore >= 0.1.20', 67 | 'librosa', 68 | 'scikit-image', 69 | 'wandb==0.14.0', 70 | "optax==0.1.4", 71 | "tqdm", 72 | "transforms3d==0.4.1", 73 | "pyglove==0.4.3", 74 | "seqio==0.0.8", 75 | ], 76 | extras_require={ 77 | 'data': ['datasets', 'google-cloud-storage', "resampy"], 78 | "demo": ["resampy", 'google-cloud-storage', 'gradio==4.8.0', 'notebook', 'sk-video'], 79 | # Cloud TPU requirements. 80 | 'tpu': [f'jax[tpu]==0.3.25'], 81 | }, 82 | classifiers=[ 83 | 'Intended Audience :: Science/Research', 84 | 'License :: OSI Approved :: Apache Software License', 85 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 86 | ], 87 | ) -------------------------------------------------------------------------------- /t5x/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Import API modules.""" 16 | 17 | import t5x.adafactor 18 | import t5x.checkpoints 19 | import t5x.decoding 20 | import t5x.gin_utils 21 | import t5x.losses 22 | import t5x.models 23 | import t5x.partitioning 24 | import t5x.state_utils 25 | import t5x.train_state 26 | import t5x.trainer 27 | import t5x.utils 28 | 29 | # Version number. 30 | from t5x.version import __version__ 31 | 32 | # TODO(adarob): Move clients to t5x.checkpointing and rename 33 | # checkpoints.py to checkpointing.py 34 | checkpointing = t5x.checkpoints 35 | -------------------------------------------------------------------------------- /t5x/binary_search_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for binary_search.""" 16 | 17 | from absl.testing import absltest 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | from t5x import binary_search 22 | 23 | _INT32_MIN = np.iinfo(np.int32).min 24 | _INT32_MAX = np.iinfo(np.int32).max 25 | 26 | 27 | class BinarySearchTest(absltest.TestCase): 28 | 29 | def test_int32_bsearch(self): 30 | a = jnp.asarray([ 31 | 1, 32 | 43, 33 | 79, 34 | 2048, 35 | 0, 36 | 2047, 37 | _INT32_MIN, 38 | _INT32_MIN + 1, 39 | _INT32_MAX, 40 | _INT32_MAX - 1, 41 | ], 42 | dtype=jnp.int32) 43 | 44 | def predicate(x): 45 | return x > a 46 | 47 | r = binary_search.int32_bsearch(a.shape, predicate) 48 | np.testing.assert_array_equal(a, r) 49 | 50 | def test_int32_bsearch_extreme_predicates(self): 51 | 52 | def predicate_false(x): 53 | return jnp.full_like(x, False) 54 | 55 | np.testing.assert_array_equal( 56 | jnp.asarray([_INT32_MAX]), 57 | binary_search.int32_bsearch((1,), predicate_false)) 58 | 59 | def predicate_true(x): 60 | return jnp.full_like(x, True) 61 | 62 | np.testing.assert_array_equal( 63 | jnp.asarray([_INT32_MIN]), 64 | binary_search.int32_bsearch((1,), predicate_true)) 65 | 66 | def test_float32_bsearch(self): 67 | a = jnp.asarray([1.23, 0.0, -0.0, 105.4, -1024, 4.3], dtype=jnp.float32) 68 | 69 | def predicate(x): 70 | return x > a 71 | 72 | c = binary_search.float32_bsearch(a.shape, predicate) 73 | # Given that the predicate is based on floating point '>' as implemented by 74 | # JAX, we need our equality test to be based on floating point '==' as 75 | # implemented by JAX, rather than np.testing.assert_array_equal. 76 | # 77 | # Some corner cases on subnormal numbers may be different, depending on what 78 | # platform we run on. 79 | self.assertTrue(jnp.all(a == c), f'a={a}, c={c}') 80 | 81 | def test_topk_mask(self): 82 | mask = -1e10 83 | x = jnp.asarray([ 84 | [1.4, 7.9, -4.3, 100, 71, 6, -1e4], 85 | [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, -100], 86 | ]) 87 | 88 | # Using exact equality here, because topk_mask guarantees it: it is just 89 | # masking some things, not doing arithmetic on the array. 90 | np.testing.assert_array_equal( 91 | jnp.asarray([ 92 | [mask, mask, mask, 100, mask, mask, mask], 93 | [mask, mask, mask, mask, mask, 9.7, mask], 94 | ]), 95 | binary_search.topk_mask(x, 1, mask), 96 | ) 97 | np.testing.assert_array_equal( 98 | jnp.asarray([ 99 | [mask, mask, mask, 100, 71, mask, mask], 100 | [8.3, mask, mask, mask, mask, 9.7, mask], 101 | ]), 102 | binary_search.topk_mask(x, 2, mask), 103 | ) 104 | np.testing.assert_array_equal( 105 | jnp.asarray([ 106 | [mask, 7.9, mask, 100, 71, mask, mask], 107 | [8.3, mask, 1.3, mask, mask, 9.7, mask], 108 | ]), 109 | binary_search.topk_mask(x, 3, mask), 110 | ) 111 | np.testing.assert_array_equal( 112 | jnp.asarray([ 113 | [mask, 7.9, mask, 100, 71, 6, mask], 114 | [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, mask], 115 | ]), 116 | binary_search.topk_mask(x, 4, mask), 117 | ) 118 | np.testing.assert_array_equal( 119 | jnp.asarray([ 120 | [1.4, 7.9, mask, 100, 71, 6, mask], 121 | [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, mask], 122 | ]), 123 | binary_search.topk_mask(x, 5, mask), 124 | ) 125 | np.testing.assert_array_equal( 126 | jnp.asarray([ 127 | [1.4, 7.9, -4.3, 100, 71, 6, mask], 128 | [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, mask], 129 | ]), 130 | binary_search.topk_mask(x, 6, mask), 131 | ) 132 | np.testing.assert_array_equal( 133 | jnp.asarray([ 134 | [1.4, 7.9, -4.3, 100, 71, 6, -1e4], 135 | [8.3, 1.2, 1.3, 1.2, 1.2, 9.7, -100], 136 | ]), 137 | binary_search.topk_mask(x, 7, mask), 138 | ) 139 | 140 | def test_topp_mask(self): 141 | probs = jnp.asarray([ 142 | [0.0, 0.7, 0.04, 0.06, 0.2, 0.0], 143 | [0.0, 0.2, 0.2, 0.2, 0.3, 0.1], 144 | ]) 145 | logits = jnp.log(probs) 146 | np.testing.assert_allclose(jax.nn.softmax(logits), probs) 147 | mask = -1e10 148 | 149 | # Using exact equality here, because topp_mask guarantees it: it is just 150 | # masking some things, not doing arithmetic on the array. 151 | np.testing.assert_array_equal( 152 | jnp.asarray([ 153 | [mask, jnp.log(0.7), mask, mask, mask, mask], 154 | [mask, mask, mask, mask, jnp.log(0.3), mask], 155 | ]), 156 | binary_search.topp_mask(logits, 0.1, mask), 157 | ) 158 | np.testing.assert_array_equal( 159 | jnp.asarray([ 160 | [mask, jnp.log(0.7), mask, mask, mask, mask], 161 | [mask, mask, mask, mask, jnp.log(0.3), mask], 162 | ]), 163 | binary_search.topp_mask(logits, 0.3, mask), 164 | ) 165 | np.testing.assert_array_equal( 166 | jnp.asarray([ 167 | [mask, jnp.log(0.7), mask, mask, mask, mask], 168 | [ 169 | mask, 170 | jnp.log(0.2), 171 | jnp.log(0.2), 172 | jnp.log(0.2), 173 | jnp.log(0.3), mask 174 | ], 175 | ]), 176 | binary_search.topp_mask(logits, 0.4, mask), 177 | ) 178 | np.testing.assert_array_equal( 179 | jnp.asarray([ 180 | [mask, jnp.log(0.7), mask, mask, 181 | jnp.log(0.2), mask], 182 | [ 183 | mask, 184 | jnp.log(0.2), 185 | jnp.log(0.2), 186 | jnp.log(0.2), 187 | jnp.log(0.3), mask 188 | ], 189 | ]), 190 | binary_search.topp_mask(logits, 0.8, mask), 191 | ) 192 | np.testing.assert_array_equal( 193 | jnp.asarray([ 194 | [mask, jnp.log(0.7), mask, 195 | jnp.log(0.06), 196 | jnp.log(0.2), mask], 197 | [ 198 | mask, 199 | jnp.log(0.2), 200 | jnp.log(0.2), 201 | jnp.log(0.2), 202 | jnp.log(0.3), 203 | jnp.log(0.1) 204 | ], 205 | ]), 206 | binary_search.topp_mask(logits, 0.95, mask), 207 | ) 208 | 209 | 210 | if __name__ == '__main__': 211 | absltest.main() 212 | -------------------------------------------------------------------------------- /t5x/checkpoint_importer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for t5x.checkpoint_importer.""" 16 | 17 | import json 18 | import os 19 | 20 | from absl import flags 21 | from absl.testing import absltest 22 | import jax 23 | import numpy as np 24 | from t5x import checkpoint_importer 25 | import tensorflow as tf 26 | 27 | 28 | class CheckpointImporterTest(absltest.TestCase): 29 | 30 | def test_rel_embeddings_shared_layers(self): 31 | # This represents a ckpt where the Mesh TensorFlow's 32 | # transformer_layers.SelfAttention.relative_attention_type = "bias_shared", 33 | # i.e., the same relative attention parameters are shared by all layers 34 | # within the (en|de)coder. 35 | ckpt_data = { 36 | 'encoder/block_000/layer_000/SelfAttention/relative_attention_bias': 37 | 1, 38 | 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias': 39 | 2, 40 | 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v': 41 | 3, 42 | } 43 | t5_data = checkpoint_importer.t5_importer.apply(ckpt_data) 44 | t5_data = checkpoint_importer._maybe_correct_relpos_bias(t5_data) 45 | expected = { 46 | 'target/encoder/relpos_bias/rel_embedding': 1, 47 | 'target/decoder/relpos_bias/rel_embedding': 2, 48 | 'state/param_states/decoder/relpos_bias/rel_embedding/v': 3, 49 | } 50 | self.assertEqual(t5_data, expected) 51 | 52 | def test_rel_embeddings_per_layer(self): 53 | # This represents a ckpt where the Mesh TensorFlow's 54 | # transformer_layers.SelfAttention.relative_attention_type = "bias", i.e., 55 | # each layer has its own relative attention parameters. 56 | ckpt_data = { 57 | 'encoder/block_000/layer_000/SelfAttention/relative_attention_bias': 58 | 1, 59 | 'encoder/block_001/layer_000/SelfAttention/relative_attention_bias': 60 | 2, 61 | 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias': 62 | 3, 63 | 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v': 64 | 4, 65 | 'decoder/block_011/layer_000/SelfAttention/relative_attention_bias': 66 | 5 67 | } 68 | t5_data = checkpoint_importer.t5_importer.apply(ckpt_data) 69 | t5_data = checkpoint_importer._maybe_correct_relpos_bias(t5_data) 70 | expected = { 71 | 'target/encoder/layers_0/relpos_bias/rel_embedding': 1, 72 | 'target/encoder/layers_1/relpos_bias/rel_embedding': 2, 73 | 'target/decoder/layers_0/relpos_bias/rel_embedding': 3, 74 | 'state/param_states/decoder/layers_0/relpos_bias/rel_embedding/v': 4, 75 | 'target/decoder/layers_11/relpos_bias/rel_embedding': 5, 76 | } 77 | self.assertEqual(t5_data, expected) 78 | 79 | 80 | if __name__ == '__main__': 81 | absltest.main() 82 | -------------------------------------------------------------------------------- /t5x/checkpoint_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Checkpoint helper functions for managing checkpoints. 16 | 17 | Supports marking checkpoints as pinned to exclude them from the checkpointer 18 | removal process. 19 | """ 20 | 21 | import os 22 | 23 | from absl import logging 24 | 25 | from tensorflow.io import gfile 26 | 27 | # PINNED file in the checkpoint directory indicates that the checkpoint should 28 | # not be removed during the automatic pruning of old checkpoints. 29 | _PINNED_CHECKPOINT_FILENAME = 'PINNED' 30 | 31 | 32 | def pinned_checkpoint_filepath(ckpt_dir: str) -> str: 33 | """Full path of the pinned checkpoint file.""" 34 | return os.path.join(ckpt_dir, _PINNED_CHECKPOINT_FILENAME) 35 | 36 | 37 | def is_pinned_checkpoint(ckpt_dir: str) -> bool: 38 | """Returns whether the checkpoint is pinned, and should NOT be removed.""" 39 | pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir) 40 | if gfile.exists(pinned_ckpt_file): 41 | return True 42 | return False 43 | 44 | 45 | def pin_checkpoint(ckpt_dir: str, txt: str = '1') -> None: 46 | """Pin a checkpoint so it does not get deleted by the normal pruning process. 47 | 48 | Creates a PINNED file in the checkpoint directory to indicate the checkpoint 49 | should be excluded from the deletion of old checkpoints. 50 | 51 | Args: 52 | ckpt_dir: The checkpoint step dir that is to be always kept. 53 | txt: Text to be written into the checkpoints ALWAYS_KEEP me file. 54 | """ 55 | pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir) 56 | with gfile.GFile(pinned_ckpt_file, 'w') as f: 57 | logging.debug('Write %s file : %s.', pinned_ckpt_file, txt) 58 | f.write(txt) 59 | 60 | 61 | def unpin_checkpoint(ckpt_dir: str) -> None: 62 | """Removes the pinned status of the checkpoint so it is open for deletion.""" 63 | if not is_pinned_checkpoint(ckpt_dir): 64 | logging.debug('%s is not PINNED. Nothing to do here.', ckpt_dir) 65 | return 66 | try: 67 | pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir) 68 | logging.debug('Remove %s file.', pinned_ckpt_file) 69 | gfile.rmtree(pinned_ckpt_file) 70 | except IOError: 71 | logging.exception('Failed to unpin %s', ckpt_dir) 72 | 73 | 74 | def remove_checkpoint_dir(ckpt_dir: str) -> None: 75 | """Removes the checkpoint dir if it is not pinned.""" 76 | if not is_pinned_checkpoint(ckpt_dir): 77 | logging.info('Deleting checkpoint: %s', ckpt_dir) 78 | gfile.rmtree(ckpt_dir) 79 | else: 80 | logging.info('Keeping pinned checkpoint: %s', ckpt_dir) 81 | 82 | 83 | def remove_dataset_checkpoint(ckpt_dir: str, train_ds_prefix: str) -> None: 84 | """Removes dataset checkpoints if the checkpoint is not pinned.""" 85 | if not is_pinned_checkpoint(ckpt_dir): 86 | train_ds_pattern = os.path.join(ckpt_dir, train_ds_prefix + '*') 87 | logging.info('Deleting dataset checkpoint: %s', train_ds_pattern) 88 | for file in gfile.glob(train_ds_pattern): 89 | gfile.remove(file) 90 | else: 91 | logging.info('Keeping pinned checkpoint: %s', ckpt_dir) 92 | -------------------------------------------------------------------------------- /t5x/checkpoint_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for t5x.checkpoint_utils.""" 16 | 17 | import os 18 | import traceback 19 | 20 | from absl.testing import absltest 21 | from t5x import checkpoint_utils 22 | from tensorflow.io import gfile 23 | 24 | TESTDATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") 25 | 26 | 27 | class CheckpointsUtilsTest(absltest.TestCase): 28 | 29 | def setUp(self): 30 | super().setUp() 31 | self.checkpoints_dir = self.create_tempdir() 32 | self.ckpt_dir_path = self.checkpoints_dir.full_path 33 | self.pinned_ckpt_file = os.path.join(self.ckpt_dir_path, "PINNED") 34 | self.checkpoints_dir.create_file("checkpoint") 35 | # Create a `train_ds` file representing the dataset checkpoint. 36 | train_ds_basename = "train_ds-00000-of-00001" 37 | self.train_ds_file = os.path.join(self.ckpt_dir_path, train_ds_basename) 38 | self.checkpoints_dir.create_file(train_ds_basename) 39 | 40 | def test_always_keep_checkpoint_file(self): 41 | self.assertEqual( 42 | "/path/to/ckpt/dir/PINNED", 43 | checkpoint_utils.pinned_checkpoint_filepath("/path/to/ckpt/dir")) 44 | 45 | def test_is_pinned_checkpoint_false_by_default(self): 46 | # Ensure regular checkpoint without PINNED file. 47 | self.assertFalse(gfile.exists(os.path.join(self.ckpt_dir_path, "PINNED"))) 48 | 49 | # Validate checkpoints are not pinned by default. 50 | self.assertFalse(checkpoint_utils.is_pinned_checkpoint(self.ckpt_dir_path)) 51 | 52 | def test_is_pinned_checkpoint(self): 53 | # Ensure the checkpoint directory as pinned. 54 | pinned_ckpt_testdata = os.path.join(TESTDATA, "pinned_ckpt_dir") 55 | pinned_file = os.path.join(pinned_ckpt_testdata, "PINNED") 56 | self.assertTrue(gfile.exists(pinned_file)) 57 | 58 | # Test and validate. 59 | self.assertTrue(checkpoint_utils.is_pinned_checkpoint(pinned_ckpt_testdata)) 60 | 61 | def test_is_pinned_missing_ckpt(self): 62 | self.assertFalse( 63 | checkpoint_utils.is_pinned_checkpoint( 64 | os.path.join(self.ckpt_dir_path, "ckpt_does_not_exist"))) 65 | 66 | def test_pin_checkpoint(self): 67 | # Ensure directory isn't already pinned. 68 | self.assertFalse(gfile.exists(self.pinned_ckpt_file)) 69 | 70 | # Test. 71 | checkpoint_utils.pin_checkpoint(self.ckpt_dir_path) 72 | 73 | # Validate. 74 | self.assertTrue(gfile.exists(self.pinned_ckpt_file)) 75 | with open(self.pinned_ckpt_file) as f: 76 | self.assertEqual("1", f.read()) 77 | 78 | def test_pin_checkpoint_txt(self): 79 | checkpoint_utils.pin_checkpoint(self.ckpt_dir_path, "TEXT_IN_PINNED") 80 | self.assertTrue(os.path.exists(os.path.join(self.ckpt_dir_path, "PINNED"))) 81 | with open(self.pinned_ckpt_file) as f: 82 | self.assertEqual("TEXT_IN_PINNED", f.read()) 83 | 84 | def test_unpin_checkpoint(self): 85 | # Mark the checkpoint directory as pinned. 86 | self.checkpoints_dir.create_file("PINNED") 87 | self.assertTrue(checkpoint_utils.is_pinned_checkpoint(self.ckpt_dir_path)) 88 | 89 | # Test. 90 | checkpoint_utils.unpin_checkpoint(self.ckpt_dir_path) 91 | 92 | # Validate the "PINNED" checkpoint file got removed. 93 | self.assertFalse(gfile.exists(os.path.join(self.ckpt_dir_path, "PINNED"))) 94 | 95 | def test_unpin_checkpoint_does_not_exist(self): 96 | missing_ckpt_path = os.path.join(self.ckpt_dir_path, "ckpt_does_not_exist") 97 | self.assertFalse(gfile.exists(missing_ckpt_path)) 98 | 99 | # Test. Assert does not raise error. 100 | try: 101 | checkpoint_utils.unpin_checkpoint(missing_ckpt_path) 102 | except IOError: 103 | # TODO(b/172262005): Remove traceback.format_exc() from the error message. 104 | self.fail("Unpin checkpoint failed with: %s" % traceback.format_exc()) 105 | 106 | def test_remove_checkpoint_dir(self): 107 | # Ensure the checkpoint directory is setup. 108 | assert gfile.exists(self.ckpt_dir_path) 109 | 110 | # Test. 111 | checkpoint_utils.remove_checkpoint_dir(self.ckpt_dir_path) 112 | 113 | # Validate the checkpoint directory got removed. 114 | self.assertFalse(gfile.exists(self.ckpt_dir_path)) 115 | 116 | def test_remove_checkpoint_dir_pinned(self): 117 | # Mark the checkpoint directory as pinned so it does not get removed. 118 | self.checkpoints_dir.create_file("PINNED") 119 | 120 | # Test. 121 | checkpoint_utils.remove_checkpoint_dir(self.ckpt_dir_path) 122 | 123 | # Validate the checkpoint directory still exists. 124 | self.assertTrue(gfile.exists(self.ckpt_dir_path)) 125 | 126 | def test_remove_dataset_checkpoint(self): 127 | # Ensure the checkpoint directory is setup. 128 | assert gfile.exists(self.ckpt_dir_path) 129 | 130 | # Test. 131 | checkpoint_utils.remove_dataset_checkpoint(self.ckpt_dir_path, "train_ds") 132 | 133 | # Validate the checkpoint directory got removed. 134 | self.assertFalse(gfile.exists(self.train_ds_file)) 135 | self.assertTrue(gfile.exists(self.ckpt_dir_path)) 136 | 137 | def test_remove_dataset_checkpoint_pinned(self): 138 | # Mark the checkpoint directory as pinned so it does not get removed. 139 | self.checkpoints_dir.create_file("PINNED") 140 | 141 | # Test. 142 | checkpoint_utils.remove_dataset_checkpoint(self.ckpt_dir_path, "train_ds") 143 | 144 | # Validate the checkpoint directory still exists. 145 | self.assertTrue(gfile.exists(self.train_ds_file)) 146 | self.assertTrue(gfile.exists(self.ckpt_dir_path)) 147 | 148 | if __name__ == "__main__": 149 | absltest.main() 150 | -------------------------------------------------------------------------------- /t5x/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """This empty file is needed for loading the gin files in this directory.""" 16 | -------------------------------------------------------------------------------- /t5x/configs/runs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This empty file is needed for loading the gin files in this directory. 16 | -------------------------------------------------------------------------------- /t5x/configs/runs/debug.gin: -------------------------------------------------------------------------------- 1 | # Defaults for pretraining with train.py. 2 | # 3 | # 4 | # You must also include a binding for MODEL. 5 | # 6 | # Required to be set: 7 | # 8 | # - MIXTURE_OR_TASK_NAME 9 | # - TASK_FEATURE_LENGTHS 10 | # - TRAIN_STEPS 11 | # - MODEL_DIR: # automatically set when using xm_launch 12 | # 13 | # Commonly overridden options: 14 | # 15 | # - train/DatasetConfig.batch_size 16 | # - train_eval/DatasetConfig.batch_size 17 | # - PjitPartitioner.num_partitions 18 | # - Trainer.num_microbatches 19 | # - DROPOUT_RATE 20 | from __gin__ import dynamic_registration 21 | 22 | import __main__ as train_script 23 | from t5x import gin_utils 24 | from t5x import partitioning 25 | from t5x import utils 26 | from t5x import trainer 27 | 28 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED 29 | TASK_FEATURE_LENGTHS = %gin.REQUIRED 30 | TRAIN_STEPS = %gin.REQUIRED 31 | MODEL_DIR = %gin.REQUIRED 32 | BATCH_SIZE = 128 33 | USE_CACHED_TASKS = False 34 | BASE_LR = 0.01 35 | WARMUP_STEPS = 100 36 | TIMESCALE = 100 37 | 38 | # DEPRECATED: Import the this module in your gin file. 39 | MIXTURE_OR_TASK_MODULE = None 40 | SHUFFLE_TRAIN_EXAMPLES = False 41 | 42 | # HW RNG is faster than SW, but has limited determinism. 43 | # Most notably it is not deterministic across different 44 | # submeshes. 45 | USE_HARDWARE_RNG = False 46 | # None always uses faster, hardware RNG 47 | RANDOM_SEED = None 48 | 49 | # Can be overridden with `train.*`.` 50 | train_script.train: 51 | model = %MODEL # imported from separate gin file 52 | model_dir = %MODEL_DIR 53 | train_dataset_cfg = @train/utils.DatasetConfig() 54 | train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() 55 | infer_eval_dataset_cfg = None 56 | checkpoint_cfg = @utils.CheckpointConfig() 57 | partitioner = @partitioning.PjitPartitioner() 58 | trainer_cls = @trainer.Trainer 59 | total_steps = %TRAIN_STEPS 60 | eval_steps = 20 61 | eval_period = 1000 62 | random_seed = %RANDOM_SEED 63 | use_hardware_rng = %USE_HARDWARE_RNG 64 | summarize_config_fn = @gin_utils.summarize_gin_config 65 | 66 | partitioning.PjitPartitioner: 67 | num_partitions = 1 68 | model_parallel_submesh = None 69 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 70 | 71 | train/utils.DatasetConfig: 72 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME 73 | task_feature_lengths = %TASK_FEATURE_LENGTHS 74 | split = 'train' 75 | batch_size = %BATCH_SIZE 76 | shuffle = %SHUFFLE_TRAIN_EXAMPLES 77 | seed = None # use a new seed each run/restart 78 | use_cached = %USE_CACHED_TASKS 79 | pack = False 80 | module = %MIXTURE_OR_TASK_MODULE 81 | 82 | train_eval/utils.DatasetConfig: 83 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL 84 | task_feature_lengths = %TASK_FEATURE_LENGTHS 85 | split = 'validation' 86 | batch_size = %BATCH_SIZE 87 | shuffle = False 88 | seed = 42 89 | use_cached = %USE_CACHED_TASKS 90 | pack = False 91 | module = %MIXTURE_OR_TASK_MODULE 92 | 93 | utils.CheckpointConfig: 94 | restore = @utils.RestoreCheckpointConfig() 95 | save = @utils.SaveCheckpointConfig() 96 | 97 | utils.RestoreCheckpointConfig: 98 | path = %INITIAL_CHECKPOINT_PATH 99 | mode = 'specific' 100 | dtype = 'float32' 101 | strict = False 102 | # fallback_to_scratch = %FALLBACK_TO_SCRATCH 103 | # state_transformation_fns = %STATE_TRANSFORMATION_FNS 104 | 105 | utils.SaveCheckpointConfig: 106 | period = 20000 107 | dtype = 'float32' 108 | keep = 5 # keep all checkpoints 109 | save_dataset = False # don't checkpoint dataset state 110 | 111 | trainer.Trainer: 112 | num_microbatches = None 113 | learning_rate_fn = @utils.create_learning_rate_scheduler() 114 | 115 | # utils.create_learning_rate_scheduler: 116 | # factors = 'constant * rsqrt_decay' 117 | # base_learning_rate = 1.0 118 | # warmup_steps = 100 # 10k to keep consistent with T5/MTF defaults. 119 | 120 | utils.create_learning_rate_scheduler: 121 | total_steps = %TRAIN_STEPS 122 | base = %BASE_LR 123 | decay_type = 'rsqrt' 124 | warmup_steps = %WARMUP_STEPS 125 | timescale = %TIMESCALE 126 | 127 | WANDB_GROUP = None 128 | WANDB_NAME = None 129 | 130 | from t5x.examples.unified_io import utils as unified_io_utils 131 | 132 | unified_io_utils.init_wandb: 133 | group = %WANDB_GROUP 134 | name = %WANDB_NAME 135 | -------------------------------------------------------------------------------- /t5x/configs/runs/eval.gin: -------------------------------------------------------------------------------- 1 | # Defaults for eval.py. 2 | # 3 | # 4 | # You must also include a binding for MODEL. 5 | # 6 | # Required to be set: 7 | # 8 | # - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to evaluate on 9 | # - CHECKPOINT_PATH: The model checkpoint to evaluate 10 | # - EVAL_OUTPUT_DIR: The dir to write results to. 11 | # 12 | # 13 | # Commonly overridden options: 14 | # 15 | # - DatasetConfig.split 16 | # - DatasetConfig.batch_size 17 | # - DatasetConfig.use_cached 18 | # - RestoreCheckpointConfig.mode 19 | # - PjitPartitioner.num_partitions 20 | from __gin__ import dynamic_registration 21 | 22 | import __main__ as eval_script 23 | import seqio 24 | from t5x import partitioning 25 | from t5x import utils 26 | import t5x.examples.unified_io.evaluator as uio_evaluator 27 | 28 | # Must be overridden 29 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED 30 | CHECKPOINT_PATH = %gin.REQUIRED 31 | EVAL_OUTPUT_DIR = %gin.REQUIRED 32 | TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use. 33 | 34 | # DEPRECATED: Import the this module in your gin file. 35 | MIXTURE_OR_TASK_MODULE = None 36 | 37 | eval_script.evaluate: 38 | model = %MODEL # imported from separate gin file 39 | dataset_cfg = @utils.DatasetConfig() 40 | partitioner = @partitioning.PjitPartitioner() 41 | restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() 42 | output_dir = %EVAL_OUTPUT_DIR 43 | inference_evaluator_cls = @uio_evaluator.UnifiedIOEvaluator 44 | 45 | partitioning.PjitPartitioner: 46 | num_partitions = 1 47 | model_parallel_submesh = None 48 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 49 | 50 | uio_evaluator.UnifiedIOEvaluator: 51 | logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] 52 | num_examples = None # Use all examples in the dataset. 53 | use_memory_cache = False 54 | 55 | utils.DatasetConfig: 56 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME 57 | task_feature_lengths = %TASK_FEATURE_LENGTHS 58 | split = 'validation' 59 | batch_size = 32 60 | shuffle = False 61 | seed = 42 62 | use_cached = False 63 | use_memory_cache = False 64 | pack = False 65 | use_custom_packing_ops = False 66 | module = %MIXTURE_OR_TASK_MODULE 67 | 68 | utils.RestoreCheckpointConfig: 69 | path = %CHECKPOINT_PATH 70 | mode = 'specific' 71 | -------------------------------------------------------------------------------- /t5x/configs/runs/export.gin: -------------------------------------------------------------------------------- 1 | # Defaults for single_core_export.py. 2 | # 3 | # You must also include a binding for MODEL. 4 | # 5 | # Required to be set: 6 | # 7 | # - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features 8 | # to. 9 | # - CHECKPOINT_PATH: The model checkpoint to use for inference 10 | # - INFER_OUTPUT_DIR: The dir to write results to. When launching using 11 | # XManager, this is set automatically. 12 | # 13 | # Commonly overridden options: 14 | # 15 | # warmup_examples: Optional[List[str]] = None 16 | # jit_compile: bool = False 17 | 18 | from __gin__ import dynamic_registration 19 | 20 | import seqio 21 | 22 | from t5x import checkpoints 23 | from t5x import models 24 | from t5x import partitioning 25 | from t5x import utils 26 | from t5x import export_lib 27 | 28 | # Must be overridden 29 | OUTPUT_FEATURES = %gin.REQUIRED 30 | TASK_FEATURE_LENGTHS = %gin.REQUIRED 31 | CHECKPOINT_PATH = %gin.REQUIRED 32 | MODEL_OUTPUT_DIR = %gin.REQUIRED 33 | MODEL_NAME = %gin.REQUIRED 34 | BATCH_SIZE = None 35 | BEAM_SIZE = 1 36 | 37 | OUTPUT_FEATURES = {'inputs': @inputs/seqio.Feature(), 'targets': @outputs/seqio.Feature()} 38 | 39 | # Plumbing to extract the vocabulary directly from MODEL. This is needed to 40 | # tokenize the features from the saved model inputs we aren't provided with 41 | # vocabularies via a Task. 42 | inputs/seqio.Feature.vocabulary = @models.get_input_vocabulary() 43 | models.get_input_vocabulary.model = %MODEL # imported from separate gin file 44 | outputs/seqio.Feature.vocabulary = @models.get_output_vocabulary() 45 | models.get_output_vocabulary.model = %MODEL # imported from separate gin file 46 | 47 | 48 | # Typical for inference settings: 49 | ACTIVATION_DTYPE = 'bfloat16' 50 | 51 | export_lib.save: 52 | model = %MODEL # imported from separate gin file 53 | inference_mode = 'predict' 54 | restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() 55 | exportable_module_cls = @export_lib.ExportableModule 56 | create_preprocessor_fn = @export_lib.create_preprocessor 57 | create_postprocessor_fn = @export_lib.create_postprocessor 58 | write_warmup_example_fn = @export_lib.write_warmup_examples 59 | partitioner = @partitioning.PjitPartitioner() 60 | output_features = %OUTPUT_FEATURES 61 | task_feature_lengths = %TASK_FEATURE_LENGTHS 62 | output_dir = %MODEL_OUTPUT_DIR 63 | model_name = %MODEL_NAME 64 | batch_size = %BATCH_SIZE 65 | native_lowering = False 66 | 67 | utils.RestoreCheckpointConfig: 68 | path = %CHECKPOINT_PATH 69 | mode = 'specific' 70 | dtype = 'bfloat16' 71 | checkpointer_cls = @checkpoints.Checkpointer 72 | # TODO(b/234480674): GDA disabled due to incompatibility with export. 73 | use_gda = False 74 | 75 | export_lib.create_preprocessor: 76 | output_features = %OUTPUT_FEATURES 77 | task_feature_lengths = %TASK_FEATURE_LENGTHS 78 | 79 | export_lib.create_postprocessor: 80 | output_feature_names = None 81 | 82 | export_lib.ExportableModule: 83 | jit_compile = True 84 | use_batch_function = False 85 | 86 | partitioning.PjitPartitioner: 87 | num_partitions = 1 88 | params_on_devices = True 89 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 90 | 91 | models.EncoderDecoderModel.predict_batch_with_aux: 92 | num_decodes = %BEAM_SIZE 93 | return_all_decodes = True 94 | -------------------------------------------------------------------------------- /t5x/configs/runs/export_seqio.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | from t5x import export_lib 4 | from t5x import partitioning 5 | 6 | include 't5x/configs/runs/export.gin' 7 | 8 | 9 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED 10 | 11 | export_lib.save: 12 | create_preprocessor_fn = @export_lib.create_preprocessor_from_task 13 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME 14 | output_features = None 15 | 16 | export_lib.create_preprocessor_from_task: 17 | model = %MODEL 18 | task_feature_lengths = %TASK_FEATURE_LENGTHS 19 | task_name = %MIXTURE_OR_TASK_NAME 20 | serialized_examples = True 21 | run_precache = False 22 | -------------------------------------------------------------------------------- /t5x/configs/runs/finetune.gin: -------------------------------------------------------------------------------- 1 | # Defaults for finetuning with train.py. 2 | # 3 | # 4 | # You must also include a binding for MODEL. 5 | # 6 | # Required to be set: 7 | # 8 | # - MIXTURE_OR_TASK_NAME 9 | # - TASK_FEATURE_LENGTHS 10 | # - TRAIN_STEPS # includes pretrain steps 11 | # - MODEL_DIR # automatically set when using xm_launch 12 | # - INITIAL_CHECKPOINT_PATH 13 | # 14 | # When running locally, it needs to be passed in the `gin.MODEL_DIR` flag. 15 | # 16 | # `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt 17 | # has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. 18 | # 19 | # Commonly overridden options: 20 | # - DROPOUT_RATE 21 | # - BATCH_SIZE 22 | # - PjitPartitioner.num_partitions 23 | # - Trainer.num_microbatches 24 | # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess 25 | # on the fly. Most common tasks are cached, hence this is set to True by 26 | # default. 27 | 28 | from __gin__ import dynamic_registration 29 | 30 | import __main__ as train_script 31 | import seqio 32 | from t5x import gin_utils 33 | from t5x import partitioning 34 | from t5x import utils 35 | from t5x import trainer 36 | import t5x.examples.unified_io.evaluator as uio_evaluator 37 | 38 | # Must be overridden 39 | MODEL_DIR = %gin.REQUIRED 40 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED 41 | MIXTURE_OR_TASK_NAME_EVAL = %gin.REQUIRED 42 | TASK_FEATURE_LENGTHS_TRAIN = %gin.REQUIRED 43 | TASK_FEATURE_LENGTHS_EVAL = %gin.REQUIRED 44 | MIXTURE_OR_TASK_MODULE = %gin.REQUIRED 45 | TRAIN_STEPS = %gin.REQUIRED 46 | INITIAL_CHECKPOINT_PATH = %gin.REQUIRED 47 | 48 | # Commonly overridden 49 | DROPOUT_RATE = 0.1 50 | USE_CACHED_TASKS = True 51 | BATCH_SIZE = 128 52 | 53 | # Sometimes overridden 54 | EVAL_STEPS = 20 55 | EVAL_PERIOD = 1000 56 | 57 | # Convenience overrides. 58 | EVALUATOR_USE_MEMORY_CACHE = True 59 | EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. 60 | JSON_WRITE_N_RESULTS = None # Write all inferences. 61 | # HW RNG is faster than SW, but has limited determinism. 62 | # Most notably it is not deterministic across different 63 | # submeshes. 64 | USE_HARDWARE_RNG = False 65 | # None always uses faster, hardware RNG 66 | RANDOM_SEED = None 67 | 68 | # DEPRECATED: Import the this module in your gin file. 69 | MIXTURE_OR_TASK_MODULE = None 70 | TARGET_FIELD_NAME = 'text_targets' 71 | 72 | train_script.train: 73 | model = %MODEL # imported from separate gin file 74 | model_dir = %MODEL_DIR 75 | train_dataset_cfg = @train/utils.DatasetConfig() 76 | train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() 77 | infer_eval_dataset_cfg = None 78 | checkpoint_cfg = @utils.CheckpointConfig() 79 | partitioner = @partitioning.PjitPartitioner() 80 | trainer_cls = @trainer.Trainer 81 | total_steps = %TRAIN_STEPS 82 | eval_steps = %EVAL_STEPS 83 | eval_period = %EVAL_PERIOD 84 | random_seed = %RANDOM_SEED 85 | use_hardware_rng = %USE_HARDWARE_RNG 86 | summarize_config_fn = @gin_utils.summarize_gin_config 87 | inference_evaluator_cls = @uio_evaluator.UnifiedIOEvaluator 88 | 89 | partitioning.PjitPartitioner: 90 | num_partitions = 1 91 | model_parallel_submesh = None 92 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 93 | 94 | uio_evaluator.UnifiedIOEvaluator: 95 | logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] 96 | num_examples = %EVALUATOR_NUM_EXAMPLES 97 | use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE 98 | 99 | seqio.JSONLogger: 100 | write_n_results = %JSON_WRITE_N_RESULTS 101 | 102 | train/utils.DatasetConfig: 103 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME 104 | task_feature_lengths = %TASK_FEATURE_LENGTHS_TRAIN 105 | split = 'train' 106 | batch_size = %BATCH_SIZE 107 | shuffle = True 108 | seed = None # use a new seed each run/restart 109 | use_cached = %USE_CACHED_TASKS 110 | pack = False 111 | module = %MIXTURE_OR_TASK_MODULE 112 | 113 | train_eval/utils.DatasetConfig: 114 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL 115 | task_feature_lengths = %TASK_FEATURE_LENGTHS_EVAL 116 | split = 'validation' 117 | batch_size = %BATCH_SIZE 118 | shuffle = False 119 | seed = 42 120 | use_cached = %USE_CACHED_TASKS 121 | pack = False 122 | module = %MIXTURE_OR_TASK_MODULE 123 | 124 | infer_eval/utils.DatasetConfig: 125 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL 126 | task_feature_lengths = %TASK_FEATURE_LENGTHS_EVAL # compute max 127 | split = 'validation' 128 | batch_size = %BATCH_SIZE 129 | shuffle = False 130 | seed = 42 131 | use_cached = %USE_CACHED_TASKS 132 | pack = False 133 | module = %MIXTURE_OR_TASK_MODULE 134 | target_field_name = %TARGET_FIELD_NAME 135 | 136 | utils.CheckpointConfig: 137 | restore = @utils.RestoreCheckpointConfig() 138 | save = @utils.SaveCheckpointConfig() 139 | utils.RestoreCheckpointConfig: 140 | path = %INITIAL_CHECKPOINT_PATH 141 | mode = 'specific' 142 | dtype = 'float32' 143 | strict = False 144 | 145 | utils.SaveCheckpointConfig: 146 | period = 5000 147 | dtype = 'float32' 148 | keep = None # keep all checkpoints 149 | save_dataset = False # don't checkpoint dataset state 150 | 151 | trainer.Trainer: 152 | num_microbatches = None 153 | learning_rate_fn = @utils.create_learning_rate_scheduler() 154 | 155 | utils.create_learning_rate_scheduler: 156 | factors = 'constant' 157 | base_learning_rate = 0.0001 158 | warmup_steps = 1000 159 | -------------------------------------------------------------------------------- /t5x/configs/runs/infer.gin: -------------------------------------------------------------------------------- 1 | # Defaults for infer.py. 2 | # 3 | # 4 | # You must also include a binding for MODEL. 5 | # 6 | # Required to be set: 7 | # 8 | # - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to use for inference 9 | # - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features 10 | # to. 11 | # - CHECKPOINT_PATH: The model checkpoint to use for inference 12 | # - INFER_OUTPUT_DIR: The dir to write results to. 13 | # 14 | # 15 | # Commonly overridden options: 16 | # 17 | # - infer.mode 18 | # - infer.checkpoint_period 19 | # - infer.shard_id 20 | # - infer.num_shards 21 | # - DatasetConfig.split 22 | # - DatasetConfig.batch_size 23 | # - DatasetConfig.use_cached 24 | # - RestoreCheckpointConfig.is_tensorflow 25 | # - RestoreCheckpointConfig.mode 26 | # - PjitPartitioner.num_partitions 27 | from __gin__ import dynamic_registration 28 | 29 | import __main__ as infer_script 30 | from t5x import partitioning 31 | from t5x import utils 32 | 33 | # Must be overridden 34 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED 35 | TASK_FEATURE_LENGTHS = %gin.REQUIRED 36 | CHECKPOINT_PATH = %gin.REQUIRED 37 | INFER_OUTPUT_DIR = %gin.REQUIRED 38 | CHECKPOINT_PERIOD = %gin.REQUIRED 39 | # DEPRECATED: Import the this module in your gin file. 40 | MIXTURE_OR_TASK_MODULE = None 41 | 42 | infer_script.infer: 43 | mode = 'predict_with_aux' 44 | model = %MODEL # imported from separate gin file 45 | output_dir = %INFER_OUTPUT_DIR 46 | dataset_cfg = @utils.DatasetConfig() 47 | partitioner = @partitioning.PjitPartitioner() 48 | restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() 49 | checkpoint_period = %CHECKPOINT_PERIOD 50 | shard_id = 0 51 | num_shards = 1 52 | checkpoint_ds_iter = False 53 | 54 | partitioning.PjitPartitioner: 55 | num_partitions = 1 56 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 57 | 58 | utils.DatasetConfig: 59 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME 60 | module = %MIXTURE_OR_TASK_MODULE 61 | task_feature_lengths = %TASK_FEATURE_LENGTHS 62 | use_cached = False 63 | split = 'test' 64 | batch_size = 32 65 | shuffle = False 66 | seed = 0 67 | pack = False 68 | 69 | utils.RestoreCheckpointConfig: 70 | path = %CHECKPOINT_PATH 71 | mode = 'specific' 72 | dtype = 'bfloat16' 73 | -------------------------------------------------------------------------------- /t5x/configs/runs/infer_from_tfexample_file.gin: -------------------------------------------------------------------------------- 1 | # Defaults for infer.py if using a TFExample file as input. 2 | # 3 | # 4 | # The features from each TFExample are tokenized using the model's vocabulary. 5 | # By default, the inputs feature is assumed to be keyed as 'inputs', but this 6 | # can be overridden with `create_task_from_tfexample_file.inputs_key`. 7 | # 8 | # You must also include a binding for MODEL. 9 | # 10 | # Required to be set: 11 | # 12 | # - TF_EXAMPLE_FILE_PATHS: The path to read TF Examples from. 13 | # - TF_EXAMPLE_FILE_TYPE: The type of file to read TF Examples from. Currently 14 | # supported: 'tfrecord', 'recordio', 'sstable'. 15 | # - FEATURE_LENGTHS: The maximum length per feature in the TF Examples. 16 | # - CHECKPOINT_PATH: The model checkpoint to use for inference 17 | # - INFER_OUTPUT_DIR: The dir to write results to. 18 | # 19 | # 20 | # Commonly overridden options: 21 | # 22 | # - infer.mode 23 | # - infer.checkpoint_period 24 | # - infer.shard_id 25 | # - infer.num_shards 26 | # - create_task_from_tfexample_file.inputs_key 27 | # - create_task_from_tfexample_file.targets_key 28 | # - DatasetConfig.split 29 | # - DatasetConfig.batch_size 30 | # - RestoreCheckpointConfig.mode 31 | # - PjitPartitioner.num_partitions 32 | from __gin__ import dynamic_registration 33 | 34 | import __main__ as infer_script 35 | import seqio 36 | from t5x import models 37 | from t5x import partitioning 38 | from t5x import utils 39 | 40 | # Must be overridden 41 | TF_EXAMPLE_FILE_PATHS = %gin.REQUIRED 42 | TF_EXAMPLE_FILE_TYPE = %gin.REQUIRED 43 | FEATURE_LENGTHS = %gin.REQUIRED 44 | CHECKPOINT_PATH = %gin.REQUIRED 45 | INFER_OUTPUT_DIR = %gin.REQUIRED 46 | 47 | infer_script.infer: 48 | mode = 'predict' 49 | model = %MODEL # imported from separate gin file 50 | output_dir = %INFER_OUTPUT_DIR 51 | dataset_cfg = @utils.DatasetConfig() 52 | partitioner = @partitioning.PjitPartitioner() 53 | restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() 54 | checkpoint_period = 100 55 | shard_id = 0 56 | num_shards = 1 57 | 58 | partitioning.PjitPartitioner: 59 | num_partitions = 1 60 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 61 | 62 | utils.DatasetConfig: 63 | mixture_or_task_name = @infer_script.create_task_from_tfexample_file() 64 | task_feature_lengths = %FEATURE_LENGTHS 65 | split = 'infer' 66 | batch_size = 32 67 | shuffle = False 68 | seed = 0 69 | pack = False 70 | 71 | infer_script.create_task_from_tfexample_file: 72 | paths = %TF_EXAMPLE_FILE_PATHS 73 | file_type = %TF_EXAMPLE_FILE_TYPE 74 | inputs_key = 'inputs' 75 | targets_key = None 76 | features = {'inputs': @inputs/seqio.Feature(), 'targets': @outputs/seqio.Feature()} 77 | 78 | # Plumbing to extract the vocabulary directly from MODEL. This is needed to 79 | # tokenize the features from the TFExample we aren't provided with vocabularies 80 | # via a Task. 81 | inputs/seqio.Feature.vocabulary = @models.get_input_vocabulary() 82 | models.get_input_vocabulary.model = %MODEL 83 | outputs/seqio.Feature.vocabulary = @models.get_output_vocabulary() 84 | models.get_output_vocabulary.model = %MODEL 85 | 86 | utils.RestoreCheckpointConfig: 87 | mode = 'specific' 88 | path = %CHECKPOINT_PATH 89 | dtype = 'bfloat16' 90 | 91 | -------------------------------------------------------------------------------- /t5x/configs/runs/loss.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import __main__ as loss_script 4 | import seqio 5 | from t5x import partitioning 6 | from t5x import utils 7 | import t5x.examples.unified_io.evaluator as uio_evaluator 8 | 9 | # Must be overridden 10 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED 11 | CHECKPOINT_PATH = %gin.REQUIRED 12 | EVAL_OUTPUT_DIR = %gin.REQUIRED 13 | EVAL_STEPS = %gin.REQUIRED 14 | TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use. 15 | 16 | # DEPRECATED: Import the this module in your gin file. 17 | MIXTURE_OR_TASK_MODULE = None 18 | 19 | loss_script.compute_loss: 20 | model = %MODEL # imported from separate gin file 21 | dataset_cfg = @utils.DatasetConfig() 22 | partitioner = @partitioning.PjitPartitioner() 23 | restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() 24 | output_dir = %EVAL_OUTPUT_DIR 25 | eval_steps = %EVAL_STEPS 26 | 27 | partitioning.PjitPartitioner: 28 | num_partitions = 1 29 | model_parallel_submesh = None 30 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 31 | 32 | utils.DatasetConfig: 33 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME 34 | task_feature_lengths = %TASK_FEATURE_LENGTHS 35 | split = 'validation' 36 | batch_size = 32 37 | shuffle = False 38 | seed = 42 39 | use_cached = False 40 | use_memory_cache = False 41 | pack = False 42 | use_custom_packing_ops = False 43 | module = %MIXTURE_OR_TASK_MODULE 44 | 45 | utils.RestoreCheckpointConfig: 46 | path = %CHECKPOINT_PATH 47 | mode = 'specific' 48 | -------------------------------------------------------------------------------- /t5x/configs/runs/multitask.gin: -------------------------------------------------------------------------------- 1 | # Defaults for finetuning with train.py. 2 | # 3 | # 4 | # You must also include a binding for MODEL. 5 | # 6 | # Required to be set: 7 | # 8 | # - MIXTURE_OR_TASK_NAME 9 | # - TASK_FEATURE_LENGTHS 10 | # - TRAIN_STEPS # includes pretrain steps 11 | # - MODEL_DIR # automatically set when using xm_launch 12 | # - INITIAL_CHECKPOINT_PATH 13 | # 14 | # When running locally, it needs to be passed in the `gin.MODEL_DIR` flag. 15 | # 16 | # `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt 17 | # has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. 18 | # 19 | # Commonly overridden options: 20 | # - DROPOUT_RATE 21 | # - BATCH_SIZE 22 | # - PjitPartitioner.num_partitions 23 | # - Trainer.num_microbatches 24 | # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess 25 | # on the fly. Most common tasks are cached, hence this is set to True by 26 | # default. 27 | 28 | from __gin__ import dynamic_registration 29 | 30 | import __main__ as train_script 31 | import seqio 32 | from t5x import gin_utils 33 | from t5x import partitioning 34 | from t5x import utils 35 | from t5x import trainer 36 | import t5x.examples.unified_io.evaluator as uio_evaluator 37 | 38 | # Must be overridden 39 | MODEL_DIR = %gin.REQUIRED 40 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED 41 | MIXTURE_OR_TASK_NAME_EVAL = %gin.REQUIRED 42 | TASK_FEATURE_LENGTHS_TRAIN = %gin.REQUIRED 43 | TASK_FEATURE_LENGTHS_EVAL = %gin.REQUIRED 44 | MIXTURE_OR_TASK_MODULE = %gin.REQUIRED 45 | TRAIN_STEPS = %gin.REQUIRED 46 | INITIAL_CHECKPOINT_PATH = %gin.REQUIRED 47 | 48 | # Commonly overridden 49 | DROPOUT_RATE = 0.1 50 | USE_CACHED_TASKS = False 51 | BATCH_SIZE = 128 52 | 53 | # Sometimes overridden 54 | EVAL_STEPS = 20 55 | 56 | # Convenience overrides. 57 | EVALUATOR_USE_MEMORY_CACHE = False 58 | EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. 59 | JSON_WRITE_N_RESULTS = None # Write all inferences. 60 | # HW RNG is faster than SW, but has limited determinism. 61 | # Most notably it is not deterministic across different 62 | # submeshes. 63 | USE_HARDWARE_RNG = False 64 | # None always uses faster, hardware RNG 65 | RANDOM_SEED = None 66 | 67 | # DEPRECATED: Import the this module in your gin file. 68 | MIXTURE_OR_TASK_MODULE = None 69 | TARGET_FIELD_NAME = 'text_targets' 70 | 71 | train_script.train: 72 | model = %MODEL # imported from separate gin file 73 | model_dir = %MODEL_DIR 74 | train_dataset_cfg = @train/utils.DatasetConfig() 75 | train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() 76 | infer_eval_dataset_cfg = None 77 | checkpoint_cfg = @utils.CheckpointConfig() 78 | partitioner = @partitioning.PjitPartitioner() 79 | trainer_cls = @trainer.Trainer 80 | total_steps = %TRAIN_STEPS 81 | eval_steps = %EVAL_STEPS 82 | eval_period = 1000 83 | random_seed = %RANDOM_SEED 84 | use_hardware_rng = %USE_HARDWARE_RNG 85 | summarize_config_fn = @gin_utils.summarize_gin_config 86 | inference_evaluator_cls = @uio_evaluator.UnifiedIOEvaluator 87 | 88 | partitioning.PjitPartitioner: 89 | num_partitions = 1 90 | model_parallel_submesh = None 91 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 92 | 93 | uio_evaluator.UnifiedIOEvaluator: 94 | logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] 95 | num_examples = %EVALUATOR_NUM_EXAMPLES 96 | use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE 97 | 98 | seqio.JSONLogger: 99 | write_n_results = %JSON_WRITE_N_RESULTS 100 | 101 | train/utils.DatasetConfig: 102 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME 103 | task_feature_lengths = %TASK_FEATURE_LENGTHS_TRAIN 104 | split = 'train' 105 | batch_size = %BATCH_SIZE 106 | shuffle = True 107 | seed = None # use a new seed each run/restart 108 | use_cached = %USE_CACHED_TASKS 109 | pack = False 110 | module = %MIXTURE_OR_TASK_MODULE 111 | 112 | train_eval/utils.DatasetConfig: 113 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL 114 | task_feature_lengths = %TASK_FEATURE_LENGTHS_EVAL 115 | split = 'validation' 116 | batch_size = %BATCH_SIZE 117 | shuffle = False 118 | seed = 42 119 | use_cached = %USE_CACHED_TASKS 120 | pack = False 121 | module = %MIXTURE_OR_TASK_MODULE 122 | 123 | utils.CheckpointConfig: 124 | restore = @utils.RestoreCheckpointConfig() 125 | save = @utils.SaveCheckpointConfig() 126 | 127 | utils.RestoreCheckpointConfig: 128 | path = %INITIAL_CHECKPOINT_PATH 129 | mode = 'specific' 130 | dtype = 'float32' 131 | strict = False 132 | 133 | utils.SaveCheckpointConfig: 134 | period = 50000 135 | dtype = 'float32' 136 | keep = None # keep all checkpoints 137 | save_dataset = False # don't checkpoint dataset state 138 | 139 | trainer.Trainer: 140 | num_microbatches = None 141 | learning_rate_fn = @utils.create_learning_rate_scheduler() 142 | 143 | utils.create_learning_rate_scheduler: 144 | factors = 'constant * rsqrt_decay' 145 | base_learning_rate = 1.0 146 | warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults. 147 | -------------------------------------------------------------------------------- /t5x/configs/runs/precompile.gin: -------------------------------------------------------------------------------- 1 | # Defaults for precompile mode in main.py. 2 | # 3 | # You must also include a binding for MODEL. 4 | # 5 | # Required to be set: 6 | # 7 | # - MIXTURE_OR_TASK_NAME 8 | # - TASK_FEATURE_LENGTHS 9 | # - TRAIN_STEPS 10 | # - MODEL_DIR: # automatically set when using xm_launch 11 | # 12 | # Commonly overridden options: 13 | # 14 | # - USE_CACHED_TASKS 15 | # - BATCH_SIZE 16 | # - PjitPartitioner.num_partitions 17 | from __gin__ import dynamic_registration 18 | 19 | import __main__ as train_script 20 | import seqio 21 | from t5x import gin_utils 22 | from t5x import partitioning 23 | from t5x import utils 24 | from t5x import trainer 25 | 26 | MODEL_DIR = %gin.REQUIRED 27 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED 28 | TASK_FEATURE_LENGTHS = %gin.REQUIRED 29 | 30 | 31 | # Commonly overridden 32 | USE_CACHED_TASKS = True 33 | BATCH_SIZE = 128 34 | 35 | # None always uses faster, hardware RNG 36 | RANDOM_SEED = None 37 | 38 | train_script.precompile: 39 | model = %MODEL # imported from separate gin file 40 | model_dir = %MODEL_DIR 41 | train_dataset_cfg = @train/utils.DatasetConfig() 42 | partitioner = @partitioning.PjitPartitioner() 43 | random_seed = %RANDOM_SEED 44 | 45 | partitioning.PjitPartitioner: 46 | num_partitions = 1 47 | model_parallel_submesh = None 48 | backend = "tpu" 49 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 50 | 51 | train/utils.DatasetConfig: 52 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME 53 | task_feature_lengths = %TASK_FEATURE_LENGTHS 54 | split = 'train' 55 | batch_size = %BATCH_SIZE 56 | shuffle = True 57 | seed = None # use a new seed each run/restart 58 | use_cached = %USE_CACHED_TASKS 59 | pack = True 60 | -------------------------------------------------------------------------------- /t5x/configs/runs/pretrain.gin: -------------------------------------------------------------------------------- 1 | # Defaults for finetuning with train.py. 2 | # 3 | # 4 | # You must also include a binding for MODEL. 5 | # 6 | # Required to be set: 7 | # 8 | # - MIXTURE_OR_TASK_NAME 9 | # - TASK_FEATURE_LENGTHS 10 | # - TRAIN_STEPS # includes pretrain steps 11 | # - MODEL_DIR # automatically set when using xm_launch 12 | # - INITIAL_CHECKPOINT_PATH 13 | # 14 | # When running locally, it needs to be passed in the `gin.MODEL_DIR` flag. 15 | # 16 | # `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt 17 | # has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. 18 | # 19 | # Commonly overridden options: 20 | # - DROPOUT_RATE 21 | # - BATCH_SIZE 22 | # - PjitPartitioner.num_partitions 23 | # - Trainer.num_microbatches 24 | # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess 25 | # on the fly. Most common tasks are cached, hence this is set to True by 26 | # default. 27 | 28 | from __gin__ import dynamic_registration 29 | 30 | import __main__ as train_script 31 | import seqio 32 | from t5x import gin_utils 33 | from t5x import partitioning 34 | from t5x import utils 35 | from t5x import trainer 36 | import t5x.examples.unified_io.evaluator as uio_evaluator 37 | 38 | # Must be overridden 39 | MODEL_DIR = %gin.REQUIRED 40 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED 41 | MIXTURE_OR_TASK_NAME_EVAL = %gin.REQUIRED 42 | TASK_FEATURE_LENGTHS_TRAIN = %gin.REQUIRED 43 | TASK_FEATURE_LENGTHS_EVAL = %gin.REQUIRED 44 | MIXTURE_OR_TASK_MODULE = %gin.REQUIRED 45 | TRAIN_STEPS = %gin.REQUIRED 46 | INITIAL_CHECKPOINT_PATH = %gin.REQUIRED 47 | 48 | # Commonly overridden 49 | DROPOUT_RATE = 0.1 50 | USE_CACHED_TASKS = False 51 | BATCH_SIZE = 128 52 | 53 | # Sometimes overridden 54 | EVAL_STEPS = 20 55 | 56 | # Convenience overrides. 57 | EVALUATOR_USE_MEMORY_CACHE = False 58 | EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. 59 | JSON_WRITE_N_RESULTS = None # Write all inferences. 60 | # HW RNG is faster than SW, but has limited determinism. 61 | # Most notably it is not deterministic across different 62 | # submeshes. 63 | USE_HARDWARE_RNG = False 64 | # None always uses faster, hardware RNG 65 | RANDOM_SEED = None 66 | 67 | # DEPRECATED: Import the this module in your gin file. 68 | MIXTURE_OR_TASK_MODULE = None 69 | TARGET_FIELD_NAME = 'text_targets' 70 | 71 | train_script.train: 72 | model = %MODEL # imported from separate gin file 73 | model_dir = %MODEL_DIR 74 | train_dataset_cfg = @train/utils.DatasetConfig() 75 | train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() 76 | infer_eval_dataset_cfg = None 77 | checkpoint_cfg = @utils.CheckpointConfig() 78 | partitioner = @partitioning.PjitPartitioner() 79 | trainer_cls = @trainer.Trainer 80 | total_steps = %TRAIN_STEPS 81 | eval_steps = %EVAL_STEPS 82 | eval_period = 1000 83 | random_seed = %RANDOM_SEED 84 | use_hardware_rng = %USE_HARDWARE_RNG 85 | summarize_config_fn = @gin_utils.summarize_gin_config 86 | inference_evaluator_cls = @uio_evaluator.UnifiedIOEvaluator 87 | use_gda = True 88 | partitioning.PjitPartitioner: 89 | num_partitions = 1 90 | model_parallel_submesh = None 91 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 92 | 93 | uio_evaluator.UnifiedIOEvaluator: 94 | logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] 95 | num_examples = %EVALUATOR_NUM_EXAMPLES 96 | use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE 97 | 98 | seqio.JSONLogger: 99 | write_n_results = %JSON_WRITE_N_RESULTS 100 | 101 | train/utils.DatasetConfig: 102 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME 103 | task_feature_lengths = %TASK_FEATURE_LENGTHS_TRAIN 104 | split = 'train' 105 | batch_size = %BATCH_SIZE 106 | shuffle = True 107 | seed = None # use a new seed each run/restart 108 | use_cached = %USE_CACHED_TASKS 109 | pack = False 110 | module = %MIXTURE_OR_TASK_MODULE 111 | 112 | train_eval/utils.DatasetConfig: 113 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL 114 | task_feature_lengths = %TASK_FEATURE_LENGTHS_EVAL 115 | split = 'validation' 116 | batch_size = %BATCH_SIZE 117 | shuffle = False 118 | seed = 42 119 | use_cached = %USE_CACHED_TASKS 120 | pack = False 121 | module = %MIXTURE_OR_TASK_MODULE 122 | 123 | utils.CheckpointConfig: 124 | restore = @utils.RestoreCheckpointConfig() 125 | save = @utils.SaveCheckpointConfig() 126 | 127 | utils.RestoreCheckpointConfig: 128 | path = %INITIAL_CHECKPOINT_PATH 129 | mode = 'specific' 130 | dtype = 'float32' 131 | strict = False 132 | 133 | utils.SaveCheckpointConfig: 134 | period = 50000 135 | dtype = 'float32' 136 | keep = None # keep all checkpoints 137 | save_dataset = False # don't checkpoint dataset state 138 | 139 | # trainer.Trainer: 140 | # num_microbatches = None 141 | # learning_rate_fn = @utils.create_learning_rate_scheduler() 142 | 143 | # utils.create_learning_rate_scheduler: 144 | # factors = 'constant * rsqrt_decay' 145 | # base_learning_rate = 1.0 146 | # warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults. 147 | -------------------------------------------------------------------------------- /t5x/configs/runs/vit_vqgan.gin: -------------------------------------------------------------------------------- 1 | # Defaults for finetuning with train.py. 2 | # 3 | # 4 | # You must also include a binding for MODEL. 5 | # 6 | # Required to be set: 7 | # 8 | # - MIXTURE_OR_TASK_NAME 9 | # - TASK_FEATURE_LENGTHS 10 | # - TRAIN_STEPS # includes pretrain steps 11 | # - MODEL_DIR # automatically set when using xm_launch 12 | # - INITIAL_CHECKPOINT_PATH 13 | # 14 | # When running locally, it needs to be passed in the `gin.MODEL_DIR` flag. 15 | # 16 | # `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt 17 | # has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. 18 | # 19 | # Commonly overridden options: 20 | # - DROPOUT_RATE 21 | # - BATCH_SIZE 22 | # - PjitPartitioner.num_partitions 23 | # - Trainer.num_microbatches 24 | # - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess 25 | # on the fly. Most common tasks are cached, hence this is set to True by 26 | # default. 27 | 28 | from __gin__ import dynamic_registration 29 | 30 | import __main__ as train_script 31 | import seqio 32 | from t5x import gin_utils 33 | from t5x import partitioning 34 | from t5x import utils 35 | from t5x import trainer 36 | import t5x.examples.unified_io.evaluator as uio_evaluator 37 | from t5x import train_state as train_state_lib 38 | 39 | # Must be overridden 40 | MODEL_DIR = %gin.REQUIRED 41 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED 42 | MIXTURE_OR_TASK_NAME_EVAL = %gin.REQUIRED 43 | TASK_FEATURE_LENGTHS_TRAIN = %gin.REQUIRED 44 | TASK_FEATURE_LENGTHS_EVAL = %gin.REQUIRED 45 | MIXTURE_OR_TASK_MODULE = %gin.REQUIRED 46 | TRAIN_STEPS = %gin.REQUIRED 47 | INITIAL_CHECKPOINT_PATH = %gin.REQUIRED 48 | 49 | # Commonly overridden 50 | DROPOUT_RATE = 0.1 51 | USE_CACHED_TASKS = True 52 | BATCH_SIZE = 128 53 | 54 | # Sometimes overridden 55 | EVAL_STEPS = 20 56 | 57 | # Convenience overrides. 58 | EVALUATOR_USE_MEMORY_CACHE = True 59 | EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. 60 | JSON_WRITE_N_RESULTS = None # Write all inferences. 61 | # HW RNG is faster than SW, but has limited determinism. 62 | # Most notably it is not deterministic across different 63 | # submeshes. 64 | USE_HARDWARE_RNG = False 65 | # None always uses faster, hardware RNG 66 | RANDOM_SEED = None 67 | 68 | # DEPRECATED: Import the this module in your gin file. 69 | MIXTURE_OR_TASK_MODULE = None 70 | TARGET_FIELD_NAME = 'text_targets' 71 | 72 | train_script.train: 73 | model = %MODEL # imported from separate gin file 74 | model_dir = %MODEL_DIR 75 | train_dataset_cfg = @train/utils.DatasetConfig() 76 | train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() 77 | infer_eval_dataset_cfg = None 78 | checkpoint_cfg = @utils.CheckpointConfig() 79 | partitioner = @partitioning.PjitPartitioner() 80 | trainer_cls = @trainer.Trainer 81 | total_steps = %TRAIN_STEPS 82 | eval_steps = %EVAL_STEPS 83 | eval_period = 1000 84 | random_seed = %RANDOM_SEED 85 | use_hardware_rng = %USE_HARDWARE_RNG 86 | summarize_config_fn = @gin_utils.summarize_gin_config 87 | inference_evaluator_cls = @uio_evaluator.UnifiedIOEvaluator 88 | 89 | partitioning.PjitPartitioner: 90 | num_partitions = 1 91 | model_parallel_submesh = None 92 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 93 | 94 | uio_evaluator.UnifiedIOEvaluator: 95 | logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] 96 | num_examples = %EVALUATOR_NUM_EXAMPLES 97 | use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE 98 | 99 | seqio.JSONLogger: 100 | write_n_results = %JSON_WRITE_N_RESULTS 101 | 102 | train/utils.DatasetConfig: 103 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME 104 | task_feature_lengths = %TASK_FEATURE_LENGTHS_TRAIN 105 | split = 'train' 106 | batch_size = %BATCH_SIZE 107 | shuffle = True 108 | seed = None # use a new seed each run/restart 109 | use_cached = %USE_CACHED_TASKS 110 | pack = False 111 | module = %MIXTURE_OR_TASK_MODULE 112 | 113 | train_eval/utils.DatasetConfig: 114 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL 115 | task_feature_lengths = %TASK_FEATURE_LENGTHS_EVAL 116 | split = 'validation' 117 | batch_size = %BATCH_SIZE 118 | shuffle = False 119 | seed = 42 120 | use_cached = %USE_CACHED_TASKS 121 | pack = False 122 | module = %MIXTURE_OR_TASK_MODULE 123 | 124 | infer_eval/utils.DatasetConfig: 125 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL 126 | task_feature_lengths = %TASK_FEATURE_LENGTHS_EVAL # compute max 127 | split = 'validation' 128 | batch_size = %BATCH_SIZE 129 | shuffle = False 130 | seed = 42 131 | use_cached = %USE_CACHED_TASKS 132 | pack = False 133 | module = %MIXTURE_OR_TASK_MODULE 134 | target_field_name = %TARGET_FIELD_NAME 135 | 136 | utils.CheckpointConfig: 137 | restore = @utils.RestoreCheckpointConfig() 138 | save = @utils.SaveCheckpointConfig() 139 | 140 | utils.RestoreCheckpointConfig: 141 | path = %INITIAL_CHECKPOINT_PATH 142 | mode = 'specific' 143 | dtype = 'float32' 144 | strict = False 145 | 146 | utils.SaveCheckpointConfig: 147 | period = 1000 148 | dtype = 'float32' 149 | keep = 3 # keep all checkpoints 150 | save_dataset = False # don't checkpoint dataset state -------------------------------------------------------------------------------- /t5x/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/t5x/examples/__init__.py -------------------------------------------------------------------------------- /t5x/examples/unified_io/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This empty file is needed for loading the gin files in this directory. 16 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/aux_fns.py: -------------------------------------------------------------------------------- 1 | """Logit masking and checkpoint transformation functions""" 2 | 3 | import functools 4 | from typing import Dict 5 | 6 | import gin 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | 11 | from t5x import state_utils 12 | from t5x.examples.unified_io import config 13 | from t5x.examples.unified_io.data.data_utils import get_default_vocabulary 14 | from flax import linen as nn 15 | 16 | 17 | def apply_mask(logits, cur_index, mask): 18 | if len(mask.shape) == 3: 19 | mask = mask[:, cur_index] 20 | elif len(mask.shape) == 2: 21 | mask = jnp.reshape(mask[cur_index], [1, -1]) 22 | else: 23 | mask = jnp.reshape(mask, [1, -1]) 24 | if mask.dtype == jnp.bool_: 25 | flat_logits = jnp.where(mask, -1e10, logits) 26 | else: 27 | flat_logits = mask + logits 28 | return flat_logits 29 | 30 | 31 | def clf_free_logit_mask_fn(logits, _, num_decodes, alpha=10.0): 32 | logits = nn.log_softmax(logits) 33 | logits = jnp.reshape(logits, [-1, num_decodes, logits.shape[-1]]) 34 | bs = logits.shape[0] 35 | logits, clf_free = logits[:bs//2], logits[bs//2:] 36 | logits = (1 + alpha) * logits - alpha * clf_free 37 | return jnp.reshape(jnp.tile(logits, (2, 1, 1)), [bs*num_decodes, -1]) 38 | 39 | 40 | def clf_free_next_token_callback(logits, next_token, num_decodes): 41 | # The classifier free examples need to follow the non-clf-free versions 42 | next_token = jnp.reshape(next_token, [-1, num_decodes]) 43 | bs = next_token.shape[0] 44 | next_token = jnp.tile(next_token[:bs//2], (2, 1)) 45 | return jnp.reshape(next_token, -1) 46 | 47 | 48 | @gin.configurable() 49 | def pose_estimation_mask_fn_part_names(_, lengths): 50 | """Mask logits so that the model only predicts pose part names and location points""" 51 | vocab = get_default_vocabulary() 52 | if config.TOKENIZER == "llama": 53 | vocab_size = 33280 54 | else: 55 | raise NotImplementedError() 56 | masks = [] 57 | loc_mask = np.ones([vocab_size], np.bool_) 58 | loc_mask[32000:33000] = False 59 | for part in config.HUMAN_POSE_PART: 60 | masks += [loc_mask, loc_mask] 61 | for voc_id in vocab.encode(part): 62 | mask = np.ones([vocab_size], np.bool_) 63 | mask[voc_id] = False 64 | masks.append(mask) 65 | eos_mask = np.ones([vocab_size], np.bool_) 66 | eos_mask[1] = 0 67 | masks.append(eos_mask) 68 | mask = jnp.array(np.stack(masks, 0)) 69 | return functools.partial(apply_mask, mask=mask) 70 | 71 | 72 | @gin.configurable 73 | def non_loc_select(_, lengths, thresh=0.5, require_one_box=False): 74 | """Mask logits so EOS is only selected if the total prob over location tokens is < `thresh`""" 75 | voc_size = 33280 if config.TOKENIZER == "llama" else 33152 + 16384 76 | loc_mask = np.zeros([voc_size], np.float32) 77 | loc_mask[:32000] = -10000 78 | loc_mask[33000:] = -10000 79 | loc_mask = jnp.array(loc_mask) 80 | 81 | def _fn(logits, cur_index): 82 | logits = jax.nn.log_softmax(logits) 83 | probs = jnp.exp(jax.scipy.special.logsumexp(logits[:, 32000:33000], axis=-1)) 84 | use_loc = probs > thresh 85 | if require_one_box: 86 | use_loc = jnp.logical_or(use_loc, cur_index <= 3) 87 | return logits + loc_mask[None, :] * use_loc[:, None] 88 | return _fn 89 | 90 | 91 | def state_transformation_fns(): 92 | fn = [ 93 | functools.partial( 94 | state_utils.apply_assignment_map, 95 | assignment_map=[ 96 | (r'state.*', None), 97 | ]) 98 | ] 99 | 100 | return fn 101 | 102 | 103 | def remove_optimizer_state(): 104 | fn = [ 105 | functools.partial( 106 | state_utils.apply_assignment_map, 107 | assignment_map=[ 108 | (r'state.*', None), 109 | ]) 110 | ] 111 | return fn 112 | 113 | 114 | def load_vae(state_dict, target_state_dict: Dict, *, is_resuming: bool = False, modality="image"): 115 | if is_resuming: 116 | return target_state_dict 117 | return dict(target={f"target_encoders_{modality}": dict(discrete_vae=state_dict["target"])}) 118 | 119 | 120 | def vit_vqgan_restore_fn(modality="image"): 121 | return [functools.partial(load_vae, modality=modality)] 122 | 123 | 124 | def load_vqgan(state_dict, target_state_dict: Dict, *, is_resuming: bool = False, modality="image"): 125 | if is_resuming: 126 | return target_state_dict 127 | if 'image_vitvqgan' in state_dict["target"]: 128 | return dict(target={f"target_encoders_{modality}": dict(discrete_vae=state_dict["target"]['image_vitvqgan'])}) 129 | else: 130 | return dict(target={f"target_encoders_{modality}": dict(discrete_vae=state_dict["target"])}) 131 | 132 | 133 | def vqgan_restore_fn(modality="image"): 134 | return [functools.partial(load_vqgan, modality=modality)] 135 | 136 | 137 | def load_vae_all(state_dict, target_state_dict: Dict, *, is_resuming: bool = False): 138 | if is_resuming: 139 | return target_state_dict 140 | 141 | return dict(target=dict( 142 | target_encoders_image=dict(discrete_vae=state_dict["target"]["image_vitvqgan"]), 143 | target_encoders_audio=dict(discrete_vae=state_dict["target"]["audio_vitvqgan"]))) 144 | 145 | 146 | def vit_vqgan_all_restore_fn(): 147 | """State transformation to map a VQGAN checkpoint parameters into a UIO2 model""" 148 | return [load_vae_all] 149 | 150 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/data/__init__.py: -------------------------------------------------------------------------------- 1 | from seqio.dataset_providers import * 2 | from seqio.utils import * 3 | from seqio.vocabularies import * 4 | from seqio.utils import * -------------------------------------------------------------------------------- /t5x/examples/unified_io/data/mixtures.py: -------------------------------------------------------------------------------- 1 | from t5x.examples.unified_io.data import tasks 2 | from t5x.examples.unified_io.data import nlp_instruction_following 3 | from seqio import MixtureRegistry 4 | 5 | 6 | MixtureRegistry.add( 7 | "refexp", 8 | [ 9 | ("refcoco_plus_unc", 1.0), 10 | ("refcocog_google", 1.0), 11 | ("refcoco_unc", 1.0), 12 | ], 13 | ) 14 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/data/postprocessing.py: -------------------------------------------------------------------------------- 1 | def return_example(output_or_target, example=None, is_target=False): 2 | if is_target: 3 | return example 4 | else: 5 | return output_or_target 6 | 7 | 8 | def return_meta(output_or_target, example=None, is_target=False): 9 | if is_target: 10 | return {k[5:]: v for k, v in example.items() if k.startswith("meta/")} 11 | else: 12 | return output_or_target 13 | 14 | 15 | def return_field(output_or_target, field, example=None, is_target=False): 16 | if is_target: 17 | return example[field] 18 | else: 19 | return output_or_target 20 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/data/prompt_definition.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | import gin 5 | 6 | from t5x.examples.unified_io.data.prompt_dict import PROMPT_DICT 7 | 8 | 9 | @gin.configurable 10 | class Prompt: 11 | """Configurable interface for getting prompts""" 12 | 13 | def __init__(self, original_flag=True, revised_original_flag=False, manual_flag=True, 14 | gpt3_flag=True, single_prompt=False, dbg=None): 15 | self.prompt_list = [] 16 | self.original_flag = original_flag 17 | self.revised_original_flag = revised_original_flag 18 | self.manual_flag = manual_flag 19 | self.gpt3_flag = gpt3_flag 20 | self.single_prompt = single_prompt 21 | self.dbg = dbg 22 | 23 | def get_prompt_list(self, task_name, dataset_name): 24 | if self.dbg: 25 | logging.info(f"Using dbg prmopt {self.dbg}") 26 | return [self.dbg] 27 | prompt_list = [] 28 | if self.original_flag: 29 | if self.revised_original_flag and 'revised_original' in PROMPT_DICT[task_name]: 30 | prompt_list += PROMPT_DICT[task_name]['revised_original'] 31 | else: 32 | prompt_list += PROMPT_DICT[task_name]['original'] 33 | if self.revised_original_flag and 'revised_original' in PROMPT_DICT[dataset_name]: 34 | prompt_list += PROMPT_DICT[dataset_name]['revised_original'] 35 | else: 36 | prompt_list += PROMPT_DICT[dataset_name]['original'] 37 | if self.manual_flag: 38 | if 'manual' in PROMPT_DICT[task_name]: 39 | prompt_list += PROMPT_DICT[task_name]['manual'] 40 | if 'manual' in PROMPT_DICT[dataset_name]: 41 | prompt_list += PROMPT_DICT[dataset_name]['manual'] 42 | if self.gpt3_flag: 43 | if 'gpt3' in PROMPT_DICT[task_name]: 44 | prompt_list += PROMPT_DICT[task_name]['gpt3'] 45 | 46 | if 'gpt3' in PROMPT_DICT[dataset_name]: 47 | prompt_list += PROMPT_DICT[dataset_name]['gpt3'] 48 | if not prompt_list: 49 | raise ValueError(f"No prompts for {task_name}/{dataset_name}") 50 | if self.single_prompt: 51 | logging.info(f"Using prompt \"{prompt_list[0]}\" for {task_name} {dataset_name}") 52 | return prompt_list[:1] 53 | return prompt_list 54 | 55 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/data/tasks.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import seqio 4 | from seqio import TaskRegistry 5 | 6 | from t5x.examples.unified_io.metrics import metrics 7 | from t5x.examples.unified_io.data.postprocessing import return_meta, return_field 8 | from t5x.examples.unified_io.metrics.metrics import exact_match 9 | from t5x.examples.unified_io.modality_processing import unified_io_preprocessor 10 | 11 | from t5x.examples.unified_io import config, modality_processing 12 | from t5x.examples.unified_io.data import preprocessing 13 | from t5x.examples.unified_io.data.preprocessing import rekey 14 | 15 | 16 | def add_refexp(name, src_name=None): 17 | if src_name is None: 18 | src_name = name 19 | TaskRegistry.add( 20 | name, 21 | # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function. 22 | source=seqio.TfdsDataSource( 23 | tfds_name=f"ref_coco/{src_name}:1.0.0", 24 | tfds_data_dir=config.MULTITASK_TFDS_DATA_DIR, 25 | ), 26 | preprocessors=[ 27 | functools.partial( 28 | rekey, key_map={ 29 | "image": ["image"], 30 | "bbox": ["objects", "bbox"], 31 | "label": ["objects", "refexp", "raw"], 32 | "refexp_id": ["objects", "refexp", "refexp_id"], 33 | }), 34 | functools.partial( 35 | preprocessing.refer_expression_preprocessor, 36 | dataset_name=name, 37 | ), 38 | unified_io_preprocessor, 39 | ], 40 | postprocess_fn=return_meta, 41 | metric_fns=[metrics.ref_exp_metric], 42 | output_features=modality_processing.OUTPUT_FEATURES, 43 | ) 44 | 45 | 46 | add_refexp("refcoco_unc") 47 | add_refexp("refcocog_google") 48 | add_refexp("refcoco_plus_unc", "refcocoplus_unc") 49 | 50 | 51 | TaskRegistry.add( 52 | "image_generation_coco_2017", 53 | source=seqio.TfdsDataSource( 54 | tfds_name="coco_all:1.0.1", 55 | tfds_data_dir=config.MULTITASK_TFDS_DATA_DIR, 56 | ), 57 | preprocessors=[ 58 | functools.partial( 59 | rekey, key_map={ 60 | "image/filename": ["image/filename"], 61 | "image": ["image"], 62 | "captions": ["captions", "text"] 63 | }), 64 | functools.partial( 65 | preprocessing.image_generation_preprocessor, 66 | dataset_name="image_generation_coco_2017", 67 | ), 68 | unified_io_preprocessor, 69 | ], 70 | output_features=modality_processing.OUTPUT_FEATURES, 71 | ) 72 | 73 | 74 | TaskRegistry.add( 75 | "image_caption_coco_2017", 76 | source=seqio.TfdsDataSource( 77 | tfds_name="coco_all:1.0.1", 78 | tfds_data_dir=config.MULTITASK_TFDS_DATA_DIR, 79 | ), 80 | preprocessors=[ 81 | functools.partial( 82 | rekey, key_map={ 83 | "image/filename": ["image/filename"], 84 | "image": ["image"], 85 | "captions": ["captions", "text"] 86 | }), 87 | functools.partial( 88 | preprocessing.image_caption_preprocessor, 89 | dataset_name="image_caption_coco_2017", 90 | ), 91 | unified_io_preprocessor 92 | ], 93 | output_features=modality_processing.OUTPUT_FEATURES, 94 | ) 95 | 96 | 97 | TaskRegistry.add( 98 | "image_inpainting_coco", 99 | source=seqio.TfdsDataSource( 100 | tfds_name="coco_all:1.0.1", 101 | tfds_data_dir=config.MULTITASK_TFDS_DATA_DIR, 102 | ), 103 | preprocessors=[ 104 | functools.partial( 105 | rekey, key_map={ 106 | "image": ["image"], 107 | "bbox": ["objects", "bbox"], 108 | "label": ["objects", "label"], 109 | }), 110 | functools.partial( 111 | preprocessing.image_inpainting_preprocessor, 112 | dataset_name="image_inpainting_coco", 113 | class_names='metadata/coco/coco_class_name_2017.json', 114 | ), 115 | unified_io_preprocessor 116 | ], 117 | output_features=modality_processing.OUTPUT_FEATURES, 118 | ) 119 | 120 | 121 | TaskRegistry.add( 122 | "vqa_coco_2017", 123 | # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function. 124 | source=seqio.TfdsDataSource( 125 | tfds_name="coco_all:1.0.1", 126 | tfds_data_dir=config.MULTITASK_TFDS_DATA_DIR, 127 | ), 128 | preprocessors=[ 129 | functools.partial( 130 | rekey, key_map={ 131 | "image": ["image"], 132 | "text_inputs": ["vqa", "questions"], 133 | "text_targets": ["vqa", "answers"], 134 | }), 135 | preprocessing.vqa_preprocessor, 136 | unified_io_preprocessor 137 | ], 138 | postprocess_fn=functools.partial(return_field, field="meta/all_references"), 139 | metric_fns=[metrics.vqa_metric], 140 | output_features=modality_processing.OUTPUT_FEATURES, 141 | ) 142 | 143 | 144 | TaskRegistry.add( 145 | "box_classification_coco_2017", 146 | # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function. 147 | source=seqio.TfdsDataSource( 148 | tfds_name="coco_all:1.0.1", 149 | tfds_data_dir=config.MULTITASK_TFDS_DATA_DIR, 150 | ), 151 | preprocessors=[ 152 | functools.partial( 153 | rekey, key_map={ 154 | "image": ["image"], 155 | "bbox": ["objects", "bbox"], 156 | "label": ["objects", "label"], 157 | "image_id": ["image/filename"], 158 | }), 159 | functools.partial( 160 | preprocessing.box_classification_preprocessor, 161 | dataset_name='box_classification_coco_2017', 162 | class_names='metadata/coco/coco_class_name_2017.json', 163 | ), 164 | unified_io_preprocessor 165 | ], 166 | postprocess_fn=functools.partial(return_field, field="meta/label"), 167 | metric_fns=[exact_match], 168 | output_features=modality_processing.OUTPUT_FEATURES, 169 | ) -------------------------------------------------------------------------------- /t5x/examples/unified_io/evaluator.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from dataclasses import dataclass 3 | from typing import Optional, Mapping, Sequence, Any, List, Dict 4 | 5 | import gin 6 | import seqio 7 | from absl import logging 8 | from seqio import metrics as metrics_lib 9 | 10 | from t5x.examples.unified_io.data.data_utils import get_default_vocabulary 11 | 12 | AllOutputTokensType = Mapping[str, Sequence[Sequence[int]]] 13 | AllOutputScoresType = Mapping[str, Sequence[float]] 14 | AllOutputAuxValuesType = Mapping[str, Mapping[str, Sequence[Any]]] 15 | AllMetricsType = Mapping[str, Mapping[str, Any]] 16 | 17 | 18 | @dataclass 19 | class UnifiedIOOutput: 20 | """Wrapper to make it easier to work with the many different outputs of UIO2""" 21 | aux_values: Dict 22 | 23 | @property 24 | def text(self): 25 | return self.aux_values.get("text") 26 | 27 | @property 28 | def text_tokens(self): 29 | return self.aux_values.get("text-tokens") 30 | 31 | @property 32 | def image_tokens(self): 33 | return self.aux_values.get("image-tokens") 34 | 35 | @property 36 | def image(self): 37 | return self.aux_values.get("image") 38 | 39 | @property 40 | def audio(self): 41 | return self.aux_values.get("audio") 42 | 43 | @property 44 | def scores(self): 45 | if "scores" in self.aux_values: 46 | return self.aux_values["scores"] 47 | else: 48 | # Assume only one modality is presennt 49 | for x in ["text", "image", "audio"]: 50 | x = f"{x}-scores" 51 | if x in self.aux_values: 52 | return self.aux_values[x] 53 | raise ValueError(f"No scores found in self.aux_values, keys={self.aux_values.keys()}") 54 | 55 | 56 | def build_uio_outputs(aux_values, vocab) -> List[UnifiedIOOutput]: 57 | out = [] 58 | n = len(next(iter(aux_values.values()))) 59 | for ix in range(n): 60 | values = {k: v[ix] for k, v in aux_values.items()} 61 | txt_tokens = values.get("text-tokens") 62 | if txt_tokens is None: 63 | pass 64 | elif len(txt_tokens.shape) == 1: 65 | values["text"] = vocab.decode(txt_tokens) 66 | else: 67 | values["text"] = [vocab.decode(x) for x in txt_tokens] 68 | out.append(UnifiedIOOutput(values)) 69 | return out 70 | 71 | 72 | @gin.configurable() 73 | class UnifiedIOEvaluator(seqio.Evaluator): 74 | """Evaluator for UnifiedIO 2""" 75 | # This class basically follows `seqio.Evaluator` but has a few UIO2 hacks 76 | 77 | def __init__( 78 | self, 79 | mixture_or_task_name: str, 80 | feature_converter, 81 | eval_split: str = "validation", 82 | use_cached: bool = False, 83 | seed: Optional[int] = 42, 84 | sequence_length: Optional[Mapping[str, int]] = None, 85 | num_examples: Optional[int] = None, 86 | shuffle: bool = False, 87 | logger_cls: Sequence = (), 88 | log_dir: Optional[str] = None, 89 | use_memory_cache: bool = True, 90 | target_field_name: str = "targets", 91 | ): 92 | # We use a simplified `sequence_length` that does not contain fields that exactly match 93 | # the Dataset fields. This can cause an issue because the evaluator will delete those 94 | # non-matching entries before the feature conversion stage, so we alias them to these 95 | # names that do match the dataset structure here so they will be preserved. 96 | if "text_inputs" in sequence_length: 97 | sequence_length["inputs/text/tokens"] = sequence_length["text_inputs"] 98 | if "text_targets" in sequence_length: 99 | sequence_length["targets/text/tokens"] = sequence_length["text_targets"] 100 | super().__init__( 101 | mixture_or_task_name, feature_converter, eval_split, use_cached, seed, sequence_length, 102 | num_examples, shuffle, logger_cls, log_dir, use_memory_cache, target_field_name 103 | ) 104 | 105 | def _compute_metrics(self, 106 | predicted_tokens: AllOutputTokensType, 107 | scores: AllOutputScoresType, 108 | all_aux_values: AllOutputAuxValuesType, 109 | step: Optional[int] = None) -> AllMetricsType: 110 | 111 | vocab = get_default_vocabulary() 112 | all_metrics = {} 113 | for task in self.eval_tasks: 114 | logging.info("Computing metrics for %s", task.name) 115 | task_dataset = self.cached_task_datasets[task.name] 116 | targets = self.cached_targets[task.name] 117 | task_metrics = [] 118 | inferences = {} 119 | 120 | if task.predict_metric_fns or task.predict_with_aux_metric_fns: 121 | (outputs, 122 | postprocessed_outputs) = self._decode_and_postprocess_predictions( 123 | task, predicted_tokens, task_dataset, targets) 124 | inferences["output"] = outputs 125 | inferences["prediction"] = postprocessed_outputs 126 | 127 | if task.predict_metric_fns: 128 | task_metrics.extend([ 129 | metric_fn(targets, inferences["prediction"]) 130 | for metric_fn in task.predict_metric_fns 131 | ]) 132 | 133 | if task.predict_with_aux_metric_fns: 134 | aux_values = all_aux_values[task.name] 135 | uio_output = build_uio_outputs(aux_values, vocab) 136 | task_metrics.extend([ 137 | metric_fn(targets, uio_output, aux_values) 138 | for metric_fn in task.predict_with_aux_metric_fns 139 | ]) 140 | inferences["aux_value"] = aux_values 141 | 142 | if task.score_metric_fns: 143 | task_scores = scores[task.name] 144 | if len(targets) != len(task_scores): 145 | raise ValueError(f"len(targets)({len(targets)}) != " 146 | f"len(task_scores)({len(task_scores)})") 147 | task_metrics.extend([ 148 | metric_fn(targets, task_scores) 149 | for metric_fn in task.score_metric_fns 150 | ]) 151 | inferences["score"] = task_scores 152 | 153 | all_metrics[task.name] = {} 154 | for k, v in itertools.chain(*[m.items() for m in task_metrics]): 155 | if k in all_metrics[task.name]: 156 | raise ValueError(f"Duplicate metric key '{k}' in Task '{task.name}'.") 157 | all_metrics[task.name][k] = v 158 | 159 | metrics = { 160 | k: metrics_lib.Scalar(v) 161 | if not isinstance(v, metrics_lib.MetricValue) else v 162 | for k, v in all_metrics[task.name].items() 163 | } 164 | for logger in self.loggers: 165 | logger(task_name=task.name, step=step, metrics=metrics, 166 | dataset=task_dataset, inferences=inferences, targets=targets) 167 | 168 | return all_metrics 169 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/metrics/grit_keypoint.py: -------------------------------------------------------------------------------- 1 | # The object keypoint score (OKS) is computed and averaged within each image 2 | import numpy as np 3 | from scipy.optimize import linear_sum_assignment 4 | 5 | N_KEYPOINTS = 17 6 | N_DIM = 3 7 | 8 | 9 | def get_bbox_from_kp(kp): 10 | k_array3d = np.reshape(np.array(kp),(N_KEYPOINTS,N_DIM)) 11 | kp = k_array3d[:,:2] 12 | k_vis = k_array3d[:,2] 13 | kp_only_labeled = kp[k_vis > 0] 14 | if len(kp_only_labeled) == 0: 15 | raise ValueError("All points are marked as not visible!") 16 | x_min = kp_only_labeled[:,0].min() 17 | y_min = kp_only_labeled[:,1].min() 18 | x_max = kp_only_labeled[:,0].max() 19 | y_max = kp_only_labeled[:,1].max() 20 | bbox = np.array([x_min, y_min, x_max, y_max]) 21 | return bbox 22 | 23 | 24 | def computeOks(dts, gts): 25 | """ 26 | analogous to computing IoUs for localization / segmentation 27 | """ 28 | ious = np.zeros((len(dts), len(gts))) 29 | sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0 30 | variances = (sigmas * 2)**2 31 | 32 | # compute oks between each detection and ground truth object 33 | for j, gt in enumerate(gts): 34 | g = np.array(gt) 35 | xg = g[0::3]; yg = g[1::3]; vg = g[2::3] 36 | x_min, y_min, x_max, y_max = get_bbox_from_kp(gt) 37 | area = (y_max-y_min)*(x_max-x_min) 38 | for i, dt in enumerate(dts): 39 | d = np.array(dt) 40 | xd = d[0::3]; yd = d[1::3] 41 | dx = xd - xg 42 | dy = yd - yg 43 | e = (dx**2 + dy**2) / variances / (area+np.spacing(1)) / 2 44 | e = e[vg > 0] 45 | ious[i, j] = np.sum(np.exp(-e)) / e.shape[0] 46 | return ious 47 | 48 | 49 | def assign_instances(pred_points: list, gt_points: list): 50 | ious = computeOks(pred_points, gt_points) 51 | cost = -ious 52 | 53 | # solve assignment 54 | pred_ids, gt_ids = linear_sum_assignment(cost) 55 | pair_ids = list(zip(pred_ids, gt_ids)) 56 | 57 | # select assignments with iou > 0 58 | pair_ids = [(i,j) for i,j in pair_ids if ious[i,j] > 0] 59 | pairs = [(pred_points[i],gt_points[j]) for i,j in pair_ids] 60 | pair_ious = [ious[i,j] for i,j in pair_ids] 61 | 62 | return pairs, pair_ious, pair_ids 63 | 64 | 65 | def kp_metric(pred_points: list, gt_points: list, return_pairs=False) -> float: 66 | num_pred = len(pred_points) 67 | num_gt = len(gt_points) 68 | if num_pred == 0 and num_gt == 0: 69 | return 1 if not return_pairs else (1, [], []) 70 | elif min(num_pred,num_gt) == 0 and max(num_pred,num_gt) > 0: 71 | return 0 if not return_pairs else (0, [], []) 72 | 73 | pairs, pair_ious, pair_ids = assign_instances(pred_points, gt_points) 74 | num_detected = len(pairs) 75 | num_missed = num_gt - num_detected 76 | score = np.sum(pair_ious) / (num_pred + num_missed) 77 | 78 | return score if not return_pairs else (score, pairs, pair_ids) 79 | 80 | 81 | def kp_metric_wrapper(prediction, ground_truth, return_pairs=False): 82 | if 'output' in prediction: 83 | prediction = prediction['output'] 84 | assert ground_truth['output']['example_id'] == prediction['example_id'] 85 | pred_points = prediction['points'] 86 | gt_points = ground_truth['output']['points'] 87 | return kp_metric(pred_points, gt_points, return_pairs) -------------------------------------------------------------------------------- /t5x/examples/unified_io/metrics/grit_localization.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | from scipy.optimize import linear_sum_assignment 5 | 6 | 7 | def compute_iou(bbox1: list, bbox2: list, verbose: bool=False): 8 | x1, y1, x2, y2 = bbox1 9 | x1_, y1_, x2_, y2_ = bbox2 10 | 11 | x1_in = max(x1, x1_) 12 | y1_in = max(y1, y1_) 13 | x2_in = min(x2, x2_) 14 | y2_in = min(y2, y2_) 15 | 16 | intersection = compute_area(bbox=[x1_in, y1_in, x2_in, y2_in], invalid=0.0) 17 | area1 = compute_area(bbox1, invalid=0) 18 | area2 = compute_area(bbox2, invalid=0) 19 | union = area1 + area2 - intersection 20 | iou = intersection / (union + 1e-6) 21 | 22 | if verbose: 23 | return iou, intersection, union 24 | 25 | return iou 26 | 27 | 28 | def compute_area(bbox: list, invalid: float=None) -> float: 29 | x1, y1, x2, y2 = bbox 30 | 31 | if (x2 <= x1) or (y2 <= y1): 32 | area = invalid 33 | else: 34 | area = (x2 - x1) * (y2 - y1) 35 | 36 | return area 37 | 38 | 39 | def assign_boxes(pred_boxes: List[List], gt_boxes: List[List]): 40 | n1 = len(pred_boxes) 41 | n2 = len(gt_boxes) 42 | cost = np.zeros([n1,n2]) 43 | ious = np.zeros([n1,n2]) 44 | for i,bbox1 in enumerate(pred_boxes): 45 | for j,bbox2 in enumerate(gt_boxes): 46 | iou = compute_iou(bbox1,bbox2) 47 | ious[i,j] = iou 48 | cost[i,j] = 1-iou 49 | 50 | # solve assignment 51 | pred_box_ids, gt_box_ids = linear_sum_assignment(cost) 52 | pair_ids = list(zip(pred_box_ids, gt_box_ids)) 53 | 54 | # select assignments with iou > 0 55 | pair_ids = [(i,j) for i,j in pair_ids if ious[i,j] > 0] 56 | pairs = [(pred_boxes[i],gt_boxes[j]) for i,j in pair_ids] 57 | pair_ious = [ious[i,j] for i,j in pair_ids] 58 | 59 | return pairs, pair_ious, pair_ids 60 | 61 | 62 | def loc_metric(pred_boxes: List[List], gt_boxes: List[List]) -> float: 63 | num_pred = len(pred_boxes) 64 | num_gt = len(gt_boxes) 65 | if num_pred == 0 and num_gt == 0: 66 | return 1 67 | elif min(num_pred,num_gt) == 0 and max(num_pred,num_gt) > 0: 68 | return 0 69 | 70 | pairs, pair_ious, pair_ids = assign_boxes(pred_boxes,gt_boxes) 71 | num_detected = len(pairs) 72 | num_missed = num_gt - num_detected 73 | return np.sum(pair_ious) / (num_pred + num_missed) -------------------------------------------------------------------------------- /t5x/examples/unified_io/metrics/grit_segmentation.py: -------------------------------------------------------------------------------- 1 | from scipy.optimize import linear_sum_assignment 2 | import numpy as np 3 | from scipy import ndimage 4 | 5 | 6 | # https://github.com/bowenc0221/boundary-iou-api/blob/master/boundary_iou/utils/boundary_utils.py 7 | # General util function to get the boundary of a binary mask. 8 | def mask_to_boundary(mask, dilation_ratio=0.02): 9 | """ 10 | Convert binary mask to boundary mask. 11 | :param mask (numpy array, uint8): binary mask 12 | :param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal 13 | :return: boundary mask (numpy array) 14 | """ 15 | h, w = mask.shape 16 | img_diag = np.sqrt(h ** 2 + w ** 2) 17 | dilation = int(round(dilation_ratio * img_diag)) 18 | if dilation < 1: 19 | dilation = 1 20 | # Pad image so mask truncated by the image border is also considered as boundary. 21 | new_mask = np.pad(mask, [[1, 1], [1, 1]], constant_values=0) 22 | kernel = np.ones((3, 3), dtype=np.uint8) 23 | new_mask_erode = ndimage.binary_erosion(new_mask, kernel, iterations=dilation) 24 | mask_erode = new_mask_erode[1 : h + 1, 1 : w + 1] 25 | # G_d intersects G in the paper. 26 | return np.logical_xor(mask, mask_erode) 27 | 28 | 29 | def compute_iou(mask1, mask2, verbose=False): 30 | # resize predicted mask to be the same size as gt mask 31 | if mask1.shape != mask2.shape: 32 | raise NotImplementedError() 33 | 34 | intersection = np.sum(np.logical_and(mask1, mask2)) 35 | union = np.sum(np.logical_or(mask1, mask2)) 36 | 37 | iou = intersection / (union + 1e-6) 38 | 39 | if verbose: 40 | return iou, intersection, union 41 | 42 | return iou 43 | 44 | 45 | def assign_segmentations(pred_masks: list, gt_masks: list): 46 | n1 = len(pred_masks) 47 | n2 = len(gt_masks) 48 | cost = np.zeros([n1,n2]) 49 | ious = np.zeros([n1,n2]) 50 | for i,mask1 in enumerate(pred_masks): 51 | for j,mask2 in enumerate(gt_masks): 52 | iou = compute_iou(mask1,mask2) 53 | ious[i,j] = iou 54 | cost[i,j] = -iou 55 | 56 | # solve assignment 57 | pred_mask_ids, gt_mask_ids = linear_sum_assignment(cost) 58 | pair_ids = list(zip(pred_mask_ids, gt_mask_ids)) 59 | 60 | # select assignments with iou > 0 61 | pair_ids = [(i,j) for i,j in pair_ids if ious[i,j] > 0] 62 | pairs = [(pred_masks[i],gt_masks[j]) for i,j in pair_ids] 63 | pair_ious = [ious[i,j] for i,j in pair_ids] 64 | 65 | return pairs, pair_ious, pair_ids 66 | 67 | 68 | # expects numpy arrays, could change to RLEs and use rle iou / merge fuction above 69 | def seg_metric(pred_masks: list, gt_masks: list, stuff: bool, return_pairs=False) -> float: 70 | """ 71 | pred_masks: list of numpy arrays representing binary masks 72 | gt_masks: list of numpy arrays representing binary masks 73 | stuff: boolean for evaluation type (False for "Thing") 74 | pairs: return assignment pairs between prediction and gt instances 75 | """ 76 | if stuff: # merge masks into to single mask 77 | pred_masks = [np.logical_or.reduce(pred_masks)] if pred_masks else [] 78 | gt_masks = [np.logical_or.reduce(gt_masks)] 79 | 80 | num_pred = len(pred_masks) 81 | num_gt = len(gt_masks) 82 | if num_pred == 0 and num_gt == 0: 83 | return 1 if not return_pairs else (1, [], []) 84 | elif min(num_pred,num_gt) == 0 and max(num_pred,num_gt) > 0: 85 | return 0 if not return_pairs else (0, [], []) 86 | 87 | # uses BoundaryIoU 88 | # TODO double check mask_to_boundary is officially used, its not clear from the grit codebase 89 | pred_boundaries = [mask_to_boundary(m) for m in pred_masks] 90 | gt_boundaries = [mask_to_boundary(m) for m in gt_masks] 91 | 92 | pairs, pair_ious, pair_ids = assign_segmentations(pred_boundaries,gt_boundaries) 93 | num_detected = len(pairs) 94 | num_missed = num_gt - num_detected 95 | score = np.sum(pair_ious) / (num_pred + num_missed) 96 | 97 | return score if not return_pairs else (score, pairs, pair_ids) 98 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/metrics/grit_vqa.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | contractions = { 4 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \ 5 | "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \ 6 | "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \ 7 | "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \ 8 | "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \ 9 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \ 10 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \ 11 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \ 12 | "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \ 13 | "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \ 14 | "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \ 15 | "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \ 16 | "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \ 17 | "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \ 18 | "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \ 19 | "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \ 20 | "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \ 21 | "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \ 22 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \ 23 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \ 24 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \ 25 | "youll": "you'll", "youre": "you're", "youve": "you've"} 26 | 27 | manualMap = { 28 | 'none': '0', 29 | 'zero': '0', 30 | 'one': '1', 31 | 'two': '2', 32 | 'three': '3', 33 | 'four': '4', 34 | 'five': '5', 35 | 'six': '6', 36 | 'seven': '7', 37 | 'eight': '8', 38 | 'nine': '9', 39 | 'ten': '10'} 40 | 41 | articles = ['a','an','the'] 42 | 43 | punct = [ 44 | ';', r"/", '[', ']', '"', '{', '}', 45 | '(', ')', '=', '+', '\\', '_', '-', 46 | '>', '<', '@', '`', ',', '?', '!'] 47 | 48 | periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") 49 | commaStrip = re.compile("(\d)(\,)(\d)") 50 | 51 | 52 | def processPunctuation(inText): 53 | outText = inText 54 | for p in punct: 55 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(commaStrip, inText) != None): 56 | outText = outText.replace(p, '') 57 | else: 58 | outText = outText.replace(p, ' ') 59 | outText = periodStrip.sub("",outText,re.UNICODE) 60 | return outText 61 | 62 | 63 | def processDigitArticle(inText): 64 | outText = [] 65 | tempText = inText.lower().split() 66 | for word in tempText: 67 | word = manualMap.setdefault(word, word) 68 | if word not in articles: 69 | outText.append(word) 70 | else: 71 | pass 72 | for wordId, word in enumerate(outText): 73 | if word in contractions: 74 | outText[wordId] = contractions[word] 75 | outText = ' '.join(outText) 76 | return outText 77 | 78 | 79 | def preprocess_answer(ans, cache={}): 80 | """GRIT VQA pre-processing""" 81 | if ans in cache: 82 | return cache[ans] 83 | ans = ans.replace('\n', ' ') 84 | ans = ans.replace('\t',' ') 85 | ans = ans.lower().strip() 86 | preprocessed = processDigitArticle(processPunctuation(ans)) 87 | cache[ans] = preprocessed 88 | return preprocessed 89 | 90 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | """Metrics for UiO2 tasks""" 2 | from collections import Counter 3 | 4 | from absl import logging 5 | from typing import List, Sequence 6 | 7 | import numpy as np 8 | from seqio.metrics import Scalar, Text 9 | 10 | from t5x.examples.unified_io.metrics.utils import extract_coordinates_from_text, \ 11 | undo_box_preprocessing 12 | 13 | from t5x.examples.unified_io.evaluator import UnifiedIOOutput 14 | 15 | from t5x.examples.unified_io.metrics.grit_localization import compute_iou 16 | from t5x.examples.unified_io.metrics.grit_vqa import preprocess_answer as vqa_preprocessing 17 | from t5x.examples.unified_io import config 18 | 19 | 20 | def exact_match(targets, predictions: List[UnifiedIOOutput], 21 | aux_values, print_examples=True): 22 | if isinstance(targets[0], np.ndarray): 23 | # Multiple correct answers stored in a numpy object array 24 | matches = [pred.text.lower() in [x.decode("utf-8").lower() for x in target] for target, pred in zip(targets, predictions)] 25 | else: 26 | if isinstance(targets[0], bytes): 27 | targets = [x.decode("utf-8") for x in targets] 28 | matches = [pred.text.lower() == target.lower() for target, pred in zip(targets, predictions)] 29 | if print_examples: 30 | ixs = np.random.choice(len(targets), min(20, len(targets)), replace=False) 31 | examples = [f"pred={predictions[i].text} gt={targets[i]}" for i in ixs] 32 | for ex in examples: 33 | logging.info(ex) 34 | return { 35 | "score": np.mean(matches), 36 | } 37 | 38 | 39 | def vqa_score(target, pred): 40 | pred = vqa_preprocessing(pred) 41 | if isinstance(target, list): 42 | target = Counter(vqa_preprocessing(x) for x in target) 43 | return min(target[pred] / 3.0, 1) 44 | else: 45 | return float(vqa_preprocessing(pred) == vqa_preprocessing(target)) 46 | 47 | 48 | def vqa_metric(targets: Sequence, predictions: Sequence[UnifiedIOOutput], aux_values): 49 | if isinstance(targets[0], np.ndarray): 50 | targets = [[ans.decode("utf-8") for ans in answer_set] for answer_set in targets] 51 | else: 52 | targets = [answer.decode("utf-8") for answer in targets] 53 | score = np.mean([vqa_score(t, p.text) for t, p in zip(targets, predictions)]) 54 | n_targets = len(targets) 55 | ixs = np.random.choice(n_targets, min(n_targets, 20), replace=False) 56 | examples = [f"{predictions[i].text.lower()} (gt(s)={', '.join(targets[i])})" for i in ixs] 57 | return { 58 | "score": Scalar(score), 59 | "examples": Text(", ".join(examples)) 60 | } 61 | 62 | 63 | def ref_exp_metric(targets, predictions: List[UnifiedIOOutput], 64 | aux_values, original_scale=True): 65 | total_acc = 0 66 | total_iou = 0 67 | for target, pred in zip(targets, predictions): 68 | gt_boxes, image_info, src_boxes = target["boxes"], target["image_info"], target["src_boxes"] 69 | if len(gt_boxes) != 1: 70 | raise ValueError("Should always be one ground truth box") 71 | 72 | p_boxes, classes = extract_coordinates_from_text( 73 | pred.text, image_size=config.IMAGE_INPUT_SIZE, n_coordinates=4, use_label=False) 74 | if original_scale: 75 | p_boxes = undo_box_preprocessing(p_boxes, image_info) 76 | h, w = image_info[3:5] 77 | gt_boxes = src_boxes * np.array([h, w, h, w]).reshape(1, 4) 78 | if len(p_boxes) == 0: 79 | iou = 0 80 | else: 81 | iou = compute_iou(p_boxes[0], gt_boxes[0]) 82 | total_iou += iou 83 | total_acc += float((iou > 0.5)) 84 | 85 | n = len(predictions) 86 | return dict(acc=total_acc/n, iou=total_iou/n) 87 | 88 | 89 | def null_metric(targets, predictions, aux_values): 90 | return {} 91 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/metrics/rle.py: -------------------------------------------------------------------------------- 1 | from pycocotools import mask 2 | 3 | import json 4 | import numpy as np 5 | from pycocotools import mask as maskUtils 6 | 7 | 8 | def uncompressed_encode(binary_mask): 9 | binary_mask = np.asfortranarray(binary_mask) 10 | uncompressed_rle = {'counts': [], 'size': list(binary_mask.shape)} 11 | counts = uncompressed_rle.get('counts') 12 | 13 | last_elem = 0 14 | running_length = 0 15 | 16 | for i, elem in enumerate(binary_mask.ravel(order='F')): 17 | if elem == last_elem: 18 | pass 19 | else: 20 | counts.append(running_length) 21 | running_length = 0 22 | last_elem = elem 23 | running_length += 1 24 | 25 | counts.append(running_length) 26 | 27 | return uncompressed_rle 28 | 29 | 30 | def compress(uncompressed_rle): 31 | compressed_rle = mask.frPyObjects(uncompressed_rle, uncompressed_rle.get('size')[0], uncompressed_rle.get('size')[1]) 32 | return compressed_rle 33 | 34 | 35 | def to_utf(rle): 36 | rle = rle.copy() 37 | rle['counts'] = rle['counts'].decode("utf-8", "backslashreplace") 38 | return rle 39 | 40 | 41 | def from_utf(rle): 42 | rle = rle.copy() 43 | rle['counts'] = rle['counts'].encode("utf-8") 44 | return rle 45 | 46 | 47 | def encode(binary_mask, utf=True): 48 | encoded = maskUtils.encode(np.asfortranarray(binary_mask)) 49 | return to_utf(encoded) if utf else encoded 50 | 51 | 52 | def decode(rle, utf=True): 53 | if type(rle) == list: 54 | return [decode(r, utf) for r in rle] 55 | else: 56 | rle = from_utf(rle) if utf else rle 57 | decoded = maskUtils.decode(rle) 58 | return decoded 59 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/model_test.py: -------------------------------------------------------------------------------- 1 | import jax.random 2 | from absl.testing.absltest import TestCase 3 | from flax import traverse_util 4 | 5 | from t5x.examples.unified_io import test_utils, modality_processing 6 | from t5x.examples.unified_io.data.data_utils import get_default_vocabulary 7 | from t5x.examples.unified_io.modality_processing import unified_io_preprocessor, \ 8 | UnifiedIOFeatureConverter 9 | import numpy as np 10 | 11 | from t5x.examples.unified_io.models import EncoderDecoderModel 12 | from t5x.examples.unified_io.test_utils import DEBUG_CONFIG 13 | from t5x.examples.unified_io.utils import get_model 14 | import tensorflow as tf 15 | 16 | 17 | class ModelTest(TestCase): 18 | 19 | @classmethod 20 | def setUpClass(cls): 21 | cls.model = get_model("tiny", ["text"], ["text"]) 22 | cls.cfg = cls.model.module.config 23 | 24 | def test_with_choices(self): 25 | """Test predictions with `choices` is consistent with computing the loss""" 26 | seq_len = dict(text_inputs=32, text_targets=16) 27 | ds = tf.data.Dataset.from_tensors(dict( 28 | text_inputs="Which answer is best?", 29 | text_targets="", 30 | choices=["a fat cat", "a quick dog"] 31 | )) 32 | ds = unified_io_preprocessor(ds, modality_processing.OUTPUT_FEATURES, seq_len) 33 | ds = UnifiedIOFeatureConverter()(ds, seq_len) 34 | batch = next(ds.repeat(2).batch(2).as_numpy_iterator()) 35 | 36 | variables = self.model.get_initial_variables( 37 | jax.random.PRNGKey(5919), 38 | {k: v.shape for k, v in batch.items()}, 39 | {k: v.dtype for k, v in batch.items()} 40 | )["params"] 41 | 42 | _, aux = self.model.predict_batch_with_aux(variables, batch) 43 | 44 | # Take the highest-ranked choices and compute loss as a regular batch 45 | # We have to manually build EOS and auto-regressive inputs 46 | tokens = aux["text-tokens"] 47 | features = traverse_util.unflatten_dict(batch, sep="/") 48 | features["targets"] = dict(text=dict( 49 | inputs=np.pad(tokens[:, :-1], [[0, 0], [1, 0]]), 50 | mask=(tokens>0).astype(np.int32), 51 | targets=tokens 52 | )) 53 | batch = traverse_util.flatten_dict(features, sep="/") 54 | loss = self.model.loss_fn( 55 | variables, batch, None, z_loss=0.0, loss_normalizing_by_weight_sum=False, 56 | loss_normalizing_factor=1)[0] 57 | self.assertAlmostEqual(float(loss), float(aux["scores"].sum()), places=4) 58 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/packing_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from absl.testing import absltest 4 | 5 | from t5x.examples.unified_io.packing import batch_with_constraints, pack_in_pairs 6 | 7 | 8 | def build_mask(lens): 9 | c1 = np.floor(1, lens * np.random.random(lens.shape)) 10 | c2 = lens - c1 11 | l1 = np.maximum(c1.max() + np.random.randint(0, 3), 1) 12 | out1 = np.arange(l1)[None, :] < c1[:, None] 13 | l2 = np.maximum(c2.max() + np.random.randint(0, 3), 1) 14 | out2 = np.arange(l2)[None, :] < c2[:, None] 15 | return out1, out2 16 | 17 | 18 | def build_ds(lens): 19 | input_lens = np.array([x[0] for x in lens]) 20 | target_lens = np.array([x[1] for x in lens]) 21 | i1, i2 = build_mask(input_lens) 22 | t1, t2 = build_mask(target_lens) 23 | data = { 24 | "inputs/image/mask": i1, 25 | "inputs/text/mask": i2, 26 | "targets/image/mask": t1, 27 | "targets/text/mask": t2, 28 | "ixs": np.expand_dims(1 + np.arange(len(input_lens)), 1) 29 | } 30 | return tf.data.Dataset.from_tensor_slices(data) 31 | 32 | 33 | def tolist(ex): 34 | return {k: v.tolist() for k, v in ex.items()} 35 | 36 | 37 | class TestConstraintBatching(absltest.TestCase): 38 | 39 | def test_tiny(self): 40 | ds = tf.data.Dataset.from_tensor_slices(dict( 41 | c1=tf.convert_to_tensor([9, 12, 3, 4]) 42 | )) 43 | batches = list(ex["c1"].tolist() for ex in batch_with_constraints( 44 | ds, 2, 5, [(lambda x: x["c1"], 20)]).as_numpy_iterator()) 45 | self.assertEqual([[9, 3], [12, 4]], batches) 46 | 47 | batches = list(ex["c1"].tolist() for ex in batch_with_constraints( 48 | ds, 2, 5, [(lambda x: x["c1"], 100)]).as_numpy_iterator()) 49 | self.assertEqual([[9, 12], [3, 4]], batches) 50 | 51 | def test_multiple(self): 52 | ds = tf.data.Dataset.from_tensor_slices(dict( 53 | c1=tf.convert_to_tensor([2, 2, 2, 2, 2, 14, 2, 2]), 54 | c2=tf.convert_to_tensor([11, 12, 13, 2, 2, 2, 2, 2]), 55 | example_ids=tf.range(8) 56 | )) 57 | 58 | def _fn1(x): 59 | return x["c1"] 60 | 61 | def _fn2(x): 62 | return x["c2"] 63 | 64 | batch_fns = [(_fn1, 15.1), (_fn2, 20)] 65 | 66 | batches = list(set(ex["example_ids"].tolist()) for ex in batch_with_constraints( 67 | ds, 3, 10, batch_fns).as_numpy_iterator()) 68 | expected = [{0, 3, 4}, {1, 6, 7}] 69 | self.assertEqual(expected, batches) 70 | 71 | def test_random(self): 72 | seed = tf.convert_to_tensor([0, 83452]) 73 | ds = tf.data.Dataset.from_tensor_slices(dict( 74 | c1=tf.random.stateless_uniform((100,), seed, 1, 10, dtype=tf.int64), 75 | c2=tf.random.stateless_uniform((100,), seed+1, 1, 10, dtype=tf.int64) 76 | )) 77 | const = [(lambda x: x["c1"], 18), 78 | (lambda x: x["c2"], 18)] 79 | for ex in batch_with_constraints(ds, 3, 10, const).as_numpy_iterator(): 80 | self.assertLessEqual(ex["c1"].sum(), const[0][1]) 81 | self.assertLessEqual(ex["c1"].sum(), const[1][1]) 82 | 83 | 84 | def _pack_in_pairs(*args, **kwargs, ): 85 | return pack_in_pairs( 86 | *args, **kwargs, 87 | encoder_masks=[ 88 | ("inputs/image/mask", None), 89 | ("inputs/text/mask", None) 90 | ], 91 | decoder_masks=[ 92 | "targets/image/mask", 93 | "targets/text/mask" 94 | ] 95 | ) 96 | 97 | 98 | class TestPackingBatching(absltest.TestCase): 99 | 100 | def test_tiny_pack(self): 101 | ds = build_ds([(3, 2), (2, 1)]) 102 | ds = _pack_in_pairs(ds, 1, 5, 5, pool_size=2) 103 | exs = next(ds.as_numpy_iterator()) 104 | self.assertEqual(set(exs["ixs"].ravel().tolist()), {1, 2}) 105 | 106 | def test_tiny_pack2(self): 107 | ds = build_ds([(3, 2), (5, 2)]) 108 | ds = _pack_in_pairs(ds, 1, 6, 6, pool_size=1) 109 | exs = next(ds.as_numpy_iterator()) 110 | self.assertEqual(set(exs["ixs"].ravel().tolist()), {0, 2}) 111 | 112 | def test_pool2(self): 113 | ds = build_ds([ 114 | (3, 2), # add to pool 115 | (5, 5), # add to pool 116 | (2, 4), # add to pool, write (5, 5) as batch 117 | (2, 3), # pair with (3, 2) 118 | ]) 119 | ds = _pack_in_pairs(ds, 1, 5, 5, pool_size=2).as_numpy_iterator() 120 | ex1 = next(ds) 121 | self.assertEqual(set(ex1["ixs"].ravel().tolist()), {0, 2}) 122 | ex2 = next(ds) 123 | self.assertEqual(set(ex2["ixs"].ravel().tolist()), {1, 4}) 124 | 125 | 126 | if __name__ == "__main__": 127 | absltest.main() -------------------------------------------------------------------------------- /t5x/examples/unified_io/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/t5x/examples/unified_io/scripts/__init__.py -------------------------------------------------------------------------------- /t5x/examples/unified_io/t5_1_1/base.gin: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | # T5.1.1 Base model. 4 | from __gin__ import dynamic_registration 5 | 6 | import seqio 7 | from t5x import adafactor 8 | from t5x.examples.unified_io import models 9 | from t5x.examples.unified_io import network 10 | from t5x.examples.unified_io import config 11 | from t5x import optimizers 12 | from t5x import utils 13 | import optax 14 | from t5x import trainer 15 | 16 | # ------------------- Loss HParam ---------------------------------------------- 17 | Z_LOSS = 0.0001 18 | LABEL_SMOOTHING = 0.0 19 | TEXT_DECODER_LENGTH = None 20 | IMAGE_DECODER_LENGTH = None 21 | # NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) 22 | # the loss normalizing factor should be set to pretraining batch_size * 23 | # target_token_length. 24 | LOSS_NORMALIZING_FACTOR = None 25 | LOSS_NORMALIZING_BY_WEIGHT_SUM = True 26 | # Dropout should be specified in the "run" files 27 | DROPOUT_RATE = 0.0 28 | DROPOUT_BROADCAST_DIMS = (-2, ) 29 | DROPPATH_RATE = 0.0 30 | 31 | # Vocabulary (shared by encoder and decoder) 32 | VOCABULARY = @seqio.SentencePieceVocabulary() 33 | seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" 34 | 35 | # ------------------- Optimizer ------------------------------------------------ 36 | # `learning_rate` is set by `Trainer.learning_rate_fn`. 37 | # In this case, we choose to switch to the AdamW optimizer with gradient clip. 38 | OPTIMIZER = None 39 | 40 | # ------------------- Model ---------------------------------------------------- 41 | MODEL = @models.EncoderDecoderModel() 42 | models.EncoderDecoderModel: 43 | module = @network.Transformer() 44 | input_vocabulary = %VOCABULARY 45 | output_vocabulary = %VOCABULARY 46 | optimizer_def = %OPTIMIZER 47 | z_loss = %Z_LOSS 48 | label_smoothing = %LABEL_SMOOTHING 49 | loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR 50 | loss_normalizing_by_weight_sum = %LOSS_NORMALIZING_BY_WEIGHT_SUM 51 | 52 | # ------------------- Network specification ------------------------------------ 53 | network.Transformer.config = @config.T5Config() 54 | 55 | config.T5Config: 56 | vocab_size = 33280 # vocab size rounded to a multiple of 128 for TPU efficiency 57 | image_vocab_size = 16512 # vocab size rounded to a multiple of 128 for TPU efficiency 58 | image_patch_size = 16 59 | audio_vocab_size = 8320 # vocab size rounded to a multiple of 128 for TPU efficiency 60 | audio_patch_size = 16 61 | 62 | dtype = 'bfloat16' 63 | emb_dim = 768 64 | num_heads = 12 65 | num_encoder_layers = 12 66 | num_decoder_layers = 12 67 | head_dim = 64 68 | mlp_dim = 2048 69 | mlp_activations = ('silu', 'linear') 70 | dropout_rate = %DROPOUT_RATE 71 | dropout_broadcast_dims = %DROPOUT_BROADCAST_DIMS 72 | logits_via_embedding = True 73 | float32_attention_logits = True 74 | decoder_xattention_internval = 1 75 | 76 | image_tokenizer_type = 'vqgan' 77 | 78 | from t5x.examples.unified_io import modality_processing 79 | modality_processing.get_input_modalities: 80 | image_vit_cfg = @config.ImageVitFeatureConfig() 81 | audio_vit_cfg = @config.AudioVitFeatureConfig() 82 | image_history_cfg = @config.ImageResamplerConfig() 83 | audio_history_cfg = @config.AudioResamplerConfig() 84 | use_image_vit = True 85 | use_audio_vit = True 86 | use_image_history_vit = True 87 | use_audio_history_vit = True 88 | 89 | modality_processing.get_target_modalities: 90 | image_vae_config = @config.VAEConfig() 91 | audio_vae_config = @config.AudioViTVQGANConfig() 92 | 93 | config.VAEConfig: 94 | embed_dim = 256 95 | n_embed = 16384 96 | double_z = False 97 | z_channels = 4 98 | resolution = 256 99 | in_channels = 3 100 | out_ch = 3 101 | ch = 128 102 | ch_mult = (1,2,2,4) 103 | num_res_blocks = 2 104 | attn_resolutions = (32,) 105 | dropout = 0 106 | default_input_size = (256,256) 107 | patch_size = (8, 8) 108 | 109 | config.AudioViTVQGANConfig: 110 | vocab_size = 8192 111 | proj_dim = 32 112 | # Transformers 113 | encoder_hidden_size = 512 114 | encoder_num_layers = 8 115 | encoder_mlp_dim = 2048 116 | encoder_num_heads = 8 117 | encoder_head_dim = 64 118 | 119 | encoder_hidden_size = 512 120 | encoder_num_layers = 8 121 | encoder_mlp_dim = 2048 122 | encoder_num_heads = 8 123 | encoder_head_dim = 64 124 | 125 | dropout_rate = 0.0 126 | droppath_rate = 0.0 127 | attention_dropout_rate = 0.0 128 | use_bias = False 129 | act_fn = 'relu' 130 | # PE 131 | add_position_embedding = False 132 | # Misc. 133 | dtype = 'bfloat16' 134 | default_input_size = (128, 256) # we need to keep this to make it 135 | patch_size = (8, 8) 136 | 137 | output_channel = 1 138 | use_decoder = True 139 | 140 | config.ImageVitFeatureConfig: 141 | patch_size = 16 142 | pos_patch_size = 16 143 | emb_dim = 768 144 | num_heads = 12 145 | num_layers = 11 # -2 layer 146 | mlp_dim = 3072 147 | mlp_activations = ('gelu', ) 148 | dropout_rate = 0.0 149 | dropout_broadcast_dims = () 150 | default_input_size = (256, 256) 151 | num_pos = 197 152 | dtype = 'float32' 153 | 154 | config.AudioVitFeatureConfig: 155 | patch_size = 16 156 | emb_dim = 768 157 | num_heads = 12 158 | num_layers = 11 # -2 layer 159 | mlp_dim = 3072 160 | mlp_activations = ('gelu', ) 161 | dropout_rate = 0.0 162 | dropout_broadcast_dims = () 163 | default_input_size = (256, 128) 164 | transpose_input = True 165 | dtype = 'float32' 166 | 167 | config.ImageResamplerConfig: 168 | dtype = 'bfloat16' 169 | resampler_type = 'perceiver' 170 | max_frames = 8 171 | latents_size = 32 172 | emb_dim = 768 173 | num_heads = 12 174 | num_layers = 2 175 | xattention_index = (0, 1) 176 | head_dim = 64 177 | mlp_dim = 3072 178 | mlp_activations = ('gelu',) 179 | dropout_broadcast_dims = (-2,) 180 | droppath_rate = 0.0 181 | layer_drop = 0.0 182 | dropout_rate = 0.0 183 | 184 | config.AudioResamplerConfig: 185 | dtype = 'bfloat16' 186 | resampler_type = 'perceiver' 187 | max_frames = 8 188 | latents_size = 16 189 | emb_dim = 768 190 | num_heads = 12 191 | num_layers = 2 192 | xattention_index = (0, 1) 193 | head_dim = 64 194 | mlp_dim = 3072 195 | mlp_activations = ('gelu',) 196 | dropout_broadcast_dims = (-2,) 197 | droppath_rate = 0.0 198 | layer_drop = 0.0 199 | dropout_rate = 0.0 200 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/t5_1_1/eval/vision_language.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | include 't5x/configs/runs/eval.gin' 3 | 4 | from t5x.examples.unified_io import aux_fns 5 | 6 | # IMPORTANT: This assumes pretty short prompts/output, change if needed 7 | TEXT_INPUTS = 128 8 | TEXT_DECODER_LENGTH = 32 9 | IMAGE_INPUT_SAMPLES = None 10 | 11 | DROPOUT_RATE = 0.0 # might be required by an imported model config 12 | 13 | TASK_FEATURE_LENGTHS = { 14 | "text_inputs": %TEXT_INPUTS, 15 | "text_targets": %TEXT_DECODER_LENGTH, 16 | "image_input_samples": %IMAGE_INPUT_SAMPLES, 17 | "is_training": False, 18 | } 19 | 20 | from t5x.examples.unified_io import modality_processing 21 | modality_processing.get_input_modalities.input_modality=["image", "text"] 22 | modality_processing.get_target_modalities.target_modality=["text"] 23 | 24 | from t5x.examples.unified_io import models 25 | models.EncoderDecoderModel.predict_batch_with_aux: 26 | length = %TEXT_DECODER_LENGTH 27 | modality = "text" 28 | 29 | # Import so any gin configurable method in these files will be registered with gin 30 | # and thus can be modified by command line 31 | from t5x.examples.unified_io import network 32 | from t5x.examples.unified_io.metrics import metrics 33 | from t5x.examples.unified_io import decoding 34 | 35 | # Import so the registration happens 36 | from t5x.examples.unified_io.data import tasks 37 | from t5x.examples.unified_io.data import mixtures 38 | from t5x.examples.unified_io import aux_fns 39 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/t5_1_1/finetune/refexp.gin: -------------------------------------------------------------------------------- 1 | # Fine tune on referring expression, required parameters: 2 | # INITIAL_CHECKPOINT_PATH 3 | # MODEL_DIR 4 | 5 | from __gin__ import dynamic_registration 6 | import __main__ as train_script 7 | 8 | # Register necessary SeqIO Tasks/Mixtures. 9 | 10 | from t5x.examples.unified_io.data import tasks 11 | from t5x.examples.unified_io import aux_fns 12 | from t5x.examples.unified_io import models 13 | from t5x import partitioning 14 | from t5x import trainer 15 | import seqio 16 | from t5x.examples.unified_io import packing 17 | from t5x import utils as t5x_utils 18 | from t5x.examples.unified_io import config 19 | 20 | 21 | include 't5x/configs/runs/multitask.gin' 22 | 23 | MIXTURE_OR_TASK_NAME = "refcoco_unc" 24 | MIXTURE_OR_TASK_NAME_EVAL = "refcoco_unc" 25 | 26 | 27 | TRAIN_STEPS = 3_100_000 # 100000 after 3million pre-training steps 28 | DROPOUT_RATE = 0.0 29 | BATCH_SIZE = 128 30 | EVAL_STEPS = 50 31 | 32 | 33 | train_script.train: 34 | eval_period = 2500 35 | stats_period = 500 36 | partitioner = @partitioning.PjitPartitioner() 37 | use_wandb = True 38 | concurrent_metrics = False 39 | infer_eval_dataset_cfg = @train_infer/t5x_utils.DatasetConfig() 40 | 41 | 42 | t5x_utils.SaveCheckpointConfig: 43 | period = 20000 44 | 45 | 46 | train_infer/t5x_utils.DatasetConfig: 47 | mixture_or_task_name = %MIXTURE_OR_TASK_NAME_EVAL 48 | task_feature_lengths = %TASK_FEATURE_LENGTHS_EVAL 49 | split = 'validation' 50 | batch_size = %BATCH_SIZE 51 | shuffle = False 52 | seed = 42 53 | use_cached = %USE_CACHED_TASKS 54 | pack = False 55 | 56 | 57 | partitioning.PjitPartitioner.num_partitions = 4 58 | 59 | from t5x.examples.unified_io import utils 60 | from t5x.examples.unified_io import modality_processing 61 | 62 | # Only load the needed modalities 63 | modality_processing.get_input_modalities.input_modality=["image", "text"] 64 | modality_processing.get_target_modalities.target_modality=["text"] 65 | 66 | 67 | TEXT_INPUT_LEN = 256 68 | TEXT_TARGET_LEN = 32 69 | IMAGE_SAMPLES = 1.0 70 | 71 | TASK_FEATURE_LENGTHS_TRAIN = { 72 | "text_inputs": %TEXT_INPUT_LEN, 73 | "text_targets": %TEXT_TARGET_LEN, 74 | "image_input_samples": %IMAGE_SAMPLES, 75 | "image_history_input_samples": 128, 76 | "audio_input_samples": 64, 77 | "audio_history_input_samples": 64, 78 | "num_frames": 4, 79 | "is_training": True, 80 | } 81 | 82 | 83 | TASK_FEATURE_LENGTHS_EVAL = { 84 | "text_inputs": %TEXT_INPUT_LEN, 85 | "text_targets": %TEXT_TARGET_LEN, 86 | "image_input_samples": None, 87 | "image_history_input_samples": 128, 88 | "audio_input_samples": 64, 89 | "audio_history_input_samples": 64, 90 | "num_frames": 4, 91 | "is_training": False, 92 | } 93 | 94 | models.EncoderDecoderModel.predict_batch_with_aux: 95 | length = %TEXT_TARGET_LEN 96 | modality = "text" 97 | 98 | 99 | t5x_utils.create_learning_rate_scheduler: 100 | factors = 'constant * linear_warmup * rsqrt_decay' 101 | # Generally for fine-tuning we half the learning rate 102 | base_learning_rate = 0.5 103 | warmup_steps = 2000 # 10k to keep consistent with T5/MTF defaults. 104 | 105 | 106 | from t5x import adafactor 107 | OPTIMIZER = @adafactor.Adafactor() 108 | adafactor.Adafactor: 109 | decay_rate = 0.8 110 | beta1 = 0.9 111 | step_offset = 0 112 | logical_factor_rules = @adafactor.standard_logical_factor_rules() 113 | global_norm_clip_threshold = 1.0 114 | skip_nan_updates = True 115 | 116 | 117 | import t5x.examples.unified_io.evaluator as uio_evaluator 118 | uio_evaluator.UnifiedIOEvaluator: 119 | num_examples = 10000 120 | logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @utils.WandbMetricsLogger] 121 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/t5_1_1/large.gin: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | # T5.1.1 Large model. 4 | 5 | include 't5x/examples/unified_io/t5_1_1/base.gin' # imports vocab, optimizer and model. 6 | 7 | EMBED_DIM = 1024 8 | MLP_DIM = 2816 9 | NUM_HEADS = 16 10 | HEAD_DIM = 64 11 | 12 | from t5x.examples.unified_io import config 13 | network.Transformer.config = @config.T5Config() 14 | config.T5Config: 15 | emb_dim = %EMBED_DIM 16 | num_heads = %NUM_HEADS 17 | num_encoder_layers = 24 18 | num_decoder_layers = 24 19 | head_dim = %HEAD_DIM 20 | mlp_dim = %MLP_DIM 21 | decoder_xattention_internval = 1 22 | 23 | config.ImageResamplerConfig: 24 | dtype = 'bfloat16' 25 | resampler_type = 'perceiver' 26 | emb_dim = 768 27 | num_heads = 12 28 | head_dim = 64 29 | mlp_dim = 2048 30 | num_layers = 2 31 | xattention_index = (0, 1) 32 | dropout_broadcast_dims = (-2,) 33 | mlp_activations = ('gelu',) 34 | max_frames = 8 35 | latents_size = 32 36 | 37 | config.AudioResamplerConfig: 38 | resampler_type = 'perceiver' 39 | dtype = 'bfloat16' 40 | emb_dim = 768 41 | num_heads = 12 42 | head_dim = 64 43 | mlp_dim = 2048 44 | num_layers = 2 45 | xattention_index = (0, 1) 46 | dropout_broadcast_dims = (-2,) 47 | mlp_activations = ('gelu',) 48 | max_frames = 8 49 | latents_size = 16 -------------------------------------------------------------------------------- /t5x/examples/unified_io/t5_1_1/tiny.gin: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | # T5.1.1 Large model. 4 | 5 | include 't5x/examples/unified_io/t5_1_1/base.gin' # imports vocab, optimizer and model. 6 | 7 | EMBED_DIM = 16 8 | MLP_DIM = 32 9 | NUM_HEADS = 4 10 | HEAD_DIM = 16 11 | NUM_RESAMPLER_LAYER = 1 12 | 13 | from t5x.examples.unified_io import config 14 | network.Transformer.config = @config.T5Config() 15 | config.T5Config: 16 | emb_dim = %EMBED_DIM 17 | num_heads = %NUM_HEADS 18 | num_encoder_layers = 1 19 | num_decoder_layers = 1 20 | head_dim = %HEAD_DIM 21 | mlp_dim = %MLP_DIM 22 | 23 | config.ImageResamplerConfig: 24 | emb_dim = %EMBED_DIM 25 | num_heads = %NUM_HEADS 26 | num_layers = %NUM_RESAMPLER_LAYER 27 | head_dim = %HEAD_DIM 28 | mlp_dim = %MLP_DIM 29 | 30 | config.AudioResamplerConfig: 31 | emb_dim = %EMBED_DIM 32 | num_heads = %NUM_HEADS 33 | num_layers = %NUM_RESAMPLER_LAYER 34 | head_dim = %HEAD_DIM 35 | mlp_dim = %MLP_DIM 36 | 37 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/t5_1_1/xl.gin: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | # T5.1.1 XL model. 4 | 5 | include 't5x/examples/unified_io/t5_1_1/base.gin' # imports vocab, optimizer and model. 6 | 7 | 8 | IMAGE_LATENT_SIZE = 128 9 | AUDIO_LATENT_SIZE = 64 10 | 11 | from t5x.examples.unified_io import config 12 | network.Transformer.config = @config.T5Config() 13 | 14 | config.T5Config: 15 | vocab_size = 33280 # vocab size rounded to a multiple of 128 for TPU efficiency 16 | image_vocab_size = 16512 # vocab size rounded to a multiple of 128 for TPU efficiency 17 | image_patch_size = 16 18 | audio_vocab_size = 8320 # vocab size rounded to a multiple of 128 for TPU efficiency 19 | audio_patch_size = 16 20 | dtype = 'bfloat16' 21 | emb_dim = 2048 22 | num_heads = 16 23 | num_encoder_layers = 24 24 | num_decoder_layers = 24 25 | head_dim = 128 26 | mlp_dim = 5120 27 | mlp_activations = ('silu', 'linear') 28 | dropout_rate = %DROPOUT_RATE 29 | logits_via_embedding = True 30 | float32_attention_logits = True 31 | decoder_xattention_internval = 1 32 | 33 | config.ImageResamplerConfig: 34 | dtype = 'bfloat16' 35 | resampler_type = 'perceiver' 36 | emb_dim = 1024 37 | num_heads = 16 38 | head_dim = 64 39 | mlp_dim = 4096 40 | num_layers = 2 41 | xattention_index = (0, 1) 42 | dropout_broadcast_dims = (-2,) 43 | mlp_activations = ('gelu',) 44 | max_frames = 8 45 | 46 | config.AudioResamplerConfig: 47 | resampler_type = 'perceiver' 48 | dtype = 'bfloat16' 49 | emb_dim = 1024 50 | num_heads = 16 51 | head_dim = 64 52 | mlp_dim = 4096 53 | num_layers = 2 54 | xattention_index = (0, 1) 55 | dropout_broadcast_dims = (-2,) 56 | mlp_activations = ('gelu',) 57 | max_frames = 8 -------------------------------------------------------------------------------- /t5x/examples/unified_io/t5_1_1/xxl.gin: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | # T5.1.1 XL model. 4 | 5 | include 't5x/examples/unified_io/t5_1_1/xl.gin' # imports vocab, optimizer and model. 6 | 7 | EMBED_DIM = 1024 8 | MLP_DIM = 2560 9 | NUM_HEADS = 16 10 | HEAD_DIM = 64 11 | 12 | from t5x.examples.unified_io import config 13 | # ------------------- Network specification overrides -------------------------- 14 | network.Transformer.config = @config.T5Config() 15 | config.T5Config: 16 | vocab_size = 33280 # vocab size rounded to a multiple of 128 for TPU efficiency 17 | image_vocab_size = 16512 # vocab size rounded to a multiple of 128 for TPU efficiency 18 | image_patch_size = 16 19 | audio_vocab_size = 8320 # vocab size rounded to a multiple of 128 for TPU efficiency 20 | audio_patch_size = 16 21 | dtype = 'bfloat16' 22 | emb_dim = 3072 23 | num_heads = 24 24 | num_encoder_layers = 24 25 | num_decoder_layers = 24 26 | head_dim = 128 27 | mlp_dim = 8192 28 | mlp_activations = ('silu', 'linear') 29 | dropout_rate = %DROPOUT_RATE 30 | logits_via_embedding = True 31 | float32_attention_logits = True 32 | decoder_xattention_internval = 1 33 | 34 | config.ImageResamplerConfig: 35 | dtype = 'bfloat16' 36 | resampler_type = 'perceiver' 37 | emb_dim = 1024 38 | num_heads = 16 39 | head_dim = 64 40 | mlp_dim = 4096 41 | num_layers = 2 42 | xattention_index = (0, 1) 43 | dropout_broadcast_dims = (-2,) 44 | mlp_activations = ('gelu',) 45 | max_frames = 8 46 | xattn_qk_norm = False 47 | xattn_scaled_cosine = True 48 | attn_qk_norm = False 49 | attn_scaled_cosine = True 50 | 51 | config.AudioResamplerConfig: 52 | resampler_type = 'perceiver' 53 | dtype = 'bfloat16' 54 | emb_dim = 1024 55 | num_heads = 16 56 | head_dim = 64 57 | mlp_dim = 4096 58 | num_layers = 2 59 | xattention_index = (0, 1) 60 | dropout_broadcast_dims = (-2,) 61 | mlp_activations = ('gelu',) 62 | max_frames = 8 63 | xattn_qk_norm = False 64 | xattn_scaled_cosine = True 65 | attn_qk_norm = False 66 | attn_scaled_cosine = True 67 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/test_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, List 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import flax.linen as nn 8 | from t5x.examples.unified_io import modality_processing 9 | from t5x.examples.unified_io.input_modalities import ModalityEncoder 10 | from t5x.examples.unified_io.network import Transformer 11 | from t5x.examples.unified_io.config import T5Config 12 | from t5x.examples.unified_io.seq_features import InputSequence, TargetSequence 13 | from t5x.examples.unified_io.target_modalities import BasicDecoder 14 | 15 | TARGET_MODALITIES = ["text", "image", "audio"] 16 | 17 | 18 | class NullEncoder(nn.Module): 19 | is_target: bool 20 | modality_id: int 21 | 22 | def __call__(self, embed, mask, init=False, targets=None, loss_mask=None, inputs=None, 23 | rel_atten=None, enable_dropout=None, use_constraints=True, segment_ids=None): 24 | if loss_mask is None: 25 | loss_mask = mask 26 | 27 | if self.is_target: 28 | return TargetSequence( 29 | embed, None, 30 | jnp.array(self.modality_id, dtype=jnp.int32), mask, rel_atten, 31 | subsegments=segment_ids, target_tokens=targets, loss_mask=loss_mask) 32 | else: 33 | return InputSequence(embed, mask, rel_atten) 34 | 35 | 36 | @dataclass 37 | class DebugModalityEncoder(ModalityEncoder): 38 | is_target: bool 39 | modality_id: int 40 | 41 | def get_encoder(self, *args, **kwargs) -> nn.Module: 42 | return NullEncoder(self.is_target, self.modality_id) 43 | 44 | def get_decoder(self, config, shared_embedding) -> nn.Module: 45 | return BasicDecoder(None, T5Config(None, logits_via_embedding=True), shared_embedding) 46 | 47 | 48 | class NullModalityEncoder(ModalityEncoder, nn.Module): 49 | def __init__(self, is_target, seq_len): 50 | super().__init__() 51 | self.seq_len = seq_len 52 | self.is_target = is_target 53 | 54 | def get_encoder(self, config, shared_embedding) -> nn.Module: 55 | self.t5_config = config.t5_config 56 | return self 57 | 58 | def __call__(self, null, init=False): 59 | bs = null.shape[0] 60 | if self.is_target: 61 | return TargetSequence.empty( 62 | bs, self.seq_len, self.t5_config.num_heads, self.t5_config.dtype, 2) 63 | else: 64 | return InputSequence.empty(bs, self.seq_len, self.t5_config) 65 | 66 | 67 | def build_random_batch(cfg, rng: np.random.RandomState, batch_size, 68 | input_modalities: List[int], 69 | target_modalities: List[int], 70 | target_segment_ids=False 71 | ): 72 | batch = build_inputs(cfg, rng, batch_size, input_modalities) 73 | batch.update(build_targets( 74 | cfg, rng, batch_size, target_modalities, target_segment_ids)) 75 | return batch 76 | 77 | 78 | def build_inputs(cfg, np_rng: np.random.RandomState, batch_size, input_modalities: List[int]): 79 | out = {} 80 | for ix, seq_len in enumerate(input_modalities): 81 | out[f"inputs/{ix}/mask"] = np_rng.random((batch_size, seq_len)) > 0.5 82 | out[f"inputs/{ix}/embed"] = np_rng.uniform(-1, 1, (batch_size, seq_len, cfg.emb_dim)) 83 | out[f"inputs/{ix}/rel_atten"] = np_rng.uniform(0, 0.5, (batch_size, cfg.num_heads, seq_len, seq_len)) 84 | 85 | return out 86 | 87 | 88 | def build_targets( 89 | cfg, np_rng: np.random.RandomState, batch_size, target_modalities, 90 | segment_ids=False): 91 | out = {} 92 | assert len(target_modalities) <= 3 93 | for name, seq_len in zip(TARGET_MODALITIES, target_modalities): 94 | if seq_len is None: 95 | continue 96 | out.update({ 97 | f"targets/{name}/mask": np_rng.random((batch_size, seq_len)) > 0.5, 98 | f"targets/{name}/embed": np_rng.uniform(-1, 1, (batch_size, seq_len, cfg.emb_dim)), 99 | f"targets/{name}/targets": np_rng.randint(0, cfg.vocab_size, (batch_size, seq_len), dtype=np.int32), 100 | }) 101 | if segment_ids and name == "text": 102 | out[f"targets/{name}/segment_ids"] = np_rng.randint(0, 2, (batch_size, seq_len), dtype=np.int32) 103 | return out 104 | 105 | 106 | DEBUG_CONFIG = T5Config( 107 | num_encoder_layers=2, 108 | num_decoder_layers=2, 109 | vocab_size=100, 110 | dropout_rate=0.0, 111 | emb_dim=8, 112 | num_heads=2, 113 | head_dim=4, 114 | mlp_dim=12, 115 | dtype=jnp.float32, 116 | mlp_activations=('gelu',), 117 | logits_via_embedding=True, 118 | image_vocab_size=1000, 119 | ) 120 | 121 | 122 | def build_test_transformer( 123 | cfg: T5Config, input_modalities: int, 124 | target_modalities: int, 125 | variable_seed=None, 126 | ): 127 | modality_processing.get_input_modalities = lambda: { 128 | str(ix): DebugModalityEncoder(False, ix) for ix in range(input_modalities) 129 | } 130 | 131 | # Use names `TARGET_MODALITIES` since those names are also hardcoded in some places 132 | assert target_modalities <= 3 133 | modality_processing.get_target_modalities = lambda: { 134 | name: DebugModalityEncoder(True, ix) for ix, name in enumerate(TARGET_MODALITIES[:target_modalities]) 135 | } 136 | 137 | trans = Transformer(cfg) 138 | if variable_seed is not None: 139 | rng = jax.random.PRNGKey(variable_seed) 140 | np_rng = np.random.RandomState(variable_seed*9681) 141 | batch = build_random_batch( 142 | trans.config, np_rng, 1, [1]*input_modalities, [1]*target_modalities) 143 | variables = trans.init(rng, batch, init=True) 144 | return trans, variables 145 | else: 146 | return trans 147 | -------------------------------------------------------------------------------- /t5x/examples/unified_io/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import Optional, List, Dict, Mapping, Any, Sequence, Tuple 4 | 5 | import gin 6 | import jax.numpy as jnp 7 | import jax.random 8 | import numpy as np 9 | import seqio 10 | import tensorflow as tf 11 | import wandb 12 | from absl import logging 13 | 14 | from t5x import checkpoints 15 | from t5x import utils 16 | from t5x.examples.unified_io.config import SHUFFLE_BUFFER_SIZE, CYCLE_LENGTH, BLOCK_LENGTH 17 | from t5x.utils import TrainStateInitializer, RestoreCheckpointConfig, LegacyCheckpointManager 18 | 19 | 20 | def get_model(model_size, input_modalities=None, target_modalities=None, 21 | gin_bindings: Optional[List[str]]=None, dtype="bfloat16"): 22 | """Return a EncoderDecoder model and configure the code to support that model 23 | 24 | This will also configure the pre-processing functions to be consistent with returned model 25 | """ 26 | if gin.config_is_locked(): 27 | logging.warning("Using `get_model` might override existing gin flags") 28 | 29 | with gin.config.unlock_config(): 30 | def _get(model): 31 | return model 32 | 33 | bindings = [] 34 | if gin_bindings: 35 | bindings += gin_bindings 36 | if input_modalities: 37 | bindings.append(f"get_input_modalities.input_modality={input_modalities}") 38 | if target_modalities: 39 | bindings.append(f"get_target_modalities.target_modality={target_modalities}") 40 | if dtype != "bfloat16": 41 | bindings += [f"{x}.dtype=\"float32\"" for x in [ 42 | "T5Config", 43 | "AudioViTVQGANConfig", 44 | "VAEConfig", 45 | "ImageVitFeatureConfig", 46 | "AudioVitFeatureConfig", 47 | "ImageResamplerConfig", 48 | "AudioResamplerConfig", 49 | ]] 50 | _get = gin.configurable(_get) 51 | gin.parse_config_files_and_bindings( 52 | config_files=[f"t5x/examples/unified_io/t5_1_1/{model_size}.gin"], 53 | bindings=bindings + [ 54 | "_get.model=%MODEL" 55 | ] 56 | ) 57 | return _get() 58 | 59 | 60 | def get_parameters(model, model_checkpoint: str = None, 61 | partitioner=None, rng: jax.random.PRNGKey=None) -> Tuple: 62 | """Get parameters for a model 63 | 64 | model: Model to get parameters for 65 | model_checkpoint: Checkpoint, if None initialized the model from scratch 66 | partitioner: If given, load parameters with this partitioner 67 | rng: If initializing from scratch, load parameters with this RNG 68 | """ 69 | from t5x.examples.unified_io.modality_processing import get_input_spec 70 | t0 = time.perf_counter() 71 | if rng is None: 72 | seed = np.random.randint(np.iinfo(np.int32).min, np.iinfo(np.int32).max, (), np.int32) 73 | rng = jax.random.PRNGKey(seed) 74 | 75 | if model_checkpoint is None: 76 | logging.info("Init model from scratch") 77 | input_shapes, input_types = get_input_spec() 78 | if partitioner is None: 79 | params = model.get_initial_variables(rng, input_shapes, input_types) 80 | params, param_axes = params["params"], params["params_axes"] 81 | else: 82 | train_state_initializer = TrainStateInitializer( 83 | optimizer_def=None, 84 | init_fn=model.get_initial_variables, 85 | input_shapes=input_shapes, 86 | input_types=input_types, 87 | partitioner=partitioner 88 | ) 89 | train_state = train_state_initializer.from_scratch(rng) 90 | param_axes = train_state_initializer.train_state_axes.params 91 | params = freeze(train_state.params) 92 | else: 93 | if not model_checkpoint.startswith("gs://"): 94 | model_checkpoint = os.path.abspath(os.path.expanduser(model_checkpoint)) 95 | assert os.path.exists(model_checkpoint), f"{model_checkpoint=} does not exist!" 96 | logging.info(f"Loading model weights from {model_checkpoint}...") 97 | input_shapes, input_types = get_input_spec(1) 98 | if partitioner is not None: 99 | train_state_initializer = TrainStateInitializer( 100 | optimizer_def=None, 101 | init_fn=model.get_initial_variables, 102 | input_shapes=input_shapes, 103 | input_types=input_types, 104 | partitioner=partitioner 105 | ) 106 | param_axes = train_state_initializer.train_state_axes.params 107 | params = LegacyCheckpointManager( 108 | restore_cfg=RestoreCheckpointConfig(model_checkpoint), 109 | train_state_shape=train_state_initializer.global_train_state_shape, 110 | partitioner=partitioner 111 | ).restore([model_checkpoint], RestoreCheckpointConfig(model_checkpoint)).params 112 | else: 113 | params = checkpoints.load_t5x_checkpoint(path=model_checkpoint)['target'] 114 | params = jax.tree_util.tree_map(jnp.array, params) 115 | param_axes = None 116 | logging.info(f"Done in {time.perf_counter()-t0:0.1f}") 117 | return params, param_axes 118 | 119 | 120 | @gin.configurable() 121 | def init_wandb(name=None, group=None, entity=None, project=None): 122 | utils.create_learning_rate_scheduler() # Makes sure this is registered in `operative_config` 123 | config_str = gin.operative_config_str() 124 | logging.info(f"Init wandb with group={group} name={name}") 125 | wandb.init( 126 | group=group, 127 | name=name, 128 | entity=entity, 129 | project=project, 130 | force=True, 131 | notes=config_str 132 | ) 133 | 134 | 135 | def transpose_lists(lsts): 136 | """Transpose a list of lists.""" 137 | return [list(i) for i in zip(*lsts)] 138 | 139 | 140 | def list_of_dict_to_string(table: List[Dict[str, str]], filler="") -> str: 141 | keys = dict() 142 | for row in table: 143 | keys.update(row) 144 | raw_table = [list(keys)] 145 | raw_table += [[row.get(key, filler) for key in keys] for row in table] 146 | return table_string(raw_table) 147 | 148 | 149 | def table_string(table: List[List[str]]) -> str: 150 | """Table as listoflists to evenly spaces string""" 151 | # print while padding each column to the max column length 152 | if len(table) == 0: 153 | return "" 154 | col_lens = [0] * len(table[0]) 155 | for row in table: 156 | for i, cell in enumerate(row): 157 | col_lens[i] = max(len(cell), col_lens[i]) 158 | 159 | formats = ["{0:<%d}" % x for x in col_lens] 160 | out = [] 161 | for row in table: 162 | out.append(" ".join(formats[i].format(row[i]) for i in range(len(row)))) 163 | return "\n".join(out) 164 | 165 | 166 | class WandbMetricsLogger(seqio.Logger): 167 | """Log metrics to wandb""" 168 | 169 | def __call__( 170 | self, 171 | task_name: str, 172 | step: Optional[int], 173 | metrics: Mapping[str, Any], 174 | dataset: Optional[tf.data.Dataset], 175 | inferences: Optional[Mapping[str, Sequence[Any]]], 176 | targets: Optional[Sequence[Any]], 177 | ) -> None: 178 | if step is None: 179 | raise ValueError() 180 | 181 | wandb_metrics = {} 182 | for metric_name, metric_value in metrics.items(): 183 | if isinstance(metric_value, seqio.metrics.Scalar): 184 | wandb_metrics[f"inference/{task_name}/{metric_name}"] = metric_value.value 185 | else: 186 | logging.warning( 187 | "Skipping WandbLogging of non-serializable metric '%s' of type %s.", 188 | metric_name, 189 | type(metric_value), 190 | ) 191 | wandb.log(wandb_metrics, step=step) -------------------------------------------------------------------------------- /t5x/export.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Exports a T5X model. 16 | 17 | 18 | """ 19 | import os 20 | from typing import Sequence 21 | from absl import logging 22 | 23 | # Set Linen to add profiling information when constructing Modules. 24 | # Must be set before flax imports. 25 | # pylint:disable=g-import-not-at-top 26 | os.environ.setdefault('FLAX_PROFILE', 'true') 27 | 28 | import jax 29 | from t5x import export_lib 30 | 31 | if __name__ == '__main__': 32 | # pylint:disable=g-import-not-at-top 33 | from absl import app 34 | from absl import flags 35 | import gin 36 | from t5x import gin_utils 37 | # pylint:enable=g-import-not-at-top 38 | 39 | FLAGS = flags.FLAGS 40 | 41 | jax.config.parse_flags_with_absl() 42 | 43 | 44 | flags.DEFINE_multi_string( 45 | 'gin_file', 46 | default=None, 47 | help='Path to gin configuration file. Multiple paths may be passed and ' 48 | 'will be imported in the given order, with later configurations ' 49 | 'overriding earlier ones.') 50 | 51 | flags.DEFINE_multi_string( 52 | 'gin_bindings', 53 | default=[], 54 | help='Individual gin bindings. Also used to integrate gin and XManager.') 55 | 56 | flags.DEFINE_list( 57 | 'gin_search_paths', 58 | default=['t5x/configs'], 59 | help='Comma-separated list of gin config path prefixes to be prepended ' 60 | 'to suffixes given via `--gin_file`. If a file appears in. Only the ' 61 | 'first prefix that produces a valid path for each suffix will be ' 62 | 'used.') 63 | 64 | def main(argv: Sequence[str]): 65 | """Wrapper for g3pdb post mortems.""" 66 | _main(argv) 67 | 68 | def _main(argv: Sequence[str]): 69 | """True main function.""" 70 | if len(argv) > 1: 71 | raise app.UsageError('Too many command-line arguments.') 72 | 73 | save_with_gin = gin.configurable(export_lib.save) 74 | 75 | gin_utils.parse_gin_flags(FLAGS.gin_search_paths, FLAGS.gin_file, 76 | FLAGS.gin_bindings) 77 | logging.info('Creating inference function...') 78 | save_with_gin() 79 | 80 | gin_utils.run(main) 81 | -------------------------------------------------------------------------------- /t5x/gin_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for using gin configurations with T5X binaries.""" 16 | import os 17 | from typing import Optional, Sequence, Union 18 | 19 | from absl import app 20 | from absl import logging 21 | from clu import metric_writers 22 | import gin 23 | import jax 24 | import tensorflow as tf 25 | 26 | 27 | 28 | def parse_gin_flags(gin_search_paths: Sequence[str], 29 | gin_files: Sequence[str], 30 | gin_bindings: Sequence[str], 31 | skip_unknown: Union[bool, Sequence[str]] = False, 32 | finalize_config: bool = True): 33 | """Parses provided gin files override params. 34 | 35 | Args: 36 | gin_search_paths: paths that will be searched for gin files. 37 | gin_files: paths to gin config files to be parsed. Files will be parsed in 38 | order with conflicting settings being overriden by later files. Paths may 39 | be relative to paths in `gin_search_paths`. 40 | gin_bindings: individual gin bindings to be applied after the gin files are 41 | parsed. Will be applied in order with conflicting settings being overriden 42 | by later oens. 43 | skip_unknown: whether to ignore unknown bindings or raise an error (default 44 | behavior). Alternatively, a list of configurable names to skip if unknown. 45 | finalize_config: whether to finalize the config so that it cannot be 46 | modified (default behavior). 47 | """ 48 | # We import t5.data here since it includes gin configurable functions commonly 49 | # used by task modules. 50 | # TODO(adarob): Strip gin from t5.data and remove this import. 51 | import t5.data # pylint:disable=unused-import,g-import-not-at-top 52 | # Register .gin file search paths with gin 53 | for gin_file_path in gin_search_paths: 54 | gin.add_config_file_search_path(gin_file_path) 55 | 56 | 57 | # Parse config files and bindings passed via flag. 58 | gin.parse_config_files_and_bindings( 59 | gin_files, 60 | gin_bindings, 61 | skip_unknown=skip_unknown, 62 | finalize_config=finalize_config) 63 | logging.info('Gin Configuration:') 64 | for line in gin.config_str().splitlines(): 65 | logging.info('%s', line) 66 | 67 | 68 | def rewrite_gin_args(args: Sequence[str]) -> Sequence[str]: 69 | """Rewrite `--gin.NAME=VALUE` flags to `--gin_bindings=NAME=VALUE`.""" 70 | 71 | def _rewrite_gin_arg(arg): 72 | if not arg.startswith('--gin.'): 73 | return arg 74 | if '=' not in arg: 75 | raise ValueError( 76 | "Gin bindings must be of the form '--gin.=', got: " + 77 | arg) 78 | # Strip '--gin.' 79 | arg = arg[6:] 80 | name, value = arg.split('=', maxsplit=1) 81 | r_arg = f'--gin_bindings={name} = {value}' 82 | print(f'Rewritten gin arg: {r_arg}') 83 | return r_arg 84 | 85 | return [_rewrite_gin_arg(arg) for arg in args] 86 | 87 | 88 | @gin.register 89 | def summarize_gin_config(model_dir: str, 90 | summary_writer: Optional[metric_writers.MetricWriter], 91 | step: int): 92 | """Writes gin config to the model dir and TensorBoard summary.""" 93 | if jax.process_index() == 0: 94 | config_str = gin.config_str() 95 | tf.io.gfile.makedirs(model_dir) 96 | # Write the config as JSON. 97 | with tf.io.gfile.GFile(os.path.join(model_dir, 'config.gin'), 'w') as f: 98 | f.write(config_str) 99 | # Include a raw dump of the json as a text summary. 100 | if summary_writer is not None: 101 | summary_writer.write_texts(step, {'config': gin.markdown(config_str)}) 102 | summary_writer.flush() 103 | 104 | 105 | def run(main): 106 | """Wrapper for app.run that rewrites gin args before parsing.""" 107 | app.run( 108 | main, 109 | flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a))) # pytype: disable=wrong-arg-types 110 | 111 | 112 | # ====================== Configurable Utility Functions ====================== 113 | 114 | 115 | @gin.configurable 116 | def sum_fn(var1=gin.REQUIRED, var2=gin.REQUIRED): 117 | """sum function to use inside gin files.""" 118 | return var1 + var2 119 | 120 | 121 | @gin.configurable 122 | def bool_fn(var1=gin.REQUIRED): 123 | """bool function to use inside gin files.""" 124 | return bool(var1) 125 | -------------------------------------------------------------------------------- /t5x/gin_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for gin_utils.""" 16 | 17 | from absl.testing import absltest 18 | from t5x import gin_utils 19 | 20 | 21 | class GinUtilsTest(absltest.TestCase): 22 | 23 | def test_rewrite_gin_args(self): 24 | test_args = [ 25 | '--gin_file=path/to/file', 26 | 'gin.value=3', 27 | '--gin.value=3', 28 | '--gin.value="3"', 29 | '--gin.value=\'3\'', 30 | '--gin.tricky="key = value"', 31 | '--gin.dict={"foo": 4, "bar": "four"}', 32 | '--gin.gin=bar', 33 | '--gin.scope/foo=bar', 34 | ] 35 | expected_args = [ 36 | '--gin_file=path/to/file', 37 | 'gin.value=3', 38 | '--gin_bindings=value = 3', 39 | '--gin_bindings=value = "3"', 40 | '--gin_bindings=value = \'3\'', 41 | '--gin_bindings=tricky = "key = value"', 42 | '--gin_bindings=dict = {"foo": 4, "bar": "four"}', 43 | '--gin_bindings=gin = bar', 44 | '--gin_bindings=scope/foo = bar', 45 | ] 46 | self.assertSequenceEqual( 47 | gin_utils.rewrite_gin_args(test_args), expected_args) 48 | 49 | def test_rewrite_gin_args_malformed(self): 50 | test_args = ['--gin.value=3', '--gin.test'] 51 | with self.assertRaisesWithLiteralMatch( 52 | ValueError, 53 | "Gin bindings must be of the form '--gin.=', got: " 54 | '--gin.test'): 55 | gin_utils.rewrite_gin_args(test_args) 56 | 57 | 58 | if __name__ == '__main__': 59 | absltest.main() 60 | -------------------------------------------------------------------------------- /t5x/losses_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for t5x.losses.""" 16 | 17 | from absl.testing import absltest 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | from t5x import losses 22 | 23 | 24 | class LossTest(absltest.TestCase): 25 | 26 | def test_xent(self): 27 | 28 | def lossfn(logits, targets, weights): 29 | loss, z_loss, weight_sum = losses.compute_weighted_cross_entropy( 30 | logits, 31 | targets, 32 | weights, 33 | label_smoothing=0.1, 34 | z_loss=0.1, 35 | loss_normalizing_factor=0.1) 36 | return loss, (z_loss, weight_sum) 37 | 38 | batch_size = 2 39 | length = 4 40 | vocab_size = 8 41 | logits = np.random.normal(size=(batch_size, length, 42 | vocab_size)).astype(np.float32) 43 | targets = np.random.randint(0, vocab_size, size=(batch_size, length)) 44 | weights = np.ones_like(targets) 45 | out = jax.jit(jax.value_and_grad(lossfn, has_aux=True))(logits, targets, 46 | weights) 47 | (loss, (z_loss, weight_sum)), dlogits = out 48 | # Just a smoke test for now 49 | # TODO(t5x): Expand test 50 | print(jax.device_get(((loss, (z_loss, weight_sum)), dlogits))) 51 | 52 | 53 | class SpecialLossNormalizingFactorTest(absltest.TestCase): 54 | 55 | def test_num_real_target_tokens(self): 56 | batch = { 57 | 'decoder_target_tokens': 58 | jnp.asarray([[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32) 59 | } 60 | 61 | (output_lnf, 62 | output_loss_weights) = losses.get_loss_normalizing_factor_and_weights( 63 | loss_normalizing_factor=losses.SpecialLossNormalizingFactor 64 | .NUM_REAL_TARGET_TOKENS, 65 | batch=batch) 66 | 67 | np.testing.assert_allclose(output_lnf, 6.0, rtol=1e-3) 68 | np.testing.assert_allclose( 69 | output_loss_weights, 70 | np.array([[1.0, 1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0]], 71 | dtype=np.float32), 72 | rtol=1e-3) 73 | 74 | def test_num_total_target_tokens(self): 75 | batch = { 76 | 'decoder_target_tokens': 77 | jnp.asarray([[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32) 78 | } 79 | 80 | (output_lnf, 81 | output_loss_weights) = losses.get_loss_normalizing_factor_and_weights( 82 | loss_normalizing_factor=losses.SpecialLossNormalizingFactor 83 | .NUM_TOTAL_TARGET_TOKENS, 84 | batch=batch) 85 | 86 | np.testing.assert_allclose(output_lnf, 10.0, rtol=1e-3) 87 | np.testing.assert_allclose( 88 | output_loss_weights, 89 | np.array([[1.0, 1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0]], 90 | dtype=np.float32), 91 | rtol=1e-3) 92 | 93 | def test_average_per_sequence(self): 94 | batch = { 95 | 'decoder_target_tokens': 96 | jnp.asarray([[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32) 97 | } 98 | 99 | (output_lnf, 100 | output_loss_weights) = losses.get_loss_normalizing_factor_and_weights( 101 | loss_normalizing_factor=losses.SpecialLossNormalizingFactor 102 | .AVERAGE_PER_SEQUENCE, 103 | batch=batch) 104 | 105 | np.testing.assert_allclose(output_lnf, 2.0, rtol=1e-3) 106 | np.testing.assert_allclose( 107 | output_loss_weights, 108 | jnp.asarray([[0.25, 0.25, 0.25, 0.25, 0.0], [0.5, 0.5, 0.0, 0.0, 0.0]], 109 | jnp.float32), 110 | rtol=1e-3) 111 | 112 | def test_average_per_sequence_with_weights(self): 113 | batch = { 114 | 'decoder_target_tokens': 115 | jnp.asarray([[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32), 116 | 'decoder_loss_weights': 117 | jnp.asarray([[0.5, 1.0, 0.25, 2.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0]], 118 | jnp.float32) 119 | } 120 | 121 | (output_lnf, 122 | output_loss_weights) = losses.get_loss_normalizing_factor_and_weights( 123 | loss_normalizing_factor=losses.SpecialLossNormalizingFactor 124 | .AVERAGE_PER_SEQUENCE, 125 | batch=batch) 126 | 127 | np.testing.assert_allclose(output_lnf, 2.0, rtol=1e-3) 128 | np.testing.assert_allclose( 129 | output_loss_weights, 130 | jnp.asarray([[0.5 / 3.75, 1.0 / 3.75, 0.25 / 3.75, 2.0 / 3.75, 0.0], 131 | [1.0 / 2.0, 1.0 / 2.0, 0.0, 0.0, 0.0]], jnp.float32), 132 | rtol=1e-3) 133 | 134 | def test_sum_weights_per_segment(self): 135 | weights = jnp.asarray( 136 | [[0.5, 1.0, 0.25, 2.0, 1.5], [1.0, 2.0, 3.0, 4.0, 5.0]], jnp.float32) 137 | positions = jnp.asarray([[0, 1, 2, 0, 0], [0, 0, 1, 0, 0]]) 138 | segment_ids = jnp.asarray([[1, 1, 1, 2, 3], [1, 2, 2, 3, 0]]) 139 | 140 | norm_vec = losses._sum_weights_per_segment(positions, segment_ids, weights) 141 | 142 | np.testing.assert_allclose( 143 | norm_vec, 144 | jnp.asarray([[1.75, 1.75, 1.75, 2.0, 1.5], [1.0, 5.0, 5.0, 4.0, 0.0]], 145 | jnp.float32), 146 | rtol=1e-3) 147 | 148 | def test_average_per_sequence_with_weights_with_packing(self): 149 | batch = { 150 | 'decoder_target_tokens': 151 | jnp.asarray([[1, 2, 3, 4, 5], [5, 6, 7, 8, 0]], jnp.int32), 152 | 'decoder_loss_weights': 153 | jnp.asarray([[0.5, 1.0, 0.25, 2.0, 1.5], [1.0, 2.0, 3.0, 4.0, 5.0]], 154 | jnp.float32), 155 | 'decoder_positions': 156 | jnp.asarray([[0, 1, 2, 0, 0], [0, 0, 1, 0, 0]]), 157 | 'decoder_segment_ids': 158 | jnp.asarray([[1, 1, 1, 2, 3], [1, 2, 2, 3, 0]]) 159 | } 160 | 161 | (output_lnf, 162 | output_loss_weights) = losses.get_loss_normalizing_factor_and_weights( 163 | loss_normalizing_factor=losses.SpecialLossNormalizingFactor 164 | .AVERAGE_PER_SEQUENCE, 165 | batch=batch) 166 | 167 | np.testing.assert_allclose(output_lnf, 6.0, rtol=1e-3) 168 | np.testing.assert_allclose( 169 | output_loss_weights, 170 | jnp.asarray( 171 | [[0.5 / 1.75, 1.0 / 1.75, 0.25 / 1.75, 2.0 / 2.0, 1.5 / 1.5], 172 | [1.0 / 1.0, 2.0 / 5.0, 3.0 / 5.0, 4.0 / 4.0, 0.0]], jnp.float32), 173 | rtol=1e-3) 174 | 175 | if __name__ == '__main__': 176 | absltest.main() 177 | -------------------------------------------------------------------------------- /t5x/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""The main entrance for running any of the T5X supported binaries. 16 | 17 | Currently this includes train/infer/eval/precompile. 18 | 19 | Example Local (CPU) Pretrain Gin usage 20 | 21 | python -m t5x.main \ 22 | --gin_file=t5x/examples/t5/t5_1_1/tiny.gin \ 23 | --gin_file=t5x/configs/runs/pretrain.gin \ 24 | --gin.MODEL_DIR=\"/tmp/t5x_pretrain\" \ 25 | --gin.TRAIN_STEPS=10 \ 26 | --gin.MIXTURE_OR_TASK_NAME=\"c4_v220_span_corruption\" \ 27 | --gin.MIXTURE_OR_TASK_MODULE=\"t5.data.mixtures\" \ 28 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 128, 'targets': 30}" \ 29 | --gin.DROPOUT_RATE=0.1 \ 30 | --run_mode=train \ 31 | --logtostderr 32 | """ 33 | import concurrent.futures # pylint:disable=unused-import 34 | import enum 35 | import importlib 36 | import os 37 | import sys 38 | from typing import Optional, Sequence 39 | 40 | from absl import app 41 | from absl import flags 42 | from absl import logging 43 | 44 | import gin 45 | import jax 46 | import seqio 47 | 48 | from t5x import gin_utils 49 | from t5x import utils 50 | 51 | 52 | @enum.unique 53 | class RunMode(enum.Enum): 54 | """All the running mode possible in T5X.""" 55 | TRAIN = 'train' 56 | EVAL = 'eval' 57 | INFER = 'infer' 58 | PRECOMPILE = 'precompile' 59 | EXPORT = 'export' 60 | 61 | 62 | _GIN_FILE = flags.DEFINE_multi_string( 63 | 'gin_file', 64 | default=None, 65 | help='Path to gin configuration file. Multiple paths may be passed and ' 66 | 'will be imported in the given order, with later configurations ' 67 | 'overriding earlier ones.') 68 | 69 | _GIN_BINDINGS = flags.DEFINE_multi_string( 70 | 'gin_bindings', default=[], help='Individual gin bindings.') 71 | 72 | _GIN_SEARCH_PATHS = flags.DEFINE_list( 73 | 'gin_search_paths', 74 | default=['.'], 75 | help='Comma-separated list of gin config path prefixes to be prepended ' 76 | 'to suffixes given via `--gin_file`. If a file appears in. Only the ' 77 | 'first prefix that produces a valid path for each suffix will be ' 78 | 'used.') 79 | 80 | _RUN_MODE = flags.DEFINE_enum_class( 81 | 'run_mode', 82 | default=None, 83 | enum_class=RunMode, 84 | help='The mode to run T5X under') 85 | 86 | _TFDS_DATA_DIR = flags.DEFINE_string( 87 | 'tfds_data_dir', None, 88 | 'If set, this directory will be used to store datasets prepared by ' 89 | 'TensorFlow Datasets that are not available in the public TFDS GCS ' 90 | 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' 91 | 'all `Task`s.') 92 | 93 | _DRY_RUN = flags.DEFINE_bool( 94 | 'dry_run', False, 95 | 'If set, does not start the function but stil loads and logs the config.') 96 | 97 | 98 | FLAGS = flags.FLAGS 99 | 100 | # Automatically search for gin files relative to the T5X package. 101 | _DEFAULT_GIN_SEARCH_PATHS = [ 102 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 103 | ] 104 | 105 | # Mapping of run_mode to the attribute used in the imported module, e.g. 106 | # {EVAL : 'evaluate'} will load 'evaluate' in eval.py. 107 | _ATTR_BY_RUN_MODE = { 108 | RunMode.TRAIN: 'train', 109 | RunMode.EVAL: 'evaluate', 110 | RunMode.INFER: 'infer', 111 | RunMode.PRECOMPILE: 'precompile', 112 | RunMode.EXPORT: 'save', 113 | } 114 | 115 | 116 | main_module = sys.modules[__name__] 117 | 118 | 119 | def main(argv: Sequence[str]): 120 | if len(argv) > 1: 121 | raise app.UsageError('Too many command-line arguments.') 122 | 123 | if _RUN_MODE.value is None: 124 | raise ValueError("'run_mode' flag must be specified when using main.py.") 125 | # Dynamic import the modules based on run_mode, e.g. 126 | # If _RUN_MODE.value is 'train', below is equivalent of doing: 127 | # from t5x import train 128 | # train = train.train 129 | 130 | # _RUN_MODE can never be None after this point. 131 | # pytype: disable=attribute-error 132 | lib_name = _RUN_MODE.value.name.lower() 133 | import_attr = _ATTR_BY_RUN_MODE[_RUN_MODE.value] 134 | # pytype: enable=attribute-error 135 | 136 | parent_module = 't5x' 137 | 138 | 139 | module_to_import = f'{parent_module}.{lib_name}' 140 | 141 | logging.info('Dynamically importing : %s', module_to_import) 142 | imported_lib = importlib.import_module(module_to_import) 143 | 144 | entry_func = getattr(imported_lib, import_attr) 145 | setattr(main_module, import_attr, entry_func) 146 | 147 | 148 | if _TFDS_DATA_DIR.value is not None: 149 | seqio.set_tfds_data_dir_override(_TFDS_DATA_DIR.value) 150 | 151 | 152 | # Register function explicitly under __main__ module, to maintain backward 153 | # compatability of existing '__main__' module references. 154 | gin.register(entry_func, '__main__') 155 | if _GIN_SEARCH_PATHS.value != ['.']: 156 | logging.warning( 157 | 'Using absolute paths for the gin files is strongly recommended.') 158 | 159 | # User-provided gin paths take precedence if relative paths conflict. 160 | gin_utils.parse_gin_flags(_GIN_SEARCH_PATHS.value + _DEFAULT_GIN_SEARCH_PATHS, 161 | _GIN_FILE.value, _GIN_BINDINGS.value) 162 | 163 | if _DRY_RUN.value: 164 | return 165 | 166 | run_with_gin = gin.get_configurable(entry_func) 167 | 168 | run_with_gin() 169 | 170 | 171 | 172 | def _flags_parser(args: Sequence[str]) -> Sequence[str]: 173 | """Flag parser. 174 | 175 | See absl.app.parse_flags_with_usage and absl.app.main(..., flags_parser). 176 | 177 | Args: 178 | args: All command line arguments. 179 | 180 | Returns: 181 | [str], a non-empty list of remaining command line arguments after parsing 182 | flags, including program name. 183 | """ 184 | return app.parse_flags_with_usage(list(gin_utils.rewrite_gin_args(args))) 185 | 186 | 187 | if __name__ == '__main__': 188 | jax.config.parse_flags_with_absl() 189 | app.run(main, flags_parser=_flags_parser) 190 | -------------------------------------------------------------------------------- /t5x/metrics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for clu.metrics.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | from t5x import metrics 23 | 24 | 25 | class MetricsTest(parameterized.TestCase): 26 | 27 | @parameterized.named_parameters( 28 | ("0d_values", 2., 2.), ("1d_values", [1, 2, 3], 6.), 29 | ("2d_values", [[1, 2], [2, 3], [3, 4]], 15.), 30 | ("3d_values", [[[1, 2], [2, 3]], [[2, 1], [3, 4]], [[3, 1], [4, 1]]], 27.) 31 | ) 32 | def test_sum(self, values, expected_result): 33 | self.assertAlmostEqual( 34 | metrics.Sum.from_model_output(values).compute(), expected_result) 35 | 36 | def test_time_rate(self): 37 | value = np.array([3.]) 38 | duration = 2. 39 | metric = metrics.TimeRate.from_model_output(value).replace_duration( 40 | duration) 41 | self.assertAlmostEqual(metric.compute(), value / duration) 42 | 43 | def test_time_rate_unset_duration(self): 44 | value = jnp.array([3.]) 45 | metric = metrics.TimeRate.from_model_output(value) 46 | with self.assertRaises(ValueError): 47 | metric.compute() 48 | 49 | def test_time_rate_sets_duration_inside_jitted_fn(self): 50 | 51 | @jax.jit 52 | def fn(): 53 | value = jnp.array([3.]) 54 | duration = 2. 55 | metric = metrics.TimeRate.from_model_output(value).replace_duration( 56 | duration) 57 | return metric 58 | 59 | with self.assertRaises(ValueError): 60 | fn() 61 | 62 | def test_time(self): 63 | duration = 2. 64 | metric = metrics.Time().replace_duration(duration) 65 | self.assertAlmostEqual(metric.compute(), duration) 66 | 67 | def test_time_unset_duration(self): 68 | metric = metrics.Time() 69 | with self.assertRaises(ValueError): 70 | metric.compute() 71 | 72 | @parameterized.named_parameters( 73 | ("0d_values", 2., 2.), 74 | ("1d_values", [1, 2, 3], 6.), 75 | ) 76 | def test_average_per_step(self, values, expected_result): 77 | a = metrics.AveragePerStep.from_model_output(values) 78 | m = metrics.set_step_metrics_num_steps({"a": a}, 1) 79 | self.assertAlmostEqual(m["a"].compute(), expected_result) 80 | 81 | steps = 5 82 | b = metrics.AveragePerStep.from_model_output(values, steps=steps) 83 | m = metrics.set_step_metrics_num_steps({"b": b}, steps) 84 | self.assertAlmostEqual(m["b"].compute(), expected_result / steps) 85 | 86 | def test_steps_per_time(self): 87 | steps = 8. 88 | duration = 2. 89 | metric = metrics.StepsPerTime.from_model_output( 90 | steps=steps).replace_duration(duration) 91 | metrics_dict = metrics.set_step_metrics_num_steps({"metric": metric}, steps) 92 | self.assertAlmostEqual(metrics_dict["metric"].compute(), steps / duration) 93 | 94 | 95 | if __name__ == "__main__": 96 | absltest.main() 97 | -------------------------------------------------------------------------------- /t5x/precompile.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Precompile and generates HLO from TPU metadata backend. 16 | 17 | TPU Metadata backend is a TPU backend without real TPU devices while supporting 18 | any TPU topologies, to allow work that doesn't require real TPUs to run as if 19 | it is, e.g., compiling/lowering a HLO graph with the backend. 20 | 21 | Ideally, the precompile defaults to cpu backend for default device array 22 | placement since metadata backend does not have memory allocation. 23 | 24 | The pjit function is pinned to use available TPU Metadata backend, for getting 25 | a proper lowering under TPU mesh. 26 | 27 | """ 28 | 29 | import os 30 | from typing import Callable, Optional 31 | 32 | import clu.data 33 | 34 | import jax 35 | from jax import random 36 | import numpy as np 37 | import t5.data.mixtures # pylint:disable=unused-import 38 | from t5x import models 39 | from t5x import partitioning 40 | from t5x import trainer as trainer_lib 41 | from t5x import utils 42 | import tensorflow as tf 43 | 44 | 45 | 46 | def precompile( 47 | *, 48 | model: models.BaseTransformerModel, 49 | train_dataset_cfg: utils.DatasetConfig, 50 | partitioner: partitioning.BasePartitioner, 51 | model_dir: str, 52 | random_seed: Optional[int], 53 | get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset, 54 | verify_matching_vocabs_fn: Optional[ 55 | Callable[[utils.DatasetConfig, models.BaseTransformerModel], 56 | None]] = utils.verify_matching_vocabs, 57 | ): 58 | """Compiles and dump the HLO to model dir, with HLO text dumps.""" 59 | rng = random.PRNGKey(random_seed or 42) 60 | _, trainer_rng = random.split(rng, 2) 61 | 62 | # TODO(hthu): Find a better way of getting dataset shapes instead of actually 63 | # reading database and iterate on it. 64 | data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size) 65 | ds_shard_id = data_layout.shard_id 66 | num_ds_shards = data_layout.num_shards 67 | 68 | if verify_matching_vocabs_fn is not None: 69 | verify_matching_vocabs_fn(train_dataset_cfg, model) 70 | 71 | train_iter = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards, 72 | model.FEATURE_CONVERTER_CLS) 73 | if isinstance(train_iter, tf.data.Dataset): 74 | train_iter = clu.data.TfDatasetIterator(train_iter, checkpoint=True) 75 | elif not isinstance(train_iter, clu.data.dataset_iterator.DatasetIterator): 76 | raise ValueError( 77 | f'get_dataset_fn returned unsupported type {type(train_iter)}.') 78 | 79 | # Need to use full batch size. 80 | input_shapes = jax.tree_map(lambda x: (data_layout.batch_size, *x.shape[1:]), 81 | train_iter.element_spec) 82 | input_types = jax.tree_map(lambda x: x.dtype, train_iter.element_spec) 83 | dummy_batch = jax.tree_map(lambda x: np.ones(x.shape, x.dtype), 84 | train_iter.element_spec) 85 | 86 | # Compiling does not care about loading real weights. 87 | train_state_initializer = utils.TrainStateInitializer( 88 | optimizer_def=model.optimizer_def, 89 | init_fn=model.get_initial_variables, 90 | input_shapes=input_shapes, 91 | input_types=input_types, 92 | partitioner=partitioner) 93 | train_state_shape = train_state_initializer.global_train_state_shape 94 | train_state_axes = train_state_initializer.train_state_axes 95 | 96 | def train_step(train_state, batch): 97 | return trainer_lib.train_with_lr( 98 | train_state, 99 | batch, 100 | learning_rate=1e-3, 101 | dropout_rng=trainer_rng, 102 | model=model, 103 | num_microbatches=None, 104 | weight_metrics_computer=None) 105 | 106 | partitioned_step = partitioner.partition( 107 | train_step, 108 | in_axis_resources=(train_state_axes, partitioning.PartitionSpec('data',)), 109 | out_axis_resources=(train_state_axes, None), 110 | donate_argnums=(0,)) 111 | 112 | # PartitionedTrainCallable has lower() defined but isn't exposed in pytype. 113 | # TODO(hthu): Explicitly expose the lower() interface. 114 | # pytype: disable=attribute-error 115 | lowered = partitioned_step.lower(train_state_shape, dummy_batch) 116 | # pytype: enable=attribute-error 117 | 118 | 119 | # TODO(hthu): Make this a proper library without writing files by default. 120 | tf.io.gfile.makedirs(model_dir) 121 | with tf.io.gfile.GFile( 122 | os.path.join(model_dir, 'lowered_hlo_pre_optimization'), 'w') as f: 123 | f.write(lowered.compiler_ir(dialect='hlo').as_serialized_hlo_module_proto()) 124 | compiled = lowered.compile() 125 | output_path = os.path.join(model_dir, 'lowered_hlo_post_optimization') 126 | with tf.io.gfile.GFile(output_path, 'w') as f: 127 | f.write(compiled.compiler_ir()[0].as_serialized_hlo_module_proto()) 128 | with tf.io.gfile.GFile(os.path.join(model_dir, 'assignment'), 'wb') as f: 129 | np.save(f, partitioner.mesh.device_ids) 130 | -------------------------------------------------------------------------------- /t5x/testdata/mtf_tiny_t5/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt-0" 2 | all_model_checkpoint_paths: "model.ckpt-0" 3 | -------------------------------------------------------------------------------- /t5x/testdata/mtf_tiny_t5/model.ckpt-0.data-00000-of-00002: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /t5x/testdata/mtf_tiny_t5/model.ckpt-0.data-00001-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/t5x/testdata/mtf_tiny_t5/model.ckpt-0.data-00001-of-00002 -------------------------------------------------------------------------------- /t5x/testdata/mtf_tiny_t5/model.ckpt-0.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/t5x/testdata/mtf_tiny_t5/model.ckpt-0.index -------------------------------------------------------------------------------- /t5x/testdata/pinned_ckpt_dir/PINNED: -------------------------------------------------------------------------------- 1 | 1 -------------------------------------------------------------------------------- /t5x/testdata/test_t5_tiny.checkpoint_0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/unified-io-2/502ac4d81239f82c891a9f412b000c3c8d4e2946/t5x/testdata/test_t5_tiny.checkpoint_0 -------------------------------------------------------------------------------- /t5x/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Separate file for storing the current version of T5X. 16 | 17 | Stored in a separate file so that setup.py can reference the version without 18 | pulling in all the dependencies in __init__.py. 19 | """ 20 | __version__ = '0.0.0' 21 | --------------------------------------------------------------------------------