├── .gitignore ├── .sample_plot.png ├── LICENSE ├── README.md ├── build.sbt ├── output └── .gitkeep ├── plot ├── plot_cumulative_rewards.r ├── plot_exp3.r ├── plot_hedge.r ├── plot_standard_epsilon_greedy.r ├── plot_standard_softmax.r ├── plot_ucb1.r └── read_data.r ├── project ├── MIT.scala ├── build.properties └── plugin.sbt ├── run_test.sh └── src └── main ├── resources └── application.conf └── scala └── com └── github └── everpeace └── banditsbook ├── Demo.scala ├── TestAll.scala ├── algorithm ├── Algorithm.scala ├── TracedAlgorithmDriver.scala ├── epsilon_greedy │ ├── Standard.scala │ └── TestStandard.scala ├── exp3 │ ├── Exp3.scala │ └── TestExp3.scala ├── hedge │ ├── Hedge.scala │ └── TestHedge.scala ├── package.scala ├── softmax │ ├── Standard.scala │ └── TestStandard.scala └── ucb │ ├── TestUCB1.scala │ └── UCB1.scala ├── arm ├── Arms.scala └── package.scala └── testing_framework └── TestRunner.scala /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.log 3 | 4 | # sbt specific 5 | .cache 6 | .history 7 | .lib/ 8 | dist/* 9 | target/ 10 | lib_managed/ 11 | src_managed/ 12 | project/boot/ 13 | project/plugins/project/ 14 | 15 | # Scala-IDE specific 16 | .scala_dependencies 17 | .worksheet 18 | 19 | # Idea 20 | .idea 21 | 22 | # Created by https://www.gitignore.io/api/r 23 | 24 | ### R ### 25 | # History files 26 | .Rhistory 27 | .Rapp.history 28 | 29 | # Session Data files 30 | .RData 31 | 32 | # Example code in package build process 33 | *-Ex.R 34 | 35 | # Output files from R CMD build 36 | /*.tar.gz 37 | 38 | # Output files from R CMD check 39 | /*.Rcheck/ 40 | 41 | # RStudio files 42 | .Rproj.user/ 43 | 44 | # produced vignettes 45 | vignettes/*.html 46 | vignettes/*.pdf 47 | 48 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 49 | .httr-oauth 50 | 51 | # knitr and R markdown default cache directories 52 | /*_cache/ 53 | /cache/ 54 | 55 | # Temporary files created by R markdown 56 | *.utf8.md 57 | *.knit.md 58 | 59 | # output files 60 | *csv 61 | *tsv 62 | *png 63 | Rplots.pdf 64 | -------------------------------------------------------------------------------- /.sample_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/everpeace/banditsbook-scala/5653bc345d44000ef0ca2d5dc3239dc82f504357/.sample_plot.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Shingo Omura 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # banditsbook-scala 2 | ![sample_plot](.sample_plot.png) 3 | 4 | This repository is inspired by [johnmyleswhite/BanditsBook][johnmyleswhite/BanditsBook]. 5 | 6 | This contains Scala implementations (with [breeze][breeze] and [cats][cats]) of several standard algorithms for solving the Multi-Armed Bandits Problem, including: 7 | 8 | * [x] epsilon-Greedy 9 | * [x] Softmax (Boltzmann) 10 | * [x] UCB1 11 | * [ ] UCB2 12 | * [x] Hedge 13 | * [x] Exp3 14 | * [ ] annealing versions of aboves 15 | 16 | It also contains code that provides a testing framework for bandit algorithms based around simple Monte Carlo simulations. 17 | 18 | ## Languages and design overview 19 | All implementations are in Scala. If you're interested in other language, please see [johnmyleswhite/BanditsBook]. 20 | 21 | This bandits algorithms are implemented by functional and stateless style. Algorithm behavior is modeled as State Monad by using [cats][cats] (please see [Demo.scala][Demo.scala]). 22 | 23 | Implementing bandits algorithm needs to linear algebra(vector/matrix calculations) and probability calculations. This implementation uses [breeze][breeze] for those features. 24 | 25 | ## Getting Started 26 | To try out this code, you can go to [Demo.scala][Demo.scala] and then run the demo. 27 | 28 | ``` 29 | // run the demo written in procedural manner (but it is stateless.) 30 | sbt "run-main com.github.everpeace.banditsbook.Demo" 31 | 32 | // run the demo written in monadic manner 33 | sbt "run-main com.github.everpeace.banditsbook.DemoMonadic" 34 | ``` 35 | 36 | You should step through that code line-by-line to understand what the functions are doing. The book provides more in-depth explanations of how the algorithms work. 37 | 38 | ## Simulations 39 | This repository includes some handy scripts to run simulations. 40 | 41 | * Prerequisites: 42 | * [sbt](http://www.scala-sbt.org/) 43 | * [R](https://www.r-project.org) (for plotting simulation results) 44 | 45 | To run simulations, just hit the commands 46 | 47 | ``` 48 | // this takes some time, you can enjoy a cup of coffee :-). 49 | $ cd banditsbook-scala 50 | $ ./run_test.sh all 51 | ``` 52 | 53 | This will executes to run simulations on the configurations defined in [application.conf](src/main/resources/application.conf) and generate graphs of simulation results to `output` directory like an image on the top. Please note that all arms are simulated by Bernoulli distributions. 54 | 55 | ## Adding New Algorithms: API Expectations 56 | [Algorithm][Algorithm.scala] is defined as below: 57 | 58 | ``` 59 | // Reward : type of reward which this algorithm work for. 60 | // AlgorithmState : type of state which the algorithm handles. 61 | abstract class Algorithm[Reward, AlgorithmState] { 62 | 63 | // The method works for initialization. 64 | // Given arms, returns state value of this algorithm 65 | def initialState(arms: Seq[Arm[Reward]]): AlgorithmState 66 | 67 | // The method that returns the index of the Arm 68 | // that the algorithm selects on the current play. 69 | def selectArm(arms: Seq[Arm[Reward]], state: AlgorithmState): Int 70 | 71 | // The method calculates next state of the algorithm 72 | // in response to its most recently selected arm's reward. 73 | def updateState(arms: Seq[Arm[Reward]], state: AlgorithmState, chosen: Int, reward: Reward): AlgorithmState 74 | 75 | // 76 | // State Monadic Values: 77 | // These values are induced from above methods. 78 | // This means you can get a monadic algorithm 79 | // instance for free! 80 | // 81 | import cats.data.State 82 | import State._ 83 | def selectArm: State[(Seq[Arm[Reward]], AlgorithmState), Arm[Reward]] = 84 | inspect { 85 | case (arms, state) => 86 | arms(selectArm(arms, state)) 87 | } 88 | 89 | def updateState(chosenArm:Arm[Reward], reward: Reward): SState[(Seq[Arm[Reward]], AlgorithmState), AlgorithmState] = 90 | inspect { 91 | case (arms, state) => 92 | updateState(arms, state, arms.indexOf(chosenArm), reward) 93 | } 94 | } 95 | ``` 96 | 97 | You may need to implement your own arm simulator. Arm is modeled by `breeze.stats.distributions.Rand[+T]`. Please refer to [Arms.scala][Arms.scala] for typical arm implementations. 98 | 99 | ## License 100 | The MIT License (MIT) 101 | 102 | Copyright (c) 2016 Shingo Omura 103 | 104 | ## Contributions 105 | Contributions are welcome :-) There are no complicated regulations. Feel free to open issues and pull requests! 106 | 107 | [johnmyleswhite/BanditsBook]: https://github.com/johnmyleswhite/BanditsBook 108 | [breeze]: https://github.com/scalanlp/breeze 109 | [cats]: https://github.com/typelevel/cats 110 | [Demo.scala]: src/main/scala/com/github/everpeace/banditsbook/Demo.scala 111 | [Algorithm.scala]: src/main/scala/com/github/everpeace/banditsbook/algorithm/Algorithm.scala 112 | [Arms.scala]: /src/main/scala/com/github/everpeace/banditsbook/arm/Arms.scala 113 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import de.heikoseeberger.sbtheader.license.License 2 | 3 | // project info 4 | organization := "com.github.everpeace" 5 | name := "banditsbook-scala" 6 | version := "0.1.0" 7 | 8 | // compile settingss 9 | scalaVersion := "2.11.8" 10 | crossScalaVersions := Seq("2.10.6", "2.11.8") 11 | scalacOptions ++= Seq( 12 | "-deprecation", 13 | "-encoding", "UTF-8", 14 | "-feature", 15 | "-language:existentials", 16 | "-language:higherKinds", 17 | "-language:implicitConversions", 18 | "-language:experimental.macros", 19 | "-unchecked", 20 | "-Xfatal-warnings", 21 | "-Xlint", 22 | "-Yinline-warnings", 23 | "-Ywarn-dead-code", 24 | "-Xfuture" 25 | ) 26 | 27 | // dependencies 28 | resolvers ++= Seq( 29 | "Sonatype OSS Snapshots" at "http://oss.sonatype.org/content/repositories/snapshots/" 30 | ) 31 | libraryDependencies ++= Seq( 32 | "org.scalanlp" %% "breeze" % "0.12", 33 | "org.scalanlp" %% "breeze-natives" % "0.12", 34 | "org.scalanlp" %% "breeze-viz" % "0.12", 35 | "org.typelevel" %% "cats" % "0.5.0", 36 | "com.typesafe" % "config" % "1.3.0", 37 | "org.scala-lang.modules" %% "scala-pickling" % "0.10.1", 38 | "org.scalatest" %% "scalatest" % "2.2.6" % "test", 39 | "org.scalacheck" %% "scalacheck" % "1.13.0" % "test" 40 | ) 41 | 42 | // auto import setting for "sbt console" 43 | initialCommands := "import com.github.everpeace.banditsbook._" 44 | 45 | // sbt-headers plugin settings. 46 | unmanagedSourceDirectories in Compile += baseDirectory.value / "plot" 47 | headers := Map( 48 | "scala" -> MIT("2016", "Shingo Omura"), 49 | "conf" -> MIT("2016", "Shingo Omura", "#"), 50 | "r" -> MIT("2016", "Shingo Omura", "#"), 51 | "sh" -> MIT("2016", "Shingo Omura", "#") 52 | ) 53 | -------------------------------------------------------------------------------- /output/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/everpeace/banditsbook-scala/5653bc345d44000ef0ca2d5dc3239dc82f504357/output/.gitkeep -------------------------------------------------------------------------------- /plot/plot_cumulative_rewards.r: -------------------------------------------------------------------------------- 1 | # 2 | # The MIT License (MIT) 3 | # 4 | # Copyright (c) 2016 Shingo Omura 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | library(dplyr) 25 | library(ggplot2) 26 | 27 | plot_cumulative_rewards_with_hyper_param <- function(df){ 28 | cumStat <- df %>% 29 | dplyr::group_by(hyper_param, step) %>% 30 | dplyr::summarise(cumulative_reward_min = min(cumulative_reward), 31 | cumulative_reward_sd_min = mean(cumulative_reward) - sd(cumulative_reward), 32 | cumulative_reward_mean = mean(cumulative_reward), 33 | cumulative_reward_sd_max = mean(cumulative_reward) + sd(cumulative_reward), 34 | cumulative_reward_max = max(cumulative_reward) 35 | ) 36 | 37 | 38 | g <- ggplot(data = cumStat, 39 | aes(x = step, 40 | y = cumulative_reward_mean, 41 | colour = hyper_param)) 42 | g <- g + geom_point(size=0.5) 43 | g <- g + geom_line(size=0.5) 44 | g <- g + geom_ribbon(alpha = 0.1, colour=NA, 45 | aes( 46 | fill = hyper_param, 47 | ymin=cumulative_reward_min, 48 | ymax = cumulative_reward_max 49 | )) 50 | g <- g + geom_ribbon(alpha = 0.2, colour=NA, 51 | aes( 52 | fill = hyper_param, 53 | ymin = cumulative_reward_sd_min, 54 | ymax = cumulative_reward_sd_max 55 | )) 56 | g <- g + facet_grid(. ~ hyper_param) 57 | g <- g + theme(legend.position="none") 58 | g 59 | } 60 | 61 | plot_cumulative_rewards_without_hyper_param <- function(df){ 62 | cumStat <- df %>% 63 | dplyr::group_by(step) %>% 64 | dplyr::summarise(cumulative_reward_min = min(cumulative_reward), 65 | cumulative_reward_sd_min = mean(cumulative_reward) - sd(cumulative_reward), 66 | cumulative_reward_mean = mean(cumulative_reward), 67 | cumulative_reward_sd_max = mean(cumulative_reward) + sd(cumulative_reward), 68 | cumulative_reward_max = max(cumulative_reward)) 69 | g <- ggplot(data = cumStat, 70 | aes(x = step, 71 | y = cumulative_reward_mean)) 72 | g <- g + geom_point(size=0.5) 73 | g <- g + geom_line(size=0.5) 74 | g <- g + geom_ribbon(alpha = 0.1, colour=NA, 75 | aes( 76 | ymin=cumulative_reward_min, 77 | ymax = cumulative_reward_max 78 | )) 79 | g <- g + geom_ribbon(alpha = 0.2, colour=NA, 80 | aes( 81 | ymin = cumulative_reward_sd_min, 82 | ymax = cumulative_reward_sd_max 83 | )) 84 | g <- g + theme(legend.position="none") 85 | g 86 | } 87 | -------------------------------------------------------------------------------- /plot/plot_exp3.r: -------------------------------------------------------------------------------- 1 | #! /usr/bin/Rscript 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2016 Shingo Omura 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | initial.options <- commandArgs(trailingOnly = FALSE) 26 | file.arg.name <- "--file=" 27 | script.name <- sub(file.arg.name, "", initial.options[grep(file.arg.name, initial.options)]) 28 | script.basename <- dirname(script.name) 29 | ourdir.arg.name <- "--outdir=" 30 | outdir <- sub(ourdir.arg.name, "", initial.options[grep(ourdir.arg.name, initial.options)]) 31 | 32 | source(file.path(script.basename, "plot_cumulative_rewards.r")) 33 | source(file.path(script.basename, "read_data.r")) 34 | library(stringr) 35 | 36 | # read data 37 | datafile_path <- file.path(script.basename, "..", outdir, "test-exp3-results.csv") 38 | df <- read_data_with_hyper_param(datafile_path) 39 | 40 | # plot 41 | g <- plot_cumulative_rewards_with_hyper_param(df) 42 | g <- g + geom_line(aes(y=step+1.0), linetype="dotted") 43 | g <- g + ggtitle("Cumulative Rewards of Exp3 for each γ. (note: dashed line indicates optimal behavior)") 44 | g <- g + theme(plot.title = element_text(hjust = 0)) 45 | 46 | # print(g) 47 | ggsave(file = str_replace(datafile_path, ".csv", ".png"), plot = g, width = 15, height = 7) 48 | -------------------------------------------------------------------------------- /plot/plot_hedge.r: -------------------------------------------------------------------------------- 1 | #! /usr/bin/Rscript 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2016 Shingo Omura 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | initial.options <- commandArgs(trailingOnly = FALSE) 26 | file.arg.name <- "--file=" 27 | script.name <- sub(file.arg.name, "", initial.options[grep(file.arg.name, initial.options)]) 28 | script.basename <- dirname(script.name) 29 | ourdir.arg.name <- "--outdir=" 30 | outdir <- sub(ourdir.arg.name, "", initial.options[grep(ourdir.arg.name, initial.options)]) 31 | 32 | source(file.path(script.basename, "plot_cumulative_rewards.r")) 33 | source(file.path(script.basename, "read_data.r")) 34 | library(stringr) 35 | 36 | # read data 37 | datafile_path <- file.path(script.basename, "..", outdir, "test-hedge-results.csv") 38 | df <- read_data_with_hyper_param(datafile_path) 39 | 40 | # plot 41 | g <- plot_cumulative_rewards_with_hyper_param(df) 42 | g <- g + geom_line(aes(y=step+1.0), linetype="dotted") 43 | g <- g + ggtitle("Cumulative Rewards of Hedge for each η. (note: dashed line indicates optimal behavior)") 44 | g <- g + theme(plot.title = element_text(hjust = 0)) 45 | 46 | # print(g) 47 | ggsave(file = str_replace(datafile_path, ".csv", ".png"), plot = g, width = 15, height = 7) 48 | -------------------------------------------------------------------------------- /plot/plot_standard_epsilon_greedy.r: -------------------------------------------------------------------------------- 1 | #! /usr/bin/Rscript 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2016 Shingo Omura 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | initial.options <- commandArgs(trailingOnly = FALSE) 26 | file.arg.name <- "--file=" 27 | script.name <- sub(file.arg.name, "", initial.options[grep(file.arg.name, initial.options)]) 28 | script.basename <- dirname(script.name) 29 | ourdir.arg.name <- "--outdir=" 30 | outdir <- sub(ourdir.arg.name, "", initial.options[grep(ourdir.arg.name, initial.options)]) 31 | 32 | source(file.path(script.basename, "plot_cumulative_rewards.r")) 33 | source(file.path(script.basename, "read_data.r")) 34 | library(stringr) 35 | 36 | # read data 37 | datafile_path <- file.path(script.basename, "..", outdir, "test-standard-epsilon-greedy-results.csv") 38 | df <- read_data_with_hyper_param(datafile_path) 39 | 40 | # plot 41 | g <- plot_cumulative_rewards_with_hyper_param(df) 42 | g <- g + geom_line(aes(y=step+1.0), linetype="dotted") 43 | g <- g + ggtitle("Cumulative Rewards of Standard ε-Greedy for each ε. (note: dashed line indicates optimal behavior)") 44 | g <- g + theme(plot.title = element_text(hjust = 0)) 45 | 46 | # print(g) 47 | ggsave(file = str_replace(datafile_path, ".csv", ".png"), plot = g, width = 15, height = 7) 48 | -------------------------------------------------------------------------------- /plot/plot_standard_softmax.r: -------------------------------------------------------------------------------- 1 | #! /usr/bin/Rscript 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2016 Shingo Omura 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | initial.options <- commandArgs(trailingOnly = FALSE) 26 | file.arg.name <- "--file=" 27 | script.name <- sub(file.arg.name, "", initial.options[grep(file.arg.name, initial.options)]) 28 | script.basename <- dirname(script.name) 29 | ourdir.arg.name <- "--outdir=" 30 | outdir <- sub(ourdir.arg.name, "", initial.options[grep(ourdir.arg.name, initial.options)]) 31 | 32 | source(file.path(script.basename, "plot_cumulative_rewards.r")) 33 | source(file.path(script.basename, "read_data.r")) 34 | library(stringr) 35 | 36 | # read data 37 | datafile_path <- file.path(script.basename, "..", outdir, "test-standard-softmax-results.csv") 38 | df <- read_data_with_hyper_param(datafile_path) 39 | 40 | # plot 41 | g <- plot_cumulative_rewards_with_hyper_param(df) 42 | g <- g + geom_line(aes(y=step+1.0), linetype="dotted") 43 | g <- g + ggtitle("Cumulative Rewards of Standard Softmax for each τ. (note: dashed line indicates optimal behavior)") 44 | g <- g + theme(plot.title = element_text(hjust = 0)) 45 | 46 | # print(g) 47 | ggsave(file = str_replace(datafile_path, ".csv", ".png"), plot = g, width = 15, height = 7) 48 | -------------------------------------------------------------------------------- /plot/plot_ucb1.r: -------------------------------------------------------------------------------- 1 | #! /usr/bin/Rscript 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2016 Shingo Omura 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | initial.options <- commandArgs(trailingOnly = FALSE) 26 | file.arg.name <- "--file=" 27 | script.name <- sub(file.arg.name, "", initial.options[grep(file.arg.name, initial.options)]) 28 | script.basename <- dirname(script.name) 29 | ourdir.arg.name <- "--outdir=" 30 | outdir <- sub(ourdir.arg.name, "", initial.options[grep(ourdir.arg.name, initial.options)]) 31 | 32 | source(file.path(script.basename, "plot_cumulative_rewards.r")) 33 | source(file.path(script.basename, "read_data.r")) 34 | library(stringr) 35 | 36 | # read data 37 | datafile_path <- file.path(script.basename, "..", outdir, "test-ucb1-results.csv") 38 | df <- read_data_without_hyper_param(datafile_path) 39 | 40 | # plot 41 | g <- plot_cumulative_rewards_without_hyper_param(df) 42 | g <- g + geom_line(aes(y=step+1.0), linetype="dotted") 43 | g <- g + ggtitle("Cumulative Rewards of UCB1. (note: dashed line indicates optimal behavior.)") 44 | g <- g + theme(plot.title = element_text(hjust = 0)) 45 | 46 | # print(g) 47 | ggsave(file = str_replace(datafile_path, ".csv", ".png"), plot = g, width = 7, height = 7) 48 | -------------------------------------------------------------------------------- /plot/read_data.r: -------------------------------------------------------------------------------- 1 | # 2 | # The MIT License (MIT) 3 | # 4 | # Copyright (c) 2016 Shingo Omura 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | read_data_with_hyper_param <- function(datafile_path){ 25 | df <- read.table(datafile_path, header = T ,sep=",") 26 | hp_name <- names(df)[1] 27 | df$hyper_param <- as.factor(df[, hp_name]) 28 | df 29 | } 30 | 31 | read_data_without_hyper_param <- function(datafile_path){ 32 | df <- read.table(datafile_path, header = T ,sep=",") 33 | df 34 | } 35 | -------------------------------------------------------------------------------- /project/MIT.scala: -------------------------------------------------------------------------------- 1 | import de.heikoseeberger.sbtheader.HeaderPattern 2 | import de.heikoseeberger.sbtheader.license.License 3 | 4 | import scala.util.matching.Regex 5 | 6 | object MIT extends License { 7 | import HeaderPattern._ 8 | override def apply(yyyy: String, copyrightOwner: String, commentStyle: String = "*"): (Regex, String) = { 9 | commentStyle match { 10 | case "*" => 11 | ( 12 | cStyleBlockComment, 13 | s"""|/* 14 | | * Copyright (c) $yyyy $copyrightOwner 15 | | * 16 | | * Permission is hereby granted, free of charge, to any person obtaining a copy of 17 | | * this software and associated documentation files (the "Software"), to deal in 18 | | * the Software without restriction, including without limitation the rights to 19 | | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 20 | | * the Software, and to permit persons to whom the Software is furnished to do so, 21 | | * subject to the following conditions: 22 | | * 23 | | * The above copyright notice and this permission notice shall be included in all 24 | | * copies or substantial portions of the Software. 25 | | * 26 | | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 28 | | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 29 | | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 30 | | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 31 | | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 32 | | */ 33 | | 34 | |""".stripMargin 35 | ) 36 | case "#" => 37 | ( 38 | hashLineComment, 39 | s"""|# 40 | |# Copyright (c) $yyyy $copyrightOwner 41 | |# 42 | |# Permission is hereby granted, free of charge, to any person obtaining a copy of 43 | |# this software and associated documentation files (the "Software"), to deal in 44 | |# the Software without restriction, including without limitation the rights to 45 | |# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 46 | |# the Software, and to permit persons to whom the Software is furnished to do so, 47 | |# subject to the following conditions: 48 | |# 49 | |# The above copyright notice and this permission notice shall be included in all 50 | |# copies or substantial portions of the Software. 51 | |# 52 | |# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 53 | |# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 54 | |# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 55 | |# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 56 | |# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 57 | |# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 58 | | 59 | |""".stripMargin 60 | ) 61 | case _ => 62 | throw new IllegalArgumentException(s"Comment style '$commentStyle' not supported") 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 0.13.11 2 | -------------------------------------------------------------------------------- /project/plugin.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("de.heikoseeberger" % "sbt-header" % "1.5.1") 2 | -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2016 Shingo Omura 6 | # # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | SBT=$(which sbt) 26 | R=$(which r) 27 | OPEN=$(which open) 28 | PLOT_TOOL_DIR="plot" 29 | OUTPUT_DIR=${RUN_TEST_OUTPUT_DIR:-output} 30 | 31 | available_algs=(standard-epsilon-greedy standard-softmax exp3 hedge ucb1) 32 | 33 | function show_usage(){ 34 | cat < brz_softmax} 35 | 36 | // Several Arm type is defined in arm package. 37 | import arm._ 38 | 39 | val arm1 = BernoulliArm(0.2) 40 | arm1.draw() 41 | arm1.draw() 42 | 43 | val arm2 = NormalArm(10.0, 1.0) 44 | arm2.draw() 45 | arm2.draw() 46 | 47 | val arm3 = BernoulliArm(0.2) 48 | arm3.draw() 49 | arm3.draw() 50 | 51 | // create algorithm instances 52 | import algorithm._ 53 | 54 | val algo1 = epsilon_greedy.Standard.Algorithm(ε = 0.1d) 55 | val algo2 = softmax.Standard.Algorithm(τ = 1.0d) 56 | val algo3 = ucb.UCB1.Algorithm 57 | val algo4 = exp3.Exp3.Algorithm(γ = 0.2) 58 | val algos = Seq(algo1, algo2, algo3, algo4) 59 | 60 | val arms = scala.collection.immutable.Seq(arm1, arm2, arm3) 61 | val t = 1000 62 | val sep = "," 63 | 64 | } 65 | 66 | object Demo extends DemoBase with App { 67 | def t_times(f: => Unit) = for { t <- 0 until t} { f } 68 | 69 | println() 70 | 71 | println("Standard Epsilon Greedy algorithm result") 72 | var algo1State = algo1.initialState(arms) 73 | t_times { 74 | val chosen_arm = algo1.selectArm(arms, algo1State) 75 | val reward = arms(chosen_arm).draw 76 | algo1State = algo1.updateState(arms, algo1State, chosen_arm, reward) 77 | } 78 | println(s" counts: [${algo1State.counts.valuesIterator.mkString(sep)}]") 79 | println(s"expectations: [${algo1State.expectations.valuesIterator.mkString(sep)}]") 80 | 81 | println() 82 | 83 | println("Standard Softmax algorithm result") 84 | var algo2State = algo2.initialState(arms) 85 | t_times { 86 | val chosen_arm = algo2.selectArm(arms, algo2State) 87 | val reward = arms(chosen_arm).draw 88 | algo2State = algo2.updateState(arms, algo2State, chosen_arm, reward) 89 | } 90 | println(s" counts: [${algo2State.counts.valuesIterator.mkString(sep)}]") 91 | println(s"expectations: [${algo2State.expectations.valuesIterator.mkString(sep)}]") 92 | 93 | println() 94 | 95 | println("UCB1 algorithm result") 96 | var algo3State = algo3.initialState(arms) 97 | t_times { 98 | val chosen_arm = algo3.selectArm(arms, algo3State) 99 | val reward = arms(chosen_arm).draw 100 | algo3State = algo3.updateState(arms, algo3State, chosen_arm, reward) 101 | } 102 | println(s" counts: [${algo3State.counts.valuesIterator.mkString(sep)}]") 103 | println(s"expectations: [${algo3State.expectations.valuesIterator.mkString(sep)}]") 104 | 105 | println() 106 | 107 | println("Exp3 algorithm result") 108 | var algo4State = algo4.initialState(arms) 109 | t_times { 110 | val chosen_arm = algo4.selectArm(arms, algo4State) 111 | val reward = arms(chosen_arm).draw 112 | algo4State = algo4.updateState(arms, algo4State, chosen_arm, reward) 113 | } 114 | println(s" counts: [${algo4State.counts.valuesIterator.mkString(sep)}]") 115 | println(s"weights: [${algo4State.weights.valuesIterator.mkString(sep)}]") 116 | 117 | println() 118 | } 119 | 120 | 121 | object DemoMonadic extends DemoBase with App { 122 | import algorithm._ 123 | import arm._ 124 | 125 | // algorithm behavior is also implemented by State Monad (by using cats) 126 | import cats.data.State 127 | import State._ 128 | 129 | // simple driver here. 130 | object SimpleAlgorithmDriver { 131 | private def simulation[S](algo: Algorithm[Double, S], n: Int) = { 132 | def trial = 133 | for { 134 | chosenArm <- algo.selectArm 135 | reward = chosenArm.draw() 136 | _ <- algo.updateState(chosenArm, reward) 137 | } yield () 138 | 139 | def trials(n: Int): State[(Seq[Arm[Double]], S), Unit] = n match { 140 | case 0 => pure(()) 141 | case n => for { 142 | _ <- trials(n - 1) 143 | _ <- trial 144 | } yield () 145 | } 146 | 147 | trials(n) 148 | } 149 | 150 | def run[S](algo: Algorithm[Double, S], nStep: Int) = 151 | simulation(algo, nStep).runS((arms, algo.initialState(arms))).value 152 | } 153 | 154 | println() 155 | 156 | println("Standard Epsilon Greedy algorithm result") 157 | val algo1Result = SimpleAlgorithmDriver.run(algo1, t) 158 | println(s" counts: [${algo1Result._2.counts.valuesIterator.mkString(sep)}]") 159 | println(s"expectations: [${algo1Result._2.expectations.valuesIterator.mkString(sep)}]") 160 | 161 | println() 162 | 163 | println("Standard Softmax algorithm result") 164 | val algo2Result = SimpleAlgorithmDriver.run(algo2, t) 165 | println(s" counts: [${algo2Result._2.counts.valuesIterator.mkString(sep)}]") 166 | println(s"expectations: [${algo2Result._2.expectations.valuesIterator.mkString(sep)}]") 167 | 168 | println() 169 | 170 | println("UCB1 algorithm result") 171 | val algo3Result = SimpleAlgorithmDriver.run(algo3, t) 172 | println(s" counts: [${algo3Result._2.counts.valuesIterator.mkString(sep)}]") 173 | println(s"expectations: [${algo3Result._2.expectations.valuesIterator.mkString(sep)}]") 174 | 175 | println() 176 | 177 | println("Exp3 algorithm result") 178 | val algo4Result = SimpleAlgorithmDriver.run(algo4, t) 179 | println(s" counts: [${algo4Result._2.counts.valuesIterator.mkString(sep)}]") 180 | println(s"weights: [${algo4Result._2.weights.valuesIterator.mkString(sep)}]") 181 | 182 | println() 183 | } 184 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/TestAll.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook 23 | import com.github.everpeace.banditsbook.algorithm._ 24 | 25 | object TestAll extends App { 26 | new epsilon_greedy._TestStandard{}.run() 27 | println() 28 | 29 | new softmax._TestStandard{}.run() 30 | println() 31 | 32 | new exp3._TestExp3 {}.run() 33 | println() 34 | 35 | new hedge._TestHedge {}.run() 36 | println() 37 | 38 | new ucb._TestUCB1{}.run() 39 | println() 40 | } 41 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/algorithm/Algorithm.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook.algorithm 23 | 24 | import breeze.storage.Zero 25 | import cats.data.{State => CState} 26 | import com.github.everpeace.banditsbook.arm.Arm 27 | 28 | import scala.reflect.ClassTag 29 | 30 | 31 | abstract class Algorithm[Reward: ClassTag: Zero, AlgorithmState] { 32 | import cats.data.{State => CState} 33 | import CState._ 34 | 35 | def initialState(arms: Seq[Arm[Reward]]): AlgorithmState 36 | def selectArm(arms: Seq[Arm[Reward]], state: AlgorithmState): Int 37 | def updateState(arms: Seq[Arm[Reward]], state: AlgorithmState, chosen: Int, reward: Reward): AlgorithmState 38 | 39 | def selectArm: CState[(Seq[Arm[Reward]], AlgorithmState), Arm[Reward]] = 40 | inspect { 41 | case (arms, state) => 42 | arms(selectArm(arms, state)) 43 | } 44 | 45 | def updateState(chosenArm:Arm[Reward], reward: Reward): CState[(Seq[Arm[Reward]], AlgorithmState), AlgorithmState] = 46 | inspect { 47 | case (arms, state) => 48 | updateState(arms, state, arms.indexOf(chosenArm), reward) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/algorithm/TracedAlgorithmDriver.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook.algorithm 23 | 24 | import breeze.linalg.Vector 25 | import breeze.linalg.Vector._ 26 | import breeze.storage.Zero 27 | import cats.data.{State => CState} 28 | import com.github.everpeace.banditsbook.arm._ 29 | 30 | import scala.reflect.ClassTag 31 | 32 | object TracedAlgorithmDriver { 33 | 34 | // Note: breeze.linalg.Vector is mutable. 35 | final case class Trace[Reward: Zero](chosenArms: Vector[Int], counts: Vector[Int], rewards: Vector[Reward]) 36 | 37 | final case class State[Reward, AlgorithmState](arms: Seq[Arm[Reward]], step: Int, horizon: Int, 38 | algState: AlgorithmState, trace: Trace[Reward]) 39 | } 40 | 41 | case class TracedAlgorithmDriver[Reward: Zero: ClassTag, AlgorithmState](algo: Algorithm[Reward, AlgorithmState])(implicit zeroInt: Zero[Int]) { 42 | import CState._ 43 | import TracedAlgorithmDriver._ 44 | 45 | private val incrementStep = modify[State[Reward, AlgorithmState]] { s => s.copy(step = s.step + 1) } 46 | private def setAlgState(s: AlgorithmState) = modify[State[Reward, AlgorithmState]] { _.copy(algState = s) } 47 | private def updateTrace(a: Arm[Reward], r: Reward) = modify[State[Reward, AlgorithmState]] { s => 48 | val step = s.step 49 | val chosen = s.arms.indexOf(a) 50 | val count = s.trace.counts(chosen) 51 | s.trace.chosenArms.update(step, chosen) 52 | s.trace.counts.update(chosen, count + 1) 53 | s.trace.rewards.update(step, r) 54 | s.copy() 55 | } 56 | 57 | // drive 'step' once. 58 | private val driveStep: CState[State[Reward, AlgorithmState], Unit] = for { 59 | state <- get[State[Reward, AlgorithmState]] 60 | chosenArm = algo.selectArm.runA((state.arms, state.algState)).value 61 | reward = chosenArm.draw() 62 | newState = algo.updateState(chosenArm, reward).runA((state.arms, state.algState)).value 63 | _ <- setAlgState(newState) 64 | _ <- updateTrace(chosenArm, reward) 65 | _ <- incrementStep 66 | } yield () 67 | 68 | // drive 'step' $n times 69 | private def driveSteps(n: Int): CState[State[Reward, AlgorithmState], Unit] = n match { 70 | case 0 => pure( () ) // nop 71 | case _ => for { 72 | _ <- driveSteps(n - 1) 73 | _ <- driveStep 74 | } yield () 75 | } 76 | 77 | /** 78 | * perform 'step' $steps times from initial state and return its final state. 79 | */ 80 | final def run(arms: Seq[Arm[Reward]], horizon: Int, steps: Int): State[Reward, AlgorithmState] = 81 | runFrom( 82 | State(arms, 0, horizon, 83 | algo.initialState(arms), 84 | Trace[Reward](zeros(horizon), zeros[Int](arms.size), zeros[Reward](horizon)) 85 | ), 86 | steps 87 | ) 88 | final def run(arms: Seq[Arm[Reward]], horizon: Int): State[Reward, AlgorithmState] = run(arms, horizon, horizon) 89 | 90 | /** 91 | * resume the algorithm from given state and return its final state 92 | */ 93 | final def runFrom(state: State[Reward, AlgorithmState], steps: Int): State[Reward, AlgorithmState] = { 94 | if ((state.horizon - state.step) <= steps) 95 | driveSteps(state.horizon - state.step).runS(state).value 96 | else 97 | driveSteps(steps).runS(state).value 98 | 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/algorithm/epsilon_greedy/Standard.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook.algorithm.epsilon_greedy 23 | 24 | import breeze.linalg.argmax 25 | import breeze.stats.distributions.{Bernoulli, Rand, RandBasis} 26 | import breeze.storage.Zero 27 | import com.github.everpeace.banditsbook.algorithm.Algorithm 28 | import com.github.everpeace.banditsbook.arm.Arm 29 | 30 | import scala.collection.immutable.Seq 31 | import scala.reflect.ClassTag 32 | 33 | /** 34 | * see: http://www.cs.nyu.edu/~mohri/pub/bandit.pdf 35 | */ 36 | object Standard { 37 | 38 | import breeze.linalg.Vector 39 | import Vector._ 40 | 41 | case class State(ε: Double, counts: Vector[Int], expectations: Vector[Double]) 42 | 43 | def Algorithm(ε: Double)(implicit zeroDouble: Zero[Double], zeroInt: Zero[Int], tag: ClassTag[Double], rand: RandBasis = Rand) 44 | = new Algorithm[Double, State] { 45 | 46 | override def initialState(arms: Seq[Arm[Double]]): State = 47 | State(ε, zeros[Int](arms.size), zeros[Double](arms.size)) 48 | 49 | override def selectArm(arms: Seq[Arm[Double]], state: State): Int = 50 | Bernoulli.distribution(state.ε).draw() match { 51 | case true => 52 | // Exploit 53 | argmax(state.expectations) 54 | case false => 55 | // Explore 56 | Rand.randInt(state.expectations.size).draw() 57 | } 58 | 59 | override def updateState(arms: Seq[Arm[Double]], state: State, chosen: Int, reward: Double): State = { 60 | val counts = state.counts 61 | val expectations = state.expectations 62 | 63 | val count = counts(chosen) + 1 64 | counts.update(chosen, count) 65 | 66 | val expectation = (((count - 1) / count.toDouble) * expectations(chosen)) + ((1 / count.toDouble) * reward) 67 | expectations.update(chosen, expectation) 68 | state.copy(counts = counts, expectations = expectations) 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/algorithm/epsilon_greedy/TestStandard.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook.algorithm.epsilon_greedy 23 | 24 | import java.io.{File, PrintWriter} 25 | 26 | import breeze.linalg._ 27 | import breeze.stats.MeanAndVariance 28 | import com.github.everpeace.banditsbook.arm._ 29 | import com.github.everpeace.banditsbook.testing_framework.TestRunner 30 | import com.github.everpeace.banditsbook.testing_framework.TestRunner._ 31 | import com.typesafe.config.ConfigFactory 32 | 33 | import scala.collection.immutable.Seq 34 | 35 | object TestStandard extends _TestStandard with App { 36 | run() 37 | } 38 | 39 | trait _TestStandard { 40 | def run() = { 41 | // implicit val randBasis = RandBasis.mt0 42 | 43 | val conf = ConfigFactory.load() 44 | val baseKey = "banditsbook.algorithm.epsilon_greedy.test-standard" 45 | val (_means, Some(εs), horizon, nSims, outDir) = readConfig(conf, baseKey, Some("εs")) 46 | val means = shuffle(_means) 47 | val arms = Seq(means:_*).map(μ => BernoulliArm(μ)) 48 | 49 | val outputPath = new File(outDir, "test-standard-epsilon-greedy-results.csv") 50 | val file = new PrintWriter(outputPath.toString) 51 | file.write("epsilon, sim_num, step, chosen_arm, reward, cumulative_reward\n") 52 | try { 53 | println("---------------------------------") 54 | println("Standard Epsilon Greedy Algorithm") 55 | println("---------------------------------") 56 | println(s" arms = ${means.map("(μ="+_+")").mkString(", ")} (Best Arm = ${argmax(means)})") 57 | println(s"horizon = $horizon") 58 | println(s" nSims = $nSims") 59 | println(s" ε = (${εs.mkString(",")})") 60 | println("") 61 | 62 | val meanOfFinalRewards = scala.collection.mutable.Map.empty[Double, MeanAndVariance] 63 | val res = for { 64 | ε <- εs 65 | } yield { 66 | println(s"starts simulation on ε=$ε.") 67 | 68 | val algo = Standard.Algorithm(ε) 69 | val res = TestRunner.run(algo, arms, nSims, horizon) 70 | 71 | for { 72 | sim <- 0 until nSims 73 | } { 74 | val st = sim * horizon 75 | val end = ((sim + 1) * horizon) - 1 76 | } 77 | val finalRewards = res.cumRewards((horizon-1) until (nSims * horizon, horizon)) 78 | import breeze.stats._ 79 | val meanAndVar = meanAndVariance(finalRewards) 80 | meanOfFinalRewards += ε -> meanAndVar 81 | println(s"reward stats: ${TestRunner.toString(meanAndVar)}") 82 | 83 | res.rawResults.valuesIterator.foreach{ v => 84 | file.write(s"${Seq(ε.toString, v._1.toString, v._2.toString, v._3.toString, v._4.toString, v._5.toString).mkString(",")}\n") 85 | } 86 | println(s"finished simulation on ε=$ε.") 87 | } 88 | println("") 89 | println(s"reward stats summary") 90 | println(s"${meanOfFinalRewards.iterator.toSeq.sortBy(_._1).map(p => (s"ε = ${p._1}", TestRunner.toString(p._2))).mkString("\n")}") 91 | } finally { 92 | file.close() 93 | println("") 94 | println(s"results are written to ${outputPath}") 95 | } 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/algorithm/exp3/Exp3.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook.algorithm.exp3 23 | 24 | import breeze.linalg.Vector._ 25 | import breeze.linalg._ 26 | import breeze.numerics.exp 27 | import breeze.stats.distributions.{Rand, RandBasis} 28 | import breeze.storage.Zero 29 | import com.github.everpeace.banditsbook.algorithm._ 30 | import com.github.everpeace.banditsbook.arm.Arm 31 | 32 | import scala.collection.immutable.Seq 33 | import scala.reflect.ClassTag 34 | 35 | /** 36 | * see: http://www.cs.nyu.edu/~mohri/pub/bandit.pdf 37 | */ 38 | object Exp3 { 39 | 40 | case class State(γ: Double, weights: Vector[Double], counts: Vector[Int]) 41 | 42 | def Algorithm(γ: Double)(implicit zeroReward: Zero[Double], zeroInt: Zero[Int], tag: ClassTag[Double], rand: RandBasis = Rand) 43 | = { 44 | require(0< γ && γ <= 1, "γ must be in (0,1]") 45 | 46 | new Algorithm[Double, State] { 47 | 48 | override def initialState(arms: Seq[Arm[Double]]): State = State( 49 | γ, fill(arms.size)(1.0d), zeros[Int](arms.size) 50 | ) 51 | 52 | override def selectArm(arms: Seq[Arm[Double]], state: State): Int = 53 | CategoricalDistribution(probs(state.γ, state.weights)).draw() 54 | 55 | override def updateState(arms: Seq[Arm[Double]], state: State, chosen: Int, reward: Double): State = { 56 | val counts = state.counts 57 | val weights = state.weights 58 | 59 | val count = counts(chosen) + 1 60 | counts.update(chosen, count) 61 | 62 | val K = weights.size 63 | val p = probs(state.γ, weights) 64 | val x = zeros[Double](K) 65 | x.update(chosen, reward/p(chosen)) 66 | weights *= exp((state.γ * x) / K.toDouble) 67 | 68 | state.copy(weights = weights, counts = counts) 69 | } 70 | 71 | private def probs(γ: Double, weights: Vector[Double]): Vector[Double] = { 72 | val K = weights.size // #arms 73 | ((1 - γ) * (weights / sum(weights))) + (γ / K) 74 | } 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/algorithm/exp3/TestExp3.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook.algorithm.exp3 23 | 24 | import java.io.{File, PrintWriter} 25 | 26 | import breeze.linalg._ 27 | import breeze.stats.MeanAndVariance 28 | import com.github.everpeace.banditsbook.arm._ 29 | import com.github.everpeace.banditsbook.testing_framework.TestRunner 30 | import com.github.everpeace.banditsbook.testing_framework.TestRunner._ 31 | import com.typesafe.config.ConfigFactory 32 | 33 | import scala.collection.immutable.Seq 34 | 35 | object TestExp3 extends _TestExp3 with App { 36 | run() 37 | } 38 | 39 | trait _TestExp3{ 40 | def run() = { 41 | // implicit val randBasis = RandBasis.mt0 42 | 43 | val conf = ConfigFactory.load() 44 | val baseKey = "banditsbook.algorithm.exp3.test-exp3" 45 | val (_means, Some(γs), horizon, nSims, outDir) = readConfig(conf, baseKey, Some("γs")) 46 | val means = shuffle(_means) 47 | val arms = Seq(means:_*).map(μ => BernoulliArm(μ)) 48 | 49 | val outputPath = new File(outDir, "test-exp3-results.csv") 50 | val file = new PrintWriter(outputPath.toString) 51 | file.write("gamma, sim_num, step, chosen_arm, reward, cumulative_reward\n") 52 | try { 53 | println("-------------------------------") 54 | println("EXP3 Algorithm") 55 | println("-------------------------------") 56 | println(s" arms = ${means.map("(μ="+_+")").mkString(", ")} (Best Arm = ${argmax(means)})") 57 | println(s"horizon = $horizon") 58 | println(s" nSims = $nSims") 59 | println(s" γ = (${γs.mkString(",")})") 60 | println("") 61 | 62 | val meanOfFinalRewards = scala.collection.mutable.Map.empty[Double, MeanAndVariance] 63 | val res = for { 64 | γ <- γs 65 | } yield { 66 | println(s"starts simulation on γ=$γ.") 67 | 68 | val algo = Exp3.Algorithm(γ) 69 | val res = TestRunner.run(algo, arms, nSims, horizon) 70 | 71 | for { 72 | sim <- 0 until nSims 73 | } { 74 | val st = sim * horizon 75 | val end = ((sim + 1) * horizon) - 1 76 | } 77 | val finalRewards = res.cumRewards((horizon-1) until (nSims * horizon, horizon)) 78 | import breeze.stats._ 79 | val meanAndVar = meanAndVariance(finalRewards) 80 | meanOfFinalRewards += γ -> meanAndVar 81 | println(s"reward stats: ${TestRunner.toString(meanAndVar)}") 82 | 83 | res.rawResults.valuesIterator.foreach{ v => 84 | file.write(s"${Seq(γ, v._1, v._2, v._3, v._4, v._5).mkString(",")}\n") 85 | } 86 | println(s"finished simulation on γ=$γ.") 87 | } 88 | println("") 89 | println(s"reward stats summary") 90 | println(s"${meanOfFinalRewards.iterator.toSeq.sortBy(_._1).toSeq.sortBy(_._1).map(p => (s"γ=${p._1}", TestRunner.toString(p._2))).mkString("\n")}") 91 | } finally { 92 | file.close() 93 | println("") 94 | println(s"results are written to ${outputPath}") 95 | } 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/algorithm/hedge/Hedge.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook.algorithm.hedge 23 | 24 | import breeze.linalg.Vector._ 25 | import breeze.linalg._ 26 | import breeze.numerics.exp 27 | import breeze.stats.distributions.{Rand, RandBasis} 28 | import breeze.storage.Zero 29 | import com.github.everpeace.banditsbook.algorithm._ 30 | import com.github.everpeace.banditsbook.arm.Arm 31 | 32 | import scala.collection.immutable.Seq 33 | import scala.reflect.ClassTag 34 | 35 | /** 36 | * http://www.dklevine.com/archive/refs4462.pdf 37 | */ 38 | object Hedge { 39 | 40 | case class State(η: Double, counts: Vector[Int], gains: Vector[Double]) 41 | 42 | def Algorithm(η: Double)(implicit zeroReward: Zero[Double], zeroInt: Zero[Int], tag: ClassTag[Double], rand: RandBasis = Rand) 43 | = { 44 | require(η > 0, "η must be positive.") 45 | new Algorithm[Double, State] { 46 | 47 | override def initialState(arms: Seq[Arm[Double]]): State = State( 48 | η, zeros(arms.size), zeros(arms.size) 49 | ) 50 | 51 | override def selectArm(arms: Seq[Arm[Double]], state: State): Int = { 52 | val gains = state.gains 53 | val η = state.η 54 | val p = exp(gains / η) / sum(exp(gains / η)) 55 | CategoricalDistribution(p).draw 56 | } 57 | 58 | override def updateState(arms: Seq[Arm[Double]], state: State, chosen: Int, reward: Double): State = { 59 | val counts = state.counts 60 | val gains = state.gains 61 | 62 | val count = counts(chosen) + 1 63 | counts.update(chosen, count) 64 | 65 | val expectation = gains(chosen) + reward 66 | gains.update(chosen, expectation) 67 | 68 | state.copy(counts = counts, gains = gains) 69 | } 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/algorithm/hedge/TestHedge.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook.algorithm.hedge 23 | 24 | import java.io.{File, PrintWriter} 25 | 26 | import breeze.linalg._ 27 | import breeze.stats.MeanAndVariance 28 | import com.github.everpeace.banditsbook.arm._ 29 | import com.github.everpeace.banditsbook.testing_framework.TestRunner 30 | import com.github.everpeace.banditsbook.testing_framework.TestRunner._ 31 | import com.typesafe.config.ConfigFactory 32 | 33 | import scala.collection.immutable.Seq 34 | 35 | object TestHedge extends _TestHedge with App{ 36 | run() 37 | } 38 | 39 | trait _TestHedge { 40 | def run() = { 41 | // implicit val randBasis = RandBasis.mt0 42 | 43 | val conf = ConfigFactory.load() 44 | val baseKey = "banditsbook.algorithm.hedge.test-hedge" 45 | val (_means, Some(ηs), horizon, nSims, outDir) = readConfig(conf, baseKey, Some("ηs")) 46 | val means = shuffle(_means) 47 | val arms = Seq(means:_*).map(μ => BernoulliArm(μ)) 48 | 49 | val outputPath = new File(outDir, "test-hedge-results.csv") 50 | val file = new PrintWriter(outputPath.toString) 51 | file.write("eta, sim_num, step, chosen_arm, reward, cumulative_reward\n") 52 | try { 53 | println("-------------------------------") 54 | println("Hedge Algorithm") 55 | println("-------------------------------") 56 | println(s" arms = ${means.map("(μ="+_+")").mkString(", ")} (Best Arm = ${argmax(means)})") 57 | println(s"horizon = $horizon") 58 | println(s" nSims = $nSims") 59 | println(s" η = (${ηs.mkString(",")})") 60 | println("") 61 | 62 | val meanOfFinalRewards = scala.collection.mutable.Map.empty[Double, MeanAndVariance] 63 | val res = for { 64 | η <- ηs 65 | } yield { 66 | println(s"starts simulation on η=$η.") 67 | 68 | val algo = Hedge.Algorithm(η) 69 | val res = TestRunner.run(algo, arms, nSims, horizon) 70 | 71 | for { 72 | sim <- 0 until nSims 73 | } { 74 | val st = sim * horizon 75 | val end = ((sim + 1) * horizon) - 1 76 | } 77 | val finalRewards = res.cumRewards((horizon-1) until (nSims * horizon, horizon)) 78 | import breeze.stats._ 79 | val meanAndVar = meanAndVariance(finalRewards) 80 | meanOfFinalRewards += η -> meanAndVar 81 | println(s"reward stats: ${TestRunner.toString(meanAndVar)}") 82 | 83 | res.rawResults.valuesIterator.foreach{ v => 84 | file.write(s"${Seq(η, v._1, v._2, v._3, v._4, v._5).mkString(",")}\n") 85 | } 86 | println(s"finished simulation on η=$η.") 87 | } 88 | println("") 89 | println(s"reward stats summary") 90 | println(s"${meanOfFinalRewards.iterator.toSeq.sortBy(_._1).map(p => (s"η=${p._1}", TestRunner.toString(p._2))).mkString("\n")}") 91 | } finally { 92 | file.close() 93 | println("") 94 | println(s"results are written to ${outputPath}") 95 | } 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/algorithm/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook 23 | 24 | import breeze.linalg._ 25 | import breeze.stats.distributions.{ApacheDiscreteDistribution, Rand} 26 | import org.apache.commons.math3.distribution.{AbstractIntegerDistribution, EnumeratedIntegerDistribution} 27 | 28 | package object algorithm { 29 | 30 | type Seq[+T] = scala.collection.immutable.Seq[T] 31 | 32 | def CategoricalDistribution(probs: Vector[Double]): Rand[Int] = new ApacheDiscreteDistribution { 33 | override protected val inner: AbstractIntegerDistribution 34 | = new EnumeratedIntegerDistribution((0 until probs.size).toArray, probs.copy.valuesIterator.toArray) 35 | } 36 | 37 | } 38 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/algorithm/softmax/Standard.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook.algorithm.softmax 23 | 24 | import breeze.linalg.Vector._ 25 | import breeze.linalg._ 26 | import breeze.numerics.exp 27 | import breeze.stats.distributions.{Rand, RandBasis} 28 | import breeze.storage.Zero 29 | import com.github.everpeace.banditsbook.algorithm._ 30 | import com.github.everpeace.banditsbook.arm.Arm 31 | 32 | import scala.collection.immutable.Seq 33 | import scala.reflect.ClassTag 34 | 35 | /** 36 | * see: http://www.cs.nyu.edu/~mohri/pub/bandit.pdf 37 | */ 38 | object Standard { 39 | 40 | case class State(τ: Double, counts: Vector[Int], expectations: Vector[Double]) 41 | 42 | 43 | def Algorithm(τ: Double)(implicit zeroReward: Zero[Double], zeroInt: Zero[Int], tag: ClassTag[Double], rand: RandBasis = Rand) 44 | = { 45 | require(τ > 0, "τ must be positive.") 46 | new Algorithm[Double, State] { 47 | 48 | override def initialState(arms: Seq[Arm[Double]]): State = State( 49 | τ, zeros(arms.size), zeros(arms.size) 50 | ) 51 | 52 | override def selectArm(arms: Seq[Arm[Double]], state: State): Int = { 53 | val expectations = state.expectations 54 | val τ = state.τ 55 | val p = exp(expectations / τ) / sum(exp(expectations / τ)) 56 | CategoricalDistribution(p).draw 57 | } 58 | 59 | override def updateState(arms: Seq[Arm[Double]], state: State, chosen: Int, reward: Double): State = { 60 | val counts = state.counts 61 | val expectations = state.expectations 62 | 63 | val count = counts(chosen) + 1 64 | counts.update(chosen, count) 65 | 66 | val expectation = (((count - 1) / count.toDouble) * expectations(chosen)) + ((1 / count.toDouble) * reward) 67 | expectations.update(chosen, expectation) 68 | 69 | state.copy(counts = counts, expectations = expectations) 70 | } 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/algorithm/softmax/TestStandard.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook.algorithm.softmax 23 | 24 | import java.io.{File, PrintWriter} 25 | 26 | import breeze.linalg._ 27 | import breeze.stats.MeanAndVariance 28 | import com.github.everpeace.banditsbook.arm._ 29 | import com.github.everpeace.banditsbook.testing_framework.TestRunner 30 | import com.github.everpeace.banditsbook.testing_framework.TestRunner._ 31 | import com.typesafe.config.ConfigFactory 32 | 33 | import scala.collection.immutable.Seq 34 | 35 | object TestStandard extends _TestStandard with App { 36 | run() 37 | } 38 | 39 | trait _TestStandard { 40 | def run() = { 41 | // implicit val randBasis = RandBasis.mt0 42 | 43 | val conf = ConfigFactory.load() 44 | val baseKey = "banditsbook.algorithm.softmax.test-standard" 45 | val (_means, Some(τs), horizon, nSims, outDir) = readConfig(conf, baseKey, Some("τs")) 46 | val means = shuffle(_means) 47 | val arms = Seq(means:_*).map(μ => BernoulliArm(μ)) 48 | 49 | val outputPath = new File(outDir, "test-standard-softmax-results.csv") 50 | val file = new PrintWriter(outputPath.toString) 51 | file.write("tau, sim_num, step, chosen_arm, reward, cumulative_reward\n") 52 | try { 53 | println("-------------------------------") 54 | println("Standard Softmax Algorithm") 55 | println("-------------------------------") 56 | println(s" arms = ${means.map("(μ="+_+")").mkString(", ")} (Best Arm = ${argmax(means)})") 57 | println(s"horizon = $horizon") 58 | println(s" nSims = $nSims") 59 | println(s" τ = (${τs.mkString(",")})") 60 | println("") 61 | 62 | val meanOfFinalRewards = scala.collection.mutable.Map.empty[Double, MeanAndVariance] 63 | val res = for { 64 | τ <- τs 65 | } yield { 66 | println(s"starts simulation on τ=$τ.") 67 | 68 | val algo = Standard.Algorithm(τ) 69 | val res = TestRunner.run(algo, arms, nSims, horizon) 70 | 71 | for { 72 | sim <- 0 until nSims 73 | } { 74 | val st = sim * horizon 75 | val end = ((sim + 1) * horizon) - 1 76 | } 77 | val finalRewards = res.cumRewards((horizon-1) until (nSims * horizon, horizon)) 78 | import breeze.stats._ 79 | val meanAndVar = meanAndVariance(finalRewards) 80 | meanOfFinalRewards += τ -> meanAndVar 81 | println(s"reward stats: ${TestRunner.toString(meanAndVar)}") 82 | 83 | res.rawResults.valuesIterator.foreach{ v => 84 | file.write(s"${Seq(τ, v._1, v._2, v._3, v._4, v._5).mkString(",")}\n") 85 | } 86 | println(s"finished simulation on τ=$τ.") 87 | } 88 | println("") 89 | println(s"reward stats summary") 90 | println(s"${meanOfFinalRewards.iterator.toSeq.sortBy(_._1).map(p => (s"τ=${p._1}", TestRunner.toString(p._2))).mkString("\n")}") 91 | } finally { 92 | file.close() 93 | println("") 94 | println(s"results are written to ${outputPath}") 95 | } 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/algorithm/ucb/TestUCB1.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook.algorithm.ucb 23 | 24 | import java.io.{File, PrintWriter} 25 | 26 | import breeze.linalg._ 27 | import com.github.everpeace.banditsbook.arm._ 28 | import com.github.everpeace.banditsbook.testing_framework.TestRunner 29 | import com.github.everpeace.banditsbook.testing_framework.TestRunner._ 30 | import com.typesafe.config.ConfigFactory 31 | 32 | import scala.collection.immutable.Seq 33 | 34 | object TestUCB1 extends _TestUCB1 with App{ 35 | run() 36 | } 37 | 38 | trait _TestUCB1 { 39 | def run() = { 40 | // implicit val randBasis = RandBasis.mt0 41 | 42 | val conf = ConfigFactory.load() 43 | val baseKey = "banditsbook.algorithm.ucb.test-ucb1" 44 | val (_means, _, horizon, nSims, outDir) = readConfig(conf, baseKey) 45 | val means = shuffle(_means) 46 | val arms = Seq(means:_*).map(μ => BernoulliArm(μ)) 47 | 48 | 49 | val outputPath = new File(outDir, "test-ucb1-results.csv") 50 | val file = new PrintWriter(outputPath.toString) 51 | file.write("sim_num, step, chosen_arm, reward, cumulative_reward\n") 52 | try { 53 | println("-------------------------------") 54 | println("UCB1 Algorithm") 55 | println("-------------------------------") 56 | println(s" arms = ${means.map("(μ="+_+")").mkString(", ")} (Best Arm = ${argmax(means)})") 57 | println(s"horizon = $horizon") 58 | println(s" nSims = $nSims") 59 | println( "The algorithm has no hyper parameters.") 60 | println("") 61 | 62 | println(s"starts simulation.") 63 | 64 | val algo = UCB1.Algorithm 65 | val res = TestRunner.run(algo, arms, nSims, horizon) 66 | 67 | for {sim <- 0 until nSims} { 68 | val st = sim * horizon 69 | val end = ((sim + 1) * horizon) - 1 70 | } 71 | val finalRewards = res.cumRewards((horizon-1) until (nSims * horizon, horizon)) 72 | import breeze.stats._ 73 | val meanAndVar = meanAndVariance(finalRewards) 74 | println(s"reward stats: ${TestRunner.toString(meanAndVar)}") 75 | 76 | res.rawResults.valuesIterator.foreach{ v => 77 | file.write(s"${Seq(v._1, v._2, v._3, v._4, v._5).mkString(",")}\n") 78 | } 79 | println(s"finished simulation.") 80 | } finally { 81 | file.close() 82 | println("") 83 | println(s"results are written to ${outputPath}") 84 | } 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/algorithm/ucb/UCB1.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook.algorithm.ucb 23 | 24 | import breeze.numerics.sqrt 25 | import breeze.stats.distributions.{Rand, RandBasis} 26 | import breeze.storage.Zero 27 | import com.github.everpeace.banditsbook.algorithm.Algorithm 28 | import com.github.everpeace.banditsbook.arm.Arm 29 | 30 | import scala.collection.immutable.Seq 31 | import scala.reflect.ClassTag 32 | 33 | /** 34 | * http://www.cs.mcgill.ca/~vkules/bandits.pdf 35 | */ 36 | object UCB1 { 37 | 38 | import breeze.linalg._ 39 | import Vector._ 40 | 41 | case class State(counts: Vector[Int], expectations: Vector[Double]) 42 | 43 | def Algorithm(implicit zeroReward: Zero[Double], zeroInt: Zero[Int], tag: ClassTag[Double], rand: RandBasis = Rand) 44 | = { 45 | new Algorithm[Double, State] { 46 | 47 | override def initialState(arms: Seq[Arm[Double]]): State = State( 48 | zeros(arms.size), zeros(arms.size) 49 | ) 50 | 51 | override def selectArm(arms: Seq[Arm[Double]], state: State): Int = { 52 | val counts = state.counts 53 | val expectations = state.expectations 54 | val step = sum(counts) 55 | val factor = fill(counts.size)(2 * scala.math.log(step)) 56 | val bonus = sqrt(factor / counts.map(_.toDouble)) 57 | val score = expectations + bonus 58 | argmax(score) 59 | } 60 | 61 | override def updateState(arms: Seq[Arm[Double]], state: State, chosen: Int, reward: Double): State = { 62 | val counts = state.counts 63 | val expectations = state.expectations 64 | val count = counts(chosen) + 1 65 | counts.update(chosen, count) 66 | 67 | val expectation = (((count - 1) / count.toDouble) * expectations(chosen)) + ((1 / count.toDouble) * reward) 68 | expectations.update(chosen, expectation) 69 | 70 | state.copy(counts = counts, expectations = expectations) 71 | } 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/arm/Arms.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook.arm 23 | 24 | import breeze.stats.distributions.{Bernoulli, Gaussian} 25 | 26 | trait Arms { 27 | 28 | def BernoulliArm(p: Double): Arm[Double] = Bernoulli.distribution(p).map { 29 | case true => 1.0d 30 | case false => 0.0d 31 | } 32 | 33 | def NormalArm(μ: Double, σ: Double): Arm[Double] = Gaussian.distribution(μ -> σ) 34 | 35 | def AdversarialArm(start: Int, activeStart: Int, activeEnd: Int) = new Arm[Double] { 36 | @volatile var t = start 37 | override def draw(): Double = { 38 | t += 1 39 | t match { 40 | case _t if _t < activeStart => 0.0d 41 | case _t if activeStart <= _t && _t <= activeEnd => 1.0 42 | case _t => 0.0d 43 | } 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/arm/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook 23 | 24 | import breeze.stats.distributions.Rand 25 | 26 | package object arm extends Arms { 27 | type Arm[+Reward] = Rand[Reward] 28 | } 29 | -------------------------------------------------------------------------------- /src/main/scala/com/github/everpeace/banditsbook/testing_framework/TestRunner.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Shingo Omura 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | * this software and associated documentation files (the "Software"), to deal in 6 | * the Software without restriction, including without limitation the rights to 7 | * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | * the Software, and to permit persons to whom the Software is furnished to do so, 9 | * subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | */ 21 | 22 | package com.github.everpeace.banditsbook.testing_framework 23 | 24 | import java.io.File 25 | 26 | import breeze.linalg.Vector._ 27 | import breeze.linalg._ 28 | import breeze.stats.MeanAndVariance 29 | import breeze.storage.Zero 30 | import com.github.everpeace.banditsbook.algorithm.{Algorithm, TracedAlgorithmDriver} 31 | import com.github.everpeace.banditsbook.arm._ 32 | import com.typesafe.config.Config 33 | 34 | import scala.collection.immutable.Seq 35 | import scala.reflect.ClassTag 36 | 37 | object TestRunner { 38 | def toString(mv: MeanAndVariance) = s"(μ = ${mv.mean}, σ^2 = ${mv.variance})" 39 | 40 | def readConfig(conf: Config, baseKey: String, hyper_parameters_name: Option[String] = None) = { 41 | import scala.collection.convert.decorateAsScala._ 42 | val means = Array( 43 | conf.getDoubleList(s"$baseKey.arm-means").asScala.map(_.toDouble): _* 44 | ) 45 | val hyper_parameters = hyper_parameters_name.map(name => 46 | conf.getDoubleList(s"$baseKey.$name").asScala.map(_.toDouble).toSeq 47 | ) 48 | val horizon = conf.getInt(s"$baseKey.horizon") 49 | val nSims = conf.getInt(s"$baseKey.n-sims") 50 | val outDir = new File(conf.getString(s"banditsbook.algorithm.test-common.out-dir")) 51 | 52 | (means, hyper_parameters, horizon, nSims, outDir) 53 | } 54 | 55 | case class TestRunnerResult( 56 | arms: Seq[Arm[Double]], 57 | numSims: Int, 58 | horizon: Int, 59 | simNums: Vector[Int], 60 | stepNums: Vector[Int], 61 | chosenArms: Vector[Int], 62 | rewards: Vector[Double], 63 | cumRewards: Vector[Double], 64 | // raw data format: 65 | // simNum, step, chosenArm, rewards, cumRewards 66 | rawResults: Vector[(Int, Int, Int, Double, Double)] 67 | ) 68 | 69 | def run[AlgState](alg: Algorithm[Double, AlgState], arms: Seq[Arm[Double]], nSims: Int, horizon: Int) 70 | (implicit zeroDouble: Zero[Double], zero:Zero[Int], classTag: ClassTag[Double] ) 71 | = { 72 | val simNums = zeros[Int](nSims * horizon) 73 | val stepNums = zeros[Int](nSims * horizon) 74 | val chosenArms = zeros[Int](nSims * horizon) 75 | val rewards = zeros[Double](nSims * horizon) 76 | val cumRewards = zeros[Double](nSims * horizon) 77 | val rawResults = fill(nSims * horizon)((0, 0, 0, 0.0d, 0.0d)) 78 | 79 | for { sim <- 0 until nSims }{ 80 | val st = horizon * sim 81 | val end = (horizon * (sim + 1)) - 1 82 | 83 | val driver = TracedAlgorithmDriver(alg) 84 | val res = driver.run(arms, horizon) 85 | 86 | val cums = Vector( 87 | res.trace.rewards.valuesIterator.foldLeft(Array.empty[Double])((cum, r) => cum ++ Array(cum.lastOption.getOrElse(0d) + r)) 88 | ) 89 | simNums(st to end) := fill(horizon)(sim) 90 | stepNums(st to end) := tabulate(horizon)(identity) 91 | chosenArms(st to end) := res.trace.chosenArms 92 | rewards(st to end) := res.trace.rewards 93 | cumRewards(st to end) := cums 94 | rawResults(st to end) := tabulate(horizon)(i => 95 | (sim, i, res.trace.chosenArms(i), res.trace.rewards(i), cums(i)) 96 | ) 97 | print((if (sim % 10 == 9) "." else "") + (if (sim % 1000 == 999 || sim == nSims - 1) s"[${sim + 1}]\n" else "")) 98 | } 99 | 100 | TestRunnerResult( 101 | arms, nSims, horizon, simNums, stepNums, 102 | chosenArms, rewards, cumRewards, 103 | rawResults 104 | ) 105 | } 106 | 107 | } 108 | --------------------------------------------------------------------------------