├── .coveragerc
├── .github
└── workflows
│ └── build.yaml
├── .readthedocs.yaml
├── .travis.yml
├── AUTHORS
├── CONTRIBUTING.md
├── ISSUE_TEMPLATE.md
├── LICENSE
├── README.md
├── docs
├── .readthedocs.yaml
├── Makefile
├── requirements.txt
└── source
│ ├── conf.py
│ ├── index.rst
│ ├── notebooks
│ ├── layers_intro.ipynb
│ ├── tf_numpy_and_keras.ipynb
│ └── trax_intro.ipynb
│ ├── trax.data.rst
│ ├── trax.fastmath.rst
│ ├── trax.layers.rst
│ ├── trax.models.rst
│ ├── trax.optimizers.rst
│ ├── trax.rl.rst
│ ├── trax.rst
│ └── trax.supervised.rst
├── oss_scripts
├── oss_pip_install.sh
├── oss_release.sh
└── oss_tests.sh
├── pylintrc
├── setup.py
└── trax
├── __init__.py
├── data
├── __init__.py
├── debug_data_pipeline.py
├── inputs.py
├── inputs_test.py
├── testdata
│ ├── bert_uncased_vocab.txt
│ ├── c4
│ │ └── en
│ │ │ └── 2.3.0
│ │ │ ├── c4-train.tfrecord-00000-of-00001
│ │ │ ├── c4-validation.tfrecord-00000-of-00001
│ │ │ └── dataset_info.json
│ ├── corpus-1.txt
│ ├── corpus-2.txt
│ ├── en_8k.subword
│ ├── para_crawl
│ │ └── ende
│ │ │ └── 1.2.0
│ │ │ ├── dataset_info.json
│ │ │ ├── features.json
│ │ │ └── para_crawl-train.tfrecord-00000-of-00001
│ ├── sentencepiece.model
│ ├── squad
│ │ └── v1.1
│ │ │ └── 3.0.0
│ │ │ ├── dataset_info.json
│ │ │ ├── squad-train.tfrecord-00000-of-00001
│ │ │ └── squad-validation.tfrecord-00000-of-00001
│ ├── vocab-1.txt
│ └── vocab-2.txt
├── text_encoder.py
├── text_encoder_build_subword.py
├── text_encoder_test.py
├── tf_inputs.py
├── tf_inputs_test.py
├── tokenizer.py
└── tokenizer_test.py
├── examples
├── Deep_N_Gram_Models.ipynb
├── Fashion_MNIST_with_Trax.ipynb
├── Knowledge_Tracing_Transformer.ipynb
├── MathQA_Python_generation_notebook.ipynb
├── NER_using_Reformer.ipynb
├── NMT_with_Transformers_Reformers_using_Trax.ipynb
├── README.md
├── Terraformer_from_scratch.ipynb
├── earlystopping.ipynb
├── illustrated_wideresnet.ipynb
├── semantic_segmentation.ipynb
└── trax_data_Explained.ipynb
├── fastmath
├── __init__.py
├── jax.py
├── numpy.py
├── ops.py
├── ops_test.py
└── tf.py
├── import_test.py
├── intro.ipynb
├── jaxboard.py
├── layers
├── README.md
├── __init__.py
├── acceleration.py
├── acceleration_test.py
├── activation_fns.py
├── activation_fns_test.py
├── assert_shape.py
├── assert_shape_test.py
├── attention.py
├── attention_test.py
├── base.py
├── base_test.py
├── combinators.py
├── combinators_test.py
├── convolution.py
├── convolution_test.py
├── core.py
├── core_test.py
├── deconvolution.py
├── deconvolution_test.py
├── initializers.py
├── initializers_test.py
├── intro.ipynb
├── metrics.py
├── metrics_test.py
├── normalization.py
├── normalization_test.py
├── pooling.py
├── pooling_test.py
├── research
│ ├── __init__.py
│ ├── efficient_attention.py
│ ├── efficient_attention_test.py
│ ├── position_encodings.py
│ ├── position_encodings_test.py
│ ├── rel_attention.py
│ ├── rel_attention_test.py
│ ├── resampling.py
│ ├── rotary_positional_embedding.py
│ ├── rotary_positional_embedding_test.py
│ ├── sparsity.py
│ └── sparsity_test.py
├── reversible.py
├── reversible_test.py
├── rnn.py
├── rnn_test.py
├── test_utils.py
└── test_utils_test.py
├── models
├── Attention_Visualization_in_Trax.ipynb
├── __init__.py
├── atari_cnn.py
├── atari_cnn_test.py
├── mlp.py
├── mlp_test.py
├── neural_gpu.py
├── neural_gpu_test.py
├── reformer
│ ├── README.md
│ ├── __init__.py
│ ├── image_generation.ipynb
│ ├── machine_translation.ipynb
│ ├── reformer.py
│ ├── reformer_e2e_test.py
│ ├── reformer_test.py
│ ├── testdata
│ │ ├── translate_ende_wmt32k-dev-00000-of-00001
│ │ ├── translate_ende_wmt32k-train-00000-of-00001
│ │ └── vocab.translate_ende_wmt32k.32768.subwords
│ └── text_generation.ipynb
├── research
│ ├── __init__.py
│ ├── bert.py
│ ├── configurable_transformer.py
│ ├── configurable_transformer_test.py
│ ├── examples
│ │ ├── hourglass_downsampled_imagenet.ipynb
│ │ └── hourglass_enwik8.ipynb
│ ├── hourglass.py
│ ├── hourglass_test.py
│ ├── layerdrop_transformer.py
│ ├── layerdrop_transformer_test.py
│ ├── predict_terraformer.py
│ ├── rezero.py
│ ├── rezero_test.py
│ ├── rse.py
│ ├── rse_test.py
│ ├── terraformer.py
│ ├── terraformer_e2e_test.py
│ ├── terraformer_memory_test.py
│ ├── terraformer_oom_test.py
│ ├── terraformer_test.py
│ ├── testdata
│ │ ├── translate_ende_wmt32k-dev-00000-of-00001
│ │ ├── translate_ende_wmt32k-train-00000-of-00001
│ │ └── vocab.translate_ende_wmt32k.32768.subwords
│ ├── transformer2.py
│ └── transformer2_test.py
├── resnet.py
├── resnet_test.py
├── rl.py
├── rl_test.py
├── rnn.py
├── rnn_test.py
├── transformer.py
└── transformer_test.py
├── optimizers
├── __init__.py
├── adafactor.py
├── adam.py
├── base.py
├── momentum.py
├── optimizers_test.py
├── rms_prop.py
├── sm3.py
├── trainer.py
└── trainer_test.py
├── predict_drop.py
├── rl
├── __init__.py
├── actor_critic.py
├── actor_critic_joint.py
├── actor_critic_joint_test.py
├── actor_critic_test.py
├── advantages.py
├── advantages_test.py
├── atari_test.py
├── configs
│ ├── dqn_cartpole_regression.gin
│ ├── light_a2c_joint_atari_sweep.yaml
│ ├── light_atari.gin
│ ├── light_atari_sweep.yaml
│ ├── light_awr_cartpole_sweep.yaml
│ ├── light_awr_joint_atari_sweep.yaml
│ ├── light_awr_joint_cartpole.gin
│ ├── light_awr_joint_cartpole_sweep.yaml
│ ├── light_cartpole.gin
│ ├── light_copy.gin
│ ├── light_copy_sweep.yaml
│ ├── light_joint_atari.gin
│ ├── light_joint_cartpole.gin
│ ├── light_lunarlander.gin
│ ├── light_mujoco.gin
│ ├── light_mujoco_regression_test.gin
│ ├── light_mujoco_sweep.yaml
│ ├── light_ppo_atari.gin
│ ├── light_ppo_boxing_regression_test.gin
│ ├── light_ppo_cartpole_regression_test.gin
│ ├── light_ppo_half_cheetah_regression_test.gin
│ ├── light_ppo_joint_atari.gin
│ ├── light_ppo_joint_atari_sweep.yaml
│ ├── light_ppo_lunar_lander_regression_test.gin
│ ├── light_ppo_pong_regression_test.gin
│ ├── ppo_atari_sweep.yaml
│ ├── ppo_cartpole_sweep.yaml
│ └── transformer_srl_sine.gin
├── distributions.py
├── distributions_test.py
├── envs
│ ├── __init__.py
│ ├── data_envs.py
│ └── data_envs_test.py
├── normalization.py
├── normalization_test.py
├── policy_tasks.py
├── rl_layers.py
├── serialization_utils.py
├── serialization_utils_test.py
├── space_serializer.py
├── space_serializer_test.py
├── task.py
├── task_test.py
├── training.py
├── training_test.py
├── value_tasks.py
└── value_tasks_test.py
├── rl_trainer.py
├── shapes.py
├── shapes_test.py
├── supervised
├── __init__.py
├── callbacks.py
├── callbacks_test.py
├── configs
│ ├── bert.gin
│ ├── bert_glue_classification.gin
│ ├── bert_glue_regression.gin
│ ├── bert_glue_sweep_regression_task.yaml
│ ├── bert_glue_sweep_single_sentence.yaml
│ ├── bert_glue_sweep_two_sentences.yaml
│ ├── bert_pretraining.gin
│ ├── bert_pretraining_onlymlm.gin
│ ├── bert_pretraining_onlynsp.gin
│ ├── c4.gin
│ ├── c4_pretrain_16gb_adafactor.gin
│ ├── c4_trax_data.gin
│ ├── cond_skipping_transformer_lm1b.gin
│ ├── gru_copy.gin
│ ├── hourglass_cifar10.gin
│ ├── hourglass_enwik8.gin
│ ├── hourglass_imagenet32.gin
│ ├── hourglass_imagenet64.gin
│ ├── hourglass_wiki40b.gin
│ ├── layerdrop_every_transformer_lm1b.gin
│ ├── layerdrop_transformer_lm1b.gin
│ ├── layerdrop_ushape_transformer_lm1b.gin
│ ├── lstm_lm1b.gin
│ ├── lstm_seq2seq_wmt_ende.gin
│ ├── mlp_mnist.gin
│ ├── reformer_addition.gin
│ ├── reformer_bair_robot_pushing.gin
│ ├── reformer_c4.gin
│ ├── reformer_cifar10.gin
│ ├── reformer_copy.gin
│ ├── reformer_enwik8.gin
│ ├── reformer_imagenet64.gin
│ ├── reformer_imagenet64_testing.gin
│ ├── reformer_pc_enpl.gin
│ ├── reformer_wmt_ende.gin
│ ├── reformer_wmt_ende_big.gin
│ ├── resnet50_frn_imagenet_8gb.gin
│ ├── resnet50_imagenet_8gb_testing.gin
│ ├── rezero_wmt_ende_16gb_adafactor_testing.gin
│ ├── rse_addition.gin
│ ├── rse_addition_sweep.yaml
│ ├── scientific_papers_reformer_lm.gin
│ ├── scientific_papers_terraformer.gin
│ ├── scientific_papers_terraformer_favor.gin
│ ├── scientific_papers_terraformer_pretrained.gin
│ ├── skipping_transformer_lm1b.gin
│ ├── sp_sweep.yaml
│ ├── sparse_c4_pretrain_16gb_adafactor.gin
│ ├── sparse_lm1b_pretrain_16gb.gin
│ ├── t5_aqua_parallel.gin
│ ├── t5_drop.gin
│ ├── t5_glue_classification.gin
│ ├── t5_glue_classification_mnli.gin
│ ├── t5_glue_classification_parallel.gin
│ ├── t5_glue_classification_two_constants.gin
│ ├── t5_mathqa.gin
│ ├── t5_mathqa_drop_loop.gin
│ ├── t5_mathqa_drop_sweep.yaml
│ ├── t5_mathqa_multi.gin
│ ├── t5_mathqa_parallel.gin
│ ├── t5_mathqa_parallel_full.gin
│ ├── t5_mathqa_parallel_full_correct_order.gin
│ ├── t5_mathqa_parallel_full_order.gin
│ ├── t5_mathqa_parallel_with_drop_annot.gin
│ ├── t5_sweep.yaml
│ ├── t5_sweep_temperature.yaml
│ ├── terraformer_c4_medium.gin
│ ├── terraformer_copy.gin
│ ├── terraformer_copy_self_attn.gin
│ ├── terraformer_purelsh_copy.gin
│ ├── terraformer_wmt_ende.gin
│ ├── transformer_big_lm1b_8gb.gin
│ ├── transformer_finetune_squad_16gb.gin
│ ├── transformer_imdb_8gb.gin
│ ├── transformer_imdb_tfds.gin
│ ├── transformer_lm1b_8gb_testing.gin
│ ├── transformer_lm_cnndailymail.gin
│ ├── transformer_lm_wmt_ende_16gb.gin
│ ├── transformer_lm_wmt_ende_8gb.gin
│ ├── transformer_ptb_16gb.gin
│ ├── transformer_wmt_ende_16gb_adafactor_testing.gin
│ ├── transformer_wmt_ende_8gb.gin
│ └── wide_resnet_cifar10_8gb.gin
├── decoding.py
├── decoding_test.py
├── decoding_timing_test.py
├── history.py
├── history_test.py
├── lr_schedules.py
├── lr_schedules_test.py
├── mnist_test.py
├── pretrain_finetune.py
├── testdata
│ ├── reformerlm_copy_lsh_attn.pkl.gz
│ ├── terraformer_copy_lsh_attn.pkl.gz
│ ├── terraformer_copy_self_attn.pkl.gz
│ ├── terraformer_purelsh_copy.pkl.gz
│ ├── transformer_copy.pkl.gz
│ └── transformerlm_copy.pkl.gz
├── trainer_lib.py
├── trainer_lib_test.py
├── training.py
└── training_test.py
├── test_utils.py
├── tf_numpy
├── __init__.py
├── examples
│ └── mnist
│ │ ├── dataset.py
│ │ ├── model.py
│ │ ├── train.py
│ │ └── train_test.py
├── extensions
│ ├── __init__.py
│ ├── extensions.py
│ └── extensions_test.py
├── jax_tests
│ ├── config.py
│ ├── lax_numpy_einsum_test.py
│ ├── lax_numpy_indexing_test.py
│ ├── lax_numpy_test.py
│ ├── test_util.py
│ └── vmap_test.py
├── numpy
│ └── __init__.py
├── numpy_impl
│ ├── __init__.py
│ ├── array_ops.py
│ ├── arrays.py
│ ├── dtypes.py
│ ├── math_ops.py
│ ├── random.py
│ ├── tests
│ │ ├── array_ops_test.py
│ │ ├── arrays_test.py
│ │ ├── backprop_test.py
│ │ ├── logic_test.py
│ │ ├── math_ops_test.py
│ │ ├── random_test.py
│ │ └── utils_test.py
│ └── utils.py
└── public_symbol_test.py
├── tf_numpy_and_keras.ipynb
├── trainer.py
├── trainer_flags.py
├── trax2keras.py
└── trax2keras_test.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | source =
3 | trax/
4 | omit =
5 | *_test.py
6 | */site-packages/*
7 |
8 | [report]
9 | omit =
10 | */site-packages/*
11 | exclude_lines =
12 | pragma: no cover
13 | def __repr__
14 | raise NotImplementedError
15 | if __name__ == .__main__.:
16 |
17 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # .readthedocs.yml
16 | # Read the Docs configuration file
17 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
18 |
19 | # Required
20 | version: 2
21 |
22 | # Build documentation in the docs/source directory with Sphinx.
23 | sphinx:
24 | configuration: docs/source/conf.py
25 |
26 | # Build docs in additional formats (PDF, ePub).
27 | formats: all
28 |
29 | # Optionally set the version of Python and requirements required to build your docs
30 | python:
31 | version: 3.7
32 | install:
33 | - requirements: docs/requirements.txt
34 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | sudo: required
2 | language: python
3 | cache: pip
4 | git:
5 | depth: 3
6 | quiet: true
7 | python:
8 | - "3.6"
9 | env:
10 | global:
11 | - TF_VERSION="2.4.*"
12 | matrix:
13 | - TRAX_TEST="lib"
14 | - TRAX_TEST="research"
15 | install:
16 | - ./oss_scripts/oss_pip_install.sh
17 | script:
18 | - ./oss_scripts/oss_tests.sh
19 |
--------------------------------------------------------------------------------
/AUTHORS:
--------------------------------------------------------------------------------
1 | # This is the list of Trax authors for copyright purposes.
2 | #
3 | # This does not necessarily list everyone who has contributed code, since in
4 | # some cases, their employer may be the copyright holder. To see the full list
5 | # of contributors, see the revision history in source control.
6 |
7 | Google Inc.
8 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | # Issues
4 |
5 | * Please tag your issue with `bug`, `feature request`, or `question` to help us
6 | effectively respond.
7 | * Please include the versions of JAX or Tensorflow you are running.
8 | * Please provide the command line you ran as well as the log output.
9 |
10 | # Pull Requests
11 |
12 | We'd love to accept your patches and contributions to this project. There are
13 | just a few small guidelines you need to follow.
14 |
15 | ## Contributor License Agreement
16 |
17 | Contributions to this project must be accompanied by a Contributor License
18 | Agreement. You (or your employer) retain the copyright to your contribution,
19 | this simply gives us permission to use and redistribute your contributions as
20 | part of the project. Head over to to see
21 | your current agreements on file or to sign a new one.
22 |
23 | You generally only need to submit a CLA once, so if you've already submitted one
24 | (even if it was for a different project), you probably don't need to do it
25 | again.
26 |
27 | ## Code reviews
28 |
29 | All submissions, including submissions by project members, require review. We
30 | use GitHub pull requests for this purpose. Consult
31 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
32 | information on using pull requests.
33 |
34 | ## Community Guidelines
35 |
36 | This project follows
37 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
38 |
--------------------------------------------------------------------------------
/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ### Description
2 |
3 | ...
4 |
5 | ### Environment information
6 |
7 | ```
8 | OS:
9 |
10 | $ pip freeze | grep trax
11 | # your output here
12 |
13 | $ pip freeze | grep tensor
14 | # your output here
15 |
16 | $ pip freeze | grep jax
17 | # your output here
18 |
19 | $ python -V
20 | # your output here
21 | ```
22 |
23 | ### For bugs: reproduction and error logs
24 |
25 | ```
26 | # Steps to reproduce:
27 | ...
28 | ```
29 |
30 | ```
31 | # Error logs:
32 | ...
33 | ```
34 |
--------------------------------------------------------------------------------
/docs/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # .readthedocs.yml
16 | # Read the Docs configuration file
17 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
18 |
19 | # Required
20 | version: 2
21 |
22 | # Build documentation in the docs/source directory with Sphinx.
23 | sphinx:
24 | configuration: docs/source/conf.py
25 |
26 | # Build docs in additional formats (PDF, ePub).
27 | formats: all
28 |
29 | # Optionally set the version of Python and requirements required to build your docs
30 | python:
31 | version: 3.7
32 | install:
33 | - requirements: docs/requirements.txt
34 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SOURCEDIR = source
8 | BUILDDIR = build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
20 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | nbsphinx
2 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. Trax documentation master file.
2 | Should contain at minimum the root `toctree` directive.
3 |
4 | Trax Tutorials
5 | ==============
6 |
7 | .. toctree::
8 | :caption: Introductory Notebooks
9 | :maxdepth: 2
10 |
11 | notebooks/trax_intro
12 | notebooks/layers_intro
13 | notebooks/tf_numpy_and_keras
14 |
15 |
16 | Trax API
17 | ========
18 |
19 | .. toctree::
20 | :caption: Packages/modules
21 | :maxdepth: 2
22 |
23 | trax.*
24 |
25 |
26 | Indices and Tables
27 | ==================
28 |
29 | * :ref:`genindex`
30 | * :ref:`modindex`
31 | * :ref:`search`
32 |
--------------------------------------------------------------------------------
/docs/source/trax.data.rst:
--------------------------------------------------------------------------------
1 | trax.data
2 | =========
3 |
4 | inputs
5 | ------
6 |
7 | .. automodule:: trax.data.inputs
8 |
9 | tf\_inputs
10 | ----------
11 |
12 | .. automodule:: trax.data.tf_inputs
13 |
--------------------------------------------------------------------------------
/docs/source/trax.fastmath.rst:
--------------------------------------------------------------------------------
1 | trax.fastmath
2 | =============
3 |
4 | ops
5 | ---
6 |
7 | .. automodule:: trax.fastmath.ops
8 |
--------------------------------------------------------------------------------
/docs/source/trax.layers.rst:
--------------------------------------------------------------------------------
1 | trax.layers
2 | ===========
3 |
4 | acceleration
5 | ------------
6 |
7 | .. automodule:: trax.layers.acceleration
8 |
9 | activation\_fns
10 | ---------------
11 |
12 | .. automodule:: trax.layers.activation_fns
13 |
14 | attention
15 | ---------
16 |
17 | .. automodule:: trax.layers.attention
18 |
19 | base
20 | ----
21 |
22 | .. automodule:: trax.layers.base
23 |
24 | combinators
25 | -----------
26 |
27 | .. automodule:: trax.layers.combinators
28 |
29 | convolution
30 | -----------
31 |
32 | .. automodule:: trax.layers.convolution
33 |
34 | core
35 | ----
36 |
37 | .. automodule:: trax.layers.core
38 |
39 | initializers
40 | ------------
41 |
42 | .. automodule:: trax.layers.initializers
43 |
44 | metrics
45 | -------
46 |
47 | .. automodule:: trax.layers.metrics
48 |
49 | normalization
50 | -------------
51 |
52 | .. automodule:: trax.layers.normalization
53 |
54 | pooling
55 | -------
56 |
57 | .. automodule:: trax.layers.pooling
58 |
59 | reversible
60 | ----------
61 |
62 | .. automodule:: trax.layers.reversible
63 |
64 | rnn
65 | ---
66 |
67 | .. automodule:: trax.layers.rnn
68 |
69 | research.efficient\_attention
70 | -----------------------------
71 |
72 | .. automodule:: trax.layers.research.efficient_attention
73 |
74 | research.position\_encodings
75 | ----------------------------
76 |
77 | .. automodule:: trax.layers.research.position_encodings
78 |
--------------------------------------------------------------------------------
/docs/source/trax.models.rst:
--------------------------------------------------------------------------------
1 | trax.models
2 | ===========
3 |
4 | atari\_cnn
5 | ----------
6 |
7 | .. automodule:: trax.models.atari_cnn
8 |
9 | mlp
10 | ---
11 |
12 | .. automodule:: trax.models.mlp
13 |
14 | neural\_gpu
15 | -----------
16 |
17 | .. automodule:: trax.models.neural_gpu
18 |
19 | resnet
20 | ------
21 |
22 | .. automodule:: trax.models.resnet
23 |
24 | rl
25 | --
26 |
27 | .. automodule:: trax.models.rl
28 |
29 | rnn
30 | ---
31 |
32 | .. automodule:: trax.models.rnn
33 |
34 | transformer
35 | -----------
36 |
37 | .. automodule:: trax.models.transformer
38 |
39 | reformer.reformer
40 | -----------------
41 |
42 | .. automodule:: trax.models.reformer.reformer
43 |
44 | research.bert
45 | -------------
46 |
47 | .. automodule:: trax.models.research.bert
48 |
49 | research.skipping\_transformer
50 | ------------------------------
51 |
52 | .. automodule:: trax.models.research.skipping_transformer
53 |
--------------------------------------------------------------------------------
/docs/source/trax.optimizers.rst:
--------------------------------------------------------------------------------
1 | trax.optimizers
2 | ===============
3 |
4 | adafactor
5 | ---------
6 |
7 | .. automodule:: trax.optimizers.adafactor
8 |
9 | adam
10 | ----
11 |
12 | .. automodule:: trax.optimizers.adam
13 |
14 | base
15 | ----
16 |
17 | .. automodule:: trax.optimizers.base
18 |
19 | momentum
20 | --------
21 |
22 | .. automodule:: trax.optimizers.momentum
23 |
24 | rms\_prop
25 | ---------
26 |
27 | .. automodule:: trax.optimizers.rms_prop
28 |
29 | sm3
30 | ---
31 |
32 | .. automodule:: trax.optimizers.sm3
33 |
--------------------------------------------------------------------------------
/docs/source/trax.rl.rst:
--------------------------------------------------------------------------------
1 | trax.rl package
2 | ===============
3 |
4 | actor\_critic
5 | -------------
6 |
7 | .. automodule:: trax.rl.actor_critic
8 |
9 | actor\_critic\_joint
10 | --------------------
11 |
12 | .. automodule:: trax.rl.actor_critic_joint
13 |
14 | advantages
15 | ----------
16 |
17 | .. automodule:: trax.rl.advantages
18 |
19 | distributions
20 | -------------
21 |
22 | .. automodule:: trax.rl.distributions
23 |
24 | normalization
25 | -------------
26 |
27 | .. automodule:: trax.rl.normalization
28 |
29 | rl\_layers
30 | ----------
31 |
32 | .. automodule:: trax.rl.rl_layers
33 |
34 | serialization\_utils
35 | --------------------
36 |
37 | .. automodule:: trax.rl.serialization_utils
38 |
39 | space\_serializer
40 | -----------------
41 |
42 | .. automodule:: trax.rl.space_serializer
43 |
44 | task
45 | ----
46 |
47 | .. automodule:: trax.rl.task
48 |
49 | training
50 | --------
51 |
52 | .. automodule:: trax.rl.training
53 |
--------------------------------------------------------------------------------
/docs/source/trax.rst:
--------------------------------------------------------------------------------
1 | trax
2 | ====
3 |
4 | .. toctree::
5 | fastmath.*
6 | layers.*
7 | models.*
8 | data.*
9 | optimizers.*
10 | supervised.*
11 | rl.*
12 |
13 |
14 | shapes
15 | ------
16 |
17 | .. automodule:: trax.shapes
18 |
19 | trainer
20 | -------
21 |
22 | .. automodule:: trax.trainer
23 |
24 | rl\_trainer
25 | -----------
26 |
27 | .. automodule:: trax.rl_trainer
28 |
29 |
30 | trax2keras
31 | ----------
32 |
33 | .. automodule:: trax.trax2keras
34 |
--------------------------------------------------------------------------------
/docs/source/trax.supervised.rst:
--------------------------------------------------------------------------------
1 | trax.supervised
2 | ===============
3 |
4 | decoding
5 | --------
6 |
7 | .. automodule:: trax.supervised.decoding
8 |
9 | lr\_schedules
10 | -------------
11 |
12 | .. automodule:: trax.supervised.lr_schedules
13 |
14 | training
15 | --------
16 |
17 | .. automodule:: trax.supervised.training
18 |
--------------------------------------------------------------------------------
/oss_scripts/oss_pip_install.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | #!/bin/bash
16 |
17 | set -v # print commands as they're executed
18 | set -e # fail and exit on any command erroring
19 |
20 | : "${TF_VERSION:?}"
21 |
22 | # Make sure we have the latest pip and setuptools installed.
23 | pip install -q -U pip
24 | pip install -q -U setuptools
25 |
26 | # Make sure we have the latest version of numpy - avoid problems we were
27 | # seeing with Python 3
28 | pip install -q -U numpy
29 |
30 | # Install appropriate version to tensorflow.
31 | if [[ "$TF_VERSION" == "tf-nightly" ]]
32 | then
33 | pip install tf-nightly;
34 | else
35 | pip install -q "tensorflow==$TF_VERSION"
36 | fi
37 |
38 | # Just print the version again to make sure.
39 | python -c 'import tensorflow as tf; print(tf.__version__)'
40 |
41 | # First ensure that the base dependencies are sufficient for a full import
42 | pip install -q -e .
43 |
44 | # Then install the test dependencies
45 | pip install -q -e .[tests]
46 | # Make sure to install the atari extras for gym
47 | pip install "gym[atari]"
48 |
49 | # Coverage.
50 | pip install coverage coveralls
51 |
--------------------------------------------------------------------------------
/oss_scripts/oss_release.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | #!/bin/bash
16 |
17 | set -v # print commands as they're executed
18 | set -e # fail and exit on any command erroring
19 |
20 | GIT_COMMIT_ID=${1:-""}
21 | [[ -z $GIT_COMMIT_ID ]] && echo "Must provide a commit" && exit 1
22 |
23 | TMP_DIR=$(mktemp -d)
24 | pushd $TMP_DIR
25 |
26 | echo "Cloning trax and checking out commit $GIT_COMMIT_ID"
27 | git clone https://github.com/google/trax.git
28 | cd trax
29 | git checkout $GIT_COMMIT_ID
30 |
31 | python3 -m pip install wheel twine pyopenssl
32 |
33 | # Build the distribution
34 | echo "Building distribution"
35 | python3 setup.py sdist
36 | python3 setup.py bdist_wheel --universal
37 |
38 | # Publish to PyPI
39 | echo "Publishing to PyPI"
40 | python3 -m twine upload dist/*
41 |
42 | # Cleanup
43 | rm -rf build/ dist/ trax.egg-info/
44 | popd
45 | rm -rf $TMP_DIR
46 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # coding=utf-8
17 | """Install trax."""
18 |
19 | from setuptools import find_packages
20 | from setuptools import setup
21 |
22 | setup(
23 | name='trax',
24 | version='1.4.1',
25 | description='Trax',
26 | long_description=(
27 | 'Trax helps you understand deep learning. We start with basic maths and'
28 | ' go through layers, models, supervised and reinforcement learning. We '
29 | 'get to advanced deep learning results, including recent papers and '
30 | 'state-of-the-art models.'
31 | ),
32 | author='Google Inc.',
33 | author_email='no-reply@google.com',
34 | url='http://github.com/google/trax',
35 | license='Apache 2.0',
36 | packages=find_packages(),
37 | install_requires=[
38 | 'absl-py',
39 | 'funcsigs',
40 | 'gin-config',
41 | 'gym',
42 | 'jax',
43 | 'jaxlib',
44 | 'matplotlib',
45 | 'numpy',
46 | 'psutil',
47 | 'scipy',
48 | 'six',
49 | 'tensorflow-datasets',
50 | 'tensorflow-text',
51 | ],
52 | extras_require={
53 | 'tensorflow': ['tensorflow>=1.15.0'],
54 | 'tensorflow_gpu': ['tensorflow-gpu>=1.15.0'],
55 | 't5': ['t5>=0.4.0'],
56 | 'tests': [
57 | 'attrs',
58 | 'jupyter',
59 | 'mock',
60 | 'parameterized',
61 | 'pylint',
62 | 'pytest',
63 | 'wrapt==1.11.*',
64 | ],
65 | 't2t': ['tensor2tensor',],
66 | },
67 | classifiers=[
68 | 'Development Status :: 4 - Beta',
69 | 'Intended Audience :: Developers',
70 | 'Intended Audience :: Science/Research',
71 | 'License :: OSI Approved :: Apache Software License',
72 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
73 | ],
74 | keywords='tensorflow machine learning jax',
75 | )
76 |
--------------------------------------------------------------------------------
/trax/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Trax top level import."""
17 |
18 | from trax import data
19 | from trax import fastmath
20 | from trax import layers
21 | from trax import models
22 | from trax import optimizers
23 | from trax import shapes
24 | from trax import supervised
25 | from trax.supervised import lr_schedules as lr
26 | from trax.trax2keras import AsKeras
27 |
--------------------------------------------------------------------------------
/trax/data/debug_data_pipeline.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A debugging decorator for TRAX input pipeline."""
17 |
18 | import functools
19 |
20 | from absl import logging
21 | import gin
22 |
23 |
24 | @gin.configurable(denylist=['f'])
25 | def debug_pipeline(f, debug=False, method='pow', log_prefix=None):
26 | """Decorator for input pipeline generators that logs examples at intervals."""
27 | if not debug:
28 | return f
29 |
30 | assert method in ('pow', 'every')
31 | @functools.wraps(f)
32 | def wrapper(*args, **kwargs):
33 | count = 0
34 | prefix = log_prefix or f.__name__
35 | for example in f(*args, **kwargs):
36 | count += 1
37 | if method == 'every' or (method == 'pow' and (count & count - 1 == 0)):
38 | logging.info('%s example[%d] = %r', prefix, count, example)
39 | yield example
40 |
41 | return wrapper
42 |
--------------------------------------------------------------------------------
/trax/data/testdata/c4/en/2.3.0/c4-train.tfrecord-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/data/testdata/c4/en/2.3.0/c4-train.tfrecord-00000-of-00001
--------------------------------------------------------------------------------
/trax/data/testdata/c4/en/2.3.0/c4-validation.tfrecord-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/data/testdata/c4/en/2.3.0/c4-validation.tfrecord-00000-of-00001
--------------------------------------------------------------------------------
/trax/data/testdata/c4/en/2.3.0/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "citation": "\n@article{2019t5,\n author = {Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu},\n title = {Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer},\n journal = {arXiv e-prints},\n year = {2019},\n archivePrefix = {arXiv},\n eprint = {1910.10683},\n}\n",
3 | "description": "A colossal, cleaned version of Common Crawl's web crawl corpus.\n\nBased on Common Crawl dataset: \"https://commoncrawl.org\"\n\nDue to the overhead of cleaning the dataset, it is recommend you prepare it with\na distributed service like Cloud Dataflow. More info at\nhttps://www.tensorflow.org/datasets/beam_datasets.\n",
4 | "downloadSize": "7650623283551",
5 | "location": {
6 | "urls": [
7 | "https://github.com/google-research/text-to-text-transfer-transformer#datasets"
8 | ]
9 | },
10 | "name": "c4",
11 | "splits": [
12 | {
13 | "name": "train",
14 | "numBytes": "10603",
15 | "numShards": "1",
16 | "shardLengths": [
17 | "10"
18 | ]
19 | },
20 | {
21 | "name": "validation",
22 | "numBytes": "16851",
23 | "numShards": "1",
24 | "shardLengths": [
25 | "10"
26 | ]
27 | }
28 | ],
29 | "version": "2.3.0"
30 | }
31 |
--------------------------------------------------------------------------------
/trax/data/testdata/corpus-1.txt:
--------------------------------------------------------------------------------
1 | One morning I shot an elephant in my pajamas. How he got in my pajamas, I don't
2 | know.
3 |
4 | Groucho Marx
5 |
--------------------------------------------------------------------------------
/trax/data/testdata/corpus-2.txt:
--------------------------------------------------------------------------------
1 | I haven't slept for 10 days... because that would be too long.
2 |
3 | Mitch Hedberg
4 |
--------------------------------------------------------------------------------
/trax/data/testdata/para_crawl/ende/1.2.0/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "citation": "@misc {paracrawl,\n title = \"ParaCrawl\",\n year = \"2018\",\n url = \"http://paracrawl.eu/download.html.\"\n}",
3 | "configDescription": "Translation dataset from English to de.",
4 | "configName": "ende",
5 | "description": "Web-Scale Parallel Corpora for Official European Languages.",
6 | "downloadSize": "1307754745",
7 | "location": {
8 | "urls": [
9 | "https://paracrawl.eu/releases.html"
10 | ]
11 | },
12 | "name": "para_crawl",
13 | "splits": [
14 | {
15 | "name": "train",
16 | "numBytes": "3241",
17 | "shardLengths": [
18 | "10"
19 | ]
20 | }
21 | ],
22 | "supervisedKeys": {
23 | "input": "en",
24 | "output": "de"
25 | },
26 | "version": "1.2.0"
27 | }
28 |
--------------------------------------------------------------------------------
/trax/data/testdata/para_crawl/ende/1.2.0/features.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": "tensorflow_datasets.core.features.translation_feature.Translation",
3 | "content": {
4 | "languages": [
5 | "de",
6 | "en"
7 | ]
8 | }
9 | }
--------------------------------------------------------------------------------
/trax/data/testdata/para_crawl/ende/1.2.0/para_crawl-train.tfrecord-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/data/testdata/para_crawl/ende/1.2.0/para_crawl-train.tfrecord-00000-of-00001
--------------------------------------------------------------------------------
/trax/data/testdata/sentencepiece.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/data/testdata/sentencepiece.model
--------------------------------------------------------------------------------
/trax/data/testdata/squad/v1.1/3.0.0/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "citation": "@article{2016arXiv160605250R,\n author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n Konstantin and {Liang}, Percy},\n title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n journal = {arXiv e-prints},\n year = 2016,\n eid = {arXiv:1606.05250},\n pages = {arXiv:1606.05250},\narchivePrefix = {arXiv},\n eprint = {1606.05250},\n}\n",
3 | "description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n",
4 | "location": {
5 | "urls": [
6 | "https://rajpurkar.github.io/SQuAD-explorer/"
7 | ]
8 | },
9 | "name": "squad",
10 | "schema": {
11 | "feature": [
12 | {
13 | "name": "answers"
14 | },
15 | {
16 | "name": "context",
17 | "type": "BYTES"
18 | },
19 | {
20 | "name": "id",
21 | "type": "BYTES"
22 | },
23 | {
24 | "name": "question",
25 | "type": "BYTES"
26 | },
27 | {
28 | "name": "title",
29 | "type": "BYTES"
30 | }
31 | ]
32 | },
33 | "sizeInBytes": "35142551",
34 | "splits": [
35 | {
36 | "name": "train",
37 | "numShards": "1",
38 | "shardLengths": [
39 | "10"
40 | ]
41 | },
42 | {
43 | "name": "validation",
44 | "numShards": "1",
45 | "shardLengths": [
46 | "10"
47 | ]
48 | }
49 | ],
50 | "version": "3.0.0"
51 | }
52 |
--------------------------------------------------------------------------------
/trax/data/testdata/squad/v1.1/3.0.0/squad-train.tfrecord-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/data/testdata/squad/v1.1/3.0.0/squad-train.tfrecord-00000-of-00001
--------------------------------------------------------------------------------
/trax/data/testdata/squad/v1.1/3.0.0/squad-validation.tfrecord-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/data/testdata/squad/v1.1/3.0.0/squad-validation.tfrecord-00000-of-00001
--------------------------------------------------------------------------------
/trax/data/testdata/vocab-1.txt:
--------------------------------------------------------------------------------
1 | lollipop,8
2 | reverberated,12
3 |
--------------------------------------------------------------------------------
/trax/data/testdata/vocab-2.txt:
--------------------------------------------------------------------------------
1 | kattywampus,11
2 | kaput
3 | balderdash,10
4 | jiggery-pokery,14
5 |
--------------------------------------------------------------------------------
/trax/examples/README.md:
--------------------------------------------------------------------------------
1 | # Trax Examples
2 |
3 |
4 | Here are a few example notebooks:-
5 |
6 | * [**trax.data API explained**](https://github.com/google/trax/blob/master/trax/examples/trax_data_Explained.ipynb) : Explains some of the major functions in the `trax.data` API
7 | * [**Named Entity Recognition using Reformer**](https://github.com/google/trax/blob/master/trax/examples/NER_using_Reformer.ipynb) : Uses a [Kaggle dataset](https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus) for implementing Named Entity Recognition using the [Reformer](https://arxiv.org/abs/2001.04451) architecture.
8 | * [**Deep N-Gram models**](https://github.com/google/trax/blob/master/trax/examples/Deep_N_Gram_Models.ipynb) : Implementation of deep n-gram models trained on Shakespeares works
9 | * [**Knowledge Tracing Transformer**](https://github.com/google/trax/blob/master/trax/examples/Knowledge_Tracing_Transformer.ipynb): End-to-end example adapting basic Transformer architecture to accommodate a knowledge tracing task.
10 | * [**Neural Machine Translation with Transformers/Reformers**](https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb): A guide to Neural Machine Translation using Transformers/Reformers. Includes a detailed tutorial using Trax.
11 |
--------------------------------------------------------------------------------
/trax/fastmath/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Trax fast math -- NumPy-style math on accelerators."""
17 |
18 | from trax.fastmath.numpy import nested_map
19 | from trax.fastmath.numpy import nested_map_multiarg
20 | from trax.fastmath.numpy import nested_stack
21 | from trax.fastmath.numpy import nested_zip
22 | from trax.fastmath.numpy import tree_flatten
23 | from trax.fastmath.numpy import tree_leaves
24 | from trax.fastmath.numpy import tree_unflatten
25 | from trax.fastmath.ops import * # pylint: disable=wildcard-import
26 |
--------------------------------------------------------------------------------
/trax/import_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for importing Trax."""
17 |
18 | from absl.testing import absltest
19 |
20 |
21 | class ImportTest(absltest.TestCase):
22 |
23 | def test_import_trax(self):
24 | try:
25 | # Import trax
26 | import trax # pylint: disable=g-import-not-at-top
27 | # Access a few symbols.
28 | dir(trax.fastmath)
29 | dir(trax.layers)
30 | dir(trax.models)
31 | except ImportError as e:
32 | raise e
33 |
34 |
35 | if __name__ == '__main__':
36 | absltest.main()
37 |
--------------------------------------------------------------------------------
/trax/layers/README.md:
--------------------------------------------------------------------------------
1 | # Trax Layers
2 |
3 |
4 |
5 | ## Base layer structure
6 |
7 | All layers inherit from the Layer class and generally need to implement 2
8 | methods:
9 |
10 | ```python
11 | def forward(self, inputs):
12 | """Computes the layer's output as part of a forward pass through the model."""
13 |
14 | def init_weights_and_state(self, input_signature):
15 | """Initializes weights and state for inputs with the given signature."""
16 | ```
17 |
18 | The base Layer class wraps these functions and provides initialization
19 | and call functions to be used as follows.
20 |
21 | ```python
22 | layer = MyLayer()
23 | x = np.zeros(10)
24 | layer.init(signature(x))
25 | output = layer(x)
26 | ```
27 |
28 | ## Fn layer
29 |
30 | To create simple layers without parameters, use the Fn layer.
31 |
32 | ```python
33 | def Relu(x):
34 | return Fn('Relu', lambda x: np.maximum(x, np.zeros_like(x)))
35 | ```
36 |
37 | ## Parameter sharing
38 |
39 | Parameters are shared when the same layer object is used.
40 |
41 | ```python
42 | standard_mlp = layers.Serial(layers.Dense(10), layers.Dense(10))
43 | layer = Dense(10)
44 | shared_parameters_mlp = layers.Serial(layer, layer)
45 | ```
46 |
47 | ## Core layers
48 |
49 | * Dense
50 | * Conv
51 |
52 | ## Layer composition
53 |
54 | * Serial
55 | * Parallel
56 |
--------------------------------------------------------------------------------
/trax/layers/activation_fns_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for activation function layers."""
17 |
18 | from absl.testing import absltest
19 | import numpy as np
20 |
21 | import trax.layers as tl
22 |
23 |
24 | class ActivationFnsTest(absltest.TestCase):
25 |
26 | def test_relu(self):
27 | layer = tl.Relu()
28 | x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0])
29 | y = layer(x)
30 | self.assertEqual(tl.to_list(y), [0.0, 0.0, 0.0, 2.0, 3.0, 5.0])
31 |
32 | def test_parametric_relu(self):
33 | layer = tl.ParametricRelu(a=.25)
34 | x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0])
35 | y = layer(x)
36 | self.assertEqual(tl.to_list(y), [0.0, 0.0, 0.0, .5, .75, 1.25])
37 |
38 | def test_leaky_relu(self):
39 | layer = tl.LeakyRelu(a=.125)
40 | x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0])
41 | y = layer(x)
42 | self.assertEqual(tl.to_list(y), [-.25, -.125, 0.0, 2.0, 3.0, 5.0])
43 |
44 | def test_hard_sigmoid(self):
45 | layer = tl.HardSigmoid()
46 | x = np.array([-1.5, -.5, -.25, 0.0, .25, .5, 1.5])
47 | y = layer(x)
48 | self.assertEqual(tl.to_list(y), [0.0, 0.5, 0.75, 1.0, 1.0, 1.0, 1.0])
49 |
50 | def test_hard_tanh(self):
51 | layer = tl.HardTanh()
52 | x = np.array([-1.5, -.5, -.25, 0.0, .25, .5, 1.5])
53 | y = layer(x)
54 | self.assertEqual(tl.to_list(y), [-1.0, -.5, -.25, 0.0, .25, .5, 1.0])
55 |
56 |
57 | if __name__ == '__main__':
58 | absltest.main()
59 |
--------------------------------------------------------------------------------
/trax/layers/deconvolution_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Deconvolution layers."""
17 |
18 | from absl.testing import absltest
19 | import numpy as np
20 |
21 | from trax import shapes
22 | import trax.layers as tl
23 |
24 |
25 | class ConvTransposeTest(absltest.TestCase):
26 |
27 | def test_call(self):
28 | layer = tl.ConvTranspose(30, (3, 3))
29 | x = np.ones((9, 5, 5, 20))
30 | layer.init(shapes.signature(x))
31 |
32 | y = layer(x)
33 | self.assertEqual(y.shape, (9, 7, 7, 30))
34 |
35 |
36 | if __name__ == '__main__':
37 | absltest.main()
38 |
--------------------------------------------------------------------------------
/trax/layers/research/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/trax/layers/research/rel_attention_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Copyright 2020 The Trax Authors.
17 | #
18 | # Licensed under the Apache License, Version 2.0 (the "License");
19 | # you may not use this file except in compliance with the License.
20 | # You may obtain a copy of the License at
21 | #
22 | # http://www.apache.org/licenses/LICENSE-2.0
23 | #
24 | # Unless required by applicable law or agreed to in writing, software
25 | # distributed under the License is distributed on an "AS IS" BASIS,
26 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27 | # See the License for the specific language governing permissions and
28 | # limitations under the License.
29 |
30 | """Tests for trax.layers.relattention."""
31 |
32 | from absl.testing import absltest
33 | import numpy as np
34 |
35 | import trax.layers as tl
36 | import trax.layers.research.rel_attention as ra
37 |
38 |
39 | class RelAttentionTest(absltest.TestCase):
40 |
41 | def test_fast_shift_matrix(self):
42 | layer = ra._fast_matrix_shift
43 | x = np.array([[[[-3., -2., -1., 0.], [-3., -2., -1.,
44 | 0.], [-3., -2., -1., 0.],
45 | [-3., -2., -1., 0.]]]]).astype(np.float32)
46 |
47 | y = layer(x)
48 | self.assertEqual(y.dtype, np.float32)
49 | self.assertEqual(
50 | tl.to_list(y), [[[[0., 0., -3., -2.], [-1., 0., 0., -3.],
51 | [-2., -1., 0., 0.], [-3., -2., -1., 0.]]]])
52 |
53 | if __name__ == '__main__':
54 | absltest.main()
55 |
--------------------------------------------------------------------------------
/trax/layers/research/rotary_positional_embedding.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Rotary positional embeddings.
17 |
18 | Rotary positional embedding implementation, as described in:
19 | https://arxiv.org/pdf/2104.09864.pdf
20 | """
21 |
22 | # from trax import layers as tl
23 | from trax.fastmath import numpy as jnp
24 | from trax.layers import core
25 |
26 |
27 | def rotate(x):
28 | """Rotate function."""
29 | _, l, d = x.shape
30 | inv_freq = jnp.exp(jnp.arange(0, d, 2) * -(jnp.log(10000.0) / d))
31 | positions = jnp.arange(l)
32 | freqs = jnp.einsum('i,j->ij', positions, inv_freq)
33 | emb = jnp.concatenate((freqs, freqs), axis=-1)
34 | cos = jnp.cos(emb)
35 | sin = jnp.sin(emb)
36 |
37 | def mul(vecs, pos_emb):
38 | return jnp.einsum('bld,ld->bld', vecs, pos_emb)
39 |
40 | def rotate_half(x):
41 | x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
42 | return jnp.concatenate((-x2, x1), axis=x1.ndim - 1)
43 |
44 | return mul(x, cos) + mul(rotate_half(x), sin)
45 |
46 |
47 | def Rotate(): # pylint: disable=invalid-name
48 | return core.Fn('Rotate', rotate)
49 |
--------------------------------------------------------------------------------
/trax/layers/research/rotary_positional_embedding_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for trax.layers.research.rotary_positional_embedding."""
17 |
18 | from absl.testing import absltest
19 | import numpy as np
20 | from trax.layers.research import rotary_positional_embedding as rotary_pe
21 |
22 |
23 | class RelAttentionTest(absltest.TestCase):
24 |
25 | def test_rotary_monotonicity(self):
26 | layer = rotary_pe.Rotate()
27 | batch_size = 1
28 | seq_len = 32
29 | d_model = 512
30 | shape = (batch_size, seq_len, d_model)
31 | q, k = np.ones(shape).astype(np.float32), np.ones(shape).astype(np.float32)
32 | q, k = layer(q), layer(k)
33 |
34 | self.assertEqual(q.dtype, np.float32)
35 | self.assertEqual(q.shape, shape)
36 |
37 | # Test monotonicity of the resulting dot_product for the two first tokens
38 | # in close proximity
39 | dot_product = np.einsum('bnd, bmd -> bnm', q, k)
40 |
41 | self.assertTrue((dot_product[0, 0, :9] > dot_product[0, 0, 1:10]).all())
42 | self.assertTrue((dot_product[0, 1, 1:10] > dot_product[0, 1, 2:11]).all())
43 |
44 |
45 | if __name__ == '__main__':
46 | absltest.main()
47 |
--------------------------------------------------------------------------------
/trax/layers/reversible_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for reversible layers."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import numpy as np
21 |
22 | from trax import fastmath
23 | import trax.layers as tl
24 |
25 |
26 | BACKENDS = [fastmath.Backend.JAX]
27 |
28 |
29 | class ReversibleLayerTest(parameterized.TestCase):
30 |
31 | @parameterized.named_parameters([('_' + b.value, b) for b in BACKENDS])
32 | def test_reversible_swap(self, backend):
33 | with fastmath.use_backend(backend):
34 | layer = tl.ReversibleSwap()
35 | xs = [np.array([1, 2]), np.array([10, 20])]
36 | ys = layer(xs)
37 | self.assertEqual(tl.to_list(ys), [[10, 20], [1, 2]])
38 |
39 |
40 | if __name__ == '__main__':
41 | absltest.main()
42 |
--------------------------------------------------------------------------------
/trax/layers/rnn_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for rnn layers."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import numpy as np
21 |
22 | from trax import fastmath
23 | from trax import shapes
24 | import trax.layers as tl
25 |
26 |
27 | BACKENDS = [fastmath.Backend.JAX]
28 |
29 |
30 | @parameterized.named_parameters(
31 | ('_' + b.value, b) for b in BACKENDS)
32 | class RnnTest(parameterized.TestCase):
33 |
34 | def test_conv_gru_cell(self, backend):
35 | with fastmath.use_backend(backend):
36 | layer = tl.ConvGRUCell(9, kernel_size=(3, 3))
37 | x = np.ones((8, 1, 7, 9))
38 | _, _ = layer.init(shapes.signature(x))
39 | y = layer(x)
40 | self.assertEqual(y.shape, x.shape)
41 |
42 | def test_gru_cell(self, backend):
43 | with fastmath.use_backend(backend):
44 | layer = tl.GRUCell(9)
45 | xs = [np.ones((8, 7, 9)), np.ones((8, 7, 9))]
46 | _, _ = layer.init(shapes.signature(xs))
47 | ys = layer(xs)
48 | self.assertEqual([y.shape for y in ys], [(8, 7, 9), (8, 7, 9)])
49 |
50 | def test_lstm_cell(self, backend):
51 | with fastmath.use_backend(backend):
52 | layer = tl.LSTMCell(9)
53 | xs = [np.ones((8, 9)), np.ones((8, 18))]
54 | _, _ = layer.init(shapes.signature(xs))
55 | ys = layer(xs)
56 | self.assertEqual([y.shape for y in ys], [(8, 9), (8, 18)])
57 |
58 | def test_sru(self, backend):
59 | with fastmath.use_backend(backend):
60 | layer = tl.SRU(7)
61 | x = np.ones((8, 9, 7), np.float32)
62 | _, _ = layer.init(shapes.signature(x))
63 | y = layer(x)
64 | self.assertEqual(y.shape, x.shape)
65 |
66 | def test_names(self, backend):
67 | with fastmath.use_backend(backend):
68 | layer = tl.LSTM(3)
69 | self.assertEqual('LSTM_3', str(layer))
70 | layer = tl.GRU(5)
71 | self.assertEqual('GRU_5', str(layer))
72 | layer = tl.SRU(7)
73 | self.assertEqual('SRU_7', str(layer))
74 |
75 |
76 | if __name__ == '__main__':
77 | absltest.main()
78 |
--------------------------------------------------------------------------------
/trax/models/atari_cnn_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for trax.models.atari_cnn."""
17 |
18 | import functools
19 | import operator as op
20 | import numpy as np
21 | from tensorflow import test
22 | from trax.models import atari_cnn
23 | from trax.shapes import ShapeDtype
24 |
25 |
26 | class AtariCnnTest(test.TestCase):
27 |
28 | def test_computes(self):
29 | hidden_size = (4, 4)
30 | output_size = 6
31 | model = atari_cnn.AtariCnn(
32 | hidden_sizes=hidden_size, output_size=output_size)
33 | B, T, OBS = 2, 2, (28, 28, 3) # pylint: disable=invalid-name
34 | input_signature = ShapeDtype((1, 1) + OBS)
35 | _, _ = model.init(input_signature)
36 | x = np.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape(
37 | B, T + 1, *OBS)
38 | y = model(x)
39 | self.assertEqual((B, T + 1, output_size), y.shape)
40 |
41 |
42 | class FrameStackMLPTest(test.TestCase):
43 |
44 | def test_computes(self):
45 | hidden_size = (4, 4)
46 | output_size = 6
47 | model = atari_cnn.FrameStackMLP(
48 | hidden_sizes=hidden_size, output_size=output_size)
49 | B, T, OBS = 2, 2, 3 # pylint: disable=invalid-name
50 | input_signature = ShapeDtype((1, 1, OBS))
51 | _, _ = model.init(input_signature)
52 | x = np.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS)
53 | y = model(x)
54 | self.assertEqual((B, T + 1, output_size), y.shape)
55 |
56 |
57 | if __name__ == '__main__':
58 | test.main()
59 |
--------------------------------------------------------------------------------
/trax/models/mlp.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """mlp -- functions that assemble "multilayer perceptron" networks."""
17 |
18 | from trax import layers as tl
19 |
20 |
21 | def MLP(
22 | layer_widths=(128, 64),
23 | activation_fn=tl.Relu,
24 | out_activation=False,
25 | flatten=True,
26 | mode='train'):
27 | """A "multilayer perceptron" (MLP) network.
28 |
29 | This is a classic fully connected feedforward network, with one or more
30 | layers and a (nonlinear) activation function between each layer. For
31 | historical reasons, such networks are often called multilayer perceptrons;
32 | but they are more accurately described as multilayer networks, where
33 | each layer + activation function is a perceptron-like unit (see, e.g.,
34 | [https://en.wikipedia.org/wiki/Multilayer_perceptron#Terminology]).
35 |
36 | Args:
37 | layer_widths: Tuple of ints telling the number of layers and the width of
38 | each layer. For example, setting `layer_widths=(128, 64, 32)` would
39 | yield 3 layers with successive widths of 128, 64, and 32.
40 | activation_fn: Type of activation function between pairs of fully connected
41 | layers; must be an activation-type subclass of `Layer`.
42 | out_activation: If True, include a copy of the activation function as the
43 | last layer in the network.
44 | flatten: If True, insert a layer at the head of the network to flatten the
45 | input tensor into a matrix of shape (batch_size. -1).
46 | mode: Ignored.
47 |
48 | Returns:
49 | An assembled MLP network with the specified layers. This network can either
50 | be initialized and trained as a full model, or can be used as a building
51 | block in a larger network.
52 | """
53 | del mode
54 |
55 | layers = []
56 | for width in layer_widths:
57 | layers.append(tl.Dense(width))
58 | layers.append(activation_fn())
59 |
60 | if not out_activation:
61 | # Don't need the last activation.
62 | layers.pop()
63 |
64 | return tl.Serial(
65 | [tl.Flatten()] if flatten else [],
66 | layers,
67 | )
68 |
--------------------------------------------------------------------------------
/trax/models/mlp_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for MLP."""
17 |
18 | from absl.testing import absltest
19 | import numpy as np
20 |
21 | from trax import fastmath
22 | from trax import shapes
23 | from trax.models import mlp
24 |
25 |
26 | class MLPTest(absltest.TestCase):
27 |
28 | def test_mlp_forward_shape(self):
29 | model = mlp.MLP(layer_widths=(32, 16, 8))
30 | x = np.ones((7, 28, 28, 3)).astype(np.float32)
31 | _, _ = model.init(shapes.signature(x))
32 | y = model(x)
33 | self.assertEqual(y.shape, (7, 8))
34 |
35 |
36 |
37 | if __name__ == '__main__':
38 | absltest.main()
39 |
--------------------------------------------------------------------------------
/trax/models/neural_gpu.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of the improved Neural GPU (NGPU)."""
17 |
18 | from trax import layers as tl
19 | from trax.fastmath import numpy as jnp
20 |
21 |
22 | # TODO(ddohan): Combinator to add saturation costs to loss
23 | def SaturationCost(x, limit=0.9):
24 | return jnp.minimum(0, jnp.abs(x) - limit)
25 |
26 |
27 | def DiagonalGate():
28 | """Split channels in 3 parts. Shifts 1st and 3rd sections to left/right."""
29 |
30 | def f(x): # pylint: disable=invalid-name
31 | # x : [batch, 1, length, depth]
32 | x = jnp.pad(x, [(0, 0), (0, 0), (1, 1), (0, 0)],
33 | mode='constant', constant_values=0.0)
34 | depth = x.shape[-1] // 3
35 | assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth,
36 | x.shape)
37 | xs = [
38 | x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth],
39 | x[:, :, 2:, 2 * depth:3 * depth]
40 | ]
41 | return jnp.concatenate(xs, axis=3)
42 | return tl.Fn('DiagonalGate', f)
43 |
44 |
45 | def ConvDiagonalGRU(units, kernel_size=(3, 3)):
46 | """Build convolutional GRU with diagonal gating as in ImprovedNGPU."""
47 |
48 | def BuildConv():
49 | return tl.Conv(filters=units, kernel_size=kernel_size, padding='SAME')
50 |
51 | return tl.GeneralGRUCell(
52 | candidate_transform=BuildConv,
53 | memory_transform_fn=DiagonalGate,
54 | gate_nonlinearity=tl.HardSigmoid,
55 | candidate_nonlinearity=tl.HardTanh)
56 |
57 |
58 | def NeuralGPU(d_feature=96, steps=16, vocab_size=2, mode='train'):
59 | """Implementation of Neural GPU: https://arxiv.org/abs/1702.08727.
60 |
61 | Args:
62 | d_feature: Number of memory channels (dimensionality of feature embedding).
63 | steps: Number of times depthwise recurrence steps.
64 | vocab_size: Vocabulary size.
65 | mode: Whether we are training or evaluating or doing inference.
66 |
67 | Returns:
68 | A NeuralGPU Stax model.
69 | """
70 | del mode
71 |
72 | core = ConvDiagonalGRU(units=d_feature)
73 | return tl.Serial(
74 | tl.Embedding(vocab_size=vocab_size, d_feature=d_feature),
75 | [core] * steps,
76 | tl.Dense(vocab_size),
77 | )
78 |
--------------------------------------------------------------------------------
/trax/models/neural_gpu_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for trax.models.neural_gpu."""
17 |
18 | from absl.testing import absltest
19 | import numpy as np
20 |
21 | from trax import shapes
22 | from trax.models import neural_gpu
23 |
24 |
25 | class NeuralGPUTest(absltest.TestCase):
26 |
27 | def test_ngpu(self):
28 | model = neural_gpu.NeuralGPU(d_feature=30, steps=4, vocab_size=22)
29 | x = np.ones((3, 5, 7)).astype(np.int32)
30 | _, _ = model.init(shapes.signature(x))
31 | y = model(x)
32 | self.assertEqual(y.shape, (3, 5, 7, 22))
33 |
34 |
35 | if __name__ == '__main__':
36 | absltest.main()
37 |
--------------------------------------------------------------------------------
/trax/models/reformer/README.md:
--------------------------------------------------------------------------------
1 | ## Reformer and Terraformer: The Efficient Transformers
2 |
3 | Reformer and Terraformer are more efficient versions of Transformer that uses reversible layers, locality-sensitive hashing and sparse layers.
4 |
5 | ### Papers
6 |
7 | * Reformer: Read about the details of Reformer in the [Reformer paper](https://arxiv.org/abs/2001.04451) which was selected for oral presentation at [ICLR 2020](https://iclr.cc/Conferences/2020/).
8 |
9 |
10 | * Terraformer: Read about the details of Terraformer in the following paper.
11 |
12 | ### Models
13 |
14 |
15 | * Generate images with Reformer using [this colab](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/image_generation.ipynb).
16 |
17 | * Translate from English to German with a reversible encoder-decoder model using [this colab](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/machine_translation.ipynb).
18 |
19 | * Medium-size (~700M weights) Terraformer model for summarizing arxiv articles is available at `gs://trax-ml/terraformer/medium`
20 |
21 | * Large (~7B weights) Terraformer model pre-trained on C4 `gs://trax-ml/terraformer/big`
22 |
--------------------------------------------------------------------------------
/trax/models/reformer/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/trax/models/reformer/reformer_e2e_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """End to end test for Reformer."""
17 |
18 | import os
19 |
20 | from absl.testing import absltest
21 | import gin
22 |
23 | from trax import test_utils
24 | from trax.models.reformer import reformer # pylint: disable=unused-import
25 | from trax.supervised import trainer_lib
26 | from trax.tf_numpy import numpy as tf_np # pylint: disable=unused-import
27 |
28 | pkg_dir, _ = os.path.split(__file__)
29 | _TESTDATA = os.path.join(pkg_dir, 'testdata')
30 | _CONFIG_DIR = os.path.join(pkg_dir, '../../supervised/configs/')
31 |
32 |
33 | class ReformerE2ETest(absltest.TestCase):
34 |
35 | def setUp(self):
36 | super().setUp()
37 | gin.clear_config()
38 | gin.add_config_file_search_path(_CONFIG_DIR)
39 | test_utils.ensure_flag('test_tmpdir')
40 |
41 | def test_reformer_wmt_ende(self):
42 | batch_size_per_device = 2
43 | steps = 1
44 | n_layers = 2
45 | d_ff = 32
46 |
47 | gin.parse_config_file('reformer_wmt_ende.gin')
48 |
49 | gin.bind_parameter('data_streams.data_dir', _TESTDATA)
50 | gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device)
51 | gin.bind_parameter('train.steps', steps)
52 | gin.bind_parameter('Reformer.n_encoder_layers', n_layers)
53 | gin.bind_parameter('Reformer.n_decoder_layers', n_layers)
54 | gin.bind_parameter('Reformer.d_ff', d_ff)
55 |
56 | output_dir = self.create_tempdir().full_path
57 | _ = trainer_lib.train(output_dir=output_dir)
58 |
59 | def test_reformer_copy(self):
60 | batch_size_per_device = 2
61 | steps = 1
62 | n_layers = 2
63 | d_ff = 32
64 | d_model = 32
65 |
66 | gin.parse_config_file('reformer_copy.gin')
67 |
68 | gin.bind_parameter('data_streams.data_dir', _TESTDATA)
69 | gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device)
70 | gin.bind_parameter('train.steps', steps)
71 | gin.bind_parameter('ReformerLM.n_layers', n_layers)
72 | gin.bind_parameter('ReformerLM.d_ff', d_ff)
73 | gin.bind_parameter('ReformerLM.d_model', d_model)
74 |
75 | output_dir = self.create_tempdir().full_path
76 | _ = trainer_lib.train(output_dir=output_dir)
77 |
78 |
79 | if __name__ == '__main__':
80 | absltest.main()
81 |
--------------------------------------------------------------------------------
/trax/models/reformer/testdata/translate_ende_wmt32k-dev-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/models/reformer/testdata/translate_ende_wmt32k-dev-00000-of-00001
--------------------------------------------------------------------------------
/trax/models/reformer/testdata/translate_ende_wmt32k-train-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/models/reformer/testdata/translate_ende_wmt32k-train-00000-of-00001
--------------------------------------------------------------------------------
/trax/models/research/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/trax/models/research/rezero_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for ReZero models."""
17 |
18 | from absl.testing import absltest
19 | import numpy as np
20 |
21 | from trax import layers as tl
22 | from trax import shapes
23 | from trax.models.research import rezero
24 |
25 |
26 | class ResidualZeroTest(absltest.TestCase):
27 |
28 | def test_residual_layer_forward(self):
29 | """Tests that the forward pass runs and returns the expected shape."""
30 | model = rezero.ResidualZero(tl.Dense(5))
31 | x = [np.arange(5).astype(np.float32)]
32 | _, _ = model.init(shapes.signature(x))
33 | y = model(x)
34 | self.assertEqual(y.tolist(), [0., 1., 2., 3., 4.])
35 |
36 |
37 | class ReZeroTransformerLMTest(absltest.TestCase):
38 |
39 | def test_rezero_lm_forward_shape(self):
40 | """Tests that the forward pass runs and returns the expected shape."""
41 | vocab_size = 16
42 | model = rezero.ReZeroTransformerLM(
43 | vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2, max_len=16)
44 | xs = [np.ones((1, 8)).astype(np.int32),
45 | np.ones((1, 8)).astype(np.int32)]
46 | _, _ = model.init(shapes.signature(xs))
47 | ys = model(xs)
48 | self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)])
49 |
50 |
51 | class ReZeroTransformerTest(absltest.TestCase):
52 |
53 | def test_rezero_forward_shape(self):
54 | """Tests that the forward pass runs and returns the expected shape."""
55 | vocab_size = 16
56 | model = rezero.ReZeroTransformer(
57 | vocab_size, d_model=32, d_ff=64, n_encoder_layers=2, n_decoder_layers=2,
58 | n_heads=2, max_len=16)
59 | xs = [np.ones((1, 8)).astype(np.int32),
60 | np.ones((1, 8)).astype(np.int32)]
61 | _, _ = model.init(shapes.signature(xs))
62 | ys = model(xs)
63 | self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)])
64 |
65 |
66 | if __name__ == '__main__':
67 | absltest.main()
68 |
--------------------------------------------------------------------------------
/trax/models/research/terraformer_memory_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test for memory usage in Terraformer models.
17 |
18 | This test is designed to run on TPUv3 hardware, processing 1 million tokens at a
19 | time while just barely fitting within the 16 GB memory budget.
20 | """
21 |
22 |
23 | from absl.testing import absltest
24 |
25 |
26 |
27 | class TerraformerMemoryTest(absltest.TestCase):
28 |
29 |
30 | def test_terraformer_memory(self):
31 | pass # TODO(jonni): Figure out an OSS-compatible memory test.
32 |
33 |
34 | if __name__ == '__main__':
35 | config.config_with_absl()
36 | absltest.main()
37 |
--------------------------------------------------------------------------------
/trax/models/research/testdata/translate_ende_wmt32k-dev-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/models/research/testdata/translate_ende_wmt32k-dev-00000-of-00001
--------------------------------------------------------------------------------
/trax/models/research/testdata/translate_ende_wmt32k-train-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/models/research/testdata/translate_ende_wmt32k-train-00000-of-00001
--------------------------------------------------------------------------------
/trax/models/resnet_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Resnet models."""
17 |
18 | from absl.testing import absltest
19 | import numpy as np
20 |
21 | from trax import fastmath
22 | from trax import shapes
23 | from trax.models import resnet
24 |
25 |
26 | class ResnetTest(absltest.TestCase):
27 |
28 | def test_resnet(self):
29 | model = resnet.Resnet50(d_hidden=8, n_output_classes=10)
30 | x = np.ones((3, 256, 256, 3)).astype(np.float32)
31 | _, _ = model.init(shapes.signature(x))
32 | y = model(x)
33 | self.assertEqual(y.shape, (3, 10))
34 |
35 | def test_wide_resnet(self):
36 | model = resnet.WideResnet(n_blocks=1, n_output_classes=10)
37 | x = np.ones((3, 32, 32, 3)).astype(np.float32)
38 | _, _ = model.init(shapes.signature(x))
39 | y = model(x)
40 | self.assertEqual(y.shape, (3, 10))
41 |
42 |
43 |
44 | if __name__ == '__main__':
45 | absltest.main()
46 |
--------------------------------------------------------------------------------
/trax/models/rl_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for RL."""
17 |
18 | from unittest import mock
19 | from absl.testing import absltest
20 | import numpy as np
21 |
22 | from trax import shapes
23 | from trax.models import rl
24 |
25 |
26 | class RLTest(absltest.TestCase):
27 |
28 | def test_policy_forward_shape(self):
29 | mock_dist = mock.MagicMock()
30 | mock_dist.n_inputs = 4
31 | model = rl.Policy(policy_distribution=mock_dist)
32 | x = np.ones((2, 3))
33 | _, _ = model.init(shapes.signature(x))
34 | y = model(x)
35 | self.assertEqual(y.shape, (2, 4))
36 |
37 | def test_value_forward_shape(self):
38 | model = rl.Value()
39 | x = np.ones((2, 3))
40 | _, _ = model.init(shapes.signature(x))
41 | y = model(x)
42 | self.assertEqual(y.shape, (2, 1))
43 |
44 | def test_policy_and_value_forward_shape(self):
45 | mock_dist = mock.MagicMock()
46 | mock_dist.n_inputs = 4
47 | model = rl.PolicyAndValue(policy_distribution=mock_dist)
48 | x = np.ones((2, 3))
49 | _, _ = model.init(shapes.signature(x))
50 | ys = model(x)
51 | self.assertEqual([y.shape for y in ys], [(2, 4), (2, 1)])
52 |
53 |
54 | if __name__ == '__main__':
55 | absltest.main()
56 |
--------------------------------------------------------------------------------
/trax/models/rnn_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for RNNs."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import numpy as np
21 |
22 | from trax import fastmath
23 | from trax import shapes
24 | from trax.models import rnn
25 |
26 | BACKENDS = [fastmath.Backend.JAX]
27 |
28 |
29 | @parameterized.named_parameters(
30 | ('_' + b.value, b) for b in BACKENDS)
31 | class RNNTest(parameterized.TestCase):
32 |
33 | def test_rnnlm_forward_shape(self, backend):
34 | with fastmath.use_backend(backend):
35 | model = rnn.RNNLM(vocab_size=20, d_model=16)
36 | x = np.ones((3, 28)).astype(np.int32)
37 | _, _ = model.init(shapes.signature(x))
38 | y = model(x)
39 | self.assertEqual(y.shape, (3, 28, 20))
40 |
41 | def test_grulm_forward_shape(self, backend):
42 | with fastmath.use_backend(backend):
43 | model = rnn.GRULM(vocab_size=20, d_model=16)
44 | x = np.ones((3, 28)).astype(np.int32)
45 | _, _ = model.init(shapes.signature(x))
46 | y = model(x)
47 | self.assertEqual(y.shape, (3, 28, 20))
48 |
49 | def test_lstmseq2seqattn_forward_shape(self, backend):
50 | with fastmath.use_backend(backend):
51 | model = rnn.LSTMSeq2SeqAttn(
52 | input_vocab_size=20, target_vocab_size=20, d_model=16)
53 | x = np.ones((3, 28)).astype(np.int32)
54 | _, _ = model.init([shapes.signature(x), shapes.signature(x)])
55 | ys = model([x, x])
56 | self.assertEqual([y.shape for y in ys], [(3, 28, 20), (3, 28)])
57 |
58 |
59 | if __name__ == '__main__':
60 | absltest.main()
61 |
--------------------------------------------------------------------------------
/trax/models/transformer_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Transformer models."""
17 |
18 | import functools
19 |
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 | import numpy as np
23 |
24 | from trax import fastmath
25 | from trax import shapes
26 | from trax.layers import test_utils
27 | from trax.models import transformer
28 |
29 |
30 | class TransformerTest(parameterized.TestCase):
31 |
32 | def test_transformer_lm_forward_shape(self):
33 | vocab_size = 16
34 | model = transformer.TransformerLM(
35 | vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2)
36 | x = np.ones((3, 5)).astype(np.int32)
37 | _, _ = model.init(shapes.signature(x))
38 | y = model(x)
39 | self.assertEqual(y.shape, (3, 5, vocab_size))
40 |
41 | def _test_transformer_forward_shape(self, input_vocab_size,
42 | output_vocab_size):
43 | model = transformer.Transformer(
44 | input_vocab_size, output_vocab_size, d_model=32, d_ff=64,
45 | n_encoder_layers=2, n_decoder_layers=2, n_heads=2)
46 | xs = [np.ones((3, 5)).astype(np.int32), np.ones((3, 5)).astype(np.int32)]
47 | _, _ = model.init(shapes.signature(xs))
48 | y, _ = model(xs)
49 |
50 | vocab_size = output_vocab_size or input_vocab_size
51 | self.assertEqual(y.shape, (3, 5, vocab_size))
52 |
53 | @parameterized.named_parameters(
54 | ('same_vocab', 16, None),
55 | ('same_size', 16, 16),
56 | ('different_size', 16, 50))
57 | def test_transformer_forward_shape(self, input_vocab_size, output_vocab_size):
58 | """Run the Transformer forward and check output shape."""
59 | self._test_transformer_forward_shape(input_vocab_size, output_vocab_size)
60 |
61 |
62 | def test_dot_product_causal_attention_fast_inference(self):
63 | model_fn = functools.partial(
64 | transformer.TransformerLM, d_model=4, d_ff=8, n_layers=2, n_heads=2
65 | )
66 | test_utils.test_eval_equals_predict_discrete(model_fn)
67 |
68 |
69 | if __name__ == '__main__':
70 | absltest.main()
71 |
--------------------------------------------------------------------------------
/trax/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Optimizers for use with Trax layers."""
17 |
18 | import gin
19 |
20 | from trax.optimizers import adafactor
21 | from trax.optimizers import adam
22 | from trax.optimizers import base
23 | from trax.optimizers import momentum
24 | from trax.optimizers import rms_prop
25 | from trax.optimizers import sm3
26 | from trax.optimizers import trainer
27 | from trax.optimizers.trainer import ReversibleSerialTrainer
28 | from trax.optimizers.trainer import Trainer
29 |
30 |
31 | def opt_configure(*args, **kwargs):
32 | kwargs['module'] = 'trax.optimizers'
33 | return gin.external_configurable(*args, **kwargs)
34 |
35 | # Optimizers (using upper-case names).
36 | # pylint: disable=invalid-name
37 | SGD = opt_configure(base.SGD)
38 | Momentum = opt_configure(momentum.Momentum)
39 | RMSProp = opt_configure(rms_prop.RMSProp)
40 | Adam = opt_configure(adam.Adam)
41 | Adafactor = opt_configure(adafactor.Adafactor)
42 | SM3 = opt_configure(sm3.SM3)
43 |
--------------------------------------------------------------------------------
/trax/optimizers/momentum.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Nesterov momentum optimizer (also known as Nesterov Accelerated Gradient)."""
17 |
18 | from trax.fastmath import numpy as jnp
19 | from trax.optimizers import base
20 |
21 |
22 | # TODO(jonni): Consider renaming this class to NesterovMomentum.
23 | class Momentum(base.Optimizer):
24 | r"""A momentum optimizer.
25 |
26 | This class implements two variants of momentum stochastic gradient descent
27 | (SGD): with and without the Nesterov correction. The implementation of the
28 | Nesterov update is based on the concepts in Sutskever et al. (2013)
29 | [http://jmlr.org/proceedings/papers/v28/sutskever13.pdf], reformulated in
30 | Bengio et al. (2012) [https://arxiv.org/abs/1212.0901], to work well with
31 | backpropagation (equations 6 and 7):
32 |
33 | .. math::
34 | v_t &= \mu_{t-1}v_{t-1} - \epsilon_{t-1}\nabla f(\Theta_{t-1}) \\
35 | \Theta_t &= \Theta_{t-1} - \mu_{t-1} v_{t-1} + \mu_t v_t + v_t
36 |
37 | where :math:`\mu_{t-1}` is the momentum (decay) coefficient at time step
38 | :math:`t-1` and :math:`\epsilon_{t-1}` is the learning rate at :math:`t-1`.
39 |
40 | Note that the implementation below also includes a weight decay rate
41 | (:math:`\alpha`) on the parameters, independent of the Nesterov momentum.
42 | """
43 |
44 | def __init__(
45 | self, learning_rate=0.01, mass=0.9, weight_decay_rate=1e-5, nesterov=True
46 | ): # pylint: disable=useless-super-delegation
47 | super().__init__(
48 | learning_rate=learning_rate,
49 | mass=mass,
50 | weight_decay_rate=weight_decay_rate,
51 | )
52 | self._nesterov = nesterov
53 |
54 | def init(self, weights):
55 | return jnp.zeros_like(weights)
56 |
57 | def update(self, step, grads, weights, velocity, opt_params):
58 | del step
59 | v = velocity
60 | mu = opt_params['mass']
61 | alpha = opt_params['weight_decay_rate']
62 | epsilon = opt_params['learning_rate']
63 |
64 | new_v = mu * v + grads
65 | if self._nesterov:
66 | weight_update = mu * new_v + grads
67 | else:
68 | weight_update = new_v
69 | new_weights = (1 - alpha) * weights - epsilon * weight_update
70 |
71 | new_weights = new_weights.astype(weights.dtype)
72 | return (new_weights, new_v)
73 |
--------------------------------------------------------------------------------
/trax/optimizers/optimizers_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for supervised training optimizers."""
17 |
18 | from absl.testing import absltest
19 |
20 | import numpy as np
21 |
22 | from trax import optimizers
23 | from trax.optimizers import momentum
24 |
25 |
26 | class OptimizersTest(absltest.TestCase):
27 |
28 | def test_slots(self):
29 | weights_shape = (3, 5)
30 | weight_tree = np.arange(15).reshape(weights_shape)
31 |
32 | # SGD - an optimizer that doesn't use slots.
33 | opt_1 = optimizers.SGD(.01)
34 | self.assertIsNone(opt_1.slots)
35 | opt_1.tree_init(weight_tree)
36 | self.assertIsInstance(opt_1.slots, tuple)
37 | self.assertLen(opt_1.slots, 1)
38 | self.assertIsNone(opt_1.slots[0])
39 |
40 | # Momentum - an optimizer with slots
41 | opt_2 = momentum.Momentum(.01)
42 | self.assertIsNone(opt_2.slots)
43 | opt_2.tree_init(weight_tree)
44 | self.assertIsInstance(opt_2.slots, tuple)
45 | self.assertLen(opt_2.slots, 1)
46 | self.assertEqual(weights_shape, opt_2.slots[0].shape)
47 |
48 |
49 | if __name__ == '__main__':
50 | absltest.main()
51 |
--------------------------------------------------------------------------------
/trax/optimizers/rms_prop.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """RMSProp optimizer class."""
17 |
18 | from trax.fastmath import numpy as jnp
19 | from trax.optimizers import base as opt_base
20 |
21 |
22 | class RMSProp(opt_base.Optimizer):
23 | """RMSProp optimizer.
24 |
25 | Uses optimizer weights ("slots") to maintain a root-mean-square exponentially
26 | decaying average of gradients from prior training batches.
27 | """
28 |
29 | def __init__(self, learning_rate=0.001, gamma=0.9,
30 | eps=1e-8, clip_grad_norm=None): # pylint: disable=useless-super-delegation
31 | super().__init__(
32 | learning_rate=learning_rate,
33 | gamma=gamma,
34 | eps=eps,
35 | clip_grad_norm=clip_grad_norm
36 | )
37 |
38 | def init(self, weights):
39 | return jnp.ones_like(weights)
40 |
41 | def update(self, step, grads, weights, avg_sq_grad, opt_params):
42 | del step
43 | lr = opt_params['learning_rate']
44 | gamma = opt_params['gamma']
45 | eps = opt_params['eps']
46 | avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1. - gamma)
47 | weights = weights - (lr * grads /
48 | (jnp.sqrt(avg_sq_grad) + eps)).astype(weights.dtype)
49 | return weights, avg_sq_grad
50 |
--------------------------------------------------------------------------------
/trax/rl/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Trax RL library."""
17 |
18 | import gin
19 |
20 | from trax.rl import actor_critic
21 | from trax.rl import actor_critic_joint
22 | from trax.rl import envs
23 | from trax.rl import serialization_utils
24 | from trax.rl import training
25 |
26 |
27 | def configure_rl(*args, **kwargs):
28 | kwargs['module'] = 'trax.rl'
29 | kwargs['denylist'] = ['task', 'output_dir']
30 | return gin.external_configurable(*args, **kwargs)
31 |
32 |
33 | A2C = configure_rl(actor_critic.A2C)
34 | AWR = configure_rl(actor_critic.AWR)
35 | LoopAWR = configure_rl(actor_critic.LoopAWR)
36 | PPO = configure_rl(actor_critic.PPO)
37 | SamplingAWR = configure_rl(actor_critic.SamplingAWR)
38 |
39 | A2CJoint = configure_rl(actor_critic_joint.A2CJoint)
40 | AWRJoint = configure_rl(actor_critic_joint.AWRJoint)
41 | PPOJoint = configure_rl(actor_critic_joint.PPOJoint)
42 |
43 | PolicyGradient = configure_rl(training.PolicyGradient)
44 | ExpertIteration = configure_rl(training.ExpertIteration)
45 | DQN = configure_rl(training.DQN)
46 |
47 | TimeSeriesModel = gin.external_configurable(
48 | serialization_utils.TimeSeriesModel, module='trax.rl'
49 | )
50 |
--------------------------------------------------------------------------------
/trax/rl/atari_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for RL training."""
17 |
18 | import functools
19 |
20 | from absl.testing import absltest
21 |
22 | from trax import models
23 | from trax import optimizers as opt
24 | from trax.models import atari_cnn
25 | from trax.rl import actor_critic
26 | from trax.rl import task as rl_task
27 | from trax.supervised import lr_schedules
28 |
29 |
30 |
31 |
32 | if __name__ == '__main__':
33 | absltest.main()
34 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_a2c_joint_atari_sweep.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | A2CJoint.train_steps_per_epoch: [40, 80]
16 | multifactor.constant: [0.0007, 0.0001]
17 | A2CJoint.value_loss_coeff: [0.1, 0.25]
18 | A2CJoint.batch_size: [128, 256]
19 | A2CJoint.n_trajectories_per_epoch: [5,]
20 | RLTask.max_steps: [2000,]
21 | A2CJoint.entropy_coeff: [0.01, 0.1,]
22 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_atari_sweep.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | RLTask.env: [
16 | "boxing",
17 | "breakout",
18 | "freeway",
19 | "gopher",
20 | "pong",
21 | "seaquest",
22 | ]
23 |
24 | # Sweep 1:
25 | train_steps: [100, 300, 1000, 3000]
26 | SamplingAWR.n_interactions_per_epoch: [200, 500, 1000, 2000]
27 | SamplingAWR.n_replay_epochs: [50, 100, 200, 500, 1000]
28 |
29 | # Sweep 2:
30 | SamplingAWR.q_value_temperature: [0.3, 1.0]
31 | policy_lr: [0.00003, 0.0001, 0.0003]
32 | value_lr: [0.001, 0.0005, 0.0001]
33 |
34 | # Sweep 3:
35 | SamplingAWR.value_evals_per_epoch: [5, 10, 20, 50]
36 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_awr_cartpole_sweep.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | AWRJoint.policy_train_steps_per_epoch: [10, 100, 1000]
16 | multifactor.constant: [0.0001, 0.001, 0.01]
17 | AWRJoint.batch_size: [32, 128]
18 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_awr_joint_atari_sweep.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | AWRJoint.train_steps_per_epoch: [10, 100, 1000]
16 | multifactor.constant: [0.0001, 0.001, 0.01]
17 | AWRJoint.batch_size: [32, 128]
18 | RLTask.max_steps: [200, 2000]
19 | RLTask.env: ["pong"]
20 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_awr_joint_cartpole.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.supervised.lr_schedules
16 | import trax.models
17 | import trax.optimizers
18 | import trax.rl
19 | import trax.rl_trainer
20 |
21 | # Parameters for PolicyAndValue:
22 | # ==============================================================================
23 | PolicyAndValue.body = @trax.models.MLP
24 |
25 | # Parameters for MLP:
26 | # ==============================================================================
27 | MLP.flatten = False
28 | MLP.layer_widths = (128,)
29 | MLP.out_activation = True
30 |
31 | # Parameters for multifactor:
32 | # ==============================================================================
33 | multifactor.constant = 0.01
34 | multifactor.factors = 'constant'
35 |
36 | # Parameters for RLTask:
37 | # ==============================================================================
38 | RLTask.env = "CartPole-v0"
39 | RLTask.initial_trajectories = 1000
40 | RLTask.gamma = 0.99
41 | RLTask.max_steps = 200
42 |
43 | # Parameters for AWR:
44 | # ==============================================================================
45 | AWRJoint.joint_model = @trax.models.PolicyAndValue
46 | AWRJoint.optimizer = @trax.optimizers.Adam
47 | AWRJoint.batch_size = 32
48 | AWRJoint.train_steps_per_epoch = 1000
49 | AWRJoint.lr_schedule = @multifactor
50 | AWRJoint.n_trajectories_per_epoch = 10
51 | AWRJoint.beta = 1.0
52 | AWRJoint.w_max = 20
53 | AWRJoint.max_slice_length = 1
54 |
55 | # Parameters for train_rl:
56 | # ==============================================================================
57 | train_rl.light_rl = True
58 | train_rl.light_rl_trainer = @trax.rl.AWRJoint
59 | train_rl.n_epochs = 10000
60 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_awr_joint_cartpole_sweep.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | AWRJoint.train_steps_per_epoch: [10, 100, 1000]
16 | multifactor.constant: [0.0001, 0.001, 0.01]
17 | AWRJoint.batch_size: [32, 128]
18 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_copy_sweep.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Task difficulty:
16 | copy_length: [3, 4, 6, 8]
17 | # Model params:
18 | TransformerDecoder.n_layers: [1, 2, 3]
19 | d_model: [32, 64, 128, 256, 512]
20 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_mujoco_sweep.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | RLTask.env: [
16 | "DM-HalfCheetah-v2",
17 | "DM-Hopper-v2",
18 | "DM-Humanoid-v2",
19 | "DM-Walker2d-v2",
20 | ]
21 |
22 | # Sweep 1:
23 | train_steps: [100, 300, 1000, 3000]
24 | SamplingAWR.n_interactions_per_epoch: [200, 500, 1000, 2000]
25 | SamplingAWR.n_replay_epochs: [50, 100, 200, 500, 1000]
26 |
27 | # Sweep 2:
28 | SamplingAWR.q_value_temperature: [0.3, 1.0]
29 | policy_lr: [0.00003, 0.0001, 0.0003]
30 | value_lr: [0.001, 0.0005, 0.0001]
31 |
32 | # Sweep 3:
33 | SamplingAWR.value_evals_per_epoch: [5, 10, 20, 50]
34 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_ppo_atari.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.layers
16 | import trax.supervised.lr_schedules
17 | import trax.models
18 | import trax.optimizers
19 | import trax.rl
20 | import trax.rl_trainer
21 |
22 | # Parameters for Policy:
23 | # ==============================================================================
24 | Policy.body = @trax.models.AtariCnnBody
25 |
26 | # Parameters for Value:
27 | # ==============================================================================
28 | Value.body = @trax.models.AtariCnnBody
29 |
30 | # Parameters for the AtariCnnBody:
31 | # ==============================================================================
32 | AtariCnnBody.kernel_initializer = @trax.layers.AtariConvInit
33 |
34 | # Parameters for multifactor:
35 | # ==============================================================================
36 | value/multifactor.constant = 0.0001
37 | value/multifactor.factors = 'constant'
38 | policy/multifactor.constant = 0.0001
39 | policy/multifactor.factors = 'constant'
40 |
41 | # Parameters for RLTask:
42 | # ==============================================================================
43 | RLTask.env = "boxing"
44 | RLTask.initial_trajectories = 100
45 | RLTask.gamma = 0.99
46 | RLTask.max_steps = 200
47 | RLTask.dm_suite = True
48 |
49 | # Parameters for PPO:
50 | # ==============================================================================
51 | PPO.n_shared_layers=0
52 | PPO.value_model = @trax.models.Value
53 | PPO.value_optimizer = @trax.optimizers.Adam
54 | PPO.value_batch_size = 4
55 | PPO.value_train_steps_per_epoch = 10
56 | PPO.value_lr_schedule = @value/multifactor
57 | PPO.policy_model = @trax.models.Policy
58 | PPO.policy_optimizer = @trax.optimizers.Adam
59 | PPO.policy_batch_size = 4
60 | PPO.policy_train_steps_per_epoch = 10
61 | PPO.policy_lr_schedule = @policy/multifactor
62 | PPO.n_trajectories_per_epoch = 50
63 |
64 | # Parameters for train_rl:
65 | # ==============================================================================
66 | train_rl.light_rl = True
67 | train_rl.light_rl_trainer = @trax.rl.PPO
68 | train_rl.n_epochs = 5000
69 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_ppo_boxing_regression_test.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.supervised.lr_schedules
16 | import trax.models
17 | import trax.optimizers
18 | import trax.rl
19 | import trax.rl_trainer
20 |
21 | # Parameters for Adam:
22 | # ==============================================================================
23 | Adam.clip_grad_norm = 0.5
24 |
25 | # Parameters for PolicyAndValue:
26 | # ==============================================================================
27 | PolicyAndValue.body = @trax.models.AtariCnnBody
28 |
29 | # Parameters for the AtariCnnBody:
30 | # ==============================================================================
31 | AtariCnnBody.kernel_initializer = @trax.layers.AtariConvInit
32 | AtariCnnBody.n_frames = 1
33 | AtariCnnBody.padding = 'VALID'
34 |
35 | # Parameters for multifactor:
36 | # ==============================================================================
37 | multifactor.constant = 0.0001
38 | multifactor.factors = 'constant * linear_warmup'
39 | multifactor.warmup_steps = 100
40 |
41 | # Parameters for RLTask:
42 | # ==============================================================================
43 | RLTask.env = "boxing"
44 | RLTask.initial_trajectories = 0
45 | RLTask.gamma = 0.99
46 | RLTask.max_steps = 2000
47 | RLTask.dm_suite = True
48 | RLTask.num_stacked_frames = 4
49 |
50 | # Parameters for PPO:
51 | # ==============================================================================
52 | PPOJoint.joint_model = @trax.models.PolicyAndValue
53 | PPOJoint.optimizer = @trax.optimizers.Adam
54 | PPOJoint.batch_size = 256
55 | PPOJoint.train_steps_per_epoch = 80
56 | PPOJoint.lr_schedule = @multifactor
57 | PPOJoint.n_trajectories_per_epoch = 5
58 | PPOJoint.epsilon = 0.1
59 | PPOJoint.value_loss_coeff = 0.1
60 | PPOJoint.entropy_coeff = 0.01
61 |
62 | # Parameters for train_rl:
63 | # ==============================================================================
64 | train_rl.light_rl = True
65 | train_rl.light_rl_trainer = @trax.rl.PPOJoint
66 | train_rl.n_epochs = 900
67 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_ppo_cartpole_regression_test.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.supervised.lr_schedules
16 | import trax.models
17 | import trax.optimizers
18 | import trax.rl
19 | import trax.rl_trainer
20 |
21 | # Parameters for Adam:
22 | # ==============================================================================
23 | Adam.clip_grad_norm = 0.5
24 |
25 | # Parameters for PolicyAndValue:
26 | # ==============================================================================
27 | PolicyAndValue.body = @trax.models.MLP
28 |
29 | # Parameters for MLP:
30 | # ==============================================================================
31 | MLP.flatten = False
32 | MLP.layer_widths = (64,64)
33 | MLP.out_activation = True
34 |
35 | # Parameters for multifactor:
36 | # ==============================================================================
37 | multifactor.constant = 0.0025
38 | multifactor.factors = 'constant * linear_warmup'
39 | multifactor.warmup_steps = 100
40 |
41 | # Parameters for RLTask:
42 | # ==============================================================================
43 | RLTask.env = "CartPole-v0"
44 | RLTask.initial_trajectories = 0
45 | RLTask.gamma = 0.99
46 | RLTask.max_steps = 200
47 |
48 | # Parameters for PPO:
49 | # ==============================================================================
50 | PPOJoint.joint_model = @trax.models.PolicyAndValue
51 | PPOJoint.optimizer = @trax.optimizers.Adam
52 | PPOJoint.batch_size = 256
53 | PPOJoint.train_steps_per_epoch = 40
54 | PPOJoint.lr_schedule = @multifactor
55 | PPOJoint.n_trajectories_per_epoch = 20
56 | PPOJoint.epsilon = 0.1
57 | PPOJoint.value_loss_coeff = 0
58 | PPOJoint.entropy_coeff = 0
59 |
60 |
61 | # Parameters for train_rl:
62 | # ==============================================================================
63 | train_rl.light_rl = True
64 | train_rl.light_rl_trainer = @trax.rl.PPOJoint
65 | train_rl.n_epochs = 500
66 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_ppo_joint_atari.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.supervised.lr_schedules
16 | import trax.models
17 | import trax.optimizers
18 | import trax.rl
19 | import trax.rl_trainer
20 |
21 | # Parameters for Adam:
22 | # ==============================================================================
23 | Adam.clip_grad_norm = 0.5
24 |
25 | # Parameters for PolicyAndValue:
26 | # ==============================================================================
27 | PolicyAndValue.body = @trax.models.AtariCnnBody
28 |
29 | # Parameters for the AtariCnnBody:
30 | # ==============================================================================
31 | AtariCnnBody.kernel_initializer = @trax.layers.AtariConvInit
32 | AtariCnnBody.n_frames = 1
33 | AtariCnnBody.padding = 'VALID'
34 |
35 | # Parameters for multifactor:
36 | # ==============================================================================
37 | multifactor.constant = 0.01
38 | multifactor.factors = 'constant'
39 |
40 | # Parameters for RLTask:
41 | # ==============================================================================
42 | RLTask.env = "boxing"
43 | RLTask.initial_trajectories = 0
44 | RLTask.gamma = 0.99
45 | RLTask.max_steps = 2000
46 | RLTask.dm_suite = True
47 | RLTask.num_stacked_frames = 4
48 |
49 | # Parameters for PPO:
50 | # ==============================================================================
51 | PPOJoint.joint_model = @trax.models.PolicyAndValue
52 | PPOJoint.optimizer = @trax.optimizers.Adam
53 | PPOJoint.batch_size = 16
54 | PPOJoint.train_steps_per_epoch = 20
55 | PPOJoint.lr_schedule = @multifactor
56 | PPOJoint.n_trajectories_per_epoch = 25
57 | PPOJoint.epsilon = 0.1
58 | PPOJoint.value_loss_coeff = 0.1
59 | PPOJoint.entropy_coeff = 0.01
60 |
61 | # Parameters for train_rl:
62 | # ==============================================================================
63 | train_rl.light_rl = True
64 | train_rl.light_rl_trainer = @trax.rl.PPOJoint
65 | train_rl.n_epochs = 10000
66 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_ppo_joint_atari_sweep.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | PPOJoint.train_steps_per_epoch: [10, 100]
16 | multifactor.constant: [0.0001, 0.000025]
17 | PPOJoint.value_loss_coeff: [0.01, 0.1, 1]
18 | PPOJoint.batch_size: [8, 16, 32, 256]
19 | PPOJoint.epsilon: [0.1, 0.2, 0.3]
20 | RLTask.max_steps: [200, 2000]
21 | RLTask.env: ["pong", "boxing", "breakout", "freeway"]
22 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_ppo_lunar_lander_regression_test.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.supervised.lr_schedules
16 | import trax.models
17 | import trax.optimizers
18 | import trax.rl
19 | import trax.rl_trainer
20 |
21 | # Parameters for Adam:
22 | # ==============================================================================
23 | Adam.clip_grad_norm = 0.5
24 |
25 | # Parameters for PolicyAndValue:
26 | # ==============================================================================
27 | PolicyAndValue.body = @trax.models.MLP
28 |
29 | # Parameters for MLP:
30 | # ==============================================================================
31 | MLP.flatten = False
32 | MLP.layer_widths = (64,64)
33 | MLP.out_activation = True
34 |
35 | # Parameters for multifactor:
36 | # ==============================================================================
37 | multifactor.constant = 0.0025
38 | multifactor.factors = 'constant * linear_warmup'
39 | multifactor.warmup_steps = 100
40 |
41 | # Parameters for RLTask:
42 | # ==============================================================================
43 | RLTask.env = "LunarLander-v2"
44 | RLTask.initial_trajectories = 0
45 | RLTask.gamma = 0.99
46 | RLTask.max_steps = 1000
47 |
48 | # Parameters for PPO:
49 | # ==============================================================================
50 | PPOJoint.joint_model = @trax.models.PolicyAndValue
51 | PPOJoint.optimizer = @trax.optimizers.Adam
52 | PPOJoint.batch_size = 256
53 | PPOJoint.train_steps_per_epoch = 200
54 | PPOJoint.lr_schedule = @multifactor
55 | PPOJoint.n_trajectories_per_epoch = 20
56 | PPOJoint.epsilon = 0.1
57 | PPOJoint.value_loss_coeff = 0
58 | PPOJoint.entropy_coeff = 0.01
59 |
60 |
61 | # Parameters for train_rl:
62 | # ==============================================================================
63 | train_rl.light_rl = True
64 | train_rl.light_rl_trainer = @trax.rl.PPOJoint
65 | train_rl.n_epochs = 130
66 |
--------------------------------------------------------------------------------
/trax/rl/configs/light_ppo_pong_regression_test.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.supervised.lr_schedules
16 | import trax.models
17 | import trax.optimizers
18 | import trax.rl
19 | import trax.rl_trainer
20 |
21 | # Parameters for Adam:
22 | # ==============================================================================
23 | Adam.clip_grad_norm = 0.5
24 |
25 | # Parameters for PolicyAndValue:
26 | # ==============================================================================
27 | PolicyAndValue.body = @trax.models.AtariCnnBody
28 |
29 | # Parameters for the AtariCnnBody:
30 | # ==============================================================================
31 | AtariCnnBody.kernel_initializer = @trax.layers.AtariConvInit
32 | AtariCnnBody.n_frames = 1
33 | AtariCnnBody.padding = 'VALID'
34 |
35 | # Parameters for multifactor:
36 | # ==============================================================================
37 | multifactor.constant = 0.0001
38 | multifactor.factors = 'constant * linear_warmup'
39 | multifactor.warmup_steps = 100
40 |
41 | # Parameters for RLTask:
42 | # ==============================================================================
43 | RLTask.env = "pong"
44 | RLTask.initial_trajectories = 0
45 | RLTask.gamma = 0.99
46 | RLTask.max_steps = 2000
47 | RLTask.dm_suite = True
48 | RLTask.num_stacked_frames = 4
49 |
50 | # Parameters for PPO:
51 | # ==============================================================================
52 | PPOJoint.joint_model = @trax.models.PolicyAndValue
53 | PPOJoint.optimizer = @trax.optimizers.Adam
54 | PPOJoint.batch_size = 256
55 | PPOJoint.train_steps_per_epoch = 80
56 | PPOJoint.lr_schedule = @multifactor
57 | PPOJoint.n_trajectories_per_epoch = 5
58 | PPOJoint.epsilon = 0.1
59 | PPOJoint.value_loss_coeff = 0.25
60 | PPOJoint.entropy_coeff = 0.1
61 |
62 | # Parameters for train_rl:
63 | # ==============================================================================
64 | train_rl.light_rl = True
65 | train_rl.light_rl_trainer = @trax.rl.PPOJoint
66 | train_rl.n_epochs = 900
67 |
--------------------------------------------------------------------------------
/trax/rl/configs/ppo_atari_sweep.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | RLTask.gamma: [0.99]
16 | RLTask.env: ["boxing"]
17 | value__.multifactor.constant: [0.0001, 0.001]
18 | value__.multifactor.factors: ['constant']
19 | policy__.multifactor.constant: [0.0001, 0.00025]
20 | policy__.multifactor.factors: ['constant']
21 |
--------------------------------------------------------------------------------
/trax/rl/configs/ppo_cartpole_sweep.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | RLTask.gamma: [0.99, 0.95]
16 |
--------------------------------------------------------------------------------
/trax/rl/configs/transformer_srl_sine.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Not really an RL config, but it's here because it depends on serialization
16 | # logic which is currently in trax.rl.
17 |
18 | import trax.data.tf_inputs
19 | import trax.models
20 | import trax.optimizers
21 | import trax.supervised.trainer_lib
22 |
23 | # Module trax.data.inputs:
24 | # ==============================================================================
25 | sine_inputs.batch_size = 64
26 | sine_inputs.length = 100
27 | sine_inputs.min_period = 0.1
28 | sine_inputs.max_period = 10
29 |
30 | # Module trax.models.transformer:
31 | # ==============================================================================
32 | TransformerLM.d_model = 128
33 | TransformerLM.d_ff = 256
34 | TransformerLM.dropout = 0.1
35 | TransformerLM.max_len = 2048
36 | TransformerLM.mode = 'train'
37 | TransformerLM.n_heads = 2
38 | TransformerLM.n_layers = 2
39 |
40 | # Module trax.rl.serialization_utils:
41 | # ==============================================================================
42 | TimeSeriesModel.seq_model = @trax.models.TransformerLM
43 | TimeSeriesModel.low = -1.0
44 | TimeSeriesModel.high = 1.0
45 | TimeSeriesModel.precision = 2
46 | TimeSeriesModel.vocab_size = 16
47 | TimeSeriesModel.significance_decay = 0.7
48 |
49 | # Module trax.supervised.lr_schedules:
50 | # ==============================================================================
51 | multifactor.constant = 0.1
52 | multifactor.factors = 'constant * linear_warmup * rsqrt_decay'
53 | multifactor.warmup_steps = 1000
54 |
55 | # Module trax.supervised.trainer_lib:
56 | # ==============================================================================
57 | train.inputs = @trax.data.sine_inputs
58 | train.eval_frequency = 1000
59 | train.eval_steps = 10
60 | train.model = @trax.rl.TimeSeriesModel
61 | train.optimizer = @trax.optimizers.Adam
62 | train.steps = 10000
63 | train.callbacks = (
64 | @trax.supervised.callbacks.SerializedModelEvaluation,
65 | )
66 |
67 | # Module trax.supervised.callbacks:
68 | # ==============================================================================
69 | SerializedModelEvaluation.eval_at = 500
70 | SerializedModelEvaluation.context_lengths = (0, 1, 10)
71 | SerializedModelEvaluation.horizon_lengths = (1, 3, 10)
72 |
--------------------------------------------------------------------------------
/trax/rl/envs/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Trax RL environments library."""
17 |
18 | import gin
19 | from trax.rl.envs import data_envs
20 |
21 |
22 | def configure_rl_env(*args, **kwargs):
23 | kwargs['module'] = 'trax.rl.envs'
24 | return gin.external_configurable(*args, **kwargs)
25 |
26 |
27 | copy_stream = configure_rl_env(data_envs.copy_stream)
28 | SequenceDataEnv = configure_rl_env(data_envs.SequenceDataEnv)
29 |
--------------------------------------------------------------------------------
/trax/rl/normalization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for trax.rl.normalization."""
17 |
18 | from absl.testing import absltest
19 | import numpy as np
20 |
21 | from trax import shapes
22 | from trax.rl import normalization
23 |
24 |
25 | class NormalizationTest(absltest.TestCase):
26 |
27 | def test_running_mean(self):
28 | x = np.random.uniform(size=10)
29 | state = normalization.running_mean_init(shape=())
30 | for i in range(len(x)):
31 | state = normalization.running_mean_update(x[i], state)
32 | np.testing.assert_almost_equal(
33 | normalization.running_mean_get_mean(state), np.mean(x[:i + 1])
34 | )
35 |
36 | def test_running_variance(self):
37 | x = np.random.uniform(size=10)
38 | state = normalization.running_mean_and_variance_init(shape=())
39 | for i in range(len(x)):
40 | state = normalization.running_mean_and_variance_update(x[i], state)
41 | np.testing.assert_almost_equal(
42 | normalization.running_mean_and_variance_get_variance(state),
43 | np.var(x[:i + 1]),
44 | )
45 |
46 | def test_normalize_collect(self):
47 | x = np.random.uniform(size=(2, 3, 4, 5))
48 | normalize = normalization.Normalize(mode='collect')
49 | normalize.init(shapes.signature(x))
50 | old_state = normalize.state
51 | y = normalize(x)
52 | with self.assertRaises(AssertionError):
53 | np.testing.assert_equal(normalize.state, old_state)
54 | with self.assertRaises(AssertionError):
55 | np.testing.assert_almost_equal(x, y)
56 |
57 | def test_normalize_train(self):
58 | x = np.random.uniform(size=(2, 3, 4, 5))
59 | normalize = normalization.Normalize(mode='train', epsilon=0.0)
60 | normalize.init(shapes.signature(x))
61 | old_state = normalize.state
62 | y = normalize(x)
63 | np.testing.assert_equal(normalize.state, old_state)
64 | np.testing.assert_almost_equal(x, y)
65 |
66 |
67 | if __name__ == '__main__':
68 | absltest.main()
69 |
--------------------------------------------------------------------------------
/trax/supervised/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Supervised learning imports in Trax."""
17 |
18 | from trax.supervised import callbacks
19 | from trax.supervised import decoding
20 | from trax.supervised import lr_schedules
21 | from trax.supervised import trainer_lib
22 | from trax.supervised import training
23 | from trax.supervised.trainer_lib import train
24 | from trax.supervised.trainer_lib import Trainer
25 | from trax.supervised.training import EvalTask
26 | from trax.supervised.training import TrainTask
27 |
--------------------------------------------------------------------------------
/trax/supervised/configs/bert.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import t5.data.preprocessors
16 | import trax.data
17 | import trax.layers
18 | import trax.models
19 | import trax.optimizers
20 | import trax.supervised.lr_schedules
21 | import trax.supervised.trainer_lib
22 | import trax.layers.metrics
23 |
24 | # Parameters for TFDS data pipeline:
25 | # ==============================================================================
26 | data.Tokenize.vocab_file = 'bert_uncased_vocab.txt'
27 | data.Tokenize.vocab_type = 'bert-lowercase'
28 | # If during the execution time of the binary the directory trax/data/testdata
29 | # containing the vocab file is not accessible, then copy the file to a drive
30 | # and change the path accordingly.
31 | data.Tokenize.vocab_dir = 'trax/data/testdata/'
32 | data.Tokenize.keys = [0, 1]
33 | data.PadToLength.len_map = {0: 512, 1: 512, 2: 512}
34 | data.PadToLength.pad_value = {0: 0, 1: 0, 2:0}
35 | data.TruncateToLength.len_map = {0: (512,), 1: (512,), 2: (512,)}
36 | data.Batch.batch_size = 16
37 |
38 | # Parameters for train:
39 | # ==============================================================================
40 | train.optimizer = @trax.optimizers.Adam
41 | train.eval_frequency = 20
42 | train.eval_steps = 10
43 | train.inputs = @data.make_inputs
44 | train.model = @trax.models.BERT
45 | train.steps = 200000
46 | train.checkpoint_highest = 'accuracy'
47 |
48 | # Parameters for BERT:
49 | # ==============================================================================
50 | BERT.init_checkpoint = 'bert-base-uncased'
51 |
52 | # Parameters for multifactor:
53 | # ==============================================================================
54 | multifactor.constant = 3e-5
55 | multifactor.factors = 'constant * linear_warmup'
56 | multifactor.warmup_steps = 1000
57 |
--------------------------------------------------------------------------------
/trax/supervised/configs/bert_glue_classification.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models.research.bert
17 |
18 | include 'bert.gin'
19 |
20 | # Parameters for training and eval data:
21 | # ==============================================================================
22 | #
23 | # See https://www.tensorflow.org/datasets/catalog/glue. Valid benchmark names:
24 | #
25 | # cola, sst2, mrpc, qqp, stsb, mnli, qnli, rte, wnli.
26 | #
27 | # Training on WNLI with this setup is not recommended and will likely result in
28 | # lower than baseline accuracy.
29 | make_inputs.train_stream = @data.BertGlueTrainStream
30 | make_inputs.eval_stream = @data.BertGlueEvalStream
31 | BertGlueTrainStream.benchmark = 'mnli'
32 | BertGlueEvalStream.benchmark = 'mnli'
33 |
34 | # Parameters for BERT:
35 | # ==============================================================================
36 | BERT.head = @bert.BERTClassifierHead
37 | bert.BERTClassifierHead.n_classes = 3
38 |
--------------------------------------------------------------------------------
/trax/supervised/configs/bert_glue_regression.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models.research.bert
17 |
18 | include 'bert.gin'
19 |
20 | # Parameters for training and eval data:
21 | # ==============================================================================
22 | #
23 | # See https://www.tensorflow.org/datasets/catalog/glue. Valid benchmark names:
24 | #
25 | # cola, sst2, mrpc, qqp, stsb, mnli, qnli, rte, wnli.
26 | #
27 | # Training on WNLI with this setup is not recommended and will likely result in
28 | # lower than baseline accuracy.
29 | make_inputs.train_stream = @data.BertGlueTrainStream
30 | make_inputs.eval_stream = @data.BertGlueEvalStream
31 | BertGlueTrainStream.benchmark = 'stsb'
32 | BertGlueEvalStream.benchmark = 'stsb'
33 |
34 | # Parameters for train:
35 | # ==============================================================================
36 | train.loss_fn = @trax.layers.L2Loss()
37 | train.metrics = {'loss': @trax.layers.L2Loss()}
38 |
39 | # Parameters for BERT:
40 | # ==============================================================================
41 | BERT.head = @bert.BERTRegressionHead
42 |
--------------------------------------------------------------------------------
/trax/supervised/configs/bert_glue_sweep_regression_task.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | dataset_name: [
16 | 'glue/stsb',
17 | ]
18 | multifactor.constant: [2.0e-5, 3.0e-5, 5.0e-5]
19 |
--------------------------------------------------------------------------------
/trax/supervised/configs/bert_glue_sweep_single_sentence.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | dataset_name: [
16 | 'glue/cola',
17 | 'glue/sst2',
18 | ]
19 | multifactor.constant: [2.0e-5, 3.0e-5, 5.0e-5]
20 | CreateBertInputs.double_sentence: [False]
21 | data.Tokenize.keys: [[0,]]
22 |
--------------------------------------------------------------------------------
/trax/supervised/configs/bert_glue_sweep_two_sentences.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | dataset_name: [
16 | 'glue/mrpc',
17 | 'glue/qqp',
18 | 'glue/mnli',
19 | 'glue/qnli',
20 | 'glue/rte',
21 | 'glue/wnli',
22 | ]
23 | multifactor.constant: [2.0e-5, 3.0e-5, 5.0e-5]
24 |
--------------------------------------------------------------------------------
/trax/supervised/configs/bert_pretraining.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models.research.bert
17 |
18 | include 'bert.gin'
19 |
20 | dataset_name = 'wiki40b'
21 |
22 | # Parameters for TFDS data pipeline:
23 | # ==============================================================================
24 | make_inputs.train_stream = [
25 | @train/data.BertNextSentencePredictionInputs(),
26 | @data.Tokenize(),
27 | @data.CreateBertInputs(),
28 | @data.Shuffle(),
29 | @data.PadToLength(),
30 | @data.TruncateToLength(),
31 | @data.mask_random_tokens,
32 | @data.Batch()
33 | ]
34 | make_inputs.eval_stream = [
35 | @eval/data.BertNextSentencePredictionInputs(),
36 | @data.Tokenize(),
37 | @data.CreateBertInputs(),
38 | @data.Shuffle(),
39 | @data.PadToLength(),
40 | @data.TruncateToLength(),
41 | @data.mask_random_tokens,
42 | @data.Batch()
43 | ]
44 |
45 | train/data.BertNextSentencePredictionInputs.dataset_name = %dataset_name
46 | train/data.TFDS.train = True
47 | eval/data.BertNextSentencePredictionInputs.dataset_name = %dataset_name
48 | eval/data.TFDS.train = False
49 |
50 | # Parameters for train:
51 | # ==============================================================================
52 | train.loss_fn = @bert.BERTPretrainingLoss()
53 | train.metrics = {'loss': @bert.BERTPretrainingLoss()}
54 |
55 | # Parameters for BERT:
56 | # ==============================================================================
57 | BERT.head = @bert.BERTPretrainingHead
58 | bert.BERTPretrainingHead.n_classes = 2
59 |
--------------------------------------------------------------------------------
/trax/supervised/configs/bert_pretraining_onlymlm.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models.research.bert
17 |
18 | include 'bert.gin'
19 |
20 | dataset_name = 'wiki40b'
21 |
22 | # Parameters for TFDS data pipeline:
23 | # ==============================================================================
24 | make_inputs.train_stream = [
25 | @train/data.CorpusToRandomChunks(),
26 | @data.Tokenize(),
27 | @data.CreateBertInputs(),
28 | @data.Shuffle(),
29 | @data.PadToLength(),
30 | @data.TruncateToLength(),
31 | @data.mask_random_tokens,
32 | @data.Batch()
33 | ]
34 | make_inputs.eval_stream = [
35 | @eval/data.CorpusToRandomChunks(),
36 | @data.Tokenize(),
37 | @data.CreateBertInputs(),
38 | @data.Shuffle(),
39 | @data.PadToLength(),
40 | @data.TruncateToLength(),
41 | @data.mask_random_tokens,
42 | @data.Batch()
43 | ]
44 |
45 | train/data.CorpusToRandomChunks.dataset_name = %dataset_name
46 | train/data.TFDS.train = True
47 | eval/data.CorpusToRandomChunks.dataset_name = %dataset_name
48 | eval/data.TFDS.train = False
49 |
50 | data.CreateBertInputs.labeled = False
51 | data.CreateBertInputs.double_sentence = False
52 |
53 | # Parameters for BERT:
54 | # ==============================================================================
55 | BERT.head = @bert.BERTMLMHead
56 |
57 |
--------------------------------------------------------------------------------
/trax/supervised/configs/bert_pretraining_onlynsp.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models.research.bert
17 |
18 | include 'bert.gin'
19 |
20 | dataset_name = 'wiki40b'
21 |
22 | # Parameters for TFDS data pipeline:
23 | # ==============================================================================
24 | make_inputs.train_stream = [
25 | @train/data.BertNextSentencePredictionInputs(),
26 | @data.Tokenize(),
27 | @data.CreateBertInputs(),
28 | @data.Shuffle(),
29 | @data.PadToLength(),
30 | @data.TruncateToLength(),
31 | @data.Batch()
32 | ]
33 | make_inputs.eval_stream = [
34 | @eval/data.BertNextSentencePredictionInputs(),
35 | @data.Tokenize(),
36 | @data.CreateBertInputs(),
37 | @data.Shuffle(),
38 | @data.PadToLength(),
39 | @data.TruncateToLength(),
40 | @data.Batch()
41 | ]
42 |
43 | train/data.BertNextSentencePredictionInputs.dataset_name = %dataset_name
44 | train/data.TFDS.train = True
45 | eval/data.BertNextSentencePredictionInputs.dataset_name = %dataset_name
46 | eval/data.TFDS.train = False
47 |
48 | # Parameters for BERT:
49 | # ==============================================================================
50 | BERT.head = @bert.BERTClassifierHead
51 | bert.BERTClassifierHead.n_classes = 2
52 |
53 |
--------------------------------------------------------------------------------
/trax/supervised/configs/cond_skipping_transformer_lm1b.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models
17 | import trax.optimizers
18 | import trax.supervised.trainer_lib
19 |
20 | # Parameters for batcher:
21 | # ==============================================================================
22 | batcher.data_streams = @data.data_streams
23 | batcher.batch_size_per_device = 96
24 | batcher.eval_batch_size = 128
25 | batcher.max_eval_length = 2048
26 | batcher.id_to_mask = 0
27 |
28 | # Parameters for data_streams:
29 | # ==============================================================================
30 | data_streams.data_dir = None
31 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k'
32 | data_streams.input_name = 'targets'
33 | data_streams.preprocess_fn = @data.lm1b_preprocess
34 |
35 | # Parameters for multifactor:
36 | # ==============================================================================
37 | multifactor.constant = 0.3
38 | multifactor.factors = 'constant * linear_warmup'
39 | multifactor.warmup_steps = 8000
40 |
41 | # Parameters for lm1b_preprocess:
42 | # ==============================================================================
43 | lm1b_preprocess.max_target_length = 512
44 | lm1b_preprocess.max_eval_target_length = 2048
45 |
46 | # Parameters for train:
47 | # ==============================================================================
48 | train.eval_frequency = 1000
49 | train.eval_steps = 10
50 | train.model = @trax.models.SkippingTransformerLM
51 | train.optimizer = @trax.optimizers.SM3
52 | train.steps = 500000
53 |
54 | # Parameters for DotProductCausalAttention:
55 | # ==============================================================================
56 | DotProductCausalAttention.dropout = 0.1
57 |
58 | # Parameters for SkippingTransformerLM:
59 | # ==============================================================================
60 | SkippingTransformerLM.d_model = 1024
61 | SkippingTransformerLM.d_ff = 4096
62 | SkippingTransformerLM.dropout = 0.2
63 | SkippingTransformerLM.max_len = 2048
64 | SkippingTransformerLM.mode = 'train'
65 | SkippingTransformerLM.n_heads = 8
66 | SkippingTransformerLM.n_layers = 8
67 | SkippingTransformerLM.vocab_size = 32000
68 |
--------------------------------------------------------------------------------
/trax/supervised/configs/gru_copy.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models
17 | import trax.optimizers
18 | import trax.supervised.trainer_lib
19 |
20 | n_symbols = 32
21 | length = 16
22 | batch = 512
23 |
24 | # Parameters for data_streams:
25 | # ==============================================================================
26 | data_streams.data_dir = None
27 | data_streams.dataset_name = 'gru_copy'
28 |
29 | # Parameters for sequence_copy_inputs:
30 | # ==============================================================================
31 | sequence_copy_inputs.vocab_size = %n_symbols
32 | sequence_copy_inputs.batch_size = %batch
33 | sequence_copy_inputs.train_length = %length
34 | sequence_copy_inputs.eval_min_length = 2
35 | sequence_copy_inputs.eval_max_length = %length
36 | sequence_copy_inputs.reverse = False
37 |
38 | # Parameters for multifactor:
39 | # ==============================================================================
40 | multifactor.constant = 0.001
41 | multifactor.factors = 'constant * linear_warmup'
42 | multifactor.warmup_steps = 8000
43 |
44 | # Parameters for train:
45 | # ==============================================================================
46 | train.eval_frequency = 1000
47 | train.eval_steps = 10
48 | train.inputs = @trax.data.sequence_copy_inputs
49 | train.model = @trax.models.RNNLM
50 | train.optimizer = @trax.optimizers.Adam
51 | train.steps = 500000
52 |
53 | # Parameters for RNNLM:
54 | # ==============================================================================
55 | RNNLM.rnn_cell = @trax.layers.GRUCell
56 | RNNLM.rnn_cell_d_state_multiplier = 1
57 | RNNLM.d_model = 128
58 | RNNLM.dropout = 0.1
59 | RNNLM.n_layers = 2
60 | RNNLM.vocab_size = %n_symbols
61 |
--------------------------------------------------------------------------------
/trax/supervised/configs/hourglass_cifar10.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.layers
17 | import trax.models
18 | import trax.optimizers
19 | import trax.supervised.trainer_lib
20 |
21 | train_steps = 100000
22 |
23 | # Parameters for batcher:
24 | # ==============================================================================
25 | batcher.data_streams = @data.data_streams
26 | batcher.batch_size_per_device = 1
27 | batcher.eval_batch_size = 8
28 | batcher.max_eval_length = 3072 # 32 * 32 * 3
29 | batcher.variable_shapes = False
30 |
31 | # Parameters for data_streams:
32 | # ==============================================================================
33 | data_streams.data_dir = None
34 | data_streams.dataset_name = 'cifar10'
35 | data_streams.input_name = 'image'
36 | data_streams.target_name = 'image'
37 | data_streams.bare_preprocess_fn = \
38 | @data.downsampled_imagenet_flatten_bare_preprocess
39 |
40 | # Parameters for multifactor: # ================================================
41 | multifactor.constant = 1e-3
42 | multifactor.factors = 'constant * linear_warmup * cosine_decay'
43 | multifactor.warmup_steps = 5000
44 | multifactor.steps_per_cycle = %train_steps
45 |
46 | # Parameters for Adam:
47 | # ==============================================================================
48 | Adam.weight_decay_rate=0.0
49 | Adam.b1 = 0.9
50 | Adam.b2 = 0.98
51 | Adam.eps = 1e-9
52 |
53 | # Parameters for train:
54 | # ==============================================================================
55 | train.eval_frequency = 2000
56 | train.eval_steps = 625
57 | train.checkpoints_at = [100000]
58 | train.model = @trax.models.HourglassLM
59 | train.optimizer = @trax.optimizers.Adam
60 | train.steps = %train_steps
61 |
62 |
63 | # Parameters for HourglassLM:
64 | # ==============================================================================
65 | HourglassLM.d_model = 512
66 | HourglassLM.d_ff = 2048
67 | HourglassLM.vanilla_layers = (1, 1)
68 | HourglassLM.hierarchy = '8@3'
69 | HourglassLM.dropout = 0.0
70 | HourglassLM.mode = 'train'
71 | HourglassLM.n_heads = 8
72 | HourglassLM.vocab_size = 256
73 | HourglassLM.attention_downsampling_fn = @LinearPooling
74 | HourglassLM.attention_upsampling_fn = @LinearUpsampling
75 |
--------------------------------------------------------------------------------
/trax/supervised/configs/hourglass_enwik8.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.layers
17 | import trax.models
18 | import trax.optimizers
19 | import trax.supervised.trainer_lib
20 |
21 |
22 | # Parameters for batcher:
23 | # ==============================================================================
24 | batcher.data_streams = @data.data_streams
25 | batcher.max_eval_length = 2049
26 | batcher.buckets = ([2049], [8])
27 | batcher.id_to_mask = 0
28 |
29 | # Parameters for data_streams:
30 | # ==============================================================================
31 | data_streams.data_dir = None
32 | data_streams.dataset_name = 't2t_enwik8_l2k'
33 | data_streams.input_name = 'targets'
34 |
35 | # Parameters for multifactor:
36 | # ==============================================================================
37 | # 0.03125 ~= 1024^-0.5 = d_model^-0.5
38 | multifactor.constant = 4.1e-4
39 | multifactor.factors = 'constant * linear_warmup * cosine_decay'
40 | multifactor.warmup_steps = 4000
41 | multifactor.steps_per_cycle = 350000
42 |
43 | # Parameters for Adam:
44 | # ==============================================================================
45 | Adam.weight_decay_rate = 0.0
46 | Adam.b1 = 0.9
47 | Adam.b2 = 0.98
48 | Adam.eps = 1e-9
49 |
50 | # Parameters for train:
51 | # ==============================================================================
52 | train.eval_frequency = 2000
53 | train.eval_steps = 305
54 | train.model = @trax.models.HourglassLM
55 | train.optimizer = @trax.optimizers.Adam
56 | train.steps = 263000
57 | train.save_graphs = False
58 | train.checkpoints_at = [150000, 175000, 263000]
59 |
60 |
61 | # Parameters for HourglassLM:
62 | # ==============================================================================
63 | HourglassLM.d_ff = 2048
64 | HourglassLM.d_model = 512
65 | HourglassLM.dropout = 0.2
66 | HourglassLM.vanilla_layers = (5,5)
67 | HourglassLM.hierarchy = '24@3'
68 | HourglassLM.n_heads = 8
69 | HourglassLM.vocab_size = 256
70 | HourglassLM.attention_upsampling_fn = @NaiveUpsampling
71 | HourglassLM.ff_activation = @trax.layers.FastGelu
72 |
--------------------------------------------------------------------------------
/trax/supervised/configs/layerdrop_every_transformer_lm1b.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models
17 | import trax.optimizers
18 | import trax.supervised.trainer_lib
19 |
20 | # Parameters for batcher:
21 | # ==============================================================================
22 | batcher.data_streams = @data.data_streams
23 | batcher.batch_size_per_device = 96
24 | batcher.eval_batch_size = 128
25 | batcher.max_eval_length = 2048
26 | batcher.id_to_mask = 0
27 |
28 | # Parameters for data_streams:
29 | # ==============================================================================
30 | data_streams.data_dir = None
31 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k'
32 | data_streams.input_name = 'targets'
33 | data_streams.preprocess_fn = @data.lm1b_preprocess
34 |
35 | # Parameters for multifactor:
36 | # ==============================================================================
37 | multifactor.constant = 0.3
38 | multifactor.factors = 'constant * linear_warmup'
39 | multifactor.warmup_steps = 8000
40 |
41 | # Parameters for lm1b_preprocess:
42 | # ==============================================================================
43 | lm1b_preprocess.max_target_length = 512
44 | lm1b_preprocess.max_eval_target_length = 2048
45 |
46 | # Parameters for train:
47 | # ==============================================================================
48 | train.eval_frequency = 1000
49 | train.eval_steps = 10
50 | train.model = @trax.models.EveryOtherLayerDropTransformerLM
51 | train.optimizer = @trax.optimizers.SM3
52 | train.steps = 500000
53 |
54 | # Parameters for DotProductCausalAttention:
55 | # ==============================================================================
56 | DotProductCausalAttention.dropout = 0.1
57 |
58 | # Parameters for SkippingTransformerLM:
59 | # ==============================================================================
60 | EveryOtherLayerDropTransformerLM.d_model = 1024
61 | EveryOtherLayerDropTransformerLM.d_ff = 4096
62 | EveryOtherLayerDropTransformerLM.dropout = 0.2
63 | EveryOtherLayerDropTransformerLM.max_len = 2048
64 | EveryOtherLayerDropTransformerLM.mode = 'train'
65 | EveryOtherLayerDropTransformerLM.n_heads = 8
66 | EveryOtherLayerDropTransformerLM.n_layers = 8
67 | EveryOtherLayerDropTransformerLM.vocab_size = 32000
68 | EveryOtherLayerDropTransformerLM.skip_mode = 'even'
69 |
--------------------------------------------------------------------------------
/trax/supervised/configs/layerdrop_transformer_lm1b.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models
17 | import trax.optimizers
18 | import trax.supervised.trainer_lib
19 |
20 | # Parameters for batcher:
21 | # ==============================================================================
22 | batcher.data_streams = @data.data_streams
23 | batcher.batch_size_per_device = 96
24 | batcher.eval_batch_size = 128
25 | batcher.max_eval_length = 2048
26 | batcher.id_to_mask = 0
27 |
28 | # Parameters for data_streams:
29 | # ==============================================================================
30 | data_streams.data_dir = None
31 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k'
32 | data_streams.input_name = 'targets'
33 | data_streams.preprocess_fn = @data.lm1b_preprocess
34 |
35 | # Parameters for multifactor:
36 | # ==============================================================================
37 | multifactor.constant = 0.3
38 | multifactor.factors = 'constant * linear_warmup'
39 | multifactor.warmup_steps = 8000
40 |
41 | # Parameters for lm1b_preprocess:
42 | # ==============================================================================
43 | lm1b_preprocess.max_target_length = 512
44 | lm1b_preprocess.max_eval_target_length = 2048
45 |
46 | # Parameters for train:
47 | # ==============================================================================
48 | train.eval_frequency = 1000
49 | train.eval_steps = 10
50 | train.model = @trax.models.LayerDropTransformerLM
51 | train.optimizer = @trax.optimizers.SM3
52 | train.steps = 500000
53 |
54 | # Parameters for DotProductCausalAttention:
55 | # ==============================================================================
56 | DotProductCausalAttention.dropout = 0.1
57 |
58 | # Parameters for SkippingTransformerLM:
59 | # ==============================================================================
60 | LayerDropTransformerLM.d_model = 1024
61 | LayerDropTransformerLM.d_ff = 4096
62 | LayerDropTransformerLM.dropout = 0.2
63 | LayerDropTransformerLM.max_len = 2048
64 | LayerDropTransformerLM.mode = 'train'
65 | LayerDropTransformerLM.n_heads = 8
66 | LayerDropTransformerLM.n_layers = 8
67 | LayerDropTransformerLM.vocab_size = 32000
68 | LayerDropTransformerLM.skip_fraction = 0.4
69 |
--------------------------------------------------------------------------------
/trax/supervised/configs/layerdrop_ushape_transformer_lm1b.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models
17 | import trax.optimizers
18 | import trax.supervised.trainer_lib
19 |
20 | # Parameters for batcher:
21 | # ==============================================================================
22 | batcher.data_streams = @data.data_streams
23 | batcher.batch_size_per_device = 96
24 | batcher.eval_batch_size = 128
25 | batcher.max_eval_length = 2048
26 | batcher.id_to_mask = 0
27 |
28 | # Parameters for data_streams:
29 | # ==============================================================================
30 | data_streams.data_dir = None
31 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k'
32 | data_streams.input_name = 'targets'
33 | data_streams.preprocess_fn = @data.lm1b_preprocess
34 |
35 | # Parameters for multifactor:
36 | # ==============================================================================
37 | multifactor.constant = 0.3
38 | multifactor.factors = 'constant * linear_warmup'
39 | multifactor.warmup_steps = 8000
40 |
41 | # Parameters for lm1b_preprocess:
42 | # ==============================================================================
43 | lm1b_preprocess.max_target_length = 512
44 | lm1b_preprocess.max_eval_target_length = 2048
45 |
46 | # Parameters for train:
47 | # ==============================================================================
48 | train.eval_frequency = 1000
49 | train.eval_steps = 10
50 | train.model = @trax.models.LayerDropTransformerLM
51 | train.optimizer = @trax.optimizers.SM3
52 | train.steps = 500000
53 |
54 | # Parameters for DotProductCausalAttention:
55 | # ==============================================================================
56 | DotProductCausalAttention.dropout = 0.1
57 |
58 | # Parameters for SkippingTransformerLM:
59 | # ==============================================================================
60 | LayerDropTransformerLM.d_model = 1024
61 | LayerDropTransformerLM.d_ff = 4096
62 | LayerDropTransformerLM.dropout = 0.2
63 | LayerDropTransformerLM.max_len = 2048
64 | LayerDropTransformerLM.mode = 'train'
65 | LayerDropTransformerLM.n_heads = 8
66 | LayerDropTransformerLM.n_layers = 8
67 | LayerDropTransformerLM.vocab_size = 32000
68 | LayerDropTransformerLM.skip_fraction = [0.1, 0.1, 0.3, 0.3, 0.3, 0.3, 0.1, 0.1]
69 |
--------------------------------------------------------------------------------
/trax/supervised/configs/lstm_lm1b.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models
17 | import trax.optimizers
18 | import trax.supervised.trainer_lib
19 |
20 | # Parameters for batcher:
21 | # ==============================================================================
22 | batcher.data_streams = @data.data_streams
23 | batcher.batch_size_per_device = 64
24 | batcher.eval_batch_size = 64
25 | batcher.max_eval_length = 2048
26 | batcher.id_to_mask = 0
27 |
28 | # Parameters for data_streams:
29 | # ==============================================================================
30 | data_streams.data_dir = None
31 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k'
32 | data_streams.input_name = 'targets'
33 | data_streams.preprocess_fn = @data.lm1b_preprocess
34 |
35 | # Parameters for multifactor:
36 | # ==============================================================================
37 | multifactor.constant = 0.001
38 | multifactor.factors = 'constant * linear_warmup'
39 | multifactor.warmup_steps = 8000
40 |
41 | # Parameters for lm1b_preprocess:
42 | # ==============================================================================
43 | lm1b_preprocess.max_target_length = 512
44 | lm1b_preprocess.max_eval_target_length = 2048
45 |
46 | # Parameters for train:
47 | # ==============================================================================
48 | train.eval_frequency = 1000
49 | train.eval_steps = 10
50 | train.model = @trax.models.RNNLM
51 | train.optimizer = @trax.optimizers.Adam
52 | train.steps = 500000
53 |
54 | # Parameters for RNNLM:
55 | # ==============================================================================
56 | RNNLM.rnn_cell = @trax.layers.LSTMCell
57 | RNNLM.rnn_cell_d_state_multiplier = 2
58 | RNNLM.d_model = 512
59 | RNNLM.dropout = 0.1
60 | RNNLM.n_layers = 2
61 | RNNLM.vocab_size = 32000
62 |
--------------------------------------------------------------------------------
/trax/supervised/configs/lstm_seq2seq_wmt_ende.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models
17 | import trax.optimizers
18 | import trax.supervised.lr_schedules
19 | import trax.supervised.trainer_lib
20 |
21 | # Parameters for batcher:
22 | # ==============================================================================
23 | batcher.data_streams = @data.data_streams
24 | batcher.batch_size_per_device = 64
25 | batcher.eval_batch_size = 64
26 | batcher.max_eval_length = 512
27 | batcher.bucket_length = 32
28 | batcher.buckets_include_inputs_in_length=True
29 | batcher.id_to_mask = 0
30 |
31 | # Parameters for data_streams:
32 | # ==============================================================================
33 | data_streams.data_dir = None
34 | data_streams.dataset_name = 't2t_translate_ende_wmt32k'
35 | data_streams.preprocess_fn = @data.wmt_preprocess
36 |
37 | # Parameters for wmt_preproces:
38 | # ==============================================================================
39 | wmt_preprocess.max_length = 256
40 | wmt_preprocess.max_eval_length = 512
41 |
42 | # Parameters for lr_schedules.warmup:
43 | # ==============================================================================
44 | lr_schedules.warmup.max_value = 0.001
45 | lr_schedules.warmup.n_warmup_steps = 1000
46 |
47 | # Parameters for Adam:
48 | # ==============================================================================
49 | Adam.weight_decay_rate = 0.0
50 |
51 | # Parameters for train:
52 | # ==============================================================================
53 | train.eval_frequency = 1000
54 | train.eval_steps = 10
55 | train.model = @trax.models.LSTMSeq2SeqAttn
56 | train.optimizer = @trax.optimizers.Adam
57 | train.lr_schedule_fn = @lr_schedules.warmup
58 | train.steps = 250000
59 |
60 | # Parameters for LSTMSeq2SeqAttn:
61 | # ==============================================================================
62 | LSTMSeq2SeqAttn.d_model= 1024
63 | LSTMSeq2SeqAttn.n_encoder_layers = 2
64 | LSTMSeq2SeqAttn.n_decoder_layers = 2
65 | LSTMSeq2SeqAttn.attention_dropout = 0.2
66 | LSTMSeq2SeqAttn.n_attention_heads = 8
67 | LSTMSeq2SeqAttn.input_vocab_size = 33300
68 | LSTMSeq2SeqAttn.target_vocab_size = 33300
69 |
--------------------------------------------------------------------------------
/trax/supervised/configs/mlp_mnist.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models
17 | import trax.optimizers
18 | import trax.supervised.lr_schedules
19 | import trax.supervised.trainer_lib
20 |
21 | # Parameters for batcher:
22 | # ==============================================================================
23 | batcher.data_streams = @data.data_streams
24 | batcher.batch_size_per_device = 256
25 | batcher.eval_batch_size = 256
26 | batcher.variable_shapes = False
27 |
28 | # Parameters for data.data_streams:
29 | # ==============================================================================
30 | data_streams.data_dir = None
31 | data_streams.dataset_name = 'mnist'
32 |
33 | # Parameters for MLP:
34 | # ==============================================================================
35 | MLP.layer_widths=(128, 64)
36 |
37 | # Parameters for lr_schedules.constant
38 | # ==============================================================================
39 | lr_schedules.constant.value = 0.0001
40 |
41 | # Parameters for train:
42 | # ==============================================================================
43 | train.optimizer = @trax.optimizers.Adam
44 | train.eval_frequency = 200
45 | train.eval_steps = 10
46 | train.model = @trax.models.MLP
47 | train.steps = 2000
48 | train.checkpoint_highest = 'accuracy'
49 | train.lr_schedule_fn = @lr_schedules.constant
50 |
--------------------------------------------------------------------------------
/trax/supervised/configs/reformer_cifar10.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models
17 | import trax.optimizers
18 | import trax.supervised.trainer_lib
19 |
20 | # Parameters that will vary between experiments:
21 | # ==============================================================================
22 | train.model = @trax.models.ReformerLM
23 | attn_kv = 64
24 | n_layers = 9
25 |
26 | # Parameters for batcher:
27 | # ==============================================================================
28 | batcher.data_streams = @data.data_streams
29 | batcher.batch_size_per_device = 1
30 | batcher.eval_batch_size = 8
31 | batcher.max_eval_length = 12288 # 64 * 64 * 3
32 |
33 | # Parameters for data_streams:
34 | # ==============================================================================
35 | data_streams.data_dir = None
36 | data_streams.dataset_name = 'cifar10'
37 | data_streams.preprocess_fn = @data.cifar10_augmentation_flatten_preprocess
38 |
39 | # Parameters for multifactor:
40 | # ==============================================================================
41 | multifactor.constant = 1.0
42 | multifactor.factors = 'constant * linear_warmup * rsqrt_decay'
43 | multifactor.warmup_steps = 8000
44 |
45 | # Parameters for train:
46 | # ==============================================================================
47 | train.eval_frequency = 500
48 | train.eval_steps = 64
49 | # train.model: see top
50 | train.optimizer = @trax.optimizers.Adafactor
51 | train.steps = 500000
52 | train.save_graphs = False
53 | train.checkpoints_at = \
54 | [1000, 5000, 10000, 20000, 40000, 60000, 80000,
55 | 100000, 200000, 300000, 400000, 500000]
56 |
57 | # Parameters for ReformerLM:
58 | # ==============================================================================
59 | ReformerLM.d_attention_key = %attn_kv
60 | ReformerLM.d_attention_value = %attn_kv
61 | ReformerLM.d_model = 1024
62 | ReformerLM.d_ff = 4096
63 | ReformerLM.dropout = 0.1
64 | ReformerLM.max_len = 12288 # 64 * 64 * 3
65 | ReformerLM.mode = 'train'
66 | ReformerLM.n_heads = 8
67 | ReformerLM.n_layers = %n_layers
68 | ReformerLM.vocab_size = 256
69 |
--------------------------------------------------------------------------------
/trax/supervised/configs/resnet50_frn_imagenet_8gb.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.supervised.lr_schedules
17 | import trax.models
18 | import trax.optimizers
19 | import trax.supervised.trainer_lib
20 |
21 | # Parameters for batcher:
22 | # ==============================================================================
23 | batcher.data_streams = @data.data_streams
24 | batcher.batch_size_per_device = 32
25 | batcher.bucket_length = 32
26 | batcher.buckets = None
27 | batcher.eval_batch_size = 32
28 | batcher.variable_shapes = False
29 |
30 | # Parameters for data_streams:
31 | # ==============================================================================
32 | data_streams.data_dir = None
33 | data_streams.dataset_name = 't2t_image_imagenet224'
34 |
35 | # Parameters for FilterResponseNorm:
36 | # ==============================================================================
37 | # TODO(afrozm): Make this work in the learn_epsilon = True setting.
38 | FilterResponseNorm.learn_epsilon = False
39 |
40 | # Parameters for multifactor:
41 | # ==============================================================================
42 | multifactor.constant = 30.0
43 | multifactor.factors = 'constant * linear_warmup * rsqrt_decay'
44 | multifactor.warmup_steps = 30000
45 |
46 | # Parameters for Momentum:
47 | # ==============================================================================
48 | Momentum.mass = 0.9
49 |
50 | # Parameters for Resnet50:
51 | # ==============================================================================
52 | Resnet50.d_hidden = 64
53 | Resnet50.n_output_classes = 1001
54 | Resnet50.norm = @trax.layers.FilterResponseNorm
55 | Resnet50.non_linearity = @trax.layers.ThresholdedLinearUnit
56 |
57 | # Parameters for train:
58 | # ==============================================================================
59 | train.eval_frequency = 2000
60 | train.eval_steps = 20
61 | train.model = @trax.models.Resnet50
62 | train.optimizer = @trax.optimizers.Momentum
63 | train.steps = 300000
64 |
--------------------------------------------------------------------------------
/trax/supervised/configs/resnet50_imagenet_8gb_testing.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.supervised.lr_schedules
17 | import trax.models
18 | import trax.optimizers
19 | import trax.supervised.trainer_lib
20 |
21 | # Parameters for batcher:
22 | # ==============================================================================
23 | batcher.data_streams = @data.data_streams
24 | batcher.batch_size_per_device = 32
25 | batcher.eval_batch_size = 32
26 | batcher.variable_shapes = False
27 |
28 | # Parameters for data_streams:
29 | # ==============================================================================
30 | data_streams.data_dir = None
31 | data_streams.dataset_name = 't2t_image_imagenet224'
32 | data_streams.preprocess_fn = @data.squeeze_targets_preprocess
33 |
34 | # Parameters for multifactor:
35 | # ==============================================================================
36 | multifactor.factors = 'constant * linear_warmup'
37 | multifactor.constant = 0.2
38 | multifactor.warmup_steps = 400
39 |
40 | # Parameters for Momentum:
41 | # ==============================================================================
42 | Momentum.mass = 0.9
43 |
44 | # Parameters for Resnet50:
45 | # ==============================================================================
46 | Resnet50.d_hidden = 64
47 | Resnet50.n_output_classes = 1001
48 |
49 | # Parameters for train:
50 | # ==============================================================================
51 | train.eval_frequency = 2000
52 | train.eval_steps = 20
53 | train.model = @trax.models.Resnet50
54 | train.optimizer = @trax.optimizers.Momentum
55 | train.steps = 100000
56 |
--------------------------------------------------------------------------------
/trax/supervised/configs/rse_addition_sweep.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # multifactor.warmup_steps: [4000, 5000, 6000]
16 | # multifactor.constant: [1.0e-1, 5.0e-1, 1.0, 5.0]
17 |
18 | # ResidualShuffleExchange.n_blocks: [1, 2]
19 | ResidualShuffleExchange.dropout: [0.0, 0.05, 0.1]
20 | ResidualShuffleExchange.input_dropout: [0.0, 0.05, 0.1]
21 |
22 | # WeightedCategoryCrossEntropy.label_smoothing: [0., 0.1, 0.05, 0.01, 0.005]
23 |
24 | # train.optimizer: ["%Adam", "%Adafactor"]
25 | # Adafactor.do_momentum: [True, False]
26 |
--------------------------------------------------------------------------------
/trax/supervised/configs/skipping_transformer_lm1b.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models
17 | import trax.optimizers
18 | import trax.supervised.trainer_lib
19 |
20 | # Parameters for batcher:
21 | # ==============================================================================
22 | batcher.data_streams = @data.data_streams
23 | batcher.batch_size_per_device = 96
24 | batcher.eval_batch_size = 128
25 | batcher.max_eval_length = 2048
26 | batcher.id_to_mask = 0
27 |
28 | # Parameters for data_streams:
29 | # ==============================================================================
30 | data_streams.data_dir = None
31 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k'
32 | data_streams.input_name = 'targets'
33 | data_streams.preprocess_fn = @data.lm1b_preprocess
34 |
35 | # Parameters for multifactor:
36 | # ==============================================================================
37 | multifactor.constant = 0.3
38 | multifactor.factors = 'constant * linear_warmup'
39 | multifactor.warmup_steps = 8000
40 |
41 | # Parameters for lm1b_preprocess:
42 | # ==============================================================================
43 | lm1b_preprocess.max_target_length = 512
44 | lm1b_preprocess.max_eval_target_length = 2048
45 |
46 | # Parameters for train:
47 | # ==============================================================================
48 | train.eval_frequency = 1000
49 | train.eval_steps = 10
50 | train.model = @trax.models.SkippingTransformerLM
51 | train.optimizer = @trax.optimizers.SM3
52 | train.steps = 500000
53 |
54 | # Parameters for DotProductCausalAttention:
55 | # ==============================================================================
56 | DotProductCausalAttention.dropout = 0.1
57 |
58 | # Parameters for SkippingTransformerLM:
59 | # ==============================================================================
60 | SkippingTransformerLM.d_model = 1024
61 | SkippingTransformerLM.d_ff = 4096
62 | SkippingTransformerLM.dropout = 0.2
63 | SkippingTransformerLM.max_len = 2048
64 | SkippingTransformerLM.mode = 'train'
65 | SkippingTransformerLM.n_heads = 8
66 | SkippingTransformerLM.n_layers = 8
67 | SkippingTransformerLM.vocab_size = 32000
68 |
--------------------------------------------------------------------------------
/trax/supervised/configs/sp_sweep.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | ConfigurableTerraformer.d_ff: [2048, 4096]
16 | ConfigurableTerraformer.ff_chunk_size: [0, 2048, 4096, 8192]
17 | LSHSelfAttention.chunk_len: [64, 128, 256]
18 | LSHSelfAttention.n_buckets: [128, 256, 512]
19 | multifactor.constant: [0.03125, 0.0442, 1.0]
20 |
--------------------------------------------------------------------------------
/trax/supervised/configs/t5_mathqa_drop_sweep.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | Adafactor.weight_decay_rate: [0, ]
16 | multifactor.constant: [0.001, ]
17 | CreateAnnotatedDropInputs.percentile: [0.01, 0.05, 0.1, 0.25, 0.5, 1]
18 |
--------------------------------------------------------------------------------
/trax/supervised/configs/t5_sweep.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | dataset_name: [
16 | 'glue/mrpc',
17 | 'glue/sst2',
18 | 'glue/qqp',
19 | 'glue/mnli',
20 | 'glue/qnli',
21 | 'glue/rte',
22 | ]
23 | Adafactor.weight_decay_rate: [0, ]
24 | multifactor.constant: [0.001, 0.0001, 0.00001]
25 |
26 |
27 |
--------------------------------------------------------------------------------
/trax/supervised/configs/t5_sweep_temperature.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | dataset_name: [
16 | 'glue/mrpc',
17 | 'glue/sst2',
18 | 'glue/qqp',
19 | 'glue/mnli',
20 | 'glue/qnli',
21 | 'glue/rte',
22 | ]
23 | ConfigurableTransformer.ff_sparsity: ['64 64 0.0 1.0', '64 64 0.0 0.0']
24 |
--------------------------------------------------------------------------------
/trax/supervised/configs/terraformer_copy.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # -*-Python-*-
16 |
17 | include 'reformer_copy.gin'
18 |
19 | import trax.data
20 | import trax.models
21 | import trax.optimizers
22 | import trax.supervised.trainer_lib
23 |
24 | # Parameters for ConfigurableTerraformer:
25 | # ==============================================================================
26 | ConfigurableTerraformer.d_model = 256
27 | ConfigurableTerraformer.d_ff = 512
28 | ConfigurableTerraformer.dropout = 0.05
29 | ConfigurableTerraformer.max_len = %max_len
30 | ConfigurableTerraformer.n_heads = 4
31 | ConfigurableTerraformer.mode = 'train'
32 | ConfigurableTerraformer.n_encoder_layers = 3
33 | ConfigurableTerraformer.n_decoder_layers = 3
34 | ConfigurableTerraformer.ff_use_sru = 0
35 | ConfigurableTerraformer.d_attention_key = 64
36 | ConfigurableTerraformer.d_attention_value = 64
37 | ConfigurableTerraformer.encoder_attention_type = @LSHSelfAttention
38 | ConfigurableTerraformer.encoder_decoder_attention_type = @LSHSelfAttention
39 | ConfigurableTerraformer.n_decoder_attention_layers = 1
40 | ConfigurableTerraformer.input_vocab_size = %vocab_size
41 | ConfigurableTerraformer.pos_type = 'fixed-base'
42 |
43 | # Parameters for train:
44 | # ==============================================================================
45 | train.inputs = @trax.data.simple_sequence_copy_inputs
46 | train.model = @trax.models.ConfigurableTerraformer
47 |
--------------------------------------------------------------------------------
/trax/supervised/configs/terraformer_copy_self_attn.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # -*-Python-*-
16 |
17 | include 'terraformer_copy.gin'
18 |
19 | import trax.models
20 |
21 |
22 | # Parameters for ConfigurableTerraformer:
23 | # ==============================================================================
24 | ConfigurableTerraformer.encoder_attention_type = @SelfAttention
25 | ConfigurableTerraformer.encoder_decoder_attention_type = @SelfAttention
26 |
27 | # Parameters for SelfAttention:
28 | # ==============================================================================
29 | # 0 < predict_drop_len <= predict_mem_len
30 | SelfAttention.predict_mem_len = %max_len
31 | SelfAttention.predict_drop_len = %max_len
32 |
--------------------------------------------------------------------------------
/trax/supervised/configs/terraformer_purelsh_copy.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # -*-Python-*-
16 |
17 | include 'terraformer_copy.gin'
18 |
19 | import trax.data
20 | import trax.models
21 | import trax.optimizers
22 | import trax.supervised.trainer_lib
23 |
24 |
25 | # Parameters for PureLSHSelfAttention:
26 | # ==============================================================================
27 | PureLSHSelfAttention.attention_dropout = 0.0
28 | PureLSHSelfAttention.chunk_len = 16
29 | PureLSHSelfAttention.n_buckets = [32, 32]
30 | PureLSHSelfAttention.n_chunks_after = 0
31 | PureLSHSelfAttention.n_chunks_before = 1
32 | PureLSHSelfAttention.n_hashes = 2
33 | PureLSHSelfAttention.n_parallel_heads = 1
34 | PureLSHSelfAttention.max_length_for_buckets = 1024
35 | # 0 < predict_drop_len <= predict_mem_len
36 | PureLSHSelfAttention.predict_mem_len = %max_len
37 | PureLSHSelfAttention.predict_drop_len = %max_len
38 |
39 | # Parameters for PureLSHSelfAttentionWrapper:
40 | # ==============================================================================
41 | PureLSHSelfAttentionWrapper.pure_lsh_implementation = @PureLSHSelfAttention
42 |
43 | # We need something special for the encoder.
44 | enc/PureLSHSelfAttention.n_chunks_after = 1
45 | encoder/PureLSHSelfAttentionWrapper.pure_lsh_implementation = @enc/PureLSHSelfAttention
46 |
47 | # Parameters for ConfigurableTerraformer:
48 | # ==============================================================================
49 | ConfigurableTerraformer.encoder_attention_type = @encoder/PureLSHSelfAttentionWrapper
50 | ConfigurableTerraformer.encoder_decoder_attention_type = @PureLSHSelfAttentionWrapper
51 |
52 | # Parameters for train:
53 | # ==============================================================================
54 | train.inputs = @trax.data.simple_sequence_copy_inputs
55 | train.model = @trax.models.ConfigurableTerraformer
56 |
--------------------------------------------------------------------------------
/trax/supervised/configs/transformer_big_lm1b_8gb.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models
17 | import trax.optimizers
18 | import trax.supervised.trainer_lib
19 |
20 | # Parameters for batcher:
21 | # ==============================================================================
22 | batcher.data_streams = @data.data_streams
23 | batcher.batch_size_per_device = 32
24 | batcher.eval_batch_size = 64
25 | batcher.max_eval_length = 512
26 | batcher.id_to_mask = 0
27 |
28 | # Parameters for data_streams:
29 | # ==============================================================================
30 | data_streams.data_dir = None
31 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k'
32 | data_streams.input_name = 'targets'
33 | data_streams.preprocess_fn = @data.lm1b_preprocess
34 |
35 | # Parameters for multifactor:
36 | # ==============================================================================
37 | multifactor.constant = 0.1
38 | multifactor.factors = 'constant * linear_warmup'
39 | multifactor.warmup_steps = 16000
40 |
41 | # Parameters for lm1b_preprocess:
42 | # ==============================================================================
43 | lm1b_preprocess.max_target_length = 512
44 | lm1b_preprocess.max_eval_target_length = 512
45 |
46 | # Parameters for train:
47 | # ==============================================================================
48 | train.eval_frequency = 1000
49 | train.eval_steps = 10
50 | train.model = @trax.models.TransformerLM
51 | train.optimizer = @trax.optimizers.SM3
52 | train.steps = 500000
53 |
54 | # Parameters for TransformerLM:
55 | # ==============================================================================
56 | TransformerLM.d_model = 1024
57 | TransformerLM.d_ff = 8192
58 | TransformerLM.dropout = 0.1
59 | TransformerLM.max_len = 2048
60 | TransformerLM.mode = 'train'
61 | TransformerLM.n_heads = 8
62 | TransformerLM.n_layers = 8
63 | TransformerLM.vocab_size = 32000
64 |
--------------------------------------------------------------------------------
/trax/supervised/configs/transformer_finetune_squad_16gb.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | include 'c4_pretrain_16gb_adafactor.gin'
17 |
18 |
19 | # Parameters for train:
20 | # ==============================================================================
21 | # This should always be overriden since this is relative to the pre-trained
22 | # checkpoint.
23 | train.steps = 262144
24 |
25 |
26 | # Parameters for batcher:
27 | # ==============================================================================
28 | batcher.buckets = ([513,], [8, 8])
29 | batcher.strict_pad_on_len = True
30 | batcher.buckets_include_inputs_in_length = True
31 |
32 |
33 | # Parameters for multifactor:
34 | # ==============================================================================
35 | multifactor.factors = 'constant'
36 | multifactor.constant = 0.001
37 |
38 |
39 | # Parameters for data_streams:
40 | # ==============================================================================
41 | data_streams.dataset_name = 'squad/plain_text:1.0.0'
42 | data_streams.bare_preprocess_fn = @data.generic_text_dataset_preprocess_fn
43 |
44 |
45 | # Parameters for get_t5_preprocessor_by_name:
46 | # ==============================================================================
47 | squad/get_t5_preprocessor_by_name.name = 'squad'
48 | squad/get_t5_preprocessor_by_name.fn_kwargs = {'include_context': True}
49 |
50 |
51 | # Parameters for generic_text_dataset_preprocess_fn:
52 | # ==============================================================================
53 | generic_text_dataset_preprocess_fn.text_preprocess_fns = [
54 | @squad/get_t5_preprocessor_by_name()
55 | ]
56 | generic_text_dataset_preprocess_fn.token_preprocess_fns = [
57 | @data.add_eos_to_output_features
58 | ]
59 |
--------------------------------------------------------------------------------
/trax/supervised/configs/transformer_lm1b_8gb_testing.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models
17 | import trax.optimizers
18 | import trax.supervised.trainer_lib
19 |
20 | # Module trax.data:
21 | # ==============================================================================
22 | batcher.data_streams = @data.data_streams
23 | batcher.batch_size_per_device = 128
24 | batcher.eval_batch_size = 128
25 | batcher.max_eval_length = 2048
26 | batcher.id_to_mask = 0
27 |
28 | data_streams.data_dir = None
29 | data_streams.dataset_name = 't2t_languagemodel_lm1b32k'
30 | data_streams.input_name = 'targets'
31 | data_streams.preprocess_fn = @data.lm1b_preprocess
32 |
33 | lm1b_preprocess.max_target_length = 512
34 | lm1b_preprocess.max_eval_target_length = 2048
35 |
36 | # Module trax.models.transformer:
37 | # ==============================================================================
38 | TransformerLM.d_model = 512
39 | TransformerLM.d_ff = 2048
40 | TransformerLM.dropout = 0.1
41 | TransformerLM.max_len = 2048
42 | TransformerLM.mode = 'train'
43 | TransformerLM.n_heads = 8
44 | TransformerLM.n_layers = 6
45 | TransformerLM.vocab_size = 32000
46 |
47 | # Module trax.supervised.lr_schedules:
48 | # ==============================================================================
49 | multifactor.constant = 0.1
50 | multifactor.factors = 'constant * linear_warmup * rsqrt_decay'
51 | multifactor.warmup_steps = 8000
52 |
53 | # Module trax.supervised.trainer_lib:
54 | # ==============================================================================
55 | train.eval_frequency = 1000
56 | train.eval_steps = 10
57 | train.model = @trax.models.TransformerLM
58 | train.optimizer = @trax.optimizers.Adam
59 | train.steps = 100000
60 |
--------------------------------------------------------------------------------
/trax/supervised/configs/transformer_ptb_16gb.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.models
17 | import trax.optimizers
18 | import trax.supervised.lr_schedules
19 | import trax.supervised.trainer_lib
20 |
21 | # Parameters for batcher:
22 | # ==============================================================================
23 | batcher.data_streams = @data.data_streams
24 | batcher.batch_size_per_device = 64
25 | batcher.eval_batch_size = 512
26 | batcher.max_eval_length = 2048
27 | batcher.id_to_mask = 0
28 |
29 | # Parameters for data_streams:
30 | # ==============================================================================
31 | data_streams.data_dir = None
32 | data_streams.dataset_name = 't2t_languagemodel_ptb10k'
33 | data_streams.input_name = 'targets'
34 | data_streams.preprocess_fn = @data.lm1b_preprocess
35 |
36 | # Parameters for multifactor:
37 | # ==============================================================================
38 | multifactor.constant = 1.0
39 | multifactor.factors = 'constant * linear_warmup * rsqrt_decay'
40 | multifactor.warmup_steps = 8000
41 |
42 | # Parameters for lm1b_preprocess:
43 | # ==============================================================================
44 | lm1b_preprocess.max_target_length = 512
45 | lm1b_preprocess.max_eval_target_length = 2048
46 |
47 | # Parameters for train:
48 | # ==============================================================================
49 | train.eval_frequency = 200
50 | train.eval_steps = 2
51 | train.model = @trax.models.TransformerLM
52 | train.optimizer = @trax.optimizers.Adafactor
53 | train.steps = 20000
54 |
55 | # Parameters for TransformerLM:
56 | # ==============================================================================
57 | TransformerLM.d_model = 512
58 | TransformerLM.d_ff = 2048
59 | TransformerLM.dropout = 0.5
60 | TransformerLM.max_len = 2048
61 | TransformerLM.mode = 'train'
62 | TransformerLM.n_heads = 8
63 | TransformerLM.n_layers = 6
64 | TransformerLM.vocab_size = 10240
65 |
--------------------------------------------------------------------------------
/trax/supervised/configs/wide_resnet_cifar10_8gb.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Trax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import trax.data
16 | import trax.supervised.lr_schedules
17 | import trax.models
18 | import trax.optimizers
19 | import trax.supervised.trainer_lib
20 |
21 | # Parameters for batcher:
22 | # ==============================================================================
23 | batcher.data_streams = @data.data_streams
24 | batcher.batch_size_per_device = 256
25 | batcher.bucket_length = 32
26 | batcher.buckets = None
27 | batcher.eval_batch_size = 512
28 | batcher.variable_shapes = False
29 |
30 | # Parameters for data_streams:
31 | # ==============================================================================
32 | data_streams.data_dir = None
33 | data_streams.dataset_name = 'cifar10'
34 | data_streams.preprocess_fn = @data.cifar10_augmentation_preprocess
35 |
36 | # Parameters for multifactor:
37 | # ==============================================================================
38 | multifactor.factors = 'constant * linear_warmup'
39 | multifactor.constant = 0.5
40 | multifactor.warmup_steps = 400
41 |
42 | # Parameters for Momentum:
43 | # ==============================================================================
44 | Momentum.mass = 0.9
45 | Momentum.weight_decay_rate = 5e-4
46 |
47 | # Parameters for WideResnet:
48 | # ==============================================================================
49 | WideResnet.widen_factor = 10
50 | WideResnet.n_blocks = 4
51 | WideResnet.n_output_classes = 10
52 |
53 | # Parameters for train:
54 | # ==============================================================================
55 | train.eval_frequency = 100
56 | train.eval_steps = 10
57 | train.model = @trax.models.WideResnet
58 | train.optimizer = @trax.optimizers.Momentum
59 | train.steps = 10000
60 |
--------------------------------------------------------------------------------
/trax/supervised/history_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for trax.supervised.history."""
17 |
18 | from absl.testing import absltest
19 |
20 | from trax.supervised import history as trax_history
21 |
22 |
23 | class HistoryTest(absltest.TestCase):
24 |
25 | def test_unknown_mode(self):
26 | history = trax_history.History()
27 | history.append('train', 'metric1', 1, 0.1)
28 | self.assertEqual(history.get('unknown_mode', 'metric1'), [])
29 |
30 | def test_unknown_metric(self):
31 | history = trax_history.History()
32 | history.append('train', 'metric1', 1, 0.1)
33 | self.assertEqual(history.get('train', 'unknown_metric'), [])
34 |
35 | def test_serializer_and_deserializer(self):
36 | history = trax_history.History()
37 | history.append('train', 'metric1', 1, 0.1)
38 | json_object = history.to_dict()
39 | history2 = trax_history.History.from_dict(json_object)
40 | self.assertEqual(history2.get('train', 'metric1'), [(1, 0.1)])
41 |
42 | def test_modes(self):
43 | history = trax_history.History()
44 | history.append('train', 'metric1', 1, 0.1)
45 | history.append('test', 'metric2', 2, 0.2)
46 | self.assertEqual(history.modes, ['test', 'train'])
47 |
48 | def test_metrics_for_mode(self):
49 | history = trax_history.History()
50 | history.append('train', 'metric1', 1, 0.1)
51 | history.append('train', 'metric2', 2, 0.2)
52 | self.assertEqual(history.metrics_for_mode('train'), ['metric1', 'metric2'])
53 |
54 |
55 | if __name__ == '__main__':
56 | absltest.main()
57 |
--------------------------------------------------------------------------------
/trax/supervised/testdata/reformerlm_copy_lsh_attn.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/supervised/testdata/reformerlm_copy_lsh_attn.pkl.gz
--------------------------------------------------------------------------------
/trax/supervised/testdata/terraformer_copy_lsh_attn.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/supervised/testdata/terraformer_copy_lsh_attn.pkl.gz
--------------------------------------------------------------------------------
/trax/supervised/testdata/terraformer_copy_self_attn.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/supervised/testdata/terraformer_copy_self_attn.pkl.gz
--------------------------------------------------------------------------------
/trax/supervised/testdata/terraformer_purelsh_copy.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/supervised/testdata/terraformer_purelsh_copy.pkl.gz
--------------------------------------------------------------------------------
/trax/supervised/testdata/transformer_copy.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/supervised/testdata/transformer_copy.pkl.gz
--------------------------------------------------------------------------------
/trax/supervised/testdata/transformerlm_copy.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/trax/7c189c2c3a37851df7fc641b322a99bbe9f8aa10/trax/supervised/testdata/transformerlm_copy.pkl.gz
--------------------------------------------------------------------------------
/trax/test_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A few utilities for tests."""
17 |
18 | import sys
19 |
20 | from absl import flags
21 |
22 | FLAGS = flags.FLAGS
23 |
24 |
25 | # pytest doesn't run the test as a main, so it doesn't parse the flags
26 | # so if flags are required in tests, this will ensure that flags are manually
27 | # parsed and the desired flag exists.
28 | def ensure_flag(flag_str):
29 | try:
30 | getattr(FLAGS, flag_str)
31 | except flags.UnparsedFlagAccessError:
32 | # Manually parse flags.
33 | FLAGS(sys.argv)
34 | finally:
35 | assert getattr(FLAGS, flag_str)
36 |
--------------------------------------------------------------------------------
/trax/tf_numpy/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/trax/tf_numpy/examples/mnist/train_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test that the example training script works on fake data."""
17 | import mock
18 | import numpy as np
19 | import tensorflow.compat.v2 as tf
20 |
21 | from trax.tf_numpy.examples.mnist import dataset
22 | from trax.tf_numpy.examples.mnist import train
23 |
24 |
25 | class TFNumpyMnistExampleTest(tf.test.TestCase):
26 |
27 | def testRuns(self):
28 | with mock.patch.object(dataset, 'load', new=fake_mnist_data):
29 | train.train(
30 | batch_size=1,
31 | learning_rate=0.1,
32 | num_training_iters=10,
33 | validation_steps=5)
34 | train.train(
35 | batch_size=2,
36 | learning_rate=0.1,
37 | num_training_iters=5,
38 | validation_steps=2)
39 | train.train(
40 | batch_size=10,
41 | learning_rate=0.1,
42 | num_training_iters=1,
43 | validation_steps=1)
44 |
45 |
46 | def fake_mnist_data():
47 |
48 | def gen_examples(num_examples):
49 | x = np.asarray(
50 | np.random.randn(num_examples, 784), dtype=np.float32)
51 | y = np.zeros((num_examples, 10), dtype=np.float32)
52 | y[:][0] = 1.
53 | return (x, y)
54 |
55 | return (gen_examples(100), gen_examples(10), gen_examples(10))
56 |
57 |
58 | if __name__ == '__main__':
59 | tf.compat.v1.enable_eager_execution()
60 | tf.test.main()
61 |
--------------------------------------------------------------------------------
/trax/tf_numpy/extensions/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """JAX-like function transformations and extensions for TF-numpy."""
17 |
18 | # pylint: disable=wildcard-import
19 | from trax.tf_numpy.extensions.extensions import *
20 | # pylint: enable=wildcard-import
21 |
--------------------------------------------------------------------------------
/trax/tf_numpy/numpy/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """NumPy like wrapper for Tensorflow."""
17 |
18 | # pylint: disable=wildcard-import
19 | # pylint: disable=g-import-not-at-top
20 | # pylint: disable=g-direct-tensorflow-import
21 |
22 | try:
23 | # Note that this import will work in tf-nightly and TF versions 2.4 and
24 | # higher.
25 | from tensorflow.experimental.numpy import *
26 | # TODO(agarwal): get rid of following imports.
27 | from tensorflow.experimental.numpy import random
28 | from tensorflow import bfloat16
29 | import numpy as onp
30 | from tensorflow.python.ops.numpy_ops.np_dtypes import canonicalize_dtype
31 | from tensorflow.python.ops.numpy_ops.np_dtypes import default_float_type
32 | from tensorflow.python.ops.numpy_ops.np_dtypes import is_allow_float64
33 | from tensorflow.python.ops.numpy_ops.np_dtypes import set_allow_float64
34 |
35 | random.DEFAULT_RANDN_DTYPE = onp.float32
36 | except ImportError:
37 | try:
38 | # Note that this import will work in TF 2.3 and higher.
39 | from tensorflow.python.ops.numpy_ops import *
40 | from tensorflow import bfloat16
41 |
42 | except ImportError:
43 | # Note that this fallback will be needed for TF 2.2.
44 | from tensorflow import newaxis
45 |
46 | from trax.tf_numpy.numpy_impl import random
47 |
48 | # pylint: disable=wildcard-import
49 | from trax.tf_numpy.numpy_impl.array_ops import *
50 | from trax.tf_numpy.numpy_impl.arrays import *
51 | from trax.tf_numpy.numpy_impl.dtypes import *
52 | from trax.tf_numpy.numpy_impl.math_ops import *
53 | from trax.tf_numpy.numpy_impl.utils import finfo
54 | from trax.tf_numpy.numpy_impl.utils import promote_types
55 | from trax.tf_numpy.numpy_impl.utils import result_type
56 | # pylint: enable=wildcard-import
57 |
58 | max = amax # pylint: disable=redefined-builtin,undefined-variable
59 | min = amin # pylint: disable=redefined-builtin,undefined-variable
60 | round = around # pylint: disable=redefined-builtin,undefined-variable
61 |
62 | try:
63 | from tensorflow.python.ops.numpy_ops.np_config import enable_numpy_behavior
64 | # TODO(b/171429739): This should be moved to every individual file/test.
65 | enable_numpy_behavior()
66 |
67 | except ImportError:
68 | pass
69 |
--------------------------------------------------------------------------------
/trax/tf_numpy/numpy_impl/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """NumPy API. Deprecated."""
17 |
--------------------------------------------------------------------------------
/trax/tf_numpy/numpy_impl/dtypes.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Dtypes and dtype utilities."""
17 | import numpy as np
18 |
19 | # We use numpy's dtypes instead of TF's, because the user expects to use them
20 | # with numpy facilities such as `np.dtype(np.int64)` and
21 | # `if x.dtype.type is np.int64`.
22 | # pylint: disable=unused-import
23 | # pylint: disable=g-bad-import-order
24 | from numpy import bool_
25 | from numpy import int_
26 | from numpy import int16
27 | from numpy import int32
28 | from numpy import int64
29 | from numpy import int8
30 | from numpy import uint16
31 | from numpy import uint32
32 | from numpy import uint64
33 | from numpy import uint8
34 | from numpy import float16
35 | from numpy import float32
36 | from numpy import float64
37 | float_ = float64
38 | from numpy import complex64
39 | from numpy import complex128
40 | complex_ = complex128
41 |
42 | from numpy import inexact
43 |
44 | from numpy import iinfo
45 | from numpy import issubdtype
46 |
47 | from numpy import inf
48 |
49 | # TODO(wangpeng): Make bfloat16 a numpy dtype instead of using TF's
50 | from tensorflow.compat.v2 import bfloat16
51 | # pylint: enable=g-bad-import-order
52 | # pylint: enable=unused-import
53 |
54 |
55 | _to_float32 = {
56 | np.dtype('float64'): np.dtype('float32'),
57 | np.dtype('complex128'): np.dtype('complex64'),
58 | }
59 |
60 |
61 | _allow_float64 = True
62 |
63 |
64 | def is_allow_float64():
65 | return _allow_float64
66 |
67 |
68 | def set_allow_float64(b):
69 | global _allow_float64
70 | _allow_float64 = b
71 |
72 |
73 | def canonicalize_dtype(dtype):
74 | if not is_allow_float64():
75 | return _to_float32.get(dtype, dtype)
76 | else:
77 | return dtype
78 |
79 |
80 | def _result_type(*arrays_and_dtypes):
81 | dtype = np.result_type(*arrays_and_dtypes)
82 | return canonicalize_dtype(dtype)
83 |
84 |
85 | def default_float_type():
86 | """Gets the default float type.
87 |
88 | Returns:
89 | If `is_allow_float64()` is true, returns float64; otherwise returns float32.
90 | """
91 | if is_allow_float64():
92 | return float64
93 | else:
94 | return float32
95 |
--------------------------------------------------------------------------------
/trax/tf_numpy/numpy_impl/random.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Random functions."""
17 | import numpy as np
18 | import tensorflow.compat.v2 as tf
19 |
20 | from trax.tf_numpy.numpy_impl import utils
21 |
22 |
23 | DEFAULT_RANDN_DTYPE = np.float32
24 |
25 |
26 | def randn(*args):
27 | """Returns samples from a normal distribution.
28 |
29 | Uses `tf.random_normal`.
30 |
31 | Args:
32 | *args: The shape of the output array.
33 |
34 | Returns:
35 | An ndarray with shape `args` and dtype `float64`.
36 | """
37 | # TODO(wangpeng): Use new stateful RNG
38 | if utils.isscalar(args):
39 | args = (args,)
40 | return utils.tensor_to_ndarray(
41 | tf.random.normal(args, dtype=DEFAULT_RANDN_DTYPE))
42 |
43 |
44 | def seed(s):
45 | """Sets the seed for the random number generator.
46 |
47 | Uses `tf.set_random_seed`.
48 |
49 | Args:
50 | s: an integer.
51 | """
52 | # TODO(wangpeng): make the signature the same as numpy
53 | tf.random.set_seed(s)
54 |
--------------------------------------------------------------------------------
/trax/tf_numpy/numpy_impl/tests/backprop_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for backpropgration on tf-numpy functions."""
17 | import tensorflow.compat.v2 as tf
18 |
19 | from trax.tf_numpy.numpy_impl import array_ops
20 | # Required for operator overloads
21 | from trax.tf_numpy.numpy_impl import math_ops # pylint: disable=unused-import
22 |
23 |
24 | class BackpropTest(tf.test.TestCase):
25 |
26 | def test_setitem(self):
27 | # Single integer index.
28 | a = array_ops.array([1., 2., 3.])
29 | b = array_ops.array(5.)
30 | c = array_ops.array(10.)
31 |
32 | tensors = [arr.data for arr in [a, b, c]]
33 | with tf.GradientTape() as g:
34 | g.watch(tensors)
35 | a[1] = b + c
36 | loss = array_ops.sum(a)
37 |
38 | gradients = g.gradient(loss.data, tensors)
39 | self.assertSequenceEqual(
40 | array_ops.array(gradients[0]).tolist(), [1., 0., 1.])
41 | self.assertEqual(array_ops.array(gradients[1]).tolist(), 1.)
42 | self.assertEqual(array_ops.array(gradients[2]).tolist(), 1.)
43 |
44 | # Tuple index.
45 | a = array_ops.array([[[1., 2.], [3., 4.]], [[5., 6.],
46 | [7., 8.]]]) # 2x2x2 array.
47 | b = array_ops.array([10., 11.])
48 |
49 | tensors = [arr.data for arr in [a, b]]
50 | with tf.GradientTape() as g:
51 | g.watch(tensors)
52 | a[(1, 0)] = b
53 | loss = array_ops.sum(a)
54 |
55 | gradients = g.gradient(loss.data, tensors)
56 | self.assertSequenceEqual(
57 | array_ops.array(gradients[0]).tolist(),
58 | [[[1., 1.], [1., 1.]], [[0., 0.], [1., 1.]]])
59 | self.assertEqual(array_ops.array(gradients[1]).tolist(), [1., 1.])
60 |
61 |
62 | if __name__ == '__main__':
63 | tf.compat.v1.enable_eager_execution()
64 | tf.test.main()
65 |
--------------------------------------------------------------------------------
/trax/tf_numpy/numpy_impl/tests/utils_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for utils.py."""
17 | import tensorflow.compat.v2 as tf
18 |
19 | from trax.tf_numpy.numpy_impl import utils
20 |
21 |
22 | class UtilsTest(tf.test.TestCase):
23 |
24 | # pylint: disable=unused-argument
25 | def testNpDoc(self):
26 | def np_fun(x):
27 | """np_fun docstring."""
28 | return
29 | @utils.np_doc(np_fun)
30 | def f():
31 | """f docstring."""
32 | return
33 | expected = """TensorFlow variant of `numpy.np_fun`.
34 |
35 | Unsupported arguments: `x`.
36 |
37 | f docstring.
38 |
39 | Documentation for `numpy.np_fun`:
40 |
41 | np_fun docstring."""
42 | self.assertEqual(f.__doc__, expected)
43 |
44 |
45 | if __name__ == '__main__':
46 | tf.enable_v2_behavior()
47 | tf.test.main()
48 |
--------------------------------------------------------------------------------
/trax/tf_numpy/public_symbol_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Trax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests different ways to use the public tf-numpy module."""
17 | import numpy as onp
18 |
19 | import tensorflow as tf
20 | import tensorflow.experimental.numpy as np1
21 | from tensorflow.experimental import numpy as np2 # pylint: disable=reimported
22 |
23 |
24 | np3 = tf.experimental.numpy
25 |
26 |
27 | class PublicSymbolTest(tf.test.TestCase):
28 |
29 | def testSimple(self):
30 | a = 0.1
31 | b = 0.2
32 | for op in [np1.add, np2.add, np3.add]:
33 | self.assertAllClose(onp.add(a, b), op(a, b))
34 |
35 |
36 | if __name__ == "__main__":
37 | tf.compat.v1.enable_eager_execution()
38 | tf.test.main()
39 |
--------------------------------------------------------------------------------