├── .github └── workflows │ └── ci.yaml ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── benchmark_matrix.png ├── delayed_psgd.png └── memory.png ├── benchmark ├── README.md ├── adversarial_gradient.py ├── batch_size_scaling.py ├── beale.py ├── benchmark_results.md ├── char_rnn.py ├── constrained_optimization.py ├── create_hiker_analogy_diagram.py ├── discontinuous_gradient.py ├── dynamic_landscape.py ├── exploding_gradient.py ├── format_results.py ├── generate_animations.py ├── gradient_delay.py ├── gradient_noise_scale.py ├── grokking.py ├── layer_wise_scale.py ├── loss_contour.py ├── merge_logs.py ├── minimax.py ├── momentum_utilization.py ├── noisy_matmul.py ├── parameter_scale.py ├── plateau_navigation.py ├── postprocess_requirements.txt ├── powers.py ├── powers_varying_target.py ├── quadratic_varying_scale.py ├── quadratic_varying_target.py ├── rastrigin.py ├── relu_boundaries.py ├── requirements.txt ├── rosenbrock.py ├── run_all_benchmarks.py ├── saddle_point.py ├── scale_invariant.py ├── shakespeare.txt ├── sparse_gradient.py ├── utils.py ├── wide_linear.py ├── xor_digit.py ├── xor_digit_rnn.py ├── xor_sequence.py ├── xor_sequence_rnn.py ├── xor_spot.py └── xor_spot_rnn.py ├── build.sh ├── docs ├── assets │ ├── benchmark_matrix.png │ ├── early_stopping.png │ ├── hiker_analogy_diagram.png │ ├── psgd_efficiency_cache.png │ ├── psgd_efficiency_cache_triu_as_line.png │ ├── psgd_efficiency_triu_as_line.png │ └── saddle_point_comparison.gif ├── benchmark.md └── psgd_efficiency.md ├── examples ├── autoencoder.py ├── lra.py └── modify_functions.py ├── heavyball ├── __init__.py ├── chainable.py ├── helpers.py └── utils.py ├── pre-commit.yaml ├── pyproject.toml └── test ├── benchmark_psgd_lb.py ├── readme.md ├── test_bf16_params.py ├── test_bf16_q.py ├── test_bf16_storage.py ├── test_caution.py ├── test_channels_last.py ├── test_closure.py ├── test_ema.py ├── test_foreach.py ├── test_hook.py ├── test_mars.py ├── test_memory.py ├── test_memory_leak.py ├── test_merge.py ├── test_no_grad.py ├── test_psgd_precond_init_stability.py ├── test_save_restore.py ├── test_soap.py └── test_stochastic_updates.py /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: Continuous integration 2 | 3 | on: 4 | pull_request: 5 | workflow_dispatch: 6 | 7 | jobs: 8 | pre-commit: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: actions/setup-python@v5 13 | with: 14 | python-version: "3.11" 15 | - name: Run pre-commit hooks 16 | uses: pre-commit/action@v3.0.1 17 | with: 18 | extra_args: --all-files --config pre-commit.yaml 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | splinecam 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2022, Lucas Nestler 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # heavyball 2 | 3 | [![PyPI version](https://img.shields.io/pypi/v/heavyball?color=blue)][pypi] 4 | [![License](https://img.shields.io/badge/license-BSD--3--Clause-blue.svg)][license] 5 | 6 | _High-performance, extensible, chainable optimizers for PyTorch._ 7 | 8 | ## Why heavyball 9 | 10 | - **Lightning-Fast Training**: Batched `foreach` operations deliver significant speedups on large models. 11 | - **Adaptive & Extensible**: Built-in AdamW, RMSprop, Schedule-Free algorithms, and PaLM-inspired schedules. 12 | - **Plug-and-Play**: Drop-in replacements for `torch.optim` with seamless integration. 13 | - **Customizable**: Chainable API lets you compose optimizers and transforms (MARS correction, cautious updates, orthogonal updates). 14 | - **Battle-Tested**: Extensive benchmarks and real-world examples included. 15 | 16 | ## Key Features 17 | 18 | - Foreach-based optimizers: `ForeachAdamW`, `ForeachRMSprop`, `ForeachSFAdamW`, `Muon`, `ADOPT`, `MSAM`, … 19 | - Schedule-Free optimizers with dynamic learning rate adaptation. 20 | - Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling. 21 | - Chainable transforms for custom optimization recipes. 22 | - Comprehensive benchmark suite (`benchmark/`). 23 | - Detailed documentation and example-driven tutorials. 24 | 25 | ## Quickstart 26 | 27 | **Install:** 28 | ```bash 29 | pip install heavyball 30 | ``` 31 | 32 | **Basic usage:** 33 | ```python 34 | import torch 35 | from torch import nn 36 | from heavyball import ForeachAdamW 37 | 38 | model = nn.Sequential( 39 | nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10) 40 | ) 41 | optimizer = ForeachAdamW(model.parameters(), lr=1e-3) 42 | 43 | for data, target in dataloader: 44 | optimizer.zero_grad() 45 | output = model(data) 46 | loss = torch.nn.functional.cross_entropy(output, target) 47 | loss.backward() 48 | optimizer.step() 49 | ``` 50 | 51 | ## Benchmarks 52 | 53 | > Reproduce benchmarks with: 54 | > ```bash 55 | > python3 -m benchmark.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880 56 | > ``` 57 | 58 | 59 | ## Contributing 60 | 61 | We welcome contributions! Please check the [issue tracker][tracker] and follow these steps: 62 | 1. Fork the repo and create a feature branch. 63 | 2. Install dev dependencies: `pip install -e .[dev]`. 64 | 3. Run tests: `pytest`. 65 | 4. Submit a pull request. 66 | 67 | ## License 68 | 69 | BSD 3-Clause — see the [LICENSE](LICENSE) file. 70 | 71 | --- 72 |

73 | Made by the HeavyBall team. 74 |

75 | 76 | [pypi]: https://pypi.org/project/heavyball/ 77 | [license]: LICENSE 78 | [tracker]: https://github.com/HomebrewML/HeavyBall/issues 79 | -------------------------------------------------------------------------------- /assets/benchmark_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/assets/benchmark_matrix.png -------------------------------------------------------------------------------- /assets/delayed_psgd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/assets/delayed_psgd.png -------------------------------------------------------------------------------- /assets/memory.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/assets/memory.png -------------------------------------------------------------------------------- /benchmark/README.md: -------------------------------------------------------------------------------- 1 | # HeavyBall Benchmark Suite 2 | 3 | This repository contains a suite of benchmark problems designed to test the capabilities and robustness of `torch.optim` optimizers. The framework is designed to be modular and extensible, allowing for the easy addition of new benchmarks and optimizers. 4 | 5 | ## Setup 6 | 7 | To get started, install the required dependencies: 8 | 9 | ```bash 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Usage 14 | 15 | To run the entire benchmark suite with a specific set of optimizers, use the `run_all_benchmarks.py` script. 16 | 17 | For example, to run the benchmarks for Adam and SGD on easy and medium difficulties: 18 | ```bash 19 | python run_all_benchmarks.py --opt Adam --opt SGD --difficulties easy medium 20 | ``` 21 | 22 | You can specify multiple optimizers and difficulties. The results will be written to `benchmark_results.md`. 23 | 24 | For more options, see: 25 | ```bash 26 | python run_all_benchmarks.py --help 27 | ``` 28 | 29 | ### Core Philosophy and Design Principles 30 | 31 | Based on the analysis of [`benchmark_template.py`](benchmark_template.py), [`optimizer_template.py`](optimizer_template.py), and [`utils.py`](utils.py), the core framework is a modular and extensible system designed for benchmarking `torch` optimizers. It is composed of three primary components: 32 | 33 | 1. **Benchmarks**: Each benchmark is a self-contained optimization problem defined by a class. It provides a loss function to be minimized, the parameters to be optimized, and a specific success condition. 34 | 2. **Optimizers**: These are expected to adhere to the `torch.optim.Optimizer` interface. A template is provided to facilitate the creation of new custom optimizers. 35 | 3. **Benchmarking Harness**: The [`utils.py`](utils.py) script contains the engine that runs the benchmarks. It includes a sophisticated `Validator` class for monitoring convergence and detecting failures, an `Objective` class that encapsulates the entire process of running a benchmark (including hyperparameter search with `optuna`), and a high-level `trial` function to orchestrate the whole process. 36 | 37 | #### Design Principles 38 | 39 | The framework is built upon the following design principles: 40 | 41 | * **Modularity**: The separation of benchmarks, optimizers, and the runner logic allows for easy extension. New benchmarks or optimizers can be added by implementing simple, well-defined interfaces without needing to alter the core execution logic. 42 | * **Automation**: The framework automates the process of running benchmarks, including a hyperparameter search for learning rate and momentum coefficients. This simplifies the process of evaluating an optimizer's performance across various problems. 43 | * **Robustness**: The `Validator` class provides sophisticated logic to terminate runs that are not making progress, saving computational resources. It checks for multiple failure conditions, such as stagnating loss and lack of improvement over time. 44 | * **PyTorch Native**: The framework is built on PyTorch and leverages its core components like `autograd` for gradient computation and the standard `Optimizer` class structure. 45 | 46 | #### Benchmark Structure 47 | 48 | A benchmark is structured as a class, as defined in [`benchmark_template.py`](benchmark_template.py:4), with the following key methods: 49 | 50 | * **`__init__(self, device)`**: Initializes the benchmark's state. This is where the model parameters and any necessary data are created. The parameters to be optimized are stored in `self.params`. 51 | * **`__call__(self) -> torch.Tensor`**: This is the core of the benchmark, defining the objective function. It computes and returns a scalar loss tensor that the optimizer will try to minimize. 52 | * **`has_succeeded(self) -> bool`**: Defines the "win condition" for the benchmark. It returns `True` if the optimization is considered successful, typically based on the loss value falling below a predefined threshold. 53 | * **`get_params(self) -> list[torch.Tensor]`**: Returns the list of `torch.Tensor` parameters that will be optimized. 54 | 55 | ### Existing Benchmarks 56 | 57 | | File | Category | Description | 58 | |---|---|---| 59 | | [`adversarial_gradient.py`](adversarial_gradient.py) | Noise Robustness | Tests optimizer's robustness to an oscillating adversarial component added to the gradient, simulating adversarial noise. | 60 | | [`batch_size_scaling.py`](batch_size_scaling.py) | Noise Robustness | Tests how optimizers handle noise at different scales, which are modulated by a simulated, randomly changing batch size. | 61 | | [`beale.py`](beale.py) | General/Classic | Implements the Beale function, a classic optimization benchmark with sharp valleys, to test general performance. Also includes visualization. | 62 | | [`char_rnn.py`](char_rnn.py) | Sequence & Memory (RNN/LSTM) | A character-level language model using an LSTM on text data, testing performance on a sequence task requiring memory. | 63 | | [`constrained_optimization.py`](constrained_optimization.py) | Landscape Traversal | A simple quadratic objective with a penalty-based constraint, creating a sharp ridge in the loss landscape for the optimizer to navigate. | 64 | | [`discontinuous_gradient.py`](discontinuous_gradient.py) | Gradient Characteristics | Tests optimizer robustness to non-smooth landscapes by using an objective function with a discontinuous gradient at the origin. | 65 | | [`dynamic_landscape.py`](dynamic_landscape.py) | Dynamic Environments | Tests an optimizer's ability to track a continuously shifting target in a non-stationary loss landscape. | 66 | | [`exploding_gradient.py`](exploding_gradient.py) | Gradient Characteristics | Tests an optimizer's numerical stability and handling of extreme gradient values by using an exponential function that causes gradients to grow rapidly. | 67 | | [`gradient_delay.py`](gradient_delay.py) | Gradient Characteristics | Tests an optimizer's ability to handle asynchronous or delayed updates by using gradients from previous steps. | 68 | | [`gradient_noise_scale.py`](gradient_noise_scale.py) | Noise Robustness | Tests an optimizer's ability to handle dynamically changing noise levels, where the noise scale anneals over time. | 69 | | [`grokking.py`](grokking.py) | Landscape Traversal | Tests for 'grokking' by training a model on a modular arithmetic task, examining if the optimizer can find a generalizable solution after a long period of memorization. | 70 | | [`layer_wise_scale.py`](layer_wise_scale.py) | Multi-Scale & Conditioning | Tests an optimizer's ability to handle parameters with vastly different gradient scales by scaling the loss contribution of different layers. | 71 | | [`minimax.py`](minimax.py) | Landscape Traversal | Implements a minimax objective function which creates a saddle point, testing the optimizer's ability to escape such points. | 72 | | [`momentum_utilization.py`](momentum_utilization.py) | Landscape Traversal | Tests the effective use of momentum by creating an oscillating loss landscape with many local minima. | 73 | | [`noisy_matmul.py`](noisy_matmul.py) | Gradient Characteristics | Tests optimizer stability in deep networks by performing a sequence of matrix multiplications, which can lead to exploding or vanishing gradients. | 74 | | [`parameter_scale.py`](parameter_scale.py) | Multi-Scale & Conditioning | Tests an optimizer's ability to handle parameters initialized at widely different scales, creating a poorly conditioned problem. | 75 | | [`plateau_navigation.py`](plateau_navigation.py) | Landscape Traversal | Tests an optimizer's ability to navigate a loss landscape with a large, flat plateau region surrounded by a steep cliff. | 76 | | [`powers_varying_target.py`](powers_varying_target.py) | Multi-Scale & Conditioning | Creates a complex, poorly conditioned landscape by raising parameters to different powers against a non-zero, varying target. | 77 | | [`powers.py`](powers.py) | Multi-Scale & Conditioning | Creates a poorly conditioned problem by raising parameters to various powers, resulting in gradients of different magnitudes. | 78 | | [`quadratic_varying_scale.py`](quadratic_varying_scale.py) | Multi-Scale & Conditioning | Tests handling of ill-conditioned problems by creating a quadratic objective where each parameter's gradient has a different scale. | 79 | | [`quadratic_varying_target.py`](quadratic_varying_target.py) | General/Classic | A simple quadratic bowl (sphere) benchmark where the minimum is at a non-zero target vector. | 80 | | [`rastrigin.py`](rastrigin.py) | General/Classic | Implements the Rastrigin function, a classic highly multi-modal benchmark for testing global optimization capabilities. | 81 | | [`rosenbrock.py`](rosenbrock.py) | General/Classic | Implements the Rosenbrock function, a classic benchmark known for its narrow, banana-shaped valley that is difficult to navigate. | 82 | | [`saddle_point.py`](saddle_point.py) | Landscape Traversal | Tests an optimizer's ability to escape a classic saddle point (x^2 - y^2) and visualizes the optimization paths. | 83 | | [`scale_invariant.py`](scale_invariant.py) | Multi-Scale & Conditioning | Tests an optimizer's handling of parameters at different scales by using a logarithmic objective on parameters initialized across many orders of magnitude. | 84 | | [`sparse_gradient.py`](sparse_gradient.py) | Gradient Characteristics | Tests an optimizer's performance when gradients are sparse, which is simulated by randomly masking parameter updates. | 85 | | [`wide_linear.py`](wide_linear.py) | Multi-Scale & Conditioning | Tests optimizer performance on a model with a very wide linear layer, which can present conditioning challenges. | 86 | | [`xor_digit_rnn.py`](xor_digit_rnn.py) | Sequence & Memory (RNN/LSTM) | Tests an RNN's ability to solve the parity task (XOR sum) on a sequence of bits, requiring it to maintain a memory state. | 87 | | [`xor_digit.py`](xor_digit.py) | Sequence & Memory (RNN/LSTM) | Tests an LSTM's ability to solve the parity task (XOR sum) on a sequence of bits, a classic test of sequence memory. | 88 | | [`xor_sequence_rnn.py`](xor_sequence_rnn.py) | Sequence & Memory (RNN/LSTM) | A sequence-to-sequence task where an RNN must learn to compute the element-wise XOR of two input sequences. | 89 | | [`xor_sequence.py`](xor_sequence.py) | Sequence & Memory (RNN/LSTM) | A sequence-to-sequence task where an LSTM must learn to compute the element-wise XOR of two input sequences. | 90 | | [`xor_spot_rnn.py`](xor_spot_rnn.py) | Sequence & Memory (RNN/LSTM) | Tests an RNN's ability to learn a pointwise forget mechanism by predicting the XOR of two values at randomly marked spots in a sequence. | 91 | | [`xor_spot.py`](xor_spot.py) | Sequence & Memory (RNN/LSTM) | Tests an LSTM's ability to learn a pointwise forget mechanism by predicting the XOR of two values at randomly marked spots in a sequence. | 92 | 93 | ### Contributing 94 | 95 | This framework is designed for extensibility. You can contribute by adding new benchmarks to test optimizers against, or by implementing new optimizers to evaluate. 96 | 97 | #### Adding a New Benchmark 98 | 99 | To add a new benchmark, follow these steps: 100 | 101 | 1. **Create a new Python file** for your benchmark (e.g., `my_new_benchmark.py`). 102 | 2. **Implement the benchmark class.** Your class should follow the structure provided in [`benchmark_template.py`](benchmark_template.py:4). 103 | * **`__init__(self, device)`**: Initializes the benchmark's state. This is where the model parameters and any necessary data are created. The parameters to be optimized are stored in `self.params`. 104 | * **`__call__(self) -> torch.Tensor`**: This is the core of the benchmark, defining the objective function. It computes and returns a scalar loss tensor that the optimizer will try to minimize. 105 | * **`has_succeeded(self) -> bool`**: Defines the "win condition" for the benchmark. It returns `True` if the optimization is considered successful, typically based on the loss value falling below a predefined threshold. 106 | * **`get_params(self) -> list[torch.Tensor]`**: Returns the list of `torch.Tensor` parameters that will be optimized. 107 | 3. **Add your benchmark to `run_all_benchmarks.py`**: To include your new benchmark in the full suite, you will need to import it in [`run_all_benchmarks.py`](run_all_benchmarks.py) and add it to the list of benchmarks to be run. 108 | 4. **(Optional) Add a description to `README.md`**: Add a row to the "Benchmark Tests" table in this `README.md` to document your new benchmark. 109 | 110 | #### Adding a New Optimizer 111 | 112 | To add a new optimizer: 113 | 114 | 1. **Implement your optimizer class.** Your optimizer must follow the `torch.optim.Optimizer` interface. You can use [`optimizer_template.py`](optimizer_template.py:1) as a starting point. 115 | 2. **Make your optimizer available to the benchmark runner.** This is typically done by adding it to the `optimizer_mapping` dictionary in [`utils.py`](utils.py:1). 116 | 3. **Run the benchmarks.** You can now run your optimizer against the benchmarks using the `run_all_benchmarks.py` script. For example: 117 | ```bash 118 | python run_all_benchmarks.py --opt YourNewOptimizerName --difficulties easy 119 | -------------------------------------------------------------------------------- /benchmark/adversarial_gradient.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import param_norm_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | 14 | configs = { 15 | "trivial": {"frequency": 1000}, 16 | "easy": {"frequency": 100}, 17 | "medium": {"frequency": 10}, 18 | "hard": {"frequency": 7}, 19 | "extreme": {"frequency": 4}, 20 | "nightmare": {"frequency": 2}, 21 | } 22 | 23 | 24 | class Model(nn.Module): 25 | def __init__(self, frequency, size=1024): 26 | super().__init__() 27 | self.param = nn.Parameter(torch.randn(size)) 28 | self.register_buffer("step", torch.zeros(1)) 29 | self.frequency = 1 / frequency * 1.1 # to avoid repeating 30 | 31 | def forward(self): 32 | """Test optimizer's robustness to adversarial gradient patterns.""" 33 | self.step += 1 34 | # Create an oscillating adversarial component 35 | direction = torch.sin(self.step * torch.pi * self.frequency) 36 | # Main objective plus adversarial component 37 | return self.param.square().mean() + direction * self.param.mean() 38 | 39 | 40 | @app.command() 41 | def main( 42 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 43 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 44 | steps: int = 100, 45 | weight_decay: float = 0, 46 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 47 | trials: int = 100, 48 | win_condition_multiplier: float = 1.0, 49 | config: Optional[str] = None, 50 | ): 51 | frequency = configs.get(config, {}).get("frequency", 10) 52 | dtype = [getattr(torch, d) for d in dtype] 53 | model = Model(frequency).cuda().double() 54 | 55 | def data(): 56 | return None, None 57 | 58 | # More lenient condition due to adversarial component 59 | trial( 60 | model, 61 | data, 62 | None, 63 | param_norm_win_condition(win_condition_multiplier * 1e-3, 0), 64 | steps, 65 | opt[0], 66 | dtype[0], 67 | 1, 68 | 1, 69 | weight_decay, 70 | method[0], 71 | 1, 72 | 1, 73 | failure_threshold=7, 74 | base_lr=1e-3, 75 | trials=trials, 76 | ) # More attempts for adversarial case 77 | 78 | 79 | if __name__ == "__main__": 80 | app() 81 | -------------------------------------------------------------------------------- /benchmark/batch_size_scaling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from typing import List, Optional 4 | 5 | import torch 6 | import torch.backends.opt_einsum 7 | import typer 8 | from torch import nn 9 | 10 | from benchmark.utils import param_norm_win_condition, trial 11 | from heavyball.utils import set_torch 12 | 13 | app = typer.Typer(pretty_exceptions_enable=False) 14 | set_torch() 15 | 16 | configs = { 17 | "trivial": {"max_batch": 65536}, 18 | "easy": {"max_batch": 8192}, 19 | "medium": {"max_batch": 1024}, 20 | "hard": {"max_batch": 128}, 21 | "extreme": {"max_batch": 16}, 22 | "nightmare": {"max_batch": 2}, 23 | } 24 | 25 | 26 | class Model(nn.Module): 27 | def __init__(self, max_batch: int, size=1024): 28 | super().__init__() 29 | self.param = nn.Parameter(torch.randn(size)) 30 | self.max_batch = max_batch 31 | self.rng = random.Random(0x1238192) 32 | 33 | def forward(self): 34 | """Test optimizer's ability to handle different batch sizes and noise scales.""" 35 | generator = torch.Generator(device=self.param.device).manual_seed(self.rng.randint(0, 2**31)) 36 | noise = torch.randn(self.param.shape, generator=generator, device=self.param.device) 37 | scale = self.param.norm() / (noise.norm() + 1e-6) 38 | batch_scale = self.max_batch ** (self.rng.random() / 2) # sqrt of random uniform between 1 and max_batch 39 | noise *= scale.detach() / math.sqrt(batch_scale) 40 | return (self.param + noise).square().mean() 41 | 42 | 43 | @app.command() 44 | def main( 45 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 46 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 47 | steps: int = 100, 48 | weight_decay: float = 0, 49 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 50 | trials: int = 100, 51 | win_condition_multiplier: float = 1.0, 52 | config: Optional[str] = None, 53 | ): 54 | max_batch = configs.get(config, {}).get("max_batch", 256) 55 | dtype = [getattr(torch, d) for d in dtype] 56 | model = Model(max_batch).cuda().double() 57 | 58 | def data(): 59 | return None, None 60 | 61 | # Use a more lenient win condition since we have inherent noise 62 | trial( 63 | model, 64 | data, 65 | None, 66 | param_norm_win_condition(win_condition_multiplier * 1e-8, 0), 67 | steps, 68 | opt[0], 69 | dtype[0], 70 | 1, 71 | 1, 72 | weight_decay, 73 | method[0], 74 | 1, 75 | 1, 76 | failure_threshold=5, 77 | base_lr=1e-3, 78 | trials=trials, 79 | ) 80 | 81 | 82 | if __name__ == "__main__": 83 | app() 84 | -------------------------------------------------------------------------------- /benchmark/beale.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import random 3 | from typing import List, Optional 4 | 5 | import matplotlib.colors 6 | import torch 7 | import torch.backends.opt_einsum 8 | import typer 9 | from torch import nn 10 | 11 | from benchmark.utils import Plotter, SkipConfig, loss_win_condition, trial 12 | from heavyball.utils import set_torch 13 | 14 | app = typer.Typer(pretty_exceptions_enable=False) 15 | set_torch() 16 | 17 | 18 | def objective(x, y): 19 | x = x + 3 20 | y = y + 0.5 21 | return (1.5 - x + x * y) ** 2 + (2.25 - x + x * y**2) ** 2 + (2.625 - x + x * y**3) ** 2 22 | 23 | 24 | class Model(nn.Module): 25 | def __init__(self, x): 26 | super().__init__() 27 | self.param = nn.Parameter(torch.tensor(x).float()) 28 | 29 | def forward(self): 30 | return objective(*self.param) 31 | 32 | 33 | @app.command() 34 | def main( 35 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 36 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 37 | steps: int = 100, 38 | weight_decay: float = 0, 39 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 40 | show_image: bool = False, 41 | trials: int = 100, 42 | win_condition_multiplier: float = 1.0, 43 | config: Optional[str] = None, 44 | ): 45 | if config is not None and config != "trivial": 46 | raise SkipConfig("'config' must be 'trivial'.") 47 | dtype = [getattr(torch, d) for d in dtype] 48 | coords = (-7, -4) 49 | 50 | # Clean up old plots 51 | for path in pathlib.Path(".").glob("beale.png"): 52 | path.unlink() 53 | 54 | colors = list(matplotlib.colors.TABLEAU_COLORS.values()) 55 | rng = random.Random(0x1239121) 56 | rng.shuffle(colors) 57 | 58 | if show_image: 59 | model = Plotter(Model(coords), x_limits=(-8, 2), y_limits=(-8, 2), should_normalize=True) 60 | else: 61 | model = Model(coords) 62 | model.double() 63 | 64 | def data(): 65 | return None, None 66 | 67 | model = trial( 68 | model, 69 | data, 70 | None, 71 | loss_win_condition(win_condition_multiplier * 1e-8 * (not show_image)), 72 | steps, 73 | opt[0], 74 | dtype[0], 75 | 1, 76 | 1, 77 | weight_decay, 78 | method[0], 79 | 1, 80 | 1, 81 | base_lr=1e-4, 82 | trials=trials, 83 | return_best=show_image, 84 | ) 85 | 86 | if not show_image: 87 | return 88 | 89 | model.plot(save_path="beale.png") 90 | 91 | 92 | if __name__ == "__main__": 93 | app() 94 | -------------------------------------------------------------------------------- /benchmark/char_rnn.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List 3 | 4 | import torch 5 | import torch.backends.opt_einsum 6 | import torch.nn as nn 7 | import typer 8 | from torch.nn import functional as F 9 | 10 | from benchmark.utils import loss_win_condition, trial 11 | from heavyball.utils import set_torch 12 | 13 | app = typer.Typer(pretty_exceptions_enable=False) 14 | set_torch() 15 | 16 | 17 | class Take0(nn.Module): 18 | def forward(self, x): 19 | return x[0] 20 | 21 | 22 | class Model(nn.Module): 23 | def __init__(self, features: int, sequence: int): 24 | super().__init__() 25 | self.sequence = sequence 26 | self.net = nn.Sequential( 27 | nn.Embedding(256, features), 28 | nn.LSTM(features, features, 1, batch_first=True), # Removed dropout since num_layers=1 29 | Take0(), 30 | nn.Linear(features, 256), 31 | ) 32 | 33 | def forward(self, inp): 34 | return self.net(inp) 35 | 36 | 37 | @app.command() 38 | def main( 39 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 40 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 41 | features: int = 512, 42 | sequence: int = 256, 43 | batch: int = 16, 44 | steps: int = 100, 45 | weight_decay: float = 0, 46 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 47 | win_condition_multiplier: float = 1.0, 48 | trials: int = 10, 49 | ): 50 | dtype = [getattr(torch, d) for d in dtype] 51 | model = Model(features, sequence).cuda() 52 | 53 | # Load text data 54 | benchmark_dir = Path(__file__).parent 55 | with open(benchmark_dir / "shakespeare.txt", "rb") as f: 56 | text = f.read() 57 | chars = torch.frombuffer(text, dtype=torch.uint8).cuda().long() 58 | 59 | # Create holdout set 60 | chars = chars[(sequence + 1) * batch :] 61 | offsets = torch.arange(0, sequence + 1, device="cuda").repeat(batch, 1) 62 | 63 | def data(): 64 | batch_offsets = torch.randint(0, len(chars) - sequence - 1, (batch,), device="cuda") 65 | batch_offsets = batch_offsets[:, None] + offsets 66 | batch_chars = chars[batch_offsets] 67 | batch_chars = batch_chars.view(batch, sequence + 1) 68 | src = batch_chars[:, :-1] 69 | tgt = batch_chars[:, 1:] 70 | return src, tgt 71 | 72 | trial( 73 | model, 74 | data, 75 | F.cross_entropy, 76 | loss_win_condition(win_condition_multiplier * 2.0), 77 | steps, 78 | opt[0], 79 | dtype[0], 80 | features, 81 | batch, 82 | weight_decay, 83 | method[0], 84 | sequence, 85 | 1, 86 | failure_threshold=10, 87 | base_lr=1e-3, 88 | trials=trials, 89 | ) 90 | 91 | 92 | if __name__ == "__main__": 93 | app() 94 | -------------------------------------------------------------------------------- /benchmark/constrained_optimization.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import List, Optional 3 | 4 | import torch 5 | import torch.backends.opt_einsum 6 | import typer 7 | from torch import nn 8 | 9 | from benchmark.utils import trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | configs = { 16 | "trivial": {"penalty": 1e1}, 17 | "easy": {"penalty": 1e2}, 18 | "medium": {"penalty": 1e4}, 19 | "hard": {"size": 1e6}, 20 | "extreme": {"penalty": 1e8}, 21 | "nightmare": {"penalty": 1e10}, 22 | } 23 | 24 | # Objective: Minimize (x-2)^2 subject to x <= 1 25 | # Implemented using a penalty: (x-2)^2 + penalty * max(0, x - 1) 26 | TARGET_X = 1.0 27 | TOLERANCE = 1e-3 28 | 29 | 30 | def objective(x, penalty): 31 | """Objective function with a penalty for violating the constraint x <= 1.""" 32 | return (x - 2.0) ** 2 + penalty * torch.relu(x - TARGET_X) 33 | 34 | 35 | class Model(nn.Module): 36 | def __init__(self, penalty): 37 | super().__init__() 38 | # Using a tensor with requires_grad=True directly as the parameter 39 | self.param = nn.Parameter(torch.zeros((16,))) 40 | self.penalty = penalty 41 | 42 | def forward(self): 43 | return objective(self.param, self.penalty).mean() 44 | 45 | 46 | def win_condition(model, loss): 47 | with torch.no_grad(): 48 | success = ((model.param - TARGET_X).abs() < TOLERANCE).all().item() 49 | return success, {} 50 | 51 | 52 | @app.command() 53 | def main( 54 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 55 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 56 | steps: int = 200, 57 | # Increased steps slightly 58 | weight_decay: float = 0, 59 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 60 | trials: int = 50, # Reduced trials slightly for faster testing 61 | win_condition_multiplier: float = 1.0, # Not used directly, but kept for consistency 62 | config: Optional[str] = None, 63 | ): 64 | penalty = configs.get(config, {}).get("penalty", 1e6) 65 | dtype = [getattr(torch, d) for d in dtype] 66 | 67 | # Clean up old plots if any (though this benchmark doesn't plot) 68 | for path in pathlib.Path(".").glob("constrained_optimization*.png"): 69 | path.unlink() 70 | 71 | model = Model(penalty) 72 | model.double() # Use double for precision if needed 73 | 74 | # No external data needed for this simple objective 75 | def data(): 76 | return None, None 77 | 78 | # The loss is the objective value itself 79 | loss_fn = None 80 | 81 | trial( 82 | model, 83 | data, 84 | loss_fn, 85 | win_condition, 86 | steps, 87 | opt[0], 88 | dtype[0], 89 | 1, # size (not relevant here) 90 | 1, # batch (not relevant here) 91 | weight_decay, 92 | method[0], 93 | 1, # length (not relevant here) 94 | 1, # depth (not relevant here) 95 | failure_threshold=3, 96 | base_lr=1e-3, # Default base LR, hyperopt will search 97 | trials=trials, 98 | group=32, # Smaller group size might be better for simple problems 99 | ) 100 | 101 | 102 | if __name__ == "__main__": 103 | app() 104 | -------------------------------------------------------------------------------- /benchmark/create_hiker_analogy_diagram.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | def create_hiker_analogy_diagram(): 8 | """ 9 | Generates and saves a multi-panel diagram illustrating key optimization challenges. 10 | 11 | The diagram includes separate panels for: 12 | - A saddle point 13 | - A narrow ravine (representing ill-conditioning) 14 | - A sharp cliff (representing a discontinuous gradient) 15 | """ 16 | # --- Create the plot --- 17 | fig, axes = plt.subplots(1, 3, figsize=(18, 6), subplot_kw={"aspect": "equal"}) 18 | fig.suptitle( 19 | "Challenges in Optimization Landscapes: The Hiker's Analogy", 20 | fontsize=20, 21 | y=0.98, 22 | ) 23 | 24 | # --- Panel 1: Saddle Point --- 25 | ax1 = axes[0] 26 | x1 = np.linspace(-3, 3, 300) 27 | y1 = np.linspace(-3, 3, 300) 28 | X1, Y1 = np.meshgrid(x1, y1) 29 | Z1 = X1**3 + Y1**3 30 | levels1 = np.linspace(np.min(Z1), np.max(Z1), 30) 31 | ax1.contourf(X1, Y1, Z1, levels=levels1, cmap="viridis") 32 | ax1.contour(X1, Y1, Z1, levels=levels1, colors="black", linewidths=0.5, alpha=0.5) 33 | ax1.set_title("Saddle Point", fontsize=16) 34 | ax1.text(0, 0, "★", color="red", fontsize=20, ha="center", va="center") 35 | 36 | # --- Panel 2: Ravine (Ill-Conditioning) --- 37 | ax2 = axes[1] 38 | x2 = np.linspace(-3, 3, 300) 39 | y2 = np.linspace(-3, 3, 300) 40 | X2, Y2 = np.meshgrid(x2, y2) 41 | Z2 = 0.1 * X2**2 + 10 * Y2**2 # Rosenbrock-like narrow valley 42 | levels2 = np.linspace(np.min(Z2), np.max(Z2), 30) 43 | ax2.contourf(X2, Y2, Z2, levels=levels2, cmap="magma") 44 | ax2.contour(X2, Y2, Z2, levels=levels2, colors="black", linewidths=0.5, alpha=0.5) 45 | ax2.set_title("Ravine (Ill-Conditioning)", fontsize=16) 46 | ax2.text(0, 0, "★", color="cyan", fontsize=20, ha="center", va="center") 47 | 48 | # --- Panel 3: Cliff (Discontinuous Gradient) --- 49 | ax3 = axes[2] 50 | x3 = np.linspace(-3, 3, 300) 51 | y3 = np.linspace(-3, 3, 300) 52 | X3, Y3 = np.meshgrid(x3, y3) 53 | Z3 = -Y3 # A simple slope 54 | Z3[X3 > 0] += 3 # Create a sharp cliff 55 | levels3 = np.linspace(np.min(Z3), np.max(Z3), 30) 56 | ax3.contourf(X3, Y3, Z3, levels=levels3, cmap="plasma") 57 | ax3.contour(X3, Y3, Z3, levels=levels3, colors="black", linewidths=0.5, alpha=0.5) 58 | ax3.set_title("Cliff (Discontinuous Gradient)", fontsize=16) 59 | ax3.plot([0, 0], [-3, 3], "r--", lw=2) # Mark the cliff edge 60 | 61 | # --- Style all plots --- 62 | for ax in axes: 63 | ax.set_xticks([]) 64 | ax.set_yticks([]) 65 | ax.spines["top"].set_visible(False) 66 | ax.spines["right"].set_visible(False) 67 | ax.spines["bottom"].set_visible(False) 68 | ax.spines["left"].set_visible(False) 69 | 70 | fig.tight_layout(rect=[0, 0, 1, 0.95]) # Adjust layout for suptitle 71 | 72 | # --- Save the figure --- 73 | output_path = "docs/assets/hiker_analogy_diagram.png" 74 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 75 | plt.savefig(output_path, bbox_inches="tight", pad_inches=0.1, dpi=300) 76 | print(f"Multi-panel diagram saved to {output_path}") 77 | 78 | 79 | if __name__ == "__main__": 80 | create_hiker_analogy_diagram() 81 | -------------------------------------------------------------------------------- /benchmark/discontinuous_gradient.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import SkipConfig, param_norm_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | 14 | 15 | def objective(x): 16 | """Tests optimizer robustness to non-smooth landscapes with discontinuous gradients.""" 17 | return torch.where(x < 0, x**2, 2 * x).mean() # Discontinuous gradient at x=0 18 | 19 | 20 | class Model(nn.Module): 21 | def __init__(self, size=1024): 22 | super().__init__() 23 | self.param = nn.Parameter(torch.randn(size)) 24 | 25 | def forward(self): 26 | return objective(self.param) 27 | 28 | 29 | @app.command() 30 | def main( 31 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 32 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 33 | steps: int = 100, 34 | weight_decay: float = 0, 35 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 36 | trials: int = 100, 37 | win_condition_multiplier: float = 1.0, 38 | config: Optional[str] = None, 39 | ): 40 | if config is not None and config != "trivial": 41 | raise SkipConfig("'config' must be 'trivial'.") 42 | dtype = [getattr(torch, d) for d in dtype] 43 | model = Model().cuda().double() 44 | 45 | def data(): 46 | return None, None 47 | 48 | trial( 49 | model, 50 | data, 51 | None, 52 | param_norm_win_condition(win_condition_multiplier * 1e-4, 0), 53 | steps, 54 | opt[0], 55 | dtype[0], 56 | 1, 57 | 1, 58 | weight_decay, 59 | method[0], 60 | 1, 61 | 1, 62 | failure_threshold=3, 63 | base_lr=1e-3, 64 | trials=trials, 65 | ) 66 | 67 | 68 | if __name__ == "__main__": 69 | app() 70 | -------------------------------------------------------------------------------- /benchmark/dynamic_landscape.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests optimizer's ability to track moving targets. 3 | 4 | This benchmark simulates a dynamic loss landscape where the optimal parameters 5 | continuously shift over time. This tests the optimizer's ability to: 6 | 1. Track moving targets 7 | 2. Adapt to non-stationary objectives 8 | 3. Handle continuous parameter updates 9 | """ 10 | 11 | import itertools 12 | import math 13 | from typing import List, Optional 14 | 15 | import torch 16 | import torch.nn as nn 17 | import typer 18 | 19 | from benchmark.utils import loss_win_condition, trial 20 | from heavyball.utils import set_torch 21 | 22 | app = typer.Typer(pretty_exceptions_enable=False) 23 | set_torch() 24 | configs = { 25 | "trivial": {"frequency": 1000}, 26 | "easy": {"frequency": 100}, 27 | "medium": {"frequency": 20}, 28 | "hard": {"frequency": 10}, 29 | "extreme": {"frequency": 6}, 30 | "nightmare": {"frequency": 4}, 31 | } 32 | 33 | 34 | class ShiftingSphere(nn.Module): 35 | def __init__(self, dim, frequency): 36 | super().__init__() 37 | self.param = nn.Parameter(torch.randn(dim)) 38 | self.phase = 0 39 | self.frequency = 1 / frequency * 1.1 # so that we don't repeat numbers 40 | 41 | def forward(self): 42 | self.phase += self.frequency 43 | target = torch.linspace(0, 2 * math.pi, len(self.param), device=self.param.device, dtype=self.param.dtype) 44 | target = torch.sin(target + self.phase) 45 | return (self.param - target).square().mean() 46 | 47 | 48 | @app.command() 49 | def main( 50 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 51 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 52 | dim: int = 16384, 53 | steps: int = 500, 54 | weight_decay: float = 0, 55 | opt: List[str] = typer.Option(["adamw"], help="Optimizers to use"), 56 | win_condition_multiplier: float = 1.0, 57 | trials: int = 3, 58 | config: Optional[str] = None, 59 | ): 60 | """Run dynamic landscape benchmark with specified parameters.""" 61 | frequency = configs.get(config, {}).get("frequency", 0.1) 62 | dtype = [getattr(torch, d) for d in dtype] 63 | 64 | for args in itertools.product(method, dtype, [dim], opt, [weight_decay]): 65 | m, d, dim, o, wd = args 66 | 67 | model = ShiftingSphere(dim, frequency) 68 | 69 | def data(): 70 | return None, None 71 | 72 | # Win condition: average squared error should be small (parameters close to target) 73 | trial( 74 | model, 75 | data, 76 | None, 77 | loss_win_condition(0.01 * win_condition_multiplier), 78 | steps, 79 | [o], 80 | [d], 81 | 1, 82 | 1, 83 | wd, 84 | m, 85 | 1, 86 | 1, 87 | base_lr=0.1, 88 | trials=trials, 89 | ) 90 | 91 | 92 | if __name__ == "__main__": 93 | app() 94 | -------------------------------------------------------------------------------- /benchmark/exploding_gradient.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests optimizer's ability to handle exploding gradients. 3 | 4 | This benchmark creates a scenario where gradients can grow exponentially, 5 | testing the optimizer's: 6 | 1. Gradient clipping/scaling mechanisms 7 | 2. Numerical stability 8 | 3. Ability to make progress despite extreme gradient values 9 | """ 10 | 11 | import itertools 12 | from typing import List, Optional 13 | 14 | import torch 15 | import torch.nn as nn 16 | import typer 17 | 18 | from benchmark.utils import param_norm_win_condition, trial 19 | from heavyball.utils import set_torch 20 | 21 | app = typer.Typer(pretty_exceptions_enable=False) 22 | set_torch() 23 | configs = { 24 | "trivial": {"scale": 1}, 25 | "easy": {"scale": 2}, 26 | "medium": {"scale": 4}, 27 | "hard": {"scale": 8}, 28 | "extreme": {"scale": 12}, 29 | "nightmare": {"scale": 16}, 30 | } 31 | 32 | 33 | class ExplodingGradient(nn.Module): 34 | def __init__(self, scale, dim): 35 | super().__init__() 36 | self.param = nn.Parameter(torch.randn(dim)) 37 | self.scale = scale # Controls how quickly gradients grow 38 | 39 | def forward(self): 40 | # Creates exponentially growing gradients 41 | # Gradient will be scale * exp(|param|) * sign(param) 42 | return torch.exp(self.scale * torch.abs(self.param)).mean() 43 | 44 | 45 | @app.command() 46 | def main( 47 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 48 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 49 | dim: int = 512, 50 | steps: int = 500, 51 | weight_decay: float = 0, 52 | opt: List[str] = typer.Option(["adamw"], help="Optimizers to use"), 53 | win_condition_multiplier: float = 1.0, 54 | trials: int = 3, 55 | config: Optional[str] = None, 56 | ): 57 | scale = configs.get(config, {}).get("scale", 2) 58 | """Run exploding gradient benchmark with specified parameters.""" 59 | dtype = [getattr(torch, d) for d in dtype] 60 | 61 | for args in itertools.product(method, dtype, [dim], opt, [weight_decay]): 62 | m, d, dim, o, wd = args 63 | 64 | model = ExplodingGradient(dim, scale) 65 | 66 | def data(): 67 | return None, None 68 | 69 | # Win condition: loss should be close to 1.0 (exp(0) = 1) 70 | # Using 1.1 as threshold since perfect convergence is hard 71 | trial( 72 | model, 73 | data, 74 | None, 75 | param_norm_win_condition(0.01 * win_condition_multiplier, 0), 76 | steps, 77 | [o], 78 | [d], 79 | 1, 80 | 1, 81 | wd, 82 | m, 83 | 1, 84 | 1, 85 | base_lr=0.001, # Lower learning rate due to large gradients 86 | trials=trials, 87 | ) 88 | 89 | 90 | if __name__ == "__main__": 91 | app() 92 | -------------------------------------------------------------------------------- /benchmark/format_results.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | import scipy 7 | import typer 8 | from matplotlib.cm import ScalarMappable 9 | from matplotlib.colors import Normalize 10 | from matplotlib.patches import Rectangle 11 | 12 | 13 | def parse_loss(loss_str): 14 | loss_str = loss_str.strip() 15 | if not loss_str or loss_str.lower() == "nan": 16 | return float("nan") 17 | if loss_str.lower() == "inf": 18 | return float("inf") 19 | try: 20 | return ( 21 | float(loss_str) 22 | if "e" not in loss_str 23 | else float(loss_str.split("e")[0]) * 10 ** float(loss_str.split("e")[1]) 24 | ) 25 | except ValueError: 26 | return float("nan") 27 | 28 | 29 | def process_str(x, truthy): 30 | if x == "No": 31 | return "" 32 | if x == "Yes": 33 | return f"{truthy}-" 34 | return f"{x}-" 35 | 36 | 37 | def read_benchmark_results(file_path): 38 | with open(file_path, "r") as f: 39 | content = f.read() 40 | details = re.search(r"## Details\n\n(.*?)(?=\n##|\n\Z)", content, re.DOTALL | re.IGNORECASE) 41 | if not details: 42 | raise ValueError("Details section not found.") 43 | table = details.group(1).strip() 44 | lines = re.search(r"\|:?-+:(.*?)\|\n(.*)", table, re.DOTALL).group(2).strip().split("\n") 45 | data = [] 46 | for i, line in enumerate(lines): 47 | if not line.strip() or line.startswith("|---"): 48 | continue 49 | parts = [p.strip() for p in line.split("|")[1:-1]] 50 | if len(parts) < 8: 51 | continue 52 | try: 53 | caution = process_str(parts[2], "cautious") 54 | mars = process_str(parts[3], "mars") 55 | optimizer = f"{caution}{mars}{parts[1]}" 56 | optimizer = optimizer.replace("Foreach", "").replace("Cached", "").strip() 57 | data.append({ 58 | "benchmark": parts[0], 59 | "optimizer": optimizer, 60 | "success": parts[4] == "✓", 61 | "runtime": float(parts[5].replace("s", "")) if parts[5] else float("nan"), 62 | "loss": parse_loss(parts[6]) if parts[6] else float("nan"), 63 | "attempts": int(parts[7]) if parts[7].isdigit() else 0, 64 | }) 65 | except (IndexError, ValueError): 66 | continue 67 | return pd.DataFrame(data) 68 | 69 | 70 | def create_success_matrix(df): 71 | if df.empty: 72 | return pd.DataFrame() 73 | benchmarks = sorted(df["benchmark"].unique()) 74 | optimizers = sorted(df["optimizer"].unique()) 75 | success_matrix = pd.DataFrame(0, index=benchmarks, columns=optimizers, dtype=int) 76 | for _, row in df.iterrows(): 77 | if row["success"] and row["benchmark"] in success_matrix.index and row["optimizer"] in success_matrix.columns: 78 | success_matrix.loc[row["benchmark"], row["optimizer"]] = 1 79 | base_tasks = sorted(set(b.split("-")[0] for b in benchmarks)) 80 | success_total_matrix = pd.DataFrame(0, index=base_tasks, columns=optimizers, dtype=int) 81 | for benchmark in success_matrix.index: 82 | base_task = benchmark.split("-")[0] 83 | if base_task in success_total_matrix.index: 84 | success_total_matrix.loc[base_task] += success_matrix.loc[benchmark] 85 | return success_total_matrix 86 | 87 | 88 | def normalize_matrix_by_row_max(matrix): 89 | max_in_row = matrix.max(axis=1) 90 | max_in_row[max_in_row == 0] = 1 91 | return matrix.div(max_in_row, axis=0) * 100 92 | 93 | 94 | def create_visual_matrix_normalized(success_total_matrix): 95 | if success_total_matrix.empty: 96 | return None 97 | tasks_to_keep = success_total_matrix.sum(axis=1) > 0 98 | filtered_matrix = success_total_matrix[tasks_to_keep].copy() 99 | if filtered_matrix.empty: 100 | return None 101 | filtered_matrix[:] = scipy.stats.rankdata(filtered_matrix, axis=1, method="dense") 102 | normalized_matrix = normalize_matrix_by_row_max(filtered_matrix) 103 | optimizer_means = normalized_matrix.mean(axis=0) 104 | task_means = normalized_matrix.mean(axis=1) 105 | overall_mean = optimizer_means.mean() 106 | plot_matrix: pd.DataFrame = normalized_matrix.copy() 107 | plot_matrix.loc["Avg. Optimizer"] = optimizer_means 108 | 109 | # weight of 0.5, as "jack of all trades, master of none" is better than "perfect at xor but awful at delay" 110 | optimizer_score = (normalized_matrix**0.5).mean(axis=0) 111 | optimizer_indices = np.argsort(-optimizer_score.to_numpy()) 112 | plot_matrix = plot_matrix.iloc[:, optimizer_indices] 113 | 114 | full_task_means = pd.concat([task_means, pd.Series([overall_mean], index=["Avg. Optimizer"])]) 115 | plot_matrix["Avg. Task"] = full_task_means 116 | plot_tasks = plot_matrix.index 117 | plot_optimizers = plot_matrix.columns 118 | 119 | plt.style.use("seaborn-v0_8-whitegrid") 120 | fig, ax = plt.subplots( 121 | figsize=(max(14, len(plot_optimizers) * 0.8), max(10, len(plot_tasks) * 0.6)), facecolor="white" 122 | ) 123 | cmap = plt.cm.Blues 124 | norm = Normalize(vmin=0, vmax=100) 125 | mapper = ScalarMappable(norm=norm, cmap=cmap) 126 | 127 | for i, task in enumerate(plot_tasks): 128 | for j, optimizer in enumerate(plot_optimizers): 129 | value = plot_matrix.loc[task, optimizer] 130 | original_count = 0 131 | if task != "Avg. Optimizer" and optimizer != "Avg. Task": 132 | if task in success_total_matrix.index and optimizer in success_total_matrix.columns: 133 | original_count = success_total_matrix.loc[task, optimizer] 134 | is_summary = task == "Avg. Optimizer" or optimizer == "Avg. Task" 135 | face_color = (0.85, 0.85, 0.85, 0.7) if not is_summary and original_count == 0 else mapper.to_rgba(value) 136 | edge_color = "#666666" if is_summary else "#AAAAAA" 137 | edge_width = 1.5 if is_summary else 0.5 138 | rect = Rectangle((j - 0.5, i - 0.5), 1, 1, facecolor=face_color, edgecolor=edge_color, linewidth=edge_width) 139 | ax.add_patch(rect) 140 | # brightness = sum(face_color[:3]) * 0.333 141 | # text_color = 'white' if brightness < 0.6 else 'black' 142 | # ax.text(j, i, f"{value:.0f}", ha='center', va='center', color=text_color, fontsize=12, fontweight='bold') 143 | 144 | ax.axhline(len(plot_tasks) - 1.5, color="#333333", linewidth=2, linestyle="-") 145 | ax.axvline(len(plot_optimizers) - 1.5, color="#333333", linewidth=2, linestyle="-") 146 | ax.set_xticks(np.arange(len(plot_optimizers))) 147 | ax.set_yticks(np.arange(len(plot_tasks))) 148 | ax.set_xticklabels(plot_optimizers, rotation=45, ha="right", fontsize=14, fontweight="bold", color="#333333") 149 | ax.set_yticklabels(plot_tasks, fontsize=14, fontweight="bold", color="#333333") 150 | ax.set_xlabel("Optimizer", fontsize=16, fontweight="bold", labelpad=15, color="#333333") 151 | ax.set_ylabel("Task", fontsize=16, fontweight="bold", labelpad=15, color="#333333") 152 | ax.set_title( 153 | "Success Rate by Task and Optimizer\n(100% = Best Performance for Each Task)", 154 | fontsize=18, 155 | fontweight="bold", 156 | pad=20, 157 | color="#333333", 158 | ) 159 | 160 | cbar = fig.colorbar(mapper, ax=ax, pad=0.02, aspect=30, shrink=0.8) 161 | cbar.set_label("Success Rate (%)", rotation=270, labelpad=20, fontsize=14, fontweight="bold", color="#333333") 162 | cbar.ax.tick_params(labelsize=12, color="#333333", labelcolor="#333333") 163 | 164 | ax.set_facecolor("#F5F5F5") 165 | fig.patch.set_facecolor("white") 166 | for spine in ax.spines.values(): 167 | spine.set_visible(True) 168 | spine.set_color("#333333") 169 | spine.set_linewidth(1.5) 170 | 171 | plt.tight_layout(rect=[0.02, 0.05, 0.98, 0.95]) 172 | return fig 173 | 174 | 175 | app = typer.Typer(pretty_exceptions_enable=False) 176 | 177 | 178 | @app.command() 179 | def main(file: str = typer.Argument("benchmark_results.md")): 180 | try: 181 | df = read_benchmark_results(file) 182 | if df.empty: 183 | print("No data loaded from benchmark file.") 184 | return 185 | except Exception as e: 186 | print(f"Error reading benchmark file: {e}") 187 | return 188 | success_total_matrix = create_success_matrix(df) 189 | if success_total_matrix.empty: 190 | print("No successful runs found.") 191 | return 192 | fig = create_visual_matrix_normalized(success_total_matrix) 193 | if fig is None: 194 | print("Plot generation failed.") 195 | return 196 | plt.savefig("benchmark_matrix.png", dpi=300, bbox_inches="tight", facecolor="white", pad_inches=0.3) 197 | print("Saved heatmap to: benchmark_heatmap.png") 198 | plt.close(fig) 199 | 200 | 201 | if __name__ == "__main__": 202 | app() 203 | -------------------------------------------------------------------------------- /benchmark/generate_animations.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import math 3 | import pathlib 4 | from copy import deepcopy 5 | from typing import List 6 | 7 | import matplotlib.animation as animation 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import torch 11 | import typer 12 | from utils import get_optim as get_optimizer 13 | 14 | app = typer.Typer(pretty_exceptions_enable=False) 15 | 16 | 17 | @app.command() 18 | def main( 19 | benchmark_name: str = typer.Option(..., help="Name of the benchmark to run (e.g., 'saddle_point')."), 20 | optimizer_names: List[str] = typer.Option( 21 | ..., "--optimizer-name", help="Name of an optimizer to include in the comparison." 22 | ), 23 | output_file: str = typer.Option(..., help="Path to save the generated GIF."), 24 | steps: int = 100, 25 | ): 26 | """ 27 | Generates an animated GIF of optimizer paths on a given benchmark. 28 | """ 29 | try: 30 | benchmark_module = importlib.import_module(f"benchmark.{benchmark_name}") 31 | except ImportError: 32 | print(f"Error: Benchmark '{benchmark_name}' not found.") 33 | raise typer.Exit(1) 34 | 35 | if benchmark_name == "saddle_point": 36 | model = benchmark_module.Model(1, 0) 37 | 38 | def objective_fn(*xs): 39 | return benchmark_module.objective(*xs, power=model.power) 40 | 41 | x_limits, y_limits = (-2, 2), (-2, 2) 42 | elif benchmark_name == "rosenbrock": 43 | coords = (-7, 4) 44 | model = benchmark_module.Model(coords) 45 | objective_fn = benchmark_module.objective 46 | x_limits, y_limits = (-10, 10), (-10, 10) 47 | elif benchmark_name == "beale": 48 | coords = (-7, -4) 49 | model = benchmark_module.Model(coords) 50 | objective_fn = benchmark_module.objective 51 | x_limits, y_limits = (-8, 2), (-8, 2) 52 | else: 53 | print(f"Error: Benchmark '{benchmark_name}' is not supported for animation yet.") 54 | raise typer.Exit(1) 55 | 56 | trajectories = [] 57 | models = [] 58 | for optimizer_name in optimizer_names: 59 | print(f"Tuning LR for {optimizer_name}...") 60 | best_lr = None 61 | min_loss = float("inf") 62 | tuning_steps = math.ceil(steps / 3) 63 | lr_candidates = np.logspace(-8, 0, 50) 64 | 65 | for test_lr in lr_candidates: 66 | temp_model = deepcopy(model) 67 | temp_optimizer = get_optimizer(optimizer_name, temp_model.parameters(), lr=test_lr) 68 | 69 | def _closure(): 70 | loss = temp_model() 71 | loss.backward() 72 | 73 | for _ in range(tuning_steps): 74 | temp_optimizer.zero_grad() 75 | temp_optimizer.step(_closure) 76 | 77 | final_loss = temp_model().item() 78 | if final_loss < min_loss: 79 | min_loss = final_loss 80 | best_lr = test_lr 81 | 82 | print(f" > Best LR for {optimizer_name}: {best_lr:.5f}") 83 | 84 | m = deepcopy(model) 85 | models.append(m) 86 | optimizer = get_optimizer(optimizer_name, m.parameters(), lr=best_lr) 87 | 88 | trajectory = [m.param.detach().clone()] 89 | 90 | def _closure(): 91 | loss = m() 92 | loss.backward() 93 | 94 | for _ in range(steps): 95 | optimizer.zero_grad() 96 | optimizer.step(_closure) 97 | trajectory.append(m.param.detach().clone()) 98 | 99 | trajectories.append(list(torch.stack(trajectory).cpu().numpy())) 100 | print(f" > Final position for {optimizer_name}: {trajectories[-1][-1]}") 101 | 102 | target_trajectory = None 103 | if benchmark_name == "dynamic_targets": 104 | target_trajectory = models[0].get_target_trajectory() 105 | 106 | paths_xy = [] 107 | for traj in trajectories: 108 | path_array = np.array(traj) 109 | paths_xy.append((path_array[:, 0], path_array[:, 1])) 110 | 111 | target_path_xy = None 112 | if target_trajectory: 113 | target_array = np.array(target_trajectory) 114 | target_path_xy = (target_array[:, 0], target_array[:, 1]) 115 | 116 | fig, ax = plt.subplots(figsize=(8, 6)) 117 | ax.set_xlim(x_limits) 118 | ax.set_ylim(y_limits) 119 | ax.set_xlabel("x") 120 | ax.set_ylabel("y") 121 | ax.set_title(f"{benchmark_name.replace('_', ' ').title()} with {', '.join(optimizer_names)}") 122 | 123 | if objective_fn: 124 | x = torch.linspace(x_limits[0], x_limits[1], 100) 125 | y = torch.linspace(y_limits[0], y_limits[1], 100) 126 | X, Y = torch.meshgrid(x, y, indexing="ij") 127 | Z = objective_fn(X, Y) 128 | Z -= Z.min() 129 | Z = Z.log() 130 | ax.contourf(X.numpy(), Y.numpy(), Z.numpy(), levels=50, cmap="viridis") 131 | ax.contour(X.numpy(), Y.numpy(), Z.numpy(), levels=50, colors="k", linewidths=1) 132 | 133 | lines = [ax.plot([], [], "x-", label=name)[0] for name in optimizer_names] 134 | ax.legend() 135 | 136 | target_dot = None 137 | if benchmark_name == "dynamic_targets": 138 | (target_dot,) = ax.plot([], [], "ro", label="Target") 139 | ax.legend() 140 | 141 | def init(): 142 | for line in lines: 143 | line.set_data([], []) 144 | if target_dot: 145 | target_dot.set_data([], []) 146 | return tuple(lines) + (target_dot,) 147 | return tuple(lines) 148 | 149 | frames = min(100, steps + 1) 150 | 151 | def animate(i): 152 | i = math.ceil(i / frames * steps) 153 | for j, line in enumerate(lines): 154 | x_data, y_data = paths_xy[j] 155 | line.set_data(x_data[: i + 1], y_data[: i + 1]) 156 | 157 | if target_path_xy: 158 | target_x, target_y = target_path_xy 159 | if i > 0: 160 | target_dot.set_data([target_x[i - 1]], [target_y[i - 1]]) 161 | else: 162 | target_dot.set_data([], []) 163 | return tuple(lines) + (target_dot,) 164 | 165 | return tuple(lines) 166 | 167 | ani = animation.FuncAnimation(fig, animate, init_func=init, frames=frames, interval=100, blit=True) 168 | 169 | output_path = pathlib.Path(output_file) 170 | output_path.parent.mkdir(parents=True, exist_ok=True) 171 | ani.save(output_path, writer="ffmpeg", fps=10) 172 | print(f"Animation saved to {output_file}") 173 | 174 | 175 | if __name__ == "__main__": 176 | app() 177 | -------------------------------------------------------------------------------- /benchmark/gradient_delay.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import List, Optional 3 | 4 | import torch 5 | import torch.backends.opt_einsum 6 | import typer 7 | from torch import nn 8 | 9 | from benchmark.utils import loss_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | configs = { 16 | "trivial": {"max_delay": 2}, 17 | "easy": {"max_delay": 4}, 18 | "medium": {"max_delay": 16}, 19 | "hard": {"max_delay": 64}, 20 | "extreme": {"max_delay": 128}, 21 | "nightmare": {"max_delay": 256}, 22 | } 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, max_delay=16, param_size=256): 27 | super().__init__() 28 | self.params = nn.ParameterList([nn.Parameter(torch.randn(param_size)) for _ in range(max_delay)]) 29 | # Different update frequencies for each parameter 30 | self.delays = [i for i in range(max_delay)] 31 | self.step = 0 32 | self.grad_queues = [deque(maxlen=i + 1) for i in self.delays] 33 | 34 | def forward(self): 35 | """Test optimizer's ability to handle delayed gradients.""" 36 | total_loss = 0 37 | self.step += 1 38 | 39 | for param, delay, queue in zip(self.params, self.delays, self.grad_queues): 40 | # Current loss for this parameter 41 | loss = param.square().mean() 42 | 43 | # Store the gradient in the queue 44 | queue.append(loss) 45 | 46 | # Only add to total loss when we have enough history 47 | if len(queue) == queue.maxlen and self.step % (delay + 1) == 0: 48 | total_loss = total_loss + queue[0] # Use oldest gradient 49 | 50 | return total_loss / len(self.params) 51 | 52 | 53 | @app.command() 54 | def main( 55 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 56 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 57 | steps: int = 100, 58 | weight_decay: float = 0, 59 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 60 | trials: int = 100, 61 | win_condition_multiplier: float = 1.0, 62 | config: Optional[str] = None, 63 | ): 64 | max_delay = configs.get(config, {}).get("max_delay", 4) 65 | dtype = [getattr(torch, d) for d in dtype] 66 | model = Model(max_delay).cuda().double() 67 | 68 | def data(): 69 | return None, None 70 | 71 | # More lenient win condition and more steps due to delayed updates 72 | trial( 73 | model, 74 | data, 75 | None, 76 | loss_win_condition(win_condition_multiplier * 1e-4), 77 | steps * 2, 78 | opt[0], 79 | dtype[0], 80 | 1, 81 | 1, 82 | weight_decay, 83 | method[0], 84 | 1, 85 | 1, 86 | failure_threshold=5, 87 | base_lr=1e-3, 88 | trials=trials, 89 | ) # Double steps, more attempts 90 | 91 | 92 | if __name__ == "__main__": 93 | app() 94 | -------------------------------------------------------------------------------- /benchmark/gradient_noise_scale.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import param_norm_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | 14 | configs = { 15 | "trivial": {"offset": 32}, 16 | "easy": {"offset": 16}, 17 | "medium": {"offset": 8}, 18 | "hard": {"offset": 4}, 19 | "extreme": {"offset": 2}, 20 | "nightmare": {"offset": 1}, 21 | } 22 | 23 | 24 | class Model(nn.Module): 25 | def __init__(self, offset, size=4096): 26 | super().__init__() 27 | self.param = nn.Parameter(torch.randn(size)) 28 | self.register_buffer("step", torch.zeros(1)) 29 | self.offset = offset 30 | 31 | def forward(self): 32 | """Test optimizer's ability to handle changing noise levels during training.""" 33 | self.step += 1 34 | # Noise that decreases over time 35 | noise_scale = 1.0 / (self.offset + self.step) 36 | noise = torch.randn_like(self.param) * noise_scale 37 | return (self.param + noise).square().mean() 38 | 39 | 40 | @app.command() 41 | def main( 42 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 43 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 44 | steps: int = 100, 45 | weight_decay: float = 0, 46 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 47 | trials: int = 100, 48 | win_condition_multiplier: float = 1.0, 49 | config: Optional[str] = None, 50 | ): 51 | offset = configs.get(config, {}).get("offset", 4) 52 | dtype = [getattr(torch, d) for d in dtype] 53 | model = Model(offset).cuda().double() 54 | 55 | def data(): 56 | return None, None 57 | 58 | # Lenient initial condition due to high initial noise 59 | trial( 60 | model, 61 | data, 62 | None, 63 | param_norm_win_condition(win_condition_multiplier * 1e-3, 0), 64 | steps, 65 | opt[0], 66 | dtype[0], 67 | 1, 68 | 1, 69 | weight_decay, 70 | method[0], 71 | 1, 72 | 1, 73 | failure_threshold=5, 74 | base_lr=1e-3, 75 | trials=trials, 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | app() 81 | -------------------------------------------------------------------------------- /benchmark/grokking.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import itertools 3 | import random 4 | from collections import defaultdict 5 | from pathlib import Path 6 | from typing import List 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import torch 11 | import torch.backends.opt_einsum 12 | import torch.nn as nn 13 | import typer 14 | from torch.utils.data import DataLoader 15 | 16 | import heavyball 17 | from benchmark.utils import get_optim 18 | from heavyball.utils import set_torch 19 | 20 | app = typer.Typer(pretty_exceptions_enable=False) 21 | set_torch() 22 | 23 | 24 | class ModularMLP(nn.Module): 25 | def __init__(self, numbers, p, hidden_dim): 26 | super().__init__() 27 | self.net = nn.Sequential( 28 | nn.Embedding(p, hidden_dim), 29 | nn.Flatten(), 30 | nn.Linear(numbers * hidden_dim, hidden_dim), 31 | nn.LayerNorm(hidden_dim), 32 | nn.LeakyReLU(), 33 | nn.Linear(hidden_dim, hidden_dim), 34 | nn.LayerNorm(hidden_dim), 35 | nn.LeakyReLU(), 36 | nn.Linear(hidden_dim, p), 37 | ) 38 | 39 | def forward(self, x): 40 | return self.net(x) 41 | 42 | 43 | class ModuloDataset(torch.utils.data.Dataset): 44 | def __init__(self, p, numbers, min_idx, length, batch_size): 45 | length = length // batch_size 46 | self.p = p 47 | self.numbers = numbers 48 | self.n_samples = length 49 | self.min_idx = min_idx 50 | self.max_idx = min_idx + length 51 | self.batch_size = batch_size 52 | 53 | def __len__(self): 54 | return self.n_samples 55 | 56 | def __getitem__(self, idx): 57 | generator = torch.Generator() 58 | generator.manual_seed(random.Random(min(idx + self.min_idx, self.max_idx)).randint(0, 2**32)) 59 | x = torch.randint(0, self.p, (self.batch_size, self.numbers), generator=generator) 60 | y = (x.sum(dim=-1) % self.p).long() 61 | return x, y 62 | 63 | 64 | def evaluate(model, loader, device): 65 | """Evaluate model accuracy""" 66 | model.eval() 67 | correct = total = 0 68 | with torch.no_grad(): 69 | for x, y in loader: 70 | x, y = x.to(device), y.to(device) 71 | out = model(x) 72 | pred = out.argmax(dim=1) 73 | correct += (pred == y).sum().detach() 74 | total += y.size(0) 75 | return correct / total 76 | 77 | 78 | def plot_results(train_losses, test_accs, steps_to_grok=None, save_path=None): 79 | """Plot training curves""" 80 | _fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True) 81 | 82 | # Plot training loss 83 | ax1.plot(train_losses, label="Training Loss") 84 | ax1.set_yscale("log") 85 | ax1.set_ylabel("Loss") 86 | ax1.set_title("Training Loss Over Time") 87 | ax1.grid(True) 88 | 89 | # Plot test accuracy 90 | eval_steps = np.arange(0, len(train_losses), len(train_losses) // len(test_accs)) 91 | ax2.plot(eval_steps, test_accs, label="Test Accuracy", color="orange") 92 | ax2.axhline(y=0.9, color="r", linestyle="--", label="Grokking Threshold") 93 | ax2.set_ylabel("Accuracy") 94 | ax2.set_xlabel("Steps") 95 | ax2.set_title("Test Accuracy Over Time") 96 | ax2.grid(True) 97 | 98 | if steps_to_grok is not None: 99 | ax2.axvline(x=steps_to_grok, color="g", linestyle="--", label=f"Grokking Step ({steps_to_grok})") 100 | 101 | ax1.legend() 102 | ax2.legend() 103 | plt.tight_layout() 104 | if save_path: 105 | plt.savefig(save_path) 106 | plt.close() 107 | 108 | 109 | @app.command() 110 | def main( 111 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 112 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 113 | opt: List[str] = typer.Option( 114 | ["ForeachSOAP", "PaLMForeachSOAP", "PrecondScheduleForeachSOAP"], help="Optimizers to use" 115 | ), 116 | steps: int = 100, 117 | batch_size: int = 32, 118 | hidden_dim: int = 32, 119 | p: int = 257, 120 | numbers: int = 4, 121 | weight_decay: float = 0, 122 | lr: float = 1e-4, 123 | train_percent: float = 0.1, 124 | eval_samples: int = 1024, 125 | printervall: int = 1000, 126 | ): 127 | dtype = [getattr(torch, d) for d in dtype] 128 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 129 | 130 | # Clean up old plots 131 | plot_dir = Path(".") 132 | for path in plot_dir.glob("grokking_*.png"): 133 | path.unlink() 134 | 135 | # Pre-generate datasets 136 | unique_samples = p**numbers 137 | train_data = ModuloDataset(p, numbers, 0, int(unique_samples * train_percent), batch_size) 138 | test_data = ModuloDataset(p, numbers, train_data.max_idx, eval_samples, eval_samples) 139 | 140 | print(f"Training on {train_data.n_samples * batch_size:,} samples - {train_percent * 100}%") 141 | print(f"Testing on {eval_samples:,} samples") 142 | 143 | train_loader = DataLoader( 144 | train_data, 145 | collate_fn=lambda x: x[0], 146 | batch_size=1, 147 | shuffle=False, 148 | pin_memory=True, 149 | num_workers=4, 150 | drop_last=True, 151 | prefetch_factor=16, 152 | persistent_workers=True, 153 | ) 154 | 155 | test_loader = DataLoader( 156 | test_data, 157 | collate_fn=lambda x: x[0], 158 | batch_size=1, 159 | shuffle=False, 160 | pin_memory=True, 161 | num_workers=4, 162 | drop_last=True, 163 | prefetch_factor=32, 164 | ) 165 | test_loader = list(test_loader) 166 | test_loader = [[x.pin_memory() for x in i] for i in test_loader] 167 | 168 | train_iter = iter(train_loader) 169 | history = defaultdict(list) 170 | 171 | def data(): 172 | """Get next batch from the dataloader""" 173 | nonlocal train_iter 174 | try: 175 | x, y = next(train_iter) 176 | except (StopIteration, NameError): 177 | train_iter = iter(train_loader) 178 | x, y = next(train_iter) 179 | return x.to(device), y.to(device) 180 | 181 | criterion = nn.CrossEntropyLoss() 182 | 183 | def win_condition(model, loss_hist): 184 | """Check if model has achieved grokking""" 185 | if not isinstance(loss_hist, float): 186 | loss = loss_hist 187 | else: 188 | loss = loss_hist 189 | 190 | history["loss"].append(loss) 191 | 192 | if loss > 0.1: # Not converged yet 193 | return False, {} 194 | 195 | # If loss is low, check test accuracy 196 | acc = evaluate(model, test_loader, device) 197 | history["test_acc"].append(acc) 198 | return acc > 0.9, {"test_acc": acc} 199 | 200 | global_model = ModularMLP(numbers, p, hidden_dim).to(device) 201 | global_model = torch.compile(global_model, mode="max-autotune-no-cudagraphs") 202 | for d, o in itertools.product(dtype, opt): 203 | print(f"\nRunning {o} with {d}") 204 | model = copy.deepcopy(global_model) 205 | model.to(dtype=d) 206 | 207 | history.clear() 208 | 209 | # Get optimizer class 210 | optimizer_class = getattr(heavyball, o) 211 | optimizer = get_optim(optimizer_class, model.parameters(), lr=lr, weight_decay=weight_decay) 212 | 213 | loss_hist = torch.empty(steps) 214 | 215 | # Training loop 216 | for step in range(steps): 217 | model.train() 218 | x, y = data() 219 | 220 | optimizer.zero_grad() 221 | out = model(x) 222 | loss = criterion(out, y) 223 | loss.backward() 224 | optimizer.step() 225 | 226 | with torch.no_grad(): 227 | loss_hist[step] = loss.detach() 228 | 229 | if step % printervall == 0: 230 | lh = loss_hist[:step][-printervall:].mean().item() 231 | acc = evaluate(model, test_loader, device).item() 232 | history["test_acc"].append(acc) 233 | print(f"Step {step}: Loss = {lh:.4f}, Test Acc = {acc:.4f}") 234 | 235 | # Plot results 236 | plot_name = plot_dir / f"grokking_{o}_{d}_lr{lr}_h{hidden_dim}_p{p}.png" 237 | plot_results( 238 | loss_hist.cpu().numpy(), 239 | history["test_acc"], 240 | next((i for i, acc in enumerate(history["test_acc"]) if acc > 0.9), None), 241 | plot_name, 242 | ) 243 | print(f"Training curves saved to {plot_name}") 244 | 245 | 246 | if __name__ == "__main__": 247 | app() 248 | -------------------------------------------------------------------------------- /benchmark/layer_wise_scale.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import loss_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | 14 | configs = { 15 | "trivial": {"scale": 2}, 16 | "easy": {"scale": 1e1}, 17 | "medium": {"scale": 1e3}, 18 | "hard": {"size": 1e5}, 19 | "extreme": {"scale": 1e7}, 20 | "nightmare": {"scale": 1e9}, 21 | } 22 | 23 | 24 | class Model(nn.Module): 25 | def __init__(self, scale: float, size=1024): 26 | super().__init__() 27 | # Simulate different layer scales in deep networks 28 | self.layer1 = nn.Parameter(torch.randn(size)) # Small gradients 29 | self.layer2 = nn.Parameter(torch.randn(size)) # Medium gradients 30 | self.layer3 = nn.Parameter(torch.randn(size)) # Large gradients 31 | self.scale = scale 32 | 33 | def forward(self): 34 | """Test optimizer's ability to handle different gradient scales across layers.""" 35 | # Each layer contributes equally to the loss but has very different scales 36 | return ( 37 | self.layer1.square().mean() * self.scale 38 | + self.layer2.square().mean() 39 | + self.layer3.square().mean() / self.scale 40 | ) / 3 41 | 42 | 43 | @app.command() 44 | def main( 45 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 46 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 47 | steps: int = 100, 48 | weight_decay: float = 0, 49 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 50 | trials: int = 100, 51 | win_condition_multiplier: float = 1.0, 52 | config: Optional[str] = None, 53 | ): 54 | scale = configs.get(config, {}).get("scale", 1e3) 55 | dtype = [getattr(torch, d) for d in dtype] 56 | model = Model(scale).cuda().double() 57 | 58 | def data(): 59 | return None, None 60 | 61 | # More lenient win condition due to vastly different scales 62 | trial( 63 | model, 64 | data, 65 | None, 66 | loss_win_condition(win_condition_multiplier * 1e-4), 67 | steps, 68 | opt[0], 69 | dtype[0], 70 | 1, 71 | 1, 72 | weight_decay, 73 | method[0], 74 | 1, 75 | 1, 76 | failure_threshold=5, 77 | base_lr=1e-4, 78 | trials=trials, 79 | ) # Lower learning rate and more attempts 80 | 81 | 82 | if __name__ == "__main__": 83 | app() 84 | -------------------------------------------------------------------------------- /benchmark/loss_contour.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import tqdm 9 | 10 | import heavyball 11 | 12 | device = "cuda" 13 | heavyball.utils.compile_mode = None 14 | heavyball.utils.dynamic = True 15 | heavyball.utils.set_torch() 16 | 17 | 18 | class Sine(nn.Module): 19 | def forward(self, x): 20 | return torch.sin(x) 21 | 22 | 23 | class Residual(nn.Module): 24 | def __init__(self, wrapped): 25 | super(Residual, self).__init__() 26 | self.wrapped = wrapped 27 | 28 | def forward(self, x): 29 | return x + self.wrapped(x) 30 | 31 | 32 | class DatasetNorm(nn.Module): 33 | def __init__(self, features: int, momentum: float = 0.99): 34 | super().__init__() 35 | self.weight = nn.Parameter(torch.stack([torch.ones(features), torch.zeros(features)], 1)) 36 | self.register_buffer("stats", torch.zeros(features * 2)) 37 | self.register_buffer("step", torch.tensor(0)) 38 | self.momentum = momentum 39 | 40 | def forward(self, x): 41 | if True: 42 | with torch.no_grad(): 43 | mean, sq_mean = x.mean(dim=0), (x**2).mean(dim=0) 44 | stats = torch.cat([mean, sq_mean]) 45 | self.step.add_(1) 46 | self.stats.lerp_(stats, 1 - heavyball.utils.beta_debias(self.momentum, self.step)) 47 | # self.stats.lerp_(stats, self.step == 1) 48 | mean, sq_mean = self.stats.chunk(2) 49 | std = (sq_mean - mean**2).clamp_min_(1e-6).sqrt() 50 | else: 51 | std, mean = 1, 0 52 | weight, bias = self.weight.unbind(1) 53 | return (x - mean) / std * weight + bias 54 | 55 | 56 | class MLP(nn.Module): 57 | def __init__(self, in_shape, out_shape, width, depth, act=Sine(), expanded: int = 256): 58 | super(MLP, self).__init__() 59 | layers = [] 60 | layers.append(nn.Linear(in_shape, width)) 61 | 62 | for _ in range(depth - 1): 63 | layers.append( 64 | Residual( 65 | nn.Sequential( 66 | nn.Linear(width, expanded), # 67 | act, # 68 | DatasetNorm(expanded), # 69 | nn.Linear(expanded, width), 70 | ) 71 | ) 72 | ) 73 | layers.append(DatasetNorm(width)) 74 | layers.append(nn.Linear(width, out_shape)) 75 | self.model = nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | return self.model(x) 79 | 80 | 81 | def generate_two_moons_torch(n_samples=1000, noise=0.1, random_state=None): 82 | if random_state is not None: 83 | torch.manual_seed(random_state) 84 | 85 | half_samples = n_samples // 2 86 | 87 | theta1 = torch.linspace(0, np.pi, half_samples, device=device) 88 | theta2 = torch.linspace(0, np.pi, half_samples, device=device) 89 | 90 | X1 = torch.stack([torch.cos(theta1), torch.sin(theta1)], dim=1) 91 | X2 = torch.stack([1 - torch.cos(theta2), 1 - torch.sin(theta2) - 0.5], dim=1) 92 | 93 | X = torch.cat([X1, X2], dim=0) 94 | y = torch.cat([torch.zeros(half_samples, device=device), torch.ones(half_samples, device=device)], dim=0) 95 | 96 | X += noise * torch.randn(n_samples, 2, device=device) 97 | 98 | indices = torch.randperm(n_samples, device=device) 99 | X = X[indices] 100 | y = y[indices] 101 | 102 | return X, y 103 | 104 | 105 | def train_and_generate_frames( 106 | model, 107 | X_train, 108 | y_train, 109 | domain, 110 | epochs, 111 | lr, 112 | filename="training_video", 113 | resolution: int = 128, 114 | subsample: int = 1, 115 | train_samples: int = 1024, 116 | ): 117 | X_train = X_train.to(device).float() 118 | y_train = y_train.view(-1, 1).to(device).float() 119 | 120 | optimizers = { 121 | "ForeachSOAP": heavyball.ForeachSOAP(model.parameters(), lr=lr), 122 | "PaLMForeachSOAP": heavyball.PaLMForeachSOAP(model.parameters(), lr=lr), 123 | "PrecondScheduleForeachSOAP": heavyball.PrecondScheduleForeachSOAP(model.parameters(), lr=lr), 124 | } 125 | criterion = nn.BCEWithLogitsLoss() 126 | 127 | xx, yy = torch.meshgrid( 128 | torch.linspace(domain[0][0], domain[1][0], resolution, device=device), 129 | torch.linspace(domain[0][1], domain[1][1], resolution, device=device), 130 | indexing="xy", 131 | ) 132 | grid_points = torch.stack((xx.ravel(), yy.ravel()), dim=1).float() 133 | 134 | base_model = copy.deepcopy(model) 135 | 136 | for optimizer_name, optimizer in optimizers.items(): 137 | model = copy.deepcopy(base_model) 138 | print(f"\nTraining with {optimizer_name}") 139 | model.train() 140 | 141 | os.makedirs("frames", exist_ok=True) 142 | 143 | for epoch in tqdm.tqdm(range(epochs)): 144 | outputs = model(X_train) 145 | loss = criterion(outputs, y_train) 146 | 147 | optimizer.zero_grad() 148 | loss.backward() 149 | optimizer.step() 150 | 151 | if epoch % subsample == 0: 152 | model.eval() 153 | with torch.no_grad(): 154 | Z = model(grid_points).reshape(resolution, resolution) 155 | plt.figure(figsize=(10, 8)) 156 | plt.contourf(xx.cpu(), yy.cpu(), Z.cpu(), levels=20) 157 | plt.colorbar(label="Model Output") 158 | plt.scatter(X_train[:, 0].cpu(), X_train[:, 1].cpu(), c=y_train.cpu(), cmap="coolwarm") 159 | plt.title(f"{optimizer_name} - Epoch {epoch}, Loss: {loss.item():.4f}") 160 | plt.savefig(f"frames/{optimizer_name}_epoch_{epoch:05d}.png") 161 | plt.close() 162 | model.train() 163 | 164 | 165 | if __name__ == "__main__": 166 | X, y = generate_two_moons_torch(n_samples=1024, noise=0.05, random_state=42) 167 | 168 | domain = np.array([ 169 | [X[:, 0].min().item() - 1, X[:, 1].min().item() - 1], 170 | [X[:, 0].max().item() + 1, X[:, 1].max().item() + 1], 171 | ]) 172 | 173 | model = torch.compile(MLP(in_shape=2, out_shape=1, width=2, depth=32), mode="max-autotune-no-cudagraphs").to(device) 174 | 175 | epochs = 100 176 | lr = 1e-4 177 | train_and_generate_frames(model, X, y, domain, epochs, lr) 178 | -------------------------------------------------------------------------------- /benchmark/merge_logs.py: -------------------------------------------------------------------------------- 1 | import io 2 | from datetime import datetime 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import markdown 7 | import numpy as np 8 | import pandas as pd 9 | import typer 10 | 11 | app = typer.Typer(pretty_exceptions_enable=True) 12 | 13 | 14 | def nanmean(x): 15 | return np.mean(x[np.isfinite(x)]) 16 | 17 | 18 | def merge_markdown_files(file_paths): 19 | """Merge multiple markdown files by taking the union of detail experiments and computing new summary stats.""" 20 | latest_timestamp = datetime.min 21 | union_details = [] 22 | 23 | for file_path in file_paths: 24 | with open(file_path, "r") as f: 25 | content = f.read() 26 | content = content.split("## Details")[1].split("## ")[0].strip() 27 | html = markdown.markdown(content, extensions=["tables"]) 28 | union_details.append(pd.read_html(io.StringIO(html))[0]) 29 | 30 | union_details = pd.concat(union_details) 31 | 32 | # For duplicate rows (i.e., Benchmark, Optimizer, Cautious, Mars == same), take the best run (i.e., Successful, Runtime, Attempts) 33 | union_details["Success"] = union_details["Success"] == "✓" 34 | union_details["Runtime"] = union_details["Runtime"].str.replace("s", "") 35 | union_details["Runtime"] = pd.to_numeric(union_details["Runtime"], errors="coerce") 36 | union_details["Attempts"] = pd.to_numeric(union_details["Attempts"], errors="coerce") 37 | union_details["Loss"] = pd.to_numeric(union_details["Loss"], errors="coerce") 38 | 39 | union_details = union_details.sort_values(by=["Success", "Runtime", "Attempts"], ascending=[False, True, True]) 40 | union_details = union_details.drop_duplicates(keep="first", subset=["Benchmark", "Optimizer", "Cautious", "Mars"]) 41 | 42 | configs = union_details[["Optimizer", "Cautious", "Mars"]].drop_duplicates().to_dict(orient="records") 43 | 44 | new_summary = [] 45 | 46 | for config in configs: 47 | config_details = union_details[ 48 | (union_details["Optimizer"] == config["Optimizer"]) 49 | & (union_details["Cautious"] == config["Cautious"]) 50 | & (union_details["Mars"] == config["Mars"]) 51 | ] 52 | new_summary.append({ 53 | **config, 54 | "Attempts": nanmean(config_details["Attempts"]), 55 | "Success": f"{int(np.sum(config_details['Success']))}/{len(config_details)}", 56 | "Average Runtime": f"{nanmean(config_details['Runtime']):.1f}s", 57 | }) 58 | 59 | new_summary = pd.DataFrame(new_summary) 60 | new_summary.sort_values(by=["Optimizer", "Cautious", "Mars"], inplace=True) 61 | 62 | union_details["Runtime"] = [f"{x:.1f}s" for x in union_details["Runtime"]] 63 | union_details["Success"] = ["✓" if x else "✗" for x in union_details["Success"]] 64 | 65 | # Generate merged content with updated summary based on union of experiments 66 | merged_content = f"""# Benchmark Results 67 | 68 | Generated: {latest_timestamp} 69 | Last updated: {latest_timestamp} 70 | 71 | ## Summary 72 | 73 | {new_summary.to_markdown(index=False)} 74 | 75 | ## Details 76 | 77 | {union_details.to_markdown(index=False)} 78 | """ 79 | 80 | return merged_content 81 | 82 | 83 | @app.command() 84 | def main(path: List[str] = typer.Option([], help="Markdown files to merge")): 85 | files = [Path(p) for p in path] 86 | 87 | # Generate merged content 88 | merged_content = merge_markdown_files(files) 89 | 90 | # Write to output file 91 | output_path = Path("merged_results.md") 92 | with open(output_path, "w") as f: 93 | f.write(merged_content) 94 | 95 | 96 | if __name__ == "__main__": 97 | app() 98 | -------------------------------------------------------------------------------- /benchmark/minimax.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from benchmark.utils import param_norm_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | configs = { 15 | "trivial": {"size": 4}, 16 | "easy": {"size": 16}, 17 | "medium": {"size": 512}, 18 | "hard": {"size": 8192}, 19 | "extreme": {"size": 2**15}, 20 | "nightmare": {"size": 2**17}, 21 | } 22 | 23 | 24 | class Model(nn.Module): 25 | def __init__(self, size): 26 | super().__init__() 27 | self.param = nn.Parameter(torch.randn((2 * size,))) 28 | 29 | def forward(self, inp): 30 | param0, param1 = self.param.chunk(2, 0) 31 | return param0 @ param1 + (param0 @ param0 + param1 @ param1) / 2 32 | 33 | 34 | @app.command() 35 | def main( 36 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 37 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 38 | size: int = 1024, 39 | depth: int = 4, 40 | batch: int = 16, 41 | steps: int = 10, 42 | weight_decay: float = 0, 43 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 44 | win_condition_multiplier: float = 1.0, 45 | trials: int = 10, 46 | config: Optional[str] = None, 47 | ): 48 | size = configs.get(config, {}).get("size", size) 49 | 50 | dtype = [getattr(torch, d) for d in dtype] 51 | model = Model(size).cuda() 52 | 53 | def data(): 54 | inp = torch.randn((batch, size), device="cuda", dtype=dtype[0]) 55 | return inp, inp.cumsum(1) 56 | 57 | trial( 58 | model, 59 | data, 60 | F.mse_loss, 61 | param_norm_win_condition(1e-7 * win_condition_multiplier, 0), 62 | steps, 63 | opt[0], 64 | dtype[0], 65 | size, 66 | batch, 67 | weight_decay, 68 | method[0], 69 | 1, 70 | depth, 71 | failure_threshold=depth * 2, 72 | base_lr=1e-3, 73 | trials=trials, 74 | ) 75 | 76 | 77 | if __name__ == "__main__": 78 | app() 79 | -------------------------------------------------------------------------------- /benchmark/momentum_utilization.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import loss_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | configs = { 14 | "trivial": {"weight": 0.004}, 15 | "easy": {"weight": 0.02}, 16 | "medium": {"weight": 0.1}, 17 | "hard": {"weight": 0.5}, 18 | "extreme": {"weight": 1}, 19 | "nightmare": {"weight": 2}, 20 | } 21 | 22 | 23 | class Model(nn.Module): 24 | def __init__(self, weight: float, size=1024): 25 | super().__init__() 26 | self.param = nn.Parameter(torch.randn(size)) 27 | self.register_buffer("t", torch.zeros(1)) 28 | self.weight = weight 29 | 30 | def forward(self): 31 | """Tests effective use of momentum for oscillating landscapes.""" 32 | self.t += 0.1 33 | x = self.param 34 | return (x.square() + self.weight * torch.sin(10 * x) * torch.cos(self.t)).mean() 35 | 36 | 37 | @app.command() 38 | def main( 39 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 40 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 41 | steps: int = 100, 42 | weight_decay: float = 0, 43 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 44 | trials: int = 100, 45 | win_condition_multiplier: float = 1.0, 46 | weight: float = 0.1, 47 | config: Optional[str] = None, 48 | ): 49 | weight = configs.get(config, {}).get("weight", weight) 50 | dtype = [getattr(torch, d) for d in dtype] 51 | model = Model(weight).cuda().double() 52 | 53 | def data(): 54 | return None, None 55 | 56 | trial( 57 | model, 58 | data, 59 | None, 60 | loss_win_condition(win_condition_multiplier * 1e-6), 61 | steps, 62 | opt[0], 63 | dtype[0], 64 | 1, 65 | 1, 66 | weight_decay, 67 | method[0], 68 | 1, 69 | 1, 70 | failure_threshold=3, 71 | base_lr=1e-3, 72 | trials=trials, 73 | ) 74 | 75 | 76 | if __name__ == "__main__": 77 | app() 78 | -------------------------------------------------------------------------------- /benchmark/noisy_matmul.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from benchmark.utils import param_norm_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | configs = { 16 | "trivial": {"depth": 1}, 17 | "easy": {"depth": 2}, 18 | "medium": {"depth": 8}, 19 | "hard": {"depth": 12}, 20 | "extreme": {"depth": 16}, 21 | "nightmare": {"depth": 20}, 22 | } 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, size): 27 | super().__init__() 28 | self.param = nn.Parameter(torch.randn((size,))) 29 | self.offset = nn.Buffer(torch.randn_like(self.param)) 30 | 31 | def forward(self, inp): 32 | y = None 33 | y0 = self.param.view(1, -1).expand(inp.size(0), -1) + self.offset # offset, so weight decay doesnt help 34 | for i in inp.unbind(1): 35 | y = torch.einsum("bi,bik->bk", y0, i) 36 | y0 = F.leaky_relu(y, 0.1) 37 | return y 38 | 39 | 40 | @app.command() 41 | def main( 42 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 43 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 44 | size: int = 64, 45 | depth: int = 4, 46 | batch: int = 128, 47 | steps: int = 10, 48 | weight_decay: float = 0, 49 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 50 | win_condition_multiplier: float = 1.0, 51 | trials: int = 10, 52 | config: Optional[str] = None, 53 | ): 54 | depth = configs.get(config, {}).get("depth", depth) 55 | dtype = [getattr(torch, d) for d in dtype] 56 | model = Model(size).cuda() 57 | 58 | def data(): 59 | inp = torch.randn((batch, depth, size, size), device="cuda", dtype=dtype[0]) / size**0.5 60 | return inp, torch.zeros((batch, size), device="cuda", dtype=dtype[0]) 61 | 62 | trial( 63 | model, 64 | data, 65 | F.mse_loss, 66 | param_norm_win_condition(1e-7 * win_condition_multiplier, model.offset), 67 | steps, 68 | opt[0], 69 | dtype[0], 70 | size, 71 | batch, 72 | weight_decay, 73 | method[0], 74 | 1, 75 | depth, 76 | failure_threshold=depth * 2, 77 | base_lr=1e-3, 78 | trials=trials, 79 | ) 80 | 81 | 82 | if __name__ == "__main__": 83 | app() 84 | -------------------------------------------------------------------------------- /benchmark/parameter_scale.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import loss_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | configs = {"easy": {"scale": 1e1}, "medium": {"scale": 1e3}, "hard": {"scale": 1e5}} 14 | 15 | 16 | class Model(nn.Module): 17 | def __init__(self, size, scale: float): 18 | super().__init__() 19 | # Simulate different layer scales in deep networks 20 | self.layer1 = nn.Parameter(torch.randn(size) * scale) # Small gradients 21 | self.layer2 = nn.Parameter(torch.randn(size)) # Medium gradients 22 | self.layer3 = nn.Parameter(torch.randn(size) / scale) # Large gradients 23 | 24 | def forward(self): 25 | """Test optimizer's ability to handle different gradient scales across layers.""" 26 | # Each layer contributes equally to the loss but has very different scales 27 | return (self.layer1.square().mean() + self.layer2.square().mean() + self.layer3.square().mean()) / 3 28 | 29 | 30 | @app.command() 31 | def main( 32 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 33 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 34 | steps: int = 100, 35 | weight_decay: float = 0, 36 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 37 | trials: int = 100, 38 | win_condition_multiplier: float = 1.0, 39 | config: Optional[str] = None, 40 | ): 41 | scale = configs.get(config, {}).get("scale", 1e3) 42 | 43 | dtype = [getattr(torch, d) for d in dtype] 44 | 45 | model = Model(size=1024, scale=scale).cuda().double() 46 | 47 | def data(): 48 | return None, None 49 | 50 | # More lenient win condition due to vastly different scales 51 | trial( 52 | model, 53 | data, 54 | None, 55 | loss_win_condition(win_condition_multiplier * 1e-4), 56 | steps, 57 | opt[0], 58 | dtype[0], 59 | 1, 60 | 1, 61 | weight_decay, 62 | method[0], 63 | 1, 64 | 1, 65 | failure_threshold=5, 66 | base_lr=1e-4, 67 | trials=trials, 68 | ) # Lower learning rate and more attempts 69 | 70 | 71 | if __name__ == "__main__": 72 | app() 73 | -------------------------------------------------------------------------------- /benchmark/plateau_navigation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pathlib 3 | import random 4 | from typing import List, Optional 5 | 6 | import matplotlib.colors 7 | import torch 8 | import torch.backends.opt_einsum 9 | import typer 10 | from torch import nn 11 | from utils import Plotter 12 | 13 | from benchmark.utils import loss_win_condition, trial 14 | from heavyball.utils import set_torch 15 | 16 | app = typer.Typer(pretty_exceptions_enable=False) 17 | set_torch() 18 | 19 | configs = { 20 | "trivial": {"scale": 1}, 21 | "easy": {"scale": 4}, 22 | "medium": {"scale": 8}, 23 | "hard": {"scale": 12}, 24 | "extreme": {"scale": 16}, 25 | "nightmare": {"scale": 20}, 26 | } 27 | 28 | 29 | def objective(x, y, scale: float): 30 | """Tests optimizer's ability to handle regions with very small gradients and sharp plateaus.""" 31 | output = 1 / (1 + torch.exp((x**2 + y**2 - 1) * -scale)) 32 | minimum = 1 / (1 + math.exp(scale)) 33 | return output - minimum # ensure the minimum is at 0 34 | 35 | 36 | class Model(nn.Module): 37 | def __init__(self, x, scale): 38 | super().__init__() 39 | self.param = nn.Parameter(torch.tensor(x).float()) 40 | self.scale = scale 41 | 42 | def forward(self): 43 | return objective(*self.param, scale=self.scale) 44 | 45 | 46 | @app.command() 47 | def main( 48 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 49 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 50 | steps: int = 100, 51 | weight_decay: float = 0, 52 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 53 | show_image: bool = False, 54 | trials: int = 100, 55 | win_condition_multiplier: float = 1.0, 56 | config: Optional[str] = None, 57 | ): 58 | scale = configs.get(config, {}).get("scale", 4) 59 | 60 | dtype = [getattr(torch, d) for d in dtype] 61 | coords = (1.5, 1.5) # Start outside the plateau 62 | 63 | # Clean up old plots 64 | for path in pathlib.Path(".").glob("plateau_navigation.png"): 65 | path.unlink() 66 | 67 | colors = list(matplotlib.colors.TABLEAU_COLORS.values()) 68 | rng = random.Random(0x1239121) 69 | rng.shuffle(colors) 70 | 71 | if show_image: 72 | model = Plotter(lambda *x: objective(*x, scale=scale).log()) 73 | else: 74 | model = Model(coords, scale=scale) 75 | model.double() 76 | 77 | def data(): 78 | return None, None 79 | 80 | trial( 81 | model, 82 | data, 83 | None, 84 | loss_win_condition(win_condition_multiplier * 1e-4), 85 | steps, 86 | opt[0], 87 | dtype[0], 88 | 1, 89 | 1, 90 | weight_decay, 91 | method[0], 92 | 1, 93 | 1, 94 | failure_threshold=3, 95 | base_lr=1e-3, 96 | trials=trials, 97 | ) 98 | 99 | 100 | if __name__ == "__main__": 101 | app() 102 | -------------------------------------------------------------------------------- /benchmark/postprocess_requirements.txt: -------------------------------------------------------------------------------- 1 | pandas>=1.5.0 2 | tabulate>=0.9.0 3 | termcolor>=2.1.0 4 | matplotlib>=3.7.0 5 | seaborn>=0.12.0 6 | -------------------------------------------------------------------------------- /benchmark/powers.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import torch.nn as nn 6 | import typer 7 | 8 | from benchmark.utils import loss_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | configs = { 14 | "trivial": {"powers": 4}, 15 | "easy": {"powers": 8}, 16 | "medium": {"powers": 16}, 17 | "hard": {"powers": 32}, 18 | "extreme": {"powers": 128}, 19 | "nightmare": {"powers": 512}, 20 | } 21 | 22 | 23 | class Model(nn.Module): 24 | def __init__(self, size, powers, target): 25 | super().__init__() 26 | self.target = target 27 | self.param = nn.Parameter(torch.rand(powers, size) * 2) 28 | self.register_buffer("scale", torch.arange(powers).float().add(1)) 29 | 30 | def forward(self): 31 | x = self.param - self.target 32 | x = x ** self.scale.view(-1, 1) 33 | return x.square().mean() 34 | 35 | 36 | @app.command() 37 | def main( 38 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 39 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 40 | size: int = 64, 41 | powers: int = 8, 42 | steps: int = 10, 43 | target: float = 1.0, 44 | weight_decay: float = 0, 45 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 46 | win_condition_multiplier: float = 1.0, 47 | trials: int = 10, 48 | config: Optional[str] = None, 49 | ): 50 | powers = configs.get(config, {}).get("powers", powers) 51 | 52 | dtype = [getattr(torch, d) for d in dtype] 53 | model = Model(size, powers, target).cuda().double() 54 | 55 | def data(): 56 | return None, None 57 | 58 | trial( 59 | model, 60 | data, 61 | None, 62 | loss_win_condition(win_condition_multiplier * 1e-8), 63 | steps, 64 | opt[0], 65 | dtype[0], 66 | 1, 67 | 1, 68 | weight_decay, 69 | method[0], 70 | 1, 71 | 1, 72 | failure_threshold=3, 73 | base_lr=1e-3, 74 | trials=trials, 75 | ) 76 | 77 | 78 | if __name__ == "__main__": 79 | app() 80 | -------------------------------------------------------------------------------- /benchmark/powers_varying_target.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import torch.nn as nn 6 | import typer 7 | 8 | from benchmark.utils import loss_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | 14 | configs = { 15 | "trivial": {"powers": 4}, 16 | "easy": {"powers": 8}, 17 | "medium": {"powers": 16}, 18 | "hard": {"powers": 32}, 19 | "extreme": {"powers": 128}, 20 | "nightmare": {"powers": 512}, 21 | } 22 | 23 | 24 | class Model(nn.Module): 25 | def __init__(self, size, powers, target_mult): 26 | super().__init__() 27 | self.target = nn.Buffer( 28 | torch.arange(powers * size).view(size, powers).transpose(0, 1).float() * target_mult / powers / size 29 | ) 30 | self.param = nn.Parameter(torch.rand(powers, size) * 2) 31 | self.register_buffer("scale", torch.arange(powers).float().add(1)) 32 | 33 | def forward(self): 34 | x = self.param - self.target 35 | x = x ** self.scale.view(-1, 1) 36 | return x.square().mean() 37 | 38 | 39 | @app.command() 40 | def main( 41 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 42 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 43 | size: int = 64, 44 | powers: int = 8, 45 | steps: int = 10, 46 | target_mult: float = 1.0, 47 | weight_decay: float = 0, 48 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 49 | win_condition_multiplier: float = 1.0, 50 | trials: int = 10, 51 | config: Optional[str] = None, 52 | ): 53 | powers = configs.get(config, {}).get("powers", powers) 54 | 55 | dtype = [getattr(torch, d) for d in dtype] 56 | model = Model(size, powers, target_mult).cuda().double() 57 | 58 | def data(): 59 | return None, None 60 | 61 | trial( 62 | model, 63 | data, 64 | None, 65 | loss_win_condition(win_condition_multiplier * 1e-6), 66 | steps, 67 | opt[0], 68 | dtype[0], 69 | 1, 70 | 1, 71 | weight_decay, 72 | method[0], 73 | 1, 74 | 1, 75 | failure_threshold=3, 76 | base_lr=1e-3, 77 | trials=trials, 78 | ) 79 | 80 | 81 | if __name__ == "__main__": 82 | app() 83 | -------------------------------------------------------------------------------- /benchmark/quadratic_varying_scale.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import typer 7 | 8 | from benchmark.utils import param_norm_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | 14 | configs = { 15 | "trivial": {"size": 4}, 16 | "easy": {"size": 16}, 17 | "medium": {"size": 512}, 18 | "hard": {"size": 8192}, 19 | "extreme": {"size": 2**15}, 20 | "nightmare": {"size": 2**17}, 21 | } 22 | 23 | 24 | class Model(nn.Module): 25 | def __init__(self, size): 26 | super().__init__() 27 | self.param = nn.Parameter(torch.randn(size)) 28 | self.register_buffer("scale", F.normalize(torch.arange(1, 1 + size).float(), dim=0, p=1)) 29 | 30 | def forward(self): 31 | return self.param.square() @ self.scale 32 | 33 | 34 | @app.command() 35 | def main( 36 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 37 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 38 | size: int = 1024, 39 | batch: int = 256, 40 | steps: int = 100, 41 | weight_decay: float = 0, 42 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 43 | trials: int = 10, 44 | win_condition_multiplier: float = 1.0, 45 | config: Optional[str] = None, 46 | ): 47 | size = configs.get(config, {}).get("size", size) 48 | dtype = [getattr(torch, d) for d in dtype] 49 | model = Model(size).cuda() 50 | 51 | def data(): 52 | return None, None 53 | 54 | trial( 55 | model, 56 | data, 57 | None, 58 | param_norm_win_condition(win_condition_multiplier * 1e-7, 0), 59 | steps, 60 | opt[0], 61 | dtype[0], 62 | size, 63 | batch, 64 | weight_decay, 65 | method[0], 66 | 1, 67 | 1, 68 | failure_threshold=2, 69 | base_lr=1e-3, 70 | trials=trials, 71 | ) 72 | 73 | 74 | if __name__ == "__main__": 75 | app() 76 | -------------------------------------------------------------------------------- /benchmark/quadratic_varying_target.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import typer 7 | 8 | from benchmark.utils import param_norm_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | 14 | 15 | configs = { 16 | "trivial": {"size": 4}, 17 | "easy": {"size": 16}, 18 | "medium": {"size": 512}, 19 | "hard": {"size": 8192}, 20 | "extreme": {"size": 2**15}, 21 | "nightmare": {"size": 2**17}, 22 | } 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, size): 27 | super().__init__() 28 | self.param = nn.Parameter(torch.randn(size)) 29 | self.register_buffer("target", F.normalize(torch.arange(size).float(), dim=0)) 30 | 31 | def forward(self): 32 | return (self.param - self.target).square().mean() 33 | 34 | 35 | @app.command() 36 | def main( 37 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 38 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 39 | size: int = 1024, 40 | batch: int = 256, 41 | steps: int = 100, 42 | weight_decay: float = 0, 43 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 44 | trials: int = 10, 45 | win_condition_multiplier: float = 1.0, 46 | config: Optional[str] = None, 47 | ): 48 | size = configs.get(config, {}).get("size", size) 49 | 50 | dtype = [getattr(torch, d) for d in dtype] 51 | model = Model(size).cuda() 52 | 53 | def data(): 54 | return None, None 55 | 56 | trial( 57 | model, 58 | data, 59 | None, 60 | param_norm_win_condition(win_condition_multiplier * 1e-8, -model.target), 61 | steps, 62 | opt[0], 63 | dtype[0], 64 | size, 65 | batch, 66 | weight_decay, 67 | method[0], 68 | 1, 69 | 1, 70 | failure_threshold=2, 71 | base_lr=1e-3, 72 | trials=trials, 73 | ) 74 | 75 | 76 | if __name__ == "__main__": 77 | app() 78 | -------------------------------------------------------------------------------- /benchmark/rastrigin.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pathlib 3 | import random 4 | from typing import List, Optional 5 | 6 | import matplotlib.colors 7 | import torch 8 | import torch.backends.opt_einsum 9 | import typer 10 | from torch import nn 11 | from utils import Plotter 12 | 13 | from benchmark.utils import SkipConfig, loss_win_condition, trial 14 | from heavyball.utils import set_torch 15 | 16 | app = typer.Typer(pretty_exceptions_enable=False) 17 | set_torch() 18 | 19 | 20 | def _formula(x, A): 21 | return x**2 + A * (1 - torch.cos(2 * math.pi * x)) 22 | 23 | 24 | def objective(*args, A=10): 25 | if len(args) == 1: 26 | return _formula(args[0], A).mean() 27 | 28 | return sum(_formula(x, A) for x in args) / len(args) 29 | 30 | 31 | class Model(nn.Module): 32 | def __init__(self, x): 33 | super().__init__() 34 | self.param = nn.Parameter(torch.tensor(x).float()) 35 | 36 | def forward(self): 37 | return objective(self.param) 38 | 39 | 40 | @app.command() 41 | def main( 42 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 43 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 44 | steps: int = 100, 45 | weight_decay: float = 0, 46 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 47 | show_image: bool = False, 48 | trials: int = 100, 49 | win_condition_multiplier: float = 1.0, 50 | size: int = 2, 51 | config: Optional[str] = None, 52 | ): 53 | if config is not None and config != "trivial": 54 | raise SkipConfig("'config' must be 'trivial'.") 55 | if show_image: 56 | assert size == 2, "Image can only be displayed for 2D functions" 57 | dtype = [getattr(torch, d) for d in dtype] 58 | coords = (-2.2,) * size 59 | 60 | # Clean up old plots 61 | for path in pathlib.Path(".").glob("rastrigin.png"): 62 | path.unlink() 63 | 64 | colors = list(matplotlib.colors.TABLEAU_COLORS.values()) 65 | rng = random.Random(0x1239121) 66 | rng.shuffle(colors) 67 | 68 | if show_image: 69 | model = Model(coords) 70 | model = Plotter( 71 | model, 72 | x_limits=(-8, 2), 73 | y_limits=(-8, 2), 74 | ) 75 | else: 76 | model = Model(coords) 77 | model.double() 78 | 79 | def data(): 80 | return None, None 81 | 82 | model = trial( 83 | model, 84 | data, 85 | None, 86 | loss_win_condition(win_condition_multiplier * 1e-2 * (not show_image)), 87 | steps, 88 | opt[0], 89 | dtype[0], 90 | 1, 91 | 1, 92 | weight_decay, 93 | method[0], 94 | 1, 95 | 1, 96 | base_lr=1e-4, 97 | trials=trials, 98 | return_best=show_image, 99 | ) 100 | 101 | if not show_image: 102 | return 103 | 104 | model.plot(title=f"{method[0]} {opt[0]}", save_path="rastrigin.png") 105 | 106 | 107 | if __name__ == "__main__": 108 | app() 109 | -------------------------------------------------------------------------------- /benchmark/requirements.txt: -------------------------------------------------------------------------------- 1 | lightgbm 2 | optuna>=4.3.0 3 | gpytorch>=1.14 4 | botorch>=0.13.0 5 | optuna-integration[botorch] 6 | optuna-integration>=4.3.0 7 | threadpoolctl>=3.6.0 8 | scipy>=1.15.2 9 | scikit-learn>=1.6.1 10 | numexpr>=2.10.2 11 | hebo>=0.3.6 12 | pymoo>=0.6.1 13 | numpy>=1.26.4,<2.0.0 14 | optunahub>=0.2.0 15 | cmaes>=0.11.1 16 | typer>=0.15.0,<0.16.0 17 | -------------------------------------------------------------------------------- /benchmark/rosenbrock.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import random 3 | from typing import List, Optional 4 | 5 | import matplotlib.colors 6 | import torch 7 | import torch.backends.opt_einsum 8 | import typer 9 | from torch import nn 10 | 11 | from benchmark.utils import Plotter, SkipConfig, loss_win_condition, trial 12 | from heavyball.utils import set_torch 13 | 14 | app = typer.Typer(pretty_exceptions_enable=False) 15 | set_torch() 16 | 17 | 18 | def objective(x, y): 19 | return (1 - x) ** 2 + 1 * (y - x**2) ** 2 20 | 21 | 22 | class Model(nn.Module): 23 | def __init__(self, x): 24 | super().__init__() 25 | self.param = nn.Parameter(torch.tensor(x).float()) 26 | 27 | def forward(self): 28 | return objective(*self.param) 29 | 30 | 31 | @app.command() 32 | def main( 33 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 34 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 35 | steps: int = 100, 36 | weight_decay: float = 0, 37 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 38 | show_image: bool = False, 39 | trials: int = 100, 40 | win_condition_multiplier: float = 1.0, 41 | config: Optional[str] = None, 42 | ): 43 | if config is not None and config != "trivial": 44 | raise SkipConfig("'config' must be 'trivial'.") 45 | dtype = [getattr(torch, d) for d in dtype] 46 | coords = (-7, -4) 47 | 48 | # Clean up old plots 49 | for path in pathlib.Path(".").glob("rosenbrock.png"): 50 | path.unlink() 51 | 52 | colors = list(matplotlib.colors.TABLEAU_COLORS.values()) 53 | rng = random.Random(0x1239121) 54 | rng.shuffle(colors) 55 | 56 | if show_image: 57 | model = Plotter(Model(coords), x_limits=(-8, 2), y_limits=(-8, 2), should_normalize=True) 58 | else: 59 | model = Model(coords) 60 | model.double() 61 | 62 | def data(): 63 | return None, None 64 | 65 | model = trial( 66 | model, 67 | data, 68 | None, 69 | loss_win_condition(win_condition_multiplier * 1e-9 * (not show_image)), 70 | steps, 71 | opt[0], 72 | dtype[0], 73 | 1, 74 | 1, 75 | weight_decay, 76 | method[0], 77 | 1, 78 | 1, 79 | base_lr=1e-4, 80 | trials=trials, 81 | return_best=show_image, 82 | ) 83 | 84 | if not show_image: 85 | return 86 | 87 | model.plot(save_path="rosenbrock.png") 88 | 89 | 90 | if __name__ == "__main__": 91 | app() 92 | -------------------------------------------------------------------------------- /benchmark/saddle_point.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import random 3 | from typing import List, Optional 4 | 5 | import matplotlib.colors 6 | import torch 7 | import torch.backends.opt_einsum 8 | import typer 9 | from torch import nn 10 | from utils import Plotter 11 | 12 | from benchmark.utils import loss_win_condition, trial 13 | from heavyball.utils import set_torch 14 | 15 | app = typer.Typer(pretty_exceptions_enable=False) 16 | set_torch() 17 | 18 | configs = { 19 | "trivial": {"power": 1}, 20 | "easy": {"power": 2}, 21 | "medium": {"power": 4}, 22 | "hard": {"power": 8}, 23 | "extreme": {"power": 16}, 24 | "nightmare": {"power": 32}, 25 | } 26 | 27 | 28 | def objective(*xs, power): 29 | """Classic saddle point objective - tests ability to escape saddle points.""" 30 | return sum(x**power for x in xs) 31 | 32 | 33 | class Model(nn.Module): 34 | def __init__(self, power, offset): 35 | super().__init__() 36 | self.param = nn.Parameter(torch.tensor([1.2, 1.9]).float()) 37 | self.offset = offset 38 | self.power = 2 * power + 1 39 | 40 | def forward(self): 41 | return objective(*self.param, power=self.power) + self.offset 42 | 43 | 44 | @app.command() 45 | def main( 46 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 47 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 48 | steps: int = 100, 49 | weight_decay: float = 0, 50 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 51 | show_image: bool = False, 52 | trials: int = 100, 53 | win_condition_multiplier: float = 1.0, 54 | config: Optional[str] = None, 55 | ): 56 | dtype = [getattr(torch, d) for d in dtype] 57 | coords = configs.get(config, {}).get("power", 1) 58 | 59 | # Clean up old plots 60 | for path in pathlib.Path(".").glob("saddle_point.png"): 61 | path.unlink() 62 | 63 | colors = list(matplotlib.colors.TABLEAU_COLORS.values()) 64 | rng = random.Random(0x1239121) 65 | rng.shuffle(colors) 66 | 67 | offset = win_condition_multiplier * 10 68 | 69 | if show_image: 70 | model = Plotter( 71 | lambda *x: objective(*x).add(offset).log(), 72 | coords=coords, 73 | xlim=(-2, 2), 74 | ylim=(-2, 2), 75 | normalize=8, 76 | after_step=torch.exp, 77 | ) 78 | else: 79 | model = Model(coords, offset) 80 | model.double() 81 | 82 | def data(): 83 | return None, None 84 | 85 | trial( 86 | model, 87 | data, 88 | None, 89 | loss_win_condition(0.1), 90 | steps, 91 | opt[0], 92 | dtype[0], 93 | 1, 94 | 1, 95 | weight_decay, 96 | method[0], 97 | 1, 98 | 1, 99 | failure_threshold=3, 100 | base_lr=1e-3, 101 | trials=trials, 102 | ) 103 | 104 | 105 | if __name__ == "__main__": 106 | app() 107 | -------------------------------------------------------------------------------- /benchmark/scale_invariant.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import loss_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | configs = { 14 | "trivial": {"range": 1}, 15 | "easy": {"range": 2}, 16 | "medium": {"range": 3}, 17 | "hard": {"range": 4}, 18 | "extreme": {"range": 5}, 19 | "nightmare": {"range": 6}, 20 | } 21 | 22 | 23 | def objective(x): 24 | """Tests optimizer's ability to handle different parameter scales.""" 25 | return torch.log1p(x.square()).mean() 26 | 27 | 28 | class Model(nn.Module): 29 | def __init__(self, size, value_range): 30 | super().__init__() 31 | # Initialize with different scales 32 | scales = torch.logspace(-value_range, value_range, size) 33 | self.param = nn.Parameter(scales * torch.randn(size)) 34 | 35 | def forward(self): 36 | return objective(self.param) 37 | 38 | 39 | @app.command() 40 | def main( 41 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 42 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 43 | steps: int = 100, 44 | weight_decay: float = 0, 45 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 46 | trials: int = 100, 47 | win_condition_multiplier: float = 1.0, 48 | size: int = 512, 49 | config: Optional[str] = None, 50 | ): 51 | value_range = configs.get(config, {}).get("range", 3) 52 | 53 | dtype = [getattr(torch, d) for d in dtype] 54 | model = Model(size, value_range).cuda().double() 55 | 56 | def data(): 57 | return None, None 58 | 59 | trial( 60 | model, 61 | data, 62 | None, 63 | loss_win_condition(win_condition_multiplier * 1e-3), 64 | steps, 65 | opt[0], 66 | dtype[0], 67 | 1, 68 | 1, 69 | weight_decay, 70 | method[0], 71 | 1, 72 | 1, 73 | failure_threshold=3, 74 | base_lr=1e-3, 75 | trials=trials, 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | app() 81 | -------------------------------------------------------------------------------- /benchmark/sparse_gradient.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | 8 | from benchmark.utils import loss_win_condition, trial 9 | from heavyball.utils import set_torch 10 | 11 | app = typer.Typer(pretty_exceptions_enable=False) 12 | set_torch() 13 | configs = { 14 | "trivial": {"sparsity": 0.5}, 15 | "easy": {"sparsity": 2**-3}, 16 | "medium": {"sparsity": 2**-6}, 17 | "hard": {"sparsity": 2**-8}, 18 | "extreme": {"sparsity": 2**-11}, 19 | "nightmare": {"sparsity": 2**-14}, 20 | } 21 | 22 | 23 | class Model(nn.Module): 24 | def __init__(self, size=2**16, sparsity=2**-6): 25 | super().__init__() 26 | self.param = nn.Parameter(torch.randn(size)) 27 | self.sparsity = sparsity 28 | self.register_buffer("prev_mask", torch.zeros_like(self.param)) 29 | 30 | def forward(self): 31 | """Test optimizer's ability to handle sparse gradients.""" 32 | # Generate new random mask each time, but keep some consistency 33 | new_mask = (torch.rand_like(self.param) < self.sparsity).float() 34 | mask = (new_mask + self.prev_mask) > 0 # Union of current and previous mask 35 | self.prev_mask.copy_(new_mask) 36 | 37 | return (self.param * mask.float()).square().mean() 38 | 39 | 40 | @app.command() 41 | def main( 42 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 43 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 44 | steps: int = 100, 45 | weight_decay: float = 0, 46 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 47 | trials: int = 100, 48 | win_condition_multiplier: float = 1.0, 49 | sparsity: float = 2**-6, 50 | config: Optional[str] = None, 51 | ): 52 | sparsity = configs.get(config, {}).get("sparsity", sparsity) 53 | dtype = [getattr(torch, d) for d in dtype] 54 | model = Model(sparsity=sparsity).cuda().double() 55 | 56 | def data(): 57 | return None, None 58 | 59 | # Win condition accounts for sparsity - harder to reach very low loss 60 | trial( 61 | model, 62 | data, 63 | None, 64 | loss_win_condition(win_condition_multiplier * 1e-4), 65 | steps, 66 | opt[0], 67 | dtype[0], 68 | 1, 69 | 1, 70 | weight_decay, 71 | method[0], 72 | 1, 73 | 1, 74 | failure_threshold=5, 75 | base_lr=1e-3, 76 | trials=trials, 77 | ) # More failure attempts allowed 78 | 79 | 80 | if __name__ == "__main__": 81 | app() 82 | -------------------------------------------------------------------------------- /benchmark/wide_linear.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import typer 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from benchmark.utils import param_norm_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | configs = { 16 | "trivial": {"size": 4}, 17 | "easy": {"size": 32}, 18 | "medium": {"size": 512}, 19 | "hard": {"size": 2048}, 20 | "extreme": {"size": 8192}, 21 | "nightmare": {"size": 2**14}, 22 | } 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, size): 27 | super().__init__() 28 | self.param = nn.Parameter(torch.randn((size, size))) 29 | self.target = nn.Buffer(torch.triu(torch.ones_like(self.param))) 30 | 31 | def forward(self, inp): 32 | return inp @ self.param 33 | 34 | 35 | @app.command() 36 | def main( 37 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 38 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 39 | size: int = 1024, 40 | depth: int = 4, 41 | batch: int = 16, 42 | steps: int = 10, 43 | weight_decay: float = 0, 44 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 45 | win_condition_multiplier: float = 1.0, 46 | trials: int = 10, 47 | config: Optional[str] = None, 48 | ): 49 | size = configs.get(config, {}).get("size", size) 50 | dtype = [getattr(torch, d) for d in dtype] 51 | model = Model(size).cuda() 52 | 53 | def data(): 54 | inp = torch.randn((batch, size), device="cuda", dtype=dtype[0]) 55 | return inp, inp.cumsum(1) 56 | 57 | trial( 58 | model, 59 | data, 60 | F.mse_loss, 61 | param_norm_win_condition(1e-7 * win_condition_multiplier, model.target), 62 | steps, 63 | opt[0], 64 | dtype[0], 65 | size, 66 | batch, 67 | weight_decay, 68 | method[0], 69 | 1, 70 | depth, 71 | failure_threshold=depth * 2, 72 | base_lr=1e-3, 73 | trials=trials, 74 | ) 75 | 76 | 77 | if __name__ == "__main__": 78 | app() 79 | -------------------------------------------------------------------------------- /benchmark/xor_digit.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import torch.nn as nn 6 | import typer 7 | from torch.nn import functional as F 8 | 9 | from benchmark.utils import loss_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | configs = { 16 | "trivial": {"length": 4}, 17 | "easy": {"length": 8}, 18 | "medium": {"length": 16}, 19 | "hard": {"length": 32}, 20 | "extreme": {"length": 64}, 21 | "nightmare": {"length": 96}, 22 | } 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, size, depth): 27 | super().__init__() 28 | self.embed = nn.Embedding(2, size) 29 | self.enc = nn.LSTM(size, size, depth, batch_first=False) 30 | self.enc.flatten_parameters() 31 | self.proj = nn.Sequential( 32 | nn.LayerNorm(size), # 33 | nn.Linear(size, 1), 34 | ) 35 | 36 | def forward(self, inp): 37 | inp = inp.transpose(0, 1) 38 | inp = self.embed(inp.squeeze(-1).long()) 39 | out, _ = torch.compiler.disable()(self.enc)(inp) 40 | return self.proj(out[-1, :]) 41 | 42 | 43 | @app.command() 44 | def main( 45 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 46 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 47 | length: int = 64, 48 | size: int = 64, 49 | depth: int = 1, 50 | batch: int = 256, 51 | steps: int = 10, 52 | weight_decay: float = 0, 53 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 54 | trials: int = 10, 55 | win_condition_multiplier: float = 1.0, 56 | config: Optional[str] = None, 57 | ): 58 | length = configs.get(config, {}).get("length", length) 59 | dtype = [getattr(torch, d) for d in dtype] 60 | torch.manual_seed(0x1239121) 61 | model = Model(size, depth).cuda() 62 | 63 | def data(): 64 | inp = torch.randn((batch, length, 1), device="cuda", dtype=dtype[0]) 65 | inp = inp > 0 66 | return inp.to(dtype[0]), (inp.sum(1) % 2).to(dtype[0]) 67 | 68 | trial( 69 | model, 70 | data, 71 | F.binary_cross_entropy_with_logits, 72 | loss_win_condition(win_condition_multiplier * 1e-3), 73 | steps, 74 | opt[0], 75 | dtype[0], 76 | size, 77 | batch, 78 | weight_decay, 79 | method[0], 80 | length, 81 | depth, 82 | failure_threshold=10, 83 | base_lr=1e-6, 84 | trials=trials, 85 | ) 86 | 87 | 88 | if __name__ == "__main__": 89 | app() 90 | -------------------------------------------------------------------------------- /benchmark/xor_digit_rnn.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import torch.nn as nn 6 | import typer 7 | from torch.nn import functional as F 8 | 9 | from benchmark.utils import loss_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | configs = { 16 | "trivial": {"length": 4}, 17 | "easy": {"length": 8}, 18 | "medium": {"length": 16}, 19 | "hard": {"length": 32}, 20 | "extreme": {"length": 64}, 21 | "nightmare": {"length": 96}, 22 | } 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, size, depth): 27 | super().__init__() 28 | self.embed = nn.Embedding(2, size) 29 | self.enc = nn.RNN(size, size, depth, batch_first=False) 30 | self.enc.flatten_parameters() 31 | self.proj = nn.Sequential( 32 | nn.LayerNorm(size), # 33 | nn.Linear(size, 1), 34 | ) 35 | 36 | def forward(self, inp): 37 | inp = inp.transpose(0, 1) 38 | inp = self.embed(inp.squeeze(-1).long()) 39 | out, _ = torch.compiler.disable()(self.enc)(inp) 40 | return self.proj(out[-1, :]) 41 | 42 | 43 | @app.command() 44 | def main( 45 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 46 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 47 | length: int = 64, 48 | size: int = 64, 49 | depth: int = 1, 50 | batch: int = 256, 51 | steps: int = 10, 52 | weight_decay: float = 0, 53 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 54 | trials: int = 10, 55 | win_condition_multiplier: float = 1.0, 56 | config: Optional[str] = None, 57 | ): 58 | length = configs.get(config, {}).get("length", length) 59 | 60 | dtype = [getattr(torch, d) for d in dtype] 61 | torch.manual_seed(0x1239121) 62 | model = Model(size, depth).cuda() 63 | 64 | def data(): 65 | inp = torch.randn((batch, length, 1), device="cuda", dtype=dtype[0]) 66 | inp = inp > 0 67 | return inp.to(dtype[0]), (inp.sum(1) % 2).to(dtype[0]) 68 | 69 | trial( 70 | model, 71 | data, 72 | F.binary_cross_entropy_with_logits, 73 | loss_win_condition(win_condition_multiplier * 1e-3), 74 | steps, 75 | opt[0], 76 | dtype[0], 77 | size, 78 | batch, 79 | weight_decay, 80 | method[0], 81 | length, 82 | depth, 83 | failure_threshold=10, 84 | base_lr=1e-6, 85 | trials=trials, 86 | ) 87 | 88 | 89 | if __name__ == "__main__": 90 | app() 91 | -------------------------------------------------------------------------------- /benchmark/xor_sequence.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import torch.nn as nn 6 | import typer 7 | from torch.nn import functional as F 8 | 9 | from benchmark.utils import loss_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | configs = { 16 | "trivial": {"length": 2}, 17 | "easy": {"length": 4}, 18 | "medium": {"length": 6}, 19 | "hard": {"length": 9}, 20 | "extreme": {"length": 12}, 21 | "nightmare": {"length": 14}, 22 | } 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, size, depth): 27 | super().__init__() 28 | self.embed0 = nn.Embedding(2, size) 29 | self.embed1 = nn.Embedding(2, size) 30 | self.enc = nn.LSTM(size, size, depth, batch_first=False) 31 | self.dec = nn.LSTM(size, size, depth, batch_first=False) 32 | self.enc.flatten_parameters() 33 | self.dec.flatten_parameters() 34 | self.proj = nn.Sequential( 35 | nn.LayerNorm(size), # 36 | nn.Linear(size, 1), 37 | ) 38 | 39 | def forward(self, inp): 40 | i0, i1 = inp.chunk(2, 1) 41 | i0 = i0.transpose(0, 1) 42 | i1 = i1.transpose(0, 1) 43 | i0 = self.embed0(i0) 44 | i1 = self.embed1(i1) 45 | _, state = torch.compiler.disable()(self.enc)(i0) 46 | out, _ = torch.compiler.disable()(self.dec)(i1, state) 47 | return self.proj(out.transpose(0, 1)) 48 | 49 | 50 | @app.command() 51 | def main( 52 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 53 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 54 | length: int = 14, 55 | size: int = 16, 56 | depth: int = 1, 57 | batch: int = 256, 58 | steps: int = 100, 59 | weight_decay: float = 0, 60 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 61 | win_condition_multiplier: float = 1, 62 | trials: int = 10, 63 | config: Optional[str] = None, 64 | ): 65 | length = configs.get(config, {}).get("length", length) 66 | 67 | dtype = [getattr(torch, d) for d in dtype] 68 | torch.manual_seed(0x1239121) 69 | model = Model(size, depth).cuda() 70 | 71 | def data(): 72 | inp = torch.randn((batch, 2 * length, 1), device="cuda", dtype=dtype[0]) 73 | inp = inp > 0 74 | i0, i1 = inp.chunk(2, 1) 75 | xored = torch.logical_xor(i0, i1) 76 | return inp.long().squeeze(-1), xored.to(dtype[0]) 77 | 78 | trial( 79 | model, 80 | data, 81 | F.binary_cross_entropy_with_logits, 82 | loss_win_condition(win_condition_multiplier * 1e-2), 83 | steps, 84 | opt[0], 85 | dtype[0], 86 | size, 87 | batch, 88 | weight_decay, 89 | method[0], 90 | length, 91 | depth, 92 | failure_threshold=10, 93 | base_lr=0.001, 94 | trials=trials, 95 | ) 96 | 97 | 98 | if __name__ == "__main__": 99 | app() 100 | -------------------------------------------------------------------------------- /benchmark/xor_sequence_rnn.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.backends.opt_einsum 5 | import torch.nn as nn 6 | import typer 7 | from torch.nn import functional as F 8 | 9 | from benchmark.utils import loss_win_condition, trial 10 | from heavyball.utils import set_torch 11 | 12 | app = typer.Typer(pretty_exceptions_enable=False) 13 | set_torch() 14 | 15 | 16 | configs = { 17 | "trivial": {"length": 2}, 18 | "easy": {"length": 4}, 19 | "medium": {"length": 6}, 20 | "hard": {"length": 9}, 21 | "extreme": {"length": 12}, 22 | "nightmare": {"length": 14}, 23 | } 24 | 25 | 26 | class Model(nn.Module): 27 | def __init__(self, size, depth): 28 | super().__init__() 29 | self.embed0 = nn.Embedding(2, size) 30 | self.embed1 = nn.Embedding(2, size) 31 | self.enc = nn.RNN(size, size, depth, batch_first=False) 32 | self.dec = nn.RNN(size, size, depth, batch_first=False) 33 | self.enc.flatten_parameters() 34 | self.dec.flatten_parameters() 35 | self.proj = nn.Sequential( 36 | nn.LayerNorm(size), # 37 | nn.Linear(size, 1), 38 | ) 39 | 40 | def forward(self, inp): 41 | i0, i1 = inp.chunk(2, 1) 42 | i0 = i0.transpose(0, 1) 43 | i1 = i1.transpose(0, 1) 44 | i0 = self.embed0(i0) 45 | i1 = self.embed1(i1) 46 | _, state = torch.compiler.disable()(self.enc)(i0) 47 | out, _ = torch.compiler.disable()(self.dec)(i1, state) 48 | return self.proj(out.transpose(0, 1)) 49 | 50 | 51 | @app.command() 52 | def main( 53 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 54 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 55 | length: int = 14, 56 | size: int = 16, 57 | depth: int = 1, 58 | batch: int = 256, 59 | steps: int = 100, 60 | weight_decay: float = 0, 61 | opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), 62 | win_condition_multiplier: float = 1, 63 | trials: int = 10, 64 | config: Optional[str] = None, 65 | ): 66 | length = configs.get(config, {}).get("length", length) 67 | 68 | dtype = [getattr(torch, d) for d in dtype] 69 | torch.manual_seed(0x1239121) 70 | model = Model(size, depth).cuda() 71 | 72 | def data(): 73 | inp = torch.randn((batch, 2 * length, 1), device="cuda", dtype=dtype[0]) 74 | inp = inp > 0 75 | i0, i1 = inp.chunk(2, 1) 76 | xored = torch.logical_xor(i0, i1) 77 | return inp.long().squeeze(-1), xored.to(dtype[0]) 78 | 79 | trial( 80 | model, 81 | data, 82 | F.binary_cross_entropy_with_logits, 83 | loss_win_condition(win_condition_multiplier * 1e-2), 84 | steps, 85 | opt[0], 86 | dtype[0], 87 | size, 88 | batch, 89 | weight_decay, 90 | method[0], 91 | length, 92 | depth, 93 | failure_threshold=10, 94 | base_lr=0.001, 95 | trials=trials, 96 | ) 97 | 98 | 99 | if __name__ == "__main__": 100 | app() 101 | -------------------------------------------------------------------------------- /benchmark/xor_spot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inspired by https://github.com/lixilinx/psgd_torch/blob/master/rnn_xor_problem_general_purpose_preconditioner.py 3 | This version is strongly simplified but follows the same basic idea: 4 | 1) Generate random sequence 5 | 2) Mark two spots 6 | 3) Train a model to predict the xor of the two spots 7 | This does NOT elicit memory in the RNN, but it does force it to learn a pointwise forget mechanism. 8 | """ 9 | 10 | import itertools 11 | from typing import List, Optional 12 | 13 | import torch 14 | import torch.backends.opt_einsum 15 | import torch.nn as nn 16 | import typer 17 | 18 | from benchmark.utils import loss_win_condition, trial 19 | from heavyball.utils import set_torch 20 | 21 | app = typer.Typer(pretty_exceptions_enable=False) 22 | set_torch() 23 | configs = { 24 | "trivial": {"length": 4}, 25 | "easy": {"length": 8}, 26 | "medium": {"length": 16}, 27 | "hard": {"length": 32}, 28 | "extreme": {"length": 64}, 29 | "nightmare": {"length": 96}, 30 | } 31 | 32 | 33 | class Model(nn.Module): 34 | def __init__(self, size, depth): 35 | super().__init__() 36 | self.embed = nn.Embedding(4, size) 37 | self.enc = nn.LSTM(size, size, depth, batch_first=False) 38 | self.enc.flatten_parameters() 39 | self.proj = nn.Sequential( 40 | nn.LayerNorm(size), # 41 | nn.Linear(size, 1), 42 | ) 43 | 44 | def forward(self, inp): 45 | inp = self.embed(inp.squeeze(-1).long()) 46 | inp = inp[0] + inp[1] 47 | out, _ = torch.compiler.disable()(self.enc)(inp.transpose(0, 1)) 48 | return self.proj(out[-1, :]) 49 | 50 | 51 | @app.command() 52 | def main( 53 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 54 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 55 | length: int = 64, 56 | size: int = 64, 57 | depth: int = 1, 58 | batch: int = 256, 59 | steps: int = 10, 60 | weight_decay: float = 0, 61 | opt: List[str] = typer.Option( 62 | ["ForeachSOAP", "PaLMForeachSOAP", "PrecondScheduleForeachSOAP"], help="Optimizers to use" 63 | ), 64 | win_condition_multiplier: float = 1.0, 65 | trials: int = 10, 66 | config: Optional[str] = None, 67 | ): 68 | length = configs.get(config, {}).get("length", length) 69 | dtype = [getattr(torch, d) for d in dtype] 70 | 71 | for args in itertools.product(method, dtype, [(length, size, depth, batch)], opt, [weight_decay]): 72 | m, d, (l, s, dp, b), o, wd = args 73 | 74 | model = Model(s, dp).cuda() 75 | 76 | def data(): 77 | inp = torch.randn((b, l, 1), device="cuda", dtype=d) 78 | inp = inp > 0 79 | zeros = torch.zeros_like(inp) 80 | zeros[:, torch.randint(0, l, (b,), device="cuda")] = 1 81 | zeros[:, torch.randint(0, l, (b,), device="cuda")] = 1 82 | target = (inp * zeros).sum(1) % 2 83 | return torch.stack((inp, zeros + 2), 0).to(d), target.to(d) 84 | 85 | trial( 86 | model, 87 | data, 88 | torch.nn.functional.binary_cross_entropy_with_logits, 89 | loss_win_condition(win_condition_multiplier * 1e-2), 90 | steps, 91 | o, 92 | d, 93 | s, 94 | b, 95 | wd, 96 | m, 97 | l, 98 | dp, 99 | failure_threshold=10, 100 | trials=trials, 101 | ) 102 | 103 | 104 | if __name__ == "__main__": 105 | app() 106 | -------------------------------------------------------------------------------- /benchmark/xor_spot_rnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inspired by https://github.com/lixilinx/psgd_torch/blob/master/rnn_xor_problem_general_purpose_preconditioner.py 3 | This version is strongly simplified but follows the same basic idea: 4 | 1) Generate random sequence 5 | 2) Mark two spots 6 | 3) Train a model to predict the xor of the two spots 7 | This does NOT elicit memory in the RNN, but it does force it to learn a pointwise forget mechanism. 8 | """ 9 | 10 | import itertools 11 | from typing import List, Optional 12 | 13 | import torch 14 | import torch.backends.opt_einsum 15 | import torch.nn as nn 16 | import typer 17 | 18 | from benchmark.utils import loss_win_condition, trial 19 | from heavyball.utils import set_torch 20 | 21 | app = typer.Typer(pretty_exceptions_enable=False) 22 | set_torch() 23 | 24 | configs = { 25 | "trivial": {"length": 4}, 26 | "easy": {"length": 8}, 27 | "medium": {"length": 16}, 28 | "hard": {"length": 32}, 29 | "extreme": {"length": 64}, 30 | "nightmare": {"length": 96}, 31 | } 32 | 33 | 34 | class Model(nn.Module): 35 | def __init__(self, size, depth): 36 | super().__init__() 37 | self.embed = nn.Embedding(4, size) 38 | self.enc = nn.RNN(size, size, depth, batch_first=False) 39 | self.enc.flatten_parameters() 40 | self.proj = nn.Sequential( 41 | nn.LayerNorm(size), # 42 | nn.Linear(size, 1), 43 | ) 44 | 45 | def forward(self, inp): 46 | inp = self.embed(inp.squeeze(-1).long()) 47 | inp = inp[0] + inp[1] 48 | out, _ = torch.compiler.disable()(self.enc)(inp.transpose(0, 1)) 49 | return self.proj(out[-1, :]) 50 | 51 | 52 | @app.command() 53 | def main( 54 | method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), 55 | dtype: List[str] = typer.Option(["float32"], help="Data type to use"), 56 | length: int = 64, 57 | size: int = 64, 58 | depth: int = 1, 59 | batch: int = 256, 60 | steps: int = 10, 61 | weight_decay: float = 0, 62 | opt: List[str] = typer.Option( 63 | ["ForeachSOAP", "PaLMForeachSOAP", "PrecondScheduleForeachSOAP"], help="Optimizers to use" 64 | ), 65 | win_condition_multiplier: float = 1.0, 66 | trials: int = 10, 67 | config: Optional[str] = None, 68 | ): 69 | length = configs.get(config, {}).get("length", length) 70 | 71 | dtype = [getattr(torch, d) for d in dtype] 72 | 73 | for args in itertools.product(method, dtype, [(length, size, depth, batch)], opt, [weight_decay]): 74 | m, d, (l, s, dp, b), o, wd = args 75 | 76 | model = Model(s, dp).cuda() 77 | 78 | def data(): 79 | inp = torch.randn((b, l, 1), device="cuda", dtype=d) 80 | inp = inp > 0 81 | zeros = torch.zeros_like(inp) 82 | zeros[:, torch.randint(0, l, (b,), device="cuda")] = 1 83 | zeros[:, torch.randint(0, l, (b,), device="cuda")] = 1 84 | target = (inp * zeros).sum(1) % 2 85 | return torch.stack((inp, zeros + 2), 0).to(d), target.to(d) 86 | 87 | trial( 88 | model, 89 | data, 90 | torch.nn.functional.binary_cross_entropy_with_logits, 91 | loss_win_condition(win_condition_multiplier * 1e-2), 92 | steps, 93 | o, 94 | d, 95 | s, 96 | b, 97 | wd, 98 | m, 99 | l, 100 | dp, 101 | failure_threshold=10, 102 | trials=trials, 103 | ) 104 | 105 | 106 | if __name__ == "__main__": 107 | app() 108 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | rm -rf dist/* 2 | python -m build 3 | python -m twine upload dist/* 4 | rm -rf dist/ build/ heavyball.egg-info/ 5 | -------------------------------------------------------------------------------- /docs/assets/benchmark_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/docs/assets/benchmark_matrix.png -------------------------------------------------------------------------------- /docs/assets/early_stopping.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/docs/assets/early_stopping.png -------------------------------------------------------------------------------- /docs/assets/hiker_analogy_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/docs/assets/hiker_analogy_diagram.png -------------------------------------------------------------------------------- /docs/assets/psgd_efficiency_cache.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/docs/assets/psgd_efficiency_cache.png -------------------------------------------------------------------------------- /docs/assets/psgd_efficiency_cache_triu_as_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/docs/assets/psgd_efficiency_cache_triu_as_line.png -------------------------------------------------------------------------------- /docs/assets/psgd_efficiency_triu_as_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/docs/assets/psgd_efficiency_triu_as_line.png -------------------------------------------------------------------------------- /docs/assets/saddle_point_comparison.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HomebrewML/HeavyBall/d0995758749242a51b476c832faac3bfc2969a90/docs/assets/saddle_point_comparison.gif -------------------------------------------------------------------------------- /docs/benchmark.md: -------------------------------------------------------------------------------- 1 | # Beyond the Leaderboard: A Diagnostic Benchmark for Optimizer Reliability 2 | 3 | Traditional machine learning benchmarks often incentivize "teaching to the test," a race to top leaderboards that may 4 | not reflect real-world performance. This practice can mask a critical issue: **silent failures**. Optimizers may 5 | converge without reporting an error, yet settle in a suboptimal solution, leading to models that underperform in subtle 6 | but significant ways. Such failures are costly, as they often go undetected until significant downstream damage has occurred. 7 | 8 | The HeavyBall Benchmark was created to address this problem. It's not another leaderboard but a diagnostic tool 9 | designed to expose hidden optimizer weaknesses. Each test provides a clear, pass/fail check against known optimization 10 | challenges, creating a detailed map of an optimizer's strengths and weaknesses. 11 | 12 | ## The Problem of Silent Failures 13 | 14 | A silent failure occurs when an optimizer converges without error yet settles in a suboptimal basin, leading to poor 15 | downstream model performance. A descending loss curve can be profoundly misleading. Without a clear point of comparison, 16 | an optimizer that has simply gotten stuck can appear identical to one that has found a genuinely good solution. 17 | 18 | ![Three optimizers (SGD, L-BFGS, PSGD) converging to different loss values on the same problem.](assets/early_stopping.png) 19 | *Figure 3 from ["Black Box Lie Group Preconditioners for SGD"](https://arxiv.org/abs/2211.04422). As shown, SGD (blue) 20 | appears converged, but PSGD (black) finds a significantly better basin. This highlights a fundamental capability gap.* 21 | 22 | The graph above illustrates this trap. All three optimizers seem to converge as their loss curves flatten. An 23 | unsuspecting practitioner might see the flatlining loss of the blue curve (SGD) and conclude the job is done. Yet, the 24 | black curve (PSGD) consistently finds a solution that is an order of magnitude better. This isn't a matter of an unlucky 25 | random seed; it's a capability gap. No amount of rerunning SGD will find the better solution that PSGD locates with 26 | ease. 27 | 28 | This capability gap often goes unnoticed. The training logs provide no explicit error, leaving developers to discover 29 | the underperformance only through expensive downstream evaluation. 30 | 31 | ## The HeavyBall Diagnostic Benchmark 32 | 33 | Instead of a single score, the HeavyBall Benchmark provides a granular, diagnostic map of an optimizer's performance. 34 | It's built on a suite of over 150 independent, pass/fail tests, each targeting a specific, well-understood challenge 35 | known to cause hidden failures. 36 | 37 | ![A heatmap showing various optimizers (columns) and their success rate on different benchmark tasks (rows).](assets/benchmark_matrix.png) 38 | *This map reveals systemic strengths and weaknesses in popular optimizers. Each cell represents a pass/fail outcome for 39 | a solver (column) on a specific task (row). Darker blue indicates a higher success rate.* 40 | 41 | There is no partial credit; the optimizer either solves the problem or it does not. This binary outcome makes failure 42 | modes explicit, turning abstract weaknesses into concrete, observable data. 43 | 44 | ### Experimental Setup 45 | 46 | To ensure a fair comparison, we gave each optimizer a generous budget: 1,000 hyperparameter trials per task, with each 47 | trial running for up to 1,000,000 steps. This setup tests the optimizer's raw capability, not just its default settings. 48 | The `Attempts` column in our results reflects the average number of trials needed for success; a lower number indicates 49 | that the method is easier to tune for the problems it can solve. 50 | 51 | ## Benchmark Results: No Single Best Optimizer 52 | 53 | Our results show that no single optimizer dominates. Even the best-performing optimizers exhibit surprising weaknesses, 54 | reinforcing the need for diagnostic rather than purely comparative evaluation. 55 | 56 | | Optimizer | Cautious¹ | Mars² | Success | Attempts | Avg Runtime (s) | 57 | |:---------------|:----------|:------|:--------|:---------|:----------------| 58 | | PSGDKron | No | No | 77.0% | 73.2 | 8240 | 59 | | NewtonPSGDKron | No | No | 77.0% | 80.5 | 9052 | 60 | | AdamW | Yes | No | 75.7% | 61.2 | 8072 | 61 | | ForeachSOAP | No | No | 72.5% | 77.9 | 7827 | 62 | | AdamW | No | No | 72.3% | 107.8 | 10029 | 63 | | MuonLaProp | No | No | 68.2% | 82.7 | 10141 | 64 | | RMSprop | No | No | 55.6% | 114.4 | 10725 | 65 | | Muon | No | No | 51.0% | 129.1 | 14525 | 66 | 67 | ¹ `Cautious`: Avoids taking a step in a direction the current gradient does not agree with 68 |
69 | ² `Mars`: Reduces variance in gradients 70 | 71 | *This is a subset of the full results. For a complete breakdown, see 72 | the [full benchmark results](https://github.com/HomebrewML/HeavyBall/blob/main/benchmark/benchmark_results.md).* 73 | 74 | ### Case Study: The `AdamW` Family 75 | 76 | The results for the popular AdamW optimizer are particularly revealing. The standard implementation has a respectable 77 | 72.3% success rate. However, enabling the `Cautious` flag boosts the success rate to 75.7% and significantly reduces the 78 | number of attempts needed to find a solution. This isn't just a number; it's a diagnostic signal. The standard `AdamW` 79 | is more prone to getting stuck in ways that its `Cautious` variant can avoid, allowing a practitioner to make a more 80 | informed choice. 81 | 82 | ### Case Study: Escaping the Saddle Point 83 | 84 | An optimizer’s inability to navigate a saddle point is a classic example of a silent failure. A key test of an 85 | optimizer's robustness is its ability to navigate a saddle point—a region that is a minimum in one direction but a 86 | maximum in another. The gradient approaches zero at the center, trapping first-order methods that rely solely on the 87 | gradient. 88 | 89 | ![Animation: Optimizer paths on a saddle point, showing SGD getting stuck while a momentum-based optimizer successfully escapes.](assets/saddle_point_comparison.gif) 90 | 91 | Our benchmark includes a specific test for this challenge. An optimizer that passes demonstrates a greater capacity to 92 | handle the complex non-convex landscapes common in deep learning. A failure provides a clear diagnostic signal that the 93 | optimizer may be unreliable in these settings. 94 | 95 | ## Conclusion 96 | 97 | The HeavyBall Benchmark represents a necessary shift in how we evaluate optimizers, moving from a culture of 98 | score-chasing to one of deep, diagnostic understanding. These hidden failures aren’t rare edge cases—they’re a routine 99 | source of wasted compute and disappointing models. By making them explicit, the benchmark equips researchers and 100 | practitioners with a detailed map of an optimizer's capabilities. By clearly identifying hidden failure modes, 101 | practitioners can confidently choose, tune, or reconsider their optimization strategies, ultimately leading to more 102 | robust and reliable models. Future work will focus on expanding our suite of diagnostic tests to cover more complex 103 | failure modes and developing novel visualization techniques. 104 | 105 | --- 106 | 107 | **Resources:** 108 | 109 | * **Full Results:** [benchmark/benchmark_results.md](https://github.com/HomebrewML/HeavyBall/blob/main/benchmark/benchmark_results.md) 110 | * **Benchmark Code:** [https://github.com/HomebrewML/HeavyBall/tree/main/benchmark](https://github.com/HomebrewML/HeavyBall/tree/main/benchmark) 111 | -------------------------------------------------------------------------------- /docs/psgd_efficiency.md: -------------------------------------------------------------------------------- 1 | # PSGD Efficiency 2 | 3 | This document discusses various methods of reducing PSGD's memory and compute overhead, as well as the trade-offs 4 | involved. 5 | 6 | ## `triu_as_line` 7 | 8 | `triu_as_line` is an argument that reduces the preconditioniner (`Q`) storage overhead by storing only the upper 9 | triangle of the triangular `Q` as a 1D array, halving memory usage.\ 10 | This comes at the cost of having to remap the 1D array to a 2D array every time the preconditioner is used, which needs 11 | significant memory bandwidth. 12 | `triu_as_line` is enabled by default, and can be disabled by setting it to `False`.\ 13 | 14 | A high-overhead test-case (`python3 xor_digit.py --batch 16 --size 1024 --length 4 --depth 1`) showed that the total 15 | step time may be increased by up ~58% when training with `triu_as_line=True`.\ 16 | Larger batch sizes help ammortize this issue. 17 | 18 | ![psgd_efficiency_triu_as_line.png](assets/psgd_efficiency_triu_as_line.png) 19 | 20 | ## Cached Preconditioner 21 | 22 | For `PSGDKron`, there's an alternative variant, `CachedPSGDKron`.\ 23 | PSGDKron computes the preconditioning matrix on the fly based on the triangular `Q`. However, this preconditioner can be 24 | precomputed and reused across steps. This reduces the per-step overhead of computing the preconditioner, but doubles 25 | the memory overhead. 26 | 27 | ![psgd_efficiency_cache.png](assets/psgd_efficiency_cache.png) 28 | 29 | If the doubled memory cost of `CachedPSGDKron` is too high, it's possible to use `CachedPSGDKron` with 30 | `triu_as_line=True`, which reduces the total memory cost from 2x `Q` to 1.5x `Q`. 31 | 32 | ![psgd_efficiency_cache_triu_as_line.png](assets/psgd_efficiency_cache_triu_as_line.png) 33 | -------------------------------------------------------------------------------- /examples/autoencoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from datetime import datetime 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import tqdm 9 | from torch.nn import functional as F 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torchvision.datasets import MNIST 12 | from torchvision.transforms import v2 13 | from torchvision.utils import make_grid 14 | 15 | import heavyball 16 | 17 | heavyball.utils.set_torch() 18 | 19 | 20 | class Residual(nn.Sequential): 21 | def forward(self, input): 22 | out = super().forward(input) 23 | return out + F.interpolate(input, out.shape[2:]) 24 | 25 | 26 | class Block(nn.Sequential): 27 | def __init__( 28 | self, 29 | in_features: int, 30 | intermediate: int, 31 | out_features: int, 32 | kernel: int, 33 | stride: int, 34 | up: bool, 35 | depth: int, 36 | ): 37 | padding = kernel // 2 38 | layers = [nn.Conv2d(in_features, intermediate, kernel_size=kernel, padding=padding)] 39 | 40 | for _ in range(depth): 41 | layers.append( 42 | Residual( 43 | nn.Upsample(scale_factor=stride) if up else nn.MaxPool2d(stride), 44 | nn.BatchNorm2d(intermediate), 45 | nn.ReLU(), 46 | nn.Conv2d(intermediate, intermediate, kernel_size=kernel, padding=padding), 47 | ) 48 | ) 49 | 50 | layers.append(nn.ReLU()) 51 | layers.append(nn.Conv2d(intermediate, out_features, kernel_size=kernel, padding=padding)) 52 | 53 | super().__init__(*layers) 54 | 55 | 56 | class Autoencoder(nn.Module): 57 | def __init__(self, kernel: int = 3, stride: int = 2, hidden: int = 1, intermediate: int = 128): 58 | super(Autoencoder, self).__init__() 59 | self.enc = Block(1, intermediate, hidden, kernel, stride, False, 3) 60 | self.dec = Block(hidden, intermediate, 1, kernel, stride, True, 3) 61 | 62 | def forward(self, x): 63 | x = self.enc(x).sigmoid() 64 | # label = x > torch.rand_like(x) 65 | # x = label.detach().float() + x - x.detach() 66 | out = self.dec(x) 67 | return out 68 | 69 | 70 | class RandomPad(nn.Module): 71 | def __init__(self, amount: int): 72 | super().__init__() 73 | self.amount = amount 74 | self.rng = np.random.default_rng(0x12312) 75 | 76 | def forward(self, inp): 77 | new = [] 78 | xs, ys = np.split((np.random.randint(0, self.amount, size=2 * inp.size(0)) * self.amount).round(), 2) 79 | for val, x, y in zip(inp, xs, ys): 80 | padded = F.pad(val, (x, self.amount - x, y, self.amount - y)) 81 | new.append(padded) 82 | return torch.stack(new) 83 | 84 | 85 | def main(epochs: int, batch: int, log_interval: int = 16): 86 | # Setup tensorboard logging 87 | torch.manual_seed(0x12783) 88 | np.random.seed(0x12783) 89 | random.seed(0x12783) 90 | log_dir = os.path.join("runs", f"soap_{datetime.now().strftime('%Y%m%d_%H%M%S')}") 91 | writer = SummaryWriter(log_dir) 92 | 93 | model = torch.compile(Autoencoder().cuda(), mode="default") 94 | optimizer = heavyball.PSGDKron( 95 | model.parameters(), 96 | lr=1e-4, 97 | mars=True, 98 | lower_bound_beta=0.9, 99 | inverse_free=True, 100 | precond_update_power_iterations=6, 101 | store_triu_as_line=False, 102 | ) 103 | 104 | transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32)]) 105 | train = [img for img, _ in MNIST(root="./data", train=True, download=True, transform=transform)] 106 | test = [img for _, (img, _) in zip(range(8), MNIST(root="./data", train=False, download=True, transform=transform))] 107 | 108 | train = torch.stack(train).cuda() / 255.0 109 | eval_batch = torch.stack(test) / 255.0 110 | 111 | transform = RandomPad(4) 112 | eval_batch = transform(eval_batch) 113 | eval_batch_cuda = eval_batch.cuda() 114 | step = 0 115 | total_loss = 0 116 | 117 | for epoch in range(epochs): 118 | train = train[torch.randperm(train.size(0))].contiguous() 119 | batches = transform(train) 120 | batches = batches[: batches.size(0) // batch * batch] 121 | batches = batches.view(-1, batch, *batches.shape[1:]) 122 | 123 | for i in tqdm.tqdm(range(batches.size(0))): 124 | img = batches[i] 125 | step += 1 126 | 127 | def _closure(): 128 | output = model(img) 129 | loss = F.mse_loss(output, img) 130 | loss.backward() 131 | return loss 132 | 133 | loss = optimizer.step(_closure) 134 | optimizer.zero_grad() 135 | with torch.no_grad(): 136 | total_loss = total_loss + loss.detach() 137 | 138 | if step % log_interval == 0: 139 | avg_loss = (total_loss / log_interval).item() 140 | writer.add_scalar("Loss/train", avg_loss, step) 141 | total_loss = 0 142 | if step % (log_interval * 10) == 0: 143 | writer.flush() 144 | 145 | with torch.no_grad(): 146 | model.eval() 147 | samples = model(eval_batch_cuda) 148 | comparison = torch.cat([eval_batch, samples.cpu()], dim=0) 149 | grid = make_grid(comparison, nrow=8, normalize=True, padding=2) 150 | writer.add_image("reconstructions", grid, epoch) 151 | model.train() 152 | writer.flush() 153 | 154 | 155 | if __name__ == "__main__": 156 | main(epochs=100, batch=128) 157 | -------------------------------------------------------------------------------- /examples/lra.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import torch.nn as nn 7 | import tqdm 8 | from preconditioned_stochastic_gradient_descent import LRA 9 | from torch.nn import functional as F 10 | from torch.utils.data import DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torchvision.datasets import MNIST 13 | from torchvision.transforms import v2 14 | from torchvision.utils import make_grid 15 | 16 | import heavyball 17 | 18 | heavyball.utils.compile_mode = None 19 | heavyball.utils.set_torch() 20 | 21 | 22 | class Residual(nn.Sequential): 23 | def forward(self, input): 24 | out = super().forward(input) 25 | return out + F.interpolate(input, out.shape[2:]) 26 | 27 | 28 | class Block(nn.Sequential): 29 | def __init__( 30 | self, 31 | in_features: int, 32 | intermediate: int, 33 | out_features: int, 34 | kernel: int, 35 | stride: int, 36 | up: bool, 37 | depth: int, 38 | ): 39 | padding = kernel // 2 40 | layers = [nn.Conv2d(in_features, intermediate, kernel_size=kernel, padding=padding)] 41 | 42 | for _ in range(depth): 43 | layers.append( 44 | Residual( 45 | nn.Upsample(scale_factor=stride) if up else nn.MaxPool2d(stride), 46 | nn.BatchNorm2d(intermediate), 47 | nn.ReLU(), 48 | nn.Conv2d(intermediate, intermediate, kernel_size=kernel, padding=padding), 49 | ) 50 | ) 51 | 52 | layers.append(nn.ReLU()) 53 | layers.append(nn.Conv2d(intermediate, out_features, kernel_size=kernel, padding=padding)) 54 | 55 | super().__init__(*layers) 56 | 57 | 58 | class Autoencoder(nn.Module): 59 | def __init__(self, kernel: int = 5, stride: int = 2, hidden: int = 8, intermediate: int = 128): 60 | super(Autoencoder, self).__init__() 61 | self.enc = Block(1, intermediate, hidden, kernel, stride, False, 1) 62 | self.balancer = nn.BatchNorm2d(hidden, affine=False) 63 | self.dec = Block(hidden, intermediate, 1, kernel, stride, True, 1) 64 | 65 | def forward(self, x): 66 | x = self.enc(x) 67 | x = self.balancer(x).sigmoid() 68 | out = self.dec(x) 69 | return out 70 | 71 | 72 | def plot_samples(model, data, epoch, save_dir="samples"): 73 | os.makedirs(save_dir, exist_ok=True) 74 | model.eval() 75 | with torch.no_grad(): 76 | samples = model(data.cuda()) 77 | # Create a grid of original and reconstructed images 78 | comparison = torch.cat([data, samples.cpu() * 255.0], dim=0) 79 | grid = make_grid(comparison, nrow=8, normalize=True, padding=2) 80 | plt.figure(figsize=(10, 5)) 81 | plt.imshow(grid.permute(1, 2, 0)) 82 | plt.axis("off") 83 | plt.savefig(os.path.join(save_dir, f"epoch_{epoch}.png")) 84 | plt.close() 85 | model.train() 86 | 87 | 88 | class RandomPad(nn.Module): 89 | def __init__(self, amount: int): 90 | super().__init__() 91 | self.amount = amount 92 | 93 | def forward(self, inp): 94 | x = torch.randint(0, self.amount, (inp.size(0),)) 95 | y = torch.randint(0, self.amount, (inp.size(0),)) 96 | new = torch.zeros( 97 | [inp.shape[0], inp.shape[1] + self.amount, inp.shape[2] + self.amount], 98 | device=inp.device, 99 | dtype=inp.dtype, 100 | ) 101 | new[:, x : x + inp.size(1), y : y + inp.size(2)] = inp 102 | return new 103 | 104 | 105 | def mean(updates): 106 | return [sum(us) / len(us) for us in zip(*updates)] 107 | 108 | 109 | def main(epochs: int, batch: int): 110 | # Setup tensorboard logging 111 | log_dir = os.path.join("runs", f"soap_{datetime.now().strftime('%Y%m%d_%H%M%S')}") 112 | writer = SummaryWriter(log_dir) 113 | 114 | torch.manual_seed(0x12783) 115 | model = Autoencoder().cuda() 116 | # optimizer = heavyball.ForeachPSGDLRA( 117 | # model.parameters(), lr=1e-3, mars=True,precond_init_scale=1, precond_lr=0.1 118 | # ) 119 | optimizer = LRA( 120 | model.parameters(), 121 | rank_of_approximation=20, 122 | preconditioner_init_scale=1, 123 | lr_params=1e-4, 124 | lr_preconditioner=0.1, 125 | exact_hessian_vector_product=False, 126 | preconditioner_type="whitening", 127 | momentum=0.9, 128 | ) 129 | transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32), RandomPad(4)]) 130 | trainset = list(MNIST(root="./data", train=True, download=True, transform=transform)) * epochs 131 | dataloader = DataLoader(trainset, batch_size=batch, shuffle=True, num_workers=8, drop_last=True, pin_memory=True) 132 | 133 | step = 0 134 | losses = [] 135 | 136 | for data in tqdm.tqdm(dataloader): 137 | img, _ = data 138 | img = img.to(device="cuda", non_blocking=True) / 255.0 139 | 140 | def _closure(): 141 | output = model(img) 142 | loss = F.mse_loss(output, img) 143 | # loss.backward() 144 | return loss 145 | 146 | loss = optimizer.step(_closure) 147 | losses.append(loss.detach()) 148 | if len(losses) >= 64: 149 | for loss in losses: 150 | writer.add_scalar("Loss/train", loss.item(), step) 151 | step += 1 152 | losses.clear() 153 | writer.flush() 154 | 155 | 156 | if __name__ == "__main__": 157 | main(epochs=2000, batch=64) 158 | -------------------------------------------------------------------------------- /examples/modify_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | import heavyball 6 | 7 | heavyball.utils.compile_mode = "default" 8 | heavyball.utils.set_torch() 9 | 10 | 11 | def main(epochs: int, batch: int, features: int = 16, steps: int = 1024): 12 | model = nn.Sequential(nn.Linear(features, features * 4), nn.ReLU(), nn.Linear(features * 4, 1)) 13 | model.cuda() 14 | 15 | optimizer = heavyball.SOAP( 16 | model.parameters(), lr=1e-3, precondition_frequency=1 17 | ) # initial_d is required by scale_by_lr_adaptation but not used in standard SOAP - we'll get a warning about it 18 | optimizer.fns = optimizer.fns + [ 19 | heavyball.chainable.orthogonalize_update 20 | ] # important that we assign and don't just .append()! 21 | 22 | for epoch in range(epochs): 23 | total_loss = 0.0 24 | for _ in range(steps): 25 | data = torch.randn((batch, features), device="cuda") 26 | target = data.square().mean(1, keepdim=True) 27 | 28 | def _closure(): 29 | output = model(data) 30 | loss = F.mse_loss(output, target) 31 | loss.backward() 32 | return loss 33 | 34 | loss = optimizer.step(_closure) 35 | optimizer.zero_grad() 36 | with torch.no_grad(): 37 | total_loss = total_loss + loss.detach() 38 | 39 | avg_loss = (total_loss / steps).item() 40 | print(f"[{epoch:{len(str(epochs))}d}/{epochs}] Loss: {avg_loss:.4f}") 41 | 42 | 43 | if __name__ == "__main__": 44 | main(epochs=100, batch=1024) 45 | -------------------------------------------------------------------------------- /pre-commit.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.9.9 4 | hooks: 5 | - id: ruff 6 | types_or: [python, jupyter] 7 | - id: ruff-format 8 | args: [--diff] 9 | types_or: [python, jupyter] 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v5.0.0 12 | hooks: 13 | - id: check-added-large-files 14 | - id: check-merge-conflict 15 | - id: check-toml 16 | - id: check-yaml 17 | - id: end-of-file-fixer 18 | - id: mixed-line-ending 19 | args: [--fix=lf] 20 | - id: trailing-whitespace 21 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=75.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "heavyball" 7 | description = "Efficient Optimizers" 8 | version = "2.0.0.dev0" 9 | authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }] 10 | classifiers = ["Intended Audience :: Developers", 11 | "Intended Audience :: Science/Research", 12 | "License :: OSI Approved :: BSD License", 13 | "Natural Language :: English", 14 | "Operating System :: OS Independent", 15 | "Programming Language :: Python :: 3", 16 | ] 17 | dependencies = ["opt-einsum>=3.4.0", 18 | "torch>=2.1.0", 19 | "numpy", 20 | ] 21 | keywords = ["torch", 22 | "optimizer", 23 | "muon", 24 | "soap", 25 | "psgd", 26 | ] 27 | readme = "README.md" 28 | requires-python = ">=3.9" 29 | 30 | [project.optional-dependencies] 31 | dev = ["pre-commit", "pytest", "ruff", "matplotlib", "seaborn", "hyperopt", "pandas", "typer", "optuna", "optunahub", "botorch", "hebo"] 32 | 33 | [project.urls] 34 | source = "https://github.com/HomebrewML/HeavyBall" 35 | tracker = "https://github.com/HomebrewML/HeavyBall/issues" 36 | 37 | [tool.ruff] 38 | line-length = 120 39 | 40 | [tool.ruff.lint] 41 | extend-select = ["I", "W"] 42 | ignore = ["E741"] 43 | preview = true 44 | 45 | [tool.ruff.lint.isort] 46 | relative-imports-order = "closest-to-furthest" 47 | 48 | [tool.ruff.format] 49 | preview = true 50 | 51 | [tool.setuptools.packages.find] 52 | include = ["heavyball*"] 53 | -------------------------------------------------------------------------------- /test/benchmark_psgd_lb.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import math 3 | import os 4 | import time 5 | 6 | os.environ["TORCH_TRACE"] = "./tracedir" 7 | import matplotlib.pyplot as plt 8 | import pandas as pd 9 | import seaborn as sns 10 | import torch 11 | 12 | from heavyball.utils import max_singular_value_cholesky, max_singular_value_power_iter, set_torch 13 | 14 | 15 | def display_stats(exact, name, approx, duration): 16 | exact = exact.double().cpu() 17 | approx = approx.double().cpu() 18 | error = torch.abs(approx - exact) 19 | rel_error = error / exact.clamp(min=1e-8) 20 | print( 21 | f"{name} | Took: {duration:.6f}s | Approx={approx.mean():.4e}, Exact={exact.mean():.4e}, " 22 | f"Abs Error={error.mean():.4e}, Rel Error={rel_error.mean():.5f}" 23 | ) 24 | return { 25 | "approx": approx.mean().item(), 26 | "exact": exact.mean().item(), 27 | "abs_error": error.mean().item(), 28 | "rel_error": rel_error.mean().item(), 29 | "duration": duration, 30 | } 31 | 32 | 33 | def measure_time(xs, fn): 34 | for x in xs: # Warmup 35 | fn(x) 36 | torch.cuda.synchronize() 37 | start = time.time() 38 | results = [fn(x) for x in xs] 39 | torch.cuda.synchronize() 40 | return time.time() - start, torch.tensor(results) 41 | 42 | 43 | def baseline_norm(x): 44 | return torch.linalg.matrix_norm(x, ord=2) 45 | 46 | 47 | @torch.inference_mode() 48 | def test_singular_value_approx(min_val=2, max_val=64, attempts=1): 49 | torch.manual_seed(0x12378) 50 | test_cases = [ 51 | (lambda x: torch.randn((x, x)), "Normal"), 52 | (lambda x: torch.rand((x, x)), "Uniform"), 53 | (lambda x: torch.randn((x, x)) * torch.arange(x).view(1, -1), "Normal * Arange"), 54 | (lambda x: torch.randn((x, x)).exp(), "exp(Normal)"), 55 | (lambda x: torch.randn((x, x)) ** 16, "Normal ** 16"), 56 | ] 57 | max_name_len = max(len(name) for _, name in test_cases) 58 | test_cases = [(fn, f"{name:{max_name_len}}") for fn, name in test_cases] 59 | methods = ( 60 | ("exact", baseline_norm), 61 | ("cholesky", max_singular_value_cholesky), 62 | ("power_iter_0", functools.partial(max_singular_value_power_iter, iterations=0)), 63 | ("power_iter_1", functools.partial(max_singular_value_power_iter, iterations=2)), 64 | ("power_iter_2", functools.partial(max_singular_value_power_iter, iterations=6)), 65 | ) 66 | results = [] 67 | 68 | sizes = [2**s for s in range(int(math.log2(min_val)), int(math.log2(max_val)) + 1)] 69 | for size in sizes: 70 | size_str = f"{size:{len(str(max_val))}d}" 71 | for matrix_fn, name in test_cases: 72 | matrices = [matrix_fn(size).cuda().float() for _ in range(attempts)] 73 | exact_vals = None 74 | for method_name, method_fn in methods: 75 | duration, approx_vals = measure_time(matrices, method_fn) 76 | if method_name == "exact": 77 | exact_vals = approx_vals 78 | stats = display_stats(exact_vals, f"{name} ({size_str}) | {method_name}", approx_vals, duration) 79 | results.append({"size": size, "matrix_type": name.strip(), "method": method_name, **stats}) 80 | return pd.DataFrame(results) 81 | 82 | 83 | def plot_results(df): 84 | # Set a more modern aesthetic with clear colors 85 | sns.set_theme(style="whitegrid", font_scale=1.3) 86 | palette = sns.color_palette("viridis", n_colors=4) 87 | 88 | # Create a figure with 2x2 subplots 89 | fig, axes = plt.subplots(2, 2, figsize=(18, 12), gridspec_kw={"hspace": 0.35, "wspace": 0.25}) 90 | 91 | # Add a title to the overall figure 92 | fig.suptitle("Singular Value Approximation Method Comparison", fontsize=20, y=0.98) 93 | 94 | # 1. Duration vs Size (Top Left) 95 | method_order = ["cholesky", "power_iter_0", "power_iter_1", "power_iter_2"] 96 | method_names = { 97 | "cholesky": "Cholesky", 98 | "power_iter_0": "Power Iteration (Default)", 99 | "power_iter_1": "Power Iteration (1 iter)", 100 | "power_iter_2": "Power Iteration (2 iter)", 101 | } 102 | 103 | # Filter out the exact method for time comparison and ensure consistent order 104 | plot_df = df[df["method"] != "exact"].copy() 105 | plot_df["method_name"] = plot_df["method"].map(method_names) 106 | 107 | # Add speedup compared to exact method 108 | exact_times = df[df["method"] == "exact"].set_index(["size", "matrix_type"])["duration"] 109 | plot_df["speedup"] = plot_df.apply( 110 | lambda row: exact_times.loc[(row["size"], row["matrix_type"])] / row["duration"], axis=1 111 | ) 112 | 113 | # Plot duration vs size 114 | sns.lineplot( 115 | data=plot_df, 116 | x="size", 117 | y="duration", 118 | hue="method_name", 119 | style="method_name", 120 | markers=True, 121 | markersize=10, 122 | linewidth=3, 123 | ax=axes[0, 0], 124 | palette=palette, 125 | hue_order=[method_names[m] for m in method_order], 126 | ) 127 | 128 | # Formatting for first plot 129 | axes[0, 0].set( 130 | xscale="log", 131 | yscale="log", 132 | title="Computation Time vs Matrix Size", 133 | xlabel="Matrix Size (n×n)", 134 | ylabel="Time (seconds)", 135 | ) 136 | axes[0, 0].grid(True, which="both", ls="-", alpha=0.2) 137 | axes[0, 0].legend(title="Method", frameon=True, title_fontsize=14, fontsize=12) 138 | 139 | # Add exact method time as a reference line 140 | exact_avg = df[df["method"] == "exact"].groupby("size")["duration"].mean() 141 | axes[0, 0].plot(exact_avg.index, exact_avg.values, "r--", linewidth=2, alpha=0.7, label="Exact (SVD)") 142 | axes[0, 0].legend(title="Method", frameon=True, title_fontsize=14, fontsize=12) 143 | 144 | # 2. Relative Error vs Size (Top Right) 145 | sns.lineplot( 146 | data=plot_df, 147 | x="size", 148 | y="rel_error", 149 | hue="method_name", 150 | style="method_name", 151 | markers=True, 152 | markersize=10, 153 | linewidth=3, 154 | ax=axes[0, 1], 155 | palette=palette, 156 | hue_order=[method_names[m] for m in method_order], 157 | ) 158 | 159 | # Formatting for second plot 160 | axes[0, 1].set( 161 | xscale="log", title="Relative Error vs Matrix Size", xlabel="Matrix Size (n×n)", ylabel="Relative Error" 162 | ) 163 | axes[0, 1].set_yscale("log") 164 | axes[0, 1].grid(True, which="both", ls="-", alpha=0.2) 165 | axes[0, 1].legend(title="Method", frameon=True, title_fontsize=14, fontsize=12) 166 | 167 | # 3. Speedup Factor vs Size (Bottom Left) 168 | sns.lineplot( 169 | data=plot_df, 170 | x="size", 171 | y="speedup", 172 | hue="method_name", 173 | style="method_name", 174 | markers=True, 175 | markersize=10, 176 | linewidth=3, 177 | ax=axes[1, 0], 178 | palette=palette, 179 | hue_order=[method_names[m] for m in method_order], 180 | ) 181 | 182 | # Formatting for third plot 183 | axes[1, 0].set( 184 | xscale="log", 185 | yscale="log", 186 | title="Speedup vs Matrix Size (compared to exact SVD)", 187 | xlabel="Matrix Size (n×n)", 188 | ylabel="Speedup Factor (×)", 189 | ) 190 | axes[1, 0].grid(True, which="both", ls="-", alpha=0.2) 191 | axes[1, 0].legend(title="Method", frameon=True, title_fontsize=14, fontsize=12) 192 | 193 | # 4. Error vs Matrix Type (Bottom Right) - Boxplot 194 | sns.boxplot(data=plot_df, x="method_name", y="rel_error", hue="matrix_type", ax=axes[1, 1], palette="Set2") 195 | 196 | # Formatting for fourth plot 197 | axes[1, 1].set( 198 | yscale="log", 199 | title="Relative Error by Method and Matrix Type", 200 | xlabel="Method", 201 | ylabel="Relative Error (log scale)", 202 | ) 203 | axes[1, 1].tick_params(axis="x", rotation=45) 204 | axes[1, 1].legend(title="Matrix Type", frameon=True, title_fontsize=14, fontsize=12) 205 | 206 | # Apply consistent formatting to all subplots 207 | for ax in axes.flatten(): 208 | ax.title.set_fontsize(16) 209 | ax.xaxis.label.set_fontsize(14) 210 | ax.yaxis.label.set_fontsize(14) 211 | ax.tick_params(labelsize=12) 212 | 213 | # For logarithmic scales, add minor gridlines 214 | if ax.get_xscale() == "log": 215 | ax.xaxis.grid(True, which="minor", linestyle="--", alpha=0.2) 216 | if ax.get_yscale() == "log": 217 | ax.yaxis.grid(True, which="minor", linestyle="--", alpha=0.2) 218 | 219 | plt.tight_layout(rect=[0, 0, 1, 0.96]) # Leave space for the figure title 220 | 221 | return fig 222 | 223 | 224 | def main(): 225 | set_torch() 226 | with torch._dynamo.utils.disable_cache_limit(): 227 | results = test_singular_value_approx() 228 | fig = plot_results(results) 229 | fig.savefig("singular_value_comparison.png", dpi=300, bbox_inches="tight") 230 | plt.show() 231 | 232 | 233 | if __name__ == "__main__": 234 | main() 235 | -------------------------------------------------------------------------------- /test/readme.md: -------------------------------------------------------------------------------- 1 | # Tests 2 | 3 | - [ ] regression test against SOAP (due to implementation challenges) 4 | - [ ] peak memory test 5 | - [ ] compute (runtime) test 6 | -------------------------------------------------------------------------------- /test/test_bf16_params.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import pytest 5 | import torch 6 | from torch import nn 7 | from torch._dynamo import config 8 | 9 | import heavyball 10 | import heavyball.utils 11 | from benchmark.utils import get_optim 12 | from heavyball.utils import clean, set_torch 13 | 14 | os.environ["TORCH_LOGS"] = "+recompiles" 15 | 16 | config.cache_size_limit = 128 17 | 18 | 19 | @pytest.mark.parametrize("opt", heavyball.__all__) 20 | @pytest.mark.parametrize("size,depth", [(256, 1)]) 21 | def test_foreach(opt, size, depth: int, iterations: int = 512, outer_iterations: int = 1): 22 | set_torch() 23 | opt = getattr(heavyball, opt) 24 | 25 | peaks = [] 26 | losses = [] 27 | 28 | torch.manual_seed(0x123131) 29 | model = nn.Sequential(*[nn.Linear(size, size, bias=False) for _ in range(depth)]).to(torch.double).cuda() 30 | 31 | for dtype in [torch.float32, torch.bfloat16]: 32 | torch.manual_seed(0x2131290) 33 | peaks.append([]) 34 | losses.append([]) 35 | 36 | for i in range(outer_iterations): 37 | mdl = copy.deepcopy(model).to(dtype) 38 | o = get_optim(opt, mdl.parameters(), lr=1e-4, update_clipping=None, warmup_steps=128) 39 | print(f"\n\n\n{dtype} {opt} {size} {depth}\n\n\n") 40 | for _ in range(iterations): 41 | loss = mdl(torch.randn((1024, size), device="cuda", dtype=dtype)).double().abs().mean() 42 | loss.backward() 43 | print(mdl[0].weight.double().norm().item()) 44 | o.step() 45 | o.zero_grad() 46 | losses[-1].append(loss.detach()) 47 | 48 | del mdl, o 49 | clean() 50 | 51 | for i, (l0, l1) in enumerate(zip(*losses)): 52 | print(i, l0.item(), l1.item()) 53 | # assert torch.allclose(l0.float(), l1.float(), rtol=0.1) 54 | -------------------------------------------------------------------------------- /test/test_bf16_q.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | from torch._dynamo import config 5 | 6 | import heavyball 7 | import heavyball.utils 8 | from benchmark.utils import get_optim 9 | from heavyball.utils import PSGDBase, clean, set_torch 10 | 11 | config.cache_size_limit = 128 12 | 13 | 14 | @pytest.mark.parametrize("opt", heavyball.__all__) 15 | @pytest.mark.parametrize("size,depth", [(256, 2)]) 16 | def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3): 17 | set_torch() 18 | 19 | opt = getattr(heavyball, opt) 20 | if not issubclass(opt, PSGDBase): 21 | raise pytest.skip("Only PSGD is supported") 22 | 23 | peaks = [] 24 | losses = [] 25 | 26 | for q_dtype in ["float32", "bfloat16"]: 27 | torch.manual_seed(0x2131290) 28 | peaks.append([]) 29 | losses.append([]) 30 | 31 | for i in range(outer_iterations): 32 | model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda() 33 | o = get_optim(opt, model.parameters(), lr=1e-3, q_dtype=q_dtype) 34 | 35 | for _ in range(iterations): 36 | loss = model(torch.randn((1024, size), device="cuda")).square().mean() 37 | loss.backward() 38 | o.step() 39 | o.zero_grad() 40 | losses[-1].append(loss.detach()) 41 | 42 | del model, o 43 | clean() 44 | 45 | for i, (l0, l1) in enumerate(zip(*losses)): 46 | print(i, l0.item(), l1.item()) 47 | assert torch.allclose(l0, l1, rtol=0.1) 48 | -------------------------------------------------------------------------------- /test/test_bf16_storage.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | from torch._dynamo import config 5 | 6 | import heavyball 7 | import heavyball.utils 8 | from benchmark.utils import get_optim 9 | from heavyball.utils import PSGDBase, clean, set_torch 10 | 11 | config.cache_size_limit = 128 12 | 13 | 14 | @pytest.mark.parametrize("opt", heavyball.__all__) 15 | @pytest.mark.parametrize("size,depth", [(256, 2)]) 16 | def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3): 17 | set_torch() 18 | 19 | if "soap" in opt.lower(): 20 | raise pytest.skip("soap is not supported") 21 | 22 | opt = getattr(heavyball, opt) 23 | 24 | if PSGDBase in opt.__mro__: 25 | raise pytest.skip("PSGD is not supported") 26 | 27 | peaks = [] 28 | losses = [] 29 | 30 | for dtype_name in ["float32", "bfloat16"]: 31 | torch.manual_seed(0x2131290) 32 | peaks.append([]) 33 | losses.append([]) 34 | 35 | dtype = getattr(torch, dtype_name) 36 | 37 | for i in range(outer_iterations): 38 | model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda().to(dtype) 39 | o = get_optim(opt, model.parameters(), lr=1e-3, storage_dtype=dtype_name) 40 | 41 | for _ in range(iterations): 42 | loss = model(torch.randn((1024, size), device="cuda", dtype=dtype)).square().mean() 43 | loss.backward() 44 | o.step() 45 | o.zero_grad() 46 | losses[-1].append(loss.detach()) 47 | 48 | del model, o 49 | clean() 50 | 51 | for i, (l0, l1) in enumerate(zip(*losses)): 52 | print(i, l0.item(), l1.item()) 53 | assert torch.allclose(l0.float(), l1.float(), rtol=0.1) 54 | -------------------------------------------------------------------------------- /test/test_caution.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["TORCH_LOGS"] = "+recompiles" 4 | 5 | import pytest 6 | import torch 7 | from torch import nn 8 | from torch._dynamo import config 9 | 10 | import heavyball 11 | import heavyball.utils 12 | from benchmark.utils import get_optim 13 | from heavyball.utils import clean, set_torch 14 | 15 | config.cache_size_limit = 128 16 | 17 | 18 | @pytest.mark.parametrize("opt", heavyball.__all__) 19 | @pytest.mark.parametrize("size,depth", [(128, 2)]) 20 | def test_caution(opt, size, depth: int, iterations: int = 16, outer_iterations: int = 1): 21 | set_torch() 22 | opt = getattr(heavyball, opt) 23 | peaks = [] 24 | losses = [] 25 | 26 | for caution in [True, False]: 27 | torch.manual_seed(0x2131290) 28 | peaks.append([]) 29 | losses.append([]) 30 | 31 | for i in range(outer_iterations): 32 | model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda() 33 | o = get_optim(opt, model.parameters(), lr=1e-5, caution=caution) 34 | 35 | for _ in range(iterations): 36 | loss = model(torch.randn((1024, size), device="cuda")).square().mean() 37 | loss.backward() 38 | o.step() 39 | o.zero_grad() 40 | losses[-1].append(loss.detach()) 41 | 42 | del model, o 43 | clean() 44 | 45 | for i, (l0, l1) in enumerate(zip(*losses)): 46 | print(i, l0.item(), l1.item()) 47 | assert l0.item() <= l1.item() * 1.1 48 | -------------------------------------------------------------------------------- /test/test_channels_last.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["TORCH_LOGS"] = "+recompiles" 4 | 5 | import pytest 6 | import torch 7 | from torch import nn 8 | from torch._dynamo import config 9 | 10 | import heavyball 11 | import heavyball.utils 12 | from benchmark.utils import get_optim 13 | from heavyball.utils import clean, set_torch 14 | 15 | heavyball.utils.zeroth_power_mode = "newtonschulz" 16 | heavyball.utils.compile_mode = "default" 17 | config.cache_size_limit = 128 18 | 19 | 20 | @pytest.mark.parametrize("opt", heavyball.__all__) 21 | @pytest.mark.parametrize("size,depth", [(128, 1)]) 22 | def test_foreach(opt, size, depth: int, iterations: int = 1024, outer_iterations: int = 1): 23 | set_torch() 24 | opt = getattr(heavyball, opt) 25 | 26 | peaks = [] 27 | losses = [] 28 | 29 | for is_channels_last in [False, True]: 30 | torch.manual_seed(0x2131290) 31 | peaks.append([]) 32 | losses.append([]) 33 | 34 | for i in range(outer_iterations): 35 | model = nn.Sequential(*[nn.Conv2d(size, size, 3) for _ in range(depth)]).cuda() 36 | if is_channels_last: 37 | model.to(memory_format=torch.channels_last) 38 | 39 | o = get_optim(opt, model.parameters(), lr=1e-3, weight_decay=1e-4, warmup_steps=16) 40 | 41 | for _ in range(iterations): 42 | loss = model(torch.randn((1024, size, 4, 4), device="cuda")).square().mean() 43 | loss.backward() 44 | o.step() 45 | o.zero_grad() 46 | losses[-1].append(loss.detach()) 47 | 48 | del model, o 49 | clean() 50 | 51 | for i, (l0, l1) in enumerate(zip(*losses)): 52 | print(i, l0.item(), l1.item()) 53 | assert torch.allclose(l0.float(), l1.float(), rtol=0.1) 54 | -------------------------------------------------------------------------------- /test/test_closure.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | import torch 5 | from torch import nn 6 | 7 | import heavyball 8 | import heavyball.utils 9 | from benchmark.utils import get_optim 10 | from heavyball.utils import clean, set_torch 11 | 12 | 13 | class Param(nn.Module): 14 | def __init__(self, size): 15 | super().__init__() 16 | self.weight = nn.Parameter(torch.randn(size)) 17 | 18 | def forward(self, inp): 19 | return self.weight.mean() * inp 20 | 21 | 22 | @pytest.mark.parametrize("opt", heavyball.__all__) 23 | @pytest.mark.parametrize( 24 | "size", 25 | [ 26 | (4, 4, 4, 4), 27 | ], 28 | ) 29 | def test_closure(opt, size: List[int], depth: int = 2, iterations: int = 5, outer_iterations: int = 3): 30 | clean() 31 | set_torch() 32 | 33 | opt = getattr(heavyball, opt) 34 | 35 | for _ in range(outer_iterations): 36 | clean() 37 | model = nn.Sequential(*[Param(size) for _ in range(depth)]).cuda() 38 | o = get_optim(opt, model.parameters(), lr=1e-3) 39 | 40 | def _closure(): 41 | loss = model(torch.randn((1, size[0]), device="cuda")).sum() 42 | loss.backward() 43 | return loss 44 | 45 | for i in range(iterations): 46 | o.step(_closure) 47 | o.zero_grad() 48 | print(o.state_size()) 49 | -------------------------------------------------------------------------------- /test/test_ema.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | from torch._dynamo import config 5 | 6 | import heavyball 7 | import heavyball.utils 8 | from benchmark.utils import get_optim 9 | from heavyball.utils import clean, set_torch 10 | 11 | config.cache_size_limit = 128 12 | 13 | 14 | def get_memory(): 15 | clean() 16 | torch.cuda.synchronize() 17 | clean() 18 | torch.cuda.synchronize() 19 | return torch.cuda.memory_allocated() 20 | 21 | 22 | @pytest.mark.parametrize("opt", heavyball.__all__) 23 | @pytest.mark.parametrize("size,depth", [(256, 2)]) 24 | def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3): 25 | set_torch() 26 | opt = getattr(heavyball, opt) 27 | 28 | peaks = [] 29 | losses = [] 30 | 31 | for do_ema in [True, False]: 32 | torch.manual_seed(0x2131290) 33 | peaks.append([]) 34 | losses.append([]) 35 | 36 | for i in range(outer_iterations): 37 | model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda() 38 | o = get_optim(opt, model.parameters(), lr=1e-3) 39 | 40 | for _ in range(iterations): 41 | loss = model(torch.randn((1024, size), device="cuda")).square().mean() 42 | loss.backward() 43 | o.step() 44 | o.zero_grad() 45 | if do_ema: 46 | o.ema_update() 47 | o.copy_emas_to_params() 48 | o.copy_params_to_emas() 49 | losses[-1].append(loss.detach()) 50 | 51 | if do_ema: 52 | o.copy_emas_to_params() 53 | loss = model(torch.randn((1024, size), device="cuda")).square().mean() 54 | losses[-1].append(loss.detach()) 55 | 56 | del model, o 57 | clean() 58 | 59 | for i, (l0, l1) in enumerate(zip(*losses)): 60 | print(i, l0.item(), l1.item()) 61 | assert l0.float() <= l1.float() 62 | -------------------------------------------------------------------------------- /test/test_foreach.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | import heavyball 6 | import heavyball.utils 7 | from benchmark.utils import get_optim 8 | from heavyball.utils import PSGDBase, clean, set_torch 9 | 10 | 11 | def get_memory(): 12 | clean() 13 | torch.cuda.synchronize() 14 | clean() 15 | torch.cuda.synchronize() 16 | return torch.cuda.memory_allocated() 17 | 18 | 19 | @pytest.mark.parametrize("opt", heavyball.__all__) 20 | @pytest.mark.parametrize("size,depth", [(256, 128)]) 21 | def test_foreach(opt, size, depth: int, iterations: int = 4096, outer_iterations: int = 2): 22 | set_torch() 23 | 24 | opt = getattr(heavyball, opt) 25 | 26 | peaks = [] 27 | losses = [] 28 | 29 | for foreach in [True, False]: 30 | torch.manual_seed(0x2131290) 31 | peaks.append([]) 32 | losses.append([]) 33 | 34 | for i in range(outer_iterations): 35 | clean() 36 | model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda() 37 | clean() 38 | 39 | torch.cuda.reset_peak_memory_stats() 40 | torch.cuda.reset_max_memory_allocated() 41 | torch.cuda.reset_max_memory_cached() 42 | torch.cuda.reset_accumulated_memory_stats() 43 | 44 | clean() 45 | o = get_optim(opt, model.parameters(), lr=1e-3, foreach=foreach) 46 | clean() 47 | 48 | for _ in range(iterations): 49 | loss = model(torch.randn((1, size), device="cuda")).sum() 50 | loss.backward() 51 | o.step() 52 | o.zero_grad() 53 | losses[-1].append(loss.detach()) 54 | 55 | del model, o 56 | clean() 57 | 58 | peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"] 59 | 60 | if i > 0: 61 | peaks[-1].append(peak) 62 | 63 | for p0, p1 in zip(*peaks): 64 | assert p0 > p1 65 | for l0, l1 in zip(*losses): # increase error tolerance for PSGD, as we have different RNGs -> expected differences 66 | assert torch.allclose(l0, l1, rtol=0.01 if isinstance(opt, PSGDBase) else 1e-5) 67 | -------------------------------------------------------------------------------- /test/test_hook.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["TORCH_LOGS"] = "+recompiles" 4 | 5 | import pytest 6 | import torch 7 | from torch import nn 8 | from torch._dynamo import config 9 | 10 | import heavyball 11 | import heavyball.utils 12 | from benchmark.utils import get_optim 13 | from heavyball.utils import clean, hook_optimizer_into_model, set_torch 14 | 15 | heavyball.utils.compile_mode = "default" 16 | config.cache_size_limit = 128 17 | 18 | 19 | @pytest.mark.parametrize("opt", heavyball.__all__) 20 | @pytest.mark.parametrize("size,depth", [(128, 1)]) 21 | def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 1): 22 | set_torch() 23 | opt = getattr(heavyball, opt) 24 | 25 | peaks = [] 26 | losses = [] 27 | 28 | for use_hook in [False, True]: 29 | torch.manual_seed(0x2131290) 30 | peaks.append([]) 31 | losses.append([]) 32 | 33 | for i in range(outer_iterations): 34 | model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda() 35 | 36 | if use_hook: 37 | hook_optimizer_into_model(model, opt, lr=1e-3, weight_decay=1e-4, warmup_steps=16) 38 | else: 39 | o = get_optim(opt, model.parameters(), lr=1e-3, weight_decay=1e-4, warmup_steps=16) 40 | for _ in range(iterations): 41 | loss = model(torch.randn((1024, size), device="cuda")).square().mean() 42 | loss.backward() 43 | if not use_hook: 44 | o.step() 45 | o.zero_grad() 46 | losses[-1].append(loss.detach()) 47 | 48 | clean() 49 | 50 | for i, (l0, l1) in enumerate(zip(*losses)): 51 | print(i, l0.item(), l1.item()) 52 | assert torch.allclose(l0.float(), l1.float(), rtol=0.1) 53 | -------------------------------------------------------------------------------- /test/test_mars.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | from torch._dynamo import config 5 | 6 | import heavyball 7 | import heavyball.utils 8 | from benchmark.utils import get_optim 9 | from heavyball.utils import ScheduleFree, clean, set_torch 10 | 11 | config.cache_size_limit = 128 12 | 13 | 14 | @pytest.mark.parametrize("opt", heavyball.__all__) 15 | @pytest.mark.parametrize("size,depth", [(128, 2)]) 16 | def test_mars(opt, size, depth: int, iterations: int = 16384, outer_iterations: int = 1): 17 | set_torch() 18 | opt = getattr(heavyball, opt) 19 | if ScheduleFree in opt.__mro__: 20 | raise pytest.skip("Skipping ScheduleFree") 21 | 22 | peaks = [] 23 | losses = [] 24 | 25 | for mars in [True, False]: 26 | torch.manual_seed(0x2131290) 27 | peaks.append([]) 28 | losses.append([]) 29 | 30 | for i in range(outer_iterations): 31 | model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda() 32 | o = get_optim(opt, model.parameters(), lr=1e-5, mars=mars) 33 | 34 | for _ in range(iterations): 35 | loss = model(torch.randn((1024, size), device="cuda")).square().mean() 36 | loss.backward() 37 | o.step() 38 | o.zero_grad() 39 | losses[-1].append(loss.detach()) 40 | 41 | del model, o 42 | clean() 43 | 44 | for i, (l0, l1) in enumerate(zip(*losses)): 45 | print(i, l0.item(), l1.item()) 46 | assert l0.item() <= l1.item() * 1.1 47 | -------------------------------------------------------------------------------- /test/test_memory.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | import heavyball 6 | import heavyball.utils 7 | from benchmark.utils import get_optim 8 | from heavyball.utils import clean, set_torch 9 | 10 | 11 | def get_memory(): 12 | clean() 13 | torch.cuda.synchronize() 14 | clean() 15 | torch.cuda.synchronize() 16 | return torch.cuda.memory_allocated() 17 | 18 | 19 | expected_memory = { 20 | "adamw": {"after": 4, "peak": 5.1}, 21 | "soap": {"after": 7, "peak": 14}, 22 | "psgd": {"after": 4, "peak": 11.5}, 23 | "padam": {"after": 5, "peak": 11.4}, 24 | } 25 | 26 | 27 | @pytest.mark.parametrize("opt", ["NewtonHybrid2PSGDKron"]) 28 | @pytest.mark.parametrize("method", ["qr", "newtonschulz2", "svd", "eigh"]) 29 | @pytest.mark.parametrize("size,depth", [(8192, 2), (2048, 16)]) 30 | def test_memory(opt, method, size, depth: int, iterations: int = 5, outer_iterations: int = 3): 31 | if "soap" not in opt.lower() and method != "qr": 32 | raise pytest.skip("Only SOAP supports `method` argument") 33 | set_torch() 34 | 35 | for k, v in expected_memory.items(): 36 | if k in opt.lower(): 37 | break 38 | else: 39 | raise pytest.skip(f"Opt {opt} not supported") 40 | 41 | opt = getattr(heavyball, opt) 42 | heavyball.utils.zeroth_power_mode = method 43 | 44 | for i in range(outer_iterations): 45 | model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda() 46 | print(model) 47 | torch.cuda.reset_peak_memory_stats() 48 | torch.cuda.reset_max_memory_allocated() 49 | torch.cuda.reset_max_memory_cached() 50 | torch.cuda.reset_accumulated_memory_stats() 51 | 52 | model_allocated = get_memory() 53 | o = get_optim(opt, model.parameters(), lr=1e-3) 54 | for _ in range(iterations): 55 | data = torch.randn((1, size), device="cuda").requires_grad_(True) 56 | 57 | def _closure(): 58 | nonlocal model 59 | loss = (model(data) - data).square().mean() 60 | loss.backward() 61 | return loss 62 | 63 | o.step(_closure) 64 | 65 | opt_allocated = get_memory() 66 | 67 | del model, o 68 | peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"] 69 | 70 | print(f"Peak: {peak / model_allocated:.2f}x | Opt: {opt_allocated / model_allocated:.2f}x") 71 | if i > 0: 72 | assert peak / model_allocated < v["peak"] 73 | assert opt_allocated / model_allocated < v["after"] 74 | -------------------------------------------------------------------------------- /test/test_memory_leak.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import tqdm 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | import heavyball 8 | import heavyball.utils 9 | from benchmark.utils import get_optim 10 | from heavyball.utils import clean, set_torch 11 | 12 | 13 | def get_memory(): 14 | clean() 15 | clean() 16 | torch.cuda.synchronize() 17 | out = torch.cuda.memory_stats()["allocated_bytes.all.peak"] 18 | torch.cuda.reset_peak_memory_stats() 19 | torch.cuda.reset_max_memory_allocated() 20 | torch.cuda.reset_max_memory_cached() 21 | torch.cuda.reset_accumulated_memory_stats() 22 | return out 23 | 24 | 25 | class LayerNorm2dParam(nn.Module): 26 | def __init__(self, num_features): 27 | super(LayerNorm2dParam, self).__init__() 28 | self.param = nn.Parameter(torch.ones(2, num_features)) 29 | 30 | def forward(self, x): 31 | weight, bias = self.param.unbind(0) 32 | return F.layer_norm(x, [x.size(-1)], weight, bias) 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "opt", ["NewtonHybrid2PSGDKron"] 37 | ) # leak with NewtonHybrid2PSGDKron, but not ForeachCachedNewtonPSGD or ForeachCachedPSGDKron 38 | @pytest.mark.parametrize("size,depth", [(64, 2)]) # happens across all sizes 39 | @pytest.mark.parametrize("mars", [False]) # happens with True and False 40 | @pytest.mark.parametrize("cached", [False]) # happens with True and False 41 | @pytest.mark.parametrize("delayed", [False]) # happens with True and False 42 | @pytest.mark.parametrize("merge_dims", [False]) # happens with True and False 43 | @pytest.mark.parametrize("split", [True]) # happens with True and False 44 | @pytest.mark.parametrize("finite_differences", [False]) # only happens with False - does not happen with True 45 | def test_memory( 46 | opt, 47 | size, 48 | depth: int, 49 | mars: bool, 50 | cached: bool, 51 | delayed: bool, 52 | merge_dims: bool, 53 | split: bool, 54 | finite_differences: bool, 55 | iterations: int = 10000, 56 | warmup: int = 100, 57 | check_every: int = 10, 58 | max_growth: float = 1.10, 59 | ): 60 | set_torch() 61 | 62 | opt = getattr(heavyball, opt) 63 | model = nn.Sequential(*[LayerNorm2dParam(size) for _ in range(depth)]).cuda() 64 | print(model) 65 | 66 | o = get_optim( 67 | opt, 68 | model.parameters(), 69 | lr=1e-3, 70 | mars=mars, 71 | merge_dims=merge_dims, 72 | split=split, 73 | cached=cached, 74 | delayed=delayed, 75 | preconditioner_update_probability=1.0, 76 | ) 77 | if finite_differences: 78 | if not o.hessian_approx: 79 | pytest.skip("Finite Differences is an HVP calculation - can't do it on non-hvp optimizers") 80 | o.finite_differences = True 81 | 82 | peak = 0 83 | for i in tqdm.trange(iterations): 84 | data = torch.randn((1, size), device="cuda").requires_grad_(True) 85 | 86 | def _closure(): 87 | nonlocal model 88 | loss = (model(data) - data).square().mean() 89 | loss.backward() 90 | return loss 91 | 92 | o.step(_closure) 93 | 94 | if i % check_every == 0: 95 | if i <= warmup: 96 | peak = max(peak, get_memory()) 97 | if i > warmup: 98 | new = get_memory() 99 | print(i, peak, new) 100 | assert peak * max_growth >= new # fudge factor 101 | -------------------------------------------------------------------------------- /test/test_merge.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | import torch 5 | from torch import nn 6 | 7 | import heavyball 8 | import heavyball.utils 9 | from benchmark.utils import get_optim 10 | from heavyball.utils import clean, set_torch 11 | 12 | 13 | class Param(nn.Module): 14 | def __init__(self, size): 15 | super().__init__() 16 | self.weight = nn.Parameter(torch.randn(size)) 17 | 18 | def forward(self, inp): 19 | return self.weight.mean() * inp 20 | 21 | 22 | @pytest.mark.parametrize("opt", ["ForeachPSGDKron", "ForeachPaLMPAdam"]) 23 | @pytest.mark.parametrize("method", ["qr", "newtonschulz2", "svd", "eigh"]) 24 | @pytest.mark.parametrize("size", [(16, 16, 16, 16), (4, 4, 4, 4), (512, 1, 128), (32128, 768)]) 25 | @pytest.mark.parametrize("merge,split", [(False, False), (True, False), (True, True)]) 26 | def test_merge( 27 | opt, 28 | method, 29 | size: List[int], 30 | merge, 31 | split, 32 | depth: int = 2, 33 | iterations: int = 5, 34 | outer_iterations: int = 3, 35 | ): 36 | if "soap" not in opt.lower() and method != "qr": 37 | raise pytest.skip("Only SOAP supports `method` argument") 38 | clean() 39 | set_torch() 40 | 41 | opt = getattr(heavyball, opt) 42 | heavyball.utils.zeroth_power_mode = method 43 | 44 | for _ in range(outer_iterations): 45 | clean() 46 | model = nn.Sequential(*[Param(size) for _ in range(depth)]).cuda() 47 | # We don't know if merging will use more or less memory, but we do know that it shouldn't crash. This test is to check if it crashes 48 | o = get_optim( 49 | opt, 50 | model.parameters(), 51 | lr=1e-3, 52 | merge_dims=merge, 53 | split=split, 54 | max_precond_dim=256, 55 | max_size_triangular=256, 56 | ) 57 | 58 | for i in range(iterations): 59 | model(torch.randn((1, size[0]), device="cuda")).sum().backward() 60 | o.step() 61 | o.zero_grad() 62 | print(o.state_size()) 63 | 64 | del model, o 65 | -------------------------------------------------------------------------------- /test/test_no_grad.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | import torch 5 | from torch import nn 6 | 7 | import heavyball 8 | import heavyball.utils 9 | from benchmark.utils import get_optim 10 | from heavyball.utils import clean, set_torch 11 | 12 | 13 | class Param(nn.Module): 14 | def __init__(self, size): 15 | super().__init__() 16 | self.weight = nn.Parameter(torch.randn(size)) 17 | 18 | def forward(self, inp): 19 | return self.weight.mean() * inp 20 | 21 | 22 | @pytest.mark.parametrize("opt", heavyball.__all__) 23 | @pytest.mark.parametrize( 24 | "size", 25 | [ 26 | (4, 4, 4, 4), 27 | ], 28 | ) 29 | def test_closure(opt, size: List[int], depth: int = 2, iterations: int = 5, outer_iterations: int = 3): 30 | clean() 31 | set_torch() 32 | 33 | opt = getattr(heavyball, opt) 34 | 35 | for _ in range(outer_iterations): 36 | clean() 37 | model = nn.Sequential(*[Param(size) for _ in range(depth)]).cuda() 38 | o = get_optim(opt, model.parameters(), lr=1e-3) 39 | 40 | for i in range(iterations): 41 | o.step() 42 | o.zero_grad() 43 | assert o.state_size() == 0 44 | 45 | del model, o 46 | -------------------------------------------------------------------------------- /test/test_psgd_precond_init_stability.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | from torch._dynamo import config 7 | 8 | from heavyball.utils import _lse_mean, divided_root, mean_root, stable_exp 9 | 10 | config.cache_size_limit = sys.maxsize 11 | config.accumulated_cache_size_limit = sys.maxsize 12 | 13 | 14 | def np_mean_pow_root(z_np: np.ndarray, pow_val: float, eps_val: float): 15 | z_np = np.abs(z_np).clip(min=eps_val) ** pow_val 16 | mean_pow = np.mean(z_np) ** (1 / pow_val / 2) 17 | return mean_pow 18 | 19 | 20 | def _get_tensor(numel, dtype, maxval: float = 10): 21 | x = torch.randn(numel, dtype=dtype) # * torch.arange(numel, dtype=dtype) 22 | x /= x.abs().max() 23 | x *= maxval 24 | x = x.clone() 25 | return x, x.numpy().astype(np.float128) 26 | 27 | 28 | tolerance = { 29 | np.float64: {"rtol": 1e-10, "atol": 1e-12}, 30 | np.float32: {"rtol": 1e-5, "atol": 1e-6}, 31 | np.float16: {"rtol": 1e-2, "atol": 1e-3}, 32 | } 33 | 34 | 35 | def _isclose(x, y): 36 | if isinstance(x, torch.Tensor): 37 | x = x.cpu().numpy() 38 | if isinstance(y, torch.Tensor): 39 | y = y.cpu().numpy() 40 | if not np.isfinite(x).all(): 41 | assert ~(np.isfinite(x) ^ np.isfinite(y)).any() # all are nan together 42 | if np.isfinite(x).any(): 43 | for k, v in tolerance.items(): # numpy doesn't support indexing 44 | if x.dtype == k: 45 | tol = v 46 | break 47 | else: 48 | raise ValueError(f"dtype {x.dtype} not supported") 49 | assert np.allclose(x[np.isfinite(x)], y[np.isfinite(y)], **tol) 50 | 51 | 52 | @pytest.mark.parametrize("x_val", list(range(-10, 10))) 53 | @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64]) 54 | def test_stable_exp_scalar(x_val, dtype): 55 | x = torch.tensor(x_val, dtype=dtype) 56 | result = stable_exp(x) 57 | expected = np.exp(x_val) if x_val <= 0 else 1 / np.exp(-x_val) 58 | _isclose(result.to(x.dtype), expected) 59 | 60 | 61 | @pytest.mark.parametrize("numel", [2**i for i in range(10)]) 62 | @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64]) 63 | def test_stable_exp_tensor(numel, dtype): 64 | x, x_np = _get_tensor(numel, dtype) 65 | result = stable_exp(x) 66 | expected = np.exp(x_np) 67 | _isclose(result.to(x.dtype), expected) 68 | 69 | 70 | @pytest.mark.parametrize("numel", [2**i for i in range(10)]) 71 | @pytest.mark.parametrize("pow_val", list(range(1, 16))) 72 | @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64]) 73 | def test_lse_mean(numel, pow_val, dtype): 74 | x, x_np = _get_tensor(numel, dtype) 75 | result = _lse_mean(x, pow_val, 1e-20) 76 | expected = np.log(np.mean(np.abs(x_np) ** pow_val)) / pow_val / 2 77 | _isclose(result.to(x.dtype), expected) 78 | 79 | 80 | @pytest.mark.parametrize("numel", [2**i for i in range(10)]) 81 | @pytest.mark.parametrize("pow_val", list(range(1, 16))) 82 | @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64]) 83 | def test_mean_root(numel, pow_val, dtype): 84 | x, x_np = _get_tensor(numel, dtype) 85 | result = mean_root(x, pow_val) 86 | expected = 1 / (np.mean(np.abs(x_np) ** pow_val) ** (1 / pow_val / 2)) 87 | _isclose(result.to(x.dtype), expected) 88 | 89 | 90 | @pytest.mark.parametrize("numel", [2**i for i in range(10)]) 91 | @pytest.mark.parametrize("pow0_val", list(range(1, 16))) 92 | @pytest.mark.parametrize("pow1_val", list(range(1, 16))) 93 | @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64]) 94 | def test_divided_root(numel, pow0_val, pow1_val, dtype): 95 | x, x_np = _get_tensor(numel, dtype) 96 | y, y_np = _get_tensor(numel, dtype) 97 | result = divided_root(x, y, pow0_val, pow1_val) 98 | expected = np_mean_pow_root(x_np, pow0_val, 1e-12) / np_mean_pow_root(y_np, pow1_val, 1e-12) 99 | _isclose(result.to(x.dtype), expected) 100 | -------------------------------------------------------------------------------- /test/test_save_restore.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | os.environ["TORCH_LOGS"] = "+recompiles" 5 | 6 | import pytest 7 | import torch 8 | from torch import nn 9 | from torch._dynamo import config 10 | from torch.utils._pytree import tree_map 11 | 12 | import heavyball 13 | import heavyball.utils 14 | from benchmark.utils import get_optim 15 | from heavyball.utils import set_torch 16 | 17 | config.cache_size_limit = 128 18 | 19 | 20 | def _train_one(dataset, model, opt): 21 | torch.manual_seed(0x2131290) 22 | for d in dataset: 23 | opt.zero_grad() 24 | 25 | def _closure(): 26 | loss = (model(d) - d.square()).square().mean() 27 | loss.backward() 28 | return loss 29 | 30 | opt.step(_closure) 31 | return model 32 | 33 | 34 | def _allclose(x, y): 35 | if isinstance(x, torch.Tensor): 36 | assert torch.allclose(x, y) 37 | elif isinstance(x, (list, tuple)): 38 | assert all(_allclose(x, y) for x, y in zip(x, y)) 39 | elif not isinstance(x, bytes): # bytes -> it's a pickle 40 | assert x == y 41 | 42 | 43 | @pytest.mark.parametrize("opt", heavyball.__all__) 44 | @pytest.mark.parametrize("size,depth", [(32, 2)]) 45 | @pytest.mark.parametrize("split", [False, True]) 46 | @pytest.mark.parametrize("merge_dims", [False, True]) 47 | def test_save_restore( 48 | opt, size, depth: int, split: bool, merge_dims: bool, iterations: int = 32, outer_iterations: int = 8 49 | ): 50 | set_torch() 51 | opt = getattr(heavyball, opt) 52 | 53 | torch.manual_seed(0x2131290) 54 | data = torch.randn((iterations, size), device="cuda", dtype=torch.double) 55 | 56 | model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda().double() 57 | o: torch.optim.Optimizer = get_optim( 58 | opt, model.parameters(), lr=1e-3, merge_dims=merge_dims, split=split, storage_dtype="float64", q_dtype="float64" 59 | ) 60 | 61 | for x in range(outer_iterations): 62 | new_m = copy.deepcopy(model) 63 | new_o = get_optim(opt, new_m.parameters(), lr=1e-3) 64 | state_dict = copy.deepcopy(o.state_dict()) 65 | m = _train_one(data, model, o) 66 | 67 | new_o.load_state_dict(state_dict) 68 | new_m = _train_one(data, new_m, new_o) 69 | 70 | tree_map(_allclose, new_o.state_dict(), o.state_dict()) 71 | 72 | for normal_param, state_param in zip(m.parameters(), new_m.parameters()): 73 | assert torch.allclose(normal_param, state_param) 74 | -------------------------------------------------------------------------------- /test/test_stochastic_updates.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | import heavyball 6 | import heavyball.utils 7 | from benchmark.utils import get_optim 8 | from heavyball.utils import PSGDBase, clean, set_torch 9 | 10 | 11 | def get_memory(): 12 | clean() 13 | torch.cuda.synchronize() 14 | clean() 15 | torch.cuda.synchronize() 16 | return torch.cuda.memory_allocated() 17 | 18 | 19 | @pytest.mark.parametrize("opt", heavyball.__all__) 20 | @pytest.mark.parametrize("size,depth", [(128, 1)]) 21 | def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations: int = 3): 22 | set_torch() 23 | 24 | opt = getattr(heavyball, opt) 25 | if not issubclass(opt, PSGDBase): 26 | raise pytest.skip("Only PSGD is supported") 27 | 28 | peaks = [] 29 | losses = [] 30 | 31 | for stochastic in [False, True]: 32 | print("stochastic", stochastic) 33 | torch.manual_seed(0x2131290) 34 | peaks.append([]) 35 | losses.append([]) 36 | 37 | for i in range(outer_iterations): 38 | model = nn.Sequential(*[nn.Linear(size, size, bias=False) for _ in range(depth)]).cuda() 39 | o = get_optim(opt, model.parameters(), lr=1e-3, stochastic_schedule=stochastic) 40 | 41 | for _ in range(iterations): 42 | loss = model(torch.randn((128, size), device="cuda")).square().mean() 43 | loss.backward() 44 | o.step() 45 | o.zero_grad() 46 | losses[-1].append(loss.detach()) 47 | 48 | del model, o 49 | clean() 50 | 51 | stochastic = sum([l.item() for l in losses[1]]) 52 | deterministic = sum([l.item() for l in losses[0]]) 53 | print(f"{deterministic=}, {stochastic=}") 54 | assert deterministic < stochastic 55 | --------------------------------------------------------------------------------