├── .github └── workflows │ └── ci.yaml ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── benchmark_matrix.png ├── delayed_psgd.png └── memory.png ├── benchmark ├── README.md ├── adversarial_gradient.py ├── batch_size_scaling.py ├── beale.py ├── benchmark_results.md ├── char_rnn.py ├── constrained_optimization.py ├── create_hiker_analogy_diagram.py ├── discontinuous_gradient.py ├── dynamic_landscape.py ├── exploding_gradient.py ├── format_results.py ├── generate_animations.py ├── gradient_delay.py ├── gradient_noise_scale.py ├── grokking.py ├── layer_wise_scale.py ├── loss_contour.py ├── merge_logs.py ├── minimax.py ├── momentum_utilization.py ├── noisy_matmul.py ├── parameter_scale.py ├── plateau_navigation.py ├── postprocess_requirements.txt ├── powers.py ├── powers_varying_target.py ├── quadratic_varying_scale.py ├── quadratic_varying_target.py ├── rastrigin.py ├── relu_boundaries.py ├── requirements.txt ├── rosenbrock.py ├── run_all_benchmarks.py ├── saddle_point.py ├── scale_invariant.py ├── shakespeare.txt ├── sparse_gradient.py ├── utils.py ├── wide_linear.py ├── xor_digit.py ├── xor_digit_rnn.py ├── xor_sequence.py ├── xor_sequence_rnn.py ├── xor_spot.py └── xor_spot_rnn.py ├── build.sh ├── docs ├── assets │ ├── benchmark_matrix.png │ ├── early_stopping.png │ ├── hiker_analogy_diagram.png │ ├── psgd_efficiency_cache.png │ ├── psgd_efficiency_cache_triu_as_line.png │ ├── psgd_efficiency_triu_as_line.png │ └── saddle_point_comparison.gif ├── benchmark.md └── psgd_efficiency.md ├── examples ├── autoencoder.py ├── lra.py └── modify_functions.py ├── heavyball ├── __init__.py ├── chainable.py ├── helpers.py └── utils.py ├── pre-commit.yaml ├── pyproject.toml └── test ├── benchmark_psgd_lb.py ├── readme.md ├── test_bf16_params.py ├── test_bf16_q.py ├── test_bf16_storage.py ├── test_caution.py ├── test_channels_last.py ├── test_closure.py ├── test_ema.py ├── test_foreach.py ├── test_hook.py ├── test_mars.py ├── test_memory.py ├── test_memory_leak.py ├── test_merge.py ├── test_no_grad.py ├── test_psgd_precond_init_stability.py ├── test_save_restore.py ├── test_soap.py └── test_stochastic_updates.py /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: Continuous integration 2 | 3 | on: 4 | pull_request: 5 | workflow_dispatch: 6 | 7 | jobs: 8 | pre-commit: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: actions/setup-python@v5 13 | with: 14 | python-version: "3.11" 15 | - name: Run pre-commit hooks 16 | uses: pre-commit/action@v3.0.1 17 | with: 18 | extra_args: --all-files --config pre-commit.yaml 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | splinecam 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2022, Lucas Nestler 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # heavyball 2 | 3 | [][pypi] 4 | [][license] 5 | 6 | _High-performance, extensible, chainable optimizers for PyTorch._ 7 | 8 | ## Why heavyball 9 | 10 | - **Lightning-Fast Training**: Batched `foreach` operations deliver significant speedups on large models. 11 | - **Adaptive & Extensible**: Built-in AdamW, RMSprop, Schedule-Free algorithms, and PaLM-inspired schedules. 12 | - **Plug-and-Play**: Drop-in replacements for `torch.optim` with seamless integration. 13 | - **Customizable**: Chainable API lets you compose optimizers and transforms (MARS correction, cautious updates, orthogonal updates). 14 | - **Battle-Tested**: Extensive benchmarks and real-world examples included. 15 | 16 | ## Key Features 17 | 18 | - Foreach-based optimizers: `ForeachAdamW`, `ForeachRMSprop`, `ForeachSFAdamW`, `Muon`, `ADOPT`, `MSAM`, … 19 | - Schedule-Free optimizers with dynamic learning rate adaptation. 20 | - Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling. 21 | - Chainable transforms for custom optimization recipes. 22 | - Comprehensive benchmark suite (`benchmark/`). 23 | - Detailed documentation and example-driven tutorials. 24 | 25 | ## Quickstart 26 | 27 | **Install:** 28 | ```bash 29 | pip install heavyball 30 | ``` 31 | 32 | **Basic usage:** 33 | ```python 34 | import torch 35 | from torch import nn 36 | from heavyball import ForeachAdamW 37 | 38 | model = nn.Sequential( 39 | nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10) 40 | ) 41 | optimizer = ForeachAdamW(model.parameters(), lr=1e-3) 42 | 43 | for data, target in dataloader: 44 | optimizer.zero_grad() 45 | output = model(data) 46 | loss = torch.nn.functional.cross_entropy(output, target) 47 | loss.backward() 48 | optimizer.step() 49 | ``` 50 | 51 | ## Benchmarks 52 | 53 | > Reproduce benchmarks with: 54 | > ```bash 55 | > python3 -m benchmark.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880 56 | > ``` 57 | 58 | 59 | ## Contributing 60 | 61 | We welcome contributions! Please check the [issue tracker][tracker] and follow these steps: 62 | 1. Fork the repo and create a feature branch. 63 | 2. Install dev dependencies: `pip install -e .[dev]`. 64 | 3. Run tests: `pytest`. 65 | 4. Submit a pull request. 66 | 67 | ## License 68 | 69 | BSD 3-Clause — see the [LICENSE](LICENSE) file. 70 | 71 | --- 72 |
73 | Made by the HeavyBall team. 74 |
75 | 76 | [pypi]: https://pypi.org/project/heavyball/ 77 | [license]: LICENSE 78 | [tracker]: https://github.com/HomebrewML/HeavyBall/issues 79 | -------------------------------------------------------------------------------- /assets/benchmark_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/assets/benchmark_matrix.png -------------------------------------------------------------------------------- /assets/delayed_psgd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/assets/delayed_psgd.png -------------------------------------------------------------------------------- /assets/memory.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/assets/memory.png -------------------------------------------------------------------------------- /benchmark/README.md: -------------------------------------------------------------------------------- 1 | # HeavyBall Benchmark Suite 2 | 3 | This repository contains a suite of benchmark problems designed to test the capabilities and robustness of `torch.optim` optimizers. The framework is designed to be modular and extensible, allowing for the easy addition of new benchmarks and optimizers. 4 | 5 | ## Setup 6 | 7 | To get started, install the required dependencies: 8 | 9 | ```bash 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Usage 14 | 15 | To run the entire benchmark suite with a specific set of optimizers, use the `run_all_benchmarks.py` script. 16 | 17 | For example, to run the benchmarks for Adam and SGD on easy and medium difficulties: 18 | ```bash 19 | python run_all_benchmarks.py --opt Adam --opt SGD --difficulties easy medium 20 | ``` 21 | 22 | You can specify multiple optimizers and difficulties. The results will be written to `benchmark_results.md`. 23 | 24 | For more options, see: 25 | ```bash 26 | python run_all_benchmarks.py --help 27 | ``` 28 | 29 | ### Core Philosophy and Design Principles 30 | 31 | Based on the analysis of [`benchmark_template.py`](benchmark_template.py), [`optimizer_template.py`](optimizer_template.py), and [`utils.py`](utils.py), the core framework is a modular and extensible system designed for benchmarking `torch` optimizers. It is composed of three primary components: 32 | 33 | 1. **Benchmarks**: Each benchmark is a self-contained optimization problem defined by a class. It provides a loss function to be minimized, the parameters to be optimized, and a specific success condition. 34 | 2. **Optimizers**: These are expected to adhere to the `torch.optim.Optimizer` interface. A template is provided to facilitate the creation of new custom optimizers. 35 | 3. **Benchmarking Harness**: The [`utils.py`](utils.py) script contains the engine that runs the benchmarks. It includes a sophisticated `Validator` class for monitoring convergence and detecting failures, an `Objective` class that encapsulates the entire process of running a benchmark (including hyperparameter search with `optuna`), and a high-level `trial` function to orchestrate the whole process. 36 | 37 | #### Design Principles 38 | 39 | The framework is built upon the following design principles: 40 | 41 | * **Modularity**: The separation of benchmarks, optimizers, and the runner logic allows for easy extension. New benchmarks or optimizers can be added by implementing simple, well-defined interfaces without needing to alter the core execution logic. 42 | * **Automation**: The framework automates the process of running benchmarks, including a hyperparameter search for learning rate and momentum coefficients. This simplifies the process of evaluating an optimizer's performance across various problems. 43 | * **Robustness**: The `Validator` class provides sophisticated logic to terminate runs that are not making progress, saving computational resources. It checks for multiple failure conditions, such as stagnating loss and lack of improvement over time. 44 | * **PyTorch Native**: The framework is built on PyTorch and leverages its core components like `autograd` for gradient computation and the standard `Optimizer` class structure. 45 | 46 | #### Benchmark Structure 47 | 48 | A benchmark is structured as a class, as defined in [`benchmark_template.py`](benchmark_template.py:4), with the following key methods: 49 | 50 | * **`__init__(self, device)`**: Initializes the benchmark's state. This is where the model parameters and any necessary data are created. The parameters to be optimized are stored in `self.params`. 51 | * **`__call__(self) -> torch.Tensor`**: This is the core of the benchmark, defining the objective function. It computes and returns a scalar loss tensor that the optimizer will try to minimize. 52 | * **`has_succeeded(self) -> bool`**: Defines the "win condition" for the benchmark. It returns `True` if the optimization is considered successful, typically based on the loss value falling below a predefined threshold. 53 | * **`get_params(self) -> list[torch.Tensor]`**: Returns the list of `torch.Tensor` parameters that will be optimized. 54 | 55 | ### Existing Benchmarks 56 | 57 | | File | Category | Description | 58 | |---|---|---| 59 | | [`adversarial_gradient.py`](adversarial_gradient.py) | Noise Robustness | Tests optimizer's robustness to an oscillating adversarial component added to the gradient, simulating adversarial noise. | 60 | | [`batch_size_scaling.py`](batch_size_scaling.py) | Noise Robustness | Tests how optimizers handle noise at different scales, which are modulated by a simulated, randomly changing batch size. | 61 | | [`beale.py`](beale.py) | General/Classic | Implements the Beale function, a classic optimization benchmark with sharp valleys, to test general performance. Also includes visualization. | 62 | | [`char_rnn.py`](char_rnn.py) | Sequence & Memory (RNN/LSTM) | A character-level language model using an LSTM on text data, testing performance on a sequence task requiring memory. | 63 | | [`constrained_optimization.py`](constrained_optimization.py) | Landscape Traversal | A simple quadratic objective with a penalty-based constraint, creating a sharp ridge in the loss landscape for the optimizer to navigate. | 64 | | [`discontinuous_gradient.py`](discontinuous_gradient.py) | Gradient Characteristics | Tests optimizer robustness to non-smooth landscapes by using an objective function with a discontinuous gradient at the origin. | 65 | | [`dynamic_landscape.py`](dynamic_landscape.py) | Dynamic Environments | Tests an optimizer's ability to track a continuously shifting target in a non-stationary loss landscape. | 66 | | [`exploding_gradient.py`](exploding_gradient.py) | Gradient Characteristics | Tests an optimizer's numerical stability and handling of extreme gradient values by using an exponential function that causes gradients to grow rapidly. | 67 | | [`gradient_delay.py`](gradient_delay.py) | Gradient Characteristics | Tests an optimizer's ability to handle asynchronous or delayed updates by using gradients from previous steps. | 68 | | [`gradient_noise_scale.py`](gradient_noise_scale.py) | Noise Robustness | Tests an optimizer's ability to handle dynamically changing noise levels, where the noise scale anneals over time. | 69 | | [`grokking.py`](grokking.py) | Landscape Traversal | Tests for 'grokking' by training a model on a modular arithmetic task, examining if the optimizer can find a generalizable solution after a long period of memorization. | 70 | | [`layer_wise_scale.py`](layer_wise_scale.py) | Multi-Scale & Conditioning | Tests an optimizer's ability to handle parameters with vastly different gradient scales by scaling the loss contribution of different layers. | 71 | | [`minimax.py`](minimax.py) | Landscape Traversal | Implements a minimax objective function which creates a saddle point, testing the optimizer's ability to escape such points. | 72 | | [`momentum_utilization.py`](momentum_utilization.py) | Landscape Traversal | Tests the effective use of momentum by creating an oscillating loss landscape with many local minima. | 73 | | [`noisy_matmul.py`](noisy_matmul.py) | Gradient Characteristics | Tests optimizer stability in deep networks by performing a sequence of matrix multiplications, which can lead to exploding or vanishing gradients. | 74 | | [`parameter_scale.py`](parameter_scale.py) | Multi-Scale & Conditioning | Tests an optimizer's ability to handle parameters initialized at widely different scales, creating a poorly conditioned problem. | 75 | | [`plateau_navigation.py`](plateau_navigation.py) | Landscape Traversal | Tests an optimizer's ability to navigate a loss landscape with a large, flat plateau region surrounded by a steep cliff. | 76 | | [`powers_varying_target.py`](powers_varying_target.py) | Multi-Scale & Conditioning | Creates a complex, poorly conditioned landscape by raising parameters to different powers against a non-zero, varying target. | 77 | | [`powers.py`](powers.py) | Multi-Scale & Conditioning | Creates a poorly conditioned problem by raising parameters to various powers, resulting in gradients of different magnitudes. | 78 | | [`quadratic_varying_scale.py`](quadratic_varying_scale.py) | Multi-Scale & Conditioning | Tests handling of ill-conditioned problems by creating a quadratic objective where each parameter's gradient has a different scale. | 79 | | [`quadratic_varying_target.py`](quadratic_varying_target.py) | General/Classic | A simple quadratic bowl (sphere) benchmark where the minimum is at a non-zero target vector. | 80 | | [`rastrigin.py`](rastrigin.py) | General/Classic | Implements the Rastrigin function, a classic highly multi-modal benchmark for testing global optimization capabilities. | 81 | | [`rosenbrock.py`](rosenbrock.py) | General/Classic | Implements the Rosenbrock function, a classic benchmark known for its narrow, banana-shaped valley that is difficult to navigate. | 82 | | [`saddle_point.py`](saddle_point.py) | Landscape Traversal | Tests an optimizer's ability to escape a classic saddle point (x^2 - y^2) and visualizes the optimization paths. | 83 | | [`scale_invariant.py`](scale_invariant.py) | Multi-Scale & Conditioning | Tests an optimizer's handling of parameters at different scales by using a logarithmic objective on parameters initialized across many orders of magnitude. | 84 | | [`sparse_gradient.py`](sparse_gradient.py) | Gradient Characteristics | Tests an optimizer's performance when gradients are sparse, which is simulated by randomly masking parameter updates. | 85 | | [`wide_linear.py`](wide_linear.py) | Multi-Scale & Conditioning | Tests optimizer performance on a model with a very wide linear layer, which can present conditioning challenges. | 86 | | [`xor_digit_rnn.py`](xor_digit_rnn.py) | Sequence & Memory (RNN/LSTM) | Tests an RNN's ability to solve the parity task (XOR sum) on a sequence of bits, requiring it to maintain a memory state. | 87 | | [`xor_digit.py`](xor_digit.py) | Sequence & Memory (RNN/LSTM) | Tests an LSTM's ability to solve the parity task (XOR sum) on a sequence of bits, a classic test of sequence memory. | 88 | | [`xor_sequence_rnn.py`](xor_sequence_rnn.py) | Sequence & Memory (RNN/LSTM) | A sequence-to-sequence task where an RNN must learn to compute the element-wise XOR of two input sequences. | 89 | | [`xor_sequence.py`](xor_sequence.py) | Sequence & Memory (RNN/LSTM) | A sequence-to-sequence task where an LSTM must learn to compute the element-wise XOR of two input sequences. | 90 | | [`xor_spot_rnn.py`](xor_spot_rnn.py) | Sequence & Memory (RNN/LSTM) | Tests an RNN's ability to learn a pointwise forget mechanism by predicting the XOR of two values at randomly marked spots in a sequence. | 91 | | [`xor_spot.py`](xor_spot.py) | Sequence & Memory (RNN/LSTM) | Tests an LSTM's ability to learn a pointwise forget mechanism by predicting the XOR of two values at randomly marked spots in a sequence. | 92 | 93 | ### Contributing 94 | 95 | This framework is designed for extensibility. You can contribute by adding new benchmarks to test optimizers against, or by implementing new optimizers to evaluate. 96 | 97 | #### Adding a New Benchmark 98 | 99 | To add a new benchmark, follow these steps: 100 | 101 | 1. **Create a new Python file** for your benchmark (e.g., `my_new_benchmark.py`). 102 | 2. **Implement the benchmark class.** Your class should follow the structure provided in [`benchmark_template.py`](benchmark_template.py:4). 103 | * **`__init__(self, device)`**: Initializes the benchmark's state. This is where the model parameters and any necessary data are created. The parameters to be optimized are stored in `self.params`. 104 | * **`__call__(self) -> torch.Tensor`**: This is the core of the benchmark, defining the objective function. It computes and returns a scalar loss tensor that the optimizer will try to minimize. 105 | * **`has_succeeded(self) -> bool`**: Defines the "win condition" for the benchmark. It returns `True` if the optimization is considered successful, typically based on the loss value falling below a predefined threshold. 106 | * **`get_params(self) -> list[torch.Tensor]`**: Returns the list of `torch.Tensor` parameters that will be optimized. 107 | 3. **Add your benchmark to `run_all_benchmarks.py`**: To include your new benchmark in the full suite, you will need to import it in [`run_all_benchmarks.py`](run_all_benchmarks.py) and add it to the list of benchmarks to be run. 108 | 4. **(Optional) Add a description to `README.md`**: Add a row to the "Benchmark Tests" table in this `README.md` to document your new benchmark. 109 | 110 | #### Adding a New Optimizer 111 | 112 | To add a new optimizer: 113 | 114 | 1. **Implement your optimizer class.** Your optimizer must follow the `torch.optim.Optimizer` interface. You can use [`optimizer_template.py`](optimizer_template.py:1) as a starting point. 115 | 2. **Make your optimizer available to the benchmark runner.** This is typically done by adding it to the `optimizer_mapping` dictionary in [`utils.py`](utils.py:1). 116 | 3. **Run the benchmarks.** You can now run your optimizer against the benchmarks using the `run_all_benchmarks.py` script. For example: 117 | ```bash 118 | python run_all_benchmarks.py --opt YourNewOptimizerName --difficulties easy 119 | -------------------------------------------------------------------------------- /benchmark/adversarial_gradient.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import param_norm_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | 14 | configs = { 15 | "trivial": {"frequency": 1000}, 16 | "easy": {"frequency": 100}, 17 | "medium": {"frequency": 10}, 18 | "hard": {"frequency": 7}, 19 | "extreme": {"frequency": 4}, 20 | "nightmare": {"frequency": 2}, 21 | } 22 | 23 | 24 | class Model(nn.Module): 25 | def __init__(self, frequency, size=1024): 26 | super().__init__() 27 | self.param = nn.Parameter(torch.randn(size)) 28 | self.register_buffer("step", torch.zeros(1)) 29 | self.frequency = 1 / frequency * 1.1 # to avoid repeating 30 | 31 | def forward(self): 32 | """Test optimizer's robustness to adversarial gradient patterns.""" 33 | self.step += 1 34 | # Create an oscillating adversarial component 35 | direction = torch.sin(self.step * torch.pi * self.frequency) 36 | # Main objective plus adversarial component 37 | return self.param.square().mean() + direction * self.param.mean() 38 | 39 | 40 | @app.command() 41 | def main( 42 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 43 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 44 | steps: int = 100, 45 | weight_decay: float = 0, 46 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 47 | trials: int = 100, 48 | win_condition_multiplier: float = 1.0, 49 | config: Optional[str] = None, 50 | ): 51 | frequency = configs.get(config, {}).get("frequency", 10) 52 | dtype = [getattr(torch, d) for d in dtype] 53 | model = Model(frequency).cuda().double() 54 | 55 | def data(): 56 | return None, None 57 | 58 | # More lenient condition due to adversarial component 59 | trial( 60 | model, 61 | data, 62 | None, 63 | param_norm_win_condition(win_condition_multiplier * 1e-3, 0), 64 | steps, 65 | opt[0], 66 | dtype[0], 67 | 1, 68 | 1, 69 | weight_decay, 70 | method[0], 71 | 1, 72 | 1, 73 | failure_threshold=7, 74 | base_lr=1e-3, 75 | trials=trials, 76 | ) # More attempts for adversarial case 77 | 78 | 79 | if __name__ == "__main__": 80 | app() 81 | -------------------------------------------------------------------------------- /benchmark/batch_size_scaling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from typing import List, Optional 4 | 5 | import torch 6 | import torch.backends.opt_einsum 7 | import typer 8 | from torch import nn 9 | 10 | from benchmark.utils import param_norm_win_condition, trial 11 | from heavyball.utils import set_torch 12 | 13 | app = typer.Typer(pretty_exceptions_enable=False) 14 | set_torch() 15 | 16 | configs = { 17 | "trivial": {"max_batch": 65536}, 18 | "easy": {"max_batch": 8192}, 19 | "medium": {"max_batch": 1024}, 20 | "hard": {"max_batch": 128}, 21 | "extreme": {"max_batch": 16}, 22 | "nightmare": {"max_batch": 2}, 23 | } 24 | 25 | 26 | class Model(nn.Module): 27 | def __init__(self, max_batch: int, size=1024): 28 | super().__init__() 29 | self.param = nn.Parameter(torch.randn(size)) 30 | self.max_batch = max_batch 31 | self.rng = random.Random(0x1238192) 32 | 33 | def forward(self): 34 | """Test optimizer's ability to handle different batch sizes and noise scales.""" 35 | generator = torch.Generator(device=self.param.device).manual_seed(self.rng.randint(0, 2**31)) 36 | noise = torch.randn(self.param.shape, generator=generator, device=self.param.device) 37 | scale = self.param.norm() / (noise.norm() + 1e-6) 38 | batch_scale = self.max_batch ** (self.rng.random() / 2) # sqrt of random uniform between 1 and max_batch 39 | noise *= scale.detach() / math.sqrt(batch_scale) 40 | return (self.param + noise).square().mean() 41 | 42 | 43 | @app.command() 44 | def main( 45 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 46 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 47 | steps: int = 100, 48 | weight_decay: float = 0, 49 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 50 | trials: int = 100, 51 | win_condition_multiplier: float = 1.0, 52 | config: Optional[str] = None, 53 | ): 54 | max_batch = configs.get(config, {}).get("max_batch", 256) 55 | dtype = [getattr(torch, d) for d in dtype] 56 | model = Model(max_batch).cuda().double() 57 | 58 | def data(): 59 | return None, None 60 | 61 | # Use a more lenient win condition since we have inherent noise 62 | trial( 63 | model, 64 | data, 65 | None, 66 | param_norm_win_condition(win_condition_multiplier * 1e-8, 0), 67 | steps, 68 | opt[0], 69 | dtype[0], 70 | 1, 71 | 1, 72 | weight_decay, 73 | method[0], 74 | 1, 75 | 1, 76 | failure_threshold=5, 77 | base_lr=1e-3, 78 | trials=trials, 79 | ) 80 | 81 | 82 | if __name__ == "__main__": 83 | app() 84 | -------------------------------------------------------------------------------- /benchmark/beale.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import random 3 | from typing import List, Optional 4 | 5 | import matplotlib.colors 6 | import torch 7 | import torch.backends.opt_einsum 8 | import typer 9 | from torch import nn 10 | 11 | from benchmark.utils import Plotter, SkipConfig, loss_win_condition, trial 12 | from heavyball.utils import set_torch 13 | 14 | app = typer.Typer(pretty_exceptions_enable=False) 15 | set_torch() 16 | 17 | 18 | def objective(x, y): 19 | x = x + 3 20 | y = y + 0.5 21 | return (1.5 - x + x * y) ** 2 + (2.25 - x + x * y**2) ** 2 + (2.625 - x + x * y**3) ** 2 22 | 23 | 24 | class Model(nn.Module): 25 | def __init__(self, x): 26 | super().__init__() 27 | self.param = nn.Parameter(torch.tensor(x).float()) 28 | 29 | def forward(self): 30 | return objective(*self.param) 31 | 32 | 33 | @app.command() 34 | def main( 35 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 36 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 37 | steps: int = 100, 38 | weight_decay: float = 0, 39 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 40 | show_image: bool = False, 41 | trials: int = 100, 42 | win_condition_multiplier: float = 1.0, 43 | config: Optional[str] = None, 44 | ): 45 | if config is not None and config != "trivial": 46 | raise SkipConfig("'config' must be 'trivial'.") 47 | dtype = [getattr(torch, d) for d in dtype] 48 | coords = (-7, -4) 49 | 50 | # Clean up old plots 51 | for path in pathlib.Path(".").glob("beale.png"): 52 | path.unlink() 53 | 54 | colors = list(matplotlib.colors.TABLEAU_COLORS.values()) 55 | rng = random.Random(0x1239121) 56 | rng.shuffle(colors) 57 | 58 | if show_image: 59 | model = Plotter(Model(coords), x_limits=(-8, 2), y_limits=(-8, 2), should_normalize=True) 60 | else: 61 | model = Model(coords) 62 | model.double() 63 | 64 | def data(): 65 | return None, None 66 | 67 | model = trial( 68 | model, 69 | data, 70 | None, 71 | loss_win_condition(win_condition_multiplier * 1e-8 * (not show_image)), 72 | steps, 73 | opt[0], 74 | dtype[0], 75 | 1, 76 | 1, 77 | weight_decay, 78 | method[0], 79 | 1, 80 | 1, 81 | base_lr=1e-4, 82 | trials=trials, 83 | return_best=show_image, 84 | ) 85 | 86 | if not show_image: 87 | return 88 | 89 | model.plot(save_path="beale.png") 90 | 91 | 92 | if __name__ == "__main__": 93 | app() 94 | -------------------------------------------------------------------------------- /benchmark/char_rnn.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List 3 | 4 | import torch 5 | import torch.backends.opt_einsum 6 | import torch.nn as nn 7 | import typer 8 | from torch.nn import functional as F 9 | 10 | from benchmark.utils import loss_win_condition, trial 11 | from heavyball.utils import set_torch 12 | 13 | app = typer.Typer(pretty_exceptions_enable=False) 14 | set_torch() 15 | 16 | 17 | class Take0(nn.Module): 18 | def forward(self, x): 19 | return x[0] 20 | 21 | 22 | class Model(nn.Module): 23 | def __init__(self, features: int, sequence: int): 24 | super().__init__() 25 | self.sequence = sequence 26 | self.net = nn.Sequential( 27 | nn.Embedding(256, features), 28 | nn.LSTM(features, features, 1, batch_first=True), # Removed dropout since num_layers=1 29 | Take0(), 30 | nn.Linear(features, 256), 31 | ) 32 | 33 | def forward(self, inp): 34 | return self.net(inp) 35 | 36 | 37 | @app.command() 38 | def main( 39 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 40 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 41 | features: int = 512, 42 | sequence: int = 256, 43 | batch: int = 16, 44 | steps: int = 100, 45 | weight_decay: float = 0, 46 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 47 | win_condition_multiplier: float = 1.0, 48 | trials: int = 10, 49 | ): 50 | dtype = [getattr(torch, d) for d in dtype] 51 | model = Model(features, sequence).cuda() 52 | 53 | # Load text data 54 | benchmark_dir = Path(__file__).parent 55 | with open(benchmark_dir / "shakespeare.txt", "rb") as f: 56 | text = f.read() 57 | chars = torch.frombuffer(text, dtype=torch.uint8).cuda().long() 58 | 59 | # Create holdout set 60 | chars = chars[(sequence + 1) * batch :] 61 | offsets = torch.arange(0, sequence + 1, device="cuda").repeat(batch, 1) 62 | 63 | def data(): 64 | batch_offsets = torch.randint(0, len(chars) - sequence - 1, (batch,), device="cuda") 65 | batch_offsets = batch_offsets[:, None] + offsets 66 | batch_chars = chars[batch_offsets] 67 | batch_chars = batch_chars.view(batch, sequence + 1) 68 | src = batch_chars[:, :-1] 69 | tgt = batch_chars[:, 1:] 70 | return src, tgt 71 | 72 | trial( 73 | model, 74 | data, 75 | F.cross_entropy, 76 | loss_win_condition(win_condition_multiplier * 2.0), 77 | steps, 78 | opt[0], 79 | dtype[0], 80 | features, 81 | batch, 82 | weight_decay, 83 | method[0], 84 | sequence, 85 | 1, 86 | failure_threshold=10, 87 | base_lr=1e-3, 88 | trials=trials, 89 | ) 90 | 91 | 92 | if __name__ == "__main__": 93 | app() 94 | -------------------------------------------------------------------------------- /benchmark/constrained_optimization.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import List, Optional 3 | 4 | import torch 5 | import torch.backends.opt_einsum 6 | import typer 7 | from torch import nn 8 | 9 | from benchmark.utils import trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | configs = { 16 | "trivial": {"penalty": 1e1}, 17 | "easy": {"penalty": 1e2}, 18 | "medium": {"penalty": 1e4}, 19 | "hard": {"size": 1e6}, 20 | "extreme": {"penalty": 1e8}, 21 | "nightmare": {"penalty": 1e10}, 22 | } 23 | 24 | # Objective: Minimize (x-2)^2 subject to x <= 1 25 | # Implemented using a penalty: (x-2)^2 + penalty * max(0, x - 1) 26 | TARGET_X = 1.0 27 | TOLERANCE = 1e-3 28 | 29 | 30 | def objective(x, penalty): 31 | """Objective function with a penalty for violating the constraint x <= 1.""" 32 | return (x - 2.0) ** 2 + penalty * torch.relu(x - TARGET_X) 33 | 34 | 35 | class Model(nn.Module): 36 | def __init__(self, penalty): 37 | super().__init__() 38 | # Using a tensor with requires_grad=True directly as the parameter 39 | self.param = nn.Parameter(torch.zeros((16,))) 40 | self.penalty = penalty 41 | 42 | def forward(self): 43 | return objective(self.param, self.penalty).mean() 44 | 45 | 46 | def win_condition(model, loss): 47 | with torch.no_grad(): 48 | success = ((model.param - TARGET_X).abs() < TOLERANCE).all().item() 49 | return success, {} 50 | 51 | 52 | @app.command() 53 | def main( 54 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 55 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 56 | steps: int = 200, 57 | # Increased steps slightly 58 | weight_decay: float = 0, 59 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 60 | trials: int = 50, # Reduced trials slightly for faster testing 61 | win_condition_multiplier: float = 1.0, # Not used directly, but kept for consistency 62 | config: Optional[str] = None, 63 | ): 64 | penalty = configs.get(config, {}).get("penalty", 1e6) 65 | dtype = [getattr(torch, d) for d in dtype] 66 | 67 | # Clean up old plots if any (though this benchmark doesn't plot) 68 | for path in pathlib.Path(".").glob("constrained_optimization*.png"): 69 | path.unlink() 70 | 71 | model = Model(penalty) 72 | model.double() # Use double for precision if needed 73 | 74 | # No external data needed for this simple objective 75 | def data(): 76 | return None, None 77 | 78 | # The loss is the objective value itself 79 | loss_fn = None 80 | 81 | trial( 82 | model, 83 | data, 84 | loss_fn, 85 | win_condition, 86 | steps, 87 | opt[0], 88 | dtype[0], 89 | 1, # size (not relevant here) 90 | 1, # batch (not relevant here) 91 | weight_decay, 92 | method[0], 93 | 1, # length (not relevant here) 94 | 1, # depth (not relevant here) 95 | failure_threshold=3, 96 | base_lr=1e-3, # Default base LR, hyperopt will search 97 | trials=trials, 98 | group=32, # Smaller group size might be better for simple problems 99 | ) 100 | 101 | 102 | if __name__ == "__main__": 103 | app() 104 | -------------------------------------------------------------------------------- /benchmark/create_hiker_analogy_diagram.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | def create_hiker_analogy_diagram(): 8 | """ 9 | Generates and saves a multi-panel diagram illustrating key optimization challenges. 10 | 11 | The diagram includes separate panels for: 12 | - A saddle point 13 | - A narrow ravine (representing ill-conditioning) 14 | - A sharp cliff (representing a discontinuous gradient) 15 | """ 16 | # --- Create the plot --- 17 | fig, axes = plt.subplots(1, 3, figsize=(18, 6), subplot_kw={"aspect": "equal"}) 18 | fig.suptitle( 19 | "Challenges in Optimization Landscapes: The Hiker's Analogy", 20 | fontsize=20, 21 | y=0.98, 22 | ) 23 | 24 | # --- Panel 1: Saddle Point --- 25 | ax1 = axes[0] 26 | x1 = np.linspace(-3, 3, 300) 27 | y1 = np.linspace(-3, 3, 300) 28 | X1, Y1 = np.meshgrid(x1, y1) 29 | Z1 = X1**3 + Y1**3 30 | levels1 = np.linspace(np.min(Z1), np.max(Z1), 30) 31 | ax1.contourf(X1, Y1, Z1, levels=levels1, cmap="viridis") 32 | ax1.contour(X1, Y1, Z1, levels=levels1, colors="black", linewidths=0.5, alpha=0.5) 33 | ax1.set_title("Saddle Point", fontsize=16) 34 | ax1.text(0, 0, "★", color="red", fontsize=20, ha="center", va="center") 35 | 36 | # --- Panel 2: Ravine (Ill-Conditioning) --- 37 | ax2 = axes[1] 38 | x2 = np.linspace(-3, 3, 300) 39 | y2 = np.linspace(-3, 3, 300) 40 | X2, Y2 = np.meshgrid(x2, y2) 41 | Z2 = 0.1 * X2**2 + 10 * Y2**2 # Rosenbrock-like narrow valley 42 | levels2 = np.linspace(np.min(Z2), np.max(Z2), 30) 43 | ax2.contourf(X2, Y2, Z2, levels=levels2, cmap="magma") 44 | ax2.contour(X2, Y2, Z2, levels=levels2, colors="black", linewidths=0.5, alpha=0.5) 45 | ax2.set_title("Ravine (Ill-Conditioning)", fontsize=16) 46 | ax2.text(0, 0, "★", color="cyan", fontsize=20, ha="center", va="center") 47 | 48 | # --- Panel 3: Cliff (Discontinuous Gradient) --- 49 | ax3 = axes[2] 50 | x3 = np.linspace(-3, 3, 300) 51 | y3 = np.linspace(-3, 3, 300) 52 | X3, Y3 = np.meshgrid(x3, y3) 53 | Z3 = -Y3 # A simple slope 54 | Z3[X3 > 0] += 3 # Create a sharp cliff 55 | levels3 = np.linspace(np.min(Z3), np.max(Z3), 30) 56 | ax3.contourf(X3, Y3, Z3, levels=levels3, cmap="plasma") 57 | ax3.contour(X3, Y3, Z3, levels=levels3, colors="black", linewidths=0.5, alpha=0.5) 58 | ax3.set_title("Cliff (Discontinuous Gradient)", fontsize=16) 59 | ax3.plot([0, 0], [-3, 3], "r--", lw=2) # Mark the cliff edge 60 | 61 | # --- Style all plots --- 62 | for ax in axes: 63 | ax.set_xticks([]) 64 | ax.set_yticks([]) 65 | ax.spines["top"].set_visible(False) 66 | ax.spines["right"].set_visible(False) 67 | ax.spines["bottom"].set_visible(False) 68 | ax.spines["left"].set_visible(False) 69 | 70 | fig.tight_layout(rect=[0, 0, 1, 0.95]) # Adjust layout for suptitle 71 | 72 | # --- Save the figure --- 73 | output_path = "docs/assets/hiker_analogy_diagram.png" 74 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 75 | plt.savefig(output_path, bbox_inches="tight", pad_inches=0.1, dpi=300) 76 | print(f"Multi-panel diagram saved to {output_path}") 77 | 78 | 79 | if __name__ == "__main__": 80 | create_hiker_analogy_diagram() 81 | -------------------------------------------------------------------------------- /benchmark/discontinuous_gradient.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import SkipConfig, param_norm_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | 14 | 15 | def objective(x): 16 | """Tests optimizer robustness to non-smooth landscapes with discontinuous gradients.""" 17 | return torch.where(x < 0, x**2, 2 * x).mean() # Discontinuous gradient at x=0 18 | 19 | 20 | class Model(nn.Module): 21 | def __init__(self, size=1024): 22 | super().__init__() 23 | self.param = nn.Parameter(torch.randn(size)) 24 | 25 | def forward(self): 26 | return objective(self.param) 27 | 28 | 29 | @app.command() 30 | def main( 31 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 32 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 33 | steps: int = 100, 34 | weight_decay: float = 0, 35 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 36 | trials: int = 100, 37 | win_condition_multiplier: float = 1.0, 38 | config: Optional[str] = None, 39 | ): 40 | if config is not None and config != "trivial": 41 | raise SkipConfig("'config' must be 'trivial'.") 42 | dtype = [getattr(torch, d) for d in dtype] 43 | model = Model().cuda().double() 44 | 45 | def data(): 46 | return None, None 47 | 48 | trial( 49 | model, 50 | data, 51 | None, 52 | param_norm_win_condition(win_condition_multiplier * 1e-4, 0), 53 | steps, 54 | opt[0], 55 | dtype[0], 56 | 1, 57 | 1, 58 | weight_decay, 59 | method[0], 60 | 1, 61 | 1, 62 | failure_threshold=3, 63 | base_lr=1e-3, 64 | trials=trials, 65 | ) 66 | 67 | 68 | if __name__ == "__main__": 69 | app() 70 | -------------------------------------------------------------------------------- /benchmark/dynamic_landscape.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests optimizer's ability to track moving targets. 3 | 4 | This benchmark simulates a dynamic loss landscape where the optimal parameters 5 | continuously shift over time. This tests the optimizer's ability to: 6 | 1. Track moving targets 7 | 2. Adapt to non-stationary objectives 8 | 3. Handle continuous parameter updates 9 | """ 10 | 11 | import itertools 12 | import math 13 | from typing import List, Optional 14 | 15 | import torch 16 | import torch.nn as nn 17 | import typer 18 | 19 | from benchmark.utils import loss_win_condition, trial 20 | from heavyball.utils import set_torch 21 | 22 | app = typer.Typer(pretty_exceptions_enable=False) 23 | set_torch() 24 | configs = { 25 | "trivial": {"frequency": 1000}, 26 | "easy": {"frequency": 100}, 27 | "medium": {"frequency": 20}, 28 | "hard": {"frequency": 10}, 29 | "extreme": {"frequency": 6}, 30 | "nightmare": {"frequency": 4}, 31 | } 32 | 33 | 34 | class ShiftingSphere(nn.Module): 35 | def __init__(self, dim, frequency): 36 | super().__init__() 37 | self.param = nn.Parameter(torch.randn(dim)) 38 | self.phase = 0 39 | self.frequency = 1 / frequency * 1.1 # so that we don't repeat numbers 40 | 41 | def forward(self): 42 | self.phase += self.frequency 43 | target = torch.linspace(0, 2 * math.pi, len(self.param), device=self.param.device, dtype=self.param.dtype) 44 | target = torch.sin(target + self.phase) 45 | return (self.param - target).square().mean() 46 | 47 | 48 | @app.command() 49 | def main( 50 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 51 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 52 | dim: int = 16384, 53 | steps: int = 500, 54 | weight_decay: float = 0, 55 | opt: List[str] = typer.Option(["adamw"], help="Optimizers to use"), 56 | win_condition_multiplier: float = 1.0, 57 | trials: int = 3, 58 | config: Optional[str] = None, 59 | ): 60 | """Run dynamic landscape benchmark with specified parameters.""" 61 | frequency = configs.get(config, {}).get("frequency", 0.1) 62 | dtype = [getattr(torch, d) for d in dtype] 63 | 64 | for args in itertools.product(method, dtype, [dim], opt, [weight_decay]): 65 | m, d, dim, o, wd = args 66 | 67 | model = ShiftingSphere(dim, frequency) 68 | 69 | def data(): 70 | return None, None 71 | 72 | # Win condition: average squared error should be small (parameters close to target) 73 | trial( 74 | model, 75 | data, 76 | None, 77 | loss_win_condition(0.01 * win_condition_multiplier), 78 | steps, 79 | [o], 80 | [d], 81 | 1, 82 | 1, 83 | wd, 84 | m, 85 | 1, 86 | 1, 87 | base_lr=0.1, 88 | trials=trials, 89 | ) 90 | 91 | 92 | if __name__ == "__main__": 93 | app() 94 | -------------------------------------------------------------------------------- /benchmark/exploding_gradient.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests optimizer's ability to handle exploding gradients. 3 | 4 | This benchmark creates a scenario where gradients can grow exponentially, 5 | testing the optimizer's: 6 | 1. Gradient clipping/scaling mechanisms 7 | 2. Numerical stability 8 | 3. Ability to make progress despite extreme gradient values 9 | """ 10 | 11 | import itertools 12 | from typing import List, Optional 13 | 14 | import torch 15 | import torch.nn as nn 16 | import typer 17 | 18 | from benchmark.utils import param_norm_win_condition, trial 19 | from heavyball.utils import set_torch 20 | 21 | app = typer.Typer(pretty_exceptions_enable=False) 22 | set_torch() 23 | configs = { 24 | "trivial": {"scale": 1}, 25 | "easy": {"scale": 2}, 26 | "medium": {"scale": 4}, 27 | "hard": {"scale": 8}, 28 | "extreme": {"scale": 12}, 29 | "nightmare": {"scale": 16}, 30 | } 31 | 32 | 33 | class ExplodingGradient(nn.Module): 34 | def __init__(self, scale, dim): 35 | super().__init__() 36 | self.param = nn.Parameter(torch.randn(dim)) 37 | self.scale = scale # Controls how quickly gradients grow 38 | 39 | def forward(self): 40 | # Creates exponentially growing gradients 41 | # Gradient will be scale * exp(|param|) * sign(param) 42 | return torch.exp(self.scale * torch.abs(self.param)).mean() 43 | 44 | 45 | @app.command() 46 | def main( 47 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 48 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 49 | dim: int = 512, 50 | steps: int = 500, 51 | weight_decay: float = 0, 52 | opt: List[str] = typer.Option(["adamw"], help="Optimizers to use"), 53 | win_condition_multiplier: float = 1.0, 54 | trials: int = 3, 55 | config: Optional[str] = None, 56 | ): 57 | scale = configs.get(config, {}).get("scale", 2) 58 | """Run exploding gradient benchmark with specified parameters.""" 59 | dtype = [getattr(torch, d) for d in dtype] 60 | 61 | for args in itertools.product(method, dtype, [dim], opt, [weight_decay]): 62 | m, d, dim, o, wd = args 63 | 64 | model = ExplodingGradient(dim, scale) 65 | 66 | def data(): 67 | return None, None 68 | 69 | # Win condition: loss should be close to 1.0 (exp(0) = 1) 70 | # Using 1.1 as threshold since perfect convergence is hard 71 | trial( 72 | model, 73 | data, 74 | None, 75 | param_norm_win_condition(0.01 * win_condition_multiplier, 0), 76 | steps, 77 | [o], 78 | [d], 79 | 1, 80 | 1, 81 | wd, 82 | m, 83 | 1, 84 | 1, 85 | base_lr=0.001, # Lower learning rate due to large gradients 86 | trials=trials, 87 | ) 88 | 89 | 90 | if __name__ == "__main__": 91 | app() 92 | -------------------------------------------------------------------------------- /benchmark/format_results.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | import scipy 7 | import typer 8 | from matplotlib.cm import ScalarMappable 9 | from matplotlib.colors import Normalize 10 | from matplotlib.patches import Rectangle 11 | 12 | 13 | def parse_loss(loss_str): 14 | loss_str = loss_str.strip() 15 | if not loss_str or loss_str.lower() == "nan": 16 | return float("nan") 17 | if loss_str.lower() == "inf": 18 | return float("inf") 19 | try: 20 | return ( 21 | float(loss_str) 22 | if "e" not in loss_str 23 | else float(loss_str.split("e")[0]) * 10 ** float(loss_str.split("e")[1]) 24 | ) 25 | except ValueError: 26 | return float("nan") 27 | 28 | 29 | def process_str(x, truthy): 30 | if x == "No": 31 | return "" 32 | if x == "Yes": 33 | return f"{truthy}-" 34 | return f"{x}-" 35 | 36 | 37 | def read_benchmark_results(file_path): 38 | with open(file_path, "r") as f: 39 | content = f.read() 40 | details = re.search(r"## Details\n\n(.*?)(?=\n##|\n\Z)", content, re.DOTALL | re.IGNORECASE) 41 | if not details: 42 | raise ValueError("Details section not found.") 43 | table = details.group(1).strip() 44 | lines = re.search(r"\|:?-+:(.*?)\|\n(.*)", table, re.DOTALL).group(2).strip().split("\n") 45 | data = [] 46 | for i, line in enumerate(lines): 47 | if not line.strip() or line.startswith("|---"): 48 | continue 49 | parts = [p.strip() for p in line.split("|")[1:-1]] 50 | if len(parts) < 8: 51 | continue 52 | try: 53 | caution = process_str(parts[2], "cautious") 54 | mars = process_str(parts[3], "mars") 55 | optimizer = f"{caution}{mars}{parts[1]}" 56 | optimizer = optimizer.replace("Foreach", "").replace("Cached", "").strip() 57 | data.append({ 58 | "benchmark": parts[0], 59 | "optimizer": optimizer, 60 | "success": parts[4] == "✓", 61 | "runtime": float(parts[5].replace("s", "")) if parts[5] else float("nan"), 62 | "loss": parse_loss(parts[6]) if parts[6] else float("nan"), 63 | "attempts": int(parts[7]) if parts[7].isdigit() else 0, 64 | }) 65 | except (IndexError, ValueError): 66 | continue 67 | return pd.DataFrame(data) 68 | 69 | 70 | def create_success_matrix(df): 71 | if df.empty: 72 | return pd.DataFrame() 73 | benchmarks = sorted(df["benchmark"].unique()) 74 | optimizers = sorted(df["optimizer"].unique()) 75 | success_matrix = pd.DataFrame(0, index=benchmarks, columns=optimizers, dtype=int) 76 | for _, row in df.iterrows(): 77 | if row["success"] and row["benchmark"] in success_matrix.index and row["optimizer"] in success_matrix.columns: 78 | success_matrix.loc[row["benchmark"], row["optimizer"]] = 1 79 | base_tasks = sorted(set(b.split("-")[0] for b in benchmarks)) 80 | success_total_matrix = pd.DataFrame(0, index=base_tasks, columns=optimizers, dtype=int) 81 | for benchmark in success_matrix.index: 82 | base_task = benchmark.split("-")[0] 83 | if base_task in success_total_matrix.index: 84 | success_total_matrix.loc[base_task] += success_matrix.loc[benchmark] 85 | return success_total_matrix 86 | 87 | 88 | def normalize_matrix_by_row_max(matrix): 89 | max_in_row = matrix.max(axis=1) 90 | max_in_row[max_in_row == 0] = 1 91 | return matrix.div(max_in_row, axis=0) * 100 92 | 93 | 94 | def create_visual_matrix_normalized(success_total_matrix): 95 | if success_total_matrix.empty: 96 | return None 97 | tasks_to_keep = success_total_matrix.sum(axis=1) > 0 98 | filtered_matrix = success_total_matrix[tasks_to_keep].copy() 99 | if filtered_matrix.empty: 100 | return None 101 | filtered_matrix[:] = scipy.stats.rankdata(filtered_matrix, axis=1, method="dense") 102 | normalized_matrix = normalize_matrix_by_row_max(filtered_matrix) 103 | optimizer_means = normalized_matrix.mean(axis=0) 104 | task_means = normalized_matrix.mean(axis=1) 105 | overall_mean = optimizer_means.mean() 106 | plot_matrix: pd.DataFrame = normalized_matrix.copy() 107 | plot_matrix.loc["Avg. Optimizer"] = optimizer_means 108 | 109 | # weight of 0.5, as "jack of all trades, master of none" is better than "perfect at xor but awful at delay" 110 | optimizer_score = (normalized_matrix**0.5).mean(axis=0) 111 | optimizer_indices = np.argsort(-optimizer_score.to_numpy()) 112 | plot_matrix = plot_matrix.iloc[:, optimizer_indices] 113 | 114 | full_task_means = pd.concat([task_means, pd.Series([overall_mean], index=["Avg. Optimizer"])]) 115 | plot_matrix["Avg. Task"] = full_task_means 116 | plot_tasks = plot_matrix.index 117 | plot_optimizers = plot_matrix.columns 118 | 119 | plt.style.use("seaborn-v0_8-whitegrid") 120 | fig, ax = plt.subplots( 121 | figsize=(max(14, len(plot_optimizers) * 0.8), max(10, len(plot_tasks) * 0.6)), facecolor="white" 122 | ) 123 | cmap = plt.cm.Blues 124 | norm = Normalize(vmin=0, vmax=100) 125 | mapper = ScalarMappable(norm=norm, cmap=cmap) 126 | 127 | for i, task in enumerate(plot_tasks): 128 | for j, optimizer in enumerate(plot_optimizers): 129 | value = plot_matrix.loc[task, optimizer] 130 | original_count = 0 131 | if task != "Avg. Optimizer" and optimizer != "Avg. Task": 132 | if task in success_total_matrix.index and optimizer in success_total_matrix.columns: 133 | original_count = success_total_matrix.loc[task, optimizer] 134 | is_summary = task == "Avg. Optimizer" or optimizer == "Avg. Task" 135 | face_color = (0.85, 0.85, 0.85, 0.7) if not is_summary and original_count == 0 else mapper.to_rgba(value) 136 | edge_color = "#666666" if is_summary else "#AAAAAA" 137 | edge_width = 1.5 if is_summary else 0.5 138 | rect = Rectangle((j - 0.5, i - 0.5), 1, 1, facecolor=face_color, edgecolor=edge_color, linewidth=edge_width) 139 | ax.add_patch(rect) 140 | # brightness = sum(face_color[:3]) * 0.333 141 | # text_color = 'white' if brightness < 0.6 else 'black' 142 | # ax.text(j, i, f"{value:.0f}", ha='center', va='center', color=text_color, fontsize=12, fontweight='bold') 143 | 144 | ax.axhline(len(plot_tasks) - 1.5, color="#333333", linewidth=2, linestyle="-") 145 | ax.axvline(len(plot_optimizers) - 1.5, color="#333333", linewidth=2, linestyle="-") 146 | ax.set_xticks(np.arange(len(plot_optimizers))) 147 | ax.set_yticks(np.arange(len(plot_tasks))) 148 | ax.set_xticklabels(plot_optimizers, rotation=45, ha="right", fontsize=14, fontweight="bold", color="#333333") 149 | ax.set_yticklabels(plot_tasks, fontsize=14, fontweight="bold", color="#333333") 150 | ax.set_xlabel("Optimizer", fontsize=16, fontweight="bold", labelpad=15, color="#333333") 151 | ax.set_ylabel("Task", fontsize=16, fontweight="bold", labelpad=15, color="#333333") 152 | ax.set_title( 153 | "Success Rate by Task and Optimizer\n(100% = Best Performance for Each Task)", 154 | fontsize=18, 155 | fontweight="bold", 156 | pad=20, 157 | color="#333333", 158 | ) 159 | 160 | cbar = fig.colorbar(mapper, ax=ax, pad=0.02, aspect=30, shrink=0.8) 161 | cbar.set_label("Success Rate (%)", rotation=270, labelpad=20, fontsize=14, fontweight="bold", color="#333333") 162 | cbar.ax.tick_params(labelsize=12, color="#333333", labelcolor="#333333") 163 | 164 | ax.set_facecolor("#F5F5F5") 165 | fig.patch.set_facecolor("white") 166 | for spine in ax.spines.values(): 167 | spine.set_visible(True) 168 | spine.set_color("#333333") 169 | spine.set_linewidth(1.5) 170 | 171 | plt.tight_layout(rect=[0.02, 0.05, 0.98, 0.95]) 172 | return fig 173 | 174 | 175 | app = typer.Typer(pretty_exceptions_enable=False) 176 | 177 | 178 | @app.command() 179 | def main(file: str = typer.Argument("benchmark_results.md")): 180 | try: 181 | df = read_benchmark_results(file) 182 | if df.empty: 183 | print("No data loaded from benchmark file.") 184 | return 185 | except Exception as e: 186 | print(f"Error reading benchmark file: {e}") 187 | return 188 | success_total_matrix = create_success_matrix(df) 189 | if success_total_matrix.empty: 190 | print("No successful runs found.") 191 | return 192 | fig = create_visual_matrix_normalized(success_total_matrix) 193 | if fig is None: 194 | print("Plot generation failed.") 195 | return 196 | plt.savefig("benchmark_matrix.png", dpi=300, bbox_inches="tight", facecolor="white", pad_inches=0.3) 197 | print("Saved heatmap to: benchmark_heatmap.png") 198 | plt.close(fig) 199 | 200 | 201 | if __name__ == "__main__": 202 | app() 203 | -------------------------------------------------------------------------------- /benchmark/generate_animations.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import math 3 | import pathlib 4 | from copy import deepcopy 5 | from typing import List 6 | 7 | import matplotlib.animation as animation 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import torch 11 | import typer 12 | from utils import get_optim as get_optimizer 13 | 14 | app = typer.Typer(pretty_exceptions_enable=False) 15 | 16 | 17 | @app.command() 18 | def main( 19 | benchmark_name: str = typer.Option(..., help="Name of the benchmark to run (e.g., 'saddle_point')."), 20 | optimizer_names: List[str] = typer.Option( 21 | ..., "--optimizer-name", help="Name of an optimizer to include in the comparison." 22 | ), 23 | output_file: str = typer.Option(..., help="Path to save the generated GIF."), 24 | steps: int = 100, 25 | ): 26 | """ 27 | Generates an animated GIF of optimizer paths on a given benchmark. 28 | """ 29 | try: 30 | benchmark_module = importlib.import_module(f"benchmark.{benchmark_name}") 31 | except ImportError: 32 | print(f"Error: Benchmark '{benchmark_name}' not found.") 33 | raise typer.Exit(1) 34 | 35 | if benchmark_name == "saddle_point": 36 | model = benchmark_module.Model(1, 0) 37 | 38 | def objective_fn(*xs): 39 | return benchmark_module.objective(*xs, power=model.power) 40 | 41 | x_limits, y_limits = (-2, 2), (-2, 2) 42 | elif benchmark_name == "rosenbrock": 43 | coords = (-7, 4) 44 | model = benchmark_module.Model(coords) 45 | objective_fn = benchmark_module.objective 46 | x_limits, y_limits = (-10, 10), (-10, 10) 47 | elif benchmark_name == "beale": 48 | coords = (-7, -4) 49 | model = benchmark_module.Model(coords) 50 | objective_fn = benchmark_module.objective 51 | x_limits, y_limits = (-8, 2), (-8, 2) 52 | else: 53 | print(f"Error: Benchmark '{benchmark_name}' is not supported for animation yet.") 54 | raise typer.Exit(1) 55 | 56 | trajectories = [] 57 | models = [] 58 | for optimizer_name in optimizer_names: 59 | print(f"Tuning LR for {optimizer_name}...") 60 | best_lr = None 61 | min_loss = float("inf") 62 | tuning_steps = math.ceil(steps / 3) 63 | lr_candidates = np.logspace(-8, 0, 50) 64 | 65 | for test_lr in lr_candidates: 66 | temp_model = deepcopy(model) 67 | temp_optimizer = get_optimizer(optimizer_name, temp_model.parameters(), lr=test_lr) 68 | 69 | def _closure(): 70 | loss = temp_model() 71 | loss.backward() 72 | 73 | for _ in range(tuning_steps): 74 | temp_optimizer.zero_grad() 75 | temp_optimizer.step(_closure) 76 | 77 | final_loss = temp_model().item() 78 | if final_loss < min_loss: 79 | min_loss = final_loss 80 | best_lr = test_lr 81 | 82 | print(f" > Best LR for {optimizer_name}: {best_lr:.5f}") 83 | 84 | m = deepcopy(model) 85 | models.append(m) 86 | optimizer = get_optimizer(optimizer_name, m.parameters(), lr=best_lr) 87 | 88 | trajectory = [m.param.detach().clone()] 89 | 90 | def _closure(): 91 | loss = m() 92 | loss.backward() 93 | 94 | for _ in range(steps): 95 | optimizer.zero_grad() 96 | optimizer.step(_closure) 97 | trajectory.append(m.param.detach().clone()) 98 | 99 | trajectories.append(list(torch.stack(trajectory).cpu().numpy())) 100 | print(f" > Final position for {optimizer_name}: {trajectories[-1][-1]}") 101 | 102 | target_trajectory = None 103 | if benchmark_name == "dynamic_targets": 104 | target_trajectory = models[0].get_target_trajectory() 105 | 106 | paths_xy = [] 107 | for traj in trajectories: 108 | path_array = np.array(traj) 109 | paths_xy.append((path_array[:, 0], path_array[:, 1])) 110 | 111 | target_path_xy = None 112 | if target_trajectory: 113 | target_array = np.array(target_trajectory) 114 | target_path_xy = (target_array[:, 0], target_array[:, 1]) 115 | 116 | fig, ax = plt.subplots(figsize=(8, 6)) 117 | ax.set_xlim(x_limits) 118 | ax.set_ylim(y_limits) 119 | ax.set_xlabel("x") 120 | ax.set_ylabel("y") 121 | ax.set_title(f"{benchmark_name.replace('_', ' ').title()} with {', '.join(optimizer_names)}") 122 | 123 | if objective_fn: 124 | x = torch.linspace(x_limits[0], x_limits[1], 100) 125 | y = torch.linspace(y_limits[0], y_limits[1], 100) 126 | X, Y = torch.meshgrid(x, y, indexing="ij") 127 | Z = objective_fn(X, Y) 128 | Z -= Z.min() 129 | Z = Z.log() 130 | ax.contourf(X.numpy(), Y.numpy(), Z.numpy(), levels=50, cmap="viridis") 131 | ax.contour(X.numpy(), Y.numpy(), Z.numpy(), levels=50, colors="k", linewidths=1) 132 | 133 | lines = [ax.plot([], [], "x-", label=name)[0] for name in optimizer_names] 134 | ax.legend() 135 | 136 | target_dot = None 137 | if benchmark_name == "dynamic_targets": 138 | (target_dot,) = ax.plot([], [], "ro", label="Target") 139 | ax.legend() 140 | 141 | def init(): 142 | for line in lines: 143 | line.set_data([], []) 144 | if target_dot: 145 | target_dot.set_data([], []) 146 | return tuple(lines) + (target_dot,) 147 | return tuple(lines) 148 | 149 | frames = min(100, steps + 1) 150 | 151 | def animate(i): 152 | i = math.ceil(i / frames * steps) 153 | for j, line in enumerate(lines): 154 | x_data, y_data = paths_xy[j] 155 | line.set_data(x_data[: i + 1], y_data[: i + 1]) 156 | 157 | if target_path_xy: 158 | target_x, target_y = target_path_xy 159 | if i > 0: 160 | target_dot.set_data([target_x[i - 1]], [target_y[i - 1]]) 161 | else: 162 | target_dot.set_data([], []) 163 | return tuple(lines) + (target_dot,) 164 | 165 | return tuple(lines) 166 | 167 | ani = animation.FuncAnimation(fig, animate, init_func=init, frames=frames, interval=100, blit=True) 168 | 169 | output_path = pathlib.Path(output_file) 170 | output_path.parent.mkdir(parents=True, exist_ok=True) 171 | ani.save(output_path, writer="ffmpeg", fps=10) 172 | print(f"Animation saved to {output_file}") 173 | 174 | 175 | if __name__ == "__main__": 176 | app() 177 | -------------------------------------------------------------------------------- /benchmark/gradient_delay.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import List, Optional 3 | 4 | import torch 5 | import torch.backends.opt_einsum 6 | import typer 7 | from torch import nn 8 | 9 | from benchmark.utils import loss_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | configs = { 16 | "trivial": {"max_delay": 2}, 17 | "easy": {"max_delay": 4}, 18 | "medium": {"max_delay": 16}, 19 | "hard": {"max_delay": 64}, 20 | "extreme": {"max_delay": 128}, 21 | "nightmare": {"max_delay": 256}, 22 | } 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, max_delay=16, param_size=256): 27 | super().__init__() 28 | self.params = nn.ParameterList([nn.Parameter(torch.randn(param_size)) for _ in range(max_delay)]) 29 | # Different update frequencies for each parameter 30 | self.delays = [i for i in range(max_delay)] 31 | self.step = 0 32 | self.grad_queues = [deque(maxlen=i + 1) for i in self.delays] 33 | 34 | def forward(self): 35 | """Test optimizer's ability to handle delayed gradients.""" 36 | total_loss = 0 37 | self.step += 1 38 | 39 | for param, delay, queue in zip(self.params, self.delays, self.grad_queues): 40 | # Current loss for this parameter 41 | loss = param.square().mean() 42 | 43 | # Store the gradient in the queue 44 | queue.append(loss) 45 | 46 | # Only add to total loss when we have enough history 47 | if len(queue) == queue.maxlen and self.step % (delay + 1) == 0: 48 | total_loss = total_loss + queue[0] # Use oldest gradient 49 | 50 | return total_loss / len(self.params) 51 | 52 | 53 | @app.command() 54 | def main( 55 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 56 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 57 | steps: int = 100, 58 | weight_decay: float = 0, 59 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 60 | trials: int = 100, 61 | win_condition_multiplier: float = 1.0, 62 | config: Optional[str] = None, 63 | ): 64 | max_delay = configs.get(config, {}).get("max_delay", 4) 65 | dtype = [getattr(torch, d) for d in dtype] 66 | model = Model(max_delay).cuda().double() 67 | 68 | def data(): 69 | return None, None 70 | 71 | # More lenient win condition and more steps due to delayed updates 72 | trial( 73 | model, 74 | data, 75 | None, 76 | loss_win_condition(win_condition_multiplier * 1e-4), 77 | steps * 2, 78 | opt[0], 79 | dtype[0], 80 | 1, 81 | 1, 82 | weight_decay, 83 | method[0], 84 | 1, 85 | 1, 86 | failure_threshold=5, 87 | base_lr=1e-3, 88 | trials=trials, 89 | ) # Double steps, more attempts 90 | 91 | 92 | if __name__ == "__main__": 93 | app() 94 | -------------------------------------------------------------------------------- /benchmark/gradient_noise_scale.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import param_norm_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | 14 | configs = { 15 | "trivial": {"offset": 32}, 16 | "easy": {"offset": 16}, 17 | "medium": {"offset": 8}, 18 | "hard": {"offset": 4}, 19 | "extreme": {"offset": 2}, 20 | "nightmare": {"offset": 1}, 21 | } 22 | 23 | 24 | class Model(nn.Module): 25 | def __init__(self, offset, size=4096): 26 | super().__init__() 27 | self.param = nn.Parameter(torch.randn(size)) 28 | self.register_buffer("step", torch.zeros(1)) 29 | self.offset = offset 30 | 31 | def forward(self): 32 | """Test optimizer's ability to handle changing noise levels during training.""" 33 | self.step += 1 34 | # Noise that decreases over time 35 | noise_scale = 1.0 / (self.offset + self.step) 36 | noise = torch.randn_like(self.param) * noise_scale 37 | return (self.param + noise).square().mean() 38 | 39 | 40 | @app.command() 41 | def main( 42 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 43 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 44 | steps: int = 100, 45 | weight_decay: float = 0, 46 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 47 | trials: int = 100, 48 | win_condition_multiplier: float = 1.0, 49 | config: Optional[str] = None, 50 | ): 51 | offset = configs.get(config, {}).get("offset", 4) 52 | dtype = [getattr(torch, d) for d in dtype] 53 | model = Model(offset).cuda().double() 54 | 55 | def data(): 56 | return None, None 57 | 58 | # Lenient initial condition due to high initial noise 59 | trial( 60 | model, 61 | data, 62 | None, 63 | param_norm_win_condition(win_condition_multiplier * 1e-3, 0), 64 | steps, 65 | opt[0], 66 | dtype[0], 67 | 1, 68 | 1, 69 | weight_decay, 70 | method[0], 71 | 1, 72 | 1, 73 | failure_threshold=5, 74 | base_lr=1e-3, 75 | trials=trials, 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | app() 81 | -------------------------------------------------------------------------------- /benchmark/grokking.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import itertools 3 | import random 4 | from collections import defaultdict 5 | from pathlib import Path 6 | from typing import List 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import torch 11 | import torch.backends.opt_einsum 12 | import torch.nn as nn 13 | import typer 14 | from torch.utils.data import DataLoader 15 | 16 | import heavyball 17 | from benchmark.utils import get_optim 18 | from heavyball.utils import set_torch 19 | 20 | app = typer.Typer(pretty_exceptions_enable=False) 21 | set_torch() 22 | 23 | 24 | class ModularMLP(nn.Module): 25 | def __init__(self, numbers, p, hidden_dim): 26 | super().__init__() 27 | self.net = nn.Sequential( 28 | nn.Embedding(p, hidden_dim), 29 | nn.Flatten(), 30 | nn.Linear(numbers * hidden_dim, hidden_dim), 31 | nn.LayerNorm(hidden_dim), 32 | nn.LeakyReLU(), 33 | nn.Linear(hidden_dim, hidden_dim), 34 | nn.LayerNorm(hidden_dim), 35 | nn.LeakyReLU(), 36 | nn.Linear(hidden_dim, p), 37 | ) 38 | 39 | def forward(self, x): 40 | return self.net(x) 41 | 42 | 43 | class ModuloDataset(torch.utils.data.Dataset): 44 | def __init__(self, p, numbers, min_idx, length, batch_size): 45 | length = length // batch_size 46 | self.p = p 47 | self.numbers = numbers 48 | self.n_samples = length 49 | self.min_idx = min_idx 50 | self.max_idx = min_idx + length 51 | self.batch_size = batch_size 52 | 53 | def __len__(self): 54 | return self.n_samples 55 | 56 | def __getitem__(self, idx): 57 | generator = torch.Generator() 58 | generator.manual_seed(random.Random(min(idx + self.min_idx, self.max_idx)).randint(0, 2**32)) 59 | x = torch.randint(0, self.p, (self.batch_size, self.numbers), generator=generator) 60 | y = (x.sum(dim=-1) % self.p).long() 61 | return x, y 62 | 63 | 64 | def evaluate(model, loader, device): 65 | """Evaluate model accuracy""" 66 | model.eval() 67 | correct = total = 0 68 | with torch.no_grad(): 69 | for x, y in loader: 70 | x, y = x.to(device), y.to(device) 71 | out = model(x) 72 | pred = out.argmax(dim=1) 73 | correct += (pred == y).sum().detach() 74 | total += y.size(0) 75 | return correct / total 76 | 77 | 78 | def plot_results(train_losses, test_accs, steps_to_grok=None, save_path=None): 79 | """Plot training curves""" 80 | _fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True) 81 | 82 | # Plot training loss 83 | ax1.plot(train_losses, label="Training Loss") 84 | ax1.set_yscale("log") 85 | ax1.set_ylabel("Loss") 86 | ax1.set_title("Training Loss Over Time") 87 | ax1.grid(True) 88 | 89 | # Plot test accuracy 90 | eval_steps = np.arange(0, len(train_losses), len(train_losses) // len(test_accs)) 91 | ax2.plot(eval_steps, test_accs, label="Test Accuracy", color="orange") 92 | ax2.axhline(y=0.9, color="r", linestyle="--", label="Grokking Threshold") 93 | ax2.set_ylabel("Accuracy") 94 | ax2.set_xlabel("Steps") 95 | ax2.set_title("Test Accuracy Over Time") 96 | ax2.grid(True) 97 | 98 | if steps_to_grok is not None: 99 | ax2.axvline(x=steps_to_grok, color="g", linestyle="--", label=f"Grokking Step ({steps_to_grok})") 100 | 101 | ax1.legend() 102 | ax2.legend() 103 | plt.tight_layout() 104 | if save_path: 105 | plt.savefig(save_path) 106 | plt.close() 107 | 108 | 109 | @app.command() 110 | def main( 111 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 112 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 113 | opt: List[str] = typer.Option( 114 | ["ForeachSOAP", "PaLMForeachSOAP", "PrecondScheduleForeachSOAP"], help="Optimizers to use" 115 | ), 116 | steps: int = 100, 117 | batch_size: int = 32, 118 | hidden_dim: int = 32, 119 | p: int = 257, 120 | numbers: int = 4, 121 | weight_decay: float = 0, 122 | lr: float = 1e-4, 123 | train_percent: float = 0.1, 124 | eval_samples: int = 1024, 125 | printervall: int = 1000, 126 | ): 127 | dtype = [getattr(torch, d) for d in dtype] 128 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 129 | 130 | # Clean up old plots 131 | plot_dir = Path(".") 132 | for path in plot_dir.glob("grokking_*.png"): 133 | path.unlink() 134 | 135 | # Pre-generate datasets 136 | unique_samples = p**numbers 137 | train_data = ModuloDataset(p, numbers, 0, int(unique_samples * train_percent), batch_size) 138 | test_data = ModuloDataset(p, numbers, train_data.max_idx, eval_samples, eval_samples) 139 | 140 | print(f"Training on {train_data.n_samples * batch_size:,} samples - {train_percent * 100}%") 141 | print(f"Testing on {eval_samples:,} samples") 142 | 143 | train_loader = DataLoader( 144 | train_data, 145 | collate_fn=lambda x: x[0], 146 | batch_size=1, 147 | shuffle=False, 148 | pin_memory=True, 149 | num_workers=4, 150 | drop_last=True, 151 | prefetch_factor=16, 152 | persistent_workers=True, 153 | ) 154 | 155 | test_loader = DataLoader( 156 | test_data, 157 | collate_fn=lambda x: x[0], 158 | batch_size=1, 159 | shuffle=False, 160 | pin_memory=True, 161 | num_workers=4, 162 | drop_last=True, 163 | prefetch_factor=32, 164 | ) 165 | test_loader = list(test_loader) 166 | test_loader = [[x.pin_memory() for x in i] for i in test_loader] 167 | 168 | train_iter = iter(train_loader) 169 | history = defaultdict(list) 170 | 171 | def data(): 172 | """Get next batch from the dataloader""" 173 | nonlocal train_iter 174 | try: 175 | x, y = next(train_iter) 176 | except (StopIteration, NameError): 177 | train_iter = iter(train_loader) 178 | x, y = next(train_iter) 179 | return x.to(device), y.to(device) 180 | 181 | criterion = nn.CrossEntropyLoss() 182 | 183 | def win_condition(model, loss_hist): 184 | """Check if model has achieved grokking""" 185 | if not isinstance(loss_hist, float): 186 | loss = loss_hist 187 | else: 188 | loss = loss_hist 189 | 190 | history["loss"].append(loss) 191 | 192 | if loss > 0.1: # Not converged yet 193 | return False, {} 194 | 195 | # If loss is low, check test accuracy 196 | acc = evaluate(model, test_loader, device) 197 | history["test_acc"].append(acc) 198 | return acc > 0.9, {"test_acc": acc} 199 | 200 | global_model = ModularMLP(numbers, p, hidden_dim).to(device) 201 | global_model = torch.compile(global_model, mode="max-autotune-no-cudagraphs") 202 | for d, o in itertools.product(dtype, opt): 203 | print(f"\nRunning {o} with {d}") 204 | model = copy.deepcopy(global_model) 205 | model.to(dtype=d) 206 | 207 | history.clear() 208 | 209 | # Get optimizer class 210 | optimizer_class = getattr(heavyball, o) 211 | optimizer = get_optim(optimizer_class, model.parameters(), lr=lr, weight_decay=weight_decay) 212 | 213 | loss_hist = torch.empty(steps) 214 | 215 | # Training loop 216 | for step in range(steps): 217 | model.train() 218 | x, y = data() 219 | 220 | optimizer.zero_grad() 221 | out = model(x) 222 | loss = criterion(out, y) 223 | loss.backward() 224 | optimizer.step() 225 | 226 | with torch.no_grad(): 227 | loss_hist[step] = loss.detach() 228 | 229 | if step % printervall == 0: 230 | lh = loss_hist[:step][-printervall:].mean().item() 231 | acc = evaluate(model, test_loader, device).item() 232 | history["test_acc"].append(acc) 233 | print(f"Step {step}: Loss = {lh:.4f}, Test Acc = {acc:.4f}") 234 | 235 | # Plot results 236 | plot_name = plot_dir / f"grokking_{o}_{d}_lr{lr}_h{hidden_dim}_p{p}.png" 237 | plot_results( 238 | loss_hist.cpu().numpy(), 239 | history["test_acc"], 240 | next((i for i, acc in enumerate(history["test_acc"]) if acc > 0.9), None), 241 | plot_name, 242 | ) 243 | print(f"Training curves saved to {plot_name}") 244 | 245 | 246 | if __name__ == "__main__": 247 | app() 248 | -------------------------------------------------------------------------------- /benchmark/layer_wise_scale.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import loss_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | 14 | configs = { 15 | "trivial": {"scale": 2}, 16 | "easy": {"scale": 1e1}, 17 | "medium": {"scale": 1e3}, 18 | "hard": {"size": 1e5}, 19 | "extreme": {"scale": 1e7}, 20 | "nightmare": {"scale": 1e9}, 21 | } 22 | 23 | 24 | class Model(nn.Module): 25 | def __init__(self, scale: float, size=1024): 26 | super().__init__() 27 | # Simulate different layer scales in deep networks 28 | self.layer1 = nn.Parameter(torch.randn(size)) # Small gradients 29 | self.layer2 = nn.Parameter(torch.randn(size)) # Medium gradients 30 | self.layer3 = nn.Parameter(torch.randn(size)) # Large gradients 31 | self.scale = scale 32 | 33 | def forward(self): 34 | """Test optimizer's ability to handle different gradient scales across layers.""" 35 | # Each layer contributes equally to the loss but has very different scales 36 | return ( 37 | self.layer1.square().mean() * self.scale 38 | + self.layer2.square().mean() 39 | + self.layer3.square().mean() / self.scale 40 | ) / 3 41 | 42 | 43 | @app.command() 44 | def main( 45 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 46 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 47 | steps: int = 100, 48 | weight_decay: float = 0, 49 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 50 | trials: int = 100, 51 | win_condition_multiplier: float = 1.0, 52 | config: Optional[str] = None, 53 | ): 54 | scale = configs.get(config, {}).get("scale", 1e3) 55 | dtype = [getattr(torch, d) for d in dtype] 56 | model = Model(scale).cuda().double() 57 | 58 | def data(): 59 | return None, None 60 | 61 | # More lenient win condition due to vastly different scales 62 | trial( 63 | model, 64 | data, 65 | None, 66 | loss_win_condition(win_condition_multiplier * 1e-4), 67 | steps, 68 | opt[0], 69 | dtype[0], 70 | 1, 71 | 1, 72 | weight_decay, 73 | method[0], 74 | 1, 75 | 1, 76 | failure_threshold=5, 77 | base_lr=1e-4, 78 | trials=trials, 79 | ) # Lower learning rate and more attempts 80 | 81 | 82 | if __name__ == "__main__": 83 | app() 84 | -------------------------------------------------------------------------------- /benchmark/loss_contour.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import tqdm 9 | 10 | import heavyball 11 | 12 | device = "cuda" 13 | heavyball.utils.compile_mode = None 14 | heavyball.utils.dynamic = True 15 | heavyball.utils.set_torch() 16 | 17 | 18 | class Sine(nn.Module): 19 | def forward(self, x): 20 | return torch.sin(x) 21 | 22 | 23 | class Residual(nn.Module): 24 | def __init__(self, wrapped): 25 | super(Residual, self).__init__() 26 | self.wrapped = wrapped 27 | 28 | def forward(self, x): 29 | return x + self.wrapped(x) 30 | 31 | 32 | class DatasetNorm(nn.Module): 33 | def __init__(self, features: int, momentum: float = 0.99): 34 | super().__init__() 35 | self.weight = nn.Parameter(torch.stack([torch.ones(features), torch.zeros(features)], 1)) 36 | self.register_buffer("stats", torch.zeros(features * 2)) 37 | self.register_buffer("step", torch.tensor(0)) 38 | self.momentum = momentum 39 | 40 | def forward(self, x): 41 | if True: 42 | with torch.no_grad(): 43 | mean, sq_mean = x.mean(dim=0), (x**2).mean(dim=0) 44 | stats = torch.cat([mean, sq_mean]) 45 | self.step.add_(1) 46 | self.stats.lerp_(stats, 1 - heavyball.utils.beta_debias(self.momentum, self.step)) 47 | # self.stats.lerp_(stats, self.step == 1) 48 | mean, sq_mean = self.stats.chunk(2) 49 | std = (sq_mean - mean**2).clamp_min_(1e-6).sqrt() 50 | else: 51 | std, mean = 1, 0 52 | weight, bias = self.weight.unbind(1) 53 | return (x - mean) / std * weight + bias 54 | 55 | 56 | class MLP(nn.Module): 57 | def __init__(self, in_shape, out_shape, width, depth, act=Sine(), expanded: int = 256): 58 | super(MLP, self).__init__() 59 | layers = [] 60 | layers.append(nn.Linear(in_shape, width)) 61 | 62 | for _ in range(depth - 1): 63 | layers.append( 64 | Residual( 65 | nn.Sequential( 66 | nn.Linear(width, expanded), # 67 | act, # 68 | DatasetNorm(expanded), # 69 | nn.Linear(expanded, width), 70 | ) 71 | ) 72 | ) 73 | layers.append(DatasetNorm(width)) 74 | layers.append(nn.Linear(width, out_shape)) 75 | self.model = nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | return self.model(x) 79 | 80 | 81 | def generate_two_moons_torch(n_samples=1000, noise=0.1, random_state=None): 82 | if random_state is not None: 83 | torch.manual_seed(random_state) 84 | 85 | half_samples = n_samples // 2 86 | 87 | theta1 = torch.linspace(0, np.pi, half_samples, device=device) 88 | theta2 = torch.linspace(0, np.pi, half_samples, device=device) 89 | 90 | X1 = torch.stack([torch.cos(theta1), torch.sin(theta1)], dim=1) 91 | X2 = torch.stack([1 - torch.cos(theta2), 1 - torch.sin(theta2) - 0.5], dim=1) 92 | 93 | X = torch.cat([X1, X2], dim=0) 94 | y = torch.cat([torch.zeros(half_samples, device=device), torch.ones(half_samples, device=device)], dim=0) 95 | 96 | X += noise * torch.randn(n_samples, 2, device=device) 97 | 98 | indices = torch.randperm(n_samples, device=device) 99 | X = X[indices] 100 | y = y[indices] 101 | 102 | return X, y 103 | 104 | 105 | def train_and_generate_frames( 106 | model, 107 | X_train, 108 | y_train, 109 | domain, 110 | epochs, 111 | lr, 112 | filename="training_video", 113 | resolution: int = 128, 114 | subsample: int = 1, 115 | train_samples: int = 1024, 116 | ): 117 | X_train = X_train.to(device).float() 118 | y_train = y_train.view(-1, 1).to(device).float() 119 | 120 | optimizers = { 121 | "ForeachSOAP": heavyball.ForeachSOAP(model.parameters(), lr=lr), 122 | "PaLMForeachSOAP": heavyball.PaLMForeachSOAP(model.parameters(), lr=lr), 123 | "PrecondScheduleForeachSOAP": heavyball.PrecondScheduleForeachSOAP(model.parameters(), lr=lr), 124 | } 125 | criterion = nn.BCEWithLogitsLoss() 126 | 127 | xx, yy = torch.meshgrid( 128 | torch.linspace(domain[0][0], domain[1][0], resolution, device=device), 129 | torch.linspace(domain[0][1], domain[1][1], resolution, device=device), 130 | indexing="xy", 131 | ) 132 | grid_points = torch.stack((xx.ravel(), yy.ravel()), dim=1).float() 133 | 134 | base_model = copy.deepcopy(model) 135 | 136 | for optimizer_name, optimizer in optimizers.items(): 137 | model = copy.deepcopy(base_model) 138 | print(f"\nTraining with {optimizer_name}") 139 | model.train() 140 | 141 | os.makedirs("frames", exist_ok=True) 142 | 143 | for epoch in tqdm.tqdm(range(epochs)): 144 | outputs = model(X_train) 145 | loss = criterion(outputs, y_train) 146 | 147 | optimizer.zero_grad() 148 | loss.backward() 149 | optimizer.step() 150 | 151 | if epoch % subsample == 0: 152 | model.eval() 153 | with torch.no_grad(): 154 | Z = model(grid_points).reshape(resolution, resolution) 155 | plt.figure(figsize=(10, 8)) 156 | plt.contourf(xx.cpu(), yy.cpu(), Z.cpu(), levels=20) 157 | plt.colorbar(label="Model Output") 158 | plt.scatter(X_train[:, 0].cpu(), X_train[:, 1].cpu(), c=y_train.cpu(), cmap="coolwarm") 159 | plt.title(f"{optimizer_name} - Epoch {epoch}, Loss: {loss.item():.4f}") 160 | plt.savefig(f"frames/{optimizer_name}_epoch_{epoch:05d}.png") 161 | plt.close() 162 | model.train() 163 | 164 | 165 | if __name__ == "__main__": 166 | X, y = generate_two_moons_torch(n_samples=1024, noise=0.05, random_state=42) 167 | 168 | domain = np.array([ 169 | [X[:, 0].min().item() - 1, X[:, 1].min().item() - 1], 170 | [X[:, 0].max().item() + 1, X[:, 1].max().item() + 1], 171 | ]) 172 | 173 | model = torch.compile(MLP(in_shape=2, out_shape=1, width=2, depth=32), mode="max-autotune-no-cudagraphs").to(device) 174 | 175 | epochs = 100 176 | lr = 1e-4 177 | train_and_generate_frames(model, X, y, domain, epochs, lr) 178 | -------------------------------------------------------------------------------- /benchmark/merge_logs.py: -------------------------------------------------------------------------------- 1 | import io 2 | from datetime import datetime 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import markdown 7 | import numpy as np 8 | import pandas as pd 9 | import typer 10 | 11 | app = typer.Typer(pretty_exceptions_enable=True) 12 | 13 | 14 | def nanmean(x): 15 | return np.mean(x[np.isfinite(x)]) 16 | 17 | 18 | def merge_markdown_files(file_paths): 19 | """Merge multiple markdown files by taking the union of detail experiments and computing new summary stats.""" 20 | latest_timestamp = datetime.min 21 | union_details = [] 22 | 23 | for file_path in file_paths: 24 | with open(file_path, "r") as f: 25 | content = f.read() 26 | content = content.split("## Details")[1].split("## ")[0].strip() 27 | html = markdown.markdown(content, extensions=["tables"]) 28 | union_details.append(pd.read_html(io.StringIO(html))[0]) 29 | 30 | union_details = pd.concat(union_details) 31 | 32 | # For duplicate rows (i.e., Benchmark, Optimizer, Cautious, Mars == same), take the best run (i.e., Successful, Runtime, Attempts) 33 | union_details["Success"] = union_details["Success"] == "✓" 34 | union_details["Runtime"] = union_details["Runtime"].str.replace("s", "") 35 | union_details["Runtime"] = pd.to_numeric(union_details["Runtime"], errors="coerce") 36 | union_details["Attempts"] = pd.to_numeric(union_details["Attempts"], errors="coerce") 37 | union_details["Loss"] = pd.to_numeric(union_details["Loss"], errors="coerce") 38 | 39 | union_details = union_details.sort_values(by=["Success", "Runtime", "Attempts"], ascending=[False, True, True]) 40 | union_details = union_details.drop_duplicates(keep="first", subset=["Benchmark", "Optimizer", "Cautious", "Mars"]) 41 | 42 | configs = union_details[["Optimizer", "Cautious", "Mars"]].drop_duplicates().to_dict(orient="records") 43 | 44 | new_summary = [] 45 | 46 | for config in configs: 47 | config_details = union_details[ 48 | (union_details["Optimizer"] == config["Optimizer"]) 49 | & (union_details["Cautious"] == config["Cautious"]) 50 | & (union_details["Mars"] == config["Mars"]) 51 | ] 52 | new_summary.append({ 53 | **config, 54 | "Attempts": nanmean(config_details["Attempts"]), 55 | "Success": f"{int(np.sum(config_details['Success']))}/{len(config_details)}", 56 | "Average Runtime": f"{nanmean(config_details['Runtime']):.1f}s", 57 | }) 58 | 59 | new_summary = pd.DataFrame(new_summary) 60 | new_summary.sort_values(by=["Optimizer", "Cautious", "Mars"], inplace=True) 61 | 62 | union_details["Runtime"] = [f"{x:.1f}s" for x in union_details["Runtime"]] 63 | union_details["Success"] = ["✓" if x else "✗" for x in union_details["Success"]] 64 | 65 | # Generate merged content with updated summary based on union of experiments 66 | merged_content = f"""# Benchmark Results 67 | 68 | Generated: {latest_timestamp} 69 | Last updated: {latest_timestamp} 70 | 71 | ## Summary 72 | 73 | {new_summary.to_markdown(index=False)} 74 | 75 | ## Details 76 | 77 | {union_details.to_markdown(index=False)} 78 | """ 79 | 80 | return merged_content 81 | 82 | 83 | @app.command() 84 | def main(path: List[str] = typer.Option([], help="Markdown files to merge")): 85 | files = [Path(p) for p in path] 86 | 87 | # Generate merged content 88 | merged_content = merge_markdown_files(files) 89 | 90 | # Write to output file 91 | output_path = Path("merged_results.md") 92 | with open(output_path, "w") as f: 93 | f.write(merged_content) 94 | 95 | 96 | if __name__ == "__main__": 97 | app() 98 | -------------------------------------------------------------------------------- /benchmark/minimax.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from benchmark.utils import param_norm_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | configs = { 15 | "trivial": {"size": 4}, 16 | "easy": {"size": 16}, 17 | "medium": {"size": 512}, 18 | "hard": {"size": 8192}, 19 | "extreme": {"size": 2**15}, 20 | "nightmare": {"size": 2**17}, 21 | } 22 | 23 | 24 | class Model(nn.Module): 25 | def __init__(self, size): 26 | super().__init__() 27 | self.param = nn.Parameter(torch.randn((2 * size,))) 28 | 29 | def forward(self, inp): 30 | param0, param1 = self.param.chunk(2, 0) 31 | return param0 @ param1 + (param0 @ param0 + param1 @ param1) / 2 32 | 33 | 34 | @app.command() 35 | def main( 36 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 37 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 38 | size: int = 1024, 39 | depth: int = 4, 40 | batch: int = 16, 41 | steps: int = 10, 42 | weight_decay: float = 0, 43 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 44 | win_condition_multiplier: float = 1.0, 45 | trials: int = 10, 46 | config: Optional[str] = None, 47 | ): 48 | size = configs.get(config, {}).get("size", size) 49 | 50 | dtype = [getattr(torch, d) for d in dtype] 51 | model = Model(size).cuda() 52 | 53 | def data(): 54 | inp = torch.randn((batch, size), device="cuda", dtype=dtype[0]) 55 | return inp, inp.cumsum(1) 56 | 57 | trial( 58 | model, 59 | data, 60 | F.mse_loss, 61 | param_norm_win_condition(1e-7 * win_condition_multiplier, 0), 62 | steps, 63 | opt[0], 64 | dtype[0], 65 | size, 66 | batch, 67 | weight_decay, 68 | method[0], 69 | 1, 70 | depth, 71 | failure_threshold=depth * 2, 72 | base_lr=1e-3, 73 | trials=trials, 74 | ) 75 | 76 | 77 | if __name__ == "__main__": 78 | app() 79 | -------------------------------------------------------------------------------- /benchmark/momentum_utilization.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import loss_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | configs = { 14 | "trivial": {"weight": 0.004}, 15 | "easy": {"weight": 0.02}, 16 | "medium": {"weight": 0.1}, 17 | "hard": {"weight": 0.5}, 18 | "extreme": {"weight": 1}, 19 | "nightmare": {"weight": 2}, 20 | } 21 | 22 | 23 | class Model(nn.Module): 24 | def __init__(self, weight: float, size=1024): 25 | super().__init__() 26 | self.param = nn.Parameter(torch.randn(size)) 27 | self.register_buffer("t", torch.zeros(1)) 28 | self.weight = weight 29 | 30 | def forward(self): 31 | """Tests effective use of momentum for oscillating landscapes.""" 32 | self.t += 0.1 33 | x = self.param 34 | return (x.square() + self.weight * torch.sin(10 * x) * torch.cos(self.t)).mean() 35 | 36 | 37 | @app.command() 38 | def main( 39 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 40 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 41 | steps: int = 100, 42 | weight_decay: float = 0, 43 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 44 | trials: int = 100, 45 | win_condition_multiplier: float = 1.0, 46 | weight: float = 0.1, 47 | config: Optional[str] = None, 48 | ): 49 | weight = configs.get(config, {}).get("weight", weight) 50 | dtype = [getattr(torch, d) for d in dtype] 51 | model = Model(weight).cuda().double() 52 | 53 | def data(): 54 | return None, None 55 | 56 | trial( 57 | model, 58 | data, 59 | None, 60 | loss_win_condition(win_condition_multiplier * 1e-6), 61 | steps, 62 | opt[0], 63 | dtype[0], 64 | 1, 65 | 1, 66 | weight_decay, 67 | method[0], 68 | 1, 69 | 1, 70 | failure_threshold=3, 71 | base_lr=1e-3, 72 | trials=trials, 73 | ) 74 | 75 | 76 | if __name__ == "__main__": 77 | app() 78 | -------------------------------------------------------------------------------- /benchmark/noisy_matmul.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from benchmark.utils import param_norm_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | configs = { 16 | "trivial": {"depth": 1}, 17 | "easy": {"depth": 2}, 18 | "medium": {"depth": 8}, 19 | "hard": {"depth": 12}, 20 | "extreme": {"depth": 16}, 21 | "nightmare": {"depth": 20}, 22 | } 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, size): 27 | super().__init__() 28 | self.param = nn.Parameter(torch.randn((size,))) 29 | self.offset = nn.Buffer(torch.randn_like(self.param)) 30 | 31 | def forward(self, inp): 32 | y = None 33 | y0 = self.param.view(1, -1).expand(inp.size(0), -1) + self.offset # offset, so weight decay doesnt help 34 | for i in inp.unbind(1): 35 | y = torch.einsum("bi,bik->bk", y0, i) 36 | y0 = F.leaky_relu(y, 0.1) 37 | return y 38 | 39 | 40 | @app.command() 41 | def main( 42 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 43 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 44 | size: int = 64, 45 | depth: int = 4, 46 | batch: int = 128, 47 | steps: int = 10, 48 | weight_decay: float = 0, 49 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 50 | win_condition_multiplier: float = 1.0, 51 | trials: int = 10, 52 | config: Optional[str] = None, 53 | ): 54 | depth = configs.get(config, {}).get("depth", depth) 55 | dtype = [getattr(torch, d) for d in dtype] 56 | model = Model(size).cuda() 57 | 58 | def data(): 59 | inp = torch.randn((batch, depth, size, size), device="cuda", dtype=dtype[0]) / size**0.5 60 | return inp, torch.zeros((batch, size), device="cuda", dtype=dtype[0]) 61 | 62 | trial( 63 | model, 64 | data, 65 | F.mse_loss, 66 | param_norm_win_condition(1e-7 * win_condition_multiplier, model.offset), 67 | steps, 68 | opt[0], 69 | dtype[0], 70 | size, 71 | batch, 72 | weight_decay, 73 | method[0], 74 | 1, 75 | depth, 76 | failure_threshold=depth * 2, 77 | base_lr=1e-3, 78 | trials=trials, 79 | ) 80 | 81 | 82 | if __name__ == "__main__": 83 | app() 84 | -------------------------------------------------------------------------------- /benchmark/parameter_scale.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import loss_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | configs = {"easy": {"scale": 1e1}, "medium": {"scale": 1e3}, "hard": {"scale": 1e5}} 14 | 15 | 16 | class Model(nn.Module): 17 | def __init__(self, size, scale: float): 18 | super().__init__() 19 | # Simulate different layer scales in deep networks 20 | self.layer1 = nn.Parameter(torch.randn(size) * scale) # Small gradients 21 | self.layer2 = nn.Parameter(torch.randn(size)) # Medium gradients 22 | self.layer3 = nn.Parameter(torch.randn(size) / scale) # Large gradients 23 | 24 | def forward(self): 25 | """Test optimizer's ability to handle different gradient scales across layers.""" 26 | # Each layer contributes equally to the loss but has very different scales 27 | return (self.layer1.square().mean() + self.layer2.square().mean() + self.layer3.square().mean()) / 3 28 | 29 | 30 | @app.command() 31 | def main( 32 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 33 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 34 | steps: int = 100, 35 | weight_decay: float = 0, 36 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 37 | trials: int = 100, 38 | win_condition_multiplier: float = 1.0, 39 | config: Optional[str] = None, 40 | ): 41 | scale = configs.get(config, {}).get("scale", 1e3) 42 | 43 | dtype = [getattr(torch, d) for d in dtype] 44 | 45 | model = Model(size=1024, scale=scale).cuda().double() 46 | 47 | def data(): 48 | return None, None 49 | 50 | # More lenient win condition due to vastly different scales 51 | trial( 52 | model, 53 | data, 54 | None, 55 | loss_win_condition(win_condition_multiplier * 1e-4), 56 | steps, 57 | opt[0], 58 | dtype[0], 59 | 1, 60 | 1, 61 | weight_decay, 62 | method[0], 63 | 1, 64 | 1, 65 | failure_threshold=5, 66 | base_lr=1e-4, 67 | trials=trials, 68 | ) # Lower learning rate and more attempts 69 | 70 | 71 | if __name__ == "__main__": 72 | app() 73 | -------------------------------------------------------------------------------- /benchmark/plateau_navigation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pathlib 3 | import random 4 | from typing import List, Optional 5 | 6 | import matplotlib.colors 7 | import torch 8 | import torch.backends.opt_einsum 9 | import typer 10 | from torch import nn 11 | from utils import Plotter 12 | 13 | from benchmark.utils import loss_win_condition, trial 14 | from heavyball.utils import set_torch 15 | 16 | app = typer.Typer(pretty_exceptions_enable=False) 17 | set_torch() 18 | 19 | configs = { 20 | "trivial": {"scale": 1}, 21 | "easy": {"scale": 4}, 22 | "medium": {"scale": 8}, 23 | "hard": {"scale": 12}, 24 | "extreme": {"scale": 16}, 25 | "nightmare": {"scale": 20}, 26 | } 27 | 28 | 29 | def objective(x, y, scale: float): 30 | """Tests optimizer's ability to handle regions with very small gradients and sharp plateaus.""" 31 | output = 1 / (1 + torch.exp((x**2 + y**2 - 1) * -scale)) 32 | minimum = 1 / (1 + math.exp(scale)) 33 | return output - minimum # ensure the minimum is at 0 34 | 35 | 36 | class Model(nn.Module): 37 | def __init__(self, x, scale): 38 | super().__init__() 39 | self.param = nn.Parameter(torch.tensor(x).float()) 40 | self.scale = scale 41 | 42 | def forward(self): 43 | return objective(*self.param, scale=self.scale) 44 | 45 | 46 | @app.command() 47 | def main( 48 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 49 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 50 | steps: int = 100, 51 | weight_decay: float = 0, 52 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 53 | show_image: bool = False, 54 | trials: int = 100, 55 | win_condition_multiplier: float = 1.0, 56 | config: Optional[str] = None, 57 | ): 58 | scale = configs.get(config, {}).get("scale", 4) 59 | 60 | dtype = [getattr(torch, d) for d in dtype] 61 | coords = (1.5, 1.5) # Start outside the plateau 62 | 63 | # Clean up old plots 64 | for path in pathlib.Path(".").glob("plateau_navigation.png"): 65 | path.unlink() 66 | 67 | colors = list(matplotlib.colors.TABLEAU_COLORS.values()) 68 | rng = random.Random(0x1239121) 69 | rng.shuffle(colors) 70 | 71 | if show_image: 72 | model = Plotter(lambda *x: objective(*x, scale=scale).log()) 73 | else: 74 | model = Model(coords, scale=scale) 75 | model.double() 76 | 77 | def data(): 78 | return None, None 79 | 80 | trial( 81 | model, 82 | data, 83 | None, 84 | loss_win_condition(win_condition_multiplier * 1e-4), 85 | steps, 86 | opt[0], 87 | dtype[0], 88 | 1, 89 | 1, 90 | weight_decay, 91 | method[0], 92 | 1, 93 | 1, 94 | failure_threshold=3, 95 | base_lr=1e-3, 96 | trials=trials, 97 | ) 98 | 99 | 100 | if __name__ == "__main__": 101 | app() 102 | -------------------------------------------------------------------------------- /benchmark/postprocess_requirements.txt: -------------------------------------------------------------------------------- 1 | pandas>=1.5.0 2 | tabulate>=0.9.0 3 | termcolor>=2.1.0 4 | matplotlib>=3.7.0 5 | seaborn>=0.12.0 6 | -------------------------------------------------------------------------------- /benchmark/powers.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import torch.nn as nn 6 | import typer 7 | 8 | from benchmark.utils import loss_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | configs = { 14 | "trivial": {"powers": 4}, 15 | "easy": {"powers": 8}, 16 | "medium": {"powers": 16}, 17 | "hard": {"powers": 32}, 18 | "extreme": {"powers": 128}, 19 | "nightmare": {"powers": 512}, 20 | } 21 | 22 | 23 | class Model(nn.Module): 24 | def __init__(self, size, powers, target): 25 | super().__init__() 26 | self.target = target 27 | self.param = nn.Parameter(torch.rand(powers, size) * 2) 28 | self.register_buffer("scale", torch.arange(powers).float().add(1)) 29 | 30 | def forward(self): 31 | x = self.param - self.target 32 | x = x ** self.scale.view(-1, 1) 33 | return x.square().mean() 34 | 35 | 36 | @app.command() 37 | def main( 38 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 39 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 40 | size: int = 64, 41 | powers: int = 8, 42 | steps: int = 10, 43 | target: float = 1.0, 44 | weight_decay: float = 0, 45 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 46 | win_condition_multiplier: float = 1.0, 47 | trials: int = 10, 48 | config: Optional[str] = None, 49 | ): 50 | powers = configs.get(config, {}).get("powers", powers) 51 | 52 | dtype = [getattr(torch, d) for d in dtype] 53 | model = Model(size, powers, target).cuda().double() 54 | 55 | def data(): 56 | return None, None 57 | 58 | trial( 59 | model, 60 | data, 61 | None, 62 | loss_win_condition(win_condition_multiplier * 1e-8), 63 | steps, 64 | opt[0], 65 | dtype[0], 66 | 1, 67 | 1, 68 | weight_decay, 69 | method[0], 70 | 1, 71 | 1, 72 | failure_threshold=3, 73 | base_lr=1e-3, 74 | trials=trials, 75 | ) 76 | 77 | 78 | if __name__ == "__main__": 79 | app() 80 | -------------------------------------------------------------------------------- /benchmark/powers_varying_target.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import torch.nn as nn 6 | import typer 7 | 8 | from benchmark.utils import loss_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | 14 | configs = { 15 | "trivial": {"powers": 4}, 16 | "easy": {"powers": 8}, 17 | "medium": {"powers": 16}, 18 | "hard": {"powers": 32}, 19 | "extreme": {"powers": 128}, 20 | "nightmare": {"powers": 512}, 21 | } 22 | 23 | 24 | class Model(nn.Module): 25 | def __init__(self, size, powers, target_mult): 26 | super().__init__() 27 | self.target = nn.Buffer( 28 | torch.arange(powers * size).view(size, powers).transpose(0, 1).float() * target_mult / powers / size 29 | ) 30 | self.param = nn.Parameter(torch.rand(powers, size) * 2) 31 | self.register_buffer("scale", torch.arange(powers).float().add(1)) 32 | 33 | def forward(self): 34 | x = self.param - self.target 35 | x = x ** self.scale.view(-1, 1) 36 | return x.square().mean() 37 | 38 | 39 | @app.command() 40 | def main( 41 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 42 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 43 | size: int = 64, 44 | powers: int = 8, 45 | steps: int = 10, 46 | target_mult: float = 1.0, 47 | weight_decay: float = 0, 48 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 49 | win_condition_multiplier: float = 1.0, 50 | trials: int = 10, 51 | config: Optional[str] = None, 52 | ): 53 | powers = configs.get(config, {}).get("powers", powers) 54 | 55 | dtype = [getattr(torch, d) for d in dtype] 56 | model = Model(size, powers, target_mult).cuda().double() 57 | 58 | def data(): 59 | return None, None 60 | 61 | trial( 62 | model, 63 | data, 64 | None, 65 | loss_win_condition(win_condition_multiplier * 1e-6), 66 | steps, 67 | opt[0], 68 | dtype[0], 69 | 1, 70 | 1, 71 | weight_decay, 72 | method[0], 73 | 1, 74 | 1, 75 | failure_threshold=3, 76 | base_lr=1e-3, 77 | trials=trials, 78 | ) 79 | 80 | 81 | if __name__ == "__main__": 82 | app() 83 | -------------------------------------------------------------------------------- /benchmark/quadratic_varying_scale.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import typer 7 | 8 | from benchmark.utils import param_norm_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | 14 | configs = { 15 | "trivial": {"size": 4}, 16 | "easy": {"size": 16}, 17 | "medium": {"size": 512}, 18 | "hard": {"size": 8192}, 19 | "extreme": {"size": 2**15}, 20 | "nightmare": {"size": 2**17}, 21 | } 22 | 23 | 24 | class Model(nn.Module): 25 | def __init__(self, size): 26 | super().__init__() 27 | self.param = nn.Parameter(torch.randn(size)) 28 | self.register_buffer("scale", F.normalize(torch.arange(1, 1 + size).float(), dim=0, p=1)) 29 | 30 | def forward(self): 31 | return self.param.square() @ self.scale 32 | 33 | 34 | @app.command() 35 | def main( 36 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 37 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 38 | size: int = 1024, 39 | batch: int = 256, 40 | steps: int = 100, 41 | weight_decay: float = 0, 42 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 43 | trials: int = 10, 44 | win_condition_multiplier: float = 1.0, 45 | config: Optional[str] = None, 46 | ): 47 | size = configs.get(config, {}).get("size", size) 48 | dtype = [getattr(torch, d) for d in dtype] 49 | model = Model(size).cuda() 50 | 51 | def data(): 52 | return None, None 53 | 54 | trial( 55 | model, 56 | data, 57 | None, 58 | param_norm_win_condition(win_condition_multiplier * 1e-7, 0), 59 | steps, 60 | opt[0], 61 | dtype[0], 62 | size, 63 | batch, 64 | weight_decay, 65 | method[0], 66 | 1, 67 | 1, 68 | failure_threshold=2, 69 | base_lr=1e-3, 70 | trials=trials, 71 | ) 72 | 73 | 74 | if __name__ == "__main__": 75 | app() 76 | -------------------------------------------------------------------------------- /benchmark/quadratic_varying_target.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import typer 7 | 8 | from benchmark.utils import param_norm_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | 14 | 15 | configs = { 16 | "trivial": {"size": 4}, 17 | "easy": {"size": 16}, 18 | "medium": {"size": 512}, 19 | "hard": {"size": 8192}, 20 | "extreme": {"size": 2**15}, 21 | "nightmare": {"size": 2**17}, 22 | } 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, size): 27 | super().__init__() 28 | self.param = nn.Parameter(torch.randn(size)) 29 | self.register_buffer("target", F.normalize(torch.arange(size).float(), dim=0)) 30 | 31 | def forward(self): 32 | return (self.param - self.target).square().mean() 33 | 34 | 35 | @app.command() 36 | def main( 37 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 38 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 39 | size: int = 1024, 40 | batch: int = 256, 41 | steps: int = 100, 42 | weight_decay: float = 0, 43 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 44 | trials: int = 10, 45 | win_condition_multiplier: float = 1.0, 46 | config: Optional[str] = None, 47 | ): 48 | size = configs.get(config, {}).get("size", size) 49 | 50 | dtype = [getattr(torch, d) for d in dtype] 51 | model = Model(size).cuda() 52 | 53 | def data(): 54 | return None, None 55 | 56 | trial( 57 | model, 58 | data, 59 | None, 60 | param_norm_win_condition(win_condition_multiplier * 1e-8, -model.target), 61 | steps, 62 | opt[0], 63 | dtype[0], 64 | size, 65 | batch, 66 | weight_decay, 67 | method[0], 68 | 1, 69 | 1, 70 | failure_threshold=2, 71 | base_lr=1e-3, 72 | trials=trials, 73 | ) 74 | 75 | 76 | if __name__ == "__main__": 77 | app() 78 | -------------------------------------------------------------------------------- /benchmark/rastrigin.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pathlib 3 | import random 4 | from typing import List, Optional 5 | 6 | import matplotlib.colors 7 | import torch 8 | import torch.backends.opt_einsum 9 | import typer 10 | from torch import nn 11 | from utils import Plotter 12 | 13 | from benchmark.utils import SkipConfig, loss_win_condition, trial 14 | from heavyball.utils import set_torch 15 | 16 | app = typer.Typer(pretty_exceptions_enable=False) 17 | set_torch() 18 | 19 | 20 | def _formula(x, A): 21 | return x**2 + A * (1 - torch.cos(2 * math.pi * x)) 22 | 23 | 24 | def objective(*args, A=10): 25 | if len(args) == 1: 26 | return _formula(args[0], A).mean() 27 | 28 | return sum(_formula(x, A) for x in args) / len(args) 29 | 30 | 31 | class Model(nn.Module): 32 | def __init__(self, x): 33 | super().__init__() 34 | self.param = nn.Parameter(torch.tensor(x).float()) 35 | 36 | def forward(self): 37 | return objective(self.param) 38 | 39 | 40 | @app.command() 41 | def main( 42 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 43 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 44 | steps: int = 100, 45 | weight_decay: float = 0, 46 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 47 | show_image: bool = False, 48 | trials: int = 100, 49 | win_condition_multiplier: float = 1.0, 50 | size: int = 2, 51 | config: Optional[str] = None, 52 | ): 53 | if config is not None and config != "trivial": 54 | raise SkipConfig("'config' must be 'trivial'.") 55 | if show_image: 56 | assert size == 2, "Image can only be displayed for 2D functions" 57 | dtype = [getattr(torch, d) for d in dtype] 58 | coords = (-2.2,) * size 59 | 60 | # Clean up old plots 61 | for path in pathlib.Path(".").glob("rastrigin.png"): 62 | path.unlink() 63 | 64 | colors = list(matplotlib.colors.TABLEAU_COLORS.values()) 65 | rng = random.Random(0x1239121) 66 | rng.shuffle(colors) 67 | 68 | if show_image: 69 | model = Model(coords) 70 | model = Plotter( 71 | model, 72 | x_limits=(-8, 2), 73 | y_limits=(-8, 2), 74 | ) 75 | else: 76 | model = Model(coords) 77 | model.double() 78 | 79 | def data(): 80 | return None, None 81 | 82 | model = trial( 83 | model, 84 | data, 85 | None, 86 | loss_win_condition(win_condition_multiplier * 1e-2 * (not show_image)), 87 | steps, 88 | opt[0], 89 | dtype[0], 90 | 1, 91 | 1, 92 | weight_decay, 93 | method[0], 94 | 1, 95 | 1, 96 | base_lr=1e-4, 97 | trials=trials, 98 | return_best=show_image, 99 | ) 100 | 101 | if not show_image: 102 | return 103 | 104 | model.plot(title=f"{method[0]} {opt[0]}", save_path="rastrigin.png") 105 | 106 | 107 | if __name__ == "__main__": 108 | app() 109 | -------------------------------------------------------------------------------- /benchmark/requirements.txt: -------------------------------------------------------------------------------- 1 | lightgbm 2 | optuna>=4.3.0 3 | gpytorch>=1.14 4 | botorch>=0.13.0 5 | optuna-integration[botorch] 6 | optuna-integration>=4.3.0 7 | threadpoolctl>=3.6.0 8 | scipy>=1.15.2 9 | scikit-learn>=1.6.1 10 | numexpr>=2.10.2 11 | hebo>=0.3.6 12 | pymoo>=0.6.1 13 | numpy>=1.26.4,<2.0.0 14 | optunahub>=0.2.0 15 | cmaes>=0.11.1 16 | typer>=0.15.0,<0.16.0 17 | -------------------------------------------------------------------------------- /benchmark/rosenbrock.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import random 3 | from typing import List, Optional 4 | 5 | import matplotlib.colors 6 | import torch 7 | import torch.backends.opt_einsum 8 | import typer 9 | from torch import nn 10 | 11 | from benchmark.utils import Plotter, SkipConfig, loss_win_condition, trial 12 | from heavyball.utils import set_torch 13 | 14 | app = typer.Typer(pretty_exceptions_enable=False) 15 | set_torch() 16 | 17 | 18 | def objective(x, y): 19 | return (1 - x) ** 2 + 1 * (y - x**2) ** 2 20 | 21 | 22 | class Model(nn.Module): 23 | def __init__(self, x): 24 | super().__init__() 25 | self.param = nn.Parameter(torch.tensor(x).float()) 26 | 27 | def forward(self): 28 | return objective(*self.param) 29 | 30 | 31 | @app.command() 32 | def main( 33 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 34 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 35 | steps: int = 100, 36 | weight_decay: float = 0, 37 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 38 | show_image: bool = False, 39 | trials: int = 100, 40 | win_condition_multiplier: float = 1.0, 41 | config: Optional[str] = None, 42 | ): 43 | if config is not None and config != "trivial": 44 | raise SkipConfig("'config' must be 'trivial'.") 45 | dtype = [getattr(torch, d) for d in dtype] 46 | coords = (-7, -4) 47 | 48 | # Clean up old plots 49 | for path in pathlib.Path(".").glob("rosenbrock.png"): 50 | path.unlink() 51 | 52 | colors = list(matplotlib.colors.TABLEAU_COLORS.values()) 53 | rng = random.Random(0x1239121) 54 | rng.shuffle(colors) 55 | 56 | if show_image: 57 | model = Plotter(Model(coords), x_limits=(-8, 2), y_limits=(-8, 2), should_normalize=True) 58 | else: 59 | model = Model(coords) 60 | model.double() 61 | 62 | def data(): 63 | return None, None 64 | 65 | model = trial( 66 | model, 67 | data, 68 | None, 69 | loss_win_condition(win_condition_multiplier * 1e-9 * (not show_image)), 70 | steps, 71 | opt[0], 72 | dtype[0], 73 | 1, 74 | 1, 75 | weight_decay, 76 | method[0], 77 | 1, 78 | 1, 79 | base_lr=1e-4, 80 | trials=trials, 81 | return_best=show_image, 82 | ) 83 | 84 | if not show_image: 85 | return 86 | 87 | model.plot(save_path="rosenbrock.png") 88 | 89 | 90 | if __name__ == "__main__": 91 | app() 92 | -------------------------------------------------------------------------------- /benchmark/saddle_point.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import random 3 | from typing import List, Optional 4 | 5 | import matplotlib.colors 6 | import torch 7 | import torch.backends.opt_einsum 8 | import typer 9 | from torch import nn 10 | from utils import Plotter 11 | 12 | from benchmark.utils import loss_win_condition, trial 13 | from heavyball.utils import set_torch 14 | 15 | app = typer.Typer(pretty_exceptions_enable=False) 16 | set_torch() 17 | 18 | configs = { 19 | "trivial": {"power": 1}, 20 | "easy": {"power": 2}, 21 | "medium": {"power": 4}, 22 | "hard": {"power": 8}, 23 | "extreme": {"power": 16}, 24 | "nightmare": {"power": 32}, 25 | } 26 | 27 | 28 | def objective(*xs, power): 29 | """Classic saddle point objective - tests ability to escape saddle points.""" 30 | return sum(x**power for x in xs) 31 | 32 | 33 | class Model(nn.Module): 34 | def __init__(self, power, offset): 35 | super().__init__() 36 | self.param = nn.Parameter(torch.tensor([1.2, 1.9]).float()) 37 | self.offset = offset 38 | self.power = 2 * power + 1 39 | 40 | def forward(self): 41 | return objective(*self.param, power=self.power) + self.offset 42 | 43 | 44 | @app.command() 45 | def main( 46 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 47 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 48 | steps: int = 100, 49 | weight_decay: float = 0, 50 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 51 | show_image: bool = False, 52 | trials: int = 100, 53 | win_condition_multiplier: float = 1.0, 54 | config: Optional[str] = None, 55 | ): 56 | dtype = [getattr(torch, d) for d in dtype] 57 | coords = configs.get(config, {}).get("power", 1) 58 | 59 | # Clean up old plots 60 | for path in pathlib.Path(".").glob("saddle_point.png"): 61 | path.unlink() 62 | 63 | colors = list(matplotlib.colors.TABLEAU_COLORS.values()) 64 | rng = random.Random(0x1239121) 65 | rng.shuffle(colors) 66 | 67 | offset = win_condition_multiplier * 10 68 | 69 | if show_image: 70 | model = Plotter( 71 | lambda *x: objective(*x).add(offset).log(), 72 | coords=coords, 73 | xlim=(-2, 2), 74 | ylim=(-2, 2), 75 | normalize=8, 76 | after_step=torch.exp, 77 | ) 78 | else: 79 | model = Model(coords, offset) 80 | model.double() 81 | 82 | def data(): 83 | return None, None 84 | 85 | trial( 86 | model, 87 | data, 88 | None, 89 | loss_win_condition(0.1), 90 | steps, 91 | opt[0], 92 | dtype[0], 93 | 1, 94 | 1, 95 | weight_decay, 96 | method[0], 97 | 1, 98 | 1, 99 | failure_threshold=3, 100 | base_lr=1e-3, 101 | trials=trials, 102 | ) 103 | 104 | 105 | if __name__ == "__main__": 106 | app() 107 | -------------------------------------------------------------------------------- /benchmark/scale_invariant.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import loss_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | configs = { 14 | "trivial": {"range": 1}, 15 | "easy": {"range": 2}, 16 | "medium": {"range": 3}, 17 | "hard": {"range": 4}, 18 | "extreme": {"range": 5}, 19 | "nightmare": {"range": 6}, 20 | } 21 | 22 | 23 | def objective(x): 24 | """Tests optimizer's ability to handle different parameter scales.""" 25 | return torch.log1p(x.square()).mean() 26 | 27 | 28 | class Model(nn.Module): 29 | def __init__(self, size, value_range): 30 | super().__init__() 31 | # Initialize with different scales 32 | scales = torch.logspace(-value_range, value_range, size) 33 | self.param = nn.Parameter(scales * torch.randn(size)) 34 | 35 | def forward(self): 36 | return objective(self.param) 37 | 38 | 39 | @app.command() 40 | def main( 41 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 42 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 43 | steps: int = 100, 44 | weight_decay: float = 0, 45 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 46 | trials: int = 100, 47 | win_condition_multiplier: float = 1.0, 48 | size: int = 512, 49 | config: Optional[str] = None, 50 | ): 51 | value_range = configs.get(config, {}).get("range", 3) 52 | 53 | dtype = [getattr(torch, d) for d in dtype] 54 | model = Model(size, value_range).cuda().double() 55 | 56 | def data(): 57 | return None, None 58 | 59 | trial( 60 | model, 61 | data, 62 | None, 63 | loss_win_condition(win_condition_multiplier * 1e-3), 64 | steps, 65 | opt[0], 66 | dtype[0], 67 | 1, 68 | 1, 69 | weight_decay, 70 | method[0], 71 | 1, 72 | 1, 73 | failure_threshold=3, 74 | base_lr=1e-3, 75 | trials=trials, 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | app() 81 | -------------------------------------------------------------------------------- /benchmark/sparse_gradient.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import loss_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | configs = { 14 | "trivial": {"sparsity": 0.5}, 15 | "easy": {"sparsity": 2**-3}, 16 | "medium": {"sparsity": 2**-6}, 17 | "hard": {"sparsity": 2**-8}, 18 | "extreme": {"sparsity": 2**-11}, 19 | "nightmare": {"sparsity": 2**-14}, 20 | } 21 | 22 | 23 | class Model(nn.Module): 24 | def __init__(self, size=2**16, sparsity=2**-6): 25 | super().__init__() 26 | self.param = nn.Parameter(torch.randn(size)) 27 | self.sparsity = sparsity 28 | self.register_buffer("prev_mask", torch.zeros_like(self.param)) 29 | 30 | def forward(self): 31 | """Test optimizer's ability to handle sparse gradients.""" 32 | # Generate new random mask each time, but keep some consistency 33 | new_mask = (torch.rand_like(self.param) < self.sparsity).float() 34 | mask = (new_mask + self.prev_mask) > 0 # Union of current and previous mask 35 | self.prev_mask.copy_(new_mask) 36 | 37 | return (self.param * mask.float()).square().mean() 38 | 39 | 40 | @app.command() 41 | def main( 42 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 43 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 44 | steps: int = 100, 45 | weight_decay: float = 0, 46 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 47 | trials: int = 100, 48 | win_condition_multiplier: float = 1.0, 49 | sparsity: float = 2**-6, 50 | config: Optional[str] = None, 51 | ): 52 | sparsity = configs.get(config, {}).get("sparsity", sparsity) 53 | dtype = [getattr(torch, d) for d in dtype] 54 | model = Model(sparsity=sparsity).cuda().double() 55 | 56 | def data(): 57 | return None, None 58 | 59 | # Win condition accounts for sparsity - harder to reach very low loss 60 | trial( 61 | model, 62 | data, 63 | None, 64 | loss_win_condition(win_condition_multiplier * 1e-4), 65 | steps, 66 | opt[0], 67 | dtype[0], 68 | 1, 69 | 1, 70 | weight_decay, 71 | method[0], 72 | 1, 73 | 1, 74 | failure_threshold=5, 75 | base_lr=1e-3, 76 | trials=trials, 77 | ) # More failure attempts allowed 78 | 79 | 80 | if __name__ == "__main__": 81 | app() 82 | -------------------------------------------------------------------------------- /benchmark/wide_linear.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from benchmark.utils import param_norm_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | configs = { 16 | "trivial": {"size": 4}, 17 | "easy": {"size": 32}, 18 | "medium": {"size": 512}, 19 | "hard": {"size": 2048}, 20 | "extreme": {"size": 8192}, 21 | "nightmare": {"size": 2**14}, 22 | } 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, size): 27 | super().__init__() 28 | self.param = nn.Parameter(torch.randn((size, size))) 29 | self.target = nn.Buffer(torch.triu(torch.ones_like(self.param))) 30 | 31 | def forward(self, inp): 32 | return inp @ self.param 33 | 34 | 35 | @app.command() 36 | def main( 37 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 38 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 39 | size: int = 1024, 40 | depth: int = 4, 41 | batch: int = 16, 42 | steps: int = 10, 43 | weight_decay: float = 0, 44 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 45 | win_condition_multiplier: float = 1.0, 46 | trials: int = 10, 47 | config: Optional[str] = None, 48 | ): 49 | size = configs.get(config, {}).get("size", size) 50 | dtype = [getattr(torch, d) for d in dtype] 51 | model = Model(size).cuda() 52 | 53 | def data(): 54 | inp = torch.randn((batch, size), device="cuda", dtype=dtype[0]) 55 | return inp, inp.cumsum(1) 56 | 57 | trial( 58 | model, 59 | data, 60 | F.mse_loss, 61 | param_norm_win_condition(1e-7 * win_condition_multiplier, model.target), 62 | steps, 63 | opt[0], 64 | dtype[0], 65 | size, 66 | batch, 67 | weight_decay, 68 | method[0], 69 | 1, 70 | depth, 71 | failure_threshold=depth * 2, 72 | base_lr=1e-3, 73 | trials=trials, 74 | ) 75 | 76 | 77 | if __name__ == "__main__": 78 | app() 79 | -------------------------------------------------------------------------------- /benchmark/xor_digit.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import torch.nn as nn 6 | import typer 7 | from torch.nn import functional as F 8 | 9 | from benchmark.utils import loss_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | configs = { 16 | "trivial": {"length": 4}, 17 | "easy": {"length": 8}, 18 | "medium": {"length": 16}, 19 | "hard": {"length": 32}, 20 | "extreme": {"length": 64}, 21 | "nightmare": {"length": 96}, 22 | } 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, size, depth): 27 | super().__init__() 28 | self.embed = nn.Embedding(2, size) 29 | self.enc = nn.LSTM(size, size, depth, batch_first=False) 30 | self.enc.flatten_parameters() 31 | self.proj = nn.Sequential( 32 | nn.LayerNorm(size), # 33 | nn.Linear(size, 1), 34 | ) 35 | 36 | def forward(self, inp): 37 | inp = inp.transpose(0, 1) 38 | inp = self.embed(inp.squeeze(-1).long()) 39 | out, _ = torch.compiler.disable()(self.enc)(inp) 40 | return self.proj(out[-1, :]) 41 | 42 | 43 | @app.command() 44 | def main( 45 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 46 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 47 | length: int = 64, 48 | size: int = 64, 49 | depth: int = 1, 50 | batch: int = 256, 51 | steps: int = 10, 52 | weight_decay: float = 0, 53 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 54 | trials: int = 10, 55 | win_condition_multiplier: float = 1.0, 56 | config: Optional[str] = None, 57 | ): 58 | length = configs.get(config, {}).get("length", length) 59 | dtype = [getattr(torch, d) for d in dtype] 60 | torch.manual_seed(0x1239121) 61 | model = Model(size, depth).cuda() 62 | 63 | def data(): 64 | inp = torch.randn((batch, length, 1), device="cuda", dtype=dtype[0]) 65 | inp = inp > 0 66 | return inp.to(dtype[0]), (inp.sum(1) % 2).to(dtype[0]) 67 | 68 | trial( 69 | model, 70 | data, 71 | F.binary_cross_entropy_with_logits, 72 | loss_win_condition(win_condition_multiplier * 1e-3), 73 | steps, 74 | opt[0], 75 | dtype[0], 76 | size, 77 | batch, 78 | weight_decay, 79 | method[0], 80 | length, 81 | depth, 82 | failure_threshold=10, 83 | base_lr=1e-6, 84 | trials=trials, 85 | ) 86 | 87 | 88 | if __name__ == "__main__": 89 | app() 90 | -------------------------------------------------------------------------------- /benchmark/xor_digit_rnn.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import torch.nn as nn 6 | import typer 7 | from torch.nn import functional as F 8 | 9 | from benchmark.utils import loss_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | configs = { 16 | "trivial": {"length": 4}, 17 | "easy": {"length": 8}, 18 | "medium": {"length": 16}, 19 | "hard": {"length": 32}, 20 | "extreme": {"length": 64}, 21 | "nightmare": {"length": 96}, 22 | } 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, size, depth): 27 | super().__init__() 28 | self.embed = nn.Embedding(2, size) 29 | self.enc = nn.RNN(size, size, depth, batch_first=False) 30 | self.enc.flatten_parameters() 31 | self.proj = nn.Sequential( 32 | nn.LayerNorm(size), # 33 | nn.Linear(size, 1), 34 | ) 35 | 36 | def forward(self, inp): 37 | inp = inp.transpose(0, 1) 38 | inp = self.embed(inp.squeeze(-1).long()) 39 | out, _ = torch.compiler.disable()(self.enc)(inp) 40 | return self.proj(out[-1, :]) 41 | 42 | 43 | @app.command() 44 | def main( 45 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 46 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 47 | length: int = 64, 48 | size: int = 64, 49 | depth: int = 1, 50 | batch: int = 256, 51 | steps: int = 10, 52 | weight_decay: float = 0, 53 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 54 | trials: int = 10, 55 | win_condition_multiplier: float = 1.0, 56 | config: Optional[str] = None, 57 | ): 58 | length = configs.get(config, {}).get("length", length) 59 | 60 | dtype = [getattr(torch, d) for d in dtype] 61 | torch.manual_seed(0x1239121) 62 | model = Model(size, depth).cuda() 63 | 64 | def data(): 65 | inp = torch.randn((batch, length, 1), device="cuda", dtype=dtype[0]) 66 | inp = inp > 0 67 | return inp.to(dtype[0]), (inp.sum(1) % 2).to(dtype[0]) 68 | 69 | trial( 70 | model, 71 | data, 72 | F.binary_cross_entropy_with_logits, 73 | loss_win_condition(win_condition_multiplier * 1e-3), 74 | steps, 75 | opt[0], 76 | dtype[0], 77 | size, 78 | batch, 79 | weight_decay, 80 | method[0], 81 | length, 82 | depth, 83 | failure_threshold=10, 84 | base_lr=1e-6, 85 | trials=trials, 86 | ) 87 | 88 | 89 | if __name__ == "__main__": 90 | app() 91 | -------------------------------------------------------------------------------- /benchmark/xor_sequence.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import torch.nn as nn 6 | import typer 7 | from torch.nn import functional as F 8 | 9 | from benchmark.utils import loss_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | configs = { 16 | "trivial": {"length": 2}, 17 | "easy": {"length": 4}, 18 | "medium": {"length": 6}, 19 | "hard": {"length": 9}, 20 | "extreme": {"length": 12}, 21 | "nightmare": {"length": 14}, 22 | } 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, size, depth): 27 | super().__init__() 28 | self.embed0 = nn.Embedding(2, size) 29 | self.embed1 = nn.Embedding(2, size) 30 | self.enc = nn.LSTM(size, size, depth, batch_first=False) 31 | self.dec = nn.LSTM(size, size, depth, batch_first=False) 32 | self.enc.flatten_parameters() 33 | self.dec.flatten_parameters() 34 | self.proj = nn.Sequential( 35 | nn.LayerNorm(size), # 36 | nn.Linear(size, 1), 37 | ) 38 | 39 | def forward(self, inp): 40 | i0, i1 = inp.chunk(2, 1) 41 | i0 = i0.transpose(0, 1) 42 | i1 = i1.transpose(0, 1) 43 | i0 = self.embed0(i0) 44 | i1 = self.embed1(i1) 45 | _, state = torch.compiler.disable()(self.enc)(i0) 46 | out, _ = torch.compiler.disable()(self.dec)(i1, state) 47 | return self.proj(out.transpose(0, 1)) 48 | 49 | 50 | @app.command() 51 | def main( 52 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 53 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 54 | length: int = 14, 55 | size: int = 16, 56 | depth: int = 1, 57 | batch: int = 256, 58 | steps: int = 100, 59 | weight_decay: float = 0, 60 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 61 | win_condition_multiplier: float = 1, 62 | trials: int = 10, 63 | config: Optional[str] = None, 64 | ): 65 | length = configs.get(config, {}).get("length", length) 66 | 67 | dtype = [getattr(torch, d) for d in dtype] 68 | torch.manual_seed(0x1239121) 69 | model = Model(size, depth).cuda() 70 | 71 | def data(): 72 | inp = torch.randn((batch, 2 * length, 1), device="cuda", dtype=dtype[0]) 73 | inp = inp > 0 74 | i0, i1 = inp.chunk(2, 1) 75 | xored = torch.logical_xor(i0, i1) 76 | return inp.long().squeeze(-1), xored.to(dtype[0]) 77 | 78 | trial( 79 | model, 80 | data, 81 | F.binary_cross_entropy_with_logits, 82 | loss_win_condition(win_condition_multiplier * 1e-2), 83 | steps, 84 | opt[0], 85 | dtype[0], 86 | size, 87 | batch, 88 | weight_decay, 89 | method[0], 90 | length, 91 | depth, 92 | failure_threshold=10, 93 | base_lr=0.001, 94 | trials=trials, 95 | ) 96 | 97 | 98 | if __name__ == "__main__": 99 | app() 100 | -------------------------------------------------------------------------------- /benchmark/xor_sequence_rnn.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import torch.nn as nn 6 | import typer 7 | from torch.nn import functional as F 8 | 9 | from benchmark.utils import loss_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | 16 | configs = { 17 | "trivial": {"length": 2}, 18 | "easy": {"length": 4}, 19 | "medium": {"length": 6}, 20 | "hard": {"length": 9}, 21 | "extreme": {"length": 12}, 22 | "nightmare": {"length": 14}, 23 | } 24 | 25 | 26 | class Model(nn.Module): 27 | def __init__(self, size, depth): 28 | super().__init__() 29 | self.embed0 = nn.Embedding(2, size) 30 | self.embed1 = nn.Embedding(2, size) 31 | self.enc = nn.RNN(size, size, depth, batch_first=False) 32 | self.dec = nn.RNN(size, size, depth, batch_first=False) 33 | self.enc.flatten_parameters() 34 | self.dec.flatten_parameters() 35 | self.proj = nn.Sequential( 36 | nn.LayerNorm(size), # 37 | nn.Linear(size, 1), 38 | ) 39 | 40 | def forward(self, inp): 41 | i0, i1 = inp.chunk(2, 1) 42 | i0 = i0.transpose(0, 1) 43 | i1 = i1.transpose(0, 1) 44 | i0 = self.embed0(i0) 45 | i1 = self.embed1(i1) 46 | _, state = torch.compiler.disable()(self.enc)(i0) 47 | out, _ = torch.compiler.disable()(self.dec)(i1, state) 48 | return self.proj(out.transpose(0, 1)) 49 | 50 | 51 | @app.command() 52 | def main( 53 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 54 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 55 | length: int = 14, 56 | size: int = 16, 57 | depth: int = 1, 58 | batch: int = 256, 59 | steps: int = 100, 60 | weight_decay: float = 0, 61 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 62 | win_condition_multiplier: float = 1, 63 | trials: int = 10, 64 | config: Optional[str] = None, 65 | ): 66 | length = configs.get(config, {}).get("length", length) 67 | 68 | dtype = [getattr(torch, d) for d in dtype] 69 | torch.manual_seed(0x1239121) 70 | model = Model(size, depth).cuda() 71 | 72 | def data(): 73 | inp = torch.randn((batch, 2 * length, 1), device="cuda", dtype=dtype[0]) 74 | inp = inp > 0 75 | i0, i1 = inp.chunk(2, 1) 76 | xored = torch.logical_xor(i0, i1) 77 | return inp.long().squeeze(-1), xored.to(dtype[0]) 78 | 79 | trial( 80 | model, 81 | data, 82 | F.binary_cross_entropy_with_logits, 83 | loss_win_condition(win_condition_multiplier * 1e-2), 84 | steps, 85 | opt[0], 86 | dtype[0], 87 | size, 88 | batch, 89 | weight_decay, 90 | method[0], 91 | length, 92 | depth, 93 | failure_threshold=10, 94 | base_lr=0.001, 95 | trials=trials, 96 | ) 97 | 98 | 99 | if __name__ == "__main__": 100 | app() 101 | -------------------------------------------------------------------------------- /benchmark/xor_spot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inspired by https://github.com/lixilinx/psgd_torch/blob/master/rnn_xor_problem_general_purpose_preconditioner.py 3 | This version is strongly simplified but follows the same basic idea: 4 | 1) Generate random sequence 5 | 2) Mark two spots 6 | 3) Train a model to predict the xor of the two spots 7 | This does NOT elicit memory in the RNN, but it does force it to learn a pointwise forget mechanism. 8 | """ 9 | 10 | import itertools 11 | from typing import List, Optional 12 | 13 | import torch 14 | import torch.backends.opt_einsum 15 | import torch.nn as nn 16 | import typer 17 | 18 | from benchmark.utils import loss_win_condition, trial 19 | from heavyball.utils import set_torch 20 | 21 | app = typer.Typer(pretty_exceptions_enable=False) 22 | set_torch() 23 | configs = { 24 | "trivial": {"length": 4}, 25 | "easy": {"length": 8}, 26 | "medium": {"length": 16}, 27 | "hard": {"length": 32}, 28 | "extreme": {"length": 64}, 29 | "nightmare": {"length": 96}, 30 | } 31 | 32 | 33 | class Model(nn.Module): 34 | def __init__(self, size, depth): 35 | super().__init__() 36 | self.embed = nn.Embedding(4, size) 37 | self.enc = nn.LSTM(size, size, depth, batch_first=False) 38 | self.enc.flatten_parameters() 39 | self.proj = nn.Sequential( 40 | nn.LayerNorm(size), # 41 | nn.Linear(size, 1), 42 | ) 43 | 44 | def forward(self, inp): 45 | inp = self.embed(inp.squeeze(-1).long()) 46 | inp = inp[0] + inp[1] 47 | out, _ = torch.compiler.disable()(self.enc)(inp.transpose(0, 1)) 48 | return self.proj(out[-1, :]) 49 | 50 | 51 | @app.command() 52 | def main( 53 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 54 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 55 | length: int = 64, 56 | size: int = 64, 57 | depth: int = 1, 58 | batch: int = 256, 59 | steps: int = 10, 60 | weight_decay: float = 0, 61 | opt: List[str] = typer.Option( 62 | ["ForeachSOAP", "PaLMForeachSOAP", "PrecondScheduleForeachSOAP"], help="Optimizers to use" 63 | ), 64 | win_condition_multiplier: float = 1.0, 65 | trials: int = 10, 66 | config: Optional[str] = None, 67 | ): 68 | length = configs.get(config, {}).get("length", length) 69 | dtype = [getattr(torch, d) for d in dtype] 70 | 71 | for args in itertools.product(method, dtype, [(length, size, depth, batch)], opt, [weight_decay]): 72 | m, d, (l, s, dp, b), o, wd = args 73 | 74 | model = Model(s, dp).cuda() 75 | 76 | def data(): 77 | inp = torch.randn((b, l, 1), device="cuda", dtype=d) 78 | inp = inp > 0 79 | zeros = torch.zeros_like(inp) 80 | zeros[:, torch.randint(0, l, (b,), device="cuda")] = 1 81 | zeros[:, torch.randint(0, l, (b,), device="cuda")] = 1 82 | target = (inp * zeros).sum(1) % 2 83 | return torch.stack((inp, zeros + 2), 0).to(d), target.to(d) 84 | 85 | trial( 86 | model, 87 | data, 88 | torch.nn.functional.binary_cross_entropy_with_logits, 89 | loss_win_condition(win_condition_multiplier * 1e-2), 90 | steps, 91 | o, 92 | d, 93 | s, 94 | b, 95 | wd, 96 | m, 97 | l, 98 | dp, 99 | failure_threshold=10, 100 | trials=trials, 101 | ) 102 | 103 | 104 | if __name__ == "__main__": 105 | app() 106 | -------------------------------------------------------------------------------- /benchmark/xor_spot_rnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inspired by https://github.com/lixilinx/psgd_torch/blob/master/rnn_xor_problem_general_purpose_preconditioner.py 3 | This version is strongly simplified but follows the same basic idea: 4 | 1) Generate random sequence 5 | 2) Mark two spots 6 | 3) Train a model to predict the xor of the two spots 7 | This does NOT elicit memory in the RNN, but it does force it to learn a pointwise forget mechanism. 8 | """ 9 | 10 | import itertools 11 | from typing import List, Optional 12 | 13 | import torch 14 | import torch.backends.opt_einsum 15 | import torch.nn as nn 16 | import typer 17 | 18 | from benchmark.utils import loss_win_condition, trial 19 | from heavyball.utils import set_torch 20 | 21 | app = typer.Typer(pretty_exceptions_enable=False) 22 | set_torch() 23 | 24 | configs = { 25 | "trivial": {"length": 4}, 26 | "easy": {"length": 8}, 27 | "medium": {"length": 16}, 28 | "hard": {"length": 32}, 29 | "extreme": {"length": 64}, 30 | "nightmare": {"length": 96}, 31 | } 32 | 33 | 34 | class Model(nn.Module): 35 | def __init__(self, size, depth): 36 | super().__init__() 37 | self.embed = nn.Embedding(4, size) 38 | self.enc = nn.RNN(size, size, depth, batch_first=False) 39 | self.enc.flatten_parameters() 40 | self.proj = nn.Sequential( 41 | nn.LayerNorm(size), # 42 | nn.Linear(size, 1), 43 | ) 44 | 45 | def forward(self, inp): 46 | inp = self.embed(inp.squeeze(-1).long()) 47 | inp = inp[0] + inp[1] 48 | out, _ = torch.compiler.disable()(self.enc)(inp.transpose(0, 1)) 49 | return self.proj(out[-1, :]) 50 | 51 | 52 | @app.command() 53 | def main( 54 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 55 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 56 | length: int = 64, 57 | size: int = 64, 58 | depth: int = 1, 59 | batch: int = 256, 60 | steps: int = 10, 61 | weight_decay: float = 0, 62 | opt: List[str] = typer.Option( 63 | ["ForeachSOAP", "PaLMForeachSOAP", "PrecondScheduleForeachSOAP"], help="Optimizers to use" 64 | ), 65 | win_condition_multiplier: float = 1.0, 66 | trials: int = 10, 67 | config: Optional[str] = None, 68 | ): 69 | length = configs.get(config, {}).get("length", length) 70 | 71 | dtype = [getattr(torch, d) for d in dtype] 72 | 73 | for args in itertools.product(method, dtype, [(length, size, depth, batch)], opt, [weight_decay]): 74 | m, d, (l, s, dp, b), o, wd = args 75 | 76 | model = Model(s, dp).cuda() 77 | 78 | def data(): 79 | inp = torch.randn((b, l, 1), device="cuda", dtype=d) 80 | inp = inp > 0 81 | zeros = torch.zeros_like(inp) 82 | zeros[:, torch.randint(0, l, (b,), device="cuda")] = 1 83 | zeros[:, torch.randint(0, l, (b,), device="cuda")] = 1 84 | target = (inp * zeros).sum(1) % 2 85 | return torch.stack((inp, zeros + 2), 0).to(d), target.to(d) 86 | 87 | trial( 88 | model, 89 | data, 90 | torch.nn.functional.binary_cross_entropy_with_logits, 91 | loss_win_condition(win_condition_multiplier * 1e-2), 92 | steps, 93 | o, 94 | d, 95 | s, 96 | b, 97 | wd, 98 | m, 99 | l, 100 | dp, 101 | failure_threshold=10, 102 | trials=trials, 103 | ) 104 | 105 | 106 | if __name__ == "__main__": 107 | app() 108 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | rm -rf dist/* 2 | python -m build 3 | python -m twine upload dist/* 4 | rm -rf dist/ build/ heavyball.egg-info/ 5 | -------------------------------------------------------------------------------- /docs/assets/benchmark_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/docs/assets/benchmark_matrix.png -------------------------------------------------------------------------------- /docs/assets/early_stopping.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/docs/assets/early_stopping.png -------------------------------------------------------------------------------- /docs/assets/hiker_analogy_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/docs/assets/hiker_analogy_diagram.png -------------------------------------------------------------------------------- /docs/assets/psgd_efficiency_cache.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/docs/assets/psgd_efficiency_cache.png -------------------------------------------------------------------------------- /docs/assets/psgd_efficiency_cache_triu_as_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/docs/assets/psgd_efficiency_cache_triu_as_line.png -------------------------------------------------------------------------------- /docs/assets/psgd_efficiency_triu_as_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/docs/assets/psgd_efficiency_triu_as_line.png -------------------------------------------------------------------------------- /docs/assets/saddle_point_comparison.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/docs/assets/saddle_point_comparison.gif -------------------------------------------------------------------------------- /docs/benchmark.md: -------------------------------------------------------------------------------- 1 | # Beyond the Leaderboard: A Diagnostic Benchmark for Optimizer Reliability 2 | 3 | Traditional machine learning benchmarks often incentivize "teaching to the test," a race to top leaderboards that may 4 | not reflect real-world performance. This practice can mask a critical issue: **silent failures**. Optimizers may 5 | converge without reporting an error, yet settle in a suboptimal solution, leading to models that underperform in subtle 6 | but significant ways. Such failures are costly, as they often go undetected until significant downstream damage has occurred. 7 | 8 | The HeavyBall Benchmark was created to address this problem. It's not another leaderboard but a diagnostic tool 9 | designed to expose hidden optimizer weaknesses. Each test provides a clear, pass/fail check against known optimization 10 | challenges, creating a detailed map of an optimizer's strengths and weaknesses. 11 | 12 | ## The Problem of Silent Failures 13 | 14 | A silent failure occurs when an optimizer converges without error yet settles in a suboptimal basin, leading to poor 15 | downstream model performance. A descending loss curve can be profoundly misleading. Without a clear point of comparison, 16 | an optimizer that has simply gotten stuck can appear identical to one that has found a genuinely good solution. 17 | 18 |  19 | *Figure 3 from ["Black Box Lie Group Preconditioners for SGD"](https://arxiv.org/abs/2211.04422). As shown, SGD (blue) 20 | appears converged, but PSGD (black) finds a significantly better basin. This highlights a fundamental capability gap.* 21 | 22 | The graph above illustrates this trap. All three optimizers seem to converge as their loss curves flatten. An 23 | unsuspecting practitioner might see the flatlining loss of the blue curve (SGD) and conclude the job is done. Yet, the 24 | black curve (PSGD) consistently finds a solution that is an order of magnitude better. This isn't a matter of an unlucky 25 | random seed; it's a capability gap. No amount of rerunning SGD will find the better solution that PSGD locates with 26 | ease. 27 | 28 | This capability gap often goes unnoticed. The training logs provide no explicit error, leaving developers to discover 29 | the underperformance only through expensive downstream evaluation. 30 | 31 | ## The HeavyBall Diagnostic Benchmark 32 | 33 | Instead of a single score, the HeavyBall Benchmark provides a granular, diagnostic map of an optimizer's performance. 34 | It's built on a suite of over 150 independent, pass/fail tests, each targeting a specific, well-understood challenge 35 | known to cause hidden failures. 36 | 37 |  38 | *This map reveals systemic strengths and weaknesses in popular optimizers. Each cell represents a pass/fail outcome for 39 | a solver (column) on a specific task (row). Darker blue indicates a higher success rate.* 40 | 41 | There is no partial credit; the optimizer either solves the problem or it does not. This binary outcome makes failure 42 | modes explicit, turning abstract weaknesses into concrete, observable data. 43 | 44 | ### Experimental Setup 45 | 46 | To ensure a fair comparison, we gave each optimizer a generous budget: 1,000 hyperparameter trials per task, with each 47 | trial running for up to 1,000,000 steps. This setup tests the optimizer's raw capability, not just its default settings. 48 | The `Attempts` column in our results reflects the average number of trials needed for success; a lower number indicates 49 | that the method is easier to tune for the problems it can solve. 50 | 51 | ## Benchmark Results: No Single Best Optimizer 52 | 53 | Our results show that no single optimizer dominates. Even the best-performing optimizers exhibit surprising weaknesses, 54 | reinforcing the need for diagnostic rather than purely comparative evaluation. 55 | 56 | | Optimizer | Cautious¹ | Mars² | Success | Attempts | Avg Runtime (s) | 57 | |:---------------|:----------|:------|:--------|:---------|:----------------| 58 | | PSGDKron | No | No | 77.0% | 73.2 | 8240 | 59 | | NewtonPSGDKron | No | No | 77.0% | 80.5 | 9052 | 60 | | AdamW | Yes | No | 75.7% | 61.2 | 8072 | 61 | | ForeachSOAP | No | No | 72.5% | 77.9 | 7827 | 62 | | AdamW | No | No | 72.3% | 107.8 | 10029 | 63 | | MuonLaProp | No | No | 68.2% | 82.7 | 10141 | 64 | | RMSprop | No | No | 55.6% | 114.4 | 10725 | 65 | | Muon | No | No | 51.0% | 129.1 | 14525 | 66 | 67 | ¹ `Cautious`: Avoids taking a step in a direction the current gradient does not agree with 68 |