├── .gitignore ├── .travis.yml ├── AUTHORS ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE.md ├── LICENSE ├── README.md ├── docs ├── cloud_mlengine.md ├── cloud_tpu.md ├── distributed_training.md ├── index.md ├── multi_problem.md ├── new_model.md ├── new_problem.md ├── overview.md ├── tutorials │ └── asr_with_transformer.md └── walkthrough.md ├── floyd.yml ├── floyd_requirements.txt ├── oss_scripts ├── oss_integration_test.sh ├── oss_pip_install.sh ├── oss_release.sh └── oss_tests.sh ├── pylintrc ├── setup.py └── tensor2tensor ├── __init__.py ├── bin ├── __init__.py ├── build_vocab.py ├── make_tf_configs.py ├── t2t-avg-all ├── t2t-bleu ├── t2t-datagen ├── t2t-decoder ├── t2t-eval ├── t2t-exporter ├── t2t-insights-server ├── t2t-make-tf-configs ├── t2t-query-server ├── t2t-trainer ├── t2t-translate-all ├── t2t_attack.py ├── t2t_avg_all.py ├── t2t_bleu.py ├── t2t_datagen.py ├── t2t_decoder.py ├── t2t_distill.py ├── t2t_eval.py ├── t2t_prune.py ├── t2t_trainer.py ├── t2t_trainer_test.py └── t2t_translate_all.py ├── data_generators ├── README.md ├── __init__.py ├── algorithmic.py ├── algorithmic_math.py ├── algorithmic_math_deepmind.py ├── algorithmic_math_test.py ├── algorithmic_math_two_variables.py ├── algorithmic_test.py ├── all_problems.py ├── allen_brain.py ├── allen_brain_test.py ├── audio.py ├── audio_encoder.py ├── audio_test.py ├── babi_qa.py ├── bair_robot_pushing.py ├── celeba.py ├── celeba_test.py ├── celebahq.py ├── cifar.py ├── cipher.py ├── cleaner_en_xx.py ├── cnn_dailymail.py ├── cola.py ├── common_voice.py ├── common_voice_test.py ├── conll_ner.py ├── desc2code.py ├── desc2code_test.py ├── dialog_abstract.py ├── dialog_cornell.py ├── dialog_dailydialog.py ├── dialog_opensubtitles.py ├── dialog_personachat.py ├── dna_encoder.py ├── dna_encoder_test.py ├── enwik8.py ├── fsns.py ├── function_docstring.py ├── gene_expression.py ├── gene_expression_test.py ├── generator_utils.py ├── generator_utils_test.py ├── google_robot_pushing.py ├── gym_env.py ├── gym_env_test.py ├── ice_parsing.py ├── image_lsun.py ├── image_utils.py ├── image_utils_test.py ├── imagenet.py ├── imagenet_test.py ├── imdb.py ├── inspect_tfrecord.py ├── lambada.py ├── librispeech.py ├── lm1b.py ├── lm1b_imdb.py ├── lm1b_mnli.py ├── mnist.py ├── moving_mnist.py ├── mrpc.py ├── mscoco.py ├── mscoco_test.py ├── multi_problem.py ├── multi_problem_v2.py ├── multi_problem_v2_test.py ├── multinli.py ├── ocr.py ├── ops │ ├── pack_sequences_ops.cc │ ├── pack_sequences_ops_test.py │ ├── subword_text_encoder.cc │ ├── subword_text_encoder.h │ ├── subword_text_encoder_ops.cc │ ├── subword_text_encoder_ops_test.py │ ├── subword_text_encoder_test.cc │ └── testdata │ │ └── subwords ├── paraphrase_ms_coco.py ├── paraphrase_ms_coco_test.py ├── pointer_generator_word.py ├── problem.py ├── problem_hparams.py ├── problem_test.py ├── program_search.py ├── program_search_test.py ├── ptb.py ├── qnli.py ├── quora_qpairs.py ├── rte.py ├── scitail.py ├── seq2edits.py ├── snli.py ├── speech_recognition.py ├── squad.py ├── sst_binary.py ├── stanford_nli.py ├── style_transfer.py ├── style_transfer_test.py ├── subject_verb_agreement.py ├── test_data │ ├── 1.csv │ ├── corpus-1.txt │ ├── corpus-2.txt │ ├── vocab-1.txt │ └── vocab-2.txt ├── text_encoder.py ├── text_encoder_build_subword.py ├── text_encoder_test.py ├── text_problems.py ├── text_problems_test.py ├── timeseries.py ├── timeseries_data_generator.py ├── timeseries_data_generator_test.py ├── timeseries_test.py ├── tokenizer.py ├── tokenizer_test.py ├── transduction_problems.py ├── transduction_problems_test.py ├── translate.py ├── translate_encs.py ├── translate_encs_cubbitt.py ├── translate_ende.py ├── translate_ende_test.py ├── translate_enes.py ├── translate_enet.py ├── translate_enfr.py ├── translate_enid.py ├── translate_enmk.py ├── translate_enro.py ├── translate_entn.py ├── translate_envi.py ├── translate_enzh.py ├── translate_test.py ├── video_generated.py ├── video_utils.py ├── video_utils_test.py ├── vqa.py ├── vqa_utils.py ├── wiki.py ├── wiki_lm.py ├── wiki_multi_problems.py ├── wiki_revision.py ├── wiki_revision_utils.py ├── wikifact │ └── README.md ├── wikisum │ ├── README.md │ ├── __init__.py │ ├── delete_instances.sh │ ├── generate_vocab.py │ ├── get_references_commoncrawl.py │ ├── get_references_web.py │ ├── get_references_web_single_group.py │ ├── html.py │ ├── parallel_launch.py │ ├── produce_examples.py │ ├── test_data │ │ ├── para_bad1.txt │ │ └── para_good1.txt │ ├── utils.py │ ├── utils_test.py │ ├── validate_data.py │ └── wikisum.py ├── wikitext103.py ├── wnli.py ├── wsj_parsing.py ├── yelp_full.py └── yelp_polarity.py ├── envs ├── __init__.py ├── env_problem.py ├── env_problem_utils.py ├── env_problem_utils_test.py ├── gym_env_problem.py ├── gym_env_problem_test.py ├── gym_spaces_utils.py ├── gym_spaces_utils_test.py ├── mujoco_problems.py ├── mujoco_problems_test.py ├── rendered_env_problem.py ├── rendered_env_problem_test.py ├── tic_tac_toe_env.py ├── tic_tac_toe_env_problem.py ├── tic_tac_toe_env_problem_test.py ├── tic_tac_toe_env_test.py ├── time_step.py ├── time_step_test.py ├── trajectory.py └── trajectory_test.py ├── insights ├── README.md ├── __init__.py ├── graph.py ├── insight_configuration.proto ├── polymer │ ├── .bowerrc │ ├── attention_visualization │ │ ├── attention-visualization.html │ │ └── attention-visualization.js │ ├── bower.json │ ├── common-types.js │ ├── explore_view │ │ ├── explore-view.html │ │ └── explore-view.js │ ├── graph_visualization │ │ ├── graph-visualization.html │ │ └── graph-visualization.js │ ├── index.html │ ├── insights_app │ │ ├── insights-app.html │ │ └── insights-app.js │ ├── language_selector │ │ ├── language-selector-content.html │ │ ├── language-selector-content.js │ │ ├── language-selector.html │ │ └── language-selector.js │ ├── processing_visualization │ │ ├── processing-visualization.html │ │ └── processing-visualization.js │ ├── query_card │ │ ├── query-card.html │ │ └── query-card.js │ ├── tensor2tensor.html │ └── translation_result │ │ ├── translation-result.html │ │ └── translation-result.js ├── query_processor.py ├── server.py └── transformer_model.py ├── layers ├── __init__.py ├── area_attention.py ├── area_attention_test.py ├── common_attention.py ├── common_attention_test.py ├── common_audio.py ├── common_hparams.py ├── common_image_attention.py ├── common_image_attention_test.py ├── common_layers.py ├── common_layers_test.py ├── common_video.py ├── common_video_test.py ├── discretization.py ├── discretization_test.py ├── latent_layers.py ├── latent_layers_test.py ├── message_passing_attention.py ├── modalities.py ├── modalities_test.py ├── ngram.py ├── ngram_test.py ├── transformer_glow_layers.py ├── transformer_glow_layers_ops.py ├── transformer_glow_layers_ops_test.py ├── transformer_glow_layers_test.py ├── transformer_layers.py ├── transformer_memory.py ├── transformer_memory_test.py ├── vq_discrete.py └── vqa_layers.py ├── metrics ├── __init__.py ├── video_conditional_fvd.py └── video_conditional_fvd_test.py ├── models ├── README.md ├── __init__.py ├── basic.py ├── basic_test.py ├── bytenet.py ├── bytenet_test.py ├── distillation.py ├── evolved_transformer.py ├── evolved_transformer_test.py ├── image_transformer.py ├── image_transformer_2d.py ├── image_transformer_2d_test.py ├── image_transformer_test.py ├── lstm.py ├── lstm_test.py ├── mtf_image_transformer.py ├── mtf_image_transformer_test.py ├── mtf_resnet.py ├── mtf_transformer.py ├── mtf_transformer2.py ├── mtf_transformer_test.py ├── neural_architecture_search │ ├── README.md │ ├── __init__.py │ ├── nas_layers.py │ ├── nas_layers_test.py │ ├── nas_model.py │ └── nas_model_test.py ├── neural_assistant.py ├── neural_gpu.py ├── neural_gpu_test.py ├── research │ ├── __init__.py │ ├── adafactor_experiments.py │ ├── aligned.py │ ├── attention_lm.py │ ├── attention_lm_moe.py │ ├── autoencoders.py │ ├── autoencoders_test.py │ ├── cycle_gan.py │ ├── gene_expression.py │ ├── gene_expression_test.py │ ├── glow.py │ ├── glow_init_hook.py │ ├── glow_ops.py │ ├── glow_ops_test.py │ ├── glow_test.py │ ├── lm_experiments.py │ ├── moe.py │ ├── moe_experiments.py │ ├── multiquery_paper.py │ ├── neural_stack.py │ ├── neural_stack_test.py │ ├── residual_shuffle_exchange.py │ ├── rl.py │ ├── shuffle_network.py │ ├── similarity_transformer.py │ ├── super_lm.py │ ├── transformer_aux.py │ ├── transformer_aux_test.py │ ├── transformer_moe.py │ ├── transformer_nat.py │ ├── transformer_parallel.py │ ├── transformer_revnet.py │ ├── transformer_revnet_test.py │ ├── transformer_seq2edits.py │ ├── transformer_sketch.py │ ├── transformer_symshard.py │ ├── transformer_vae.py │ ├── transformer_vae_flow_prior.py │ ├── transformer_vae_flow_prior_ops.py │ ├── transformer_vae_test.py │ ├── universal_transformer.py │ ├── universal_transformer_test.py │ ├── universal_transformer_util.py │ ├── vqa_attention.py │ ├── vqa_attention_test.py │ ├── vqa_recurrent_self_attention.py │ └── vqa_self_attention.py ├── resnet.py ├── resnet_test.py ├── revnet.py ├── revnet_test.py ├── shake_shake.py ├── slicenet.py ├── slicenet_test.py ├── text_cnn.py ├── transformer.py ├── transformer_test.py ├── vanilla_gan.py ├── video │ ├── __init__.py │ ├── base.py │ ├── base_vae.py │ ├── basic_deterministic.py │ ├── basic_deterministic_params.py │ ├── basic_deterministic_test.py │ ├── basic_recurrent.py │ ├── basic_recurrent_test.py │ ├── basic_stochastic.py │ ├── basic_stochastic_test.py │ ├── emily.py │ ├── emily_test.py │ ├── epva.py │ ├── epva_params.py │ ├── next_frame_glow.py │ ├── nfg_conv3d_test.py │ ├── nfg_conv_lstm_test.py │ ├── nfg_conv_test.py │ ├── nfg_interpolate.py │ ├── nfg_test_utils.py │ ├── nfg_uncond_test.py │ ├── savp.py │ ├── savp_params.py │ ├── savp_test.py │ ├── sv2p.py │ ├── sv2p_params.py │ ├── sv2p_test.py │ └── tests_utils.py ├── xception.py └── xception_test.py ├── notebooks ├── Transformer_translate.ipynb ├── asr_transformer.ipynb ├── hello_t2t-rl.ipynb ├── hello_t2t.ipynb └── t2t_problem.ipynb ├── problems.py ├── problems_colab.py ├── problems_test.py ├── rl ├── README.md ├── __init__.py ├── batch_dqn_agent_test.py ├── batch_runner_test.py ├── datagen_with_agent.py ├── dopamine_connector.py ├── envs │ ├── __init__.py │ ├── in_graph_batch_env.py │ ├── py_func_batch_env.py │ ├── simulated_batch_env.py │ ├── simulated_batch_gym_env.py │ └── tf_atari_wrappers.py ├── evaluator.py ├── evaluator_test.py ├── gym_utils.py ├── gym_utils_test.py ├── player.py ├── player_utils.py ├── policy_learner.py ├── ppo.py ├── ppo_learner.py ├── restarter.py ├── restarter_test.py ├── rl_utils.py ├── trainer_model_based.py ├── trainer_model_based_agent_only.py ├── trainer_model_based_params.py ├── trainer_model_based_recurrent_test.py ├── trainer_model_based_stochastic_test.py ├── trainer_model_based_sv2p_test.py ├── trainer_model_based_test.py ├── trainer_model_free.py ├── trainer_model_free_test.py └── trainer_model_free_tictactoe_test.py ├── serving ├── README.md ├── __init__.py ├── export.py ├── query.py └── serving_utils.py ├── test_data ├── example_usr_dir │ ├── __init__.py │ ├── my_submodule.py │ └── requirements.txt ├── transformer_test_ckpt │ ├── checkpoint │ ├── flags.txt │ ├── hparams.json │ ├── model.ckpt-1.data-00000-of-00002 │ ├── model.ckpt-1.data-00001-of-00002 │ ├── model.ckpt-1.index │ └── model.ckpt-1.meta ├── vocab.translate_ende_wmt32k.32768.subwords └── vocab.translate_ende_wmt8k.8192.subwords ├── utils ├── __init__.py ├── adafactor.py ├── adafactor_test.py ├── adv_attack_utils.py ├── avg_checkpoints.py ├── beam_search.py ├── beam_search_test.py ├── bleu_hook.py ├── bleu_hook_test.py ├── checkpoint_compatibility_test.py ├── cloud_mlengine.py ├── compute_video_metrics.py ├── contrib.py ├── data_reader.py ├── data_reader_test.py ├── decoding.py ├── devices.py ├── diet.py ├── diet_test.py ├── expert_utils.py ├── expert_utils_test.py ├── flags.py ├── get_cnndm_rouge.sh ├── get_ende_bleu.sh ├── get_rouge.py ├── hparam.py ├── hparam_test.py ├── hparams_lib.py ├── hparams_lib_test.py ├── learning_rate.py ├── metrics.py ├── metrics_hook.py ├── metrics_hook_test.py ├── metrics_test.py ├── misc_utils.py ├── misc_utils_test.py ├── mlperf_log.py ├── mlperf_tags.py ├── mtf_model.py ├── multistep_optimizer.py ├── multistep_optimizer_test.py ├── multistep_with_adamoptimizer.py ├── multistep_with_adamoptimizer_test.py ├── optimize.py ├── optimize_test.py ├── partial_checkpoint_load_hook.py ├── pruning_utils.py ├── quantization.py ├── registry.py ├── registry_test.py ├── restore_hook.py ├── rouge.py ├── rouge_test.py ├── sari_hook.py ├── sari_hook_test.py ├── scheduled_sampling.py ├── t2t_model.py ├── t2t_model_test.py ├── test_utils.py ├── test_utils_test.py ├── trainer_lib.py ├── trainer_lib_test.py ├── update_ops_hook.py ├── usr_dir.py ├── video │ ├── prediction2gif.py │ └── reward_confusion.py ├── video2gif.py ├── video_metrics.py ├── video_metrics_test.py ├── yellowfin.py └── yellowfin_test.py └── visualization ├── TransformerVisualization.ipynb ├── __init__.py ├── attention.js ├── attention.py ├── visualization.py └── visualization_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | 8 | # Python egg metadata, regenerated from source files by setuptools. 9 | /*.egg-info 10 | .eggs/ 11 | 12 | # PyPI distribution artifacts. 13 | build/ 14 | dist/ 15 | 16 | # Sublime project files 17 | *.sublime-project 18 | *.sublime-workspace 19 | 20 | # Tests 21 | .pytest_cache/ 22 | 23 | # Other 24 | *.DS_Store 25 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | language: python 3 | cache: pip 4 | git: 5 | depth: 3 6 | quiet: true 7 | services: 8 | - docker 9 | python: 10 | - "3.6" 11 | env: 12 | global: 13 | - T2T_PROBLEM=algorithmic_reverse_binary40_test 14 | - T2T_DATA_DIR=/tmp/t2t-data 15 | - T2T_TRAIN_DIR=/tmp/t2t-train 16 | - TF_LATEST="1.15.*" 17 | # This is necessary to have gsutil work with Python 2.7 18 | - BOTO_CONFIG=/dev/null 19 | matrix: 20 | - TF_VERSION="1.15.*" 21 | install: 22 | - ./oss_scripts/oss_pip_install.sh 23 | script: 24 | - ./oss_scripts/oss_tests.sh 25 | - ./oss_scripts/oss_integration_test.sh 26 | 27 | # Conditional commands should each be in a separate block to get proper 28 | # errors on Travis. 29 | # 30 | # TODO(afrozm): Re-enable if this becomes an issue. 31 | # - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then 32 | # pylint -j 2 tensor2tensor; 33 | # fi 34 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of T2T authors for copyright purposes. 2 | # 3 | # This does not necessarily list everyone who has contributed code, since in 4 | # some cases, their employer may be the copyright holder. To see the full list 5 | # of contributors, see the revision history in source control. 6 | 7 | Google Inc. 8 | Artit Wangperawong -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | # Issues 4 | 5 | * Please tag your issue with `bug`, `feature request`, or `question` to help us 6 | effectively respond. 7 | * Please include the versions of TensorFlow and Tensor2Tensor you are running 8 | (run `pip list | grep tensor`) 9 | * Please provide the command line you ran as well as the log output. 10 | 11 | # Pull Requests 12 | 13 | We'd love to accept your patches and contributions to this project. There are 14 | just a few small guidelines you need to follow. 15 | 16 | ## Contributor License Agreement 17 | 18 | Contributions to this project must be accompanied by a Contributor License 19 | Agreement. You (or your employer) retain the copyright to your contribution, 20 | this simply gives us permission to use and redistribute your contributions as 21 | part of the project. Head over to to see 22 | your current agreements on file or to sign a new one. 23 | 24 | You generally only need to submit a CLA once, so if you've already submitted one 25 | (even if it was for a different project), you probably don't need to do it 26 | again. 27 | 28 | ## Code reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use GitHub pull requests for this purpose. Consult 32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 33 | information on using pull requests. 34 | -------------------------------------------------------------------------------- /ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ### Description 2 | 3 | ... 4 | 5 | ### Environment information 6 | 7 | ``` 8 | OS: 9 | 10 | $ pip freeze | grep tensor 11 | # your output here 12 | 13 | $ python -V 14 | # your output here 15 | ``` 16 | 17 | ### For bugs: reproduction and error logs 18 | 19 | ``` 20 | # Steps to reproduce: 21 | ... 22 | ``` 23 | 24 | ``` 25 | # Error logs: 26 | ... 27 | ``` 28 | -------------------------------------------------------------------------------- /docs/cloud_tpu.md: -------------------------------------------------------------------------------- 1 | # Running on Cloud TPUs 2 | 3 | Tensor2Tensor supports running on Google Cloud Platforms TPUs, chips 4 | specialized for ML training. See the official tutorials for [running the 5 | T2T Transformer for text on Cloud TPUs](https://cloud.google.com/tpu/docs/tutorials/transformer) and 6 | [Transformer for Speech Recognition](https://cloud.google.com/tpu/docs/tutorials/automated-speech-recognition). 7 | 8 | ## Other models on TPU 9 | 10 | Many of Tensor2Tensor's models work on TPU. 11 | 12 | You can provision a VM and TPU with `ctpu up`. Use the `t2t-trainer` command 13 | on the VM as usual with the additional flags `--use_tpu` and 14 | `--cloud_tpu_name=$TPU_NAME`. 15 | 16 | Note that because the `TPUEstimator` does not catch the `OutOfRangeError` 17 | during evaluation, you should ensure that `--eval_steps` is small enough to 18 | not exhaust the evaluation data. 19 | 20 | A non-exhaustive list of T2T models that work on TPU: 21 | 22 | * Image generation: `imagetransformer` with `imagetransformer_base_tpu` (or 23 | `imagetransformer_tiny_tpu`) 24 | * Super-resolution: `img2img_transformer` with `img2img_transformer_base_tpu` 25 | (or `img2img_transformer_tiny_tpu`) 26 | * `resnet` with `resnet_50` (or `resnet_18` or `resnet_34`) 27 | * `revnet` with `revnet_104` (or `revnet_38_cifar`) 28 | * `shake_shake` with `shakeshake_tpu` (or `shakeshake_small`) 29 | 30 | ## Example invocation 31 | 32 | Use `ctpu up` to bring up the VM and TPU machines; once the machines are ready 33 | it will SSH you into the VM and you can run the following: 34 | 35 | ``` 36 | # DATA_DIR and OUT_DIR should be GCS buckets 37 | # TPU_NAME should have been set automatically by the ctpu tool 38 | 39 | t2t-trainer \ 40 | --model=shake_shake \ 41 | --hparams_set=shakeshake_tpu \ 42 | --problem=image_cifar10 \ 43 | --train_steps=180000 \ 44 | --eval_steps=9 \ 45 | --local_eval_frequency=100 \ 46 | --data_dir=$DATA_DIR \ 47 | --output_dir=$OUT_DIR \ 48 | --use_tpu \ 49 | --cloud_tpu_name=$TPU_NAME 50 | ``` 51 | -------------------------------------------------------------------------------- /docs/tutorials/asr_with_transformer.md: -------------------------------------------------------------------------------- 1 | # Automated Speech Recognition with the Transformer model 2 | 3 | See the 4 | [official tutorial](https://cloud.google.com/tpu/docs/tutorials/automated-speech-recognition). 5 | -------------------------------------------------------------------------------- /floyd.yml: -------------------------------------------------------------------------------- 1 | env: tensorflow-1.12 2 | machine: gpu 3 | -------------------------------------------------------------------------------- /floyd_requirements.txt: -------------------------------------------------------------------------------- 1 | tensor2tensor 2 | -------------------------------------------------------------------------------- /oss_scripts/oss_integration_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Note that this test script requires docker to be installed and running. 4 | 5 | set -v # print commands as they're executed 6 | set -e # fail and exit on any command erroring 7 | 8 | : "${TF_VERSION:?}" 9 | : "${TF_LATEST:?}" 10 | : "${T2T_DATA_DIR:?}" 11 | : "${T2T_TRAIN_DIR:?}" 12 | : "${T2T_PROBLEM:?}" 13 | 14 | # Test --t2t_usr_dir 15 | t2t-trainer --registry_help --t2t_usr_dir=./tensor2tensor/test_data/example_usr_dir 2>&1 | grep my_very_own_hparams && echo passed 16 | 17 | # Run data generation, training, and decoding on a dummy problem 18 | t2t-datagen --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR 19 | t2t-trainer --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --train_steps=5 --eval_steps=5 --output_dir=$T2T_TRAIN_DIR 20 | t2t-decoder --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR --decode_hparams='num_samples=10' 21 | 22 | # Test serving 23 | if [[ "$TF_VERSION" == "$TF_LATEST" ]] 24 | then 25 | # Export for serving 26 | pip install tensorflow_hub 27 | t2t-exporter \ 28 | --problem=$T2T_PROBLEM \ 29 | --data_dir=$T2T_DATA_DIR \ 30 | --model=transformer \ 31 | --hparams_set=transformer_tiny \ 32 | --output_dir=$T2T_TRAIN_DIR 33 | 34 | # Run model server 35 | server_port=8500 36 | model_name=my_model 37 | docker run -d -p $server_port:$server_port \ 38 | --mount type=bind,source=$T2T_TRAIN_DIR/export,target=/models/$model_name \ 39 | -e MODEL_NAME=$model_name -t tensorflow/serving 40 | sleep 10 41 | 42 | # Query 43 | pip install tensorflow-serving-api=="$TF_VERSION" 44 | t2t-query-server \ 45 | --server=localhost:$server_port \ 46 | --servable_name=$model_name \ 47 | --problem=$T2T_PROBLEM \ 48 | --data_dir=$T2T_DATA_DIR \ 49 | --inputs_once='1 0 1 0 1 0' 50 | fi 51 | -------------------------------------------------------------------------------- /oss_scripts/oss_pip_install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -v # print commands as they're executed 4 | set -e # fail and exit on any command erroring 5 | 6 | : "${TF_VERSION:?}" 7 | 8 | # Make sure we have the latest pip and setuptools installed. 9 | pip install -q -U pip 10 | pip install -q -U setuptools 11 | 12 | # Make sure we have the latest version of numpy - avoid problems we were 13 | # seeing with Python 3 14 | pip install -q -U numpy 15 | pip install -q "tensorflow==$TF_VERSION" 16 | 17 | # Just print the version again to make sure. 18 | python -c 'import tensorflow as tf; print(tf.__version__)' 19 | 20 | # First ensure that the base dependencies are sufficient for a full import 21 | pip install -q -e . 22 | t2t-trainer --registry_help 2>&1 >/dev/null 23 | t2t-datagen 2>&1 | grep translate_ende 2>&1 >/dev/null && echo passed 24 | 25 | # Then install the test dependencies 26 | pip install -q -e .[tests,allen] 27 | # Make sure to install the atari extras for gym 28 | pip install "gym[atari]" 29 | -------------------------------------------------------------------------------- /oss_scripts/oss_release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -v # print commands as they're executed 4 | set -e # fail and exit on any command erroring 5 | 6 | GIT_COMMIT_ID=${1:-""} 7 | [[ -z $GIT_COMMIT_ID ]] && echo "Must provide a commit" && exit 1 8 | 9 | TMP_DIR=$(mktemp -d) 10 | pushd $TMP_DIR 11 | 12 | echo "Cloning tensor2tensor and checking out commit $GIT_COMMIT_ID" 13 | git clone https://github.com/tensorflow/tensor2tensor.git 14 | cd tensor2tensor 15 | git checkout $GIT_COMMIT_ID 16 | 17 | # Without `python -m` we sometimes get module not callable error: 18 | # https://stackoverflow.com/questions/58451650/pip-no-longer-working-after-update-error-module-object-is-not-callable 19 | python -m pip install wheel twine pyopenssl 20 | 21 | # Build the distribution 22 | echo "Building distribution" 23 | python setup.py sdist 24 | python setup.py bdist_wheel --universal 25 | 26 | # Publish to PyPI 27 | echo "Publishing to PyPI" 28 | twine upload dist/* 29 | 30 | # Cleanup 31 | rm -rf build/ dist/ tensor2tensor.egg-info/ 32 | popd 33 | rm -rf $TMP_DIR 34 | -------------------------------------------------------------------------------- /tensor2tensor/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tensor2tensor/bin/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tensor2tensor/bin/build_vocab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Build vocab for a subclass of Text2TextProblem. 17 | 18 | build_vocab \ 19 | --problem=program_search_algolisp \ 20 | --data_dir=~/t2t_data \ 21 | --tmp_dir=~/t2t_data/tmp 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | import os 29 | 30 | from tensor2tensor import problems as problems_lib # pylint: disable=unused-import 31 | from tensor2tensor.data_generators import text_problems 32 | from tensor2tensor.utils import registry 33 | import tensorflow.compat.v1 as tf 34 | 35 | flags = tf.flags 36 | FLAGS = flags.FLAGS 37 | 38 | flags.DEFINE_string("data_dir", "/tmp/t2t/data_dir", 39 | "Directory to place the generated vocabulary file in.") 40 | 41 | flags.DEFINE_string("tmp_dir", "/tmp/t2t/tmp_dir", 42 | "Temporary storage directory.") 43 | 44 | flags.DEFINE_string("problem", None, 45 | "Problem to generate the vocabulary file for.") 46 | 47 | flags.mark_flag_as_required("problem") 48 | 49 | 50 | def main(_): 51 | problem = registry.problem(FLAGS.problem) 52 | 53 | # We make the assumption that the problem is a subclass of Text2TextProblem. 54 | assert isinstance(problem, text_problems.Text2TextProblem) 55 | 56 | data_dir = os.path.expanduser(FLAGS.data_dir) 57 | tmp_dir = os.path.expanduser(FLAGS.tmp_dir) 58 | 59 | tf.gfile.MakeDirs(data_dir) 60 | tf.gfile.MakeDirs(tmp_dir) 61 | 62 | tf.logging.info("Saving vocabulary to data_dir: %s" % data_dir) 63 | 64 | problem.get_or_create_vocab(data_dir, tmp_dir) 65 | 66 | tf.logging.info("Saved vocabulary file: " + 67 | os.path.join(data_dir, problem.vocab_filename)) 68 | 69 | 70 | if __name__ == "__main__": 71 | tf.logging.set_verbosity(tf.logging.INFO) 72 | tf.app.run() 73 | -------------------------------------------------------------------------------- /tensor2tensor/bin/t2t-avg-all: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """t2t-avg-all.""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from tensor2tensor.bin import t2t_avg_all 8 | 9 | import tensorflow as tf 10 | 11 | def main(argv): 12 | t2t_avg_all.main(argv) 13 | 14 | 15 | if __name__ == "__main__": 16 | tf.logging.set_verbosity(tf.logging.INFO) 17 | tf.app.run() 18 | -------------------------------------------------------------------------------- /tensor2tensor/bin/t2t-bleu: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """t2t-bleu.""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from tensor2tensor.bin import t2t_bleu 8 | 9 | import tensorflow as tf 10 | 11 | def main(argv): 12 | t2t_bleu.main(argv) 13 | 14 | 15 | 16 | if __name__ == "__main__": 17 | tf.logging.set_verbosity(tf.logging.INFO) 18 | tf.app.run() 19 | -------------------------------------------------------------------------------- /tensor2tensor/bin/t2t-datagen: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Data generation for Tensor2Tensor. 3 | 4 | This script is used to generate data to train your models 5 | for a number problems for which open-source data is available. 6 | 7 | For example, to generate data for MNIST run this: 8 | 9 | t2t-datagen \ 10 | --problem=image_mnist \ 11 | --data_dir=~/t2t_data \ 12 | --tmp_dir=~/t2t_data/tmp 13 | """ 14 | from __future__ import absolute_import 15 | from __future__ import division 16 | from __future__ import print_function 17 | 18 | from tensor2tensor.bin import t2t_datagen 19 | 20 | import tensorflow.compat.v1 as tf 21 | 22 | def main(argv): 23 | t2t_datagen.main(argv) 24 | 25 | 26 | if __name__ == "__main__": 27 | tf.logging.set_verbosity(tf.logging.INFO) 28 | tf.app.run() 29 | -------------------------------------------------------------------------------- /tensor2tensor/bin/t2t-decoder: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """t2t-decoder.""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from tensor2tensor.bin import t2t_decoder 8 | 9 | import tensorflow as tf 10 | 11 | def main(argv): 12 | t2t_decoder.main(argv) 13 | 14 | 15 | if __name__ == "__main__": 16 | tf.logging.set_verbosity(tf.logging.INFO) 17 | tf.app.run() 18 | -------------------------------------------------------------------------------- /tensor2tensor/bin/t2t-eval: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Run t2t-eval from a trained checkpoint. 3 | 4 | This script is used to run evaluation from a trained checkpoint. Example 5 | to run evaluation on the test set when trained checkpoint is in /output_dir. 6 | 7 | t2t-eval \ 8 | --problem=image_mnist \ 9 | --model=imagetransformer \ 10 | --data_dir=~/t2t 11 | --output_dir=/output_dir \ 12 | --eval_use_test_set=True \ 13 | """ 14 | from __future__ import absolute_import 15 | from __future__ import division 16 | from __future__ import print_function 17 | 18 | from tensor2tensor.bin import t2t_eval 19 | 20 | import tensorflow as tf 21 | 22 | def main(argv): 23 | t2t_eval.main(argv) 24 | 25 | 26 | if __name__ == "__main__": 27 | tf.logging.set_verbosity(tf.logging.INFO) 28 | tf.app.run() 29 | -------------------------------------------------------------------------------- /tensor2tensor/bin/t2t-exporter: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """t2t-exporter.""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from tensor2tensor.serving import export 8 | 9 | import tensorflow as tf 10 | 11 | def main(argv): 12 | export.main(argv) 13 | 14 | 15 | if __name__ == "__main__": 16 | tf.logging.set_verbosity(tf.logging.INFO) 17 | tf.app.run() 18 | -------------------------------------------------------------------------------- /tensor2tensor/bin/t2t-insights-server: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """t2t-insights-server.""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from tensor2tensor.insights import server 8 | 9 | import tensorflow as tf 10 | 11 | def main(argv): 12 | server.main(argv) 13 | 14 | 15 | if __name__ == "__main__": 16 | tf.logging.set_verbosity(tf.logging.INFO) 17 | tf.app.run() 18 | -------------------------------------------------------------------------------- /tensor2tensor/bin/t2t-make-tf-configs: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """t2t-make-tf-configs.""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from tensor2tensor.bin import make_tf_configs 8 | 9 | import tensorflow as tf 10 | 11 | def main(argv): 12 | make_tf_configs.main(argv) 13 | 14 | 15 | if __name__ == "__main__": 16 | tf.logging.set_verbosity(tf.logging.INFO) 17 | tf.app.run() 18 | -------------------------------------------------------------------------------- /tensor2tensor/bin/t2t-query-server: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """t2t-query-server.""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from tensor2tensor.serving import query 8 | 9 | import tensorflow as tf 10 | 11 | def main(argv): 12 | query.main(argv) 13 | 14 | 15 | if __name__ == "__main__": 16 | tf.logging.set_verbosity(tf.logging.INFO) 17 | tf.app.run() 18 | -------------------------------------------------------------------------------- /tensor2tensor/bin/t2t-trainer: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Trainer for Tensor2Tensor. 3 | 4 | This script is used to train your models in Tensor2Tensor. 5 | 6 | For example, to train a shake-shake model on MNIST run this: 7 | 8 | t2t-trainer \ 9 | --generate_data \ 10 | --problem=image_mnist \ 11 | --data_dir=~/t2t_data \ 12 | --tmp_dir=~/t2t_data/tmp 13 | --model=shake_shake \ 14 | --hparams_set=shake_shake_quick \ 15 | --output_dir=~/t2t_train/mnist1 \ 16 | --train_steps=1000 \ 17 | --eval_steps=100 18 | """ 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | from tensor2tensor.bin import t2t_trainer 24 | 25 | import tensorflow.compat.v1 as tf 26 | 27 | def main(argv): 28 | t2t_trainer.main(argv) 29 | 30 | 31 | if __name__ == "__main__": 32 | tf.logging.set_verbosity(tf.logging.INFO) 33 | tf.app.run(main) 34 | -------------------------------------------------------------------------------- /tensor2tensor/bin/t2t-translate-all: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """t2t-translate-all.""" 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from tensor2tensor.bin import t2t_translate_all 8 | 9 | import tensorflow as tf 10 | 11 | def main(argv): 12 | t2t_translate_all.main(argv) 13 | 14 | 15 | 16 | if __name__ == "__main__": 17 | tf.logging.set_verbosity(tf.logging.INFO) 18 | tf.app.run() 19 | -------------------------------------------------------------------------------- /tensor2tensor/bin/t2t_eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Perform evaluation on trained T2T models using the Estimator API.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.bin import t2t_trainer # pylint: disable=unused-import 23 | from tensor2tensor.data_generators import problem # pylint: disable=unused-import 24 | from tensor2tensor.utils import trainer_lib 25 | from tensor2tensor.utils import usr_dir 26 | import tensorflow.compat.v1 as tf 27 | from tensorflow.compat.v1 import estimator as tf_estimator 28 | 29 | flags = tf.flags 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | def main(_): 34 | tf.logging.set_verbosity(tf.logging.INFO) 35 | trainer_lib.set_random_seed(FLAGS.random_seed) 36 | usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) 37 | 38 | hparams = trainer_lib.create_hparams( 39 | FLAGS.hparams_set, FLAGS.hparams, data_dir=FLAGS.data_dir, 40 | problem_name=FLAGS.problem) 41 | 42 | # set appropriate dataset-split, if flags.eval_use_test_set. 43 | dataset_split = "test" if FLAGS.eval_use_test_set else None 44 | dataset_kwargs = {"dataset_split": dataset_split} 45 | eval_input_fn = hparams.problem.make_estimator_input_fn( 46 | tf_estimator.ModeKeys.EVAL, hparams, dataset_kwargs=dataset_kwargs) 47 | config = t2t_trainer.create_run_config(hparams) 48 | 49 | # summary-hook in tf.estimator.EstimatorSpec requires 50 | # hparams.model_dir to be set. 51 | hparams.add_hparam("model_dir", config.model_dir) 52 | 53 | estimator = trainer_lib.create_estimator( 54 | FLAGS.model, hparams, config, use_tpu=FLAGS.use_tpu) 55 | ckpt_iter = trainer_lib.next_checkpoint( 56 | hparams.model_dir, FLAGS.eval_timeout_mins) 57 | for ckpt_path in ckpt_iter: 58 | predictions = estimator.evaluate( 59 | eval_input_fn, steps=FLAGS.eval_steps, checkpoint_path=ckpt_path) 60 | tf.logging.info(predictions) 61 | 62 | 63 | if __name__ == "__main__": 64 | tf.logging.set_verbosity(tf.logging.INFO) 65 | tf.app.run() 66 | -------------------------------------------------------------------------------- /tensor2tensor/bin/t2t_trainer_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for t2t_trainer.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from tensor2tensor.bin import t2t_trainer 22 | from tensor2tensor.utils import trainer_lib_test 23 | 24 | import tensorflow.compat.v1 as tf 25 | 26 | FLAGS = tf.flags.FLAGS 27 | 28 | 29 | class TrainerTest(tf.test.TestCase): 30 | 31 | @classmethod 32 | def setUpClass(cls): 33 | trainer_lib_test.TrainerLibTest.setUpClass() 34 | 35 | def testTrain(self): 36 | FLAGS.problem = "tiny_algo" 37 | FLAGS.model = "transformer" 38 | FLAGS.hparams_set = "transformer_tiny" 39 | FLAGS.train_steps = 1 40 | FLAGS.eval_steps = 1 41 | FLAGS.output_dir = tf.test.get_temp_dir() 42 | FLAGS.data_dir = tf.test.get_temp_dir() 43 | t2t_trainer.main(None) 44 | 45 | 46 | if __name__ == "__main__": 47 | tf.test.main() 48 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/audio_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2tensor.data_generators.audio.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import io 23 | import os 24 | from tensor2tensor.data_generators import audio 25 | 26 | import tensorflow.compat.v1 as tf 27 | 28 | 29 | class AudioTest(tf.test.TestCase): 30 | 31 | def testDataCollection(self): 32 | # Generate a trivial source and target file. 33 | tmp_dir = self.get_temp_dir() 34 | test_files = [ 35 | "dir1/file1", 36 | "dir1/file2", 37 | "dir1/dir2/file3", 38 | "dir1/dir2/dir3/file4", 39 | ] 40 | for filename in test_files: 41 | input_filename = os.path.join(tmp_dir, filename + ".WAV") 42 | target_filename = os.path.join(tmp_dir, filename + ".WRD") 43 | directories = os.path.dirname(input_filename) 44 | if not os.path.exists(directories): 45 | os.makedirs(directories) 46 | io.open(input_filename, "wb") 47 | io.open(target_filename, "wb") 48 | 49 | data_dict = audio._collect_data(tmp_dir, ".WAV", ".WRD") 50 | expected = [os.path.join(tmp_dir, filename) for filename in test_files] 51 | self.assertEqual(sorted(list(data_dict)), sorted(expected)) 52 | 53 | # Clean up. 54 | for filename in test_files: 55 | os.remove(os.path.join(tmp_dir, "%s.WAV" % filename)) 56 | os.remove(os.path.join(tmp_dir, "%s.WRD" % filename)) 57 | 58 | 59 | if __name__ == "__main__": 60 | tf.test.main() 61 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/celeba_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for CelebA.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl.testing import parameterized 23 | from tensor2tensor.data_generators import celeba 24 | from tensor2tensor.utils import hparam 25 | 26 | import tensorflow.compat.v1 as tf 27 | from tensorflow.compat.v1 import estimator as tf_estimator 28 | 29 | 30 | class CelebaTest(parameterized.TestCase, tf.test.TestCase): 31 | 32 | @parameterized.named_parameters( 33 | ("Default", None), 34 | ("Area", "AREA"), 35 | ("Dilated", "DILATED")) 36 | def testCelebaMultiResolutionPreprocessExample(self, resize_method): 37 | example = {"inputs": tf.random_uniform([218, 178, 3], minval=-1.)} 38 | mode = tf_estimator.ModeKeys.TRAIN 39 | hparams = hparam.HParams(resolutions=[8, 16, 32]) 40 | if resize_method is not None: 41 | hparams.resize_method = resize_method 42 | 43 | problem = celeba.ImageCelebaMultiResolution() 44 | preprocessed_example = problem.preprocess_example(example, mode, hparams) 45 | self.assertLen(preprocessed_example, 2) 46 | self.assertEqual(preprocessed_example["inputs"].shape, (138, 138, 3)) 47 | self.assertEqual(preprocessed_example["targets"].shape, (42, 32, 3)) 48 | 49 | 50 | if __name__ == "__main__": 51 | tf.test.main() 52 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/common_voice_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2tensor.data_generators.common_voice.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | from tensor2tensor.data_generators import common_voice 24 | 25 | import tensorflow.compat.v1 as tf 26 | 27 | pkg_dir, _ = os.path.split(__file__) 28 | _TESTDATA = os.path.join(pkg_dir, "test_data") 29 | 30 | 31 | class CommonVoiceTest(tf.test.TestCase): 32 | 33 | def testCollectData(self): 34 | output = common_voice._collect_data(_TESTDATA) 35 | self.assertEqual(1, len(output)) 36 | 37 | # NOTE: No header. 38 | self.assertTrue("my_media" == output[0][0]) 39 | self.assertTrue("my_label" == output[0][2]) 40 | 41 | 42 | if __name__ == "__main__": 43 | tf.test.main() 44 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/desc2code_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for desc2code.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.data_generators import desc2code 23 | 24 | import tensorflow.compat.v1 as tf 25 | 26 | CODE_CPP_IN = """ 27 | #include 28 | 29 | void main() { // This comment will be removed 30 | // This too. 31 | // 32 | /* Not this one */ 33 | \t 34 | \t 35 | int a \t\n = 3;// 36 | // 37 | } 38 | 39 | """ 40 | 41 | CODE_CPP_OUT = ("#include void main() { /* Not this one */ int a = " 42 | "3; }") 43 | 44 | 45 | class Desc2codeTest(tf.test.TestCase): 46 | 47 | def testCppPreprocess(self): 48 | """Check that the file correctly preprocess the code source.""" 49 | cpp_pb = desc2code.ProgrammingDesc2codeCpp() 50 | 51 | self.assertEqual( # Add space beween two lines 52 | cpp_pb.preprocess_target("firstline//comm1\nsecondline//comm2\n"), 53 | "firstline secondline") 54 | # Checking for boths comments and spaces 55 | self.assertEqual(cpp_pb.preprocess_target(CODE_CPP_IN), CODE_CPP_OUT) 56 | self.assertEqual( 57 | cpp_pb.preprocess_target(" not removed //abcd "), 58 | "not removed //abcd") 59 | 60 | 61 | if __name__ == "__main__": 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/dna_encoder_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2tensor.data_generators.dna_encoder.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from tensor2tensor.data_generators import dna_encoder 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | class DnaEncoderTest(tf.test.TestCase): 26 | 27 | def test_encode_decode(self): 28 | original = 'TTCGCGGNNNAACCCAACGCCATCTATGTANNTTGAGTTGTTGAGTTAAA' 29 | 30 | # Encoding should be reversible for any reasonable chunk size. 31 | for chunk_size in [1, 2, 4, 6, 8]: 32 | encoder = dna_encoder.DNAEncoder(chunk_size=chunk_size) 33 | encoded = encoder.encode(original) 34 | decoded = encoder.decode(encoded) 35 | self.assertEqual(original, decoded) 36 | 37 | def test_delimited_dna_encoder(self): 38 | original = 'TTCGCGGNNN,AACCCAACGC,CATCTATGTA,NNTTGAGTTG,TTGAGTTAAA' 39 | 40 | # Encoding should be reversible for any reasonable chunk size. 41 | for chunk_size in [1, 2, 4, 6, 8]: 42 | encoder = dna_encoder.DelimitedDNAEncoder(chunk_size=chunk_size) 43 | encoded = encoder.encode(original) 44 | decoded = encoder.decode(encoded) 45 | self.assertEqual(original, decoded) 46 | 47 | 48 | if __name__ == '__main__': 49 | tf.test.main() 50 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/gene_expression_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Genetics problems.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import numpy as np 21 | 22 | from tensor2tensor.data_generators import dna_encoder 23 | from tensor2tensor.data_generators import gene_expression 24 | 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | class GeneticsTest(tf.test.TestCase): 29 | 30 | def _one_hot_bases(self, bases): 31 | ref = ["A", "C", "T", "G"] 32 | one_hots = [] 33 | for base in bases: 34 | one_hot = [False] * 4 35 | if base in ref: 36 | one_hot[ref.index(base)] = True 37 | one_hots.append(one_hot) 38 | return np.array(one_hots) 39 | 40 | def testRecordToExample(self): 41 | encoder = dna_encoder.DNAEncoder(chunk_size=2) 42 | raw_inputs = ["A", "C", "G", "N", "C", "T"] 43 | 44 | # Put in numpy arrays in the same format as in the h5 file 45 | inputs = self._one_hot_bases(raw_inputs) 46 | mask = np.array([True, False, True]) 47 | outputs = np.array([[1.0, 2.0, 3.0], [5.0, 1.0, 0.2], [5.1, 2.3, 2.3]]) 48 | # Convert to example dict 49 | ex_dict = gene_expression.to_example_dict(encoder, inputs, mask, outputs) 50 | 51 | self.assertEqual(len(raw_inputs) // 2 + 1, len(ex_dict["inputs"])) 52 | self.assertAllEqual(encoder.encode(raw_inputs) + [1], ex_dict["inputs"]) 53 | self.assertAllEqual([1.0, 0.0, 1.0], ex_dict["targets_mask"]) 54 | self.assertAllEqual([1.0, 2.0, 3.0, 5.0, 1.0, 0.2, 5.1, 2.3, 2.3], 55 | ex_dict["targets"]) 56 | self.assertAllEqual([3, 3], ex_dict["targets_shape"]) 57 | 58 | def testGenerateShardArgs(self): 59 | num_examples = 37 60 | num_shards = 4 61 | outfiles = [str(i) for i in range(num_shards)] 62 | shard_args = gene_expression.generate_shard_args(outfiles, num_examples) 63 | 64 | starts, ends, fnames = zip(*shard_args) 65 | self.assertAllEqual([0, 9, 18, 27], starts) 66 | self.assertAllEqual([9, 18, 27, 37], ends) 67 | self.assertAllEqual(fnames, outfiles) 68 | 69 | 70 | if __name__ == "__main__": 71 | tf.test.main() 72 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/imagenet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for ImageNet.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl.testing import parameterized 23 | from tensor2tensor.data_generators import imagenet 24 | from tensor2tensor.utils import hparam 25 | 26 | import tensorflow.compat.v1 as tf 27 | from tensorflow.compat.v1 import estimator as tf_estimator 28 | 29 | 30 | class ImagenetTest(parameterized.TestCase, tf.test.TestCase): 31 | 32 | @parameterized.named_parameters( 33 | ("Default", None), 34 | ("Area", "AREA"), 35 | ("Dilated", "DILATED")) 36 | def testImagenetMultiResolutionPreprocessExample(self, resize_method): 37 | example = {"inputs": tf.random_uniform([64, 64, 3], minval=-1.)} 38 | mode = tf_estimator.ModeKeys.TRAIN 39 | hparams = hparam.HParams(resolutions=[8, 16, 32]) 40 | if resize_method is not None: 41 | hparams.resize_method = resize_method 42 | 43 | problem = imagenet.ImageImagenetMultiResolutionGen() 44 | preprocessed_example = problem.preprocess_example(example, mode, hparams) 45 | self.assertLen(preprocessed_example, 1) 46 | self.assertEqual(preprocessed_example["inputs"].shape, (42, 32, 3)) 47 | 48 | def testImagenetIsNormalized(self): 49 | problem = imagenet.ImageImagenet224() 50 | self.assertTrue(problem.normalize_image) 51 | problem = imagenet.ImageImagenet224NoNormalization() 52 | self.assertFalse(problem.normalize_image) 53 | 54 | 55 | if __name__ == "__main__": 56 | tf.test.main() 57 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/lm1b_imdb.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Data generators for LM1B and IMDb combined data-set.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.data_generators import imdb 23 | from tensor2tensor.data_generators import lm1b 24 | from tensor2tensor.data_generators import multi_problem 25 | from tensor2tensor.data_generators import text_problems 26 | from tensor2tensor.utils import registry 27 | 28 | 29 | @registry.register_problem 30 | class LanguagemodelLm1bSentimentIMDB(multi_problem.MultiProblem): 31 | """LM1b and IMDb mixed problem class for multitask learning.""" 32 | 33 | def __init__(self, was_reversed=False, was_copy=False): 34 | super(LanguagemodelLm1bSentimentIMDB, self).__init__(was_reversed, was_copy) 35 | self.task_list.append(lm1b.LanguagemodelLm1bCharacters()) 36 | self.task_list.append(imdb.SentimentIMDBCharacters()) 37 | 38 | @property 39 | def vocab_type(self): 40 | return text_problems.VocabType.CHARACTER 41 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/lm1b_mnli.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Data generators for LM1B and MNLI combined datasets.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.data_generators import lm1b 23 | from tensor2tensor.data_generators import multi_problem 24 | from tensor2tensor.data_generators import multinli 25 | from tensor2tensor.data_generators import text_problems 26 | from tensor2tensor.utils import registry 27 | 28 | 29 | @registry.register_problem 30 | class LanguagemodelLm1bMultiNLISubwords(multi_problem.MultiProblem): 31 | """LM1b and MNLI mixed problem class for multitask learning.""" 32 | 33 | def __init__(self, was_reversed=False, was_copy=False): 34 | super(LanguagemodelLm1bMultiNLISubwords, self).__init__( 35 | was_reversed, was_copy) 36 | self.task_list.append(lm1b.LanguagemodelLm1b32k()) 37 | self.task_list.append(multinli.MultiNLISharedVocab()) 38 | 39 | @property 40 | def vocab_type(self): 41 | return text_problems.VocabType.SUBWORD 42 | 43 | 44 | @registry.register_problem 45 | class LanguagemodelLm1bMultiNLI(multi_problem.MultiProblem): 46 | """LM1b and MNLI mixed problem class for multitask learning.""" 47 | 48 | def __init__(self, was_reversed=False, was_copy=False): 49 | super(LanguagemodelLm1bMultiNLI, self).__init__(was_reversed, was_copy) 50 | self.task_list.append(lm1b.LanguagemodelLm1bCharacters()) 51 | self.task_list.append(multinli.MultiNLICharacters()) 52 | 53 | @property 54 | def vocab_type(self): 55 | return text_problems.VocabType.CHARACTER 56 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/mscoco_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for MS COCO.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl.testing import parameterized 23 | from tensor2tensor.data_generators import mscoco 24 | from tensor2tensor.utils import hparam 25 | 26 | import tensorflow.compat.v1 as tf 27 | from tensorflow.compat.v1 import estimator as tf_estimator 28 | 29 | 30 | class MscocoTest(parameterized.TestCase, tf.test.TestCase): 31 | 32 | @parameterized.named_parameters( 33 | ("Default", None), 34 | ("Area", "AREA"), 35 | ("Dilated", "DILATED")) 36 | def testMsCocoMultiResolutionPreprocessExample(self, resize_method): 37 | example = {"inputs": tf.random_uniform([400, 400, 3], minval=-1.)} 38 | mode = tf_estimator.ModeKeys.TRAIN 39 | hparams = hparam.HParams(resolutions=[8, 16, 32]) 40 | if resize_method is not None: 41 | hparams.resize_method = resize_method 42 | 43 | problem = mscoco.ImageTextMsCocoMultiResolution() 44 | preprocessed_example = problem.preprocess_example(example, mode, hparams) 45 | self.assertLen(preprocessed_example, 1) 46 | self.assertEqual(preprocessed_example["inputs"].shape, (42, 32, 3)) 47 | 48 | 49 | if __name__ == "__main__": 50 | tf.test.main() 51 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/ops/subword_text_encoder.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSOR2TESNOR_DATA_GENERATORS_OPS_SUBWORD_TEXT_ENCODER_H_ 2 | #define TENSOR2TESNOR_DATA_GENERATORS_OPS_SUBWORD_TEXT_ENCODER_H_ 3 | 4 | #include "third_party/absl/container/flat_hash_map.h" 5 | #include "third_party/absl/container/flat_hash_set.h" 6 | #include "third_party/absl/strings/string_view.h" 7 | #include "third_party/icu/include/unicode/uchar.h" 8 | #include "third_party/tensorflow/core/framework/tensor.h" 9 | 10 | namespace tensor2tensor { 11 | 12 | // A subword text encoder with built in tokenizer. 13 | // 14 | // Equivalent to tensor2tensor's subword text 15 | // https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py, 16 | // This code (or a suitable replacement) should eventually move into tfds 17 | // and should be deleted from tensor2tensor. 18 | 19 | class SubwordTextEncoder { 20 | public: 21 | explicit SubwordTextEncoder(const std::string& vocab_filename); 22 | virtual ~SubwordTextEncoder() {} 23 | 24 | // Breaks up input text into subtokens. 25 | void Encode(absl::string_view text, std::vector* ids); 26 | 27 | private: 28 | // Given a full token as input, breaks the token up into subtokens and appends 29 | // corresponding IDs to the ids vector. 30 | void EncodeSubtokens(absl::string_view token, std::vector* ids); 31 | 32 | // Escapes a token so unencodable characters are replaced by escape sequences. 33 | std::string EscapeToken(absl::string_view token); 34 | 35 | // Maps subword tokens to IDs. 36 | absl::flat_hash_map vocab_; 37 | // A set containing all valid unicode code points that can be encoded without 38 | // being escaped. 39 | absl::flat_hash_set alphabet_; 40 | }; 41 | 42 | } // namespace tensor2tensor 43 | 44 | #endif // TENSOR2TESNOR_DATA_GENERATORS_OPS_SUBWORD_TEXT_ENCODER_H_ 45 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/ops/subword_text_encoder_ops.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "third_party/py/tensor2tensor/data_generators/ops/subword_text_encoder.h" 4 | #include "third_party/tensorflow/core/framework/op_kernel.h" 5 | #include "third_party/tensorflow/core/framework/shape_inference.h" 6 | #include "third_party/tensorflow/core/framework/tensor.h" 7 | #include "third_party/tensorflow/core/framework/types.h" 8 | 9 | namespace tensor2tensor { 10 | namespace { 11 | 12 | using ::tensorflow::DEVICE_CPU; 13 | using ::tensorflow::OpKernel; 14 | using ::tensorflow::OpKernelConstruction; 15 | using ::tensorflow::OpKernelContext; 16 | using ::tensorflow::Status; 17 | using ::tensorflow::Tensor; 18 | using ::tensorflow::TensorShape; 19 | using ::tensorflow::tstring; 20 | using ::tensorflow::shape_inference::InferenceContext; 21 | 22 | REGISTER_OP("SubwordTextEncoderEncode") 23 | .Input("s: string") 24 | .Output("encoded: int64") 25 | .Attr("vocab_filename: string") 26 | .SetShapeFn([](InferenceContext* ctx) { 27 | ctx->set_output(0, ctx->Vector(ctx->UnknownDim())); 28 | return tensorflow::Status(); 29 | }); 30 | 31 | class SubwordTextEncoderEncodeOp : public OpKernel { 32 | public: 33 | explicit SubwordTextEncoderEncodeOp( 34 | OpKernelConstruction* ctx) : OpKernel(ctx) { 35 | std::string vocab_filename; 36 | OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_filename", &vocab_filename)); 37 | encoder_ = std::make_unique(vocab_filename); 38 | } 39 | 40 | void Compute(OpKernelContext* ctx) override { 41 | // Get input string and deserialize into ArticleExample proto. 42 | absl::string_view s = ctx->input(0).scalar()(); 43 | 44 | // Construct encoded output tensors. 45 | std::vector encoded_ids; 46 | encoder_->Encode(s, &encoded_ids); 47 | Tensor* encoded; 48 | OP_REQUIRES_OK( 49 | ctx, ctx->allocate_output( 50 | 0, TensorShape({static_cast(encoded_ids.size())}), 51 | &encoded)); 52 | auto encoded_vec = encoded->vec(); 53 | // TODO(noam): find someone who remembers c++ eigen and ask the proper way 54 | // to copy a std::Vector to an Eigen whatever-this-is 55 | for (int i = 0; i < encoded_ids.size(); i++) { 56 | encoded_vec(i) = encoded_ids[i]; 57 | } 58 | } 59 | 60 | private: 61 | std::unique_ptr encoder_; 62 | }; 63 | 64 | REGISTER_KERNEL_BUILDER(Name("SubwordTextEncoderEncode").Device(DEVICE_CPU), 65 | SubwordTextEncoderEncodeOp); 66 | 67 | } // namespace 68 | } // namespace tensor2tensor 69 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/ops/subword_text_encoder_ops_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for subword_text_encoder_ops.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.data_generators.ops import subword_text_encoder_ops 23 | import tensorflow.compat.v1 as tf 24 | 25 | vocab_file = ( 26 | "third_party/py/tensor2tensor/data_generators/ops/testdata/subwords") 27 | 28 | 29 | class SubwordTextEncoderOpsTest(tf.test.TestCase): 30 | 31 | def test_subword_text_encoder_encode(self): 32 | s = "the quick brown fox jumps over the lazy dog" 33 | encoded = subword_text_encoder_ops.subword_text_encoder_encode( 34 | s, vocab_file) 35 | self.assertAllEqual(encoded, [2, 3, 4, 5, 6, 7, 8, 9, 2, 11, 12, 1]) 36 | 37 | 38 | if __name__ == "__main__": 39 | tf.enable_eager_execution() 40 | tf.test.main() 41 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/ops/subword_text_encoder_test.cc: -------------------------------------------------------------------------------- 1 | #include "third_party/py/tensor2tensor/data_generators/ops/subword_text_encoder.h" 2 | 3 | #include "testing/base/public/gunit.h" 4 | #include "third_party/tensorflow/core/framework/tensor.h" 5 | #include "third_party/tensorflow/core/framework/tensor_testutil.h" 6 | 7 | namespace tensor2tensor { 8 | namespace { 9 | 10 | TEST(SubwordTextEncoderTest, EncodesSubTokens) { 11 | SubwordTextEncoder encoder("third_party/py/tensor2tensor/" 12 | "data_generators/ops/testdata/subwords"); 13 | std::vector t; 14 | encoder.Encode("the quick brown fox jumps over the lazy dog", &t); 15 | EXPECT_EQ(t, std::vector({2, 3, 4, 5, 6, 7, 8, 9, 2, 11, 12, 1})); 16 | } 17 | 18 | TEST(SubwordTextEncoderTest, EncodesUnicodeSubTokens) { 19 | SubwordTextEncoder encoder("third_party/py/tensor2tensor/" 20 | "data_generators/ops/testdata/subwords"); 21 | std::vector t; 22 | encoder.Encode("ɧęĻĽÒ", &t); 23 | EXPECT_EQ(t, std::vector({13, 14, 1})); 24 | } 25 | 26 | TEST(SubwordTextEncoderTest, EncodesUnicodeCodePoints) { 27 | SubwordTextEncoder encoder("third_party/py/tensor2tensor/" 28 | "data_generators/ops/testdata/subwords"); 29 | std::vector t; 30 | encoder.Encode("⻦ ⻭", &t); 31 | EXPECT_EQ(t, std::vector({15, 18, 16, 17, 1})); 32 | } 33 | 34 | TEST(SubwordTextEncoderTest, EncodesCharactersNotInAlphabet) { 35 | SubwordTextEncoder encoder("third_party/py/tensor2tensor/" 36 | "data_generators/ops/testdata/subwords"); 37 | std::vector t; 38 | encoder.Encode("!", &t); 39 | // Subtokens: '\', '3', '3', ';', '_', '', ''. 40 | EXPECT_EQ(t, std::vector({19, 23, 23, 30, 17, 1})); 41 | } 42 | 43 | } // namespace 44 | } // namespace tensor2tensor 45 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/ops/testdata/subwords: -------------------------------------------------------------------------------- 1 | '' 2 | '' 3 | 'the_' 4 | 'quick_' 5 | 'brow' 6 | 'n_' 7 | 'fox_' 8 | 'jump' 9 | 's_' 10 | 'over_' 11 | 'the_' 12 | 'lazy_' 13 | 'dog_' 14 | 'ɧę' 15 | 'ĻĽÒ_' 16 | '⻦' 17 | '⻭' 18 | '_' 19 | ' ' 20 | '\' 21 | '0' 22 | '1' 23 | '2' 24 | '3' 25 | '4' 26 | '5' 27 | '6' 28 | '7' 29 | '8' 30 | '9' 31 | ';' -------------------------------------------------------------------------------- /tensor2tensor/data_generators/test_data/1.csv: -------------------------------------------------------------------------------- 1 | media_name,label 2 | my_media,my_label 3 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/test_data/corpus-1.txt: -------------------------------------------------------------------------------- 1 | One morning I shot an elephant in my pajamas. How he got in my pajamas, I don't 2 | know. 3 | 4 | Groucho Marx 5 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/test_data/corpus-2.txt: -------------------------------------------------------------------------------- 1 | I haven't slept for 10 days... because that would be too long. 2 | 3 | Mitch Hedberg 4 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/test_data/vocab-1.txt: -------------------------------------------------------------------------------- 1 | lollipop,8 2 | reverberated,12 3 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/test_data/vocab-2.txt: -------------------------------------------------------------------------------- 1 | kattywampus,11 2 | kaput 3 | balderdash,10 4 | jiggery-pokery,14 5 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/timeseries_data_generator.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Data generator for the timeseries problem.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | 23 | 24 | def generate_data(timeseries_length, timeseries_params): 25 | """Generates synthetic timeseries using input parameters. 26 | 27 | Each generated timeseries has timeseries_length data points. 28 | Parameters for each timeseries are specified by timeseries_params. 29 | 30 | Args: 31 | timeseries_length: Number of data points to generate for each timeseries. 32 | timeseries_params: Parameters used to generate the timeseries. The following 33 | parameters need to be specified for each timeseries: 34 | m = Slope of the timeseries used to compute the timeseries trend. 35 | b = y-intercept of the timeseries used to compute the timeseries trend. 36 | A = Timeseries amplitude used to compute timeseries period. 37 | freqcoeff = Frequency coefficient used to compute timeseries period. 38 | rndA = Random amplitude used to inject noise into the timeseries. 39 | fn = Base timeseries function (np.cos or np.sin). 40 | Example params for two timeseries. 41 | [{"m": 0.006, "b": 300.0, "A":50.0, "freqcoeff":1500.0, "rndA":15.0, 42 | "fn": np.sin}, 43 | {"m": 0.000, "b": 500.0, "A":35.0, "freqcoeff":3500.0, "rndA":25.0, 44 | "fn": np.cos}] 45 | 46 | Returns: 47 | Multi-timeseries (list of list). 48 | """ 49 | x = range(timeseries_length) 50 | 51 | multi_timeseries = [] 52 | for p in timeseries_params: 53 | # Trend 54 | y1 = [p["m"] * i + p["b"] for i in x] 55 | # Period 56 | y2 = [p["A"] * p["fn"](i / p["freqcoeff"]) for i in x] 57 | # Noise 58 | y3 = np.random.normal(0, p["rndA"], timeseries_length).tolist() 59 | # Sum of Trend, Period and Noise. Replace negative values with zero. 60 | y = [max(a + b + c, 0) for a, b, c in zip(y1, y2, y3)] 61 | multi_timeseries.append(y) 62 | 63 | return multi_timeseries 64 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/translate_ende_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2tensor.data_generators.translate_ende.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.data_generators import problem 23 | from tensor2tensor.data_generators import translate_ende 24 | 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | class TranslateEndeTest(tf.test.TestCase): 29 | """Tests that some TranslateEnde subclasses inherit information correctly.""" 30 | 31 | def test_vocab_size(self): 32 | wmt_8k = translate_ende.TranslateEndeWmt8k() 33 | wmt_32k = translate_ende.TranslateEndeWmt32k() 34 | self.assertEqual(wmt_8k.approx_vocab_size, 8192) 35 | self.assertEqual(wmt_32k.approx_vocab_size, 32768) 36 | 37 | def test_additional_datasets(self): 38 | wmt_8k = translate_ende.TranslateEndeWmt8k() 39 | wmt_32k = translate_ende.TranslateEndeWmt32k() 40 | self.assertListEqual(wmt_8k.additional_training_datasets, []) 41 | self.assertListEqual(wmt_32k.additional_training_datasets, []) 42 | 43 | def test_source_data_files(self): 44 | wmt_8k = translate_ende.TranslateEndeWmt8k() 45 | wmt_32k = translate_ende.TranslateEndeWmt32k() 46 | eval_split = problem.DatasetSplit.EVAL 47 | train_split = problem.DatasetSplit.TRAIN 48 | 49 | wmt_8k_eval_files = wmt_8k.source_data_files(eval_split) 50 | wmt_32k_eval_files = wmt_32k.source_data_files(eval_split) 51 | self.assertListEqual(wmt_8k_eval_files, wmt_32k_eval_files) 52 | self.assertGreater(len(wmt_8k_eval_files), 0) 53 | 54 | wmt_8k_train_files = wmt_8k.source_data_files(train_split) 55 | wmt_32k_train_files = wmt_32k.source_data_files(train_split) 56 | self.assertListEqual(wmt_8k_train_files, wmt_32k_train_files) 57 | self.assertGreater(len(wmt_8k_train_files), 0) 58 | 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/translate_entn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Data generators for translation data-sets.""" 17 | 18 | 19 | from tensor2tensor.data_generators import problem 20 | from tensor2tensor.data_generators import text_encoder 21 | from tensor2tensor.data_generators import translate 22 | from tensor2tensor.utils import registry 23 | 24 | 25 | EOS = text_encoder.EOS_ID 26 | 27 | _URL = "https://github.com/LauraMartinus/ukuxhumana/blob/master/data/en_tn" 28 | 29 | _ENTN_TRAIN_DATASETS = [[ 30 | _URL + "/eng_tswane.train.tar.gz?raw=true", 31 | ("entn_parallel.train.en", "entn_parallel.train.tn") 32 | ]] 33 | 34 | _ENTN_TEST_DATASETS = [[ 35 | _URL + "/eng_tswane.dev.tar.gz?raw=true", 36 | ("entn_parallel.dev.en", "entn_parallel.dev.tn") 37 | ]] 38 | 39 | 40 | @registry.register_problem 41 | class TranslateEntnRma(translate.TranslateProblem): 42 | """Problem spec for English-Setswana translation. 43 | 44 | Uses the RMA Autshumato dataset. 45 | """ 46 | 47 | @property 48 | def approx_vocab_size(self): 49 | return 2**15 # 32768 50 | 51 | @property 52 | def vocab_filename(self): 53 | return "vocab.entn.%d" % self.approx_vocab_size 54 | 55 | def source_data_files(self, dataset_split): 56 | train = dataset_split == problem.DatasetSplit.TRAIN 57 | return _ENTN_TRAIN_DATASETS if train else _ENTN_TEST_DATASETS 58 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/translate_envi.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Data generators for En-Vi translation.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from tensor2tensor.data_generators import problem 22 | from tensor2tensor.data_generators import text_encoder 23 | from tensor2tensor.data_generators import translate 24 | from tensor2tensor.utils import registry 25 | 26 | # End-of-sentence marker. 27 | EOS = text_encoder.EOS_ID 28 | 29 | # For English-Vietnamese the IWSLT'15 corpus 30 | # from https://nlp.stanford.edu/projects/nmt/ is used. 31 | # The original dataset has 133K parallel sentences. 32 | _ENVI_TRAIN_DATASETS = [[ 33 | "https://github.com/stefan-it/nmt-en-vi/raw/master/data/train-en-vi.tgz", # pylint: disable=line-too-long 34 | ("train.en", "train.vi") 35 | ]] 36 | 37 | # For development 1,553 parallel sentences are used. 38 | _ENVI_TEST_DATASETS = [[ 39 | "https://github.com/stefan-it/nmt-en-vi/raw/master/data/dev-2012-en-vi.tgz", # pylint: disable=line-too-long 40 | ("tst2012.en", "tst2012.vi") 41 | ]] 42 | 43 | 44 | # See this PR on github for some results with Transformer on this Problem. 45 | # https://github.com/tensorflow/tensor2tensor/pull/611 46 | 47 | 48 | @registry.register_problem 49 | class TranslateEnviIwslt32k(translate.TranslateProblem): 50 | """Problem spec for IWSLT'15 En-Vi translation.""" 51 | 52 | @property 53 | def approx_vocab_size(self): 54 | return 2**15 # 32768 55 | 56 | def source_data_files(self, dataset_split): 57 | train = dataset_split == problem.DatasetSplit.TRAIN 58 | return _ENVI_TRAIN_DATASETS if train else _ENVI_TEST_DATASETS 59 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/wikifact/README.md: -------------------------------------------------------------------------------- 1 | # Assessing the Factual Accuracy of Generated Text 2 | 3 | This directory will contain the code and scripts to generate data and train 4 | models from the paper *Assessing the Factual Accuracy of Generated Text*. 5 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/wikisum/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/wikisum/delete_instances.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Delete Google Compute Engine instances with naming structure $NAME-$INDEX 4 | # (e.g. machines created with parallel_launch.py). 5 | # Example usage: 6 | # delete_instances.sh fetch-ref-urls 1000 7 | 8 | NAME=$1 9 | MAX=$2 10 | MIN=${3:-0} 11 | 12 | LOG_F=/tmp/delete-$NAME-logs.txt 13 | 14 | echo "Deleting $MAX instances starting with $NAME-$MIN" 15 | 16 | for i in $(seq $MIN $MAX) 17 | do 18 | gcloud compute instances delete --quiet $NAME-$i > $LOG_F 2>&1 & 19 | if [[ $(( i % 100 )) == 0 ]] 20 | then 21 | # Give it some room to breathe every 100 22 | sleep 30 23 | fi 24 | done 25 | 26 | echo "Delete commands launched. Logs redirected to $LOG_F" 27 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/wikisum/generate_vocab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Generate vocab from references and wikis.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensor2tensor.data_generators.wikisum import wikisum 22 | 23 | import tensorflow.compat.v1 as tf 24 | 25 | flags = tf.flags 26 | FLAGS = flags.FLAGS 27 | 28 | flags.DEFINE_string("out_dir", None, "Directory to write vocab to.") 29 | flags.DEFINE_string("wikis_dir", 30 | "gs://tensor2tensor-data/wikisum/wiki_content/", 31 | "Directory with wiki_content.tfrecords shards.") 32 | flags.DEFINE_string("refs_dir", None, 33 | "Directory with process_X folders with reference shards.") 34 | flags.DEFINE_bool("for_commoncrawl", False, 35 | "Whether to use WikisumCommoncrawl or WikisumWeb.") 36 | 37 | 38 | def main(_): 39 | if FLAGS.for_commoncrawl: 40 | problem = wikisum.WikisumCommoncrawl() 41 | else: 42 | problem = wikisum.WikisumWeb() 43 | problem.generate_vocab(FLAGS.out_dir, FLAGS.wikis_dir, FLAGS.refs_dir) 44 | 45 | 46 | if __name__ == "__main__": 47 | tf.logging.set_verbosity(tf.logging.INFO) 48 | tf.app.run() 49 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/wikisum/get_references_commoncrawl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Extract references from CommonCrawl files.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import tempfile 23 | 24 | from tensor2tensor.data_generators.wikisum import utils 25 | from tensor2tensor.data_generators.wikisum import wikisum 26 | 27 | import tensorflow.compat.v1 as tf 28 | 29 | flags = tf.flags 30 | FLAGS = flags.FLAGS 31 | 32 | flags.DEFINE_integer("num_tasks", 1000, "Number of parallel tasks.") 33 | flags.DEFINE_integer("task_id", 0, "Task id in a parallel run.") 34 | flags.DEFINE_string("metadata_dir", 35 | "gs://tensor2tensor-data/wikisum/commoncrawl_metadata/", 36 | "Path to metadata files specifying what references are in " 37 | "which CommonCrawl files.") 38 | flags.DEFINE_string("out_dir", None, "Directory to write references to.") 39 | flags.DEFINE_string("commoncrawl_wet_dir", None, 40 | "Path to CommonCrawl wet.gz files locally. If not " 41 | "provided, will download.") 42 | 43 | 44 | def main(_): 45 | assert FLAGS.out_dir 46 | assert FLAGS.metadata_dir 47 | out_dir = os.path.join(FLAGS.out_dir, "process_%d" % FLAGS.task_id) 48 | tf.gfile.MakeDirs(out_dir) 49 | 50 | with utils.timing("get_refs_commoncrawl"): 51 | # Get all WET files 52 | if FLAGS.commoncrawl_wet_dir: 53 | wet_files = tf.gfile.Glob( 54 | os.path.join(FLAGS.commoncrawl_wet_dir, "*.wet.gz")) 55 | else: 56 | tmp_dir = tempfile.gettempdir() 57 | wet_files = list( 58 | utils.wet_download_urls(utils.WET_PATHS_BY_DATE["0917"], tmp_dir)) 59 | 60 | # Shard and select this task's work 61 | wet_files.sort() 62 | wet_files = utils.shard(wet_files, FLAGS.num_tasks)[FLAGS.task_id] 63 | tf.logging.info("Sharded out WET files. Processing %d files", 64 | len(wet_files)) 65 | 66 | wikisum.extract_references_from_wets(wet_files, FLAGS.metadata_dir, out_dir) 67 | 68 | 69 | if __name__ == "__main__": 70 | tf.logging.set_verbosity(tf.logging.INFO) 71 | tf.app.run() 72 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/wikisum/test_data/para_bad1.txt: -------------------------------------------------------------------------------- 1 | kolkata ward no 97 37 2 | you are here : india » west bengal » kolkata » kolkata 3 | this paragraph too short 4 | a | b | c | d | e | f | g | h | i | j | k | l | m | n | o | p | q | r | s | t | u | v | w | x | y | z 5 | 123 123 123 123 985 9880 1230 0980 . 12398 . 6 | - 5 . 7 % - 5 . 2 % - 15 . 1 % 4 . 7 % - 13 . 3 % 7 | http : / / www . bbc . co . uk / sport / football / 24351521 8 | no . - 26 beadon street . 9 | { { / playpopup } } { { ^ playpopup } } { { # playinvideopage } } { { / playinvideopage } } { { ^ playinvideopage } } { { / playinvideopage } } { { / playpopup } }

{ { # playpopup } } { { / playpopup } } { { ^ playpopup } } { { # playinvideopage } } { { / playinvideopage } } { { ^ playinvideopage } } { { / playinvideopage } } { { / playpopup } } { { genre } } 10 | denham , samuel coulter , sally 133 oct 28 1819 11 | browse by 12 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/wikisum/test_data/para_good1.txt: -------------------------------------------------------------------------------- 1 | this is a very good paragraph . it even has two sentences . 2 | the castle that was soon to figure so largely in lee’s life lay fourteen miles 3 | to the southwest of where he sat perched atop his tank . topped with storybook 4 | crenelations and accompanied by a rich history , schloss itter , as it’s called 5 | in german , was first mentioned in land records as early as 1240 . since then , 6 | itter has passed through a number of hands . after germany’s march 1938 7 | annexation of austria , the castle’s robust construction and relatively remote 8 | location attracted the attention of the notoriously secretive nazis . within 9 | months of absorbing austria into the greater reich , the german government 10 | requisitioned castle itter for unspecified “official use”—which included housing 11 | for several months in 1942 an organization called the “german association for 12 | combating the dangers of tobacco . ” on february 7 , 1943 , it fell into new 13 | hands yet again , for on that day , the structure and all its outbuildings were 14 | requisitioned by the wehrmacht on behalf of the ss . 15 | the url for the site is http : / / www . bbc . co . uk / sport / football / 24351521 . 16 | -------------------------------------------------------------------------------- /tensor2tensor/data_generators/wikisum/utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2tensor.data_generators.wikisum.utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | from tensor2tensor.data_generators.wikisum import utils 24 | 25 | import tensorflow.compat.v1 as tf 26 | 27 | pkg_dir = os.path.abspath(__file__) 28 | pkg_dir, _ = os.path.split(pkg_dir) 29 | _TESTDATA = os.path.join(pkg_dir, "test_data") 30 | 31 | 32 | def _get_testdata(filename): 33 | with tf.io.gfile.GFile(filename) as f: 34 | return f.read() 35 | 36 | 37 | class UtilsTest(tf.test.TestCase): 38 | 39 | def test_filter_paragraph(self): 40 | for bad in tf.io.gfile.glob(os.path.join(_TESTDATA, "para_bad*.txt")): 41 | for p in _get_testdata(bad).split("\n"): 42 | self.assertTrue(utils.filter_paragraph(p), 43 | msg="Didn't filter %s" % p) 44 | for good in tf.io.gfile.glob(os.path.join(_TESTDATA, "para_good*.txt")): 45 | for p in _get_testdata(good).split("\n"): 46 | p = _get_testdata(good) 47 | self.assertFalse(utils.filter_paragraph(p), msg="Filtered %s" % p) 48 | 49 | 50 | if __name__ == "__main__": 51 | tf.test.main() 52 | -------------------------------------------------------------------------------- /tensor2tensor/envs/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Environments defined in T2T. Imports here force registration.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.envs import gym_env_problem 23 | from tensor2tensor.envs import tic_tac_toe_env 24 | from tensor2tensor.envs import tic_tac_toe_env_problem 25 | -------------------------------------------------------------------------------- /tensor2tensor/envs/gym_spaces_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for gym_spaces_utils.py.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from gym.spaces import Box 23 | from gym.spaces import Discrete 24 | import numpy as np 25 | from tensor2tensor.envs import gym_spaces_utils 26 | import tensorflow.compat.v1 as tf 27 | 28 | 29 | class GymSpacesUtilsTest(tf.test.TestCase): 30 | 31 | def test_discrete_space_spec(self): 32 | discrete_space = Discrete(100) 33 | spec = gym_spaces_utils.gym_space_spec(discrete_space) 34 | self.assertIsInstance(spec, tf.FixedLenFeature) 35 | self.assertEqual(spec.dtype, tf.int64) 36 | self.assertListEqual(list(spec.shape), [1]) 37 | 38 | def test_box_space_spec(self): 39 | box_space = Box(low=0, high=10, shape=[5, 6], dtype=np.float32) 40 | spec = gym_spaces_utils.gym_space_spec(box_space) 41 | self.assertIsInstance(spec, tf.FixedLenFeature) 42 | self.assertEqual(spec.dtype, tf.float32) 43 | self.assertListEqual(list(spec.shape), [5, 6]) 44 | 45 | def test_discrete_space_encode(self): 46 | discrete_space = Discrete(100) 47 | value = discrete_space.sample() 48 | encoded_value = gym_spaces_utils.gym_space_encode(discrete_space, value) 49 | self.assertListEqual([value], encoded_value) 50 | 51 | def test_box_space_encode(self): 52 | box_space = Box(low=0, high=10, shape=[2], dtype=np.int64) 53 | value = np.array([2, 3]) 54 | encoded_value = gym_spaces_utils.gym_space_encode(box_space, value) 55 | self.assertListEqual([2, 3], encoded_value) 56 | 57 | 58 | if __name__ == '__main__': 59 | tf.test.main() 60 | -------------------------------------------------------------------------------- /tensor2tensor/envs/mujoco_problems.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Mujoco Gym environments.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import functools 23 | from tensor2tensor.envs import rendered_env_problem 24 | from tensor2tensor.layers import modalities 25 | from tensor2tensor.rl import gym_utils 26 | from tensor2tensor.utils import registry 27 | 28 | 29 | 30 | @registry.register_env_problem 31 | class ReacherEnvProblem(rendered_env_problem.RenderedEnvProblem): 32 | """Mujoco's reacher environment.""" 33 | 34 | def __init__(self): 35 | base_env_name = "Reacher-v2" 36 | wrapper_fn = functools.partial( 37 | gym_utils.gym_env_wrapper, **{ 38 | "rl_env_max_episode_steps": -1, 39 | "maxskip_env": False, 40 | "rendered_env": True, 41 | "rendered_env_resize_to": None, # Do not resize frames 42 | "sticky_actions": False, 43 | "output_dtype": None, 44 | "num_actions": None, 45 | }) 46 | super(ReacherEnvProblem, self).__init__( 47 | base_env_name=base_env_name, env_wrapper_fn=wrapper_fn) 48 | 49 | @property 50 | def input_modality(self): 51 | return modalities.ModalityType.VIDEO 52 | 53 | @property 54 | def target_modality(self): 55 | return modalities.ModalityType.VIDEO 56 | 57 | @property 58 | def action_modality(self): 59 | return modalities.ModalityType.IDENTITY 60 | 61 | @property 62 | def reward_modality(self): 63 | return modalities.ModalityType.IDENTITY 64 | 65 | @property 66 | def input_vocab_size(self): 67 | return 256 68 | 69 | @property 70 | def target_vocab_size(self): 71 | return 256 72 | -------------------------------------------------------------------------------- /tensor2tensor/envs/mujoco_problems_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2tensor.envs.mujoco_problems.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | from tensor2tensor.envs import env_problem_utils 24 | from tensor2tensor.envs import mujoco_problems # pylint: disable=unused-import 25 | from tensor2tensor.utils import registry 26 | import tensorflow.compat.v1 as tf 27 | 28 | 29 | class ReacherEnvProblemTest(tf.test.TestCase): 30 | 31 | def test_registration_and_interaction_with_env_problem(self): 32 | batch_size = 5 33 | # This ensures that registration has occurred. 34 | ep = registry.env_problem("reacher_env_problem", batch_size=batch_size) 35 | ep.reset() 36 | num_done = 0 37 | nsteps = 100 38 | for _ in range(nsteps): 39 | actions = np.stack([ep.action_space.sample() for _ in range(batch_size)]) 40 | obs, rewards, dones, infos = ep.step(actions) 41 | 42 | # Assert that things are happening batchwise. 43 | self.assertEqual(batch_size, len(obs)) 44 | self.assertEqual(batch_size, len(rewards)) 45 | self.assertEqual(batch_size, len(dones)) 46 | self.assertEqual(batch_size, len(infos)) 47 | 48 | done_indices = env_problem_utils.done_indices(dones) 49 | ep.reset(done_indices) 50 | num_done += sum(dones) 51 | 52 | # Assert that something got done atleast, 53 | self.assertGreater(num_done, 0) 54 | 55 | 56 | if __name__ == "__main__": 57 | tf.test.main() 58 | -------------------------------------------------------------------------------- /tensor2tensor/envs/tic_tac_toe_env_problem.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """TicTacToeEnvProblem wraps the TicTacToeEnv in an EnvProblem.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.envs import gym_env_problem 23 | from tensor2tensor.layers import modalities 24 | from tensor2tensor.utils import registry 25 | 26 | 27 | @registry.register_env_problem 28 | class TicTacToeEnvProblem(gym_env_problem.GymEnvProblem): 29 | """Plays `batch_size` games of tic-tac-toe.""" 30 | 31 | def __init__(self): 32 | super(TicTacToeEnvProblem, self).__init__( 33 | base_env_name="T2TEnv-TicTacToeEnv-v0", 34 | reward_range=(-1, 1)) 35 | 36 | @property 37 | def input_modality(self): 38 | return modalities.ModalityType.IDENTITY_SYMBOL 39 | 40 | @property 41 | def input_vocab_size(self): 42 | # Since a box can be either x or o or empty. 43 | return 3 44 | 45 | @property 46 | def target_modality(self): 47 | return modalities.ModalityType.IDENTITY_SYMBOL 48 | 49 | @property 50 | def target_vocab_size(self): 51 | # Since reward is either -1 or 0 or +1. 52 | return 3 53 | 54 | @property 55 | def action_modality(self): 56 | return modalities.ModalityType.SYMBOL_WEIGHTS_ALL 57 | -------------------------------------------------------------------------------- /tensor2tensor/envs/tic_tac_toe_env_problem_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2tensor.envs.tic_tac_toe_env_problem.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | from tensor2tensor.envs import env_problem_utils 24 | from tensor2tensor.envs import tic_tac_toe_env # pylint: disable=unused-import 25 | from tensor2tensor.envs import tic_tac_toe_env_problem # pylint: disable=unused-import 26 | from tensor2tensor.utils import registry 27 | import tensorflow.compat.v1 as tf 28 | 29 | 30 | class TicTacToeEnvProblemTest(tf.test.TestCase): 31 | 32 | def test_registration_and_interaction_with_env_problem(self): 33 | batch_size = 5 34 | # This ensures that registration has occurred. 35 | ep = registry.env_problem("tic_tac_toe_env_problem", batch_size=batch_size) 36 | ep.reset() 37 | num_done, num_lost, num_won, num_draw = 0, 0, 0, 0 38 | nsteps = 100 39 | for _ in range(nsteps): 40 | actions = np.stack([ep.action_space.sample() for _ in range(batch_size)]) 41 | obs, rewards, dones, infos = ep.step(actions) 42 | 43 | # Assert that things are happening batchwise. 44 | self.assertEqual(batch_size, len(obs)) 45 | self.assertEqual(batch_size, len(rewards)) 46 | self.assertEqual(batch_size, len(dones)) 47 | self.assertEqual(batch_size, len(infos)) 48 | 49 | done_indices = env_problem_utils.done_indices(dones) 50 | ep.reset(done_indices) 51 | num_done += sum(dones) 52 | for r, d in zip(rewards, dones): 53 | if not d: 54 | continue 55 | if r == -1: 56 | num_lost += 1 57 | elif r == 0: 58 | num_draw += 1 59 | elif r == 1: 60 | num_won += 1 61 | else: 62 | raise ValueError("reward should be -1, 0, 1 but is {}".format(r)) 63 | 64 | # Assert that something got done atleast, without that the next assert is 65 | # meaningless. 66 | self.assertGreater(num_done, 0) 67 | 68 | # Assert that things are consistent. 69 | self.assertEqual(num_done, num_won + num_lost + num_draw) 70 | 71 | 72 | if __name__ == "__main__": 73 | tf.test.main() 74 | -------------------------------------------------------------------------------- /tensor2tensor/envs/time_step_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2tensor.envs.time_step.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.envs import time_step 23 | 24 | import tensorflow.compat.v1 as tf 25 | 26 | 27 | class TimeStepTest(tf.test.TestCase): 28 | 29 | def test_create_time_step(self): 30 | ts = time_step.TimeStep.create_time_step( 31 | observation=1, done=True, raw_reward=1.0, processed_reward=1, action=1, 32 | info={1: 1, 2: 4}) 33 | 34 | self.assertEqual(1, ts.observation) 35 | self.assertTrue(ts.done) 36 | self.assertNear(1.0, ts.raw_reward, 1e-6) 37 | self.assertEqual(1, ts.processed_reward) 38 | self.assertEqual(1, ts.action) 39 | self.assertEqual({1: 1, 2: 4}, ts.info) 40 | 41 | def test_replace(self): 42 | ts = time_step.TimeStep.create_time_step(observation=1, action=1) 43 | self.assertFalse(ts.done) 44 | 45 | tsr = ts.replace(action=2, done=True, info={1: 1, 2: 4}) 46 | 47 | # Asert that ts didn't change. 48 | self.assertFalse(ts.done) 49 | self.assertEqual(1, ts.observation) 50 | self.assertEqual(1, ts.action) 51 | 52 | # But tsr is as expected. 53 | self.assertTrue(tsr.done) 54 | self.assertEqual(1, tsr.observation) # unchanged 55 | self.assertEqual(2, tsr.action) # changed 56 | self.assertEqual({1: 1, 2: 4}, tsr.info) 57 | 58 | 59 | if __name__ == '__main__': 60 | tf.test.main() 61 | -------------------------------------------------------------------------------- /tensor2tensor/insights/README.md: -------------------------------------------------------------------------------- 1 | # Tensor2Tensor Insights 2 | 3 | The Insights packages provides an interactive webservice for understanding the 4 | inner workings of a Tensor2Tensor model. It will provide a series of 5 | visualizations extracted from a requested T2T model that informs model developers 6 | and model users on how to improve or best utilize a model. 7 | 8 | ## Dependencies 9 | 10 | Before using the Insights server, you must install [Bower](https://bower.io/) 11 | which we use to manage our web component dependencies. You can easily install 12 | this with the [Node Package Manager](https://www.npmjs.com/). 13 | 14 | ## Setup Instructions 15 | 16 | After training a model, such as according to the Quick Start guide, you can run 17 | the `t2t-insights-server` binary and begin querying it. 18 | 19 | First, prepare the bower dependencies by navigating into the 20 | `tensor2tensor/insights/polymer` directory and running `bower install`: 21 | 22 | ``` 23 | pushd tensor2tensor/insights/polymer 24 | bower install 25 | popd 26 | ``` 27 | 28 | The models run by server is then configured by a JSON version of the 29 | InsightsConfiguration protocol buffer. Using the model trained in the Quick 30 | Start guide, a sample configuration would be: 31 | 32 | ``` 33 | { 34 | "configuration": [{ 35 | "source_language": "en", 36 | "target_language": "de", 37 | "label": "transformers_wmt32k", 38 | "transformer": { 39 | "model": "transformer", 40 | "model_dir": "/tmp/t2t/train", 41 | "data_dir": "/tmp/t2t/data", 42 | "hparams": "", 43 | "hparams_set": "transformer_base_single_gpu", 44 | "problem": "translate_ende_wmt32k" 45 | } 46 | }], 47 | "language": [{ 48 | "code": "en", 49 | "name": "English" 50 | },{ 51 | "code": "de", 52 | "name": "German" 53 | }] 54 | } 55 | ``` 56 | 57 | With that saved to `configuration.json`, run the following: 58 | 59 | ``` 60 | t2t-insights-server \ 61 | --configuration=configuration.json \ 62 | --static_path=`pwd`/tensor2tensor/insights/polymer 63 | ``` 64 | 65 | This will bring up a minimal [Flask](http://flask.pocoo.org/) REST service 66 | served by a [GUnicorn](http://gunicorn.org/) HTTP Server. 67 | 68 | ## Features to be developed 69 | 70 | This is a minimal web server. We are in the process of adding additional 71 | exciting features that give insight into a model's behavior: 72 | 73 | * Integrating a multi-head attention visualization. 74 | * Registering multiple models to compare their behavior. 75 | * Indexing training data to find examples related to a current query. 76 | * Tracking interesting query + translation pairs for deeper analysis. 77 | -------------------------------------------------------------------------------- /tensor2tensor/insights/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tensor2tensor/insights/insight_configuration.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensor2tensor; 4 | 5 | // Configures the Neural Machine Translation Insight Frontend with a set of 6 | // supported query processors and languages. 7 | message InsightConfiguration { 8 | // Specifies zero or more models to inspect. 9 | repeated QueryProcessorConfiguration configuration = 1; 10 | 11 | // Specifies language codes and display names. 12 | repeated Language language = 2; 13 | } 14 | 15 | // A displayable language name. 16 | message Language { 17 | // The BCP-47 Language code. 18 | string code = 1; 19 | // The language's display name. 20 | string name = 2; 21 | } 22 | 23 | // Configures a QueryProcessor and registers it with the Insight Frontend when 24 | // responding to analysis queries. 25 | message QueryProcessorConfiguration { 26 | // The model's BCP-47 source language code. 27 | string source_language = 1; 28 | // The model's BCP-47 target language code. 29 | string target_language = 2; 30 | // A short label for the model. 31 | string label = 3; 32 | // The QueryProcessor to use. By default we just use the TransformerModel. 33 | string query_processor = 4; 34 | 35 | // Configuration for the TransformerModel. 36 | TransformerConfiguration transformer = 5; 37 | } 38 | 39 | // Specifies the parameters for a trained Transformer model to inspect. These 40 | // parameters match those in t2t-trainer and t2t-decoder. 41 | message TransformerConfiguration { 42 | // The model type. 43 | string model = 1; 44 | // The trained model directory. 45 | string model_dir = 2; 46 | // The data directory for the model. 47 | string data_dir = 3; 48 | 49 | // The hyperparameter set for running the model. 50 | string hparams_set = 4; 51 | // Overriding hyperparameters. 52 | string hparams = 5; 53 | // The problem sets over which this model was trained and configured. 54 | string problems = 6; 55 | } 56 | -------------------------------------------------------------------------------- /tensor2tensor/insights/polymer/.bowerrc: -------------------------------------------------------------------------------- 1 | { 2 | "directory": "." 3 | } 4 | -------------------------------------------------------------------------------- /tensor2tensor/insights/polymer/insights_app/insights-app.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 The Tensor2Tensor Authors. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * `` Manages the views of the NMT Insights App. 20 | * 21 | * ### Usage 22 | * 23 | * 24 | * 25 | */ 26 | class InsightsApp extends Polymer.Element { 27 | /** 28 | * @return {string} The component name. 29 | */ 30 | static get is() { 31 | return 'insights-app'; 32 | } 33 | 34 | /** 35 | * @return {!Object} The component properties. 36 | */ 37 | static get properties() { 38 | return { 39 | /** 40 | * @type {string} 41 | */ 42 | page: { 43 | type: String, 44 | reflectToAttribute: true, 45 | }, 46 | }; 47 | } 48 | 49 | /** 50 | * @return {!Array} The component observers. 51 | */ 52 | static get observers() { 53 | return [ 54 | 'routePageChanged_(routeData.page)', 55 | ]; 56 | } 57 | 58 | /** 59 | * Updates the page field if page exists or uses a default value. 60 | * @param {?string} page The current page name being viewed. 61 | * @private 62 | */ 63 | routePageChanged_(page) { 64 | if (page == this.page) { 65 | return; 66 | } 67 | this.page = page || 'explore'; 68 | this.set('routeData.page', this.page); 69 | 70 | // Refresh the now selected page in case it needs new data on a new view. 71 | let currentPage = this.get('currentPage'); 72 | if (currentPage) { 73 | currentPage.refresh(); 74 | } 75 | } 76 | } 77 | 78 | customElements.define(InsightsApp.is, InsightsApp); 79 | -------------------------------------------------------------------------------- /tensor2tensor/insights/polymer/language_selector/language-selector-content.html: -------------------------------------------------------------------------------- 1 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /tensor2tensor/insights/polymer/language_selector/language-selector.html: -------------------------------------------------------------------------------- 1 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /tensor2tensor/insights/polymer/language_selector/language-selector.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 The Tensor2Tensor Authors. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * `` provides a searchable dropdown of languages. 20 | * 21 | * The dropdown will present the selected language's Name. When opened, the 22 | * search bar will filter available languages by any language name or code that 23 | * has the query text as a substring. 24 | * 25 | * By default, this will auto select a provided language with language code 26 | * 'en'. 27 | * 28 | * ### Usage 29 | * 30 | * 31 | * 32 | */ 33 | class LanguageSelector extends Polymer.Element { 34 | /** 35 | * @return {string} The component name. 36 | */ 37 | static get is() { 38 | return 'language-selector'; 39 | } 40 | 41 | /** 42 | * @return {!Object} The component properties. 43 | */ 44 | static get properties() { 45 | return { 46 | /** 47 | * @type {string} 48 | */ 49 | label: { 50 | type: String, 51 | }, 52 | /** 53 | * @type {?Array} 54 | */ 55 | languages: { 56 | type: Array, 57 | }, 58 | /** 59 | * @type {!Language} 60 | */ 61 | value: { 62 | type: Object, 63 | notify: true, 64 | }, 65 | /** 66 | * @type {string} 67 | */ 68 | defaultCode: { 69 | type: String, 70 | value: 'en', 71 | }, 72 | }; 73 | } 74 | 75 | /** 76 | * Selects the language in the drop down. 77 | * @param {Language} language The language to pre-select. 78 | * @public 79 | */ 80 | forceSelection(language) { 81 | this.$.selector.forceSelection(language); 82 | } 83 | } 84 | 85 | customElements.define(LanguageSelector.is, LanguageSelector); 86 | -------------------------------------------------------------------------------- /tensor2tensor/insights/polymer/processing_visualization/processing-visualization.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 The Tensor2Tensor Authors. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * `` summarises pre/post processing steps. 20 | * 21 | * This element presents the pre-processing segmentation steps and 22 | * post-processing de-segmentation and rewrite steps that are applied to a 23 | * translation query. 24 | * 25 | * ### Usage 26 | * 27 | * 28 | */ 29 | class ProcessingVisualization extends Polymer.Element { 30 | /** 31 | * @return {string} The component name. 32 | */ 33 | static get is() { 34 | return 'processing-visualization'; 35 | } 36 | 37 | /** 38 | * @return {!Object} The component properties. 39 | */ 40 | static get properties() { 41 | return { 42 | /** 43 | * @type {!QueryProcessingVisualization} 44 | */ 45 | data: { 46 | type: Object, 47 | }, 48 | }; 49 | } 50 | } 51 | 52 | customElements.define(ProcessingVisualization.is, ProcessingVisualization); 53 | -------------------------------------------------------------------------------- /tensor2tensor/insights/query_processor.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A base class for all query processing classes.""" 17 | 18 | 19 | class QueryProcessor(object): 20 | """Base class for any class that wants to process sequence queries. 21 | 22 | QueryProcessor classes are expected to convert a string query to a series of 23 | visualization structures. 24 | 25 | TODO(kstevens): Define how the visualization structures should look once the 26 | protos are in better shape. 27 | """ 28 | 29 | def process(self, query): 30 | """Returns the generated visualizations for query. 31 | 32 | Args: 33 | query: The string input 34 | 35 | Returns: 36 | A dictionary with one key: 'result' that maps to a list of visualization 37 | objects. 38 | """ 39 | del query 40 | return {"result": []} 41 | -------------------------------------------------------------------------------- /tensor2tensor/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tensor2tensor/layers/ngram_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for n-gram layer.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.layers import ngram 23 | 24 | from tensor2tensor.utils import test_utils 25 | 26 | import tensorflow.compat.v1 as tf 27 | tf.enable_eager_execution() 28 | 29 | 30 | class NGramTest(tf.test.TestCase): 31 | 32 | @test_utils.run_in_graph_and_eager_modes() 33 | def testNGramLayerShape(self): 34 | batch_size = 2 35 | length = 8 36 | vocab_size = 3 37 | minval = 1 38 | maxval = 4 39 | inputs = tf.random_uniform( 40 | [batch_size, length], minval=0, maxval=vocab_size, dtype=tf.int32) 41 | layer = ngram.NGram(vocab_size, minval, maxval) 42 | outputs = layer(inputs) 43 | outputs_val = self.evaluate(outputs) 44 | num_ngrams = sum([vocab_size**n for n in range(minval, maxval)]) 45 | self.assertEqual(outputs_val.shape, (batch_size, num_ngrams)) 46 | 47 | @test_utils.run_in_graph_and_eager_modes() 48 | def testNGramLayerOutput(self): 49 | inputs = tf.constant( 50 | [[0, 0, 0, 0, 1], 51 | [2, 1, 2, 1, 0]], dtype=tf.int32) 52 | layer = ngram.NGram(3, minval=1, maxval=3) 53 | outputs = layer(inputs) 54 | expected_outputs = tf.constant( 55 | [[4., 1., 0., 2., 0., 0., 0., 0., 0., 0., 0., 0.], 56 | [1., 2., 2., 0., 0., 0., 0., 0., 0., 0., 2., 0.]], dtype=tf.float32) 57 | outputs_val, expected_outputs_val = self.evaluate( 58 | [outputs, expected_outputs]) 59 | self.assertAllEqual(outputs_val, expected_outputs_val) 60 | 61 | if __name__ == "__main__": 62 | tf.test.main() 63 | 64 | -------------------------------------------------------------------------------- /tensor2tensor/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tensor2tensor/metrics/video_conditional_fvd_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for video_conditional_fvd.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.metrics import video_conditional_fvd 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | class VideoConditionalFvdTest(tf.test.TestCase): 27 | 28 | def test_sample(self): 29 | dataset = video_conditional_fvd.VideoEvaluationDataset( 30 | n_input_frames=4, 31 | n_output_frames=10, 32 | get_video_batch_fn=None) 33 | model = video_conditional_fvd.Model( 34 | apply_fn=None, 35 | load_fn=None) 36 | video_conditional_fvd.evaluate_model(dataset, model, 10, 16) 37 | 38 | 39 | if __name__ == '__main__': 40 | tf.test.main() 41 | -------------------------------------------------------------------------------- /tensor2tensor/models/README.md: -------------------------------------------------------------------------------- 1 | # Constructing T2T Models. 2 | 3 | This directory contains T2T models, their hyperparameters, and a number 4 | of common layers and hyperparameter settings to help construct new models. 5 | Common building blocks are in `common_layers.py` and `common_attention.py`. 6 | Common hyperparameters are in `common_hparams.py`. Models are imported in 7 | `__init__.py`. 8 | 9 | ## Adding a new model. 10 | 11 | To add a model to the built-in set, create a new file (see, e.g., 12 | `neural_gpu.py`) and write your model class inheriting from `T2TModel` there and 13 | decorate it with `registry.register_model`. Import it in `__init__.py`. 14 | 15 | It is now available to use with the trainer binary (`t2t-trainer`) using the 16 | `--model=model_name` flag. 17 | -------------------------------------------------------------------------------- /tensor2tensor/models/basic.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Basic models for testing simple tasks.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.layers import common_hparams 23 | from tensor2tensor.layers import common_layers 24 | from tensor2tensor.utils import registry 25 | from tensor2tensor.utils import t2t_model 26 | 27 | import tensorflow.compat.v1 as tf 28 | 29 | 30 | @registry.register_model 31 | class BasicFcRelu(t2t_model.T2TModel): 32 | """Basic fully-connected + ReLU model.""" 33 | 34 | def body(self, features): 35 | hparams = self.hparams 36 | x = features["inputs"] 37 | shape = common_layers.shape_list(x) 38 | x = tf.reshape(x, [-1, shape[1] * shape[2] * shape[3]]) 39 | for i in range(hparams.num_hidden_layers): 40 | x = tf.layers.dense(x, hparams.hidden_size, name="layer_%d" % i) 41 | x = tf.nn.dropout(x, keep_prob=1.0 - hparams.dropout) 42 | x = tf.nn.relu(x) 43 | return tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) # 4D For T2T. 44 | 45 | 46 | @registry.register_hparams 47 | def basic_fc_small(): 48 | """Small fully connected model.""" 49 | hparams = common_hparams.basic_params1() 50 | hparams.learning_rate = 0.1 51 | hparams.batch_size = 128 52 | hparams.hidden_size = 256 53 | hparams.num_hidden_layers = 2 54 | hparams.initializer = "uniform_unit_scaling" 55 | hparams.initializer_gain = 1.0 56 | hparams.weight_decay = 0.0 57 | hparams.dropout = 0.0 58 | return hparams 59 | -------------------------------------------------------------------------------- /tensor2tensor/models/basic_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Basic nets tests.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | import numpy as np 22 | 23 | from tensor2tensor.data_generators import mnist # pylint: disable=unused-import 24 | from tensor2tensor.models import basic 25 | from tensor2tensor.utils import trainer_lib 26 | 27 | import tensorflow.compat.v1 as tf 28 | from tensorflow.compat.v1 import estimator as tf_estimator 29 | 30 | 31 | class BasicTest(tf.test.TestCase): 32 | 33 | def testBasicFcRelu(self): 34 | x = np.random.randint(256, size=(1, 28, 28, 1)) 35 | y = np.random.randint(10, size=(1, 1)) 36 | hparams = trainer_lib.create_hparams( 37 | "basic_fc_small", problem_name="image_mnist", data_dir=".") 38 | with self.test_session() as session: 39 | features = { 40 | "inputs": tf.constant(x, dtype=tf.int32), 41 | "targets": tf.constant(y, dtype=tf.int32), 42 | } 43 | model = basic.BasicFcRelu(hparams, tf_estimator.ModeKeys.TRAIN) 44 | logits, _ = model(features) 45 | session.run(tf.global_variables_initializer()) 46 | res = session.run(logits) 47 | self.assertEqual(res.shape, (1, 1, 1, 1, 10)) 48 | 49 | 50 | if __name__ == "__main__": 51 | tf.test.main() 52 | -------------------------------------------------------------------------------- /tensor2tensor/models/bytenet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ByteNet tests.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | import numpy as np 22 | 23 | from tensor2tensor.data_generators import problem_hparams 24 | from tensor2tensor.models import bytenet 25 | 26 | import tensorflow.compat.v1 as tf 27 | from tensorflow.compat.v1 import estimator as tf_estimator 28 | 29 | 30 | class ByteNetTest(tf.test.TestCase): 31 | 32 | def testByteNet(self): 33 | vocab_size = 9 34 | x = np.random.randint(1, high=vocab_size, size=(3, 5, 1, 1)) 35 | y = np.random.randint(1, high=vocab_size, size=(3, 6, 1, 1)) 36 | hparams = bytenet.bytenet_base() 37 | p_hparams = problem_hparams.test_problem_hparams(vocab_size, 38 | vocab_size, 39 | hparams) 40 | with self.test_session() as session: 41 | features = { 42 | "inputs": tf.constant(x, dtype=tf.int32), 43 | "targets": tf.constant(y, dtype=tf.int32), 44 | } 45 | model = bytenet.ByteNet( 46 | hparams, tf_estimator.ModeKeys.TRAIN, p_hparams) 47 | logits, _ = model(features) 48 | session.run(tf.global_variables_initializer()) 49 | res = session.run(logits) 50 | self.assertEqual(res.shape, (3, 50, 1, 1, vocab_size)) 51 | 52 | 53 | if __name__ == "__main__": 54 | tf.test.main() 55 | -------------------------------------------------------------------------------- /tensor2tensor/models/neural_architecture_search/README.md: -------------------------------------------------------------------------------- 1 | This directory contains the configurable model code used in the Evolved 2 | Transformer paper (https://arxiv.org/abs/1901.11117). It can be used to train 3 | models in the search space as was done in the paper. 4 | -------------------------------------------------------------------------------- /tensor2tensor/models/neural_architecture_search/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | -------------------------------------------------------------------------------- /tensor2tensor/models/neural_gpu_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Neural GPU.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | import numpy as np 22 | 23 | from tensor2tensor.data_generators import problem_hparams 24 | from tensor2tensor.layers import common_hparams 25 | from tensor2tensor.models import neural_gpu 26 | 27 | import tensorflow.compat.v1 as tf 28 | from tensorflow.compat.v1 import estimator as tf_estimator 29 | 30 | 31 | class NeuralGPUTest(tf.test.TestCase): 32 | 33 | def testNeuralGPU(self): 34 | hparams = common_hparams.basic_params1() 35 | batch_size = 3 36 | input_length = 5 37 | target_length = input_length 38 | input_vocab_size = 9 39 | target_vocab_size = 11 40 | p_hparams = problem_hparams.test_problem_hparams(input_vocab_size, 41 | target_vocab_size, 42 | hparams) 43 | inputs = np.random.randint( 44 | input_vocab_size, size=(batch_size, input_length, 1, 1)) 45 | targets = np.random.randint( 46 | target_vocab_size, size=(batch_size, target_length, 1, 1)) 47 | with self.test_session() as session: 48 | features = { 49 | "inputs": tf.constant(inputs, dtype=tf.int32), 50 | "targets": tf.constant(targets, dtype=tf.int32) 51 | } 52 | model = neural_gpu.NeuralGPU(hparams, tf_estimator.ModeKeys.TRAIN, 53 | p_hparams) 54 | logits, _ = model(features) 55 | session.run(tf.global_variables_initializer()) 56 | res = session.run(logits) 57 | self.assertEqual(res.shape, (batch_size, target_length, 1, 1, 58 | target_vocab_size)) 59 | 60 | 61 | if __name__ == "__main__": 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /tensor2tensor/models/research/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tensor2tensor/models/research/glow_init_hook.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Hook to run glow initialization on a larger batch.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | class GlowInitHook(tf.train.SessionRunHook): 26 | """ 27 | Hook that runs data-dependent initialization once before the first step. 28 | 29 | The init op is stored in the tf collection glow_init_op. Look at the 30 | "body" in glow.py for more details. 31 | """ 32 | 33 | def after_create_session(self, session, coord): 34 | del coord 35 | global_step = session.run(tf.train.get_global_step()) 36 | if global_step == 0: 37 | ddi = tf.get_collection("glow_init_op") 38 | # In-case of a multi-GPU system, this just runs the first op in the 39 | # collection. 40 | if ddi: 41 | session.run(ddi[0]) 42 | -------------------------------------------------------------------------------- /tensor2tensor/models/research/transformer_sketch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Transformer Sketch for im2sketch problems. 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | from tensor2tensor.layers import common_layers 24 | from tensor2tensor.models import transformer 25 | from tensor2tensor.utils import registry 26 | 27 | import tensorflow.compat.v1 as tf 28 | 29 | 30 | @registry.register_model 31 | class TransformerSketch(transformer.Transformer): 32 | """Transformer with strided convolutions.""" 33 | 34 | def encode(self, inputs, target_space, hparams, features=None, losses=None): 35 | """Add layers of strided convolutions on top of encoder.""" 36 | with tf.variable_scope("downstride"): 37 | hparams = self.hparams 38 | kernel, strides = (4, 4), (2, 2) 39 | x = inputs 40 | # Down-convolutions. 41 | for i in range(hparams.num_compress_steps): 42 | x = common_layers.make_even_size(x) 43 | x = tf.layers.conv2d( 44 | x, hparams.hidden_size, kernel, strides=strides, 45 | padding="SAME", activation=common_layers.belu, name="conv_%d" % i) 46 | x = common_layers.layer_norm(x) 47 | 48 | encoder_output, encoder_decoder_attention_bias = super( 49 | TransformerSketch, self).encode( 50 | x, target_space, hparams, features=features, losses=losses) 51 | return encoder_output, encoder_decoder_attention_bias 52 | 53 | 54 | @registry.register_hparams 55 | def transformer_sketch(): 56 | """Basic transformer_sketch hparams.""" 57 | hparams = transformer.transformer_small() 58 | hparams.num_compress_steps = 4 59 | hparams.batch_size = 32 60 | hparams.clip_grad_norm = 2. 61 | hparams.sampling_method = "random" 62 | return hparams 63 | -------------------------------------------------------------------------------- /tensor2tensor/models/research/transformer_vae_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2tensor.models.research.transformer_vae.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import numpy as np 21 | from tensor2tensor.data_generators import problem_hparams 22 | from tensor2tensor.models.research import transformer_vae 23 | import tensorflow.compat.v1 as tf 24 | from tensorflow.compat.v1 import estimator as tf_estimator 25 | 26 | 27 | class TransformerVaeTest(tf.test.TestCase): 28 | 29 | def testTransformerAEOnDVQ(self): 30 | batch_size = 3 31 | input_length = 5 32 | target_length = 16 33 | vocab_size = 9 34 | hparams = transformer_vae.transformer_ae_small() 35 | hparams.bottleneck_kind = "dvq" 36 | hparams.dp_strength = 0 37 | p_hparams = problem_hparams.test_problem_hparams(vocab_size, 38 | vocab_size, 39 | hparams) 40 | hparams.problem_hparams = p_hparams 41 | inputs = np.random.randint( 42 | vocab_size, size=(batch_size, input_length, 1, 1)) 43 | targets = np.random.randint( 44 | vocab_size, size=(batch_size, target_length, 1, 1)) 45 | features = { 46 | "inputs": tf.constant(inputs, dtype=tf.int32), 47 | "targets": tf.constant(targets, dtype=tf.int32), 48 | "target_space_id": tf.constant(1, dtype=tf.int32), 49 | } 50 | tf.train.create_global_step() 51 | model = transformer_vae.TransformerAE(hparams, tf_estimator.ModeKeys.TRAIN, 52 | p_hparams) 53 | logits, _ = model(features) 54 | with self.test_session() as session: 55 | session.run(tf.global_variables_initializer()) 56 | logits_val = session.run(logits) 57 | self.assertEqual(logits_val.shape, 58 | (batch_size, target_length, 1, 1, vocab_size)) 59 | 60 | 61 | if __name__ == "__main__": 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /tensor2tensor/models/resnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Resnet tests.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | 24 | from tensor2tensor.data_generators import problem_hparams 25 | from tensor2tensor.layers import modalities 26 | from tensor2tensor.models import resnet 27 | 28 | import tensorflow.compat.v1 as tf 29 | from tensorflow.compat.v1 import estimator as tf_estimator 30 | 31 | 32 | def resnet_tiny_cpu(): 33 | hparams = resnet.resnet_base() 34 | hparams.layer_sizes = [2, 2, 2, 2] 35 | hparams.use_nchw = False 36 | return hparams 37 | 38 | 39 | class ResnetTest(tf.test.TestCase): 40 | 41 | def _test_resnet(self, img_size, output_size): 42 | vocab_size = 9 43 | batch_size = 2 44 | x = np.random.randint( 45 | 256, size=(batch_size, img_size, img_size, 3)) 46 | y = np.random.randint( 47 | 1, high=vocab_size, size=(batch_size, 1, 1, 1)) 48 | hparams = resnet_tiny_cpu() 49 | p_hparams = problem_hparams.test_problem_hparams(vocab_size, 50 | vocab_size, 51 | hparams) 52 | p_hparams.modality["inputs"] = modalities.ModalityType.IMAGE 53 | p_hparams.modality["targets"] = modalities.ModalityType.CLASS_LABEL 54 | with self.test_session() as session: 55 | features = { 56 | "inputs": tf.constant(x, dtype=tf.int32), 57 | "targets": tf.constant(y, dtype=tf.int32), 58 | } 59 | model = resnet.Resnet(hparams, tf_estimator.ModeKeys.TRAIN, p_hparams) 60 | logits, _ = model(features) 61 | session.run(tf.global_variables_initializer()) 62 | res = session.run(logits) 63 | self.assertEqual(res.shape, (batch_size,) + output_size + (1, vocab_size)) 64 | 65 | def testResnetLarge(self): 66 | self._test_resnet(img_size=224, output_size=(1, 1)) 67 | 68 | 69 | if __name__ == "__main__": 70 | tf.test.main() 71 | -------------------------------------------------------------------------------- /tensor2tensor/models/video/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tensor2tensor/models/video/basic_deterministic_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Basic tests for basic deterministic model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.models.video import basic_deterministic 23 | from tensor2tensor.models.video import basic_deterministic_params 24 | from tensor2tensor.models.video import tests_utils 25 | 26 | import tensorflow.compat.v1 as tf 27 | 28 | 29 | class NextFrameTest(tests_utils.BaseNextFrameTest): 30 | 31 | def testBasicDeterministic(self): 32 | self.TestOnVariousInputOutputSizes( 33 | basic_deterministic_params.next_frame_basic_deterministic(), 34 | basic_deterministic.NextFrameBasicDeterministic, 35 | 256, 36 | False) 37 | 38 | if __name__ == "__main__": 39 | tf.test.main() 40 | -------------------------------------------------------------------------------- /tensor2tensor/models/video/basic_recurrent.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Basic recurrent models for testing simple tasks.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.layers import common_video 23 | from tensor2tensor.models.video import basic_stochastic 24 | from tensor2tensor.utils import registry 25 | 26 | 27 | @registry.register_model 28 | class NextFrameBasicRecurrent( 29 | basic_stochastic.NextFrameBasicStochasticDiscrete): 30 | """Basic next-frame recurrent model.""" 31 | 32 | @property 33 | def is_recurrent_model(self): 34 | return True 35 | 36 | def middle_network(self, layer, internal_states): 37 | lstm_func = common_video.conv_lstm_2d 38 | hp = self.hparams 39 | 40 | lstm_states = internal_states 41 | if lstm_states is None: 42 | lstm_states = [None] * hp.num_lstm_layers 43 | 44 | # LSTM layers 45 | x = layer 46 | for j in range(hp.num_lstm_layers): 47 | x, lstm_states[j] = lstm_func(x, lstm_states[j], hp.num_lstm_filters) 48 | return x, lstm_states 49 | 50 | 51 | @registry.register_hparams 52 | def next_frame_basic_recurrent(): 53 | """Basic 2-frame recurrent model with stochastic tower.""" 54 | hparams = basic_stochastic.next_frame_basic_stochastic_discrete() 55 | hparams.filter_double_steps = 2 56 | hparams.hidden_size = 64 57 | hparams.video_num_input_frames = 4 58 | hparams.video_num_target_frames = 4 59 | hparams.concat_internal_states = False 60 | hparams.add_hparam("num_lstm_layers", 2) 61 | hparams.add_hparam("num_lstm_filters", 256) 62 | return hparams 63 | -------------------------------------------------------------------------------- /tensor2tensor/models/video/basic_recurrent_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Basic tests for basic deterministic model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.models.video import basic_recurrent 23 | from tensor2tensor.models.video import tests_utils 24 | 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | class NextFrameTest(tests_utils.BaseNextFrameTest): 29 | 30 | def testBasicDeterministic(self): 31 | self.TestOnVariousInputOutputSizes( 32 | basic_recurrent.next_frame_basic_recurrent(), 33 | basic_recurrent.NextFrameBasicRecurrent, 34 | 256, 35 | False) 36 | 37 | if __name__ == "__main__": 38 | tf.test.main() 39 | -------------------------------------------------------------------------------- /tensor2tensor/models/video/basic_stochastic_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Basic tests for basic stochastic model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.models.video import basic_stochastic 23 | from tensor2tensor.models.video import tests_utils 24 | 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | class NextFrameTest(tests_utils.BaseNextFrameTest): 29 | 30 | def testBasicStochastic(self): 31 | self.TestOnVariousInputOutputSizes( 32 | basic_stochastic.next_frame_basic_stochastic(), 33 | basic_stochastic.NextFrameBasicStochastic, 34 | 256, 35 | False) 36 | 37 | if __name__ == "__main__": 38 | tf.test.main() 39 | -------------------------------------------------------------------------------- /tensor2tensor/models/video/emily_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Basic tests for emily's model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.models.video import emily 23 | from tensor2tensor.models.video import tests_utils 24 | 25 | 26 | import tensorflow.compat.v1 as tf 27 | 28 | 29 | class NextFrameTest(tests_utils.BaseNextFrameTest): 30 | 31 | def testEmily(self): 32 | self.TestOnVariousInputOutputSizes( 33 | emily.next_frame_emily(), 34 | emily.NextFrameEmily, 35 | 1) 36 | 37 | 38 | if __name__ == "__main__": 39 | tf.test.main() 40 | -------------------------------------------------------------------------------- /tensor2tensor/models/video/epva_params.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Param sets for EPVA model.""" 17 | 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensor2tensor.layers import modalities 22 | from tensor2tensor.models.video import basic_deterministic_params 23 | from tensor2tensor.utils import registry 24 | 25 | 26 | @registry.register_hparams 27 | def next_frame_epva(): 28 | """EPVA hparams.""" 29 | hparams = basic_deterministic_params.next_frame_basic_deterministic() 30 | hparams.video_num_input_frames = 4 31 | hparams.video_num_target_frames = 4 32 | hparams.bottom = { 33 | "inputs": modalities.video_raw_bottom, 34 | "targets": modalities.video_raw_targets_bottom, 35 | } 36 | hparams.loss = { 37 | "targets": modalities.video_l2_raw_loss, 38 | } 39 | hparams.top = { 40 | "targets": modalities.video_raw_top, 41 | } 42 | hparams.learning_rate_schedule = "constant" 43 | hparams.learning_rate_constant = 1e-05 44 | hparams.batch_size = 2 45 | hparams.clip_grad_norm = 0.01 46 | # TODO(msaffar): disentangle EPVA from SV2P 47 | hparams.add_hparam("reward_prediction", False) 48 | hparams.add_hparam("clip_pixel_values", True) 49 | hparams.add_hparam("context_frames", 5) 50 | hparams.add_hparam("enc_learning_rate", 1e-5) 51 | hparams.add_hparam("enc_pred_loss_scale", 0.1) 52 | hparams.add_hparam("enc_pred_loss_scale_delay", 6e5) 53 | hparams.add_hparam("enc_size", 64) 54 | hparams.add_hparam("enc_keep_prob", .65) 55 | hparams.add_hparam("enc_pred_use_l1_loss", False) 56 | hparams.add_hparam("enc_pred_use_l2norm", False) 57 | hparams.add_hparam("van_learning_rate", 3e-5) 58 | hparams.add_hparam("van_keep_prob", .9) 59 | hparams.add_hparam("sequence_length ", 64) 60 | hparams.add_hparam("skip_num", 2) 61 | hparams.add_hparam("pred_noise_std", 0) 62 | hparams.add_hparam("lstm_state_noise_stddev", 0) 63 | return hparams 64 | -------------------------------------------------------------------------------- /tensor2tensor/models/video/nfg_conv3d_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test when the latent-network encoder is a conv3d net.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl.testing import parameterized 23 | from tensor2tensor.models.video import nfg_test_utils 24 | import tensorflow.compat.v1 as tf 25 | 26 | conv3d_net_hparams = ( 27 | ("conv3d_net", 2, 2, "conv3d_net", "conditional", -1, 3), 28 | ("conv3d_net_gatu", 2, 2, "conv3d_net", "conditional", -1, 3, False, False, 29 | "gatu"), 30 | ("conv3d_dil", 2, 2, "conv3d_net", "conditional", -1, -1, False, True),) 31 | 32 | 33 | class NextFrameGlowConv3DTest(nfg_test_utils.NextFrameGlowTest, 34 | parameterized.TestCase): 35 | 36 | @parameterized.named_parameters(*conv3d_net_hparams) 37 | def testGlowTrainAndDecode(self, in_frames=1, out_frames=1, 38 | latent_dist_encoder="pointwise", 39 | gen_mode="conditional", pretrain_steps=-1, 40 | num_train_frames=-1, cond_first_frame=False, 41 | apply_dilations=False, activation="relu"): 42 | self.GlowTrainAndDecode( 43 | in_frames=in_frames, out_frames=out_frames, 44 | latent_dist_encoder=latent_dist_encoder, gen_mode=gen_mode, 45 | pretrain_steps=pretrain_steps, num_train_frames=num_train_frames, 46 | cond_first_frame=cond_first_frame, apply_dilations=apply_dilations, 47 | activation=activation) 48 | 49 | 50 | if __name__ == "__main__": 51 | tf.test.main() 52 | -------------------------------------------------------------------------------- /tensor2tensor/models/video/nfg_conv_lstm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test when the latent-network encoder is a conv-lstm.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl.testing import parameterized 23 | from tensor2tensor.models.video import nfg_test_utils 24 | import tensorflow.compat.v1 as tf 25 | 26 | conv_lstm_hparams = ( 27 | ("in_3_out_2_lstm", 2, 1, "conv_lstm", "conditional", -1), 28 | ("lstm_pretrain", 2, 1, "conv_lstm", "conditional", 50000)) 29 | 30 | 31 | class NextFrameGlowConv3DTest(nfg_test_utils.NextFrameGlowTest, 32 | parameterized.TestCase): 33 | 34 | @parameterized.named_parameters(*conv_lstm_hparams) 35 | def testGlowTrainAndDecode(self, in_frames=1, out_frames=1, 36 | latent_dist_encoder="pointwise", 37 | gen_mode="conditional", pretrain_steps=-1, 38 | num_train_frames=-1, cond_first_frame=False): 39 | self.GlowTrainAndDecode( 40 | in_frames=in_frames, out_frames=out_frames, 41 | latent_dist_encoder=latent_dist_encoder, gen_mode=gen_mode, 42 | pretrain_steps=pretrain_steps, num_train_frames=num_train_frames, 43 | cond_first_frame=cond_first_frame) 44 | 45 | 46 | if __name__ == "__main__": 47 | tf.test.main() 48 | -------------------------------------------------------------------------------- /tensor2tensor/models/video/nfg_conv_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test when the latent-network encoder is a 2-D conv.""" 17 | 18 | from absl.testing import parameterized 19 | from tensor2tensor.models.video import nfg_test_utils 20 | import tensorflow.compat.v1 as tf 21 | 22 | conv_net_hparams = ( 23 | ("in_3_out_2_conv", 3, 1, "conv_net", "conditional"), 24 | ("conv_net_cond_first", 2, 2, "conv_net", "conditional", -1, 3, True),) 25 | 26 | 27 | class NextFrameGlowConvTest(nfg_test_utils.NextFrameGlowTest, 28 | parameterized.TestCase): 29 | 30 | @parameterized.named_parameters(*conv_net_hparams) 31 | def testGlowTrainAndDecode(self, in_frames=1, out_frames=1, 32 | latent_dist_encoder="pointwise", 33 | gen_mode="conditional", pretrain_steps=-1, 34 | num_train_frames=-1, cond_first_frame=False): 35 | self.GlowTrainAndDecode( 36 | in_frames=in_frames, out_frames=out_frames, gen_mode=gen_mode, 37 | latent_dist_encoder=latent_dist_encoder, 38 | pretrain_steps=pretrain_steps, num_train_frames=num_train_frames, 39 | cond_first_frame=cond_first_frame) 40 | 41 | 42 | if __name__ == "__main__": 43 | tf.test.main() 44 | -------------------------------------------------------------------------------- /tensor2tensor/models/video/nfg_uncond_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for unconditional glow.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl.testing import parameterized 23 | from tensor2tensor.models.video import nfg_test_utils 24 | import tensorflow.compat.v1 as tf 25 | 26 | uncond_hparams = ( 27 | ("in_1_out_1", 1, 1, "pointwise", "conditional"), 28 | ("uncond", 1, 3, "pointwise", "unconditional", -1, 1),) 29 | 30 | 31 | class NfgUncondTest(nfg_test_utils.NextFrameGlowTest, parameterized.TestCase): 32 | 33 | @parameterized.named_parameters(*uncond_hparams) 34 | def testGlowTrainAndDecode(self, in_frames=1, out_frames=1, 35 | latent_dist_encoder="pointwise", 36 | gen_mode="conditional", pretrain_steps=-1, 37 | num_train_frames=-1, cond_first_frame=False): 38 | self.GlowTrainAndDecode( 39 | in_frames=in_frames, out_frames=out_frames, 40 | latent_dist_encoder=latent_dist_encoder, gen_mode=gen_mode, 41 | pretrain_steps=pretrain_steps, num_train_frames=num_train_frames, 42 | cond_first_frame=cond_first_frame) 43 | 44 | 45 | if __name__ == "__main__": 46 | tf.test.main() 47 | -------------------------------------------------------------------------------- /tensor2tensor/models/video/savp_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Basic tests for SAVP model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.models.video import savp 23 | from tensor2tensor.models.video import savp_params 24 | from tensor2tensor.models.video import tests_utils 25 | 26 | 27 | import tensorflow.compat.v1 as tf 28 | 29 | 30 | class NextFrameTest(tests_utils.BaseNextFrameTest): 31 | 32 | def testSavpVAE(self): 33 | savp_hparams = savp_params.next_frame_savp() 34 | savp_hparams.use_vae = True 35 | savp_hparams.use_gan = False 36 | self.TestOnVariousInputOutputSizes( 37 | savp_hparams, savp.NextFrameSAVP, 1) 38 | self.TestOnVariousUpSampleLayers( 39 | savp_hparams, savp.NextFrameSAVP, 1) 40 | 41 | def testSavpGAN(self): 42 | hparams = savp_params.next_frame_savp() 43 | hparams.use_gan = True 44 | hparams.use_vae = False 45 | self.TestVideoModel(7, 5, hparams, savp.NextFrameSAVP, 1) 46 | 47 | hparams.gan_optimization = "sequential" 48 | self.TestVideoModel(7, 5, hparams, savp.NextFrameSAVP, 1) 49 | 50 | def testSavpGANVAE(self): 51 | hparams = savp_params.next_frame_savp() 52 | hparams.use_vae = True 53 | hparams.use_gan = True 54 | self.TestVideoModel(7, 5, hparams, savp.NextFrameSAVP, 1) 55 | 56 | def testInvalidVAEGANCombinations(self): 57 | hparams = savp_params.next_frame_savp() 58 | hparams.use_gan = False 59 | hparams.use_vae = False 60 | self.assertRaises(ValueError, self.TestVideoModel, 61 | 7, 5, hparams, savp.NextFrameSAVP, 1) 62 | 63 | if __name__ == "__main__": 64 | tf.test.main() 65 | -------------------------------------------------------------------------------- /tensor2tensor/models/video/sv2p_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Basic tests for SV2P model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.models.video import sv2p 23 | from tensor2tensor.models.video import sv2p_params 24 | from tensor2tensor.models.video import tests_utils 25 | 26 | import tensorflow.compat.v1 as tf 27 | 28 | 29 | class NextFrameTest(tests_utils.BaseNextFrameTest): 30 | 31 | def testSv2p(self): 32 | self.TestOnVariousInputOutputSizes( 33 | sv2p_params.next_frame_sv2p(), 34 | sv2p.NextFrameSv2p, 35 | 1, 36 | False) 37 | 38 | def testSv2pWithActions(self): 39 | self.TestWithActions( 40 | sv2p_params.next_frame_sv2p(), 41 | sv2p.NextFrameSv2p, 42 | 1, 43 | False) 44 | 45 | def testSv2pWithActionsAndRewards(self): 46 | hp = sv2p_params.next_frame_sv2p() 47 | hp.internal_loss = True 48 | self.TestWithActionAndRewards( 49 | hp, 50 | sv2p.NextFrameSv2p, 51 | 1, 52 | False) 53 | 54 | def testSv2pWithActionsAndRewardsExternalLoss(self): 55 | hp = sv2p_params.next_frame_sv2p() 56 | hp.internal_loss = False 57 | self.TestWithActionAndRewards( 58 | hp, 59 | sv2p.NextFrameSv2p, 60 | 1, 61 | False) 62 | 63 | def testSv2pTwoFrames(self): 64 | self.TestOnVariousInputOutputSizes( 65 | sv2p_params.next_frame_sv2p(), 66 | sv2p.NextFrameSv2pTwoFrames, 67 | 1, 68 | False) 69 | 70 | 71 | if __name__ == "__main__": 72 | tf.test.main() 73 | -------------------------------------------------------------------------------- /tensor2tensor/models/xception_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Xception tests.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | 24 | from tensor2tensor.data_generators import problem_hparams 25 | from tensor2tensor.layers import modalities 26 | from tensor2tensor.models import xception 27 | 28 | import tensorflow.compat.v1 as tf 29 | from tensorflow.compat.v1 import estimator as tf_estimator 30 | 31 | 32 | class XceptionTest(tf.test.TestCase): 33 | 34 | def _test_xception(self, img_size): 35 | vocab_size = 9 36 | batch_size = 3 37 | x = np.random.randint( 38 | 256, size=(batch_size, img_size, img_size, 3)) 39 | y = np.random.randint( 40 | 1, high=vocab_size, size=(batch_size, 1, 1, 1)) 41 | hparams = xception.xception_tiny() 42 | p_hparams = problem_hparams.test_problem_hparams(vocab_size, 43 | vocab_size, 44 | hparams) 45 | p_hparams.modality["inputs"] = modalities.ModalityType.IMAGE 46 | p_hparams.modality["targets"] = modalities.ModalityType.CLASS_LABEL 47 | with self.test_session() as session: 48 | features = { 49 | "inputs": tf.constant(x, dtype=tf.int32), 50 | "targets": tf.constant(y, dtype=tf.int32), 51 | } 52 | model = xception.Xception(hparams, tf_estimator.ModeKeys.TRAIN, p_hparams) 53 | logits, _ = model(features) 54 | session.run(tf.global_variables_initializer()) 55 | res = session.run(logits) 56 | self.assertEqual(res.shape, (batch_size, 1, 1, 1, vocab_size)) 57 | 58 | def testXceptionSmallImage(self): 59 | self._test_xception(img_size=9) 60 | 61 | def testXceptionLargeImage(self): 62 | self._test_xception(img_size=256) 63 | 64 | 65 | if __name__ == "__main__": 66 | tf.test.main() 67 | -------------------------------------------------------------------------------- /tensor2tensor/problems.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Access T2T Problems.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensor2tensor.data_generators import all_problems 22 | from tensor2tensor.utils import registry 23 | 24 | 25 | def problem(name): 26 | return registry.problem(name) 27 | 28 | 29 | def available(): 30 | return registry.list_base_problems() 31 | 32 | 33 | all_problems.import_modules(all_problems.ALL_MODULES) 34 | -------------------------------------------------------------------------------- /tensor2tensor/problems_colab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Access T2T Problems.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensor2tensor.data_generators import all_problems 22 | from tensor2tensor.utils import registry 23 | 24 | 25 | def problem(name): 26 | return registry.problem(name) 27 | 28 | 29 | def available(): 30 | return sorted(registry.list_problems()) 31 | 32 | 33 | # Import problem modules 34 | _modules = list(all_problems.MODULES) 35 | 36 | all_problems.import_modules(_modules) 37 | -------------------------------------------------------------------------------- /tensor2tensor/problems_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """tensor2tensor.problems test.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensor2tensor import problems 22 | 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | class ProblemsTest(tf.test.TestCase): 27 | 28 | def testImport(self): 29 | self.assertIsNotNone(problems) 30 | 31 | if __name__ == "__main__": 32 | tf.test.main() 33 | -------------------------------------------------------------------------------- /tensor2tensor/rl/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tensor2tensor/rl/datagen_with_agent.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Generate trajectories to disk with random or ckpt agent. 17 | 18 | TODO: Usage 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | from tensor2tensor.data_generators import gym_env 26 | from tensor2tensor.utils import registry 27 | 28 | import tensorflow.compat.v1 as tf 29 | 30 | flags = tf.flags 31 | FLAGS = flags.FLAGS 32 | 33 | flags.DEFINE_string("data_dir", "", "Data directory.") 34 | flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen", 35 | "Temporary storage directory.") 36 | flags.DEFINE_string("game", None, "Atari game to generate data for.") 37 | flags.DEFINE_integer("num_env_steps", 5000, "Number of steps to roll out.") 38 | flags.DEFINE_boolean("eval", False, "Whether to run in eval mode.") 39 | 40 | 41 | def main(_): 42 | 43 | tf.gfile.MakeDirs(FLAGS.data_dir) 44 | tf.gfile.MakeDirs(FLAGS.tmp_dir) 45 | 46 | # Create problem if not already defined 47 | problem_name = "gym_discrete_problem_with_agent_on_%s" % FLAGS.game 48 | if problem_name not in registry.Registries.problems: 49 | gym_env.register_game(FLAGS.game) 50 | 51 | # Generate 52 | tf.logging.info("Running %s environment for %d steps for trajectories.", 53 | FLAGS.game, FLAGS.num_env_steps) 54 | problem = registry.problem(problem_name) 55 | problem.settable_num_steps = FLAGS.num_env_steps 56 | problem.settable_eval_phase = FLAGS.eval 57 | problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir) 58 | 59 | # Log stats 60 | if problem.statistics.number_of_dones: 61 | mean_reward = (problem.statistics.sum_of_rewards / 62 | problem.statistics.number_of_dones) 63 | tf.logging.info("Mean reward: %.2f, Num dones: %d", 64 | mean_reward, 65 | problem.statistics.number_of_dones) 66 | 67 | 68 | if __name__ == "__main__": 69 | tf.app.run(main) 70 | -------------------------------------------------------------------------------- /tensor2tensor/rl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tensor2tensor/rl/evaluator_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests the evaluator.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensor2tensor.rl import evaluator 22 | from tensor2tensor.utils import registry 23 | 24 | import tensorflow.compat.v1 as tf 25 | 26 | 27 | class EvalTest(tf.test.TestCase): 28 | 29 | def test_evaluate_pong_random_agent(self): 30 | loop_hparams = registry.hparams("rlmb_tiny") 31 | planner_hparams = registry.hparams("planner_tiny") 32 | temp_dir = tf.test.get_temp_dir() 33 | evaluator.evaluate( 34 | loop_hparams, planner_hparams, temp_dir, temp_dir, temp_dir, 35 | agent_type="random", eval_mode="agent_real", eval_with_learner=False, 36 | log_every_steps=None, debug_video_path="" 37 | ) 38 | 39 | 40 | if __name__ == "__main__": 41 | tf.test.main() 42 | -------------------------------------------------------------------------------- /tensor2tensor/rl/policy_learner.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unified interface for different RL algorithms.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | class PolicyLearner(object): 24 | """API for policy learners.""" 25 | 26 | def __init__( 27 | self, frame_stack_size, base_event_dir, agent_model_dir, total_num_epochs 28 | ): 29 | self.frame_stack_size = frame_stack_size 30 | self.base_event_dir = base_event_dir 31 | self.agent_model_dir = agent_model_dir 32 | self.total_num_epochs = total_num_epochs 33 | 34 | def train( 35 | self, 36 | env_fn, 37 | hparams, 38 | simulated, 39 | save_continuously, 40 | epoch, 41 | sampling_temp=1.0, 42 | num_env_steps=None, 43 | env_step_multiplier=1, 44 | eval_env_fn=None, 45 | report_fn=None 46 | ): 47 | """Train.""" 48 | raise NotImplementedError() 49 | 50 | def evaluate(self, env_fn, hparams, sampling_temp): 51 | raise NotImplementedError() 52 | -------------------------------------------------------------------------------- /tensor2tensor/rl/trainer_model_based_agent_only.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Training of model-based RL agent assuming a fully trained world model. 17 | 18 | Example invocation: 19 | 20 | python -m tensor2tensor.rl.trainer_model_based_agent_only \ 21 | --loop_hparams_set=rl_modelrl_base \ 22 | --world_model_dir=$HOME/world_model/ \ 23 | --data_dir=$HOME/data/ \ 24 | --output_dir=$HOME/ppo_agent_only/ \ 25 | """ 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | from tensor2tensor.bin import t2t_trainer # pylint: disable=unused-import 31 | from tensor2tensor.data_generators import gym_env 32 | from tensor2tensor.rl import trainer_model_based 33 | from tensor2tensor.rl import trainer_model_based_params 34 | 35 | 36 | import tensorflow.compat.v1 as tf 37 | 38 | 39 | flags = tf.flags 40 | FLAGS = flags.FLAGS 41 | 42 | flags.DEFINE_string("world_model_dir", "", 43 | "Directory containing checkpoints of the world model.") 44 | 45 | 46 | def get_simulated_problem_name(game): 47 | game_with_mode = game 48 | if game in gym_env.ATARI_GAMES: 49 | game_with_mode += "_deterministic-v4" 50 | return "gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode 51 | 52 | 53 | def main(_): 54 | hparams = trainer_model_based_params.create_loop_hparams() 55 | problem_name = get_simulated_problem_name(hparams.game) 56 | world_model_dir = FLAGS.world_model_dir 57 | agent_model_dir = FLAGS.output_dir 58 | event_dir = FLAGS.output_dir 59 | epoch_data_dir = FLAGS.data_dir # only required for initial frames 60 | 61 | trainer_model_based.train_agent( 62 | problem_name, 63 | agent_model_dir, 64 | event_dir, 65 | world_model_dir, 66 | epoch_data_dir, 67 | hparams, 68 | 0, 69 | epoch=0, 70 | is_final_epoch=True) 71 | 72 | if __name__ == "__main__": 73 | tf.logging.set_verbosity(tf.logging.INFO) 74 | tf.app.run() 75 | -------------------------------------------------------------------------------- /tensor2tensor/rl/trainer_model_based_recurrent_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tiny run of trainer_model_based. Smoke test.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensor2tensor.rl import trainer_model_based 22 | 23 | import tensorflow.compat.v1 as tf 24 | 25 | FLAGS = tf.flags.FLAGS 26 | 27 | 28 | class ModelRLExperimentRecurrentTest(tf.test.TestCase): 29 | 30 | def test_basic_recurrent(self): 31 | FLAGS.output_dir = tf.test.get_temp_dir() 32 | FLAGS.loop_hparams_set = "rlmb_tiny_recurrent" 33 | FLAGS.schedule = "train" # skip evaluation for world model training 34 | trainer_model_based.main(None) 35 | 36 | 37 | if __name__ == "__main__": 38 | tf.test.main() 39 | -------------------------------------------------------------------------------- /tensor2tensor/rl/trainer_model_based_stochastic_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tiny run of trainer_model_based with stochastic model. Smoke test.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensor2tensor.rl import trainer_model_based 22 | 23 | import tensorflow.compat.v1 as tf 24 | 25 | FLAGS = tf.flags.FLAGS 26 | 27 | 28 | class ModelRLExperimentStochasticTest(tf.test.TestCase): 29 | 30 | def test_basic_stochastic(self): 31 | FLAGS.output_dir = tf.test.get_temp_dir() 32 | FLAGS.loop_hparams_set = "rlmb_tiny_stochastic" 33 | FLAGS.schedule = "train" # skip evaluation for world model training 34 | trainer_model_based.main(None) 35 | 36 | 37 | if __name__ == "__main__": 38 | tf.test.main() 39 | -------------------------------------------------------------------------------- /tensor2tensor/rl/trainer_model_based_sv2p_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tiny run of trainer_model_based with stochastic model. Smoke test.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensor2tensor.rl import trainer_model_based 22 | 23 | import tensorflow.compat.v1 as tf 24 | 25 | FLAGS = tf.flags.FLAGS 26 | 27 | 28 | class ModelRLExperimentSv2pTest(tf.test.TestCase): 29 | 30 | def test_sv2p(self): 31 | FLAGS.output_dir = tf.test.get_temp_dir() 32 | FLAGS.loop_hparams_set = "rlmb_tiny_sv2p" 33 | trainer_model_based.main(None) 34 | 35 | 36 | if __name__ == "__main__": 37 | tf.test.main() 38 | -------------------------------------------------------------------------------- /tensor2tensor/rl/trainer_model_based_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tiny run of trainer_model_based. Smoke test.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensor2tensor.rl import trainer_model_based 22 | 23 | import tensorflow.compat.v1 as tf 24 | 25 | FLAGS = tf.flags.FLAGS 26 | 27 | 28 | class ModelRLExperimentTest(tf.test.TestCase): 29 | 30 | def _test_hparams_skip_evaluation(self, hparams_set): 31 | FLAGS.output_dir = tf.test.get_temp_dir() 32 | FLAGS.loop_hparams_set = hparams_set 33 | FLAGS.schedule = "train" # skip evaluation for world model training 34 | trainer_model_based.main(None) 35 | 36 | def test_basic(self): 37 | self._test_hparams_skip_evaluation("rlmb_tiny") 38 | 39 | # TODO(kozak): enable when it works. 40 | # def test_dqn_basic(self): 41 | # self._test_hparams_skip_evaluation("rlmb_dqn_tiny") 42 | 43 | 44 | if __name__ == "__main__": 45 | tf.test.main() 46 | -------------------------------------------------------------------------------- /tensor2tensor/rl/trainer_model_free_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests of basic flow of collecting trajectories and training PPO.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensor2tensor.rl import trainer_model_free 22 | from tensor2tensor.utils import registry 23 | 24 | import tensorflow.compat.v1 as tf 25 | 26 | FLAGS = tf.flags.FLAGS 27 | 28 | 29 | class TrainTest(tf.test.TestCase): 30 | 31 | def _test_hparams_set(self, hparams_set): 32 | hparams = registry.hparams(hparams_set) 33 | FLAGS.output_dir = tf.test.get_temp_dir() 34 | trainer_model_free.train(hparams, FLAGS.output_dir, 35 | env_problem_name=None) 36 | 37 | def test_train_pong(self): 38 | self._test_hparams_set("rlmf_tiny") 39 | 40 | def test_train_pong_dqn(self): 41 | self._test_hparams_set("rlmf_dqn_tiny") 42 | 43 | 44 | if __name__ == "__main__": 45 | tf.test.main() 46 | -------------------------------------------------------------------------------- /tensor2tensor/rl/trainer_model_free_tictactoe_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests of basic flow of collecting trajectories and training PPO.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.rl import trainer_model_free 23 | from tensor2tensor.utils import registry 24 | 25 | import tensorflow.compat.v1 as tf 26 | 27 | FLAGS = tf.flags.FLAGS 28 | 29 | 30 | class TrainerModelFreeTicTacToeTest(tf.test.TestCase): 31 | 32 | def test_train_tictactoe(self): 33 | hparams = registry.hparams("rlmf_tictactoe") 34 | hparams.batch_size = 2 35 | hparams.eval_sampling_temps = [0.0, 1.0] 36 | hparams.add_hparam("ppo_epochs_num", 2) 37 | hparams.add_hparam("ppo_epoch_length", 3) 38 | 39 | hparams.epochs_num = 100 40 | hparams.eval_every_epochs = 25 41 | 42 | FLAGS.output_dir = tf.test.get_temp_dir() 43 | FLAGS.env_problem_name = "tic_tac_toe_env_problem" 44 | trainer_model_free.train(hparams, FLAGS.output_dir, FLAGS.env_problem_name) 45 | 46 | 47 | if __name__ == "__main__": 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /tensor2tensor/serving/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tensor2tensor/test_data/example_usr_dir/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Example T2T user directory.""" 17 | from . import my_submodule 18 | -------------------------------------------------------------------------------- /tensor2tensor/test_data/example_usr_dir/requirements.txt: -------------------------------------------------------------------------------- 1 | gutenberg 2 | -------------------------------------------------------------------------------- /tensor2tensor/test_data/transformer_test_ckpt/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt-1" 2 | all_model_checkpoint_paths: "model.ckpt-1" 3 | -------------------------------------------------------------------------------- /tensor2tensor/test_data/transformer_test_ckpt/flags.txt: -------------------------------------------------------------------------------- 1 | --eval_steps=1 2 | --hparams_range= 3 | --t2t_usr_dir= 4 | --enable_graph_rewriter=False 5 | --sync=False 6 | --eval_run_autoregressive=False 7 | --eval_use_test_set=False 8 | --worker_id=0 9 | --eval_early_stopping_metric_minimize=True 10 | --worker_replicas=1 11 | --random_seed=1234 12 | --worker_gpu_memory_fraction=0.95 13 | --train_steps=1 14 | --iterations_per_loop=1000 15 | --registry_help=False 16 | --worker_gpu=1 17 | --keep_checkpoint_max=20 18 | --save_checkpoints_secs=0 19 | --gpu_order= 20 | --master= 21 | --generate_data=False 22 | --local_eval_frequency=2000 23 | --export_saved_model=False 24 | --eval_early_stopping_steps=None 25 | --output_dir=/tmp/oss_train 26 | --profile=False 27 | --ps_job=/job:ps 28 | --tmp_dir=/tmp/t2t_datagen 29 | --schedule=continuous_train_and_eval 30 | --problem=translate_ende_wmt8k 31 | --hparams= 32 | --use_tpu=False 33 | --eval_early_stopping_metric_delta=0.1 34 | --ps_gpu=0 35 | --keep_checkpoint_every_n_hours=10000 36 | --decode_hparams= 37 | --tfdbg=False 38 | --data_dir=~/t2t/data 39 | --ps_replicas=0 40 | --eval_early_stopping_metric=loss 41 | --log_device_placement=False 42 | --hparams_set=transformer_test 43 | --dbgprofile=False 44 | --timit_paths= 45 | --tpu_num_shards=8 46 | --locally_shard_to_cpu=False 47 | --worker_job=/job:localhost 48 | --model=transformer 49 | --parsing_path= 50 | -------------------------------------------------------------------------------- /tensor2tensor/test_data/transformer_test_ckpt/hparams.json: -------------------------------------------------------------------------------- 1 | {"daisy_chain_variables": true, "optimizer_adam_beta1": 0.9, "scheduled_sampling_prob": 0.0, "num_hidden_layers": 2, "moe_loss_coef": 0.01, "max_target_seq_length": 0, "clip_grad_norm": 0.0, "pos": "timing", "scheduled_sampling_gold_mixin_prob": 0.5, "initializer": "uniform_unit_scaling", "grad_noise_scale": 0.0, "optimizer_momentum_momentum": 0.9, "nbr_decoder_problems": 1, "attention_key_channels": 0, "eval_drop_long_sequences": false, "learning_rate_cosine_cycle_steps": 250000, "prepend_mode": "none", "weight_decay": 0.0, "symbol_modality_skip_top": false, "weight_noise": 0.0, "target_modality": "default", "attention_dropout": 0.1, "parameter_attention_value_channels": 0, "factored_logits": false, "relu_dropout": 0.1, "no_data_parallelism": false, "layer_preprocess_sequence": "n", "sampling_method": "argmax", "learning_rate": 0.2, "num_heads": 2, "max_length": 256, "summarize_grads": false, "attention_value_channels": 0, "num_encoder_layers": 0, "label_smoothing": 0.1, "use_fixed_batch_size": false, "optimizer": "adam", "moe_k": 2, "self_attention_type": "dot_product", "learning_rate_decay_scheme": "noam", "sampling_temp": 1.0, "kernel_height": 3, "use_pad_remover": true, "batch_size": 4096, "max_relative_position": 0, "force_full_predict": false, "min_length_bucket": 8, "layer_prepostprocess_dropout": 0.1, "eval_run_autoregressive": false, "shared_embedding_and_softmax_weights": true, "symbol_modality_num_shards": 16, "dropout": 0.2, "compress_steps": 0, "parameter_attention_key_channels": 0, "length_bucket_step": 1.1, "kernel_width": 1, "hidden_size": 16, "num_decoder_layers": 0, "input_modalities": "default", "filter_size": 8, "optimizer_adam_beta2": 0.98, "scheduled_sampling_warmup_steps": 50000, "norm_type": "layer", "min_length": 0, "moe_num_experts": 64, "multiply_embedding_mode": "sqrt_depth", "max_input_seq_length": 0, "learning_rate_warmup_steps": 8000, "proximity_bias": false, "ffn_layer": "dense_relu_dense", "initializer_gain": 1.0, "layer_postprocess_sequence": "da", "moe_hidden_sizes": "2048", "optimizer_adam_epsilon": 1e-09, "norm_epsilon": 1e-06} 2 | -------------------------------------------------------------------------------- /tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.data-00000-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/tensor2tensor/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.data-00000-of-00002 -------------------------------------------------------------------------------- /tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.data-00001-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/tensor2tensor/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.data-00001-of-00002 -------------------------------------------------------------------------------- /tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/tensor2tensor/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.index -------------------------------------------------------------------------------- /tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/tensor2tensor/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.meta -------------------------------------------------------------------------------- /tensor2tensor/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tensor2tensor/utils/adafactor_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for adafactor.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.utils import adafactor 23 | 24 | import tensorflow as tf 25 | 26 | 27 | class AdafactorTest(tf.test.TestCase): 28 | 29 | def testCallableLearningRate(self): 30 | def lr(): 31 | return 0.01 32 | 33 | opt = adafactor.AdafactorOptimizer(learning_rate=lr) 34 | v1 = tf.Variable([1., 2.]) 35 | v2 = tf.Variable([3., 4.]) 36 | with tf.GradientTape() as tape: 37 | tape.watch([v1, v2]) 38 | loss = v1 * v2 39 | v1_grad, v2_grad = tape.gradient(loss, [v1, v2]) 40 | opt.apply_gradients(((v1_grad, v1), (v2_grad, v2))) 41 | 42 | 43 | if __name__ == '__main__': 44 | tf.test.main() 45 | -------------------------------------------------------------------------------- /tensor2tensor/utils/compute_video_metrics.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Computes and saves the metrics for video prediction and generation.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | from six.moves import range 25 | from tensor2tensor.bin import t2t_decoder 26 | from tensor2tensor.utils import video_metrics 27 | import tensorflow.compat.v1 as tf 28 | 29 | 30 | FLAGS = tf.flags.FLAGS 31 | 32 | 33 | def main(_): 34 | hparams = t2t_decoder.create_hparams() 35 | problem = hparams.problem 36 | frame_shape = [problem.frame_height, 37 | problem.frame_width, 38 | problem.num_channels] 39 | decode_hp = t2t_decoder.create_decode_hparams() 40 | 41 | output_dirs = [ 42 | os.path.join(FLAGS.output_dir, "decode_%05d" % decode_id) 43 | for decode_id in range(decode_hp.num_decodes) 44 | ] 45 | 46 | video_metrics.compute_and_save_video_metrics( 47 | output_dirs, 48 | FLAGS.problem, 49 | hparams.video_num_target_frames, 50 | frame_shape) 51 | 52 | 53 | if __name__ == "__main__": 54 | tf.app.run(main) 55 | -------------------------------------------------------------------------------- /tensor2tensor/utils/diet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for common layers.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from tensor2tensor.utils import diet 22 | 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | class DietVarTest(tf.test.TestCase): 27 | 28 | def testDiet(self): 29 | 30 | params = diet.diet_adam_optimizer_params() 31 | 32 | @diet.fn_with_diet_vars(params) 33 | def model_fn(x): 34 | y = tf.layers.dense(x, 10, use_bias=False) 35 | return y 36 | 37 | @diet.fn_with_diet_vars(params) 38 | def model_fn2(x): 39 | y = tf.layers.dense(x, 10, use_bias=False) 40 | return y 41 | 42 | x = tf.random_uniform((10, 10)) 43 | y = model_fn(x) + 10. 44 | y = model_fn2(y) + 10. 45 | grads = tf.gradients(y, [x]) 46 | with tf.control_dependencies(grads): 47 | incr_step = tf.assign_add(tf.train.get_or_create_global_step(), 1) 48 | 49 | train_op = tf.group(incr_step, *grads) 50 | with self.test_session() as sess: 51 | sess.run(tf.global_variables_initializer()) 52 | orig_vals = sess.run(tf.global_variables()) 53 | for _ in range(10): 54 | sess.run(train_op) 55 | new_vals = sess.run(tf.global_variables()) 56 | 57 | different = [] 58 | for old, new in zip(orig_vals, new_vals): 59 | try: 60 | self.assertAllClose(old, new) 61 | except AssertionError: 62 | different.append(True) 63 | self.assertEqual(len(different), len(tf.global_variables())) 64 | 65 | 66 | if __name__ == "__main__": 67 | tf.test.main() 68 | -------------------------------------------------------------------------------- /tensor2tensor/utils/get_cnndm_rouge.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Path to moses dir 4 | mosesdecoder=$1 5 | 6 | # Path to file containing gold summaries, one per line 7 | targets_file=$2 8 | # Path to file containing model generated summaries, one per line 9 | decodes_file=$3 10 | 11 | # Tokenize. 12 | perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l en < $targets_file > $targets_file.tok 13 | perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l en < $decodes_file > $decodes_file.tok 14 | 15 | # Get rouge scores 16 | python get_rouge.py --decodes_filename $decodes_file.tok --targets_filename $targets_file.tok 17 | -------------------------------------------------------------------------------- /tensor2tensor/utils/get_ende_bleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mosesdecoder=~/mosesdecoder 4 | tok_gold_targets=newstest2013.tok.de 5 | 6 | decodes_file=$1 7 | 8 | # Replace unicode. 9 | perl $mosesdecoder/scripts/tokenizer/replace-unicode-punctuation.perl -l de < $decodes_file > $decodes_file.n 10 | 11 | # Tokenize. 12 | perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l de < $decodes_file.n > $decodes_file.tok 13 | 14 | # Put compounds in ATAT format (comparable to papers like GNMT, ConvS2S). 15 | # See https://nlp.stanford.edu/projects/nmt/ : 16 | # 'Also, for historical reasons, we split compound words, e.g., 17 | # "rich-text format" --> rich ##AT##-##AT## text format."' 18 | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $tok_gold_targets > $tok_gold_targets.atat 19 | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $decodes_file.tok > $decodes_file.tok.atat 20 | 21 | # Get BLEU. 22 | perl $mosesdecoder/scripts/generic/multi-bleu.perl $tok_gold_targets.atat < $decodes_file.tok.atat 23 | -------------------------------------------------------------------------------- /tensor2tensor/utils/hparams_lib_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for trainer_lib.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | from tensor2tensor.utils import hparams_lib 25 | 26 | import tensorflow.compat.v1 as tf 27 | 28 | 29 | class HparamsLibTest(tf.test.TestCase): 30 | 31 | def testCreateHparamsFromJson(self): 32 | # Get json_path 33 | pkg = os.path.abspath(__file__) 34 | pkg, _ = os.path.split(pkg) 35 | pkg, _ = os.path.split(pkg) 36 | json_path = os.path.join( 37 | pkg, "test_data", "transformer_test_ckpt", "hparams.json") 38 | 39 | # Create hparams 40 | hparams = hparams_lib.create_hparams_from_json(json_path) 41 | self.assertEqual(75, len(hparams.values())) 42 | 43 | 44 | if __name__ == "__main__": 45 | tf.test.main() 46 | -------------------------------------------------------------------------------- /tensor2tensor/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Miscellaneous utilities.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import pprint 23 | import re 24 | 25 | # Camel case to snake case utils 26 | _first_cap_re = re.compile("(.)([A-Z][a-z0-9]+)") 27 | _all_cap_re = re.compile("([a-z0-9])([A-Z])") 28 | 29 | 30 | def camelcase_to_snakecase(name): 31 | s1 = _first_cap_re.sub(r"\1_\2", name) 32 | return _all_cap_re.sub(r"\1_\2", s1).lower() 33 | 34 | 35 | def snakecase_to_camelcase(name): 36 | return "".join([w[0].upper() + w[1:] for w in name.split("_")]) 37 | 38 | 39 | def pprint_hparams(hparams): 40 | """Represents hparams using its dictionary and calls pprint.pformat on it.""" 41 | return "\n{}".format(pprint.pformat(hparams.values(), width=1)) 42 | -------------------------------------------------------------------------------- /tensor2tensor/utils/optimize_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2tensor.utils.optimize.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl.testing import parameterized 23 | from tensor2tensor.utils import hparams_lib 24 | from tensor2tensor.utils import optimize 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | class OptimizeTest(parameterized.TestCase, tf.test.TestCase): 29 | 30 | @parameterized.parameters( 31 | "sgd", 32 | "SGD", 33 | "rms_prop", 34 | "RMSProp", 35 | "adagrad", 36 | "Adagrad", 37 | "adam", 38 | "Adam", 39 | "adam_w", 40 | "AdamW", 41 | ) 42 | def test_names(self, opt_name): 43 | hparams = hparams_lib.create_hparams("basic_1") 44 | optimize.ConditionalOptimizer(opt_name, 0.1, hparams) 45 | 46 | 47 | if __name__ == "__main__": 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /tensor2tensor/utils/partial_checkpoint_load_hook.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Hook to partially load a checkpoint.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | class PartialCheckpointLoad(tf.train.SessionRunHook): 25 | """Partially load train_variables from a checkpoint. 26 | 27 | Hook used to load each variable saved in checkpoint into the graph. It 28 | will ignore any additional variables present in the graph that are not 29 | saved in the checkpoint. (Note: The loaded variables include ADAM/training 30 | variables, if they exist in the checkpoint) 31 | Can perform mapping if the base scopename for graph variables is different 32 | from the checkpoint variables. 33 | """ 34 | 35 | def __init__(self, hook_context, chk_scopename, graph_scopename): 36 | """Initialize the hook with chkp directory and scopenames. 37 | 38 | Args: 39 | hook_context: HookContext object containing hparams. 40 | chk_scopename: Base scopename of variables in the checkpoint being loaded 41 | graph_scopename: Base scopename of variables in current graph 42 | """ 43 | self.checkpoint_path = hook_context.hparams.partial_load_checkpoint 44 | self.chk_scopename = chk_scopename 45 | self.graph_scopename = graph_scopename 46 | 47 | def begin(self): 48 | # TODO(karishmamalkan): Add logging for when variables are loaded 49 | variable_references = {var.name: var for var in tf.all_variables()} 50 | variable_mappings = {} 51 | vars_in_chk = tf.train.list_variables(self.checkpoint_path) 52 | for (var, _) in vars_in_chk: 53 | variable_mappings[var] = variable_references[ 54 | var.replace(self.chk_scopename, self.graph_scopename) + ":0"] 55 | tf.train.init_from_checkpoint(self.checkpoint_path, variable_mappings) 56 | -------------------------------------------------------------------------------- /tensor2tensor/utils/restore_hook.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Restore hooks.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import six 23 | 24 | from tensor2tensor.utils import contrib 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | class RestoreHook(tf.train.SessionRunHook): 29 | """Restore variables from a checkpoint path.""" 30 | 31 | def __init__(self, checkpoint_path="", new_model_scope="", old_model_scope="", 32 | include=None, exclude=None): 33 | self._checkpoint_path = checkpoint_path 34 | self._new_model_scope = new_model_scope 35 | self._old_model_scope = old_model_scope 36 | self._include = include 37 | self._exclude = exclude 38 | 39 | def begin(self): 40 | """Load variables from checkpoint. 41 | 42 | New model variables have the following name foramt: 43 | new_model_scope/old_model_scope/xxx/xxx:0 To find the map of 44 | name to variable, need to strip the new_model_scope and then 45 | match the old_model_scope and remove the suffix :0. 46 | 47 | """ 48 | variables_to_restore = contrib.framework().get_variables_to_restore( 49 | include=self._include, exclude=self._exclude) 50 | # remove new_model_scope from variable name prefix 51 | assignment_map = {variable.name[len(self._new_model_scope):]: variable 52 | for variable in variables_to_restore 53 | if variable.name.startswith(self._new_model_scope)} 54 | # remove :0 from variable name suffix 55 | assignment_map = {name.split(":")[0]: variable 56 | for name, variable in six.iteritems(assignment_map) 57 | if name.startswith(self._old_model_scope)} 58 | self._assignment_map = assignment_map 59 | 60 | tf.logging.info("restoring %d variables from checkpoint %s"%( 61 | len(assignment_map), self._checkpoint_path)) 62 | tf.train.init_from_checkpoint(self._checkpoint_path, self._assignment_map) 63 | -------------------------------------------------------------------------------- /tensor2tensor/utils/test_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tensor2tensor.utils.test_utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensor2tensor.utils import test_utils 23 | 24 | import tensorflow.compat.v1 as tf 25 | tf.enable_eager_execution() 26 | 27 | 28 | class RunInGraphAndEagerTest(tf.test.TestCase): 29 | 30 | def test_run_in_graph_and_eager_modes(self): 31 | l = [] 32 | def inc(self, with_brackets): 33 | del self # self argument is required by run_in_graph_and_eager_modes. 34 | mode = "eager" if tf.executing_eagerly() else "graph" 35 | with_brackets = "with_brackets" if with_brackets else "without_brackets" 36 | l.append((with_brackets, mode)) 37 | 38 | f = test_utils.run_in_graph_and_eager_modes(inc) 39 | f(self, with_brackets=False) 40 | f = test_utils.run_in_graph_and_eager_modes()(inc) 41 | f(self, with_brackets=True) 42 | 43 | self.assertEqual(len(l), 4) 44 | self.assertEqual(set(l), { 45 | ("with_brackets", "graph"), 46 | ("with_brackets", "eager"), 47 | ("without_brackets", "graph"), 48 | ("without_brackets", "eager"), 49 | }) 50 | 51 | def test_run_in_graph_and_eager_modes_setup_in_same_mode(self): 52 | modes = [] 53 | mode_name = lambda: "eager" if tf.executing_eagerly() else "graph" 54 | 55 | class ExampleTest(tf.test.TestCase): 56 | 57 | def runTest(self): 58 | pass 59 | 60 | def setUp(self): 61 | modes.append("setup_" + mode_name()) 62 | 63 | @test_utils.run_in_graph_and_eager_modes 64 | def testBody(self): 65 | modes.append("run_" + mode_name()) 66 | 67 | e = ExampleTest() 68 | e.setUp() 69 | e.testBody() 70 | 71 | self.assertEqual(modes[0:2], ["setup_eager", "run_eager"]) 72 | self.assertEqual(modes[2:], ["setup_graph", "run_graph"]) 73 | 74 | if __name__ == "__main__": 75 | tf.test.main() 76 | -------------------------------------------------------------------------------- /tensor2tensor/utils/update_ops_hook.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Hook to run tf.GraphKeys.UPDATE_OPS.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | class UpdateOpsHook(tf.train.SessionRunHook): 25 | """Hook to run assign_ops.""" 26 | 27 | def before_run(self, run_context): 28 | del run_context 29 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 30 | return tf.train.SessionRunArgs(update_ops) 31 | -------------------------------------------------------------------------------- /tensor2tensor/utils/usr_dir.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility to load code from an external user-supplied directory.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import importlib 22 | import os 23 | import sys 24 | import tensorflow.compat.v1 as tf 25 | 26 | 27 | INTERNAL_USR_DIR_PACKAGE = "t2t_usr_dir_internal" 28 | 29 | 30 | def import_usr_dir(usr_dir): 31 | """Import module at usr_dir, if provided.""" 32 | if not usr_dir: 33 | return 34 | if usr_dir == INTERNAL_USR_DIR_PACKAGE: 35 | # The package has been installed with pip under this name for Cloud ML 36 | # Engine so just import it. 37 | importlib.import_module(INTERNAL_USR_DIR_PACKAGE) 38 | return 39 | 40 | dir_path = os.path.abspath(os.path.expanduser(usr_dir).rstrip("/")) 41 | containing_dir, module_name = os.path.split(dir_path) 42 | tf.logging.info("Importing user module %s from path %s", module_name, 43 | containing_dir) 44 | sys.path.insert(0, containing_dir) 45 | importlib.import_module(module_name) 46 | sys.path.pop(0) 47 | -------------------------------------------------------------------------------- /tensor2tensor/utils/video_metrics_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """video metrics test.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | from tensor2tensor.utils import video_metrics 24 | import tensorflow.compat.v1 as tf 25 | 26 | 27 | class VideoMetricsTest(tf.test.TestCase): 28 | 29 | def test_reduce_to_best_decode(self): 30 | # num_decodes=2, num_samples=3, num_frames=4 31 | decode1 = [ 32 | [30.0, 32.0, 33.0, 34.0], 33 | [22.0, 19.0, 12.0, 13.0], 34 | [30.0, 10.0, 30.0, 10.0]] 35 | decode2 = [ 36 | [22.0, 19.0, 12.0, 13.0], 37 | [30.0, 32.0, 33.0, 34.0], 38 | [25.0, 25.0, 25.0, 25.0]] 39 | all_decodes = [decode1, decode2] 40 | all_decodes = np.array(all_decodes) 41 | best_decode, best_decode_ind = video_metrics.reduce_to_best_decode( 42 | all_decodes, np.argmax) 43 | worst_decode, worst_decode_ind = video_metrics.reduce_to_best_decode( 44 | all_decodes, np.argmin) 45 | exp_best_decode = [ 46 | [30.0, 32.0, 33.0, 34.0], 47 | [30.0, 32.0, 33.0, 34.0], 48 | [25.0, 25.0, 25.0, 25.0]] 49 | exp_worst_decode = [ 50 | [22.0, 19.0, 12.0, 13.0], 51 | [22.0, 19.0, 12.0, 13.0], 52 | [30.0, 10.0, 30.0, 10.0]] 53 | self.assertTrue(np.allclose(best_decode, exp_best_decode)) 54 | self.assertTrue(np.allclose(worst_decode, exp_worst_decode)) 55 | self.assertTrue(np.allclose(best_decode_ind, [0, 1, 1])) 56 | self.assertTrue(np.allclose(worst_decode_ind, [1, 0, 0])) 57 | 58 | 59 | if __name__ == '__main__': 60 | tf.test.main() 61 | -------------------------------------------------------------------------------- /tensor2tensor/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | --------------------------------------------------------------------------------