├── bandits ├── __init__.py ├── figures │ └── empty ├── results │ └── empty ├── scripts │ ├── __init__.py │ ├── tabular_test.py │ ├── run_experiments.py │ ├── run_experiments.ipynb │ ├── training_utils.py │ ├── thompson_sampling_bernoulli.py │ ├── tabular_subspace_exp.py │ ├── plot_results.py │ ├── movielens_exp.py │ ├── mnist_exp.py │ ├── tabular_exp.py │ └── subspace_bandits.ipynb ├── __main__.py ├── environments │ ├── mnist_env.py │ ├── environment.py │ ├── ads16_env.py │ ├── movielens_env.py │ └── tabular_env.py ├── agents │ ├── base.py │ ├── agent_utils.py │ ├── diagonal_subspace.py │ ├── linear_kf_bandit.py │ ├── linear_bandit.py │ ├── linear_bandit_wide.py │ ├── neural_greedy.py │ ├── low_rank_filter_bandit.py │ ├── neural_linear_bandit_wide.py │ ├── ekf_orig_diag.py │ ├── ekf_orig_full.py │ ├── neural_linear.py │ ├── ekf_subspace.py │ └── limited_memory_neural_linear.py └── training.py ├── aistats2022-slides ├── .npmrc ├── README.md ├── public │ ├── ts-bandits.mp4 │ └── subspace-neural-bandit-diagram.jpg ├── vercel.json ├── .gitignore ├── netlify.toml ├── style.css ├── package.json ├── components │ └── Counter.vue ├── assets │ └── subspace-neural-bandit-diagram.tex └── slides.md ├── bandit-data ├── bandit-adult.pkl ├── bandit-stock.pkl ├── bandit-mushroom.pkl ├── bandit-statlog.pkl ├── bandit-covertype.pkl └── ml-100k │ └── README.txt ├── setup.py ├── requirements.txt ├── LICENSE ├── README.md ├── .gitignore └── demos └── lofi_tabular.py /bandits/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bandits/figures/empty: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bandits/results/empty: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bandits/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aistats2022-slides/.npmrc: -------------------------------------------------------------------------------- 1 | # for pnpm 2 | shamefully-hoist=true 3 | -------------------------------------------------------------------------------- /aistats2022-slides/README.md: -------------------------------------------------------------------------------- 1 | # Subspace EKF neural bandits 2 | ## AIStats 2022 3 | -------------------------------------------------------------------------------- /bandit-data/bandit-adult.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bandits/main/bandit-data/bandit-adult.pkl -------------------------------------------------------------------------------- /bandit-data/bandit-stock.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bandits/main/bandit-data/bandit-stock.pkl -------------------------------------------------------------------------------- /bandit-data/bandit-mushroom.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bandits/main/bandit-data/bandit-mushroom.pkl -------------------------------------------------------------------------------- /bandit-data/bandit-statlog.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bandits/main/bandit-data/bandit-statlog.pkl -------------------------------------------------------------------------------- /bandit-data/bandit-covertype.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bandits/main/bandit-data/bandit-covertype.pkl -------------------------------------------------------------------------------- /aistats2022-slides/public/ts-bandits.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bandits/main/aistats2022-slides/public/ts-bandits.mp4 -------------------------------------------------------------------------------- /aistats2022-slides/vercel.json: -------------------------------------------------------------------------------- 1 | { 2 | "rewrites": [ 3 | { "source": "/(.*)", "destination": "/index.html" } 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /aistats2022-slides/.gitignore: -------------------------------------------------------------------------------- 1 | *.pdf 2 | node_modules 3 | .DS_Store 4 | dist 5 | *.local 6 | index.html 7 | .remote-assets 8 | components.d.ts 9 | .texpadtmp 10 | -------------------------------------------------------------------------------- /aistats2022-slides/public/subspace-neural-bandit-diagram.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bandits/main/aistats2022-slides/public/subspace-neural-bandit-diagram.jpg -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup, find_packages 3 | 4 | setup( 5 | name="bandits", 6 | packages=find_packages(), 7 | install_requires=[] 8 | ) 9 | -------------------------------------------------------------------------------- /aistats2022-slides/netlify.toml: -------------------------------------------------------------------------------- 1 | [build.environment] 2 | NODE_VERSION = "14" 3 | 4 | [build] 5 | publish = "dist" 6 | command = "npm run build" 7 | 8 | [[redirects]] 9 | from = "/*" 10 | to = "/index.html" 11 | status = 200 12 | -------------------------------------------------------------------------------- /aistats2022-slides/style.css: -------------------------------------------------------------------------------- 1 | .centered { 2 | position: fixed; 3 | top: 50%; 4 | left: 50%; 5 | /* bring your own prefixes */ 6 | transform: translate(-50%, -50%); 7 | } 8 | 9 | .horizontal-center { 10 | display: block; 11 | margin-left: auto; 12 | margin-right: auto; 13 | } 14 | 15 | .bottom-right { 16 | position: absolute; 17 | bottom: 0; 18 | right: 0; 19 | } 20 | -------------------------------------------------------------------------------- /aistats2022-slides/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "private": true, 3 | "scripts": { 4 | "build": "slidev build", 5 | "dev": "slidev --open", 6 | "export": "slidev export" 7 | }, 8 | "dependencies": { 9 | "@slidev/cli": "^0.28.6", 10 | "@slidev/theme-default": "*", 11 | "@slidev/theme-seriph": "*" 12 | }, 13 | "name": "aistats2022", 14 | "devDependencies": { 15 | "playwright-chromium": "^1.19.1" 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /bandits/__main__.py: -------------------------------------------------------------------------------- 1 | import fire 2 | from scripts import plot_results 3 | from scripts import run_experiments 4 | from scripts import tabular_test 5 | 6 | 7 | class Experiments: 8 | def test(self): 9 | tabular_test.main() 10 | 11 | def plot_experiments(self): 12 | plot_results.main() 13 | 14 | def run_experiments(self, experiment=None): 15 | run_experiments.main(experiment) 16 | 17 | def run_and_plot(self): 18 | self.run_experiments() 19 | self.plot_experiments() 20 | 21 | if __name__ == "__main__": 22 | fire.Fire(Experiments) -------------------------------------------------------------------------------- /aistats2022-slides/components/Counter.vue: -------------------------------------------------------------------------------- 1 | 12 | 13 | 38 | -------------------------------------------------------------------------------- /bandits/environments/mnist_env.py: -------------------------------------------------------------------------------- 1 | from jax.nn import one_hot 2 | from jax.random import split, permutation 3 | 4 | import numpy as np 5 | from sklearn.datasets import fetch_openml 6 | 7 | from .environment import BanditEnvironment 8 | 9 | 10 | def get_mnist(key, ntrain): 11 | X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False) 12 | 13 | X = X / 255. 14 | y = y.astype(np.int32) 15 | 16 | perm = permutation(key, np.arange(len(X))) 17 | ntrain = ntrain if ntrain < len(X) else len(X) 18 | perm = perm[:ntrain] 19 | X, y = X[perm], y[perm] 20 | 21 | narms = len(np.unique(y)) 22 | Y = one_hot(y, narms) 23 | 24 | opt_rewards = np.ones((ntrain,)) 25 | return X, Y, opt_rewards 26 | 27 | 28 | def MnistEnvironment(key, ntrain=0): 29 | key, mykey = split(key) 30 | X, Y, opt_rewards = get_mnist(mykey, ntrain) 31 | return BanditEnvironment(key, X, Y, opt_rewards) 32 | -------------------------------------------------------------------------------- /bandits/agents/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from functools import partial 3 | 4 | class BanditAgent(ABC): 5 | def __init__(self, bandit): 6 | self.bandit = bandit 7 | 8 | @abstractmethod 9 | def init_bel(self, key, contexts, states, actions, rewards): 10 | ... 11 | 12 | @abstractmethod 13 | def sample_params(self, key, bel): 14 | ... 15 | 16 | @abstractmethod 17 | def update_bel(self, bel, context, action, reward): 18 | ... 19 | 20 | # TODO: Make it abstractmethod 21 | # @abstractmethod 22 | def predict_rewards(self, params, context): 23 | ... 24 | 25 | def choose_action(self, key, bel, context): 26 | params = self.sample_params(key, bel) 27 | predicted_rewards = self.predict_rewards(params, context) 28 | action = predicted_rewards.argmax() 29 | return action 30 | 31 | def __str__(self): 32 | return "BanditAgent" 33 | 34 | def __repr__(self): 35 | return str(self) 36 | -------------------------------------------------------------------------------- /bandits/scripts/tabular_test.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from jax import random 3 | from agents.linear_bandit import LinearBandit 4 | from environments.tabular_env import TabularEnvironment 5 | from .training_utils import train, MLP, summarize_results 6 | 7 | 8 | def main(ntrials=10, npulls=20, nwarmup=2000, seed=314): 9 | key = random.PRNGKey(seed) 10 | ntrain = 5000 11 | env = TabularEnvironment(key, ntrain=ntrain, name="statlog", intercept=False, path="./bandit-data") 12 | linear_params = {} 13 | num_arms = env.labels_onehot.shape[-1] 14 | 15 | time_init = time() 16 | warmup_rewards, rewards_trace, opt_rewards = train(key, LinearBandit, env, npulls, 17 | ntrials, 18 | linear_params, neural=False) 19 | 20 | rtotal, rstd = summarize_results(warmup_rewards, rewards_trace) 21 | time_end = time() 22 | running_time = time() - time_init 23 | print(f"Time : {running_time:0.3f}s") 24 | 25 | if __name__ == "__main__": 26 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=1.0.0 2 | certifi>=2021.10.8 3 | charset-normalizer>=2.0.9 4 | chex>=0.1.0 5 | cloudpickle>=2.0.0 6 | contextlib2>=21.6.0 7 | cycler>=0.11.0 8 | decorator>=5.1.0 9 | dm-tree>=0.1.6 10 | fire>=0.4.0 11 | flatbuffers>=2.0 12 | flax>=0.3.6 13 | fonttools>=4.28.3 14 | gast>=0.5.3 15 | idna>=3.3 16 | jax>=0.2.25 17 | jaxlib>=0.1.73 18 | joblib>=1.1.0 19 | kiwisolver>=1.3.2 20 | libtpu-nightly>=0.1.dev20211018 21 | matplotlib>=3.5.0 22 | ml-collections>=0.1.0 23 | msgpack>=1.0.3 24 | numpy>=1.21.4 25 | opt-einsum>=3.3.0 26 | optax @ git+git://github.com/deepmind/optax.git@16085e99edc8cbfcdc1251b6dad945f427c5eb18 27 | packaging>=21.3 28 | pandas>=1.3.4 29 | Pillow>=8.4.0 30 | pyparsing>=3.0.6 31 | python-dateutil>=2.8.2 32 | pytz>=2021.3 33 | PyYAML>=6.0 34 | requests>=2.26.0 35 | scikit-learn>=1.0.1 36 | scipy>=1.7.3 37 | seaborn>=0.11.2 38 | setuptools-scm>=6.3.2 39 | six>=1.16.0 40 | termcolor>=1.1.0 41 | tfp-nightly>=0.16.0.dev20211208 42 | threadpoolctl>=3.0.0 43 | tomli>=1.2.2 44 | toolz>=0.11.2 45 | typing_extensions>=4.0.1 46 | urllib3>=1.26.7 47 | jsl @ git+git://github.com/probml/jsl -------------------------------------------------------------------------------- /aistats2022-slides/assets/subspace-neural-bandit-diagram.tex: -------------------------------------------------------------------------------- 1 | \documentclass[11pt]{article} 2 | 3 | \usepackage{bm} 4 | \usepackage{tikz-cd} 5 | 6 | \begin{document} 7 | % https://tikzcd.yichuanshen.de/#N4Igdg9gJgpgziAXAbVABwnAlgFyxMJZARgBoAGAXVJADcBDAGwFcYkRgAdTgIwDMABAC8AvgH1gOALTERIEaXSZc+QigBMFanSat2XXoNFic8xSAzY8BImWLaGLNohDceERlDgBPALbvGbhwACxgceglpWTMlK1UiTXsaRz0XNw8vPwCg0PDIuQVYlRsUMnUHXWcQACdImQLzS2K1ZE1y5Mr2WtNCi2VrFvJSAGYKp303QThxSXqYvriSkhGx1JAI2ejepoGElY7xlwiexv74lGH9nUOOSYFp-Pmd8+QhqgO1g34BAEEGot2KCG7Wua3Snh8-g8OTCEQAVPJtDAoABzeBEUB8aoQXxIIYgHAQJBkUHOMDMRiMGiMeg8GCMAAKZxKIEYMD4J0x2NxiHxhKQ6l6WJxxJo-MQwyF3KQABYxUTEILzMKeQBWeVISXK6WIADsGsQMqlIsQADYDUbtSaABwG1XGnkATjtDuJfIVSq5JuI7s1rsQxBJ4s9IBVxKDCsllBEQA 8 | \begin{tikzcd} 9 | {\bf A} \arrow[rd] \arrow[rrd] & {\bf z}_{t-1} \arrow[r] \arrow[d] & {\bf z}_t \arrow[d] & \\ 10 | & \boldsymbol\theta_{t-1} & \boldsymbol\theta_{t} & \\ 11 | \boldsymbol\theta_* \arrow[ru] \arrow[rru] & r_{t-1} \arrow[u] & r_t \arrow[u] & \\ 12 | {\bf s}_{t-1} \arrow[ru] & a_{t-1} \arrow[u] & a_t \arrow[u] & {\bf s}_{t} \arrow[lu] 13 | \end{tikzcd} 14 | \end{document} -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Probabilistic machine learning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bandits/scripts/run_experiments.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import ml_collections 3 | import glob 4 | from datetime import datetime 5 | 6 | from . import movielens_exp as movielens_run 7 | from . import mnist_exp as mnist_run 8 | from . import tabular_exp as tabular_run 9 | from . import tabular_subspace_exp as tabular_sub_run 10 | 11 | def make_config(filepath): 12 | """Get the default hyperparameter configuration.""" 13 | config = ml_collections.ConfigDict() 14 | config.filepath = filepath 15 | config.ntrials = 10 16 | return config 17 | 18 | 19 | def main(experiment=None): 20 | timestamp = datetime.timestamp(datetime.now()) 21 | 22 | experiments = { 23 | "tabular": tabular_run, 24 | "mnist": mnist_run, 25 | "movielens": movielens_run, 26 | "tabular_subspace": tabular_sub_run 27 | } 28 | 29 | if experiment is not None: 30 | print(experiment) 31 | if experiment not in experiments: 32 | err = f"Experiment {experiment} not found. " 33 | err += f"Available experiments: {list(experiments.keys())}" 34 | raise ValueError(err) 35 | experiments = {experiment: experiments[experiment]} 36 | 37 | for experiment_name, experiment_run in experiments.items(): 38 | filename = f"./bandits/results/{experiment_name}_results_{timestamp}.csv" 39 | config = make_config(filename) 40 | experiment_run.main(config) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() -------------------------------------------------------------------------------- /bandits/agents/agent_utils.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import value_and_grad, jit 3 | from jax.random import normal 4 | from jax.lax import scan 5 | from jax.flatten_util import ravel_pytree 6 | 7 | 8 | def NIGupdate(bel, phi, reward): 9 | mu, Sigma, a, b = bel 10 | Lambda = jnp.linalg.inv(Sigma) 11 | Lambda_update = jnp.outer(phi, phi) + Lambda 12 | Sigma_update = jnp.linalg.inv(Lambda_update) 13 | mu_update = Sigma_update @ (Lambda @ mu + phi * reward) 14 | a_update = a + 1 / 2 15 | b_update = b + (reward ** 2 + mu.T @ Lambda @ mu - mu_update.T @ Lambda_update @ mu_update) / 2 16 | bel = (mu_update, Sigma_update, a_update, b_update) 17 | return bel 18 | 19 | 20 | def convert_params_from_subspace_to_full(params_subspace, projection_matrix, params_full): 21 | params = jnp.matmul(params_subspace, projection_matrix) + params_full 22 | return params 23 | 24 | 25 | def generate_random_basis(key, d, D): 26 | projection_matrix = normal(key, shape=(d, D)) 27 | projection_matrix = projection_matrix / jnp.linalg.norm(projection_matrix, axis=-1, keepdims=True) 28 | return projection_matrix 29 | 30 | 31 | def train(state, loss_fn, nepochs=300, has_aux=True): 32 | @jit 33 | def step(state, _): 34 | grad_fn = value_and_grad(loss_fn, has_aux=has_aux) 35 | val, grads = grad_fn(state.params) 36 | loss = val[0] if has_aux else val 37 | state = state.apply_gradients(grads=grads) 38 | flat_params, _ = ravel_pytree(state.params) 39 | return state, {"loss": loss, "params": flat_params} 40 | 41 | state, metrics = scan(step, state, jnp.empty(nepochs)) 42 | 43 | return state, metrics 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient Online Bayesian Inference for Neural Bandits 2 | 3 | **🚨Breaking changes🚨** 4 | See [aistats2022 release](https://github.com/probml/bandits/releases/tag/aistats2022) For a [JSL@ceeef0](https://github.com/probml/JSL/commit/ceeef0f02b185c7188afb40b977a9406d97c21ba) and `jax<=0.2.22` compatible version. 5 | 6 | ---- 7 | 8 | By [Gerardo Durán-Martín](http://github.com/gerdm), [Aleyna Kara](https://github.com/karalleyna), and [Kevin Murphy](https://github.com/murphyk) 9 | 10 | [Arxiv paper](https://arxiv.org/abs/2112.00195). 11 | 12 | [Slides](https://probml.github.io/bandits/1) 13 | 14 | MNIST-experiment 15 | 16 | ----- 17 | 18 | ## Installation 19 | 20 | ``` 21 | pip install fire 22 | pip install ml-collections 23 | ``` 24 | 25 | ## Reproduce the results 26 | 27 | There are two ways to reproduce the results from the paper 28 | 29 | ### Run the scripts 30 | 31 | To reproduce the results, `cd` into the project folder and run 32 | 33 | ```bash 34 | python bandits test 35 | ``` 36 | 37 | ```bash 38 | python bandits run_and_plot 39 | ``` 40 | 41 | ### Step by step 42 | 43 | If you only want to reproduce the results, run 44 | 45 | ```bash 46 | python bandits run_experiments 47 | ``` 48 | 49 | If you have previously reproduced the results and want to reproduce the plots, run 50 | 51 | ```bash 52 | python bandits plot_experiments 53 | ``` 54 | 55 | The results will be stored inside `bandits/figures/`. 56 | 57 | ### Execute the notebooks 58 | 59 | An alternative way to reproduce the results is to simply open and run [`subspace_bandits.ipynb`](https://github.com/probml/bandits/blob/main/bandits/scripts/subspace_bandits.ipynb) 60 | -------------------------------------------------------------------------------- /bandits/environments/environment.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | class BanditEnvironment: 6 | def __init__(self, key, X, Y, opt_rewards): 7 | # Randomise dataset rows 8 | n_obs, n_features = X.shape 9 | 10 | new_ixs = jax.random.choice(key, n_obs, (n_obs,), replace=False) 11 | 12 | X = jnp.asarray(X)[new_ixs] 13 | Y = jnp.asarray(Y)[new_ixs] 14 | opt_rewards = jnp.asarray(opt_rewards)[new_ixs] 15 | 16 | self.contexts = X 17 | self.labels_onehot = Y 18 | self.opt_rewards = opt_rewards 19 | _, self.n_arms = Y.shape 20 | self.n_steps, self.n_features = X.shape 21 | 22 | def get_state(self, t): 23 | return self.labels_onehot[t] 24 | 25 | def get_context(self, t): 26 | return self.contexts[t] 27 | 28 | def get_reward(self, t, action): 29 | return jnp.float32(self.labels_onehot[t][action]) 30 | 31 | def warmup(self, num_pulls): 32 | num_steps, num_actions = self.labels_onehot.shape 33 | # Create array of round-robin actions: 0, 1, 2, 0, 1, 2, 0, 1, 2, ... 34 | warmup_actions = jnp.arange(num_actions) 35 | warmup_actions = jnp.repeat(warmup_actions, num_pulls).reshape(num_actions, -1) 36 | actions = warmup_actions.reshape(-1, order="F").astype(jnp.int32) 37 | num_warmup_actions = len(actions) 38 | 39 | time_steps = jnp.arange(num_warmup_actions) 40 | 41 | def get_contexts_and_rewards(t, a): 42 | context = self.get_context(t) 43 | state = self.get_state(t) 44 | reward = self.get_reward(t, a) 45 | return context, state, reward 46 | 47 | contexts, states, rewards = jax.vmap(get_contexts_and_rewards, in_axes=(0, 0))(time_steps, actions) 48 | 49 | return contexts, states, actions, rewards 50 | -------------------------------------------------------------------------------- /bandits/agents/diagonal_subspace.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jsl.nlds.diagonal_extended_kalman_filter import DiagonalExtendedKalmanFilter 3 | from .ekf_subspace import SubspaceNeuralBandit 4 | from tensorflow_probability.substrates import jax as tfp 5 | 6 | tfd = tfp.distributions 7 | 8 | 9 | class DiagonalSubspaceNeuralBandit(SubspaceNeuralBandit): 10 | 11 | def __init__(self, num_features, num_arms, model, opt, prior_noise_variance, nwarmup=1000, nepochs=1000, 12 | system_noise=0.0, observation_noise=1.0, n_components=0.9999, random_projection=False): 13 | super().__init__(num_features, num_arms, model, opt, prior_noise_variance, nwarmup, nepochs, 14 | system_noise, observation_noise, n_components, random_projection) 15 | 16 | def init_bel(self, key, contexts, states, actions, rewards): 17 | bel = super().init_bel(key, contexts, states, actions, rewards) 18 | 19 | params_subspace_init, _, t = bel 20 | 21 | subspace_dim = self.n_components 22 | Q = jnp.ones(subspace_dim) * self.system_noise 23 | R = self.observation_noise 24 | 25 | covariance_subspace_init = jnp.ones(subspace_dim) * self.prior_noise_variance 26 | 27 | def fz(params): 28 | return params 29 | 30 | def fx(params, context, action): 31 | return self.predict_rewards(params, context)[action, None] 32 | 33 | ekf = DiagonalExtendedKalmanFilter(fz, fx, Q, R) 34 | self.ekf = ekf 35 | 36 | bel = (params_subspace_init, covariance_subspace_init, t) 37 | 38 | return bel 39 | 40 | def sample_params(self, key, bel): 41 | params_subspace, covariance_subspace, t = bel 42 | normal_dist = tfd.Normal(loc=params_subspace, scale=covariance_subspace) 43 | params_subspace = normal_dist.sample(seed=key) 44 | return params_subspace 45 | -------------------------------------------------------------------------------- /bandits/environments/ads16_env.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax.ops import index_add 3 | from jax.random import split, permutation 4 | from jax.nn import one_hot 5 | 6 | import pandas as pd 7 | 8 | import requests 9 | import io 10 | 11 | from environment import BanditEnvironment 12 | 13 | 14 | def get_ads16(key, ntrain, intercept): 15 | url = "https://raw.githubusercontent.com/probml/probml-data/main/data/ads16_preprocessed.csv" 16 | download = requests.get(url).content 17 | 18 | dataset = pd.read_csv(io.StringIO(download.decode('utf-8'))) 19 | dataset.drop(columns=['Unnamed: 0'], inplace=True) 20 | dataset = dataset.sample(frac=1).reset_index(drop=True).to_numpy() 21 | 22 | ntrain = ntrain if ntrain > 0 and ntrain < len(dataset) else len(dataset) 23 | nusers, nads = 120, 300 24 | users = jnp.arange(nusers) 25 | 26 | n_ads_per_user, rem = divmod(ntrain, nusers) 27 | 28 | mykey, key = split(key) 29 | indices = permutation(mykey, users)[:rem] 30 | n_ads_per_user = jnp.ones((nusers,)) * int(n_ads_per_user) 31 | n_ads_per_user = index_add(n_ads_per_user, indices, 1).astype(jnp.int32) 32 | 33 | indices = jnp.array([]) 34 | 35 | for user, nrow in enumerate(n_ads_per_user): 36 | mykey, key = split(key) 37 | df_indices = jnp.arange(user * nads, (user + 1) * nads) 38 | indices = jnp.append(indices, permutation(mykey, df_indices)[:nrow]).astype(jnp.int32) 39 | 40 | narms = 2 41 | dataset = dataset[indices] 42 | 43 | X = dataset[:, :-1] 44 | Y = one_hot(dataset[:, -1], narms) 45 | 46 | if intercept: 47 | X = jnp.concatenate([jnp.ones_like(X[:, :1]), X]) 48 | 49 | opt_rewards = jnp.ones((len(X),)) 50 | 51 | return X, Y, opt_rewards 52 | 53 | 54 | def ADS16Environment(key, ntrain, intercept=False): 55 | mykey, key = split(key) 56 | X, Y, opt_rewards = get_ads16(mykey, ntrain, intercept) 57 | return BanditEnvironment(key, X, Y, opt_rewards) 58 | -------------------------------------------------------------------------------- /bandits/environments/movielens_env.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | import jax.numpy as jnp 5 | 6 | from .environment import BanditEnvironment 7 | 8 | MOVIELENS_NUM_USERS = 943 9 | MOVIELENS_NUM_MOVIES = 1682 10 | 11 | 12 | def load_movielens_data(filepath): 13 | dataset = pd.read_csv(filepath, delimiter='\t', header=None) 14 | columns = {0: 'user_id', 1: 'item_id', 2: 'ranking', 3: 'timestamp'} 15 | dataset = dataset.rename(columns=columns) 16 | dataset['user_id'] -= 1 17 | dataset['item_id'] -= 1 18 | dataset = dataset.drop(columns="timestamp") 19 | 20 | rankings_matrix = np.zeros((MOVIELENS_NUM_USERS, MOVIELENS_NUM_MOVIES)) 21 | for i, row in dataset.iterrows(): 22 | rankings_matrix[row["user_id"], row["item_id"]] = float(row["ranking"]) 23 | return rankings_matrix 24 | 25 | 26 | # https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/environments/movielens_py_environment.py 27 | def get_movielens(rank_k, num_movies, repeat=5): 28 | """Initializes the MovieLens Bandit environment. 29 | Args: 30 | rank_k : (int) Which rank to use in the matrix factorization. 31 | batch_size: (int) Number of observations generated per call. 32 | num_movies: (int) Only the first `num_movies` movies will be used by the 33 | environment. The rest is cut out from the data. 34 | """ 35 | num_actions = num_movies 36 | context_dim = rank_k 37 | 38 | # Compute the matrix factorization. 39 | data_matrix = load_movielens_data("../bandit-data/ml-100k/u.data") 40 | # Keep only the first items. 41 | data_matrix = data_matrix[:, :num_movies] 42 | # Filter the users with no iterm rated. 43 | nonzero_users = list(np.nonzero(np.sum(data_matrix, axis=1) > 0.0)[0]) * repeat 44 | data_matrix = data_matrix[nonzero_users, :] 45 | effective_num_users = len(nonzero_users) 46 | 47 | # Compute the SVD. 48 | u, s, vh = np.linalg.svd(data_matrix, full_matrices=False) 49 | 50 | # Keep only the largest singular values. 51 | u_hat = u[:, :context_dim] * np.sqrt(s[:context_dim]) 52 | v_hat = np.transpose(np.transpose(vh[:rank_k, :]) * np.sqrt(s[:rank_k])) 53 | approx_ratings_matrix = np.matmul(u_hat, v_hat) 54 | opt_rewards = np.max(approx_ratings_matrix, axis=1) 55 | return u_hat, approx_ratings_matrix, opt_rewards 56 | 57 | 58 | def MovielensEnvironment(key, rank_k=20, num_movies=20, repeat=5, intercept=False): 59 | X, y, opt_rewards = get_movielens(rank_k, num_movies, repeat) 60 | 61 | if intercept: 62 | X = jnp.hstack([jnp.ones_like(X[:, :1]), X]) 63 | 64 | return BanditEnvironment(key, X, y, opt_rewards) 65 | -------------------------------------------------------------------------------- /bandits/agents/linear_kf_bandit.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax.lax import scan 3 | from jax.random import split 4 | from jsl.lds.kalman_filter import KalmanFilterNoiseEstimation 5 | from tensorflow_probability.substrates import jax as tfp 6 | 7 | tfd = tfp.distributions 8 | 9 | 10 | class LinearKFBandit: 11 | def __init__(self, num_features, num_arms, eta=6.0, lmbda=0.25): 12 | self.num_features = num_features 13 | self.num_arms = num_arms 14 | self.eta = eta 15 | self.lmbda = lmbda 16 | 17 | def init_bel(self, key, contexts, states, actions, rewards): 18 | v = 2 * self.eta * jnp.ones((self.num_arms,)) 19 | tau = jnp.ones((self.num_arms,)) 20 | 21 | Sigma0 = jnp.eye(self.num_features) 22 | mu0 = jnp.zeros((self.num_features,)) 23 | 24 | Sigma = 1. / self.lmbda * jnp.repeat(Sigma0[None, ...], self.num_arms, axis=0) 25 | mu = Sigma @ mu0 26 | A = jnp.eye(self.num_features) 27 | Q = 0 28 | 29 | self.kf = KalmanFilterNoiseEstimation(A, Q, mu, Sigma, v, tau) 30 | 31 | def warmup_update(bel, cur): 32 | context, action, reward = cur 33 | bel = self.update_bel(bel, context, action, reward) 34 | return bel, None 35 | 36 | bel = (mu, Sigma, v, tau) 37 | bel, _ = scan(warmup_update, bel, (contexts, actions, rewards)) 38 | 39 | return bel 40 | 41 | def update_bel(self, bel, context, action, reward): 42 | mu, Sigma, v, tau = bel 43 | state = (mu[action], Sigma[action], v[action], tau[action]) 44 | xs = (context, reward) 45 | 46 | mu_k, Sigma_k, v_k, tau_k = self.kf.kalman_step(state, xs) 47 | 48 | mu = mu.at[action].set(mu_k) 49 | Sigma = Sigma.at[action].set(Sigma_k) 50 | v = v.at[action].set(v_k) 51 | tau = tau.at[action].set(tau_k) 52 | 53 | bel = (mu, Sigma, v, tau) 54 | 55 | return bel 56 | 57 | def sample_params(self, key, bel): 58 | sigma_key, w_key = split(key, 2) 59 | mu, Sigma, v, tau = bel 60 | 61 | lmbda = tfd.InverseGamma(v / 2., (v * tau) / 2.).sample(seed=sigma_key) 62 | V = lmbda[:, None, None] 63 | 64 | covariance_matrix = V * Sigma 65 | w = tfd.MultivariateNormalFullCovariance(loc=mu, covariance_matrix=covariance_matrix).sample(seed=w_key) 66 | 67 | return w 68 | 69 | def choose_action(self, key, bel, context): 70 | # Thompson sampling strategy 71 | # Could also use epsilon greedy or UCB 72 | w = self.sample_params(key, bel) 73 | predicted_reward = jnp.einsum("m,km->k", context, w) 74 | action = predicted_reward.argmax() 75 | return action 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.csv 2 | *.png 3 | *.DS_Store 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # pytype static type analyzer 139 | .pytype/ 140 | 141 | # Cython debug symbols 142 | cython_debug/ 143 | -------------------------------------------------------------------------------- /bandits/training.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from functools import partial 4 | 5 | 6 | def reshape_vvmap(x): 7 | """ 8 | Reshape a 2D array to a 1D array. 9 | This is taken to be the output of a double vmap or pmap. 10 | """ 11 | shape_orig = x.shape 12 | shape_new = (shape_orig[0] * shape_orig[1], *shape_orig[2:]) 13 | return jnp.reshape(x, shape_new) 14 | 15 | 16 | def step(bel, t, key_base, bandit, env): 17 | key = jax.random.fold_in(key_base, t) 18 | context = env.get_context(t) 19 | 20 | action = bandit.choose_action(key, bel, context) 21 | reward = env.get_reward(t, action) 22 | bel = bandit.update_bel(bel, context, action, reward) 23 | 24 | hist = { 25 | "actions": action, 26 | "rewards": reward 27 | } 28 | 29 | return bel, hist 30 | 31 | 32 | def warmup_bandit(key, bandit, env, npulls): 33 | warmup_contexts, warmup_states, warmup_actions, warmup_rewards = env.warmup(npulls) 34 | bel = bandit.init_bel(key, warmup_contexts, warmup_states, warmup_actions, warmup_rewards) 35 | 36 | hist = { 37 | "states": warmup_states, 38 | "actions": warmup_actions, 39 | "rewards": warmup_rewards, 40 | } 41 | return bel, hist 42 | 43 | 44 | def run_bandit(key, bel, bandit, env, t_start=0): 45 | step_part = partial(step, key_base=key, bandit=bandit, env=env) 46 | steps = jnp.arange(t_start, env.n_steps) 47 | bel, hist = jax.lax.scan(step_part, bel, steps) 48 | return bel, hist 49 | 50 | 51 | def run_bandit_trials(key, bel, bandit, env, t_start=0, n_trials=1): 52 | keys = jax.random.split(key, n_trials) 53 | run_partal = partial(run_bandit, bel=bel, bandit=bandit, env=env, t_start=t_start) 54 | run_partial = jax.vmap(run_partal) 55 | 56 | bel, hist = run_partial(keys) 57 | return bel, hist 58 | 59 | def run_bandit_trials_pmap(key, bel, bandit, env, t_start=0, n_trials=1): 60 | keys = jax.random.split(key, n_trials) 61 | run_partial = partial(run_bandit, bel=bel, bandit=bandit, env=env, t_start=t_start) 62 | run_partial = jax.pmap(run_partial) 63 | 64 | bel, hist = run_partial(keys) 65 | return bel, hist 66 | 67 | 68 | def run_bandit_trials_multiple(key, bel, bandit, env, t_start, n_trials): 69 | """ 70 | Run vmap over multiple trials, and pmap over multiple devices 71 | """ 72 | ndevices = jax.local_device_count() 73 | nsamples_per_device = n_trials // ndevices 74 | keys = jax.random.split(key, ndevices) 75 | run_partial = partial(run_bandit_trials, bel=bel, bandit=bandit, env=env, t_start=t_start, n_trials=nsamples_per_device) 76 | run_partial = jax.pmap(run_partial) 77 | 78 | bel, hist = run_partial(keys) 79 | hist = jax.tree_map(reshape_vvmap, hist) 80 | bel = jax.tree_map(reshape_vvmap, bel) 81 | 82 | return bel, hist 83 | -------------------------------------------------------------------------------- /bandits/agents/linear_bandit.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import lax 3 | from jax import random 4 | 5 | from tensorflow_probability.substrates import jax as tfp 6 | 7 | tfd = tfp.distributions 8 | 9 | 10 | class LinearBandit: 11 | def __init__(self, num_features, num_arms, eta=6.0, lmbda=0.25): 12 | self.num_features = num_features 13 | self.num_arms = num_arms 14 | self.eta = eta 15 | self.lmbda = lmbda 16 | 17 | def init_bel(self, key, contexts, states, actions, rewards): 18 | mu = jnp.zeros((self.num_arms, self.num_features)) 19 | Sigma = 1. / self.lmbda * jnp.eye(self.num_features) * jnp.ones((self.num_arms, 1, 1)) 20 | a = self.eta * jnp.ones((self.num_arms,)) 21 | b = self.eta * jnp.ones((self.num_arms,)) 22 | 23 | initial_bel = (mu, Sigma, a, b) 24 | 25 | def update(bel, cur): # could do batch update 26 | context, action, reward = cur 27 | bel = self.update_bel(bel, context, action, reward) 28 | return bel, None 29 | 30 | bel, _ = lax.scan(update, initial_bel, (contexts, actions, rewards)) 31 | return bel 32 | 33 | def update_bel(self, bel, context, action, reward): 34 | mu, Sigma, a, b = bel 35 | 36 | mu_k, Sigma_k = mu[action], Sigma[action] 37 | Lambda_k = jnp.linalg.inv(Sigma_k) 38 | a_k, b_k = a[action], b[action] 39 | 40 | # weight params 41 | Lambda_update = jnp.outer(context, context) + Lambda_k 42 | Sigma_update = jnp.linalg.inv(Lambda_update) 43 | mu_update = Sigma_update @ (Lambda_k @ mu_k + context * reward) 44 | # noise params 45 | a_update = a_k + 1 / 2 46 | b_update = b_k + (reward ** 2 + mu_k.T @ Lambda_k @ mu_k - mu_update.T @ Lambda_update @ mu_update) / 2 47 | 48 | # Update only the chosen action at time t 49 | mu = mu.at[action].set(mu_update) 50 | Sigma = Sigma.at[action].set(Sigma_update) 51 | a = a.at[action].set(a_update) 52 | b = b.at[action].set(b_update) 53 | 54 | bel = (mu, Sigma, a, b) 55 | 56 | return bel 57 | 58 | def sample_params(self, key, bel): 59 | mu, Sigma, a, b = bel 60 | 61 | sigma_key, w_key = random.split(key, 2) 62 | sigma2_samp = tfd.InverseGamma(concentration=a, scale=b).sample(seed=sigma_key) 63 | covariance_matrix = sigma2_samp[:, None, None] * Sigma 64 | w = tfd.MultivariateNormalFullCovariance(loc=mu, covariance_matrix=covariance_matrix).sample( 65 | seed=w_key) 66 | return w 67 | 68 | def choose_action(self, key, bel, context): 69 | # Thompson sampling strategy 70 | # Could also use epsilon greedy or UCB 71 | w = self.sample_params(key, bel) 72 | predicted_reward = jnp.einsum("m,km->k", context, w) 73 | action = predicted_reward.argmax() 74 | return action 75 | -------------------------------------------------------------------------------- /bandits/agents/linear_bandit_wide.py: -------------------------------------------------------------------------------- 1 | # linear bandit with a single linear layer applied to a "wide" feature vector phi(s,a). 2 | # So reward = w' * phi(s,a). If the vector is block structured one-hot, in which we put 3 | # phi(s) into slot/block a, then we have w' phi(s,a) = w_a phi(s), which is a standard linaer model. 4 | # For example, suppose phi(s)=[s1,s2] and we have 3 actions. 5 | # Then phi(s,a=1) = [s1 s2 0 0 0 0], phi(s,a=3) = [0 0 0 0 s1 s2]. 6 | # Similarly let w = [w11 w12. w21 w22 w31 w32] where w(i,j) is weight for action i, feature j. 7 | # Then w'phi(s,a=1) = [w11 w12] = w_1. 8 | 9 | import jax.numpy as jnp 10 | from jax import vmap 11 | from jax.random import split 12 | from jax.nn import one_hot 13 | from jax.lax import scan 14 | 15 | from .agent_utils import NIGupdate 16 | 17 | from tensorflow_probability.substrates import jax as tfp 18 | 19 | tfd = tfp.distributions 20 | 21 | 22 | class LinearBanditWide: 23 | def __init__(self, num_features, num_arms, eta=6.0, lmbda=0.25): 24 | self.num_features = num_features 25 | self.num_arms = num_arms 26 | self.eta = eta 27 | self.lmbda = lmbda 28 | 29 | def widen(self, context, action): 30 | phi = jnp.zeros((self.num_arms, self.num_features)) 31 | phi = phi.at[action].set(context) 32 | return phi.flatten() 33 | 34 | def init_bel(self, key, contexts, states, actions, rewards): 35 | mu = jnp.zeros((self.num_arms * self.num_features)) 36 | Sigma = 1 / self.lmbda * jnp.eye(self.num_features * self.num_arms) 37 | a = self.eta * jnp.ones((self.num_arms * self.num_features,)) 38 | b = self.eta * jnp.ones((self.num_arms * self.num_features,)) 39 | 40 | initial_bel = (mu, Sigma, a, b) 41 | 42 | def update(bel, cur): # could do batch update 43 | phi, reward = cur 44 | bel = NIGupdate(bel, phi, reward) 45 | return bel, None 46 | 47 | phis = vmap(self.widen)(contexts, actions) 48 | bel, _ = scan(update, initial_bel, (phis, rewards)) 49 | return bel 50 | 51 | def update_bel(self, bel, context, action, reward): 52 | phi = self.widen(context, action) 53 | bel = NIGupdate(bel, phi, reward) 54 | return bel 55 | 56 | def sample_params(self, key, bel): 57 | mu, Sigma, a, b = bel 58 | 59 | sigma_key, w_key = split(key, 2) 60 | sigma2_samp = tfd.InverseGamma(concentration=a, scale=b).sample(seed=sigma_key) 61 | covariance_matrix = sigma2_samp * Sigma 62 | w = tfd.MultivariateNormalFullCovariance(loc=mu, covariance_matrix=covariance_matrix).sample( 63 | seed=w_key) 64 | return w 65 | 66 | def choose_action(self, key, bel, context): 67 | w = self.sample_params(key, bel) 68 | 69 | def get_reward(action): 70 | reward = one_hot(action, self.num_arms) 71 | phi = self.widen(context, action) 72 | reward = phi @ w 73 | return reward 74 | 75 | actions = jnp.arange(self.num_arms) 76 | rewards = vmap(get_reward)(actions) 77 | action = jnp.argmax(rewards) 78 | 79 | return action 80 | -------------------------------------------------------------------------------- /bandits/agents/neural_greedy.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.random import split 4 | from rebayes.sgd_filter import replay_sgd 5 | from bandits.agents.base import BanditAgent 6 | 7 | 8 | class NeuralGreedyBandit(BanditAgent): 9 | def __init__( 10 | self, num_features, num_arms, model, memory_size, tx, 11 | epsilon, n_inner=100 12 | ): 13 | self.num_features = num_features 14 | self.num_arms = num_arms 15 | self.model = model 16 | self.memory_size = memory_size 17 | self.tx = tx 18 | self.epsilon = epsilon 19 | self.n_inner = int(n_inner) 20 | 21 | def init_bel(self, key, contexts, states, actions, rewards): 22 | _, dim_in = contexts.shape 23 | X_dummy = jnp.ones((1, dim_in)) 24 | params = self.model.init(key, X_dummy) 25 | out = self.model.apply(params, X_dummy) 26 | dim_out = out.shape[-1] 27 | 28 | def apply_fn(params, xs): 29 | return self.model.apply(params, xs) 30 | 31 | def predict_rewards(params, contexts): 32 | return self.model.apply(params, contexts) 33 | 34 | agent = replay_sgd.FifoSGD( 35 | lossfn=lossfn_rmse_extra_dim, 36 | apply_fn=apply_fn, 37 | tx=self.tx, 38 | buffer_size=self.memory_size, 39 | dim_features=dim_in + 1, # +1 for the action 40 | dim_output=1, 41 | n_inner=self.n_inner 42 | ) 43 | 44 | bel = agent.init_bel(params, None) 45 | self.agent = agent 46 | self.predict_rewards = predict_rewards 47 | 48 | return bel 49 | 50 | def sample_params(self, key, bel): 51 | return bel.params 52 | 53 | def update_bel(self, bel, context, action, reward): 54 | xs = jnp.r_[context, action] 55 | bel = self.agent.update_state(bel, xs, reward) 56 | return bel 57 | 58 | def choose_action(self, key, bel, context): 59 | key, key_action = split(key) 60 | greedy = jax.random.bernoulli(key, 1 - self.epsilon) 61 | 62 | def explore(): 63 | action = jax.random.randint(key_action, shape=(), minval=0, maxval=self.num_arms) 64 | return action 65 | 66 | def exploit(): 67 | params = self.sample_params(key, bel) 68 | predicted_rewards = self.predict_rewards(params, context) 69 | action = predicted_rewards.argmax(axis=-1) 70 | return action 71 | 72 | action = jax.lax.cond(greedy == 1, exploit, explore) 73 | return action 74 | 75 | 76 | def lossfn_rmse_extra_dim(params, counter, xs, y, apply_fn): 77 | """ 78 | Lossfunction for regression problems. 79 | We consider an extra dimension in the input xs, which is the action. 80 | """ 81 | X = xs[..., :-1] 82 | action = xs[..., -1].astype(jnp.int32) 83 | buffer_size = X.shape[0] 84 | ix_slice = jnp.arange(buffer_size) 85 | yhat = apply_fn(params, X)[ix_slice, action].ravel() 86 | y = y.ravel() 87 | err = jnp.power(y - yhat, 2) 88 | loss = (err * counter).sum() / counter.sum() 89 | return loss 90 | -------------------------------------------------------------------------------- /bandits/scripts/run_experiments.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "!pip install -qqq ml-collections" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "tags": [] 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import os\n", 21 | "os.chdir(\"..\")\n", 22 | "\n", 23 | "import jax\n", 24 | "import ml_collections\n", 25 | "\n", 26 | "import pandas as pd\n", 27 | "\n", 28 | "import glob\n", 29 | "from datetime import datetime\n", 30 | "\n", 31 | "import scripts.movielens_exp as movielens_run\n", 32 | "import scripts.mnist_exp as mnist_run\n", 33 | "import scripts.tabular_exp as tabular_run\n", 34 | "import scripts.tabular_subspace_exp as tabular_sub_run\n", 35 | "\n", 36 | "print(jax.device_count())" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "def get_config(filepath):\n", 46 | " \"\"\"Get the default hyperparameter configuration.\"\"\"\n", 47 | " config = ml_collections.ConfigDict()\n", 48 | " config.filepath = filepath\n", 49 | " config.ntrials = 10\n", 50 | " return config" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "timestamp = datetime.timestamp(datetime.now())" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "# Run tabular experiments" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "tabular_filename = f\"./results/tabular_results_{timestamp}.csv\"\n", 76 | "config = get_config(tabular_filename)\n", 77 | "tabular_run.main(config)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "# Run MNIST experiments" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "mnist_filename = f\"./results/mnist_results_{timestamp}.csv\"\n", 94 | "config = get_config(mnist_filename)\n", 95 | "mnist_run.main(config)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "# Run movielens experiments" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "movielens_filename = f\"./results/movielens_results_{timestamp}.csv\"\n", 112 | "config = get_config(movielens_filename)\n", 113 | "movielens_run.main(config)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "# Run tabular subspace experiment" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "tabular_sub_filename = f\"./results/tabular_subspace_results_{timestamp}.csv\"\n", 130 | "config = get_config(tabular_sub_filename)\n", 131 | "tabular_sub_run.main(config)" 132 | ] 133 | } 134 | ], 135 | "metadata": { 136 | "interpreter": { 137 | "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" 138 | }, 139 | "kernelspec": { 140 | "display_name": "probml", 141 | "language": "python", 142 | "name": "probml" 143 | }, 144 | "language_info": { 145 | "codemirror_mode": { 146 | "name": "ipython", 147 | "version": 3 148 | }, 149 | "file_extension": ".py", 150 | "mimetype": "text/x-python", 151 | "name": "python", 152 | "nbconvert_exporter": "python", 153 | "pygments_lexer": "ipython3", 154 | "version": "3.9.7" 155 | } 156 | }, 157 | "nbformat": 4, 158 | "nbformat_minor": 4 159 | } 160 | -------------------------------------------------------------------------------- /bandits/scripts/training_utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax.linen as nn 4 | 5 | import warnings 6 | warnings.filterwarnings("ignore") 7 | 8 | class MLP(nn.Module): 9 | num_arms: int 10 | 11 | @nn.compact 12 | def __call__(self, x): 13 | x = nn.relu(nn.Dense(50, name="last_layer")(x)) 14 | x = nn.Dense(self.num_arms)(x) 15 | return x 16 | 17 | 18 | class MLPWide(nn.Module): 19 | num_arms: int 20 | 21 | @nn.compact 22 | def __call__(self, x): 23 | x = nn.relu(nn.Dense(200)(x)) 24 | x = nn.relu(nn.Dense(200, name="last_layer")(x)) 25 | x = nn.Dense(self.num_arms)(x) 26 | return x 27 | 28 | 29 | class LeNet5(nn.Module): 30 | num_arms: int 31 | 32 | @nn.compact 33 | def __call__(self, x): 34 | x = x if len(x.shape) > 1 else x[None, :] 35 | x = x.reshape((x.shape[0], 28, 28, 1)) 36 | x = nn.Conv(features=6, kernel_size=(5, 5))(x) 37 | x = nn.relu(x) 38 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID") 39 | x = nn.Conv(features=16, kernel_size=(5, 5), padding="VALID")(x) 40 | x = nn.relu(x) 41 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID") 42 | x = x.reshape((x.shape[0], -1)) # Flatten 43 | x = nn.Dense(features=120)(x) 44 | x = nn.relu(x) 45 | x = nn.Dense(features=84, name="last_layer")(x) # There are 10 classes in MNIST 46 | x = nn.relu(x) 47 | x = nn.Dense(features=self.num_arms)(x) 48 | return x.squeeze() 49 | 50 | 51 | def train(key, bandit_cls, env, npulls, ntrials, bandit_kwargs, neural=True): 52 | # TODO: deprecate neural flag 53 | nsteps, nfeatures = env.contexts.shape 54 | _, narms = env.labels_onehot.shape 55 | bandit = bandit_cls(nfeatures, narms, **bandit_kwargs) 56 | 57 | warmup_contexts, warmup_states, warmup_actions, warmup_rewards = env.warmup(npulls) 58 | 59 | key, mykey = jax.random.split(key) 60 | bel = bandit.init_bel(mykey, warmup_contexts, warmup_states, warmup_actions, warmup_rewards) 61 | warmup = (warmup_contexts, warmup_states, warmup_actions, warmup_rewards) 62 | 63 | def single_trial(key): 64 | _, _, rewards = run_bandit(key, bandit, bel, env, warmup, nsteps=nsteps, neural=neural) 65 | return rewards 66 | 67 | if ntrials > 1: 68 | keys = jax.random.split(key, ntrials) 69 | rewards_trace = jax.vmap(single_trial)(keys) 70 | else: 71 | rewards_trace = single_trial(key) 72 | 73 | return warmup_rewards, rewards_trace, env.opt_rewards 74 | 75 | 76 | def run_bandit(key, bandit, bel, env, warmup, nsteps, neural=True): 77 | def step(bel, cur): 78 | mykey, t = cur 79 | context = env.get_context(t) 80 | 81 | action = bandit.choose_action(mykey, bel, context) 82 | reward = env.get_reward(t, action) 83 | bel = bandit.update_bel(bel, context, action, reward) 84 | 85 | return bel, (context, action, reward) 86 | 87 | warmup_contexts, _, warmup_actions, warmup_rewards = warmup 88 | nwarmup = len(warmup_rewards) 89 | 90 | steps = jnp.arange(nsteps - nwarmup) + nwarmup 91 | keys = jax.random.split(key, nsteps - nwarmup) 92 | 93 | if neural: 94 | bandit.init_contexts_and_states(env.contexts[steps], env.labels_onehot[steps]) 95 | mu, Sigma, a, b, params, _ = bel 96 | bel = (mu, Sigma, a, b, params, 0) 97 | 98 | _, (contexts, actions, rewards) = jax.lax.scan(step, bel, (keys, steps)) 99 | 100 | contexts = jnp.vstack([warmup_contexts, contexts]) 101 | actions = jnp.append(warmup_actions, actions) 102 | rewards = jnp.append(warmup_rewards, rewards) 103 | 104 | return contexts, actions, rewards 105 | 106 | 107 | def summarize_results(warmup_rewards, rewards): 108 | """ 109 | Print a summary of running a Bandit algorithm for a number of runs 110 | """ 111 | warmup_reward = warmup_rewards.sum() 112 | rewards = rewards.sum(axis=-1) 113 | r_mean = rewards.mean() 114 | r_std = rewards.std() 115 | r_total = r_mean + warmup_reward 116 | 117 | print(f"Expected Reward : {r_total:0.2f} ± {r_std:0.2f}") 118 | return r_total, r_std 119 | -------------------------------------------------------------------------------- /bandits/scripts/thompson_sampling_bernoulli.py: -------------------------------------------------------------------------------- 1 | # Resolution of a Multi-Armed Bandit problem 2 | # using Thompson Sampling. 3 | # Author: Gerardo Durán-Martín (@gerdm) 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import matplotlib.pyplot as plt 8 | from jax import random 9 | from jax.nn import one_hot 10 | from jax.scipy.stats import beta 11 | from functools import partial 12 | import matplotlib.animation as animation 13 | 14 | 15 | class BetaBernoulliBandits: 16 | def __init__(self, K): 17 | self.K = K 18 | 19 | def sample(self, key, params): 20 | alphas = params["alpha"] 21 | betas = params["beta"] 22 | params_sample = random.beta(key, alphas, betas) 23 | return params_sample 24 | 25 | def predict_rewards(self, params_sample): 26 | return params_sample 27 | 28 | def update(self, action, params, reward): 29 | alphas = params["alpha"] 30 | betas = params["beta"] 31 | # Update policy distribution 32 | ind_vector = one_hot(action, self.K) 33 | alphas_posterior = alphas + ind_vector * reward 34 | betas_posterior = betas + ind_vector * (1 - reward) 35 | return { 36 | "alpha": alphas_posterior, 37 | "beta": betas_posterior 38 | } 39 | 40 | 41 | def true_reward(key, action, mean_rewards): 42 | reward = random.bernoulli(key, mean_rewards[action]) 43 | return reward 44 | 45 | 46 | def thompson_sampling_step(model_params, key, model, environment): 47 | """ 48 | Context-free implementation of the Thompson sampling algorithm. 49 | This implementation considers a single step 50 | 51 | Parameters 52 | ---------- 53 | model_params: dict 54 | environment: function 55 | key: jax.random.PRNGKey 56 | moidel: instance of a Bandit model 57 | """ 58 | key_sample, key_reward = random.split(key) 59 | params = model.sample(key_sample, model_params) 60 | pred_rewards = model.predict_rewards(params) 61 | action = pred_rewards.argmax() 62 | reward = environment(key_reward, action) 63 | model_params = model.update(action, model_params, reward) 64 | prob_arm = model_params["alpha"] / (model_params["alpha"] + model_params["beta"]) 65 | return model_params, (model_params, prob_arm) 66 | 67 | 68 | if __name__ == "__main__": 69 | T = 200 70 | key = random.PRNGKey(31415) 71 | keys = random.split(key, T) 72 | mean_rewards = jnp.array([0.4, 0.5, 0.2, 0.9]) 73 | K = len(mean_rewards) 74 | bbbandit = BetaBernoulliBandits(mean_rewards) 75 | init_params = {"alpha": jnp.ones(K), 76 | "beta": jnp.ones(K)} 77 | 78 | environment = partial(true_reward, mean_rewards=mean_rewards) 79 | thompson_partial = partial(thompson_sampling_step, 80 | model=BetaBernoulliBandits(K), 81 | environment=environment) 82 | posteriors, (hist, prob_arm_hist) = jax.lax.scan(thompson_partial, init_params, keys) 83 | 84 | p_range = jnp.linspace(0, 1, 100) 85 | bandits_pdf_hist = beta.pdf(p_range[:, None, None], hist["alpha"][None, ...], hist["beta"][None, ...]) 86 | colors = ["orange", "blue", "green", "red"] 87 | colors = [f"tab:{color}" for color in colors] 88 | 89 | _, n_steps, _ = bandits_pdf_hist.shape 90 | fig, ax = plt.subplots(1, 4, figsize=(13, 2)) 91 | filepath = "./bandits.mp4" 92 | 93 | def animate(t): 94 | for k, (axi, color) in enumerate(zip(ax, colors)): 95 | axi.cla() 96 | bandit = bandits_pdf_hist[:, t, k] 97 | axi.plot(p_range, bandit, c=color) 98 | axi.set_xlim(0, 1) 99 | n_pos = hist["alpha"][t, k].item() - 1 100 | n_trials = hist["beta"][t, k].item() + n_pos - 1 101 | axi.set_title(f"t={t+1}\np={mean_rewards[k]:0.2f}\n{n_pos:.0f}/{n_trials:.0f}") 102 | plt.tight_layout() 103 | return ax 104 | 105 | ani = animation.FuncAnimation(fig, animate, frames=n_steps) 106 | ani.save(filepath, dpi=300, bitrate=-1, fps=10) 107 | 108 | plt.plot(prob_arm_hist) 109 | plt.legend([f"mean reward: {reward:0.2f}" for reward in mean_rewards], loc="lower right") 110 | plt.savefig("beta-bernoulli-thompson-sampling.pdf") 111 | plt.show() 112 | -------------------------------------------------------------------------------- /bandits/agents/low_rank_filter_bandit.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.flatten_util import ravel_pytree 4 | from rebayes.low_rank_filter import lofi 5 | from bandits.agents.base import BanditAgent 6 | from rebayes.utils.sampling import sample_dlr_single 7 | from tensorflow_probability.substrates import jax as tfp 8 | 9 | tfd = tfp.distributions 10 | 11 | class LowRankFilterBandit(BanditAgent): 12 | """ 13 | Regression bandit with low-rank filter. 14 | We consider a single neural network with k 15 | outputs corresponding to the k arms. 16 | """ 17 | def __init__(self, num_features, num_arms, model, memory_size, emission_covariance, 18 | initial_covariance, dynamics_weights, dynamics_covariance): 19 | self.num_features = num_features 20 | self.num_arms = num_arms 21 | self.model = model 22 | self.memory_size = memory_size 23 | 24 | self.emission_covariance = emission_covariance 25 | self.initial_covariance = initial_covariance 26 | self.dynamics_weights = dynamics_weights 27 | self.dynamics_covariance = dynamics_covariance 28 | 29 | def init_bel(self, key, contexts, states, actions, rewards): 30 | _, dim_in = contexts.shape 31 | params = self.model.init(key, jnp.ones((1, dim_in))) 32 | flat_params, recfn = ravel_pytree(params) 33 | 34 | def apply_fn(flat_params, xs): 35 | context, action = xs 36 | return self.model.apply(recfn(flat_params), context)[action, None] 37 | 38 | def predict_rewards(flat_params, context): 39 | return self.model.apply(recfn(flat_params), context) 40 | 41 | agent = lofi.RebayesLoFiDiagonal( 42 | dynamics_weights=self.dynamics_weights, 43 | dynamics_covariance=self.dynamics_covariance, 44 | emission_mean_function=apply_fn, 45 | emission_cov_function=lambda m, x: self.emission_covariance, 46 | adaptive_emission_cov=False, 47 | dynamics_covariance_inflation_factor=0.0, 48 | memory_size=self.memory_size, 49 | steady_state=False, 50 | emission_dist=tfd.Normal 51 | ) 52 | bel = agent.init_bel(flat_params, self.initial_covariance) 53 | self.agent = agent 54 | self.predict_rewards = predict_rewards 55 | 56 | return bel 57 | 58 | def sample_params(self, key, bel): 59 | params_samp = self.agent.sample_state(bel, key, 1).ravel() 60 | return params_samp 61 | 62 | def update_bel(self, bel, context, action, reward): 63 | xs = (context, action) 64 | bel = self.agent.update_state(bel, xs, reward) 65 | return bel 66 | 67 | 68 | class LowRankGreedy(LowRankFilterBandit): 69 | """ 70 | Low-rank filter with greedy action selection. 71 | """ 72 | def __init__(self, num_features, num_arms, model, memory_size, emission_covariance, 73 | initial_covariance, dynamics_weights, dynamics_covariance, epsilon): 74 | super().__init__(num_features, num_arms, model, memory_size, emission_covariance, 75 | initial_covariance, dynamics_weights, dynamics_covariance) 76 | self.epsilon = epsilon 77 | 78 | def choose_action(self, key, bel, context): 79 | key, key_action = jax.random.split(key) 80 | greedy = jax.random.bernoulli(key, 1 - self.epsilon) 81 | if greedy: 82 | rewards = self.predict_rewards(bel.state, context) 83 | action = jnp.argmax(rewards) 84 | else: 85 | action = jax.random.randint(key_action, (1,), 0, self.num_arms) 86 | return action 87 | 88 | 89 | def choose_action(self, key, bel, context): 90 | key, key_action = jax.random.split(key) 91 | greedy = jax.random.bernoulli(key, 1 - self.epsilon) 92 | 93 | def explore(): 94 | action = jax.random.randint(key_action, shape=(), minval=0, maxval=self.num_arms) 95 | return action 96 | 97 | def exploit(): 98 | params = bel.mean 99 | predicted_rewards = self.predict_rewards(params, context) 100 | action = predicted_rewards.argmax(axis=-1) 101 | return action 102 | 103 | action = jax.lax.cond(greedy == 1, exploit, explore) 104 | return action -------------------------------------------------------------------------------- /bandits/scripts/tabular_subspace_exp.py: -------------------------------------------------------------------------------- 1 | from jax.random import split, PRNGKey 2 | 3 | import optax 4 | import pandas as pd 5 | 6 | import argparse 7 | from time import time 8 | 9 | from environments.tabular_env import TabularEnvironment 10 | from agents.ekf_subspace import SubspaceNeuralBandit 11 | 12 | from .training_utils import train, MLP, summarize_results 13 | from .mnist_exp import mapping, method_ordering 14 | 15 | 16 | def main(config): 17 | # Tabular datasets 18 | key = PRNGKey(314) 19 | key, shuttle_key, covetype_key, adult_key = split(key, 4) 20 | ntrain = 5000 21 | 22 | shuttle_env = TabularEnvironment(shuttle_key, ntrain=ntrain, name='statlog', intercept=True) 23 | covertype_env = TabularEnvironment(covetype_key, ntrain=ntrain, name='covertype', intercept=True) 24 | adult_env = TabularEnvironment(adult_key, ntrain=ntrain, name='adult', intercept=True) 25 | environments = {"shuttle": shuttle_env, "covertype": covertype_env, "adult": adult_env} 26 | 27 | learning_rate = 0.05 28 | momentum = 0.9 29 | 30 | # Subspace Neural Bandit with SVD 31 | npulls, nwarmup = 20, 2000 32 | observation_noise = 0.0 33 | prior_noise_variance = 1e-4 34 | nepochs = 1000 35 | random_projection = False 36 | 37 | ekf_sub_svd = {"opt": optax.sgd(learning_rate, momentum), "prior_noise_variance": prior_noise_variance, 38 | "nwarmup": nwarmup, "nepochs": nepochs, 39 | "observation_noise": observation_noise, 40 | "random_projection": random_projection} 41 | 42 | # Subspace Neural Bandit without SVD 43 | ekf_sub_rnd = ekf_sub_svd.copy() 44 | ekf_sub_rnd["random_projection"] = True 45 | 46 | bandits = {"EKF Subspace SVD": {"kwargs": ekf_sub_svd, 47 | "bandit": SubspaceNeuralBandit 48 | }, 49 | "EKF Subspace RND": {"kwargs": ekf_sub_rnd, 50 | "bandit": SubspaceNeuralBandit 51 | } 52 | } 53 | 54 | results = [] 55 | subspace_dimensions = [2, 3, 4, 5, 10, 15, 20, 30, 40, 50, 60, 100, 150, 200, 300, 400, 500] 56 | model_name = "MLP1" 57 | for env_name, env in environments.items(): 58 | print("Environment : ", env_name) 59 | num_arms = env.labels_onehot.shape[-1] 60 | for subspace_dim in subspace_dimensions: 61 | model = MLP(num_arms) 62 | for bandit_name, properties in bandits.items(): 63 | properties["kwargs"]["n_components"] = subspace_dim 64 | properties["kwargs"]["model"] = model 65 | key, mykey = split(key) 66 | print(f"\tBandit : {bandit_name}") 67 | start = time() 68 | warmup_rewards, rewards_trace, opt_rewards = train(mykey, properties["bandit"], env, npulls, 69 | config.ntrials, 70 | properties["kwargs"], neural=False) 71 | 72 | rtotal, rstd = summarize_results(warmup_rewards, rewards_trace) 73 | end = time() 74 | print(f"\t\tTime : {end - start}") 75 | results.append((env_name, bandit_name, model_name, subspace_dim, end - start, rtotal, rstd)) 76 | 77 | df = pd.DataFrame(results) 78 | df = df.rename(columns={0: "Dataset", 1: "Method", 2: "Model", 3: "Subspace Dim", 4: "Time", 5: "Reward", 6: "Std"}) 79 | 80 | df["Method"] = df["Method"].apply(lambda v: mapping[v]) 81 | 82 | df["Subspace Dim"] = df['Subspace Dim'].astype(int) 83 | df["Reward"] = df['Reward'].astype(float) 84 | df["Time"] = df['Time'].astype(float) 85 | df["Std"] = df['Std'].astype(float) 86 | 87 | df["Rank"] = df["Method"].apply(lambda v: method_ordering[v]) 88 | df.to_csv(config.filepath) 89 | 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument('--ntrials', type=int, nargs='?', const=10, default=10) 94 | filepath = "bandits/results/tabular_subspace_results.csv" 95 | parser.add_argument('--filepath', type=str, nargs='?', const=filepath, default=filepath) 96 | 97 | # Parse the argument 98 | args = parser.parse_args() 99 | main(args) 100 | -------------------------------------------------------------------------------- /bandits/agents/neural_linear_bandit_wide.py: -------------------------------------------------------------------------------- 1 | # reward = w' * phi(s,a; theta), where theta is learned 2 | 3 | import jax.numpy as jnp 4 | from jax import vmap 5 | from jax.random import split 6 | from jax.nn import one_hot 7 | from jax.lax import scan, cond 8 | 9 | import optax 10 | 11 | from flax.training import train_state 12 | 13 | from .agent_utils import NIGupdate, train 14 | from scripts.training_utils import MLP 15 | 16 | from tensorflow_probability.substrates import jax as tfp 17 | 18 | tfd = tfp.distributions 19 | 20 | 21 | class NeuralLinearBanditWide: 22 | def __init__(self, num_features, num_arms, model=None, opt=optax.adam(learning_rate=1e-2), eta=6.0, lmbda=0.25, 23 | update_step_mod=100, batch_size=5000, nepochs=3000): 24 | self.num_features = num_features 25 | self.num_arms = num_arms 26 | 27 | if model is None: 28 | self.model = MLP(500, num_arms) 29 | else: 30 | try: 31 | self.model = model() 32 | except: 33 | self.model = model 34 | 35 | self.opt = opt 36 | self.eta = eta 37 | self.lmbda = lmbda 38 | self.update_step_mod = update_step_mod 39 | self.batch_size = batch_size 40 | self.nepochs = nepochs 41 | 42 | def init_bel(self, key, contexts, states, actions, rewards): 43 | 44 | key, mykey = split(key) 45 | initial_params = self.model.init(mykey, jnp.zeros((self.num_features,))) 46 | initial_train_state = train_state.TrainState.create(apply_fn=self.model.apply, params=initial_params, 47 | tx=self.opt) 48 | 49 | mu = jnp.zeros((self.num_arms, 500)) 50 | Sigma = 1 * self.lmbda * jnp.eye(500) * jnp.ones((self.num_arms, 1, 1)) 51 | a = self.eta * jnp.ones((self.num_arms,)) 52 | b = self.eta * jnp.ones((self.num_arms,)) 53 | t = 0 54 | 55 | def update(bel, x): 56 | context, action, reward = x 57 | return self.update_bel(bel, context, action, reward), None 58 | 59 | initial_bel = (mu, Sigma, a, b, initial_train_state, t) 60 | X = vmap(self.widen)(contexts)(actions) 61 | self.init_contexts_and_states(contexts, states) 62 | (bel, key), _ = scan(update, initial_bel, (contexts, actions, rewards)) 63 | return bel 64 | 65 | def featurize(self, params, x, feature_layer="last_layer"): 66 | _, inter = self.model.apply(params, x, capture_intermediates=True) 67 | Phi, *_ = inter["intermediates"][feature_layer]["__call__"] 68 | return Phi 69 | 70 | def widen(self, context, action): 71 | phi = jnp.zeros((self.num_arms, self.num_features)) 72 | phi[action] = context 73 | return phi.flatten() 74 | 75 | def cond_update_params(self, t): 76 | return (t % self.update_step_mod) == 0 77 | 78 | def init_contexts_and_states(self, contexts, states, actions, rewards): 79 | self.X = vmap(self.widen)(contexts)(actions) 80 | self.Y = rewards 81 | 82 | def update_bel(self, bel, context, action, reward): 83 | 84 | _, _, _, _, state, t = bel 85 | sgd_params = (state, t) 86 | 87 | phi = self.widen(self, context, action) 88 | state = cond(self.cond_update_params(t), 89 | lambda sgd_params: train(self.model, sgd_params[0], phi, reward, 90 | nepochs=self.nepochs, t=sgd_params[1]), 91 | lambda sgd_params: sgd_params[0], sgd_params) 92 | lin_bel = NIGupdate(bel, phi, reward) 93 | bel = (*lin_bel, state, t + 1) 94 | 95 | return bel 96 | 97 | def sample_params(self, key, bel): 98 | mu, Sigma, a, b, _, _ = bel 99 | sigma_key, w_key = split(key) 100 | sigma2 = tfd.InverseGamma(concentration=a, scale=b).sample(seed=sigma_key) 101 | covariance_matrix = sigma2[:, None, None] * Sigma 102 | w = tfd.MultivariateNormalFullCovariance(loc=mu, covariance_matrix=covariance_matrix).sample(seed=w_key) 103 | return w 104 | 105 | def choose_action(self, key, bel, context): 106 | w = self.sample_params(key, bel) 107 | 108 | def get_reward(action): 109 | reward = one_hot(action, self.num_arms) 110 | phi = self.widen(context, reward) 111 | reward = phi * w 112 | return reward 113 | 114 | actions = jnp.arange(self.num_arms) 115 | rewards = vmap(get_reward)(actions) 116 | action = jnp.argmax(rewards) 117 | 118 | return action 119 | -------------------------------------------------------------------------------- /bandits/agents/ekf_orig_diag.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import jit 3 | from jax.flatten_util import ravel_pytree 4 | 5 | import optax 6 | from flax.training import train_state 7 | 8 | from .agent_utils import train 9 | from scripts.training_utils import MLP 10 | from jsl.nlds.diagonal_extended_kalman_filter import DiagonalExtendedKalmanFilter 11 | 12 | from tensorflow_probability.substrates import jax as tfp 13 | 14 | tfd = tfp.distributions 15 | 16 | 17 | class DiagonalNeuralBandit: 18 | def __init__(self, num_features, num_arms, model, opt, prior_noise_variance, nwarmup=1000, nepochs=1000, 19 | system_noise=0.0, observation_noise=1.0): 20 | """ 21 | Subspace Neural Bandit implementation. 22 | Parameters 23 | ---------- 24 | num_arms: int 25 | Number of bandit arms / number of actions 26 | environment : Environment 27 | The environment to be used. 28 | model : flax.nn.Module 29 | The flax model to be used for the bandits. Note that this model is independent of the 30 | model architecture. The only constraint is that the last layer should have the same 31 | number of outputs as the number of arms. 32 | learning_rate : float 33 | The learning rate for the optimizer used for the warmup phase. 34 | momentum : float 35 | The momentum for the optimizer used for the warmup phase. 36 | nepochs : int 37 | The number of epochs to be used for the warmup SGD phase. 38 | """ 39 | self.num_features = num_features 40 | self.num_arms = num_arms 41 | 42 | if model is None: 43 | self.model = MLP(500, num_arms) 44 | else: 45 | try: 46 | self.model = model() 47 | except: 48 | self.model = model 49 | 50 | self.opt = opt 51 | self.prior_noise_variance = prior_noise_variance 52 | self.nwarmup = nwarmup 53 | self.nepochs = nepochs 54 | self.system_noise = system_noise 55 | self.observation_noise = observation_noise 56 | 57 | def init_bel(self, key, contexts, states, actions, rewards): 58 | initial_params = self.model.init(key, jnp.ones((1, self.num_features)))["params"] 59 | initial_train_state = train_state.TrainState.create(apply_fn=self.model.apply, params=initial_params, 60 | tx=self.opt) 61 | 62 | def loss_fn(params): 63 | pred_reward = self.model.apply({"params": params}, contexts)[:, actions.astype(int)] 64 | loss = optax.l2_loss(pred_reward, states[:, actions.astype(int)]).mean() 65 | return loss, pred_reward 66 | 67 | warmup_state, _ = train(initial_train_state, loss_fn=loss_fn, nepochs=self.nepochs) 68 | 69 | params_full_init, reconstruct_tree_params = ravel_pytree(warmup_state.params) 70 | nparams = params_full_init.size 71 | 72 | Q = jnp.ones(nparams) * self.system_noise 73 | R = self.observation_noise 74 | 75 | params_subspace_init = jnp.zeros(nparams) 76 | covariance_subspace_init = jnp.ones(nparams) * self.prior_noise_variance 77 | 78 | def predict_rewards(params, context): 79 | params_tree = reconstruct_tree_params(params) 80 | outputs = self.model.apply({"params": params_tree}, context) 81 | return outputs 82 | 83 | self.predict_rewards = predict_rewards 84 | 85 | def fz(params): 86 | return params 87 | 88 | def fx(params, context, action): 89 | return predict_rewards(params, context)[action, None] 90 | 91 | ekf = DiagonalExtendedKalmanFilter(fz, fx, Q, R) 92 | self.ekf = ekf 93 | bel = (params_subspace_init, covariance_subspace_init, 0) 94 | 95 | return bel 96 | 97 | def sample_params(self, key, bel): 98 | params_subspace, covariance_subspace, t = bel 99 | normal_dist = tfd.Normal(loc=params_subspace, scale=covariance_subspace) 100 | params_subspace = normal_dist.sample(seed=key) 101 | return params_subspace 102 | 103 | def update_bel(self, bel, context, action, reward): 104 | xs = (reward, (context, action)) 105 | bel, _ = jit(self.ekf.filter_step)(bel, xs) 106 | return bel 107 | 108 | def choose_action(self, key, bel, context): 109 | # Thompson sampling strategy 110 | # Could also use epsilon greedy or UCB 111 | w = self.sample_params(key, bel) 112 | predicted_reward = self.predict_rewards(w, context) 113 | action = predicted_reward.argmax() 114 | return action 115 | -------------------------------------------------------------------------------- /bandits/agents/ekf_orig_full.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import jit 3 | from jax.flatten_util import ravel_pytree 4 | 5 | import optax 6 | 7 | from flax.training import train_state 8 | 9 | from .agent_utils import train 10 | from jsl.nlds.extended_kalman_filter import ExtendedKalmanFilter 11 | from scripts.training_utils import MLP 12 | from tensorflow_probability.substrates import jax as tfp 13 | 14 | tfd = tfp.distributions 15 | 16 | 17 | class EKFNeuralBandit: 18 | def __init__(self, num_features, num_arms, model, opt, prior_noise_variance, nwarmup=1000, nepochs=1000, 19 | system_noise=0.0, observation_noise=1.0): 20 | """ 21 | Subspace Neural Bandit implementation. 22 | Parameters 23 | ---------- 24 | num_arms: int 25 | Number of bandit arms / number of actions 26 | environment : Environment 27 | The environment to be used. 28 | model : flax.nn.Module 29 | The flax model to be used for the bandits. Note that this model is independent of the 30 | model architecture. The only constraint is that the last layer should have the same 31 | number of outputs as the number of arms. 32 | learning_rate : float 33 | The learning rate for the optimizer used for the warmup phase. 34 | momentum : float 35 | The momentum for the optimizer used for the warmup phase. 36 | nepochs : int 37 | The number of epochs to be used for the warmup SGD phase. 38 | """ 39 | self.num_features = num_features 40 | self.num_arms = num_arms 41 | 42 | if model is None: 43 | self.model = MLP(500, num_arms) 44 | else: 45 | try: 46 | self.model = model() 47 | except: 48 | self.model = model 49 | 50 | self.opt = opt 51 | self.prior_noise_variance = prior_noise_variance 52 | self.nwarmup = nwarmup 53 | self.nepochs = nepochs 54 | self.system_noise = system_noise 55 | self.observation_noise = observation_noise 56 | 57 | def init_bel(self, key, contexts, states, actions, rewards): 58 | initial_params = self.model.init(key, jnp.ones((1, self.num_features)))["params"] 59 | initial_train_state = train_state.TrainState.create(apply_fn=self.model.apply, params=initial_params, 60 | tx=self.opt) 61 | 62 | def loss_fn(params): 63 | pred_reward = self.model.apply({"params": params}, contexts)[:, actions.astype(int)] 64 | loss = optax.l2_loss(pred_reward, states[:, actions.astype(int)]).mean() 65 | return loss, pred_reward 66 | 67 | warmup_state, _ = train(initial_train_state, loss_fn=loss_fn, nepochs=self.nepochs) 68 | 69 | params_full_init, reconstruct_tree_params = ravel_pytree(warmup_state.params) 70 | nparams = params_full_init.size 71 | 72 | Q = jnp.eye(nparams) * self.system_noise 73 | R = jnp.eye(1) * self.observation_noise 74 | 75 | params_subspace_init = jnp.zeros(nparams) 76 | covariance_subspace_init = jnp.eye(nparams) * self.prior_noise_variance 77 | 78 | def predict_rewards(params, context): 79 | params_tree = reconstruct_tree_params(params) 80 | outputs = self.model.apply({"params": params_tree}, context) 81 | return outputs 82 | 83 | self.predict_rewards = predict_rewards 84 | 85 | def fz(params): 86 | return params 87 | 88 | def fx(params, context, action): 89 | return predict_rewards(params, context)[action, None] 90 | 91 | ekf = ExtendedKalmanFilter(fz, fx, Q, R) 92 | self.ekf = ekf 93 | bel = (params_subspace_init, covariance_subspace_init, 0) 94 | 95 | return bel 96 | 97 | def sample_params(self, key, bel): 98 | params_subspace, covariance_subspace, t = bel 99 | mv_normal = tfd.MultivariateNormalFullCovariance(loc=params_subspace, covariance_matrix=covariance_subspace) 100 | params_subspace = mv_normal.sample(seed=key) 101 | return params_subspace 102 | 103 | def update_bel(self, bel, context, action, reward): 104 | xs = (reward, (context, action)) 105 | bel, _ = jit(self.ekf.filter_step)(bel, xs) 106 | return bel 107 | 108 | def choose_action(self, key, bel, context): 109 | # Thompson sampling strategy 110 | # Could also use epsilon greedy or UCB 111 | w = self.sample_params(key, bel) 112 | predicted_reward = self.predict_rewards(w, context) 113 | action = predicted_reward.argmax() 114 | return action 115 | -------------------------------------------------------------------------------- /bandits/agents/neural_linear.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import optax 3 | import jax.numpy as jnp 4 | from flax.training.train_state import TrainState 5 | 6 | from .agent_utils import train 7 | from tensorflow_probability.substrates import jax as tfp 8 | 9 | tfd = tfp.distributions 10 | 11 | 12 | class NeuralLinearBandit: 13 | def __init__(self, num_features, num_arms, model=None, opt=optax.adam(learning_rate=1e-2), eta=6.0, lmbda=0.25, 14 | update_step_mod=100, batch_size=5000, nepochs=3000): 15 | self.num_features = num_features 16 | self.num_arms = num_arms 17 | 18 | self.opt = opt 19 | self.eta = eta 20 | self.lmbda = lmbda 21 | self.update_step_mod = update_step_mod 22 | self.batch_size = batch_size 23 | self.nepochs = nepochs 24 | self.model = model 25 | 26 | def init_bel(self, key, contexts, states, actions, rewards): 27 | key, mykey = jax.random.split(key) 28 | xdummy = jnp.zeros((self.num_features)) 29 | initial_params = self.model.init(mykey, xdummy) 30 | initial_train_state = TrainState.create(apply_fn=self.model.apply, params=initial_params, 31 | tx=self.opt) 32 | 33 | n_hidden_last = self.model.apply(initial_params, xdummy, capture_intermediates=True)[1]["intermediates"]["last_layer"]["__call__"][0].shape[0] 34 | mu = jnp.zeros((self.num_arms, n_hidden_last)) 35 | Sigma = 1 / self.lmbda * jnp.eye(n_hidden_last) * jnp.ones((self.num_arms, 1, 1)) 36 | a = self.eta * jnp.ones((self.num_arms,)) 37 | b = self.eta * jnp.ones((self.num_arms,)) 38 | t = 0 39 | 40 | def update(bel, x): 41 | context, action, reward = x 42 | return self.update_bel(bel, context, action, reward), None 43 | 44 | self.contexts = contexts 45 | self.states = states 46 | 47 | initial_bel = (mu, Sigma, a, b, initial_train_state, t) 48 | bel, _ = jax.lax.scan(update, initial_bel, (contexts, actions, rewards)) 49 | return bel 50 | 51 | def featurize(self, params, x, feature_layer="last_layer"): 52 | _, inter = self.model.apply(params, x, capture_intermediates=True) 53 | Phi, *_ = inter["intermediates"][feature_layer]["__call__"] 54 | return Phi.squeeze() 55 | 56 | 57 | def cond_update_params(self, t): 58 | return (t % self.update_step_mod) == 0 59 | 60 | def update_bel(self, bel, context, action, reward): 61 | mu, Sigma, a, b, state, t = bel 62 | 63 | sgd_params = (state, t) 64 | 65 | def loss_fn(params): 66 | n_samples, *_ = self.contexts.shape 67 | final_t = jax.lax.cond(t == 0, lambda t: n_samples, lambda t: t.astype(int), t) 68 | sample_range = (jnp.arange(n_samples) <= t)[:, None] 69 | 70 | pred_reward = self.model.apply(params, self.contexts) 71 | loss = (optax.l2_loss(pred_reward, self.states) * sample_range).sum() / final_t 72 | return loss, pred_reward 73 | 74 | state = jax.lax.cond(self.cond_update_params(t), 75 | lambda sgd_params: train(sgd_params[0], loss_fn=loss_fn, nepochs=self.nepochs)[0], 76 | lambda sgd_params: sgd_params[0], sgd_params) 77 | 78 | transformed_context = self.featurize(state.params, context) 79 | 80 | mu_k, Sigma_k = mu[action], Sigma[action] 81 | Lambda_k = jnp.linalg.inv(Sigma_k) 82 | a_k, b_k = a[action], b[action] 83 | 84 | # weight params 85 | Lambda_update = jnp.outer(transformed_context, transformed_context) + Lambda_k 86 | Sigma_update = jnp.linalg.inv(Lambda_update) 87 | mu_update = Sigma_update @ (Lambda_k @ mu_k + transformed_context * reward) 88 | 89 | # noise params 90 | a_update = a_k + 1 / 2 91 | b_update = b_k + (reward ** 2 + mu_k.T @ Lambda_k @ mu_k - mu_update.T @ Lambda_update @ mu_update) / 2 92 | 93 | # update only the chosen action at time t 94 | mu = mu.at[action].set(mu_update) 95 | Sigma = Sigma.at[action].set(Sigma_update) 96 | a = a.at[action].set(a_update) 97 | b = b.at[action].set(b_update) 98 | t = t + 1 99 | 100 | bel = (mu, Sigma, a, b, state, t) 101 | return bel 102 | 103 | def sample_params(self, key, bel): 104 | mu, Sigma, a, b, _, _ = bel 105 | sigma_key, w_key = jax.random.split(key) 106 | sigma2 = tfd.InverseGamma(concentration=a, scale=b).sample(seed=sigma_key) 107 | covariance_matrix = sigma2[:, None, None] * Sigma 108 | w = tfd.MultivariateNormalFullCovariance(loc=mu, covariance_matrix=covariance_matrix).sample(seed=w_key) 109 | return w 110 | 111 | def choose_action(self, key, bel, context): 112 | # Thompson sampling strategy 113 | # Could also use epsilon greedy or UCB 114 | state = bel[-2] 115 | context_transformed = self.featurize(state.params, context) 116 | w = self.sample_params(key, bel) 117 | predicted_reward = jnp.einsum("m,km->k", context_transformed, w) 118 | action = predicted_reward.argmax() 119 | return action 120 | -------------------------------------------------------------------------------- /bandits/scripts/plot_results.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pandas as pd 3 | import seaborn as sns 4 | import matplotlib.pyplot as plt 5 | 6 | def plot_figure(data, x, y, filename, figsize=(24, 9), log_scale=False): 7 | sns.set(font_scale=1.5) 8 | plt.style.use("seaborn-poster") 9 | 10 | fig, ax = plt.subplots(figsize=figsize, dpi=300) 11 | g = sns.barplot(x=x, y=y, hue="Method", data=data, errwidth=2, ax=ax, palette=colors) 12 | if log_scale: 13 | g.set_yscale("log") 14 | plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0)) 15 | plt.tight_layout() 16 | plt.savefig(f"./bandits/figures/{filename}.png") 17 | plt.show() 18 | 19 | def read_data(dataset_name): 20 | *_, filename = sorted(glob.glob(f"./bandits/results/{dataset_name}_results*.csv")) 21 | df = pd.read_csv(filename) 22 | if dataset_name=="mnist": 23 | linear_df = df[(df["Method"]=="Lin-KF") | (df["Method"]=="Lin")].copy() 24 | linear_df["Model"] = "MLP2" 25 | df = df.append(linear_df) 26 | linear_df["Model"] = "LeNet5" 27 | df = df.append(linear_df) 28 | 29 | by = ["Rank"] if dataset_name=="tabular" else ["Rank", "AltRank"] 30 | 31 | data_up = df.sort_values(by=by).copy() 32 | data_down = df.sort_values(by=by).copy() 33 | 34 | data_up["Reward"] = data_up["Reward"] + data_up["Std"] 35 | data_down["Reward"] = data_down["Reward"] - data_down["Std"] 36 | data = pd.concat([data_up, data_down]) 37 | return data 38 | 39 | def plot_subspace_figure(df, filename=None): 40 | df = df.reset_index().drop(columns=["index"]) 41 | plt.style.use("seaborn-darkgrid") 42 | fig, ax = plt.subplots(figsize=(12, 8)) 43 | sns.lineplot(x="Subspace Dim", y="Reward", hue="Method", marker="o", data=df) 44 | lines, labels = ax.get_legend_handles_labels() 45 | for line, method in zip(lines, labels): 46 | data = df[df["Method"]==method] 47 | color = line.get_c() 48 | y_lower_bound = data["Reward"] - data["Std"] 49 | y_upper_bound = data["Reward"] + data["Std"] 50 | ax.fill_between(data["Subspace Dim"], y_lower_bound, y_upper_bound, color=color, alpha=0.3) 51 | 52 | ax.set_ylabel("Reward", fontsize=16) 53 | plt.setp(ax.get_xticklabels(), fontsize=16) 54 | plt.setp(ax.get_yticklabels(), fontsize=16) 55 | ax.set_xlabel("Subspace Dimension(d)", fontsize=16) 56 | dataset = df.iloc[0]["Dataset"] 57 | ax.set_title(f"{dataset.title()} - Subspace Dim vs. Reward", fontsize=18) 58 | legend = ax.legend(loc="lower right", prop={'size': 16},frameon=1) 59 | frame = legend.get_frame() 60 | frame.set_color('white') 61 | frame.set_alpha(0.6) 62 | 63 | file_path = "./bandits/figures/" 64 | file_path = file_path + f"{dataset}_sub_reward.png" if filename is None else file_path + f"{filename}.png" 65 | plt.savefig(file_path) 66 | 67 | method_ordering = {"EKF-Sub-SVD": 0, 68 | "EKF-Sub-RND": 1, 69 | "EKF-Sub-Diag-SVD": 2, 70 | "EKF-Sub-Diag-RND": 3, 71 | "EKF-Orig-Full": 4, 72 | "EKF-Orig-Diag": 5, 73 | "NL-Lim": 6, 74 | "NL-Unlim": 7, 75 | "Lin": 8, 76 | "Lin-KF": 9, 77 | "Lin-Wide": 9, 78 | "Lim2": 10, 79 | "NeuralTS": 11} 80 | 81 | colors = {k : sns.color_palette("Paired")[v] 82 | if k!="Lin-KF" else sns.color_palette("tab20")[8] 83 | for k,v in method_ordering.items()} 84 | 85 | dataset_info = { 86 | "mnist": { 87 | "elements": ["EKF-Sub-SVD", "EKF-Sub-RND", "EKF-Sub-Diag-SVD", 88 | "EKF-Sub-Diag-RND", "EKF-Orig-Diag", "NL-Lim", 89 | "NL-Unlim", "Lin"], 90 | "x": "Model", 91 | }, 92 | 93 | "tabular": { 94 | "elements": ["EKF-Sub-SVD", "EKF-Sub-RND", "EKF-Sub-Diag-SVD", 95 | "EKF-Sub-Diag-RND", "EKF-Orig-Diag", "NL-Lim", 96 | "NL-Unlim", "Lin"], 97 | "x": "Dataset" 98 | }, 99 | 100 | "movielens": { 101 | "elements": ["EKF-Sub-SVD", "EKF-Sub-RND", "EKF-Sub-Diag-SVD", 102 | "EKF-Sub-Diag-RND", "EKF-Orig-Diag", "NL-Lim", 103 | "NL-Unlim", "Lin"], 104 | "x": "Model" 105 | }, 106 | } 107 | 108 | 109 | plot_configs = [ 110 | {"metric": "Reward", "log_scale":False}, 111 | {"metric": "Time", "log_scale":True}, 112 | 113 | ] 114 | 115 | 116 | def main(): 117 | # Create reward / runnnig time experiments 118 | print("Plotting reward / running time") 119 | for dataset_name in dataset_info: 120 | print(dataset_name) 121 | info = dataset_info[dataset_name] 122 | methods = info["elements"] 123 | x = info["x"] 124 | 125 | df = read_data(dataset_name) 126 | df = df[df["Method"].isin(methods)] 127 | 128 | for config in plot_configs: 129 | metric = config["metric"] 130 | use_log_scale = config["log_scale"] 131 | 132 | filename = f"{dataset_name}_{metric.lower()}" 133 | plot_figure(df, x, metric, filename, log_scale=use_log_scale) 134 | 135 | 136 | # Plot subspace-dim v.s. reward 137 | print("Plotting subspace dim v.s. reward") 138 | *_, filename = sorted(glob.glob(f"./bandits/results/tabular_subspace_results*.csv")) 139 | tabular_sub_df = pd.read_csv(filename) 140 | 141 | datasets = ["shuttle", "adult", "covertype"] 142 | for dataset_name in datasets: 143 | print(dataset_name) 144 | subdf = tabular_sub_df.query("Dataset == @dataset_name") 145 | plot_subspace_figure(subdf) 146 | 147 | 148 | if __name__ == "__main__": 149 | main() -------------------------------------------------------------------------------- /demos/lofi_tabular.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this demo, we evaluate the performance of the 3 | lofi bandit on a tabular dataset 4 | """ 5 | import jax 6 | import optax 7 | import pickle 8 | import numpy as np 9 | import flax.linen as nn 10 | import jax.numpy as jnp 11 | from time import time 12 | from datetime import datetime 13 | from bayes_opt import BayesianOptimization 14 | from bandits import training as btrain 15 | from bandits.agents.low_rank_filter_bandit import LowRankFilterBandit 16 | from bandits.environments.tabular_env import TabularEnvironment 17 | 18 | class MLP(nn.Module): 19 | num_arms: int 20 | 21 | @nn.compact 22 | def __call__(self, x): 23 | # x = nn.Dense(50)(x) 24 | # x = nn.relu(x) 25 | x = nn.Dense(50, name="last_layer")(x) 26 | x = nn.relu(x) 27 | x = nn.Dense(self.num_arms)(x) 28 | return x 29 | 30 | 31 | def warmup_and_run(eval_hparams, transform_fn, bandit_cls, env, key, npulls, n_trials=1, **kwargs): 32 | n_devices = jax.local_device_count() 33 | key_warmup, key_train = jax.random.split(key, 2) 34 | hparams = transform_fn(eval_hparams) 35 | hparams = {**hparams, **kwargs} 36 | 37 | bandit = bandit_cls(env.n_features, env.n_arms, **hparams) 38 | 39 | bel, hist_warmup = btrain.warmup_bandit(key_warmup, bandit, env, npulls) 40 | time_init = time() 41 | if n_trials == 1: 42 | bel, hist_train = btrain.run_bandit(key_train, bel, bandit, env, t_start=npulls) 43 | elif 1 < n_trials <= n_devices: 44 | bel, hist_train = btrain.run_bandit_trials_pmap(key_train, bel, bandit, env, t_start=npulls, n_trials=n_trials) 45 | elif n_trials > n_devices: 46 | bel, hist_train = btrain.run_bandit_trials_multiple(key_train, bel, bandit, env, t_start=npulls, n_trials=n_trials) 47 | time_end = time() 48 | total_time = time_end - time_init 49 | 50 | 51 | res = { 52 | "hist_warmup": hist_warmup, 53 | "hist_train": hist_train, 54 | } 55 | # res = jax.tree_map(np.array, res) 56 | res["total_time"] = total_time 57 | 58 | return res 59 | 60 | 61 | def transform_hparams_subspace_fixed(hparams): 62 | emission_covariance = jnp.exp(hparams["log_em_cov"]) 63 | initial_covariance = jnp.exp(hparams["log_init_cov"]) 64 | warmup_learning_rate = jnp.exp(hparams["log_warmup_lr"]) 65 | 66 | hparams = { 67 | "observation_noise": emission_covariance, 68 | "prior_noise_variance": initial_covariance, 69 | "opt": warmup_learning_rate, 70 | } 71 | 72 | return hparams 73 | 74 | 75 | def transform_hparams_lofi(hparams): 76 | emission_covariance = jnp.exp(hparams["log_em_cov"]) 77 | initial_covariance = jnp.exp(hparams["log_init_cov"]) 78 | dynamics_weights = 1 - jnp.exp(hparams["log_1m_dweights"]) 79 | dynamics_covariance = jnp.exp(hparams["log_dcov"]) 80 | 81 | hparams = { 82 | "emission_covariance": emission_covariance, 83 | "initial_covariance": initial_covariance, 84 | "dynamics_weights": dynamics_weights, 85 | "dynamics_covariance": dynamics_covariance, 86 | } 87 | return hparams 88 | 89 | 90 | def transform_hparams_lofi_fixed(hparams): 91 | """ 92 | Transformation assuming that the dynamicss weights 93 | and dynamics covariance are static 94 | """ 95 | emission_covariance = jnp.exp(hparams["log_em_cov"]) 96 | initial_covariance = jnp.exp(hparams["log_init_cov"]) 97 | 98 | hparams = { 99 | "emission_covariance": emission_covariance, 100 | "initial_covariance": initial_covariance, 101 | } 102 | return hparams 103 | 104 | 105 | def transform_hparams_linear(hparams): 106 | eta = hparams["eta"] 107 | lmbda = jnp.exp(hparams["log_lambda"]) 108 | hparams = { 109 | "eta": eta, 110 | "lmbda": lmbda, 111 | } 112 | return hparams 113 | 114 | 115 | def transform_hparams_neural_linear(hparams): 116 | lr = jnp.exp(hparams["log_lr"]) 117 | eta = hparams["eta"] 118 | lmbda = jnp.exp(hparams["log_lambda"]) 119 | opt = optax.adam(lr) 120 | 121 | hparams = { 122 | "lmbda": lmbda, 123 | "eta": eta, 124 | "opt": opt, 125 | } 126 | return hparams 127 | 128 | def transform_hparams_rsgd(hparams): 129 | lr = jnp.exp(hparams["log_lr"]) 130 | tx = optax.adam(lr) 131 | hparams = { 132 | "tx": tx, 133 | } 134 | return hparams 135 | 136 | if __name__ == "__main__": 137 | ntrials = 10 138 | npulls = 20 139 | key = jax.random.PRNGKey(314) 140 | key_env, key_warmup, key_train = jax.random.split(key, 3) 141 | ntrain = 500 # 5000 142 | env = TabularEnvironment(key_env, ntrain=ntrain, name='statlog', intercept=False, path="./bandit-data") 143 | num_arms = env.labels_onehot.shape[-1] 144 | model = MLP(num_arms) 145 | 146 | kwargs_lofi = { 147 | "emission_covariance": 0.01, 148 | "initial_covariance": 1.0, 149 | "dynamics_weights": 1.0, 150 | "dynamics_covariance": 0.0, 151 | "memory_size": 10, 152 | "model": model, 153 | } 154 | 155 | n_features = env.n_features 156 | n_arms = env.n_arms 157 | bandit = LowRankFilterBandit(n_features, n_arms, **kwargs_lofi) 158 | 159 | bel, hist_warmup = btrain.warmup_bandit(key_warmup, bandit, env, npulls) 160 | bel, hist_train = btrain.run_bandit_trials(key_train, bel, bandit, env, t_start=npulls, n_trials=ntrials) 161 | 162 | res = { 163 | "hist_warmup": hist_warmup, 164 | "hist_train": hist_train, 165 | } 166 | res = jax.tree_map(np.array, res) 167 | 168 | # Store results 169 | datestr = datetime.now().strftime("%Y%m%d%H%M%S") 170 | path_to_results = f"./results/lofi_{datestr}.pkl" 171 | with open(path_to_results, "wb") as f: 172 | pickle.dump(res, f) 173 | print(f"Results stored in {path_to_results}") 174 | -------------------------------------------------------------------------------- /bandits/agents/ekf_subspace.py: -------------------------------------------------------------------------------- 1 | import optax 2 | import jax.numpy as jnp 3 | from jax import jit, device_put 4 | from jax.random import split 5 | from jax.flatten_util import ravel_pytree 6 | from flax.training import train_state 7 | from sklearn.decomposition import PCA 8 | from .agent_utils import train, generate_random_basis, convert_params_from_subspace_to_full 9 | from scripts.training_utils import MLP 10 | from jsl.nlds.extended_kalman_filter import ExtendedKalmanFilter 11 | from tensorflow_probability.substrates import jax as tfp 12 | 13 | tfd = tfp.distributions 14 | 15 | 16 | class SubspaceNeuralBandit: 17 | def __init__(self, num_features, num_arms, model, opt, prior_noise_variance, nwarmup=1000, nepochs=1000, 18 | system_noise=0.0, observation_noise=1.0, n_components=0.9999, random_projection=False): 19 | """ 20 | Subspace Neural Bandit implementation. 21 | Parameters 22 | ---------- 23 | num_arms: int 24 | Number of bandit arms / number of actions 25 | environment : Environment 26 | The environment to be used. 27 | model : flax.nn.Module 28 | The flax model to be used for the bandits. Note that this model is independent of the 29 | model architecture. The only constraint is that the last layer should have the same 30 | number of outputs as the number of arms. 31 | opt: flax.optim.Optimizer 32 | The optimizer to be used for training the model. 33 | learning_rate : float 34 | The learning rate for the optimizer used for the warmup phase. 35 | momentum : float 36 | The momentum for the optimizer used for the warmup phase. 37 | nepochs : int 38 | The number of epochs to be used for the warmup SGD phase. 39 | """ 40 | self.num_features = num_features 41 | self.num_arms = num_arms 42 | 43 | # TODO: deprecate hard-coded MLP 44 | if model is None: 45 | self.model = MLP(500, num_arms) 46 | else: 47 | try: 48 | self.model = model() 49 | except: 50 | self.model = model 51 | 52 | self.opt = opt 53 | self.prior_noise_variance = prior_noise_variance 54 | self.nwarmup = nwarmup 55 | self.nepochs = nepochs 56 | self.system_noise = system_noise 57 | self.observation_noise = observation_noise 58 | self.n_components = n_components 59 | self.random_projection = random_projection 60 | 61 | def init_bel(self, key, contexts, states, actions, rewards): 62 | warmup_key, projection_key = split(key, 2) 63 | initial_params = self.model.init(warmup_key, jnp.ones((1, self.num_features)))["params"] 64 | initial_train_state = train_state.TrainState.create(apply_fn=self.model.apply, params=initial_params, 65 | tx=self.opt) 66 | 67 | def loss_fn(params): 68 | pred_reward = self.model.apply({"params": params}, contexts)[:, actions.astype(int)] 69 | loss = optax.l2_loss(pred_reward, states[:, actions.astype(int)]).mean() 70 | return loss, pred_reward 71 | 72 | warmup_state, warmup_metrics = train(initial_train_state, loss_fn=loss_fn, nepochs=self.nepochs) 73 | 74 | thinned_samples = warmup_metrics["params"][::2] 75 | params_trace = thinned_samples[-self.nwarmup:] 76 | 77 | if not self.random_projection: 78 | pca = PCA(n_components=self.n_components) 79 | pca.fit(params_trace) 80 | subspace_dim = pca.n_components_ 81 | self.n_components = pca.n_components_ 82 | projection_matrix = device_put(pca.components_) 83 | else: 84 | if type(self.n_components) != int: 85 | raise ValueError(f"n_components must be an integer, got {self.n_components}") 86 | total_dim = params_trace.shape[-1] 87 | subspace_dim = self.n_components 88 | projection_matrix = generate_random_basis(projection_key, subspace_dim, total_dim) 89 | 90 | Q = jnp.eye(subspace_dim) * self.system_noise 91 | R = jnp.eye(1) * self.observation_noise 92 | 93 | params_full_init, reconstruct_tree_params = ravel_pytree(warmup_state.params) 94 | params_subspace_init = jnp.zeros(subspace_dim) 95 | covariance_subspace_init = jnp.eye(subspace_dim) * self.prior_noise_variance 96 | 97 | def predict_rewards(params_subspace_sample, context): 98 | params = convert_params_from_subspace_to_full(params_subspace_sample, projection_matrix, params_full_init) 99 | params = reconstruct_tree_params(params) 100 | outputs = self.model.apply({"params": params}, context) 101 | return outputs 102 | 103 | self.predict_rewards = predict_rewards 104 | 105 | def fz(params): 106 | return params 107 | 108 | def fx(params, context, action): 109 | return predict_rewards(params, context)[action, None] 110 | 111 | ekf = ExtendedKalmanFilter(fz, fx, Q, R) 112 | self.ekf = ekf 113 | 114 | bel = (params_subspace_init, covariance_subspace_init, 0) 115 | return bel 116 | 117 | def sample_params(self, key, bel): 118 | params_subspace, covariance_subspace, t = bel 119 | mv_normal = tfd.MultivariateNormalFullCovariance(loc=params_subspace, covariance_matrix=covariance_subspace) 120 | params_subspace = mv_normal.sample(seed=key) 121 | return params_subspace 122 | 123 | def update_bel(self, bel, context, action, reward): 124 | xs = (reward, (context, action)) 125 | bel, _ = jit(self.ekf.filter_step)(bel, xs) 126 | return bel 127 | 128 | def choose_action(self, key, bel, context): 129 | # Thompson sampling strategy 130 | # Could also use epsilon greedy or UCB 131 | w = self.sample_params(key, bel) 132 | predicted_reward = self.predict_rewards(w, context) 133 | action = predicted_reward.argmax() 134 | return action 135 | -------------------------------------------------------------------------------- /bandits/scripts/movielens_exp.py: -------------------------------------------------------------------------------- 1 | from jax.random import PRNGKey 2 | 3 | import optax 4 | import pandas as pd 5 | 6 | import argparse 7 | from time import time 8 | 9 | from environments.movielens_env import MovielensEnvironment 10 | 11 | from agents.linear_bandit import LinearBandit 12 | from agents.linear_kf_bandit import LinearKFBandit 13 | from agents.ekf_subspace import SubspaceNeuralBandit 14 | from agents.ekf_orig_diag import DiagonalNeuralBandit 15 | from agents.diagonal_subspace import DiagonalSubspaceNeuralBandit 16 | from agents.limited_memory_neural_linear import LimitedMemoryNeuralLinearBandit 17 | 18 | from .training_utils import train, MLP, MLPWide 19 | from .mnist_exp import mapping, rank, summarize_results, method_ordering 20 | 21 | 22 | def main(config): 23 | eta = 6.0 24 | lmbda = 0.25 25 | 26 | learning_rate = 0.01 27 | momentum = 0.9 28 | 29 | update_step_mod = 100 30 | buffer_size = 50 31 | nepochs = 100 32 | 33 | # Neural Linear Limited 34 | nl_lim = {"buffer_size": buffer_size, "opt": optax.sgd(learning_rate, momentum), "eta": eta, "lmbda": lmbda, 35 | "update_step_mod": update_step_mod, "nepochs": nepochs} 36 | 37 | # Neural Linear Unlimited 38 | buffer_size = 4800 39 | 40 | nl_unlim = nl_lim.copy() 41 | nl_unlim["buffer_size"] = buffer_size 42 | 43 | npulls, nwarmup = 2, 2000 44 | learning_rate, momentum = 0.8, 0.9 45 | observation_noise = 0.0 46 | prior_noise_variance = 1e-4 47 | n_components = 470 48 | nepochs = 1000 49 | random_projection = False 50 | 51 | # Subspace Neural with SVD 52 | ekf_sub_svd = {"opt": optax.sgd(learning_rate, momentum), "prior_noise_variance": prior_noise_variance, 53 | "nwarmup": nwarmup, "nepochs": nepochs, 54 | "observation_noise": observation_noise, "n_components": n_components, 55 | "random_projection": random_projection} 56 | 57 | # Subspace Neural without SVD 58 | ekf_sub_rnd = ekf_sub_svd.copy() 59 | ekf_sub_rnd["random_projection"] = True 60 | 61 | system_noise = 0.0 62 | 63 | ekf_orig = {"opt": optax.sgd(learning_rate, momentum), "prior_noise_variance": prior_noise_variance, 64 | "nwarmup": nwarmup, "nepochs": nepochs, 65 | "system_noise": system_noise, "observation_noise": observation_noise} 66 | linear = {} 67 | 68 | bandits = {"Linear": {"kwargs": linear, 69 | "bandit": LinearBandit 70 | }, 71 | "Linear KF": {"kwargs": linear.copy(), 72 | "bandit": LinearKFBandit 73 | }, 74 | "Limited Neural Linear": {"kwargs": nl_lim, 75 | "bandit": LimitedMemoryNeuralLinearBandit 76 | }, 77 | "Unlimited Neural Linear": {"kwargs": nl_unlim, 78 | "bandit": LimitedMemoryNeuralLinearBandit 79 | }, 80 | "EKF Subspace SVD": {"kwargs": ekf_sub_svd, 81 | "bandit": SubspaceNeuralBandit 82 | }, 83 | "EKF Subspace RND": {"kwargs": ekf_sub_rnd, 84 | "bandit": SubspaceNeuralBandit 85 | }, 86 | "EKF Diagonal Subspace SVD": {"kwargs": ekf_sub_svd.copy(), 87 | "bandit": DiagonalSubspaceNeuralBandit 88 | }, 89 | "EKF Diagonal Subspace RND": {"kwargs": ekf_sub_rnd.copy(), 90 | "bandit": DiagonalSubspaceNeuralBandit 91 | }, 92 | "EKF Orig Diagonal": {"kwargs": ekf_orig, 93 | "bandit": DiagonalNeuralBandit 94 | } 95 | } 96 | 97 | results = [] 98 | repeats = [1] 99 | for repeat in repeats: 100 | key = PRNGKey(0) 101 | # Create the environment beforehand 102 | movielens = MovielensEnvironment(key, repeat=repeat) 103 | # Number of different digits 104 | num_arms = movielens.labels_onehot.shape[-1] 105 | models = {"MLP1": MLP(num_arms), "MLP2": MLPWide(num_arms)} 106 | 107 | for model_name, model in models.items(): 108 | for bandit_name, properties in bandits.items(): 109 | if not bandit_name.startswith("Linear"): 110 | properties["kwargs"]["model"] = model 111 | print(f"Bandit : {bandit_name}") 112 | key = PRNGKey(314) 113 | start = time() 114 | warmup_rewards, rewards_trace, opt_rewards = train(key, properties["bandit"], movielens, npulls, 115 | config.ntrials, properties["kwargs"], neural=False) 116 | rtotal, rstd = summarize_results(warmup_rewards, rewards_trace, spacing="\t") 117 | end = time() 118 | print(f"\tTime : {end - start}:0.3f") 119 | results.append((bandit_name, model_name, end - start, rtotal, rstd)) 120 | 121 | df = pd.DataFrame(results) 122 | df = df.rename(columns={0: "Method", 1: "Model", 2: "Time", 3: "Reward", 4: "Std"}) 123 | df["Method"] = df["Method"].apply(lambda v: mapping[v]) 124 | df["Rank"] = df["Method"].apply(lambda v: method_ordering[v]) 125 | df["AltRank"] = df["Model"].apply(lambda v: rank[v]) 126 | 127 | df["Reward"] = df['Reward'].astype(float) 128 | df["Time"] = df['Time'].astype(float) 129 | df["Std"] = df['Std'].astype(float) 130 | df.to_csv(config.filepath) 131 | 132 | 133 | if __name__ == "__main__": 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument('--ntrials', type=int, nargs='?', const=10, default=10) 136 | filepath = "bandits/results/movielens_results.csv" 137 | parser.add_argument('--filepath', type=str, nargs='?', const=filepath, default=filepath) 138 | 139 | # Parse the argument 140 | args = parser.parse_args() 141 | main(args) 142 | -------------------------------------------------------------------------------- /aistats2022-slides/slides.md: -------------------------------------------------------------------------------- 1 | --- 2 | # try also 'default' to start simple 3 | theme: default 4 | # random image from a curated Unsplash collection by Anthony 5 | # like them? see https://unsplash.com/collections/94734566/slidev 6 | background: https://images.unsplash.com/photo-1620460700571-320445215efb?crop=entropy&cs=tinysrgb&fit=crop&fm=jpg&h=1080&ixid=MnwxfDB8MXxyYW5kb218MHw5NDczNDU2Nnx8fHx8fHwxNjQ1MjY3NTc2&ixlib=rb-1.2.1&q=80&utm_campaign=api-credit&utm_medium=referral&utm_source=unsplash_source&w=1920 7 | # apply any windi css classes to the current slide 8 | --- 9 | 10 | # Subspace Neural Bandits 11 | ### AIStats 2022 12 | 13 | Gerardo Duran-Martin, Queen Mary University of London, UK 14 | Aleyna Kara, Boğaziçi University, Turkey 15 | Kevin Murphy, Google Research, Brain Team 16 | 17 | Feburary 2022 18 | 19 | --- 20 | 21 | # Contextual bandits 22 | ### [Li, et.al. (2012)](https://arxiv.org/abs/1003.0146) 23 | 24 | Let $\mathcal{A} = \{a^{(1)}, \ldots, a^{(K)}\}$ be a set of actions. At every time step $t=1,\ldots,T$ 25 | 1. we are given a context ${\bf s}_t$ 26 | 2. we decide, based on ${\bf s}_t$, an action $a_t \in \mathcal{A}$ 27 | 3. we obtain a reward $r_t$ based on the context ${\bf s}_t$ and the chosen action $a_t$ 28 | 29 | Our goal is to choose the set of actions that maximise the expected reward $\sum_{t=1}^T\mathbb{E}[R_t]$. 30 | 31 | --- 32 | 33 | # Thompson Sampling 34 | 35 | ### [Agrawal and Goyal (2014)](https://arxiv.org/abs/1209.3352), [Russo, et.al. (2014)](https://arxiv.org/abs/1402.0298) 36 | Let $\mathcal{D}_t = (s_t, a_t, r_t)$ be a sequence of observations. Let $\mathcal{D}_{1:t} = \{\mathcal{D}_1, \ldots, \mathcal{D}_t\}$. 37 | 38 | At every time step $t=1,\ldots, T$, we follow the following procedure: 39 | 1. Sample $\boldsymbol\theta_t \sim p(\cdot \vert \mathcal{D}_{1:t})$ 40 | 2. $a_t = \arg\max_{a \in \mathcal{A}} \mathbb{E}[R(s_t,a; \boldsymbol\theta_t)]$ 41 | 3. Obtain $r_t \sim R(s_t,a_t; \boldsymbol\theta_t)$ 42 | 4. Store $\mathcal{D}_t = (s_t, a_t, r_t)$ 43 | 44 | *Example*: Beta-Bernoulli bandit with $K=4$ arms. 45 | 46 | 49 | 50 | 51 | --- 52 | 53 | # Neural Bandits 54 | ### Characterising the reward function 55 | 56 | Let $f: \mathcal{S}\times\mathcal{A}\times\mathbb{R}^D \to \mathbb{R}^K$ be a neural network. A neural bandit is a contextual bandit where the reward is taken to be 57 | 58 | $$ 59 | r_t \vert {\bf s}_t, a, \theta_t \sim \mathcal{N}\Big(f({\bf s}_t, a, \boldsymbol\theta_t), \sigma^2\Big) 60 | $$ 61 | 62 | 63 | The main question: How to determine $\boldsymbol\theta_t$ at every time step $t$ using Thompson sampling? 64 | 65 | We need to compute (or approximate) the posterior distribution of the parameters in the neural network: 66 | 67 | $$ 68 | \begin{aligned} 69 | p(\boldsymbol\theta \vert \mathcal{D}_{1:t}) &= p(\boldsymbol\theta \vert \mathcal{D}_{1:t-1}, \mathcal{D}_t)\\ 70 | &\propto p(\boldsymbol\theta \vert \mathcal{D}_{1:t-1}) p(\mathcal{D}_t \vert \boldsymbol\theta) \\ 71 | \end{aligned} 72 | $$ 73 | 74 | --- 75 | 76 | # Subspace neural bandits 77 | ## Motivation 78 | 79 | * Current state-of-the-art solutions, although efficient, are not fully Bayesian. 80 | 1. Neural linear approximation 81 | 2. Lim2 approximation 82 | 3. Neural tangent approximation 83 | 84 | 85 | 86 | * Fully Bayesian solutions are computationally expensive. 87 | 88 | 1. Hamiltonian Monte Carlo (HMC) sampling of posterior beliefs 89 | 90 | 2. Extended Kalman Filter (EKF) online estimation of posterior beliefs 91 | * We seek to solve the contextual-neural-bandit problem in a way that is **fully Bayesian** and **computationally-efficient**. 92 | 93 | 94 | 95 | ---- 96 | 97 | # Extended Kalman filter and neural networks 98 | ### [Singhal and Wu (1988)](https://proceedings.neurips.cc/paper/1988/hash/38b3eff8baf56627478ec76a704e9b52-Abstract.html) 99 | Online learning of neural network parameters 100 | 101 | $$ 102 | \begin{aligned} 103 | \boldsymbol\theta_t \vert \boldsymbol\theta_{t-1} \sim \mathcal{N}(\boldsymbol\theta_{t-1}, \sigma^2 {\bf I}) \\ 104 | r_t \vert \boldsymbol\theta_t \sim \mathcal{N}(f({\bf s}_t, a_t, \boldsymbol\theta_t), \sigma^2 {\bf I}) 105 | \end{aligned} 106 | $$ 107 | 108 |
109 | 110 | 113 | 114 | --- 115 | 116 | # Neural networks (in a subspace) 117 | ### [Li, et.al. (2018)](https://arxiv.org/abs/1804.08838), [Larsen, et.al. (2021)](https://arxiv.org/abs/2107.05802) 118 | Stochastic gradient descent (SGD) for neural networks **in a subspace**. 119 | Neural networks live in a linear subspace. 120 | 121 | $$ 122 | \boldsymbol\theta({\bf z}_t) = {\bf A z}_{t} + \boldsymbol\theta_* 123 | $$ 124 | 125 |
126 | 127 | 128 | 129 | --- 130 | 131 | # Our contribution: subspace neural bandits 132 | ### Extended Kalman filter and neural networks in a subspace. 133 | We learn a subspace ${\bf z} \in \mathbb{R}^d$ 134 | 135 | $$ 136 | \begin{aligned} 137 | \boldsymbol\theta({\bf z}_t) &= {\bf A z}_{t} + \boldsymbol\theta_* \\ 138 | {\bf z}_t \vert {\bf z}_{t-1} &\sim \mathcal{N}({\bf z}_{t-1}, \tau^2 {\bf I}) \\ 139 | {\bf r}_t \vert {\bf z}_t &\sim \mathcal{N}\Big(f({\bf s}_t, a_t, \boldsymbol\theta({\bf z}_t)), \sigma^2 {\bf I}\Big) 140 | \end{aligned} 141 | $$ 142 | 143 | 144 | 145 | --- 146 | 147 | # Results 148 | ### MNIST: cumulative reward 149 | 150 | Classification-turned-bandit problem. 151 | Maximum reward is $T=5000$ (total number of samples). 152 | 153 |
154 |
155 | 156 | 157 | 158 | --- 159 | 160 | # Results 161 | ### MNIST: running time 162 | 163 |
164 | 165 |
166 | 167 | --- 168 | 169 | # Results 170 | ### Effect of subspace dimensionality 171 | We seek to make $d$ as small as possible. 172 | 173 |
174 | 175 | 176 | 177 | 178 | --- 179 | 180 | # Subspace neural bandits 181 | ### probml.github.io/bandits 182 | 183 |
184 |

📑 Paper

185 |

💻 Github repo

186 |
187 | -------------------------------------------------------------------------------- /bandits/agents/limited_memory_neural_linear.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import jit 3 | from jax.random import split 4 | from jax.lax import scan, cond 5 | from jax.nn import one_hot 6 | 7 | import optax 8 | 9 | from flax.training import train_state 10 | 11 | from .agent_utils import train 12 | from scripts.training_utils import MLP 13 | from tensorflow_probability.substrates import jax as tfp 14 | 15 | tfd = tfp.distributions 16 | 17 | 18 | class LimitedMemoryNeuralLinearBandit: 19 | """ 20 | Neural-linear bandit on a buffer. We train the model in the warmup 21 | phase considering all of the datapoints. After the warmup phase, we 22 | train from the rest of the dataset considering only a fixed number 23 | of datapoints to train on. 24 | """ 25 | 26 | def __init__(self, num_features, num_arms, buffer_size, model=None, opt=optax.adam(learning_rate=1e-2), eta=6.0, 27 | lmbda=0.25, 28 | update_step_mod=100, nepochs=3000): 29 | 30 | self.num_features = num_features 31 | self.num_arms = num_arms 32 | 33 | if model is None: 34 | self.model = MLP(500, num_arms) 35 | else: 36 | try: 37 | self.model = model() 38 | except: 39 | self.model = model 40 | 41 | self.opt = opt 42 | self.eta = eta 43 | self.lmbda = lmbda 44 | self.update_step_mod = update_step_mod 45 | self.nepochs = nepochs 46 | self.buffer_size = buffer_size 47 | self.buffer_indexer = jnp.arange(self.buffer_size) 48 | 49 | def init_bel(self, key, contexts, states, actions, rewards): 50 | """ 51 | Initialize the multi-armed bandit model by training the model on the warmup phase 52 | doing a round-robin of the actions. 53 | """ 54 | 55 | # Initialize feature matrix 56 | nsamples, nfeatures = contexts.shape 57 | initial_params = self.model.init(key, jnp.ones(nfeatures)) 58 | 59 | num_features_last_layer = initial_params["params"]["last_layer"]["bias"].size 60 | mu = jnp.zeros((self.num_arms, num_features_last_layer)) 61 | Sigma = 1 / self.lmbda * jnp.eye(num_features_last_layer) * jnp.ones((self.num_arms, 1, 1)) 62 | a = self.eta * jnp.ones((self.num_arms,)) 63 | b = self.eta * jnp.ones((self.num_arms,)) 64 | initial_train_state = train_state.TrainState.create(apply_fn=self.model.apply, params=initial_params, 65 | tx=self.opt) 66 | t = 0 67 | 68 | context_buffer = jnp.zeros((self.buffer_size, nfeatures)) 69 | reward_buffer = jnp.zeros(self.buffer_size) 70 | action_buffer = -jnp.ones(self.buffer_size) 71 | buffer_ix = 0 72 | 73 | def update(bel, x): 74 | context, action, reward = x 75 | return self.update_bel(bel, context, action, reward), None 76 | 77 | buffer = (context_buffer, reward_buffer, action_buffer, buffer_ix) 78 | 79 | initial_bel = (mu, Sigma, a, b, initial_train_state, t, buffer) 80 | 81 | bel, _ = scan(update, initial_bel, (contexts, actions, rewards)) 82 | 83 | return bel 84 | 85 | def _update_buffer(self, buffer, new_item, index): 86 | """ 87 | source: https://github.com/google/jax/issues/4590 88 | """ 89 | buffer = buffer.at[index].set(new_item) 90 | index = (index + 1) % self.buffer_size 91 | return buffer, index 92 | 93 | def cond_update_params(self, t): 94 | cond1 = (t % self.update_step_mod) == 0 95 | cond2 = t > 0 96 | return cond1 * cond2 97 | 98 | def featurize(self, params, x, feature_layer="last_layer"): 99 | _, inter = self.model.apply(params, x, capture_intermediates=True) 100 | Phi, *_ = inter["intermediates"][feature_layer]["__call__"] 101 | return Phi.squeeze() 102 | 103 | def update_bel(self, bel, context, action, reward): 104 | mu, Sigma, a, b, state, t, buffer = bel 105 | context_buffer, reward_buffer, action_buffer, buffer_ix = buffer 106 | 107 | update_buffer = jit(self._update_buffer) 108 | context_buffer, _ = update_buffer(context_buffer, context, buffer_ix) 109 | reward_buffer, _ = update_buffer(reward_buffer, reward, buffer_ix) 110 | action_buffer, buffer_ix = update_buffer(action_buffer, action, buffer_ix) 111 | 112 | Y_buffer = one_hot(action_buffer, self.num_arms) * reward_buffer[:, None] 113 | 114 | num_elements = jnp.minimum(self.buffer_size, t) 115 | valmap = self.buffer_indexer <= num_elements.astype(float) 116 | valmap = valmap[:, None] 117 | 118 | @jit 119 | def loss_fn(params): 120 | pred_reward = self.model.apply(params, context_buffer) 121 | loss = jnp.where(valmap, optax.l2_loss(pred_reward, Y_buffer), 0.0) 122 | loss = loss.sum() / num_elements 123 | return loss 124 | 125 | state = cond(self.cond_update_params(t), 126 | lambda s: train(s, loss_fn=loss_fn, nepochs=self.nepochs, has_aux=False)[0], 127 | lambda s: s, state) 128 | 129 | transformed_context = self.featurize(state.params, context) 130 | 131 | mu_k, Sigma_k = mu[action], Sigma[action] 132 | Lambda_k = jnp.linalg.inv(Sigma_k) 133 | a_k, b_k = a[action], b[action] 134 | 135 | # weight params 136 | Lambda_update = jnp.outer(transformed_context, transformed_context) + Lambda_k 137 | Sigma_update = jnp.linalg.inv(Lambda_update) 138 | mu_update = Sigma_update @ (Lambda_k @ mu_k + transformed_context * reward) 139 | 140 | # noise params 141 | a_update = a_k + 1 / 2 142 | b_update = b_k + (reward ** 2 + mu_k.T @ Lambda_k @ mu_k - mu_update.T @ Lambda_update @ mu_update) / 2 143 | 144 | # update only the chosen action at time t 145 | mu = mu.at[action].set(mu_update) 146 | Sigma = Sigma.at[action].set(Sigma_update) 147 | a = a.at[action].set(a_update) 148 | b = b.at[action].set(b_update) 149 | t = t + 1 150 | 151 | buffer = (context_buffer, reward_buffer, action_buffer, buffer_ix) 152 | 153 | bel = (mu, Sigma, a, b, state, t, buffer) 154 | 155 | return bel 156 | 157 | def sample_params(self, key, bel): 158 | mu, Sigma, a, b, _, _, _ = bel 159 | sigma_key, w_key = split(key) 160 | sigma2 = tfd.InverseGamma(concentration=a, scale=b).sample(seed=sigma_key) 161 | covariance_matrix = sigma2[:, None, None] * Sigma 162 | w = tfd.MultivariateNormalFullCovariance(loc=mu, covariance_matrix=covariance_matrix).sample(seed=w_key) 163 | return w 164 | 165 | def choose_action(self, key, bel, context): 166 | # Thompson sampling strategy 167 | # Could also use epsilon greedy or UCB 168 | state = bel[-3] 169 | context_transformed = self.featurize(state.params, context) 170 | w = self.sample_params(key, bel) 171 | predicted_reward = jnp.einsum("m,km->k", context_transformed, w) 172 | action = predicted_reward.argmax() 173 | return action 174 | -------------------------------------------------------------------------------- /bandit-data/ml-100k/README.txt: -------------------------------------------------------------------------------- 1 | SUMMARY & USAGE LICENSE 2 | ============================================= 3 | 4 | MovieLens data sets were collected by the GroupLens Research Project 5 | at the University of Minnesota. 6 | 7 | This data set consists of: 8 | * 100,000 ratings (1-5) from 943 users on 1682 movies. 9 | * Each user has rated at least 20 movies. 10 | * Simple demographic info for the users (age, gender, occupation, zip) 11 | 12 | The data was collected through the MovieLens web site 13 | (movielens.umn.edu) during the seven-month period from September 19th, 14 | 1997 through April 22nd, 1998. This data has been cleaned up - users 15 | who had less than 20 ratings or did not have complete demographic 16 | information were removed from this data set. Detailed descriptions of 17 | the data file can be found at the end of this file. 18 | 19 | Neither the University of Minnesota nor any of the researchers 20 | involved can guarantee the correctness of the data, its suitability 21 | for any particular purpose, or the validity of results based on the 22 | use of the data set. The data set may be used for any research 23 | purposes under the following conditions: 24 | 25 | * The user may not state or imply any endorsement from the 26 | University of Minnesota or the GroupLens Research Group. 27 | 28 | * The user must acknowledge the use of the data set in 29 | publications resulting from the use of the data set 30 | (see below for citation information). 31 | 32 | * The user may not redistribute the data without separate 33 | permission. 34 | 35 | * The user may not use this information for any commercial or 36 | revenue-bearing purposes without first obtaining permission 37 | from a faculty member of the GroupLens Research Project at the 38 | University of Minnesota. 39 | 40 | If you have any further questions or comments, please contact GroupLens 41 | . 42 | 43 | CITATION 44 | ============================================== 45 | 46 | To acknowledge use of the dataset in publications, please cite the 47 | following paper: 48 | 49 | F. Maxwell Harper and Joseph A. Konstan. 2015. The MovieLens Datasets: 50 | History and Context. ACM Transactions on Interactive Intelligent 51 | Systems (TiiS) 5, 4, Article 19 (December 2015), 19 pages. 52 | DOI=http://dx.doi.org/10.1145/2827872 53 | 54 | 55 | ACKNOWLEDGEMENTS 56 | ============================================== 57 | 58 | Thanks to Al Borchers for cleaning up this data and writing the 59 | accompanying scripts. 60 | 61 | PUBLISHED WORK THAT HAS USED THIS DATASET 62 | ============================================== 63 | 64 | Herlocker, J., Konstan, J., Borchers, A., Riedl, J.. An Algorithmic 65 | Framework for Performing Collaborative Filtering. Proceedings of the 66 | 1999 Conference on Research and Development in Information 67 | Retrieval. Aug. 1999. 68 | 69 | FURTHER INFORMATION ABOUT THE GROUPLENS RESEARCH PROJECT 70 | ============================================== 71 | 72 | The GroupLens Research Project is a research group in the Department 73 | of Computer Science and Engineering at the University of Minnesota. 74 | Members of the GroupLens Research Project are involved in many 75 | research projects related to the fields of information filtering, 76 | collaborative filtering, and recommender systems. The project is lead 77 | by professors John Riedl and Joseph Konstan. The project began to 78 | explore automated collaborative filtering in 1992, but is most well 79 | known for its world wide trial of an automated collaborative filtering 80 | system for Usenet news in 1996. The technology developed in the 81 | Usenet trial formed the base for the formation of Net Perceptions, 82 | Inc., which was founded by members of GroupLens Research. Since then 83 | the project has expanded its scope to research overall information 84 | filtering solutions, integrating in content-based methods as well as 85 | improving current collaborative filtering technology. 86 | 87 | Further information on the GroupLens Research project, including 88 | research publications, can be found at the following web site: 89 | 90 | http://www.grouplens.org/ 91 | 92 | GroupLens Research currently operates a movie recommender based on 93 | collaborative filtering: 94 | 95 | http://www.movielens.org/ 96 | 97 | DETAILED DESCRIPTIONS OF DATA FILES 98 | ============================================== 99 | 100 | Here are brief descriptions of the data. 101 | 102 | ml-data.tar.gz -- Compressed tar file. To rebuild the u data files do this: 103 | gunzip ml-data.tar.gz 104 | tar xvf ml-data.tar 105 | mku.sh 106 | 107 | u.data -- The full u data set, 100000 ratings by 943 users on 1682 items. 108 | Each user has rated at least 20 movies. Users and items are 109 | numbered consecutively from 1. The data is randomly 110 | ordered. This is a tab separated list of 111 | user id | item id | rating | timestamp. 112 | The time stamps are unix seconds since 1/1/1970 UTC 113 | 114 | u.info -- The number of users, items, and ratings in the u data set. 115 | 116 | u.item -- Information about the items (movies); this is a tab separated 117 | list of 118 | movie id | movie title | release date | video release date | 119 | IMDb URL | unknown | Action | Adventure | Animation | 120 | Children's | Comedy | Crime | Documentary | Drama | Fantasy | 121 | Film-Noir | Horror | Musical | Mystery | Romance | Sci-Fi | 122 | Thriller | War | Western | 123 | The last 19 fields are the genres, a 1 indicates the movie 124 | is of that genre, a 0 indicates it is not; movies can be in 125 | several genres at once. 126 | The movie ids are the ones used in the u.data data set. 127 | 128 | u.genre -- A list of the genres. 129 | 130 | u.user -- Demographic information about the users; this is a tab 131 | separated list of 132 | user id | age | gender | occupation | zip code 133 | The user ids are the ones used in the u.data data set. 134 | 135 | u.occupation -- A list of the occupations. 136 | 137 | u1.base -- The data sets u1.base and u1.test through u5.base and u5.test 138 | u1.test are 80%/20% splits of the u data into training and test data. 139 | u2.base Each of u1, ..., u5 have disjoint test sets; this if for 140 | u2.test 5 fold cross validation (where you repeat your experiment 141 | u3.base with each training and test set and average the results). 142 | u3.test These data sets can be generated from u.data by mku.sh. 143 | u4.base 144 | u4.test 145 | u5.base 146 | u5.test 147 | 148 | ua.base -- The data sets ua.base, ua.test, ub.base, and ub.test 149 | ua.test split the u data into a training set and a test set with 150 | ub.base exactly 10 ratings per user in the test set. The sets 151 | ub.test ua.test and ub.test are disjoint. These data sets can 152 | be generated from u.data by mku.sh. 153 | 154 | allbut.pl -- The script that generates training and test sets where 155 | all but n of a users ratings are in the training data. 156 | 157 | mku.sh -- A shell script to generate all the u data sets from u.data. 158 | -------------------------------------------------------------------------------- /bandits/environments/tabular_env.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax.random import split, permutation 3 | from jax.nn import one_hot 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | import pickle 9 | import requests 10 | import io 11 | 12 | from sklearn.preprocessing import OneHotEncoder, normalize 13 | from sklearn.datasets import fetch_openml 14 | 15 | from .environment import BanditEnvironment 16 | 17 | 18 | # https://github.com/ofirnabati/Neural-Linear-Bandits-with-Likelihood-Matching/blob/85b541c225ec453bbf54650291616759e28b59d5/bandits/data/data_sampler.py#L528 19 | def safe_std(values): 20 | """Remove zero std values for ones.""" 21 | return np.array([val if val != 0.0 else 1.0 for val in values]) 22 | 23 | 24 | def read_file_from_url(name): 25 | if name == 'adult': 26 | url = "https://raw.githubusercontent.com/probml/probml-data/main/data/adult.data" 27 | elif name == 'covertype': 28 | url = "https://raw.githubusercontent.com/probml/probml-data/main/data/covtype.data" 29 | else: 30 | url = "https://raw.githubusercontent.com/probml/probml-data/main/data/shuttle.trn" 31 | 32 | download = requests.get(url).content 33 | file = io.StringIO(download.decode('utf-8')) 34 | return file 35 | 36 | 37 | # https://github.com/ofirnabati/Neural-Linear-Bandits-with-Likelihood-Matching/blob/85b541c225ec453bbf54650291616759e28b59d5/bandits/data/data_sampler.py#L528 38 | def classification_to_bandit_problem(X, y, narms=None): 39 | """Normalize contexts and encode deterministic rewards.""" 40 | 41 | if narms is None: 42 | narms = np.max(y) + 1 43 | 44 | ntrain = X.shape[0] 45 | 46 | # Due to random subsampling in small problems, some features may be constant 47 | sstd = safe_std(np.std(X, axis=0, keepdims=True)[0, :]) 48 | 49 | # Normalize features 50 | X = ((X - np.mean(X, axis=0, keepdims=True)) / sstd) 51 | 52 | # One hot encode labels as rewards 53 | y = one_hot(y, narms) 54 | 55 | opt_rewards = np.ones((ntrain,)) 56 | 57 | return X, y, opt_rewards 58 | 59 | 60 | # https://github.com/ofirnabati/Neural-Linear-Bandits-with-Likelihood-Matching/blob/85b541c225ec453bbf54650291616759e28b59d5/bandits/data/data_sampler.py#L165 61 | def sample_shuttle_data(): 62 | """Returns bandit problem dataset based on the UCI statlog data. 63 | Returns: 64 | dataset: Sampled matrix with rows: (X, y) 65 | opt_vals: Vector of deterministic optimal (reward, action) for each context. 66 | https://archive.ics.uci.edu/ml/datasets/Statlog+(Shuttle) 67 | """ 68 | file = read_file_from_url("shuttle") 69 | data = np.loadtxt(file) 70 | 71 | narms = 7 # some of the actions are very rarely optimal. 72 | 73 | # Last column is label, rest are features 74 | X = data[:, :-1] 75 | y = data[:, -1].astype(int) - 1 # convert to 0 based index 76 | 77 | return classification_to_bandit_problem(X, y, narms=narms) 78 | 79 | 80 | # https://github.com/ofirnabati/Neural-Linear-Bandits-with-Likelihood-Matching/blob/85b541c225ec453bbf54650291616759e28b59d5/bandits/data/data_sampler.py#L165 81 | def sample_adult_data(): 82 | """Returns bandit problem dataset based on the UCI adult data. 83 | Returns: 84 | dataset: Sampled matrix with rows: (X, y) 85 | opt_vals: Vector of deterministic optimal (reward, action) for each context. 86 | Preprocessing: 87 | * drop rows with missing values 88 | * convert categorical variables to 1 hot encoding 89 | https://archive.ics.uci.edu/ml/datasets/census+income 90 | """ 91 | file = read_file_from_url("adult") 92 | df = pd.read_csv(file, header=None, na_values=[' ?']).dropna() 93 | 94 | narms = 2 95 | 96 | y = df[14].astype('str') 97 | df = df.drop([14, 6], axis=1) 98 | 99 | y = y.str.replace('.', '') 100 | y = y.astype('category').cat.codes.to_numpy() 101 | 102 | # Convert categorical variables to 1 hot encoding 103 | cols_to_transform = [1, 3, 5, 7, 8, 9, 13] 104 | df = pd.get_dummies(df, columns=cols_to_transform) 105 | X = df.to_numpy() 106 | 107 | return classification_to_bandit_problem(X, y, narms=narms) 108 | 109 | 110 | # https://github.com/ofirnabati/Neural-Linear-Bandits-with-Likelihood-Matching/blob/85b541c225ec453bbf54650291616759e28b59d5/bandits/data/data_sampler.py#L165 111 | def sample_covertype_data(): 112 | """Returns bandit problem dataset based on the UCI Cover_Type data. 113 | Returns: 114 | dataset: Sampled matrix with rows: (X, y) 115 | opt_vals: Vector of deterministic optimal (reward, action) for each context. 116 | Preprocessing: 117 | * drop rows with missing labels 118 | * convert categorical variables to 1 hot encoding 119 | https://archive.ics.uci.edu/ml/datasets/Covertype 120 | """ 121 | 122 | file = read_file_from_url("covertype") 123 | df = pd.read_csv(file, header=None, na_values=[' ?']).dropna() 124 | 125 | narms = 7 126 | 127 | # Assuming what the paper calls response variable is the label? 128 | # Last column is label. 129 | y = df[df.columns[-1]].astype('category').cat.codes.to_numpy() 130 | df = df.drop([df.columns[-1]], axis=1) 131 | 132 | X = df.to_numpy() 133 | 134 | return classification_to_bandit_problem(X, y, narms=narms) 135 | 136 | 137 | def get_tabular_data_from_url(name): 138 | if name == 'adult': 139 | return sample_adult_data() 140 | elif name == 'covertype': 141 | return sample_covertype_data() 142 | elif name == 'statlog': 143 | return sample_shuttle_data() 144 | else: 145 | raise RuntimeError('Dataset does not exist') 146 | 147 | 148 | def get_tabular_data_from_openml(name): 149 | if name == 'adult': 150 | X, y = fetch_openml('adult', version=2, return_X_y=True, as_frame=False) 151 | elif name == 'covertype': 152 | X, y = fetch_openml('covertype', version=3, return_X_y=True, as_frame=False) 153 | elif name == 'statlog': 154 | X, y = fetch_openml('shuttle', version=1, return_X_y=True, as_frame=False) 155 | else: 156 | raise RuntimeError('Dataset does not exist') 157 | 158 | X[np.isnan(X)] = - 1 159 | X = normalize(X) 160 | 161 | # generate one_hot coding: 162 | y = OneHotEncoder(sparse=False).fit_transform(y.reshape((-1, 1))) 163 | 164 | opt_rewards = jnp.ones((len(X),)) 165 | 166 | return X, y, opt_rewards 167 | 168 | 169 | def get_tabular_data_from_pkl(name, path): 170 | with open(f"{path}/bandit-{name}.pkl", "rb") as f: 171 | sampled_vals = pickle.load(f) 172 | 173 | contexts, opt_rewards, (*_, actions) = sampled_vals 174 | contexts = jnp.c_[jnp.ones_like(contexts[:, :1]), contexts] 175 | narms = len(jnp.unique(actions)) 176 | actions = one_hot(actions, narms) 177 | return contexts, actions, opt_rewards 178 | 179 | 180 | def TabularEnvironment(key, name, ntrain=0, intercept=True, load_from="pkl", path="./bandit-data"): 181 | """ 182 | Parameters 183 | ---------- 184 | key: jax.random.PRNGKey 185 | Random number generator key. 186 | name: str 187 | One of ['adult', 'covertype', 'statlog']. 188 | """ 189 | if load_from == "url": 190 | X, y, opt_rewards = get_tabular_data_from_openml(name) 191 | elif load_from == "openml": 192 | X, y, opt_rewards = get_tabular_data_from_url(name) 193 | elif load_from == "pkl": 194 | X, y, opt_rewards = get_tabular_data_from_pkl(name, path) 195 | else: 196 | raise ValueError('load_from must be equal to pkl, openml or url.') 197 | 198 | ntrain = ntrain if ntrain < len(X) and ntrain > 0 else len(X) 199 | X, y = jnp.float32(X)[:ntrain], jnp.float32(y)[:ntrain] 200 | 201 | if intercept: 202 | X = jnp.hstack([jnp.ones_like(X[:, :1]), X]) 203 | 204 | return BanditEnvironment(key, X, y, opt_rewards) 205 | -------------------------------------------------------------------------------- /bandits/scripts/mnist_exp.py: -------------------------------------------------------------------------------- 1 | import optax 2 | from jax.random import PRNGKey 3 | 4 | import argparse 5 | from time import time 6 | 7 | import pandas as pd 8 | 9 | from environments.mnist_env import MnistEnvironment 10 | 11 | from agents.linear_bandit import LinearBandit 12 | from agents.linear_kf_bandit import LinearKFBandit 13 | from agents.ekf_subspace import SubspaceNeuralBandit 14 | from agents.ekf_orig_diag import DiagonalNeuralBandit 15 | from agents.diagonal_subspace import DiagonalSubspaceNeuralBandit 16 | from agents.limited_memory_neural_linear import LimitedMemoryNeuralLinearBandit 17 | from agents.low_rank_filter_bandit import LowRankFilterBandit 18 | 19 | from .training_utils import train, MLP, MLPWide, LeNet5, summarize_results 20 | 21 | method_ordering = {"EKF-Sub-SVD": 0, 22 | "EKF-Sub-RND": 1, 23 | "EKF-Sub-Diag-SVD": 2, 24 | "EKF-Sub-Diag-RND": 3, 25 | "EKF-Orig-Full": 4, 26 | "EKF-Orig-Diag": 5, 27 | "NL-Lim": 6, 28 | "NL-Unlim": 7, 29 | "Lin": 8, 30 | "Lin-KF": 9, 31 | "Lin-Wide": 9, 32 | "Lim2": 10, 33 | "NeuralTS": 11, 34 | "LoFi": 12 35 | } 36 | 37 | rank = {"MLP1": 0, "MLP2": 1, "LeNet5": 2} 38 | 39 | mapping = { 40 | "Lim2": "Lim2", 41 | "EKF Subspace SVD": "EKF-Sub-SVD", 42 | "EKF Subspace RND": "EKF-Sub-RND", 43 | "EKF Diagonal Subspace SVD": "EKF-Sub-Diag-SVD", 44 | "EKF Diagonal Subspace RND": "EKF-Sub-Diag-RND", 45 | "EKF Orig Full": "EKF-Orig-Full", 46 | "Linear": "Lin", 47 | "Linear Wide": "Lin-Wide", 48 | "Linear KF": "Lin-KF", 49 | "Unlimited Neural Linear": "NL-Unlim", 50 | "Limited Neural Linear": "NL-Lim", 51 | "NeuralTS": "NeuralTS", 52 | "EKF Orig Diagonal": "EKF-Orig-Diag", 53 | "LoFi": "LoFi", 54 | } 55 | 56 | 57 | def main(config): 58 | key = PRNGKey(0) 59 | ntrain = 5000 60 | 61 | # Create the environment beforehand 62 | mnist_env = MnistEnvironment(key, ntrain=ntrain) 63 | 64 | # Number of different digits 65 | num_arms = 10 66 | models = {"MLP1": MLP(num_arms), "MLP2": MLPWide(num_arms), "LeNet5": LeNet5(num_arms)} 67 | 68 | eta = 6.0 69 | lmbda = 0.25 70 | 71 | learning_rate = 0.01 72 | momentum = 0.9 73 | 74 | update_step_mod = 100 75 | buffer_size = 50 76 | nepochs = 100 77 | 78 | # Neural Linear Limited 79 | nl_lim = {"buffer_size": buffer_size, "opt": optax.sgd(learning_rate, momentum), "eta": eta, "lmbda": lmbda, 80 | "update_step_mod": update_step_mod, "nepochs": nepochs} 81 | 82 | # Neural Linear Unlimited 83 | buffer_size = 4800 84 | 85 | nl_unlim = nl_lim.copy() 86 | nl_unlim["buffer_size"] = buffer_size 87 | 88 | npulls, nwarmup = 20, 2000 89 | learning_rate, momentum = 0.8, 0.9 90 | observation_noise = 0.0 91 | prior_noise_variance = 1e-4 92 | n_components = 470 93 | nepochs = 1000 94 | random_projection = False 95 | 96 | # Subspace Neural with SVD 97 | ekf_sub_svd = {"opt": optax.sgd(learning_rate, momentum), "prior_noise_variance": prior_noise_variance, 98 | "nwarmup": nwarmup, "nepochs": nepochs, 99 | "observation_noise": observation_noise, "n_components": n_components, 100 | "random_projection": random_projection} 101 | 102 | # Subspace Neural without SVD 103 | ekf_sub_rnd = ekf_sub_svd.copy() 104 | ekf_sub_rnd["random_projection"] = True 105 | 106 | system_noise = 0.0 107 | 108 | ekf_orig = {"opt": optax.sgd(learning_rate, momentum), "prior_noise_variance": prior_noise_variance, 109 | "nwarmup": nwarmup, "nepochs": nepochs, 110 | "system_noise": system_noise, "observation_noise": observation_noise} 111 | linear = {} 112 | 113 | 114 | # LoFi 115 | emission_covariance = 0.01 116 | initial_covariance = 1.0 117 | dynamics_weights = 1.0 118 | dynamics_covariance = 0.0 119 | memory_size = 10 120 | 121 | lofi = { 122 | "emission_covariance": emission_covariance, 123 | "initial_covariance": initial_covariance, 124 | "dynamics_weights": dynamics_weights, 125 | "dynamics_covariance": dynamics_covariance, 126 | "memory_size": memory_size 127 | } 128 | 129 | 130 | bandits = {"Linear": {"kwargs": linear, 131 | "bandit": LinearBandit 132 | }, 133 | "Linear KF": {"kwargs": linear.copy(), 134 | "bandit": LinearKFBandit 135 | }, 136 | "Limited Neural Linear": {"kwargs": nl_lim, 137 | "bandit": LimitedMemoryNeuralLinearBandit 138 | }, 139 | "Unlimited Neural Linear": {"kwargs": nl_unlim, 140 | "bandit": LimitedMemoryNeuralLinearBandit 141 | }, 142 | "EKF Subspace SVD": {"kwargs": ekf_sub_svd, 143 | "bandit": SubspaceNeuralBandit 144 | }, 145 | "EKF Subspace RND": {"kwargs": ekf_sub_rnd, 146 | "bandit": SubspaceNeuralBandit 147 | }, 148 | "EKF Diagonal Subspace SVD": {"kwargs": ekf_sub_svd.copy(), 149 | "bandit": DiagonalSubspaceNeuralBandit 150 | }, 151 | "EKF Diagonal Subspace RND": {"kwargs": ekf_sub_rnd.copy(), 152 | "bandit": DiagonalSubspaceNeuralBandit 153 | }, 154 | "EKF Orig Diagonal": {"kwargs": ekf_orig, 155 | "bandit": DiagonalNeuralBandit 156 | }, 157 | "LoFi": { 158 | "kwargs": lofi, 159 | "bandit": LowRankFilterBandit 160 | } 161 | } 162 | 163 | results = [] 164 | 165 | for model_name, model in models.items(): 166 | print(f"Model : {model_name}") 167 | for bandit_name, properties in bandits.items(): 168 | if not bandit_name.startswith("Linear"): 169 | properties["kwargs"]["model"] = model 170 | elif model_name != "MLP1": 171 | continue 172 | print(f"\tBandit : {bandit_name}") 173 | key = PRNGKey(314) 174 | start = time() 175 | warmup_rewards, rewards_trace, opt_rewards = train(key, properties["bandit"], mnist_env, npulls, 176 | config.ntrials, 177 | properties["kwargs"], neural=False) 178 | rtotal, rstd = summarize_results(warmup_rewards, rewards_trace) 179 | end = time() 180 | print(f"\t\tTime : {end - start}") 181 | results.append((bandit_name, model_name, end - start, rtotal, rstd)) 182 | 183 | df = pd.DataFrame(results) 184 | df = df.rename(columns={0: "Method", 1: "Model", 2: "Time", 3: "Reward", 4: "Std"}) 185 | 186 | df["Method"] = df["Method"].apply(lambda v: mapping[v]) 187 | df["Rank"] = df["Method"].apply(lambda v: method_ordering[v]) 188 | df["AltRank"] = df["Model"].apply(lambda v: rank[v]) 189 | 190 | df["Reward"] = df['Reward'].astype(float) 191 | df["Time"] = df['Time'].astype(float) 192 | df["Std"] = df['Std'].astype(float) 193 | df.to_csv(config.filepath) 194 | 195 | 196 | if __name__ == "__main__": 197 | parser = argparse.ArgumentParser() 198 | parser.add_argument('--ntrials', type=int, nargs='?', const=10, default=10) 199 | filepath = "bandits/results/mnist_results.csv" 200 | parser.add_argument('--filepath', type=str, nargs='?', const=filepath, default=filepath) 201 | 202 | # Parse the argument 203 | args = parser.parse_args() 204 | main(args) 205 | -------------------------------------------------------------------------------- /bandits/scripts/tabular_exp.py: -------------------------------------------------------------------------------- 1 | import optax 2 | import pandas as pd 3 | from jax.random import split, PRNGKey 4 | 5 | import argparse 6 | from time import time 7 | 8 | from environments.tabular_env import TabularEnvironment 9 | 10 | from agents.linear_bandit import LinearBandit 11 | from agents.linear_kf_bandit import LinearKFBandit 12 | from agents.linear_bandit_wide import LinearBanditWide 13 | from agents.ekf_subspace import SubspaceNeuralBandit 14 | from agents.ekf_orig_diag import DiagonalNeuralBandit 15 | from agents.ekf_orig_full import EKFNeuralBandit 16 | from agents.diagonal_subspace import DiagonalSubspaceNeuralBandit 17 | from agents.limited_memory_neural_linear import LimitedMemoryNeuralLinearBandit 18 | from agents.low_rank_filter_bandit import LowRankFilterBandit 19 | 20 | from .training_utils import train, MLP, summarize_results 21 | from .mnist_exp import mapping, method_ordering 22 | 23 | 24 | def main(config): 25 | # Tabular datasets 26 | key = PRNGKey(0) 27 | shuttle_key, covetype_key, adult_key, stock_key = split(key, 4) 28 | ntrain = 5000 29 | 30 | shuttle_env = TabularEnvironment(shuttle_key, ntrain=ntrain, name='statlog', intercept=False, path="./bandit-data") 31 | covertype_env = TabularEnvironment(covetype_key, ntrain=ntrain, name='covertype', intercept=False, path="./bandit-data") 32 | adult_env = TabularEnvironment(adult_key, ntrain=ntrain, name='adult', intercept=False, path="./bandit-data") 33 | environments = {"shuttle": shuttle_env, "covertype": covertype_env, "adult": adult_env} 34 | 35 | # Linear & Linear Wide 36 | linear = {} 37 | 38 | # Neural Linear Limited 39 | eta = 6.0 40 | lmbda = 0.25 41 | 42 | learning_rate = 0.05 43 | momentum = 0.9 44 | prior_noise_variance = 1e-3 45 | observation_noise = 0.01 46 | 47 | update_step_mod = 100 48 | buffer_size = 20 49 | nepochs = 100 50 | 51 | nl_lim = {"buffer_size": buffer_size, "opt": optax.sgd(learning_rate, momentum), "eta": eta, "lmbda": lmbda, 52 | "update_step_mod": update_step_mod, "nepochs": nepochs} 53 | 54 | 55 | # Neural Linear Limited 56 | buffer_size = 5000 57 | nl_unlim = nl_lim.copy() 58 | nl_unlim["buffer_size"] = buffer_size 59 | 60 | 61 | # Subspace Neural Bandit with SVD 62 | npulls, nwarmup = 20, 2000 63 | observation_noise = 0.0 64 | prior_noise_variance = 1e-4 65 | n_components = 470 66 | nepochs = 1000 67 | random_projection = False 68 | 69 | ekf_sub_svd = {"opt": optax.sgd(learning_rate, momentum), "prior_noise_variance": prior_noise_variance, 70 | "nwarmup": nwarmup, "nepochs": nepochs, 71 | "observation_noise": observation_noise, "n_components": n_components, 72 | "random_projection": random_projection} 73 | 74 | # Subspace Neural Bandit without SVD 75 | ekf_sub_rnd = ekf_sub_svd.copy() 76 | ekf_sub_rnd["random_projection"] = True 77 | 78 | # EKF Neural & EKF Neural Diagonal 79 | system_noise = 0.0 80 | prior_noise_variance = 1e-3 81 | nepochs = 100 82 | nwarmup = 1000 83 | learning_rate = 0.05 84 | momentum = 0.9 85 | observation_noise = 0.01 86 | 87 | ekf_orig = { 88 | "opt": optax.sgd(learning_rate, momentum), 89 | "prior_noise_variance": prior_noise_variance, 90 | "nwarmup": nwarmup, 91 | "nepochs": nepochs, 92 | "system_noise": system_noise, 93 | "observation_noise": observation_noise 94 | } 95 | 96 | 97 | # LoFi 98 | emission_covariance = 0.01 99 | initial_covariance = 1.0 100 | dynamics_weights = 1.0 101 | dynamics_covariance = 0.0 102 | memory_size = 10 103 | 104 | lofi = { 105 | "emission_covariance": emission_covariance, 106 | "initial_covariance": initial_covariance, 107 | "dynamics_weights": dynamics_weights, 108 | "dynamics_covariance": dynamics_covariance, 109 | "memory_size": memory_size 110 | } 111 | 112 | 113 | 114 | bandits = {"Linear": {"kwargs": linear, 115 | "bandit": LinearBandit 116 | }, 117 | "Linear KF": {"kwargs": linear.copy(), 118 | "bandit": LinearKFBandit 119 | }, 120 | "Linear Wide": {"kwargs": linear, 121 | "bandit": LinearBanditWide 122 | }, 123 | "Limited Neural Linear": {"kwargs": nl_lim, 124 | "bandit": LimitedMemoryNeuralLinearBandit 125 | }, 126 | "Unlimited Neural Linear": {"kwargs": nl_unlim, 127 | "bandit": LimitedMemoryNeuralLinearBandit 128 | }, 129 | "EKF Subspace SVD": {"kwargs": ekf_sub_svd, 130 | "bandit": SubspaceNeuralBandit 131 | }, 132 | "EKF Subspace RND": {"kwargs": ekf_sub_rnd, 133 | "bandit": SubspaceNeuralBandit 134 | }, 135 | "EKF Diagonal Subspace SVD": {"kwargs": ekf_sub_svd, 136 | "bandit": DiagonalSubspaceNeuralBandit 137 | }, 138 | "EKF Diagonal Subspace RND": {"kwargs": ekf_sub_rnd, 139 | "bandit": DiagonalSubspaceNeuralBandit 140 | }, 141 | "EKF Orig Diagonal": {"kwargs": ekf_orig, 142 | "bandit": DiagonalNeuralBandit 143 | }, 144 | "EKF Orig Full": {"kwargs": ekf_orig, 145 | "bandit": EKFNeuralBandit 146 | }, 147 | "LoFi": { 148 | "kwargs": lofi, 149 | "bandit": LowRankFilterBandit 150 | } 151 | } 152 | 153 | results = [] 154 | 155 | for env_name, env in environments.items(): 156 | print("Environment : ", env_name) 157 | num_arms = env.labels_onehot.shape[-1] 158 | models = {"MLP1": MLP(num_arms)} # You could also add MLPWide(num_arms) 159 | 160 | for model_name, model in models.items(): 161 | for bandit_name, properties in bandits.items(): 162 | if not bandit_name.startswith("Linear"): 163 | properties["kwargs"]["model"] = model 164 | print(f"\tBandit : {bandit_name}") 165 | key = PRNGKey(314) 166 | start = time() 167 | warmup_rewards, rewards_trace, opt_rewards = train(key, properties["bandit"], env, npulls, 168 | config.ntrials, 169 | properties["kwargs"], neural=False) 170 | 171 | rtotal, rstd = summarize_results(warmup_rewards, rewards_trace) 172 | end = time() 173 | print(f"\t\tTime : {end - start:0.3f}s") 174 | results.append((env_name, bandit_name, end - start, rtotal, rstd)) 175 | 176 | # initialize results given in the paper 177 | # running time, mean, and std values for Lim2. 178 | # We obtained these values by running the following code: 179 | # https://github.com/ofirnabati/Neural-Linear-Bandits-with-Likelihood-Matching 180 | # set to the parameters presented in the paper: https://arxiv.org/abs/2102.03799 181 | lim2data = [["shuttle", "Lim2", 42.20236960171787, 4826.4, 319.82351111111], 182 | ["covertype", "Lim2", 124.96883611524915, 2660.7, 333.93744444444], 183 | ["adult", "Lim2", 34.89770766110576, 3985.5, 113.127926], 184 | ] 185 | 186 | # Values obtained from appendix B of https://arxiv.org/abs/2102.03799 187 | neuraltsdata = [ 188 | ["shuttle", "NeuralTS", 0.0, 4348, 265], 189 | ["covertype", "NeuralTS", 0.0, 1877, 83], 190 | ["adult", "NeuralTS", 0.0, 3769, 2], ] 191 | 192 | df = pd.DataFrame(results + lim2data + neuraltsdata) 193 | df = df.rename(columns={0: "Dataset", 1: "Method", 2: "Time", 3: "Reward", 4: "Std"}) 194 | 195 | df["Method"] = df["Method"].apply(lambda v: mapping[v]) 196 | 197 | df["Reward"] = df['Reward'].astype(float) 198 | df["Time"] = df['Time'].astype(float) 199 | df["Std"] = df['Std'].astype(float) 200 | 201 | df["Rank"] = df["Method"].apply(lambda v: method_ordering[v]) 202 | df.to_csv(config.filepath) 203 | 204 | 205 | if __name__ == "__main__": 206 | parser = argparse.ArgumentParser() 207 | parser.add_argument('--ntrials', type=int, nargs='?', const=10, default=10) 208 | filepath = "bandits/results/tabular_results.csv" 209 | parser.add_argument('--filepath', type=str, nargs='?', const=filepath, default=filepath) 210 | 211 | # Parse the argument 212 | args = parser.parse_args() 213 | main(args) 214 | -------------------------------------------------------------------------------- /bandits/scripts/subspace_bandits.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "source": [ 16 | "# Bayesian Subspace bandits\n", 17 | "\n", 18 | "See https://arxiv.org/abs/2112.00195 for details.\n" 19 | ], 20 | "metadata": { 21 | "id": "1o3MquliBXCr" 22 | } 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "source": [ 27 | "## Installation" 28 | ], 29 | "metadata": { 30 | "id": "e9DNLtCwCOTb" 31 | } 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 1, 36 | "metadata": { 37 | "colab": { 38 | "base_uri": "https://localhost:8080/" 39 | }, 40 | "id": "fjs5Hm5_Env3", 41 | "outputId": "e292342a-7120-4f2a-c45e-0725eb12f6a7" 42 | }, 43 | "outputs": [ 44 | { 45 | "output_type": "stream", 46 | "name": "stdout", 47 | "text": [ 48 | "Cloning into 'bandits'...\n", 49 | "remote: Enumerating objects: 56, done.\u001b[K\n", 50 | "remote: Counting objects: 100% (56/56), done.\u001b[K\n", 51 | "remote: Compressing objects: 100% (53/53), done.\u001b[K\n", 52 | "remote: Total 56 (delta 11), reused 23 (delta 1), pack-reused 0\u001b[K\n", 53 | "Unpacking objects: 100% (56/56), done.\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "!git clone --depth 1 https://github.com/probml/bandits" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 2, 64 | "metadata": { 65 | "id": "Bu0fD71JEaiR", 66 | "outputId": "b14bb16c-df35-4ca7-e4de-a2bf785e7267", 67 | "colab": { 68 | "base_uri": "https://localhost:8080/" 69 | } 70 | }, 71 | "outputs": [ 72 | { 73 | "output_type": "stream", 74 | "name": "stdout", 75 | "text": [ 76 | "\u001b[?25l\r\u001b[K |███▊ | 10 kB 32.8 MB/s eta 0:00:01\r\u001b[K |███████▌ | 20 kB 37.0 MB/s eta 0:00:01\r\u001b[K |███████████▏ | 30 kB 22.2 MB/s eta 0:00:01\r\u001b[K |███████████████ | 40 kB 18.1 MB/s eta 0:00:01\r\u001b[K |██████████████████▊ | 51 kB 17.1 MB/s eta 0:00:01\r\u001b[K |██████████████████████▍ | 61 kB 14.0 MB/s eta 0:00:01\r\u001b[K |██████████████████████████▏ | 71 kB 11.9 MB/s eta 0:00:01\r\u001b[K |██████████████████████████████ | 81 kB 13.1 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 87 kB 5.8 MB/s \n", 77 | "\u001b[?25h Building wheel for fire (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 78 | "\u001b[K |████████████████████████████████| 88 kB 6.0 MB/s \n", 79 | "\u001b[K |████████████████████████████████| 65 kB 3.3 MB/s \n", 80 | "\u001b[?25h Building wheel for optax (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 81 | " Building wheel for flax (setup.py) ... \u001b[?25l\u001b[?25hdone\n" 82 | ] 83 | } 84 | ], 85 | "source": [ 86 | "!pip install -qqq fire\n", 87 | "!pip install -qqq ml-collections\n", 88 | "!pip install -qqq git+git://github.com/deepmind/optax.git\n", 89 | "!pip install -qqq --upgrade git+https://github.com/google/flax.git" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "source": [ 95 | "## Test the installatation" 96 | ], 97 | "metadata": { 98 | "id": "dK6p1QZrBfly" 99 | } 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 3, 104 | "metadata": { 105 | "colab": { 106 | "base_uri": "https://localhost:8080/" 107 | }, 108 | "id": "scOTAFKLncE5", 109 | "outputId": "9d5602e2-7c95-4d85-be66-6beacf9c7349" 110 | }, 111 | "outputs": [ 112 | { 113 | "output_type": "stream", 114 | "name": "stdout", 115 | "text": [ 116 | "Expected Reward : 4419.70 ± 13.78\n", 117 | "Time : 11.732s\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "%%bash\n", 123 | "cd /content/bandits\n", 124 | "python bandits test" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "source": [ 130 | "## Setup " 131 | ], 132 | "metadata": { 133 | "id": "gx6W3l60Birc" 134 | } 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": { 140 | "colab": { 141 | "base_uri": "https://localhost:8080/" 142 | }, 143 | "id": "IgQ7Wq37LUQG", 144 | "outputId": "21ad3203-0a12-4004-df85-04a90f282821" 145 | }, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "/content/bandits/bandits/experiments\n" 152 | ] 153 | } 154 | ], 155 | "source": [ 156 | "%cd /content/bandits/bandits/experiments" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": { 163 | "colab": { 164 | "base_uri": "https://localhost:8080/" 165 | }, 166 | "id": "YSgufIApEaiU", 167 | "outputId": "d7af82c8-f00c-4bb6-f16f-47c63828026f", 168 | "tags": [] 169 | }, 170 | "outputs": [ 171 | { 172 | "name": "stderr", 173 | "output_type": "stream", 174 | "text": [ 175 | "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" 176 | ] 177 | }, 178 | { 179 | "name": "stdout", 180 | "output_type": "stream", 181 | "text": [ 182 | "1\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "import os\n", 188 | "os.chdir(\"..\")\n", 189 | "\n", 190 | "import jax\n", 191 | "import ml_collections\n", 192 | "\n", 193 | "import pandas as pd\n", 194 | "\n", 195 | "import glob\n", 196 | "from datetime import datetime\n", 197 | "\n", 198 | "import scripts.movielens_exp as movielens_run\n", 199 | "import scripts.mnist_exp as mnist_run\n", 200 | "import scripts.tabular_exp as tabular_run\n", 201 | "import scripts.tabular_subspace_exp as tabular_sub_run\n", 202 | "\n", 203 | "print(jax.device_count())" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": { 210 | "id": "8gnJYer1EaiV" 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "def get_config(results_filename):\n", 215 | " \"\"\"Get the default hyperparameter configuration.\"\"\"\n", 216 | " config = ml_collections.ConfigDict()\n", 217 | " config.filepath = results_filename\n", 218 | " config.ntrials = 2 # was 10 in paper\n", 219 | " return config" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": { 226 | "id": "Yo1dXiJSEaiV" 227 | }, 228 | "outputs": [], 229 | "source": [ 230 | "timestamp = datetime.timestamp(datetime.now())" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "metadata": { 237 | "id": "QuYleUSEG-or" 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "def plot_figure(data, x, y, filename, figsize=(24, 9), log_scale=False): \n", 242 | " sns.set(font_scale=1.5)\n", 243 | " plt.style.use(\"seaborn-poster\")\n", 244 | "\n", 245 | " fig, ax = plt.subplots(figsize=figsize, dpi=300)\n", 246 | " g = sns.barplot(x=x, y=y, hue=\"Method\", data=data, errwidth=2, ax=ax, palette=colors)\n", 247 | " if log_scale:\n", 248 | " g.set_yscale(\"log\")\n", 249 | " plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0))\n", 250 | " plt.tight_layout()\n", 251 | " plt.savefig(f\"./figures/{filename}.png\")\n", 252 | " plt.show()\n", 253 | "\n", 254 | "def read_data(dataset_name):\n", 255 | " *_, filename = sorted(glob.glob(f\"./results/{dataset_name}_results*.csv\"))\n", 256 | " df = pd.read_csv(filename)\n", 257 | " if dataset_name==\"mnist\":\n", 258 | " linear_df = df[(df[\"Method\"]==\"Lin-KF\") | (df[\"Method\"]==\"Lin\")].copy()\n", 259 | " linear_df[\"Model\"] = \"MLP2\"\n", 260 | " df = df.append(linear_df)\n", 261 | " linear_df[\"Model\"] = \"LeNet5\"\n", 262 | " df = df.append(linear_df)\n", 263 | "\n", 264 | " by = [\"Rank\"] if dataset_name==\"tabular\" else [\"Rank\", \"AltRank\"]\n", 265 | "\n", 266 | " data_up = df.sort_values(by=by).copy()\n", 267 | " data_down = df.sort_values(by=by).copy()\n", 268 | "\n", 269 | " data_up[\"Reward\"] = data_up[\"Reward\"] + data_up[\"Std\"]\n", 270 | " data_down[\"Reward\"] = data_down[\"Reward\"] - data_down[\"Std\"]\n", 271 | " data = pd.concat([data_up, data_down])\n", 272 | " return data\n", 273 | "\n", 274 | "def plot_subspace_figure(df, filename=None):\n", 275 | " df = df.reset_index().drop(columns=[\"index\"])\n", 276 | " plt.style.use(\"seaborn-darkgrid\")\n", 277 | " fig, ax = plt.subplots(figsize=(12, 8))\n", 278 | " sns.lineplot(x=\"Subspace Dim\", y=\"Reward\", hue=\"Method\", marker=\"o\", data=df)\n", 279 | " lines, labels = ax.get_legend_handles_labels()\n", 280 | " for line, method in zip(lines, labels):\n", 281 | " data = df[df[\"Method\"]==method]\n", 282 | " color = line.get_c()\n", 283 | " y_lower_bound = data[\"Reward\"] - data[\"Std\"]\n", 284 | " y_upper_bound = data[\"Reward\"] + data[\"Std\"]\n", 285 | " ax.fill_between(data[\"Subspace Dim\"], y_lower_bound, y_upper_bound, color=color, alpha=0.3)\n", 286 | "\n", 287 | " ax.set_ylabel(\"Reward\", fontsize=16)\n", 288 | " plt.setp(ax.get_xticklabels(), fontsize=16) \n", 289 | " plt.setp(ax.get_yticklabels(), fontsize=16) \n", 290 | " ax.set_xlabel(\"Subspace Dimension(d)\", fontsize=16)\n", 291 | " dataset = df.iloc[0][\"Dataset\"]\n", 292 | " ax.set_title(f\"{dataset.title()} - Subspace Dim vs. Reward\", fontsize=18)\n", 293 | " legend = ax.legend(loc=\"lower right\", prop={'size': 16},frameon=1)\n", 294 | " frame = legend.get_frame()\n", 295 | " frame.set_color('white')\n", 296 | " frame.set_alpha(0.6)\n", 297 | " \n", 298 | " file_path = \"./figures/\"\n", 299 | " file_path = file_path + f\"{dataset}_sub_reward.png\" if filename is None else file_path + f\"{filename}.png\"\n", 300 | " plt.savefig(file_path)" 301 | ] 302 | }, 303 | { 304 | "cell_type": "markdown", 305 | "metadata": { 306 | "id": "NaqV5SteEaiV" 307 | }, 308 | "source": [ 309 | "# Run tabular experiments" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "metadata": { 316 | "colab": { 317 | "base_uri": "https://localhost:8080/" 318 | }, 319 | "id": "XVOocnQw7mKl", 320 | "outputId": "bb987eec-617c-448e-cb7a-9274381c4f2c" 321 | }, 322 | "outputs": [ 323 | { 324 | "name": "stdout", 325 | "output_type": "stream", 326 | "text": [ 327 | "/content/bandits/bandits/experiments\n" 328 | ] 329 | } 330 | ], 331 | "source": [ 332 | "%cd /content/bandits/bandits" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": null, 338 | "metadata": { 339 | "colab": { 340 | "base_uri": "https://localhost:8080/" 341 | }, 342 | "id": "zoWaUbkdEaiX", 343 | "outputId": "e6db79b7-ac55-452d-aa7d-23449547bb69" 344 | }, 345 | "outputs": [ 346 | { 347 | "name": "stdout", 348 | "output_type": "stream", 349 | "text": [ 350 | "Environment : shuttle\n", 351 | "\tBandit : Linear\n", 352 | "\t\tExpected Reward : 4413.50 ± 4.50\n", 353 | "\t\tTime : 10.469s\n", 354 | "\tBandit : Linear KF\n", 355 | "\t\tExpected Reward : 4414.50 ± 4.50\n", 356 | "\t\tTime : 6.309s\n", 357 | "\tBandit : Linear Wide\n", 358 | "\t\tExpected Reward : 4210.00 ± 10.00\n", 359 | "\t\tTime : 25.030s\n", 360 | "\tBandit : Limited Neural Linear\n", 361 | "\t\tExpected Reward : 3840.00 ± 3.00\n", 362 | "\t\tTime : 23.608s\n", 363 | "\tBandit : Unlimited Neural Linear\n", 364 | "\t\tExpected Reward : 4089.00 ± 70.00\n", 365 | "\t\tTime : 42.628s\n", 366 | "\tBandit : EKF Subspace SVD\n", 367 | "\t\tExpected Reward : 4731.00 ± 116.00\n", 368 | "\t\tTime : 198.925s\n", 369 | "\tBandit : EKF Subspace RND\n", 370 | "\t\tExpected Reward : 4846.50 ± 1.50\n", 371 | "\t\tTime : 199.065s\n", 372 | "\tBandit : EKF Diagonal Subspace SVD\n", 373 | "\t\tExpected Reward : 4831.00 ± 0.00\n", 374 | "\t\tTime : 9.122s\n", 375 | "\tBandit : EKF Diagonal Subspace RND\n", 376 | "\t\tExpected Reward : 4797.00 ± 0.00\n", 377 | "\t\tTime : 9.127s\n", 378 | "\tBandit : EKF Orig Diagonal\n", 379 | "\t\tExpected Reward : 3915.00 ± 4.00\n", 380 | "\t\tTime : 6.106s\n", 381 | "\tBandit : EKF Orig Full\n", 382 | "\t\tExpected Reward : 3913.00 ± 2.00\n", 383 | "\t\tTime : 875.099s\n", 384 | "Environment : covertype\n", 385 | "\tBandit : Linear\n", 386 | "\t\tExpected Reward : 3016.50 ± 13.50\n", 387 | "\t\tTime : 20.976s\n", 388 | "\tBandit : Linear KF\n", 389 | "\t\tExpected Reward : 3014.50 ± 11.50\n", 390 | "\t\tTime : 10.928s\n", 391 | "\tBandit : Linear Wide\n", 392 | "\t\tExpected Reward : 1831.50 ± 5.50\n", 393 | "\t\tTime : 272.890s\n", 394 | "\tBandit : Limited Neural Linear\n", 395 | "\t\tExpected Reward : 1835.00 ± 4.00\n", 396 | "\t\tTime : 20.702s\n", 397 | "\tBandit : Unlimited Neural Linear\n", 398 | "\t\tExpected Reward : 2760.00 ± 26.00\n", 399 | "\t\tTime : 44.914s\n", 400 | "\tBandit : EKF Subspace SVD\n", 401 | "\t\tExpected Reward : 3211.00 ± 12.00\n", 402 | "\t\tTime : 211.246s\n", 403 | "\tBandit : EKF Subspace RND\n", 404 | "\t\tExpected Reward : 3216.00 ± 3.00\n", 405 | "\t\tTime : 212.506s\n", 406 | "\tBandit : EKF Diagonal Subspace SVD\n", 407 | "\t\tExpected Reward : 2315.00 ± 0.00\n", 408 | "\t\tTime : 17.255s\n", 409 | "\tBandit : EKF Diagonal Subspace RND\n", 410 | "\t\tExpected Reward : 2766.00 ± 0.00\n", 411 | "\t\tTime : 16.895s\n", 412 | "\tBandit : EKF Orig Diagonal\n", 413 | "\t\tExpected Reward : 1369.00 ± 1025.00\n", 414 | "\t\tTime : 4.503s\n", 415 | "\tBandit : EKF Orig Full\n" 416 | ] 417 | } 418 | ], 419 | "source": [ 420 | "tabular_filename = f\"./results/tabular_results_{timestamp}.csv\"\n", 421 | "config = get_config(tabular_filename)\n", 422 | "tabular_run.main(config)" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": null, 428 | "metadata": { 429 | "id": "6Jkow86vHXBl" 430 | }, 431 | "outputs": [], 432 | "source": [ 433 | "dataset_name = \"tabular\"\n", 434 | "tabular_df = read_data(dataset_name)\n", 435 | "tabular_rows = ['EKF-Sub-SVD', 'EKF-Sub-RND', 'EKF-Sub-Diag-SVD', 'EKF-Sub-Diag-RND',\n", 436 | " 'EKF-Orig-Full', 'EKF-Orig-Diag', 'NL-Lim', 'NL-Unlim', 'Lin', 'Lim2', 'NeuralTS']\n", 437 | "tabular_df = tabular_df[tabular_df['Method'].isin(tabular_rows)]" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "metadata": { 444 | "id": "YO4xTn4THjMN" 445 | }, 446 | "outputs": [], 447 | "source": [ 448 | "x, y = \"Dataset\", \"Reward\"\n", 449 | "filename = f\"{dataset_name}_{y.lower()}\"\n", 450 | "plot_figure(tabular_df, x, y, filename)" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": null, 456 | "metadata": { 457 | "id": "Eh_9zkdLHmU-" 458 | }, 459 | "outputs": [], 460 | "source": [ 461 | "x, y = \"Dataset\", \"Time\"\n", 462 | "filename = f\"{dataset_name}_{y.lower()}\"\n", 463 | "plot_figure(tabular_df[tabular_df[\"Method\"] != \"NeuralTS\"], x, y, filename, log_scale=True)" 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "metadata": { 469 | "id": "tUxNfJnoEaiZ" 470 | }, 471 | "source": [ 472 | "# Run movielens experiments" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": null, 478 | "metadata": { 479 | "id": "XkzXuQp3Eaia" 480 | }, 481 | "outputs": [], 482 | "source": [ 483 | "movielens_filename = f\"./results/movielens_results_{timestamp}.csv\"\n", 484 | "config = get_config(movielens_filename)\n", 485 | "movielens_run.main(config)" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": null, 491 | "metadata": { 492 | "id": "JME59OYvILCs" 493 | }, 494 | "outputs": [], 495 | "source": [ 496 | "dataset_name = \"movielens\"\n", 497 | "movielens_df = read_data(dataset_name)\n", 498 | "movielens_rows = ['EKF-Sub-SVD', 'EKF-Sub-RND', 'EKF-Sub-Diag-SVD', 'EKF-Sub-Diag-RND',\n", 499 | " 'EKF-Orig-Diag', 'NL-Lim', 'NL-Unlim', 'Lin']\n", 500 | "movielens_df = movielens_df[movielens_df['Method'].isin(movielens_rows)]" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": null, 506 | "metadata": { 507 | "id": "cJ37A_YtIM9i" 508 | }, 509 | "outputs": [], 510 | "source": [ 511 | "x, y = \"Model\", \"Reward\"\n", 512 | "filename = f\"{dataset_name}_{y.lower()}\"\n", 513 | "plot_figure(movielens_df, x, y, filename)" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": null, 519 | "metadata": { 520 | "id": "KDLNAIQ6IPFT" 521 | }, 522 | "outputs": [], 523 | "source": [ 524 | "x, y = \"Model\", \"Time\"\n", 525 | "filename = f\"{dataset_name}_{y.lower()}\"\n", 526 | "plot_figure(movielens_df, x, y, filename)" 527 | ] 528 | }, 529 | { 530 | "cell_type": "markdown", 531 | "metadata": { 532 | "id": "PkJlfdAyEaiX" 533 | }, 534 | "source": [ 535 | "# Run MNIST experiments" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": null, 541 | "metadata": { 542 | "id": "9J26gcVTEaiZ" 543 | }, 544 | "outputs": [], 545 | "source": [ 546 | "mnist_filename = f\"./results/mnist_results_{timestamp}.csv\"\n", 547 | "config = get_config(mnist_filename)\n", 548 | "mnist_run.main(config)" 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "execution_count": null, 554 | "metadata": { 555 | "id": "I0--fVenH7kK" 556 | }, 557 | "outputs": [], 558 | "source": [ 559 | "method_ordering = {\"EKF-Sub-SVD\": 0,\n", 560 | " \"EKF-Sub-RND\": 1,\n", 561 | " \"EKF-Sub-Diag-SVD\": 2,\n", 562 | " \"EKF-Sub-Diag-RND\": 3,\n", 563 | " \"EKF-Orig-Full\": 4,\n", 564 | " \"EKF-Orig-Diag\": 5,\n", 565 | " \"NL-Lim\": 6,\n", 566 | " \"NL-Unlim\": 7,\n", 567 | " \"Lin\": 8,\n", 568 | " \"Lin-KF\": 9,\n", 569 | " \"Lin-Wide\": 9,\n", 570 | " \"Lim2\": 10,\n", 571 | " \"NeuralTS\": 11}\n", 572 | " \n", 573 | "colors = {k : sns.color_palette(\"Paired\")[v]\n", 574 | " if k!=\"Lin-KF\" else sns.color_palette(\"tab20\")[8]\n", 575 | " for k,v in method_ordering.items()}" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": null, 581 | "metadata": { 582 | "id": "8AwA3qTaH-Oo" 583 | }, 584 | "outputs": [], 585 | "source": [ 586 | "dataset_name = \"mnist\"\n", 587 | "# For possible methods, run mnist_df.Method.unique()\n", 588 | "mnist_rows = ['EKF-Sub-SVD', 'EKF-Sub-RND', 'EKF-Sub-Diag-SVD', 'EKF-Sub-Diag-RND', 'EKF-Orig-Diag', 'NL-Lim', 'NL-Unlim', 'Lin']" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": null, 594 | "metadata": { 595 | "id": "lbpYHsbXIAf_" 596 | }, 597 | "outputs": [], 598 | "source": [ 599 | "mnist_df = read_data(dataset_name)\n", 600 | "mnist_df = mnist_df[mnist_df['Method'].isin(mnist_rows)]" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": null, 606 | "metadata": { 607 | "id": "d7AUy1jjICGJ" 608 | }, 609 | "outputs": [], 610 | "source": [ 611 | "x, y = \"Model\", \"Reward\"\n", 612 | "filename = f\"{dataset_name}_{y.lower()}\"\n", 613 | "plot_figure(mnist_df, x, y, filename)" 614 | ] 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": null, 619 | "metadata": { 620 | "id": "WoNF2EB6IEPQ" 621 | }, 622 | "outputs": [], 623 | "source": [ 624 | "x, y = \"Model\", \"Time\"\n", 625 | "filename = f\"{dataset_name}_{y.lower()}\"\n", 626 | "plot_figure(mnist_df, x, y, filename, log_scale=True)" 627 | ] 628 | }, 629 | { 630 | "cell_type": "markdown", 631 | "metadata": { 632 | "id": "a_dE_wIQEaia" 633 | }, 634 | "source": [ 635 | "# Run tabular subspace experiment" 636 | ] 637 | }, 638 | { 639 | "cell_type": "code", 640 | "execution_count": null, 641 | "metadata": { 642 | "id": "4usZ5SIdEaib" 643 | }, 644 | "outputs": [], 645 | "source": [ 646 | "tabular_sub_filename = f\"./results/tabular_subspace_results_{timestamp}.csv\"\n", 647 | "config = get_config(tabular_sub_filename)\n", 648 | "tabular_sub_run.main(config)" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": null, 654 | "metadata": { 655 | "id": "xGAjjg5bIRfN" 656 | }, 657 | "outputs": [], 658 | "source": [ 659 | "*_, filename = sorted(glob.glob(f\"./results/tabular_subspace_results*.csv\"))\n", 660 | "tabular_sub_df = pd.read_csv(filename)" 661 | ] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": null, 666 | "metadata": { 667 | "id": "BP-hf3XbITZy" 668 | }, 669 | "outputs": [], 670 | "source": [ 671 | "dataset_name = \"shuttle\"\n", 672 | "shuttle = tabular_sub_df[tabular_sub_df[\"Dataset\"]==dataset_name]\n", 673 | "plot_subspace_figure(shuttle)" 674 | ] 675 | }, 676 | { 677 | "cell_type": "code", 678 | "execution_count": null, 679 | "metadata": { 680 | "id": "wKV6qFqsITAY" 681 | }, 682 | "outputs": [], 683 | "source": [ 684 | "dataset_name = \"adult\"\n", 685 | "adult = tabular_sub_df[tabular_sub_df[\"Dataset\"]==dataset_name]\n", 686 | "plot_subspace_figure(adult)" 687 | ] 688 | }, 689 | { 690 | "cell_type": "code", 691 | "execution_count": null, 692 | "metadata": { 693 | "id": "s7K8-ONgIW1k" 694 | }, 695 | "outputs": [], 696 | "source": [ 697 | "dataset_name = \"covertype\"\n", 698 | "covertype = tabular_sub_df[tabular_sub_df[\"Dataset\"]==dataset_name]\n", 699 | "plot_subspace_figure(covertype)" 700 | ] 701 | } 702 | ], 703 | "metadata": { 704 | "accelerator": "GPU", 705 | "colab": { 706 | "machine_shape": "hm", 707 | "name": "subspace-bandits.ipynb", 708 | "provenance": [], 709 | "toc_visible": true, 710 | "include_colab_link": true 711 | }, 712 | "interpreter": { 713 | "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" 714 | }, 715 | "kernelspec": { 716 | "display_name": "Python 3", 717 | "name": "python3" 718 | }, 719 | "language_info": { 720 | "name": "python" 721 | } 722 | }, 723 | "nbformat": 4, 724 | "nbformat_minor": 0 725 | } --------------------------------------------------------------------------------