├── .github ├── FUNDING.yml └── workflows │ └── ci.yml ├── .gitignore ├── .jvmopts ├── .pre-commit-config.yaml ├── .projectile ├── .scalafmt.conf ├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE ├── PLAN.md ├── README.md ├── build.sbt ├── docs └── src │ └── main │ ├── resources │ └── microsite │ │ ├── data │ │ └── menu.yml │ │ ├── img │ │ ├── agent_env.png │ │ └── policy_iteration.png │ │ └── js │ │ └── mathjax.js │ └── tut │ ├── cookbook.md │ ├── cookbook │ └── cookbook.md │ ├── course.md │ ├── course │ ├── tabular.md │ └── tabular │ │ ├── bandits.md │ │ ├── dynamic_programming.md │ │ ├── finite_mdps.md │ │ ├── monte_carlo.md │ │ ├── n_step_bootstrapping.md │ │ ├── tabular_methods.md │ │ ├── td_learning.md │ │ └── warning.md │ ├── faq.md │ ├── index.md │ ├── policies.md │ ├── policies │ ├── policies.md │ ├── stochastic.md │ └── stochastic │ │ ├── epsilon_greedy.md │ │ └── random.md │ ├── state.md │ └── state │ ├── simple.md │ └── simple │ └── bandit.md ├── project ├── build.properties └── plugins.sbt ├── scala-rl-book └── src │ ├── main │ └── scala │ │ └── com │ │ └── scalarl │ │ └── book │ │ ├── Chapter2.scala │ │ ├── Chapter3.scala │ │ ├── Chapter4.scala │ │ ├── Chapter5.scala │ │ └── Chapter6.scala │ └── test │ └── scala │ └── com │ └── scalarl │ └── book │ ├── Chapter3Spec.scala │ └── Chapter4Spec.scala ├── scala-rl-core └── src │ ├── main │ └── scala │ │ └── com │ │ └── scalarl │ │ ├── ActionValueFn.scala │ │ ├── Agent.scala │ │ ├── Evaluator.scala │ │ ├── Policy.scala │ │ ├── SARS.scala │ │ ├── State.scala │ │ ├── StateValueFn.scala │ │ ├── Time.scala │ │ ├── Util.scala │ │ ├── algebra │ │ ├── AffineCombination.scala │ │ ├── Decompose.scala │ │ ├── Expectation.scala │ │ ├── Module.scala │ │ ├── ToDouble.scala │ │ └── Weight.scala │ │ ├── evaluate │ │ ├── ActionValue.scala │ │ └── StateValue.scala │ │ ├── logic │ │ ├── Episode.scala │ │ ├── MonteCarlo.scala │ │ └── Sweep.scala │ │ ├── package.scala │ │ ├── policy │ │ ├── Gradient.scala │ │ ├── Greedy.scala │ │ ├── UCB.scala │ │ └── bandit │ │ │ └── Greedy.scala │ │ ├── rainier │ │ └── Categorical.scala │ │ ├── state │ │ ├── MapState.scala │ │ └── TickState.scala │ │ ├── util │ │ └── FrequencyTracker.scala │ │ └── value │ │ ├── ConstantStep.scala │ │ ├── DecayState.scala │ │ └── WeightedAverage.scala │ └── test │ └── scala │ └── com │ └── scalarl │ └── value │ └── ConstantStepLaws.scala ├── scala-rl-plot └── src │ └── main │ └── scala │ └── com │ └── scalarl │ └── plot │ ├── Plot.scala │ └── Tabulator.scala ├── scala-rl-world └── src │ ├── main │ └── scala │ │ └── com │ │ └── scalarl │ │ └── world │ │ ├── Bandit.scala │ │ ├── Blackjack.scala │ │ ├── CarRental.scala │ │ ├── GamblersProblem.scala │ │ ├── GridWorld.scala │ │ ├── InfiniteVariance.scala │ │ ├── connectfour │ │ ├── Game.scala │ │ └── IO.scala │ │ └── util │ │ ├── CardDeck.scala │ │ └── Grid.scala │ └── test │ └── scala │ └── com │ └── scalarl │ └── world │ └── connectfour │ └── ConnectFourSpec.scala ├── scaladoc-root.txt ├── scripts ├── decrypt-keys.sh └── publishMicrosite.sh └── travis-deploy-key.enc /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: sritchie 4 | patreon: sritchie 5 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | pull_request: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | checks: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: actions/setup-java@v4 14 | with: 15 | cache: "sbt" 16 | distribution: "temurin" 17 | java-version: 21 18 | - uses: sbt/setup-sbt@v1 19 | - run: sbt "; scalafmtCheckAll; scalafmtSbtCheck" 20 | 21 | test: 22 | runs-on: ubuntu-latest 23 | steps: 24 | - uses: actions/checkout@v4 25 | - uses: actions/setup-java@v4 26 | with: 27 | cache: "sbt" 28 | distribution: "temurin" 29 | java-version: 21 30 | - uses: sbt/setup-sbt@v1 31 | - run: sbt test 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # see also test/files/.gitignore 2 | /test/files/.gitignore 3 | 4 | # *.jar 5 | *~ 6 | 7 | sonatype.sbt 8 | 9 | #sbt 10 | /project/target/ 11 | /project/project/target 12 | 13 | /target/ 14 | /src/jline/target/ 15 | 16 | # target directories for ant build 17 | /build/ 18 | /dists/ 19 | 20 | # other 21 | /out/ 22 | /bin/ 23 | /sandbox/ 24 | 25 | # eclipse, intellij 26 | /.classpath 27 | /.project 28 | /src/intellij/*.iml 29 | /src/intellij/*.ipr 30 | /src/intellij/*.iws 31 | /.cache 32 | /.idea 33 | /.settings 34 | */.classpath 35 | */.project 36 | */.cache 37 | */.settings 38 | 39 | # bak files produced by ./cleanup-commit 40 | *.bak 41 | *.swp 42 | 43 | # from Scalding 44 | BUILD 45 | target/ 46 | lib_managed/ 47 | project/boot/ 48 | project/build/target/ 49 | project/plugins/target/ 50 | project/plugins/lib_managed/ 51 | project/plugins/src_managed/ 52 | 53 | # Generated Images 54 | images/ 55 | 56 | # Auto-copied by sbt-microsites 57 | docs/src/main/tut/contributing.md 58 | 59 | .ruby_version 60 | .vscode 61 | project 62 | -------------------------------------------------------------------------------- /.jvmopts: -------------------------------------------------------------------------------- 1 | # see https://weblogs.java.net/blog/kcpeppe/archive/2013/12/11/case-study-jvm-hotspot-flags 2 | -Dfile.encoding=UTF8 3 | -Xms1G 4 | -Xmx3G 5 | -XX:ReservedCodeCacheSize=250M 6 | -XX:+TieredCompilation 7 | -XX:-UseGCOverheadLimit 8 | -Djava.security.manager=allow 9 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-yaml 6 | args: [--unsafe] 7 | - id: end-of-file-fixer 8 | - id: trailing-whitespace 9 | -------------------------------------------------------------------------------- /.projectile: -------------------------------------------------------------------------------- 1 | -*.semanticdb 2 | -------------------------------------------------------------------------------- /.scalafmt.conf: -------------------------------------------------------------------------------- 1 | version=3.9.4 2 | runner.dialect = scala212 3 | maxColumn = 100 4 | newlines.penalizeSingleSelectMultiArgList = false 5 | align.openParenCallSite = false 6 | rewrite.rules = [AvoidInfix, SortImports, RedundantBraces, RedundantParens, PreferCurlyFors] 7 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | 3 | scala: 4 | - 2.12.10 5 | 6 | jdk: 7 | - openjdk8 8 | 9 | cache: 10 | directories: 11 | - $HOME/.sbt/1.0/dependency 12 | - $HOME/.sbt/launchers 13 | - $HOME/.ivy2/cache 14 | - $HOME/.sbt/boot 15 | 16 | before_cache: 17 | - du -h -d 1 $HOME/.ivy2/cache 18 | - du -h -d 2 $HOME/.sbt/ 19 | - find $HOME/.sbt -name "*.lock" -type f -delete 20 | - find $HOME/.ivy2/cache -name "ivydata-*.properties" -type f -delete 21 | 22 | before_install: 23 | - if [ "$TRAVIS_BRANCH" = "master" -a "$TRAVIS_PULL_REQUEST" = "false" ]; then bash scripts/decrypt-keys.sh; fi 24 | - export PATH=${PATH}:./vendor/bundle 25 | 26 | install: 27 | - rvm use 2.6.0 --install --fuzzy 28 | - yes | gem update --system 29 | - gem install sass 30 | - gem install jekyll -v 3.8.5 31 | 32 | script: 33 | - sbt ++$TRAVIS_SCALA_VERSION clean coverage test coverageReport docs/makeMicrosite 34 | 35 | after_success: 36 | - bash <(curl -s https://codecov.io/bash) 37 | - if [ "$TRAVIS_BRANCH" = "master" -a "$TRAVIS_PULL_REQUEST" = "false" ]; then bash scripts/publishMicrosite.sh; fi 38 | - if [ "$TRAVIS_PULL_REQUEST" = "true" ]; then echo "Not in master branch, skipping deploy and release"; fi 39 | 40 | notifications: 41 | webhooks: 42 | urls: 43 | - https://webhooks.gitter.im/e/d54a74117bc77b928e1f 44 | on_success: change 45 | on_failure: always 46 | on_start: never 47 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to ScalaRL 2 | 3 | This page lists recommendations and requirements for how to best contribute to ScalaRL. We strive to obey these as best as possible. As always, thanks for contributing - we hope these guidelines make it easier and shed some light on our approach and processes. 4 | 5 | ## Key branches 6 | 7 | - `master` is the latest, deployed version. 8 | - `develop` is where development happens and all pull requests should be submitted. 9 | 10 | ## Pull requests 11 | 12 | Submit pull requests against the `develop` branch. Try not to pollute your pull request with unintended changes. Keep it simple and small. 13 | 14 | ## Contributing Tests 15 | 16 | We don't have strong conventions around our tests, but here are a few guidelines that might help. 17 | 18 | ### Scalacheck Properties 19 | 20 | If you're adding [scalacheck](https://scalacheck.org/) properties... hold tight, more coming soon. Here's an example of an example: 21 | 22 | ```scala 23 | package com.scalarl 24 | 25 | import org.scalacheck.Prop 26 | 27 | class ExampleLaws extends ??? { 28 | // Fill in! 29 | } 30 | ``` 31 | 32 | ### Scalatest 33 | 34 | We use [scalatest](http://www.scalatest.org/) for all of our other tests. 35 | 36 | ## Contributing Documentation 37 | 38 | The documentation for ScalaRL's website is stored in the `docs/src/main/tut` directory of the [docs subproject](https://github.com/sritchie/scala-rl/tree/develop/docs). 39 | 40 | ScalaRL's documentation is powered by [sbt-microsites](https://47deg.github.io/sbt-microsites/) and [tut](https://github.com/tpolecat/tut). `tut` compiles any code that appears in the documentation, ensuring that snippets and examples won't go out of date. 41 | 42 | We would love your help making our documentation better. If you see a page that's empty or needs work, please send us a pull request making it better. If you contribute a new data structure to ScalaRL, please add a corresponding documentation page. To do this, you'll need to: 43 | 44 | - Add a new Markdown file to `docs/src/main/tut/datatypes` with the following format: 45 | 46 | ```markdown 47 | --- 48 | layout: docs 49 | title: "" 50 | section: "data" 51 | source: "scala-rl-core/src/main/scala/io/samritchie/rl/.scala" 52 | scaladoc: "#scalarl." 53 | --- 54 | 55 | # Your Data Type 56 | 57 | ..... 58 | ``` 59 | 60 | - Make sure to add some code examples! Any code block of this form will get compiled using `tut`: 61 | 62 | 63 | ```toot:book 64 | 65 | ``` 66 | 67 | (Please replace `toot` with `tut`!) `tut` will evaluate your code as if you'd pasted it into a REPL and insert each line's results in the output. State persists across `tut` code blocks, so feel free to alternate code blocks with text discussion. See the [tut README](https://github.com/tpolecat/tut) for more information on the various options you can use to customize your code blocks. 68 | - Add your page to the appropriate section in [the menu](https://github.com/sritchie/scala-rl/tree/develop/docs/src/main/resources/microsite/data/menu.yml) 69 | 70 | ### Generating the Site 71 | 72 | run `sbt docs/makeMicrosite` to generate a local copy of the microsite. 73 | 74 | ### Previewing the site 75 | 76 | 1. Install jekyll locally, depending on your platform, you might do this with any of the following commands: 77 | 78 | ``` 79 | yum install jekyll 80 | apt-get install jekyll 81 | gem install jekyll 82 | ``` 83 | 84 | 2. In a shell, navigate to the generated site directory in `docs/target/site` 85 | 3. Start jekyll with `jekyll serve --incremental` 86 | 4. Navigate to http://127.0.0.1:4000/scala-rl/ in your browser 87 | 5. Make changes to your site, and run `sbt docs/makeMicrosite` to regenerate the site. The changes should be reflected as soon as `sbt docs/makeMicrosite` completes. 88 | 89 | ## Post-release 90 | 91 | After the release occurs, you will need to update the documentation. Here is a list of the places that will definitely need to be updated: 92 | 93 | * `README.md`: update version numbers 94 | * `CHANGES.md`: summarize changes since last release 95 | 96 | (Other changes may be necessary, especially for large releases.) 97 | 98 | You can get a list of changes between release tags `v0.1.2` and `v0.2.0` via `git log v0.1.2..v0.2.0`. Scanning this list of commit messages is a good way to get a summary of what happened, although it does not account for conversations that occurred on Github. (You can see the same view on the Github UI by navigating to .) 99 | 100 | Once the relevant documentation changes have been committed, new [release notes](https://github.com/sritchie/scala-rl/releases) should be added. You can add a release by clicking the "Draft a new release" button on that page, or if the relevant release already exists, you can click "Edit release". 101 | 102 | The website should then be updated via `sbt docs/publishMicrosite`. 103 | 104 | ## License 105 | 106 | By contributing your code, you agree to license your contribution under the terms of the [APLv2](LICENSE). 107 | -------------------------------------------------------------------------------- /PLAN.md: -------------------------------------------------------------------------------- 1 | # The Plan 2 | 3 | What am I going to need to make? 4 | 5 | - various games and worlds 6 | - MDP code with proper interfaces 7 | - 10-arm test harness... 8 | - some interface for the world to interact 9 | - various programming demos, like the racetrack thing 10 | 11 | Do I need to bring some probabilistic programming ideas into the fold here? Maybe not at first, maybe that's too turbo? 12 | 13 | Each of these needs to have its own visualization. 14 | 15 | Can I go totally overboard and integrate this with a website where you can interact? Where you can actually have some way of playing the games, or setting the parameters? 16 | 17 | ## Probabilistic Programming 18 | 19 | Does this help me at all? https://github.com/stripe/rainier/blob/develop/docs/tour.md 20 | 21 | ## TODO 22 | 23 | Things I felt at one point were important.. 24 | 25 | - make DecayState work with a RING, not with anything so generic! And 26 | specialize it. (investigate what this means.) 27 | - we want the aggregator that currently deals with Value instances to take a 28 | Double only in the case with gamma = 1.0, a Left(instance) in the case where 29 | gamma = 0.0, and some generic thing... 30 | - rewrite ActionValueMap in terms of a default and a base. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Functional RL in Scala 2 | 3 | [![Build status](https://img.shields.io/travis/sritchie/scala-rl/develop.svg?maxAge=3600)](http://travis-ci.com/sritchie/scala-rl) 4 | [![Codecov branch](https://img.shields.io/codecov/c/github/sritchie/scala-rl/develop.svg?maxAge=3600)](https://codecov.io/github/sritchie/scala-rl) 5 | [![Latest version](https://index.scala-lang.org/sritchie/scala-rl/scala-rl-core/latest.svg?color=orange)](https://index.scala-lang.org/sritchie/scala-rl/scala-rl-core) 6 | [![Gitter](https://badges.gitter.im/ScalaRL/community.svg)](https://gitter.im/ScalaRL/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) 7 | [![Patreon](https://img.shields.io/badge/patreon-donate-blue.svg)](https://www.patreon.com/sritchie) 8 | 9 | ### Overview 10 | 11 | Reinforcement Learning in Scala, the functional way. 12 | 13 | I definitely don't need to go fully overboard... but the gold standard is to reimplement a bunch of this stuff: 14 | 15 | https://github.com/ShangtongZhang/reinforcement-learning-an-introduction 16 | 17 | in Scala. 18 | 19 | ## Notes 20 | 21 | Can we write the update steps in some interesting way? Like, you have a function that you pass an action to, and eventually it returns some reward? I think so! 22 | 23 | ## Blog Series 24 | 25 | This code supports the blog series on functional reinforcement learning. 26 | 27 | ## Get Involved 28 | 29 | Want to contribute examples or use this stuff? 30 | 31 | ## Inspiration 32 | 33 | - the book, Reinforcement Learning. 34 | - https://github.com/ShangtongZhang/reinforcement-learning-an-introduction 35 | 36 | ## To File 37 | 38 | - I'm using Rainier's version, but this is a nice article about the 39 | probability Monad: 40 | https://www.chrisstucchio.com/blog/2016/probability_the_monad.html 41 | - We use ScalaFMT https://scalameta.org/scalafmt/docs/installation.html 42 | - also, wartremover http://www.wartremover.org 43 | 44 | 45 | ## License 46 | 47 | Copyright 2019 Sam Ritchie. 48 | 49 | Licensed under the [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0). 50 | -------------------------------------------------------------------------------- /docs/src/main/resources/microsite/data/menu.yml: -------------------------------------------------------------------------------- 1 | options: 2 | 3 | ########################### 4 | # Policies Menu Options # 5 | ########################### 6 | 7 | - title: Policies 8 | url: policies.html 9 | menu_type: policy 10 | 11 | - title: Stochastic 12 | url: policies/stochastic.html 13 | menu_type: policy 14 | menu_section: stochastic 15 | 16 | nested_options: 17 | - title: Random 18 | url: policies/stochastic/random.html 19 | menu_section: stochastic 20 | 21 | nested_options: 22 | - title: Epsilon Greedy 23 | url: policies/stochastic/epsilon_greedy.html 24 | menu_section: stochastic 25 | 26 | ############################# 27 | # State Menu Options # 28 | ############################# 29 | 30 | - title: State 31 | url: state.html 32 | menu_type: state 33 | menu_section: state 34 | 35 | - title: Simple States 36 | url: state/simple.html 37 | menu_type: state 38 | menu_section: simple 39 | 40 | nested_options: 41 | - title: Bandit 42 | url: state/simple/bandit.html 43 | menu_section: simple 44 | 45 | ########################### 46 | # Cookbook Menu Options # 47 | ########################### 48 | 49 | - title: Cookbook 50 | url: cookbook.html 51 | menu_type: cookbook 52 | 53 | ########################### 54 | # Course Menu Options # 55 | ########################### 56 | 57 | - title: Functional RL in Scala 58 | url: course.html 59 | menu_type: course 60 | 61 | - title: Tabular Methods 62 | url: course/tabular.html 63 | menu_type: course 64 | menu_section: tabular 65 | 66 | nested_options: 67 | - title: Multi-armed Bandits 68 | url: course/tabular/bandits.html 69 | menu_section: tabular 70 | 71 | - title: Finite Markov Decision Processes 72 | url: course/tabular/finite_mdps.html 73 | menu_section: tabular 74 | 75 | - title: Dynamic Programming 76 | url: course/tabular/dynamic_programming.html 77 | menu_section: tabular 78 | 79 | - title: Monte Carlo Methods 80 | url: course/tabular/monte_carlo.html 81 | menu_section: tabular 82 | 83 | - title: Temporal Difference Learning 84 | url: course/tabular/td_learning.html 85 | menu_section: tabular 86 | 87 | - title: n-step Bootstrapping 88 | url: course/tabular/n_step_bootstrapping.html 89 | menu_section: tabular 90 | 91 | - title: Planning and Learning with Tabular Methods 92 | url: course/tabular/tabular_methods.html 93 | menu_section: tabular 94 | -------------------------------------------------------------------------------- /docs/src/main/resources/microsite/img/agent_env.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frogrocketlabs/scala-rl/cc02d7a46cc75436cdb2eaa41cd9f13cc97c3391/docs/src/main/resources/microsite/img/agent_env.png -------------------------------------------------------------------------------- /docs/src/main/resources/microsite/img/policy_iteration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frogrocketlabs/scala-rl/cc02d7a46cc75436cdb2eaa41cd9f13cc97c3391/docs/src/main/resources/microsite/img/policy_iteration.png -------------------------------------------------------------------------------- /docs/src/main/resources/microsite/js/mathjax.js: -------------------------------------------------------------------------------- 1 | (function () { 2 | var head = document.getElementsByTagName("head")[0], script; 3 | script = document.createElement("script"); 4 | script.type = "text/x-mathjax-config"; 5 | script[(window.opera ? "innerHTML" : "text")] = 6 | "MathJax.Hub.Config({\n" + 7 | " tex2jax: { inlineMath: [['$','$'], ['\\\\(','\\\\)']], processEscapes: true},\n" + 8 | " TeX: { equationNumbers: { autoNumber: \"AMS\" } }\n" + 9 | "});"; 10 | head.appendChild(script); 11 | script = document.createElement("script"); 12 | script.type = "text/javascript"; 13 | script.src = "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-MML-AM_CHTML"; 14 | head.appendChild(script); 15 | })(); 16 | -------------------------------------------------------------------------------- /docs/src/main/tut/cookbook.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: docs 3 | title: "Cookbook" 4 | section: "cookbook" 5 | position: 6 6 | --- 7 | 8 | {% include_relative cookbook/cookbook.md %} 9 | 10 | ## Index 11 | 12 | {% for x in site.pages %} 13 | {% if x.section == 'cookbook' %} 14 | - [{{x.title}}]({{site.baseurl}}{{x.url}}) 15 | {% endif %} 16 | {% endfor %} 17 | -------------------------------------------------------------------------------- /docs/src/main/tut/cookbook/cookbook.md: -------------------------------------------------------------------------------- 1 | # Cookbook 2 | 3 | In Progress - a cookbook of things you might like to do with ScalaRL. 4 | -------------------------------------------------------------------------------- /docs/src/main/tut/course.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: docs 3 | title: "Functional RL Course" 4 | section: "course" 5 | position: 1 6 | --- 7 | 8 | # Functional RL in Scala 9 | 10 | {% include_relative course/tabular/warning.md %} 11 | 12 | # Notes 13 | 14 | This is the main page for the course! Check in periodically for updates; I'm working on the code first, then looping back to the course. 15 | 16 | # What's this? 17 | 18 | What is this course all about? Why get into it. 19 | 20 | Each chapter review will actually be a showcase of some software I've built to functionally show off these goodies, based on the concepts in the books. Each will have links to other interpretations of this stuff. Kind of a literate programming deal. 21 | 22 | And, boom, I'm actually back working here on what I'm hoping will be a wonderful demo of what I've learned! 23 | 24 | * First step is to go through and summarize the chapters, as I've been doing. For each one I can note what I'd like to be able to program and show off. The goal is to make some sort of original contribution, while at the same time creating teaching resources. But the goal is not to go completely off the deep end with this earlier stuff. 25 | 26 | ## What is Reinforcement Learning? 27 | 28 | Would be nice to have some examples of the types of applications that everyone knows about. 29 | 30 | # Why am I writing this guide? 31 | 32 | Well, the abstractions don't seem... terribly well done. They're focused on getting results; I wanted to see if by re-interpreting them in a typed language I'd gain some insight into what the core concepts were, what the interfaces are that are having their implementations swapped out as we move higher up the ladder of power. 33 | 34 | This area is different from most of the supervised learning methods that I've dealt with in previous lives, and from unsupervised learning. 35 | 36 | I think all the money is in the second half of the book, but the first half of the book has some wonderful abstractions. 37 | 38 | And I think, by the way, that I need to get the code for this locked down before I make the move of starting to publish these guides. 39 | 40 | I also want to share my excitement about the things that I've been learning! Something is driving me to absorb and then spread the love, here. Why not talk more about it? 41 | 42 | # Notes on the Book's Introduction 43 | 44 | The book is focused on core, online learning algorithms in reinforcement learning. 45 | 46 | If I'm going to summarize this, the key points are... 47 | 48 | * what kinds of problems can this type of RL solve? What is novel here? 49 | * What is the formal, mathematical framework presented to analyze these methods? 50 | * What is missing? What should you expect? 51 | 52 | When you come up with a mathematical model, you have a problem on your hands. Does the way that we, humans, learn, actually work this way? Are we a special case of the model? Or are we something different, are we and the model both special cases? When you come up with an area of study you're trying to get above what we're doing, and then see what application humans actually are. 53 | 54 | > In this book we explore a computational approach to learning from interaction. (p. 1) 55 | 56 | Important to keep the distinction clear between the three things - a problem, a class of solution methods, and the field that studies the problem and its solution methods. The term! 57 | 58 | What is reinforcement learning? And where did these ideas come from? 59 | 60 | * comment on _explore / exploit_ tradeoff... 61 | * Reinforcement learning considers an entire agent, not just some subproblem. We at the outset have the entire agent coordinating with a world, either through a model or through interaction. (How does interact with the embedded agent story?) 62 | 63 | > One must look beyond the most obvious examples of agents and their environments to appreciate the generality of the reinforcement learning framework. (p. 4) 64 | 65 | Okay, lots of great examples of what you might do with RL... 66 | 67 | these all involve interaction between an agent and an environment, moving toward some goal, despite uncertainty about the environment. At some point you need feedback, but feedback might not come for a while. How do you explore in the meantime, and track what you're doing? 68 | 69 | ## Elements 70 | 71 | What are the main elements? Goals, rewards, value, and potentially a model. 72 | 73 | What are the limitations? I think maybe this book is not the best place to find the most troubling limitations, but maybe it is. 74 | 75 | Are we stuffing too much into this model? Well, the book does not concern itself with what is going on with the state signal, for example. How do we get a reward? Who chooses it? And... for humans, we seem to be able to do both. So what are we missing? This is a major problem with Hadoop, etc, by the way. The biggest challenges in industry are around, how do I get my input data in a nice format, and what do I do with the results? It's very easy to write about how to process data and learn models. That's maybe the most interesting thing... but not the biggest challenge. 76 | 77 | Levels of learning, from the Tegmark book in Life 3.0 78 | 79 | * level 1 - learn in the DNA, over generations. 80 | * level 2 - learn from experience 81 | * level 3 - modify yourself! 82 | 83 | Reinforcement learning seems to be about level 2. 84 | 85 | But.. that is sort of arbitrary too! 86 | 87 | Okay, tic tac toe example... how does something, a game, fit within the presented model? It's interesting that they present a game against an opponent, for their example, but then MOST of the examples in the first half, because we're dealing with online learning, maybe, are about an agent playing against an environment. How would you train a model against an opponent? That has to come later. Do they both learn at the same time? 88 | 89 | NOTE if I code this up that this is something a little funky. This is an after-state value function model, that knows that there are symmetries... we're walking around a graph, and there are multiple ways to get to a particular state. 90 | 91 | ## Unfiled Elements 92 | 93 | * training and running, what are the counterparts for reinforcement learning? There is prediction and control. It's an online problem, so they're a little mixed up, anyway. But it would be nice to introduce the vocab etc in a separate post before getting after it with the individual ones. 94 | * This is a nice example of how you can solve a bigger problem by breaking it into smaller problems. Who knows how the research actually progressed? But the big thing is hard to understand, and writing a book that slowly introduces degrees of freedom is a wonderful thing. 95 | * For all the posts, we need to talk about our need to converge. The goal is that if the environment doesn't change, we want to hone in and slowly gain more knowledge about what's happening. We also want to explore, in case the environment starts to change again, which it surely will. 96 | 97 | ## Other Resources 98 | 99 | Books, etc... what to read? 100 | -------------------------------------------------------------------------------- /docs/src/main/tut/course/tabular.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: docs 3 | title: "Tabular Methods" 4 | section: "course" 5 | --- 6 | 7 | # Tabular Methods 8 | 9 | {% include_relative tabular/warning.md %} 10 | 11 | # Notes on Part 1: Tabular Solution Methods 12 | 13 | This subsection covers the first half of the RL book. I'll update this as I figure out more about how we want to structure the course. 14 | 15 | For each of these, at first, I don't think I need to get so turbo into exactly what is going on. Just flesh out what we build up to and why it's an important idea, then make a list of what I might want to code here. 16 | 17 | I think when I do my programming exercises, I want to note... what exactly are we building up to, here? What are we making? 18 | 19 | The endgame is what's presented in chapter 8. The book nicely layers various concerns, but I think because of that loses the idea that the interfaces are common between all of the ideas. 20 | 21 | SO I think I need to implement each of these things, and then go back and fill in the various implementations. What's the simplest way we could implement blah? How do we evaluate this stuff? 22 | 23 | Then how do we make it all more complex? 24 | 25 | Is there a way to collect stats as we go ahead and train? We're training... but what else might we want to know? 26 | 27 | These are just great examples of how I'm able to think about research engineering problems, I think. 28 | 29 | The overview of this section is: 30 | * how do we learn anything at all when there's a single state, let alone multiple? Turns out this is an area called "multi-armed bandits". This is an overview that we maybe used at Twitter to deploy models. If you can get realtime feedback (like with an ads system!) you can use a system like this. 31 | * Then we expand quickly out to multiple states and apply the markov decision process framework, which we then use for the rest of the tutorial / book. 32 | * Next, chapters 4-6, go into dynamic programming, monte carlo methods, and then, finally, temporal-difference learning. These are all versions of the same method - play a game to conclusion. 33 | * Then, the final two chapters go into how to combine these various methods. I don't remember what is happening at those later steps, but I'm excited to go back and find out! Chapter 8 is a full-on combination... and, I think, uses the combination method to create a nice summary of how to combine all of these methods. 34 | -------------------------------------------------------------------------------- /docs/src/main/tut/course/tabular/bandits.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: docs 3 | title: "Multi-armed Bandits" 4 | section: "course" 5 | --- 6 | 7 | # Multi-Armed Bandits 8 | 9 | {% include_relative warning.md %} 10 | 11 | # Notes 12 | 13 | This comes from chapter 2 of the RL book, and is part of my series on reinforcement learning. 14 | 15 | This chapter introduces the kind of thing we're going to be dealing with. What is my unique contribution here? The problem framework, its types, are interesting, and the way you play. By abstracting out the structure, you can go to town on the various pieces that we plug in to the game. Much of the innovation presented in the RL book is various tweaks on the plugins. 16 | 17 | ## Discussion 18 | 19 | Remember, for each of these, the goal should be to come up with a one paragraph description of what the chapter is about. 20 | 21 | Example of a style of method that does NOT contextualize... but has some relationships to what comes later. 22 | 23 | how do we learn anything at all when there's a single state, let alone multiple? Turns out this is an area called "multi-armed bandits". This is an overview that we maybe used at Twitter to deploy models. If you can get realtime feedback (like with an ads system!) you can use a system like this. 24 | 25 | A key point here is evaluative vs instructive feedback. 26 | 27 | This is not meant to be a total exploration of the bandit problem in Scala. They're just trying to use this to introduce the later ideas... so I need to use this chapter to introduce simple versions of the interfaces that will later reign. That is going to be my contribution here. 28 | 29 | What's the objective here? You get these levers, you do something, you get a reward. How do you act so as to maximize the expected reward? 30 | 31 | Already the assumptions start flying, as, of course, they must. You're making the assumption that there is some logic to the rewards! That it's not totally random. Does this make sense? Well... how else can you act, I guess? 32 | 33 | Now we start introducing some terminology that will come in later - the $q^*(a)$ function, the action value. There is no state yet. 34 | 35 | There is the problem, always here, of balancing exploration vs exploitation. 36 | 37 | ### Action Value 38 | 39 | Sample average... then select the best. Seems reasonable! 40 | 41 | Already we have a place where we can start to put in different functions. What about the latest? This is already a simplified version of the more difficult thing, when you've got a state. 42 | 43 | This section introduces how to think about what exactly you're trying to maximize. Implicit is that you're keeping a table around, and updating values in it, then inspecting some range in that table. 44 | 45 | This give you the ability to choose exploit / explore. 46 | 47 | ### 10-armed Testbed 48 | https://datascience.stackexchange.com/questions/15415/what-is-the-reward-function-in-the-10-armed-test-bed 49 | 50 | A demo of the 10 armed testbed in python here: https://github.com/SahanaRamnath/MultiArmedBandit_RL 51 | 52 | ### Incremental Implementation 53 | This is the key to online learning. How can we write an interface to cover what's happening here? 54 | 55 | Kind of lame to cover this... link to Algebird. We want an aggregator that can absorb new rewards into the old estimate. Turns out there is a more general idea at play here. 56 | 57 | ### Non-stationary problem 58 | 59 | What is interesting about this? Well, we want to start ignoring old information. 60 | p. 33 talks about convergence guarantees, and how we don't actually want to converge if we want to be open to new information. 61 | 62 | Exponential recency-weighted average is a nice way around this. This is a way of swapping in a different implementation of the aggregator! 63 | 64 | Another way would be sliding windows. Can we go into that? Why would you want that? You... wouldn't, as much. But the door is open now to different ways of exploring. 65 | 66 | ### Optimistic Initial Values 67 | 68 | This strikes me as a goofy way to solve the problem, but a nice way to get the system to explore the space, if you're going to stick within the bounds supplied. How do you FORCE the system to choose all options, to at least explore everything once? One way is just to make everything irresistable. 69 | 70 | ### UCB Action Selection 71 | 72 | This is a way of prioritizing states that haven't been explored in a while. Can I get this in to my interface? I bet you could add a term that's calculated based on how many times we've actually seen something... like, you track a count of times you've chosen vs rewards or something. Oh, you can. It's just the denominator. So I think you have everything here to implement this stuff behind various interfaces. 73 | 74 | Then picking what you want becomes... a hyperparameter optimization. Or, at least, you can investigate how these various systems work. 75 | 76 | ### Gradient Bandit Algorithms 77 | 78 | I don't know if I need to go into this one, as it doesn't come up much later. No need to cover absolutely everything in the book. But give it a look, see if it fits within the model. 79 | 80 | But the idea here is that the reward itself maybe never changes for certain states, which doesn't work well with the idea of gradient descent, where you continually shift and then stop when you've converged. You want to assign scores to each state, and have some way of shifting those scores. 81 | 82 | ### Contextual Bandits. 83 | 84 | They mention this as a next step... but maybe not that important of a step. It's similar to some later stuff, but, remember, we have the problem of the action only affecting the immediate reward. Or, rather, the reward that comes from the action is the only thing that updates our estimate, NOT any indication of what state is going to come next. The dynamics are slightly different. 85 | 86 | ## Concluding Thoughts 87 | 88 | From the conclusion... 89 | 90 | This is really the state of the art stuff, and extends nicely, the various implementations extend nicely, to the big problems we're going to tackle in the next sections. 91 | 92 | ## Programming Exercises 93 | 94 | - single state version of the later markov stuff 95 | - interface to plug in how reward aggregation happens 96 | - 10-arm testbed... 97 | - some way of graphing the results 98 | - exercise 2.5 - design and conduct an experiment to demonstrate the difficulties that sample-average methods have for nonstationary problems.... etc 99 | - Optimistic initial values? Return a high score if you haven't been seen yet? A better default result? A "zero" for the aggregator!! Well... it breaks the monoid, of course, but that's okay. 100 | - UCB action selection implementation. 101 | - MAYBE a gradient descent thing? 102 | - Exercise 2.11 - generate a parameter study for the nonstationary case laid out in exercise 2.5. This is going to get pretty turbo... but if I can do all this I'm going to be at the state of the art, no question about it. We'll see... 103 | 104 | ## Further Reading 105 | 106 | - I DID read this back at Twitter... but it's not great. https://amzn.to/31thYiB 107 | - Go find some great blog posts highlighting this stuff. 108 | - DeepMind guys working on a book: https://tor-lattimore.com 109 | - Book site: https://banditalgs.com 110 | - Here's another article on bandits. https://medium.com/@tamoghnaghosh_30691/multi-armed-bandits-a-naive-form-of-reinforcement-learning-a133c8ec19be 111 | -------------------------------------------------------------------------------- /docs/src/main/tut/course/tabular/finite_mdps.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: docs 3 | title: "Finite Markov Decision Processes" 4 | section: "course" 5 | --- 6 | 7 | # Finite Markov Decision Processes 8 | 9 | {% include_relative warning.md %} 10 | 11 | # Notes 12 | 13 | There is not too much to go on here for programming exercises, which is nice. This is the shell for everything else. There's really just the gridworld. The Python example does this by hardcoding everything in. 14 | 15 | Can I make the gridworld really nice? Is there some interface that a game can satisfy... 16 | 17 | I mean, embedding it in a webpage is the best thing I can do. You set up your gridworld and hit "train" and boom, it gives you something solid. Is this worth doing? Only if it's easy! 18 | 19 | Obviously good to go over the summary again on p. 69 before publishing anything here. 20 | 21 | ### Historical Notes 22 | 23 | This is covered on page 69, but maybe do some looking so we can give background on where these ideas came from. 24 | 25 | ### Summary 26 | 27 | * we get the model for what we're going to talk about next, and a ton of the vocab. We're NOT talking about how to set up problems to be solved by this framework... but now we know about rewards, discounting, action-value functions, state-value functions etc, and we have some hints about how we might implement this. 28 | * The idea of how large the state space is obviously matters. For goofy examples like gridworld... it's super easy to see what to do. For something more complicated, which is anything in the real world, anything really interesting, you've got to get more turbo and use the approximation methods in the second half of the book. 29 | * But we get an intro to what training vs prediction are in this world. It is all about decisions that step the state of the world forward. And then you've got a hovering process, watching all of this and integrating information learned back into the relevant states. How far back do you get to remember? We'll talk about that later. 30 | 31 | Okay, this chapter introduces a bunch of different concepts, building up to this problem of... how do we model a system where different states present different opportunities? 32 | 33 | * the finite markov decision process! 34 | 35 | > These state-dependent quantities are essential to accurately assigning credit for long-term consequences to individual action selections. (p. 47) 36 | 37 | One q... they address this in chapter 17, but how do we go beyond the mathematical formalism? Are there things that you can only solve if you go beyond? 38 | 39 | What is the "social psychology" field going to look like for AI and machine learning? How do things change when you can inspect inside the black box? 40 | 41 | Next... the agent-environment interface. 42 | 43 | This is a mathematical model, remember. You have some inner thing, the agent, that is interacting with a totally unknown world. I guess you could build a model of the world, but it's assumed that the world is not inspectable except via the state representation you've built. 44 | 45 | Something like this. 46 | 47 | ![Agent Environment Interaction](/img/agent_env.png) 48 | 49 | * then we introduce the function $p(s', r \| s, a)$, which defines the _dynamics_ of the system. That term is interesting, since it shows up in all sorts of places. Dynamics means, the equation governing what the system will do next. In physics often the variable is time... here time is replaced by the idea of taking an action. That is what ticks the clock forward. This is really a function though. 50 | 51 | The art is defining the boundary. 52 | 53 | > In practice, the agent-environment boundary is determined once one has selected particular states, actions and rewards, and thus has identified a specific decision making task of interest. (p. 50) 54 | 55 | HUGE abstraction. Now lots of examples of stuff that fits into it. 56 | 57 | Exceptions? when the state is incomplete and unknowable; random errors; or maybe... you wouldn't want to use it when there is in fact a baked in physics model? Maybe not? That can come in later, as part of the thing calculating the value. 58 | 59 | The reward hypothesis: 60 | 61 | > That all of what we mean by goals and purposes can be well thought of as the maximization of the expected value of the cumulative sum of a received scalar signal. (p. 53) 62 | 63 | This is a little clunky. What it's saying is that the goal is to maximize the total number of cocaine drips you're going to receive in the future. 64 | 65 | * "expected value of the cumulative sum" - you're guessing. You also want the total number over time to be big, not just tomorrow. 66 | 67 | So you have to craft an environment that produces rewards that send you toward a particular goal. 68 | 69 | Of course this the alignment problem, right here. What if you fuck it up? 70 | This is called _specification gaming_: https://docs.google.com/spreadsheets/u/1/d/e/2PACX-1vRPiprOaC3HsCf5Tuum8bRfzYUiKLRqJmbOoC-32JorNdfyTiRRsR7Ea5eWtvsWzuxo8bjOxCG84dAg/pubhtml 71 | 72 | That list is a bunch of amazing examples of the potential fuckups available. 73 | 74 | > The reward signal is your way of communicating to the robot what you want it to achieve, not how you want it achieved. (p. 54) 75 | 76 | Then we talk about expected return, and how to discount items in the future. The whole trick is going to be how to pass back information about how the decision affected future states. 77 | 78 | * episodic and continuing tasks are the same... 79 | 80 | Then, the meat. All of reinforcement learning, most of it, anyway, is about estimating value functions. 81 | 82 | State-value and action-value functions are the two big ones. You have some choices. What are you going to do next? 83 | 84 | * _Policy_: a mapping from states to probabilities of selecting each possible action. The policy knows what it is possible to do next. 85 | 86 | The policy uses the value function to decide what is best... but the value function has to assume some policy behavior to know what the policy is. Well, you don't HAVE to... you can have an action-value function that does not assume. But state-value functions, if you just want to know "how good or bad is my current situation?" You need to know how you're going to act. 87 | 88 | If you perform a ton of trials, and visit each state a ton of times and keep track of what happened over many visits... Those are called "monto carlo methods". 89 | 90 | ### Bellman Equation 91 | 92 | This comes up enough in reading that it needs its own section. 93 | 94 | This is a recursive equation that gives the value of the current state as a sum of next states. 95 | 96 | > It states that the value of the start state must equal the (discounted) value of the expected next state, plus the reward expected along the way. (p. 59) 97 | 98 | This is a different kind of dynamics. 99 | 100 | The value function is the unique solution to its Bellman equation. 101 | 102 | Then we get some examples... gridworld, golf. 103 | 104 | ### Optimal 105 | 106 | Can we get to a perfect one? You get this funky back and forth thing going between prediction and control. Often, a crappy policy can still generate the optimal value function, etc. 67-68 talk about this too. 107 | 108 | I guess this section also talks about how it's easy to compute an optimal policy once you have an optimal value function, and you can ping back and forth and see what's going on. 109 | 110 | ## Programming Exercises 111 | 112 | * Demonstrate the interfaces required to implement this stuff, lock them all down and discuss the concepts using the interfaces as the hook. 113 | * Gridworld? 114 | * Golf example? maybe not... 115 | * Note that when we code this puppy up, if we can come up with some way of indexing the state that is not just a unique ID, or where it came from, then we can start to piggyback with the afterstate example. If we can get to multiple tic-tac-toe pairs from the same spot... 116 | -------------------------------------------------------------------------------- /docs/src/main/tut/course/tabular/n_step_bootstrapping.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: docs 3 | title: "n-step Bootstrapping" 4 | section: "course" 5 | --- 6 | 7 | # n-step Bootstrapping 8 | 9 | {% include_relative warning.md %} 10 | 11 | # Notes 12 | 13 | Really, this is the main thing we should be coding. And then we can specialize back to get everything else. Build up, as in the book, but I need to know where I'm building toward! 14 | 15 | These let us shift smoothly between the TD(0) method and the full-on, infinite step Monte Carlo methods. We hear another mention of eligibility traces... but I don't know what that is, so let's move on! 16 | 17 | n-step methods allow bootstrapping to occur very fast, to get online behavior, but also if we have time to start extending updates out into the future, to get the benefit of the long view. 18 | 19 | * Prediction first, 20 | * then control. 21 | 22 | We start with n-step Sarsa... but then we come up with some other methods that are in fact extensions of a grand unifying thing! That I want to go into. 23 | 24 | We sort of have to solve for THIS stuff... and then all of the previous work becomes a generalization. Well, sort of. Dynamic programming becomes a... is it a special case, still? I guess not since we don't have the dynamics at play. But if we do then we can use them. That is what comes back in chapter 8. 25 | 26 | ## Chapter Sketch 27 | 28 | The individual chapters. What is up? 29 | 30 | ### n-step TD Prediction 31 | 32 | Nice picture of the backup diagrams, which I still don't really get / like. What do we do if we want to extend beyond that first step? We need to keep playing a game, generating an episode, but as we play we can start passing info back. 33 | 34 | > For example, a two-step reward would be based on the first two rewards and the estimated value of the state two steps later. (p. 142) 35 | 36 | > The methods are still TD methods because they still change an earlier estimate based on how it differs from a later estimate. (p. 142) 37 | 38 | This still is a constant-alpha method; *BUT NOTE* that we are sneaking in an idea with the constant alpha thing!!! That is definitely NOT the only way to aggregate this stuff! That is just a way to privilege recent games. You can totally have different methods of aggregation. 39 | 40 | Then we look at the random walk task about show how different settings of alpha or n cause different overshoots. Thing to note is that an intermediate value of n worked nicely, halfway between monte carlo and TD. 41 | 42 | This is a state-value updater because the policy already decides what to do... it has some action, presumably, that it is going to take for everything. Or some chance of taking each action that it based on an action-value function that you assume someone else gave you. It definitely has knowledge of the actions, in any case. 43 | 44 | ### n-step Sarsa 45 | 46 | Remember, to come up with a policy we need to switch from state-value to action-value. That way we can use the collection of action-values to modify the policy, potentially. 47 | We still need the guarantee that we're possibly going to explore all states, so for on policy we need an $\epsilon\text{-greedy}$ method. 48 | 49 | Keep the action-value around then update. 50 | 51 | Expected sarsa is easy... just use the expected approximate values at the end. 52 | 53 | ### n-step Off-Policy Learning 54 | 55 | Woah, generalize again! Obviously I have to code this first, and then specialize it back to all of the previous stuff. Show how single state gives you bandits, for example. 56 | 57 | Note that our importance sampling ratio starts and ends one step later than if we were estimating the action-value. 58 | 59 | off-policy n-step expected sarsa too, here. just showing off. 60 | 61 | ### Per-decision methods with Control Variates 62 | 63 | Research topic alert! Look into this more in the exercises, but this should just be a different aggregator. This is a way of saying, okay, my update is actually a weighted average of the training policy's thing, what it learned, and then what the target policy ALREADY has. 64 | 65 | This is using importance sampling in place of alpha, I think? Think of a way to describe what is happening. But show again that it is just another aggregator. 66 | 67 | ### Off policy learning WITHOUT importance sampling - the n-step tree backup algo 68 | 69 | Is there a non-importance-sampling method that we can use for off-policy learning? We did this for the one-step case before... can we do this for the good stuff, the multi-step? 70 | 71 | You can get some budget and go chase promising stuff, unfolding as you go, if you have time. 72 | 73 | You're sort of learning from your model. I like it, easy. 74 | 75 | ### n-step $Q(\sigma)$ 76 | 77 | Can we unify all of the algos we've already seen? 78 | 79 | yeah, this is what to actually code, since we can... show that the $\sigma$ parameter can be set anywhere from 0 to 1. 80 | 81 | ## Programming Exercises 82 | 83 | * Note that when we do the constant alpha thing, that is just a particular way to aggregate. If we used $1 / n$ then we'd be doing a straight-up average. But you can weight different games differently, even, instead of going by time. The reason to weight games played more recently is that you've got more experience that went into choosing that trajectory. 84 | * Maybe when you average in the monoid you do a different thing - you weight by how many games you've played, instead of by alpha. 85 | * exercise 7.2... plug in the "sum of TD errors" thing from 7.1! That is actually a DIFFERENT way of propagating the information around. Implement this as an aggregator as well. 86 | * implement n-step sarsa 87 | * expected n-step sarsa 88 | * q-learning 89 | * Can we do a sliding window product? I think so right? Since products are just repeated addition? Well, that's a nice extension. 90 | * exercise 7.10, implement the control variates thing. 91 | * off-policy without importance sampling... tree backup algo! 92 | * n-step q-sigma 93 | -------------------------------------------------------------------------------- /docs/src/main/tut/course/tabular/tabular_methods.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: docs 3 | title: "Planning and Learning with Tabular Methods" 4 | section: "course" 5 | --- 6 | 7 | # Planning and Learning with Tabular Methods 8 | 9 | {% include_relative warning.md %} 10 | 11 | ## Notes 12 | 13 | iirc this is more of a summary of everything that came before, yes, uniting it into the big approach. 14 | 15 | ## Chapter Sketch 16 | 17 | There is a ton in here. Does this unify, or should I chill out and go back to the good stuff? Try to get a feel for what this is introducing... then I can start coding. 18 | 19 | We start with tabular dyna-Q, then slowly make it more complicated. 20 | 21 | ## Models and Planning 22 | 23 | This introduces the idea that models let you plan but updating your state-value or action-value functions based on predictions... the map is not the territory, but you can increase your understanding of the map. 24 | 25 | We are now back to the idea of a model, vs an actual environment to interact with. How interesting, similar to consciousness, here... we are always interacting with the model, and then some separate process behind the scenes is updating the model to be more like the real world. 26 | 27 | * distribution models 28 | * sample models 29 | 30 | Remember the two? 31 | 32 | Can you just draw a reward/state pair? Or do you get the entire distribution to look at? That's the difference. 33 | 34 | Planning - take either a model or (model, policy) and produces an improved policy. 35 | 36 | * there is also a thing called a PLAN SPACE MODEL, which we don't consider further, but this seems super cool. You have functions from plans to plans?? 37 | 38 | demo - random-sample one-step tabular Q-planning. 39 | 40 | planning in very small steps, too. They love online. 41 | 42 | ### Dyna: Integrated Planning, Acting, and Learning 43 | 44 | Dyna-Q. 45 | 46 | There are two functions for real-world experience: 47 | 48 | * improve the model 49 | * directly improve the value function. 50 | 51 | What is tabular Dyna-Q? (obviously this is an interface, and we should be able to plug in various versions!) 52 | 53 | * planning method? one-step tabular q-planning 54 | * direct RL? one-step tabular Q-learning 55 | * The MODEL is super dumb and just assumes a deterministic transition - whatever happened last time will happen again, for (s, a) => (r, s) 56 | 57 | Okay, I'm going to need to implement this to explain it more clearly. Go back here and get the interface going. 58 | 59 | ## When the Model is Wrong 60 | 61 | The model above can be wrong with stochastic info, or with very little info to build a stochastic model. 62 | 63 | If you have a fucked up model usually you will run into a clash with reality, and that will correct the model. 64 | 65 | A worse thing is when the environment becomes BETTER than it was before, and we don't get to adjust in time. 66 | 67 | Well, we can give a reward bonus. That is a nice way to deal. And implement the measurement i guess that shows how it gets better? 68 | 69 | ## Prioritized Sweeping 70 | 71 | okay, before we were doing a sweep that just randomly selected states from what we saw before. 72 | 73 | We really want to work backward from any state whose value has just changed. That is where we can ripple the updates from. 74 | 75 | Does that imply that we want a distance metric? That would help, right? 76 | 77 | A function to get the set of states that are blah steps away? A graph model really would be nice in this case. 78 | 79 | algorithm on page 170... can we do this better? 80 | 81 | "sample updates", page 171. 82 | 83 | ## Expected vs Sample Updates 84 | 85 | I guess we get here into what sample updates actually are. 86 | 87 | Oh, this is at the end... do we want to just sample at the VERY end, or do we use the expectation give the model? 88 | 89 | What if we sample within quantiles or something? We don't weight everything? That is, again, a way to get a sample model... 90 | 91 | Do we need to do any testing here? 92 | 93 | Maybe just make sure that our code harness COULD test this stuff. 94 | 95 | ## Trajectory Sampling 96 | 97 | Sample according to the current policy, within the model! This is called trajectory sampling. 98 | 99 | vast, boring parts of the state-space are ignored. 100 | 101 | ## Real-time Dynamic Programming 102 | 103 | Another algo that we can talk about, complicated, get into it if we decide to show this off. Page 178 shows the goodies. 104 | 105 | ## Planning at Decision Time 106 | 107 | We've been talking about background planning; you can give a budget here and plan away in the background. 108 | 109 | Woah... what if we use planning, based on the model, to roll forward? Use the CURRENT STATE in the actual game, like chess or something, to decide what move to do? 110 | Do you plan at decision time, when given a state? Or do you plan ahead of time, trying to guess what states you'll be in, and then rapidly get after it? 111 | 112 | ## Heuristic Search 113 | 114 | Classic decision time algos! 115 | 116 | Implement this shit, and save the results. 117 | 118 | ## Rollout Algorithms 119 | 120 | > Rollout algorithms are decision-time planning algorithms based on Monte Carlo control applied to simulated trajectories that all begin at the current environment state. (p. 183) 121 | 122 | They do a ton of monte-carlo guesses and then use that info to decide whether or not to store the value functions. 123 | 124 | This comes from Backgammon, where you "roll out" the dice a bunch of times to see what might happen next. 125 | 126 | There is a rollout policy... 127 | 128 | Figure out how this works, implement once I have my existing goodies. 129 | 130 | You want to save info? 131 | 132 | ## Monte Carlo Tree Search 133 | 134 | Well, you should do monte carlo tree search. 135 | 136 | This is an amazing one that I definitely can't wait to implement! And now we are into some serious modern shit. 137 | 138 | ## Summary of Part 1 139 | 140 | Woah, what a haul. 141 | 142 | Lots of different dimensions... p.191 covers these. There is so much to cover in this fucking reinforcement learning writeup! I am going to be such a stud if I can tackle all of this. I just need to be coding and writing all the time, now... And getting my neural network stuff down. 143 | 144 | ## Programming Exercises 145 | 146 | * function to go from a distribution model to a sample model. (these should be graphs so we can do breadth first searches out from areas that have changed.) 147 | * interfaces for sample and distribution models? 148 | * q-planning, page 161 149 | * dyna-Q 150 | * dyna-Q+, with a bonus that encourages exploration. 151 | * have the bonus implemented in different spots... just like in ex 8.4 152 | * trajectory sampling. 153 | * replicate figure 8,8...I have the code, do it in scala, change b=3. 154 | * RTDP on the racetrack, show that off! 155 | * heuristic search. 156 | -------------------------------------------------------------------------------- /docs/src/main/tut/course/tabular/warning.md: -------------------------------------------------------------------------------- 1 | # WARNING! 2 | 3 | Note - this is VERY EARLY DAYS! All of the files in the course with this warning are the raw, totally unprocessed notes that I generated during my first reading of "Reinforcement Learning: An Introduction". 4 | 5 | I'll be converting these into proper course sections with wonderful embedded code that you can try out. I'm not there yet, but I wanted to develop all of this in the open. 6 | 7 | Maybe you'll find these notes interesting, but don't expect anything special... yet! 8 | -------------------------------------------------------------------------------- /docs/src/main/tut/faq.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: page 3 | title: "FAQ" 4 | section: "faq" 5 | position: 4 6 | --- 7 | 8 | ## Frequently Asked Questions 9 | 10 | * [What is the course?](#course) 11 | * [How can I use Reinforcement Learning?](#how) 12 | * [How can I help?](#contributing) 13 | 14 | ### What is the course? 15 | 16 | I'll fill in more later - I'm writing these as placeholders. 17 | 18 | ### How can I use Reinforcement Learning? 19 | 20 | Another placeholder! 21 | 22 | ### How can I help? 23 | 24 | The ScalaRL community welcomes and encourages contributions! Here are a few ways to help out: 25 | 26 | - Look at the [code coverage report](https://codecov.io/github/sritchie/scala-rl?branch=develop), find some untested code, and write a test for it. Even simple helper methods and syntax enrichment should be tested. 27 | - Find an [open issue](https://github.com/sritchie/scala-rl/issues?q=is%3Aopen+is%3Aissue), leave a comment on it to let people know you are working on it, and submit a pull request. 28 | 29 | See the [contributing guide]({{ site.baseurl }}/contributing.html) for more information. 30 | -------------------------------------------------------------------------------- /docs/src/main/tut/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: home 3 | title: "Home" 4 | section: "home" 5 | --- 6 | 7 | ScalaRL is a library which provides abstractions for functional reinforcement learning in the [Scala programming language](https://scala-lang.org). 8 | 9 | The project also includes a course based on Sutton et al's book on Reinforcement Learning. 10 | 11 | This site is A WORK IN PROGRESS! 12 | 13 | ### What can you do with this code? 14 | 15 | Notes on usage. 16 | 17 | ## Using ScalaRL 18 | 19 | ScalaRL modules are available on Maven Central. The current groupid and version for all modules is, respectively, `"io.samritchie"` and `0.0.1`. 20 | 21 | See [ScalaRL's page on the Scaladex](https://index.scala-lang.org/sritchie/scala-rl) for information on all published artifacts and their associated Scala versions. ScalaRL currently supports 2.12. 22 | 23 | ## Documentation 24 | 25 | The latest API docs are hosted at ScalaRL's [ScalaDoc index](api/). 26 | 27 | ## Get Involved + Code of Conduct 28 | 29 | Pull requests and bug reports are always welcome! Check out our [Contributing guide](contributing.html) for information on what we most need help with and how you can get started contributing. 30 | 31 | Issues should be reported on the [GitHub issue tracker](https://github.com/sritchie/scala-rl/issues). 32 | 33 | A list of contributors to the project can be found here: [Contributors](https://github.com/sritchie/scala-rl/graphs/contributors) 34 | 35 | ## License 36 | 37 | Copyright 2019 Sam Ritchie. 38 | 39 | Licensed under the [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0). 40 | -------------------------------------------------------------------------------- /docs/src/main/tut/policies.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: docs 3 | title: "Policies" 4 | section: "policies" 5 | position: 3 6 | --- 7 | 8 | {% include_relative policies/policies.md %} 9 | 10 | ## Index 11 | 12 | {% for x in site.pages %} 13 | {% if x.section == 'policies' %} 14 | - [{{x.title}}]({{site.baseurl}}{{x.url}}) 15 | {% endif %} 16 | {% endfor %} 17 | -------------------------------------------------------------------------------- /docs/src/main/tut/policies/policies.md: -------------------------------------------------------------------------------- 1 | # Policies 2 | 3 | The good stuff. 4 | 5 | # What are they? 6 | 7 | Notes about what policies are. 8 | -------------------------------------------------------------------------------- /docs/src/main/tut/policies/stochastic.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: docs 3 | title: "Stochastic Policies" 4 | section: "policies" 5 | --- 6 | 7 | # Stochastic Policies 8 | 9 | Interesting policies that use randomness. 10 | 11 | ## Probability Monad 12 | 13 | Notes. 14 | -------------------------------------------------------------------------------- /docs/src/main/tut/policies/stochastic/epsilon_greedy.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: docs 3 | title: "Epsilon Greedy" 4 | section: "policies" 5 | --- 6 | 7 | # Epsilon Greedy Policy 8 | 9 | What is it? 10 | -------------------------------------------------------------------------------- /docs/src/main/tut/policies/stochastic/random.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: docs 3 | title: "Random" 4 | section: "policies" 5 | --- 6 | 7 | # Random 8 | 9 | Random policy notes. A totally random policy. 10 | -------------------------------------------------------------------------------- /docs/src/main/tut/state.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: docs 3 | title: "State" 4 | section: "state" 5 | position: 2 6 | --- 7 | 8 | # State 9 | 10 | These are the various state implementations. 11 | 12 | ## Index 13 | 14 | {% for x in site.pages %} 15 | {% if x.section == 'state' %} 16 | - [{{x.title}}]({{site.baseurl}}{{x.url}}) 17 | {% endif %} 18 | {% endfor %} 19 | -------------------------------------------------------------------------------- /docs/src/main/tut/state/simple.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: docs 3 | title: "Simple States" 4 | section: "state" 5 | --- 6 | 7 | # Simple States 8 | 9 | These states route back to themselves. Maybe this is not necessary. 10 | -------------------------------------------------------------------------------- /docs/src/main/tut/state/simple/bandit.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: docs 3 | title: "Bandit" 4 | section: "state" 5 | source: "scala-rl-core/src/main/scala/io/samritchie/state/Bandit.scala" 6 | scaladoc: "#scalarl.state.Bandit" 7 | --- 8 | 9 | # Bandit 10 | 11 | Bandit information. Talk more here. 12 | 13 | ### Related Code 14 | 15 | These links might be helpful: 16 | 17 | - [Bandit.scala](https://github.com/sritchie/scala-rl/blob/develop/scala-rl-core/src/main/scala/io/samritchie/rl/state/Bandit.scala) 18 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.9.8 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | resolvers ++= Seq( 2 | "jgit-repo".at("https://download.eclipse.org/jgit/maven"), 3 | Resolver.url( 4 | "bintray-sbt-plugin-releases", 5 | url("https://dl.bintray.com/content/sbt/sbt-plugin-releases") 6 | )( 7 | Resolver.ivyStylePatterns 8 | ) 9 | ) 10 | 11 | addSbtPlugin("com.47deg" % "sbt-microsites" % "0.9.7") 12 | addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.4.2") 13 | addSbtPlugin("com.github.sbt" % "sbt-release" % "1.0.15") 14 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.2") 15 | addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "1.1.4") 16 | addSbtPlugin("com.typesafe.sbt" % "sbt-ghpages" % "0.6.3") 17 | addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.4") 18 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.6.1") 19 | addSbtPlugin("org.wartremover" % "sbt-wartremover" % "3.3.2") 20 | addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.8.1") 21 | addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.0") 22 | -------------------------------------------------------------------------------- /scala-rl-book/src/main/scala/com/scalarl/book/Chapter2.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package book 3 | 4 | import cats.implicits._ 5 | import com.stripe.rainier.cats._ 6 | import com.stripe.rainier.core.{Generator, Normal} 7 | import com.stripe.rainier.compute.{Evaluator, Real} 8 | import com.stripe.rainier.sampler.RNG 9 | import com.twitter.util.Stopwatch 10 | import com.scalarl.logic.Episode 11 | import com.scalarl.plot.Plot 12 | import com.scalarl.policy.bandit.Greedy 13 | import com.scalarl.rainier.Categorical 14 | import com.scalarl.world.Bandit 15 | 16 | /** # Introduction to Chapter 2 17 | * 18 | * This chapter is about Bandits. These are markov processes that know about a single state, 19 | * really. The trick here is going to be getting the stuff that plays these particular states to be 20 | * more general, and work with the same machinery that rolls states forward. 21 | * 22 | * What we REALLY NEED here is both the top and bottom graphs, getting it done. 23 | * 24 | * The top graph is the average reward across GAMES per step. 25 | * 26 | * So we really want to march them ALL forward and grab the average reward... 27 | */ 28 | object Chapter2 { 29 | import Bandit.Arm 30 | import Episode.Moment 31 | 32 | /** These are needed to actually call get on anything. 33 | */ 34 | implicit val rng: RNG = RNG.default 35 | implicit val evaluator: Numeric[Real] = new Evaluator(Map.empty) 36 | 37 | // Implementing it the way it does in the book. 38 | def average(s: Iterable[Double]): Double = { 39 | val (sum, n) = s.foldLeft((0.0, 0)) { case ((sum, n), i) => 40 | (sum + i, n + 1) 41 | } 42 | sum / n 43 | } 44 | 45 | def playBandit[Obs, A, R]( 46 | policy: Policy[Obs, A, R, Generator, Generator], 47 | stateGen: Generator[State[Obs, A, R, Generator]], 48 | nRuns: Int, 49 | timeSteps: Int 50 | )( 51 | reduce: List[SARS[Obs, A, R, Generator]] => R 52 | ): (List[Moment[Obs, A, R, Generator]], List[R]) = { 53 | val rewardSeqGen = 54 | (0 until nRuns).toList 55 | .map(i => stateGen.map(s => Episode.Moment(policy, s))) 56 | .sequence 57 | .flatMap { pairs => 58 | Episode.playManyN[Obs, A, R, Generator]( 59 | pairs, 60 | timeSteps 61 | )(reduce) 62 | } 63 | 64 | val elapsed = Stopwatch.start() 65 | val rewardSeq = rewardSeqGen.get 66 | println( 67 | s"Time to play $nRuns runs of $timeSteps time steps each: ${elapsed()}" 68 | ) 69 | 70 | rewardSeq 71 | } 72 | 73 | /** Generates the n-armed testbed. 74 | */ 75 | def nArmedTestbed( 76 | nArms: Int, 77 | meanMean: Double, 78 | stdDev: Double 79 | ): Generator[State[Unit, Arm, Double, Generator]] = Bandit.stationary( 80 | nArms, 81 | Normal(meanMean, stdDev).generator 82 | .map(mean => Normal(mean, stdDev).generator) 83 | ) 84 | 85 | /** Generates a non-stationary distribution. 86 | */ 87 | def nonStationaryTestbed( 88 | nArms: Int, 89 | mean: Double, 90 | stdDev: Double 91 | ): Generator[State[Unit, Arm, Double, Generator]] = 92 | Bandit.nonStationary( 93 | nArms, 94 | Generator.constant(Normal(mean, stdDev).generator), 95 | { case (_, r, _) => Normal(r, stdDev).generator } 96 | ) 97 | 98 | def play(policy: Policy[Unit, Arm, Double, Cat, Generator]): List[Double] = 99 | playBandit( 100 | policy.mapK(Categorical.catToGenerator), 101 | nArmedTestbed(10, 0.0, 1.0), 102 | nRuns = 200, 103 | timeSteps = 1000 104 | ) { case items => average(items.map(_.reward)) }._2 105 | 106 | def main(items: Array[String]): Unit = 107 | Plot.lineChartSeq( 108 | (play(Greedy.incrementalConfig(0.0).policy), "0.0"), 109 | (play(Greedy.incrementalConfig(0.01).policy), "0.01"), 110 | (play(Greedy.incrementalConfig(0.1).policy), "0.1") 111 | ) 112 | } 113 | -------------------------------------------------------------------------------- /scala-rl-book/src/main/scala/com/scalarl/book/Chapter3.scala: -------------------------------------------------------------------------------- 1 | /** This chapter plays a couple of gridworld games. Current goal is to get this all building, and 2 | * printing nicely. 3 | * 4 | * This chapter introduces the idea of the Markov Decision Process. 5 | */ 6 | package com.scalarl 7 | package book 8 | 9 | import cats.Id 10 | import com.scalarl.algebra.ToDouble 11 | import com.scalarl.logic.Sweep 12 | import com.scalarl.plot.Tabulator 13 | import com.scalarl.policy.Greedy 14 | import com.scalarl.value.DecayState 15 | import com.scalarl.world.GridWorld 16 | import com.scalarl.world.util.Grid 17 | 18 | object Chapter3 { 19 | import com.scalarl.world.util.Grid.{Bounds, Move, Position} 20 | 21 | // Configuration for the gridworld used in the examples. 22 | val gridConf = GridWorld 23 | .Config(Bounds(5, 5)) 24 | .withJump(Position.of(0, 1), Position.of(4, 1), 10) 25 | .withJump(Position.of(0, 3), Position.of(2, 3), 5) 26 | 27 | val allowedIterations: Long = 10000 28 | val epsilon: Double = 1e-4 29 | val gamma: Double = 0.9 30 | 31 | val emptyFn = StateValueFn.empty[Position, DecayState[Double]]( 32 | DecayState.DecayedValue(0.0) 33 | ) 34 | 35 | def notConverging(iterations: Long, allowed: Long): Boolean = 36 | iterations >= allowed 37 | 38 | /** Note... this version, following the python code, checks that the sum of all differences is 39 | * less than epsilon. In the next chapter we use the max function instead here to get this 40 | * working, to check that the maximum delta is less than epsilon. 41 | */ 42 | def valueFunctionConverged[Obs, T: ToDouble]( 43 | l: StateValueFn[Obs, T], 44 | r: StateValueFn[Obs, T] 45 | ): Boolean = Sweep.diffBelow(l, r, epsilon)(_ + _) 46 | 47 | def shouldStop[Obs, T: ToDouble]( 48 | l: StateValueFn[Obs, T], 49 | r: StateValueFn[Obs, T], 50 | iterations: Long 51 | ): Boolean = 52 | notConverging(iterations, allowedIterations) || 53 | valueFunctionConverged(l, r) 54 | 55 | def toTable( 56 | conf: GridWorld.Config, 57 | f: Position => Double 58 | ): Iterable[Iterable[Double]] = 59 | Grid 60 | .allStates(conf.bounds) 61 | .map(g => f(g.position)) 62 | .toArray 63 | .grouped(conf.bounds.numRows) 64 | .toSeq 65 | .map(_.toSeq) 66 | 67 | def printFigure[T: ToDouble]( 68 | conf: GridWorld.Config, 69 | pair: (StateValueFn[Position, T], Long), 70 | title: String 71 | ): Unit = { 72 | val (valueFn, iterations) = pair 73 | println(s"${title}:") 74 | println( 75 | Tabulator.format( 76 | toTable(conf, p => ToDouble[T].apply(valueFn.stateValue(p))) 77 | ) 78 | ) 79 | println(s"That took $iterations iterations, for the record.") 80 | } 81 | 82 | /** This is Figure 3.2, with proper stopping conditions and everything. Lots of work to go. 83 | */ 84 | def threeTwo: (StateValueFn[Position, DecayState[Double]], Long) = 85 | Sweep.sweepUntil[Position, Move, Double, DecayState[Double], Cat, Id]( 86 | emptyFn, 87 | _ => Policy.random[Position, Move, Double, Id], 88 | DecayState.bellmanFn(gamma), 89 | gridConf.stateSweep, 90 | shouldStop _, 91 | inPlace = true, 92 | valueIteration = false 93 | ) 94 | 95 | /** This is Figure 3.5. This is currently working! 96 | */ 97 | def threeFive: (StateValueFn[Position, DecayState[Double]], Long) = { 98 | implicit val dm = DecayState.decayStateModule(gamma) 99 | Sweep.sweepUntil[Position, Move, Double, DecayState[Double], Cat, Id]( 100 | emptyFn, 101 | fn => 102 | Greedy 103 | .Config[Double, DecayState[Double]]( 104 | 0.0, 105 | DecayState.Reward(_), 106 | (a, b) => DecayState.decayStateGroup[Double](gamma).plus(a, b), 107 | DecayState.DecayedValue(0.0) 108 | ) 109 | .id(fn), 110 | DecayState.bellmanFn(gamma), 111 | gridConf.stateSweep, 112 | shouldStop _, 113 | inPlace = true, 114 | valueIteration = true 115 | ) 116 | } 117 | 118 | /** This currently works, and displays rough tables for each of the required bits. 119 | */ 120 | def main(items: Array[String]): Unit = { 121 | printFigure(gridConf, threeTwo, "Figure 3.2") 122 | printFigure(gridConf, threeFive, "Figure 3.5") 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /scala-rl-book/src/main/scala/com/scalarl/book/Chapter6.scala: -------------------------------------------------------------------------------- 1 | /** Attempts at sarsa and other algorithms. 2 | */ 3 | package com.scalarl 4 | package book 5 | 6 | object Chapter6 { 7 | def main(items: Array[String]): Unit = () 8 | } 9 | -------------------------------------------------------------------------------- /scala-rl-book/src/test/scala/com/scalarl/book/Chapter3Spec.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package book 3 | 4 | import com.scalarl.logic.Sweep 5 | import com.scalarl.value.DecayState 6 | import com.scalarl.world.util.Grid 7 | import org.scalatest.funsuite.AnyFunSuite 8 | 9 | /** And this is a placeholder for basic tests. 10 | */ 11 | class Chapter3Spec extends AnyFunSuite { 12 | import Grid.{Move, Position} 13 | 14 | val epsilon = 1e-4 15 | val gamma = 0.9 16 | val zeroValue = DecayState.DecayedValue(0.0) 17 | 18 | test("Figure 3.2's value function matches the gold set") { 19 | val (actual, _) = Chapter3.threeTwo 20 | val expected = StateValueFn.Base[Position, DecayState[Double]]( 21 | Map( 22 | Position.of(0, 0) -> 3.3090, 23 | Position.of(0, 1) -> 8.7893, 24 | Position.of(0, 2) -> 4.4276, 25 | Position.of(0, 3) -> 5.3223, 26 | Position.of(0, 4) -> 1.4921, 27 | Position.of(1, 0) -> 1.5216, 28 | Position.of(1, 1) -> 2.9923, 29 | Position.of(1, 2) -> 2.2501, 30 | Position.of(1, 3) -> 1.9075, 31 | Position.of(1, 4) -> 0.5474, 32 | Position.of(2, 0) -> 0.0508, 33 | Position.of(2, 1) -> 0.7381, 34 | Position.of(2, 2) -> 0.6731, 35 | Position.of(2, 3) -> 0.3582, 36 | Position.of(2, 4) -> -0.4031, 37 | Position.of(3, 0) -> -0.9735, 38 | Position.of(3, 1) -> -0.4354, 39 | Position.of(3, 2) -> -0.3548, 40 | Position.of(3, 3) -> -0.5855, 41 | Position.of(3, 4) -> -1.1830, 42 | Position.of(4, 0) -> -1.8576, 43 | Position.of(4, 1) -> -1.3452, 44 | Position.of(4, 2) -> -1.2292, 45 | Position.of(4, 3) -> -1.4229, 46 | Position.of(4, 4) -> -1.9751 47 | ).mapValues(DecayState.DecayedValue(_)), 48 | zeroValue 49 | ) 50 | 51 | assert(Sweep.diffBelow(actual, expected, epsilon)(_.max(_))) 52 | } 53 | 54 | val expectedThreeFive = StateValueFn.Base[Position, DecayState[Double]]( 55 | Map( 56 | Position.of(0, 0) -> 21.9774, 57 | Position.of(0, 1) -> 24.4194, 58 | Position.of(0, 2) -> 21.9774, 59 | Position.of(0, 3) -> 19.4194, 60 | Position.of(0, 4) -> 17.4774, 61 | Position.of(1, 0) -> 19.7797, 62 | Position.of(1, 1) -> 21.9774, 63 | Position.of(1, 2) -> 19.7797, 64 | Position.of(1, 3) -> 17.8017, 65 | Position.of(1, 4) -> 16.0215, 66 | Position.of(2, 0) -> 17.8017, 67 | Position.of(2, 1) -> 19.7797, 68 | Position.of(2, 2) -> 17.8017, 69 | Position.of(2, 3) -> 16.0215, 70 | Position.of(2, 4) -> 14.4194, 71 | Position.of(3, 0) -> 16.0215, 72 | Position.of(3, 1) -> 17.8017, 73 | Position.of(3, 2) -> 16.0215, 74 | Position.of(3, 3) -> 14.4194, 75 | Position.of(3, 4) -> 12.9774, 76 | Position.of(4, 0) -> 14.4194, 77 | Position.of(4, 1) -> 16.0215, 78 | Position.of(4, 2) -> 14.4194, 79 | Position.of(4, 3) -> 12.9774, 80 | Position.of(4, 4) -> 11.6797 81 | ).mapValues(DecayState.DecayedValue(_)), 82 | zeroValue 83 | ) 84 | 85 | test("Figure 3.5's value function matches the gold set.") { 86 | val (actual, _) = Chapter3.threeFive 87 | assert(Sweep.diffBelow(actual, expectedThreeFive, epsilon)(_.max(_))) 88 | } 89 | 90 | test("Figure 3.5's calculation matches the full categorical version") { 91 | val idToCat = Util.idToMonad[Cat] 92 | 93 | // Empty value function to start. 94 | val emptyFn = StateValueFn.empty[Position, DecayState[Double]](zeroValue) 95 | 96 | // Build a Stochastic version of the greedy policy. 97 | implicit val dm = DecayState.decayStateModule(gamma) 98 | val stochasticConf = policy.Greedy.Config[Double, DecayState[Double]]( 99 | 0.0, 100 | DecayState.Reward(_), 101 | (a, b) => DecayState.decayStateGroup[Double](gamma).plus(a, b), 102 | zeroValue 103 | ) 104 | 105 | val (actual, _) = 106 | Sweep.sweepUntil[Position, Move, Double, DecayState[Double], Cat, Cat]( 107 | emptyFn, 108 | stochasticConf.stochastic[Position, Move](_), 109 | DecayState.bellmanFn(gamma), 110 | Chapter3.gridConf.stateSweep.map(_.mapK(idToCat)), 111 | Chapter3.shouldStop _, 112 | inPlace = true, 113 | valueIteration = true 114 | ) 115 | assert(Sweep.diffBelow(actual, expectedThreeFive, epsilon)(_.max(_))) 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /scala-rl-book/src/test/scala/com/scalarl/book/Chapter4Spec.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package book 3 | 4 | import com.scalarl.logic.Sweep 5 | import com.scalarl.value.DecayState 6 | import com.scalarl.world.util.Grid 7 | import org.scalatest.funsuite.AnyFunSuite 8 | 9 | /** And this is a placeholder for basic tests. 10 | */ 11 | class Chapter4Spec extends AnyFunSuite { 12 | import Grid.{Move, Position} 13 | 14 | val gamma = 1.0 15 | val epsilon = 1e-3 16 | val zeroValue = DecayState.DecayedValue(0.0) 17 | val expectedFourOne = StateValueFn.Base[Position, DecayState[Double]]( 18 | Map( 19 | Position.of(0, 0) -> 0.0, 20 | Position.of(0, 1) -> -13.9989, 21 | Position.of(0, 2) -> -19.9984, 22 | Position.of(0, 3) -> -21.9982, 23 | Position.of(1, 0) -> -13.9989, 24 | Position.of(1, 1) -> -17.9986, 25 | Position.of(1, 2) -> -19.9984, 26 | Position.of(1, 3) -> -19.9984, 27 | Position.of(2, 0) -> -19.9984, 28 | Position.of(2, 1) -> -19.9984, 29 | Position.of(2, 2) -> -17.9986, 30 | Position.of(2, 3) -> -13.9989, 31 | Position.of(3, 0) -> -21.9982, 32 | Position.of(3, 1) -> -19.9984, 33 | Position.of(3, 2) -> -13.9989, 34 | Position.of(3, 3) -> 0.0 35 | ).mapValues(DecayState.DecayedValue(_)), 36 | zeroValue 37 | ) 38 | 39 | test("Figure 4.1's value function matches the gold set") { 40 | val (actual, _) = Chapter4.fourOne(inPlace = false) 41 | assert(Sweep.diffBelow(actual, expectedFourOne, epsilon)(_.max(_))) 42 | } 43 | 44 | test("Figure 4.1's calculation matches the full categorical version") { 45 | val idToCat = Util.idToMonad[Cat] 46 | 47 | // Empty value function to start. 48 | val emptyFn = StateValueFn.empty[Position, DecayState[Double]](zeroValue) 49 | 50 | val (actual, _) = 51 | Sweep.sweepUntil[Position, Move, Double, DecayState[Double], Cat, Cat]( 52 | emptyFn, 53 | _ => Policy.random[Position, Move, Double, Cat], 54 | DecayState.bellmanFn(gamma), 55 | Chapter4.gridConf.stateSweep.map(_.mapK(idToCat)), 56 | Chapter4.shouldStop(_, _, _), 57 | inPlace = true, 58 | valueIteration = true 59 | ) 60 | assert(Sweep.diffBelow(actual, expectedFourOne, epsilon)(_.max(_))) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/Agent.scala: -------------------------------------------------------------------------------- 1 | /** An Agent is a combination of a Policy and a value function. 2 | * 3 | * The whole markov decision process is a weighted, directed graph. The backup diagrams we see are 4 | * subgraphs; 5 | * 6 | * The nodes are: 7 | * 8 | * - State nodes, with edges leading out to each possible action. 9 | * - Action nodes, with edges leading out to (reward, state) pairs. 10 | * 11 | * Policies are maps of State => Map[A, Weight]. I don't know that I have a policy that is NOT 12 | * that. 13 | * 14 | * StateValueFn instances are records of the values at particular State nodes. 15 | * 16 | * So to get the value of an ACTION node you need either: 17 | * 18 | * - To track it directly, with an ActionValueFn, or 19 | * - to estimate it with some model of the dynamics of the system. 20 | * 21 | * TODO - Key questions: \- Can I rethink the interface here? Can StateValueFn instances ONLY be 22 | * calculated for... rings where the weights add up to 1? "Affine combination" is the key idea 23 | * here... a linear combination where the set of scalars adds to 1. (Read more about affine 24 | * combinations here: https://www.sciencedirect.com/topics/computer-science/affine-combination) \- 25 | * Does that mean that we have a StateValueFn of ONLY DOUBLES? 26 | * 27 | * # The four key ideas: (Policy, EnvModel are a pair) (ActionValueFn, StateValueFn are a pair) 28 | * 29 | * TODO - what is an actual WORLD here? It's something that stochastically returns the same things 30 | * an EnvModel would, of course. An EnvModel should give me the dynamics for any particular state I 31 | * happen to want, for anything I happen to find myself in. That can work for Blackjack, for 32 | * example. 33 | * 34 | * FINAL COMMENTS: \- Do we have to restrict the type of StateValueFn to be ONLY a double for now, 35 | * until I can generalize? \- What are the remaining concepts... ways to update these various 36 | * features by walking around in the graph. \- TODO NEXT - clean up what StateValueFn is actually 37 | * doing. Can it have an internal agg type? \- AGENT can have an internal thing it uses to track 38 | * experience. That's the object oriented version... and the various algorithms are ways to 39 | * propagate credit back. \- The more complicated ones can also access the four things that the 40 | * agent has. \- YOU CAN TOTALLY have a policy that does not need a numeric value... but that just 41 | * needs something with an ordering. So the policy has less stringent requirements than the actual 42 | * estimating thing. TODO can this help me at all? \- Will I need Value, or ToDouble at the end of 43 | * all of this? 44 | * 45 | * Here's a piece on graph learning on reinforcement learning problems: 46 | * http://proceedings.mlr.press/v89/madjiheurem19a/madjiheurem19a.pdf 47 | * 48 | * A combo of a policy and its value function. Could be helpful navigating the monte carlo stuff. 49 | * 50 | * The bandits are actually agent instances. AND, they probably need to each keep their own 51 | * internal aggregation types to use inside their value functions. 52 | * 53 | * TODO take these notes. \- an agent is the only thing that has a full view on the graph. \- the 54 | * graph is a directed, weighted graph... to get an expected value you have to normalize the 55 | * weights and multiply each by the values of the next nodes. \- the "backup diagram" is the 56 | * subgraph that the agent looked at to make its decision. 57 | */ 58 | package com.scalarl 59 | 60 | import cats.Monad 61 | 62 | trait Agent[Obs, A, @specialized(Int, Long, Float, Double) R, T, M[_]] { self => 63 | type This = Agent[Obs, A, R, T, M] 64 | 65 | def monad: Monad[M] 66 | def policy: Policy[Obs, A, R, M, M] 67 | def valueFunction: StateValueFn[Obs, T] 68 | 69 | def play(state: State[Obs, A, R, M]): M[(state.This, SARS[Obs, A, R, M])] = 70 | monad.flatMap(policy.choose(state)) { a => 71 | monad.map(state.act(a)) { case (r, s2) => 72 | (s2, SARS(state, a, r, s2)) 73 | } 74 | } 75 | } 76 | 77 | object Agent { 78 | 79 | /** Agent that can't learn. 80 | */ 81 | case class StaticAgent[Obs, A, R, T, M[_]]( 82 | policy: Policy[Obs, A, R, M, M], 83 | valueFunction: StateValueFn[Obs, T] 84 | )(implicit val monad: Monad[M]) 85 | extends Agent[Obs, A, R, T, M] 86 | } 87 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/Evaluator.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | 3 | import com.scalarl.evaluate.{ActionValue, StateValue} 4 | import com.scalarl.algebra.{Expectation, Module} 5 | 6 | /** Contains traits and instances for the two evaluation methods. 7 | */ 8 | object Evaluator { 9 | import Module.DModule 10 | 11 | /** Evaluator that uses a world's dynamics to estimate the value of a given action. 12 | */ 13 | def oneAhead[Obs, A, R, G: DModule, M[_], S[_]: Expectation]( 14 | valueFn: StateValueFn[Obs, G], 15 | prepare: R => G, 16 | merge: (G, G) => G 17 | ): ActionValue[Obs, A, R, G, S] = 18 | StateValue 19 | .fn[Obs, A, R, G, S](valueFn) 20 | .byStateValue(prepare, merge) 21 | 22 | /** The full bellman estimation, where we know the dynamics of the policy and of the system. 23 | * 24 | * Could also be defined by 25 | * 26 | * {{{ 27 | * state.byPolicy(policy).byStateValue(prepare, merge).apply(valueFn) 28 | * }}} 29 | */ 30 | def bellman[Obs, A, R, G: DModule, M[_]: Expectation, S[_]: Expectation]( 31 | valueFn: StateValueFn[Obs, G], 32 | policy: Policy[Obs, A, R, M, S], 33 | prepare: R => G, 34 | merge: (G, G) => G 35 | ): StateValue[Obs, A, R, G, S] = 36 | StateValue 37 | .fn[Obs, A, R, G, S](valueFn) // statevalue 38 | .byStateValue(prepare, merge) // actionvalue 39 | .byPolicy(policy) // statevalue 40 | 41 | /** This is my attempt at getting a better builder syntax going! 42 | */ 43 | def state[Obs, A, R, G, S[_]]: FromState[Obs, A, R, G, S, StateValue] = 44 | new FromState(f => f) 45 | 46 | def action[Obs, A, R, G, S[_]]: FromAction[Obs, A, R, G, S, ActionValue] = 47 | new FromAction(f => f) 48 | 49 | /** Builder class that manages conversion of an [[evaluate.StateValue]] instance into either an 50 | * [[evaluate.ActionValue]] or [[evaluate.StateValue]] instance. 51 | */ 52 | class FromState[Obs, A, R, G, S[_], F[_, _, _, _, *[_]]] private[scalarl] ( 53 | f: StateValue[Obs, A, R, G, S] => F[Obs, A, R, G, S] 54 | ) { 55 | def fn(vfn: StateValueFn[Obs, G]): F[Obs, A, R, G, S] = 56 | f(StateValue.fn(vfn)) 57 | 58 | def byPolicy[M[_]]( 59 | policy: Policy[Obs, A, R, M, S] 60 | )(implicit 61 | M: Expectation[M], 62 | MV: Module[Double, G] 63 | ): FromAction[Obs, A, R, G, S, F] = 64 | new FromAction[Obs, A, R, G, S, F](fn => f(fn.byPolicy(policy))) 65 | } 66 | 67 | /** Builder class that manages conversion of an [[evaluate.ActionValue]] instance into either an 68 | * [[evaluate.ActionValue]] or [[evaluate.StateValue]] instance. 69 | */ 70 | class FromAction[Obs, A, R, G, S[_], F[_, _, _, _, *[_]]] private[scalarl] ( 71 | f: ActionValue[Obs, A, R, G, S] => F[Obs, A, R, G, S] 72 | ) { 73 | def fn(vfn: ActionValueFn[Obs, A, G]): F[Obs, A, R, G, S] = 74 | f(ActionValue.fn(vfn)) 75 | 76 | def byStateValue( 77 | prepare: R => G, 78 | merge: (G, G) => G 79 | )(implicit 80 | S: Expectation[S], 81 | MV: Module[Double, G] 82 | ): FromState[Obs, A, R, G, S, F] = 83 | new FromState[Obs, A, R, G, S, F](fn => f(fn.byStateValue(prepare, merge))) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/Policy.scala: -------------------------------------------------------------------------------- 1 | /** Policy implementation, get some! 2 | */ 3 | package com.scalarl 4 | 5 | import cats.{Functor, Id} 6 | import cats.arrow.FunctionK 7 | import com.scalarl.evaluate.ActionValue 8 | import com.scalarl.rainier.Categorical 9 | 10 | import scala.language.higherKinds 11 | 12 | /** This is how agents actually choose what comes next. This is a stochastic policy. We have to to 13 | * be able to match this up with a state that has the same monadic return type, but for now it's 14 | * hardcoded. 15 | * 16 | * A - Action Obs - the observation offered by this state. R - reward M - the monadic type offered 17 | * by the policy. S - the monad for the state. 18 | */ 19 | trait Policy[Obs, A, @specialized(Int, Long, Float, Double) R, M[_], S[_]] { 20 | self => 21 | type This = Policy[Obs, A, R, M, S] 22 | 23 | def choose(state: State[Obs, A, R, S]): M[A] 24 | def learn(sars: SARS[Obs, A, R, S]): This = self 25 | 26 | def contramapObservation[P]( 27 | f: P => Obs 28 | )(implicit S: Functor[S]): Policy[P, A, R, M, S] = 29 | new Policy[P, A, R, M, S] { 30 | override def choose(state: State[P, A, R, S]) = 31 | self.choose(state.mapObservation(f)) 32 | override def learn(sars: SARS[P, A, R, S]) = 33 | self.learn(sars.mapObservation(f)).contramapObservation(f) 34 | } 35 | 36 | def contramapReward[T]( 37 | f: T => R 38 | )(implicit S: Functor[S]): Policy[Obs, A, T, M, S] = 39 | new Policy[Obs, A, T, M, S] { 40 | override def choose(state: State[Obs, A, T, S]) = 41 | self.choose(state.mapReward(f)) 42 | override def learn(sars: SARS[Obs, A, T, S]) = 43 | self.learn(sars.mapReward(f)).contramapReward(f) 44 | } 45 | 46 | /** Just an idea to see if I can make stochastic deciders out of deterministic deciders. We'll see 47 | * how this develops. 48 | */ 49 | def mapK[N[_]](f: FunctionK[M, N]): Policy[Obs, A, R, N, S] = 50 | new Policy[Obs, A, R, N, S] { r => 51 | override def choose(state: State[Obs, A, R, S]): N[A] = f( 52 | self.choose(state) 53 | ) 54 | override def learn( 55 | sars: SARS[Obs, A, R, S] 56 | ): Policy[Obs, A, R, N, S] = 57 | self.learn(sars).mapK(f) 58 | } 59 | } 60 | 61 | object Policy { 62 | 63 | /** If all you care about is a choose fn. 64 | */ 65 | def choose[Obs, A, R, M[_], S[_]]( 66 | chooseFn: State[Obs, A, R, S] => M[A] 67 | ): Policy[Obs, A, R, M, S] = 68 | new Policy[Obs, A, R, M, S] { self => 69 | override def choose(state: State[Obs, A, R, S]): M[A] = chooseFn(state) 70 | } 71 | 72 | /** Full exploration. mapK(Categorical.setToCat) to get the usual Greedy. 73 | */ 74 | def random[Obs, A, R, S[_]]: Policy[Obs, A, R, Cat, S] = 75 | Policy.choose(s => Categorical.fromSet(s.actions)) 76 | 77 | /** Full greed. mapK(Categorical.setToCat) to get the usual Greedy. 78 | */ 79 | def greedy[Obs, A, R, T: Ordering, S[_]]( 80 | evaluator: ActionValue[Obs, A, R, T, S] 81 | ): Policy[Obs, A, R, Cat, S] = 82 | choose(s => Categorical.fromSet(evaluator.greedyOptions(s))) 83 | 84 | /** In between. This is equal to 85 | * 86 | * {{{ 87 | * epsilonGreedy(evaluator, 1.0) == greedy(evaluator).mapK(Cat.setToCat) 88 | * epsilonGreedy(evaluator, 0.0) == random.mapK(Cat.setToCat) 89 | * }}} 90 | */ 91 | def epsilonGreedy[Obs, A, R, T: Ordering, S[_]]( 92 | evaluator: ActionValue[Obs, A, R, T, S], 93 | epsilon: Double 94 | ): policy.Greedy[Obs, A, R, T, S] = new policy.Greedy(evaluator, epsilon) 95 | 96 | /** Always return the same. 97 | */ 98 | def constant[Obs, A, R, S[_]](a: A): Policy[Obs, A, R, Id, S] = 99 | choose[Obs, A, R, Id, S](_ => a) 100 | } 101 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/SARS.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | 3 | import cats.Functor 4 | 5 | /** Represents a single step in a reinforcement learning episode. 6 | * 7 | * SARS stands for State-Action-Reward-State, capturing the complete transition: 8 | * - The initial state the agent was in 9 | * - The action the agent took 10 | * - The reward received for taking that action 11 | * - The next state the environment transitioned to 12 | */ 13 | final case class SARS[Obs, A, R, S[_]]( 14 | state: State[Obs, A, R, S], 15 | action: A, 16 | reward: R, 17 | nextState: State[Obs, A, R, S] 18 | ) { 19 | 20 | /** Maps the observation type of this SARS to a new type. 21 | * 22 | * @param f 23 | * The function to transform the observation from type Obs to type P 24 | * @param S 25 | * Evidence that S has a Functor instance 26 | */ 27 | def mapObservation[P](f: Obs => P)(implicit S: Functor[S]): SARS[P, A, R, S] = 28 | SARS(state.mapObservation(f), action, reward, nextState.mapObservation(f)) 29 | 30 | /** Maps the reward type of this SARS to a new type. 31 | * 32 | * @param f 33 | * The function to transform the reward from type R to type T 34 | * @param S 35 | * Evidence that S has a Functor instance 36 | */ 37 | def mapReward[T](f: R => T)(implicit S: Functor[S]): SARS[Obs, A, T, S] = 38 | SARS(state.mapReward(f), action, f(reward), nextState.mapReward(f)) 39 | 40 | } 41 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/State.scala: -------------------------------------------------------------------------------- 1 | /** Okay, let's see what we can do for this bandit thing. We want something that can run as a 2 | * bandit, and then a world to run it in. 3 | * 4 | * FIRST STEP: 5 | * 6 | * - Recreate figure 2_2, playing the bandit for a while. see what happens. 7 | * - epsilon-greedy 8 | * - ucb 9 | * - gradient 10 | * 11 | * I would use it but I think it's old. 12 | * 13 | * Then... do we go to the state monad, on TOP of the generator, to return the reward? Instead of 14 | * returning the pair directly? 15 | * 16 | * Let's give it a try! 17 | */ 18 | package com.scalarl 19 | 20 | import cats.Functor 21 | import cats.arrow.FunctionK 22 | 23 | object State { 24 | type ActionView[Obs, A, R, M[_]] = M[(R, State[Obs, A, R, M])] 25 | type Dynamics[Obs, A, R, M[_]] = Map[A, ActionView[Obs, A, R, M]] 26 | } 27 | 28 | /** A world should probably have a generator of states and actions... and then you can use that to 29 | * get to the next thing. The state here is going to be useful in the Markov model; for the bandit 30 | * we only have a single state, not that useful. 31 | */ 32 | trait State[Obs, A, @specialized(Int, Long, Float, Double) R, M[_]] { self => 33 | type This = State[Obs, A, R, M] 34 | 35 | def observation: Obs 36 | 37 | /** For every action you could take, returns a generator of the next set of rewards. This is a 38 | * real world, or a sample model. If we want the full distribution we're going to have to build 39 | * out a better interface. Good enough for now. 40 | */ 41 | def dynamics: Map[A, M[(R, This)]] 42 | def invalidMove: M[(R, This)] 43 | 44 | def actions: Set[A] = dynamics.keySet 45 | def act(action: A): M[(R, This)] = dynamics.getOrElse(action, invalidMove) 46 | 47 | def isTerminal: Boolean = actions.isEmpty 48 | 49 | /** Maps the observation type of this state to a new type. 50 | * 51 | * @param f 52 | * The function to transform the observation from type Obs to type P 53 | * @param M 54 | * Evidence that M has a Functor instance 55 | * @return 56 | * A new State with observations of type P but the same actions and rewards 57 | */ 58 | def mapObservation[P]( 59 | f: Obs => P 60 | )(implicit M: Functor[M]): State[P, A, R, M] = 61 | new State[P, A, R, M] { 62 | private def innerMap(pair: M[(R, State[Obs, A, R, M])]) = 63 | M.map(pair) { case (r, s) => (r, s.mapObservation(f)) } 64 | override def observation = f(self.observation) 65 | override def dynamics = self.dynamics.mapValues(innerMap(_)) 66 | override def invalidMove = innerMap(self.invalidMove) 67 | override def act(action: A) = innerMap(self.act(action)) 68 | override def actions: Set[A] = self.actions 69 | } 70 | 71 | def mapReward[T](f: R => T)(implicit M: Functor[M]): State[Obs, A, T, M] = 72 | new State[Obs, A, T, M] { 73 | private def innerMap(pair: M[(R, State[Obs, A, R, M])]) = 74 | M.map(pair) { case (r, s) => (f(r), s.mapReward(f)) } 75 | 76 | override def observation = self.observation 77 | override def dynamics = self.dynamics.mapValues(innerMap(_)) 78 | override def invalidMove = innerMap(self.invalidMove) 79 | override def act(action: A) = innerMap(self.act(action)) 80 | override def actions: Set[A] = self.actions 81 | } 82 | 83 | def mapK[N[_]]( 84 | f: FunctionK[M, N] 85 | )(implicit N: Functor[N]): State[Obs, A, R, N] = new State[Obs, A, R, N] { 86 | private def innerMap(pair: M[(R, State[Obs, A, R, M])]) = 87 | N.map(f(pair)) { case (r, s) => (r, s.mapK(f)) } 88 | override def observation = self.observation 89 | override def dynamics = self.dynamics.mapValues(innerMap(_)) 90 | override def invalidMove = innerMap(self.invalidMove) 91 | override def act(action: A) = innerMap(self.act(action)) 92 | override def actions: Set[A] = self.actions 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/StateValueFn.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | 3 | import com.twitter.algebird.{Aggregator, Monoid, MonoidAggregator, Semigroup} 4 | import com.scalarl.evaluate.StateValue 5 | 6 | /** Along with [[ActionValueFn]], this is the main trait in tabular reinforcement learning for 7 | * tracking the value of a state as evidenced by the observation it returns. 8 | * 9 | * We need some way for this to learn, or see new observations, that's part of the trait. 10 | * 11 | * @tparam Obs 12 | * Observation returned by the [[State]] instances tracked by [[StateValueFn]]. 13 | * @tparam T 14 | * type of values tracked by [[StateValueFn]]. 15 | */ 16 | trait StateValueFn[Obs, T] { self => 17 | 18 | /** Returns an Iterable of all observations associated with some internally tracked value T. 19 | */ 20 | def seen: Iterable[Obs] 21 | 22 | /** Returns the stored value associated with the given observation. 23 | */ 24 | def stateValue(obs: Obs): T 25 | 26 | /** Absorb a new value for the supplied observation. The behavior of this function is 27 | * implementation dependent; some might ignore the value, some might merge it in to an existing 28 | * set of values, some might completely replace the stored state. 29 | */ 30 | def update(state: Obs, value: T): StateValueFn[Obs, T] 31 | 32 | /** Transforms this [[StateValueFn]] into a new instance that applies the supplied `prepare` to 33 | * all incoming values before they're learned, and presents tracked T instances using the 34 | * `present` fn before returning them via [[stateValue]]. 35 | * 36 | * @tparam the 37 | * type of value stored by the returned [[StateValueFn]]. 38 | */ 39 | def fold[U](prepare: U => T, present: T => U): StateValueFn[Obs, U] = 40 | new StateValueFn.Folded[Obs, T, U](self, prepare, present) 41 | 42 | /** Returns a [[StateValueFn]] instance that uses the supplied semigroup T to merge values into 43 | * this current [[StateValueFn]]. 44 | * 45 | * @param T 46 | * Semigroup instance used to merge values. 47 | */ 48 | def mergeable(implicit T: Semigroup[T]): StateValueFn[Obs, T] = 49 | new StateValueFn.Mergeable(self) 50 | 51 | /** TODO fill in. 52 | */ 53 | def toEvaluator[A, R, S[_]]: StateValue[Obs, A, R, T, S] = 54 | StateValue.fn(self) 55 | } 56 | 57 | /** Constructors and classes associated with [[StateValueFn]]. 58 | */ 59 | object StateValueFn { 60 | 61 | /** Returns an empty [[StateValueFn]] backed by an immutable map. 62 | */ 63 | def empty[Obs, T]: StateValueFn[Obs, Option[T]] = 64 | empty[Obs, Option[T]](None) 65 | 66 | /** Returns an empty [[StateValueFn]] backed by an immutable map. The supplied default value will 67 | * be returned by [[StateValueFn.stateValue]] for any obs that's not been seen by the 68 | * [[StateValueFn]]. 69 | */ 70 | def empty[Obs, T](default: T): StateValueFn[Obs, T] = 71 | Base(Map.empty, default) 72 | 73 | /** Returns an empty [[StateValueFn]] backed by an immutable map that uses the zero of the 74 | * supplied Monoid as a default value, and merges new learned values into the value in the 75 | * underlying map using the Monoid's `plus` function. 76 | */ 77 | def mergeable[Obs, T](implicit T: Monoid[T]): StateValueFn[Obs, T] = 78 | mergeable(T.zero) 79 | 80 | /** Returns an empty [[StateValueFn]] backed by an immutable map that uses the supplied `default` 81 | * as a default value, and merges new learned values into the value in the underlying map using 82 | * the Semigroup's `plus` function. 83 | */ 84 | def mergeable[Obs, T](default: T)(implicit 85 | T: Semigroup[T] 86 | ): StateValueFn[Obs, T] = 87 | empty(default).mergeable 88 | 89 | /** Returns a [[StateValueFn]] that: 90 | * 91 | * \- uses the supplied default as an initial value \- merges values in using the aggregator's 92 | * semigroup \- prepares and presents using the aggregator's analogous functions 93 | */ 94 | def fromAggregator[Obs, T, U]( 95 | default: T, 96 | agg: Aggregator[U, T, U] 97 | ): StateValueFn[Obs, U] = 98 | empty(default) 99 | .mergeable(agg.semigroup) 100 | .fold(agg.prepare, agg.present) 101 | 102 | /** Returns a [[StateValueFn]] that: 103 | * 104 | * \- uses the MonoidAggregator's monoid.zero as an initial value \- merges values in using the 105 | * aggregator's monoid \- prepares and presents using the aggregator's analogous functions 106 | */ 107 | def fromAggregator[Obs, T, U]( 108 | agg: MonoidAggregator[U, T, U] 109 | ): StateValueFn[Obs, U] = 110 | mergeable(agg.monoid).fold(agg.prepare, agg.present) 111 | 112 | /** Basic implementation of a [[StateValueFn]] that stores any value supplied to [[update]] in an 113 | * internal immutable map. 114 | * 115 | * @param m 116 | * the immutable map used for storage. 117 | * @param default 118 | * value returned by [[Base]] when queried for some observation it hasn't yet seen. 119 | */ 120 | case class Base[Obs, T](m: Map[Obs, T], default: T) extends StateValueFn[Obs, T] { self => 121 | override def seen: Iterable[Obs] = m.keySet 122 | override def stateValue(obs: Obs): T = m.getOrElse(obs, default) 123 | 124 | /** @inheritdoc 125 | * This implementation replaces any existing value with no merge or logic. 126 | */ 127 | override def update(obs: Obs, value: T): Base[Obs, T] = 128 | Base(m.updated(obs, value), default) 129 | } 130 | 131 | /** [[StateValueFn]] implementation that implements a fold. 132 | * 133 | * Any value supplied to [[update]] will be transformed first by prepare before being passed to 134 | * the base [[StateValueFn]]. Any value retrieved by [[stateValue]] will be passed to `present` 135 | * before being returned. 136 | */ 137 | case class Folded[Obs, T, U]( 138 | base: StateValueFn[Obs, T], 139 | prepare: U => T, 140 | present: T => U 141 | ) extends StateValueFn[Obs, U] { 142 | override def seen: Iterable[Obs] = base.seen 143 | override def stateValue(obs: Obs): U = present(base.stateValue(obs)) 144 | override def update(obs: Obs, value: U): StateValueFn[Obs, U] = 145 | Folded(base.update(obs, prepare(value)), prepare, present) 146 | } 147 | 148 | /** [[StateValueFn]] implementation that merges values passed to [[update]] into the value stored 149 | * by the base [[StateValueFn]] using the supplied Semigroup's `plus` function. 150 | */ 151 | case class Mergeable[Obs, T]( 152 | base: StateValueFn[Obs, T] 153 | )(implicit T: Semigroup[T]) 154 | extends StateValueFn[Obs, T] { 155 | override def seen: Iterable[Obs] = base.seen 156 | override def stateValue(obs: Obs): T = base.stateValue(obs) 157 | 158 | /** @inheritdoc 159 | * 160 | * This implementation replaces uses a Semigroup[T] to merge the supplied value in to whatever 161 | * value is stored in the underlying m. 162 | */ 163 | override def update(obs: Obs, value: T): StateValueFn[Obs, T] = { 164 | val merged = T.plus(stateValue(obs), value) 165 | Mergeable(base.update(obs, merged)) 166 | } 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/Time.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | 3 | /** A value class wrapper around Long that allows us to talk about time ticking and evolution in a 4 | * type-safe way. 5 | * 6 | * This class provides methods for incrementing time, comparing time values, and basic arithmetic 7 | * operations, while maintaining type safety through the AnyVal wrapper. 8 | */ 9 | 10 | case class Time(value: Long) extends AnyVal { 11 | def tick: Time = Time(value + 1) 12 | def -(r: Time) = value - r.value 13 | def +(r: Time) = value + r.value 14 | def <=(r: Time) = value <= r.value 15 | def <(r: Time) = value < r.value 16 | def compareTo(r: Time) = value.compareTo(r.value) 17 | } 18 | 19 | object Time { 20 | val Min: Time = Time(Long.MinValue) 21 | val Max: Time = Time(Long.MaxValue) 22 | val Zero: Time = Time(0L) 23 | } 24 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/Util.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | 3 | import cats.{Comonad, Id, Monad} 4 | import cats.arrow.FunctionK 5 | import cats.data.StateT 6 | import com.twitter.algebird.{AveragedValue, Fold, MonoidAggregator, Ring, Semigroup, VectorSpace} 7 | import com.stripe.rainier.compute.Real 8 | import com.scalarl.algebra.{Module, ToDouble} 9 | 10 | import scala.language.higherKinds 11 | 12 | object Util { 13 | import cats.syntax.functor._ 14 | 15 | /** Here we provide various "missing" typeclass instances sewing together algebird typeclasses and 16 | * implementing typeclasses for rainier types. 17 | */ 18 | object Instances { 19 | // this lets us sort AveragedValue instances... 20 | implicit val averageValueOrd: Ordering[AveragedValue] = 21 | Ordering.by(_.value) 22 | 23 | // shows how to extract the averaged value out from the accumulating data structure 24 | implicit val avToDouble: ToDouble[AveragedValue] = 25 | ToDouble.instance(_.value) 26 | 27 | // Module instance, representing a module that can scale AveragedValue by some scalar double. 28 | implicit val avModule: Module[Double, AveragedValue] = 29 | Module.from((r, av) => AveragedValue(av.count, r * av.value)) 30 | 31 | // easy, just expose this implicitly. 32 | implicit val realRing: Ring[Real] = RealRing 33 | 34 | // trivial VectorSpace, showing that the cats.Id monad (and any Ring R) form a vectorspace. 35 | implicit def idVectorSpace[R](implicit R: Ring[R]): VectorSpace[R, Id] = 36 | VectorSpace.from[R, Id](R.times(_, _)) 37 | 38 | // Ring instance for rainer Reals. 39 | object RealRing extends Ring[Real] { 40 | override def one = Real.one 41 | override def zero = Real.zero 42 | override def negate(v: Real) = -v 43 | override def plus(l: Real, r: Real) = l + r 44 | override def minus(l: Real, r: Real) = l - r 45 | override def times(l: Real, r: Real) = l * r 46 | } 47 | } 48 | 49 | /** Clamps a value between a minimum and maximum value. 50 | * 51 | * This function ensures that the input value `a` is not less than `min` and not greater than 52 | * `max`, returning the clamped value. 53 | * 54 | * @param a 55 | * The value to clamp. 56 | * @param min 57 | * The minimum value. 58 | */ 59 | def clamp[A](a: A, min: A, max: A)(implicit ord: Ordering[A]): A = 60 | ord.min(ord.max(a, min), max) 61 | 62 | /** Creates a Map from a set of keys using a function to generate values. 63 | * 64 | * This function takes a set of keys and a function that maps each key to a value, returning a 65 | * Map with the keys and their corresponding values. 66 | * 67 | * @param keys 68 | */ 69 | def makeMap[K, V](keys: Set[K])(f: K => V): Map[K, V] = makeMapUnsafe(keys)(f) 70 | 71 | /** similar to makeMap, but doesn't guarantee that there are not duplicate keys. If keys contains 72 | * duplicates, later keys override earlier keys. 73 | * 74 | * @param keys 75 | */ 76 | def makeMapUnsafe[K, V](keys: TraversableOnce[K])(f: K => V): Map[K, V] = 77 | keys.foldLeft(Map.empty[K, V]) { case (m, k) => 78 | m.updated(k, f(k)) 79 | } 80 | 81 | /** Update the key in the supplied map using the function - the function handles both cases, when 82 | * the item is there and when it's not. 83 | */ 84 | def updateWith[K, V](m: Map[K, V], k: K)(f: Option[V] => V): Map[K, V] = 85 | m.updated(k, f(m.get(k))) 86 | 87 | /** Merges a key and a value into a map using a semigroup to combine values. */ 88 | def mergeV[K, V: Semigroup](m: Map[K, V], k: K, delta: V): Map[K, V] = 89 | updateWith(m, k) { 90 | case None => delta 91 | case Some(v) => Semigroup.plus[V](v, delta) 92 | } 93 | 94 | /** Finds the keys with the maximum values in a map. 95 | */ 96 | def maxKeys[A, B: Ordering](m: Map[A, B]): Set[A] = allMaxBy(m.keySet)(m(_)) 97 | 98 | /** Returns the set of keys that map (via `f`) to the maximal B, out of all `as` transformed. 99 | */ 100 | def allMaxBy[A, B: Ordering](as: Set[A])(f: A => B): Set[A] = 101 | if (as.isEmpty) Set.empty 102 | else { 103 | val maxB = f(as.maxBy(f)) 104 | as.filter(a => Ordering[B].equiv(maxB, f(a))) 105 | } 106 | 107 | /** Iterates a monadic function `f` `n` of times using the starting value `a`. 108 | */ 109 | def iterateM[M[_], A]( 110 | n: Int 111 | )(a: A)(f: A => M[A])(implicit M: Monad[M]): M[A] = 112 | M.iterateWhileM((n, a)) { case (k, a) => 113 | f(a).map((k - 1, _)) 114 | }(_._1 > 0) 115 | .map(_._2) 116 | 117 | /** A version of iterateUntilM that uses an aggregator to store the auxiliary results kicked out 118 | * by the step function. 119 | */ 120 | def iterateUntilM[M[_], A, B, C, D](init: A, agg: MonoidAggregator[B, C, D])( 121 | f: A => M[(A, B)] 122 | )(p: A => Boolean)(implicit M: Monad[M]): M[(A, D)] = 123 | M.iterateUntilM((init, agg.monoid.zero)) { case (a, c) => 124 | f(a).map { case (a2, b) => 125 | (a2, agg.append(c, b)) 126 | } 127 | }(pair => p(pair._1)) 128 | .map { case (a, c) => (a, agg.present(c)) } 129 | 130 | /** A version of iterateUntilM that uses a Fold to store the auxiliary results kicked out by the 131 | * step function. 132 | */ 133 | def foldUntilM[M[_], A, B, C](init: A, fold: Fold[B, C])( 134 | f: A => M[(A, B)] 135 | )(p: A => Boolean)(implicit M: Monad[M]): M[(A, C)] = { 136 | val foldState = fold.build() 137 | M.iterateUntilM((init, foldState.start)) { case (a, c) => 138 | f(a).map { case (a2, b) => 139 | (a2, foldState.add(c, b)) 140 | } 141 | }(pair => p(pair._1)) 142 | .map { case (a, c) => (a, foldState.end(c)) } 143 | } 144 | 145 | /** And a helper function that will let me test this out with monoid aggregators, like the ones I 146 | * wrote to walk trajectories. 147 | */ 148 | def aggToFold[A, B, C](agg: MonoidAggregator[A, B, C]): Fold[A, C] = 149 | Fold.fold[B, A, C]( 150 | start = agg.monoid.zero, 151 | add = (b, a) => agg.monoid.plus(b, agg.prepare(a)), 152 | end = agg.present(_) 153 | ) 154 | 155 | /** A version of iterateWhileM that uses an aggregator to store the auxiliary results kicked out 156 | * by the step function. 157 | */ 158 | def iterateWhileM[M[_]: Monad, A, B, C, D]( 159 | init: A, 160 | agg: MonoidAggregator[B, C, D] 161 | )( 162 | f: A => M[(A, B)] 163 | )(p: A => Boolean): M[(A, D)] = 164 | iterateUntilM(init, agg)(f)(!p(_)) 165 | 166 | /** Unused for now... TODO try this out, get the interface going in state monad style! 167 | */ 168 | def runUntilM[M[_]: Monad, S, A, B, C]( 169 | state: StateT[M, S, A], 170 | agg: MonoidAggregator[A, B, C] 171 | )(p: S => Boolean): StateT[M, S, C] = 172 | StateT[M, S, C] { s => 173 | iterateUntilM(s, agg)(state.run(_))(p) 174 | } 175 | 176 | /** Accumulates differences between the two for every A in the supplied sequence. The combine 177 | * function is used to aggregate the differences. 178 | * 179 | * I recommend using max or +. 180 | */ 181 | def diff[A]( 182 | as: TraversableOnce[A], 183 | lf: A => Double, 184 | rf: A => Double, 185 | combine: (Double, Double) => Double 186 | ): Double = 187 | as.foldLeft(0.0) { (acc, k) => 188 | combine(acc, (lf(k) - rf(k)).abs) 189 | } 190 | 191 | /** Cats helpers. 192 | */ 193 | def idToMonad[M[_]](implicit M: Monad[M]): FunctionK[Id, M] = 194 | new FunctionK[Id, M] { 195 | def apply[A](a: A): M[A] = M.pure(a) 196 | } 197 | 198 | def mfk[M[_], N[_]](implicit M: Comonad[M], N: Monad[N]): FunctionK[M, N] = 199 | new FunctionK[M, N] { 200 | def apply[A](ma: M[A]): N[A] = N.pure(M.extract(ma)) 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/algebra/AffineCombination.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package algebra 3 | 4 | import cats.Id 5 | import com.twitter.algebird.Ring 6 | 7 | /** This is not currently used! Another attempt at a better thing, here... but I don't know if this 8 | * solves my problem of needing to compose up the stack. 9 | * 10 | * I had a note about this in [[Agent]]. 11 | */ 12 | trait AffineCombination[M[_], R] { 13 | implicit def ring: Ring[R] 14 | def get[A](ma: M[A])(f: A => R): R 15 | } 16 | 17 | object AffineCombination { 18 | // Contract is that if all A == R.one, and f = _ => R.one, the fn returns 19 | // R.one. 20 | def take[A, R: Ring](items: Iterator[(A, R)])(f: A => R)(implicit 21 | R: Ring[R] 22 | ): R = 23 | R.sum(items.map { case (a, r) => R.times(f(a), r) }) 24 | 25 | @inline final def apply[M[_], R](implicit 26 | M: AffineCombination[M, R] 27 | ): AffineCombination[M, R] = M 28 | 29 | implicit def id[R](implicit R: Ring[R]): AffineCombination[Id, R] = 30 | new AffineCombination[Id, R] { 31 | implicit val ring = R 32 | def get[A](a: A)(f: A => R) = f(a) 33 | } 34 | 35 | implicit def fromDecomposition[M[_], R](implicit 36 | D: Decompose[M, R] 37 | ): AffineCombination[M, R] = 38 | new AffineCombination[M, R] { 39 | implicit def ring = D.ring 40 | def get[A](ma: M[A])(f: A => R) = take(D.decompose(ma))(f) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/algebra/Decompose.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package algebra 3 | 4 | import cats.Id 5 | import com.stripe.rainier.compute.Real 6 | import com.stripe.rainier.core.{Categorical => RCat} 7 | import com.twitter.algebird.{DoubleRing, Ring} 8 | import com.scalarl.rainier.Categorical 9 | 10 | trait Decompose[M[_], R] extends Serializable { 11 | def ring: Ring[R] 12 | def decompose[A](ma: M[A]): Iterator[(A, R)] 13 | } 14 | 15 | object Decompose { 16 | @inline final def apply[M[_], R](implicit 17 | W: Decompose[M, R] 18 | ): Decompose[M, R] = W 19 | 20 | implicit def id[R](implicit R: Ring[R]): Decompose[Id, R] = 21 | new Decompose[Id, R] { 22 | override def ring = R 23 | override def decompose[A](a: A): Iterator[(A, R)] = Iterator((a, R.one)) 24 | } 25 | 26 | implicit def rcatDouble(implicit n: Numeric[Real]): Decompose[RCat, Double] = 27 | new Decompose[RCat, Double] { 28 | override val ring = DoubleRing 29 | override def decompose[A](ma: RCat[A]): Iterator[(A, Double)] = 30 | ma.pmf.iterator.map { case (a, r) => (a, n.toDouble(r)) } 31 | } 32 | 33 | implicit val rcatReal: Decompose[RCat, Real] = 34 | new Decompose[RCat, Real] { 35 | override val ring = Util.Instances.RealRing 36 | override def decompose[A](ma: RCat[A]): Iterator[(A, Real)] = 37 | ma.pmf.iterator 38 | } 39 | 40 | implicit val categoricalDouble: Decompose[Categorical, Double] = 41 | new Decompose[Categorical, Double] { 42 | override val ring = DoubleRing 43 | override def decompose[A](ma: Categorical[A]): Iterator[(A, Double)] = 44 | ma.pmfSeq.iterator 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/algebra/Expectation.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package algebra 3 | 4 | import cats.Id 5 | 6 | /** This definitely works, but I need to think through how we're going to be able to return things 7 | * like Futures, that have to communicate over the network. Does the value return type cover it? 8 | * 9 | * NOTE the implementation is responsible for normalizing. 10 | */ 11 | trait Expectation[M[_]] { 12 | def get[A, B](a: M[A])(f: A => B)(implicit M: Module[Double, B]): B 13 | } 14 | 15 | object Expectation extends ExpectationImplicits { 16 | @inline final def apply[M[_]](implicit M: Expectation[M]): Expectation[M] = M 17 | 18 | implicit val id: Expectation[Id] = new Expectation[Id] { 19 | def get[A, B](a: A)(f: A => B)(implicit M: Module[Double, B]): B = f(a) 20 | } 21 | } 22 | 23 | trait ExpectationImplicits { 24 | implicit def fromDecomposition[M[_]](implicit 25 | D: Decompose[M, Double] 26 | ): Expectation[M] = 27 | new Expectation[M] { 28 | def get[A, B](a: M[A])(f: A => B)(implicit M: Module[Double, B]): B = 29 | M.group.sum( 30 | D.decompose(a).map { case (a, coef) => M.scale(coef, f(a)) } 31 | ) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/algebra/Module.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package algebra 3 | 4 | import com.twitter.algebird.{Group, Ring, VectorSpace} 5 | 6 | /** This class represents an abstract-algebraic "module". A module is a generalization of vector 7 | * spaces that allows scalars to come from a ring instead of a field. It consists of: 8 | * 9 | * - An abelian group (G, +) representing the elements that can be scaled 10 | * - A ring (R, +, *) representing the scalars 11 | * - A scaling operation R × G → G that satisfies: 12 | * - r(g₁ + g₂) = rg₁ + rg₂ (distributivity over group addition) 13 | * - (r₁ + r₂)g = r₁g + r₂g (distributivity over ring addition) 14 | * - (r₁r₂)g = r₁(r₂g) (compatibility with ring multiplication) 15 | * - 1g = g (identity scalar) 16 | * 17 | * For more details see: https://en.wikipedia.org/wiki/Module_(mathematics) 18 | */ 19 | object Module { 20 | // the default module! 21 | type DModule[T] = Module[Double, T] 22 | 23 | /** This method is used to get the default module for a given type. 24 | * 25 | * @param M 26 | * The module to get. 27 | * @return 28 | * The default module for the given type. 29 | */ 30 | @inline final def apply[R, G](implicit M: Module[R, G]): Module[R, G] = M 31 | 32 | /** supplies an implicit module, given an implicitly-available Ring for some type R. 33 | */ 34 | implicit def ringModule[R: Ring]: Module[R, R] = from(Ring.times(_, _)) 35 | 36 | /** Given an implicit ring and group, accepts a scaleFn that shows how to perform scalar 37 | * multiplication between elements of the ring and the group and returns a new module over R and 38 | * G. 39 | */ 40 | def from[R, G]( 41 | scaleFn: (R, G) => G 42 | )(implicit R: Ring[R], G: Group[G]): Module[R, G] = 43 | new Module[R, G] { 44 | override def ring = R 45 | override def group = G 46 | def scale(r: R, g: G) = 47 | if (R.isNonZero(r)) scaleFn(r, g) else G.zero 48 | } 49 | 50 | /* Algebird's vector space is generic on the container type C, and implicitly pulls in a group on 51 | C[F]. We are a little more general. 52 | */ 53 | def fromVectorSpace[F, C[_]](implicit 54 | R: Ring[F], 55 | V: VectorSpace[F, C] 56 | ): Module[F, C[F]] = { 57 | implicit val g = V.group 58 | from[F, C[F]](V.scale(_, _)) 59 | } 60 | } 61 | 62 | trait Module[R, G] extends Serializable { 63 | implicit def ring: Ring[R] 64 | implicit def group: Group[G] 65 | def scale(r: R, g: G): G 66 | } 67 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/algebra/ToDouble.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package algebra 3 | 4 | import scala.{specialized => sp} 5 | 6 | /** Typeclass that encodes how some type A can be converted into a Double instance. 7 | */ 8 | trait ToDouble[@sp(Int, Long, Float, Double) A] extends Any with Serializable { 9 | self => 10 | def apply(a: A): Double 11 | 12 | def contramap[B](f: B => A): ToDouble[B] = new ToDouble[B] { 13 | def apply(b: B): Double = self.apply(f(b)) 14 | } 15 | } 16 | 17 | object ToDouble { 18 | 19 | /** Access an implicit `[[ToDouble]][A]`. 20 | */ 21 | @inline final def apply[A](implicit ev: ToDouble[A]): ToDouble[A] = ev 22 | 23 | /** Generates an instance of[[ToDouble]] from a pure function. 24 | */ 25 | @inline def instance[A](toDouble: A => Double): ToDouble[A] = 26 | new ToDouble[A] { 27 | override def apply(a: A): Double = toDouble(a) 28 | } 29 | 30 | /** The [[ToDouble]] instance for doubles uses the identity function. 31 | */ 32 | implicit val fromDouble: ToDouble[Double] = instance(d => d) 33 | 34 | /** Any type A that conforms to the Numeric typeclass can be converted to double via the toDouble 35 | * method on that typeclass. 36 | */ 37 | implicit def numericToDouble[A](implicit N: Numeric[A]): ToDouble[A] = 38 | instance(N.toDouble(_)) 39 | } 40 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/algebra/Weight.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package algebra 3 | 4 | import com.twitter.algebird.Monoid 5 | 6 | /** Value class that represents some Double-valued weight that can be applied to a type. 7 | */ 8 | case class Weight(w: Double) extends AnyVal { 9 | def +(r: Weight): Weight = Weight(w + r.w) 10 | def *(r: Weight): Weight = Weight(w * r.w) 11 | def /(r: Weight): Weight = Weight(w / r.w) 12 | def <(r: Weight): Boolean = w < r.w 13 | } 14 | 15 | object Weight { 16 | 17 | /** */ 18 | val One: Weight = Weight(1.0) 19 | 20 | /** */ 21 | val Zero: Weight = Weight(0.0) 22 | 23 | implicit val timesMonoid: Monoid[Weight] = Monoid.from(One)(_ * _) 24 | implicit val ord: Ordering[Weight] = Ordering.by(_.w) 25 | implicit val toDouble: ToDouble[Weight] = ToDouble.instance(_.w) 26 | } 27 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/evaluate/ActionValue.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package evaluate 3 | 4 | import com.scalarl.algebra.{Expectation, Module} 5 | import Module.DModule 6 | 7 | /** trait for evaluation of a given state, action pair. 8 | */ 9 | sealed trait ActionValue[Obs, A, R, G, S[_]] extends Product with Serializable { 10 | 11 | def evaluate(state: State[Obs, A, R, S], a: A): G 12 | 13 | def greedyOptions(state: State[Obs, A, R, S])(implicit 14 | G: Ordering[G] 15 | ): Set[A] = 16 | Util.allMaxBy[A, G](state.actions)(evaluate(state, _)) 17 | 18 | def byPolicy[M[_]]( 19 | policy: Policy[Obs, A, R, M, S] 20 | )(implicit 21 | M: Expectation[M], 22 | MV: Module[Double, G] 23 | ): StateValue[Obs, A, R, G, S] = 24 | StateValue.ByPolicy(this, policy) 25 | } 26 | 27 | object ActionValue { 28 | def fn[Obs, A, R, G, S[_]]( 29 | f: ActionValueFn[Obs, A, G] 30 | ): ActionValue[Obs, A, R, G, S] = Fn(f) 31 | 32 | /** Evaluates the action's value directly. 33 | */ 34 | final case class Fn[Obs, A, R, G, S[_]](f: ActionValueFn[Obs, A, G]) 35 | extends ActionValue[Obs, A, R, G, S] { 36 | def evaluate(state: State[Obs, A, R, S], a: A): G = 37 | f.actionValue(state.observation, a) 38 | } 39 | 40 | /** The state under evaluation potentially offers dynamics 41 | */ 42 | final case class ByStateValue[Obs, A, R, G: DModule, S[_]: Expectation]( 43 | evaluator: StateValue[Obs, A, R, G, S], 44 | prepare: R => G, 45 | merge: (G, G) => G 46 | ) extends ActionValue[Obs, A, R, G, S] { 47 | def evaluate(state: State[Obs, A, R, S], a: A): G = 48 | Expectation[S].get(state.act(a)) { case (r, s) => 49 | merge(evaluator.evaluate(s), prepare(r)) 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/evaluate/StateValue.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package evaluate 3 | 4 | import com.scalarl.algebra.{Expectation, Module} 5 | import Module.DModule 6 | 7 | /** trait for evaluating a given state. 8 | */ 9 | sealed trait StateValue[Obs, A, R, G, S[_]] extends Product with Serializable { 10 | 11 | /** Returns an evaluation of the given state. 12 | */ 13 | def evaluate(state: State[Obs, A, R, S]): G 14 | 15 | /** Upgrades to evaluate given... what is going on? 16 | */ 17 | def byStateValue( 18 | prepare: R => G, 19 | merge: (G, G) => G 20 | )(implicit 21 | S: Expectation[S], 22 | MV: Module[Double, G] 23 | ): ActionValue[Obs, A, R, G, S] = 24 | ActionValue.ByStateValue(this, prepare, merge) 25 | } 26 | 27 | object StateValue { 28 | 29 | /** Returns a basic evaluator that uses a given state value function. 30 | */ 31 | def fn[Obs, A, R, G, S[_]]( 32 | f: StateValueFn[Obs, G] 33 | ): StateValue[Obs, A, R, G, S] = Fn(f) 34 | 35 | /** This evaluates the state's value directly. 36 | */ 37 | final case class Fn[Obs, A, R, G, S[_]](f: StateValueFn[Obs, G]) 38 | extends StateValue[Obs, A, R, G, S] { 39 | def evaluate(state: State[Obs, A, R, S]): G = 40 | f.stateValue(state.observation) 41 | } 42 | 43 | /** Evaluates the state's value by weighting evaluated action values by the policy's chance of 44 | * choosing each action. 45 | */ 46 | final case class ByPolicy[Obs, A, R, G: DModule, M[_]: Expectation, S[_]]( 47 | evaluator: ActionValue[Obs, A, R, G, S], 48 | policy: Policy[Obs, A, R, M, S] 49 | ) extends StateValue[Obs, A, R, G, S] { 50 | def evaluate(state: State[Obs, A, R, S]): G = 51 | Expectation[M].get(policy.choose(state)) { a => 52 | evaluator.evaluate(state, a) 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/logic/Episode.scala: -------------------------------------------------------------------------------- 1 | /** Logic for playing episodic games. 2 | */ 3 | package com.scalarl 4 | package logic 5 | 6 | import cats.{Functor, Monad} 7 | import cats.implicits._ 8 | 9 | object Episode { 10 | import cats.syntax.functor._ 11 | 12 | /** Wrapper around a combination of state and policy. A moment in time. this wraps up a common 13 | * thing that we interact with... 14 | */ 15 | case class Moment[Obs, A, R, M[_]]( 16 | policy: Policy[Obs, A, R, M, M], 17 | state: State[Obs, A, R, M] 18 | ) { 19 | def choice: M[A] = policy.choose(state) 20 | 21 | def act( 22 | a: A 23 | )(implicit M: Functor[M]): M[(Moment[Obs, A, R, M], SARS[Obs, A, R, M])] = 24 | state.act(a).map { case (r, s2) => 25 | val sars = SARS(state, a, r, s2) 26 | (Moment(policy.learn(sars), s2), sars) 27 | } 28 | 29 | /** Play a single round of a game. Returns M of: 30 | * 31 | * \- pair of (the new policy that's learned, the new state you end up in) \- triple of (state 32 | * you came from, action you took, reward you received). 33 | */ 34 | def play(implicit 35 | M: Monad[M] 36 | ): M[(Moment[Obs, A, R, M], SARS[Obs, A, R, M])] = 37 | policy.choose(state).flatMap(act) 38 | } 39 | 40 | /** Takes a policy and a starting state and returns an M containing the final policy, final state 41 | * and the trajectory that got us there. 42 | */ 43 | def playEpisode[Obs, A, R, M[_]: Monad, T]( 44 | moment: Moment[Obs, A, R, M], 45 | tracker: MonteCarlo.Tracker[Obs, A, R, T, M] 46 | ): M[(Moment[Obs, A, R, M], MonteCarlo.Trajectory[Obs, A, R, M])] = 47 | Util.iterateUntilM(moment, tracker)(_.play)(_.state.isTerminal) 48 | 49 | /** Specialized version of playEpisode that only updates every first time a state is seen. 50 | */ 51 | def firstVisit[Obs, A, R, M[_]: Monad]( 52 | moment: Moment[Obs, A, R, M] 53 | ): M[(Moment[Obs, A, R, M], MonteCarlo.Trajectory[Obs, A, R, M])] = 54 | Episode.playEpisode(moment, MonteCarlo.Tracker.firstVisit) 55 | 56 | // Below this we have the functions that have been useful for tracking bandit 57 | // problems. I wonder if there is some nice primitive we can develop for 58 | // clicking many agents forward at once. Is that an interesting thing to do? 59 | 60 | /** Takes a list of policy, initial state pairs and plays a single episode of a game with each of 61 | * them. 62 | */ 63 | def playMany[Obs, A, R, M[_]: Monad]( 64 | moments: List[Moment[Obs, A, R, M]] 65 | )( 66 | rewardSum: List[SARS[Obs, A, R, M]] => R 67 | ): M[(List[Moment[Obs, A, R, M]], R)] = 68 | moments.traverse(_.play).map { results => 69 | ( 70 | // this could actually build a nice trajectory for many items at once. 71 | results.map(_._1), 72 | rewardSum(results.map(_._2)) 73 | ) 74 | } 75 | 76 | /** Takes an initial set of policies and astate... we could definitely adapt this to do some 77 | * serious learning on the policies, and use the MonoidAggregator stuff. 78 | */ 79 | def playManyN[Obs, A, R, M[_]: Monad]( 80 | moments: List[Moment[Obs, A, R, M]], 81 | nTimes: Int 82 | )( 83 | rewardSum: List[SARS[Obs, A, R, M]] => R 84 | ): M[(List[Moment[Obs, A, R, M]], List[R])] = 85 | Util.iterateM(nTimes)((moments, List.empty[R])) { case (ps, rs) => 86 | playMany(ps)(rewardSum).map { case (newMoment, r) => 87 | (newMoment, rs :+ r) 88 | } 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/logic/MonteCarlo.scala: -------------------------------------------------------------------------------- 1 | /** Logic for playing episodic games. 2 | */ 3 | package com.scalarl 4 | package logic 5 | 6 | import cats.Monad 7 | import cats.implicits._ 8 | import com.twitter.algebird.{Aggregator, Monoid, MonoidAggregator} 9 | import com.scalarl.algebra.Weight 10 | import com.scalarl.util.FrequencyTracker 11 | import scala.annotation.tailrec 12 | 13 | object MonteCarlo { 14 | import Episode.Moment 15 | 16 | case class ShouldUpdateState(get: Boolean) extends AnyVal 17 | object ShouldUpdateState { 18 | val yes = ShouldUpdateState(true) 19 | val no = ShouldUpdateState(false) 20 | } 21 | 22 | // This is an iterator of SARS observations, starting at the very beginning, 23 | // paired with a note about whether or not that observation should trigger a 24 | // state update. This is primarily interesting for distinguishing between 25 | // first and every visit updates, or something in between. 26 | type Trajectory[Obs, A, R, M[_]] = 27 | Iterator[(SARS[Obs, A, R, M], ShouldUpdateState)] 28 | 29 | // The T type here is type that's used to aggregate the trajectory. 30 | type Tracker[Obs, A, R, T, M[_]] = 31 | MonoidAggregator[SARS[Obs, A, R, M], T, Trajectory[Obs, A, R, M]] 32 | 33 | object Tracker { 34 | // Use a frequency tracker to aggregate; this allows us to count backwards 35 | // from the end and figure out when the first time we saw a particular state 36 | // was. 37 | type FirstVisit[Obs, A, R, M[_]] = 38 | Tracker[Obs, A, R, FrequencyTracker[SARS[Obs, A, R, M], Obs], M] 39 | 40 | // If we don't care we can accumulate the trajectory the state using a vector. 41 | // 42 | // This should be equivalent to using a frequency tracker but ignoring 43 | // whether or not we've seen the state. 44 | type EveryVisit[Obs, A, R, M[_]] = 45 | Tracker[Obs, A, R, Vector[SARS[Obs, A, R, M]], M] 46 | 47 | /** Returns a Tracker instance that will generate a trajectory where ShouldUpdateState is only 48 | * true the first time in the trajectory a state is encountered. 49 | */ 50 | def firstVisit[Obs, A, R, M[_]]: FirstVisit[Obs, A, R, M] = { 51 | implicit val m = 52 | FrequencyTracker.monoid[SARS[Obs, A, R, M], Obs](_.state.observation) 53 | Aggregator.appendMonoid( 54 | appnd = _ :+ _, 55 | pres = _.reverseIterator.map { case (t, seen) => 56 | (t, ShouldUpdateState(seen == 0)) 57 | } 58 | ) 59 | } 60 | 61 | /** Returns a Tracker instance where ShouldUpdateState signals YES for every single state. 62 | */ 63 | def everyVisit[Obs, A, R, M[_]]: EveryVisit[Obs, A, R, M] = 64 | Aggregator.appendMonoid( 65 | appnd = _ :+ _, 66 | pres = _.reverseIterator.map((_, ShouldUpdateState.yes)) 67 | ) 68 | } 69 | 70 | // TODO the next phase is to try and get n step SARSA working, maybe with an 71 | // expected bump at the end. And this has to work AS we're building the 72 | // trajectory. 73 | // 74 | // The trick here is that I need to freaking get access to the trajectory 75 | // itself and start to build up that business. 76 | // 77 | // Leave this for a while... but I think a key thing is going to be getting 78 | // access to the trajectory as I'm walking around this shit. 79 | def sarsa[Obs, A, R, M[_]: Monad, T]( 80 | moment: Moment[Obs, A, R, M], 81 | tracker: MonteCarlo.Tracker[Obs, A, R, T, M] 82 | ): M[(Moment[Obs, A, R, M], MonteCarlo.Trajectory[Obs, A, R, M])] = 83 | Util.iterateUntilM(moment, tracker) { case m @ Moment(policy, state) => 84 | m.play 85 | }(_.state.isTerminal) 86 | 87 | /** So if you have G, your return... okay, this is a version that tracks the weights, but doesn't 88 | * give you a nice way to push the weights back. What if we make the weight part of G? Try that 89 | * in the next fn. 90 | * 91 | * This is a full monte carlo trajectory tracker that's able to do off-policy control. The 92 | * behavior policy does NOT change at all, but that's okay, I guess. We're going to have to solve 93 | * that now. Presumably if you're updating a value function at any point you could get a new 94 | * agent. 95 | */ 96 | def processTrajectory[Obs, A, R, G, M[_]]( 97 | trajectory: Trajectory[Obs, A, R, M], 98 | valueFn: ActionValueFn[Obs, A, G], 99 | agg: MonoidAggregator[SARS[Obs, A, R, M], G, Option[G]] 100 | ): ActionValueFn[Obs, A, G] = { 101 | 102 | @tailrec 103 | def loop( 104 | t: Trajectory[Obs, A, R, M], 105 | vfn: ActionValueFn[Obs, A, G], 106 | g: G 107 | ): ActionValueFn[Obs, A, G] = 108 | if (t.isEmpty) vfn 109 | else { 110 | val (sars, shouldUpdate) = t.next 111 | agg.present(agg.append(g, sars)) match { 112 | case None => vfn 113 | case Some(g2) => 114 | val newFn = if (shouldUpdate.get) { 115 | val SARS(s, a, r, s2) = sars 116 | vfn.update(s.observation, a, g2) 117 | } else vfn 118 | loop(t, newFn, g2) 119 | } 120 | } 121 | // I think we HAVE to start with zero here, since we always have some sort 122 | // of zero value for the final state, even if we use a new aggregation type. 123 | loop(trajectory, valueFn, agg.monoid.zero) 124 | } 125 | 126 | /** This is a simpler version that doesn't do any weighting. This should be equivalent to the more 127 | * difficult one above, with a constant weight of 1 for everything. 128 | */ 129 | def processTrajectorySimple[Obs, A, R, G, M[_]]( 130 | trajectory: Trajectory[Obs, A, R, M], 131 | valueFn: ActionValueFn[Obs, A, G], 132 | agg: MonoidAggregator[R, G, G] 133 | ): ActionValueFn[Obs, A, G] = 134 | // I think we HAVE to start with zero here, since we always have some sort 135 | // of zero value for the final state, even if we use a new aggregation type. 136 | trajectory 137 | .foldLeft((valueFn, agg.monoid.zero)) { case ((vf, g), (SARS(s, a, r, s2), shouldUpdate)) => 138 | val g2 = agg.append(g, r) 139 | if (shouldUpdate.get) { 140 | (vf.update(s.observation, a, agg.present(g2)), g2) 141 | } else (vf, g2) 142 | } 143 | ._1 144 | 145 | // generates a monoid aggregator that can handle weights! We'll need to pair 146 | // this with a value function that knows how to handle weights on the way in, 147 | // by keeping a count for each state, and handling the weight multiplication, 148 | // that sort of thing. 149 | def weighted[Obs, A, R, G, M[_]]( 150 | agg: MonoidAggregator[R, G, G], 151 | fn: SARS[Obs, A, R, M] => Weight 152 | ): MonoidAggregator[SARS[Obs, A, R, M], (G, Weight), Option[(G, Weight)]] = { 153 | implicit val m: Monoid[G] = agg.monoid 154 | Aggregator 155 | .appendMonoid[SARS[Obs, A, R, M], (G, Weight)] { case ((g, w), sars) => 156 | (agg.append(g, sars.reward), w * fn(sars)) 157 | } 158 | .andThenPresent { 159 | case (g, Weight.Zero) => None 160 | case pair => Some(pair) 161 | 162 | } 163 | } 164 | 165 | // generates a function that uses two policies to assign a weight. This is 166 | // input into the stuff above. 167 | def byPolicy[Obs, A, R, M[_]]( 168 | basePolicy: Policy[Obs, A, R, Cat, M], 169 | targetPolicy: Policy[Obs, A, R, Cat, M] 170 | ): (State[Obs, A, R, M], A, R) => Weight = { case (s, a, _) => 171 | val num = targetPolicy.choose(s).pmf(a) 172 | val denom = basePolicy.choose(s).pmf(a) 173 | Weight(num / denom) 174 | } 175 | 176 | // function that always returns a weight of 1. 177 | def constant[Obs, A, R, M[_]]: SARS[Obs, A, R, M] => Weight = _ => Weight.One 178 | } 179 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/logic/Sweep.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package logic 3 | 4 | import cats.{Id, Monad} 5 | import com.scalarl.algebra.{Expectation, Module, ToDouble} 6 | import com.scalarl.evaluate.StateValue 7 | 8 | object Sweep { 9 | import Module.DModule 10 | 11 | sealed trait Update extends Product with Serializable 12 | object Update { 13 | final case object Single 14 | final case object SweepComplete 15 | } 16 | 17 | /** This sweeps across the whole state space and updates the policy every single time IF you set 18 | * valueIteration to true. Otherwise it creates a policy once and then uses it each time. 19 | * 20 | * What we really want is the ability to ping between updates to the value function or learning 21 | * steps; to insert them every so often. 22 | * 23 | * This function does NOT currently return the final policy, since you can just make it yourself, 24 | * given the return value and the function. 25 | */ 26 | def sweep[Obs, A, R, T, M[_]: Expectation, S[_]: Expectation]( 27 | valueFn: StateValueFn[Obs, T], 28 | policyFn: StateValueFn[Obs, T] => Policy[Obs, A, R, M, S], 29 | evaluatorFn: ( 30 | StateValueFn[Obs, T], 31 | Policy[Obs, A, R, M, S] 32 | ) => StateValue[Obs, A, R, T, S], 33 | states: Traversable[State[Obs, A, R, S]], 34 | inPlace: Boolean, 35 | valueIteration: Boolean 36 | ): StateValueFn[Obs, T] = 37 | states 38 | .foldLeft( 39 | (valueFn, evaluatorFn(valueFn, policyFn(valueFn)), policyFn(valueFn)) 40 | ) { case ((vf, ev, p), state) => 41 | val newFn = vf.update(state.observation, ev.evaluate(state)) 42 | val newPolicy = if (valueIteration) policyFn(newFn) else p 43 | val newEv = if (inPlace) evaluatorFn(newFn, newPolicy) else ev 44 | (newFn, newEv, newPolicy) 45 | } 46 | ._1 47 | 48 | def sweepUntil[Obs, A, R, T, M[_]: Expectation, S[_]: Expectation]( 49 | valueFn: StateValueFn[Obs, T], 50 | policyFn: StateValueFn[Obs, T] => Policy[Obs, A, R, M, S], 51 | evaluatorFn: ( 52 | StateValueFn[Obs, T], 53 | Policy[Obs, A, R, M, S] 54 | ) => StateValue[Obs, A, R, T, S], 55 | states: Traversable[State[Obs, A, R, S]], 56 | stopFn: (StateValueFn[Obs, T], StateValueFn[Obs, T], Long) => Boolean, 57 | inPlace: Boolean, 58 | valueIteration: Boolean 59 | ): (StateValueFn[Obs, T], Long) = 60 | Monad[Id].tailRecM((valueFn, 0L)) { case (fn, nIterations) => 61 | val updated = 62 | sweep(fn, policyFn, evaluatorFn, states, inPlace, valueIteration) 63 | Either.cond( 64 | stopFn(fn, updated, nIterations), 65 | (updated, nIterations), 66 | (updated, nIterations + 1) 67 | ) 68 | } 69 | 70 | // TODO - this probably needs to take evaluators directly. 71 | def isPolicyStable[Obs, A, R, T: DModule: Ordering, M[_], S[_]: Expectation]( 72 | l: StateValueFn[Obs, T], 73 | r: StateValueFn[Obs, T], 74 | prepare: R => T, 75 | merge: (T, T) => T, 76 | states: Traversable[State[Obs, A, R, S]] 77 | ): Boolean = { 78 | val lEvaluator = Evaluator.oneAhead[Obs, A, R, T, M, S](l, prepare, merge) 79 | val rEvaluator = Evaluator.oneAhead[Obs, A, R, T, M, S](r, prepare, merge) 80 | states.forall { s => 81 | lEvaluator.greedyOptions(s) == rEvaluator.greedyOptions(s) 82 | } 83 | } 84 | 85 | /** Helper to tell if we can stop iterating. The combine function is used to aggregate the 86 | * differences between the value functions for each observation... the final aggregated value 87 | * must be less than epsilon to return true, false otherwise. 88 | */ 89 | def diffBelow[Obs, T: ToDouble]( 90 | l: StateValueFn[Obs, T], 91 | r: StateValueFn[Obs, T], 92 | epsilon: Double 93 | )( 94 | combine: (Double, Double) => Double 95 | ): Boolean = Ordering[Double].lt( 96 | diffValue(l, r, combine), 97 | epsilon 98 | ) 99 | 100 | /** TODO consider putting this on the actual trait. 101 | */ 102 | def diffValue[Obs, T]( 103 | l: StateValueFn[Obs, T], 104 | r: StateValueFn[Obs, T], 105 | combine: (Double, Double) => Double 106 | )(implicit T: ToDouble[T]): Double = 107 | Util.diff[Obs]( 108 | l.seen ++ r.seen, 109 | o => T(l.stateValue(o)), 110 | o => T(r.stateValue(o)), 111 | combine 112 | ) 113 | } 114 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/package.scala: -------------------------------------------------------------------------------- 1 | package com 2 | 3 | /** Functional reinforcement learning in Scala. 4 | */ 5 | package object scalarl { 6 | 7 | /** Type alias for [[com.scalarl.rainier.Categorical]], which represents a finite discrete 8 | * probability distribution. 9 | */ 10 | type Cat[+T] = rainier.Categorical[T] 11 | } 12 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/policy/Gradient.scala: -------------------------------------------------------------------------------- 1 | /** Policy that accumulates using the Gradient. 2 | */ 3 | package com.scalarl 4 | package policy 5 | package bandit 6 | 7 | import com.twitter.algebird.{Aggregator, AveragedValue, Monoid, Semigroup} 8 | import com.scalarl.algebra.ToDouble 9 | import com.scalarl.rainier.Categorical 10 | 11 | /** This thing needs to track its average reward internally... then, if we have the gradient 12 | * baseline set, use that thing to generate the notes. 13 | * 14 | * T is the "average" type. 15 | */ 16 | case class Gradient[Obs, A: Equiv, R: ToDouble, T: ToDouble, S[_]]( 17 | config: Gradient.Config[R, T], 18 | valueFn: ActionValueFn[Obs, A, Gradient.Item[T]] 19 | ) extends Policy[Obs, A, R, Cat, S] { 20 | 21 | /** Let's try out this style for a bit. This gives us a way to convert an action directly into a 22 | * probability, using our actionValue Map above. 23 | */ 24 | def aToDouble(obs: Obs): ToDouble[A] = 25 | Gradient.Item 26 | .itemToDouble[T] 27 | .contramap[A]( 28 | valueFn.actionValue(obs, _) 29 | ) 30 | 31 | override def choose(state: State[Obs, A, R, S]): Cat[A] = { 32 | implicit val at = aToDouble(state.observation) 33 | Categorical.softmax(state.actions) 34 | } 35 | 36 | override def learn(sars: SARS[Obs, A, R, S]): This = { 37 | val SARS(state, action, reward, nextState) = sars 38 | 39 | val pmf = choose(state).pmf 40 | val obs = state.observation 41 | 42 | val updated = state.actions.foldLeft(valueFn) { case (vfn, a) => 43 | // the new item has to get bootstrapped with the old value... that is 44 | // SORT of associative, and works. Test soon. 45 | val old = valueFn.actionValue(state.observation, a).t 46 | 47 | // get the delta, 48 | val delta = ToDouble[R].apply(reward) - ToDouble[T].apply(old) 49 | 50 | // then there might be some nicer way of doing this. 51 | val actionProb = 52 | if (Equiv[A].equiv(a, action)) 53 | -pmf(a) 54 | else 55 | 1 - pmf(a) 56 | 57 | // this is definitely the baseline. 58 | val newItem = Gradient.Item( 59 | actionProb * delta * config.stepSize, 60 | config.prepare(reward) 61 | ) 62 | vfn.update(obs, a, newItem) 63 | } 64 | copy(valueFn = updated) 65 | } 66 | } 67 | 68 | object Gradient { 69 | import Util.Instances.avToDouble 70 | 71 | object Item { 72 | class ItemSemigroup[T](implicit T: Semigroup[T]) extends Semigroup[Item[T]] { 73 | override def plus(l: Item[T], r: Item[T]): Item[T] = 74 | Item(l.q + r.q, T.plus(l.t, r.t)) 75 | } 76 | 77 | // Monoid instance, not used for now but meaningful, I think. 78 | class ItemMonoid[T](implicit T: Monoid[T]) extends ItemSemigroup[T] with Monoid[Item[T]] { 79 | override val zero: Item[T] = Item(0, T.zero) 80 | } 81 | 82 | // implicit instances. 83 | implicit def semigroup[T: Semigroup]: Semigroup[Item[T]] = 84 | new ItemSemigroup[T] 85 | implicit def ord[T: Ordering]: Ordering[Item[T]] = Ordering.by(_.t) 86 | implicit def monoid[T: Monoid] = new ItemMonoid[T] 87 | implicit def itemToDouble[T]: ToDouble[Item[T]] = ToDouble.instance(_.q) 88 | } 89 | 90 | /** Represents an action value AND some sort of accumulated value. The action value is something 91 | * we get by aggregating a reward in some way. 92 | * 93 | * You might just sum, which would be goofy; you might do some averaged value, or exponentially 94 | * decaying average. 95 | * 96 | * The t is the reward aggregator. The q is the item that's getting updated in this funky way. 97 | * 98 | * So how would you write a semigroup for this? You'd have to semigroup combine the T... what is 99 | * the monoid on the q? 100 | */ 101 | case class Item[T](q: Double, t: T) 102 | 103 | /** Holds properties necessary to run the gradient algorithm. 104 | */ 105 | case class Config[R: ToDouble, T: ToDouble]( 106 | initial: T, 107 | stepSize: Double, 108 | prepare: R => T, 109 | plus: (T, T) => T 110 | ) { 111 | implicit val m: Monoid[T] = Monoid.from(initial)(plus) 112 | 113 | /** Generates an actual policy from the supplied config. 114 | */ 115 | def policy[Obs, A, S[_]]: Gradient[Obs, A, R, T, S] = 116 | Gradient(this, ActionValueFn.mergeable[Obs, A, Item[T]]) 117 | } 118 | 119 | /** Hand-selected version that uses AveragedValue to accumulate internally. 120 | */ 121 | def incrementalConfig( 122 | stepSize: Double, 123 | initial: Double = 0.0 124 | ): Config[Double, AveragedValue] = 125 | Config(AveragedValue(initial), stepSize, AveragedValue(_), _ + _) 126 | 127 | /** Uses NO averaging baseline. 128 | */ 129 | def noBaseline(stepSize: Double): Config[Double, Unit] = 130 | fromAggregator(stepSize, (), Aggregator.const(0.0)) 131 | 132 | /** Generate this gradient from some aggregator. 133 | */ 134 | def fromAggregator[R: ToDouble, T]( 135 | stepSize: Double, 136 | initial: T, 137 | agg: Aggregator[R, T, Double] 138 | ): Config[R, T] = { 139 | implicit val tToDouble: ToDouble[T] = ToDouble.instance(agg.present(_)) 140 | Config( 141 | initial, 142 | stepSize, 143 | agg.prepare(_), 144 | agg.semigroup.plus(_, _) 145 | ) 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/policy/Greedy.scala: -------------------------------------------------------------------------------- 1 | /** First crack at a policy that is actually greedy with respect to some action value function. 2 | * 3 | * Because this only has access to states, to do any updating it needs to be able to either look 4 | * ahead, or to see the dynamics of the system. 5 | * 6 | * Both of those ideas are implemented below. 7 | * 8 | * TODO Eval is actually a nice interface for only being able to look ahead so far. If it's a Now, 9 | * you can look directly in. But then you can't look further. That'll come in handy later when we 10 | * try to make games, etc. I can imagine some data type that makes it difficult to see, of course. 11 | * And then your best guess has to involve some knowledge of where you might get to, even if you 12 | * don't know the reward. 13 | */ 14 | package com.scalarl 15 | package policy 16 | 17 | import cats.{Id, Monad} 18 | import com.scalarl.algebra.{Expectation, Module} 19 | import com.scalarl.evaluate.ActionValue 20 | import com.scalarl.rainier.Categorical 21 | 22 | /** Base logic for greedy policies. 23 | */ 24 | class Greedy[Obs, A, R, T: Ordering, S[_]]( 25 | evaluator: ActionValue[Obs, A, R, T, S], 26 | epsilon: Double 27 | ) extends Policy[Obs, A, R, Cat, S] { self => 28 | 29 | private val explore: Cat[Boolean] = 30 | Categorical.boolean(epsilon) 31 | 32 | private def allActions(state: State[Obs, A, R, S]): Cat[A] = 33 | Categorical.fromSet(state.actions) 34 | 35 | private def greedy(state: State[Obs, A, R, S]): Cat[A] = 36 | Categorical.fromSet(evaluator.greedyOptions(state)) 37 | 38 | override def choose(state: State[Obs, A, R, S]): Cat[A] = 39 | Monad[Cat].ifM(explore)(allActions(state), greedy(state)) 40 | } 41 | 42 | object Greedy { 43 | import Module.DModule 44 | 45 | case class Config[R, T: DModule: Ordering]( 46 | epsilon: Double, 47 | prepare: R => T, 48 | merge: (T, T) => T, 49 | default: T 50 | ) { 51 | def id[Obs, A](valueFn: StateValueFn[Obs, T]): Policy[Obs, A, R, Cat, Id] = 52 | policy(valueFn) 53 | 54 | def stochastic[Obs, A]( 55 | valueFn: StateValueFn[Obs, T] 56 | ): Policy[Obs, A, R, Cat, Cat] = 57 | policy(valueFn) 58 | 59 | def policy[Obs, A, S[_]: Expectation]( 60 | valueFn: StateValueFn[Obs, T] 61 | ): Policy[Obs, A, R, Cat, S] = 62 | new Greedy[Obs, A, R, T, S]( 63 | Evaluator.oneAhead(valueFn, prepare, merge), 64 | epsilon 65 | ) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/policy/UCB.scala: -------------------------------------------------------------------------------- 1 | /** Policy that accumulates using the UCB algorithm. 2 | * 3 | * TODO should I make an Empty Choice option with a sealed trait? 4 | */ 5 | package com.scalarl 6 | package policy 7 | package bandit 8 | 9 | import com.twitter.algebird.{Aggregator, Monoid, Semigroup} 10 | import com.scalarl.evaluate.ActionValue 11 | import com.scalarl.rainier.Categorical 12 | 13 | case class UCB[Obs, A, R, T, S[_]]( 14 | config: UCB.Config[R, T], 15 | valueFn: ActionValueFn[Obs, A, UCB.Choice[T]], 16 | time: Time 17 | ) extends Policy[Obs, A, R, Cat, S] { 18 | 19 | private val evaluator: ActionValue[Obs, A, R, UCB.Choice[T], S] = 20 | valueFn.toEvaluator 21 | 22 | override def choose(state: State[Obs, A, R, S]): Cat[A] = 23 | Categorical.fromSet( 24 | Util 25 | .allMaxBy(state.actions)( 26 | evaluator.evaluate(state, _).totalValue(time) 27 | ) 28 | ) 29 | 30 | /** learn here passes directly through to the ActionValueFn now, which is the new thing. Does this 31 | * mean that we shouldn't learn at all? Should that get delegated to an agent? 32 | */ 33 | override def learn(sars: SARS[Obs, A, R, S]): This = 34 | copy( 35 | valueFn = valueFn.update( 36 | sars.state.observation, 37 | sars.action, 38 | config.choice(sars.reward) 39 | ), 40 | time = time.tick 41 | ) 42 | } 43 | 44 | object UCB { 45 | 46 | /** Generates a Config instance from an algebird Aggregator and a UCB parameter. 47 | */ 48 | def fromAggregator[R, T]( 49 | initial: T, 50 | param: Param, 51 | agg: Aggregator[R, T, Double] 52 | ): Config[R, T] = 53 | Config(param, initial, agg.prepare _, agg.semigroup.plus _, agg.present _) 54 | 55 | case class Config[R, T]( 56 | param: Param, 57 | initial: T, 58 | prepare: R => T, 59 | plus: (T, T) => T, 60 | present: T => Double 61 | ) { 62 | 63 | /** Returns a fresh policy instance using this config. 64 | */ 65 | def policy[Obs, A, S[_]]: UCB[Obs, A, R, T, S] = { 66 | implicit val tMonoid: Monoid[T] = Monoid.from(initial)(plus) 67 | implicit val monoid: Monoid[Choice[T]] = Choice.monoid(param, present) 68 | val avm = ActionValueFn.mergeable[Obs, A, Choice[T]] 69 | UCB(this, avm, Time.Zero) 70 | } 71 | 72 | // These are private and embedded in the config to make it easy to 73 | // share the fns without crossing the beams. 74 | private[scalarl] def merge(choice: Choice[T], r: R) = 75 | choice.update(plus(_, prepare(r))) 76 | private[scalarl] def choice(r: R): Choice[T] = 77 | Choice.one(prepare(r), param)(present) 78 | 79 | def initialChoice: Choice[T] = Choice.zero(initial, param)(present) 80 | } 81 | 82 | /** Tunes how important the upper confidence bound business is. 83 | */ 84 | case class Param(c: Int) extends AnyVal 85 | 86 | /** Needs documentation; this is a way of tracking how many times a particular thing was chosen 87 | * along with its value. 88 | */ 89 | object Choice { 90 | // Classes... 91 | class ChoiceSemigroup[T](implicit T: Semigroup[T]) extends Semigroup[Choice[T]] { 92 | override def plus(l: Choice[T], r: Choice[T]): Choice[T] = 93 | l.copy(t = T.plus(l.t, r.t), visits = l.visits + r.visits) 94 | } 95 | 96 | // Monoid instance, not used for now but meaningful, I think. 97 | class ChoiceMonoid[T](param: Param, toDouble: T => Double)(implicit 98 | T: Monoid[T] 99 | ) extends ChoiceSemigroup[T] 100 | with Monoid[Choice[T]] { 101 | override val zero: Choice[T] = 102 | Choice.zero[T](T.zero, param)(toDouble) 103 | } 104 | 105 | // implicit instances. 106 | implicit def semigroup[T: Semigroup]: Semigroup[Choice[T]] = 107 | new ChoiceSemigroup[T] 108 | implicit def ord[T: Ordering]: Ordering[Choice[T]] = Ordering.by(_.t) 109 | 110 | def monoid[T: Monoid](param: Param, toDouble: T => Double) = 111 | new ChoiceMonoid[T](param, toDouble) 112 | 113 | // constructors. 114 | def zero[T](initial: T, param: Param)(toDouble: T => Double): Choice[T] = 115 | Choice(initial, 0L, param, toDouble) 116 | 117 | def one[T](t: T, param: Param)(toDouble: T => Double): Choice[T] = 118 | Choice(t, 1L, param, toDouble) 119 | } 120 | 121 | /** Tracks the info required for the UCB calculation. 122 | */ 123 | case class Choice[T]( 124 | t: T, 125 | visits: Long, 126 | param: Param, 127 | toDouble: T => Double 128 | ) { 129 | 130 | /** Updates the contained value, increments the visits. 131 | */ 132 | def update(f: T => T): Choice[T] = 133 | copy(f(t), visits + 1, param, toDouble) 134 | 135 | def totalValue(time: Time): Double = 136 | if (visits <= 0) toDouble(t) 137 | else toDouble(t) + bonus(time) 138 | 139 | def compare(other: Choice[T], time: Time): Int = 140 | (this.visits, other.visits) match { 141 | case (0L, 0L) => 0 142 | case (0L, _) => 1 143 | case (_, 0L) => -1 144 | case _ => 145 | Ordering[Double].compare( 146 | totalValue(time), 147 | other.totalValue(time) 148 | ) 149 | } 150 | 151 | // Only called if visits is > 0. 152 | private def bonus(time: Time): Double = 153 | param.c * math.sqrt(math.log(time.value + 1)) / visits 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/policy/bandit/Greedy.scala: -------------------------------------------------------------------------------- 1 | /** Policy that accumulates via epsilon greedy. This is only still here because it knows how to 2 | * learn. 3 | */ 4 | package com.scalarl 5 | package policy 6 | package bandit 7 | 8 | import cats.{Functor, Monad} 9 | import com.twitter.algebird.{AveragedValue, Semigroup} 10 | import com.scalarl.rainier.Categorical 11 | import Util.Instances._ 12 | 13 | /** @param epsilon 14 | * number between 0 and 1. 15 | */ 16 | case class Greedy[Obs, A, R, T: Ordering, S[_]]( 17 | config: Greedy.Config[R, T], 18 | valueFn: ActionValueFn[Obs, A, T] 19 | ) extends Policy[Obs, A, R, Cat, S] { 20 | implicit val functor: Functor[Cat] = Functor[Cat] 21 | 22 | private val explore: Cat[Boolean] = 23 | Categorical.boolean(config.epsilon) 24 | 25 | private def allActions(state: State[Obs, A, R, S]): Cat[A] = 26 | Categorical.fromSet(state.actions) 27 | 28 | private def greedy(state: State[Obs, A, R, S]): Cat[A] = { 29 | val obs = state.observation 30 | Categorical.fromSet( 31 | Util.allMaxBy(state.actions)( 32 | valueFn.actionValue(obs, _) 33 | ) 34 | ) 35 | } 36 | 37 | override def choose(state: State[Obs, A, R, S]): Cat[A] = 38 | Monad[Cat] 39 | .ifM(explore)( 40 | allActions(state), 41 | greedy(state) 42 | ) 43 | 44 | override def learn(sars: SARS[Obs, A, R, S]): This = 45 | copy( 46 | valueFn = valueFn.update( 47 | sars.state.observation, 48 | sars.action, 49 | config.prepare(sars.reward) 50 | ) 51 | ) 52 | } 53 | 54 | // TODO Oh boy, this really does look like it needs an aggregator... maybe 55 | // I build it without, but then include the algebird versions 56 | // elsewhere? Or maybe I build to the cats interfaces, then I have an 57 | // algebird package? More for later. 58 | object Greedy { 59 | case class Config[R, T: Semigroup: Ordering]( 60 | epsilon: Double, 61 | prepare: R => T, 62 | initial: T 63 | ) { 64 | def policy[A, Obs, S[_]]: Greedy[Obs, A, R, T, S] = 65 | Greedy(this, ActionValueFn.mergeable(initial)) 66 | } 67 | 68 | /** Returns an incremental config. 69 | * 70 | * TODO we also need a version that uses a constant step size, instead of sample averages. And 71 | * maybe a version that uses exponential decay? 72 | */ 73 | def incrementalConfig( 74 | epsilon: Double, 75 | initial: Double = 0.0 76 | ): Config[Double, AveragedValue] = Config( 77 | epsilon, 78 | AveragedValue(_), 79 | AveragedValue(initial) 80 | ) 81 | } 82 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/rainier/Categorical.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package rainier 3 | 4 | import cats.{Applicative, Monad, Monoid} 5 | import cats.arrow.FunctionK 6 | import com.stripe.rainier.compute.Real 7 | import com.stripe.rainier.core.{Categorical => RCat, Generator, ToGenerator} 8 | import com.scalarl.algebra.ToDouble 9 | 10 | import scala.annotation.tailrec 11 | import scala.collection.immutable.Queue 12 | 13 | /** Identical to rainier's `Categorical`, except written with `Double` instead of `Real`. 14 | * 15 | * @param pmfSeq 16 | * A map with keys corresponding to the possible outcomes and values corresponding to the 17 | * probabilities of those outcomes. 18 | */ 19 | final case class Categorical[+T](pmfSeq: List[(T, Double)]) { 20 | def pmf[U >: T]: Map[U, Double] = pmfSeq.toMap 21 | 22 | def map[U](fn: T => U): Categorical[U] = 23 | Categorical( 24 | pmfSeq 25 | .foldLeft(Map.empty[U, Double]) { case (acc, (t, p)) => 26 | Util.mergeV(acc, fn(t), p) 27 | } 28 | .toList 29 | ) 30 | 31 | def flatMap[U](fn: T => Categorical[U]): Categorical[U] = 32 | Categorical( 33 | (for { 34 | (t, p) <- pmfSeq 35 | (u, p2) <- fn(t).pmfSeq 36 | } yield (u, p * p2)) 37 | .foldLeft(Map.empty[U, Double]) { case (acc, (u, p)) => 38 | Util.mergeV(acc, u, p) 39 | } 40 | .toList 41 | ) 42 | 43 | def zip[U](other: Categorical[U]): Categorical[(T, U)] = 44 | Categorical( 45 | for { 46 | (t, p) <- pmfSeq 47 | (u, p2) <- other.pmfSeq 48 | } yield ((t, u), p * p2) 49 | ) 50 | 51 | def toRainier[U >: T]: RCat[U] = 52 | RCat.normalize(pmf[U].mapValues(Real(_))) 53 | } 54 | 55 | object Categorical extends CategoricalInstances { 56 | def apply[T](pmf: Map[T, Double]): Categorical[T] = Categorical(pmf.toList) 57 | 58 | object Poisson { 59 | case class Lambda(value: Double) extends AnyVal 60 | 61 | def gamma(z: Double): Double = 62 | if (z == 0.0) 63 | Double.PositiveInfinity 64 | else if (z == 1.0 || z == 2.0) 65 | 0.0 66 | else 67 | approxGamma(z) 68 | 69 | private def approxGamma(z: Double): Double = { 70 | val v = z + 1.0 71 | val w = v + (1.0 / ((12.0 * v) - (1.0 / (10.0 * v)))) 72 | (math.log(Math.PI * 2) / 2.0) - (math 73 | .log(v) / 2.0) + (v * (math.log(w) - 1.0)) - math.log(z) 74 | } 75 | 76 | def logProbability(k: Int, lambda: Double): Double = 77 | -lambda + (math.log(lambda) * k) - gamma(k + 1.0) 78 | 79 | def probability(k: Int, lambda: Double): Double = 80 | math.exp(logProbability(k, lambda)) 81 | } 82 | 83 | def poisson(upperBound: Int, mean: Poisson.Lambda): Categorical[Int] = 84 | normalize( 85 | Util.makeMapUnsafe(0 until upperBound)(Poisson.probability(_, mean.value)) 86 | ) 87 | 88 | def boolean(p: Double): Categorical[Boolean] = 89 | Categorical(Map(true -> p, false -> (1.0 - p))) 90 | 91 | def pure[A](a: A): Categorical[A] = Categorical(List((a, 1.0))) 92 | 93 | def normalize[T](pmf: Map[T, Double]): Categorical[T] = { 94 | val total = pmf.values.toList.sum 95 | Categorical(pmf.map { case (t, p) => (t, p / total) }) 96 | } 97 | 98 | def seq[T](ts: Seq[T]): Categorical[T] = 99 | normalize(ts.groupBy(identity).mapValues(_.size)) 100 | 101 | def fromSet[T](ts: Set[T]): Categorical[T] = { 102 | val p = 1.0 / ts.size 103 | Categorical( 104 | ts.foldLeft(Map.empty[T, Double])((m, t) => m.updated(t, p)) 105 | ) 106 | } 107 | 108 | def softmax[A, B](m: Map[A, Double]): Categorical[A] = 109 | normalize(m.mapValues(math.exp(_))) 110 | 111 | def softmax[A: ToDouble](as: Set[A]): Categorical[A] = { 112 | val (pmf, sum) = as.foldLeft((Map.empty[A, Double], 0.0)) { case ((m, r), a) => 113 | val aExp = math.exp(ToDouble[A].apply(a)) 114 | (m.updated(a, aExp), r + aExp) 115 | } 116 | normalize(pmf.mapValues(_ / sum)) 117 | } 118 | } 119 | 120 | trait CategoricalInstances { 121 | implicit val catMonad: Monad[Categorical] = CategoricalMonad 122 | implicit def catMonoid[A: Monoid]: Monoid[Categorical[A]] = 123 | Applicative.monoid[Categorical, A] 124 | implicit def gen[T]: ToGenerator[Categorical[T], T] = 125 | new ToGenerator[Categorical[T], T] { 126 | def apply(c: Categorical[T]) = c.toRainier.generator 127 | } 128 | 129 | val setToCategorical: FunctionK[Set, Categorical] = 130 | new FunctionK[Set, Categorical] { 131 | def apply[A](sa: Set[A]) = Categorical.fromSet(sa) 132 | } 133 | 134 | val toRainierCategorical: FunctionK[Categorical, RCat] = 135 | new FunctionK[Categorical, RCat] { 136 | def apply[A](ca: Categorical[A]) = ca.toRainier 137 | } 138 | 139 | val catToGenerator: FunctionK[Categorical, Generator] = 140 | new FunctionK[Categorical, Generator] { 141 | def apply[A](ca: Categorical[A]): Generator[A] = ca.toRainier.generator 142 | } 143 | } 144 | 145 | private[scalarl] object CategoricalMonad extends Monad[Categorical] { 146 | def pure[A](x: A): Categorical[A] = Categorical.pure(x) 147 | 148 | override def map[A, B](fa: Categorical[A])(f: A => B): Categorical[B] = 149 | fa.map(f) 150 | 151 | override def product[A, B]( 152 | fa: Categorical[A], 153 | fb: Categorical[B] 154 | ): Categorical[(A, B)] = fa.zip(fb) 155 | 156 | override def flatMap[A, B](fa: Categorical[A])( 157 | f: A => Categorical[B] 158 | ): Categorical[B] = 159 | fa.flatMap(f) 160 | 161 | def tailRecM[A, B]( 162 | a: A 163 | )(f: A => Categorical[Either[A, B]]): Categorical[B] = { 164 | @tailrec 165 | def run( 166 | acc: Map[B, Double], 167 | queue: Queue[(Either[A, B], Double)] 168 | ): Map[B, Double] = 169 | queue.headOption match { 170 | case None => acc 171 | case Some((Left(a), v)) => 172 | run( 173 | acc, 174 | queue.drop(1) ++ f(a).pmfSeq.map { case (eab, d) => (eab, d * v) } 175 | ) 176 | case Some((Right(b), v)) => 177 | run(Util.mergeV(acc, b, v), queue.drop(1)) 178 | } 179 | val pmf = run(Map.empty, f(a).pmfSeq.to[Queue]) 180 | Categorical[B](pmf) 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/state/MapState.scala: -------------------------------------------------------------------------------- 1 | /** A MapState is a particular kind of state. 2 | * 3 | * TODO - maybe get a way to do observations in here? 4 | */ 5 | package com.scalarl 6 | package state 7 | 8 | import cats.Functor 9 | import cats.syntax.functor._ 10 | import com.stripe.rainier.cats._ 11 | import com.stripe.rainier.core.Generator 12 | 13 | /** MapState that doesn't evolve. 14 | */ 15 | case class StaticMapState[A, R, S[_]: Functor]( 16 | rewards: Map[A, S[R]], 17 | penalty: S[R] 18 | ) extends State[Unit, A, R, S] { 19 | override val observation: Unit = () 20 | override val dynamics = rewards.mapValues(_.map((_, this))) 21 | override val invalidMove = penalty.map((_, this)) 22 | } 23 | 24 | /** MDP with a single state. 25 | */ 26 | case class MapState[Obs, A, R, S[_]: Functor]( 27 | observation: Obs, 28 | rewards: Map[A, S[R]], 29 | penalty: S[R], 30 | step: (Obs, A, R, S[R]) => (Obs, S[R]) 31 | ) extends State[Obs, A, R, S] { 32 | 33 | private def updateForA(a: A, r: R): State[Obs, A, R, S] = { 34 | val (newObservation, newGen) = step(observation, a, r, rewards(a)) 35 | MapState( 36 | newObservation, 37 | rewards.updated(a, newGen), 38 | penalty, 39 | step 40 | ) 41 | } 42 | 43 | override val invalidMove = penalty.map((_, this)) 44 | override def dynamics = rewards.map { case (a, g) => 45 | (a, g.map(r => (r, updateForA(a, r)))) 46 | } 47 | } 48 | 49 | object MapState { 50 | private def genMap[A, R]( 51 | actions: Set[A], 52 | gen: Generator[Generator[R]] 53 | ): Generator[Map[A, Generator[R]]] = 54 | gen.repeat(actions.size).map(actions.zip(_).toMap) 55 | 56 | /** One of the two ways to construct a MapState. 57 | */ 58 | def static[A, R]( 59 | actions: Set[A], 60 | penalty: Generator[R], 61 | gen: Generator[Generator[R]] 62 | ): Generator[StaticMapState[A, R, Generator]] = 63 | genMap(actions, gen).map(StaticMapState(_, penalty)) 64 | 65 | /** The second of two ways to construct a MapState. 66 | */ 67 | def updating[Obs, A, R]( 68 | actions: Set[A], 69 | initialObservation: Obs, 70 | penalty: Generator[R], 71 | gen: Generator[Generator[R]], 72 | step: (Obs, A, R, Generator[R]) => (Obs, Generator[R]) 73 | ): Generator[MapState[Obs, A, R, Generator]] = 74 | genMap(actions, gen).map(MapState(initialObservation, _, penalty, step)) 75 | } 76 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/state/TickState.scala: -------------------------------------------------------------------------------- 1 | /** A TickState is limited in the number of ticks it can perform. 2 | */ 3 | package com.scalarl 4 | package state 5 | 6 | import cats.Functor 7 | import cats.arrow.FunctionK 8 | import cats.syntax.functor._ 9 | 10 | /** State that ends after a certain number of interactions. This is useful for turning a 11 | * non-episodic into an episodic task. 12 | */ 13 | case class TickState[Obs, A, R, S[_]: Functor]( 14 | state: State[Obs, A, R, S], 15 | tick: Int, 16 | limit: Int 17 | ) extends State[Obs, A, R, S] { 18 | override def observation: Obs = state.observation 19 | 20 | override def dynamics: Map[A, S[(R, This)]] = 21 | if (isTerminal) Map.empty 22 | else { 23 | state.dynamics.mapValues { v => 24 | v.map { case (r, innerState) => 25 | (r, TickState[Obs, A, R, S](innerState, tick - 1, limit)) 26 | } 27 | } 28 | } 29 | 30 | override def invalidMove: S[(R, This)] = state.invalidMove 31 | 32 | override def actions: Set[A] = if (isTerminal) Set.empty else state.actions 33 | 34 | override def act(action: A): S[(R, This)] = 35 | dynamics.getOrElse(action, invalidMove) 36 | 37 | override def isTerminal: Boolean = (tick >= limit) || state.isTerminal 38 | 39 | override def mapObservation[P](f: Obs => P)(implicit 40 | S: Functor[S] 41 | ): State[P, A, R, S] = 42 | TickState(state.mapObservation(f)(S), tick, limit)(S) 43 | 44 | override def mapReward[T](f: R => T)(implicit 45 | S: Functor[S] 46 | ): State[Obs, A, T, S] = 47 | TickState(state.mapReward(f)(S), tick, limit)(S) 48 | 49 | override def mapK[N[_]](f: FunctionK[S, N])(implicit 50 | N: Functor[N] 51 | ): State[Obs, A, R, N] = 52 | TickState(state.mapK(f), tick, limit) 53 | } 54 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/util/FrequencyTracker.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package util 3 | 4 | import com.twitter.algebird.Monoid 5 | 6 | /** Aggregating thing that also keeps track of frequencies. The item will be paired with a zero if 7 | * this is the first time seeing it. 8 | */ 9 | case class FrequencyTracker[A, B]( 10 | items: Vector[(A, Int)], 11 | frequencies: Map[B, Int], 12 | f: A => B 13 | ) { 14 | def :+(a: A): FrequencyTracker[A, B] = { 15 | val b = f(a) 16 | val newFrequencies = Util.mergeV(frequencies, b, 1) 17 | FrequencyTracker(items :+ ((a, newFrequencies(b) - 1)), newFrequencies, f) 18 | } 19 | def iterator: Iterator[(A, Int)] = items.iterator 20 | def reverseIterator: Iterator[(A, Int)] = items.reverse.iterator 21 | } 22 | object FrequencyTracker { 23 | def empty[A, B](f: A => B): FrequencyTracker[A, B] = 24 | FrequencyTracker(Vector.empty, Map.empty[B, Int], f) 25 | 26 | def pure[A, B](a: A, f: A => B): FrequencyTracker[A, B] = empty(f) :+ a 27 | 28 | def monoid[A, B](f: A => B): Monoid[FrequencyTracker[A, B]] = 29 | new Monoid[FrequencyTracker[A, B]] { 30 | val zero: FrequencyTracker[A, B] = FrequencyTracker.empty(f) 31 | def plus( 32 | l: FrequencyTracker[A, B], 33 | r: FrequencyTracker[A, B] 34 | ): FrequencyTracker[A, B] = 35 | FrequencyTracker( 36 | l.items ++ r.items, 37 | Monoid.plus[Map[B, Int]](l.frequencies, r.frequencies), 38 | l.f 39 | ) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/value/ConstantStep.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package value 3 | 4 | import com.scalarl.algebra.Module 5 | import com.twitter.algebird.Group 6 | 7 | /** Exponential recency-weighted average. This is similar to a weighted average, but instead of 8 | * weighting by the count, it uses a constant weighting factor. 9 | * 10 | * TODO consider changing Numeric to ToDouble? 11 | * 12 | * TODO there is some interesting thing going on here, where you're summing up constantly weighted 13 | * factors... and then in the importance weighting you're doing the same thing, you just send in a 14 | * different weight every time. 15 | * 16 | * The time is a way of skipping big swathes and injecting zeros in, but it's really the same. 17 | * 18 | * You can maybe also think about this like you're just keeping track of some weighted numerator, 19 | * but, same thing, injecting more sum of the weights into the denominator, but nothing, zeros, 20 | * into the top. 21 | * 22 | * That would be a nice thing to unify together. 23 | */ 24 | case class ConstantStep(value: Double, time: Time) extends Ordered[ConstantStep] { 25 | import ConstantStep.{Alpha, Epsilon} 26 | 27 | def compare(that: ConstantStep): Int = 28 | time.compareTo(that.time) match { 29 | case 0 => value.compare(that.value) 30 | case other => other 31 | } 32 | 33 | def decayTo(t2: Time, alpha: Alpha, eps: Epsilon): ConstantStep = 34 | if (t2 <= time) 35 | this 36 | else { 37 | val newV = value * math.pow(1 - alpha.toDouble, t2 - time) 38 | if (math.abs(newV) > eps.toDouble) 39 | ConstantStep(newV, t2) 40 | else 41 | ConstantStep.zero 42 | } 43 | } 44 | 45 | object ConstantStep { 46 | case class Alpha(toDouble: Double) extends AnyVal { 47 | def *(r: Double): Double = toDouble * r 48 | } 49 | case class Epsilon(toDouble: Double) extends AnyVal 50 | 51 | val zero: ConstantStep = ConstantStep(0.0, Time.Min) 52 | 53 | def buildAggregate[T](value: T)(implicit num: Numeric[T]): ConstantStep = 54 | buildAggregate(value, Time.Min) 55 | 56 | def buildAggregate[T](value: T, time: Time)(implicit 57 | num: Numeric[T] 58 | ): ConstantStep = 59 | ConstantStep(num.toDouble(value), time) 60 | 61 | /** Rewards can only be assigned to time one tick in the future. 62 | */ 63 | def buildReward[T](reward: T, alpha: Alpha, time: Time)(implicit 64 | num: Numeric[T] 65 | ): ConstantStep = 66 | ConstantStep(alpha * num.toDouble(reward), time.tick) 67 | 68 | def group(alpha: Alpha, eps: Epsilon): Group[ConstantStep] = 69 | new ConstantStepGroup(alpha, eps) 70 | 71 | implicit def module(implicit 72 | G: Group[ConstantStep] 73 | ): Module[Double, ConstantStep] = 74 | Module.from((d, cs) => ConstantStep(cs.value * d, cs.time)) 75 | } 76 | 77 | class ConstantStepGroup( 78 | alpha: ConstantStep.Alpha, 79 | eps: ConstantStep.Epsilon 80 | ) extends Group[ConstantStep] { 81 | override val zero: ConstantStep = ConstantStep.zero 82 | 83 | override def isNonZero(cs: ConstantStep) = cs.value != 0L 84 | 85 | override def negate(v: ConstantStep) = ConstantStep(-v.value, v.time) 86 | 87 | override def plus(l: ConstantStep, r: ConstantStep) = { 88 | val (a, b, t) = 89 | if (l.time < r.time) 90 | (l.decayTo(r.time, alpha, eps), r, r.time) 91 | else 92 | (l, r.decayTo(l.time, alpha, eps), l.time) 93 | 94 | ConstantStep(a.value + b.value, t) 95 | } 96 | 97 | /** Returns the value if the timestamp is less than the time of the supplied ConstantStep 98 | * instance. 99 | */ 100 | def valueAsOf(v: ConstantStep, time: Time): Double = 101 | v.decayTo(time, alpha, eps).value 102 | 103 | /** This assigns the reward at the current time, which forces the timestamp forward. 104 | * 105 | * If you didn't bump the time you could do this: 106 | * 107 | * modern + (reward * (alpha / (1.0 - alpha))) 108 | * 109 | * And force the alpha to be less than one. 110 | */ 111 | def reward(v: ConstantStep, reward: Double, time: Time): ConstantStep = { 112 | val newTime = time.tick 113 | val updatedV = v.decayTo(newTime, alpha, eps).value 114 | ConstantStep( 115 | updatedV + alpha * reward, 116 | newTime 117 | ) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/value/DecayState.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package value 3 | 4 | import com.twitter.algebird.{Group, Ring, VectorSpace} 5 | import com.scalarl.algebra.{Expectation, Module, ToDouble} 6 | import com.scalarl.evaluate.StateValue 7 | 8 | /** This represents a value that's weighted as you move away from it. This is useful because we can 9 | * KEEP GOING, and continue to weight it. 10 | */ 11 | sealed trait DecayState[A] extends Product with Serializable { 12 | def toValue: DecayState.DecayedValue[A] 13 | def get: A 14 | } 15 | 16 | object DecayState { 17 | import Module.DModule 18 | 19 | case class Reward[A](get: A) extends DecayState[A] { 20 | override lazy val toValue: DecayedValue[A] = DecayedValue(get) 21 | } 22 | case class DecayedValue[A](get: A) extends DecayState[A] { 23 | override val toValue: DecayedValue[A] = this 24 | } 25 | 26 | /** Filling in. 27 | */ 28 | def bellmanFn[Obs, A, R: DModule, T, M[_]: Expectation, S[_]: Expectation]( 29 | gamma: Double 30 | ): ( 31 | StateValueFn[Obs, DecayState[R]], 32 | Policy[Obs, A, R, M, S] 33 | ) => StateValue[Obs, A, R, DecayState[R], S] = { 34 | val group = decayStateGroup[R](gamma) 35 | implicit val module = decayStateModule[R](gamma) 36 | (f, p) => 37 | Evaluator.bellman[Obs, A, R, DecayState[R], M, S]( 38 | f, 39 | p, 40 | Reward(_), 41 | group.plus(_, _) 42 | ) 43 | } 44 | 45 | def decayStateModule[A]( 46 | gamma: Double 47 | )(implicit M: Module[Double, A]): Module[Double, DecayState[A]] = { 48 | implicit val group: Group[DecayState[A]] = decayStateGroup(gamma) 49 | Module.from((r, d) => 50 | d match { 51 | case Reward(reward) => Reward(M.scale(r, reward)) 52 | case DecayedValue(v) => DecayedValue(M.scale(r, v)) 53 | } 54 | ) 55 | } 56 | 57 | // This is just sort of silly and probably can go. 58 | def decayStateVectorSpace[A: Ring]( 59 | gamma: Double 60 | )(implicit M: Module[Double, A]): VectorSpace[A, DecayState] = { 61 | implicit val group: Group[DecayState[A]] = decayStateGroup(gamma) 62 | VectorSpace.from((r, d) => 63 | d match { 64 | case Reward(reward) => Reward(Ring.times(r, reward)) 65 | case DecayedValue(v) => DecayedValue(Ring.times(r, v)) 66 | } 67 | ) 68 | } 69 | 70 | def decayStateGroup[A]( 71 | gamma: Double 72 | )(implicit M: Module[Double, A]): Group[DecayState[A]] = 73 | new Group[DecayState[A]] { 74 | private val GA = M.group 75 | override val zero = DecayedValue(GA.zero) 76 | override def negate(d: DecayState[A]) = d match { 77 | case Reward(a) => Reward(GA.negate(a)) 78 | case DecayedValue(a) => DecayedValue(GA.negate(a)) 79 | } 80 | override def plus(l: DecayState[A], r: DecayState[A]) = (l, r) match { 81 | case (Reward(a), Reward(b)) => Reward(GA.plus(a, b)) 82 | case (DecayedValue(a), Reward(b)) => 83 | DecayedValue(GA.plus(M.scale(gamma, a), b)) 84 | case (Reward(a), DecayedValue(b)) => 85 | DecayedValue(GA.plus(M.scale(gamma, b), a)) 86 | case (DecayedValue(a), DecayedValue(b)) => DecayedValue(GA.plus(a, b)) 87 | } 88 | } 89 | 90 | implicit def toDouble[A](implicit A: ToDouble[A]): ToDouble[DecayState[A]] = 91 | A.contramap(_.get) 92 | 93 | implicit def dsOrd[A: Ordering]: Ordering[DecayState[A]] = 94 | Ordering.by(_.get) 95 | } 96 | -------------------------------------------------------------------------------- /scala-rl-core/src/main/scala/com/scalarl/value/WeightedAverage.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package value 3 | 4 | import _root_.algebra.CommutativeGroup 5 | import com.twitter.algebird.Group 6 | import com.scalarl.algebra.Weight 7 | 8 | /** This is of course extremely similar to the averaged value implementation in Algebird... it just 9 | * keeps track of a numerator AND denominator 10 | */ 11 | case class WeightedAverage(weightSum: Weight, value: Double) { 12 | 13 | /** Returns a copy of this instance with a negative value. Note that 14 | * 15 | * {{{ 16 | * a + -b == a - b 17 | * }}} 18 | */ 19 | def unary_- : WeightedAverage = copy(value = -value) 20 | 21 | /** Averages this instance with the *opposite* of the supplied [[WeightedAverage]] instance, 22 | * effectively subtracting out that instance's contribution to the mean. 23 | * 24 | * @param r 25 | * the instance to subtract 26 | * @return 27 | * an instance with `r`'s stream subtracted out 28 | */ 29 | def -(r: WeightedAverage): WeightedAverage = 30 | WeightedAverageGroup.minus(this, r) 31 | 32 | /** Averages this instance with another [[WeightedAverage]] instance. 33 | * @param r 34 | * the other instance 35 | * @return 36 | * an instance representing the mean of this instance and `r`. 37 | */ 38 | def +(r: WeightedAverage): WeightedAverage = 39 | WeightedAverageGroup.plus(this, r) 40 | 41 | /** Returns a new instance that averages `that` into this instance. 42 | * 43 | * @param that 44 | * value to average into this instance 45 | * @return 46 | * an instance representing the mean of this instance and `that`. 47 | */ 48 | def +(that: Double): WeightedAverage = plus(that, Weight.One) 49 | 50 | def plus(that: Double, weight: Weight): WeightedAverage = 51 | WeightedAverage( 52 | weightSum + weight, 53 | WeightedAverageGroup.getCombinedMean(weightSum.w, value, weight.w, that) 54 | ) 55 | } 56 | 57 | object WeightedAverage { 58 | implicit val group: Group[WeightedAverage] = WeightedAverageGroup 59 | } 60 | 61 | /** [[Group]] implementation for [[WeightedAverage]]. 62 | * 63 | * @define object 64 | * `WeightedAverage` 65 | */ 66 | object WeightedAverageGroup extends Group[WeightedAverage] with CommutativeGroup[WeightedAverage] { 67 | 68 | /** When combining averages, if the counts sizes are too close we should use a different 69 | * algorithm. This constant defines how close the ratio of the smaller to the total count can be: 70 | */ 71 | private val STABILITY_CONSTANT = 0.1 72 | 73 | /** Given two streams of doubles (n, an) and (k, ak) of form (count, mean), calculates the mean of 74 | * the combined stream. 75 | * 76 | * Uses a more stable online algorithm which should be suitable for large numbers of records 77 | * similar to: 78 | * http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 79 | */ 80 | private[scalarl] def getCombinedMean( 81 | n: Double, 82 | an: Double, 83 | k: Double, 84 | ak: Double 85 | ): Double = 86 | if (n < k) getCombinedMean(k, ak, n, an) 87 | else 88 | (n + k) match { 89 | case 0.0 => 0.0 90 | case newCount if newCount == n => an 91 | case newCount => 92 | val scaling = k / newCount 93 | // a_n + (a_k - a_n)*(k/(n+k)) is only stable if n is not approximately k 94 | if (scaling < STABILITY_CONSTANT) an + (ak - an) * scaling 95 | else (n * an + k * ak) / newCount 96 | } 97 | 98 | override val zero: WeightedAverage = WeightedAverage(Weight.Zero, 0.0) 99 | 100 | override def isNonZero(av: WeightedAverage): Boolean = av.value != 0L 101 | 102 | override def negate(av: WeightedAverage): WeightedAverage = -av 103 | 104 | /** Optimized implementation of [[plus]]. Uses internal mutation to combine the supplied 105 | * [[WeightedAverage]] instances without creating intermediate objects. 106 | */ 107 | override def sumOption( 108 | iter: TraversableOnce[WeightedAverage] 109 | ): Option[WeightedAverage] = 110 | if (iter.isEmpty) None 111 | else { 112 | var weightSum = 0.0 113 | var average = 0.0 114 | iter.foreach { case WeightedAverage(Weight(w), v) => 115 | average = getCombinedMean(weightSum, average, w, v) 116 | weightSum += w 117 | } 118 | Some(WeightedAverage(Weight(weightSum), average)) 119 | } 120 | 121 | /** @see 122 | * [[WeightedAverage!.+(r:*]] for the implementation. 123 | */ 124 | override def plus(l: WeightedAverage, r: WeightedAverage): WeightedAverage = { 125 | val n = l.weightSum 126 | val k = r.weightSum 127 | val newAve = getCombinedMean(n.w, l.value, k.w, r.value) 128 | WeightedAverage(n + k, newAve) 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /scala-rl-core/src/test/scala/com/scalarl/value/ConstantStepLaws.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package value 3 | 4 | import com.twitter.algebird._ 5 | import org.scalacheck.{Arbitrary, Gen} 6 | import org.scalatest.propspec.AnyPropSpec 7 | import org.scalatestplus.scalacheck.Checkers 8 | import org.scalacheck.Prop.forAll 9 | 10 | class ConstantStepLaws extends AnyPropSpec with Checkers with ConstantStepArb { 11 | import BaseProperties._ 12 | import ConstantStep.{zero, Alpha} 13 | import ConstantStepLaws.{alpha, fill, EPS} 14 | 15 | implicit val stepGroup: ConstantStepGroup = 16 | new ConstantStepGroup(alpha, EPS) 17 | 18 | implicit val equiv: Equiv[ConstantStep] = 19 | Equiv.fromFunction { (l, r) => 20 | ((l.value == 0L) && (l.value == 0L)) || 21 | approxEq(EPS.toDouble)(l.value, r.value) && (l.time == r.time) 22 | } 23 | 24 | property("ConstantStep forms a commutative group")(check { 25 | groupLaws[ConstantStep] && isCommutative[ConstantStep] 26 | }) 27 | 28 | property("ConstantStep's monoid works like the single-step version")(check { 29 | forAll { (rewards: List[Int]) => 30 | val (csAccumulator, t) = fill(stepGroup, zero, rewards) 31 | val simpleAcc = rewards.foldLeft(0.0) { case (acc, reward) => 32 | acc + alpha * (reward - acc) 33 | } 34 | approxEq(EPS.toDouble)(simpleAcc, csAccumulator.value) 35 | } 36 | }) 37 | 38 | property( 39 | "Adding two instances together acts like a single instance with double rewards" 40 | )(check { 41 | forAll { (rewards: List[Float]) => 42 | val (acc, ts1) = fill(stepGroup, zero, rewards) 43 | val (doubleAcc, ts2) = fill(stepGroup, zero, rewards.map(_.toDouble * 2)) 44 | 45 | approxEq(EPS.toDouble)(stepGroup.plus(acc, acc).value, doubleAcc.value) 46 | } 47 | }) 48 | 49 | property("With an alpha of one, all weight's placed on the latest reward.")( 50 | check { 51 | val oneMonoid = new ConstantStepGroup(Alpha(1.0), EPS) 52 | 53 | forAll { (rewards: List[Int]) => 54 | val instances = 55 | rewards.scanLeft((zero, zero.time)) { case ((acc, ts), r) => 56 | (oneMonoid.reward(acc, r, ts), ts.tick) 57 | } 58 | 59 | instances.tail.zip(rewards).forall { case ((acc, _), reward) => 60 | approxEq(EPS.toDouble)(acc.value, reward) 61 | } 62 | } 63 | } 64 | ) 65 | 66 | property( 67 | "adding a reward works the same as adding an instance one tick later." 68 | )(check { 69 | forAll { (cs: ConstantStep, reward: Double) => 70 | approxEq(EPS.toDouble)( 71 | stepGroup.reward(cs, reward, cs.time).value, 72 | stepGroup 73 | .plus(cs, ConstantStep.buildAggregate(alpha * reward, cs.time.tick)) 74 | .value 75 | ) 76 | } 77 | }) 78 | 79 | property("A reward is an aggregate * alpha, one step in the future.")(check { 80 | forAll { (reward: Double, time: Time) => 81 | approxEq(EPS.toDouble)( 82 | ConstantStep.buildReward(reward, alpha, time).value, 83 | ConstantStep.buildAggregate(alpha * reward, time.tick).value 84 | ) 85 | } 86 | }) 87 | } 88 | 89 | object ConstantStepLaws { 90 | import ConstantStep.{Alpha, Epsilon} 91 | 92 | val EPS: Epsilon = Epsilon(1e-10) 93 | val alpha: Alpha = Alpha(0.1) 94 | val stepGroup: ConstantStepGroup = 95 | new ConstantStepGroup(alpha, EPS) 96 | 97 | def fill[T: Numeric]( 98 | monoid: ConstantStepGroup, 99 | init: ConstantStep, 100 | rewards: List[T] 101 | ): (ConstantStep, Time) = 102 | rewards 103 | .foldLeft((init, init.time)) { case ((acc, ts), r) => 104 | (monoid.reward(acc, implicitly[Numeric[T]].toDouble(r), ts), ts.tick) 105 | } 106 | } 107 | 108 | class ConstantStepTest extends org.scalatest.funsuite.AnyFunSuite { 109 | import BaseProperties.approxEq 110 | import ConstantStep.zero 111 | import ConstantStepLaws.{alpha, stepGroup, EPS} 112 | 113 | test("Two steps of the normal increment works as expected") { 114 | val r1 = 10 115 | val r2 = 12 116 | 117 | val stepOne = stepGroup.reward(zero, r1, zero.time) 118 | val stepTwo = stepGroup.reward(stepOne, r2, stepOne.time) 119 | 120 | assert(approxEq(EPS.toDouble)(stepOne.value, alpha * r1)) 121 | assert( 122 | approxEq(EPS.toDouble)( 123 | stepTwo.value, 124 | (alpha * r1) + alpha * (r2 - (alpha * r1)) 125 | ) 126 | ) 127 | } 128 | 129 | test("monoid works like the single-step version") { 130 | val rewards = List[Double](10, 50, 40, 32, 1.0) 131 | val (csAccumulator, t) = ConstantStepLaws.fill(stepGroup, zero, rewards) 132 | val simpleAcc = rewards.foldLeft(0.0) { case (acc, reward) => 133 | acc + alpha * (reward - acc) 134 | } 135 | 136 | assert(approxEq(EPS.toDouble)(csAccumulator.value, simpleAcc)) 137 | } 138 | 139 | test( 140 | "Adding an instance with (alpha * reward) one tick in the future equals a reward now." 141 | ) { 142 | val r: Double = 10.0 143 | 144 | val rewarded = stepGroup.reward(zero, r, zero.time) 145 | val stepped = stepGroup.plus(zero, ConstantStep(alpha * r, zero.time.tick)) 146 | 147 | assert(approxEq(EPS.toDouble)(rewarded.value, stepped.value)) 148 | } 149 | } 150 | 151 | /** Generators and Arbitrary instances live below. 152 | */ 153 | trait ConstantStepGen { 154 | def genTime: Gen[Time] = 155 | Gen 156 | .choose(Int.MinValue.toLong, Int.MaxValue.toLong) 157 | .map(Time(_)) 158 | 159 | def genStep: Gen[ConstantStep] = 160 | for { 161 | value <- Gen.choose(-1e50, 1e50) 162 | time <- genTime 163 | } yield ConstantStep(value, time) 164 | } 165 | 166 | object ConstantStepGenerators extends ConstantStepGen 167 | 168 | trait ConstantStepArb { 169 | import ConstantStepGenerators._ 170 | 171 | implicit val arbStep: Arbitrary[ConstantStep] = Arbitrary(genStep) 172 | implicit val arbTime: Arbitrary[Time] = Arbitrary(genTime) 173 | } 174 | -------------------------------------------------------------------------------- /scala-rl-plot/src/main/scala/com/scalarl/plot/Plot.scala: -------------------------------------------------------------------------------- 1 | /** The good stuff. Plotting charts. Options were Plotly and Breeze-Viz... but then, those are both 2 | * a little busted. So I decided to go with Evilplot. 3 | * 4 | * https://cibotech.github.io/evilplot/plot-catalog.html 5 | * 6 | * Here's a great example of the kinds of things we can do with this plotting library: 7 | * 8 | * https://www.cibotechnologies.com/about/blog/scalastan-and-evilplot-bayesian-statistics-meets-combinator-based-visualization/ 9 | */ 10 | package com.scalarl 11 | package plot 12 | 13 | import com.cibo.evilplot.colors.{HTMLNamedColors, RGB} 14 | import com.cibo.evilplot.displayPlot 15 | import com.cibo.evilplot.numeric.{Bounds, Point} 16 | import com.cibo.evilplot.plot.{FunctionPlot, Heatmap, LinePlot, Overlay} 17 | import com.cibo.evilplot.plot.aesthetics.DefaultTheme 18 | import com.cibo.evilplot.numeric.Point 19 | 20 | object Plot { 21 | import DefaultTheme._ 22 | 23 | // Example of a linechart, just testing it out. 24 | def lineChartSeq(pointSeq: (Seq[Double], String)*): Unit = 25 | lineChart( 26 | pointSeq.map { case (points, title) => 27 | (points.toList.zipWithIndex.map { case (a, i) => Point(i, a) }, title) 28 | } 29 | ) 30 | 31 | def lineChart(data: Seq[(Seq[Point], String)]): Unit = 32 | displayPlot { 33 | Overlay( 34 | data.map { case (points, title) => 35 | LinePlot.series(points, title, RGB.random) 36 | }: _* 37 | ).xAxis() 38 | .yAxis() 39 | .frame() 40 | .xLabel("x") 41 | .yLabel("y") 42 | .title("Yo!") 43 | .overlayLegend() 44 | .render() 45 | } 46 | 47 | // test of a polynomail plot, again, just an example to work with. 48 | def polyPlot(): Unit = { 49 | val x = Overlay( 50 | FunctionPlot.series( 51 | x => x * x, 52 | "y = x^2", 53 | HTMLNamedColors.dodgerBlue, 54 | xbounds = Some(Bounds(-1, 1)) 55 | ), 56 | FunctionPlot 57 | .series( 58 | x => math.pow(x, 3), 59 | "y = x^3", 60 | HTMLNamedColors.crimson, 61 | xbounds = Some(Bounds(-1, 1)) 62 | ), 63 | FunctionPlot 64 | .series( 65 | x => math.pow(x, 4), 66 | "y = x^4", 67 | HTMLNamedColors.green, 68 | xbounds = Some(Bounds(-1, 1)) 69 | ) 70 | ).title("A bunch of polynomials.") 71 | .overlayLegend() 72 | .standard() 73 | .render() 74 | displayPlot(x) 75 | } 76 | 77 | def gridPlot(): Unit = { 78 | import com.cibo.evilplot.demo.DemoPlots 79 | displayPlot(DemoPlots.axesTesting) 80 | // displayPlot( 81 | 82 | // ScatterPlot(data) 83 | // .frame() 84 | // .xLabel("x") 85 | // .yLabel("y") 86 | // .xGrid(lineCount = Some(8)) 87 | // // lineRenderer = Some(GridLineRenderer.custom { (extent, label) => 88 | // // Line(extent.height, theme.elements.gridLineSize) 89 | // // .colored(HTMLNamedColors.black) 90 | // // .rotated(90) 91 | // // })) 92 | // .yGrid( 93 | // lineCount = Some(8), 94 | // lineRenderer = Some(GridLineRenderer.custom { (extent, label) => 95 | // Line(extent.width, theme.elements.gridLineSize) 96 | // .colored(HTMLNamedColors.black) 97 | // }) 98 | // ) 99 | // .render() 100 | // ) 101 | } 102 | 103 | def heatMap(data: Seq[Seq[Double]], colorCount: Int): Unit = 104 | displayPlot( 105 | Heatmap(data, colorCount) 106 | .standard() 107 | .rightLegend() 108 | .render() 109 | ) 110 | 111 | def main(items: Array[String]): Unit = 112 | // lineChart(Seq(Seq.tabulate(100) { i => 113 | // Point(i.toDouble, scala.util.Random.nextDouble()) 114 | // } -> "Title.")) 115 | 116 | // polyPlot() 117 | gridPlot() 118 | } 119 | -------------------------------------------------------------------------------- /scala-rl-plot/src/main/scala/com/scalarl/plot/Tabulator.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl 2 | package plot 3 | 4 | /** Shamelessly copied from: http://stackoverflow.com/a/7542476 5 | * 6 | * I modified this slightly to take Iterable instances instead of Seq instances; this still needs 7 | * to be updated to handle more formatting options. 8 | * 9 | * I'm using this to print tables for Gridworld. 10 | */ 11 | object Tabulator { 12 | 13 | def csv(table: Iterable[Iterable[Any]]): String = 14 | table.map(_.mkString(",")).mkString("\n") 15 | 16 | def format(table: Iterable[Iterable[Any]]): String = 17 | if (table.isEmpty) "" 18 | else { 19 | val sizes = 20 | for (row <- table) 21 | yield for (cell <- row) 22 | yield 23 | if (cell == null) 0 24 | else cell.toString.length 25 | val colSizes = for (col <- sizes.transpose) yield col.max 26 | val rows = for (row <- table) yield formatRow(row, colSizes) 27 | formatRows(rowSeparator(colSizes), rows) 28 | } 29 | 30 | def formatRows(rowSeparator: String, rows: Iterable[String]): String = 31 | (rowSeparator :: rows.toList ::: rowSeparator :: List( 32 | )).mkString("\n") 33 | 34 | def formatRow(row: Iterable[Any], colSizes: Iterable[Int]): String = { 35 | val cells = 36 | for ((item, size) <- row.zip(colSizes)) 37 | yield 38 | if (size == 0) "" 39 | else ("%" + size.toString + "s").format(item) 40 | cells.mkString("|", "|", "|") 41 | } 42 | 43 | def rowSeparator(colSizes: Iterable[Int]): String = 44 | colSizes 45 | .map { 46 | "-" * _ 47 | } 48 | .mkString("+", "+", "+") 49 | } 50 | -------------------------------------------------------------------------------- /scala-rl-world/src/main/scala/com/scalarl/world/Bandit.scala: -------------------------------------------------------------------------------- 1 | /** A bandit is a particular kind of state. 2 | */ 3 | package com.scalarl 4 | package world 5 | 6 | import com.stripe.rainier.core.Generator 7 | import com.scalarl.state.MapState 8 | 9 | object Bandit { 10 | object Arm { 11 | implicit val ordering: Ordering[Arm] = Ordering.by(_.i) 12 | } 13 | 14 | case class Arm(i: Int) 15 | 16 | /** An "Arm" is something that takes you to a new state. We just happen to have only a single 17 | * state here, so it always takes you back to a given "bandit" problem. 18 | */ 19 | def arms(k: Int): Set[Arm] = (0 until k).map(Arm(_)).toSet 20 | 21 | /** Returns a Generator that splits out states for each of the games to play. 22 | */ 23 | def stationary( 24 | nArms: Int, 25 | gen: Generator[Generator[Double]] 26 | ): Generator[State[Unit, Arm, Double, Generator]] = { 27 | val penalty = Generator.constant(0.0) 28 | MapState.static(arms(nArms), penalty, gen) 29 | } 30 | 31 | /** Returns a Generator that splits out states for each of the games to play. This generator 32 | * evolves in a non-stationary way. 33 | * 34 | * The set below is totally fucked... it's returning a SINGLE generator each time, not the good 35 | * stuff that we need. 36 | */ 37 | def nonStationary( 38 | nArms: Int, 39 | gen: Generator[Generator[Double]], 40 | updater: (Arm, Double, Generator[Double]) => Generator[Double] 41 | ): Generator[State[Unit, Arm, Double, Generator]] = { 42 | val penalty = Generator.constant(0.0) 43 | MapState.updating[Unit, Arm, Double]( 44 | arms(nArms), 45 | (), 46 | penalty, 47 | gen, 48 | (obs, a, r, gen) => ((), updater(a, r, gen)) 49 | ) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /scala-rl-world/src/main/scala/com/scalarl/world/Blackjack.scala: -------------------------------------------------------------------------------- 1 | /** Monadic Blackjack! 2 | */ 3 | package com.scalarl 4 | package world 5 | 6 | import cats.{Id, Monad} 7 | import cats.implicits._ 8 | import com.scalarl.world.util.CardDeck 9 | 10 | object Blackjack { 11 | import CardDeck.{Card, Rank} 12 | 13 | // TODO - to make this more legit we need to add the ability to bet. 14 | sealed trait Action extends Product with Serializable 15 | object Action { 16 | final case object Hit extends Action 17 | final case object Stay extends Action 18 | } 19 | 20 | sealed trait Result extends Product with Serializable 21 | object Result { 22 | final case object Win extends Result 23 | final case object Draw extends Result 24 | final case object Lose extends Result 25 | final case object Pending extends Result 26 | } 27 | 28 | def cardValue(card: Card): Int = card.rank match { 29 | case Rank.Ace => 11 30 | case Rank.Jack | Rank.Queen | Rank.King => 10 31 | case Rank.Number(n) => n 32 | } 33 | 34 | /** TODO - to make this solid, the Hand should actually maintain a sorted state, so that we can 35 | * use it as the key in a hashmap. It's fine for now, since this state is actually going to get 36 | * dropped down into a state viewable by a policy. 37 | */ 38 | case class Hand(showing: Seq[Card], hidden: Seq[Card]) { 39 | def takeCard(card: Card, isShowing: Boolean): Hand = 40 | if (isShowing) 41 | copy(showing = showing :+ card) 42 | else 43 | copy(hidden = hidden :+ card) 44 | 45 | val usableAce: Boolean = Hand.aceCount(cards) > 0 46 | val totalScore: Int = Hand.score(cards) 47 | val showingScore: Int = Hand.score(showing) 48 | 49 | def cards: Seq[Card] = showing ++ hidden 50 | def busted: Boolean = totalScore > 21 51 | def showAll: Hand = if (hidden.isEmpty) this else Hand(cards, Seq.empty) 52 | } 53 | 54 | object Hand { 55 | val empty = Hand(Seq.empty, Seq.empty) 56 | 57 | def aceCount(cards: Seq[Card]): Int = cards.filter(_.rank == Rank.Ace).size 58 | def maxPoints(cards: Seq[Card]): Int = 59 | cards.foldLeft(0)((acc, c) => acc + cardValue(c)) 60 | 61 | def score(cards: Seq[Card]): Int = { 62 | def loop(points: Int, acesLeft: Int): Int = 63 | if (acesLeft <= 0) 64 | points 65 | else if (points > 21) 66 | loop(points - 10, acesLeft - 1) 67 | else points 68 | 69 | loop(maxPoints(cards), aceCount(cards)) 70 | } 71 | } 72 | 73 | /** This is the actual, full rich game. 74 | */ 75 | case class Game(player: Hand, dealer: Hand) { 76 | def agentView: AgentView = AgentView( 77 | player.usableAce, 78 | player.totalScore, 79 | dealer.showingScore 80 | ) 81 | def showAll: Game = Game(player.showAll, dealer.showAll) 82 | } 83 | object Game { 84 | val empty = Game(Hand.empty, Hand.empty) 85 | } 86 | 87 | def dealerHand[M[_]: Monad](getCard: M[Card]): M[Hand] = 88 | for { 89 | showing <- getCard 90 | hidden <- getCard 91 | } yield Hand(Seq(showing), Seq(hidden)) 92 | 93 | def playerHand[M[_]: Monad](getCard: M[Card]): M[Hand] = 94 | Monad[M].tailRecM[Hand, Hand](Hand.empty) { hand => 95 | if (hand.totalScore < 12) 96 | getCard.map(card => Left(hand.takeCard(card, true))) 97 | else 98 | Monad[M].pure(Right(hand)) 99 | } 100 | 101 | def gameGenerator[M[_]: Monad](getCard: M[Card]): M[Game] = 102 | for { 103 | player <- playerHand(getCard) 104 | dealer <- dealerHand(getCard) 105 | } yield Game(player, dealer) 106 | 107 | case class Config[M[_]: Monad](getCard: M[Card]) { 108 | def build(startingState: Game): Blackjack[M] = 109 | Alive(this, startingState) 110 | 111 | def stateM: M[Blackjack[M]] = gameGenerator(getCard).map(build(_)) 112 | } 113 | 114 | /** This is what the agent is allowed to see. 115 | */ 116 | case class AgentView( 117 | usableAce: Boolean, 118 | playerSum: Int, 119 | dealerSum: Int 120 | ) 121 | 122 | /** Generate a simple fixed policy for an agent. 123 | */ 124 | def policy[S[_]]( 125 | f: AgentView => Action 126 | ): Policy[AgentView, Action, Double, Id, S] = 127 | Policy.choose[AgentView, Action, Double, Id, S](s => f(s.observation)) 128 | 129 | // TODO get the game below to use this as the "opponent" instead of manually 130 | // doing it. 131 | def dealerPolicy[S[_]](hitBelow: Int): Policy[Game, Action, Double, Id, S] = 132 | Policy.choose { state => 133 | val hand = state.observation.dealer 134 | if (hand.totalScore < hitBelow) Action.Hit else Action.Stay 135 | } 136 | } 137 | 138 | sealed trait Blackjack[M[_]] extends State[Blackjack.Game, Blackjack.Action, Blackjack.Result, M] { 139 | def game: Blackjack.Game 140 | } 141 | 142 | case class Dead[M[_]: Monad](game: Blackjack.Game) extends Blackjack[M] { 143 | override val invalidMove = Monad[M].pure((Blackjack.Result.Lose, this)) 144 | override val observation = game.showAll 145 | override val dynamics = Map.empty 146 | } 147 | 148 | /** So this is PROBABLY a place where I actually need the full state, so I can track that the dealer 149 | * has two cards, generated randomly. 150 | */ 151 | case class Alive[M[_]: Monad](config: Blackjack.Config[M], game: Blackjack.Game) 152 | extends Blackjack[M] { 153 | import Blackjack.{Action, Game, Hand, Result} 154 | import CardDeck.Card 155 | 156 | override val observation: Game = game 157 | 158 | // I think we're not going to be able to ever need to call this from the 159 | // current set of techniques... so maybe we move this to some place where we 160 | // have an expected value? 161 | override def dynamics: Map[Action, M[(Result, This)]] = 162 | Map( 163 | Action.Hit -> config.getCard.map(hit(_)), 164 | Action.Stay -> dealerTurn(config.getCard) 165 | ) 166 | override val invalidMove = Monad[M].pure((Blackjack.Result.Pending, this)) 167 | 168 | private def hit(card: Card): (Result, This) = { 169 | val newGame = Game(game.player.takeCard(card, true), game.dealer) 170 | if (newGame.player.busted) 171 | (Result.Lose, Dead(newGame)) 172 | else 173 | (Result.Pending, copy(game = newGame)) 174 | } 175 | 176 | private def endingResult(player: Hand, dealer: Hand): Result = 177 | if (dealer.busted || player.totalScore > dealer.totalScore) 178 | Result.Win 179 | else if (game.player.totalScore == dealer.totalScore) 180 | Result.Draw 181 | else 182 | Result.Lose 183 | 184 | // This is really a policy interaction... we should allow these to ping back 185 | // and forth. Can I just PLAY a policy? 186 | // TODO convert this to use the dealer policy that's in the config. 187 | private def dealerTurn(getCard: M[Card]) = 188 | Monad[M].tailRecM[Hand, (Result, This)](game.dealer) { hand => 189 | if (hand.totalScore < 17) 190 | getCard.map(card => Left(hand.takeCard(card, true))) 191 | else { 192 | val result = endingResult(game.player, hand) 193 | val dead = Dead[M](game.copy(dealer = hand)) 194 | Monad[M].pure(Right((result, dead))) 195 | } 196 | } 197 | } 198 | -------------------------------------------------------------------------------- /scala-rl-world/src/main/scala/com/scalarl/world/CarRental.scala: -------------------------------------------------------------------------------- 1 | /** Car rental game based on what we have in Chapter 4. This generates Figure 4.2 and helps with the 2 | * homework assignments there. 3 | */ 4 | package com.scalarl 5 | package world 6 | 7 | import com.scalarl.rainier.Categorical 8 | 9 | object CarRental { 10 | import Categorical.Poisson 11 | import Poisson.Lambda 12 | 13 | case class Inventory(n: Int, maxN: Int) { 14 | def -(m: Move): Inventory = this + -m 15 | def +(m: Move): Inventory = Inventory(Util.clamp(n + m.n, 0, maxN), maxN) 16 | def update(rentals: Move, returns: Move): Inventory = 17 | Inventory( 18 | math.min(n - rentals.n + returns.n, maxN), 19 | maxN 20 | ) 21 | } 22 | case class Move(n: Int) extends AnyVal { 23 | def unary_- = Move(-n) 24 | } 25 | object Move { 26 | def inclusiveRange(fromMove: Move, toMove: Move): Iterable[Move] = 27 | (fromMove.n to toMove.n).map(Move(_)) 28 | } 29 | 30 | // One of these comes in for each location. 31 | case class Update(rentalRequests: Int, returns: Int) 32 | 33 | sealed trait DistConf extends Product with Serializable 34 | case class PoissonConfig(upperBound: Int, mean: Lambda) extends DistConf 35 | case class ConstantConfig(mean: Int) extends DistConf 36 | case class Location( 37 | requests: DistConf, 38 | returns: DistConf, 39 | maxCars: Int 40 | ) 41 | 42 | // This is the update that comes in for both bullshits 43 | type InvPair = (Inventory, Inventory) 44 | 45 | case class Config( 46 | aConfig: Location, 47 | bConfig: Location, 48 | maxMoves: Move, 49 | rentalCredit: Double, 50 | moveCost: Double 51 | ) { 52 | import cats.implicits._ 53 | 54 | val allMoves: Iterable[Move] = Move.inclusiveRange(-maxMoves, maxMoves) 55 | lazy val dist: Cat[(Update, Update)] = 56 | ( 57 | toDistribution(aConfig.requests), 58 | toDistribution(aConfig.returns), 59 | toDistribution(bConfig.requests), 60 | toDistribution(bConfig.returns) 61 | ).mapN { case (a, b, c, d) => 62 | (Update(a, b), Update(c, d)) 63 | } 64 | 65 | def build(a: Inventory, b: Inventory): CarRental = 66 | CarRental(this, dist, a, b) 67 | 68 | def stateSweep: Traversable[CarRental] = 69 | for { 70 | a <- 0 to aConfig.maxCars 71 | b <- 0 to bConfig.maxCars 72 | } yield build( 73 | Inventory(a, aConfig.maxCars), 74 | Inventory(b, bConfig.maxCars) 75 | ) 76 | } 77 | 78 | def toDistribution(config: DistConf): Cat[Int] = 79 | config match { 80 | case PoissonConfig(upperBound, mean) => 81 | Categorical.poisson(upperBound, mean) 82 | case ConstantConfig(mean) => Categorical.pure(mean) 83 | } 84 | } 85 | 86 | import CarRental.{InvPair, Inventory, Move, Update} 87 | 88 | case class CarRental( 89 | config: CarRental.Config, 90 | pmf: Cat[(Update, Update)], 91 | a: Inventory, 92 | b: Inventory 93 | ) extends State[InvPair, Move, Double, Cat] { 94 | 95 | override val observation: InvPair = (a, b) 96 | 97 | /** Go through all possibilities... 98 | * 99 | * FIRST move the cars. THEN calculate the cost. 100 | * 101 | * THEN do the Poisson update and factor in the amount of money back, plus costs... 102 | * 103 | * positive goes from a to b, negative goes from b to a. 104 | * 105 | * TODO filter this so that we don't present moves that will more than deplete some spot. 106 | * Overloading is fine, since it gets the cars off the board... I guess? 107 | * 108 | * TODO I THINK we can only make this faster if we decide to use an Eval... get an EvalT going 109 | * for the monads, and define an expected value instance there. But then the first person to go 110 | * and iterate through will evaluate everything. 111 | */ 112 | override lazy val dynamics: Map[Move, Cat[(Double, CarRental)]] = 113 | Util.makeMapUnsafe(config.allMoves) { move => 114 | pmf.map { case (aUpdate, bUpdate) => 115 | val (newA, newB, reward) = processAll(move, aUpdate, bUpdate) 116 | (reward, copy(a = newA, b = newB)) 117 | } 118 | } 119 | override val invalidMove = Categorical.pure((0.0, this)) 120 | 121 | private def processAll( 122 | move: Move, 123 | aUpdate: Update, 124 | bUpdate: Update 125 | ): (Inventory, Inventory, Double) = { 126 | // TODO this shouldn't charge you if you CAN'T move a car. Check if there 127 | // are enough and fix that. 128 | val moveCost = config.moveCost * math.abs(move.n) 129 | val (newA, rewardA) = process(-move, a, aUpdate) 130 | val (newB, rewardB) = process(move, b, bUpdate) 131 | (newA, newB, rewardA + rewardB - moveCost) 132 | } 133 | 134 | private def process( 135 | move: Move, 136 | inventory: Inventory, 137 | update: Update 138 | ): (Inventory, Double) = { 139 | val afterMove = inventory + move 140 | val validRentals = math.min(afterMove.n, update.rentalRequests) 141 | val nextInventory = 142 | afterMove.update(Move(validRentals), Move(update.returns)) 143 | (nextInventory, config.rentalCredit * validRentals) 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /scala-rl-world/src/main/scala/com/scalarl/world/GamblersProblem.scala: -------------------------------------------------------------------------------- 1 | /** Gambler's Problem! Chapter 4 again; this generates Figure 4.3. 2 | */ 3 | package com.scalarl 4 | package world 5 | 6 | import com.scalarl.rainier.Categorical 7 | 8 | object GamblersProblem { 9 | case class Amount(p: Int) extends AnyVal { 10 | def >=(r: Amount): Boolean = p >= r.p 11 | } 12 | 13 | case class Config( 14 | headProb: Double, 15 | winningAmount: Amount, 16 | winningReward: Double 17 | ) { 18 | val headsDistribution: Cat[Boolean] = 19 | Categorical.boolean(headProb) 20 | 21 | def build(startingAmount: Amount): GamblersProblem = 22 | GamblersProblem(this, startingAmount) 23 | 24 | def stateSweep: Traversable[GamblersProblem] = 25 | for (amt <- 0 until winningAmount.p) yield build(Amount(amt)) 26 | } 27 | } 28 | 29 | /** Gotta read more about what the hell is going on, but the key is that we have 100 possible 30 | * states... for the value function. 31 | */ 32 | case class GamblersProblem( 33 | config: GamblersProblem.Config, 34 | amount: GamblersProblem.Amount 35 | ) extends State[GamblersProblem.Amount, GamblersProblem.Amount, Double, Cat] { 36 | import GamblersProblem.Amount 37 | 38 | override val observation = amount 39 | 40 | // Maybe we want a real penalty here. 41 | override val invalidMove = Categorical.pure((0.0, this)) 42 | 43 | override lazy val dynamics: Map[Amount, Cat[(Double, GamblersProblem)]] = 44 | if (amount >= config.winningAmount || amount.p <= 0) 45 | Map.empty 46 | else 47 | Util.makeMapUnsafe( 48 | (1 to math.min(amount.p, config.winningAmount.p - amount.p)) 49 | .map(Amount(_)) 50 | ) { move => 51 | config.headsDistribution.map { winningBet => 52 | val newAmount = 53 | if (winningBet) move.p + amount.p else move.p - amount.p 54 | val reward = 55 | if (newAmount == config.winningAmount.p) 56 | config.winningReward 57 | else 58 | 0 59 | (reward, copy(amount = Amount(newAmount))) 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /scala-rl-world/src/main/scala/com/scalarl/world/GridWorld.scala: -------------------------------------------------------------------------------- 1 | /** Gridworld implementation, based on what we need from Chapter 3. I think this can get much more 2 | * complicated, especially once observations come into the picture. 3 | */ 4 | package com.scalarl 5 | package world 6 | 7 | import cats.Id 8 | import com.scalarl.world.util.Grid 9 | import scala.util.{Success, Try} 10 | 11 | object GridWorld { 12 | import Grid.{Bounds, Position} 13 | 14 | object Jumps { 15 | def empty: Jumps = Jumps(Map.empty) 16 | } 17 | 18 | case class Jumps(jumps: Map[Position, (Position, Double)]) { 19 | def get(p: Position): Option[(Position, Double)] = jumps.get(p) 20 | 21 | /** TODO 22 | * 23 | * - validate that there are no cycles! 24 | * - validate that we're within the bounds for all endpoints! 25 | */ 26 | def validate(bounds: Bounds): Try[Jumps] = Success(this) 27 | def and(from: Position, to: Position, reward: Double): Jumps = 28 | Jumps(jumps.updated(from, (to, reward))) 29 | } 30 | 31 | case class Config( 32 | bounds: Bounds, 33 | default: Double = 0.0, 34 | penalty: Double = -1.0, 35 | jumps: Jumps = Jumps.empty, 36 | values: Map[Grid.Position, Double] = Map.empty, 37 | terminalStates: Set[Grid.Position] = Set.empty 38 | ) { 39 | def withJump(from: Position, to: Position, reward: Double): Config = 40 | copy(jumps = jumps.and(from, to, reward)) 41 | 42 | def withValue(position: Position, value: Double): Config = 43 | copy(values = values.updated(position, value)) 44 | 45 | def withTerminalState(position: Position, value: Double = default): Config = 46 | copy( 47 | values = values.updated(position, value), 48 | terminalStates = terminalStates + position 49 | ) 50 | 51 | /** Build by projecting a row or column outside of the specified bounds onto the boundary. 52 | */ 53 | def buildConfined(start: Position): GridWorld = 54 | buildUnsafe(start.confine(bounds)) 55 | 56 | /** Build, assuming that everything is legit! 57 | */ 58 | def buildUnsafe(start: Position): GridWorld = 59 | GridWorld( 60 | Grid(start, bounds), 61 | default, 62 | penalty, 63 | jumps, 64 | values, 65 | terminalStates 66 | ) 67 | 68 | /** Returns a Try that's successful if supplied position is within bounds, false otherwise. 69 | */ 70 | def build(start: Position): Try[GridWorld] = 71 | start.assertWithin(bounds).map(buildUnsafe(_)) 72 | 73 | def stateSweep: Traversable[GridWorld] = 74 | Grid.allStates(bounds).map { 75 | GridWorld(_, default, penalty, jumps, values, terminalStates) 76 | } 77 | } 78 | } 79 | 80 | /** TODO - redo this to store dynamics ALL OVER, so we don't have to recalculate them? lazily build 81 | * up the map... but don't REPLACE once it's there? That should slightly speed us up. 82 | */ 83 | case class GridWorld( 84 | grid: Grid, 85 | defaultReward: Double, 86 | penalty: Double, 87 | jumps: GridWorld.Jumps, 88 | values: Map[Grid.Position, Double], 89 | terminalStates: Set[Grid.Position] 90 | ) extends State[Grid.Position, Grid.Move, Double, Id] { 91 | import Grid.{Move, Position} 92 | 93 | override val observation: Position = grid.position 94 | override val invalidMove = (penalty, this) 95 | 96 | override lazy val dynamics: Map[Move, (Double, GridWorld)] = 97 | if (terminalStates(grid.position)) 98 | Map.empty 99 | else 100 | Util.makeMap(Grid.Move.all)(actNow(_)) 101 | 102 | private def positionValue(position: Position): Double = 103 | values.getOrElse(position, defaultReward) 104 | 105 | /** This is the NON-monadic action, since we can do it immediately. The dynamics are where it all 106 | * gets passed down to the user. 107 | * 108 | * There is still a wall, though! The user can't look ahead. If you CAN look ahead, and don't 109 | * hide it behind a delay, then boom, we have the ability to do the checkers example. 110 | */ 111 | private def actNow(move: Move): (Double, GridWorld) = 112 | jumps.get(grid.position) match { 113 | case None => 114 | grid 115 | .move(move) 116 | .map(g => (positionValue(g.position), copy(grid = g))) 117 | .getOrElse((penalty, this)) 118 | case Some((newPosition, reward)) => 119 | (reward, copy(grid = grid.teleportUnsafe(newPosition))) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /scala-rl-world/src/main/scala/com/scalarl/world/InfiniteVariance.scala: -------------------------------------------------------------------------------- 1 | /** Infinite variance world from Chapter 5. In this game, you can go left or right; if you go left, 2 | * you have some odds of winning or losing, and if you go right you immediately lose. 3 | */ 4 | package com.scalarl 5 | package world 6 | 7 | import com.scalarl.rainier.Categorical 8 | 9 | sealed trait InfiniteVariance 10 | extends State[InfiniteVariance.View, InfiniteVariance.Move, Int, Cat] { 11 | override val invalidMove = Categorical.pure((0, this)) 12 | } 13 | 14 | object InfiniteVariance { 15 | val startingState: InfiniteVariance = AliveState 16 | 17 | sealed trait Move extends Product with Serializable 18 | object Move { 19 | final case object Left extends Move 20 | final case object Right extends Move 21 | 22 | val all: Set[Move] = Set(Left, Right) 23 | } 24 | 25 | sealed trait View extends Product with Serializable 26 | object View { 27 | final case object Alive extends View 28 | final case object Dead extends View 29 | } 30 | 31 | object AliveState extends InfiniteVariance { 32 | override val observation = View.Alive 33 | override val dynamics: Map[Move, Cat[(Int, InfiniteVariance)]] = Map( 34 | Move.Left -> Categorical( 35 | Map( 36 | (0, AliveState) -> 0.1, 37 | (1, DeadState) -> 0.9 38 | ) 39 | ), 40 | Move.Right -> Categorical.pure((0, DeadState)) 41 | ) 42 | } 43 | 44 | object DeadState extends InfiniteVariance { 45 | override val observation = InfiniteVariance.View.Dead 46 | override val dynamics = Map.empty 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /scala-rl-world/src/main/scala/com/scalarl/world/connectfour/IO.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl.world.connectfour 2 | 3 | import scala.io.StdIn 4 | import scala.util.{Failure, Success, Try} 5 | 6 | object IO { 7 | import Game._ 8 | 9 | /** Return successful piece if it's possible to parse, failure otherwise. 10 | */ 11 | def getColumn(color: Color): Column = { 12 | println(s"Enter the column where you'd like to place your $color piece.") 13 | try 14 | StdIn.readLine("column> ").toInt 15 | catch { 16 | case e: NumberFormatException => 17 | println("What you entered isn't a valid number. Try again.") 18 | println 19 | getColumn(color) 20 | } 21 | } 22 | 23 | /** Gets the starting game color from the user. 24 | */ 25 | def initialColor: Color = { 26 | println("What color would like to go first? Red's the default.") 27 | StdIn.readLine("red or black> ").toLowerCase match { 28 | case "" => 29 | println("You entered nothing. We'll default to red.") 30 | println 31 | Color.Red 32 | case "red" => 33 | println("Red it is!") 34 | println 35 | Color.Red 36 | case "black" => 37 | println("Black it is!") 38 | println 39 | Color.Black 40 | case _ => 41 | println("I'm afraid that's not a valid input. Try again!") 42 | println 43 | initialColor 44 | } 45 | } 46 | 47 | /** Prints out the current board with some surrounding text. 48 | */ 49 | def printBoardState(board: Board): Unit = { 50 | println("Current Board State:") 51 | println(board) 52 | println 53 | } 54 | 55 | /** Performs a turn and returns either a failure, if the move was invalid in some way, or a 56 | * successful pair of the new board and the position updated by the move. 57 | */ 58 | def turn(board: Board, turnColor: Color): Try[(Board, Board.Position)] = { 59 | printBoardState(board) 60 | val column = getColumn(turnColor) 61 | board.tryMove(Move(column, turnColor))(board.performMove(_)) 62 | } 63 | 64 | /** Plays the game to completion, looping on every turn and alternating colors. The game ends if 65 | * the board fills up or if one side wins. 66 | */ 67 | def gameLoop(board: Board, turnColor: Color): Unit = 68 | turn(board, turnColor) match { 69 | case Success((newBoard, position)) => 70 | println(s"Nice, piece placed at $position.") 71 | println 72 | 73 | newBoard.checkAllPositionsForWin match { 74 | case None => 75 | if (newBoard.isFull) { 76 | println("Oh no, the board is full! Thanks for playing.") 77 | System.exit(0) 78 | 79 | } else { 80 | gameLoop(newBoard, Color.other(turnColor)) 81 | } 82 | case Some(winningColor) => 83 | println(s"Congratulations, $winningColor... You win!") 84 | printBoardState(newBoard) 85 | System.exit(0) 86 | } 87 | 88 | case Failure(e) => 89 | println("Whoops, your move was invalid with the following error: ") 90 | println(e.getMessage) 91 | println("Let's try that again.") 92 | println 93 | 94 | gameLoop(board, turnColor) 95 | } 96 | 97 | def main(items: Array[String]): Unit = 98 | gameLoop(Board.defaultEmpty, initialColor) 99 | } 100 | -------------------------------------------------------------------------------- /scala-rl-world/src/main/scala/com/scalarl/world/util/CardDeck.scala: -------------------------------------------------------------------------------- 1 | /** Card deck, for card games. 2 | */ 3 | package com.scalarl 4 | package world 5 | package util 6 | 7 | import com.stripe.rainier.core.Generator 8 | import com.scalarl.rainier.Categorical 9 | 10 | object CardDeck { 11 | sealed trait Rank extends Any with Product with Serializable 12 | object Rank { 13 | val all: Vector[Rank] = 14 | Vector(Jack, King, Queen, Ace) ++ (2 to 10).map(Number(_)) 15 | final case object Jack extends Rank 16 | final case object Queen extends Rank 17 | final case object King extends Rank 18 | final case object Ace extends Rank 19 | final case class Number(value: Int) extends AnyVal with Rank 20 | 21 | implicit val rankOrd: Ordering[Rank] = 22 | Ordering.by(all.indexOf(_)) 23 | } 24 | 25 | sealed trait Suit extends Product with Serializable 26 | object Suit { 27 | val all: Vector[Suit] = Vector(Spades, Hearts, Clubs, Diamonds) 28 | final case object Spades extends Suit 29 | final case object Hearts extends Suit 30 | final case object Clubs extends Suit 31 | final case object Diamonds extends Suit 32 | 33 | implicit val suitOrd: Ordering[Suit] = 34 | Ordering.by(all.indexOf(_)) 35 | } 36 | 37 | case class Card(suit: Suit, rank: Rank) 38 | object Card { 39 | val all: Vector[Card] = for { 40 | suit <- Suit.all 41 | rank <- Rank.all 42 | } yield Card(suit, rank) 43 | 44 | val deck: Set[Card] = all.to[Set] 45 | 46 | implicit val ordering: Ordering[Card] = 47 | Ordering.by(card => (card.suit, card.rank)) 48 | } 49 | 50 | /** Generates cards from an infinite stream, with replacement. 51 | */ 52 | val basic: Generator[Card] = Generator.vector(Card.all) 53 | val allCat: Cat[Card] = Categorical.seq(Card.all) 54 | val heartsCat: Cat[Card] = 55 | Categorical.seq(Card.all.filter(_.suit == Suit.Hearts)) 56 | } 57 | -------------------------------------------------------------------------------- /scala-rl-world/src/main/scala/com/scalarl/world/util/Grid.scala: -------------------------------------------------------------------------------- 1 | /** Grid-related utilities. I bet I could generate 1, 2, 3d grids, with custom moves between them... 2 | */ 3 | package com.scalarl 4 | package world 5 | package util 6 | 7 | import scala.util.{Failure, Success, Try} 8 | 9 | object Grid { 10 | sealed trait Move extends Product with Serializable 11 | object Move { 12 | final case object Left extends Move 13 | final case object Right extends Move 14 | final case object Up extends Move 15 | final case object Down extends Move 16 | 17 | val all: Set[Move] = Set(Left, Right, Up, Down) 18 | } 19 | 20 | case class Row(value: Int) extends AnyVal { 21 | def up: Row = Row(value + 1) 22 | def down: Row = Row(value - 1) 23 | 24 | /** Returns a row that's guaranteed to sit within the range specified by numColumns. 25 | */ 26 | def confine(numRows: Int): Row = 27 | Row(Util.clamp(value, 0, numRows - 1)) 28 | 29 | def isWithin(numRows: Int): Boolean = value >= 0 && value < numRows 30 | def assertWithin(numRows: Int): Try[Row] = 31 | if (isWithin(numRows)) 32 | Success(this) 33 | else 34 | Failure( 35 | new AssertionError( 36 | s"Column $value is invalid: Must be between 0 and $numRows." 37 | ) 38 | ) 39 | } 40 | 41 | case class Col(value: Int) extends AnyVal { 42 | def left: Col = Col(value + 1) 43 | def right: Col = Col(value - 1) 44 | 45 | /** Returns a column that's guaranteed to sit within the range specified by numColumns. 46 | */ 47 | def confine(numColumns: Int): Col = 48 | Col(Util.clamp(value, 0, numColumns - 1)) 49 | 50 | def isWithin(numColumns: Int): Boolean = value >= 0 && value < numColumns 51 | def assertWithin(numColumns: Int): Try[Col] = 52 | if (isWithin(numColumns)) 53 | Success(this) 54 | else 55 | Failure( 56 | new AssertionError( 57 | s"Column $value is invalid: Must be between 0 and $numColumns." 58 | ) 59 | ) 60 | 61 | } 62 | 63 | object Position { 64 | def of(row: Int, col: Int): Position = 65 | apply(Row(row), Col(col)) 66 | } 67 | case class Position(row: Row, col: Col) { 68 | def left: Position = Position(row, col.left) 69 | def right: Position = Position(row, col.right) 70 | def up: Position = Position(row.up, col) 71 | def down: Position = Position(row.down, col) 72 | 73 | def confine(bounds: Bounds): Position = 74 | Position( 75 | row.confine(bounds.numRows), 76 | col.confine(bounds.numColumns) 77 | ) 78 | 79 | def isWithin(bounds: Bounds): Boolean = 80 | row.isWithin(bounds.numRows) && col.isWithin(bounds.numColumns) 81 | 82 | def assertWithin(bounds: Bounds): Try[Position] = 83 | for { 84 | r <- row.assertWithin(bounds.numRows) 85 | x <- col.assertWithin(bounds.numColumns) 86 | } yield this 87 | } 88 | case class Bounds(numRows: Int, numColumns: Int) { 89 | def allPositions: Traversable[Position] = 90 | for { 91 | r <- 0 until numRows 92 | c <- 0 until numColumns 93 | } yield Position.of(r, c) 94 | } 95 | 96 | /** Produces a traversable instance containing all possible Grid states. 97 | */ 98 | def allStates(bounds: Bounds): Traversable[Grid] = 99 | bounds.allPositions.map(Grid(_, bounds)) 100 | } 101 | 102 | case class Grid(position: Grid.Position, bounds: Grid.Bounds) { 103 | import Grid.{Move, Position} 104 | 105 | private def moveF(move: Move): Position => Position = 106 | move match { 107 | case Move.Left => _.left 108 | case Move.Right => _.right 109 | case Move.Up => _.up 110 | case Move.Down => _.down 111 | } 112 | 113 | def move(move: Move): Try[Grid] = 114 | teleport(moveF(move)(position)) 115 | 116 | def teleportUnsafe(newPosition: Position): Grid = 117 | copy(position = newPosition) 118 | 119 | def teleport(newPosition: Position): Try[Grid] = 120 | newPosition.assertWithin(bounds).map(Grid(_, bounds)) 121 | } 122 | -------------------------------------------------------------------------------- /scala-rl-world/src/test/scala/com/scalarl/world/connectfour/ConnectFourSpec.scala: -------------------------------------------------------------------------------- 1 | package com.scalarl.world.connectfour 2 | 3 | import org.scalacheck.{Arbitrary, Gen, Properties} 4 | import org.scalacheck.Prop.forAll 5 | 6 | object ConnectFourSpec extends Properties("ConnectFour") with ConnectFourArb { 7 | import Game._ 8 | 9 | def makeMoveValid(move: Move, maxColumns: Int): Move = 10 | move.copy( 11 | column = if (maxColumns == 0) 0 else move.column % maxColumns 12 | ) 13 | 14 | property("generated moves are valid.") = forAll { (move: Move, columns: Int) => 15 | val posColumns = Math.abs(columns % 100) 16 | makeMoveValid(move, posColumns).column <= posColumns 17 | } 18 | } 19 | 20 | /** And this is a placeholder for basic tests. 21 | */ 22 | class ConnectFourTest extends org.scalatest.funsuite.AnyFunSuite { 23 | test("example.test") { 24 | val digits = List(1, 2, 3) 25 | assert(digits.sum === 6) 26 | } 27 | } 28 | 29 | /** Generators and Arbitrary instances live below. 30 | */ 31 | trait ConnectFourGen { 32 | import Game._ 33 | 34 | implicit val genColor: Gen[Color] = Gen.oneOf(Color.Red, Color.Black) 35 | 36 | implicit val genMove: Gen[Move] = for { 37 | column <- Gen.posNum[Int] 38 | color <- genColor 39 | } yield Move(column, color) 40 | } 41 | 42 | object ConnectFourGenerators extends ConnectFourGen 43 | 44 | trait ConnectFourArb { 45 | import Game._ 46 | import ConnectFourGenerators._ 47 | 48 | implicit val arbColor: Arbitrary[Color] = Arbitrary(genColor) 49 | implicit val arbMove: Arbitrary[Move] = Arbitrary(genMove) 50 | } 51 | -------------------------------------------------------------------------------- /scaladoc-root.txt: -------------------------------------------------------------------------------- 1 | edit this text on github 2 |

ScalaRL

3 | This is the API documentation for the ScalaRL functional reinforcement learning library. 4 | 5 | Further documentation for ScalaRL can be found at the documentation site. 6 | 7 | Check out the ScalaRL package list for all the goods. 8 | -------------------------------------------------------------------------------- /scripts/decrypt-keys.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | openssl aes-256-cbc -K $encrypted_80ed8559843c_key -iv $encrypted_80ed8559843c_iv -in travis-deploy-key.enc -out travis-deploy-key -d 3 | chmod 600 travis-deploy-key; 4 | cp travis-deploy-key ~/.ssh/id_rsa; 5 | -------------------------------------------------------------------------------- /scripts/publishMicrosite.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | git config --global user.email "sritchie09@gmail.com" 5 | git config --global user.name "Sam Ritchie" 6 | git config --global push.default simple 7 | 8 | sbt docs/publishMicrosite 9 | -------------------------------------------------------------------------------- /travis-deploy-key.enc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frogrocketlabs/scala-rl/cc02d7a46cc75436cdb2eaa41cd9f13cc97c3391/travis-deploy-key.enc --------------------------------------------------------------------------------