├── .gitignore ├── LICENSE ├── Package.resolved ├── Package.swift ├── README.md └── Sources ├── CGLFW ├── module.modulemap └── shim.h ├── ReinforcementLearning ├── Agents │ ├── AgentUtilities.swift │ ├── Agents.swift │ ├── DeepQNetworks.swift │ └── PolicyGradientAgents.swift ├── Distributions │ ├── Bernoulli.swift │ ├── Categorical.swift │ ├── Deterministic.swift │ ├── Distribution.swift │ └── Uniform.swift ├── Environments │ ├── ClassicControl │ │ └── CartPole.swift │ ├── Environment.swift │ └── Metrics.swift ├── ReplayBuffers.swift ├── Spaces.swift ├── Utilities │ ├── General.swift │ ├── LearningRates.swift │ ├── Normalization.swift │ ├── Protocols.swift │ ├── Rendering.swift │ └── Rendering2.swift └── Values.swift └── ReinforcementLearningExperiments ├── CartPole.swift └── main.swift /.gitignore: -------------------------------------------------------------------------------- 1 | # MacOS 2 | .DS_Store 3 | 4 | # Swift Package Manager 5 | /.build 6 | /Packages 7 | /*.xcodeproj 8 | 9 | # VS Code 10 | /.vscode 11 | 12 | # CLion 13 | /.idea 14 | 15 | # Cloned retro repository 16 | /retro 17 | 18 | # Other Stuff 19 | temp 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019, Emmanouil Antonios Platanios. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2019, Emmanouil Antonios Platanios. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "object": { 3 | "pins": [ 4 | { 5 | "package": "swift-log", 6 | "repositoryURL": "https://github.com/apple/swift-log.git", 7 | "state": { 8 | "branch": null, 9 | "revision": "f4240bf022a69815241a883c03645444b58ac553", 10 | "version": "1.1.0" 11 | } 12 | } 13 | ] 14 | }, 15 | "version": 1 16 | } 17 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version:5.0 2 | // The swift-tools-version declares the minimum version of Swift required to build this package. 3 | import Foundation 4 | import PackageDescription 5 | 6 | let package = Package( 7 | name: "ReinforcementLearning", 8 | platforms: [.macOS(.v10_13)], 9 | products: [ 10 | .library( 11 | name: "ReinforcementLearning", 12 | targets: ["ReinforcementLearning"]), 13 | .executable( 14 | name: "ReinforcementLearningExperiments", 15 | targets: ["ReinforcementLearningExperiments"]) 16 | ], 17 | dependencies: [ 18 | .package(url: "https://github.com/apple/swift-log.git", from: "1.0.0") 19 | ], 20 | targets: [ 21 | .systemLibrary( 22 | name: "CGLFW", 23 | path: "Sources/CGLFW", 24 | pkgConfig: "glfw3", 25 | providers: [ 26 | .brew(["--HEAD git glfw3"]), 27 | .apt(["libglfw3", "libglfw3-dev"]) 28 | ]), 29 | .target( 30 | name: "ReinforcementLearning", 31 | dependencies: ["CGLFW"], 32 | path: "Sources/ReinforcementLearning"), 33 | .target( 34 | name: "ReinforcementLearningExperiments", 35 | dependencies: ["Logging", "ReinforcementLearning"]) 36 | ] 37 | ) 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **WARNING:** This is currently work-in-progress. 2 | 3 | # Reinforcement Learning in Swift 4 | 5 | This repository contains a reinforcement learning library 6 | built using Swift for TensorFlow, that also encompasses the 7 | functionality of OpenAI Gym. The following is a list of 8 | currently supported features. 9 | 10 | - All algorithms and interfaces are designed and 11 | implemented with batching in mind to support efficient 12 | training of neural networks that often operate on batched 13 | inputs. 14 | - [Environments](https://github.com/eaplatanios/retro-swift/blob/master/Sources/ReinforcementLearning/Environments/Environment.swift): 15 | - [Cart-Pole (classic control example)](https://github.com/eaplatanios/retro-swift/blob/master/Sources/ReinforcementLearning/Environments/ClassicControl/CartPole.swift) 16 | - [Atari Games (using the Arcade Learning Environment)](https://github.com/eaplatanios/swift-ale) 17 | - [Retro Games (atari, sega, etc., using Gym Retro)](https://github.com/eaplatanios/swift-retro) 18 | - [Agents](https://github.com/eaplatanios/retro-swift/blob/master/Sources/ReinforcementLearning/Agents/Agent.swift): 19 | - [Policy Gradient Algorithms](https://github.com/eaplatanios/retro-swift/blob/master/Sources/ReinforcementLearning/Agents/PolicyGradientAgents.swift): 20 | - REINFORCE 21 | - Advantage Actor Critic (A2C) 22 | - Proximal Policy Optimization (PPO) 23 | - *UPCOMING: Deep Deterministic Policy Gradients (DDPG)* 24 | - *UPCOMING: Twin Delayed Deep Deterministic Policy Gradients (TD3)* 25 | - *UPCOMING: Soft Actor Critic (SAC)* 26 | - Q-Learning Algorithms: 27 | - [Deep Q-Networks (DQN)](https://github.com/eaplatanios/retro-swift/blob/master/Sources/ReinforcementLearning/Agents/DeepQNetworks.swift) 28 | - *UPCOMING: Double Deep Q-Networks (DDQN)* 29 | - [Advantage Estimation Methods](https://github.com/eaplatanios/retro-swift/blob/master/Sources/ReinforcementLearning/Values.swift): 30 | - Empirical Advantage Estimation 31 | - Generalized Advantage Estimation (GAE) 32 | - [Replay Buffers](https://github.com/eaplatanios/retro-swift/blob/master/Sources/ReinforcementLearning/ReplayBuffers.swift): 33 | - Uniform Replay Buffer 34 | - *UPCOMING: Prioritized Replay Buffer* 35 | - [Visualization using OpenGL for all of the currently 36 | implemented environments.](https://github.com/eaplatanios/retro-swift/blob/master/Sources/ReinforcementLearning/Utilities/Rendering.swift) 37 | 38 | ## Installation 39 | 40 | ### Prerequisites 41 | 42 | #### GLFW 43 | 44 | GLFW is used for rendering. You can install it using: 45 | 46 | ```bash 47 | # For MacOS: 48 | brew install --HEAD git glfw3 49 | 50 | # For Linux: 51 | sudo apt install libglfw3-dev libglfw3 52 | ``` 53 | 54 | **NOTE:** The Swift Package Manager uses `pkg-config` to 55 | locate the installed libraries and so you need to make sure 56 | that `pkg-config` is configured correctly. That may require 57 | you to set the `PKG_CONFIG_PATH` environment variable 58 | correctly. 59 | 60 | **NOTE:** If the rendered image does not update according 61 | to the specified frames per second value and you are using 62 | MacOS 10.14, you should update to 10.14.4 because there is 63 | a bug in previous releases of 10.14 which breaks VSync. 64 | 65 | ## Reinforcement Learning Library Design Notes 66 | 67 | **WARNING:** The below is not relevant anymore. I have been 68 | working on a new simpler and more powerful interface and 69 | plan to update the examples shown in this file soon. 70 | 71 | ### Batching 72 | 73 | Batching can occur at two levels: 74 | 75 | - __Environment:__ 76 | - __Policy:__ 77 | 78 | For example, in the case of retro games, the environment 79 | can only operate on one action at a time (i.e., it is not 80 | batched). If we have a policy that is also not batched, 81 | then we the process of collecting trajectories for training 82 | looks as follows: 83 | 84 | ``` 85 | ... → Policy → Environment → Policy → Environment → ... 86 | ``` 87 | 88 | In this diagram, the policy is invoked to produce the next 89 | action and then the environment is invoked to take a step 90 | using that action and return rewards, etc. If instead we 91 | are using a policy that can be batched (e.g., a 92 | convolutional neural network policy would be much more 93 | efficient if executed in a batched manner), then we can 94 | collect trajectories for training in the following manner: 95 | 96 | ``` 97 | ↗ Environment ↘ ↗ Environment ↘ 98 | ... ⇒ Policy ⇒ → Environment → ⇒ Policy ⇒ → Environment → ... 99 | ↘ Environment ↗ ↘ Environment ↗ 100 | ``` 101 | 102 | where multiple copies of the environment are running 103 | separately, producing rewards that are then batched and fed 104 | all together to a single batched policy. This policy then 105 | produces a batch of actions that is split up and each action 106 | is in term fed to its corresponding environment. Similarly, 107 | we can have a batched environment being used together with 108 | an unbatched policy: 109 | 110 | ``` 111 | ↗ Policy ↘ ↗ Policy ↘ 112 | ... → Policy → ⇒ Environment ⇒ → Policy → ⇒ Environment ⇒ ... 113 | ↘ Policy ↗ ↘ Policy ↗ 114 | ``` 115 | 116 | or, even better, a batched environment used together with a 117 | batched policy: 118 | 119 | ``` 120 | ... ⇒ Policy ⇒ Environment ⇒ Policy ⇒ Environment ⇒ ... 121 | ``` 122 | 123 | **NOTE:** Note that a batched policy is always usable as a 124 | policy (the batch conversions are handled automatically), 125 | and the same is true for batched environments. 126 | -------------------------------------------------------------------------------- /Sources/CGLFW/module.modulemap: -------------------------------------------------------------------------------- 1 | module CGLFW [system] { 2 | umbrella header "shim.h" 3 | link "glfw" 4 | link "GL" 5 | export * 6 | } 7 | -------------------------------------------------------------------------------- /Sources/CGLFW/shim.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #ifndef __APPLE__ 4 | #define GL_GLEXT_PROTOTYPES 5 | #include 6 | #endif 7 | 8 | #include 9 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Agents/AgentUtilities.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | public struct AgentInput: Differentiable, KeyPathIterable { 18 | @noDerivative public let observation: Observation 19 | public var state: State 20 | 21 | @inlinable 22 | @differentiable(wrt: state) 23 | public init(observation: Observation, state: State) { 24 | self.observation = observation 25 | self.state = state 26 | } 27 | } 28 | 29 | public struct QNetworkOutput: Differentiable { 30 | public var qValues: Tensor 31 | public var state: State 32 | 33 | @inlinable 34 | @differentiable 35 | public init(qValues: Tensor, state: State) { 36 | self.qValues = qValues 37 | self.state = state 38 | } 39 | } 40 | 41 | public struct ActorOutput< 42 | ActionDistribution: DifferentiableDistribution & KeyPathIterable, 43 | State: Differentiable & KeyPathIterable 44 | >: Differentiable, KeyPathIterable { 45 | public var actionDistribution: ActionDistribution 46 | public var state: State 47 | 48 | @inlinable 49 | @differentiable 50 | public init(actionDistribution: ActionDistribution, state: State) { 51 | self.actionDistribution = actionDistribution 52 | self.state = state 53 | } 54 | } 55 | 56 | public struct ActorCriticOutput< 57 | ActionDistribution: DifferentiableDistribution & KeyPathIterable, 58 | State: Differentiable & KeyPathIterable 59 | >: Differentiable, KeyPathIterable { 60 | public var actionDistribution: ActionDistribution 61 | public var value: Tensor 62 | public var state: State 63 | 64 | @inlinable 65 | @differentiable 66 | public init(actionDistribution: ActionDistribution, value: Tensor, state: State) { 67 | self.actionDistribution = actionDistribution 68 | self.value = value 69 | self.state = state 70 | } 71 | } 72 | 73 | public struct StatelessActorCriticOutput< 74 | ActionDistribution: DifferentiableDistribution & KeyPathIterable 75 | >: Differentiable, KeyPathIterable { 76 | public var actionDistribution: ActionDistribution 77 | public var value: Tensor 78 | 79 | @inlinable 80 | @differentiable 81 | public init(actionDistribution: ActionDistribution, value: Tensor) { 82 | self.actionDistribution = actionDistribution 83 | self.value = value 84 | } 85 | } 86 | 87 | public struct StatelessQNetwork< 88 | Environment: ReinforcementLearning.Environment, 89 | Network: Module & Copyable 90 | >: Module & Copyable where 91 | Network.Input == Environment.Observation, 92 | Network.Output == Tensor 93 | { 94 | public var statelessNetwork: Network 95 | 96 | @inlinable 97 | public init(_ statelessNetwork: Network) { 98 | self.statelessNetwork = statelessNetwork 99 | } 100 | 101 | @inlinable 102 | @differentiable 103 | public func callAsFunction( 104 | _ input: AgentInput 105 | ) -> QNetworkOutput { 106 | QNetworkOutput(qValues: statelessNetwork(input.observation), state: Empty()) 107 | } 108 | 109 | @inlinable 110 | public func copy() -> StatelessQNetwork { 111 | StatelessQNetwork(statelessNetwork.copy()) 112 | } 113 | } 114 | 115 | public struct StatelessActorNetwork< 116 | Environment: ReinforcementLearning.Environment, 117 | Network: Module 118 | >: Module where 119 | Environment.ActionSpace.ValueDistribution: DifferentiableDistribution, 120 | Network.Input == Environment.Observation, 121 | Network.Output == Environment.ActionSpace.ValueDistribution, 122 | Network.Output: KeyPathIterable 123 | { 124 | public var statelessNetwork: Network 125 | 126 | @inlinable 127 | public init(_ statelessNetwork: Network) { 128 | self.statelessNetwork = statelessNetwork 129 | } 130 | 131 | @inlinable 132 | @differentiable 133 | public func callAsFunction( 134 | _ input: AgentInput 135 | ) -> ActorOutput { 136 | ActorOutput(actionDistribution: statelessNetwork(input.observation), state: Empty()) 137 | } 138 | } 139 | 140 | public struct StatelessActorCriticNetwork< 141 | Environment: ReinforcementLearning.Environment, 142 | Network: Module 143 | >: Module where 144 | Environment.ActionSpace.ValueDistribution: DifferentiableDistribution, 145 | Network.Input == Environment.Observation, 146 | Network.Output == StatelessActorCriticOutput 147 | { 148 | public var statelessNetwork: Network 149 | 150 | @inlinable 151 | public init(_ statelessNetwork: Network) { 152 | self.statelessNetwork = statelessNetwork 153 | } 154 | 155 | @inlinable 156 | @differentiable 157 | public func callAsFunction( 158 | _ input: AgentInput 159 | ) -> ActorCriticOutput { 160 | let networkOutput = statelessNetwork(input.observation) 161 | return ActorCriticOutput( 162 | actionDistribution: networkOutput.actionDistribution, 163 | value: networkOutput.value, 164 | state: Empty()) 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Agents/Agents.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | public typealias StepCallback = 18 | (inout E, inout Trajectory) -> Void 19 | 20 | public protocol Agent { 21 | associatedtype Environment: ReinforcementLearning.Environment 22 | associatedtype State 23 | 24 | var actionSpace: Environment.ActionSpace { get } 25 | var state: State { get set } 26 | 27 | mutating func action(for step: Step) -> Action 28 | 29 | /// Updates this agent, effectively performing a single training step. 30 | /// 31 | /// - Parameter trajectory: Trajectory to use for the update. 32 | /// - Returns: Loss function value. 33 | @discardableResult 34 | mutating func update(using trajectory: Trajectory) -> Float 35 | 36 | @discardableResult 37 | mutating func update( 38 | using environment: inout Environment, 39 | maxSteps: Int, 40 | maxEpisodes: Int, 41 | callbacks: [StepCallback] 42 | ) throws -> Float 43 | } 44 | 45 | extension Agent where State == Empty { 46 | public var state: State { 47 | get { Empty() } 48 | set {} 49 | } 50 | } 51 | 52 | extension Agent { 53 | public typealias Observation = Environment.Observation 54 | public typealias Action = Environment.Action 55 | public typealias Reward = Environment.Reward 56 | 57 | @inlinable 58 | public mutating func run( 59 | in environment: inout Environment, 60 | maxSteps: Int = Int.max, 61 | maxEpisodes: Int = Int.max, 62 | callbacks: [StepCallback] = [] 63 | ) throws { 64 | var currentStep = environment.currentStep 65 | var numSteps = 0 66 | var numEpisodes = 0 67 | while numSteps < maxSteps && numEpisodes < maxEpisodes { 68 | let state = self.state 69 | let action = self.action(for: currentStep) 70 | let nextStep = try environment.step(taking: action) 71 | var trajectory = Trajectory( 72 | stepKind: nextStep.kind, 73 | observation: currentStep.observation, 74 | state: state, 75 | action: action, 76 | reward: nextStep.reward) 77 | callbacks.forEach { $0(&environment, &trajectory) } 78 | numSteps += Int((1 - Tensor(nextStep.kind.isLast())).sum().scalarized()) 79 | numEpisodes += Int(Tensor(nextStep.kind.isLast()).sum().scalarized()) 80 | currentStep = nextStep 81 | } 82 | } 83 | } 84 | 85 | /// Trajectory generated by having an agent interact with an environment. 86 | /// 87 | /// Trajectories consist of five main components, each of which can be a nested structure of 88 | /// tensors with shapes whose first two dimensions are `[T, B]`, where `T` is the length of the 89 | /// trajectory in terms of time steps and `B` is the batch size. The five components are: 90 | /// - `stepKind`: Represents the kind of each time step (i.e., "first", "transition", or "last"). 91 | /// For example, if the agent takes an action in time step `t` that results in the current 92 | /// episode ending, then `stepKind[t]` will be "last" and `stepKind[t + 1]` will be "first". 93 | /// - `observation`: Observation that the agent receives from the environment in the beginning 94 | /// of each time step. 95 | /// - `state`: State of the agent at the beginning of each time step. 96 | /// - `action`: Action the agent took in each time step. 97 | /// - `reward`: Reward that the agent received from the environment after each action. The reward 98 | /// received after taking `action[t]` is `reward[t]`. 99 | public struct Trajectory: KeyPathIterable { 100 | // These need to be mutable because we use `KeyPathIterable.recursivelyAllWritableKeyPaths` to 101 | // automatically derive conformance to `Replayable`. 102 | public var stepKind: StepKind 103 | public var observation: Observation 104 | public var state: State 105 | public var action: Action 106 | public var reward: Reward 107 | 108 | @inlinable 109 | public init( 110 | stepKind: StepKind, 111 | observation: Observation, 112 | state: State, 113 | action: Action, 114 | reward: Reward 115 | ) { 116 | self.stepKind = stepKind 117 | self.observation = observation 118 | self.state = state 119 | self.action = action 120 | self.reward = reward 121 | } 122 | } 123 | 124 | public struct AnyAgent: Agent { 125 | public typealias Observation = Environment.Observation 126 | public typealias Action = Environment.Action 127 | public typealias Reward = Environment.Reward 128 | 129 | @usableFromInline internal let _actionSpace: () -> Environment.ActionSpace 130 | @usableFromInline internal let _getState: () -> State 131 | @usableFromInline internal let _setState: (State) -> Void 132 | @usableFromInline internal let _action: (Step) -> Action 133 | 134 | @usableFromInline internal let _updateUsingTrajectory: ( 135 | Trajectory 136 | ) -> Float 137 | 138 | @usableFromInline internal let _updateUsingEnvironment: ( 139 | inout Environment, 140 | Int, 141 | Int, 142 | [StepCallback] 143 | ) throws -> Float 144 | 145 | public var actionSpace: Environment.ActionSpace { _actionSpace() } 146 | 147 | public var state: State { 148 | get { _getState() } 149 | set { _setState(newValue) } 150 | } 151 | 152 | @inlinable 153 | public init(_ agent: A) where A.Environment == Environment, A.State == State { 154 | var agent = agent 155 | _actionSpace = { () in agent.actionSpace } 156 | _getState = { () in agent.state } 157 | _setState = { agent.state = $0 } 158 | _action = { agent.action(for: $0) } 159 | _updateUsingTrajectory = { agent.update(using: $0) } 160 | _updateUsingEnvironment = { try agent.update( 161 | using: &$0, 162 | maxSteps: $1, 163 | maxEpisodes: $2, 164 | callbacks: $3) 165 | } 166 | } 167 | 168 | @inlinable 169 | public mutating func action(for step: Step) -> Action { 170 | _action(step) 171 | } 172 | 173 | @inlinable 174 | @discardableResult 175 | public mutating func update( 176 | using trajectory: Trajectory 177 | ) -> Float { 178 | _updateUsingTrajectory(trajectory) 179 | } 180 | 181 | @inlinable 182 | @discardableResult 183 | public mutating func update( 184 | using environment: inout Environment, 185 | maxSteps: Int = Int.max, 186 | maxEpisodes: Int = Int.max, 187 | callbacks: [StepCallback] = [] 188 | ) throws -> Float { 189 | try _updateUsingEnvironment(&environment, maxSteps, maxEpisodes, callbacks) 190 | } 191 | } 192 | 193 | // TODO: Support `boltzman(temperature:)`. 194 | public enum ProbabilisticAgentMode { 195 | case random 196 | case greedy 197 | case epsilonGreedy(_ epsilon: Float) 198 | case probabilistic 199 | } 200 | 201 | public protocol ProbabilisticAgent: Agent { 202 | associatedtype ActionDistribution: Distribution where ActionDistribution.Value == Action 203 | 204 | /// Generates the distribution over next actions given the current environment step. 205 | mutating func actionDistribution(for step: Step) -> ActionDistribution 206 | } 207 | 208 | extension ProbabilisticAgent { 209 | @inlinable 210 | public mutating func action(for step: Step) -> Action { 211 | action(for: step, mode: .greedy) 212 | } 213 | 214 | /// - Note: We cannot use a default argument value for `mode` here because of the `Agent` 215 | /// protocol requirement for an `Agent.action(for:)` function. 216 | @inlinable 217 | public mutating func action( 218 | for step: Step, 219 | mode: ProbabilisticAgentMode 220 | ) -> Action { 221 | switch mode { 222 | case .random: 223 | return actionSpace.sample() 224 | case .greedy: 225 | return actionDistribution(for: step).mode() 226 | case let .epsilonGreedy(epsilon) where Float.random(in: 0..<1) < epsilon: 227 | return actionSpace.sample() 228 | case .epsilonGreedy(_): 229 | return actionDistribution(for: step).mode() 230 | case .probabilistic: 231 | return actionDistribution(for: step).sample() 232 | } 233 | } 234 | 235 | @inlinable 236 | public mutating func run( 237 | in environment: inout Environment, 238 | mode: ProbabilisticAgentMode = .greedy, 239 | maxSteps: Int = Int.max, 240 | maxEpisodes: Int = Int.max, 241 | callbacks: [StepCallback] = [] 242 | ) throws { 243 | var currentStep = environment.currentStep 244 | var numSteps = 0 245 | var numEpisodes = 0 246 | while numSteps < maxSteps && numEpisodes < maxEpisodes { 247 | let action = self.action(for: currentStep, mode: mode) 248 | let nextStep = try environment.step(taking: action) 249 | var trajectory = Trajectory( 250 | stepKind: nextStep.kind, 251 | observation: currentStep.observation, 252 | state: state, 253 | action: action, 254 | reward: nextStep.reward) 255 | callbacks.forEach { $0(&environment, &trajectory) } 256 | numSteps += Int((1 - Tensor(nextStep.kind.isLast())).sum().scalarized()) 257 | numEpisodes += Int(Tensor(nextStep.kind.isLast()).sum().scalarized()) 258 | currentStep = nextStep 259 | } 260 | } 261 | } 262 | 263 | public struct AnyProbabilisticAgent< 264 | Environment: ReinforcementLearning.Environment, 265 | ActionDistribution: Distribution, 266 | State 267 | >: ProbabilisticAgent where ActionDistribution.Value == Environment.Action { 268 | public typealias Observation = Environment.Observation 269 | public typealias Action = Environment.Action 270 | public typealias Reward = Environment.Reward 271 | 272 | @usableFromInline internal let _actionSpace: () -> Environment.ActionSpace 273 | @usableFromInline internal let _getState: () -> State 274 | @usableFromInline internal let _setState: (State) -> Void 275 | @usableFromInline internal let _action: (Step) -> Action 276 | 277 | @usableFromInline internal let _actionDistribution: ( 278 | Step 279 | ) -> ActionDistribution 280 | 281 | @usableFromInline internal let _updateUsingTrajectory: ( 282 | Trajectory 283 | ) -> Float 284 | 285 | @usableFromInline internal let _updateUsingEnvironment: ( 286 | inout Environment, 287 | Int, 288 | Int, 289 | [StepCallback] 290 | ) throws -> Float 291 | 292 | public var actionSpace: Environment.ActionSpace { _actionSpace() } 293 | 294 | public var state: State { 295 | get { _getState() } 296 | set { _setState(newValue) } 297 | } 298 | 299 | public init(_ agent: A) where 300 | A.Environment == Environment, 301 | A.ActionDistribution == ActionDistribution, 302 | A.State == State 303 | { 304 | var agent = agent 305 | _actionSpace = { () in agent.actionSpace } 306 | _getState = { () in agent.state } 307 | _setState = { agent.state = $0 } 308 | _action = { agent.action(for: $0) } 309 | _actionDistribution = { agent.actionDistribution(for: $0) } 310 | _updateUsingTrajectory = { agent.update(using: $0) } 311 | _updateUsingEnvironment = { try agent.update( 312 | using: &$0, 313 | maxSteps: $1, 314 | maxEpisodes: $2, 315 | callbacks: $3) 316 | } 317 | } 318 | 319 | @inlinable 320 | public mutating func action(for step: Step) -> Action { 321 | _action(step) 322 | } 323 | 324 | @inlinable 325 | public mutating func actionDistribution( 326 | for step: Step 327 | ) -> ActionDistribution { 328 | _actionDistribution(step) 329 | } 330 | 331 | @inlinable 332 | @discardableResult 333 | public mutating func update( 334 | using trajectory: Trajectory 335 | ) -> Float { 336 | _updateUsingTrajectory(trajectory) 337 | } 338 | 339 | @inlinable 340 | @discardableResult 341 | public mutating func update( 342 | using environment: inout Environment, 343 | maxSteps: Int = Int.max, 344 | maxEpisodes: Int = Int.max, 345 | callbacks: [StepCallback] = [] 346 | ) throws -> Float { 347 | try _updateUsingEnvironment(&environment, maxSteps, maxEpisodes, callbacks) 348 | } 349 | } 350 | 351 | public struct RandomAgent: ProbabilisticAgent { 352 | public typealias Observation = Environment.ObservationSpace.Value 353 | public typealias State = Empty 354 | public typealias Action = Environment.ActionSpace.Value 355 | public typealias ActionDistribution = Environment.ActionSpace.ValueDistribution 356 | public typealias Reward = Environment.Reward 357 | 358 | public let actionSpace: Environment.ActionSpace 359 | 360 | @inlinable 361 | public init(for environment: Environment) { 362 | actionSpace = environment.actionSpace 363 | } 364 | 365 | @inlinable 366 | public func actionDistribution(for step: Step) -> ActionDistribution { 367 | actionSpace.distribution 368 | } 369 | 370 | @inlinable 371 | @discardableResult 372 | public mutating func update( 373 | using trajectory: Trajectory 374 | ) -> Float { 375 | 0.0 376 | } 377 | 378 | @inlinable 379 | @discardableResult 380 | public mutating func update( 381 | using environment: inout Environment, 382 | maxSteps: Int, 383 | maxEpisodes: Int, 384 | callbacks: [StepCallback] 385 | ) -> Float { 386 | 0.0 387 | } 388 | } 389 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Agents/DeepQNetworks.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | // TODO: Fill the replay buffer with random data in the beginning of training. 18 | // TODO: Reward scaling / reward shaping. 19 | // TODO: Exploration schedules (i.e., how to vary ε while training). 20 | 21 | // We let Q-networks output distributions over actions, making them able to handle both discrete 22 | // and continuous action spaces. 23 | public struct DQNAgent< 24 | Environment: ReinforcementLearning.Environment, 25 | State: Differentiable, 26 | QNetwork: Module & Copyable, 27 | Optimizer: TensorFlow.Optimizer 28 | >: ProbabilisticAgent 29 | where 30 | Environment.ActionSpace.ValueDistribution == Categorical, 31 | Environment.Reward == Tensor, 32 | QNetwork.Input == AgentInput, 33 | QNetwork.Output == QNetworkOutput, 34 | Optimizer.Model == QNetwork 35 | { 36 | public typealias Observation = Environment.Observation 37 | public typealias Action = Environment.ActionSpace.Value 38 | public typealias ActionDistribution = Environment.ActionSpace.ValueDistribution 39 | public typealias Reward = Tensor 40 | 41 | public let actionSpace: Environment.ActionSpace 42 | public var state: State 43 | public var qNetwork: QNetwork 44 | public var targetQNetwork: QNetwork 45 | public var optimizer: Optimizer 46 | 47 | public let trainSequenceLength: Int 48 | public let maxReplayedSequenceLength: Int 49 | public let epsilonGreedy: Float 50 | public let targetUpdateForgetFactor: Float 51 | public let targetUpdatePeriod: Int 52 | public let discountFactor: Float 53 | public let trainStepsPerIteration: Int 54 | 55 | @usableFromInline internal var replayBuffer: UniformReplayBuffer< 56 | Trajectory>? 57 | @usableFromInline internal var trainingStep: Int = 0 58 | 59 | @inlinable 60 | public init( 61 | for environment: Environment, 62 | qNetwork: QNetwork, 63 | initialState: State, 64 | optimizer: (QNetwork) -> Optimizer, 65 | trainSequenceLength: Int, 66 | maxReplayedSequenceLength: Int, 67 | epsilonGreedy: Float = 0.1, 68 | targetUpdateForgetFactor: Float = 1.0, 69 | targetUpdatePeriod: Int = 1, 70 | discountFactor: Float = 0.99, 71 | trainStepsPerIteration: Int = 1 72 | ) { 73 | precondition( 74 | trainSequenceLength > 0, 75 | "The provided training sequence length must be greater than 0.") 76 | precondition( 77 | trainSequenceLength < maxReplayedSequenceLength, 78 | "The provided training sequence length is larger than the maximum replayed sequence length.") 79 | precondition( 80 | targetUpdateForgetFactor > 0.0 && targetUpdateForgetFactor <= 1.0, 81 | "The target update forget factor must be in the interval (0, 1].") 82 | self.actionSpace = environment.actionSpace 83 | self.state = initialState 84 | self.qNetwork = qNetwork 85 | self.targetQNetwork = qNetwork.copy() 86 | self.optimizer = optimizer(qNetwork) 87 | self.trainSequenceLength = trainSequenceLength 88 | self.maxReplayedSequenceLength = maxReplayedSequenceLength 89 | self.epsilonGreedy = epsilonGreedy 90 | self.targetUpdateForgetFactor = targetUpdateForgetFactor 91 | self.targetUpdatePeriod = targetUpdatePeriod 92 | self.discountFactor = discountFactor 93 | self.trainStepsPerIteration = trainStepsPerIteration 94 | self.replayBuffer = nil 95 | } 96 | 97 | @inlinable 98 | public mutating func actionDistribution( 99 | for step: Step 100 | ) -> ActionDistribution { 101 | let qNetworkOutput = qNetwork(AgentInput(observation: step.observation, state: state)) 102 | state = qNetworkOutput.state 103 | return Categorical(logits: qNetworkOutput.qValues) 104 | } 105 | 106 | @discardableResult 107 | public mutating func update( 108 | using trajectory: Trajectory 109 | ) -> Float { 110 | let (loss, gradient) = valueWithGradient(at: qNetwork) { qNetwork -> Tensor in 111 | let qNetworkOutput = qNetwork(AgentInput( 112 | observation: trajectory.observation, 113 | state: trajectory.state)) 114 | let qValue = qNetworkOutput.qValues.batchGathering( 115 | atIndices: trajectory.action, 116 | alongAxis: 2, 117 | batchDimensionCount: 2) 118 | 119 | // Split the trajectory such that the last step is only used to compute the next Q value. 120 | let sequenceLength = qValue.shape[0] - 1 121 | let currentStepKind = StepKind(trajectory.stepKind.rawValue[0..(1.0) 133 | let quadratic = min(error, delta) 134 | // The following expression is the same in value as `max(error - delta, 0)`, but 135 | // importantly the gradient for the expression when `error == delta` is `0` (for the form 136 | // using `max(_:_:)` it would be `1`). This is necessary to avoid doubling the gradient 137 | // because there is already a nonzero contribution to it from the quadratic term. 138 | var tdLoss = 0.5 * quadratic * quadratic + delta * (error - quadratic) 139 | 140 | // Mask the loss for all steps that mark the end of an episode. 141 | tdLoss = tdLoss * (1 - Tensor(currentStepKind.isLast())) 142 | 143 | // Finally, sum the loss over the time dimension and average across the batch dimension. 144 | // Note that we use an element-wise loss up to this point in order to ensure that each 145 | // element is always weighted by `1/B` where `B` is the batch size, even when some of the 146 | // steps have zero loss due to episode transitions. Weighting by `1/K` where `K` is the 147 | // actual number of non-zero loss weights (e.g., due to the mask) would artificially increase 148 | // the contribution of the non-masked loss elements. This would get increasingly worse as the 149 | // number of episode transitions increases. 150 | return tdLoss.sum(squeezingAxes: 0).mean() 151 | } 152 | optimizer.update(&qNetwork, along: gradient) 153 | updateTargetQNetwork() 154 | return loss.scalarized() 155 | } 156 | 157 | @inlinable 158 | @discardableResult 159 | public mutating func update( 160 | using environment: inout Environment, 161 | maxSteps: Int = Int.max, 162 | maxEpisodes: Int = Int.max, 163 | callbacks: [StepCallback] = [] 164 | ) throws -> Float { 165 | if replayBuffer == nil { 166 | replayBuffer = UniformReplayBuffer( 167 | batchSize: environment.batchSize, 168 | maxLength: maxReplayedSequenceLength) 169 | } 170 | var currentStep = environment.currentStep 171 | var numSteps = 0 172 | var numEpisodes = 0 173 | while numSteps < maxSteps && numEpisodes < maxEpisodes { 174 | let state = self.state 175 | let action = self.action(for: currentStep, mode: .epsilonGreedy(epsilonGreedy)) 176 | let nextStep = try environment.step(taking: action) 177 | var trajectory = Trajectory( 178 | stepKind: nextStep.kind, 179 | observation: currentStep.observation, 180 | state: state, 181 | action: action, 182 | reward: nextStep.reward) 183 | replayBuffer!.record(trajectory) 184 | callbacks.forEach { $0(&environment, &trajectory) } 185 | numSteps += Int((1 - Tensor(nextStep.kind.isLast())).sum().scalarized()) 186 | numEpisodes += Int(Tensor(nextStep.kind.isLast()).sum().scalarized()) 187 | currentStep = nextStep 188 | } 189 | var loss: Float = 0.0 190 | for _ in 0.. Tensor { 218 | targetQNetwork(AgentInput( 219 | observation: observation, 220 | state: state 221 | )).qValues.max(squeezingAxes: -1) 222 | } 223 | } 224 | 225 | extension DQNAgent where State == Empty { 226 | @inlinable 227 | public init( 228 | for environment: Environment, 229 | qNetwork: StatelessNetwork, 230 | optimizer: (StatelessQNetwork) -> Optimizer, 231 | trainSequenceLength: Int, 232 | maxReplayedSequenceLength: Int, 233 | epsilonGreedy: Float = 0.1, 234 | targetUpdateForgetFactor: Float = 1.0, 235 | targetUpdatePeriod: Int = 1, 236 | discountFactor: Float = 0.99, 237 | trainStepsPerIteration: Int = 1 238 | ) where 239 | QNetwork == StatelessQNetwork, 240 | StatelessNetwork.Input == Observation, 241 | StatelessNetwork.Output == Tensor 242 | { 243 | let qNetwork = StatelessQNetwork(qNetwork) 244 | self.init( 245 | for: environment, 246 | qNetwork: qNetwork, 247 | initialState: Empty(), 248 | optimizer: optimizer, 249 | trainSequenceLength: trainSequenceLength, 250 | maxReplayedSequenceLength: maxReplayedSequenceLength, 251 | epsilonGreedy: epsilonGreedy, 252 | targetUpdateForgetFactor: targetUpdateForgetFactor, 253 | targetUpdatePeriod: targetUpdatePeriod, 254 | discountFactor: discountFactor, 255 | trainStepsPerIteration: trainStepsPerIteration) 256 | } 257 | } 258 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Agents/PolicyGradientAgents.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | // TODO: Add support for gradient clipping. 18 | // TODO: L2 regularization support for networks. 19 | // TODO: Reward normalizer. 20 | // TODO: Reward norm clipping. 21 | 22 | public protocol PolicyGradientAgent: ProbabilisticAgent {} 23 | 24 | extension PolicyGradientAgent { 25 | @inlinable 26 | @discardableResult 27 | public mutating func update( 28 | using environment: inout Environment, 29 | maxSteps: Int = Int.max, 30 | maxEpisodes: Int = Int.max, 31 | callbacks: [StepCallback] = [] 32 | ) throws -> Float { 33 | var trajectories = [Trajectory]() 34 | var currentStep = environment.currentStep 35 | var numSteps = 0 36 | var numEpisodes = 0 37 | while numSteps < maxSteps && numEpisodes < maxEpisodes { 38 | let state = self.state 39 | let action = self.action(for: currentStep, mode: .probabilistic) 40 | let nextStep = try environment.step(taking: action) 41 | var trajectory = Trajectory( 42 | stepKind: nextStep.kind, 43 | observation: currentStep.observation, 44 | state: state, 45 | action: action, 46 | reward: nextStep.reward) 47 | trajectories.append(trajectory) 48 | callbacks.forEach { $0(&environment, &trajectory) } 49 | numSteps += Int((1 - Tensor(nextStep.kind.isLast())).sum().scalarized()) 50 | numEpisodes += Int(Tensor(nextStep.kind.isLast()).sum().scalarized()) 51 | currentStep = nextStep 52 | } 53 | return update(using: Trajectory.stack(trajectories)) 54 | } 55 | } 56 | 57 | public struct ReinforceAgent< 58 | Environment: ReinforcementLearning.Environment, 59 | State: Differentiable, 60 | Network: Module, 61 | Optimizer: TensorFlow.Optimizer 62 | >: PolicyGradientAgent 63 | where 64 | Environment.ActionSpace.ValueDistribution: DifferentiableDistribution, 65 | Environment.Reward == Tensor, 66 | Network.Input == AgentInput, 67 | Network.Output == ActorOutput, 68 | Optimizer.Model == Network 69 | { 70 | public typealias Observation = Environment.Observation 71 | public typealias Action = ActionDistribution.Value 72 | public typealias ActionDistribution = Environment.ActionSpace.ValueDistribution 73 | public typealias Reward = Tensor 74 | 75 | public let actionSpace: Environment.ActionSpace 76 | public var state: State 77 | public var network: Network 78 | public var optimizer: Optimizer 79 | 80 | public let discountFactor: Float 81 | public let entropyRegularizationWeight: Float 82 | 83 | @usableFromInline internal var returnsNormalizer: TensorNormalizer? 84 | 85 | @inlinable 86 | public init( 87 | for environment: Environment, 88 | network: Network, 89 | initialState: State, 90 | optimizer: (Network) -> Optimizer, 91 | discountFactor: Float, 92 | normalizeReturns: Bool = true, 93 | entropyRegularizationWeight: Float = 0.0 94 | ) { 95 | self.actionSpace = environment.actionSpace 96 | self.state = initialState 97 | self.network = network 98 | self.optimizer = optimizer(network) 99 | self.discountFactor = discountFactor 100 | self.returnsNormalizer = normalizeReturns ? 101 | TensorNormalizer(streaming: true, alongAxes: 0, 1) : 102 | nil 103 | self.entropyRegularizationWeight = entropyRegularizationWeight 104 | } 105 | 106 | @inlinable 107 | public mutating func actionDistribution( 108 | for step: Step 109 | ) -> ActionDistribution { 110 | let networkOutput = network(AgentInput(observation: step.observation, state: state)) 111 | state = networkOutput.state 112 | return networkOutput.actionDistribution 113 | } 114 | 115 | @discardableResult 116 | public mutating func update( 117 | using trajectory: Trajectory 118 | ) -> Float { 119 | var returns = discountedReturns( 120 | discountFactor: discountFactor, 121 | stepKinds: trajectory.stepKind, 122 | rewards: trajectory.reward) 123 | let (loss, gradient) = valueWithGradient(at: network) { network -> Tensor in 124 | let networkOutput = network(AgentInput( 125 | observation: trajectory.observation, 126 | state: trajectory.state)) 127 | let actionDistribution = networkOutput.actionDistribution 128 | self.returnsNormalizer?.update(using: returns) 129 | if let normalizer = self.returnsNormalizer { 130 | returns = normalizer.normalize(returns) 131 | } 132 | let actionLogProbs = actionDistribution.logProbability(of: trajectory.action) 133 | 134 | // The policy gradient loss is defined as the sum, over time steps, of action 135 | // log-probabilities multiplied with the cumulative return from that time step onward. 136 | let actionLogProbWeightedReturns = actionLogProbs * returns 137 | 138 | // REINFORCE requires completed episodes and thus we mask out incomplete ones. 139 | let mask = Tensor(trajectory.stepKind.completeEpisodeMask()) 140 | let episodeCount = trajectory.stepKind.episodeCount() 141 | 142 | precondition( 143 | episodeCount.scalarized() > 0, 144 | "REINFORCE requires at least one completed episode.") 145 | 146 | // We compute the mean of the policy gradient loss over the number of episodes. 147 | let policyGradientLoss = -(actionLogProbWeightedReturns * mask).sum() / episodeCount 148 | 149 | // If entropy regularization is being used for the action distribution, then we also 150 | // compute the entropy loss term. 151 | var entropyLoss = Tensor(0.0) 152 | if self.entropyRegularizationWeight > 0.0 { 153 | let entropy = actionDistribution.entropy() 154 | entropyLoss = entropyLoss - self.entropyRegularizationWeight * entropy.mean() 155 | } 156 | return policyGradientLoss + entropyLoss 157 | } 158 | optimizer.update(&network, along: gradient) 159 | return loss.scalarized() 160 | } 161 | } 162 | 163 | extension ReinforceAgent where State == Empty { 164 | @inlinable 165 | public init( 166 | for environment: Environment, 167 | network: StatelessNetwork, 168 | optimizer: (StatelessActorNetwork) -> Optimizer, 169 | discountFactor: Float, 170 | normalizeReturns: Bool = true, 171 | entropyRegularizationWeight: Float = 0.0 172 | ) where 173 | Network == StatelessActorNetwork, 174 | StatelessNetwork.Input == Observation, 175 | StatelessNetwork.Output == ActionDistribution 176 | { 177 | let network = StatelessActorNetwork(network) 178 | self.init( 179 | for: environment, 180 | network: network, 181 | initialState: Empty(), 182 | optimizer: optimizer, 183 | discountFactor: discountFactor, 184 | normalizeReturns: normalizeReturns, 185 | entropyRegularizationWeight: entropyRegularizationWeight) 186 | } 187 | } 188 | 189 | public struct A2CAgent< 190 | Environment: ReinforcementLearning.Environment, 191 | State: Differentiable, 192 | Network: Module, 193 | Optimizer: TensorFlow.Optimizer 194 | >: PolicyGradientAgent 195 | where 196 | Environment.Reward == Tensor, 197 | Network.Input == AgentInput, 198 | Network.Output == ActorCriticOutput, 199 | Optimizer.Model == Network 200 | { 201 | public typealias Observation = Environment.Observation 202 | public typealias Action = ActionDistribution.Value 203 | public typealias ActionDistribution = Environment.ActionSpace.ValueDistribution 204 | public typealias Reward = Tensor 205 | 206 | public let actionSpace: Environment.ActionSpace 207 | public var state: State 208 | public var network: Network 209 | public var optimizer: Optimizer 210 | 211 | public let advantageFunction: AdvantageFunction 212 | public let valueEstimationLossWeight: Float 213 | public let entropyRegularizationWeight: Float 214 | 215 | @usableFromInline internal var advantagesNormalizer: TensorNormalizer? 216 | 217 | @inlinable 218 | public init( 219 | for environment: Environment, 220 | network: Network, 221 | initialState: State, 222 | optimizer: (Network) -> Optimizer, 223 | advantageFunction: AdvantageFunction = GeneralizedAdvantageEstimation(discountFactor: 0.9), 224 | normalizeAdvantages: Bool = true, 225 | valueEstimationLossWeight: Float = 0.2, 226 | entropyRegularizationWeight: Float = 0.0 227 | ) { 228 | self.actionSpace = environment.actionSpace 229 | self.state = initialState 230 | self.network = network 231 | self.optimizer = optimizer(network) 232 | self.advantageFunction = advantageFunction 233 | self.advantagesNormalizer = normalizeAdvantages ? 234 | TensorNormalizer(streaming: true, alongAxes: 0, 1) : 235 | nil 236 | self.valueEstimationLossWeight = valueEstimationLossWeight 237 | self.entropyRegularizationWeight = entropyRegularizationWeight 238 | } 239 | 240 | @inlinable 241 | public mutating func actionDistribution( 242 | for step: Step 243 | ) -> ActionDistribution { 244 | let networkOutput = network(AgentInput(observation: step.observation, state: state)) 245 | state = networkOutput.state 246 | return networkOutput.actionDistribution 247 | } 248 | 249 | @discardableResult 250 | public mutating func update( 251 | using trajectory: Trajectory 252 | ) -> Float { 253 | let (loss, gradient) = valueWithGradient(at: network) { network -> Tensor in 254 | let networkOutput = network(AgentInput( 255 | observation: trajectory.observation, 256 | state: trajectory.state)) 257 | 258 | // Split the trajectory such that the last step is only used to provide the final value 259 | // estimate used for advantage estimation. 260 | let sequenceLength = networkOutput.value.shape[0] - 1 261 | let stepKinds = StepKind(trajectory.stepKind.rawValue[0..(0.0) 297 | if self.entropyRegularizationWeight > 0.0 { 298 | let entropy = actionDistribution.entropy()[0..( 311 | for environment: Environment, 312 | network: StatelessNetwork, 313 | optimizer: (StatelessActorCriticNetwork) -> Optimizer, 314 | advantageFunction: AdvantageFunction = GeneralizedAdvantageEstimation(discountFactor: 0.9), 315 | normalizeAdvantages: Bool = true, 316 | valueEstimationLossWeight: Float = 0.2, 317 | entropyRegularizationWeight: Float = 0.0 318 | ) where 319 | Network == StatelessActorCriticNetwork, 320 | StatelessNetwork.Input == Observation, 321 | StatelessNetwork.Output == StatelessActorCriticOutput 322 | { 323 | let network = StatelessActorCriticNetwork(network) 324 | self.init( 325 | for: environment, 326 | network: network, 327 | initialState: Empty(), 328 | optimizer: optimizer, 329 | advantageFunction: advantageFunction, 330 | normalizeAdvantages: normalizeAdvantages, 331 | valueEstimationLossWeight: valueEstimationLossWeight, 332 | entropyRegularizationWeight: entropyRegularizationWeight) 333 | } 334 | } 335 | 336 | // TODO: !! Allow `epsilon` to change while training. 337 | public struct PPOClip { 338 | public let epsilon: Float 339 | 340 | @inlinable 341 | public init(epsilon: Float = 0.1) { 342 | self.epsilon = epsilon 343 | } 344 | } 345 | 346 | public struct PPOPenalty { 347 | public let klCutoffFactor: Float 348 | public let klCutoffCoefficient: Float 349 | public let adaptiveKLTarget: Float 350 | public let adaptiveKLToleranceFactor: Float 351 | public let adaptiveKLBetaScalingFactor: Float 352 | public var adaptiveKLBeta: Float? 353 | 354 | @inlinable 355 | public init( 356 | klCutoffFactor: Float = 0.2, 357 | klCutoffCoefficient: Float = 1000.0, 358 | adaptiveKLTarget: Float = 0.01, 359 | adaptiveKLToleranceFactor: Float = 1.5, 360 | adaptiveKLBetaScalingFactor: Float = 2.0, 361 | adaptiveKLBeta: Float? = 1.0 362 | ) { 363 | precondition(adaptiveKLBetaScalingFactor > 0, "The beta scaling factor must be positive.") 364 | self.klCutoffFactor = klCutoffFactor 365 | self.klCutoffCoefficient = klCutoffCoefficient 366 | self.adaptiveKLTarget = adaptiveKLTarget 367 | self.adaptiveKLToleranceFactor = adaptiveKLToleranceFactor 368 | self.adaptiveKLBetaScalingFactor = adaptiveKLBetaScalingFactor 369 | self.adaptiveKLBeta = adaptiveKLBeta 370 | } 371 | } 372 | 373 | // TODO: !! Allow `clipThreshold` to change while training. 374 | public struct PPOValueEstimationLoss { 375 | public let weight: Float 376 | public let clipThreshold: Float? 377 | 378 | @inlinable 379 | public init(weight: Float = 0.5, clipThreshold: Float? = 0.1) { 380 | self.weight = weight 381 | self.clipThreshold = clipThreshold 382 | } 383 | } 384 | 385 | public struct PPOEntropyRegularization { 386 | public let weight: Float 387 | 388 | @inlinable 389 | public init(weight: Float) { 390 | self.weight = weight 391 | } 392 | } 393 | 394 | public struct PPOAgent< 395 | Environment: ReinforcementLearning.Environment, 396 | State: Differentiable, 397 | Network: Module, 398 | Optimizer: TensorFlow.Optimizer, 399 | LearningRate: ReinforcementLearning.LearningRate 400 | >: PolicyGradientAgent 401 | where 402 | Environment.ActionSpace.ValueDistribution: DifferentiableKLDivergence, 403 | Environment.Reward == Tensor, 404 | Network.Input == AgentInput, 405 | Network.Output == ActorCriticOutput, 406 | Optimizer.Model == Network, 407 | LearningRate.Scalar == Optimizer.Scalar 408 | { 409 | public typealias Observation = Environment.Observation 410 | public typealias Action = ActionDistribution.Value 411 | public typealias ActionDistribution = Environment.ActionSpace.ValueDistribution 412 | public typealias Reward = Tensor 413 | 414 | public let actionSpace: Environment.ActionSpace 415 | public var state: State 416 | public var network: Network 417 | public var optimizer: Optimizer 418 | public var trainingStep: UInt64 = 0 419 | 420 | public let learningRate: LearningRate 421 | public let maxGradientNorm: Float? 422 | public let rewardsPreprocessor: (Tensor) -> Tensor 423 | public let advantageFunction: AdvantageFunction 424 | public let useTDLambdaReturn: Bool 425 | public let clip: PPOClip? 426 | public let penalty: PPOPenalty? 427 | public let valueEstimationLoss: PPOValueEstimationLoss 428 | public let entropyRegularization: PPOEntropyRegularization? 429 | public let iterationCountPerUpdate: Int 430 | 431 | @usableFromInline internal var advantagesNormalizer: TensorNormalizer? 432 | 433 | @inlinable 434 | public init( 435 | for environment: Environment, 436 | network: Network, 437 | initialState: State, 438 | optimizer: (Network) -> Optimizer, 439 | learningRate: LearningRate, 440 | maxGradientNorm: Float? = 0.5, 441 | rewardsPreprocessor: @escaping (Tensor) -> Tensor = { $0 }, 442 | advantageFunction: AdvantageFunction = GeneralizedAdvantageEstimation( 443 | discountFactor: 0.99, 444 | discountWeight: 0.95), 445 | advantagesNormalizer: TensorNormalizer? = TensorNormalizer( 446 | streaming: true, 447 | alongAxes: 0, 1), 448 | useTDLambdaReturn: Bool = true, 449 | clip: PPOClip? = PPOClip(), 450 | penalty: PPOPenalty? = PPOPenalty(), 451 | valueEstimationLoss: PPOValueEstimationLoss = PPOValueEstimationLoss(), 452 | entropyRegularization: PPOEntropyRegularization? = PPOEntropyRegularization(weight: 0.01), 453 | iterationCountPerUpdate: Int = 4 454 | ) { 455 | self.actionSpace = environment.actionSpace 456 | self.state = initialState 457 | self.network = network 458 | self.optimizer = optimizer(network) 459 | self.learningRate = learningRate 460 | self.maxGradientNorm = maxGradientNorm 461 | self.rewardsPreprocessor = rewardsPreprocessor 462 | self.advantageFunction = advantageFunction 463 | self.advantagesNormalizer = advantagesNormalizer 464 | self.useTDLambdaReturn = useTDLambdaReturn 465 | self.clip = clip 466 | self.penalty = penalty 467 | self.valueEstimationLoss = valueEstimationLoss 468 | self.entropyRegularization = entropyRegularization 469 | self.iterationCountPerUpdate = iterationCountPerUpdate 470 | } 471 | 472 | @inlinable 473 | public mutating func actionDistribution( 474 | for step: Step 475 | ) -> ActionDistribution { 476 | let networkOutput = network(AgentInput(observation: step.observation, state: state)) 477 | state = networkOutput.state 478 | return networkOutput.actionDistribution 479 | } 480 | 481 | @discardableResult 482 | public mutating func update( 483 | using trajectory: Trajectory 484 | ) -> Float { 485 | optimizer.learningRate = learningRate(forStep: trainingStep) 486 | trainingStep += 1 487 | 488 | // Split the trajectory such that the last step is only used to provide the final value 489 | // estimate used for advantage estimation. 490 | let networkOutput = network(AgentInput( 491 | observation: trajectory.observation, 492 | state: trajectory.state)) 493 | let sequenceLength = networkOutput.value.shape[0] - 1 494 | let stepKinds = StepKind(trajectory.stepKind.rawValue[0.. Tensor in 524 | // TODO: Should we be updating the state here? 525 | let newNetworkOutput = network(AgentInput( 526 | observation: trajectory.observation, 527 | state: trajectory.state)) 528 | 529 | // Compute the new action log probabilities. 530 | let newActionDistribution = newNetworkOutput.actionDistribution 531 | let newActionLogProbs = newActionDistribution.logProbability( 532 | of: trajectory.action 533 | )[0..(c.epsilon) 541 | let importanceRatioClipped = importanceRatio.clipped(min: 1 - ε, max: 1 + ε) 542 | loss = -min(loss, importanceRatioClipped * advantages).mean() 543 | } else { 544 | loss = -loss.mean() 545 | } 546 | 547 | // KL penalty loss term. 548 | if let p = self.penalty { 549 | let klDivergence = actionDistribution.klDivergence(to: newActionDistribution) 550 | let klMean = klDivergence.mean() 551 | let klCutoffLoss = max(klMean - p.klCutoffFactor * p.adaptiveKLTarget, 0).squared() 552 | loss = loss + p.klCutoffCoefficient * klCutoffLoss 553 | if let beta = p.adaptiveKLBeta { 554 | loss = loss + beta * klMean 555 | } 556 | } 557 | 558 | // Entropy regularization loss term. 559 | if let e = self.entropyRegularization { 560 | let entropy = newActionDistribution.entropy()[0..(c) 569 | let clippedValues = values + (newValues - values).clipped(min: -ε, max: ε) 570 | let clippedValueLoss = (clippedValues - returns).squared() 571 | valueLoss = max(valueLoss, clippedValueLoss) 572 | } 573 | return loss + self.valueEstimationLoss.weight * valueLoss.mean() / 2 574 | } 575 | if let clipNorm = maxGradientNorm { 576 | gradient.clipByGlobalNorm(clipNorm: clipNorm) 577 | } 578 | optimizer.update(&network, along: gradient) 579 | lastEpochLoss = loss.scalarized() 580 | } 581 | 582 | // After the network is updated, we may need to update the adaptive KL beta. 583 | if var p = penalty, let beta = p.adaptiveKLBeta { 584 | let klDivergence = network(AgentInput( 585 | observation: trajectory.observation, 586 | state: trajectory.state) 587 | ).actionDistribution.klDivergence(to: actionDistribution) 588 | let klMean = klDivergence.mean().scalarized() 589 | if klMean < p.adaptiveKLTarget / p.adaptiveKLToleranceFactor { 590 | p.adaptiveKLBeta = max(beta / p.adaptiveKLBetaScalingFactor, 1e-16) 591 | } else if klMean > p.adaptiveKLTarget * p.adaptiveKLToleranceFactor { 592 | p.adaptiveKLBeta = beta * p.adaptiveKLBetaScalingFactor 593 | } 594 | } 595 | 596 | return lastEpochLoss 597 | } 598 | } 599 | 600 | extension PPOAgent where State == Empty { 601 | @inlinable 602 | public init( 603 | for environment: Environment, 604 | network: StatelessNetwork, 605 | optimizer: (StatelessActorCriticNetwork) -> Optimizer, 606 | learningRate: LearningRate, 607 | maxGradientNorm: Float? = 0.5, 608 | rewardsPreprocessor: @escaping (Tensor) -> Tensor = { $0 }, 609 | advantageFunction: AdvantageFunction = GeneralizedAdvantageEstimation( 610 | discountFactor: 0.99, 611 | discountWeight: 0.95), 612 | advantagesNormalizer: TensorNormalizer? = TensorNormalizer( 613 | streaming: true, 614 | alongAxes: 0, 1), 615 | useTDLambdaReturn: Bool = true, 616 | clip: PPOClip? = PPOClip(), 617 | penalty: PPOPenalty? = PPOPenalty(), 618 | valueEstimationLoss: PPOValueEstimationLoss = PPOValueEstimationLoss(), 619 | entropyRegularization: PPOEntropyRegularization? = PPOEntropyRegularization(weight: 0.01), 620 | iterationCountPerUpdate: Int = 4 621 | ) where 622 | Network == StatelessActorCriticNetwork, 623 | StatelessNetwork.Input == Observation, 624 | StatelessNetwork.Output == StatelessActorCriticOutput 625 | { 626 | let network = StatelessActorCriticNetwork(network) 627 | self.init( 628 | for: environment, 629 | network: network, 630 | initialState: Empty(), 631 | optimizer: optimizer, 632 | learningRate: learningRate, 633 | maxGradientNorm: maxGradientNorm, 634 | rewardsPreprocessor: rewardsPreprocessor, 635 | advantageFunction: advantageFunction, 636 | advantagesNormalizer: advantagesNormalizer, 637 | useTDLambdaReturn: useTDLambdaReturn, 638 | clip: clip, 639 | penalty: penalty, 640 | valueEstimationLoss: valueEstimationLoss, 641 | entropyRegularization: entropyRegularization, 642 | iterationCountPerUpdate: iterationCountPerUpdate) 643 | } 644 | } 645 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Distributions/Bernoulli.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | public struct Bernoulli: DifferentiableDistribution, KeyPathIterable { 18 | /// Unnormalized log-probabilities of this bernoulli distribution. 19 | public var logits: Tensor 20 | 21 | @inlinable 22 | @differentiable(wrt: logits) 23 | public init(logits: Tensor) { 24 | self.logits = logits 25 | } 26 | 27 | @inlinable 28 | @differentiable(wrt: logProbabilities) 29 | public init(logProbabilities: Tensor) { 30 | self.logits = logProbabilities 31 | } 32 | 33 | @inlinable 34 | @differentiable(wrt: probabilities) 35 | public init(probabilities: Tensor) { 36 | self.logits = log(probabilities) 37 | } 38 | 39 | @inlinable 40 | @differentiable(wrt: self) 41 | public func logProbability(of value: Tensor) -> Tensor { 42 | max(logits, Tensor(0.0)) - logits * Tensor(value) + softplus(-abs(logits)) 43 | } 44 | 45 | @inlinable 46 | @differentiable(wrt: self) 47 | public func entropy() -> Tensor { 48 | max(logits, Tensor(0.0)) - logits * sigmoid(logits) + softplus(-abs(logits)) 49 | } 50 | 51 | @inlinable 52 | public func mode() -> Tensor { 53 | Tensor(logSigmoid(logits) .> log(0.5)) 54 | } 55 | 56 | @inlinable 57 | public func sample() -> Tensor { 58 | let seed = Context.local.randomSeed 59 | let logProbabilities = logSigmoid(logits) 60 | let uniform: Tensor = _Raw.statelessRandomUniform( 61 | shape: logProbabilities.shapeTensor, 62 | seed: Tensor([seed.graph, seed.op])) 63 | return Tensor(logProbabilities .< log(uniform)) 64 | } 65 | } 66 | 67 | // TODO: !!! Is the following correct? 68 | extension Bernoulli: DifferentiableKLDivergence { 69 | @inlinable 70 | @differentiable 71 | public func klDivergence(to target: Bernoulli) -> Tensor { 72 | let logProbabilities = logSigmoid(logits) 73 | let kl = exp(logProbabilities) * (logProbabilities - logSigmoid(target.logits)) 74 | return kl.sum(squeezingAxes: -1) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Distributions/Categorical.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | public struct Categorical: DifferentiableDistribution, KeyPathIterable { 18 | /// Log-probabilities of this categorical distribution. 19 | public var logProbabilities: Tensor 20 | 21 | @inlinable 22 | @differentiable(wrt: logProbabilities) 23 | public init(logProbabilities: Tensor) { 24 | self.logProbabilities = logProbabilities 25 | } 26 | 27 | @inlinable 28 | @differentiable(wrt: probabilities) 29 | public init(probabilities: Tensor) { 30 | self.logProbabilities = log(probabilities) 31 | } 32 | 33 | @inlinable 34 | @differentiable(wrt: logits) 35 | public init(logits: Tensor) { 36 | self.logProbabilities = logSoftmax(logits) 37 | } 38 | 39 | @inlinable 40 | @differentiable(wrt: self) 41 | public func logProbability(of value: Tensor) -> Tensor { 42 | logProbabilities.batchGathering( 43 | atIndices: value.expandingShape(at: -1), 44 | alongAxis: 2, 45 | batchDimensionCount: 2 46 | ).squeezingShape(at: -1) 47 | } 48 | 49 | @inlinable 50 | @differentiable(wrt: self) 51 | public func entropy() -> Tensor { 52 | -(logProbabilities * exp(logProbabilities)).sum(squeezingAxes: -1) 53 | } 54 | 55 | @inlinable 56 | public func mode() -> Tensor { 57 | Tensor(logProbabilities.argmax(squeezingAxis: 1)) 58 | } 59 | 60 | @inlinable 61 | public func sample() -> Tensor { 62 | let seed = Context.local.randomSeed 63 | let outerDimCount = self.logProbabilities.rank - 1 64 | let logProbabilities = self.logProbabilities.flattenedBatch(outerDimCount: outerDimCount) 65 | let multinomial: Tensor = _Raw.statelessMultinomial( 66 | logits: logProbabilities, 67 | numSamples: Tensor(1), 68 | seed: Tensor([seed.graph, seed.op])) 69 | let flattenedSamples = multinomial.gathering(atIndices: Tensor(0), alongAxis: 1) 70 | return flattenedSamples.unflattenedBatch( 71 | outerDims: [Int](self.logProbabilities.shape.dimensions[0.. Tensor { 79 | (exp(logProbabilities) * (logProbabilities - target.logProbabilities)).sum(squeezingAxes: -1) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Distributions/Deterministic.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | public struct Deterministic: Distribution, KeyPathIterable { 18 | // TODO: Make `internal(set)` once `@usableFromInline` is supported. 19 | public var value: Tensor 20 | 21 | @inlinable 22 | public init(at value: Tensor) { 23 | self.value = value 24 | } 25 | 26 | @inlinable 27 | public func logProbability(of value: Tensor) -> Tensor { 28 | // TODO: What about NaNs? 29 | log(Tensor(value .== self.value)) 30 | } 31 | 32 | @inlinable 33 | public func entropy() -> Tensor { 34 | Tensor(zeros: value.shape) 35 | } 36 | 37 | @inlinable 38 | public func mode() -> Tensor { 39 | value 40 | } 41 | 42 | @inlinable 43 | public func sample() -> Tensor { 44 | value 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Distributions/Distribution.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | public protocol Distribution { 18 | associatedtype Value 19 | 20 | func logProbability(of value: Value) -> Tensor 21 | func entropy() -> Tensor 22 | 23 | /// Returns the mode of this distribution. If the distribution has multiple modes, then one of 24 | /// them is sampled randomly (and uniformly) and returned. 25 | func mode() -> Value 26 | 27 | /// Returns a random sample drawn from this distribution. 28 | func sample() -> Value 29 | } 30 | 31 | public extension Distribution { 32 | func probability(of value: Value) -> Tensor { 33 | exp(logProbability(of: value)) 34 | } 35 | } 36 | 37 | public protocol DifferentiableDistribution: Distribution, Differentiable { 38 | @differentiable(wrt: self) 39 | func logProbability(of value: Value) -> Tensor 40 | 41 | @differentiable(wrt: self) 42 | func entropy() -> Tensor 43 | } 44 | 45 | extension DifferentiableDistribution { 46 | @inlinable 47 | @differentiable(wrt: self) 48 | public func probability(of value: Value) -> Tensor { 49 | exp(logProbability(of: value)) 50 | } 51 | } 52 | 53 | // TODO: It would be great to support KL divergence between different distributions, but that 54 | // would require multiple conformances to the same protocol with different `TargetDistribution` 55 | // types, which is not currently supported in Swift. :( This is also a place where a feature 56 | // similar to Scala implicits would be great. 57 | public protocol KLDivergence where Self: Distribution { 58 | func klDivergence(to target: Self) -> Tensor 59 | } 60 | 61 | public protocol DifferentiableKLDivergence: KLDivergence where Self: DifferentiableDistribution { 62 | @differentiable 63 | func klDivergence(to target: Self) -> Tensor 64 | } 65 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Distributions/Uniform.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | public struct Uniform< 18 | Scalar: TensorFlowFloatingPoint 19 | >: DifferentiableDistribution, KeyPathIterable { 20 | @noDerivative public let shape: Tensor 21 | public var lowerBound: Tensor 22 | public var upperBound: Tensor 23 | 24 | @inlinable 25 | @differentiable(wrt: (lowerBound, upperBound)) 26 | public init( 27 | shape: Tensor, 28 | lowerBound: Tensor = Tensor(zeros: []), 29 | upperBound: Tensor = Tensor(ones: []) 30 | ) { 31 | self.shape = shape 32 | self.lowerBound = lowerBound 33 | self.upperBound = upperBound 34 | } 35 | 36 | @inlinable 37 | @differentiable(wrt: self) 38 | public func logProbability(of value: Tensor) -> Tensor { 39 | log(1.0) - log(Tensor(upperBound - lowerBound)) 40 | } 41 | 42 | @inlinable 43 | @differentiable(wrt: self) 44 | public func entropy() -> Tensor { 45 | log(Tensor(upperBound - lowerBound)) 46 | } 47 | 48 | @inlinable 49 | public func mode() -> Tensor { 50 | sample() 51 | } 52 | 53 | @inlinable 54 | public func sample() -> Tensor { 55 | // TODO: Make `Tensor.init(randomUniform:...)` accept `Tensor` for the shape. 56 | let seed = Context.local.randomSeed 57 | let sample: Tensor = _Raw.statelessRandomUniform( 58 | shape: shape, 59 | seed: Tensor([seed.graph, seed.op])) 60 | return sample * (upperBound - lowerBound) + lowerBound 61 | } 62 | } 63 | 64 | extension Uniform { 65 | @inlinable 66 | @differentiable(wrt: (lowerBound, upperBound)) 67 | public init(lowerBound: Tensor, upperBound: Tensor) { 68 | self.shape = withoutDerivative(at: lowerBound.shapeTensor) 69 | self.lowerBound = lowerBound 70 | self.upperBound = upperBound 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Environments/ClassicControl/CartPole.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | @usableFromInline internal let gravity: Float = 9.8 18 | @usableFromInline internal let cartMass: Float = 1.0 19 | @usableFromInline internal let poleMass: Float = 0.1 20 | @usableFromInline internal let length: Float = 0.5 21 | @usableFromInline internal let forceMagnitude: Float = 10.0 22 | @usableFromInline internal let secondCountBetweenUpdates: Float = 0.02 23 | @usableFromInline internal let angleThreshold: Float = 12 * 2 * Float.pi / 360 24 | @usableFromInline internal let positionThreshold: Float = 2.4 25 | @usableFromInline internal let totalMass: Float = cartMass + poleMass 26 | @usableFromInline internal let poleMassLength: Float = poleMass * length 27 | 28 | public struct CartPoleEnvironment: RenderableEnvironment { 29 | public let batchSize: Int 30 | public let actionSpace: Discrete 31 | public var observationSpace: ObservationSpace 32 | 33 | @usableFromInline internal var step: Step> 34 | @usableFromInline internal var needsReset: Tensor 35 | @usableFromInline internal var renderer: CartPoleRenderer? = nil 36 | 37 | @inlinable public var currentStep: Step> { step } 38 | 39 | @inlinable 40 | public init(batchSize: Int, renderer: CartPoleRenderer? = nil) { 41 | self.batchSize = batchSize 42 | self.actionSpace = Discrete(withSize: 2, batchSize: batchSize) 43 | self.observationSpace = ObservationSpace(batchSize: batchSize) 44 | self.step = Step( 45 | kind: StepKind.first(batchSize: batchSize), 46 | observation: observationSpace.sample(), 47 | reward: Tensor(ones: [batchSize])) 48 | self.needsReset = Tensor(repeating: false, shape: [batchSize]) 49 | self.renderer = renderer 50 | } 51 | 52 | /// Updates the environment according to the provided action. 53 | @inlinable 54 | @discardableResult 55 | public mutating func step(taking action: Tensor) -> Step> { 56 | // precondition(actionSpace.contains(action), "Invalid action provided.") 57 | var position = step.observation.position 58 | var positionDerivative = step.observation.positionDerivative 59 | var angle = step.observation.angle 60 | var angleDerivative = step.observation.angleDerivative 61 | 62 | // Calculate the updates to the pole position, angle, and their derivatives. 63 | let force = Tensor(2 * action - 1) * forceMagnitude 64 | let angleCosine = cos(angle) 65 | let angleSine = sin(angle) 66 | let temp = force + poleMassLength * angleDerivative * angleDerivative * angleSine 67 | let angleAccNominator = gravity * angleSine - temp * angleCosine / totalMass 68 | let angleAccDenominator = 4/3 - poleMass * angleCosine * angleCosine / totalMass 69 | let angleAcc = angleAccNominator / (length * angleAccDenominator) 70 | let positionAcc = (temp - poleMassLength * angleAcc * angleCosine) / totalMass 71 | position += secondCountBetweenUpdates * positionDerivative 72 | positionDerivative += secondCountBetweenUpdates * positionAcc 73 | angle += secondCountBetweenUpdates * angleDerivative 74 | angleDerivative += secondCountBetweenUpdates * angleAcc 75 | 76 | // Take into account the finished simulations in the batch. 77 | let sample = observationSpace.sample() 78 | step.observation.position = position.replacing(with: sample.position, where: needsReset) 79 | step.observation.positionDerivative = positionDerivative.replacing( 80 | with: sample.positionDerivative, 81 | where: needsReset) 82 | step.observation.angle = angle.replacing(with: sample.angle, where: needsReset) 83 | step.observation.angleDerivative = angleDerivative.replacing( 84 | with: sample.angleDerivative, 85 | where: needsReset) 86 | let newNeedsReset = (step.observation.position .< -positionThreshold) 87 | .elementsLogicalOr(step.observation.position .> positionThreshold) 88 | .elementsLogicalOr(step.observation.angle .< -angleThreshold) 89 | .elementsLogicalOr(step.observation.angle .> angleThreshold) 90 | step.kind.rawValue = Tensor(onesLike: step.kind.rawValue) 91 | .replacing(with: Tensor(zeros: newNeedsReset.shape), where: needsReset) 92 | .replacing(with: 3 * Tensor(ones: newNeedsReset.shape), where: newNeedsReset) 93 | // Rewards need not be updated because they are always equal to one. 94 | needsReset = newNeedsReset 95 | return step 96 | } 97 | 98 | /// Resets the environment. 99 | @inlinable 100 | @discardableResult 101 | public mutating func reset() -> Step> { 102 | step.kind = StepKind.first(batchSize: batchSize) 103 | step.observation = observationSpace.sample() 104 | needsReset = Tensor(repeating: false, shape: [batchSize]) 105 | return step 106 | } 107 | 108 | /// Returns a copy of this environment that is reset before being returned. 109 | @inlinable 110 | public func copy() -> CartPoleEnvironment { 111 | CartPoleEnvironment(batchSize: batchSize, renderer: renderer) 112 | } 113 | 114 | @inlinable 115 | public mutating func render() { 116 | if renderer == nil { renderer = CartPoleRenderer() } 117 | renderer!.render(observation: step.observation) 118 | } 119 | } 120 | 121 | extension CartPoleEnvironment { 122 | public struct Observation: Differentiable, KeyPathIterable { 123 | public var position: Tensor 124 | public var positionDerivative: Tensor 125 | public var angle: Tensor 126 | public var angleDerivative: Tensor 127 | 128 | @inlinable 129 | public init( 130 | position: Tensor, 131 | positionDerivative: Tensor, 132 | angle: Tensor, 133 | angleDerivative: Tensor 134 | ) { 135 | self.position = position 136 | self.positionDerivative = positionDerivative 137 | self.angle = angle 138 | self.angleDerivative = angleDerivative 139 | } 140 | } 141 | 142 | public struct ObservationSpace: Space { 143 | public let distribution: ValueDistribution 144 | 145 | @inlinable 146 | public init(batchSize: Int) { 147 | self.distribution = ValueDistribution(batchSize: batchSize) 148 | } 149 | 150 | @inlinable 151 | public var description: String { 152 | "CartPoleObservation" 153 | } 154 | 155 | @inlinable 156 | public func contains(_ value: Observation) -> Bool { 157 | true 158 | } 159 | 160 | public struct ValueDistribution: DifferentiableDistribution, KeyPathIterable { 161 | @noDerivative public let batchSize: Int 162 | 163 | public var positionDistribution: Uniform { 164 | Uniform( 165 | lowerBound: Tensor(repeating: -0.05, shape: [batchSize]), 166 | upperBound: Tensor(repeating: 0.05, shape: [batchSize])) 167 | } 168 | 169 | public var positionDerivativeDistribution: Uniform { 170 | Uniform( 171 | lowerBound: Tensor(repeating: -0.05, shape: [batchSize]), 172 | upperBound: Tensor(repeating: 0.05, shape: [batchSize])) 173 | } 174 | 175 | public var angleDistribution: Uniform { 176 | Uniform( 177 | lowerBound: Tensor(repeating: -0.05, shape: [batchSize]), 178 | upperBound: Tensor(repeating: 0.05, shape: [batchSize])) 179 | } 180 | 181 | public var angleDerivativeDistribution: Uniform { 182 | Uniform( 183 | lowerBound: Tensor(repeating: -0.05, shape: [batchSize]), 184 | upperBound: Tensor(repeating: 0.05, shape: [batchSize])) 185 | } 186 | 187 | @inlinable 188 | public init(batchSize: Int) { 189 | self.batchSize = batchSize 190 | } 191 | 192 | // TODO: @inlinable 193 | @differentiable(wrt: self) 194 | public func logProbability(of value: Observation) -> Tensor { 195 | positionDistribution.logProbability(of: value.position) + 196 | positionDerivativeDistribution.logProbability(of: value.positionDerivative) + 197 | angleDistribution.logProbability(of: value.angle) + 198 | angleDerivativeDistribution.logProbability(of: value.angleDerivative) 199 | } 200 | 201 | // TODO: @inlinable 202 | @differentiable(wrt: self) 203 | public func entropy() -> Tensor { 204 | positionDistribution.entropy() + 205 | positionDerivativeDistribution.entropy() + 206 | angleDistribution.entropy() + 207 | angleDerivativeDistribution.entropy() 208 | } 209 | 210 | @inlinable 211 | public func mode() -> Observation { 212 | Observation( 213 | position: positionDistribution.mode(), 214 | positionDerivative: positionDerivativeDistribution.mode(), 215 | angle: angleDistribution.mode(), 216 | angleDerivative: angleDerivativeDistribution.mode()) 217 | } 218 | 219 | @inlinable 220 | public func sample() -> Observation { 221 | Observation( 222 | position: positionDistribution.sample(), 223 | positionDerivative: positionDerivativeDistribution.sample(), 224 | angle: angleDistribution.sample(), 225 | angleDerivative: angleDerivativeDistribution.sample()) 226 | } 227 | } 228 | } 229 | } 230 | 231 | public struct CartPoleRenderer: GLFWScene { 232 | public let windowWidth: Int 233 | public let windowHeight: Int 234 | public let worldWidth: Float 235 | public let scale: Float 236 | public let cartTop: Float 237 | public let poleWidth: Float 238 | public let poleLength: Float 239 | public let cartWidth: Float 240 | public let cartHeight: Float 241 | 242 | @usableFromInline internal var window: GLFWWindow 243 | @usableFromInline internal var cart: GLFWGeometry 244 | @usableFromInline internal var pole: GLFWGeometry 245 | @usableFromInline internal var axle: GLFWGeometry 246 | @usableFromInline internal var track: GLFWGeometry 247 | @usableFromInline internal var cartTransform: GLFWTransform 248 | @usableFromInline internal var poleTransform: GLFWTransform 249 | 250 | @inlinable 251 | public init( 252 | windowWidth: Int = 600, 253 | windowHeight: Int = 400, 254 | positionThreshold: Float = 2.4, 255 | cartTop: Float = 100.0, 256 | poleWidth: Float = 10.0, 257 | cartWidth: Float = 50.0, 258 | cartHeight: Float = 30.0 259 | ) { 260 | self.windowWidth = windowWidth 261 | self.windowHeight = windowHeight 262 | self.worldWidth = positionThreshold * 2 263 | self.scale = Float(windowWidth) / worldWidth 264 | self.cartTop = cartTop 265 | self.poleWidth = poleWidth 266 | self.poleLength = scale 267 | self.cartWidth = cartWidth 268 | self.cartHeight = cartHeight 269 | 270 | // Create the GLFW window along with all the shapes. 271 | self.window = try! GLFWWindow( 272 | name: "CartPole Environment", 273 | width: windowWidth, 274 | height: windowHeight, 275 | framesPerSecond: 60) 276 | let (cl, cr, ct, cb) = ( 277 | -cartWidth / 2, cartWidth / 2, 278 | cartHeight / 2, -cartHeight / 2) 279 | self.cart = GLFWPolygon(vertices: [(cl, cb), (cl, ct), (cr, ct), (cr, cb)]) 280 | self.cartTransform = GLFWTransform() 281 | self.cart.attributes.append(cartTransform) 282 | let (pl, pr, pt, pb) = ( 283 | -poleWidth / 2, poleWidth / 2, 284 | poleLength - poleWidth / 2, -poleWidth / 2) 285 | self.pole = GLFWPolygon(vertices: [(pl, pb), (pl, pt), (pr, pt), (pr, pb)]) 286 | self.pole.attributes.append(GLFWColor(red: 0.8, green: 0.6, blue: 0.4)) 287 | self.poleTransform = GLFWTransform(translation: (0.0, cartHeight / 4)) 288 | self.pole.attributes.append(poleTransform) 289 | self.pole.attributes.append(cartTransform) 290 | let axleVertices = (0..<30).map { i -> (Float, Float) in 291 | let angle = 2 * Float.pi * Float(i) / Float(30) 292 | return (cos(angle) * poleWidth / 2, sin(angle) * poleWidth / 2) 293 | } 294 | self.axle = GLFWPolygon(vertices: axleVertices) 295 | self.axle.attributes.append(poleTransform) 296 | self.axle.attributes.append(cartTransform) 297 | self.axle.attributes.append(GLFWColor(red: 0.5, green: 0.5, blue: 0.8)) 298 | self.track = GLFWLine(start: (0.0, cartTop), end: (Float(windowWidth), cartTop)) 299 | self.track.attributes.append(GLFWColor(red: 0, green: 0, blue: 0)) 300 | } 301 | 302 | @inlinable 303 | public func draw() { 304 | cart.renderWithAttributes() 305 | pole.renderWithAttributes() 306 | axle.renderWithAttributes() 307 | track.renderWithAttributes() 308 | } 309 | 310 | @inlinable 311 | public mutating func render(observation: CartPoleEnvironment.Observation) { 312 | // TODO: Support batched environments. 313 | let position = observation.position[0].scalarized() 314 | let angle = observation.angle[0].scalarized() 315 | cartTransform.translation = (position * scale + Float(windowWidth) / 2, cartTop) 316 | poleTransform.rotation = -angle 317 | render(in: window) 318 | } 319 | } 320 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Environments/Environment.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import Foundation 16 | import TensorFlow 17 | 18 | public protocol Environment { 19 | associatedtype ObservationSpace: Space 20 | associatedtype ActionSpace: Space 21 | associatedtype Reward 22 | 23 | var batchSize: Int { get } 24 | var observationSpace: ObservationSpace { get } 25 | var actionSpace: ActionSpace { get } 26 | 27 | /// Result of the last step taken in this environment (i.e., its current state). 28 | var currentStep: Step { mutating get } 29 | 30 | /// Updates the environment according to the provided action. 31 | @discardableResult 32 | mutating func step(taking action: Action) throws -> Step 33 | 34 | /// Resets the environment. 35 | @discardableResult 36 | mutating func reset() throws -> Step 37 | } 38 | 39 | public extension Environment { 40 | typealias Observation = ObservationSpace.Value 41 | typealias Action = ActionSpace.Value 42 | } 43 | 44 | public protocol RenderableEnvironment: Environment { 45 | mutating func render() throws 46 | } 47 | 48 | /// Contains the data emitted by an environment at a single step of interaction. 49 | public struct Step: KeyPathIterable { 50 | // TODO: Make `internal(set)` once `@usableFromInline` is supported. 51 | public var kind: StepKind 52 | public var observation: Observation 53 | public var reward: Reward 54 | 55 | @inlinable 56 | public init(kind: StepKind, observation: Observation, reward: Reward) { 57 | self.kind = kind 58 | self.observation = observation 59 | self.reward = reward 60 | } 61 | 62 | @inlinable 63 | public func copy( 64 | kind: StepKind? = nil, 65 | observation: Observation? = nil, 66 | reward: Reward? = nil 67 | ) -> Step { 68 | Step( 69 | kind: kind ?? self.kind, 70 | observation: observation ?? self.observation, 71 | reward: reward ?? self.reward) 72 | } 73 | } 74 | 75 | /// Represents the type of a step. 76 | public struct StepKind: KeyPathIterable { 77 | // TODO: Make `internal(set)` once `@usableFromInline` is supported. 78 | public var rawValue: Tensor 79 | 80 | @inlinable 81 | public init(_ rawValue: Tensor) { 82 | self.rawValue = rawValue 83 | } 84 | } 85 | 86 | extension StepKind { 87 | /// Denotes the first step in a sequence. 88 | @inlinable 89 | public static func first() -> StepKind { 90 | StepKind(Tensor(0)) 91 | } 92 | 93 | /// Denotes an transition step in a sequence (i.e., not first or last). 94 | @inlinable 95 | public static func transition() -> StepKind { 96 | StepKind(Tensor(1)) 97 | } 98 | 99 | /// Denotes the last step in a sequence. 100 | @inlinable 101 | public static func last(withReset: Bool = true) -> StepKind { 102 | StepKind(Tensor(withReset ? 3 : 2)) 103 | } 104 | 105 | /// Returns a batched `StepKind` filled with "first" step kind values. 106 | @inlinable 107 | public static func first(batchSize: Int) -> StepKind { 108 | StepKind(first().rawValue.expandingShape(at: 0) 109 | .tiled(multiples: Tensor([Int32(batchSize)]))) 110 | } 111 | 112 | /// Returns a batched `StepKind` filled with "transition" step kind values. 113 | @inlinable 114 | public static func transition(batchSize: Int) -> StepKind { 115 | StepKind(transition().rawValue.expandingShape(at: 0) 116 | .tiled(multiples: Tensor([Int32(batchSize)]))) 117 | } 118 | 119 | /// Returns a batched `StepKind` filled with "last" step kind values. 120 | @inlinable 121 | public static func last(batchSize: Int, withReset: Bool = true) -> StepKind { 122 | StepKind(last(withReset: withReset).rawValue.expandingShape(at: 0) 123 | .tiled(multiples: Tensor([Int32(batchSize)]))) 124 | } 125 | 126 | @inlinable 127 | public func isFirst() -> Tensor { 128 | rawValue .== 0 129 | } 130 | 131 | @inlinable 132 | public func isTransition() -> Tensor { 133 | rawValue .== 1 134 | } 135 | 136 | @inlinable 137 | public func isLast(withReset: Bool = false) -> Tensor { 138 | withReset ? rawValue .== 3 : (rawValue .== 2).elementsLogicalOr(rawValue .== 3) 139 | } 140 | 141 | /// Returns a tensor containing the number of completed episodes contained in the trajectory 142 | /// that this step kind corresponds to. 143 | @inlinable 144 | public func episodeCount(withReset: Bool = false) -> Tensor { 145 | Tensor(isLast(withReset: withReset)).sum() 146 | } 147 | 148 | /// Returns a boolean tensor whose `false`-valued elements correspond to steps of episodes that 149 | /// did not complete by the end of the trajectory that this step kind corresponds to. 150 | @inlinable 151 | public func completeEpisodeMask(withReset: Bool = false) -> Tensor { 152 | Tensor(isLast(withReset: withReset)).cumulativeSum(alongAxis: 0, reverse: true) .> 0 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Environments/Metrics.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | public protocol Metric { 18 | associatedtype Observation 19 | associatedtype State 20 | associatedtype Action 21 | associatedtype Reward 22 | associatedtype Value 23 | 24 | mutating func update(using trajectory: Trajectory) 25 | mutating func reset() 26 | func value() -> Value 27 | } 28 | 29 | public struct AverageEpisodeLength: Metric { 30 | public typealias Observation = Environment.Observation 31 | public typealias Action = Environment.Action 32 | public typealias Reward = Environment.Reward 33 | 34 | public let withReset: Bool 35 | 36 | @usableFromInline internal var deque: Deque 37 | @usableFromInline internal var episodeSteps: Tensor 38 | 39 | @inlinable 40 | public init(for environment: Environment, bufferSize: Int, withReset: Bool = true) { 41 | self.withReset = withReset 42 | self.deque = Deque(size: bufferSize) 43 | self.episodeSteps = Tensor(repeating: 0, shape: [environment.batchSize]) 44 | } 45 | 46 | @inlinable 47 | public mutating func update(using trajectory: Trajectory) { 48 | let isLast = trajectory.stepKind.isLast(withReset: withReset) 49 | let isNotLast = 1 - Tensor(isLast) 50 | episodeSteps += isNotLast 51 | for length in episodeSteps.gathering(where: isLast).scalars { 52 | deque.push(Float(length)) 53 | } 54 | episodeSteps *= isNotLast 55 | } 56 | 57 | @inlinable 58 | public mutating func reset() { 59 | deque.reset() 60 | episodeSteps = Tensor(zerosLike: episodeSteps) 61 | } 62 | 63 | @inlinable 64 | public func value() -> Float { 65 | deque.mean() 66 | } 67 | } 68 | 69 | public struct AverageEpisodeReward: Metric 70 | where Environment.Reward == Tensor { 71 | public typealias Observation = Environment.Observation 72 | public typealias Action = Environment.Action 73 | public typealias Reward = Environment.Reward 74 | 75 | public let withReset: Bool 76 | 77 | @usableFromInline internal var deque: Deque 78 | @usableFromInline internal var episodeRewards: Tensor 79 | 80 | @inlinable 81 | public init(for environment: Environment, bufferSize: Int, withReset: Bool = true) { 82 | self.withReset = withReset 83 | self.deque = Deque(size: bufferSize) 84 | self.episodeRewards = Tensor(repeating: 0, shape: [environment.batchSize]) 85 | } 86 | 87 | @inlinable 88 | public mutating func update(using trajectory: Trajectory) { 89 | let isLast = trajectory.stepKind.isLast(withReset: withReset) 90 | episodeRewards += trajectory.reward 91 | for reward in episodeRewards.gathering(where: isLast).scalars { 92 | deque.push(reward) 93 | } 94 | episodeRewards *= (1 - Tensor(isLast)) 95 | } 96 | 97 | @inlinable 98 | public mutating func reset() { 99 | deque.reset() 100 | episodeRewards = Tensor(zerosLike: episodeRewards) 101 | } 102 | 103 | @inlinable 104 | public func value() -> Float { 105 | deque.mean() 106 | } 107 | } 108 | 109 | public struct TotalCumulativeReward: Metric 110 | where Environment.Reward == Tensor { 111 | public typealias Observation = Environment.Observation 112 | public typealias Action = Environment.Action 113 | public typealias Reward = Environment.Reward 114 | 115 | @usableFromInline internal var rewards: Tensor 116 | 117 | @inlinable 118 | public init(for environment: Environment) { 119 | self.rewards = Tensor(repeating: 0, shape: [environment.batchSize]) 120 | } 121 | 122 | @inlinable 123 | public mutating func update(using trajectory: Trajectory) { 124 | rewards += trajectory.reward 125 | } 126 | 127 | @inlinable 128 | public mutating func reset() { 129 | rewards = Tensor(zerosLike: rewards) 130 | } 131 | 132 | @inlinable 133 | public func value() -> [Float] { 134 | rewards.scalars 135 | } 136 | } 137 | 138 | @usableFromInline 139 | internal struct Deque { 140 | @usableFromInline internal let size: Int 141 | @usableFromInline internal var buffer: [Scalar] 142 | @usableFromInline internal var index: Int 143 | @usableFromInline internal var full: Bool 144 | 145 | @inlinable 146 | init(size: Int) { 147 | self.size = size 148 | self.buffer = [Scalar](repeating: 0, count: size) 149 | self.index = 0 150 | self.full = false 151 | } 152 | 153 | @inlinable 154 | mutating func push(_ value: Scalar) { 155 | buffer[index] = value 156 | index += 1 157 | full = full || index == buffer.count 158 | index = index % buffer.count 159 | } 160 | 161 | @inlinable 162 | mutating func reset() { 163 | index = 0 164 | full = false 165 | } 166 | 167 | @inlinable 168 | func mean() -> Scalar { 169 | let sum = full ? buffer.reduce(0, +) : buffer[0.. Data 33 | 34 | /// Returns a batch sampled uniformly at random from the recorded data. 35 | /// 36 | /// - Parameters: 37 | /// - batchSize: Batch size. 38 | /// - stepCount: Number of time steps to include for each batch element. If 39 | /// `stepCount == nil`, the returned batch consists of tensor groups where each 40 | /// tensor has shape `[batchSize, ...]`. Otherwise, each such tensor has shape 41 | /// `[batchSize, stepCount, ...]`. 42 | /// - Returns: Batch sampled uniformly at random from the recorded data. 43 | func sampleBatch(batchSize: Int, stepCount: Int?) -> ReplayBufferBatch 44 | 45 | /// Resets the contents of this buffer. 46 | mutating func reset() 47 | } 48 | 49 | public struct ReplayBufferBatch { 50 | public let batch: Data 51 | public let ids: Tensor 52 | public let probabilities: Tensor 53 | } 54 | 55 | public struct UniformReplayBuffer: ReplayBuffer { 56 | public let batchSize: Int 57 | public let maxLength: Int 58 | public let capacity: Int 59 | 60 | internal var batchOffsets: Tensor { 61 | Tensor(rangeFrom: 0, to: Int64(batchSize), stride: 1) * Int64(maxLength) 62 | } 63 | 64 | internal let lastIDCounterDispatchQueue = DispatchQueue(label: "UniformReplayBuffer") 65 | internal var lastID: Int64 = -1 66 | 67 | internal var idsStorage: Tensor? = nil 68 | internal var dataStorage: Data? = nil 69 | 70 | public init(batchSize: Int, maxLength: Int) { 71 | self.batchSize = batchSize 72 | self.maxLength = maxLength 73 | self.capacity = batchSize * maxLength 74 | } 75 | 76 | /// Records the provided data batch to this replay buffer. 77 | public mutating func record(_ batch: Data) { 78 | if idsStorage == nil { 79 | idsStorage = Tensor(emptyLike: Tensor([1]), withCapacity: capacity) 80 | dataStorage = Data(emptyLike: batch, withCapacity: capacity) 81 | } 82 | let id: Int64 = lastIDCounterDispatchQueue.sync { 83 | lastID += 1 84 | return lastID 85 | } 86 | let indices = (batchOffsets + id % Int64(maxLength)).expandingShape(at: 1) 87 | idsStorage!.update( 88 | atIndices: indices, 89 | using: Tensor([Int64](repeating: id, count: batchSize))) 90 | dataStorage!.update(atIndices: indices, using: batch) 91 | } 92 | 93 | /// Returns all of the data recorded in this replay buffer. 94 | /// 95 | /// - Returns: Recorded data in the form of a tensor group where each tensor has shape 96 | /// `[maxSequenceLength, batchSize, ...]`. 97 | public func recordedData() -> Data { 98 | // Repeat `ids` over `batchSize` resulting in a tensor with shape 99 | // `[idsRange.1 - idsRange.0, batchSize, ...]`. 100 | let idsRange = validIDsRange() 101 | let ids = Tensor( 102 | rangeFrom: idsRange.0, 103 | to: idsRange.1, 104 | stride: 1 105 | ).expandingShape(at: 1).tiled(multiples: Tensor([1, Int32(batchSize)])) 106 | 107 | // Create the `batchOffsets` with shape `[1, batchSize]`, and then add them to `ids` to obtain 108 | // the row indices in the storage tensors. 109 | let batchOffsets = Tensor( 110 | rangeFrom: Tensor(0), 111 | to: Tensor(Int64(batchSize)), 112 | stride: Tensor(1) 113 | ).expandingShape(at: 0) * Int64(maxLength) 114 | let indices = ids % Int64(maxLength) + batchOffsets 115 | 116 | return dataStorage!.gathering(atIndices: indices) 117 | } 118 | 119 | /// Returns a batch sampled uniformly at random from the recorded data. 120 | /// 121 | /// - Parameters: 122 | /// - batchSize: Batch size. 123 | /// - stepCount: Number of time steps to include for each batch element. If 124 | /// `stepCount == nil`, the returned batch consists of tensor groups where each 125 | /// tensor has shape `[batchSize, ...]`. Otherwise, each such tensor has shape 126 | /// `[stepCount, batchSize, ...]`. 127 | /// - Returns: Batch sampled uniformly at random from the recorded data. 128 | public func sampleBatch(batchSize: Int, stepCount: Int? = nil) -> ReplayBufferBatch { 129 | let idsRange = validIDsRange(stepCount: Int64(stepCount ?? 1)) 130 | let idsCount = idsRange.1 - idsRange.0 131 | precondition(idsCount > 0 && idsStorage != nil && dataStorage != nil, "Empty buffer.") 132 | 133 | // Sample random IDs across multiple random batches. 134 | let batchOffsets = Tensor( 135 | randomUniform: [batchSize], 136 | lowerBound: Tensor(0), 137 | upperBound: Tensor(Int64(self.batchSize)) 138 | ) * Int64(maxLength) 139 | var indices = Tensor( 140 | randomUniform: [batchSize], 141 | lowerBound: Tensor(idsRange.0), 142 | upperBound: Tensor(idsRange.1) 143 | ) + batchOffsets 144 | 145 | if let stepCount = stepCount { 146 | indices = indices.expandingShape(at: 0) 147 | .tiled(multiples: Tensor([Int32(stepCount), 1])) 148 | let stepRange = Tensor(rangeFrom: 0, to: Int64(stepCount), stride: 1) 149 | .reshaped(to: [stepCount, 1]) 150 | .tiled(multiples: Tensor([1, Int32(batchSize)])) 151 | indices = (stepRange + indices) % Int64(capacity) 152 | } else { 153 | indices = indices % Int64(capacity) 154 | } 155 | 156 | let ids = idsStorage!.gathering(atIndices: indices) 157 | let batch = dataStorage!.gathering(atIndices: indices) 158 | let probabilities = Tensor( 159 | repeating: 1 / Float(idsCount * Int64(self.batchSize)), 160 | shape: [batchSize]) 161 | 162 | return ReplayBufferBatch(batch: batch, ids: ids, probabilities: probabilities) 163 | } 164 | 165 | /// Resets the contents of this buffer. 166 | public mutating func reset() { 167 | lastIDCounterDispatchQueue.sync { 168 | lastID = -1 169 | idsStorage = nil 170 | dataStorage = nil 171 | } 172 | } 173 | 174 | /// Returns the range of valid IDs. 175 | /// 176 | /// - Parameter stepCount: Optional way to specify how many IDs need to be valid. 177 | /// - Returns: Tuple representing the range, where the first element is the inclusive lower bound 178 | /// and the second element is the exclusive upper bound. 179 | /// - Note: When `stepCount` is provided, the upper bound of the range can be increased by up to 180 | /// `stepCount`. 181 | internal func validIDsRange(stepCount: Int64 = 1) -> (Int64, Int64) { 182 | let lastID = lastIDCounterDispatchQueue.sync { self.lastID } 183 | let minIDNotFull = Int64(0) 184 | let maxIDNotFull = max(lastID + 1 - stepCount + 1, 0) 185 | let minIDFull = lastID + 1 - Int64(maxLength) 186 | let maxIDFull = lastID + 1 - stepCount + 1 187 | return lastID < maxLength ? (minIDNotFull, maxIDNotFull) : (minIDFull, maxIDFull) 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Spaces.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | public protocol Space: CustomStringConvertible { 18 | associatedtype Value 19 | associatedtype ValueDistribution: Distribution where ValueDistribution.Value == Value 20 | 21 | var distribution: ValueDistribution { get } 22 | 23 | /// Returns a boolean specifying if `value` is a valid member of this space. 24 | func contains(_ value: Value) -> Bool 25 | } 26 | 27 | public extension Space { 28 | /// Sample a random element from this space. 29 | @inlinable 30 | func sample() -> Value { 31 | distribution.sample() 32 | } 33 | } 34 | 35 | public struct Discrete: Space { 36 | public let size: Int 37 | public let distribution: Categorical 38 | 39 | @inlinable 40 | public init(withSize size: Int, batchSize: Int) { 41 | self.size = size 42 | self.distribution = Categorical(logits: Tensor(ones: [batchSize, size])) 43 | } 44 | 45 | @inlinable 46 | public var description: String { 47 | "Discrete(\(size))" 48 | } 49 | 50 | @inlinable 51 | public func contains(_ value: Tensor) -> Bool { 52 | let v = value.scalarized() 53 | return value.rank < 2 && v >= 0 && v < Int32(size) 54 | } 55 | } 56 | 57 | public struct MultiBinary: Space { 58 | public let size: Int 59 | public let shape: TensorShape 60 | public let distribution: Bernoulli 61 | 62 | @inlinable 63 | public init(withSize size: Int, batchSize: Int) { 64 | self.size = size 65 | self.shape = [size] 66 | self.distribution = Bernoulli(logits: Tensor(ones: [batchSize, size])) 67 | } 68 | 69 | @inlinable 70 | public var description: String { 71 | "MultiBinary(\(size))" 72 | } 73 | 74 | @inlinable 75 | public func contains(_ value: Tensor) -> Bool { 76 | value.shape == shape && value.scalars.allSatisfy { $0 == 0 || $0 == 1 } 77 | } 78 | } 79 | 80 | public struct MultiDiscrete: Space { 81 | public let sizes: [Int] 82 | public let shape: TensorShape 83 | public let distribution: ValueDistribution 84 | 85 | @inlinable 86 | public init(withSizes sizes: [Int], batchSize: Int) { 87 | self.sizes = sizes 88 | self.shape = [sizes.count] 89 | self.distribution = ValueDistribution(sizes: sizes, batchSize: batchSize) 90 | } 91 | 92 | @inlinable 93 | public var description: String { 94 | "MultiDiscrete(\(sizes.map{String($0)}.joined(separator: ", ")))" 95 | } 96 | 97 | @inlinable 98 | public func contains(_ value: Tensor) -> Bool { 99 | let scalars = value.scalars 100 | return scalars.allSatisfy { $0 >= 0 } && zip(scalars, sizes).allSatisfy { $0 < $1 } 101 | } 102 | 103 | public struct ValueDistribution: DifferentiableDistribution { 104 | @noDerivative @usableFromInline internal let sizes: [Int] 105 | @usableFromInline internal var distributions: [Categorical] 106 | 107 | @inlinable 108 | public init(sizes: [Int], batchSize: Int) { 109 | self.sizes = sizes 110 | self.distributions = sizes.map { 111 | Categorical(logits: Tensor(ones: [batchSize, $0])) 112 | } 113 | } 114 | 115 | @inlinable 116 | @differentiable(wrt: self) 117 | public func logProbability(of value: Tensor) -> Tensor { 118 | let values = value.unstacked() 119 | var logProbability = Tensor(0.0) 120 | for i in 0.. Tensor { 129 | var entropy = Tensor(0.0) 130 | for i in 0.. Tensor { 138 | Tensor(concatenating: distributions.map { $0.mode() }) 139 | } 140 | 141 | @inlinable 142 | public func sample() -> Tensor { 143 | Tensor(concatenating: distributions.map { $0.sample() }) 144 | } 145 | } 146 | } 147 | 148 | public struct DiscreteBox: Space { 149 | public let shape: TensorShape 150 | public let lowerBound: Tensor 151 | public let upperBound: Tensor 152 | 153 | public let distribution: ValueDistribution 154 | 155 | @inlinable 156 | public init(shape: TensorShape, lowerBound: Scalar, upperBound: Scalar) { 157 | self.shape = shape 158 | self.lowerBound = Tensor(lowerBound) 159 | self.upperBound = Tensor(upperBound) 160 | self.distribution = ValueDistribution( 161 | distribution: Uniform( 162 | shape: Tensor(shape.dimensions.map(Int32.init)), 163 | lowerBound: Tensor(self.lowerBound), 164 | upperBound: Tensor(self.upperBound))) 165 | } 166 | 167 | @inlinable 168 | public init(lowerBound: Tensor, upperBound: Tensor) { 169 | precondition(lowerBound.shape == upperBound.shape, 170 | "'lowerBound' and 'upperBound' must have the same shape.") 171 | self.shape = lowerBound.shape 172 | self.lowerBound = lowerBound 173 | self.upperBound = upperBound 174 | self.distribution = ValueDistribution( 175 | distribution: Uniform( 176 | shape: Tensor(shape.dimensions.map(Int32.init)), 177 | lowerBound: Tensor(self.lowerBound), 178 | upperBound: Tensor(self.upperBound))) 179 | } 180 | 181 | @inlinable 182 | public var description: String { 183 | "DiscreteBox(\(shape.dimensions.map{String($0)}.joined(separator: ", ")))" 184 | } 185 | 186 | @inlinable 187 | public func contains(_ value: Tensor) -> Bool { 188 | let scalars = value.scalars 189 | return value.shape == shape && 190 | zip(scalars, lowerBound.scalars).allSatisfy{$0 >= $1} && 191 | zip(scalars, upperBound.scalars).allSatisfy{$0 <= $1} 192 | } 193 | 194 | public struct ValueDistribution: DifferentiableDistribution { 195 | @usableFromInline internal var distribution: Uniform 196 | 197 | @inlinable 198 | public init(distribution: Uniform) { 199 | self.distribution = distribution 200 | } 201 | 202 | @inlinable 203 | @differentiable(wrt: self) 204 | public func logProbability(of value: Tensor) -> Tensor { 205 | distribution.logProbability(of: Tensor(value)) 206 | } 207 | 208 | @inlinable 209 | @differentiable(wrt: self) 210 | public func entropy() -> Tensor { 211 | distribution.entropy() 212 | } 213 | 214 | @inlinable 215 | public func mode() -> Tensor { 216 | Tensor(distribution.mode()) 217 | } 218 | 219 | @inlinable 220 | public func sample() -> Tensor { 221 | Tensor(distribution.sample()) 222 | } 223 | } 224 | } 225 | 226 | public struct Box: Space { 227 | public let shape: TensorShape 228 | public let lowerBound: Tensor 229 | public let upperBound: Tensor 230 | 231 | public let distribution: Uniform 232 | 233 | @inlinable 234 | public init(shape: TensorShape, lowerBound: Scalar, upperBound: Scalar) { 235 | self.shape = shape 236 | self.lowerBound = Tensor(lowerBound) 237 | self.upperBound = Tensor(upperBound) 238 | self.distribution = Uniform( 239 | shape: Tensor(shape.dimensions.map(Int32.init)), 240 | lowerBound: self.lowerBound, 241 | upperBound: self.upperBound) 242 | } 243 | 244 | @inlinable 245 | public init(lowerBound: Tensor, upperBound: Tensor) { 246 | precondition(lowerBound.shape == upperBound.shape, 247 | "'lowerBound' and 'upperBound' must have the same shape.") 248 | self.shape = lowerBound.shape 249 | self.lowerBound = lowerBound 250 | self.upperBound = upperBound 251 | self.distribution = Uniform( 252 | shape: Tensor(shape.dimensions.map(Int32.init)), 253 | lowerBound: self.lowerBound, 254 | upperBound: self.upperBound) 255 | } 256 | 257 | @inlinable 258 | public var description: String { 259 | "Box(\(shape.dimensions.map{String($0)}.joined(separator: ", ")))" 260 | } 261 | 262 | @inlinable 263 | public func contains(_ value: Tensor) -> Bool { 264 | let scalars = value.scalars 265 | return value.shape == shape && 266 | zip(scalars, lowerBound.scalars).allSatisfy{$0 >= $1} && 267 | zip(scalars, upperBound.scalars).allSatisfy{$0 <= $1} 268 | } 269 | } 270 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Utilities/General.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import Foundation 16 | import TensorFlow 17 | 18 | #if os(Linux) 19 | import FoundationNetworking 20 | #endif 21 | 22 | public typealias TensorFlowSeed = (graph: Int32, op: Int32) 23 | 24 | public enum ReinforcementLearningError: Error { 25 | case renderingError(String) 26 | } 27 | 28 | public struct Empty: Differentiable, KeyPathIterable { 29 | public init() {} 30 | } 31 | 32 | public protocol Copyable { 33 | func copy() -> Self 34 | } 35 | 36 | public extension Encodable { 37 | func json(pretty: Bool = true) throws -> String { 38 | let encoder = JSONEncoder() 39 | if pretty { 40 | encoder.outputFormatting = .prettyPrinted 41 | } 42 | let data = try encoder.encode(self) 43 | return String(data: data, encoding: .utf8)! 44 | } 45 | } 46 | 47 | public extension Decodable { 48 | init(fromJson json: String) throws { 49 | let jsonDecoder = JSONDecoder() 50 | self = try jsonDecoder.decode(Self.self, from: json.data(using: .utf8)!) 51 | } 52 | } 53 | 54 | /// Downloads the file at `url` to `path`, if `path` does not exist. 55 | /// 56 | /// - Parameters: 57 | /// - from: URL to download data from. 58 | /// - to: Destination file path. 59 | /// 60 | /// - Returns: Boolean value indicating whether a download was 61 | /// performed (as opposed to not needed). 62 | public func maybeDownload(from url: URL, to destination: URL) throws { 63 | if !FileManager.default.fileExists(atPath: destination.path) { 64 | // Create any potentially missing directories. 65 | try FileManager.default.createDirectory( 66 | atPath: destination.deletingLastPathComponent().path, 67 | withIntermediateDirectories: true) 68 | 69 | // Create the URL session that will be used to download the dataset. 70 | let semaphore = DispatchSemaphore(value: 0) 71 | let delegate = DataDownloadDelegate(destinationFileUrl: destination, semaphore: semaphore) 72 | let session = URLSession(configuration: .default, delegate: delegate, delegateQueue: nil) 73 | 74 | // Download the data to a temporary file and then copy that file to 75 | // the destination path. 76 | print("Downloading \(url).") 77 | let task = session.downloadTask(with: url) 78 | task.resume() 79 | 80 | // Wait for the download to finish. 81 | semaphore.wait() 82 | } 83 | } 84 | 85 | internal class DataDownloadDelegate: NSObject, URLSessionDownloadDelegate { 86 | let destinationFileUrl: URL 87 | let semaphore: DispatchSemaphore 88 | let numBytesFrequency: Int64 89 | 90 | internal var logCount: Int64 = 0 91 | 92 | init( 93 | destinationFileUrl: URL, 94 | semaphore: DispatchSemaphore, 95 | numBytesFrequency: Int64 = 1024 * 1024 96 | ) { 97 | self.destinationFileUrl = destinationFileUrl 98 | self.semaphore = semaphore 99 | self.numBytesFrequency = numBytesFrequency 100 | } 101 | 102 | internal func urlSession( 103 | _ session: URLSession, 104 | downloadTask: URLSessionDownloadTask, 105 | didWriteData bytesWritten: Int64, 106 | totalBytesWritten: Int64, 107 | totalBytesExpectedToWrite: Int64 108 | ) -> Void { 109 | if (totalBytesWritten / numBytesFrequency > logCount) { 110 | let mBytesWritten = String(format: "%.2f", Float(totalBytesWritten) / (1024 * 1024)) 111 | if totalBytesExpectedToWrite > 0 { 112 | let mBytesExpectedToWrite = String( 113 | format: "%.2f", Float(totalBytesExpectedToWrite) / (1024 * 1024)) 114 | print("Downloaded \(mBytesWritten) MBs out of \(mBytesExpectedToWrite).") 115 | } else { 116 | print("Downloaded \(mBytesWritten) MBs.") 117 | } 118 | logCount += 1 119 | } 120 | } 121 | 122 | internal func urlSession( 123 | _ session: URLSession, 124 | downloadTask: URLSessionDownloadTask, 125 | didFinishDownloadingTo location: URL 126 | ) -> Void { 127 | logCount = 0 128 | do { 129 | try FileManager.default.moveItem(at: location, to: destinationFileUrl) 130 | } catch (let writeError) { 131 | print("Error writing file \(location.path) : \(writeError)") 132 | } 133 | print("The file was downloaded successfully to \(location.path).") 134 | semaphore.signal() 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Utilities/LearningRates.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | /// Learning rate schedule that takes the current training step as input and returns a learning 16 | /// rate to be used for training. 17 | public protocol LearningRate { 18 | associatedtype Scalar: FloatingPoint 19 | 20 | /// Returns the learning rate value for the specified training step. 21 | /// 22 | /// - Parameter step: Training step. 23 | func callAsFunction(forStep step: UInt64) -> Scalar 24 | } 25 | 26 | /// Dummy learning rate schedule that represents no schedule being used. This is useful as a 27 | /// default value whenever a learning rate schedule argument is used. 28 | public struct FixedLearningRate: LearningRate { 29 | public let value: Scalar 30 | 31 | @inlinable 32 | public init(_ value: Scalar) { 33 | self.value = value 34 | } 35 | 36 | @inlinable 37 | public func callAsFunction(forStep step: UInt64) -> Scalar { 38 | value 39 | } 40 | } 41 | 42 | extension FixedLearningRate: ExpressibleByFloatLiteral 43 | where Scalar: _ExpressibleByBuiltinFloatLiteral { 44 | public typealias FloatLiteralType = Scalar 45 | 46 | public init(floatLiteral value: Scalar) { 47 | self.init(value) 48 | } 49 | } 50 | 51 | /// Linearly decayed learning rate. 52 | /// 53 | /// The decayed learning rate is computed as follows: 54 | /// ```swift 55 | /// let initial = baseLearningRate(forStep: step) 56 | /// let decayed = initial + step * slope 57 | /// let decayedLearningRate = max(lowerBound * initial, decayed) 58 | /// ``` 59 | public struct LinearlyDecayedLearningRate: LearningRate { 60 | public typealias Scalar = BaseLearningRate.Scalar 61 | 62 | public let baseLearningRate: BaseLearningRate 63 | public let slope: Scalar 64 | public let lowerBound: Scalar 65 | public let startStep: UInt64 66 | 67 | /// Creates a new linearly decayed learning rate. 68 | /// 69 | /// - Parameters: 70 | /// - baseLearningRate: Learning rate to decay. 71 | /// - slope: Slope of the linear decay. 72 | /// - lowerBound: Minimum decayed learning rate value as a fraction of the original learning 73 | /// rate value. 74 | /// - startStep: Step after which to start decaying the learning rate. 75 | @inlinable 76 | public init( 77 | baseLearningRate: BaseLearningRate, 78 | slope: Scalar, 79 | lowerBound: Scalar = Scalar(0), 80 | startStep: UInt64 = 0 81 | ) { 82 | self.baseLearningRate = baseLearningRate 83 | self.slope = slope 84 | self.lowerBound = lowerBound 85 | self.startStep = startStep 86 | } 87 | 88 | @inlinable 89 | public func callAsFunction(forStep step: UInt64) -> Scalar { 90 | let learningRate = baseLearningRate(forStep: step) 91 | if step < startStep { return learningRate } 92 | let step = step - startStep 93 | let decayed = learningRate + Scalar(step) * slope 94 | return max(lowerBound * learningRate, decayed) 95 | } 96 | } 97 | 98 | /// Exponentially decayed learning rate. 99 | /// 100 | /// The decayed learning rate is computed as follows: 101 | /// ```swift 102 | /// let initial = baseLearningRate(forStep: step) 103 | /// let decay = decayRate ^ (step / decayStepCount) 104 | /// let decayedLearningRate = initial * ((1 - lowerBound) * decay + lowerBound) 105 | /// ``` 106 | /// where if `staircase = true`, then `step / decayStepCount` uses integer division and the decayed 107 | /// learning rate follows a staircase function. 108 | public struct ExponentiallyDecayedLearningRate: LearningRate 109 | where BaseLearningRate.Scalar: ElementaryFunctions { 110 | public typealias Scalar = BaseLearningRate.Scalar 111 | 112 | public let baseLearningRate: BaseLearningRate 113 | public let decayRate: Scalar 114 | public let decayStepCount: UInt64 115 | public let staircase: Bool 116 | public let lowerBound: Scalar 117 | public let startStep: UInt64 118 | 119 | /// Creates a new exponentially decayed learning rate. 120 | /// 121 | /// - Parameters: 122 | /// - baseLearningRate: Learning rate to decay. 123 | /// - decayRate: Decay rate. 124 | /// - decayStepCount: Decay step count. 125 | /// - staircase: If `true`, the decay will occur at discrete intervals. 126 | /// - lowerBound: Minimum decayed learning rate value as a fraction of the original learning 127 | /// rate value. 128 | /// - startStep: Step after which to start decaying the learning rate. 129 | @inlinable 130 | public init( 131 | baseLearningRate: BaseLearningRate, 132 | decayRate: Scalar, 133 | decayStepCount: UInt64, 134 | staircase: Bool = false, 135 | lowerBound: Scalar = Scalar(0), 136 | startStep: UInt64 = 0 137 | ) { 138 | self.baseLearningRate = baseLearningRate 139 | self.decayRate = decayRate 140 | self.decayStepCount = decayStepCount 141 | self.staircase = staircase 142 | self.lowerBound = lowerBound 143 | self.startStep = startStep 144 | } 145 | 146 | @inlinable 147 | public func callAsFunction(forStep step: UInt64) -> Scalar { 148 | let learningRate = baseLearningRate(forStep: step) 149 | if step < startStep { return learningRate } 150 | let step = step - startStep 151 | let power = Scalar(step) / Scalar(decayStepCount) 152 | let decay = Scalar.pow(decayRate, staircase ? power.rounded(.down) : power) 153 | return learningRate * ((1 - lowerBound) * decay + lowerBound) 154 | } 155 | } 156 | 157 | /// Reciprocal square root decayed learning rate. 158 | /// 159 | /// The decayed learning rate is computed as follows: 160 | /// ```swift 161 | /// let initial = baseLearningRate(forStep: step) 162 | /// let decay = decayFactor / sqrt(max(step, decayThreshold)) 163 | /// let decayedLearningRate = initial * ((1 - lowerBound) * decay + lowerBound) 164 | /// ``` 165 | public struct RSqrtLearningRateDecay: LearningRate 166 | where BaseLearningRate.Scalar: ElementaryFunctions { 167 | public typealias Scalar = BaseLearningRate.Scalar 168 | 169 | public let baseLearningRate: BaseLearningRate 170 | public let decayFactor: Scalar 171 | public let decayThreshold: Scalar 172 | public let lowerBound: Scalar 173 | public let startStep: UInt64 174 | 175 | /// Creates a new reciprocal square root decayed learning rate. 176 | /// 177 | /// - Parameters: 178 | /// - baseLearningRate: Learning rate to decay. 179 | /// - decayFactor: Decay factor. 180 | /// - decayThreshold: Decay threshold. 181 | /// - lowerBound: Minimum decayed learning rate value as a fraction of the original learning 182 | /// rate value. 183 | /// - startStep: Step after which to start decaying the learning rate. 184 | @inlinable 185 | public init( 186 | baseLearningRate: BaseLearningRate, 187 | decayFactor: Scalar, 188 | decayThreshold: Scalar, 189 | lowerBound: Scalar = Scalar(0), 190 | startStep: UInt64 = 0 191 | ) { 192 | self.baseLearningRate = baseLearningRate 193 | self.decayFactor = decayFactor 194 | self.decayThreshold = decayThreshold 195 | self.lowerBound = lowerBound 196 | self.startStep = startStep 197 | } 198 | 199 | @inlinable 200 | public func callAsFunction(forStep step: UInt64) -> Scalar { 201 | let learningRate = baseLearningRate(forStep: step) 202 | if step < startStep { return learningRate } 203 | let step = step - startStep 204 | let decay = decayFactor / Scalar.sqrt(max(Scalar(step), decayThreshold)) 205 | return learningRate * ((1 - lowerBound) * decay + lowerBound) 206 | } 207 | } 208 | 209 | /// Cosine decayed learning rate. 210 | /// 211 | /// The decayed learning rate is computed as follows: 212 | /// ```swift 213 | /// let initial = baseLearningRate(forStep: step) 214 | /// let decay = 0.5 * (1 + cos(pi * min(step, cycleStepCount) / cycleStepCount)) 215 | /// let decayedLearningRate = initial * ((1 - lowerBound) * decay + lowerBound) 216 | /// ``` 217 | public struct CosineDecayedLearningRate: LearningRate 218 | where BaseLearningRate.Scalar: ElementaryFunctions { 219 | public typealias Scalar = BaseLearningRate.Scalar 220 | 221 | public let baseLearningRate: BaseLearningRate 222 | public let cycleStepCount: UInt64 223 | public let lowerBound: Scalar 224 | public let startStep: UInt64 225 | 226 | /// Creates a new cosine decayed learning rate. 227 | /// 228 | /// - Parameters: 229 | /// - baseLearningRate: Learning rate to decay. 230 | /// - cycleStepCount: Cosine decay cycle in terms of number of steps. 231 | /// - lowerBound: Minimum decayed learning rate value as a fraction of the original learning 232 | /// rate value. 233 | /// - startStep: Step after which to start decaying the learning rate. 234 | @inlinable 235 | public init( 236 | baseLearningRate: BaseLearningRate, 237 | cycleStepCount: UInt64, 238 | lowerBound: Scalar = Scalar(0), 239 | startStep: UInt64 = 0 240 | ) { 241 | self.baseLearningRate = baseLearningRate 242 | self.cycleStepCount = cycleStepCount 243 | self.lowerBound = lowerBound 244 | self.startStep = startStep 245 | } 246 | 247 | @inlinable 248 | public func callAsFunction(forStep step: UInt64) -> Scalar { 249 | let learningRate = baseLearningRate(forStep: step) 250 | if step < startStep { return learningRate } 251 | let step = step - startStep 252 | let cosine = Scalar.cos(Scalar(min(step, cycleStepCount))) 253 | let decay = (1 + cosine) * Scalar.pi / Scalar(2 * cycleStepCount) 254 | return learningRate * ((1 - lowerBound) * decay + lowerBound) 255 | } 256 | } 257 | 258 | /// Cycle-linear 10x decayed learning rate. 259 | /// 260 | /// The decayed learning rate is computed as follows: 261 | /// ```swift 262 | /// let initial = baseLearningRate(forStep: step) 263 | /// let cyclePosition = 1 - abs((step % (2 * cycleStepCount) - cycleStepCount) / cycleStepCount) 264 | /// let decay = (0.1 + cyclePosition) * 3 265 | /// let decayedLearningRate = initial * ((1 - lowerBound) * decay + lowerBound) 266 | /// ``` 267 | public struct CycleLinear10xLearningRateDecay: LearningRate { 268 | public typealias Scalar = BaseLearningRate.Scalar 269 | 270 | public let baseLearningRate: BaseLearningRate 271 | public let cycleStepCount: UInt64 272 | public let lowerBound: Scalar 273 | public let startStep: UInt64 274 | 275 | /// Creates a new cycle-linear 10x decayed learning rate. 276 | /// 277 | /// - Parameters: 278 | /// - baseLearningRate: Learning rate to decay. 279 | /// - cycleStepCount: Cycle-linear 10x decay cycle in terms of number of steps. 280 | /// - lowerBound: Minimum decayed learning rate value as a fraction of the original learning 281 | /// rate value. 282 | /// - startStep: Step after which to start decaying the learning rate. 283 | @inlinable 284 | public init( 285 | baseLearningRate: BaseLearningRate, 286 | cycleStepCount: UInt64, 287 | lowerBound: Scalar = Scalar(0), 288 | startStep: UInt64 = 0 289 | ) { 290 | self.baseLearningRate = baseLearningRate 291 | self.cycleStepCount = cycleStepCount 292 | self.lowerBound = lowerBound 293 | self.startStep = startStep 294 | } 295 | 296 | @inlinable 297 | public func callAsFunction(forStep step: UInt64) -> Scalar { 298 | let learningRate = baseLearningRate(forStep: step) 299 | if step < startStep { return learningRate } 300 | let step = step - startStep 301 | let ratio = Scalar((step % (2 * cycleStepCount) - cycleStepCount)) / Scalar(cycleStepCount) 302 | let cyclePosition = 1 - abs(ratio) 303 | let decay = (1 / Scalar(10) + cyclePosition) * 3 // 10x difference in each cycle (0.3 - 3). 304 | return learningRate * ((1 - lowerBound) * decay + lowerBound) 305 | } 306 | } 307 | 308 | /// Linearly warmed-up learning rate. 309 | /// 310 | /// For the first `warmUpStepCount` steps the base learning rate is multiplied with: 311 | /// ``` 312 | /// warmUpOffset + ((1 - warmUpOffset) / warmUpStepCount) * step 313 | /// ``` 314 | /// 315 | /// - Source: [Attention is All You Need (Section 5.3)](https://arxiv.org/pdf/1706.03762.pdf). 316 | public struct LinearlyWarmedUpLearningRate: LearningRate { 317 | public typealias Scalar = BaseLearningRate.Scalar 318 | 319 | public let baseLearningRate: BaseLearningRate 320 | public let warmUpStepCount: UInt64 321 | public let warmUpOffset: Scalar 322 | 323 | /// Creates a new linear learning rate warm-up schedule. 324 | /// 325 | /// - Parameters: 326 | /// - baseLearningRate: Learning rate to warm-up. 327 | /// - warmUpStepCount: Number of warm-up steps. 328 | /// - warmUpOffset: Linear schedule offset. 329 | @inlinable 330 | public init( 331 | baseLearningRate: BaseLearningRate, 332 | warmUpStepCount: UInt64, 333 | warmUpOffset: Scalar 334 | ) { 335 | self.baseLearningRate = baseLearningRate 336 | self.warmUpStepCount = warmUpStepCount 337 | self.warmUpOffset = warmUpOffset 338 | } 339 | 340 | @inlinable 341 | public func callAsFunction(forStep step: UInt64) -> Scalar { 342 | let learningRate = baseLearningRate(forStep: step) 343 | if step >= warmUpStepCount { return learningRate } 344 | let factor = warmUpOffset + ((1 - warmUpOffset) / Scalar(warmUpStepCount)) * Scalar(step) 345 | return learningRate * factor 346 | } 347 | } 348 | 349 | /// Exponentially warmed-up learning rate. 350 | /// 351 | /// For the first `warmUpStepCount` steps the base learning rate is multiplied with: 352 | /// ``` 353 | /// exp(log(warmUpFactor) / step) ^ (warmUpStepCount - step) 354 | /// ``` 355 | /// 356 | /// - Source: [Attention is All You Need (Section 5.3)](https://arxiv.org/pdf/1706.03762.pdf). 357 | public struct ExponentialLearningRateWarmUp: LearningRate 358 | where BaseLearningRate.Scalar: ElementaryFunctions { 359 | public typealias Scalar = BaseLearningRate.Scalar 360 | 361 | public let baseLearningRate: BaseLearningRate 362 | public let warmUpStepCount: UInt64 363 | public let warmUpFactor: Scalar 364 | 365 | /// Creates a new linear learning rate warm-up schedule. 366 | /// 367 | /// - Parameters: 368 | /// - baseLearningRate: Learning rate to warm-up. 369 | /// - warmUpStepCount: Number of warm-up steps. 370 | /// - warmUpFactor: Warm-up learning rate scaling factor. 371 | @inlinable 372 | public init( 373 | baseLearningRate: BaseLearningRate, 374 | warmUpStepCount: UInt64, 375 | warmUpFactor: Scalar 376 | ) { 377 | self.baseLearningRate = baseLearningRate 378 | self.warmUpStepCount = warmUpStepCount 379 | self.warmUpFactor = warmUpFactor 380 | } 381 | 382 | @inlinable 383 | public func callAsFunction(forStep step: UInt64) -> Scalar { 384 | let learningRate = baseLearningRate(forStep: step) 385 | if step >= warmUpStepCount { return learningRate } 386 | let base = Scalar.exp(Scalar.log(warmUpFactor) / Scalar(warmUpStepCount)) 387 | let factor = Scalar.pow(base, Scalar(warmUpStepCount - step)) 388 | return learningRate * factor 389 | } 390 | } 391 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Utilities/Normalization.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | public protocol Normalizer { 18 | associatedtype Value 19 | 20 | func normalize(_ value: Value) -> Value 21 | mutating func update(using value: Value) 22 | mutating func reset() 23 | } 24 | 25 | public struct TensorNormalizer: Normalizer { 26 | public let axes: Tensor 27 | public let streaming: Bool 28 | 29 | private var count: Tensor? 30 | private var valueSum: Tensor? 31 | private var valueSquaredSum: Tensor? 32 | 33 | public init(streaming: Bool, alongAxes axes: Tensor) { 34 | self.axes = axes 35 | self.streaming = streaming 36 | self.count = streaming ? Tensor(Scalar(Float.ulpOfOne)) : nil 37 | self.valueSum = streaming ? Tensor(zeros: []) : nil 38 | self.valueSquaredSum = streaming ? Tensor(zeros: []) : nil 39 | } 40 | 41 | public init(streaming: Bool, alongAxes axes: [Int]) { 42 | self.init(streaming: streaming, alongAxes: Tensor(axes.map(Int32.init))) 43 | } 44 | 45 | public init(streaming: Bool, alongAxes axes: Int...) { 46 | self.init(streaming: streaming, alongAxes: Tensor(axes.map(Int32.init))) 47 | } 48 | 49 | public func normalize(_ value: Tensor) -> Tensor { 50 | if streaming { 51 | let mean = valueSum! / count! 52 | let variance = (valueSquaredSum! - valueSum!.squared() / count!) / count! 53 | return (value - mean) / (sqrt(variance) + Scalar(Float.ulpOfOne)) 54 | } 55 | let moments = value.moments(alongAxes: axes) 56 | return (value - moments.mean) / (sqrt(moments.variance) + Scalar(Float.ulpOfOne)) 57 | } 58 | 59 | public mutating func update(using value: Tensor) { 60 | if streaming { 61 | count = count! + Tensor(value.shapeTensor.gathering(atIndices: axes).product()) 62 | valueSum = valueSum! + value.sum(alongAxes: axes) 63 | valueSquaredSum = valueSquaredSum! + value.squared().sum(alongAxes: axes) 64 | } 65 | } 66 | 67 | public mutating func reset() { 68 | if streaming { 69 | count = Tensor(Scalar(Float.ulpOfOne)) 70 | valueSum = Tensor(zeros: []) 71 | valueSquaredSum = Tensor(zeros: []) 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Utilities/Protocols.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | public protocol Stackable { 18 | static func stack(_ values: [Self]) -> Self 19 | func unstacked() -> [Self] 20 | } 21 | 22 | public protocol DifferentiableStackable: Stackable, Differentiable { 23 | @differentiable 24 | static func stack(_ values: [Self]) -> Self 25 | 26 | @differentiable 27 | func unstacked() -> [Self] 28 | } 29 | 30 | extension Tensor: Stackable { 31 | public static func stack(_ values: [Tensor]) -> Tensor { 32 | Tensor(stacking: values, alongAxis: 0) 33 | } 34 | 35 | public func unstacked() -> [Tensor] { 36 | unstacked(alongAxis: 0) 37 | } 38 | } 39 | 40 | extension Tensor: DifferentiableStackable where Scalar: TensorFlowFloatingPoint { 41 | @differentiable 42 | public static func stack(_ values: [Tensor]) -> Tensor { 43 | Tensor(stacking: values, alongAxis: 0) 44 | } 45 | 46 | @differentiable 47 | public func unstacked() -> [Tensor] { 48 | unstacked(alongAxis: 0) 49 | } 50 | } 51 | 52 | public protocol Replayable { 53 | init(emptyLike example: Self, withCapacity capacity: Int) 54 | 55 | mutating func update(atIndices indices: Tensor, using values: Self) 56 | func gathering(atIndices indices: Tensor) -> Self 57 | } 58 | 59 | public protocol Batchable { 60 | func flattenedBatch(outerDimCount: Int) -> Self 61 | func unflattenedBatch(outerDims: [Int]) -> Self 62 | } 63 | 64 | public protocol DifferentiableBatchable: Batchable, Differentiable { 65 | @differentiable(wrt: self) 66 | func flattenedBatch(outerDimCount: Int) -> Self 67 | 68 | @differentiable(wrt: self) 69 | func unflattenedBatch(outerDims: [Int]) -> Self 70 | } 71 | 72 | extension Tensor: Replayable where Scalar: Numeric { 73 | public init(emptyLike example: Tensor, withCapacity capacity: Int) { 74 | if example.rank <= 1 { 75 | self.init(zeros: [capacity]) 76 | } else { 77 | self.init(zeros: TensorShape([capacity] + example.shape.dimensions[1...])) 78 | } 79 | } 80 | 81 | public mutating func update(atIndices indices: Tensor, using values: Tensor) { 82 | self = _Raw.tensorScatterUpdate(self, indices: indices, updates: values) 83 | } 84 | 85 | public func gathering(atIndices indices: Tensor) -> Tensor { 86 | gathering(atIndices: indices, alongAxis: 0) 87 | } 88 | } 89 | 90 | extension Tensor: Batchable { 91 | public func flattenedBatch(outerDimCount: Int) -> Tensor { 92 | if outerDimCount == 1 { 93 | return self 94 | } 95 | var newShape = [-1] 96 | for i in outerDimCount.. Tensor { 103 | if rank > 1 { 104 | return reshaped(to: TensorShape(outerDims + shape.dimensions[1...])) 105 | } 106 | return reshaped(to: TensorShape(outerDims)) 107 | } 108 | } 109 | 110 | extension Tensor: DifferentiableBatchable where Scalar: TensorFlowFloatingPoint { 111 | @differentiable(wrt: self) 112 | public func flattenedBatch(outerDimCount: Int) -> Tensor { 113 | if outerDimCount == 1 { 114 | return self 115 | } 116 | // TODO: Remove this hack once the S4TF auto-diff memory leak is fixed. 117 | let newShape = Swift.withoutDerivative(at: self.shape) { shape -> [Int] in 118 | var newShape = [-1] 119 | for i in outerDimCount.. Tensor { 129 | if rank > 1 { 130 | return reshaped(to: TensorShape(outerDims + shape.dimensions[1...])) 131 | } 132 | return reshaped(to: TensorShape(outerDims)) 133 | } 134 | } 135 | 136 | extension Tensor where Scalar: TensorFlowFloatingPoint { 137 | public mutating func update(using other: Tensor, forgetFactor: Float) { 138 | let forgetFactor = Scalar(forgetFactor) 139 | self = forgetFactor * self + (1 - forgetFactor) * other 140 | } 141 | } 142 | 143 | extension KeyPathIterable { 144 | public static func stack(_ values: [Self]) -> Self { 145 | var result = values[0] 146 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 147 | result[keyPath: kp] = Tensor.stack(values.map { $0[keyPath: kp] }) 148 | } 149 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 150 | result[keyPath: kp] = Tensor.stack(values.map { $0[keyPath: kp] }) 151 | } 152 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 153 | result[keyPath: kp] = Tensor.stack(values.map { $0[keyPath: kp] }) 154 | } 155 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 156 | result[keyPath: kp] = Tensor.stack(values.map { $0[keyPath: kp] }) 157 | } 158 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 159 | result[keyPath: kp] = Tensor.stack(values.map { $0[keyPath: kp] }) 160 | } 161 | return result 162 | } 163 | 164 | public func unstacked() -> [Self] { 165 | var result = [Self]() 166 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 167 | let unstacked = self[keyPath: kp].unstacked() 168 | if result.isEmpty { 169 | result = [Self](repeating: self, count: unstacked.count) 170 | } 171 | for i in result.indices { 172 | result[i][keyPath: kp] = unstacked[i] 173 | } 174 | } 175 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 176 | let unstacked = self[keyPath: kp].unstacked() 177 | if result.isEmpty { 178 | result = [Self](repeating: self, count: unstacked.count) 179 | } 180 | for i in result.indices { 181 | result[i][keyPath: kp] = unstacked[i] 182 | } 183 | } 184 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 185 | let unstacked = self[keyPath: kp].unstacked() 186 | if result.isEmpty { 187 | result = [Self](repeating: self, count: unstacked.count) 188 | } 189 | for i in result.indices { 190 | result[i][keyPath: kp] = unstacked[i] 191 | } 192 | } 193 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 194 | let unstacked = self[keyPath: kp].unstacked() 195 | if result.isEmpty { 196 | result = [Self](repeating: self, count: unstacked.count) 197 | } 198 | for i in result.indices { 199 | result[i][keyPath: kp] = unstacked[i] 200 | } 201 | } 202 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 203 | let unstacked = self[keyPath: kp].unstacked() 204 | if result.isEmpty { 205 | result = [Self](repeating: self, count: unstacked.count) 206 | } 207 | for i in result.indices { 208 | result[i][keyPath: kp] = unstacked[i] 209 | } 210 | } 211 | return result 212 | } 213 | 214 | public init(emptyLike example: Self, withCapacity capacity: Int) { 215 | self = example 216 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 217 | self[keyPath: kp] = Tensor(emptyLike: example[keyPath: kp], withCapacity: capacity) 218 | } 219 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 220 | self[keyPath: kp] = Tensor(emptyLike: example[keyPath: kp], withCapacity: capacity) 221 | } 222 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 223 | self[keyPath: kp] = Tensor(emptyLike: example[keyPath: kp], withCapacity: capacity) 224 | } 225 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 226 | self[keyPath: kp] = Tensor(emptyLike: example[keyPath: kp], withCapacity: capacity) 227 | } 228 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 229 | self[keyPath: kp] = Tensor(emptyLike: example[keyPath: kp], withCapacity: capacity) 230 | } 231 | } 232 | 233 | public mutating func update(atIndices indices: Tensor, using values: Self) { 234 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 235 | self[keyPath: kp].update(atIndices: indices, using: values[keyPath: kp]) 236 | } 237 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 238 | self[keyPath: kp].update(atIndices: indices, using: values[keyPath: kp]) 239 | } 240 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 241 | self[keyPath: kp].update(atIndices: indices, using: values[keyPath: kp]) 242 | } 243 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 244 | self[keyPath: kp].update(atIndices: indices, using: values[keyPath: kp]) 245 | } 246 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 247 | self[keyPath: kp].update(atIndices: indices, using: values[keyPath: kp]) 248 | } 249 | } 250 | 251 | public func gathering(atIndices indices: Tensor) -> Self { 252 | var result = self 253 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 254 | result[keyPath: kp] = result[keyPath: kp].gathering(atIndices: indices) 255 | } 256 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 257 | result[keyPath: kp] = result[keyPath: kp].gathering(atIndices: indices) 258 | } 259 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 260 | result[keyPath: kp] = result[keyPath: kp].gathering(atIndices: indices) 261 | } 262 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 263 | result[keyPath: kp] = result[keyPath: kp].gathering(atIndices: indices) 264 | } 265 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 266 | result[keyPath: kp] = result[keyPath: kp].gathering(atIndices: indices) 267 | } 268 | return result 269 | } 270 | 271 | public func flattenedBatch(outerDimCount: Int) -> Self { 272 | var result = self 273 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 274 | result[keyPath: kp] = result[keyPath: kp].flattenedBatch(outerDimCount: outerDimCount) 275 | } 276 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 277 | result[keyPath: kp] = result[keyPath: kp].flattenedBatch(outerDimCount: outerDimCount) 278 | } 279 | return result 280 | } 281 | 282 | public func unflattenedBatch(outerDims: [Int]) -> Self { 283 | var result = self 284 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 285 | result[keyPath: kp] = result[keyPath: kp].unflattenedBatch(outerDims: outerDims) 286 | } 287 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 288 | result[keyPath: kp] = result[keyPath: kp].unflattenedBatch(outerDims: outerDims) 289 | } 290 | return result 291 | } 292 | 293 | public mutating func update(using other: Self, forgetFactor: Float) { 294 | var result = self 295 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 296 | result[keyPath: kp].update(using: other[keyPath: kp], forgetFactor: forgetFactor) 297 | } 298 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 299 | result[keyPath: kp].update(using: other[keyPath: kp], forgetFactor: forgetFactor) 300 | } 301 | } 302 | } 303 | 304 | extension KeyPathIterable where Self: Differentiable, Self.TangentVector: KeyPathIterable { 305 | // TODO: Differentiable `stack` and `unstacked`. 306 | 307 | @differentiable(wrt: self, vjp: _vjpFlattenedBatch) 308 | public func flattenedBatch(outerDimCount: Int) -> Self { 309 | var result = self 310 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 311 | result[keyPath: kp] = result[keyPath: kp].flattenedBatch(outerDimCount: outerDimCount) 312 | } 313 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 314 | result[keyPath: kp] = result[keyPath: kp].flattenedBatch(outerDimCount: outerDimCount) 315 | } 316 | return result 317 | } 318 | 319 | @differentiable(wrt: self, vjp: _vjpUnflattenedBatch) 320 | public func unflattenedBatch(outerDims: [Int]) -> Self { 321 | var result = self 322 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 323 | result[keyPath: kp] = result[keyPath: kp].unflattenedBatch(outerDims: outerDims) 324 | } 325 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 326 | result[keyPath: kp] = result[keyPath: kp].unflattenedBatch(outerDims: outerDims) 327 | } 328 | return result 329 | } 330 | } 331 | 332 | internal extension KeyPathIterable 333 | where Self: Differentiable, Self.TangentVector: KeyPathIterable { 334 | @usableFromInline 335 | func _vjpFlattenedBatch(outerDimCount: Int) -> (Self, (TangentVector) -> TangentVector) { 336 | // TODO: This is very hacky. 337 | var outerDims = [Int]() 338 | for kp in recursivelyAllWritableKeyPaths(to: Tensor.self) { 339 | outerDims = [Int](self[keyPath: kp].shape.dimensions[0...self) { 346 | result[keyPath: kp] = seed[keyPath: kp].unflattenedBatch(outerDims: outerDims) 347 | } 348 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 349 | result[keyPath: kp] = seed[keyPath: kp].unflattenedBatch(outerDims: outerDims) 350 | } 351 | return result 352 | }) 353 | } 354 | 355 | @usableFromInline 356 | func _vjpUnflattenedBatch(outerDims: [Int]) -> (Self, (TangentVector) -> TangentVector) { 357 | let result = unflattenedBatch(outerDims: outerDims) 358 | return (result, { seed in 359 | var result = seed 360 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 361 | result[keyPath: kp] = seed[keyPath: kp].flattenedBatch(outerDimCount: outerDims.count) 362 | } 363 | for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { 364 | result[keyPath: kp] = seed[keyPath: kp].flattenedBatch(outerDimCount: outerDims.count) 365 | } 366 | return result 367 | }) 368 | } 369 | } 370 | 371 | extension KeyPathIterable { 372 | public mutating func clipByGlobalNorm(clipNorm: Scalar) { 373 | let clipNorm = Tensor(clipNorm) 374 | var globalNorm = Tensor(zeros: []) 375 | for kp in self.recursivelyAllWritableKeyPaths(to: Tensor.self) { 376 | globalNorm += self[keyPath: kp].squared().sum() 377 | } 378 | globalNorm = sqrt(globalNorm) 379 | for kp in self.recursivelyAllWritableKeyPaths(to: Tensor.self) { 380 | self[keyPath: kp] *= clipNorm / max(globalNorm, clipNorm) 381 | } 382 | } 383 | } 384 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Utilities/Rendering.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import CGLFW 16 | import Foundation 17 | import TensorFlow 18 | 19 | public class ImageRenderer { 20 | public let framesPerSecond: Double? 21 | 22 | @usableFromInline internal let initialMaxWidth: Int32 23 | @usableFromInline internal var window: OpaquePointer? 24 | @usableFromInline internal var frameBuffer: GLuint = 0 25 | @usableFromInline internal var texture: GLuint = 0 26 | @usableFromInline internal var isOpen: Bool = true 27 | 28 | @inlinable 29 | public init(initialMaxWidth: Int32 = 800, framesPerSecond: Double? = nil) { 30 | self.initialMaxWidth = initialMaxWidth 31 | self.framesPerSecond = framesPerSecond 32 | } 33 | 34 | @inlinable 35 | deinit { 36 | closeWindow() 37 | } 38 | 39 | @inlinable 40 | public func render(_ data: ShapedArray) throws { 41 | if !isOpen { return } 42 | if let fps = framesPerSecond { Thread.sleep(forTimeInterval: 1 / fps) } 43 | 44 | if self.window == nil { 45 | var width = Int32(data.shape[1]) 46 | var height = Int32(data.shape[0]) 47 | if width > initialMaxWidth { 48 | let scale = Float(initialMaxWidth) / Float(width) 49 | width = Int32(scale * Float(width)) 50 | height = Int32(scale * Float(height)) 51 | } 52 | try createWindow(width: width, height: height) 53 | } 54 | 55 | if let window = self.window { 56 | if glfwWindowShouldClose(window) > 0 { closeWindow(); exit(0) } 57 | glfwMakeContextCurrent(window) 58 | glClear(UInt32(GL_COLOR_BUFFER_BIT)) 59 | 60 | // Generate the image texture. 61 | try ImageRenderer.preprocessData(data).withUnsafeBufferPointer { 62 | glTexImage2D( 63 | GLenum(GL_TEXTURE_2D), 0, GL_RGB8, GLsizei(data.shape[1]), GLsizei(data.shape[0]), 64 | 0, GLenum(GL_RGB), GLenum(GL_UNSIGNED_BYTE), $0.baseAddress) 65 | } 66 | 67 | // Resize and render the texture. 68 | var width: GLsizei = 0 69 | var height: GLsizei = 0 70 | glfwGetFramebufferSize(window, &width, &height) 71 | glBlitFramebuffer( 72 | 0, 0, GLsizei(data.shape[1]), GLsizei(data.shape[0]), 0, 0, 73 | width, height, GLenum(GL_COLOR_BUFFER_BIT), GLenum(GL_LINEAR)) 74 | 75 | // Swap the OpenGL front and back buffers to show the image. 76 | glfwSwapBuffers(window) 77 | glfwPollEvents() 78 | } 79 | } 80 | 81 | @inlinable 82 | public func createWindow(width: Int32, height: Int32) throws { 83 | // Initialize GLFW. 84 | if glfwInit() == 0 { 85 | throw ReinforcementLearningError.renderingError("Failed to initialize GLFW.") 86 | } 87 | 88 | // Open a new window. 89 | guard let window = glfwCreateWindow(width, height, "Gym Retro", nil, nil) else { 90 | glfwTerminate() 91 | throw ReinforcementLearningError.renderingError("Failed to open a GLFW window.") 92 | } 93 | 94 | self.window = window 95 | 96 | glfwMakeContextCurrent(window) 97 | 98 | // Generate a frame buffer. 99 | glGenFramebuffers(1, &frameBuffer) 100 | glBindFramebuffer(GLenum(GL_READ_FRAMEBUFFER), frameBuffer) 101 | 102 | // Generate a texture. 103 | glGenTextures(1, &texture) 104 | glBindTexture(GLenum(GL_TEXTURE_2D), texture) 105 | glTexParameteri(GLenum(GL_TEXTURE_2D), GLenum(GL_TEXTURE_MAG_FILTER), GL_NEAREST) 106 | glTexParameteri(GLenum(GL_TEXTURE_2D), GLenum(GL_TEXTURE_MIN_FILTER), GL_NEAREST) 107 | 108 | // Bind the texture to the frame buffer. 109 | glFramebufferTexture2D( 110 | GLenum(GL_READ_FRAMEBUFFER), GLenum(GL_COLOR_ATTACHMENT0), 111 | GLenum(GL_TEXTURE_2D), texture, 0) 112 | } 113 | 114 | @inlinable 115 | public func closeWindow() { 116 | if let window = self.window { 117 | glDeleteTextures(1, &texture) 118 | glDeleteFramebuffers(1, &frameBuffer) 119 | glfwDestroyWindow(window) 120 | glfwPollEvents() 121 | glfwTerminate() 122 | self.isOpen = false 123 | self.window = nil 124 | } 125 | } 126 | 127 | @inlinable 128 | internal static func preprocessData(_ data: ShapedArray) throws -> [UInt8] { 129 | precondition(data.rank == 3 && data.shape[2] == 3, "Data must have shape '[Height, Width, 3]'.") 130 | let rowSize = data.shape[1] * data.shape[2] 131 | let scalars = data.scalars 132 | var preprocessed = [UInt8]() 133 | preprocessed.reserveCapacity(scalars.count) 134 | for row in 1...data.shape[0] { 135 | let index = (data.shape[0] - row) * rowSize 136 | preprocessed.append(contentsOf: scalars[index..<(index + rowSize)]) 137 | } 138 | return preprocessed 139 | } 140 | } 141 | 142 | public protocol GLFWScene { 143 | func draw() 144 | } 145 | 146 | extension GLFWScene { 147 | @inlinable 148 | public func render(in window: GLFWWindow) { 149 | if !window.isOpen { return } 150 | if let fps = window.framesPerSecond { Thread.sleep(forTimeInterval: 1 / fps) } 151 | if let w = window.window { 152 | // TODO: Should the following be exiting the running problem? 153 | if glfwWindowShouldClose(w) > 0 { window.close(); exit(0) } 154 | glfwMakeContextCurrent(w) 155 | glClearColor(1, 1, 1, 1) 156 | glClear(GLbitfield(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)) 157 | draw() 158 | glfwSwapBuffers(w) 159 | glfwPollEvents() 160 | } 161 | } 162 | } 163 | 164 | public class GLFWWindow { 165 | public let name: String 166 | 167 | @usableFromInline internal let width: Int 168 | @usableFromInline internal let height: Int 169 | @usableFromInline internal let framesPerSecond: Double? 170 | @usableFromInline internal var window: OpaquePointer? 171 | @usableFromInline internal var isOpen: Bool = true 172 | 173 | @inlinable 174 | public init(name: String, width: Int, height: Int, framesPerSecond: Double? = nil) throws { 175 | self.name = name 176 | self.width = width 177 | self.height = height 178 | self.framesPerSecond = framesPerSecond 179 | 180 | // Initialize GLFW. 181 | if glfwInit() == 0 { 182 | throw ReinforcementLearningError.renderingError("Failed to initialize GLFW.") 183 | } 184 | 185 | // Open a new window. 186 | guard let window = glfwCreateWindow(Int32(width), Int32(height), name, nil, nil) else { 187 | glfwTerminate() 188 | throw ReinforcementLearningError.renderingError("Failed to open a GLFW window.") 189 | } 190 | 191 | self.window = window 192 | glfwMakeContextCurrent(window) 193 | glEnable(GLenum(GL_BLEND)) 194 | glBlendFunc(GLenum(GL_SRC_ALPHA), GLenum(GL_ONE_MINUS_SRC_ALPHA)) 195 | glViewport(0, 0, Int32(width), Int32(height)) 196 | glMatrixMode(GLenum(GL_PROJECTION)) 197 | glLoadIdentity() 198 | glOrtho(0, GLdouble(width), 0, GLdouble(height), 0, 1) 199 | } 200 | 201 | @inlinable 202 | deinit { 203 | close() 204 | } 205 | 206 | @inlinable 207 | public func close() { 208 | if let window = self.window { 209 | glfwDestroyWindow(window) 210 | glfwPollEvents() 211 | glfwTerminate() 212 | self.isOpen = false 213 | self.window = nil 214 | } 215 | } 216 | } 217 | 218 | public protocol GLFWAttribute: AnyObject { 219 | func enable() 220 | func disable() 221 | } 222 | 223 | public class GLFWTransform: GLFWAttribute { 224 | public var translation: (Float, Float) 225 | public var rotation: Float 226 | public var scale: (Float, Float) 227 | 228 | @inlinable 229 | public init( 230 | translation: (Float, Float) = (0.0, 0.0), 231 | rotation: Float = 0.0, 232 | scale: (Float, Float) = (1.0, 1.0) 233 | ) { 234 | self.translation = translation 235 | self.rotation = rotation 236 | self.scale = scale 237 | } 238 | 239 | @inlinable 240 | public func enable() { 241 | glPushMatrix() 242 | glTranslatef(translation.0, translation.1, 0) 243 | glRotatef(rotation * 180 / Float.pi, 0, 0, 1) 244 | glScalef(scale.0, scale.1, 1) 245 | } 246 | 247 | @inlinable 248 | public func disable() { 249 | glPopMatrix() 250 | } 251 | } 252 | 253 | public class GLFWColor: GLFWAttribute { 254 | public let red: Float 255 | public let green: Float 256 | public let blue: Float 257 | 258 | @inlinable 259 | public init(red: Float, green: Float, blue: Float) { 260 | self.red = red 261 | self.green = green 262 | self.blue = blue 263 | } 264 | 265 | @inlinable 266 | public func enable() { 267 | glColor3f(red, green, blue) 268 | } 269 | 270 | @inlinable 271 | public func disable() { } 272 | } 273 | 274 | public class GLFWLineStyle: GLFWAttribute { 275 | public let pattern: UInt16 276 | 277 | @inlinable 278 | public init(_ pattern: UInt16) { 279 | self.pattern = pattern 280 | } 281 | 282 | @inlinable 283 | public func enable() { 284 | glEnable(GLenum(GL_LINE_STIPPLE)) 285 | glLineStipple(1, pattern) 286 | } 287 | 288 | @inlinable 289 | public func disable() { 290 | glDisable(GLenum(GL_LINE_STIPPLE)) 291 | } 292 | } 293 | 294 | public class GLFWLineWidth: GLFWAttribute { 295 | public let width: Float 296 | 297 | @inlinable 298 | public init(_ width: Float) { 299 | self.width = width 300 | } 301 | 302 | @inlinable 303 | public func enable() { 304 | glLineWidth(width) 305 | } 306 | 307 | @inlinable 308 | public func disable() { } 309 | } 310 | 311 | public protocol GLFWGeometry: AnyObject { 312 | var attributes: [GLFWAttribute] { get set } 313 | 314 | func render() 315 | } 316 | 317 | extension GLFWGeometry { 318 | @inlinable 319 | public func renderWithAttributes() { 320 | attributes.reversed().forEach { $0.enable() } 321 | render() 322 | attributes.forEach { $0.disable() } 323 | } 324 | } 325 | 326 | public class GLFWCompoundGeometry: GLFWGeometry { 327 | public var attributes: [GLFWAttribute] = [] 328 | public var components: [GLFWGeometry] 329 | 330 | @inlinable 331 | public init(_ components: GLFWGeometry...) { 332 | self.components = components 333 | attributes = components.flatMap { $0.attributes } 334 | } 335 | 336 | @inlinable 337 | public func render() { 338 | components.forEach { $0.render() } 339 | } 340 | 341 | @inlinable 342 | public func renderWithAttributes() { 343 | components.forEach { $0.renderWithAttributes() } 344 | } 345 | } 346 | 347 | public class GLFWPoint: GLFWGeometry { 348 | public var attributes: [GLFWAttribute] = [] 349 | public var coordinates: (Float, Float, Float) 350 | 351 | @inlinable 352 | public init(coordinates: (Float, Float, Float) = (0, 0, 0)) { 353 | self.coordinates = coordinates 354 | } 355 | 356 | @inlinable 357 | public func render() { 358 | glBegin(GLenum(GL_POINTS)) 359 | glVertex3f(0, 0, 0) 360 | glEnd() 361 | } 362 | } 363 | 364 | public class GLFWLine: GLFWGeometry { 365 | public var attributes: [GLFWAttribute] = [] 366 | public var start: (Float, Float) 367 | public var end: (Float, Float) 368 | 369 | @inlinable 370 | public init(start: (Float, Float), end: (Float, Float)) { 371 | self.start = start 372 | self.end = end 373 | } 374 | 375 | @inlinable 376 | public func render() { 377 | glBegin(GLenum(GL_LINES)) 378 | glVertex2f(start.0, start.1) 379 | glVertex2f(end.0, end.1) 380 | glEnd() 381 | } 382 | } 383 | 384 | public class GLFWPolygon: GLFWGeometry { 385 | public var attributes: [GLFWAttribute] = [] 386 | public var vertices: [(Float, Float)] 387 | 388 | @inlinable 389 | public init(vertices: [(Float, Float)]) { 390 | self.vertices = vertices 391 | } 392 | 393 | @inlinable 394 | public func render() { 395 | switch vertices.count { 396 | case 4: glBegin(GLenum(GL_QUADS)) 397 | case 4...: glBegin(GLenum(GL_POLYGON)) 398 | case _: glBegin(GLenum(GL_TRIANGLES)) 399 | } 400 | vertices.forEach { glVertex3f($0.0, $0.1, 0) } 401 | glEnd() 402 | } 403 | } 404 | 405 | public class GLFWPolyLine: GLFWGeometry { 406 | public var attributes: [GLFWAttribute] = [] 407 | public var vertices: [(Float, Float)] 408 | public var closed: Bool 409 | 410 | @inlinable 411 | public init(vertices: [(Float, Float)], closed: Bool) { 412 | self.vertices = vertices 413 | self.closed = closed 414 | } 415 | 416 | @inlinable 417 | public func render() { 418 | glBegin(GLenum(closed ? GL_LINE_LOOP : GL_LINE_STRIP)) 419 | vertices.forEach { glVertex3f($0.0, $0.1, 0) } 420 | glEnd() 421 | } 422 | } 423 | 424 | public class GLFWCircle: GLFWGeometry { 425 | public var attributes: [GLFWAttribute] = [] 426 | public var radius: Float = 10 427 | public var resolution: Int = 30 428 | public var filled: Bool = true 429 | 430 | @inlinable 431 | public func render() { 432 | let vertices = (0.. (Float, Float) in 433 | let angle = 2 * Float.pi * Float(i) / Float(resolution) 434 | return (cos(angle) * radius, sin(angle) * radius) 435 | } 436 | if filled { 437 | GLFWPolygon(vertices: vertices).render() 438 | } else { 439 | GLFWPolyLine(vertices: vertices, closed: true).render() 440 | } 441 | } 442 | } 443 | 444 | // public class GLFWImage: GLFWGeometry { 445 | // public var attributes: [GLFWAttribute] = [] 446 | // public var image: ShapedArray 447 | 448 | // public init(_ image: ShapedArray) { 449 | // precondition( 450 | // image.rank == 3 && image.shape[2] == 3, 451 | // "The image array must have shape '[Height, Width, 3]'.") 452 | // self.image = image 453 | // } 454 | 455 | // public func render(in window: OpaquePointer?) { 456 | // // Generate the image texture. 457 | // let rowSize = image.shape[1] * image.shape[2] 458 | // let scalars = image.scalars 459 | // var preprocessed = [UInt8]() 460 | // preprocessed.reserveCapacity(scalars.count) 461 | // for row in 1...image.shape[0] { 462 | // let index = (image.shape[0] - row) * rowSize 463 | // preprocessed.append(contentsOf: scalars[index..<(index + rowSize)]) 464 | // } 465 | // preprocessed.withUnsafeBufferPointer { 466 | // glTexImage2D( 467 | // GLenum(GL_TEXTURE_2D), 0, GL_RGB8, GLsizei(image.shape[1]), GLsizei(image.shape[0]), 468 | // 0, GLenum(GL_RGB), GLenum(GL_UNSIGNED_BYTE), $0.baseAddress) 469 | // } 470 | 471 | // // Resize and render the texture. 472 | // var width: GLsizei = 0 473 | // var height: GLsizei = 0 474 | // glfwGetFramebufferSize(window, &width, &height) 475 | // glBlitFramebuffer( 476 | // 0, 0, GLsizei(image.shape[1]), GLsizei(image.shape[0]), 0, 0, 477 | // width, height, GLenum(GL_COLOR_BUFFER_BIT), GLenum(GL_LINEAR)) 478 | // } 479 | // } 480 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Utilities/Rendering2.swift: -------------------------------------------------------------------------------- 1 | // // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // // 3 | // // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // // use this file except in compliance with the License. You may obtain a copy of 5 | // // the License at 6 | // // 7 | // // http://www.apache.org/licenses/LICENSE-2.0 8 | // // 9 | // // Unless required by applicable law or agreed to in writing, software 10 | // // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // // License for the specific language governing permissions and limitations under 13 | // // the License. 14 | 15 | // import CGLFW 16 | // import CVulkan 17 | 18 | // @inlinable 19 | // internal func withCStrings( 20 | // _ args: [String], 21 | // _ body: (UnsafePointer?>?) -> R 22 | // ) -> R { 23 | // var cStrings = args.map { UnsafePointer(strdup($0)) } 24 | // defer { cStrings.forEach { free(UnsafeMutablePointer(mutating: $0)) } } 25 | // return body(&cStrings) 26 | // } 27 | 28 | // public class RenderingWindow { 29 | // public let name: String 30 | 31 | // public let applicationVersion: UInt32 = 1 32 | // public let vulkanEngineVersion: UInt32 = 0 33 | // public let vulkanAPIVersion: UInt32 = 0 34 | 35 | // public let vulkanEnabledLayers: [String] = [] 36 | // public let vulkanDeviceScoringFunction: (VulkanPhysicalDevice) -> Int? = 37 | // VulkanPhysicalDevice.defaultScore(for:) 38 | 39 | // @usableFromInline internal let width: Int 40 | // @usableFromInline internal let height: Int 41 | // @usableFromInline internal let framesPerSecond: Double? 42 | // @usableFromInline internal var window: OpaquePointer? = nil 43 | // @usableFromInline internal var isOpen: Bool = true 44 | 45 | // @usableFromInline internal var vulkanInstance: VkInstance? = nil 46 | // @usableFromInline internal var vulkanPhysicalDevice: VulkanPhysicalDevice? = nil 47 | 48 | // @inlinable 49 | // public init(name: String, width: Int, height: Int, framesPerSecond: Double? = nil) throws { 50 | // self.name = name 51 | // self.width = width 52 | // self.height = height 53 | // self.framesPerSecond = framesPerSecond 54 | // try! glfwInit() 55 | // try! vulkanInit() 56 | // } 57 | 58 | // @inlinable 59 | // deinit { 60 | // close() 61 | // } 62 | 63 | // @inlinable 64 | // internal func glfwInit() throws { 65 | // // Initialize GLFW. 66 | // if CGLFW.glfwInit() == 0 { throw ReinforcementLearningError.renderingError("Failed to initialize GLFW.") } 67 | 68 | // // Since we are using Vulkan, we need to tell GLFW to not create an OpenGL context. 69 | // glfwWindowHint(GLFW_CLIENT_API, GLFW_NO_API) 70 | // glfwWindowHint(GLFW_RESIZABLE, GLFW_FALSE) 71 | 72 | // // Open a new window. 73 | // guard let window = glfwCreateWindow(Int32(width), Int32(height), name, nil, nil) else { 74 | // glfwTerminate() 75 | // throw ReinforcementLearningError.renderingError("Failed to open a GLFW window.") 76 | // } 77 | 78 | // self.window = window 79 | // glfwMakeContextCurrent(window) 80 | // } 81 | 82 | // @inlinable 83 | // internal func vulkanInit() throws { 84 | // try! vulkanCreateInstance() 85 | // try! vulkanPickPhysicalDevice() 86 | // try! vulkanCreateLogicalDevice() 87 | // } 88 | 89 | // @inlinable 90 | // internal func vulkanCreateInstance() throws { 91 | // var applicationInformation = VkApplicationInfo( 92 | // sType: VK_STRUCTURE_TYPE_APPLICATION_INFO, 93 | // pNext: nil, 94 | // pApplicationName: name, 95 | // applicationVersion: applicationVersion, 96 | // pEngineName: nil, 97 | // engineVersion: vulkanEngineVersion, 98 | // apiVersion: vulkanAPIVersion) 99 | 100 | // // Get the required GLFW extensions. 101 | // var glfwExtensionCount: UInt32 = 0 102 | // let glfwExtensions = glfwGetRequiredInstanceExtensions(&glfwExtensionCount) 103 | 104 | // var information = withCStrings(vulkanEnabledLayers) { enabledLayers in 105 | // VkInstanceCreateInfo( 106 | // sType: VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO, 107 | // pNext: nil, 108 | // flags: 0, 109 | // pApplicationInfo: &applicationInformation, 110 | // enabledLayerCount: UInt32(self.vulkanEnabledLayers.count), 111 | // ppEnabledLayerNames: enabledLayers, 112 | // enabledExtensionCount: glfwExtensionCount, 113 | // ppEnabledExtensionNames: glfwExtensions) 114 | // } 115 | 116 | // let result = vkCreateInstance(&information, nil, &vulkanInstance) 117 | // if result != VK_SUCCESS { 118 | // throw ReinforcementLearningError.renderingError( 119 | // "Failed to create a Vulkan instance. Error code: \(String(describing: result.rawValue))") 120 | // } 121 | // } 122 | 123 | // @inlinable 124 | // internal func vulkanPickPhysicalDevice() throws { 125 | // let devices = vulkanDevices() 126 | // let scores = devices.map(vulkanDeviceScoringFunction) 127 | // var maxScore = Int.min 128 | // var device: VulkanPhysicalDevice? = nil 129 | // for (currentDevice, currentScore) in zip(devices, scores) { 130 | // if let s = currentScore, s >= maxScore { 131 | // maxScore = s 132 | // device = currentDevice 133 | // } 134 | // } 135 | // if let d = device { 136 | // vulkanPhysicalDevice = d 137 | // } else { 138 | // throw ReinforcementLearningError.renderingError("Failed to find a suitable device for rendering.") 139 | // } 140 | // } 141 | 142 | // @inlinable 143 | // internal func vulkanCreateLogicalDevice() throws { 144 | 145 | // } 146 | 147 | // @inlinable 148 | // public func vulkanExtensions() -> [String] { 149 | // var extensionCount: UInt32 = 0 150 | // vkEnumerateInstanceExtensionProperties(nil, &extensionCount, nil) 151 | // let extensionsPointer = UnsafeMutablePointer.allocate( 152 | // capacity: Int(extensionCount)) 153 | // defer { extensionsPointer.deallocate() } 154 | // vkEnumerateInstanceExtensionProperties(nil, &extensionCount, extensionsPointer) 155 | // let extensionsBufferPointer = UnsafeBufferPointer( 156 | // start: extensionsPointer, 157 | // count: Int(extensionCount)) 158 | // return [VkExtensionProperties](extensionsBufferPointer).map { 159 | // var name = $0.extensionName 160 | // return withUnsafeBytes(of: &name) { rawPointer -> String in 161 | // String(cString: rawPointer.baseAddress!.assumingMemoryBound(to: CChar.self)) 162 | // } 163 | // } 164 | // } 165 | 166 | // @inlinable 167 | // public func vulkanDevices() -> [VulkanPhysicalDevice] { 168 | // var deviceCount: UInt32 = 0 169 | // vkEnumeratePhysicalDevices(vulkanInstance, &deviceCount, nil) 170 | // if (deviceCount == 0) { return [VulkanPhysicalDevice]() } 171 | // let devicesPointer = UnsafeMutablePointer.allocate( 172 | // capacity: Int(deviceCount)) 173 | // defer { devicesPointer.deallocate() } 174 | // vkEnumeratePhysicalDevices(vulkanInstance, &deviceCount, devicesPointer) 175 | // let devicesBufferPointer = UnsafeBufferPointer(start: devicesPointer, count: Int(deviceCount)) 176 | // return [VkPhysicalDevice?](devicesBufferPointer).map { VulkanPhysicalDevice(device: $0!) } 177 | // } 178 | 179 | // @inlinable 180 | // public func close() { 181 | // if let vulkanInstance = self.vulkanInstance { 182 | // vkDestroyInstance(vulkanInstance, nil) 183 | // self.vulkanInstance = nil 184 | // } 185 | // if let window = self.window { 186 | // glfwDestroyWindow(window) 187 | // glfwTerminate() 188 | // self.isOpen = false 189 | // self.window = nil 190 | // } 191 | // } 192 | // } 193 | 194 | // public struct VulkanPhysicalDevice { 195 | // @usableFromInline internal let device: VkPhysicalDevice 196 | 197 | // public let name: String 198 | // public let apiVersion: UInt32 199 | // public let driverVersion: UInt32 200 | // public let vendorID: UInt32 201 | // public let deviceID: UInt32 202 | // public let type: RenderingDeviceType 203 | // public let limits: VkPhysicalDeviceLimits 204 | // public let sparseProperties: VkPhysicalDeviceSparseProperties 205 | // public let features: VkPhysicalDeviceFeatures 206 | 207 | // @inlinable 208 | // internal init(device: VkPhysicalDevice) { 209 | // let propertiesPointer = UnsafeMutablePointer.allocate(capacity: 1) 210 | // vkGetPhysicalDeviceProperties(device, propertiesPointer) 211 | // let properties = propertiesPointer.pointee 212 | // self.device = device 213 | // var deviceName = properties.deviceName 214 | // self.name = withUnsafeBytes(of: &deviceName) { rawPointer -> String in 215 | // String(cString: rawPointer.baseAddress!.assumingMemoryBound(to: CChar.self)) 216 | // } 217 | // self.apiVersion = properties.apiVersion 218 | // self.driverVersion = properties.driverVersion 219 | // self.vendorID = properties.vendorID 220 | // self.deviceID = properties.deviceID 221 | // self.type = RenderingDeviceType(from: properties.deviceType) 222 | // self.limits = properties.limits 223 | // self.sparseProperties = properties.sparseProperties 224 | // let featuresPointer = UnsafeMutablePointer.allocate(capacity: 1) 225 | // vkGetPhysicalDeviceFeatures(device, featuresPointer) 226 | // self.features = featuresPointer.pointee 227 | // } 228 | // } 229 | 230 | // extension VulkanPhysicalDevice { 231 | // @inlinable 232 | // public var queueFamilies: [VulkanQueueFamily] { 233 | // var queueFamilyCount: UInt32 = 0 234 | // vkGetPhysicalDeviceQueueFamilyProperties(device, &queueFamilyCount, nil) 235 | // let queueFamiliesPoiner = UnsafeMutablePointer.allocate( 236 | // capacity: Int(queueFamilyCount)) 237 | // vkGetPhysicalDeviceQueueFamilyProperties(device, &queueFamilyCount, queueFamiliesPoiner) 238 | // let queueFamiliesBufferPointer = UnsafeBufferPointer( 239 | // start: queueFamiliesPoiner, 240 | // count: Int(queueFamilyCount)) 241 | // return queueFamiliesBufferPointer.enumerated().map { 242 | // VulkanQueueFamily(at: $0, withProperties: $1) 243 | // } 244 | // } 245 | // } 246 | 247 | // extension VulkanPhysicalDevice { 248 | // public static func defaultScore(for device: VulkanPhysicalDevice) -> Int? { 249 | // // if device.features.geometryShader == 0 { return 0 } 250 | // if !device.queueFamilies.contains(where: { $0.supportsGraphics }) { return nil } 251 | // var score = 0 252 | // if device.type == .discreteGPU { score += 1000 } 253 | // score += Int(device.limits.maxImageDimension2D) 254 | // return score 255 | // } 256 | // } 257 | 258 | // public struct VulkanQueueFamily { 259 | // public let index: Int 260 | // public let supportsGraphics: Bool 261 | // public let supportsCompute: Bool 262 | // public let supportsTransfer: Bool 263 | // public let supportsSparseBinding: Bool 264 | // public let supportsProtected: Bool 265 | // public let queueCount: UInt32 266 | // public let timestampValidBits: UInt32 267 | // public let minImageTransferGranularity: VulkanExtent3D 268 | 269 | // @inlinable 270 | // internal init(at index: Int, withProperties properties: VkQueueFamilyProperties) { 271 | // self.index = index 272 | // self.supportsGraphics = properties.queueFlags & VK_QUEUE_GRAPHICS_BIT.rawValue != 0 273 | // self.supportsCompute = properties.queueFlags & VK_QUEUE_COMPUTE_BIT.rawValue != 0 274 | // self.supportsTransfer = properties.queueFlags & VK_QUEUE_TRANSFER_BIT.rawValue != 0 275 | // self.supportsSparseBinding = properties.queueFlags & VK_QUEUE_SPARSE_BINDING_BIT.rawValue != 0 276 | // self.supportsProtected = properties.queueFlags & VK_QUEUE_PROTECTED_BIT.rawValue != 0 277 | // self.queueCount = properties.queueCount 278 | // self.timestampValidBits = properties.timestampValidBits 279 | // self.minImageTransferGranularity = VulkanExtent3D( 280 | // width: properties.minImageTransferGranularity.width, 281 | // height: properties.minImageTransferGranularity.height, 282 | // depth: properties.minImageTransferGranularity.depth) 283 | // } 284 | // } 285 | 286 | // public struct VulkanExtent3D { 287 | // public let width: UInt32 288 | // public let height: UInt32 289 | // public let depth: UInt32 290 | 291 | // @inlinable 292 | // internal init(width: UInt32, height: UInt32, depth: UInt32) { 293 | // self.width = width 294 | // self.height = height 295 | // self.depth = depth 296 | // } 297 | // } 298 | 299 | // public enum RenderingDeviceType { 300 | // case integratedGPU, discreteGPU, virtualGPU, cpu, other 301 | 302 | // @inlinable 303 | // internal init(from vulkanValue: VkPhysicalDeviceType) { 304 | // switch vulkanValue { 305 | // case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU: self = .integratedGPU 306 | // case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU: self = .discreteGPU 307 | // case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU: self = .virtualGPU 308 | // case VK_PHYSICAL_DEVICE_TYPE_CPU: self = .cpu 309 | // case _: self = .other 310 | // } 311 | // } 312 | // } 313 | -------------------------------------------------------------------------------- /Sources/ReinforcementLearning/Values.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Emmanouil Antonios Platanios. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | import TensorFlow 16 | 17 | /// Computes discounted returns. 18 | /// 19 | /// Discounted returns are defined as follows: 20 | /// `Q_t = \sum_{t'=t}^T gamma^{t'-t} * r_{t'} + gamma^{T-t+1} * finalValue`, 21 | /// where `r_t` represents the reward at time step `t` and `gamma` represents the discount factor. 22 | /// For more details refer to "Reinforcement Learning: An Introduction" Second Edition by 23 | /// Richard S. Sutton and Andrew G. Barto. 24 | /// 25 | /// The discounted return computation also takes into account the time steps when episodes end 26 | /// (i.e., steps whose kind is `.last`) by making sure to reset the discounted return being carried 27 | /// backwards through time. 28 | /// 29 | /// Typically, each reward tensor will have shape `[BatchSize]` (for batched rewards) or `[]` (for 30 | /// unbatched rewards). 31 | /// 32 | /// - Parameters: 33 | /// - discountFactor: Reward discount factor (`gamma` in the above example). 34 | /// - stepKinds: Contains the step kinds (represented using their integer values) for each step. 35 | /// - rewards: Contains the rewards for each step. 36 | /// - finalValue: Estimated value at the final step. This is used to bootstrap the reward-to-go 37 | /// computation. Defaults to zeros. 38 | /// 39 | /// - Returns: Array of discounted return values over time. 40 | @inlinable 41 | public func discountedReturns( 42 | discountFactor: Scalar, 43 | stepKinds: StepKind, 44 | rewards: Tensor, 45 | finalValue: Tensor? = nil 46 | ) -> Tensor { 47 | let isLast = stepKinds.isLast() 48 | let T = stepKinds.rawValue.shape[0] 49 | // TODO: This looks very ugly. 50 | let Tminus1 = Tensor(Int64(T - 1)) 51 | let finalReward = finalValue ?? Tensor(zerosLike: rewards[0]) 52 | var discountedReturns = [Tensor]() 53 | for t in 0.. { 70 | public let advantages: Tensor 71 | public let discountedReturns: () -> Tensor 72 | 73 | @inlinable 74 | public init(advantages: Tensor, discountedReturns: @escaping () -> Tensor) { 75 | self.advantages = advantages 76 | self.discountedReturns = discountedReturns 77 | } 78 | } 79 | 80 | public protocol AdvantageFunction { 81 | /// - Parameters: 82 | /// - stepKinds: Contains the step kinds (represented using their integer values) for each step. 83 | /// - rewards: Contains the rewards obtained at each step. 84 | /// - values: Contains the value estimates for each step. 85 | /// - finalValue: Estimated value at the final step. 86 | func callAsFunction( 87 | stepKinds: StepKind, 88 | rewards: Tensor, 89 | values: Tensor, 90 | finalValue: Tensor 91 | ) -> AdvantageEstimate 92 | } 93 | 94 | /// Performs empirical advantage estimation. 95 | /// 96 | /// The empirical advantage estimate at step `t` is defined as: 97 | /// `advantage[t] = returns[t] - value[t]`, where the returns are computed using 98 | /// `discountedReturns(discountFactor:stepKinds:rewards:finalValue:)`. 99 | public struct EmpiricalAdvantageEstimation: AdvantageFunction { 100 | public let discountFactor: Float 101 | 102 | /// - Parameters: 103 | /// - discountFactor: Reward discount factor value, which must be between `0.0` and `1.0`. 104 | @inlinable 105 | public init(discountFactor: Float) { 106 | self.discountFactor = discountFactor 107 | } 108 | 109 | /// - Parameters: 110 | /// - stepKinds: Contains the step kinds (represented using their integer values) for each step. 111 | /// - rewards: Contains the rewards obtained at each step. 112 | /// - values: Contains the value estimates for each step. 113 | /// - finalValue: Estimated value at the final step. 114 | @inlinable 115 | public func callAsFunction( 116 | stepKinds: StepKind, 117 | rewards: Tensor, 118 | values: Tensor, 119 | finalValue: Tensor 120 | ) -> AdvantageEstimate { 121 | let returns = discountedReturns( 122 | discountFactor: Scalar(discountFactor), 123 | stepKinds: stepKinds, 124 | rewards: rewards, 125 | finalValue: finalValue) 126 | return AdvantageEstimate(advantages: returns - values, discountedReturns: { () in returns }) 127 | } 128 | } 129 | 130 | /// Performs generalized advantage estimation. 131 | /// 132 | /// For more details refer to "High-Dimensional Continuous Control Using Generalized Advantage 133 | /// Estimation" by John Schulman, Philipp Moritz et al. The full paper can be found at: 134 | /// https://arxiv.org/abs/1506.02438. 135 | public struct GeneralizedAdvantageEstimation: AdvantageFunction { 136 | public let discountFactor: Float 137 | public let discountWeight: Float 138 | 139 | /// - Parameters: 140 | /// - discountFactor: Reward discount factor value, which must be between `0.0` and `1.0`. 141 | /// - discountWeight: A weight between `0.0` and `1.0` that is used for variance reduction in 142 | /// the temporal differences. 143 | @inlinable 144 | public init(discountFactor: Float, discountWeight: Float = 1) { 145 | self.discountFactor = discountFactor 146 | self.discountWeight = discountWeight 147 | } 148 | 149 | /// - Parameters: 150 | /// - stepKinds: Contains the step kinds (represented using their integer values) for each step. 151 | /// - rewards: Contains the rewards obtained at each step. 152 | /// - values: Contains the value estimates for each step. 153 | /// - finalValue: Estimated value at the final step. 154 | @inlinable 155 | public func callAsFunction( 156 | stepKinds: StepKind, 157 | rewards: Tensor, 158 | values: Tensor, 159 | finalValue: Tensor 160 | ) -> AdvantageEstimate { 161 | let discountWeight = Scalar(self.discountWeight) 162 | let discountFactor = Scalar(self.discountFactor) 163 | let isNotLast = 1 - Tensor(stepKinds.isLast()) 164 | let T = stepKinds.rawValue.shape[0] 165 | 166 | // Compute advantages in reverse order. 167 | // TODO: This looks very ugly. 168 | let Tminus1 = Tensor(Int64(T - 1)) 169 | let last = rewards.gathering(atIndices: Tminus1) + 170 | discountFactor * finalValue * isNotLast.gathering(atIndices: Tminus1) - 171 | values.gathering(atIndices: Tminus1) 172 | var advantages = [last] 173 | for t in 1.. = Dense(inputSize: 4, outputSize: 100) 21 | public var dense2: Dense = Dense(inputSize: 100, outputSize: 2) 22 | 23 | @differentiable 24 | public func callAsFunction(_ input: CartPoleEnvironment.Observation) -> Categorical { 25 | let stackedInput = Tensor( 26 | stacking: [ 27 | input.position, input.positionDerivative, 28 | input.angle, input.angleDerivative], 29 | alongAxis: input.position.rank) 30 | let outerDimCount = stackedInput.rank - 1 31 | let outerDims = [Int](stackedInput.shape.dimensions[0..