├── .flake8
├── .github
└── workflows
│ ├── build_documentation.yaml
│ ├── ci_cd.yaml
│ ├── patch_version.yaml
│ └── release.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGELOG.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── benchmarks
├── albert
│ ├── README.MD
│ ├── benchmark_hf.py
│ ├── benchmark_tft.py
│ ├── conf
│ │ ├── benchmark
│ │ │ ├── hf.yaml
│ │ │ └── tft.yaml
│ │ └── config.yaml
│ └── run.py
├── gpt2
│ ├── README.MD
│ ├── benchmark_hf.py
│ ├── benchmark_tft.py
│ ├── conf
│ │ ├── benchmark
│ │ │ ├── hf.yaml
│ │ │ └── tft.yaml
│ │ └── config.yaml
│ └── run.py
├── imagenet_clip_benchmark.ipynb
├── imagenet_clip_benchmark.md
├── roberta
│ ├── README.MD
│ ├── __init__.py
│ ├── benchmark_hf.py
│ ├── benchmark_tft.py
│ ├── conf
│ │ ├── benchmark
│ │ │ ├── hf.yaml
│ │ │ └── tft.yaml
│ │ └── config.yaml
│ └── run.py
├── t5
│ ├── README.MD
│ ├── benchmark_hf.py
│ ├── benchmark_tft.py
│ ├── conf
│ │ ├── benchmark
│ │ │ ├── hf.yaml
│ │ │ └── tft.yaml
│ │ └── config.yaml
│ └── run.py
└── vit
│ ├── README.MD
│ ├── benchmark_hf.py
│ ├── benchmark_tft.py
│ ├── conf
│ ├── benchmark
│ │ ├── hf.yaml
│ │ └── tft.yaml
│ └── config.yaml
│ └── run.py
├── custom_hook.py
├── docs
├── Makefile
├── index.html
├── requirements.txt
├── requirements_build.txt
└── source
│ ├── README.md
│ ├── README.rst
│ ├── _static
│ ├── css
│ │ ├── Calibre-Light.ttf
│ │ ├── Calibre-Medium.otf
│ │ ├── Calibre-Regular.otf
│ │ ├── Calibre-Thin.otf
│ │ ├── DMSans-Bold.ttf
│ │ ├── DMSans-BoldItalic.ttf
│ │ ├── DMSans-Italic.ttf
│ │ ├── DMSans-Medium.ttf
│ │ ├── DMSans-MediumItalic.ttf
│ │ ├── DMSans-Regular.ttf
│ │ ├── DMSerifText-Italic.ttf
│ │ ├── DMSerifText-Regular.ttf
│ │ ├── code-snippets.css
│ │ └── custom_from_huggingface.css
│ ├── js
│ │ ├── custom.js
│ │ └── huggingface_logo.svg
│ ├── tf_transformers.png
│ ├── tf_transformers_resized.png
│ ├── transformers_blue.png
│ └── transformers_mix.png
│ ├── benchmarks
│ ├── albert.md
│ ├── gpt2.md
│ ├── imagenet_clip_benchmark.md
│ ├── t5.md
│ └── vit.md
│ ├── conf.py
│ ├── favicon.ico
│ ├── imgs
│ ├── long_block_sequencer.gif
│ └── philosophy.png
│ ├── index.rst
│ ├── introduction_docs
│ ├── installation.md
│ ├── philosophy.rst
│ └── quicktour.rst
│ ├── model_doc
│ ├── albert.rst
│ ├── albert_tokenizer.rst
│ ├── bart.rst
│ ├── bert.rst
│ ├── bigbird_tokenizer.rst
│ ├── clip.rst
│ ├── clip_feature_extractor.rst
│ ├── encoder_decoder.rst
│ ├── gpt2.rst
│ ├── m2m100.rst
│ ├── mbart.rst
│ ├── mt5.rst
│ ├── parallelism.md
│ ├── roberta.rst
│ ├── sentence_transformer.rst
│ ├── t5.rst
│ ├── t5_tokenizer.rst
│ ├── visual_bert.rst
│ ├── vit.rst
│ ├── vit_feature_extractor.rst
│ └── wav2vec2.rst
│ ├── model_usage
│ ├── sentence_transformers.ipynb
│ ├── sentence_transformers.md
│ ├── text_generation_using_gpt2.ipynb
│ ├── text_generation_using_gpt2.md
│ ├── text_generation_using_t5.ipynb
│ └── text_generation_using_t5.md
│ ├── research
│ ├── glue.md
│ └── long_block_sequencer.md
│ ├── tflite_tutorials
│ ├── albert_tflite.ipynb
│ ├── albert_tflite.md
│ ├── bert_tflite.ipynb
│ ├── bert_tflite.md
│ ├── roberta_tflite.ipynb
│ └── roberta_tflite.md
│ └── tutorials
│ ├── 1_read_write_tfrecords.ipynb
│ ├── 1_read_write_tfrecords.md
│ ├── 2_text_classification_imdb_albert.ipynb
│ ├── 2_text_classification_imdb_albert.md
│ ├── 3_masked_lm_tpu.ipynb
│ ├── 3_masked_lm_tpu.md
│ ├── 4_image_classification_vit_multi_gpu.ipynb
│ ├── 4_image_classification_vit_multi_gpu.md
│ ├── 5_sentence_embedding_roberta_quora_zeroshot.ipynb
│ ├── 5_sentence_embedding_roberta_quora_zeroshot.md
│ ├── 6_prompt_engineering_clip.ipynb
│ ├── 6_prompt_engineering_clip.md
│ ├── 7_gpt2_question_answering_squad.ipynb
│ ├── 7_gpt2_question_answering_squad.md
│ ├── 8_code_code_java_to_csharp_t5.ipynb
│ ├── 8_code_code_java_to_csharp_t5.md
│ ├── 9_images_tfrecords.ipynb
│ ├── 9_images_tfrecords.md
│ ├── README.ipynb
│ ├── README.md
│ ├── push_model_to_hf_hub.ipynb
│ ├── push_model_to_hf_hub.md
│ └── sample.ipynb
├── isort.cfg
├── mypy.ini
├── patch_version.py
├── poetry.lock
├── pyproject.toml
├── research
├── c4_grammatical_correction
│ ├── README.md
│ ├── __init__.py
│ ├── conf
│ │ └── config.yaml
│ ├── dataset_loader.py
│ ├── model.py
│ ├── run_c4_grammar_correction.py
│ └── train_c4_grammar_correction.py
├── diffusion
│ ├── attention.py
│ ├── beta_schedule.py
│ ├── check_diffusion.ipynb
│ ├── check_unet.ipynb
│ ├── gaussian_diffusion.py
│ ├── resnet.py
│ ├── time_embedding_layers.py
│ ├── unet.py
│ └── utils.py
├── experiments
│ ├── long_sequencer.ipynb
│ ├── long_sequencer_t5_small.ipynb
│ ├── roberta2roberta_pubmed.ipynb
│ └── roberta2roberta_xsum.ipynb
├── glue
│ ├── README.md
│ ├── cola.py
│ ├── conf
│ │ ├── config.yaml
│ │ └── glue
│ │ │ ├── cola.yaml
│ │ │ ├── mnli.yaml
│ │ │ ├── mrpc.yaml
│ │ │ ├── qnli.yaml
│ │ │ ├── qqp.yaml
│ │ │ ├── rte.yaml
│ │ │ ├── sst2.yaml
│ │ │ └── stsb.yaml
│ ├── mnli.py
│ ├── model.py
│ ├── mrpc.py
│ ├── qnli.py
│ ├── qqp.py
│ ├── rte.py
│ ├── run_glue.py
│ ├── run_mnli_mismatched.py
│ ├── score_glue.py
│ ├── sst2.py
│ ├── stsb.py
│ └── test_glue.ipynb
├── long_block_sequencer
│ ├── README.md
│ ├── __init__.py
│ ├── conf
│ │ └── config.yaml
│ ├── dataset_loader.py
│ ├── evaluate.py
│ ├── long_block_encoder.py
│ ├── model.py
│ ├── run_long_block_sequencer.py
│ └── train_long_block_sequencer.py
├── masked_language_model
│ ├── README.md
│ ├── __init__.py
│ ├── callbacks.py
│ ├── conf
│ │ └── config.yaml
│ ├── dataset_loader.py
│ ├── model.py
│ ├── run_mlm.py
│ └── train_mlm.py
├── masked_language_model_old
│ ├── 1_data_prepration.ipynb
│ ├── 1_data_to_text.py
│ ├── 2_text_to_features.py
│ ├── README.MD
│ ├── config
│ │ ├── data_config.yaml
│ │ ├── tfrecord_config.yaml
│ │ └── train_config.yaml
│ ├── prepare_data.py
│ └── train_mlm.py
├── mix_language_model
│ ├── README.MD
│ ├── __init__.py
│ ├── conf
│ │ └── config.yaml
│ ├── dataset_loader.py
│ ├── dataset_loader_with_sentence_split.py
│ ├── mix_lm_model.py
│ ├── model.py
│ ├── run_mix_lm.py
│ └── train_mix_lm.py
├── mix_language_model_old
│ ├── 1_data_prepration.ipynb
│ ├── 1_data_to_text.py
│ ├── 2_text_to_features.py
│ ├── README.MD
│ ├── config
│ │ ├── data_config.yaml
│ │ ├── tfrecord_config.yaml
│ │ └── train_config.yaml
│ ├── eval_mix_lm.py
│ └── train_mix_lm.py
├── sentence2vec
│ ├── dataset_loader.py
│ ├── model.py
│ ├── similarity_model.py
│ └── train.py
├── sentence_language_model
│ ├── README.md
│ ├── __init__.py
│ ├── callbacks.py
│ ├── conf
│ │ └── config.yaml
│ ├── dataset_loader.py
│ ├── model.py
│ ├── run_mlm.py
│ └── train_mlm.py
├── similarity_model_pretraining
│ ├── README.MD
│ ├── __init__.py
│ ├── callbacks.py
│ ├── conf
│ │ └── config.yaml
│ ├── dataset_loader.py
│ ├── model.py
│ ├── run_similarity.py
│ ├── similarity_model.py
│ └── train_similairity.py
└── t5_style_pretraining
│ ├── README.md
│ ├── __init__.py
│ ├── conf
│ └── config.yaml
│ ├── dataset_loader.py
│ ├── model.py
│ ├── run_t5_modified.py
│ ├── t5_modified.py
│ ├── t5_tokenizer_modified.py
│ └── train_t5_modified.py
├── src
├── logo.png
├── logo2.png
├── tf_transformers
│ ├── __init__.py
│ ├── activations
│ │ ├── __init__.py
│ │ ├── gelu.py
│ │ ├── swish.py
│ │ └── utils.py
│ ├── callbacks
│ │ ├── __init__.py
│ │ ├── metric_callback_list.py
│ │ └── metrics
│ │ │ ├── __init__.py
│ │ │ ├── callbacks.py
│ │ │ ├── pearson_spearman_callback.py
│ │ │ ├── sklearn_callbacks.py
│ │ │ └── text_generation_callbacks.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── chainer.py
│ │ ├── distribute_utils.py
│ │ ├── keras_utils.py
│ │ ├── legacy_compile.py
│ │ ├── legacy_layer.py
│ │ ├── legacy_model.py
│ │ ├── legacy_module.py
│ │ ├── model_utils_for_all.py
│ │ ├── model_wrapper.py
│ │ ├── performance_utils.py
│ │ ├── read_from_hub.py
│ │ ├── trainer.py
│ │ ├── trainer_for_all.py
│ │ └── transformer_config.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── callbacks
│ │ │ ├── __init__.py
│ │ │ └── mlm_callback.py
│ │ ├── ner_utils_sp.py
│ │ ├── processors
│ │ │ ├── __init__.py
│ │ │ ├── mlm.py
│ │ │ └── mlm_ttext.py
│ │ ├── squad_utils_sp.py
│ │ ├── tfprocessor_utils.py
│ │ ├── tfrecord_utils.py
│ │ └── utils.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── attention
│ │ │ ├── __init__.py
│ │ │ ├── bart_attention.py
│ │ │ ├── bert_attention.py
│ │ │ ├── bigbird_attention.py
│ │ │ ├── clip_attention.py
│ │ │ ├── gpt2_attention.py
│ │ │ └── t5_attention.py
│ │ ├── bias_layer.py
│ │ ├── dense_einsum.py
│ │ ├── image_embeddings.py
│ │ ├── image_utils.py
│ │ ├── layer_normalization.py
│ │ ├── mask
│ │ │ ├── __init__.py
│ │ │ ├── causal_mask.py
│ │ │ ├── cross_attention_mask.py
│ │ │ ├── masked_softmax.py
│ │ │ ├── prefix_mask.py
│ │ │ └── self_attention_mask.py
│ │ ├── mlm_layer copy.py
│ │ ├── mlm_layer.py
│ │ ├── multihead_attention.py
│ │ ├── on_device_embedding.py
│ │ ├── position_embedding.py
│ │ └── transformer
│ │ │ ├── __init__.py
│ │ │ ├── bart_transformer.py
│ │ │ ├── bert_transformer.py
│ │ │ ├── byt5_transformer.py
│ │ │ ├── clip_transformer.py
│ │ │ ├── gpt2_transformer.py
│ │ │ ├── mt5_transformer.py
│ │ │ ├── t5_transformer.py
│ │ │ └── vit_transformer.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── cross_entropy.py
│ │ └── loss_wrapper.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── albert
│ │ │ ├── __init__.py
│ │ │ ├── albert.py
│ │ │ ├── albert_model.py
│ │ │ ├── configuration_albert.py
│ │ │ ├── convert.py
│ │ │ └── tokenizer_albert.py
│ │ ├── bart
│ │ │ ├── __init__.py
│ │ │ ├── bart.py
│ │ │ ├── bart_model.py
│ │ │ ├── configuration_bart.py
│ │ │ └── convert.py
│ │ ├── bert
│ │ │ ├── __init__.py
│ │ │ ├── bert.py
│ │ │ ├── bert_model.py
│ │ │ ├── configuration_bert.py
│ │ │ └── convert.py
│ │ ├── bigbird
│ │ │ ├── __init__.py
│ │ │ └── tokenizer_bigbird_roberta.py
│ │ ├── byt5
│ │ │ ├── __init__.py
│ │ │ ├── byt5.py
│ │ │ ├── byt5_model.py
│ │ │ ├── configuration_byt5.py
│ │ │ └── convert.py
│ │ ├── clip
│ │ │ ├── __init__.py
│ │ │ ├── clip.py
│ │ │ ├── clip_feature_extractor.py
│ │ │ ├── clip_image_encoder.py
│ │ │ ├── clip_model.py
│ │ │ ├── clip_text_encoder.py
│ │ │ ├── configuration_clip.py
│ │ │ └── convert.py
│ │ ├── distilbert
│ │ │ ├── __init__.py
│ │ │ ├── configuration_bert.py
│ │ │ ├── convert.py
│ │ │ └── distilbert_model.py
│ │ ├── encoder_decoder
│ │ │ ├── __init__.py
│ │ │ └── encoder_decoder.py
│ │ ├── gpt2
│ │ │ ├── __init__.py
│ │ │ ├── configuration_gpt2.py
│ │ │ ├── convert.py
│ │ │ ├── gpt2.py
│ │ │ └── gpt2_model.py
│ │ ├── minilm
│ │ │ ├── __init__.py
│ │ │ ├── configuration_minilm.py
│ │ │ ├── convert.py
│ │ │ └── minilm_model.py
│ │ ├── model_configs
│ │ │ ├── __init__.py
│ │ │ ├── albert
│ │ │ │ ├── albert_base_v2.py
│ │ │ │ └── albert_large_v2.py
│ │ │ ├── bert
│ │ │ │ ├── bert_base_cased.py
│ │ │ │ ├── bert_base_uncased.py
│ │ │ │ ├── bert_large_cased.py
│ │ │ │ └── bert_large_uncased.py
│ │ │ ├── general_config.py
│ │ │ ├── gpt2
│ │ │ │ ├── gpt2.py
│ │ │ │ └── gpt2_medium.py
│ │ │ ├── mt5
│ │ │ │ └── mt5_small.py
│ │ │ ├── roberta
│ │ │ │ ├── roberta_base.py
│ │ │ │ └── roberta_large.py
│ │ │ ├── t5
│ │ │ │ ├── t5_base.py
│ │ │ │ └── t5_small.py
│ │ │ └── unilm_cnndm
│ │ │ │ └── config.json
│ │ ├── mt5
│ │ │ ├── __init__.py
│ │ │ ├── configuration_mt5.py
│ │ │ ├── convert.py
│ │ │ ├── mt5.py
│ │ │ └── mt5_model.py
│ │ ├── roberta
│ │ │ ├── __init__.py
│ │ │ ├── configuration_roberta.py
│ │ │ ├── convert.py
│ │ │ ├── roberta.py
│ │ │ └── roberta_model.py
│ │ ├── sentence_transformers
│ │ │ ├── __init__.py
│ │ │ ├── distilbert_sentence_model.py
│ │ │ ├── distilroberta_model.py
│ │ │ ├── minilm_model.py
│ │ │ ├── sentence_transformers.py
│ │ │ └── t5_sentence_model.py
│ │ ├── t5
│ │ │ ├── __init__.py
│ │ │ ├── configuration_t5.py
│ │ │ ├── convert.py
│ │ │ ├── t5.py
│ │ │ ├── t5_model.py
│ │ │ └── tokenizer_t5.py
│ │ ├── tasks
│ │ │ ├── __init__.py
│ │ │ ├── classification.py
│ │ │ ├── maked_lm_model.py
│ │ │ ├── similarity_model.py
│ │ │ └── span_selection.py
│ │ └── vit
│ │ │ ├── __init__.py
│ │ │ ├── configuration_vit.py
│ │ │ ├── convert.py
│ │ │ ├── vit.py
│ │ │ ├── vit_feature_extractor.py
│ │ │ └── vit_model.py
│ ├── optimization
│ │ ├── __init__.py
│ │ ├── adafactor_optimization.py
│ │ ├── adam_weighted.py
│ │ ├── learning_rate_utils.py
│ │ └── optimization.py
│ ├── text
│ │ ├── __init__.py
│ │ ├── decoder_utils.py
│ │ ├── lm_tasks
│ │ │ ├── __init__.py
│ │ │ ├── causal_lm.py
│ │ │ ├── masked_lm.py
│ │ │ └── prefix_lm.py
│ │ ├── sentencepiece_layer.py
│ │ ├── sentencepiece_model_pb2.py
│ │ ├── text_decoder.py
│ │ ├── text_decoder_encoder_only.py
│ │ ├── text_decoder_encoder_only_serializable.py
│ │ ├── text_decoder_model.py
│ │ ├── text_decoder_seq2seq.py
│ │ ├── text_decoder_seq2seq_serializable.py
│ │ ├── text_decoder_serializable_encoder_only.py
│ │ └── text_layer_experimental.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── docstring_file_utils.py
│ │ ├── docstring_utils.py
│ │ ├── fast_sp_alignment.py
│ │ ├── positional_bias_utils.py
│ │ ├── push_to_hub.py
│ │ ├── tf_utils.py
│ │ ├── tokenization.py
│ │ ├── utils.py
│ │ └── viz_utils.py
└── transformers_blue.png
├── tests
├── __init__.py
├── model_test_scripts
│ ├── test_modeling_albert.py
│ ├── test_modeling_bert.py
│ ├── test_modeling_gpt2.py
│ ├── test_modeling_mt5.py
│ ├── test_modeling_roberta.py
│ ├── test_modeling_t5.py
│ ├── test_modeling_vit.py
│ └── test_wav2vec2.py
└── test_tf_transformers.py
└── tutorials
├── 1_read_write_tfrecords.ipynb
├── 1_read_write_tfrecords.md
├── 2_text_classification_imdb_albert.ipynb
├── 2_text_classification_imdb_albert.md
├── 3_masked_lm_tpu.ipynb
├── 3_masked_lm_tpu.md
├── 4_image_classification_vit_multi_gpu.ipynb
├── 4_image_classification_vit_multi_gpu.md
├── 5_sentence_embedding_roberta_quora_zeroshot.ipynb
├── 5_sentence_embedding_roberta_quora_zeroshot.md
├── 6_prompt_engineering_clip.ipynb
├── 6_prompt_engineering_clip.md
├── 7_gpt2_question_answering_squad.ipynb
├── 7_gpt2_question_answering_squad.md
├── 8_code_code_java_to_csharp_t5.ipynb
├── 8_code_code_java_to_csharp_t5.md
├── 9_images_tfrecords.ipynb
├── 9_images_tfrecords.md
├── README.ipynb
├── README.md
├── push_model_to_hf_hub.ipynb
└── push_model_to_hf_hub.md
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | select = B,C,E,F,P,T4,W,B9
3 | # F401 - imported but unused
4 | # C901 - complexity
5 | # E722 do not use bare 'except'
6 | ignore = E203, E266, E501, W503, C901, E722, B001, B006, E731
7 | max-complexity = 10
8 | max-line-length = 120
9 | # to avoid import unused in __init__.py
10 | per-file-ignores =
11 | __init__.py:F401
12 | exclude =
13 | .git,
14 | __pycache__,
15 | tests,
16 | docs,
17 | tutorials
18 |
--------------------------------------------------------------------------------
/.github/workflows/build_documentation.yaml:
--------------------------------------------------------------------------------
1 | name: Documentation Build
2 | on:
3 | workflow_dispatch:
4 |
5 | jobs:
6 | build-docs:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - name: Checkout
10 | uses: actions/checkout@master
11 | with:
12 | fetch-depth: 0
13 | - uses: actions/setup-python@v2
14 | with:
15 | python-version: 3.9
16 | - name: Build documentation
17 | run: |
18 | mkdir gh-pages
19 | touch gh-pages/.nojekyll
20 | cd docs/
21 | pip3 install -r requirements_build.txt
22 | make clean html
23 | cp -r * ../gh-pages/
24 | - name: Deploy documentation
25 | uses: JamesIves/github-pages-deploy-action@4.1.4
26 | with:
27 | branch: gh-pages
28 | folder: gh-pages
--------------------------------------------------------------------------------
/.github/workflows/ci_cd.yaml:
--------------------------------------------------------------------------------
1 | name: CI-CD
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 | branches:
8 | - main
9 |
10 | jobs:
11 | ci:
12 | # Step 1. Set up operating system
13 | runs-on: ${{ matrix.os }}
14 | strategy:
15 | matrix:
16 | os: [macos-latest, ubuntu-latest]
17 | steps:
18 | # Step 2. Set up Python 3.9
19 | - uses: actions/setup-python@v2
20 | with:
21 | python-version: 3.9
22 | # Step 3. Check-out repository so we can access its contents
23 | - uses: actions/checkout@v2
24 | - name: Install package
25 | run: |
26 | pip install poetry==1.3.2
27 | poetry install
28 | # Step 6. Run flake8 for tftransformers
29 | - name: Run flake8
30 | run: |
31 | poetry add flake8
32 | poetry run flake8 src/
33 | # Step 5. Run tests for tftransformers
34 | - name: Test with pytest
35 | run: poetry run pytest tests/test_tf_transformers.py --cov-report=xml --cov=tests
36 | # Step 6. Use Codecov to track coverage
37 | - uses: codecov/codecov-action@v2
38 | with:
39 | file: ./coverage.xml # coverage report
40 | fail_ci_if_error: true # terminate workflow if there's an error
41 | token: ${{ secrets.CODECOV_TOKEN }}
42 | flags: unittests
43 | name: codecov-umbrella
44 | verbose: true
45 |
--------------------------------------------------------------------------------
/.github/workflows/patch_version.yaml:
--------------------------------------------------------------------------------
1 | name: Patch version
2 | on:
3 | release:
4 | types: [published]
5 |
6 | jobs:
7 | release:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@v2
11 | - name: Get tag
12 | id: vars
13 | run: echo ::set-output name=tag::${GITHUB_REF#refs/*/}
14 | - uses: actions/setup-python@v2
15 | with:
16 | python-version: '3.9'
17 | - name: Patch version and commit
18 | run: |
19 | echo "Tag retrived is ${{ steps.vars.outputs.tag }}"
20 | pip install poetry
21 | poetry version ${{ steps.vars.outputs.tag }}
22 | git config user.name github-actions
23 | git config user.email github-actions@github.com
24 | python patch_version.py ${{ steps.vars.outputs.tag }}
25 | git add pyproject.toml
26 | git add src/tf_transformers/__init__.py
27 | git add tests/test_tf_transformers.py
28 | git commit -m "Updated version to ${{ steps.vars.outputs.tag }}"
29 | git push origin HEAD:main
--------------------------------------------------------------------------------
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | name: Release
2 | on:
3 | workflow_dispatch:
4 |
5 | jobs:
6 | release:
7 | # # Only run this job if new work is pushed to "main"
8 | # if: github.event_name == 'push' && github.ref == 'refs/heads/main'
9 | # Step 1. Set up operating system
10 | runs-on: ${{ matrix.os }}
11 | strategy:
12 | matrix:
13 | os: [ubuntu-latest]
14 | steps:
15 | # Step 2. Set up Python 3.9
16 | - uses: actions/setup-python@v2
17 | with:
18 | python-version: 3.9
19 | # Step 3. Check-out repository so we can access its contents
20 | - uses: actions/checkout@v2
21 | with:
22 | fetch-depth: 0
23 | # Step 4. Build
24 | - name: Build package
25 | run: |
26 | pip install poetry==1.3.2
27 | poetry build
28 | # Step 5. Publish to TestPyPI
29 | - uses: pypa/gh-action-pypi-publish@release/v1
30 | with:
31 | user: __token__
32 | password: ${{ secrets.TEST_PYPI_TOKEN }}
33 | repository_url: https://test.pypi.org/legacy/
34 | skip_existing: true
35 | # Step 6. Test install from TestPyPI
36 | - name: Test install from TestPyPI
37 | run: |
38 | pip install \
39 | --index-url https://test.pypi.org/simple/ \
40 | --extra-index-url https://pypi.org/simple \
41 | tf-transformers
42 | # Step 7. Publish to PyPI
43 | - uses: pypa/gh-action-pypi-publish@release/v1
44 | with:
45 | user: __token__
46 | password: ${{ secrets.PYPI_TOKEN }}
47 | skip_existing: true
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Initially taken from Github's Python gitignore file
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 |
9 | # C extensions
10 | *.so
11 |
12 | # DS store
13 | .DS_Store
14 | .DS_Store/
15 |
16 | # tests and logs
17 | tests/fixtures/*
18 | !tests/fixtures/sample_text_no_unicode.txt
19 | logs/
20 | lightning_logs/
21 | lang_code_data/
22 |
23 | #hydra outputs
24 | benchmarks/*/outputs
25 | research/*/outputs
26 | outputs/
27 |
28 | # Distribution / packaging
29 | .Python
30 | build/
31 | develop-eggs/
32 | dist/
33 | downloads/
34 | eggs/
35 | .eggs/
36 | lib/
37 | lib64/
38 | parts/
39 | sdist/
40 | var/
41 | wheels/
42 | *.egg-info/
43 | .installed.cfg
44 | *.egg
45 | MANIFEST
46 |
47 | # PyInstaller
48 | # Usually these files are written by a python script from a template
49 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
50 | *.manifest
51 | *.spec
52 |
53 | # Installer logs
54 | pip-log.txt
55 | pip-delete-this-directory.txt
56 |
57 | # Unit test / coverage reports
58 | htmlcov/
59 | .tox/
60 | .nox/
61 | .coverage
62 | .coverage.*
63 | .cache
64 | nosetests.xml
65 | coverage.xml
66 | *.cover
67 | .hypothesis/
68 | .pytest_cache/
69 |
70 | # Translations
71 | *.mo
72 | *.pot
73 |
74 | # Django stuff:
75 | *.log
76 | local_settings.py
77 | db.sqlite3
78 |
79 | # Flask stuff:
80 | instance/
81 | .webassets-cache
82 |
83 | # Scrapy stuff:
84 | .scrapy
85 |
86 | # Sphinx documentation
87 | docs/_build/
88 |
89 | # PyBuilder
90 | target/
91 |
92 | # Jupyter Notebook
93 | .ipynb_checkpoints
94 |
95 | # IPython
96 | profile_default/
97 | ipython_config.py
98 |
99 | # pyenv
100 | .python-version
101 |
102 | # celery beat schedule file
103 | celerybeat-schedule
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 |
108 | # Environments
109 | .env
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
135 | # vscode
136 | .vs
137 | .vscode
138 |
139 | # Pycharm
140 | .idea
141 |
142 | # TF code
143 | tensorflow_code
144 |
145 | # Models
146 | proc_data
147 |
148 | # examples
149 | runs
150 | /runs_old
151 | /wandb
152 | /examples/runs
153 | /examples/**/*.args
154 | /examples/rag/sweep
155 |
156 | # data
157 | /data
158 | serialization_dir
159 |
160 | # emacs
161 | *.*~
162 | debug.env
163 |
164 | # vim
165 | .*.swp
166 |
167 | #ctags
168 | tags
169 |
170 | # pre-commit
171 | #.pre-commit*
172 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_stages:
2 | - commit
3 | exclude: .git|.tox|tests|docs
4 | fail_fast: true
5 | repos:
6 | - hooks:
7 | - exclude: ^tutorials/
8 | id: check-docstring-first
9 | - exclude: ^tutorials/
10 | id: trailing-whitespace
11 | - exclude: ^tutorials/
12 | id: end-of-file-fixer
13 | - id: check-toml
14 | - id: check-merge-conflict
15 | - exclude: ^tutorials/
16 | id: check-added-large-files
17 | repo: https://github.com/pre-commit/pre-commit-hooks
18 | rev: v4.4.0
19 | - hooks:
20 | - id: black
21 | repo: https://github.com/psf/black
22 | rev: 23.1.0
23 | - hooks:
24 | - id: isort
25 | repo: https://github.com/timothycrosley/isort
26 | rev: 5.12.0
27 | - hooks:
28 | - additional_dependencies:
29 | - flake8-bugbear
30 | - flake8-implicit-str-concat
31 | id: flake8
32 | repo: https://github.com/pycqa/flake8.git
33 | rev: 6.0.0
34 | - hooks:
35 | - args:
36 | - --no-strict-optional
37 | - --ignore-missing-imports
38 | id: mypy
39 | repo: https://github.com/pre-commit/mirrors-mypy
40 | rev: v1.0.1
41 | - hooks:
42 | - args:
43 | - --sync
44 | - tutorials/*.ipynb
45 | files: ^tutorials/
46 | id: jupytext
47 | repo: https://github.com/mwouts/jupytext
48 | rev: v1.14.4
49 |
50 | # - hooks:
51 | # - id: commitizen
52 | # stages: [commit-msg]
53 | # repo: https://github.com/commitizen-tools/commitizen
54 | # rev: v2.20.3
55 |
56 | - hooks:
57 | - entry: python custom_hook.py
58 | id: custom_hook
59 | language: python
60 | name: custom_hook
61 | verbose: true
62 | repo: local
63 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | ## Unreleased
2 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | Contributions are welcome, and they are greatly appreciated! Every little bit
4 | helps, and credit will always be given.
5 |
6 | ## Types of Contributions
7 |
8 | ### Report Bugs
9 |
10 | If you are reporting a bug, please include:
11 |
12 | * Your operating system name and version.
13 | * Any details about your local setup that might be helpful in troubleshooting.
14 | * Detailed steps to reproduce the bug.
15 |
16 | ### Fix Bugs
17 |
18 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help
19 | wanted" is open to whoever wants to implement it.
20 |
21 | ### Implement Features
22 |
23 | Look through the GitHub issues for features. Anything tagged with "enhancement"
24 | and "help wanted" is open to whoever wants to implement it.
25 |
26 | ### Write Documentation
27 |
28 | You can never have enough documentation! Please feel free to contribute to any
29 | part of the documentation, such as the official docs, docstrings, or even
30 | on the web in blog posts, articles, and such.
31 |
32 | ### Submit Feedback
33 |
34 | If you are proposing a feature:
35 |
36 | * Explain in detail how it would work.
37 | * Keep the scope as narrow as possible, to make it easier to implement.
38 | * Remember that this is a volunteer-driven project, and that contributions
39 | are welcome :)
40 |
41 | ## Get Started!
42 |
43 | Ready to contribute? Here's how to set up `tf-transformers` for local development.
44 |
45 | 1. Download a copy of `tf-transformers` locally.
46 | 2. Install `tf-transformers` using `poetry`:
47 |
48 | ```console
49 | $ poetry install
50 | ```
51 |
52 | 3. Use `git` (or similar) to create a branch for local development and make your changes:
53 |
54 | ```console
55 | $ git checkout -b name-of-your-bugfix-or-feature
56 | ```
57 |
58 | 4. When you're done making changes, check that your changes conform to any code formatting requirements and pass any tests.
59 |
60 | 5. Commit your changes and open a pull request.
61 |
62 | ## Pull Request Guidelines
63 |
64 | Before you submit a pull request, check that it meets these guidelines:
65 |
66 | 1. The pull request should include additional tests if appropriate.
67 | 2. If the pull request adds functionality, the docs should be updated.
68 | 3. The pull request should work for all currently supported operating systems and versions of Python.
69 |
70 | ## Code of Conduct
71 |
72 | Please note that the `tf-transformers` project is released with a
73 | Code of Conduct. By contributing to this project you agree to abide by its terms. Thanks.
74 |
--------------------------------------------------------------------------------
/benchmarks/albert/conf/benchmark/hf.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: hf
3 | device: cuda
4 |
5 | data:
6 | name: imdb
7 | take_sample: false
8 | batch_size: 32
9 | max_length: 512
10 |
11 | model:
12 | name: albert-base-v2
13 | type: tf
14 |
--------------------------------------------------------------------------------
/benchmarks/albert/conf/benchmark/tft.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: tft
3 |
4 | data:
5 | name: imdb
6 | take_sample: false
7 | batch_size: 32
8 | max_length: 512
9 |
10 | model:
11 | name: albert-base-v2
12 | type: saved_model
13 |
--------------------------------------------------------------------------------
/benchmarks/albert/conf/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - benchmark: tft
3 |
--------------------------------------------------------------------------------
/benchmarks/albert/run.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import hydra
4 | from omegaconf import DictConfig, OmegaConf
5 |
6 | # A logger for this file
7 | log = logging.getLogger(__name__)
8 |
9 |
10 | def run_benchmark(cfg):
11 |
12 | if cfg.benchmark.task.name == "tft":
13 | from benchmark_tft import TftBenchmark
14 |
15 | benchmark = TftBenchmark(cfg)
16 | results = benchmark.run()
17 | log.info(results)
18 |
19 | if cfg.benchmark.task.name == "hf":
20 | from benchmark_hf import HFBenchmark
21 |
22 | benchmark = HFBenchmark(cfg)
23 | results = benchmark.run()
24 | log.info(results)
25 |
26 |
27 | @hydra.main(config_path="conf", config_name="config")
28 | def run(cfg: DictConfig) -> None:
29 | run_benchmark((cfg))
30 | print(OmegaConf.to_yaml(cfg))
31 |
32 |
33 | if __name__ == "__main__":
34 | run()
35 |
--------------------------------------------------------------------------------
/benchmarks/gpt2/conf/benchmark/hf.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: hf
3 | device: cuda
4 |
5 | data:
6 | name: cnn_dailymail
7 | take_sample: false
8 | batch_size: 32
9 | max_length: 512
10 |
11 | model:
12 | name: gpt2
13 | type: tf
14 |
15 | text_generation:
16 | max_length: 64
17 |
--------------------------------------------------------------------------------
/benchmarks/gpt2/conf/benchmark/tft.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: tft
3 |
4 | data:
5 | name: cnn_dailymail
6 | take_sample: false
7 | batch_size: 32
8 | max_length: 512
9 |
10 | model:
11 | name: gpt2
12 | type: textdecoder_saved_model
13 |
14 | text_generation:
15 | max_iterations: 64
16 | mode: greedy
17 |
--------------------------------------------------------------------------------
/benchmarks/gpt2/conf/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - benchmark: tft
3 |
--------------------------------------------------------------------------------
/benchmarks/gpt2/run.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import hydra
4 | from omegaconf import DictConfig, OmegaConf
5 |
6 | # A logger for this file
7 | log = logging.getLogger(__name__)
8 |
9 |
10 | def run_benchmark(cfg):
11 |
12 | if cfg.benchmark.task.name == "tft":
13 | from benchmark_tft import TftBenchmark
14 |
15 | benchmark = TftBenchmark(cfg)
16 | results = benchmark.run()
17 | log.info(results)
18 |
19 | if cfg.benchmark.task.name == "hf":
20 | from benchmark_hf import HFBenchmark
21 |
22 | benchmark = HFBenchmark(cfg)
23 | results = benchmark.run()
24 | log.info(results)
25 |
26 |
27 | @hydra.main(config_path="conf", config_name="config")
28 | def run(cfg: DictConfig) -> None:
29 | run_benchmark((cfg))
30 | print(OmegaConf.to_yaml(cfg))
31 |
32 |
33 | if __name__ == "__main__":
34 | run()
35 |
--------------------------------------------------------------------------------
/benchmarks/roberta/README.MD:
--------------------------------------------------------------------------------
1 |
16 |
17 | # Benchmark Roberta
18 |
19 | - [Code - Roberta Benchmark](https://github.com/legacyai/tf-transformers/tree/main/benchmark/roberta)
20 |
21 | This is used to benchmark the performance of Roberta model on text generation tasks. We evaluate it using 3 frameworks.
22 | Tensorflow-Transformers (default), HuggingFace PyTorch, HuggingFace Tensorflow and HuggingFace JAX (pending).
23 | Executing these scripts are fairly straightfoward and expect users to install the necessary libraries before executing
24 | the benchmark script.
25 |
26 | All the configuration are managed using [Hydra](https://github.com/facebookresearch/hydra).
27 |
28 | -> Machine - **Tesla V100-SXM2 32GB**
29 |
30 | -> Tensorflow-version - **2.4.1**
31 |
32 | -> Huggingface-Transformer-Version - **4.12.5**
33 |
34 | -> PyTorch-Version - **1.9.0**
35 |
36 | ## Tensorflow-Transformers. (tft)
37 |
38 | The default benchmark mode is ```tft```.
39 | 1. To execute ```tft``` (default) :
40 | ```python run.py benchmark=tft```
41 |
42 | 2. To execute ```type``` eg ```keras_model``` :
43 | ```python run.py benchmark=tft benchmark.model.type=keras_model```
44 |
45 | * a. keras_model - Uses tf.keras.Model.
46 | * b. saved_model - Uses tf.saved_model .
47 |
48 | ## HuggingFace-Tensorflow. (hf-tf)
49 |
50 | 1. To execute ```hf-tf``` (default) :
51 | ```python run.py benchmark=hf benchmark.model.type=tf```
52 |
53 |
54 | ## HuggingFace-PyTorch. (hf-pt)
55 |
56 | 1. To execute ```hf-pt``` (default) :
57 | ```python run.py benchmark=hf benchmark.model.type=pt```
58 |
59 |
60 | ## HuggingFace-JAX. (hf-jax) (Not Available)
61 |
62 | 1. To execute ```hf-jax``` (default) :
63 | ```python run.py benchmark=hf benchmark.model.type=jax```
64 |
65 |
66 | ## Official Benchmarks on IMDB
67 |
68 | ```
69 | Text Classification:
70 | | | batch_size | time (s) | samples/second |
71 | |:---------------------------|-------------:|:-------------:|:-----------|------
72 | | tft + saved_model | 32 | 372.44 sec | 67 |
73 | | tft + keras_model | 32 | 367.17 sec | 70 |
74 | | hf_tf | 32 | 287.91 sec | 86 |
75 | | hf_pt | 32 | 253.61 sec | 98 |
76 | | hf_jax (pmap) | 32 | N/A | N/A |
77 | ```
78 |
--------------------------------------------------------------------------------
/benchmarks/roberta/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/benchmarks/roberta/__init__.py
--------------------------------------------------------------------------------
/benchmarks/roberta/conf/benchmark/hf.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: hf
3 | device: cuda
4 |
5 | data:
6 | name: imdb
7 | take_sample: false
8 | batch_size: 32
9 | max_length: 512
10 |
11 | model:
12 | name: roberta-base
13 | type: tf
14 |
--------------------------------------------------------------------------------
/benchmarks/roberta/conf/benchmark/tft.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: tft
3 |
4 | data:
5 | name: imdb
6 | take_sample: false
7 | batch_size: 32
8 | max_length: 512
9 |
10 | model:
11 | name: roberta-base
12 | type: saved_model
13 |
--------------------------------------------------------------------------------
/benchmarks/roberta/conf/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - benchmark: tft
3 |
--------------------------------------------------------------------------------
/benchmarks/roberta/run.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import hydra
4 | from omegaconf import DictConfig, OmegaConf
5 |
6 | # A logger for this file
7 | log = logging.getLogger(__name__)
8 |
9 |
10 | def run_benchmark(cfg):
11 |
12 | if cfg.benchmark.task.name == "tft":
13 | from benchmark_tft import TftBenchmark
14 |
15 | benchmark = TftBenchmark(cfg)
16 | results = benchmark.run()
17 | log.info(results)
18 |
19 | if cfg.benchmark.task.name == "hf":
20 | from benchmark_hf import HFBenchmark
21 |
22 | benchmark = HFBenchmark(cfg)
23 | results = benchmark.run()
24 | log.info(results)
25 |
26 |
27 | @hydra.main(config_path="conf", config_name="config")
28 | def run(cfg: DictConfig) -> None:
29 | run_benchmark((cfg))
30 | print(OmegaConf.to_yaml(cfg))
31 |
32 |
33 | if __name__ == "__main__":
34 | run()
35 |
--------------------------------------------------------------------------------
/benchmarks/t5/conf/benchmark/hf.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: hf
3 | device: cuda
4 |
5 | data:
6 | name: xsum
7 | take_sample: false
8 | batch_size: 32
9 | max_length: 512
10 |
11 | model:
12 | name: t5-small
13 | type: tf
14 |
15 | text_generation:
16 | max_length: 64
17 | eos_token_id: 250000
18 |
--------------------------------------------------------------------------------
/benchmarks/t5/conf/benchmark/tft.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: tft
3 |
4 | data:
5 | name: xsum
6 | take_sample: false
7 | batch_size: 32
8 | max_length: 512
9 |
10 | model:
11 | name: t5-small
12 | type: textdecoder_saved_model
13 |
14 | text_generation:
15 | max_iterations: 64
16 | mode: greedy
17 |
--------------------------------------------------------------------------------
/benchmarks/t5/conf/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - benchmark: tft
3 |
--------------------------------------------------------------------------------
/benchmarks/t5/run.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import hydra
4 | from omegaconf import DictConfig, OmegaConf
5 |
6 | # A logger for this file
7 | log = logging.getLogger(__name__)
8 |
9 |
10 | def run_benchmark(cfg):
11 |
12 | if cfg.benchmark.task.name == "tft":
13 | from benchmark_tft import TftBenchmark
14 |
15 | benchmark = TftBenchmark(cfg)
16 | results = benchmark.run()
17 | log.info(results)
18 |
19 | if cfg.benchmark.task.name == "hf":
20 | from benchmark_hf import HFBenchmark
21 |
22 | benchmark = HFBenchmark(cfg)
23 | results = benchmark.run()
24 | log.info(results)
25 |
26 |
27 | @hydra.main(config_path="conf", config_name="config")
28 | def run(cfg: DictConfig) -> None:
29 | run_benchmark((cfg))
30 | print(OmegaConf.to_yaml(cfg))
31 |
32 |
33 | if __name__ == "__main__":
34 | run()
35 |
--------------------------------------------------------------------------------
/benchmarks/vit/README.MD:
--------------------------------------------------------------------------------
1 |
16 |
17 | # Benchmark ViT
18 |
19 | - [Code - ViT Benchmark](https://github.com/legacyai/tf-transformers/tree/main/benchmark/vit)
20 |
21 | This is used to benchmark the performance of ViT model on text generation tasks. We evaluate it using 3 frameworks.
22 | Tensorflow-Transformers (default), HuggingFace PyTorch, HuggingFace Tensorflow and HuggingFace JAX (pending).
23 | Executing these scripts are fairly straightfoward and expect users to install the necessary libraries before executing
24 | the benchmark script.
25 |
26 | All the configuration are managed using [Hydra](https://github.com/facebookresearch/hydra).
27 |
28 | -> Machine - **Tesla V100-SXM2 32GB**
29 |
30 | -> Tensorflow-version - **2.4.1**
31 |
32 | -> Huggingface-Transformer-Version - **4.12.5**
33 |
34 | -> PyTorch-Version - **1.9.0**
35 |
36 | ## Tensorflow-Transformers. (tft)
37 |
38 | The default benchmark mode is ```tft```.
39 | 1. To execute ```tft``` (default) :
40 | ```python run.py benchmark=tft```
41 |
42 | 2. To execute ```type``` eg ```keras_model``` :
43 | ```python run.py benchmark=tft benchmark.model.type=keras_model```
44 |
45 | * a. keras_model - Uses tf.keras.Model.
46 | * b. saved_model - Uses tf.saved_model
47 | * c. saved_model_tf-io - Uses tf.saved_model, ```model + tf.io ``` is serialized together.
48 |
49 |
50 | ## HuggingFace-Tensorflow. (hf-tf)
51 |
52 | 1. To execute ```hf-tf``` (default) :
53 | ```python run.py benchmark=hf benchmark.model.type=tf```
54 |
55 |
56 | ## HuggingFace-PyTorch. (hf-pt)
57 |
58 | 1. To execute ```hf-pt``` (default) :
59 | ```python run.py benchmark=hf benchmark.model.type=pt```
60 |
61 |
62 | ## Official Benchmarks on Keras Flowed dataset (5000 samples)
63 |
64 | ```
65 | Text Classification:
66 | | | batch_size | time (s) | samples/second |
67 | |:---------------------------|-------------:|:-------------:|:-----------|------
68 | | tft + saved_model | 32 | 29.06 sec | 126 |
69 | | tft + saved_model + tf-io | 32 | 17.62 sec | 208 |
70 | | tft + keras_model | 32 | 29.48 sec | 124 |
71 | | hf_tf | 32 | 95.84 sec | 38 |
72 | | hf_pt | 32 | 24.79 sec | 148 |
73 | | tft + keras + hf pipeline | 32 | 94.44 sec | 39 |
74 | ```
75 |
--------------------------------------------------------------------------------
/benchmarks/vit/conf/benchmark/hf.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: hf
3 | device: cuda
4 |
5 | data:
6 | take_sample: false
7 | batch_size: 32
8 |
9 | model:
10 | name: google/vit-base-patch16-224
11 | type: tf
12 |
--------------------------------------------------------------------------------
/benchmarks/vit/conf/benchmark/tft.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: tft
3 |
4 | data:
5 | take_sample: false
6 | batch_size: 32
7 |
8 | model:
9 | name: google/vit-base-patch16-224
10 | type: saved_model
11 |
--------------------------------------------------------------------------------
/benchmarks/vit/conf/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - benchmark: tft
3 |
--------------------------------------------------------------------------------
/benchmarks/vit/run.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import hydra
4 | from omegaconf import DictConfig, OmegaConf
5 |
6 | # A logger for this file
7 | log = logging.getLogger(__name__)
8 |
9 |
10 | def run_benchmark(cfg):
11 |
12 | if cfg.benchmark.task.name == "tft":
13 | from benchmark_tft import TftBenchmark
14 |
15 | benchmark = TftBenchmark(cfg)
16 | results = benchmark.run()
17 | log.info(results)
18 |
19 | if cfg.benchmark.task.name == "hf":
20 | from benchmark_hf import HFBenchmark
21 |
22 | benchmark = HFBenchmark(cfg)
23 | results = benchmark.run()
24 | log.info(results)
25 |
26 |
27 | @hydra.main(config_path="conf", config_name="config")
28 | def run(cfg: DictConfig) -> None:
29 | run_benchmark((cfg))
30 | print(OmegaConf.to_yaml(cfg))
31 |
32 |
33 | if __name__ == "__main__":
34 | run()
35 |
--------------------------------------------------------------------------------
/custom_hook.py:
--------------------------------------------------------------------------------
1 | def add_new_files_to_jupytext():
2 | """This function will check all .ipynb and .md files.
3 | If a new .ipynb is present, it will convert to equivalent .md
4 | """
5 | import glob
6 | import subprocess
7 |
8 | all_ipynb = glob.glob('tutorials/*.ipynb')
9 | all_md = glob.glob('tutorials/*.md')
10 |
11 | all_ipynb = [name.split('.ipynb')[0] for name in all_ipynb]
12 | all_md = [name.split('.md')[0] for name in all_md]
13 |
14 | notebook_list = []
15 | for notebook_name in all_ipynb:
16 | if notebook_name not in all_md:
17 | notebook_list.append(notebook_name + '.ipynb')
18 |
19 | if notebook_list != []:
20 | for notebook in notebook_list:
21 | subprocess.run(["jupytext --set-formats ipynb,md:myst {}".format(notebook)], shell=True)
22 |
23 |
24 | def move_to_docs():
25 | """Move tutorals to docs"""
26 | print("Copying the tutorials to docs")
27 | import shutil
28 |
29 | shutil.copytree("tutorials", "docs/source/tutorials", dirs_exist_ok=True)
30 |
31 |
32 | if __name__ == '__main__':
33 | add_new_files_to_jupytext() # Convert new notebooks to md using jupytext
34 | move_to_docs() # Move new tuorials to docs
35 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | # sphinx <4 required by myst-nb v0.12.0 (Feb 2021)
2 | # sphinx >=3 required by sphinx-autodoc-typehints v1.11.1 (Oct 2020)
3 | sphinx >=3, <4
4 | sphinx_rtd_theme
5 | sphinx-autodoc-typehints==1.11.1
6 | jupyter-sphinx>=0.3.2
7 | myst-nb==v0.12.0
8 |
9 | # Packages as per HuggingFace extensions
10 | sphinx_rtd_theme
11 | recommonmark
12 | sphinx_markdown_tables
13 | sphinxext-opengraph
14 | sphinx-copybutton
15 | matplotlib
16 |
--------------------------------------------------------------------------------
/docs/requirements_build.txt:
--------------------------------------------------------------------------------
1 | # sphinx <4 required by myst-nb v0.12.0 (Feb 2021)
2 | # sphinx >=3 required by sphinx-autodoc-typehints v1.11.1 (Oct 2020)
3 | sphinx >=3, <4
4 | sphinx_rtd_theme==1.0.0
5 | sphinx-autodoc-typehints==1.11.1
6 | jupyter-sphinx>=0.3.2
7 | myst-nb
8 | jinja2==3.0.0
9 |
10 |
11 | # Packages as per HuggingFace extensions
12 | sphinx_rtd_theme
13 | recommonmark
14 | sphinx_markdown_tables
15 | sphinxext-opengraph
16 | sphinx-copybutton
17 | matplotlib
18 |
19 | # Packaged for build
20 | tensorflow-text==2.7.0
21 | sentencepiece
22 | tqdm
23 | transformers
24 |
--------------------------------------------------------------------------------
/docs/source/README.md:
--------------------------------------------------------------------------------
1 | We use jupytext to keep copies of notebook in sync with Markdown equivalent.
2 |
3 | ### Adding a new notebook
4 |
5 | ```
6 | jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb
7 | ```
8 |
9 | ### Syncing Notebooks
10 |
11 | After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running:
12 |
13 | ```
14 | jupytext --sync docs/notebooks/*
15 | ```
16 |
--------------------------------------------------------------------------------
/docs/source/README.rst:
--------------------------------------------------------------------------------
1 | We use jupytext to keep copies of notebook in sync with Markdown equivalent.
2 |
3 | ### Adding a new notebook
4 |
5 | ```
6 | jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb
7 | ```
8 |
9 | ### Syncing Notebooks
10 |
11 | After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running:
12 |
13 | ```
14 | jupytext --sync docs/notebooks/*
15 | ```
16 |
--------------------------------------------------------------------------------
/docs/source/_static/css/Calibre-Light.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/Calibre-Light.ttf
--------------------------------------------------------------------------------
/docs/source/_static/css/Calibre-Medium.otf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/Calibre-Medium.otf
--------------------------------------------------------------------------------
/docs/source/_static/css/Calibre-Regular.otf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/Calibre-Regular.otf
--------------------------------------------------------------------------------
/docs/source/_static/css/Calibre-Thin.otf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/Calibre-Thin.otf
--------------------------------------------------------------------------------
/docs/source/_static/css/DMSans-Bold.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSans-Bold.ttf
--------------------------------------------------------------------------------
/docs/source/_static/css/DMSans-BoldItalic.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSans-BoldItalic.ttf
--------------------------------------------------------------------------------
/docs/source/_static/css/DMSans-Italic.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSans-Italic.ttf
--------------------------------------------------------------------------------
/docs/source/_static/css/DMSans-Medium.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSans-Medium.ttf
--------------------------------------------------------------------------------
/docs/source/_static/css/DMSans-MediumItalic.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSans-MediumItalic.ttf
--------------------------------------------------------------------------------
/docs/source/_static/css/DMSans-Regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSans-Regular.ttf
--------------------------------------------------------------------------------
/docs/source/_static/css/DMSerifText-Italic.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSerifText-Italic.ttf
--------------------------------------------------------------------------------
/docs/source/_static/css/DMSerifText-Regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSerifText-Regular.ttf
--------------------------------------------------------------------------------
/docs/source/_static/css/code-snippets.css:
--------------------------------------------------------------------------------
1 | /* Code block Changes Comments */
2 | .highlight .c1, .highlight .sd{
3 | color: rgb(146, 56, 56);
4 | }
5 |
6 | /* Code block Changes from import*/
7 | .highlight .kn, .highlight .nv, .highlight .s2, .highlight .ow {
8 | color: #6670FF;
9 | }
10 |
11 | /* Code block Changes methods*/
12 | .highlight .nn, .highlight .k, .highlight .s1, .highlight .nb, .highlight .bp, .highlight .kc {
13 | color: #3C2478;
14 | }
15 |
16 | /* Code block Changes >>*/
17 | .highlight .gp {
18 | color: #3C2478;
19 | }
20 |
21 | /* Code block Changes Add border*/
22 | .rst-content div[class^='highlight'] {
23 | border-color: #1a1a1a;
24 | }
25 |
--------------------------------------------------------------------------------
/docs/source/_static/tf_transformers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/tf_transformers.png
--------------------------------------------------------------------------------
/docs/source/_static/tf_transformers_resized.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/tf_transformers_resized.png
--------------------------------------------------------------------------------
/docs/source/_static/transformers_blue.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/transformers_blue.png
--------------------------------------------------------------------------------
/docs/source/_static/transformers_mix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/transformers_mix.png
--------------------------------------------------------------------------------
/docs/source/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/favicon.ico
--------------------------------------------------------------------------------
/docs/source/imgs/long_block_sequencer.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/imgs/long_block_sequencer.gif
--------------------------------------------------------------------------------
/docs/source/imgs/philosophy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/imgs/philosophy.png
--------------------------------------------------------------------------------
/docs/source/model_doc/clip_feature_extractor.rst:
--------------------------------------------------------------------------------
1 | ..
2 | Copyright 2020 The HuggingFace Team and TFT Team. All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
5 | the License. You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
10 | an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
11 | specific language governing permissions and limitations under the License.
12 |
13 | CLIP Feature Extractor
14 | -----------------------------------------------------------------------------------------------------------------------
15 |
16 | Overview
17 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
18 |
19 | This page includes information about how to use CLIPFeatureExtractorTF with tensorflow-ops.
20 | This feature extractors works in sync with :class:`~tf.data.Dataset` and so is useful for on the fly preprocessing.
21 |
22 | .. code-block::
23 |
24 | >>> from tf_transformers.models import CLIPFeatureExtractorTF
25 | >>> image_path_list = # List fo image paths
26 | >>> CLIP_feature_extractor_tf = CLIPFeatureExtractorTF(img_height=224, img_width=224)
27 | >>> outputs = CLIP_feature_extractor_tf({'image': tf.constant(image_path_list)})
28 |
29 |
30 | CLIPFeatureExtractorTF
31 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
32 |
33 | .. autoclass:: tf_transformers.models.CLIPFeatureExtractorTF
34 | :members:
35 |
--------------------------------------------------------------------------------
/docs/source/model_doc/encoder_decoder.rst:
--------------------------------------------------------------------------------
1 | ..
2 | Copyright 2020 TFT Team. All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
5 | the License. You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
10 | an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
11 | specific language governing permissions and limitations under the License.
12 |
13 | EncoderDecoder
14 | -----------------------------------------------------------------------------------------------------------------------
15 |
16 | Overview
17 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
18 |
19 | The EncoderDecoder model also known as Seq2Seq Model, consists of Encoder and Decoder models, like
20 | Bert and GPT2 or Bert and Bert itself. For more details, please refer to Bart model :doc:`Bart ` or T5 Model
21 | :doc:`T5 ` .
22 | This consits of:
23 | - Encoder and Decoder
24 | - Cross attention between Decoder and Encoder
25 |
26 |
27 | EncoderDecoder (Seq2Seq)
28 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
29 |
30 | .. autoclass:: tf_transformers.models.EncoderDecoder
31 | :members:
32 |
--------------------------------------------------------------------------------
/docs/source/model_doc/m2m100.rst:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/model_doc/m2m100.rst
--------------------------------------------------------------------------------
/docs/source/model_doc/mbart.rst:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/model_doc/mbart.rst
--------------------------------------------------------------------------------
/docs/source/model_doc/t5_tokenizer.rst:
--------------------------------------------------------------------------------
1 | ..
2 | Copyright 2020 The HuggingFace Team and TFT Team. All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
5 | the License. You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
10 | an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
11 | specific language governing permissions and limitations under the License.
12 |
13 | T5 Tokenizer
14 | -----------------------------------------------------------------------------------------------------------------------
15 |
16 | Overview
17 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
18 |
19 | This page includes information about how to use T5Tokenizer with tensorflow-text.
20 | This tokenizer works in sync with :class:`~tf.data.Dataset` and so is useful for on the fly tokenization.
21 |
22 | .. code-block::
23 |
24 | >>> from tf_transformers.models import T5TokenizerTFText
25 | >>> tokenizer = T5TokenizerTFText.from_pretrained("t5-small")
26 | >>> text = ['The following statements are true about sentences in English:',
27 | '',
28 | 'A new sentence begins with a capital letter.']
29 | >>> inputs = {'text': text}
30 | >>> outputs = tokenizer(inputs) # Ragged Tensor Output
31 |
32 | # Dynamic Padding
33 | >>> tokenizer = T5TokenizerTFText.from_pretrained("t5-small", dynamic_padding=True)
34 | >>> text = ['The following statements are true about sentences in English:',
35 | '',
36 | 'A new sentence begins with a capital letter.']
37 | >>> inputs = {'text': text}
38 | >>> outputs = tokenizer(inputs) # Dict of tf.Tensor
39 |
40 | # Static Padding
41 | >>> tokenizer = T5TokenizerTFText.from_pretrained("t5-small", pack_model_inputs=True)
42 | >>> text = ['The following statements are true about sentences in English:',
43 | '',
44 | 'A new sentence begins with a capital letter.']
45 | >>> inputs = {'text': text}
46 | >>> outputs = tokenizer(inputs) # Dict of tf.Tensor
47 |
48 | # To Add Special Tokens
49 | >>> tokenizer = T5TokenizerTFText.from_pretrained("t5-small", add_special_tokens=True)
50 |
51 | T5TokenizerTFText
52 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
53 |
54 | .. autoclass:: tf_transformers.models.T5TokenizerTFText
55 | :members:
56 |
57 | T5TokenizerLayer
58 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
59 |
60 | .. autoclass:: tf_transformers.models.T5TokenizerLayer
61 | :members:
62 |
--------------------------------------------------------------------------------
/docs/source/model_doc/visual_bert.rst:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/model_doc/visual_bert.rst
--------------------------------------------------------------------------------
/docs/source/model_doc/vit_feature_extractor.rst:
--------------------------------------------------------------------------------
1 | ..
2 | Copyright 2020 TFT Team. All rights reserved.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
5 | the License. You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
10 | an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
11 | specific language governing permissions and limitations under the License.
12 |
13 | ViT Feature Extractor
14 | -----------------------------------------------------------------------------------------------------------------------
15 |
16 | Overview
17 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
18 |
19 | This page includes information about how to use ViTFeatureExtractorTF with tensorflow-ops.
20 | This feature extractors works in sync with :class:`~tf.data.Dataset` and so is useful for on the fly preprocessing.
21 |
22 | .. code-block::
23 |
24 | >>> from tf_transformers.models import ViTFeatureExtractorTF
25 | >>> image_path_list = # List fo image paths
26 | >>> vit_feature_extractor_tf = ViTFeatureExtractorTF(img_height=224, img_width=224)
27 | >>> outputs = vit_feature_extractor_tf({'image': tf.constant(image_path_list)})
28 |
29 | ViTFeatureExtractorTF
30 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
31 |
32 | .. autoclass:: tf_transformers.models.ViTFeatureExtractorTF
33 | :members:
34 |
--------------------------------------------------------------------------------
/docs/source/model_doc/wav2vec2.rst:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/model_doc/wav2vec2.rst
--------------------------------------------------------------------------------
/docs/source/model_usage/sentence_transformers.md:
--------------------------------------------------------------------------------
1 | ---
2 | jupytext:
3 | formats: ipynb,md:myst
4 | text_representation:
5 | extension: .md
6 | format_name: myst
7 | format_version: 0.13
8 | jupytext_version: 1.13.5
9 | kernelspec:
10 | display_name: Python 3 (ipykernel)
11 | language: python
12 | name: python3
13 | ---
14 |
15 | # Sentence Transformer in tf-transformers
16 |
17 | * This is a simple tutorial to demonstrate how ```SentenceTransformer``` models has been integrated
18 | to ```tf-transformers``` and how to use it
19 | * The following tutorial is applicable to all supported ```SentenceTransformer``` models.
20 |
21 | ```{code-cell} ipython3
22 |
23 | ```
24 |
25 | ### Load Sentence-t5 model
26 |
27 | ```{code-cell} ipython3
28 | import tensorflow as tf
29 | from tf_transformers.models import SentenceTransformer
30 | ```
31 |
32 | ```{code-cell} ipython3
33 | model_name = 'sentence-transformers/sentence-t5-base' # Load any sentencetransformer model here
34 | model = SentenceTransformer.from_pretrained(model_name)
35 | ```
36 |
37 | ### Whats my model input?
38 |
39 | * All models in ```tf-transformers``` are designed with full connections. All you need is ```model.input``` if its a ```LegacyModel/tf.keras.Model``` or ```model.model_inputs``` if its a ```LegacyLayer/tf.keras.layers.Layer```
40 |
41 | ```{code-cell} ipython3
42 | model.input
43 | ```
44 |
45 | ### Whats my model output?
46 |
47 | * All models in ```tf-transformers``` are designed with full connections. All you need is ```model.output``` if its a ```LegacyModel/tf.keras.Model``` or ```model.model_outputs``` if its a ```LegacyLayer/tf.keras.layers.Layer```
48 |
49 | ```{code-cell} ipython3
50 | model.output
51 | ```
52 |
53 | ### Sentence vectors
54 |
55 | ```{code-cell} ipython3
56 | from transformers import AutoTokenizer
57 |
58 | tokenizer = AutoTokenizer.from_pretrained(model_name)
59 |
60 | text = ['This is a sentence to get vector', 'This one too']
61 | inputs = tokenizer(text, return_tensors='tf', padding=True)
62 |
63 | inputs_tf = {'input_ids': inputs['input_ids'], 'input_mask': inputs['attention_mask']}
64 | outputs_tf = model(inputs_tf)
65 | print("Sentence vector", outputs_tf['sentence_vector'].shape)
66 | ```
67 |
68 | ```{code-cell} ipython3
69 |
70 | ```
71 |
72 | ### Serialize as usual and load it
73 |
74 | * Serialize, load and assert outputs with non serialized ```(```tf.keras.Model```)```
75 |
76 | ```{code-cell} ipython3
77 | model_dir = 'MODELS/sentence_t5'
78 | model.save_transformers_serialized(model_dir)
79 |
80 | loaded = tf.saved_model.load(model_dir)
81 | model = loaded.signatures['serving_default']
82 |
83 | outputs_tf_serialized = model(**inputs_tf)
84 |
85 | tf.debugging.assert_near(outputs_tf['sentence_vector'], outputs_tf_serialized['sentence_vector'])
86 | ```
87 |
88 | ```{code-cell} ipython3
89 |
90 | ```
91 |
92 | ```{code-cell} ipython3
93 |
94 | ```
95 |
--------------------------------------------------------------------------------
/docs/source/tutorials/README.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "4193b067",
6 | "metadata": {},
7 | "source": [
8 | "We use jupytext to keep copies of notebook in sync with Markdown equivalent.\n",
9 | "\n",
10 | "### Adding a new notebook\n",
11 | "\n",
12 | "```\n",
13 | "jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb\n",
14 | "```\n",
15 | "\n",
16 | "### Syncing Notebooks\n",
17 | "\n",
18 | "After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running:\n",
19 | "\n",
20 | "```\n",
21 | "jupytext --sync docs/notebooks/*\n",
22 | "```"
23 | ]
24 | }
25 | ],
26 | "metadata": {
27 | "jupytext": {
28 | "cell_metadata_filter": "-all",
29 | "formats": "ipynb,md:myst",
30 | "main_language": "python"
31 | }
32 | },
33 | "nbformat": 4,
34 | "nbformat_minor": 5
35 | }
36 |
--------------------------------------------------------------------------------
/docs/source/tutorials/README.md:
--------------------------------------------------------------------------------
1 | ---
2 | jupytext:
3 | cell_metadata_filter: -all
4 | formats: ipynb,md:myst
5 | main_language: python
6 | text_representation:
7 | extension: .md
8 | format_name: myst
9 | format_version: 0.13
10 | jupytext_version: 1.14.4
11 | ---
12 |
13 | We use jupytext to keep copies of notebook in sync with Markdown equivalent.
14 |
15 | ### Adding a new notebook
16 |
17 | ```
18 | jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb
19 | ```
20 |
21 | ### Syncing Notebooks
22 |
23 | After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running:
24 |
25 | ```
26 | jupytext --sync docs/notebooks/*
27 | ```
28 |
--------------------------------------------------------------------------------
/docs/source/tutorials/push_model_to_hf_hub.md:
--------------------------------------------------------------------------------
1 | ---
2 | jupytext:
3 | formats: ipynb,md:myst
4 | text_representation:
5 | extension: .md
6 | format_name: myst
7 | format_version: 0.13
8 | jupytext_version: 1.14.4
9 | kernelspec:
10 | display_name: Python 3 (ipykernel)
11 | language: python
12 | name: python3
13 | ---
14 |
15 | ```{code-cell} ipython3
16 |
17 | ```
18 |
19 | ```{code-cell} ipython3
20 | import subprocess
21 | import os
22 | from distutils.dir_util import copy_tree
23 | ```
24 |
25 | ### How to push a model to hub .
26 |
27 | * Make sure you have logged to ```huggingface-cli login``` using your token.
28 |
29 | ```{code-cell} ipython3
30 |
31 | ```
32 |
33 | ```{code-cell} ipython3
34 | model_name = 'byt5-small'
35 | ```
36 |
37 | #### 1. Create model name directory under organization name
38 |
39 | ```{code-cell} ipython3
40 | subprocess.run(['huggingface-cli', 'repo',
41 | 'create', '{}'.format(model_name),
42 | '--yes',
43 | '--organization', 'tftransformers'])
44 | ```
45 |
46 | ```{code-cell} ipython3
47 |
48 | ```
49 |
50 | #### 2. Now clone that above created repo/folder to our local cwd
51 |
52 | ```{code-cell} ipython3
53 | subprocess.run(["git", "clone", "https://huggingface.co/tftransformers/{}".format(model_name)])
54 | ```
55 |
56 | ```{code-cell} ipython3
57 |
58 | ```
59 |
60 | #### 3. Now move your model directory , to current working directory under ```model_name``` directory
61 |
62 | ```{code-cell} ipython3
63 | cwd = os.getcwd() # Getc current working dir
64 | new_working_dir = os.path.join(cwd, model_name) # This is cloned from hf hub under organization
65 | os.chdir("{}".format(new_working_dir)) # Switch to new working dir
66 |
67 | # Cached model directory keep changing as per other machine
68 | cached_model_dir = '/var/folders/vq/4fxns8l55gq8_msgygbyb51h0000gn/T/tf_transformers_cache/{}/'.format(model_name)
69 |
70 | # Copy cached model directory , to new working directory
71 | copy_tree(cached_model_dir, new_working_dir)
72 | ```
73 |
74 | ```{code-cell} ipython3
75 |
76 | ```
77 |
78 | #### 4. Now time to push these model to hub
79 |
80 | ```{code-cell} ipython3
81 | subprocess.run(["git-lfs", "track", "*"])
82 | subprocess.run(["git", "add", "."])
83 | subprocess.run(["git", "commit", "-m", "Pushing new model {}".format(model_name)]) # Commit message
84 | subprocess.run(["git", "push"])
85 |
86 | # Change back to original cwd
87 | os.chdir("{}".format(cwd))
88 | ```
89 |
90 | ```{code-cell} ipython3
91 |
92 | ```
93 |
--------------------------------------------------------------------------------
/docs/source/tutorials/sample.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "source": [],
7 | "outputs": [],
8 | "metadata": {}
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "source": [
13 | "This is a sample ipynb file to check jupytext hooks"
14 | ],
15 | "metadata": {}
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "source": [
21 | "import tensorflow as tf"
22 | ],
23 | "outputs": [],
24 | "metadata": {}
25 | }
26 | ],
27 | "metadata": {
28 | "orig_nbformat": 4,
29 | "language_info": {
30 | "name": "python"
31 | }
32 | },
33 | "nbformat": 4,
34 | "nbformat_minor": 2
35 | }
--------------------------------------------------------------------------------
/isort.cfg:
--------------------------------------------------------------------------------
1 | [settings]
2 | #force_single_line=True
3 | multi_line_output=3
4 | force_grid_wrap=0
5 | use_parentheses=True
6 | line_length=120
7 | profile = "black"
8 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | show_error_codes = True
3 | disable_error_code = attr-defined
4 |
5 |
6 |
7 | [mypy-six.*]
8 | ignore_missing_imports = True
9 |
10 | [mypy-google.*]
11 | ignore_missing_imports = True
12 |
--------------------------------------------------------------------------------
/patch_version.py:
--------------------------------------------------------------------------------
1 | # This script is adapted from python-semantic-release
2 | # flake8: noqa
3 |
4 | import sys
5 |
6 |
7 | def patch(required_version):
8 | """
9 | Write the new version to init and test file
10 | """
11 |
12 | required_version = required_version.replace('v', '') # Replace v1.0.0 to 1.0.0
13 | # Edit src
14 | with open('src/tf_transformers/__init__.py', 'w') as f:
15 | f.write('__version__ = "{}"\n'.format(required_version))
16 |
17 | # Edit test
18 | test_content = 'from tf_transformers import __version__\n\nversion = "{}"\ndef test_version():\n assert __version__ == version\n'
19 | with open('tests/test_tf_transformers.py', 'w') as f:
20 | f.write(test_content.format(required_version))
21 |
22 |
23 | if __name__ == '__main__':
24 | version = sys.argv[-1] # last cli argument
25 | patch(version)
26 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "tf-transformers"
3 | version = "v2.0.0"
4 | description = "NLP with Transformer based models on Tensorflow 2.0"
5 | authors = ["Sarath R Nair "]
6 | maintainers = ["Sarath R Nair "]
7 | license = "MIT"
8 | readme = "README.md"
9 | homepage = ""
10 | repository = ""
11 | documentation = ""
12 | keywords = [
13 | "tensorflow",
14 | "transformers",
15 | "nlp",
16 | "keras",
17 | "bert",
18 | "deep learning"
19 | ]
20 |
21 | [tool.poetry.dependencies]
22 | python = "^3.9"
23 | tqdm = "^4.62.3"
24 | transformers = "^4.15.0"
25 | sentencepiece = "^0.1.96"
26 | absl-py = "^1.0.0"
27 | cffi = "^1.15.1"
28 |
29 |
30 | [tool.poetry.dev-dependencies]
31 | pytest = "^5.2"
32 | pylint = "^2.6.0"
33 | coverage = {extras = ["toml"], version = "^5.5"}
34 | pytest-cov = "^2.11.1"
35 | python-semantic-release = "^7.23.0"
36 |
37 | [tool.semantic_release]
38 | version_variable = [
39 | "src/tf_transformers/__init__.py:__version__",
40 | "tests/test_tf_transformers.py:version"
41 | ]
42 | version_toml = [
43 | "pyproject.toml:tool.poetry.version"
44 | ]
45 | version_pattern = [
46 | "README.md:version: v{version}"
47 |
48 | ]
49 | branch = "main"
50 | dist_path = "dist/"
51 | upload_to_pypi = false
52 | remove_dist = false
53 | build_command = "pip install poetry && poetry build"
54 |
55 | [build-system]
56 | requires = ["poetry-core>=1.0.4"]
57 | build-backend = "poetry.core.masonry.api"
58 |
59 | [tool.isort]
60 | profile = "black"
61 |
62 | [tool.black]
63 | skip-string-normalization = true
64 | line-length = 120
65 | target-version = ['py37']
66 | include = '\.pyi?$'
67 | exclude = '''
68 |
69 | (
70 | /(
71 | \.eggs # exclude a few common directories in the
72 | | \.git # root of the project
73 | | \.hg
74 | | \.mypy_cache
75 | | \.tox
76 | | \.venv
77 | | _build
78 | | buck-out
79 | | build
80 | | dist
81 | | tutorials
82 | )/
83 | | foo.py # also separately exclude a file named foo.py in
84 | # the root of the project
85 | )
86 | '''
87 |
--------------------------------------------------------------------------------
/research/c4_grammatical_correction/README.md:
--------------------------------------------------------------------------------
1 |
2 | ### Sentence Masked Language Modeling using TFText
3 |
4 | This folder has all necessary scripts to run the Sentence MLM using tensorflow text.
5 | Instead of masking words, we mask sentences (sequence of words)
6 |
7 | ### Advantage
8 |
9 | No need to prepare features in TFRecord format. We make use of dynamic mlm. All we need is a text file or folder of text files with text files, which each line corrsponding to text.
10 |
11 | ### WandB
12 |
13 | By default we are using Wandb. if enviornment variable ```WANDB_PROJECT=None```, wandb will be disabled.
14 |
15 | ``` export WANDB_PROJECT='t5-c4-grammatical-correction' ```
16 | ### Configuration (Hydra)
17 |
18 | All or most configurations can be managed using ```conf/config.yaml```. You can override it by command line also.
19 |
20 | Eg: For TPU , we need a data in GCS and model_checkpoint_dir to be in GCS too.
21 |
22 | ``` python3 run_c4_grammar_correction.py \
23 | task.data_directory=gs://legacyai-bucket/c4_grammar_correction_data \
24 | task.train_batch_size=512 \
25 | trainer.dtype=bf16 \
26 | trainer.model_checkpoint_dir=gs://legacyai-bucket/t5_c4_lr_3e5 \
27 | trainer.steps_per_epoch=50000 \
28 | trainer.epochs=10 \
29 | trainer.strategy=tpu \
30 | trainer.tpu_address=legacyai-tpu-2 \
31 | optimizer.learning_rate=3e-5
32 | model.is_training=true
33 | model.use_dropout=true
34 | ```
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/research/c4_grammatical_correction/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/research/c4_grammatical_correction/__init__.py
--------------------------------------------------------------------------------
/research/c4_grammatical_correction/conf/config.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | data_directory:
3 | max_seq_len: 64
4 | train_batch_size: 512
5 | trainer:
6 | dtype: bf16
7 | num_gpus: 0
8 | tpu_address:
9 | epochs: 10
10 | strategy: mirrored
11 | steps_per_epoch: 50000
12 | model_checkpoint_dir:
13 | global_norm: 1.0
14 | optimizer:
15 | learning_rate: 3e-5
16 | num_warmup_steps: 0.1
17 | decay_function: cosine
18 | adam_beta_1: 0.9
19 | adam_beta_2: 0.95
20 | adam_epsilon: 10e-8
21 | weight_decay_rate: 0.1
22 | optimizer_type: adamw
23 | loss_type:
24 | use_constant_lr: false
25 | model:
26 | is_training: false
27 | use_dropout: false
28 | num_layers: 12
29 |
--------------------------------------------------------------------------------
/research/c4_grammatical_correction/run_c4_grammar_correction.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """This is the main script to run Mix Language Model"""
18 | import os
19 | import warnings
20 |
21 | import hydra
22 | from absl import logging
23 | from omegaconf import DictConfig
24 | from train_c4_grammar_correction import run_train
25 |
26 | logging.set_verbosity("INFO")
27 |
28 | # We set PROJECT_NAME from ENVIORNMENT VARIABLE
29 | WANDB_PROJECT = os.getenv('WANDB_PROJECT', None)
30 | use_wandb = True
31 | if WANDB_PROJECT is None:
32 | warnings.warn(
33 | "For wandb-project should not be None.\
34 | Set export WANDB_PROJECT="
35 | )
36 | use_wandb = False
37 |
38 |
39 | @hydra.main(config_path="conf", config_name="config")
40 | def run(cfg: DictConfig) -> None:
41 | print("Config", cfg)
42 | config_dict = dict(cfg)
43 | # For TPU, we need to initialize it before tf text dataset
44 | # starts triggering. Hack
45 | if cfg.trainer.strategy == 'tpu':
46 | from model import get_trainer
47 |
48 | distribution_strategy = 'tpu'
49 | num_gpus = 0
50 | tpu_address = cfg.trainer.tpu_address
51 | get_trainer(
52 | distribution_strategy=distribution_strategy,
53 | num_gpus=num_gpus,
54 | tpu_address=tpu_address,
55 | dtype=cfg.trainer.dtype,
56 | ) # noqa
57 |
58 | if use_wandb:
59 | import wandb
60 |
61 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True)
62 | history = run_train(cfg, wandb)
63 | else:
64 | # Set wandb = None
65 | history = run_train(cfg, None)
66 | return history
67 |
68 |
69 | if __name__ == "__main__":
70 | history = run()
71 |
--------------------------------------------------------------------------------
/research/diffusion/beta_schedule.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 |
5 |
6 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
7 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
8 | warmup_time = int(num_diffusion_timesteps * warmup_frac)
9 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
10 | return betas
11 |
12 |
13 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
14 | """
15 | Create a beta schedule that discretizes the given alpha_t_bar function,
16 | which defines the cumulative product of (1-beta) over time from t = [0,1].
17 |
18 | :param num_diffusion_timesteps: the number of betas to produce.
19 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
20 | produces the cumulative product of (1-beta) up to that
21 | part of the diffusion process.
22 | :param max_beta: the maximum beta to use; use values lower than 1 to
23 | prevent singularities.
24 | """
25 | betas = []
26 | for i in range(num_diffusion_timesteps):
27 | t1 = i / num_diffusion_timesteps
28 | t2 = (i + 1) / num_diffusion_timesteps
29 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
30 | return np.array(betas)
31 |
32 |
33 | def get_beta_schedule(beta_schedule, num_diffusion_timesteps, beta_start=0.0001, beta_end=0.02):
34 | if beta_schedule == 'quad':
35 | betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2
36 | elif beta_schedule == 'linear':
37 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
38 | elif beta_schedule == 'warmup10':
39 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
40 | elif beta_schedule == 'warmup50':
41 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
42 | elif beta_schedule == 'const':
43 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
44 | elif beta_schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
45 | betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
46 | elif beta_schedule == 'cosine':
47 | betas = betas_for_alpha_bar(
48 | num_diffusion_timesteps,
49 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
50 | )
51 | else:
52 | raise NotImplementedError(beta_schedule)
53 |
54 | assert betas.shape == (num_diffusion_timesteps,)
55 | return betas
56 |
--------------------------------------------------------------------------------
/research/diffusion/time_embedding_layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class TimeEmbedding(tf.keras.layers.Layer):
5 | """Creates a sinusoidal embedding.
6 |
7 | This layer creates a sinusoidal embedding as described in "Attention is All you need":
8 |
9 | """
10 |
11 | def __init__(
12 | self,
13 | n_channels,
14 | initializer="glorot_uniform",
15 | bias_initializer='zeros',
16 | name="time_embeddings",
17 | activation='swish',
18 | use_bias=True,
19 | dtype=tf.float32,
20 | **kwargs,
21 | ):
22 | """
23 | Args:
24 | n_channels ([int]): Similar to embedding size
25 | scale_factor ([int]): How much to scale embedding size for next dense layer
26 | initializer (str, optional): The initializer to use for the
27 | embedding weights. Defaults to "glorot_uniform".
28 | name (str, optional): name of the layer. Defaults to "positional_embeddings".
29 | dtype ([type], optional): [description]. Defaults to tf.float32.
30 | """
31 | super(TimeEmbedding, self).__init__(name=name, dtype=dtype, **kwargs)
32 |
33 | assert (n_channels % 2) == 0
34 | self._n_channels = n_channels
35 | self._initializer = initializer
36 | self._bias_initializer = bias_initializer
37 | self._dtype = dtype
38 |
39 | def get_config(self):
40 | """Config based on init arguments
41 |
42 | Returns:
43 | [dict]: Dict of all init arguments
44 | """
45 | config = {
46 | "n_channels": self._n_channels,
47 | "initializer": self._initializer,
48 | "name": self._name,
49 | "dtype": self._dtype,
50 | }
51 | base_config = super(TimeEmbedding, self).get_config()
52 | return dict(list(base_config.items()) + list(config.items()))
53 |
54 | def call(self, timesteps):
55 | """Call
56 |
57 | Args:
58 | timesteps ([tf.Tensor]): input ids 1D timesteps (B, ) B is batch_size
59 | eg:
60 |
61 | Returns:
62 | [tf.Tensor]: embeddings 3D (b x s x h)
63 | """
64 | half_dim = self._n_channels // 2
65 | emb = tf.math.log(10000.0) / (half_dim - 1)
66 | emb = tf.exp(tf.range(half_dim, dtype=self._dtype) * -emb) # 1-D vector of size half_dim
67 | emb = tf.cast(tf.expand_dims(timesteps, axis=1), self._dtype) * tf.expand_dims(emb, axis=0) # B x half_dim
68 | emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=1) # B x self._n_channels
69 |
70 | return emb
71 |
--------------------------------------------------------------------------------
/research/diffusion/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | # As per original diffusion code
5 | def get_initializer(scale=0):
6 | if scale == 0:
7 | scale = scale = 1e-10
8 | initializer = tf.keras.initializers.VarianceScaling(scale=scale, mode='fan_avg', distribution='uniform')
9 | return initializer
10 |
--------------------------------------------------------------------------------
/research/glue/conf/config.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_batch_size: 32
3 | eval_batch_size: 64
4 | take_sample: false
5 | max_seq_length: 128
6 | static_padding: false
7 | trainer:
8 | dtype: fp32
9 | num_gpus: 2
10 | tpu_address:
11 | epochs: 3
12 | strategy: mirrored
13 | optimizer:
14 | learning_rate: 2e-5
15 | loss_type:
16 | model:
17 | is_training: true
18 | use_dropout: true
19 |
--------------------------------------------------------------------------------
/research/glue/conf/glue/cola.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: cola
3 | data:
4 | name: cola
5 | num_classes: 2
6 | max_seq_length: 256
7 |
--------------------------------------------------------------------------------
/research/glue/conf/glue/mnli.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: mnli
3 | data:
4 | name: mnli
5 | num_classes: 3
6 | max_seq_length: 256
7 |
--------------------------------------------------------------------------------
/research/glue/conf/glue/mrpc.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: mrpc
3 | data:
4 | name: mrpc
5 | num_classes: 2
6 | max_seq_length: 256
7 |
--------------------------------------------------------------------------------
/research/glue/conf/glue/qnli.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: qnli
3 | data:
4 | name: qnli
5 | num_classes: 2
6 | max_seq_length: 256
7 |
--------------------------------------------------------------------------------
/research/glue/conf/glue/qqp.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: qqp
3 | data:
4 | name: qqp
5 | num_classes: 2
6 | max_seq_length: 128
7 |
--------------------------------------------------------------------------------
/research/glue/conf/glue/rte.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: rte
3 | data:
4 | name: rte
5 | num_classes: 2
6 | max_seq_length: 256
7 |
--------------------------------------------------------------------------------
/research/glue/conf/glue/sst2.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: sst2
3 | data:
4 | name: sst2
5 | num_classes: 2
6 | max_seq_length: 256
7 |
--------------------------------------------------------------------------------
/research/glue/conf/glue/stsb.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | name: stsb
3 | data:
4 | name: stsb
5 | num_classes: 1
6 | max_seq_length: 128
7 |
--------------------------------------------------------------------------------
/research/glue/model.py:
--------------------------------------------------------------------------------
1 | from transformers import AlbertTokenizer
2 |
3 | from tf_transformers.core import Trainer
4 | from tf_transformers.models import AlbertModel as Model
5 | from tf_transformers.optimization import create_optimizer
6 |
7 | MODEL_NAME = "albert-base-v2"
8 |
9 |
10 | def get_model(return_all_layer_outputs, is_training, use_dropout):
11 | """Get the model"""
12 | model = Model.from_pretrained(
13 | MODEL_NAME, return_all_layer_outputs=return_all_layer_outputs, is_training=is_training, use_dropout=use_dropout
14 | )
15 | return model
16 |
17 |
18 | def get_tokenizer():
19 | """Get Tokenizer"""
20 | return AlbertTokenizer.from_pretrained(MODEL_NAME)
21 |
22 |
23 | def get_optimizer(learning_rate, examples, batch_size, epochs):
24 | """Get optimizer"""
25 | steps_per_epoch = int(examples / batch_size)
26 | num_train_steps = steps_per_epoch * epochs
27 | warmup_steps = int(0.1 * num_train_steps)
28 |
29 | def optimizer_fn():
30 | optimizer, learning_rate_fn = create_optimizer(learning_rate, num_train_steps, warmup_steps)
31 | return optimizer
32 |
33 | return optimizer_fn
34 |
35 |
36 | def get_trainer(distribution_strategy, num_gpus=0, tpu_address=None):
37 | """Get Trainer"""
38 | trainer = Trainer(distribution_strategy, num_gpus=num_gpus, tpu_address=tpu_address)
39 | return trainer
40 |
--------------------------------------------------------------------------------
/research/glue/run_mnli_mismatched.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """Run MNLI Mismatched validation"""
18 |
19 | import os
20 |
21 | import datasets
22 | from model import get_tokenizer
23 |
24 |
25 | def run_mnli_mismatched_evaluation(
26 | model, model_dir, write_tfrecord, read_tfrecord, metric_callback, number_of_checkpoints, max_seq_length
27 | ):
28 | """MNLI Mismatched evaluation"""
29 |
30 | data = datasets.load_dataset("glue", 'mnli')
31 |
32 | # Validation matched
33 | tokenizer = get_tokenizer()
34 | tfrecord_dir = "/tmp/glue/mnli_mismatched/"
35 | take_sample = False
36 | eval_batch_size = 32
37 |
38 | write_tfrecord(
39 | data["validation_mismatched"],
40 | max_seq_length,
41 | tokenizer,
42 | tfrecord_dir,
43 | mode="eval",
44 | take_sample=take_sample,
45 | verbose=1000,
46 | )
47 |
48 | # Read TFRecords Validation
49 | eval_tfrecord_dir = os.path.join(tfrecord_dir, "eval")
50 | eval_dataset, total_eval_examples = read_tfrecord(
51 | eval_tfrecord_dir, eval_batch_size, shuffle=False, drop_remainder=False
52 | )
53 |
54 | results_per_epoch = []
55 | for i in range(1, number_of_checkpoints + 1):
56 | # Load checkpoint
57 | ckpt_path = os.path.join(model_dir, "ckpt-{}".format(i))
58 |
59 | model.load_checkpoint(checkpoint_path=ckpt_path)
60 |
61 | result = metric_callback({"model": model, "validation_dataset": eval_dataset})
62 | results_per_epoch.append(result)
63 |
64 | return results_per_epoch
65 |
--------------------------------------------------------------------------------
/research/long_block_sequencer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/research/long_block_sequencer/__init__.py
--------------------------------------------------------------------------------
/research/long_block_sequencer/conf/config.yaml:
--------------------------------------------------------------------------------
1 |
2 | task:
3 | max_seq_len: 4096
4 | decoder_seq_len: 256
5 | num_splits: 8
6 | use_gru_layer: false
7 | projection_dimension: 512
8 | train_batch_size: 8
9 |
10 | trainer:
11 | dtype: fp16
12 | num_gpus: 2
13 | tpu_address:
14 | epochs: 3
15 | strategy: mirrored
16 | model_checkpoint_dir:
17 | optimizer:
18 | learning_rate: 0.001
19 | loss_type:
20 | use_constant_lr: true
21 | model:
22 | num_layers: 12
23 | model_name: 't5-small'
24 | eval:
25 | eval_batch_size: 4
26 | model_checkpoint_dir:
27 | model_checkpoint_path:
28 | take_sample: false
29 | mode:
30 |
--------------------------------------------------------------------------------
/research/long_block_sequencer/run_long_block_sequencer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """This is the main script to run Long Block Sequencer Model"""
18 | import os
19 |
20 | import hydra
21 | from absl import logging
22 | from omegaconf import DictConfig
23 | from train_long_block_sequencer import run_train
24 |
25 | logging.set_verbosity("INFO")
26 |
27 | # We set PROJECT_NAME from ENVIORNMENT VARIABLE
28 | WANDB_PROJECT = os.getenv('WANDB_PROJECT', None)
29 | use_wandb = True
30 | if WANDB_PROJECT is None:
31 | logging.info("Not using wandb as no `WANDB_PROJECT` has been set via export ")
32 | use_wandb = False
33 |
34 |
35 | @hydra.main(config_path="conf", config_name="config")
36 | def run(cfg: DictConfig) -> None:
37 | print("Config", cfg)
38 | config_dict = dict(cfg)
39 | # For TPU, we need to initialize it before tf text dataset
40 | # starts triggering. Hack
41 | if cfg.trainer.strategy == 'tpu':
42 | from model import get_trainer
43 |
44 | distribution_strategy = 'tpu'
45 | num_gpus = 0
46 | tpu_address = cfg.trainer.tpu_address
47 | get_trainer(
48 | distribution_strategy=distribution_strategy,
49 | num_gpus=num_gpus,
50 | tpu_address=tpu_address,
51 | dtype=cfg.trainer.dtype,
52 | ) # noqa
53 |
54 | if use_wandb:
55 | import wandb
56 |
57 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True)
58 | history = run_train(cfg, wandb)
59 | else:
60 | # Set wandb = None
61 | history = run_train(cfg, None)
62 | return history
63 |
64 |
65 | if __name__ == "__main__":
66 | history = run()
67 |
--------------------------------------------------------------------------------
/research/masked_language_model/README.md:
--------------------------------------------------------------------------------
1 |
2 | ### Masked Language Modeling using TFText
3 |
4 | This folder has all necessary scripts to run the MLM using tensorflow text.
5 |
6 | ### Advantage
7 |
8 | No need to prepare features in TFRecord format. We make use of dynamic mlm. All we need is a text file or folder of text files with text files, which each line corrsponding to text.
9 |
10 | ### Configuration (Hydra)
11 |
12 | All or most configurations can be managed using ```conf/config.yaml```. You can override it by command line also.
13 |
14 | Eg: For TPU , we need a data in GCS and model_checkpoint_dir to be in GCS too.
15 |
16 | ```python3 run_mlm.py \ data.data_directory=$GCP_BUCKET/data/ \ trainer.model_checkpoint_dir=$GCP_BUCKET/model```
17 |
18 | ### WandB
19 |
20 | By default we are using Wandb. Check ```run_mlm.py``` to disable it.
21 |
--------------------------------------------------------------------------------
/research/masked_language_model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/research/masked_language_model/__init__.py
--------------------------------------------------------------------------------
/research/masked_language_model/conf/config.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_directory:
3 | train_batch_size: 32
4 | task:
5 | max_seq_len: 128
6 | max_predictions_per_seq: 20
7 | trainer:
8 | dtype: fp32
9 | num_gpus: 2
10 | tpu_address:
11 | epochs: 3
12 | strategy: mirrored
13 | steps_per_epoch: 10000
14 | model_checkpoint_dir:
15 | optimizer:
16 | learning_rate: 3e-5
17 | warmup_rate: 0.2
18 | loss_type:
19 | use_constant_lr: false
20 | model:
21 | is_training: true
22 | use_dropout: true
23 | num_layers: 24
24 |
--------------------------------------------------------------------------------
/research/masked_language_model/dataset_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from random import shuffle
3 |
4 | import tensorflow as tf
5 |
6 |
7 | def get_dataset(data_directory, masked_lm_map_fn, batch_size):
8 | """Convert text to tf.data.Dataset after map fn
9 |
10 | Args:
11 | data_directory ([type]): [description]
12 | masked_lm_map_fn ([type]): [description]
13 | batch_size ([type]): [description]
14 |
15 | Returns:
16 | [type]: [description]
17 | """
18 |
19 | def filter_out_empty_mask(x, y):
20 | """When an example doesn't have multiple sentences\
21 | there wont be any masked sentence. Ignore those examples,
22 | as nothing to predict.
23 | """
24 | return tf.greater(tf.reduce_sum(tf.cast(tf.not_equal(x['masked_lm_positions'], 0), tf.int32)), 0)
25 |
26 | all_text_files = tf.io.gfile.glob(os.path.join(data_directory, '*.txt'))
27 | shuffle(all_text_files)
28 | ds = tf.data.TextLineDataset(all_text_files)
29 | # Our data has sentences joined by '__||__'. So, for word based MLM
30 | # we need to replace '__||__', by ''. and club it as a single sentence
31 | # tf.strings.regex_replace not working as expected
32 | ds = ds.map(lambda x: tf.strings.split(x, '__||__'), num_parallel_calls=tf.data.AUTOTUNE)
33 | ds = ds.map(lambda x: tf.strings.reduce_join([x], separator=' '), num_parallel_calls=tf.data.AUTOTUNE)
34 |
35 | # We need to add the text as dict
36 | ds = ds.map(lambda x: {'text': x}, num_parallel_calls=tf.data.AUTOTUNE)
37 |
38 | # Do MLM
39 | ds = ds.map(masked_lm_map_fn, num_parallel_calls=tf.data.AUTOTUNE)
40 |
41 | # Filter examples if there is not atleast single MASK sentence
42 | ds = ds.filter(filter_out_empty_mask)
43 |
44 | # # Shuffle and Prefetch
45 | ds = ds.shuffle(100, reshuffle_each_iteration=True).prefetch(buffer_size=tf.data.AUTOTUNE)
46 |
47 | # Batch
48 | ds = ds.batch(batch_size, drop_remainder=True)
49 |
50 | # Auto SHARD
51 | options = tf.data.Options()
52 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO
53 | ds = ds.with_options(options)
54 |
55 | return ds
56 |
--------------------------------------------------------------------------------
/research/masked_language_model/model.py:
--------------------------------------------------------------------------------
1 | from tf_transformers.core import Trainer
2 | from tf_transformers.losses.loss_wrapper import get_lm_loss
3 | from tf_transformers.models import (
4 | BigBirdRobertaTokenizerTFText,
5 | GPT2Model,
6 | MaskedLMModel,
7 | )
8 | from tf_transformers.optimization import create_optimizer
9 |
10 | MODEL_NAME = 'gpt2'
11 | TOKENIZER_NAME = "google/bigbird-roberta-large"
12 |
13 |
14 | def get_model(return_all_layer_outputs, is_training, use_dropout, vocab_size):
15 | """Get the model from model function"""
16 |
17 | def model_fn():
18 | # We use GPT2 Style model, but we use BigBird Roberta Tokenizer
19 | config = GPT2Model.get_config(MODEL_NAME)
20 | # We update the vocab_size for that reason
21 | config['vocab_size'] = vocab_size
22 | model = GPT2Model.from_config(config, mask_mode='user_defined', return_layer=True)
23 | model = MaskedLMModel(
24 | model,
25 | use_extra_mlm_layer=False,
26 | hidden_size=config['embedding_size'],
27 | layer_norm_epsilon=config['layer_norm_epsilon'],
28 | )
29 | model = model.get_model()
30 | return model
31 |
32 | return model_fn
33 |
34 |
35 | def get_tokenizer():
36 | tokenizer_layer = BigBirdRobertaTokenizerTFText.from_pretrained(TOKENIZER_NAME)
37 | return tokenizer_layer
38 |
39 |
40 | def get_optimizer(learning_rate, steps_per_epoch, epochs, warmup_rate, use_constant_lr=False):
41 | """Get AdamW optimizer"""
42 |
43 | # Total steps over all epochs
44 | num_train_steps = steps_per_epoch * epochs
45 | warmup_steps = int(warmup_rate * num_train_steps)
46 |
47 | def optimizer_fn():
48 | if use_constant_lr:
49 | from tf_transformers.optimization.adam_weighted import AdamWeightDecay
50 |
51 | optimizer = AdamWeightDecay(learning_rate=learning_rate)
52 | return optimizer
53 |
54 | optimizer, learning_rate_fn = create_optimizer(learning_rate, num_train_steps, warmup_steps)
55 | return optimizer
56 |
57 | return optimizer_fn
58 |
59 |
60 | def get_loss(loss_type):
61 | """Get MLM Loss"""
62 | return get_lm_loss(loss_type=loss_type)
63 |
64 |
65 | def get_trainer(distribution_strategy, dtype, num_gpus=0, tpu_address=None):
66 | """Get Trainer"""
67 | trainer = Trainer(distribution_strategy, dtype=dtype, num_gpus=num_gpus, tpu_address=tpu_address)
68 | return trainer
69 |
70 |
71 | def get_hf_tokenizer():
72 | """Get HuggingFace Tokenizer"""
73 | from transformers import BigBirdTokenizer
74 |
75 | tokenizer = BigBirdTokenizer.from_pretrained(TOKENIZER_NAME)
76 | return tokenizer
77 |
--------------------------------------------------------------------------------
/research/masked_language_model/run_mlm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """This is the main script to run MLM"""
18 | import os
19 | import warnings
20 |
21 | import hydra
22 | from absl import logging
23 | from omegaconf import DictConfig
24 | from train_mlm import run_train
25 |
26 | logging.set_verbosity("INFO")
27 |
28 | # We set PROJECT_NAME from ENVIORNMENT VARIABLE
29 | WANDB_PROJECT = os.getenv('WANDB_PROJECT', None)
30 | use_wandb = True
31 | if WANDB_PROJECT is None:
32 | warnings.warn(
33 | "For wandb-project should not be None.\
34 | Set export WANDB_PROJECT="
35 | )
36 | use_wandb = False
37 |
38 |
39 | @hydra.main(config_path="conf", config_name="config")
40 | def run(cfg: DictConfig) -> None:
41 | print("Config", cfg)
42 | config_dict = dict(cfg)
43 | # For TPU, we need to initialize it before tf text dataset
44 | # starts triggering. Hack
45 | if cfg.trainer.strategy == 'tpu':
46 | from model import get_trainer
47 |
48 | distribution_strategy = 'tpu'
49 | num_gpus = 0
50 | tpu_address = cfg.trainer.tpu_address
51 | get_trainer(
52 | distribution_strategy=distribution_strategy,
53 | num_gpus=num_gpus,
54 | tpu_address=tpu_address,
55 | dtype=cfg.trainer.dtype,
56 | ) # noqa
57 |
58 | if use_wandb:
59 | import wandb
60 |
61 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True)
62 | history = run_train(cfg, wandb)
63 | else:
64 | # Set wandb = None
65 | history = run_train(cfg, None)
66 | return history
67 |
68 |
69 | if __name__ == "__main__":
70 | history = run()
71 |
--------------------------------------------------------------------------------
/research/masked_language_model_old/1_data_to_text.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import hydra
5 | from omegaconf import DictConfig, OmegaConf
6 |
7 | from tf_transformers.data.utils import hf_dump_chars_to_textfile
8 |
9 | # A logger for this file
10 | log = logging.getLogger(__name__)
11 |
12 |
13 | def write_data(cfg):
14 | """Load dataset and write to txt file"""
15 | output_file = cfg.data.output_text_file
16 | if os.path.isfile(output_file):
17 | raise FileExistsError()
18 |
19 | from datasets import load_dataset
20 |
21 | if cfg.data.version:
22 | dataset = load_dataset(cfg.data.name, cfg.data.version)
23 | else:
24 | dataset = load_dataset(cfg.data.name)
25 |
26 | split = cfg.data.split # train, test, dev
27 | data_keys = cfg.data.keys # text
28 | hf_dump_chars_to_textfile(output_file, dataset[split], data_keys, max_char=-1)
29 |
30 |
31 | @hydra.main(config_path="config", config_name="data_config")
32 | def run(cfg: DictConfig) -> None:
33 | print(OmegaConf.to_yaml(cfg))
34 | write_data(cfg)
35 |
36 |
37 | if __name__ == "__main__":
38 | run()
39 |
--------------------------------------------------------------------------------
/research/masked_language_model_old/2_text_to_features.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import tensorflow as tf
3 | from omegaconf import DictConfig, OmegaConf
4 |
5 | from tf_transformers.data import TFWriter
6 | from tf_transformers.text import SentencepieceTokenizer
7 |
8 |
9 | def load_tokenizer(cfg):
10 | """Load tf text based tokenizer"""
11 | model_file_path = cfg.tokenizer.model_file_path
12 | do_lower_case = cfg.tokenizer.do_lower_case
13 | special_tokens = cfg.tokenizer.special_tokens
14 |
15 | tokenizer_layer = SentencepieceTokenizer(
16 | model_file_path=model_file_path, lower_case=do_lower_case, special_tokens=special_tokens
17 | )
18 |
19 | return tokenizer_layer
20 |
21 |
22 | def create_tfrecords(cfg):
23 | """Prepare tfrecords"""
24 | schema = {
25 | "input_ids": ("var_len", "int"),
26 | }
27 |
28 | tfrecord_output_dir = cfg.data.tfrecord_output_dir
29 | tfrecord_filename = cfg.data.tfrecord_filename
30 | tfrecord_nfiles = cfg.data.tfrecord_nfiles
31 | tfrecord_mode = cfg.data.tfrecord_mode
32 | tfrecord_overwrite = cfg.data.tfrecord_overwrite
33 |
34 | input_text_files = cfg.data.input_text_files
35 | batch_size = cfg.data.batch_size
36 |
37 | tfwriter = TFWriter(
38 | schema=schema,
39 | file_name=tfrecord_filename,
40 | model_dir=tfrecord_output_dir,
41 | tag=tfrecord_mode,
42 | n_files=tfrecord_nfiles,
43 | overwrite=tfrecord_overwrite,
44 | )
45 |
46 | dataset = tf.data.TextLineDataset(input_text_files)
47 |
48 | def text_normalize(line):
49 | """Exclude empty string"""
50 | line = tf.strings.strip(line)
51 | return tf.not_equal(tf.strings.length(line), 0)
52 |
53 | dataset = dataset.filter(text_normalize)
54 | dataset = dataset.apply(tf.data.experimental.unique())
55 | dataset = dataset.batch(batch_size, drop_remainder=False)
56 |
57 | def parse_train():
58 | import tqdm
59 |
60 | tokenizer_layer = load_tokenizer(cfg)
61 | for batch_input in tqdm.tqdm(dataset):
62 | batch_input = {'text': [batch_input]}
63 | batch_tokenized = tokenizer_layer(batch_input)["input_ids"].to_list()
64 | for example_input_ids in batch_tokenized:
65 | yield {"input_ids": example_input_ids}
66 |
67 | # Process
68 | tfwriter.process(parse_fn=parse_train())
69 |
70 |
71 | @hydra.main(config_path="config", config_name="tfrecord_config")
72 | def run(cfg: DictConfig) -> None:
73 | print(OmegaConf.to_yaml(cfg))
74 | create_tfrecords(cfg)
75 |
76 |
77 | if __name__ == "__main__":
78 | run()
79 |
--------------------------------------------------------------------------------
/research/masked_language_model_old/README.MD:
--------------------------------------------------------------------------------
1 |
2 | # Prepare Data
3 |
4 | python3 1_data_to_text.py data.name=wikipedia data.version=20200501.en data.output_text_file=/home/sidhu/Datasets/data/wikipedia.txt
5 | python3 1_data_to_text.py data.name=bookcorpus data.output_text_file=/home/sidhu/datasets/bookcorpus.txt
6 |
7 | # Prepare tfrecords
8 |
9 | nohup python3 2_text_to_features.py tokenizer.model_file_path=/home/sidhu/Datasets/data/t5_extended_vocab/new_spiece.model tokenizer.do_lower_case=false data.tfrecord_output_dir=/home/sidhu/Datasets/data/wiki_tfrecords data.tfrecord_filename=wiki data.tfrecord_nfiles=10 data.input_text_files=[/home/sidhu/Datasets/data/wikipedia.txt] data.batch_size=1024 > wiki_tfrecord.log &
10 |
11 |
12 | Bookcorpus
13 |
14 | nohup python3 2_text_to_features.py tokenizer.model_file_path=/home/sidhu/Datasets/data/t5_extended_vocab/new_spiece.model tokenizer.do_lower_case=false data.tfrecord_output_dir=/home/sidhu/Datasets/data/bookcorpus_tfrecords data.tfrecord_filename=bookcorpus data.tfrecord_nfiles=10 data.input_text_files=[/home/sidhu/Datasets/data/bookcorpus.txt] data.batch_size=1024 > bookcorpus_tfrecord.log &
15 |
16 |
17 | python3 train_mlm.py data.tfrecord_path_list=["/home/sidhu/Datasets/bookcorpus_tfrecords", "/home/sidhu/Datasets/wiki_tfrecords"] \
18 | tokenizer.model_file_path=/home/Sidhu/Datasets/vocab/new_spiece.model
19 |
20 | python3 train_mlm.py tokenizer.model_file_path=/home/sidhu/Datasets/vocab/new_spiece.model \
21 | model.model_save_dir=/home/sidhu/Projects/joint_bert
22 |
--------------------------------------------------------------------------------
/research/masked_language_model_old/config/data_config.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | name: mc4
3 | version:
4 | split: train
5 | output_text_file: 'output.txt'
6 | data_keys: ['text']
7 |
--------------------------------------------------------------------------------
/research/masked_language_model_old/config/tfrecord_config.yaml:
--------------------------------------------------------------------------------
1 |
2 | tokenizer:
3 | model_file_path:
4 | do_lower_case: false
5 | special_tokens: ['[CLS]', '[MASK]', '', '', '']
6 | data:
7 | tfrecord_output_dir:
8 | tfrecord_filename:
9 | tfrecord_nfiles:
10 | tfrecord_mode: train
11 | tfrecord_overwrite: false
12 | input_text_files:
13 | batch_size: 1024
14 |
--------------------------------------------------------------------------------
/research/masked_language_model_old/config/train_config.yaml:
--------------------------------------------------------------------------------
1 |
2 | tokenizer:
3 | vocab_size: 32002
4 | cls_token: '[CLS]'
5 | mask_token: '[MASK]'
6 | pad_token: ''
7 | sep_token: ''
8 | unk_token: ''
9 | model_file_path: 'new_spiece.model'
10 | do_lower_case: false
11 | special_tokens: ['[CLS]', '[MASK]', '', '', '']
12 | data:
13 | max_seq_len: 128
14 | max_predictions_per_batch: 20
15 | batch_size: 512
16 | min_sen_len:
17 | tfrecord_path_list: ['path1, path2']
18 | model:
19 | optimizer:
20 | learning_rate: 5e-5
21 | train_steps: 2000000
22 | warmup_steps: 60000
23 | optimizer_type: adamw
24 | loss:
25 | loss_type: joint
26 | epochs: 2
27 | steps_per_epoch: 200
28 | callback_steps: [100]
29 | model_save_dir:
30 | trainer:
31 | device_type: tpu
32 | device_address: local
33 | dtype: bf16
34 |
--------------------------------------------------------------------------------
/research/mix_language_model/README.MD:
--------------------------------------------------------------------------------
1 |
2 | ### Sentence Masked Language Modeling using TFText
3 |
4 | This folder has all necessary scripts to run the Sentence MLM using tensorflow text.
5 | Instead of masking words, we mask sentences (sequence of words)
6 |
7 | ### Advantage
8 |
9 | No need to prepare features in TFRecord format. We make use of dynamic mlm. All we need is a text file or folder of text files with text files, which each line corrsponding to text.
10 |
11 | ### Configuration (Hydra)
12 |
13 | All or most configurations can be managed using ```conf/config.yaml```. You can override it by command line also.
14 |
15 | Eg: For TPU , we need a data in GCS and model_checkpoint_dir to be in GCS too.
16 |
17 | ```python3 run_mlm.py task.data_directory= task.train_batch_size=128 trainer.dtype=bf16 trainer.model_checkpoint_dir= trainer.steps_per_epoch=50000 trainer.callback_steps=10000 trainer.epochs=20 trainer.strategy=tpu trainer.tpu_address= optimizer.learning_rate=5e-4```
18 |
19 | ### WandB
20 |
21 | By default we are using Wandb. if enviornment variable ```WANDB_PROJECT=None```, wandb will be disabled.
22 |
--------------------------------------------------------------------------------
/research/mix_language_model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/research/mix_language_model/__init__.py
--------------------------------------------------------------------------------
/research/mix_language_model/conf/config.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | data_directory:
3 | max_seq_len: 1024
4 | max_predictions_per_seq: 1024
5 | minimum_prefix_length: 900
6 | train_batch_size: 128
7 | trainer:
8 | dtype: bf16
9 | num_gpus: 0
10 | tpu_address:
11 | epochs: 20
12 | strategy: mirrored
13 | steps_per_epoch: 50000
14 | model_checkpoint_dir:
15 | global_norm: 1.0
16 | optimizer:
17 | learning_rate: 0.006
18 | num_warmup_steps: 0.1
19 | decay_function: cosine
20 | adam_beta_1: 0.9
21 | adam_beta_2: 0.95
22 | adam_epsilon: 10e-8
23 | weight_decay_rate: 0.1
24 | optimizer_type: adamw
25 | loss_type:
26 | use_constant_lr: false
27 | model:
28 | is_training: false
29 | use_dropout: false
30 | num_layers: 12
31 |
--------------------------------------------------------------------------------
/research/mix_language_model/run_mix_lm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """This is the main script to run Mix Language Model"""
18 | import os
19 | import warnings
20 |
21 | import hydra
22 | from absl import logging
23 | from omegaconf import DictConfig
24 | from train_mix_lm import run_train
25 |
26 | logging.set_verbosity("INFO")
27 |
28 | # We set PROJECT_NAME from ENVIORNMENT VARIABLE
29 | WANDB_PROJECT = os.getenv('WANDB_PROJECT', None)
30 | use_wandb = True
31 | if WANDB_PROJECT is None:
32 | warnings.warn(
33 | "For wandb-project should not be None.\
34 | Set export WANDB_PROJECT="
35 | )
36 | use_wandb = False
37 |
38 |
39 | @hydra.main(config_path="conf", config_name="config")
40 | def run(cfg: DictConfig) -> None:
41 | print("Config", cfg)
42 | config_dict = dict(cfg)
43 | # For TPU, we need to initialize it before tf text dataset
44 | # starts triggering. Hack
45 | if cfg.trainer.strategy == 'tpu':
46 | from model import get_trainer
47 |
48 | distribution_strategy = 'tpu'
49 | num_gpus = 0
50 | tpu_address = cfg.trainer.tpu_address
51 | get_trainer(
52 | distribution_strategy=distribution_strategy,
53 | num_gpus=num_gpus,
54 | tpu_address=tpu_address,
55 | dtype=cfg.trainer.dtype,
56 | ) # noqa
57 |
58 | if use_wandb:
59 | import wandb
60 |
61 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True)
62 | history = run_train(cfg, wandb)
63 | else:
64 | # Set wandb = None
65 | history = run_train(cfg, None)
66 | return history
67 |
68 |
69 | if __name__ == "__main__":
70 | history = run()
71 |
--------------------------------------------------------------------------------
/research/mix_language_model_old/1_data_to_text.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import hydra
5 | from omegaconf import DictConfig, OmegaConf
6 |
7 | from tf_transformers.data.utils import hf_dump_chars_to_textfile
8 |
9 | # A logger for this file
10 | log = logging.getLogger(__name__)
11 |
12 |
13 | def write_data(cfg):
14 | """Load dataset and write to txt file"""
15 | output_file = cfg.data.output_text_file
16 | if os.path.isfile(output_file):
17 | raise FileExistsError()
18 |
19 | from datasets import load_dataset
20 |
21 | if cfg.data.version:
22 | dataset = load_dataset(cfg.data.name, cfg.data.version)
23 | else:
24 | dataset = load_dataset(cfg.data.name)
25 |
26 | split = cfg.data.split # train, test, dev
27 | data_keys = cfg.data.keys # text
28 | hf_dump_chars_to_textfile(output_file, dataset[split], data_keys, max_char=-1)
29 |
30 |
31 | @hydra.main(config_path="config", config_name="data_config")
32 | def run(cfg: DictConfig) -> None:
33 | print(OmegaConf.to_yaml(cfg))
34 | write_data(cfg)
35 |
36 |
37 | if __name__ == "__main__":
38 | run()
39 |
--------------------------------------------------------------------------------
/research/mix_language_model_old/2_text_to_features.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import tensorflow as tf
3 | from omegaconf import DictConfig, OmegaConf
4 |
5 | from tf_transformers.data import TFWriter
6 | from tf_transformers.text import SentencepieceTokenizer
7 |
8 |
9 | def load_tokenizer(cfg):
10 | """Load tf text based tokenizer"""
11 | model_file_path = cfg.tokenizer.model_file_path
12 | do_lower_case = cfg.tokenizer.do_lower_case
13 | special_tokens = cfg.tokenizer.special_tokens
14 |
15 | tokenizer_layer = SentencepieceTokenizer(
16 | model_file_path=model_file_path, lower_case=do_lower_case, special_tokens=special_tokens
17 | )
18 |
19 | return tokenizer_layer
20 |
21 |
22 | def create_tfrecords(cfg):
23 | """Prepare tfrecords"""
24 | schema = {
25 | "input_ids": ("var_len", "int"),
26 | }
27 |
28 | tfrecord_output_dir = cfg.data.tfrecord_output_dir
29 | tfrecord_filename = cfg.data.tfrecord_filename
30 | tfrecord_nfiles = cfg.data.tfrecord_nfiles
31 | tfrecord_mode = cfg.data.tfrecord_mode
32 | tfrecord_overwrite = cfg.data.tfrecord_overwrite
33 |
34 | input_text_files = cfg.data.input_text_files
35 | batch_size = cfg.data.batch_size
36 |
37 | tfwriter = TFWriter(
38 | schema=schema,
39 | file_name=tfrecord_filename,
40 | model_dir=tfrecord_output_dir,
41 | tag=tfrecord_mode,
42 | n_files=tfrecord_nfiles,
43 | overwrite=tfrecord_overwrite,
44 | )
45 |
46 | dataset = tf.data.TextLineDataset(input_text_files)
47 |
48 | def text_normalize(line):
49 | """Exclude empty string"""
50 | line = tf.strings.strip(line)
51 | return tf.not_equal(tf.strings.length(line), 0)
52 |
53 | dataset = dataset.filter(text_normalize)
54 | dataset = dataset.apply(tf.data.experimental.unique())
55 | dataset = dataset.batch(batch_size, drop_remainder=False)
56 |
57 | def parse_train():
58 | import tqdm
59 |
60 | tokenizer_layer = load_tokenizer(cfg)
61 | for batch_input in tqdm.tqdm(dataset):
62 | batch_input = {'text': [batch_input]}
63 | batch_tokenized = tokenizer_layer(batch_input)["input_ids"].to_list()
64 | for example_input_ids in batch_tokenized:
65 | yield {"input_ids": example_input_ids}
66 |
67 | # Process
68 | tfwriter.process(parse_fn=parse_train())
69 |
70 |
71 | @hydra.main(config_path="config", config_name="tfrecord_config")
72 | def run(cfg: DictConfig) -> None:
73 | print(OmegaConf.to_yaml(cfg))
74 | create_tfrecords(cfg)
75 |
76 |
77 | if __name__ == "__main__":
78 | run()
79 |
--------------------------------------------------------------------------------
/research/mix_language_model_old/README.MD:
--------------------------------------------------------------------------------
1 |
2 | # Prepare Data
3 |
4 | python3 1_data_to_text.py data.name=wikipedia data.version= data.output_text_file=/home/Sidhu/datasets/wikipedia.txt
5 | python3 1_data_to_text.py data.name=bookcorpus data.output_text_file=/home/Sidhu/datasets/bookcorpus.txt
6 |
7 | # Prepare tfrecords
8 |
9 | nohup python3 2_text_to_features.py tokenizer.model_file_path=/home/sidhu/Datasets/vocab/new_spiece.model tokenizer.do_lower_case=false data.tfrecord_output_dir=/home/sidhu/Datasets/wiki_tfrecords data.tfrecord_filename=wiki data.tfrecord_nfiles=25 data.input_text_files=[/home/sidhu/Datasets/wikipedia.txt] data.batch_size=1024 > wiki_tfrecord.log &
10 |
11 |
12 | Bookcorpus
13 |
14 | nohup python3 2_text_to_features.py tokenizer.model_file_path=/home/sidhu/Datasets/vocab/new_spiece.model tokenizer.do_lower_case=false data.tfrecord_output_dir=/home/sidhu/Datasets/bookcorpus_tfrecords data.tfrecord_filename=bookcorpus data.tfrecord_nfiles=10 data.input_text_files=[/home/sidhu/Datasets/bookcorpus.txt] data.batch_size=1024 > bookcorpus_tfrecord.log &
15 |
16 |
17 | python3 train_mix_mlm.py tokenizer.model_file_path=/home/sidhu/Datasets/vocab/new_spiece.model \
18 | model.model_save_dir=/home/sidhu/Projects/joint_bert
19 |
--------------------------------------------------------------------------------
/research/mix_language_model_old/config/data_config.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | name: mc4
3 | version:
4 | split: train
5 | output_text_file: 'output.txt'
6 | data_keys: ['text']
7 |
--------------------------------------------------------------------------------
/research/mix_language_model_old/config/tfrecord_config.yaml:
--------------------------------------------------------------------------------
1 |
2 | tokenizer:
3 | model_file_path:
4 | do_lower_case: false
5 | special_tokens: ['[CLS]', '[MASK]', '', '', '']
6 | data:
7 | tfrecord_output_dir:
8 | tfrecord_filename:
9 | tfrecord_nfiles:
10 | tfrecord_mode: train
11 | tfrecord_overwrite: false
12 | input_text_files:
13 | batch_size: 1024
14 |
--------------------------------------------------------------------------------
/research/mix_language_model_old/config/train_config.yaml:
--------------------------------------------------------------------------------
1 |
2 | tokenizer:
3 | vocab_size: 32002
4 | cls_token: '[CLS]'
5 | mask_token: '[MASK]'
6 | pad_token: ''
7 | sep_token: ''
8 | unk_token: ''
9 | model_file_path: 'new_spiece.model'
10 | do_lower_case: false
11 | special_tokens: ['[CLS]', '[MASK]', '', '', '']
12 | data:
13 | max_seq_len: 128
14 | max_predictions_per_batch: 20
15 | batch_size: 2048
16 | min_sen_len:
17 | tfrecord_path_list: ['path1, path2']
18 | model:
19 | optimizer:
20 | learning_rate: 1e-4
21 | train_steps: 2000000
22 | warmup_steps: 60000
23 | optimizer_type: adamw
24 | loss:
25 | loss_type:
26 | epochs: 2
27 | steps_per_epoch: 200
28 | callback_steps: [100]
29 | model_save_dir:
30 | trainer:
31 | device_type: tpu
32 | device_address: local
33 | dtype: bf16
34 |
--------------------------------------------------------------------------------
/research/sentence2vec/train.py:
--------------------------------------------------------------------------------
1 | import tensorflow_text as tf_text # noqa
2 | import wandb
3 | from dataset_loader import get_dataset
4 | from model import get_model, get_optimizer, loss_fn
5 |
6 | from tf_transformers.core import Trainer
7 | from tf_transformers.models import AlbertTokenizerTFText
8 |
9 | wandb.login()
10 |
11 | TPU_ADDRESS = 'legacyai-tpu-2'
12 | DTYPE = 'bf16'
13 |
14 | delim_regex_pattern = '\. ' # noqa
15 | window_length = 10
16 | minimum_sentences = 4
17 | batch_size = 512
18 | max_seq_length = 256
19 |
20 | learning_rate = 2e-5
21 | epochs = 50
22 | steps_per_epoch = 100000
23 | num_train_steps = steps_per_epoch * epochs
24 | num_warmup_steps = 0.1 * num_train_steps
25 | global_norm = 5.0
26 | optimizer_fn = get_optimizer(learning_rate, num_train_steps, num_warmup_steps, decay_function='linear')
27 |
28 | clip_logits = True
29 | use_random_base = False
30 | siamese = True
31 |
32 | model_checkpoint_dir = 'gs://legacyai-bucket/sentence2vec_1'
33 |
34 | WANDB_PROJECT = 'sentence2vec'
35 | config_dict = {}
36 | config_dict['learning_rate'] = learning_rate
37 | config_dict['steps_per_epoch'] = steps_per_epoch
38 | config_dict['epochs'] = epochs
39 | config_dict['num_train_steps'] = steps_per_epoch * epochs
40 | config_dict['num_warmup_steps'] = 0.1 * num_train_steps
41 | config_dict['global_norm'] = global_norm
42 | config_dict['model_checkpoint_dir'] = model_checkpoint_dir
43 | config_dict['clip_logits'] = clip_logits
44 | config_dict['use_random_base'] = use_random_base
45 |
46 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True)
47 |
48 |
49 | trainer = Trainer(distribution_strategy='tpu', tpu_address=TPU_ADDRESS, dtype=DTYPE)
50 |
51 | tokenizer_layer = AlbertTokenizerTFText.from_pretrained("albert-base-v2", add_special_tokens=False)
52 | train_dataset = get_dataset(
53 | delim_regex_pattern, minimum_sentences, window_length, tokenizer_layer, max_seq_length, batch_size
54 | )
55 | model_fn = get_model(clip_logits, use_random_base, siamese)
56 |
57 |
58 | # Train
59 | training_loss_names = ['loss_cls', 'loss_mean', 'loss']
60 | history = trainer.run(
61 | model_fn=model_fn,
62 | optimizer_fn=optimizer_fn,
63 | train_dataset=train_dataset,
64 | train_loss_fn=loss_fn,
65 | epochs=epochs,
66 | steps_per_epoch=steps_per_epoch,
67 | model_checkpoint_dir=model_checkpoint_dir,
68 | batch_size=batch_size,
69 | training_loss_names=training_loss_names,
70 | repeat_dataset=True,
71 | wandb=wandb,
72 | clip_norm=global_norm,
73 | )
74 |
--------------------------------------------------------------------------------
/research/sentence_language_model/README.md:
--------------------------------------------------------------------------------
1 |
2 | ### Sentence Masked Language Modeling using TFText
3 |
4 | This folder has all necessary scripts to run the Sentence MLM using tensorflow text.
5 | Instead of masking words, we mask sentences (sequence of words)
6 |
7 | ### Advantage
8 |
9 | No need to prepare features in TFRecord format. We make use of dynamic mlm. All we need is a text file or folder of text files with text files, which each line corrsponding to text.
10 |
11 | ### Configuration (Hydra)
12 |
13 | All or most configurations can be managed using ```conf/config.yaml```. You can override it by command line also.
14 |
15 | Eg: For TPU , we need a data in GCS and model_checkpoint_dir to be in GCS too.
16 |
17 | ```python3 run_mlm.py data.data_directory= data.train_batch_size=128 trainer.dtype=bf16 trainer.model_checkpoint_dir= trainer.steps_per_epoch=50000 trainer.callback_steps=10000 trainer.epochs=20 trainer.strategy=tpu trainer.tpu_address= optimizer.learning_rate=5e-4```
18 |
19 | ### WandB
20 |
21 | By default we are using Wandb. if enviornment variable ```WANDB_PROJECT=None```, wandb will be disabled.
22 |
--------------------------------------------------------------------------------
/research/sentence_language_model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/research/sentence_language_model/__init__.py
--------------------------------------------------------------------------------
/research/sentence_language_model/conf/config.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_directory:
3 | train_batch_size: 128
4 | task:
5 | max_seq_len: 1024
6 | max_predictions_per_seq: 200
7 | trainer:
8 | dtype: bf16
9 | num_gpus: 0
10 | tpu_address:
11 | epochs: 20
12 | strategy: mirrored
13 | steps_per_epoch: 50000
14 | model_checkpoint_dir:
15 | callback_steps: 10000
16 | optimizer:
17 | learning_rate: 1e-4
18 | warmup_rate: 0.2
19 | loss_type:
20 | use_constant_lr: false
21 | model:
22 | is_training: true
23 | use_dropout: true
24 | num_layers: 12
25 |
--------------------------------------------------------------------------------
/research/sentence_language_model/dataset_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from random import shuffle
3 |
4 | import tensorflow as tf
5 |
6 |
7 | def get_dataset(data_directory, masked_lm_map_fn, batch_size):
8 | """Convert text to tf.data.Dataset after map fn
9 |
10 | Args:
11 | data_directory ([type]): [description]
12 | masked_lm_map_fn ([type]): [description]
13 | batch_size ([type]): [description]
14 |
15 | Returns:
16 | [type]: [description]
17 | """
18 |
19 | def filter_out_empty_mask(x, y):
20 | """When an example doesn't have multiple sentences\
21 | there wont be any masked sentence. Ignore those examples,
22 | as nothing to predict.
23 | """
24 | return tf.greater(tf.reduce_sum(tf.cast(tf.not_equal(x['masked_lm_positions'], 0), tf.int32)), 0)
25 |
26 | all_text_files = tf.io.gfile.glob(os.path.join(data_directory, '*.txt'))
27 | shuffle(all_text_files)
28 | ds = tf.data.TextLineDataset(all_text_files)
29 |
30 | # We need to add the text as dict
31 | ds = ds.map(lambda x: {'text': x}, num_parallel_calls=tf.data.AUTOTUNE)
32 |
33 | # Do MLM
34 | ds = ds.map(masked_lm_map_fn, num_parallel_calls=tf.data.AUTOTUNE)
35 |
36 | # Filter examples if there is not atleast single MASK sentence
37 | ds = ds.filter(filter_out_empty_mask)
38 |
39 | # Batch
40 | ds = ds.batch(batch_size, drop_remainder=True)
41 |
42 | # Shuffle and Prefetch
43 | ds = ds.shuffle(100, reshuffle_each_iteration=True).prefetch(buffer_size=tf.data.AUTOTUNE)
44 |
45 | # Auto SHARD
46 | options = tf.data.Options()
47 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO
48 | ds = ds.with_options(options)
49 |
50 | return ds
51 |
--------------------------------------------------------------------------------
/research/sentence_language_model/model.py:
--------------------------------------------------------------------------------
1 | from tf_transformers.core import Trainer
2 | from tf_transformers.losses.loss_wrapper import get_lm_loss
3 | from tf_transformers.models import (
4 | BigBirdRobertaTokenizerTFText,
5 | GPT2Model,
6 | MaskedLMModel,
7 | )
8 | from tf_transformers.optimization import create_optimizer
9 |
10 | MODEL_NAME = 'gpt2'
11 | TOKENIZER_NAME = "google/bigbird-roberta-large"
12 |
13 |
14 | def get_model(return_all_layer_outputs, is_training, use_dropout, vocab_size):
15 | """Get the model from model function"""
16 |
17 | def model_fn():
18 | # We use GPT2 Style model, but we use BigBird Roberta Tokenizer
19 | config = GPT2Model.get_config(MODEL_NAME)
20 | # We update the vocab_size for that reason
21 | config['vocab_size'] = vocab_size
22 | model = GPT2Model.from_config(config, mask_mode='user_defined', return_layer=True)
23 | model = MaskedLMModel(
24 | model,
25 | use_extra_mlm_layer=False,
26 | hidden_size=config['embedding_size'],
27 | layer_norm_epsilon=config['layer_norm_epsilon'],
28 | )
29 | model = model.get_model()
30 | return model
31 |
32 | return model_fn
33 |
34 |
35 | def get_tokenizer():
36 | tokenizer_layer = BigBirdRobertaTokenizerTFText.from_pretrained(TOKENIZER_NAME)
37 | return tokenizer_layer
38 |
39 |
40 | def get_optimizer(learning_rate, steps_per_epoch, epochs, warmup_rate, use_constant_lr=False):
41 | """Get AdamW optimizer"""
42 |
43 | # Total steps over all epochs
44 | num_train_steps = steps_per_epoch * epochs
45 | warmup_steps = int(warmup_rate * num_train_steps)
46 |
47 | def optimizer_fn():
48 | if use_constant_lr:
49 | from tf_transformers.optimization.adam_weighted import AdamWeightDecay
50 |
51 | optimizer = AdamWeightDecay(learning_rate=learning_rate)
52 | return optimizer
53 |
54 | optimizer, learning_rate_fn = create_optimizer(learning_rate, num_train_steps, warmup_steps)
55 | return optimizer
56 |
57 | return optimizer_fn
58 |
59 |
60 | def get_loss(loss_type):
61 | """Get MLM Loss"""
62 | return get_lm_loss(loss_type=loss_type)
63 |
64 |
65 | def get_trainer(distribution_strategy, dtype, num_gpus=0, tpu_address=None):
66 | """Get Trainer"""
67 | trainer = Trainer(distribution_strategy, num_gpus=num_gpus, tpu_address=tpu_address, dtype=dtype)
68 | return trainer
69 |
70 |
71 | def get_hf_tokenizer():
72 | """Get HuggingFace Tokenizer"""
73 | from transformers import BigBirdTokenizer
74 |
75 | tokenizer = BigBirdTokenizer.from_pretrained(TOKENIZER_NAME)
76 | return tokenizer
77 |
--------------------------------------------------------------------------------
/research/sentence_language_model/run_mlm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """This is the main script to run MLM"""
18 | import os
19 | import warnings
20 |
21 | import hydra
22 | from absl import logging
23 | from omegaconf import DictConfig
24 | from train_mlm import run_train
25 |
26 | logging.set_verbosity("INFO")
27 |
28 | # We set PROJECT_NAME from ENVIORNMENT VARIABLE
29 | WANDB_PROJECT = os.getenv('WANDB_PROJECT', None)
30 | use_wandb = True
31 | if WANDB_PROJECT is None:
32 | warnings.warn(
33 | "For wandb-project should not be None.\
34 | Set export WANDB_PROJECT="
35 | )
36 | use_wandb = False
37 |
38 |
39 | @hydra.main(config_path="conf", config_name="config")
40 | def run(cfg: DictConfig) -> None:
41 | print("Config", cfg)
42 | config_dict = dict(cfg)
43 | # For TPU, we need to initialize it before tf text dataset
44 | # starts triggering. Hack
45 | if cfg.trainer.strategy == 'tpu':
46 | from model import get_trainer
47 |
48 | distribution_strategy = 'tpu'
49 | num_gpus = 0
50 | tpu_address = cfg.trainer.tpu_address
51 | get_trainer(
52 | distribution_strategy=distribution_strategy,
53 | num_gpus=num_gpus,
54 | tpu_address=tpu_address,
55 | dtype=cfg.trainer.dtype,
56 | ) # noqa
57 |
58 | if use_wandb:
59 | import wandb
60 |
61 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True)
62 | history = run_train(cfg, wandb)
63 | else:
64 | # Set wandb = None
65 | history = run_train(cfg, None)
66 | return history
67 |
68 |
69 | if __name__ == "__main__":
70 | history = run()
71 |
--------------------------------------------------------------------------------
/research/similarity_model_pretraining/README.MD:
--------------------------------------------------------------------------------
1 |
2 | ### Masked Language Modeling using TFText
3 |
4 | This folder has all necessary scripts to run the MLM using tensorflow text.
5 |
6 | ### Advantage
7 |
8 | No need to prepare features in TFRecord format. We make use of dynamic mlm. All we need is a text file or folder of text files with text files, which each line corrsponding to text.
9 |
10 | ### Configuration (Hydra)
11 |
12 | All or most configurations can be managed using ```conf/config.yaml```. You can override it by command line also.
13 |
14 | Eg: For TPU , we need a data in GCS and model_checkpoint_dir to be in GCS too.
15 |
16 | ```python3 run_mlm.py \ data.data_directory=$GCP_BUCKET/data/ \ trainer.model_checkpoint_dir=$GCP_BUCKET/model```
17 |
18 | ### WandB
19 |
20 | By default we are using Wandb. Check ```run_mlm.py``` to disable it.
21 |
--------------------------------------------------------------------------------
/research/similarity_model_pretraining/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/research/similarity_model_pretraining/__init__.py
--------------------------------------------------------------------------------
/research/similarity_model_pretraining/conf/config.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_directory:
3 | train_batch_size: 32
4 | eval_batch_size: 32
5 | task:
6 | max_seq_len: 128
7 | max_predictions_per_seq: 40
8 | trainer:
9 | dtype: fp32
10 | num_gpus: 2
11 | tpu_address:
12 | epochs: 3
13 | strategy: mirrored
14 | steps_per_epoch: 10000
15 | model_checkpoint_dir:
16 | callback_steps: 1000
17 | optimizer:
18 | learning_rate: 5e-4
19 | warmup_rate: 0.1
20 | decay_function: cosine
21 | loss_type:
22 | use_constant_lr: false
23 | model:
24 | is_training: true
25 | use_dropout: true
26 | num_layers: 12
27 |
--------------------------------------------------------------------------------
/research/similarity_model_pretraining/run_similarity.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """This is the main script to run MLM"""
18 | import os
19 | import warnings
20 |
21 | import hydra
22 | from absl import logging
23 | from omegaconf import DictConfig
24 | from train_similairity import run_train
25 |
26 | logging.set_verbosity("INFO")
27 |
28 | # We set PROJECT_NAME from ENVIORNMENT VARIABLE
29 | WANDB_PROJECT = os.getenv('WANDB_PROJECT', None)
30 | use_wandb = True
31 | if WANDB_PROJECT is None:
32 | warnings.warn(
33 | "For wandb-project should not be None.\
34 | Set export WANDB_PROJECT="
35 | )
36 | use_wandb = False
37 |
38 |
39 | @hydra.main(config_path="conf", config_name="config")
40 | def run(cfg: DictConfig) -> None:
41 | print("Config", cfg)
42 | config_dict = dict(cfg)
43 | # For TPU, we need to initialize it before tf text dataset
44 | # starts triggering. Hack
45 | if cfg.trainer.strategy == 'tpu':
46 | from model import get_trainer
47 |
48 | distribution_strategy = 'tpu'
49 | num_gpus = 0
50 | tpu_address = cfg.trainer.tpu_address
51 | get_trainer(
52 | distribution_strategy=distribution_strategy,
53 | num_gpus=num_gpus,
54 | tpu_address=tpu_address,
55 | dtype=cfg.trainer.dtype,
56 | ) # noqa
57 |
58 | if use_wandb:
59 | import wandb
60 |
61 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True)
62 | history = run_train(cfg, wandb)
63 | else:
64 | # Set wandb = None
65 | history = run_train(cfg, None)
66 | return history
67 |
68 |
69 | if __name__ == "__main__":
70 | history = run()
71 |
--------------------------------------------------------------------------------
/research/t5_style_pretraining/README.md:
--------------------------------------------------------------------------------
1 |
2 | ### Sentence Masked Language Modeling using TFText
3 |
4 | This folder has all necessary scripts to run the Sentence MLM using tensorflow text.
5 | Instead of masking words, we mask sentences (sequence of words)
6 |
7 | ### Advantage
8 |
9 | No need to prepare features in TFRecord format. We make use of dynamic mlm. All we need is a text file or folder of text files with text files, which each line corrsponding to text.
10 |
11 | ### WandB
12 |
13 | By default we are using Wandb. if enviornment variable ```WANDB_PROJECT=None```, wandb will be disabled.
14 |
15 | ``` export WANDB_PROJECT='t5-style-pretraining' ```
16 | ### Configuration (Hydra)
17 |
18 | All or most configurations can be managed using ```conf/config.yaml```. You can override it by command line also.
19 |
20 | Eg: For TPU , we need a data in GCS and model_checkpoint_dir to be in GCS too.
21 |
22 | ``` nohup python3 run_t5_modified.py \
23 | task.data_directory=gs://legacyai-bucket \
24 | task.train_batch_size=128 \
25 | trainer.dtype=bf16 \
26 | trainer.model_checkpoint_dir=gs://legacyai-bucket/t5_style_t5_small_lr_0.0005 \
27 | trainer.steps_per_epoch=10000 \
28 | trainer.epochs=100 \
29 | trainer.strategy=tpu \
30 | trainer.tpu_address=legacyai-tpu-1 \
31 | optimizer.learning_rate=0.01 \
32 | model.is_training=true \
33 | model.use_dropout=true \
34 | model.model_name=t5-small > logs &
35 | ```
36 |
--------------------------------------------------------------------------------
/research/t5_style_pretraining/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/research/t5_style_pretraining/__init__.py
--------------------------------------------------------------------------------
/research/t5_style_pretraining/conf/config.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | data_directory:
3 | max_seq_len: 256
4 | train_batch_size: 512
5 | trainer:
6 | dtype: bf16
7 | num_gpus: 0
8 | tpu_address:
9 | epochs: 100
10 | strategy: mirrored
11 | steps_per_epoch: 10000
12 | model_checkpoint_dir:
13 | global_norm: 5.0
14 | optimizer:
15 | learning_rate: 0.001
16 | num_warmup_steps: 0.1
17 | decay_function: cosine
18 | adam_beta_1: 0.9
19 | adam_beta_2: 0.95
20 | adam_epsilon: 10e-8
21 | weight_decay_rate: 0.1
22 | optimizer_type: adamw
23 | loss_type:
24 | use_constant_lr: false
25 | model:
26 | is_training: true
27 | use_dropout: true
28 | model_name:
29 | num_layers: 12
30 |
--------------------------------------------------------------------------------
/research/t5_style_pretraining/run_t5_modified.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """This is the main script to run Mix Language Model"""
18 | import os
19 | import warnings
20 |
21 | import hydra
22 | from absl import logging
23 | from omegaconf import DictConfig
24 | from train_t5_modified import run_train
25 |
26 | logging.set_verbosity("INFO")
27 |
28 | # We set PROJECT_NAME from ENVIORNMENT VARIABLE
29 | WANDB_PROJECT = os.getenv('WANDB_PROJECT', None)
30 | use_wandb = True
31 | if WANDB_PROJECT is None:
32 | warnings.warn(
33 | "For wandb-project should not be None.\
34 | Set export WANDB_PROJECT="
35 | )
36 | use_wandb = False
37 |
38 |
39 | @hydra.main(config_path="conf", config_name="config")
40 | def run(cfg: DictConfig) -> None:
41 | print("Config", cfg)
42 | config_dict = dict(cfg)
43 | # For TPU, we need to initialize it before tf text dataset
44 | # starts triggering. Hack
45 | if cfg.trainer.strategy == 'tpu':
46 | from model import get_trainer
47 |
48 | distribution_strategy = 'tpu'
49 | num_gpus = 0
50 | tpu_address = cfg.trainer.tpu_address
51 | get_trainer(
52 | distribution_strategy=distribution_strategy,
53 | num_gpus=num_gpus,
54 | tpu_address=tpu_address,
55 | dtype=cfg.trainer.dtype,
56 | ) # noqa
57 |
58 | if use_wandb:
59 | import wandb
60 |
61 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True)
62 | history = run_train(cfg, wandb)
63 | else:
64 | # Set wandb = None
65 | history = run_train(cfg, None)
66 | return history
67 |
68 |
69 | if __name__ == "__main__":
70 | history = run()
71 |
--------------------------------------------------------------------------------
/src/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/src/logo.png
--------------------------------------------------------------------------------
/src/logo2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/src/logo2.png
--------------------------------------------------------------------------------
/src/tf_transformers/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "2.0.0"
2 |
--------------------------------------------------------------------------------
/src/tf_transformers/activations/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Activations package definition."""
16 | from tf_transformers.activations.gelu import gelu, quick_gelu
17 | from tf_transformers.activations.swish import hard_swish, identity, simple_swish
18 | from tf_transformers.activations.utils import get_activation
19 |
--------------------------------------------------------------------------------
/src/tf_transformers/activations/gelu.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Gaussian error linear unit."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import math
20 |
21 | import tensorflow as tf
22 |
23 |
24 | @tf.keras.utils.register_keras_serializable(package="Text")
25 | def gelu(x):
26 | """Gaussian Error Linear Unit.
27 |
28 | This is a smoother version of the RELU.
29 | Original paper: https://arxiv.org/abs/1606.08415
30 | Args:
31 | x: float Tensor to perform activation.
32 |
33 | Returns:
34 | `x` with the GELU activation applied.
35 | """
36 | cdf = 0.5 * (1.0 + tf.tanh((math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3)))))
37 | return x * cdf
38 |
39 |
40 | @tf.keras.utils.register_keras_serializable(package="Text")
41 | def quick_gelu(x):
42 | """Quick GELU as in CLIP
43 |
44 | Returns:
45 | `x` with the Quick GELU activation applied.
46 | """
47 | return x * tf.sigmoid(1.702 * x)
48 |
--------------------------------------------------------------------------------
/src/tf_transformers/activations/swish.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Customized Swish activation."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import tensorflow as tf
20 |
21 |
22 | @tf.keras.utils.register_keras_serializable(package="Text")
23 | def simple_swish(features):
24 | """Computes the Swish activation function.
25 |
26 | The tf.nn.swish operation uses a custom gradient to reduce memory usage.
27 | Since saving custom gradients in SavedModel is currently not supported, and
28 | one would not be able to use an exported TF-Hub module for fine-tuning, we
29 | provide this wrapper that can allow to select whether to use the native
30 | TensorFlow swish operation, or whether to use a customized operation that
31 | has uses default TensorFlow gradient computation.
32 |
33 | Args:
34 | features: A `Tensor` representing preactivation values.
35 |
36 | Returns:
37 | The activation value.
38 | """
39 | features = tf.convert_to_tensor(features)
40 | return features * tf.nn.sigmoid(features)
41 |
42 |
43 | @tf.keras.utils.register_keras_serializable(package="Text")
44 | def hard_swish(features):
45 | """Computes a hard version of the swish function.
46 |
47 | This operation can be used to reduce computational cost and improve
48 | quantization for edge devices.
49 |
50 | Args:
51 | features: A `Tensor` representing preactivation values.
52 |
53 | Returns:
54 | The activation value.
55 | """
56 | features = tf.convert_to_tensor(features)
57 | return features * tf.nn.relu6(features + tf.constant(3.0)) * (1.0 / 6.0)
58 |
59 |
60 | @tf.keras.utils.register_keras_serializable(package="Text")
61 | def identity(features):
62 | """Computes the identity function.
63 |
64 | Useful for helping in quantization.
65 |
66 | Args:
67 | features: A `Tensor` representing preactivation values.
68 |
69 | Returns:
70 | The activation value.
71 | """
72 | features = tf.convert_to_tensor(features)
73 | return tf.identity(features)
74 |
--------------------------------------------------------------------------------
/src/tf_transformers/activations/utils.py:
--------------------------------------------------------------------------------
1 | import six
2 | import tensorflow as tf
3 |
4 | from tf_transformers import activations
5 |
6 |
7 | # TODO(hongkuny): consider moving custom string-map lookup to keras api.
8 | def get_activation(identifier):
9 | """Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
10 |
11 | It checks string first and if it is one of customized activation not in TF,
12 | the corresponding activation will be returned. For non-customized activation
13 | names and callable identifiers, always fallback to tf.keras.activations.get.
14 |
15 | Args:
16 | identifier: String name of the activation function or callable.
17 |
18 | Returns:
19 | A Python function corresponding to the activation function.
20 | """
21 | if isinstance(identifier, six.string_types):
22 | name_to_fn = {
23 | "gelu": activations.gelu,
24 | "simple_swish": activations.simple_swish,
25 | "hard_swish": activations.hard_swish,
26 | "identity": activations.identity,
27 | "relu": tf.keras.activations.relu,
28 | "quick_gelu": activations.quick_gelu,
29 | }
30 | identifier = str(identifier).lower()
31 | if identifier in name_to_fn:
32 | return tf.keras.activations.get(name_to_fn[identifier])
33 | return tf.keras.activations.get(identifier)
34 |
--------------------------------------------------------------------------------
/src/tf_transformers/callbacks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/src/tf_transformers/callbacks/__init__.py
--------------------------------------------------------------------------------
/src/tf_transformers/callbacks/metric_callback_list.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """This will return tf.keras.metric object based on name"""
18 |
19 | import tensorflow as tf
20 |
21 | _ALL_METRIC_NAMES = ['binary_accuracy']
22 |
23 |
24 | def show_available_metric_names():
25 | print(_ALL_METRIC_NAMES)
26 |
27 |
28 | def get_callback(metric_name: str):
29 | """Return tf.keras.metric with a name, for callback"""
30 | metric_name = metric_name.lower().strip()
31 |
32 | if metric_name not in _ALL_METRIC_NAMES:
33 | raise ValueError("{} not present in {}".format(metric_name, _ALL_METRIC_NAMES))
34 |
35 | if metric_name == "binary_accuracy":
36 | return tf.keras.metrics.BinaryAccuracy, metric_name
37 |
--------------------------------------------------------------------------------
/src/tf_transformers/callbacks/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from tf_transformers.callbacks.metrics.callbacks import MetricCallback
2 | from tf_transformers.callbacks.metrics.pearson_spearman_callback import (
3 | PearsonSpearmanCallback,
4 | )
5 | from tf_transformers.callbacks.metrics.sklearn_callbacks import SklearnMetricCallback
6 | from tf_transformers.callbacks.metrics.text_generation_callbacks import TextGenerationMetricCallback
7 |
--------------------------------------------------------------------------------
/src/tf_transformers/core/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.core.chainer import ClassificationChainer, TextGenerationChainer
18 | from tf_transformers.core.legacy_layer import LegacyLayer
19 | from tf_transformers.core.legacy_model import LegacyModel
20 | from tf_transformers.core.legacy_module import LegacyModule, LegacyModuleCustom
21 | from tf_transformers.core.model_wrapper import ModelWrapper
22 | from tf_transformers.core.trainer import Trainer
23 | from tf_transformers.core.trainer_for_all import TrainerforAll
24 | from tf_transformers.core.transformer_config import TransformerConfig
25 |
--------------------------------------------------------------------------------
/src/tf_transformers/core/model_utils_for_all.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from absl import logging
3 |
4 |
5 | def load_checkpoint_custom(model, checkpoint_dir=None, checkpoint_path=None, options=None, **kwargs):
6 | """[summary]
7 |
8 | Args:
9 | checkpoint_dir ([str]): [Location of the model]
10 | """
11 | try:
12 | options = tf.train.CheckpointOptions(experimental_io_device="/job:localhost")
13 | except:
14 | options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
15 |
16 | if checkpoint_dir:
17 | if tf.io.gfile.exists(checkpoint_dir):
18 | if tf.io.gfile.isdir(checkpoint_dir) is False:
19 | raise ValueError("checkpoint_dir expects a directory not a file {}.".format(checkpoint_dir))
20 | if checkpoint_path:
21 | if tf.io.gfile.isdir(checkpoint_path) is True:
22 | raise ValueError("checkpoint_path expects a checkpoint-file not a directory {}.".format(checkpoint_path))
23 | checkpoint = tf.train.Checkpoint(model=model, **kwargs)
24 | if checkpoint_path is None and checkpoint_dir:
25 | checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
26 | if checkpoint_path is None:
27 | if checkpoint_dir:
28 | logging.info("No ❌❌ checkpoint found in {}".format(checkpoint_dir))
29 | else:
30 | logging.info("No ❌❌ checkpoint found")
31 | return None
32 | else:
33 | if options:
34 | status = checkpoint.restore(checkpoint_path, options=options)
35 | else:
36 | status = checkpoint.restore(checkpoint_path)
37 | # Important
38 | if status.assert_existing_objects_matched():
39 | logging.info("Successful ✅✅: Model checkpoints matched and loaded from {}".format(checkpoint_path))
40 | return checkpoint
41 | else:
42 | logging.info("Failed ❌❌ to load the checkpoint. Status Assertion Failed.")
43 | return None
44 |
--------------------------------------------------------------------------------
/src/tf_transformers/core/performance_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | # Map string to TensorFlow dtype
4 | DTYPE_MAP = {
5 | "fp16": tf.float16,
6 | "bf16": tf.bfloat16,
7 | "fp32": tf.float32,
8 | }
9 |
10 |
11 | def get_tf_dtype(dtype):
12 | return DTYPE_MAP[dtype]
13 |
14 |
15 | def is_float16(dtype):
16 | if dtype in [tf.float16]:
17 | return True
18 | return False
19 |
20 |
21 | def set_mixed_precision_policy(dtype):
22 | """Sets mix precision policy."""
23 | if dtype == tf.float16:
24 | tf.keras.mixed_precision.set_global_policy('mixed_float16')
25 | elif dtype == tf.bfloat16:
26 | tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
27 | elif dtype == tf.float32:
28 | tf.keras.mixed_precision.set_global_policy('float32')
29 | else:
30 | raise ValueError('Unexpected dtype: %s' % dtype)
31 |
32 |
33 | def configure_optimizer(optimizer, use_float16=False, use_graph_rewrite=False, loss_scale='dynamic'):
34 | """Configures optimizer object with performance options."""
35 | if use_float16:
36 | if loss_scale == 'dynamic':
37 | optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
38 | else:
39 | # loss_scale is a number. We interpret that as a fixed loss scale.
40 | optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer, dynamic=False, initial_scale=loss_scale)
41 |
42 | if use_graph_rewrite:
43 | # Note: the model dtype must be 'float32', which will ensure
44 | # tf.keras.mixed_precision and enable_mixed_precision_graph_rewrite do not
45 | # double up.
46 | optimizer = tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite(optimizer)
47 | return optimizer
48 |
--------------------------------------------------------------------------------
/src/tf_transformers/core/read_from_hub.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | import tensorflow as tf
5 | from absl import logging
6 | from huggingface_hub import hf_hub_download, snapshot_download
7 |
8 | logging.set_verbosity("INFO")
9 |
10 |
11 | def get_config_cache(url: str):
12 | """Load model from Huggingface hub"""
13 | local_cache = snapshot_download(repo_id=url)
14 |
15 | # Load config from cache
16 | config_path = Path(local_cache, "config.json")
17 | if not config_path.exists():
18 | raise ValueError("config.json is not present in model hub {}".format(url))
19 | config_dict = json.load(open(config_path))
20 | return config_dict, local_cache
21 |
22 |
23 | def get_config_only(url: str):
24 | """Load config from Huggingface hub"""
25 | config_path = hf_hub_download(repo_id=url, filename="config.json")
26 | # Load config from cache
27 | config_dict = json.load(open(config_path))
28 | return config_dict
29 |
30 |
31 | def load_pretrained_model(model: tf.keras.Model, local_cache: str, url: str):
32 | """Load model from cache"""
33 | try:
34 | local_device_option = tf.train.CheckpointOptions(experimental_io_device="/job:localhost")
35 | except:
36 | import traceback
37 |
38 | print(traceback.format_exc())
39 | local_device_option = tf.CheckpointOptions(experimental_io_device="/job:localhost")
40 | else:
41 | local_device_option = None
42 |
43 | model.load_checkpoint(local_cache, options=local_device_option)
44 | logging.info("Successful ✅: Loaded model from {}".format(url))
45 |
--------------------------------------------------------------------------------
/src/tf_transformers/data/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.data.tfprocessor_utils import TFProcessor
18 | from tf_transformers.data.tfrecord_utils import TFReader, TFWriter
19 | from tf_transformers.data.utils import (
20 | pad_dataset,
21 | pad_dataset_normal,
22 | pad_ragged,
23 | separate_x_y,
24 | )
25 |
--------------------------------------------------------------------------------
/src/tf_transformers/data/callbacks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/src/tf_transformers/data/callbacks/__init__.py
--------------------------------------------------------------------------------
/src/tf_transformers/data/ner_utils_sp.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors and The TensorFlow Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 |
18 | from tf_transformers.utils import fast_sp_alignment
19 |
20 |
21 | def get_tokens_labels(aligned_words, orig_to_new_index, label_tokens, sub_words_mapped, label_pad_token="[PAD]"):
22 | """
23 | convert each sub word into labels
24 | If a word is split into multiple sub words,
25 | then first sub word is assigned with label and other sub words will be padded
26 | """
27 | aligned_labels = [label_pad_token] * len(aligned_words)
28 | for original_pos, new_pos in enumerate(orig_to_new_index):
29 | aligned_labels[new_pos] = label_tokens[original_pos]
30 |
31 | flat_tokens = []
32 | flat_labels = []
33 |
34 | # The first word of the subword token is assigned entity
35 | # other tokens will be add PAD labels (we will mask it while training)
36 | assert len(aligned_words) == len(sub_words_mapped) == len(aligned_labels)
37 | for (_align_word, _align_word, _align_label) in zip(aligned_words, sub_words_mapped, aligned_labels):
38 | temp_w = []
39 | for _align_word in _align_word:
40 | temp_w.append(_align_word)
41 | temp_l = [label_pad_token] * len(temp_w)
42 | temp_l[0] = _align_label
43 | flat_tokens.extend(temp_w)
44 | flat_labels.extend(temp_l)
45 |
46 | return flat_tokens, flat_labels
47 |
48 |
49 | def fast_tokenize_and_align_sentence_for_ner(
50 | tokenizer, sentence, word_tokens, SPECIAL_PIECE, is_training=False, label_tokens=None, label_pad_token=None
51 | ):
52 |
53 | """
54 | align sentence sub words and labels using fast_sp
55 | """
56 | orig_to_new_index, aligned_words, sub_words_mapped = fast_sp_alignment(sentence, tokenizer, SPECIAL_PIECE)
57 |
58 | if is_training:
59 | flat_tokens, flat_labels = get_tokens_labels(
60 | aligned_words, orig_to_new_index, label_tokens, sub_words_mapped, label_pad_token
61 | )
62 | return aligned_words, sub_words_mapped, flat_tokens, flat_labels
63 | else:
64 | flat_tokens = [w for sub_words in sub_words_mapped for w in sub_words]
65 | return aligned_words, sub_words_mapped, flat_tokens, orig_to_new_index
66 |
--------------------------------------------------------------------------------
/src/tf_transformers/data/processors/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/src/tf_transformers/data/processors/__init__.py
--------------------------------------------------------------------------------
/src/tf_transformers/data/processors/mlm_ttext.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/src/tf_transformers/data/processors/mlm_ttext.py
--------------------------------------------------------------------------------
/src/tf_transformers/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.layers.bias_layer import BiasLayer
18 | from tf_transformers.layers.image_embeddings import (
19 | PatchEmbeddings,
20 | PositionEmbeddingImage,
21 | )
22 | from tf_transformers.layers.layer_normalization import (
23 | GPT2LayerNormalization,
24 | T5LayerNormalization,
25 | )
26 | from tf_transformers.layers.mlm_layer import MaskedLM
27 | from tf_transformers.layers.on_device_embedding import OnDeviceEmbedding
28 | from tf_transformers.layers.position_embedding import PositionEmbedding
29 |
--------------------------------------------------------------------------------
/src/tf_transformers/layers/attention/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.layers.attention.bart_attention import BartAttention
18 | from tf_transformers.layers.attention.bert_attention import MultiHeadAttention
19 | from tf_transformers.layers.attention.bigbird_attention import BigBirdAttention
20 | from tf_transformers.layers.attention.clip_attention import CLIPMultiHeadAttention
21 | from tf_transformers.layers.attention.gpt2_attention import GPT2Attention
22 | from tf_transformers.layers.attention.t5_attention import T5Attention
23 |
--------------------------------------------------------------------------------
/src/tf_transformers/layers/bias_layer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors and The TensorFlow Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | import tensorflow as tf
18 |
19 |
20 | class BiasLayer(tf.keras.layers.Layer):
21 | def __init__(self, name="bias", trainable=True, initializer="zeros", *args, **kwargs):
22 | self._trainable = trainable
23 | self._initializer = initializer
24 | self._name = name
25 | super(BiasLayer, self).__init__(name=name, trainable=trainable, **kwargs)
26 |
27 | def build(self, input_shape):
28 | self.bias = self.add_weight(
29 | name="bias", shape=(input_shape[-1],), initializer=self._initializer, trainable=self._trainable
30 | )
31 |
32 | def call(self, x):
33 | return x + self.bias
34 |
--------------------------------------------------------------------------------
/src/tf_transformers/layers/image_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def read_and_process(img, img_height, img_width, num_channels, rescale, normalize, read):
5 | """Read and Process image.
6 | If read is True, tf.io will read and parse the image.
7 | If read is False, we expect a numpy array (PIL open image) or TF image (tf.uint8)
8 |
9 | Args:
10 | img (): tf.string Filepath or tf.uint8 Image array
11 | img_height (int): Image Height for resizing
12 | img_width (int): Image Width for resizing
13 | num_channels (int): 3 for RGB and 1 for GrayScale
14 | rescale (bool): rescale image to (0 to 1)
15 | normalize (bool): normalize image by 0 mean and 1 stddev
16 | read (bool): to read and decode the image or to skip it
17 |
18 | Returns:
19 | tf.float32 (image array (3D))
20 | """
21 | if read:
22 | # Read image
23 | img = tf.io.read_file(img)
24 | # convert the compressed string to a 3D uint8 tensor
25 | img = tf.image.decode_jpeg(img, channels=num_channels)
26 |
27 | # resize the image to the desired size
28 | img = tf.image.resize(img, [img_height, img_width])
29 | # Rescale between (0 and 1)
30 | if rescale:
31 | img = tf.keras.layers.experimental.preprocessing.Rescaling(1.0 / 255)(img)
32 | # TODO tf.keras.layers.experimental.preprocessing.Normalization
33 | if normalize:
34 | img = tf.image.per_image_standardization(img)
35 | return img
36 |
--------------------------------------------------------------------------------
/src/tf_transformers/layers/mask/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.layers.mask.causal_mask import CausalMask
18 | from tf_transformers.layers.mask.cross_attention_mask import CrossAttentionMask
19 | from tf_transformers.layers.mask.prefix_mask import prefix_mask
20 | from tf_transformers.layers.mask.self_attention_mask import SelfAttentionMask
21 |
--------------------------------------------------------------------------------
/src/tf_transformers/layers/mask/causal_mask.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors and The TensorFlow Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """Keras layer that creates a self-attention mask."""
18 |
19 | # from __future__ import google_type_annotations
20 | from __future__ import absolute_import, division, print_function
21 |
22 | import tensorflow as tf
23 |
24 | from tf_transformers.utils import tf_utils
25 |
26 |
27 | def attention_mask_square(nd):
28 | """1's in the lower triangle, counting from the lower right corner.
29 |
30 | Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
31 | """
32 | dtype = tf_utils.get_dtype()
33 | ns = nd
34 | i = tf.range(nd)[:, None]
35 | j = tf.range(ns)
36 | m = i >= j - ns + nd
37 | return tf.cast(m, dtype)
38 |
39 |
40 | @tf.keras.utils.register_keras_serializable(package="Text")
41 | class CausalMask(tf.keras.layers.Layer):
42 | """Create 3D attention mask from a 3D tensor mask.
43 |
44 | inputs[0]: from_tensor: 2D or 3D Tensor of shape
45 | [batch_size, from_seq_length, ...].
46 | inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length].
47 |
48 | Returns:
49 | float Tensor of shape [batch_size, from_seq_length, to_seq_length].
50 | """
51 |
52 | def call(self, inputs):
53 | from_tensor = inputs
54 | from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
55 | batch_size = from_shape[0]
56 | from_seq_length = from_shape[1]
57 |
58 | # 2D Lower Triangular Mask
59 | from_mask = attention_mask_square(from_seq_length)
60 |
61 | # Replicate 2D `N` times
62 | mask = tf.cast(tf.ones([batch_size, 1, 1]), from_mask.dtype) * from_mask
63 |
64 | return mask
65 |
--------------------------------------------------------------------------------
/src/tf_transformers/layers/mask/cross_attention_mask.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors and The TensorFlow Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """Keras layer that creates a self-attention mask."""
18 |
19 | # from __future__ import google_type_annotations
20 | from __future__ import absolute_import, division, print_function
21 |
22 | import tensorflow as tf
23 |
24 | from tf_transformers.utils import tf_utils
25 |
26 |
27 | @tf.keras.utils.register_keras_serializable(package="Text")
28 | class CrossAttentionMask(tf.keras.layers.Layer):
29 | """Create 3D attention mask from a 2D tensor mask.
30 |
31 | inputs[0]: from_tensor: 2D or 3D Tensor of shape
32 | [batch_size, from_seq_length, ...].
33 | inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length].
34 |
35 | Returns:
36 | float Tensor of shape [batch_size, from_seq_length, to_seq_length].
37 | """
38 |
39 | def __init__(self, **kwargs):
40 | # We need to have a default dtype of float32, since the inputs (which Keras
41 | # usually uses to infer the dtype) will always be int32.
42 | if "dtype" not in kwargs:
43 | kwargs["dtype"] = tf_utils.get_dtype()
44 |
45 | super(CrossAttentionMask, self).__init__(**kwargs)
46 | self._dtype = kwargs["dtype"]
47 |
48 | def call(self, inputs):
49 | to_mask = inputs[1]
50 | batch_size, from_seq_length = tf_utils.get_shape_list(inputs[0])
51 | _, to_seq_length = tf_utils.get_shape_list(inputs[1])
52 |
53 | to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]), dtype=self._dtype)
54 |
55 | # We don't assume that `from_tensor` is a mask (although it could be). We
56 | # don't actually care if we attend *from* padding tokens (only *to* padding)
57 | # tokens so we create a tensor of all ones.
58 | #
59 | # `broadcast_ones` = [batch_size, from_seq_length, 1]
60 | broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1], dtype=self._dtype)
61 |
62 | # Here we broadcast along two dimensions to create the mask.
63 | mask = broadcast_ones * to_mask
64 |
65 | return mask
66 |
--------------------------------------------------------------------------------
/src/tf_transformers/layers/mask/self_attention_mask.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors and The TensorFlow Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | """Keras layer that creates a self-attention mask."""
18 |
19 | # from __future__ import google_type_annotations
20 | from __future__ import absolute_import, division, print_function
21 |
22 | import tensorflow as tf
23 |
24 | from tf_transformers.utils import tf_utils
25 |
26 |
27 | @tf.keras.utils.register_keras_serializable(package="Text")
28 | class SelfAttentionMask(tf.keras.layers.Layer):
29 | """Create 3D attention mask from a 2D tensor mask.
30 |
31 | inputs[0]: from_tensor: 2D or 3D Tensor of shape
32 | [batch_size, from_seq_length, ...].
33 | inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length].
34 |
35 | Returns:
36 | float Tensor of shape [batch_size, from_seq_length, to_seq_length].
37 | """
38 |
39 | def call(self, inputs):
40 | """
41 |
42 | Args:
43 | inputs : List of ([embeddings, input_mask])
44 | embeddings: 3D (b x s x h)
45 | input_mask: 2D (b x s)
46 | Returns:
47 | Tensor: (b x from_seq_length x to_seq_length)
48 | """
49 | from_tensor = inputs[0]
50 | to_mask = inputs[1]
51 | from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
52 | batch_size = from_shape[0]
53 | from_seq_length = from_shape[1]
54 |
55 | to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
56 | to_seq_length = to_shape[1]
57 |
58 | to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]), dtype=from_tensor.dtype)
59 |
60 | # We don't assume that `from_tensor` is a mask (although it could be). We
61 | # don't actually care if we attend *from* padding tokens (only *to* padding)
62 | # tokens so we create a tensor of all ones.
63 | #
64 | # `broadcast_ones` = [batch_size, from_seq_length, 1]
65 | broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype)
66 |
67 | # Here we broadcast along two dimensions to create the mask.
68 | mask = broadcast_ones * to_mask
69 |
70 | return mask
71 |
--------------------------------------------------------------------------------
/src/tf_transformers/layers/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.layers.transformer.bart_transformer import TransformerBART
18 | from tf_transformers.layers.transformer.bert_transformer import TransformerBERT
19 | from tf_transformers.layers.transformer.clip_transformer import TransformerCLIP
20 | from tf_transformers.layers.transformer.gpt2_transformer import TransformerGPT2
21 | from tf_transformers.layers.transformer.mt5_transformer import TransformermT5
22 | from tf_transformers.layers.transformer.t5_transformer import TransformerT5
23 | from tf_transformers.layers.transformer.byt5_transformer import TransformerByT5
24 | from tf_transformers.layers.transformer.vit_transformer import TransformerVIT
25 |
--------------------------------------------------------------------------------
/src/tf_transformers/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.losses.cross_entropy import (
18 | cross_entropy_loss,
19 | cross_entropy_loss_for_classification,
20 | cross_entropy_loss_label_smoothing,
21 | )
22 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.albert import (
18 | AlbertConfig,
19 | AlbertEncoder,
20 | AlbertModel,
21 | AlbertTokenizerLayer,
22 | AlbertTokenizerTFText,
23 | )
24 | from tf_transformers.models.bart import BartConfig, BartEncoder, BartModel
25 | from tf_transformers.models.bert import BertConfig, BertEncoder, BertModel
26 | from tf_transformers.models.bigbird import (
27 | BigBirdRobertaTokenizerLayer,
28 | BigBirdRobertaTokenizerTFText,
29 | )
30 | from tf_transformers.models.clip import (
31 | CLIPEncoder,
32 | CLIPFeatureExtractorTF,
33 | CLIPImageConfig,
34 | CLIPImageEncoder,
35 | CLIPModel,
36 | CLIPTextConfig,
37 | CLIPTextEncoder,
38 | )
39 | from tf_transformers.models.distilbert import DistilBertConfig, DistilBertModel
40 | from tf_transformers.models.encoder_decoder import EncoderDecoder
41 | from tf_transformers.models.gpt2 import GPT2Config, GPT2Encoder, GPT2Model
42 | from tf_transformers.models.minilm import MiniLMConfig, MiniLMModel
43 | from tf_transformers.models.mt5 import MT5Config, MT5Encoder, MT5Model
44 | from tf_transformers.models.roberta import RobertaConfig, RobertaEncoder, RobertaModel
45 | from tf_transformers.models.sentence_transformers import SentenceTransformer
46 | from tf_transformers.models.t5 import (
47 | T5Config,
48 | T5Encoder,
49 | T5Model,
50 | T5TokenizerLayer,
51 | T5TokenizerTFText,
52 | )
53 | from tf_transformers.models.byt5 import (
54 | ByT5Config,
55 | ByT5Encoder,
56 | ByT5Model,
57 | )
58 | from tf_transformers.models.tasks import (
59 | Classification_Model,
60 | MaskedLMModel,
61 | Similarity_Model,
62 | Span_Selection_Model,
63 | )
64 | from tf_transformers.models.vit import (
65 | ViTConfig,
66 | ViTEncoder,
67 | ViTFeatureExtractorTF,
68 | ViTModel,
69 | )
70 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/albert/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.albert.albert import AlbertEncoder
18 | from tf_transformers.models.albert.albert_model import AlbertModel
19 | from tf_transformers.models.albert.configuration_albert import AlbertConfig
20 | from tf_transformers.models.albert.tokenizer_albert import (
21 | AlbertTokenizerLayer,
22 | AlbertTokenizerTFText,
23 | )
24 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/bart/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.bart.bart import BartEncoder
18 | from tf_transformers.models.bart.bart_model import BartModel
19 | from tf_transformers.models.bart.configuration_bart import BartConfig
20 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/bert/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.bert.bert import BertEncoder
18 | from tf_transformers.models.bert.bert_model import BertModel
19 | from tf_transformers.models.bert.configuration_bert import BertConfig
20 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/bigbird/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.bigbird.tokenizer_bigbird_roberta import (
18 | BigBirdRobertaTokenizerLayer,
19 | BigBirdRobertaTokenizerTFText,
20 | )
21 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/byt5/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.byt5.configuration_byt5 import ByT5Config
18 | from tf_transformers.models.byt5.byt5 import ByT5Encoder
19 | from tf_transformers.models.byt5.byt5_model import ByT5Model
20 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/clip/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.clip.clip import CLIPEncoder
18 | from tf_transformers.models.clip.clip_image_encoder import CLIPImageEncoder
19 | from tf_transformers.models.clip.clip_text_encoder import CLIPTextEncoder
20 | from tf_transformers.models.clip.configuration_clip import (
21 | CLIPImageConfig,
22 | CLIPTextConfig,
23 | )
24 | from tf_transformers.models.clip.clip_feature_extractor import CLIPFeatureExtractorTF
25 | from tf_transformers.models.clip.clip_model import CLIPModel
26 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/distilbert/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.bert.bert import BertEncoder
18 | from tf_transformers.models.distilbert.configuration_bert import DistilBertConfig
19 | from tf_transformers.models.distilbert.distilbert_model import DistilBertModel
20 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/encoder_decoder/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.encoder_decoder.encoder_decoder import EncoderDecoder
18 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/gpt2/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.gpt2.configuration_gpt2 import GPT2Config
18 | from tf_transformers.models.gpt2.gpt2 import GPT2Encoder
19 | from tf_transformers.models.gpt2.gpt2_model import GPT2Model
20 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/minilm/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.bert.bert import BertEncoder
18 | from tf_transformers.models.minilm.configuration_minilm import MiniLMConfig
19 | from tf_transformers.models.minilm.minilm_model import MiniLMModel
20 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.model_configs.albert import albert_base_v2, albert_large_v2
18 | from tf_transformers.models.model_configs.bert import (
19 | bert_base_cased,
20 | bert_base_uncased,
21 | bert_large_cased,
22 | bert_large_uncased,
23 | )
24 | from tf_transformers.models.model_configs.general_config import TransformerConfig
25 | from tf_transformers.models.model_configs.gpt2 import gpt2, gpt2_medium
26 | from tf_transformers.models.model_configs.mt5 import mt5_small
27 | from tf_transformers.models.model_configs.roberta import roberta_base, roberta_large
28 | from tf_transformers.models.model_configs.t5 import t5_base, t5_small
29 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/albert/albert_base_v2.py:
--------------------------------------------------------------------------------
1 | config = {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "intermediate_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "embedding_size": 128,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 3072,
9 | "max_position_embeddings": 512,
10 | "num_attention_heads": 12,
11 | "attention_head_size": 64,
12 | "num_hidden_layers": 12,
13 | "type_vocab_size": 2,
14 | "vocab_size": 30000,
15 | "layer_norm_epsilon": 1e-12,
16 | "embedding_projection_size": 768,
17 | }
18 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/albert/albert_large_v2.py:
--------------------------------------------------------------------------------
1 | config = {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "intermediate_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "embedding_size": 128,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 3072,
9 | "max_position_embeddings": 512,
10 | "num_attention_heads": 16,
11 | "attention_head_size": 64,
12 | "num_hidden_layers": 24,
13 | "type_vocab_size": 2,
14 | "vocab_size": 30000,
15 | "layer_norm_epsilon": 1e-12,
16 | "embedding_projection_size": 1024,
17 | }
18 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/bert/bert_base_cased.py:
--------------------------------------------------------------------------------
1 | config = {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "intermediate_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "embedding_size": 768,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 3072,
9 | "max_position_embeddings": 512,
10 | "num_attention_heads": 12,
11 | "attention_head_size": 64,
12 | "num_hidden_layers": 12,
13 | "type_vocab_size": 2,
14 | "vocab_size": 28996,
15 | "layer_norm_epsilon": 1e-12,
16 | }
17 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/bert/bert_base_uncased.py:
--------------------------------------------------------------------------------
1 | config = {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "intermediate_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "embedding_size": 768,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 3072,
9 | "max_position_embeddings": 512,
10 | "num_attention_heads": 12,
11 | "num_hidden_layers": 12,
12 | "attention_head_size": 64,
13 | "type_vocab_size": 2,
14 | "vocab_size": 30522,
15 | "layer_norm_epsilon": 1e-12,
16 | }
17 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/bert/bert_large_cased.py:
--------------------------------------------------------------------------------
1 | config = {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "intermediate_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "embedding_size": 1024,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 4096,
9 | "max_position_embeddings": 512,
10 | "num_attention_heads": 16,
11 | "attention_head_size": 64,
12 | "num_hidden_layers": 24,
13 | "type_vocab_size": 2,
14 | "vocab_size": 28996,
15 | "layer_norm_epsilon": 1e-12,
16 | }
17 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/bert/bert_large_uncased.py:
--------------------------------------------------------------------------------
1 | config = {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "intermediate_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "embedding_size": 1024,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 4096,
9 | "max_position_embeddings": 512,
10 | "num_attention_heads": 16,
11 | "attention_head_size": 64,
12 | "num_hidden_layers": 24,
13 | "type_vocab_size": 2,
14 | "vocab_size": 30522,
15 | "layer_norm_epsilon": 1e-12,
16 | }
17 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/gpt2/gpt2.py:
--------------------------------------------------------------------------------
1 | config = {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "intermediate_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "embedding_size": 768,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 3072,
9 | "max_position_embeddings": 1024,
10 | "num_attention_heads": 12,
11 | "attention_head_size": 64,
12 | "num_hidden_layers": 12,
13 | "type_vocab_size": -1,
14 | "vocab_size": 50257,
15 | "layer_norm_epsilon": 1e-05,
16 | "mask_mode": "causal",
17 | }
18 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/gpt2/gpt2_medium.py:
--------------------------------------------------------------------------------
1 | config = {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "intermediate_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "embedding_size": 1024,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 4096,
9 | "max_position_embeddings": 1024,
10 | "num_attention_heads": 16,
11 | "attention_head_size": 64,
12 | "num_hidden_layers": 24,
13 | "type_vocab_size": -1,
14 | "vocab_size": 50257,
15 | "layer_norm_epsilon": 1e-05,
16 | "mask_mode": "causal",
17 | }
18 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/mt5/mt5_small.py:
--------------------------------------------------------------------------------
1 | config = {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "intermediate_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "embedding_size": 512,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 1024,
9 | "max_position_embeddings": -1,
10 | "num_attention_heads": 6,
11 | "attention_head_size": 64,
12 | "num_hidden_layers": 8,
13 | "vocab_size": 250112,
14 | "type_vocab_size": -1,
15 | "layer_norm_epsilon": 1e-06,
16 | "bidirectional": True,
17 | "positional_buckets": 32,
18 | }
19 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/roberta/roberta_base.py:
--------------------------------------------------------------------------------
1 | config = {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "intermediate_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "embedding_size": 768,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 3072,
9 | "max_position_embeddings": 512,
10 | "num_attention_heads": 12,
11 | "attention_head_size": 64,
12 | "num_hidden_layers": 12,
13 | "type_vocab_size": 1,
14 | "vocab_size": 50265,
15 | "layer_norm_epsilon": 1e-05,
16 | }
17 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/roberta/roberta_large.py:
--------------------------------------------------------------------------------
1 | config = {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "intermediate_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "embedding_size": 1024,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 3072,
9 | "max_position_embeddings": 512,
10 | "num_attention_heads": 16,
11 | "attention_head_size": 64,
12 | "num_hidden_layers": 24,
13 | "type_vocab_size": 1,
14 | "vocab_size": 50265,
15 | "layer_norm_epsilon": 1e-05,
16 | }
17 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/t5/t5_base.py:
--------------------------------------------------------------------------------
1 | config = {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "intermediate_act": "relu",
5 | "hidden_dropout_prob": 0.1,
6 | "embedding_size": 768,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 3072,
9 | "max_position_embeddings": -1,
10 | "num_attention_heads": 12,
11 | "attention_head_size": 64,
12 | "num_hidden_layers": 12,
13 | "vocab_size": 32128,
14 | "type_vocab_size": -1,
15 | "layer_norm_epsilon": 1e-06,
16 | "bidirectional": True,
17 | "positional_buckets": 32,
18 | }
19 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/t5/t5_small.py:
--------------------------------------------------------------------------------
1 | config = {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "intermediate_act": "relu",
5 | "hidden_dropout_prob": 0.1,
6 | "embedding_size": 512,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 2048,
9 | "max_position_embeddings": -1,
10 | "num_attention_heads": 8,
11 | "attention_head_size": 64,
12 | "num_hidden_layers": 6,
13 | "vocab_size": 32128,
14 | "type_vocab_size": -1,
15 | "layer_norm_epsilon": 1e-06,
16 | "bidirectional": True,
17 | "positional_buckets": 32,
18 | }
19 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/model_configs/unilm_cnndm/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "intermediate_act": "gelu",
5 | "hidden_dropout_prob": 0.1,
6 | "embedding_size": 1024,
7 | "initializer_range": 0.02,
8 | "intermediate_size": 4096,
9 | "max_position_embeddings": 768,
10 | "num_attention_heads": 16,
11 | "num_hidden_layers": 24,
12 | "type_vocab_size": 6,
13 | "vocab_size": 28996,
14 | "layer_norm_epsilon": 1e-05
15 | }
16 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/mt5/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.mt5.configuration_mt5 import MT5Config
18 | from tf_transformers.models.mt5.mt5 import MT5Encoder
19 | from tf_transformers.models.mt5.mt5_model import MT5Model
20 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/roberta/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.roberta.configuration_roberta import RobertaConfig
18 | from tf_transformers.models.roberta.roberta import RobertaEncoder
19 | from tf_transformers.models.roberta.roberta_model import RobertaModel
20 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/sentence_transformers/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.sentence_transformers.sentence_transformers import (
18 | SentenceTransformer,
19 | )
20 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/t5/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.t5.configuration_t5 import T5Config
18 | from tf_transformers.models.t5.t5 import T5Encoder
19 | from tf_transformers.models.t5.t5_model import T5Model
20 | from tf_transformers.models.t5.tokenizer_t5 import T5TokenizerLayer, T5TokenizerTFText
21 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.tasks.classification import Classification_Model
18 | from tf_transformers.models.tasks.maked_lm_model import MaskedLMModel
19 | from tf_transformers.models.tasks.similarity_model import Similarity_Model
20 | from tf_transformers.models.tasks.span_selection import Span_Selection_Model
21 |
--------------------------------------------------------------------------------
/src/tf_transformers/models/vit/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.models.vit.configuration_vit import ViTConfig
18 | from tf_transformers.models.vit.vit import ViTEncoder
19 | from tf_transformers.models.vit.vit_feature_extractor import ViTFeatureExtractorTF
20 | from tf_transformers.models.vit.vit_model import ViTModel
21 |
--------------------------------------------------------------------------------
/src/tf_transformers/optimization/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.optimization.optimization import create_optimizer
18 |
--------------------------------------------------------------------------------
/src/tf_transformers/text/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.text.decoder_utils import (
18 | _gather_beams,
19 | _log_prob_from_logits,
20 | assign_zeros_to_K_V,
21 | top_k_logits,
22 | top_p_logits,
23 | )
24 | from tf_transformers.text.sentencepiece_layer import SentencepieceTokenizer
25 | from tf_transformers.text.text_decoder import TextDecoder, TextDecoderSerializable
26 | from tf_transformers.text.text_decoder_model import TextDecoderModel
27 | from tf_transformers.text.text_decoder_seq2seq import TextDecoderSeq2Seq
28 | from tf_transformers.text.text_decoder_seq2seq_serializable import (
29 | TextDecoderSerializableSeq2Seq,
30 | )
31 |
--------------------------------------------------------------------------------
/src/tf_transformers/text/lm_tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from tf_transformers.text.lm_tasks.causal_lm import causal_lm_fn
2 | from tf_transformers.text.lm_tasks.masked_lm import mlm_fn
3 | from tf_transformers.text.lm_tasks.prefix_lm import prefix_lm_fn, prefix_lm_fn_v2
4 |
--------------------------------------------------------------------------------
/src/tf_transformers/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 TF-Transformers Authors.
3 | # All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | from tf_transformers.utils.fast_sp_alignment import fast_sp_alignment
18 | from tf_transformers.utils.tokenization import BasicTokenizer
19 | from tf_transformers.utils.utils import (
20 | get_config,
21 | get_model_wrapper,
22 | validate_model_name,
23 | )
24 |
--------------------------------------------------------------------------------
/src/tf_transformers/utils/docstring_file_utils.py:
--------------------------------------------------------------------------------
1 | def add_start_docstrings(*docstr):
2 | def docstring_decorator(fn):
3 | fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
4 | return fn
5 |
6 | return docstring_decorator
7 |
--------------------------------------------------------------------------------
/src/tf_transformers/utils/positional_bias_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import tensorflow as tf
4 |
5 |
6 | def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
7 | """
8 | Adapted from Mesh Tensorflow:
9 | https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/\
10 | mesh_tensorflow/transformer/transformer_layers.py#L593
11 |
12 | Translate relative position to a bucket number for relative attention.
13 | The relative position is defined as memory_position - query_position, i.e.
14 | the distance in tokens from the attending position to the attended-to
15 | position. If bidirectional=False, then positive relative positions are
16 | invalid.
17 | We use smaller buckets for small absolute relative_position and larger buckets
18 | for larger absolute relative_positions. All relative positions >=max_distance
19 | map to the same bucket. All relative positions <=-max_distance map to the
20 | same bucket. This should allow for more graceful generalization to longer
21 | sequences than the model has been trained on.
22 | Args:
23 | relative_position: an int32 Tensor
24 | bidirectional: a boolean - whether the attention is bidirectional
25 | num_buckets: an integer
26 | max_distance: an integer
27 | Returns:
28 | a Tensor with the same shape as relative_position, containing int32
29 | values in the range [0, num_buckets)
30 | """
31 | ret = 0
32 | n = -relative_position
33 | if bidirectional:
34 | num_buckets //= 2
35 | ret += tf.dtypes.cast(tf.math.less(n, 0), tf.int32) * num_buckets
36 | n = tf.math.abs(n)
37 | else:
38 | n = tf.math.maximum(n, 0)
39 | # now n is in the range [0, inf)
40 | max_exact = num_buckets // 2
41 | is_small = tf.math.less(n, max_exact)
42 | val_if_large = max_exact + tf.dtypes.cast(
43 | tf.math.log(tf.dtypes.cast(n, tf.float32) / max_exact)
44 | / math.log(max_distance / max_exact)
45 | * (num_buckets - max_exact),
46 | tf.int32,
47 | )
48 | val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
49 | ret += tf.where(is_small, n, val_if_large)
50 | return ret
51 |
52 |
53 | def compute_positional_bias(qlen, klen, bidirectional=True, num_buckets=32):
54 | """Compute binned relative position bias"""
55 | context_position = tf.range(qlen)[:, None]
56 | memory_position = tf.range(klen)[None, :]
57 | relative_position = memory_position - context_position # shape (qlen, klen)
58 | rp_bucket = _relative_position_bucket(
59 | relative_position,
60 | bidirectional=bidirectional,
61 | num_buckets=num_buckets,
62 | )
63 | return rp_bucket
64 |
--------------------------------------------------------------------------------
/src/tf_transformers/utils/push_to_hub.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import time
4 | from distutils.dir_util import copy_tree
5 |
6 | from absl import logging
7 |
8 | from tf_transformers.models import ViTModel
9 |
10 | logging.set_verbosity("INFO")
11 |
12 | # Load Model
13 | model_name = 'vit-large-patch32-384'
14 | model = ViTModel.from_pretrained(model_name)
15 |
16 | model = ViTModel.from_pretrained(model_name, classification_labels=1000)
17 |
18 |
19 | models_list = [
20 | 'vit-base-patch16-224',
21 | 'vit-base-patch32-384',
22 | 'vit-base-patch32-224-in21k',
23 | 'vit-large-patch16-224',
24 | 'vit-large-patch32-384',
25 | ]
26 | cwd = os.getcwd()
27 | MODEL_DIR = '/home/sarathrnair/MODELS/'
28 |
29 | for model_name in models_list:
30 |
31 | subprocess.run(
32 | ['huggingface-cli', 'repo', 'create', '{}'.format(model_name), '--yes', '--organization', 'tftransformers']
33 | )
34 |
35 | subprocess.run(["git", "clone", "https://huggingface.co/tftransformers/{}".format(model_name)])
36 | new_working_dir = os.path.join(cwd, model_name)
37 | os.chdir("{}".format(new_working_dir))
38 | cached_model_dir = os.path.join(MODEL_DIR, "tf_transformers_cache/{}/".format(model_name))
39 | copy_tree(cached_model_dir, new_working_dir)
40 |
41 | subprocess.run(["git-lfs", "track", "*"])
42 | subprocess.run(["git", "add", "."])
43 | subprocess.run(["git", "commit", "-m", "New Model"])
44 | subprocess.run(["git", "push"])
45 |
46 | os.chdir("{}".format(cwd))
47 | time.sleep(2)
48 | logging.info("Completed {}".format(model_name))
49 | print("------------------------------------------------------------------")
50 |
--------------------------------------------------------------------------------
/src/tf_transformers/utils/utils.py:
--------------------------------------------------------------------------------
1 | def get_config(module_name, config_name):
2 | """Load config from .py based on importlib
3 |
4 | Args:
5 | module_name ([type]): [description]
6 | config_name ([type]): [description]
7 |
8 | Raises:
9 | ValueError: [description]
10 |
11 | """
12 | import importlib
13 |
14 | my_module = importlib.import_module(module_name)
15 | config = getattr(my_module, config_name).config
16 | return config
17 |
18 |
19 | def get_model_wrapper(model_name):
20 |
21 | import importlib
22 |
23 | model_name = model_name.split("_")[0].strip() # roberta_base --> roberta
24 | model_cls = importlib.import_module("tf_transformers.models.model_wrappers.{}_wrapper".format(model_name))
25 | return model_cls.modelWrapper
26 |
27 |
28 | def validate_model_name(model_name, allowed_model_names):
29 | """Validate model_name
30 |
31 | Args:
32 | model_name ([type]): [description]
33 |
34 | Raises:
35 | ValueError: [description]
36 | """
37 | if model_name not in allowed_model_names:
38 | raise ValueError("{} not in allowed names {}".format(model_name, allowed_model_names))
39 |
40 |
41 | # This is unused
42 | # def pytorch_conversion_debug():
43 | # # Check pytorch conversion wiith TF conversion
44 | # import numpy as np
45 | # for index , var in enumerate(model.variables):
46 |
47 | # var2 = model_tf.variables[index]
48 |
49 | # var_shape = list(var.shape)
50 | # var2_shape = list(var2.shape)
51 |
52 | # assert(var_shape == var2_shape)
53 |
54 | # if len(var_shape) == 1:
55 | # var_sum = tf.reduce_sum(var).numpy()
56 | # var2_sum = tf.reduce_sum(var2).numpy()
57 | # assert(np.allclose(var_sum, var2_sum) == True)
58 |
59 | # if len(var_shape) == 2:
60 | # var_sum = tf.reduce_sum(var, axis=-1).numpy()
61 | # var2_sum = tf.reduce_sum(var2, axis=-1).numpy()
62 | # assert(np.allclose(var_sum, var2_sum) == True)
63 |
64 | # if len(var_shape) == 3:
65 | # var_sum = tf.reduce_sum(var, axis=[0, 2]).numpy()
66 | # var2_sum = tf.reduce_sum(var2, axis=[0, 2]).numpy()
67 | # assert(np.allclose(var_sum, var2_sum) == True)
68 |
--------------------------------------------------------------------------------
/src/tf_transformers/utils/viz_utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import seaborn as sns
4 |
5 |
6 | def compute_scores(vectors):
7 | corr = np.inner(vectors, vectors)
8 | cmax = np.max(corr)
9 | corr /= cmax
10 | return corr
11 |
12 |
13 | def plot_similarity(labels, features1, features2, rotation, title1="Model1", title2="Model2"):
14 |
15 | corr1 = compute_scores(features1)
16 | corr2 = compute_scores(features2)
17 | sns.set(rc={"axes.facecolor": "white", "figure.facecolor": "white"})
18 | sns.set_context("poster")
19 |
20 | fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 8))
21 | fig.subplots_adjust(wspace=0.02)
22 |
23 | sns.set(font_scale=1.0)
24 | g1 = sns.heatmap(
25 | corr1,
26 | ax=ax1,
27 | cbar=False,
28 | yticklabels=labels,
29 | xticklabels=labels,
30 | vmin=np.min(corr1),
31 | vmax=np.max(corr1),
32 | cmap="Blues",
33 | )
34 |
35 | g2 = sns.heatmap(
36 | corr2,
37 | ax=ax2,
38 | cbar=False,
39 | xticklabels=labels,
40 | vmin=np.min(corr2),
41 | vmax=np.max(corr2),
42 | cmap="Blues",
43 | )
44 | g2.set(yticks=[])
45 | fig.colorbar(ax2.collections[0], ax=ax1, location="right", use_gridspec=False, pad=0.01)
46 | fig.colorbar(ax2.collections[0], ax=ax2, location="right", use_gridspec=False, pad=0.01)
47 |
48 | g1.set_title(title1)
49 | g2.set_title(title2)
50 |
--------------------------------------------------------------------------------
/src/transformers_blue.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/src/transformers_blue.png
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/tests/__init__.py
--------------------------------------------------------------------------------
/tests/model_test_scripts/test_wav2vec2.py:
--------------------------------------------------------------------------------
1 |
2 | from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
3 | from datasets import load_dataset
4 | # import soundfile as sf
5 | import torch
6 |
7 | # load model and processor
8 | processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h")
9 | model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h")
10 |
11 |
12 | from scipy.io import wavfile
13 | import numpy as np
14 |
15 | file_name = 'sample.wav'
16 | data = wavfile.read(file_name)
17 | framerate = data[0]
18 | sounddata = data[1]
19 | time = np.arange(0,len(sounddata))/framerate
20 | print('Sample rate:',framerate,'Hz')
21 | print('Total time:',len(sounddata)/framerate,'s')
22 |
23 | # Load file
24 | import librosa
25 | input_audio, _ = librosa.load(file_name,
26 | sr=16000)
27 |
28 | input_values = processor(input_audio, return_tensors="pt").input_values # torch.Size([1, 3270299])
29 |
30 | logits = model(input_values).logits # torch.Size([1, 10219, 32])
31 |
32 | predicted_ids = torch.argmax(logits, dim=-1) # torch.Size([1, 10219])
33 |
34 |
35 | transcription = processor.batch_decode(predicted_ids)[0]
--------------------------------------------------------------------------------
/tests/test_tf_transformers.py:
--------------------------------------------------------------------------------
1 | from tf_transformers import __version__
2 |
3 | version = "2.0.0"
4 | def test_version():
5 | assert __version__ == version
6 |
--------------------------------------------------------------------------------
/tutorials/README.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "4193b067",
6 | "metadata": {},
7 | "source": [
8 | "We use jupytext to keep copies of notebook in sync with Markdown equivalent.\n",
9 | "\n",
10 | "### Adding a new notebook\n",
11 | "\n",
12 | "```\n",
13 | "jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb\n",
14 | "```\n",
15 | "\n",
16 | "### Syncing Notebooks\n",
17 | "\n",
18 | "After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running:\n",
19 | "\n",
20 | "```\n",
21 | "jupytext --sync docs/notebooks/*\n",
22 | "```"
23 | ]
24 | }
25 | ],
26 | "metadata": {
27 | "jupytext": {
28 | "cell_metadata_filter": "-all",
29 | "formats": "ipynb,md:myst",
30 | "main_language": "python"
31 | }
32 | },
33 | "nbformat": 4,
34 | "nbformat_minor": 5
35 | }
36 |
--------------------------------------------------------------------------------
/tutorials/README.md:
--------------------------------------------------------------------------------
1 | ---
2 | jupytext:
3 | cell_metadata_filter: -all
4 | formats: ipynb,md:myst
5 | main_language: python
6 | text_representation:
7 | extension: .md
8 | format_name: myst
9 | format_version: 0.13
10 | jupytext_version: 1.14.4
11 | ---
12 |
13 | We use jupytext to keep copies of notebook in sync with Markdown equivalent.
14 |
15 | ### Adding a new notebook
16 |
17 | ```
18 | jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb
19 | ```
20 |
21 | ### Syncing Notebooks
22 |
23 | After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running:
24 |
25 | ```
26 | jupytext --sync docs/notebooks/*
27 | ```
28 |
--------------------------------------------------------------------------------
/tutorials/push_model_to_hf_hub.md:
--------------------------------------------------------------------------------
1 | ---
2 | jupytext:
3 | formats: ipynb,md:myst
4 | text_representation:
5 | extension: .md
6 | format_name: myst
7 | format_version: 0.13
8 | jupytext_version: 1.14.4
9 | kernelspec:
10 | display_name: Python 3 (ipykernel)
11 | language: python
12 | name: python3
13 | ---
14 |
15 | ```{code-cell} ipython3
16 |
17 | ```
18 |
19 | ```{code-cell} ipython3
20 | import subprocess
21 | import os
22 | from distutils.dir_util import copy_tree
23 | ```
24 |
25 | ### How to push a model to hub .
26 |
27 | * Make sure you have logged to ```huggingface-cli login``` using your token.
28 |
29 | ```{code-cell} ipython3
30 |
31 | ```
32 |
33 | ```{code-cell} ipython3
34 | model_name = 'byt5-small'
35 | ```
36 |
37 | #### 1. Create model name directory under organization name
38 |
39 | ```{code-cell} ipython3
40 | subprocess.run(['huggingface-cli', 'repo',
41 | 'create', '{}'.format(model_name),
42 | '--yes',
43 | '--organization', 'tftransformers'])
44 | ```
45 |
46 | ```{code-cell} ipython3
47 |
48 | ```
49 |
50 | #### 2. Now clone that above created repo/folder to our local cwd
51 |
52 | ```{code-cell} ipython3
53 | subprocess.run(["git", "clone", "https://huggingface.co/tftransformers/{}".format(model_name)])
54 | ```
55 |
56 | ```{code-cell} ipython3
57 |
58 | ```
59 |
60 | #### 3. Now move your model directory , to current working directory under ```model_name``` directory
61 |
62 | ```{code-cell} ipython3
63 | cwd = os.getcwd() # Getc current working dir
64 | new_working_dir = os.path.join(cwd, model_name) # This is cloned from hf hub under organization
65 | os.chdir("{}".format(new_working_dir)) # Switch to new working dir
66 |
67 | # Cached model directory keep changing as per other machine
68 | cached_model_dir = '/var/folders/vq/4fxns8l55gq8_msgygbyb51h0000gn/T/tf_transformers_cache/{}/'.format(model_name)
69 |
70 | # Copy cached model directory , to new working directory
71 | copy_tree(cached_model_dir, new_working_dir)
72 | ```
73 |
74 | ```{code-cell} ipython3
75 |
76 | ```
77 |
78 | #### 4. Now time to push these model to hub
79 |
80 | ```{code-cell} ipython3
81 | subprocess.run(["git-lfs", "track", "*"])
82 | subprocess.run(["git", "add", "."])
83 | subprocess.run(["git", "commit", "-m", "Pushing new model {}".format(model_name)]) # Commit message
84 | subprocess.run(["git", "push"])
85 |
86 | # Change back to original cwd
87 | os.chdir("{}".format(cwd))
88 | ```
89 |
90 | ```{code-cell} ipython3
91 |
92 | ```
93 |
--------------------------------------------------------------------------------