├── .flake8 ├── .github └── workflows │ ├── build_documentation.yaml │ ├── ci_cd.yaml │ ├── patch_version.yaml │ └── release.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks ├── albert │ ├── README.MD │ ├── benchmark_hf.py │ ├── benchmark_tft.py │ ├── conf │ │ ├── benchmark │ │ │ ├── hf.yaml │ │ │ └── tft.yaml │ │ └── config.yaml │ └── run.py ├── gpt2 │ ├── README.MD │ ├── benchmark_hf.py │ ├── benchmark_tft.py │ ├── conf │ │ ├── benchmark │ │ │ ├── hf.yaml │ │ │ └── tft.yaml │ │ └── config.yaml │ └── run.py ├── imagenet_clip_benchmark.ipynb ├── imagenet_clip_benchmark.md ├── roberta │ ├── README.MD │ ├── __init__.py │ ├── benchmark_hf.py │ ├── benchmark_tft.py │ ├── conf │ │ ├── benchmark │ │ │ ├── hf.yaml │ │ │ └── tft.yaml │ │ └── config.yaml │ └── run.py ├── t5 │ ├── README.MD │ ├── benchmark_hf.py │ ├── benchmark_tft.py │ ├── conf │ │ ├── benchmark │ │ │ ├── hf.yaml │ │ │ └── tft.yaml │ │ └── config.yaml │ └── run.py └── vit │ ├── README.MD │ ├── benchmark_hf.py │ ├── benchmark_tft.py │ ├── conf │ ├── benchmark │ │ ├── hf.yaml │ │ └── tft.yaml │ └── config.yaml │ └── run.py ├── custom_hook.py ├── docs ├── Makefile ├── index.html ├── requirements.txt ├── requirements_build.txt └── source │ ├── README.md │ ├── README.rst │ ├── _static │ ├── css │ │ ├── Calibre-Light.ttf │ │ ├── Calibre-Medium.otf │ │ ├── Calibre-Regular.otf │ │ ├── Calibre-Thin.otf │ │ ├── DMSans-Bold.ttf │ │ ├── DMSans-BoldItalic.ttf │ │ ├── DMSans-Italic.ttf │ │ ├── DMSans-Medium.ttf │ │ ├── DMSans-MediumItalic.ttf │ │ ├── DMSans-Regular.ttf │ │ ├── DMSerifText-Italic.ttf │ │ ├── DMSerifText-Regular.ttf │ │ ├── code-snippets.css │ │ └── custom_from_huggingface.css │ ├── js │ │ ├── custom.js │ │ └── huggingface_logo.svg │ ├── tf_transformers.png │ ├── tf_transformers_resized.png │ ├── transformers_blue.png │ └── transformers_mix.png │ ├── benchmarks │ ├── albert.md │ ├── gpt2.md │ ├── imagenet_clip_benchmark.md │ ├── t5.md │ └── vit.md │ ├── conf.py │ ├── favicon.ico │ ├── imgs │ ├── long_block_sequencer.gif │ └── philosophy.png │ ├── index.rst │ ├── introduction_docs │ ├── installation.md │ ├── philosophy.rst │ └── quicktour.rst │ ├── model_doc │ ├── albert.rst │ ├── albert_tokenizer.rst │ ├── bart.rst │ ├── bert.rst │ ├── bigbird_tokenizer.rst │ ├── clip.rst │ ├── clip_feature_extractor.rst │ ├── encoder_decoder.rst │ ├── gpt2.rst │ ├── m2m100.rst │ ├── mbart.rst │ ├── mt5.rst │ ├── parallelism.md │ ├── roberta.rst │ ├── sentence_transformer.rst │ ├── t5.rst │ ├── t5_tokenizer.rst │ ├── visual_bert.rst │ ├── vit.rst │ ├── vit_feature_extractor.rst │ └── wav2vec2.rst │ ├── model_usage │ ├── sentence_transformers.ipynb │ ├── sentence_transformers.md │ ├── text_generation_using_gpt2.ipynb │ ├── text_generation_using_gpt2.md │ ├── text_generation_using_t5.ipynb │ └── text_generation_using_t5.md │ ├── research │ ├── glue.md │ └── long_block_sequencer.md │ ├── tflite_tutorials │ ├── albert_tflite.ipynb │ ├── albert_tflite.md │ ├── bert_tflite.ipynb │ ├── bert_tflite.md │ ├── roberta_tflite.ipynb │ └── roberta_tflite.md │ └── tutorials │ ├── 1_read_write_tfrecords.ipynb │ ├── 1_read_write_tfrecords.md │ ├── 2_text_classification_imdb_albert.ipynb │ ├── 2_text_classification_imdb_albert.md │ ├── 3_masked_lm_tpu.ipynb │ ├── 3_masked_lm_tpu.md │ ├── 4_image_classification_vit_multi_gpu.ipynb │ ├── 4_image_classification_vit_multi_gpu.md │ ├── 5_sentence_embedding_roberta_quora_zeroshot.ipynb │ ├── 5_sentence_embedding_roberta_quora_zeroshot.md │ ├── 6_prompt_engineering_clip.ipynb │ ├── 6_prompt_engineering_clip.md │ ├── 7_gpt2_question_answering_squad.ipynb │ ├── 7_gpt2_question_answering_squad.md │ ├── 8_code_code_java_to_csharp_t5.ipynb │ ├── 8_code_code_java_to_csharp_t5.md │ ├── 9_images_tfrecords.ipynb │ ├── 9_images_tfrecords.md │ ├── README.ipynb │ ├── README.md │ ├── push_model_to_hf_hub.ipynb │ ├── push_model_to_hf_hub.md │ └── sample.ipynb ├── isort.cfg ├── mypy.ini ├── patch_version.py ├── poetry.lock ├── pyproject.toml ├── research ├── c4_grammatical_correction │ ├── README.md │ ├── __init__.py │ ├── conf │ │ └── config.yaml │ ├── dataset_loader.py │ ├── model.py │ ├── run_c4_grammar_correction.py │ └── train_c4_grammar_correction.py ├── diffusion │ ├── attention.py │ ├── beta_schedule.py │ ├── check_diffusion.ipynb │ ├── check_unet.ipynb │ ├── gaussian_diffusion.py │ ├── resnet.py │ ├── time_embedding_layers.py │ ├── unet.py │ └── utils.py ├── experiments │ ├── long_sequencer.ipynb │ ├── long_sequencer_t5_small.ipynb │ ├── roberta2roberta_pubmed.ipynb │ └── roberta2roberta_xsum.ipynb ├── glue │ ├── README.md │ ├── cola.py │ ├── conf │ │ ├── config.yaml │ │ └── glue │ │ │ ├── cola.yaml │ │ │ ├── mnli.yaml │ │ │ ├── mrpc.yaml │ │ │ ├── qnli.yaml │ │ │ ├── qqp.yaml │ │ │ ├── rte.yaml │ │ │ ├── sst2.yaml │ │ │ └── stsb.yaml │ ├── mnli.py │ ├── model.py │ ├── mrpc.py │ ├── qnli.py │ ├── qqp.py │ ├── rte.py │ ├── run_glue.py │ ├── run_mnli_mismatched.py │ ├── score_glue.py │ ├── sst2.py │ ├── stsb.py │ └── test_glue.ipynb ├── long_block_sequencer │ ├── README.md │ ├── __init__.py │ ├── conf │ │ └── config.yaml │ ├── dataset_loader.py │ ├── evaluate.py │ ├── long_block_encoder.py │ ├── model.py │ ├── run_long_block_sequencer.py │ └── train_long_block_sequencer.py ├── masked_language_model │ ├── README.md │ ├── __init__.py │ ├── callbacks.py │ ├── conf │ │ └── config.yaml │ ├── dataset_loader.py │ ├── model.py │ ├── run_mlm.py │ └── train_mlm.py ├── masked_language_model_old │ ├── 1_data_prepration.ipynb │ ├── 1_data_to_text.py │ ├── 2_text_to_features.py │ ├── README.MD │ ├── config │ │ ├── data_config.yaml │ │ ├── tfrecord_config.yaml │ │ └── train_config.yaml │ ├── prepare_data.py │ └── train_mlm.py ├── mix_language_model │ ├── README.MD │ ├── __init__.py │ ├── conf │ │ └── config.yaml │ ├── dataset_loader.py │ ├── dataset_loader_with_sentence_split.py │ ├── mix_lm_model.py │ ├── model.py │ ├── run_mix_lm.py │ └── train_mix_lm.py ├── mix_language_model_old │ ├── 1_data_prepration.ipynb │ ├── 1_data_to_text.py │ ├── 2_text_to_features.py │ ├── README.MD │ ├── config │ │ ├── data_config.yaml │ │ ├── tfrecord_config.yaml │ │ └── train_config.yaml │ ├── eval_mix_lm.py │ └── train_mix_lm.py ├── sentence2vec │ ├── dataset_loader.py │ ├── model.py │ ├── similarity_model.py │ └── train.py ├── sentence_language_model │ ├── README.md │ ├── __init__.py │ ├── callbacks.py │ ├── conf │ │ └── config.yaml │ ├── dataset_loader.py │ ├── model.py │ ├── run_mlm.py │ └── train_mlm.py ├── similarity_model_pretraining │ ├── README.MD │ ├── __init__.py │ ├── callbacks.py │ ├── conf │ │ └── config.yaml │ ├── dataset_loader.py │ ├── model.py │ ├── run_similarity.py │ ├── similarity_model.py │ └── train_similairity.py └── t5_style_pretraining │ ├── README.md │ ├── __init__.py │ ├── conf │ └── config.yaml │ ├── dataset_loader.py │ ├── model.py │ ├── run_t5_modified.py │ ├── t5_modified.py │ ├── t5_tokenizer_modified.py │ └── train_t5_modified.py ├── src ├── logo.png ├── logo2.png ├── tf_transformers │ ├── __init__.py │ ├── activations │ │ ├── __init__.py │ │ ├── gelu.py │ │ ├── swish.py │ │ └── utils.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── metric_callback_list.py │ │ └── metrics │ │ │ ├── __init__.py │ │ │ ├── callbacks.py │ │ │ ├── pearson_spearman_callback.py │ │ │ ├── sklearn_callbacks.py │ │ │ └── text_generation_callbacks.py │ ├── core │ │ ├── __init__.py │ │ ├── chainer.py │ │ ├── distribute_utils.py │ │ ├── keras_utils.py │ │ ├── legacy_compile.py │ │ ├── legacy_layer.py │ │ ├── legacy_model.py │ │ ├── legacy_module.py │ │ ├── model_utils_for_all.py │ │ ├── model_wrapper.py │ │ ├── performance_utils.py │ │ ├── read_from_hub.py │ │ ├── trainer.py │ │ ├── trainer_for_all.py │ │ └── transformer_config.py │ ├── data │ │ ├── __init__.py │ │ ├── callbacks │ │ │ ├── __init__.py │ │ │ └── mlm_callback.py │ │ ├── ner_utils_sp.py │ │ ├── processors │ │ │ ├── __init__.py │ │ │ ├── mlm.py │ │ │ └── mlm_ttext.py │ │ ├── squad_utils_sp.py │ │ ├── tfprocessor_utils.py │ │ ├── tfrecord_utils.py │ │ └── utils.py │ ├── layers │ │ ├── __init__.py │ │ ├── attention │ │ │ ├── __init__.py │ │ │ ├── bart_attention.py │ │ │ ├── bert_attention.py │ │ │ ├── bigbird_attention.py │ │ │ ├── clip_attention.py │ │ │ ├── gpt2_attention.py │ │ │ └── t5_attention.py │ │ ├── bias_layer.py │ │ ├── dense_einsum.py │ │ ├── image_embeddings.py │ │ ├── image_utils.py │ │ ├── layer_normalization.py │ │ ├── mask │ │ │ ├── __init__.py │ │ │ ├── causal_mask.py │ │ │ ├── cross_attention_mask.py │ │ │ ├── masked_softmax.py │ │ │ ├── prefix_mask.py │ │ │ └── self_attention_mask.py │ │ ├── mlm_layer copy.py │ │ ├── mlm_layer.py │ │ ├── multihead_attention.py │ │ ├── on_device_embedding.py │ │ ├── position_embedding.py │ │ └── transformer │ │ │ ├── __init__.py │ │ │ ├── bart_transformer.py │ │ │ ├── bert_transformer.py │ │ │ ├── byt5_transformer.py │ │ │ ├── clip_transformer.py │ │ │ ├── gpt2_transformer.py │ │ │ ├── mt5_transformer.py │ │ │ ├── t5_transformer.py │ │ │ └── vit_transformer.py │ ├── losses │ │ ├── __init__.py │ │ ├── cross_entropy.py │ │ └── loss_wrapper.py │ ├── models │ │ ├── __init__.py │ │ ├── albert │ │ │ ├── __init__.py │ │ │ ├── albert.py │ │ │ ├── albert_model.py │ │ │ ├── configuration_albert.py │ │ │ ├── convert.py │ │ │ └── tokenizer_albert.py │ │ ├── bart │ │ │ ├── __init__.py │ │ │ ├── bart.py │ │ │ ├── bart_model.py │ │ │ ├── configuration_bart.py │ │ │ └── convert.py │ │ ├── bert │ │ │ ├── __init__.py │ │ │ ├── bert.py │ │ │ ├── bert_model.py │ │ │ ├── configuration_bert.py │ │ │ └── convert.py │ │ ├── bigbird │ │ │ ├── __init__.py │ │ │ └── tokenizer_bigbird_roberta.py │ │ ├── byt5 │ │ │ ├── __init__.py │ │ │ ├── byt5.py │ │ │ ├── byt5_model.py │ │ │ ├── configuration_byt5.py │ │ │ └── convert.py │ │ ├── clip │ │ │ ├── __init__.py │ │ │ ├── clip.py │ │ │ ├── clip_feature_extractor.py │ │ │ ├── clip_image_encoder.py │ │ │ ├── clip_model.py │ │ │ ├── clip_text_encoder.py │ │ │ ├── configuration_clip.py │ │ │ └── convert.py │ │ ├── distilbert │ │ │ ├── __init__.py │ │ │ ├── configuration_bert.py │ │ │ ├── convert.py │ │ │ └── distilbert_model.py │ │ ├── encoder_decoder │ │ │ ├── __init__.py │ │ │ └── encoder_decoder.py │ │ ├── gpt2 │ │ │ ├── __init__.py │ │ │ ├── configuration_gpt2.py │ │ │ ├── convert.py │ │ │ ├── gpt2.py │ │ │ └── gpt2_model.py │ │ ├── minilm │ │ │ ├── __init__.py │ │ │ ├── configuration_minilm.py │ │ │ ├── convert.py │ │ │ └── minilm_model.py │ │ ├── model_configs │ │ │ ├── __init__.py │ │ │ ├── albert │ │ │ │ ├── albert_base_v2.py │ │ │ │ └── albert_large_v2.py │ │ │ ├── bert │ │ │ │ ├── bert_base_cased.py │ │ │ │ ├── bert_base_uncased.py │ │ │ │ ├── bert_large_cased.py │ │ │ │ └── bert_large_uncased.py │ │ │ ├── general_config.py │ │ │ ├── gpt2 │ │ │ │ ├── gpt2.py │ │ │ │ └── gpt2_medium.py │ │ │ ├── mt5 │ │ │ │ └── mt5_small.py │ │ │ ├── roberta │ │ │ │ ├── roberta_base.py │ │ │ │ └── roberta_large.py │ │ │ ├── t5 │ │ │ │ ├── t5_base.py │ │ │ │ └── t5_small.py │ │ │ └── unilm_cnndm │ │ │ │ └── config.json │ │ ├── mt5 │ │ │ ├── __init__.py │ │ │ ├── configuration_mt5.py │ │ │ ├── convert.py │ │ │ ├── mt5.py │ │ │ └── mt5_model.py │ │ ├── roberta │ │ │ ├── __init__.py │ │ │ ├── configuration_roberta.py │ │ │ ├── convert.py │ │ │ ├── roberta.py │ │ │ └── roberta_model.py │ │ ├── sentence_transformers │ │ │ ├── __init__.py │ │ │ ├── distilbert_sentence_model.py │ │ │ ├── distilroberta_model.py │ │ │ ├── minilm_model.py │ │ │ ├── sentence_transformers.py │ │ │ └── t5_sentence_model.py │ │ ├── t5 │ │ │ ├── __init__.py │ │ │ ├── configuration_t5.py │ │ │ ├── convert.py │ │ │ ├── t5.py │ │ │ ├── t5_model.py │ │ │ └── tokenizer_t5.py │ │ ├── tasks │ │ │ ├── __init__.py │ │ │ ├── classification.py │ │ │ ├── maked_lm_model.py │ │ │ ├── similarity_model.py │ │ │ └── span_selection.py │ │ └── vit │ │ │ ├── __init__.py │ │ │ ├── configuration_vit.py │ │ │ ├── convert.py │ │ │ ├── vit.py │ │ │ ├── vit_feature_extractor.py │ │ │ └── vit_model.py │ ├── optimization │ │ ├── __init__.py │ │ ├── adafactor_optimization.py │ │ ├── adam_weighted.py │ │ ├── learning_rate_utils.py │ │ └── optimization.py │ ├── text │ │ ├── __init__.py │ │ ├── decoder_utils.py │ │ ├── lm_tasks │ │ │ ├── __init__.py │ │ │ ├── causal_lm.py │ │ │ ├── masked_lm.py │ │ │ └── prefix_lm.py │ │ ├── sentencepiece_layer.py │ │ ├── sentencepiece_model_pb2.py │ │ ├── text_decoder.py │ │ ├── text_decoder_encoder_only.py │ │ ├── text_decoder_encoder_only_serializable.py │ │ ├── text_decoder_model.py │ │ ├── text_decoder_seq2seq.py │ │ ├── text_decoder_seq2seq_serializable.py │ │ ├── text_decoder_serializable_encoder_only.py │ │ └── text_layer_experimental.py │ └── utils │ │ ├── __init__.py │ │ ├── docstring_file_utils.py │ │ ├── docstring_utils.py │ │ ├── fast_sp_alignment.py │ │ ├── positional_bias_utils.py │ │ ├── push_to_hub.py │ │ ├── tf_utils.py │ │ ├── tokenization.py │ │ ├── utils.py │ │ └── viz_utils.py └── transformers_blue.png ├── tests ├── __init__.py ├── model_test_scripts │ ├── test_modeling_albert.py │ ├── test_modeling_bert.py │ ├── test_modeling_gpt2.py │ ├── test_modeling_mt5.py │ ├── test_modeling_roberta.py │ ├── test_modeling_t5.py │ ├── test_modeling_vit.py │ └── test_wav2vec2.py └── test_tf_transformers.py └── tutorials ├── 1_read_write_tfrecords.ipynb ├── 1_read_write_tfrecords.md ├── 2_text_classification_imdb_albert.ipynb ├── 2_text_classification_imdb_albert.md ├── 3_masked_lm_tpu.ipynb ├── 3_masked_lm_tpu.md ├── 4_image_classification_vit_multi_gpu.ipynb ├── 4_image_classification_vit_multi_gpu.md ├── 5_sentence_embedding_roberta_quora_zeroshot.ipynb ├── 5_sentence_embedding_roberta_quora_zeroshot.md ├── 6_prompt_engineering_clip.ipynb ├── 6_prompt_engineering_clip.md ├── 7_gpt2_question_answering_squad.ipynb ├── 7_gpt2_question_answering_squad.md ├── 8_code_code_java_to_csharp_t5.ipynb ├── 8_code_code_java_to_csharp_t5.md ├── 9_images_tfrecords.ipynb ├── 9_images_tfrecords.md ├── README.ipynb ├── README.md ├── push_model_to_hf_hub.ipynb └── push_model_to_hf_hub.md /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = B,C,E,F,P,T4,W,B9 3 | # F401 - imported but unused 4 | # C901 - complexity 5 | # E722 do not use bare 'except' 6 | ignore = E203, E266, E501, W503, C901, E722, B001, B006, E731 7 | max-complexity = 10 8 | max-line-length = 120 9 | # to avoid import unused in __init__.py 10 | per-file-ignores = 11 | __init__.py:F401 12 | exclude = 13 | .git, 14 | __pycache__, 15 | tests, 16 | docs, 17 | tutorials 18 | -------------------------------------------------------------------------------- /.github/workflows/build_documentation.yaml: -------------------------------------------------------------------------------- 1 | name: Documentation Build 2 | on: 3 | workflow_dispatch: 4 | 5 | jobs: 6 | build-docs: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Checkout 10 | uses: actions/checkout@master 11 | with: 12 | fetch-depth: 0 13 | - uses: actions/setup-python@v2 14 | with: 15 | python-version: 3.9 16 | - name: Build documentation 17 | run: | 18 | mkdir gh-pages 19 | touch gh-pages/.nojekyll 20 | cd docs/ 21 | pip3 install -r requirements_build.txt 22 | make clean html 23 | cp -r * ../gh-pages/ 24 | - name: Deploy documentation 25 | uses: JamesIves/github-pages-deploy-action@4.1.4 26 | with: 27 | branch: gh-pages 28 | folder: gh-pages -------------------------------------------------------------------------------- /.github/workflows/ci_cd.yaml: -------------------------------------------------------------------------------- 1 | name: CI-CD 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - main 9 | 10 | jobs: 11 | ci: 12 | # Step 1. Set up operating system 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | os: [macos-latest, ubuntu-latest] 17 | steps: 18 | # Step 2. Set up Python 3.9 19 | - uses: actions/setup-python@v2 20 | with: 21 | python-version: 3.9 22 | # Step 3. Check-out repository so we can access its contents 23 | - uses: actions/checkout@v2 24 | - name: Install package 25 | run: | 26 | pip install poetry==1.3.2 27 | poetry install 28 | # Step 6. Run flake8 for tftransformers 29 | - name: Run flake8 30 | run: | 31 | poetry add flake8 32 | poetry run flake8 src/ 33 | # Step 5. Run tests for tftransformers 34 | - name: Test with pytest 35 | run: poetry run pytest tests/test_tf_transformers.py --cov-report=xml --cov=tests 36 | # Step 6. Use Codecov to track coverage 37 | - uses: codecov/codecov-action@v2 38 | with: 39 | file: ./coverage.xml # coverage report 40 | fail_ci_if_error: true # terminate workflow if there's an error 41 | token: ${{ secrets.CODECOV_TOKEN }} 42 | flags: unittests 43 | name: codecov-umbrella 44 | verbose: true 45 | -------------------------------------------------------------------------------- /.github/workflows/patch_version.yaml: -------------------------------------------------------------------------------- 1 | name: Patch version 2 | on: 3 | release: 4 | types: [published] 5 | 6 | jobs: 7 | release: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v2 11 | - name: Get tag 12 | id: vars 13 | run: echo ::set-output name=tag::${GITHUB_REF#refs/*/} 14 | - uses: actions/setup-python@v2 15 | with: 16 | python-version: '3.9' 17 | - name: Patch version and commit 18 | run: | 19 | echo "Tag retrived is ${{ steps.vars.outputs.tag }}" 20 | pip install poetry 21 | poetry version ${{ steps.vars.outputs.tag }} 22 | git config user.name github-actions 23 | git config user.email github-actions@github.com 24 | python patch_version.py ${{ steps.vars.outputs.tag }} 25 | git add pyproject.toml 26 | git add src/tf_transformers/__init__.py 27 | git add tests/test_tf_transformers.py 28 | git commit -m "Updated version to ${{ steps.vars.outputs.tag }}" 29 | git push origin HEAD:main -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | workflow_dispatch: 4 | 5 | jobs: 6 | release: 7 | # # Only run this job if new work is pushed to "main" 8 | # if: github.event_name == 'push' && github.ref == 'refs/heads/main' 9 | # Step 1. Set up operating system 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | matrix: 13 | os: [ubuntu-latest] 14 | steps: 15 | # Step 2. Set up Python 3.9 16 | - uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.9 19 | # Step 3. Check-out repository so we can access its contents 20 | - uses: actions/checkout@v2 21 | with: 22 | fetch-depth: 0 23 | # Step 4. Build 24 | - name: Build package 25 | run: | 26 | pip install poetry==1.3.2 27 | poetry build 28 | # Step 5. Publish to TestPyPI 29 | - uses: pypa/gh-action-pypi-publish@release/v1 30 | with: 31 | user: __token__ 32 | password: ${{ secrets.TEST_PYPI_TOKEN }} 33 | repository_url: https://test.pypi.org/legacy/ 34 | skip_existing: true 35 | # Step 6. Test install from TestPyPI 36 | - name: Test install from TestPyPI 37 | run: | 38 | pip install \ 39 | --index-url https://test.pypi.org/simple/ \ 40 | --extra-index-url https://pypi.org/simple \ 41 | tf-transformers 42 | # Step 7. Publish to PyPI 43 | - uses: pypa/gh-action-pypi-publish@release/v1 44 | with: 45 | user: __token__ 46 | password: ${{ secrets.PYPI_TOKEN }} 47 | skip_existing: true -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | 9 | # C extensions 10 | *.so 11 | 12 | # DS store 13 | .DS_Store 14 | .DS_Store/ 15 | 16 | # tests and logs 17 | tests/fixtures/* 18 | !tests/fixtures/sample_text_no_unicode.txt 19 | logs/ 20 | lightning_logs/ 21 | lang_code_data/ 22 | 23 | #hydra outputs 24 | benchmarks/*/outputs 25 | research/*/outputs 26 | outputs/ 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Django stuff: 75 | *.log 76 | local_settings.py 77 | db.sqlite3 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # IPython 96 | profile_default/ 97 | ipython_config.py 98 | 99 | # pyenv 100 | .python-version 101 | 102 | # celery beat schedule file 103 | celerybeat-schedule 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # vscode 136 | .vs 137 | .vscode 138 | 139 | # Pycharm 140 | .idea 141 | 142 | # TF code 143 | tensorflow_code 144 | 145 | # Models 146 | proc_data 147 | 148 | # examples 149 | runs 150 | /runs_old 151 | /wandb 152 | /examples/runs 153 | /examples/**/*.args 154 | /examples/rag/sweep 155 | 156 | # data 157 | /data 158 | serialization_dir 159 | 160 | # emacs 161 | *.*~ 162 | debug.env 163 | 164 | # vim 165 | .*.swp 166 | 167 | #ctags 168 | tags 169 | 170 | # pre-commit 171 | #.pre-commit* 172 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_stages: 2 | - commit 3 | exclude: .git|.tox|tests|docs 4 | fail_fast: true 5 | repos: 6 | - hooks: 7 | - exclude: ^tutorials/ 8 | id: check-docstring-first 9 | - exclude: ^tutorials/ 10 | id: trailing-whitespace 11 | - exclude: ^tutorials/ 12 | id: end-of-file-fixer 13 | - id: check-toml 14 | - id: check-merge-conflict 15 | - exclude: ^tutorials/ 16 | id: check-added-large-files 17 | repo: https://github.com/pre-commit/pre-commit-hooks 18 | rev: v4.4.0 19 | - hooks: 20 | - id: black 21 | repo: https://github.com/psf/black 22 | rev: 23.1.0 23 | - hooks: 24 | - id: isort 25 | repo: https://github.com/timothycrosley/isort 26 | rev: 5.12.0 27 | - hooks: 28 | - additional_dependencies: 29 | - flake8-bugbear 30 | - flake8-implicit-str-concat 31 | id: flake8 32 | repo: https://github.com/pycqa/flake8.git 33 | rev: 6.0.0 34 | - hooks: 35 | - args: 36 | - --no-strict-optional 37 | - --ignore-missing-imports 38 | id: mypy 39 | repo: https://github.com/pre-commit/mirrors-mypy 40 | rev: v1.0.1 41 | - hooks: 42 | - args: 43 | - --sync 44 | - tutorials/*.ipynb 45 | files: ^tutorials/ 46 | id: jupytext 47 | repo: https://github.com/mwouts/jupytext 48 | rev: v1.14.4 49 | 50 | # - hooks: 51 | # - id: commitizen 52 | # stages: [commit-msg] 53 | # repo: https://github.com/commitizen-tools/commitizen 54 | # rev: v2.20.3 55 | 56 | - hooks: 57 | - entry: python custom_hook.py 58 | id: custom_hook 59 | language: python 60 | name: custom_hook 61 | verbose: true 62 | repo: local 63 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## Unreleased 2 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions are welcome, and they are greatly appreciated! Every little bit 4 | helps, and credit will always be given. 5 | 6 | ## Types of Contributions 7 | 8 | ### Report Bugs 9 | 10 | If you are reporting a bug, please include: 11 | 12 | * Your operating system name and version. 13 | * Any details about your local setup that might be helpful in troubleshooting. 14 | * Detailed steps to reproduce the bug. 15 | 16 | ### Fix Bugs 17 | 18 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help 19 | wanted" is open to whoever wants to implement it. 20 | 21 | ### Implement Features 22 | 23 | Look through the GitHub issues for features. Anything tagged with "enhancement" 24 | and "help wanted" is open to whoever wants to implement it. 25 | 26 | ### Write Documentation 27 | 28 | You can never have enough documentation! Please feel free to contribute to any 29 | part of the documentation, such as the official docs, docstrings, or even 30 | on the web in blog posts, articles, and such. 31 | 32 | ### Submit Feedback 33 | 34 | If you are proposing a feature: 35 | 36 | * Explain in detail how it would work. 37 | * Keep the scope as narrow as possible, to make it easier to implement. 38 | * Remember that this is a volunteer-driven project, and that contributions 39 | are welcome :) 40 | 41 | ## Get Started! 42 | 43 | Ready to contribute? Here's how to set up `tf-transformers` for local development. 44 | 45 | 1. Download a copy of `tf-transformers` locally. 46 | 2. Install `tf-transformers` using `poetry`: 47 | 48 | ```console 49 | $ poetry install 50 | ``` 51 | 52 | 3. Use `git` (or similar) to create a branch for local development and make your changes: 53 | 54 | ```console 55 | $ git checkout -b name-of-your-bugfix-or-feature 56 | ``` 57 | 58 | 4. When you're done making changes, check that your changes conform to any code formatting requirements and pass any tests. 59 | 60 | 5. Commit your changes and open a pull request. 61 | 62 | ## Pull Request Guidelines 63 | 64 | Before you submit a pull request, check that it meets these guidelines: 65 | 66 | 1. The pull request should include additional tests if appropriate. 67 | 2. If the pull request adds functionality, the docs should be updated. 68 | 3. The pull request should work for all currently supported operating systems and versions of Python. 69 | 70 | ## Code of Conduct 71 | 72 | Please note that the `tf-transformers` project is released with a 73 | Code of Conduct. By contributing to this project you agree to abide by its terms. Thanks. 74 | -------------------------------------------------------------------------------- /benchmarks/albert/conf/benchmark/hf.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: hf 3 | device: cuda 4 | 5 | data: 6 | name: imdb 7 | take_sample: false 8 | batch_size: 32 9 | max_length: 512 10 | 11 | model: 12 | name: albert-base-v2 13 | type: tf 14 | -------------------------------------------------------------------------------- /benchmarks/albert/conf/benchmark/tft.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: tft 3 | 4 | data: 5 | name: imdb 6 | take_sample: false 7 | batch_size: 32 8 | max_length: 512 9 | 10 | model: 11 | name: albert-base-v2 12 | type: saved_model 13 | -------------------------------------------------------------------------------- /benchmarks/albert/conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - benchmark: tft 3 | -------------------------------------------------------------------------------- /benchmarks/albert/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import hydra 4 | from omegaconf import DictConfig, OmegaConf 5 | 6 | # A logger for this file 7 | log = logging.getLogger(__name__) 8 | 9 | 10 | def run_benchmark(cfg): 11 | 12 | if cfg.benchmark.task.name == "tft": 13 | from benchmark_tft import TftBenchmark 14 | 15 | benchmark = TftBenchmark(cfg) 16 | results = benchmark.run() 17 | log.info(results) 18 | 19 | if cfg.benchmark.task.name == "hf": 20 | from benchmark_hf import HFBenchmark 21 | 22 | benchmark = HFBenchmark(cfg) 23 | results = benchmark.run() 24 | log.info(results) 25 | 26 | 27 | @hydra.main(config_path="conf", config_name="config") 28 | def run(cfg: DictConfig) -> None: 29 | run_benchmark((cfg)) 30 | print(OmegaConf.to_yaml(cfg)) 31 | 32 | 33 | if __name__ == "__main__": 34 | run() 35 | -------------------------------------------------------------------------------- /benchmarks/gpt2/conf/benchmark/hf.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: hf 3 | device: cuda 4 | 5 | data: 6 | name: cnn_dailymail 7 | take_sample: false 8 | batch_size: 32 9 | max_length: 512 10 | 11 | model: 12 | name: gpt2 13 | type: tf 14 | 15 | text_generation: 16 | max_length: 64 17 | -------------------------------------------------------------------------------- /benchmarks/gpt2/conf/benchmark/tft.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: tft 3 | 4 | data: 5 | name: cnn_dailymail 6 | take_sample: false 7 | batch_size: 32 8 | max_length: 512 9 | 10 | model: 11 | name: gpt2 12 | type: textdecoder_saved_model 13 | 14 | text_generation: 15 | max_iterations: 64 16 | mode: greedy 17 | -------------------------------------------------------------------------------- /benchmarks/gpt2/conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - benchmark: tft 3 | -------------------------------------------------------------------------------- /benchmarks/gpt2/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import hydra 4 | from omegaconf import DictConfig, OmegaConf 5 | 6 | # A logger for this file 7 | log = logging.getLogger(__name__) 8 | 9 | 10 | def run_benchmark(cfg): 11 | 12 | if cfg.benchmark.task.name == "tft": 13 | from benchmark_tft import TftBenchmark 14 | 15 | benchmark = TftBenchmark(cfg) 16 | results = benchmark.run() 17 | log.info(results) 18 | 19 | if cfg.benchmark.task.name == "hf": 20 | from benchmark_hf import HFBenchmark 21 | 22 | benchmark = HFBenchmark(cfg) 23 | results = benchmark.run() 24 | log.info(results) 25 | 26 | 27 | @hydra.main(config_path="conf", config_name="config") 28 | def run(cfg: DictConfig) -> None: 29 | run_benchmark((cfg)) 30 | print(OmegaConf.to_yaml(cfg)) 31 | 32 | 33 | if __name__ == "__main__": 34 | run() 35 | -------------------------------------------------------------------------------- /benchmarks/roberta/README.MD: -------------------------------------------------------------------------------- 1 | 16 | 17 | # Benchmark Roberta 18 | 19 | - [Code - Roberta Benchmark](https://github.com/legacyai/tf-transformers/tree/main/benchmark/roberta) 20 | 21 | This is used to benchmark the performance of Roberta model on text generation tasks. We evaluate it using 3 frameworks. 22 | Tensorflow-Transformers (default), HuggingFace PyTorch, HuggingFace Tensorflow and HuggingFace JAX (pending). 23 | Executing these scripts are fairly straightfoward and expect users to install the necessary libraries before executing 24 | the benchmark script. 25 | 26 | All the configuration are managed using [Hydra](https://github.com/facebookresearch/hydra). 27 | 28 | -> Machine - **Tesla V100-SXM2 32GB** 29 | 30 | -> Tensorflow-version - **2.4.1** 31 | 32 | -> Huggingface-Transformer-Version - **4.12.5** 33 | 34 | -> PyTorch-Version - **1.9.0** 35 | 36 | ## Tensorflow-Transformers. (tft) 37 | 38 | The default benchmark mode is ```tft```. 39 | 1. To execute ```tft``` (default) : 40 | ```python run.py benchmark=tft``` 41 | 42 | 2. To execute ```type``` eg ```keras_model``` : 43 | ```python run.py benchmark=tft benchmark.model.type=keras_model``` 44 | 45 | * a. keras_model - Uses tf.keras.Model. 46 | * b. saved_model - Uses tf.saved_model . 47 | 48 | ## HuggingFace-Tensorflow. (hf-tf) 49 | 50 | 1. To execute ```hf-tf``` (default) : 51 | ```python run.py benchmark=hf benchmark.model.type=tf``` 52 | 53 | 54 | ## HuggingFace-PyTorch. (hf-pt) 55 | 56 | 1. To execute ```hf-pt``` (default) : 57 | ```python run.py benchmark=hf benchmark.model.type=pt``` 58 | 59 | 60 | ## HuggingFace-JAX. (hf-jax) (Not Available) 61 | 62 | 1. To execute ```hf-jax``` (default) : 63 | ```python run.py benchmark=hf benchmark.model.type=jax``` 64 | 65 | 66 | ## Official Benchmarks on IMDB 67 | 68 | ``` 69 | Text Classification: 70 | | | batch_size | time (s) | samples/second | 71 | |:---------------------------|-------------:|:-------------:|:-----------|------ 72 | | tft + saved_model | 32 | 372.44 sec | 67 | 73 | | tft + keras_model | 32 | 367.17 sec | 70 | 74 | | hf_tf | 32 | 287.91 sec | 86 | 75 | | hf_pt | 32 | 253.61 sec | 98 | 76 | | hf_jax (pmap) | 32 | N/A | N/A | 77 | ``` 78 | -------------------------------------------------------------------------------- /benchmarks/roberta/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/benchmarks/roberta/__init__.py -------------------------------------------------------------------------------- /benchmarks/roberta/conf/benchmark/hf.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: hf 3 | device: cuda 4 | 5 | data: 6 | name: imdb 7 | take_sample: false 8 | batch_size: 32 9 | max_length: 512 10 | 11 | model: 12 | name: roberta-base 13 | type: tf 14 | -------------------------------------------------------------------------------- /benchmarks/roberta/conf/benchmark/tft.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: tft 3 | 4 | data: 5 | name: imdb 6 | take_sample: false 7 | batch_size: 32 8 | max_length: 512 9 | 10 | model: 11 | name: roberta-base 12 | type: saved_model 13 | -------------------------------------------------------------------------------- /benchmarks/roberta/conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - benchmark: tft 3 | -------------------------------------------------------------------------------- /benchmarks/roberta/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import hydra 4 | from omegaconf import DictConfig, OmegaConf 5 | 6 | # A logger for this file 7 | log = logging.getLogger(__name__) 8 | 9 | 10 | def run_benchmark(cfg): 11 | 12 | if cfg.benchmark.task.name == "tft": 13 | from benchmark_tft import TftBenchmark 14 | 15 | benchmark = TftBenchmark(cfg) 16 | results = benchmark.run() 17 | log.info(results) 18 | 19 | if cfg.benchmark.task.name == "hf": 20 | from benchmark_hf import HFBenchmark 21 | 22 | benchmark = HFBenchmark(cfg) 23 | results = benchmark.run() 24 | log.info(results) 25 | 26 | 27 | @hydra.main(config_path="conf", config_name="config") 28 | def run(cfg: DictConfig) -> None: 29 | run_benchmark((cfg)) 30 | print(OmegaConf.to_yaml(cfg)) 31 | 32 | 33 | if __name__ == "__main__": 34 | run() 35 | -------------------------------------------------------------------------------- /benchmarks/t5/conf/benchmark/hf.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: hf 3 | device: cuda 4 | 5 | data: 6 | name: xsum 7 | take_sample: false 8 | batch_size: 32 9 | max_length: 512 10 | 11 | model: 12 | name: t5-small 13 | type: tf 14 | 15 | text_generation: 16 | max_length: 64 17 | eos_token_id: 250000 18 | -------------------------------------------------------------------------------- /benchmarks/t5/conf/benchmark/tft.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: tft 3 | 4 | data: 5 | name: xsum 6 | take_sample: false 7 | batch_size: 32 8 | max_length: 512 9 | 10 | model: 11 | name: t5-small 12 | type: textdecoder_saved_model 13 | 14 | text_generation: 15 | max_iterations: 64 16 | mode: greedy 17 | -------------------------------------------------------------------------------- /benchmarks/t5/conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - benchmark: tft 3 | -------------------------------------------------------------------------------- /benchmarks/t5/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import hydra 4 | from omegaconf import DictConfig, OmegaConf 5 | 6 | # A logger for this file 7 | log = logging.getLogger(__name__) 8 | 9 | 10 | def run_benchmark(cfg): 11 | 12 | if cfg.benchmark.task.name == "tft": 13 | from benchmark_tft import TftBenchmark 14 | 15 | benchmark = TftBenchmark(cfg) 16 | results = benchmark.run() 17 | log.info(results) 18 | 19 | if cfg.benchmark.task.name == "hf": 20 | from benchmark_hf import HFBenchmark 21 | 22 | benchmark = HFBenchmark(cfg) 23 | results = benchmark.run() 24 | log.info(results) 25 | 26 | 27 | @hydra.main(config_path="conf", config_name="config") 28 | def run(cfg: DictConfig) -> None: 29 | run_benchmark((cfg)) 30 | print(OmegaConf.to_yaml(cfg)) 31 | 32 | 33 | if __name__ == "__main__": 34 | run() 35 | -------------------------------------------------------------------------------- /benchmarks/vit/README.MD: -------------------------------------------------------------------------------- 1 | 16 | 17 | # Benchmark ViT 18 | 19 | - [Code - ViT Benchmark](https://github.com/legacyai/tf-transformers/tree/main/benchmark/vit) 20 | 21 | This is used to benchmark the performance of ViT model on text generation tasks. We evaluate it using 3 frameworks. 22 | Tensorflow-Transformers (default), HuggingFace PyTorch, HuggingFace Tensorflow and HuggingFace JAX (pending). 23 | Executing these scripts are fairly straightfoward and expect users to install the necessary libraries before executing 24 | the benchmark script. 25 | 26 | All the configuration are managed using [Hydra](https://github.com/facebookresearch/hydra). 27 | 28 | -> Machine - **Tesla V100-SXM2 32GB** 29 | 30 | -> Tensorflow-version - **2.4.1** 31 | 32 | -> Huggingface-Transformer-Version - **4.12.5** 33 | 34 | -> PyTorch-Version - **1.9.0** 35 | 36 | ## Tensorflow-Transformers. (tft) 37 | 38 | The default benchmark mode is ```tft```. 39 | 1. To execute ```tft``` (default) : 40 | ```python run.py benchmark=tft``` 41 | 42 | 2. To execute ```type``` eg ```keras_model``` : 43 | ```python run.py benchmark=tft benchmark.model.type=keras_model``` 44 | 45 | * a. keras_model - Uses tf.keras.Model. 46 | * b. saved_model - Uses tf.saved_model 47 | * c. saved_model_tf-io - Uses tf.saved_model, ```model + tf.io ``` is serialized together. 48 | 49 | 50 | ## HuggingFace-Tensorflow. (hf-tf) 51 | 52 | 1. To execute ```hf-tf``` (default) : 53 | ```python run.py benchmark=hf benchmark.model.type=tf``` 54 | 55 | 56 | ## HuggingFace-PyTorch. (hf-pt) 57 | 58 | 1. To execute ```hf-pt``` (default) : 59 | ```python run.py benchmark=hf benchmark.model.type=pt``` 60 | 61 | 62 | ## Official Benchmarks on Keras Flowed dataset (5000 samples) 63 | 64 | ``` 65 | Text Classification: 66 | | | batch_size | time (s) | samples/second | 67 | |:---------------------------|-------------:|:-------------:|:-----------|------ 68 | | tft + saved_model | 32 | 29.06 sec | 126 | 69 | | tft + saved_model + tf-io | 32 | 17.62 sec | 208 | 70 | | tft + keras_model | 32 | 29.48 sec | 124 | 71 | | hf_tf | 32 | 95.84 sec | 38 | 72 | | hf_pt | 32 | 24.79 sec | 148 | 73 | | tft + keras + hf pipeline | 32 | 94.44 sec | 39 | 74 | ``` 75 | -------------------------------------------------------------------------------- /benchmarks/vit/conf/benchmark/hf.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: hf 3 | device: cuda 4 | 5 | data: 6 | take_sample: false 7 | batch_size: 32 8 | 9 | model: 10 | name: google/vit-base-patch16-224 11 | type: tf 12 | -------------------------------------------------------------------------------- /benchmarks/vit/conf/benchmark/tft.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: tft 3 | 4 | data: 5 | take_sample: false 6 | batch_size: 32 7 | 8 | model: 9 | name: google/vit-base-patch16-224 10 | type: saved_model 11 | -------------------------------------------------------------------------------- /benchmarks/vit/conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - benchmark: tft 3 | -------------------------------------------------------------------------------- /benchmarks/vit/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import hydra 4 | from omegaconf import DictConfig, OmegaConf 5 | 6 | # A logger for this file 7 | log = logging.getLogger(__name__) 8 | 9 | 10 | def run_benchmark(cfg): 11 | 12 | if cfg.benchmark.task.name == "tft": 13 | from benchmark_tft import TftBenchmark 14 | 15 | benchmark = TftBenchmark(cfg) 16 | results = benchmark.run() 17 | log.info(results) 18 | 19 | if cfg.benchmark.task.name == "hf": 20 | from benchmark_hf import HFBenchmark 21 | 22 | benchmark = HFBenchmark(cfg) 23 | results = benchmark.run() 24 | log.info(results) 25 | 26 | 27 | @hydra.main(config_path="conf", config_name="config") 28 | def run(cfg: DictConfig) -> None: 29 | run_benchmark((cfg)) 30 | print(OmegaConf.to_yaml(cfg)) 31 | 32 | 33 | if __name__ == "__main__": 34 | run() 35 | -------------------------------------------------------------------------------- /custom_hook.py: -------------------------------------------------------------------------------- 1 | def add_new_files_to_jupytext(): 2 | """This function will check all .ipynb and .md files. 3 | If a new .ipynb is present, it will convert to equivalent .md 4 | """ 5 | import glob 6 | import subprocess 7 | 8 | all_ipynb = glob.glob('tutorials/*.ipynb') 9 | all_md = glob.glob('tutorials/*.md') 10 | 11 | all_ipynb = [name.split('.ipynb')[0] for name in all_ipynb] 12 | all_md = [name.split('.md')[0] for name in all_md] 13 | 14 | notebook_list = [] 15 | for notebook_name in all_ipynb: 16 | if notebook_name not in all_md: 17 | notebook_list.append(notebook_name + '.ipynb') 18 | 19 | if notebook_list != []: 20 | for notebook in notebook_list: 21 | subprocess.run(["jupytext --set-formats ipynb,md:myst {}".format(notebook)], shell=True) 22 | 23 | 24 | def move_to_docs(): 25 | """Move tutorals to docs""" 26 | print("Copying the tutorials to docs") 27 | import shutil 28 | 29 | shutil.copytree("tutorials", "docs/source/tutorials", dirs_exist_ok=True) 30 | 31 | 32 | if __name__ == '__main__': 33 | add_new_files_to_jupytext() # Convert new notebooks to md using jupytext 34 | move_to_docs() # Move new tuorials to docs 35 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # sphinx <4 required by myst-nb v0.12.0 (Feb 2021) 2 | # sphinx >=3 required by sphinx-autodoc-typehints v1.11.1 (Oct 2020) 3 | sphinx >=3, <4 4 | sphinx_rtd_theme 5 | sphinx-autodoc-typehints==1.11.1 6 | jupyter-sphinx>=0.3.2 7 | myst-nb==v0.12.0 8 | 9 | # Packages as per HuggingFace extensions 10 | sphinx_rtd_theme 11 | recommonmark 12 | sphinx_markdown_tables 13 | sphinxext-opengraph 14 | sphinx-copybutton 15 | matplotlib 16 | -------------------------------------------------------------------------------- /docs/requirements_build.txt: -------------------------------------------------------------------------------- 1 | # sphinx <4 required by myst-nb v0.12.0 (Feb 2021) 2 | # sphinx >=3 required by sphinx-autodoc-typehints v1.11.1 (Oct 2020) 3 | sphinx >=3, <4 4 | sphinx_rtd_theme==1.0.0 5 | sphinx-autodoc-typehints==1.11.1 6 | jupyter-sphinx>=0.3.2 7 | myst-nb 8 | jinja2==3.0.0 9 | 10 | 11 | # Packages as per HuggingFace extensions 12 | sphinx_rtd_theme 13 | recommonmark 14 | sphinx_markdown_tables 15 | sphinxext-opengraph 16 | sphinx-copybutton 17 | matplotlib 18 | 19 | # Packaged for build 20 | tensorflow-text==2.7.0 21 | sentencepiece 22 | tqdm 23 | transformers 24 | -------------------------------------------------------------------------------- /docs/source/README.md: -------------------------------------------------------------------------------- 1 | We use jupytext to keep copies of notebook in sync with Markdown equivalent. 2 | 3 | ### Adding a new notebook 4 | 5 | ``` 6 | jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb 7 | ``` 8 | 9 | ### Syncing Notebooks 10 | 11 | After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running: 12 | 13 | ``` 14 | jupytext --sync docs/notebooks/* 15 | ``` 16 | -------------------------------------------------------------------------------- /docs/source/README.rst: -------------------------------------------------------------------------------- 1 | We use jupytext to keep copies of notebook in sync with Markdown equivalent. 2 | 3 | ### Adding a new notebook 4 | 5 | ``` 6 | jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb 7 | ``` 8 | 9 | ### Syncing Notebooks 10 | 11 | After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running: 12 | 13 | ``` 14 | jupytext --sync docs/notebooks/* 15 | ``` 16 | -------------------------------------------------------------------------------- /docs/source/_static/css/Calibre-Light.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/Calibre-Light.ttf -------------------------------------------------------------------------------- /docs/source/_static/css/Calibre-Medium.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/Calibre-Medium.otf -------------------------------------------------------------------------------- /docs/source/_static/css/Calibre-Regular.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/Calibre-Regular.otf -------------------------------------------------------------------------------- /docs/source/_static/css/Calibre-Thin.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/Calibre-Thin.otf -------------------------------------------------------------------------------- /docs/source/_static/css/DMSans-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSans-Bold.ttf -------------------------------------------------------------------------------- /docs/source/_static/css/DMSans-BoldItalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSans-BoldItalic.ttf -------------------------------------------------------------------------------- /docs/source/_static/css/DMSans-Italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSans-Italic.ttf -------------------------------------------------------------------------------- /docs/source/_static/css/DMSans-Medium.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSans-Medium.ttf -------------------------------------------------------------------------------- /docs/source/_static/css/DMSans-MediumItalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSans-MediumItalic.ttf -------------------------------------------------------------------------------- /docs/source/_static/css/DMSans-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSans-Regular.ttf -------------------------------------------------------------------------------- /docs/source/_static/css/DMSerifText-Italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSerifText-Italic.ttf -------------------------------------------------------------------------------- /docs/source/_static/css/DMSerifText-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/css/DMSerifText-Regular.ttf -------------------------------------------------------------------------------- /docs/source/_static/css/code-snippets.css: -------------------------------------------------------------------------------- 1 | /* Code block Changes Comments */ 2 | .highlight .c1, .highlight .sd{ 3 | color: rgb(146, 56, 56); 4 | } 5 | 6 | /* Code block Changes from import*/ 7 | .highlight .kn, .highlight .nv, .highlight .s2, .highlight .ow { 8 | color: #6670FF; 9 | } 10 | 11 | /* Code block Changes methods*/ 12 | .highlight .nn, .highlight .k, .highlight .s1, .highlight .nb, .highlight .bp, .highlight .kc { 13 | color: #3C2478; 14 | } 15 | 16 | /* Code block Changes >>*/ 17 | .highlight .gp { 18 | color: #3C2478; 19 | } 20 | 21 | /* Code block Changes Add border*/ 22 | .rst-content div[class^='highlight'] { 23 | border-color: #1a1a1a; 24 | } 25 | -------------------------------------------------------------------------------- /docs/source/_static/tf_transformers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/tf_transformers.png -------------------------------------------------------------------------------- /docs/source/_static/tf_transformers_resized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/tf_transformers_resized.png -------------------------------------------------------------------------------- /docs/source/_static/transformers_blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/transformers_blue.png -------------------------------------------------------------------------------- /docs/source/_static/transformers_mix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/_static/transformers_mix.png -------------------------------------------------------------------------------- /docs/source/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/favicon.ico -------------------------------------------------------------------------------- /docs/source/imgs/long_block_sequencer.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/imgs/long_block_sequencer.gif -------------------------------------------------------------------------------- /docs/source/imgs/philosophy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/imgs/philosophy.png -------------------------------------------------------------------------------- /docs/source/model_doc/clip_feature_extractor.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright 2020 The HuggingFace Team and TFT Team. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with 5 | the License. You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on 10 | an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the 11 | specific language governing permissions and limitations under the License. 12 | 13 | CLIP Feature Extractor 14 | ----------------------------------------------------------------------------------------------------------------------- 15 | 16 | Overview 17 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 18 | 19 | This page includes information about how to use CLIPFeatureExtractorTF with tensorflow-ops. 20 | This feature extractors works in sync with :class:`~tf.data.Dataset` and so is useful for on the fly preprocessing. 21 | 22 | .. code-block:: 23 | 24 | >>> from tf_transformers.models import CLIPFeatureExtractorTF 25 | >>> image_path_list = # List fo image paths 26 | >>> CLIP_feature_extractor_tf = CLIPFeatureExtractorTF(img_height=224, img_width=224) 27 | >>> outputs = CLIP_feature_extractor_tf({'image': tf.constant(image_path_list)}) 28 | 29 | 30 | CLIPFeatureExtractorTF 31 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 32 | 33 | .. autoclass:: tf_transformers.models.CLIPFeatureExtractorTF 34 | :members: 35 | -------------------------------------------------------------------------------- /docs/source/model_doc/encoder_decoder.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright 2020 TFT Team. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with 5 | the License. You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on 10 | an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the 11 | specific language governing permissions and limitations under the License. 12 | 13 | EncoderDecoder 14 | ----------------------------------------------------------------------------------------------------------------------- 15 | 16 | Overview 17 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 18 | 19 | The EncoderDecoder model also known as Seq2Seq Model, consists of Encoder and Decoder models, like 20 | Bert and GPT2 or Bert and Bert itself. For more details, please refer to Bart model :doc:`Bart ` or T5 Model 21 | :doc:`T5 ` . 22 | This consits of: 23 | - Encoder and Decoder 24 | - Cross attention between Decoder and Encoder 25 | 26 | 27 | EncoderDecoder (Seq2Seq) 28 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 29 | 30 | .. autoclass:: tf_transformers.models.EncoderDecoder 31 | :members: 32 | -------------------------------------------------------------------------------- /docs/source/model_doc/m2m100.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/model_doc/m2m100.rst -------------------------------------------------------------------------------- /docs/source/model_doc/mbart.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/model_doc/mbart.rst -------------------------------------------------------------------------------- /docs/source/model_doc/t5_tokenizer.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright 2020 The HuggingFace Team and TFT Team. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with 5 | the License. You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on 10 | an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the 11 | specific language governing permissions and limitations under the License. 12 | 13 | T5 Tokenizer 14 | ----------------------------------------------------------------------------------------------------------------------- 15 | 16 | Overview 17 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 18 | 19 | This page includes information about how to use T5Tokenizer with tensorflow-text. 20 | This tokenizer works in sync with :class:`~tf.data.Dataset` and so is useful for on the fly tokenization. 21 | 22 | .. code-block:: 23 | 24 | >>> from tf_transformers.models import T5TokenizerTFText 25 | >>> tokenizer = T5TokenizerTFText.from_pretrained("t5-small") 26 | >>> text = ['The following statements are true about sentences in English:', 27 | '', 28 | 'A new sentence begins with a capital letter.'] 29 | >>> inputs = {'text': text} 30 | >>> outputs = tokenizer(inputs) # Ragged Tensor Output 31 | 32 | # Dynamic Padding 33 | >>> tokenizer = T5TokenizerTFText.from_pretrained("t5-small", dynamic_padding=True) 34 | >>> text = ['The following statements are true about sentences in English:', 35 | '', 36 | 'A new sentence begins with a capital letter.'] 37 | >>> inputs = {'text': text} 38 | >>> outputs = tokenizer(inputs) # Dict of tf.Tensor 39 | 40 | # Static Padding 41 | >>> tokenizer = T5TokenizerTFText.from_pretrained("t5-small", pack_model_inputs=True) 42 | >>> text = ['The following statements are true about sentences in English:', 43 | '', 44 | 'A new sentence begins with a capital letter.'] 45 | >>> inputs = {'text': text} 46 | >>> outputs = tokenizer(inputs) # Dict of tf.Tensor 47 | 48 | # To Add Special Tokens 49 | >>> tokenizer = T5TokenizerTFText.from_pretrained("t5-small", add_special_tokens=True) 50 | 51 | T5TokenizerTFText 52 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 53 | 54 | .. autoclass:: tf_transformers.models.T5TokenizerTFText 55 | :members: 56 | 57 | T5TokenizerLayer 58 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 59 | 60 | .. autoclass:: tf_transformers.models.T5TokenizerLayer 61 | :members: 62 | -------------------------------------------------------------------------------- /docs/source/model_doc/visual_bert.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/model_doc/visual_bert.rst -------------------------------------------------------------------------------- /docs/source/model_doc/vit_feature_extractor.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright 2020 TFT Team. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with 5 | the License. You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on 10 | an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the 11 | specific language governing permissions and limitations under the License. 12 | 13 | ViT Feature Extractor 14 | ----------------------------------------------------------------------------------------------------------------------- 15 | 16 | Overview 17 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 18 | 19 | This page includes information about how to use ViTFeatureExtractorTF with tensorflow-ops. 20 | This feature extractors works in sync with :class:`~tf.data.Dataset` and so is useful for on the fly preprocessing. 21 | 22 | .. code-block:: 23 | 24 | >>> from tf_transformers.models import ViTFeatureExtractorTF 25 | >>> image_path_list = # List fo image paths 26 | >>> vit_feature_extractor_tf = ViTFeatureExtractorTF(img_height=224, img_width=224) 27 | >>> outputs = vit_feature_extractor_tf({'image': tf.constant(image_path_list)}) 28 | 29 | ViTFeatureExtractorTF 30 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 31 | 32 | .. autoclass:: tf_transformers.models.ViTFeatureExtractorTF 33 | :members: 34 | -------------------------------------------------------------------------------- /docs/source/model_doc/wav2vec2.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/docs/source/model_doc/wav2vec2.rst -------------------------------------------------------------------------------- /docs/source/model_usage/sentence_transformers.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | formats: ipynb,md:myst 4 | text_representation: 5 | extension: .md 6 | format_name: myst 7 | format_version: 0.13 8 | jupytext_version: 1.13.5 9 | kernelspec: 10 | display_name: Python 3 (ipykernel) 11 | language: python 12 | name: python3 13 | --- 14 | 15 | # Sentence Transformer in tf-transformers 16 | 17 | * This is a simple tutorial to demonstrate how ```SentenceTransformer``` models has been integrated 18 | to ```tf-transformers``` and how to use it 19 | * The following tutorial is applicable to all supported ```SentenceTransformer``` models. 20 | 21 | ```{code-cell} ipython3 22 | 23 | ``` 24 | 25 | ### Load Sentence-t5 model 26 | 27 | ```{code-cell} ipython3 28 | import tensorflow as tf 29 | from tf_transformers.models import SentenceTransformer 30 | ``` 31 | 32 | ```{code-cell} ipython3 33 | model_name = 'sentence-transformers/sentence-t5-base' # Load any sentencetransformer model here 34 | model = SentenceTransformer.from_pretrained(model_name) 35 | ``` 36 | 37 | ### Whats my model input? 38 | 39 | * All models in ```tf-transformers``` are designed with full connections. All you need is ```model.input``` if its a ```LegacyModel/tf.keras.Model``` or ```model.model_inputs``` if its a ```LegacyLayer/tf.keras.layers.Layer``` 40 | 41 | ```{code-cell} ipython3 42 | model.input 43 | ``` 44 | 45 | ### Whats my model output? 46 | 47 | * All models in ```tf-transformers``` are designed with full connections. All you need is ```model.output``` if its a ```LegacyModel/tf.keras.Model``` or ```model.model_outputs``` if its a ```LegacyLayer/tf.keras.layers.Layer``` 48 | 49 | ```{code-cell} ipython3 50 | model.output 51 | ``` 52 | 53 | ### Sentence vectors 54 | 55 | ```{code-cell} ipython3 56 | from transformers import AutoTokenizer 57 | 58 | tokenizer = AutoTokenizer.from_pretrained(model_name) 59 | 60 | text = ['This is a sentence to get vector', 'This one too'] 61 | inputs = tokenizer(text, return_tensors='tf', padding=True) 62 | 63 | inputs_tf = {'input_ids': inputs['input_ids'], 'input_mask': inputs['attention_mask']} 64 | outputs_tf = model(inputs_tf) 65 | print("Sentence vector", outputs_tf['sentence_vector'].shape) 66 | ``` 67 | 68 | ```{code-cell} ipython3 69 | 70 | ``` 71 | 72 | ### Serialize as usual and load it 73 | 74 | * Serialize, load and assert outputs with non serialized ```(```tf.keras.Model```)``` 75 | 76 | ```{code-cell} ipython3 77 | model_dir = 'MODELS/sentence_t5' 78 | model.save_transformers_serialized(model_dir) 79 | 80 | loaded = tf.saved_model.load(model_dir) 81 | model = loaded.signatures['serving_default'] 82 | 83 | outputs_tf_serialized = model(**inputs_tf) 84 | 85 | tf.debugging.assert_near(outputs_tf['sentence_vector'], outputs_tf_serialized['sentence_vector']) 86 | ``` 87 | 88 | ```{code-cell} ipython3 89 | 90 | ``` 91 | 92 | ```{code-cell} ipython3 93 | 94 | ``` 95 | -------------------------------------------------------------------------------- /docs/source/tutorials/README.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "4193b067", 6 | "metadata": {}, 7 | "source": [ 8 | "We use jupytext to keep copies of notebook in sync with Markdown equivalent.\n", 9 | "\n", 10 | "### Adding a new notebook\n", 11 | "\n", 12 | "```\n", 13 | "jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb\n", 14 | "```\n", 15 | "\n", 16 | "### Syncing Notebooks\n", 17 | "\n", 18 | "After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running:\n", 19 | "\n", 20 | "```\n", 21 | "jupytext --sync docs/notebooks/*\n", 22 | "```" 23 | ] 24 | } 25 | ], 26 | "metadata": { 27 | "jupytext": { 28 | "cell_metadata_filter": "-all", 29 | "formats": "ipynb,md:myst", 30 | "main_language": "python" 31 | } 32 | }, 33 | "nbformat": 4, 34 | "nbformat_minor": 5 35 | } 36 | -------------------------------------------------------------------------------- /docs/source/tutorials/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | cell_metadata_filter: -all 4 | formats: ipynb,md:myst 5 | main_language: python 6 | text_representation: 7 | extension: .md 8 | format_name: myst 9 | format_version: 0.13 10 | jupytext_version: 1.14.4 11 | --- 12 | 13 | We use jupytext to keep copies of notebook in sync with Markdown equivalent. 14 | 15 | ### Adding a new notebook 16 | 17 | ``` 18 | jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb 19 | ``` 20 | 21 | ### Syncing Notebooks 22 | 23 | After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running: 24 | 25 | ``` 26 | jupytext --sync docs/notebooks/* 27 | ``` 28 | -------------------------------------------------------------------------------- /docs/source/tutorials/push_model_to_hf_hub.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | formats: ipynb,md:myst 4 | text_representation: 5 | extension: .md 6 | format_name: myst 7 | format_version: 0.13 8 | jupytext_version: 1.14.4 9 | kernelspec: 10 | display_name: Python 3 (ipykernel) 11 | language: python 12 | name: python3 13 | --- 14 | 15 | ```{code-cell} ipython3 16 | 17 | ``` 18 | 19 | ```{code-cell} ipython3 20 | import subprocess 21 | import os 22 | from distutils.dir_util import copy_tree 23 | ``` 24 | 25 | ### How to push a model to hub . 26 | 27 | * Make sure you have logged to ```huggingface-cli login``` using your token. 28 | 29 | ```{code-cell} ipython3 30 | 31 | ``` 32 | 33 | ```{code-cell} ipython3 34 | model_name = 'byt5-small' 35 | ``` 36 | 37 | #### 1. Create model name directory under organization name 38 | 39 | ```{code-cell} ipython3 40 | subprocess.run(['huggingface-cli', 'repo', 41 | 'create', '{}'.format(model_name), 42 | '--yes', 43 | '--organization', 'tftransformers']) 44 | ``` 45 | 46 | ```{code-cell} ipython3 47 | 48 | ``` 49 | 50 | #### 2. Now clone that above created repo/folder to our local cwd 51 | 52 | ```{code-cell} ipython3 53 | subprocess.run(["git", "clone", "https://huggingface.co/tftransformers/{}".format(model_name)]) 54 | ``` 55 | 56 | ```{code-cell} ipython3 57 | 58 | ``` 59 | 60 | #### 3. Now move your model directory , to current working directory under ```model_name``` directory 61 | 62 | ```{code-cell} ipython3 63 | cwd = os.getcwd() # Getc current working dir 64 | new_working_dir = os.path.join(cwd, model_name) # This is cloned from hf hub under organization 65 | os.chdir("{}".format(new_working_dir)) # Switch to new working dir 66 | 67 | # Cached model directory keep changing as per other machine 68 | cached_model_dir = '/var/folders/vq/4fxns8l55gq8_msgygbyb51h0000gn/T/tf_transformers_cache/{}/'.format(model_name) 69 | 70 | # Copy cached model directory , to new working directory 71 | copy_tree(cached_model_dir, new_working_dir) 72 | ``` 73 | 74 | ```{code-cell} ipython3 75 | 76 | ``` 77 | 78 | #### 4. Now time to push these model to hub 79 | 80 | ```{code-cell} ipython3 81 | subprocess.run(["git-lfs", "track", "*"]) 82 | subprocess.run(["git", "add", "."]) 83 | subprocess.run(["git", "commit", "-m", "Pushing new model {}".format(model_name)]) # Commit message 84 | subprocess.run(["git", "push"]) 85 | 86 | # Change back to original cwd 87 | os.chdir("{}".format(cwd)) 88 | ``` 89 | 90 | ```{code-cell} ipython3 91 | 92 | ``` 93 | -------------------------------------------------------------------------------- /docs/source/tutorials/sample.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "source": [], 7 | "outputs": [], 8 | "metadata": {} 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "source": [ 13 | "This is a sample ipynb file to check jupytext hooks" 14 | ], 15 | "metadata": {} 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "source": [ 21 | "import tensorflow as tf" 22 | ], 23 | "outputs": [], 24 | "metadata": {} 25 | } 26 | ], 27 | "metadata": { 28 | "orig_nbformat": 4, 29 | "language_info": { 30 | "name": "python" 31 | } 32 | }, 33 | "nbformat": 4, 34 | "nbformat_minor": 2 35 | } -------------------------------------------------------------------------------- /isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | #force_single_line=True 3 | multi_line_output=3 4 | force_grid_wrap=0 5 | use_parentheses=True 6 | line_length=120 7 | profile = "black" 8 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | show_error_codes = True 3 | disable_error_code = attr-defined 4 | 5 | 6 | 7 | [mypy-six.*] 8 | ignore_missing_imports = True 9 | 10 | [mypy-google.*] 11 | ignore_missing_imports = True 12 | -------------------------------------------------------------------------------- /patch_version.py: -------------------------------------------------------------------------------- 1 | # This script is adapted from python-semantic-release 2 | # flake8: noqa 3 | 4 | import sys 5 | 6 | 7 | def patch(required_version): 8 | """ 9 | Write the new version to init and test file 10 | """ 11 | 12 | required_version = required_version.replace('v', '') # Replace v1.0.0 to 1.0.0 13 | # Edit src 14 | with open('src/tf_transformers/__init__.py', 'w') as f: 15 | f.write('__version__ = "{}"\n'.format(required_version)) 16 | 17 | # Edit test 18 | test_content = 'from tf_transformers import __version__\n\nversion = "{}"\ndef test_version():\n assert __version__ == version\n' 19 | with open('tests/test_tf_transformers.py', 'w') as f: 20 | f.write(test_content.format(required_version)) 21 | 22 | 23 | if __name__ == '__main__': 24 | version = sys.argv[-1] # last cli argument 25 | patch(version) 26 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "tf-transformers" 3 | version = "v2.0.0" 4 | description = "NLP with Transformer based models on Tensorflow 2.0" 5 | authors = ["Sarath R Nair "] 6 | maintainers = ["Sarath R Nair "] 7 | license = "MIT" 8 | readme = "README.md" 9 | homepage = "" 10 | repository = "" 11 | documentation = "" 12 | keywords = [ 13 | "tensorflow", 14 | "transformers", 15 | "nlp", 16 | "keras", 17 | "bert", 18 | "deep learning" 19 | ] 20 | 21 | [tool.poetry.dependencies] 22 | python = "^3.9" 23 | tqdm = "^4.62.3" 24 | transformers = "^4.15.0" 25 | sentencepiece = "^0.1.96" 26 | absl-py = "^1.0.0" 27 | cffi = "^1.15.1" 28 | 29 | 30 | [tool.poetry.dev-dependencies] 31 | pytest = "^5.2" 32 | pylint = "^2.6.0" 33 | coverage = {extras = ["toml"], version = "^5.5"} 34 | pytest-cov = "^2.11.1" 35 | python-semantic-release = "^7.23.0" 36 | 37 | [tool.semantic_release] 38 | version_variable = [ 39 | "src/tf_transformers/__init__.py:__version__", 40 | "tests/test_tf_transformers.py:version" 41 | ] 42 | version_toml = [ 43 | "pyproject.toml:tool.poetry.version" 44 | ] 45 | version_pattern = [ 46 | "README.md:version: v{version}" 47 | 48 | ] 49 | branch = "main" 50 | dist_path = "dist/" 51 | upload_to_pypi = false 52 | remove_dist = false 53 | build_command = "pip install poetry && poetry build" 54 | 55 | [build-system] 56 | requires = ["poetry-core>=1.0.4"] 57 | build-backend = "poetry.core.masonry.api" 58 | 59 | [tool.isort] 60 | profile = "black" 61 | 62 | [tool.black] 63 | skip-string-normalization = true 64 | line-length = 120 65 | target-version = ['py37'] 66 | include = '\.pyi?$' 67 | exclude = ''' 68 | 69 | ( 70 | /( 71 | \.eggs # exclude a few common directories in the 72 | | \.git # root of the project 73 | | \.hg 74 | | \.mypy_cache 75 | | \.tox 76 | | \.venv 77 | | _build 78 | | buck-out 79 | | build 80 | | dist 81 | | tutorials 82 | )/ 83 | | foo.py # also separately exclude a file named foo.py in 84 | # the root of the project 85 | ) 86 | ''' 87 | -------------------------------------------------------------------------------- /research/c4_grammatical_correction/README.md: -------------------------------------------------------------------------------- 1 | 2 | ### Sentence Masked Language Modeling using TFText 3 | 4 | This folder has all necessary scripts to run the Sentence MLM using tensorflow text. 5 | Instead of masking words, we mask sentences (sequence of words) 6 | 7 | ### Advantage 8 | 9 | No need to prepare features in TFRecord format. We make use of dynamic mlm. All we need is a text file or folder of text files with text files, which each line corrsponding to text. 10 | 11 | ### WandB 12 | 13 | By default we are using Wandb. if enviornment variable ```WANDB_PROJECT=None```, wandb will be disabled. 14 | 15 | ``` export WANDB_PROJECT='t5-c4-grammatical-correction' ``` 16 | ### Configuration (Hydra) 17 | 18 | All or most configurations can be managed using ```conf/config.yaml```. You can override it by command line also. 19 | 20 | Eg: For TPU , we need a data in GCS and model_checkpoint_dir to be in GCS too. 21 | 22 | ``` python3 run_c4_grammar_correction.py \ 23 | task.data_directory=gs://legacyai-bucket/c4_grammar_correction_data \ 24 | task.train_batch_size=512 \ 25 | trainer.dtype=bf16 \ 26 | trainer.model_checkpoint_dir=gs://legacyai-bucket/t5_c4_lr_3e5 \ 27 | trainer.steps_per_epoch=50000 \ 28 | trainer.epochs=10 \ 29 | trainer.strategy=tpu \ 30 | trainer.tpu_address=legacyai-tpu-2 \ 31 | optimizer.learning_rate=3e-5 32 | model.is_training=true 33 | model.use_dropout=true 34 | ``` 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /research/c4_grammatical_correction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/research/c4_grammatical_correction/__init__.py -------------------------------------------------------------------------------- /research/c4_grammatical_correction/conf/config.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | data_directory: 3 | max_seq_len: 64 4 | train_batch_size: 512 5 | trainer: 6 | dtype: bf16 7 | num_gpus: 0 8 | tpu_address: 9 | epochs: 10 10 | strategy: mirrored 11 | steps_per_epoch: 50000 12 | model_checkpoint_dir: 13 | global_norm: 1.0 14 | optimizer: 15 | learning_rate: 3e-5 16 | num_warmup_steps: 0.1 17 | decay_function: cosine 18 | adam_beta_1: 0.9 19 | adam_beta_2: 0.95 20 | adam_epsilon: 10e-8 21 | weight_decay_rate: 0.1 22 | optimizer_type: adamw 23 | loss_type: 24 | use_constant_lr: false 25 | model: 26 | is_training: false 27 | use_dropout: false 28 | num_layers: 12 29 | -------------------------------------------------------------------------------- /research/c4_grammatical_correction/run_c4_grammar_correction.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """This is the main script to run Mix Language Model""" 18 | import os 19 | import warnings 20 | 21 | import hydra 22 | from absl import logging 23 | from omegaconf import DictConfig 24 | from train_c4_grammar_correction import run_train 25 | 26 | logging.set_verbosity("INFO") 27 | 28 | # We set PROJECT_NAME from ENVIORNMENT VARIABLE 29 | WANDB_PROJECT = os.getenv('WANDB_PROJECT', None) 30 | use_wandb = True 31 | if WANDB_PROJECT is None: 32 | warnings.warn( 33 | "For wandb-project should not be None.\ 34 | Set export WANDB_PROJECT=" 35 | ) 36 | use_wandb = False 37 | 38 | 39 | @hydra.main(config_path="conf", config_name="config") 40 | def run(cfg: DictConfig) -> None: 41 | print("Config", cfg) 42 | config_dict = dict(cfg) 43 | # For TPU, we need to initialize it before tf text dataset 44 | # starts triggering. Hack 45 | if cfg.trainer.strategy == 'tpu': 46 | from model import get_trainer 47 | 48 | distribution_strategy = 'tpu' 49 | num_gpus = 0 50 | tpu_address = cfg.trainer.tpu_address 51 | get_trainer( 52 | distribution_strategy=distribution_strategy, 53 | num_gpus=num_gpus, 54 | tpu_address=tpu_address, 55 | dtype=cfg.trainer.dtype, 56 | ) # noqa 57 | 58 | if use_wandb: 59 | import wandb 60 | 61 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True) 62 | history = run_train(cfg, wandb) 63 | else: 64 | # Set wandb = None 65 | history = run_train(cfg, None) 66 | return history 67 | 68 | 69 | if __name__ == "__main__": 70 | history = run() 71 | -------------------------------------------------------------------------------- /research/diffusion/beta_schedule.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | 5 | 6 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): 7 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 8 | warmup_time = int(num_diffusion_timesteps * warmup_frac) 9 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) 10 | return betas 11 | 12 | 13 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 14 | """ 15 | Create a beta schedule that discretizes the given alpha_t_bar function, 16 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 17 | 18 | :param num_diffusion_timesteps: the number of betas to produce. 19 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 20 | produces the cumulative product of (1-beta) up to that 21 | part of the diffusion process. 22 | :param max_beta: the maximum beta to use; use values lower than 1 to 23 | prevent singularities. 24 | """ 25 | betas = [] 26 | for i in range(num_diffusion_timesteps): 27 | t1 = i / num_diffusion_timesteps 28 | t2 = (i + 1) / num_diffusion_timesteps 29 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 30 | return np.array(betas) 31 | 32 | 33 | def get_beta_schedule(beta_schedule, num_diffusion_timesteps, beta_start=0.0001, beta_end=0.02): 34 | if beta_schedule == 'quad': 35 | betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2 36 | elif beta_schedule == 'linear': 37 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 38 | elif beta_schedule == 'warmup10': 39 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) 40 | elif beta_schedule == 'warmup50': 41 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) 42 | elif beta_schedule == 'const': 43 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 44 | elif beta_schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 45 | betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) 46 | elif beta_schedule == 'cosine': 47 | betas = betas_for_alpha_bar( 48 | num_diffusion_timesteps, 49 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 50 | ) 51 | else: 52 | raise NotImplementedError(beta_schedule) 53 | 54 | assert betas.shape == (num_diffusion_timesteps,) 55 | return betas 56 | -------------------------------------------------------------------------------- /research/diffusion/time_embedding_layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class TimeEmbedding(tf.keras.layers.Layer): 5 | """Creates a sinusoidal embedding. 6 | 7 | This layer creates a sinusoidal embedding as described in "Attention is All you need": 8 | 9 | """ 10 | 11 | def __init__( 12 | self, 13 | n_channels, 14 | initializer="glorot_uniform", 15 | bias_initializer='zeros', 16 | name="time_embeddings", 17 | activation='swish', 18 | use_bias=True, 19 | dtype=tf.float32, 20 | **kwargs, 21 | ): 22 | """ 23 | Args: 24 | n_channels ([int]): Similar to embedding size 25 | scale_factor ([int]): How much to scale embedding size for next dense layer 26 | initializer (str, optional): The initializer to use for the 27 | embedding weights. Defaults to "glorot_uniform". 28 | name (str, optional): name of the layer. Defaults to "positional_embeddings". 29 | dtype ([type], optional): [description]. Defaults to tf.float32. 30 | """ 31 | super(TimeEmbedding, self).__init__(name=name, dtype=dtype, **kwargs) 32 | 33 | assert (n_channels % 2) == 0 34 | self._n_channels = n_channels 35 | self._initializer = initializer 36 | self._bias_initializer = bias_initializer 37 | self._dtype = dtype 38 | 39 | def get_config(self): 40 | """Config based on init arguments 41 | 42 | Returns: 43 | [dict]: Dict of all init arguments 44 | """ 45 | config = { 46 | "n_channels": self._n_channels, 47 | "initializer": self._initializer, 48 | "name": self._name, 49 | "dtype": self._dtype, 50 | } 51 | base_config = super(TimeEmbedding, self).get_config() 52 | return dict(list(base_config.items()) + list(config.items())) 53 | 54 | def call(self, timesteps): 55 | """Call 56 | 57 | Args: 58 | timesteps ([tf.Tensor]): input ids 1D timesteps (B, ) B is batch_size 59 | eg: 60 | 61 | Returns: 62 | [tf.Tensor]: embeddings 3D (b x s x h) 63 | """ 64 | half_dim = self._n_channels // 2 65 | emb = tf.math.log(10000.0) / (half_dim - 1) 66 | emb = tf.exp(tf.range(half_dim, dtype=self._dtype) * -emb) # 1-D vector of size half_dim 67 | emb = tf.cast(tf.expand_dims(timesteps, axis=1), self._dtype) * tf.expand_dims(emb, axis=0) # B x half_dim 68 | emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=1) # B x self._n_channels 69 | 70 | return emb 71 | -------------------------------------------------------------------------------- /research/diffusion/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | # As per original diffusion code 5 | def get_initializer(scale=0): 6 | if scale == 0: 7 | scale = scale = 1e-10 8 | initializer = tf.keras.initializers.VarianceScaling(scale=scale, mode='fan_avg', distribution='uniform') 9 | return initializer 10 | -------------------------------------------------------------------------------- /research/glue/conf/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_batch_size: 32 3 | eval_batch_size: 64 4 | take_sample: false 5 | max_seq_length: 128 6 | static_padding: false 7 | trainer: 8 | dtype: fp32 9 | num_gpus: 2 10 | tpu_address: 11 | epochs: 3 12 | strategy: mirrored 13 | optimizer: 14 | learning_rate: 2e-5 15 | loss_type: 16 | model: 17 | is_training: true 18 | use_dropout: true 19 | -------------------------------------------------------------------------------- /research/glue/conf/glue/cola.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: cola 3 | data: 4 | name: cola 5 | num_classes: 2 6 | max_seq_length: 256 7 | -------------------------------------------------------------------------------- /research/glue/conf/glue/mnli.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: mnli 3 | data: 4 | name: mnli 5 | num_classes: 3 6 | max_seq_length: 256 7 | -------------------------------------------------------------------------------- /research/glue/conf/glue/mrpc.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: mrpc 3 | data: 4 | name: mrpc 5 | num_classes: 2 6 | max_seq_length: 256 7 | -------------------------------------------------------------------------------- /research/glue/conf/glue/qnli.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: qnli 3 | data: 4 | name: qnli 5 | num_classes: 2 6 | max_seq_length: 256 7 | -------------------------------------------------------------------------------- /research/glue/conf/glue/qqp.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: qqp 3 | data: 4 | name: qqp 5 | num_classes: 2 6 | max_seq_length: 128 7 | -------------------------------------------------------------------------------- /research/glue/conf/glue/rte.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: rte 3 | data: 4 | name: rte 5 | num_classes: 2 6 | max_seq_length: 256 7 | -------------------------------------------------------------------------------- /research/glue/conf/glue/sst2.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: sst2 3 | data: 4 | name: sst2 5 | num_classes: 2 6 | max_seq_length: 256 7 | -------------------------------------------------------------------------------- /research/glue/conf/glue/stsb.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | name: stsb 3 | data: 4 | name: stsb 5 | num_classes: 1 6 | max_seq_length: 128 7 | -------------------------------------------------------------------------------- /research/glue/model.py: -------------------------------------------------------------------------------- 1 | from transformers import AlbertTokenizer 2 | 3 | from tf_transformers.core import Trainer 4 | from tf_transformers.models import AlbertModel as Model 5 | from tf_transformers.optimization import create_optimizer 6 | 7 | MODEL_NAME = "albert-base-v2" 8 | 9 | 10 | def get_model(return_all_layer_outputs, is_training, use_dropout): 11 | """Get the model""" 12 | model = Model.from_pretrained( 13 | MODEL_NAME, return_all_layer_outputs=return_all_layer_outputs, is_training=is_training, use_dropout=use_dropout 14 | ) 15 | return model 16 | 17 | 18 | def get_tokenizer(): 19 | """Get Tokenizer""" 20 | return AlbertTokenizer.from_pretrained(MODEL_NAME) 21 | 22 | 23 | def get_optimizer(learning_rate, examples, batch_size, epochs): 24 | """Get optimizer""" 25 | steps_per_epoch = int(examples / batch_size) 26 | num_train_steps = steps_per_epoch * epochs 27 | warmup_steps = int(0.1 * num_train_steps) 28 | 29 | def optimizer_fn(): 30 | optimizer, learning_rate_fn = create_optimizer(learning_rate, num_train_steps, warmup_steps) 31 | return optimizer 32 | 33 | return optimizer_fn 34 | 35 | 36 | def get_trainer(distribution_strategy, num_gpus=0, tpu_address=None): 37 | """Get Trainer""" 38 | trainer = Trainer(distribution_strategy, num_gpus=num_gpus, tpu_address=tpu_address) 39 | return trainer 40 | -------------------------------------------------------------------------------- /research/glue/run_mnli_mismatched.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """Run MNLI Mismatched validation""" 18 | 19 | import os 20 | 21 | import datasets 22 | from model import get_tokenizer 23 | 24 | 25 | def run_mnli_mismatched_evaluation( 26 | model, model_dir, write_tfrecord, read_tfrecord, metric_callback, number_of_checkpoints, max_seq_length 27 | ): 28 | """MNLI Mismatched evaluation""" 29 | 30 | data = datasets.load_dataset("glue", 'mnli') 31 | 32 | # Validation matched 33 | tokenizer = get_tokenizer() 34 | tfrecord_dir = "/tmp/glue/mnli_mismatched/" 35 | take_sample = False 36 | eval_batch_size = 32 37 | 38 | write_tfrecord( 39 | data["validation_mismatched"], 40 | max_seq_length, 41 | tokenizer, 42 | tfrecord_dir, 43 | mode="eval", 44 | take_sample=take_sample, 45 | verbose=1000, 46 | ) 47 | 48 | # Read TFRecords Validation 49 | eval_tfrecord_dir = os.path.join(tfrecord_dir, "eval") 50 | eval_dataset, total_eval_examples = read_tfrecord( 51 | eval_tfrecord_dir, eval_batch_size, shuffle=False, drop_remainder=False 52 | ) 53 | 54 | results_per_epoch = [] 55 | for i in range(1, number_of_checkpoints + 1): 56 | # Load checkpoint 57 | ckpt_path = os.path.join(model_dir, "ckpt-{}".format(i)) 58 | 59 | model.load_checkpoint(checkpoint_path=ckpt_path) 60 | 61 | result = metric_callback({"model": model, "validation_dataset": eval_dataset}) 62 | results_per_epoch.append(result) 63 | 64 | return results_per_epoch 65 | -------------------------------------------------------------------------------- /research/long_block_sequencer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/research/long_block_sequencer/__init__.py -------------------------------------------------------------------------------- /research/long_block_sequencer/conf/config.yaml: -------------------------------------------------------------------------------- 1 | 2 | task: 3 | max_seq_len: 4096 4 | decoder_seq_len: 256 5 | num_splits: 8 6 | use_gru_layer: false 7 | projection_dimension: 512 8 | train_batch_size: 8 9 | 10 | trainer: 11 | dtype: fp16 12 | num_gpus: 2 13 | tpu_address: 14 | epochs: 3 15 | strategy: mirrored 16 | model_checkpoint_dir: 17 | optimizer: 18 | learning_rate: 0.001 19 | loss_type: 20 | use_constant_lr: true 21 | model: 22 | num_layers: 12 23 | model_name: 't5-small' 24 | eval: 25 | eval_batch_size: 4 26 | model_checkpoint_dir: 27 | model_checkpoint_path: 28 | take_sample: false 29 | mode: 30 | -------------------------------------------------------------------------------- /research/long_block_sequencer/run_long_block_sequencer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """This is the main script to run Long Block Sequencer Model""" 18 | import os 19 | 20 | import hydra 21 | from absl import logging 22 | from omegaconf import DictConfig 23 | from train_long_block_sequencer import run_train 24 | 25 | logging.set_verbosity("INFO") 26 | 27 | # We set PROJECT_NAME from ENVIORNMENT VARIABLE 28 | WANDB_PROJECT = os.getenv('WANDB_PROJECT', None) 29 | use_wandb = True 30 | if WANDB_PROJECT is None: 31 | logging.info("Not using wandb as no `WANDB_PROJECT` has been set via export ") 32 | use_wandb = False 33 | 34 | 35 | @hydra.main(config_path="conf", config_name="config") 36 | def run(cfg: DictConfig) -> None: 37 | print("Config", cfg) 38 | config_dict = dict(cfg) 39 | # For TPU, we need to initialize it before tf text dataset 40 | # starts triggering. Hack 41 | if cfg.trainer.strategy == 'tpu': 42 | from model import get_trainer 43 | 44 | distribution_strategy = 'tpu' 45 | num_gpus = 0 46 | tpu_address = cfg.trainer.tpu_address 47 | get_trainer( 48 | distribution_strategy=distribution_strategy, 49 | num_gpus=num_gpus, 50 | tpu_address=tpu_address, 51 | dtype=cfg.trainer.dtype, 52 | ) # noqa 53 | 54 | if use_wandb: 55 | import wandb 56 | 57 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True) 58 | history = run_train(cfg, wandb) 59 | else: 60 | # Set wandb = None 61 | history = run_train(cfg, None) 62 | return history 63 | 64 | 65 | if __name__ == "__main__": 66 | history = run() 67 | -------------------------------------------------------------------------------- /research/masked_language_model/README.md: -------------------------------------------------------------------------------- 1 | 2 | ### Masked Language Modeling using TFText 3 | 4 | This folder has all necessary scripts to run the MLM using tensorflow text. 5 | 6 | ### Advantage 7 | 8 | No need to prepare features in TFRecord format. We make use of dynamic mlm. All we need is a text file or folder of text files with text files, which each line corrsponding to text. 9 | 10 | ### Configuration (Hydra) 11 | 12 | All or most configurations can be managed using ```conf/config.yaml```. You can override it by command line also. 13 | 14 | Eg: For TPU , we need a data in GCS and model_checkpoint_dir to be in GCS too. 15 | 16 | ```python3 run_mlm.py \ data.data_directory=$GCP_BUCKET/data/ \ trainer.model_checkpoint_dir=$GCP_BUCKET/model``` 17 | 18 | ### WandB 19 | 20 | By default we are using Wandb. Check ```run_mlm.py``` to disable it. 21 | -------------------------------------------------------------------------------- /research/masked_language_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/research/masked_language_model/__init__.py -------------------------------------------------------------------------------- /research/masked_language_model/conf/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_directory: 3 | train_batch_size: 32 4 | task: 5 | max_seq_len: 128 6 | max_predictions_per_seq: 20 7 | trainer: 8 | dtype: fp32 9 | num_gpus: 2 10 | tpu_address: 11 | epochs: 3 12 | strategy: mirrored 13 | steps_per_epoch: 10000 14 | model_checkpoint_dir: 15 | optimizer: 16 | learning_rate: 3e-5 17 | warmup_rate: 0.2 18 | loss_type: 19 | use_constant_lr: false 20 | model: 21 | is_training: true 22 | use_dropout: true 23 | num_layers: 24 24 | -------------------------------------------------------------------------------- /research/masked_language_model/dataset_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from random import shuffle 3 | 4 | import tensorflow as tf 5 | 6 | 7 | def get_dataset(data_directory, masked_lm_map_fn, batch_size): 8 | """Convert text to tf.data.Dataset after map fn 9 | 10 | Args: 11 | data_directory ([type]): [description] 12 | masked_lm_map_fn ([type]): [description] 13 | batch_size ([type]): [description] 14 | 15 | Returns: 16 | [type]: [description] 17 | """ 18 | 19 | def filter_out_empty_mask(x, y): 20 | """When an example doesn't have multiple sentences\ 21 | there wont be any masked sentence. Ignore those examples, 22 | as nothing to predict. 23 | """ 24 | return tf.greater(tf.reduce_sum(tf.cast(tf.not_equal(x['masked_lm_positions'], 0), tf.int32)), 0) 25 | 26 | all_text_files = tf.io.gfile.glob(os.path.join(data_directory, '*.txt')) 27 | shuffle(all_text_files) 28 | ds = tf.data.TextLineDataset(all_text_files) 29 | # Our data has sentences joined by '__||__'. So, for word based MLM 30 | # we need to replace '__||__', by ''. and club it as a single sentence 31 | # tf.strings.regex_replace not working as expected 32 | ds = ds.map(lambda x: tf.strings.split(x, '__||__'), num_parallel_calls=tf.data.AUTOTUNE) 33 | ds = ds.map(lambda x: tf.strings.reduce_join([x], separator=' '), num_parallel_calls=tf.data.AUTOTUNE) 34 | 35 | # We need to add the text as dict 36 | ds = ds.map(lambda x: {'text': x}, num_parallel_calls=tf.data.AUTOTUNE) 37 | 38 | # Do MLM 39 | ds = ds.map(masked_lm_map_fn, num_parallel_calls=tf.data.AUTOTUNE) 40 | 41 | # Filter examples if there is not atleast single MASK sentence 42 | ds = ds.filter(filter_out_empty_mask) 43 | 44 | # # Shuffle and Prefetch 45 | ds = ds.shuffle(100, reshuffle_each_iteration=True).prefetch(buffer_size=tf.data.AUTOTUNE) 46 | 47 | # Batch 48 | ds = ds.batch(batch_size, drop_remainder=True) 49 | 50 | # Auto SHARD 51 | options = tf.data.Options() 52 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO 53 | ds = ds.with_options(options) 54 | 55 | return ds 56 | -------------------------------------------------------------------------------- /research/masked_language_model/model.py: -------------------------------------------------------------------------------- 1 | from tf_transformers.core import Trainer 2 | from tf_transformers.losses.loss_wrapper import get_lm_loss 3 | from tf_transformers.models import ( 4 | BigBirdRobertaTokenizerTFText, 5 | GPT2Model, 6 | MaskedLMModel, 7 | ) 8 | from tf_transformers.optimization import create_optimizer 9 | 10 | MODEL_NAME = 'gpt2' 11 | TOKENIZER_NAME = "google/bigbird-roberta-large" 12 | 13 | 14 | def get_model(return_all_layer_outputs, is_training, use_dropout, vocab_size): 15 | """Get the model from model function""" 16 | 17 | def model_fn(): 18 | # We use GPT2 Style model, but we use BigBird Roberta Tokenizer 19 | config = GPT2Model.get_config(MODEL_NAME) 20 | # We update the vocab_size for that reason 21 | config['vocab_size'] = vocab_size 22 | model = GPT2Model.from_config(config, mask_mode='user_defined', return_layer=True) 23 | model = MaskedLMModel( 24 | model, 25 | use_extra_mlm_layer=False, 26 | hidden_size=config['embedding_size'], 27 | layer_norm_epsilon=config['layer_norm_epsilon'], 28 | ) 29 | model = model.get_model() 30 | return model 31 | 32 | return model_fn 33 | 34 | 35 | def get_tokenizer(): 36 | tokenizer_layer = BigBirdRobertaTokenizerTFText.from_pretrained(TOKENIZER_NAME) 37 | return tokenizer_layer 38 | 39 | 40 | def get_optimizer(learning_rate, steps_per_epoch, epochs, warmup_rate, use_constant_lr=False): 41 | """Get AdamW optimizer""" 42 | 43 | # Total steps over all epochs 44 | num_train_steps = steps_per_epoch * epochs 45 | warmup_steps = int(warmup_rate * num_train_steps) 46 | 47 | def optimizer_fn(): 48 | if use_constant_lr: 49 | from tf_transformers.optimization.adam_weighted import AdamWeightDecay 50 | 51 | optimizer = AdamWeightDecay(learning_rate=learning_rate) 52 | return optimizer 53 | 54 | optimizer, learning_rate_fn = create_optimizer(learning_rate, num_train_steps, warmup_steps) 55 | return optimizer 56 | 57 | return optimizer_fn 58 | 59 | 60 | def get_loss(loss_type): 61 | """Get MLM Loss""" 62 | return get_lm_loss(loss_type=loss_type) 63 | 64 | 65 | def get_trainer(distribution_strategy, dtype, num_gpus=0, tpu_address=None): 66 | """Get Trainer""" 67 | trainer = Trainer(distribution_strategy, dtype=dtype, num_gpus=num_gpus, tpu_address=tpu_address) 68 | return trainer 69 | 70 | 71 | def get_hf_tokenizer(): 72 | """Get HuggingFace Tokenizer""" 73 | from transformers import BigBirdTokenizer 74 | 75 | tokenizer = BigBirdTokenizer.from_pretrained(TOKENIZER_NAME) 76 | return tokenizer 77 | -------------------------------------------------------------------------------- /research/masked_language_model/run_mlm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """This is the main script to run MLM""" 18 | import os 19 | import warnings 20 | 21 | import hydra 22 | from absl import logging 23 | from omegaconf import DictConfig 24 | from train_mlm import run_train 25 | 26 | logging.set_verbosity("INFO") 27 | 28 | # We set PROJECT_NAME from ENVIORNMENT VARIABLE 29 | WANDB_PROJECT = os.getenv('WANDB_PROJECT', None) 30 | use_wandb = True 31 | if WANDB_PROJECT is None: 32 | warnings.warn( 33 | "For wandb-project should not be None.\ 34 | Set export WANDB_PROJECT=" 35 | ) 36 | use_wandb = False 37 | 38 | 39 | @hydra.main(config_path="conf", config_name="config") 40 | def run(cfg: DictConfig) -> None: 41 | print("Config", cfg) 42 | config_dict = dict(cfg) 43 | # For TPU, we need to initialize it before tf text dataset 44 | # starts triggering. Hack 45 | if cfg.trainer.strategy == 'tpu': 46 | from model import get_trainer 47 | 48 | distribution_strategy = 'tpu' 49 | num_gpus = 0 50 | tpu_address = cfg.trainer.tpu_address 51 | get_trainer( 52 | distribution_strategy=distribution_strategy, 53 | num_gpus=num_gpus, 54 | tpu_address=tpu_address, 55 | dtype=cfg.trainer.dtype, 56 | ) # noqa 57 | 58 | if use_wandb: 59 | import wandb 60 | 61 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True) 62 | history = run_train(cfg, wandb) 63 | else: 64 | # Set wandb = None 65 | history = run_train(cfg, None) 66 | return history 67 | 68 | 69 | if __name__ == "__main__": 70 | history = run() 71 | -------------------------------------------------------------------------------- /research/masked_language_model_old/1_data_to_text.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import hydra 5 | from omegaconf import DictConfig, OmegaConf 6 | 7 | from tf_transformers.data.utils import hf_dump_chars_to_textfile 8 | 9 | # A logger for this file 10 | log = logging.getLogger(__name__) 11 | 12 | 13 | def write_data(cfg): 14 | """Load dataset and write to txt file""" 15 | output_file = cfg.data.output_text_file 16 | if os.path.isfile(output_file): 17 | raise FileExistsError() 18 | 19 | from datasets import load_dataset 20 | 21 | if cfg.data.version: 22 | dataset = load_dataset(cfg.data.name, cfg.data.version) 23 | else: 24 | dataset = load_dataset(cfg.data.name) 25 | 26 | split = cfg.data.split # train, test, dev 27 | data_keys = cfg.data.keys # text 28 | hf_dump_chars_to_textfile(output_file, dataset[split], data_keys, max_char=-1) 29 | 30 | 31 | @hydra.main(config_path="config", config_name="data_config") 32 | def run(cfg: DictConfig) -> None: 33 | print(OmegaConf.to_yaml(cfg)) 34 | write_data(cfg) 35 | 36 | 37 | if __name__ == "__main__": 38 | run() 39 | -------------------------------------------------------------------------------- /research/masked_language_model_old/2_text_to_features.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import tensorflow as tf 3 | from omegaconf import DictConfig, OmegaConf 4 | 5 | from tf_transformers.data import TFWriter 6 | from tf_transformers.text import SentencepieceTokenizer 7 | 8 | 9 | def load_tokenizer(cfg): 10 | """Load tf text based tokenizer""" 11 | model_file_path = cfg.tokenizer.model_file_path 12 | do_lower_case = cfg.tokenizer.do_lower_case 13 | special_tokens = cfg.tokenizer.special_tokens 14 | 15 | tokenizer_layer = SentencepieceTokenizer( 16 | model_file_path=model_file_path, lower_case=do_lower_case, special_tokens=special_tokens 17 | ) 18 | 19 | return tokenizer_layer 20 | 21 | 22 | def create_tfrecords(cfg): 23 | """Prepare tfrecords""" 24 | schema = { 25 | "input_ids": ("var_len", "int"), 26 | } 27 | 28 | tfrecord_output_dir = cfg.data.tfrecord_output_dir 29 | tfrecord_filename = cfg.data.tfrecord_filename 30 | tfrecord_nfiles = cfg.data.tfrecord_nfiles 31 | tfrecord_mode = cfg.data.tfrecord_mode 32 | tfrecord_overwrite = cfg.data.tfrecord_overwrite 33 | 34 | input_text_files = cfg.data.input_text_files 35 | batch_size = cfg.data.batch_size 36 | 37 | tfwriter = TFWriter( 38 | schema=schema, 39 | file_name=tfrecord_filename, 40 | model_dir=tfrecord_output_dir, 41 | tag=tfrecord_mode, 42 | n_files=tfrecord_nfiles, 43 | overwrite=tfrecord_overwrite, 44 | ) 45 | 46 | dataset = tf.data.TextLineDataset(input_text_files) 47 | 48 | def text_normalize(line): 49 | """Exclude empty string""" 50 | line = tf.strings.strip(line) 51 | return tf.not_equal(tf.strings.length(line), 0) 52 | 53 | dataset = dataset.filter(text_normalize) 54 | dataset = dataset.apply(tf.data.experimental.unique()) 55 | dataset = dataset.batch(batch_size, drop_remainder=False) 56 | 57 | def parse_train(): 58 | import tqdm 59 | 60 | tokenizer_layer = load_tokenizer(cfg) 61 | for batch_input in tqdm.tqdm(dataset): 62 | batch_input = {'text': [batch_input]} 63 | batch_tokenized = tokenizer_layer(batch_input)["input_ids"].to_list() 64 | for example_input_ids in batch_tokenized: 65 | yield {"input_ids": example_input_ids} 66 | 67 | # Process 68 | tfwriter.process(parse_fn=parse_train()) 69 | 70 | 71 | @hydra.main(config_path="config", config_name="tfrecord_config") 72 | def run(cfg: DictConfig) -> None: 73 | print(OmegaConf.to_yaml(cfg)) 74 | create_tfrecords(cfg) 75 | 76 | 77 | if __name__ == "__main__": 78 | run() 79 | -------------------------------------------------------------------------------- /research/masked_language_model_old/README.MD: -------------------------------------------------------------------------------- 1 | 2 | # Prepare Data 3 | 4 | python3 1_data_to_text.py data.name=wikipedia data.version=20200501.en data.output_text_file=/home/sidhu/Datasets/data/wikipedia.txt 5 | python3 1_data_to_text.py data.name=bookcorpus data.output_text_file=/home/sidhu/datasets/bookcorpus.txt 6 | 7 | # Prepare tfrecords 8 | 9 | nohup python3 2_text_to_features.py tokenizer.model_file_path=/home/sidhu/Datasets/data/t5_extended_vocab/new_spiece.model tokenizer.do_lower_case=false data.tfrecord_output_dir=/home/sidhu/Datasets/data/wiki_tfrecords data.tfrecord_filename=wiki data.tfrecord_nfiles=10 data.input_text_files=[/home/sidhu/Datasets/data/wikipedia.txt] data.batch_size=1024 > wiki_tfrecord.log & 10 | 11 | 12 | Bookcorpus 13 | 14 | nohup python3 2_text_to_features.py tokenizer.model_file_path=/home/sidhu/Datasets/data/t5_extended_vocab/new_spiece.model tokenizer.do_lower_case=false data.tfrecord_output_dir=/home/sidhu/Datasets/data/bookcorpus_tfrecords data.tfrecord_filename=bookcorpus data.tfrecord_nfiles=10 data.input_text_files=[/home/sidhu/Datasets/data/bookcorpus.txt] data.batch_size=1024 > bookcorpus_tfrecord.log & 15 | 16 | 17 | python3 train_mlm.py data.tfrecord_path_list=["/home/sidhu/Datasets/bookcorpus_tfrecords", "/home/sidhu/Datasets/wiki_tfrecords"] \ 18 | tokenizer.model_file_path=/home/Sidhu/Datasets/vocab/new_spiece.model 19 | 20 | python3 train_mlm.py tokenizer.model_file_path=/home/sidhu/Datasets/vocab/new_spiece.model \ 21 | model.model_save_dir=/home/sidhu/Projects/joint_bert 22 | -------------------------------------------------------------------------------- /research/masked_language_model_old/config/data_config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: mc4 3 | version: 4 | split: train 5 | output_text_file: 'output.txt' 6 | data_keys: ['text'] 7 | -------------------------------------------------------------------------------- /research/masked_language_model_old/config/tfrecord_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | tokenizer: 3 | model_file_path: 4 | do_lower_case: false 5 | special_tokens: ['[CLS]', '[MASK]', '', '', ''] 6 | data: 7 | tfrecord_output_dir: 8 | tfrecord_filename: 9 | tfrecord_nfiles: 10 | tfrecord_mode: train 11 | tfrecord_overwrite: false 12 | input_text_files: 13 | batch_size: 1024 14 | -------------------------------------------------------------------------------- /research/masked_language_model_old/config/train_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | tokenizer: 3 | vocab_size: 32002 4 | cls_token: '[CLS]' 5 | mask_token: '[MASK]' 6 | pad_token: '' 7 | sep_token: '' 8 | unk_token: '' 9 | model_file_path: 'new_spiece.model' 10 | do_lower_case: false 11 | special_tokens: ['[CLS]', '[MASK]', '', '', ''] 12 | data: 13 | max_seq_len: 128 14 | max_predictions_per_batch: 20 15 | batch_size: 512 16 | min_sen_len: 17 | tfrecord_path_list: ['path1, path2'] 18 | model: 19 | optimizer: 20 | learning_rate: 5e-5 21 | train_steps: 2000000 22 | warmup_steps: 60000 23 | optimizer_type: adamw 24 | loss: 25 | loss_type: joint 26 | epochs: 2 27 | steps_per_epoch: 200 28 | callback_steps: [100] 29 | model_save_dir: 30 | trainer: 31 | device_type: tpu 32 | device_address: local 33 | dtype: bf16 34 | -------------------------------------------------------------------------------- /research/mix_language_model/README.MD: -------------------------------------------------------------------------------- 1 | 2 | ### Sentence Masked Language Modeling using TFText 3 | 4 | This folder has all necessary scripts to run the Sentence MLM using tensorflow text. 5 | Instead of masking words, we mask sentences (sequence of words) 6 | 7 | ### Advantage 8 | 9 | No need to prepare features in TFRecord format. We make use of dynamic mlm. All we need is a text file or folder of text files with text files, which each line corrsponding to text. 10 | 11 | ### Configuration (Hydra) 12 | 13 | All or most configurations can be managed using ```conf/config.yaml```. You can override it by command line also. 14 | 15 | Eg: For TPU , we need a data in GCS and model_checkpoint_dir to be in GCS too. 16 | 17 | ```python3 run_mlm.py task.data_directory= task.train_batch_size=128 trainer.dtype=bf16 trainer.model_checkpoint_dir= trainer.steps_per_epoch=50000 trainer.callback_steps=10000 trainer.epochs=20 trainer.strategy=tpu trainer.tpu_address= optimizer.learning_rate=5e-4``` 18 | 19 | ### WandB 20 | 21 | By default we are using Wandb. if enviornment variable ```WANDB_PROJECT=None```, wandb will be disabled. 22 | -------------------------------------------------------------------------------- /research/mix_language_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/research/mix_language_model/__init__.py -------------------------------------------------------------------------------- /research/mix_language_model/conf/config.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | data_directory: 3 | max_seq_len: 1024 4 | max_predictions_per_seq: 1024 5 | minimum_prefix_length: 900 6 | train_batch_size: 128 7 | trainer: 8 | dtype: bf16 9 | num_gpus: 0 10 | tpu_address: 11 | epochs: 20 12 | strategy: mirrored 13 | steps_per_epoch: 50000 14 | model_checkpoint_dir: 15 | global_norm: 1.0 16 | optimizer: 17 | learning_rate: 0.006 18 | num_warmup_steps: 0.1 19 | decay_function: cosine 20 | adam_beta_1: 0.9 21 | adam_beta_2: 0.95 22 | adam_epsilon: 10e-8 23 | weight_decay_rate: 0.1 24 | optimizer_type: adamw 25 | loss_type: 26 | use_constant_lr: false 27 | model: 28 | is_training: false 29 | use_dropout: false 30 | num_layers: 12 31 | -------------------------------------------------------------------------------- /research/mix_language_model/run_mix_lm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """This is the main script to run Mix Language Model""" 18 | import os 19 | import warnings 20 | 21 | import hydra 22 | from absl import logging 23 | from omegaconf import DictConfig 24 | from train_mix_lm import run_train 25 | 26 | logging.set_verbosity("INFO") 27 | 28 | # We set PROJECT_NAME from ENVIORNMENT VARIABLE 29 | WANDB_PROJECT = os.getenv('WANDB_PROJECT', None) 30 | use_wandb = True 31 | if WANDB_PROJECT is None: 32 | warnings.warn( 33 | "For wandb-project should not be None.\ 34 | Set export WANDB_PROJECT=" 35 | ) 36 | use_wandb = False 37 | 38 | 39 | @hydra.main(config_path="conf", config_name="config") 40 | def run(cfg: DictConfig) -> None: 41 | print("Config", cfg) 42 | config_dict = dict(cfg) 43 | # For TPU, we need to initialize it before tf text dataset 44 | # starts triggering. Hack 45 | if cfg.trainer.strategy == 'tpu': 46 | from model import get_trainer 47 | 48 | distribution_strategy = 'tpu' 49 | num_gpus = 0 50 | tpu_address = cfg.trainer.tpu_address 51 | get_trainer( 52 | distribution_strategy=distribution_strategy, 53 | num_gpus=num_gpus, 54 | tpu_address=tpu_address, 55 | dtype=cfg.trainer.dtype, 56 | ) # noqa 57 | 58 | if use_wandb: 59 | import wandb 60 | 61 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True) 62 | history = run_train(cfg, wandb) 63 | else: 64 | # Set wandb = None 65 | history = run_train(cfg, None) 66 | return history 67 | 68 | 69 | if __name__ == "__main__": 70 | history = run() 71 | -------------------------------------------------------------------------------- /research/mix_language_model_old/1_data_to_text.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import hydra 5 | from omegaconf import DictConfig, OmegaConf 6 | 7 | from tf_transformers.data.utils import hf_dump_chars_to_textfile 8 | 9 | # A logger for this file 10 | log = logging.getLogger(__name__) 11 | 12 | 13 | def write_data(cfg): 14 | """Load dataset and write to txt file""" 15 | output_file = cfg.data.output_text_file 16 | if os.path.isfile(output_file): 17 | raise FileExistsError() 18 | 19 | from datasets import load_dataset 20 | 21 | if cfg.data.version: 22 | dataset = load_dataset(cfg.data.name, cfg.data.version) 23 | else: 24 | dataset = load_dataset(cfg.data.name) 25 | 26 | split = cfg.data.split # train, test, dev 27 | data_keys = cfg.data.keys # text 28 | hf_dump_chars_to_textfile(output_file, dataset[split], data_keys, max_char=-1) 29 | 30 | 31 | @hydra.main(config_path="config", config_name="data_config") 32 | def run(cfg: DictConfig) -> None: 33 | print(OmegaConf.to_yaml(cfg)) 34 | write_data(cfg) 35 | 36 | 37 | if __name__ == "__main__": 38 | run() 39 | -------------------------------------------------------------------------------- /research/mix_language_model_old/2_text_to_features.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import tensorflow as tf 3 | from omegaconf import DictConfig, OmegaConf 4 | 5 | from tf_transformers.data import TFWriter 6 | from tf_transformers.text import SentencepieceTokenizer 7 | 8 | 9 | def load_tokenizer(cfg): 10 | """Load tf text based tokenizer""" 11 | model_file_path = cfg.tokenizer.model_file_path 12 | do_lower_case = cfg.tokenizer.do_lower_case 13 | special_tokens = cfg.tokenizer.special_tokens 14 | 15 | tokenizer_layer = SentencepieceTokenizer( 16 | model_file_path=model_file_path, lower_case=do_lower_case, special_tokens=special_tokens 17 | ) 18 | 19 | return tokenizer_layer 20 | 21 | 22 | def create_tfrecords(cfg): 23 | """Prepare tfrecords""" 24 | schema = { 25 | "input_ids": ("var_len", "int"), 26 | } 27 | 28 | tfrecord_output_dir = cfg.data.tfrecord_output_dir 29 | tfrecord_filename = cfg.data.tfrecord_filename 30 | tfrecord_nfiles = cfg.data.tfrecord_nfiles 31 | tfrecord_mode = cfg.data.tfrecord_mode 32 | tfrecord_overwrite = cfg.data.tfrecord_overwrite 33 | 34 | input_text_files = cfg.data.input_text_files 35 | batch_size = cfg.data.batch_size 36 | 37 | tfwriter = TFWriter( 38 | schema=schema, 39 | file_name=tfrecord_filename, 40 | model_dir=tfrecord_output_dir, 41 | tag=tfrecord_mode, 42 | n_files=tfrecord_nfiles, 43 | overwrite=tfrecord_overwrite, 44 | ) 45 | 46 | dataset = tf.data.TextLineDataset(input_text_files) 47 | 48 | def text_normalize(line): 49 | """Exclude empty string""" 50 | line = tf.strings.strip(line) 51 | return tf.not_equal(tf.strings.length(line), 0) 52 | 53 | dataset = dataset.filter(text_normalize) 54 | dataset = dataset.apply(tf.data.experimental.unique()) 55 | dataset = dataset.batch(batch_size, drop_remainder=False) 56 | 57 | def parse_train(): 58 | import tqdm 59 | 60 | tokenizer_layer = load_tokenizer(cfg) 61 | for batch_input in tqdm.tqdm(dataset): 62 | batch_input = {'text': [batch_input]} 63 | batch_tokenized = tokenizer_layer(batch_input)["input_ids"].to_list() 64 | for example_input_ids in batch_tokenized: 65 | yield {"input_ids": example_input_ids} 66 | 67 | # Process 68 | tfwriter.process(parse_fn=parse_train()) 69 | 70 | 71 | @hydra.main(config_path="config", config_name="tfrecord_config") 72 | def run(cfg: DictConfig) -> None: 73 | print(OmegaConf.to_yaml(cfg)) 74 | create_tfrecords(cfg) 75 | 76 | 77 | if __name__ == "__main__": 78 | run() 79 | -------------------------------------------------------------------------------- /research/mix_language_model_old/README.MD: -------------------------------------------------------------------------------- 1 | 2 | # Prepare Data 3 | 4 | python3 1_data_to_text.py data.name=wikipedia data.version= data.output_text_file=/home/Sidhu/datasets/wikipedia.txt 5 | python3 1_data_to_text.py data.name=bookcorpus data.output_text_file=/home/Sidhu/datasets/bookcorpus.txt 6 | 7 | # Prepare tfrecords 8 | 9 | nohup python3 2_text_to_features.py tokenizer.model_file_path=/home/sidhu/Datasets/vocab/new_spiece.model tokenizer.do_lower_case=false data.tfrecord_output_dir=/home/sidhu/Datasets/wiki_tfrecords data.tfrecord_filename=wiki data.tfrecord_nfiles=25 data.input_text_files=[/home/sidhu/Datasets/wikipedia.txt] data.batch_size=1024 > wiki_tfrecord.log & 10 | 11 | 12 | Bookcorpus 13 | 14 | nohup python3 2_text_to_features.py tokenizer.model_file_path=/home/sidhu/Datasets/vocab/new_spiece.model tokenizer.do_lower_case=false data.tfrecord_output_dir=/home/sidhu/Datasets/bookcorpus_tfrecords data.tfrecord_filename=bookcorpus data.tfrecord_nfiles=10 data.input_text_files=[/home/sidhu/Datasets/bookcorpus.txt] data.batch_size=1024 > bookcorpus_tfrecord.log & 15 | 16 | 17 | python3 train_mix_mlm.py tokenizer.model_file_path=/home/sidhu/Datasets/vocab/new_spiece.model \ 18 | model.model_save_dir=/home/sidhu/Projects/joint_bert 19 | -------------------------------------------------------------------------------- /research/mix_language_model_old/config/data_config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: mc4 3 | version: 4 | split: train 5 | output_text_file: 'output.txt' 6 | data_keys: ['text'] 7 | -------------------------------------------------------------------------------- /research/mix_language_model_old/config/tfrecord_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | tokenizer: 3 | model_file_path: 4 | do_lower_case: false 5 | special_tokens: ['[CLS]', '[MASK]', '', '', ''] 6 | data: 7 | tfrecord_output_dir: 8 | tfrecord_filename: 9 | tfrecord_nfiles: 10 | tfrecord_mode: train 11 | tfrecord_overwrite: false 12 | input_text_files: 13 | batch_size: 1024 14 | -------------------------------------------------------------------------------- /research/mix_language_model_old/config/train_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | tokenizer: 3 | vocab_size: 32002 4 | cls_token: '[CLS]' 5 | mask_token: '[MASK]' 6 | pad_token: '' 7 | sep_token: '' 8 | unk_token: '' 9 | model_file_path: 'new_spiece.model' 10 | do_lower_case: false 11 | special_tokens: ['[CLS]', '[MASK]', '', '', ''] 12 | data: 13 | max_seq_len: 128 14 | max_predictions_per_batch: 20 15 | batch_size: 2048 16 | min_sen_len: 17 | tfrecord_path_list: ['path1, path2'] 18 | model: 19 | optimizer: 20 | learning_rate: 1e-4 21 | train_steps: 2000000 22 | warmup_steps: 60000 23 | optimizer_type: adamw 24 | loss: 25 | loss_type: 26 | epochs: 2 27 | steps_per_epoch: 200 28 | callback_steps: [100] 29 | model_save_dir: 30 | trainer: 31 | device_type: tpu 32 | device_address: local 33 | dtype: bf16 34 | -------------------------------------------------------------------------------- /research/sentence2vec/train.py: -------------------------------------------------------------------------------- 1 | import tensorflow_text as tf_text # noqa 2 | import wandb 3 | from dataset_loader import get_dataset 4 | from model import get_model, get_optimizer, loss_fn 5 | 6 | from tf_transformers.core import Trainer 7 | from tf_transformers.models import AlbertTokenizerTFText 8 | 9 | wandb.login() 10 | 11 | TPU_ADDRESS = 'legacyai-tpu-2' 12 | DTYPE = 'bf16' 13 | 14 | delim_regex_pattern = '\. ' # noqa 15 | window_length = 10 16 | minimum_sentences = 4 17 | batch_size = 512 18 | max_seq_length = 256 19 | 20 | learning_rate = 2e-5 21 | epochs = 50 22 | steps_per_epoch = 100000 23 | num_train_steps = steps_per_epoch * epochs 24 | num_warmup_steps = 0.1 * num_train_steps 25 | global_norm = 5.0 26 | optimizer_fn = get_optimizer(learning_rate, num_train_steps, num_warmup_steps, decay_function='linear') 27 | 28 | clip_logits = True 29 | use_random_base = False 30 | siamese = True 31 | 32 | model_checkpoint_dir = 'gs://legacyai-bucket/sentence2vec_1' 33 | 34 | WANDB_PROJECT = 'sentence2vec' 35 | config_dict = {} 36 | config_dict['learning_rate'] = learning_rate 37 | config_dict['steps_per_epoch'] = steps_per_epoch 38 | config_dict['epochs'] = epochs 39 | config_dict['num_train_steps'] = steps_per_epoch * epochs 40 | config_dict['num_warmup_steps'] = 0.1 * num_train_steps 41 | config_dict['global_norm'] = global_norm 42 | config_dict['model_checkpoint_dir'] = model_checkpoint_dir 43 | config_dict['clip_logits'] = clip_logits 44 | config_dict['use_random_base'] = use_random_base 45 | 46 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True) 47 | 48 | 49 | trainer = Trainer(distribution_strategy='tpu', tpu_address=TPU_ADDRESS, dtype=DTYPE) 50 | 51 | tokenizer_layer = AlbertTokenizerTFText.from_pretrained("albert-base-v2", add_special_tokens=False) 52 | train_dataset = get_dataset( 53 | delim_regex_pattern, minimum_sentences, window_length, tokenizer_layer, max_seq_length, batch_size 54 | ) 55 | model_fn = get_model(clip_logits, use_random_base, siamese) 56 | 57 | 58 | # Train 59 | training_loss_names = ['loss_cls', 'loss_mean', 'loss'] 60 | history = trainer.run( 61 | model_fn=model_fn, 62 | optimizer_fn=optimizer_fn, 63 | train_dataset=train_dataset, 64 | train_loss_fn=loss_fn, 65 | epochs=epochs, 66 | steps_per_epoch=steps_per_epoch, 67 | model_checkpoint_dir=model_checkpoint_dir, 68 | batch_size=batch_size, 69 | training_loss_names=training_loss_names, 70 | repeat_dataset=True, 71 | wandb=wandb, 72 | clip_norm=global_norm, 73 | ) 74 | -------------------------------------------------------------------------------- /research/sentence_language_model/README.md: -------------------------------------------------------------------------------- 1 | 2 | ### Sentence Masked Language Modeling using TFText 3 | 4 | This folder has all necessary scripts to run the Sentence MLM using tensorflow text. 5 | Instead of masking words, we mask sentences (sequence of words) 6 | 7 | ### Advantage 8 | 9 | No need to prepare features in TFRecord format. We make use of dynamic mlm. All we need is a text file or folder of text files with text files, which each line corrsponding to text. 10 | 11 | ### Configuration (Hydra) 12 | 13 | All or most configurations can be managed using ```conf/config.yaml```. You can override it by command line also. 14 | 15 | Eg: For TPU , we need a data in GCS and model_checkpoint_dir to be in GCS too. 16 | 17 | ```python3 run_mlm.py data.data_directory= data.train_batch_size=128 trainer.dtype=bf16 trainer.model_checkpoint_dir= trainer.steps_per_epoch=50000 trainer.callback_steps=10000 trainer.epochs=20 trainer.strategy=tpu trainer.tpu_address= optimizer.learning_rate=5e-4``` 18 | 19 | ### WandB 20 | 21 | By default we are using Wandb. if enviornment variable ```WANDB_PROJECT=None```, wandb will be disabled. 22 | -------------------------------------------------------------------------------- /research/sentence_language_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/research/sentence_language_model/__init__.py -------------------------------------------------------------------------------- /research/sentence_language_model/conf/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_directory: 3 | train_batch_size: 128 4 | task: 5 | max_seq_len: 1024 6 | max_predictions_per_seq: 200 7 | trainer: 8 | dtype: bf16 9 | num_gpus: 0 10 | tpu_address: 11 | epochs: 20 12 | strategy: mirrored 13 | steps_per_epoch: 50000 14 | model_checkpoint_dir: 15 | callback_steps: 10000 16 | optimizer: 17 | learning_rate: 1e-4 18 | warmup_rate: 0.2 19 | loss_type: 20 | use_constant_lr: false 21 | model: 22 | is_training: true 23 | use_dropout: true 24 | num_layers: 12 25 | -------------------------------------------------------------------------------- /research/sentence_language_model/dataset_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from random import shuffle 3 | 4 | import tensorflow as tf 5 | 6 | 7 | def get_dataset(data_directory, masked_lm_map_fn, batch_size): 8 | """Convert text to tf.data.Dataset after map fn 9 | 10 | Args: 11 | data_directory ([type]): [description] 12 | masked_lm_map_fn ([type]): [description] 13 | batch_size ([type]): [description] 14 | 15 | Returns: 16 | [type]: [description] 17 | """ 18 | 19 | def filter_out_empty_mask(x, y): 20 | """When an example doesn't have multiple sentences\ 21 | there wont be any masked sentence. Ignore those examples, 22 | as nothing to predict. 23 | """ 24 | return tf.greater(tf.reduce_sum(tf.cast(tf.not_equal(x['masked_lm_positions'], 0), tf.int32)), 0) 25 | 26 | all_text_files = tf.io.gfile.glob(os.path.join(data_directory, '*.txt')) 27 | shuffle(all_text_files) 28 | ds = tf.data.TextLineDataset(all_text_files) 29 | 30 | # We need to add the text as dict 31 | ds = ds.map(lambda x: {'text': x}, num_parallel_calls=tf.data.AUTOTUNE) 32 | 33 | # Do MLM 34 | ds = ds.map(masked_lm_map_fn, num_parallel_calls=tf.data.AUTOTUNE) 35 | 36 | # Filter examples if there is not atleast single MASK sentence 37 | ds = ds.filter(filter_out_empty_mask) 38 | 39 | # Batch 40 | ds = ds.batch(batch_size, drop_remainder=True) 41 | 42 | # Shuffle and Prefetch 43 | ds = ds.shuffle(100, reshuffle_each_iteration=True).prefetch(buffer_size=tf.data.AUTOTUNE) 44 | 45 | # Auto SHARD 46 | options = tf.data.Options() 47 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO 48 | ds = ds.with_options(options) 49 | 50 | return ds 51 | -------------------------------------------------------------------------------- /research/sentence_language_model/model.py: -------------------------------------------------------------------------------- 1 | from tf_transformers.core import Trainer 2 | from tf_transformers.losses.loss_wrapper import get_lm_loss 3 | from tf_transformers.models import ( 4 | BigBirdRobertaTokenizerTFText, 5 | GPT2Model, 6 | MaskedLMModel, 7 | ) 8 | from tf_transformers.optimization import create_optimizer 9 | 10 | MODEL_NAME = 'gpt2' 11 | TOKENIZER_NAME = "google/bigbird-roberta-large" 12 | 13 | 14 | def get_model(return_all_layer_outputs, is_training, use_dropout, vocab_size): 15 | """Get the model from model function""" 16 | 17 | def model_fn(): 18 | # We use GPT2 Style model, but we use BigBird Roberta Tokenizer 19 | config = GPT2Model.get_config(MODEL_NAME) 20 | # We update the vocab_size for that reason 21 | config['vocab_size'] = vocab_size 22 | model = GPT2Model.from_config(config, mask_mode='user_defined', return_layer=True) 23 | model = MaskedLMModel( 24 | model, 25 | use_extra_mlm_layer=False, 26 | hidden_size=config['embedding_size'], 27 | layer_norm_epsilon=config['layer_norm_epsilon'], 28 | ) 29 | model = model.get_model() 30 | return model 31 | 32 | return model_fn 33 | 34 | 35 | def get_tokenizer(): 36 | tokenizer_layer = BigBirdRobertaTokenizerTFText.from_pretrained(TOKENIZER_NAME) 37 | return tokenizer_layer 38 | 39 | 40 | def get_optimizer(learning_rate, steps_per_epoch, epochs, warmup_rate, use_constant_lr=False): 41 | """Get AdamW optimizer""" 42 | 43 | # Total steps over all epochs 44 | num_train_steps = steps_per_epoch * epochs 45 | warmup_steps = int(warmup_rate * num_train_steps) 46 | 47 | def optimizer_fn(): 48 | if use_constant_lr: 49 | from tf_transformers.optimization.adam_weighted import AdamWeightDecay 50 | 51 | optimizer = AdamWeightDecay(learning_rate=learning_rate) 52 | return optimizer 53 | 54 | optimizer, learning_rate_fn = create_optimizer(learning_rate, num_train_steps, warmup_steps) 55 | return optimizer 56 | 57 | return optimizer_fn 58 | 59 | 60 | def get_loss(loss_type): 61 | """Get MLM Loss""" 62 | return get_lm_loss(loss_type=loss_type) 63 | 64 | 65 | def get_trainer(distribution_strategy, dtype, num_gpus=0, tpu_address=None): 66 | """Get Trainer""" 67 | trainer = Trainer(distribution_strategy, num_gpus=num_gpus, tpu_address=tpu_address, dtype=dtype) 68 | return trainer 69 | 70 | 71 | def get_hf_tokenizer(): 72 | """Get HuggingFace Tokenizer""" 73 | from transformers import BigBirdTokenizer 74 | 75 | tokenizer = BigBirdTokenizer.from_pretrained(TOKENIZER_NAME) 76 | return tokenizer 77 | -------------------------------------------------------------------------------- /research/sentence_language_model/run_mlm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """This is the main script to run MLM""" 18 | import os 19 | import warnings 20 | 21 | import hydra 22 | from absl import logging 23 | from omegaconf import DictConfig 24 | from train_mlm import run_train 25 | 26 | logging.set_verbosity("INFO") 27 | 28 | # We set PROJECT_NAME from ENVIORNMENT VARIABLE 29 | WANDB_PROJECT = os.getenv('WANDB_PROJECT', None) 30 | use_wandb = True 31 | if WANDB_PROJECT is None: 32 | warnings.warn( 33 | "For wandb-project should not be None.\ 34 | Set export WANDB_PROJECT=" 35 | ) 36 | use_wandb = False 37 | 38 | 39 | @hydra.main(config_path="conf", config_name="config") 40 | def run(cfg: DictConfig) -> None: 41 | print("Config", cfg) 42 | config_dict = dict(cfg) 43 | # For TPU, we need to initialize it before tf text dataset 44 | # starts triggering. Hack 45 | if cfg.trainer.strategy == 'tpu': 46 | from model import get_trainer 47 | 48 | distribution_strategy = 'tpu' 49 | num_gpus = 0 50 | tpu_address = cfg.trainer.tpu_address 51 | get_trainer( 52 | distribution_strategy=distribution_strategy, 53 | num_gpus=num_gpus, 54 | tpu_address=tpu_address, 55 | dtype=cfg.trainer.dtype, 56 | ) # noqa 57 | 58 | if use_wandb: 59 | import wandb 60 | 61 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True) 62 | history = run_train(cfg, wandb) 63 | else: 64 | # Set wandb = None 65 | history = run_train(cfg, None) 66 | return history 67 | 68 | 69 | if __name__ == "__main__": 70 | history = run() 71 | -------------------------------------------------------------------------------- /research/similarity_model_pretraining/README.MD: -------------------------------------------------------------------------------- 1 | 2 | ### Masked Language Modeling using TFText 3 | 4 | This folder has all necessary scripts to run the MLM using tensorflow text. 5 | 6 | ### Advantage 7 | 8 | No need to prepare features in TFRecord format. We make use of dynamic mlm. All we need is a text file or folder of text files with text files, which each line corrsponding to text. 9 | 10 | ### Configuration (Hydra) 11 | 12 | All or most configurations can be managed using ```conf/config.yaml```. You can override it by command line also. 13 | 14 | Eg: For TPU , we need a data in GCS and model_checkpoint_dir to be in GCS too. 15 | 16 | ```python3 run_mlm.py \ data.data_directory=$GCP_BUCKET/data/ \ trainer.model_checkpoint_dir=$GCP_BUCKET/model``` 17 | 18 | ### WandB 19 | 20 | By default we are using Wandb. Check ```run_mlm.py``` to disable it. 21 | -------------------------------------------------------------------------------- /research/similarity_model_pretraining/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/research/similarity_model_pretraining/__init__.py -------------------------------------------------------------------------------- /research/similarity_model_pretraining/conf/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_directory: 3 | train_batch_size: 32 4 | eval_batch_size: 32 5 | task: 6 | max_seq_len: 128 7 | max_predictions_per_seq: 40 8 | trainer: 9 | dtype: fp32 10 | num_gpus: 2 11 | tpu_address: 12 | epochs: 3 13 | strategy: mirrored 14 | steps_per_epoch: 10000 15 | model_checkpoint_dir: 16 | callback_steps: 1000 17 | optimizer: 18 | learning_rate: 5e-4 19 | warmup_rate: 0.1 20 | decay_function: cosine 21 | loss_type: 22 | use_constant_lr: false 23 | model: 24 | is_training: true 25 | use_dropout: true 26 | num_layers: 12 27 | -------------------------------------------------------------------------------- /research/similarity_model_pretraining/run_similarity.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """This is the main script to run MLM""" 18 | import os 19 | import warnings 20 | 21 | import hydra 22 | from absl import logging 23 | from omegaconf import DictConfig 24 | from train_similairity import run_train 25 | 26 | logging.set_verbosity("INFO") 27 | 28 | # We set PROJECT_NAME from ENVIORNMENT VARIABLE 29 | WANDB_PROJECT = os.getenv('WANDB_PROJECT', None) 30 | use_wandb = True 31 | if WANDB_PROJECT is None: 32 | warnings.warn( 33 | "For wandb-project should not be None.\ 34 | Set export WANDB_PROJECT=" 35 | ) 36 | use_wandb = False 37 | 38 | 39 | @hydra.main(config_path="conf", config_name="config") 40 | def run(cfg: DictConfig) -> None: 41 | print("Config", cfg) 42 | config_dict = dict(cfg) 43 | # For TPU, we need to initialize it before tf text dataset 44 | # starts triggering. Hack 45 | if cfg.trainer.strategy == 'tpu': 46 | from model import get_trainer 47 | 48 | distribution_strategy = 'tpu' 49 | num_gpus = 0 50 | tpu_address = cfg.trainer.tpu_address 51 | get_trainer( 52 | distribution_strategy=distribution_strategy, 53 | num_gpus=num_gpus, 54 | tpu_address=tpu_address, 55 | dtype=cfg.trainer.dtype, 56 | ) # noqa 57 | 58 | if use_wandb: 59 | import wandb 60 | 61 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True) 62 | history = run_train(cfg, wandb) 63 | else: 64 | # Set wandb = None 65 | history = run_train(cfg, None) 66 | return history 67 | 68 | 69 | if __name__ == "__main__": 70 | history = run() 71 | -------------------------------------------------------------------------------- /research/t5_style_pretraining/README.md: -------------------------------------------------------------------------------- 1 | 2 | ### Sentence Masked Language Modeling using TFText 3 | 4 | This folder has all necessary scripts to run the Sentence MLM using tensorflow text. 5 | Instead of masking words, we mask sentences (sequence of words) 6 | 7 | ### Advantage 8 | 9 | No need to prepare features in TFRecord format. We make use of dynamic mlm. All we need is a text file or folder of text files with text files, which each line corrsponding to text. 10 | 11 | ### WandB 12 | 13 | By default we are using Wandb. if enviornment variable ```WANDB_PROJECT=None```, wandb will be disabled. 14 | 15 | ``` export WANDB_PROJECT='t5-style-pretraining' ``` 16 | ### Configuration (Hydra) 17 | 18 | All or most configurations can be managed using ```conf/config.yaml```. You can override it by command line also. 19 | 20 | Eg: For TPU , we need a data in GCS and model_checkpoint_dir to be in GCS too. 21 | 22 | ``` nohup python3 run_t5_modified.py \ 23 | task.data_directory=gs://legacyai-bucket \ 24 | task.train_batch_size=128 \ 25 | trainer.dtype=bf16 \ 26 | trainer.model_checkpoint_dir=gs://legacyai-bucket/t5_style_t5_small_lr_0.0005 \ 27 | trainer.steps_per_epoch=10000 \ 28 | trainer.epochs=100 \ 29 | trainer.strategy=tpu \ 30 | trainer.tpu_address=legacyai-tpu-1 \ 31 | optimizer.learning_rate=0.01 \ 32 | model.is_training=true \ 33 | model.use_dropout=true \ 34 | model.model_name=t5-small > logs & 35 | ``` 36 | -------------------------------------------------------------------------------- /research/t5_style_pretraining/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/research/t5_style_pretraining/__init__.py -------------------------------------------------------------------------------- /research/t5_style_pretraining/conf/config.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | data_directory: 3 | max_seq_len: 256 4 | train_batch_size: 512 5 | trainer: 6 | dtype: bf16 7 | num_gpus: 0 8 | tpu_address: 9 | epochs: 100 10 | strategy: mirrored 11 | steps_per_epoch: 10000 12 | model_checkpoint_dir: 13 | global_norm: 5.0 14 | optimizer: 15 | learning_rate: 0.001 16 | num_warmup_steps: 0.1 17 | decay_function: cosine 18 | adam_beta_1: 0.9 19 | adam_beta_2: 0.95 20 | adam_epsilon: 10e-8 21 | weight_decay_rate: 0.1 22 | optimizer_type: adamw 23 | loss_type: 24 | use_constant_lr: false 25 | model: 26 | is_training: true 27 | use_dropout: true 28 | model_name: 29 | num_layers: 12 30 | -------------------------------------------------------------------------------- /research/t5_style_pretraining/run_t5_modified.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """This is the main script to run Mix Language Model""" 18 | import os 19 | import warnings 20 | 21 | import hydra 22 | from absl import logging 23 | from omegaconf import DictConfig 24 | from train_t5_modified import run_train 25 | 26 | logging.set_verbosity("INFO") 27 | 28 | # We set PROJECT_NAME from ENVIORNMENT VARIABLE 29 | WANDB_PROJECT = os.getenv('WANDB_PROJECT', None) 30 | use_wandb = True 31 | if WANDB_PROJECT is None: 32 | warnings.warn( 33 | "For wandb-project should not be None.\ 34 | Set export WANDB_PROJECT=" 35 | ) 36 | use_wandb = False 37 | 38 | 39 | @hydra.main(config_path="conf", config_name="config") 40 | def run(cfg: DictConfig) -> None: 41 | print("Config", cfg) 42 | config_dict = dict(cfg) 43 | # For TPU, we need to initialize it before tf text dataset 44 | # starts triggering. Hack 45 | if cfg.trainer.strategy == 'tpu': 46 | from model import get_trainer 47 | 48 | distribution_strategy = 'tpu' 49 | num_gpus = 0 50 | tpu_address = cfg.trainer.tpu_address 51 | get_trainer( 52 | distribution_strategy=distribution_strategy, 53 | num_gpus=num_gpus, 54 | tpu_address=tpu_address, 55 | dtype=cfg.trainer.dtype, 56 | ) # noqa 57 | 58 | if use_wandb: 59 | import wandb 60 | 61 | wandb.init(project=WANDB_PROJECT, config=config_dict, sync_tensorboard=True) 62 | history = run_train(cfg, wandb) 63 | else: 64 | # Set wandb = None 65 | history = run_train(cfg, None) 66 | return history 67 | 68 | 69 | if __name__ == "__main__": 70 | history = run() 71 | -------------------------------------------------------------------------------- /src/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/src/logo.png -------------------------------------------------------------------------------- /src/logo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/src/logo2.png -------------------------------------------------------------------------------- /src/tf_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.0.0" 2 | -------------------------------------------------------------------------------- /src/tf_transformers/activations/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Activations package definition.""" 16 | from tf_transformers.activations.gelu import gelu, quick_gelu 17 | from tf_transformers.activations.swish import hard_swish, identity, simple_swish 18 | from tf_transformers.activations.utils import get_activation 19 | -------------------------------------------------------------------------------- /src/tf_transformers/activations/gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Gaussian error linear unit.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import math 20 | 21 | import tensorflow as tf 22 | 23 | 24 | @tf.keras.utils.register_keras_serializable(package="Text") 25 | def gelu(x): 26 | """Gaussian Error Linear Unit. 27 | 28 | This is a smoother version of the RELU. 29 | Original paper: https://arxiv.org/abs/1606.08415 30 | Args: 31 | x: float Tensor to perform activation. 32 | 33 | Returns: 34 | `x` with the GELU activation applied. 35 | """ 36 | cdf = 0.5 * (1.0 + tf.tanh((math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3))))) 37 | return x * cdf 38 | 39 | 40 | @tf.keras.utils.register_keras_serializable(package="Text") 41 | def quick_gelu(x): 42 | """Quick GELU as in CLIP 43 | 44 | Returns: 45 | `x` with the Quick GELU activation applied. 46 | """ 47 | return x * tf.sigmoid(1.702 * x) 48 | -------------------------------------------------------------------------------- /src/tf_transformers/activations/swish.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Customized Swish activation.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import tensorflow as tf 20 | 21 | 22 | @tf.keras.utils.register_keras_serializable(package="Text") 23 | def simple_swish(features): 24 | """Computes the Swish activation function. 25 | 26 | The tf.nn.swish operation uses a custom gradient to reduce memory usage. 27 | Since saving custom gradients in SavedModel is currently not supported, and 28 | one would not be able to use an exported TF-Hub module for fine-tuning, we 29 | provide this wrapper that can allow to select whether to use the native 30 | TensorFlow swish operation, or whether to use a customized operation that 31 | has uses default TensorFlow gradient computation. 32 | 33 | Args: 34 | features: A `Tensor` representing preactivation values. 35 | 36 | Returns: 37 | The activation value. 38 | """ 39 | features = tf.convert_to_tensor(features) 40 | return features * tf.nn.sigmoid(features) 41 | 42 | 43 | @tf.keras.utils.register_keras_serializable(package="Text") 44 | def hard_swish(features): 45 | """Computes a hard version of the swish function. 46 | 47 | This operation can be used to reduce computational cost and improve 48 | quantization for edge devices. 49 | 50 | Args: 51 | features: A `Tensor` representing preactivation values. 52 | 53 | Returns: 54 | The activation value. 55 | """ 56 | features = tf.convert_to_tensor(features) 57 | return features * tf.nn.relu6(features + tf.constant(3.0)) * (1.0 / 6.0) 58 | 59 | 60 | @tf.keras.utils.register_keras_serializable(package="Text") 61 | def identity(features): 62 | """Computes the identity function. 63 | 64 | Useful for helping in quantization. 65 | 66 | Args: 67 | features: A `Tensor` representing preactivation values. 68 | 69 | Returns: 70 | The activation value. 71 | """ 72 | features = tf.convert_to_tensor(features) 73 | return tf.identity(features) 74 | -------------------------------------------------------------------------------- /src/tf_transformers/activations/utils.py: -------------------------------------------------------------------------------- 1 | import six 2 | import tensorflow as tf 3 | 4 | from tf_transformers import activations 5 | 6 | 7 | # TODO(hongkuny): consider moving custom string-map lookup to keras api. 8 | def get_activation(identifier): 9 | """Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`. 10 | 11 | It checks string first and if it is one of customized activation not in TF, 12 | the corresponding activation will be returned. For non-customized activation 13 | names and callable identifiers, always fallback to tf.keras.activations.get. 14 | 15 | Args: 16 | identifier: String name of the activation function or callable. 17 | 18 | Returns: 19 | A Python function corresponding to the activation function. 20 | """ 21 | if isinstance(identifier, six.string_types): 22 | name_to_fn = { 23 | "gelu": activations.gelu, 24 | "simple_swish": activations.simple_swish, 25 | "hard_swish": activations.hard_swish, 26 | "identity": activations.identity, 27 | "relu": tf.keras.activations.relu, 28 | "quick_gelu": activations.quick_gelu, 29 | } 30 | identifier = str(identifier).lower() 31 | if identifier in name_to_fn: 32 | return tf.keras.activations.get(name_to_fn[identifier]) 33 | return tf.keras.activations.get(identifier) 34 | -------------------------------------------------------------------------------- /src/tf_transformers/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/src/tf_transformers/callbacks/__init__.py -------------------------------------------------------------------------------- /src/tf_transformers/callbacks/metric_callback_list.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """This will return tf.keras.metric object based on name""" 18 | 19 | import tensorflow as tf 20 | 21 | _ALL_METRIC_NAMES = ['binary_accuracy'] 22 | 23 | 24 | def show_available_metric_names(): 25 | print(_ALL_METRIC_NAMES) 26 | 27 | 28 | def get_callback(metric_name: str): 29 | """Return tf.keras.metric with a name, for callback""" 30 | metric_name = metric_name.lower().strip() 31 | 32 | if metric_name not in _ALL_METRIC_NAMES: 33 | raise ValueError("{} not present in {}".format(metric_name, _ALL_METRIC_NAMES)) 34 | 35 | if metric_name == "binary_accuracy": 36 | return tf.keras.metrics.BinaryAccuracy, metric_name 37 | -------------------------------------------------------------------------------- /src/tf_transformers/callbacks/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from tf_transformers.callbacks.metrics.callbacks import MetricCallback 2 | from tf_transformers.callbacks.metrics.pearson_spearman_callback import ( 3 | PearsonSpearmanCallback, 4 | ) 5 | from tf_transformers.callbacks.metrics.sklearn_callbacks import SklearnMetricCallback 6 | from tf_transformers.callbacks.metrics.text_generation_callbacks import TextGenerationMetricCallback 7 | -------------------------------------------------------------------------------- /src/tf_transformers/core/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.core.chainer import ClassificationChainer, TextGenerationChainer 18 | from tf_transformers.core.legacy_layer import LegacyLayer 19 | from tf_transformers.core.legacy_model import LegacyModel 20 | from tf_transformers.core.legacy_module import LegacyModule, LegacyModuleCustom 21 | from tf_transformers.core.model_wrapper import ModelWrapper 22 | from tf_transformers.core.trainer import Trainer 23 | from tf_transformers.core.trainer_for_all import TrainerforAll 24 | from tf_transformers.core.transformer_config import TransformerConfig 25 | -------------------------------------------------------------------------------- /src/tf_transformers/core/model_utils_for_all.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from absl import logging 3 | 4 | 5 | def load_checkpoint_custom(model, checkpoint_dir=None, checkpoint_path=None, options=None, **kwargs): 6 | """[summary] 7 | 8 | Args: 9 | checkpoint_dir ([str]): [Location of the model] 10 | """ 11 | try: 12 | options = tf.train.CheckpointOptions(experimental_io_device="/job:localhost") 13 | except: 14 | options = tf.CheckpointOptions(experimental_io_device="/job:localhost") 15 | 16 | if checkpoint_dir: 17 | if tf.io.gfile.exists(checkpoint_dir): 18 | if tf.io.gfile.isdir(checkpoint_dir) is False: 19 | raise ValueError("checkpoint_dir expects a directory not a file {}.".format(checkpoint_dir)) 20 | if checkpoint_path: 21 | if tf.io.gfile.isdir(checkpoint_path) is True: 22 | raise ValueError("checkpoint_path expects a checkpoint-file not a directory {}.".format(checkpoint_path)) 23 | checkpoint = tf.train.Checkpoint(model=model, **kwargs) 24 | if checkpoint_path is None and checkpoint_dir: 25 | checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) 26 | if checkpoint_path is None: 27 | if checkpoint_dir: 28 | logging.info("No ❌❌ checkpoint found in {}".format(checkpoint_dir)) 29 | else: 30 | logging.info("No ❌❌ checkpoint found") 31 | return None 32 | else: 33 | if options: 34 | status = checkpoint.restore(checkpoint_path, options=options) 35 | else: 36 | status = checkpoint.restore(checkpoint_path) 37 | # Important 38 | if status.assert_existing_objects_matched(): 39 | logging.info("Successful ✅✅: Model checkpoints matched and loaded from {}".format(checkpoint_path)) 40 | return checkpoint 41 | else: 42 | logging.info("Failed ❌❌ to load the checkpoint. Status Assertion Failed.") 43 | return None 44 | -------------------------------------------------------------------------------- /src/tf_transformers/core/performance_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | # Map string to TensorFlow dtype 4 | DTYPE_MAP = { 5 | "fp16": tf.float16, 6 | "bf16": tf.bfloat16, 7 | "fp32": tf.float32, 8 | } 9 | 10 | 11 | def get_tf_dtype(dtype): 12 | return DTYPE_MAP[dtype] 13 | 14 | 15 | def is_float16(dtype): 16 | if dtype in [tf.float16]: 17 | return True 18 | return False 19 | 20 | 21 | def set_mixed_precision_policy(dtype): 22 | """Sets mix precision policy.""" 23 | if dtype == tf.float16: 24 | tf.keras.mixed_precision.set_global_policy('mixed_float16') 25 | elif dtype == tf.bfloat16: 26 | tf.keras.mixed_precision.set_global_policy('mixed_bfloat16') 27 | elif dtype == tf.float32: 28 | tf.keras.mixed_precision.set_global_policy('float32') 29 | else: 30 | raise ValueError('Unexpected dtype: %s' % dtype) 31 | 32 | 33 | def configure_optimizer(optimizer, use_float16=False, use_graph_rewrite=False, loss_scale='dynamic'): 34 | """Configures optimizer object with performance options.""" 35 | if use_float16: 36 | if loss_scale == 'dynamic': 37 | optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer) 38 | else: 39 | # loss_scale is a number. We interpret that as a fixed loss scale. 40 | optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer, dynamic=False, initial_scale=loss_scale) 41 | 42 | if use_graph_rewrite: 43 | # Note: the model dtype must be 'float32', which will ensure 44 | # tf.keras.mixed_precision and enable_mixed_precision_graph_rewrite do not 45 | # double up. 46 | optimizer = tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite(optimizer) 47 | return optimizer 48 | -------------------------------------------------------------------------------- /src/tf_transformers/core/read_from_hub.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import tensorflow as tf 5 | from absl import logging 6 | from huggingface_hub import hf_hub_download, snapshot_download 7 | 8 | logging.set_verbosity("INFO") 9 | 10 | 11 | def get_config_cache(url: str): 12 | """Load model from Huggingface hub""" 13 | local_cache = snapshot_download(repo_id=url) 14 | 15 | # Load config from cache 16 | config_path = Path(local_cache, "config.json") 17 | if not config_path.exists(): 18 | raise ValueError("config.json is not present in model hub {}".format(url)) 19 | config_dict = json.load(open(config_path)) 20 | return config_dict, local_cache 21 | 22 | 23 | def get_config_only(url: str): 24 | """Load config from Huggingface hub""" 25 | config_path = hf_hub_download(repo_id=url, filename="config.json") 26 | # Load config from cache 27 | config_dict = json.load(open(config_path)) 28 | return config_dict 29 | 30 | 31 | def load_pretrained_model(model: tf.keras.Model, local_cache: str, url: str): 32 | """Load model from cache""" 33 | try: 34 | local_device_option = tf.train.CheckpointOptions(experimental_io_device="/job:localhost") 35 | except: 36 | import traceback 37 | 38 | print(traceback.format_exc()) 39 | local_device_option = tf.CheckpointOptions(experimental_io_device="/job:localhost") 40 | else: 41 | local_device_option = None 42 | 43 | model.load_checkpoint(local_cache, options=local_device_option) 44 | logging.info("Successful ✅: Loaded model from {}".format(url)) 45 | -------------------------------------------------------------------------------- /src/tf_transformers/data/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.data.tfprocessor_utils import TFProcessor 18 | from tf_transformers.data.tfrecord_utils import TFReader, TFWriter 19 | from tf_transformers.data.utils import ( 20 | pad_dataset, 21 | pad_dataset_normal, 22 | pad_ragged, 23 | separate_x_y, 24 | ) 25 | -------------------------------------------------------------------------------- /src/tf_transformers/data/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/src/tf_transformers/data/callbacks/__init__.py -------------------------------------------------------------------------------- /src/tf_transformers/data/ner_utils_sp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors and The TensorFlow Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | 18 | from tf_transformers.utils import fast_sp_alignment 19 | 20 | 21 | def get_tokens_labels(aligned_words, orig_to_new_index, label_tokens, sub_words_mapped, label_pad_token="[PAD]"): 22 | """ 23 | convert each sub word into labels 24 | If a word is split into multiple sub words, 25 | then first sub word is assigned with label and other sub words will be padded 26 | """ 27 | aligned_labels = [label_pad_token] * len(aligned_words) 28 | for original_pos, new_pos in enumerate(orig_to_new_index): 29 | aligned_labels[new_pos] = label_tokens[original_pos] 30 | 31 | flat_tokens = [] 32 | flat_labels = [] 33 | 34 | # The first word of the subword token is assigned entity 35 | # other tokens will be add PAD labels (we will mask it while training) 36 | assert len(aligned_words) == len(sub_words_mapped) == len(aligned_labels) 37 | for (_align_word, _align_word, _align_label) in zip(aligned_words, sub_words_mapped, aligned_labels): 38 | temp_w = [] 39 | for _align_word in _align_word: 40 | temp_w.append(_align_word) 41 | temp_l = [label_pad_token] * len(temp_w) 42 | temp_l[0] = _align_label 43 | flat_tokens.extend(temp_w) 44 | flat_labels.extend(temp_l) 45 | 46 | return flat_tokens, flat_labels 47 | 48 | 49 | def fast_tokenize_and_align_sentence_for_ner( 50 | tokenizer, sentence, word_tokens, SPECIAL_PIECE, is_training=False, label_tokens=None, label_pad_token=None 51 | ): 52 | 53 | """ 54 | align sentence sub words and labels using fast_sp 55 | """ 56 | orig_to_new_index, aligned_words, sub_words_mapped = fast_sp_alignment(sentence, tokenizer, SPECIAL_PIECE) 57 | 58 | if is_training: 59 | flat_tokens, flat_labels = get_tokens_labels( 60 | aligned_words, orig_to_new_index, label_tokens, sub_words_mapped, label_pad_token 61 | ) 62 | return aligned_words, sub_words_mapped, flat_tokens, flat_labels 63 | else: 64 | flat_tokens = [w for sub_words in sub_words_mapped for w in sub_words] 65 | return aligned_words, sub_words_mapped, flat_tokens, orig_to_new_index 66 | -------------------------------------------------------------------------------- /src/tf_transformers/data/processors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/src/tf_transformers/data/processors/__init__.py -------------------------------------------------------------------------------- /src/tf_transformers/data/processors/mlm_ttext.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/src/tf_transformers/data/processors/mlm_ttext.py -------------------------------------------------------------------------------- /src/tf_transformers/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.layers.bias_layer import BiasLayer 18 | from tf_transformers.layers.image_embeddings import ( 19 | PatchEmbeddings, 20 | PositionEmbeddingImage, 21 | ) 22 | from tf_transformers.layers.layer_normalization import ( 23 | GPT2LayerNormalization, 24 | T5LayerNormalization, 25 | ) 26 | from tf_transformers.layers.mlm_layer import MaskedLM 27 | from tf_transformers.layers.on_device_embedding import OnDeviceEmbedding 28 | from tf_transformers.layers.position_embedding import PositionEmbedding 29 | -------------------------------------------------------------------------------- /src/tf_transformers/layers/attention/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.layers.attention.bart_attention import BartAttention 18 | from tf_transformers.layers.attention.bert_attention import MultiHeadAttention 19 | from tf_transformers.layers.attention.bigbird_attention import BigBirdAttention 20 | from tf_transformers.layers.attention.clip_attention import CLIPMultiHeadAttention 21 | from tf_transformers.layers.attention.gpt2_attention import GPT2Attention 22 | from tf_transformers.layers.attention.t5_attention import T5Attention 23 | -------------------------------------------------------------------------------- /src/tf_transformers/layers/bias_layer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors and The TensorFlow Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | import tensorflow as tf 18 | 19 | 20 | class BiasLayer(tf.keras.layers.Layer): 21 | def __init__(self, name="bias", trainable=True, initializer="zeros", *args, **kwargs): 22 | self._trainable = trainable 23 | self._initializer = initializer 24 | self._name = name 25 | super(BiasLayer, self).__init__(name=name, trainable=trainable, **kwargs) 26 | 27 | def build(self, input_shape): 28 | self.bias = self.add_weight( 29 | name="bias", shape=(input_shape[-1],), initializer=self._initializer, trainable=self._trainable 30 | ) 31 | 32 | def call(self, x): 33 | return x + self.bias 34 | -------------------------------------------------------------------------------- /src/tf_transformers/layers/image_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def read_and_process(img, img_height, img_width, num_channels, rescale, normalize, read): 5 | """Read and Process image. 6 | If read is True, tf.io will read and parse the image. 7 | If read is False, we expect a numpy array (PIL open image) or TF image (tf.uint8) 8 | 9 | Args: 10 | img (): tf.string Filepath or tf.uint8 Image array 11 | img_height (int): Image Height for resizing 12 | img_width (int): Image Width for resizing 13 | num_channels (int): 3 for RGB and 1 for GrayScale 14 | rescale (bool): rescale image to (0 to 1) 15 | normalize (bool): normalize image by 0 mean and 1 stddev 16 | read (bool): to read and decode the image or to skip it 17 | 18 | Returns: 19 | tf.float32 (image array (3D)) 20 | """ 21 | if read: 22 | # Read image 23 | img = tf.io.read_file(img) 24 | # convert the compressed string to a 3D uint8 tensor 25 | img = tf.image.decode_jpeg(img, channels=num_channels) 26 | 27 | # resize the image to the desired size 28 | img = tf.image.resize(img, [img_height, img_width]) 29 | # Rescale between (0 and 1) 30 | if rescale: 31 | img = tf.keras.layers.experimental.preprocessing.Rescaling(1.0 / 255)(img) 32 | # TODO tf.keras.layers.experimental.preprocessing.Normalization 33 | if normalize: 34 | img = tf.image.per_image_standardization(img) 35 | return img 36 | -------------------------------------------------------------------------------- /src/tf_transformers/layers/mask/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.layers.mask.causal_mask import CausalMask 18 | from tf_transformers.layers.mask.cross_attention_mask import CrossAttentionMask 19 | from tf_transformers.layers.mask.prefix_mask import prefix_mask 20 | from tf_transformers.layers.mask.self_attention_mask import SelfAttentionMask 21 | -------------------------------------------------------------------------------- /src/tf_transformers/layers/mask/causal_mask.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors and The TensorFlow Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """Keras layer that creates a self-attention mask.""" 18 | 19 | # from __future__ import google_type_annotations 20 | from __future__ import absolute_import, division, print_function 21 | 22 | import tensorflow as tf 23 | 24 | from tf_transformers.utils import tf_utils 25 | 26 | 27 | def attention_mask_square(nd): 28 | """1's in the lower triangle, counting from the lower right corner. 29 | 30 | Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. 31 | """ 32 | dtype = tf_utils.get_dtype() 33 | ns = nd 34 | i = tf.range(nd)[:, None] 35 | j = tf.range(ns) 36 | m = i >= j - ns + nd 37 | return tf.cast(m, dtype) 38 | 39 | 40 | @tf.keras.utils.register_keras_serializable(package="Text") 41 | class CausalMask(tf.keras.layers.Layer): 42 | """Create 3D attention mask from a 3D tensor mask. 43 | 44 | inputs[0]: from_tensor: 2D or 3D Tensor of shape 45 | [batch_size, from_seq_length, ...]. 46 | inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 47 | 48 | Returns: 49 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 50 | """ 51 | 52 | def call(self, inputs): 53 | from_tensor = inputs 54 | from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3]) 55 | batch_size = from_shape[0] 56 | from_seq_length = from_shape[1] 57 | 58 | # 2D Lower Triangular Mask 59 | from_mask = attention_mask_square(from_seq_length) 60 | 61 | # Replicate 2D `N` times 62 | mask = tf.cast(tf.ones([batch_size, 1, 1]), from_mask.dtype) * from_mask 63 | 64 | return mask 65 | -------------------------------------------------------------------------------- /src/tf_transformers/layers/mask/cross_attention_mask.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors and The TensorFlow Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """Keras layer that creates a self-attention mask.""" 18 | 19 | # from __future__ import google_type_annotations 20 | from __future__ import absolute_import, division, print_function 21 | 22 | import tensorflow as tf 23 | 24 | from tf_transformers.utils import tf_utils 25 | 26 | 27 | @tf.keras.utils.register_keras_serializable(package="Text") 28 | class CrossAttentionMask(tf.keras.layers.Layer): 29 | """Create 3D attention mask from a 2D tensor mask. 30 | 31 | inputs[0]: from_tensor: 2D or 3D Tensor of shape 32 | [batch_size, from_seq_length, ...]. 33 | inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 34 | 35 | Returns: 36 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 37 | """ 38 | 39 | def __init__(self, **kwargs): 40 | # We need to have a default dtype of float32, since the inputs (which Keras 41 | # usually uses to infer the dtype) will always be int32. 42 | if "dtype" not in kwargs: 43 | kwargs["dtype"] = tf_utils.get_dtype() 44 | 45 | super(CrossAttentionMask, self).__init__(**kwargs) 46 | self._dtype = kwargs["dtype"] 47 | 48 | def call(self, inputs): 49 | to_mask = inputs[1] 50 | batch_size, from_seq_length = tf_utils.get_shape_list(inputs[0]) 51 | _, to_seq_length = tf_utils.get_shape_list(inputs[1]) 52 | 53 | to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]), dtype=self._dtype) 54 | 55 | # We don't assume that `from_tensor` is a mask (although it could be). We 56 | # don't actually care if we attend *from* padding tokens (only *to* padding) 57 | # tokens so we create a tensor of all ones. 58 | # 59 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 60 | broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1], dtype=self._dtype) 61 | 62 | # Here we broadcast along two dimensions to create the mask. 63 | mask = broadcast_ones * to_mask 64 | 65 | return mask 66 | -------------------------------------------------------------------------------- /src/tf_transformers/layers/mask/self_attention_mask.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors and The TensorFlow Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """Keras layer that creates a self-attention mask.""" 18 | 19 | # from __future__ import google_type_annotations 20 | from __future__ import absolute_import, division, print_function 21 | 22 | import tensorflow as tf 23 | 24 | from tf_transformers.utils import tf_utils 25 | 26 | 27 | @tf.keras.utils.register_keras_serializable(package="Text") 28 | class SelfAttentionMask(tf.keras.layers.Layer): 29 | """Create 3D attention mask from a 2D tensor mask. 30 | 31 | inputs[0]: from_tensor: 2D or 3D Tensor of shape 32 | [batch_size, from_seq_length, ...]. 33 | inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 34 | 35 | Returns: 36 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 37 | """ 38 | 39 | def call(self, inputs): 40 | """ 41 | 42 | Args: 43 | inputs : List of ([embeddings, input_mask]) 44 | embeddings: 3D (b x s x h) 45 | input_mask: 2D (b x s) 46 | Returns: 47 | Tensor: (b x from_seq_length x to_seq_length) 48 | """ 49 | from_tensor = inputs[0] 50 | to_mask = inputs[1] 51 | from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3]) 52 | batch_size = from_shape[0] 53 | from_seq_length = from_shape[1] 54 | 55 | to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2) 56 | to_seq_length = to_shape[1] 57 | 58 | to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]), dtype=from_tensor.dtype) 59 | 60 | # We don't assume that `from_tensor` is a mask (although it could be). We 61 | # don't actually care if we attend *from* padding tokens (only *to* padding) 62 | # tokens so we create a tensor of all ones. 63 | # 64 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 65 | broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype) 66 | 67 | # Here we broadcast along two dimensions to create the mask. 68 | mask = broadcast_ones * to_mask 69 | 70 | return mask 71 | -------------------------------------------------------------------------------- /src/tf_transformers/layers/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.layers.transformer.bart_transformer import TransformerBART 18 | from tf_transformers.layers.transformer.bert_transformer import TransformerBERT 19 | from tf_transformers.layers.transformer.clip_transformer import TransformerCLIP 20 | from tf_transformers.layers.transformer.gpt2_transformer import TransformerGPT2 21 | from tf_transformers.layers.transformer.mt5_transformer import TransformermT5 22 | from tf_transformers.layers.transformer.t5_transformer import TransformerT5 23 | from tf_transformers.layers.transformer.byt5_transformer import TransformerByT5 24 | from tf_transformers.layers.transformer.vit_transformer import TransformerVIT 25 | -------------------------------------------------------------------------------- /src/tf_transformers/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.losses.cross_entropy import ( 18 | cross_entropy_loss, 19 | cross_entropy_loss_for_classification, 20 | cross_entropy_loss_label_smoothing, 21 | ) 22 | -------------------------------------------------------------------------------- /src/tf_transformers/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.albert import ( 18 | AlbertConfig, 19 | AlbertEncoder, 20 | AlbertModel, 21 | AlbertTokenizerLayer, 22 | AlbertTokenizerTFText, 23 | ) 24 | from tf_transformers.models.bart import BartConfig, BartEncoder, BartModel 25 | from tf_transformers.models.bert import BertConfig, BertEncoder, BertModel 26 | from tf_transformers.models.bigbird import ( 27 | BigBirdRobertaTokenizerLayer, 28 | BigBirdRobertaTokenizerTFText, 29 | ) 30 | from tf_transformers.models.clip import ( 31 | CLIPEncoder, 32 | CLIPFeatureExtractorTF, 33 | CLIPImageConfig, 34 | CLIPImageEncoder, 35 | CLIPModel, 36 | CLIPTextConfig, 37 | CLIPTextEncoder, 38 | ) 39 | from tf_transformers.models.distilbert import DistilBertConfig, DistilBertModel 40 | from tf_transformers.models.encoder_decoder import EncoderDecoder 41 | from tf_transformers.models.gpt2 import GPT2Config, GPT2Encoder, GPT2Model 42 | from tf_transformers.models.minilm import MiniLMConfig, MiniLMModel 43 | from tf_transformers.models.mt5 import MT5Config, MT5Encoder, MT5Model 44 | from tf_transformers.models.roberta import RobertaConfig, RobertaEncoder, RobertaModel 45 | from tf_transformers.models.sentence_transformers import SentenceTransformer 46 | from tf_transformers.models.t5 import ( 47 | T5Config, 48 | T5Encoder, 49 | T5Model, 50 | T5TokenizerLayer, 51 | T5TokenizerTFText, 52 | ) 53 | from tf_transformers.models.byt5 import ( 54 | ByT5Config, 55 | ByT5Encoder, 56 | ByT5Model, 57 | ) 58 | from tf_transformers.models.tasks import ( 59 | Classification_Model, 60 | MaskedLMModel, 61 | Similarity_Model, 62 | Span_Selection_Model, 63 | ) 64 | from tf_transformers.models.vit import ( 65 | ViTConfig, 66 | ViTEncoder, 67 | ViTFeatureExtractorTF, 68 | ViTModel, 69 | ) 70 | -------------------------------------------------------------------------------- /src/tf_transformers/models/albert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.albert.albert import AlbertEncoder 18 | from tf_transformers.models.albert.albert_model import AlbertModel 19 | from tf_transformers.models.albert.configuration_albert import AlbertConfig 20 | from tf_transformers.models.albert.tokenizer_albert import ( 21 | AlbertTokenizerLayer, 22 | AlbertTokenizerTFText, 23 | ) 24 | -------------------------------------------------------------------------------- /src/tf_transformers/models/bart/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.bart.bart import BartEncoder 18 | from tf_transformers.models.bart.bart_model import BartModel 19 | from tf_transformers.models.bart.configuration_bart import BartConfig 20 | -------------------------------------------------------------------------------- /src/tf_transformers/models/bert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.bert.bert import BertEncoder 18 | from tf_transformers.models.bert.bert_model import BertModel 19 | from tf_transformers.models.bert.configuration_bert import BertConfig 20 | -------------------------------------------------------------------------------- /src/tf_transformers/models/bigbird/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.bigbird.tokenizer_bigbird_roberta import ( 18 | BigBirdRobertaTokenizerLayer, 19 | BigBirdRobertaTokenizerTFText, 20 | ) 21 | -------------------------------------------------------------------------------- /src/tf_transformers/models/byt5/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.byt5.configuration_byt5 import ByT5Config 18 | from tf_transformers.models.byt5.byt5 import ByT5Encoder 19 | from tf_transformers.models.byt5.byt5_model import ByT5Model 20 | -------------------------------------------------------------------------------- /src/tf_transformers/models/clip/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.clip.clip import CLIPEncoder 18 | from tf_transformers.models.clip.clip_image_encoder import CLIPImageEncoder 19 | from tf_transformers.models.clip.clip_text_encoder import CLIPTextEncoder 20 | from tf_transformers.models.clip.configuration_clip import ( 21 | CLIPImageConfig, 22 | CLIPTextConfig, 23 | ) 24 | from tf_transformers.models.clip.clip_feature_extractor import CLIPFeatureExtractorTF 25 | from tf_transformers.models.clip.clip_model import CLIPModel 26 | -------------------------------------------------------------------------------- /src/tf_transformers/models/distilbert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.bert.bert import BertEncoder 18 | from tf_transformers.models.distilbert.configuration_bert import DistilBertConfig 19 | from tf_transformers.models.distilbert.distilbert_model import DistilBertModel 20 | -------------------------------------------------------------------------------- /src/tf_transformers/models/encoder_decoder/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.encoder_decoder.encoder_decoder import EncoderDecoder 18 | -------------------------------------------------------------------------------- /src/tf_transformers/models/gpt2/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.gpt2.configuration_gpt2 import GPT2Config 18 | from tf_transformers.models.gpt2.gpt2 import GPT2Encoder 19 | from tf_transformers.models.gpt2.gpt2_model import GPT2Model 20 | -------------------------------------------------------------------------------- /src/tf_transformers/models/minilm/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.bert.bert import BertEncoder 18 | from tf_transformers.models.minilm.configuration_minilm import MiniLMConfig 19 | from tf_transformers.models.minilm.minilm_model import MiniLMModel 20 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.model_configs.albert import albert_base_v2, albert_large_v2 18 | from tf_transformers.models.model_configs.bert import ( 19 | bert_base_cased, 20 | bert_base_uncased, 21 | bert_large_cased, 22 | bert_large_uncased, 23 | ) 24 | from tf_transformers.models.model_configs.general_config import TransformerConfig 25 | from tf_transformers.models.model_configs.gpt2 import gpt2, gpt2_medium 26 | from tf_transformers.models.model_configs.mt5 import mt5_small 27 | from tf_transformers.models.model_configs.roberta import roberta_base, roberta_large 28 | from tf_transformers.models.model_configs.t5 import t5_base, t5_small 29 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/albert/albert_base_v2.py: -------------------------------------------------------------------------------- 1 | config = { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "intermediate_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "embedding_size": 128, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "attention_head_size": 64, 12 | "num_hidden_layers": 12, 13 | "type_vocab_size": 2, 14 | "vocab_size": 30000, 15 | "layer_norm_epsilon": 1e-12, 16 | "embedding_projection_size": 768, 17 | } 18 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/albert/albert_large_v2.py: -------------------------------------------------------------------------------- 1 | config = { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "intermediate_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "embedding_size": 128, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 16, 11 | "attention_head_size": 64, 12 | "num_hidden_layers": 24, 13 | "type_vocab_size": 2, 14 | "vocab_size": 30000, 15 | "layer_norm_epsilon": 1e-12, 16 | "embedding_projection_size": 1024, 17 | } 18 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/bert/bert_base_cased.py: -------------------------------------------------------------------------------- 1 | config = { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "intermediate_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "embedding_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "attention_head_size": 64, 12 | "num_hidden_layers": 12, 13 | "type_vocab_size": 2, 14 | "vocab_size": 28996, 15 | "layer_norm_epsilon": 1e-12, 16 | } 17 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/bert/bert_base_uncased.py: -------------------------------------------------------------------------------- 1 | config = { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "intermediate_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "embedding_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "attention_head_size": 64, 13 | "type_vocab_size": 2, 14 | "vocab_size": 30522, 15 | "layer_norm_epsilon": 1e-12, 16 | } 17 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/bert/bert_large_cased.py: -------------------------------------------------------------------------------- 1 | config = { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "intermediate_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "embedding_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 16, 11 | "attention_head_size": 64, 12 | "num_hidden_layers": 24, 13 | "type_vocab_size": 2, 14 | "vocab_size": 28996, 15 | "layer_norm_epsilon": 1e-12, 16 | } 17 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/bert/bert_large_uncased.py: -------------------------------------------------------------------------------- 1 | config = { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "intermediate_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "embedding_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 16, 11 | "attention_head_size": 64, 12 | "num_hidden_layers": 24, 13 | "type_vocab_size": 2, 14 | "vocab_size": 30522, 15 | "layer_norm_epsilon": 1e-12, 16 | } 17 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/gpt2/gpt2.py: -------------------------------------------------------------------------------- 1 | config = { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "intermediate_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "embedding_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 1024, 10 | "num_attention_heads": 12, 11 | "attention_head_size": 64, 12 | "num_hidden_layers": 12, 13 | "type_vocab_size": -1, 14 | "vocab_size": 50257, 15 | "layer_norm_epsilon": 1e-05, 16 | "mask_mode": "causal", 17 | } 18 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/gpt2/gpt2_medium.py: -------------------------------------------------------------------------------- 1 | config = { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "intermediate_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "embedding_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 1024, 10 | "num_attention_heads": 16, 11 | "attention_head_size": 64, 12 | "num_hidden_layers": 24, 13 | "type_vocab_size": -1, 14 | "vocab_size": 50257, 15 | "layer_norm_epsilon": 1e-05, 16 | "mask_mode": "causal", 17 | } 18 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/mt5/mt5_small.py: -------------------------------------------------------------------------------- 1 | config = { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "intermediate_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "embedding_size": 512, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 1024, 9 | "max_position_embeddings": -1, 10 | "num_attention_heads": 6, 11 | "attention_head_size": 64, 12 | "num_hidden_layers": 8, 13 | "vocab_size": 250112, 14 | "type_vocab_size": -1, 15 | "layer_norm_epsilon": 1e-06, 16 | "bidirectional": True, 17 | "positional_buckets": 32, 18 | } 19 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/roberta/roberta_base.py: -------------------------------------------------------------------------------- 1 | config = { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "intermediate_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "embedding_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "attention_head_size": 64, 12 | "num_hidden_layers": 12, 13 | "type_vocab_size": 1, 14 | "vocab_size": 50265, 15 | "layer_norm_epsilon": 1e-05, 16 | } 17 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/roberta/roberta_large.py: -------------------------------------------------------------------------------- 1 | config = { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "intermediate_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "embedding_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 16, 11 | "attention_head_size": 64, 12 | "num_hidden_layers": 24, 13 | "type_vocab_size": 1, 14 | "vocab_size": 50265, 15 | "layer_norm_epsilon": 1e-05, 16 | } 17 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/t5/t5_base.py: -------------------------------------------------------------------------------- 1 | config = { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "intermediate_act": "relu", 5 | "hidden_dropout_prob": 0.1, 6 | "embedding_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": -1, 10 | "num_attention_heads": 12, 11 | "attention_head_size": 64, 12 | "num_hidden_layers": 12, 13 | "vocab_size": 32128, 14 | "type_vocab_size": -1, 15 | "layer_norm_epsilon": 1e-06, 16 | "bidirectional": True, 17 | "positional_buckets": 32, 18 | } 19 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/t5/t5_small.py: -------------------------------------------------------------------------------- 1 | config = { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "intermediate_act": "relu", 5 | "hidden_dropout_prob": 0.1, 6 | "embedding_size": 512, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 2048, 9 | "max_position_embeddings": -1, 10 | "num_attention_heads": 8, 11 | "attention_head_size": 64, 12 | "num_hidden_layers": 6, 13 | "vocab_size": 32128, 14 | "type_vocab_size": -1, 15 | "layer_norm_epsilon": 1e-06, 16 | "bidirectional": True, 17 | "positional_buckets": 32, 18 | } 19 | -------------------------------------------------------------------------------- /src/tf_transformers/models/model_configs/unilm_cnndm/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "intermediate_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "embedding_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 768, 10 | "num_attention_heads": 16, 11 | "num_hidden_layers": 24, 12 | "type_vocab_size": 6, 13 | "vocab_size": 28996, 14 | "layer_norm_epsilon": 1e-05 15 | } 16 | -------------------------------------------------------------------------------- /src/tf_transformers/models/mt5/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.mt5.configuration_mt5 import MT5Config 18 | from tf_transformers.models.mt5.mt5 import MT5Encoder 19 | from tf_transformers.models.mt5.mt5_model import MT5Model 20 | -------------------------------------------------------------------------------- /src/tf_transformers/models/roberta/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.roberta.configuration_roberta import RobertaConfig 18 | from tf_transformers.models.roberta.roberta import RobertaEncoder 19 | from tf_transformers.models.roberta.roberta_model import RobertaModel 20 | -------------------------------------------------------------------------------- /src/tf_transformers/models/sentence_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.sentence_transformers.sentence_transformers import ( 18 | SentenceTransformer, 19 | ) 20 | -------------------------------------------------------------------------------- /src/tf_transformers/models/t5/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.t5.configuration_t5 import T5Config 18 | from tf_transformers.models.t5.t5 import T5Encoder 19 | from tf_transformers.models.t5.t5_model import T5Model 20 | from tf_transformers.models.t5.tokenizer_t5 import T5TokenizerLayer, T5TokenizerTFText 21 | -------------------------------------------------------------------------------- /src/tf_transformers/models/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.tasks.classification import Classification_Model 18 | from tf_transformers.models.tasks.maked_lm_model import MaskedLMModel 19 | from tf_transformers.models.tasks.similarity_model import Similarity_Model 20 | from tf_transformers.models.tasks.span_selection import Span_Selection_Model 21 | -------------------------------------------------------------------------------- /src/tf_transformers/models/vit/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.models.vit.configuration_vit import ViTConfig 18 | from tf_transformers.models.vit.vit import ViTEncoder 19 | from tf_transformers.models.vit.vit_feature_extractor import ViTFeatureExtractorTF 20 | from tf_transformers.models.vit.vit_model import ViTModel 21 | -------------------------------------------------------------------------------- /src/tf_transformers/optimization/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.optimization.optimization import create_optimizer 18 | -------------------------------------------------------------------------------- /src/tf_transformers/text/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.text.decoder_utils import ( 18 | _gather_beams, 19 | _log_prob_from_logits, 20 | assign_zeros_to_K_V, 21 | top_k_logits, 22 | top_p_logits, 23 | ) 24 | from tf_transformers.text.sentencepiece_layer import SentencepieceTokenizer 25 | from tf_transformers.text.text_decoder import TextDecoder, TextDecoderSerializable 26 | from tf_transformers.text.text_decoder_model import TextDecoderModel 27 | from tf_transformers.text.text_decoder_seq2seq import TextDecoderSeq2Seq 28 | from tf_transformers.text.text_decoder_seq2seq_serializable import ( 29 | TextDecoderSerializableSeq2Seq, 30 | ) 31 | -------------------------------------------------------------------------------- /src/tf_transformers/text/lm_tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from tf_transformers.text.lm_tasks.causal_lm import causal_lm_fn 2 | from tf_transformers.text.lm_tasks.masked_lm import mlm_fn 3 | from tf_transformers.text.lm_tasks.prefix_lm import prefix_lm_fn, prefix_lm_fn_v2 4 | -------------------------------------------------------------------------------- /src/tf_transformers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 TF-Transformers Authors. 3 | # All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | from tf_transformers.utils.fast_sp_alignment import fast_sp_alignment 18 | from tf_transformers.utils.tokenization import BasicTokenizer 19 | from tf_transformers.utils.utils import ( 20 | get_config, 21 | get_model_wrapper, 22 | validate_model_name, 23 | ) 24 | -------------------------------------------------------------------------------- /src/tf_transformers/utils/docstring_file_utils.py: -------------------------------------------------------------------------------- 1 | def add_start_docstrings(*docstr): 2 | def docstring_decorator(fn): 3 | fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") 4 | return fn 5 | 6 | return docstring_decorator 7 | -------------------------------------------------------------------------------- /src/tf_transformers/utils/positional_bias_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): 7 | """ 8 | Adapted from Mesh Tensorflow: 9 | https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/\ 10 | mesh_tensorflow/transformer/transformer_layers.py#L593 11 | 12 | Translate relative position to a bucket number for relative attention. 13 | The relative position is defined as memory_position - query_position, i.e. 14 | the distance in tokens from the attending position to the attended-to 15 | position. If bidirectional=False, then positive relative positions are 16 | invalid. 17 | We use smaller buckets for small absolute relative_position and larger buckets 18 | for larger absolute relative_positions. All relative positions >=max_distance 19 | map to the same bucket. All relative positions <=-max_distance map to the 20 | same bucket. This should allow for more graceful generalization to longer 21 | sequences than the model has been trained on. 22 | Args: 23 | relative_position: an int32 Tensor 24 | bidirectional: a boolean - whether the attention is bidirectional 25 | num_buckets: an integer 26 | max_distance: an integer 27 | Returns: 28 | a Tensor with the same shape as relative_position, containing int32 29 | values in the range [0, num_buckets) 30 | """ 31 | ret = 0 32 | n = -relative_position 33 | if bidirectional: 34 | num_buckets //= 2 35 | ret += tf.dtypes.cast(tf.math.less(n, 0), tf.int32) * num_buckets 36 | n = tf.math.abs(n) 37 | else: 38 | n = tf.math.maximum(n, 0) 39 | # now n is in the range [0, inf) 40 | max_exact = num_buckets // 2 41 | is_small = tf.math.less(n, max_exact) 42 | val_if_large = max_exact + tf.dtypes.cast( 43 | tf.math.log(tf.dtypes.cast(n, tf.float32) / max_exact) 44 | / math.log(max_distance / max_exact) 45 | * (num_buckets - max_exact), 46 | tf.int32, 47 | ) 48 | val_if_large = tf.math.minimum(val_if_large, num_buckets - 1) 49 | ret += tf.where(is_small, n, val_if_large) 50 | return ret 51 | 52 | 53 | def compute_positional_bias(qlen, klen, bidirectional=True, num_buckets=32): 54 | """Compute binned relative position bias""" 55 | context_position = tf.range(qlen)[:, None] 56 | memory_position = tf.range(klen)[None, :] 57 | relative_position = memory_position - context_position # shape (qlen, klen) 58 | rp_bucket = _relative_position_bucket( 59 | relative_position, 60 | bidirectional=bidirectional, 61 | num_buckets=num_buckets, 62 | ) 63 | return rp_bucket 64 | -------------------------------------------------------------------------------- /src/tf_transformers/utils/push_to_hub.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | from distutils.dir_util import copy_tree 5 | 6 | from absl import logging 7 | 8 | from tf_transformers.models import ViTModel 9 | 10 | logging.set_verbosity("INFO") 11 | 12 | # Load Model 13 | model_name = 'vit-large-patch32-384' 14 | model = ViTModel.from_pretrained(model_name) 15 | 16 | model = ViTModel.from_pretrained(model_name, classification_labels=1000) 17 | 18 | 19 | models_list = [ 20 | 'vit-base-patch16-224', 21 | 'vit-base-patch32-384', 22 | 'vit-base-patch32-224-in21k', 23 | 'vit-large-patch16-224', 24 | 'vit-large-patch32-384', 25 | ] 26 | cwd = os.getcwd() 27 | MODEL_DIR = '/home/sarathrnair/MODELS/' 28 | 29 | for model_name in models_list: 30 | 31 | subprocess.run( 32 | ['huggingface-cli', 'repo', 'create', '{}'.format(model_name), '--yes', '--organization', 'tftransformers'] 33 | ) 34 | 35 | subprocess.run(["git", "clone", "https://huggingface.co/tftransformers/{}".format(model_name)]) 36 | new_working_dir = os.path.join(cwd, model_name) 37 | os.chdir("{}".format(new_working_dir)) 38 | cached_model_dir = os.path.join(MODEL_DIR, "tf_transformers_cache/{}/".format(model_name)) 39 | copy_tree(cached_model_dir, new_working_dir) 40 | 41 | subprocess.run(["git-lfs", "track", "*"]) 42 | subprocess.run(["git", "add", "."]) 43 | subprocess.run(["git", "commit", "-m", "New Model"]) 44 | subprocess.run(["git", "push"]) 45 | 46 | os.chdir("{}".format(cwd)) 47 | time.sleep(2) 48 | logging.info("Completed {}".format(model_name)) 49 | print("------------------------------------------------------------------") 50 | -------------------------------------------------------------------------------- /src/tf_transformers/utils/utils.py: -------------------------------------------------------------------------------- 1 | def get_config(module_name, config_name): 2 | """Load config from .py based on importlib 3 | 4 | Args: 5 | module_name ([type]): [description] 6 | config_name ([type]): [description] 7 | 8 | Raises: 9 | ValueError: [description] 10 | 11 | """ 12 | import importlib 13 | 14 | my_module = importlib.import_module(module_name) 15 | config = getattr(my_module, config_name).config 16 | return config 17 | 18 | 19 | def get_model_wrapper(model_name): 20 | 21 | import importlib 22 | 23 | model_name = model_name.split("_")[0].strip() # roberta_base --> roberta 24 | model_cls = importlib.import_module("tf_transformers.models.model_wrappers.{}_wrapper".format(model_name)) 25 | return model_cls.modelWrapper 26 | 27 | 28 | def validate_model_name(model_name, allowed_model_names): 29 | """Validate model_name 30 | 31 | Args: 32 | model_name ([type]): [description] 33 | 34 | Raises: 35 | ValueError: [description] 36 | """ 37 | if model_name not in allowed_model_names: 38 | raise ValueError("{} not in allowed names {}".format(model_name, allowed_model_names)) 39 | 40 | 41 | # This is unused 42 | # def pytorch_conversion_debug(): 43 | # # Check pytorch conversion wiith TF conversion 44 | # import numpy as np 45 | # for index , var in enumerate(model.variables): 46 | 47 | # var2 = model_tf.variables[index] 48 | 49 | # var_shape = list(var.shape) 50 | # var2_shape = list(var2.shape) 51 | 52 | # assert(var_shape == var2_shape) 53 | 54 | # if len(var_shape) == 1: 55 | # var_sum = tf.reduce_sum(var).numpy() 56 | # var2_sum = tf.reduce_sum(var2).numpy() 57 | # assert(np.allclose(var_sum, var2_sum) == True) 58 | 59 | # if len(var_shape) == 2: 60 | # var_sum = tf.reduce_sum(var, axis=-1).numpy() 61 | # var2_sum = tf.reduce_sum(var2, axis=-1).numpy() 62 | # assert(np.allclose(var_sum, var2_sum) == True) 63 | 64 | # if len(var_shape) == 3: 65 | # var_sum = tf.reduce_sum(var, axis=[0, 2]).numpy() 66 | # var2_sum = tf.reduce_sum(var2, axis=[0, 2]).numpy() 67 | # assert(np.allclose(var_sum, var2_sum) == True) 68 | -------------------------------------------------------------------------------- /src/tf_transformers/utils/viz_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import seaborn as sns 4 | 5 | 6 | def compute_scores(vectors): 7 | corr = np.inner(vectors, vectors) 8 | cmax = np.max(corr) 9 | corr /= cmax 10 | return corr 11 | 12 | 13 | def plot_similarity(labels, features1, features2, rotation, title1="Model1", title2="Model2"): 14 | 15 | corr1 = compute_scores(features1) 16 | corr2 = compute_scores(features2) 17 | sns.set(rc={"axes.facecolor": "white", "figure.facecolor": "white"}) 18 | sns.set_context("poster") 19 | 20 | fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 8)) 21 | fig.subplots_adjust(wspace=0.02) 22 | 23 | sns.set(font_scale=1.0) 24 | g1 = sns.heatmap( 25 | corr1, 26 | ax=ax1, 27 | cbar=False, 28 | yticklabels=labels, 29 | xticklabels=labels, 30 | vmin=np.min(corr1), 31 | vmax=np.max(corr1), 32 | cmap="Blues", 33 | ) 34 | 35 | g2 = sns.heatmap( 36 | corr2, 37 | ax=ax2, 38 | cbar=False, 39 | xticklabels=labels, 40 | vmin=np.min(corr2), 41 | vmax=np.max(corr2), 42 | cmap="Blues", 43 | ) 44 | g2.set(yticks=[]) 45 | fig.colorbar(ax2.collections[0], ax=ax1, location="right", use_gridspec=False, pad=0.01) 46 | fig.colorbar(ax2.collections[0], ax=ax2, location="right", use_gridspec=False, pad=0.01) 47 | 48 | g1.set_title(title1) 49 | g2.set_title(title2) 50 | -------------------------------------------------------------------------------- /src/transformers_blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/src/transformers_blue.png -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/legacyai/tf-transformers/f08fcde836a320ee4250ede04e5717837db77ada/tests/__init__.py -------------------------------------------------------------------------------- /tests/model_test_scripts/test_wav2vec2.py: -------------------------------------------------------------------------------- 1 | 2 | from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC 3 | from datasets import load_dataset 4 | # import soundfile as sf 5 | import torch 6 | 7 | # load model and processor 8 | processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h") 9 | model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h") 10 | 11 | 12 | from scipy.io import wavfile 13 | import numpy as np 14 | 15 | file_name = 'sample.wav' 16 | data = wavfile.read(file_name) 17 | framerate = data[0] 18 | sounddata = data[1] 19 | time = np.arange(0,len(sounddata))/framerate 20 | print('Sample rate:',framerate,'Hz') 21 | print('Total time:',len(sounddata)/framerate,'s') 22 | 23 | # Load file 24 | import librosa 25 | input_audio, _ = librosa.load(file_name, 26 | sr=16000) 27 | 28 | input_values = processor(input_audio, return_tensors="pt").input_values # torch.Size([1, 3270299]) 29 | 30 | logits = model(input_values).logits # torch.Size([1, 10219, 32]) 31 | 32 | predicted_ids = torch.argmax(logits, dim=-1) # torch.Size([1, 10219]) 33 | 34 | 35 | transcription = processor.batch_decode(predicted_ids)[0] -------------------------------------------------------------------------------- /tests/test_tf_transformers.py: -------------------------------------------------------------------------------- 1 | from tf_transformers import __version__ 2 | 3 | version = "2.0.0" 4 | def test_version(): 5 | assert __version__ == version 6 | -------------------------------------------------------------------------------- /tutorials/README.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "4193b067", 6 | "metadata": {}, 7 | "source": [ 8 | "We use jupytext to keep copies of notebook in sync with Markdown equivalent.\n", 9 | "\n", 10 | "### Adding a new notebook\n", 11 | "\n", 12 | "```\n", 13 | "jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb\n", 14 | "```\n", 15 | "\n", 16 | "### Syncing Notebooks\n", 17 | "\n", 18 | "After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running:\n", 19 | "\n", 20 | "```\n", 21 | "jupytext --sync docs/notebooks/*\n", 22 | "```" 23 | ] 24 | } 25 | ], 26 | "metadata": { 27 | "jupytext": { 28 | "cell_metadata_filter": "-all", 29 | "formats": "ipynb,md:myst", 30 | "main_language": "python" 31 | } 32 | }, 33 | "nbformat": 4, 34 | "nbformat_minor": 5 35 | } 36 | -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | cell_metadata_filter: -all 4 | formats: ipynb,md:myst 5 | main_language: python 6 | text_representation: 7 | extension: .md 8 | format_name: myst 9 | format_version: 0.13 10 | jupytext_version: 1.14.4 11 | --- 12 | 13 | We use jupytext to keep copies of notebook in sync with Markdown equivalent. 14 | 15 | ### Adding a new notebook 16 | 17 | ``` 18 | jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb 19 | ``` 20 | 21 | ### Syncing Notebooks 22 | 23 | After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running: 24 | 25 | ``` 26 | jupytext --sync docs/notebooks/* 27 | ``` 28 | -------------------------------------------------------------------------------- /tutorials/push_model_to_hf_hub.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | formats: ipynb,md:myst 4 | text_representation: 5 | extension: .md 6 | format_name: myst 7 | format_version: 0.13 8 | jupytext_version: 1.14.4 9 | kernelspec: 10 | display_name: Python 3 (ipykernel) 11 | language: python 12 | name: python3 13 | --- 14 | 15 | ```{code-cell} ipython3 16 | 17 | ``` 18 | 19 | ```{code-cell} ipython3 20 | import subprocess 21 | import os 22 | from distutils.dir_util import copy_tree 23 | ``` 24 | 25 | ### How to push a model to hub . 26 | 27 | * Make sure you have logged to ```huggingface-cli login``` using your token. 28 | 29 | ```{code-cell} ipython3 30 | 31 | ``` 32 | 33 | ```{code-cell} ipython3 34 | model_name = 'byt5-small' 35 | ``` 36 | 37 | #### 1. Create model name directory under organization name 38 | 39 | ```{code-cell} ipython3 40 | subprocess.run(['huggingface-cli', 'repo', 41 | 'create', '{}'.format(model_name), 42 | '--yes', 43 | '--organization', 'tftransformers']) 44 | ``` 45 | 46 | ```{code-cell} ipython3 47 | 48 | ``` 49 | 50 | #### 2. Now clone that above created repo/folder to our local cwd 51 | 52 | ```{code-cell} ipython3 53 | subprocess.run(["git", "clone", "https://huggingface.co/tftransformers/{}".format(model_name)]) 54 | ``` 55 | 56 | ```{code-cell} ipython3 57 | 58 | ``` 59 | 60 | #### 3. Now move your model directory , to current working directory under ```model_name``` directory 61 | 62 | ```{code-cell} ipython3 63 | cwd = os.getcwd() # Getc current working dir 64 | new_working_dir = os.path.join(cwd, model_name) # This is cloned from hf hub under organization 65 | os.chdir("{}".format(new_working_dir)) # Switch to new working dir 66 | 67 | # Cached model directory keep changing as per other machine 68 | cached_model_dir = '/var/folders/vq/4fxns8l55gq8_msgygbyb51h0000gn/T/tf_transformers_cache/{}/'.format(model_name) 69 | 70 | # Copy cached model directory , to new working directory 71 | copy_tree(cached_model_dir, new_working_dir) 72 | ``` 73 | 74 | ```{code-cell} ipython3 75 | 76 | ``` 77 | 78 | #### 4. Now time to push these model to hub 79 | 80 | ```{code-cell} ipython3 81 | subprocess.run(["git-lfs", "track", "*"]) 82 | subprocess.run(["git", "add", "."]) 83 | subprocess.run(["git", "commit", "-m", "Pushing new model {}".format(model_name)]) # Commit message 84 | subprocess.run(["git", "push"]) 85 | 86 | # Change back to original cwd 87 | os.chdir("{}".format(cwd)) 88 | ``` 89 | 90 | ```{code-cell} ipython3 91 | 92 | ``` 93 | --------------------------------------------------------------------------------