├── .coveragerc ├── .github └── workflows │ └── build.yaml ├── .readthedocs.yaml ├── .travis.yml ├── AUTHORS ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE.md ├── LICENSE ├── README.md ├── docs ├── .readthedocs.yaml ├── Makefile ├── requirements.txt └── source │ ├── conf.py │ ├── index.rst │ ├── notebooks │ ├── layers_intro.ipynb │ ├── tf_numpy_and_keras.ipynb │ └── trax_intro.ipynb │ ├── trax.data.rst │ ├── trax.fastmath.rst │ ├── trax.layers.rst │ ├── trax.models.rst │ ├── trax.optimizers.rst │ ├── trax.rl.rst │ ├── trax.rst │ └── trax.supervised.rst ├── oss_scripts ├── oss_pip_install.sh ├── oss_release.sh └── oss_tests.sh ├── pylintrc ├── setup.py └── trax ├── __init__.py ├── data ├── __init__.py ├── debug_data_pipeline.py ├── inputs.py ├── inputs_test.py ├── testdata │ ├── bert_uncased_vocab.txt │ ├── c4 │ │ └── en │ │ │ └── 2.3.0 │ │ │ ├── c4-train.tfrecord-00000-of-00001 │ │ │ ├── c4-validation.tfrecord-00000-of-00001 │ │ │ └── dataset_info.json │ ├── corpus-1.txt │ ├── corpus-2.txt │ ├── en_8k.subword │ ├── para_crawl │ │ └── ende │ │ │ └── 1.2.0 │ │ │ ├── dataset_info.json │ │ │ ├── features.json │ │ │ └── para_crawl-train.tfrecord-00000-of-00001 │ ├── sentencepiece.model │ ├── squad │ │ └── v1.1 │ │ │ └── 3.0.0 │ │ │ ├── dataset_info.json │ │ │ ├── squad-train.tfrecord-00000-of-00001 │ │ │ └── squad-validation.tfrecord-00000-of-00001 │ ├── vocab-1.txt │ └── vocab-2.txt ├── text_encoder.py ├── text_encoder_build_subword.py ├── text_encoder_test.py ├── tf_inputs.py ├── tf_inputs_test.py ├── tokenizer.py └── tokenizer_test.py ├── examples ├── Deep_N_Gram_Models.ipynb ├── Fashion_MNIST_with_Trax.ipynb ├── Knowledge_Tracing_Transformer.ipynb ├── MathQA_Python_generation_notebook.ipynb ├── NER_using_Reformer.ipynb ├── NMT_with_Transformers_Reformers_using_Trax.ipynb ├── README.md ├── Terraformer_from_scratch.ipynb ├── earlystopping.ipynb ├── illustrated_wideresnet.ipynb ├── semantic_segmentation.ipynb └── trax_data_Explained.ipynb ├── fastmath ├── __init__.py ├── jax.py ├── numpy.py ├── ops.py ├── ops_test.py └── tf.py ├── import_test.py ├── intro.ipynb ├── jaxboard.py ├── layers ├── README.md ├── __init__.py ├── acceleration.py ├── acceleration_test.py ├── activation_fns.py ├── activation_fns_test.py ├── assert_shape.py ├── assert_shape_test.py ├── attention.py ├── attention_test.py ├── base.py ├── base_test.py ├── combinators.py ├── combinators_test.py ├── convolution.py ├── convolution_test.py ├── core.py ├── core_test.py ├── deconvolution.py ├── deconvolution_test.py ├── initializers.py ├── initializers_test.py ├── intro.ipynb ├── metrics.py ├── metrics_test.py ├── normalization.py ├── normalization_test.py ├── pooling.py ├── pooling_test.py ├── research │ ├── __init__.py │ ├── efficient_attention.py │ ├── efficient_attention_test.py │ ├── position_encodings.py │ ├── position_encodings_test.py │ ├── rel_attention.py │ ├── rel_attention_test.py │ ├── resampling.py │ ├── rotary_positional_embedding.py │ ├── rotary_positional_embedding_test.py │ ├── sparsity.py │ └── sparsity_test.py ├── reversible.py ├── reversible_test.py ├── rnn.py ├── rnn_test.py ├── test_utils.py └── test_utils_test.py ├── models ├── Attention_Visualization_in_Trax.ipynb ├── __init__.py ├── atari_cnn.py ├── atari_cnn_test.py ├── mlp.py ├── mlp_test.py ├── neural_gpu.py ├── neural_gpu_test.py ├── reformer │ ├── README.md │ ├── __init__.py │ ├── image_generation.ipynb │ ├── machine_translation.ipynb │ ├── reformer.py │ ├── reformer_e2e_test.py │ ├── reformer_test.py │ ├── testdata │ │ ├── translate_ende_wmt32k-dev-00000-of-00001 │ │ ├── translate_ende_wmt32k-train-00000-of-00001 │ │ └── vocab.translate_ende_wmt32k.32768.subwords │ └── text_generation.ipynb ├── research │ ├── __init__.py │ ├── bert.py │ ├── configurable_transformer.py │ ├── configurable_transformer_test.py │ ├── examples │ │ ├── hourglass_downsampled_imagenet.ipynb │ │ └── hourglass_enwik8.ipynb │ ├── hourglass.py │ ├── hourglass_test.py │ ├── layerdrop_transformer.py │ ├── layerdrop_transformer_test.py │ ├── predict_terraformer.py │ ├── rezero.py │ ├── rezero_test.py │ ├── rse.py │ ├── rse_test.py │ ├── terraformer.py │ ├── terraformer_e2e_test.py │ ├── terraformer_memory_test.py │ ├── terraformer_oom_test.py │ ├── terraformer_test.py │ ├── testdata │ │ ├── translate_ende_wmt32k-dev-00000-of-00001 │ │ ├── translate_ende_wmt32k-train-00000-of-00001 │ │ └── vocab.translate_ende_wmt32k.32768.subwords │ ├── transformer2.py │ └── transformer2_test.py ├── resnet.py ├── resnet_test.py ├── rl.py ├── rl_test.py ├── rnn.py ├── rnn_test.py ├── transformer.py └── transformer_test.py ├── optimizers ├── __init__.py ├── adafactor.py ├── adam.py ├── base.py ├── momentum.py ├── optimizers_test.py ├── rms_prop.py ├── sm3.py ├── trainer.py └── trainer_test.py ├── predict_drop.py ├── rl ├── __init__.py ├── actor_critic.py ├── actor_critic_joint.py ├── actor_critic_joint_test.py ├── actor_critic_test.py ├── advantages.py ├── advantages_test.py ├── atari_test.py ├── configs │ ├── dqn_cartpole_regression.gin │ ├── light_a2c_joint_atari_sweep.yaml │ ├── light_atari.gin │ ├── light_atari_sweep.yaml │ ├── light_awr_cartpole_sweep.yaml │ ├── light_awr_joint_atari_sweep.yaml │ ├── light_awr_joint_cartpole.gin │ ├── light_awr_joint_cartpole_sweep.yaml │ ├── light_cartpole.gin │ ├── light_copy.gin │ ├── light_copy_sweep.yaml │ ├── light_joint_atari.gin │ ├── light_joint_cartpole.gin │ ├── light_lunarlander.gin │ ├── light_mujoco.gin │ ├── light_mujoco_regression_test.gin │ ├── light_mujoco_sweep.yaml │ ├── light_ppo_atari.gin │ ├── light_ppo_boxing_regression_test.gin │ ├── light_ppo_cartpole_regression_test.gin │ ├── light_ppo_half_cheetah_regression_test.gin │ ├── light_ppo_joint_atari.gin │ ├── light_ppo_joint_atari_sweep.yaml │ ├── light_ppo_lunar_lander_regression_test.gin │ ├── light_ppo_pong_regression_test.gin │ ├── ppo_atari_sweep.yaml │ ├── ppo_cartpole_sweep.yaml │ └── transformer_srl_sine.gin ├── distributions.py ├── distributions_test.py ├── envs │ ├── __init__.py │ ├── data_envs.py │ └── data_envs_test.py ├── normalization.py ├── normalization_test.py ├── policy_tasks.py ├── rl_layers.py ├── serialization_utils.py ├── serialization_utils_test.py ├── space_serializer.py ├── space_serializer_test.py ├── task.py ├── task_test.py ├── training.py ├── training_test.py ├── value_tasks.py └── value_tasks_test.py ├── rl_trainer.py ├── shapes.py ├── shapes_test.py ├── supervised ├── __init__.py ├── callbacks.py ├── callbacks_test.py ├── configs │ ├── bert.gin │ ├── bert_glue_classification.gin │ ├── bert_glue_regression.gin │ ├── bert_glue_sweep_regression_task.yaml │ ├── bert_glue_sweep_single_sentence.yaml │ ├── bert_glue_sweep_two_sentences.yaml │ ├── bert_pretraining.gin │ ├── bert_pretraining_onlymlm.gin │ ├── bert_pretraining_onlynsp.gin │ ├── c4.gin │ ├── c4_pretrain_16gb_adafactor.gin │ ├── c4_trax_data.gin │ ├── cond_skipping_transformer_lm1b.gin │ ├── gru_copy.gin │ ├── hourglass_cifar10.gin │ ├── hourglass_enwik8.gin │ ├── hourglass_imagenet32.gin │ ├── hourglass_imagenet64.gin │ ├── hourglass_wiki40b.gin │ ├── layerdrop_every_transformer_lm1b.gin │ ├── layerdrop_transformer_lm1b.gin │ ├── layerdrop_ushape_transformer_lm1b.gin │ ├── lstm_lm1b.gin │ ├── lstm_seq2seq_wmt_ende.gin │ ├── mlp_mnist.gin │ ├── reformer_addition.gin │ ├── reformer_bair_robot_pushing.gin │ ├── reformer_c4.gin │ ├── reformer_cifar10.gin │ ├── reformer_copy.gin │ ├── reformer_enwik8.gin │ ├── reformer_imagenet64.gin │ ├── reformer_imagenet64_testing.gin │ ├── reformer_pc_enpl.gin │ ├── reformer_wmt_ende.gin │ ├── reformer_wmt_ende_big.gin │ ├── resnet50_frn_imagenet_8gb.gin │ ├── resnet50_imagenet_8gb_testing.gin │ ├── rezero_wmt_ende_16gb_adafactor_testing.gin │ ├── rse_addition.gin │ ├── rse_addition_sweep.yaml │ ├── scientific_papers_reformer_lm.gin │ ├── scientific_papers_terraformer.gin │ ├── scientific_papers_terraformer_favor.gin │ ├── scientific_papers_terraformer_pretrained.gin │ ├── skipping_transformer_lm1b.gin │ ├── sp_sweep.yaml │ ├── sparse_c4_pretrain_16gb_adafactor.gin │ ├── sparse_lm1b_pretrain_16gb.gin │ ├── t5_aqua_parallel.gin │ ├── t5_drop.gin │ ├── t5_glue_classification.gin │ ├── t5_glue_classification_mnli.gin │ ├── t5_glue_classification_parallel.gin │ ├── t5_glue_classification_two_constants.gin │ ├── t5_mathqa.gin │ ├── t5_mathqa_drop_loop.gin │ ├── t5_mathqa_drop_sweep.yaml │ ├── t5_mathqa_multi.gin │ ├── t5_mathqa_parallel.gin │ ├── t5_mathqa_parallel_full.gin │ ├── t5_mathqa_parallel_full_correct_order.gin │ ├── t5_mathqa_parallel_full_order.gin │ ├── t5_mathqa_parallel_with_drop_annot.gin │ ├── t5_sweep.yaml │ ├── t5_sweep_temperature.yaml │ ├── terraformer_c4_medium.gin │ ├── terraformer_copy.gin │ ├── terraformer_copy_self_attn.gin │ ├── terraformer_purelsh_copy.gin │ ├── terraformer_wmt_ende.gin │ ├── transformer_big_lm1b_8gb.gin │ ├── transformer_finetune_squad_16gb.gin │ ├── transformer_imdb_8gb.gin │ ├── transformer_imdb_tfds.gin │ ├── transformer_lm1b_8gb_testing.gin │ ├── transformer_lm_cnndailymail.gin │ ├── transformer_lm_wmt_ende_16gb.gin │ ├── transformer_lm_wmt_ende_8gb.gin │ ├── transformer_ptb_16gb.gin │ ├── transformer_wmt_ende_16gb_adafactor_testing.gin │ ├── transformer_wmt_ende_8gb.gin │ └── wide_resnet_cifar10_8gb.gin ├── decoding.py ├── decoding_test.py ├── decoding_timing_test.py ├── history.py ├── history_test.py ├── lr_schedules.py ├── lr_schedules_test.py ├── mnist_test.py ├── pretrain_finetune.py ├── testdata │ ├── reformerlm_copy_lsh_attn.pkl.gz │ ├── terraformer_copy_lsh_attn.pkl.gz │ ├── terraformer_copy_self_attn.pkl.gz │ ├── terraformer_purelsh_copy.pkl.gz │ ├── transformer_copy.pkl.gz │ └── transformerlm_copy.pkl.gz ├── trainer_lib.py ├── trainer_lib_test.py ├── training.py └── training_test.py ├── test_utils.py ├── tf_numpy ├── __init__.py ├── examples │ └── mnist │ │ ├── dataset.py │ │ ├── model.py │ │ ├── train.py │ │ └── train_test.py ├── extensions │ ├── __init__.py │ ├── extensions.py │ └── extensions_test.py ├── jax_tests │ ├── config.py │ ├── lax_numpy_einsum_test.py │ ├── lax_numpy_indexing_test.py │ ├── lax_numpy_test.py │ ├── test_util.py │ └── vmap_test.py ├── numpy │ └── __init__.py ├── numpy_impl │ ├── __init__.py │ ├── array_ops.py │ ├── arrays.py │ ├── dtypes.py │ ├── math_ops.py │ ├── random.py │ ├── tests │ │ ├── array_ops_test.py │ │ ├── arrays_test.py │ │ ├── backprop_test.py │ │ ├── logic_test.py │ │ ├── math_ops_test.py │ │ ├── random_test.py │ │ └── utils_test.py │ └── utils.py └── public_symbol_test.py ├── tf_numpy_and_keras.ipynb ├── trainer.py ├── trainer_flags.py ├── trax2keras.py └── trax2keras_test.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = 3 | trax/ 4 | omit = 5 | *_test.py 6 | */site-packages/* 7 | 8 | [report] 9 | omit = 10 | */site-packages/* 11 | exclude_lines = 12 | pragma: no cover 13 | def __repr__ 14 | raise NotImplementedError 15 | if __name__ == .__main__.: 16 | 17 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # .readthedocs.yml 16 | # Read the Docs configuration file 17 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 18 | 19 | # Required 20 | version: 2 21 | 22 | # Build documentation in the docs/source directory with Sphinx. 23 | sphinx: 24 | configuration: docs/source/conf.py 25 | 26 | # Build docs in additional formats (PDF, ePub). 27 | formats: all 28 | 29 | # Optionally set the version of Python and requirements required to build your docs 30 | python: 31 | version: 3.7 32 | install: 33 | - requirements: docs/requirements.txt 34 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | language: python 3 | cache: pip 4 | git: 5 | depth: 3 6 | quiet: true 7 | python: 8 | - "3.6" 9 | env: 10 | global: 11 | - TF_VERSION="2.4.*" 12 | matrix: 13 | - TRAX_TEST="lib" 14 | - TRAX_TEST="research" 15 | install: 16 | - ./oss_scripts/oss_pip_install.sh 17 | script: 18 | - ./oss_scripts/oss_tests.sh 19 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of Trax authors for copyright purposes. 2 | # 3 | # This does not necessarily list everyone who has contributed code, since in 4 | # some cases, their employer may be the copyright holder. To see the full list 5 | # of contributors, see the revision history in source control. 6 | 7 | Google Inc. 8 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | # Issues 4 | 5 | * Please tag your issue with `bug`, `feature request`, or `question` to help us 6 | effectively respond. 7 | * Please include the versions of JAX or Tensorflow you are running. 8 | * Please provide the command line you ran as well as the log output. 9 | 10 | # Pull Requests 11 | 12 | We'd love to accept your patches and contributions to this project. There are 13 | just a few small guidelines you need to follow. 14 | 15 | ## Contributor License Agreement 16 | 17 | Contributions to this project must be accompanied by a Contributor License 18 | Agreement. You (or your employer) retain the copyright to your contribution, 19 | this simply gives us permission to use and redistribute your contributions as 20 | part of the project. Head over to to see 21 | your current agreements on file or to sign a new one. 22 | 23 | You generally only need to submit a CLA once, so if you've already submitted one 24 | (even if it was for a different project), you probably don't need to do it 25 | again. 26 | 27 | ## Code reviews 28 | 29 | All submissions, including submissions by project members, require review. We 30 | use GitHub pull requests for this purpose. Consult 31 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 32 | information on using pull requests. 33 | 34 | ## Community Guidelines 35 | 36 | This project follows 37 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 38 | -------------------------------------------------------------------------------- /ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ### Description 2 | 3 | ... 4 | 5 | ### Environment information 6 | 7 | ``` 8 | OS: 9 | 10 | $ pip freeze | grep trax 11 | # your output here 12 | 13 | $ pip freeze | grep tensor 14 | # your output here 15 | 16 | $ pip freeze | grep jax 17 | # your output here 18 | 19 | $ python -V 20 | # your output here 21 | ``` 22 | 23 | ### For bugs: reproduction and error logs 24 | 25 | ``` 26 | # Steps to reproduce: 27 | ... 28 | ``` 29 | 30 | ``` 31 | # Error logs: 32 | ... 33 | ``` 34 | -------------------------------------------------------------------------------- /docs/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # .readthedocs.yml 16 | # Read the Docs configuration file 17 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 18 | 19 | # Required 20 | version: 2 21 | 22 | # Build documentation in the docs/source directory with Sphinx. 23 | sphinx: 24 | configuration: docs/source/conf.py 25 | 26 | # Build docs in additional formats (PDF, ePub). 27 | formats: all 28 | 29 | # Optionally set the version of Python and requirements required to build your docs 30 | python: 31 | version: 3.7 32 | install: 33 | - requirements: docs/requirements.txt 34 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | nbsphinx 2 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Trax documentation master file. 2 | Should contain at minimum the root `toctree` directive. 3 | 4 | Trax Tutorials 5 | ============== 6 | 7 | .. toctree:: 8 | :caption: Introductory Notebooks 9 | :maxdepth: 2 10 | 11 | notebooks/trax_intro 12 | notebooks/layers_intro 13 | notebooks/tf_numpy_and_keras 14 | 15 | 16 | Trax API 17 | ======== 18 | 19 | .. toctree:: 20 | :caption: Packages/modules 21 | :maxdepth: 2 22 | 23 | trax.* 24 | 25 | 26 | Indices and Tables 27 | ================== 28 | 29 | * :ref:`genindex` 30 | * :ref:`modindex` 31 | * :ref:`search` 32 | -------------------------------------------------------------------------------- /docs/source/trax.data.rst: -------------------------------------------------------------------------------- 1 | trax.data 2 | ========= 3 | 4 | inputs 5 | ------ 6 | 7 | .. automodule:: trax.data.inputs 8 | 9 | tf\_inputs 10 | ---------- 11 | 12 | .. automodule:: trax.data.tf_inputs 13 | -------------------------------------------------------------------------------- /docs/source/trax.fastmath.rst: -------------------------------------------------------------------------------- 1 | trax.fastmath 2 | ============= 3 | 4 | ops 5 | --- 6 | 7 | .. automodule:: trax.fastmath.ops 8 | -------------------------------------------------------------------------------- /docs/source/trax.layers.rst: -------------------------------------------------------------------------------- 1 | trax.layers 2 | =========== 3 | 4 | acceleration 5 | ------------ 6 | 7 | .. automodule:: trax.layers.acceleration 8 | 9 | activation\_fns 10 | --------------- 11 | 12 | .. automodule:: trax.layers.activation_fns 13 | 14 | attention 15 | --------- 16 | 17 | .. automodule:: trax.layers.attention 18 | 19 | base 20 | ---- 21 | 22 | .. automodule:: trax.layers.base 23 | 24 | combinators 25 | ----------- 26 | 27 | .. automodule:: trax.layers.combinators 28 | 29 | convolution 30 | ----------- 31 | 32 | .. automodule:: trax.layers.convolution 33 | 34 | core 35 | ---- 36 | 37 | .. automodule:: trax.layers.core 38 | 39 | initializers 40 | ------------ 41 | 42 | .. automodule:: trax.layers.initializers 43 | 44 | metrics 45 | ------- 46 | 47 | .. automodule:: trax.layers.metrics 48 | 49 | normalization 50 | ------------- 51 | 52 | .. automodule:: trax.layers.normalization 53 | 54 | pooling 55 | ------- 56 | 57 | .. automodule:: trax.layers.pooling 58 | 59 | reversible 60 | ---------- 61 | 62 | .. automodule:: trax.layers.reversible 63 | 64 | rnn 65 | --- 66 | 67 | .. automodule:: trax.layers.rnn 68 | 69 | research.efficient\_attention 70 | ----------------------------- 71 | 72 | .. automodule:: trax.layers.research.efficient_attention 73 | 74 | research.position\_encodings 75 | ---------------------------- 76 | 77 | .. automodule:: trax.layers.research.position_encodings 78 | -------------------------------------------------------------------------------- /docs/source/trax.models.rst: -------------------------------------------------------------------------------- 1 | trax.models 2 | =========== 3 | 4 | atari\_cnn 5 | ---------- 6 | 7 | .. automodule:: trax.models.atari_cnn 8 | 9 | mlp 10 | --- 11 | 12 | .. automodule:: trax.models.mlp 13 | 14 | neural\_gpu 15 | ----------- 16 | 17 | .. automodule:: trax.models.neural_gpu 18 | 19 | resnet 20 | ------ 21 | 22 | .. automodule:: trax.models.resnet 23 | 24 | rl 25 | -- 26 | 27 | .. automodule:: trax.models.rl 28 | 29 | rnn 30 | --- 31 | 32 | .. automodule:: trax.models.rnn 33 | 34 | transformer 35 | ----------- 36 | 37 | .. automodule:: trax.models.transformer 38 | 39 | reformer.reformer 40 | ----------------- 41 | 42 | .. automodule:: trax.models.reformer.reformer 43 | 44 | research.bert 45 | ------------- 46 | 47 | .. automodule:: trax.models.research.bert 48 | 49 | research.skipping\_transformer 50 | ------------------------------ 51 | 52 | .. automodule:: trax.models.research.skipping_transformer 53 | -------------------------------------------------------------------------------- /docs/source/trax.optimizers.rst: -------------------------------------------------------------------------------- 1 | trax.optimizers 2 | =============== 3 | 4 | adafactor 5 | --------- 6 | 7 | .. automodule:: trax.optimizers.adafactor 8 | 9 | adam 10 | ---- 11 | 12 | .. automodule:: trax.optimizers.adam 13 | 14 | base 15 | ---- 16 | 17 | .. automodule:: trax.optimizers.base 18 | 19 | momentum 20 | -------- 21 | 22 | .. automodule:: trax.optimizers.momentum 23 | 24 | rms\_prop 25 | --------- 26 | 27 | .. automodule:: trax.optimizers.rms_prop 28 | 29 | sm3 30 | --- 31 | 32 | .. automodule:: trax.optimizers.sm3 33 | -------------------------------------------------------------------------------- /docs/source/trax.rl.rst: -------------------------------------------------------------------------------- 1 | trax.rl package 2 | =============== 3 | 4 | actor\_critic 5 | ------------- 6 | 7 | .. automodule:: trax.rl.actor_critic 8 | 9 | actor\_critic\_joint 10 | -------------------- 11 | 12 | .. automodule:: trax.rl.actor_critic_joint 13 | 14 | advantages 15 | ---------- 16 | 17 | .. automodule:: trax.rl.advantages 18 | 19 | distributions 20 | ------------- 21 | 22 | .. automodule:: trax.rl.distributions 23 | 24 | normalization 25 | ------------- 26 | 27 | .. automodule:: trax.rl.normalization 28 | 29 | rl\_layers 30 | ---------- 31 | 32 | .. automodule:: trax.rl.rl_layers 33 | 34 | serialization\_utils 35 | -------------------- 36 | 37 | .. automodule:: trax.rl.serialization_utils 38 | 39 | space\_serializer 40 | ----------------- 41 | 42 | .. automodule:: trax.rl.space_serializer 43 | 44 | task 45 | ---- 46 | 47 | .. automodule:: trax.rl.task 48 | 49 | training 50 | -------- 51 | 52 | .. automodule:: trax.rl.training 53 | -------------------------------------------------------------------------------- /docs/source/trax.rst: -------------------------------------------------------------------------------- 1 | trax 2 | ==== 3 | 4 | .. toctree:: 5 | fastmath.* 6 | layers.* 7 | models.* 8 | data.* 9 | optimizers.* 10 | supervised.* 11 | rl.* 12 | 13 | 14 | shapes 15 | ------ 16 | 17 | .. automodule:: trax.shapes 18 | 19 | trainer 20 | ------- 21 | 22 | .. automodule:: trax.trainer 23 | 24 | rl\_trainer 25 | ----------- 26 | 27 | .. automodule:: trax.rl_trainer 28 | 29 | 30 | trax2keras 31 | ---------- 32 | 33 | .. automodule:: trax.trax2keras 34 | -------------------------------------------------------------------------------- /docs/source/trax.supervised.rst: -------------------------------------------------------------------------------- 1 | trax.supervised 2 | =============== 3 | 4 | decoding 5 | -------- 6 | 7 | .. automodule:: trax.supervised.decoding 8 | 9 | lr\_schedules 10 | ------------- 11 | 12 | .. automodule:: trax.supervised.lr_schedules 13 | 14 | training 15 | -------- 16 | 17 | .. automodule:: trax.supervised.training 18 | -------------------------------------------------------------------------------- /oss_scripts/oss_pip_install.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | 17 | set -v # print commands as they're executed 18 | set -e # fail and exit on any command erroring 19 | 20 | : "${TF_VERSION:?}" 21 | 22 | # Make sure we have the latest pip and setuptools installed. 23 | pip install -q -U pip 24 | pip install -q -U setuptools 25 | 26 | # Make sure we have the latest version of numpy - avoid problems we were 27 | # seeing with Python 3 28 | pip install -q -U numpy 29 | 30 | # Install appropriate version to tensorflow. 31 | if [[ "$TF_VERSION" == "tf-nightly" ]] 32 | then 33 | pip install tf-nightly; 34 | else 35 | pip install -q "tensorflow==$TF_VERSION" 36 | fi 37 | 38 | # Just print the version again to make sure. 39 | python -c 'import tensorflow as tf; print(tf.__version__)' 40 | 41 | # First ensure that the base dependencies are sufficient for a full import 42 | pip install -q -e . 43 | 44 | # Then install the test dependencies 45 | pip install -q -e .[tests] 46 | # Make sure to install the atari extras for gym 47 | pip install "gym[atari]" 48 | 49 | # Coverage. 50 | pip install coverage coveralls 51 | -------------------------------------------------------------------------------- /oss_scripts/oss_release.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | 17 | set -v # print commands as they're executed 18 | set -e # fail and exit on any command erroring 19 | 20 | GIT_COMMIT_ID=${1:-""} 21 | [[ -z $GIT_COMMIT_ID ]] && echo "Must provide a commit" && exit 1 22 | 23 | TMP_DIR=$(mktemp -d) 24 | pushd $TMP_DIR 25 | 26 | echo "Cloning trax and checking out commit $GIT_COMMIT_ID" 27 | git clone https://github.com/google/trax.git 28 | cd trax 29 | git checkout $GIT_COMMIT_ID 30 | 31 | python3 -m pip install wheel twine pyopenssl 32 | 33 | # Build the distribution 34 | echo "Building distribution" 35 | python3 setup.py sdist 36 | python3 setup.py bdist_wheel --universal 37 | 38 | # Publish to PyPI 39 | echo "Publishing to PyPI" 40 | python3 -m twine upload dist/* 41 | 42 | # Cleanup 43 | rm -rf build/ dist/ trax.egg-info/ 44 | popd 45 | rm -rf $TMP_DIR 46 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # coding=utf-8 17 | """Install trax.""" 18 | 19 | from setuptools import find_packages 20 | from setuptools import setup 21 | 22 | setup( 23 | name='trax', 24 | version='1.4.1', 25 | description='Trax', 26 | long_description=( 27 | 'Trax helps you understand deep learning. We start with basic maths and' 28 | ' go through layers, models, supervised and reinforcement learning. We ' 29 | 'get to advanced deep learning results, including recent papers and ' 30 | 'state-of-the-art models.' 31 | ), 32 | author='Google Inc.', 33 | author_email='no-reply@google.com', 34 | url='http://github.com/google/trax', 35 | license='Apache 2.0', 36 | packages=find_packages(), 37 | install_requires=[ 38 | 'absl-py', 39 | 'funcsigs', 40 | 'gin-config', 41 | 'gym', 42 | 'jax', 43 | 'jaxlib', 44 | 'matplotlib', 45 | 'numpy', 46 | 'psutil', 47 | 'scipy', 48 | 'six', 49 | 'tensorflow-datasets', 50 | 'tensorflow-text', 51 | ], 52 | extras_require={ 53 | 'tensorflow': ['tensorflow>=1.15.0'], 54 | 'tensorflow_gpu': ['tensorflow-gpu>=1.15.0'], 55 | 't5': ['t5>=0.4.0'], 56 | 'tests': [ 57 | 'attrs', 58 | 'jupyter', 59 | 'mock', 60 | 'parameterized', 61 | 'pylint', 62 | 'pytest', 63 | 'wrapt==1.11.*', 64 | ], 65 | 't2t': ['tensor2tensor',], 66 | }, 67 | classifiers=[ 68 | 'Development Status :: 4 - Beta', 69 | 'Intended Audience :: Developers', 70 | 'Intended Audience :: Science/Research', 71 | 'License :: OSI Approved :: Apache Software License', 72 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 73 | ], 74 | keywords='tensorflow machine learning jax', 75 | ) 76 | -------------------------------------------------------------------------------- /trax/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Trax top level import.""" 17 | 18 | from trax import data 19 | from trax import fastmath 20 | from trax import layers 21 | from trax import models 22 | from trax import optimizers 23 | from trax import shapes 24 | from trax import supervised 25 | from trax.supervised import lr_schedules as lr 26 | from trax.trax2keras import AsKeras 27 | -------------------------------------------------------------------------------- /trax/data/debug_data_pipeline.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A debugging decorator for TRAX input pipeline.""" 17 | 18 | import functools 19 | 20 | from absl import logging 21 | import gin 22 | 23 | 24 | @gin.configurable(denylist=['f']) 25 | def debug_pipeline(f, debug=False, method='pow', log_prefix=None): 26 | """Decorator for input pipeline generators that logs examples at intervals.""" 27 | if not debug: 28 | return f 29 | 30 | assert method in ('pow', 'every') 31 | @functools.wraps(f) 32 | def wrapper(*args, **kwargs): 33 | count = 0 34 | prefix = log_prefix or f.__name__ 35 | for example in f(*args, **kwargs): 36 | count += 1 37 | if method == 'every' or (method == 'pow' and (count & count - 1 == 0)): 38 | logging.info('%s example[%d] = %r', prefix, count, example) 39 | yield example 40 | 41 | return wrapper 42 | -------------------------------------------------------------------------------- /trax/data/testdata/c4/en/2.3.0/c4-train.tfrecord-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/data/testdata/c4/en/2.3.0/c4-train.tfrecord-00000-of-00001 -------------------------------------------------------------------------------- /trax/data/testdata/c4/en/2.3.0/c4-validation.tfrecord-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/data/testdata/c4/en/2.3.0/c4-validation.tfrecord-00000-of-00001 -------------------------------------------------------------------------------- /trax/data/testdata/c4/en/2.3.0/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "\n@article{2019t5,\n author = {Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu},\n title = {Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer},\n journal = {arXiv e-prints},\n year = {2019},\n archivePrefix = {arXiv},\n eprint = {1910.10683},\n}\n", 3 | "description": "A colossal, cleaned version of Common Crawl's web crawl corpus.\n\nBased on Common Crawl dataset: \"https://commoncrawl.org\"\n\nDue to the overhead of cleaning the dataset, it is recommend you prepare it with\na distributed service like Cloud Dataflow. More info at\nhttps://www.tensorflow.org/datasets/beam_datasets.\n", 4 | "downloadSize": "7650623283551", 5 | "location": { 6 | "urls": [ 7 | "https://github.com/google-research/text-to-text-transfer-transformer#datasets" 8 | ] 9 | }, 10 | "name": "c4", 11 | "splits": [ 12 | { 13 | "name": "train", 14 | "numBytes": "10603", 15 | "numShards": "1", 16 | "shardLengths": [ 17 | "10" 18 | ] 19 | }, 20 | { 21 | "name": "validation", 22 | "numBytes": "16851", 23 | "numShards": "1", 24 | "shardLengths": [ 25 | "10" 26 | ] 27 | } 28 | ], 29 | "version": "2.3.0" 30 | } 31 | -------------------------------------------------------------------------------- /trax/data/testdata/corpus-1.txt: -------------------------------------------------------------------------------- 1 | One morning I shot an elephant in my pajamas. How he got in my pajamas, I don't 2 | know. 3 | 4 | Groucho Marx 5 | -------------------------------------------------------------------------------- /trax/data/testdata/corpus-2.txt: -------------------------------------------------------------------------------- 1 | I haven't slept for 10 days... because that would be too long. 2 | 3 | Mitch Hedberg 4 | -------------------------------------------------------------------------------- /trax/data/testdata/para_crawl/ende/1.2.0/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "@misc {paracrawl,\n title = \"ParaCrawl\",\n year = \"2018\",\n url = \"http://paracrawl.eu/download.html.\"\n}", 3 | "configDescription": "Translation dataset from English to de.", 4 | "configName": "ende", 5 | "description": "Web-Scale Parallel Corpora for Official European Languages.", 6 | "downloadSize": "1307754745", 7 | "location": { 8 | "urls": [ 9 | "https://paracrawl.eu/releases.html" 10 | ] 11 | }, 12 | "name": "para_crawl", 13 | "splits": [ 14 | { 15 | "name": "train", 16 | "numBytes": "3241", 17 | "shardLengths": [ 18 | "10" 19 | ] 20 | } 21 | ], 22 | "supervisedKeys": { 23 | "input": "en", 24 | "output": "de" 25 | }, 26 | "version": "1.2.0" 27 | } 28 | -------------------------------------------------------------------------------- /trax/data/testdata/para_crawl/ende/1.2.0/features.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "tensorflow_datasets.core.features.translation_feature.Translation", 3 | "content": { 4 | "languages": [ 5 | "de", 6 | "en" 7 | ] 8 | } 9 | } -------------------------------------------------------------------------------- /trax/data/testdata/para_crawl/ende/1.2.0/para_crawl-train.tfrecord-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/data/testdata/para_crawl/ende/1.2.0/para_crawl-train.tfrecord-00000-of-00001 -------------------------------------------------------------------------------- /trax/data/testdata/sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/data/testdata/sentencepiece.model -------------------------------------------------------------------------------- /trax/data/testdata/squad/v1.1/3.0.0/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "@article{2016arXiv160605250R,\n author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n Konstantin and {Liang}, Percy},\n title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n journal = {arXiv e-prints},\n year = 2016,\n eid = {arXiv:1606.05250},\n pages = {arXiv:1606.05250},\narchivePrefix = {arXiv},\n eprint = {1606.05250},\n}\n", 3 | "description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n", 4 | "location": { 5 | "urls": [ 6 | "https://rajpurkar.github.io/SQuAD-explorer/" 7 | ] 8 | }, 9 | "name": "squad", 10 | "schema": { 11 | "feature": [ 12 | { 13 | "name": "answers" 14 | }, 15 | { 16 | "name": "context", 17 | "type": "BYTES" 18 | }, 19 | { 20 | "name": "id", 21 | "type": "BYTES" 22 | }, 23 | { 24 | "name": "question", 25 | "type": "BYTES" 26 | }, 27 | { 28 | "name": "title", 29 | "type": "BYTES" 30 | } 31 | ] 32 | }, 33 | "sizeInBytes": "35142551", 34 | "splits": [ 35 | { 36 | "name": "train", 37 | "numShards": "1", 38 | "shardLengths": [ 39 | "10" 40 | ] 41 | }, 42 | { 43 | "name": "validation", 44 | "numShards": "1", 45 | "shardLengths": [ 46 | "10" 47 | ] 48 | } 49 | ], 50 | "version": "3.0.0" 51 | } 52 | -------------------------------------------------------------------------------- /trax/data/testdata/squad/v1.1/3.0.0/squad-train.tfrecord-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/data/testdata/squad/v1.1/3.0.0/squad-train.tfrecord-00000-of-00001 -------------------------------------------------------------------------------- /trax/data/testdata/squad/v1.1/3.0.0/squad-validation.tfrecord-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/data/testdata/squad/v1.1/3.0.0/squad-validation.tfrecord-00000-of-00001 -------------------------------------------------------------------------------- /trax/data/testdata/vocab-1.txt: -------------------------------------------------------------------------------- 1 | lollipop,8 2 | reverberated,12 3 | -------------------------------------------------------------------------------- /trax/data/testdata/vocab-2.txt: -------------------------------------------------------------------------------- 1 | kattywampus,11 2 | kaput 3 | balderdash,10 4 | jiggery-pokery,14 5 | -------------------------------------------------------------------------------- /trax/examples/README.md: -------------------------------------------------------------------------------- 1 | # Trax Examples 2 | 3 | 4 | Here are a few example notebooks:- 5 | 6 | * [**trax.data API explained**](https://github.com/google/trax/blob/master/trax/examples/trax_data_Explained.ipynb) : Explains some of the major functions in the `trax.data` API 7 | * [**Named Entity Recognition using Reformer**](https://github.com/google/trax/blob/master/trax/examples/NER_using_Reformer.ipynb) : Uses a [Kaggle dataset](https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus) for implementing Named Entity Recognition using the [Reformer](https://arxiv.org/abs/2001.04451) architecture. 8 | * [**Deep N-Gram models**](https://github.com/google/trax/blob/master/trax/examples/Deep_N_Gram_Models.ipynb) : Implementation of deep n-gram models trained on Shakespeares works 9 | * [**Knowledge Tracing Transformer**](https://github.com/google/trax/blob/master/trax/examples/Knowledge_Tracing_Transformer.ipynb): End-to-end example adapting basic Transformer architecture to accommodate a knowledge tracing task. 10 | * [**Neural Machine Translation with Transformers/Reformers**](https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb): A guide to Neural Machine Translation using Transformers/Reformers. Includes a detailed tutorial using Trax. 11 | -------------------------------------------------------------------------------- /trax/fastmath/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Trax fast math -- NumPy-style math on accelerators.""" 17 | 18 | from trax.fastmath.numpy import nested_map 19 | from trax.fastmath.numpy import nested_map_multiarg 20 | from trax.fastmath.numpy import nested_stack 21 | from trax.fastmath.numpy import nested_zip 22 | from trax.fastmath.numpy import tree_flatten 23 | from trax.fastmath.numpy import tree_leaves 24 | from trax.fastmath.numpy import tree_unflatten 25 | from trax.fastmath.ops import * # pylint: disable=wildcard-import 26 | -------------------------------------------------------------------------------- /trax/import_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for importing Trax.""" 17 | 18 | from absl.testing import absltest 19 | 20 | 21 | class ImportTest(absltest.TestCase): 22 | 23 | def test_import_trax(self): 24 | try: 25 | # Import trax 26 | import trax # pylint: disable=g-import-not-at-top 27 | # Access a few symbols. 28 | dir(trax.fastmath) 29 | dir(trax.layers) 30 | dir(trax.models) 31 | except ImportError as e: 32 | raise e 33 | 34 | 35 | if __name__ == '__main__': 36 | absltest.main() 37 | -------------------------------------------------------------------------------- /trax/layers/README.md: -------------------------------------------------------------------------------- 1 | # Trax Layers 2 | 3 | 4 | 5 | ## Base layer structure 6 | 7 | All layers inherit from the Layer class and generally need to implement 2 8 | methods: 9 | 10 | ```python 11 | def forward(self, inputs): 12 | """Computes the layer's output as part of a forward pass through the model.""" 13 | 14 | def init_weights_and_state(self, input_signature): 15 | """Initializes weights and state for inputs with the given signature.""" 16 | ``` 17 | 18 | The base Layer class wraps these functions and provides initialization 19 | and call functions to be used as follows. 20 | 21 | ```python 22 | layer = MyLayer() 23 | x = np.zeros(10) 24 | layer.init(signature(x)) 25 | output = layer(x) 26 | ``` 27 | 28 | ## Fn layer 29 | 30 | To create simple layers without parameters, use the Fn layer. 31 | 32 | ```python 33 | def Relu(x): 34 | return Fn('Relu', lambda x: np.maximum(x, np.zeros_like(x))) 35 | ``` 36 | 37 | ## Parameter sharing 38 | 39 | Parameters are shared when the same layer object is used. 40 | 41 | ```python 42 | standard_mlp = layers.Serial(layers.Dense(10), layers.Dense(10)) 43 | layer = Dense(10) 44 | shared_parameters_mlp = layers.Serial(layer, layer) 45 | ``` 46 | 47 | ## Core layers 48 | 49 | * Dense 50 | * Conv 51 | 52 | ## Layer composition 53 | 54 | * Serial 55 | * Parallel 56 | -------------------------------------------------------------------------------- /trax/layers/activation_fns_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for activation function layers.""" 17 | 18 | from absl.testing import absltest 19 | import numpy as np 20 | 21 | import trax.layers as tl 22 | 23 | 24 | class ActivationFnsTest(absltest.TestCase): 25 | 26 | def test_relu(self): 27 | layer = tl.Relu() 28 | x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) 29 | y = layer(x) 30 | self.assertEqual(tl.to_list(y), [0.0, 0.0, 0.0, 2.0, 3.0, 5.0]) 31 | 32 | def test_parametric_relu(self): 33 | layer = tl.ParametricRelu(a=.25) 34 | x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) 35 | y = layer(x) 36 | self.assertEqual(tl.to_list(y), [0.0, 0.0, 0.0, .5, .75, 1.25]) 37 | 38 | def test_leaky_relu(self): 39 | layer = tl.LeakyRelu(a=.125) 40 | x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) 41 | y = layer(x) 42 | self.assertEqual(tl.to_list(y), [-.25, -.125, 0.0, 2.0, 3.0, 5.0]) 43 | 44 | def test_hard_sigmoid(self): 45 | layer = tl.HardSigmoid() 46 | x = np.array([-1.5, -.5, -.25, 0.0, .25, .5, 1.5]) 47 | y = layer(x) 48 | self.assertEqual(tl.to_list(y), [0.0, 0.5, 0.75, 1.0, 1.0, 1.0, 1.0]) 49 | 50 | def test_hard_tanh(self): 51 | layer = tl.HardTanh() 52 | x = np.array([-1.5, -.5, -.25, 0.0, .25, .5, 1.5]) 53 | y = layer(x) 54 | self.assertEqual(tl.to_list(y), [-1.0, -.5, -.25, 0.0, .25, .5, 1.0]) 55 | 56 | 57 | if __name__ == '__main__': 58 | absltest.main() 59 | -------------------------------------------------------------------------------- /trax/layers/deconvolution_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Deconvolution layers.""" 17 | 18 | from absl.testing import absltest 19 | import numpy as np 20 | 21 | from trax import shapes 22 | import trax.layers as tl 23 | 24 | 25 | class ConvTransposeTest(absltest.TestCase): 26 | 27 | def test_call(self): 28 | layer = tl.ConvTranspose(30, (3, 3)) 29 | x = np.ones((9, 5, 5, 20)) 30 | layer.init(shapes.signature(x)) 31 | 32 | y = layer(x) 33 | self.assertEqual(y.shape, (9, 7, 7, 30)) 34 | 35 | 36 | if __name__ == '__main__': 37 | absltest.main() 38 | -------------------------------------------------------------------------------- /trax/layers/research/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /trax/layers/research/rel_attention_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Copyright 2020 The Trax Authors. 17 | # 18 | # Licensed under the Apache License, Version 2.0 (the "License"); 19 | # you may not use this file except in compliance with the License. 20 | # You may obtain a copy of the License at 21 | # 22 | # http://www.apache.org/licenses/LICENSE-2.0 23 | # 24 | # Unless required by applicable law or agreed to in writing, software 25 | # distributed under the License is distributed on an "AS IS" BASIS, 26 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 27 | # See the License for the specific language governing permissions and 28 | # limitations under the License. 29 | 30 | """Tests for trax.layers.relattention.""" 31 | 32 | from absl.testing import absltest 33 | import numpy as np 34 | 35 | import trax.layers as tl 36 | import trax.layers.research.rel_attention as ra 37 | 38 | 39 | class RelAttentionTest(absltest.TestCase): 40 | 41 | def test_fast_shift_matrix(self): 42 | layer = ra._fast_matrix_shift 43 | x = np.array([[[[-3., -2., -1., 0.], [-3., -2., -1., 44 | 0.], [-3., -2., -1., 0.], 45 | [-3., -2., -1., 0.]]]]).astype(np.float32) 46 | 47 | y = layer(x) 48 | self.assertEqual(y.dtype, np.float32) 49 | self.assertEqual( 50 | tl.to_list(y), [[[[0., 0., -3., -2.], [-1., 0., 0., -3.], 51 | [-2., -1., 0., 0.], [-3., -2., -1., 0.]]]]) 52 | 53 | if __name__ == '__main__': 54 | absltest.main() 55 | -------------------------------------------------------------------------------- /trax/layers/research/rotary_positional_embedding.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Rotary positional embeddings. 17 | 18 | Rotary positional embedding implementation, as described in: 19 | https://arxiv.org/pdf/2104.09864.pdf 20 | """ 21 | 22 | # from trax import layers as tl 23 | from trax.fastmath import numpy as jnp 24 | from trax.layers import core 25 | 26 | 27 | def rotate(x): 28 | """Rotate function.""" 29 | _, l, d = x.shape 30 | inv_freq = jnp.exp(jnp.arange(0, d, 2) * -(jnp.log(10000.0) / d)) 31 | positions = jnp.arange(l) 32 | freqs = jnp.einsum('i,j->ij', positions, inv_freq) 33 | emb = jnp.concatenate((freqs, freqs), axis=-1) 34 | cos = jnp.cos(emb) 35 | sin = jnp.sin(emb) 36 | 37 | def mul(vecs, pos_emb): 38 | return jnp.einsum('bld,ld->bld', vecs, pos_emb) 39 | 40 | def rotate_half(x): 41 | x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] 42 | return jnp.concatenate((-x2, x1), axis=x1.ndim - 1) 43 | 44 | return mul(x, cos) + mul(rotate_half(x), sin) 45 | 46 | 47 | def Rotate(): # pylint: disable=invalid-name 48 | return core.Fn('Rotate', rotate) 49 | -------------------------------------------------------------------------------- /trax/layers/research/rotary_positional_embedding_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for trax.layers.research.rotary_positional_embedding.""" 17 | 18 | from absl.testing import absltest 19 | import numpy as np 20 | from trax.layers.research import rotary_positional_embedding as rotary_pe 21 | 22 | 23 | class RelAttentionTest(absltest.TestCase): 24 | 25 | def test_rotary_monotonicity(self): 26 | layer = rotary_pe.Rotate() 27 | batch_size = 1 28 | seq_len = 32 29 | d_model = 512 30 | shape = (batch_size, seq_len, d_model) 31 | q, k = np.ones(shape).astype(np.float32), np.ones(shape).astype(np.float32) 32 | q, k = layer(q), layer(k) 33 | 34 | self.assertEqual(q.dtype, np.float32) 35 | self.assertEqual(q.shape, shape) 36 | 37 | # Test monotonicity of the resulting dot_product for the two first tokens 38 | # in close proximity 39 | dot_product = np.einsum('bnd, bmd -> bnm', q, k) 40 | 41 | self.assertTrue((dot_product[0, 0, :9] > dot_product[0, 0, 1:10]).all()) 42 | self.assertTrue((dot_product[0, 1, 1:10] > dot_product[0, 1, 2:11]).all()) 43 | 44 | 45 | if __name__ == '__main__': 46 | absltest.main() 47 | -------------------------------------------------------------------------------- /trax/layers/reversible_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for reversible layers.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import numpy as np 21 | 22 | from trax import fastmath 23 | import trax.layers as tl 24 | 25 | 26 | BACKENDS = [fastmath.Backend.JAX] 27 | 28 | 29 | class ReversibleLayerTest(parameterized.TestCase): 30 | 31 | @parameterized.named_parameters([('_' + b.value, b) for b in BACKENDS]) 32 | def test_reversible_swap(self, backend): 33 | with fastmath.use_backend(backend): 34 | layer = tl.ReversibleSwap() 35 | xs = [np.array([1, 2]), np.array([10, 20])] 36 | ys = layer(xs) 37 | self.assertEqual(tl.to_list(ys), [[10, 20], [1, 2]]) 38 | 39 | 40 | if __name__ == '__main__': 41 | absltest.main() 42 | -------------------------------------------------------------------------------- /trax/layers/rnn_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for rnn layers.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import numpy as np 21 | 22 | from trax import fastmath 23 | from trax import shapes 24 | import trax.layers as tl 25 | 26 | 27 | BACKENDS = [fastmath.Backend.JAX] 28 | 29 | 30 | @parameterized.named_parameters( 31 | ('_' + b.value, b) for b in BACKENDS) 32 | class RnnTest(parameterized.TestCase): 33 | 34 | def test_conv_gru_cell(self, backend): 35 | with fastmath.use_backend(backend): 36 | layer = tl.ConvGRUCell(9, kernel_size=(3, 3)) 37 | x = np.ones((8, 1, 7, 9)) 38 | _, _ = layer.init(shapes.signature(x)) 39 | y = layer(x) 40 | self.assertEqual(y.shape, x.shape) 41 | 42 | def test_gru_cell(self, backend): 43 | with fastmath.use_backend(backend): 44 | layer = tl.GRUCell(9) 45 | xs = [np.ones((8, 7, 9)), np.ones((8, 7, 9))] 46 | _, _ = layer.init(shapes.signature(xs)) 47 | ys = layer(xs) 48 | self.assertEqual([y.shape for y in ys], [(8, 7, 9), (8, 7, 9)]) 49 | 50 | def test_lstm_cell(self, backend): 51 | with fastmath.use_backend(backend): 52 | layer = tl.LSTMCell(9) 53 | xs = [np.ones((8, 9)), np.ones((8, 18))] 54 | _, _ = layer.init(shapes.signature(xs)) 55 | ys = layer(xs) 56 | self.assertEqual([y.shape for y in ys], [(8, 9), (8, 18)]) 57 | 58 | def test_sru(self, backend): 59 | with fastmath.use_backend(backend): 60 | layer = tl.SRU(7) 61 | x = np.ones((8, 9, 7), np.float32) 62 | _, _ = layer.init(shapes.signature(x)) 63 | y = layer(x) 64 | self.assertEqual(y.shape, x.shape) 65 | 66 | def test_names(self, backend): 67 | with fastmath.use_backend(backend): 68 | layer = tl.LSTM(3) 69 | self.assertEqual('LSTM_3', str(layer)) 70 | layer = tl.GRU(5) 71 | self.assertEqual('GRU_5', str(layer)) 72 | layer = tl.SRU(7) 73 | self.assertEqual('SRU_7', str(layer)) 74 | 75 | 76 | if __name__ == '__main__': 77 | absltest.main() 78 | -------------------------------------------------------------------------------- /trax/models/atari_cnn_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for trax.models.atari_cnn.""" 17 | 18 | import functools 19 | import operator as op 20 | import numpy as np 21 | from tensorflow import test 22 | from trax.models import atari_cnn 23 | from trax.shapes import ShapeDtype 24 | 25 | 26 | class AtariCnnTest(test.TestCase): 27 | 28 | def test_computes(self): 29 | hidden_size = (4, 4) 30 | output_size = 6 31 | model = atari_cnn.AtariCnn( 32 | hidden_sizes=hidden_size, output_size=output_size) 33 | B, T, OBS = 2, 2, (28, 28, 3) # pylint: disable=invalid-name 34 | input_signature = ShapeDtype((1, 1) + OBS) 35 | _, _ = model.init(input_signature) 36 | x = np.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape( 37 | B, T + 1, *OBS) 38 | y = model(x) 39 | self.assertEqual((B, T + 1, output_size), y.shape) 40 | 41 | 42 | class FrameStackMLPTest(test.TestCase): 43 | 44 | def test_computes(self): 45 | hidden_size = (4, 4) 46 | output_size = 6 47 | model = atari_cnn.FrameStackMLP( 48 | hidden_sizes=hidden_size, output_size=output_size) 49 | B, T, OBS = 2, 2, 3 # pylint: disable=invalid-name 50 | input_signature = ShapeDtype((1, 1, OBS)) 51 | _, _ = model.init(input_signature) 52 | x = np.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS) 53 | y = model(x) 54 | self.assertEqual((B, T + 1, output_size), y.shape) 55 | 56 | 57 | if __name__ == '__main__': 58 | test.main() 59 | -------------------------------------------------------------------------------- /trax/models/mlp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """mlp -- functions that assemble "multilayer perceptron" networks.""" 17 | 18 | from trax import layers as tl 19 | 20 | 21 | def MLP( 22 | layer_widths=(128, 64), 23 | activation_fn=tl.Relu, 24 | out_activation=False, 25 | flatten=True, 26 | mode='train'): 27 | """A "multilayer perceptron" (MLP) network. 28 | 29 | This is a classic fully connected feedforward network, with one or more 30 | layers and a (nonlinear) activation function between each layer. For 31 | historical reasons, such networks are often called multilayer perceptrons; 32 | but they are more accurately described as multilayer networks, where 33 | each layer + activation function is a perceptron-like unit (see, e.g., 34 | [https://en.wikipedia.org/wiki/Multilayer_perceptron#Terminology]). 35 | 36 | Args: 37 | layer_widths: Tuple of ints telling the number of layers and the width of 38 | each layer. For example, setting `layer_widths=(128, 64, 32)` would 39 | yield 3 layers with successive widths of 128, 64, and 32. 40 | activation_fn: Type of activation function between pairs of fully connected 41 | layers; must be an activation-type subclass of `Layer`. 42 | out_activation: If True, include a copy of the activation function as the 43 | last layer in the network. 44 | flatten: If True, insert a layer at the head of the network to flatten the 45 | input tensor into a matrix of shape (batch_size. -1). 46 | mode: Ignored. 47 | 48 | Returns: 49 | An assembled MLP network with the specified layers. This network can either 50 | be initialized and trained as a full model, or can be used as a building 51 | block in a larger network. 52 | """ 53 | del mode 54 | 55 | layers = [] 56 | for width in layer_widths: 57 | layers.append(tl.Dense(width)) 58 | layers.append(activation_fn()) 59 | 60 | if not out_activation: 61 | # Don't need the last activation. 62 | layers.pop() 63 | 64 | return tl.Serial( 65 | [tl.Flatten()] if flatten else [], 66 | layers, 67 | ) 68 | -------------------------------------------------------------------------------- /trax/models/mlp_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for MLP.""" 17 | 18 | from absl.testing import absltest 19 | import numpy as np 20 | 21 | from trax import fastmath 22 | from trax import shapes 23 | from trax.models import mlp 24 | 25 | 26 | class MLPTest(absltest.TestCase): 27 | 28 | def test_mlp_forward_shape(self): 29 | model = mlp.MLP(layer_widths=(32, 16, 8)) 30 | x = np.ones((7, 28, 28, 3)).astype(np.float32) 31 | _, _ = model.init(shapes.signature(x)) 32 | y = model(x) 33 | self.assertEqual(y.shape, (7, 8)) 34 | 35 | 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /trax/models/neural_gpu.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of the improved Neural GPU (NGPU).""" 17 | 18 | from trax import layers as tl 19 | from trax.fastmath import numpy as jnp 20 | 21 | 22 | # TODO(ddohan): Combinator to add saturation costs to loss 23 | def SaturationCost(x, limit=0.9): 24 | return jnp.minimum(0, jnp.abs(x) - limit) 25 | 26 | 27 | def DiagonalGate(): 28 | """Split channels in 3 parts. Shifts 1st and 3rd sections to left/right.""" 29 | 30 | def f(x): # pylint: disable=invalid-name 31 | # x : [batch, 1, length, depth] 32 | x = jnp.pad(x, [(0, 0), (0, 0), (1, 1), (0, 0)], 33 | mode='constant', constant_values=0.0) 34 | depth = x.shape[-1] // 3 35 | assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth, 36 | x.shape) 37 | xs = [ 38 | x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth], 39 | x[:, :, 2:, 2 * depth:3 * depth] 40 | ] 41 | return jnp.concatenate(xs, axis=3) 42 | return tl.Fn('DiagonalGate', f) 43 | 44 | 45 | def ConvDiagonalGRU(units, kernel_size=(3, 3)): 46 | """Build convolutional GRU with diagonal gating as in ImprovedNGPU.""" 47 | 48 | def BuildConv(): 49 | return tl.Conv(filters=units, kernel_size=kernel_size, padding='SAME') 50 | 51 | return tl.GeneralGRUCell( 52 | candidate_transform=BuildConv, 53 | memory_transform_fn=DiagonalGate, 54 | gate_nonlinearity=tl.HardSigmoid, 55 | candidate_nonlinearity=tl.HardTanh) 56 | 57 | 58 | def NeuralGPU(d_feature=96, steps=16, vocab_size=2, mode='train'): 59 | """Implementation of Neural GPU: https://arxiv.org/abs/1702.08727. 60 | 61 | Args: 62 | d_feature: Number of memory channels (dimensionality of feature embedding). 63 | steps: Number of times depthwise recurrence steps. 64 | vocab_size: Vocabulary size. 65 | mode: Whether we are training or evaluating or doing inference. 66 | 67 | Returns: 68 | A NeuralGPU Stax model. 69 | """ 70 | del mode 71 | 72 | core = ConvDiagonalGRU(units=d_feature) 73 | return tl.Serial( 74 | tl.Embedding(vocab_size=vocab_size, d_feature=d_feature), 75 | [core] * steps, 76 | tl.Dense(vocab_size), 77 | ) 78 | -------------------------------------------------------------------------------- /trax/models/neural_gpu_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for trax.models.neural_gpu.""" 17 | 18 | from absl.testing import absltest 19 | import numpy as np 20 | 21 | from trax import shapes 22 | from trax.models import neural_gpu 23 | 24 | 25 | class NeuralGPUTest(absltest.TestCase): 26 | 27 | def test_ngpu(self): 28 | model = neural_gpu.NeuralGPU(d_feature=30, steps=4, vocab_size=22) 29 | x = np.ones((3, 5, 7)).astype(np.int32) 30 | _, _ = model.init(shapes.signature(x)) 31 | y = model(x) 32 | self.assertEqual(y.shape, (3, 5, 7, 22)) 33 | 34 | 35 | if __name__ == '__main__': 36 | absltest.main() 37 | -------------------------------------------------------------------------------- /trax/models/reformer/README.md: -------------------------------------------------------------------------------- 1 | ## Reformer and Terraformer: The Efficient Transformers 2 | 3 | Reformer and Terraformer are more efficient versions of Transformer that uses reversible layers, locality-sensitive hashing and sparse layers. 4 | 5 | ### Papers 6 | 7 | * Reformer: Read about the details of Reformer in the [Reformer paper](https://arxiv.org/abs/2001.04451) which was selected for oral presentation at [ICLR 2020](https://iclr.cc/Conferences/2020/). 8 | 9 | 10 | * Terraformer: Read about the details of Terraformer in the following paper. 11 | 12 | ### Models 13 | 14 | 15 | * Generate images with Reformer using [this colab](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/image_generation.ipynb). 16 | 17 | * Translate from English to German with a reversible encoder-decoder model using [this colab](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/machine_translation.ipynb). 18 | 19 | * Medium-size (~700M weights) Terraformer model for summarizing arxiv articles is available at `gs://trax-ml/terraformer/medium` 20 | 21 | * Large (~7B weights) Terraformer model pre-trained on C4 `gs://trax-ml/terraformer/big` 22 | -------------------------------------------------------------------------------- /trax/models/reformer/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /trax/models/reformer/reformer_e2e_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """End to end test for Reformer.""" 17 | 18 | import os 19 | 20 | from absl.testing import absltest 21 | import gin 22 | 23 | from trax import test_utils 24 | from trax.models.reformer import reformer # pylint: disable=unused-import 25 | from trax.supervised import trainer_lib 26 | from trax.tf_numpy import numpy as tf_np # pylint: disable=unused-import 27 | 28 | pkg_dir, _ = os.path.split(__file__) 29 | _TESTDATA = os.path.join(pkg_dir, 'testdata') 30 | _CONFIG_DIR = os.path.join(pkg_dir, '../../supervised/configs/') 31 | 32 | 33 | class ReformerE2ETest(absltest.TestCase): 34 | 35 | def setUp(self): 36 | super().setUp() 37 | gin.clear_config() 38 | gin.add_config_file_search_path(_CONFIG_DIR) 39 | test_utils.ensure_flag('test_tmpdir') 40 | 41 | def test_reformer_wmt_ende(self): 42 | batch_size_per_device = 2 43 | steps = 1 44 | n_layers = 2 45 | d_ff = 32 46 | 47 | gin.parse_config_file('reformer_wmt_ende.gin') 48 | 49 | gin.bind_parameter('data_streams.data_dir', _TESTDATA) 50 | gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) 51 | gin.bind_parameter('train.steps', steps) 52 | gin.bind_parameter('Reformer.n_encoder_layers', n_layers) 53 | gin.bind_parameter('Reformer.n_decoder_layers', n_layers) 54 | gin.bind_parameter('Reformer.d_ff', d_ff) 55 | 56 | output_dir = self.create_tempdir().full_path 57 | _ = trainer_lib.train(output_dir=output_dir) 58 | 59 | def test_reformer_copy(self): 60 | batch_size_per_device = 2 61 | steps = 1 62 | n_layers = 2 63 | d_ff = 32 64 | d_model = 32 65 | 66 | gin.parse_config_file('reformer_copy.gin') 67 | 68 | gin.bind_parameter('data_streams.data_dir', _TESTDATA) 69 | gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) 70 | gin.bind_parameter('train.steps', steps) 71 | gin.bind_parameter('ReformerLM.n_layers', n_layers) 72 | gin.bind_parameter('ReformerLM.d_ff', d_ff) 73 | gin.bind_parameter('ReformerLM.d_model', d_model) 74 | 75 | output_dir = self.create_tempdir().full_path 76 | _ = trainer_lib.train(output_dir=output_dir) 77 | 78 | 79 | if __name__ == '__main__': 80 | absltest.main() 81 | -------------------------------------------------------------------------------- /trax/models/reformer/testdata/translate_ende_wmt32k-dev-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/models/reformer/testdata/translate_ende_wmt32k-dev-00000-of-00001 -------------------------------------------------------------------------------- /trax/models/reformer/testdata/translate_ende_wmt32k-train-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/models/reformer/testdata/translate_ende_wmt32k-train-00000-of-00001 -------------------------------------------------------------------------------- /trax/models/research/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /trax/models/research/rezero_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for ReZero models.""" 17 | 18 | from absl.testing import absltest 19 | import numpy as np 20 | 21 | from trax import layers as tl 22 | from trax import shapes 23 | from trax.models.research import rezero 24 | 25 | 26 | class ResidualZeroTest(absltest.TestCase): 27 | 28 | def test_residual_layer_forward(self): 29 | """Tests that the forward pass runs and returns the expected shape.""" 30 | model = rezero.ResidualZero(tl.Dense(5)) 31 | x = [np.arange(5).astype(np.float32)] 32 | _, _ = model.init(shapes.signature(x)) 33 | y = model(x) 34 | self.assertEqual(y.tolist(), [0., 1., 2., 3., 4.]) 35 | 36 | 37 | class ReZeroTransformerLMTest(absltest.TestCase): 38 | 39 | def test_rezero_lm_forward_shape(self): 40 | """Tests that the forward pass runs and returns the expected shape.""" 41 | vocab_size = 16 42 | model = rezero.ReZeroTransformerLM( 43 | vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2, max_len=16) 44 | xs = [np.ones((1, 8)).astype(np.int32), 45 | np.ones((1, 8)).astype(np.int32)] 46 | _, _ = model.init(shapes.signature(xs)) 47 | ys = model(xs) 48 | self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) 49 | 50 | 51 | class ReZeroTransformerTest(absltest.TestCase): 52 | 53 | def test_rezero_forward_shape(self): 54 | """Tests that the forward pass runs and returns the expected shape.""" 55 | vocab_size = 16 56 | model = rezero.ReZeroTransformer( 57 | vocab_size, d_model=32, d_ff=64, n_encoder_layers=2, n_decoder_layers=2, 58 | n_heads=2, max_len=16) 59 | xs = [np.ones((1, 8)).astype(np.int32), 60 | np.ones((1, 8)).astype(np.int32)] 61 | _, _ = model.init(shapes.signature(xs)) 62 | ys = model(xs) 63 | self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) 64 | 65 | 66 | if __name__ == '__main__': 67 | absltest.main() 68 | -------------------------------------------------------------------------------- /trax/models/research/terraformer_memory_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test for memory usage in Terraformer models. 17 | 18 | This test is designed to run on TPUv3 hardware, processing 1 million tokens at a 19 | time while just barely fitting within the 16 GB memory budget. 20 | """ 21 | 22 | 23 | from absl.testing import absltest 24 | 25 | 26 | 27 | class TerraformerMemoryTest(absltest.TestCase): 28 | 29 | 30 | def test_terraformer_memory(self): 31 | pass # TODO(jonni): Figure out an OSS-compatible memory test. 32 | 33 | 34 | if __name__ == '__main__': 35 | config.config_with_absl() 36 | absltest.main() 37 | -------------------------------------------------------------------------------- /trax/models/research/testdata/translate_ende_wmt32k-dev-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/models/research/testdata/translate_ende_wmt32k-dev-00000-of-00001 -------------------------------------------------------------------------------- /trax/models/research/testdata/translate_ende_wmt32k-train-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/models/research/testdata/translate_ende_wmt32k-train-00000-of-00001 -------------------------------------------------------------------------------- /trax/models/resnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Resnet models.""" 17 | 18 | from absl.testing import absltest 19 | import numpy as np 20 | 21 | from trax import fastmath 22 | from trax import shapes 23 | from trax.models import resnet 24 | 25 | 26 | class ResnetTest(absltest.TestCase): 27 | 28 | def test_resnet(self): 29 | model = resnet.Resnet50(d_hidden=8, n_output_classes=10) 30 | x = np.ones((3, 256, 256, 3)).astype(np.float32) 31 | _, _ = model.init(shapes.signature(x)) 32 | y = model(x) 33 | self.assertEqual(y.shape, (3, 10)) 34 | 35 | def test_wide_resnet(self): 36 | model = resnet.WideResnet(n_blocks=1, n_output_classes=10) 37 | x = np.ones((3, 32, 32, 3)).astype(np.float32) 38 | _, _ = model.init(shapes.signature(x)) 39 | y = model(x) 40 | self.assertEqual(y.shape, (3, 10)) 41 | 42 | 43 | 44 | if __name__ == '__main__': 45 | absltest.main() 46 | -------------------------------------------------------------------------------- /trax/models/rl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for RL.""" 17 | 18 | from unittest import mock 19 | from absl.testing import absltest 20 | import numpy as np 21 | 22 | from trax import shapes 23 | from trax.models import rl 24 | 25 | 26 | class RLTest(absltest.TestCase): 27 | 28 | def test_policy_forward_shape(self): 29 | mock_dist = mock.MagicMock() 30 | mock_dist.n_inputs = 4 31 | model = rl.Policy(policy_distribution=mock_dist) 32 | x = np.ones((2, 3)) 33 | _, _ = model.init(shapes.signature(x)) 34 | y = model(x) 35 | self.assertEqual(y.shape, (2, 4)) 36 | 37 | def test_value_forward_shape(self): 38 | model = rl.Value() 39 | x = np.ones((2, 3)) 40 | _, _ = model.init(shapes.signature(x)) 41 | y = model(x) 42 | self.assertEqual(y.shape, (2, 1)) 43 | 44 | def test_policy_and_value_forward_shape(self): 45 | mock_dist = mock.MagicMock() 46 | mock_dist.n_inputs = 4 47 | model = rl.PolicyAndValue(policy_distribution=mock_dist) 48 | x = np.ones((2, 3)) 49 | _, _ = model.init(shapes.signature(x)) 50 | ys = model(x) 51 | self.assertEqual([y.shape for y in ys], [(2, 4), (2, 1)]) 52 | 53 | 54 | if __name__ == '__main__': 55 | absltest.main() 56 | -------------------------------------------------------------------------------- /trax/models/rnn_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for RNNs.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import numpy as np 21 | 22 | from trax import fastmath 23 | from trax import shapes 24 | from trax.models import rnn 25 | 26 | BACKENDS = [fastmath.Backend.JAX] 27 | 28 | 29 | @parameterized.named_parameters( 30 | ('_' + b.value, b) for b in BACKENDS) 31 | class RNNTest(parameterized.TestCase): 32 | 33 | def test_rnnlm_forward_shape(self, backend): 34 | with fastmath.use_backend(backend): 35 | model = rnn.RNNLM(vocab_size=20, d_model=16) 36 | x = np.ones((3, 28)).astype(np.int32) 37 | _, _ = model.init(shapes.signature(x)) 38 | y = model(x) 39 | self.assertEqual(y.shape, (3, 28, 20)) 40 | 41 | def test_grulm_forward_shape(self, backend): 42 | with fastmath.use_backend(backend): 43 | model = rnn.GRULM(vocab_size=20, d_model=16) 44 | x = np.ones((3, 28)).astype(np.int32) 45 | _, _ = model.init(shapes.signature(x)) 46 | y = model(x) 47 | self.assertEqual(y.shape, (3, 28, 20)) 48 | 49 | def test_lstmseq2seqattn_forward_shape(self, backend): 50 | with fastmath.use_backend(backend): 51 | model = rnn.LSTMSeq2SeqAttn( 52 | input_vocab_size=20, target_vocab_size=20, d_model=16) 53 | x = np.ones((3, 28)).astype(np.int32) 54 | _, _ = model.init([shapes.signature(x), shapes.signature(x)]) 55 | ys = model([x, x]) 56 | self.assertEqual([y.shape for y in ys], [(3, 28, 20), (3, 28)]) 57 | 58 | 59 | if __name__ == '__main__': 60 | absltest.main() 61 | -------------------------------------------------------------------------------- /trax/models/transformer_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Transformer models.""" 17 | 18 | import functools 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | import numpy as np 23 | 24 | from trax import fastmath 25 | from trax import shapes 26 | from trax.layers import test_utils 27 | from trax.models import transformer 28 | 29 | 30 | class TransformerTest(parameterized.TestCase): 31 | 32 | def test_transformer_lm_forward_shape(self): 33 | vocab_size = 16 34 | model = transformer.TransformerLM( 35 | vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2) 36 | x = np.ones((3, 5)).astype(np.int32) 37 | _, _ = model.init(shapes.signature(x)) 38 | y = model(x) 39 | self.assertEqual(y.shape, (3, 5, vocab_size)) 40 | 41 | def _test_transformer_forward_shape(self, input_vocab_size, 42 | output_vocab_size): 43 | model = transformer.Transformer( 44 | input_vocab_size, output_vocab_size, d_model=32, d_ff=64, 45 | n_encoder_layers=2, n_decoder_layers=2, n_heads=2) 46 | xs = [np.ones((3, 5)).astype(np.int32), np.ones((3, 5)).astype(np.int32)] 47 | _, _ = model.init(shapes.signature(xs)) 48 | y, _ = model(xs) 49 | 50 | vocab_size = output_vocab_size or input_vocab_size 51 | self.assertEqual(y.shape, (3, 5, vocab_size)) 52 | 53 | @parameterized.named_parameters( 54 | ('same_vocab', 16, None), 55 | ('same_size', 16, 16), 56 | ('different_size', 16, 50)) 57 | def test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): 58 | """Run the Transformer forward and check output shape.""" 59 | self._test_transformer_forward_shape(input_vocab_size, output_vocab_size) 60 | 61 | 62 | def test_dot_product_causal_attention_fast_inference(self): 63 | model_fn = functools.partial( 64 | transformer.TransformerLM, d_model=4, d_ff=8, n_layers=2, n_heads=2 65 | ) 66 | test_utils.test_eval_equals_predict_discrete(model_fn) 67 | 68 | 69 | if __name__ == '__main__': 70 | absltest.main() 71 | -------------------------------------------------------------------------------- /trax/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Optimizers for use with Trax layers.""" 17 | 18 | import gin 19 | 20 | from trax.optimizers import adafactor 21 | from trax.optimizers import adam 22 | from trax.optimizers import base 23 | from trax.optimizers import momentum 24 | from trax.optimizers import rms_prop 25 | from trax.optimizers import sm3 26 | from trax.optimizers import trainer 27 | from trax.optimizers.trainer import ReversibleSerialTrainer 28 | from trax.optimizers.trainer import Trainer 29 | 30 | 31 | def opt_configure(*args, **kwargs): 32 | kwargs['module'] = 'trax.optimizers' 33 | return gin.external_configurable(*args, **kwargs) 34 | 35 | # Optimizers (using upper-case names). 36 | # pylint: disable=invalid-name 37 | SGD = opt_configure(base.SGD) 38 | Momentum = opt_configure(momentum.Momentum) 39 | RMSProp = opt_configure(rms_prop.RMSProp) 40 | Adam = opt_configure(adam.Adam) 41 | Adafactor = opt_configure(adafactor.Adafactor) 42 | SM3 = opt_configure(sm3.SM3) 43 | -------------------------------------------------------------------------------- /trax/optimizers/momentum.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Nesterov momentum optimizer (also known as Nesterov Accelerated Gradient).""" 17 | 18 | from trax.fastmath import numpy as jnp 19 | from trax.optimizers import base 20 | 21 | 22 | # TODO(jonni): Consider renaming this class to NesterovMomentum. 23 | class Momentum(base.Optimizer): 24 | r"""A momentum optimizer. 25 | 26 | This class implements two variants of momentum stochastic gradient descent 27 | (SGD): with and without the Nesterov correction. The implementation of the 28 | Nesterov update is based on the concepts in Sutskever et al. (2013) 29 | [http://jmlr.org/proceedings/papers/v28/sutskever13.pdf], reformulated in 30 | Bengio et al. (2012) [https://arxiv.org/abs/1212.0901], to work well with 31 | backpropagation (equations 6 and 7): 32 | 33 | .. math:: 34 | v_t &= \mu_{t-1}v_{t-1} - \epsilon_{t-1}\nabla f(\Theta_{t-1}) \\ 35 | \Theta_t &= \Theta_{t-1} - \mu_{t-1} v_{t-1} + \mu_t v_t + v_t 36 | 37 | where :math:`\mu_{t-1}` is the momentum (decay) coefficient at time step 38 | :math:`t-1` and :math:`\epsilon_{t-1}` is the learning rate at :math:`t-1`. 39 | 40 | Note that the implementation below also includes a weight decay rate 41 | (:math:`\alpha`) on the parameters, independent of the Nesterov momentum. 42 | """ 43 | 44 | def __init__( 45 | self, learning_rate=0.01, mass=0.9, weight_decay_rate=1e-5, nesterov=True 46 | ): # pylint: disable=useless-super-delegation 47 | super().__init__( 48 | learning_rate=learning_rate, 49 | mass=mass, 50 | weight_decay_rate=weight_decay_rate, 51 | ) 52 | self._nesterov = nesterov 53 | 54 | def init(self, weights): 55 | return jnp.zeros_like(weights) 56 | 57 | def update(self, step, grads, weights, velocity, opt_params): 58 | del step 59 | v = velocity 60 | mu = opt_params['mass'] 61 | alpha = opt_params['weight_decay_rate'] 62 | epsilon = opt_params['learning_rate'] 63 | 64 | new_v = mu * v + grads 65 | if self._nesterov: 66 | weight_update = mu * new_v + grads 67 | else: 68 | weight_update = new_v 69 | new_weights = (1 - alpha) * weights - epsilon * weight_update 70 | 71 | new_weights = new_weights.astype(weights.dtype) 72 | return (new_weights, new_v) 73 | -------------------------------------------------------------------------------- /trax/optimizers/optimizers_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for supervised training optimizers.""" 17 | 18 | from absl.testing import absltest 19 | 20 | import numpy as np 21 | 22 | from trax import optimizers 23 | from trax.optimizers import momentum 24 | 25 | 26 | class OptimizersTest(absltest.TestCase): 27 | 28 | def test_slots(self): 29 | weights_shape = (3, 5) 30 | weight_tree = np.arange(15).reshape(weights_shape) 31 | 32 | # SGD - an optimizer that doesn't use slots. 33 | opt_1 = optimizers.SGD(.01) 34 | self.assertIsNone(opt_1.slots) 35 | opt_1.tree_init(weight_tree) 36 | self.assertIsInstance(opt_1.slots, tuple) 37 | self.assertLen(opt_1.slots, 1) 38 | self.assertIsNone(opt_1.slots[0]) 39 | 40 | # Momentum - an optimizer with slots 41 | opt_2 = momentum.Momentum(.01) 42 | self.assertIsNone(opt_2.slots) 43 | opt_2.tree_init(weight_tree) 44 | self.assertIsInstance(opt_2.slots, tuple) 45 | self.assertLen(opt_2.slots, 1) 46 | self.assertEqual(weights_shape, opt_2.slots[0].shape) 47 | 48 | 49 | if __name__ == '__main__': 50 | absltest.main() 51 | -------------------------------------------------------------------------------- /trax/optimizers/rms_prop.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """RMSProp optimizer class.""" 17 | 18 | from trax.fastmath import numpy as jnp 19 | from trax.optimizers import base as opt_base 20 | 21 | 22 | class RMSProp(opt_base.Optimizer): 23 | """RMSProp optimizer. 24 | 25 | Uses optimizer weights ("slots") to maintain a root-mean-square exponentially 26 | decaying average of gradients from prior training batches. 27 | """ 28 | 29 | def __init__(self, learning_rate=0.001, gamma=0.9, 30 | eps=1e-8, clip_grad_norm=None): # pylint: disable=useless-super-delegation 31 | super().__init__( 32 | learning_rate=learning_rate, 33 | gamma=gamma, 34 | eps=eps, 35 | clip_grad_norm=clip_grad_norm 36 | ) 37 | 38 | def init(self, weights): 39 | return jnp.ones_like(weights) 40 | 41 | def update(self, step, grads, weights, avg_sq_grad, opt_params): 42 | del step 43 | lr = opt_params['learning_rate'] 44 | gamma = opt_params['gamma'] 45 | eps = opt_params['eps'] 46 | avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1. - gamma) 47 | weights = weights - (lr * grads / 48 | (jnp.sqrt(avg_sq_grad) + eps)).astype(weights.dtype) 49 | return weights, avg_sq_grad 50 | -------------------------------------------------------------------------------- /trax/rl/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Trax RL library.""" 17 | 18 | import gin 19 | 20 | from trax.rl import actor_critic 21 | from trax.rl import actor_critic_joint 22 | from trax.rl import envs 23 | from trax.rl import serialization_utils 24 | from trax.rl import training 25 | 26 | 27 | def configure_rl(*args, **kwargs): 28 | kwargs['module'] = 'trax.rl' 29 | kwargs['denylist'] = ['task', 'output_dir'] 30 | return gin.external_configurable(*args, **kwargs) 31 | 32 | 33 | A2C = configure_rl(actor_critic.A2C) 34 | AWR = configure_rl(actor_critic.AWR) 35 | LoopAWR = configure_rl(actor_critic.LoopAWR) 36 | PPO = configure_rl(actor_critic.PPO) 37 | SamplingAWR = configure_rl(actor_critic.SamplingAWR) 38 | 39 | A2CJoint = configure_rl(actor_critic_joint.A2CJoint) 40 | AWRJoint = configure_rl(actor_critic_joint.AWRJoint) 41 | PPOJoint = configure_rl(actor_critic_joint.PPOJoint) 42 | 43 | PolicyGradient = configure_rl(training.PolicyGradient) 44 | ExpertIteration = configure_rl(training.ExpertIteration) 45 | DQN = configure_rl(training.DQN) 46 | 47 | TimeSeriesModel = gin.external_configurable( 48 | serialization_utils.TimeSeriesModel, module='trax.rl' 49 | ) 50 | -------------------------------------------------------------------------------- /trax/rl/atari_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for RL training.""" 17 | 18 | import functools 19 | 20 | from absl.testing import absltest 21 | 22 | from trax import models 23 | from trax import optimizers as opt 24 | from trax.models import atari_cnn 25 | from trax.rl import actor_critic 26 | from trax.rl import task as rl_task 27 | from trax.supervised import lr_schedules 28 | 29 | 30 | 31 | 32 | if __name__ == '__main__': 33 | absltest.main() 34 | -------------------------------------------------------------------------------- /trax/rl/configs/light_a2c_joint_atari_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | A2CJoint.train_steps_per_epoch: [40, 80] 16 | multifactor.constant: [0.0007, 0.0001] 17 | A2CJoint.value_loss_coeff: [0.1, 0.25] 18 | A2CJoint.batch_size: [128, 256] 19 | A2CJoint.n_trajectories_per_epoch: [5,] 20 | RLTask.max_steps: [2000,] 21 | A2CJoint.entropy_coeff: [0.01, 0.1,] 22 | -------------------------------------------------------------------------------- /trax/rl/configs/light_atari_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | RLTask.env: [ 16 | "boxing", 17 | "breakout", 18 | "freeway", 19 | "gopher", 20 | "pong", 21 | "seaquest", 22 | ] 23 | 24 | # Sweep 1: 25 | train_steps: [100, 300, 1000, 3000] 26 | SamplingAWR.n_interactions_per_epoch: [200, 500, 1000, 2000] 27 | SamplingAWR.n_replay_epochs: [50, 100, 200, 500, 1000] 28 | 29 | # Sweep 2: 30 | SamplingAWR.q_value_temperature: [0.3, 1.0] 31 | policy_lr: [0.00003, 0.0001, 0.0003] 32 | value_lr: [0.001, 0.0005, 0.0001] 33 | 34 | # Sweep 3: 35 | SamplingAWR.value_evals_per_epoch: [5, 10, 20, 50] 36 | -------------------------------------------------------------------------------- /trax/rl/configs/light_awr_cartpole_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | AWRJoint.policy_train_steps_per_epoch: [10, 100, 1000] 16 | multifactor.constant: [0.0001, 0.001, 0.01] 17 | AWRJoint.batch_size: [32, 128] 18 | -------------------------------------------------------------------------------- /trax/rl/configs/light_awr_joint_atari_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | AWRJoint.train_steps_per_epoch: [10, 100, 1000] 16 | multifactor.constant: [0.0001, 0.001, 0.01] 17 | AWRJoint.batch_size: [32, 128] 18 | RLTask.max_steps: [200, 2000] 19 | RLTask.env: ["pong"] 20 | -------------------------------------------------------------------------------- /trax/rl/configs/light_awr_joint_cartpole.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.supervised.lr_schedules 16 | import trax.models 17 | import trax.optimizers 18 | import trax.rl 19 | import trax.rl_trainer 20 | 21 | # Parameters for PolicyAndValue: 22 | # ============================================================================== 23 | PolicyAndValue.body = @trax.models.MLP 24 | 25 | # Parameters for MLP: 26 | # ============================================================================== 27 | MLP.flatten = False 28 | MLP.layer_widths = (128,) 29 | MLP.out_activation = True 30 | 31 | # Parameters for multifactor: 32 | # ============================================================================== 33 | multifactor.constant = 0.01 34 | multifactor.factors = 'constant' 35 | 36 | # Parameters for RLTask: 37 | # ============================================================================== 38 | RLTask.env = "CartPole-v0" 39 | RLTask.initial_trajectories = 1000 40 | RLTask.gamma = 0.99 41 | RLTask.max_steps = 200 42 | 43 | # Parameters for AWR: 44 | # ============================================================================== 45 | AWRJoint.joint_model = @trax.models.PolicyAndValue 46 | AWRJoint.optimizer = @trax.optimizers.Adam 47 | AWRJoint.batch_size = 32 48 | AWRJoint.train_steps_per_epoch = 1000 49 | AWRJoint.lr_schedule = @multifactor 50 | AWRJoint.n_trajectories_per_epoch = 10 51 | AWRJoint.beta = 1.0 52 | AWRJoint.w_max = 20 53 | AWRJoint.max_slice_length = 1 54 | 55 | # Parameters for train_rl: 56 | # ============================================================================== 57 | train_rl.light_rl = True 58 | train_rl.light_rl_trainer = @trax.rl.AWRJoint 59 | train_rl.n_epochs = 10000 60 | -------------------------------------------------------------------------------- /trax/rl/configs/light_awr_joint_cartpole_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | AWRJoint.train_steps_per_epoch: [10, 100, 1000] 16 | multifactor.constant: [0.0001, 0.001, 0.01] 17 | AWRJoint.batch_size: [32, 128] 18 | -------------------------------------------------------------------------------- /trax/rl/configs/light_copy_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Task difficulty: 16 | copy_length: [3, 4, 6, 8] 17 | # Model params: 18 | TransformerDecoder.n_layers: [1, 2, 3] 19 | d_model: [32, 64, 128, 256, 512] 20 | -------------------------------------------------------------------------------- /trax/rl/configs/light_mujoco_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | RLTask.env: [ 16 | "DM-HalfCheetah-v2", 17 | "DM-Hopper-v2", 18 | "DM-Humanoid-v2", 19 | "DM-Walker2d-v2", 20 | ] 21 | 22 | # Sweep 1: 23 | train_steps: [100, 300, 1000, 3000] 24 | SamplingAWR.n_interactions_per_epoch: [200, 500, 1000, 2000] 25 | SamplingAWR.n_replay_epochs: [50, 100, 200, 500, 1000] 26 | 27 | # Sweep 2: 28 | SamplingAWR.q_value_temperature: [0.3, 1.0] 29 | policy_lr: [0.00003, 0.0001, 0.0003] 30 | value_lr: [0.001, 0.0005, 0.0001] 31 | 32 | # Sweep 3: 33 | SamplingAWR.value_evals_per_epoch: [5, 10, 20, 50] 34 | -------------------------------------------------------------------------------- /trax/rl/configs/light_ppo_atari.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.layers 16 | import trax.supervised.lr_schedules 17 | import trax.models 18 | import trax.optimizers 19 | import trax.rl 20 | import trax.rl_trainer 21 | 22 | # Parameters for Policy: 23 | # ============================================================================== 24 | Policy.body = @trax.models.AtariCnnBody 25 | 26 | # Parameters for Value: 27 | # ============================================================================== 28 | Value.body = @trax.models.AtariCnnBody 29 | 30 | # Parameters for the AtariCnnBody: 31 | # ============================================================================== 32 | AtariCnnBody.kernel_initializer = @trax.layers.AtariConvInit 33 | 34 | # Parameters for multifactor: 35 | # ============================================================================== 36 | value/multifactor.constant = 0.0001 37 | value/multifactor.factors = 'constant' 38 | policy/multifactor.constant = 0.0001 39 | policy/multifactor.factors = 'constant' 40 | 41 | # Parameters for RLTask: 42 | # ============================================================================== 43 | RLTask.env = "boxing" 44 | RLTask.initial_trajectories = 100 45 | RLTask.gamma = 0.99 46 | RLTask.max_steps = 200 47 | RLTask.dm_suite = True 48 | 49 | # Parameters for PPO: 50 | # ============================================================================== 51 | PPO.n_shared_layers=0 52 | PPO.value_model = @trax.models.Value 53 | PPO.value_optimizer = @trax.optimizers.Adam 54 | PPO.value_batch_size = 4 55 | PPO.value_train_steps_per_epoch = 10 56 | PPO.value_lr_schedule = @value/multifactor 57 | PPO.policy_model = @trax.models.Policy 58 | PPO.policy_optimizer = @trax.optimizers.Adam 59 | PPO.policy_batch_size = 4 60 | PPO.policy_train_steps_per_epoch = 10 61 | PPO.policy_lr_schedule = @policy/multifactor 62 | PPO.n_trajectories_per_epoch = 50 63 | 64 | # Parameters for train_rl: 65 | # ============================================================================== 66 | train_rl.light_rl = True 67 | train_rl.light_rl_trainer = @trax.rl.PPO 68 | train_rl.n_epochs = 5000 69 | -------------------------------------------------------------------------------- /trax/rl/configs/light_ppo_boxing_regression_test.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.supervised.lr_schedules 16 | import trax.models 17 | import trax.optimizers 18 | import trax.rl 19 | import trax.rl_trainer 20 | 21 | # Parameters for Adam: 22 | # ============================================================================== 23 | Adam.clip_grad_norm = 0.5 24 | 25 | # Parameters for PolicyAndValue: 26 | # ============================================================================== 27 | PolicyAndValue.body = @trax.models.AtariCnnBody 28 | 29 | # Parameters for the AtariCnnBody: 30 | # ============================================================================== 31 | AtariCnnBody.kernel_initializer = @trax.layers.AtariConvInit 32 | AtariCnnBody.n_frames = 1 33 | AtariCnnBody.padding = 'VALID' 34 | 35 | # Parameters for multifactor: 36 | # ============================================================================== 37 | multifactor.constant = 0.0001 38 | multifactor.factors = 'constant * linear_warmup' 39 | multifactor.warmup_steps = 100 40 | 41 | # Parameters for RLTask: 42 | # ============================================================================== 43 | RLTask.env = "boxing" 44 | RLTask.initial_trajectories = 0 45 | RLTask.gamma = 0.99 46 | RLTask.max_steps = 2000 47 | RLTask.dm_suite = True 48 | RLTask.num_stacked_frames = 4 49 | 50 | # Parameters for PPO: 51 | # ============================================================================== 52 | PPOJoint.joint_model = @trax.models.PolicyAndValue 53 | PPOJoint.optimizer = @trax.optimizers.Adam 54 | PPOJoint.batch_size = 256 55 | PPOJoint.train_steps_per_epoch = 80 56 | PPOJoint.lr_schedule = @multifactor 57 | PPOJoint.n_trajectories_per_epoch = 5 58 | PPOJoint.epsilon = 0.1 59 | PPOJoint.value_loss_coeff = 0.1 60 | PPOJoint.entropy_coeff = 0.01 61 | 62 | # Parameters for train_rl: 63 | # ============================================================================== 64 | train_rl.light_rl = True 65 | train_rl.light_rl_trainer = @trax.rl.PPOJoint 66 | train_rl.n_epochs = 900 67 | -------------------------------------------------------------------------------- /trax/rl/configs/light_ppo_cartpole_regression_test.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.supervised.lr_schedules 16 | import trax.models 17 | import trax.optimizers 18 | import trax.rl 19 | import trax.rl_trainer 20 | 21 | # Parameters for Adam: 22 | # ============================================================================== 23 | Adam.clip_grad_norm = 0.5 24 | 25 | # Parameters for PolicyAndValue: 26 | # ============================================================================== 27 | PolicyAndValue.body = @trax.models.MLP 28 | 29 | # Parameters for MLP: 30 | # ============================================================================== 31 | MLP.flatten = False 32 | MLP.layer_widths = (64,64) 33 | MLP.out_activation = True 34 | 35 | # Parameters for multifactor: 36 | # ============================================================================== 37 | multifactor.constant = 0.0025 38 | multifactor.factors = 'constant * linear_warmup' 39 | multifactor.warmup_steps = 100 40 | 41 | # Parameters for RLTask: 42 | # ============================================================================== 43 | RLTask.env = "CartPole-v0" 44 | RLTask.initial_trajectories = 0 45 | RLTask.gamma = 0.99 46 | RLTask.max_steps = 200 47 | 48 | # Parameters for PPO: 49 | # ============================================================================== 50 | PPOJoint.joint_model = @trax.models.PolicyAndValue 51 | PPOJoint.optimizer = @trax.optimizers.Adam 52 | PPOJoint.batch_size = 256 53 | PPOJoint.train_steps_per_epoch = 40 54 | PPOJoint.lr_schedule = @multifactor 55 | PPOJoint.n_trajectories_per_epoch = 20 56 | PPOJoint.epsilon = 0.1 57 | PPOJoint.value_loss_coeff = 0 58 | PPOJoint.entropy_coeff = 0 59 | 60 | 61 | # Parameters for train_rl: 62 | # ============================================================================== 63 | train_rl.light_rl = True 64 | train_rl.light_rl_trainer = @trax.rl.PPOJoint 65 | train_rl.n_epochs = 500 66 | -------------------------------------------------------------------------------- /trax/rl/configs/light_ppo_joint_atari.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.supervised.lr_schedules 16 | import trax.models 17 | import trax.optimizers 18 | import trax.rl 19 | import trax.rl_trainer 20 | 21 | # Parameters for Adam: 22 | # ============================================================================== 23 | Adam.clip_grad_norm = 0.5 24 | 25 | # Parameters for PolicyAndValue: 26 | # ============================================================================== 27 | PolicyAndValue.body = @trax.models.AtariCnnBody 28 | 29 | # Parameters for the AtariCnnBody: 30 | # ============================================================================== 31 | AtariCnnBody.kernel_initializer = @trax.layers.AtariConvInit 32 | AtariCnnBody.n_frames = 1 33 | AtariCnnBody.padding = 'VALID' 34 | 35 | # Parameters for multifactor: 36 | # ============================================================================== 37 | multifactor.constant = 0.01 38 | multifactor.factors = 'constant' 39 | 40 | # Parameters for RLTask: 41 | # ============================================================================== 42 | RLTask.env = "boxing" 43 | RLTask.initial_trajectories = 0 44 | RLTask.gamma = 0.99 45 | RLTask.max_steps = 2000 46 | RLTask.dm_suite = True 47 | RLTask.num_stacked_frames = 4 48 | 49 | # Parameters for PPO: 50 | # ============================================================================== 51 | PPOJoint.joint_model = @trax.models.PolicyAndValue 52 | PPOJoint.optimizer = @trax.optimizers.Adam 53 | PPOJoint.batch_size = 16 54 | PPOJoint.train_steps_per_epoch = 20 55 | PPOJoint.lr_schedule = @multifactor 56 | PPOJoint.n_trajectories_per_epoch = 25 57 | PPOJoint.epsilon = 0.1 58 | PPOJoint.value_loss_coeff = 0.1 59 | PPOJoint.entropy_coeff = 0.01 60 | 61 | # Parameters for train_rl: 62 | # ============================================================================== 63 | train_rl.light_rl = True 64 | train_rl.light_rl_trainer = @trax.rl.PPOJoint 65 | train_rl.n_epochs = 10000 66 | -------------------------------------------------------------------------------- /trax/rl/configs/light_ppo_joint_atari_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | PPOJoint.train_steps_per_epoch: [10, 100] 16 | multifactor.constant: [0.0001, 0.000025] 17 | PPOJoint.value_loss_coeff: [0.01, 0.1, 1] 18 | PPOJoint.batch_size: [8, 16, 32, 256] 19 | PPOJoint.epsilon: [0.1, 0.2, 0.3] 20 | RLTask.max_steps: [200, 2000] 21 | RLTask.env: ["pong", "boxing", "breakout", "freeway"] 22 | -------------------------------------------------------------------------------- /trax/rl/configs/light_ppo_lunar_lander_regression_test.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.supervised.lr_schedules 16 | import trax.models 17 | import trax.optimizers 18 | import trax.rl 19 | import trax.rl_trainer 20 | 21 | # Parameters for Adam: 22 | # ============================================================================== 23 | Adam.clip_grad_norm = 0.5 24 | 25 | # Parameters for PolicyAndValue: 26 | # ============================================================================== 27 | PolicyAndValue.body = @trax.models.MLP 28 | 29 | # Parameters for MLP: 30 | # ============================================================================== 31 | MLP.flatten = False 32 | MLP.layer_widths = (64,64) 33 | MLP.out_activation = True 34 | 35 | # Parameters for multifactor: 36 | # ============================================================================== 37 | multifactor.constant = 0.0025 38 | multifactor.factors = 'constant * linear_warmup' 39 | multifactor.warmup_steps = 100 40 | 41 | # Parameters for RLTask: 42 | # ============================================================================== 43 | RLTask.env = "LunarLander-v2" 44 | RLTask.initial_trajectories = 0 45 | RLTask.gamma = 0.99 46 | RLTask.max_steps = 1000 47 | 48 | # Parameters for PPO: 49 | # ============================================================================== 50 | PPOJoint.joint_model = @trax.models.PolicyAndValue 51 | PPOJoint.optimizer = @trax.optimizers.Adam 52 | PPOJoint.batch_size = 256 53 | PPOJoint.train_steps_per_epoch = 200 54 | PPOJoint.lr_schedule = @multifactor 55 | PPOJoint.n_trajectories_per_epoch = 20 56 | PPOJoint.epsilon = 0.1 57 | PPOJoint.value_loss_coeff = 0 58 | PPOJoint.entropy_coeff = 0.01 59 | 60 | 61 | # Parameters for train_rl: 62 | # ============================================================================== 63 | train_rl.light_rl = True 64 | train_rl.light_rl_trainer = @trax.rl.PPOJoint 65 | train_rl.n_epochs = 130 66 | -------------------------------------------------------------------------------- /trax/rl/configs/light_ppo_pong_regression_test.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.supervised.lr_schedules 16 | import trax.models 17 | import trax.optimizers 18 | import trax.rl 19 | import trax.rl_trainer 20 | 21 | # Parameters for Adam: 22 | # ============================================================================== 23 | Adam.clip_grad_norm = 0.5 24 | 25 | # Parameters for PolicyAndValue: 26 | # ============================================================================== 27 | PolicyAndValue.body = @trax.models.AtariCnnBody 28 | 29 | # Parameters for the AtariCnnBody: 30 | # ============================================================================== 31 | AtariCnnBody.kernel_initializer = @trax.layers.AtariConvInit 32 | AtariCnnBody.n_frames = 1 33 | AtariCnnBody.padding = 'VALID' 34 | 35 | # Parameters for multifactor: 36 | # ============================================================================== 37 | multifactor.constant = 0.0001 38 | multifactor.factors = 'constant * linear_warmup' 39 | multifactor.warmup_steps = 100 40 | 41 | # Parameters for RLTask: 42 | # ============================================================================== 43 | RLTask.env = "pong" 44 | RLTask.initial_trajectories = 0 45 | RLTask.gamma = 0.99 46 | RLTask.max_steps = 2000 47 | RLTask.dm_suite = True 48 | RLTask.num_stacked_frames = 4 49 | 50 | # Parameters for PPO: 51 | # ============================================================================== 52 | PPOJoint.joint_model = @trax.models.PolicyAndValue 53 | PPOJoint.optimizer = @trax.optimizers.Adam 54 | PPOJoint.batch_size = 256 55 | PPOJoint.train_steps_per_epoch = 80 56 | PPOJoint.lr_schedule = @multifactor 57 | PPOJoint.n_trajectories_per_epoch = 5 58 | PPOJoint.epsilon = 0.1 59 | PPOJoint.value_loss_coeff = 0.25 60 | PPOJoint.entropy_coeff = 0.1 61 | 62 | # Parameters for train_rl: 63 | # ============================================================================== 64 | train_rl.light_rl = True 65 | train_rl.light_rl_trainer = @trax.rl.PPOJoint 66 | train_rl.n_epochs = 900 67 | -------------------------------------------------------------------------------- /trax/rl/configs/ppo_atari_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | RLTask.gamma: [0.99] 16 | RLTask.env: ["boxing"] 17 | value__.multifactor.constant: [0.0001, 0.001] 18 | value__.multifactor.factors: ['constant'] 19 | policy__.multifactor.constant: [0.0001, 0.00025] 20 | policy__.multifactor.factors: ['constant'] 21 | -------------------------------------------------------------------------------- /trax/rl/configs/ppo_cartpole_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | RLTask.gamma: [0.99, 0.95] 16 | -------------------------------------------------------------------------------- /trax/rl/configs/transformer_srl_sine.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Not really an RL config, but it's here because it depends on serialization 16 | # logic which is currently in trax.rl. 17 | 18 | import trax.data.tf_inputs 19 | import trax.models 20 | import trax.optimizers 21 | import trax.supervised.trainer_lib 22 | 23 | # Module trax.data.inputs: 24 | # ============================================================================== 25 | sine_inputs.batch_size = 64 26 | sine_inputs.length = 100 27 | sine_inputs.min_period = 0.1 28 | sine_inputs.max_period = 10 29 | 30 | # Module trax.models.transformer: 31 | # ============================================================================== 32 | TransformerLM.d_model = 128 33 | TransformerLM.d_ff = 256 34 | TransformerLM.dropout = 0.1 35 | TransformerLM.max_len = 2048 36 | TransformerLM.mode = 'train' 37 | TransformerLM.n_heads = 2 38 | TransformerLM.n_layers = 2 39 | 40 | # Module trax.rl.serialization_utils: 41 | # ============================================================================== 42 | TimeSeriesModel.seq_model = @trax.models.TransformerLM 43 | TimeSeriesModel.low = -1.0 44 | TimeSeriesModel.high = 1.0 45 | TimeSeriesModel.precision = 2 46 | TimeSeriesModel.vocab_size = 16 47 | TimeSeriesModel.significance_decay = 0.7 48 | 49 | # Module trax.supervised.lr_schedules: 50 | # ============================================================================== 51 | multifactor.constant = 0.1 52 | multifactor.factors = 'constant * linear_warmup * rsqrt_decay' 53 | multifactor.warmup_steps = 1000 54 | 55 | # Module trax.supervised.trainer_lib: 56 | # ============================================================================== 57 | train.inputs = @trax.data.sine_inputs 58 | train.eval_frequency = 1000 59 | train.eval_steps = 10 60 | train.model = @trax.rl.TimeSeriesModel 61 | train.optimizer = @trax.optimizers.Adam 62 | train.steps = 10000 63 | train.callbacks = ( 64 | @trax.supervised.callbacks.SerializedModelEvaluation, 65 | ) 66 | 67 | # Module trax.supervised.callbacks: 68 | # ============================================================================== 69 | SerializedModelEvaluation.eval_at = 500 70 | SerializedModelEvaluation.context_lengths = (0, 1, 10) 71 | SerializedModelEvaluation.horizon_lengths = (1, 3, 10) 72 | -------------------------------------------------------------------------------- /trax/rl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Trax RL environments library.""" 17 | 18 | import gin 19 | from trax.rl.envs import data_envs 20 | 21 | 22 | def configure_rl_env(*args, **kwargs): 23 | kwargs['module'] = 'trax.rl.envs' 24 | return gin.external_configurable(*args, **kwargs) 25 | 26 | 27 | copy_stream = configure_rl_env(data_envs.copy_stream) 28 | SequenceDataEnv = configure_rl_env(data_envs.SequenceDataEnv) 29 | -------------------------------------------------------------------------------- /trax/rl/normalization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for trax.rl.normalization.""" 17 | 18 | from absl.testing import absltest 19 | import numpy as np 20 | 21 | from trax import shapes 22 | from trax.rl import normalization 23 | 24 | 25 | class NormalizationTest(absltest.TestCase): 26 | 27 | def test_running_mean(self): 28 | x = np.random.uniform(size=10) 29 | state = normalization.running_mean_init(shape=()) 30 | for i in range(len(x)): 31 | state = normalization.running_mean_update(x[i], state) 32 | np.testing.assert_almost_equal( 33 | normalization.running_mean_get_mean(state), np.mean(x[:i + 1]) 34 | ) 35 | 36 | def test_running_variance(self): 37 | x = np.random.uniform(size=10) 38 | state = normalization.running_mean_and_variance_init(shape=()) 39 | for i in range(len(x)): 40 | state = normalization.running_mean_and_variance_update(x[i], state) 41 | np.testing.assert_almost_equal( 42 | normalization.running_mean_and_variance_get_variance(state), 43 | np.var(x[:i + 1]), 44 | ) 45 | 46 | def test_normalize_collect(self): 47 | x = np.random.uniform(size=(2, 3, 4, 5)) 48 | normalize = normalization.Normalize(mode='collect') 49 | normalize.init(shapes.signature(x)) 50 | old_state = normalize.state 51 | y = normalize(x) 52 | with self.assertRaises(AssertionError): 53 | np.testing.assert_equal(normalize.state, old_state) 54 | with self.assertRaises(AssertionError): 55 | np.testing.assert_almost_equal(x, y) 56 | 57 | def test_normalize_train(self): 58 | x = np.random.uniform(size=(2, 3, 4, 5)) 59 | normalize = normalization.Normalize(mode='train', epsilon=0.0) 60 | normalize.init(shapes.signature(x)) 61 | old_state = normalize.state 62 | y = normalize(x) 63 | np.testing.assert_equal(normalize.state, old_state) 64 | np.testing.assert_almost_equal(x, y) 65 | 66 | 67 | if __name__ == '__main__': 68 | absltest.main() 69 | -------------------------------------------------------------------------------- /trax/supervised/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Supervised learning imports in Trax.""" 17 | 18 | from trax.supervised import callbacks 19 | from trax.supervised import decoding 20 | from trax.supervised import lr_schedules 21 | from trax.supervised import trainer_lib 22 | from trax.supervised import training 23 | from trax.supervised.trainer_lib import train 24 | from trax.supervised.trainer_lib import Trainer 25 | from trax.supervised.training import EvalTask 26 | from trax.supervised.training import TrainTask 27 | -------------------------------------------------------------------------------- /trax/supervised/configs/bert.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import t5.data.preprocessors 16 | import trax.data 17 | import trax.layers 18 | import trax.models 19 | import trax.optimizers 20 | import trax.supervised.lr_schedules 21 | import trax.supervised.trainer_lib 22 | import trax.layers.metrics 23 | 24 | # Parameters for TFDS data pipeline: 25 | # ============================================================================== 26 | data.Tokenize.vocab_file = 'bert_uncased_vocab.txt' 27 | data.Tokenize.vocab_type = 'bert-lowercase' 28 | # If during the execution time of the binary the directory trax/data/testdata 29 | # containing the vocab file is not accessible, then copy the file to a drive 30 | # and change the path accordingly. 31 | data.Tokenize.vocab_dir = 'trax/data/testdata/' 32 | data.Tokenize.keys = [0, 1] 33 | data.PadToLength.len_map = {0: 512, 1: 512, 2: 512} 34 | data.PadToLength.pad_value = {0: 0, 1: 0, 2:0} 35 | data.TruncateToLength.len_map = {0: (512,), 1: (512,), 2: (512,)} 36 | data.Batch.batch_size = 16 37 | 38 | # Parameters for train: 39 | # ============================================================================== 40 | train.optimizer = @trax.optimizers.Adam 41 | train.eval_frequency = 20 42 | train.eval_steps = 10 43 | train.inputs = @data.make_inputs 44 | train.model = @trax.models.BERT 45 | train.steps = 200000 46 | train.checkpoint_highest = 'accuracy' 47 | 48 | # Parameters for BERT: 49 | # ============================================================================== 50 | BERT.init_checkpoint = 'bert-base-uncased' 51 | 52 | # Parameters for multifactor: 53 | # ============================================================================== 54 | multifactor.constant = 3e-5 55 | multifactor.factors = 'constant * linear_warmup' 56 | multifactor.warmup_steps = 1000 57 | -------------------------------------------------------------------------------- /trax/supervised/configs/bert_glue_classification.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models.research.bert 17 | 18 | include 'bert.gin' 19 | 20 | # Parameters for training and eval data: 21 | # ============================================================================== 22 | # 23 | # See https://www.tensorflow.org/datasets/catalog/glue. Valid benchmark names: 24 | # 25 | # cola, sst2, mrpc, qqp, stsb, mnli, qnli, rte, wnli. 26 | # 27 | # Training on WNLI with this setup is not recommended and will likely result in 28 | # lower than baseline accuracy. 29 | make_inputs.train_stream = @data.BertGlueTrainStream 30 | make_inputs.eval_stream = @data.BertGlueEvalStream 31 | BertGlueTrainStream.benchmark = 'mnli' 32 | BertGlueEvalStream.benchmark = 'mnli' 33 | 34 | # Parameters for BERT: 35 | # ============================================================================== 36 | BERT.head = @bert.BERTClassifierHead 37 | bert.BERTClassifierHead.n_classes = 3 38 | -------------------------------------------------------------------------------- /trax/supervised/configs/bert_glue_regression.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models.research.bert 17 | 18 | include 'bert.gin' 19 | 20 | # Parameters for training and eval data: 21 | # ============================================================================== 22 | # 23 | # See https://www.tensorflow.org/datasets/catalog/glue. Valid benchmark names: 24 | # 25 | # cola, sst2, mrpc, qqp, stsb, mnli, qnli, rte, wnli. 26 | # 27 | # Training on WNLI with this setup is not recommended and will likely result in 28 | # lower than baseline accuracy. 29 | make_inputs.train_stream = @data.BertGlueTrainStream 30 | make_inputs.eval_stream = @data.BertGlueEvalStream 31 | BertGlueTrainStream.benchmark = 'stsb' 32 | BertGlueEvalStream.benchmark = 'stsb' 33 | 34 | # Parameters for train: 35 | # ============================================================================== 36 | train.loss_fn = @trax.layers.L2Loss() 37 | train.metrics = {'loss': @trax.layers.L2Loss()} 38 | 39 | # Parameters for BERT: 40 | # ============================================================================== 41 | BERT.head = @bert.BERTRegressionHead 42 | -------------------------------------------------------------------------------- /trax/supervised/configs/bert_glue_sweep_regression_task.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | dataset_name: [ 16 | 'glue/stsb', 17 | ] 18 | multifactor.constant: [2.0e-5, 3.0e-5, 5.0e-5] 19 | -------------------------------------------------------------------------------- /trax/supervised/configs/bert_glue_sweep_single_sentence.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | dataset_name: [ 16 | 'glue/cola', 17 | 'glue/sst2', 18 | ] 19 | multifactor.constant: [2.0e-5, 3.0e-5, 5.0e-5] 20 | CreateBertInputs.double_sentence: [False] 21 | data.Tokenize.keys: [[0,]] 22 | -------------------------------------------------------------------------------- /trax/supervised/configs/bert_glue_sweep_two_sentences.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | dataset_name: [ 16 | 'glue/mrpc', 17 | 'glue/qqp', 18 | 'glue/mnli', 19 | 'glue/qnli', 20 | 'glue/rte', 21 | 'glue/wnli', 22 | ] 23 | multifactor.constant: [2.0e-5, 3.0e-5, 5.0e-5] 24 | -------------------------------------------------------------------------------- /trax/supervised/configs/bert_pretraining.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models.research.bert 17 | 18 | include 'bert.gin' 19 | 20 | dataset_name = 'wiki40b' 21 | 22 | # Parameters for TFDS data pipeline: 23 | # ============================================================================== 24 | make_inputs.train_stream = [ 25 | @train/data.BertNextSentencePredictionInputs(), 26 | @data.Tokenize(), 27 | @data.CreateBertInputs(), 28 | @data.Shuffle(), 29 | @data.PadToLength(), 30 | @data.TruncateToLength(), 31 | @data.mask_random_tokens, 32 | @data.Batch() 33 | ] 34 | make_inputs.eval_stream = [ 35 | @eval/data.BertNextSentencePredictionInputs(), 36 | @data.Tokenize(), 37 | @data.CreateBertInputs(), 38 | @data.Shuffle(), 39 | @data.PadToLength(), 40 | @data.TruncateToLength(), 41 | @data.mask_random_tokens, 42 | @data.Batch() 43 | ] 44 | 45 | train/data.BertNextSentencePredictionInputs.dataset_name = %dataset_name 46 | train/data.TFDS.train = True 47 | eval/data.BertNextSentencePredictionInputs.dataset_name = %dataset_name 48 | eval/data.TFDS.train = False 49 | 50 | # Parameters for train: 51 | # ============================================================================== 52 | train.loss_fn = @bert.BERTPretrainingLoss() 53 | train.metrics = {'loss': @bert.BERTPretrainingLoss()} 54 | 55 | # Parameters for BERT: 56 | # ============================================================================== 57 | BERT.head = @bert.BERTPretrainingHead 58 | bert.BERTPretrainingHead.n_classes = 2 59 | -------------------------------------------------------------------------------- /trax/supervised/configs/bert_pretraining_onlymlm.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models.research.bert 17 | 18 | include 'bert.gin' 19 | 20 | dataset_name = 'wiki40b' 21 | 22 | # Parameters for TFDS data pipeline: 23 | # ============================================================================== 24 | make_inputs.train_stream = [ 25 | @train/data.CorpusToRandomChunks(), 26 | @data.Tokenize(), 27 | @data.CreateBertInputs(), 28 | @data.Shuffle(), 29 | @data.PadToLength(), 30 | @data.TruncateToLength(), 31 | @data.mask_random_tokens, 32 | @data.Batch() 33 | ] 34 | make_inputs.eval_stream = [ 35 | @eval/data.CorpusToRandomChunks(), 36 | @data.Tokenize(), 37 | @data.CreateBertInputs(), 38 | @data.Shuffle(), 39 | @data.PadToLength(), 40 | @data.TruncateToLength(), 41 | @data.mask_random_tokens, 42 | @data.Batch() 43 | ] 44 | 45 | train/data.CorpusToRandomChunks.dataset_name = %dataset_name 46 | train/data.TFDS.train = True 47 | eval/data.CorpusToRandomChunks.dataset_name = %dataset_name 48 | eval/data.TFDS.train = False 49 | 50 | data.CreateBertInputs.labeled = False 51 | data.CreateBertInputs.double_sentence = False 52 | 53 | # Parameters for BERT: 54 | # ============================================================================== 55 | BERT.head = @bert.BERTMLMHead 56 | 57 | -------------------------------------------------------------------------------- /trax/supervised/configs/bert_pretraining_onlynsp.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models.research.bert 17 | 18 | include 'bert.gin' 19 | 20 | dataset_name = 'wiki40b' 21 | 22 | # Parameters for TFDS data pipeline: 23 | # ============================================================================== 24 | make_inputs.train_stream = [ 25 | @train/data.BertNextSentencePredictionInputs(), 26 | @data.Tokenize(), 27 | @data.CreateBertInputs(), 28 | @data.Shuffle(), 29 | @data.PadToLength(), 30 | @data.TruncateToLength(), 31 | @data.Batch() 32 | ] 33 | make_inputs.eval_stream = [ 34 | @eval/data.BertNextSentencePredictionInputs(), 35 | @data.Tokenize(), 36 | @data.CreateBertInputs(), 37 | @data.Shuffle(), 38 | @data.PadToLength(), 39 | @data.TruncateToLength(), 40 | @data.Batch() 41 | ] 42 | 43 | train/data.BertNextSentencePredictionInputs.dataset_name = %dataset_name 44 | train/data.TFDS.train = True 45 | eval/data.BertNextSentencePredictionInputs.dataset_name = %dataset_name 46 | eval/data.TFDS.train = False 47 | 48 | # Parameters for BERT: 49 | # ============================================================================== 50 | BERT.head = @bert.BERTClassifierHead 51 | bert.BERTClassifierHead.n_classes = 2 52 | 53 | -------------------------------------------------------------------------------- /trax/supervised/configs/cond_skipping_transformer_lm1b.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models 17 | import trax.optimizers 18 | import trax.supervised.trainer_lib 19 | 20 | # Parameters for batcher: 21 | # ============================================================================== 22 | batcher.data_streams = @data.data_streams 23 | batcher.batch_size_per_device = 96 24 | batcher.eval_batch_size = 128 25 | batcher.max_eval_length = 2048 26 | batcher.id_to_mask = 0 27 | 28 | # Parameters for data_streams: 29 | # ============================================================================== 30 | data_streams.data_dir = None 31 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k' 32 | data_streams.input_name = 'targets' 33 | data_streams.preprocess_fn = @data.lm1b_preprocess 34 | 35 | # Parameters for multifactor: 36 | # ============================================================================== 37 | multifactor.constant = 0.3 38 | multifactor.factors = 'constant * linear_warmup' 39 | multifactor.warmup_steps = 8000 40 | 41 | # Parameters for lm1b_preprocess: 42 | # ============================================================================== 43 | lm1b_preprocess.max_target_length = 512 44 | lm1b_preprocess.max_eval_target_length = 2048 45 | 46 | # Parameters for train: 47 | # ============================================================================== 48 | train.eval_frequency = 1000 49 | train.eval_steps = 10 50 | train.model = @trax.models.SkippingTransformerLM 51 | train.optimizer = @trax.optimizers.SM3 52 | train.steps = 500000 53 | 54 | # Parameters for DotProductCausalAttention: 55 | # ============================================================================== 56 | DotProductCausalAttention.dropout = 0.1 57 | 58 | # Parameters for SkippingTransformerLM: 59 | # ============================================================================== 60 | SkippingTransformerLM.d_model = 1024 61 | SkippingTransformerLM.d_ff = 4096 62 | SkippingTransformerLM.dropout = 0.2 63 | SkippingTransformerLM.max_len = 2048 64 | SkippingTransformerLM.mode = 'train' 65 | SkippingTransformerLM.n_heads = 8 66 | SkippingTransformerLM.n_layers = 8 67 | SkippingTransformerLM.vocab_size = 32000 68 | -------------------------------------------------------------------------------- /trax/supervised/configs/gru_copy.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models 17 | import trax.optimizers 18 | import trax.supervised.trainer_lib 19 | 20 | n_symbols = 32 21 | length = 16 22 | batch = 512 23 | 24 | # Parameters for data_streams: 25 | # ============================================================================== 26 | data_streams.data_dir = None 27 | data_streams.dataset_name = 'gru_copy' 28 | 29 | # Parameters for sequence_copy_inputs: 30 | # ============================================================================== 31 | sequence_copy_inputs.vocab_size = %n_symbols 32 | sequence_copy_inputs.batch_size = %batch 33 | sequence_copy_inputs.train_length = %length 34 | sequence_copy_inputs.eval_min_length = 2 35 | sequence_copy_inputs.eval_max_length = %length 36 | sequence_copy_inputs.reverse = False 37 | 38 | # Parameters for multifactor: 39 | # ============================================================================== 40 | multifactor.constant = 0.001 41 | multifactor.factors = 'constant * linear_warmup' 42 | multifactor.warmup_steps = 8000 43 | 44 | # Parameters for train: 45 | # ============================================================================== 46 | train.eval_frequency = 1000 47 | train.eval_steps = 10 48 | train.inputs = @trax.data.sequence_copy_inputs 49 | train.model = @trax.models.RNNLM 50 | train.optimizer = @trax.optimizers.Adam 51 | train.steps = 500000 52 | 53 | # Parameters for RNNLM: 54 | # ============================================================================== 55 | RNNLM.rnn_cell = @trax.layers.GRUCell 56 | RNNLM.rnn_cell_d_state_multiplier = 1 57 | RNNLM.d_model = 128 58 | RNNLM.dropout = 0.1 59 | RNNLM.n_layers = 2 60 | RNNLM.vocab_size = %n_symbols 61 | -------------------------------------------------------------------------------- /trax/supervised/configs/hourglass_cifar10.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.layers 17 | import trax.models 18 | import trax.optimizers 19 | import trax.supervised.trainer_lib 20 | 21 | train_steps = 100000 22 | 23 | # Parameters for batcher: 24 | # ============================================================================== 25 | batcher.data_streams = @data.data_streams 26 | batcher.batch_size_per_device = 1 27 | batcher.eval_batch_size = 8 28 | batcher.max_eval_length = 3072 # 32 * 32 * 3 29 | batcher.variable_shapes = False 30 | 31 | # Parameters for data_streams: 32 | # ============================================================================== 33 | data_streams.data_dir = None 34 | data_streams.dataset_name = 'cifar10' 35 | data_streams.input_name = 'image' 36 | data_streams.target_name = 'image' 37 | data_streams.bare_preprocess_fn = \ 38 | @data.downsampled_imagenet_flatten_bare_preprocess 39 | 40 | # Parameters for multifactor: # ================================================ 41 | multifactor.constant = 1e-3 42 | multifactor.factors = 'constant * linear_warmup * cosine_decay' 43 | multifactor.warmup_steps = 5000 44 | multifactor.steps_per_cycle = %train_steps 45 | 46 | # Parameters for Adam: 47 | # ============================================================================== 48 | Adam.weight_decay_rate=0.0 49 | Adam.b1 = 0.9 50 | Adam.b2 = 0.98 51 | Adam.eps = 1e-9 52 | 53 | # Parameters for train: 54 | # ============================================================================== 55 | train.eval_frequency = 2000 56 | train.eval_steps = 625 57 | train.checkpoints_at = [100000] 58 | train.model = @trax.models.HourglassLM 59 | train.optimizer = @trax.optimizers.Adam 60 | train.steps = %train_steps 61 | 62 | 63 | # Parameters for HourglassLM: 64 | # ============================================================================== 65 | HourglassLM.d_model = 512 66 | HourglassLM.d_ff = 2048 67 | HourglassLM.vanilla_layers = (1, 1) 68 | HourglassLM.hierarchy = '8@3' 69 | HourglassLM.dropout = 0.0 70 | HourglassLM.mode = 'train' 71 | HourglassLM.n_heads = 8 72 | HourglassLM.vocab_size = 256 73 | HourglassLM.attention_downsampling_fn = @LinearPooling 74 | HourglassLM.attention_upsampling_fn = @LinearUpsampling 75 | -------------------------------------------------------------------------------- /trax/supervised/configs/hourglass_enwik8.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.layers 17 | import trax.models 18 | import trax.optimizers 19 | import trax.supervised.trainer_lib 20 | 21 | 22 | # Parameters for batcher: 23 | # ============================================================================== 24 | batcher.data_streams = @data.data_streams 25 | batcher.max_eval_length = 2049 26 | batcher.buckets = ([2049], [8]) 27 | batcher.id_to_mask = 0 28 | 29 | # Parameters for data_streams: 30 | # ============================================================================== 31 | data_streams.data_dir = None 32 | data_streams.dataset_name = 't2t_enwik8_l2k' 33 | data_streams.input_name = 'targets' 34 | 35 | # Parameters for multifactor: 36 | # ============================================================================== 37 | # 0.03125 ~= 1024^-0.5 = d_model^-0.5 38 | multifactor.constant = 4.1e-4 39 | multifactor.factors = 'constant * linear_warmup * cosine_decay' 40 | multifactor.warmup_steps = 4000 41 | multifactor.steps_per_cycle = 350000 42 | 43 | # Parameters for Adam: 44 | # ============================================================================== 45 | Adam.weight_decay_rate = 0.0 46 | Adam.b1 = 0.9 47 | Adam.b2 = 0.98 48 | Adam.eps = 1e-9 49 | 50 | # Parameters for train: 51 | # ============================================================================== 52 | train.eval_frequency = 2000 53 | train.eval_steps = 305 54 | train.model = @trax.models.HourglassLM 55 | train.optimizer = @trax.optimizers.Adam 56 | train.steps = 263000 57 | train.save_graphs = False 58 | train.checkpoints_at = [150000, 175000, 263000] 59 | 60 | 61 | # Parameters for HourglassLM: 62 | # ============================================================================== 63 | HourglassLM.d_ff = 2048 64 | HourglassLM.d_model = 512 65 | HourglassLM.dropout = 0.2 66 | HourglassLM.vanilla_layers = (5,5) 67 | HourglassLM.hierarchy = '24@3' 68 | HourglassLM.n_heads = 8 69 | HourglassLM.vocab_size = 256 70 | HourglassLM.attention_upsampling_fn = @NaiveUpsampling 71 | HourglassLM.ff_activation = @trax.layers.FastGelu 72 | -------------------------------------------------------------------------------- /trax/supervised/configs/layerdrop_every_transformer_lm1b.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models 17 | import trax.optimizers 18 | import trax.supervised.trainer_lib 19 | 20 | # Parameters for batcher: 21 | # ============================================================================== 22 | batcher.data_streams = @data.data_streams 23 | batcher.batch_size_per_device = 96 24 | batcher.eval_batch_size = 128 25 | batcher.max_eval_length = 2048 26 | batcher.id_to_mask = 0 27 | 28 | # Parameters for data_streams: 29 | # ============================================================================== 30 | data_streams.data_dir = None 31 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k' 32 | data_streams.input_name = 'targets' 33 | data_streams.preprocess_fn = @data.lm1b_preprocess 34 | 35 | # Parameters for multifactor: 36 | # ============================================================================== 37 | multifactor.constant = 0.3 38 | multifactor.factors = 'constant * linear_warmup' 39 | multifactor.warmup_steps = 8000 40 | 41 | # Parameters for lm1b_preprocess: 42 | # ============================================================================== 43 | lm1b_preprocess.max_target_length = 512 44 | lm1b_preprocess.max_eval_target_length = 2048 45 | 46 | # Parameters for train: 47 | # ============================================================================== 48 | train.eval_frequency = 1000 49 | train.eval_steps = 10 50 | train.model = @trax.models.EveryOtherLayerDropTransformerLM 51 | train.optimizer = @trax.optimizers.SM3 52 | train.steps = 500000 53 | 54 | # Parameters for DotProductCausalAttention: 55 | # ============================================================================== 56 | DotProductCausalAttention.dropout = 0.1 57 | 58 | # Parameters for SkippingTransformerLM: 59 | # ============================================================================== 60 | EveryOtherLayerDropTransformerLM.d_model = 1024 61 | EveryOtherLayerDropTransformerLM.d_ff = 4096 62 | EveryOtherLayerDropTransformerLM.dropout = 0.2 63 | EveryOtherLayerDropTransformerLM.max_len = 2048 64 | EveryOtherLayerDropTransformerLM.mode = 'train' 65 | EveryOtherLayerDropTransformerLM.n_heads = 8 66 | EveryOtherLayerDropTransformerLM.n_layers = 8 67 | EveryOtherLayerDropTransformerLM.vocab_size = 32000 68 | EveryOtherLayerDropTransformerLM.skip_mode = 'even' 69 | -------------------------------------------------------------------------------- /trax/supervised/configs/layerdrop_transformer_lm1b.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models 17 | import trax.optimizers 18 | import trax.supervised.trainer_lib 19 | 20 | # Parameters for batcher: 21 | # ============================================================================== 22 | batcher.data_streams = @data.data_streams 23 | batcher.batch_size_per_device = 96 24 | batcher.eval_batch_size = 128 25 | batcher.max_eval_length = 2048 26 | batcher.id_to_mask = 0 27 | 28 | # Parameters for data_streams: 29 | # ============================================================================== 30 | data_streams.data_dir = None 31 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k' 32 | data_streams.input_name = 'targets' 33 | data_streams.preprocess_fn = @data.lm1b_preprocess 34 | 35 | # Parameters for multifactor: 36 | # ============================================================================== 37 | multifactor.constant = 0.3 38 | multifactor.factors = 'constant * linear_warmup' 39 | multifactor.warmup_steps = 8000 40 | 41 | # Parameters for lm1b_preprocess: 42 | # ============================================================================== 43 | lm1b_preprocess.max_target_length = 512 44 | lm1b_preprocess.max_eval_target_length = 2048 45 | 46 | # Parameters for train: 47 | # ============================================================================== 48 | train.eval_frequency = 1000 49 | train.eval_steps = 10 50 | train.model = @trax.models.LayerDropTransformerLM 51 | train.optimizer = @trax.optimizers.SM3 52 | train.steps = 500000 53 | 54 | # Parameters for DotProductCausalAttention: 55 | # ============================================================================== 56 | DotProductCausalAttention.dropout = 0.1 57 | 58 | # Parameters for SkippingTransformerLM: 59 | # ============================================================================== 60 | LayerDropTransformerLM.d_model = 1024 61 | LayerDropTransformerLM.d_ff = 4096 62 | LayerDropTransformerLM.dropout = 0.2 63 | LayerDropTransformerLM.max_len = 2048 64 | LayerDropTransformerLM.mode = 'train' 65 | LayerDropTransformerLM.n_heads = 8 66 | LayerDropTransformerLM.n_layers = 8 67 | LayerDropTransformerLM.vocab_size = 32000 68 | LayerDropTransformerLM.skip_fraction = 0.4 69 | -------------------------------------------------------------------------------- /trax/supervised/configs/layerdrop_ushape_transformer_lm1b.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models 17 | import trax.optimizers 18 | import trax.supervised.trainer_lib 19 | 20 | # Parameters for batcher: 21 | # ============================================================================== 22 | batcher.data_streams = @data.data_streams 23 | batcher.batch_size_per_device = 96 24 | batcher.eval_batch_size = 128 25 | batcher.max_eval_length = 2048 26 | batcher.id_to_mask = 0 27 | 28 | # Parameters for data_streams: 29 | # ============================================================================== 30 | data_streams.data_dir = None 31 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k' 32 | data_streams.input_name = 'targets' 33 | data_streams.preprocess_fn = @data.lm1b_preprocess 34 | 35 | # Parameters for multifactor: 36 | # ============================================================================== 37 | multifactor.constant = 0.3 38 | multifactor.factors = 'constant * linear_warmup' 39 | multifactor.warmup_steps = 8000 40 | 41 | # Parameters for lm1b_preprocess: 42 | # ============================================================================== 43 | lm1b_preprocess.max_target_length = 512 44 | lm1b_preprocess.max_eval_target_length = 2048 45 | 46 | # Parameters for train: 47 | # ============================================================================== 48 | train.eval_frequency = 1000 49 | train.eval_steps = 10 50 | train.model = @trax.models.LayerDropTransformerLM 51 | train.optimizer = @trax.optimizers.SM3 52 | train.steps = 500000 53 | 54 | # Parameters for DotProductCausalAttention: 55 | # ============================================================================== 56 | DotProductCausalAttention.dropout = 0.1 57 | 58 | # Parameters for SkippingTransformerLM: 59 | # ============================================================================== 60 | LayerDropTransformerLM.d_model = 1024 61 | LayerDropTransformerLM.d_ff = 4096 62 | LayerDropTransformerLM.dropout = 0.2 63 | LayerDropTransformerLM.max_len = 2048 64 | LayerDropTransformerLM.mode = 'train' 65 | LayerDropTransformerLM.n_heads = 8 66 | LayerDropTransformerLM.n_layers = 8 67 | LayerDropTransformerLM.vocab_size = 32000 68 | LayerDropTransformerLM.skip_fraction = [0.1, 0.1, 0.3, 0.3, 0.3, 0.3, 0.1, 0.1] 69 | -------------------------------------------------------------------------------- /trax/supervised/configs/lstm_lm1b.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models 17 | import trax.optimizers 18 | import trax.supervised.trainer_lib 19 | 20 | # Parameters for batcher: 21 | # ============================================================================== 22 | batcher.data_streams = @data.data_streams 23 | batcher.batch_size_per_device = 64 24 | batcher.eval_batch_size = 64 25 | batcher.max_eval_length = 2048 26 | batcher.id_to_mask = 0 27 | 28 | # Parameters for data_streams: 29 | # ============================================================================== 30 | data_streams.data_dir = None 31 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k' 32 | data_streams.input_name = 'targets' 33 | data_streams.preprocess_fn = @data.lm1b_preprocess 34 | 35 | # Parameters for multifactor: 36 | # ============================================================================== 37 | multifactor.constant = 0.001 38 | multifactor.factors = 'constant * linear_warmup' 39 | multifactor.warmup_steps = 8000 40 | 41 | # Parameters for lm1b_preprocess: 42 | # ============================================================================== 43 | lm1b_preprocess.max_target_length = 512 44 | lm1b_preprocess.max_eval_target_length = 2048 45 | 46 | # Parameters for train: 47 | # ============================================================================== 48 | train.eval_frequency = 1000 49 | train.eval_steps = 10 50 | train.model = @trax.models.RNNLM 51 | train.optimizer = @trax.optimizers.Adam 52 | train.steps = 500000 53 | 54 | # Parameters for RNNLM: 55 | # ============================================================================== 56 | RNNLM.rnn_cell = @trax.layers.LSTMCell 57 | RNNLM.rnn_cell_d_state_multiplier = 2 58 | RNNLM.d_model = 512 59 | RNNLM.dropout = 0.1 60 | RNNLM.n_layers = 2 61 | RNNLM.vocab_size = 32000 62 | -------------------------------------------------------------------------------- /trax/supervised/configs/lstm_seq2seq_wmt_ende.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models 17 | import trax.optimizers 18 | import trax.supervised.lr_schedules 19 | import trax.supervised.trainer_lib 20 | 21 | # Parameters for batcher: 22 | # ============================================================================== 23 | batcher.data_streams = @data.data_streams 24 | batcher.batch_size_per_device = 64 25 | batcher.eval_batch_size = 64 26 | batcher.max_eval_length = 512 27 | batcher.bucket_length = 32 28 | batcher.buckets_include_inputs_in_length=True 29 | batcher.id_to_mask = 0 30 | 31 | # Parameters for data_streams: 32 | # ============================================================================== 33 | data_streams.data_dir = None 34 | data_streams.dataset_name = 't2t_translate_ende_wmt32k' 35 | data_streams.preprocess_fn = @data.wmt_preprocess 36 | 37 | # Parameters for wmt_preproces: 38 | # ============================================================================== 39 | wmt_preprocess.max_length = 256 40 | wmt_preprocess.max_eval_length = 512 41 | 42 | # Parameters for lr_schedules.warmup: 43 | # ============================================================================== 44 | lr_schedules.warmup.max_value = 0.001 45 | lr_schedules.warmup.n_warmup_steps = 1000 46 | 47 | # Parameters for Adam: 48 | # ============================================================================== 49 | Adam.weight_decay_rate = 0.0 50 | 51 | # Parameters for train: 52 | # ============================================================================== 53 | train.eval_frequency = 1000 54 | train.eval_steps = 10 55 | train.model = @trax.models.LSTMSeq2SeqAttn 56 | train.optimizer = @trax.optimizers.Adam 57 | train.lr_schedule_fn = @lr_schedules.warmup 58 | train.steps = 250000 59 | 60 | # Parameters for LSTMSeq2SeqAttn: 61 | # ============================================================================== 62 | LSTMSeq2SeqAttn.d_model= 1024 63 | LSTMSeq2SeqAttn.n_encoder_layers = 2 64 | LSTMSeq2SeqAttn.n_decoder_layers = 2 65 | LSTMSeq2SeqAttn.attention_dropout = 0.2 66 | LSTMSeq2SeqAttn.n_attention_heads = 8 67 | LSTMSeq2SeqAttn.input_vocab_size = 33300 68 | LSTMSeq2SeqAttn.target_vocab_size = 33300 69 | -------------------------------------------------------------------------------- /trax/supervised/configs/mlp_mnist.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models 17 | import trax.optimizers 18 | import trax.supervised.lr_schedules 19 | import trax.supervised.trainer_lib 20 | 21 | # Parameters for batcher: 22 | # ============================================================================== 23 | batcher.data_streams = @data.data_streams 24 | batcher.batch_size_per_device = 256 25 | batcher.eval_batch_size = 256 26 | batcher.variable_shapes = False 27 | 28 | # Parameters for data.data_streams: 29 | # ============================================================================== 30 | data_streams.data_dir = None 31 | data_streams.dataset_name = 'mnist' 32 | 33 | # Parameters for MLP: 34 | # ============================================================================== 35 | MLP.layer_widths=(128, 64) 36 | 37 | # Parameters for lr_schedules.constant 38 | # ============================================================================== 39 | lr_schedules.constant.value = 0.0001 40 | 41 | # Parameters for train: 42 | # ============================================================================== 43 | train.optimizer = @trax.optimizers.Adam 44 | train.eval_frequency = 200 45 | train.eval_steps = 10 46 | train.model = @trax.models.MLP 47 | train.steps = 2000 48 | train.checkpoint_highest = 'accuracy' 49 | train.lr_schedule_fn = @lr_schedules.constant 50 | -------------------------------------------------------------------------------- /trax/supervised/configs/reformer_cifar10.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models 17 | import trax.optimizers 18 | import trax.supervised.trainer_lib 19 | 20 | # Parameters that will vary between experiments: 21 | # ============================================================================== 22 | train.model = @trax.models.ReformerLM 23 | attn_kv = 64 24 | n_layers = 9 25 | 26 | # Parameters for batcher: 27 | # ============================================================================== 28 | batcher.data_streams = @data.data_streams 29 | batcher.batch_size_per_device = 1 30 | batcher.eval_batch_size = 8 31 | batcher.max_eval_length = 12288 # 64 * 64 * 3 32 | 33 | # Parameters for data_streams: 34 | # ============================================================================== 35 | data_streams.data_dir = None 36 | data_streams.dataset_name = 'cifar10' 37 | data_streams.preprocess_fn = @data.cifar10_augmentation_flatten_preprocess 38 | 39 | # Parameters for multifactor: 40 | # ============================================================================== 41 | multifactor.constant = 1.0 42 | multifactor.factors = 'constant * linear_warmup * rsqrt_decay' 43 | multifactor.warmup_steps = 8000 44 | 45 | # Parameters for train: 46 | # ============================================================================== 47 | train.eval_frequency = 500 48 | train.eval_steps = 64 49 | # train.model: see top 50 | train.optimizer = @trax.optimizers.Adafactor 51 | train.steps = 500000 52 | train.save_graphs = False 53 | train.checkpoints_at = \ 54 | [1000, 5000, 10000, 20000, 40000, 60000, 80000, 55 | 100000, 200000, 300000, 400000, 500000] 56 | 57 | # Parameters for ReformerLM: 58 | # ============================================================================== 59 | ReformerLM.d_attention_key = %attn_kv 60 | ReformerLM.d_attention_value = %attn_kv 61 | ReformerLM.d_model = 1024 62 | ReformerLM.d_ff = 4096 63 | ReformerLM.dropout = 0.1 64 | ReformerLM.max_len = 12288 # 64 * 64 * 3 65 | ReformerLM.mode = 'train' 66 | ReformerLM.n_heads = 8 67 | ReformerLM.n_layers = %n_layers 68 | ReformerLM.vocab_size = 256 69 | -------------------------------------------------------------------------------- /trax/supervised/configs/resnet50_frn_imagenet_8gb.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.supervised.lr_schedules 17 | import trax.models 18 | import trax.optimizers 19 | import trax.supervised.trainer_lib 20 | 21 | # Parameters for batcher: 22 | # ============================================================================== 23 | batcher.data_streams = @data.data_streams 24 | batcher.batch_size_per_device = 32 25 | batcher.bucket_length = 32 26 | batcher.buckets = None 27 | batcher.eval_batch_size = 32 28 | batcher.variable_shapes = False 29 | 30 | # Parameters for data_streams: 31 | # ============================================================================== 32 | data_streams.data_dir = None 33 | data_streams.dataset_name = 't2t_image_imagenet224' 34 | 35 | # Parameters for FilterResponseNorm: 36 | # ============================================================================== 37 | # TODO(afrozm): Make this work in the learn_epsilon = True setting. 38 | FilterResponseNorm.learn_epsilon = False 39 | 40 | # Parameters for multifactor: 41 | # ============================================================================== 42 | multifactor.constant = 30.0 43 | multifactor.factors = 'constant * linear_warmup * rsqrt_decay' 44 | multifactor.warmup_steps = 30000 45 | 46 | # Parameters for Momentum: 47 | # ============================================================================== 48 | Momentum.mass = 0.9 49 | 50 | # Parameters for Resnet50: 51 | # ============================================================================== 52 | Resnet50.d_hidden = 64 53 | Resnet50.n_output_classes = 1001 54 | Resnet50.norm = @trax.layers.FilterResponseNorm 55 | Resnet50.non_linearity = @trax.layers.ThresholdedLinearUnit 56 | 57 | # Parameters for train: 58 | # ============================================================================== 59 | train.eval_frequency = 2000 60 | train.eval_steps = 20 61 | train.model = @trax.models.Resnet50 62 | train.optimizer = @trax.optimizers.Momentum 63 | train.steps = 300000 64 | -------------------------------------------------------------------------------- /trax/supervised/configs/resnet50_imagenet_8gb_testing.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.supervised.lr_schedules 17 | import trax.models 18 | import trax.optimizers 19 | import trax.supervised.trainer_lib 20 | 21 | # Parameters for batcher: 22 | # ============================================================================== 23 | batcher.data_streams = @data.data_streams 24 | batcher.batch_size_per_device = 32 25 | batcher.eval_batch_size = 32 26 | batcher.variable_shapes = False 27 | 28 | # Parameters for data_streams: 29 | # ============================================================================== 30 | data_streams.data_dir = None 31 | data_streams.dataset_name = 't2t_image_imagenet224' 32 | data_streams.preprocess_fn = @data.squeeze_targets_preprocess 33 | 34 | # Parameters for multifactor: 35 | # ============================================================================== 36 | multifactor.factors = 'constant * linear_warmup' 37 | multifactor.constant = 0.2 38 | multifactor.warmup_steps = 400 39 | 40 | # Parameters for Momentum: 41 | # ============================================================================== 42 | Momentum.mass = 0.9 43 | 44 | # Parameters for Resnet50: 45 | # ============================================================================== 46 | Resnet50.d_hidden = 64 47 | Resnet50.n_output_classes = 1001 48 | 49 | # Parameters for train: 50 | # ============================================================================== 51 | train.eval_frequency = 2000 52 | train.eval_steps = 20 53 | train.model = @trax.models.Resnet50 54 | train.optimizer = @trax.optimizers.Momentum 55 | train.steps = 100000 56 | -------------------------------------------------------------------------------- /trax/supervised/configs/rse_addition_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # multifactor.warmup_steps: [4000, 5000, 6000] 16 | # multifactor.constant: [1.0e-1, 5.0e-1, 1.0, 5.0] 17 | 18 | # ResidualShuffleExchange.n_blocks: [1, 2] 19 | ResidualShuffleExchange.dropout: [0.0, 0.05, 0.1] 20 | ResidualShuffleExchange.input_dropout: [0.0, 0.05, 0.1] 21 | 22 | # WeightedCategoryCrossEntropy.label_smoothing: [0., 0.1, 0.05, 0.01, 0.005] 23 | 24 | # train.optimizer: ["%Adam", "%Adafactor"] 25 | # Adafactor.do_momentum: [True, False] 26 | -------------------------------------------------------------------------------- /trax/supervised/configs/skipping_transformer_lm1b.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models 17 | import trax.optimizers 18 | import trax.supervised.trainer_lib 19 | 20 | # Parameters for batcher: 21 | # ============================================================================== 22 | batcher.data_streams = @data.data_streams 23 | batcher.batch_size_per_device = 96 24 | batcher.eval_batch_size = 128 25 | batcher.max_eval_length = 2048 26 | batcher.id_to_mask = 0 27 | 28 | # Parameters for data_streams: 29 | # ============================================================================== 30 | data_streams.data_dir = None 31 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k' 32 | data_streams.input_name = 'targets' 33 | data_streams.preprocess_fn = @data.lm1b_preprocess 34 | 35 | # Parameters for multifactor: 36 | # ============================================================================== 37 | multifactor.constant = 0.3 38 | multifactor.factors = 'constant * linear_warmup' 39 | multifactor.warmup_steps = 8000 40 | 41 | # Parameters for lm1b_preprocess: 42 | # ============================================================================== 43 | lm1b_preprocess.max_target_length = 512 44 | lm1b_preprocess.max_eval_target_length = 2048 45 | 46 | # Parameters for train: 47 | # ============================================================================== 48 | train.eval_frequency = 1000 49 | train.eval_steps = 10 50 | train.model = @trax.models.SkippingTransformerLM 51 | train.optimizer = @trax.optimizers.SM3 52 | train.steps = 500000 53 | 54 | # Parameters for DotProductCausalAttention: 55 | # ============================================================================== 56 | DotProductCausalAttention.dropout = 0.1 57 | 58 | # Parameters for SkippingTransformerLM: 59 | # ============================================================================== 60 | SkippingTransformerLM.d_model = 1024 61 | SkippingTransformerLM.d_ff = 4096 62 | SkippingTransformerLM.dropout = 0.2 63 | SkippingTransformerLM.max_len = 2048 64 | SkippingTransformerLM.mode = 'train' 65 | SkippingTransformerLM.n_heads = 8 66 | SkippingTransformerLM.n_layers = 8 67 | SkippingTransformerLM.vocab_size = 32000 68 | -------------------------------------------------------------------------------- /trax/supervised/configs/sp_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ConfigurableTerraformer.d_ff: [2048, 4096] 16 | ConfigurableTerraformer.ff_chunk_size: [0, 2048, 4096, 8192] 17 | LSHSelfAttention.chunk_len: [64, 128, 256] 18 | LSHSelfAttention.n_buckets: [128, 256, 512] 19 | multifactor.constant: [0.03125, 0.0442, 1.0] 20 | -------------------------------------------------------------------------------- /trax/supervised/configs/t5_mathqa_drop_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | Adafactor.weight_decay_rate: [0, ] 16 | multifactor.constant: [0.001, ] 17 | CreateAnnotatedDropInputs.percentile: [0.01, 0.05, 0.1, 0.25, 0.5, 1] 18 | -------------------------------------------------------------------------------- /trax/supervised/configs/t5_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | dataset_name: [ 16 | 'glue/mrpc', 17 | 'glue/sst2', 18 | 'glue/qqp', 19 | 'glue/mnli', 20 | 'glue/qnli', 21 | 'glue/rte', 22 | ] 23 | Adafactor.weight_decay_rate: [0, ] 24 | multifactor.constant: [0.001, 0.0001, 0.00001] 25 | 26 | 27 | -------------------------------------------------------------------------------- /trax/supervised/configs/t5_sweep_temperature.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | dataset_name: [ 16 | 'glue/mrpc', 17 | 'glue/sst2', 18 | 'glue/qqp', 19 | 'glue/mnli', 20 | 'glue/qnli', 21 | 'glue/rte', 22 | ] 23 | ConfigurableTransformer.ff_sparsity: ['64 64 0.0 1.0', '64 64 0.0 0.0'] 24 | -------------------------------------------------------------------------------- /trax/supervised/configs/terraformer_copy.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # -*-Python-*- 16 | 17 | include 'reformer_copy.gin' 18 | 19 | import trax.data 20 | import trax.models 21 | import trax.optimizers 22 | import trax.supervised.trainer_lib 23 | 24 | # Parameters for ConfigurableTerraformer: 25 | # ============================================================================== 26 | ConfigurableTerraformer.d_model = 256 27 | ConfigurableTerraformer.d_ff = 512 28 | ConfigurableTerraformer.dropout = 0.05 29 | ConfigurableTerraformer.max_len = %max_len 30 | ConfigurableTerraformer.n_heads = 4 31 | ConfigurableTerraformer.mode = 'train' 32 | ConfigurableTerraformer.n_encoder_layers = 3 33 | ConfigurableTerraformer.n_decoder_layers = 3 34 | ConfigurableTerraformer.ff_use_sru = 0 35 | ConfigurableTerraformer.d_attention_key = 64 36 | ConfigurableTerraformer.d_attention_value = 64 37 | ConfigurableTerraformer.encoder_attention_type = @LSHSelfAttention 38 | ConfigurableTerraformer.encoder_decoder_attention_type = @LSHSelfAttention 39 | ConfigurableTerraformer.n_decoder_attention_layers = 1 40 | ConfigurableTerraformer.input_vocab_size = %vocab_size 41 | ConfigurableTerraformer.pos_type = 'fixed-base' 42 | 43 | # Parameters for train: 44 | # ============================================================================== 45 | train.inputs = @trax.data.simple_sequence_copy_inputs 46 | train.model = @trax.models.ConfigurableTerraformer 47 | -------------------------------------------------------------------------------- /trax/supervised/configs/terraformer_copy_self_attn.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # -*-Python-*- 16 | 17 | include 'terraformer_copy.gin' 18 | 19 | import trax.models 20 | 21 | 22 | # Parameters for ConfigurableTerraformer: 23 | # ============================================================================== 24 | ConfigurableTerraformer.encoder_attention_type = @SelfAttention 25 | ConfigurableTerraformer.encoder_decoder_attention_type = @SelfAttention 26 | 27 | # Parameters for SelfAttention: 28 | # ============================================================================== 29 | # 0 < predict_drop_len <= predict_mem_len 30 | SelfAttention.predict_mem_len = %max_len 31 | SelfAttention.predict_drop_len = %max_len 32 | -------------------------------------------------------------------------------- /trax/supervised/configs/terraformer_purelsh_copy.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # -*-Python-*- 16 | 17 | include 'terraformer_copy.gin' 18 | 19 | import trax.data 20 | import trax.models 21 | import trax.optimizers 22 | import trax.supervised.trainer_lib 23 | 24 | 25 | # Parameters for PureLSHSelfAttention: 26 | # ============================================================================== 27 | PureLSHSelfAttention.attention_dropout = 0.0 28 | PureLSHSelfAttention.chunk_len = 16 29 | PureLSHSelfAttention.n_buckets = [32, 32] 30 | PureLSHSelfAttention.n_chunks_after = 0 31 | PureLSHSelfAttention.n_chunks_before = 1 32 | PureLSHSelfAttention.n_hashes = 2 33 | PureLSHSelfAttention.n_parallel_heads = 1 34 | PureLSHSelfAttention.max_length_for_buckets = 1024 35 | # 0 < predict_drop_len <= predict_mem_len 36 | PureLSHSelfAttention.predict_mem_len = %max_len 37 | PureLSHSelfAttention.predict_drop_len = %max_len 38 | 39 | # Parameters for PureLSHSelfAttentionWrapper: 40 | # ============================================================================== 41 | PureLSHSelfAttentionWrapper.pure_lsh_implementation = @PureLSHSelfAttention 42 | 43 | # We need something special for the encoder. 44 | enc/PureLSHSelfAttention.n_chunks_after = 1 45 | encoder/PureLSHSelfAttentionWrapper.pure_lsh_implementation = @enc/PureLSHSelfAttention 46 | 47 | # Parameters for ConfigurableTerraformer: 48 | # ============================================================================== 49 | ConfigurableTerraformer.encoder_attention_type = @encoder/PureLSHSelfAttentionWrapper 50 | ConfigurableTerraformer.encoder_decoder_attention_type = @PureLSHSelfAttentionWrapper 51 | 52 | # Parameters for train: 53 | # ============================================================================== 54 | train.inputs = @trax.data.simple_sequence_copy_inputs 55 | train.model = @trax.models.ConfigurableTerraformer 56 | -------------------------------------------------------------------------------- /trax/supervised/configs/transformer_big_lm1b_8gb.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models 17 | import trax.optimizers 18 | import trax.supervised.trainer_lib 19 | 20 | # Parameters for batcher: 21 | # ============================================================================== 22 | batcher.data_streams = @data.data_streams 23 | batcher.batch_size_per_device = 32 24 | batcher.eval_batch_size = 64 25 | batcher.max_eval_length = 512 26 | batcher.id_to_mask = 0 27 | 28 | # Parameters for data_streams: 29 | # ============================================================================== 30 | data_streams.data_dir = None 31 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k' 32 | data_streams.input_name = 'targets' 33 | data_streams.preprocess_fn = @data.lm1b_preprocess 34 | 35 | # Parameters for multifactor: 36 | # ============================================================================== 37 | multifactor.constant = 0.1 38 | multifactor.factors = 'constant * linear_warmup' 39 | multifactor.warmup_steps = 16000 40 | 41 | # Parameters for lm1b_preprocess: 42 | # ============================================================================== 43 | lm1b_preprocess.max_target_length = 512 44 | lm1b_preprocess.max_eval_target_length = 512 45 | 46 | # Parameters for train: 47 | # ============================================================================== 48 | train.eval_frequency = 1000 49 | train.eval_steps = 10 50 | train.model = @trax.models.TransformerLM 51 | train.optimizer = @trax.optimizers.SM3 52 | train.steps = 500000 53 | 54 | # Parameters for TransformerLM: 55 | # ============================================================================== 56 | TransformerLM.d_model = 1024 57 | TransformerLM.d_ff = 8192 58 | TransformerLM.dropout = 0.1 59 | TransformerLM.max_len = 2048 60 | TransformerLM.mode = 'train' 61 | TransformerLM.n_heads = 8 62 | TransformerLM.n_layers = 8 63 | TransformerLM.vocab_size = 32000 64 | -------------------------------------------------------------------------------- /trax/supervised/configs/transformer_finetune_squad_16gb.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | include 'c4_pretrain_16gb_adafactor.gin' 17 | 18 | 19 | # Parameters for train: 20 | # ============================================================================== 21 | # This should always be overriden since this is relative to the pre-trained 22 | # checkpoint. 23 | train.steps = 262144 24 | 25 | 26 | # Parameters for batcher: 27 | # ============================================================================== 28 | batcher.buckets = ([513,], [8, 8]) 29 | batcher.strict_pad_on_len = True 30 | batcher.buckets_include_inputs_in_length = True 31 | 32 | 33 | # Parameters for multifactor: 34 | # ============================================================================== 35 | multifactor.factors = 'constant' 36 | multifactor.constant = 0.001 37 | 38 | 39 | # Parameters for data_streams: 40 | # ============================================================================== 41 | data_streams.dataset_name = 'squad/plain_text:1.0.0' 42 | data_streams.bare_preprocess_fn = @data.generic_text_dataset_preprocess_fn 43 | 44 | 45 | # Parameters for get_t5_preprocessor_by_name: 46 | # ============================================================================== 47 | squad/get_t5_preprocessor_by_name.name = 'squad' 48 | squad/get_t5_preprocessor_by_name.fn_kwargs = {'include_context': True} 49 | 50 | 51 | # Parameters for generic_text_dataset_preprocess_fn: 52 | # ============================================================================== 53 | generic_text_dataset_preprocess_fn.text_preprocess_fns = [ 54 | @squad/get_t5_preprocessor_by_name() 55 | ] 56 | generic_text_dataset_preprocess_fn.token_preprocess_fns = [ 57 | @data.add_eos_to_output_features 58 | ] 59 | -------------------------------------------------------------------------------- /trax/supervised/configs/transformer_lm1b_8gb_testing.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models 17 | import trax.optimizers 18 | import trax.supervised.trainer_lib 19 | 20 | # Module trax.data: 21 | # ============================================================================== 22 | batcher.data_streams = @data.data_streams 23 | batcher.batch_size_per_device = 128 24 | batcher.eval_batch_size = 128 25 | batcher.max_eval_length = 2048 26 | batcher.id_to_mask = 0 27 | 28 | data_streams.data_dir = None 29 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k' 30 | data_streams.input_name = 'targets' 31 | data_streams.preprocess_fn = @data.lm1b_preprocess 32 | 33 | lm1b_preprocess.max_target_length = 512 34 | lm1b_preprocess.max_eval_target_length = 2048 35 | 36 | # Module trax.models.transformer: 37 | # ============================================================================== 38 | TransformerLM.d_model = 512 39 | TransformerLM.d_ff = 2048 40 | TransformerLM.dropout = 0.1 41 | TransformerLM.max_len = 2048 42 | TransformerLM.mode = 'train' 43 | TransformerLM.n_heads = 8 44 | TransformerLM.n_layers = 6 45 | TransformerLM.vocab_size = 32000 46 | 47 | # Module trax.supervised.lr_schedules: 48 | # ============================================================================== 49 | multifactor.constant = 0.1 50 | multifactor.factors = 'constant * linear_warmup * rsqrt_decay' 51 | multifactor.warmup_steps = 8000 52 | 53 | # Module trax.supervised.trainer_lib: 54 | # ============================================================================== 55 | train.eval_frequency = 1000 56 | train.eval_steps = 10 57 | train.model = @trax.models.TransformerLM 58 | train.optimizer = @trax.optimizers.Adam 59 | train.steps = 100000 60 | -------------------------------------------------------------------------------- /trax/supervised/configs/transformer_ptb_16gb.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.models 17 | import trax.optimizers 18 | import trax.supervised.lr_schedules 19 | import trax.supervised.trainer_lib 20 | 21 | # Parameters for batcher: 22 | # ============================================================================== 23 | batcher.data_streams = @data.data_streams 24 | batcher.batch_size_per_device = 64 25 | batcher.eval_batch_size = 512 26 | batcher.max_eval_length = 2048 27 | batcher.id_to_mask = 0 28 | 29 | # Parameters for data_streams: 30 | # ============================================================================== 31 | data_streams.data_dir = None 32 | data_streams.dataset_name = 't2t_languagemodel_ptb10k' 33 | data_streams.input_name = 'targets' 34 | data_streams.preprocess_fn = @data.lm1b_preprocess 35 | 36 | # Parameters for multifactor: 37 | # ============================================================================== 38 | multifactor.constant = 1.0 39 | multifactor.factors = 'constant * linear_warmup * rsqrt_decay' 40 | multifactor.warmup_steps = 8000 41 | 42 | # Parameters for lm1b_preprocess: 43 | # ============================================================================== 44 | lm1b_preprocess.max_target_length = 512 45 | lm1b_preprocess.max_eval_target_length = 2048 46 | 47 | # Parameters for train: 48 | # ============================================================================== 49 | train.eval_frequency = 200 50 | train.eval_steps = 2 51 | train.model = @trax.models.TransformerLM 52 | train.optimizer = @trax.optimizers.Adafactor 53 | train.steps = 20000 54 | 55 | # Parameters for TransformerLM: 56 | # ============================================================================== 57 | TransformerLM.d_model = 512 58 | TransformerLM.d_ff = 2048 59 | TransformerLM.dropout = 0.5 60 | TransformerLM.max_len = 2048 61 | TransformerLM.mode = 'train' 62 | TransformerLM.n_heads = 8 63 | TransformerLM.n_layers = 6 64 | TransformerLM.vocab_size = 10240 65 | -------------------------------------------------------------------------------- /trax/supervised/configs/wide_resnet_cifar10_8gb.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Trax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import trax.data 16 | import trax.supervised.lr_schedules 17 | import trax.models 18 | import trax.optimizers 19 | import trax.supervised.trainer_lib 20 | 21 | # Parameters for batcher: 22 | # ============================================================================== 23 | batcher.data_streams = @data.data_streams 24 | batcher.batch_size_per_device = 256 25 | batcher.bucket_length = 32 26 | batcher.buckets = None 27 | batcher.eval_batch_size = 512 28 | batcher.variable_shapes = False 29 | 30 | # Parameters for data_streams: 31 | # ============================================================================== 32 | data_streams.data_dir = None 33 | data_streams.dataset_name = 'cifar10' 34 | data_streams.preprocess_fn = @data.cifar10_augmentation_preprocess 35 | 36 | # Parameters for multifactor: 37 | # ============================================================================== 38 | multifactor.factors = 'constant * linear_warmup' 39 | multifactor.constant = 0.5 40 | multifactor.warmup_steps = 400 41 | 42 | # Parameters for Momentum: 43 | # ============================================================================== 44 | Momentum.mass = 0.9 45 | Momentum.weight_decay_rate = 5e-4 46 | 47 | # Parameters for WideResnet: 48 | # ============================================================================== 49 | WideResnet.widen_factor = 10 50 | WideResnet.n_blocks = 4 51 | WideResnet.n_output_classes = 10 52 | 53 | # Parameters for train: 54 | # ============================================================================== 55 | train.eval_frequency = 100 56 | train.eval_steps = 10 57 | train.model = @trax.models.WideResnet 58 | train.optimizer = @trax.optimizers.Momentum 59 | train.steps = 10000 60 | -------------------------------------------------------------------------------- /trax/supervised/history_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for trax.supervised.history.""" 17 | 18 | from absl.testing import absltest 19 | 20 | from trax.supervised import history as trax_history 21 | 22 | 23 | class HistoryTest(absltest.TestCase): 24 | 25 | def test_unknown_mode(self): 26 | history = trax_history.History() 27 | history.append('train', 'metric1', 1, 0.1) 28 | self.assertEqual(history.get('unknown_mode', 'metric1'), []) 29 | 30 | def test_unknown_metric(self): 31 | history = trax_history.History() 32 | history.append('train', 'metric1', 1, 0.1) 33 | self.assertEqual(history.get('train', 'unknown_metric'), []) 34 | 35 | def test_serializer_and_deserializer(self): 36 | history = trax_history.History() 37 | history.append('train', 'metric1', 1, 0.1) 38 | json_object = history.to_dict() 39 | history2 = trax_history.History.from_dict(json_object) 40 | self.assertEqual(history2.get('train', 'metric1'), [(1, 0.1)]) 41 | 42 | def test_modes(self): 43 | history = trax_history.History() 44 | history.append('train', 'metric1', 1, 0.1) 45 | history.append('test', 'metric2', 2, 0.2) 46 | self.assertEqual(history.modes, ['test', 'train']) 47 | 48 | def test_metrics_for_mode(self): 49 | history = trax_history.History() 50 | history.append('train', 'metric1', 1, 0.1) 51 | history.append('train', 'metric2', 2, 0.2) 52 | self.assertEqual(history.metrics_for_mode('train'), ['metric1', 'metric2']) 53 | 54 | 55 | if __name__ == '__main__': 56 | absltest.main() 57 | -------------------------------------------------------------------------------- /trax/supervised/testdata/reformerlm_copy_lsh_attn.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/supervised/testdata/reformerlm_copy_lsh_attn.pkl.gz -------------------------------------------------------------------------------- /trax/supervised/testdata/terraformer_copy_lsh_attn.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/supervised/testdata/terraformer_copy_lsh_attn.pkl.gz -------------------------------------------------------------------------------- /trax/supervised/testdata/terraformer_copy_self_attn.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/supervised/testdata/terraformer_copy_self_attn.pkl.gz -------------------------------------------------------------------------------- /trax/supervised/testdata/terraformer_purelsh_copy.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/supervised/testdata/terraformer_purelsh_copy.pkl.gz -------------------------------------------------------------------------------- /trax/supervised/testdata/transformer_copy.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/supervised/testdata/transformer_copy.pkl.gz -------------------------------------------------------------------------------- /trax/supervised/testdata/transformerlm_copy.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/supervised/testdata/transformerlm_copy.pkl.gz -------------------------------------------------------------------------------- /trax/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A few utilities for tests.""" 17 | 18 | import sys 19 | 20 | from absl import flags 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | 25 | # pytest doesn't run the test as a main, so it doesn't parse the flags 26 | # so if flags are required in tests, this will ensure that flags are manually 27 | # parsed and the desired flag exists. 28 | def ensure_flag(flag_str): 29 | try: 30 | getattr(FLAGS, flag_str) 31 | except flags.UnparsedFlagAccessError: 32 | # Manually parse flags. 33 | FLAGS(sys.argv) 34 | finally: 35 | assert getattr(FLAGS, flag_str) 36 | -------------------------------------------------------------------------------- /trax/tf_numpy/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /trax/tf_numpy/examples/mnist/train_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test that the example training script works on fake data.""" 17 | import mock 18 | import numpy as np 19 | import tensorflow.compat.v2 as tf 20 | 21 | from trax.tf_numpy.examples.mnist import dataset 22 | from trax.tf_numpy.examples.mnist import train 23 | 24 | 25 | class TFNumpyMnistExampleTest(tf.test.TestCase): 26 | 27 | def testRuns(self): 28 | with mock.patch.object(dataset, 'load', new=fake_mnist_data): 29 | train.train( 30 | batch_size=1, 31 | learning_rate=0.1, 32 | num_training_iters=10, 33 | validation_steps=5) 34 | train.train( 35 | batch_size=2, 36 | learning_rate=0.1, 37 | num_training_iters=5, 38 | validation_steps=2) 39 | train.train( 40 | batch_size=10, 41 | learning_rate=0.1, 42 | num_training_iters=1, 43 | validation_steps=1) 44 | 45 | 46 | def fake_mnist_data(): 47 | 48 | def gen_examples(num_examples): 49 | x = np.asarray( 50 | np.random.randn(num_examples, 784), dtype=np.float32) 51 | y = np.zeros((num_examples, 10), dtype=np.float32) 52 | y[:][0] = 1. 53 | return (x, y) 54 | 55 | return (gen_examples(100), gen_examples(10), gen_examples(10)) 56 | 57 | 58 | if __name__ == '__main__': 59 | tf.compat.v1.enable_eager_execution() 60 | tf.test.main() 61 | -------------------------------------------------------------------------------- /trax/tf_numpy/extensions/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """JAX-like function transformations and extensions for TF-numpy.""" 17 | 18 | # pylint: disable=wildcard-import 19 | from trax.tf_numpy.extensions.extensions import * 20 | # pylint: enable=wildcard-import 21 | -------------------------------------------------------------------------------- /trax/tf_numpy/numpy/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """NumPy like wrapper for Tensorflow.""" 17 | 18 | # pylint: disable=wildcard-import 19 | # pylint: disable=g-import-not-at-top 20 | # pylint: disable=g-direct-tensorflow-import 21 | 22 | try: 23 | # Note that this import will work in tf-nightly and TF versions 2.4 and 24 | # higher. 25 | from tensorflow.experimental.numpy import * 26 | # TODO(agarwal): get rid of following imports. 27 | from tensorflow.experimental.numpy import random 28 | from tensorflow import bfloat16 29 | import numpy as onp 30 | from tensorflow.python.ops.numpy_ops.np_dtypes import canonicalize_dtype 31 | from tensorflow.python.ops.numpy_ops.np_dtypes import default_float_type 32 | from tensorflow.python.ops.numpy_ops.np_dtypes import is_allow_float64 33 | from tensorflow.python.ops.numpy_ops.np_dtypes import set_allow_float64 34 | 35 | random.DEFAULT_RANDN_DTYPE = onp.float32 36 | except ImportError: 37 | try: 38 | # Note that this import will work in TF 2.3 and higher. 39 | from tensorflow.python.ops.numpy_ops import * 40 | from tensorflow import bfloat16 41 | 42 | except ImportError: 43 | # Note that this fallback will be needed for TF 2.2. 44 | from tensorflow import newaxis 45 | 46 | from trax.tf_numpy.numpy_impl import random 47 | 48 | # pylint: disable=wildcard-import 49 | from trax.tf_numpy.numpy_impl.array_ops import * 50 | from trax.tf_numpy.numpy_impl.arrays import * 51 | from trax.tf_numpy.numpy_impl.dtypes import * 52 | from trax.tf_numpy.numpy_impl.math_ops import * 53 | from trax.tf_numpy.numpy_impl.utils import finfo 54 | from trax.tf_numpy.numpy_impl.utils import promote_types 55 | from trax.tf_numpy.numpy_impl.utils import result_type 56 | # pylint: enable=wildcard-import 57 | 58 | max = amax # pylint: disable=redefined-builtin,undefined-variable 59 | min = amin # pylint: disable=redefined-builtin,undefined-variable 60 | round = around # pylint: disable=redefined-builtin,undefined-variable 61 | 62 | try: 63 | from tensorflow.python.ops.numpy_ops.np_config import enable_numpy_behavior 64 | # TODO(b/171429739): This should be moved to every individual file/test. 65 | enable_numpy_behavior() 66 | 67 | except ImportError: 68 | pass 69 | -------------------------------------------------------------------------------- /trax/tf_numpy/numpy_impl/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """NumPy API. Deprecated.""" 17 | -------------------------------------------------------------------------------- /trax/tf_numpy/numpy_impl/dtypes.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Dtypes and dtype utilities.""" 17 | import numpy as np 18 | 19 | # We use numpy's dtypes instead of TF's, because the user expects to use them 20 | # with numpy facilities such as `np.dtype(np.int64)` and 21 | # `if x.dtype.type is np.int64`. 22 | # pylint: disable=unused-import 23 | # pylint: disable=g-bad-import-order 24 | from numpy import bool_ 25 | from numpy import int_ 26 | from numpy import int16 27 | from numpy import int32 28 | from numpy import int64 29 | from numpy import int8 30 | from numpy import uint16 31 | from numpy import uint32 32 | from numpy import uint64 33 | from numpy import uint8 34 | from numpy import float16 35 | from numpy import float32 36 | from numpy import float64 37 | float_ = float64 38 | from numpy import complex64 39 | from numpy import complex128 40 | complex_ = complex128 41 | 42 | from numpy import inexact 43 | 44 | from numpy import iinfo 45 | from numpy import issubdtype 46 | 47 | from numpy import inf 48 | 49 | # TODO(wangpeng): Make bfloat16 a numpy dtype instead of using TF's 50 | from tensorflow.compat.v2 import bfloat16 51 | # pylint: enable=g-bad-import-order 52 | # pylint: enable=unused-import 53 | 54 | 55 | _to_float32 = { 56 | np.dtype('float64'): np.dtype('float32'), 57 | np.dtype('complex128'): np.dtype('complex64'), 58 | } 59 | 60 | 61 | _allow_float64 = True 62 | 63 | 64 | def is_allow_float64(): 65 | return _allow_float64 66 | 67 | 68 | def set_allow_float64(b): 69 | global _allow_float64 70 | _allow_float64 = b 71 | 72 | 73 | def canonicalize_dtype(dtype): 74 | if not is_allow_float64(): 75 | return _to_float32.get(dtype, dtype) 76 | else: 77 | return dtype 78 | 79 | 80 | def _result_type(*arrays_and_dtypes): 81 | dtype = np.result_type(*arrays_and_dtypes) 82 | return canonicalize_dtype(dtype) 83 | 84 | 85 | def default_float_type(): 86 | """Gets the default float type. 87 | 88 | Returns: 89 | If `is_allow_float64()` is true, returns float64; otherwise returns float32. 90 | """ 91 | if is_allow_float64(): 92 | return float64 93 | else: 94 | return float32 95 | -------------------------------------------------------------------------------- /trax/tf_numpy/numpy_impl/random.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Random functions.""" 17 | import numpy as np 18 | import tensorflow.compat.v2 as tf 19 | 20 | from trax.tf_numpy.numpy_impl import utils 21 | 22 | 23 | DEFAULT_RANDN_DTYPE = np.float32 24 | 25 | 26 | def randn(*args): 27 | """Returns samples from a normal distribution. 28 | 29 | Uses `tf.random_normal`. 30 | 31 | Args: 32 | *args: The shape of the output array. 33 | 34 | Returns: 35 | An ndarray with shape `args` and dtype `float64`. 36 | """ 37 | # TODO(wangpeng): Use new stateful RNG 38 | if utils.isscalar(args): 39 | args = (args,) 40 | return utils.tensor_to_ndarray( 41 | tf.random.normal(args, dtype=DEFAULT_RANDN_DTYPE)) 42 | 43 | 44 | def seed(s): 45 | """Sets the seed for the random number generator. 46 | 47 | Uses `tf.set_random_seed`. 48 | 49 | Args: 50 | s: an integer. 51 | """ 52 | # TODO(wangpeng): make the signature the same as numpy 53 | tf.random.set_seed(s) 54 | -------------------------------------------------------------------------------- /trax/tf_numpy/numpy_impl/tests/backprop_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for backpropgration on tf-numpy functions.""" 17 | import tensorflow.compat.v2 as tf 18 | 19 | from trax.tf_numpy.numpy_impl import array_ops 20 | # Required for operator overloads 21 | from trax.tf_numpy.numpy_impl import math_ops # pylint: disable=unused-import 22 | 23 | 24 | class BackpropTest(tf.test.TestCase): 25 | 26 | def test_setitem(self): 27 | # Single integer index. 28 | a = array_ops.array([1., 2., 3.]) 29 | b = array_ops.array(5.) 30 | c = array_ops.array(10.) 31 | 32 | tensors = [arr.data for arr in [a, b, c]] 33 | with tf.GradientTape() as g: 34 | g.watch(tensors) 35 | a[1] = b + c 36 | loss = array_ops.sum(a) 37 | 38 | gradients = g.gradient(loss.data, tensors) 39 | self.assertSequenceEqual( 40 | array_ops.array(gradients[0]).tolist(), [1., 0., 1.]) 41 | self.assertEqual(array_ops.array(gradients[1]).tolist(), 1.) 42 | self.assertEqual(array_ops.array(gradients[2]).tolist(), 1.) 43 | 44 | # Tuple index. 45 | a = array_ops.array([[[1., 2.], [3., 4.]], [[5., 6.], 46 | [7., 8.]]]) # 2x2x2 array. 47 | b = array_ops.array([10., 11.]) 48 | 49 | tensors = [arr.data for arr in [a, b]] 50 | with tf.GradientTape() as g: 51 | g.watch(tensors) 52 | a[(1, 0)] = b 53 | loss = array_ops.sum(a) 54 | 55 | gradients = g.gradient(loss.data, tensors) 56 | self.assertSequenceEqual( 57 | array_ops.array(gradients[0]).tolist(), 58 | [[[1., 1.], [1., 1.]], [[0., 0.], [1., 1.]]]) 59 | self.assertEqual(array_ops.array(gradients[1]).tolist(), [1., 1.]) 60 | 61 | 62 | if __name__ == '__main__': 63 | tf.compat.v1.enable_eager_execution() 64 | tf.test.main() 65 | -------------------------------------------------------------------------------- /trax/tf_numpy/numpy_impl/tests/utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for utils.py.""" 17 | import tensorflow.compat.v2 as tf 18 | 19 | from trax.tf_numpy.numpy_impl import utils 20 | 21 | 22 | class UtilsTest(tf.test.TestCase): 23 | 24 | # pylint: disable=unused-argument 25 | def testNpDoc(self): 26 | def np_fun(x): 27 | """np_fun docstring.""" 28 | return 29 | @utils.np_doc(np_fun) 30 | def f(): 31 | """f docstring.""" 32 | return 33 | expected = """TensorFlow variant of `numpy.np_fun`. 34 | 35 | Unsupported arguments: `x`. 36 | 37 | f docstring. 38 | 39 | Documentation for `numpy.np_fun`: 40 | 41 | np_fun docstring.""" 42 | self.assertEqual(f.__doc__, expected) 43 | 44 | 45 | if __name__ == '__main__': 46 | tf.enable_v2_behavior() 47 | tf.test.main() 48 | -------------------------------------------------------------------------------- /trax/tf_numpy/public_symbol_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Trax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests different ways to use the public tf-numpy module.""" 17 | import numpy as onp 18 | 19 | import tensorflow as tf 20 | import tensorflow.experimental.numpy as np1 21 | from tensorflow.experimental import numpy as np2 # pylint: disable=reimported 22 | 23 | 24 | np3 = tf.experimental.numpy 25 | 26 | 27 | class PublicSymbolTest(tf.test.TestCase): 28 | 29 | def testSimple(self): 30 | a = 0.1 31 | b = 0.2 32 | for op in [np1.add, np2.add, np3.add]: 33 | self.assertAllClose(onp.add(a, b), op(a, b)) 34 | 35 | 36 | if __name__ == "__main__": 37 | tf.compat.v1.enable_eager_execution() 38 | tf.test.main() 39 | --------------------------------------------------------------------------------