├── bandits
├── __init__.py
├── figures
│ └── empty
├── results
│ └── empty
├── scripts
│ ├── __init__.py
│ ├── tabular_test.py
│ ├── run_experiments.py
│ ├── run_experiments.ipynb
│ ├── training_utils.py
│ ├── thompson_sampling_bernoulli.py
│ ├── tabular_subspace_exp.py
│ ├── plot_results.py
│ ├── movielens_exp.py
│ ├── mnist_exp.py
│ ├── tabular_exp.py
│ └── subspace_bandits.ipynb
├── __main__.py
├── environments
│ ├── mnist_env.py
│ ├── environment.py
│ ├── ads16_env.py
│ ├── movielens_env.py
│ └── tabular_env.py
├── agents
│ ├── base.py
│ ├── agent_utils.py
│ ├── diagonal_subspace.py
│ ├── linear_kf_bandit.py
│ ├── linear_bandit.py
│ ├── linear_bandit_wide.py
│ ├── neural_greedy.py
│ ├── low_rank_filter_bandit.py
│ ├── neural_linear_bandit_wide.py
│ ├── ekf_orig_diag.py
│ ├── ekf_orig_full.py
│ ├── neural_linear.py
│ ├── ekf_subspace.py
│ └── limited_memory_neural_linear.py
└── training.py
├── aistats2022-slides
├── .npmrc
├── README.md
├── public
│ ├── ts-bandits.mp4
│ └── subspace-neural-bandit-diagram.jpg
├── vercel.json
├── .gitignore
├── netlify.toml
├── style.css
├── package.json
├── components
│ └── Counter.vue
├── assets
│ └── subspace-neural-bandit-diagram.tex
└── slides.md
├── bandit-data
├── bandit-adult.pkl
├── bandit-stock.pkl
├── bandit-mushroom.pkl
├── bandit-statlog.pkl
├── bandit-covertype.pkl
└── ml-100k
│ └── README.txt
├── setup.py
├── requirements.txt
├── LICENSE
├── README.md
├── .gitignore
└── demos
└── lofi_tabular.py
/bandits/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/bandits/figures/empty:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/bandits/results/empty:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/bandits/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aistats2022-slides/.npmrc:
--------------------------------------------------------------------------------
1 | # for pnpm
2 | shamefully-hoist=true
3 |
--------------------------------------------------------------------------------
/aistats2022-slides/README.md:
--------------------------------------------------------------------------------
1 | # Subspace EKF neural bandits
2 | ## AIStats 2022
3 |
--------------------------------------------------------------------------------
/bandit-data/bandit-adult.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probml/bandits/main/bandit-data/bandit-adult.pkl
--------------------------------------------------------------------------------
/bandit-data/bandit-stock.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probml/bandits/main/bandit-data/bandit-stock.pkl
--------------------------------------------------------------------------------
/bandit-data/bandit-mushroom.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probml/bandits/main/bandit-data/bandit-mushroom.pkl
--------------------------------------------------------------------------------
/bandit-data/bandit-statlog.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probml/bandits/main/bandit-data/bandit-statlog.pkl
--------------------------------------------------------------------------------
/bandit-data/bandit-covertype.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probml/bandits/main/bandit-data/bandit-covertype.pkl
--------------------------------------------------------------------------------
/aistats2022-slides/public/ts-bandits.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probml/bandits/main/aistats2022-slides/public/ts-bandits.mp4
--------------------------------------------------------------------------------
/aistats2022-slides/vercel.json:
--------------------------------------------------------------------------------
1 | {
2 | "rewrites": [
3 | { "source": "/(.*)", "destination": "/index.html" }
4 | ]
5 | }
6 |
--------------------------------------------------------------------------------
/aistats2022-slides/.gitignore:
--------------------------------------------------------------------------------
1 | *.pdf
2 | node_modules
3 | .DS_Store
4 | dist
5 | *.local
6 | index.html
7 | .remote-assets
8 | components.d.ts
9 | .texpadtmp
10 |
--------------------------------------------------------------------------------
/aistats2022-slides/public/subspace-neural-bandit-diagram.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probml/bandits/main/aistats2022-slides/public/subspace-neural-bandit-diagram.jpg
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from setuptools import setup, find_packages
3 |
4 | setup(
5 | name="bandits",
6 | packages=find_packages(),
7 | install_requires=[]
8 | )
9 |
--------------------------------------------------------------------------------
/aistats2022-slides/netlify.toml:
--------------------------------------------------------------------------------
1 | [build.environment]
2 | NODE_VERSION = "14"
3 |
4 | [build]
5 | publish = "dist"
6 | command = "npm run build"
7 |
8 | [[redirects]]
9 | from = "/*"
10 | to = "/index.html"
11 | status = 200
12 |
--------------------------------------------------------------------------------
/aistats2022-slides/style.css:
--------------------------------------------------------------------------------
1 | .centered {
2 | position: fixed;
3 | top: 50%;
4 | left: 50%;
5 | /* bring your own prefixes */
6 | transform: translate(-50%, -50%);
7 | }
8 |
9 | .horizontal-center {
10 | display: block;
11 | margin-left: auto;
12 | margin-right: auto;
13 | }
14 |
15 | .bottom-right {
16 | position: absolute;
17 | bottom: 0;
18 | right: 0;
19 | }
20 |
--------------------------------------------------------------------------------
/aistats2022-slides/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "private": true,
3 | "scripts": {
4 | "build": "slidev build",
5 | "dev": "slidev --open",
6 | "export": "slidev export"
7 | },
8 | "dependencies": {
9 | "@slidev/cli": "^0.28.6",
10 | "@slidev/theme-default": "*",
11 | "@slidev/theme-seriph": "*"
12 | },
13 | "name": "aistats2022",
14 | "devDependencies": {
15 | "playwright-chromium": "^1.19.1"
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/bandits/__main__.py:
--------------------------------------------------------------------------------
1 | import fire
2 | from scripts import plot_results
3 | from scripts import run_experiments
4 | from scripts import tabular_test
5 |
6 |
7 | class Experiments:
8 | def test(self):
9 | tabular_test.main()
10 |
11 | def plot_experiments(self):
12 | plot_results.main()
13 |
14 | def run_experiments(self, experiment=None):
15 | run_experiments.main(experiment)
16 |
17 | def run_and_plot(self):
18 | self.run_experiments()
19 | self.plot_experiments()
20 |
21 | if __name__ == "__main__":
22 | fire.Fire(Experiments)
--------------------------------------------------------------------------------
/aistats2022-slides/components/Counter.vue:
--------------------------------------------------------------------------------
1 |
12 |
13 |
14 |
15 |
16 | -----
17 |
18 | ## Installation
19 |
20 | ```
21 | pip install fire
22 | pip install ml-collections
23 | ```
24 |
25 | ## Reproduce the results
26 |
27 | There are two ways to reproduce the results from the paper
28 |
29 | ### Run the scripts
30 |
31 | To reproduce the results, `cd` into the project folder and run
32 |
33 | ```bash
34 | python bandits test
35 | ```
36 |
37 | ```bash
38 | python bandits run_and_plot
39 | ```
40 |
41 | ### Step by step
42 |
43 | If you only want to reproduce the results, run
44 |
45 | ```bash
46 | python bandits run_experiments
47 | ```
48 |
49 | If you have previously reproduced the results and want to reproduce the plots, run
50 |
51 | ```bash
52 | python bandits plot_experiments
53 | ```
54 |
55 | The results will be stored inside `bandits/figures/`.
56 |
57 | ### Execute the notebooks
58 |
59 | An alternative way to reproduce the results is to simply open and run [`subspace_bandits.ipynb`](https://github.com/probml/bandits/blob/main/bandits/scripts/subspace_bandits.ipynb)
60 |
--------------------------------------------------------------------------------
/bandits/environments/environment.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 |
4 |
5 | class BanditEnvironment:
6 | def __init__(self, key, X, Y, opt_rewards):
7 | # Randomise dataset rows
8 | n_obs, n_features = X.shape
9 |
10 | new_ixs = jax.random.choice(key, n_obs, (n_obs,), replace=False)
11 |
12 | X = jnp.asarray(X)[new_ixs]
13 | Y = jnp.asarray(Y)[new_ixs]
14 | opt_rewards = jnp.asarray(opt_rewards)[new_ixs]
15 |
16 | self.contexts = X
17 | self.labels_onehot = Y
18 | self.opt_rewards = opt_rewards
19 | _, self.n_arms = Y.shape
20 | self.n_steps, self.n_features = X.shape
21 |
22 | def get_state(self, t):
23 | return self.labels_onehot[t]
24 |
25 | def get_context(self, t):
26 | return self.contexts[t]
27 |
28 | def get_reward(self, t, action):
29 | return jnp.float32(self.labels_onehot[t][action])
30 |
31 | def warmup(self, num_pulls):
32 | num_steps, num_actions = self.labels_onehot.shape
33 | # Create array of round-robin actions: 0, 1, 2, 0, 1, 2, 0, 1, 2, ...
34 | warmup_actions = jnp.arange(num_actions)
35 | warmup_actions = jnp.repeat(warmup_actions, num_pulls).reshape(num_actions, -1)
36 | actions = warmup_actions.reshape(-1, order="F").astype(jnp.int32)
37 | num_warmup_actions = len(actions)
38 |
39 | time_steps = jnp.arange(num_warmup_actions)
40 |
41 | def get_contexts_and_rewards(t, a):
42 | context = self.get_context(t)
43 | state = self.get_state(t)
44 | reward = self.get_reward(t, a)
45 | return context, state, reward
46 |
47 | contexts, states, rewards = jax.vmap(get_contexts_and_rewards, in_axes=(0, 0))(time_steps, actions)
48 |
49 | return contexts, states, actions, rewards
50 |
--------------------------------------------------------------------------------
/bandits/agents/diagonal_subspace.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jsl.nlds.diagonal_extended_kalman_filter import DiagonalExtendedKalmanFilter
3 | from .ekf_subspace import SubspaceNeuralBandit
4 | from tensorflow_probability.substrates import jax as tfp
5 |
6 | tfd = tfp.distributions
7 |
8 |
9 | class DiagonalSubspaceNeuralBandit(SubspaceNeuralBandit):
10 |
11 | def __init__(self, num_features, num_arms, model, opt, prior_noise_variance, nwarmup=1000, nepochs=1000,
12 | system_noise=0.0, observation_noise=1.0, n_components=0.9999, random_projection=False):
13 | super().__init__(num_features, num_arms, model, opt, prior_noise_variance, nwarmup, nepochs,
14 | system_noise, observation_noise, n_components, random_projection)
15 |
16 | def init_bel(self, key, contexts, states, actions, rewards):
17 | bel = super().init_bel(key, contexts, states, actions, rewards)
18 |
19 | params_subspace_init, _, t = bel
20 |
21 | subspace_dim = self.n_components
22 | Q = jnp.ones(subspace_dim) * self.system_noise
23 | R = self.observation_noise
24 |
25 | covariance_subspace_init = jnp.ones(subspace_dim) * self.prior_noise_variance
26 |
27 | def fz(params):
28 | return params
29 |
30 | def fx(params, context, action):
31 | return self.predict_rewards(params, context)[action, None]
32 |
33 | ekf = DiagonalExtendedKalmanFilter(fz, fx, Q, R)
34 | self.ekf = ekf
35 |
36 | bel = (params_subspace_init, covariance_subspace_init, t)
37 |
38 | return bel
39 |
40 | def sample_params(self, key, bel):
41 | params_subspace, covariance_subspace, t = bel
42 | normal_dist = tfd.Normal(loc=params_subspace, scale=covariance_subspace)
43 | params_subspace = normal_dist.sample(seed=key)
44 | return params_subspace
45 |
--------------------------------------------------------------------------------
/bandits/environments/ads16_env.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jax.ops import index_add
3 | from jax.random import split, permutation
4 | from jax.nn import one_hot
5 |
6 | import pandas as pd
7 |
8 | import requests
9 | import io
10 |
11 | from environment import BanditEnvironment
12 |
13 |
14 | def get_ads16(key, ntrain, intercept):
15 | url = "https://raw.githubusercontent.com/probml/probml-data/main/data/ads16_preprocessed.csv"
16 | download = requests.get(url).content
17 |
18 | dataset = pd.read_csv(io.StringIO(download.decode('utf-8')))
19 | dataset.drop(columns=['Unnamed: 0'], inplace=True)
20 | dataset = dataset.sample(frac=1).reset_index(drop=True).to_numpy()
21 |
22 | ntrain = ntrain if ntrain > 0 and ntrain < len(dataset) else len(dataset)
23 | nusers, nads = 120, 300
24 | users = jnp.arange(nusers)
25 |
26 | n_ads_per_user, rem = divmod(ntrain, nusers)
27 |
28 | mykey, key = split(key)
29 | indices = permutation(mykey, users)[:rem]
30 | n_ads_per_user = jnp.ones((nusers,)) * int(n_ads_per_user)
31 | n_ads_per_user = index_add(n_ads_per_user, indices, 1).astype(jnp.int32)
32 |
33 | indices = jnp.array([])
34 |
35 | for user, nrow in enumerate(n_ads_per_user):
36 | mykey, key = split(key)
37 | df_indices = jnp.arange(user * nads, (user + 1) * nads)
38 | indices = jnp.append(indices, permutation(mykey, df_indices)[:nrow]).astype(jnp.int32)
39 |
40 | narms = 2
41 | dataset = dataset[indices]
42 |
43 | X = dataset[:, :-1]
44 | Y = one_hot(dataset[:, -1], narms)
45 |
46 | if intercept:
47 | X = jnp.concatenate([jnp.ones_like(X[:, :1]), X])
48 |
49 | opt_rewards = jnp.ones((len(X),))
50 |
51 | return X, Y, opt_rewards
52 |
53 |
54 | def ADS16Environment(key, ntrain, intercept=False):
55 | mykey, key = split(key)
56 | X, Y, opt_rewards = get_ads16(mykey, ntrain, intercept)
57 | return BanditEnvironment(key, X, Y, opt_rewards)
58 |
--------------------------------------------------------------------------------
/bandits/environments/movielens_env.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 |
4 | import jax.numpy as jnp
5 |
6 | from .environment import BanditEnvironment
7 |
8 | MOVIELENS_NUM_USERS = 943
9 | MOVIELENS_NUM_MOVIES = 1682
10 |
11 |
12 | def load_movielens_data(filepath):
13 | dataset = pd.read_csv(filepath, delimiter='\t', header=None)
14 | columns = {0: 'user_id', 1: 'item_id', 2: 'ranking', 3: 'timestamp'}
15 | dataset = dataset.rename(columns=columns)
16 | dataset['user_id'] -= 1
17 | dataset['item_id'] -= 1
18 | dataset = dataset.drop(columns="timestamp")
19 |
20 | rankings_matrix = np.zeros((MOVIELENS_NUM_USERS, MOVIELENS_NUM_MOVIES))
21 | for i, row in dataset.iterrows():
22 | rankings_matrix[row["user_id"], row["item_id"]] = float(row["ranking"])
23 | return rankings_matrix
24 |
25 |
26 | # https://github.com/tensorflow/agents/blob/master/tf_agents/bandits/environments/movielens_py_environment.py
27 | def get_movielens(rank_k, num_movies, repeat=5):
28 | """Initializes the MovieLens Bandit environment.
29 | Args:
30 | rank_k : (int) Which rank to use in the matrix factorization.
31 | batch_size: (int) Number of observations generated per call.
32 | num_movies: (int) Only the first `num_movies` movies will be used by the
33 | environment. The rest is cut out from the data.
34 | """
35 | num_actions = num_movies
36 | context_dim = rank_k
37 |
38 | # Compute the matrix factorization.
39 | data_matrix = load_movielens_data("../bandit-data/ml-100k/u.data")
40 | # Keep only the first items.
41 | data_matrix = data_matrix[:, :num_movies]
42 | # Filter the users with no iterm rated.
43 | nonzero_users = list(np.nonzero(np.sum(data_matrix, axis=1) > 0.0)[0]) * repeat
44 | data_matrix = data_matrix[nonzero_users, :]
45 | effective_num_users = len(nonzero_users)
46 |
47 | # Compute the SVD.
48 | u, s, vh = np.linalg.svd(data_matrix, full_matrices=False)
49 |
50 | # Keep only the largest singular values.
51 | u_hat = u[:, :context_dim] * np.sqrt(s[:context_dim])
52 | v_hat = np.transpose(np.transpose(vh[:rank_k, :]) * np.sqrt(s[:rank_k]))
53 | approx_ratings_matrix = np.matmul(u_hat, v_hat)
54 | opt_rewards = np.max(approx_ratings_matrix, axis=1)
55 | return u_hat, approx_ratings_matrix, opt_rewards
56 |
57 |
58 | def MovielensEnvironment(key, rank_k=20, num_movies=20, repeat=5, intercept=False):
59 | X, y, opt_rewards = get_movielens(rank_k, num_movies, repeat)
60 |
61 | if intercept:
62 | X = jnp.hstack([jnp.ones_like(X[:, :1]), X])
63 |
64 | return BanditEnvironment(key, X, y, opt_rewards)
65 |
--------------------------------------------------------------------------------
/bandits/agents/linear_kf_bandit.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jax.lax import scan
3 | from jax.random import split
4 | from jsl.lds.kalman_filter import KalmanFilterNoiseEstimation
5 | from tensorflow_probability.substrates import jax as tfp
6 |
7 | tfd = tfp.distributions
8 |
9 |
10 | class LinearKFBandit:
11 | def __init__(self, num_features, num_arms, eta=6.0, lmbda=0.25):
12 | self.num_features = num_features
13 | self.num_arms = num_arms
14 | self.eta = eta
15 | self.lmbda = lmbda
16 |
17 | def init_bel(self, key, contexts, states, actions, rewards):
18 | v = 2 * self.eta * jnp.ones((self.num_arms,))
19 | tau = jnp.ones((self.num_arms,))
20 |
21 | Sigma0 = jnp.eye(self.num_features)
22 | mu0 = jnp.zeros((self.num_features,))
23 |
24 | Sigma = 1. / self.lmbda * jnp.repeat(Sigma0[None, ...], self.num_arms, axis=0)
25 | mu = Sigma @ mu0
26 | A = jnp.eye(self.num_features)
27 | Q = 0
28 |
29 | self.kf = KalmanFilterNoiseEstimation(A, Q, mu, Sigma, v, tau)
30 |
31 | def warmup_update(bel, cur):
32 | context, action, reward = cur
33 | bel = self.update_bel(bel, context, action, reward)
34 | return bel, None
35 |
36 | bel = (mu, Sigma, v, tau)
37 | bel, _ = scan(warmup_update, bel, (contexts, actions, rewards))
38 |
39 | return bel
40 |
41 | def update_bel(self, bel, context, action, reward):
42 | mu, Sigma, v, tau = bel
43 | state = (mu[action], Sigma[action], v[action], tau[action])
44 | xs = (context, reward)
45 |
46 | mu_k, Sigma_k, v_k, tau_k = self.kf.kalman_step(state, xs)
47 |
48 | mu = mu.at[action].set(mu_k)
49 | Sigma = Sigma.at[action].set(Sigma_k)
50 | v = v.at[action].set(v_k)
51 | tau = tau.at[action].set(tau_k)
52 |
53 | bel = (mu, Sigma, v, tau)
54 |
55 | return bel
56 |
57 | def sample_params(self, key, bel):
58 | sigma_key, w_key = split(key, 2)
59 | mu, Sigma, v, tau = bel
60 |
61 | lmbda = tfd.InverseGamma(v / 2., (v * tau) / 2.).sample(seed=sigma_key)
62 | V = lmbda[:, None, None]
63 |
64 | covariance_matrix = V * Sigma
65 | w = tfd.MultivariateNormalFullCovariance(loc=mu, covariance_matrix=covariance_matrix).sample(seed=w_key)
66 |
67 | return w
68 |
69 | def choose_action(self, key, bel, context):
70 | # Thompson sampling strategy
71 | # Could also use epsilon greedy or UCB
72 | w = self.sample_params(key, bel)
73 | predicted_reward = jnp.einsum("m,km->k", context, w)
74 | action = predicted_reward.argmax()
75 | return action
76 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.csv
2 | *.png
3 | *.DS_Store
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102 | __pypackages__/
103 |
104 | # Celery stuff
105 | celerybeat-schedule
106 | celerybeat.pid
107 |
108 | # SageMath parsed files
109 | *.sage.py
110 |
111 | # Environments
112 | .env
113 | .venv
114 | env/
115 | venv/
116 | ENV/
117 | env.bak/
118 | venv.bak/
119 |
120 | # Spyder project settings
121 | .spyderproject
122 | .spyproject
123 |
124 | # Rope project settings
125 | .ropeproject
126 |
127 | # mkdocs documentation
128 | /site
129 |
130 | # mypy
131 | .mypy_cache/
132 | .dmypy.json
133 | dmypy.json
134 |
135 | # Pyre type checker
136 | .pyre/
137 |
138 | # pytype static type analyzer
139 | .pytype/
140 |
141 | # Cython debug symbols
142 | cython_debug/
143 |
--------------------------------------------------------------------------------
/bandits/training.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from functools import partial
4 |
5 |
6 | def reshape_vvmap(x):
7 | """
8 | Reshape a 2D array to a 1D array.
9 | This is taken to be the output of a double vmap or pmap.
10 | """
11 | shape_orig = x.shape
12 | shape_new = (shape_orig[0] * shape_orig[1], *shape_orig[2:])
13 | return jnp.reshape(x, shape_new)
14 |
15 |
16 | def step(bel, t, key_base, bandit, env):
17 | key = jax.random.fold_in(key_base, t)
18 | context = env.get_context(t)
19 |
20 | action = bandit.choose_action(key, bel, context)
21 | reward = env.get_reward(t, action)
22 | bel = bandit.update_bel(bel, context, action, reward)
23 |
24 | hist = {
25 | "actions": action,
26 | "rewards": reward
27 | }
28 |
29 | return bel, hist
30 |
31 |
32 | def warmup_bandit(key, bandit, env, npulls):
33 | warmup_contexts, warmup_states, warmup_actions, warmup_rewards = env.warmup(npulls)
34 | bel = bandit.init_bel(key, warmup_contexts, warmup_states, warmup_actions, warmup_rewards)
35 |
36 | hist = {
37 | "states": warmup_states,
38 | "actions": warmup_actions,
39 | "rewards": warmup_rewards,
40 | }
41 | return bel, hist
42 |
43 |
44 | def run_bandit(key, bel, bandit, env, t_start=0):
45 | step_part = partial(step, key_base=key, bandit=bandit, env=env)
46 | steps = jnp.arange(t_start, env.n_steps)
47 | bel, hist = jax.lax.scan(step_part, bel, steps)
48 | return bel, hist
49 |
50 |
51 | def run_bandit_trials(key, bel, bandit, env, t_start=0, n_trials=1):
52 | keys = jax.random.split(key, n_trials)
53 | run_partal = partial(run_bandit, bel=bel, bandit=bandit, env=env, t_start=t_start)
54 | run_partial = jax.vmap(run_partal)
55 |
56 | bel, hist = run_partial(keys)
57 | return bel, hist
58 |
59 | def run_bandit_trials_pmap(key, bel, bandit, env, t_start=0, n_trials=1):
60 | keys = jax.random.split(key, n_trials)
61 | run_partial = partial(run_bandit, bel=bel, bandit=bandit, env=env, t_start=t_start)
62 | run_partial = jax.pmap(run_partial)
63 |
64 | bel, hist = run_partial(keys)
65 | return bel, hist
66 |
67 |
68 | def run_bandit_trials_multiple(key, bel, bandit, env, t_start, n_trials):
69 | """
70 | Run vmap over multiple trials, and pmap over multiple devices
71 | """
72 | ndevices = jax.local_device_count()
73 | nsamples_per_device = n_trials // ndevices
74 | keys = jax.random.split(key, ndevices)
75 | run_partial = partial(run_bandit_trials, bel=bel, bandit=bandit, env=env, t_start=t_start, n_trials=nsamples_per_device)
76 | run_partial = jax.pmap(run_partial)
77 |
78 | bel, hist = run_partial(keys)
79 | hist = jax.tree_map(reshape_vvmap, hist)
80 | bel = jax.tree_map(reshape_vvmap, bel)
81 |
82 | return bel, hist
83 |
--------------------------------------------------------------------------------
/bandits/agents/linear_bandit.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jax import lax
3 | from jax import random
4 |
5 | from tensorflow_probability.substrates import jax as tfp
6 |
7 | tfd = tfp.distributions
8 |
9 |
10 | class LinearBandit:
11 | def __init__(self, num_features, num_arms, eta=6.0, lmbda=0.25):
12 | self.num_features = num_features
13 | self.num_arms = num_arms
14 | self.eta = eta
15 | self.lmbda = lmbda
16 |
17 | def init_bel(self, key, contexts, states, actions, rewards):
18 | mu = jnp.zeros((self.num_arms, self.num_features))
19 | Sigma = 1. / self.lmbda * jnp.eye(self.num_features) * jnp.ones((self.num_arms, 1, 1))
20 | a = self.eta * jnp.ones((self.num_arms,))
21 | b = self.eta * jnp.ones((self.num_arms,))
22 |
23 | initial_bel = (mu, Sigma, a, b)
24 |
25 | def update(bel, cur): # could do batch update
26 | context, action, reward = cur
27 | bel = self.update_bel(bel, context, action, reward)
28 | return bel, None
29 |
30 | bel, _ = lax.scan(update, initial_bel, (contexts, actions, rewards))
31 | return bel
32 |
33 | def update_bel(self, bel, context, action, reward):
34 | mu, Sigma, a, b = bel
35 |
36 | mu_k, Sigma_k = mu[action], Sigma[action]
37 | Lambda_k = jnp.linalg.inv(Sigma_k)
38 | a_k, b_k = a[action], b[action]
39 |
40 | # weight params
41 | Lambda_update = jnp.outer(context, context) + Lambda_k
42 | Sigma_update = jnp.linalg.inv(Lambda_update)
43 | mu_update = Sigma_update @ (Lambda_k @ mu_k + context * reward)
44 | # noise params
45 | a_update = a_k + 1 / 2
46 | b_update = b_k + (reward ** 2 + mu_k.T @ Lambda_k @ mu_k - mu_update.T @ Lambda_update @ mu_update) / 2
47 |
48 | # Update only the chosen action at time t
49 | mu = mu.at[action].set(mu_update)
50 | Sigma = Sigma.at[action].set(Sigma_update)
51 | a = a.at[action].set(a_update)
52 | b = b.at[action].set(b_update)
53 |
54 | bel = (mu, Sigma, a, b)
55 |
56 | return bel
57 |
58 | def sample_params(self, key, bel):
59 | mu, Sigma, a, b = bel
60 |
61 | sigma_key, w_key = random.split(key, 2)
62 | sigma2_samp = tfd.InverseGamma(concentration=a, scale=b).sample(seed=sigma_key)
63 | covariance_matrix = sigma2_samp[:, None, None] * Sigma
64 | w = tfd.MultivariateNormalFullCovariance(loc=mu, covariance_matrix=covariance_matrix).sample(
65 | seed=w_key)
66 | return w
67 |
68 | def choose_action(self, key, bel, context):
69 | # Thompson sampling strategy
70 | # Could also use epsilon greedy or UCB
71 | w = self.sample_params(key, bel)
72 | predicted_reward = jnp.einsum("m,km->k", context, w)
73 | action = predicted_reward.argmax()
74 | return action
75 |
--------------------------------------------------------------------------------
/bandits/agents/linear_bandit_wide.py:
--------------------------------------------------------------------------------
1 | # linear bandit with a single linear layer applied to a "wide" feature vector phi(s,a).
2 | # So reward = w' * phi(s,a). If the vector is block structured one-hot, in which we put
3 | # phi(s) into slot/block a, then we have w' phi(s,a) = w_a phi(s), which is a standard linaer model.
4 | # For example, suppose phi(s)=[s1,s2] and we have 3 actions.
5 | # Then phi(s,a=1) = [s1 s2 0 0 0 0], phi(s,a=3) = [0 0 0 0 s1 s2].
6 | # Similarly let w = [w11 w12. w21 w22 w31 w32] where w(i,j) is weight for action i, feature j.
7 | # Then w'phi(s,a=1) = [w11 w12] = w_1.
8 |
9 | import jax.numpy as jnp
10 | from jax import vmap
11 | from jax.random import split
12 | from jax.nn import one_hot
13 | from jax.lax import scan
14 |
15 | from .agent_utils import NIGupdate
16 |
17 | from tensorflow_probability.substrates import jax as tfp
18 |
19 | tfd = tfp.distributions
20 |
21 |
22 | class LinearBanditWide:
23 | def __init__(self, num_features, num_arms, eta=6.0, lmbda=0.25):
24 | self.num_features = num_features
25 | self.num_arms = num_arms
26 | self.eta = eta
27 | self.lmbda = lmbda
28 |
29 | def widen(self, context, action):
30 | phi = jnp.zeros((self.num_arms, self.num_features))
31 | phi = phi.at[action].set(context)
32 | return phi.flatten()
33 |
34 | def init_bel(self, key, contexts, states, actions, rewards):
35 | mu = jnp.zeros((self.num_arms * self.num_features))
36 | Sigma = 1 / self.lmbda * jnp.eye(self.num_features * self.num_arms)
37 | a = self.eta * jnp.ones((self.num_arms * self.num_features,))
38 | b = self.eta * jnp.ones((self.num_arms * self.num_features,))
39 |
40 | initial_bel = (mu, Sigma, a, b)
41 |
42 | def update(bel, cur): # could do batch update
43 | phi, reward = cur
44 | bel = NIGupdate(bel, phi, reward)
45 | return bel, None
46 |
47 | phis = vmap(self.widen)(contexts, actions)
48 | bel, _ = scan(update, initial_bel, (phis, rewards))
49 | return bel
50 |
51 | def update_bel(self, bel, context, action, reward):
52 | phi = self.widen(context, action)
53 | bel = NIGupdate(bel, phi, reward)
54 | return bel
55 |
56 | def sample_params(self, key, bel):
57 | mu, Sigma, a, b = bel
58 |
59 | sigma_key, w_key = split(key, 2)
60 | sigma2_samp = tfd.InverseGamma(concentration=a, scale=b).sample(seed=sigma_key)
61 | covariance_matrix = sigma2_samp * Sigma
62 | w = tfd.MultivariateNormalFullCovariance(loc=mu, covariance_matrix=covariance_matrix).sample(
63 | seed=w_key)
64 | return w
65 |
66 | def choose_action(self, key, bel, context):
67 | w = self.sample_params(key, bel)
68 |
69 | def get_reward(action):
70 | reward = one_hot(action, self.num_arms)
71 | phi = self.widen(context, action)
72 | reward = phi @ w
73 | return reward
74 |
75 | actions = jnp.arange(self.num_arms)
76 | rewards = vmap(get_reward)(actions)
77 | action = jnp.argmax(rewards)
78 |
79 | return action
80 |
--------------------------------------------------------------------------------
/bandits/agents/neural_greedy.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax.random import split
4 | from rebayes.sgd_filter import replay_sgd
5 | from bandits.agents.base import BanditAgent
6 |
7 |
8 | class NeuralGreedyBandit(BanditAgent):
9 | def __init__(
10 | self, num_features, num_arms, model, memory_size, tx,
11 | epsilon, n_inner=100
12 | ):
13 | self.num_features = num_features
14 | self.num_arms = num_arms
15 | self.model = model
16 | self.memory_size = memory_size
17 | self.tx = tx
18 | self.epsilon = epsilon
19 | self.n_inner = int(n_inner)
20 |
21 | def init_bel(self, key, contexts, states, actions, rewards):
22 | _, dim_in = contexts.shape
23 | X_dummy = jnp.ones((1, dim_in))
24 | params = self.model.init(key, X_dummy)
25 | out = self.model.apply(params, X_dummy)
26 | dim_out = out.shape[-1]
27 |
28 | def apply_fn(params, xs):
29 | return self.model.apply(params, xs)
30 |
31 | def predict_rewards(params, contexts):
32 | return self.model.apply(params, contexts)
33 |
34 | agent = replay_sgd.FifoSGD(
35 | lossfn=lossfn_rmse_extra_dim,
36 | apply_fn=apply_fn,
37 | tx=self.tx,
38 | buffer_size=self.memory_size,
39 | dim_features=dim_in + 1, # +1 for the action
40 | dim_output=1,
41 | n_inner=self.n_inner
42 | )
43 |
44 | bel = agent.init_bel(params, None)
45 | self.agent = agent
46 | self.predict_rewards = predict_rewards
47 |
48 | return bel
49 |
50 | def sample_params(self, key, bel):
51 | return bel.params
52 |
53 | def update_bel(self, bel, context, action, reward):
54 | xs = jnp.r_[context, action]
55 | bel = self.agent.update_state(bel, xs, reward)
56 | return bel
57 |
58 | def choose_action(self, key, bel, context):
59 | key, key_action = split(key)
60 | greedy = jax.random.bernoulli(key, 1 - self.epsilon)
61 |
62 | def explore():
63 | action = jax.random.randint(key_action, shape=(), minval=0, maxval=self.num_arms)
64 | return action
65 |
66 | def exploit():
67 | params = self.sample_params(key, bel)
68 | predicted_rewards = self.predict_rewards(params, context)
69 | action = predicted_rewards.argmax(axis=-1)
70 | return action
71 |
72 | action = jax.lax.cond(greedy == 1, exploit, explore)
73 | return action
74 |
75 |
76 | def lossfn_rmse_extra_dim(params, counter, xs, y, apply_fn):
77 | """
78 | Lossfunction for regression problems.
79 | We consider an extra dimension in the input xs, which is the action.
80 | """
81 | X = xs[..., :-1]
82 | action = xs[..., -1].astype(jnp.int32)
83 | buffer_size = X.shape[0]
84 | ix_slice = jnp.arange(buffer_size)
85 | yhat = apply_fn(params, X)[ix_slice, action].ravel()
86 | y = y.ravel()
87 | err = jnp.power(y - yhat, 2)
88 | loss = (err * counter).sum() / counter.sum()
89 | return loss
90 |
--------------------------------------------------------------------------------
/bandits/scripts/run_experiments.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "!pip install -qqq ml-collections"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "tags": []
17 | },
18 | "outputs": [],
19 | "source": [
20 | "import os\n",
21 | "os.chdir(\"..\")\n",
22 | "\n",
23 | "import jax\n",
24 | "import ml_collections\n",
25 | "\n",
26 | "import pandas as pd\n",
27 | "\n",
28 | "import glob\n",
29 | "from datetime import datetime\n",
30 | "\n",
31 | "import scripts.movielens_exp as movielens_run\n",
32 | "import scripts.mnist_exp as mnist_run\n",
33 | "import scripts.tabular_exp as tabular_run\n",
34 | "import scripts.tabular_subspace_exp as tabular_sub_run\n",
35 | "\n",
36 | "print(jax.device_count())"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "def get_config(filepath):\n",
46 | " \"\"\"Get the default hyperparameter configuration.\"\"\"\n",
47 | " config = ml_collections.ConfigDict()\n",
48 | " config.filepath = filepath\n",
49 | " config.ntrials = 10\n",
50 | " return config"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "timestamp = datetime.timestamp(datetime.now())"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "metadata": {},
65 | "source": [
66 | "# Run tabular experiments"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "tabular_filename = f\"./results/tabular_results_{timestamp}.csv\"\n",
76 | "config = get_config(tabular_filename)\n",
77 | "tabular_run.main(config)"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {},
83 | "source": [
84 | "# Run MNIST experiments"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "mnist_filename = f\"./results/mnist_results_{timestamp}.csv\"\n",
94 | "config = get_config(mnist_filename)\n",
95 | "mnist_run.main(config)"
96 | ]
97 | },
98 | {
99 | "cell_type": "markdown",
100 | "metadata": {},
101 | "source": [
102 | "# Run movielens experiments"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "movielens_filename = f\"./results/movielens_results_{timestamp}.csv\"\n",
112 | "config = get_config(movielens_filename)\n",
113 | "movielens_run.main(config)"
114 | ]
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "metadata": {},
119 | "source": [
120 | "# Run tabular subspace experiment"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "tabular_sub_filename = f\"./results/tabular_subspace_results_{timestamp}.csv\"\n",
130 | "config = get_config(tabular_sub_filename)\n",
131 | "tabular_sub_run.main(config)"
132 | ]
133 | }
134 | ],
135 | "metadata": {
136 | "interpreter": {
137 | "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
138 | },
139 | "kernelspec": {
140 | "display_name": "probml",
141 | "language": "python",
142 | "name": "probml"
143 | },
144 | "language_info": {
145 | "codemirror_mode": {
146 | "name": "ipython",
147 | "version": 3
148 | },
149 | "file_extension": ".py",
150 | "mimetype": "text/x-python",
151 | "name": "python",
152 | "nbconvert_exporter": "python",
153 | "pygments_lexer": "ipython3",
154 | "version": "3.9.7"
155 | }
156 | },
157 | "nbformat": 4,
158 | "nbformat_minor": 4
159 | }
160 |
--------------------------------------------------------------------------------
/bandits/scripts/training_utils.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import flax.linen as nn
4 |
5 | import warnings
6 | warnings.filterwarnings("ignore")
7 |
8 | class MLP(nn.Module):
9 | num_arms: int
10 |
11 | @nn.compact
12 | def __call__(self, x):
13 | x = nn.relu(nn.Dense(50, name="last_layer")(x))
14 | x = nn.Dense(self.num_arms)(x)
15 | return x
16 |
17 |
18 | class MLPWide(nn.Module):
19 | num_arms: int
20 |
21 | @nn.compact
22 | def __call__(self, x):
23 | x = nn.relu(nn.Dense(200)(x))
24 | x = nn.relu(nn.Dense(200, name="last_layer")(x))
25 | x = nn.Dense(self.num_arms)(x)
26 | return x
27 |
28 |
29 | class LeNet5(nn.Module):
30 | num_arms: int
31 |
32 | @nn.compact
33 | def __call__(self, x):
34 | x = x if len(x.shape) > 1 else x[None, :]
35 | x = x.reshape((x.shape[0], 28, 28, 1))
36 | x = nn.Conv(features=6, kernel_size=(5, 5))(x)
37 | x = nn.relu(x)
38 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID")
39 | x = nn.Conv(features=16, kernel_size=(5, 5), padding="VALID")(x)
40 | x = nn.relu(x)
41 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID")
42 | x = x.reshape((x.shape[0], -1)) # Flatten
43 | x = nn.Dense(features=120)(x)
44 | x = nn.relu(x)
45 | x = nn.Dense(features=84, name="last_layer")(x) # There are 10 classes in MNIST
46 | x = nn.relu(x)
47 | x = nn.Dense(features=self.num_arms)(x)
48 | return x.squeeze()
49 |
50 |
51 | def train(key, bandit_cls, env, npulls, ntrials, bandit_kwargs, neural=True):
52 | # TODO: deprecate neural flag
53 | nsteps, nfeatures = env.contexts.shape
54 | _, narms = env.labels_onehot.shape
55 | bandit = bandit_cls(nfeatures, narms, **bandit_kwargs)
56 |
57 | warmup_contexts, warmup_states, warmup_actions, warmup_rewards = env.warmup(npulls)
58 |
59 | key, mykey = jax.random.split(key)
60 | bel = bandit.init_bel(mykey, warmup_contexts, warmup_states, warmup_actions, warmup_rewards)
61 | warmup = (warmup_contexts, warmup_states, warmup_actions, warmup_rewards)
62 |
63 | def single_trial(key):
64 | _, _, rewards = run_bandit(key, bandit, bel, env, warmup, nsteps=nsteps, neural=neural)
65 | return rewards
66 |
67 | if ntrials > 1:
68 | keys = jax.random.split(key, ntrials)
69 | rewards_trace = jax.vmap(single_trial)(keys)
70 | else:
71 | rewards_trace = single_trial(key)
72 |
73 | return warmup_rewards, rewards_trace, env.opt_rewards
74 |
75 |
76 | def run_bandit(key, bandit, bel, env, warmup, nsteps, neural=True):
77 | def step(bel, cur):
78 | mykey, t = cur
79 | context = env.get_context(t)
80 |
81 | action = bandit.choose_action(mykey, bel, context)
82 | reward = env.get_reward(t, action)
83 | bel = bandit.update_bel(bel, context, action, reward)
84 |
85 | return bel, (context, action, reward)
86 |
87 | warmup_contexts, _, warmup_actions, warmup_rewards = warmup
88 | nwarmup = len(warmup_rewards)
89 |
90 | steps = jnp.arange(nsteps - nwarmup) + nwarmup
91 | keys = jax.random.split(key, nsteps - nwarmup)
92 |
93 | if neural:
94 | bandit.init_contexts_and_states(env.contexts[steps], env.labels_onehot[steps])
95 | mu, Sigma, a, b, params, _ = bel
96 | bel = (mu, Sigma, a, b, params, 0)
97 |
98 | _, (contexts, actions, rewards) = jax.lax.scan(step, bel, (keys, steps))
99 |
100 | contexts = jnp.vstack([warmup_contexts, contexts])
101 | actions = jnp.append(warmup_actions, actions)
102 | rewards = jnp.append(warmup_rewards, rewards)
103 |
104 | return contexts, actions, rewards
105 |
106 |
107 | def summarize_results(warmup_rewards, rewards):
108 | """
109 | Print a summary of running a Bandit algorithm for a number of runs
110 | """
111 | warmup_reward = warmup_rewards.sum()
112 | rewards = rewards.sum(axis=-1)
113 | r_mean = rewards.mean()
114 | r_std = rewards.std()
115 | r_total = r_mean + warmup_reward
116 |
117 | print(f"Expected Reward : {r_total:0.2f} ± {r_std:0.2f}")
118 | return r_total, r_std
119 |
--------------------------------------------------------------------------------
/bandits/scripts/thompson_sampling_bernoulli.py:
--------------------------------------------------------------------------------
1 | # Resolution of a Multi-Armed Bandit problem
2 | # using Thompson Sampling.
3 | # Author: Gerardo Durán-Martín (@gerdm)
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import matplotlib.pyplot as plt
8 | from jax import random
9 | from jax.nn import one_hot
10 | from jax.scipy.stats import beta
11 | from functools import partial
12 | import matplotlib.animation as animation
13 |
14 |
15 | class BetaBernoulliBandits:
16 | def __init__(self, K):
17 | self.K = K
18 |
19 | def sample(self, key, params):
20 | alphas = params["alpha"]
21 | betas = params["beta"]
22 | params_sample = random.beta(key, alphas, betas)
23 | return params_sample
24 |
25 | def predict_rewards(self, params_sample):
26 | return params_sample
27 |
28 | def update(self, action, params, reward):
29 | alphas = params["alpha"]
30 | betas = params["beta"]
31 | # Update policy distribution
32 | ind_vector = one_hot(action, self.K)
33 | alphas_posterior = alphas + ind_vector * reward
34 | betas_posterior = betas + ind_vector * (1 - reward)
35 | return {
36 | "alpha": alphas_posterior,
37 | "beta": betas_posterior
38 | }
39 |
40 |
41 | def true_reward(key, action, mean_rewards):
42 | reward = random.bernoulli(key, mean_rewards[action])
43 | return reward
44 |
45 |
46 | def thompson_sampling_step(model_params, key, model, environment):
47 | """
48 | Context-free implementation of the Thompson sampling algorithm.
49 | This implementation considers a single step
50 |
51 | Parameters
52 | ----------
53 | model_params: dict
54 | environment: function
55 | key: jax.random.PRNGKey
56 | moidel: instance of a Bandit model
57 | """
58 | key_sample, key_reward = random.split(key)
59 | params = model.sample(key_sample, model_params)
60 | pred_rewards = model.predict_rewards(params)
61 | action = pred_rewards.argmax()
62 | reward = environment(key_reward, action)
63 | model_params = model.update(action, model_params, reward)
64 | prob_arm = model_params["alpha"] / (model_params["alpha"] + model_params["beta"])
65 | return model_params, (model_params, prob_arm)
66 |
67 |
68 | if __name__ == "__main__":
69 | T = 200
70 | key = random.PRNGKey(31415)
71 | keys = random.split(key, T)
72 | mean_rewards = jnp.array([0.4, 0.5, 0.2, 0.9])
73 | K = len(mean_rewards)
74 | bbbandit = BetaBernoulliBandits(mean_rewards)
75 | init_params = {"alpha": jnp.ones(K),
76 | "beta": jnp.ones(K)}
77 |
78 | environment = partial(true_reward, mean_rewards=mean_rewards)
79 | thompson_partial = partial(thompson_sampling_step,
80 | model=BetaBernoulliBandits(K),
81 | environment=environment)
82 | posteriors, (hist, prob_arm_hist) = jax.lax.scan(thompson_partial, init_params, keys)
83 |
84 | p_range = jnp.linspace(0, 1, 100)
85 | bandits_pdf_hist = beta.pdf(p_range[:, None, None], hist["alpha"][None, ...], hist["beta"][None, ...])
86 | colors = ["orange", "blue", "green", "red"]
87 | colors = [f"tab:{color}" for color in colors]
88 |
89 | _, n_steps, _ = bandits_pdf_hist.shape
90 | fig, ax = plt.subplots(1, 4, figsize=(13, 2))
91 | filepath = "./bandits.mp4"
92 |
93 | def animate(t):
94 | for k, (axi, color) in enumerate(zip(ax, colors)):
95 | axi.cla()
96 | bandit = bandits_pdf_hist[:, t, k]
97 | axi.plot(p_range, bandit, c=color)
98 | axi.set_xlim(0, 1)
99 | n_pos = hist["alpha"][t, k].item() - 1
100 | n_trials = hist["beta"][t, k].item() + n_pos - 1
101 | axi.set_title(f"t={t+1}\np={mean_rewards[k]:0.2f}\n{n_pos:.0f}/{n_trials:.0f}")
102 | plt.tight_layout()
103 | return ax
104 |
105 | ani = animation.FuncAnimation(fig, animate, frames=n_steps)
106 | ani.save(filepath, dpi=300, bitrate=-1, fps=10)
107 |
108 | plt.plot(prob_arm_hist)
109 | plt.legend([f"mean reward: {reward:0.2f}" for reward in mean_rewards], loc="lower right")
110 | plt.savefig("beta-bernoulli-thompson-sampling.pdf")
111 | plt.show()
112 |
--------------------------------------------------------------------------------
/bandits/agents/low_rank_filter_bandit.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax.flatten_util import ravel_pytree
4 | from rebayes.low_rank_filter import lofi
5 | from bandits.agents.base import BanditAgent
6 | from rebayes.utils.sampling import sample_dlr_single
7 | from tensorflow_probability.substrates import jax as tfp
8 |
9 | tfd = tfp.distributions
10 |
11 | class LowRankFilterBandit(BanditAgent):
12 | """
13 | Regression bandit with low-rank filter.
14 | We consider a single neural network with k
15 | outputs corresponding to the k arms.
16 | """
17 | def __init__(self, num_features, num_arms, model, memory_size, emission_covariance,
18 | initial_covariance, dynamics_weights, dynamics_covariance):
19 | self.num_features = num_features
20 | self.num_arms = num_arms
21 | self.model = model
22 | self.memory_size = memory_size
23 |
24 | self.emission_covariance = emission_covariance
25 | self.initial_covariance = initial_covariance
26 | self.dynamics_weights = dynamics_weights
27 | self.dynamics_covariance = dynamics_covariance
28 |
29 | def init_bel(self, key, contexts, states, actions, rewards):
30 | _, dim_in = contexts.shape
31 | params = self.model.init(key, jnp.ones((1, dim_in)))
32 | flat_params, recfn = ravel_pytree(params)
33 |
34 | def apply_fn(flat_params, xs):
35 | context, action = xs
36 | return self.model.apply(recfn(flat_params), context)[action, None]
37 |
38 | def predict_rewards(flat_params, context):
39 | return self.model.apply(recfn(flat_params), context)
40 |
41 | agent = lofi.RebayesLoFiDiagonal(
42 | dynamics_weights=self.dynamics_weights,
43 | dynamics_covariance=self.dynamics_covariance,
44 | emission_mean_function=apply_fn,
45 | emission_cov_function=lambda m, x: self.emission_covariance,
46 | adaptive_emission_cov=False,
47 | dynamics_covariance_inflation_factor=0.0,
48 | memory_size=self.memory_size,
49 | steady_state=False,
50 | emission_dist=tfd.Normal
51 | )
52 | bel = agent.init_bel(flat_params, self.initial_covariance)
53 | self.agent = agent
54 | self.predict_rewards = predict_rewards
55 |
56 | return bel
57 |
58 | def sample_params(self, key, bel):
59 | params_samp = self.agent.sample_state(bel, key, 1).ravel()
60 | return params_samp
61 |
62 | def update_bel(self, bel, context, action, reward):
63 | xs = (context, action)
64 | bel = self.agent.update_state(bel, xs, reward)
65 | return bel
66 |
67 |
68 | class LowRankGreedy(LowRankFilterBandit):
69 | """
70 | Low-rank filter with greedy action selection.
71 | """
72 | def __init__(self, num_features, num_arms, model, memory_size, emission_covariance,
73 | initial_covariance, dynamics_weights, dynamics_covariance, epsilon):
74 | super().__init__(num_features, num_arms, model, memory_size, emission_covariance,
75 | initial_covariance, dynamics_weights, dynamics_covariance)
76 | self.epsilon = epsilon
77 |
78 | def choose_action(self, key, bel, context):
79 | key, key_action = jax.random.split(key)
80 | greedy = jax.random.bernoulli(key, 1 - self.epsilon)
81 | if greedy:
82 | rewards = self.predict_rewards(bel.state, context)
83 | action = jnp.argmax(rewards)
84 | else:
85 | action = jax.random.randint(key_action, (1,), 0, self.num_arms)
86 | return action
87 |
88 |
89 | def choose_action(self, key, bel, context):
90 | key, key_action = jax.random.split(key)
91 | greedy = jax.random.bernoulli(key, 1 - self.epsilon)
92 |
93 | def explore():
94 | action = jax.random.randint(key_action, shape=(), minval=0, maxval=self.num_arms)
95 | return action
96 |
97 | def exploit():
98 | params = bel.mean
99 | predicted_rewards = self.predict_rewards(params, context)
100 | action = predicted_rewards.argmax(axis=-1)
101 | return action
102 |
103 | action = jax.lax.cond(greedy == 1, exploit, explore)
104 | return action
--------------------------------------------------------------------------------
/bandits/scripts/tabular_subspace_exp.py:
--------------------------------------------------------------------------------
1 | from jax.random import split, PRNGKey
2 |
3 | import optax
4 | import pandas as pd
5 |
6 | import argparse
7 | from time import time
8 |
9 | from environments.tabular_env import TabularEnvironment
10 | from agents.ekf_subspace import SubspaceNeuralBandit
11 |
12 | from .training_utils import train, MLP, summarize_results
13 | from .mnist_exp import mapping, method_ordering
14 |
15 |
16 | def main(config):
17 | # Tabular datasets
18 | key = PRNGKey(314)
19 | key, shuttle_key, covetype_key, adult_key = split(key, 4)
20 | ntrain = 5000
21 |
22 | shuttle_env = TabularEnvironment(shuttle_key, ntrain=ntrain, name='statlog', intercept=True)
23 | covertype_env = TabularEnvironment(covetype_key, ntrain=ntrain, name='covertype', intercept=True)
24 | adult_env = TabularEnvironment(adult_key, ntrain=ntrain, name='adult', intercept=True)
25 | environments = {"shuttle": shuttle_env, "covertype": covertype_env, "adult": adult_env}
26 |
27 | learning_rate = 0.05
28 | momentum = 0.9
29 |
30 | # Subspace Neural Bandit with SVD
31 | npulls, nwarmup = 20, 2000
32 | observation_noise = 0.0
33 | prior_noise_variance = 1e-4
34 | nepochs = 1000
35 | random_projection = False
36 |
37 | ekf_sub_svd = {"opt": optax.sgd(learning_rate, momentum), "prior_noise_variance": prior_noise_variance,
38 | "nwarmup": nwarmup, "nepochs": nepochs,
39 | "observation_noise": observation_noise,
40 | "random_projection": random_projection}
41 |
42 | # Subspace Neural Bandit without SVD
43 | ekf_sub_rnd = ekf_sub_svd.copy()
44 | ekf_sub_rnd["random_projection"] = True
45 |
46 | bandits = {"EKF Subspace SVD": {"kwargs": ekf_sub_svd,
47 | "bandit": SubspaceNeuralBandit
48 | },
49 | "EKF Subspace RND": {"kwargs": ekf_sub_rnd,
50 | "bandit": SubspaceNeuralBandit
51 | }
52 | }
53 |
54 | results = []
55 | subspace_dimensions = [2, 3, 4, 5, 10, 15, 20, 30, 40, 50, 60, 100, 150, 200, 300, 400, 500]
56 | model_name = "MLP1"
57 | for env_name, env in environments.items():
58 | print("Environment : ", env_name)
59 | num_arms = env.labels_onehot.shape[-1]
60 | for subspace_dim in subspace_dimensions:
61 | model = MLP(num_arms)
62 | for bandit_name, properties in bandits.items():
63 | properties["kwargs"]["n_components"] = subspace_dim
64 | properties["kwargs"]["model"] = model
65 | key, mykey = split(key)
66 | print(f"\tBandit : {bandit_name}")
67 | start = time()
68 | warmup_rewards, rewards_trace, opt_rewards = train(mykey, properties["bandit"], env, npulls,
69 | config.ntrials,
70 | properties["kwargs"], neural=False)
71 |
72 | rtotal, rstd = summarize_results(warmup_rewards, rewards_trace)
73 | end = time()
74 | print(f"\t\tTime : {end - start}")
75 | results.append((env_name, bandit_name, model_name, subspace_dim, end - start, rtotal, rstd))
76 |
77 | df = pd.DataFrame(results)
78 | df = df.rename(columns={0: "Dataset", 1: "Method", 2: "Model", 3: "Subspace Dim", 4: "Time", 5: "Reward", 6: "Std"})
79 |
80 | df["Method"] = df["Method"].apply(lambda v: mapping[v])
81 |
82 | df["Subspace Dim"] = df['Subspace Dim'].astype(int)
83 | df["Reward"] = df['Reward'].astype(float)
84 | df["Time"] = df['Time'].astype(float)
85 | df["Std"] = df['Std'].astype(float)
86 |
87 | df["Rank"] = df["Method"].apply(lambda v: method_ordering[v])
88 | df.to_csv(config.filepath)
89 |
90 |
91 | if __name__ == "__main__":
92 | parser = argparse.ArgumentParser()
93 | parser.add_argument('--ntrials', type=int, nargs='?', const=10, default=10)
94 | filepath = "bandits/results/tabular_subspace_results.csv"
95 | parser.add_argument('--filepath', type=str, nargs='?', const=filepath, default=filepath)
96 |
97 | # Parse the argument
98 | args = parser.parse_args()
99 | main(args)
100 |
--------------------------------------------------------------------------------
/bandits/agents/neural_linear_bandit_wide.py:
--------------------------------------------------------------------------------
1 | # reward = w' * phi(s,a; theta), where theta is learned
2 |
3 | import jax.numpy as jnp
4 | from jax import vmap
5 | from jax.random import split
6 | from jax.nn import one_hot
7 | from jax.lax import scan, cond
8 |
9 | import optax
10 |
11 | from flax.training import train_state
12 |
13 | from .agent_utils import NIGupdate, train
14 | from scripts.training_utils import MLP
15 |
16 | from tensorflow_probability.substrates import jax as tfp
17 |
18 | tfd = tfp.distributions
19 |
20 |
21 | class NeuralLinearBanditWide:
22 | def __init__(self, num_features, num_arms, model=None, opt=optax.adam(learning_rate=1e-2), eta=6.0, lmbda=0.25,
23 | update_step_mod=100, batch_size=5000, nepochs=3000):
24 | self.num_features = num_features
25 | self.num_arms = num_arms
26 |
27 | if model is None:
28 | self.model = MLP(500, num_arms)
29 | else:
30 | try:
31 | self.model = model()
32 | except:
33 | self.model = model
34 |
35 | self.opt = opt
36 | self.eta = eta
37 | self.lmbda = lmbda
38 | self.update_step_mod = update_step_mod
39 | self.batch_size = batch_size
40 | self.nepochs = nepochs
41 |
42 | def init_bel(self, key, contexts, states, actions, rewards):
43 |
44 | key, mykey = split(key)
45 | initial_params = self.model.init(mykey, jnp.zeros((self.num_features,)))
46 | initial_train_state = train_state.TrainState.create(apply_fn=self.model.apply, params=initial_params,
47 | tx=self.opt)
48 |
49 | mu = jnp.zeros((self.num_arms, 500))
50 | Sigma = 1 * self.lmbda * jnp.eye(500) * jnp.ones((self.num_arms, 1, 1))
51 | a = self.eta * jnp.ones((self.num_arms,))
52 | b = self.eta * jnp.ones((self.num_arms,))
53 | t = 0
54 |
55 | def update(bel, x):
56 | context, action, reward = x
57 | return self.update_bel(bel, context, action, reward), None
58 |
59 | initial_bel = (mu, Sigma, a, b, initial_train_state, t)
60 | X = vmap(self.widen)(contexts)(actions)
61 | self.init_contexts_and_states(contexts, states)
62 | (bel, key), _ = scan(update, initial_bel, (contexts, actions, rewards))
63 | return bel
64 |
65 | def featurize(self, params, x, feature_layer="last_layer"):
66 | _, inter = self.model.apply(params, x, capture_intermediates=True)
67 | Phi, *_ = inter["intermediates"][feature_layer]["__call__"]
68 | return Phi
69 |
70 | def widen(self, context, action):
71 | phi = jnp.zeros((self.num_arms, self.num_features))
72 | phi[action] = context
73 | return phi.flatten()
74 |
75 | def cond_update_params(self, t):
76 | return (t % self.update_step_mod) == 0
77 |
78 | def init_contexts_and_states(self, contexts, states, actions, rewards):
79 | self.X = vmap(self.widen)(contexts)(actions)
80 | self.Y = rewards
81 |
82 | def update_bel(self, bel, context, action, reward):
83 |
84 | _, _, _, _, state, t = bel
85 | sgd_params = (state, t)
86 |
87 | phi = self.widen(self, context, action)
88 | state = cond(self.cond_update_params(t),
89 | lambda sgd_params: train(self.model, sgd_params[0], phi, reward,
90 | nepochs=self.nepochs, t=sgd_params[1]),
91 | lambda sgd_params: sgd_params[0], sgd_params)
92 | lin_bel = NIGupdate(bel, phi, reward)
93 | bel = (*lin_bel, state, t + 1)
94 |
95 | return bel
96 |
97 | def sample_params(self, key, bel):
98 | mu, Sigma, a, b, _, _ = bel
99 | sigma_key, w_key = split(key)
100 | sigma2 = tfd.InverseGamma(concentration=a, scale=b).sample(seed=sigma_key)
101 | covariance_matrix = sigma2[:, None, None] * Sigma
102 | w = tfd.MultivariateNormalFullCovariance(loc=mu, covariance_matrix=covariance_matrix).sample(seed=w_key)
103 | return w
104 |
105 | def choose_action(self, key, bel, context):
106 | w = self.sample_params(key, bel)
107 |
108 | def get_reward(action):
109 | reward = one_hot(action, self.num_arms)
110 | phi = self.widen(context, reward)
111 | reward = phi * w
112 | return reward
113 |
114 | actions = jnp.arange(self.num_arms)
115 | rewards = vmap(get_reward)(actions)
116 | action = jnp.argmax(rewards)
117 |
118 | return action
119 |
--------------------------------------------------------------------------------
/bandits/agents/ekf_orig_diag.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jax import jit
3 | from jax.flatten_util import ravel_pytree
4 |
5 | import optax
6 | from flax.training import train_state
7 |
8 | from .agent_utils import train
9 | from scripts.training_utils import MLP
10 | from jsl.nlds.diagonal_extended_kalman_filter import DiagonalExtendedKalmanFilter
11 |
12 | from tensorflow_probability.substrates import jax as tfp
13 |
14 | tfd = tfp.distributions
15 |
16 |
17 | class DiagonalNeuralBandit:
18 | def __init__(self, num_features, num_arms, model, opt, prior_noise_variance, nwarmup=1000, nepochs=1000,
19 | system_noise=0.0, observation_noise=1.0):
20 | """
21 | Subspace Neural Bandit implementation.
22 | Parameters
23 | ----------
24 | num_arms: int
25 | Number of bandit arms / number of actions
26 | environment : Environment
27 | The environment to be used.
28 | model : flax.nn.Module
29 | The flax model to be used for the bandits. Note that this model is independent of the
30 | model architecture. The only constraint is that the last layer should have the same
31 | number of outputs as the number of arms.
32 | learning_rate : float
33 | The learning rate for the optimizer used for the warmup phase.
34 | momentum : float
35 | The momentum for the optimizer used for the warmup phase.
36 | nepochs : int
37 | The number of epochs to be used for the warmup SGD phase.
38 | """
39 | self.num_features = num_features
40 | self.num_arms = num_arms
41 |
42 | if model is None:
43 | self.model = MLP(500, num_arms)
44 | else:
45 | try:
46 | self.model = model()
47 | except:
48 | self.model = model
49 |
50 | self.opt = opt
51 | self.prior_noise_variance = prior_noise_variance
52 | self.nwarmup = nwarmup
53 | self.nepochs = nepochs
54 | self.system_noise = system_noise
55 | self.observation_noise = observation_noise
56 |
57 | def init_bel(self, key, contexts, states, actions, rewards):
58 | initial_params = self.model.init(key, jnp.ones((1, self.num_features)))["params"]
59 | initial_train_state = train_state.TrainState.create(apply_fn=self.model.apply, params=initial_params,
60 | tx=self.opt)
61 |
62 | def loss_fn(params):
63 | pred_reward = self.model.apply({"params": params}, contexts)[:, actions.astype(int)]
64 | loss = optax.l2_loss(pred_reward, states[:, actions.astype(int)]).mean()
65 | return loss, pred_reward
66 |
67 | warmup_state, _ = train(initial_train_state, loss_fn=loss_fn, nepochs=self.nepochs)
68 |
69 | params_full_init, reconstruct_tree_params = ravel_pytree(warmup_state.params)
70 | nparams = params_full_init.size
71 |
72 | Q = jnp.ones(nparams) * self.system_noise
73 | R = self.observation_noise
74 |
75 | params_subspace_init = jnp.zeros(nparams)
76 | covariance_subspace_init = jnp.ones(nparams) * self.prior_noise_variance
77 |
78 | def predict_rewards(params, context):
79 | params_tree = reconstruct_tree_params(params)
80 | outputs = self.model.apply({"params": params_tree}, context)
81 | return outputs
82 |
83 | self.predict_rewards = predict_rewards
84 |
85 | def fz(params):
86 | return params
87 |
88 | def fx(params, context, action):
89 | return predict_rewards(params, context)[action, None]
90 |
91 | ekf = DiagonalExtendedKalmanFilter(fz, fx, Q, R)
92 | self.ekf = ekf
93 | bel = (params_subspace_init, covariance_subspace_init, 0)
94 |
95 | return bel
96 |
97 | def sample_params(self, key, bel):
98 | params_subspace, covariance_subspace, t = bel
99 | normal_dist = tfd.Normal(loc=params_subspace, scale=covariance_subspace)
100 | params_subspace = normal_dist.sample(seed=key)
101 | return params_subspace
102 |
103 | def update_bel(self, bel, context, action, reward):
104 | xs = (reward, (context, action))
105 | bel, _ = jit(self.ekf.filter_step)(bel, xs)
106 | return bel
107 |
108 | def choose_action(self, key, bel, context):
109 | # Thompson sampling strategy
110 | # Could also use epsilon greedy or UCB
111 | w = self.sample_params(key, bel)
112 | predicted_reward = self.predict_rewards(w, context)
113 | action = predicted_reward.argmax()
114 | return action
115 |
--------------------------------------------------------------------------------
/bandits/agents/ekf_orig_full.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jax import jit
3 | from jax.flatten_util import ravel_pytree
4 |
5 | import optax
6 |
7 | from flax.training import train_state
8 |
9 | from .agent_utils import train
10 | from jsl.nlds.extended_kalman_filter import ExtendedKalmanFilter
11 | from scripts.training_utils import MLP
12 | from tensorflow_probability.substrates import jax as tfp
13 |
14 | tfd = tfp.distributions
15 |
16 |
17 | class EKFNeuralBandit:
18 | def __init__(self, num_features, num_arms, model, opt, prior_noise_variance, nwarmup=1000, nepochs=1000,
19 | system_noise=0.0, observation_noise=1.0):
20 | """
21 | Subspace Neural Bandit implementation.
22 | Parameters
23 | ----------
24 | num_arms: int
25 | Number of bandit arms / number of actions
26 | environment : Environment
27 | The environment to be used.
28 | model : flax.nn.Module
29 | The flax model to be used for the bandits. Note that this model is independent of the
30 | model architecture. The only constraint is that the last layer should have the same
31 | number of outputs as the number of arms.
32 | learning_rate : float
33 | The learning rate for the optimizer used for the warmup phase.
34 | momentum : float
35 | The momentum for the optimizer used for the warmup phase.
36 | nepochs : int
37 | The number of epochs to be used for the warmup SGD phase.
38 | """
39 | self.num_features = num_features
40 | self.num_arms = num_arms
41 |
42 | if model is None:
43 | self.model = MLP(500, num_arms)
44 | else:
45 | try:
46 | self.model = model()
47 | except:
48 | self.model = model
49 |
50 | self.opt = opt
51 | self.prior_noise_variance = prior_noise_variance
52 | self.nwarmup = nwarmup
53 | self.nepochs = nepochs
54 | self.system_noise = system_noise
55 | self.observation_noise = observation_noise
56 |
57 | def init_bel(self, key, contexts, states, actions, rewards):
58 | initial_params = self.model.init(key, jnp.ones((1, self.num_features)))["params"]
59 | initial_train_state = train_state.TrainState.create(apply_fn=self.model.apply, params=initial_params,
60 | tx=self.opt)
61 |
62 | def loss_fn(params):
63 | pred_reward = self.model.apply({"params": params}, contexts)[:, actions.astype(int)]
64 | loss = optax.l2_loss(pred_reward, states[:, actions.astype(int)]).mean()
65 | return loss, pred_reward
66 |
67 | warmup_state, _ = train(initial_train_state, loss_fn=loss_fn, nepochs=self.nepochs)
68 |
69 | params_full_init, reconstruct_tree_params = ravel_pytree(warmup_state.params)
70 | nparams = params_full_init.size
71 |
72 | Q = jnp.eye(nparams) * self.system_noise
73 | R = jnp.eye(1) * self.observation_noise
74 |
75 | params_subspace_init = jnp.zeros(nparams)
76 | covariance_subspace_init = jnp.eye(nparams) * self.prior_noise_variance
77 |
78 | def predict_rewards(params, context):
79 | params_tree = reconstruct_tree_params(params)
80 | outputs = self.model.apply({"params": params_tree}, context)
81 | return outputs
82 |
83 | self.predict_rewards = predict_rewards
84 |
85 | def fz(params):
86 | return params
87 |
88 | def fx(params, context, action):
89 | return predict_rewards(params, context)[action, None]
90 |
91 | ekf = ExtendedKalmanFilter(fz, fx, Q, R)
92 | self.ekf = ekf
93 | bel = (params_subspace_init, covariance_subspace_init, 0)
94 |
95 | return bel
96 |
97 | def sample_params(self, key, bel):
98 | params_subspace, covariance_subspace, t = bel
99 | mv_normal = tfd.MultivariateNormalFullCovariance(loc=params_subspace, covariance_matrix=covariance_subspace)
100 | params_subspace = mv_normal.sample(seed=key)
101 | return params_subspace
102 |
103 | def update_bel(self, bel, context, action, reward):
104 | xs = (reward, (context, action))
105 | bel, _ = jit(self.ekf.filter_step)(bel, xs)
106 | return bel
107 |
108 | def choose_action(self, key, bel, context):
109 | # Thompson sampling strategy
110 | # Could also use epsilon greedy or UCB
111 | w = self.sample_params(key, bel)
112 | predicted_reward = self.predict_rewards(w, context)
113 | action = predicted_reward.argmax()
114 | return action
115 |
--------------------------------------------------------------------------------
/bandits/agents/neural_linear.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import optax
3 | import jax.numpy as jnp
4 | from flax.training.train_state import TrainState
5 |
6 | from .agent_utils import train
7 | from tensorflow_probability.substrates import jax as tfp
8 |
9 | tfd = tfp.distributions
10 |
11 |
12 | class NeuralLinearBandit:
13 | def __init__(self, num_features, num_arms, model=None, opt=optax.adam(learning_rate=1e-2), eta=6.0, lmbda=0.25,
14 | update_step_mod=100, batch_size=5000, nepochs=3000):
15 | self.num_features = num_features
16 | self.num_arms = num_arms
17 |
18 | self.opt = opt
19 | self.eta = eta
20 | self.lmbda = lmbda
21 | self.update_step_mod = update_step_mod
22 | self.batch_size = batch_size
23 | self.nepochs = nepochs
24 | self.model = model
25 |
26 | def init_bel(self, key, contexts, states, actions, rewards):
27 | key, mykey = jax.random.split(key)
28 | xdummy = jnp.zeros((self.num_features))
29 | initial_params = self.model.init(mykey, xdummy)
30 | initial_train_state = TrainState.create(apply_fn=self.model.apply, params=initial_params,
31 | tx=self.opt)
32 |
33 | n_hidden_last = self.model.apply(initial_params, xdummy, capture_intermediates=True)[1]["intermediates"]["last_layer"]["__call__"][0].shape[0]
34 | mu = jnp.zeros((self.num_arms, n_hidden_last))
35 | Sigma = 1 / self.lmbda * jnp.eye(n_hidden_last) * jnp.ones((self.num_arms, 1, 1))
36 | a = self.eta * jnp.ones((self.num_arms,))
37 | b = self.eta * jnp.ones((self.num_arms,))
38 | t = 0
39 |
40 | def update(bel, x):
41 | context, action, reward = x
42 | return self.update_bel(bel, context, action, reward), None
43 |
44 | self.contexts = contexts
45 | self.states = states
46 |
47 | initial_bel = (mu, Sigma, a, b, initial_train_state, t)
48 | bel, _ = jax.lax.scan(update, initial_bel, (contexts, actions, rewards))
49 | return bel
50 |
51 | def featurize(self, params, x, feature_layer="last_layer"):
52 | _, inter = self.model.apply(params, x, capture_intermediates=True)
53 | Phi, *_ = inter["intermediates"][feature_layer]["__call__"]
54 | return Phi.squeeze()
55 |
56 |
57 | def cond_update_params(self, t):
58 | return (t % self.update_step_mod) == 0
59 |
60 | def update_bel(self, bel, context, action, reward):
61 | mu, Sigma, a, b, state, t = bel
62 |
63 | sgd_params = (state, t)
64 |
65 | def loss_fn(params):
66 | n_samples, *_ = self.contexts.shape
67 | final_t = jax.lax.cond(t == 0, lambda t: n_samples, lambda t: t.astype(int), t)
68 | sample_range = (jnp.arange(n_samples) <= t)[:, None]
69 |
70 | pred_reward = self.model.apply(params, self.contexts)
71 | loss = (optax.l2_loss(pred_reward, self.states) * sample_range).sum() / final_t
72 | return loss, pred_reward
73 |
74 | state = jax.lax.cond(self.cond_update_params(t),
75 | lambda sgd_params: train(sgd_params[0], loss_fn=loss_fn, nepochs=self.nepochs)[0],
76 | lambda sgd_params: sgd_params[0], sgd_params)
77 |
78 | transformed_context = self.featurize(state.params, context)
79 |
80 | mu_k, Sigma_k = mu[action], Sigma[action]
81 | Lambda_k = jnp.linalg.inv(Sigma_k)
82 | a_k, b_k = a[action], b[action]
83 |
84 | # weight params
85 | Lambda_update = jnp.outer(transformed_context, transformed_context) + Lambda_k
86 | Sigma_update = jnp.linalg.inv(Lambda_update)
87 | mu_update = Sigma_update @ (Lambda_k @ mu_k + transformed_context * reward)
88 |
89 | # noise params
90 | a_update = a_k + 1 / 2
91 | b_update = b_k + (reward ** 2 + mu_k.T @ Lambda_k @ mu_k - mu_update.T @ Lambda_update @ mu_update) / 2
92 |
93 | # update only the chosen action at time t
94 | mu = mu.at[action].set(mu_update)
95 | Sigma = Sigma.at[action].set(Sigma_update)
96 | a = a.at[action].set(a_update)
97 | b = b.at[action].set(b_update)
98 | t = t + 1
99 |
100 | bel = (mu, Sigma, a, b, state, t)
101 | return bel
102 |
103 | def sample_params(self, key, bel):
104 | mu, Sigma, a, b, _, _ = bel
105 | sigma_key, w_key = jax.random.split(key)
106 | sigma2 = tfd.InverseGamma(concentration=a, scale=b).sample(seed=sigma_key)
107 | covariance_matrix = sigma2[:, None, None] * Sigma
108 | w = tfd.MultivariateNormalFullCovariance(loc=mu, covariance_matrix=covariance_matrix).sample(seed=w_key)
109 | return w
110 |
111 | def choose_action(self, key, bel, context):
112 | # Thompson sampling strategy
113 | # Could also use epsilon greedy or UCB
114 | state = bel[-2]
115 | context_transformed = self.featurize(state.params, context)
116 | w = self.sample_params(key, bel)
117 | predicted_reward = jnp.einsum("m,km->k", context_transformed, w)
118 | action = predicted_reward.argmax()
119 | return action
120 |
--------------------------------------------------------------------------------
/bandits/scripts/plot_results.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import pandas as pd
3 | import seaborn as sns
4 | import matplotlib.pyplot as plt
5 |
6 | def plot_figure(data, x, y, filename, figsize=(24, 9), log_scale=False):
7 | sns.set(font_scale=1.5)
8 | plt.style.use("seaborn-poster")
9 |
10 | fig, ax = plt.subplots(figsize=figsize, dpi=300)
11 | g = sns.barplot(x=x, y=y, hue="Method", data=data, errwidth=2, ax=ax, palette=colors)
12 | if log_scale:
13 | g.set_yscale("log")
14 | plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0))
15 | plt.tight_layout()
16 | plt.savefig(f"./bandits/figures/{filename}.png")
17 | plt.show()
18 |
19 | def read_data(dataset_name):
20 | *_, filename = sorted(glob.glob(f"./bandits/results/{dataset_name}_results*.csv"))
21 | df = pd.read_csv(filename)
22 | if dataset_name=="mnist":
23 | linear_df = df[(df["Method"]=="Lin-KF") | (df["Method"]=="Lin")].copy()
24 | linear_df["Model"] = "MLP2"
25 | df = df.append(linear_df)
26 | linear_df["Model"] = "LeNet5"
27 | df = df.append(linear_df)
28 |
29 | by = ["Rank"] if dataset_name=="tabular" else ["Rank", "AltRank"]
30 |
31 | data_up = df.sort_values(by=by).copy()
32 | data_down = df.sort_values(by=by).copy()
33 |
34 | data_up["Reward"] = data_up["Reward"] + data_up["Std"]
35 | data_down["Reward"] = data_down["Reward"] - data_down["Std"]
36 | data = pd.concat([data_up, data_down])
37 | return data
38 |
39 | def plot_subspace_figure(df, filename=None):
40 | df = df.reset_index().drop(columns=["index"])
41 | plt.style.use("seaborn-darkgrid")
42 | fig, ax = plt.subplots(figsize=(12, 8))
43 | sns.lineplot(x="Subspace Dim", y="Reward", hue="Method", marker="o", data=df)
44 | lines, labels = ax.get_legend_handles_labels()
45 | for line, method in zip(lines, labels):
46 | data = df[df["Method"]==method]
47 | color = line.get_c()
48 | y_lower_bound = data["Reward"] - data["Std"]
49 | y_upper_bound = data["Reward"] + data["Std"]
50 | ax.fill_between(data["Subspace Dim"], y_lower_bound, y_upper_bound, color=color, alpha=0.3)
51 |
52 | ax.set_ylabel("Reward", fontsize=16)
53 | plt.setp(ax.get_xticklabels(), fontsize=16)
54 | plt.setp(ax.get_yticklabels(), fontsize=16)
55 | ax.set_xlabel("Subspace Dimension(d)", fontsize=16)
56 | dataset = df.iloc[0]["Dataset"]
57 | ax.set_title(f"{dataset.title()} - Subspace Dim vs. Reward", fontsize=18)
58 | legend = ax.legend(loc="lower right", prop={'size': 16},frameon=1)
59 | frame = legend.get_frame()
60 | frame.set_color('white')
61 | frame.set_alpha(0.6)
62 |
63 | file_path = "./bandits/figures/"
64 | file_path = file_path + f"{dataset}_sub_reward.png" if filename is None else file_path + f"{filename}.png"
65 | plt.savefig(file_path)
66 |
67 | method_ordering = {"EKF-Sub-SVD": 0,
68 | "EKF-Sub-RND": 1,
69 | "EKF-Sub-Diag-SVD": 2,
70 | "EKF-Sub-Diag-RND": 3,
71 | "EKF-Orig-Full": 4,
72 | "EKF-Orig-Diag": 5,
73 | "NL-Lim": 6,
74 | "NL-Unlim": 7,
75 | "Lin": 8,
76 | "Lin-KF": 9,
77 | "Lin-Wide": 9,
78 | "Lim2": 10,
79 | "NeuralTS": 11}
80 |
81 | colors = {k : sns.color_palette("Paired")[v]
82 | if k!="Lin-KF" else sns.color_palette("tab20")[8]
83 | for k,v in method_ordering.items()}
84 |
85 | dataset_info = {
86 | "mnist": {
87 | "elements": ["EKF-Sub-SVD", "EKF-Sub-RND", "EKF-Sub-Diag-SVD",
88 | "EKF-Sub-Diag-RND", "EKF-Orig-Diag", "NL-Lim",
89 | "NL-Unlim", "Lin"],
90 | "x": "Model",
91 | },
92 |
93 | "tabular": {
94 | "elements": ["EKF-Sub-SVD", "EKF-Sub-RND", "EKF-Sub-Diag-SVD",
95 | "EKF-Sub-Diag-RND", "EKF-Orig-Diag", "NL-Lim",
96 | "NL-Unlim", "Lin"],
97 | "x": "Dataset"
98 | },
99 |
100 | "movielens": {
101 | "elements": ["EKF-Sub-SVD", "EKF-Sub-RND", "EKF-Sub-Diag-SVD",
102 | "EKF-Sub-Diag-RND", "EKF-Orig-Diag", "NL-Lim",
103 | "NL-Unlim", "Lin"],
104 | "x": "Model"
105 | },
106 | }
107 |
108 |
109 | plot_configs = [
110 | {"metric": "Reward", "log_scale":False},
111 | {"metric": "Time", "log_scale":True},
112 |
113 | ]
114 |
115 |
116 | def main():
117 | # Create reward / runnnig time experiments
118 | print("Plotting reward / running time")
119 | for dataset_name in dataset_info:
120 | print(dataset_name)
121 | info = dataset_info[dataset_name]
122 | methods = info["elements"]
123 | x = info["x"]
124 |
125 | df = read_data(dataset_name)
126 | df = df[df["Method"].isin(methods)]
127 |
128 | for config in plot_configs:
129 | metric = config["metric"]
130 | use_log_scale = config["log_scale"]
131 |
132 | filename = f"{dataset_name}_{metric.lower()}"
133 | plot_figure(df, x, metric, filename, log_scale=use_log_scale)
134 |
135 |
136 | # Plot subspace-dim v.s. reward
137 | print("Plotting subspace dim v.s. reward")
138 | *_, filename = sorted(glob.glob(f"./bandits/results/tabular_subspace_results*.csv"))
139 | tabular_sub_df = pd.read_csv(filename)
140 |
141 | datasets = ["shuttle", "adult", "covertype"]
142 | for dataset_name in datasets:
143 | print(dataset_name)
144 | subdf = tabular_sub_df.query("Dataset == @dataset_name")
145 | plot_subspace_figure(subdf)
146 |
147 |
148 | if __name__ == "__main__":
149 | main()
--------------------------------------------------------------------------------
/demos/lofi_tabular.py:
--------------------------------------------------------------------------------
1 | """
2 | In this demo, we evaluate the performance of the
3 | lofi bandit on a tabular dataset
4 | """
5 | import jax
6 | import optax
7 | import pickle
8 | import numpy as np
9 | import flax.linen as nn
10 | import jax.numpy as jnp
11 | from time import time
12 | from datetime import datetime
13 | from bayes_opt import BayesianOptimization
14 | from bandits import training as btrain
15 | from bandits.agents.low_rank_filter_bandit import LowRankFilterBandit
16 | from bandits.environments.tabular_env import TabularEnvironment
17 |
18 | class MLP(nn.Module):
19 | num_arms: int
20 |
21 | @nn.compact
22 | def __call__(self, x):
23 | # x = nn.Dense(50)(x)
24 | # x = nn.relu(x)
25 | x = nn.Dense(50, name="last_layer")(x)
26 | x = nn.relu(x)
27 | x = nn.Dense(self.num_arms)(x)
28 | return x
29 |
30 |
31 | def warmup_and_run(eval_hparams, transform_fn, bandit_cls, env, key, npulls, n_trials=1, **kwargs):
32 | n_devices = jax.local_device_count()
33 | key_warmup, key_train = jax.random.split(key, 2)
34 | hparams = transform_fn(eval_hparams)
35 | hparams = {**hparams, **kwargs}
36 |
37 | bandit = bandit_cls(env.n_features, env.n_arms, **hparams)
38 |
39 | bel, hist_warmup = btrain.warmup_bandit(key_warmup, bandit, env, npulls)
40 | time_init = time()
41 | if n_trials == 1:
42 | bel, hist_train = btrain.run_bandit(key_train, bel, bandit, env, t_start=npulls)
43 | elif 1 < n_trials <= n_devices:
44 | bel, hist_train = btrain.run_bandit_trials_pmap(key_train, bel, bandit, env, t_start=npulls, n_trials=n_trials)
45 | elif n_trials > n_devices:
46 | bel, hist_train = btrain.run_bandit_trials_multiple(key_train, bel, bandit, env, t_start=npulls, n_trials=n_trials)
47 | time_end = time()
48 | total_time = time_end - time_init
49 |
50 |
51 | res = {
52 | "hist_warmup": hist_warmup,
53 | "hist_train": hist_train,
54 | }
55 | # res = jax.tree_map(np.array, res)
56 | res["total_time"] = total_time
57 |
58 | return res
59 |
60 |
61 | def transform_hparams_subspace_fixed(hparams):
62 | emission_covariance = jnp.exp(hparams["log_em_cov"])
63 | initial_covariance = jnp.exp(hparams["log_init_cov"])
64 | warmup_learning_rate = jnp.exp(hparams["log_warmup_lr"])
65 |
66 | hparams = {
67 | "observation_noise": emission_covariance,
68 | "prior_noise_variance": initial_covariance,
69 | "opt": warmup_learning_rate,
70 | }
71 |
72 | return hparams
73 |
74 |
75 | def transform_hparams_lofi(hparams):
76 | emission_covariance = jnp.exp(hparams["log_em_cov"])
77 | initial_covariance = jnp.exp(hparams["log_init_cov"])
78 | dynamics_weights = 1 - jnp.exp(hparams["log_1m_dweights"])
79 | dynamics_covariance = jnp.exp(hparams["log_dcov"])
80 |
81 | hparams = {
82 | "emission_covariance": emission_covariance,
83 | "initial_covariance": initial_covariance,
84 | "dynamics_weights": dynamics_weights,
85 | "dynamics_covariance": dynamics_covariance,
86 | }
87 | return hparams
88 |
89 |
90 | def transform_hparams_lofi_fixed(hparams):
91 | """
92 | Transformation assuming that the dynamicss weights
93 | and dynamics covariance are static
94 | """
95 | emission_covariance = jnp.exp(hparams["log_em_cov"])
96 | initial_covariance = jnp.exp(hparams["log_init_cov"])
97 |
98 | hparams = {
99 | "emission_covariance": emission_covariance,
100 | "initial_covariance": initial_covariance,
101 | }
102 | return hparams
103 |
104 |
105 | def transform_hparams_linear(hparams):
106 | eta = hparams["eta"]
107 | lmbda = jnp.exp(hparams["log_lambda"])
108 | hparams = {
109 | "eta": eta,
110 | "lmbda": lmbda,
111 | }
112 | return hparams
113 |
114 |
115 | def transform_hparams_neural_linear(hparams):
116 | lr = jnp.exp(hparams["log_lr"])
117 | eta = hparams["eta"]
118 | lmbda = jnp.exp(hparams["log_lambda"])
119 | opt = optax.adam(lr)
120 |
121 | hparams = {
122 | "lmbda": lmbda,
123 | "eta": eta,
124 | "opt": opt,
125 | }
126 | return hparams
127 |
128 | def transform_hparams_rsgd(hparams):
129 | lr = jnp.exp(hparams["log_lr"])
130 | tx = optax.adam(lr)
131 | hparams = {
132 | "tx": tx,
133 | }
134 | return hparams
135 |
136 | if __name__ == "__main__":
137 | ntrials = 10
138 | npulls = 20
139 | key = jax.random.PRNGKey(314)
140 | key_env, key_warmup, key_train = jax.random.split(key, 3)
141 | ntrain = 500 # 5000
142 | env = TabularEnvironment(key_env, ntrain=ntrain, name='statlog', intercept=False, path="./bandit-data")
143 | num_arms = env.labels_onehot.shape[-1]
144 | model = MLP(num_arms)
145 |
146 | kwargs_lofi = {
147 | "emission_covariance": 0.01,
148 | "initial_covariance": 1.0,
149 | "dynamics_weights": 1.0,
150 | "dynamics_covariance": 0.0,
151 | "memory_size": 10,
152 | "model": model,
153 | }
154 |
155 | n_features = env.n_features
156 | n_arms = env.n_arms
157 | bandit = LowRankFilterBandit(n_features, n_arms, **kwargs_lofi)
158 |
159 | bel, hist_warmup = btrain.warmup_bandit(key_warmup, bandit, env, npulls)
160 | bel, hist_train = btrain.run_bandit_trials(key_train, bel, bandit, env, t_start=npulls, n_trials=ntrials)
161 |
162 | res = {
163 | "hist_warmup": hist_warmup,
164 | "hist_train": hist_train,
165 | }
166 | res = jax.tree_map(np.array, res)
167 |
168 | # Store results
169 | datestr = datetime.now().strftime("%Y%m%d%H%M%S")
170 | path_to_results = f"./results/lofi_{datestr}.pkl"
171 | with open(path_to_results, "wb") as f:
172 | pickle.dump(res, f)
173 | print(f"Results stored in {path_to_results}")
174 |
--------------------------------------------------------------------------------
/bandits/agents/ekf_subspace.py:
--------------------------------------------------------------------------------
1 | import optax
2 | import jax.numpy as jnp
3 | from jax import jit, device_put
4 | from jax.random import split
5 | from jax.flatten_util import ravel_pytree
6 | from flax.training import train_state
7 | from sklearn.decomposition import PCA
8 | from .agent_utils import train, generate_random_basis, convert_params_from_subspace_to_full
9 | from scripts.training_utils import MLP
10 | from jsl.nlds.extended_kalman_filter import ExtendedKalmanFilter
11 | from tensorflow_probability.substrates import jax as tfp
12 |
13 | tfd = tfp.distributions
14 |
15 |
16 | class SubspaceNeuralBandit:
17 | def __init__(self, num_features, num_arms, model, opt, prior_noise_variance, nwarmup=1000, nepochs=1000,
18 | system_noise=0.0, observation_noise=1.0, n_components=0.9999, random_projection=False):
19 | """
20 | Subspace Neural Bandit implementation.
21 | Parameters
22 | ----------
23 | num_arms: int
24 | Number of bandit arms / number of actions
25 | environment : Environment
26 | The environment to be used.
27 | model : flax.nn.Module
28 | The flax model to be used for the bandits. Note that this model is independent of the
29 | model architecture. The only constraint is that the last layer should have the same
30 | number of outputs as the number of arms.
31 | opt: flax.optim.Optimizer
32 | The optimizer to be used for training the model.
33 | learning_rate : float
34 | The learning rate for the optimizer used for the warmup phase.
35 | momentum : float
36 | The momentum for the optimizer used for the warmup phase.
37 | nepochs : int
38 | The number of epochs to be used for the warmup SGD phase.
39 | """
40 | self.num_features = num_features
41 | self.num_arms = num_arms
42 |
43 | # TODO: deprecate hard-coded MLP
44 | if model is None:
45 | self.model = MLP(500, num_arms)
46 | else:
47 | try:
48 | self.model = model()
49 | except:
50 | self.model = model
51 |
52 | self.opt = opt
53 | self.prior_noise_variance = prior_noise_variance
54 | self.nwarmup = nwarmup
55 | self.nepochs = nepochs
56 | self.system_noise = system_noise
57 | self.observation_noise = observation_noise
58 | self.n_components = n_components
59 | self.random_projection = random_projection
60 |
61 | def init_bel(self, key, contexts, states, actions, rewards):
62 | warmup_key, projection_key = split(key, 2)
63 | initial_params = self.model.init(warmup_key, jnp.ones((1, self.num_features)))["params"]
64 | initial_train_state = train_state.TrainState.create(apply_fn=self.model.apply, params=initial_params,
65 | tx=self.opt)
66 |
67 | def loss_fn(params):
68 | pred_reward = self.model.apply({"params": params}, contexts)[:, actions.astype(int)]
69 | loss = optax.l2_loss(pred_reward, states[:, actions.astype(int)]).mean()
70 | return loss, pred_reward
71 |
72 | warmup_state, warmup_metrics = train(initial_train_state, loss_fn=loss_fn, nepochs=self.nepochs)
73 |
74 | thinned_samples = warmup_metrics["params"][::2]
75 | params_trace = thinned_samples[-self.nwarmup:]
76 |
77 | if not self.random_projection:
78 | pca = PCA(n_components=self.n_components)
79 | pca.fit(params_trace)
80 | subspace_dim = pca.n_components_
81 | self.n_components = pca.n_components_
82 | projection_matrix = device_put(pca.components_)
83 | else:
84 | if type(self.n_components) != int:
85 | raise ValueError(f"n_components must be an integer, got {self.n_components}")
86 | total_dim = params_trace.shape[-1]
87 | subspace_dim = self.n_components
88 | projection_matrix = generate_random_basis(projection_key, subspace_dim, total_dim)
89 |
90 | Q = jnp.eye(subspace_dim) * self.system_noise
91 | R = jnp.eye(1) * self.observation_noise
92 |
93 | params_full_init, reconstruct_tree_params = ravel_pytree(warmup_state.params)
94 | params_subspace_init = jnp.zeros(subspace_dim)
95 | covariance_subspace_init = jnp.eye(subspace_dim) * self.prior_noise_variance
96 |
97 | def predict_rewards(params_subspace_sample, context):
98 | params = convert_params_from_subspace_to_full(params_subspace_sample, projection_matrix, params_full_init)
99 | params = reconstruct_tree_params(params)
100 | outputs = self.model.apply({"params": params}, context)
101 | return outputs
102 |
103 | self.predict_rewards = predict_rewards
104 |
105 | def fz(params):
106 | return params
107 |
108 | def fx(params, context, action):
109 | return predict_rewards(params, context)[action, None]
110 |
111 | ekf = ExtendedKalmanFilter(fz, fx, Q, R)
112 | self.ekf = ekf
113 |
114 | bel = (params_subspace_init, covariance_subspace_init, 0)
115 | return bel
116 |
117 | def sample_params(self, key, bel):
118 | params_subspace, covariance_subspace, t = bel
119 | mv_normal = tfd.MultivariateNormalFullCovariance(loc=params_subspace, covariance_matrix=covariance_subspace)
120 | params_subspace = mv_normal.sample(seed=key)
121 | return params_subspace
122 |
123 | def update_bel(self, bel, context, action, reward):
124 | xs = (reward, (context, action))
125 | bel, _ = jit(self.ekf.filter_step)(bel, xs)
126 | return bel
127 |
128 | def choose_action(self, key, bel, context):
129 | # Thompson sampling strategy
130 | # Could also use epsilon greedy or UCB
131 | w = self.sample_params(key, bel)
132 | predicted_reward = self.predict_rewards(w, context)
133 | action = predicted_reward.argmax()
134 | return action
135 |
--------------------------------------------------------------------------------
/bandits/scripts/movielens_exp.py:
--------------------------------------------------------------------------------
1 | from jax.random import PRNGKey
2 |
3 | import optax
4 | import pandas as pd
5 |
6 | import argparse
7 | from time import time
8 |
9 | from environments.movielens_env import MovielensEnvironment
10 |
11 | from agents.linear_bandit import LinearBandit
12 | from agents.linear_kf_bandit import LinearKFBandit
13 | from agents.ekf_subspace import SubspaceNeuralBandit
14 | from agents.ekf_orig_diag import DiagonalNeuralBandit
15 | from agents.diagonal_subspace import DiagonalSubspaceNeuralBandit
16 | from agents.limited_memory_neural_linear import LimitedMemoryNeuralLinearBandit
17 |
18 | from .training_utils import train, MLP, MLPWide
19 | from .mnist_exp import mapping, rank, summarize_results, method_ordering
20 |
21 |
22 | def main(config):
23 | eta = 6.0
24 | lmbda = 0.25
25 |
26 | learning_rate = 0.01
27 | momentum = 0.9
28 |
29 | update_step_mod = 100
30 | buffer_size = 50
31 | nepochs = 100
32 |
33 | # Neural Linear Limited
34 | nl_lim = {"buffer_size": buffer_size, "opt": optax.sgd(learning_rate, momentum), "eta": eta, "lmbda": lmbda,
35 | "update_step_mod": update_step_mod, "nepochs": nepochs}
36 |
37 | # Neural Linear Unlimited
38 | buffer_size = 4800
39 |
40 | nl_unlim = nl_lim.copy()
41 | nl_unlim["buffer_size"] = buffer_size
42 |
43 | npulls, nwarmup = 2, 2000
44 | learning_rate, momentum = 0.8, 0.9
45 | observation_noise = 0.0
46 | prior_noise_variance = 1e-4
47 | n_components = 470
48 | nepochs = 1000
49 | random_projection = False
50 |
51 | # Subspace Neural with SVD
52 | ekf_sub_svd = {"opt": optax.sgd(learning_rate, momentum), "prior_noise_variance": prior_noise_variance,
53 | "nwarmup": nwarmup, "nepochs": nepochs,
54 | "observation_noise": observation_noise, "n_components": n_components,
55 | "random_projection": random_projection}
56 |
57 | # Subspace Neural without SVD
58 | ekf_sub_rnd = ekf_sub_svd.copy()
59 | ekf_sub_rnd["random_projection"] = True
60 |
61 | system_noise = 0.0
62 |
63 | ekf_orig = {"opt": optax.sgd(learning_rate, momentum), "prior_noise_variance": prior_noise_variance,
64 | "nwarmup": nwarmup, "nepochs": nepochs,
65 | "system_noise": system_noise, "observation_noise": observation_noise}
66 | linear = {}
67 |
68 | bandits = {"Linear": {"kwargs": linear,
69 | "bandit": LinearBandit
70 | },
71 | "Linear KF": {"kwargs": linear.copy(),
72 | "bandit": LinearKFBandit
73 | },
74 | "Limited Neural Linear": {"kwargs": nl_lim,
75 | "bandit": LimitedMemoryNeuralLinearBandit
76 | },
77 | "Unlimited Neural Linear": {"kwargs": nl_unlim,
78 | "bandit": LimitedMemoryNeuralLinearBandit
79 | },
80 | "EKF Subspace SVD": {"kwargs": ekf_sub_svd,
81 | "bandit": SubspaceNeuralBandit
82 | },
83 | "EKF Subspace RND": {"kwargs": ekf_sub_rnd,
84 | "bandit": SubspaceNeuralBandit
85 | },
86 | "EKF Diagonal Subspace SVD": {"kwargs": ekf_sub_svd.copy(),
87 | "bandit": DiagonalSubspaceNeuralBandit
88 | },
89 | "EKF Diagonal Subspace RND": {"kwargs": ekf_sub_rnd.copy(),
90 | "bandit": DiagonalSubspaceNeuralBandit
91 | },
92 | "EKF Orig Diagonal": {"kwargs": ekf_orig,
93 | "bandit": DiagonalNeuralBandit
94 | }
95 | }
96 |
97 | results = []
98 | repeats = [1]
99 | for repeat in repeats:
100 | key = PRNGKey(0)
101 | # Create the environment beforehand
102 | movielens = MovielensEnvironment(key, repeat=repeat)
103 | # Number of different digits
104 | num_arms = movielens.labels_onehot.shape[-1]
105 | models = {"MLP1": MLP(num_arms), "MLP2": MLPWide(num_arms)}
106 |
107 | for model_name, model in models.items():
108 | for bandit_name, properties in bandits.items():
109 | if not bandit_name.startswith("Linear"):
110 | properties["kwargs"]["model"] = model
111 | print(f"Bandit : {bandit_name}")
112 | key = PRNGKey(314)
113 | start = time()
114 | warmup_rewards, rewards_trace, opt_rewards = train(key, properties["bandit"], movielens, npulls,
115 | config.ntrials, properties["kwargs"], neural=False)
116 | rtotal, rstd = summarize_results(warmup_rewards, rewards_trace, spacing="\t")
117 | end = time()
118 | print(f"\tTime : {end - start}:0.3f")
119 | results.append((bandit_name, model_name, end - start, rtotal, rstd))
120 |
121 | df = pd.DataFrame(results)
122 | df = df.rename(columns={0: "Method", 1: "Model", 2: "Time", 3: "Reward", 4: "Std"})
123 | df["Method"] = df["Method"].apply(lambda v: mapping[v])
124 | df["Rank"] = df["Method"].apply(lambda v: method_ordering[v])
125 | df["AltRank"] = df["Model"].apply(lambda v: rank[v])
126 |
127 | df["Reward"] = df['Reward'].astype(float)
128 | df["Time"] = df['Time'].astype(float)
129 | df["Std"] = df['Std'].astype(float)
130 | df.to_csv(config.filepath)
131 |
132 |
133 | if __name__ == "__main__":
134 | parser = argparse.ArgumentParser()
135 | parser.add_argument('--ntrials', type=int, nargs='?', const=10, default=10)
136 | filepath = "bandits/results/movielens_results.csv"
137 | parser.add_argument('--filepath', type=str, nargs='?', const=filepath, default=filepath)
138 |
139 | # Parse the argument
140 | args = parser.parse_args()
141 | main(args)
142 |
--------------------------------------------------------------------------------
/aistats2022-slides/slides.md:
--------------------------------------------------------------------------------
1 | ---
2 | # try also 'default' to start simple
3 | theme: default
4 | # random image from a curated Unsplash collection by Anthony
5 | # like them? see https://unsplash.com/collections/94734566/slidev
6 | background: https://images.unsplash.com/photo-1620460700571-320445215efb?crop=entropy&cs=tinysrgb&fit=crop&fm=jpg&h=1080&ixid=MnwxfDB8MXxyYW5kb218MHw5NDczNDU2Nnx8fHx8fHwxNjQ1MjY3NTc2&ixlib=rb-1.2.1&q=80&utm_campaign=api-credit&utm_medium=referral&utm_source=unsplash_source&w=1920
7 | # apply any windi css classes to the current slide
8 | ---
9 |
10 | # Subspace Neural Bandits
11 | ### AIStats 2022
12 |
13 | Gerardo Duran-Martin, Queen Mary University of London, UK
14 | Aleyna Kara, Boğaziçi University, Turkey
15 | Kevin Murphy, Google Research, Brain Team
16 |
17 | Feburary 2022
18 |
19 | ---
20 |
21 | # Contextual bandits
22 | ### [Li, et.al. (2012)](https://arxiv.org/abs/1003.0146)
23 |
24 | Let $\mathcal{A} = \{a^{(1)}, \ldots, a^{(K)}\}$ be a set of actions. At every time step $t=1,\ldots,T$
25 | 1. we are given a context ${\bf s}_t$
26 | 2. we decide, based on ${\bf s}_t$, an action $a_t \in \mathcal{A}$
27 | 3. we obtain a reward $r_t$ based on the context ${\bf s}_t$ and the chosen action $a_t$
28 |
29 | Our goal is to choose the set of actions that maximise the expected reward $\sum_{t=1}^T\mathbb{E}[R_t]$.
30 |
31 | ---
32 |
33 | # Thompson Sampling
34 |
35 | ### [Agrawal and Goyal (2014)](https://arxiv.org/abs/1209.3352), [Russo, et.al. (2014)](https://arxiv.org/abs/1402.0298)
36 | Let $\mathcal{D}_t = (s_t, a_t, r_t)$ be a sequence of observations. Let $\mathcal{D}_{1:t} = \{\mathcal{D}_1, \ldots, \mathcal{D}_t\}$.
37 |
38 | At every time step $t=1,\ldots, T$, we follow the following procedure:
39 | 1. Sample $\boldsymbol\theta_t \sim p(\cdot \vert \mathcal{D}_{1:t})$
40 | 2. $a_t = \arg\max_{a \in \mathcal{A}} \mathbb{E}[R(s_t,a; \boldsymbol\theta_t)]$
41 | 3. Obtain $r_t \sim R(s_t,a_t; \boldsymbol\theta_t)$
42 | 4. Store $\mathcal{D}_t = (s_t, a_t, r_t)$
43 |
44 | *Example*: Beta-Bernoulli bandit with $K=4$ arms.
45 |
46 |
49 |
50 |
51 | ---
52 |
53 | # Neural Bandits
54 | ### Characterising the reward function
55 |
56 | Let $f: \mathcal{S}\times\mathcal{A}\times\mathbb{R}^D \to \mathbb{R}^K$ be a neural network. A neural bandit is a contextual bandit where the reward is taken to be
57 |
58 | $$
59 | r_t \vert {\bf s}_t, a, \theta_t \sim \mathcal{N}\Big(f({\bf s}_t, a, \boldsymbol\theta_t), \sigma^2\Big)
60 | $$
61 |
62 |
63 | The main question: How to determine $\boldsymbol\theta_t$ at every time step $t$ using Thompson sampling?
64 |
65 | We need to compute (or approximate) the posterior distribution of the parameters in the neural network:
66 |
67 | $$
68 | \begin{aligned}
69 | p(\boldsymbol\theta \vert \mathcal{D}_{1:t}) &= p(\boldsymbol\theta \vert \mathcal{D}_{1:t-1}, \mathcal{D}_t)\\
70 | &\propto p(\boldsymbol\theta \vert \mathcal{D}_{1:t-1}) p(\mathcal{D}_t \vert \boldsymbol\theta) \\
71 | \end{aligned}
72 | $$
73 |
74 | ---
75 |
76 | # Subspace neural bandits
77 | ## Motivation
78 |
79 | * Current state-of-the-art solutions, although efficient, are not fully Bayesian.
80 | 1. Neural linear approximation
81 | 2. Lim2 approximation
82 | 3. Neural tangent approximation
83 |
84 |
109 |
110 |
113 |
114 | ---
115 |
116 | # Neural networks (in a subspace)
117 | ### [Li, et.al. (2018)](https://arxiv.org/abs/1804.08838), [Larsen, et.al. (2021)](https://arxiv.org/abs/2107.05802)
118 | Stochastic gradient descent (SGD) for neural networks **in a subspace**.
119 | Neural networks live in a linear subspace.
120 |
121 | $$
122 | \boldsymbol\theta({\bf z}_t) = {\bf A z}_{t} + \boldsymbol\theta_*
123 | $$
124 |
125 |
126 |
127 |
128 |
129 | ---
130 |
131 | # Our contribution: subspace neural bandits
132 | ### Extended Kalman filter and neural networks in a subspace.
133 | We learn a subspace ${\bf z} \in \mathbb{R}^d$
134 |
135 | $$
136 | \begin{aligned}
137 | \boldsymbol\theta({\bf z}_t) &= {\bf A z}_{t} + \boldsymbol\theta_* \\
138 | {\bf z}_t \vert {\bf z}_{t-1} &\sim \mathcal{N}({\bf z}_{t-1}, \tau^2 {\bf I}) \\
139 | {\bf r}_t \vert {\bf z}_t &\sim \mathcal{N}\Big(f({\bf s}_t, a_t, \boldsymbol\theta({\bf z}_t)), \sigma^2 {\bf I}\Big)
140 | \end{aligned}
141 | $$
142 |
143 |
144 |
145 | ---
146 |
147 | # Results
148 | ### MNIST: cumulative reward
149 |
150 | Classification-turned-bandit problem.
151 | Maximum reward is $T=5000$ (total number of samples).
152 |
153 |
154 |
155 |
156 |
157 |
158 | ---
159 |
160 | # Results
161 | ### MNIST: running time
162 |
163 |
165 |
176 |
177 |
178 | ---
179 |
180 | # Subspace neural bandits
181 | ### probml.github.io/bandits
182 |
183 |