├── .readthedocs.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── README.md ├── _static │ ├── continuous_eval_system_diagram.png │ ├── style.css │ └── vis_Conv_Cifar10_32x64x64.png ├── _templates │ └── layout.html ├── conf.py ├── continuous_eval.rst ├── index.rst ├── learned_optimizers.rst ├── notebooks │ ├── Part1_Introduction.ipynb │ ├── Part1_Introduction.md │ ├── Part1_Introduction.py │ ├── Part2_CustomTasks.ipynb │ ├── Part2_CustomTasks.md │ ├── Part2_CustomTasks.py │ ├── Part3_Truncation_TruncatedStep.ipynb │ ├── Part3_Truncation_TruncatedStep.md │ ├── Part3_Truncation_TruncatedStep.py │ ├── Part4_GradientEstimators.ipynb │ ├── Part4_GradientEstimators.md │ ├── Part4_GradientEstimators.py │ ├── Part5_Meta_training_with_GradientLearner.ipynb │ ├── Part5_Meta_training_with_GradientLearner.md │ ├── Part5_Meta_training_with_GradientLearner.py │ ├── Part6_custom_learned_optimizers.ipynb │ ├── Part6_custom_learned_optimizers.md │ ├── Part6_custom_learned_optimizers.py │ ├── no_dependency_learned_optimizer.ipynb │ ├── no_dependency_learned_optimizer.md │ ├── no_dependency_learned_optimizer.py │ ├── summary_tutorial.ipynb │ ├── summary_tutorial.md │ └── summary_tutorial.py ├── optimizer_baselines.rst ├── optimizers.rst ├── outer_trainers.rst ├── population.rst ├── requirements.txt └── tasks.rst ├── learned_optimization ├── __init__.py ├── baselines │ ├── __init__.py │ ├── data │ │ ├── speedup_over_adam.json │ │ ├── speedup_over_adam.pkl │ │ ├── speedup_over_adam_v2.json │ │ ├── speedup_over_multiple_baselines.json │ │ └── speedup_over_multiple_baselines_v2.json │ ├── hparam_sets.py │ ├── normalizers.py │ ├── normalizers_test.py │ ├── run_archive.py │ ├── run_time_task.py │ ├── run_trainer.py │ ├── run_trainer_test.py │ ├── utils.py │ └── vis.py ├── checkpoints.py ├── circular_buffer.py ├── circular_buffer_test.py ├── continuous_eval │ ├── __init__.py │ ├── evaluation_set.py │ ├── run_eval_chief.py │ ├── run_eval_chief_test.py │ ├── run_eval_worker.py │ ├── run_eval_worker_test.py │ ├── task_group_server.py │ └── task_group_server_test.py ├── distributed.py ├── distributed_test.py ├── entry_points │ ├── README.md │ ├── run_check_eval_worker.py │ ├── run_outer_train_single.py │ ├── run_outer_trainer.py │ ├── run_outer_worker.py │ ├── run_population_controller.py │ ├── run_single_machine_trainer.py │ └── run_tboard_eval.py ├── eval_training.py ├── eval_training_multi_device_test.py ├── eval_training_test.py ├── examples │ ├── README.md │ ├── simple_lopt_train.py │ └── simple_lopt_train_test.py ├── filesystem.py ├── jax_utils.py ├── jax_utils_test.py ├── launchers │ ├── __init__.py │ ├── launch_template_multi_machine_lopt.py │ └── utils.py ├── learned_optimizers │ ├── __init__.py │ ├── adafac_mlp_lopt.py │ ├── adafac_mlp_lopt_test.py │ ├── adafac_nominal.py │ ├── adafac_nominal_test.py │ ├── base.py │ ├── base_test.py │ ├── common.py │ ├── mlp_lopt.py │ ├── mlp_lopt_test.py │ ├── nn_adam.py │ ├── nn_adam_test.py │ ├── opt_from_checkpoint.py │ ├── rnn_mlp_lopt.py │ ├── rnn_mlp_lopt_test.py │ └── test_utils.py ├── notebook_utils.py ├── optimizers │ ├── __init__.py │ ├── alias.py │ ├── base.py │ ├── base_test.py │ ├── data │ │ └── opt_list.json │ ├── gradient_accumulator.py │ ├── gradient_accumulator_test.py │ ├── learning_rate_schedules.py │ ├── nadamw.py │ ├── nadamw_test.py │ ├── opt_list.py │ ├── opt_list_test.py │ ├── opt_to_optax.py │ ├── opt_to_optax_test.py │ ├── optax_opts.py │ ├── optax_opts_test.py │ ├── optimizer_wrappers.py │ ├── shampoo.py │ ├── shampoo_test.py │ └── test_utils.py ├── outer_train.py ├── outer_train_test.py ├── outer_trainers │ ├── __init__.py │ ├── common.py │ ├── full_es.py │ ├── full_es_test.py │ ├── full_grad.py │ ├── full_grad_test.py │ ├── gradient_learner.py │ ├── gradient_learner_test.py │ ├── lopt_truncated_step.py │ ├── lopt_truncated_step_test.py │ ├── test_utils.py │ ├── truncated_es.py │ ├── truncated_es_test.py │ ├── truncated_grad.py │ ├── truncated_grad_test.py │ ├── truncated_pes.py │ ├── truncated_pes_pmap_test.py │ ├── truncated_pes_test.py │ ├── truncated_step.py │ ├── truncation_schedule.py │ └── truncation_schedule_test.py ├── population │ ├── __init__.py │ ├── examples │ │ ├── __init__.py │ │ ├── complex_cnn │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ └── train_threads.py │ │ ├── simple_cnn │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── train.py │ │ │ ├── train_test.py │ │ │ └── train_threads.py │ │ └── synthetic │ │ │ ├── __init__.py │ │ │ ├── train.py │ │ │ └── train_test.py │ ├── mutators │ │ ├── __init__.py │ │ ├── fixed_schedule.py │ │ ├── fixed_schedule_test.py │ │ ├── single_worker_explore.py │ │ ├── single_worker_explore_test.py │ │ ├── winner_take_all_genetic.py │ │ └── winner_take_all_genetic_test.py │ ├── population.py │ └── population_test.py ├── profile.py ├── py_utils.py ├── research │ ├── README.md │ ├── __init__.py │ ├── brax │ │ ├── brax_env_truncated_step.py │ │ ├── brax_env_truncated_step_test.py │ │ └── run_single_machine_rl_trainer.py │ ├── data_driven │ │ ├── README.md │ │ ├── data.py │ │ ├── mnist_projections.py │ │ ├── mnist_projections_test.py │ │ ├── model_components.py │ │ ├── models.py │ │ ├── resnet.py │ │ ├── run_mnist_projections.py │ │ ├── summary.py │ │ └── transformer.py │ ├── discovered_policy_optimisation │ │ └── truncated_step.py │ ├── distill │ │ ├── truncated_distill.py │ │ └── truncated_distill_test.py │ ├── es_scaling │ │ └── es_scaling_tasks.py │ ├── general_lopt │ │ ├── Demo_for_training_a_model_with_a_learned_optimizer.ipynb │ │ ├── README.md │ │ ├── __init__.py │ │ ├── configs │ │ │ ├── README.md │ │ │ ├── large_scale_phase1.py │ │ │ ├── large_scale_phase2.py │ │ │ ├── large_scale_phase3.py │ │ │ └── large_scale_phase4.py │ │ ├── hyper_v2.py │ │ ├── hyper_v2_test.py │ │ ├── prefab.py │ │ ├── prefab_test.py │ │ ├── pretrained_optimizers.py │ │ ├── pretrained_optimizers_test.py │ │ └── tasks │ │ │ ├── __init__.py │ │ │ ├── fast_mlp_diff_data.py │ │ │ └── parametric.py │ ├── hysteresis │ │ ├── __init__.py │ │ ├── data_in_state_tasks.py │ │ ├── extra_small_lopt.py │ │ ├── truncated_es_no_hyst.py │ │ ├── truncated_es_shared_noise.py │ │ ├── truncated_pes_no_hyst.py │ │ └── truncated_pes_no_hyst_same_data.py │ ├── jaxnerf │ │ ├── datasets.py │ │ ├── example_data │ │ │ ├── imgs │ │ │ │ └── r_0.png │ │ │ ├── transforms_test.json │ │ │ └── transforms_train.json │ │ ├── jaxnerf.py │ │ └── jaxnerf_test.py │ ├── scaling │ │ ├── scaling_tasks.py │ │ └── scaling_transformer.py │ └── univ_nfn │ │ ├── gen_pred │ │ └── train_pred.py │ │ ├── learned_opt │ │ ├── __init__.py │ │ ├── learned_opts.py │ │ └── outer_opt.py │ │ └── nfn │ │ ├── __init__.py │ │ ├── ff_layers.py │ │ ├── siren.py │ │ ├── universal_layers.py │ │ └── utils.py ├── run_outer_train_single.py ├── setup_experiment.py ├── summary.py ├── summary_test.py ├── tasks │ ├── __init__.py │ ├── base.py │ ├── datasets │ │ ├── __init__.py │ │ ├── base.py │ │ ├── beam │ │ │ └── make_resized_image_beam.py │ │ ├── image.py │ │ ├── image_test.py │ │ ├── language.py │ │ └── language_test.py │ ├── es_wrapper.py │ ├── es_wrapper_test.py │ ├── fixed │ │ ├── __init__.py │ │ ├── all.py │ │ ├── conv.py │ │ ├── conv_test.py │ │ ├── es_wrapped.py │ │ ├── es_wrapped_test.py │ │ ├── image_mlp.py │ │ ├── image_mlp_ae.py │ │ ├── image_mlp_ae_test.py │ │ ├── image_mlp_test.py │ │ ├── lopt.py │ │ ├── lopt_test.py │ │ ├── mlp_mixer.py │ │ ├── mlp_mixer_test.py │ │ ├── resnet.py │ │ ├── resnet_test.py │ │ ├── rnn_lm.py │ │ ├── rnn_lm_test.py │ │ ├── transformer_lm.py │ │ ├── transformer_lm_test.py │ │ ├── vit.py │ │ └── vit_test.py │ ├── generative_model_utils.py │ ├── parametric │ │ ├── __init__.py │ │ ├── cfgobject.py │ │ ├── cfgobject_test.py │ │ ├── image_conv.py │ │ ├── image_conv_test.py │ │ ├── image_mlp.py │ │ ├── image_mlp_ae.py │ │ ├── image_mlp_ae_test.py │ │ ├── image_mlp_test.py │ │ ├── image_mlp_vae.py │ │ ├── image_mlp_vae_test.py │ │ ├── image_resnet.py │ │ ├── image_resnet_test.py │ │ ├── lm_rnn.py │ │ ├── lm_rnn_test.py │ │ ├── lm_transformer.py │ │ ├── lm_transformer_test.py │ │ ├── lopt.py │ │ ├── lopt_test.py │ │ ├── lopt_trunc.py │ │ ├── lopt_trunc_test.py │ │ ├── parametric_utils.py │ │ ├── parametric_utils_test.py │ │ ├── vit.py │ │ └── vit_test.py │ ├── quadratics.py │ ├── quadratics_test.py │ ├── resnet.py │ ├── rnn.py │ ├── task_augmentation.py │ ├── task_augmentation_test.py │ ├── test_utils.py │ └── transformer.py ├── time_filter │ ├── model_paths.py │ ├── run_sample_and_time.py │ ├── run_train_time_model.py │ ├── time_model.py │ ├── time_model_data.py │ ├── time_model_data_test.py │ ├── timings.py │ └── timings_test.py ├── training.py ├── tree_utils.py └── tree_utils_test.py ├── requirements.txt └── setup.py /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: "3.9" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: docs/conf.py 17 | 18 | # Optionally declare the Python requirements required to build your docs 19 | python: 20 | install: 21 | - requirements: docs/requirements.txt 22 | - requirements: requirements.txt 23 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | sync-notebooks: 16 | jupytext --sync notebooks/*{py,md,ipynb} 17 | 18 | .PHONY: help Makefile 19 | 20 | # Catch-all target: route all unknown targets to Sphinx using the new 21 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 22 | %: Makefile 23 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 24 | 25 | clean: 26 | rm -rf $(BUILDDIR)/* 27 | rm -rf auto_examples/ 28 | rm -rf _autosummary/ 29 | 30 | html-noplot: 31 | $(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html 32 | @echo 33 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 34 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for docs 2 | 3 | First install jupytext: 4 | 5 | `pip3 install jupytext` 6 | 7 | To sync the different forms of the notebooks, cd to this directory and 8 | 9 | `make sync-notebooks` 10 | -------------------------------------------------------------------------------- /docs/_static/continuous_eval_system_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/learned_optimization/3ab34d27bdd38f487af65fb025f9c86869946e81/docs/_static/continuous_eval_system_diagram.png -------------------------------------------------------------------------------- /docs/_static/style.css: -------------------------------------------------------------------------------- 1 | @import url("theme.css"); 2 | 3 | .wy-side-nav-search { 4 | background-color: #fff; 5 | } 6 | -------------------------------------------------------------------------------- /docs/_static/vis_Conv_Cifar10_32x64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/learned_optimization/3ab34d27bdd38f487af65fb025f9c86869946e81/docs/_static/vis_Conv_Cifar10_32x64x64.png -------------------------------------------------------------------------------- /docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | {% set css_files = css_files + ["_static/style.css"] %} 3 | -------------------------------------------------------------------------------- /docs/continuous_eval.rst: -------------------------------------------------------------------------------- 1 | Continuous Eval 2 | ================ 3 | 4 | Meta-learning training can be an incredibly noisy process as one often meta-learns over randomly sampled tasks and these sampled tasks could have very different meta-objectives. 5 | To get higher quality evaluation, as well as to evaluate performance on different 6 | kinds of distributions of tasks (i.e. to test out of distribution generalization) 7 | it is often useful to evaluate a suite of tasks on a single, fixed checkpoint/fixed meta-parameter value. 8 | This is a bit tricky though as this can be very expensive -- in the case of learned 9 | optimizers this requires training a number of different tasks a large number 10 | of inner-steps. 11 | To get faster evaluations, it is desirable to split up this computation and run it in 12 | parallel across a number of machines. 13 | 14 | This module is created to help with distributed evaluation of a given checkpoint 15 | and is meant to be run alongside a meta-training job / meta-training cluster 16 | so as to get stable evaluations while training. It interacts with a training 17 | job by monitoring a directory for a new checkpoint and kicks evaluation jobs 18 | off based on that. 19 | 20 | We summarize the architecture in the figure below. 21 | 22 | 23 | .. path to edit figure: https://docs.google.com/presentation/d/16zWe2ryUUcbWSBRhfYx7D5BwG9z1Mt-4BpWCOKmnais/edit?resourcekey=0-_if_-4xNYC5bgD1ZyN2Pmg#slide=id.p ---> 24 | 25 | 26 | .. image:: _static/continuous_eval_system_diagram.png 27 | :width: 800 28 | :alt: Continuous eval system diagram 29 | 30 | 31 | There are 2 types of job: the eval chief and the eval worker. 32 | 33 | Evaluation Chief 34 | ---------------- 35 | 36 | The evaluation chief is configured with a set of evaluations with which to run 37 | a given checkpoint on. The chief then starts up a couple of threads. 38 | First, there is a thread to monitor a given directory for a new checkpoint. 39 | When a new checkpoint is found, we send this checkpoint as well as the list of 40 | evaluations (e.g. different tasks) to evaluate on it to an instance of `TaskGroupChief`. 41 | This is a task queue specifically structured to manage groups of tasks being evaluated on a single checkpoint. 42 | A final thread monitors this task queue for finished groups of tasks and 43 | writes the results out to tensorboard, and optionally the population controller 44 | if one is using population based training 45 | 46 | 47 | Evaluation configuration 48 | ------------------------ 49 | 50 | While it's possible to use similar infrastructure to pass around any kind of 51 | configuration for an evaluation task, we opt to use `gin `_ -- in in particular 52 | a set of gin bindings which specify the new task. 53 | We find this very convenient, but it comes at the cost of complexity and global 54 | state. 55 | 56 | Evaluation Worker 57 | ------------------------ 58 | 59 | For a given evaluation chief, we run multiple evaluation workers. Each worker 60 | is responsible for grabbing an evaluation to work on from the eval chief. 61 | We first parse the evaluation configuration, load any weights (e.g. weights of 62 | the learned optimizer), perform some inner-training procedure which does the 63 | evaluation, and finally report the results back to the eval chief. 64 | This loop repeats over and over again. 65 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | learned_optimization reference documentation 2 | ============================================ 3 | 4 | learned\_optimization is a research codebase for training learned 5 | optimizers. It implements hand designed and learned optimizers, tasks to meta-train and meta-test them on, and outer-training algorithms such as ES, gradients and PES. 6 | 7 | .. toctree:: 8 | :maxdepth: 0 9 | :caption: Getting Started 10 | 11 | notebooks/Part1_Introduction 12 | notebooks/Part2_CustomTasks 13 | notebooks/Part3_Truncation_TruncatedStep 14 | notebooks/Part4_GradientEstimators 15 | notebooks/Part5_Meta_training_with_GradientLearner 16 | notebooks/Part6_custom_learned_optimizers 17 | 18 | .. toctree:: 19 | :maxdepth: 0 20 | :caption: Tutorials 21 | 22 | notebooks/no_dependency_learned_optimizer 23 | notebooks/summary_tutorial 24 | 25 | .. toctree:: 26 | :maxdepth: 0 27 | :caption: Modules 28 | 29 | tasks 30 | optimizers 31 | learned_optimizers 32 | outer_trainers 33 | continuous_eval 34 | population 35 | optimizer_baselines 36 | 37 | 38 | Indices and tables 39 | ================== 40 | 41 | * :ref:`genindex` 42 | * :ref:`modindex` 43 | * :ref:`search` 44 | -------------------------------------------------------------------------------- /docs/learned_optimizers.rst: -------------------------------------------------------------------------------- 1 | Learned Optimizers 2 | ==================== 3 | 4 | .. automodule:: learned_optimization.learned_optimizers.base 5 | 6 | 7 | API 8 | --- 9 | 10 | .. autofunction:: learned_optimization.learned_optimizers.base.LearnedOptimizer 11 | .. autofunction:: learned_optimization.learned_optimizers.base.LearnableSGD 12 | .. autofunction:: learned_optimization.learned_optimizers.base.LearnableSGDM 13 | .. autofunction:: learned_optimization.learned_optimizers.base.LearnableAdam -------------------------------------------------------------------------------- /docs/optimizers.rst: -------------------------------------------------------------------------------- 1 | Optimizers 2 | =========== 3 | 4 | .. automodule:: learned_optimization.optimizers.base 5 | 6 | 7 | API 8 | --- 9 | 10 | .. autofunction:: learned_optimization.optimizers.base.Optimizer 11 | .. autofunction:: learned_optimization.optimizers.base.SGD 12 | .. autofunction:: learned_optimization.optimizers.base.SGDM 13 | .. autofunction:: learned_optimization.optimizers.base.Adam 14 | -------------------------------------------------------------------------------- /docs/outer_trainers.rst: -------------------------------------------------------------------------------- 1 | Outer Trainers 2 | =============== 3 | 4 | Outer-training, or meta-training is the process of learning the meta-parameters. 5 | 6 | 7 | 8 | Base 9 | ---- 10 | 11 | .. autofunction:: learned_optimization.outer_trainers.gradient_learner.GradientEstimator 12 | .. autofunction:: learned_optimization.outer_trainers.gradient_learner.gradient_worker_compute 13 | .. autofunction:: learned_optimization.outer_trainers.gradient_learner.SingleMachineGradientLearner 14 | 15 | 16 | FullES 17 | -------------- 18 | 19 | .. automodule:: learned_optimization.outer_trainers.full_es 20 | 21 | .. autofunction:: learned_optimization.outer_trainers.full_es.FullES 22 | 23 | 24 | 25 | TruncatedPES 26 | -------------- 27 | 28 | .. automodule:: learned_optimization.outer_trainers.truncated_pes 29 | 30 | .. autofunction:: learned_optimization.outer_trainers.truncated_pes.TruncatedPES -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.3.4 2 | pandoc>=1.0.2 3 | sphinx>=3.5.1 4 | recommonmark>=0.7.1 5 | sphinx_rtd_theme>=0.5.1 6 | sphinx_autodoc_typehints>=1.11.1 7 | ipython>=7.20.0 8 | ipykernel>=5.5.0 9 | nbsphinx>=0.8.7 10 | sphinx_copybutton>=0.4.0 11 | myst_nb>=0.13.0 12 | -------------------------------------------------------------------------------- /docs/tasks.rst: -------------------------------------------------------------------------------- 1 | Tasks 2 | ===== 3 | 4 | .. automodule:: learned_optimization.tasks.base 5 | 6 | 7 | Base 8 | --- 9 | 10 | .. autofunction:: learned_optimization.tasks.base.Task 11 | .. autofunction:: learned_optimization.tasks.base.TaskFamily 12 | 13 | 14 | QuadraticTasks 15 | -------------- 16 | 17 | .. automodule:: learned_optimization.tasks.quadratics 18 | 19 | .. autofunction:: learned_optimization.tasks.quadratics.QuadraticTask 20 | .. autofunction:: learned_optimization.tasks.quadratics.BatchQuadraticTask 21 | .. autofunction:: learned_optimization.tasks.quadratics.SumQuadraticTask 22 | .. autofunction:: learned_optimization.tasks.quadratics.FixedDimQuadraticFamily 23 | .. autofunction:: learned_optimization.tasks.quadratics.FixedDimQuadraticFamilyData -------------------------------------------------------------------------------- /learned_optimization/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """learned_optimizer module.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Baselines module.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/baselines/data/speedup_over_adam.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/learned_optimization/3ab34d27bdd38f487af65fb025f9c86869946e81/learned_optimization/baselines/data/speedup_over_adam.pkl -------------------------------------------------------------------------------- /learned_optimization/baselines/normalizers_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for normalizers.""" 17 | 18 | from absl.testing import absltest 19 | from learned_optimization.baselines import normalizers 20 | 21 | 22 | class NormalizersTest(absltest.TestCase): 23 | 24 | def test_build_and_apply_one(self): 25 | data = normalizers._speedup_over_adam_build("ImageMLP_Cifar10_8_Relu32") 26 | fn = normalizers._speedup_over_adam_make_func(data) 27 | ret = fn(1.0) 28 | self.assertGreater(ret, 0) 29 | # check that it is no more than 100k -- length of max adam lengths. 30 | self.assertLess(ret, 100_001) 31 | 32 | def test_speedup_over_adam_normalizer_map(self): 33 | normfns = normalizers.speedup_over_adam_normalizer_map() 34 | ret = normfns["ImageMLP_Cifar10_8_Relu32"](1.0) 35 | self.assertGreater(ret, 0) 36 | self.assertLess(ret, 100_001) 37 | 38 | 39 | if __name__ == "__main__": 40 | absltest.main() 41 | -------------------------------------------------------------------------------- /learned_optimization/baselines/run_time_task.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Time a single task.""" 17 | 18 | import os 19 | import pickle 20 | import uuid 21 | 22 | from absl import app 23 | from absl import logging 24 | import gin 25 | import jax 26 | from learned_optimization import filesystem 27 | from learned_optimization import setup_experiment 28 | from learned_optimization.tasks import base as tasks_base 29 | from learned_optimization.time_filter import timings 30 | 31 | 32 | @gin.configurable 33 | def run_many_eval_and_save(task: tasks_base.Task = gin.REQUIRED, 34 | save_dir: str = gin.REQUIRED): 35 | """Compute and save `num_to_run` runtime statistics.""" 36 | 37 | dev = jax.devices()[0] 38 | dev_name = f"{dev.platform}_{dev.device_kind}" 39 | dev_name = dev_name.replace(" ", "") 40 | 41 | save_dir = os.path.join(save_dir, dev_name) 42 | filesystem.make_dirs(save_dir) 43 | 44 | task_family = tasks_base.single_task_to_family(task) 45 | 46 | speeds_and_std = timings.task_family_runtime_stats( 47 | task_family, num_tasks_list=[1, 2, 4, 8]) 48 | 49 | output = pickle.dumps(speeds_and_std) 50 | 51 | path = os.path.join(save_dir, f"{str(uuid.uuid4())}.pkl") 52 | logging.info("Writing results to %s", path) 53 | with filesystem.file_open(path, "wb") as f: 54 | f.write(output) 55 | 56 | 57 | def main(argv): 58 | if len(argv) > 1: 59 | raise app.UsageError("Too many command-line arguments.") 60 | setup_experiment.setup_experiment(make_dir=False) 61 | run_many_eval_and_save() 62 | 63 | 64 | if __name__ == "__main__": 65 | app.run(main) 66 | -------------------------------------------------------------------------------- /learned_optimization/baselines/run_trainer_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for run_trainer.""" 17 | 18 | import os 19 | import tempfile 20 | 21 | from absl.testing import absltest 22 | import gin 23 | from learned_optimization.baselines import run_trainer 24 | from learned_optimization.baselines import utils 25 | from learned_optimization.optimizers import base as opt_base 26 | from learned_optimization.tasks import quadratics 27 | import numpy as onp 28 | from numpy import testing 29 | 30 | 31 | @gin.configurable() 32 | def make_quad(): 33 | return quadratics.QuadraticTask() 34 | 35 | 36 | class RunTrainerTest(absltest.TestCase): 37 | 38 | def test_inner_train_task(self): 39 | with tempfile.TemporaryDirectory() as tmpdir: 40 | os.environ["LOPT_BASELINE_DIR"] = tmpdir 41 | opt = opt_base.SGD() 42 | task = quadratics.QuadraticTask() 43 | run_trainer.inner_train_task( 44 | task, 45 | opt, 46 | num_steps=10, 47 | eval_every=5, 48 | eval_batches=2, 49 | last_eval_batches=4) 50 | 51 | results = utils.load_baseline_results( 52 | "QuadraticTask", 53 | "SGD_lr0.01", 54 | num_steps=10, 55 | eval_every=5, 56 | eval_batches=2, 57 | last_eval_batches=4, 58 | output_type="curves", 59 | threads=0) 60 | 61 | self.assertLen(results, 1) 62 | results, = results 63 | 64 | self.assertEqual(results["num_steps"], 10) 65 | self.assertEqual(results["eval_batches"], 2) 66 | self.assertEqual(results["last_eval_batches"], 4) 67 | testing.assert_almost_equal(results["train/xs"], 68 | onp.arange(11, dtype=onp.int32)) 69 | testing.assert_almost_equal(results["eval/xs"], onp.asarray([0, 5, 10])) 70 | 71 | def test_inner_train_task_from_gin(self): 72 | with tempfile.TemporaryDirectory() as tmpdir: 73 | os.environ["LOPT_BASELINE_DIR"] = tmpdir 74 | configs = [ 75 | "inner_train_task.task=@make_quad()", "inner_train_task.opt=@SGD()" 76 | ] 77 | gin.parse_config(configs) 78 | 79 | run_trainer.inner_train_task( 80 | num_steps=10, eval_every=5, eval_batches=2, last_eval_batches=4) 81 | 82 | results = utils.load_baseline_results( 83 | "make_quad", 84 | "SGD_lr0.01", 85 | num_steps=10, 86 | eval_every=5, 87 | eval_batches=2, 88 | last_eval_batches=4, 89 | output_type="curves", 90 | threads=0) 91 | 92 | self.assertLen(results, 1) 93 | 94 | 95 | if __name__ == "__main__": 96 | absltest.main() 97 | -------------------------------------------------------------------------------- /learned_optimization/circular_buffer_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimization.circular_buffer.""" 17 | 18 | from absl.testing import absltest 19 | import haiku as hk 20 | import jax 21 | import jax.numpy as jnp 22 | from learned_optimization import circular_buffer 23 | import numpy as onp 24 | 25 | freeze = hk.data_structures.to_immutable_dict 26 | 27 | 28 | class CircularBufferTest(absltest.TestCase): 29 | 30 | def test_circular_buffer(self): 31 | s = freeze({"a": jnp.ones([1, 3])}) 32 | av = jax.tree_util.tree_map(lambda x: x.aval, s) 33 | 34 | buffer = circular_buffer.CircularBuffer(av, 5) 35 | 36 | state = buffer.init() 37 | 38 | state = buffer.add(state, s) 39 | 40 | vs, m = buffer.stack_reorder(state) 41 | assert onp.sum(m) == 1 42 | onp.testing.assert_allclose(m, onp.asarray([0, 0, 0, 0, 1])) 43 | onp.testing.assert_allclose(vs["a"][-1], onp.ones([1, 3])) 44 | onp.testing.assert_allclose(vs["a"][0], onp.zeros([1, 3])) 45 | 46 | vs, idx = buffer.stack_with_idx(state) 47 | onp.testing.assert_allclose(idx, [4, -1, -1, -1, -1]) 48 | onp.testing.assert_allclose(vs["a"][0], onp.ones([1, 3])) 49 | 50 | for _ in range(4): 51 | state = buffer.add(state, s) 52 | 53 | s = freeze({"a": jnp.ones([1, 3]) * 2}) 54 | state = buffer.add(state, s) 55 | 56 | vs, m = buffer.stack_reorder(state) 57 | assert onp.sum(m) == 5 58 | onp.testing.assert_allclose(m, onp.asarray([1, 1, 1, 1, 1])) 59 | onp.testing.assert_allclose(vs["a"][-1], 2 * onp.ones([1, 3])) 60 | onp.testing.assert_allclose(vs["a"][0], onp.ones([1, 3])) 61 | 62 | vs, idx = buffer.stack_with_idx(state) 63 | onp.testing.assert_allclose(idx, [4, 0, 1, 2, 3]) 64 | onp.testing.assert_allclose(vs["a"][0], 2 * onp.ones([1, 3])) 65 | onp.testing.assert_allclose(vs["a"][1], onp.ones([1, 3])) 66 | 67 | def test_gather(self): 68 | s = freeze({"a": jnp.ones([3])}) 69 | av = jax.tree_util.tree_map(lambda x: x.aval, s) 70 | 71 | buffer = circular_buffer.CircularBuffer(av, 5) 72 | 73 | state = buffer.init() 74 | 75 | for i in range(12): 76 | s = freeze({"a": jnp.ones([3]) * i}) 77 | state = buffer.add(state, s) 78 | 79 | v = buffer.gather_from_present(state, jnp.asarray([4, 3, 0, 0])) 80 | assert v["a"][0][0] == 11.0 81 | assert v["a"][1][0] == 10.0 82 | assert v["a"][2][1] == 7.0 83 | 84 | s = freeze({"a": jnp.ones([3]) * (i + 1)}) 85 | state = buffer.add(state, s) 86 | v = buffer.gather_from_present(state, jnp.asarray([4, 3, 0, 0])) 87 | assert v["a"][0][0] == 12.0 88 | assert v["a"][1][0] == 11.0 89 | assert v["a"][2][1] == 8.0 90 | 91 | 92 | if __name__ == "__main__": 93 | absltest.main() 94 | -------------------------------------------------------------------------------- /learned_optimization/continuous_eval/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Continuous evalulation module.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/continuous_eval/task_group_server_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.continuous_eval.task_group_server.""" 17 | 18 | import os 19 | import shutil 20 | import tempfile 21 | 22 | from absl.testing import absltest 23 | from learned_optimization.continuous_eval import task_group_server 24 | 25 | 26 | class TaskGroupServerTest(absltest.TestCase): 27 | 28 | def test_task_group_chief(self): 29 | log_dir = os.path.join(tempfile.gettempdir(), "log_dir") 30 | 31 | if os.path.exists(log_dir): 32 | shutil.rmtree(log_dir) 33 | 34 | chief = task_group_server.TaskGroupChief( 35 | "test", log_dir, 1, start_server=False) 36 | chief.daemon = True 37 | 38 | chief.start() 39 | 40 | tasks = range(5) 41 | chief.add_task_group("HelloTask", tasks) 42 | 43 | for _ in range(5): 44 | self.assertIs(chief.get_finished_task_group(), None) 45 | work = chief.get_work(0) 46 | self.assertIsNot(work, None) 47 | chief.finish_work(0, 123) 48 | 49 | task_group, values, tasks = chief.get_finished_task_group() 50 | self.assertEqual(values, [123] * 5) 51 | self.assertEqual(task_group, "HelloTask") 52 | self.assertIs(chief.get_finished_task_group(), None) 53 | 54 | work = chief.get_work(0) 55 | self.assertIs(work, None) 56 | 57 | chief.should_stop = True 58 | chief.join() 59 | 60 | 61 | if __name__ == "__main__": 62 | absltest.main() 63 | -------------------------------------------------------------------------------- /learned_optimization/entry_points/README.md: -------------------------------------------------------------------------------- 1 | # Entry points 2 | 3 | The contents of this folder contain various run-able scripts, all configured with gin. 4 | 5 | -------------------------------------------------------------------------------- /learned_optimization/entry_points/run_check_eval_worker.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Local test script to check that the evaluation configurations run.""" 17 | 18 | import time 19 | from typing import Sequence 20 | 21 | from absl import app 22 | import gin 23 | import jax 24 | from learned_optimization import eval_training 25 | from learned_optimization import profile 26 | from learned_optimization import setup_experiment 27 | from learned_optimization.continuous_eval import run_eval_chief 28 | from learned_optimization.continuous_eval import run_eval_worker 29 | from learned_optimization.continuous_eval import evaluation_set # pylint: disable=unused-import 30 | 31 | 32 | def main(unused_argv: Sequence[str]) -> None: 33 | setup_experiment.setup_experiment(gin_finalize=False) 34 | 35 | unused_chief_name, unused_num_workers, lopt = run_eval_chief.eval_chief_config( 36 | ) 37 | 38 | cfg = gin.query_parameter("run_evaluation_chief.evaluation_set") 39 | eval_task_configs = cfg.configurable.wrapper() 40 | task = eval_task_configs[0] 41 | (eval_cfg, unused_eval_name) = task 42 | 43 | gin.parse_config(eval_cfg, skip_unknown=True) 44 | 45 | key = jax.random.PRNGKey(0) 46 | theta = lopt.init(key) 47 | task_family = run_eval_worker.get_task_family() 48 | 49 | with profile.Profile("inner_train"): 50 | stime = time.time() 51 | losses = eval_training.multi_task_training_curves( 52 | task_family, lopt, theta=theta) 53 | total_time = time.time() - stime 54 | result = {"total_time": total_time, "gen_id": "", "step": 123} 55 | for k, v in losses.items(): 56 | result[k] = v 57 | print(result) 58 | 59 | 60 | if __name__ == "__main__": 61 | app.run(main) 62 | -------------------------------------------------------------------------------- /learned_optimization/entry_points/run_outer_train_single.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Entry point / main script for outer_train.run_train on a single machine.""" 17 | from typing import Sequence 18 | 19 | from absl import app 20 | from learned_optimization import outer_train 21 | from learned_optimization import setup_experiment 22 | 23 | 24 | def main(argv: Sequence[str]) -> None: 25 | if len(argv) > 1: 26 | raise app.UsageError('Too many command-line arguments.') 27 | 28 | train_log_dir = setup_experiment.setup_experiment(make_dir=True) 29 | # Run both the trainer and the worker on the same machine. 30 | outer_train.run_train( 31 | train_log_dir, 32 | is_trainer=True, 33 | is_worker=True, 34 | ) 35 | 36 | 37 | if __name__ == '__main__': 38 | app.run(main) 39 | -------------------------------------------------------------------------------- /learned_optimization/entry_points/run_outer_trainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Entry point to run just the learner process. 17 | 18 | This should be run with run_outer_worker.py. 19 | """ 20 | from typing import Sequence 21 | 22 | from absl import app 23 | from learned_optimization import outer_train 24 | from learned_optimization import setup_experiment 25 | 26 | 27 | def main(argv: Sequence[str]) -> None: 28 | if len(argv) > 1: 29 | raise app.UsageError('Too many command-line arguments.') 30 | 31 | train_log_dir = setup_experiment.setup_experiment(make_dir=True) 32 | outer_train.run_train( 33 | train_log_dir, 34 | is_trainer=True, 35 | is_worker=False, 36 | ) 37 | 38 | 39 | if __name__ == '__main__': 40 | app.run(main) 41 | -------------------------------------------------------------------------------- /learned_optimization/entry_points/run_outer_worker.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Entry point to run just the worker process. 17 | 18 | This should be run with run_outer_trainer.py. 19 | """ 20 | from typing import Sequence 21 | 22 | from absl import app 23 | from absl import flags 24 | from learned_optimization import outer_train 25 | from learned_optimization import setup_experiment 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | 30 | def main(argv: Sequence[str]) -> None: 31 | if len(argv) > 1: 32 | raise app.UsageError('Too many command-line arguments.') 33 | 34 | train_log_dir = setup_experiment.setup_experiment() 35 | outer_train.run_train( 36 | train_log_dir, is_trainer=False, is_worker=True, worker_id=FLAGS.task) 37 | 38 | 39 | if __name__ == '__main__': 40 | app.run(main) 41 | -------------------------------------------------------------------------------- /learned_optimization/entry_points/run_population_controller.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Entry point for population based training.""" 17 | 18 | from absl import app 19 | from learned_optimization import outer_train 20 | from learned_optimization import setup_experiment 21 | 22 | 23 | 24 | def main(argv): 25 | train_log_dir = setup_experiment.setup_experiment(make_dir=True) 26 | 27 | outer_train.run_population_controller(train_log_dir) 28 | 29 | 30 | if __name__ == "__main__": 31 | app.run(main) 32 | -------------------------------------------------------------------------------- /learned_optimization/entry_points/run_tboard_eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Script that inner trains and logs metrics to eventfiles to view with tboard. 17 | 18 | This is used to debug task implementations. 19 | """ 20 | from typing import Optional 21 | 22 | from absl import app 23 | import gin 24 | import jax 25 | from learned_optimization import eval_training 26 | from learned_optimization import setup_experiment 27 | from learned_optimization import summary 28 | from learned_optimization.optimizers import base as opt_base 29 | from learned_optimization.tasks import base as tasks_base 30 | import numpy as onp 31 | 32 | 33 | @gin.configurable 34 | def inner_train_task( 35 | train_log_dir, 36 | task: tasks_base.Task = gin.REQUIRED, 37 | opt: opt_base.Optimizer = gin.REQUIRED, 38 | num_steps: int = gin.REQUIRED, 39 | eval_every: int = 10, 40 | eval_batches: int = 5, 41 | last_eval_batches: int = 10, 42 | eval_task: Optional[tasks_base.Task] = None, 43 | metrics_every: Optional[int] = None, 44 | device: Optional[jax.Device] = None, 45 | ): 46 | """Train and save results of a single training run. 47 | 48 | Args: 49 | train_log_dir: location to put summaries. 50 | task: task to train 51 | opt: optimizer to apply. 52 | num_steps: number of training iterations. 53 | eval_every: how frequently to run evaluation. 54 | eval_batches: number of batches to evaluate on. 55 | last_eval_batches: number of batches to run at end of training. 56 | eval_task: The task used for evaluation. If none we use the same task as 57 | used for training. This is useful when having different train and test 58 | functions as is the case for batchnorm and dropout. 59 | device: device to train with. 60 | """ 61 | 62 | seed = onp.random.randint(0, onp.iinfo(onp.int32).max) 63 | key = jax.random.PRNGKey(seed) 64 | 65 | summary_writer = summary.MultiWriter(summary.PrintWriter(), 66 | summary.JaxboardWriter(train_log_dir)) 67 | 68 | eval_training.single_task_training_curves( 69 | task=task, 70 | opt=opt, 71 | num_steps=num_steps, 72 | key=key, 73 | eval_every=eval_every, 74 | eval_batches=eval_batches, 75 | last_eval_batches=last_eval_batches, 76 | eval_task=eval_task, 77 | device=device, 78 | metrics_every=metrics_every, 79 | summary_writer=summary_writer) 80 | 81 | 82 | def main(argv): 83 | if len(argv) > 1: 84 | raise app.UsageError("Too many command-line arguments.") 85 | train_log_dir = setup_experiment.setup_experiment(make_dir=True) 86 | 87 | inner_train_task(train_log_dir) 88 | 89 | 90 | if __name__ == "__main__": 91 | app.run(main) 92 | -------------------------------------------------------------------------------- /learned_optimization/eval_training_multi_device_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.eval_training and multiple devices.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from learned_optimization import eval_training 21 | from learned_optimization.learned_optimizers import base as lopt_base 22 | from learned_optimization.tasks import quadratics 23 | 24 | 25 | class EvalTrainingMultiDeviceTest(absltest.TestCase): 26 | 27 | def test_multi_task_training_curves(self): 28 | num_devices = len(jax.local_devices()) 29 | if num_devices != 8: 30 | self.skipTest(f"Not enough accelerators! Expected 8, found {num_devices}." 31 | " Please fun this test with exactly 8 accelerators.") 32 | 33 | key = jax.random.PRNGKey(0) 34 | task_family = quadratics.FixedDimQuadraticFamilyData(10) 35 | learned_opt = lopt_base.LearnableSGD() 36 | theta = learned_opt.init(key) 37 | 38 | res = eval_training.multi_task_training_curves( 39 | task_family, 40 | learned_opt, 41 | theta, 42 | n_tasks=16, # 2 tasks per device 43 | key=key, 44 | eval_every=2, 45 | steps_per_jit=2, 46 | n_eval_batches=3, 47 | n_eval_batches_vec=4, 48 | steps=10, 49 | with_aux_values=("l2",), 50 | ) 51 | self.assertIn("train/xs", res) 52 | self.assertIn("train/loss", res) 53 | self.assertIn("eval/train/loss", res) 54 | 55 | 56 | if __name__ == "__main__": 57 | absltest.main() 58 | -------------------------------------------------------------------------------- /learned_optimization/eval_training_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.eval_training.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from learned_optimization import eval_training 21 | from learned_optimization.learned_optimizers import base as lopt_base 22 | from learned_optimization.optimizers import base as opt_base 23 | from learned_optimization.tasks import quadratics 24 | 25 | 26 | class EvalTrainingTest(absltest.TestCase): 27 | 28 | def test_single_task_training_curves_no_data(self): 29 | key = jax.random.PRNGKey(0) 30 | task = quadratics.QuadraticTask(10) 31 | opt = opt_base.SGD() 32 | res = eval_training.single_task_training_curves( 33 | task, opt, 10, key, eval_every=2) 34 | self.assertIn("train/xs", res) 35 | self.assertIn("train/loss", res) 36 | self.assertEqual(res["train/loss"].shape, (11,)) 37 | 38 | def test_single_task_training_curves_with_data(self): 39 | key = jax.random.PRNGKey(0) 40 | task = quadratics.BatchQuadraticTask() 41 | opt = opt_base.SGD() 42 | res = eval_training.single_task_training_curves( 43 | task, opt, 10, key, eval_every=2) 44 | self.assertIn("train/xs", res) 45 | self.assertIn("train/loss", res) 46 | self.assertEqual(res["train/loss"].shape, (11,)) 47 | 48 | def test_multi_task_training_curves(self): 49 | key = jax.random.PRNGKey(0) 50 | task_family = quadratics.FixedDimQuadraticFamilyData(10) 51 | learned_opt = lopt_base.LearnableSGD() 52 | theta = learned_opt.init(key) 53 | 54 | res = eval_training.multi_task_training_curves( 55 | task_family, 56 | learned_opt, 57 | theta, 58 | n_tasks=5, 59 | key=key, 60 | eval_every=2, 61 | steps_per_jit=2, 62 | steps=10, 63 | with_aux_values=("l2",), 64 | ) 65 | self.assertIn("train/xs", res) 66 | self.assertIn("train/loss", res) 67 | self.assertIn("eval/train/loss", res) 68 | self.assertEqual(res["train/loss"].shape, ( 69 | 5, 70 | 10, 71 | )) 72 | 73 | # 10 / 2 and then 1 additional due to eval-ing at the end. 74 | self.assertEqual(res["eval/train/loss"].shape, ( 75 | 5, 76 | 6, 77 | )) 78 | 79 | self.assertIn("eval/train/aux/l2", res) 80 | self.assertNotIn("eval/train/aux/l1", res) 81 | 82 | 83 | if __name__ == "__main__": 84 | absltest.main() 85 | -------------------------------------------------------------------------------- /learned_optimization/examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | ### Simple LOpt Train 4 | A small example training a learned optimizer. 5 | -------------------------------------------------------------------------------- /learned_optimization/examples/simple_lopt_train_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for simple_lopt_train.""" 17 | 18 | import tempfile 19 | 20 | from absl.testing import absltest 21 | from learned_optimization.examples import simple_lopt_train 22 | from learned_optimization.tasks import quadratics 23 | 24 | 25 | class SimpleLoptTrainTest(absltest.TestCase): 26 | 27 | def test_train(self): 28 | with tempfile.TemporaryDirectory() as path: 29 | simple_lopt_train.train(path, 10, task=quadratics.QuadraticTask()) 30 | 31 | 32 | if __name__ == '__main__': 33 | absltest.main() 34 | -------------------------------------------------------------------------------- /learned_optimization/filesystem.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Filesystem interface for writing to different data sources.""" 17 | 18 | import glob as py_glob 19 | import os 20 | import shutil 21 | from typing import Sequence 22 | import tensorflow as tf 23 | 24 | 25 | 26 | def _path_on_gcp(path: str) -> bool: 27 | prefixes = ["gs://"] 28 | return any([path.startswith(p) for p in prefixes]) 29 | 30 | 31 | def file_open(path: str, mode: str): 32 | """Open a file, returning a file object.""" 33 | if _path_on_gcp(path): 34 | return tf.io.gfile.GFile(path, mode) 35 | return open(path, mode) 36 | 37 | 38 | def make_dirs(path: str): 39 | """Make directories for given path.""" 40 | if _path_on_gcp(path): 41 | return tf.io.gfile.makedirs(path) 42 | 43 | if not os.path.exists(path): 44 | return os.makedirs(path) 45 | 46 | 47 | def copy(path: str, target: str): 48 | """Copy path to target.""" 49 | if _path_on_gcp(path): 50 | return tf.io.gfile.copy(path, target, overwrite=True) 51 | 52 | return shutil.copy(path, target) 53 | 54 | 55 | def rename(path: str, target: str): 56 | """Copy path to target.""" 57 | if _path_on_gcp(path): 58 | return tf.io.gfile.rename(path, target, overwrite=True) 59 | 60 | return shutil.move(path, target) 61 | 62 | 63 | def exists(path: str) -> bool: 64 | """Check if a file exists.""" 65 | if _path_on_gcp(path): 66 | return tf.io.gfile.exists(path) 67 | 68 | return os.path.exists(path) 69 | 70 | 71 | def glob(pattern: str) -> Sequence[str]: 72 | """Glob the filesystem with given pattern.""" 73 | if _path_on_gcp(pattern): 74 | return tf.io.gfile.glob(pattern) 75 | 76 | return py_glob.glob(pattern) 77 | 78 | 79 | def remove(path: str) -> bool: 80 | """Remove a file.""" 81 | if _path_on_gcp(path): 82 | return tf.io.gfile.remove(path) 83 | return shutil.rmtree(path) # pytype: disable=bad-return-type 84 | -------------------------------------------------------------------------------- /learned_optimization/jax_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for jax_utils.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from learned_optimization import jax_utils 21 | 22 | 23 | class JaxUtilsTest(absltest.TestCase): 24 | 25 | def test_in_jit(self): 26 | state = [] 27 | 28 | def fn(x): 29 | if jax_utils.in_jit(): 30 | state.append(None) 31 | return x * 2 32 | 33 | _ = fn(1) 34 | self.assertEmpty(state) 35 | 36 | _ = jax.jit(fn)(1) 37 | self.assertLen(state, 1) 38 | 39 | 40 | if __name__ == '__main__': 41 | absltest.main() 42 | -------------------------------------------------------------------------------- /learned_optimization/launchers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """learned_optimization.launchers""" 17 | -------------------------------------------------------------------------------- /learned_optimization/learned_optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """learned_optimizers.learned_optimizers module.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/learned_optimizers/adafac_mlp_lopt_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.adafac_mlp_lopt.""" 17 | from absl.testing import absltest 18 | import jax 19 | from learned_optimization.learned_optimizers import adafac_mlp_lopt 20 | from learned_optimization.learned_optimizers import test_utils 21 | import numpy as onp 22 | 23 | 24 | class AdafacMLPLOptTest(absltest.TestCase): 25 | 26 | def test_adafac_mlp_lopt(self): 27 | test_utils.smoketest_learned_optimizer(adafac_mlp_lopt.AdafacMLPLOpt()) 28 | 29 | def test_split_weights_equal_to_full(self): 30 | lopt1 = adafac_mlp_lopt.AdafacMLPLOpt( 31 | concat_weights=True, step_mult=10, exp_mult=1) 32 | lopt2 = adafac_mlp_lopt.AdafacMLPLOpt( 33 | concat_weights=False, split_weights=True, step_mult=10, exp_mult=1) 34 | 35 | key = jax.random.PRNGKey(0) 36 | theta = lopt1.init(key) 37 | opt1 = lopt1.opt_fn(theta) 38 | opt2 = lopt2.opt_fn(theta) 39 | 40 | p = (jax.random.normal(key, [2, 2]),) 41 | g = (jax.random.normal(key, [2, 2]),) 42 | 43 | opt_state = opt1.init(p, num_steps=10) 44 | print(opt_state) 45 | 46 | opt_state1 = opt1.update(opt_state, g, 1.0) 47 | opt_state2 = opt2.update(opt_state, g, 1.0) 48 | 49 | diff = opt_state1.params[0] - opt_state2.params[0] 50 | self.assertLess(onp.mean(onp.abs(diff)), 1e-6) 51 | 52 | 53 | if __name__ == '__main__': 54 | absltest.main() 55 | -------------------------------------------------------------------------------- /learned_optimization/learned_optimizers/adafac_nominal_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.adafac_nominal.""" 17 | from absl.testing import absltest 18 | from learned_optimization.learned_optimizers import adafac_nominal 19 | from learned_optimization.learned_optimizers import test_utils 20 | 21 | 22 | class MLPNomLOptTest(absltest.TestCase): 23 | 24 | def test_adafac_mlp_lopt(self): 25 | test_utils.smoketest_learned_optimizer(adafac_nominal.MLPNomLOpt()) 26 | 27 | 28 | if __name__ == '__main__': 29 | absltest.main() 30 | -------------------------------------------------------------------------------- /learned_optimization/learned_optimizers/base_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.learned_optimizers.base.""" 17 | from absl.testing import absltest 18 | from learned_optimization.learned_optimizers import base 19 | from learned_optimization.learned_optimizers import test_utils 20 | 21 | 22 | class BaseTest(absltest.TestCase): 23 | 24 | def test_learnable_adam(self): 25 | test_utils.smoketest_learned_optimizer(base.LearnableAdam()) 26 | 27 | def test_learnable_sgd(self): 28 | test_utils.smoketest_learned_optimizer(base.LearnableSGD()) 29 | 30 | def test_learnable_sgdm(self): 31 | test_utils.smoketest_learned_optimizer(base.LearnableSGDM()) 32 | 33 | 34 | if __name__ == '__main__': 35 | absltest.main() 36 | -------------------------------------------------------------------------------- /learned_optimization/learned_optimizers/mlp_lopt_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.mlp_lopt.""" 17 | from absl.testing import absltest 18 | from learned_optimization.learned_optimizers import mlp_lopt 19 | from learned_optimization.learned_optimizers import test_utils 20 | 21 | 22 | class MLPLOptTest(absltest.TestCase): 23 | 24 | def test_mlp_lopt(self): 25 | test_utils.smoketest_learned_optimizer(mlp_lopt.MLPLOpt()) 26 | 27 | 28 | if __name__ == '__main__': 29 | absltest.main() 30 | -------------------------------------------------------------------------------- /learned_optimization/learned_optimizers/nn_adam_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.nn_adam.""" 17 | from absl.testing import absltest 18 | from learned_optimization.learned_optimizers import nn_adam 19 | from learned_optimization.learned_optimizers import test_utils 20 | 21 | 22 | class NNAdamTest(absltest.TestCase): 23 | 24 | def test_nn_adam(self): 25 | test_utils.smoketest_learned_optimizer(nn_adam.NNAdam()) 26 | 27 | 28 | if __name__ == '__main__': 29 | absltest.main() 30 | -------------------------------------------------------------------------------- /learned_optimization/learned_optimizers/rnn_mlp_lopt_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.rnn_mlp_lopt.""" 17 | from absl.testing import absltest 18 | from learned_optimization.learned_optimizers import rnn_mlp_lopt 19 | from learned_optimization.learned_optimizers import test_utils 20 | 21 | 22 | class RNNMLPLOptTest(absltest.TestCase): 23 | 24 | def test_rnn_mlp_lopt(self): 25 | test_utils.smoketest_learned_optimizer(rnn_mlp_lopt.RNNMLPLOpt()) 26 | 27 | 28 | if __name__ == '__main__': 29 | absltest.main() 30 | -------------------------------------------------------------------------------- /learned_optimization/learned_optimizers/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test utilities for testing learned optimizers.""" 17 | import jax 18 | from learned_optimization.optimizers import test_utils as opt_test_utils 19 | 20 | 21 | def smoketest_learned_optimizer(learned_optimizer, strict_types=True): 22 | key = jax.random.PRNGKey(0) 23 | 24 | meta_param = learned_optimizer.init(key) 25 | optimizer = learned_optimizer.opt_fn(meta_param) 26 | 27 | opt_test_utils.smoketest_optimizer(optimizer, strict_types=strict_types) 28 | -------------------------------------------------------------------------------- /learned_optimization/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """learned_optimizers.optimizers module.""" 17 | from learned_optimization.optimizers.base import * 18 | from learned_optimization.optimizers.optax_opts import * 19 | -------------------------------------------------------------------------------- /learned_optimization/optimizers/alias.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Adam configurations, which are named for easy usage from gin.""" 17 | import gin 18 | from learned_optimization.optimizers import learning_rate_schedules 19 | from learned_optimization.optimizers import optax_opts 20 | 21 | 22 | @gin.configurable() 23 | def AdamSqrtDecayLR(learning_rate): 24 | lr_decay = learning_rate_schedules.LinearRampupSqrtDecay( 25 | learning_rate, warmup_steps=1000) 26 | return optax_opts.Adam(lr_decay) 27 | 28 | 29 | @gin.configurable() 30 | def AdamExpDecayLR(learning_rate): 31 | # 7e-5 results in roughly a 3 order of magnitude reduction over 100_000 steps. 32 | lr_decay = learning_rate_schedules.ExponentialDecay(learning_rate, 7e-5) 33 | return optax_opts.Adam(lr_decay) 34 | -------------------------------------------------------------------------------- /learned_optimization/optimizers/base_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.optimizers.base.""" 17 | from absl.testing import absltest 18 | from learned_optimization.optimizers import base 19 | from learned_optimization.optimizers import optax_opts 20 | from learned_optimization.optimizers import test_utils 21 | 22 | 23 | class BaseTest(absltest.TestCase): 24 | 25 | def test_gradient_clip_adam(self): 26 | test_utils.smoketest_optimizer( 27 | base.GradientClipOptimizer(optax_opts.Adam(1e-4))) 28 | 29 | def test_grafted_optimizer(self): 30 | test_utils.smoketest_optimizer( 31 | base.GraftedOptimizer(optax_opts.SGDM(1e-2), optax_opts.Adam(1e-4))) 32 | 33 | if __name__ == '__main__': 34 | absltest.main() 35 | -------------------------------------------------------------------------------- /learned_optimization/optimizers/gradient_accumulator_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for gradient_accumulator.""" 17 | 18 | from absl.testing import absltest 19 | from learned_optimization.optimizers import gradient_accumulator 20 | from learned_optimization.optimizers import optax_opts 21 | from learned_optimization.optimizers import test_utils 22 | 23 | 24 | class GradientAccumulatorTest(absltest.TestCase): 25 | 26 | def test_gradient_accumulator(self): 27 | opt = optax_opts.Adam() 28 | opt = gradient_accumulator.GradientAccumulator(opt, 5) 29 | test_utils.smoketest_optimizer(opt) 30 | 31 | 32 | if __name__ == '__main__': 33 | absltest.main() 34 | -------------------------------------------------------------------------------- /learned_optimization/optimizers/nadamw_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.optimizers.nadamw.""" 17 | from absl.testing import absltest 18 | from learned_optimization.optimizers import learning_rate_schedules 19 | from learned_optimization.optimizers import nadamw 20 | from learned_optimization.optimizers import test_utils 21 | 22 | 23 | class NAdamWTest(absltest.TestCase): 24 | 25 | def test_nadamw_nesterov(self): 26 | test_utils.smoketest_optimizer(nadamw.NAdamW(use_nesterov=True)) 27 | 28 | def test_nadamw_weight_decays(self): 29 | test_utils.smoketest_optimizer( 30 | nadamw.NAdamW(adamw_weight_decay=1.0, l2_weight_decay=1.0)) 31 | 32 | def test_nadamw_with_schedule(self): 33 | sched = learning_rate_schedules.CosineLearningRateSchedule( 34 | 1e-3, 0.1, 0.3, 0.1) 35 | test_utils.smoketest_optimizer(nadamw.NAdamW(learning_rate=sched)) 36 | 37 | 38 | if __name__ == '__main__': 39 | absltest.main() 40 | -------------------------------------------------------------------------------- /learned_optimization/optimizers/opt_list_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.optimizers.opt_list.""" 17 | from absl.testing import absltest 18 | import jax 19 | import jax.numpy as jnp 20 | from learned_optimization.optimizers import opt_list 21 | from learned_optimization.optimizers import test_utils 22 | import numpy as onp 23 | 24 | 25 | class OptListTest(absltest.TestCase): 26 | 27 | def test_opt_list_nesterov(self): 28 | test_utils.smoketest_optimizer(opt_list.OptList(0)) 29 | 30 | def test_opt_list_vmap_idx(self): 31 | 32 | def init_update(idx): 33 | opt = opt_list.OptList(idx) 34 | p = jnp.ones(10) 35 | opt_state = opt.init(p, num_steps=10) 36 | return opt.get_params(opt.update(opt_state, p)) 37 | 38 | result = jax.vmap(init_update)(jnp.arange(10, dtype=jnp.int32)) 39 | # ensure we computed different updates. 40 | self.assertGreater(onp.sum(onp.abs(result[0, :] - result[1, :])), 0.0001) 41 | 42 | 43 | if __name__ == '__main__': 44 | absltest.main() 45 | -------------------------------------------------------------------------------- /learned_optimization/optimizers/opt_to_optax.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convert a learned_optimization Optimizer to an optax update.""" 17 | 18 | import dataclasses 19 | from typing import Any, Mapping, Optional, Tuple, NamedTuple 20 | 21 | import chex 22 | from learned_optimization import tree_utils 23 | from learned_optimization.optimizers import base 24 | import optax 25 | 26 | 27 | def opt_to_optax_opt( 28 | opt: base.Optimizer, num_steps: Optional[int] = None 29 | ) -> optax.GradientTransformationExtraArgs: 30 | """Convert a learned_optimization optimizer to an optax optimizers. 31 | 32 | This makes use of optax's 'extra_args' argument. For many learned optimizers 33 | this will contain loss values. 34 | 35 | Args: 36 | opt: the learned_optimization optimizer to convert. 37 | num_steps: The number of steps the optimizer is expected to run for. 38 | 39 | Returns: 40 | init_and_update: the optax optimizer. 41 | """ 42 | 43 | def init_fn(params: chex.ArrayTree, 44 | *, 45 | extra_args: Optional[Mapping[str, Any]] = None) -> chex.ArrayTree: 46 | del extra_args 47 | opt_state = opt.init(params, num_steps=num_steps) 48 | if dataclasses.is_dataclass(opt_state): 49 | return opt.set_params(opt_state, ()) 50 | else: 51 | raise NotImplementedError("Only flax dataclasses are supported!") 52 | 53 | def update_fn( 54 | updates: chex.ArrayTree, 55 | state: chex.ArrayTree, 56 | params: Optional[chex.ArrayTree] = None, 57 | *, 58 | extra_args: Optional[Mapping[str, Any]] = None 59 | ) -> Tuple[chex.ArrayTree, chex.ArrayTree]: 60 | if extra_args is None: 61 | extra_args = {} 62 | 63 | if params is None: 64 | raise ValueError("Params must not be None!") 65 | 66 | if dataclasses.is_dataclass(state): 67 | state = opt.set_params(state, params) 68 | else: 69 | raise NotImplementedError("Only flax dataclasses are supported!") 70 | 71 | next_state = opt.update(state, updates, **extra_args) 72 | 73 | step = tree_utils.tree_sub(opt.get_params(next_state), params) 74 | 75 | next_state = opt.set_params(next_state, ()) 76 | 77 | return step, next_state 78 | 79 | return optax.GradientTransformationExtraArgs(init_fn, update_fn) 80 | -------------------------------------------------------------------------------- /learned_optimization/optimizers/opt_to_optax_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for opt_to_optax.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from learned_optimization.learned_optimizers import rnn_mlp_lopt 21 | from learned_optimization.optimizers import opt_to_optax 22 | from learned_optimization.optimizers import optax_opts 23 | from learned_optimization.tasks import quadratics 24 | import optax 25 | 26 | 27 | class OptToOptaxTest(absltest.TestCase): 28 | 29 | def test_adam_opt_to_optax_opt(self): 30 | opt = optax_opts.Adam(1e-4) 31 | task = quadratics.QuadraticTask() 32 | key = jax.random.PRNGKey(0) 33 | p = task.init(key) 34 | 35 | optax_opt = opt_to_optax.opt_to_optax_opt(opt) 36 | 37 | opt_state = optax_opt.init(p) 38 | updates = p 39 | step, opt_state = optax_opt.update(updates, opt_state, params=p) 40 | p = optax.apply_updates(p, step) 41 | 42 | step, opt_state = optax_opt.update(updates, opt_state, params=p) 43 | p = optax.apply_updates(p, step) 44 | 45 | def test_learned_opt_to_optax_opt(self): 46 | lopt = rnn_mlp_lopt.RNNMLPLOpt() 47 | key = jax.random.PRNGKey(0) 48 | 49 | lo_opt = lopt.opt_fn(lopt.init(key)) 50 | optax_opt = opt_to_optax.opt_to_optax_opt(lo_opt, num_steps=100) 51 | 52 | task = quadratics.QuadraticTask() 53 | p = task.init(key) 54 | 55 | opt_state = optax_opt.init(p) 56 | updates = p 57 | step, opt_state = optax_opt.update( 58 | updates, opt_state, params=p, extra_args={ 59 | "loss": 1.0, 60 | "key": key 61 | }) 62 | p = optax.apply_updates(p, step) 63 | 64 | 65 | if __name__ == "__main__": 66 | absltest.main() 67 | -------------------------------------------------------------------------------- /learned_optimization/optimizers/optax_opts_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for optax_opts.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.optimizers import optax_opts 21 | from learned_optimization.optimizers import test_utils 22 | 23 | 24 | class OptaxOptsTest(parameterized.TestCase): 25 | 26 | @parameterized.named_parameters( 27 | ('RMSProp', lambda: optax_opts.RMSProp(1e-4)), 28 | ('SGDM', lambda: optax_opts.SGDM(1e-4)), 29 | ('SGD', lambda: optax_opts.SGD(1e-3)), 30 | ('Adam', lambda: optax_opts.Adam(1e-4)), 31 | ('AdaBelief', lambda: optax_opts.AdaBelief(1e-4)), 32 | ('AdaGrad', lambda: optax_opts.AdaGrad(1e-4)), 33 | ('Adafactor', lambda: optax_opts.Adafactor(1e-4)), 34 | ('Yogi', lambda: optax_opts.Yogi(1e-4)), 35 | ('RAdam', lambda: optax_opts.RAdam(1e-4)), 36 | ('AdamW', lambda: optax_opts.AdamW(1e-4)), 37 | ('Lamb', lambda: optax_opts.Lamb(1e-4)), 38 | ('Lars', lambda: optax_opts.Lars(1e-4)), 39 | ('Fromage', lambda: optax_opts.Fromage(1e-4)), 40 | ('PiecewiseLinearAdam', optax_opts.PiecewiseLinearAdam)) 41 | def test_opt(self, opt_fn): 42 | test_utils.smoketest_optimizer(opt_fn()) 43 | 44 | def test_sm3(self): 45 | # SM3 seems to switch the dtype of the state variables. 46 | test_utils.smoketest_optimizer(optax_opts.SM3(1e-4), strict_types=False) 47 | 48 | 49 | if __name__ == '__main__': 50 | absltest.main() 51 | -------------------------------------------------------------------------------- /learned_optimization/optimizers/shampoo_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.optimizers.opt_list.""" 17 | from absl.testing import absltest 18 | from learned_optimization.optimizers import shampoo 19 | from learned_optimization.optimizers import test_utils 20 | 21 | 22 | class ShampooTest(absltest.TestCase): 23 | 24 | def test_shampoo(self): 25 | test_utils.smoketest_optimizer( 26 | shampoo.Shampoo(1e-4, 32), strict_types=False) 27 | 28 | 29 | if __name__ == '__main__': 30 | absltest.main() 31 | -------------------------------------------------------------------------------- /learned_optimization/outer_trainers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """learned_optimizers.outer_trainers module.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/outer_trainers/full_grad_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.outer_trainers.full_grad.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.learned_optimizers import base 21 | from learned_optimization.outer_trainers import full_grad 22 | from learned_optimization.outer_trainers import lopt_truncated_step 23 | from learned_optimization.outer_trainers import test_utils 24 | from learned_optimization.outer_trainers import truncation_schedule 25 | from learned_optimization.tasks import quadratics 26 | 27 | 28 | class FullEsTest(parameterized.TestCase): 29 | 30 | def test_full_es_trainer(self): 31 | learned_opt = base.LearnableSGD() 32 | task_family = quadratics.FixedDimQuadraticFamily(10) 33 | trunc_sched = truncation_schedule.NeverEndingTruncationSchedule() 34 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep( 35 | task_family, learned_opt, trunc_sched, num_tasks=5) 36 | 37 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10) 38 | trainer = full_grad.FullGrad( 39 | truncated_step, truncation_schedule=trunc_sched, loss_type="avg") 40 | 41 | test_utils.trainer_smoketest(trainer) 42 | 43 | def test_full_grad_trainer_with_data(self): 44 | learned_opt = base.LearnableSGD() 45 | task_family = quadratics.FixedDimQuadraticFamilyData(10) 46 | trunc_sched = truncation_schedule.NeverEndingTruncationSchedule() 47 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep( 48 | task_family, learned_opt, trunc_sched, num_tasks=5) 49 | 50 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10) 51 | trainer = full_grad.FullGrad( 52 | truncated_step, loss_type="last", truncation_schedule=trunc_sched) 53 | 54 | test_utils.trainer_smoketest(trainer) 55 | 56 | 57 | if __name__ == "__main__": 58 | absltest.main() 59 | -------------------------------------------------------------------------------- /learned_optimization/outer_trainers/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities to test trainers.""" 17 | import jax 18 | from learned_optimization.outer_trainers import gradient_learner 19 | 20 | 21 | def trainer_smoketest(trainer): 22 | key = jax.random.PRNGKey(0) 23 | theta = trainer.truncated_step.outer_init(key) 24 | worker_weights = gradient_learner.WorkerWeights( 25 | theta, None, gradient_learner.OuterState(1) # pytype: disable=wrong-arg-types # jax-ndarray 26 | ) 27 | state = trainer.init_worker_state(worker_weights, key=key) 28 | out, metrics = trainer.compute_gradient_estimate( 29 | worker_weights, key, state, with_summary=True) 30 | del out 31 | del metrics 32 | -------------------------------------------------------------------------------- /learned_optimization/outer_trainers/truncated_es_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.outer_trainers.truncated_es.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.learned_optimizers import base 21 | from learned_optimization.outer_trainers import lopt_truncated_step 22 | from learned_optimization.outer_trainers import test_utils 23 | from learned_optimization.outer_trainers import truncated_es 24 | from learned_optimization.outer_trainers import truncation_schedule 25 | from learned_optimization.tasks import quadratics 26 | 27 | 28 | class TruncatedESTest(parameterized.TestCase): 29 | 30 | def test_truncated_es_trainer(self): 31 | learned_opt = base.LearnableSGD() 32 | task_family = quadratics.FixedDimQuadraticFamily(10) 33 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10) 34 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep( 35 | task_family, learned_opt, trunc_sched, num_tasks=5) 36 | 37 | trainer = truncated_es.TruncatedES( 38 | truncated_step, steps_per_jit=3, unroll_length=9) 39 | test_utils.trainer_smoketest(trainer) 40 | 41 | @parameterized.product(meta_loss_split=(None, "train", "same_data")) 42 | def test_truncated_es_trainer_with_data(self, meta_loss_split): 43 | learned_opt = base.LearnableSGD() 44 | task_family = quadratics.FixedDimQuadraticFamilyData(10) 45 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10) 46 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep( 47 | task_family, 48 | learned_opt, 49 | trunc_sched, 50 | num_tasks=4, 51 | meta_loss_split=meta_loss_split) 52 | 53 | trainer = truncated_es.TruncatedES( 54 | truncated_step, steps_per_jit=5, unroll_length=5) 55 | 56 | test_utils.trainer_smoketest(trainer) 57 | 58 | if __name__ == "__main__": 59 | absltest.main() 60 | -------------------------------------------------------------------------------- /learned_optimization/outer_trainers/truncated_grad_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.outer_trainers.truncated_grad.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.learned_optimizers import base 21 | from learned_optimization.outer_trainers import lopt_truncated_step 22 | from learned_optimization.outer_trainers import test_utils 23 | from learned_optimization.outer_trainers import truncated_grad 24 | from learned_optimization.outer_trainers import truncation_schedule 25 | from learned_optimization.tasks import quadratics 26 | 27 | 28 | class TruncatedGradTest(parameterized.TestCase): 29 | 30 | def test_truncated_grad_trainer(self): 31 | learned_opt = base.LearnableSGD() 32 | task_family = quadratics.FixedDimQuadraticFamily(10) 33 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10) 34 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep( 35 | task_family, learned_opt, trunc_sched, num_tasks=5) 36 | 37 | trainer = truncated_grad.TruncatedGrad( 38 | truncated_step, steps_per_jit=3, unroll_length=9) 39 | test_utils.trainer_smoketest(trainer) 40 | 41 | @parameterized.product(meta_loss_split=(None, "train")) 42 | def test_truncated_grad_trainer_with_data(self, meta_loss_split): 43 | learned_opt = base.LearnableSGD() 44 | task_family = quadratics.FixedDimQuadraticFamilyData(10) 45 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10) 46 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep( 47 | task_family, 48 | learned_opt, 49 | trunc_sched, 50 | num_tasks=5, 51 | meta_loss_split=meta_loss_split) 52 | 53 | trainer = truncated_grad.TruncatedGrad( 54 | truncated_step, steps_per_jit=5, unroll_length=5) 55 | 56 | test_utils.trainer_smoketest(trainer) 57 | 58 | if __name__ == "__main__": 59 | absltest.main() 60 | -------------------------------------------------------------------------------- /learned_optimization/outer_trainers/truncated_pes_pmap_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.outer_trainers.truncated_pes.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import jax 21 | from learned_optimization.learned_optimizers import base 22 | from learned_optimization.outer_trainers import lopt_truncated_step 23 | from learned_optimization.outer_trainers import test_utils 24 | from learned_optimization.outer_trainers import truncated_pes 25 | from learned_optimization.outer_trainers import truncation_schedule 26 | from learned_optimization.tasks import quadratics 27 | 28 | 29 | class TruncatedPesTest(parameterized.TestCase): 30 | 31 | @parameterized.product(meta_loss_split=(None, "train")) 32 | def test_truncated_pes_pmap_trainer_with_data(self, meta_loss_split): 33 | learned_opt = base.LearnableSGD() 34 | task_family = quadratics.FixedDimQuadraticFamilyData(10) 35 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10) 36 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep( 37 | task_family, 38 | learned_opt, 39 | trunc_sched, 40 | meta_loss_split=meta_loss_split, 41 | num_tasks=8) 42 | 43 | num_devices = len(jax.local_devices()) 44 | if num_devices != 8: 45 | self.skipTest(f"Not enough accelerators! Expected 8, found {num_devices}." 46 | " Please fun this test with exactly 8 accelerators.") 47 | 48 | trainer = truncated_pes.TruncatedPESPMAP( 49 | truncated_step=truncated_step, num_devices=num_devices, steps_per_jit=5) 50 | 51 | test_utils.trainer_smoketest(trainer) 52 | 53 | 54 | if __name__ == "__main__": 55 | absltest.main() 56 | -------------------------------------------------------------------------------- /learned_optimization/outer_trainers/truncated_pes_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.outer_trainers.truncated_pes.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.learned_optimizers import base 21 | from learned_optimization.outer_trainers import lopt_truncated_step 22 | from learned_optimization.outer_trainers import test_utils 23 | from learned_optimization.outer_trainers import truncated_pes 24 | from learned_optimization.outer_trainers import truncation_schedule 25 | from learned_optimization.tasks import quadratics 26 | 27 | 28 | class TruncatedPesTest(parameterized.TestCase): 29 | 30 | def test_truncated_pes_trainer(self): 31 | learned_opt = base.LearnableSGD() 32 | task_family = quadratics.FixedDimQuadraticFamily(10) 33 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10) 34 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep( 35 | task_family, learned_opt, trunc_sched, num_tasks=5) 36 | 37 | trainer = truncated_pes.TruncatedPES(truncated_step, steps_per_jit=5) 38 | test_utils.trainer_smoketest(trainer) 39 | 40 | @parameterized.product(meta_loss_split=(None, "train")) 41 | def test_truncated_pes_trainer_with_data(self, meta_loss_split): 42 | learned_opt = base.LearnableSGD() 43 | task_family = quadratics.FixedDimQuadraticFamilyData(10) 44 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10) 45 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep( 46 | task_family, 47 | learned_opt, 48 | trunc_sched, 49 | num_tasks=5, 50 | meta_loss_split=meta_loss_split) 51 | 52 | trainer = truncated_pes.TruncatedPES(truncated_step, steps_per_jit=5) 53 | 54 | test_utils.trainer_smoketest(trainer) 55 | 56 | def test_truncated_pes_stacked(self): 57 | learned_opt = base.LearnableSGD() 58 | task_family = quadratics.FixedDimQuadraticFamilyData(10) 59 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10) 60 | truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep( 61 | task_family, learned_opt, trunc_sched, num_tasks=5) 62 | trainer = truncated_pes.TruncatedPES( 63 | truncated_step, steps_per_jit=5, stack_antithetic_samples=True) 64 | 65 | test_utils.trainer_smoketest(trainer) 66 | 67 | if __name__ == "__main__": 68 | absltest.main() 69 | -------------------------------------------------------------------------------- /learned_optimization/outer_trainers/truncation_schedule_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.outer_trainers.truncation_schedule.""" 17 | 18 | from absl.testing import absltest 19 | import flax 20 | import jax 21 | from learned_optimization.optimizers import learning_rate_schedules 22 | from learned_optimization.outer_trainers import truncation_schedule 23 | 24 | 25 | @flax.struct.dataclass 26 | class OuterState: 27 | outer_iteration: int 28 | 29 | 30 | class TruncationScheduleTest(absltest.TestCase): 31 | 32 | def test_constant_truncation_sched(self): 33 | key = jax.random.PRNGKey(0) 34 | trunc_fns = truncation_schedule.ConstantTruncationSchedule(10) 35 | state = trunc_fns.init(key, outer_state=None) 36 | rng = key 37 | for i in range(10): 38 | rng, key = jax.random.split(rng) 39 | state, is_done = trunc_fns.next_state(state, i, key, outer_state=None) 40 | del is_done 41 | 42 | def test_LogUniformLengthSchedule(self): 43 | key = jax.random.PRNGKey(0) 44 | trunc_fns = truncation_schedule.LogUniformLengthSchedule(10, 100) 45 | state = trunc_fns.init(key, outer_state=None) 46 | rng = key 47 | for i in range(10): 48 | rng, key = jax.random.split(rng) 49 | state, is_done = trunc_fns.next_state(state, i, key, outer_state=None) 50 | assert state.length >= 10 51 | assert state.length <= 100 52 | del is_done 53 | 54 | def test_ScheduledTruncationSchedule(self): 55 | sched = learning_rate_schedules.PiecewiseLinear([0, 10], [10, 100]) 56 | key = jax.random.PRNGKey(0) 57 | trunc_fns = truncation_schedule.ScheduledTruncationSchedule(sched) 58 | outer_state = OuterState(0) 59 | state = trunc_fns.init(key, outer_state=outer_state) 60 | assert state.length == 10 61 | 62 | outer_state = OuterState(5) 63 | state = trunc_fns.init(key, outer_state=outer_state) 64 | assert state.length == int(10 + 100) // 2 65 | 66 | 67 | if __name__ == '__main__': 68 | absltest.main() 69 | -------------------------------------------------------------------------------- /learned_optimization/population/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Population based training module.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/population/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """learned_optimization.population.examples.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/population/examples/complex_cnn/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """learned_optimization.population.examples.complex_cnn.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/population/examples/simple_cnn/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """learned_optimization.population.examples.simple_cnn.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/population/examples/simple_cnn/train_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.population.examples.simple_cnn.train.""" 17 | import tempfile 18 | 19 | from absl.testing import absltest 20 | from learned_optimization.population.examples.simple_cnn import train 21 | 22 | 23 | class TrainTest(absltest.TestCase): 24 | 25 | def test_train(self): 26 | with tempfile.TemporaryDirectory() as directory: 27 | train.train(directory, 100) 28 | 29 | 30 | if __name__ == '__main__': 31 | absltest.main() 32 | -------------------------------------------------------------------------------- /learned_optimization/population/examples/synthetic/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """learned_optimization.population.examples.synthetic.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/population/examples/synthetic/train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A synthetic example of population based training in action.""" 17 | from absl import app 18 | import jax 19 | from learned_optimization.population import population as population_mod 20 | from learned_optimization.population.mutators import single_worker_explore 21 | import tqdm 22 | 23 | 24 | @jax.jit 25 | def gd_loss(parameters, meta_parameters): 26 | a, = parameters 27 | b, = meta_parameters 28 | return (b + a - 2.0)**2 + (a - 1.0)**2 + (b - 1.)**2 29 | 30 | 31 | @jax.jit 32 | def loss(parameters): 33 | return gd_loss(parameters, (2.0,)) 34 | 35 | 36 | @jax.jit 37 | def update(parameters, meta_parameters): 38 | dp = jax.grad(gd_loss)(parameters, meta_parameters) 39 | new_parameters = jax.tree_util.tree_map(lambda a, b: a - 0.01 * b, parameters, 40 | dp) 41 | return new_parameters 42 | 43 | 44 | def mutate_fn(meta_params, direction, _): 45 | if direction == "pos": 46 | return (meta_params[0] + 0.1,) 47 | else: 48 | return (meta_params[0] - 0.1,) 49 | 50 | 51 | def train(steps=10000): 52 | """Train for some number of steps.""" 53 | num_workers = 1 54 | mutator = single_worker_explore.BranchingSingleMachine(mutate_fn, 50, 30) 55 | population = population_mod.PopulationController( 56 | [(1.,) for _ in range(num_workers)], mutator) 57 | 58 | params = None 59 | meta_params = None 60 | gen_id = None 61 | step = 0 62 | 63 | for _ in tqdm.trange(steps): 64 | new_data = population.maybe_get_worker_data(0, gen_id, step, params, 65 | meta_params) 66 | 67 | if new_data: 68 | params = new_data.params 69 | meta_params = new_data.meta_params 70 | gen_id = new_data.generation_id 71 | step = new_data.step 72 | 73 | # If params is none, this simulates a new initialization 74 | if params is None: 75 | params = (0.,) 76 | 77 | # only log back eval every 5 steps. 78 | for i in range(5): 79 | if i % 5 == 0: 80 | l = loss(params) 81 | population.set_eval(0, gen_id, step, params, l) 82 | print(f"\t {l}, params: {params}, meta_params:{meta_params}") 83 | 84 | params = update(params, meta_params) 85 | step += 1 86 | 87 | 88 | def main(_): 89 | train(steps=10000) 90 | 91 | 92 | if __name__ == "__main__": 93 | app.run(main) 94 | -------------------------------------------------------------------------------- /learned_optimization/population/examples/synthetic/train_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for train.""" 17 | 18 | from absl.testing import absltest 19 | from learned_optimization.population.examples.synthetic import train 20 | 21 | 22 | class TrainTest(absltest.TestCase): 23 | 24 | def test_train(self): 25 | train.train(100) 26 | 27 | 28 | if __name__ == '__main__': 29 | absltest.main() 30 | -------------------------------------------------------------------------------- /learned_optimization/population/mutators/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Mutators module.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/population/mutators/fixed_schedule_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.population.fixed_schedule.""" 17 | 18 | from absl.testing import absltest 19 | from learned_optimization.population import population as pop_mod 20 | from learned_optimization.population.mutators import fixed_schedule 21 | 22 | 23 | class FixedScheduleTest(absltest.TestCase): 24 | 25 | def test_follows_schedule(self): 26 | schedule = {0: 1., 10: 2., 30: 3.0} 27 | mutate = fixed_schedule.FixedSchedule(schedule) 28 | population = pop_mod.PopulationController([1.], mutate) 29 | 30 | new_data = population.maybe_get_worker_data(0, None, 0, None, None) 31 | hparams = new_data.meta_params 32 | gen_id = new_data.generation_id 33 | 34 | for step in range(100): 35 | new_data = population.maybe_get_worker_data(0, gen_id, step, None, 36 | hparams) 37 | if new_data: 38 | hparams = new_data.meta_params 39 | gen_id = new_data.generation_id 40 | 41 | population.set_eval(0, gen_id, step, None, 1.0) 42 | 43 | if step < 11: 44 | self.assertEqual(hparams, 1.0) 45 | elif step < 31: 46 | self.assertEqual(hparams, 2.0) 47 | else: 48 | self.assertEqual(hparams, 3.0) 49 | 50 | 51 | if __name__ == '__main__': 52 | absltest.main() 53 | -------------------------------------------------------------------------------- /learned_optimization/population/mutators/single_worker_explore_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.population.single_worker_explore_test.""" 17 | 18 | from absl.testing import absltest 19 | from learned_optimization.population import population as pop_mod 20 | from learned_optimization.population.mutators import single_worker_explore 21 | import numpy as onp 22 | 23 | 24 | class SingleWorkerExploreTest(absltest.TestCase): 25 | 26 | def test_single_worker_explore(self): 27 | 28 | def mutate_fn(p, loc, _): 29 | if loc == "pos": 30 | return p + 0.05 31 | elif loc == "neg": 32 | return p - 0.05 33 | else: 34 | raise ValueError("Bad loc") 35 | 36 | mutate = single_worker_explore.BranchingSingleMachine(mutate_fn, 2, 2) 37 | population = pop_mod.PopulationController([1.], mutate) 38 | 39 | step = 0 40 | for _ in range(200): 41 | new_data = population.maybe_get_worker_data(0, None, step, None, None) 42 | 43 | if new_data: 44 | hparams = new_data.meta_params 45 | gen_id = new_data.generation_id 46 | step = new_data.step 47 | 48 | population.set_eval(0, gen_id, step + 1, None, hparams**2) 49 | 50 | self.assertLess(onp.abs(hparams), 0.1) 51 | 52 | 53 | if __name__ == "__main__": 54 | absltest.main() 55 | -------------------------------------------------------------------------------- /learned_optimization/population/mutators/winner_take_all_genetic_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.population.winner_take_all_genetic.""" 17 | 18 | from absl.testing import absltest 19 | from learned_optimization.population import population as pop_mod 20 | from learned_optimization.population.mutators import winner_take_all_genetic 21 | import numpy as onp 22 | 23 | 24 | class WinnerTakeAllGeneticTest(absltest.TestCase): 25 | 26 | def test_winner_take_all(self): 27 | onp.random.seed(0) 28 | 29 | def mutate_fn(p): 30 | p += onp.random.normal() * 0.01 31 | return p 32 | 33 | mutate = winner_take_all_genetic.WinnerTakeAllGenetic(mutate_fn, 2) 34 | population = pop_mod.PopulationController([1.], mutate) 35 | 36 | num_worker = 3 37 | population = pop_mod.PopulationController([1. for _ in range(num_worker)], 38 | mutate) 39 | 40 | for step in range(200): 41 | for i in range(num_worker): 42 | new_data = population.maybe_get_worker_data(i, None, step, None, None) 43 | 44 | hparams = new_data.meta_params 45 | gen_id = new_data.generation_id 46 | 47 | population.set_eval(i, gen_id, step + 1, None, hparams**2) 48 | 49 | self.assertLess(onp.abs(hparams), 0.3) 50 | 51 | 52 | if __name__ == '__main__': 53 | absltest.main() 54 | -------------------------------------------------------------------------------- /learned_optimization/profile.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A minimal profiling library to profile code in a context manager.""" 17 | import functools 18 | import sys 19 | import time 20 | from typing import TypeVar 21 | 22 | from absl import flags 23 | from absl import logging 24 | 25 | flags.DEFINE_bool("profile_log_times", False, "Log out timing information.") 26 | 27 | FLAGS = flags.FLAGS 28 | FLAGS(sys.argv, known_only=True) # Ensure flags are parsed at this time. 29 | 30 | T = TypeVar("T") 31 | 32 | 33 | class Profile: 34 | """Context manager for profiling functions. 35 | 36 | ``` 37 | with Profile("name"): 38 | run_somecode() 39 | ``` 40 | """ 41 | 42 | def __init__(self, name: str): 43 | self.name = name 44 | 45 | def __enter__(self): 46 | self._start_time = time.time() 47 | 48 | def __exit__(self, exc_type, exc_value, traceback): 49 | if FLAGS.profile_log_times: 50 | logging.info(f"{self.name} took {time.time()-self._start_time} seconds") # pylint: disable=logging-fstring-interpolation 51 | 52 | 53 | def wrap(): 54 | """Wrap a function in a Profile.""" 55 | 56 | def _wrapper(fn: T) -> T: 57 | 58 | @functools.wraps(fn) 59 | def _fn(*args, **kwargs): 60 | with Profile(fn.__name__): 61 | return fn(*args, **kwargs) 62 | 63 | return _fn 64 | 65 | return _wrapper 66 | 67 | 68 | -------------------------------------------------------------------------------- /learned_optimization/py_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Common python utilities.""" 17 | from concurrent import futures 18 | from typing import Any, Callable, Sequence 19 | import tqdm 20 | 21 | 22 | def threaded_tqdm_map(threads: int, func: Callable[[Any], Any], 23 | data: Sequence[Any]) -> Sequence[Any]: 24 | future_list = [] 25 | with futures.ThreadPoolExecutor(threads) as executor: 26 | for l in tqdm.tqdm(data): 27 | future_list.append(executor.submit(func, l)) 28 | return [x.result() for x in tqdm.tqdm(future_list)] 29 | -------------------------------------------------------------------------------- /learned_optimization/research/README.md: -------------------------------------------------------------------------------- 1 | This directory is contains experimental code for research / demo projects. -------------------------------------------------------------------------------- /learned_optimization/research/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Experimental and research code using learned_optimization.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/research/brax/brax_env_truncated_step_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for brax_env_truncated_step.""" 17 | 18 | from absl.testing import absltest 19 | import brax.v1 as brax 20 | import jax 21 | from learned_optimization.outer_trainers import full_es 22 | from learned_optimization.outer_trainers import test_utils 23 | from learned_optimization.outer_trainers import truncated_pes 24 | from learned_optimization.outer_trainers import truncated_step 25 | from learned_optimization.outer_trainers import truncation_schedule 26 | from learned_optimization.research.brax import brax_env_truncated_step 27 | 28 | 29 | class BraxEnvTruncatedStepTest(absltest.TestCase): 30 | 31 | def _get_tstep(self): 32 | env_name = 'fast' 33 | env = brax.envs.create(env_name) 34 | asize = env.action_size 35 | osize = env.observation_size 36 | policy = brax_env_truncated_step.BraxEnvPolicy(osize, asize) 37 | key = jax.random.PRNGKey(0) 38 | params = policy.init(key) 39 | tstep = brax_env_truncated_step.BraxEnvTruncatedStep( 40 | env_name, 41 | policy, 42 | episode_length=100, 43 | random_initial_iteration_offset=100) 44 | return tstep, params 45 | 46 | def test_brax_truncated_step(self): 47 | tstep, params = self._get_tstep() 48 | key = jax.random.PRNGKey(0) 49 | state = tstep.init_step_state(params, None, key) 50 | state, unused_out = tstep.unroll_step(params, state, key, None, None) 51 | unused_state, unused_out = tstep.unroll_step(params, state, key, None, None) 52 | 53 | def test_brax_truncated_step_with_full_es(self): 54 | tstep, _ = self._get_tstep() 55 | vtstep = truncated_step.VectorizedTruncatedStep(tstep, 32) 56 | grad_est = truncated_pes.TruncatedPES(vtstep, std=0.01) 57 | test_utils.trainer_smoketest(grad_est) 58 | 59 | def test_brax_truncated_step_with_pes(self): 60 | tstep, _ = self._get_tstep() 61 | vtstep = truncated_step.VectorizedTruncatedStep(tstep, 32) 62 | 63 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10) 64 | grad_est = full_es.FullES(vtstep, std=0.01, truncation_schedule=trunc_sched) 65 | test_utils.trainer_smoketest(grad_est) 66 | 67 | 68 | if __name__ == '__main__': 69 | absltest.main() 70 | -------------------------------------------------------------------------------- /learned_optimization/research/data_driven/README.md: -------------------------------------------------------------------------------- 1 | # General-Purpose In-Context Learning by Meta-Learning Transformers 2 | 3 | Research code for the paper https://arxiv.org/abs/2212.04458 -------------------------------------------------------------------------------- /learned_optimization/research/data_driven/run_mnist_projections.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Runs MNIST projection experiment with gin configuration.""" 17 | 18 | from absl import app 19 | import jax 20 | from learned_optimization import filesystem 21 | from learned_optimization import setup_experiment 22 | from learned_optimization.research.data_driven import mnist_projections 23 | import numpy as np 24 | import yaml 25 | 26 | 27 | def main(_) -> None: 28 | rank = jax.process_index() 29 | train_log_dir = setup_experiment.setup_experiment(make_dir=(rank == 0)) 30 | train(train_log_dir) 31 | 32 | 33 | def train(training_log_directory: str): 34 | """Runs a projection experiment. 35 | 36 | Args: 37 | training_log_directory: Directory to store log data to. 38 | """ 39 | 40 | experiment = mnist_projections.ProjectionExperiment(training_log_directory) 41 | log_dict = experiment.run() 42 | 43 | if jax.process_index() == 0: 44 | yaml_file_name = f'{training_log_directory}/results.yaml' 45 | with filesystem.file_open(yaml_file_name, 'w') as f: 46 | yaml.dump( 47 | { 48 | k: np.asarray(v).item() 49 | for k, v in log_dict.items() 50 | if np.asarray(v).size == 1 51 | }, f) 52 | np_file_name = f'{training_log_directory}/results.npy' 53 | with filesystem.file_open(np_file_name, 'wb') as f: 54 | np.save(f, log_dict) 55 | 56 | 57 | if __name__ == '__main__': 58 | app.run(main) 59 | -------------------------------------------------------------------------------- /learned_optimization/research/data_driven/summary.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils for logging of summary statistics.""" 17 | 18 | from concurrent import futures 19 | import jax 20 | 21 | from learned_optimization import summary as lo_summary 22 | from learned_optimization.baselines import utils 23 | import numpy as np 24 | 25 | 26 | def only_first_rank(func): 27 | """Only runs the function on rank 0.""" 28 | 29 | def wrappee(*args, **kwargs): 30 | if jax.process_index() == 0: 31 | return func(*args, **kwargs) 32 | return None 33 | 34 | return wrappee 35 | 36 | 37 | class DictSummaryWriter(lo_summary.SummaryWriterBase): 38 | """A summary writer than stores entire dicts as scalars and npy files.""" 39 | 40 | FLUSH_TIMEOUT_SECS = 10 41 | 42 | def __init__(self, base_writer: lo_summary.SummaryWriterBase, log_dir: str): 43 | self._writer = base_writer 44 | self._log_dir = log_dir 45 | self._thread_pool = futures.ThreadPoolExecutor(max_workers=2) 46 | self._pending_dict_writes = [] 47 | 48 | @only_first_rank 49 | def scalar(self, name, value, step): 50 | return self._writer.scalar(name, value, step) 51 | 52 | @only_first_rank 53 | def histogram(self, name, value, step): 54 | return self._writer.histogram(name, value, step) 55 | 56 | @only_first_rank 57 | def flush(self): 58 | futures.wait(self._pending_dict_writes, timeout=self.FLUSH_TIMEOUT_SECS) 59 | self._pending_dict_writes.clear() 60 | return self._writer.flush() 61 | 62 | @only_first_rank 63 | def dict(self, dict_value, step): 64 | # Store scalars in tf summary 65 | for k, v in dict_value.items(): 66 | if v is None: 67 | continue 68 | if np.isscalar(v) or v.size == 1: 69 | self.scalar(k, v, step) 70 | 71 | # Store entire dictionary as npy file 72 | file_name = f'{self._log_dir}/summary_{step}.npy' 73 | task = self._thread_pool.submit(utils.write_npz, file_name, dict_value) 74 | self._pending_dict_writes.append(task) 75 | -------------------------------------------------------------------------------- /learned_optimization/research/es_scaling/es_scaling_tasks.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # ES tasks to probe scaling. 17 | # Gradient accumulation 18 | -------------------------------------------------------------------------------- /learned_optimization/research/general_lopt/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """learned_optimization.research.general_lopt""" 17 | -------------------------------------------------------------------------------- /learned_optimization/research/general_lopt/configs/README.md: -------------------------------------------------------------------------------- 1 | Configs used in meta-training. 2 | 3 | At the moment, these are not wired up to run outside of Google's infra. -------------------------------------------------------------------------------- /learned_optimization/research/general_lopt/hyper_v2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.rnn_mlp_lopt.""" 17 | from absl.testing import absltest 18 | from learned_optimization.learned_optimizers import test_utils 19 | from learned_optimization.research.general_lopt import hyper_v2 20 | 21 | 22 | class HyperV2LOptTest(absltest.TestCase): 23 | 24 | def test_hyperv2_lopt_lopt(self): 25 | test_utils.smoketest_learned_optimizer(hyper_v2.HyperV2()) 26 | 27 | 28 | if __name__ == '__main__': 29 | absltest.main() 30 | -------------------------------------------------------------------------------- /learned_optimization/research/general_lopt/prefab_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for pretrained lopt optimizers.""" 17 | from absl.testing import absltest 18 | from learned_optimization.optimizers import test_utils as opt_test_utils 19 | from learned_optimization.research.general_lopt import prefab 20 | 21 | 22 | class PrefabLOptTest(absltest.TestCase): 23 | 24 | def test_lopt(self): 25 | opt_test_utils.smoketest_optimizer(prefab.LearnedOptimizer(100)) 26 | 27 | def test_lopt_accum(self): 28 | opt = prefab.LearnedOptimizer(100, max_training_steps=10) 29 | self.assertEqual(opt.opt.num_average, 10) 30 | opt_test_utils.smoketest_optimizer(opt) 31 | 32 | def test_lopt_wd(self): 33 | opt = prefab.LearnedOptimizer(100, weight_decay=1e-6) 34 | opt_test_utils.smoketest_optimizer(opt) 35 | 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /learned_optimization/research/general_lopt/pretrained_optimizers_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for pretrained lopt optimizers.""" 17 | from absl.testing import absltest 18 | from learned_optimization.optimizers import test_utils as opt_test_utils 19 | from learned_optimization.research.general_lopt import pretrained_optimizers 20 | 21 | 22 | class PretrainedLOptTest(absltest.TestCase): 23 | 24 | def test_hyperv2_pretrain(self): 25 | opt_test_utils.smoketest_optimizer( 26 | pretrained_optimizers 27 | .aug12_continue_on_bigger_2xbs_200kstep_bigproblem_v2_5620()) 28 | 29 | 30 | def test_hyperv2_pretrain2(self): 31 | opt_test_utils.smoketest_optimizer( 32 | pretrained_optimizers.aug11_aug4_trunc10per_last()) 33 | 34 | def test_hyperv2_pretrain3(self): 35 | opt_test_utils.smoketest_optimizer( 36 | pretrained_optimizers 37 | .aug26_aug12_continue_on_bigger_2xbs_200kstep_bigproblem_v2_1397()) 38 | 39 | 40 | if __name__ == '__main__': 41 | absltest.main() 42 | -------------------------------------------------------------------------------- /learned_optimization/research/general_lopt/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """learned_optimization.research.general_lopt.tasks""" 17 | -------------------------------------------------------------------------------- /learned_optimization/research/general_lopt/tasks/fast_mlp_diff_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A set of tasks for quick iteration on meta-training.""" 17 | import functools 18 | 19 | import gin 20 | from learned_optimization.tasks import base 21 | from learned_optimization.tasks import task_augmentation 22 | from learned_optimization.tasks.datasets import image 23 | from learned_optimization.tasks.fixed import image_mlp 24 | 25 | inner_bs = 64 26 | 27 | 28 | def cifar_task(): 29 | datasets = image.cifar10_datasets( 30 | batch_size=inner_bs, image_size=(8, 8), convert_to_black_and_white=True) 31 | return image_mlp._MLPImageTask(datasets, [32]) # pylint: disable=protected-access 32 | 33 | 34 | def fashion_mnist_task(): 35 | datasets = image.fashion_mnist_datasets( 36 | batch_size=inner_bs, image_size=(8, 8)) 37 | return image_mlp._MLPImageTask(datasets, [32]) # pylint: disable=protected-access 38 | 39 | 40 | def mnist_task(): 41 | datasets = image.mnist_datasets(batch_size=inner_bs, image_size=(8, 8)) 42 | return image_mlp._MLPImageTask(datasets, [32]) # pylint: disable=protected-access 43 | 44 | 45 | def svhn_task(): 46 | datasets = image.svhn_cropped_datasets( 47 | batch_size=inner_bs, image_size=(8, 8), convert_to_black_and_white=True) 48 | return image_mlp._MLPImageTask(datasets, [32]) # pylint: disable=protected-access 49 | 50 | 51 | def task_to_augmented_task_family(task_fn): 52 | task_family = base.single_task_to_family(task=task_fn()) 53 | return task_augmentation.ReparamWeightsFamily(task_family, "tensor", 54 | (0.01, 100)) 55 | 56 | 57 | @gin.configurable 58 | def jun24_fast_mlp_diff_data_task_list(): 59 | return [[functools.partial(task_to_augmented_task_family, fn)] 60 | for fn in [cifar_task, fashion_mnist_task, mnist_task, svhn_task]] 61 | 62 | 63 | @functools.lru_cache(None) 64 | def _get_task(task_family_idx, task_family_seed): 65 | # cache here so that the task identities are the same and thus will reuse 66 | # jit cache. 67 | task_family = jun24_fast_mlp_diff_data_task_list()[task_family_idx][0]() 68 | return base.get_task(task_family, task_family_seed=task_family_seed) 69 | 70 | 71 | _datasets = ("cifar", "fashion_mnist", "mnist", "svhn") 72 | for _i in range(4): 73 | n_seed = 8 74 | for _s in range(n_seed): 75 | _task_seed = _i * n_seed + _s 76 | _task_name = f"Jun24FastMLPEval_TaskFamily_{_datasets[_i]}_Seed{_s}" 77 | locals()[_task_name] = functools.partial(_get_task, _i, _task_seed) 78 | gin.external_configurable(locals()[_task_name], _task_name) 79 | -------------------------------------------------------------------------------- /learned_optimization/research/hysteresis/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Experimental research code.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/research/hysteresis/truncated_pes_no_hyst.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """PES with no hysteresis by spending a lot more compute.""" 17 | from typing import Any, Mapping, Optional, Tuple 18 | 19 | import gin 20 | import haiku as hk 21 | import jax.numpy as jnp 22 | from learned_optimization import profile 23 | from learned_optimization.outer_trainers import gradient_learner 24 | from learned_optimization.outer_trainers import truncated_pes 25 | from learned_optimization.outer_trainers import truncated_step as truncated_step_mod 26 | 27 | PRNGKey = jnp.ndarray 28 | MetaParams = Any 29 | TruncatedUnrollState = Any 30 | 31 | 32 | @gin.configurable 33 | class TruncatedPESNoHyst(gradient_learner.GradientEstimator): 34 | """PES with no hysteresis by spending a lot more compute.""" 35 | 36 | def __init__( 37 | self, 38 | truncated_step: truncated_step_mod.VectorizedTruncatedStep, 39 | burnin_steps: int, 40 | trunc_length=10, 41 | std=0.01, 42 | steps_per_jit=10, 43 | stack_antithetic_samples: bool = False, 44 | sign_delta_loss_scalar: Optional[float] = None, 45 | ): 46 | self.burnin_steps = burnin_steps 47 | self.trunc_length = trunc_length 48 | self.grad_est = truncated_pes.TruncatedPES(truncated_step, trunc_length, 49 | std, steps_per_jit, 50 | stack_antithetic_samples, 51 | sign_delta_loss_scalar) 52 | 53 | @profile.wrap() 54 | def init_worker_state(self, worker_weights: gradient_learner.WorkerWeights, 55 | key: PRNGKey): 56 | return self.grad_est.init_worker_state(worker_weights, key) 57 | 58 | @profile.wrap() 59 | def compute_gradient_estimate( # pytype: disable=signature-mismatch 60 | self, 61 | worker_weights: gradient_learner.WorkerWeights, 62 | key: PRNGKey, 63 | state, 64 | with_summary: bool = False 65 | ) -> Tuple[gradient_learner.GradientEstimatorOut, Mapping[str, jnp.ndarray]]: 66 | 67 | rng = hk.PRNGSequence(key) 68 | 69 | with profile.Profile("new_state"): 70 | state = self.init_worker_state(worker_weights, next(rng)) 71 | 72 | assert self.burnin_steps % self.trunc_length == 0 73 | 74 | with profile.Profile("burnin"): 75 | for _ in range(self.burnin_steps // self.trunc_length): 76 | output, unused_metrics = self.grad_est.compute_gradient_estimate( 77 | worker_weights, next(rng), state, with_summary=False) 78 | state = output.unroll_state 79 | 80 | return self.grad_est.compute_gradient_estimate( 81 | worker_weights, next(rng), state, with_summary=with_summary) 82 | -------------------------------------------------------------------------------- /learned_optimization/research/jaxnerf/example_data/imgs/r_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/learned_optimization/3ab34d27bdd38f487af65fb025f9c86869946e81/learned_optimization/research/jaxnerf/example_data/imgs/r_0.png -------------------------------------------------------------------------------- /learned_optimization/research/jaxnerf/example_data/transforms_test.json: -------------------------------------------------------------------------------- 1 | {"camera_angle_x": 0.6911112070083618, "frames": [{"file_path": "./imgs/r_0", "rotation": 0.012566370614359171, "transform_matrix": [[-0.9999021887779236, 0.004192245192825794, -0.013345719315111637, -0.05379832163453102], [-0.013988681137561798, -0.2996590733528137, 0.95394366979599, 3.845470428466797], [-4.656612873077393e-10, 0.9540371894836426, 0.29968830943107605, 1.2080823183059692], [0.0, 0.0, 0.0, 1.0]]}]} -------------------------------------------------------------------------------- /learned_optimization/research/jaxnerf/example_data/transforms_train.json: -------------------------------------------------------------------------------- 1 | {"camera_angle_x": 0.6911112070083618, "frames": [{"file_path": "./imgs/r_0", "rotation": 0.012566370614359171, "transform_matrix": [[-0.9999021887779236, 0.004192245192825794, -0.013345719315111637, -0.05379832163453102], [-0.013988681137561798, -0.2996590733528137, 0.95394366979599, 3.845470428466797], [-4.656612873077393e-10, 0.9540371894836426, 0.29968830943107605, 1.2080823183059692], [0.0, 0.0, 0.0, 1.0]]}]} -------------------------------------------------------------------------------- /learned_optimization/research/jaxnerf/jaxnerf_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for JaxNeRF tasks.""" 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from learned_optimization.research.jaxnerf import jaxnerf 20 | from learned_optimization.tasks import test_utils 21 | 22 | tasks = ['JAXNeRF_TestTask'] 23 | 24 | 25 | class JaxNeRFTest(parameterized.TestCase): 26 | 27 | @parameterized.parameters(tasks) 28 | def test_tasks(self, task_name): 29 | task = getattr(jaxnerf, task_name)() 30 | test_utils.smoketest_task(task, abstract_data=False) 31 | 32 | 33 | if __name__ == '__main__': 34 | absltest.main() 35 | -------------------------------------------------------------------------------- /learned_optimization/research/univ_nfn/learned_opt/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /learned_optimization/research/univ_nfn/learned_opt/outer_opt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Additional outer optimizer helpers.""" 17 | 18 | import gin 19 | import jax 20 | from jax import flatten_util 21 | import jax.numpy as jnp 22 | from learned_optimization.optimizers import base as opt_base 23 | from learned_optimization.optimizers import optax_opts 24 | import optax 25 | 26 | 27 | @gin.configurable 28 | class GradientNormClippedOptimizer(opt_base.Optimizer): 29 | """Clip gradients by norm before passing into an optimizer.""" 30 | 31 | def __init__(self, opt: opt_base.Optimizer, clip_norm: float = 1.0): 32 | if not isinstance(opt, opt_base.Optimizer): 33 | raise ValueError( 34 | "Must instance of Optimizer. Maybe you are passing the" 35 | f" class and not an instance? Received {opt}." 36 | ) 37 | self.opt = opt 38 | self.clip_norm = clip_norm 39 | 40 | def get_params(self, state): 41 | return self.opt.get_params(state) 42 | 43 | def get_state(self, state): 44 | return self.opt.get_state(state) 45 | 46 | def init(self, *args, **kwargs): 47 | return self.opt.init(*args, **kwargs) 48 | 49 | def update(self, opt_state, grad, *args, **kwargs): 50 | g_norm = jnp.sqrt(jnp.sum(flatten_util.ravel_pytree(grad)[0] ** 2)) 51 | trigger = jnp.squeeze(g_norm < self.clip_norm) 52 | 53 | def clip_fn(t): 54 | return jax.lax.select( 55 | trigger, t, (t / g_norm.astype(t.dtype)) * self.clip_norm 56 | ) 57 | 58 | grad = jax.tree_util.tree_map(clip_fn, grad) 59 | return self.opt.update(opt_state, grad, *args, **kwargs) 60 | 61 | 62 | @gin.configurable 63 | class WarmupCosineAdam(optax_opts.OptaxOptimizer): 64 | """Adam with linear warmup + cosine LR decay.""" 65 | 66 | def __init__( 67 | self, 68 | warmup_steps=gin.REQUIRED, 69 | decay_steps=gin.REQUIRED, 70 | learning_rate=1e-3, 71 | beta1=0.9, 72 | beta2=0.999, 73 | epsilon=1e-8, 74 | epsilon_root=1e-8, 75 | ): 76 | final_lr = learning_rate / 10 77 | opt = optax.chain( 78 | optax.scale_by_adam( 79 | b1=beta1, b2=beta2, eps=epsilon, eps_root=epsilon_root 80 | ), 81 | optax.scale_by_schedule( 82 | optax.warmup_cosine_decay_schedule( 83 | 0.0, learning_rate, warmup_steps, decay_steps, final_lr 84 | ) 85 | ), 86 | optax.scale(-1), 87 | ) 88 | super().__init__(opt) 89 | -------------------------------------------------------------------------------- /learned_optimization/research/univ_nfn/nfn/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /learned_optimization/run_outer_train_single.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """This script contains both outer workers and outer learners.""" 17 | from typing import Sequence 18 | 19 | from absl import app 20 | from learned_optimization import outer_train 21 | from learned_optimization import setup_experiment 22 | 23 | 24 | def main(argv: Sequence[str]) -> None: 25 | if len(argv) > 1: 26 | raise app.UsageError('Too many command-line arguments.') 27 | 28 | train_log_dir = setup_experiment.setup_experiment(make_dir=True) 29 | outer_train.run_train( 30 | train_log_dir, 31 | is_trainer=True, 32 | is_worker=True, 33 | ) 34 | 35 | 36 | if __name__ == '__main__': 37 | app.run(main) 38 | -------------------------------------------------------------------------------- /learned_optimization/summary_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.summary.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from jax import lax 21 | import jax.numpy as jnp 22 | from learned_optimization import summary 23 | 24 | 25 | @jax.jit 26 | def fn(a): 27 | summary.summary("a", a[0]) 28 | 29 | def update(state, _): 30 | s = state + 1 31 | summary.summary("loop", s[0]) 32 | return s, s 33 | 34 | a, _ = lax.scan(update, a, jnp.arange(20)) 35 | 36 | return a * 2 37 | 38 | 39 | def fn2(static, a): 40 | summary.summary(static, a) 41 | return a * 2 42 | 43 | 44 | if summary.ORYX_LOGGING: 45 | 46 | class SummaryTest(absltest.TestCase): 47 | 48 | def setUp(self): 49 | super().setUp() 50 | summary.reset_summary_counter() 51 | 52 | def test_summary(self): 53 | result, metrics = summary.with_summary_output_reduced(fn)( 54 | jnp.zeros(1) + 123) 55 | del result 56 | 57 | metrics = summary.aggregate_metric_list([metrics]) 58 | self.assertIn("mean||a", metrics) 59 | self.assertIn("mean||loop", metrics) 60 | self.assertEqual(metrics["mean||loop"], 133.5) 61 | self.assertEqual(metrics["mean||a"], 123) 62 | 63 | def test_summary_output_reduced_static_args(self): 64 | result, metrics = summary.with_summary_output_reduced( 65 | fn2, static_argnums=(0,))("name", jnp.zeros(1) + 123) 66 | del result 67 | 68 | metrics = summary.aggregate_metric_list([metrics]) 69 | self.assertIn("mean||name", metrics) 70 | 71 | def test_add_with_summary(self): 72 | new_fn = summary.add_with_summary(fn2, static_argnums=(0,)) 73 | o, m = new_fn("test", 1, with_summary=False) 74 | self.assertEmpty(m) 75 | self.assertEqual(o, 2) 76 | 77 | summary.reset_summary_counter() 78 | 79 | o, m = new_fn("test", 1, with_summary=True) 80 | self.assertEqual(o, 2) 81 | self.assertEqual(m["mean||test"], 1) 82 | 83 | 84 | if __name__ == "__main__": 85 | absltest.main() 86 | -------------------------------------------------------------------------------- /learned_optimization/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """learned_optimizers.tasks module.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/tasks/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """learned_optimizers.datasets module.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/tasks/datasets/language_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from absl.testing import absltest 17 | from learned_optimization.tasks.datasets import language 18 | import numpy as np 19 | 20 | 21 | class ImageTest(absltest.TestCase): 22 | 23 | def test_lm1b_32k_datasets(self): 24 | datasets = language.lm1b_32k_datasets(32, 8) 25 | self.assertEqual(datasets.abstract_batch["obs"].shape, (32, 8)) 26 | self.assertEqual(datasets.abstract_batch["target"].shape, (32, 8)) 27 | 28 | data = next(datasets.train) 29 | self.assertEqual(data["obs"].shape, (32, 8)) 30 | self.assertEqual(data["target"].shape, (32, 8)) 31 | self.assertTrue(np.all(data["obs"][:, 1:] == data["target"][:, 0:-1])) 32 | 33 | def test_lm1b_bytes_datasets(self): 34 | datasets = language.lm1b_bytes_datasets(32, 10) 35 | self.assertEqual(datasets.abstract_batch["obs"].shape, (32, 10)) 36 | data = next(datasets.train) 37 | self.assertEqual(data["obs"].shape, (32, 10)) 38 | 39 | def test_wikipedia_en_32k_datasets(self): 40 | datasets = language.wikipedia_en_32k_datasets(32, 8) 41 | self.assertEqual(datasets.abstract_batch["obs"].shape, (32, 8)) 42 | self.assertEqual(datasets.abstract_batch["target"].shape, (32, 8)) 43 | data = next(datasets.train) 44 | self.assertEqual(data["obs"].shape, (32, 8)) 45 | self.assertEqual(data["target"].shape, (32, 8)) 46 | self.assertTrue(np.all(data["obs"][:, 1:] == data["target"][:, 0:-1])) 47 | 48 | data = next(datasets.test) 49 | self.assertEqual(data["obs"].shape, (32, 8)) 50 | 51 | 52 | if __name__ == "__main__": 53 | absltest.main() 54 | -------------------------------------------------------------------------------- /learned_optimization/tasks/es_wrapper_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for es_wrapper.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import jax 21 | import jax.numpy as jnp 22 | from learned_optimization.tasks import es_wrapper 23 | from learned_optimization.tasks import quadratics 24 | from learned_optimization.tasks import test_utils 25 | import numpy as onp 26 | 27 | 28 | class EsWrapperTest(parameterized.TestCase): 29 | 30 | @parameterized.named_parameters([ 31 | ("with_vmap", True), 32 | ("without_vmap", False), 33 | ]) 34 | def test_antithetic_es_value_and_grad(self, vmap): 35 | key = jax.random.PRNGKey(0) 36 | 37 | def loss_fn(aa, unused_argb): 38 | loss = sum([jnp.sum(a**2) for a in aa.values()]) 39 | return loss, jnp.asarray(4.0) 40 | 41 | params = { 42 | "a": jnp.ones(3, dtype=jnp.float32), 43 | "b": jnp.zeros(3, dtype=jnp.float32) 44 | } 45 | 46 | (loss, aux), grads = es_wrapper.antithetic_es_value_and_grad( 47 | loss_fn, has_aux=True, std=0.01, vmap=vmap)( 48 | params, None, es_key=key) 49 | 50 | self.assertAlmostEqual(float(aux), 4) 51 | self.assertLess(onp.abs(loss - 3), 1) 52 | self.assertEqual( 53 | jax.tree_util.tree_structure(grads), 54 | jax.tree_util.tree_structure(params)) 55 | 56 | def test_multi_antithetic_es_value_and_grad(self): 57 | key = jax.random.PRNGKey(0) 58 | 59 | def loss_fn(aa, unused_argb): 60 | loss = sum([jnp.sum(a**2) for a in aa.values()]) 61 | return loss, jnp.asarray(4.0) 62 | 63 | params = { 64 | "a": jnp.ones(3, dtype=jnp.float32), 65 | "b": jnp.zeros(3, dtype=jnp.float32) 66 | } 67 | 68 | (loss, aux), grads = es_wrapper.multi_antithetic_es_value_and_grad( 69 | loss_fn, n_pairs=256, has_aux=True, std=0.01)( 70 | params, None, es_key=key) 71 | 72 | self.assertAlmostEqual(float(aux), 4) 73 | # these are stochastic estimates, so there will be error but with 256 74 | # samples this should be ok. 75 | self.assertGreater(onp.mean(grads["a"]), 1) 76 | self.assertLess(onp.mean(grads["b"]**2), 1) 77 | self.assertLess(onp.abs(loss - 3), 1) 78 | 79 | def test_es_task_wrapper(self): 80 | task = es_wrapper.ESTask(quadratics.QuadraticTask()) 81 | test_utils.smoketest_task(task) 82 | 83 | def test_es_task_family_wrapper(self): 84 | task_family = es_wrapper.ESTaskFamily( 85 | quadratics.FixedDimQuadraticFamilyData(10), std=0.01, n_pairs=3) 86 | test_utils.smoketest_task_family(task_family) 87 | 88 | 89 | if __name__ == "__main__": 90 | absltest.main() 91 | -------------------------------------------------------------------------------- /learned_optimization/tasks/fixed/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Fixed tasks module.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/tasks/fixed/all.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Import all fixed tasks into one module.""" 17 | # pylint: disable=wildcard-import 18 | 19 | from learned_optimization.tasks.fixed.conv import * 20 | from learned_optimization.tasks.fixed.image_mlp import * 21 | from learned_optimization.tasks.fixed.image_mlp_ae import * 22 | from learned_optimization.tasks.fixed.resnet import * 23 | from learned_optimization.tasks.fixed.rnn_lm import * 24 | from learned_optimization.tasks.fixed.es_wrapped import * 25 | from learned_optimization.tasks.fixed.transformer_lm import * 26 | -------------------------------------------------------------------------------- /learned_optimization/tasks/fixed/conv_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.tasks.fixed.conv.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.fixed import conv 22 | 23 | 24 | tasks = [ 25 | "Conv_Cifar10_32x64x64", 26 | "Conv_Cifar10_16_32x64x64", 27 | "Conv_Cifar10_32x64x64_Tanh", 28 | "Conv_imagenet32_16_32x64x64", 29 | "Conv_imagenet16_16_64x128x128x128", 30 | "Conv_Cifar10_32x64x64_batchnorm", 31 | "Conv_Cifar10_32x64x64_layernorm", 32 | ] 33 | 34 | 35 | class ConvTest(parameterized.TestCase): 36 | 37 | @parameterized.parameters(tasks) 38 | def test_tasks(self, task_name): 39 | task = getattr(conv, task_name)() 40 | test_utils.smoketest_task(task) 41 | 42 | 43 | if __name__ == "__main__": 44 | absltest.main() 45 | -------------------------------------------------------------------------------- /learned_optimization/tasks/fixed/es_wrapped.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tasks which estimate gradients via ES.""" 17 | # pylint: disable=invalid-name 18 | 19 | import gin 20 | from learned_optimization.tasks import es_wrapper 21 | from learned_optimization.tasks.fixed import image_mlp 22 | 23 | 24 | @gin.configurable 25 | def ESWrapped_1pair_ImageMLP_Cifar10_128x128x128_Relu(): 26 | return es_wrapper.ESTask( 27 | image_mlp.ImageMLP_Cifar10_128x128x128_Relu(), 0.01, n_pairs=1) 28 | 29 | 30 | @gin.configurable 31 | def ESWrapped_8pair_ImageMLP_Cifar10_128x128x128_Relu(): 32 | return es_wrapper.ESTask( 33 | image_mlp.ImageMLP_Cifar10_128x128x128_Relu(), 0.01, n_pairs=8) 34 | 35 | 36 | @gin.configurable 37 | def ESWrapped_1pair_ImageMLP_Mnist_128x128x128_Relu(): 38 | return es_wrapper.ESTask( 39 | image_mlp.ImageMLP_Mnist_128x128x128_Relu(), 0.01, n_pairs=1) 40 | 41 | 42 | @gin.configurable 43 | def ESWrapped_8pair_ImageMLP_Mnist_128x128x128_Relu(): 44 | return es_wrapper.ESTask( 45 | image_mlp.ImageMLP_Mnist_128x128x128_Relu(), 0.01, n_pairs=8) 46 | 47 | 48 | @gin.configurable 49 | def ESWrapped_1pair_ImageMLP_FashionMnist8_Relu32(): 50 | return es_wrapper.ESTask( 51 | image_mlp.ImageMLP_FashionMnist8_Relu32(), 0.01, n_pairs=1) 52 | 53 | 54 | @gin.configurable 55 | def ESWrapped_1pair_ImageMLP_Cifar10BW8_Relu32(): 56 | return es_wrapper.ESTask( 57 | image_mlp.ImageMLP_Cifar10BW8_Relu32(), 0.01, n_pairs=1) 58 | -------------------------------------------------------------------------------- /learned_optimization/tasks/fixed/es_wrapped_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.tasks.fixed.es_wrapped.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.fixed import es_wrapped 22 | 23 | tasks = [ 24 | "ESWrapped_1pair_ImageMLP_Cifar10_128x128x128_Relu", 25 | "ESWrapped_8pair_ImageMLP_Cifar10_128x128x128_Relu", 26 | "ESWrapped_1pair_ImageMLP_Mnist_128x128x128_Relu", 27 | "ESWrapped_8pair_ImageMLP_Mnist_128x128x128_Relu", 28 | ] 29 | 30 | 31 | class ESWrappedTest(parameterized.TestCase): 32 | 33 | @parameterized.parameters(tasks) 34 | def test_tasks(self, task_name): 35 | task = getattr(es_wrapped, task_name)() 36 | test_utils.smoketest_task(task) 37 | 38 | 39 | if __name__ == "__main__": 40 | absltest.main() 41 | -------------------------------------------------------------------------------- /learned_optimization/tasks/fixed/image_mlp_ae_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.tasks.fixed_mlp.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.fixed import image_mlp_ae 22 | 23 | tasks = [ 24 | "ImageMLPAE_Cifar10_32x32x32_bs128", 25 | "ImageMLPAE_Cifar10_256x256x256_bs128", 26 | "ImageMLPAE_Cifar10_256x256x256_bs1024", 27 | "ImageMLPAE_Cifar10_128x32x128_bs256", 28 | "ImageMLPAE_Mnist_128x32x128_bs128", 29 | "ImageMLPAE_FashionMnist_128x32x128_bs128", 30 | ] 31 | 32 | 33 | class ImageMLPAETest(parameterized.TestCase): 34 | 35 | @parameterized.parameters(tasks) 36 | def test_tasks(self, task_name): 37 | task = getattr(image_mlp_ae, task_name)() 38 | test_utils.smoketest_task(task) 39 | 40 | 41 | if __name__ == "__main__": 42 | absltest.main() 43 | -------------------------------------------------------------------------------- /learned_optimization/tasks/fixed/image_mlp_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.tasks.fixed.image_mlp.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.fixed import image_mlp 22 | 23 | 24 | tasks = [ 25 | "ImageMLP_FashionMnist_Relu128x128", 26 | "ImageMLP_Imagenet16_Relu256x256x256", 27 | "ImageMLP_Cifar10_128x128x128_Relu", 28 | "ImageMLP_Cifar10_128x128_Dropout05_Relu_MSE", 29 | "ImageMLP_Cifar10_128x128x128_BatchNorm_Relu", 30 | "ImageMLP_Cifar10_128x128x128_LayerNorm_Relu", 31 | "ImageMLP_Cifar10BW8_Relu32", 32 | ] 33 | 34 | 35 | class ImageMLPTest(parameterized.TestCase): 36 | 37 | @parameterized.parameters(tasks) 38 | def test_tasks(self, task_name): 39 | task = getattr(image_mlp, task_name)() 40 | test_utils.smoketest_task(task) 41 | 42 | 43 | if __name__ == "__main__": 44 | absltest.main() 45 | -------------------------------------------------------------------------------- /learned_optimization/tasks/fixed/lopt_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.tasks.fixed.lopt.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.fixed import lopt 22 | 23 | tasks = [ 24 | "LOpt_MLPLOpt_FahionMnist_10", 25 | "LOpt_MLPLOpt_Cifar10_8_50", 26 | "LOpt_AdafacMLPLOpt_FashionMnist_50", 27 | "LOpt_AdafacMLPLOpt_FashionMnist_20", 28 | "LOpt_LearnableAdam_Cifar10_8_50", 29 | "LOpt_ES4_AdafacMLPLOpt_FashionMnist_20", 30 | "LOpt_ES4_LOpt_MLPLOpt_FahionMnist_50", 31 | "LOpt_ES4_LOpt_MLPLOpt_Cifar10_16_10", 32 | "LOptPES_Adafac_Fashion_OuterBS8_Length50_Trunc5", 33 | "LOptPES_Adafac_Cifar_OuterBS8_Length50_Trunc5", 34 | ] 35 | 36 | 37 | class LOptTest(parameterized.TestCase): 38 | 39 | @parameterized.parameters(tasks) 40 | def test_tasks(self, task_name): 41 | task = getattr(lopt, task_name)() 42 | test_utils.smoketest_task(task) 43 | 44 | 45 | if __name__ == "__main__": 46 | absltest.main() 47 | -------------------------------------------------------------------------------- /learned_optimization/tasks/fixed/mlp_mixer_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.tasks.fixed.vit.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.fixed import mlp_mixer 22 | 23 | tasks = [ 24 | 'MLPMixer_Cifar100_bs256_tiny16', 25 | 'MLPMixer_Cifar100_small16', 26 | 'MLPMixer_Cifar100_tiny16', 27 | 'MLPMixer_Food101_64_bs256_tiny16', 28 | 'MLPMixer_Food101_64_small16', 29 | 'MLPMixer_Food101_64_tiny16', 30 | 'MLPMixer_ImageNet64_bs256_tiny16', 31 | 'MLPMixer_ImageNet64_small16', 32 | 'MLPMixer_ImageNet64_tiny16', 33 | ] 34 | 35 | 36 | class MLPMixerTest(parameterized.TestCase): 37 | 38 | @parameterized.parameters(tasks) 39 | def test_tasks(self, task_name): 40 | task = getattr(mlp_mixer, task_name)() 41 | test_utils.smoketest_task(task) 42 | 43 | 44 | if __name__ == '__main__': 45 | absltest.main() 46 | -------------------------------------------------------------------------------- /learned_optimization/tasks/fixed/resnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.tasks.fixed.transformer_lm.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.fixed import resnet 22 | 23 | tasks = [ 24 | "Resnet_MultiRuntime_0", 25 | "Resnet_MultiRuntime_1", 26 | "Resnet_MultiRuntime_2", 27 | "Resnet_MultiRuntime_3", 28 | ] 29 | 30 | 31 | class ResnetLM(parameterized.TestCase): 32 | 33 | @parameterized.parameters(tasks) 34 | def test_tasks(self, task_name): 35 | task = getattr(resnet, task_name)() 36 | test_utils.smoketest_task(task) 37 | 38 | 39 | if __name__ == "__main__": 40 | absltest.main() 41 | -------------------------------------------------------------------------------- /learned_optimization/tasks/fixed/rnn_lm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.tasks.fixed.rnn_lm.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.fixed import rnn_lm 22 | 23 | tasks = [ 24 | 'RNNLM_lm1b32k_Patch32_IRNN256_Embed128', 25 | 'RNNLM_lm1b32k_Patch32_LSTM256_Embed128', 26 | 'RNNLM_lm1b32k_Patch32_VanillaRNN256_Embed128', 27 | 'RNNLM_lm1bbytes_Patch128_LSTM128_Embed64', 28 | 'RNNLM_lm1bbytes_Patch32_GRU128_Embed64', 29 | 'RNNLM_lm1bbytes_Patch32_GRU256_Embed128', 30 | 'RNNLM_lm1bbytes_Patch32_IRNN128_Embed64', 31 | 'RNNLM_lm1bbytes_Patch32_LSTM128_Embed64', 32 | 'RNNLM_lm1bbytes_Patch32_LSTM256_Embed128', 33 | 'RNNLM_lm1bbytes_Patch32_VanillaRNN128_Embed64', 34 | 'RNNLM_wikipediaen32k_Patch32_GRU256_Embed128', 35 | 'RNNLM_wikipediaen32k_Patch32_LSTM256_Embed128', 36 | 'RNNLM_wikipediaenbytes_Patch32_GRU256_Embed128', 37 | 'RNNLM_wikipediaenbytes_Patch32_LSTM256_Embed128', 38 | ] 39 | 40 | 41 | class RNNLM(parameterized.TestCase): 42 | 43 | @parameterized.parameters(tasks) 44 | def test_tasks(self, task_name): 45 | task = getattr(rnn_lm, task_name)() 46 | test_utils.smoketest_task(task) 47 | 48 | 49 | if __name__ == '__main__': 50 | absltest.main() 51 | -------------------------------------------------------------------------------- /learned_optimization/tasks/fixed/transformer_lm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.tasks.fixed.transformer_lm.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.fixed import transformer_lm 22 | 23 | tasks = [ 24 | "TransformerLM_LM1B_MultiRuntime_0", 25 | "TransformerLM_LM1B_MultiRuntime_1", 26 | "TransformerLM_LM1B_MultiRuntime_2", 27 | "TransformerLM_LM1B_MultiRuntime_3", 28 | "TransformerLM_LM1B_MultiRuntime_4", 29 | "TransformerLM_LM1B_5layer_32width", 30 | ] 31 | 32 | 33 | class TransformerLM(parameterized.TestCase): 34 | 35 | @parameterized.parameters(tasks) 36 | def test_tasks(self, task_name): 37 | task = getattr(transformer_lm, task_name)() 38 | test_utils.smoketest_task(task) 39 | 40 | 41 | if __name__ == "__main__": 42 | absltest.main() 43 | -------------------------------------------------------------------------------- /learned_optimization/tasks/fixed/vit_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.tasks.fixed.vit.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.fixed import vit 22 | 23 | tasks = [ 24 | "VIT_ImageNet64_wideshallow", 25 | "VIT_ImageNet64_skinnydeep", 26 | "VIT_Cifar100_wideshallow", 27 | "VIT_Cifar100_skinnydeep", 28 | "VIT_Food101_64_wideshallow", 29 | "VIT_Food101_64_skinnydeep", 30 | ] 31 | 32 | 33 | class VITTest(parameterized.TestCase): 34 | 35 | @parameterized.parameters(tasks) 36 | def test_tasks(self, task_name): 37 | task = getattr(vit, task_name)() 38 | test_utils.smoketest_task(task) 39 | 40 | 41 | if __name__ == "__main__": 42 | absltest.main() 43 | -------------------------------------------------------------------------------- /learned_optimization/tasks/parametric/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Parametric tasks module.""" 17 | -------------------------------------------------------------------------------- /learned_optimization/tasks/parametric/image_conv_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for image_mlp.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.datasets import image 22 | from learned_optimization.tasks.parametric import cfgobject 23 | from learned_optimization.tasks.parametric import image_conv 24 | 25 | 26 | class ImageConvTest(absltest.TestCase): 27 | 28 | def test_ParametricImageConv(self): 29 | datasets = image.mnist_datasets(8, image_size=(8, 8)) 30 | task_family = image_conv.ParametricImageConv( 31 | datasets, 32 | num_classes=10, 33 | hidden_sizes=(8, 2), 34 | kernel_sizes=((3, 3), (3, 3)), 35 | strides=(1, 1)) 36 | test_utils.smoketest_task_family(task_family) 37 | 38 | def test_sample_image_conv(self): 39 | key = jax.random.PRNGKey(0) 40 | cfg = image_conv.sample_image_conv(key) 41 | task_family = cfgobject.object_from_config(cfg) 42 | test_utils.smoketest_task_family(task_family) 43 | 44 | def test_timed_sample_image_conv(self): 45 | key = jax.random.PRNGKey(0) 46 | sampled_task = image_conv.timed_sample_image_conv(key) 47 | self.assertIsInstance(sampled_task, cfgobject.CFGObject) 48 | 49 | 50 | if __name__ == '__main__': 51 | absltest.main() 52 | -------------------------------------------------------------------------------- /learned_optimization/tasks/parametric/image_mlp_ae_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for image_mlp_ae.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.datasets import image 22 | from learned_optimization.tasks.parametric import cfgobject 23 | from learned_optimization.tasks.parametric import image_mlp_ae 24 | 25 | 26 | class ImageMlpAETest(absltest.TestCase): 27 | 28 | def test_ParametricImageMLPAE(self): 29 | datasets = image.mnist_datasets(8, image_size=(8, 8)) 30 | task_family = image_mlp_ae.ParametricImageMLPAE(datasets, (16, 5)) 31 | test_utils.smoketest_task_family(task_family) 32 | 33 | def test_sample_image_mlp_ae(self): 34 | key = jax.random.PRNGKey(0) 35 | cfg1 = image_mlp_ae.sample_image_mlp_ae(key) 36 | cfg2 = image_mlp_ae.sample_image_mlp_ae(key) 37 | key = jax.random.PRNGKey(1) 38 | cfg3 = image_mlp_ae.sample_image_mlp_ae(key) 39 | self.assertEqual(cfg1, cfg2) 40 | self.assertNotEqual(cfg1, cfg3) 41 | 42 | obj = cfgobject.object_from_config(cfg1) 43 | self.assertIsInstance(obj, image_mlp_ae.ParametricImageMLPAE) 44 | 45 | def test_timed_sample_image_mlp_ae(self): 46 | key = jax.random.PRNGKey(0) 47 | sampled_task = image_mlp_ae.timed_sample_image_mlp_ae(key) 48 | self.assertIsInstance(sampled_task, cfgobject.CFGObject) 49 | 50 | 51 | if __name__ == '__main__': 52 | absltest.main() 53 | -------------------------------------------------------------------------------- /learned_optimization/tasks/parametric/image_mlp_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for image_mlp.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.datasets import image 22 | from learned_optimization.tasks.parametric import cfgobject 23 | from learned_optimization.tasks.parametric import image_mlp 24 | 25 | 26 | class ImageMlpTest(absltest.TestCase): 27 | 28 | def test_ParametricImageMLP(self): 29 | datasets = image.mnist_datasets(8, image_size=(8, 8)) 30 | task_family = image_mlp.ParametricImageMLP(datasets, 10, (16, 5)) 31 | test_utils.smoketest_task_family(task_family) 32 | 33 | def test_timed_sample_image_mlp(self): 34 | key = jax.random.PRNGKey(0) 35 | sampled_task = image_mlp.timed_sample_image_mlp(key) 36 | self.assertIsInstance(sampled_task, cfgobject.CFGObject) 37 | 38 | def test_timed_sample_image_mlp_subsample_data(self): 39 | key = jax.random.PRNGKey(0) 40 | sampled_task = image_mlp.timed_sample_image_mlp_subsample_data(key) 41 | self.assertIsInstance(sampled_task, cfgobject.CFGObject) 42 | 43 | 44 | if __name__ == '__main__': 45 | absltest.main() 46 | -------------------------------------------------------------------------------- /learned_optimization/tasks/parametric/image_mlp_vae_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for image_mlp_vae.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.datasets import image 22 | from learned_optimization.tasks.parametric import cfgobject 23 | from learned_optimization.tasks.parametric import image_mlp_vae 24 | 25 | 26 | class ImageMlpVAETest(absltest.TestCase): 27 | 28 | def test_ParametricImageMLPVAE(self): 29 | datasets = image.mnist_datasets(8, image_size=(8, 8)) 30 | task_family = image_mlp_vae.ParametricImageMLPVAE( 31 | datasets, enc_hidden_sizes=(16,), dec_hidden_sizes=(16,), n_z=16) 32 | test_utils.smoketest_task_family(task_family) 33 | 34 | def test_sample_image_mlp_vae(self): 35 | key = jax.random.PRNGKey(0) 36 | cfg1 = image_mlp_vae.sample_image_mlp_vae(key) 37 | cfg2 = image_mlp_vae.sample_image_mlp_vae(key) 38 | key = jax.random.PRNGKey(1) 39 | cfg3 = image_mlp_vae.sample_image_mlp_vae(key) 40 | self.assertEqual(cfg1, cfg2) 41 | self.assertNotEqual(cfg1, cfg3) 42 | 43 | obj = cfgobject.object_from_config(cfg1) 44 | self.assertIsInstance(obj, image_mlp_vae.ParametricImageMLPVAE) 45 | 46 | def test_timed_sample_image_mlp_vae(self): 47 | key = jax.random.PRNGKey(0) 48 | sampled_task = image_mlp_vae.timed_sample_image_mlp_vae(key) 49 | self.assertIsInstance(sampled_task, cfgobject.CFGObject) 50 | 51 | if __name__ == '__main__': 52 | absltest.main() 53 | -------------------------------------------------------------------------------- /learned_optimization/tasks/parametric/image_resnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for image_mlp_resnet.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.datasets import image 22 | from learned_optimization.tasks.parametric import cfgobject 23 | from learned_optimization.tasks.parametric import image_resnet 24 | 25 | 26 | class ImageMlpresnetTest(absltest.TestCase): 27 | 28 | def test_ParametricImageMLPResNet(self): 29 | datasets = image.mnist_datasets(8, image_size=(8, 8)) 30 | task_family = image_resnet.ParametricImageResNet( 31 | datasets, 32 | initial_conv_channels=4, 33 | initial_conv_stride=1, 34 | initial_conv_kernel_size=3, 35 | blocks_per_group=(1, 1, 1, 1), 36 | channels_per_group=(4, 4, 4, 4), 37 | max_pool=True) 38 | test_utils.smoketest_task_family(task_family) 39 | 40 | def test_sample_image_resnet(self): 41 | key = jax.random.PRNGKey(0) 42 | cfg1 = image_resnet.sample_image_resnet(key) 43 | cfg2 = image_resnet.sample_image_resnet(key) 44 | key = jax.random.PRNGKey(1) 45 | cfg3 = image_resnet.sample_image_resnet(key) 46 | self.assertEqual(cfg1, cfg2) 47 | self.assertNotEqual(cfg1, cfg3) 48 | 49 | obj = cfgobject.object_from_config(cfg1) 50 | self.assertIsInstance(obj, image_resnet.ParametricImageResNet) 51 | 52 | def test_timed_sample_image_resnet(self): 53 | key = jax.random.PRNGKey(0) 54 | sampled_task = image_resnet.timed_sample_image_resnet(key) 55 | self.assertIsInstance(sampled_task, cfgobject.CFGObject) 56 | 57 | if __name__ == '__main__': 58 | absltest.main() 59 | -------------------------------------------------------------------------------- /learned_optimization/tasks/parametric/lm_rnn_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for lm_rnn.""" 17 | 18 | from absl.testing import absltest 19 | import haiku as hk 20 | import jax 21 | from learned_optimization.tasks import test_utils 22 | from learned_optimization.tasks.datasets import language 23 | from learned_optimization.tasks.parametric import cfgobject 24 | from learned_optimization.tasks.parametric import lm_rnn 25 | 26 | 27 | class LmRnnTest(absltest.TestCase): 28 | 29 | def test_ParametricLMRNN(self): 30 | datasets = language.lm1b_bytes_datasets(4, 6) 31 | task_family = lm_rnn.ParametricLMRNN( 32 | datasets, 33 | rnn_size=7, 34 | vocab_size=16, 35 | embed_size=4, 36 | rnn_core_fn=hk.VanillaRNN) 37 | test_utils.smoketest_task_family(task_family) 38 | 39 | def test_sample_lm_rnn(self): 40 | key = jax.random.PRNGKey(0) 41 | cfg1 = lm_rnn.sample_lm_rnn(key) 42 | cfg2 = lm_rnn.sample_lm_rnn(key) 43 | key = jax.random.PRNGKey(1) 44 | cfg3 = lm_rnn.sample_lm_rnn(key) 45 | self.assertEqual(cfg1, cfg2) 46 | self.assertNotEqual(cfg1, cfg3) 47 | 48 | obj = cfgobject.object_from_config(cfg1) 49 | self.assertIsInstance(obj, lm_rnn.ParametricLMRNN) 50 | 51 | def test_timed_sample_lm_rnn(self): 52 | key = jax.random.PRNGKey(0) 53 | sampled_task = lm_rnn.timed_sample_lm_rnn(key) 54 | self.assertIsInstance(sampled_task, cfgobject.CFGObject) 55 | 56 | 57 | if __name__ == '__main__': 58 | absltest.main() 59 | -------------------------------------------------------------------------------- /learned_optimization/tasks/parametric/lm_transformer_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for lm_transformer.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.datasets import language 22 | from learned_optimization.tasks.parametric import cfgobject 23 | from learned_optimization.tasks.parametric import lm_transformer 24 | 25 | 26 | class LmTransformerTest(absltest.TestCase): 27 | 28 | def test_ParametricLMTransformer(self): 29 | datasets = language.lm1b_bytes_datasets(4, 6) 30 | task_family = lm_transformer.ParametricLMTransformer( 31 | datasets, vocab_size=None, num_heads=4, num_layers=3, d_model=128) 32 | test_utils.smoketest_task_family(task_family) 33 | 34 | def test_sample_lm_rnn(self): 35 | key = jax.random.PRNGKey(0) 36 | cfg1 = lm_transformer.sample_lm_transformer(key) 37 | cfg2 = lm_transformer.sample_lm_transformer(key) 38 | key = jax.random.PRNGKey(1) 39 | cfg3 = lm_transformer.sample_lm_transformer(key) 40 | self.assertEqual(cfg1, cfg2) 41 | self.assertNotEqual(cfg1, cfg3) 42 | 43 | obj = cfgobject.object_from_config(cfg1) 44 | self.assertIsInstance(obj, lm_transformer.ParametricLMTransformer) 45 | 46 | def test_timed_sample_lm_transformer(self): 47 | key = jax.random.PRNGKey(0) 48 | sampled_task = lm_transformer.timed_sample_lm_transformer(key) 49 | self.assertIsInstance(sampled_task, cfgobject.CFGObject) 50 | 51 | if __name__ == '__main__': 52 | absltest.main() 53 | -------------------------------------------------------------------------------- /learned_optimization/tasks/parametric/lopt_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for lopt.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from learned_optimization.learned_optimizers import base as lopt_base 21 | from learned_optimization.tasks import quadratics 22 | from learned_optimization.tasks import test_utils 23 | from learned_optimization.tasks.parametric import cfgobject 24 | from learned_optimization.tasks.parametric import lopt 25 | 26 | 27 | class LOptTest(absltest.TestCase): 28 | 29 | def test_ParametricLOpt(self): 30 | task_family = quadratics.FixedDimQuadraticFamily() 31 | lopt_adam = lopt_base.LearnableAdam() 32 | task_family = lopt.ParametricLOpt( 33 | task_family, lopt_adam, 2, outer_batch_size=2) 34 | test_utils.smoketest_task_family(task_family) 35 | 36 | def test_sample_lopt(self): 37 | key = jax.random.PRNGKey(0) 38 | cfg1 = lopt.sample_lopt(key) 39 | cfg2 = lopt.sample_lopt(key) 40 | key = jax.random.PRNGKey(1) 41 | cfg3 = lopt.sample_lopt(key) 42 | self.assertEqual(cfg1, cfg2) 43 | self.assertNotEqual(cfg1, cfg3) 44 | 45 | obj = cfgobject.object_from_config(cfg1) 46 | self.assertIsInstance(obj, lopt.ParametricLOpt) 47 | 48 | def test_timed_sample_lopt(self): 49 | key = jax.random.PRNGKey(0) 50 | sampled_task = lopt.timed_sample_lopt(key) 51 | self.assertIsInstance(sampled_task, cfgobject.CFGObject) 52 | 53 | 54 | if __name__ == '__main__': 55 | absltest.main() 56 | -------------------------------------------------------------------------------- /learned_optimization/tasks/parametric/lopt_trunc_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for lopt_pes_test.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from learned_optimization.learned_optimizers import base as lopt_base 21 | from learned_optimization.outer_trainers import truncation_schedule 22 | from learned_optimization.tasks import quadratics 23 | from learned_optimization.tasks import test_utils 24 | from learned_optimization.tasks.parametric import cfgobject 25 | from learned_optimization.tasks.parametric import lopt_trunc 26 | 27 | jax.config.update('jax_threefry_partitionable', False) 28 | 29 | 30 | class LOptTruncTest(absltest.TestCase): 31 | 32 | def test_ParametricLOptTrunc(self): 33 | task_family = quadratics.FixedDimQuadraticFamily() 34 | lopt_adam = lopt_base.LearnableAdam() 35 | trunc_sched = truncation_schedule.ConstantTruncationSchedule(10) 36 | task_family = lopt_trunc.ParametricLOptTrunc( 37 | task_family, 38 | lopt_adam, 39 | trunc_sched=trunc_sched, 40 | unroll_length=2, 41 | outer_batch_size=2, 42 | max_length=2, 43 | mode='TruncatedES') 44 | test_utils.smoketest_task_family(task_family) 45 | 46 | def test_sample_lopt(self): 47 | key = jax.random.PRNGKey(0) 48 | cfg1 = lopt_trunc.sample_lopt_trunc(key) 49 | cfg2 = lopt_trunc.sample_lopt_trunc(key) 50 | key = jax.random.PRNGKey(1) 51 | cfg3 = lopt_trunc.sample_lopt_trunc(key) 52 | self.assertEqual(cfg1, cfg2) 53 | self.assertNotEqual(cfg1, cfg3) 54 | 55 | obj = cfgobject.object_from_config(cfg1) 56 | self.assertIsInstance(obj, lopt_trunc.ParametricLOptTrunc) 57 | 58 | def test_timed_sample_lopt(self): 59 | key = jax.random.PRNGKey(0) 60 | sampled_task = lopt_trunc.timed_sample_lopt_trunc(key) 61 | self.assertIsInstance(sampled_task, cfgobject.CFGObject) 62 | 63 | 64 | if __name__ == '__main__': 65 | absltest.main() 66 | -------------------------------------------------------------------------------- /learned_optimization/tasks/parametric/parametric_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for parametric_utils.""" 17 | 18 | from absl.testing import absltest 19 | import haiku as hk 20 | import jax 21 | import jax.numpy as jnp 22 | from learned_optimization.tasks.parametric import parametric_utils 23 | import numpy as onp 24 | from numpy import testing 25 | 26 | 27 | class ParametricUtilsTest(absltest.TestCase): 28 | 29 | def test_SampleImageDataset(self): 30 | key = jax.random.PRNGKey(0) 31 | cfg = parametric_utils.SampleImageDataset.sample(key) 32 | datasets = parametric_utils.SampleImageDataset.get_dataset(cfg, 8, (8, 8)) 33 | _ = datasets.train 34 | 35 | def test_SampleActivation(self): 36 | key = jax.random.PRNGKey(0) 37 | cfg = parametric_utils.SampleActivation.sample(key) 38 | act_fn = parametric_utils.SampleActivation.get_dynamic(cfg) 39 | value = jax.jit(act_fn)(12.) 40 | self.assertEqual(value.shape, ()) 41 | 42 | act_fn = parametric_utils.SampleActivation.get_static(cfg) 43 | value2 = act_fn(12.) 44 | self.assertEqual(value, value2) 45 | 46 | def test_SampleInitializer(self): 47 | key = jax.random.PRNGKey(0) 48 | cfg = parametric_utils.SampleInitializer.sample(key) 49 | 50 | def forward(cfg): 51 | init = parametric_utils.SampleInitializer.get_dynamic(cfg) 52 | param = hk.get_parameter('asdf', [2, 2], dtype=jnp.float32, init=init) 53 | return param 54 | 55 | init_fn, _ = hk.transform(forward) 56 | val = jax.jit(init_fn)(key, cfg) 57 | self.assertEqual(jax.tree_util.tree_leaves(val)[0].shape, (2, 2)) 58 | 59 | def test_orth_init(self): 60 | key = jax.random.PRNGKey(0) 61 | init = parametric_utils.orth_init([16, 16], jnp.float32, key) 62 | # Check that the initializer is orthogonal by checking if all the eval's 63 | evals, unused_evecs = onp.linalg.eig(init) 64 | testing.assert_allclose(onp.abs(evals), jnp.ones([16]), rtol=1e-6) 65 | 66 | 67 | if __name__ == '__main__': 68 | absltest.main() 69 | -------------------------------------------------------------------------------- /learned_optimization/tasks/parametric/vit_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for vit.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from learned_optimization.tasks import test_utils 21 | from learned_optimization.tasks.datasets import image 22 | from learned_optimization.tasks.parametric import cfgobject 23 | from learned_optimization.tasks.parametric import vit 24 | 25 | 26 | class LmTransformerTest(absltest.TestCase): 27 | 28 | def test_ParametricVITTransformer(self): 29 | datasets = image.cifar100_datasets(8) 30 | task_family = vit.ParametricVIT( 31 | datasets, 32 | hidden_size=8, 33 | patch_size=4, 34 | mlp_dim=8, 35 | layers=1, 36 | num_heads=2, 37 | attension_dropout=0.0, 38 | dropout=0.0) 39 | test_utils.smoketest_task_family(task_family) 40 | 41 | def test_sample_lm_rnn(self): 42 | key = jax.random.PRNGKey(0) 43 | cfg1 = vit.sample_vit(key) 44 | cfg2 = vit.sample_vit(key) 45 | key = jax.random.PRNGKey(1) 46 | cfg3 = vit.sample_vit(key) 47 | self.assertEqual(cfg1, cfg2) 48 | self.assertNotEqual(cfg1, cfg3) 49 | 50 | obj = cfgobject.object_from_config(cfg1) 51 | self.assertIsInstance(obj, vit.ParametricVIT) 52 | 53 | def test_timed_sample_vit(self): 54 | key = jax.random.PRNGKey(0) 55 | sampled_task = vit.timed_sample_vit(key) 56 | self.assertIsInstance(sampled_task, cfgobject.CFGObject) 57 | 58 | 59 | if __name__ == '__main__': 60 | absltest.main() 61 | -------------------------------------------------------------------------------- /learned_optimization/tasks/quadratics_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for learned_optimizers.tasks.quadratics.""" 17 | 18 | from absl.testing import absltest 19 | from learned_optimization.tasks import quadratics 20 | from learned_optimization.tasks import test_utils 21 | 22 | 23 | class QuadraticsTest(absltest.TestCase): 24 | 25 | def test_quadratic_task(self): 26 | test_utils.smoketest_task(quadratics.QuadraticTask()) 27 | 28 | def test_batch_quadratic_task(self): 29 | test_utils.smoketest_task(quadratics.BatchQuadraticTask()) 30 | 31 | def test_fixed_dim_quadratic_family(self): 32 | test_utils.smoketest_task_family(quadratics.FixedDimQuadraticFamily(10)) 33 | 34 | def test_fixed_dim_quadratic_family_data(self): 35 | test_utils.smoketest_task_family(quadratics.FixedDimQuadraticFamilyData(10)) 36 | 37 | def test_noisy_quadratic_family(self): 38 | test_utils.smoketest_task_family(quadratics.NoisyQuadraticFamily(10, 0.1)) 39 | 40 | if __name__ == '__main__': 41 | absltest.main() 42 | -------------------------------------------------------------------------------- /learned_optimization/tasks/rnn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Recurrent RNN haiku modules.""" 17 | from typing import Optional 18 | 19 | import haiku as hk 20 | import jax 21 | 22 | 23 | class IRNN(hk.VanillaRNN): 24 | """Identity initialized RNN. 25 | 26 | This was introduced in https://arxiv.org/abs/1504.00941. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | hidden_size: int, 32 | double_bias: bool = True, 33 | name: Optional[str] = None, 34 | gain: float = 1.0, 35 | ): 36 | """Constructs a Identity RNN core. 37 | 38 | Args: 39 | hidden_size: Hidden layer size. 40 | double_bias: Whether to use a bias in the two linear layers. This changes 41 | nothing to the learning performance of the cell. However, doubling will 42 | create two sets of bias parameters rather than one. 43 | name: Name of the module. 44 | gain: multiplier on recurrent weight identity initialization. 45 | """ 46 | super().__init__( 47 | hidden_size=hidden_size, double_bias=double_bias, name=name) 48 | self.gain = gain 49 | 50 | def __call__(self, inputs, prev_state): 51 | input_to_hidden = hk.Linear(self.hidden_size) 52 | hidden_to_hidden = hk.Linear( 53 | self.hidden_size, 54 | with_bias=self.double_bias, 55 | w_init=hk.initializers.Identity(self.gain)) 56 | out = jax.nn.relu(input_to_hidden(inputs) + hidden_to_hidden(prev_state)) 57 | return out, out 58 | -------------------------------------------------------------------------------- /learned_optimization/tasks/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities for testing tasks and task families.""" 17 | 18 | from absl import logging 19 | import jax 20 | import jax.numpy as jnp 21 | from learned_optimization.tasks import base 22 | 23 | 24 | def smoketest_task(task: base.Task, abstract_data: bool = True): 25 | """Smoke test a Task. 26 | 27 | Args: 28 | task: Task to test. 29 | abstract_data: To use abstract data, or to use the real dataset for this 30 | test. Using abstract batches does not require loading the underlying 31 | dataset which can be faster if datasets are stored on a remote machine. 32 | """ 33 | key = jax.random.PRNGKey(0) 34 | param, state = task.init_with_state(key) 35 | 36 | logging.info("Getting data for %s task", str(task)) 37 | if task.datasets: 38 | if abstract_data and task.datasets.abstract_batch is not None: 39 | batch = jax.tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), 40 | task.datasets.abstract_batch) 41 | else: 42 | batch = next(task.datasets.train) 43 | else: 44 | batch = () 45 | logging.info("Got data") 46 | 47 | logging.info("starting forward") 48 | loss_and_state = task.loss_with_state(param, state, key, batch) 49 | del loss_and_state 50 | logging.info("starting backward") 51 | grad, aux = jax.grad( 52 | task.loss_with_state, has_aux=True)(param, state, key, batch) 53 | del grad, aux 54 | logging.info("checking normalizer") 55 | task.normalizer(jnp.asarray(1.0)) 56 | 57 | logging.info("checking loss_with_state_and_aux") 58 | loss, state, aux = task.loss_with_state_and_aux(param, state, key, batch) 59 | del loss, state, aux 60 | logging.info("done") 61 | 62 | 63 | def smoketest_task_family(task_family: base.TaskFamily): 64 | """Smoke test a TaskFamily.""" 65 | key = jax.random.PRNGKey(0) 66 | task_params = task_family.sample(key) 67 | task = task_family.task_fn(task_params) 68 | smoketest_task(task) 69 | 70 | _, key = jax.random.split(key) 71 | task_params = task_family.sample(key) 72 | task = task_family.task_fn(task_params) 73 | smoketest_task(task) 74 | 75 | if task.datasets is not None: 76 | assert task_family.datasets is not None 77 | -------------------------------------------------------------------------------- /learned_optimization/time_filter/time_model_data_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for time_model_data.""" 17 | 18 | import os 19 | import tempfile 20 | 21 | from absl.testing import absltest 22 | from learned_optimization.tasks import quadratics # pylint: disable=unused-import 23 | from learned_optimization.tasks.parametric import cfgobject 24 | from learned_optimization.time_filter import run_sample_and_time 25 | from learned_optimization.time_filter import time_model_data 26 | 27 | 28 | def fake_sample_quadratic(key): 29 | del key 30 | return cfgobject.CFGObject("FixedDimQuadraticFamily", {"dim": 10}) 31 | 32 | 33 | class TimeModelDataTest(absltest.TestCase): 34 | 35 | def test_load_runtime_files(self): 36 | with tempfile.TemporaryDirectory() as tmpdir: 37 | os.environ["LOPT_TIMING_DIR"] = tmpdir 38 | tmpdir = os.path.join(tmpdir, "quadratic") 39 | run_sample_and_time.run_many_eval_and_save( 40 | fake_sample_quadratic, tmpdir, num_to_run=4) 41 | data = time_model_data.load_runtime_files( 42 | "quadratic", "cpu_cpu", 5, threads=5) 43 | self.assertLen(data, 4) 44 | 45 | def test_train_test_iterators(self): 46 | with tempfile.TemporaryDirectory() as tmpdir: 47 | os.environ["LOPT_TIMING_DIR"] = tmpdir 48 | tmpdir = os.path.join(tmpdir, "quadratic") 49 | run_sample_and_time.run_many_eval_and_save( 50 | fake_sample_quadratic, tmpdir, num_to_run=4) 51 | tr_it, unused_te_it = time_model_data.train_test_iterators( 52 | "quadratic", "cpu_cpu", max_samples_to_load=5, num_test=2) 53 | for _ in range(10): 54 | data = next(tr_it) 55 | self.assertTrue("feats" in data) # pylint: disable=g-generic-assert 56 | self.assertTrue("time" in data) # pylint: disable=g-generic-assert 57 | 58 | 59 | if __name__ == "__main__": 60 | absltest.main() 61 | -------------------------------------------------------------------------------- /learned_optimization/time_filter/timings_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for timings.""" 17 | 18 | from absl.testing import absltest 19 | from learned_optimization.learned_optimizers import base as lopt_base 20 | from learned_optimization.optimizers import base as opt_base 21 | from learned_optimization.tasks import base as tasks_base 22 | from learned_optimization.tasks import quadratics 23 | from learned_optimization.time_filter import timings 24 | 25 | 26 | class TimingsTest(absltest.TestCase): 27 | 28 | def test_timing_default(self): 29 | task = quadratics.BatchQuadraticTask() 30 | task_family = tasks_base.single_task_to_family(task) 31 | stats = timings.task_family_runtime_stats(task_family) 32 | self.assertLess(stats["unroll_4x10"][0], 1.0) 33 | 34 | def test_timing_opt(self): 35 | task = quadratics.BatchQuadraticTask() 36 | task_family = tasks_base.single_task_to_family(task) 37 | stats = timings.task_family_runtime_stats(task_family, opt=opt_base.Adam()) 38 | self.assertLess(stats["unroll_4x10"][0], 1.0) 39 | 40 | def test_timing_lopt(self): 41 | task = quadratics.BatchQuadraticTask() 42 | task_family = tasks_base.single_task_to_family(task) 43 | stats = timings.task_family_runtime_stats( 44 | task_family, lopt=lopt_base.LearnableAdam()) 45 | self.assertLess(stats["unroll_4x10"][0], 1.0) 46 | 47 | 48 | if __name__ == "__main__": 49 | absltest.main() 50 | -------------------------------------------------------------------------------- /learned_optimization/training.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities useful for training and meta-training.""" 17 | import os 18 | from typing import Any, Sequence 19 | 20 | from absl import logging 21 | from flax import serialization 22 | import jax 23 | from learned_optimization import filesystem as fs 24 | from learned_optimization import profile 25 | from learned_optimization import tree_utils 26 | from learned_optimization.tasks import base as tasks_base 27 | 28 | 29 | def save_state(path, state): 30 | fs.make_dirs(os.path.dirname(path)) 31 | with fs.file_open(path, "wb") as fp: 32 | fp.write(serialization.to_bytes(state)) 33 | 34 | 35 | def load_state(path, state): 36 | logging.info("Restoring state %s", path) 37 | with fs.file_open(path, "rb") as fp: 38 | state_new = serialization.from_bytes(state, fp.read()) 39 | tree = jax.tree_util.tree_structure(state) 40 | leaves = jax.tree_util.tree_leaves(state_new) 41 | return jax.tree_util.tree_unflatten(tree, leaves) 42 | 43 | 44 | def get_batches(task_family: tasks_base.TaskFamily, 45 | batch_shape: Sequence[int], 46 | split: str, 47 | numpy: bool = False) -> Any: 48 | """Get batches of data with the `batch_shape` leading dimension.""" 49 | if len(batch_shape) == 1: 50 | return vec_get_batch(task_family, batch_shape[0], numpy=numpy, split=split) 51 | elif len(batch_shape) == 2: 52 | datas_list = [ 53 | vec_get_batch(task_family, batch_shape[1], numpy=numpy, split=split) 54 | for _ in range(batch_shape[0]) 55 | ] 56 | if numpy: 57 | return tree_utils.tree_zip_onp(datas_list) 58 | else: 59 | return tree_utils.tree_zip_jnp(datas_list) 60 | elif len(batch_shape) == 3: 61 | datas_list = [ 62 | get_batches( 63 | task_family, [batch_shape[1], batch_shape[2]], 64 | numpy=numpy, 65 | split=split) for _ in range(batch_shape[0]) 66 | ] 67 | if numpy: 68 | return tree_utils.tree_zip_onp(datas_list) 69 | else: 70 | return tree_utils.tree_zip_jnp(datas_list) 71 | else: 72 | raise NotImplementedError() 73 | 74 | 75 | @profile.wrap() 76 | def vec_get_batch(task_family, n_tasks, split, numpy=False): 77 | to_zip = [] 78 | for _ in range(n_tasks): 79 | if task_family.datasets is None: 80 | return () 81 | to_zip.append(next(task_family.datasets.split(split))) 82 | return tree_utils.tree_zip_onp(to_zip) if numpy else tree_utils.tree_zip_jnp( 83 | to_zip) 84 | -------------------------------------------------------------------------------- /learned_optimization/tree_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tree_utils.""" 17 | 18 | from typing import Any 19 | 20 | from absl.testing import absltest 21 | import flax 22 | import jax 23 | import jax.numpy as jnp 24 | from learned_optimization import tree_utils 25 | 26 | 27 | @flax.struct.dataclass 28 | class ClassName: 29 | val1: Any 30 | val2: Any 31 | 32 | 33 | class TreeUtilsTest(absltest.TestCase): 34 | 35 | def test_partition(self): 36 | c = {"a": 2, "b": ClassName(1, 2.)} 37 | 38 | partitions, unflattener = tree_utils.partition( 39 | [lambda k, v: jnp.asarray(v).dtype == jnp.int32], c) 40 | 41 | partitions[1] = jax.tree_util.tree_map(lambda x: x * 2, partitions[1]) 42 | 43 | tree_utils.partition_unflatten(unflattener, partitions) 44 | data = unflattener(partitions) 45 | 46 | self.assertEqual(data["b"].val2, 4.) 47 | self.assertEqual(data["b"].val1, 1) 48 | 49 | 50 | if __name__ == "__main__": 51 | absltest.main() 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.12.0 2 | numpy>=1.18 3 | jax>=0.2.6 4 | jaxlib>=0.1.68 5 | pytest 6 | tqdm>=4.62.3 7 | flax 8 | dm-haiku==0.0.5 9 | optax>=0.0.9 10 | tensorflow>=2.7.0 11 | tensorflow-datasets>=4.4.0 12 | tensorflow-metadata==1.5.0 13 | tensorflow-probability>=0.16.0 14 | tensorboard>=2.7.0 15 | gin-config>=0.5.0 16 | seqio>=0.0.7 17 | oryx 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Setup the package.""" 17 | 18 | import os 19 | import setuptools 20 | 21 | # https://packaging.python.org/guides/making-a-pypi-friendly-readme/ 22 | this_directory = os.path.abspath(os.path.dirname(__file__)) 23 | with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f: 24 | long_description = f.read() 25 | 26 | __version__ = '0.0.1' 27 | 28 | setuptools.setup( 29 | name='learned_optimization', 30 | version=__version__, 31 | description='Train learned optimizers in Jax.', 32 | author='learned_optimization team', 33 | author_email='lmetz@google.com', 34 | packages=setuptools.find_packages(exclude=['examples']), 35 | package_data={'learned_optimization': ['py.typed']}, 36 | python_requires='>=3.7', 37 | setup_requires=['setuptools_scm'], 38 | include_package_data=True, 39 | # TODO(lmetz) don't fix many versions! Sadly a number of these libraries 40 | # don't play nice with newer versions of other libraries. 41 | # TODO(lmetz) add oryx to this! 42 | install_requires=[ 43 | 'absl-py==0.12.0', 44 | 'numpy>=1.18', 45 | 'jax>=0.4.3', 46 | 'jaxlib>=0.4.3', 47 | 'nose', 48 | 'dm-launchpad-nightly', 49 | 'tqdm>=4.62.3', 50 | 'flax', 51 | 'dm-haiku==0.0.5', 52 | 'optax>=0.0.9', 53 | 'tensorflow>=2.7.0', 54 | 'tensorflow-datasets>=4.4.0', 55 | 'tensorflow-metadata==1.5.0', 56 | 'tensorboard>=2.7.0', 57 | 'gin-config>=0.5.0', 58 | 'oryx', 59 | ], 60 | url='https://github.com/google/learned_optimization', 61 | license='Apache-2.0', 62 | long_description=long_description, 63 | long_description_content_type='text/markdown', 64 | classifiers=[ 65 | 'Programming Language :: Python :: 3.7', 66 | 'Programming Language :: Python :: 3.8', 67 | 'Programming Language :: Python :: 3.9', 68 | ], 69 | zip_safe=False, 70 | ) 71 | --------------------------------------------------------------------------------