├── .gitignore ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── dev-requirements.txt ├── docs └── slots-docs.md ├── misc └── regret_plot.png ├── mypy.ini ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── setup.py ├── slots ├── __init__.py └── slots.py └── tests └── tests.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *~ 3 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (C) 2014 Roy Keyes 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Include the readme and license files 2 | include README.md 3 | include LICENSE.txt 4 | 5 | # Include the data files 6 | recursive-include misc * 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # slots 2 | 3 | ## *A multi-armed bandit library for Python* 4 | 5 | Slots is intended to be a basic, very easy-to-use multi-armed bandit library for Python. 6 | 7 | [![PyPI](https://img.shields.io/pypi/v/slots)](https://pypi.org/project/slots/) 8 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/slots)](https://pypi.org/project/slots/) 9 | [![Downloads](https://pepy.tech/badge/slots)](https://pepy.tech/project/slots) 10 | 11 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 12 | [![type hints with mypy](https://img.shields.io/badge/type%20hints-mypy-brightgreen)](http://mypy-lang.org/) 13 | 14 | ### Author 15 | 16 | [Roy Keyes](https://roycoding.github.io) -- roy.coding@gmail 17 | 18 | ### License: MIT 19 | 20 | See [LICENSE.txt](https://github.com/roycoding/slots/blob/master/LICENSE.txt) 21 | 22 | ### Introduction 23 | 24 | slots is a Python library designed to allow the user to explore and use simple multi-armed bandit (MAB) strategies. The basic concept behind the multi-armed bandit problem is that you are faced with *n* choices (e.g. slot machines, medicines, or UI/UX designs), each of which results in a "win" with some unknown probability. Multi-armed bandit strategies are designed to let you quickly determine which choice will yield the highest result over time, while reducing the number of tests (or arm pulls) needed to make this determination. Typically, MAB strategies attempt to strike a balance between "exploration", testing different arms in order to find the best, and "exploitation", using the best known choice. There are many variation of this problem, see [here](https://en.wikipedia.org/wiki/Multi-armed_bandit) for more background. 25 | 26 | slots provides a hopefully simple API to allow you to explore, test, and use these strategies. Basic usage looks like this: 27 | 28 | Using slots to determine the best of 3 variations on a live website. 29 | 30 | ```Python 31 | import slots 32 | 33 | mab = slots.MAB(3, live=True) 34 | ``` 35 | 36 | Make the first choice randomly, record responses, and input reward 2 was chosen. Run online trial (input most recent result) until test criteria is met. 37 | 38 | ```Python 39 | mab.online_trial(bandit=2,payout=1) 40 | ``` 41 | 42 | The response of `mab.online_trial()` is a dict of the form: 43 | 44 | ```Python 45 | {'new_trial': boolean, 'choice': int, 'best': int} 46 | ``` 47 | 48 | Where: 49 | 50 | - If the criterion is met, `new_trial` = `False`. 51 | - `choice` is the current choice of arm to try. 52 | - `best` is the current best estimate of the highest payout arm. 53 | 54 | To test strategies on arms with pre-set probabilities: 55 | 56 | ```Python 57 | # Try 3 bandits with arbitrary win probabilities 58 | b = slots.MAB(3, live=False) 59 | b.run() 60 | ``` 61 | 62 | To inspect the results and compare the estimated win probabilities versus the true win probabilities: 63 | 64 | ```Python 65 | # Current best guess 66 | b.best() 67 | > 0 68 | 69 | # Estimate of the payout probabilities 70 | b.est_probs() 71 | > array([ 0.83888149, 0.78534031, 0.32786885]) 72 | 73 | # Ground truth payout probabilities (if known) 74 | b.bandits.probs 75 | > [0.8020877268854065, 0.7185844454955193, 0.16348877912363646] 76 | ``` 77 | 78 | By default, slots uses the epsilon greedy strategy. Besides epsilon greedy, the softmax, upper confidence bound (UCB1), and Bayesian bandit strategies are also implemented. 79 | 80 | #### Regret analysis 81 | 82 | A common metric used to evaluate the relative success of a MAB strategy is "regret". This reflects that fraction of payouts (wins) that have been lost by using the sequence of pulls versus the currently best known arm. The current regret value can be calculated by calling the `mab.regret()` method. 83 | 84 | For example, the regret curves for several different MAB strategies can be generated as follows: 85 | 86 | ```Python 87 | import matplotlib.pyplot as plt 88 | import slots 89 | 90 | # Test multiple strategies for the same bandit probabilities 91 | probs = [0.4, 0.9, 0.8] 92 | 93 | strategies = [{'strategy': 'eps_greedy', 'regret': [], 94 | 'label': '$\epsilon$-greedy ($\epsilon$=0.1)'}, 95 | {'strategy': 'softmax', 'regret': [], 96 | 'label': 'Softmax ($T$=0.1)'}, 97 | {'strategy': 'ucb', 'regret': [], 98 | 'label': 'UCB1'}, 99 | {'strategy': 'bayesian', 'regret': [], 100 | 'label': 'Bayesian bandit'}, 101 | ] 102 | 103 | for s in strategies: 104 | s['mab'] = slots.MAB(probs=probs, live=False) 105 | 106 | # Run trials and calculate the regret after each trial 107 | for t in range(10000): 108 | for s in strategies: 109 | s['mab']._run(s['strategy']) 110 | s['regret'].append(s['mab'].regret()) 111 | 112 | # Pretty plotting 113 | plt.style.use(['seaborn-poster','seaborn-whitegrid']) 114 | 115 | plt.figure(figsize=(15,4)) 116 | 117 | for s in strategies: 118 | plt.plot(s['regret'], label=s['label']) 119 | 120 | plt.legend() 121 | plt.xlabel('Trials') 122 | plt.ylabel('Regret') 123 | plt.title('Multi-armed bandit strategy performance (slots)') 124 | plt.ylim(0,0.2); 125 | ``` 126 | 127 | ![Regret plot](./misc/regret_plot.png) 128 | 129 | ### API documentation 130 | 131 | For documentation on the slots API, see [slots-docs.md](https://github.com/roycoding/slots/blob/master/docs/slots-docs.md). 132 | 133 | ### Todo list: 134 | 135 | - More MAB strategies 136 | - Argument to save regret values after each trial in an array. 137 | - TESTS! 138 | 139 | ### Contributing 140 | 141 | I welcome contributions, though the pace of development is highly variable. Please file issues and submit pull requests as makes sense. 142 | 143 | The current development environment uses: 144 | 145 | - pytest >= 5.3 (5.3.2) 146 | - black >= 19.1 (19.10b0) 147 | - mypy = 0.761 148 | 149 | You can pip install these easily by including `dev-requirements.txt`. 150 | 151 | For mypy config, see `mypy.ini`. For black config, see `pyproject.toml`. 152 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | mypy>=0.761 2 | black>=19.10b0 3 | pytest>=5.3.2 -------------------------------------------------------------------------------- /docs/slots-docs.md: -------------------------------------------------------------------------------- 1 | # slots 2 | ## Multi-armed bandit library in Python 3 | 4 | ## Documentation 5 | This documents details the current and planned API for slots. Non-implemented features are noted as such. 6 | 7 | ### What does the library need to do? An aspirational list. 8 | 1. Set up N bandits with probabilities, p_i, and payouts, pay_i. 9 | 2. Implement several MAB strategies, with kwargs as parameters, and consistent API. 10 | 3. Allow for T trials. 11 | 4. Continue with more trials (i.e. save state after trials). 12 | 5. Values to save: 13 | 1. Current choice 14 | 2. number of trials completed for each arm 15 | 3. scores for each arm 16 | 4. average payout per arm (wins/trials?) 17 | 5. Current regret. Regret = Trials*mean_max - sum^T_t=1(reward_t) 18 | - See [ref](http://research.microsoft.com/en-us/um/people/sebubeck/SurveyBCB12.pdf) 19 | 6. Use sane defaults. 20 | 7. Be obvious and clean. 21 | 8. For the time being handle only binary payouts. 22 | 23 | ### Library API ideas: 24 | #### Running slots with a live website 25 | ```Python 26 | # Using slots to determine the best of 3 variations on a live website. 3 is the default number of bandits and epsilon greedy is the default strategy. 27 | mab = slots.MAB(3, live=True) 28 | 29 | # Make the first choice randomly, record responses, and input reward 30 | # 2 was chosen. 31 | # Update online trial (input most recent result) until test criteria is met. 32 | mab.online_trial(bandit=2,payout=1) 33 | 34 | # Repsonse of mab.online_trial() is a dict of the form: 35 | {'new_trial': boolean, 'choice': int, 'best': int} 36 | 37 | # Where: 38 | # If the criterion is met, new_trial = False. 39 | # choice is the current choice of arm to try next. 40 | # best is the current best estimate of the highest payout arm. 41 | ``` 42 | 43 | #### Creating a MAB test instance: 44 | 45 | ```Python 46 | # Default: 3 bandits with random probabilities, p_i. 47 | mab = slots.MAB() 48 | 49 | # Set up 4 bandits with random p_i. 50 | mab = slots.MAB(4) 51 | 52 | # 4 bandits with specified p_i 53 | mab = slots.MAB(probs = [0.2,0.1,0.4,0.1]) 54 | 55 | # Creating 3 bandits with histoprical payout data 56 | mab = slots.MAB(3, hist_payouts = np.array([[0,0,1,...], 57 | [1,0,0,...], 58 | [0,0,0,...]])) 59 | ``` 60 | 61 | #### Running tests with strategy, S 62 | 63 | ```Python 64 | # Default: Epsilon-greedy, epsilon = 0.1, num_trials = 100 65 | mab.run() 66 | 67 | # Run chosen strategy with specified parameters and number of trials 68 | mab.run(strategy = 'eps_greedy',params = {'eps':0.2}, trials = 10000) 69 | 70 | # Run strategy, updating old trial data 71 | # (NOT YET IMPLEMENTED) 72 | mab.run(continue = True) 73 | ``` 74 | 75 | #### Displaying / retrieving bandit properties 76 | 77 | ```Python 78 | # Default: display number of bandits, probabilities and payouts 79 | # (NOT YET IMPLEMENTED) 80 | mab.bandits.info() 81 | 82 | # Display info for bandit i 83 | # (NOT YET IMPLEMENTED) 84 | mab.bandits[i] 85 | 86 | # Retrieve bandits' payouts, probabilities, etc 87 | mab.bandits.payouts 88 | mab.bandits.probs 89 | 90 | # Retrieve count of bandits 91 | # (NOT YET IMPLEMENTED) 92 | mab.bandits.count 93 | ``` 94 | 95 | #### Setting bandit properties 96 | 97 | ```Python 98 | # Reset bandits to defaults 99 | # (NOT YET IMPLEMENTED) 100 | mab.bandits.reset() 101 | 102 | # Set probabilities or payouts 103 | # (NOT YET IMPLEMENTED) 104 | mab.bandits.set_probs([0.1,0.05,0.2,0.15]) 105 | mab.bandits.set_hist_payouts([[1,1,0,0],[0,1,0,0]]) 106 | ``` 107 | 108 | #### Displaying / retrieving test info 109 | 110 | ```Python 111 | # Retrieve current "best" bandit 112 | mab.best() 113 | 114 | # Retrieve bandit probability estimates 115 | # (NOT YET IMPLEMENTED) 116 | mab.prob_est() 117 | 118 | # Retrieve bandit probability estimate of bandit i 119 | # (NOT YET IMPLEMENTED) 120 | mab.est_prob(i) 121 | 122 | # Retrieve bandit probability estimates 123 | mab.est_probs() 124 | 125 | # Retrieve current bandit choice 126 | # (NOT YET IMPLEMENTED, use mab.choices[-1]) 127 | mab.current() 128 | 129 | # Retrieve sequence of choices 130 | mab.choices 131 | 132 | # Retrieve probability estimate history 133 | # (NOT YET IMPLEMENTED) 134 | mab.prob_est_sequence 135 | 136 | # Retrieve test strategy info (current strategy) -- a dict 137 | # (NOT YET IMPLEMENTED) 138 | mab.strategy_info() 139 | ``` 140 | 141 | ### Proposed MAB strategies 142 | - [x] Epsilon-greedy 143 | - [ ] Epsilon decreasing 144 | - [x] Softmax 145 | - [ ] Softmax decreasing 146 | - [x] Upper credible bound 147 | - [x] Bayesian bandits 148 | -------------------------------------------------------------------------------- /misc/regret_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/roycoding/slots/38f76812a296ca024ebf9ad3bc60448f8d12207d/misc/regret_plot.png -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | disallow_untyped_calls = True 3 | disallow_untyped_defs = True 4 | 5 | [mypy-numpy] 6 | ignore_missing_imports = True -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 79 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | # This flag says that the code is written to work on both Python 2 and Python 3 | # 3. 4 | 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """slots: a multi-armed bandit library for Python 2 | See: 3 | https://github.com/roycoding/slots 4 | """ 5 | 6 | from setuptools import setup, find_packages 7 | from codecs import open 8 | from os import path 9 | 10 | here = path.abspath(path.dirname(__file__)) 11 | 12 | # Get the long description from the README file 13 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 14 | long_description = f.read() 15 | 16 | setup( 17 | name='slots', 18 | 19 | version='0.4.0', 20 | 21 | description='A multi-armed bandit library for Python', 22 | long_description=long_description, 23 | long_description_content_type="text/markdown", 24 | 25 | # The project's main homepage. 26 | url='https://github.com/roycoding/slots', 27 | 28 | # Author details 29 | author='Roy Keyes', 30 | author_email='roy.coding@gmail.com', 31 | 32 | # Choose your license 33 | license='MIT', 34 | 35 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 36 | classifiers=[ 37 | # How mature is this project? Common values are 38 | # 3 - Alpha 39 | # 4 - Beta 40 | # 5 - Production/Stable 41 | 'Development Status :: 3 - Alpha', 42 | 43 | # Indicate who your project is intended for 44 | 'Intended Audience :: Developers', 45 | 'Intended Audience :: Science/Research', 46 | 'Topic :: Scientific/Engineering', 47 | 48 | # Pick your license as you wish (should match "license" above) 49 | 'License :: OSI Approved :: MIT License', 50 | 51 | # Specify the Python versions you support here. In particular, ensure 52 | # that you indicate whether you support Python 2, Python 3 or both. 53 | 'Programming Language :: Python :: 3.5', 54 | 'Programming Language :: Python :: 3.6', 55 | 'Programming Language :: Python :: 3.7', 56 | ], 57 | 58 | # What does your project relate to? 59 | keywords='multi-armed bandit hypothesis testing', 60 | 61 | # You can just specify the packages manually here if your project is 62 | # simple. Or you can use find_packages(). 63 | packages=find_packages(exclude=['contrib', 'docs', 'tests']), 64 | 65 | # Alternatively, if you want to distribute just a my_module.py, uncomment 66 | # this: 67 | # py_modules=["my_module"], 68 | 69 | # List run-time dependencies here. These will be installed by pip when 70 | # your project is installed. For an analysis of "install_requires" vs pip's 71 | # requirements files see: 72 | # https://packaging.python.org/en/latest/requirements.html 73 | install_requires=['numpy'], 74 | ) 75 | -------------------------------------------------------------------------------- /slots/__init__.py: -------------------------------------------------------------------------------- 1 | from .slots import MAB 2 | -------------------------------------------------------------------------------- /slots/slots.py: -------------------------------------------------------------------------------- 1 | """ 2 | slots 3 | 4 | A Python library to perform simple multi-armed bandit analyses. 5 | 6 | Scenarios: 7 | - Run MAB test on simulated data (N bandits), default epsilon-greedy test. 8 | mab = slots.MAB(probs = [0.1,0.15,0.05]) 9 | mab.run(trials = 10000) 10 | mab.best # Bandit with highest probability after T trials 11 | 12 | - Run MAB test on "real" payout data (probabilites unknown). 13 | mab = slots.MAB(hist_payouts = [[0,0,...], [1,0,...], [0,1,...]) 14 | mab.run(trials = 10000) 15 | 16 | - Run MAB test on "live" data 17 | mab = slots.MAB(num_bandits=3, live=True) 18 | mab.online_trial(bandit=1, payout=0) 19 | """ 20 | 21 | from typing import Optional, List, Dict, Any, Union, Callable 22 | 23 | import numpy as np 24 | 25 | 26 | class MAB(object): 27 | """ 28 | Multi-armed bandit test class. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | num_bandits: Optional[int] = 3, 34 | probs: Optional[np.ndarray] = None, 35 | hist_payouts: Optional[List[np.ndarray]] = None, 36 | live: Optional[bool] = False, 37 | stop_criterion: Optional[Dict] = {"criterion": "regret", "value": 0.1}, 38 | ) -> None: 39 | """ 40 | Parameters 41 | ---------- 42 | num_bandits : int, optional 43 | default is 3 44 | probs : array of floats, optional 45 | payout probabilities 46 | hist_payouts : list of lists of ints, one array per bandit, optional 47 | This is for testing on historical data. 48 | If you set `probs` or `live` is True, `hist_payouts` should be None. 49 | live : bool, optional 50 | Whether the use is for a live, online trial. 51 | stop_criterion : dict, optional 52 | Stopping criterion (str) and threshold value (float). 53 | """ 54 | 55 | self.choices: List[int] = [] 56 | 57 | if not probs: 58 | if not hist_payouts: 59 | if live: 60 | # Live trial scenario, where nothing is known except the 61 | # number of bandits 62 | self.bandits: Bandits = Bandits( 63 | live=True, payouts=np.zeros(num_bandits) 64 | ) 65 | else: 66 | # A pure experiment scenario with random probabilities. 67 | self.bandits = Bandits( 68 | probs=np.random.rand(num_bandits), 69 | payouts=np.zeros(num_bandits), 70 | live=False, 71 | ) 72 | else: 73 | # Run strategies on known historical sequence of payouts. Probabilities are not known. 74 | num_bandits = len(hist_payouts) 75 | if live: 76 | print( 77 | "slots: Cannot have a defined array of payouts and live=True. live set to False" 78 | ) 79 | self.bandits = Bandits( 80 | hist_payouts=hist_payouts, 81 | payouts=np.zeros(num_bandits), 82 | live=False, 83 | ) 84 | else: 85 | if hist_payouts: 86 | # A pure experiment scenario with known historical payout values. Probabilities will be ignored. 87 | num_bandits = len(probs) 88 | print( 89 | "slots: Since historical payout data has been supplied, probabilities will be ignored." 90 | ) 91 | if len(probs) == len(hist_payouts): 92 | self.bandits = Bandits( 93 | hist_payouts=hist_payouts, 94 | live=False, 95 | payouts=np.zeros(num_bandits), 96 | ) 97 | else: 98 | raise Exception( 99 | "slots: Dimensions of probs and payouts mismatched." 100 | ) 101 | else: 102 | # A pure experiment scenario with known probabilities 103 | num_bandits = len(probs) 104 | self.bandits = Bandits( 105 | probs=probs, payouts=np.zeros(num_bandits), live=False 106 | ) 107 | 108 | self.wins: np.ndarray = np.zeros(num_bandits) 109 | self.pulls: np.ndarray = np.zeros(num_bandits) 110 | 111 | # Set the stopping criteria 112 | self.criteria: Dict[str, Callable[[Optional[float]], bool]] = { 113 | "regret": self.regret_met 114 | } 115 | if not stop_criterion: 116 | self.criterion: str = "regret" 117 | self.stop_value: float = 0.1 118 | else: 119 | self.criterion = stop_criterion.get("criterion", "regret") 120 | self.stop_value = stop_criterion.get("value", 0.1) 121 | 122 | # Bandit selection strategies 123 | self.strategies: List[str] = [ 124 | "eps_greedy", 125 | "softmax", 126 | "ucb", 127 | "bayesian", 128 | ] 129 | 130 | def run( 131 | self, 132 | trials: int = 100, 133 | strategy: str = "eps_greedy", 134 | parameters: Optional[Dict] = None, 135 | ) -> None: 136 | """ 137 | Run MAB test with T trials. 138 | 139 | Parameters 140 | ---------- 141 | trials : int 142 | Number of trials to run. 143 | strategy : str 144 | Name of selected strategy. "eps_greedy" is default. 145 | parameters : dict 146 | Parameters for selected strategy. 147 | 148 | Available strategies: 149 | - Epsilon-greedy ("eps_greedy") 150 | - Softmax ("softmax") 151 | - Upper confidence bound ("ucb") 152 | 153 | Returns 154 | ------- 155 | None 156 | """ 157 | 158 | if trials < 1: 159 | raise Exception( 160 | "slots.MAB.run: Number of trials cannot be less than 1!" 161 | ) 162 | 163 | else: 164 | if strategy not in self.strategies: 165 | raise Exception( 166 | "slots.MAB,run: Strategy name invalid. Choose from:" 167 | " {}".format(", ".join(self.strategies)) 168 | ) 169 | 170 | # Run strategy 171 | for n in range(trials): 172 | self._run(strategy, parameters) 173 | 174 | def _run(self, strategy: str, parameters: Optional[Dict] = None) -> None: 175 | """ 176 | Run single trial of MAB strategy. 177 | 178 | Parameters 179 | ---------- 180 | strategy : str 181 | parameters : dict 182 | 183 | Returns 184 | ------- 185 | None 186 | """ 187 | 188 | choice: int = self.run_strategy(strategy, parameters) 189 | self.choices.append(choice) 190 | payout: Optional[int] = self.bandits.pull(choice) 191 | if payout is None: 192 | print("Trials exhausted. No more values for bandit", choice) 193 | return None 194 | else: 195 | self.wins[choice] += payout 196 | self.pulls[choice] += 1 197 | 198 | def run_strategy( 199 | self, strategy: str, parameters: Optional[Dict] = None 200 | ) -> int: 201 | """ 202 | Run the selected strategy and retrun bandit choice. 203 | 204 | Parameters 205 | ---------- 206 | strategy : str 207 | Name of MAB strategy. 208 | parameters : dict 209 | Strategy function parameters 210 | 211 | Returns 212 | ------- 213 | int 214 | Bandit arm choice index 215 | """ 216 | 217 | return self.__getattribute__(strategy)(params=parameters) 218 | 219 | # ###### ----------- MAB strategies ---------------------------------------#### 220 | def max_mean(self) -> int: 221 | """ 222 | Pick the bandit with the current best observed proportion of winning. 223 | 224 | Returns 225 | ------- 226 | int 227 | Index of chosen bandit 228 | """ 229 | 230 | return np.argmax(self.wins / (self.pulls + 0.1)) 231 | 232 | def bayesian(self, params: Any = None) -> int: 233 | """ 234 | Run the Bayesian Bandit algorithm which utilizes a beta distribution 235 | for exploration and exploitation. 236 | 237 | Parameters 238 | ---------- 239 | params : None 240 | For API consistency, this function can take a parameters argument, 241 | but it is ignored. 242 | 243 | Returns 244 | ------- 245 | int 246 | Index of chosen bandit 247 | """ 248 | p_success_arms: List[float] = [ 249 | np.random.beta(self.wins[i] + 1, self.pulls[i] - self.wins[i] + 1) 250 | for i in range(len(self.wins)) 251 | ] 252 | 253 | return np.array(p_success_arms).argmax() 254 | 255 | def eps_greedy(self, params: Optional[Dict[str, float]] = None) -> int: 256 | """ 257 | Run the epsilon-greedy strategy and update self.max_mean() 258 | 259 | Parameters 260 | ---------- 261 | Params : dict 262 | Epsilon 263 | 264 | Returns 265 | ------- 266 | int 267 | Index of chosen bandit 268 | """ 269 | 270 | default_eps: float = 0.1 271 | 272 | if params and type(params) == dict: 273 | eps: float = params.get("epsilon", default_eps) 274 | try: 275 | float(eps) 276 | except ValueError: 277 | print("slots: eps_greedy: Setting eps to default") 278 | eps = default_eps 279 | else: 280 | eps = default_eps 281 | 282 | r: int = np.random.rand() 283 | 284 | if r < eps: 285 | return np.random.choice( 286 | list(set(range(len(self.wins))) - {self.max_mean()}) 287 | ) 288 | else: 289 | return self.max_mean() 290 | 291 | def softmax(self, params: Optional[Dict] = None) -> int: 292 | """ 293 | Run the softmax selection strategy. 294 | 295 | Parameters 296 | ---------- 297 | Params : dict 298 | Tau 299 | 300 | Returns 301 | ------- 302 | int 303 | Index of chosen bandit 304 | """ 305 | 306 | default_tau: float = 0.1 307 | 308 | if params and type(params) == dict: 309 | tau: float = params.get("tau", default_tau) 310 | try: 311 | float(tau) 312 | except ValueError: 313 | print("slots: softmax: Setting tau to default") 314 | tau = default_tau 315 | else: 316 | tau = default_tau 317 | 318 | # Handle cold start. Not all bandits tested yet. 319 | if True in (self.pulls < 3): 320 | return np.random.choice(range(len(self.pulls))) 321 | else: 322 | payouts: np.ndarray = self.wins / (self.pulls + 0.1) 323 | norm: float = sum(np.exp(payouts / tau)) 324 | 325 | ps: np.ndarray = np.exp(payouts / tau) / norm 326 | 327 | # Randomly choose index based on CMF 328 | cmf: List[int] = [sum(ps[: i + 1]) for i in range(len(ps))] 329 | 330 | rand: float = np.random.rand() 331 | 332 | found: bool = False 333 | found_i: int = 0 334 | i: int = 0 335 | while not found: 336 | if rand < cmf[i]: 337 | found_i = i 338 | found = True 339 | else: 340 | i += 1 341 | 342 | return found_i 343 | 344 | def ucb(self, params: Optional[Dict] = None) -> int: 345 | """ 346 | Run the upper confidence bound MAB selection strategy. 347 | 348 | This is the UCB1 algorithm described in 349 | https://homes.di.unimi.it/~cesabian/Pubblicazioni/ml-02.pdf 350 | 351 | Parameters 352 | ---------- 353 | params : None 354 | For API consistency, this function can take a parameters argument, 355 | but it is ignored. 356 | 357 | Returns 358 | ------- 359 | int 360 | Index of chosen bandit 361 | """ 362 | 363 | # UCB = j_max(payout_j + sqrt(2ln(n_tot)/n_j)) 364 | 365 | # Handle cold start. Not all bandits tested yet. 366 | if True in (self.pulls < 3): 367 | return np.random.choice(range(len(self.pulls))) 368 | else: 369 | n_tot: int = sum(self.pulls) 370 | payouts: np.ndarray = self.wins / (self.pulls + 0.1) 371 | ubcs: np.ndarray = payouts + np.sqrt( 372 | 2 * np.log(n_tot) / self.pulls 373 | ) 374 | 375 | return np.argmax(ubcs) 376 | 377 | # ###------------------------------------------------------------------#### 378 | 379 | def best(self) -> Optional[int]: 380 | """ 381 | Return current 'best' choice of bandit. 382 | 383 | Returns 384 | ------- 385 | int 386 | Index of bandit 387 | """ 388 | 389 | if len(self.choices) < 1: 390 | print("slots: No trials run so far.") 391 | return None 392 | else: 393 | return np.argmax(self.wins / (self.pulls + 0.1)) 394 | 395 | def est_probs(self) -> Optional[np.ndarray]: 396 | """ 397 | Calculate current estimate of average payout for each bandit. 398 | 399 | Returns 400 | ------- 401 | array of floats or None 402 | """ 403 | 404 | if len(self.choices) < 1: 405 | print("slots: No trials run so far.") 406 | return None 407 | else: 408 | return self.wins / (self.pulls + 0.1) 409 | 410 | def regret(self) -> float: 411 | """ 412 | Calculate expected regret, where expected regret is 413 | maximum optimal reward - sum of collected rewards, i.e. 414 | 415 | expected regret = T*max_k(mean_k) - sum_(t=1-->T) (reward_t) 416 | 417 | Returns 418 | ------- 419 | float 420 | """ 421 | 422 | return ( 423 | sum(self.pulls) * np.max(np.nan_to_num(self.wins / self.pulls)) 424 | - sum(self.wins) 425 | ) / sum(self.pulls) 426 | 427 | def crit_met(self) -> bool: 428 | """ 429 | Determine if stopping criterion has been met. 430 | 431 | Returns 432 | ------- 433 | bool 434 | """ 435 | 436 | if True in (self.pulls < 3): 437 | return False 438 | else: 439 | return self.criteria[self.criterion](self.stop_value) 440 | 441 | def regret_met(self, threshold: Optional[float] = None) -> bool: 442 | """ 443 | Determine if regret criterion has been met. 444 | 445 | Parameters 446 | ---------- 447 | threshold : float 448 | 449 | Returns 450 | ------- 451 | bool 452 | """ 453 | 454 | if not threshold: 455 | return self.regret() <= self.stop_value 456 | elif self.regret() <= threshold: 457 | return True 458 | else: 459 | return False 460 | 461 | # ## ------------ Online bandit testing ------------------------------ #### 462 | def online_trial( 463 | self, 464 | bandit: Optional[int] = None, 465 | payout: Optional[int] = None, 466 | strategy: str = "eps_greedy", 467 | parameters: Optional[Dict] = None, 468 | ) -> Dict: 469 | """ 470 | Update the bandits with the results of the previous live, online trial. 471 | Next run a the selection algorithm. If the stopping criteria is 472 | met, return the best arm estimate. Otherwise return the next arm to 473 | try. 474 | 475 | Parameters 476 | ---------- 477 | bandit : int 478 | Bandit index of most recent trial 479 | payout : int 480 | Payout value of most recent trial 481 | strategy : string 482 | Name of update strategy 483 | parameters : dict 484 | Parameters for update strategy function 485 | 486 | Returns 487 | ------- 488 | dict 489 | Format: {'new_trial': boolean, 'choice': int, 'best': int} 490 | """ 491 | 492 | if bandit is not None and payout is not None: 493 | self.update(bandit=bandit, payout=payout) 494 | else: 495 | raise Exception( 496 | "slots.online_trial: bandit and/or payout value" " missing." 497 | ) 498 | 499 | if self.crit_met(): 500 | return { 501 | "new_trial": False, 502 | "choice": self.best(), 503 | "best": self.best(), 504 | } 505 | else: 506 | return { 507 | "new_trial": True, 508 | "choice": self.run_strategy(strategy, parameters), 509 | "best": self.best(), 510 | } 511 | 512 | def update(self, bandit: int, payout: int) -> None: 513 | """ 514 | Update bandit trials and payouts for given bandit. 515 | 516 | Parameters 517 | ---------- 518 | bandit : int 519 | Bandit index 520 | payout : int (0 or 1) 521 | 522 | Returns 523 | ------- 524 | None 525 | """ 526 | 527 | self.choices.append(bandit) 528 | self.pulls[bandit] += 1 529 | self.wins[bandit] += payout 530 | self.bandits.payouts[bandit] += payout 531 | 532 | 533 | class Bandits: 534 | """ 535 | Bandit class. 536 | """ 537 | 538 | def __init__( 539 | self, 540 | payouts: np.ndarray, 541 | probs: Optional[np.ndarray] = None, 542 | hist_payouts: Optional[List[np.ndarray]] = None, 543 | live: bool = False, 544 | ): 545 | """ 546 | Instantiate Bandit class, determining 547 | - Probabilities of bandit payouts 548 | - Bandit payouts 549 | 550 | Parameters 551 | ---------- 552 | payouts : array of ints 553 | Cumulative bandit payouts. `payouts` should start as an N 554 | length array of zeros, where N is the number of bandits. 555 | probs: array of floats, optional 556 | Probabilities of bandit payouts. 557 | hist_payouts: list of arrays of ints, optional 558 | live : bool, optional 559 | """ 560 | 561 | if not live: 562 | self.probs: Optional[np.ndarray] = probs 563 | self.payouts: np.ndarray = payouts 564 | self.hist_payouts: Optional[List[np.ndarray]] = hist_payouts 565 | self.live: bool = False 566 | else: 567 | self.live = True 568 | self.probs = None 569 | self.payouts = payouts 570 | 571 | def pull(self, i: int) -> Optional[int]: 572 | """ 573 | Return the payout from a single pull of the bandit i's arm. 574 | 575 | Parameters 576 | ---------- 577 | i : int 578 | Index of bandit. 579 | 580 | Returns 581 | ------- 582 | int or None 583 | """ 584 | 585 | if self.live: 586 | if len(self.payouts[i]) > 0: 587 | return self.payouts[i].pop() 588 | else: 589 | return None 590 | elif self.hist_payouts: 591 | if not self.hist_payouts[i]: 592 | return None 593 | else: 594 | _p: int = self.hist_payouts[i][0] 595 | self.hist_payouts[i] = self.hist_payouts[i][1:] 596 | return _p 597 | else: 598 | if self.probs is None: 599 | return None 600 | elif np.random.rand() < self.probs[i]: 601 | return 1 602 | else: 603 | return 0 604 | 605 | def info(self) -> None: 606 | pass 607 | -------------------------------------------------------------------------------- /tests/tests.py: -------------------------------------------------------------------------------- 1 | # Assuming pytest 2 | from slots.slots import MAB 3 | 4 | # Most basic test of defaults 5 | def test_mab(): 6 | mab = MAB() 7 | mab.run() 8 | --------------------------------------------------------------------------------