├── .formatter.exs ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── README.md ├── lib ├── polaris.ex └── polaris │ ├── optimizers.ex │ ├── schedules.ex │ ├── shared.ex │ └── updates.ex ├── mix.exs ├── mix.lock └── test ├── polaris ├── optimizers_test.exs ├── schedules_test.exs └── updates_test.exs ├── support └── polaris_case.ex └── test_helper.exs /.formatter.exs: -------------------------------------------------------------------------------- 1 | # Used by "mix format" 2 | 3 | [ 4 | import_deps: [:nx], 5 | inputs: ["{mix,.formatter}.exs", "{bench,examples,config,lib,test}/**/*.{ex,exs}"] 6 | ] 7 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | jobs: 8 | main: 9 | name: ubuntu-latest (${{ matrix.elixir }}, ${{ matrix.otp }}) 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | elixir: ["1.14.2"] 14 | otp: ["25.2"] 15 | env: 16 | MIX_ENV: test 17 | steps: 18 | - uses: actions/checkout@v3 19 | - uses: erlef/setup-beam@v1 20 | with: 21 | otp-version: ${{ matrix.otp }} 22 | elixir-version: ${{ matrix.elixir }} 23 | - name: Retrieve dependencies cache 24 | env: 25 | cache-name: cache-mix-deps 26 | uses: actions/cache@v3 27 | id: mix-cache # id to use in retrieve action 28 | with: 29 | path: | 30 | deps 31 | _build 32 | key: ${{ runner.os }}-Elixir-v${{ matrix.elixir }}-OTP-${{ matrix.otp }}-${{ hashFiles('**/mix.lock') }} 33 | - name: Install dependencies 34 | if: ${{ steps.mix-cache.outputs.cache-hit != 'true' }} 35 | run: mix deps.get 36 | - name: Compile and check warnings 37 | run: mix compile --skip-optional-deps --warnings-as-errors 38 | - name: Check formatting 39 | run: mix format --check-formatted 40 | - name: Run tests 41 | run: mix test 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build/ 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover/ 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps/ 9 | 10 | # Where third-party dependencies like ExDoc output generated docs. 11 | /doc/ 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | # Ignore package tarball (built via "mix hex.build"). 23 | polaris-*.tar 24 | 25 | # Temporary files, for example, from tests. 26 | /tmp/ 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Polaris 2 | 3 | Numerical definitions for creating and using first-order optimization techniques based on the [Optax](https://github.com/deepmind/optax) library. 4 | 5 | Here is a non-exhaustive list of supported optimizers: 6 | 7 | * Adabelief 8 | * Adagrad 9 | * Adam 10 | * Adamw 11 | * Fromage 12 | * Lamb 13 | * Noisy SGD 14 | * Radam 15 | * RMSProp 16 | * SGD 17 | * Yogi 18 | 19 | ## Installation 20 | 21 | If [available in Hex](https://hex.pm/docs/publish), the package can be installed 22 | by adding `polaris` to your list of dependencies in `mix.exs`: 23 | 24 | ```elixir 25 | def deps do 26 | [ 27 | {:polaris, "~> 0.1"} 28 | ] 29 | end 30 | ``` 31 | 32 | Documentation can be generated with [ExDoc](https://github.com/elixir-lang/ex_doc) 33 | and published on [HexDocs](https://hexdocs.pm). Once published, the docs can 34 | be found at . 35 | 36 | ## License 37 | 38 | Copyright (c) 2023 Sean Moriarity 39 | 40 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 41 | 42 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 43 | -------------------------------------------------------------------------------- /lib/polaris.ex: -------------------------------------------------------------------------------- 1 | defmodule Polaris do 2 | @moduledoc """ 3 | Collection of common optimizers and optimization utilities. 4 | """ 5 | end 6 | -------------------------------------------------------------------------------- /lib/polaris/optimizers.ex: -------------------------------------------------------------------------------- 1 | defmodule Polaris.Optimizers do 2 | @moduledoc """ 3 | Implementations of common gradient-based optimization algorithms. 4 | 5 | All of the methods in this module are written in terms of 6 | the update methods defined in `Polaris.Updates`. Polaris treats 7 | optimizers as the tuple: 8 | 9 | {init_fn, update_fn} 10 | 11 | where `init_fn` returns an initial optimizer state and `update_fn` 12 | scales input gradients. `init_fn` accepts a model's parameters 13 | and attaches state to each parameter. `update_fn` accepts 14 | gradients, optimizer state, and current model parameters and 15 | returns updated optimizer state and gradients. 16 | 17 | Custom optimizers are often created via the `Polaris.Updates` API. 18 | 19 | ## Example 20 | 21 | Consider the following usage of the Adam optimizer in a basic 22 | update function (assuming `objective` and the `dataset` are 23 | defined elsewhere): 24 | 25 | defmodule Learning do 26 | 27 | import Nx.Defn 28 | 29 | defn init(params, init_fn) do 30 | init_fn.(params) 31 | end 32 | 33 | defn update(params, optimizer_state, inputs, targets, update_fn) do 34 | {loss, gradient} = value_and_grad(params, &objective(&1, inputs, targets)) 35 | {scaled_updates, new_optimizer_state} = update_fn.(gradient, optimizer_state, params) 36 | {Polaris.Updates.apply_updates(params, scaled_updates), new_optimizer_state, loss} 37 | end 38 | end 39 | 40 | {model_params, _key} = Nx.Random.uniform(key, shape: {784, 10}) 41 | {init_fn, update_fn} = Polaris.Optimizers.adam(0.005) 42 | 43 | optimizer_state = 44 | Learning.init(params, init_fn) 45 | 46 | {new_params, new_optimizer_state, loss} = 47 | Learning.update(params, optimizer_state, inputs, targets, update_fn) 48 | 49 | """ 50 | alias Polaris.Updates 51 | 52 | @doc """ 53 | Adabelief optimizer. 54 | 55 | ## Options 56 | 57 | * `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-3` 58 | * `:b1` - first moment decay. Defaults to `0.9` 59 | * `:b2` - second moment decay. Defaults to `0.999` 60 | * `:eps` - numerical stability term. Defaults to `0.0` 61 | * `:eps_root` - numerical stability term. Defaults to `1.0e-16` 62 | 63 | ## References 64 | 65 | * [AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468) 66 | """ 67 | def adabelief(opts \\ []) do 68 | {learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-3) 69 | 70 | Updates.scale_by_belief(opts) 71 | |> scale_by_learning_rate(learning_rate) 72 | end 73 | 74 | @doc """ 75 | Adagrad optimizer. 76 | 77 | ## Options 78 | 79 | * `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-3` 80 | * `:eps` - numerical stability term. Defaults to `1.0e-7` 81 | 82 | ## References 83 | 84 | * [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) 85 | """ 86 | def adagrad(opts \\ []) do 87 | {learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-3) 88 | 89 | Updates.scale_by_rss(opts) 90 | |> scale_by_learning_rate(learning_rate) 91 | end 92 | 93 | @doc """ 94 | Adam optimizer. 95 | 96 | ## Options 97 | 98 | * `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-3` 99 | * `:b1` - first moment decay. Defaults to `0.9` 100 | * `:b2` - second moment decay. Defaults to `0.999` 101 | * `:eps` - numerical stability term. Defaults to `1.0e-8` 102 | * `:eps_root` - numerical stability term. Defaults to `1.0e-15` 103 | 104 | ## References 105 | 106 | * [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) 107 | """ 108 | def adam(opts \\ []) do 109 | {learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-3) 110 | 111 | Updates.scale_by_adam(opts) 112 | |> scale_by_learning_rate(learning_rate) 113 | end 114 | 115 | @doc """ 116 | Adam with weight decay optimizer. 117 | 118 | ## Options 119 | 120 | * `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-3` 121 | * `:b1` - first moment decay. Defaults to `0.9` 122 | * `:b2` - second moment decay. Defaults to `0.999` 123 | * `:eps` - numerical stability term. Defaults to `1.0e-8` 124 | * `:eps_root` - numerical stability term. Defaults to `0.0` 125 | * `:decay` - weight decay. Defaults to `0.0` 126 | """ 127 | def adamw(opts \\ []) do 128 | {learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-3) 129 | {decay, opts} = Keyword.pop(opts, :decay, 0.0) 130 | 131 | Updates.scale_by_adam(opts) 132 | |> Updates.add_decayed_weights(decay: decay) 133 | |> scale_by_learning_rate(learning_rate) 134 | end 135 | 136 | @doc """ 137 | Lamb optimizer. 138 | 139 | ## Options 140 | 141 | * `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-2` 142 | * `:b1` - first moment decay. Defaults to `0.9` 143 | * `:b2` - second moment decay. Defaults to `0.999` 144 | * `:eps` - numerical stability term. Defaults to `1.0e-8` 145 | * `:eps_root` - numerical stability term. Defaults to `0.0` 146 | * `:decay` - weight decay. Defaults to `0.0` 147 | * `:min_norm` - minimum norm value. Defaults to `0.0` 148 | 149 | ## References 150 | 151 | * [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962) 152 | """ 153 | def lamb(opts \\ []) do 154 | {learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-2) 155 | {decay, opts} = Keyword.pop(opts, :decay, 0.0) 156 | {min_norm, opts} = Keyword.pop(opts, :min_norm, 0.0) 157 | 158 | Updates.scale_by_adam(opts) 159 | |> Updates.add_decayed_weights(decay: decay) 160 | |> Updates.scale_by_trust_ratio(min_norm: min_norm) 161 | |> scale_by_learning_rate(learning_rate) 162 | end 163 | 164 | @doc """ 165 | Noisy SGD optimizer. 166 | 167 | ## Options 168 | 169 | * `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-2` 170 | * `:eta` - used to compute variance of noise distribution. Defaults to `0.1` 171 | * `:gamma` - used to compute variance of noise distribution. Defaults to `0.55` 172 | """ 173 | def noisy_sgd(opts \\ []) do 174 | {learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-2) 175 | 176 | scale_by_learning_rate(learning_rate) 177 | |> Updates.add_noise(opts) 178 | end 179 | 180 | @doc """ 181 | Rectified Adam optimizer. 182 | 183 | ## Options 184 | 185 | * `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-3` 186 | * `:b1` - first moment decay. Defaults to `0.9` 187 | * `:b2` - second moment decay. Defaults to `0.999` 188 | * `:eps` - numerical stability term. Defaults to `1.0e-8` 189 | * `:eps_root` - numerical stability term. Defaults to `0.0` 190 | * `:threshold` - threshold term. Defaults to `5.0` 191 | 192 | ## References 193 | 194 | * [On the Variance of Adaptive Learning Rate and Beyond](https://arxiv.org/pdf/1908.03265.pdf) 195 | """ 196 | def radam(opts \\ []) do 197 | {learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-3) 198 | 199 | Updates.scale_by_radam(opts) 200 | |> scale_by_learning_rate(learning_rate) 201 | end 202 | 203 | @doc """ 204 | RMSProp optimizer. 205 | 206 | ## Options 207 | 208 | * `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-2` 209 | * `:centered` - whether to scale by centered root of EMA of squares. Defaults to `false` 210 | * `:momentum` - momentum term. If set, uses SGD with momentum and decay set 211 | to value of this term. 212 | * `:nesterov` - whether or not to use nesterov momentum. Defaults to `false` 213 | * `:initial_scale` - initial value of EMA. Defaults to `0.0` 214 | * `:decay` - EMA decay rate. Defaults to `0.9` 215 | * `:eps` - numerical stability term. Defaults to `1.0e-8` 216 | """ 217 | def rmsprop(opts \\ []) do 218 | {learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-2) 219 | {centered, opts} = Keyword.pop(opts, :centered, false) 220 | {nesterov?, opts} = Keyword.pop(opts, :nesterov, false) 221 | {momentum, opts} = Keyword.pop(opts, :momentum, nil) 222 | 223 | combinator = 224 | if centered do 225 | Updates.scale_by_stddev(opts) 226 | else 227 | Updates.scale_by_rms(opts) 228 | end 229 | |> scale_by_learning_rate(learning_rate) 230 | 231 | if momentum, 232 | do: Updates.trace(combinator, decay: momentum, nesterov: nesterov?), 233 | else: combinator 234 | end 235 | 236 | @doc """ 237 | SGD optimizer. 238 | 239 | ## Options 240 | 241 | * `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-2` 242 | * `:momentum` - momentum term. If set, uses SGD with momentum and decay set 243 | to value of this term. 244 | * `:nesterov` - whether or not to use nesterov momentum. Defaults to `false` 245 | """ 246 | def sgd(opts \\ []) do 247 | {learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-2) 248 | momentum = opts[:momentum] 249 | nesterov? = opts[:nesterov] || false 250 | 251 | if momentum do 252 | Updates.trace(decay: momentum, nesterov: nesterov?) 253 | |> scale_by_learning_rate(learning_rate) 254 | else 255 | scale_by_learning_rate(learning_rate) 256 | end 257 | end 258 | 259 | @doc """ 260 | Yogi optimizer. 261 | 262 | ## Options 263 | 264 | * `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-2` 265 | * `:initial_accumulator_value` - initial value for first and second moment. Defaults to `0.0` 266 | * `:b1` - first moment decay. Defaults to `0.9` 267 | * `:b2` - second moment decay. Defaults to `0.999` 268 | * `:eps` - numerical stability term. Defaults to `1.0e-8` 269 | * `:eps_root` - numerical stability term. Defaults to `0.0` 270 | 271 | ## References 272 | 273 | * [Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf) 274 | """ 275 | def yogi(opts \\ []) do 276 | {learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-2) 277 | 278 | Updates.scale_by_yogi(opts) 279 | |> scale_by_learning_rate(learning_rate) 280 | end 281 | 282 | ## Helpers 283 | 284 | defp scale_by_learning_rate(combinator \\ Updates.identity(), lr) 285 | 286 | defp scale_by_learning_rate(combinator, schedule) when is_function(schedule, 1) do 287 | Updates.scale_by_schedule(combinator, fn count -> Nx.negate(schedule.(count)) end) 288 | end 289 | 290 | defp scale_by_learning_rate(combinator, lr) do 291 | Updates.scale_by_state(combinator, -lr) 292 | end 293 | end 294 | -------------------------------------------------------------------------------- /lib/polaris/schedules.ex: -------------------------------------------------------------------------------- 1 | defmodule Polaris.Schedules do 2 | @moduledoc """ 3 | Parameter Schedules. 4 | 5 | Parameter schedules are often used to anneal hyperparameters 6 | such as the learning rate during the training process. Schedules 7 | provide a mapping from the current time step to a learning rate 8 | or another hyperparameter. 9 | 10 | Choosing a good learning rate and consequently a good learning 11 | rate schedule is typically a process of trial and error. Learning 12 | rates should be relatively small such that the learning curve 13 | does not oscillate violently during the training process, but 14 | not so small that learning proceeds too slowly. Using a 15 | schedule slowly decreases oscillations during the training 16 | process such that, as the model converges, training also 17 | becomes more stable. 18 | 19 | All of the functions in this module are implemented as 20 | numerical functions and can be JIT or AOT compiled with 21 | any supported `Nx` compiler. 22 | """ 23 | 24 | import Nx.Defn 25 | 26 | @doc """ 27 | Linear decay schedule. 28 | 29 | ## Options 30 | 31 | * `:warmup` - scheduler warmup steps. Defaults to `0` 32 | 33 | * `:steps` - total number of decay steps. Defaults to `1000` 34 | """ 35 | def linear_decay(init_value, opts \\ []) do 36 | &apply_linear_decay(&1, [{:init_value, init_value} | opts]) 37 | end 38 | 39 | defnp apply_linear_decay(step, opts \\ []) do 40 | opts = 41 | keyword!(opts, 42 | init_value: 1.0e-2, 43 | warmup: 0, 44 | steps: 1000 45 | ) 46 | 47 | scale = 48 | if step < opts[:warmup] do 49 | step / Nx.max(1, opts[:warmup]) 50 | else 51 | Nx.max(0.0, (opts[:steps] - step) / Nx.max(1, opts[:steps] - opts[:warmup])) 52 | end 53 | 54 | scale * opts[:init_value] 55 | end 56 | 57 | @doc ~S""" 58 | Exponential decay schedule. 59 | 60 | $$\gamma(t) = \gamma_0 * r^{\frac{t}{k}}$$ 61 | 62 | ## Options 63 | 64 | * `:decay_rate` - rate of decay. $r$ in above formulation. 65 | Defaults to `0.95` 66 | 67 | * `:transition_steps` - steps per transition. $k$ in above 68 | formulation. Defaults to `10` 69 | 70 | * `:transition_begin` - step to begin transition. Defaults to `0` 71 | 72 | * `:staircase` - discretize outputs. Defaults to `false` 73 | 74 | """ 75 | def exponential_decay(init_value, opts \\ []) do 76 | &apply_exponential_decay(&1, [{:init_value, init_value} | opts]) 77 | end 78 | 79 | defnp apply_exponential_decay(step, opts \\ []) do 80 | opts = 81 | keyword!(opts, 82 | init_value: 1.0e-2, 83 | decay_rate: 0.95, 84 | transition_steps: 10, 85 | transition_begin: 0, 86 | staircase: false 87 | ) 88 | 89 | init_value = opts[:init_value] 90 | rate = opts[:decay_rate] 91 | staircase? = opts[:staircase] 92 | k = opts[:transition_steps] 93 | start = opts[:transition_begin] 94 | 95 | t = Nx.subtract(step, start) 96 | 97 | p = 98 | if staircase? do 99 | t 100 | |> Nx.divide(k) 101 | |> Nx.floor() 102 | else 103 | t 104 | |> Nx.divide(k) 105 | end 106 | 107 | decayed_value = 108 | rate 109 | |> Nx.pow(p) 110 | |> Nx.multiply(init_value) 111 | 112 | Nx.select( 113 | Nx.less_equal(t, 0), 114 | init_value, 115 | decayed_value 116 | ) 117 | end 118 | 119 | @doc ~S""" 120 | Cosine decay schedule. 121 | 122 | $$\gamma(t) = \gamma_0 * \left(\frac{1}{2}(1 - \alpha)(1 + \cos\pi \frac{t}{k}) + \alpha\right)$$ 123 | 124 | ## Options 125 | 126 | * `:decay_steps` - number of steps to apply decay for. 127 | $k$ in above formulation. Defaults to `10` 128 | 129 | * `:alpha` - minimum value of multiplier adjusting learning rate. 130 | $\alpha$ in above formulation. Defaults to `0.0` 131 | 132 | ## References 133 | 134 | * [SGDR: Stochastic Gradient Descent with Warm Restarts](https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx) 135 | 136 | """ 137 | def cosine_decay(init_value, opts \\ []) do 138 | &apply_cosine_decay(&1, [{:init_value, init_value} | opts]) 139 | end 140 | 141 | defnp apply_cosine_decay(step, opts \\ []) do 142 | opts = keyword!(opts, init_value: 1.0e-2, decay_steps: 10, alpha: 0.0) 143 | init_value = opts[:init_value] 144 | decay_steps = opts[:decay_steps] 145 | alpha = opts[:alpha] 146 | 147 | theta = Nx.min(step, decay_steps) / decay_steps * Nx.Constants.pi() 148 | 149 | cos = (Nx.cos(theta) + 1) / 2 * (1 - alpha) 150 | 151 | init_value * (cos + alpha) 152 | end 153 | 154 | @doc ~S""" 155 | Constant schedule. 156 | 157 | $$\gamma(t) = \gamma_0$$ 158 | 159 | """ 160 | def constant(init_value, opts \\ []) do 161 | &apply_constant(&1, [{:init_value, init_value} | opts]) 162 | end 163 | 164 | defnp apply_constant(_step, opts \\ []) do 165 | opts = keyword!(opts, init_value: 0.01) 166 | opts[:init_value] 167 | end 168 | 169 | @doc ~S""" 170 | Polynomial schedule. 171 | 172 | $$\gamma(t) = (\gamma_0 - \gamma_n) * (1 - \frac{t}{k})^p$$ 173 | 174 | ## Options 175 | 176 | * `:end_value` - end value of annealed scalar. $\gamma_n$ in above formulation. 177 | Defaults to `1.0e-3` 178 | 179 | * `:power` - power of polynomial. $p$ in above formulation. Defaults to `2` 180 | 181 | * `:transition_steps` - number of steps over which annealing takes place. 182 | $k$ in above formulation. Defaults to `10` 183 | 184 | """ 185 | def polynomial_decay(init_value, opts \\ []) do 186 | &apply_polynomial_decay(&1, [{:init_value, init_value} | opts]) 187 | end 188 | 189 | defnp apply_polynomial_decay(step, opts \\ []) do 190 | opts = 191 | keyword!(opts, 192 | init_value: 1.0e-2, 193 | end_value: 1.0e-3, 194 | power: 2, 195 | transition_steps: 10, 196 | transition_begin: 0 197 | ) 198 | 199 | init_value = opts[:init_value] 200 | end_value = opts[:end_value] 201 | start = opts[:transition_begin] 202 | k = opts[:transition_steps] 203 | p = opts[:power] 204 | 205 | step 206 | |> Nx.subtract(start) 207 | |> Nx.clip(0, k) 208 | |> Nx.divide(k) 209 | |> Nx.negate() 210 | |> Nx.add(1) 211 | |> Nx.pow(p) 212 | |> Nx.multiply(Nx.subtract(init_value, end_value)) 213 | |> Nx.add(end_value) 214 | end 215 | end 216 | -------------------------------------------------------------------------------- /lib/polaris/shared.ex: -------------------------------------------------------------------------------- 1 | defmodule Polaris.Shared do 2 | @moduledoc false 3 | 4 | # Collection of private helper functions and 5 | # macros for enforcing shape/type constraints, 6 | # doing shape calculations, and even some 7 | # helper numerical definitions. 8 | 9 | import Nx.Defn 10 | 11 | @doc """ 12 | Creates a zeros-like structure which matches the structure 13 | of the input. 14 | """ 15 | deftransform zeros_like(params, opts \\ []) do 16 | fulls_like(params, 0, opts) 17 | end 18 | 19 | @doc """ 20 | Creates a fulls-like tuple of inputs. 21 | """ 22 | deftransform fulls_like(params, value, opts \\ []) do 23 | opts = Keyword.validate!(opts, [:type]) 24 | fun = &Nx.broadcast(Nx.tensor(value, type: &2), &1) 25 | 26 | deep_new(params, fn x -> 27 | type = opts[:type] || Nx.type(x) 28 | fun.(Nx.shape(x), type) 29 | end) 30 | end 31 | 32 | @doc """ 33 | Deep merges two possibly nested maps, applying fun to leaf values. 34 | """ 35 | deftransform deep_merge(left, right, fun) do 36 | f = fn 37 | _, [] -> 38 | {nil, []} 39 | 40 | x, [y | t] -> 41 | {fun.(x, y), t} 42 | end 43 | 44 | case Nx.Defn.Composite.traverse(left, Nx.Defn.Composite.flatten_list([right]), f) do 45 | {merged, []} -> 46 | merged 47 | 48 | {_merged, _leftover} -> 49 | raise ArgumentError, "unable to merge arguments with incompatible structure" 50 | end 51 | end 52 | 53 | @doc """ 54 | Creates a new map-like structure from a possible nested map, applying `fun` 55 | to each leaf. 56 | """ 57 | deftransform deep_new(map, fun) do 58 | Nx.Defn.Composite.traverse(map, fun) 59 | end 60 | 61 | @doc """ 62 | Deep reduces a map with an accumulator. 63 | """ 64 | deftransform deep_reduce(map, acc, fun) do 65 | Nx.Defn.Composite.reduce(map, acc, fun) 66 | end 67 | 68 | @doc """ 69 | Deep map-reduce a nested container with an accumulator. 70 | """ 71 | deftransform deep_map_reduce(container, acc, fun) do 72 | Nx.Defn.Composite.traverse(container, acc, fun) 73 | end 74 | end 75 | -------------------------------------------------------------------------------- /lib/polaris/updates.ex: -------------------------------------------------------------------------------- 1 | defmodule Polaris.Updates do 2 | @moduledoc ~S""" 3 | Parameter update methods. 4 | 5 | Update methods transform the input tensor in some way, 6 | usually by scaling or shifting the input with respect 7 | to some input state. Update methods are composed 8 | to create more advanced optimization methods such as AdaGrad 9 | or Adam. Each update returns a tuple: 10 | 11 | {init_fn, update_fn} 12 | 13 | Which represent a state initialization and state update 14 | function respectively. While each method in the Updates 15 | API is a regular Elixir function, the two methods they 16 | return are implemented as `defn`, so they can be accelerated 17 | using any Nx backend or compiler. 18 | 19 | Update methods are just combinators that can be arbitrarily 20 | composed to create complex optimizers. For example, the Adam 21 | optimizer in `Polaris.Optimizers` is implemented as: 22 | 23 | def adam(opts \\ []) do 24 | opts 25 | |> Polaris.Updates.scale_by_adam() 26 | |> Polaris.Updates.scale(-opts[:learning_rate]) 27 | end 28 | 29 | Updates are maps of updates, often associated with parameters of 30 | the same names. Using `Polaris.Updates.apply_updates/3` will merge updates 31 | and parameters by adding associated parameters and updates, and 32 | ensuring any given model state is preserved. 33 | 34 | ## Custom combinators 35 | 36 | You can create your own combinators using the `stateless/2` and 37 | `stateful/3` primitives. Every update method in this module is 38 | implemented in terms of one of these two primitives. 39 | 40 | `stateless/2` represents a stateless update: 41 | 42 | def scale(combinator \\ Polaris.Updates.identity(), step_size) do 43 | stateless(combinator, &apply_scale(&1, &2, step_size)) 44 | end 45 | 46 | defnp apply_scale(updates, _params, step) do 47 | deep_new(updates, fn x -> Nx.multiply(x, step) end) 48 | end 49 | 50 | Notice how the function given to `stateless/2` is defined within `defn`. 51 | This is what allows the anonymous functions returned by `Polaris.Updates` 52 | to be used inside `defn`. 53 | 54 | `stateful/3` represents a stateful update and follows the same pattern: 55 | 56 | def my_stateful_update(updates) do 57 | Polaris.Updates.stateful(updates, &init_my_update/1, &apply_my_update/2) 58 | end 59 | 60 | defnp init_my_update(params) do 61 | state = zeros_like(params, type: :f32) 62 | %{state: state} 63 | end 64 | 65 | defnp apply_my_update(updates, state) do 66 | new_state = deep_new(state, fn v -> Nx.add(v, 0.01) end) 67 | updates = deep_merge(updates, state, fn g, z -> Nx.multiply(g, z) end) 68 | {updates, %{state: new_state}} 69 | end 70 | 71 | State associated with individual parameters should have keys that match the 72 | keys of the parameter. For example, if you have parameters `%{kernel: kernel}` 73 | with associated states `mu` and `nu` representing the first and second moments, 74 | your state should look something like: 75 | 76 | %{ 77 | mu: %{kernel: kernel_mu}, 78 | nu: %{kernel: kernel_nu} 79 | } 80 | """ 81 | import Nx.Defn 82 | import Polaris.Shared 83 | 84 | @doc ~S""" 85 | Scales input by a fixed step size. 86 | 87 | $$f(x_i) = \alpha x_i$$ 88 | """ 89 | def scale(combinator \\ identity(), step_size) do 90 | stateless(combinator, &apply_scale(&1, &2, step_size)) 91 | end 92 | 93 | defnp apply_scale(updates, _params, step) do 94 | deep_new(updates, fn v -> Nx.multiply(v, step) end) 95 | end 96 | 97 | @doc ~S""" 98 | Scales input by a tunable learning rate which can be 99 | manipulated by external APIs such as Polaris's Loop API. 100 | 101 | $$f(x_i) = \alpha x_i$$ 102 | """ 103 | def scale_by_state(combinator_or_step) 104 | 105 | def scale_by_state(step) when is_number(step) do 106 | scale_by_state(identity(), step) 107 | end 108 | 109 | def scale_by_state({init_fn, apply_fn} = combinator, step) 110 | when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_number(step) do 111 | stateful(combinator, &init_scale_by_state(&1, init_scale: step), &apply_scale_by_state/3) 112 | end 113 | 114 | defnp init_scale_by_state(_params, opts \\ []) do 115 | opts = keyword!(opts, [:init_scale]) 116 | %{scale: opts[:init_scale]} 117 | end 118 | 119 | defnp apply_scale_by_state(x, %{scale: scale} = state, params) do 120 | {apply_scale(x, params, scale), state} 121 | end 122 | 123 | @doc """ 124 | Scales input according to Adam algorithm. 125 | 126 | ## Options 127 | 128 | * `:b1` - first moment decay. Defaults to `0.9` 129 | 130 | * `:b2` - second moment decay. Defaults to `0.999` 131 | 132 | * `:eps` - numerical stability term. Defaults to `1.0e-8` 133 | 134 | * `:eps_root` - numerical stability term. Defaults to `1.0e-15` 135 | 136 | ## References 137 | 138 | * [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) 139 | 140 | """ 141 | def scale_by_adam(combinator_or_opts \\ []) 142 | 143 | def scale_by_adam(opts) when is_list(opts) do 144 | scale_by_adam(identity(), opts) 145 | end 146 | 147 | def scale_by_adam({init_fn, apply_fn} = combinator) 148 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 149 | scale_by_adam(combinator, []) 150 | end 151 | 152 | def scale_by_adam({init_fn, apply_fn} = combinator, opts) 153 | when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do 154 | stateful( 155 | combinator, 156 | &init_scale_by_adam/1, 157 | &apply_scale_by_adam(&1, &2, &3, opts) 158 | ) 159 | end 160 | 161 | defnp init_scale_by_adam(params) do 162 | mus = zeros_like(params, type: :f32) 163 | nus = zeros_like(params, type: :f32) 164 | count = Nx.tensor(0) 165 | %{mu: mus, nu: nus, count: count} 166 | end 167 | 168 | defnp apply_scale_by_adam(x, %{mu: mu, nu: nu, count: count}, _params, opts \\ []) do 169 | opts = keyword!(opts, b1: 0.9, b2: 0.999, eps: 1.0e-8, eps_root: 1.0e-15) 170 | b1 = opts[:b1] 171 | b2 = opts[:b2] 172 | eps = opts[:eps] 173 | eps_root = opts[:eps_root] 174 | 175 | mu = update_moment(x, mu, b1, 1) 176 | nu = update_moment(x, nu, b2, 2) 177 | 178 | mu_hat = bias_correction(mu, b1, count + 1) 179 | nu_hat = bias_correction(nu, b2, count + 1) 180 | 181 | x = deep_merge(mu_hat, nu_hat, fn z, t -> z / (Nx.sqrt(t + eps_root) + eps) end) 182 | {x, %{mu: mu, nu: nu, count: count + 1}} 183 | end 184 | 185 | @doc """ 186 | Scales input by the root of all prior squared inputs. 187 | 188 | ## Options 189 | 190 | * `:eps` - numerical stability term. Defaults to `1.0e-7` 191 | 192 | """ 193 | def scale_by_rss(combinator_or_opts \\ []) 194 | 195 | def scale_by_rss(opts) when is_list(opts) do 196 | scale_by_rss(identity(), opts) 197 | end 198 | 199 | def scale_by_rss({init_fn, apply_fn} = combinator) 200 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 201 | scale_by_rss(combinator, []) 202 | end 203 | 204 | def scale_by_rss({init_fn, apply_fn} = combinator, opts) 205 | when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do 206 | {initial, opts} = Keyword.pop(opts, :initial_accumulator_value, 0.1) 207 | 208 | stateful( 209 | combinator, 210 | &init_scale_by_rss(&1, initial), 211 | &apply_scale_by_rss(&1, &2, &3, opts) 212 | ) 213 | end 214 | 215 | defnp init_scale_by_rss(params, value) do 216 | sum_of_squares = fulls_like(params, value, type: :f32) 217 | %{sum_of_squares: sum_of_squares} 218 | end 219 | 220 | defnp apply_scale_by_rss(x, %{sum_of_squares: sum_of_squares}, _params, opts \\ []) do 221 | opts = keyword!(opts, eps: 1.0e-7) 222 | eps = opts[:eps] 223 | 224 | sum_of_squares = deep_merge(x, sum_of_squares, fn g, z -> Nx.pow(g, 2) + z end) 225 | 226 | inv_sqrt_squares = deep_new(sum_of_squares, fn z -> Nx.rsqrt(z + eps) end) 227 | 228 | inv_sqrt_x_square = 229 | deep_merge(sum_of_squares, inv_sqrt_squares, fn z, t -> 230 | Nx.select(Nx.greater(z, 0), t, 0.0) 231 | end) 232 | 233 | x = deep_merge(x, inv_sqrt_x_square, fn g, t -> g * t end) 234 | 235 | {x, %{sum_of_squares: sum_of_squares}} 236 | end 237 | 238 | @doc """ 239 | Scales input by the root of the EMA of squared inputs. 240 | 241 | ## Options 242 | 243 | * `:decay` - EMA decay rate. Defaults to `0.9`. 244 | 245 | * `:eps` - numerical stability term. Defaults to `1.0e-8`. 246 | 247 | ## References 248 | 249 | * [Overview of mini-batch gradient descent](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) 250 | 251 | """ 252 | def scale_by_rms(combinator_or_opts \\ []) 253 | 254 | def scale_by_rms(opts) when is_list(opts) do 255 | scale_by_rms(identity(), opts) 256 | end 257 | 258 | def scale_by_rms({init_fn, apply_fn} = combinator) 259 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 260 | scale_by_rms(combinator, []) 261 | end 262 | 263 | def scale_by_rms({init_fn, apply_fn} = combinator, opts) 264 | when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do 265 | {initial, opts} = Keyword.pop(opts, :initial_scale, 0.0) 266 | 267 | stateful( 268 | combinator, 269 | &init_scale_by_rms(&1, initial), 270 | &apply_scale_by_rms(&1, &2, &3, opts) 271 | ) 272 | end 273 | 274 | defnp init_scale_by_rms(params, scale) do 275 | nu = fulls_like(params, scale, type: :f32) 276 | %{nu: nu} 277 | end 278 | 279 | defnp apply_scale_by_rms(x, %{nu: nu}, _params, opts \\ []) do 280 | opts = keyword!(opts, decay: 0.9, eps: 1.0e-8) 281 | decay = opts[:decay] 282 | eps = opts[:eps] 283 | 284 | nu = update_moment(x, nu, decay, 2) 285 | 286 | x = deep_merge(x, nu, fn g, t -> Nx.rsqrt(t + eps) * g end) 287 | 288 | {x, %{nu: nu}} 289 | end 290 | 291 | @doc """ 292 | Scales input according to the AdaBelief algorithm. 293 | 294 | ## Options 295 | 296 | * `:b1` - first moment decay. Defaults to `0.9`. 297 | 298 | * `:b2` - second moment decay. Defaults to `0.999`. 299 | 300 | * `:eps` - numerical stability term. Defaults to `0.0`. 301 | 302 | * `:eps_root` - numerical stability term. Defaults to `1.0e-16`. 303 | 304 | ## References 305 | 306 | * [AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468) 307 | 308 | """ 309 | def scale_by_belief(combinator_or_opts \\ []) 310 | 311 | def scale_by_belief(opts) when is_list(opts) do 312 | scale_by_belief(identity(), opts) 313 | end 314 | 315 | def scale_by_belief({init_fn, apply_fn} = combinator) 316 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 317 | scale_by_belief(combinator, []) 318 | end 319 | 320 | def scale_by_belief({init_fn, apply_fn} = combinator, opts) 321 | when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do 322 | stateful( 323 | combinator, 324 | &init_scale_by_belief/1, 325 | &apply_scale_by_belief(&1, &2, &3, opts) 326 | ) 327 | end 328 | 329 | defnp init_scale_by_belief(params) do 330 | mus = zeros_like(params, type: :f32) 331 | nus = zeros_like(params, type: :f32) 332 | count = Nx.tensor(0) 333 | %{mu: mus, nu: nus, count: count} 334 | end 335 | 336 | defnp apply_scale_by_belief(x, %{mu: mu, nu: nu, count: count}, _params, opts \\ []) do 337 | opts = keyword!(opts, b1: 0.9, b2: 0.999, eps: 0.0, eps_root: 1.0e-16) 338 | b1 = opts[:b1] 339 | b2 = opts[:b2] 340 | eps = opts[:eps] 341 | eps_root = opts[:eps_root] 342 | 343 | mu = update_moment(x, mu, b1, 1) 344 | nu = update_moment(x, nu, b2, 2) 345 | 346 | mu_hat = bias_correction(mu, b1, count + 1) 347 | nu_hat = bias_correction(nu, b2, count + 1) 348 | 349 | x = deep_merge(mu_hat, nu_hat, fn z, t -> 1 / (Nx.sqrt(t + eps_root) + eps) * z end) 350 | 351 | {x, %{mu: mu, nu: nu, count: count + 1}} 352 | end 353 | 354 | @doc """ 355 | Scales input by the root of the centered EMA of squared inputs. 356 | 357 | ## Options 358 | 359 | * `:decay` - EMA decay rate. Defaults to `0.9`. 360 | 361 | * `:eps` - numerical stability term. Defaults to `1.0e-8`. 362 | 363 | ## References 364 | 365 | * [Overview of mini-batch gradient descent](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) 366 | 367 | """ 368 | def scale_by_stddev(combinator_or_opts \\ []) 369 | 370 | def scale_by_stddev(opts) when is_list(opts) do 371 | scale_by_stddev(identity(), opts) 372 | end 373 | 374 | def scale_by_stddev({init_fn, apply_fn} = combinator) 375 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 376 | scale_by_stddev(combinator, []) 377 | end 378 | 379 | def scale_by_stddev({init_fn, apply_fn} = combinator, opts) 380 | when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do 381 | {initial, opts} = Keyword.pop(opts, :initial_scale, 0.0) 382 | 383 | stateful( 384 | combinator, 385 | &init_scale_by_stddev(&1, initial), 386 | &apply_scale_by_stddev(&1, &2, &3, opts) 387 | ) 388 | end 389 | 390 | defnp init_scale_by_stddev(params, value) do 391 | mu = zeros_like(params, type: :f32) 392 | nu = fulls_like(params, value, type: :f32) 393 | %{mu: mu, nu: nu} 394 | end 395 | 396 | defnp apply_scale_by_stddev(x, %{mu: mu, nu: nu}, _params, opts \\ []) do 397 | opts = keyword!(opts, decay: 0.9, eps: 1.0e-8) 398 | decay = opts[:decay] 399 | eps = opts[:eps] 400 | 401 | mu = update_moment(x, mu, decay, 1) 402 | nu = update_moment(x, nu, decay, 2) 403 | 404 | mu_nu = 405 | deep_merge(mu, nu, fn m, n -> 406 | Nx.rsqrt(-Nx.pow(m, 2) + n + eps) 407 | end) 408 | 409 | x = deep_merge(x, mu_nu, fn g, mn -> g * mn end) 410 | 411 | {x, %{mu: mu, nu: nu}} 412 | end 413 | 414 | @doc """ 415 | Scales input using the given schedule function. 416 | 417 | This can be useful for implementing learning rate schedules. 418 | The number of update iterations is tracked by an internal 419 | counter. You might need to update the schedule to operate 420 | on per-batch schedule rather than per-epoch. 421 | """ 422 | def scale_by_schedule(combinator \\ identity(), schedule_fn) when is_function(schedule_fn, 1) do 423 | stateful( 424 | combinator, 425 | &init_scale_by_schedule/1, 426 | &apply_scale_by_schedule(&1, &2, &3, schedule_fn) 427 | ) 428 | end 429 | 430 | defnp init_scale_by_schedule(_) do 431 | %{count: Nx.tensor(0)} 432 | end 433 | 434 | defnp apply_scale_by_schedule(x, %{count: count}, _params, schedule_fn) do 435 | step_size = schedule_fn.(count) 436 | 437 | updates = deep_new(x, fn x -> x * step_size end) 438 | 439 | {updates, %{count: count + 1}} 440 | end 441 | 442 | @doc """ 443 | Scale input according to the Rectified Adam algorithm. 444 | 445 | ## Options 446 | 447 | * `:b1` - first moment decay. Defaults to `0.9` 448 | 449 | * `:b2` - second moment decay. Defaults to `0.999` 450 | 451 | * `:eps` - numerical stability term. Defaults to `1.0e-8` 452 | 453 | * `:eps_root` - numerical stability term. Defaults to `0.0` 454 | 455 | * `:threshold` - threshold for variance. Defaults to `5.0` 456 | 457 | ## References 458 | 459 | * [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265) 460 | 461 | """ 462 | def scale_by_radam(combinator_or_opts \\ []) 463 | 464 | def scale_by_radam(opts) when is_list(opts) do 465 | scale_by_radam(identity(), opts) 466 | end 467 | 468 | def scale_by_radam({init_fn, apply_fn} = combinator) 469 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 470 | scale_by_radam(combinator, []) 471 | end 472 | 473 | def scale_by_radam({init_fn, apply_fn} = combinator, opts) 474 | when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do 475 | stateful( 476 | combinator, 477 | &init_scale_by_radam/1, 478 | &apply_scale_by_radam(&1, &2, &3, opts) 479 | ) 480 | end 481 | 482 | defnp init_scale_by_radam(params) do 483 | mu = zeros_like(params, type: :f32) 484 | nu = zeros_like(params, type: :f32) 485 | count = Nx.tensor(0) 486 | %{mu: mu, nu: nu, count: count} 487 | end 488 | 489 | defnp apply_scale_by_radam(x, %{mu: mu, nu: nu, count: count}, _params, opts \\ []) do 490 | opts = keyword!(opts, b1: 0.9, b2: 0.999, eps: 1.0e-8, eps_root: 0.0, threshold: 5.0) 491 | b1 = opts[:b1] 492 | b2 = opts[:b2] 493 | eps = opts[:eps] 494 | eps_root = opts[:eps_root] 495 | threshold = opts[:threshold] 496 | 497 | ro_inf = 2.0 / (1 - b1) - 1 498 | 499 | mu = update_moment(x, mu, b1, 1) 500 | nu = update_moment(x, nu, b2, 2) 501 | count_inc = count + 1 502 | 503 | b2t = Nx.pow(b2, count_inc) 504 | ro = ro_inf - 2 * count_inc * b2t / (1 - b2t) 505 | 506 | mu_hat = bias_correction(mu, b1, count + 1) 507 | nu_hat = bias_correction(nu, b2, count + 1) 508 | 509 | x = 510 | if Nx.all(Nx.greater_equal(ro, threshold)) do 511 | radam_update(ro, ro_inf, mu_hat, nu_hat, eps_root, eps) 512 | else 513 | mu_hat 514 | end 515 | 516 | {x, %{mu: mu, nu: nu, count: count + 1}} 517 | end 518 | 519 | defnp radam_update(ro, ro_inf, mu, nu, eps_root, eps) do 520 | r = Nx.sqrt((ro - 4) * (ro - 2) * ro_inf / ((ro_inf - 4) * (ro_inf - 2) * ro)) 521 | 522 | deep_merge(mu, nu, fn m, v -> 523 | r * m / (Nx.sqrt(v + eps_root) + eps) 524 | end) 525 | end 526 | 527 | @doc """ 528 | Trace inputs with past inputs. 529 | 530 | ## Options 531 | 532 | * `:decay` - decay rate for tracing past updates. Defaults 533 | to `0.9` 534 | * `:nesterov` - whether to use Nesterov momentum. Defaults 535 | to `false` 536 | 537 | """ 538 | def trace(combinator_or_opts \\ []) 539 | 540 | def trace(opts) when is_list(opts) do 541 | trace(identity(), opts) 542 | end 543 | 544 | def trace({init_fn, apply_fn} = combinator) 545 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 546 | trace(combinator, []) 547 | end 548 | 549 | def trace({init_fn, apply_fn} = combinator, opts) 550 | when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do 551 | stateful( 552 | combinator, 553 | &init_trace/1, 554 | &apply_trace(&1, &2, &3, opts) 555 | ) 556 | end 557 | 558 | defnp init_trace(params) do 559 | trace = zeros_like(params, type: :f32) 560 | %{trace: trace} 561 | end 562 | 563 | defnp apply_trace(x, %{trace: trace}, _params, opts \\ []) do 564 | opts = keyword!(opts, decay: 0.9, nesterov: false) 565 | decay = opts[:decay] 566 | 567 | update_trace = deep_merge(x, trace, fn g, t -> t * decay + g end) 568 | 569 | x = 570 | if opts[:nesterov] do 571 | deep_merge(x, update_trace, fn g, t -> t * decay + g end) 572 | else 573 | update_trace 574 | end 575 | 576 | {x, %{trace: update_trace}} 577 | end 578 | 579 | @doc """ 580 | Clips input between -delta and delta. 581 | 582 | ## Options 583 | 584 | * `:delta` - maximum absolute value of the input. Defaults 585 | to `2.0` 586 | """ 587 | def clip(combinator_or_opts \\ []) 588 | 589 | def clip(opts) when is_list(opts) do 590 | clip(identity(), opts) 591 | end 592 | 593 | def clip({init_fn, apply_fn} = combinator) 594 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 595 | clip(combinator, []) 596 | end 597 | 598 | def clip({init_fn, apply_fn} = combinator, opts) 599 | when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do 600 | stateless(combinator, &apply_clip(&1, &2, opts)) 601 | end 602 | 603 | defnp apply_clip(x, _params, opts \\ []) do 604 | opts = keyword!(opts, delta: 2.0) 605 | delta = opts[:delta] 606 | 607 | deep_new(x, fn g -> Nx.clip(g, -delta, delta) end) 608 | end 609 | 610 | @doc """ 611 | Clips input using input global norm. 612 | 613 | ## Options 614 | 615 | * `:max_norm` - maximum norm value of input. Defaults to 616 | `1.0` 617 | """ 618 | def clip_by_global_norm(combinator_or_opts \\ []) 619 | 620 | def clip_by_global_norm(opts) when is_list(opts) do 621 | clip_by_global_norm(identity(), opts) 622 | end 623 | 624 | def clip_by_global_norm({init_fn, apply_fn} = combinator) 625 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 626 | clip_by_global_norm(combinator, []) 627 | end 628 | 629 | def clip_by_global_norm({init_fn, apply_fn} = combinator, opts) 630 | when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do 631 | stateless(combinator, &apply_clip_by_global_norm(&1, &2, opts)) 632 | end 633 | 634 | defnp apply_clip_by_global_norm(x, _params, opts \\ []) do 635 | opts = keyword!(opts, max_norm: 1.0) 636 | max_norm = opts[:max_norm] 637 | 638 | sum_gs = 639 | deep_reduce(x, Nx.tensor(0.0), fn leaf, acc -> 640 | leaf 641 | |> Nx.pow(2) 642 | |> Nx.sum() 643 | |> Nx.add(acc) 644 | end) 645 | 646 | g_norm = Nx.sqrt(sum_gs) 647 | 648 | deep_new(x, fn z -> 649 | Nx.select(Nx.less(g_norm, max_norm), z, z / g_norm * max_norm) 650 | end) 651 | end 652 | 653 | @doc """ 654 | Centralizes input by shifting updates by their mean. 655 | """ 656 | def centralize(combinator_or_opts \\ []) 657 | 658 | def centralize(opts) when is_list(opts) do 659 | centralize(identity(), opts) 660 | end 661 | 662 | def centralize({init_fn, apply_fn} = combinator) 663 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 664 | centralize(combinator, []) 665 | end 666 | 667 | def centralize({init_fn, apply_fn} = combinator, opts) 668 | when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do 669 | stateless(combinator, &apply_centralize(&1, &2, opts)) 670 | end 671 | 672 | defnp apply_centralize(x, _params, _opts \\ []) do 673 | deep_new(x, ¢ralize_for_rank/1) 674 | end 675 | 676 | deftransformp centralize_for_rank(input) do 677 | if Nx.rank(input) > 1 do 678 | input 679 | |> Nx.subtract(Nx.mean(input, axes: tl(Nx.axes(input)), keep_axes: true)) 680 | else 681 | input 682 | end 683 | end 684 | 685 | @doc """ 686 | Adds decayed weights to updates. 687 | 688 | Commonly used as a regularization strategy. 689 | 690 | ## Options 691 | 692 | * `:decay` - Rate of decay. Defaults to `0.0`. 693 | """ 694 | def add_decayed_weights(combinator_or_opts \\ []) 695 | 696 | def add_decayed_weights(opts) when is_list(opts) do 697 | add_decayed_weights(identity(), opts) 698 | end 699 | 700 | def add_decayed_weights({init_fn, apply_fn} = combinator) 701 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 702 | add_decayed_weights(combinator, []) 703 | end 704 | 705 | def add_decayed_weights({init_fn, apply_fn} = combinator, opts) 706 | when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do 707 | stateless(combinator, fn updates, params -> 708 | opts = Nx.Defn.Kernel.keyword!(opts, decay: 0.0) 709 | # Decay can be a tensor, that's why we preprocess it before-hand 710 | # and pass it as argument to defn instead of as an option. 711 | apply_weight_decay(updates, params, opts[:decay]) 712 | end) 713 | end 714 | 715 | defnp apply_weight_decay(updates, params, decay) do 716 | deep_merge(updates, params, fn g, p -> g + decay * p end) 717 | end 718 | 719 | @doc """ 720 | Scale by trust ratio. 721 | 722 | ## Options 723 | 724 | * `:min_norm` - Min norm to clip. Defaults to 725 | `0.0`. 726 | 727 | * `:trust_coefficient` - Trust coefficient. Defaults 728 | to `1.0`. 729 | 730 | * `:eps` - Numerical stability term. Defaults to `0.0`. 731 | """ 732 | def scale_by_trust_ratio(combinator_or_opts \\ []) 733 | 734 | def scale_by_trust_ratio(opts) when is_list(opts) do 735 | scale_by_trust_ratio(identity(), opts) 736 | end 737 | 738 | def scale_by_trust_ratio({init_fn, apply_fn} = combinator) 739 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 740 | scale_by_trust_ratio(combinator, []) 741 | end 742 | 743 | def scale_by_trust_ratio({init_fn, apply_fn} = combinator, opts) 744 | when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do 745 | stateless(combinator, fn update, params -> 746 | opts = Nx.Defn.Kernel.keyword!(opts, min_norm: 0.0, trust_coefficient: 1.0, eps: 0.0) 747 | 748 | apply_scale_by_trust_ratio( 749 | update, 750 | params, 751 | opts[:min_norm], 752 | opts[:trust_coefficient], 753 | opts[:eps] 754 | ) 755 | end) 756 | end 757 | 758 | defnp apply_scale_by_trust_ratio(updates, params, min_norm, trust_coefficient, eps) do 759 | deep_merge(updates, params, fn g, p -> 760 | param_norm = safe_norm(p, min_norm) 761 | update_norm = safe_norm(g, min_norm) 762 | 763 | trust_ratio = trust_coefficient * param_norm / (update_norm + eps) 764 | 765 | zero_norm = param_norm == 0.0 or update_norm == 0.0 766 | safe_trust_ratio = Nx.select(zero_norm, 1, trust_ratio) 767 | g * safe_trust_ratio 768 | end) 769 | end 770 | 771 | @doc """ 772 | Adds random Gaussian noise to the input. 773 | 774 | ## Options 775 | 776 | * `:seed` - Random seed to use. Defaults to the 777 | current system time. 778 | 779 | * `:eta` - Controls amount of noise to add. 780 | Defaults to `0.01`. 781 | 782 | * `:gamma` - Controls amount of noise to add. 783 | Defaults to `0.55`. 784 | """ 785 | def add_noise(combinator_or_opts \\ []) 786 | 787 | def add_noise(opts) when is_list(opts) do 788 | add_noise(identity(), opts) 789 | end 790 | 791 | def add_noise({init_fn, apply_fn} = combinator) 792 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 793 | add_noise(combinator, []) 794 | end 795 | 796 | def add_noise({init_fn, apply_fn} = combinator, opts) 797 | when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do 798 | {seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end) 799 | stateful(combinator, &init_add_noise(&1, seed: seed), &apply_add_noise(&1, &2, &3, opts)) 800 | end 801 | 802 | defnp init_add_noise(_params, opts \\ []) do 803 | %{count: Nx.tensor(0), key: Nx.Random.key(opts[:seed])} 804 | end 805 | 806 | defnp apply_add_noise(x, %{count: count, key: key}, _params, opts \\ []) do 807 | opts = keyword!(opts, eta: 0.01, gamma: 0.55) 808 | var = opts[:eta] / Nx.pow(count + 1, opts[:gamma]) 809 | 810 | {noise, key} = 811 | deep_map_reduce(x, key, fn z, key -> 812 | Nx.Random.normal(key, shape: Nx.shape(z), type: Nx.type(z)) 813 | end) 814 | 815 | updates = deep_merge(x, noise, fn g, n -> g + var * n end) 816 | 817 | {updates, %{count: count + 1, key: key}} 818 | end 819 | 820 | @doc """ 821 | Scale input according to the Yogi algorithm. 822 | 823 | ## Options 824 | 825 | * `:initial_accumulator_value` - Initial state accumulator value. 826 | 827 | * `:b1` - first moment decay. Defaults to `0.9` 828 | 829 | * `:b2` - second moment decay. Defaults to `0.999` 830 | 831 | * `:eps` - numerical stability term. Defaults to `1.0e-8` 832 | 833 | * `:eps_root` - numerical stability term. Defaults to `0.0` 834 | 835 | ## References 836 | 837 | * [Adaptive Methods for Nonconvex Optimization](https://proceedings.neurips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf) 838 | """ 839 | def scale_by_yogi(combinator_or_opts \\ []) 840 | 841 | def scale_by_yogi(opts) when is_list(opts) do 842 | scale_by_yogi(identity(), opts) 843 | end 844 | 845 | def scale_by_yogi({init_fn, apply_fn} = combinator) 846 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 847 | scale_by_yogi(combinator, []) 848 | end 849 | 850 | def scale_by_yogi({init_fn, apply_fn} = combinator, opts) 851 | when is_function(init_fn, 1) and is_function(apply_fn, 3) do 852 | {initial, opts} = Keyword.pop(opts, :initial_accumulator_value, 1.0e-6) 853 | 854 | stateful( 855 | combinator, 856 | &init_scale_by_yogi(&1, initial), 857 | &apply_scale_by_yogi(&1, &2, &3, opts) 858 | ) 859 | end 860 | 861 | defnp init_scale_by_yogi(params, value) do 862 | value = fulls_like(params, value, type: :f32) 863 | mu = value 864 | nu = value 865 | count = Nx.tensor(0) 866 | %{mu: mu, nu: nu, count: count} 867 | end 868 | 869 | defnp apply_scale_by_yogi(x, %{mu: mu, nu: nu, count: count}, _params, opts \\ []) do 870 | opts = keyword!(opts, b1: 0.9, b2: 0.999, eps: 1.0e-3, eps_root: 0.0) 871 | b1 = opts[:b1] 872 | b2 = opts[:b2] 873 | eps = opts[:eps] 874 | eps_root = opts[:eps_root] 875 | 876 | mu = update_moment(x, mu, b1, 1) 877 | 878 | nu = 879 | deep_merge(x, nu, fn g, v -> 880 | v - (1 - b2) * Nx.sign(v - Nx.pow(g, 2)) * Nx.pow(g, 2) 881 | end) 882 | 883 | mu_hat = bias_correction(mu, b1, count + 1) 884 | nu_hat = bias_correction(nu, b2, count + 1) 885 | 886 | updates = deep_merge(mu_hat, nu_hat, fn m, v -> m / (Nx.sqrt(v + eps_root) + eps) end) 887 | 888 | {updates, %{mu: mu, nu: nu, count: count + 1}} 889 | end 890 | 891 | @doc """ 892 | Represents a stateless update. 893 | 894 | Stateless updates do not depend on an update state and thus 895 | only require an implementation of an update function. 896 | """ 897 | def stateless({parent_init_fn, parent_apply_fn} \\ identity(), apply_fn) do 898 | apply_fn = fn updates, state, params -> 899 | {updates, state} = parent_apply_fn.(updates, state, params) 900 | {apply_fn.(updates, params), state} 901 | end 902 | 903 | {parent_init_fn, apply_fn} 904 | end 905 | 906 | @doc """ 907 | Accumulate gradients and only apply updates every N steps. 908 | """ 909 | def accumulate_gradients({parent_init_fn, parent_apply_fn} \\ identity(), steps) 910 | when is_integer(steps) do 911 | # We can't implement this one as stateful really because 912 | # we don't want to run the updates at all nor change the optimizer 913 | # state aside from this one 914 | init_fn = fn params -> 915 | state = parent_init_fn.(params) 916 | Tuple.insert_at(state, 0, init_accumulate_gradients(params)) 917 | end 918 | 919 | apply_fn = fn updates, state, params -> 920 | this_state = elem(state, 0) 921 | other_state = Tuple.delete_at(state, 0) 922 | 923 | apply_accumulate_gradients(updates, this_state, params, other_state, parent_apply_fn, 924 | steps: steps 925 | ) 926 | end 927 | 928 | {init_fn, apply_fn} 929 | end 930 | 931 | defnp init_accumulate_gradients(params) do 932 | %{gradient_state: zeros_like(params), step: Nx.tensor(0)} 933 | end 934 | 935 | defnp apply_accumulate_gradients( 936 | updates, 937 | %{gradient_state: gradient_state, step: step}, 938 | params, 939 | parent_state, 940 | parent_apply_fn, 941 | opts \\ [] 942 | ) do 943 | opts = keyword!(opts, [:steps]) 944 | 945 | max_steps = opts[:steps] 946 | 947 | if Nx.greater_equal(step, max_steps) do 948 | updates = deep_new(updates, &Nx.divide(&1, step)) 949 | {updates, new_parent_state} = parent_apply_fn.(updates, parent_state, params) 950 | new_this_state = %{gradient_state: zeros_like(params), step: Nx.tensor(0)} 951 | {updates, tuple_insert_at(new_parent_state, 0, new_this_state)} 952 | else 953 | new_this_state = %{ 954 | gradient_state: deep_merge(updates, gradient_state, &Nx.add/2), 955 | step: Nx.add(step, 1) 956 | } 957 | 958 | {zeros_like(updates), tuple_insert_at(parent_state, 0, new_this_state)} 959 | end 960 | end 961 | 962 | deftransformp tuple_insert_at(tuple, index, element) do 963 | Tuple.insert_at(tuple, index, element) 964 | end 965 | 966 | @doc """ 967 | Returns the identity update. 968 | 969 | This is often as the initial update in many functions in this module. 970 | """ 971 | def identity() do 972 | {fn _params -> {} end, fn updates, state, _params -> {updates, state} end} 973 | end 974 | 975 | def identity(combinator) do 976 | combinator 977 | end 978 | 979 | @doc """ 980 | Composes two updates. This is useful for extending optimizers 981 | without having to reimplement them. For example, you can implement 982 | gradient centralization: 983 | 984 | import Polaris.Updates 985 | 986 | Polaris.Updates.compose(Polaris.Updates.centralize(), Polaris.Optimizers.rmsprop()) 987 | 988 | This is equivalent to: 989 | 990 | Polaris.Updates.centralize() 991 | |> Polaris.Updates.scale_by_rms() 992 | """ 993 | def compose({init_fn1, apply_fn1}, {init_fn2, apply_fn2}) do 994 | init_fn = fn params -> 995 | state = init_fn1.(params) 996 | Tuple.insert_at(state, 0, init_fn2.(params)) 997 | end 998 | 999 | apply_fn = fn updates, state, params -> 1000 | this_state = elem(state, 0) 1001 | other_state = Tuple.delete_at(state, 0) 1002 | {updates, new_other_state} = apply_fn1.(updates, other_state, params) 1003 | {updates, new_this_state} = apply_fn2.(updates, this_state, params) 1004 | {updates, Tuple.insert_at(new_other_state, 0, new_this_state)} 1005 | end 1006 | 1007 | {init_fn, apply_fn} 1008 | end 1009 | 1010 | @doc """ 1011 | Represents a stateful update. 1012 | 1013 | Stateful updates require some update state, such as 1014 | momentum or RMS of previous updates. Therefore you must 1015 | implement some initialization function as well as an update 1016 | function. 1017 | """ 1018 | def stateful({parent_init_fn, parent_apply_fn} \\ identity(), init_fn, apply_fn) do 1019 | init_fn = fn params -> 1020 | state = parent_init_fn.(params) 1021 | Tuple.insert_at(state, 0, init_fn.(params)) 1022 | end 1023 | 1024 | apply_fn = fn updates, state, params -> 1025 | this_state = elem(state, 0) 1026 | other_state = Tuple.delete_at(state, 0) 1027 | {updates, new_other_state} = parent_apply_fn.(updates, other_state, params) 1028 | {updates, new_this_state} = apply_fn.(updates, this_state, params) 1029 | {updates, Tuple.insert_at(new_other_state, 0, new_this_state)} 1030 | end 1031 | 1032 | {init_fn, apply_fn} 1033 | end 1034 | 1035 | @doc """ 1036 | Applies updates to params and updates state parameters with 1037 | given state map. 1038 | """ 1039 | defn apply_updates(params, updates, state \\ nil) do 1040 | new_params = 1041 | deep_merge(params, updates, fn x, u -> 1042 | Nx.add(x, Nx.as_type(u, Nx.type(x))) 1043 | end) 1044 | 1045 | merge_state(new_params, state) 1046 | end 1047 | 1048 | deftransformp merge_state(params, state) do 1049 | case {params, state} do 1050 | {params, nil} -> 1051 | params 1052 | 1053 | {params, state} -> 1054 | merge_inner(params, state) 1055 | end 1056 | end 1057 | 1058 | defp merge_inner(%Nx.Tensor{}, %Nx.Tensor{} = state) do 1059 | state 1060 | end 1061 | 1062 | defp merge_inner(params, state) when is_map(params) and is_map(state) do 1063 | Map.merge(params, state, fn _, s1, s2 -> merge_inner(s1, s2) end) 1064 | end 1065 | 1066 | ## Helpers 1067 | 1068 | defnp update_moment(x, moment, decay, order) do 1069 | deep_merge(x, moment, fn g, z -> (1 - decay) * Nx.pow(g, order) + decay * z end) 1070 | end 1071 | 1072 | defnp bias_correction(moment, decay, count) do 1073 | deep_new(moment, fn z -> z / (1 - Nx.pow(decay, count)) end) 1074 | end 1075 | 1076 | defnp safe_norm(g, min_norm) do 1077 | norm = Nx.LinAlg.norm(g) 1078 | z = Nx.select(Nx.less_equal(norm, min_norm), 1, g) 1079 | masked_norm = Nx.LinAlg.norm(z) 1080 | Nx.select(Nx.less_equal(norm, min_norm), min_norm, masked_norm) 1081 | end 1082 | end 1083 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule Polaris.MixProject do 2 | use Mix.Project 3 | 4 | @source_url "https://github.com/elixir-nx/polaris" 5 | @version "0.1.0" 6 | 7 | def project do 8 | [ 9 | app: :polaris, 10 | version: "0.1.0", 11 | elixir: "~> 1.14", 12 | start_permanent: Mix.env() == :prod, 13 | deps: deps(), 14 | elixirc_paths: elixirc_paths(Mix.env()), 15 | description: "Optimizers for the Nx ecosystem", 16 | docs: docs(), 17 | package: package(), 18 | preferred_cli_env: [ 19 | docs: :docs, 20 | "hex.publish": :docs 21 | ] 22 | ] 23 | end 24 | 25 | defp elixirc_paths(:test), do: ~w(lib test/support) 26 | defp elixirc_paths(_), do: ~w(lib) 27 | 28 | # Run "mix help compile.app" to learn about applications. 29 | def application do 30 | [ 31 | extra_applications: [:logger] 32 | ] 33 | end 34 | 35 | # Run "mix help deps" to learn about dependencies. 36 | defp deps do 37 | [ 38 | {:nx, "~> 0.5"}, 39 | {:ex_doc, "~> 0.29", only: [:docs]} 40 | ] 41 | end 42 | 43 | defp package do 44 | [ 45 | maintainers: ["Sean Moriarity", "Paulo Valente"], 46 | licenses: ["Apache-2.0"], 47 | links: %{"GitHub" => @source_url} 48 | ] 49 | end 50 | 51 | defp docs do 52 | [ 53 | main: "Polaris", 54 | source_ref: "v#{@version}", 55 | # logo: "logo.png", 56 | source_url: @source_url, 57 | before_closing_body_tag: &before_closing_body_tag/1 58 | ] 59 | end 60 | 61 | defp before_closing_body_tag(:html) do 62 | """ 63 | 64 | 65 | 66 | 67 | 77 | 78 | 79 | 80 | 98 | """ 99 | end 100 | 101 | defp before_closing_body_tag(_), do: "" 102 | end 103 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, 3 | "earmark_parser": {:hex, :earmark_parser, "1.4.32", "fa739a0ecfa34493de19426681b23f6814573faee95dfd4b4aafe15a7b5b32c6", [:mix], [], "hexpm", "b8b0dd77d60373e77a3d7e8afa598f325e49e8663a51bcc2b88ef41838cca755"}, 4 | "ex_doc": {:hex, :ex_doc, "0.29.4", "6257ecbb20c7396b1fe5accd55b7b0d23f44b6aa18017b415cb4c2b91d997729", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "2c6699a737ae46cb61e4ed012af931b57b699643b24dabe2400a8168414bc4f5"}, 5 | "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, 6 | "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, 7 | "makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"}, 8 | "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, 9 | "nx": {:hex, :nx, "0.5.3", "6ad5534f9b82429dafa12329952708c2fdd6ab01b306e86333fdea72383147ee", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "d1072fc4423809ed09beb729e73c200ce177ddecac425d9eb6fba643669623ec"}, 10 | "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, 11 | } 12 | -------------------------------------------------------------------------------- /test/polaris/optimizers_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Polaris.OptimizersTest do 2 | use Polaris.Case, async: true 3 | 4 | @learning_rate 1.0e-1 5 | @iterations 100 6 | 7 | describe "adabelief" do 8 | test "correctly optimizes simple loss with default options" do 9 | optimizer = Polaris.Optimizers.adabelief(learning_rate: @learning_rate) 10 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 11 | num_steps = @iterations 12 | x0 = %{"x0" => Nx.tensor(1.0)} 13 | 14 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 15 | end 16 | 17 | test "correctly optimizes simple loss with custom options" do 18 | optimizer = Polaris.Optimizers.adabelief(learning_rate: @learning_rate, b1: 0.95, b2: 0.99) 19 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 20 | num_steps = @iterations 21 | x0 = %{"x0" => Nx.tensor(1.0)} 22 | 23 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 24 | end 25 | 26 | test "correctly optimizes simple loss with schedule" do 27 | optimizer = 28 | Polaris.Optimizers.adabelief(learning_rate: Polaris.Schedules.constant(@learning_rate)) 29 | 30 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 31 | num_steps = @iterations 32 | x0 = %{"x0" => Nx.tensor(1.0)} 33 | 34 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 35 | end 36 | end 37 | 38 | describe "adagrad" do 39 | test "correctly optimizes simple loss with default options" do 40 | optimizer = Polaris.Optimizers.adagrad(learning_rate: @learning_rate) 41 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 42 | num_steps = @iterations 43 | x0 = %{"x0" => Nx.tensor(1.0)} 44 | 45 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 46 | end 47 | 48 | test "correctly optimizes simple loss with custom options" do 49 | optimizer = Polaris.Optimizers.adagrad(learning_rate: @learning_rate, eps: 1.0e-3) 50 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 51 | num_steps = @iterations 52 | x0 = %{"x0" => Nx.tensor(1.0)} 53 | 54 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 55 | end 56 | 57 | test "correctly optimizes simple loss with schedule" do 58 | optimizer = 59 | Polaris.Optimizers.adagrad(learning_rate: Polaris.Schedules.constant(@learning_rate)) 60 | 61 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 62 | num_steps = @iterations 63 | x0 = %{"x0" => Nx.tensor(1.0)} 64 | 65 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 66 | end 67 | end 68 | 69 | describe "adam" do 70 | test "correctly optimizes simple loss with default options" do 71 | optimizer = Polaris.Optimizers.adam(learning_rate: @learning_rate) 72 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 73 | num_steps = @iterations 74 | x0 = %{"x0" => Nx.tensor(1.0)} 75 | 76 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 77 | end 78 | 79 | test "correctly optimizes simple loss with custom options" do 80 | optimizer = Polaris.Optimizers.adam(learning_rate: @learning_rate, b1: 0.95, b2: 0.99) 81 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 82 | num_steps = @iterations 83 | x0 = %{"x0" => Nx.tensor(1.0)} 84 | 85 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 86 | end 87 | 88 | test "correctly optimizes simple loss with schedule" do 89 | optimizer = 90 | Polaris.Optimizers.adam(learning_rate: Polaris.Schedules.constant(@learning_rate)) 91 | 92 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 93 | num_steps = @iterations 94 | x0 = %{"x0" => Nx.tensor(1.0)} 95 | 96 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 97 | end 98 | end 99 | 100 | describe "adamw" do 101 | test "correctly optimizes simple loss with default options" do 102 | optimizer = Polaris.Optimizers.adamw(learning_rate: @learning_rate) 103 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 104 | num_steps = @iterations 105 | x0 = %{"x0" => Nx.tensor(1.0)} 106 | 107 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 108 | end 109 | 110 | test "correctly optimizes simple loss with custom options" do 111 | optimizer = Polaris.Optimizers.adamw(learning_rate: @learning_rate, decay: 0.9) 112 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 113 | num_steps = @iterations 114 | x0 = %{"x0" => Nx.tensor(1.0)} 115 | 116 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 117 | end 118 | 119 | test "correctly optimizes simple loss with schedule" do 120 | optimizer = 121 | Polaris.Optimizers.adamw(learning_rate: Polaris.Schedules.constant(@learning_rate)) 122 | 123 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 124 | num_steps = @iterations 125 | x0 = %{"x0" => Nx.tensor(1.0)} 126 | 127 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 128 | end 129 | end 130 | 131 | describe "lamb" do 132 | test "correctly optimizes simple loss with default options" do 133 | optimizer = Polaris.Optimizers.lamb(learning_rate: @learning_rate) 134 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 135 | num_steps = @iterations 136 | x0 = %{"x0" => Nx.tensor([1.0])} 137 | 138 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 139 | end 140 | 141 | test "correctly optimizes simple loss with custom options" do 142 | optimizer = 143 | Polaris.Optimizers.lamb(learning_rate: @learning_rate, decay: 0.9, min_norm: 0.1) 144 | 145 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 146 | num_steps = @iterations 147 | x0 = %{"x0" => Nx.tensor([1.0])} 148 | 149 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 150 | end 151 | 152 | test "correctly optimizes simple loss with schedule" do 153 | optimizer = 154 | Polaris.Optimizers.lamb(learning_rate: Polaris.Schedules.constant(@learning_rate)) 155 | 156 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 157 | num_steps = @iterations 158 | x0 = %{"x0" => Nx.tensor([1.0])} 159 | 160 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 161 | end 162 | end 163 | 164 | describe "noisy_sgd" do 165 | test "correctly optimizes simple loss with default options" do 166 | optimizer = Polaris.Optimizers.noisy_sgd(learning_rate: @learning_rate) 167 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 168 | num_steps = @iterations 169 | x0 = %{"x0" => Nx.tensor([1.0])} 170 | 171 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 172 | end 173 | 174 | test "correctly optimizes simple loss with custom options" do 175 | optimizer = 176 | Polaris.Optimizers.noisy_sgd(learning_rate: @learning_rate, eta: 0.2, gamma: 0.6) 177 | 178 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 179 | num_steps = @iterations 180 | x0 = %{"x0" => Nx.tensor([1.0])} 181 | 182 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 183 | end 184 | 185 | test "correctly optimizes simple loss with schedule" do 186 | optimizer = 187 | Polaris.Optimizers.noisy_sgd(learning_rate: Polaris.Schedules.constant(@learning_rate)) 188 | 189 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 190 | num_steps = @iterations 191 | x0 = %{"x0" => Nx.tensor(1.0)} 192 | 193 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 194 | end 195 | end 196 | 197 | describe "radam" do 198 | test "correctly optimizes simple loss with default options" do 199 | optimizer = Polaris.Optimizers.radam(learning_rate: @learning_rate) 200 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 201 | num_steps = @iterations 202 | x0 = %{"x0" => Nx.tensor([1.0])} 203 | 204 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 205 | end 206 | 207 | test "correctly optimizes simple loss with custom options" do 208 | optimizer = Polaris.Optimizers.radam(learning_rate: @learning_rate, threshold: 2.0) 209 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 210 | num_steps = @iterations 211 | x0 = %{"x0" => Nx.tensor([1.0])} 212 | 213 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 214 | end 215 | 216 | test "correctly optimizes simple loss with schedule" do 217 | optimizer = 218 | Polaris.Optimizers.radam(learning_rate: Polaris.Schedules.constant(@learning_rate)) 219 | 220 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 221 | num_steps = @iterations 222 | x0 = %{"x0" => Nx.tensor(1.0)} 223 | 224 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 225 | end 226 | end 227 | 228 | describe "rmsprop" do 229 | test "correctly optimizes simple loss default case" do 230 | optimizer = Polaris.Optimizers.rmsprop(learning_rate: @learning_rate) 231 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 232 | num_steps = @iterations 233 | x0 = %{"x0" => Nx.tensor([1.0])} 234 | 235 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 236 | end 237 | 238 | test "correctly optimizes simple loss centered case" do 239 | optimizer = 240 | Polaris.Optimizers.rmsprop( 241 | learning_rate: @learning_rate, 242 | centered: true, 243 | initial_scale: 0.1, 244 | decay: 0.8 245 | ) 246 | 247 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 248 | num_steps = @iterations 249 | x0 = %{"x0" => Nx.tensor([1.0])} 250 | 251 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 252 | end 253 | 254 | test "correctly optimizes simple loss rms case" do 255 | optimizer = 256 | Polaris.Optimizers.rmsprop(learning_rate: @learning_rate, initial_scale: 0.1, decay: 0.8) 257 | 258 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 259 | num_steps = @iterations 260 | x0 = %{"x0" => Nx.tensor([1.0])} 261 | 262 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 263 | end 264 | 265 | test "correctly optimizes simple loss with momentum" do 266 | optimizer = 267 | Polaris.Optimizers.rmsprop( 268 | learning_rate: @learning_rate, 269 | initial_scale: 0.1, 270 | decay: 0.8, 271 | momentum: 0.9 272 | ) 273 | 274 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 275 | num_steps = @iterations 276 | x0 = %{"x0" => Nx.tensor([1.0])} 277 | 278 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 279 | end 280 | 281 | test "correctly optimizes simple loss with schedule" do 282 | optimizer = 283 | Polaris.Optimizers.rmsprop(learning_rate: Polaris.Schedules.constant(@learning_rate)) 284 | 285 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 286 | num_steps = @iterations 287 | x0 = %{"x0" => Nx.tensor(1.0)} 288 | 289 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 290 | end 291 | end 292 | 293 | describe "sgd" do 294 | test "correctly optimizes simple loss with default options" do 295 | optimizer = Polaris.Optimizers.sgd(learning_rate: @learning_rate) 296 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 297 | num_steps = @iterations 298 | x0 = %{"x0" => Nx.tensor([1.0])} 299 | 300 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 301 | end 302 | 303 | test "correctly optimizes simple loss with custom options" do 304 | optimizer = Polaris.Optimizers.sgd(learning_rate: @learning_rate, momentum: 0.9) 305 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 306 | num_steps = @iterations 307 | x0 = %{"x0" => Nx.tensor([1.0])} 308 | 309 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 310 | end 311 | 312 | test "correctly optimizes simple loss with schedule" do 313 | optimizer = 314 | Polaris.Optimizers.sgd(learning_rate: Polaris.Schedules.constant(@learning_rate)) 315 | 316 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 317 | num_steps = @iterations 318 | x0 = %{"x0" => Nx.tensor(1.0)} 319 | 320 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 321 | end 322 | end 323 | 324 | describe "yogi" do 325 | test "correctly optimizes simple loss with default options" do 326 | optimizer = Polaris.Optimizers.yogi(learning_rate: @learning_rate) 327 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 328 | num_steps = @iterations 329 | x0 = %{"x0" => Nx.tensor([1.0])} 330 | 331 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 332 | end 333 | 334 | test "correctly optimizes simple loss with custom options" do 335 | optimizer = 336 | Polaris.Optimizers.yogi(learning_rate: @learning_rate, initial_accumulator_value: 0.1) 337 | 338 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 339 | num_steps = @iterations 340 | x0 = %{"x0" => Nx.tensor([1.0])} 341 | 342 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 343 | end 344 | 345 | test "correctly optimizes simple loss with schedule" do 346 | optimizer = 347 | Polaris.Optimizers.yogi(learning_rate: Polaris.Schedules.constant(@learning_rate)) 348 | 349 | loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end 350 | num_steps = @iterations 351 | x0 = %{"x0" => Nx.tensor(1.0)} 352 | 353 | check_optimizer!(optimizer, loss_fn, x0, num_steps) 354 | end 355 | end 356 | end 357 | -------------------------------------------------------------------------------- /test/polaris/schedules_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Polaris.SchedulesTest do 2 | use Polaris.Case 3 | doctest Polaris.Schedules 4 | 5 | import Polaris.Schedules 6 | import Nx.Defn 7 | 8 | describe "exponential_decay" do 9 | test "returns arity-1 function with defaults" do 10 | fun = exponential_decay(1.0e-2) 11 | assert is_function(fun, 1) 12 | end 13 | 14 | test "returns arity-1 function with options" do 15 | fun = exponential_decay(1.0e-3, decay_rate: 0.9) 16 | assert is_function(fun, 1) 17 | end 18 | 19 | test "can be called as anonymous function" do 20 | fun = exponential_decay(1.0e-2) 21 | assert_all_close(fun.(0), 1.0e-2) 22 | 23 | fun = exponential_decay(1.0e-3) 24 | assert_all_close(fun.(0), 1.0e-3) 25 | end 26 | 27 | test "can be called within JIT" do 28 | fun = exponential_decay(1.0e-2) 29 | assert_all_close(apply(jit(fun), [0]), 1.0e-2) 30 | 31 | fun = exponential_decay(1.0e-3) 32 | assert_all_close(apply(jit(fun), [0]), 1.0e-3) 33 | end 34 | 35 | test "matches optax values at different counts" do 36 | fun1 = exponential_decay(1.0e-2, decay_rate: 0.9, transition_steps: 15) 37 | 38 | assert_all_close(fun1.(0), 1.0e-2) 39 | assert_all_close(fun1.(25), 0.008389527) 40 | assert_all_close(fun1.(50), 0.007038417) 41 | assert_all_close(fun1.(1000), 8.902254e-06) 42 | assert_all_close(fun1.(100_000), 0.0) 43 | 44 | fun2 = exponential_decay(1.0e-3, decay_rate: 0.99, transition_steps: 100) 45 | 46 | assert_all_close(fun2.(0), 1.0e-3) 47 | assert_all_close(fun2.(25), 0.0009974906) 48 | assert_all_close(fun2.(50), 0.0009949874) 49 | assert_all_close(fun2.(1000), 0.0009043822) 50 | assert_all_close(fun2.(100_000), 4.3171664e-08) 51 | 52 | fun3 = 53 | exponential_decay( 54 | 1.0e-1, 55 | decay_rate: 0.99, 56 | transition_begin: 100, 57 | transition_steps: 25 58 | ) 59 | 60 | assert_all_close(fun3.(0), 0.1) 61 | assert_all_close(fun3.(25), 0.1) 62 | assert_all_close(fun3.(50), 0.1) 63 | assert_all_close(fun3.(1000), 0.069641344) 64 | assert_all_close(fun3.(100_000), 3.6162157e-19) 65 | end 66 | end 67 | 68 | describe "cosine_decay" do 69 | test "returns arity-1 function with defaults" do 70 | fun = cosine_decay(1.0e-2) 71 | assert is_function(fun, 1) 72 | end 73 | 74 | test "returns arity-1 function with options" do 75 | fun = cosine_decay(1.0e-3, decay_steps: 5) 76 | assert is_function(fun, 1) 77 | end 78 | 79 | test "can be called as anonymous function" do 80 | fun = cosine_decay(1.0e-2) 81 | assert_all_close(fun.(0), 1.0e-2) 82 | 83 | fun = cosine_decay(1.0e-3) 84 | assert_all_close(fun.(0), 1.0e-3) 85 | end 86 | 87 | test "can be called within JIT" do 88 | fun = cosine_decay(1.0e-2) 89 | assert_all_close(apply(jit(fun), [0]), 1.0e-2) 90 | 91 | fun = cosine_decay(1.0e-3) 92 | assert_all_close(apply(jit(fun), [0]), 1.0e-3) 93 | end 94 | 95 | test "matches optax values at different counts" do 96 | fun1 = cosine_decay(1.0e-3, decay_steps: 10, alpha: 0.0) 97 | 98 | assert_all_close(fun1.(0), 0.001) 99 | assert_all_close(fun1.(25), 0.0) 100 | assert_all_close(fun1.(50), 0.00) 101 | assert_all_close(fun1.(1000), 0.0) 102 | assert_all_close(fun1.(100_000), 0.0) 103 | 104 | fun2 = cosine_decay(1.0e-2, decay_steps: 1000, alpha: 0.5) 105 | 106 | assert_all_close(fun2.(0), 0.01) 107 | assert_all_close(fun2.(25), 0.009992293) 108 | assert_all_close(fun2.(50), 0.0099692205) 109 | assert_all_close(fun2.(1000), 0.005) 110 | assert_all_close(fun2.(100_000), 0.005) 111 | 112 | fun3 = cosine_decay(1.0e-1, decay_steps: 1) 113 | 114 | assert_all_close(fun3.(0), 0.1) 115 | assert_all_close(fun3.(25), 0.0) 116 | assert_all_close(fun3.(50), 0.0) 117 | assert_all_close(fun3.(1000), 0.0) 118 | assert_all_close(fun3.(100_000), 0.0) 119 | end 120 | end 121 | 122 | describe "constant" do 123 | test "returns arity-1 function with defaults" do 124 | fun = constant(1.0e-2) 125 | assert is_function(fun, 1) 126 | end 127 | 128 | test "can be called as anonymous function" do 129 | fun = constant(1.0e-2) 130 | assert_all_close(fun.(0), 1.0e-2) 131 | 132 | fun = cosine_decay(1.0e-3) 133 | assert_all_close(fun.(0), 1.0e-3) 134 | end 135 | 136 | test "can be called within JIT" do 137 | fun = constant(1.0e-2) 138 | assert_all_close(apply(jit(fun), [0]), 1.0e-2) 139 | 140 | fun = constant(1.0e-3) 141 | assert_all_close(apply(jit(fun), [0]), 1.0e-3) 142 | end 143 | 144 | test "matches optax values at different counts" do 145 | fun1 = constant(1.0e-3) 146 | 147 | assert_all_close(fun1.(0), 0.001) 148 | assert_all_close(fun1.(25), 0.001) 149 | assert_all_close(fun1.(50), 0.001) 150 | assert_all_close(fun1.(1000), 0.001) 151 | assert_all_close(fun1.(100_000), 0.001) 152 | 153 | fun2 = constant(1.0e-2) 154 | 155 | assert_all_close(fun2.(0), 0.01) 156 | assert_all_close(fun2.(25), 0.01) 157 | assert_all_close(fun2.(50), 0.01) 158 | assert_all_close(fun2.(1000), 0.01) 159 | assert_all_close(fun2.(100_000), 0.01) 160 | 161 | fun3 = constant(1.0e-1) 162 | 163 | assert_all_close(fun3.(0), 0.1) 164 | assert_all_close(fun3.(25), 0.1) 165 | assert_all_close(fun3.(50), 0.1) 166 | assert_all_close(fun3.(1000), 0.1) 167 | assert_all_close(fun3.(100_000), 0.1) 168 | end 169 | end 170 | 171 | describe "polynomial_decay" do 172 | test "returns arity-1 function with defaults" do 173 | fun = polynomial_decay(1.0e-2) 174 | assert is_function(fun, 1) 175 | end 176 | 177 | test "returns arity-1 function with options" do 178 | fun = polynomial_decay(1.0e-3, end_value: 1.0e-4) 179 | assert is_function(fun, 1) 180 | end 181 | 182 | test "can be called as anonymous function" do 183 | fun = polynomial_decay(1.0e-2) 184 | assert_all_close(fun.(0), 1.0e-2) 185 | 186 | fun = polynomial_decay(1.0e-3) 187 | assert_all_close(fun.(0), 1.0e-3) 188 | end 189 | 190 | test "can be called within JIT" do 191 | fun = polynomial_decay(1.0e-2) 192 | assert_all_close(apply(jit(fun), [0]), 1.0e-2) 193 | 194 | fun = polynomial_decay(1.0e-3, end_value: 1.0e-4) 195 | assert_all_close(apply(jit(fun), [0]), 1.0e-3) 196 | end 197 | 198 | test "matches optax values at different counts" do 199 | fun1 = polynomial_decay(1.0e-2, end_value: 1.0e-3, power: 3, transition_steps: 1000) 200 | 201 | assert_all_close(fun1.(0), 0.01) 202 | assert_all_close(fun1.(25), 0.009341734) 203 | assert_all_close(fun1.(50), 0.008716375) 204 | assert_all_close(fun1.(1000), 0.001) 205 | assert_all_close(fun1.(100_000), 0.001) 206 | 207 | fun2 = polynomial_decay(1.0e-3, end_value: 1.0e-4, transition_begin: 100, power: 2) 208 | 209 | assert_all_close(fun2.(0), 0.001) 210 | assert_all_close(fun2.(25), 0.001) 211 | assert_all_close(fun2.(50), 0.001) 212 | assert_all_close(fun2.(1000), 0.0001) 213 | assert_all_close(fun2.(100_000), 0.0001) 214 | 215 | fun3 = 216 | polynomial_decay( 217 | 1.0e-1, 218 | end_value: 1.0e-3, 219 | transition_steps: 10000, 220 | power: 1.5 221 | ) 222 | 223 | assert_all_close(fun3.(0), 0.1) 224 | assert_all_close(fun3.(25), 0.099628985) 225 | assert_all_close(fun3.(50), 0.09925843) 226 | assert_all_close(fun3.(1000), 0.08552768) 227 | assert_all_close(fun3.(100_000), 0.001) 228 | end 229 | end 230 | end 231 | -------------------------------------------------------------------------------- /test/polaris/updates_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Polaris.UpdatesTest do 2 | use Polaris.Case 3 | doctest Polaris.Updates 4 | 5 | import Polaris.Updates 6 | 7 | describe "add_decayed_weights" do 8 | test "constructs a stateless transformation" do 9 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 10 | assert {init_fn, update_fn} = add_decayed_weights() 11 | assert is_function(init_fn, 1) 12 | assert is_function(update_fn, 3) 13 | assert init_fn.(params) == {} 14 | end 15 | 16 | test "constructs a stateless transformation with options" do 17 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 18 | assert {init_fn, update_fn} = add_decayed_weights(decay: 0.95) 19 | assert is_function(init_fn, 1) 20 | assert is_function(update_fn, 3) 21 | assert init_fn.(params) == {} 22 | end 23 | 24 | test "composes with itself" do 25 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 26 | 27 | assert {init_fn, update_fn} = 28 | add_decayed_weights(decay: 0.95) |> add_decayed_weights(decay: 0.95) 29 | 30 | assert is_function(init_fn, 1) 31 | assert is_function(update_fn, 3) 32 | assert init_fn.(params) == {} 33 | end 34 | 35 | test "composes with stateful transformation" do 36 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 37 | assert {init_fn, update_fn} = scale_by_adam() |> add_decayed_weights(decay: 0.95) 38 | assert is_function(init_fn, 1) 39 | assert is_function(update_fn, 3) 40 | assert {adam_state} = init_fn.(params) 41 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state 42 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 43 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 44 | assert_equal(count, Nx.tensor(0)) 45 | end 46 | 47 | test "matches optax with simple container" do 48 | assert {init_fn, update_fn} = add_decayed_weights(decay: 0.95) 49 | params = %{a: Nx.tensor([0.18884168, 0.92323774, 0.4513516])} 50 | updates = %{a: Nx.tensor([0.62370003, 0.86674502, 0.11204521])} 51 | state = init_fn.(params) 52 | 53 | expected_a = Nx.tensor([0.80309962, 1.74382088, 0.54082923]) 54 | 55 | assert {new_updates, new_state} = update_fn.(updates, state, params) 56 | assert %{a: actual_a} = new_updates 57 | assert new_state == {} 58 | assert_all_close(actual_a, expected_a) 59 | end 60 | 61 | test "matches optax with nested container" do 62 | assert {init_fn, update_fn} = add_decayed_weights(decay: 0.95) 63 | 64 | params = %{ 65 | a: %{ 66 | b: Nx.tensor([0.26106195, 0.52850289, 0.19788291]), 67 | c: %{d: %{}, e: Nx.tensor([[0.7100145, 0.41356265, 0.35657979]])} 68 | } 69 | } 70 | 71 | updates = %{ 72 | a: %{ 73 | b: Nx.tensor([0.83834362, 0.75873946, 0.54735649]), 74 | c: %{d: %{}, e: Nx.tensor([[0.7384456, 0.76676084, 0.72992148]])} 75 | } 76 | } 77 | 78 | state = init_fn.(params) 79 | 80 | expected_b = Nx.tensor([1.08635247, 1.26081721, 0.73534525]) 81 | expected_e = Nx.tensor([[1.41295937, 1.15964536, 1.06867228]]) 82 | 83 | assert {new_updates, new_state} = update_fn.(updates, state, params) 84 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 85 | assert new_state == {} 86 | assert_all_close(actual_b, expected_b) 87 | assert_all_close(actual_e, expected_e) 88 | end 89 | 90 | test "supports generic container" do 91 | assert {init_fn, update_fn} = add_decayed_weights(decay: 0.95) 92 | 93 | params = { 94 | { 95 | Nx.tensor([0.26106195, 0.52850289, 0.19788291]), 96 | {{}, Nx.tensor([[0.7100145, 0.41356265, 0.35657979]])} 97 | } 98 | } 99 | 100 | updates = { 101 | { 102 | Nx.tensor([0.83834362, 0.75873946, 0.54735649]), 103 | {{}, Nx.tensor([[0.7384456, 0.76676084, 0.72992148]])} 104 | } 105 | } 106 | 107 | state = init_fn.(params) 108 | 109 | expected_b = Nx.tensor([1.08635247, 1.26081721, 0.73534525]) 110 | expected_e = Nx.tensor([[1.41295937, 1.15964536, 1.06867228]]) 111 | 112 | assert {new_updates, new_state} = update_fn.(updates, state, params) 113 | assert {{actual_b, {{}, actual_e}}} = new_updates 114 | assert new_state == {} 115 | assert_all_close(actual_b, expected_b) 116 | assert_all_close(actual_e, expected_e) 117 | end 118 | end 119 | 120 | describe "add_noise" do 121 | test "constructs a stateful transformation" do 122 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 123 | assert {init_fn, update_fn} = add_noise() 124 | assert is_function(init_fn, 1) 125 | assert is_function(update_fn, 3) 126 | assert {add_noise_state} = init_fn.(params) 127 | assert %{count: count} = add_noise_state 128 | assert_equal(count, Nx.tensor(0)) 129 | end 130 | 131 | test "constructs a stateful transformation with options" do 132 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 133 | assert {init_fn, update_fn} = add_noise(gamma: 1.0) 134 | assert is_function(init_fn, 1) 135 | assert is_function(update_fn, 3) 136 | assert {add_noise_state} = init_fn.(params) 137 | assert %{count: count} = add_noise_state 138 | assert_equal(count, Nx.tensor(0)) 139 | end 140 | end 141 | 142 | describe "clip" do 143 | test "constructs a stateless transformation" do 144 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 145 | assert {init_fn, update_fn} = clip() 146 | assert is_function(init_fn, 1) 147 | assert is_function(update_fn, 3) 148 | assert init_fn.(params) == {} 149 | end 150 | 151 | test "constructs a stateless transformation with options" do 152 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 153 | assert {init_fn, update_fn} = clip(delta: 1.0) 154 | assert is_function(init_fn, 1) 155 | assert is_function(update_fn, 3) 156 | assert init_fn.(params) == {} 157 | end 158 | 159 | test "composes with itself" do 160 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 161 | assert {init_fn, update_fn} = clip(delta: 2.0) |> clip(delta: 2.0) 162 | assert is_function(init_fn, 1) 163 | assert is_function(update_fn, 3) 164 | assert init_fn.(params) == {} 165 | end 166 | 167 | test "composes with stateful transformation" do 168 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 169 | assert {init_fn, update_fn} = scale_by_adam() |> clip(delta: 2.0) 170 | assert is_function(init_fn, 1) 171 | assert is_function(update_fn, 3) 172 | assert {adam_state} = init_fn.(params) 173 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state 174 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 175 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 176 | assert_equal(count, Nx.tensor(0)) 177 | end 178 | 179 | test "matches optax with simple container" do 180 | assert {init_fn, update_fn} = clip(delta: 2.0) 181 | params = %{a: Nx.tensor([0.74794595, 0.99105549, 0.5621627])} 182 | updates = %{a: Nx.tensor([0.84208747, 0.69837738, 0.61840895])} 183 | state = init_fn.(params) 184 | 185 | expected_a = Nx.tensor([0.84208745, 0.6983774, 0.618409]) 186 | 187 | assert {new_updates, new_state} = update_fn.(updates, state, params) 188 | assert %{a: actual_a} = new_updates 189 | assert new_state == {} 190 | assert_all_close(actual_a, expected_a) 191 | end 192 | 193 | test "matches optax with nested container" do 194 | assert {init_fn, update_fn} = clip(delta: 1.0) 195 | 196 | params = %{ 197 | a: %{ 198 | b: Nx.tensor([0.62866726, 0.04867021, 0.66160428]), 199 | c: %{d: %{}, e: Nx.tensor([0.70566323, 0.52083707, 0.14541595])} 200 | } 201 | } 202 | 203 | updates = %{ 204 | a: %{ 205 | b: Nx.tensor([0.19084232, 0.09963277, 0.28141486]), 206 | c: %{d: %{}, e: Nx.tensor([0.91124607, 0.2248316, 0.79530217])} 207 | } 208 | } 209 | 210 | state = init_fn.(params) 211 | 212 | expected_b = Nx.tensor([0.19084232, 0.09963277, 0.28141487]) 213 | expected_e = Nx.tensor([0.91124606, 0.2248316, 0.79530215]) 214 | 215 | assert {new_updates, new_state} = update_fn.(updates, state, params) 216 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 217 | assert new_state == {} 218 | assert_all_close(actual_b, expected_b) 219 | assert_all_close(actual_e, expected_e) 220 | end 221 | 222 | test "supports generic container" do 223 | assert {init_fn, update_fn} = clip(delta: 1.0) 224 | 225 | params = { 226 | { 227 | Nx.tensor([0.62866726, 0.04867021, 0.66160428]), 228 | {{}, Nx.tensor([0.70566323, 0.52083707, 0.14541595])} 229 | } 230 | } 231 | 232 | updates = { 233 | { 234 | Nx.tensor([0.19084232, 0.09963277, 0.28141486]), 235 | {{}, Nx.tensor([0.91124607, 0.2248316, 0.79530217])} 236 | } 237 | } 238 | 239 | state = init_fn.(params) 240 | 241 | expected_b = Nx.tensor([0.19084232, 0.09963277, 0.28141487]) 242 | expected_e = Nx.tensor([0.91124606, 0.2248316, 0.79530215]) 243 | 244 | assert {new_updates, new_state} = update_fn.(updates, state, params) 245 | assert {{actual_b, {{}, actual_e}}} = new_updates 246 | assert new_state == {} 247 | assert_all_close(actual_b, expected_b) 248 | assert_all_close(actual_e, expected_e) 249 | end 250 | end 251 | 252 | describe "clip_by_global_norm" do 253 | test "constructs a stateless transformation" do 254 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 255 | assert {init_fn, update_fn} = clip_by_global_norm() 256 | assert is_function(init_fn, 1) 257 | assert is_function(update_fn, 3) 258 | assert init_fn.(params) == {} 259 | end 260 | 261 | test "constructs a stateless transformation with options" do 262 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 263 | assert {init_fn, update_fn} = clip_by_global_norm(max_norm: 1.0) 264 | assert is_function(init_fn, 1) 265 | assert is_function(update_fn, 3) 266 | assert init_fn.(params) == {} 267 | end 268 | 269 | test "composes with itself" do 270 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 271 | 272 | assert {init_fn, update_fn} = 273 | clip_by_global_norm(max_norm: 1.0) |> clip_by_global_norm(max_norm: 1.0) 274 | 275 | assert is_function(init_fn, 1) 276 | assert is_function(update_fn, 3) 277 | assert init_fn.(params) == {} 278 | end 279 | 280 | test "composes with stateful transformation" do 281 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 282 | assert {init_fn, update_fn} = scale_by_adam() |> clip_by_global_norm(max_norm: 1.0) 283 | assert is_function(init_fn, 1) 284 | assert is_function(update_fn, 3) 285 | assert {adam_state} = init_fn.(params) 286 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state 287 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 288 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 289 | assert_equal(count, Nx.tensor(0)) 290 | end 291 | 292 | test "matches optax with simple container" do 293 | assert {init_fn, update_fn} = clip_by_global_norm(max_norm: 1.0) 294 | params = %{a: Nx.tensor([0.72673265, 0.35788219, 0.75329067])} 295 | updates = %{a: Nx.tensor([0.68235248, 0.56976359, 0.79599518])} 296 | state = init_fn.(params) 297 | 298 | expected_a = Nx.tensor([0.571844, 0.47748914, 0.667082]) 299 | 300 | assert {new_updates, new_state} = update_fn.(updates, state, params) 301 | assert %{a: actual_a} = new_updates 302 | assert new_state == {} 303 | assert_all_close(actual_a, expected_a) 304 | end 305 | 306 | test "matches optax with nested container" do 307 | assert {init_fn, update_fn} = clip_by_global_norm(max_norm: 1.0) 308 | 309 | params = %{ 310 | a: %{ 311 | b: Nx.tensor([0.85107357, 0.67088125, 0.59811338]), 312 | c: %{d: %{}, e: Nx.tensor([0.45385324, 0.05131562, 0.91526984])} 313 | } 314 | } 315 | 316 | updates = %{ 317 | a: %{ 318 | b: Nx.tensor([0.59629243, 0.86219328, 0.30155944]), 319 | c: %{d: %{}, e: Nx.tensor([0.83792943, 0.22030587, 0.72606433])} 320 | } 321 | } 322 | 323 | state = init_fn.(params) 324 | 325 | expected_b = Nx.tensor([0.3795878, 0.54885495, 0.1919667]) 326 | expected_e = Nx.tensor([0.53340906, 0.14024231, 0.462198]) 327 | 328 | assert {new_updates, new_state} = update_fn.(updates, state, params) 329 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 330 | assert new_state == {} 331 | assert_all_close(actual_b, expected_b) 332 | assert_all_close(actual_e, expected_e) 333 | end 334 | 335 | test "supports generic container" do 336 | assert {init_fn, update_fn} = clip_by_global_norm(max_norm: 1.0) 337 | 338 | params = { 339 | { 340 | Nx.tensor([0.85107357, 0.67088125, 0.59811338]), 341 | {{}, Nx.tensor([0.45385324, 0.05131562, 0.91526984])} 342 | } 343 | } 344 | 345 | updates = { 346 | { 347 | Nx.tensor([0.59629243, 0.86219328, 0.30155944]), 348 | {{}, Nx.tensor([0.83792943, 0.22030587, 0.72606433])} 349 | } 350 | } 351 | 352 | state = init_fn.(params) 353 | 354 | expected_b = Nx.tensor([0.3795878, 0.54885495, 0.1919667]) 355 | expected_e = Nx.tensor([0.53340906, 0.14024231, 0.462198]) 356 | 357 | assert {new_updates, new_state} = update_fn.(updates, state, params) 358 | assert {{actual_b, {{}, actual_e}}} = new_updates 359 | assert new_state == {} 360 | assert_all_close(actual_b, expected_b) 361 | assert_all_close(actual_e, expected_e) 362 | end 363 | end 364 | 365 | describe "centralize" do 366 | test "constructs a stateless transformation" do 367 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 368 | assert {init_fn, update_fn} = centralize() 369 | assert is_function(init_fn, 1) 370 | assert is_function(update_fn, 3) 371 | assert init_fn.(params) == {} 372 | end 373 | 374 | test "composes with itself" do 375 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 376 | assert {init_fn, update_fn} = centralize() |> centralize() 377 | assert is_function(init_fn, 1) 378 | assert is_function(update_fn, 3) 379 | assert init_fn.(params) == {} 380 | end 381 | 382 | test "composes with stateful transformation" do 383 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 384 | assert {init_fn, update_fn} = scale_by_adam() |> centralize() 385 | assert is_function(init_fn, 1) 386 | assert is_function(update_fn, 3) 387 | assert {adam_state} = init_fn.(params) 388 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state 389 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 390 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 391 | assert_equal(count, Nx.tensor(0)) 392 | end 393 | 394 | test "matches optax with simple container" do 395 | assert {init_fn, update_fn} = centralize() 396 | params = %{a: Nx.tensor([0.14574998, 0.53619206, 0.68726124])} 397 | updates = %{a: Nx.tensor([0.05166196, 0.3979764, 0.84524461])} 398 | state = init_fn.(params) 399 | 400 | expected_a = Nx.tensor([0.05166196, 0.3979764, 0.84524461]) 401 | 402 | assert {new_updates, new_state} = update_fn.(updates, state, params) 403 | assert %{a: actual_a} = new_updates 404 | assert new_state == {} 405 | assert_all_close(actual_a, expected_a) 406 | end 407 | 408 | test "matches optax with nested container" do 409 | assert {init_fn, update_fn} = centralize() 410 | 411 | params = %{ 412 | a: %{ 413 | b: Nx.tensor([0.21855268, 0.21286796, 0.83114509]), 414 | c: %{d: %{}, e: Nx.tensor([[0.26958357, 0.59519575, 0.87732692]])} 415 | } 416 | } 417 | 418 | updates = %{ 419 | a: %{ 420 | b: Nx.tensor([0.41087112, 0.97778015, 0.51054674]), 421 | c: %{d: %{}, e: Nx.tensor([[0.20577277, 0.95319838, 0.14168365]])} 422 | } 423 | } 424 | 425 | state = init_fn.(params) 426 | 427 | expected_b = Nx.tensor([0.41087112, 0.97778015, 0.51054674]) 428 | expected_e = Nx.tensor([[-0.22777883, 0.51964678, -0.29186795]]) 429 | 430 | assert {new_updates, new_state} = update_fn.(updates, state, params) 431 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 432 | assert new_state == {} 433 | assert_all_close(actual_b, expected_b) 434 | assert_all_close(actual_e, expected_e) 435 | end 436 | 437 | test "supports generic container" do 438 | assert {init_fn, update_fn} = centralize() 439 | 440 | params = { 441 | { 442 | Nx.tensor([0.21855268, 0.21286796, 0.83114509]), 443 | {{}, Nx.tensor([[0.26958357, 0.59519575, 0.87732692]])} 444 | } 445 | } 446 | 447 | updates = { 448 | { 449 | Nx.tensor([0.41087112, 0.97778015, 0.51054674]), 450 | {{}, Nx.tensor([[0.20577277, 0.95319838, 0.14168365]])} 451 | } 452 | } 453 | 454 | state = init_fn.(params) 455 | 456 | expected_b = Nx.tensor([0.41087112, 0.97778015, 0.51054674]) 457 | expected_e = Nx.tensor([[-0.22777883, 0.51964678, -0.29186795]]) 458 | 459 | assert {new_updates, new_state} = update_fn.(updates, state, params) 460 | assert {{actual_b, {{}, actual_e}}} = new_updates 461 | assert new_state == {} 462 | assert_all_close(actual_b, expected_b) 463 | assert_all_close(actual_e, expected_e) 464 | end 465 | end 466 | 467 | describe "identity" do 468 | test "constructs a stateless transformation" do 469 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 470 | assert {init_fn, update_fn} = identity() 471 | assert is_function(init_fn, 1) 472 | assert is_function(update_fn, 3) 473 | assert init_fn.(params) == {} 474 | end 475 | 476 | test "composes with itself" do 477 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 478 | assert {init_fn, update_fn} = identity() |> identity() 479 | assert is_function(init_fn, 1) 480 | assert is_function(update_fn, 3) 481 | assert init_fn.(params) == {} 482 | end 483 | 484 | test "composes with stateful transformation" do 485 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 486 | assert {init_fn, update_fn} = scale_by_adam() |> identity() 487 | assert is_function(init_fn, 1) 488 | assert is_function(update_fn, 3) 489 | assert {adam_state} = init_fn.(params) 490 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state 491 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 492 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 493 | assert_equal(count, Nx.tensor(0)) 494 | end 495 | 496 | test "matches optax with simple container" do 497 | assert {init_fn, update_fn} = identity() 498 | params = %{a: Nx.tensor([0.18884168, 0.92323774, 0.4513516])} 499 | updates = %{a: Nx.tensor([0.62370003, 0.86674502, 0.11204521])} 500 | state = init_fn.(params) 501 | 502 | expected_a = Nx.tensor([0.62370003, 0.86674502, 0.11204521]) 503 | 504 | assert {new_updates, new_state} = update_fn.(updates, state, params) 505 | assert %{a: actual_a} = new_updates 506 | assert new_state == {} 507 | assert_all_close(actual_a, expected_a) 508 | end 509 | 510 | test "matches optax with nested container" do 511 | assert {init_fn, update_fn} = identity() 512 | 513 | params = %{ 514 | a: %{ 515 | b: Nx.tensor([0.26106195, 0.52850289, 0.19788291]), 516 | c: %{d: %{}, e: Nx.tensor([[0.7100145, 0.41356265, 0.35657979]])} 517 | } 518 | } 519 | 520 | updates = %{ 521 | a: %{ 522 | b: Nx.tensor([0.83834362, 0.75873946, 0.54735649]), 523 | c: %{d: %{}, e: Nx.tensor([[0.7384456, 0.76676084, 0.72992148]])} 524 | } 525 | } 526 | 527 | state = init_fn.(params) 528 | 529 | expected_b = Nx.tensor([0.83834362, 0.75873946, 0.54735649]) 530 | expected_e = Nx.tensor([[0.7384456, 0.76676084, 0.72992148]]) 531 | 532 | assert {new_updates, new_state} = update_fn.(updates, state, params) 533 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 534 | assert new_state == {} 535 | assert_all_close(actual_b, expected_b) 536 | assert_all_close(actual_e, expected_e) 537 | end 538 | 539 | test "supports generic container" do 540 | assert {init_fn, update_fn} = identity() 541 | 542 | params = { 543 | { 544 | Nx.tensor([0.26106195, 0.52850289, 0.19788291]), 545 | {{}, Nx.tensor([[0.7100145, 0.41356265, 0.35657979]])} 546 | } 547 | } 548 | 549 | updates = { 550 | { 551 | Nx.tensor([0.83834362, 0.75873946, 0.54735649]), 552 | {{}, Nx.tensor([[0.7384456, 0.76676084, 0.72992148]])} 553 | } 554 | } 555 | 556 | state = init_fn.(params) 557 | 558 | expected_b = Nx.tensor([0.83834362, 0.75873946, 0.54735649]) 559 | expected_e = Nx.tensor([[0.7384456, 0.76676084, 0.72992148]]) 560 | 561 | assert {new_updates, new_state} = update_fn.(updates, state, params) 562 | assert {{actual_b, {{}, actual_e}}} = new_updates 563 | assert new_state == {} 564 | assert_all_close(actual_b, expected_b) 565 | assert_all_close(actual_e, expected_e) 566 | end 567 | end 568 | 569 | describe "scale" do 570 | test "constructs a stateless transformation" do 571 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 572 | assert {init_fn, update_fn} = scale(1.0e-2) 573 | assert is_function(init_fn, 1) 574 | assert is_function(update_fn, 3) 575 | assert init_fn.(params) == {} 576 | end 577 | 578 | test "composes with itself" do 579 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 580 | assert {init_fn, update_fn} = scale(1.0e-2) |> scale(1.0e-2) 581 | assert is_function(init_fn, 1) 582 | assert is_function(update_fn, 3) 583 | assert init_fn.(params) == {} 584 | end 585 | 586 | test "composes with stateful transformation" do 587 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 588 | assert {init_fn, update_fn} = scale_by_adam() |> scale(1.0e-2) 589 | assert is_function(init_fn, 1) 590 | assert is_function(update_fn, 3) 591 | assert {adam_state} = init_fn.(params) 592 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state 593 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 594 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 595 | assert_equal(count, Nx.tensor(0)) 596 | end 597 | 598 | test "matches optax with simple container" do 599 | assert {init_fn, update_fn} = scale(1.0e-2) 600 | params = %{a: Nx.tensor([0.29887561, 0.70429164, 0.43314898])} 601 | updates = %{a: Nx.tensor([0.2584395, 0.35890494, 0.84845509])} 602 | state = init_fn.(params) 603 | 604 | expected_a = Nx.tensor([0.00258439, 0.00358905, 0.00848455]) 605 | 606 | assert {new_updates, new_state} = update_fn.(updates, state, params) 607 | assert %{a: actual_a} = new_updates 608 | assert new_state == {} 609 | assert_all_close(actual_a, expected_a) 610 | end 611 | 612 | test "matches optax with nested container" do 613 | assert {init_fn, update_fn} = scale(1.0e-2) 614 | 615 | params = %{ 616 | a: %{ 617 | b: Nx.tensor([0.58813851, 0.27981229, 0.17335737]), 618 | c: %{d: %{}, e: Nx.tensor([0.21444265, 0.63923396, 0.12755156])} 619 | } 620 | } 621 | 622 | updates = %{ 623 | a: %{ 624 | b: Nx.tensor([0.48363215, 0.7147937, 0.32252682]), 625 | c: %{d: %{}, e: Nx.tensor([0.09518468, 0.38613084, 0.20729078])} 626 | } 627 | } 628 | 629 | state = init_fn.(params) 630 | 631 | expected_b = Nx.tensor([0.00483632, 0.00714794, 0.00322527]) 632 | expected_e = Nx.tensor([0.00095185, 0.00386131, 0.00207291]) 633 | 634 | assert {new_updates, new_state} = update_fn.(updates, state, params) 635 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 636 | assert new_state == {} 637 | assert_all_close(actual_b, expected_b) 638 | assert_all_close(actual_e, expected_e) 639 | end 640 | 641 | test "supports generic container" do 642 | assert {init_fn, update_fn} = scale(1.0e-2) 643 | 644 | params = { 645 | { 646 | Nx.tensor([0.58813851, 0.27981229, 0.17335737]), 647 | {{}, Nx.tensor([0.21444265, 0.63923396, 0.12755156])} 648 | } 649 | } 650 | 651 | updates = { 652 | { 653 | Nx.tensor([0.48363215, 0.7147937, 0.32252682]), 654 | {{}, Nx.tensor([0.09518468, 0.38613084, 0.20729078])} 655 | } 656 | } 657 | 658 | state = init_fn.(params) 659 | 660 | expected_b = Nx.tensor([0.00483632, 0.00714794, 0.00322527]) 661 | expected_e = Nx.tensor([0.00095185, 0.00386131, 0.00207291]) 662 | 663 | assert {new_updates, new_state} = update_fn.(updates, state, params) 664 | assert {{actual_b, {{}, actual_e}}} = new_updates 665 | assert new_state == {} 666 | assert_all_close(actual_b, expected_b) 667 | assert_all_close(actual_e, expected_e) 668 | end 669 | end 670 | 671 | describe "scale_by_state" do 672 | test "constructs a stateful transformation" do 673 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 674 | assert {init_fn, update_fn} = scale_by_state(1.0e-3) 675 | assert is_function(init_fn, 1) 676 | assert is_function(update_fn, 3) 677 | assert {state} = init_fn.(params) 678 | assert %{scale: scale} = state 679 | assert_equal(scale, Nx.tensor(1.0e-3)) 680 | end 681 | 682 | test "composes with itself" do 683 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 684 | assert {init_fn, update_fn} = scale_by_state(1.0e-3) |> scale_by_state(1.0e-2) 685 | assert is_function(init_fn, 1) 686 | assert is_function(update_fn, 3) 687 | assert {state_1, state_2} = init_fn.(params) 688 | assert %{scale: scale_1} = state_1 689 | assert_equal(scale_1, Nx.tensor(1.0e-2)) 690 | assert %{scale: scale_2} = state_2 691 | assert_equal(scale_2, Nx.tensor(1.0e-3)) 692 | end 693 | 694 | test "composes with stateless transformation" do 695 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 696 | assert {init_fn, update_fn} = scale_by_state(1.0e-3) |> scale(1.0e-2) 697 | assert is_function(init_fn, 1) 698 | assert is_function(update_fn, 3) 699 | assert {state} = init_fn.(params) 700 | assert %{scale: scale} = state 701 | assert_equal(scale, Nx.tensor(1.0e-3)) 702 | end 703 | 704 | test "matches optax with simple container" do 705 | assert {init_fn, update_fn} = scale_by_state(1.0e-2) 706 | params = %{a: Nx.tensor([0.29887561, 0.70429164, 0.43314898])} 707 | updates = %{a: Nx.tensor([0.2584395, 0.35890494, 0.84845509])} 708 | state = init_fn.(params) 709 | 710 | expected_a = Nx.tensor([0.00258439, 0.00358905, 0.00848455]) 711 | 712 | assert {new_updates, new_state} = update_fn.(updates, state, params) 713 | assert %{a: actual_a} = new_updates 714 | assert {%{scale: scale}} = new_state 715 | assert_all_close(actual_a, expected_a) 716 | assert_all_close(scale, Nx.tensor(1.0e-2)) 717 | end 718 | 719 | test "matches optax with nested container" do 720 | assert {init_fn, update_fn} = scale_by_state(1.0e-2) 721 | 722 | params = %{ 723 | a: %{ 724 | b: Nx.tensor([0.58813851, 0.27981229, 0.17335737]), 725 | c: %{d: %{}, e: Nx.tensor([0.21444265, 0.63923396, 0.12755156])} 726 | } 727 | } 728 | 729 | updates = %{ 730 | a: %{ 731 | b: Nx.tensor([0.48363215, 0.7147937, 0.32252682]), 732 | c: %{d: %{}, e: Nx.tensor([0.09518468, 0.38613084, 0.20729078])} 733 | } 734 | } 735 | 736 | state = init_fn.(params) 737 | 738 | expected_b = Nx.tensor([0.00483632, 0.00714794, 0.00322527]) 739 | expected_e = Nx.tensor([0.00095185, 0.00386131, 0.00207291]) 740 | 741 | assert {new_updates, new_state} = update_fn.(updates, state, params) 742 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 743 | assert {%{scale: scale}} = new_state 744 | assert_all_close(actual_b, expected_b) 745 | assert_all_close(actual_e, expected_e) 746 | assert_all_close(scale, Nx.tensor(1.0e-2)) 747 | end 748 | 749 | test "supports generic container" do 750 | assert {init_fn, update_fn} = scale_by_state(1.0e-2) 751 | 752 | params = { 753 | { 754 | Nx.tensor([0.58813851, 0.27981229, 0.17335737]), 755 | {{}, Nx.tensor([0.21444265, 0.63923396, 0.12755156])} 756 | } 757 | } 758 | 759 | updates = { 760 | { 761 | Nx.tensor([0.48363215, 0.7147937, 0.32252682]), 762 | {{}, Nx.tensor([0.09518468, 0.38613084, 0.20729078])} 763 | } 764 | } 765 | 766 | state = init_fn.(params) 767 | 768 | expected_b = Nx.tensor([0.00483632, 0.00714794, 0.00322527]) 769 | expected_e = Nx.tensor([0.00095185, 0.00386131, 0.00207291]) 770 | 771 | assert {new_updates, new_state} = update_fn.(updates, state, params) 772 | assert {{actual_b, {{}, actual_e}}} = new_updates 773 | assert {%{scale: scale}} = new_state 774 | assert_all_close(actual_b, expected_b) 775 | assert_all_close(actual_e, expected_e) 776 | assert_all_close(scale, Nx.tensor(1.0e-2)) 777 | end 778 | end 779 | 780 | describe "scale_by_adam" do 781 | test "constructs a stateful transformation" do 782 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 783 | assert {init_fn, update_fn} = scale_by_adam() 784 | assert is_function(init_fn, 1) 785 | assert is_function(update_fn, 3) 786 | assert {adam_state} = init_fn.(params) 787 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state 788 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 789 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 790 | assert_equal(count, Nx.tensor(0)) 791 | end 792 | 793 | test "constructs a stateful transformation with options" do 794 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 795 | assert {init_fn, update_fn} = scale_by_adam(b1: 0.5) 796 | assert is_function(init_fn, 1) 797 | assert is_function(update_fn, 3) 798 | assert {adam_state} = init_fn.(params) 799 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state 800 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 801 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 802 | assert_equal(count, Nx.tensor(0)) 803 | end 804 | 805 | test "composes with itself" do 806 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 807 | assert {init_fn, update_fn} = scale_by_adam() |> scale_by_adam() 808 | assert is_function(init_fn, 1) 809 | assert is_function(update_fn, 3) 810 | assert {adam_state_1, adam_state_2} = init_fn.(params) 811 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state_1 812 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 813 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 814 | assert_equal(count, Nx.tensor(0)) 815 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state_2 816 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 817 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 818 | assert_equal(count, Nx.tensor(0)) 819 | end 820 | 821 | test "composes with stateless transformation" do 822 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 823 | assert {init_fn, update_fn} = scale_by_adam() |> scale(1.0e-2) 824 | assert is_function(init_fn, 1) 825 | assert is_function(update_fn, 3) 826 | assert {adam_state} = init_fn.(params) 827 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state 828 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 829 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 830 | assert_equal(count, Nx.tensor(0)) 831 | end 832 | 833 | test "matches optax with simple container" do 834 | assert {init_fn, update_fn} = scale_by_adam() 835 | params = %{a: Nx.tensor([0.29236649, 0.26508023, 0.05959644])} 836 | updates = %{a: Nx.tensor([0.01461005, 0.3796587, 0.76886989])} 837 | state = init_fn.(params) 838 | 839 | expected_a = Nx.tensor([0.99999267, 0.9999933, 0.9999933]) 840 | expected_next_mu_a = Nx.tensor([0.00146101, 0.03796587, 0.07688699]) 841 | expected_next_nu_a = Nx.tensor([2.1345357e-07, 1.4414072e-04, 5.9116090e-04]) 842 | expected_next_count = Nx.tensor(1) 843 | 844 | assert {new_updates, new_state} = update_fn.(updates, state, params) 845 | assert %{a: actual_a} = new_updates 846 | 847 | assert {%{mu: %{a: actual_next_mu_a}, nu: %{a: actual_next_nu_a}, count: actual_next_count}} = 848 | new_state 849 | 850 | assert_all_close(actual_a, expected_a) 851 | assert_all_close(actual_next_mu_a, expected_next_mu_a) 852 | assert_all_close(actual_next_nu_a, expected_next_nu_a) 853 | assert_equal(actual_next_count, expected_next_count) 854 | end 855 | 856 | test "matches optax with nested container" do 857 | assert {init_fn, update_fn} = scale_by_adam() 858 | 859 | params = %{ 860 | a: %{ 861 | b: Nx.tensor([0.16028131, 0.82155978, 0.67870557]), 862 | c: %{d: %{}, e: Nx.tensor([[0.42164469, 0.59406027, 0.24703223]])} 863 | } 864 | } 865 | 866 | updates = %{ 867 | a: %{ 868 | b: Nx.tensor([0.37850456, 0.80079877, 0.16309247]), 869 | c: %{d: %{}, e: Nx.tensor([[0.29081831, 0.29872105, 0.48405271]])} 870 | } 871 | } 872 | 873 | state = init_fn.(params) 874 | 875 | expected_b = Nx.tensor([0.9999934, 0.9999933, 0.99999315]) 876 | expected_e = Nx.tensor([[0.9999933, 0.9999933, 0.9999933]]) 877 | expected_next_mu_b = Nx.tensor([0.03785046, 0.08007988, 0.01630925]) 878 | expected_next_mu_e = Nx.tensor([[0.02908183, 0.0298721, 0.04840527]]) 879 | expected_next_nu_b = Nx.tensor([1.4326570e-04, 6.4127869e-04, 2.6599155e-05]) 880 | expected_next_nu_e = Nx.tensor([[8.4575287e-05, 8.9234265e-05, 2.3430702e-04]]) 881 | expected_next_count = Nx.tensor(1) 882 | 883 | assert {new_updates, new_state} = update_fn.(updates, state, params) 884 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 885 | assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state 886 | assert %{a: %{b: actual_next_mu_b, c: %{d: %{}, e: actual_next_mu_e}}} = new_mu 887 | assert %{a: %{b: actual_next_nu_b, c: %{d: %{}, e: actual_next_nu_e}}} = new_nu 888 | assert_all_close(actual_b, expected_b) 889 | assert_all_close(actual_e, expected_e) 890 | assert_all_close(actual_next_mu_b, expected_next_mu_b) 891 | assert_all_close(actual_next_mu_e, expected_next_mu_e) 892 | assert_all_close(actual_next_nu_b, expected_next_nu_b) 893 | assert_all_close(actual_next_nu_e, expected_next_nu_e) 894 | assert_equal(actual_next_count, expected_next_count) 895 | end 896 | 897 | test "supports generic container" do 898 | assert {init_fn, update_fn} = scale_by_adam() 899 | 900 | params = { 901 | { 902 | Nx.tensor([0.16028131, 0.82155978, 0.67870557]), 903 | {{}, Nx.tensor([[0.42164469, 0.59406027, 0.24703223]])} 904 | } 905 | } 906 | 907 | updates = { 908 | { 909 | Nx.tensor([0.37850456, 0.80079877, 0.16309247]), 910 | {{}, Nx.tensor([[0.29081831, 0.29872105, 0.48405271]])} 911 | } 912 | } 913 | 914 | state = init_fn.(params) 915 | 916 | expected_b = Nx.tensor([0.9999934, 0.9999933, 0.99999315]) 917 | expected_e = Nx.tensor([[0.9999933, 0.9999933, 0.9999933]]) 918 | expected_next_mu_b = Nx.tensor([0.03785046, 0.08007988, 0.01630925]) 919 | expected_next_mu_e = Nx.tensor([[0.02908183, 0.0298721, 0.04840527]]) 920 | expected_next_nu_b = Nx.tensor([1.4326570e-04, 6.4127869e-04, 2.6599155e-05]) 921 | expected_next_nu_e = Nx.tensor([[8.4575287e-05, 8.9234265e-05, 2.3430702e-04]]) 922 | expected_next_count = Nx.tensor(1) 923 | 924 | assert {new_updates, new_state} = update_fn.(updates, state, params) 925 | assert {{actual_b, {{}, actual_e}}} = new_updates 926 | assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state 927 | assert {{actual_next_mu_b, {{}, actual_next_mu_e}}} = new_mu 928 | assert {{actual_next_nu_b, {{}, actual_next_nu_e}}} = new_nu 929 | assert_all_close(actual_b, expected_b) 930 | assert_all_close(actual_e, expected_e) 931 | assert_all_close(actual_next_mu_b, expected_next_mu_b) 932 | assert_all_close(actual_next_mu_e, expected_next_mu_e) 933 | assert_all_close(actual_next_nu_b, expected_next_nu_b) 934 | assert_all_close(actual_next_nu_e, expected_next_nu_e) 935 | assert_equal(actual_next_count, expected_next_count) 936 | end 937 | end 938 | 939 | describe "scale_by_belief" do 940 | test "constructs a stateful transformation" do 941 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 942 | assert {init_fn, update_fn} = scale_by_belief() 943 | assert is_function(init_fn, 1) 944 | assert is_function(update_fn, 3) 945 | assert {belief_state} = init_fn.(params) 946 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = belief_state 947 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 948 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 949 | assert_equal(count, Nx.tensor(0)) 950 | end 951 | 952 | test "constructs a stateful transformation with options" do 953 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 954 | assert {init_fn, update_fn} = scale_by_belief(b1: 0.4) 955 | assert is_function(init_fn, 1) 956 | assert is_function(update_fn, 3) 957 | assert {belief_state} = init_fn.(params) 958 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = belief_state 959 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 960 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 961 | assert_equal(count, Nx.tensor(0)) 962 | end 963 | 964 | test "composes with itself" do 965 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 966 | assert {init_fn, update_fn} = scale_by_belief() |> scale_by_belief() 967 | assert is_function(init_fn, 1) 968 | assert is_function(update_fn, 3) 969 | assert {belief_state_1, belief_state_2} = init_fn.(params) 970 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = belief_state_1 971 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 972 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 973 | assert_equal(count, Nx.tensor(0)) 974 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = belief_state_2 975 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 976 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 977 | assert_equal(count, Nx.tensor(0)) 978 | end 979 | 980 | test "composes with stateless transformation" do 981 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 982 | assert {init_fn, update_fn} = scale_by_belief() |> scale(1.0e-2) 983 | assert is_function(init_fn, 1) 984 | assert is_function(update_fn, 3) 985 | assert {belief_state} = init_fn.(params) 986 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = belief_state 987 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 988 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 989 | assert_equal(count, Nx.tensor(0)) 990 | end 991 | 992 | test "matches optax with simple container" do 993 | assert {init_fn, update_fn} = scale_by_belief() 994 | params = %{a: Nx.tensor([0.35582285, 0.02904734, 0.8684706])} 995 | updates = %{a: Nx.tensor([0.64641294, 0.19990149, 0.54263212])} 996 | state = init_fn.(params) 997 | 998 | expected_a = Nx.tensor([0.9999934, 0.99999326, 0.9999933]) 999 | expected_next_mu_a = Nx.tensor([0.0646413, 0.01999015, 0.05426321]) 1000 | expected_next_nu_a = Nx.tensor([4.1784969e-04, 3.9960611e-05, 2.9444962e-04]) 1001 | expected_next_count = Nx.tensor(1) 1002 | 1003 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1004 | assert %{a: actual_a} = new_updates 1005 | 1006 | assert {%{mu: %{a: actual_next_mu_a}, nu: %{a: actual_next_nu_a}, count: actual_next_count}} = 1007 | new_state 1008 | 1009 | assert_all_close(actual_a, expected_a) 1010 | assert_all_close(actual_next_mu_a, expected_next_mu_a) 1011 | assert_all_close(actual_next_nu_a, expected_next_nu_a) 1012 | assert_equal(actual_next_count, expected_next_count) 1013 | end 1014 | 1015 | test "matches optax with nested container" do 1016 | assert {init_fn, update_fn} = scale_by_belief() 1017 | 1018 | params = %{ 1019 | a: %{ 1020 | b: Nx.tensor([0.48266117, 0.21594939, 0.25310925]), 1021 | c: %{d: %{}, e: Nx.tensor([[0.08780911, 0.25273182, 0.02973737]])} 1022 | } 1023 | } 1024 | 1025 | updates = %{ 1026 | a: %{ 1027 | b: Nx.tensor([0.15456417, 0.03338711, 0.47241908]), 1028 | c: %{d: %{}, e: Nx.tensor([[0.76352976, 0.86033023, 0.22758512]])} 1029 | } 1030 | } 1031 | 1032 | state = init_fn.(params) 1033 | 1034 | expected_b = Nx.tensor([0.9999933, 0.9999933, 0.99999326]) 1035 | expected_e = Nx.tensor([[0.9999934, 0.99999326, 0.9999933]]) 1036 | expected_next_mu_b = Nx.tensor([0.01545642, 0.00333871, 0.04724191]) 1037 | expected_next_mu_e = Nx.tensor([[0.07635298, 0.08603302, 0.02275851]]) 1038 | expected_next_nu_b = Nx.tensor([2.3890085e-05, 1.1146991e-06, 2.2317980e-04]) 1039 | expected_next_nu_e = Nx.tensor([[5.8297772e-04, 7.4016815e-04, 5.1794988e-05]]) 1040 | expected_next_count = Nx.tensor(1) 1041 | 1042 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1043 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 1044 | assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state 1045 | assert %{a: %{b: actual_next_mu_b, c: %{d: %{}, e: actual_next_mu_e}}} = new_mu 1046 | assert %{a: %{b: actual_next_nu_b, c: %{d: %{}, e: actual_next_nu_e}}} = new_nu 1047 | assert_all_close(actual_b, expected_b) 1048 | assert_all_close(actual_e, expected_e) 1049 | assert_all_close(actual_next_mu_b, expected_next_mu_b) 1050 | assert_all_close(actual_next_mu_e, expected_next_mu_e) 1051 | assert_all_close(actual_next_nu_b, expected_next_nu_b) 1052 | assert_all_close(actual_next_nu_e, expected_next_nu_e) 1053 | assert_equal(actual_next_count, expected_next_count) 1054 | end 1055 | 1056 | test "supports generic container" do 1057 | assert {init_fn, update_fn} = scale_by_belief() 1058 | 1059 | params = { 1060 | { 1061 | Nx.tensor([0.48266117, 0.21594939, 0.25310925]), 1062 | {{}, Nx.tensor([[0.08780911, 0.25273182, 0.02973737]])} 1063 | } 1064 | } 1065 | 1066 | updates = { 1067 | { 1068 | Nx.tensor([0.15456417, 0.03338711, 0.47241908]), 1069 | {{}, Nx.tensor([[0.76352976, 0.86033023, 0.22758512]])} 1070 | } 1071 | } 1072 | 1073 | state = init_fn.(params) 1074 | 1075 | expected_b = Nx.tensor([0.9999933, 0.9999933, 0.99999326]) 1076 | expected_e = Nx.tensor([[0.9999934, 0.99999326, 0.9999933]]) 1077 | expected_next_mu_b = Nx.tensor([0.01545642, 0.00333871, 0.04724191]) 1078 | expected_next_mu_e = Nx.tensor([[0.07635298, 0.08603302, 0.02275851]]) 1079 | expected_next_nu_b = Nx.tensor([2.3890085e-05, 1.1146991e-06, 2.2317980e-04]) 1080 | expected_next_nu_e = Nx.tensor([[5.8297772e-04, 7.4016815e-04, 5.1794988e-05]]) 1081 | expected_next_count = Nx.tensor(1) 1082 | 1083 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1084 | assert {{actual_b, {{}, actual_e}}} = new_updates 1085 | assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state 1086 | assert {{actual_next_mu_b, {{}, actual_next_mu_e}}} = new_mu 1087 | assert {{actual_next_nu_b, {{}, actual_next_nu_e}}} = new_nu 1088 | assert_all_close(actual_b, expected_b) 1089 | assert_all_close(actual_e, expected_e) 1090 | assert_all_close(actual_next_mu_b, expected_next_mu_b) 1091 | assert_all_close(actual_next_mu_e, expected_next_mu_e) 1092 | assert_all_close(actual_next_nu_b, expected_next_nu_b) 1093 | assert_all_close(actual_next_nu_e, expected_next_nu_e) 1094 | assert_equal(actual_next_count, expected_next_count) 1095 | end 1096 | end 1097 | 1098 | describe "scale_by_radam" do 1099 | test "constructs a stateful transformation" do 1100 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1101 | assert {init_fn, update_fn} = scale_by_radam() 1102 | assert is_function(init_fn, 1) 1103 | assert is_function(update_fn, 3) 1104 | assert {adam_state} = init_fn.(params) 1105 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state 1106 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 1107 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 1108 | assert_equal(count, Nx.tensor(0)) 1109 | end 1110 | 1111 | test "constructs a stateful transformation with options" do 1112 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1113 | assert {init_fn, update_fn} = scale_by_radam(b1: 0.5) 1114 | assert is_function(init_fn, 1) 1115 | assert is_function(update_fn, 3) 1116 | assert {adam_state} = init_fn.(params) 1117 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state 1118 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 1119 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 1120 | assert_equal(count, Nx.tensor(0)) 1121 | end 1122 | 1123 | test "composes with itself" do 1124 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1125 | assert {init_fn, update_fn} = scale_by_radam() |> scale_by_radam() 1126 | assert is_function(init_fn, 1) 1127 | assert is_function(update_fn, 3) 1128 | assert {adam_state_1, adam_state_2} = init_fn.(params) 1129 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state_1 1130 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 1131 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 1132 | assert_equal(count, Nx.tensor(0)) 1133 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state_2 1134 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 1135 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 1136 | assert_equal(count, Nx.tensor(0)) 1137 | end 1138 | 1139 | test "composes with stateless transformation" do 1140 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1141 | assert {init_fn, update_fn} = scale_by_radam() |> scale(1.0e-2) 1142 | assert is_function(init_fn, 1) 1143 | assert is_function(update_fn, 3) 1144 | assert {adam_state} = init_fn.(params) 1145 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state 1146 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 1147 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 1148 | assert_equal(count, Nx.tensor(0)) 1149 | end 1150 | 1151 | test "matches optax with simple container" do 1152 | assert {init_fn, update_fn} = scale_by_radam() 1153 | params = %{a: Nx.tensor([0.71289699, 0.29554161, 0.50779425])} 1154 | updates = %{a: Nx.tensor([0.88675452, 0.21455035, 0.53581422])} 1155 | state = init_fn.(params) 1156 | 1157 | expected_a = Nx.tensor([0.88675433, 0.2145503, 0.53581405]) 1158 | expected_next_mu_a = Nx.tensor([0.08867545, 0.02145503, 0.05358142]) 1159 | expected_next_nu_a = Nx.tensor([7.863336e-04, 4.603185e-05, 2.870969e-04]) 1160 | expected_next_count = Nx.tensor(1) 1161 | 1162 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1163 | assert %{a: actual_a} = new_updates 1164 | 1165 | assert {%{mu: %{a: actual_next_mu_a}, nu: %{a: actual_next_nu_a}, count: actual_next_count}} = 1166 | new_state 1167 | 1168 | assert_all_close(actual_a, expected_a) 1169 | assert_all_close(actual_next_mu_a, expected_next_mu_a) 1170 | assert_all_close(actual_next_nu_a, expected_next_nu_a) 1171 | assert_equal(actual_next_count, expected_next_count) 1172 | end 1173 | 1174 | test "matches optax with nested container" do 1175 | assert {init_fn, update_fn} = scale_by_radam() 1176 | 1177 | params = %{ 1178 | a: %{ 1179 | b: Nx.tensor([0.72504156, 0.86982723, 0.58679938]), 1180 | c: %{d: %{}, e: Nx.tensor([[0.26001513, 0.62556789, 0.29528421]])} 1181 | } 1182 | } 1183 | 1184 | updates = %{ 1185 | a: %{ 1186 | b: Nx.tensor([0.01536453, 0.61977439, 0.561842]), 1187 | c: %{d: %{}, e: Nx.tensor([[0.03755132, 0.80392208, 0.87391938]])} 1188 | } 1189 | } 1190 | 1191 | state = init_fn.(params) 1192 | 1193 | expected_b = Nx.tensor([0.01536453, 0.6197742, 0.56184185]) 1194 | expected_e = Nx.tensor([[0.03755131, 0.8039219, 0.8739191]]) 1195 | expected_next_mu_b = Nx.tensor([0.00153645, 0.06197744, 0.0561842]) 1196 | expected_next_mu_e = Nx.tensor([[0.00375513, 0.0803922, 0.08739194]]) 1197 | expected_next_nu_b = Nx.tensor([2.3606893e-07, 3.8412030e-04, 3.1566643e-04]) 1198 | expected_next_nu_e = Nx.tensor([[1.4101014e-06, 6.4629072e-04, 7.6373509e-04]]) 1199 | expected_next_count = Nx.tensor(1) 1200 | 1201 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1202 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 1203 | assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state 1204 | assert %{a: %{b: actual_next_mu_b, c: %{d: %{}, e: actual_next_mu_e}}} = new_mu 1205 | assert %{a: %{b: actual_next_nu_b, c: %{d: %{}, e: actual_next_nu_e}}} = new_nu 1206 | assert_all_close(actual_b, expected_b) 1207 | assert_all_close(actual_e, expected_e) 1208 | assert_all_close(actual_next_mu_b, expected_next_mu_b) 1209 | assert_all_close(actual_next_mu_e, expected_next_mu_e) 1210 | assert_all_close(actual_next_nu_b, expected_next_nu_b) 1211 | assert_all_close(actual_next_nu_e, expected_next_nu_e) 1212 | assert_equal(actual_next_count, expected_next_count) 1213 | end 1214 | 1215 | test "supports generic container" do 1216 | assert {init_fn, update_fn} = scale_by_radam() 1217 | 1218 | params = { 1219 | { 1220 | Nx.tensor([0.72504156, 0.86982723, 0.58679938]), 1221 | {{}, Nx.tensor([[0.26001513, 0.62556789, 0.29528421]])} 1222 | } 1223 | } 1224 | 1225 | updates = { 1226 | { 1227 | Nx.tensor([0.01536453, 0.61977439, 0.561842]), 1228 | {{}, Nx.tensor([[0.03755132, 0.80392208, 0.87391938]])} 1229 | } 1230 | } 1231 | 1232 | state = init_fn.(params) 1233 | 1234 | expected_b = Nx.tensor([0.01536453, 0.6197742, 0.56184185]) 1235 | expected_e = Nx.tensor([[0.03755131, 0.8039219, 0.8739191]]) 1236 | expected_next_mu_b = Nx.tensor([0.00153645, 0.06197744, 0.0561842]) 1237 | expected_next_mu_e = Nx.tensor([[0.00375513, 0.0803922, 0.08739194]]) 1238 | expected_next_nu_b = Nx.tensor([2.3606893e-07, 3.8412030e-04, 3.1566643e-04]) 1239 | expected_next_nu_e = Nx.tensor([[1.4101014e-06, 6.4629072e-04, 7.6373509e-04]]) 1240 | expected_next_count = Nx.tensor(1) 1241 | 1242 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1243 | assert {{actual_b, {{}, actual_e}}} = new_updates 1244 | assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state 1245 | assert {{actual_next_mu_b, {{}, actual_next_mu_e}}} = new_mu 1246 | assert {{actual_next_nu_b, {{}, actual_next_nu_e}}} = new_nu 1247 | assert_all_close(actual_b, expected_b) 1248 | assert_all_close(actual_e, expected_e) 1249 | assert_all_close(actual_next_mu_b, expected_next_mu_b) 1250 | assert_all_close(actual_next_mu_e, expected_next_mu_e) 1251 | assert_all_close(actual_next_nu_b, expected_next_nu_b) 1252 | assert_all_close(actual_next_nu_e, expected_next_nu_e) 1253 | assert_equal(actual_next_count, expected_next_count) 1254 | end 1255 | end 1256 | 1257 | describe "scale_by_rms" do 1258 | test "constructs a stateful transformation" do 1259 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1260 | assert {init_fn, update_fn} = scale_by_rms() 1261 | assert is_function(init_fn, 1) 1262 | assert is_function(update_fn, 3) 1263 | assert {rms_state} = init_fn.(params) 1264 | assert %{nu: %{a: nu_a}} = rms_state 1265 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 1266 | end 1267 | 1268 | test "constructs a stateful transformation with options" do 1269 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1270 | assert {init_fn, update_fn} = scale_by_rms(initial_scale: 0.1) 1271 | assert is_function(init_fn, 1) 1272 | assert is_function(update_fn, 3) 1273 | assert {rms_state} = init_fn.(params) 1274 | assert %{nu: %{a: nu_a}} = rms_state 1275 | assert_equal(nu_a, Nx.tensor([0.1, 0.1, 0.1])) 1276 | end 1277 | 1278 | test "composes with itself" do 1279 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1280 | assert {init_fn, update_fn} = scale_by_rms() |> scale_by_rms() 1281 | assert is_function(init_fn, 1) 1282 | assert is_function(update_fn, 3) 1283 | assert {rms_state_1, rms_state_2} = init_fn.(params) 1284 | assert %{nu: %{a: nu_a}} = rms_state_1 1285 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 1286 | assert %{nu: %{a: nu_a}} = rms_state_2 1287 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 1288 | end 1289 | 1290 | test "composes with stateless transformation" do 1291 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1292 | assert {init_fn, update_fn} = scale_by_rms() |> scale(1.0e-2) 1293 | assert is_function(init_fn, 1) 1294 | assert is_function(update_fn, 3) 1295 | assert {rms_state} = init_fn.(params) 1296 | assert %{nu: %{a: nu_a}} = rms_state 1297 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 1298 | end 1299 | 1300 | test "matches optax with simple container" do 1301 | assert {init_fn, update_fn} = scale_by_rms() 1302 | params = %{a: Nx.tensor([0.77100057, 0.98078091, 0.78499164])} 1303 | updates = %{a: Nx.tensor([0.25156708, 0.30524656, 0.97350756])} 1304 | state = init_fn.(params) 1305 | 1306 | expected_a = Nx.tensor([3.162275, 3.162276, 3.1622777]) 1307 | expected_next_nu_a = Nx.tensor([0.0063286, 0.00931755, 0.0947717]) 1308 | 1309 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1310 | assert %{a: actual_a} = new_updates 1311 | assert {%{nu: %{a: actual_next_nu_a}}} = new_state 1312 | assert_all_close(actual_a, expected_a) 1313 | assert_all_close(expected_next_nu_a, actual_next_nu_a) 1314 | end 1315 | 1316 | test "matches optax with nested container" do 1317 | assert {init_fn, update_fn} = scale_by_rms() 1318 | 1319 | params = %{ 1320 | a: %{ 1321 | b: Nx.tensor([0.0553049, 0.21828064, 0.98751916]), 1322 | c: %{d: %{}, e: Nx.tensor([[0.17757973, 0.67966022, 0.19382288]])} 1323 | } 1324 | } 1325 | 1326 | updates = %{ 1327 | a: %{ 1328 | b: Nx.tensor([0.61220327, 0.73535765, 0.42179138]), 1329 | c: %{d: %{}, e: Nx.tensor([[0.39331236, 0.27389305, 0.30131908]])} 1330 | } 1331 | } 1332 | 1333 | state = init_fn.(params) 1334 | 1335 | expected_b = Nx.tensor([3.1622772, 3.1622772, 3.162277]) 1336 | expected_e = Nx.tensor([[3.1622767, 3.1622758, 3.162276]]) 1337 | expected_next_nu_b = Nx.tensor([0.03747929, 0.05407509, 0.0177908]) 1338 | expected_next_nu_e = Nx.tensor([[0.01546946, 0.00750174, 0.00907932]]) 1339 | 1340 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1341 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 1342 | assert {%{nu: new_nu}} = new_state 1343 | assert %{a: %{b: actual_next_nu_b, c: %{d: %{}, e: actual_next_nu_e}}} = new_nu 1344 | assert_all_close(actual_b, expected_b) 1345 | assert_all_close(actual_e, expected_e) 1346 | assert_all_close(actual_next_nu_b, expected_next_nu_b) 1347 | assert_all_close(actual_next_nu_e, expected_next_nu_e) 1348 | end 1349 | 1350 | test "supports generic container" do 1351 | assert {init_fn, update_fn} = scale_by_rms() 1352 | 1353 | params = { 1354 | { 1355 | Nx.tensor([0.0553049, 0.21828064, 0.98751916]), 1356 | {{}, Nx.tensor([[0.17757973, 0.67966022, 0.19382288]])} 1357 | } 1358 | } 1359 | 1360 | updates = { 1361 | { 1362 | Nx.tensor([0.61220327, 0.73535765, 0.42179138]), 1363 | {{}, Nx.tensor([[0.39331236, 0.27389305, 0.30131908]])} 1364 | } 1365 | } 1366 | 1367 | state = init_fn.(params) 1368 | 1369 | expected_b = Nx.tensor([3.1622772, 3.1622772, 3.162277]) 1370 | expected_e = Nx.tensor([[3.1622767, 3.1622758, 3.162276]]) 1371 | expected_next_nu_b = Nx.tensor([0.03747929, 0.05407509, 0.0177908]) 1372 | expected_next_nu_e = Nx.tensor([[0.01546946, 0.00750174, 0.00907932]]) 1373 | 1374 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1375 | assert {{actual_b, {{}, actual_e}}} = new_updates 1376 | assert {%{nu: new_nu}} = new_state 1377 | assert {{actual_next_nu_b, {{}, actual_next_nu_e}}} = new_nu 1378 | assert_all_close(actual_b, expected_b) 1379 | assert_all_close(actual_e, expected_e) 1380 | assert_all_close(actual_next_nu_b, expected_next_nu_b) 1381 | assert_all_close(actual_next_nu_e, expected_next_nu_e) 1382 | end 1383 | end 1384 | 1385 | describe "scale_by_rss" do 1386 | test "constructs a stateful transformation" do 1387 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1388 | assert {init_fn, update_fn} = scale_by_rss() 1389 | assert is_function(init_fn, 1) 1390 | assert is_function(update_fn, 3) 1391 | assert {rss_state} = init_fn.(params) 1392 | assert %{sum_of_squares: %{a: sum_of_squares_a}} = rss_state 1393 | assert_equal(sum_of_squares_a, Nx.tensor([0.1, 0.1, 0.1])) 1394 | end 1395 | 1396 | test "constructs a stateful transformation with options" do 1397 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1398 | assert {init_fn, update_fn} = scale_by_rss(initial_accumulator_value: 0.2) 1399 | assert is_function(init_fn, 1) 1400 | assert is_function(update_fn, 3) 1401 | assert {rss_state} = init_fn.(params) 1402 | assert %{sum_of_squares: %{a: sum_of_squares_a}} = rss_state 1403 | assert_equal(sum_of_squares_a, Nx.tensor([0.2, 0.2, 0.2])) 1404 | end 1405 | 1406 | test "composes with itself" do 1407 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1408 | assert {init_fn, update_fn} = scale_by_rss() |> scale_by_rss() 1409 | assert is_function(init_fn, 1) 1410 | assert is_function(update_fn, 3) 1411 | assert {rss_state_1, rss_state_2} = init_fn.(params) 1412 | assert %{sum_of_squares: %{a: sum_of_squares_a}} = rss_state_1 1413 | assert_equal(sum_of_squares_a, Nx.tensor([0.1, 0.1, 0.1])) 1414 | assert %{sum_of_squares: %{a: sum_of_squares_a}} = rss_state_2 1415 | assert_equal(sum_of_squares_a, Nx.tensor([0.1, 0.1, 0.1])) 1416 | end 1417 | 1418 | test "composes with stateless transformation" do 1419 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1420 | assert {init_fn, update_fn} = scale_by_rss() |> scale(1.0e-2) 1421 | assert is_function(init_fn, 1) 1422 | assert is_function(update_fn, 3) 1423 | assert {rss_state} = init_fn.(params) 1424 | assert %{sum_of_squares: %{a: sum_of_squares_a}} = rss_state 1425 | assert_equal(sum_of_squares_a, Nx.tensor([0.1, 0.1, 0.1])) 1426 | end 1427 | 1428 | test "matches optax with simple container" do 1429 | assert {init_fn, update_fn} = scale_by_rss() 1430 | params = %{a: Nx.tensor([0.41327447, 0.06948837, 0.03234601])} 1431 | updates = %{a: Nx.tensor([0.2137085, 0.84399692, 0.63099467])} 1432 | state = init_fn.(params) 1433 | 1434 | expected_a = Nx.tensor([0.55993116, 0.93642795, 0.89401275]) 1435 | expected_next_sum_of_squares_a = Nx.tensor([0.14567132, 0.81233084, 0.49815428]) 1436 | 1437 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1438 | assert %{a: actual_a} = new_updates 1439 | assert {%{sum_of_squares: %{a: actual_next_sum_of_squares_a}}} = new_state 1440 | assert_all_close(actual_a, expected_a) 1441 | assert_all_close(actual_next_sum_of_squares_a, expected_next_sum_of_squares_a) 1442 | end 1443 | 1444 | test "matches optax with nested container" do 1445 | assert {init_fn, update_fn} = scale_by_rss() 1446 | 1447 | params = %{ 1448 | a: %{ 1449 | b: Nx.tensor([0.92084601, 0.27218277, 0.56501597]), 1450 | c: %{d: %{}, e: Nx.tensor([[0.92937211, 0.44536295, 0.95296635]])} 1451 | } 1452 | } 1453 | 1454 | updates = %{ 1455 | a: %{ 1456 | b: Nx.tensor([0.79292352, 0.11484326, 0.84693855]), 1457 | c: %{d: %{}, e: Nx.tensor([[0.13715272, 0.63276641, 0.5234425]])} 1458 | } 1459 | } 1460 | 1461 | state = init_fn.(params) 1462 | 1463 | expected_b = Nx.tensor([0.92885643, 0.34135267, 0.9368279]) 1464 | expected_e = Nx.tensor([[0.39790204, 0.894515, 0.855929]]) 1465 | expected_next_sum_of_squares_b = Nx.tensor([0.72872776, 0.11318897, 0.8173049]) 1466 | expected_next_sum_of_squares_e = Nx.tensor([[0.11881087, 0.50039333, 0.37399206]]) 1467 | 1468 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1469 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 1470 | assert {%{sum_of_squares: new_sum_of_squares}} = new_state 1471 | 1472 | assert %{ 1473 | a: %{ 1474 | b: actual_next_sum_of_squares_b, 1475 | c: %{d: %{}, e: actual_next_sum_of_squares_e} 1476 | } 1477 | } = new_sum_of_squares 1478 | 1479 | assert_all_close(actual_b, expected_b) 1480 | assert_all_close(actual_e, expected_e) 1481 | assert_all_close(actual_next_sum_of_squares_b, expected_next_sum_of_squares_b) 1482 | assert_all_close(actual_next_sum_of_squares_e, expected_next_sum_of_squares_e) 1483 | end 1484 | 1485 | test "supports generic container" do 1486 | assert {init_fn, update_fn} = scale_by_rss() 1487 | 1488 | params = { 1489 | { 1490 | Nx.tensor([0.92084601, 0.27218277, 0.56501597]), 1491 | {{}, Nx.tensor([[0.92937211, 0.44536295, 0.95296635]])} 1492 | } 1493 | } 1494 | 1495 | updates = { 1496 | { 1497 | Nx.tensor([0.79292352, 0.11484326, 0.84693855]), 1498 | {{}, Nx.tensor([[0.13715272, 0.63276641, 0.5234425]])} 1499 | } 1500 | } 1501 | 1502 | state = init_fn.(params) 1503 | 1504 | expected_b = Nx.tensor([0.92885643, 0.34135267, 0.9368279]) 1505 | expected_e = Nx.tensor([[0.39790204, 0.894515, 0.855929]]) 1506 | expected_next_sum_of_squares_b = Nx.tensor([0.72872776, 0.11318897, 0.8173049]) 1507 | expected_next_sum_of_squares_e = Nx.tensor([[0.11881087, 0.50039333, 0.37399206]]) 1508 | 1509 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1510 | assert {{actual_b, {{}, actual_e}}} = new_updates 1511 | assert {%{sum_of_squares: new_sum_of_squares}} = new_state 1512 | 1513 | assert { 1514 | { 1515 | actual_next_sum_of_squares_b, 1516 | {{}, actual_next_sum_of_squares_e} 1517 | } 1518 | } = new_sum_of_squares 1519 | 1520 | assert_all_close(actual_b, expected_b) 1521 | assert_all_close(actual_e, expected_e) 1522 | assert_all_close(actual_next_sum_of_squares_b, expected_next_sum_of_squares_b) 1523 | assert_all_close(actual_next_sum_of_squares_e, expected_next_sum_of_squares_e) 1524 | end 1525 | end 1526 | 1527 | describe "scale_by_schedule" do 1528 | test "constructs a stateful transformation" do 1529 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1530 | assert {init_fn, update_fn} = scale_by_schedule(Polaris.Schedules.polynomial_decay(1.0e-2)) 1531 | assert is_function(init_fn, 1) 1532 | assert is_function(update_fn, 3) 1533 | assert {schedule_state} = init_fn.(params) 1534 | assert %{count: count} = schedule_state 1535 | assert_equal(count, Nx.tensor(0)) 1536 | end 1537 | 1538 | test "composes with itself" do 1539 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1540 | 1541 | assert {init_fn, update_fn} = 1542 | scale_by_schedule(Polaris.Schedules.polynomial_decay(1.0e-2)) 1543 | |> scale_by_schedule(Polaris.Schedules.polynomial_decay(1.0e-2)) 1544 | 1545 | assert is_function(init_fn, 1) 1546 | assert is_function(update_fn, 3) 1547 | assert {schedule_state_2, schedule_state_1} = init_fn.(params) 1548 | assert %{count: count} = schedule_state_1 1549 | assert_equal(count, Nx.tensor(0)) 1550 | assert %{count: count} = schedule_state_2 1551 | assert_equal(count, Nx.tensor(0)) 1552 | end 1553 | 1554 | test "composes with stateless transformation" do 1555 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1556 | 1557 | assert {init_fn, update_fn} = 1558 | scale_by_schedule(Polaris.Schedules.polynomial_decay(1.0e-2)) |> scale(1.0e-2) 1559 | 1560 | assert is_function(init_fn, 1) 1561 | assert is_function(update_fn, 3) 1562 | assert {schedule_state} = init_fn.(params) 1563 | assert %{count: count} = schedule_state 1564 | assert_equal(count, Nx.tensor(0)) 1565 | end 1566 | 1567 | test "matches optax with simple container" do 1568 | assert {init_fn, update_fn} = scale_by_schedule(Polaris.Schedules.polynomial_decay(1.0e-2)) 1569 | params = %{a: Nx.tensor([0.77425031, 0.65418105, 0.86150202])} 1570 | updates = %{a: Nx.tensor([0.56082198, 0.94549107, 0.54412826])} 1571 | state = init_fn.(params) 1572 | 1573 | expected_a = Nx.tensor([0.00560822, 0.00945491, 0.00544128]) 1574 | expected_next_count = Nx.tensor(1) 1575 | 1576 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1577 | assert %{a: actual_a} = new_updates 1578 | assert {%{count: actual_next_count}} = new_state 1579 | assert_all_close(actual_a, expected_a) 1580 | assert_equal(actual_next_count, expected_next_count) 1581 | end 1582 | 1583 | test "matches optax with nested container" do 1584 | assert {init_fn, update_fn} = scale_by_schedule(Polaris.Schedules.polynomial_decay(1.0e-2)) 1585 | 1586 | params = %{ 1587 | a: %{ 1588 | b: Nx.tensor([0.3440084, 0.16096481, 0.43997161]), 1589 | c: %{d: %{}, e: Nx.tensor([[0.26168961, 0.40905451, 0.3061841]])} 1590 | } 1591 | } 1592 | 1593 | updates = %{ 1594 | a: %{ 1595 | b: Nx.tensor([0.27159927, 0.37657519, 0.38219061]), 1596 | c: %{d: %{}, e: Nx.tensor([[0.9613661, 0.30215168, 0.24110271]])} 1597 | } 1598 | } 1599 | 1600 | state = init_fn.(params) 1601 | 1602 | expected_b = Nx.tensor([0.00271599, 0.00376575, 0.00382191]) 1603 | expected_e = Nx.tensor([[0.00961366, 0.00302152, 0.00241103]]) 1604 | expected_next_count = Nx.tensor(1) 1605 | 1606 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1607 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 1608 | assert {%{count: actual_next_count}} = new_state 1609 | assert_all_close(actual_b, expected_b) 1610 | assert_all_close(actual_e, expected_e) 1611 | assert_equal(actual_next_count, expected_next_count) 1612 | end 1613 | 1614 | test "supports generic container" do 1615 | assert {init_fn, update_fn} = scale_by_schedule(Polaris.Schedules.polynomial_decay(1.0e-2)) 1616 | 1617 | params = { 1618 | { 1619 | Nx.tensor([0.3440084, 0.16096481, 0.43997161]), 1620 | {{}, Nx.tensor([[0.26168961, 0.40905451, 0.3061841]])} 1621 | } 1622 | } 1623 | 1624 | updates = { 1625 | { 1626 | Nx.tensor([0.27159927, 0.37657519, 0.38219061]), 1627 | {{}, Nx.tensor([[0.9613661, 0.30215168, 0.24110271]])} 1628 | } 1629 | } 1630 | 1631 | state = init_fn.(params) 1632 | 1633 | expected_b = Nx.tensor([0.00271599, 0.00376575, 0.00382191]) 1634 | expected_e = Nx.tensor([[0.00961366, 0.00302152, 0.00241103]]) 1635 | expected_next_count = Nx.tensor(1) 1636 | 1637 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1638 | assert {{actual_b, {{}, actual_e}}} = new_updates 1639 | assert {%{count: actual_next_count}} = new_state 1640 | assert_all_close(actual_b, expected_b) 1641 | assert_all_close(actual_e, expected_e) 1642 | assert_equal(actual_next_count, expected_next_count) 1643 | end 1644 | end 1645 | 1646 | describe "scale_by_stddev" do 1647 | test "constructs a stateful transformation" do 1648 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1649 | assert {init_fn, update_fn} = scale_by_stddev() 1650 | assert is_function(init_fn, 1) 1651 | assert is_function(update_fn, 3) 1652 | assert {stddev_state} = init_fn.(params) 1653 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}} = stddev_state 1654 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 1655 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 1656 | end 1657 | 1658 | test "constructs a stateful transformation with options" do 1659 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1660 | assert {init_fn, update_fn} = scale_by_stddev(initial_scale: 0.5) 1661 | assert is_function(init_fn, 1) 1662 | assert is_function(update_fn, 3) 1663 | assert {stddev_state} = init_fn.(params) 1664 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}} = stddev_state 1665 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 1666 | assert_equal(nu_a, Nx.tensor([0.5, 0.5, 0.5])) 1667 | end 1668 | 1669 | test "composes with itself" do 1670 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1671 | 1672 | assert {init_fn, update_fn} = 1673 | scale_by_stddev(initial_scale: 0.1) |> scale_by_stddev(initial_scale: 0.2) 1674 | 1675 | assert is_function(init_fn, 1) 1676 | assert is_function(update_fn, 3) 1677 | assert {stddev_state_2, stddev_state_1} = init_fn.(params) 1678 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}} = stddev_state_1 1679 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 1680 | assert_equal(nu_a, Nx.tensor([0.1, 0.1, 0.1])) 1681 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}} = stddev_state_2 1682 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 1683 | assert_equal(nu_a, Nx.tensor([0.2, 0.2, 0.2])) 1684 | end 1685 | 1686 | test "composes with stateless transformation" do 1687 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1688 | assert {init_fn, update_fn} = scale_by_stddev(initial_scale: 0.1) |> scale(1.0e-2) 1689 | assert is_function(init_fn, 1) 1690 | assert is_function(update_fn, 3) 1691 | assert {stddev_state} = init_fn.(params) 1692 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}} = stddev_state 1693 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 1694 | assert_equal(nu_a, Nx.tensor([0.1, 0.1, 0.1])) 1695 | end 1696 | 1697 | test "matches optax with simple container" do 1698 | assert {init_fn, update_fn} = scale_by_stddev() 1699 | params = %{a: Nx.tensor([0.98013234, 0.0653057, 0.39361905])} 1700 | updates = %{a: Nx.tensor([0.58050587, 0.04869076, 0.62340991])} 1701 | state = init_fn.(params) 1702 | 1703 | expected_a = Nx.tensor([3.3333325, 3.333255, 3.3333328]) 1704 | expected_next_mu_a = Nx.tensor([0.05805059, 0.00486908, 0.06234099]) 1705 | expected_next_nu_a = Nx.tensor([0.03369871, 0.00023708, 0.03886399]) 1706 | 1707 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1708 | assert %{a: actual_a} = new_updates 1709 | assert {%{mu: %{a: actual_next_mu_a}, nu: %{a: actual_next_nu_a}}} = new_state 1710 | assert_all_close(actual_a, expected_a) 1711 | assert_all_close(actual_next_mu_a, expected_next_mu_a) 1712 | assert_all_close(actual_next_nu_a, expected_next_nu_a) 1713 | end 1714 | 1715 | test "matches optax with nested container" do 1716 | assert {init_fn, update_fn} = scale_by_stddev() 1717 | 1718 | params = %{ 1719 | a: %{ 1720 | b: Nx.tensor([0.49792875, 0.04941673, 0.33815839]), 1721 | c: %{d: %{}, e: Nx.tensor([[0.70057761, 0.3689184, 0.36608007]])} 1722 | } 1723 | } 1724 | 1725 | updates = %{ 1726 | a: %{ 1727 | b: Nx.tensor([0.54587409, 0.04849768, 0.23020724]), 1728 | c: %{d: %{}, e: Nx.tensor([[0.29348535, 0.79428645, 0.76129383]])} 1729 | } 1730 | } 1731 | 1732 | state = init_fn.(params) 1733 | 1734 | expected_b = Nx.tensor([3.333333, 3.3332546, 3.33333]) 1735 | expected_e = Nx.tensor([[3.333331, 3.333333, 3.333333]]) 1736 | expected_next_mu_b = Nx.tensor([0.05458741, 0.00484977, 0.02302072]) 1737 | expected_next_mu_e = Nx.tensor([[0.02934854, 0.07942864, 0.07612938]]) 1738 | expected_next_nu_b = Nx.tensor([0.02979785, 0.0002352, 0.00529954]) 1739 | expected_next_nu_e = Nx.tensor([[0.00861336, 0.0630891, 0.05795683]]) 1740 | 1741 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1742 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 1743 | assert {%{mu: new_mu, nu: new_nu}} = new_state 1744 | assert %{a: %{b: actual_next_mu_b, c: %{d: %{}, e: actual_next_mu_e}}} = new_mu 1745 | assert %{a: %{b: actual_next_nu_b, c: %{d: %{}, e: actual_next_nu_e}}} = new_nu 1746 | assert_all_close(actual_b, expected_b) 1747 | assert_all_close(actual_e, expected_e) 1748 | assert_all_close(actual_next_mu_b, expected_next_mu_b) 1749 | assert_all_close(actual_next_mu_e, expected_next_mu_e) 1750 | assert_all_close(actual_next_nu_b, expected_next_nu_b) 1751 | assert_all_close(actual_next_nu_e, expected_next_nu_e) 1752 | end 1753 | 1754 | test "supports generic container" do 1755 | assert {init_fn, update_fn} = scale_by_stddev() 1756 | 1757 | params = { 1758 | { 1759 | Nx.tensor([0.49792875, 0.04941673, 0.33815839]), 1760 | {{}, Nx.tensor([[0.70057761, 0.3689184, 0.36608007]])} 1761 | } 1762 | } 1763 | 1764 | updates = { 1765 | { 1766 | Nx.tensor([0.54587409, 0.04849768, 0.23020724]), 1767 | {{}, Nx.tensor([[0.29348535, 0.79428645, 0.76129383]])} 1768 | } 1769 | } 1770 | 1771 | state = init_fn.(params) 1772 | 1773 | expected_b = Nx.tensor([3.333333, 3.3332546, 3.33333]) 1774 | expected_e = Nx.tensor([[3.333331, 3.333333, 3.333333]]) 1775 | expected_next_mu_b = Nx.tensor([0.05458741, 0.00484977, 0.02302072]) 1776 | expected_next_mu_e = Nx.tensor([[0.02934854, 0.07942864, 0.07612938]]) 1777 | expected_next_nu_b = Nx.tensor([0.02979785, 0.0002352, 0.00529954]) 1778 | expected_next_nu_e = Nx.tensor([[0.00861336, 0.0630891, 0.05795683]]) 1779 | 1780 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1781 | assert {{actual_b, {{}, actual_e}}} = new_updates 1782 | assert {%{mu: new_mu, nu: new_nu}} = new_state 1783 | assert {{actual_next_mu_b, {{}, actual_next_mu_e}}} = new_mu 1784 | assert {{actual_next_nu_b, {{}, actual_next_nu_e}}} = new_nu 1785 | assert_all_close(actual_b, expected_b) 1786 | assert_all_close(actual_e, expected_e) 1787 | assert_all_close(actual_next_mu_b, expected_next_mu_b) 1788 | assert_all_close(actual_next_mu_e, expected_next_mu_e) 1789 | assert_all_close(actual_next_nu_b, expected_next_nu_b) 1790 | assert_all_close(actual_next_nu_e, expected_next_nu_e) 1791 | end 1792 | end 1793 | 1794 | describe "scale_by_trust_ratio" do 1795 | test "constructs a stateless transformation" do 1796 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1797 | assert {init_fn, update_fn} = scale_by_trust_ratio() 1798 | assert is_function(init_fn, 1) 1799 | assert is_function(update_fn, 3) 1800 | assert init_fn.(params) == {} 1801 | end 1802 | 1803 | test "constructs a stateless transformation with options" do 1804 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1805 | assert {init_fn, update_fn} = scale_by_trust_ratio(min_norm: 1.0) 1806 | assert is_function(init_fn, 1) 1807 | assert is_function(update_fn, 3) 1808 | assert init_fn.(params) == {} 1809 | end 1810 | 1811 | test "composes with itself" do 1812 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1813 | 1814 | assert {init_fn, update_fn} = 1815 | scale_by_trust_ratio(min_norm: 1.0) |> scale_by_trust_ratio(min_norm: 1.0) 1816 | 1817 | assert is_function(init_fn, 1) 1818 | assert is_function(update_fn, 3) 1819 | assert init_fn.(params) == {} 1820 | end 1821 | 1822 | test "composes with stateful transformation" do 1823 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1824 | assert {init_fn, update_fn} = scale_by_adam() |> scale_by_trust_ratio(min_norm: 1.0) 1825 | assert is_function(init_fn, 1) 1826 | assert is_function(update_fn, 3) 1827 | assert {adam_state} = init_fn.(params) 1828 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state 1829 | assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) 1830 | assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) 1831 | assert_equal(count, Nx.tensor(0)) 1832 | end 1833 | 1834 | test "matches optax with simple container" do 1835 | assert {init_fn, update_fn} = scale_by_trust_ratio(min_norm: 1.0) 1836 | params = %{a: Nx.tensor([0.07719177, 0.1812708, 0.94959977])} 1837 | updates = %{a: Nx.tensor([0.29626032, 0.328152, 0.20388144])} 1838 | state = init_fn.(params) 1839 | 1840 | expected_a = Nx.tensor([0.29626033, 0.328152, 0.20388144]) 1841 | 1842 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1843 | assert %{a: actual_a} = new_updates 1844 | assert new_state == {} 1845 | assert_all_close(actual_a, expected_a) 1846 | end 1847 | 1848 | test "matches optax with nested container" do 1849 | assert {init_fn, update_fn} = scale_by_trust_ratio(min_norm: 1.0) 1850 | 1851 | params = %{ 1852 | a: %{ 1853 | b: Nx.tensor([0.98282674, 0.34776357, 0.33319137]), 1854 | c: %{d: %{}, e: Nx.tensor([[0.95596768, 0.67948137, 0.05268411]])} 1855 | } 1856 | } 1857 | 1858 | updates = %{ 1859 | a: %{ 1860 | b: Nx.tensor([0.53616958, 0.24854466, 0.26695091]), 1861 | c: %{d: %{}, e: Nx.tensor([[0.50354858, 0.91245821, 0.30518247]])} 1862 | } 1863 | } 1864 | 1865 | state = init_fn.(params) 1866 | 1867 | expected_b = Nx.tensor([0.58683133, 0.27202922, 0.29217464]) 1868 | expected_e = Nx.tensor([[0.5443927, 0.98647004, 0.3299366]]) 1869 | 1870 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1871 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 1872 | assert new_state == {} 1873 | assert_all_close(actual_b, expected_b) 1874 | assert_all_close(actual_e, expected_e) 1875 | end 1876 | 1877 | test "supports generic container" do 1878 | assert {init_fn, update_fn} = scale_by_trust_ratio(min_norm: 1.0) 1879 | 1880 | params = { 1881 | { 1882 | Nx.tensor([0.98282674, 0.34776357, 0.33319137]), 1883 | {{}, Nx.tensor([[0.95596768, 0.67948137, 0.05268411]])} 1884 | } 1885 | } 1886 | 1887 | updates = { 1888 | { 1889 | Nx.tensor([0.53616958, 0.24854466, 0.26695091]), 1890 | {{}, Nx.tensor([[0.50354858, 0.91245821, 0.30518247]])} 1891 | } 1892 | } 1893 | 1894 | state = init_fn.(params) 1895 | 1896 | expected_b = Nx.tensor([0.58683133, 0.27202922, 0.29217464]) 1897 | expected_e = Nx.tensor([[0.5443927, 0.98647004, 0.3299366]]) 1898 | 1899 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1900 | assert {{actual_b, {{}, actual_e}}} = new_updates 1901 | assert new_state == {} 1902 | assert_all_close(actual_b, expected_b) 1903 | assert_all_close(actual_e, expected_e) 1904 | end 1905 | end 1906 | 1907 | describe "scale_by_yogi" do 1908 | test "constructs a stateful transformation" do 1909 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1910 | assert {init_fn, update_fn} = scale_by_yogi() 1911 | assert is_function(init_fn, 1) 1912 | assert is_function(update_fn, 3) 1913 | assert {yogi_state} = init_fn.(params) 1914 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = yogi_state 1915 | assert_equal(mu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) 1916 | assert_equal(nu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) 1917 | assert_equal(count, Nx.tensor(0)) 1918 | end 1919 | 1920 | test "constructs a stateful transformation with options" do 1921 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1922 | assert {init_fn, update_fn} = scale_by_yogi(initial_accumulator_value: 1.0e-4) 1923 | assert is_function(init_fn, 1) 1924 | assert is_function(update_fn, 3) 1925 | assert {yogi_state} = init_fn.(params) 1926 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = yogi_state 1927 | assert_equal(mu_a, Nx.tensor([1.0e-4, 1.0e-4, 1.0e-4])) 1928 | assert_equal(nu_a, Nx.tensor([1.0e-4, 1.0e-4, 1.0e-4])) 1929 | assert_equal(count, Nx.tensor(0)) 1930 | end 1931 | 1932 | test "composes with itself" do 1933 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1934 | assert {init_fn, update_fn} = scale_by_yogi() |> scale_by_yogi() 1935 | assert is_function(init_fn, 1) 1936 | assert is_function(update_fn, 3) 1937 | assert {yogi_state_1, yogi_state_2} = init_fn.(params) 1938 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = yogi_state_1 1939 | assert_equal(mu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) 1940 | assert_equal(nu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) 1941 | assert_equal(count, Nx.tensor(0)) 1942 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = yogi_state_2 1943 | assert_equal(mu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) 1944 | assert_equal(nu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) 1945 | assert_equal(count, Nx.tensor(0)) 1946 | end 1947 | 1948 | test "composes with stateless transformation" do 1949 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 1950 | assert {init_fn, update_fn} = scale_by_yogi() |> scale(1.0e-2) 1951 | assert is_function(init_fn, 1) 1952 | assert is_function(update_fn, 3) 1953 | assert {yogi_state} = init_fn.(params) 1954 | assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = yogi_state 1955 | assert_equal(mu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) 1956 | assert_equal(nu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) 1957 | assert_equal(count, Nx.tensor(0)) 1958 | end 1959 | 1960 | test "matches optax with simple container" do 1961 | assert {init_fn, update_fn} = scale_by_yogi() 1962 | params = %{a: Nx.tensor([0.39152084, 0.86061072, 0.22693509])} 1963 | updates = %{a: Nx.tensor([0.10820288, 0.73034528, 0.6741126])} 1964 | state = init_fn.(params) 1965 | 1966 | expected_a = Nx.tensor([0.95148116, 0.99770474, 0.9974302]) 1967 | expected_next_mu_a = Nx.tensor([0.01082119, 0.07303543, 0.06741216]) 1968 | expected_next_nu_a = Nx.tensor([1.2707865e-05, 5.3440424e-04, 4.5542780e-04]) 1969 | expected_next_count = Nx.tensor(1) 1970 | 1971 | assert {new_updates, new_state} = update_fn.(updates, state, params) 1972 | assert %{a: actual_a} = new_updates 1973 | 1974 | assert {%{mu: %{a: actual_next_mu_a}, nu: %{a: actual_next_nu_a}, count: actual_next_count}} = 1975 | new_state 1976 | 1977 | assert_all_close(actual_a, expected_a) 1978 | assert_all_close(actual_next_mu_a, expected_next_mu_a) 1979 | assert_all_close(actual_next_nu_a, expected_next_nu_a) 1980 | assert_equal(actual_next_count, expected_next_count) 1981 | end 1982 | 1983 | test "matches optax with nested container" do 1984 | assert {init_fn, update_fn} = scale_by_yogi() 1985 | 1986 | params = %{ 1987 | a: %{ 1988 | b: Nx.tensor([0.87690482, 0.80993702, 0.87935556]), 1989 | c: %{d: %{}, e: Nx.tensor([[0.00528695, 0.06690531, 0.12589192]])} 1990 | } 1991 | } 1992 | 1993 | updates = %{ 1994 | a: %{ 1995 | b: Nx.tensor([0.47019351, 0.72034131, 0.32043362]), 1996 | c: %{d: %{}, e: Nx.tensor([[0.84200356, 0.76360484, 0.55381714]])} 1997 | } 1998 | } 1999 | 2000 | state = init_fn.(params) 2001 | 2002 | expected_b = Nx.tensor([0.99564576, 0.9976599, 0.99210596]) 2003 | expected_e = Nx.tensor([[0.9981149, 0.99784315, 0.9965868]]) 2004 | expected_next_mu_b = Nx.tensor([0.04702025, 0.07203503, 0.03204427]) 2005 | expected_next_mu_e = Nx.tensor([[0.08420125, 0.07636139, 0.05538262]]) 2006 | expected_next_nu_b = Nx.tensor([0.00022208, 0.00051989, 0.00010368]) 2007 | expected_next_nu_e = Nx.tensor([[0.00070997, 0.00058409, 0.00030771]]) 2008 | expected_next_count = Nx.tensor(1) 2009 | 2010 | assert {new_updates, new_state} = update_fn.(updates, state, params) 2011 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 2012 | assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state 2013 | assert %{a: %{b: actual_next_mu_b, c: %{d: %{}, e: actual_next_mu_e}}} = new_mu 2014 | assert %{a: %{b: actual_next_nu_b, c: %{d: %{}, e: actual_next_nu_e}}} = new_nu 2015 | assert_all_close(actual_b, expected_b) 2016 | assert_all_close(actual_e, expected_e) 2017 | assert_all_close(actual_next_mu_b, expected_next_mu_b) 2018 | assert_all_close(actual_next_mu_e, expected_next_mu_e) 2019 | assert_all_close(actual_next_nu_b, expected_next_nu_b) 2020 | assert_all_close(actual_next_nu_e, expected_next_nu_e) 2021 | assert_equal(actual_next_count, expected_next_count) 2022 | end 2023 | 2024 | test "supports generic container" do 2025 | assert {init_fn, update_fn} = scale_by_yogi() 2026 | 2027 | params = { 2028 | { 2029 | Nx.tensor([0.87690482, 0.80993702, 0.87935556]), 2030 | {{}, Nx.tensor([[0.00528695, 0.06690531, 0.12589192]])} 2031 | } 2032 | } 2033 | 2034 | updates = { 2035 | { 2036 | Nx.tensor([0.47019351, 0.72034131, 0.32043362]), 2037 | {{}, Nx.tensor([[0.84200356, 0.76360484, 0.55381714]])} 2038 | } 2039 | } 2040 | 2041 | state = init_fn.(params) 2042 | 2043 | expected_b = Nx.tensor([0.99564576, 0.9976599, 0.99210596]) 2044 | expected_e = Nx.tensor([[0.9981149, 0.99784315, 0.9965868]]) 2045 | expected_next_mu_b = Nx.tensor([0.04702025, 0.07203503, 0.03204427]) 2046 | expected_next_mu_e = Nx.tensor([[0.08420125, 0.07636139, 0.05538262]]) 2047 | expected_next_nu_b = Nx.tensor([0.00022208, 0.00051989, 0.00010368]) 2048 | expected_next_nu_e = Nx.tensor([[0.00070997, 0.00058409, 0.00030771]]) 2049 | expected_next_count = Nx.tensor(1) 2050 | 2051 | assert {new_updates, new_state} = update_fn.(updates, state, params) 2052 | assert {{actual_b, {{}, actual_e}}} = new_updates 2053 | assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state 2054 | assert {{actual_next_mu_b, {{}, actual_next_mu_e}}} = new_mu 2055 | assert {{actual_next_nu_b, {{}, actual_next_nu_e}}} = new_nu 2056 | assert_all_close(actual_b, expected_b) 2057 | assert_all_close(actual_e, expected_e) 2058 | assert_all_close(actual_next_mu_b, expected_next_mu_b) 2059 | assert_all_close(actual_next_mu_e, expected_next_mu_e) 2060 | assert_all_close(actual_next_nu_b, expected_next_nu_b) 2061 | assert_all_close(actual_next_nu_e, expected_next_nu_e) 2062 | assert_equal(actual_next_count, expected_next_count) 2063 | end 2064 | end 2065 | 2066 | describe "trace" do 2067 | test "constructs a stateful transformation" do 2068 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 2069 | assert {init_fn, update_fn} = trace() 2070 | assert is_function(init_fn, 1) 2071 | assert is_function(update_fn, 3) 2072 | assert {trace_state} = init_fn.(params) 2073 | assert %{trace: %{a: trace_a}} = trace_state 2074 | assert_equal(trace_a, Nx.tensor([0.0, 0.0, 0.0])) 2075 | end 2076 | 2077 | test "constructs a stateful transformation with options" do 2078 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 2079 | assert {init_fn, update_fn} = trace(decay: 0.8) 2080 | assert is_function(init_fn, 1) 2081 | assert is_function(update_fn, 3) 2082 | assert {trace_state} = init_fn.(params) 2083 | assert %{trace: %{a: trace_a}} = trace_state 2084 | assert_equal(trace_a, Nx.tensor([0.0, 0.0, 0.0])) 2085 | end 2086 | 2087 | test "composes with itself" do 2088 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 2089 | assert {init_fn, update_fn} = trace() |> trace() 2090 | assert is_function(init_fn, 1) 2091 | assert is_function(update_fn, 3) 2092 | assert {trace_state_2, trace_state_1} = init_fn.(params) 2093 | assert %{trace: %{a: trace_a}} = trace_state_1 2094 | assert_equal(trace_a, Nx.tensor([0.0, 0.0, 0.0])) 2095 | assert %{trace: %{a: trace_a}} = trace_state_2 2096 | assert_equal(trace_a, Nx.tensor([0.0, 0.0, 0.0])) 2097 | end 2098 | 2099 | test "composes with stateless transformation" do 2100 | params = %{a: Nx.tensor([1.0, 2.0, 3.0])} 2101 | assert {init_fn, update_fn} = trace() |> scale(1.0e-2) 2102 | assert is_function(init_fn, 1) 2103 | assert is_function(update_fn, 3) 2104 | assert {trace_state} = init_fn.(params) 2105 | assert %{trace: %{a: trace_a}} = trace_state 2106 | assert_equal(trace_a, Nx.tensor([0.0, 0.0, 0.0])) 2107 | end 2108 | 2109 | test "matches optax with simple container, nesterov: false" do 2110 | assert {init_fn, update_fn} = trace(nesterov: false) 2111 | params = %{a: Nx.tensor([0.54044065, 0.54168045, 0.14243068])} 2112 | updates = %{a: Nx.tensor([0.76976679, 0.19561062, 0.84724249])} 2113 | state = init_fn.(params) 2114 | 2115 | expected_a = Nx.tensor([0.7697668, 0.19561061, 0.8472425]) 2116 | expected_next_trace = Nx.tensor([0.7697668, 0.19561061, 0.8472425]) 2117 | 2118 | assert {new_updates, new_state} = update_fn.(updates, state, params) 2119 | assert %{a: actual_a} = new_updates 2120 | assert {%{trace: %{a: actual_next_trace}}} = new_state 2121 | assert_all_close(actual_a, expected_a) 2122 | assert_all_close(actual_next_trace, expected_next_trace) 2123 | end 2124 | 2125 | test "matches optax with nested container, nesterov: false" do 2126 | assert {init_fn, update_fn} = trace(nesterov: false) 2127 | 2128 | params = %{ 2129 | a: %{ 2130 | b: Nx.tensor([0.23468207, 0.75940123, 0.06601013]), 2131 | c: %{d: %{}, e: Nx.tensor([[0.68877159, 0.84383744, 0.15230977]])} 2132 | } 2133 | } 2134 | 2135 | updates = %{ 2136 | a: %{ 2137 | b: Nx.tensor([0.60272336, 0.42772071, 0.39653623]), 2138 | c: %{d: %{}, e: Nx.tensor([[0.25453278, 0.64759897, 0.71080799]])} 2139 | } 2140 | } 2141 | 2142 | state = init_fn.(params) 2143 | 2144 | expected_b = Nx.tensor([0.60272336, 0.4277207, 0.39653623]) 2145 | expected_e = Nx.tensor([[0.25453278, 0.647599, 0.710808]]) 2146 | expected_next_trace_b = Nx.tensor([0.60272336, 0.4277207, 0.39653623]) 2147 | expected_next_trace_e = Nx.tensor([[0.25453278, 0.647599, 0.710808]]) 2148 | 2149 | assert {new_updates, new_state} = update_fn.(updates, state, params) 2150 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 2151 | assert {%{trace: new_trace}} = new_state 2152 | assert %{a: %{b: actual_next_trace_b, c: %{d: %{}, e: actual_next_trace_e}}} = new_trace 2153 | assert_all_close(actual_b, expected_b) 2154 | assert_all_close(actual_e, expected_e) 2155 | assert_all_close(actual_next_trace_b, expected_next_trace_b) 2156 | assert_all_close(actual_next_trace_e, expected_next_trace_e) 2157 | end 2158 | 2159 | test "matches optax with simple container, nesterov: true" do 2160 | assert {init_fn, update_fn} = trace(nesterov: true) 2161 | params = %{a: Nx.tensor([0.05727068, 0.71336316, 0.52111667])} 2162 | updates = %{a: Nx.tensor([0.99510349, 0.38321624, 0.37485662])} 2163 | state = init_fn.(params) 2164 | 2165 | expected_a = Nx.tensor([1.8906965, 0.7281108, 0.7122276]) 2166 | expected_next_trace = Nx.tensor([0.9951035, 0.38321623, 0.37485662]) 2167 | 2168 | assert {new_updates, new_state} = update_fn.(updates, state, params) 2169 | assert %{a: actual_a} = new_updates 2170 | assert {%{trace: %{a: actual_next_trace}}} = new_state 2171 | assert_all_close(actual_a, expected_a) 2172 | assert_all_close(actual_next_trace, expected_next_trace) 2173 | end 2174 | 2175 | test "matches optax with nested container, nesterov: true" do 2176 | assert {init_fn, update_fn} = trace(nesterov: true) 2177 | 2178 | params = %{ 2179 | a: %{ 2180 | b: Nx.tensor([0.81068757, 0.89196671, 0.21672469]), 2181 | c: %{d: %{}, e: Nx.tensor([[0.9194404, 0.19829658, 0.96960522]])} 2182 | } 2183 | } 2184 | 2185 | updates = %{ 2186 | a: %{ 2187 | b: Nx.tensor([0.21182614, 0.29456406, 0.50427876]), 2188 | c: %{d: %{}, e: Nx.tensor([[0.26525984, 0.66349034, 0.11212149]])} 2189 | } 2190 | } 2191 | 2192 | state = init_fn.(params) 2193 | 2194 | expected_b = Nx.tensor([0.40246966, 0.55967176, 0.95812964]) 2195 | expected_e = Nx.tensor([[0.5039937, 1.2606317, 0.21303083]]) 2196 | expected_next_trace_b = Nx.tensor([0.21182615, 0.29456407, 0.5042788]) 2197 | expected_next_trace_e = Nx.tensor([[0.26525983, 0.66349036, 0.11212149]]) 2198 | 2199 | assert {new_updates, new_state} = update_fn.(updates, state, params) 2200 | assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates 2201 | assert {%{trace: new_trace}} = new_state 2202 | assert %{a: %{b: actual_next_trace_b, c: %{d: %{}, e: actual_next_trace_e}}} = new_trace 2203 | assert_all_close(actual_b, expected_b) 2204 | assert_all_close(actual_e, expected_e) 2205 | assert_all_close(actual_next_trace_b, expected_next_trace_b) 2206 | assert_all_close(actual_next_trace_e, expected_next_trace_e) 2207 | end 2208 | 2209 | test "supports generic container" do 2210 | assert {init_fn, update_fn} = trace(nesterov: true) 2211 | 2212 | params = { 2213 | { 2214 | Nx.tensor([0.81068757, 0.89196671, 0.21672469]), 2215 | {{}, Nx.tensor([[0.9194404, 0.19829658, 0.96960522]])} 2216 | } 2217 | } 2218 | 2219 | updates = { 2220 | { 2221 | Nx.tensor([0.21182614, 0.29456406, 0.50427876]), 2222 | {{}, Nx.tensor([[0.26525984, 0.66349034, 0.11212149]])} 2223 | } 2224 | } 2225 | 2226 | state = init_fn.(params) 2227 | 2228 | expected_b = Nx.tensor([0.40246966, 0.55967176, 0.95812964]) 2229 | expected_e = Nx.tensor([[0.5039937, 1.2606317, 0.21303083]]) 2230 | expected_next_trace_b = Nx.tensor([0.21182615, 0.29456407, 0.5042788]) 2231 | expected_next_trace_e = Nx.tensor([[0.26525983, 0.66349036, 0.11212149]]) 2232 | 2233 | assert {new_updates, new_state} = update_fn.(updates, state, params) 2234 | assert {{actual_b, {{}, actual_e}}} = new_updates 2235 | assert {%{trace: new_trace}} = new_state 2236 | assert {{actual_next_trace_b, {{}, actual_next_trace_e}}} = new_trace 2237 | assert_all_close(actual_b, expected_b) 2238 | assert_all_close(actual_e, expected_e) 2239 | assert_all_close(actual_next_trace_b, expected_next_trace_b) 2240 | assert_all_close(actual_next_trace_e, expected_next_trace_e) 2241 | end 2242 | end 2243 | 2244 | describe "accumulate_gradients" do 2245 | test "returns 0 updates, accumulated state, no updated parent state" do 2246 | assert {init_fn, update_fn} = trace(nesterov: true) |> accumulate_gradients(2) 2247 | params = %{a: Nx.tensor([1.0, 2.0])} 2248 | 2249 | updates = %{a: Nx.tensor([1.0, 1.0])} 2250 | 2251 | {%{gradient_state: %{a: init_gradient_state}}, %{trace: init_trace}} = 2252 | state = init_fn.(params) 2253 | 2254 | assert_all_close(init_gradient_state, Nx.tensor([0.0, 0.0])) 2255 | 2256 | # accumulate once 2257 | assert {new_updates, new_state} = update_fn.(updates, state, params) 2258 | assert %{a: a_updates_1} = new_updates 2259 | 2260 | assert {%{step: step, gradient_state: %{a: gradient_state}}, %{trace: new_trace}} = 2261 | new_state 2262 | 2263 | assert_all_close(a_updates_1, Nx.tensor([0.0, 0.0])) 2264 | assert_all_close(step, Nx.tensor(1)) 2265 | assert_all_close(gradient_state, Nx.tensor([1.0, 1.0])) 2266 | assert_all_close(init_trace.a, new_trace.a) 2267 | 2268 | # accumulate twice 2269 | assert {new_updates, new_state} = update_fn.(updates, new_state, params) 2270 | assert %{a: a_updates_2} = new_updates 2271 | 2272 | assert {%{step: step, gradient_state: %{a: gradient_state}}, %{trace: new_trace}} = 2273 | new_state 2274 | 2275 | assert_all_close(a_updates_2, Nx.tensor([0.0, 0.0])) 2276 | assert_all_close(step, Nx.tensor(2)) 2277 | assert_all_close(gradient_state, Nx.tensor([2.0, 2.0])) 2278 | assert_all_close(init_trace.a, new_trace.a) 2279 | 2280 | # now update 2281 | assert {new_updates, new_state} = update_fn.(updates, new_state, params) 2282 | assert %{a: a_updates_3} = new_updates 2283 | 2284 | assert {%{step: step, gradient_state: %{a: gradient_state}}, %{trace: new_trace}} = 2285 | new_state 2286 | 2287 | assert_all_close(a_updates_3, Nx.tensor([0.95, 0.95])) 2288 | assert_all_close(step, Nx.tensor(0)) 2289 | assert_all_close(gradient_state, Nx.tensor([0.0, 0.0])) 2290 | assert_all_close(new_trace.a, Nx.tensor([0.5, 0.5])) 2291 | end 2292 | end 2293 | end 2294 | -------------------------------------------------------------------------------- /test/support/polaris_case.ex: -------------------------------------------------------------------------------- 1 | defmodule Polaris.Case do 2 | use ExUnit.CaseTemplate 3 | 4 | using do 5 | quote do 6 | import Nx.Defn 7 | import Polaris.Case 8 | end 9 | end 10 | 11 | setup config do 12 | Nx.Defn.default_options(compiler: test_compiler()) 13 | Nx.default_backend(test_backend()) 14 | Process.register(self(), config.test) 15 | :ok 16 | end 17 | 18 | def test_compiler do 19 | use_exla? = System.get_env("USE_EXLA") 20 | if use_exla?, do: EXLA, else: Nx.Defn.Evaluator 21 | end 22 | 23 | def test_backend do 24 | cond do 25 | System.get_env("USE_TORCHX") -> Torchx.Backend 26 | System.get_env("USE_EXLA") -> EXLA.Backend 27 | true -> Nx.BinaryBackend 28 | end 29 | end 30 | 31 | def check_optimizer!(optimizer, loss, x0, num_steps) do 32 | check_optimizer_functions!(optimizer) 33 | check_optimizer_run!(optimizer, loss, x0, num_steps) 34 | end 35 | 36 | def assert_all_close(lhs, rhs, opts \\ []) 37 | 38 | def assert_all_close(lhs, rhs, opts) when is_tuple(lhs) and is_tuple(rhs) do 39 | lhs 40 | |> Tuple.to_list() 41 | |> Enum.zip_with(Tuple.to_list(rhs), &assert_all_close(&1, &2, opts)) 42 | end 43 | 44 | def assert_all_close(lhs, rhs, opts) do 45 | res = Nx.all_close(lhs, rhs, opts) |> Nx.backend_transfer(Nx.BinaryBackend) 46 | 47 | unless Nx.to_number(res) == 1 do 48 | raise """ 49 | expected 50 | 51 | #{inspect(Nx.backend_transfer(lhs, Nx.BinaryBackend))} 52 | 53 | to be within tolerance of 54 | 55 | #{inspect(Nx.backend_transfer(rhs, Nx.BinaryBackend))} 56 | """ 57 | end 58 | end 59 | 60 | def assert_equal(lhs, rhs) when is_tuple(lhs) and is_tuple(rhs) do 61 | lhs 62 | |> Tuple.to_list() 63 | |> Enum.zip_with(Tuple.to_list(rhs), &assert_equal/2) 64 | end 65 | 66 | def assert_equal(%Nx.Tensor{} = lhs, %Nx.Tensor{} = rhs) do 67 | res = Nx.equal(lhs, rhs) |> Nx.all() |> Nx.backend_transfer(Nx.BinaryBackend) 68 | 69 | unless Nx.to_number(res) == 1 do 70 | raise """ 71 | expected 72 | 73 | #{inspect(Nx.backend_transfer(lhs, Nx.BinaryBackend))} 74 | 75 | to be equal to 76 | 77 | #{inspect(Nx.backend_transfer(rhs, Nx.BinaryBackend))} 78 | """ 79 | end 80 | end 81 | 82 | def assert_equal(lhs, rhs) when is_map(lhs) and is_map(rhs) do 83 | lhs 84 | |> Map.values() 85 | |> Enum.zip_with(Map.values(rhs), &assert_equal/2) 86 | end 87 | 88 | def assert_not_equal(lhs, rhs) when is_tuple(lhs) and is_tuple(rhs) do 89 | lhs 90 | |> Tuple.to_list() 91 | |> Enum.zip_with(Tuple.to_list(rhs), &assert_not_equal/2) 92 | end 93 | 94 | def assert_not_equal(%Nx.Tensor{} = lhs, %Nx.Tensor{} = rhs) do 95 | res = Nx.equal(lhs, rhs) |> Nx.all() |> Nx.backend_transfer(Nx.BinaryBackend) 96 | 97 | unless Nx.to_number(res) == 0 do 98 | raise """ 99 | expected 100 | 101 | #{inspect(Nx.backend_transfer(lhs, Nx.BinaryBackend))} 102 | 103 | to be not equal to 104 | 105 | #{inspect(Nx.backend_transfer(rhs, Nx.BinaryBackend))} 106 | """ 107 | end 108 | end 109 | 110 | def assert_not_equal(lhs, rhs) when is_map(lhs) and is_map(rhs) do 111 | rhs 112 | |> Map.values() 113 | |> Enum.zip_with(Map.values(rhs), &assert_not_equal/2) 114 | end 115 | 116 | def assert_greater_equal(lhs, rhs) do 117 | res = Nx.greater_equal(lhs, rhs) |> Nx.all() |> Nx.backend_transfer(Nx.BinaryBackend) 118 | 119 | unless Nx.to_number(res) == 1 do 120 | raise """ 121 | expected 122 | 123 | #{inspect(Nx.backend_transfer(lhs, Nx.BinaryBackend))} 124 | 125 | to be greater than or equal to 126 | 127 | #{inspect(Nx.backend_transfer(rhs, Nx.BinaryBackend))} 128 | """ 129 | end 130 | end 131 | 132 | defp check_optimizer_functions!(optimizer) do 133 | {init_fn, update_fn} = optimizer 134 | is_function(init_fn, 1) and is_function(update_fn, 3) 135 | end 136 | 137 | defp check_optimizer_run!(optimizer, loss, x0, num_steps) do 138 | {init_fn, update_fn} = optimizer 139 | opt_state = init_fn.(x0) 140 | state = {x0, opt_state} 141 | 142 | step_fn = fn state -> 143 | {params, opt_state} = state 144 | gradients = Nx.Defn.grad(params, loss) 145 | {updates, new_state} = update_fn.(gradients, opt_state, params) 146 | {Polaris.Updates.apply_updates(updates, params), new_state} 147 | end 148 | 149 | {params, _} = 150 | for _ <- 1..num_steps, reduce: state do 151 | state -> 152 | apply(Nx.Defn.jit(step_fn), [state]) 153 | end 154 | 155 | lhs = loss.(params) 156 | rhs = 1.0e-2 157 | 158 | res = Nx.less_equal(lhs, rhs) |> Nx.all() |> Nx.backend_transfer(Nx.BinaryBackend) 159 | 160 | # Some optimizers require 1-D or 2-D input, so this potentially 161 | # could be multi-dimensional 162 | unless Nx.to_number(res) == 1 do 163 | raise """ 164 | expected 165 | 166 | #{inspect(lhs)} 167 | 168 | to be less than or equal to 169 | 170 | #{inspect(rhs)} 171 | """ 172 | end 173 | end 174 | end 175 | -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | ExUnit.start() 2 | --------------------------------------------------------------------------------