├── .gitignore
├── .travis.yml
├── AUTHORS
├── CONTRIBUTING.md
├── ISSUE_TEMPLATE.md
├── LICENSE
├── README.md
├── docs
├── cloud_mlengine.md
├── cloud_tpu.md
├── distributed_training.md
├── index.md
├── multi_problem.md
├── new_model.md
├── new_problem.md
├── overview.md
├── tutorials
│ └── asr_with_transformer.md
└── walkthrough.md
├── floyd.yml
├── floyd_requirements.txt
├── oss_scripts
├── oss_integration_test.sh
├── oss_pip_install.sh
├── oss_release.sh
└── oss_tests.sh
├── pylintrc
├── setup.py
└── tensor2tensor
├── __init__.py
├── bin
├── __init__.py
├── build_vocab.py
├── make_tf_configs.py
├── t2t-avg-all
├── t2t-bleu
├── t2t-datagen
├── t2t-decoder
├── t2t-eval
├── t2t-exporter
├── t2t-insights-server
├── t2t-make-tf-configs
├── t2t-query-server
├── t2t-trainer
├── t2t-translate-all
├── t2t_attack.py
├── t2t_avg_all.py
├── t2t_bleu.py
├── t2t_datagen.py
├── t2t_decoder.py
├── t2t_distill.py
├── t2t_eval.py
├── t2t_prune.py
├── t2t_trainer.py
├── t2t_trainer_test.py
└── t2t_translate_all.py
├── data_generators
├── README.md
├── __init__.py
├── algorithmic.py
├── algorithmic_math.py
├── algorithmic_math_deepmind.py
├── algorithmic_math_test.py
├── algorithmic_math_two_variables.py
├── algorithmic_test.py
├── all_problems.py
├── allen_brain.py
├── allen_brain_test.py
├── audio.py
├── audio_encoder.py
├── audio_test.py
├── babi_qa.py
├── bair_robot_pushing.py
├── celeba.py
├── celeba_test.py
├── celebahq.py
├── cifar.py
├── cipher.py
├── cleaner_en_xx.py
├── cnn_dailymail.py
├── cola.py
├── common_voice.py
├── common_voice_test.py
├── conll_ner.py
├── desc2code.py
├── desc2code_test.py
├── dialog_abstract.py
├── dialog_cornell.py
├── dialog_dailydialog.py
├── dialog_opensubtitles.py
├── dialog_personachat.py
├── dna_encoder.py
├── dna_encoder_test.py
├── enwik8.py
├── fsns.py
├── function_docstring.py
├── gene_expression.py
├── gene_expression_test.py
├── generator_utils.py
├── generator_utils_test.py
├── google_robot_pushing.py
├── gym_env.py
├── gym_env_test.py
├── ice_parsing.py
├── image_lsun.py
├── image_utils.py
├── image_utils_test.py
├── imagenet.py
├── imagenet_test.py
├── imdb.py
├── inspect_tfrecord.py
├── lambada.py
├── librispeech.py
├── lm1b.py
├── lm1b_imdb.py
├── lm1b_mnli.py
├── mnist.py
├── moving_mnist.py
├── mrpc.py
├── mscoco.py
├── mscoco_test.py
├── multi_problem.py
├── multi_problem_v2.py
├── multi_problem_v2_test.py
├── multinli.py
├── ocr.py
├── ops
│ ├── pack_sequences_ops.cc
│ ├── pack_sequences_ops_test.py
│ ├── subword_text_encoder.cc
│ ├── subword_text_encoder.h
│ ├── subword_text_encoder_ops.cc
│ ├── subword_text_encoder_ops_test.py
│ ├── subword_text_encoder_test.cc
│ └── testdata
│ │ └── subwords
├── paraphrase_ms_coco.py
├── paraphrase_ms_coco_test.py
├── pointer_generator_word.py
├── problem.py
├── problem_hparams.py
├── problem_test.py
├── program_search.py
├── program_search_test.py
├── ptb.py
├── qnli.py
├── quora_qpairs.py
├── rte.py
├── scitail.py
├── seq2edits.py
├── snli.py
├── speech_recognition.py
├── squad.py
├── sst_binary.py
├── stanford_nli.py
├── style_transfer.py
├── style_transfer_test.py
├── subject_verb_agreement.py
├── test_data
│ ├── 1.csv
│ ├── corpus-1.txt
│ ├── corpus-2.txt
│ ├── vocab-1.txt
│ └── vocab-2.txt
├── text_encoder.py
├── text_encoder_build_subword.py
├── text_encoder_test.py
├── text_problems.py
├── text_problems_test.py
├── timeseries.py
├── timeseries_data_generator.py
├── timeseries_data_generator_test.py
├── timeseries_test.py
├── tokenizer.py
├── tokenizer_test.py
├── transduction_problems.py
├── transduction_problems_test.py
├── translate.py
├── translate_encs.py
├── translate_encs_cubbitt.py
├── translate_ende.py
├── translate_ende_test.py
├── translate_enes.py
├── translate_enet.py
├── translate_enfr.py
├── translate_enid.py
├── translate_enmk.py
├── translate_enro.py
├── translate_entn.py
├── translate_envi.py
├── translate_enzh.py
├── translate_test.py
├── video_generated.py
├── video_utils.py
├── video_utils_test.py
├── vqa.py
├── vqa_utils.py
├── wiki.py
├── wiki_lm.py
├── wiki_multi_problems.py
├── wiki_revision.py
├── wiki_revision_utils.py
├── wikifact
│ └── README.md
├── wikisum
│ ├── README.md
│ ├── __init__.py
│ ├── delete_instances.sh
│ ├── generate_vocab.py
│ ├── get_references_commoncrawl.py
│ ├── get_references_web.py
│ ├── get_references_web_single_group.py
│ ├── html.py
│ ├── parallel_launch.py
│ ├── produce_examples.py
│ ├── test_data
│ │ ├── para_bad1.txt
│ │ └── para_good1.txt
│ ├── utils.py
│ ├── utils_test.py
│ ├── validate_data.py
│ └── wikisum.py
├── wikitext103.py
├── wnli.py
├── wsj_parsing.py
├── yelp_full.py
└── yelp_polarity.py
├── envs
├── __init__.py
├── env_problem.py
├── env_problem_utils.py
├── env_problem_utils_test.py
├── gym_env_problem.py
├── gym_env_problem_test.py
├── gym_spaces_utils.py
├── gym_spaces_utils_test.py
├── mujoco_problems.py
├── mujoco_problems_test.py
├── rendered_env_problem.py
├── rendered_env_problem_test.py
├── tic_tac_toe_env.py
├── tic_tac_toe_env_problem.py
├── tic_tac_toe_env_problem_test.py
├── tic_tac_toe_env_test.py
├── time_step.py
├── time_step_test.py
├── trajectory.py
└── trajectory_test.py
├── insights
├── README.md
├── __init__.py
├── graph.py
├── insight_configuration.proto
├── polymer
│ ├── .bowerrc
│ ├── attention_visualization
│ │ ├── attention-visualization.html
│ │ └── attention-visualization.js
│ ├── bower.json
│ ├── common-types.js
│ ├── explore_view
│ │ ├── explore-view.html
│ │ └── explore-view.js
│ ├── graph_visualization
│ │ ├── graph-visualization.html
│ │ └── graph-visualization.js
│ ├── index.html
│ ├── insights_app
│ │ ├── insights-app.html
│ │ └── insights-app.js
│ ├── language_selector
│ │ ├── language-selector-content.html
│ │ ├── language-selector-content.js
│ │ ├── language-selector.html
│ │ └── language-selector.js
│ ├── processing_visualization
│ │ ├── processing-visualization.html
│ │ └── processing-visualization.js
│ ├── query_card
│ │ ├── query-card.html
│ │ └── query-card.js
│ ├── tensor2tensor.html
│ └── translation_result
│ │ ├── translation-result.html
│ │ └── translation-result.js
├── query_processor.py
├── server.py
└── transformer_model.py
├── layers
├── __init__.py
├── area_attention.py
├── area_attention_test.py
├── common_attention.py
├── common_attention_test.py
├── common_audio.py
├── common_hparams.py
├── common_image_attention.py
├── common_image_attention_test.py
├── common_layers.py
├── common_layers_test.py
├── common_video.py
├── common_video_test.py
├── discretization.py
├── discretization_test.py
├── latent_layers.py
├── latent_layers_test.py
├── message_passing_attention.py
├── modalities.py
├── modalities_test.py
├── ngram.py
├── ngram_test.py
├── transformer_glow_layers.py
├── transformer_glow_layers_ops.py
├── transformer_glow_layers_ops_test.py
├── transformer_glow_layers_test.py
├── transformer_layers.py
├── transformer_memory.py
├── transformer_memory_test.py
├── vq_discrete.py
└── vqa_layers.py
├── metrics
├── __init__.py
├── video_conditional_fvd.py
└── video_conditional_fvd_test.py
├── models
├── README.md
├── __init__.py
├── basic.py
├── basic_test.py
├── bytenet.py
├── bytenet_test.py
├── distillation.py
├── evolved_transformer.py
├── evolved_transformer_test.py
├── image_transformer.py
├── image_transformer_2d.py
├── image_transformer_2d_test.py
├── image_transformer_test.py
├── lstm.py
├── lstm_test.py
├── mtf_image_transformer.py
├── mtf_image_transformer_test.py
├── mtf_resnet.py
├── mtf_transformer.py
├── mtf_transformer2.py
├── mtf_transformer_test.py
├── neural_architecture_search
│ ├── README.md
│ ├── __init__.py
│ ├── nas_layers.py
│ ├── nas_layers_test.py
│ ├── nas_model.py
│ └── nas_model_test.py
├── neural_assistant.py
├── neural_gpu.py
├── neural_gpu_test.py
├── research
│ ├── __init__.py
│ ├── adafactor_experiments.py
│ ├── aligned.py
│ ├── attention_lm.py
│ ├── attention_lm_moe.py
│ ├── autoencoders.py
│ ├── autoencoders_test.py
│ ├── cycle_gan.py
│ ├── gene_expression.py
│ ├── gene_expression_test.py
│ ├── glow.py
│ ├── glow_init_hook.py
│ ├── glow_ops.py
│ ├── glow_ops_test.py
│ ├── glow_test.py
│ ├── lm_experiments.py
│ ├── moe.py
│ ├── moe_experiments.py
│ ├── multiquery_paper.py
│ ├── neural_stack.py
│ ├── neural_stack_test.py
│ ├── residual_shuffle_exchange.py
│ ├── rl.py
│ ├── shuffle_network.py
│ ├── similarity_transformer.py
│ ├── super_lm.py
│ ├── transformer_aux.py
│ ├── transformer_aux_test.py
│ ├── transformer_moe.py
│ ├── transformer_nat.py
│ ├── transformer_parallel.py
│ ├── transformer_revnet.py
│ ├── transformer_revnet_test.py
│ ├── transformer_seq2edits.py
│ ├── transformer_sketch.py
│ ├── transformer_symshard.py
│ ├── transformer_vae.py
│ ├── transformer_vae_flow_prior.py
│ ├── transformer_vae_flow_prior_ops.py
│ ├── transformer_vae_test.py
│ ├── universal_transformer.py
│ ├── universal_transformer_test.py
│ ├── universal_transformer_util.py
│ ├── vqa_attention.py
│ ├── vqa_attention_test.py
│ ├── vqa_recurrent_self_attention.py
│ └── vqa_self_attention.py
├── resnet.py
├── resnet_test.py
├── revnet.py
├── revnet_test.py
├── shake_shake.py
├── slicenet.py
├── slicenet_test.py
├── text_cnn.py
├── transformer.py
├── transformer_test.py
├── vanilla_gan.py
├── video
│ ├── __init__.py
│ ├── base.py
│ ├── base_vae.py
│ ├── basic_deterministic.py
│ ├── basic_deterministic_params.py
│ ├── basic_deterministic_test.py
│ ├── basic_recurrent.py
│ ├── basic_recurrent_test.py
│ ├── basic_stochastic.py
│ ├── basic_stochastic_test.py
│ ├── emily.py
│ ├── emily_test.py
│ ├── epva.py
│ ├── epva_params.py
│ ├── next_frame_glow.py
│ ├── nfg_conv3d_test.py
│ ├── nfg_conv_lstm_test.py
│ ├── nfg_conv_test.py
│ ├── nfg_interpolate.py
│ ├── nfg_test_utils.py
│ ├── nfg_uncond_test.py
│ ├── savp.py
│ ├── savp_params.py
│ ├── savp_test.py
│ ├── sv2p.py
│ ├── sv2p_params.py
│ ├── sv2p_test.py
│ └── tests_utils.py
├── xception.py
└── xception_test.py
├── notebooks
├── Transformer_translate.ipynb
├── asr_transformer.ipynb
├── hello_t2t-rl.ipynb
├── hello_t2t.ipynb
└── t2t_problem.ipynb
├── problems.py
├── problems_colab.py
├── problems_test.py
├── rl
├── README.md
├── __init__.py
├── batch_dqn_agent_test.py
├── batch_runner_test.py
├── datagen_with_agent.py
├── dopamine_connector.py
├── envs
│ ├── __init__.py
│ ├── in_graph_batch_env.py
│ ├── py_func_batch_env.py
│ ├── simulated_batch_env.py
│ ├── simulated_batch_gym_env.py
│ └── tf_atari_wrappers.py
├── evaluator.py
├── evaluator_test.py
├── gym_utils.py
├── gym_utils_test.py
├── player.py
├── player_utils.py
├── policy_learner.py
├── ppo.py
├── ppo_learner.py
├── restarter.py
├── restarter_test.py
├── rl_utils.py
├── trainer_model_based.py
├── trainer_model_based_agent_only.py
├── trainer_model_based_params.py
├── trainer_model_based_recurrent_test.py
├── trainer_model_based_stochastic_test.py
├── trainer_model_based_sv2p_test.py
├── trainer_model_based_test.py
├── trainer_model_free.py
├── trainer_model_free_test.py
└── trainer_model_free_tictactoe_test.py
├── serving
├── README.md
├── __init__.py
├── export.py
├── query.py
└── serving_utils.py
├── test_data
├── example_usr_dir
│ ├── __init__.py
│ ├── my_submodule.py
│ └── requirements.txt
├── transformer_test_ckpt
│ ├── checkpoint
│ ├── flags.txt
│ ├── hparams.json
│ ├── model.ckpt-1.data-00000-of-00002
│ ├── model.ckpt-1.data-00001-of-00002
│ ├── model.ckpt-1.index
│ └── model.ckpt-1.meta
├── vocab.translate_ende_wmt32k.32768.subwords
└── vocab.translate_ende_wmt8k.8192.subwords
├── utils
├── __init__.py
├── adafactor.py
├── adafactor_test.py
├── adv_attack_utils.py
├── avg_checkpoints.py
├── beam_search.py
├── beam_search_test.py
├── bleu_hook.py
├── bleu_hook_test.py
├── checkpoint_compatibility_test.py
├── cloud_mlengine.py
├── compute_video_metrics.py
├── contrib.py
├── data_reader.py
├── data_reader_test.py
├── decoding.py
├── devices.py
├── diet.py
├── diet_test.py
├── expert_utils.py
├── expert_utils_test.py
├── flags.py
├── get_cnndm_rouge.sh
├── get_ende_bleu.sh
├── get_rouge.py
├── hparam.py
├── hparam_test.py
├── hparams_lib.py
├── hparams_lib_test.py
├── learning_rate.py
├── metrics.py
├── metrics_hook.py
├── metrics_hook_test.py
├── metrics_test.py
├── misc_utils.py
├── misc_utils_test.py
├── mlperf_log.py
├── mlperf_tags.py
├── mtf_model.py
├── multistep_optimizer.py
├── multistep_optimizer_test.py
├── multistep_with_adamoptimizer.py
├── multistep_with_adamoptimizer_test.py
├── optimize.py
├── optimize_test.py
├── partial_checkpoint_load_hook.py
├── pruning_utils.py
├── quantization.py
├── registry.py
├── registry_test.py
├── restore_hook.py
├── rouge.py
├── rouge_test.py
├── sari_hook.py
├── sari_hook_test.py
├── scheduled_sampling.py
├── t2t_model.py
├── t2t_model_test.py
├── test_utils.py
├── test_utils_test.py
├── trainer_lib.py
├── trainer_lib_test.py
├── update_ops_hook.py
├── usr_dir.py
├── video
│ ├── prediction2gif.py
│ └── reward_confusion.py
├── video2gif.py
├── video_metrics.py
├── video_metrics_test.py
├── yellowfin.py
└── yellowfin_test.py
└── visualization
├── TransformerVisualization.ipynb
├── __init__.py
├── attention.js
├── attention.py
├── visualization.py
└── visualization_test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled python modules.
2 | *.pyc
3 |
4 | # Byte-compiled
5 | _pycache__/
6 | .cache/
7 |
8 | # Python egg metadata, regenerated from source files by setuptools.
9 | /*.egg-info
10 | .eggs/
11 |
12 | # PyPI distribution artifacts.
13 | build/
14 | dist/
15 |
16 | # Sublime project files
17 | *.sublime-project
18 | *.sublime-workspace
19 |
20 | # Tests
21 | .pytest_cache/
22 |
23 | # Other
24 | *.DS_Store
25 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | sudo: required
2 | language: python
3 | cache: pip
4 | git:
5 | depth: 3
6 | quiet: true
7 | services:
8 | - docker
9 | python:
10 | - "3.6"
11 | env:
12 | global:
13 | - T2T_PROBLEM=algorithmic_reverse_binary40_test
14 | - T2T_DATA_DIR=/tmp/t2t-data
15 | - T2T_TRAIN_DIR=/tmp/t2t-train
16 | - TF_LATEST="1.15.*"
17 | # This is necessary to have gsutil work with Python 2.7
18 | - BOTO_CONFIG=/dev/null
19 | matrix:
20 | - TF_VERSION="1.15.*"
21 | install:
22 | - ./oss_scripts/oss_pip_install.sh
23 | script:
24 | - ./oss_scripts/oss_tests.sh
25 | - ./oss_scripts/oss_integration_test.sh
26 |
27 | # Conditional commands should each be in a separate block to get proper
28 | # errors on Travis.
29 | #
30 | # TODO(afrozm): Re-enable if this becomes an issue.
31 | # - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
32 | # pylint -j 2 tensor2tensor;
33 | # fi
34 |
--------------------------------------------------------------------------------
/AUTHORS:
--------------------------------------------------------------------------------
1 | # This is the list of T2T authors for copyright purposes.
2 | #
3 | # This does not necessarily list everyone who has contributed code, since in
4 | # some cases, their employer may be the copyright holder. To see the full list
5 | # of contributors, see the revision history in source control.
6 |
7 | Google Inc.
8 | Artit Wangperawong
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | # Issues
4 |
5 | * Please tag your issue with `bug`, `feature request`, or `question` to help us
6 | effectively respond.
7 | * Please include the versions of TensorFlow and Tensor2Tensor you are running
8 | (run `pip list | grep tensor`)
9 | * Please provide the command line you ran as well as the log output.
10 |
11 | # Pull Requests
12 |
13 | We'd love to accept your patches and contributions to this project. There are
14 | just a few small guidelines you need to follow.
15 |
16 | ## Contributor License Agreement
17 |
18 | Contributions to this project must be accompanied by a Contributor License
19 | Agreement. You (or your employer) retain the copyright to your contribution,
20 | this simply gives us permission to use and redistribute your contributions as
21 | part of the project. Head over to to see
22 | your current agreements on file or to sign a new one.
23 |
24 | You generally only need to submit a CLA once, so if you've already submitted one
25 | (even if it was for a different project), you probably don't need to do it
26 | again.
27 |
28 | ## Code reviews
29 |
30 | All submissions, including submissions by project members, require review. We
31 | use GitHub pull requests for this purpose. Consult
32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
33 | information on using pull requests.
34 |
--------------------------------------------------------------------------------
/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ### Description
2 |
3 | ...
4 |
5 | ### Environment information
6 |
7 | ```
8 | OS:
9 |
10 | $ pip freeze | grep tensor
11 | # your output here
12 |
13 | $ python -V
14 | # your output here
15 | ```
16 |
17 | ### For bugs: reproduction and error logs
18 |
19 | ```
20 | # Steps to reproduce:
21 | ...
22 | ```
23 |
24 | ```
25 | # Error logs:
26 | ...
27 | ```
28 |
--------------------------------------------------------------------------------
/docs/cloud_tpu.md:
--------------------------------------------------------------------------------
1 | # Running on Cloud TPUs
2 |
3 | Tensor2Tensor supports running on Google Cloud Platforms TPUs, chips
4 | specialized for ML training. See the official tutorials for [running the
5 | T2T Transformer for text on Cloud TPUs](https://cloud.google.com/tpu/docs/tutorials/transformer) and
6 | [Transformer for Speech Recognition](https://cloud.google.com/tpu/docs/tutorials/automated-speech-recognition).
7 |
8 | ## Other models on TPU
9 |
10 | Many of Tensor2Tensor's models work on TPU.
11 |
12 | You can provision a VM and TPU with `ctpu up`. Use the `t2t-trainer` command
13 | on the VM as usual with the additional flags `--use_tpu` and
14 | `--cloud_tpu_name=$TPU_NAME`.
15 |
16 | Note that because the `TPUEstimator` does not catch the `OutOfRangeError`
17 | during evaluation, you should ensure that `--eval_steps` is small enough to
18 | not exhaust the evaluation data.
19 |
20 | A non-exhaustive list of T2T models that work on TPU:
21 |
22 | * Image generation: `imagetransformer` with `imagetransformer_base_tpu` (or
23 | `imagetransformer_tiny_tpu`)
24 | * Super-resolution: `img2img_transformer` with `img2img_transformer_base_tpu`
25 | (or `img2img_transformer_tiny_tpu`)
26 | * `resnet` with `resnet_50` (or `resnet_18` or `resnet_34`)
27 | * `revnet` with `revnet_104` (or `revnet_38_cifar`)
28 | * `shake_shake` with `shakeshake_tpu` (or `shakeshake_small`)
29 |
30 | ## Example invocation
31 |
32 | Use `ctpu up` to bring up the VM and TPU machines; once the machines are ready
33 | it will SSH you into the VM and you can run the following:
34 |
35 | ```
36 | # DATA_DIR and OUT_DIR should be GCS buckets
37 | # TPU_NAME should have been set automatically by the ctpu tool
38 |
39 | t2t-trainer \
40 | --model=shake_shake \
41 | --hparams_set=shakeshake_tpu \
42 | --problem=image_cifar10 \
43 | --train_steps=180000 \
44 | --eval_steps=9 \
45 | --local_eval_frequency=100 \
46 | --data_dir=$DATA_DIR \
47 | --output_dir=$OUT_DIR \
48 | --use_tpu \
49 | --cloud_tpu_name=$TPU_NAME
50 | ```
51 |
--------------------------------------------------------------------------------
/docs/tutorials/asr_with_transformer.md:
--------------------------------------------------------------------------------
1 | # Automated Speech Recognition with the Transformer model
2 |
3 | See the
4 | [official tutorial](https://cloud.google.com/tpu/docs/tutorials/automated-speech-recognition).
5 |
--------------------------------------------------------------------------------
/floyd.yml:
--------------------------------------------------------------------------------
1 | env: tensorflow-1.12
2 | machine: gpu
3 |
--------------------------------------------------------------------------------
/floyd_requirements.txt:
--------------------------------------------------------------------------------
1 | tensor2tensor
2 |
--------------------------------------------------------------------------------
/oss_scripts/oss_integration_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Note that this test script requires docker to be installed and running.
4 |
5 | set -v # print commands as they're executed
6 | set -e # fail and exit on any command erroring
7 |
8 | : "${TF_VERSION:?}"
9 | : "${TF_LATEST:?}"
10 | : "${T2T_DATA_DIR:?}"
11 | : "${T2T_TRAIN_DIR:?}"
12 | : "${T2T_PROBLEM:?}"
13 |
14 | # Test --t2t_usr_dir
15 | t2t-trainer --registry_help --t2t_usr_dir=./tensor2tensor/test_data/example_usr_dir 2>&1 | grep my_very_own_hparams && echo passed
16 |
17 | # Run data generation, training, and decoding on a dummy problem
18 | t2t-datagen --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR
19 | t2t-trainer --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --train_steps=5 --eval_steps=5 --output_dir=$T2T_TRAIN_DIR
20 | t2t-decoder --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR --decode_hparams='num_samples=10'
21 |
22 | # Test serving
23 | if [[ "$TF_VERSION" == "$TF_LATEST" ]]
24 | then
25 | # Export for serving
26 | pip install tensorflow_hub
27 | t2t-exporter \
28 | --problem=$T2T_PROBLEM \
29 | --data_dir=$T2T_DATA_DIR \
30 | --model=transformer \
31 | --hparams_set=transformer_tiny \
32 | --output_dir=$T2T_TRAIN_DIR
33 |
34 | # Run model server
35 | server_port=8500
36 | model_name=my_model
37 | docker run -d -p $server_port:$server_port \
38 | --mount type=bind,source=$T2T_TRAIN_DIR/export,target=/models/$model_name \
39 | -e MODEL_NAME=$model_name -t tensorflow/serving
40 | sleep 10
41 |
42 | # Query
43 | pip install tensorflow-serving-api=="$TF_VERSION"
44 | t2t-query-server \
45 | --server=localhost:$server_port \
46 | --servable_name=$model_name \
47 | --problem=$T2T_PROBLEM \
48 | --data_dir=$T2T_DATA_DIR \
49 | --inputs_once='1 0 1 0 1 0'
50 | fi
51 |
--------------------------------------------------------------------------------
/oss_scripts/oss_pip_install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -v # print commands as they're executed
4 | set -e # fail and exit on any command erroring
5 |
6 | : "${TF_VERSION:?}"
7 |
8 | # Make sure we have the latest pip and setuptools installed.
9 | pip install -q -U pip
10 | pip install -q -U setuptools
11 |
12 | # Make sure we have the latest version of numpy - avoid problems we were
13 | # seeing with Python 3
14 | pip install -q -U numpy
15 | pip install -q "tensorflow==$TF_VERSION"
16 |
17 | # Just print the version again to make sure.
18 | python -c 'import tensorflow as tf; print(tf.__version__)'
19 |
20 | # First ensure that the base dependencies are sufficient for a full import
21 | pip install -q -e .
22 | t2t-trainer --registry_help 2>&1 >/dev/null
23 | t2t-datagen 2>&1 | grep translate_ende 2>&1 >/dev/null && echo passed
24 |
25 | # Then install the test dependencies
26 | pip install -q -e .[tests,allen]
27 | # Make sure to install the atari extras for gym
28 | pip install "gym[atari]"
29 |
--------------------------------------------------------------------------------
/oss_scripts/oss_release.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -v # print commands as they're executed
4 | set -e # fail and exit on any command erroring
5 |
6 | GIT_COMMIT_ID=${1:-""}
7 | [[ -z $GIT_COMMIT_ID ]] && echo "Must provide a commit" && exit 1
8 |
9 | TMP_DIR=$(mktemp -d)
10 | pushd $TMP_DIR
11 |
12 | echo "Cloning tensor2tensor and checking out commit $GIT_COMMIT_ID"
13 | git clone https://github.com/tensorflow/tensor2tensor.git
14 | cd tensor2tensor
15 | git checkout $GIT_COMMIT_ID
16 |
17 | # Without `python -m` we sometimes get module not callable error:
18 | # https://stackoverflow.com/questions/58451650/pip-no-longer-working-after-update-error-module-object-is-not-callable
19 | python -m pip install wheel twine pyopenssl
20 |
21 | # Build the distribution
22 | echo "Building distribution"
23 | python setup.py sdist
24 | python setup.py bdist_wheel --universal
25 |
26 | # Publish to PyPI
27 | echo "Publishing to PyPI"
28 | twine upload dist/*
29 |
30 | # Cleanup
31 | rm -rf build/ dist/ tensor2tensor.egg-info/
32 | popd
33 | rm -rf $TMP_DIR
34 |
--------------------------------------------------------------------------------
/tensor2tensor/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/build_vocab.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Build vocab for a subclass of Text2TextProblem.
17 |
18 | build_vocab \
19 | --problem=program_search_algolisp \
20 | --data_dir=~/t2t_data \
21 | --tmp_dir=~/t2t_data/tmp
22 | """
23 |
24 | from __future__ import absolute_import
25 | from __future__ import division
26 | from __future__ import print_function
27 |
28 | import os
29 |
30 | from tensor2tensor import problems as problems_lib # pylint: disable=unused-import
31 | from tensor2tensor.data_generators import text_problems
32 | from tensor2tensor.utils import registry
33 | import tensorflow.compat.v1 as tf
34 |
35 | flags = tf.flags
36 | FLAGS = flags.FLAGS
37 |
38 | flags.DEFINE_string("data_dir", "/tmp/t2t/data_dir",
39 | "Directory to place the generated vocabulary file in.")
40 |
41 | flags.DEFINE_string("tmp_dir", "/tmp/t2t/tmp_dir",
42 | "Temporary storage directory.")
43 |
44 | flags.DEFINE_string("problem", None,
45 | "Problem to generate the vocabulary file for.")
46 |
47 | flags.mark_flag_as_required("problem")
48 |
49 |
50 | def main(_):
51 | problem = registry.problem(FLAGS.problem)
52 |
53 | # We make the assumption that the problem is a subclass of Text2TextProblem.
54 | assert isinstance(problem, text_problems.Text2TextProblem)
55 |
56 | data_dir = os.path.expanduser(FLAGS.data_dir)
57 | tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
58 |
59 | tf.gfile.MakeDirs(data_dir)
60 | tf.gfile.MakeDirs(tmp_dir)
61 |
62 | tf.logging.info("Saving vocabulary to data_dir: %s" % data_dir)
63 |
64 | problem.get_or_create_vocab(data_dir, tmp_dir)
65 |
66 | tf.logging.info("Saved vocabulary file: " +
67 | os.path.join(data_dir, problem.vocab_filename))
68 |
69 |
70 | if __name__ == "__main__":
71 | tf.logging.set_verbosity(tf.logging.INFO)
72 | tf.app.run()
73 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/t2t-avg-all:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """t2t-avg-all."""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from tensor2tensor.bin import t2t_avg_all
8 |
9 | import tensorflow as tf
10 |
11 | def main(argv):
12 | t2t_avg_all.main(argv)
13 |
14 |
15 | if __name__ == "__main__":
16 | tf.logging.set_verbosity(tf.logging.INFO)
17 | tf.app.run()
18 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/t2t-bleu:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """t2t-bleu."""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from tensor2tensor.bin import t2t_bleu
8 |
9 | import tensorflow as tf
10 |
11 | def main(argv):
12 | t2t_bleu.main(argv)
13 |
14 |
15 |
16 | if __name__ == "__main__":
17 | tf.logging.set_verbosity(tf.logging.INFO)
18 | tf.app.run()
19 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/t2t-datagen:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Data generation for Tensor2Tensor.
3 |
4 | This script is used to generate data to train your models
5 | for a number problems for which open-source data is available.
6 |
7 | For example, to generate data for MNIST run this:
8 |
9 | t2t-datagen \
10 | --problem=image_mnist \
11 | --data_dir=~/t2t_data \
12 | --tmp_dir=~/t2t_data/tmp
13 | """
14 | from __future__ import absolute_import
15 | from __future__ import division
16 | from __future__ import print_function
17 |
18 | from tensor2tensor.bin import t2t_datagen
19 |
20 | import tensorflow.compat.v1 as tf
21 |
22 | def main(argv):
23 | t2t_datagen.main(argv)
24 |
25 |
26 | if __name__ == "__main__":
27 | tf.logging.set_verbosity(tf.logging.INFO)
28 | tf.app.run()
29 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/t2t-decoder:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """t2t-decoder."""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from tensor2tensor.bin import t2t_decoder
8 |
9 | import tensorflow as tf
10 |
11 | def main(argv):
12 | t2t_decoder.main(argv)
13 |
14 |
15 | if __name__ == "__main__":
16 | tf.logging.set_verbosity(tf.logging.INFO)
17 | tf.app.run()
18 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/t2t-eval:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Run t2t-eval from a trained checkpoint.
3 |
4 | This script is used to run evaluation from a trained checkpoint. Example
5 | to run evaluation on the test set when trained checkpoint is in /output_dir.
6 |
7 | t2t-eval \
8 | --problem=image_mnist \
9 | --model=imagetransformer \
10 | --data_dir=~/t2t
11 | --output_dir=/output_dir \
12 | --eval_use_test_set=True \
13 | """
14 | from __future__ import absolute_import
15 | from __future__ import division
16 | from __future__ import print_function
17 |
18 | from tensor2tensor.bin import t2t_eval
19 |
20 | import tensorflow as tf
21 |
22 | def main(argv):
23 | t2t_eval.main(argv)
24 |
25 |
26 | if __name__ == "__main__":
27 | tf.logging.set_verbosity(tf.logging.INFO)
28 | tf.app.run()
29 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/t2t-exporter:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """t2t-exporter."""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from tensor2tensor.serving import export
8 |
9 | import tensorflow as tf
10 |
11 | def main(argv):
12 | export.main(argv)
13 |
14 |
15 | if __name__ == "__main__":
16 | tf.logging.set_verbosity(tf.logging.INFO)
17 | tf.app.run()
18 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/t2t-insights-server:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """t2t-insights-server."""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from tensor2tensor.insights import server
8 |
9 | import tensorflow as tf
10 |
11 | def main(argv):
12 | server.main(argv)
13 |
14 |
15 | if __name__ == "__main__":
16 | tf.logging.set_verbosity(tf.logging.INFO)
17 | tf.app.run()
18 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/t2t-make-tf-configs:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """t2t-make-tf-configs."""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from tensor2tensor.bin import make_tf_configs
8 |
9 | import tensorflow as tf
10 |
11 | def main(argv):
12 | make_tf_configs.main(argv)
13 |
14 |
15 | if __name__ == "__main__":
16 | tf.logging.set_verbosity(tf.logging.INFO)
17 | tf.app.run()
18 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/t2t-query-server:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """t2t-query-server."""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from tensor2tensor.serving import query
8 |
9 | import tensorflow as tf
10 |
11 | def main(argv):
12 | query.main(argv)
13 |
14 |
15 | if __name__ == "__main__":
16 | tf.logging.set_verbosity(tf.logging.INFO)
17 | tf.app.run()
18 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/t2t-trainer:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Trainer for Tensor2Tensor.
3 |
4 | This script is used to train your models in Tensor2Tensor.
5 |
6 | For example, to train a shake-shake model on MNIST run this:
7 |
8 | t2t-trainer \
9 | --generate_data \
10 | --problem=image_mnist \
11 | --data_dir=~/t2t_data \
12 | --tmp_dir=~/t2t_data/tmp
13 | --model=shake_shake \
14 | --hparams_set=shake_shake_quick \
15 | --output_dir=~/t2t_train/mnist1 \
16 | --train_steps=1000 \
17 | --eval_steps=100
18 | """
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | from tensor2tensor.bin import t2t_trainer
24 |
25 | import tensorflow.compat.v1 as tf
26 |
27 | def main(argv):
28 | t2t_trainer.main(argv)
29 |
30 |
31 | if __name__ == "__main__":
32 | tf.logging.set_verbosity(tf.logging.INFO)
33 | tf.app.run(main)
34 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/t2t-translate-all:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """t2t-translate-all."""
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from tensor2tensor.bin import t2t_translate_all
8 |
9 | import tensorflow as tf
10 |
11 | def main(argv):
12 | t2t_translate_all.main(argv)
13 |
14 |
15 |
16 | if __name__ == "__main__":
17 | tf.logging.set_verbosity(tf.logging.INFO)
18 | tf.app.run()
19 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/t2t_eval.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Perform evaluation on trained T2T models using the Estimator API."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.bin import t2t_trainer # pylint: disable=unused-import
23 | from tensor2tensor.data_generators import problem # pylint: disable=unused-import
24 | from tensor2tensor.utils import trainer_lib
25 | from tensor2tensor.utils import usr_dir
26 | import tensorflow.compat.v1 as tf
27 | from tensorflow.compat.v1 import estimator as tf_estimator
28 |
29 | flags = tf.flags
30 | FLAGS = flags.FLAGS
31 |
32 |
33 | def main(_):
34 | tf.logging.set_verbosity(tf.logging.INFO)
35 | trainer_lib.set_random_seed(FLAGS.random_seed)
36 | usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
37 |
38 | hparams = trainer_lib.create_hparams(
39 | FLAGS.hparams_set, FLAGS.hparams, data_dir=FLAGS.data_dir,
40 | problem_name=FLAGS.problem)
41 |
42 | # set appropriate dataset-split, if flags.eval_use_test_set.
43 | dataset_split = "test" if FLAGS.eval_use_test_set else None
44 | dataset_kwargs = {"dataset_split": dataset_split}
45 | eval_input_fn = hparams.problem.make_estimator_input_fn(
46 | tf_estimator.ModeKeys.EVAL, hparams, dataset_kwargs=dataset_kwargs)
47 | config = t2t_trainer.create_run_config(hparams)
48 |
49 | # summary-hook in tf.estimator.EstimatorSpec requires
50 | # hparams.model_dir to be set.
51 | hparams.add_hparam("model_dir", config.model_dir)
52 |
53 | estimator = trainer_lib.create_estimator(
54 | FLAGS.model, hparams, config, use_tpu=FLAGS.use_tpu)
55 | ckpt_iter = trainer_lib.next_checkpoint(
56 | hparams.model_dir, FLAGS.eval_timeout_mins)
57 | for ckpt_path in ckpt_iter:
58 | predictions = estimator.evaluate(
59 | eval_input_fn, steps=FLAGS.eval_steps, checkpoint_path=ckpt_path)
60 | tf.logging.info(predictions)
61 |
62 |
63 | if __name__ == "__main__":
64 | tf.logging.set_verbosity(tf.logging.INFO)
65 | tf.app.run()
66 |
--------------------------------------------------------------------------------
/tensor2tensor/bin/t2t_trainer_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for t2t_trainer."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | from tensor2tensor.bin import t2t_trainer
22 | from tensor2tensor.utils import trainer_lib_test
23 |
24 | import tensorflow.compat.v1 as tf
25 |
26 | FLAGS = tf.flags.FLAGS
27 |
28 |
29 | class TrainerTest(tf.test.TestCase):
30 |
31 | @classmethod
32 | def setUpClass(cls):
33 | trainer_lib_test.TrainerLibTest.setUpClass()
34 |
35 | def testTrain(self):
36 | FLAGS.problem = "tiny_algo"
37 | FLAGS.model = "transformer"
38 | FLAGS.hparams_set = "transformer_tiny"
39 | FLAGS.train_steps = 1
40 | FLAGS.eval_steps = 1
41 | FLAGS.output_dir = tf.test.get_temp_dir()
42 | FLAGS.data_dir = tf.test.get_temp_dir()
43 | t2t_trainer.main(None)
44 |
45 |
46 | if __name__ == "__main__":
47 | tf.test.main()
48 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/audio_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2tensor.data_generators.audio."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import io
23 | import os
24 | from tensor2tensor.data_generators import audio
25 |
26 | import tensorflow.compat.v1 as tf
27 |
28 |
29 | class AudioTest(tf.test.TestCase):
30 |
31 | def testDataCollection(self):
32 | # Generate a trivial source and target file.
33 | tmp_dir = self.get_temp_dir()
34 | test_files = [
35 | "dir1/file1",
36 | "dir1/file2",
37 | "dir1/dir2/file3",
38 | "dir1/dir2/dir3/file4",
39 | ]
40 | for filename in test_files:
41 | input_filename = os.path.join(tmp_dir, filename + ".WAV")
42 | target_filename = os.path.join(tmp_dir, filename + ".WRD")
43 | directories = os.path.dirname(input_filename)
44 | if not os.path.exists(directories):
45 | os.makedirs(directories)
46 | io.open(input_filename, "wb")
47 | io.open(target_filename, "wb")
48 |
49 | data_dict = audio._collect_data(tmp_dir, ".WAV", ".WRD")
50 | expected = [os.path.join(tmp_dir, filename) for filename in test_files]
51 | self.assertEqual(sorted(list(data_dict)), sorted(expected))
52 |
53 | # Clean up.
54 | for filename in test_files:
55 | os.remove(os.path.join(tmp_dir, "%s.WAV" % filename))
56 | os.remove(os.path.join(tmp_dir, "%s.WRD" % filename))
57 |
58 |
59 | if __name__ == "__main__":
60 | tf.test.main()
61 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/celeba_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for CelebA."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl.testing import parameterized
23 | from tensor2tensor.data_generators import celeba
24 | from tensor2tensor.utils import hparam
25 |
26 | import tensorflow.compat.v1 as tf
27 | from tensorflow.compat.v1 import estimator as tf_estimator
28 |
29 |
30 | class CelebaTest(parameterized.TestCase, tf.test.TestCase):
31 |
32 | @parameterized.named_parameters(
33 | ("Default", None),
34 | ("Area", "AREA"),
35 | ("Dilated", "DILATED"))
36 | def testCelebaMultiResolutionPreprocessExample(self, resize_method):
37 | example = {"inputs": tf.random_uniform([218, 178, 3], minval=-1.)}
38 | mode = tf_estimator.ModeKeys.TRAIN
39 | hparams = hparam.HParams(resolutions=[8, 16, 32])
40 | if resize_method is not None:
41 | hparams.resize_method = resize_method
42 |
43 | problem = celeba.ImageCelebaMultiResolution()
44 | preprocessed_example = problem.preprocess_example(example, mode, hparams)
45 | self.assertLen(preprocessed_example, 2)
46 | self.assertEqual(preprocessed_example["inputs"].shape, (138, 138, 3))
47 | self.assertEqual(preprocessed_example["targets"].shape, (42, 32, 3))
48 |
49 |
50 | if __name__ == "__main__":
51 | tf.test.main()
52 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/common_voice_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2tensor.data_generators.common_voice."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | from tensor2tensor.data_generators import common_voice
24 |
25 | import tensorflow.compat.v1 as tf
26 |
27 | pkg_dir, _ = os.path.split(__file__)
28 | _TESTDATA = os.path.join(pkg_dir, "test_data")
29 |
30 |
31 | class CommonVoiceTest(tf.test.TestCase):
32 |
33 | def testCollectData(self):
34 | output = common_voice._collect_data(_TESTDATA)
35 | self.assertEqual(1, len(output))
36 |
37 | # NOTE: No header.
38 | self.assertTrue("my_media" == output[0][0])
39 | self.assertTrue("my_label" == output[0][2])
40 |
41 |
42 | if __name__ == "__main__":
43 | tf.test.main()
44 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/desc2code_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for desc2code."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.data_generators import desc2code
23 |
24 | import tensorflow.compat.v1 as tf
25 |
26 | CODE_CPP_IN = """
27 | #include
28 |
29 | void main() { // This comment will be removed
30 | // This too.
31 | //
32 | /* Not this one */
33 | \t
34 | \t
35 | int a \t\n = 3;//
36 | //
37 | }
38 |
39 | """
40 |
41 | CODE_CPP_OUT = ("#include void main() { /* Not this one */ int a = "
42 | "3; }")
43 |
44 |
45 | class Desc2codeTest(tf.test.TestCase):
46 |
47 | def testCppPreprocess(self):
48 | """Check that the file correctly preprocess the code source."""
49 | cpp_pb = desc2code.ProgrammingDesc2codeCpp()
50 |
51 | self.assertEqual( # Add space beween two lines
52 | cpp_pb.preprocess_target("firstline//comm1\nsecondline//comm2\n"),
53 | "firstline secondline")
54 | # Checking for boths comments and spaces
55 | self.assertEqual(cpp_pb.preprocess_target(CODE_CPP_IN), CODE_CPP_OUT)
56 | self.assertEqual(
57 | cpp_pb.preprocess_target(" not removed //abcd "),
58 | "not removed //abcd")
59 |
60 |
61 | if __name__ == "__main__":
62 | tf.test.main()
63 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/dna_encoder_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2tensor.data_generators.dna_encoder."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | from tensor2tensor.data_generators import dna_encoder
22 | import tensorflow.compat.v1 as tf
23 |
24 |
25 | class DnaEncoderTest(tf.test.TestCase):
26 |
27 | def test_encode_decode(self):
28 | original = 'TTCGCGGNNNAACCCAACGCCATCTATGTANNTTGAGTTGTTGAGTTAAA'
29 |
30 | # Encoding should be reversible for any reasonable chunk size.
31 | for chunk_size in [1, 2, 4, 6, 8]:
32 | encoder = dna_encoder.DNAEncoder(chunk_size=chunk_size)
33 | encoded = encoder.encode(original)
34 | decoded = encoder.decode(encoded)
35 | self.assertEqual(original, decoded)
36 |
37 | def test_delimited_dna_encoder(self):
38 | original = 'TTCGCGGNNN,AACCCAACGC,CATCTATGTA,NNTTGAGTTG,TTGAGTTAAA'
39 |
40 | # Encoding should be reversible for any reasonable chunk size.
41 | for chunk_size in [1, 2, 4, 6, 8]:
42 | encoder = dna_encoder.DelimitedDNAEncoder(chunk_size=chunk_size)
43 | encoded = encoder.encode(original)
44 | decoded = encoder.decode(encoded)
45 | self.assertEqual(original, decoded)
46 |
47 |
48 | if __name__ == '__main__':
49 | tf.test.main()
50 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/gene_expression_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Genetics problems."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import numpy as np
21 |
22 | from tensor2tensor.data_generators import dna_encoder
23 | from tensor2tensor.data_generators import gene_expression
24 |
25 | import tensorflow.compat.v1 as tf
26 |
27 |
28 | class GeneticsTest(tf.test.TestCase):
29 |
30 | def _one_hot_bases(self, bases):
31 | ref = ["A", "C", "T", "G"]
32 | one_hots = []
33 | for base in bases:
34 | one_hot = [False] * 4
35 | if base in ref:
36 | one_hot[ref.index(base)] = True
37 | one_hots.append(one_hot)
38 | return np.array(one_hots)
39 |
40 | def testRecordToExample(self):
41 | encoder = dna_encoder.DNAEncoder(chunk_size=2)
42 | raw_inputs = ["A", "C", "G", "N", "C", "T"]
43 |
44 | # Put in numpy arrays in the same format as in the h5 file
45 | inputs = self._one_hot_bases(raw_inputs)
46 | mask = np.array([True, False, True])
47 | outputs = np.array([[1.0, 2.0, 3.0], [5.0, 1.0, 0.2], [5.1, 2.3, 2.3]])
48 | # Convert to example dict
49 | ex_dict = gene_expression.to_example_dict(encoder, inputs, mask, outputs)
50 |
51 | self.assertEqual(len(raw_inputs) // 2 + 1, len(ex_dict["inputs"]))
52 | self.assertAllEqual(encoder.encode(raw_inputs) + [1], ex_dict["inputs"])
53 | self.assertAllEqual([1.0, 0.0, 1.0], ex_dict["targets_mask"])
54 | self.assertAllEqual([1.0, 2.0, 3.0, 5.0, 1.0, 0.2, 5.1, 2.3, 2.3],
55 | ex_dict["targets"])
56 | self.assertAllEqual([3, 3], ex_dict["targets_shape"])
57 |
58 | def testGenerateShardArgs(self):
59 | num_examples = 37
60 | num_shards = 4
61 | outfiles = [str(i) for i in range(num_shards)]
62 | shard_args = gene_expression.generate_shard_args(outfiles, num_examples)
63 |
64 | starts, ends, fnames = zip(*shard_args)
65 | self.assertAllEqual([0, 9, 18, 27], starts)
66 | self.assertAllEqual([9, 18, 27, 37], ends)
67 | self.assertAllEqual(fnames, outfiles)
68 |
69 |
70 | if __name__ == "__main__":
71 | tf.test.main()
72 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/imagenet_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for ImageNet."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl.testing import parameterized
23 | from tensor2tensor.data_generators import imagenet
24 | from tensor2tensor.utils import hparam
25 |
26 | import tensorflow.compat.v1 as tf
27 | from tensorflow.compat.v1 import estimator as tf_estimator
28 |
29 |
30 | class ImagenetTest(parameterized.TestCase, tf.test.TestCase):
31 |
32 | @parameterized.named_parameters(
33 | ("Default", None),
34 | ("Area", "AREA"),
35 | ("Dilated", "DILATED"))
36 | def testImagenetMultiResolutionPreprocessExample(self, resize_method):
37 | example = {"inputs": tf.random_uniform([64, 64, 3], minval=-1.)}
38 | mode = tf_estimator.ModeKeys.TRAIN
39 | hparams = hparam.HParams(resolutions=[8, 16, 32])
40 | if resize_method is not None:
41 | hparams.resize_method = resize_method
42 |
43 | problem = imagenet.ImageImagenetMultiResolutionGen()
44 | preprocessed_example = problem.preprocess_example(example, mode, hparams)
45 | self.assertLen(preprocessed_example, 1)
46 | self.assertEqual(preprocessed_example["inputs"].shape, (42, 32, 3))
47 |
48 | def testImagenetIsNormalized(self):
49 | problem = imagenet.ImageImagenet224()
50 | self.assertTrue(problem.normalize_image)
51 | problem = imagenet.ImageImagenet224NoNormalization()
52 | self.assertFalse(problem.normalize_image)
53 |
54 |
55 | if __name__ == "__main__":
56 | tf.test.main()
57 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/lm1b_imdb.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Data generators for LM1B and IMDb combined data-set."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.data_generators import imdb
23 | from tensor2tensor.data_generators import lm1b
24 | from tensor2tensor.data_generators import multi_problem
25 | from tensor2tensor.data_generators import text_problems
26 | from tensor2tensor.utils import registry
27 |
28 |
29 | @registry.register_problem
30 | class LanguagemodelLm1bSentimentIMDB(multi_problem.MultiProblem):
31 | """LM1b and IMDb mixed problem class for multitask learning."""
32 |
33 | def __init__(self, was_reversed=False, was_copy=False):
34 | super(LanguagemodelLm1bSentimentIMDB, self).__init__(was_reversed, was_copy)
35 | self.task_list.append(lm1b.LanguagemodelLm1bCharacters())
36 | self.task_list.append(imdb.SentimentIMDBCharacters())
37 |
38 | @property
39 | def vocab_type(self):
40 | return text_problems.VocabType.CHARACTER
41 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/lm1b_mnli.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Data generators for LM1B and MNLI combined datasets."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.data_generators import lm1b
23 | from tensor2tensor.data_generators import multi_problem
24 | from tensor2tensor.data_generators import multinli
25 | from tensor2tensor.data_generators import text_problems
26 | from tensor2tensor.utils import registry
27 |
28 |
29 | @registry.register_problem
30 | class LanguagemodelLm1bMultiNLISubwords(multi_problem.MultiProblem):
31 | """LM1b and MNLI mixed problem class for multitask learning."""
32 |
33 | def __init__(self, was_reversed=False, was_copy=False):
34 | super(LanguagemodelLm1bMultiNLISubwords, self).__init__(
35 | was_reversed, was_copy)
36 | self.task_list.append(lm1b.LanguagemodelLm1b32k())
37 | self.task_list.append(multinli.MultiNLISharedVocab())
38 |
39 | @property
40 | def vocab_type(self):
41 | return text_problems.VocabType.SUBWORD
42 |
43 |
44 | @registry.register_problem
45 | class LanguagemodelLm1bMultiNLI(multi_problem.MultiProblem):
46 | """LM1b and MNLI mixed problem class for multitask learning."""
47 |
48 | def __init__(self, was_reversed=False, was_copy=False):
49 | super(LanguagemodelLm1bMultiNLI, self).__init__(was_reversed, was_copy)
50 | self.task_list.append(lm1b.LanguagemodelLm1bCharacters())
51 | self.task_list.append(multinli.MultiNLICharacters())
52 |
53 | @property
54 | def vocab_type(self):
55 | return text_problems.VocabType.CHARACTER
56 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/mscoco_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for MS COCO."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl.testing import parameterized
23 | from tensor2tensor.data_generators import mscoco
24 | from tensor2tensor.utils import hparam
25 |
26 | import tensorflow.compat.v1 as tf
27 | from tensorflow.compat.v1 import estimator as tf_estimator
28 |
29 |
30 | class MscocoTest(parameterized.TestCase, tf.test.TestCase):
31 |
32 | @parameterized.named_parameters(
33 | ("Default", None),
34 | ("Area", "AREA"),
35 | ("Dilated", "DILATED"))
36 | def testMsCocoMultiResolutionPreprocessExample(self, resize_method):
37 | example = {"inputs": tf.random_uniform([400, 400, 3], minval=-1.)}
38 | mode = tf_estimator.ModeKeys.TRAIN
39 | hparams = hparam.HParams(resolutions=[8, 16, 32])
40 | if resize_method is not None:
41 | hparams.resize_method = resize_method
42 |
43 | problem = mscoco.ImageTextMsCocoMultiResolution()
44 | preprocessed_example = problem.preprocess_example(example, mode, hparams)
45 | self.assertLen(preprocessed_example, 1)
46 | self.assertEqual(preprocessed_example["inputs"].shape, (42, 32, 3))
47 |
48 |
49 | if __name__ == "__main__":
50 | tf.test.main()
51 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/ops/subword_text_encoder.h:
--------------------------------------------------------------------------------
1 | #ifndef TENSOR2TESNOR_DATA_GENERATORS_OPS_SUBWORD_TEXT_ENCODER_H_
2 | #define TENSOR2TESNOR_DATA_GENERATORS_OPS_SUBWORD_TEXT_ENCODER_H_
3 |
4 | #include "third_party/absl/container/flat_hash_map.h"
5 | #include "third_party/absl/container/flat_hash_set.h"
6 | #include "third_party/absl/strings/string_view.h"
7 | #include "third_party/icu/include/unicode/uchar.h"
8 | #include "third_party/tensorflow/core/framework/tensor.h"
9 |
10 | namespace tensor2tensor {
11 |
12 | // A subword text encoder with built in tokenizer.
13 | //
14 | // Equivalent to tensor2tensor's subword text
15 | // https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py,
16 | // This code (or a suitable replacement) should eventually move into tfds
17 | // and should be deleted from tensor2tensor.
18 |
19 | class SubwordTextEncoder {
20 | public:
21 | explicit SubwordTextEncoder(const std::string& vocab_filename);
22 | virtual ~SubwordTextEncoder() {}
23 |
24 | // Breaks up input text into subtokens.
25 | void Encode(absl::string_view text, std::vector* ids);
26 |
27 | private:
28 | // Given a full token as input, breaks the token up into subtokens and appends
29 | // corresponding IDs to the ids vector.
30 | void EncodeSubtokens(absl::string_view token, std::vector* ids);
31 |
32 | // Escapes a token so unencodable characters are replaced by escape sequences.
33 | std::string EscapeToken(absl::string_view token);
34 |
35 | // Maps subword tokens to IDs.
36 | absl::flat_hash_map vocab_;
37 | // A set containing all valid unicode code points that can be encoded without
38 | // being escaped.
39 | absl::flat_hash_set alphabet_;
40 | };
41 |
42 | } // namespace tensor2tensor
43 |
44 | #endif // TENSOR2TESNOR_DATA_GENERATORS_OPS_SUBWORD_TEXT_ENCODER_H_
45 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/ops/subword_text_encoder_ops.cc:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "third_party/py/tensor2tensor/data_generators/ops/subword_text_encoder.h"
4 | #include "third_party/tensorflow/core/framework/op_kernel.h"
5 | #include "third_party/tensorflow/core/framework/shape_inference.h"
6 | #include "third_party/tensorflow/core/framework/tensor.h"
7 | #include "third_party/tensorflow/core/framework/types.h"
8 |
9 | namespace tensor2tensor {
10 | namespace {
11 |
12 | using ::tensorflow::DEVICE_CPU;
13 | using ::tensorflow::OpKernel;
14 | using ::tensorflow::OpKernelConstruction;
15 | using ::tensorflow::OpKernelContext;
16 | using ::tensorflow::Status;
17 | using ::tensorflow::Tensor;
18 | using ::tensorflow::TensorShape;
19 | using ::tensorflow::tstring;
20 | using ::tensorflow::shape_inference::InferenceContext;
21 |
22 | REGISTER_OP("SubwordTextEncoderEncode")
23 | .Input("s: string")
24 | .Output("encoded: int64")
25 | .Attr("vocab_filename: string")
26 | .SetShapeFn([](InferenceContext* ctx) {
27 | ctx->set_output(0, ctx->Vector(ctx->UnknownDim()));
28 | return tensorflow::Status();
29 | });
30 |
31 | class SubwordTextEncoderEncodeOp : public OpKernel {
32 | public:
33 | explicit SubwordTextEncoderEncodeOp(
34 | OpKernelConstruction* ctx) : OpKernel(ctx) {
35 | std::string vocab_filename;
36 | OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_filename", &vocab_filename));
37 | encoder_ = std::make_unique(vocab_filename);
38 | }
39 |
40 | void Compute(OpKernelContext* ctx) override {
41 | // Get input string and deserialize into ArticleExample proto.
42 | absl::string_view s = ctx->input(0).scalar()();
43 |
44 | // Construct encoded output tensors.
45 | std::vector encoded_ids;
46 | encoder_->Encode(s, &encoded_ids);
47 | Tensor* encoded;
48 | OP_REQUIRES_OK(
49 | ctx, ctx->allocate_output(
50 | 0, TensorShape({static_cast(encoded_ids.size())}),
51 | &encoded));
52 | auto encoded_vec = encoded->vec();
53 | // TODO(noam): find someone who remembers c++ eigen and ask the proper way
54 | // to copy a std::Vector to an Eigen whatever-this-is
55 | for (int i = 0; i < encoded_ids.size(); i++) {
56 | encoded_vec(i) = encoded_ids[i];
57 | }
58 | }
59 |
60 | private:
61 | std::unique_ptr encoder_;
62 | };
63 |
64 | REGISTER_KERNEL_BUILDER(Name("SubwordTextEncoderEncode").Device(DEVICE_CPU),
65 | SubwordTextEncoderEncodeOp);
66 |
67 | } // namespace
68 | } // namespace tensor2tensor
69 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/ops/subword_text_encoder_ops_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for subword_text_encoder_ops."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.data_generators.ops import subword_text_encoder_ops
23 | import tensorflow.compat.v1 as tf
24 |
25 | vocab_file = (
26 | "third_party/py/tensor2tensor/data_generators/ops/testdata/subwords")
27 |
28 |
29 | class SubwordTextEncoderOpsTest(tf.test.TestCase):
30 |
31 | def test_subword_text_encoder_encode(self):
32 | s = "the quick brown fox jumps over the lazy dog"
33 | encoded = subword_text_encoder_ops.subword_text_encoder_encode(
34 | s, vocab_file)
35 | self.assertAllEqual(encoded, [2, 3, 4, 5, 6, 7, 8, 9, 2, 11, 12, 1])
36 |
37 |
38 | if __name__ == "__main__":
39 | tf.enable_eager_execution()
40 | tf.test.main()
41 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/ops/subword_text_encoder_test.cc:
--------------------------------------------------------------------------------
1 | #include "third_party/py/tensor2tensor/data_generators/ops/subword_text_encoder.h"
2 |
3 | #include "testing/base/public/gunit.h"
4 | #include "third_party/tensorflow/core/framework/tensor.h"
5 | #include "third_party/tensorflow/core/framework/tensor_testutil.h"
6 |
7 | namespace tensor2tensor {
8 | namespace {
9 |
10 | TEST(SubwordTextEncoderTest, EncodesSubTokens) {
11 | SubwordTextEncoder encoder("third_party/py/tensor2tensor/"
12 | "data_generators/ops/testdata/subwords");
13 | std::vector t;
14 | encoder.Encode("the quick brown fox jumps over the lazy dog", &t);
15 | EXPECT_EQ(t, std::vector({2, 3, 4, 5, 6, 7, 8, 9, 2, 11, 12, 1}));
16 | }
17 |
18 | TEST(SubwordTextEncoderTest, EncodesUnicodeSubTokens) {
19 | SubwordTextEncoder encoder("third_party/py/tensor2tensor/"
20 | "data_generators/ops/testdata/subwords");
21 | std::vector t;
22 | encoder.Encode("ɧęĻĽÒ", &t);
23 | EXPECT_EQ(t, std::vector({13, 14, 1}));
24 | }
25 |
26 | TEST(SubwordTextEncoderTest, EncodesUnicodeCodePoints) {
27 | SubwordTextEncoder encoder("third_party/py/tensor2tensor/"
28 | "data_generators/ops/testdata/subwords");
29 | std::vector t;
30 | encoder.Encode("⻦ ⻭", &t);
31 | EXPECT_EQ(t, std::vector({15, 18, 16, 17, 1}));
32 | }
33 |
34 | TEST(SubwordTextEncoderTest, EncodesCharactersNotInAlphabet) {
35 | SubwordTextEncoder encoder("third_party/py/tensor2tensor/"
36 | "data_generators/ops/testdata/subwords");
37 | std::vector t;
38 | encoder.Encode("!", &t);
39 | // Subtokens: '\', '3', '3', ';', '_', '', ''.
40 | EXPECT_EQ(t, std::vector({19, 23, 23, 30, 17, 1}));
41 | }
42 |
43 | } // namespace
44 | } // namespace tensor2tensor
45 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/ops/testdata/subwords:
--------------------------------------------------------------------------------
1 | ''
2 | ''
3 | 'the_'
4 | 'quick_'
5 | 'brow'
6 | 'n_'
7 | 'fox_'
8 | 'jump'
9 | 's_'
10 | 'over_'
11 | 'the_'
12 | 'lazy_'
13 | 'dog_'
14 | 'ɧę'
15 | 'ĻĽÒ_'
16 | '⻦'
17 | '⻭'
18 | '_'
19 | ' '
20 | '\'
21 | '0'
22 | '1'
23 | '2'
24 | '3'
25 | '4'
26 | '5'
27 | '6'
28 | '7'
29 | '8'
30 | '9'
31 | ';'
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/test_data/1.csv:
--------------------------------------------------------------------------------
1 | media_name,label
2 | my_media,my_label
3 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/test_data/corpus-1.txt:
--------------------------------------------------------------------------------
1 | One morning I shot an elephant in my pajamas. How he got in my pajamas, I don't
2 | know.
3 |
4 | Groucho Marx
5 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/test_data/corpus-2.txt:
--------------------------------------------------------------------------------
1 | I haven't slept for 10 days... because that would be too long.
2 |
3 | Mitch Hedberg
4 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/test_data/vocab-1.txt:
--------------------------------------------------------------------------------
1 | lollipop,8
2 | reverberated,12
3 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/test_data/vocab-2.txt:
--------------------------------------------------------------------------------
1 | kattywampus,11
2 | kaput
3 | balderdash,10
4 | jiggery-pokery,14
5 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/timeseries_data_generator.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Data generator for the timeseries problem."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import numpy as np
22 |
23 |
24 | def generate_data(timeseries_length, timeseries_params):
25 | """Generates synthetic timeseries using input parameters.
26 |
27 | Each generated timeseries has timeseries_length data points.
28 | Parameters for each timeseries are specified by timeseries_params.
29 |
30 | Args:
31 | timeseries_length: Number of data points to generate for each timeseries.
32 | timeseries_params: Parameters used to generate the timeseries. The following
33 | parameters need to be specified for each timeseries:
34 | m = Slope of the timeseries used to compute the timeseries trend.
35 | b = y-intercept of the timeseries used to compute the timeseries trend.
36 | A = Timeseries amplitude used to compute timeseries period.
37 | freqcoeff = Frequency coefficient used to compute timeseries period.
38 | rndA = Random amplitude used to inject noise into the timeseries.
39 | fn = Base timeseries function (np.cos or np.sin).
40 | Example params for two timeseries.
41 | [{"m": 0.006, "b": 300.0, "A":50.0, "freqcoeff":1500.0, "rndA":15.0,
42 | "fn": np.sin},
43 | {"m": 0.000, "b": 500.0, "A":35.0, "freqcoeff":3500.0, "rndA":25.0,
44 | "fn": np.cos}]
45 |
46 | Returns:
47 | Multi-timeseries (list of list).
48 | """
49 | x = range(timeseries_length)
50 |
51 | multi_timeseries = []
52 | for p in timeseries_params:
53 | # Trend
54 | y1 = [p["m"] * i + p["b"] for i in x]
55 | # Period
56 | y2 = [p["A"] * p["fn"](i / p["freqcoeff"]) for i in x]
57 | # Noise
58 | y3 = np.random.normal(0, p["rndA"], timeseries_length).tolist()
59 | # Sum of Trend, Period and Noise. Replace negative values with zero.
60 | y = [max(a + b + c, 0) for a, b, c in zip(y1, y2, y3)]
61 | multi_timeseries.append(y)
62 |
63 | return multi_timeseries
64 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/translate_ende_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2tensor.data_generators.translate_ende."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.data_generators import problem
23 | from tensor2tensor.data_generators import translate_ende
24 |
25 | import tensorflow.compat.v1 as tf
26 |
27 |
28 | class TranslateEndeTest(tf.test.TestCase):
29 | """Tests that some TranslateEnde subclasses inherit information correctly."""
30 |
31 | def test_vocab_size(self):
32 | wmt_8k = translate_ende.TranslateEndeWmt8k()
33 | wmt_32k = translate_ende.TranslateEndeWmt32k()
34 | self.assertEqual(wmt_8k.approx_vocab_size, 8192)
35 | self.assertEqual(wmt_32k.approx_vocab_size, 32768)
36 |
37 | def test_additional_datasets(self):
38 | wmt_8k = translate_ende.TranslateEndeWmt8k()
39 | wmt_32k = translate_ende.TranslateEndeWmt32k()
40 | self.assertListEqual(wmt_8k.additional_training_datasets, [])
41 | self.assertListEqual(wmt_32k.additional_training_datasets, [])
42 |
43 | def test_source_data_files(self):
44 | wmt_8k = translate_ende.TranslateEndeWmt8k()
45 | wmt_32k = translate_ende.TranslateEndeWmt32k()
46 | eval_split = problem.DatasetSplit.EVAL
47 | train_split = problem.DatasetSplit.TRAIN
48 |
49 | wmt_8k_eval_files = wmt_8k.source_data_files(eval_split)
50 | wmt_32k_eval_files = wmt_32k.source_data_files(eval_split)
51 | self.assertListEqual(wmt_8k_eval_files, wmt_32k_eval_files)
52 | self.assertGreater(len(wmt_8k_eval_files), 0)
53 |
54 | wmt_8k_train_files = wmt_8k.source_data_files(train_split)
55 | wmt_32k_train_files = wmt_32k.source_data_files(train_split)
56 | self.assertListEqual(wmt_8k_train_files, wmt_32k_train_files)
57 | self.assertGreater(len(wmt_8k_train_files), 0)
58 |
59 |
60 | if __name__ == '__main__':
61 | tf.test.main()
62 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/translate_entn.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Data generators for translation data-sets."""
17 |
18 |
19 | from tensor2tensor.data_generators import problem
20 | from tensor2tensor.data_generators import text_encoder
21 | from tensor2tensor.data_generators import translate
22 | from tensor2tensor.utils import registry
23 |
24 |
25 | EOS = text_encoder.EOS_ID
26 |
27 | _URL = "https://github.com/LauraMartinus/ukuxhumana/blob/master/data/en_tn"
28 |
29 | _ENTN_TRAIN_DATASETS = [[
30 | _URL + "/eng_tswane.train.tar.gz?raw=true",
31 | ("entn_parallel.train.en", "entn_parallel.train.tn")
32 | ]]
33 |
34 | _ENTN_TEST_DATASETS = [[
35 | _URL + "/eng_tswane.dev.tar.gz?raw=true",
36 | ("entn_parallel.dev.en", "entn_parallel.dev.tn")
37 | ]]
38 |
39 |
40 | @registry.register_problem
41 | class TranslateEntnRma(translate.TranslateProblem):
42 | """Problem spec for English-Setswana translation.
43 |
44 | Uses the RMA Autshumato dataset.
45 | """
46 |
47 | @property
48 | def approx_vocab_size(self):
49 | return 2**15 # 32768
50 |
51 | @property
52 | def vocab_filename(self):
53 | return "vocab.entn.%d" % self.approx_vocab_size
54 |
55 | def source_data_files(self, dataset_split):
56 | train = dataset_split == problem.DatasetSplit.TRAIN
57 | return _ENTN_TRAIN_DATASETS if train else _ENTN_TEST_DATASETS
58 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/translate_envi.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Data generators for En-Vi translation."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | from tensor2tensor.data_generators import problem
22 | from tensor2tensor.data_generators import text_encoder
23 | from tensor2tensor.data_generators import translate
24 | from tensor2tensor.utils import registry
25 |
26 | # End-of-sentence marker.
27 | EOS = text_encoder.EOS_ID
28 |
29 | # For English-Vietnamese the IWSLT'15 corpus
30 | # from https://nlp.stanford.edu/projects/nmt/ is used.
31 | # The original dataset has 133K parallel sentences.
32 | _ENVI_TRAIN_DATASETS = [[
33 | "https://github.com/stefan-it/nmt-en-vi/raw/master/data/train-en-vi.tgz", # pylint: disable=line-too-long
34 | ("train.en", "train.vi")
35 | ]]
36 |
37 | # For development 1,553 parallel sentences are used.
38 | _ENVI_TEST_DATASETS = [[
39 | "https://github.com/stefan-it/nmt-en-vi/raw/master/data/dev-2012-en-vi.tgz", # pylint: disable=line-too-long
40 | ("tst2012.en", "tst2012.vi")
41 | ]]
42 |
43 |
44 | # See this PR on github for some results with Transformer on this Problem.
45 | # https://github.com/tensorflow/tensor2tensor/pull/611
46 |
47 |
48 | @registry.register_problem
49 | class TranslateEnviIwslt32k(translate.TranslateProblem):
50 | """Problem spec for IWSLT'15 En-Vi translation."""
51 |
52 | @property
53 | def approx_vocab_size(self):
54 | return 2**15 # 32768
55 |
56 | def source_data_files(self, dataset_split):
57 | train = dataset_split == problem.DatasetSplit.TRAIN
58 | return _ENVI_TRAIN_DATASETS if train else _ENVI_TEST_DATASETS
59 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/wikifact/README.md:
--------------------------------------------------------------------------------
1 | # Assessing the Factual Accuracy of Generated Text
2 |
3 | This directory will contain the code and scripts to generate data and train
4 | models from the paper *Assessing the Factual Accuracy of Generated Text*.
5 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/wikisum/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/wikisum/delete_instances.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Delete Google Compute Engine instances with naming structure $NAME-$INDEX
4 | # (e.g. machines created with parallel_launch.py).
5 | # Example usage:
6 | # delete_instances.sh fetch-ref-urls 1000
7 |
8 | NAME=$1
9 | MAX=$2
10 | MIN=${3:-0}
11 |
12 | LOG_F=/tmp/delete-$NAME-logs.txt
13 |
14 | echo "Deleting $MAX instances starting with $NAME-$MIN"
15 |
16 | for i in $(seq $MIN $MAX)
17 | do
18 | gcloud compute instances delete --quiet $NAME-$i > $LOG_F 2>&1 &
19 | if [[ $(( i % 100 )) == 0 ]]
20 | then
21 | # Give it some room to breathe every 100
22 | sleep 30
23 | fi
24 | done
25 |
26 | echo "Delete commands launched. Logs redirected to $LOG_F"
27 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/wikisum/generate_vocab.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Generate vocab from references and wikis."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from tensor2tensor.data_generators.wikisum import wikisum
22 |
23 | import tensorflow.compat.v1 as tf
24 |
25 | flags = tf.flags
26 | FLAGS = flags.FLAGS
27 |
28 | flags.DEFINE_string("out_dir", None, "Directory to write vocab to.")
29 | flags.DEFINE_string("wikis_dir",
30 | "gs://tensor2tensor-data/wikisum/wiki_content/",
31 | "Directory with wiki_content.tfrecords shards.")
32 | flags.DEFINE_string("refs_dir", None,
33 | "Directory with process_X folders with reference shards.")
34 | flags.DEFINE_bool("for_commoncrawl", False,
35 | "Whether to use WikisumCommoncrawl or WikisumWeb.")
36 |
37 |
38 | def main(_):
39 | if FLAGS.for_commoncrawl:
40 | problem = wikisum.WikisumCommoncrawl()
41 | else:
42 | problem = wikisum.WikisumWeb()
43 | problem.generate_vocab(FLAGS.out_dir, FLAGS.wikis_dir, FLAGS.refs_dir)
44 |
45 |
46 | if __name__ == "__main__":
47 | tf.logging.set_verbosity(tf.logging.INFO)
48 | tf.app.run()
49 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/wikisum/get_references_commoncrawl.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Extract references from CommonCrawl files."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import tempfile
23 |
24 | from tensor2tensor.data_generators.wikisum import utils
25 | from tensor2tensor.data_generators.wikisum import wikisum
26 |
27 | import tensorflow.compat.v1 as tf
28 |
29 | flags = tf.flags
30 | FLAGS = flags.FLAGS
31 |
32 | flags.DEFINE_integer("num_tasks", 1000, "Number of parallel tasks.")
33 | flags.DEFINE_integer("task_id", 0, "Task id in a parallel run.")
34 | flags.DEFINE_string("metadata_dir",
35 | "gs://tensor2tensor-data/wikisum/commoncrawl_metadata/",
36 | "Path to metadata files specifying what references are in "
37 | "which CommonCrawl files.")
38 | flags.DEFINE_string("out_dir", None, "Directory to write references to.")
39 | flags.DEFINE_string("commoncrawl_wet_dir", None,
40 | "Path to CommonCrawl wet.gz files locally. If not "
41 | "provided, will download.")
42 |
43 |
44 | def main(_):
45 | assert FLAGS.out_dir
46 | assert FLAGS.metadata_dir
47 | out_dir = os.path.join(FLAGS.out_dir, "process_%d" % FLAGS.task_id)
48 | tf.gfile.MakeDirs(out_dir)
49 |
50 | with utils.timing("get_refs_commoncrawl"):
51 | # Get all WET files
52 | if FLAGS.commoncrawl_wet_dir:
53 | wet_files = tf.gfile.Glob(
54 | os.path.join(FLAGS.commoncrawl_wet_dir, "*.wet.gz"))
55 | else:
56 | tmp_dir = tempfile.gettempdir()
57 | wet_files = list(
58 | utils.wet_download_urls(utils.WET_PATHS_BY_DATE["0917"], tmp_dir))
59 |
60 | # Shard and select this task's work
61 | wet_files.sort()
62 | wet_files = utils.shard(wet_files, FLAGS.num_tasks)[FLAGS.task_id]
63 | tf.logging.info("Sharded out WET files. Processing %d files",
64 | len(wet_files))
65 |
66 | wikisum.extract_references_from_wets(wet_files, FLAGS.metadata_dir, out_dir)
67 |
68 |
69 | if __name__ == "__main__":
70 | tf.logging.set_verbosity(tf.logging.INFO)
71 | tf.app.run()
72 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/wikisum/test_data/para_bad1.txt:
--------------------------------------------------------------------------------
1 | kolkata ward no 97 37
2 | you are here : india » west bengal » kolkata » kolkata
3 | this paragraph too short
4 | a | b | c | d | e | f | g | h | i | j | k | l | m | n | o | p | q | r | s | t | u | v | w | x | y | z
5 | 123 123 123 123 985 9880 1230 0980 . 12398 .
6 | - 5 . 7 % - 5 . 2 % - 15 . 1 % 4 . 7 % - 13 . 3 %
7 | http : / / www . bbc . co . uk / sport / football / 24351521
8 | no . - 26 beadon street .
9 | { { / playpopup } } { { ^ playpopup } } { { # playinvideopage } } { { / playinvideopage } } { { ^ playinvideopage } } { { / playinvideopage } } { { / playpopup } }
{ { # playpopup } } { { / playpopup } } { { ^ playpopup } } { { # playinvideopage } } { { / playinvideopage } } { { ^ playinvideopage } } { { / playinvideopage } } { { / playpopup } } { { genre } }
10 | denham , samuel coulter , sally 133 oct 28 1819
11 | browse by
12 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/wikisum/test_data/para_good1.txt:
--------------------------------------------------------------------------------
1 | this is a very good paragraph . it even has two sentences .
2 | the castle that was soon to figure so largely in lee’s life lay fourteen miles
3 | to the southwest of where he sat perched atop his tank . topped with storybook
4 | crenelations and accompanied by a rich history , schloss itter , as it’s called
5 | in german , was first mentioned in land records as early as 1240 . since then ,
6 | itter has passed through a number of hands . after germany’s march 1938
7 | annexation of austria , the castle’s robust construction and relatively remote
8 | location attracted the attention of the notoriously secretive nazis . within
9 | months of absorbing austria into the greater reich , the german government
10 | requisitioned castle itter for unspecified “official use”—which included housing
11 | for several months in 1942 an organization called the “german association for
12 | combating the dangers of tobacco . ” on february 7 , 1943 , it fell into new
13 | hands yet again , for on that day , the structure and all its outbuildings were
14 | requisitioned by the wehrmacht on behalf of the ss .
15 | the url for the site is http : / / www . bbc . co . uk / sport / football / 24351521 .
16 |
--------------------------------------------------------------------------------
/tensor2tensor/data_generators/wikisum/utils_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2tensor.data_generators.wikisum.utils."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | from tensor2tensor.data_generators.wikisum import utils
24 |
25 | import tensorflow.compat.v1 as tf
26 |
27 | pkg_dir = os.path.abspath(__file__)
28 | pkg_dir, _ = os.path.split(pkg_dir)
29 | _TESTDATA = os.path.join(pkg_dir, "test_data")
30 |
31 |
32 | def _get_testdata(filename):
33 | with tf.io.gfile.GFile(filename) as f:
34 | return f.read()
35 |
36 |
37 | class UtilsTest(tf.test.TestCase):
38 |
39 | def test_filter_paragraph(self):
40 | for bad in tf.io.gfile.glob(os.path.join(_TESTDATA, "para_bad*.txt")):
41 | for p in _get_testdata(bad).split("\n"):
42 | self.assertTrue(utils.filter_paragraph(p),
43 | msg="Didn't filter %s" % p)
44 | for good in tf.io.gfile.glob(os.path.join(_TESTDATA, "para_good*.txt")):
45 | for p in _get_testdata(good).split("\n"):
46 | p = _get_testdata(good)
47 | self.assertFalse(utils.filter_paragraph(p), msg="Filtered %s" % p)
48 |
49 |
50 | if __name__ == "__main__":
51 | tf.test.main()
52 |
--------------------------------------------------------------------------------
/tensor2tensor/envs/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Environments defined in T2T. Imports here force registration."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.envs import gym_env_problem
23 | from tensor2tensor.envs import tic_tac_toe_env
24 | from tensor2tensor.envs import tic_tac_toe_env_problem
25 |
--------------------------------------------------------------------------------
/tensor2tensor/envs/gym_spaces_utils_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for gym_spaces_utils.py."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from gym.spaces import Box
23 | from gym.spaces import Discrete
24 | import numpy as np
25 | from tensor2tensor.envs import gym_spaces_utils
26 | import tensorflow.compat.v1 as tf
27 |
28 |
29 | class GymSpacesUtilsTest(tf.test.TestCase):
30 |
31 | def test_discrete_space_spec(self):
32 | discrete_space = Discrete(100)
33 | spec = gym_spaces_utils.gym_space_spec(discrete_space)
34 | self.assertIsInstance(spec, tf.FixedLenFeature)
35 | self.assertEqual(spec.dtype, tf.int64)
36 | self.assertListEqual(list(spec.shape), [1])
37 |
38 | def test_box_space_spec(self):
39 | box_space = Box(low=0, high=10, shape=[5, 6], dtype=np.float32)
40 | spec = gym_spaces_utils.gym_space_spec(box_space)
41 | self.assertIsInstance(spec, tf.FixedLenFeature)
42 | self.assertEqual(spec.dtype, tf.float32)
43 | self.assertListEqual(list(spec.shape), [5, 6])
44 |
45 | def test_discrete_space_encode(self):
46 | discrete_space = Discrete(100)
47 | value = discrete_space.sample()
48 | encoded_value = gym_spaces_utils.gym_space_encode(discrete_space, value)
49 | self.assertListEqual([value], encoded_value)
50 |
51 | def test_box_space_encode(self):
52 | box_space = Box(low=0, high=10, shape=[2], dtype=np.int64)
53 | value = np.array([2, 3])
54 | encoded_value = gym_spaces_utils.gym_space_encode(box_space, value)
55 | self.assertListEqual([2, 3], encoded_value)
56 |
57 |
58 | if __name__ == '__main__':
59 | tf.test.main()
60 |
--------------------------------------------------------------------------------
/tensor2tensor/envs/mujoco_problems.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Mujoco Gym environments."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import functools
23 | from tensor2tensor.envs import rendered_env_problem
24 | from tensor2tensor.layers import modalities
25 | from tensor2tensor.rl import gym_utils
26 | from tensor2tensor.utils import registry
27 |
28 |
29 |
30 | @registry.register_env_problem
31 | class ReacherEnvProblem(rendered_env_problem.RenderedEnvProblem):
32 | """Mujoco's reacher environment."""
33 |
34 | def __init__(self):
35 | base_env_name = "Reacher-v2"
36 | wrapper_fn = functools.partial(
37 | gym_utils.gym_env_wrapper, **{
38 | "rl_env_max_episode_steps": -1,
39 | "maxskip_env": False,
40 | "rendered_env": True,
41 | "rendered_env_resize_to": None, # Do not resize frames
42 | "sticky_actions": False,
43 | "output_dtype": None,
44 | "num_actions": None,
45 | })
46 | super(ReacherEnvProblem, self).__init__(
47 | base_env_name=base_env_name, env_wrapper_fn=wrapper_fn)
48 |
49 | @property
50 | def input_modality(self):
51 | return modalities.ModalityType.VIDEO
52 |
53 | @property
54 | def target_modality(self):
55 | return modalities.ModalityType.VIDEO
56 |
57 | @property
58 | def action_modality(self):
59 | return modalities.ModalityType.IDENTITY
60 |
61 | @property
62 | def reward_modality(self):
63 | return modalities.ModalityType.IDENTITY
64 |
65 | @property
66 | def input_vocab_size(self):
67 | return 256
68 |
69 | @property
70 | def target_vocab_size(self):
71 | return 256
72 |
--------------------------------------------------------------------------------
/tensor2tensor/envs/mujoco_problems_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2tensor.envs.mujoco_problems."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import numpy as np
23 | from tensor2tensor.envs import env_problem_utils
24 | from tensor2tensor.envs import mujoco_problems # pylint: disable=unused-import
25 | from tensor2tensor.utils import registry
26 | import tensorflow.compat.v1 as tf
27 |
28 |
29 | class ReacherEnvProblemTest(tf.test.TestCase):
30 |
31 | def test_registration_and_interaction_with_env_problem(self):
32 | batch_size = 5
33 | # This ensures that registration has occurred.
34 | ep = registry.env_problem("reacher_env_problem", batch_size=batch_size)
35 | ep.reset()
36 | num_done = 0
37 | nsteps = 100
38 | for _ in range(nsteps):
39 | actions = np.stack([ep.action_space.sample() for _ in range(batch_size)])
40 | obs, rewards, dones, infos = ep.step(actions)
41 |
42 | # Assert that things are happening batchwise.
43 | self.assertEqual(batch_size, len(obs))
44 | self.assertEqual(batch_size, len(rewards))
45 | self.assertEqual(batch_size, len(dones))
46 | self.assertEqual(batch_size, len(infos))
47 |
48 | done_indices = env_problem_utils.done_indices(dones)
49 | ep.reset(done_indices)
50 | num_done += sum(dones)
51 |
52 | # Assert that something got done atleast,
53 | self.assertGreater(num_done, 0)
54 |
55 |
56 | if __name__ == "__main__":
57 | tf.test.main()
58 |
--------------------------------------------------------------------------------
/tensor2tensor/envs/tic_tac_toe_env_problem.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """TicTacToeEnvProblem wraps the TicTacToeEnv in an EnvProblem."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.envs import gym_env_problem
23 | from tensor2tensor.layers import modalities
24 | from tensor2tensor.utils import registry
25 |
26 |
27 | @registry.register_env_problem
28 | class TicTacToeEnvProblem(gym_env_problem.GymEnvProblem):
29 | """Plays `batch_size` games of tic-tac-toe."""
30 |
31 | def __init__(self):
32 | super(TicTacToeEnvProblem, self).__init__(
33 | base_env_name="T2TEnv-TicTacToeEnv-v0",
34 | reward_range=(-1, 1))
35 |
36 | @property
37 | def input_modality(self):
38 | return modalities.ModalityType.IDENTITY_SYMBOL
39 |
40 | @property
41 | def input_vocab_size(self):
42 | # Since a box can be either x or o or empty.
43 | return 3
44 |
45 | @property
46 | def target_modality(self):
47 | return modalities.ModalityType.IDENTITY_SYMBOL
48 |
49 | @property
50 | def target_vocab_size(self):
51 | # Since reward is either -1 or 0 or +1.
52 | return 3
53 |
54 | @property
55 | def action_modality(self):
56 | return modalities.ModalityType.SYMBOL_WEIGHTS_ALL
57 |
--------------------------------------------------------------------------------
/tensor2tensor/envs/tic_tac_toe_env_problem_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2tensor.envs.tic_tac_toe_env_problem."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import numpy as np
23 | from tensor2tensor.envs import env_problem_utils
24 | from tensor2tensor.envs import tic_tac_toe_env # pylint: disable=unused-import
25 | from tensor2tensor.envs import tic_tac_toe_env_problem # pylint: disable=unused-import
26 | from tensor2tensor.utils import registry
27 | import tensorflow.compat.v1 as tf
28 |
29 |
30 | class TicTacToeEnvProblemTest(tf.test.TestCase):
31 |
32 | def test_registration_and_interaction_with_env_problem(self):
33 | batch_size = 5
34 | # This ensures that registration has occurred.
35 | ep = registry.env_problem("tic_tac_toe_env_problem", batch_size=batch_size)
36 | ep.reset()
37 | num_done, num_lost, num_won, num_draw = 0, 0, 0, 0
38 | nsteps = 100
39 | for _ in range(nsteps):
40 | actions = np.stack([ep.action_space.sample() for _ in range(batch_size)])
41 | obs, rewards, dones, infos = ep.step(actions)
42 |
43 | # Assert that things are happening batchwise.
44 | self.assertEqual(batch_size, len(obs))
45 | self.assertEqual(batch_size, len(rewards))
46 | self.assertEqual(batch_size, len(dones))
47 | self.assertEqual(batch_size, len(infos))
48 |
49 | done_indices = env_problem_utils.done_indices(dones)
50 | ep.reset(done_indices)
51 | num_done += sum(dones)
52 | for r, d in zip(rewards, dones):
53 | if not d:
54 | continue
55 | if r == -1:
56 | num_lost += 1
57 | elif r == 0:
58 | num_draw += 1
59 | elif r == 1:
60 | num_won += 1
61 | else:
62 | raise ValueError("reward should be -1, 0, 1 but is {}".format(r))
63 |
64 | # Assert that something got done atleast, without that the next assert is
65 | # meaningless.
66 | self.assertGreater(num_done, 0)
67 |
68 | # Assert that things are consistent.
69 | self.assertEqual(num_done, num_won + num_lost + num_draw)
70 |
71 |
72 | if __name__ == "__main__":
73 | tf.test.main()
74 |
--------------------------------------------------------------------------------
/tensor2tensor/envs/time_step_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2tensor.envs.time_step."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.envs import time_step
23 |
24 | import tensorflow.compat.v1 as tf
25 |
26 |
27 | class TimeStepTest(tf.test.TestCase):
28 |
29 | def test_create_time_step(self):
30 | ts = time_step.TimeStep.create_time_step(
31 | observation=1, done=True, raw_reward=1.0, processed_reward=1, action=1,
32 | info={1: 1, 2: 4})
33 |
34 | self.assertEqual(1, ts.observation)
35 | self.assertTrue(ts.done)
36 | self.assertNear(1.0, ts.raw_reward, 1e-6)
37 | self.assertEqual(1, ts.processed_reward)
38 | self.assertEqual(1, ts.action)
39 | self.assertEqual({1: 1, 2: 4}, ts.info)
40 |
41 | def test_replace(self):
42 | ts = time_step.TimeStep.create_time_step(observation=1, action=1)
43 | self.assertFalse(ts.done)
44 |
45 | tsr = ts.replace(action=2, done=True, info={1: 1, 2: 4})
46 |
47 | # Asert that ts didn't change.
48 | self.assertFalse(ts.done)
49 | self.assertEqual(1, ts.observation)
50 | self.assertEqual(1, ts.action)
51 |
52 | # But tsr is as expected.
53 | self.assertTrue(tsr.done)
54 | self.assertEqual(1, tsr.observation) # unchanged
55 | self.assertEqual(2, tsr.action) # changed
56 | self.assertEqual({1: 1, 2: 4}, tsr.info)
57 |
58 |
59 | if __name__ == '__main__':
60 | tf.test.main()
61 |
--------------------------------------------------------------------------------
/tensor2tensor/insights/README.md:
--------------------------------------------------------------------------------
1 | # Tensor2Tensor Insights
2 |
3 | The Insights packages provides an interactive webservice for understanding the
4 | inner workings of a Tensor2Tensor model. It will provide a series of
5 | visualizations extracted from a requested T2T model that informs model developers
6 | and model users on how to improve or best utilize a model.
7 |
8 | ## Dependencies
9 |
10 | Before using the Insights server, you must install [Bower](https://bower.io/)
11 | which we use to manage our web component dependencies. You can easily install
12 | this with the [Node Package Manager](https://www.npmjs.com/).
13 |
14 | ## Setup Instructions
15 |
16 | After training a model, such as according to the Quick Start guide, you can run
17 | the `t2t-insights-server` binary and begin querying it.
18 |
19 | First, prepare the bower dependencies by navigating into the
20 | `tensor2tensor/insights/polymer` directory and running `bower install`:
21 |
22 | ```
23 | pushd tensor2tensor/insights/polymer
24 | bower install
25 | popd
26 | ```
27 |
28 | The models run by server is then configured by a JSON version of the
29 | InsightsConfiguration protocol buffer. Using the model trained in the Quick
30 | Start guide, a sample configuration would be:
31 |
32 | ```
33 | {
34 | "configuration": [{
35 | "source_language": "en",
36 | "target_language": "de",
37 | "label": "transformers_wmt32k",
38 | "transformer": {
39 | "model": "transformer",
40 | "model_dir": "/tmp/t2t/train",
41 | "data_dir": "/tmp/t2t/data",
42 | "hparams": "",
43 | "hparams_set": "transformer_base_single_gpu",
44 | "problem": "translate_ende_wmt32k"
45 | }
46 | }],
47 | "language": [{
48 | "code": "en",
49 | "name": "English"
50 | },{
51 | "code": "de",
52 | "name": "German"
53 | }]
54 | }
55 | ```
56 |
57 | With that saved to `configuration.json`, run the following:
58 |
59 | ```
60 | t2t-insights-server \
61 | --configuration=configuration.json \
62 | --static_path=`pwd`/tensor2tensor/insights/polymer
63 | ```
64 |
65 | This will bring up a minimal [Flask](http://flask.pocoo.org/) REST service
66 | served by a [GUnicorn](http://gunicorn.org/) HTTP Server.
67 |
68 | ## Features to be developed
69 |
70 | This is a minimal web server. We are in the process of adding additional
71 | exciting features that give insight into a model's behavior:
72 |
73 | * Integrating a multi-head attention visualization.
74 | * Registering multiple models to compare their behavior.
75 | * Indexing training data to find examples related to a current query.
76 | * Tracking interesting query + translation pairs for deeper analysis.
77 |
--------------------------------------------------------------------------------
/tensor2tensor/insights/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/tensor2tensor/insights/insight_configuration.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensor2tensor;
4 |
5 | // Configures the Neural Machine Translation Insight Frontend with a set of
6 | // supported query processors and languages.
7 | message InsightConfiguration {
8 | // Specifies zero or more models to inspect.
9 | repeated QueryProcessorConfiguration configuration = 1;
10 |
11 | // Specifies language codes and display names.
12 | repeated Language language = 2;
13 | }
14 |
15 | // A displayable language name.
16 | message Language {
17 | // The BCP-47 Language code.
18 | string code = 1;
19 | // The language's display name.
20 | string name = 2;
21 | }
22 |
23 | // Configures a QueryProcessor and registers it with the Insight Frontend when
24 | // responding to analysis queries.
25 | message QueryProcessorConfiguration {
26 | // The model's BCP-47 source language code.
27 | string source_language = 1;
28 | // The model's BCP-47 target language code.
29 | string target_language = 2;
30 | // A short label for the model.
31 | string label = 3;
32 | // The QueryProcessor to use. By default we just use the TransformerModel.
33 | string query_processor = 4;
34 |
35 | // Configuration for the TransformerModel.
36 | TransformerConfiguration transformer = 5;
37 | }
38 |
39 | // Specifies the parameters for a trained Transformer model to inspect. These
40 | // parameters match those in t2t-trainer and t2t-decoder.
41 | message TransformerConfiguration {
42 | // The model type.
43 | string model = 1;
44 | // The trained model directory.
45 | string model_dir = 2;
46 | // The data directory for the model.
47 | string data_dir = 3;
48 |
49 | // The hyperparameter set for running the model.
50 | string hparams_set = 4;
51 | // Overriding hyperparameters.
52 | string hparams = 5;
53 | // The problem sets over which this model was trained and configured.
54 | string problems = 6;
55 | }
56 |
--------------------------------------------------------------------------------
/tensor2tensor/insights/polymer/.bowerrc:
--------------------------------------------------------------------------------
1 | {
2 | "directory": "."
3 | }
4 |
--------------------------------------------------------------------------------
/tensor2tensor/insights/polymer/insights_app/insights-app.js:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 The Tensor2Tensor Authors.
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * `` Manages the views of the NMT Insights App.
20 | *
21 | * ### Usage
22 | *
23 | *
24 | *
25 | */
26 | class InsightsApp extends Polymer.Element {
27 | /**
28 | * @return {string} The component name.
29 | */
30 | static get is() {
31 | return 'insights-app';
32 | }
33 |
34 | /**
35 | * @return {!Object} The component properties.
36 | */
37 | static get properties() {
38 | return {
39 | /**
40 | * @type {string}
41 | */
42 | page: {
43 | type: String,
44 | reflectToAttribute: true,
45 | },
46 | };
47 | }
48 |
49 | /**
50 | * @return {!Array} The component observers.
51 | */
52 | static get observers() {
53 | return [
54 | 'routePageChanged_(routeData.page)',
55 | ];
56 | }
57 |
58 | /**
59 | * Updates the page field if page exists or uses a default value.
60 | * @param {?string} page The current page name being viewed.
61 | * @private
62 | */
63 | routePageChanged_(page) {
64 | if (page == this.page) {
65 | return;
66 | }
67 | this.page = page || 'explore';
68 | this.set('routeData.page', this.page);
69 |
70 | // Refresh the now selected page in case it needs new data on a new view.
71 | let currentPage = this.get('currentPage');
72 | if (currentPage) {
73 | currentPage.refresh();
74 | }
75 | }
76 | }
77 |
78 | customElements.define(InsightsApp.is, InsightsApp);
79 |
--------------------------------------------------------------------------------
/tensor2tensor/insights/polymer/language_selector/language-selector-content.html:
--------------------------------------------------------------------------------
1 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 | [[item.name]] ([[item.code]])
58 |
59 |
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/tensor2tensor/insights/polymer/language_selector/language-selector.html:
--------------------------------------------------------------------------------
1 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
38 |
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/tensor2tensor/insights/polymer/language_selector/language-selector.js:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 The Tensor2Tensor Authors.
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * `` provides a searchable dropdown of languages.
20 | *
21 | * The dropdown will present the selected language's Name. When opened, the
22 | * search bar will filter available languages by any language name or code that
23 | * has the query text as a substring.
24 | *
25 | * By default, this will auto select a provided language with language code
26 | * 'en'.
27 | *
28 | * ### Usage
29 | *
30 | *
31 | *
32 | */
33 | class LanguageSelector extends Polymer.Element {
34 | /**
35 | * @return {string} The component name.
36 | */
37 | static get is() {
38 | return 'language-selector';
39 | }
40 |
41 | /**
42 | * @return {!Object} The component properties.
43 | */
44 | static get properties() {
45 | return {
46 | /**
47 | * @type {string}
48 | */
49 | label: {
50 | type: String,
51 | },
52 | /**
53 | * @type {?Array}
54 | */
55 | languages: {
56 | type: Array,
57 | },
58 | /**
59 | * @type {!Language}
60 | */
61 | value: {
62 | type: Object,
63 | notify: true,
64 | },
65 | /**
66 | * @type {string}
67 | */
68 | defaultCode: {
69 | type: String,
70 | value: 'en',
71 | },
72 | };
73 | }
74 |
75 | /**
76 | * Selects the language in the drop down.
77 | * @param {Language} language The language to pre-select.
78 | * @public
79 | */
80 | forceSelection(language) {
81 | this.$.selector.forceSelection(language);
82 | }
83 | }
84 |
85 | customElements.define(LanguageSelector.is, LanguageSelector);
86 |
--------------------------------------------------------------------------------
/tensor2tensor/insights/polymer/processing_visualization/processing-visualization.js:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 The Tensor2Tensor Authors.
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * `` summarises pre/post processing steps.
20 | *
21 | * This element presents the pre-processing segmentation steps and
22 | * post-processing de-segmentation and rewrite steps that are applied to a
23 | * translation query.
24 | *
25 | * ### Usage
26 | *
27 | *
28 | */
29 | class ProcessingVisualization extends Polymer.Element {
30 | /**
31 | * @return {string} The component name.
32 | */
33 | static get is() {
34 | return 'processing-visualization';
35 | }
36 |
37 | /**
38 | * @return {!Object} The component properties.
39 | */
40 | static get properties() {
41 | return {
42 | /**
43 | * @type {!QueryProcessingVisualization}
44 | */
45 | data: {
46 | type: Object,
47 | },
48 | };
49 | }
50 | }
51 |
52 | customElements.define(ProcessingVisualization.is, ProcessingVisualization);
53 |
--------------------------------------------------------------------------------
/tensor2tensor/insights/query_processor.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A base class for all query processing classes."""
17 |
18 |
19 | class QueryProcessor(object):
20 | """Base class for any class that wants to process sequence queries.
21 |
22 | QueryProcessor classes are expected to convert a string query to a series of
23 | visualization structures.
24 |
25 | TODO(kstevens): Define how the visualization structures should look once the
26 | protos are in better shape.
27 | """
28 |
29 | def process(self, query):
30 | """Returns the generated visualizations for query.
31 |
32 | Args:
33 | query: The string input
34 |
35 | Returns:
36 | A dictionary with one key: 'result' that maps to a list of visualization
37 | objects.
38 | """
39 | del query
40 | return {"result": []}
41 |
--------------------------------------------------------------------------------
/tensor2tensor/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/tensor2tensor/layers/ngram_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for n-gram layer."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.layers import ngram
23 |
24 | from tensor2tensor.utils import test_utils
25 |
26 | import tensorflow.compat.v1 as tf
27 | tf.enable_eager_execution()
28 |
29 |
30 | class NGramTest(tf.test.TestCase):
31 |
32 | @test_utils.run_in_graph_and_eager_modes()
33 | def testNGramLayerShape(self):
34 | batch_size = 2
35 | length = 8
36 | vocab_size = 3
37 | minval = 1
38 | maxval = 4
39 | inputs = tf.random_uniform(
40 | [batch_size, length], minval=0, maxval=vocab_size, dtype=tf.int32)
41 | layer = ngram.NGram(vocab_size, minval, maxval)
42 | outputs = layer(inputs)
43 | outputs_val = self.evaluate(outputs)
44 | num_ngrams = sum([vocab_size**n for n in range(minval, maxval)])
45 | self.assertEqual(outputs_val.shape, (batch_size, num_ngrams))
46 |
47 | @test_utils.run_in_graph_and_eager_modes()
48 | def testNGramLayerOutput(self):
49 | inputs = tf.constant(
50 | [[0, 0, 0, 0, 1],
51 | [2, 1, 2, 1, 0]], dtype=tf.int32)
52 | layer = ngram.NGram(3, minval=1, maxval=3)
53 | outputs = layer(inputs)
54 | expected_outputs = tf.constant(
55 | [[4., 1., 0., 2., 0., 0., 0., 0., 0., 0., 0., 0.],
56 | [1., 2., 2., 0., 0., 0., 0., 0., 0., 0., 2., 0.]], dtype=tf.float32)
57 | outputs_val, expected_outputs_val = self.evaluate(
58 | [outputs, expected_outputs])
59 | self.assertAllEqual(outputs_val, expected_outputs_val)
60 |
61 | if __name__ == "__main__":
62 | tf.test.main()
63 |
64 |
--------------------------------------------------------------------------------
/tensor2tensor/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/tensor2tensor/metrics/video_conditional_fvd_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for video_conditional_fvd."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.metrics import video_conditional_fvd
23 | import tensorflow.compat.v1 as tf
24 |
25 |
26 | class VideoConditionalFvdTest(tf.test.TestCase):
27 |
28 | def test_sample(self):
29 | dataset = video_conditional_fvd.VideoEvaluationDataset(
30 | n_input_frames=4,
31 | n_output_frames=10,
32 | get_video_batch_fn=None)
33 | model = video_conditional_fvd.Model(
34 | apply_fn=None,
35 | load_fn=None)
36 | video_conditional_fvd.evaluate_model(dataset, model, 10, 16)
37 |
38 |
39 | if __name__ == '__main__':
40 | tf.test.main()
41 |
--------------------------------------------------------------------------------
/tensor2tensor/models/README.md:
--------------------------------------------------------------------------------
1 | # Constructing T2T Models.
2 |
3 | This directory contains T2T models, their hyperparameters, and a number
4 | of common layers and hyperparameter settings to help construct new models.
5 | Common building blocks are in `common_layers.py` and `common_attention.py`.
6 | Common hyperparameters are in `common_hparams.py`. Models are imported in
7 | `__init__.py`.
8 |
9 | ## Adding a new model.
10 |
11 | To add a model to the built-in set, create a new file (see, e.g.,
12 | `neural_gpu.py`) and write your model class inheriting from `T2TModel` there and
13 | decorate it with `registry.register_model`. Import it in `__init__.py`.
14 |
15 | It is now available to use with the trainer binary (`t2t-trainer`) using the
16 | `--model=model_name` flag.
17 |
--------------------------------------------------------------------------------
/tensor2tensor/models/basic.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Basic models for testing simple tasks."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.layers import common_hparams
23 | from tensor2tensor.layers import common_layers
24 | from tensor2tensor.utils import registry
25 | from tensor2tensor.utils import t2t_model
26 |
27 | import tensorflow.compat.v1 as tf
28 |
29 |
30 | @registry.register_model
31 | class BasicFcRelu(t2t_model.T2TModel):
32 | """Basic fully-connected + ReLU model."""
33 |
34 | def body(self, features):
35 | hparams = self.hparams
36 | x = features["inputs"]
37 | shape = common_layers.shape_list(x)
38 | x = tf.reshape(x, [-1, shape[1] * shape[2] * shape[3]])
39 | for i in range(hparams.num_hidden_layers):
40 | x = tf.layers.dense(x, hparams.hidden_size, name="layer_%d" % i)
41 | x = tf.nn.dropout(x, keep_prob=1.0 - hparams.dropout)
42 | x = tf.nn.relu(x)
43 | return tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) # 4D For T2T.
44 |
45 |
46 | @registry.register_hparams
47 | def basic_fc_small():
48 | """Small fully connected model."""
49 | hparams = common_hparams.basic_params1()
50 | hparams.learning_rate = 0.1
51 | hparams.batch_size = 128
52 | hparams.hidden_size = 256
53 | hparams.num_hidden_layers = 2
54 | hparams.initializer = "uniform_unit_scaling"
55 | hparams.initializer_gain = 1.0
56 | hparams.weight_decay = 0.0
57 | hparams.dropout = 0.0
58 | return hparams
59 |
--------------------------------------------------------------------------------
/tensor2tensor/models/basic_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Basic nets tests."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | import numpy as np
22 |
23 | from tensor2tensor.data_generators import mnist # pylint: disable=unused-import
24 | from tensor2tensor.models import basic
25 | from tensor2tensor.utils import trainer_lib
26 |
27 | import tensorflow.compat.v1 as tf
28 | from tensorflow.compat.v1 import estimator as tf_estimator
29 |
30 |
31 | class BasicTest(tf.test.TestCase):
32 |
33 | def testBasicFcRelu(self):
34 | x = np.random.randint(256, size=(1, 28, 28, 1))
35 | y = np.random.randint(10, size=(1, 1))
36 | hparams = trainer_lib.create_hparams(
37 | "basic_fc_small", problem_name="image_mnist", data_dir=".")
38 | with self.test_session() as session:
39 | features = {
40 | "inputs": tf.constant(x, dtype=tf.int32),
41 | "targets": tf.constant(y, dtype=tf.int32),
42 | }
43 | model = basic.BasicFcRelu(hparams, tf_estimator.ModeKeys.TRAIN)
44 | logits, _ = model(features)
45 | session.run(tf.global_variables_initializer())
46 | res = session.run(logits)
47 | self.assertEqual(res.shape, (1, 1, 1, 1, 10))
48 |
49 |
50 | if __name__ == "__main__":
51 | tf.test.main()
52 |
--------------------------------------------------------------------------------
/tensor2tensor/models/bytenet_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """ByteNet tests."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | import numpy as np
22 |
23 | from tensor2tensor.data_generators import problem_hparams
24 | from tensor2tensor.models import bytenet
25 |
26 | import tensorflow.compat.v1 as tf
27 | from tensorflow.compat.v1 import estimator as tf_estimator
28 |
29 |
30 | class ByteNetTest(tf.test.TestCase):
31 |
32 | def testByteNet(self):
33 | vocab_size = 9
34 | x = np.random.randint(1, high=vocab_size, size=(3, 5, 1, 1))
35 | y = np.random.randint(1, high=vocab_size, size=(3, 6, 1, 1))
36 | hparams = bytenet.bytenet_base()
37 | p_hparams = problem_hparams.test_problem_hparams(vocab_size,
38 | vocab_size,
39 | hparams)
40 | with self.test_session() as session:
41 | features = {
42 | "inputs": tf.constant(x, dtype=tf.int32),
43 | "targets": tf.constant(y, dtype=tf.int32),
44 | }
45 | model = bytenet.ByteNet(
46 | hparams, tf_estimator.ModeKeys.TRAIN, p_hparams)
47 | logits, _ = model(features)
48 | session.run(tf.global_variables_initializer())
49 | res = session.run(logits)
50 | self.assertEqual(res.shape, (3, 50, 1, 1, vocab_size))
51 |
52 |
53 | if __name__ == "__main__":
54 | tf.test.main()
55 |
--------------------------------------------------------------------------------
/tensor2tensor/models/neural_architecture_search/README.md:
--------------------------------------------------------------------------------
1 | This directory contains the configurable model code used in the Evolved
2 | Transformer paper (https://arxiv.org/abs/1901.11117). It can be used to train
3 | models in the search space as was done in the paper.
4 |
--------------------------------------------------------------------------------
/tensor2tensor/models/neural_architecture_search/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 |
--------------------------------------------------------------------------------
/tensor2tensor/models/neural_gpu_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Neural GPU."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | import numpy as np
22 |
23 | from tensor2tensor.data_generators import problem_hparams
24 | from tensor2tensor.layers import common_hparams
25 | from tensor2tensor.models import neural_gpu
26 |
27 | import tensorflow.compat.v1 as tf
28 | from tensorflow.compat.v1 import estimator as tf_estimator
29 |
30 |
31 | class NeuralGPUTest(tf.test.TestCase):
32 |
33 | def testNeuralGPU(self):
34 | hparams = common_hparams.basic_params1()
35 | batch_size = 3
36 | input_length = 5
37 | target_length = input_length
38 | input_vocab_size = 9
39 | target_vocab_size = 11
40 | p_hparams = problem_hparams.test_problem_hparams(input_vocab_size,
41 | target_vocab_size,
42 | hparams)
43 | inputs = np.random.randint(
44 | input_vocab_size, size=(batch_size, input_length, 1, 1))
45 | targets = np.random.randint(
46 | target_vocab_size, size=(batch_size, target_length, 1, 1))
47 | with self.test_session() as session:
48 | features = {
49 | "inputs": tf.constant(inputs, dtype=tf.int32),
50 | "targets": tf.constant(targets, dtype=tf.int32)
51 | }
52 | model = neural_gpu.NeuralGPU(hparams, tf_estimator.ModeKeys.TRAIN,
53 | p_hparams)
54 | logits, _ = model(features)
55 | session.run(tf.global_variables_initializer())
56 | res = session.run(logits)
57 | self.assertEqual(res.shape, (batch_size, target_length, 1, 1,
58 | target_vocab_size))
59 |
60 |
61 | if __name__ == "__main__":
62 | tf.test.main()
63 |
--------------------------------------------------------------------------------
/tensor2tensor/models/research/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/tensor2tensor/models/research/glow_init_hook.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Hook to run glow initialization on a larger batch."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow.compat.v1 as tf
23 |
24 |
25 | class GlowInitHook(tf.train.SessionRunHook):
26 | """
27 | Hook that runs data-dependent initialization once before the first step.
28 |
29 | The init op is stored in the tf collection glow_init_op. Look at the
30 | "body" in glow.py for more details.
31 | """
32 |
33 | def after_create_session(self, session, coord):
34 | del coord
35 | global_step = session.run(tf.train.get_global_step())
36 | if global_step == 0:
37 | ddi = tf.get_collection("glow_init_op")
38 | # In-case of a multi-GPU system, this just runs the first op in the
39 | # collection.
40 | if ddi:
41 | session.run(ddi[0])
42 |
--------------------------------------------------------------------------------
/tensor2tensor/models/research/transformer_sketch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Transformer Sketch for im2sketch problems.
17 | """
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | from tensor2tensor.layers import common_layers
24 | from tensor2tensor.models import transformer
25 | from tensor2tensor.utils import registry
26 |
27 | import tensorflow.compat.v1 as tf
28 |
29 |
30 | @registry.register_model
31 | class TransformerSketch(transformer.Transformer):
32 | """Transformer with strided convolutions."""
33 |
34 | def encode(self, inputs, target_space, hparams, features=None, losses=None):
35 | """Add layers of strided convolutions on top of encoder."""
36 | with tf.variable_scope("downstride"):
37 | hparams = self.hparams
38 | kernel, strides = (4, 4), (2, 2)
39 | x = inputs
40 | # Down-convolutions.
41 | for i in range(hparams.num_compress_steps):
42 | x = common_layers.make_even_size(x)
43 | x = tf.layers.conv2d(
44 | x, hparams.hidden_size, kernel, strides=strides,
45 | padding="SAME", activation=common_layers.belu, name="conv_%d" % i)
46 | x = common_layers.layer_norm(x)
47 |
48 | encoder_output, encoder_decoder_attention_bias = super(
49 | TransformerSketch, self).encode(
50 | x, target_space, hparams, features=features, losses=losses)
51 | return encoder_output, encoder_decoder_attention_bias
52 |
53 |
54 | @registry.register_hparams
55 | def transformer_sketch():
56 | """Basic transformer_sketch hparams."""
57 | hparams = transformer.transformer_small()
58 | hparams.num_compress_steps = 4
59 | hparams.batch_size = 32
60 | hparams.clip_grad_norm = 2.
61 | hparams.sampling_method = "random"
62 | return hparams
63 |
--------------------------------------------------------------------------------
/tensor2tensor/models/research/transformer_vae_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2tensor.models.research.transformer_vae."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import numpy as np
21 | from tensor2tensor.data_generators import problem_hparams
22 | from tensor2tensor.models.research import transformer_vae
23 | import tensorflow.compat.v1 as tf
24 | from tensorflow.compat.v1 import estimator as tf_estimator
25 |
26 |
27 | class TransformerVaeTest(tf.test.TestCase):
28 |
29 | def testTransformerAEOnDVQ(self):
30 | batch_size = 3
31 | input_length = 5
32 | target_length = 16
33 | vocab_size = 9
34 | hparams = transformer_vae.transformer_ae_small()
35 | hparams.bottleneck_kind = "dvq"
36 | hparams.dp_strength = 0
37 | p_hparams = problem_hparams.test_problem_hparams(vocab_size,
38 | vocab_size,
39 | hparams)
40 | hparams.problem_hparams = p_hparams
41 | inputs = np.random.randint(
42 | vocab_size, size=(batch_size, input_length, 1, 1))
43 | targets = np.random.randint(
44 | vocab_size, size=(batch_size, target_length, 1, 1))
45 | features = {
46 | "inputs": tf.constant(inputs, dtype=tf.int32),
47 | "targets": tf.constant(targets, dtype=tf.int32),
48 | "target_space_id": tf.constant(1, dtype=tf.int32),
49 | }
50 | tf.train.create_global_step()
51 | model = transformer_vae.TransformerAE(hparams, tf_estimator.ModeKeys.TRAIN,
52 | p_hparams)
53 | logits, _ = model(features)
54 | with self.test_session() as session:
55 | session.run(tf.global_variables_initializer())
56 | logits_val = session.run(logits)
57 | self.assertEqual(logits_val.shape,
58 | (batch_size, target_length, 1, 1, vocab_size))
59 |
60 |
61 | if __name__ == "__main__":
62 | tf.test.main()
63 |
--------------------------------------------------------------------------------
/tensor2tensor/models/resnet_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Resnet tests."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import numpy as np
23 |
24 | from tensor2tensor.data_generators import problem_hparams
25 | from tensor2tensor.layers import modalities
26 | from tensor2tensor.models import resnet
27 |
28 | import tensorflow.compat.v1 as tf
29 | from tensorflow.compat.v1 import estimator as tf_estimator
30 |
31 |
32 | def resnet_tiny_cpu():
33 | hparams = resnet.resnet_base()
34 | hparams.layer_sizes = [2, 2, 2, 2]
35 | hparams.use_nchw = False
36 | return hparams
37 |
38 |
39 | class ResnetTest(tf.test.TestCase):
40 |
41 | def _test_resnet(self, img_size, output_size):
42 | vocab_size = 9
43 | batch_size = 2
44 | x = np.random.randint(
45 | 256, size=(batch_size, img_size, img_size, 3))
46 | y = np.random.randint(
47 | 1, high=vocab_size, size=(batch_size, 1, 1, 1))
48 | hparams = resnet_tiny_cpu()
49 | p_hparams = problem_hparams.test_problem_hparams(vocab_size,
50 | vocab_size,
51 | hparams)
52 | p_hparams.modality["inputs"] = modalities.ModalityType.IMAGE
53 | p_hparams.modality["targets"] = modalities.ModalityType.CLASS_LABEL
54 | with self.test_session() as session:
55 | features = {
56 | "inputs": tf.constant(x, dtype=tf.int32),
57 | "targets": tf.constant(y, dtype=tf.int32),
58 | }
59 | model = resnet.Resnet(hparams, tf_estimator.ModeKeys.TRAIN, p_hparams)
60 | logits, _ = model(features)
61 | session.run(tf.global_variables_initializer())
62 | res = session.run(logits)
63 | self.assertEqual(res.shape, (batch_size,) + output_size + (1, vocab_size))
64 |
65 | def testResnetLarge(self):
66 | self._test_resnet(img_size=224, output_size=(1, 1))
67 |
68 |
69 | if __name__ == "__main__":
70 | tf.test.main()
71 |
--------------------------------------------------------------------------------
/tensor2tensor/models/video/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/tensor2tensor/models/video/basic_deterministic_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Basic tests for basic deterministic model."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.models.video import basic_deterministic
23 | from tensor2tensor.models.video import basic_deterministic_params
24 | from tensor2tensor.models.video import tests_utils
25 |
26 | import tensorflow.compat.v1 as tf
27 |
28 |
29 | class NextFrameTest(tests_utils.BaseNextFrameTest):
30 |
31 | def testBasicDeterministic(self):
32 | self.TestOnVariousInputOutputSizes(
33 | basic_deterministic_params.next_frame_basic_deterministic(),
34 | basic_deterministic.NextFrameBasicDeterministic,
35 | 256,
36 | False)
37 |
38 | if __name__ == "__main__":
39 | tf.test.main()
40 |
--------------------------------------------------------------------------------
/tensor2tensor/models/video/basic_recurrent.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Basic recurrent models for testing simple tasks."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.layers import common_video
23 | from tensor2tensor.models.video import basic_stochastic
24 | from tensor2tensor.utils import registry
25 |
26 |
27 | @registry.register_model
28 | class NextFrameBasicRecurrent(
29 | basic_stochastic.NextFrameBasicStochasticDiscrete):
30 | """Basic next-frame recurrent model."""
31 |
32 | @property
33 | def is_recurrent_model(self):
34 | return True
35 |
36 | def middle_network(self, layer, internal_states):
37 | lstm_func = common_video.conv_lstm_2d
38 | hp = self.hparams
39 |
40 | lstm_states = internal_states
41 | if lstm_states is None:
42 | lstm_states = [None] * hp.num_lstm_layers
43 |
44 | # LSTM layers
45 | x = layer
46 | for j in range(hp.num_lstm_layers):
47 | x, lstm_states[j] = lstm_func(x, lstm_states[j], hp.num_lstm_filters)
48 | return x, lstm_states
49 |
50 |
51 | @registry.register_hparams
52 | def next_frame_basic_recurrent():
53 | """Basic 2-frame recurrent model with stochastic tower."""
54 | hparams = basic_stochastic.next_frame_basic_stochastic_discrete()
55 | hparams.filter_double_steps = 2
56 | hparams.hidden_size = 64
57 | hparams.video_num_input_frames = 4
58 | hparams.video_num_target_frames = 4
59 | hparams.concat_internal_states = False
60 | hparams.add_hparam("num_lstm_layers", 2)
61 | hparams.add_hparam("num_lstm_filters", 256)
62 | return hparams
63 |
--------------------------------------------------------------------------------
/tensor2tensor/models/video/basic_recurrent_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Basic tests for basic deterministic model."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.models.video import basic_recurrent
23 | from tensor2tensor.models.video import tests_utils
24 |
25 | import tensorflow.compat.v1 as tf
26 |
27 |
28 | class NextFrameTest(tests_utils.BaseNextFrameTest):
29 |
30 | def testBasicDeterministic(self):
31 | self.TestOnVariousInputOutputSizes(
32 | basic_recurrent.next_frame_basic_recurrent(),
33 | basic_recurrent.NextFrameBasicRecurrent,
34 | 256,
35 | False)
36 |
37 | if __name__ == "__main__":
38 | tf.test.main()
39 |
--------------------------------------------------------------------------------
/tensor2tensor/models/video/basic_stochastic_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Basic tests for basic stochastic model."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.models.video import basic_stochastic
23 | from tensor2tensor.models.video import tests_utils
24 |
25 | import tensorflow.compat.v1 as tf
26 |
27 |
28 | class NextFrameTest(tests_utils.BaseNextFrameTest):
29 |
30 | def testBasicStochastic(self):
31 | self.TestOnVariousInputOutputSizes(
32 | basic_stochastic.next_frame_basic_stochastic(),
33 | basic_stochastic.NextFrameBasicStochastic,
34 | 256,
35 | False)
36 |
37 | if __name__ == "__main__":
38 | tf.test.main()
39 |
--------------------------------------------------------------------------------
/tensor2tensor/models/video/emily_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Basic tests for emily's model."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.models.video import emily
23 | from tensor2tensor.models.video import tests_utils
24 |
25 |
26 | import tensorflow.compat.v1 as tf
27 |
28 |
29 | class NextFrameTest(tests_utils.BaseNextFrameTest):
30 |
31 | def testEmily(self):
32 | self.TestOnVariousInputOutputSizes(
33 | emily.next_frame_emily(),
34 | emily.NextFrameEmily,
35 | 1)
36 |
37 |
38 | if __name__ == "__main__":
39 | tf.test.main()
40 |
--------------------------------------------------------------------------------
/tensor2tensor/models/video/epva_params.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Param sets for EPVA model."""
17 |
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from tensor2tensor.layers import modalities
22 | from tensor2tensor.models.video import basic_deterministic_params
23 | from tensor2tensor.utils import registry
24 |
25 |
26 | @registry.register_hparams
27 | def next_frame_epva():
28 | """EPVA hparams."""
29 | hparams = basic_deterministic_params.next_frame_basic_deterministic()
30 | hparams.video_num_input_frames = 4
31 | hparams.video_num_target_frames = 4
32 | hparams.bottom = {
33 | "inputs": modalities.video_raw_bottom,
34 | "targets": modalities.video_raw_targets_bottom,
35 | }
36 | hparams.loss = {
37 | "targets": modalities.video_l2_raw_loss,
38 | }
39 | hparams.top = {
40 | "targets": modalities.video_raw_top,
41 | }
42 | hparams.learning_rate_schedule = "constant"
43 | hparams.learning_rate_constant = 1e-05
44 | hparams.batch_size = 2
45 | hparams.clip_grad_norm = 0.01
46 | # TODO(msaffar): disentangle EPVA from SV2P
47 | hparams.add_hparam("reward_prediction", False)
48 | hparams.add_hparam("clip_pixel_values", True)
49 | hparams.add_hparam("context_frames", 5)
50 | hparams.add_hparam("enc_learning_rate", 1e-5)
51 | hparams.add_hparam("enc_pred_loss_scale", 0.1)
52 | hparams.add_hparam("enc_pred_loss_scale_delay", 6e5)
53 | hparams.add_hparam("enc_size", 64)
54 | hparams.add_hparam("enc_keep_prob", .65)
55 | hparams.add_hparam("enc_pred_use_l1_loss", False)
56 | hparams.add_hparam("enc_pred_use_l2norm", False)
57 | hparams.add_hparam("van_learning_rate", 3e-5)
58 | hparams.add_hparam("van_keep_prob", .9)
59 | hparams.add_hparam("sequence_length ", 64)
60 | hparams.add_hparam("skip_num", 2)
61 | hparams.add_hparam("pred_noise_std", 0)
62 | hparams.add_hparam("lstm_state_noise_stddev", 0)
63 | return hparams
64 |
--------------------------------------------------------------------------------
/tensor2tensor/models/video/nfg_conv3d_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test when the latent-network encoder is a conv3d net."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl.testing import parameterized
23 | from tensor2tensor.models.video import nfg_test_utils
24 | import tensorflow.compat.v1 as tf
25 |
26 | conv3d_net_hparams = (
27 | ("conv3d_net", 2, 2, "conv3d_net", "conditional", -1, 3),
28 | ("conv3d_net_gatu", 2, 2, "conv3d_net", "conditional", -1, 3, False, False,
29 | "gatu"),
30 | ("conv3d_dil", 2, 2, "conv3d_net", "conditional", -1, -1, False, True),)
31 |
32 |
33 | class NextFrameGlowConv3DTest(nfg_test_utils.NextFrameGlowTest,
34 | parameterized.TestCase):
35 |
36 | @parameterized.named_parameters(*conv3d_net_hparams)
37 | def testGlowTrainAndDecode(self, in_frames=1, out_frames=1,
38 | latent_dist_encoder="pointwise",
39 | gen_mode="conditional", pretrain_steps=-1,
40 | num_train_frames=-1, cond_first_frame=False,
41 | apply_dilations=False, activation="relu"):
42 | self.GlowTrainAndDecode(
43 | in_frames=in_frames, out_frames=out_frames,
44 | latent_dist_encoder=latent_dist_encoder, gen_mode=gen_mode,
45 | pretrain_steps=pretrain_steps, num_train_frames=num_train_frames,
46 | cond_first_frame=cond_first_frame, apply_dilations=apply_dilations,
47 | activation=activation)
48 |
49 |
50 | if __name__ == "__main__":
51 | tf.test.main()
52 |
--------------------------------------------------------------------------------
/tensor2tensor/models/video/nfg_conv_lstm_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test when the latent-network encoder is a conv-lstm."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl.testing import parameterized
23 | from tensor2tensor.models.video import nfg_test_utils
24 | import tensorflow.compat.v1 as tf
25 |
26 | conv_lstm_hparams = (
27 | ("in_3_out_2_lstm", 2, 1, "conv_lstm", "conditional", -1),
28 | ("lstm_pretrain", 2, 1, "conv_lstm", "conditional", 50000))
29 |
30 |
31 | class NextFrameGlowConv3DTest(nfg_test_utils.NextFrameGlowTest,
32 | parameterized.TestCase):
33 |
34 | @parameterized.named_parameters(*conv_lstm_hparams)
35 | def testGlowTrainAndDecode(self, in_frames=1, out_frames=1,
36 | latent_dist_encoder="pointwise",
37 | gen_mode="conditional", pretrain_steps=-1,
38 | num_train_frames=-1, cond_first_frame=False):
39 | self.GlowTrainAndDecode(
40 | in_frames=in_frames, out_frames=out_frames,
41 | latent_dist_encoder=latent_dist_encoder, gen_mode=gen_mode,
42 | pretrain_steps=pretrain_steps, num_train_frames=num_train_frames,
43 | cond_first_frame=cond_first_frame)
44 |
45 |
46 | if __name__ == "__main__":
47 | tf.test.main()
48 |
--------------------------------------------------------------------------------
/tensor2tensor/models/video/nfg_conv_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test when the latent-network encoder is a 2-D conv."""
17 |
18 | from absl.testing import parameterized
19 | from tensor2tensor.models.video import nfg_test_utils
20 | import tensorflow.compat.v1 as tf
21 |
22 | conv_net_hparams = (
23 | ("in_3_out_2_conv", 3, 1, "conv_net", "conditional"),
24 | ("conv_net_cond_first", 2, 2, "conv_net", "conditional", -1, 3, True),)
25 |
26 |
27 | class NextFrameGlowConvTest(nfg_test_utils.NextFrameGlowTest,
28 | parameterized.TestCase):
29 |
30 | @parameterized.named_parameters(*conv_net_hparams)
31 | def testGlowTrainAndDecode(self, in_frames=1, out_frames=1,
32 | latent_dist_encoder="pointwise",
33 | gen_mode="conditional", pretrain_steps=-1,
34 | num_train_frames=-1, cond_first_frame=False):
35 | self.GlowTrainAndDecode(
36 | in_frames=in_frames, out_frames=out_frames, gen_mode=gen_mode,
37 | latent_dist_encoder=latent_dist_encoder,
38 | pretrain_steps=pretrain_steps, num_train_frames=num_train_frames,
39 | cond_first_frame=cond_first_frame)
40 |
41 |
42 | if __name__ == "__main__":
43 | tf.test.main()
44 |
--------------------------------------------------------------------------------
/tensor2tensor/models/video/nfg_uncond_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for unconditional glow."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl.testing import parameterized
23 | from tensor2tensor.models.video import nfg_test_utils
24 | import tensorflow.compat.v1 as tf
25 |
26 | uncond_hparams = (
27 | ("in_1_out_1", 1, 1, "pointwise", "conditional"),
28 | ("uncond", 1, 3, "pointwise", "unconditional", -1, 1),)
29 |
30 |
31 | class NfgUncondTest(nfg_test_utils.NextFrameGlowTest, parameterized.TestCase):
32 |
33 | @parameterized.named_parameters(*uncond_hparams)
34 | def testGlowTrainAndDecode(self, in_frames=1, out_frames=1,
35 | latent_dist_encoder="pointwise",
36 | gen_mode="conditional", pretrain_steps=-1,
37 | num_train_frames=-1, cond_first_frame=False):
38 | self.GlowTrainAndDecode(
39 | in_frames=in_frames, out_frames=out_frames,
40 | latent_dist_encoder=latent_dist_encoder, gen_mode=gen_mode,
41 | pretrain_steps=pretrain_steps, num_train_frames=num_train_frames,
42 | cond_first_frame=cond_first_frame)
43 |
44 |
45 | if __name__ == "__main__":
46 | tf.test.main()
47 |
--------------------------------------------------------------------------------
/tensor2tensor/models/video/savp_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Basic tests for SAVP model."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.models.video import savp
23 | from tensor2tensor.models.video import savp_params
24 | from tensor2tensor.models.video import tests_utils
25 |
26 |
27 | import tensorflow.compat.v1 as tf
28 |
29 |
30 | class NextFrameTest(tests_utils.BaseNextFrameTest):
31 |
32 | def testSavpVAE(self):
33 | savp_hparams = savp_params.next_frame_savp()
34 | savp_hparams.use_vae = True
35 | savp_hparams.use_gan = False
36 | self.TestOnVariousInputOutputSizes(
37 | savp_hparams, savp.NextFrameSAVP, 1)
38 | self.TestOnVariousUpSampleLayers(
39 | savp_hparams, savp.NextFrameSAVP, 1)
40 |
41 | def testSavpGAN(self):
42 | hparams = savp_params.next_frame_savp()
43 | hparams.use_gan = True
44 | hparams.use_vae = False
45 | self.TestVideoModel(7, 5, hparams, savp.NextFrameSAVP, 1)
46 |
47 | hparams.gan_optimization = "sequential"
48 | self.TestVideoModel(7, 5, hparams, savp.NextFrameSAVP, 1)
49 |
50 | def testSavpGANVAE(self):
51 | hparams = savp_params.next_frame_savp()
52 | hparams.use_vae = True
53 | hparams.use_gan = True
54 | self.TestVideoModel(7, 5, hparams, savp.NextFrameSAVP, 1)
55 |
56 | def testInvalidVAEGANCombinations(self):
57 | hparams = savp_params.next_frame_savp()
58 | hparams.use_gan = False
59 | hparams.use_vae = False
60 | self.assertRaises(ValueError, self.TestVideoModel,
61 | 7, 5, hparams, savp.NextFrameSAVP, 1)
62 |
63 | if __name__ == "__main__":
64 | tf.test.main()
65 |
--------------------------------------------------------------------------------
/tensor2tensor/models/video/sv2p_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Basic tests for SV2P model."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.models.video import sv2p
23 | from tensor2tensor.models.video import sv2p_params
24 | from tensor2tensor.models.video import tests_utils
25 |
26 | import tensorflow.compat.v1 as tf
27 |
28 |
29 | class NextFrameTest(tests_utils.BaseNextFrameTest):
30 |
31 | def testSv2p(self):
32 | self.TestOnVariousInputOutputSizes(
33 | sv2p_params.next_frame_sv2p(),
34 | sv2p.NextFrameSv2p,
35 | 1,
36 | False)
37 |
38 | def testSv2pWithActions(self):
39 | self.TestWithActions(
40 | sv2p_params.next_frame_sv2p(),
41 | sv2p.NextFrameSv2p,
42 | 1,
43 | False)
44 |
45 | def testSv2pWithActionsAndRewards(self):
46 | hp = sv2p_params.next_frame_sv2p()
47 | hp.internal_loss = True
48 | self.TestWithActionAndRewards(
49 | hp,
50 | sv2p.NextFrameSv2p,
51 | 1,
52 | False)
53 |
54 | def testSv2pWithActionsAndRewardsExternalLoss(self):
55 | hp = sv2p_params.next_frame_sv2p()
56 | hp.internal_loss = False
57 | self.TestWithActionAndRewards(
58 | hp,
59 | sv2p.NextFrameSv2p,
60 | 1,
61 | False)
62 |
63 | def testSv2pTwoFrames(self):
64 | self.TestOnVariousInputOutputSizes(
65 | sv2p_params.next_frame_sv2p(),
66 | sv2p.NextFrameSv2pTwoFrames,
67 | 1,
68 | False)
69 |
70 |
71 | if __name__ == "__main__":
72 | tf.test.main()
73 |
--------------------------------------------------------------------------------
/tensor2tensor/models/xception_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Xception tests."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import numpy as np
23 |
24 | from tensor2tensor.data_generators import problem_hparams
25 | from tensor2tensor.layers import modalities
26 | from tensor2tensor.models import xception
27 |
28 | import tensorflow.compat.v1 as tf
29 | from tensorflow.compat.v1 import estimator as tf_estimator
30 |
31 |
32 | class XceptionTest(tf.test.TestCase):
33 |
34 | def _test_xception(self, img_size):
35 | vocab_size = 9
36 | batch_size = 3
37 | x = np.random.randint(
38 | 256, size=(batch_size, img_size, img_size, 3))
39 | y = np.random.randint(
40 | 1, high=vocab_size, size=(batch_size, 1, 1, 1))
41 | hparams = xception.xception_tiny()
42 | p_hparams = problem_hparams.test_problem_hparams(vocab_size,
43 | vocab_size,
44 | hparams)
45 | p_hparams.modality["inputs"] = modalities.ModalityType.IMAGE
46 | p_hparams.modality["targets"] = modalities.ModalityType.CLASS_LABEL
47 | with self.test_session() as session:
48 | features = {
49 | "inputs": tf.constant(x, dtype=tf.int32),
50 | "targets": tf.constant(y, dtype=tf.int32),
51 | }
52 | model = xception.Xception(hparams, tf_estimator.ModeKeys.TRAIN, p_hparams)
53 | logits, _ = model(features)
54 | session.run(tf.global_variables_initializer())
55 | res = session.run(logits)
56 | self.assertEqual(res.shape, (batch_size, 1, 1, 1, vocab_size))
57 |
58 | def testXceptionSmallImage(self):
59 | self._test_xception(img_size=9)
60 |
61 | def testXceptionLargeImage(self):
62 | self._test_xception(img_size=256)
63 |
64 |
65 | if __name__ == "__main__":
66 | tf.test.main()
67 |
--------------------------------------------------------------------------------
/tensor2tensor/problems.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Access T2T Problems."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from tensor2tensor.data_generators import all_problems
22 | from tensor2tensor.utils import registry
23 |
24 |
25 | def problem(name):
26 | return registry.problem(name)
27 |
28 |
29 | def available():
30 | return registry.list_base_problems()
31 |
32 |
33 | all_problems.import_modules(all_problems.ALL_MODULES)
34 |
--------------------------------------------------------------------------------
/tensor2tensor/problems_colab.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Access T2T Problems."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from tensor2tensor.data_generators import all_problems
22 | from tensor2tensor.utils import registry
23 |
24 |
25 | def problem(name):
26 | return registry.problem(name)
27 |
28 |
29 | def available():
30 | return sorted(registry.list_problems())
31 |
32 |
33 | # Import problem modules
34 | _modules = list(all_problems.MODULES)
35 |
36 | all_problems.import_modules(_modules)
37 |
--------------------------------------------------------------------------------
/tensor2tensor/problems_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """tensor2tensor.problems test."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from tensor2tensor import problems
22 |
23 | import tensorflow.compat.v1 as tf
24 |
25 |
26 | class ProblemsTest(tf.test.TestCase):
27 |
28 | def testImport(self):
29 | self.assertIsNotNone(problems)
30 |
31 | if __name__ == "__main__":
32 | tf.test.main()
33 |
--------------------------------------------------------------------------------
/tensor2tensor/rl/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/tensor2tensor/rl/datagen_with_agent.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Generate trajectories to disk with random or ckpt agent.
17 |
18 | TODO: Usage
19 | """
20 |
21 | from __future__ import absolute_import
22 | from __future__ import division
23 | from __future__ import print_function
24 |
25 | from tensor2tensor.data_generators import gym_env
26 | from tensor2tensor.utils import registry
27 |
28 | import tensorflow.compat.v1 as tf
29 |
30 | flags = tf.flags
31 | FLAGS = flags.FLAGS
32 |
33 | flags.DEFINE_string("data_dir", "", "Data directory.")
34 | flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen",
35 | "Temporary storage directory.")
36 | flags.DEFINE_string("game", None, "Atari game to generate data for.")
37 | flags.DEFINE_integer("num_env_steps", 5000, "Number of steps to roll out.")
38 | flags.DEFINE_boolean("eval", False, "Whether to run in eval mode.")
39 |
40 |
41 | def main(_):
42 |
43 | tf.gfile.MakeDirs(FLAGS.data_dir)
44 | tf.gfile.MakeDirs(FLAGS.tmp_dir)
45 |
46 | # Create problem if not already defined
47 | problem_name = "gym_discrete_problem_with_agent_on_%s" % FLAGS.game
48 | if problem_name not in registry.Registries.problems:
49 | gym_env.register_game(FLAGS.game)
50 |
51 | # Generate
52 | tf.logging.info("Running %s environment for %d steps for trajectories.",
53 | FLAGS.game, FLAGS.num_env_steps)
54 | problem = registry.problem(problem_name)
55 | problem.settable_num_steps = FLAGS.num_env_steps
56 | problem.settable_eval_phase = FLAGS.eval
57 | problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)
58 |
59 | # Log stats
60 | if problem.statistics.number_of_dones:
61 | mean_reward = (problem.statistics.sum_of_rewards /
62 | problem.statistics.number_of_dones)
63 | tf.logging.info("Mean reward: %.2f, Num dones: %d",
64 | mean_reward,
65 | problem.statistics.number_of_dones)
66 |
67 |
68 | if __name__ == "__main__":
69 | tf.app.run(main)
70 |
--------------------------------------------------------------------------------
/tensor2tensor/rl/envs/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/tensor2tensor/rl/evaluator_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests the evaluator."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from tensor2tensor.rl import evaluator
22 | from tensor2tensor.utils import registry
23 |
24 | import tensorflow.compat.v1 as tf
25 |
26 |
27 | class EvalTest(tf.test.TestCase):
28 |
29 | def test_evaluate_pong_random_agent(self):
30 | loop_hparams = registry.hparams("rlmb_tiny")
31 | planner_hparams = registry.hparams("planner_tiny")
32 | temp_dir = tf.test.get_temp_dir()
33 | evaluator.evaluate(
34 | loop_hparams, planner_hparams, temp_dir, temp_dir, temp_dir,
35 | agent_type="random", eval_mode="agent_real", eval_with_learner=False,
36 | log_every_steps=None, debug_video_path=""
37 | )
38 |
39 |
40 | if __name__ == "__main__":
41 | tf.test.main()
42 |
--------------------------------------------------------------------------------
/tensor2tensor/rl/policy_learner.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Unified interface for different RL algorithms."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 |
23 | class PolicyLearner(object):
24 | """API for policy learners."""
25 |
26 | def __init__(
27 | self, frame_stack_size, base_event_dir, agent_model_dir, total_num_epochs
28 | ):
29 | self.frame_stack_size = frame_stack_size
30 | self.base_event_dir = base_event_dir
31 | self.agent_model_dir = agent_model_dir
32 | self.total_num_epochs = total_num_epochs
33 |
34 | def train(
35 | self,
36 | env_fn,
37 | hparams,
38 | simulated,
39 | save_continuously,
40 | epoch,
41 | sampling_temp=1.0,
42 | num_env_steps=None,
43 | env_step_multiplier=1,
44 | eval_env_fn=None,
45 | report_fn=None
46 | ):
47 | """Train."""
48 | raise NotImplementedError()
49 |
50 | def evaluate(self, env_fn, hparams, sampling_temp):
51 | raise NotImplementedError()
52 |
--------------------------------------------------------------------------------
/tensor2tensor/rl/trainer_model_based_agent_only.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Training of model-based RL agent assuming a fully trained world model.
17 |
18 | Example invocation:
19 |
20 | python -m tensor2tensor.rl.trainer_model_based_agent_only \
21 | --loop_hparams_set=rl_modelrl_base \
22 | --world_model_dir=$HOME/world_model/ \
23 | --data_dir=$HOME/data/ \
24 | --output_dir=$HOME/ppo_agent_only/ \
25 | """
26 | from __future__ import absolute_import
27 | from __future__ import division
28 | from __future__ import print_function
29 |
30 | from tensor2tensor.bin import t2t_trainer # pylint: disable=unused-import
31 | from tensor2tensor.data_generators import gym_env
32 | from tensor2tensor.rl import trainer_model_based
33 | from tensor2tensor.rl import trainer_model_based_params
34 |
35 |
36 | import tensorflow.compat.v1 as tf
37 |
38 |
39 | flags = tf.flags
40 | FLAGS = flags.FLAGS
41 |
42 | flags.DEFINE_string("world_model_dir", "",
43 | "Directory containing checkpoints of the world model.")
44 |
45 |
46 | def get_simulated_problem_name(game):
47 | game_with_mode = game
48 | if game in gym_env.ATARI_GAMES:
49 | game_with_mode += "_deterministic-v4"
50 | return "gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode
51 |
52 |
53 | def main(_):
54 | hparams = trainer_model_based_params.create_loop_hparams()
55 | problem_name = get_simulated_problem_name(hparams.game)
56 | world_model_dir = FLAGS.world_model_dir
57 | agent_model_dir = FLAGS.output_dir
58 | event_dir = FLAGS.output_dir
59 | epoch_data_dir = FLAGS.data_dir # only required for initial frames
60 |
61 | trainer_model_based.train_agent(
62 | problem_name,
63 | agent_model_dir,
64 | event_dir,
65 | world_model_dir,
66 | epoch_data_dir,
67 | hparams,
68 | 0,
69 | epoch=0,
70 | is_final_epoch=True)
71 |
72 | if __name__ == "__main__":
73 | tf.logging.set_verbosity(tf.logging.INFO)
74 | tf.app.run()
75 |
--------------------------------------------------------------------------------
/tensor2tensor/rl/trainer_model_based_recurrent_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tiny run of trainer_model_based. Smoke test."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from tensor2tensor.rl import trainer_model_based
22 |
23 | import tensorflow.compat.v1 as tf
24 |
25 | FLAGS = tf.flags.FLAGS
26 |
27 |
28 | class ModelRLExperimentRecurrentTest(tf.test.TestCase):
29 |
30 | def test_basic_recurrent(self):
31 | FLAGS.output_dir = tf.test.get_temp_dir()
32 | FLAGS.loop_hparams_set = "rlmb_tiny_recurrent"
33 | FLAGS.schedule = "train" # skip evaluation for world model training
34 | trainer_model_based.main(None)
35 |
36 |
37 | if __name__ == "__main__":
38 | tf.test.main()
39 |
--------------------------------------------------------------------------------
/tensor2tensor/rl/trainer_model_based_stochastic_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tiny run of trainer_model_based with stochastic model. Smoke test."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from tensor2tensor.rl import trainer_model_based
22 |
23 | import tensorflow.compat.v1 as tf
24 |
25 | FLAGS = tf.flags.FLAGS
26 |
27 |
28 | class ModelRLExperimentStochasticTest(tf.test.TestCase):
29 |
30 | def test_basic_stochastic(self):
31 | FLAGS.output_dir = tf.test.get_temp_dir()
32 | FLAGS.loop_hparams_set = "rlmb_tiny_stochastic"
33 | FLAGS.schedule = "train" # skip evaluation for world model training
34 | trainer_model_based.main(None)
35 |
36 |
37 | if __name__ == "__main__":
38 | tf.test.main()
39 |
--------------------------------------------------------------------------------
/tensor2tensor/rl/trainer_model_based_sv2p_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tiny run of trainer_model_based with stochastic model. Smoke test."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from tensor2tensor.rl import trainer_model_based
22 |
23 | import tensorflow.compat.v1 as tf
24 |
25 | FLAGS = tf.flags.FLAGS
26 |
27 |
28 | class ModelRLExperimentSv2pTest(tf.test.TestCase):
29 |
30 | def test_sv2p(self):
31 | FLAGS.output_dir = tf.test.get_temp_dir()
32 | FLAGS.loop_hparams_set = "rlmb_tiny_sv2p"
33 | trainer_model_based.main(None)
34 |
35 |
36 | if __name__ == "__main__":
37 | tf.test.main()
38 |
--------------------------------------------------------------------------------
/tensor2tensor/rl/trainer_model_based_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tiny run of trainer_model_based. Smoke test."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from tensor2tensor.rl import trainer_model_based
22 |
23 | import tensorflow.compat.v1 as tf
24 |
25 | FLAGS = tf.flags.FLAGS
26 |
27 |
28 | class ModelRLExperimentTest(tf.test.TestCase):
29 |
30 | def _test_hparams_skip_evaluation(self, hparams_set):
31 | FLAGS.output_dir = tf.test.get_temp_dir()
32 | FLAGS.loop_hparams_set = hparams_set
33 | FLAGS.schedule = "train" # skip evaluation for world model training
34 | trainer_model_based.main(None)
35 |
36 | def test_basic(self):
37 | self._test_hparams_skip_evaluation("rlmb_tiny")
38 |
39 | # TODO(kozak): enable when it works.
40 | # def test_dqn_basic(self):
41 | # self._test_hparams_skip_evaluation("rlmb_dqn_tiny")
42 |
43 |
44 | if __name__ == "__main__":
45 | tf.test.main()
46 |
--------------------------------------------------------------------------------
/tensor2tensor/rl/trainer_model_free_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests of basic flow of collecting trajectories and training PPO."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from tensor2tensor.rl import trainer_model_free
22 | from tensor2tensor.utils import registry
23 |
24 | import tensorflow.compat.v1 as tf
25 |
26 | FLAGS = tf.flags.FLAGS
27 |
28 |
29 | class TrainTest(tf.test.TestCase):
30 |
31 | def _test_hparams_set(self, hparams_set):
32 | hparams = registry.hparams(hparams_set)
33 | FLAGS.output_dir = tf.test.get_temp_dir()
34 | trainer_model_free.train(hparams, FLAGS.output_dir,
35 | env_problem_name=None)
36 |
37 | def test_train_pong(self):
38 | self._test_hparams_set("rlmf_tiny")
39 |
40 | def test_train_pong_dqn(self):
41 | self._test_hparams_set("rlmf_dqn_tiny")
42 |
43 |
44 | if __name__ == "__main__":
45 | tf.test.main()
46 |
--------------------------------------------------------------------------------
/tensor2tensor/rl/trainer_model_free_tictactoe_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests of basic flow of collecting trajectories and training PPO."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.rl import trainer_model_free
23 | from tensor2tensor.utils import registry
24 |
25 | import tensorflow.compat.v1 as tf
26 |
27 | FLAGS = tf.flags.FLAGS
28 |
29 |
30 | class TrainerModelFreeTicTacToeTest(tf.test.TestCase):
31 |
32 | def test_train_tictactoe(self):
33 | hparams = registry.hparams("rlmf_tictactoe")
34 | hparams.batch_size = 2
35 | hparams.eval_sampling_temps = [0.0, 1.0]
36 | hparams.add_hparam("ppo_epochs_num", 2)
37 | hparams.add_hparam("ppo_epoch_length", 3)
38 |
39 | hparams.epochs_num = 100
40 | hparams.eval_every_epochs = 25
41 |
42 | FLAGS.output_dir = tf.test.get_temp_dir()
43 | FLAGS.env_problem_name = "tic_tac_toe_env_problem"
44 | trainer_model_free.train(hparams, FLAGS.output_dir, FLAGS.env_problem_name)
45 |
46 |
47 | if __name__ == "__main__":
48 | tf.test.main()
49 |
--------------------------------------------------------------------------------
/tensor2tensor/serving/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/tensor2tensor/test_data/example_usr_dir/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Example T2T user directory."""
17 | from . import my_submodule
18 |
--------------------------------------------------------------------------------
/tensor2tensor/test_data/example_usr_dir/requirements.txt:
--------------------------------------------------------------------------------
1 | gutenberg
2 |
--------------------------------------------------------------------------------
/tensor2tensor/test_data/transformer_test_ckpt/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "model.ckpt-1"
2 | all_model_checkpoint_paths: "model.ckpt-1"
3 |
--------------------------------------------------------------------------------
/tensor2tensor/test_data/transformer_test_ckpt/flags.txt:
--------------------------------------------------------------------------------
1 | --eval_steps=1
2 | --hparams_range=
3 | --t2t_usr_dir=
4 | --enable_graph_rewriter=False
5 | --sync=False
6 | --eval_run_autoregressive=False
7 | --eval_use_test_set=False
8 | --worker_id=0
9 | --eval_early_stopping_metric_minimize=True
10 | --worker_replicas=1
11 | --random_seed=1234
12 | --worker_gpu_memory_fraction=0.95
13 | --train_steps=1
14 | --iterations_per_loop=1000
15 | --registry_help=False
16 | --worker_gpu=1
17 | --keep_checkpoint_max=20
18 | --save_checkpoints_secs=0
19 | --gpu_order=
20 | --master=
21 | --generate_data=False
22 | --local_eval_frequency=2000
23 | --export_saved_model=False
24 | --eval_early_stopping_steps=None
25 | --output_dir=/tmp/oss_train
26 | --profile=False
27 | --ps_job=/job:ps
28 | --tmp_dir=/tmp/t2t_datagen
29 | --schedule=continuous_train_and_eval
30 | --problem=translate_ende_wmt8k
31 | --hparams=
32 | --use_tpu=False
33 | --eval_early_stopping_metric_delta=0.1
34 | --ps_gpu=0
35 | --keep_checkpoint_every_n_hours=10000
36 | --decode_hparams=
37 | --tfdbg=False
38 | --data_dir=~/t2t/data
39 | --ps_replicas=0
40 | --eval_early_stopping_metric=loss
41 | --log_device_placement=False
42 | --hparams_set=transformer_test
43 | --dbgprofile=False
44 | --timit_paths=
45 | --tpu_num_shards=8
46 | --locally_shard_to_cpu=False
47 | --worker_job=/job:localhost
48 | --model=transformer
49 | --parsing_path=
50 |
--------------------------------------------------------------------------------
/tensor2tensor/test_data/transformer_test_ckpt/hparams.json:
--------------------------------------------------------------------------------
1 | {"daisy_chain_variables": true, "optimizer_adam_beta1": 0.9, "scheduled_sampling_prob": 0.0, "num_hidden_layers": 2, "moe_loss_coef": 0.01, "max_target_seq_length": 0, "clip_grad_norm": 0.0, "pos": "timing", "scheduled_sampling_gold_mixin_prob": 0.5, "initializer": "uniform_unit_scaling", "grad_noise_scale": 0.0, "optimizer_momentum_momentum": 0.9, "nbr_decoder_problems": 1, "attention_key_channels": 0, "eval_drop_long_sequences": false, "learning_rate_cosine_cycle_steps": 250000, "prepend_mode": "none", "weight_decay": 0.0, "symbol_modality_skip_top": false, "weight_noise": 0.0, "target_modality": "default", "attention_dropout": 0.1, "parameter_attention_value_channels": 0, "factored_logits": false, "relu_dropout": 0.1, "no_data_parallelism": false, "layer_preprocess_sequence": "n", "sampling_method": "argmax", "learning_rate": 0.2, "num_heads": 2, "max_length": 256, "summarize_grads": false, "attention_value_channels": 0, "num_encoder_layers": 0, "label_smoothing": 0.1, "use_fixed_batch_size": false, "optimizer": "adam", "moe_k": 2, "self_attention_type": "dot_product", "learning_rate_decay_scheme": "noam", "sampling_temp": 1.0, "kernel_height": 3, "use_pad_remover": true, "batch_size": 4096, "max_relative_position": 0, "force_full_predict": false, "min_length_bucket": 8, "layer_prepostprocess_dropout": 0.1, "eval_run_autoregressive": false, "shared_embedding_and_softmax_weights": true, "symbol_modality_num_shards": 16, "dropout": 0.2, "compress_steps": 0, "parameter_attention_key_channels": 0, "length_bucket_step": 1.1, "kernel_width": 1, "hidden_size": 16, "num_decoder_layers": 0, "input_modalities": "default", "filter_size": 8, "optimizer_adam_beta2": 0.98, "scheduled_sampling_warmup_steps": 50000, "norm_type": "layer", "min_length": 0, "moe_num_experts": 64, "multiply_embedding_mode": "sqrt_depth", "max_input_seq_length": 0, "learning_rate_warmup_steps": 8000, "proximity_bias": false, "ffn_layer": "dense_relu_dense", "initializer_gain": 1.0, "layer_postprocess_sequence": "da", "moe_hidden_sizes": "2048", "optimizer_adam_epsilon": 1e-09, "norm_epsilon": 1e-06}
2 |
--------------------------------------------------------------------------------
/tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.data-00000-of-00002:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/tensor2tensor/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.data-00000-of-00002
--------------------------------------------------------------------------------
/tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.data-00001-of-00002:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/tensor2tensor/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.data-00001-of-00002
--------------------------------------------------------------------------------
/tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/tensor2tensor/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.index
--------------------------------------------------------------------------------
/tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/tensor2tensor/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/test_data/transformer_test_ckpt/model.ckpt-1.meta
--------------------------------------------------------------------------------
/tensor2tensor/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/tensor2tensor/utils/adafactor_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for adafactor."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.utils import adafactor
23 |
24 | import tensorflow as tf
25 |
26 |
27 | class AdafactorTest(tf.test.TestCase):
28 |
29 | def testCallableLearningRate(self):
30 | def lr():
31 | return 0.01
32 |
33 | opt = adafactor.AdafactorOptimizer(learning_rate=lr)
34 | v1 = tf.Variable([1., 2.])
35 | v2 = tf.Variable([3., 4.])
36 | with tf.GradientTape() as tape:
37 | tape.watch([v1, v2])
38 | loss = v1 * v2
39 | v1_grad, v2_grad = tape.gradient(loss, [v1, v2])
40 | opt.apply_gradients(((v1_grad, v1), (v2_grad, v2)))
41 |
42 |
43 | if __name__ == '__main__':
44 | tf.test.main()
45 |
--------------------------------------------------------------------------------
/tensor2tensor/utils/compute_video_metrics.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Computes and saves the metrics for video prediction and generation."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 |
24 | from six.moves import range
25 | from tensor2tensor.bin import t2t_decoder
26 | from tensor2tensor.utils import video_metrics
27 | import tensorflow.compat.v1 as tf
28 |
29 |
30 | FLAGS = tf.flags.FLAGS
31 |
32 |
33 | def main(_):
34 | hparams = t2t_decoder.create_hparams()
35 | problem = hparams.problem
36 | frame_shape = [problem.frame_height,
37 | problem.frame_width,
38 | problem.num_channels]
39 | decode_hp = t2t_decoder.create_decode_hparams()
40 |
41 | output_dirs = [
42 | os.path.join(FLAGS.output_dir, "decode_%05d" % decode_id)
43 | for decode_id in range(decode_hp.num_decodes)
44 | ]
45 |
46 | video_metrics.compute_and_save_video_metrics(
47 | output_dirs,
48 | FLAGS.problem,
49 | hparams.video_num_target_frames,
50 | frame_shape)
51 |
52 |
53 | if __name__ == "__main__":
54 | tf.app.run(main)
55 |
--------------------------------------------------------------------------------
/tensor2tensor/utils/diet_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for common layers."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | from tensor2tensor.utils import diet
22 |
23 | import tensorflow.compat.v1 as tf
24 |
25 |
26 | class DietVarTest(tf.test.TestCase):
27 |
28 | def testDiet(self):
29 |
30 | params = diet.diet_adam_optimizer_params()
31 |
32 | @diet.fn_with_diet_vars(params)
33 | def model_fn(x):
34 | y = tf.layers.dense(x, 10, use_bias=False)
35 | return y
36 |
37 | @diet.fn_with_diet_vars(params)
38 | def model_fn2(x):
39 | y = tf.layers.dense(x, 10, use_bias=False)
40 | return y
41 |
42 | x = tf.random_uniform((10, 10))
43 | y = model_fn(x) + 10.
44 | y = model_fn2(y) + 10.
45 | grads = tf.gradients(y, [x])
46 | with tf.control_dependencies(grads):
47 | incr_step = tf.assign_add(tf.train.get_or_create_global_step(), 1)
48 |
49 | train_op = tf.group(incr_step, *grads)
50 | with self.test_session() as sess:
51 | sess.run(tf.global_variables_initializer())
52 | orig_vals = sess.run(tf.global_variables())
53 | for _ in range(10):
54 | sess.run(train_op)
55 | new_vals = sess.run(tf.global_variables())
56 |
57 | different = []
58 | for old, new in zip(orig_vals, new_vals):
59 | try:
60 | self.assertAllClose(old, new)
61 | except AssertionError:
62 | different.append(True)
63 | self.assertEqual(len(different), len(tf.global_variables()))
64 |
65 |
66 | if __name__ == "__main__":
67 | tf.test.main()
68 |
--------------------------------------------------------------------------------
/tensor2tensor/utils/get_cnndm_rouge.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Path to moses dir
4 | mosesdecoder=$1
5 |
6 | # Path to file containing gold summaries, one per line
7 | targets_file=$2
8 | # Path to file containing model generated summaries, one per line
9 | decodes_file=$3
10 |
11 | # Tokenize.
12 | perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l en < $targets_file > $targets_file.tok
13 | perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l en < $decodes_file > $decodes_file.tok
14 |
15 | # Get rouge scores
16 | python get_rouge.py --decodes_filename $decodes_file.tok --targets_filename $targets_file.tok
17 |
--------------------------------------------------------------------------------
/tensor2tensor/utils/get_ende_bleu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mosesdecoder=~/mosesdecoder
4 | tok_gold_targets=newstest2013.tok.de
5 |
6 | decodes_file=$1
7 |
8 | # Replace unicode.
9 | perl $mosesdecoder/scripts/tokenizer/replace-unicode-punctuation.perl -l de < $decodes_file > $decodes_file.n
10 |
11 | # Tokenize.
12 | perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l de < $decodes_file.n > $decodes_file.tok
13 |
14 | # Put compounds in ATAT format (comparable to papers like GNMT, ConvS2S).
15 | # See https://nlp.stanford.edu/projects/nmt/ :
16 | # 'Also, for historical reasons, we split compound words, e.g.,
17 | # "rich-text format" --> rich ##AT##-##AT## text format."'
18 | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $tok_gold_targets > $tok_gold_targets.atat
19 | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $decodes_file.tok > $decodes_file.tok.atat
20 |
21 | # Get BLEU.
22 | perl $mosesdecoder/scripts/generic/multi-bleu.perl $tok_gold_targets.atat < $decodes_file.tok.atat
23 |
--------------------------------------------------------------------------------
/tensor2tensor/utils/hparams_lib_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for trainer_lib."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 |
24 | from tensor2tensor.utils import hparams_lib
25 |
26 | import tensorflow.compat.v1 as tf
27 |
28 |
29 | class HparamsLibTest(tf.test.TestCase):
30 |
31 | def testCreateHparamsFromJson(self):
32 | # Get json_path
33 | pkg = os.path.abspath(__file__)
34 | pkg, _ = os.path.split(pkg)
35 | pkg, _ = os.path.split(pkg)
36 | json_path = os.path.join(
37 | pkg, "test_data", "transformer_test_ckpt", "hparams.json")
38 |
39 | # Create hparams
40 | hparams = hparams_lib.create_hparams_from_json(json_path)
41 | self.assertEqual(75, len(hparams.values()))
42 |
43 |
44 | if __name__ == "__main__":
45 | tf.test.main()
46 |
--------------------------------------------------------------------------------
/tensor2tensor/utils/misc_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Miscellaneous utilities."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import pprint
23 | import re
24 |
25 | # Camel case to snake case utils
26 | _first_cap_re = re.compile("(.)([A-Z][a-z0-9]+)")
27 | _all_cap_re = re.compile("([a-z0-9])([A-Z])")
28 |
29 |
30 | def camelcase_to_snakecase(name):
31 | s1 = _first_cap_re.sub(r"\1_\2", name)
32 | return _all_cap_re.sub(r"\1_\2", s1).lower()
33 |
34 |
35 | def snakecase_to_camelcase(name):
36 | return "".join([w[0].upper() + w[1:] for w in name.split("_")])
37 |
38 |
39 | def pprint_hparams(hparams):
40 | """Represents hparams using its dictionary and calls pprint.pformat on it."""
41 | return "\n{}".format(pprint.pformat(hparams.values(), width=1))
42 |
--------------------------------------------------------------------------------
/tensor2tensor/utils/optimize_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2tensor.utils.optimize."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl.testing import parameterized
23 | from tensor2tensor.utils import hparams_lib
24 | from tensor2tensor.utils import optimize
25 | import tensorflow.compat.v1 as tf
26 |
27 |
28 | class OptimizeTest(parameterized.TestCase, tf.test.TestCase):
29 |
30 | @parameterized.parameters(
31 | "sgd",
32 | "SGD",
33 | "rms_prop",
34 | "RMSProp",
35 | "adagrad",
36 | "Adagrad",
37 | "adam",
38 | "Adam",
39 | "adam_w",
40 | "AdamW",
41 | )
42 | def test_names(self, opt_name):
43 | hparams = hparams_lib.create_hparams("basic_1")
44 | optimize.ConditionalOptimizer(opt_name, 0.1, hparams)
45 |
46 |
47 | if __name__ == "__main__":
48 | tf.test.main()
49 |
--------------------------------------------------------------------------------
/tensor2tensor/utils/partial_checkpoint_load_hook.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Hook to partially load a checkpoint."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow.compat.v1 as tf
22 |
23 |
24 | class PartialCheckpointLoad(tf.train.SessionRunHook):
25 | """Partially load train_variables from a checkpoint.
26 |
27 | Hook used to load each variable saved in checkpoint into the graph. It
28 | will ignore any additional variables present in the graph that are not
29 | saved in the checkpoint. (Note: The loaded variables include ADAM/training
30 | variables, if they exist in the checkpoint)
31 | Can perform mapping if the base scopename for graph variables is different
32 | from the checkpoint variables.
33 | """
34 |
35 | def __init__(self, hook_context, chk_scopename, graph_scopename):
36 | """Initialize the hook with chkp directory and scopenames.
37 |
38 | Args:
39 | hook_context: HookContext object containing hparams.
40 | chk_scopename: Base scopename of variables in the checkpoint being loaded
41 | graph_scopename: Base scopename of variables in current graph
42 | """
43 | self.checkpoint_path = hook_context.hparams.partial_load_checkpoint
44 | self.chk_scopename = chk_scopename
45 | self.graph_scopename = graph_scopename
46 |
47 | def begin(self):
48 | # TODO(karishmamalkan): Add logging for when variables are loaded
49 | variable_references = {var.name: var for var in tf.all_variables()}
50 | variable_mappings = {}
51 | vars_in_chk = tf.train.list_variables(self.checkpoint_path)
52 | for (var, _) in vars_in_chk:
53 | variable_mappings[var] = variable_references[
54 | var.replace(self.chk_scopename, self.graph_scopename) + ":0"]
55 | tf.train.init_from_checkpoint(self.checkpoint_path, variable_mappings)
56 |
--------------------------------------------------------------------------------
/tensor2tensor/utils/restore_hook.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Restore hooks."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import six
23 |
24 | from tensor2tensor.utils import contrib
25 | import tensorflow.compat.v1 as tf
26 |
27 |
28 | class RestoreHook(tf.train.SessionRunHook):
29 | """Restore variables from a checkpoint path."""
30 |
31 | def __init__(self, checkpoint_path="", new_model_scope="", old_model_scope="",
32 | include=None, exclude=None):
33 | self._checkpoint_path = checkpoint_path
34 | self._new_model_scope = new_model_scope
35 | self._old_model_scope = old_model_scope
36 | self._include = include
37 | self._exclude = exclude
38 |
39 | def begin(self):
40 | """Load variables from checkpoint.
41 |
42 | New model variables have the following name foramt:
43 | new_model_scope/old_model_scope/xxx/xxx:0 To find the map of
44 | name to variable, need to strip the new_model_scope and then
45 | match the old_model_scope and remove the suffix :0.
46 |
47 | """
48 | variables_to_restore = contrib.framework().get_variables_to_restore(
49 | include=self._include, exclude=self._exclude)
50 | # remove new_model_scope from variable name prefix
51 | assignment_map = {variable.name[len(self._new_model_scope):]: variable
52 | for variable in variables_to_restore
53 | if variable.name.startswith(self._new_model_scope)}
54 | # remove :0 from variable name suffix
55 | assignment_map = {name.split(":")[0]: variable
56 | for name, variable in six.iteritems(assignment_map)
57 | if name.startswith(self._old_model_scope)}
58 | self._assignment_map = assignment_map
59 |
60 | tf.logging.info("restoring %d variables from checkpoint %s"%(
61 | len(assignment_map), self._checkpoint_path))
62 | tf.train.init_from_checkpoint(self._checkpoint_path, self._assignment_map)
63 |
--------------------------------------------------------------------------------
/tensor2tensor/utils/test_utils_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tensor2tensor.utils.test_utils."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from tensor2tensor.utils import test_utils
23 |
24 | import tensorflow.compat.v1 as tf
25 | tf.enable_eager_execution()
26 |
27 |
28 | class RunInGraphAndEagerTest(tf.test.TestCase):
29 |
30 | def test_run_in_graph_and_eager_modes(self):
31 | l = []
32 | def inc(self, with_brackets):
33 | del self # self argument is required by run_in_graph_and_eager_modes.
34 | mode = "eager" if tf.executing_eagerly() else "graph"
35 | with_brackets = "with_brackets" if with_brackets else "without_brackets"
36 | l.append((with_brackets, mode))
37 |
38 | f = test_utils.run_in_graph_and_eager_modes(inc)
39 | f(self, with_brackets=False)
40 | f = test_utils.run_in_graph_and_eager_modes()(inc)
41 | f(self, with_brackets=True)
42 |
43 | self.assertEqual(len(l), 4)
44 | self.assertEqual(set(l), {
45 | ("with_brackets", "graph"),
46 | ("with_brackets", "eager"),
47 | ("without_brackets", "graph"),
48 | ("without_brackets", "eager"),
49 | })
50 |
51 | def test_run_in_graph_and_eager_modes_setup_in_same_mode(self):
52 | modes = []
53 | mode_name = lambda: "eager" if tf.executing_eagerly() else "graph"
54 |
55 | class ExampleTest(tf.test.TestCase):
56 |
57 | def runTest(self):
58 | pass
59 |
60 | def setUp(self):
61 | modes.append("setup_" + mode_name())
62 |
63 | @test_utils.run_in_graph_and_eager_modes
64 | def testBody(self):
65 | modes.append("run_" + mode_name())
66 |
67 | e = ExampleTest()
68 | e.setUp()
69 | e.testBody()
70 |
71 | self.assertEqual(modes[0:2], ["setup_eager", "run_eager"])
72 | self.assertEqual(modes[2:], ["setup_graph", "run_graph"])
73 |
74 | if __name__ == "__main__":
75 | tf.test.main()
76 |
--------------------------------------------------------------------------------
/tensor2tensor/utils/update_ops_hook.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Hook to run tf.GraphKeys.UPDATE_OPS."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow.compat.v1 as tf
22 |
23 |
24 | class UpdateOpsHook(tf.train.SessionRunHook):
25 | """Hook to run assign_ops."""
26 |
27 | def before_run(self, run_context):
28 | del run_context
29 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
30 | return tf.train.SessionRunArgs(update_ops)
31 |
--------------------------------------------------------------------------------
/tensor2tensor/utils/usr_dir.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utility to load code from an external user-supplied directory."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import importlib
22 | import os
23 | import sys
24 | import tensorflow.compat.v1 as tf
25 |
26 |
27 | INTERNAL_USR_DIR_PACKAGE = "t2t_usr_dir_internal"
28 |
29 |
30 | def import_usr_dir(usr_dir):
31 | """Import module at usr_dir, if provided."""
32 | if not usr_dir:
33 | return
34 | if usr_dir == INTERNAL_USR_DIR_PACKAGE:
35 | # The package has been installed with pip under this name for Cloud ML
36 | # Engine so just import it.
37 | importlib.import_module(INTERNAL_USR_DIR_PACKAGE)
38 | return
39 |
40 | dir_path = os.path.abspath(os.path.expanduser(usr_dir).rstrip("/"))
41 | containing_dir, module_name = os.path.split(dir_path)
42 | tf.logging.info("Importing user module %s from path %s", module_name,
43 | containing_dir)
44 | sys.path.insert(0, containing_dir)
45 | importlib.import_module(module_name)
46 | sys.path.pop(0)
47 |
--------------------------------------------------------------------------------
/tensor2tensor/utils/video_metrics_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """video metrics test."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import numpy as np
23 | from tensor2tensor.utils import video_metrics
24 | import tensorflow.compat.v1 as tf
25 |
26 |
27 | class VideoMetricsTest(tf.test.TestCase):
28 |
29 | def test_reduce_to_best_decode(self):
30 | # num_decodes=2, num_samples=3, num_frames=4
31 | decode1 = [
32 | [30.0, 32.0, 33.0, 34.0],
33 | [22.0, 19.0, 12.0, 13.0],
34 | [30.0, 10.0, 30.0, 10.0]]
35 | decode2 = [
36 | [22.0, 19.0, 12.0, 13.0],
37 | [30.0, 32.0, 33.0, 34.0],
38 | [25.0, 25.0, 25.0, 25.0]]
39 | all_decodes = [decode1, decode2]
40 | all_decodes = np.array(all_decodes)
41 | best_decode, best_decode_ind = video_metrics.reduce_to_best_decode(
42 | all_decodes, np.argmax)
43 | worst_decode, worst_decode_ind = video_metrics.reduce_to_best_decode(
44 | all_decodes, np.argmin)
45 | exp_best_decode = [
46 | [30.0, 32.0, 33.0, 34.0],
47 | [30.0, 32.0, 33.0, 34.0],
48 | [25.0, 25.0, 25.0, 25.0]]
49 | exp_worst_decode = [
50 | [22.0, 19.0, 12.0, 13.0],
51 | [22.0, 19.0, 12.0, 13.0],
52 | [30.0, 10.0, 30.0, 10.0]]
53 | self.assertTrue(np.allclose(best_decode, exp_best_decode))
54 | self.assertTrue(np.allclose(worst_decode, exp_worst_decode))
55 | self.assertTrue(np.allclose(best_decode_ind, [0, 1, 1]))
56 | self.assertTrue(np.allclose(worst_decode_ind, [1, 0, 0]))
57 |
58 |
59 | if __name__ == '__main__':
60 | tf.test.main()
61 |
--------------------------------------------------------------------------------
/tensor2tensor/visualization/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 |
--------------------------------------------------------------------------------