├── config ├── docs.exs ├── prod.exs ├── test.exs ├── dev.exs ├── config.exs └── runtime.exs ├── test ├── test_helper.exs ├── rein_test.exs └── rein │ └── utils │ ├── noise │ └── ou_process_test.exs │ └── circular_buffer_test.exs ├── .formatter.exs ├── .gitignore ├── LICENSE ├── .github └── workflows │ └── elixir.yml ├── lib ├── rein │ ├── environment.ex │ ├── agent.ex │ ├── utils │ │ ├── noise │ │ │ └── ou_process.ex │ │ └── circular_buffer.ex │ ├── environments │ │ └── gridworld.ex │ └── agents │ │ ├── q_learning.ex │ │ ├── dqn.ex │ │ ├── ddpg.ex │ │ └── sac.ex └── rein.ex ├── README.md ├── mix.exs ├── mix.lock └── guides └── gridworld.livemd /config/docs.exs: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | ExUnit.start() 2 | -------------------------------------------------------------------------------- /config/prod.exs: -------------------------------------------------------------------------------- 1 | import Config 2 | 3 | config :logger, level: :info 4 | -------------------------------------------------------------------------------- /config/test.exs: -------------------------------------------------------------------------------- 1 | import Config 2 | 3 | config :logger, level: :error 4 | -------------------------------------------------------------------------------- /test/rein_test.exs: -------------------------------------------------------------------------------- 1 | defmodule ReinTest do 2 | use ExUnit.Case, async: true 3 | end 4 | -------------------------------------------------------------------------------- /config/dev.exs: -------------------------------------------------------------------------------- 1 | import Config 2 | 3 | config :logger, level: :debug 4 | config :logger, :console, format: "$time $metadata[$level] $message\n", metadata: :all 5 | -------------------------------------------------------------------------------- /.formatter.exs: -------------------------------------------------------------------------------- 1 | [ 2 | import_deps: [:nx], 3 | inputs: ["*.{ex,exs}", "priv/*/seeds.exs", "{config,lib,test}/**/*.{ex,exs}"], 4 | subdirectories: ["priv/*/migrations"] 5 | ] 6 | -------------------------------------------------------------------------------- /config/config.exs: -------------------------------------------------------------------------------- 1 | import Config 2 | 3 | config :logger, :console, 4 | format: "$time $metadata[$level] $message\n", 5 | metadata: [:request_id] 6 | 7 | import_config "#{config_env()}.exs" 8 | -------------------------------------------------------------------------------- /config/runtime.exs: -------------------------------------------------------------------------------- 1 | import Config 2 | 3 | backend_env = System.get_env("REIN_NX_BACKEND") 4 | 5 | {backend, defn_opts} = 6 | case backend_env do 7 | "torchx" -> 8 | {Torchx.Backend, []} 9 | 10 | "binary" -> 11 | {Nx.BinaryBackend, []} 12 | 13 | _ -> 14 | {EXLA.Backend, compiler: EXLA, memory_fraction: 0.5} 15 | end 16 | 17 | config :nx, 18 | default_backend: backend, 19 | global_default_backend: backend, 20 | default_defn_options: defn_opts 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build/ 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover/ 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps/ 9 | 10 | # Where 3rd-party dependencies like ExDoc output generated docs. 11 | /doc/ 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | # Ignore package tarball (built via "mix hex.build"). 23 | rein-*.tar 24 | 25 | # Ignore assets that are produced by build tools. 26 | /priv/static/assets/ 27 | 28 | # Ignore digested assets cache. 29 | /priv/static/cache_manifest.json 30 | 31 | # In case you use Node.js/npm, you want to ignore these. 32 | npm-debug.log 33 | /assets/node_modules/ 34 | 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 DockYard, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/elixir.yml: -------------------------------------------------------------------------------- 1 | # This workflow uses actions that are not certified by GitHub. 2 | # They are provided by a third-party and are governed by 3 | # separate terms of service, privacy policy, and support 4 | # documentation. 5 | 6 | name: Elixir CI 7 | 8 | on: 9 | push: 10 | branches: [ "main" ] 11 | pull_request: 12 | branches: [ "main" ] 13 | 14 | permissions: 15 | contents: read 16 | 17 | jobs: 18 | build: 19 | 20 | name: Build and test 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Elixir 26 | uses: erlef/setup-beam@61e01a43a562a89bfc54c7f9a378ff67b03e4a21 # v1.16.0 27 | with: 28 | elixir-version: '1.15.2' # [Required] Define the Elixir version 29 | otp-version: '26.0' # [Required] Define the Erlang/OTP version 30 | - name: Restore dependencies cache 31 | uses: actions/cache@v3 32 | with: 33 | path: deps 34 | key: ${{ runner.os }}-mix-${{ hashFiles('**/mix.lock') }} 35 | restore-keys: ${{ runner.os }}-mix- 36 | - name: Install dependencies 37 | run: mix deps.get 38 | - name: Formatting 39 | run: mix format --check-formatted 40 | - name: Run tests 41 | run: mix test --warnings-as-errors 42 | -------------------------------------------------------------------------------- /test/rein/utils/noise/ou_process_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Rein.Utils.Noise.OUProcessTest do 2 | use ExUnit.Case, async: true 3 | 4 | alias Rein.Utils.Noise.OUProcess 5 | 6 | test "generates samples with given shape" do 7 | Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) 8 | Nx.default_backend(Nx.BinaryBackend) 9 | 10 | key = Nx.Random.key(1) 11 | 12 | state = OUProcess.init({2}) 13 | range = 1..10 14 | 15 | {values, _key} = 16 | Enum.map_reduce(range, {key, state}, fn _, {prev_key, state} -> 17 | {state, key} = OUProcess.sample(prev_key, state) 18 | 19 | assert key.data.__struct__ == Nx.BinaryBackend 20 | refute key == prev_key 21 | 22 | {state.x, {key, state}} 23 | end) 24 | 25 | assert values == [ 26 | Nx.tensor([-0.161521315574646, -0.04836982488632202]), 27 | Nx.tensor([-0.022248566150665283, -0.040264029055833817]), 28 | Nx.tensor([-0.09898112714290619, 0.007571600377559662]), 29 | Nx.tensor([0.2752320170402527, 0.27117177844047546]), 30 | Nx.tensor([0.19806107878684998, 0.3740113377571106]), 31 | Nx.tensor([0.3326162099838257, 0.45093610882759094]), 32 | Nx.tensor([0.5560828447341919, 0.3771272897720337]), 33 | Nx.tensor([0.41871464252471924, 0.24803756177425385]), 34 | Nx.tensor([0.04342368245124817, 0.10074643790721893]), 35 | Nx.tensor([-0.3225524425506592, 0.020469389855861664]) 36 | ] 37 | end 38 | end 39 | -------------------------------------------------------------------------------- /lib/rein/environment.ex: -------------------------------------------------------------------------------- 1 | defmodule Rein.Environment do 2 | @moduledoc """ 3 | Defines an environment to be passed to `Rein`. 4 | """ 5 | 6 | @typedoc "An arbitrary `Nx.Container` that holds metadata for the environment" 7 | @type t :: Nx.Container.t() 8 | 9 | @typedoc "The full state of the current Reinforcement Learning process, as stored in the `Rein` struct" 10 | @type rl_state :: Rein.t() 11 | 12 | @doc """ 13 | Initializes the environment state with the given enviroment-specific options. 14 | 15 | Should be implemented in a way that the result would be semantically 16 | the same as if `c:reset/2` was called in the end of the function. 17 | 18 | As a suggestion, the implementation should only initialize fixed 19 | values here, that is values that don't change between sessions 20 | (epochs for non-episodic tasks, episodes for episodic tasks). Then, 21 | call `c:reset/2` internally to initialize the rest of variable values. 22 | """ 23 | @callback init(random_key :: Nx.t(), opts :: keyword) :: {t(), random_key :: Nx.t()} 24 | 25 | @doc """ 26 | Resets any values that vary between sessions (which would be episodes 27 | for episodic tasks, epochs for non-episodic tasks) for the environment state. 28 | """ 29 | @callback reset(random_key :: Nx.t(), environment_state :: t) :: {t(), random_key :: Nx.t()} 30 | 31 | @doc """ 32 | Applies the selected action to the environment. 33 | 34 | Returns the updated environment, also updated with the reward 35 | and a flag indicating whether the new state is terminal. 36 | """ 37 | @callback apply_action(rl_state, action :: Nx.t()) :: rl_state 38 | end 39 | -------------------------------------------------------------------------------- /lib/rein/agent.ex: -------------------------------------------------------------------------------- 1 | defmodule Rein.Agent do 2 | @moduledoc """ 3 | The behaviour that should be implemented by a `Rein` agent module. 4 | """ 5 | 6 | @typedoc "An arbitrary `Nx.Container` that holds metadata for the agent" 7 | @type t :: Nx.Container.t() 8 | 9 | @typedoc "The full state of the current Reinforcement Learning process, as stored in the `Rein` struct" 10 | @type rl_state :: Rein.t() 11 | 12 | @doc """ 13 | Initializes the agent state with the given agent-specific options. 14 | 15 | Should be implemented in a way that the result would be semantically 16 | the same as if `c:reset/2` was called in the end of the function. 17 | 18 | As a suggestion, the implementation should only initialize fixed 19 | values here, that is values that don't change between sessions 20 | (epochs for non-episodic tasks, episodes for episodic tasks). Then, 21 | call `c:reset/2` internally to initialize the rest of variable values. 22 | """ 23 | @callback init(random_key :: Nx.t(), opts :: keyword) :: {t(), random_key :: Nx.t()} 24 | 25 | @doc """ 26 | Resets any values that vary between sessions (which would be episodes 27 | for episodic tasks) for the agent state. 28 | """ 29 | @callback reset(random_key :: Nx.t(), rl_state :: t) :: {t(), random_key :: Nx.t()} 30 | 31 | @doc """ 32 | Selects the action to be taken. 33 | """ 34 | @callback select_action(rl_state, iteration :: Nx.t()) :: {action :: Nx.t(), rl_state} 35 | 36 | @doc """ 37 | Can be used to record the observation in an experience replay buffer. 38 | 39 | If this is not desired, just make this function return the first argument unchanged. 40 | """ 41 | @callback record_observation( 42 | rl_state, 43 | action :: Nx.t(), 44 | reward :: Nx.t(), 45 | is_terminal :: Nx.t(), 46 | next_rl_state :: rl_state 47 | ) :: rl_state 48 | @callback optimize_model(rl_state) :: rl_state 49 | end 50 | -------------------------------------------------------------------------------- /lib/rein/utils/noise/ou_process.ex: -------------------------------------------------------------------------------- 1 | defmodule Rein.Utils.Noise.OUProcess do 2 | @moduledoc """ 3 | Ornstein-Uhlenbeck (OU for short) noise generator 4 | for temporally correlated noise. 5 | """ 6 | 7 | import Nx.Defn 8 | 9 | @derive {Nx.Container, keep: [], containers: [:theta, :sigma, :mu, :x]} 10 | defstruct [:theta, :sigma, :mu, :x] 11 | 12 | @doc """ 13 | Initializes the `#{__MODULE__}`. 14 | 15 | ## Options 16 | 17 | * `:theta` - the temperature parameter. Defaults to `0.15`. 18 | * `:sigma` - the standard deviation parameter. Defaults to `0.2`. 19 | * `:mu` - the initial mean for the distribution. Defaults to `0`. 20 | * `:type` - the output type for the samples. Should be floating point. 21 | Defaults to `:f32`. 22 | """ 23 | deftransform init(shape, opts \\ []) do 24 | opts = Keyword.validate!(opts, theta: 0.15, sigma: 0.2, type: :f32, mu: 0) 25 | 26 | theta = opts[:theta] 27 | sigma = opts[:sigma] 28 | type = opts[:type] 29 | mu = opts[:mu] 30 | mu = Nx.as_type(mu, type) 31 | 32 | x = Nx.broadcast(mu, shape) 33 | %__MODULE__{theta: theta, sigma: sigma, mu: mu, x: x} 34 | end 35 | 36 | @doc """ 37 | Resets the process to the initial value. 38 | """ 39 | defn reset(state) do 40 | x = Nx.broadcast(state.mu, state.x) 41 | %__MODULE__{state | x: x} 42 | end 43 | 44 | @doc """ 45 | Samples the process and returns the updated `state` and the updated `random_key`. 46 | 47 | The new sample is contained within `state.x`. 48 | """ 49 | defn sample(random_key, state) do 50 | %__MODULE__{x: x, sigma: sigma, theta: theta, mu: mu} = state 51 | 52 | {state, random_key} = 53 | if sigma == 0 do 54 | {state, random_key} 55 | else 56 | {sample, random_key} = Nx.Random.normal(random_key, shape: Nx.shape(x)) 57 | dx = theta * (mu - x) + sigma * sample 58 | x = x + dx 59 | 60 | {%__MODULE__{state | x: x}, random_key} 61 | end 62 | 63 | {state, Nx.as_type(random_key, :u32)} 64 | end 65 | end 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rein 2 | 3 | > :warning: **This library is a work in progress!** 4 | 5 | Reinforcement Learning algorithms written in [Nx](https://github.com/elixir-nx/nx/tree/main/nx#readme). 6 | 7 | ## Installation 8 | 9 | If [available in Hex](https://hex.pm/docs/publish), the package can be installed 10 | by adding `rein` to your list of dependencies in `mix.exs`: 11 | 12 | ```elixir 13 | def deps do 14 | [ 15 | {:rein, "~> 0.1.0"} 16 | ] 17 | end 18 | ``` 19 | 20 | ### Dependencies 21 | 22 | This library has no external dependencies. However, 23 | one should be able to run [EXLA](https://github.com/elixir-nx/nx/tree/main/exla#readme), which is the default backend and compiler. 24 | 25 | [Torchx](https://github.com/elixir-nx/nx/tree/main/torchx#readme) can also be used through the `REIN_NX_BACKEND` environment variable. 26 | 27 | ### Environment variables 28 | 29 | - REIN_NX_BACKEND 30 | If set to "torchx", will use Torchx as the default backend. If "binary", uses plain Nx.BinaryBackend. 31 | Otherwise, will use EXLA as the default backend and compiler. 32 | 33 | For EXLA and Torchx, each have their own available environment variables as well. 34 | 35 | ## Authors ## 36 | 37 | - [Paulo Valente](https://github.com/polvalente) 38 | 39 | [We are very thankful for the many contributors](https://github.com/dockyard/rein/graphs/contributors) 40 | 41 | ## Versioning ## 42 | 43 | This library follows [Semantic Versioning](https://semver.org) 44 | 45 | ## Looking for help with your Elixir project? ## 46 | 47 | [At DockYard we are ready to help you build your next Elixir project](https://dockyard.com/phoenix-consulting). We have a unique expertise 48 | in Elixir and Phoenix development that is unmatched. [Get in touch!](https://dockyard.com/contact/hire-us) 49 | 50 | At DockYard we love Elixir! You can [read our Elixir blog posts](https://dockyard.com/blog/categories/elixir) 51 | 52 | ## Legal ## 53 | 54 | [DockYard](https://dockyard.com/), Inc. © 2023 55 | 56 | [@DockYard](https://twitter.com/DockYard) 57 | 58 | [Licensed under the MIT license](https://www.opensource.org/licenses/mit-license.php) 59 | -------------------------------------------------------------------------------- /test/rein/utils/circular_buffer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Rein.Utils.CircularBufferTest do 2 | use ExUnit.Case 3 | 4 | alias Rein.Utils.CircularBuffer 5 | 6 | test "persists data and reorders correctly" do 7 | Nx.default_backend(Nx.BinaryBackend) 8 | Nx.Defn.default_options(compiler: Nx.Defn.Evaluator) 9 | 10 | buffer = CircularBuffer.new({3, 2}, init_value: 0, type: :s64) 11 | 12 | assert %CircularBuffer{size: size, index: index, data: data} = 13 | buffer = CircularBuffer.append(buffer, Nx.tensor([0, 1])) 14 | 15 | assert size == Nx.tensor(1) 16 | assert index == Nx.tensor(1) 17 | 18 | assert data == 19 | Nx.tensor([ 20 | [0, 1], 21 | [0, 0], 22 | [0, 0] 23 | ]) 24 | 25 | assert %CircularBuffer{size: size, index: index, data: data} = 26 | buffer = CircularBuffer.append(buffer, Nx.tensor([2, 3])) 27 | 28 | assert size == Nx.tensor(2) 29 | assert index == Nx.tensor(2) 30 | 31 | assert data == 32 | Nx.tensor([ 33 | [0, 1], 34 | [2, 3], 35 | [0, 0] 36 | ]) 37 | 38 | assert %CircularBuffer{size: size, index: index, data: data} = 39 | buffer = CircularBuffer.append(buffer, Nx.tensor([4, 5])) 40 | 41 | assert size == Nx.tensor(3) 42 | assert index == Nx.tensor(0) 43 | 44 | assert data == 45 | Nx.tensor([ 46 | [0, 1], 47 | [2, 3], 48 | [4, 5] 49 | ]) 50 | 51 | assert %CircularBuffer{size: size, index: index, data: data} = 52 | buffer = CircularBuffer.append(buffer, Nx.tensor([6, 7])) 53 | 54 | assert size == Nx.tensor(3) 55 | assert index == Nx.tensor(1) 56 | 57 | assert data == 58 | Nx.tensor([ 59 | [6, 7], 60 | [2, 3], 61 | [4, 5] 62 | ]) 63 | 64 | assert CircularBuffer.ordered_data(buffer) == 65 | Nx.tensor([ 66 | [2, 3], 67 | [4, 5], 68 | [6, 7] 69 | ]) 70 | end 71 | end 72 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule Rein.MixProject do 2 | use Mix.Project 3 | 4 | @source_url "https://github.com/DockYard/rein" 5 | @version "0.1.0" 6 | 7 | def project do 8 | [ 9 | app: :rein, 10 | version: "0.1.0", 11 | elixir: "~> 1.14", 12 | elixirc_paths: elixirc_paths(Mix.env()), 13 | compilers: Mix.compilers(), 14 | start_permanent: Mix.env() == :prod, 15 | deps: deps(), 16 | package: package(), 17 | docs: docs(), 18 | description: "Reinforcement Learning built with Nx", 19 | preferred_cli_env: [ 20 | docs: :docs, 21 | "hex.publish": :docs 22 | ] 23 | ] 24 | end 25 | 26 | # Configuration for the OTP application. 27 | # 28 | # Type `mix help compile.app` for more information. 29 | def application do 30 | [extra_applications: [:logger, :runtime_tools]] 31 | end 32 | 33 | # Specifies which paths to compile per environment. 34 | defp elixirc_paths(:test), do: ["lib", "test/support"] 35 | defp elixirc_paths(_), do: ["lib"] 36 | 37 | # Specifies your project dependencies. 38 | # 39 | # Type `mix help deps` for examples and options. 40 | defp deps do 41 | [ 42 | {:ex_doc, "~> 0.30", only: :docs}, 43 | {:nx, "~> 0.6"}, 44 | {:axon, "~> 0.6"} 45 | | backend() 46 | ] 47 | end 48 | 49 | defp backend do 50 | case System.get_env("REIN_NX_BACKEND") do 51 | "torchx" -> 52 | [{:torchx, "~> 0.6"}] 53 | 54 | "binary" -> 55 | [] 56 | 57 | _ -> 58 | [{:exla, "~> 0.6"}] 59 | end 60 | end 61 | 62 | defp package do 63 | [ 64 | maintainers: ["Paulo Valente"], 65 | licenses: ["MIT"], 66 | links: %{"GitHub" => @source_url} 67 | ] 68 | end 69 | 70 | defp docs do 71 | [ 72 | main: "Rein", 73 | source_url_pattern: "#{@source_url}/blob/v#{@version}/rein/%{path}#L%{line}", 74 | extras: [ 75 | "guides/gridworld.livemd" 76 | ], 77 | groups_for_functions: [], 78 | groups_for_modules: [ 79 | Agents: [ 80 | Rein.Agents.QLearning, 81 | Rein.Agents.DQN, 82 | Rein.Agents.DDPG, 83 | Rein.Agents.SAC 84 | ], 85 | Environments: [ 86 | Rein.Environments.Gridworld 87 | ], 88 | Utils: [ 89 | Rein.Utils.CircularBuffer, 90 | Rein.Utils.Noise.OUProcess 91 | ] 92 | ] 93 | ] 94 | end 95 | end 96 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "axon": {:hex, :axon, "0.6.0", "fd7560079581e4cedebaf0cd5f741d6ac3516d06f204ebaf1283b1093bf66ff6", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:kino_vega_lite, "~> 0.1.7", [hex: :kino_vega_lite, repo: "hexpm", optional: true]}, {:nx, "~> 0.6.0", [hex: :nx, repo: "hexpm", optional: false]}, {:polaris, "~> 0.1", [hex: :polaris, repo: "hexpm", optional: false]}, {:table_rex, "~> 3.1.1", [hex: :table_rex, repo: "hexpm", optional: true]}], "hexpm", "204e7aeb50d231a30b25456adf17bfbaae33fe7c085e03793357ac3bf62fd853"}, 3 | "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, 4 | "earmark_parser": {:hex, :earmark_parser, "1.4.33", "3c3fd9673bb5dcc9edc28dd90f50c87ce506d1f71b70e3de69aa8154bc695d44", [:mix], [], "hexpm", "2d526833729b59b9fdb85785078697c72ac5e5066350663e5be6a1182da61b8f"}, 5 | "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, 6 | "ex_doc": {:hex, :ex_doc, "0.30.5", "aa6da96a5c23389d7dc7c381eba862710e108cee9cfdc629b7ec021313900e9e", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "88a1e115dcb91cefeef7e22df4a6ebbe4634fbf98b38adcbc25c9607d6d9d8e6"}, 7 | "exla": {:hex, :exla, "0.6.0", "af63e45ce41ad25630967923147d14292a0cc48e507b8a3cf3bf3d5483099a28", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.5.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "5f6a4a105ea9ab207b9aa4de5a294730e2bfe9639f4b8d37a7c00da131090d7a"}, 8 | "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, 9 | "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, 10 | "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, 11 | "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, 12 | "nx": {:hex, :nx, "0.6.0", "37c86eae824125a7e298dd1ee896953d9d671ce3630dcff74c77db17d734a85f", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e1ad3cc70a5828a1aedb156b71e90863d9623a2dc9b35a5588f8627a07ee6cb4"}, 13 | "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, 14 | "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, 15 | "xla": {:hex, :xla, "0.5.0", "fb8a02c02e5a4f4531fbf18a90c325e471037f983f0115d23f510e7dd9a6aa65", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "571ac797a4244b8ba8552ed0295a54397bd896708be51e4da6cbb784f6678061"}, 16 | } 17 | -------------------------------------------------------------------------------- /lib/rein/utils/circular_buffer.ex: -------------------------------------------------------------------------------- 1 | defmodule Rein.Utils.CircularBuffer do 2 | @moduledoc """ 3 | Circular Buffer utility via Nx Containers. 4 | """ 5 | 6 | import Nx.Defn 7 | 8 | @derive {Nx.Container, containers: [:data, :index, :size], keep: []} 9 | defstruct [:data, :index, :size] 10 | 11 | @doc """ 12 | Creates a new `#{__MODULE__}` with a given shape. 13 | 14 | ## Options 15 | 16 | * `:init_value` - a number or tensor that will be broadcasted 17 | to `shape`. Defaults to `0`. If the value given is vectorized, 18 | the buffer will be vectorized accordingly, but all entries will 19 | share the same current index and size. 20 | 21 | * `:type` - the type for the tensor if `:init_value`. 22 | Defaults to `:f32` 23 | """ 24 | def new(shape, opts \\ [init_value: 0, type: :f32]) do 25 | opts = Keyword.validate!(opts, init_value: 0, type: :f32) 26 | 27 | init_value = 28 | opts[:init_value] 29 | |> Nx.to_tensor() 30 | |> Nx.as_type(opts[:type]) 31 | 32 | %__MODULE__{ 33 | data: Nx.broadcast(init_value, shape), 34 | size: 0, 35 | index: 0 36 | } 37 | end 38 | 39 | @doc """ 40 | Append an item to the current buffer. 41 | 42 | If the `buffer` data has shape `{a, b, c, ...}`, 43 | `item` must have shape `{b, c, ...}` 44 | """ 45 | deftransform append(buffer, item) do 46 | starts = append_start_indices(buffer) 47 | n = Nx.axis_size(buffer.data, 0) 48 | index = Nx.remainder(Nx.add(buffer.index, 1), n) 49 | size = Nx.min(n, Nx.add(buffer.size, 1)) 50 | 51 | data = 52 | case buffer.data.vectorized_axes do 53 | [] -> 54 | Nx.put_slice(buffer.data, starts, Nx.new_axis(item, 0)) 55 | 56 | _ -> 57 | [data, item | starts] = Nx.broadcast_vectors([buffer.data, item | starts]) 58 | axes = data.vectorized_axes 59 | 60 | data = 61 | Nx.revectorize(data, [], target_shape: Tuple.insert_at(buffer.data.shape, 0, :auto)) 62 | 63 | starts = Enum.map(starts, &Nx.revectorize(&1, [], target_shape: {:auto})) 64 | item = Nx.revectorize(item, [], target_shape: Tuple.insert_at(item.shape, 0, :auto)) 65 | 66 | for i <- 0..(Nx.axis_size(data, 0) - 1), reduce: data do 67 | data -> 68 | starts = Enum.map(starts, & &1[i]) 69 | 70 | item = item[i..i] 71 | 72 | Nx.put_slice( 73 | data, 74 | [i | starts], 75 | Nx.reshape( 76 | item, 77 | Tuple.duplicate(1, Nx.rank(data) - 1) |> Tuple.append(Nx.size(item)) 78 | ) 79 | ) 80 | end 81 | |> Nx.vectorize(axes) 82 | end 83 | 84 | %{ 85 | buffer 86 | | data: data, 87 | size: size, 88 | index: index 89 | } 90 | end 91 | 92 | deftransformp append_start_indices(buffer) do 93 | [buffer.index | List.duplicate(0, tuple_size(buffer.data.shape) - 1)] 94 | end 95 | 96 | @doc """ 97 | Append multiple items to the buffer. 98 | 99 | Works in a similar fashion to `append/2`, but receives 100 | a tensor with shape equal to the buffer data except 101 | for the first axis, which will be the number of items to be appended. 102 | """ 103 | deftransform append_multiple(buffer, items) do 104 | starts = append_start_indices(buffer) 105 | n = Nx.axis_size(buffer.data, 0) 106 | 107 | case buffer.data.vectorized_axes do 108 | [] -> 109 | for i <- 0..(Nx.axis_size(items, 0) - 1), reduce: buffer do 110 | buffer -> 111 | %{ 112 | buffer 113 | | index: Nx.remainder(Nx.add(buffer.index, 1), n), 114 | data: Nx.put_slice(buffer.data, starts, Nx.new_axis(items[i], 0)), 115 | size: Nx.min(n, Nx.add(buffer.size, 1)) 116 | } 117 | end 118 | 119 | _ -> 120 | raise "not implemented for vectorized buffer" 121 | end 122 | end 123 | 124 | @doc """ 125 | Returns the data starting at the current index. 126 | 127 | The oldest persisted entry will be the first entry in 128 | the result, and so on. 129 | """ 130 | defn ordered_data(buffer) do 131 | n = elem(buffer.data.shape, 0) 132 | indices = Nx.remainder(Nx.iota({n}) + buffer.index, n) 133 | Nx.take(buffer.data, indices) 134 | end 135 | end 136 | -------------------------------------------------------------------------------- /lib/rein/environments/gridworld.ex: -------------------------------------------------------------------------------- 1 | defmodule Rein.Environments.Gridworld do 2 | @moduledoc """ 3 | Gridworld environment with 4 discrete actions. 4 | 5 | Gridworld is an environment where the agent 6 | aims to reach a given target from a collection 7 | of possible targets, only being able to choose 8 | 1 of 4 actions: up, down, left and right. 9 | """ 10 | import Nx.Defn 11 | 12 | @behaviour Rein.Environment 13 | 14 | @derive {Nx.Container, 15 | containers: [ 16 | :x, 17 | :y, 18 | :prev_x, 19 | :prev_y, 20 | :target_x, 21 | :target_y, 22 | :reward, 23 | :is_terminal, 24 | :possible_targets, 25 | :has_reached_target 26 | ], 27 | keep: []} 28 | defstruct [ 29 | :x, 30 | :y, 31 | :prev_x, 32 | :prev_y, 33 | :target_x, 34 | :target_y, 35 | :reward, 36 | :is_terminal, 37 | :possible_targets, 38 | :has_reached_target 39 | ] 40 | 41 | @min_x 0 42 | @max_x 4 43 | @min_y 0 44 | @max_y 4 45 | 46 | def bounding_box, do: {@min_x, @max_x, @min_y, @max_y} 47 | 48 | # x, y, target_x, target_y, has_reached_target, distance_norm 49 | @doc "The size of the state vector returned by `as_state_vector/1`" 50 | def state_vector_size, do: 6 51 | 52 | # up, down, left, right 53 | def num_actions, do: 4 54 | 55 | @impl true 56 | def init(random_key, opts) do 57 | opts = Keyword.validate!(opts, [:possible_targets]) 58 | 59 | possible_targets = 60 | opts[:possible_targets] || raise ArgumentError, "missing option :possible_targets" 61 | 62 | reset(random_key, %__MODULE__{possible_targets: possible_targets}) 63 | end 64 | 65 | @impl true 66 | def reset(random_key, %__MODULE__{} = state) do 67 | reward = Nx.tensor(0, type: :f32) 68 | {x, random_key} = Nx.Random.randint(random_key, @min_x, @max_x) 69 | 70 | # possible_targets is a {n, 2} tensor that contains targets that we want to sample from 71 | # this is so we avoid retraining every episode on the same target, which can lead to 72 | # overfitting 73 | {target, random_key} = 74 | Nx.Random.choice(random_key, state.possible_targets, samples: 1, axis: 0) 75 | 76 | target = Nx.reshape(target, {2}) 77 | 78 | target_x = target[0] 79 | target_y = target[1] 80 | 81 | y = Nx.tensor(0, type: :s64) 82 | 83 | # [x, y, target_x, target_y, zero_bool, _key] = 84 | # Nx.broadcast_vectors([x, y, target[0], target[1], random_key, Nx.u8(0)]) 85 | 86 | state = %{ 87 | state 88 | | x: x, 89 | y: y, 90 | prev_x: x, 91 | prev_y: y, 92 | target_x: target_x, 93 | target_y: target_y, 94 | reward: reward, 95 | is_terminal: Nx.u8(0), 96 | has_reached_target: Nx.u8(0) 97 | } 98 | 99 | {state, random_key} 100 | end 101 | 102 | @impl true 103 | defn apply_action(state, action) do 104 | %{x: x, y: y} = env = state.environment_state 105 | 106 | # 0: up, 1: down, 2: right, 3: left 107 | {new_x, new_y} = 108 | cond do 109 | action == 0 -> 110 | {x, y + 1} 111 | 112 | action == 1 -> 113 | {x, y - 1} 114 | 115 | action == 2 -> 116 | {x + 1, y} 117 | 118 | true -> 119 | {x - 1, y} 120 | end 121 | 122 | new_env = %{ 123 | env 124 | | x: Nx.clip(new_x, @min_x, @max_x), 125 | y: Nx.clip(new_y, @min_y, @max_y), 126 | prev_x: x, 127 | prev_y: y 128 | } 129 | 130 | updated_env = 131 | new_env 132 | |> is_terminal_state() 133 | |> calculate_reward() 134 | 135 | %{state | environment_state: updated_env} 136 | end 137 | 138 | defnp calculate_reward(env) do 139 | distance = Nx.abs(env.target_x - env.x) + Nx.abs(env.target_y - env.y) 140 | reward = -1.0 * distance 141 | 142 | %{env | reward: reward} 143 | end 144 | 145 | defnp is_terminal_state(env) do 146 | has_reached_target = has_reached_target(env) 147 | out_of_bounds = env.x < @min_x or env.x > @max_x or env.y < @min_y or env.y > @max_y 148 | 149 | is_terminal = has_reached_target or out_of_bounds 150 | 151 | %__MODULE__{env | is_terminal: is_terminal, has_reached_target: has_reached_target} 152 | end 153 | 154 | defnp has_reached_target(%__MODULE__{x: x, y: y, target_x: target_x, target_y: target_y}) do 155 | target_x == x and target_y == y 156 | end 157 | 158 | defnp normalize(v, min, max), do: (v - min) / (max - min) 159 | 160 | @doc """ 161 | Default function for turning the environment into a vector representation. 162 | """ 163 | defn as_state_vector(%{ 164 | x: x, 165 | y: y, 166 | target_x: target_x, 167 | target_y: target_y, 168 | has_reached_target: has_reached_target 169 | }) do 170 | x = normalize(x, @min_x, @max_x) 171 | y = normalize(y, @min_y, @max_y) 172 | 173 | target_x = normalize(target_x, @min_x, @max_x) 174 | target_y = normalize(target_y, @min_y, @max_y) 175 | 176 | # max distance is sqrt(1 ** 2 + 1 ** 2) = sqrt(2) 177 | distance_norm = Nx.sqrt((x - target_x) ** 2 + (y - target_y) ** 2) / Nx.sqrt(2) 178 | 179 | Nx.stack([ 180 | x, 181 | y, 182 | target_x, 183 | target_y, 184 | has_reached_target, 185 | distance_norm 186 | ]) 187 | |> Nx.new_axis(0) 188 | end 189 | end 190 | -------------------------------------------------------------------------------- /lib/rein/agents/q_learning.ex: -------------------------------------------------------------------------------- 1 | defmodule Rein.Agents.QLearning do 2 | @moduledoc """ 3 | Q-Learning implementation. 4 | 5 | This implementation uses epsilon-greedy sampling 6 | for exploration, and doesn't contemplate any kind 7 | of target network. 8 | """ 9 | 10 | import Nx.Defn 11 | 12 | @behaviour Rein.Agent 13 | 14 | @derive {Nx.Container, 15 | containers: [ 16 | :q_matrix, 17 | :observation 18 | ], 19 | keep: [ 20 | :num_actions, 21 | :environment_to_state_vector_fn, 22 | :learning_rate, 23 | :gamma, 24 | :exploration_eps, 25 | :state_space_shape 26 | ]} 27 | 28 | defstruct [ 29 | :q_matrix, 30 | :observation, 31 | :num_actions, 32 | :environment_to_state_vector_fn, 33 | :learning_rate, 34 | :gamma, 35 | :exploration_eps, 36 | :state_space_shape 37 | ] 38 | 39 | @impl true 40 | def init(random_key, opts \\ []) do 41 | opts = 42 | Keyword.validate!(opts, [ 43 | :state_space_shape, 44 | :num_actions, 45 | :environment_to_state_vector_fn, 46 | :learning_rate, 47 | :gamma, 48 | :exploration_eps 49 | ]) 50 | 51 | state_space_shape = opts[:state_space_shape] 52 | num_actions = opts[:num_actions] 53 | 54 | # q_matrix is a tensor in which the state_vector indexes the axis 0 55 | # as linear indices, and axis 1 is the action axis 56 | {q_matrix, random_key} = 57 | Nx.Random.uniform(random_key, -0.1, 0.1, 58 | shape: {Tuple.product(state_space_shape), num_actions} 59 | ) 60 | 61 | state = %__MODULE__{ 62 | q_matrix: q_matrix, 63 | environment_to_state_vector_fn: opts[:environment_to_state_vector_fn], 64 | learning_rate: opts[:learning_rate], 65 | gamma: opts[:gamma], 66 | exploration_eps: opts[:exploration_eps], 67 | state_space_shape: state_space_shape, 68 | num_actions: num_actions 69 | } 70 | 71 | reset(random_key, state) 72 | end 73 | 74 | @impl true 75 | def reset(random_key, %Rein{agent_state: agent_state}), do: reset(random_key, agent_state) 76 | 77 | def reset(random_key, %__MODULE__{} = agent_state) do 78 | zero = Nx.tensor(0, type: :f32) 79 | 80 | observation = %{ 81 | action: 0, 82 | state: 0, 83 | next_state: 0, 84 | reward: zero 85 | } 86 | 87 | {%__MODULE__{agent_state | observation: observation}, random_key} 88 | end 89 | 90 | @impl true 91 | defn select_action( 92 | %Rein{random_key: random_key, agent_state: agent_state} = state, 93 | _iteration 94 | ) do 95 | {sample, random_key} = Nx.Random.uniform(random_key, shape: {}) 96 | 97 | state_vector = agent_state.environment_to_state_vector_fn.(state.environment_state) 98 | 99 | {action, random_key} = 100 | if sample < agent_state.exploration_eps do 101 | Nx.Random.randint(random_key, 0, agent_state.num_actions, shape: {}) 102 | else 103 | idx = state_vector_to_index(state_vector, agent_state.state_space_shape) 104 | 105 | action = Nx.argmax(agent_state.q_matrix[idx]) 106 | 107 | {action, random_key} 108 | end 109 | 110 | {action, %{state | random_key: random_key}} 111 | end 112 | 113 | @impl true 114 | deftransform record_observation( 115 | %{ 116 | environment_state: env_state, 117 | agent_state: %{ 118 | environment_to_state_vector_fn: environment_to_state_vector_fn, 119 | state_space_shape: state_space_shape 120 | } 121 | }, 122 | action, 123 | reward, 124 | _is_terminal, 125 | %{environment_state: next_env_state} = state 126 | ) do 127 | observation = %{ 128 | state: 129 | env_state 130 | |> environment_to_state_vector_fn.() 131 | |> state_vector_to_index(state_space_shape), 132 | next_state: 133 | next_env_state 134 | |> environment_to_state_vector_fn.() 135 | |> state_vector_to_index(state_space_shape), 136 | reward: reward, 137 | action: action 138 | } 139 | 140 | put_in(state.agent_state.observation, observation) 141 | end 142 | 143 | @impl true 144 | defn optimize_model(rl_state) do 145 | %{ 146 | observation: %{ 147 | state: state, 148 | next_state: next_state, 149 | reward: reward, 150 | action: action 151 | }, 152 | q_matrix: q_matrix, 153 | gamma: gamma, 154 | learning_rate: learning_rate 155 | } = rl_state.agent_state 156 | 157 | # Q_table[current_state, action] = 158 | # (1-lr) * Q_table[current_state, action] + 159 | # lr*(reward + gamma*max(Q_table[next_state,:])) 160 | 161 | q = 162 | (1 - learning_rate) * q_matrix[[state, action]] + 163 | learning_rate * (reward + gamma * Nx.reduce_max(q_matrix[next_state])) 164 | 165 | q_matrix = Nx.indexed_put(q_matrix, Nx.stack([state, action]), q) 166 | 167 | %{rl_state | agent_state: %{rl_state.agent_state | q_matrix: q_matrix}} 168 | end 169 | 170 | deftransformp state_vector_to_index(state_vector, shape) do 171 | {linear_indices_offsets_list, _} = 172 | shape 173 | |> Tuple.to_list() 174 | |> Enum.reverse() 175 | |> Enum.reduce({[], 1}, fn x, {acc, multiplier} -> 176 | {[multiplier | acc], multiplier * x} 177 | end) 178 | 179 | linear_indices_offsets = Nx.tensor(linear_indices_offsets_list) 180 | 181 | Nx.dot(state_vector, linear_indices_offsets) 182 | end 183 | end 184 | -------------------------------------------------------------------------------- /guides/gridworld.livemd: -------------------------------------------------------------------------------- 1 | # First steps with Gridworld 2 | 3 | ```elixir 4 | my_app_root = Path.join(__DIR__, "..") 5 | 6 | Mix.install( 7 | [ 8 | {:rein, path: my_app_root}, 9 | {:kino_vega_lite, "~> 0.1"} 10 | ], 11 | config_path: Path.join(my_app_root, "config/config.exs"), 12 | lockfile: Path.join(my_app_root, "mix.lock"), 13 | # change to "cuda118" or "cuda120" to use CUDA 14 | system_env: %{"XLA_TARGET" => "cpu"} 15 | ) 16 | ``` 17 | 18 | ## Initializing the plot 19 | 20 | In the code block below, we initialize some meta variables and configure our VegaLite plot in way that it can be updated iteratively over the algorithm iterations. 21 | 22 | ```elixir 23 | alias VegaLite, as: Vl 24 | 25 | {min_x, max_x, min_y, max_y} = Rein.Environments.Gridworld.bounding_box() 26 | 27 | possible_targets_l = [[round((min_x + max_x) / 2), max_y]] 28 | 29 | # possible_targets_l = 30 | # for x <- (min_x + 2)..(max_x - 2), y <- 2..max_y do 31 | # [x, y] 32 | # end 33 | 34 | possible_targets = Nx.tensor(Enum.shuffle(possible_targets_l)) 35 | 36 | width = 600 37 | height = 600 38 | 39 | grid_widget = 40 | Vl.new(width: width, height: height) 41 | |> Vl.layers([ 42 | Vl.new() 43 | |> Vl.data(name: "target") 44 | |> Vl.mark(:point, 45 | fill: true, 46 | tooltip: [content: "data"], 47 | grid: true, 48 | size: [expr: "height * 4 * #{:math.pi()} / #{max_y - min_y}"] 49 | ) 50 | |> Vl.encode_field(:x, "x", type: :quantitative) 51 | |> Vl.encode_field(:y, "y", type: :quantitative) 52 | |> Vl.encode_field(:color, "episode", 53 | type: :nominal, 54 | scale: [scheme: "blues"], 55 | legend: false 56 | ), 57 | Vl.new() 58 | |> Vl.data(name: "trajectory") 59 | |> Vl.mark(:line, point: true, opacity: 1, tooltip: [content: "data"]) 60 | |> Vl.encode_field(:x, "x", type: :quantitative, scale: [domain: [min_x, max_x], clamp: true]) 61 | |> Vl.encode_field(:y, "y", type: :quantitative, scale: [domain: [min_y, max_y], clamp: true]) 62 | |> Vl.encode_field(:order, "index") 63 | ]) 64 | |> Kino.VegaLite.new() 65 | |> Kino.render() 66 | 67 | nil 68 | ``` 69 | 70 | ## Configuring and running the Q Learning Agent 71 | 72 | Now we're ready to start configuring our agent. The `plot_fn` function defined below is a callback that `Rein` calls at the end of each iteration, so that we can do anything with the data. 73 | 74 | Usually, this means that we'll extract data to either plot, report or save somewhere. 75 | 76 | ```elixir 77 | # 250 max_iter * 15 episodes 78 | max_points = 1000 79 | 80 | plot_fn = fn axon_state -> 81 | if axon_state.iteration > 1 do 82 | episode = axon_state.episode 83 | 84 | Kino.VegaLite.clear(grid_widget, dataset: "target") 85 | Kino.VegaLite.clear(grid_widget, dataset: "trajectory") 86 | 87 | Kino.VegaLite.push( 88 | grid_widget, 89 | %{ 90 | x: Nx.to_number(axon_state.step_state.environment_state.target_x), 91 | y: Nx.to_number(axon_state.step_state.environment_state.target_y) 92 | }, 93 | dataset: "target" 94 | ) 95 | 96 | IO.inspect("Episode #{episode} ended") 97 | 98 | trajectory = axon_state.step_state.trajectory 99 | 100 | iteration = Nx.to_number(axon_state.step_state.iteration) 101 | 102 | points = 103 | trajectory[0..(iteration - 1)//1] 104 | |> Nx.to_list() 105 | |> Enum.with_index(fn [x, y], index -> 106 | %{ 107 | x: x, 108 | y: y, 109 | index: index 110 | } 111 | end) 112 | 113 | Kino.VegaLite.push_many(grid_widget, points, dataset: "trajectory") 114 | end 115 | 116 | axon_state 117 | end 118 | ``` 119 | 120 | Now, we get to the actual training! 121 | 122 | The code below calls `Rein.train` with some configuration for the `Gridworld` environment being solved through a `QLearning` agent. 123 | 124 | This will return the whole `Axon.Loop` struct in the `result` variable, so that we can inspect and/or save it afterwards. 125 | 126 | ```elixir 127 | Kino.VegaLite.clear(grid_widget) 128 | 129 | episodes = 15_000 130 | max_iter = 20 131 | 132 | environment_to_state_vector_fn = fn %{x: x, y: y, target_x: target_x, target_y: target_y} -> 133 | delta_x = Nx.subtract(x, min_x) 134 | delta_y = Nx.subtract(y, min_y) 135 | 136 | Nx.stack([delta_x, delta_y, Nx.subtract(target_x, min_x), Nx.subtract(target_y, min_y)]) 137 | end 138 | 139 | state_to_trajectory_fn = fn %{environment_state: %{x: x, y: y}} -> 140 | Nx.stack([x, y]) 141 | end 142 | 143 | delta_x = max_x - min_x + 1 144 | delta_y = max_y - min_y + 1 145 | 146 | state_space_shape = {delta_x, delta_y, delta_x, delta_y} 147 | 148 | {t, result} = 149 | :timer.tc(fn -> 150 | Rein.train( 151 | {Rein.Environments.Gridworld, possible_targets: possible_targets}, 152 | {Rein.Agents.QLearning, 153 | state_space_shape: state_space_shape, 154 | num_actions: 4, 155 | environment_to_state_vector_fn: environment_to_state_vector_fn, 156 | learning_rate: 1.0e-2, 157 | gamma: 0.99, 158 | exploration_eps: 1.0e-4}, 159 | plot_fn, 160 | state_to_trajectory_fn, 161 | checkpoint_path: "/tmp/gridworld", 162 | num_episodes: episodes, 163 | max_iter: max_iter 164 | ) 165 | end) 166 | 167 | "#{Float.round(t / 1_000_000, 3)} s" 168 | ``` 169 | 170 | With the code below, we can check some points of interest in the learned Q matrix. 171 | 172 | Especially, we can see below that for a target at x = 2, y = 4: 173 | 174 | * For the position x = 2, y = 3, the selected action is to go up; 175 | * For the position x = 1, y = 4, the selected action is to go right; 176 | * For the position x = 3, y = 4, the selected action is to go left. 177 | 178 | This shows that at least for the positions closer to the target, our agent already knows the best policy for those respective states! 179 | 180 | ```elixir 181 | state_vector_to_index = fn state_vector, shape -> 182 | {linear_indices_offsets_list, _} = 183 | shape 184 | |> Tuple.to_list() 185 | |> Enum.reverse() 186 | |> Enum.reduce({[], 1}, fn x, {acc, multiplier} -> 187 | {[multiplier | acc], multiplier * x} 188 | end) 189 | 190 | linear_indices_offsets = Nx.tensor(linear_indices_offsets_list) 191 | 192 | Nx.dot(state_vector, linear_indices_offsets) 193 | end 194 | 195 | # Actions are [up, down, right, left] 196 | 197 | # up 198 | idx = state_vector_to_index.(Nx.tensor([2, 3, 2, 4]), {5, 5, 5, 5}) 199 | IO.inspect(result.step_state.agent_state.q_matrix[idx]) 200 | 201 | # right 202 | idx = state_vector_to_index.(Nx.tensor([1, 4, 2, 4]), {5, 5, 5, 5}) 203 | IO.inspect(result.step_state.agent_state.q_matrix[idx]) 204 | 205 | # left 206 | idx = state_vector_to_index.(Nx.tensor([3, 4, 2, 4]), {5, 5, 5, 5}) 207 | IO.inspect(result.step_state.agent_state.q_matrix[idx]) 208 | 209 | nil 210 | ``` 211 | -------------------------------------------------------------------------------- /lib/rein.ex: -------------------------------------------------------------------------------- 1 | defmodule Rein do 2 | @moduledoc """ 3 | Reinforcement Learning training and inference framework 4 | """ 5 | 6 | import Nx.Defn 7 | 8 | @type t :: %__MODULE__{ 9 | agent_state: term(), 10 | environment_state: term(), 11 | episode: Nx.t(), 12 | iteration: Nx.t(), 13 | random_key: Nx.t(), 14 | trajectory: Nx.t() 15 | } 16 | 17 | @derive {Nx.Container, 18 | containers: [ 19 | :agent_state, 20 | :environment_state, 21 | :random_key, 22 | :iteration, 23 | :episode, 24 | :trajectory 25 | ], 26 | keep: []} 27 | defstruct [ 28 | :agent_state, 29 | :environment_state, 30 | :random_key, 31 | :iteration, 32 | :episode, 33 | :trajectory 34 | ] 35 | 36 | @spec train( 37 | {environment :: module, init_opts :: keyword()}, 38 | {agent :: module, init_opts :: keyword}, 39 | episode_completed_callback :: (map() -> :ok), 40 | state_to_trajectory_fn :: (t() -> Nx.t()), 41 | opts :: keyword() 42 | ) :: term() 43 | # underscore vars below for doc names 44 | def train( 45 | _environment_with_options = {environment, environment_init_opts}, 46 | _agent_with_options = {agent, agent_init_opts}, 47 | episode_completed_callback, 48 | state_to_trajectory_fn, 49 | opts \\ [] 50 | ) do 51 | opts = 52 | Keyword.validate!(opts, [ 53 | :random_key, 54 | :max_iter, 55 | :model_name, 56 | :checkpoint_path, 57 | checkpoint_serialization_fn: &Nx.serialize/1, 58 | accumulated_episodes: 0, 59 | num_episodes: 100, 60 | checkpoint_filter_fn: fn _state, episode -> rem(episode, 500) == 0 end 61 | ]) 62 | 63 | random_key = opts[:random_key] || Nx.Random.key(System.system_time()) 64 | max_iter = opts[:max_iter] 65 | num_episodes = opts[:num_episodes] 66 | model_name = opts[:model_name] 67 | 68 | {init_agent_state, random_key} = agent.init(random_key, agent_init_opts) 69 | 70 | episode = Nx.tensor(opts[:accumulated_episodes], type: :s64) 71 | iteration = Nx.tensor(0, type: :s64) 72 | 73 | [episode, iteration, _] = 74 | Nx.broadcast_vectors([episode, iteration, random_key], align_ranks: false) 75 | 76 | {environment_state, random_key} = environment.init(random_key, environment_init_opts) 77 | 78 | {agent_state, random_key} = 79 | agent.reset(random_key, %__MODULE__{ 80 | environment_state: environment_state, 81 | agent_state: init_agent_state, 82 | episode: episode 83 | }) 84 | 85 | initial_state = %__MODULE__{ 86 | agent_state: agent_state, 87 | environment_state: environment_state, 88 | random_key: random_key, 89 | iteration: iteration, 90 | episode: episode 91 | } 92 | 93 | %Nx.Tensor{shape: {trajectory_points}} = state_to_trajectory_fn.(initial_state) 94 | 95 | trajectory = Nx.broadcast(Nx.tensor(:nan, type: :f32), {max_iter + 1, trajectory_points}) 96 | [trajectory, _] = Nx.broadcast_vectors([trajectory, random_key], align_ranks: false) 97 | 98 | initial_state = %__MODULE__{initial_state | trajectory: trajectory} 99 | 100 | loop( 101 | agent, 102 | environment, 103 | initial_state, 104 | episode_completed_callback: episode_completed_callback, 105 | state_to_trajectory_fn: state_to_trajectory_fn, 106 | num_episodes: num_episodes, 107 | max_iter: max_iter, 108 | model_name: model_name, 109 | checkpoint_path: opts[:checkpoint_path], 110 | output_transform: opts[:output_transform], 111 | checkpoint_serialization_fn: opts[:checkpoint_serialization_fn], 112 | checkpoint_filter_fn: opts[:checkpoint_filter_fn] 113 | ) 114 | end 115 | 116 | defp loop(agent, environment, initial_state, opts) do 117 | episode_completed_callback = Keyword.fetch!(opts, :episode_completed_callback) 118 | state_to_trajectory_fn = Keyword.fetch!(opts, :state_to_trajectory_fn) 119 | num_episodes = Keyword.fetch!(opts, :num_episodes) 120 | max_iter = Keyword.fetch!(opts, :max_iter) 121 | 122 | Enum.reduce(1..num_episodes, initial_state, fn episode, state_outer -> 123 | Enum.reduce_while( 124 | 1..max_iter, 125 | {reset_state(state_outer, agent, environment, state_to_trajectory_fn), 0}, 126 | fn iteration, {state, _iter} -> 127 | next_state = batch_step(state, agent, environment, state_to_trajectory_fn) 128 | 129 | is_terminal = 130 | next_state.environment_state.is_terminal 131 | |> Nx.devectorize() 132 | |> Nx.all() 133 | |> Nx.to_number() 134 | 135 | if is_terminal == 1 do 136 | {:halt, {next_state, iteration}} 137 | else 138 | {:cont, {next_state, iteration}} 139 | end 140 | end 141 | ) 142 | |> then(fn {state, iteration} -> 143 | episode_completed_callback.(%{step_state: state, episode: episode, iteration: iteration}) 144 | state 145 | end) 146 | |> tap( 147 | &checkpoint( 148 | &1, 149 | episode, 150 | opts[:model_name], 151 | opts[:checkpoint_path], 152 | opts[:checkpoint_serialization_fn], 153 | opts[:checkpoint_filter_fn] 154 | ) 155 | ) 156 | end) 157 | end 158 | 159 | defp checkpoint( 160 | state, 161 | episode, 162 | model_name, 163 | checkpoint_path, 164 | checkpoint_serialization_fn, 165 | checkpoint_filter_fn 166 | ) do 167 | if checkpoint_filter_fn.(state, episode) do 168 | serialized = checkpoint_serialization_fn.(state) 169 | File.write!(Path.join(checkpoint_path, "#{model_name}_#{episode}.ckpt"), serialized) 170 | File.write!(Path.join(checkpoint_path, "#{model_name}_latest.ckpt"), serialized) 171 | end 172 | end 173 | 174 | defp reset_state( 175 | %__MODULE__{ 176 | environment_state: environment_state, 177 | random_key: random_key 178 | } = loop_state, 179 | agent, 180 | environment, 181 | state_to_trajectory_fn 182 | ) do 183 | {environment_state, random_key} = environment.reset(random_key, environment_state) 184 | 185 | {agent_state, random_key} = 186 | agent.reset(random_key, %{loop_state | environment_state: environment_state}) 187 | 188 | state = %{ 189 | loop_state 190 | | agent_state: agent_state, 191 | environment_state: environment_state, 192 | random_key: random_key, 193 | trajectory: Nx.broadcast(Nx.tensor(:nan, type: :f32), loop_state.trajectory), 194 | episode: Nx.add(loop_state.episode, 1), 195 | iteration: Nx.tensor(0, type: :s64) 196 | } 197 | 198 | persist_trajectory(state, state_to_trajectory_fn) 199 | end 200 | 201 | defp batch_step( 202 | prev_state, 203 | agent, 204 | environment, 205 | state_to_trajectory_fn 206 | ) do 207 | {action, state} = agent.select_action(prev_state, prev_state.iteration) 208 | 209 | %{environment_state: %{reward: reward, is_terminal: is_terminal}} = 210 | state = environment.apply_action(state, action) 211 | 212 | prev_state 213 | |> agent.record_observation( 214 | action, 215 | reward, 216 | is_terminal, 217 | state 218 | ) 219 | |> agent.optimize_model() 220 | |> persist_trajectory(state_to_trajectory_fn) 221 | end 222 | 223 | defnp persist_trajectory( 224 | %__MODULE__{trajectory: trajectory, iteration: iteration} = step_state, 225 | state_to_trajectory_fn 226 | ) do 227 | updates = state_to_trajectory_fn.(step_state) 228 | 229 | %Nx.Tensor{shape: {_, num_points}} = trajectory 230 | 231 | idx = 232 | Nx.concatenate([Nx.broadcast(iteration, {num_points, 1}), Nx.iota({num_points, 1})], 233 | axis: 1 234 | ) 235 | 236 | trajectory = Nx.indexed_put(trajectory, idx, updates) 237 | %{step_state | trajectory: trajectory, iteration: iteration + 1} 238 | end 239 | end 240 | -------------------------------------------------------------------------------- /lib/rein/agents/dqn.ex: -------------------------------------------------------------------------------- 1 | defmodule Rein.Agents.DQN do 2 | @moduledoc """ 3 | Deep Q-Learning implementation. 4 | 5 | This implementation utilizes a single target network for 6 | the policy network. 7 | """ 8 | import Nx.Defn 9 | 10 | @behaviour Rein.Agent 11 | 12 | @learning_rate 1.0e-3 13 | @adamw_decay 1.0e-2 14 | @eps 1.0e-7 15 | @experience_replay_buffer_num_entries 10_000 16 | 17 | @eps_start 1 18 | @eps_decay_rate 0.995 19 | @eps_increase_rate 1.005 20 | @eps_end 0.01 21 | 22 | @train_every_steps 32 23 | @adamw_decay 0.01 24 | 25 | @batch_size 128 26 | 27 | @gamma 0.99 28 | @tau 0.001 29 | 30 | @derive {Nx.Container, 31 | containers: [ 32 | :q_policy, 33 | :q_target, 34 | :q_policy_optimizer_state, 35 | :loss, 36 | :loss_denominator, 37 | :experience_replay_buffer, 38 | :experience_replay_buffer_index, 39 | :persisted_experience_replay_buffer_entries, 40 | :total_reward, 41 | :epsilon_greedy_eps, 42 | :exploration_decay_rate, 43 | :exploration_increase_rate, 44 | :min_eps, 45 | :max_eps, 46 | :performance_memory, 47 | :performance_threshold 48 | ], 49 | keep: [ 50 | :optimizer_update_fn, 51 | :policy_predict_fn, 52 | :input_template, 53 | :state_vector_size, 54 | :num_actions, 55 | :environment_to_input_fn, 56 | :environment_to_state_vector_fn, 57 | :state_vector_to_input_fn, 58 | :learning_rate, 59 | :batch_size, 60 | :training_frequency, 61 | :target_training_frequency, 62 | :gamma 63 | ]} 64 | 65 | defstruct [ 66 | :state_vector_size, 67 | :num_actions, 68 | :q_policy, 69 | :q_target, 70 | :q_policy_optimizer_state, 71 | :policy_predict_fn, 72 | :optimizer_update_fn, 73 | :experience_replay_buffer, 74 | :experience_replay_buffer_index, 75 | :persisted_experience_replay_buffer_entries, 76 | :loss, 77 | :loss_denominator, 78 | :total_reward, 79 | :environment_to_input_fn, 80 | :environment_to_state_vector_fn, 81 | :state_vector_to_input_fn, 82 | :input_template, 83 | :learning_rate, 84 | :batch_size, 85 | :training_frequency, 86 | :target_training_frequency, 87 | :gamma, 88 | :epsilon_greedy_eps, 89 | :exploration_decay_rate, 90 | :exploration_increase_rate, 91 | :min_eps, 92 | :max_eps, 93 | :performance_memory, 94 | :performance_threshold 95 | ] 96 | 97 | @impl true 98 | def init(random_key, opts \\ []) do 99 | opts = 100 | Keyword.validate!(opts, [ 101 | :q_policy, 102 | :q_target, 103 | :policy_net, 104 | :experience_replay_buffer, 105 | :experience_replay_buffer_index, 106 | :persisted_experience_replay_buffer_entries, 107 | :environment_to_input_fn, 108 | :environment_to_state_vector_fn, 109 | :state_vector_to_input_fn, 110 | :performance_memory, 111 | target_training_frequency: @train_every_steps * 4, 112 | learning_rate: @learning_rate, 113 | batch_size: @batch_size, 114 | training_frequency: @train_every_steps, 115 | gamma: @gamma, 116 | eps_decay_rate: @eps_decay_rate, 117 | exploration_decay_rate: @eps_decay_rate, 118 | exploration_increase_rate: @eps_increase_rate, 119 | min_eps: @eps_end, 120 | max_eps: @eps_start, 121 | performance_memory_length: 500, 122 | performance_threshold: 0.01 123 | ]) 124 | 125 | policy_net = opts[:policy_net] || raise ArgumentError, "missing :policy_net option" 126 | 127 | environment_to_input_fn = 128 | opts[:environment_to_input_fn] || 129 | raise ArgumentError, "missing :environment_to_input_fn option" 130 | 131 | environment_to_state_vector_fn = 132 | opts[:environment_to_state_vector_fn] || 133 | raise ArgumentError, "missing :environment_to_state_vector_fn option" 134 | 135 | state_vector_to_input_fn = 136 | opts[:state_vector_to_input_fn] || 137 | raise ArgumentError, "missing :state_vector_to_input_fn option" 138 | 139 | {policy_init_fn, policy_predict_fn} = Axon.build(policy_net, seed: 0) 140 | 141 | # TO-DO: receive optimizer as argument 142 | {optimizer_init_fn, optimizer_update_fn} = 143 | Polaris.Updates.clip_by_global_norm() 144 | |> Polaris.Updates.compose( 145 | Polaris.Optimizers.adamw(learning_rate: @learning_rate, eps: @eps, decay: @adamw_decay) 146 | ) 147 | 148 | initial_q_policy_state = opts[:q_policy] || raise "missing initial q_policy" 149 | initial_q_target_state = opts[:q_target] || initial_q_policy_state 150 | 151 | input_template = input_template(policy_net) 152 | 153 | q_policy = policy_init_fn.(input_template, initial_q_policy_state) 154 | q_target = policy_init_fn.(input_template, initial_q_target_state) 155 | 156 | q_policy_optimizer_state = optimizer_init_fn.(q_policy) 157 | 158 | {1, num_actions} = Axon.get_output_shape(policy_net, input_template) 159 | 160 | state_vector_size = state_vector_size(input_template) 161 | 162 | loss = loss_denominator = total_reward = Nx.tensor(0, type: :f32) 163 | 164 | state = %__MODULE__{ 165 | learning_rate: opts[:learning_rate], 166 | total_reward: total_reward, 167 | batch_size: opts[:batch_size], 168 | training_frequency: opts[:training_frequency], 169 | target_training_frequency: opts[:target_training_frequency], 170 | gamma: opts[:gamma], 171 | loss: loss, 172 | loss_denominator: loss_denominator, 173 | state_vector_size: state_vector_size, 174 | num_actions: num_actions, 175 | input_template: input_template, 176 | environment_to_input_fn: environment_to_input_fn, 177 | environment_to_state_vector_fn: environment_to_state_vector_fn, 178 | state_vector_to_input_fn: state_vector_to_input_fn, 179 | q_policy: q_policy, 180 | q_policy_optimizer_state: q_policy_optimizer_state, 181 | q_target: q_target, 182 | policy_predict_fn: policy_predict_fn, 183 | optimizer_update_fn: optimizer_update_fn, 184 | # prev_state_vector, target_x, target_y, action, reward, is_terminal, next_state_vector 185 | experience_replay_buffer: 186 | opts[:experience_replay_buffer] || 187 | Nx.broadcast( 188 | Nx.tensor(:nan, type: :f32), 189 | {@experience_replay_buffer_num_entries, 2 * state_vector_size + 4} 190 | ), 191 | experience_replay_buffer_index: 192 | opts[:experience_replay_buffer_index] || Nx.tensor(0, type: :s64), 193 | persisted_experience_replay_buffer_entries: 194 | opts[:persisted_experience_replay_buffer_entries] || Nx.tensor(0, type: :s64), 195 | performance_threshold: opts[:performance_threshold], 196 | performance_memory: 197 | opts[:performance_memory] || 198 | Nx.broadcast(total_reward, {opts[:performance_memory_length]}), 199 | max_eps: opts[:max_eps], 200 | min_eps: opts[:min_eps], 201 | epsilon_greedy_eps: opts[:max_eps], 202 | exploration_decay_rate: opts[:exploration_decay_rate], 203 | exploration_increase_rate: opts[:exploration_increase_rate] 204 | } 205 | 206 | {state, random_key} 207 | end 208 | 209 | defp input_template(model) do 210 | model 211 | |> Axon.get_inputs() 212 | |> Map.new(fn {name, shape} -> 213 | [nil | shape] = Tuple.to_list(shape) 214 | shape = List.to_tuple([1 | shape]) 215 | {name, Nx.template(shape, :f32)} 216 | end) 217 | end 218 | 219 | defp state_vector_size(input_template) do 220 | Enum.reduce(input_template, 0, fn {_field, tensor}, acc -> 221 | div(Nx.size(tensor), Nx.axis_size(tensor, 0)) + acc 222 | end) 223 | end 224 | 225 | @impl true 226 | def reset(random_key, %Rein{agent_state: state, episode: episode}) do 227 | total_reward = loss = loss_denominator = Nx.tensor(0, type: :f32) 228 | 229 | state = adapt_exploration(episode, state) 230 | 231 | {%{ 232 | state 233 | | total_reward: total_reward, 234 | loss: loss, 235 | loss_denominator: loss_denominator 236 | }, random_key} 237 | end 238 | 239 | defnp adapt_exploration( 240 | episode, 241 | %__MODULE__{ 242 | exploration_decay_rate: exploration_decay_rate, 243 | exploration_increase_rate: exploration_increase_rate, 244 | epsilon_greedy_eps: eps, 245 | min_eps: min_eps, 246 | max_eps: max_eps, 247 | total_reward: reward, 248 | performance_memory: %Nx.Tensor{shape: {n}} = performance_memory, 249 | performance_threshold: performance_threshold 250 | } = state 251 | ) do 252 | {eps, performance_memory} = 253 | cond do 254 | episode == 0 -> 255 | {eps, performance_memory} 256 | 257 | episode < n -> 258 | index = Nx.remainder(episode, n) 259 | 260 | performance_memory = 261 | Nx.indexed_put( 262 | performance_memory, 263 | Nx.reshape(index, {1, 1}), 264 | Nx.reshape(reward, {1}) 265 | ) 266 | 267 | {eps, performance_memory} 268 | 269 | true -> 270 | index = Nx.remainder(episode, n) 271 | 272 | performance_memory = 273 | Nx.indexed_put(performance_memory, Nx.reshape(index, {1, 1}), Nx.reshape(reward, {1})) 274 | 275 | index = Nx.remainder(index + 1, n) 276 | 277 | # We want to get our 2 windows in sequence so that we can compare them. 278 | # The rem(iota + index + 1, n) operation will effectively set it so that 279 | # we have the oldest window starting at the first position, and then all elements 280 | # in the circular buffer fall into sequence 281 | window_indices = Nx.remainder(Nx.iota({n}) + index, n) 282 | 283 | # After we take and reshape, the first row contains the oldest `n//2` samples 284 | # and the second row, the remaining newest samples. 285 | windows = 286 | performance_memory 287 | |> Nx.take(window_indices) 288 | |> Nx.reshape({2, :auto}) 289 | 290 | # avg[0]: avg of the previous performance window 291 | # avg[1]: avg of the current performance window 292 | avg = Nx.mean(windows, axes: [1]) 293 | 294 | abs_diff = Nx.abs(avg[0] - avg[1]) 295 | 296 | eps = 297 | if abs_diff < performance_threshold do 298 | # If decayed to less than an "eps" value, 299 | # we force it to increase from that "eps" instead. 300 | Nx.min(eps * exploration_increase_rate, max_eps) 301 | else 302 | # can decay to 0 303 | Nx.max(eps * exploration_decay_rate, min_eps) 304 | end 305 | 306 | {eps, performance_memory} 307 | end 308 | 309 | %__MODULE__{ 310 | state 311 | | epsilon_greedy_eps: eps, 312 | performance_memory: performance_memory 313 | } 314 | end 315 | 316 | @impl true 317 | defn select_action( 318 | %Rein{random_key: random_key, agent_state: agent_state} = state, 319 | _iteration 320 | ) do 321 | %{ 322 | q_policy: q_policy, 323 | policy_predict_fn: policy_predict_fn, 324 | environment_to_input_fn: environment_to_input_fn, 325 | num_actions: num_actions, 326 | epsilon_greedy_eps: eps_threshold 327 | } = agent_state 328 | 329 | {sample, random_key} = Nx.Random.uniform(random_key) 330 | 331 | {action, random_key} = 332 | if sample > eps_threshold do 333 | action = 334 | q_policy 335 | |> policy_predict_fn.(environment_to_input_fn.(state.environment_state)) 336 | |> Nx.argmax() 337 | 338 | {action, random_key} 339 | else 340 | Nx.Random.randint(random_key, 0, num_actions, type: :s64) 341 | end 342 | 343 | {action, %{state | random_key: random_key}} 344 | end 345 | 346 | @impl true 347 | defn record_observation( 348 | %{ 349 | environment_state: env_state, 350 | agent_state: %{ 351 | q_policy: q_policy, 352 | policy_predict_fn: policy_predict_fn, 353 | state_vector_to_input_fn: state_vector_to_input_fn, 354 | environment_to_state_vector_fn: as_state_vector_fn, 355 | gamma: gamma 356 | } 357 | }, 358 | action, 359 | reward, 360 | is_terminal, 361 | %{environment_state: next_env_state} = state 362 | ) do 363 | state_vector = as_state_vector_fn.(env_state) 364 | next_state_vector = as_state_vector_fn.(next_env_state) 365 | 366 | idx = Nx.stack([state.agent_state.experience_replay_buffer_index, 0]) |> Nx.new_axis(0) 367 | 368 | shape = {Nx.size(state_vector) + 4 + Nx.size(next_state_vector), 1} 369 | 370 | index_template = Nx.concatenate([Nx.broadcast(0, shape), Nx.iota(shape, axis: 0)], axis: 1) 371 | 372 | predicted_reward = 373 | reward + 374 | policy_predict_fn.(q_policy, state_vector_to_input_fn.(next_state_vector)) * gamma * 375 | (1 - is_terminal) 376 | 377 | %{shape: {1}} = predicted_reward = Nx.reduce_max(predicted_reward, axes: [-1]) 378 | 379 | temporal_difference = Nx.reshape(Nx.abs(reward - predicted_reward), {1}) 380 | 381 | updates = 382 | Nx.concatenate([ 383 | Nx.flatten(state_vector), 384 | Nx.stack([action, reward, is_terminal]), 385 | Nx.flatten(next_state_vector), 386 | temporal_difference 387 | ]) 388 | 389 | experience_replay_buffer = 390 | Nx.indexed_put(state.agent_state.experience_replay_buffer, idx + index_template, updates) 391 | 392 | experience_replay_buffer_index = 393 | Nx.remainder( 394 | state.agent_state.experience_replay_buffer_index + 1, 395 | @experience_replay_buffer_num_entries 396 | ) 397 | 398 | entries = state.agent_state.persisted_experience_replay_buffer_entries 399 | 400 | persisted_experience_replay_buffer_entries = 401 | Nx.select( 402 | entries < @experience_replay_buffer_num_entries, 403 | entries + 1, 404 | entries 405 | ) 406 | 407 | %{ 408 | state 409 | | agent_state: %{ 410 | state.agent_state 411 | | experience_replay_buffer: experience_replay_buffer, 412 | experience_replay_buffer_index: experience_replay_buffer_index, 413 | persisted_experience_replay_buffer_entries: 414 | persisted_experience_replay_buffer_entries, 415 | total_reward: state.agent_state.total_reward + reward 416 | } 417 | } 418 | end 419 | 420 | @impl true 421 | defn optimize_model(state) do 422 | %{ 423 | persisted_experience_replay_buffer_entries: persisted_experience_replay_buffer_entries, 424 | experience_replay_buffer_index: experience_replay_buffer_index, 425 | batch_size: batch_size, 426 | training_frequency: training_frequency, 427 | target_training_frequency: target_training_frequency 428 | } = state.agent_state 429 | 430 | has_at_least_one_batch = persisted_experience_replay_buffer_entries > batch_size 431 | should_update_policy_net = rem(experience_replay_buffer_index, training_frequency) == 0 432 | should_update_target_net = rem(experience_replay_buffer_index, target_training_frequency) == 0 433 | 434 | {state, _, _, _} = 435 | while {state, i = 0, training_frequency, 436 | pred = has_at_least_one_batch and should_update_policy_net}, 437 | pred and i < training_frequency do 438 | {train(state), i + 1, training_frequency, pred} 439 | end 440 | 441 | {state, _, _, _} = 442 | while {state, i = 0, target_training_frequency, 443 | pred = has_at_least_one_batch and should_update_target_net}, 444 | pred and i < target_training_frequency do 445 | {soft_update_targets(state), i + 1, target_training_frequency, pred} 446 | end 447 | 448 | state 449 | end 450 | 451 | defnp train(state) do 452 | %{ 453 | agent_state: %{ 454 | q_policy: q_policy, 455 | q_target: q_target, 456 | q_policy_optimizer_state: q_policy_optimizer_state, 457 | policy_predict_fn: policy_predict_fn, 458 | optimizer_update_fn: optimizer_update_fn, 459 | state_vector_to_input_fn: state_vector_to_input_fn, 460 | state_vector_size: state_vector_size, 461 | experience_replay_buffer: experience_replay_buffer, 462 | gamma: gamma 463 | }, 464 | random_key: random_key 465 | } = state 466 | 467 | {batch, batch_idx, random_key} = 468 | sample_experience_replay_buffer(random_key, state.agent_state) 469 | 470 | state_batch = 471 | batch 472 | |> Nx.slice_along_axis(0, state_vector_size, axis: 1) 473 | |> then(state_vector_to_input_fn) 474 | 475 | action_batch = Nx.slice_along_axis(batch, state_vector_size, 1, axis: 1) 476 | reward_batch = Nx.slice_along_axis(batch, state_vector_size + 1, 1, axis: 1) 477 | is_terminal_batch = Nx.slice_along_axis(batch, state_vector_size + 2, 1, axis: 1) 478 | 479 | next_state_batch = 480 | batch 481 | |> Nx.slice_along_axis(state_vector_size + 3, state_vector_size, axis: 1) 482 | |> then(state_vector_to_input_fn) 483 | 484 | non_final_mask = not is_terminal_batch 485 | 486 | {{experience_replay_buffer, loss}, gradient} = 487 | value_and_grad( 488 | q_policy, 489 | fn q_policy -> 490 | action_idx = Nx.as_type(action_batch, :s64) 491 | 492 | %{shape: {m, 1}} = 493 | state_action_values = 494 | q_policy 495 | |> policy_predict_fn.(state_batch) 496 | |> Nx.take_along_axis(action_idx, axis: 1) 497 | 498 | expected_state_action_values = 499 | reward_batch + 500 | policy_predict_fn.(q_target, next_state_batch) * gamma * non_final_mask 501 | 502 | %{shape: {n, 1}} = 503 | expected_state_action_values = 504 | Nx.reduce_max(expected_state_action_values, axes: [-1], keep_axes: true) 505 | 506 | case {m, n} do 507 | {m, n} when m != n -> 508 | raise "shape mismatch for batch values" 509 | 510 | _ -> 511 | 1 512 | end 513 | 514 | td_errors = Nx.abs(expected_state_action_values - state_action_values) 515 | 516 | { 517 | update_priorities( 518 | experience_replay_buffer, 519 | batch_idx, 520 | state_vector_size * 2 + 3, 521 | td_errors 522 | ), 523 | Axon.Losses.huber(expected_state_action_values, state_action_values, reduction: :mean) 524 | } 525 | end, 526 | &elem(&1, 1) 527 | ) 528 | 529 | {scaled_updates, optimizer_state} = 530 | optimizer_update_fn.(gradient, q_policy_optimizer_state, q_policy) 531 | 532 | q_policy = Polaris.Updates.apply_updates(q_policy, scaled_updates) 533 | 534 | %{ 535 | state 536 | | agent_state: %{ 537 | state.agent_state 538 | | q_policy: q_policy, 539 | q_policy_optimizer_state: optimizer_state, 540 | loss: state.agent_state.loss + loss, 541 | loss_denominator: state.agent_state.loss_denominator + 1, 542 | experience_replay_buffer: experience_replay_buffer 543 | }, 544 | random_key: random_key 545 | } 546 | end 547 | 548 | defnp soft_update_targets(state) do 549 | %{agent_state: %{q_target: q_target, q_policy: q_policy} = agent_state} = state 550 | 551 | q_target = Axon.Shared.deep_merge(q_policy, q_target, &(&1 * @tau + &2 * (1 - @tau))) 552 | 553 | %{state | agent_state: %{agent_state | q_target: q_target}} 554 | end 555 | 556 | @alpha 0.6 557 | defnp sample_experience_replay_buffer( 558 | random_key, 559 | %{state_vector_size: state_vector_size} = agent_state 560 | ) do 561 | %{shape: {@experience_replay_buffer_num_entries, _}} = 562 | exp_replay_buffer = slice_experience_replay_buffer(agent_state) 563 | 564 | # Temporal Difference prioritizing: 565 | # We are going to sort experiences by temporal difference 566 | # and divide our buffer into 4 slices, from which we will 567 | # then uniformily sample. 568 | # The temporal difference is already stored in the end of our buffer. 569 | 570 | temporal_difference = 571 | exp_replay_buffer 572 | |> Nx.slice_along_axis(state_vector_size * 2 + 3, 1, axis: 1) 573 | |> Nx.flatten() 574 | 575 | priorities = temporal_difference ** @alpha 576 | probs = priorities / Nx.sum(priorities) 577 | 578 | {batch_idx, random_key} = 579 | Nx.Random.choice(random_key, Nx.iota(temporal_difference.shape), probs, 580 | samples: @batch_size, 581 | replace: false, 582 | axis: 0 583 | ) 584 | 585 | batch = Nx.take(exp_replay_buffer, batch_idx) 586 | {batch, batch_idx, random_key} 587 | end 588 | 589 | defnp slice_experience_replay_buffer(state) do 590 | %{ 591 | experience_replay_buffer: experience_replay_buffer, 592 | persisted_experience_replay_buffer_entries: entries 593 | } = state 594 | 595 | if entries < @experience_replay_buffer_num_entries do 596 | t = Nx.iota({@experience_replay_buffer_num_entries}) 597 | idx = Nx.select(t < entries, t, 0) 598 | Nx.take(experience_replay_buffer, idx) 599 | else 600 | experience_replay_buffer 601 | end 602 | end 603 | 604 | defn update_priorities( 605 | buffer, 606 | %{shape: {n}} = row_idx, 607 | target_column, 608 | td_errors 609 | ) do 610 | case td_errors.shape do 611 | {^n, 1} -> :ok 612 | shape -> raise "invalid shape for td_errors, got: #{inspect(shape)}" 613 | end 614 | 615 | indices = Nx.stack([row_idx, Nx.broadcast(target_column, {n})], axis: -1) 616 | 617 | Nx.indexed_put(buffer, indices, Nx.reshape(td_errors, {n})) 618 | end 619 | end 620 | -------------------------------------------------------------------------------- /lib/rein/agents/ddpg.ex: -------------------------------------------------------------------------------- 1 | defmodule Rein.Agents.DDPG do 2 | @moduledoc """ 3 | Deep Deterministic Policy Gradient implementation. 4 | 5 | This assumes that the Actor network will output `{nil, num_actions}` actions, 6 | and that the Critic network accepts the `"actions"` input with the same shape. 7 | 8 | Actions are deemed to be in a continuous space of type `:f32`. 9 | """ 10 | import Nx.Defn 11 | 12 | alias Rein.Utils.Noise.OUProcess 13 | alias Rein.Utils.CircularBuffer 14 | 15 | @behaviour Rein.Agent 16 | 17 | @derive {Nx.Container, 18 | containers: [ 19 | :actor_params, 20 | :actor_target_params, 21 | :critic_params, 22 | :critic_target_params, 23 | :experience_replay_buffer, 24 | :target_update_frequency, 25 | :loss, 26 | :loss_denominator, 27 | :total_reward, 28 | :actor_optimizer_state, 29 | :critic_optimizer_state, 30 | :action_lower_limit, 31 | :action_upper_limit, 32 | :ou_process, 33 | :max_sigma, 34 | :min_sigma, 35 | :exploration_decay_rate, 36 | :exploration_increase_rate, 37 | :performance_memory, 38 | :performance_threshold, 39 | :gamma, 40 | :tau, 41 | :state_features_memory 42 | ], 43 | keep: [ 44 | :exploration_fn, 45 | :environment_to_state_features_fn, 46 | :actor_predict_fn, 47 | :critic_predict_fn, 48 | :num_actions, 49 | :actor_optimizer_update_fn, 50 | :critic_optimizer_update_fn, 51 | :batch_size, 52 | :training_frequency, 53 | :input_entry_size 54 | ]} 55 | 56 | defstruct [ 57 | :num_actions, 58 | :actor_params, 59 | :actor_target_params, 60 | :actor_net, 61 | :critic_params, 62 | :critic_target_params, 63 | :critic_net, 64 | :actor_predict_fn, 65 | :critic_predict_fn, 66 | :experience_replay_buffer, 67 | :environment_to_state_features_fn, 68 | :gamma, 69 | :tau, 70 | :batch_size, 71 | :training_frequency, 72 | :target_update_frequency, 73 | :actor_optimizer_state, 74 | :critic_optimizer_state, 75 | :action_lower_limit, 76 | :action_upper_limit, 77 | :loss, 78 | :loss_denominator, 79 | :total_reward, 80 | :actor_optimizer_update_fn, 81 | :critic_optimizer_update_fn, 82 | :ou_process, 83 | :max_sigma, 84 | :min_sigma, 85 | :exploration_decay_rate, 86 | :exploration_increase_rate, 87 | :performance_memory, 88 | :performance_threshold, 89 | :exploration_fn, 90 | :state_features_memory, 91 | :input_entry_size 92 | ] 93 | 94 | @impl true 95 | def init(random_key, opts \\ []) do 96 | expected_opts = [ 97 | :actor_params, 98 | :actor_target_params, 99 | :actor_net, 100 | :critic_params, 101 | :critic_target_params, 102 | :critic_net, 103 | :experience_replay_buffer, 104 | :environment_to_state_features_fn, 105 | :performance_memory, 106 | :state_features_memory_to_input_fn, 107 | :state_features_memory, 108 | :state_features_size, 109 | :actor_optimizer, 110 | :critic_optimizer, 111 | ou_process_opts: [max_sigma: 0.2, min_sigma: 0.001, sigma: 0.01], 112 | performance_memory_length: 500, 113 | state_features_memory_length: 1, 114 | exploration_decay_rate: 0.9995, 115 | exploration_increase_rate: 1.1, 116 | performance_threshold: 0.01, 117 | exploration_fn: &Nx.less(&1, 500), 118 | gamma: 0.99, 119 | experience_replay_buffer_max_size: 100_000, 120 | tau: 0.005, 121 | batch_size: 64, 122 | training_frequency: 32, 123 | target_update_frequency: 100, 124 | action_lower_limit: -1.0, 125 | action_upper_limit: 1.0 126 | ] 127 | 128 | opts = Keyword.validate!(opts, expected_opts) 129 | 130 | # TO-DO: use NimbleOptions 131 | expected_opts 132 | |> Enum.filter(fn x -> is_atom(x) or (is_tuple(x) and is_nil(elem(x, 1))) end) 133 | |> Enum.reject(fn k -> 134 | k in [:state_features_memory, :performance_memory, :experience_replay_buffer] 135 | end) 136 | |> Enum.reduce(opts, fn 137 | k, opts -> 138 | case List.keytake(opts, k, 0) do 139 | {{^k, _}, opts} -> opts 140 | nil -> raise ArgumentError, "missing option #{k}" 141 | end 142 | end) 143 | |> Enum.each(fn {k, v} -> 144 | if is_nil(v) do 145 | raise ArgumentError, "option #{k} cannot be nil" 146 | end 147 | end) 148 | 149 | {actor_optimizer_init_fn, actor_optimizer_update_fn} = opts[:actor_optimizer] 150 | {critic_optimizer_init_fn, critic_optimizer_update_fn} = opts[:critic_optimizer] 151 | 152 | actor_net = opts[:actor_net] 153 | critic_net = opts[:critic_net] 154 | 155 | environment_to_state_features_fn = opts[:environment_to_state_features_fn] 156 | state_features_memory_to_input_fn = opts[:state_features_memory_to_input_fn] 157 | 158 | {actor_init_fn, actor_predict_fn} = Axon.build(actor_net, seed: 0) 159 | {critic_init_fn, critic_predict_fn} = Axon.build(critic_net, seed: 1) 160 | 161 | actor_predict_fn = fn params, state_features_memory -> 162 | actor_predict_fn.(params, state_features_memory_to_input_fn.(state_features_memory)) 163 | end 164 | 165 | critic_predict_fn = fn params, state_features_memory, action_vector -> 166 | input = 167 | state_features_memory 168 | |> state_features_memory_to_input_fn.() 169 | |> Map.put("actions", action_vector) 170 | 171 | critic_predict_fn.(params, input) 172 | end 173 | 174 | initial_actor_params_state = opts[:actor_params] 175 | initial_actor_target_params_state = opts[:actor_target_params] || initial_actor_params_state 176 | initial_critic_params_state = opts[:critic_params] 177 | 178 | initial_critic_target_params_state = 179 | opts[:critic_target_params] || initial_critic_params_state 180 | 181 | input_template = input_template(actor_net) 182 | 183 | case input_template do 184 | %{"actions" => _} -> 185 | raise ArgumentError, 186 | "the input template for the actor_network must not contain the reserved key \"actions\"" 187 | 188 | _ -> 189 | :ok 190 | end 191 | 192 | {1, num_actions} = Axon.get_output_shape(actor_net, input_template) 193 | 194 | {max_sigma, ou_process_opts} = Keyword.pop!(opts[:ou_process_opts], :max_sigma) 195 | {min_sigma, ou_process_opts} = Keyword.pop!(ou_process_opts, :min_sigma) 196 | 197 | unless max_sigma do 198 | raise ArgumentError, "option [:ou_process_opts][:max_sigma] cannot be nil" 199 | end 200 | 201 | unless min_sigma do 202 | raise ArgumentError, "option [:ou_process_opts][:min_sigma] cannot be nil" 203 | end 204 | 205 | ou_process = OUProcess.init({1, num_actions}, ou_process_opts) 206 | 207 | critic_template = input_template(critic_net) 208 | 209 | case critic_template do 210 | %{"actions" => action_input} -> 211 | unless action_input != Nx.template({nil, num_actions}, :f32) do 212 | raise ArgumentError, 213 | "the critic network must accept the \"actions\" input with shape {nil, #{num_actions}} and type :f32, got input template: #{critic_template}" 214 | end 215 | 216 | if Map.delete(critic_template, "actions") != input_template do 217 | raise ArgumentError, 218 | "the critic network must have the same input template as the actor network + the \"action\" input" 219 | end 220 | 221 | _ -> 222 | :ok 223 | end 224 | 225 | actor_params = actor_init_fn.(input_template, initial_actor_params_state) 226 | actor_target_params = actor_init_fn.(input_template, initial_actor_target_params_state) 227 | 228 | actor_optimizer_state = actor_optimizer_init_fn.(actor_params) 229 | 230 | critic_params = critic_init_fn.(critic_template, initial_critic_params_state) 231 | critic_target_params = critic_init_fn.(critic_template, initial_critic_target_params_state) 232 | 233 | critic_optimizer_state = critic_optimizer_init_fn.(critic_params) 234 | 235 | state_features_size = opts[:state_features_size] 236 | 237 | total_reward = loss = loss_denominator = Nx.tensor(0, type: :f32) 238 | experience_replay_buffer_max_size = opts[:experience_replay_buffer_max_size] 239 | state_features_memory_length = opts[:state_features_memory_length] 240 | input_entry_size = state_features_size * state_features_memory_length 241 | 242 | {exp_replay_buffer, random_key} = 243 | if buffer = opts[:experience_replay_buffer] do 244 | {buffer, random_key} 245 | else 246 | {random_data_1, random_key} = 247 | Nx.Random.normal(random_key, 0, 10, 248 | shape: {experience_replay_buffer_max_size, input_entry_size + num_actions} 249 | ) 250 | 251 | init_reward = Nx.broadcast(-1.0e-8, {experience_replay_buffer_max_size, 1}) 252 | 253 | {random_data_2, random_key} = 254 | Nx.Random.normal(random_key, 0, 10, 255 | shape: {experience_replay_buffer_max_size, state_features_size + 1} 256 | ) 257 | 258 | init_td_error = Nx.broadcast(1.0e-8, {experience_replay_buffer_max_size, 1}) 259 | 260 | data = 261 | [random_data_1, init_reward, random_data_2, init_td_error] 262 | |> Nx.concatenate(axis: 1) 263 | |> then(&Nx.revectorize(&1, [], target_shape: Tuple.insert_at(&1.shape, 0, :auto))) 264 | 265 | buffer = %CircularBuffer{ 266 | data: data[[0, .., ..]], 267 | index: 0, 268 | size: 0 269 | } 270 | 271 | {buffer, random_key} 272 | end 273 | 274 | state = %__MODULE__{ 275 | max_sigma: max_sigma, 276 | min_sigma: min_sigma, 277 | input_entry_size: input_entry_size, 278 | exploration_fn: opts[:exploration_fn], 279 | exploration_decay_rate: opts[:exploration_decay_rate], 280 | exploration_increase_rate: opts[:exploration_increase_rate], 281 | state_features_memory: 282 | opts[:state_features_memory] || 283 | CircularBuffer.new({state_features_memory_length, state_features_size}), 284 | num_actions: num_actions, 285 | actor_params: actor_params, 286 | actor_target_params: actor_target_params, 287 | actor_net: actor_net, 288 | critic_params: critic_params, 289 | critic_target_params: critic_target_params, 290 | critic_net: critic_net, 291 | actor_predict_fn: actor_predict_fn, 292 | critic_predict_fn: critic_predict_fn, 293 | performance_threshold: opts[:performance_threshold], 294 | performance_memory: 295 | opts[:performance_memory] || CircularBuffer.new({opts[:performance_memory_length]}), 296 | experience_replay_buffer: exp_replay_buffer, 297 | environment_to_state_features_fn: environment_to_state_features_fn, 298 | gamma: opts[:gamma], 299 | tau: opts[:tau], 300 | batch_size: opts[:batch_size], 301 | ou_process: ou_process, 302 | training_frequency: opts[:training_frequency], 303 | target_update_frequency: opts[:target_update_frequency], 304 | total_reward: total_reward, 305 | loss: loss, 306 | loss_denominator: loss_denominator, 307 | actor_optimizer_update_fn: actor_optimizer_update_fn, 308 | critic_optimizer_update_fn: critic_optimizer_update_fn, 309 | actor_optimizer_state: actor_optimizer_state, 310 | critic_optimizer_state: critic_optimizer_state, 311 | action_lower_limit: opts[:action_lower_limit], 312 | action_upper_limit: opts[:action_upper_limit] 313 | } 314 | 315 | case random_key.vectorized_axes do 316 | [] -> 317 | {state, random_key} 318 | 319 | _ -> 320 | vectorizable_paths = [ 321 | [Access.key(:ou_process), Access.key(:theta)], 322 | [Access.key(:ou_process), Access.key(:sigma)], 323 | [Access.key(:ou_process), Access.key(:mu)], 324 | [Access.key(:ou_process), Access.key(:x)], 325 | [Access.key(:loss)], 326 | [Access.key(:loss_denominator)], 327 | [Access.key(:total_reward)], 328 | [Access.key(:performance_memory), Access.key(:data)], 329 | [Access.key(:performance_memory), Access.key(:index)], 330 | [Access.key(:performance_memory), Access.key(:size)], 331 | [Access.key(:performance_threshold)], 332 | [Access.key(:state_features_memory), Access.key(:data)], 333 | [Access.key(:state_features_memory), Access.key(:index)], 334 | [Access.key(:state_features_memory), Access.key(:size)] 335 | ] 336 | 337 | vectorized_state = 338 | Enum.reduce(vectorizable_paths, state, fn path, state -> 339 | update_in(state, path, fn value -> 340 | [value, _] = Nx.broadcast_vectors([value, random_key], align_ranks: false) 341 | value 342 | end) 343 | end) 344 | 345 | {vectorized_state, random_key} 346 | end 347 | end 348 | 349 | defp input_template(model) do 350 | model 351 | |> Axon.get_inputs() 352 | |> Map.new(fn {name, shape} -> 353 | [nil | shape] = Tuple.to_list(shape) 354 | shape = List.to_tuple([1 | shape]) 355 | {name, Nx.template(shape, :f32)} 356 | end) 357 | end 358 | 359 | @impl true 360 | def reset(random_key, %Rein{ 361 | episode: episode, 362 | environment_state: env, 363 | agent_state: state 364 | }) do 365 | [zero, _] = Nx.broadcast_vectors([Nx.tensor(0, type: :f32), random_key], align_ranks: false) 366 | total_reward = loss = loss_denominator = zero 367 | 368 | state = adapt_exploration(episode, state) 369 | 370 | init_state_features = state.environment_to_state_features_fn.(env) 371 | 372 | {n, _} = state.state_features_memory.data.shape 373 | 374 | zero = Nx.as_type(zero, :s64) 375 | 376 | state_features_memory = %{ 377 | state.state_features_memory 378 | | data: Nx.tile(init_state_features, [n, 1]), 379 | index: zero, 380 | size: Nx.add(n, zero) 381 | } 382 | 383 | {%{ 384 | state 385 | | total_reward: total_reward, 386 | loss: loss, 387 | loss_denominator: loss_denominator, 388 | state_features_memory: state_features_memory 389 | }, random_key} 390 | end 391 | 392 | defnp adapt_exploration( 393 | episode, 394 | %__MODULE__{ 395 | # exploration_fn: exploration_fn, 396 | experience_replay_buffer: experience_replay_buffer, 397 | ou_process: ou_process, 398 | exploration_decay_rate: exploration_decay_rate, 399 | exploration_increase_rate: exploration_increase_rate, 400 | min_sigma: min_sigma, 401 | max_sigma: max_sigma, 402 | total_reward: reward, 403 | performance_memory: performance_memory, 404 | performance_threshold: performance_threshold 405 | } = state 406 | ) do 407 | n = Nx.axis_size(performance_memory.data, 0) 408 | 409 | {ou_process, performance_memory} = 410 | cond do 411 | episode == 0 -> 412 | {ou_process, performance_memory} 413 | 414 | episode < n or experience_replay_buffer.size < n -> 415 | {ou_process, CircularBuffer.append(performance_memory, reward)} 416 | 417 | true -> 418 | performance_memory = CircularBuffer.append(performance_memory, reward) 419 | 420 | # After we take and reshape, the first row contains the oldest `n//2` samples 421 | # and the second row, the remaining newest samples. 422 | windows = 423 | performance_memory 424 | |> CircularBuffer.ordered_data() 425 | |> Nx.reshape({2, :auto}) 426 | 427 | # avg[0]: avg of the previous performance window 428 | # avg[1]: avg of the current performance window 429 | avg = Nx.mean(windows, axes: [1]) 430 | 431 | abs_diff = Nx.abs(avg[0] - avg[1]) 432 | 433 | sigma = 434 | if abs_diff < performance_threshold do 435 | # If decayed to less than an "eps" value, 436 | # we force it to increase from that "eps" instead. 437 | Nx.min(ou_process.sigma * exploration_increase_rate, max_sigma) 438 | else 439 | # can decay to 0 440 | Nx.max(ou_process.sigma * exploration_decay_rate, min_sigma) 441 | end 442 | 443 | {%OUProcess{ou_process | sigma: sigma}, performance_memory} 444 | end 445 | 446 | ou_process = %{ou_process | x: Nx.squeeze(ou_process.x)} 447 | 448 | %__MODULE__{ 449 | state 450 | | ou_process: OUProcess.reset(ou_process), 451 | performance_memory: performance_memory 452 | } 453 | end 454 | 455 | @impl true 456 | defn select_action( 457 | %Rein{random_key: random_key, agent_state: agent_state} = state, 458 | _iteration 459 | ) do 460 | %__MODULE__{ 461 | actor_params: actor_params, 462 | actor_predict_fn: actor_predict_fn, 463 | environment_to_state_features_fn: environment_to_state_features_fn, 464 | state_features_memory: state_features_memory, 465 | action_lower_limit: action_lower_limit, 466 | action_upper_limit: action_upper_limit, 467 | ou_process: ou_process 468 | } = agent_state 469 | 470 | state_features = environment_to_state_features_fn.(state.environment_state) 471 | 472 | state_features_memory = CircularBuffer.append(state_features_memory, state_features) 473 | 474 | action_vector = 475 | actor_predict_fn.(actor_params, CircularBuffer.ordered_data(state_features_memory)) 476 | 477 | {%OUProcess{x: additive_noise} = ou_process, random_key} = 478 | OUProcess.sample(random_key, ou_process) 479 | 480 | action_vector = action_vector + additive_noise 481 | 482 | clipped_action_vector = 483 | action_vector 484 | |> Nx.max(action_lower_limit) 485 | |> Nx.min(action_upper_limit) 486 | 487 | {clipped_action_vector, 488 | %{ 489 | state 490 | | agent_state: %{ 491 | agent_state 492 | | state_features_memory: state_features_memory, 493 | ou_process: ou_process 494 | }, 495 | random_key: random_key 496 | }} 497 | end 498 | 499 | @impl true 500 | defn record_observation( 501 | %{ 502 | agent_state: %__MODULE__{ 503 | actor_target_params: actor_target_params, 504 | actor_predict_fn: actor_predict_fn, 505 | critic_params: critic_params, 506 | critic_target_params: critic_target_params, 507 | critic_predict_fn: critic_predict_fn, 508 | state_features_memory: state_features_memory, 509 | environment_to_state_features_fn: environment_to_state_features_fn, 510 | experience_replay_buffer: experience_replay_buffer, 511 | gamma: gamma 512 | } 513 | }, 514 | action_vector, 515 | reward, 516 | is_terminal, 517 | %{environment_state: next_env_state} = state 518 | ) do 519 | next_state_features = environment_to_state_features_fn.(next_env_state) 520 | next_state_features_memory = CircularBuffer.append(state_features_memory, next_state_features) 521 | 522 | state_data = CircularBuffer.ordered_data(state_features_memory) 523 | next_state_data = CircularBuffer.ordered_data(next_state_features_memory) 524 | 525 | target_action_vector = actor_predict_fn.(actor_target_params, next_state_data) 526 | 527 | target_prediction = 528 | critic_predict_fn.(critic_target_params, next_state_data, target_action_vector) 529 | 530 | temporal_difference = 531 | reward + gamma * target_prediction * (1 - is_terminal) - 532 | critic_predict_fn.(critic_params, state_data, action_vector) 533 | 534 | temporal_difference = Nx.abs(temporal_difference) 535 | 536 | updates = 537 | Nx.concatenate([ 538 | Nx.flatten(state_data), 539 | Nx.flatten(action_vector), 540 | Nx.new_axis(reward, 0), 541 | Nx.new_axis(is_terminal, 0), 542 | Nx.flatten(next_state_features), 543 | Nx.reshape(temporal_difference, {1}) 544 | ]) 545 | 546 | updates = 547 | Nx.revectorize(updates, [], 548 | target_shape: {:auto, Nx.axis_size(experience_replay_buffer.data, -1)} 549 | ) 550 | 551 | experience_replay_buffer = CircularBuffer.append_multiple(experience_replay_buffer, updates) 552 | 553 | ensure_not_vectorized!(experience_replay_buffer.data) 554 | 555 | %{ 556 | state 557 | | agent_state: %{ 558 | state.agent_state 559 | | experience_replay_buffer: experience_replay_buffer, 560 | total_reward: state.agent_state.total_reward + reward 561 | } 562 | } 563 | end 564 | 565 | deftransformp ensure_not_vectorized!(t) do 566 | case t do 567 | %{vectorized_axes: []} -> 568 | :ok 569 | 570 | %{vectorized_axes: _vectorized_axes} -> 571 | raise "found unexpected vectorized axes" 572 | end 573 | end 574 | 575 | @impl true 576 | defn optimize_model(state) do 577 | %{ 578 | experience_replay_buffer: experience_replay_buffer, 579 | batch_size: batch_size, 580 | training_frequency: training_frequency, 581 | exploration_fn: exploration_fn 582 | } = state.agent_state 583 | 584 | exploring = state.episode |> Nx.devectorize() |> Nx.take(0) |> exploration_fn.() 585 | has_at_least_one_batch = experience_replay_buffer.size > batch_size 586 | 587 | should_train = 588 | has_at_least_one_batch and rem(experience_replay_buffer.index, training_frequency) == 0 589 | 590 | should_train = should_train |> Nx.devectorize() |> Nx.any() 591 | 592 | if should_train do 593 | train_loop(state, training_frequency, exploring) 594 | else 595 | state 596 | end 597 | end 598 | 599 | deftransformp train_loop(state, training_frequency, exploring) do 600 | if training_frequency == 1 do 601 | train_loop_step(state, exploring) 602 | else 603 | train_loop_while(state, training_frequency, exploring) 604 | end 605 | |> elem(0) 606 | end 607 | 608 | defnp train_loop_while(state, training_frequency, exploring) do 609 | while {state, exploring}, _ <- 0..(training_frequency - 1)//1, unroll: false do 610 | train_loop_step(state, exploring) 611 | end 612 | end 613 | 614 | defnp train_loop_step(state, exploring) do 615 | {batch, batch_indices, random_key} = 616 | sample_experience_replay_buffer(state.random_key, state.agent_state) 617 | 618 | train_actor = not exploring 619 | 620 | updated_state = 621 | %{state | random_key: random_key} 622 | |> train(batch, batch_indices, train_actor) 623 | |> soft_update_targets(train_actor) 624 | 625 | {updated_state, exploring} 626 | end 627 | 628 | defnp train(state, batch, batch_idx, train_actor) do 629 | %{ 630 | agent_state: %{ 631 | actor_params: actor_params, 632 | actor_target_params: actor_target_params, 633 | actor_predict_fn: actor_predict_fn, 634 | critic_params: critic_params, 635 | critic_target_params: critic_target_params, 636 | critic_predict_fn: critic_predict_fn, 637 | actor_optimizer_state: actor_optimizer_state, 638 | critic_optimizer_state: critic_optimizer_state, 639 | actor_optimizer_update_fn: actor_optimizer_update_fn, 640 | critic_optimizer_update_fn: critic_optimizer_update_fn, 641 | state_features_memory: state_features_memory, 642 | input_entry_size: input_entry_size, 643 | experience_replay_buffer: experience_replay_buffer, 644 | num_actions: num_actions, 645 | gamma: gamma 646 | } 647 | } = state 648 | 649 | batch_len = Nx.axis_size(batch, 0) 650 | {num_states, state_features_size} = state_features_memory.data.shape 651 | 652 | state_batch = 653 | batch 654 | |> Nx.slice_along_axis(0, input_entry_size, axis: 1) 655 | |> Nx.reshape({batch_len, num_states, state_features_size}) 656 | 657 | action_batch = Nx.slice_along_axis(batch, input_entry_size, num_actions, axis: 1) 658 | reward_batch = Nx.slice_along_axis(batch, input_entry_size + num_actions, 1, axis: 1) 659 | 660 | is_terminal_batch = Nx.slice_along_axis(batch, input_entry_size + num_actions + 1, 1, axis: 1) 661 | 662 | # we only persisted the new state, so we need to manipulate the `state_batch` to get the actual state 663 | next_state_batch = 664 | batch 665 | |> Nx.slice_along_axis(input_entry_size + num_actions + 2, state_features_size, axis: 1) 666 | |> Nx.reshape({batch_len, 1, state_features_size}) 667 | 668 | next_state_batch = 669 | if num_states == 1 do 670 | next_state_batch 671 | else 672 | next_state_batch = 673 | [ 674 | state_batch, 675 | next_state_batch 676 | ] 677 | |> Nx.concatenate(axis: 1) 678 | |> Nx.slice_along_axis(1, num_states, axis: 1) 679 | 680 | expected_shape = {batch_len, num_states, state_features_size} 681 | actual_shape = Nx.shape(next_state_batch) 682 | 683 | case {actual_shape, expected_shape} do 684 | {x, x} -> 685 | :ok 686 | 687 | {actual_shape, expected_shape} -> 688 | raise "incorrect size for next_state_batch, expected #{inspect(expected_shape)}, got: #{inspect(actual_shape)}" 689 | end 690 | 691 | next_state_batch 692 | end 693 | 694 | non_final_mask = not is_terminal_batch 695 | 696 | ### Train Critic 697 | 698 | {{experience_replay_buffer, critic_loss}, critic_gradient} = 699 | value_and_grad( 700 | critic_params, 701 | fn critic_params -> 702 | target_actions = actor_predict_fn.(actor_target_params, next_state_batch) 703 | 704 | q_target = 705 | critic_target_params 706 | |> critic_predict_fn.(next_state_batch, target_actions) 707 | |> stop_grad() 708 | 709 | %{shape: {n, 1}} = q = critic_predict_fn.(critic_params, state_batch, action_batch) 710 | 711 | %{shape: {m, 1}} = backup = reward_batch + gamma * non_final_mask * q_target 712 | 713 | case {m, n} do 714 | {m, n} when m != n -> 715 | raise "shape mismatch for batch values" 716 | 717 | _ -> 718 | 1 719 | end 720 | 721 | td_errors = Nx.abs(backup - q) 722 | 723 | { 724 | update_priorities( 725 | experience_replay_buffer, 726 | batch_idx, 727 | td_errors 728 | ), 729 | Nx.mean(td_errors ** 2) 730 | } 731 | end, 732 | &elem(&1, 1) 733 | ) 734 | 735 | {critic_updates, critic_optimizer_state} = 736 | critic_optimizer_update_fn.(critic_gradient, critic_optimizer_state, critic_params) 737 | 738 | critic_params = Polaris.Updates.apply_updates(critic_params, critic_updates) 739 | 740 | ### Train Actor 741 | 742 | # We train the actor 3x less than the critic to avoid 743 | # training onto a moving target 744 | 745 | {actor_params, actor_optimizer_state} = 746 | if train_actor do 747 | actor_gradient = 748 | grad(actor_params, fn actor_params -> 749 | actions = actor_predict_fn.(actor_params, state_batch) 750 | q = critic_predict_fn.(critic_params, state_batch, actions) 751 | -Nx.mean(q) 752 | end) 753 | 754 | {actor_updates, actor_optimizer_state} = 755 | actor_optimizer_update_fn.(actor_gradient, actor_optimizer_state, actor_params) 756 | 757 | actor_params = Polaris.Updates.apply_updates(actor_params, actor_updates) 758 | {actor_params, actor_optimizer_state} 759 | else 760 | {actor_params, actor_optimizer_state} 761 | end 762 | 763 | %{ 764 | state 765 | | agent_state: %{ 766 | state.agent_state 767 | | actor_params: actor_params, 768 | actor_optimizer_state: actor_optimizer_state, 769 | critic_params: critic_params, 770 | critic_optimizer_state: critic_optimizer_state, 771 | loss: state.agent_state.loss + critic_loss, 772 | loss_denominator: state.agent_state.loss_denominator + 1, 773 | experience_replay_buffer: experience_replay_buffer 774 | } 775 | } 776 | end 777 | 778 | defnp soft_update_targets(state, train_actor) do 779 | %{ 780 | agent_state: 781 | %{ 782 | actor_target_params: actor_target_params, 783 | actor_params: actor_params, 784 | critic_target_params: critic_target_params, 785 | critic_params: critic_params, 786 | tau: tau 787 | } = agent_state 788 | } = state 789 | 790 | actor_target_params = 791 | if train_actor do 792 | Axon.Shared.deep_merge( 793 | actor_params, 794 | actor_target_params, 795 | &Nx.as_type(&1 * tau + &2 * (1 - tau), Nx.type(&1)) 796 | ) 797 | else 798 | actor_target_params 799 | end 800 | 801 | critic_target_params = 802 | Axon.Shared.deep_merge( 803 | critic_params, 804 | critic_target_params, 805 | &Nx.as_type(&1 * tau + &2 * (1 - tau), Nx.type(&1)) 806 | ) 807 | 808 | %{ 809 | state 810 | | agent_state: %{ 811 | agent_state 812 | | actor_target_params: actor_target_params, 813 | critic_target_params: critic_target_params 814 | } 815 | } 816 | end 817 | 818 | @alpha 0.6 819 | defnp sample_experience_replay_buffer( 820 | random_key, 821 | %__MODULE__{ 822 | batch_size: batch_size 823 | } = agent_state 824 | ) do 825 | data = agent_state.experience_replay_buffer.data 826 | 827 | temporal_difference = 828 | data 829 | |> Nx.slice_along_axis(Nx.axis_size(data, 1) - 1, 1, axis: 1) 830 | |> Nx.flatten() 831 | 832 | priorities = temporal_difference ** @alpha 833 | probs = priorities / Nx.sum(priorities) 834 | 835 | split_key = Nx.Random.split(random_key) 836 | 837 | random_key = split_key[0] 838 | vec_k = split_key[1] 839 | 840 | k = Nx.devectorize(vec_k, keep_names: false) 841 | 842 | k = 843 | case Nx.shape(k) do 844 | {2} -> 845 | k 846 | 847 | {_, 2} -> 848 | Nx.take(k, 0) 849 | end 850 | 851 | {batch_idx, _} = 852 | Nx.Random.choice(k, Nx.iota(temporal_difference.shape), probs, 853 | samples: batch_size, 854 | replace: false, 855 | axis: 0 856 | ) 857 | 858 | batch = Nx.take(data, batch_idx) 859 | 860 | {batch, batch_idx, random_key} 861 | end 862 | 863 | defn update_priorities( 864 | %{data: %{shape: {_, item_size}}} = buffer, 865 | %{shape: {n}} = entry_indices, 866 | td_errors 867 | ) do 868 | case td_errors.shape do 869 | {^n, 1} -> 870 | :ok 871 | 872 | shape -> 873 | raise "invalid shape for td_errors, got: #{inspect(shape)}, expected: #{inspect({n, 1})}" 874 | end 875 | 876 | indices = Nx.stack([entry_indices, Nx.broadcast(item_size - 1, {n})], axis: -1) 877 | 878 | %{buffer | data: Nx.indexed_put(buffer.data, indices, Nx.reshape(td_errors, {n}))} 879 | end 880 | end 881 | -------------------------------------------------------------------------------- /lib/rein/agents/sac.ex: -------------------------------------------------------------------------------- 1 | defmodule Rein.Agents.SAC do 2 | @moduledoc """ 3 | Soft Actor-Critic implementation. 4 | 5 | This assumes that the Actor network will output `{nil, num_actions, 2}`, 6 | where for each action they output the $\\mu$ and $\\sigma$ values of a random 7 | normal distribution, and that the Critic network accepts `"actions"` input with 8 | shape `{nil, num_actions}`, where the action is calculated by sampling from 9 | said random distribution. 10 | 11 | Actions are deemed to be in a continuous space of type `:f32`. 12 | 13 | The Dual Q implementation utilizes two copies of the critic network, `critic1` and `critic2`, 14 | each with their own separate target network. 15 | 16 | Vectorized axes from `:random_key` are propagated normally throughout 17 | the agent state for parallel simulations, but all samples are stored in the same 18 | circular buffer. After all simulations have ran, the optimization steps are run 19 | on a sample space consisting of all previous experiences, including all of the 20 | parallel simulations that have just finished executing. 21 | """ 22 | import Nx.Defn 23 | 24 | import Nx.Constants, only: [pi: 1] 25 | 26 | alias Rein.Utils.CircularBuffer 27 | 28 | @behaviour Rein.Agent 29 | 30 | @derive {Nx.Container, 31 | containers: [ 32 | :actor_params, 33 | :actor_target_params, 34 | :critic1_params, 35 | :critic2_params, 36 | :critic1_target_params, 37 | :critic2_target_params, 38 | :experience_replay_buffer, 39 | :loss, 40 | :loss_denominator, 41 | :total_reward, 42 | :actor_optimizer_state, 43 | :critic1_optimizer_state, 44 | :critic2_optimizer_state, 45 | :action_lower_limit, 46 | :action_upper_limit, 47 | :gamma, 48 | :tau, 49 | :state_features_memory, 50 | :log_entropy_coefficient, 51 | :log_entropy_coefficient_optimizer_state, 52 | :target_entropy 53 | ], 54 | keep: [ 55 | :environment_to_state_features_fn, 56 | :actor_predict_fn, 57 | :critic_predict_fn, 58 | :num_actions, 59 | :actor_optimizer_update_fn, 60 | :critic_optimizer_update_fn, 61 | :batch_size, 62 | :training_frequency, 63 | :input_entry_size, 64 | :reward_scale, 65 | :log_entropy_coefficient_optimizer_update_fn, 66 | :train_log_entropy_coefficient 67 | ]} 68 | 69 | defstruct [ 70 | :num_actions, 71 | :actor_params, 72 | :actor_target_params, 73 | :critic1_params, 74 | :critic2_params, 75 | :critic1_target_params, 76 | :critic2_target_params, 77 | :actor_predict_fn, 78 | :critic_predict_fn, 79 | :experience_replay_buffer, 80 | :environment_to_state_features_fn, 81 | :gamma, 82 | :tau, 83 | :batch_size, 84 | :training_frequency, 85 | :actor_optimizer_state, 86 | :critic1_optimizer_state, 87 | :critic2_optimizer_state, 88 | :action_lower_limit, 89 | :action_upper_limit, 90 | :loss, 91 | :loss_denominator, 92 | :total_reward, 93 | :actor_optimizer_update_fn, 94 | :critic_optimizer_update_fn, 95 | :state_features_memory, 96 | :input_entry_size, 97 | :log_entropy_coefficient, 98 | :reward_scale, 99 | :train_log_entropy_coefficient, 100 | :log_entropy_coefficient_optimizer_update_fn, 101 | :log_entropy_coefficient_optimizer_state, 102 | :target_entropy 103 | ] 104 | 105 | @impl true 106 | def init(random_key, opts \\ []) do 107 | expected_opts = [ 108 | :actor_net, 109 | :critic_net, 110 | :environment_to_state_features_fn, 111 | :state_features_memory_to_input_fn, 112 | :state_features_size, 113 | :actor_optimizer, 114 | :critic_optimizer, 115 | :entropy_coefficient_optimizer, 116 | reward_scale: 1, 117 | state_features_memory_length: 1, 118 | gamma: 0.99, 119 | experience_replay_buffer_max_size: 100_000, 120 | tau: 0.005, 121 | batch_size: 64, 122 | training_frequency: 32, 123 | action_lower_limit: -1.0, 124 | action_upper_limit: 1.0, 125 | entropy_coefficient: 0.2, 126 | saved_state: %{} 127 | ] 128 | 129 | opts = Keyword.validate!(opts, expected_opts) 130 | 131 | # TO-DO: use NimbleOptions 132 | expected_opts 133 | |> Enum.filter(fn x -> is_atom(x) or (is_tuple(x) and is_nil(elem(x, 1))) end) 134 | |> Enum.reject(fn k -> 135 | k in [:state_features_memory, :experience_replay_buffer, :entropy_coefficient_optimizer] 136 | end) 137 | |> Enum.reduce(opts, fn 138 | k, opts -> 139 | case List.keytake(opts, k, 0) do 140 | {{^k, _}, opts} -> opts 141 | nil -> raise ArgumentError, "missing option #{k}" 142 | end 143 | end) 144 | |> Enum.each(fn {k, v} -> 145 | if is_nil(v) do 146 | raise ArgumentError, "option #{k} cannot be nil" 147 | end 148 | end) 149 | 150 | {actor_optimizer_init_fn, actor_optimizer_update_fn} = opts[:actor_optimizer] 151 | {critic_optimizer_init_fn, critic_optimizer_update_fn} = opts[:critic_optimizer] 152 | 153 | log_entropy_coefficient = :math.log(opts[:entropy_coefficient]) 154 | 155 | {train_log_entropy_coefficient, log_entropy_coefficient_optimizer_init_fn, 156 | log_entropy_coefficient_optimizer_update_fn} = 157 | case opts[:entropy_coefficient_optimizer] do 158 | {init, upd} -> {true, init, upd} 159 | _ -> {false, fn _ -> 0 end, 0} 160 | end 161 | 162 | actor_net = opts[:actor_net] 163 | critic_net = opts[:critic_net] 164 | 165 | environment_to_state_features_fn = opts[:environment_to_state_features_fn] 166 | state_features_memory_to_input_fn = opts[:state_features_memory_to_input_fn] 167 | 168 | {actor_init_fn, actor_predict_fn} = Axon.build(actor_net, seed: 0) 169 | {critic_init_fn, critic_predict_fn} = Axon.build(critic_net, seed: 1) 170 | 171 | actor_predict_fn = fn random_key, params, state_features_memory -> 172 | action_distribution_vector = 173 | actor_predict_fn.(params, state_features_memory_to_input_fn.(state_features_memory)) 174 | 175 | mu = action_distribution_vector[[.., .., 0]] 176 | log_stddev = action_distribution_vector[[.., .., 1]] 177 | 178 | stddev = Nx.exp(log_stddev) 179 | 180 | eps_shape = Nx.shape(stddev) 181 | 182 | # Nx.Random.normal is treated as a constant, so we obtain `eps` from a mean-0 stddev-1 183 | # normal distribution and scale it by our stddev below to obtain our sample in a way that 184 | # the grads a propagated through properly. 185 | {eps, random_key} = Nx.Random.normal(random_key, shape: eps_shape) 186 | 187 | pre_squash_action = Nx.add(mu, Nx.multiply(stddev, eps)) 188 | action = Nx.tanh(pre_squash_action) 189 | 190 | log_probability = action_log_probability(mu, stddev, log_stddev, pre_squash_action, action) 191 | 192 | {action, log_probability, random_key} 193 | end 194 | 195 | critic_predict_fn = fn params, state_features_memory, action_vector -> 196 | input = 197 | state_features_memory 198 | |> state_features_memory_to_input_fn.() 199 | |> Map.put("actions", action_vector) 200 | 201 | critic_predict_fn.(params, input) 202 | end 203 | 204 | input_template = input_template(actor_net) 205 | 206 | case input_template do 207 | %{"actions" => _} -> 208 | raise ArgumentError, 209 | "the input template for the actor_network must not contain the reserved key \"actions\"" 210 | 211 | _ -> 212 | :ok 213 | end 214 | 215 | {1, num_actions, 2} = Axon.get_output_shape(actor_net, input_template) 216 | 217 | critic_template = input_template(critic_net) 218 | 219 | case critic_template do 220 | %{"actions" => action_input} -> 221 | action_input = %{action_input | vectorized_axes: []} 222 | 223 | unless action_input != Nx.template({nil, num_actions}, :f32) do 224 | raise ArgumentError, 225 | "the critic network must accept the \"actions\" input with shape {nil, #{num_actions}} and type :f32, got input template: #{critic_template}" 226 | end 227 | 228 | critic_template = Map.delete(critic_template, "actions") 229 | 230 | if critic_template != input_template do 231 | raise ArgumentError, 232 | "the critic network must have the same input template as the actor network + the \"action\" input" 233 | end 234 | 235 | _ -> 236 | :ok 237 | end 238 | 239 | log_entropy_coefficient_optimizer_state = 240 | log_entropy_coefficient_optimizer_init_fn.(log_entropy_coefficient) 241 | 242 | actor_params = actor_init_fn.(input_template, %{}) 243 | actor_optimizer_state = actor_optimizer_init_fn.(actor_params) 244 | 245 | actor_target_params = actor_init_fn.(input_template, %{}) 246 | 247 | critic1_params = critic_init_fn.(critic_template, %{}) 248 | critic2_params = critic_init_fn.(critic_template, %{}) 249 | 250 | critic1_target_params = critic_init_fn.(critic_template, %{}) 251 | critic2_target_params = critic_init_fn.(critic_template, %{}) 252 | 253 | critic1_optimizer_state = critic_optimizer_init_fn.(critic1_target_params) 254 | critic2_optimizer_state = critic_optimizer_init_fn.(critic2_target_params) 255 | 256 | state_features_size = opts[:state_features_size] 257 | 258 | total_reward = loss = loss_denominator = Nx.tensor(0, type: :f32) 259 | experience_replay_buffer_max_size = opts[:experience_replay_buffer_max_size] 260 | state_features_memory_length = opts[:state_features_memory_length] 261 | input_entry_size = state_features_size * state_features_memory_length 262 | 263 | {exp_replay_buffer, random_key} = 264 | if buffer = opts[:experience_replay_buffer] do 265 | {buffer, random_key} 266 | else 267 | {random_data_1, random_key} = 268 | Nx.Random.normal(random_key, 0, 10, 269 | shape: {experience_replay_buffer_max_size, input_entry_size + num_actions} 270 | ) 271 | 272 | init_reward = Nx.broadcast(-1.0e-8, {experience_replay_buffer_max_size, 1}) 273 | 274 | {random_data_2, random_key} = 275 | Nx.Random.normal(random_key, 0, 10, 276 | shape: {experience_replay_buffer_max_size, state_features_size + 1} 277 | ) 278 | 279 | data = 280 | [random_data_1, init_reward, random_data_2] 281 | |> Nx.concatenate(axis: 1) 282 | |> then(&Nx.revectorize(&1, [], target_shape: Tuple.insert_at(&1.shape, 0, :auto))) 283 | 284 | buffer = %CircularBuffer{ 285 | data: data[[0, .., ..]], 286 | index: 0, 287 | size: 0 288 | } 289 | 290 | {buffer, random_key} 291 | end 292 | 293 | state = %__MODULE__{ 294 | input_entry_size: input_entry_size, 295 | log_entropy_coefficient_optimizer_state: log_entropy_coefficient_optimizer_state, 296 | log_entropy_coefficient_optimizer_update_fn: log_entropy_coefficient_optimizer_update_fn, 297 | target_entropy: -num_actions, 298 | train_log_entropy_coefficient: train_log_entropy_coefficient, 299 | state_features_memory: 300 | opts[:state_features_memory] || 301 | CircularBuffer.new({state_features_memory_length, state_features_size}), 302 | num_actions: num_actions, 303 | actor_params: actor_params, 304 | actor_target_params: actor_target_params, 305 | critic1_params: critic1_params, 306 | critic2_params: critic2_params, 307 | critic1_target_params: critic1_target_params, 308 | critic2_target_params: critic2_target_params, 309 | actor_predict_fn: actor_predict_fn, 310 | critic_predict_fn: critic_predict_fn, 311 | experience_replay_buffer: exp_replay_buffer, 312 | environment_to_state_features_fn: environment_to_state_features_fn, 313 | gamma: opts[:gamma], 314 | tau: opts[:tau], 315 | batch_size: opts[:batch_size], 316 | training_frequency: opts[:training_frequency], 317 | total_reward: total_reward, 318 | loss: loss, 319 | loss_denominator: loss_denominator, 320 | actor_optimizer_update_fn: actor_optimizer_update_fn, 321 | critic_optimizer_update_fn: critic_optimizer_update_fn, 322 | actor_optimizer_state: actor_optimizer_state, 323 | critic1_optimizer_state: critic1_optimizer_state, 324 | critic2_optimizer_state: critic2_optimizer_state, 325 | action_lower_limit: opts[:action_lower_limit], 326 | action_upper_limit: opts[:action_upper_limit], 327 | log_entropy_coefficient: Nx.log(opts[:entropy_coefficient]), 328 | reward_scale: opts[:reward_scale] 329 | } 330 | 331 | saved_state = 332 | (opts[:saved_state] || %{}) 333 | |> Map.take(Map.keys(%__MODULE__{}) -- [:__struct__]) 334 | |> Enum.filter(fn {_, v} -> v && not is_function(v) end) 335 | |> Map.new() 336 | 337 | state = Map.merge(state, saved_state) 338 | 339 | case random_key.vectorized_axes do 340 | [] -> 341 | {state, random_key} 342 | 343 | _ -> 344 | vectorizable_paths = [ 345 | [Access.key(:loss)], 346 | [Access.key(:loss_denominator)], 347 | [Access.key(:total_reward)], 348 | [Access.key(:state_features_memory), Access.key(:data)], 349 | [Access.key(:state_features_memory), Access.key(:index)], 350 | [Access.key(:state_features_memory), Access.key(:size)] 351 | ] 352 | 353 | vectorized_state = 354 | Enum.reduce(vectorizable_paths, state, fn path, state -> 355 | update_in(state, path, fn value -> 356 | [value, _] = Nx.broadcast_vectors([value, random_key], align_ranks: false) 357 | value 358 | end) 359 | end) 360 | 361 | {vectorized_state, random_key} 362 | end 363 | end 364 | 365 | defp input_template(model) do 366 | model 367 | |> Axon.get_inputs() 368 | |> Map.new(fn {name, shape} -> 369 | [nil | shape] = Tuple.to_list(shape) 370 | shape = List.to_tuple([1 | shape]) 371 | {name, Nx.template(shape, :f32)} 372 | end) 373 | end 374 | 375 | @impl true 376 | def reset(random_key, %Rein{ 377 | environment_state: env, 378 | agent_state: state 379 | }) do 380 | [zero, _] = Nx.broadcast_vectors([Nx.tensor(0, type: :f32), random_key], align_ranks: false) 381 | total_reward = loss = loss_denominator = zero 382 | 383 | init_state_features = state.environment_to_state_features_fn.(env) 384 | 385 | {n, _} = state.state_features_memory.data.shape 386 | 387 | zero = Nx.as_type(zero, :s64) 388 | 389 | state_features_memory = %{ 390 | state.state_features_memory 391 | | data: Nx.tile(init_state_features, [n, 1]), 392 | index: zero, 393 | size: Nx.add(n, zero) 394 | } 395 | 396 | {%{ 397 | state 398 | | total_reward: total_reward, 399 | loss: loss, 400 | loss_denominator: loss_denominator, 401 | state_features_memory: state_features_memory 402 | }, random_key} 403 | end 404 | 405 | @impl true 406 | defn select_action( 407 | %Rein{random_key: random_key, agent_state: agent_state} = state, 408 | _iteration 409 | ) do 410 | %__MODULE__{ 411 | actor_params: actor_params, 412 | actor_predict_fn: actor_predict_fn, 413 | environment_to_state_features_fn: environment_to_state_features_fn, 414 | state_features_memory: state_features_memory, 415 | action_lower_limit: action_lower_limit, 416 | action_upper_limit: action_upper_limit 417 | } = agent_state 418 | 419 | state_features = environment_to_state_features_fn.(state.environment_state) 420 | 421 | state_features_memory = CircularBuffer.append(state_features_memory, state_features) 422 | 423 | {action_vector, _logprob, random_key} = 424 | actor_predict_fn.( 425 | random_key, 426 | actor_params, 427 | CircularBuffer.ordered_data(state_features_memory) 428 | ) 429 | 430 | clipped_action_vector = 431 | action_vector 432 | |> Nx.max(action_lower_limit) 433 | |> Nx.min(action_upper_limit) 434 | 435 | {clipped_action_vector, 436 | %{ 437 | state 438 | | agent_state: %{ 439 | agent_state 440 | | state_features_memory: state_features_memory 441 | }, 442 | random_key: random_key 443 | }} 444 | end 445 | 446 | @impl true 447 | defn record_observation( 448 | %{ 449 | agent_state: %__MODULE__{ 450 | state_features_memory: state_features_memory, 451 | environment_to_state_features_fn: environment_to_state_features_fn, 452 | experience_replay_buffer: experience_replay_buffer, 453 | reward_scale: reward_scale 454 | } 455 | }, 456 | action_vector, 457 | reward, 458 | is_terminal, 459 | %{environment_state: next_env_state} = state 460 | ) do 461 | next_state_features = environment_to_state_features_fn.(next_env_state) 462 | state_data = CircularBuffer.ordered_data(state_features_memory) 463 | 464 | reward = reward * reward_scale 465 | 466 | updates = 467 | Nx.concatenate([ 468 | Nx.flatten(state_data), 469 | Nx.flatten(action_vector), 470 | Nx.new_axis(reward, 0), 471 | Nx.new_axis(is_terminal, 0), 472 | Nx.flatten(next_state_features) 473 | ]) 474 | 475 | updates = 476 | Nx.revectorize(updates, [], 477 | target_shape: {:auto, Nx.axis_size(experience_replay_buffer.data, -1)} 478 | ) 479 | 480 | experience_replay_buffer = CircularBuffer.append_multiple(experience_replay_buffer, updates) 481 | 482 | ensure_not_vectorized!(experience_replay_buffer.data) 483 | 484 | %{ 485 | state 486 | | agent_state: %{ 487 | state.agent_state 488 | | experience_replay_buffer: experience_replay_buffer, 489 | total_reward: state.agent_state.total_reward + reward 490 | } 491 | } 492 | end 493 | 494 | deftransformp ensure_not_vectorized!(t) do 495 | case t do 496 | %{vectorized_axes: []} -> 497 | :ok 498 | 499 | %{vectorized_axes: _vectorized_axes} -> 500 | raise "found unexpected vectorized axes" 501 | end 502 | end 503 | 504 | @impl true 505 | defn optimize_model(state) do 506 | %{ 507 | batch_size: batch_size, 508 | training_frequency: training_frequency 509 | } = state.agent_state 510 | 511 | # Run training after all simulations have ended. 512 | is_terminal = 513 | state.environment_state.is_terminal 514 | |> Nx.devectorize() 515 | |> Nx.all() 516 | 517 | if is_terminal and state.agent_state.experience_replay_buffer.size > batch_size do 518 | train_loop( 519 | state, 520 | training_frequency * vectorized_axes(state.environment_state.is_terminal) 521 | ) 522 | else 523 | state 524 | end 525 | end 526 | 527 | deftransformp vectorized_axes(t) do 528 | # flat_size is all entries, inclusing vectorized axes 529 | # size is just the non-vectorized part 530 | # So training frequency here is the number of vectorized axes, 531 | # i.e. we'll run one iteration per episode simulated 532 | div(Nx.flat_size(t), Nx.size(t)) 533 | end 534 | 535 | deftransformp train_loop(state, training_frequency) do 536 | if training_frequency == 1 do 537 | train_loop_step(state) 538 | else 539 | train_loop_while(state, training_frequency: training_frequency) 540 | end 541 | end 542 | 543 | defnp train_loop_while(state, opts \\ []) do 544 | training_frequency = opts[:training_frequency] 545 | 546 | while state, _ <- 0..(training_frequency - 1)//1, unroll: false do 547 | train_loop_step(state) 548 | end 549 | end 550 | 551 | defnp train_loop_step(state) do 552 | {batch, random_key} = sample_experience_replay_buffer(state.random_key, state.agent_state) 553 | 554 | %{state | random_key: random_key} 555 | |> train(batch) 556 | |> soft_update_targets() 557 | end 558 | 559 | defnp train(state, batch) do 560 | %{ 561 | agent_state: %__MODULE__{ 562 | actor_params: actor_params, 563 | actor_target_params: actor_target_params, 564 | actor_predict_fn: actor_predict_fn, 565 | critic1_params: critic1_params, 566 | critic2_params: critic2_params, 567 | critic1_target_params: critic1_target_params, 568 | critic2_target_params: critic2_target_params, 569 | critic_predict_fn: critic_predict_fn, 570 | actor_optimizer_state: actor_optimizer_state, 571 | critic1_optimizer_state: critic1_optimizer_state, 572 | critic2_optimizer_state: critic2_optimizer_state, 573 | actor_optimizer_update_fn: actor_optimizer_update_fn, 574 | critic_optimizer_update_fn: critic_optimizer_update_fn, 575 | state_features_memory: state_features_memory, 576 | input_entry_size: input_entry_size, 577 | experience_replay_buffer: experience_replay_buffer, 578 | num_actions: num_actions, 579 | gamma: gamma, 580 | log_entropy_coefficient: log_entropy_coefficient, 581 | log_entropy_coefficient_optimizer_update_fn: log_entropy_coefficient_optimizer_update_fn, 582 | log_entropy_coefficient_optimizer_state: log_entropy_coefficient_optimizer_state, 583 | target_entropy: target_entropy, 584 | train_log_entropy_coefficient: train_log_entropy_coefficient 585 | }, 586 | random_key: random_key 587 | } = state 588 | 589 | ks = Nx.Random.split(random_key) 590 | 591 | {random_key, k1} = 592 | case {Nx.flat_size(random_key), Nx.size(random_key)} do 593 | {s, s} -> 594 | {ks[0], ks[1]} 595 | 596 | _ -> 597 | random_key = ks[0] 598 | k1 = ks[1] |> Nx.devectorize() |> Nx.take(0) 599 | {random_key, k1} 600 | end 601 | 602 | batch_len = Nx.axis_size(batch, 0) 603 | {num_states, state_features_size} = state_features_memory.data.shape 604 | 605 | state_batch = 606 | batch 607 | |> Nx.slice_along_axis(0, input_entry_size, axis: 1) 608 | |> Nx.reshape({batch_len, num_states, state_features_size}) 609 | 610 | action_batch = Nx.slice_along_axis(batch, input_entry_size, num_actions, axis: 1) 611 | reward_batch = Nx.slice_along_axis(batch, input_entry_size + num_actions, 1, axis: 1) 612 | 613 | is_terminal_batch = Nx.slice_along_axis(batch, input_entry_size + num_actions + 1, 1, axis: 1) 614 | 615 | # we only persisted the new state, so we need to manipulate the `state_batch` to get the actual state 616 | # with state memory 617 | next_state_batch = 618 | batch 619 | |> Nx.slice_along_axis(input_entry_size + num_actions + 2, state_features_size, axis: 1) 620 | |> Nx.reshape({batch_len, 1, state_features_size}) 621 | 622 | next_state_batch = 623 | if num_states == 1 do 624 | next_state_batch 625 | else 626 | next_state_batch = 627 | [ 628 | state_batch, 629 | next_state_batch 630 | ] 631 | |> Nx.concatenate(axis: 1) 632 | |> Nx.slice_along_axis(1, num_states, axis: 1) 633 | 634 | expected_shape = {batch_len, num_states, state_features_size} 635 | actual_shape = Nx.shape(next_state_batch) 636 | 637 | case {actual_shape, expected_shape} do 638 | {x, x} -> 639 | :ok 640 | 641 | {actual_shape, expected_shape} -> 642 | raise "incorrect size for next_state_batch, expected #{inspect(expected_shape)}, got: #{inspect(actual_shape)}" 643 | end 644 | 645 | next_state_batch 646 | end 647 | 648 | non_final_mask = not is_terminal_batch 649 | 650 | entropy_coefficient = stop_grad(Nx.exp(log_entropy_coefficient)) 651 | 652 | ### Train critic_params 653 | 654 | {{critic_loss, k1}, {critic1_gradient, critic2_gradient}} = 655 | value_and_grad( 656 | {critic1_params, critic2_params}, 657 | fn {critic1_params, critic2_params} -> 658 | # y_i = r_i + γ * min_{j=1,2} Q'(s_{i+1}, π(s_{i+1}|θ)|φ'_j) 659 | 660 | {target_actions, log_probability, k1} = 661 | actor_predict_fn.(k1, actor_target_params, next_state_batch) 662 | 663 | q1_target = critic_predict_fn.(critic1_target_params, next_state_batch, target_actions) 664 | q2_target = critic_predict_fn.(critic2_target_params, next_state_batch, target_actions) 665 | 666 | %{shape: {k, 1}} = q_target = stop_grad(Nx.min(q1_target, q2_target)) 667 | 668 | next_log_prob = 669 | log_probability 670 | |> Nx.devectorize() 671 | |> Nx.sum(axes: [0]) 672 | 673 | q_target = q_target - entropy_coefficient * next_log_prob 674 | 675 | %{shape: {m, 1}} = backup = reward_batch + gamma * non_final_mask * q_target 676 | 677 | # q values for each critic network 678 | %{shape: {n, 1}} = q1 = critic_predict_fn.(critic1_params, state_batch, action_batch) 679 | 680 | %{shape: {_n, 1}} = q2 = critic_predict_fn.(critic2_params, state_batch, action_batch) 681 | 682 | case {k, m, n} do 683 | {k, m, n} when m != n or m != k or n != k -> 684 | raise "shape mismatch for batch values" 685 | 686 | _ -> 687 | 1 688 | end 689 | 690 | backup = Nx.devectorize(backup) 691 | critic1_loss = Nx.mean((backup - Nx.new_axis(q1, 0)) ** 2) 692 | critic2_loss = Nx.mean((backup - Nx.new_axis(q2, 0)) ** 2) 693 | 694 | {0.5 * Nx.add(critic1_loss, critic2_loss), k1} 695 | end, 696 | &elem(&1, 0) 697 | ) 698 | 699 | {critic1_updates, critic1_optimizer_state} = 700 | critic_optimizer_update_fn.(critic1_gradient, critic1_optimizer_state, critic1_params) 701 | 702 | critic1_params = Polaris.Updates.apply_updates(critic1_params, critic1_updates) 703 | 704 | {critic2_updates, critic2_optimizer_state} = 705 | critic_optimizer_update_fn.(critic2_gradient, critic2_optimizer_state, critic2_params) 706 | 707 | critic2_params = Polaris.Updates.apply_updates(critic2_params, critic2_updates) 708 | 709 | ### Train Actor 710 | 711 | {{_, log_probs}, actor_gradient} = 712 | value_and_grad( 713 | actor_params, 714 | fn actor_params -> 715 | {actions, log_probs, _k1} = 716 | actor_predict_fn.(k1, actor_params, state_batch) 717 | 718 | q1 = critic_predict_fn.(critic1_params, state_batch, actions) 719 | 720 | q2 = critic_predict_fn.(critic2_params, state_batch, actions) 721 | 722 | q = Nx.min(q1, q2) 723 | 724 | {Nx.mean(entropy_coefficient * log_probs - q), log_probs} 725 | end, 726 | &elem(&1, 0) 727 | ) 728 | 729 | {actor_updates, actor_optimizer_state} = 730 | actor_optimizer_update_fn.(actor_gradient, actor_optimizer_state, actor_params) 731 | 732 | actor_params = Polaris.Updates.apply_updates(actor_params, actor_updates) 733 | 734 | ### Train entropy_coefficient 735 | 736 | {log_entropy_coefficient, log_entropy_coefficient_optimizer_state} = 737 | case train_log_entropy_coefficient do 738 | false -> 739 | # entropy_coef is non-trainable 740 | {log_entropy_coefficient, log_entropy_coefficient_optimizer_state} 741 | 742 | true -> 743 | g = 744 | grad(log_entropy_coefficient, fn log_entropy_coefficient -> 745 | -Nx.mean(log_entropy_coefficient * (log_probs + target_entropy)) 746 | end) 747 | 748 | {updates, log_entropy_coefficient_optimizer_state} = 749 | log_entropy_coefficient_optimizer_update_fn.( 750 | g, 751 | log_entropy_coefficient_optimizer_state, 752 | log_entropy_coefficient 753 | ) 754 | 755 | log_entropy_coefficient = 756 | Polaris.Updates.apply_updates(log_entropy_coefficient, updates) 757 | 758 | {log_entropy_coefficient, log_entropy_coefficient_optimizer_state} 759 | end 760 | 761 | %{ 762 | state 763 | | agent_state: %{ 764 | state.agent_state 765 | | actor_params: actor_params, 766 | actor_optimizer_state: actor_optimizer_state, 767 | critic1_params: critic1_params, 768 | critic1_optimizer_state: critic1_optimizer_state, 769 | critic2_params: critic2_params, 770 | critic2_optimizer_state: critic2_optimizer_state, 771 | loss: state.agent_state.loss + critic_loss, 772 | loss_denominator: state.agent_state.loss_denominator + 1, 773 | experience_replay_buffer: experience_replay_buffer, 774 | log_entropy_coefficient: log_entropy_coefficient, 775 | log_entropy_coefficient_optimizer_state: log_entropy_coefficient_optimizer_state 776 | }, 777 | random_key: random_key 778 | } 779 | end 780 | 781 | defnp soft_update_targets(state) do 782 | %{ 783 | agent_state: 784 | %__MODULE__{ 785 | actor_target_params: actor_target_params, 786 | actor_params: actor_params, 787 | critic1_params: critic1_params, 788 | critic2_params: critic2_params, 789 | critic1_target_params: critic1_target_params, 790 | critic2_target_params: critic2_target_params, 791 | tau: tau 792 | } = agent_state 793 | } = state 794 | 795 | merge_fn = &Nx.as_type(&1 * tau + &2 * (1 - tau), Nx.type(&1)) 796 | 797 | actor_target_params = 798 | Axon.Shared.deep_merge(actor_params, actor_target_params, merge_fn) 799 | 800 | critic1_target_params = 801 | Axon.Shared.deep_merge(critic1_params, critic1_target_params, merge_fn) 802 | 803 | critic2_target_params = 804 | Axon.Shared.deep_merge(critic2_params, critic2_target_params, merge_fn) 805 | 806 | %{ 807 | state 808 | | agent_state: %{ 809 | agent_state 810 | | actor_target_params: actor_target_params, 811 | critic1_target_params: critic1_target_params, 812 | critic2_target_params: critic2_target_params 813 | } 814 | } 815 | end 816 | 817 | defnp sample_experience_replay_buffer( 818 | random_key, 819 | %__MODULE__{batch_size: batch_size} = agent_state 820 | ) do 821 | data = agent_state.experience_replay_buffer.data 822 | size = agent_state.experience_replay_buffer.size 823 | 824 | # split and devectorize random_key because we want to keep the replay buffer 825 | # and its samples devectorized at all times 826 | split_key = Nx.Random.split(random_key) 827 | 828 | random_key = split_key[0] 829 | vec_k = split_key[1] 830 | 831 | k = Nx.devectorize(vec_k, keep_names: false) 832 | 833 | k = 834 | case Nx.shape(k) do 835 | {2} -> 836 | k 837 | 838 | {_, 2} -> 839 | Nx.take(k, 0) 840 | end 841 | 842 | n = Nx.axis_size(data, 0) 843 | 844 | {batch, _} = 845 | if size < n do 846 | probabilities = (Nx.iota({n}) < size) / size 847 | Nx.Random.choice(k, data, probabilities, samples: batch_size, replace: false, axis: 0) 848 | else 849 | Nx.Random.choice(k, data, samples: batch_size, replace: false, axis: 0) 850 | end 851 | 852 | {stop_grad(batch), random_key} 853 | end 854 | 855 | defnp action_log_probability(mu, stddev, log_stddev, pre_squash_action, action) do 856 | # x is assumed to be pre tanh squashing 857 | type = Nx.type(action) 858 | eps = Nx.Constants.epsilon(type) 859 | 860 | log_prob(mu, stddev, log_stddev, pre_squash_action) - 861 | Nx.sum(Nx.log(1 - action ** 2 + eps), axes: [-1], keep_axes: true) 862 | end 863 | 864 | defnp log_prob(mu, stddev, log_stddev, x) do 865 | # compute the variance 866 | type = Nx.type(x) 867 | eps = Nx.Constants.epsilon(type) 868 | 869 | # formula for the log-probability density function of a Normal distribution 870 | z = (x - mu) / (stddev + eps) 871 | 872 | log_prob = -0.5 * z ** 2 - log_stddev - Nx.log(Nx.sqrt(2 * pi(type))) 873 | 874 | Nx.sum(log_prob, axes: [-1], keep_axes: true) 875 | end 876 | end 877 | --------------------------------------------------------------------------------