├── .readthedocs.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── Makefile
├── README.md
├── _static
│ ├── continuous_eval_system_diagram.png
│ ├── style.css
│ └── vis_Conv_Cifar10_32x64x64.png
├── _templates
│ └── layout.html
├── conf.py
├── continuous_eval.rst
├── index.rst
├── learned_optimizers.rst
├── notebooks
│ ├── Part1_Introduction.ipynb
│ ├── Part1_Introduction.md
│ ├── Part1_Introduction.py
│ ├── Part2_CustomTasks.ipynb
│ ├── Part2_CustomTasks.md
│ ├── Part2_CustomTasks.py
│ ├── Part3_Truncation_TruncatedStep.ipynb
│ ├── Part3_Truncation_TruncatedStep.md
│ ├── Part3_Truncation_TruncatedStep.py
│ ├── Part4_GradientEstimators.ipynb
│ ├── Part4_GradientEstimators.md
│ ├── Part4_GradientEstimators.py
│ ├── Part5_Meta_training_with_GradientLearner.ipynb
│ ├── Part5_Meta_training_with_GradientLearner.md
│ ├── Part5_Meta_training_with_GradientLearner.py
│ ├── Part6_custom_learned_optimizers.ipynb
│ ├── Part6_custom_learned_optimizers.md
│ ├── Part6_custom_learned_optimizers.py
│ ├── no_dependency_learned_optimizer.ipynb
│ ├── no_dependency_learned_optimizer.md
│ ├── no_dependency_learned_optimizer.py
│ ├── summary_tutorial.ipynb
│ ├── summary_tutorial.md
│ └── summary_tutorial.py
├── optimizer_baselines.rst
├── optimizers.rst
├── outer_trainers.rst
├── population.rst
├── requirements.txt
└── tasks.rst
├── learned_optimization
├── __init__.py
├── baselines
│ ├── __init__.py
│ ├── data
│ │ ├── speedup_over_adam.json
│ │ ├── speedup_over_adam.pkl
│ │ ├── speedup_over_adam_v2.json
│ │ ├── speedup_over_multiple_baselines.json
│ │ └── speedup_over_multiple_baselines_v2.json
│ ├── hparam_sets.py
│ ├── normalizers.py
│ ├── normalizers_test.py
│ ├── run_archive.py
│ ├── run_time_task.py
│ ├── run_trainer.py
│ ├── run_trainer_test.py
│ ├── utils.py
│ └── vis.py
├── checkpoints.py
├── circular_buffer.py
├── circular_buffer_test.py
├── continuous_eval
│ ├── __init__.py
│ ├── evaluation_set.py
│ ├── run_eval_chief.py
│ ├── run_eval_chief_test.py
│ ├── run_eval_worker.py
│ ├── run_eval_worker_test.py
│ ├── task_group_server.py
│ └── task_group_server_test.py
├── distributed.py
├── distributed_test.py
├── entry_points
│ ├── README.md
│ ├── run_check_eval_worker.py
│ ├── run_outer_train_single.py
│ ├── run_outer_trainer.py
│ ├── run_outer_worker.py
│ ├── run_population_controller.py
│ ├── run_single_machine_trainer.py
│ └── run_tboard_eval.py
├── eval_training.py
├── eval_training_multi_device_test.py
├── eval_training_test.py
├── examples
│ ├── README.md
│ ├── simple_lopt_train.py
│ └── simple_lopt_train_test.py
├── filesystem.py
├── jax_utils.py
├── jax_utils_test.py
├── launchers
│ ├── __init__.py
│ ├── launch_template_multi_machine_lopt.py
│ └── utils.py
├── learned_optimizers
│ ├── __init__.py
│ ├── adafac_mlp_lopt.py
│ ├── adafac_mlp_lopt_test.py
│ ├── adafac_nominal.py
│ ├── adafac_nominal_test.py
│ ├── base.py
│ ├── base_test.py
│ ├── common.py
│ ├── mlp_lopt.py
│ ├── mlp_lopt_test.py
│ ├── nn_adam.py
│ ├── nn_adam_test.py
│ ├── opt_from_checkpoint.py
│ ├── rnn_mlp_lopt.py
│ ├── rnn_mlp_lopt_test.py
│ └── test_utils.py
├── notebook_utils.py
├── optimizers
│ ├── __init__.py
│ ├── alias.py
│ ├── base.py
│ ├── base_test.py
│ ├── data
│ │ └── opt_list.json
│ ├── gradient_accumulator.py
│ ├── gradient_accumulator_test.py
│ ├── learning_rate_schedules.py
│ ├── nadamw.py
│ ├── nadamw_test.py
│ ├── opt_list.py
│ ├── opt_list_test.py
│ ├── opt_to_optax.py
│ ├── opt_to_optax_test.py
│ ├── optax_opts.py
│ ├── optax_opts_test.py
│ ├── optimizer_wrappers.py
│ ├── shampoo.py
│ ├── shampoo_test.py
│ └── test_utils.py
├── outer_train.py
├── outer_train_test.py
├── outer_trainers
│ ├── __init__.py
│ ├── common.py
│ ├── full_es.py
│ ├── full_es_test.py
│ ├── full_grad.py
│ ├── full_grad_test.py
│ ├── gradient_learner.py
│ ├── gradient_learner_test.py
│ ├── lopt_truncated_step.py
│ ├── lopt_truncated_step_test.py
│ ├── test_utils.py
│ ├── truncated_es.py
│ ├── truncated_es_test.py
│ ├── truncated_grad.py
│ ├── truncated_grad_test.py
│ ├── truncated_pes.py
│ ├── truncated_pes_pmap_test.py
│ ├── truncated_pes_test.py
│ ├── truncated_step.py
│ ├── truncation_schedule.py
│ └── truncation_schedule_test.py
├── population
│ ├── __init__.py
│ ├── examples
│ │ ├── __init__.py
│ │ ├── complex_cnn
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ └── train_threads.py
│ │ ├── simple_cnn
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── train.py
│ │ │ ├── train_test.py
│ │ │ └── train_threads.py
│ │ └── synthetic
│ │ │ ├── __init__.py
│ │ │ ├── train.py
│ │ │ └── train_test.py
│ ├── mutators
│ │ ├── __init__.py
│ │ ├── fixed_schedule.py
│ │ ├── fixed_schedule_test.py
│ │ ├── single_worker_explore.py
│ │ ├── single_worker_explore_test.py
│ │ ├── winner_take_all_genetic.py
│ │ └── winner_take_all_genetic_test.py
│ ├── population.py
│ └── population_test.py
├── profile.py
├── py_utils.py
├── research
│ ├── README.md
│ ├── __init__.py
│ ├── brax
│ │ ├── brax_env_truncated_step.py
│ │ ├── brax_env_truncated_step_test.py
│ │ └── run_single_machine_rl_trainer.py
│ ├── data_driven
│ │ ├── README.md
│ │ ├── data.py
│ │ ├── mnist_projections.py
│ │ ├── mnist_projections_test.py
│ │ ├── model_components.py
│ │ ├── models.py
│ │ ├── resnet.py
│ │ ├── run_mnist_projections.py
│ │ ├── summary.py
│ │ └── transformer.py
│ ├── discovered_policy_optimisation
│ │ └── truncated_step.py
│ ├── distill
│ │ ├── truncated_distill.py
│ │ └── truncated_distill_test.py
│ ├── es_scaling
│ │ └── es_scaling_tasks.py
│ ├── general_lopt
│ │ ├── Demo_for_training_a_model_with_a_learned_optimizer.ipynb
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── configs
│ │ │ ├── README.md
│ │ │ ├── large_scale_phase1.py
│ │ │ ├── large_scale_phase2.py
│ │ │ ├── large_scale_phase3.py
│ │ │ └── large_scale_phase4.py
│ │ ├── hyper_v2.py
│ │ ├── hyper_v2_test.py
│ │ ├── prefab.py
│ │ ├── prefab_test.py
│ │ ├── pretrained_optimizers.py
│ │ ├── pretrained_optimizers_test.py
│ │ └── tasks
│ │ │ ├── __init__.py
│ │ │ ├── fast_mlp_diff_data.py
│ │ │ └── parametric.py
│ ├── hysteresis
│ │ ├── __init__.py
│ │ ├── data_in_state_tasks.py
│ │ ├── extra_small_lopt.py
│ │ ├── truncated_es_no_hyst.py
│ │ ├── truncated_es_shared_noise.py
│ │ ├── truncated_pes_no_hyst.py
│ │ └── truncated_pes_no_hyst_same_data.py
│ ├── jaxnerf
│ │ ├── datasets.py
│ │ ├── example_data
│ │ │ ├── imgs
│ │ │ │ └── r_0.png
│ │ │ ├── transforms_test.json
│ │ │ └── transforms_train.json
│ │ ├── jaxnerf.py
│ │ └── jaxnerf_test.py
│ ├── scaling
│ │ ├── scaling_tasks.py
│ │ └── scaling_transformer.py
│ └── univ_nfn
│ │ ├── gen_pred
│ │ └── train_pred.py
│ │ ├── learned_opt
│ │ ├── __init__.py
│ │ ├── learned_opts.py
│ │ └── outer_opt.py
│ │ └── nfn
│ │ ├── __init__.py
│ │ ├── ff_layers.py
│ │ ├── siren.py
│ │ ├── universal_layers.py
│ │ └── utils.py
├── run_outer_train_single.py
├── setup_experiment.py
├── summary.py
├── summary_test.py
├── tasks
│ ├── __init__.py
│ ├── base.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── beam
│ │ │ └── make_resized_image_beam.py
│ │ ├── image.py
│ │ ├── image_test.py
│ │ ├── language.py
│ │ └── language_test.py
│ ├── es_wrapper.py
│ ├── es_wrapper_test.py
│ ├── fixed
│ │ ├── __init__.py
│ │ ├── all.py
│ │ ├── conv.py
│ │ ├── conv_test.py
│ │ ├── es_wrapped.py
│ │ ├── es_wrapped_test.py
│ │ ├── image_mlp.py
│ │ ├── image_mlp_ae.py
│ │ ├── image_mlp_ae_test.py
│ │ ├── image_mlp_test.py
│ │ ├── lopt.py
│ │ ├── lopt_test.py
│ │ ├── mlp_mixer.py
│ │ ├── mlp_mixer_test.py
│ │ ├── resnet.py
│ │ ├── resnet_test.py
│ │ ├── rnn_lm.py
│ │ ├── rnn_lm_test.py
│ │ ├── transformer_lm.py
│ │ ├── transformer_lm_test.py
│ │ ├── vit.py
│ │ └── vit_test.py
│ ├── generative_model_utils.py
│ ├── parametric
│ │ ├── __init__.py
│ │ ├── cfgobject.py
│ │ ├── cfgobject_test.py
│ │ ├── image_conv.py
│ │ ├── image_conv_test.py
│ │ ├── image_mlp.py
│ │ ├── image_mlp_ae.py
│ │ ├── image_mlp_ae_test.py
│ │ ├── image_mlp_test.py
│ │ ├── image_mlp_vae.py
│ │ ├── image_mlp_vae_test.py
│ │ ├── image_resnet.py
│ │ ├── image_resnet_test.py
│ │ ├── lm_rnn.py
│ │ ├── lm_rnn_test.py
│ │ ├── lm_transformer.py
│ │ ├── lm_transformer_test.py
│ │ ├── lopt.py
│ │ ├── lopt_test.py
│ │ ├── lopt_trunc.py
│ │ ├── lopt_trunc_test.py
│ │ ├── parametric_utils.py
│ │ ├── parametric_utils_test.py
│ │ ├── vit.py
│ │ └── vit_test.py
│ ├── quadratics.py
│ ├── quadratics_test.py
│ ├── resnet.py
│ ├── rnn.py
│ ├── task_augmentation.py
│ ├── task_augmentation_test.py
│ ├── test_utils.py
│ └── transformer.py
├── time_filter
│ ├── model_paths.py
│ ├── run_sample_and_time.py
│ ├── run_train_time_model.py
│ ├── time_model.py
│ ├── time_model_data.py
│ ├── time_model_data_test.py
│ ├── timings.py
│ └── timings_test.py
├── training.py
├── tree_utils.py
└── tree_utils_test.py
├── requirements.txt
└── setup.py
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Set the version of Python and other tools you might need
9 | build:
10 | os: ubuntu-20.04
11 | tools:
12 | python: "3.9"
13 |
14 | # Build documentation in the docs/ directory with Sphinx
15 | sphinx:
16 | configuration: docs/conf.py
17 |
18 | # Optionally declare the Python requirements required to build your docs
19 | python:
20 | install:
21 | - requirements: docs/requirements.txt
22 | - requirements: requirements.txt
23 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | sync-notebooks:
16 | jupytext --sync notebooks/*{py,md,ipynb}
17 |
18 | .PHONY: help Makefile
19 |
20 | # Catch-all target: route all unknown targets to Sphinx using the new
21 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
22 | %: Makefile
23 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
24 |
25 | clean:
26 | rm -rf $(BUILDDIR)/*
27 | rm -rf auto_examples/
28 | rm -rf _autosummary/
29 |
30 | html-noplot:
31 | $(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html
32 | @echo
33 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
34 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Instructions for docs
2 |
3 | First install jupytext:
4 |
5 | `pip3 install jupytext`
6 |
7 | To sync the different forms of the notebooks, cd to this directory and
8 |
9 | `make sync-notebooks`
10 |
--------------------------------------------------------------------------------
/docs/_static/continuous_eval_system_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/learned_optimization/3ab34d27bdd38f487af65fb025f9c86869946e81/docs/_static/continuous_eval_system_diagram.png
--------------------------------------------------------------------------------
/docs/_static/style.css:
--------------------------------------------------------------------------------
1 | @import url("theme.css");
2 |
3 | .wy-side-nav-search {
4 | background-color: #fff;
5 | }
6 |
--------------------------------------------------------------------------------
/docs/_static/vis_Conv_Cifar10_32x64x64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/learned_optimization/3ab34d27bdd38f487af65fb025f9c86869946e81/docs/_static/vis_Conv_Cifar10_32x64x64.png
--------------------------------------------------------------------------------
/docs/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 | {% set css_files = css_files + ["_static/style.css"] %}
3 |
--------------------------------------------------------------------------------
/docs/continuous_eval.rst:
--------------------------------------------------------------------------------
1 | Continuous Eval
2 | ================
3 |
4 | Meta-learning training can be an incredibly noisy process as one often meta-learns over randomly sampled tasks and these sampled tasks could have very different meta-objectives.
5 | To get higher quality evaluation, as well as to evaluate performance on different
6 | kinds of distributions of tasks (i.e. to test out of distribution generalization)
7 | it is often useful to evaluate a suite of tasks on a single, fixed checkpoint/fixed meta-parameter value.
8 | This is a bit tricky though as this can be very expensive -- in the case of learned
9 | optimizers this requires training a number of different tasks a large number
10 | of inner-steps.
11 | To get faster evaluations, it is desirable to split up this computation and run it in
12 | parallel across a number of machines.
13 |
14 | This module is created to help with distributed evaluation of a given checkpoint
15 | and is meant to be run alongside a meta-training job / meta-training cluster
16 | so as to get stable evaluations while training. It interacts with a training
17 | job by monitoring a directory for a new checkpoint and kicks evaluation jobs
18 | off based on that.
19 |
20 | We summarize the architecture in the figure below.
21 |
22 |
23 | .. path to edit figure: https://docs.google.com/presentation/d/16zWe2ryUUcbWSBRhfYx7D5BwG9z1Mt-4BpWCOKmnais/edit?resourcekey=0-_if_-4xNYC5bgD1ZyN2Pmg#slide=id.p --->
24 |
25 |
26 | .. image:: _static/continuous_eval_system_diagram.png
27 | :width: 800
28 | :alt: Continuous eval system diagram
29 |
30 |
31 | There are 2 types of job: the eval chief and the eval worker.
32 |
33 | Evaluation Chief
34 | ----------------
35 |
36 | The evaluation chief is configured with a set of evaluations with which to run
37 | a given checkpoint on. The chief then starts up a couple of threads.
38 | First, there is a thread to monitor a given directory for a new checkpoint.
39 | When a new checkpoint is found, we send this checkpoint as well as the list of
40 | evaluations (e.g. different tasks) to evaluate on it to an instance of `TaskGroupChief`.
41 | This is a task queue specifically structured to manage groups of tasks being evaluated on a single checkpoint.
42 | A final thread monitors this task queue for finished groups of tasks and
43 | writes the results out to tensorboard, and optionally the population controller
44 | if one is using population based training
45 |
46 |
47 | Evaluation configuration
48 | ------------------------
49 |
50 | While it's possible to use similar infrastructure to pass around any kind of
51 | configuration for an evaluation task, we opt to use `gin `_ -- in in particular
52 | a set of gin bindings which specify the new task.
53 | We find this very convenient, but it comes at the cost of complexity and global
54 | state.
55 |
56 | Evaluation Worker
57 | ------------------------
58 |
59 | For a given evaluation chief, we run multiple evaluation workers. Each worker
60 | is responsible for grabbing an evaluation to work on from the eval chief.
61 | We first parse the evaluation configuration, load any weights (e.g. weights of
62 | the learned optimizer), perform some inner-training procedure which does the
63 | evaluation, and finally report the results back to the eval chief.
64 | This loop repeats over and over again.
65 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | learned_optimization reference documentation
2 | ============================================
3 |
4 | learned\_optimization is a research codebase for training learned
5 | optimizers. It implements hand designed and learned optimizers, tasks to meta-train and meta-test them on, and outer-training algorithms such as ES, gradients and PES.
6 |
7 | .. toctree::
8 | :maxdepth: 0
9 | :caption: Getting Started
10 |
11 | notebooks/Part1_Introduction
12 | notebooks/Part2_CustomTasks
13 | notebooks/Part3_Truncation_TruncatedStep
14 | notebooks/Part4_GradientEstimators
15 | notebooks/Part5_Meta_training_with_GradientLearner
16 | notebooks/Part6_custom_learned_optimizers
17 |
18 | .. toctree::
19 | :maxdepth: 0
20 | :caption: Tutorials
21 |
22 | notebooks/no_dependency_learned_optimizer
23 | notebooks/summary_tutorial
24 |
25 | .. toctree::
26 | :maxdepth: 0
27 | :caption: Modules
28 |
29 | tasks
30 | optimizers
31 | learned_optimizers
32 | outer_trainers
33 | continuous_eval
34 | population
35 | optimizer_baselines
36 |
37 |
38 | Indices and tables
39 | ==================
40 |
41 | * :ref:`genindex`
42 | * :ref:`modindex`
43 | * :ref:`search`
44 |
--------------------------------------------------------------------------------
/docs/learned_optimizers.rst:
--------------------------------------------------------------------------------
1 | Learned Optimizers
2 | ====================
3 |
4 | .. automodule:: learned_optimization.learned_optimizers.base
5 |
6 |
7 | API
8 | ---
9 |
10 | .. autofunction:: learned_optimization.learned_optimizers.base.LearnedOptimizer
11 | .. autofunction:: learned_optimization.learned_optimizers.base.LearnableSGD
12 | .. autofunction:: learned_optimization.learned_optimizers.base.LearnableSGDM
13 | .. autofunction:: learned_optimization.learned_optimizers.base.LearnableAdam
--------------------------------------------------------------------------------
/docs/optimizers.rst:
--------------------------------------------------------------------------------
1 | Optimizers
2 | ===========
3 |
4 | .. automodule:: learned_optimization.optimizers.base
5 |
6 |
7 | API
8 | ---
9 |
10 | .. autofunction:: learned_optimization.optimizers.base.Optimizer
11 | .. autofunction:: learned_optimization.optimizers.base.SGD
12 | .. autofunction:: learned_optimization.optimizers.base.SGDM
13 | .. autofunction:: learned_optimization.optimizers.base.Adam
14 |
--------------------------------------------------------------------------------
/docs/outer_trainers.rst:
--------------------------------------------------------------------------------
1 | Outer Trainers
2 | ===============
3 |
4 | Outer-training, or meta-training is the process of learning the meta-parameters.
5 |
6 |
7 |
8 | Base
9 | ----
10 |
11 | .. autofunction:: learned_optimization.outer_trainers.gradient_learner.GradientEstimator
12 | .. autofunction:: learned_optimization.outer_trainers.gradient_learner.gradient_worker_compute
13 | .. autofunction:: learned_optimization.outer_trainers.gradient_learner.SingleMachineGradientLearner
14 |
15 |
16 | FullES
17 | --------------
18 |
19 | .. automodule:: learned_optimization.outer_trainers.full_es
20 |
21 | .. autofunction:: learned_optimization.outer_trainers.full_es.FullES
22 |
23 |
24 |
25 | TruncatedPES
26 | --------------
27 |
28 | .. automodule:: learned_optimization.outer_trainers.truncated_pes
29 |
30 | .. autofunction:: learned_optimization.outer_trainers.truncated_pes.TruncatedPES
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib>=3.3.4
2 | pandoc>=1.0.2
3 | sphinx>=3.5.1
4 | recommonmark>=0.7.1
5 | sphinx_rtd_theme>=0.5.1
6 | sphinx_autodoc_typehints>=1.11.1
7 | ipython>=7.20.0
8 | ipykernel>=5.5.0
9 | nbsphinx>=0.8.7
10 | sphinx_copybutton>=0.4.0
11 | myst_nb>=0.13.0
12 |
--------------------------------------------------------------------------------
/docs/tasks.rst:
--------------------------------------------------------------------------------
1 | Tasks
2 | =====
3 |
4 | .. automodule:: learned_optimization.tasks.base
5 |
6 |
7 | Base
8 | ---
9 |
10 | .. autofunction:: learned_optimization.tasks.base.Task
11 | .. autofunction:: learned_optimization.tasks.base.TaskFamily
12 |
13 |
14 | QuadraticTasks
15 | --------------
16 |
17 | .. automodule:: learned_optimization.tasks.quadratics
18 |
19 | .. autofunction:: learned_optimization.tasks.quadratics.QuadraticTask
20 | .. autofunction:: learned_optimization.tasks.quadratics.BatchQuadraticTask
21 | .. autofunction:: learned_optimization.tasks.quadratics.SumQuadraticTask
22 | .. autofunction:: learned_optimization.tasks.quadratics.FixedDimQuadraticFamily
23 | .. autofunction:: learned_optimization.tasks.quadratics.FixedDimQuadraticFamilyData
--------------------------------------------------------------------------------
/learned_optimization/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """learned_optimizer module."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/baselines/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Baselines module."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/baselines/data/speedup_over_adam.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/learned_optimization/3ab34d27bdd38f487af65fb025f9c86869946e81/learned_optimization/baselines/data/speedup_over_adam.pkl
--------------------------------------------------------------------------------
/learned_optimization/baselines/normalizers_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for normalizers."""
17 |
18 | from absl.testing import absltest
19 | from learned_optimization.baselines import normalizers
20 |
21 |
22 | class NormalizersTest(absltest.TestCase):
23 |
24 | def test_build_and_apply_one(self):
25 | data = normalizers._speedup_over_adam_build("ImageMLP_Cifar10_8_Relu32")
26 | fn = normalizers._speedup_over_adam_make_func(data)
27 | ret = fn(1.0)
28 | self.assertGreater(ret, 0)
29 | # check that it is no more than 100k -- length of max adam lengths.
30 | self.assertLess(ret, 100_001)
31 |
32 | def test_speedup_over_adam_normalizer_map(self):
33 | normfns = normalizers.speedup_over_adam_normalizer_map()
34 | ret = normfns["ImageMLP_Cifar10_8_Relu32"](1.0)
35 | self.assertGreater(ret, 0)
36 | self.assertLess(ret, 100_001)
37 |
38 |
39 | if __name__ == "__main__":
40 | absltest.main()
41 |
--------------------------------------------------------------------------------
/learned_optimization/baselines/run_time_task.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Time a single task."""
17 |
18 | import os
19 | import pickle
20 | import uuid
21 |
22 | from absl import app
23 | from absl import logging
24 | import gin
25 | import jax
26 | from learned_optimization import filesystem
27 | from learned_optimization import setup_experiment
28 | from learned_optimization.tasks import base as tasks_base
29 | from learned_optimization.time_filter import timings
30 |
31 |
32 | @gin.configurable
33 | def run_many_eval_and_save(task: tasks_base.Task = gin.REQUIRED,
34 | save_dir: str = gin.REQUIRED):
35 | """Compute and save `num_to_run` runtime statistics."""
36 |
37 | dev = jax.devices()[0]
38 | dev_name = f"{dev.platform}_{dev.device_kind}"
39 | dev_name = dev_name.replace(" ", "")
40 |
41 | save_dir = os.path.join(save_dir, dev_name)
42 | filesystem.make_dirs(save_dir)
43 |
44 | task_family = tasks_base.single_task_to_family(task)
45 |
46 | speeds_and_std = timings.task_family_runtime_stats(
47 | task_family, num_tasks_list=[1, 2, 4, 8])
48 |
49 | output = pickle.dumps(speeds_and_std)
50 |
51 | path = os.path.join(save_dir, f"{str(uuid.uuid4())}.pkl")
52 | logging.info("Writing results to %s", path)
53 | with filesystem.file_open(path, "wb") as f:
54 | f.write(output)
55 |
56 |
57 | def main(argv):
58 | if len(argv) > 1:
59 | raise app.UsageError("Too many command-line arguments.")
60 | setup_experiment.setup_experiment(make_dir=False)
61 | run_many_eval_and_save()
62 |
63 |
64 | if __name__ == "__main__":
65 | app.run(main)
66 |
--------------------------------------------------------------------------------
/learned_optimization/baselines/run_trainer_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for run_trainer."""
17 |
18 | import os
19 | import tempfile
20 |
21 | from absl.testing import absltest
22 | import gin
23 | from learned_optimization.baselines import run_trainer
24 | from learned_optimization.baselines import utils
25 | from learned_optimization.optimizers import base as opt_base
26 | from learned_optimization.tasks import quadratics
27 | import numpy as onp
28 | from numpy import testing
29 |
30 |
31 | @gin.configurable()
32 | def make_quad():
33 | return quadratics.QuadraticTask()
34 |
35 |
36 | class RunTrainerTest(absltest.TestCase):
37 |
38 | def test_inner_train_task(self):
39 | with tempfile.TemporaryDirectory() as tmpdir:
40 | os.environ["LOPT_BASELINE_DIR"] = tmpdir
41 | opt = opt_base.SGD()
42 | task = quadratics.QuadraticTask()
43 | run_trainer.inner_train_task(
44 | task,
45 | opt,
46 | num_steps=10,
47 | eval_every=5,
48 | eval_batches=2,
49 | last_eval_batches=4)
50 |
51 | results = utils.load_baseline_results(
52 | "QuadraticTask",
53 | "SGD_lr0.01",
54 | num_steps=10,
55 | eval_every=5,
56 | eval_batches=2,
57 | last_eval_batches=4,
58 | output_type="curves",
59 | threads=0)
60 |
61 | self.assertLen(results, 1)
62 | results, = results
63 |
64 | self.assertEqual(results["num_steps"], 10)
65 | self.assertEqual(results["eval_batches"], 2)
66 | self.assertEqual(results["last_eval_batches"], 4)
67 | testing.assert_almost_equal(results["train/xs"],
68 | onp.arange(11, dtype=onp.int32))
69 | testing.assert_almost_equal(results["eval/xs"], onp.asarray([0, 5, 10]))
70 |
71 | def test_inner_train_task_from_gin(self):
72 | with tempfile.TemporaryDirectory() as tmpdir:
73 | os.environ["LOPT_BASELINE_DIR"] = tmpdir
74 | configs = [
75 | "inner_train_task.task=@make_quad()", "inner_train_task.opt=@SGD()"
76 | ]
77 | gin.parse_config(configs)
78 |
79 | run_trainer.inner_train_task(
80 | num_steps=10, eval_every=5, eval_batches=2, last_eval_batches=4)
81 |
82 | results = utils.load_baseline_results(
83 | "make_quad",
84 | "SGD_lr0.01",
85 | num_steps=10,
86 | eval_every=5,
87 | eval_batches=2,
88 | last_eval_batches=4,
89 | output_type="curves",
90 | threads=0)
91 |
92 | self.assertLen(results, 1)
93 |
94 |
95 | if __name__ == "__main__":
96 | absltest.main()
97 |
--------------------------------------------------------------------------------
/learned_optimization/circular_buffer_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimization.circular_buffer."""
17 |
18 | from absl.testing import absltest
19 | import haiku as hk
20 | import jax
21 | import jax.numpy as jnp
22 | from learned_optimization import circular_buffer
23 | import numpy as onp
24 |
25 | freeze = hk.data_structures.to_immutable_dict
26 |
27 |
28 | class CircularBufferTest(absltest.TestCase):
29 |
30 | def test_circular_buffer(self):
31 | s = freeze({"a": jnp.ones([1, 3])})
32 | av = jax.tree_util.tree_map(lambda x: x.aval, s)
33 |
34 | buffer = circular_buffer.CircularBuffer(av, 5)
35 |
36 | state = buffer.init()
37 |
38 | state = buffer.add(state, s)
39 |
40 | vs, m = buffer.stack_reorder(state)
41 | assert onp.sum(m) == 1
42 | onp.testing.assert_allclose(m, onp.asarray([0, 0, 0, 0, 1]))
43 | onp.testing.assert_allclose(vs["a"][-1], onp.ones([1, 3]))
44 | onp.testing.assert_allclose(vs["a"][0], onp.zeros([1, 3]))
45 |
46 | vs, idx = buffer.stack_with_idx(state)
47 | onp.testing.assert_allclose(idx, [4, -1, -1, -1, -1])
48 | onp.testing.assert_allclose(vs["a"][0], onp.ones([1, 3]))
49 |
50 | for _ in range(4):
51 | state = buffer.add(state, s)
52 |
53 | s = freeze({"a": jnp.ones([1, 3]) * 2})
54 | state = buffer.add(state, s)
55 |
56 | vs, m = buffer.stack_reorder(state)
57 | assert onp.sum(m) == 5
58 | onp.testing.assert_allclose(m, onp.asarray([1, 1, 1, 1, 1]))
59 | onp.testing.assert_allclose(vs["a"][-1], 2 * onp.ones([1, 3]))
60 | onp.testing.assert_allclose(vs["a"][0], onp.ones([1, 3]))
61 |
62 | vs, idx = buffer.stack_with_idx(state)
63 | onp.testing.assert_allclose(idx, [4, 0, 1, 2, 3])
64 | onp.testing.assert_allclose(vs["a"][0], 2 * onp.ones([1, 3]))
65 | onp.testing.assert_allclose(vs["a"][1], onp.ones([1, 3]))
66 |
67 | def test_gather(self):
68 | s = freeze({"a": jnp.ones([3])})
69 | av = jax.tree_util.tree_map(lambda x: x.aval, s)
70 |
71 | buffer = circular_buffer.CircularBuffer(av, 5)
72 |
73 | state = buffer.init()
74 |
75 | for i in range(12):
76 | s = freeze({"a": jnp.ones([3]) * i})
77 | state = buffer.add(state, s)
78 |
79 | v = buffer.gather_from_present(state, jnp.asarray([4, 3, 0, 0]))
80 | assert v["a"][0][0] == 11.0
81 | assert v["a"][1][0] == 10.0
82 | assert v["a"][2][1] == 7.0
83 |
84 | s = freeze({"a": jnp.ones([3]) * (i + 1)})
85 | state = buffer.add(state, s)
86 | v = buffer.gather_from_present(state, jnp.asarray([4, 3, 0, 0]))
87 | assert v["a"][0][0] == 12.0
88 | assert v["a"][1][0] == 11.0
89 | assert v["a"][2][1] == 8.0
90 |
91 |
92 | if __name__ == "__main__":
93 | absltest.main()
94 |
--------------------------------------------------------------------------------
/learned_optimization/continuous_eval/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Continuous evalulation module."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/continuous_eval/task_group_server_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.continuous_eval.task_group_server."""
17 |
18 | import os
19 | import shutil
20 | import tempfile
21 |
22 | from absl.testing import absltest
23 | from learned_optimization.continuous_eval import task_group_server
24 |
25 |
26 | class TaskGroupServerTest(absltest.TestCase):
27 |
28 | def test_task_group_chief(self):
29 | log_dir = os.path.join(tempfile.gettempdir(), "log_dir")
30 |
31 | if os.path.exists(log_dir):
32 | shutil.rmtree(log_dir)
33 |
34 | chief = task_group_server.TaskGroupChief(
35 | "test", log_dir, 1, start_server=False)
36 | chief.daemon = True
37 |
38 | chief.start()
39 |
40 | tasks = range(5)
41 | chief.add_task_group("HelloTask", tasks)
42 |
43 | for _ in range(5):
44 | self.assertIs(chief.get_finished_task_group(), None)
45 | work = chief.get_work(0)
46 | self.assertIsNot(work, None)
47 | chief.finish_work(0, 123)
48 |
49 | task_group, values, tasks = chief.get_finished_task_group()
50 | self.assertEqual(values, [123] * 5)
51 | self.assertEqual(task_group, "HelloTask")
52 | self.assertIs(chief.get_finished_task_group(), None)
53 |
54 | work = chief.get_work(0)
55 | self.assertIs(work, None)
56 |
57 | chief.should_stop = True
58 | chief.join()
59 |
60 |
61 | if __name__ == "__main__":
62 | absltest.main()
63 |
--------------------------------------------------------------------------------
/learned_optimization/entry_points/README.md:
--------------------------------------------------------------------------------
1 | # Entry points
2 |
3 | The contents of this folder contain various run-able scripts, all configured with gin.
4 |
5 |
--------------------------------------------------------------------------------
/learned_optimization/entry_points/run_check_eval_worker.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Local test script to check that the evaluation configurations run."""
17 |
18 | import time
19 | from typing import Sequence
20 |
21 | from absl import app
22 | import gin
23 | import jax
24 | from learned_optimization import eval_training
25 | from learned_optimization import profile
26 | from learned_optimization import setup_experiment
27 | from learned_optimization.continuous_eval import run_eval_chief
28 | from learned_optimization.continuous_eval import run_eval_worker
29 | from learned_optimization.continuous_eval import evaluation_set # pylint: disable=unused-import
30 |
31 |
32 | def main(unused_argv: Sequence[str]) -> None:
33 | setup_experiment.setup_experiment(gin_finalize=False)
34 |
35 | unused_chief_name, unused_num_workers, lopt = run_eval_chief.eval_chief_config(
36 | )
37 |
38 | cfg = gin.query_parameter("run_evaluation_chief.evaluation_set")
39 | eval_task_configs = cfg.configurable.wrapper()
40 | task = eval_task_configs[0]
41 | (eval_cfg, unused_eval_name) = task
42 |
43 | gin.parse_config(eval_cfg, skip_unknown=True)
44 |
45 | key = jax.random.PRNGKey(0)
46 | theta = lopt.init(key)
47 | task_family = run_eval_worker.get_task_family()
48 |
49 | with profile.Profile("inner_train"):
50 | stime = time.time()
51 | losses = eval_training.multi_task_training_curves(
52 | task_family, lopt, theta=theta)
53 | total_time = time.time() - stime
54 | result = {"total_time": total_time, "gen_id": "", "step": 123}
55 | for k, v in losses.items():
56 | result[k] = v
57 | print(result)
58 |
59 |
60 | if __name__ == "__main__":
61 | app.run(main)
62 |
--------------------------------------------------------------------------------
/learned_optimization/entry_points/run_outer_train_single.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Entry point / main script for outer_train.run_train on a single machine."""
17 | from typing import Sequence
18 |
19 | from absl import app
20 | from learned_optimization import outer_train
21 | from learned_optimization import setup_experiment
22 |
23 |
24 | def main(argv: Sequence[str]) -> None:
25 | if len(argv) > 1:
26 | raise app.UsageError('Too many command-line arguments.')
27 |
28 | train_log_dir = setup_experiment.setup_experiment(make_dir=True)
29 | # Run both the trainer and the worker on the same machine.
30 | outer_train.run_train(
31 | train_log_dir,
32 | is_trainer=True,
33 | is_worker=True,
34 | )
35 |
36 |
37 | if __name__ == '__main__':
38 | app.run(main)
39 |
--------------------------------------------------------------------------------
/learned_optimization/entry_points/run_outer_trainer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Entry point to run just the learner process.
17 |
18 | This should be run with run_outer_worker.py.
19 | """
20 | from typing import Sequence
21 |
22 | from absl import app
23 | from learned_optimization import outer_train
24 | from learned_optimization import setup_experiment
25 |
26 |
27 | def main(argv: Sequence[str]) -> None:
28 | if len(argv) > 1:
29 | raise app.UsageError('Too many command-line arguments.')
30 |
31 | train_log_dir = setup_experiment.setup_experiment(make_dir=True)
32 | outer_train.run_train(
33 | train_log_dir,
34 | is_trainer=True,
35 | is_worker=False,
36 | )
37 |
38 |
39 | if __name__ == '__main__':
40 | app.run(main)
41 |
--------------------------------------------------------------------------------
/learned_optimization/entry_points/run_outer_worker.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Entry point to run just the worker process.
17 |
18 | This should be run with run_outer_trainer.py.
19 | """
20 | from typing import Sequence
21 |
22 | from absl import app
23 | from absl import flags
24 | from learned_optimization import outer_train
25 | from learned_optimization import setup_experiment
26 |
27 | FLAGS = flags.FLAGS
28 |
29 |
30 | def main(argv: Sequence[str]) -> None:
31 | if len(argv) > 1:
32 | raise app.UsageError('Too many command-line arguments.')
33 |
34 | train_log_dir = setup_experiment.setup_experiment()
35 | outer_train.run_train(
36 | train_log_dir, is_trainer=False, is_worker=True, worker_id=FLAGS.task)
37 |
38 |
39 | if __name__ == '__main__':
40 | app.run(main)
41 |
--------------------------------------------------------------------------------
/learned_optimization/entry_points/run_population_controller.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Entry point for population based training."""
17 |
18 | from absl import app
19 | from learned_optimization import outer_train
20 | from learned_optimization import setup_experiment
21 |
22 |
23 |
24 | def main(argv):
25 | train_log_dir = setup_experiment.setup_experiment(make_dir=True)
26 |
27 | outer_train.run_population_controller(train_log_dir)
28 |
29 |
30 | if __name__ == "__main__":
31 | app.run(main)
32 |
--------------------------------------------------------------------------------
/learned_optimization/entry_points/run_tboard_eval.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Script that inner trains and logs metrics to eventfiles to view with tboard.
17 |
18 | This is used to debug task implementations.
19 | """
20 | from typing import Optional
21 |
22 | from absl import app
23 | import gin
24 | import jax
25 | from learned_optimization import eval_training
26 | from learned_optimization import setup_experiment
27 | from learned_optimization import summary
28 | from learned_optimization.optimizers import base as opt_base
29 | from learned_optimization.tasks import base as tasks_base
30 | import numpy as onp
31 |
32 |
33 | @gin.configurable
34 | def inner_train_task(
35 | train_log_dir,
36 | task: tasks_base.Task = gin.REQUIRED,
37 | opt: opt_base.Optimizer = gin.REQUIRED,
38 | num_steps: int = gin.REQUIRED,
39 | eval_every: int = 10,
40 | eval_batches: int = 5,
41 | last_eval_batches: int = 10,
42 | eval_task: Optional[tasks_base.Task] = None,
43 | metrics_every: Optional[int] = None,
44 | device: Optional[jax.Device] = None,
45 | ):
46 | """Train and save results of a single training run.
47 |
48 | Args:
49 | train_log_dir: location to put summaries.
50 | task: task to train
51 | opt: optimizer to apply.
52 | num_steps: number of training iterations.
53 | eval_every: how frequently to run evaluation.
54 | eval_batches: number of batches to evaluate on.
55 | last_eval_batches: number of batches to run at end of training.
56 | eval_task: The task used for evaluation. If none we use the same task as
57 | used for training. This is useful when having different train and test
58 | functions as is the case for batchnorm and dropout.
59 | device: device to train with.
60 | """
61 |
62 | seed = onp.random.randint(0, onp.iinfo(onp.int32).max)
63 | key = jax.random.PRNGKey(seed)
64 |
65 | summary_writer = summary.MultiWriter(summary.PrintWriter(),
66 | summary.JaxboardWriter(train_log_dir))
67 |
68 | eval_training.single_task_training_curves(
69 | task=task,
70 | opt=opt,
71 | num_steps=num_steps,
72 | key=key,
73 | eval_every=eval_every,
74 | eval_batches=eval_batches,
75 | last_eval_batches=last_eval_batches,
76 | eval_task=eval_task,
77 | device=device,
78 | metrics_every=metrics_every,
79 | summary_writer=summary_writer)
80 |
81 |
82 | def main(argv):
83 | if len(argv) > 1:
84 | raise app.UsageError("Too many command-line arguments.")
85 | train_log_dir = setup_experiment.setup_experiment(make_dir=True)
86 |
87 | inner_train_task(train_log_dir)
88 |
89 |
90 | if __name__ == "__main__":
91 | app.run(main)
92 |
--------------------------------------------------------------------------------
/learned_optimization/eval_training_multi_device_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.eval_training and multiple devices."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from learned_optimization import eval_training
21 | from learned_optimization.learned_optimizers import base as lopt_base
22 | from learned_optimization.tasks import quadratics
23 |
24 |
25 | class EvalTrainingMultiDeviceTest(absltest.TestCase):
26 |
27 | def test_multi_task_training_curves(self):
28 | num_devices = len(jax.local_devices())
29 | if num_devices != 8:
30 | self.skipTest(f"Not enough accelerators! Expected 8, found {num_devices}."
31 | " Please fun this test with exactly 8 accelerators.")
32 |
33 | key = jax.random.PRNGKey(0)
34 | task_family = quadratics.FixedDimQuadraticFamilyData(10)
35 | learned_opt = lopt_base.LearnableSGD()
36 | theta = learned_opt.init(key)
37 |
38 | res = eval_training.multi_task_training_curves(
39 | task_family,
40 | learned_opt,
41 | theta,
42 | n_tasks=16, # 2 tasks per device
43 | key=key,
44 | eval_every=2,
45 | steps_per_jit=2,
46 | n_eval_batches=3,
47 | n_eval_batches_vec=4,
48 | steps=10,
49 | with_aux_values=("l2",),
50 | )
51 | self.assertIn("train/xs", res)
52 | self.assertIn("train/loss", res)
53 | self.assertIn("eval/train/loss", res)
54 |
55 |
56 | if __name__ == "__main__":
57 | absltest.main()
58 |
--------------------------------------------------------------------------------
/learned_optimization/eval_training_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.eval_training."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from learned_optimization import eval_training
21 | from learned_optimization.learned_optimizers import base as lopt_base
22 | from learned_optimization.optimizers import base as opt_base
23 | from learned_optimization.tasks import quadratics
24 |
25 |
26 | class EvalTrainingTest(absltest.TestCase):
27 |
28 | def test_single_task_training_curves_no_data(self):
29 | key = jax.random.PRNGKey(0)
30 | task = quadratics.QuadraticTask(10)
31 | opt = opt_base.SGD()
32 | res = eval_training.single_task_training_curves(
33 | task, opt, 10, key, eval_every=2)
34 | self.assertIn("train/xs", res)
35 | self.assertIn("train/loss", res)
36 | self.assertEqual(res["train/loss"].shape, (11,))
37 |
38 | def test_single_task_training_curves_with_data(self):
39 | key = jax.random.PRNGKey(0)
40 | task = quadratics.BatchQuadraticTask()
41 | opt = opt_base.SGD()
42 | res = eval_training.single_task_training_curves(
43 | task, opt, 10, key, eval_every=2)
44 | self.assertIn("train/xs", res)
45 | self.assertIn("train/loss", res)
46 | self.assertEqual(res["train/loss"].shape, (11,))
47 |
48 | def test_multi_task_training_curves(self):
49 | key = jax.random.PRNGKey(0)
50 | task_family = quadratics.FixedDimQuadraticFamilyData(10)
51 | learned_opt = lopt_base.LearnableSGD()
52 | theta = learned_opt.init(key)
53 |
54 | res = eval_training.multi_task_training_curves(
55 | task_family,
56 | learned_opt,
57 | theta,
58 | n_tasks=5,
59 | key=key,
60 | eval_every=2,
61 | steps_per_jit=2,
62 | steps=10,
63 | with_aux_values=("l2",),
64 | )
65 | self.assertIn("train/xs", res)
66 | self.assertIn("train/loss", res)
67 | self.assertIn("eval/train/loss", res)
68 | self.assertEqual(res["train/loss"].shape, (
69 | 5,
70 | 10,
71 | ))
72 |
73 | # 10 / 2 and then 1 additional due to eval-ing at the end.
74 | self.assertEqual(res["eval/train/loss"].shape, (
75 | 5,
76 | 6,
77 | ))
78 |
79 | self.assertIn("eval/train/aux/l2", res)
80 | self.assertNotIn("eval/train/aux/l1", res)
81 |
82 |
83 | if __name__ == "__main__":
84 | absltest.main()
85 |
--------------------------------------------------------------------------------
/learned_optimization/examples/README.md:
--------------------------------------------------------------------------------
1 | # Examples
2 |
3 | ### Simple LOpt Train
4 | A small example training a learned optimizer.
5 |
--------------------------------------------------------------------------------
/learned_optimization/examples/simple_lopt_train_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for simple_lopt_train."""
17 |
18 | import tempfile
19 |
20 | from absl.testing import absltest
21 | from learned_optimization.examples import simple_lopt_train
22 | from learned_optimization.tasks import quadratics
23 |
24 |
25 | class SimpleLoptTrainTest(absltest.TestCase):
26 |
27 | def test_train(self):
28 | with tempfile.TemporaryDirectory() as path:
29 | simple_lopt_train.train(path, 10, task=quadratics.QuadraticTask())
30 |
31 |
32 | if __name__ == '__main__':
33 | absltest.main()
34 |
--------------------------------------------------------------------------------
/learned_optimization/filesystem.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Filesystem interface for writing to different data sources."""
17 |
18 | import glob as py_glob
19 | import os
20 | import shutil
21 | from typing import Sequence
22 | import tensorflow as tf
23 |
24 |
25 |
26 | def _path_on_gcp(path: str) -> bool:
27 | prefixes = ["gs://"]
28 | return any([path.startswith(p) for p in prefixes])
29 |
30 |
31 | def file_open(path: str, mode: str):
32 | """Open a file, returning a file object."""
33 | if _path_on_gcp(path):
34 | return tf.io.gfile.GFile(path, mode)
35 | return open(path, mode)
36 |
37 |
38 | def make_dirs(path: str):
39 | """Make directories for given path."""
40 | if _path_on_gcp(path):
41 | return tf.io.gfile.makedirs(path)
42 |
43 | if not os.path.exists(path):
44 | return os.makedirs(path)
45 |
46 |
47 | def copy(path: str, target: str):
48 | """Copy path to target."""
49 | if _path_on_gcp(path):
50 | return tf.io.gfile.copy(path, target, overwrite=True)
51 |
52 | return shutil.copy(path, target)
53 |
54 |
55 | def rename(path: str, target: str):
56 | """Copy path to target."""
57 | if _path_on_gcp(path):
58 | return tf.io.gfile.rename(path, target, overwrite=True)
59 |
60 | return shutil.move(path, target)
61 |
62 |
63 | def exists(path: str) -> bool:
64 | """Check if a file exists."""
65 | if _path_on_gcp(path):
66 | return tf.io.gfile.exists(path)
67 |
68 | return os.path.exists(path)
69 |
70 |
71 | def glob(pattern: str) -> Sequence[str]:
72 | """Glob the filesystem with given pattern."""
73 | if _path_on_gcp(pattern):
74 | return tf.io.gfile.glob(pattern)
75 |
76 | return py_glob.glob(pattern)
77 |
78 |
79 | def remove(path: str) -> bool:
80 | """Remove a file."""
81 | if _path_on_gcp(path):
82 | return tf.io.gfile.remove(path)
83 | return shutil.rmtree(path) # pytype: disable=bad-return-type
84 |
--------------------------------------------------------------------------------
/learned_optimization/jax_utils_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for jax_utils."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from learned_optimization import jax_utils
21 |
22 |
23 | class JaxUtilsTest(absltest.TestCase):
24 |
25 | def test_in_jit(self):
26 | state = []
27 |
28 | def fn(x):
29 | if jax_utils.in_jit():
30 | state.append(None)
31 | return x * 2
32 |
33 | _ = fn(1)
34 | self.assertEmpty(state)
35 |
36 | _ = jax.jit(fn)(1)
37 | self.assertLen(state, 1)
38 |
39 |
40 | if __name__ == '__main__':
41 | absltest.main()
42 |
--------------------------------------------------------------------------------
/learned_optimization/launchers/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """learned_optimization.launchers"""
17 |
--------------------------------------------------------------------------------
/learned_optimization/learned_optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """learned_optimizers.learned_optimizers module."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/learned_optimizers/adafac_mlp_lopt_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.adafac_mlp_lopt."""
17 | from absl.testing import absltest
18 | import jax
19 | from learned_optimization.learned_optimizers import adafac_mlp_lopt
20 | from learned_optimization.learned_optimizers import test_utils
21 | import numpy as onp
22 |
23 |
24 | class AdafacMLPLOptTest(absltest.TestCase):
25 |
26 | def test_adafac_mlp_lopt(self):
27 | test_utils.smoketest_learned_optimizer(adafac_mlp_lopt.AdafacMLPLOpt())
28 |
29 | def test_split_weights_equal_to_full(self):
30 | lopt1 = adafac_mlp_lopt.AdafacMLPLOpt(
31 | concat_weights=True, step_mult=10, exp_mult=1)
32 | lopt2 = adafac_mlp_lopt.AdafacMLPLOpt(
33 | concat_weights=False, split_weights=True, step_mult=10, exp_mult=1)
34 |
35 | key = jax.random.PRNGKey(0)
36 | theta = lopt1.init(key)
37 | opt1 = lopt1.opt_fn(theta)
38 | opt2 = lopt2.opt_fn(theta)
39 |
40 | p = (jax.random.normal(key, [2, 2]),)
41 | g = (jax.random.normal(key, [2, 2]),)
42 |
43 | opt_state = opt1.init(p, num_steps=10)
44 | print(opt_state)
45 |
46 | opt_state1 = opt1.update(opt_state, g, 1.0)
47 | opt_state2 = opt2.update(opt_state, g, 1.0)
48 |
49 | diff = opt_state1.params[0] - opt_state2.params[0]
50 | self.assertLess(onp.mean(onp.abs(diff)), 1e-6)
51 |
52 |
53 | if __name__ == '__main__':
54 | absltest.main()
55 |
--------------------------------------------------------------------------------
/learned_optimization/learned_optimizers/adafac_nominal_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.adafac_nominal."""
17 | from absl.testing import absltest
18 | from learned_optimization.learned_optimizers import adafac_nominal
19 | from learned_optimization.learned_optimizers import test_utils
20 |
21 |
22 | class MLPNomLOptTest(absltest.TestCase):
23 |
24 | def test_adafac_mlp_lopt(self):
25 | test_utils.smoketest_learned_optimizer(adafac_nominal.MLPNomLOpt())
26 |
27 |
28 | if __name__ == '__main__':
29 | absltest.main()
30 |
--------------------------------------------------------------------------------
/learned_optimization/learned_optimizers/base_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.learned_optimizers.base."""
17 | from absl.testing import absltest
18 | from learned_optimization.learned_optimizers import base
19 | from learned_optimization.learned_optimizers import test_utils
20 |
21 |
22 | class BaseTest(absltest.TestCase):
23 |
24 | def test_learnable_adam(self):
25 | test_utils.smoketest_learned_optimizer(base.LearnableAdam())
26 |
27 | def test_learnable_sgd(self):
28 | test_utils.smoketest_learned_optimizer(base.LearnableSGD())
29 |
30 | def test_learnable_sgdm(self):
31 | test_utils.smoketest_learned_optimizer(base.LearnableSGDM())
32 |
33 |
34 | if __name__ == '__main__':
35 | absltest.main()
36 |
--------------------------------------------------------------------------------
/learned_optimization/learned_optimizers/mlp_lopt_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.mlp_lopt."""
17 | from absl.testing import absltest
18 | from learned_optimization.learned_optimizers import mlp_lopt
19 | from learned_optimization.learned_optimizers import test_utils
20 |
21 |
22 | class MLPLOptTest(absltest.TestCase):
23 |
24 | def test_mlp_lopt(self):
25 | test_utils.smoketest_learned_optimizer(mlp_lopt.MLPLOpt())
26 |
27 |
28 | if __name__ == '__main__':
29 | absltest.main()
30 |
--------------------------------------------------------------------------------
/learned_optimization/learned_optimizers/nn_adam_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.nn_adam."""
17 | from absl.testing import absltest
18 | from learned_optimization.learned_optimizers import nn_adam
19 | from learned_optimization.learned_optimizers import test_utils
20 |
21 |
22 | class NNAdamTest(absltest.TestCase):
23 |
24 | def test_nn_adam(self):
25 | test_utils.smoketest_learned_optimizer(nn_adam.NNAdam())
26 |
27 |
28 | if __name__ == '__main__':
29 | absltest.main()
30 |
--------------------------------------------------------------------------------
/learned_optimization/learned_optimizers/rnn_mlp_lopt_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.rnn_mlp_lopt."""
17 | from absl.testing import absltest
18 | from learned_optimization.learned_optimizers import rnn_mlp_lopt
19 | from learned_optimization.learned_optimizers import test_utils
20 |
21 |
22 | class RNNMLPLOptTest(absltest.TestCase):
23 |
24 | def test_rnn_mlp_lopt(self):
25 | test_utils.smoketest_learned_optimizer(rnn_mlp_lopt.RNNMLPLOpt())
26 |
27 |
28 | if __name__ == '__main__':
29 | absltest.main()
30 |
--------------------------------------------------------------------------------
/learned_optimization/learned_optimizers/test_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test utilities for testing learned optimizers."""
17 | import jax
18 | from learned_optimization.optimizers import test_utils as opt_test_utils
19 |
20 |
21 | def smoketest_learned_optimizer(learned_optimizer, strict_types=True):
22 | key = jax.random.PRNGKey(0)
23 |
24 | meta_param = learned_optimizer.init(key)
25 | optimizer = learned_optimizer.opt_fn(meta_param)
26 |
27 | opt_test_utils.smoketest_optimizer(optimizer, strict_types=strict_types)
28 |
--------------------------------------------------------------------------------
/learned_optimization/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """learned_optimizers.optimizers module."""
17 | from learned_optimization.optimizers.base import *
18 | from learned_optimization.optimizers.optax_opts import *
19 |
--------------------------------------------------------------------------------
/learned_optimization/optimizers/alias.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Adam configurations, which are named for easy usage from gin."""
17 | import gin
18 | from learned_optimization.optimizers import learning_rate_schedules
19 | from learned_optimization.optimizers import optax_opts
20 |
21 |
22 | @gin.configurable()
23 | def AdamSqrtDecayLR(learning_rate):
24 | lr_decay = learning_rate_schedules.LinearRampupSqrtDecay(
25 | learning_rate, warmup_steps=1000)
26 | return optax_opts.Adam(lr_decay)
27 |
28 |
29 | @gin.configurable()
30 | def AdamExpDecayLR(learning_rate):
31 | # 7e-5 results in roughly a 3 order of magnitude reduction over 100_000 steps.
32 | lr_decay = learning_rate_schedules.ExponentialDecay(learning_rate, 7e-5)
33 | return optax_opts.Adam(lr_decay)
34 |
--------------------------------------------------------------------------------
/learned_optimization/optimizers/base_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.optimizers.base."""
17 | from absl.testing import absltest
18 | from learned_optimization.optimizers import base
19 | from learned_optimization.optimizers import optax_opts
20 | from learned_optimization.optimizers import test_utils
21 |
22 |
23 | class BaseTest(absltest.TestCase):
24 |
25 | def test_gradient_clip_adam(self):
26 | test_utils.smoketest_optimizer(
27 | base.GradientClipOptimizer(optax_opts.Adam(1e-4)))
28 |
29 | def test_grafted_optimizer(self):
30 | test_utils.smoketest_optimizer(
31 | base.GraftedOptimizer(optax_opts.SGDM(1e-2), optax_opts.Adam(1e-4)))
32 |
33 | if __name__ == '__main__':
34 | absltest.main()
35 |
--------------------------------------------------------------------------------
/learned_optimization/optimizers/gradient_accumulator_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for gradient_accumulator."""
17 |
18 | from absl.testing import absltest
19 | from learned_optimization.optimizers import gradient_accumulator
20 | from learned_optimization.optimizers import optax_opts
21 | from learned_optimization.optimizers import test_utils
22 |
23 |
24 | class GradientAccumulatorTest(absltest.TestCase):
25 |
26 | def test_gradient_accumulator(self):
27 | opt = optax_opts.Adam()
28 | opt = gradient_accumulator.GradientAccumulator(opt, 5)
29 | test_utils.smoketest_optimizer(opt)
30 |
31 |
32 | if __name__ == '__main__':
33 | absltest.main()
34 |
--------------------------------------------------------------------------------
/learned_optimization/optimizers/nadamw_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.optimizers.nadamw."""
17 | from absl.testing import absltest
18 | from learned_optimization.optimizers import learning_rate_schedules
19 | from learned_optimization.optimizers import nadamw
20 | from learned_optimization.optimizers import test_utils
21 |
22 |
23 | class NAdamWTest(absltest.TestCase):
24 |
25 | def test_nadamw_nesterov(self):
26 | test_utils.smoketest_optimizer(nadamw.NAdamW(use_nesterov=True))
27 |
28 | def test_nadamw_weight_decays(self):
29 | test_utils.smoketest_optimizer(
30 | nadamw.NAdamW(adamw_weight_decay=1.0, l2_weight_decay=1.0))
31 |
32 | def test_nadamw_with_schedule(self):
33 | sched = learning_rate_schedules.CosineLearningRateSchedule(
34 | 1e-3, 0.1, 0.3, 0.1)
35 | test_utils.smoketest_optimizer(nadamw.NAdamW(learning_rate=sched))
36 |
37 |
38 | if __name__ == '__main__':
39 | absltest.main()
40 |
--------------------------------------------------------------------------------
/learned_optimization/optimizers/opt_list_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.optimizers.opt_list."""
17 | from absl.testing import absltest
18 | import jax
19 | import jax.numpy as jnp
20 | from learned_optimization.optimizers import opt_list
21 | from learned_optimization.optimizers import test_utils
22 | import numpy as onp
23 |
24 |
25 | class OptListTest(absltest.TestCase):
26 |
27 | def test_opt_list_nesterov(self):
28 | test_utils.smoketest_optimizer(opt_list.OptList(0))
29 |
30 | def test_opt_list_vmap_idx(self):
31 |
32 | def init_update(idx):
33 | opt = opt_list.OptList(idx)
34 | p = jnp.ones(10)
35 | opt_state = opt.init(p, num_steps=10)
36 | return opt.get_params(opt.update(opt_state, p))
37 |
38 | result = jax.vmap(init_update)(jnp.arange(10, dtype=jnp.int32))
39 | # ensure we computed different updates.
40 | self.assertGreater(onp.sum(onp.abs(result[0, :] - result[1, :])), 0.0001)
41 |
42 |
43 | if __name__ == '__main__':
44 | absltest.main()
45 |
--------------------------------------------------------------------------------
/learned_optimization/optimizers/opt_to_optax.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Convert a learned_optimization Optimizer to an optax update."""
17 |
18 | import dataclasses
19 | from typing import Any, Mapping, Optional, Tuple, NamedTuple
20 |
21 | import chex
22 | from learned_optimization import tree_utils
23 | from learned_optimization.optimizers import base
24 | import optax
25 |
26 |
27 | def opt_to_optax_opt(
28 | opt: base.Optimizer, num_steps: Optional[int] = None
29 | ) -> optax.GradientTransformationExtraArgs:
30 | """Convert a learned_optimization optimizer to an optax optimizers.
31 |
32 | This makes use of optax's 'extra_args' argument. For many learned optimizers
33 | this will contain loss values.
34 |
35 | Args:
36 | opt: the learned_optimization optimizer to convert.
37 | num_steps: The number of steps the optimizer is expected to run for.
38 |
39 | Returns:
40 | init_and_update: the optax optimizer.
41 | """
42 |
43 | def init_fn(params: chex.ArrayTree,
44 | *,
45 | extra_args: Optional[Mapping[str, Any]] = None) -> chex.ArrayTree:
46 | del extra_args
47 | opt_state = opt.init(params, num_steps=num_steps)
48 | if dataclasses.is_dataclass(opt_state):
49 | return opt.set_params(opt_state, ())
50 | else:
51 | raise NotImplementedError("Only flax dataclasses are supported!")
52 |
53 | def update_fn(
54 | updates: chex.ArrayTree,
55 | state: chex.ArrayTree,
56 | params: Optional[chex.ArrayTree] = None,
57 | *,
58 | extra_args: Optional[Mapping[str, Any]] = None
59 | ) -> Tuple[chex.ArrayTree, chex.ArrayTree]:
60 | if extra_args is None:
61 | extra_args = {}
62 |
63 | if params is None:
64 | raise ValueError("Params must not be None!")
65 |
66 | if dataclasses.is_dataclass(state):
67 | state = opt.set_params(state, params)
68 | else:
69 | raise NotImplementedError("Only flax dataclasses are supported!")
70 |
71 | next_state = opt.update(state, updates, **extra_args)
72 |
73 | step = tree_utils.tree_sub(opt.get_params(next_state), params)
74 |
75 | next_state = opt.set_params(next_state, ())
76 |
77 | return step, next_state
78 |
79 | return optax.GradientTransformationExtraArgs(init_fn, update_fn)
80 |
--------------------------------------------------------------------------------
/learned_optimization/optimizers/opt_to_optax_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for opt_to_optax."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from learned_optimization.learned_optimizers import rnn_mlp_lopt
21 | from learned_optimization.optimizers import opt_to_optax
22 | from learned_optimization.optimizers import optax_opts
23 | from learned_optimization.tasks import quadratics
24 | import optax
25 |
26 |
27 | class OptToOptaxTest(absltest.TestCase):
28 |
29 | def test_adam_opt_to_optax_opt(self):
30 | opt = optax_opts.Adam(1e-4)
31 | task = quadratics.QuadraticTask()
32 | key = jax.random.PRNGKey(0)
33 | p = task.init(key)
34 |
35 | optax_opt = opt_to_optax.opt_to_optax_opt(opt)
36 |
37 | opt_state = optax_opt.init(p)
38 | updates = p
39 | step, opt_state = optax_opt.update(updates, opt_state, params=p)
40 | p = optax.apply_updates(p, step)
41 |
42 | step, opt_state = optax_opt.update(updates, opt_state, params=p)
43 | p = optax.apply_updates(p, step)
44 |
45 | def test_learned_opt_to_optax_opt(self):
46 | lopt = rnn_mlp_lopt.RNNMLPLOpt()
47 | key = jax.random.PRNGKey(0)
48 |
49 | lo_opt = lopt.opt_fn(lopt.init(key))
50 | optax_opt = opt_to_optax.opt_to_optax_opt(lo_opt, num_steps=100)
51 |
52 | task = quadratics.QuadraticTask()
53 | p = task.init(key)
54 |
55 | opt_state = optax_opt.init(p)
56 | updates = p
57 | step, opt_state = optax_opt.update(
58 | updates, opt_state, params=p, extra_args={
59 | "loss": 1.0,
60 | "key": key
61 | })
62 | p = optax.apply_updates(p, step)
63 |
64 |
65 | if __name__ == "__main__":
66 | absltest.main()
67 |
--------------------------------------------------------------------------------
/learned_optimization/optimizers/optax_opts_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for optax_opts."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.optimizers import optax_opts
21 | from learned_optimization.optimizers import test_utils
22 |
23 |
24 | class OptaxOptsTest(parameterized.TestCase):
25 |
26 | @parameterized.named_parameters(
27 | ('RMSProp', lambda: optax_opts.RMSProp(1e-4)),
28 | ('SGDM', lambda: optax_opts.SGDM(1e-4)),
29 | ('SGD', lambda: optax_opts.SGD(1e-3)),
30 | ('Adam', lambda: optax_opts.Adam(1e-4)),
31 | ('AdaBelief', lambda: optax_opts.AdaBelief(1e-4)),
32 | ('AdaGrad', lambda: optax_opts.AdaGrad(1e-4)),
33 | ('Adafactor', lambda: optax_opts.Adafactor(1e-4)),
34 | ('Yogi', lambda: optax_opts.Yogi(1e-4)),
35 | ('RAdam', lambda: optax_opts.RAdam(1e-4)),
36 | ('AdamW', lambda: optax_opts.AdamW(1e-4)),
37 | ('Lamb', lambda: optax_opts.Lamb(1e-4)),
38 | ('Lars', lambda: optax_opts.Lars(1e-4)),
39 | ('Fromage', lambda: optax_opts.Fromage(1e-4)),
40 | ('PiecewiseLinearAdam', optax_opts.PiecewiseLinearAdam))
41 | def test_opt(self, opt_fn):
42 | test_utils.smoketest_optimizer(opt_fn())
43 |
44 | def test_sm3(self):
45 | # SM3 seems to switch the dtype of the state variables.
46 | test_utils.smoketest_optimizer(optax_opts.SM3(1e-4), strict_types=False)
47 |
48 |
49 | if __name__ == '__main__':
50 | absltest.main()
51 |
--------------------------------------------------------------------------------
/learned_optimization/optimizers/shampoo_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.optimizers.opt_list."""
17 | from absl.testing import absltest
18 | from learned_optimization.optimizers import shampoo
19 | from learned_optimization.optimizers import test_utils
20 |
21 |
22 | class ShampooTest(absltest.TestCase):
23 |
24 | def test_shampoo(self):
25 | test_utils.smoketest_optimizer(
26 | shampoo.Shampoo(1e-4, 32), strict_types=False)
27 |
28 |
29 | if __name__ == '__main__':
30 | absltest.main()
31 |
--------------------------------------------------------------------------------
/learned_optimization/outer_trainers/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """learned_optimizers.outer_trainers module."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/outer_trainers/full_grad_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.outer_trainers.full_grad."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.learned_optimizers import base
21 | from learned_optimization.outer_trainers import full_grad
22 | from learned_optimization.outer_trainers import lopt_truncated_step
23 | from learned_optimization.outer_trainers import test_utils
24 | from learned_optimization.outer_trainers import truncation_schedule
25 | from learned_optimization.tasks import quadratics
26 |
27 |
28 | class FullEsTest(parameterized.TestCase):
29 |
30 | def test_full_es_trainer(self):
31 | learned_opt = base.LearnableSGD()
32 | task_family = quadratics.FixedDimQuadraticFamily(10)
33 | trunc_sched = truncation_schedule.NeverEndingTruncationSchedule()
34 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
35 | task_family, learned_opt, trunc_sched, num_tasks=5)
36 |
37 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)
38 | trainer = full_grad.FullGrad(
39 | truncated_step, truncation_schedule=trunc_sched, loss_type="avg")
40 |
41 | test_utils.trainer_smoketest(trainer)
42 |
43 | def test_full_grad_trainer_with_data(self):
44 | learned_opt = base.LearnableSGD()
45 | task_family = quadratics.FixedDimQuadraticFamilyData(10)
46 | trunc_sched = truncation_schedule.NeverEndingTruncationSchedule()
47 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
48 | task_family, learned_opt, trunc_sched, num_tasks=5)
49 |
50 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)
51 | trainer = full_grad.FullGrad(
52 | truncated_step, loss_type="last", truncation_schedule=trunc_sched)
53 |
54 | test_utils.trainer_smoketest(trainer)
55 |
56 |
57 | if __name__ == "__main__":
58 | absltest.main()
59 |
--------------------------------------------------------------------------------
/learned_optimization/outer_trainers/test_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities to test trainers."""
17 | import jax
18 | from learned_optimization.outer_trainers import gradient_learner
19 |
20 |
21 | def trainer_smoketest(trainer):
22 | key = jax.random.PRNGKey(0)
23 | theta = trainer.truncated_step.outer_init(key)
24 | worker_weights = gradient_learner.WorkerWeights(
25 | theta, None, gradient_learner.OuterState(1) # pytype: disable=wrong-arg-types # jax-ndarray
26 | )
27 | state = trainer.init_worker_state(worker_weights, key=key)
28 | out, metrics = trainer.compute_gradient_estimate(
29 | worker_weights, key, state, with_summary=True)
30 | del out
31 | del metrics
32 |
--------------------------------------------------------------------------------
/learned_optimization/outer_trainers/truncated_es_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.outer_trainers.truncated_es."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.learned_optimizers import base
21 | from learned_optimization.outer_trainers import lopt_truncated_step
22 | from learned_optimization.outer_trainers import test_utils
23 | from learned_optimization.outer_trainers import truncated_es
24 | from learned_optimization.outer_trainers import truncation_schedule
25 | from learned_optimization.tasks import quadratics
26 |
27 |
28 | class TruncatedESTest(parameterized.TestCase):
29 |
30 | def test_truncated_es_trainer(self):
31 | learned_opt = base.LearnableSGD()
32 | task_family = quadratics.FixedDimQuadraticFamily(10)
33 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)
34 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
35 | task_family, learned_opt, trunc_sched, num_tasks=5)
36 |
37 | trainer = truncated_es.TruncatedES(
38 | truncated_step, steps_per_jit=3, unroll_length=9)
39 | test_utils.trainer_smoketest(trainer)
40 |
41 | @parameterized.product(meta_loss_split=(None, "train", "same_data"))
42 | def test_truncated_es_trainer_with_data(self, meta_loss_split):
43 | learned_opt = base.LearnableSGD()
44 | task_family = quadratics.FixedDimQuadraticFamilyData(10)
45 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)
46 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
47 | task_family,
48 | learned_opt,
49 | trunc_sched,
50 | num_tasks=4,
51 | meta_loss_split=meta_loss_split)
52 |
53 | trainer = truncated_es.TruncatedES(
54 | truncated_step, steps_per_jit=5, unroll_length=5)
55 |
56 | test_utils.trainer_smoketest(trainer)
57 |
58 | if __name__ == "__main__":
59 | absltest.main()
60 |
--------------------------------------------------------------------------------
/learned_optimization/outer_trainers/truncated_grad_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.outer_trainers.truncated_grad."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.learned_optimizers import base
21 | from learned_optimization.outer_trainers import lopt_truncated_step
22 | from learned_optimization.outer_trainers import test_utils
23 | from learned_optimization.outer_trainers import truncated_grad
24 | from learned_optimization.outer_trainers import truncation_schedule
25 | from learned_optimization.tasks import quadratics
26 |
27 |
28 | class TruncatedGradTest(parameterized.TestCase):
29 |
30 | def test_truncated_grad_trainer(self):
31 | learned_opt = base.LearnableSGD()
32 | task_family = quadratics.FixedDimQuadraticFamily(10)
33 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)
34 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
35 | task_family, learned_opt, trunc_sched, num_tasks=5)
36 |
37 | trainer = truncated_grad.TruncatedGrad(
38 | truncated_step, steps_per_jit=3, unroll_length=9)
39 | test_utils.trainer_smoketest(trainer)
40 |
41 | @parameterized.product(meta_loss_split=(None, "train"))
42 | def test_truncated_grad_trainer_with_data(self, meta_loss_split):
43 | learned_opt = base.LearnableSGD()
44 | task_family = quadratics.FixedDimQuadraticFamilyData(10)
45 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)
46 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
47 | task_family,
48 | learned_opt,
49 | trunc_sched,
50 | num_tasks=5,
51 | meta_loss_split=meta_loss_split)
52 |
53 | trainer = truncated_grad.TruncatedGrad(
54 | truncated_step, steps_per_jit=5, unroll_length=5)
55 |
56 | test_utils.trainer_smoketest(trainer)
57 |
58 | if __name__ == "__main__":
59 | absltest.main()
60 |
--------------------------------------------------------------------------------
/learned_optimization/outer_trainers/truncated_pes_pmap_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.outer_trainers.truncated_pes."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import jax
21 | from learned_optimization.learned_optimizers import base
22 | from learned_optimization.outer_trainers import lopt_truncated_step
23 | from learned_optimization.outer_trainers import test_utils
24 | from learned_optimization.outer_trainers import truncated_pes
25 | from learned_optimization.outer_trainers import truncation_schedule
26 | from learned_optimization.tasks import quadratics
27 |
28 |
29 | class TruncatedPesTest(parameterized.TestCase):
30 |
31 | @parameterized.product(meta_loss_split=(None, "train"))
32 | def test_truncated_pes_pmap_trainer_with_data(self, meta_loss_split):
33 | learned_opt = base.LearnableSGD()
34 | task_family = quadratics.FixedDimQuadraticFamilyData(10)
35 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)
36 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
37 | task_family,
38 | learned_opt,
39 | trunc_sched,
40 | meta_loss_split=meta_loss_split,
41 | num_tasks=8)
42 |
43 | num_devices = len(jax.local_devices())
44 | if num_devices != 8:
45 | self.skipTest(f"Not enough accelerators! Expected 8, found {num_devices}."
46 | " Please fun this test with exactly 8 accelerators.")
47 |
48 | trainer = truncated_pes.TruncatedPESPMAP(
49 | truncated_step=truncated_step, num_devices=num_devices, steps_per_jit=5)
50 |
51 | test_utils.trainer_smoketest(trainer)
52 |
53 |
54 | if __name__ == "__main__":
55 | absltest.main()
56 |
--------------------------------------------------------------------------------
/learned_optimization/outer_trainers/truncated_pes_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.outer_trainers.truncated_pes."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.learned_optimizers import base
21 | from learned_optimization.outer_trainers import lopt_truncated_step
22 | from learned_optimization.outer_trainers import test_utils
23 | from learned_optimization.outer_trainers import truncated_pes
24 | from learned_optimization.outer_trainers import truncation_schedule
25 | from learned_optimization.tasks import quadratics
26 |
27 |
28 | class TruncatedPesTest(parameterized.TestCase):
29 |
30 | def test_truncated_pes_trainer(self):
31 | learned_opt = base.LearnableSGD()
32 | task_family = quadratics.FixedDimQuadraticFamily(10)
33 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)
34 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
35 | task_family, learned_opt, trunc_sched, num_tasks=5)
36 |
37 | trainer = truncated_pes.TruncatedPES(truncated_step, steps_per_jit=5)
38 | test_utils.trainer_smoketest(trainer)
39 |
40 | @parameterized.product(meta_loss_split=(None, "train"))
41 | def test_truncated_pes_trainer_with_data(self, meta_loss_split):
42 | learned_opt = base.LearnableSGD()
43 | task_family = quadratics.FixedDimQuadraticFamilyData(10)
44 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)
45 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
46 | task_family,
47 | learned_opt,
48 | trunc_sched,
49 | num_tasks=5,
50 | meta_loss_split=meta_loss_split)
51 |
52 | trainer = truncated_pes.TruncatedPES(truncated_step, steps_per_jit=5)
53 |
54 | test_utils.trainer_smoketest(trainer)
55 |
56 | def test_truncated_pes_stacked(self):
57 | learned_opt = base.LearnableSGD()
58 | task_family = quadratics.FixedDimQuadraticFamilyData(10)
59 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)
60 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
61 | task_family, learned_opt, trunc_sched, num_tasks=5)
62 | trainer = truncated_pes.TruncatedPES(
63 | truncated_step, steps_per_jit=5, stack_antithetic_samples=True)
64 |
65 | test_utils.trainer_smoketest(trainer)
66 |
67 | if __name__ == "__main__":
68 | absltest.main()
69 |
--------------------------------------------------------------------------------
/learned_optimization/outer_trainers/truncation_schedule_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.outer_trainers.truncation_schedule."""
17 |
18 | from absl.testing import absltest
19 | import flax
20 | import jax
21 | from learned_optimization.optimizers import learning_rate_schedules
22 | from learned_optimization.outer_trainers import truncation_schedule
23 |
24 |
25 | @flax.struct.dataclass
26 | class OuterState:
27 | outer_iteration: int
28 |
29 |
30 | class TruncationScheduleTest(absltest.TestCase):
31 |
32 | def test_constant_truncation_sched(self):
33 | key = jax.random.PRNGKey(0)
34 | trunc_fns = truncation_schedule.ConstantTruncationSchedule(10)
35 | state = trunc_fns.init(key, outer_state=None)
36 | rng = key
37 | for i in range(10):
38 | rng, key = jax.random.split(rng)
39 | state, is_done = trunc_fns.next_state(state, i, key, outer_state=None)
40 | del is_done
41 |
42 | def test_LogUniformLengthSchedule(self):
43 | key = jax.random.PRNGKey(0)
44 | trunc_fns = truncation_schedule.LogUniformLengthSchedule(10, 100)
45 | state = trunc_fns.init(key, outer_state=None)
46 | rng = key
47 | for i in range(10):
48 | rng, key = jax.random.split(rng)
49 | state, is_done = trunc_fns.next_state(state, i, key, outer_state=None)
50 | assert state.length >= 10
51 | assert state.length <= 100
52 | del is_done
53 |
54 | def test_ScheduledTruncationSchedule(self):
55 | sched = learning_rate_schedules.PiecewiseLinear([0, 10], [10, 100])
56 | key = jax.random.PRNGKey(0)
57 | trunc_fns = truncation_schedule.ScheduledTruncationSchedule(sched)
58 | outer_state = OuterState(0)
59 | state = trunc_fns.init(key, outer_state=outer_state)
60 | assert state.length == 10
61 |
62 | outer_state = OuterState(5)
63 | state = trunc_fns.init(key, outer_state=outer_state)
64 | assert state.length == int(10 + 100) // 2
65 |
66 |
67 | if __name__ == '__main__':
68 | absltest.main()
69 |
--------------------------------------------------------------------------------
/learned_optimization/population/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Population based training module."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/population/examples/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """learned_optimization.population.examples."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/population/examples/complex_cnn/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """learned_optimization.population.examples.complex_cnn."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/population/examples/simple_cnn/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """learned_optimization.population.examples.simple_cnn."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/population/examples/simple_cnn/train_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.population.examples.simple_cnn.train."""
17 | import tempfile
18 |
19 | from absl.testing import absltest
20 | from learned_optimization.population.examples.simple_cnn import train
21 |
22 |
23 | class TrainTest(absltest.TestCase):
24 |
25 | def test_train(self):
26 | with tempfile.TemporaryDirectory() as directory:
27 | train.train(directory, 100)
28 |
29 |
30 | if __name__ == '__main__':
31 | absltest.main()
32 |
--------------------------------------------------------------------------------
/learned_optimization/population/examples/synthetic/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """learned_optimization.population.examples.synthetic."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/population/examples/synthetic/train.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A synthetic example of population based training in action."""
17 | from absl import app
18 | import jax
19 | from learned_optimization.population import population as population_mod
20 | from learned_optimization.population.mutators import single_worker_explore
21 | import tqdm
22 |
23 |
24 | @jax.jit
25 | def gd_loss(parameters, meta_parameters):
26 | a, = parameters
27 | b, = meta_parameters
28 | return (b + a - 2.0)**2 + (a - 1.0)**2 + (b - 1.)**2
29 |
30 |
31 | @jax.jit
32 | def loss(parameters):
33 | return gd_loss(parameters, (2.0,))
34 |
35 |
36 | @jax.jit
37 | def update(parameters, meta_parameters):
38 | dp = jax.grad(gd_loss)(parameters, meta_parameters)
39 | new_parameters = jax.tree_util.tree_map(lambda a, b: a - 0.01 * b, parameters,
40 | dp)
41 | return new_parameters
42 |
43 |
44 | def mutate_fn(meta_params, direction, _):
45 | if direction == "pos":
46 | return (meta_params[0] + 0.1,)
47 | else:
48 | return (meta_params[0] - 0.1,)
49 |
50 |
51 | def train(steps=10000):
52 | """Train for some number of steps."""
53 | num_workers = 1
54 | mutator = single_worker_explore.BranchingSingleMachine(mutate_fn, 50, 30)
55 | population = population_mod.PopulationController(
56 | [(1.,) for _ in range(num_workers)], mutator)
57 |
58 | params = None
59 | meta_params = None
60 | gen_id = None
61 | step = 0
62 |
63 | for _ in tqdm.trange(steps):
64 | new_data = population.maybe_get_worker_data(0, gen_id, step, params,
65 | meta_params)
66 |
67 | if new_data:
68 | params = new_data.params
69 | meta_params = new_data.meta_params
70 | gen_id = new_data.generation_id
71 | step = new_data.step
72 |
73 | # If params is none, this simulates a new initialization
74 | if params is None:
75 | params = (0.,)
76 |
77 | # only log back eval every 5 steps.
78 | for i in range(5):
79 | if i % 5 == 0:
80 | l = loss(params)
81 | population.set_eval(0, gen_id, step, params, l)
82 | print(f"\t {l}, params: {params}, meta_params:{meta_params}")
83 |
84 | params = update(params, meta_params)
85 | step += 1
86 |
87 |
88 | def main(_):
89 | train(steps=10000)
90 |
91 |
92 | if __name__ == "__main__":
93 | app.run(main)
94 |
--------------------------------------------------------------------------------
/learned_optimization/population/examples/synthetic/train_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for train."""
17 |
18 | from absl.testing import absltest
19 | from learned_optimization.population.examples.synthetic import train
20 |
21 |
22 | class TrainTest(absltest.TestCase):
23 |
24 | def test_train(self):
25 | train.train(100)
26 |
27 |
28 | if __name__ == '__main__':
29 | absltest.main()
30 |
--------------------------------------------------------------------------------
/learned_optimization/population/mutators/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Mutators module."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/population/mutators/fixed_schedule_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.population.fixed_schedule."""
17 |
18 | from absl.testing import absltest
19 | from learned_optimization.population import population as pop_mod
20 | from learned_optimization.population.mutators import fixed_schedule
21 |
22 |
23 | class FixedScheduleTest(absltest.TestCase):
24 |
25 | def test_follows_schedule(self):
26 | schedule = {0: 1., 10: 2., 30: 3.0}
27 | mutate = fixed_schedule.FixedSchedule(schedule)
28 | population = pop_mod.PopulationController([1.], mutate)
29 |
30 | new_data = population.maybe_get_worker_data(0, None, 0, None, None)
31 | hparams = new_data.meta_params
32 | gen_id = new_data.generation_id
33 |
34 | for step in range(100):
35 | new_data = population.maybe_get_worker_data(0, gen_id, step, None,
36 | hparams)
37 | if new_data:
38 | hparams = new_data.meta_params
39 | gen_id = new_data.generation_id
40 |
41 | population.set_eval(0, gen_id, step, None, 1.0)
42 |
43 | if step < 11:
44 | self.assertEqual(hparams, 1.0)
45 | elif step < 31:
46 | self.assertEqual(hparams, 2.0)
47 | else:
48 | self.assertEqual(hparams, 3.0)
49 |
50 |
51 | if __name__ == '__main__':
52 | absltest.main()
53 |
--------------------------------------------------------------------------------
/learned_optimization/population/mutators/single_worker_explore_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.population.single_worker_explore_test."""
17 |
18 | from absl.testing import absltest
19 | from learned_optimization.population import population as pop_mod
20 | from learned_optimization.population.mutators import single_worker_explore
21 | import numpy as onp
22 |
23 |
24 | class SingleWorkerExploreTest(absltest.TestCase):
25 |
26 | def test_single_worker_explore(self):
27 |
28 | def mutate_fn(p, loc, _):
29 | if loc == "pos":
30 | return p + 0.05
31 | elif loc == "neg":
32 | return p - 0.05
33 | else:
34 | raise ValueError("Bad loc")
35 |
36 | mutate = single_worker_explore.BranchingSingleMachine(mutate_fn, 2, 2)
37 | population = pop_mod.PopulationController([1.], mutate)
38 |
39 | step = 0
40 | for _ in range(200):
41 | new_data = population.maybe_get_worker_data(0, None, step, None, None)
42 |
43 | if new_data:
44 | hparams = new_data.meta_params
45 | gen_id = new_data.generation_id
46 | step = new_data.step
47 |
48 | population.set_eval(0, gen_id, step + 1, None, hparams**2)
49 |
50 | self.assertLess(onp.abs(hparams), 0.1)
51 |
52 |
53 | if __name__ == "__main__":
54 | absltest.main()
55 |
--------------------------------------------------------------------------------
/learned_optimization/population/mutators/winner_take_all_genetic_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.population.winner_take_all_genetic."""
17 |
18 | from absl.testing import absltest
19 | from learned_optimization.population import population as pop_mod
20 | from learned_optimization.population.mutators import winner_take_all_genetic
21 | import numpy as onp
22 |
23 |
24 | class WinnerTakeAllGeneticTest(absltest.TestCase):
25 |
26 | def test_winner_take_all(self):
27 | onp.random.seed(0)
28 |
29 | def mutate_fn(p):
30 | p += onp.random.normal() * 0.01
31 | return p
32 |
33 | mutate = winner_take_all_genetic.WinnerTakeAllGenetic(mutate_fn, 2)
34 | population = pop_mod.PopulationController([1.], mutate)
35 |
36 | num_worker = 3
37 | population = pop_mod.PopulationController([1. for _ in range(num_worker)],
38 | mutate)
39 |
40 | for step in range(200):
41 | for i in range(num_worker):
42 | new_data = population.maybe_get_worker_data(i, None, step, None, None)
43 |
44 | hparams = new_data.meta_params
45 | gen_id = new_data.generation_id
46 |
47 | population.set_eval(i, gen_id, step + 1, None, hparams**2)
48 |
49 | self.assertLess(onp.abs(hparams), 0.3)
50 |
51 |
52 | if __name__ == '__main__':
53 | absltest.main()
54 |
--------------------------------------------------------------------------------
/learned_optimization/profile.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A minimal profiling library to profile code in a context manager."""
17 | import functools
18 | import sys
19 | import time
20 | from typing import TypeVar
21 |
22 | from absl import flags
23 | from absl import logging
24 |
25 | flags.DEFINE_bool("profile_log_times", False, "Log out timing information.")
26 |
27 | FLAGS = flags.FLAGS
28 | FLAGS(sys.argv, known_only=True) # Ensure flags are parsed at this time.
29 |
30 | T = TypeVar("T")
31 |
32 |
33 | class Profile:
34 | """Context manager for profiling functions.
35 |
36 | ```
37 | with Profile("name"):
38 | run_somecode()
39 | ```
40 | """
41 |
42 | def __init__(self, name: str):
43 | self.name = name
44 |
45 | def __enter__(self):
46 | self._start_time = time.time()
47 |
48 | def __exit__(self, exc_type, exc_value, traceback):
49 | if FLAGS.profile_log_times:
50 | logging.info(f"{self.name} took {time.time()-self._start_time} seconds") # pylint: disable=logging-fstring-interpolation
51 |
52 |
53 | def wrap():
54 | """Wrap a function in a Profile."""
55 |
56 | def _wrapper(fn: T) -> T:
57 |
58 | @functools.wraps(fn)
59 | def _fn(*args, **kwargs):
60 | with Profile(fn.__name__):
61 | return fn(*args, **kwargs)
62 |
63 | return _fn
64 |
65 | return _wrapper
66 |
67 |
68 |
--------------------------------------------------------------------------------
/learned_optimization/py_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Common python utilities."""
17 | from concurrent import futures
18 | from typing import Any, Callable, Sequence
19 | import tqdm
20 |
21 |
22 | def threaded_tqdm_map(threads: int, func: Callable[[Any], Any],
23 | data: Sequence[Any]) -> Sequence[Any]:
24 | future_list = []
25 | with futures.ThreadPoolExecutor(threads) as executor:
26 | for l in tqdm.tqdm(data):
27 | future_list.append(executor.submit(func, l))
28 | return [x.result() for x in tqdm.tqdm(future_list)]
29 |
--------------------------------------------------------------------------------
/learned_optimization/research/README.md:
--------------------------------------------------------------------------------
1 | This directory is contains experimental code for research / demo projects.
--------------------------------------------------------------------------------
/learned_optimization/research/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Experimental and research code using learned_optimization."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/research/brax/brax_env_truncated_step_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for brax_env_truncated_step."""
17 |
18 | from absl.testing import absltest
19 | import brax.v1 as brax
20 | import jax
21 | from learned_optimization.outer_trainers import full_es
22 | from learned_optimization.outer_trainers import test_utils
23 | from learned_optimization.outer_trainers import truncated_pes
24 | from learned_optimization.outer_trainers import truncated_step
25 | from learned_optimization.outer_trainers import truncation_schedule
26 | from learned_optimization.research.brax import brax_env_truncated_step
27 |
28 |
29 | class BraxEnvTruncatedStepTest(absltest.TestCase):
30 |
31 | def _get_tstep(self):
32 | env_name = 'fast'
33 | env = brax.envs.create(env_name)
34 | asize = env.action_size
35 | osize = env.observation_size
36 | policy = brax_env_truncated_step.BraxEnvPolicy(osize, asize)
37 | key = jax.random.PRNGKey(0)
38 | params = policy.init(key)
39 | tstep = brax_env_truncated_step.BraxEnvTruncatedStep(
40 | env_name,
41 | policy,
42 | episode_length=100,
43 | random_initial_iteration_offset=100)
44 | return tstep, params
45 |
46 | def test_brax_truncated_step(self):
47 | tstep, params = self._get_tstep()
48 | key = jax.random.PRNGKey(0)
49 | state = tstep.init_step_state(params, None, key)
50 | state, unused_out = tstep.unroll_step(params, state, key, None, None)
51 | unused_state, unused_out = tstep.unroll_step(params, state, key, None, None)
52 |
53 | def test_brax_truncated_step_with_full_es(self):
54 | tstep, _ = self._get_tstep()
55 | vtstep = truncated_step.VectorizedTruncatedStep(tstep, 32)
56 | grad_est = truncated_pes.TruncatedPES(vtstep, std=0.01)
57 | test_utils.trainer_smoketest(grad_est)
58 |
59 | def test_brax_truncated_step_with_pes(self):
60 | tstep, _ = self._get_tstep()
61 | vtstep = truncated_step.VectorizedTruncatedStep(tstep, 32)
62 |
63 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)
64 | grad_est = full_es.FullES(vtstep, std=0.01, truncation_schedule=trunc_sched)
65 | test_utils.trainer_smoketest(grad_est)
66 |
67 |
68 | if __name__ == '__main__':
69 | absltest.main()
70 |
--------------------------------------------------------------------------------
/learned_optimization/research/data_driven/README.md:
--------------------------------------------------------------------------------
1 | # General-Purpose In-Context Learning by Meta-Learning Transformers
2 |
3 | Research code for the paper https://arxiv.org/abs/2212.04458
--------------------------------------------------------------------------------
/learned_optimization/research/data_driven/run_mnist_projections.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Runs MNIST projection experiment with gin configuration."""
17 |
18 | from absl import app
19 | import jax
20 | from learned_optimization import filesystem
21 | from learned_optimization import setup_experiment
22 | from learned_optimization.research.data_driven import mnist_projections
23 | import numpy as np
24 | import yaml
25 |
26 |
27 | def main(_) -> None:
28 | rank = jax.process_index()
29 | train_log_dir = setup_experiment.setup_experiment(make_dir=(rank == 0))
30 | train(train_log_dir)
31 |
32 |
33 | def train(training_log_directory: str):
34 | """Runs a projection experiment.
35 |
36 | Args:
37 | training_log_directory: Directory to store log data to.
38 | """
39 |
40 | experiment = mnist_projections.ProjectionExperiment(training_log_directory)
41 | log_dict = experiment.run()
42 |
43 | if jax.process_index() == 0:
44 | yaml_file_name = f'{training_log_directory}/results.yaml'
45 | with filesystem.file_open(yaml_file_name, 'w') as f:
46 | yaml.dump(
47 | {
48 | k: np.asarray(v).item()
49 | for k, v in log_dict.items()
50 | if np.asarray(v).size == 1
51 | }, f)
52 | np_file_name = f'{training_log_directory}/results.npy'
53 | with filesystem.file_open(np_file_name, 'wb') as f:
54 | np.save(f, log_dict)
55 |
56 |
57 | if __name__ == '__main__':
58 | app.run(main)
59 |
--------------------------------------------------------------------------------
/learned_optimization/research/data_driven/summary.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utils for logging of summary statistics."""
17 |
18 | from concurrent import futures
19 | import jax
20 |
21 | from learned_optimization import summary as lo_summary
22 | from learned_optimization.baselines import utils
23 | import numpy as np
24 |
25 |
26 | def only_first_rank(func):
27 | """Only runs the function on rank 0."""
28 |
29 | def wrappee(*args, **kwargs):
30 | if jax.process_index() == 0:
31 | return func(*args, **kwargs)
32 | return None
33 |
34 | return wrappee
35 |
36 |
37 | class DictSummaryWriter(lo_summary.SummaryWriterBase):
38 | """A summary writer than stores entire dicts as scalars and npy files."""
39 |
40 | FLUSH_TIMEOUT_SECS = 10
41 |
42 | def __init__(self, base_writer: lo_summary.SummaryWriterBase, log_dir: str):
43 | self._writer = base_writer
44 | self._log_dir = log_dir
45 | self._thread_pool = futures.ThreadPoolExecutor(max_workers=2)
46 | self._pending_dict_writes = []
47 |
48 | @only_first_rank
49 | def scalar(self, name, value, step):
50 | return self._writer.scalar(name, value, step)
51 |
52 | @only_first_rank
53 | def histogram(self, name, value, step):
54 | return self._writer.histogram(name, value, step)
55 |
56 | @only_first_rank
57 | def flush(self):
58 | futures.wait(self._pending_dict_writes, timeout=self.FLUSH_TIMEOUT_SECS)
59 | self._pending_dict_writes.clear()
60 | return self._writer.flush()
61 |
62 | @only_first_rank
63 | def dict(self, dict_value, step):
64 | # Store scalars in tf summary
65 | for k, v in dict_value.items():
66 | if v is None:
67 | continue
68 | if np.isscalar(v) or v.size == 1:
69 | self.scalar(k, v, step)
70 |
71 | # Store entire dictionary as npy file
72 | file_name = f'{self._log_dir}/summary_{step}.npy'
73 | task = self._thread_pool.submit(utils.write_npz, file_name, dict_value)
74 | self._pending_dict_writes.append(task)
75 |
--------------------------------------------------------------------------------
/learned_optimization/research/es_scaling/es_scaling_tasks.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # ES tasks to probe scaling.
17 | # Gradient accumulation
18 |
--------------------------------------------------------------------------------
/learned_optimization/research/general_lopt/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """learned_optimization.research.general_lopt"""
17 |
--------------------------------------------------------------------------------
/learned_optimization/research/general_lopt/configs/README.md:
--------------------------------------------------------------------------------
1 | Configs used in meta-training.
2 |
3 | At the moment, these are not wired up to run outside of Google's infra.
--------------------------------------------------------------------------------
/learned_optimization/research/general_lopt/hyper_v2_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.rnn_mlp_lopt."""
17 | from absl.testing import absltest
18 | from learned_optimization.learned_optimizers import test_utils
19 | from learned_optimization.research.general_lopt import hyper_v2
20 |
21 |
22 | class HyperV2LOptTest(absltest.TestCase):
23 |
24 | def test_hyperv2_lopt_lopt(self):
25 | test_utils.smoketest_learned_optimizer(hyper_v2.HyperV2())
26 |
27 |
28 | if __name__ == '__main__':
29 | absltest.main()
30 |
--------------------------------------------------------------------------------
/learned_optimization/research/general_lopt/prefab_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for pretrained lopt optimizers."""
17 | from absl.testing import absltest
18 | from learned_optimization.optimizers import test_utils as opt_test_utils
19 | from learned_optimization.research.general_lopt import prefab
20 |
21 |
22 | class PrefabLOptTest(absltest.TestCase):
23 |
24 | def test_lopt(self):
25 | opt_test_utils.smoketest_optimizer(prefab.LearnedOptimizer(100))
26 |
27 | def test_lopt_accum(self):
28 | opt = prefab.LearnedOptimizer(100, max_training_steps=10)
29 | self.assertEqual(opt.opt.num_average, 10)
30 | opt_test_utils.smoketest_optimizer(opt)
31 |
32 | def test_lopt_wd(self):
33 | opt = prefab.LearnedOptimizer(100, weight_decay=1e-6)
34 | opt_test_utils.smoketest_optimizer(opt)
35 |
36 |
37 | if __name__ == '__main__':
38 | absltest.main()
39 |
--------------------------------------------------------------------------------
/learned_optimization/research/general_lopt/pretrained_optimizers_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for pretrained lopt optimizers."""
17 | from absl.testing import absltest
18 | from learned_optimization.optimizers import test_utils as opt_test_utils
19 | from learned_optimization.research.general_lopt import pretrained_optimizers
20 |
21 |
22 | class PretrainedLOptTest(absltest.TestCase):
23 |
24 | def test_hyperv2_pretrain(self):
25 | opt_test_utils.smoketest_optimizer(
26 | pretrained_optimizers
27 | .aug12_continue_on_bigger_2xbs_200kstep_bigproblem_v2_5620())
28 |
29 |
30 | def test_hyperv2_pretrain2(self):
31 | opt_test_utils.smoketest_optimizer(
32 | pretrained_optimizers.aug11_aug4_trunc10per_last())
33 |
34 | def test_hyperv2_pretrain3(self):
35 | opt_test_utils.smoketest_optimizer(
36 | pretrained_optimizers
37 | .aug26_aug12_continue_on_bigger_2xbs_200kstep_bigproblem_v2_1397())
38 |
39 |
40 | if __name__ == '__main__':
41 | absltest.main()
42 |
--------------------------------------------------------------------------------
/learned_optimization/research/general_lopt/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """learned_optimization.research.general_lopt.tasks"""
17 |
--------------------------------------------------------------------------------
/learned_optimization/research/general_lopt/tasks/fast_mlp_diff_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A set of tasks for quick iteration on meta-training."""
17 | import functools
18 |
19 | import gin
20 | from learned_optimization.tasks import base
21 | from learned_optimization.tasks import task_augmentation
22 | from learned_optimization.tasks.datasets import image
23 | from learned_optimization.tasks.fixed import image_mlp
24 |
25 | inner_bs = 64
26 |
27 |
28 | def cifar_task():
29 | datasets = image.cifar10_datasets(
30 | batch_size=inner_bs, image_size=(8, 8), convert_to_black_and_white=True)
31 | return image_mlp._MLPImageTask(datasets, [32]) # pylint: disable=protected-access
32 |
33 |
34 | def fashion_mnist_task():
35 | datasets = image.fashion_mnist_datasets(
36 | batch_size=inner_bs, image_size=(8, 8))
37 | return image_mlp._MLPImageTask(datasets, [32]) # pylint: disable=protected-access
38 |
39 |
40 | def mnist_task():
41 | datasets = image.mnist_datasets(batch_size=inner_bs, image_size=(8, 8))
42 | return image_mlp._MLPImageTask(datasets, [32]) # pylint: disable=protected-access
43 |
44 |
45 | def svhn_task():
46 | datasets = image.svhn_cropped_datasets(
47 | batch_size=inner_bs, image_size=(8, 8), convert_to_black_and_white=True)
48 | return image_mlp._MLPImageTask(datasets, [32]) # pylint: disable=protected-access
49 |
50 |
51 | def task_to_augmented_task_family(task_fn):
52 | task_family = base.single_task_to_family(task=task_fn())
53 | return task_augmentation.ReparamWeightsFamily(task_family, "tensor",
54 | (0.01, 100))
55 |
56 |
57 | @gin.configurable
58 | def jun24_fast_mlp_diff_data_task_list():
59 | return [[functools.partial(task_to_augmented_task_family, fn)]
60 | for fn in [cifar_task, fashion_mnist_task, mnist_task, svhn_task]]
61 |
62 |
63 | @functools.lru_cache(None)
64 | def _get_task(task_family_idx, task_family_seed):
65 | # cache here so that the task identities are the same and thus will reuse
66 | # jit cache.
67 | task_family = jun24_fast_mlp_diff_data_task_list()[task_family_idx][0]()
68 | return base.get_task(task_family, task_family_seed=task_family_seed)
69 |
70 |
71 | _datasets = ("cifar", "fashion_mnist", "mnist", "svhn")
72 | for _i in range(4):
73 | n_seed = 8
74 | for _s in range(n_seed):
75 | _task_seed = _i * n_seed + _s
76 | _task_name = f"Jun24FastMLPEval_TaskFamily_{_datasets[_i]}_Seed{_s}"
77 | locals()[_task_name] = functools.partial(_get_task, _i, _task_seed)
78 | gin.external_configurable(locals()[_task_name], _task_name)
79 |
--------------------------------------------------------------------------------
/learned_optimization/research/hysteresis/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Experimental research code."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/research/hysteresis/truncated_pes_no_hyst.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """PES with no hysteresis by spending a lot more compute."""
17 | from typing import Any, Mapping, Optional, Tuple
18 |
19 | import gin
20 | import haiku as hk
21 | import jax.numpy as jnp
22 | from learned_optimization import profile
23 | from learned_optimization.outer_trainers import gradient_learner
24 | from learned_optimization.outer_trainers import truncated_pes
25 | from learned_optimization.outer_trainers import truncated_step as truncated_step_mod
26 |
27 | PRNGKey = jnp.ndarray
28 | MetaParams = Any
29 | TruncatedUnrollState = Any
30 |
31 |
32 | @gin.configurable
33 | class TruncatedPESNoHyst(gradient_learner.GradientEstimator):
34 | """PES with no hysteresis by spending a lot more compute."""
35 |
36 | def __init__(
37 | self,
38 | truncated_step: truncated_step_mod.VectorizedTruncatedStep,
39 | burnin_steps: int,
40 | trunc_length=10,
41 | std=0.01,
42 | steps_per_jit=10,
43 | stack_antithetic_samples: bool = False,
44 | sign_delta_loss_scalar: Optional[float] = None,
45 | ):
46 | self.burnin_steps = burnin_steps
47 | self.trunc_length = trunc_length
48 | self.grad_est = truncated_pes.TruncatedPES(truncated_step, trunc_length,
49 | std, steps_per_jit,
50 | stack_antithetic_samples,
51 | sign_delta_loss_scalar)
52 |
53 | @profile.wrap()
54 | def init_worker_state(self, worker_weights: gradient_learner.WorkerWeights,
55 | key: PRNGKey):
56 | return self.grad_est.init_worker_state(worker_weights, key)
57 |
58 | @profile.wrap()
59 | def compute_gradient_estimate( # pytype: disable=signature-mismatch
60 | self,
61 | worker_weights: gradient_learner.WorkerWeights,
62 | key: PRNGKey,
63 | state,
64 | with_summary: bool = False
65 | ) -> Tuple[gradient_learner.GradientEstimatorOut, Mapping[str, jnp.ndarray]]:
66 |
67 | rng = hk.PRNGSequence(key)
68 |
69 | with profile.Profile("new_state"):
70 | state = self.init_worker_state(worker_weights, next(rng))
71 |
72 | assert self.burnin_steps % self.trunc_length == 0
73 |
74 | with profile.Profile("burnin"):
75 | for _ in range(self.burnin_steps // self.trunc_length):
76 | output, unused_metrics = self.grad_est.compute_gradient_estimate(
77 | worker_weights, next(rng), state, with_summary=False)
78 | state = output.unroll_state
79 |
80 | return self.grad_est.compute_gradient_estimate(
81 | worker_weights, next(rng), state, with_summary=with_summary)
82 |
--------------------------------------------------------------------------------
/learned_optimization/research/jaxnerf/example_data/imgs/r_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/learned_optimization/3ab34d27bdd38f487af65fb025f9c86869946e81/learned_optimization/research/jaxnerf/example_data/imgs/r_0.png
--------------------------------------------------------------------------------
/learned_optimization/research/jaxnerf/example_data/transforms_test.json:
--------------------------------------------------------------------------------
1 | {"camera_angle_x": 0.6911112070083618, "frames": [{"file_path": "./imgs/r_0", "rotation": 0.012566370614359171, "transform_matrix": [[-0.9999021887779236, 0.004192245192825794, -0.013345719315111637, -0.05379832163453102], [-0.013988681137561798, -0.2996590733528137, 0.95394366979599, 3.845470428466797], [-4.656612873077393e-10, 0.9540371894836426, 0.29968830943107605, 1.2080823183059692], [0.0, 0.0, 0.0, 1.0]]}]}
--------------------------------------------------------------------------------
/learned_optimization/research/jaxnerf/example_data/transforms_train.json:
--------------------------------------------------------------------------------
1 | {"camera_angle_x": 0.6911112070083618, "frames": [{"file_path": "./imgs/r_0", "rotation": 0.012566370614359171, "transform_matrix": [[-0.9999021887779236, 0.004192245192825794, -0.013345719315111637, -0.05379832163453102], [-0.013988681137561798, -0.2996590733528137, 0.95394366979599, 3.845470428466797], [-4.656612873077393e-10, 0.9540371894836426, 0.29968830943107605, 1.2080823183059692], [0.0, 0.0, 0.0, 1.0]]}]}
--------------------------------------------------------------------------------
/learned_optimization/research/jaxnerf/jaxnerf_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for JaxNeRF tasks."""
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from learned_optimization.research.jaxnerf import jaxnerf
20 | from learned_optimization.tasks import test_utils
21 |
22 | tasks = ['JAXNeRF_TestTask']
23 |
24 |
25 | class JaxNeRFTest(parameterized.TestCase):
26 |
27 | @parameterized.parameters(tasks)
28 | def test_tasks(self, task_name):
29 | task = getattr(jaxnerf, task_name)()
30 | test_utils.smoketest_task(task, abstract_data=False)
31 |
32 |
33 | if __name__ == '__main__':
34 | absltest.main()
35 |
--------------------------------------------------------------------------------
/learned_optimization/research/univ_nfn/learned_opt/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/learned_optimization/research/univ_nfn/learned_opt/outer_opt.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Additional outer optimizer helpers."""
17 |
18 | import gin
19 | import jax
20 | from jax import flatten_util
21 | import jax.numpy as jnp
22 | from learned_optimization.optimizers import base as opt_base
23 | from learned_optimization.optimizers import optax_opts
24 | import optax
25 |
26 |
27 | @gin.configurable
28 | class GradientNormClippedOptimizer(opt_base.Optimizer):
29 | """Clip gradients by norm before passing into an optimizer."""
30 |
31 | def __init__(self, opt: opt_base.Optimizer, clip_norm: float = 1.0):
32 | if not isinstance(opt, opt_base.Optimizer):
33 | raise ValueError(
34 | "Must instance of Optimizer. Maybe you are passing the"
35 | f" class and not an instance? Received {opt}."
36 | )
37 | self.opt = opt
38 | self.clip_norm = clip_norm
39 |
40 | def get_params(self, state):
41 | return self.opt.get_params(state)
42 |
43 | def get_state(self, state):
44 | return self.opt.get_state(state)
45 |
46 | def init(self, *args, **kwargs):
47 | return self.opt.init(*args, **kwargs)
48 |
49 | def update(self, opt_state, grad, *args, **kwargs):
50 | g_norm = jnp.sqrt(jnp.sum(flatten_util.ravel_pytree(grad)[0] ** 2))
51 | trigger = jnp.squeeze(g_norm < self.clip_norm)
52 |
53 | def clip_fn(t):
54 | return jax.lax.select(
55 | trigger, t, (t / g_norm.astype(t.dtype)) * self.clip_norm
56 | )
57 |
58 | grad = jax.tree_util.tree_map(clip_fn, grad)
59 | return self.opt.update(opt_state, grad, *args, **kwargs)
60 |
61 |
62 | @gin.configurable
63 | class WarmupCosineAdam(optax_opts.OptaxOptimizer):
64 | """Adam with linear warmup + cosine LR decay."""
65 |
66 | def __init__(
67 | self,
68 | warmup_steps=gin.REQUIRED,
69 | decay_steps=gin.REQUIRED,
70 | learning_rate=1e-3,
71 | beta1=0.9,
72 | beta2=0.999,
73 | epsilon=1e-8,
74 | epsilon_root=1e-8,
75 | ):
76 | final_lr = learning_rate / 10
77 | opt = optax.chain(
78 | optax.scale_by_adam(
79 | b1=beta1, b2=beta2, eps=epsilon, eps_root=epsilon_root
80 | ),
81 | optax.scale_by_schedule(
82 | optax.warmup_cosine_decay_schedule(
83 | 0.0, learning_rate, warmup_steps, decay_steps, final_lr
84 | )
85 | ),
86 | optax.scale(-1),
87 | )
88 | super().__init__(opt)
89 |
--------------------------------------------------------------------------------
/learned_optimization/research/univ_nfn/nfn/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/learned_optimization/run_outer_train_single.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """This script contains both outer workers and outer learners."""
17 | from typing import Sequence
18 |
19 | from absl import app
20 | from learned_optimization import outer_train
21 | from learned_optimization import setup_experiment
22 |
23 |
24 | def main(argv: Sequence[str]) -> None:
25 | if len(argv) > 1:
26 | raise app.UsageError('Too many command-line arguments.')
27 |
28 | train_log_dir = setup_experiment.setup_experiment(make_dir=True)
29 | outer_train.run_train(
30 | train_log_dir,
31 | is_trainer=True,
32 | is_worker=True,
33 | )
34 |
35 |
36 | if __name__ == '__main__':
37 | app.run(main)
38 |
--------------------------------------------------------------------------------
/learned_optimization/summary_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.summary."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from jax import lax
21 | import jax.numpy as jnp
22 | from learned_optimization import summary
23 |
24 |
25 | @jax.jit
26 | def fn(a):
27 | summary.summary("a", a[0])
28 |
29 | def update(state, _):
30 | s = state + 1
31 | summary.summary("loop", s[0])
32 | return s, s
33 |
34 | a, _ = lax.scan(update, a, jnp.arange(20))
35 |
36 | return a * 2
37 |
38 |
39 | def fn2(static, a):
40 | summary.summary(static, a)
41 | return a * 2
42 |
43 |
44 | if summary.ORYX_LOGGING:
45 |
46 | class SummaryTest(absltest.TestCase):
47 |
48 | def setUp(self):
49 | super().setUp()
50 | summary.reset_summary_counter()
51 |
52 | def test_summary(self):
53 | result, metrics = summary.with_summary_output_reduced(fn)(
54 | jnp.zeros(1) + 123)
55 | del result
56 |
57 | metrics = summary.aggregate_metric_list([metrics])
58 | self.assertIn("mean||a", metrics)
59 | self.assertIn("mean||loop", metrics)
60 | self.assertEqual(metrics["mean||loop"], 133.5)
61 | self.assertEqual(metrics["mean||a"], 123)
62 |
63 | def test_summary_output_reduced_static_args(self):
64 | result, metrics = summary.with_summary_output_reduced(
65 | fn2, static_argnums=(0,))("name", jnp.zeros(1) + 123)
66 | del result
67 |
68 | metrics = summary.aggregate_metric_list([metrics])
69 | self.assertIn("mean||name", metrics)
70 |
71 | def test_add_with_summary(self):
72 | new_fn = summary.add_with_summary(fn2, static_argnums=(0,))
73 | o, m = new_fn("test", 1, with_summary=False)
74 | self.assertEmpty(m)
75 | self.assertEqual(o, 2)
76 |
77 | summary.reset_summary_counter()
78 |
79 | o, m = new_fn("test", 1, with_summary=True)
80 | self.assertEqual(o, 2)
81 | self.assertEqual(m["mean||test"], 1)
82 |
83 |
84 | if __name__ == "__main__":
85 | absltest.main()
86 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """learned_optimizers.tasks module."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """learned_optimizers.datasets module."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/datasets/language_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from absl.testing import absltest
17 | from learned_optimization.tasks.datasets import language
18 | import numpy as np
19 |
20 |
21 | class ImageTest(absltest.TestCase):
22 |
23 | def test_lm1b_32k_datasets(self):
24 | datasets = language.lm1b_32k_datasets(32, 8)
25 | self.assertEqual(datasets.abstract_batch["obs"].shape, (32, 8))
26 | self.assertEqual(datasets.abstract_batch["target"].shape, (32, 8))
27 |
28 | data = next(datasets.train)
29 | self.assertEqual(data["obs"].shape, (32, 8))
30 | self.assertEqual(data["target"].shape, (32, 8))
31 | self.assertTrue(np.all(data["obs"][:, 1:] == data["target"][:, 0:-1]))
32 |
33 | def test_lm1b_bytes_datasets(self):
34 | datasets = language.lm1b_bytes_datasets(32, 10)
35 | self.assertEqual(datasets.abstract_batch["obs"].shape, (32, 10))
36 | data = next(datasets.train)
37 | self.assertEqual(data["obs"].shape, (32, 10))
38 |
39 | def test_wikipedia_en_32k_datasets(self):
40 | datasets = language.wikipedia_en_32k_datasets(32, 8)
41 | self.assertEqual(datasets.abstract_batch["obs"].shape, (32, 8))
42 | self.assertEqual(datasets.abstract_batch["target"].shape, (32, 8))
43 | data = next(datasets.train)
44 | self.assertEqual(data["obs"].shape, (32, 8))
45 | self.assertEqual(data["target"].shape, (32, 8))
46 | self.assertTrue(np.all(data["obs"][:, 1:] == data["target"][:, 0:-1]))
47 |
48 | data = next(datasets.test)
49 | self.assertEqual(data["obs"].shape, (32, 8))
50 |
51 |
52 | if __name__ == "__main__":
53 | absltest.main()
54 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/es_wrapper_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for es_wrapper."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import jax
21 | import jax.numpy as jnp
22 | from learned_optimization.tasks import es_wrapper
23 | from learned_optimization.tasks import quadratics
24 | from learned_optimization.tasks import test_utils
25 | import numpy as onp
26 |
27 |
28 | class EsWrapperTest(parameterized.TestCase):
29 |
30 | @parameterized.named_parameters([
31 | ("with_vmap", True),
32 | ("without_vmap", False),
33 | ])
34 | def test_antithetic_es_value_and_grad(self, vmap):
35 | key = jax.random.PRNGKey(0)
36 |
37 | def loss_fn(aa, unused_argb):
38 | loss = sum([jnp.sum(a**2) for a in aa.values()])
39 | return loss, jnp.asarray(4.0)
40 |
41 | params = {
42 | "a": jnp.ones(3, dtype=jnp.float32),
43 | "b": jnp.zeros(3, dtype=jnp.float32)
44 | }
45 |
46 | (loss, aux), grads = es_wrapper.antithetic_es_value_and_grad(
47 | loss_fn, has_aux=True, std=0.01, vmap=vmap)(
48 | params, None, es_key=key)
49 |
50 | self.assertAlmostEqual(float(aux), 4)
51 | self.assertLess(onp.abs(loss - 3), 1)
52 | self.assertEqual(
53 | jax.tree_util.tree_structure(grads),
54 | jax.tree_util.tree_structure(params))
55 |
56 | def test_multi_antithetic_es_value_and_grad(self):
57 | key = jax.random.PRNGKey(0)
58 |
59 | def loss_fn(aa, unused_argb):
60 | loss = sum([jnp.sum(a**2) for a in aa.values()])
61 | return loss, jnp.asarray(4.0)
62 |
63 | params = {
64 | "a": jnp.ones(3, dtype=jnp.float32),
65 | "b": jnp.zeros(3, dtype=jnp.float32)
66 | }
67 |
68 | (loss, aux), grads = es_wrapper.multi_antithetic_es_value_and_grad(
69 | loss_fn, n_pairs=256, has_aux=True, std=0.01)(
70 | params, None, es_key=key)
71 |
72 | self.assertAlmostEqual(float(aux), 4)
73 | # these are stochastic estimates, so there will be error but with 256
74 | # samples this should be ok.
75 | self.assertGreater(onp.mean(grads["a"]), 1)
76 | self.assertLess(onp.mean(grads["b"]**2), 1)
77 | self.assertLess(onp.abs(loss - 3), 1)
78 |
79 | def test_es_task_wrapper(self):
80 | task = es_wrapper.ESTask(quadratics.QuadraticTask())
81 | test_utils.smoketest_task(task)
82 |
83 | def test_es_task_family_wrapper(self):
84 | task_family = es_wrapper.ESTaskFamily(
85 | quadratics.FixedDimQuadraticFamilyData(10), std=0.01, n_pairs=3)
86 | test_utils.smoketest_task_family(task_family)
87 |
88 |
89 | if __name__ == "__main__":
90 | absltest.main()
91 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/fixed/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Fixed tasks module."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/fixed/all.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Import all fixed tasks into one module."""
17 | # pylint: disable=wildcard-import
18 |
19 | from learned_optimization.tasks.fixed.conv import *
20 | from learned_optimization.tasks.fixed.image_mlp import *
21 | from learned_optimization.tasks.fixed.image_mlp_ae import *
22 | from learned_optimization.tasks.fixed.resnet import *
23 | from learned_optimization.tasks.fixed.rnn_lm import *
24 | from learned_optimization.tasks.fixed.es_wrapped import *
25 | from learned_optimization.tasks.fixed.transformer_lm import *
26 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/fixed/conv_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.tasks.fixed.conv."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.fixed import conv
22 |
23 |
24 | tasks = [
25 | "Conv_Cifar10_32x64x64",
26 | "Conv_Cifar10_16_32x64x64",
27 | "Conv_Cifar10_32x64x64_Tanh",
28 | "Conv_imagenet32_16_32x64x64",
29 | "Conv_imagenet16_16_64x128x128x128",
30 | "Conv_Cifar10_32x64x64_batchnorm",
31 | "Conv_Cifar10_32x64x64_layernorm",
32 | ]
33 |
34 |
35 | class ConvTest(parameterized.TestCase):
36 |
37 | @parameterized.parameters(tasks)
38 | def test_tasks(self, task_name):
39 | task = getattr(conv, task_name)()
40 | test_utils.smoketest_task(task)
41 |
42 |
43 | if __name__ == "__main__":
44 | absltest.main()
45 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/fixed/es_wrapped.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tasks which estimate gradients via ES."""
17 | # pylint: disable=invalid-name
18 |
19 | import gin
20 | from learned_optimization.tasks import es_wrapper
21 | from learned_optimization.tasks.fixed import image_mlp
22 |
23 |
24 | @gin.configurable
25 | def ESWrapped_1pair_ImageMLP_Cifar10_128x128x128_Relu():
26 | return es_wrapper.ESTask(
27 | image_mlp.ImageMLP_Cifar10_128x128x128_Relu(), 0.01, n_pairs=1)
28 |
29 |
30 | @gin.configurable
31 | def ESWrapped_8pair_ImageMLP_Cifar10_128x128x128_Relu():
32 | return es_wrapper.ESTask(
33 | image_mlp.ImageMLP_Cifar10_128x128x128_Relu(), 0.01, n_pairs=8)
34 |
35 |
36 | @gin.configurable
37 | def ESWrapped_1pair_ImageMLP_Mnist_128x128x128_Relu():
38 | return es_wrapper.ESTask(
39 | image_mlp.ImageMLP_Mnist_128x128x128_Relu(), 0.01, n_pairs=1)
40 |
41 |
42 | @gin.configurable
43 | def ESWrapped_8pair_ImageMLP_Mnist_128x128x128_Relu():
44 | return es_wrapper.ESTask(
45 | image_mlp.ImageMLP_Mnist_128x128x128_Relu(), 0.01, n_pairs=8)
46 |
47 |
48 | @gin.configurable
49 | def ESWrapped_1pair_ImageMLP_FashionMnist8_Relu32():
50 | return es_wrapper.ESTask(
51 | image_mlp.ImageMLP_FashionMnist8_Relu32(), 0.01, n_pairs=1)
52 |
53 |
54 | @gin.configurable
55 | def ESWrapped_1pair_ImageMLP_Cifar10BW8_Relu32():
56 | return es_wrapper.ESTask(
57 | image_mlp.ImageMLP_Cifar10BW8_Relu32(), 0.01, n_pairs=1)
58 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/fixed/es_wrapped_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.tasks.fixed.es_wrapped."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.fixed import es_wrapped
22 |
23 | tasks = [
24 | "ESWrapped_1pair_ImageMLP_Cifar10_128x128x128_Relu",
25 | "ESWrapped_8pair_ImageMLP_Cifar10_128x128x128_Relu",
26 | "ESWrapped_1pair_ImageMLP_Mnist_128x128x128_Relu",
27 | "ESWrapped_8pair_ImageMLP_Mnist_128x128x128_Relu",
28 | ]
29 |
30 |
31 | class ESWrappedTest(parameterized.TestCase):
32 |
33 | @parameterized.parameters(tasks)
34 | def test_tasks(self, task_name):
35 | task = getattr(es_wrapped, task_name)()
36 | test_utils.smoketest_task(task)
37 |
38 |
39 | if __name__ == "__main__":
40 | absltest.main()
41 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/fixed/image_mlp_ae_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.tasks.fixed_mlp."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.fixed import image_mlp_ae
22 |
23 | tasks = [
24 | "ImageMLPAE_Cifar10_32x32x32_bs128",
25 | "ImageMLPAE_Cifar10_256x256x256_bs128",
26 | "ImageMLPAE_Cifar10_256x256x256_bs1024",
27 | "ImageMLPAE_Cifar10_128x32x128_bs256",
28 | "ImageMLPAE_Mnist_128x32x128_bs128",
29 | "ImageMLPAE_FashionMnist_128x32x128_bs128",
30 | ]
31 |
32 |
33 | class ImageMLPAETest(parameterized.TestCase):
34 |
35 | @parameterized.parameters(tasks)
36 | def test_tasks(self, task_name):
37 | task = getattr(image_mlp_ae, task_name)()
38 | test_utils.smoketest_task(task)
39 |
40 |
41 | if __name__ == "__main__":
42 | absltest.main()
43 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/fixed/image_mlp_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.tasks.fixed.image_mlp."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.fixed import image_mlp
22 |
23 |
24 | tasks = [
25 | "ImageMLP_FashionMnist_Relu128x128",
26 | "ImageMLP_Imagenet16_Relu256x256x256",
27 | "ImageMLP_Cifar10_128x128x128_Relu",
28 | "ImageMLP_Cifar10_128x128_Dropout05_Relu_MSE",
29 | "ImageMLP_Cifar10_128x128x128_BatchNorm_Relu",
30 | "ImageMLP_Cifar10_128x128x128_LayerNorm_Relu",
31 | "ImageMLP_Cifar10BW8_Relu32",
32 | ]
33 |
34 |
35 | class ImageMLPTest(parameterized.TestCase):
36 |
37 | @parameterized.parameters(tasks)
38 | def test_tasks(self, task_name):
39 | task = getattr(image_mlp, task_name)()
40 | test_utils.smoketest_task(task)
41 |
42 |
43 | if __name__ == "__main__":
44 | absltest.main()
45 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/fixed/lopt_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.tasks.fixed.lopt."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.fixed import lopt
22 |
23 | tasks = [
24 | "LOpt_MLPLOpt_FahionMnist_10",
25 | "LOpt_MLPLOpt_Cifar10_8_50",
26 | "LOpt_AdafacMLPLOpt_FashionMnist_50",
27 | "LOpt_AdafacMLPLOpt_FashionMnist_20",
28 | "LOpt_LearnableAdam_Cifar10_8_50",
29 | "LOpt_ES4_AdafacMLPLOpt_FashionMnist_20",
30 | "LOpt_ES4_LOpt_MLPLOpt_FahionMnist_50",
31 | "LOpt_ES4_LOpt_MLPLOpt_Cifar10_16_10",
32 | "LOptPES_Adafac_Fashion_OuterBS8_Length50_Trunc5",
33 | "LOptPES_Adafac_Cifar_OuterBS8_Length50_Trunc5",
34 | ]
35 |
36 |
37 | class LOptTest(parameterized.TestCase):
38 |
39 | @parameterized.parameters(tasks)
40 | def test_tasks(self, task_name):
41 | task = getattr(lopt, task_name)()
42 | test_utils.smoketest_task(task)
43 |
44 |
45 | if __name__ == "__main__":
46 | absltest.main()
47 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/fixed/mlp_mixer_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.tasks.fixed.vit."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.fixed import mlp_mixer
22 |
23 | tasks = [
24 | 'MLPMixer_Cifar100_bs256_tiny16',
25 | 'MLPMixer_Cifar100_small16',
26 | 'MLPMixer_Cifar100_tiny16',
27 | 'MLPMixer_Food101_64_bs256_tiny16',
28 | 'MLPMixer_Food101_64_small16',
29 | 'MLPMixer_Food101_64_tiny16',
30 | 'MLPMixer_ImageNet64_bs256_tiny16',
31 | 'MLPMixer_ImageNet64_small16',
32 | 'MLPMixer_ImageNet64_tiny16',
33 | ]
34 |
35 |
36 | class MLPMixerTest(parameterized.TestCase):
37 |
38 | @parameterized.parameters(tasks)
39 | def test_tasks(self, task_name):
40 | task = getattr(mlp_mixer, task_name)()
41 | test_utils.smoketest_task(task)
42 |
43 |
44 | if __name__ == '__main__':
45 | absltest.main()
46 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/fixed/resnet_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.tasks.fixed.transformer_lm."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.fixed import resnet
22 |
23 | tasks = [
24 | "Resnet_MultiRuntime_0",
25 | "Resnet_MultiRuntime_1",
26 | "Resnet_MultiRuntime_2",
27 | "Resnet_MultiRuntime_3",
28 | ]
29 |
30 |
31 | class ResnetLM(parameterized.TestCase):
32 |
33 | @parameterized.parameters(tasks)
34 | def test_tasks(self, task_name):
35 | task = getattr(resnet, task_name)()
36 | test_utils.smoketest_task(task)
37 |
38 |
39 | if __name__ == "__main__":
40 | absltest.main()
41 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/fixed/rnn_lm_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.tasks.fixed.rnn_lm."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.fixed import rnn_lm
22 |
23 | tasks = [
24 | 'RNNLM_lm1b32k_Patch32_IRNN256_Embed128',
25 | 'RNNLM_lm1b32k_Patch32_LSTM256_Embed128',
26 | 'RNNLM_lm1b32k_Patch32_VanillaRNN256_Embed128',
27 | 'RNNLM_lm1bbytes_Patch128_LSTM128_Embed64',
28 | 'RNNLM_lm1bbytes_Patch32_GRU128_Embed64',
29 | 'RNNLM_lm1bbytes_Patch32_GRU256_Embed128',
30 | 'RNNLM_lm1bbytes_Patch32_IRNN128_Embed64',
31 | 'RNNLM_lm1bbytes_Patch32_LSTM128_Embed64',
32 | 'RNNLM_lm1bbytes_Patch32_LSTM256_Embed128',
33 | 'RNNLM_lm1bbytes_Patch32_VanillaRNN128_Embed64',
34 | 'RNNLM_wikipediaen32k_Patch32_GRU256_Embed128',
35 | 'RNNLM_wikipediaen32k_Patch32_LSTM256_Embed128',
36 | 'RNNLM_wikipediaenbytes_Patch32_GRU256_Embed128',
37 | 'RNNLM_wikipediaenbytes_Patch32_LSTM256_Embed128',
38 | ]
39 |
40 |
41 | class RNNLM(parameterized.TestCase):
42 |
43 | @parameterized.parameters(tasks)
44 | def test_tasks(self, task_name):
45 | task = getattr(rnn_lm, task_name)()
46 | test_utils.smoketest_task(task)
47 |
48 |
49 | if __name__ == '__main__':
50 | absltest.main()
51 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/fixed/transformer_lm_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.tasks.fixed.transformer_lm."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.fixed import transformer_lm
22 |
23 | tasks = [
24 | "TransformerLM_LM1B_MultiRuntime_0",
25 | "TransformerLM_LM1B_MultiRuntime_1",
26 | "TransformerLM_LM1B_MultiRuntime_2",
27 | "TransformerLM_LM1B_MultiRuntime_3",
28 | "TransformerLM_LM1B_MultiRuntime_4",
29 | "TransformerLM_LM1B_5layer_32width",
30 | ]
31 |
32 |
33 | class TransformerLM(parameterized.TestCase):
34 |
35 | @parameterized.parameters(tasks)
36 | def test_tasks(self, task_name):
37 | task = getattr(transformer_lm, task_name)()
38 | test_utils.smoketest_task(task)
39 |
40 |
41 | if __name__ == "__main__":
42 | absltest.main()
43 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/fixed/vit_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.tasks.fixed.vit."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.fixed import vit
22 |
23 | tasks = [
24 | "VIT_ImageNet64_wideshallow",
25 | "VIT_ImageNet64_skinnydeep",
26 | "VIT_Cifar100_wideshallow",
27 | "VIT_Cifar100_skinnydeep",
28 | "VIT_Food101_64_wideshallow",
29 | "VIT_Food101_64_skinnydeep",
30 | ]
31 |
32 |
33 | class VITTest(parameterized.TestCase):
34 |
35 | @parameterized.parameters(tasks)
36 | def test_tasks(self, task_name):
37 | task = getattr(vit, task_name)()
38 | test_utils.smoketest_task(task)
39 |
40 |
41 | if __name__ == "__main__":
42 | absltest.main()
43 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/parametric/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Parametric tasks module."""
17 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/parametric/image_conv_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for image_mlp."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.datasets import image
22 | from learned_optimization.tasks.parametric import cfgobject
23 | from learned_optimization.tasks.parametric import image_conv
24 |
25 |
26 | class ImageConvTest(absltest.TestCase):
27 |
28 | def test_ParametricImageConv(self):
29 | datasets = image.mnist_datasets(8, image_size=(8, 8))
30 | task_family = image_conv.ParametricImageConv(
31 | datasets,
32 | num_classes=10,
33 | hidden_sizes=(8, 2),
34 | kernel_sizes=((3, 3), (3, 3)),
35 | strides=(1, 1))
36 | test_utils.smoketest_task_family(task_family)
37 |
38 | def test_sample_image_conv(self):
39 | key = jax.random.PRNGKey(0)
40 | cfg = image_conv.sample_image_conv(key)
41 | task_family = cfgobject.object_from_config(cfg)
42 | test_utils.smoketest_task_family(task_family)
43 |
44 | def test_timed_sample_image_conv(self):
45 | key = jax.random.PRNGKey(0)
46 | sampled_task = image_conv.timed_sample_image_conv(key)
47 | self.assertIsInstance(sampled_task, cfgobject.CFGObject)
48 |
49 |
50 | if __name__ == '__main__':
51 | absltest.main()
52 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/parametric/image_mlp_ae_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for image_mlp_ae."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.datasets import image
22 | from learned_optimization.tasks.parametric import cfgobject
23 | from learned_optimization.tasks.parametric import image_mlp_ae
24 |
25 |
26 | class ImageMlpAETest(absltest.TestCase):
27 |
28 | def test_ParametricImageMLPAE(self):
29 | datasets = image.mnist_datasets(8, image_size=(8, 8))
30 | task_family = image_mlp_ae.ParametricImageMLPAE(datasets, (16, 5))
31 | test_utils.smoketest_task_family(task_family)
32 |
33 | def test_sample_image_mlp_ae(self):
34 | key = jax.random.PRNGKey(0)
35 | cfg1 = image_mlp_ae.sample_image_mlp_ae(key)
36 | cfg2 = image_mlp_ae.sample_image_mlp_ae(key)
37 | key = jax.random.PRNGKey(1)
38 | cfg3 = image_mlp_ae.sample_image_mlp_ae(key)
39 | self.assertEqual(cfg1, cfg2)
40 | self.assertNotEqual(cfg1, cfg3)
41 |
42 | obj = cfgobject.object_from_config(cfg1)
43 | self.assertIsInstance(obj, image_mlp_ae.ParametricImageMLPAE)
44 |
45 | def test_timed_sample_image_mlp_ae(self):
46 | key = jax.random.PRNGKey(0)
47 | sampled_task = image_mlp_ae.timed_sample_image_mlp_ae(key)
48 | self.assertIsInstance(sampled_task, cfgobject.CFGObject)
49 |
50 |
51 | if __name__ == '__main__':
52 | absltest.main()
53 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/parametric/image_mlp_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for image_mlp."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.datasets import image
22 | from learned_optimization.tasks.parametric import cfgobject
23 | from learned_optimization.tasks.parametric import image_mlp
24 |
25 |
26 | class ImageMlpTest(absltest.TestCase):
27 |
28 | def test_ParametricImageMLP(self):
29 | datasets = image.mnist_datasets(8, image_size=(8, 8))
30 | task_family = image_mlp.ParametricImageMLP(datasets, 10, (16, 5))
31 | test_utils.smoketest_task_family(task_family)
32 |
33 | def test_timed_sample_image_mlp(self):
34 | key = jax.random.PRNGKey(0)
35 | sampled_task = image_mlp.timed_sample_image_mlp(key)
36 | self.assertIsInstance(sampled_task, cfgobject.CFGObject)
37 |
38 | def test_timed_sample_image_mlp_subsample_data(self):
39 | key = jax.random.PRNGKey(0)
40 | sampled_task = image_mlp.timed_sample_image_mlp_subsample_data(key)
41 | self.assertIsInstance(sampled_task, cfgobject.CFGObject)
42 |
43 |
44 | if __name__ == '__main__':
45 | absltest.main()
46 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/parametric/image_mlp_vae_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for image_mlp_vae."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.datasets import image
22 | from learned_optimization.tasks.parametric import cfgobject
23 | from learned_optimization.tasks.parametric import image_mlp_vae
24 |
25 |
26 | class ImageMlpVAETest(absltest.TestCase):
27 |
28 | def test_ParametricImageMLPVAE(self):
29 | datasets = image.mnist_datasets(8, image_size=(8, 8))
30 | task_family = image_mlp_vae.ParametricImageMLPVAE(
31 | datasets, enc_hidden_sizes=(16,), dec_hidden_sizes=(16,), n_z=16)
32 | test_utils.smoketest_task_family(task_family)
33 |
34 | def test_sample_image_mlp_vae(self):
35 | key = jax.random.PRNGKey(0)
36 | cfg1 = image_mlp_vae.sample_image_mlp_vae(key)
37 | cfg2 = image_mlp_vae.sample_image_mlp_vae(key)
38 | key = jax.random.PRNGKey(1)
39 | cfg3 = image_mlp_vae.sample_image_mlp_vae(key)
40 | self.assertEqual(cfg1, cfg2)
41 | self.assertNotEqual(cfg1, cfg3)
42 |
43 | obj = cfgobject.object_from_config(cfg1)
44 | self.assertIsInstance(obj, image_mlp_vae.ParametricImageMLPVAE)
45 |
46 | def test_timed_sample_image_mlp_vae(self):
47 | key = jax.random.PRNGKey(0)
48 | sampled_task = image_mlp_vae.timed_sample_image_mlp_vae(key)
49 | self.assertIsInstance(sampled_task, cfgobject.CFGObject)
50 |
51 | if __name__ == '__main__':
52 | absltest.main()
53 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/parametric/image_resnet_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for image_mlp_resnet."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.datasets import image
22 | from learned_optimization.tasks.parametric import cfgobject
23 | from learned_optimization.tasks.parametric import image_resnet
24 |
25 |
26 | class ImageMlpresnetTest(absltest.TestCase):
27 |
28 | def test_ParametricImageMLPResNet(self):
29 | datasets = image.mnist_datasets(8, image_size=(8, 8))
30 | task_family = image_resnet.ParametricImageResNet(
31 | datasets,
32 | initial_conv_channels=4,
33 | initial_conv_stride=1,
34 | initial_conv_kernel_size=3,
35 | blocks_per_group=(1, 1, 1, 1),
36 | channels_per_group=(4, 4, 4, 4),
37 | max_pool=True)
38 | test_utils.smoketest_task_family(task_family)
39 |
40 | def test_sample_image_resnet(self):
41 | key = jax.random.PRNGKey(0)
42 | cfg1 = image_resnet.sample_image_resnet(key)
43 | cfg2 = image_resnet.sample_image_resnet(key)
44 | key = jax.random.PRNGKey(1)
45 | cfg3 = image_resnet.sample_image_resnet(key)
46 | self.assertEqual(cfg1, cfg2)
47 | self.assertNotEqual(cfg1, cfg3)
48 |
49 | obj = cfgobject.object_from_config(cfg1)
50 | self.assertIsInstance(obj, image_resnet.ParametricImageResNet)
51 |
52 | def test_timed_sample_image_resnet(self):
53 | key = jax.random.PRNGKey(0)
54 | sampled_task = image_resnet.timed_sample_image_resnet(key)
55 | self.assertIsInstance(sampled_task, cfgobject.CFGObject)
56 |
57 | if __name__ == '__main__':
58 | absltest.main()
59 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/parametric/lm_rnn_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for lm_rnn."""
17 |
18 | from absl.testing import absltest
19 | import haiku as hk
20 | import jax
21 | from learned_optimization.tasks import test_utils
22 | from learned_optimization.tasks.datasets import language
23 | from learned_optimization.tasks.parametric import cfgobject
24 | from learned_optimization.tasks.parametric import lm_rnn
25 |
26 |
27 | class LmRnnTest(absltest.TestCase):
28 |
29 | def test_ParametricLMRNN(self):
30 | datasets = language.lm1b_bytes_datasets(4, 6)
31 | task_family = lm_rnn.ParametricLMRNN(
32 | datasets,
33 | rnn_size=7,
34 | vocab_size=16,
35 | embed_size=4,
36 | rnn_core_fn=hk.VanillaRNN)
37 | test_utils.smoketest_task_family(task_family)
38 |
39 | def test_sample_lm_rnn(self):
40 | key = jax.random.PRNGKey(0)
41 | cfg1 = lm_rnn.sample_lm_rnn(key)
42 | cfg2 = lm_rnn.sample_lm_rnn(key)
43 | key = jax.random.PRNGKey(1)
44 | cfg3 = lm_rnn.sample_lm_rnn(key)
45 | self.assertEqual(cfg1, cfg2)
46 | self.assertNotEqual(cfg1, cfg3)
47 |
48 | obj = cfgobject.object_from_config(cfg1)
49 | self.assertIsInstance(obj, lm_rnn.ParametricLMRNN)
50 |
51 | def test_timed_sample_lm_rnn(self):
52 | key = jax.random.PRNGKey(0)
53 | sampled_task = lm_rnn.timed_sample_lm_rnn(key)
54 | self.assertIsInstance(sampled_task, cfgobject.CFGObject)
55 |
56 |
57 | if __name__ == '__main__':
58 | absltest.main()
59 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/parametric/lm_transformer_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for lm_transformer."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.datasets import language
22 | from learned_optimization.tasks.parametric import cfgobject
23 | from learned_optimization.tasks.parametric import lm_transformer
24 |
25 |
26 | class LmTransformerTest(absltest.TestCase):
27 |
28 | def test_ParametricLMTransformer(self):
29 | datasets = language.lm1b_bytes_datasets(4, 6)
30 | task_family = lm_transformer.ParametricLMTransformer(
31 | datasets, vocab_size=None, num_heads=4, num_layers=3, d_model=128)
32 | test_utils.smoketest_task_family(task_family)
33 |
34 | def test_sample_lm_rnn(self):
35 | key = jax.random.PRNGKey(0)
36 | cfg1 = lm_transformer.sample_lm_transformer(key)
37 | cfg2 = lm_transformer.sample_lm_transformer(key)
38 | key = jax.random.PRNGKey(1)
39 | cfg3 = lm_transformer.sample_lm_transformer(key)
40 | self.assertEqual(cfg1, cfg2)
41 | self.assertNotEqual(cfg1, cfg3)
42 |
43 | obj = cfgobject.object_from_config(cfg1)
44 | self.assertIsInstance(obj, lm_transformer.ParametricLMTransformer)
45 |
46 | def test_timed_sample_lm_transformer(self):
47 | key = jax.random.PRNGKey(0)
48 | sampled_task = lm_transformer.timed_sample_lm_transformer(key)
49 | self.assertIsInstance(sampled_task, cfgobject.CFGObject)
50 |
51 | if __name__ == '__main__':
52 | absltest.main()
53 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/parametric/lopt_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for lopt."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from learned_optimization.learned_optimizers import base as lopt_base
21 | from learned_optimization.tasks import quadratics
22 | from learned_optimization.tasks import test_utils
23 | from learned_optimization.tasks.parametric import cfgobject
24 | from learned_optimization.tasks.parametric import lopt
25 |
26 |
27 | class LOptTest(absltest.TestCase):
28 |
29 | def test_ParametricLOpt(self):
30 | task_family = quadratics.FixedDimQuadraticFamily()
31 | lopt_adam = lopt_base.LearnableAdam()
32 | task_family = lopt.ParametricLOpt(
33 | task_family, lopt_adam, 2, outer_batch_size=2)
34 | test_utils.smoketest_task_family(task_family)
35 |
36 | def test_sample_lopt(self):
37 | key = jax.random.PRNGKey(0)
38 | cfg1 = lopt.sample_lopt(key)
39 | cfg2 = lopt.sample_lopt(key)
40 | key = jax.random.PRNGKey(1)
41 | cfg3 = lopt.sample_lopt(key)
42 | self.assertEqual(cfg1, cfg2)
43 | self.assertNotEqual(cfg1, cfg3)
44 |
45 | obj = cfgobject.object_from_config(cfg1)
46 | self.assertIsInstance(obj, lopt.ParametricLOpt)
47 |
48 | def test_timed_sample_lopt(self):
49 | key = jax.random.PRNGKey(0)
50 | sampled_task = lopt.timed_sample_lopt(key)
51 | self.assertIsInstance(sampled_task, cfgobject.CFGObject)
52 |
53 |
54 | if __name__ == '__main__':
55 | absltest.main()
56 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/parametric/lopt_trunc_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for lopt_pes_test."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from learned_optimization.learned_optimizers import base as lopt_base
21 | from learned_optimization.outer_trainers import truncation_schedule
22 | from learned_optimization.tasks import quadratics
23 | from learned_optimization.tasks import test_utils
24 | from learned_optimization.tasks.parametric import cfgobject
25 | from learned_optimization.tasks.parametric import lopt_trunc
26 |
27 | jax.config.update('jax_threefry_partitionable', False)
28 |
29 |
30 | class LOptTruncTest(absltest.TestCase):
31 |
32 | def test_ParametricLOptTrunc(self):
33 | task_family = quadratics.FixedDimQuadraticFamily()
34 | lopt_adam = lopt_base.LearnableAdam()
35 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)
36 | task_family = lopt_trunc.ParametricLOptTrunc(
37 | task_family,
38 | lopt_adam,
39 | trunc_sched=trunc_sched,
40 | unroll_length=2,
41 | outer_batch_size=2,
42 | max_length=2,
43 | mode='TruncatedES')
44 | test_utils.smoketest_task_family(task_family)
45 |
46 | def test_sample_lopt(self):
47 | key = jax.random.PRNGKey(0)
48 | cfg1 = lopt_trunc.sample_lopt_trunc(key)
49 | cfg2 = lopt_trunc.sample_lopt_trunc(key)
50 | key = jax.random.PRNGKey(1)
51 | cfg3 = lopt_trunc.sample_lopt_trunc(key)
52 | self.assertEqual(cfg1, cfg2)
53 | self.assertNotEqual(cfg1, cfg3)
54 |
55 | obj = cfgobject.object_from_config(cfg1)
56 | self.assertIsInstance(obj, lopt_trunc.ParametricLOptTrunc)
57 |
58 | def test_timed_sample_lopt(self):
59 | key = jax.random.PRNGKey(0)
60 | sampled_task = lopt_trunc.timed_sample_lopt_trunc(key)
61 | self.assertIsInstance(sampled_task, cfgobject.CFGObject)
62 |
63 |
64 | if __name__ == '__main__':
65 | absltest.main()
66 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/parametric/parametric_utils_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for parametric_utils."""
17 |
18 | from absl.testing import absltest
19 | import haiku as hk
20 | import jax
21 | import jax.numpy as jnp
22 | from learned_optimization.tasks.parametric import parametric_utils
23 | import numpy as onp
24 | from numpy import testing
25 |
26 |
27 | class ParametricUtilsTest(absltest.TestCase):
28 |
29 | def test_SampleImageDataset(self):
30 | key = jax.random.PRNGKey(0)
31 | cfg = parametric_utils.SampleImageDataset.sample(key)
32 | datasets = parametric_utils.SampleImageDataset.get_dataset(cfg, 8, (8, 8))
33 | _ = datasets.train
34 |
35 | def test_SampleActivation(self):
36 | key = jax.random.PRNGKey(0)
37 | cfg = parametric_utils.SampleActivation.sample(key)
38 | act_fn = parametric_utils.SampleActivation.get_dynamic(cfg)
39 | value = jax.jit(act_fn)(12.)
40 | self.assertEqual(value.shape, ())
41 |
42 | act_fn = parametric_utils.SampleActivation.get_static(cfg)
43 | value2 = act_fn(12.)
44 | self.assertEqual(value, value2)
45 |
46 | def test_SampleInitializer(self):
47 | key = jax.random.PRNGKey(0)
48 | cfg = parametric_utils.SampleInitializer.sample(key)
49 |
50 | def forward(cfg):
51 | init = parametric_utils.SampleInitializer.get_dynamic(cfg)
52 | param = hk.get_parameter('asdf', [2, 2], dtype=jnp.float32, init=init)
53 | return param
54 |
55 | init_fn, _ = hk.transform(forward)
56 | val = jax.jit(init_fn)(key, cfg)
57 | self.assertEqual(jax.tree_util.tree_leaves(val)[0].shape, (2, 2))
58 |
59 | def test_orth_init(self):
60 | key = jax.random.PRNGKey(0)
61 | init = parametric_utils.orth_init([16, 16], jnp.float32, key)
62 | # Check that the initializer is orthogonal by checking if all the eval's
63 | evals, unused_evecs = onp.linalg.eig(init)
64 | testing.assert_allclose(onp.abs(evals), jnp.ones([16]), rtol=1e-6)
65 |
66 |
67 | if __name__ == '__main__':
68 | absltest.main()
69 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/parametric/vit_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for vit."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from learned_optimization.tasks import test_utils
21 | from learned_optimization.tasks.datasets import image
22 | from learned_optimization.tasks.parametric import cfgobject
23 | from learned_optimization.tasks.parametric import vit
24 |
25 |
26 | class LmTransformerTest(absltest.TestCase):
27 |
28 | def test_ParametricVITTransformer(self):
29 | datasets = image.cifar100_datasets(8)
30 | task_family = vit.ParametricVIT(
31 | datasets,
32 | hidden_size=8,
33 | patch_size=4,
34 | mlp_dim=8,
35 | layers=1,
36 | num_heads=2,
37 | attension_dropout=0.0,
38 | dropout=0.0)
39 | test_utils.smoketest_task_family(task_family)
40 |
41 | def test_sample_lm_rnn(self):
42 | key = jax.random.PRNGKey(0)
43 | cfg1 = vit.sample_vit(key)
44 | cfg2 = vit.sample_vit(key)
45 | key = jax.random.PRNGKey(1)
46 | cfg3 = vit.sample_vit(key)
47 | self.assertEqual(cfg1, cfg2)
48 | self.assertNotEqual(cfg1, cfg3)
49 |
50 | obj = cfgobject.object_from_config(cfg1)
51 | self.assertIsInstance(obj, vit.ParametricVIT)
52 |
53 | def test_timed_sample_vit(self):
54 | key = jax.random.PRNGKey(0)
55 | sampled_task = vit.timed_sample_vit(key)
56 | self.assertIsInstance(sampled_task, cfgobject.CFGObject)
57 |
58 |
59 | if __name__ == '__main__':
60 | absltest.main()
61 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/quadratics_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for learned_optimizers.tasks.quadratics."""
17 |
18 | from absl.testing import absltest
19 | from learned_optimization.tasks import quadratics
20 | from learned_optimization.tasks import test_utils
21 |
22 |
23 | class QuadraticsTest(absltest.TestCase):
24 |
25 | def test_quadratic_task(self):
26 | test_utils.smoketest_task(quadratics.QuadraticTask())
27 |
28 | def test_batch_quadratic_task(self):
29 | test_utils.smoketest_task(quadratics.BatchQuadraticTask())
30 |
31 | def test_fixed_dim_quadratic_family(self):
32 | test_utils.smoketest_task_family(quadratics.FixedDimQuadraticFamily(10))
33 |
34 | def test_fixed_dim_quadratic_family_data(self):
35 | test_utils.smoketest_task_family(quadratics.FixedDimQuadraticFamilyData(10))
36 |
37 | def test_noisy_quadratic_family(self):
38 | test_utils.smoketest_task_family(quadratics.NoisyQuadraticFamily(10, 0.1))
39 |
40 | if __name__ == '__main__':
41 | absltest.main()
42 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/rnn.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Recurrent RNN haiku modules."""
17 | from typing import Optional
18 |
19 | import haiku as hk
20 | import jax
21 |
22 |
23 | class IRNN(hk.VanillaRNN):
24 | """Identity initialized RNN.
25 |
26 | This was introduced in https://arxiv.org/abs/1504.00941.
27 | """
28 |
29 | def __init__(
30 | self,
31 | hidden_size: int,
32 | double_bias: bool = True,
33 | name: Optional[str] = None,
34 | gain: float = 1.0,
35 | ):
36 | """Constructs a Identity RNN core.
37 |
38 | Args:
39 | hidden_size: Hidden layer size.
40 | double_bias: Whether to use a bias in the two linear layers. This changes
41 | nothing to the learning performance of the cell. However, doubling will
42 | create two sets of bias parameters rather than one.
43 | name: Name of the module.
44 | gain: multiplier on recurrent weight identity initialization.
45 | """
46 | super().__init__(
47 | hidden_size=hidden_size, double_bias=double_bias, name=name)
48 | self.gain = gain
49 |
50 | def __call__(self, inputs, prev_state):
51 | input_to_hidden = hk.Linear(self.hidden_size)
52 | hidden_to_hidden = hk.Linear(
53 | self.hidden_size,
54 | with_bias=self.double_bias,
55 | w_init=hk.initializers.Identity(self.gain))
56 | out = jax.nn.relu(input_to_hidden(inputs) + hidden_to_hidden(prev_state))
57 | return out, out
58 |
--------------------------------------------------------------------------------
/learned_optimization/tasks/test_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities for testing tasks and task families."""
17 |
18 | from absl import logging
19 | import jax
20 | import jax.numpy as jnp
21 | from learned_optimization.tasks import base
22 |
23 |
24 | def smoketest_task(task: base.Task, abstract_data: bool = True):
25 | """Smoke test a Task.
26 |
27 | Args:
28 | task: Task to test.
29 | abstract_data: To use abstract data, or to use the real dataset for this
30 | test. Using abstract batches does not require loading the underlying
31 | dataset which can be faster if datasets are stored on a remote machine.
32 | """
33 | key = jax.random.PRNGKey(0)
34 | param, state = task.init_with_state(key)
35 |
36 | logging.info("Getting data for %s task", str(task))
37 | if task.datasets:
38 | if abstract_data and task.datasets.abstract_batch is not None:
39 | batch = jax.tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype),
40 | task.datasets.abstract_batch)
41 | else:
42 | batch = next(task.datasets.train)
43 | else:
44 | batch = ()
45 | logging.info("Got data")
46 |
47 | logging.info("starting forward")
48 | loss_and_state = task.loss_with_state(param, state, key, batch)
49 | del loss_and_state
50 | logging.info("starting backward")
51 | grad, aux = jax.grad(
52 | task.loss_with_state, has_aux=True)(param, state, key, batch)
53 | del grad, aux
54 | logging.info("checking normalizer")
55 | task.normalizer(jnp.asarray(1.0))
56 |
57 | logging.info("checking loss_with_state_and_aux")
58 | loss, state, aux = task.loss_with_state_and_aux(param, state, key, batch)
59 | del loss, state, aux
60 | logging.info("done")
61 |
62 |
63 | def smoketest_task_family(task_family: base.TaskFamily):
64 | """Smoke test a TaskFamily."""
65 | key = jax.random.PRNGKey(0)
66 | task_params = task_family.sample(key)
67 | task = task_family.task_fn(task_params)
68 | smoketest_task(task)
69 |
70 | _, key = jax.random.split(key)
71 | task_params = task_family.sample(key)
72 | task = task_family.task_fn(task_params)
73 | smoketest_task(task)
74 |
75 | if task.datasets is not None:
76 | assert task_family.datasets is not None
77 |
--------------------------------------------------------------------------------
/learned_optimization/time_filter/time_model_data_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for time_model_data."""
17 |
18 | import os
19 | import tempfile
20 |
21 | from absl.testing import absltest
22 | from learned_optimization.tasks import quadratics # pylint: disable=unused-import
23 | from learned_optimization.tasks.parametric import cfgobject
24 | from learned_optimization.time_filter import run_sample_and_time
25 | from learned_optimization.time_filter import time_model_data
26 |
27 |
28 | def fake_sample_quadratic(key):
29 | del key
30 | return cfgobject.CFGObject("FixedDimQuadraticFamily", {"dim": 10})
31 |
32 |
33 | class TimeModelDataTest(absltest.TestCase):
34 |
35 | def test_load_runtime_files(self):
36 | with tempfile.TemporaryDirectory() as tmpdir:
37 | os.environ["LOPT_TIMING_DIR"] = tmpdir
38 | tmpdir = os.path.join(tmpdir, "quadratic")
39 | run_sample_and_time.run_many_eval_and_save(
40 | fake_sample_quadratic, tmpdir, num_to_run=4)
41 | data = time_model_data.load_runtime_files(
42 | "quadratic", "cpu_cpu", 5, threads=5)
43 | self.assertLen(data, 4)
44 |
45 | def test_train_test_iterators(self):
46 | with tempfile.TemporaryDirectory() as tmpdir:
47 | os.environ["LOPT_TIMING_DIR"] = tmpdir
48 | tmpdir = os.path.join(tmpdir, "quadratic")
49 | run_sample_and_time.run_many_eval_and_save(
50 | fake_sample_quadratic, tmpdir, num_to_run=4)
51 | tr_it, unused_te_it = time_model_data.train_test_iterators(
52 | "quadratic", "cpu_cpu", max_samples_to_load=5, num_test=2)
53 | for _ in range(10):
54 | data = next(tr_it)
55 | self.assertTrue("feats" in data) # pylint: disable=g-generic-assert
56 | self.assertTrue("time" in data) # pylint: disable=g-generic-assert
57 |
58 |
59 | if __name__ == "__main__":
60 | absltest.main()
61 |
--------------------------------------------------------------------------------
/learned_optimization/time_filter/timings_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for timings."""
17 |
18 | from absl.testing import absltest
19 | from learned_optimization.learned_optimizers import base as lopt_base
20 | from learned_optimization.optimizers import base as opt_base
21 | from learned_optimization.tasks import base as tasks_base
22 | from learned_optimization.tasks import quadratics
23 | from learned_optimization.time_filter import timings
24 |
25 |
26 | class TimingsTest(absltest.TestCase):
27 |
28 | def test_timing_default(self):
29 | task = quadratics.BatchQuadraticTask()
30 | task_family = tasks_base.single_task_to_family(task)
31 | stats = timings.task_family_runtime_stats(task_family)
32 | self.assertLess(stats["unroll_4x10"][0], 1.0)
33 |
34 | def test_timing_opt(self):
35 | task = quadratics.BatchQuadraticTask()
36 | task_family = tasks_base.single_task_to_family(task)
37 | stats = timings.task_family_runtime_stats(task_family, opt=opt_base.Adam())
38 | self.assertLess(stats["unroll_4x10"][0], 1.0)
39 |
40 | def test_timing_lopt(self):
41 | task = quadratics.BatchQuadraticTask()
42 | task_family = tasks_base.single_task_to_family(task)
43 | stats = timings.task_family_runtime_stats(
44 | task_family, lopt=lopt_base.LearnableAdam())
45 | self.assertLess(stats["unroll_4x10"][0], 1.0)
46 |
47 |
48 | if __name__ == "__main__":
49 | absltest.main()
50 |
--------------------------------------------------------------------------------
/learned_optimization/training.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities useful for training and meta-training."""
17 | import os
18 | from typing import Any, Sequence
19 |
20 | from absl import logging
21 | from flax import serialization
22 | import jax
23 | from learned_optimization import filesystem as fs
24 | from learned_optimization import profile
25 | from learned_optimization import tree_utils
26 | from learned_optimization.tasks import base as tasks_base
27 |
28 |
29 | def save_state(path, state):
30 | fs.make_dirs(os.path.dirname(path))
31 | with fs.file_open(path, "wb") as fp:
32 | fp.write(serialization.to_bytes(state))
33 |
34 |
35 | def load_state(path, state):
36 | logging.info("Restoring state %s", path)
37 | with fs.file_open(path, "rb") as fp:
38 | state_new = serialization.from_bytes(state, fp.read())
39 | tree = jax.tree_util.tree_structure(state)
40 | leaves = jax.tree_util.tree_leaves(state_new)
41 | return jax.tree_util.tree_unflatten(tree, leaves)
42 |
43 |
44 | def get_batches(task_family: tasks_base.TaskFamily,
45 | batch_shape: Sequence[int],
46 | split: str,
47 | numpy: bool = False) -> Any:
48 | """Get batches of data with the `batch_shape` leading dimension."""
49 | if len(batch_shape) == 1:
50 | return vec_get_batch(task_family, batch_shape[0], numpy=numpy, split=split)
51 | elif len(batch_shape) == 2:
52 | datas_list = [
53 | vec_get_batch(task_family, batch_shape[1], numpy=numpy, split=split)
54 | for _ in range(batch_shape[0])
55 | ]
56 | if numpy:
57 | return tree_utils.tree_zip_onp(datas_list)
58 | else:
59 | return tree_utils.tree_zip_jnp(datas_list)
60 | elif len(batch_shape) == 3:
61 | datas_list = [
62 | get_batches(
63 | task_family, [batch_shape[1], batch_shape[2]],
64 | numpy=numpy,
65 | split=split) for _ in range(batch_shape[0])
66 | ]
67 | if numpy:
68 | return tree_utils.tree_zip_onp(datas_list)
69 | else:
70 | return tree_utils.tree_zip_jnp(datas_list)
71 | else:
72 | raise NotImplementedError()
73 |
74 |
75 | @profile.wrap()
76 | def vec_get_batch(task_family, n_tasks, split, numpy=False):
77 | to_zip = []
78 | for _ in range(n_tasks):
79 | if task_family.datasets is None:
80 | return ()
81 | to_zip.append(next(task_family.datasets.split(split)))
82 | return tree_utils.tree_zip_onp(to_zip) if numpy else tree_utils.tree_zip_jnp(
83 | to_zip)
84 |
--------------------------------------------------------------------------------
/learned_optimization/tree_utils_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for tree_utils."""
17 |
18 | from typing import Any
19 |
20 | from absl.testing import absltest
21 | import flax
22 | import jax
23 | import jax.numpy as jnp
24 | from learned_optimization import tree_utils
25 |
26 |
27 | @flax.struct.dataclass
28 | class ClassName:
29 | val1: Any
30 | val2: Any
31 |
32 |
33 | class TreeUtilsTest(absltest.TestCase):
34 |
35 | def test_partition(self):
36 | c = {"a": 2, "b": ClassName(1, 2.)}
37 |
38 | partitions, unflattener = tree_utils.partition(
39 | [lambda k, v: jnp.asarray(v).dtype == jnp.int32], c)
40 |
41 | partitions[1] = jax.tree_util.tree_map(lambda x: x * 2, partitions[1])
42 |
43 | tree_utils.partition_unflatten(unflattener, partitions)
44 | data = unflattener(partitions)
45 |
46 | self.assertEqual(data["b"].val2, 4.)
47 | self.assertEqual(data["b"].val1, 1)
48 |
49 |
50 | if __name__ == "__main__":
51 | absltest.main()
52 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.12.0
2 | numpy>=1.18
3 | jax>=0.2.6
4 | jaxlib>=0.1.68
5 | pytest
6 | tqdm>=4.62.3
7 | flax
8 | dm-haiku==0.0.5
9 | optax>=0.0.9
10 | tensorflow>=2.7.0
11 | tensorflow-datasets>=4.4.0
12 | tensorflow-metadata==1.5.0
13 | tensorflow-probability>=0.16.0
14 | tensorboard>=2.7.0
15 | gin-config>=0.5.0
16 | seqio>=0.0.7
17 | oryx
18 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Setup the package."""
17 |
18 | import os
19 | import setuptools
20 |
21 | # https://packaging.python.org/guides/making-a-pypi-friendly-readme/
22 | this_directory = os.path.abspath(os.path.dirname(__file__))
23 | with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f:
24 | long_description = f.read()
25 |
26 | __version__ = '0.0.1'
27 |
28 | setuptools.setup(
29 | name='learned_optimization',
30 | version=__version__,
31 | description='Train learned optimizers in Jax.',
32 | author='learned_optimization team',
33 | author_email='lmetz@google.com',
34 | packages=setuptools.find_packages(exclude=['examples']),
35 | package_data={'learned_optimization': ['py.typed']},
36 | python_requires='>=3.7',
37 | setup_requires=['setuptools_scm'],
38 | include_package_data=True,
39 | # TODO(lmetz) don't fix many versions! Sadly a number of these libraries
40 | # don't play nice with newer versions of other libraries.
41 | # TODO(lmetz) add oryx to this!
42 | install_requires=[
43 | 'absl-py==0.12.0',
44 | 'numpy>=1.18',
45 | 'jax>=0.4.3',
46 | 'jaxlib>=0.4.3',
47 | 'nose',
48 | 'dm-launchpad-nightly',
49 | 'tqdm>=4.62.3',
50 | 'flax',
51 | 'dm-haiku==0.0.5',
52 | 'optax>=0.0.9',
53 | 'tensorflow>=2.7.0',
54 | 'tensorflow-datasets>=4.4.0',
55 | 'tensorflow-metadata==1.5.0',
56 | 'tensorboard>=2.7.0',
57 | 'gin-config>=0.5.0',
58 | 'oryx',
59 | ],
60 | url='https://github.com/google/learned_optimization',
61 | license='Apache-2.0',
62 | long_description=long_description,
63 | long_description_content_type='text/markdown',
64 | classifiers=[
65 | 'Programming Language :: Python :: 3.7',
66 | 'Programming Language :: Python :: 3.8',
67 | 'Programming Language :: Python :: 3.9',
68 | ],
69 | zip_safe=False,
70 | )
71 |
--------------------------------------------------------------------------------