├── .gitignore
├── A2C
└── Tutorial_Advantage_Actor_Critic_(A2C).ipynb
├── Deep_Q_Learning
├── README.md
└── Tutorial_Deep_Q_Learning.ipynb
├── Exploration
├── README.md
└── Tutorial_UCBVI.ipynb
├── LICENSE
├── README.md
├── Value Iteration and Q-Learning
├── README.md
└── Value_Iteration_and_Q_Learning.ipynb
├── colab_test
└── test_rlberry_setup.ipynb
├── logo
└── logo_wide.svg
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | *.mp4
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
136 | # pytype static type analyzer
137 | .pytype/
138 |
139 | # Cython debug symbols
140 | cython_debug/
--------------------------------------------------------------------------------
/A2C/Tutorial_Advantage_Actor_Critic_(A2C).ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Tutorial - Advantage Actor Critic (A2C).ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyOerJxVFIaozWjxy5taLfea",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | }
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "view-in-github",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "
"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {
31 | "id": "FRvfou6G9RGn"
32 | },
33 | "source": [
34 | "# Tutorial - Advantage Actor Critic (A2C)\n",
35 | "\n",
36 | "A2C keeps two neural networks:\n",
37 | "* One network with paramemeters $\\theta$ to represent the policy $\\pi_\\theta$.\n",
38 | "* One network with parameters $\\omega$ to represent a value function $V_\\omega$, that approximates $V^{\\pi_\\theta}$\n",
39 | "\n",
40 | "\n",
41 | "At each iteration, A2C collects $M$ transitions $(s_i, a_i, r_i, s_i')_{i=1}^M$ by following the policy $\\pi_\\theta$. If a terminal state is reached, we simply go back to the initial state and continue to play $\\pi_\\theta$ until we gather the $M$ transitions.\n",
42 | "\n",
43 | "Consider the following quantities, defined based on the collected transitions:\n",
44 | "\n",
45 | "$$\n",
46 | "\\widehat{V}(s_i) = \\widehat{Q}(s_i, a_i) = \\sum_{t=i}^{\\tau_i \\wedge M} \\gamma^{t-i} r_t + \\gamma^{M-i+1} V_\\omega(s_M')\\mathbb{I}\\{\\tau_i>M\\}\n",
47 | "$$\n",
48 | "\n",
49 | "where and $\\tau_i = \\min\\{t\\geq i: s_i' \\text{ is a terminal state}\\}$, and \n",
50 | "\n",
51 | "$$\n",
52 | "\\mathbf{A}_\\omega(s_i, a_i) = \\widehat{Q}(s_i, a_i) - V_\\omega(s_i) \n",
53 | "$$\n",
54 | "\n",
55 | "\n",
56 | "A2C then takes a gradient step to minimize the policy \"loss\" (keeping $\\omega$ fixed):\n",
57 | "\n",
58 | "$$\n",
59 | "L_\\pi(\\theta) =\n",
60 | "-\\frac{1}{M} \\sum_{i=1}^M \\mathbf{A}_\\omega(s_i, a_i) \\log \\pi_\\theta(a_i|s_i)\n",
61 | "- \\frac{\\alpha}{M}\\sum_{i=1}^M \\sum_a \\pi(a|s_i) \\log \\frac{1}{\\pi(a|s_i)}\n",
62 | "$$\n",
63 | "\n",
64 | "and a gradient step to minimize the value loss (keeping $\\theta$ fixed):\n",
65 | "\n",
66 | "$$\n",
67 | "L_v(\\omega) = \\frac{1}{M} \\sum_{i=1}^M \\left( \\widehat{V}(s_i) - V_\\omega(s_i) \\right)^2\n",
68 | "$$\n",
69 | " \n",
70 | "\n",
71 | "\n",
72 | "# Reminders\n",
73 | "\n",
74 | "\n",
75 | "Objective function:\n",
76 | "\n",
77 | "$$\n",
78 | "J(\\theta) = \\mathbb{E}_{\\pi_\\theta}\n",
79 | "\\left[ \n",
80 | " \\sum_{t=0}^\\infty \\gamma^t r(S_t, A_t)\n",
81 | "\\right]\n",
82 | "$$\n",
83 | "\n",
84 | "Policy gradient:\n",
85 | "\n",
86 | "$$\n",
87 | "\\nabla_\\theta J(\\theta)= \\mathbb{E}_{\\pi_\\theta}\n",
88 | "\\left[ \n",
89 | " \\sum_{t=0}^\\infty \\gamma^t A^{\\pi_\\theta}(S_t, A_t) \n",
90 | " \\nabla_\\theta \\log \\pi_\\theta(A_t|S_t)\n",
91 | "\\right]\n",
92 | "$$\n",
93 | "where $A^{\\pi_\\theta}(s, a) = Q^{\\pi_\\theta}(s, a) - V^{\\pi_\\theta}(s) $ is the advantage function."
94 | ]
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "metadata": {
99 | "id": "Er4wbIih9e24"
100 | },
101 | "source": [
102 | "# Colab setup"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "metadata": {
108 | "colab": {
109 | "base_uri": "https://localhost:8080/"
110 | },
111 | "id": "O12jMLD29DAU",
112 | "outputId": "37a4b59a-2b5d-44f4-da53-51fd84d77c3f"
113 | },
114 | "source": [
115 | "# After installing, restart the kernel\n",
116 | "\n",
117 | "# install rlberry library\n",
118 | "!git clone https://github.com/rlberry-py/rlberry.git \n",
119 | "!cd rlberry && git pull && pip install -e .[full] > /dev/null 2>&1\n",
120 | "!pip install ffmpeg-python > /dev/null 2>&1\n",
121 | "\n",
122 | "# gym\n",
123 | "!pip install 'gym[all]' > /dev/null 2>&1\n",
124 | "\n",
125 | "# packages required to show video\n",
126 | "!pip install pyvirtualdisplay > /dev/null 2>&1\n",
127 | "!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1\n",
128 | "\n",
129 | "# ask to restart runtime\n",
130 | "print(\"\")\n",
131 | "print(\" ~~~ Libraries installed, please restart the runtime! ~~~ \")\n",
132 | "print(\"\")"
133 | ],
134 | "execution_count": 1,
135 | "outputs": [
136 | {
137 | "output_type": "stream",
138 | "text": [
139 | "Cloning into 'rlberry'...\n",
140 | "remote: Enumerating objects: 472, done.\u001b[K\n",
141 | "remote: Counting objects: 100% (472/472), done.\u001b[K\n",
142 | "remote: Compressing objects: 100% (292/292), done.\u001b[K\n",
143 | "remote: Total 3541 (delta 283), reused 326 (delta 177), pack-reused 3069\u001b[K\n",
144 | "Receiving objects: 100% (3541/3541), 886.51 KiB | 9.85 MiB/s, done.\n",
145 | "Resolving deltas: 100% (2277/2277), done.\n",
146 | "Already up to date.\n",
147 | "\n",
148 | " ~~~ Libraries installed, please restart the runtime! ~~~ \n",
149 | "\n"
150 | ],
151 | "name": "stdout"
152 | }
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "metadata": {
158 | "id": "gKOp4h0Oe9-X"
159 | },
160 | "source": [
161 | "import gym\r\n",
162 | "from gym import logger as gymlogger\r\n",
163 | "from gym.wrappers import Monitor\r\n",
164 | "gymlogger.set_level(40) # error only\r\n",
165 | "\r\n",
166 | "import torch\r\n",
167 | "import torch.nn as nn\r\n",
168 | "import torch.nn.functional as F \r\n",
169 | "from torch import optim\r\n",
170 | "\r\n",
171 | "import numpy as np\r\n",
172 | "\r\n",
173 | "\r\n",
174 | "# for videos\r\n",
175 | "import rlberry.colab_utils.display_setup\r\n",
176 | "from rlberry.colab_utils.display_setup import show_video"
177 | ],
178 | "execution_count": 7,
179 | "outputs": []
180 | },
181 | {
182 | "cell_type": "code",
183 | "metadata": {
184 | "id": "MESFRbWdfA6P"
185 | },
186 | "source": [
187 | "class ActorNetwork(nn.Module):\r\n",
188 | " \"\"\"\r\n",
189 | " This network represents the policy\r\n",
190 | " \"\"\"\r\n",
191 | "\r\n",
192 | " def __init__(self, input_size, hidden_size, action_size):\r\n",
193 | " super(ActorNetwork, self).__init__()\r\n",
194 | " self.n_actions = action_size\r\n",
195 | " self.dim_observation = input_size\r\n",
196 | " \r\n",
197 | " self.net = nn.Sequential(\r\n",
198 | " nn.Linear(in_features=self.dim_observation, out_features=hidden_size),\r\n",
199 | " nn.ReLU(),\r\n",
200 | " nn.Linear(in_features=hidden_size, out_features=hidden_size),\r\n",
201 | " nn.ReLU(),\r\n",
202 | " nn.Linear(in_features=hidden_size, out_features=self.n_actions),\r\n",
203 | " nn.Softmax(dim=-1)\r\n",
204 | " )\r\n",
205 | " \r\n",
206 | " def policy(self, state):\r\n",
207 | " state = torch.tensor(state, dtype=torch.float)\r\n",
208 | " return self.net(state)\r\n",
209 | " \r\n",
210 | " def sample_action(self, state):\r\n",
211 | " state = torch.tensor(state, dtype=torch.float)\r\n",
212 | " action = torch.multinomial(self.policy(state), 1)\r\n",
213 | " return action.item()"
214 | ],
215 | "execution_count": 8,
216 | "outputs": []
217 | },
218 | {
219 | "cell_type": "code",
220 | "metadata": {
221 | "id": "R_DHHAQNfD7Z"
222 | },
223 | "source": [
224 | "class ValueNetwork(nn.Module):\r\n",
225 | " \"\"\"\r\n",
226 | " This class represents the value function\r\n",
227 | " \"\"\"\r\n",
228 | "\r\n",
229 | " def __init__(self, input_size, hidden_size, output_size):\r\n",
230 | " super(ValueNetwork, self).__init__()\r\n",
231 | " self.fc1 = nn.Linear(input_size, hidden_size)\r\n",
232 | " self.fc2 = nn.Linear(hidden_size, hidden_size)\r\n",
233 | " self.fc3 = nn.Linear(hidden_size, output_size)\r\n",
234 | "\r\n",
235 | " def forward(self, x):\r\n",
236 | " out = F.relu(self.fc1(x))\r\n",
237 | " out = F.relu(self.fc2(out))\r\n",
238 | " out = self.fc3(out)\r\n",
239 | " return out\r\n",
240 | " \r\n",
241 | " def value(self, state):\r\n",
242 | " state = torch.tensor(state, dtype=torch.float)\r\n",
243 | " return self.forward(state)"
244 | ],
245 | "execution_count": 9,
246 | "outputs": []
247 | },
248 | {
249 | "cell_type": "code",
250 | "metadata": {
251 | "id": "_Ry-b3HgfGx5"
252 | },
253 | "source": [
254 | "# You can select your environment here\r\n",
255 | "env_id = 'CartPole-v1' # @param [\"CartPole-v1\", \"LunarLander-v2\", \"MountainCar-v0\"]\r\n",
256 | "env = gym.make(env_id)\r\n",
257 | "eval_env = gym.make(env_id) # environment to evaluate the policy"
258 | ],
259 | "execution_count": 10,
260 | "outputs": []
261 | },
262 | {
263 | "cell_type": "code",
264 | "metadata": {
265 | "id": "h65dXIY5fMZg"
266 | },
267 | "source": [
268 | "# Define you networks\r\n",
269 | "value_network = ValueNetwork(env.observation_space.shape[0], 16, 1)\r\n",
270 | "actor_network = ActorNetwork(env.observation_space.shape[0], 16, env.action_space.n)\r\n",
271 | "print(value_network)\r\n",
272 | "print(actor_network)\r\n",
273 | "\r\n",
274 | "# Define your optimizers\r\n",
275 | "value_network_optimizer = torch.optim.RMSprop(value_network.parameters(), lr=0.01)\r\n",
276 | "actor_network_optimizer = torch.optim.RMSprop(actor_network.parameters(), lr=0.01)\r\n",
277 | "\r\n",
278 | "# --------------------------------------------------------------\r\n",
279 | "# Parameters\r\n",
280 | "# --------------------------------------------------------------\r\n",
281 | "num_iterations = 300 # Number of iterations\r\n",
282 | "batch_size = 512 # How many samples to collect (value of M)\r\n",
283 | "gamma = 1 # Discount factor\r\n",
284 | "alpha = 0.001 # Entropy term coefficient\r\n",
285 | "reward_threshold = 495 # Stop training when the policy achieves this amound of rewards\r\n",
286 | "\r\n",
287 | "\r\n",
288 | "# --------------------------------------------------------------\r\n",
289 | "# Train\r\n",
290 | "# --------------------------------------------------------------\r\n",
291 | "for iteration in range(num_iterations):\r\n",
292 | " # Initialize batch storage\r\n",
293 | " states = np.empty((batch_size,) + env.observation_space.shape, dtype=np.float) # shape (batch_size, state_dim)\r\n",
294 | " rewards = np.empty((batch_size,), dtype=np.float) # shape (batch_size, ) \r\n",
295 | " next_states = np.empty((batch_size,) + env.observation_space.shape, dtype=np.float) # shape (batch_size, state_dim)\r\n",
296 | " dones = np.empty((batch_size,), dtype=np.bool) # shape (batch_size, ) \r\n",
297 | " proba = torch.empty((batch_size,), dtype=np.float) # shape (batch_size, ), store pi(a_t|s_t)\r\n",
298 | " next_value = 0 # \r\n",
299 | " \r\n",
300 | " # Intialize environment\r\n",
301 | " state = env.reset()\r\n",
302 | "\r\n",
303 | " # Generate batch\r\n",
304 | " for i in range(batch_size):\r\n",
305 | " action = actor_network.sample_action(state)\r\n",
306 | " next_state, reward, done, _ = env.step(action)\r\n",
307 | "\r\n",
308 | " states[i] = # ...\r\n",
309 | " rewards[i] = # ...\r\n",
310 | " next_states[i] = # ...\r\n",
311 | " dones[i] = # ...\r\n",
312 | " proba[i] = # ...\r\n",
313 | "\r\n",
314 | " state = next_state\r\n",
315 | " if done:\r\n",
316 | " state = env.reset()\r\n",
317 | "\r\n",
318 | " if not done:\r\n",
319 | " next_value = value_network.value(next_states[-1]).detach().numpy()[0]\r\n",
320 | "\r\n",
321 | " # compute returns (without bootstrapping)\r\n",
322 | " returns = np.zeros((batch_size,), dtype=np.float)\r\n",
323 | " T = batch_size\r\n",
324 | " for j in range(T):\r\n",
325 | " returns[T-j-1] = rewards[T-j-1]\r\n",
326 | " if j > 0:\r\n",
327 | " returns[T-j-1] += gamma * returns[T-j] * (1 - dones[T-j])\r\n",
328 | " else:\r\n",
329 | " returns[T-j-1] += gamma * next_value\r\n",
330 | "\r\n",
331 | " # compute advantage\r\n",
332 | " values = value_network.value(states)\r\n",
333 | " advantages = # ...\r\n",
334 | "\r\n",
335 | " # Compute MSE (Value loss)\r\n",
336 | " value_network_optimizer.zero_grad()\r\n",
337 | " loss_value = # ...\r\n",
338 | " loss_value.backward()\r\n",
339 | " value_network_optimizer.step()\r\n",
340 | "\r\n",
341 | " # Compute entropy term\r\n",
342 | " dist = actor_network.policy(states)\r\n",
343 | " entropy_term = -(dist*dist.log()).sum(-1).mean()\r\n",
344 | "\r\n",
345 | " # Compute policy loss\r\n",
346 | " actor_network_optimizer.zero_grad()\r\n",
347 | " loss_policy = # ...\r\n",
348 | " loss_policy += -alpha * entropy_term\r\n",
349 | " loss_policy.backward()\r\n",
350 | " actor_network_optimizer.step()\r\n",
351 | "\r\n",
352 | " if( (iteration+1)%10 == 0 ):\r\n",
353 | " eval_rewards = np.zeros(5)\r\n",
354 | " for sim in range(5):\r\n",
355 | " eval_done = False\r\n",
356 | " eval_state = eval_env.reset()\r\n",
357 | " while not eval_done:\r\n",
358 | " eval_action = actor_network.sample_action(eval_state)\r\n",
359 | " eval_next_state, eval_reward, eval_done, _ = eval_env.step(eval_action)\r\n",
360 | " eval_rewards[sim] += eval_reward\r\n",
361 | " eval_state = eval_next_state\r\n",
362 | " print(\"Iteration = {}, loss_value = {:0.3f}, loss_policy = {:0.3f}, rewards = {:0.2f}\"\r\n",
363 | " .format(iteration +1, loss_value.item(), loss_policy.item(), eval_rewards.mean()))\r\n",
364 | " if (eval_rewards.mean() > reward_threshold):\r\n",
365 | " break"
366 | ],
367 | "execution_count": null,
368 | "outputs": []
369 | },
370 | {
371 | "cell_type": "code",
372 | "metadata": {
373 | "id": "kPzvAqDVhc_K"
374 | },
375 | "source": [
376 | "env = Monitor(env, \"./gym-results\", force=True, video_callable=lambda episode: True)\r\n",
377 | "for episode in range(1):\r\n",
378 | " done = False\r\n",
379 | " state = env.reset()\r\n",
380 | " while not done:\r\n",
381 | " action = actor_network.sample_action(state)\r\n",
382 | " state, reward, done, info = env.step(action)\r\n",
383 | "env.close()\r\n",
384 | "show_video(directory=\"./gym-results\")"
385 | ],
386 | "execution_count": null,
387 | "outputs": []
388 | },
389 | {
390 | "cell_type": "markdown",
391 | "metadata": {
392 | "id": "vNqnseJtlU87"
393 | },
394 | "source": [
395 | "# Test other environments!\r\n",
396 | "\r\n",
397 | "Try some other environments available in OpenAI gym ([link](https://gym.openai.com/envs/#classic_control)). Suggestion: use `classic control` or `Box2D` environments."
398 | ]
399 | }
400 | ]
401 | }
--------------------------------------------------------------------------------
/Deep_Q_Learning/README.md:
--------------------------------------------------------------------------------
1 | # Instructions
2 |
3 | **To run the notebook in [Google Colab](https://colab.research.google.com/)**, click on the link
4 | `Open in Colab` at the top of the `.ipynb` file.
5 |
6 |
7 | **To run the notebook locally**, download the `.ipynb` file and install the required libraries,
8 | as explained below.
9 |
10 | * Setup virtual environment (optional but recommended):
11 |
12 | ```
13 | conda create -n rltutorials python=3.8
14 | conda activate rltutorials
15 | ```
16 |
17 | * Install required libraries:
18 |
19 | ```
20 | conda install -c conda-forge jupyterlab
21 | pip install git+https://github.com/rlberry-py/rlberry.git#egg=rlberry[torch_agents]
22 | ```
23 |
--------------------------------------------------------------------------------
/Deep_Q_Learning/Tutorial_Deep_Q_Learning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Tutorial_Deep_Q_Learning.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyP9EbLl6g2dURBpFFjKPouU",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | }
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "view-in-github",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "
"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {
31 | "id": "2j_no2BuvPUE"
32 | },
33 | "source": [
34 | "# Tutorial - Deep Q-Learning \n",
35 | "\n",
36 | "Deep Q-Learning uses a neural network to approximate $Q$ functions. Hence, we usually refer to this algorithm as DQN (for *deep Q network*).\n",
37 | "\n",
38 | "The parameters of the neural network are denoted by $\\theta$. \n",
39 | "* As input, the network takes a state $s$,\n",
40 | "* As output, the network returns $Q(s, a, \\theta)$, the value of each action $a$ in state $s$, according to the parameters $\\theta$.\n",
41 | "\n",
42 | "\n",
43 | "The goal of Deep Q-Learning is to learn the parameters $\\theta$ so that $Q(s, a, \\theta)$ approximates well the optimal $Q$-function $Q^*(s, a)$. \n",
44 | "\n",
45 | "In addition to the network with parameters $\\theta$, the algorithm keeps another network with the same architecture and parameters $\\theta^-$, called **target network**.\n",
46 | "\n",
47 | "The algorithm works as follows:\n",
48 | "\n",
49 | "1. At each time $t$, the agent is in state $s_t$ and has observed the transitions $(s_i, a_i, r_i, s_i')_{i=1}^{t-1}$, which are stored in a **replay buffer**.\n",
50 | "\n",
51 | "2. Choose action $a_t = \\arg\\max_a Q(s_t, a)$ with probability $1-\\varepsilon_t$, and $a_t$=random action with probability $\\varepsilon_t$. \n",
52 | "\n",
53 | "3. Take action $a_t$, observe reward $r_t$ and next state $s_t'$.\n",
54 | "\n",
55 | "4. Add transition $(s_t, a_t, r_t, s_t')$ to the **replay buffer**.\n",
56 | "\n",
57 | "4. Sample a minibatch $\\mathcal{B}$ containing $B$ transitions from the replay buffer. Using this minibatch, we define the loss:\n",
58 | "\n",
59 | "$$\n",
60 | "L(\\theta) = \\sum_{(s_i, a_i, r_i, s_i') \\in \\mathcal{B}}\n",
61 | "\\left[\n",
62 | "Q(s_i, a_i, \\theta) - y_i\n",
63 | "\\right]^2\n",
64 | "$$\n",
65 | "where the $y_i$ are the **targets** computed with the **target network** $\\theta^-$:\n",
66 | "\n",
67 | "$$\n",
68 | "y_i = r_i + \\gamma \\max_{a'} Q(s_i', a', \\theta^-).\n",
69 | "$$\n",
70 | "\n",
71 | "5. Update the parameters $\\theta$ to minimize the loss, e.g., with gradient descent (**keeping $\\theta^-$ fixed**): \n",
72 | "$$\n",
73 | "\\theta \\gets \\theta - \\eta \\nabla_\\theta L(\\theta)\n",
74 | "$$\n",
75 | "where $\\eta$ is the optimization learning rate. \n",
76 | "\n",
77 | "6. Every $N$ transitions ($t\\mod N$ = 0), update target parameters: $\\theta^- \\gets \\theta$.\n",
78 | "\n",
79 | "7. $t \\gets t+1$. Stop if $t = T$, otherwise go to step 2."
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {
85 | "id": "HhKHif__t9OD"
86 | },
87 | "source": [
88 | "# Colab setup"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "metadata": {
94 | "colab": {
95 | "base_uri": "https://localhost:8080/"
96 | },
97 | "id": "aylqy_sDqebM",
98 | "outputId": "e1a78b7f-f832-4119-e8c5-3e02264944d9"
99 | },
100 | "source": [
101 | "# After installing, restart the kernel\n",
102 | "\n",
103 | "if 'google.colab' in str(get_ipython()):\n",
104 | " print(\"Installing packages, please wait a few moments. You may need to restart the runtime after the installation.\")\n",
105 | "\n",
106 | " # install rlberry library\n",
107 | " !pip install git+https://github.com/rlberry-py/rlberry.git#egg=rlberry[default] > /dev/null 2>&1\n",
108 | "\n",
109 | " # install gym\n",
110 | " !pip install gym[all] > /dev/null 2>&1\n",
111 | "\n",
112 | " # packages required to show video\n",
113 | " !pip install pyvirtualdisplay > /dev/null 2>&1\n",
114 | " !apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1"
115 | ],
116 | "execution_count": 18,
117 | "outputs": [
118 | {
119 | "output_type": "stream",
120 | "name": "stdout",
121 | "text": [
122 | "Installing packages, please wait a few moments. You may need to restart the runtime after the installation.\n"
123 | ]
124 | }
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "metadata": {
130 | "id": "VWBRfwosfA9f"
131 | },
132 | "source": [
133 | "# Imports\n",
134 | "import torch\n",
135 | "import torch.nn as nn\n",
136 | "import torch.nn.functional as F\n",
137 | "import torch.optim as optim\n",
138 | "import numpy as np\n",
139 | "import random\n",
140 | "from copy import deepcopy\n",
141 | "from gym.wrappers import Monitor\n",
142 | "import gym"
143 | ],
144 | "execution_count": 19,
145 | "outputs": []
146 | },
147 | {
148 | "cell_type": "code",
149 | "metadata": {
150 | "id": "35Zzr-xCya5y"
151 | },
152 | "source": [
153 | "# Create directory for saving videos\n",
154 | "!mkdir videos > /dev/null 2>&1\n",
155 | "\n",
156 | "# Initialize display and import function to show videos\n",
157 | "import rlberry.colab_utils.display_setup\n",
158 | "from rlberry.colab_utils.display_setup import show_video"
159 | ],
160 | "execution_count": 20,
161 | "outputs": []
162 | },
163 | {
164 | "cell_type": "code",
165 | "metadata": {
166 | "id": "FLLwJLQlrTxo"
167 | },
168 | "source": [
169 | "# Random number generator\n",
170 | "import rlberry.seeding as seeding \n",
171 | "seeder = seeding.Seeder(456)\n",
172 | "rng = seeder.rng"
173 | ],
174 | "execution_count": 21,
175 | "outputs": []
176 | },
177 | {
178 | "cell_type": "markdown",
179 | "metadata": {
180 | "id": "528oqsgefIFl"
181 | },
182 | "source": [
183 | "# 1. Define the parameters"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "metadata": {
189 | "id": "CtExtR4dfMbm",
190 | "colab": {
191 | "base_uri": "https://localhost:8080/"
192 | },
193 | "outputId": "64f36e7b-b953-4442-bc88-9d9fe6b90ef7"
194 | },
195 | "source": [
196 | "# Environment\n",
197 | "env = gym.make(\"CartPole-v0\")\n",
198 | "\n",
199 | "# Discount factor\n",
200 | "GAMMA = 0.99\n",
201 | "\n",
202 | "# Batch size\n",
203 | "BATCH_SIZE = 256\n",
204 | "# Capacity of the replay buffer\n",
205 | "BUFFER_CAPACITY = 10000\n",
206 | "# Update target net every ... episodes\n",
207 | "UPDATE_TARGET_EVERY = 20\n",
208 | "\n",
209 | "# Initial value of epsilon\n",
210 | "EPSILON_START = 1.0\n",
211 | "# Parameter to decrease epsilon\n",
212 | "DECREASE_EPSILON = 200\n",
213 | "# Minimum value of epislon\n",
214 | "EPSILON_MIN = 0.05\n",
215 | "\n",
216 | "# Number of training episodes\n",
217 | "N_EPISODES = 200\n",
218 | "\n",
219 | "# Learning rate\n",
220 | "LEARNING_RATE = 0.1"
221 | ],
222 | "execution_count": 22,
223 | "outputs": [
224 | {
225 | "output_type": "stream",
226 | "name": "stdout",
227 | "text": [
228 | "INFO: Making new env: CartPole-v0\n"
229 | ]
230 | }
231 | ]
232 | },
233 | {
234 | "cell_type": "markdown",
235 | "metadata": {
236 | "id": "6g16Je-dhM2Q"
237 | },
238 | "source": [
239 | "# 2. Define the replay buffer"
240 | ]
241 | },
242 | {
243 | "cell_type": "code",
244 | "metadata": {
245 | "id": "Jvh82br9hMNt"
246 | },
247 | "source": [
248 | "class ReplayBuffer:\n",
249 | " def __init__(self, capacity):\n",
250 | " self.capacity = capacity\n",
251 | " self.memory = []\n",
252 | " self.position = 0\n",
253 | "\n",
254 | " def push(self, state, action, reward, next_state, done):\n",
255 | " \"\"\"Saves a transition.\"\"\"\n",
256 | " if len(self.memory) < self.capacity:\n",
257 | " self.memory.append(None)\n",
258 | " self.memory[self.position] = (state, action, reward, next_state, done)\n",
259 | " self.position = (self.position + 1) % self.capacity\n",
260 | "\n",
261 | " def sample(self, batch_size):\n",
262 | " return rng.choice(self.memory, batch_size).tolist()\n",
263 | "\n",
264 | "\n",
265 | " def __len__(self):\n",
266 | " return len(self.memory)\n",
267 | "\n",
268 | "# create instance of replay buffer\n",
269 | "replay_buffer = ReplayBuffer(BUFFER_CAPACITY)"
270 | ],
271 | "execution_count": 23,
272 | "outputs": []
273 | },
274 | {
275 | "cell_type": "markdown",
276 | "metadata": {
277 | "id": "UCc9WZppi92W"
278 | },
279 | "source": [
280 | "# 3. Define the neural network architecture, objective and optimizer"
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "metadata": {
286 | "id": "sdNz3Jrwi9iS"
287 | },
288 | "source": [
289 | "class Net(nn.Module):\n",
290 | " \"\"\"\n",
291 | " Basic neural net.\n",
292 | " \"\"\"\n",
293 | " def __init__(self, obs_size, hidden_size, n_actions):\n",
294 | " super(Net, self).__init__()\n",
295 | " self.net = nn.Sequential(\n",
296 | " nn.Linear(obs_size, hidden_size),\n",
297 | " nn.ReLU(),\n",
298 | " nn.Linear(hidden_size, n_actions)\n",
299 | " )\n",
300 | "\n",
301 | " def forward(self, x):\n",
302 | " return self.net(x)"
303 | ],
304 | "execution_count": 24,
305 | "outputs": []
306 | },
307 | {
308 | "cell_type": "code",
309 | "metadata": {
310 | "id": "NI9hFJ28jLZ_"
311 | },
312 | "source": [
313 | "# create network and target network\n",
314 | "hidden_size = 128\n",
315 | "obs_size = env.observation_space.shape[0]\n",
316 | "n_actions = env.action_space.n\n",
317 | "\n",
318 | "q_net = Net(obs_size, hidden_size, n_actions)\n",
319 | "target_net = Net(obs_size, hidden_size, n_actions)\n",
320 | "\n",
321 | "# objective and optimizer\n",
322 | "objective = nn.MSELoss()\n",
323 | "optimizer = optim.Adam(params=q_net.parameters(), lr=LEARNING_RATE)"
324 | ],
325 | "execution_count": 25,
326 | "outputs": []
327 | },
328 | {
329 | "cell_type": "markdown",
330 | "metadata": {
331 | "id": "xnR8nfoSjZjL"
332 | },
333 | "source": [
334 | "# 4. Implement Deep Q-Learning"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "metadata": {
340 | "id": "z6fT8cKdjmTZ"
341 | },
342 | "source": [
343 | "#\n",
344 | "# Some useful functions\n",
345 | "#\n",
346 | "\n",
347 | "def get_q(states):\n",
348 | " \"\"\"\n",
349 | " Compute Q function for a list of states\n",
350 | " \"\"\"\n",
351 | " with torch.no_grad():\n",
352 | " states_v = torch.FloatTensor([states])\n",
353 | " output = q_net.forward(states_v).data.numpy() # shape (1, len(states), n_actions)\n",
354 | " return output[0, :, :] # shape (len(states), n_actions)\n",
355 | "\n",
356 | "def eval_dqn(n_sim=5):\n",
357 | " \"\"\" \n",
358 | " Monte Carlo evaluation of DQN agent.\n",
359 | "\n",
360 | " Repeat n_sim times:\n",
361 | " * Run the DQN policy until the environment reaches a terminal state (= one episode)\n",
362 | " * Compute the sum of rewards in this episode\n",
363 | " * Store the sum of rewards in the episode_rewards array.\n",
364 | " \"\"\"\n",
365 | " env_copy = deepcopy(env)\n",
366 | " episode_rewards = np.zeros(n_sim)\n",
367 | "\n",
368 | " for ii in range(n_sim):\n",
369 | " state = env_copy.reset()\n",
370 | " done = False \n",
371 | " while not done:\n",
372 | " action = choose_action(state, 0.0)\n",
373 | " next_state, reward, done, _ = env_copy.step(action)\n",
374 | " episode_rewards[ii] += reward\n",
375 | " state = next_state\n",
376 | " return episode_rewards"
377 | ],
378 | "execution_count": 26,
379 | "outputs": []
380 | },
381 | {
382 | "cell_type": "code",
383 | "metadata": {
384 | "id": "OMspDNntkIoe"
385 | },
386 | "source": [
387 | "def choose_action(state, epsilon):\n",
388 | " \"\"\"\n",
389 | " ** TO BE IMPLEMENTED **\n",
390 | " \n",
391 | " Return action according to an epsilon-greedy exploration policy\n",
392 | " \"\"\"\n",
393 | " return 0\n",
394 | " \n",
395 | "\n",
396 | "def update(state, action, reward, next_state, done):\n",
397 | " \"\"\"\n",
398 | " ** TO BE COMPLETED **\n",
399 | " \"\"\"\n",
400 | " \n",
401 | " # add data to replay buffer\n",
402 | " replay_buffer.push(state, action, reward, next_state, done)\n",
403 | " \n",
404 | " if len(replay_buffer) < BATCH_SIZE:\n",
405 | " return np.inf\n",
406 | " \n",
407 | " # get batch\n",
408 | " transitions = replay_buffer.sample(BATCH_SIZE)\n",
409 | "\n",
410 | " # Compute loss - TO BE IMPLEMENTED!\n",
411 | " values = torch.zeros(BATCH_SIZE) # to be computed using batch\n",
412 | " targets = torch.zeros(BATCH_SIZE) # to be computed using batch\n",
413 | " loss = objective(values, targets)\n",
414 | " \n",
415 | " # Optimize the model - UNCOMMENT!\n",
416 | "# optimizer.zero_grad()\n",
417 | "# loss.backward()\n",
418 | "# optimizer.step()\n",
419 | " \n",
420 | " return loss.data.numpy()"
421 | ],
422 | "execution_count": 27,
423 | "outputs": []
424 | },
425 | {
426 | "cell_type": "code",
427 | "metadata": {
428 | "id": "QIhpKPhkkU4W",
429 | "colab": {
430 | "base_uri": "https://localhost:8080/"
431 | },
432 | "outputId": "93f23393-0bc4-48bf-d315-1fbc1d94f7c2"
433 | },
434 | "source": [
435 | "\n",
436 | "#\n",
437 | "# Train\n",
438 | "# \n",
439 | "\n",
440 | "EVAL_EVERY = 5\n",
441 | "REWARD_THRESHOLD = 199\n",
442 | "\n",
443 | "def train():\n",
444 | " state = env.reset()\n",
445 | " epsilon = EPSILON_START\n",
446 | " ep = 0\n",
447 | " total_time = 0\n",
448 | " while ep < N_EPISODES:\n",
449 | " action = choose_action(state, epsilon)\n",
450 | "\n",
451 | " # take action and update replay buffer and networks\n",
452 | " next_state, reward, done, _ = env.step(action)\n",
453 | " loss = update(state, action, reward, next_state, done)\n",
454 | "\n",
455 | " # update state\n",
456 | " state = next_state\n",
457 | "\n",
458 | " # end episode if done\n",
459 | " if done:\n",
460 | " state = env.reset()\n",
461 | " ep += 1\n",
462 | " if ( (ep+1)% EVAL_EVERY == 0):\n",
463 | " rewards = eval_dqn()\n",
464 | " print(\"episode =\", ep+1, \", reward = \", np.mean(rewards))\n",
465 | " if np.mean(rewards) >= REWARD_THRESHOLD:\n",
466 | " break\n",
467 | "\n",
468 | " # update target network\n",
469 | " if ep % UPDATE_TARGET_EVERY == 0:\n",
470 | " target_net.load_state_dict(q_net.state_dict())\n",
471 | " # decrease epsilon\n",
472 | " epsilon = EPSILON_MIN + (EPSILON_START - EPSILON_MIN) * \\\n",
473 | " np.exp(-1. * ep / DECREASE_EPSILON ) \n",
474 | "\n",
475 | " total_time += 1\n",
476 | "\n",
477 | "# Run the training loop\n",
478 | "train()\n",
479 | "\n",
480 | "# Evaluate the final policy\n",
481 | "rewards = eval_dqn(20)\n",
482 | "print(\"\")\n",
483 | "print(\"mean reward after training = \", np.mean(rewards))"
484 | ],
485 | "execution_count": 28,
486 | "outputs": [
487 | {
488 | "output_type": "stream",
489 | "name": "stdout",
490 | "text": [
491 | "episode = 5 , reward = 9.6\n",
492 | "episode = 10 , reward = 9.4\n",
493 | "episode = 15 , reward = 9.4\n",
494 | "episode = 20 , reward = 9.2\n",
495 | "episode = 25 , reward = 9.2\n",
496 | "episode = 30 , reward = 9.8\n",
497 | "episode = 35 , reward = 9.8\n",
498 | "episode = 40 , reward = 10.0\n",
499 | "episode = 45 , reward = 9.2\n",
500 | "episode = 50 , reward = 9.8\n"
501 | ]
502 | },
503 | {
504 | "output_type": "stream",
505 | "name": "stderr",
506 | "text": [
507 | "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:15: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
508 | " from ipykernel import kernelapp as app\n"
509 | ]
510 | },
511 | {
512 | "output_type": "stream",
513 | "name": "stdout",
514 | "text": [
515 | "episode = 55 , reward = 9.8\n",
516 | "episode = 60 , reward = 9.4\n",
517 | "episode = 65 , reward = 9.6\n",
518 | "episode = 70 , reward = 9.6\n",
519 | "episode = 75 , reward = 8.8\n",
520 | "episode = 80 , reward = 10.0\n",
521 | "episode = 85 , reward = 9.2\n",
522 | "episode = 90 , reward = 9.4\n",
523 | "episode = 95 , reward = 9.2\n",
524 | "episode = 100 , reward = 9.2\n",
525 | "episode = 105 , reward = 9.2\n",
526 | "episode = 110 , reward = 9.6\n",
527 | "episode = 115 , reward = 9.2\n",
528 | "episode = 120 , reward = 9.2\n",
529 | "episode = 125 , reward = 9.4\n",
530 | "episode = 130 , reward = 9.8\n",
531 | "episode = 135 , reward = 9.2\n",
532 | "episode = 140 , reward = 9.2\n",
533 | "episode = 145 , reward = 10.2\n",
534 | "episode = 150 , reward = 9.2\n",
535 | "episode = 155 , reward = 9.4\n",
536 | "episode = 160 , reward = 9.6\n",
537 | "episode = 165 , reward = 9.6\n",
538 | "episode = 170 , reward = 9.4\n",
539 | "episode = 175 , reward = 9.0\n",
540 | "episode = 180 , reward = 9.0\n",
541 | "episode = 185 , reward = 9.6\n",
542 | "episode = 190 , reward = 9.2\n",
543 | "episode = 195 , reward = 9.4\n",
544 | "episode = 200 , reward = 9.4\n",
545 | "\n",
546 | "mean reward after training = 9.8\n"
547 | ]
548 | }
549 | ]
550 | },
551 | {
552 | "cell_type": "markdown",
553 | "metadata": {
554 | "id": "c8QZwuvjgrMm"
555 | },
556 | "source": [
557 | "# Visualize the DQN policy"
558 | ]
559 | },
560 | {
561 | "cell_type": "code",
562 | "metadata": {
563 | "colab": {
564 | "base_uri": "https://localhost:8080/",
565 | "height": 474
566 | },
567 | "id": "FGcGwOcEfzPz",
568 | "outputId": "3aa22829-9b5c-4308-cd1a-aadb1a629fb0"
569 | },
570 | "source": [
571 | "def render_env(env):\n",
572 | " env = deepcopy(env)\n",
573 | " env = Monitor(env, './videos', force=True, video_callable=lambda episode: True)\n",
574 | " for episode in range(1):\n",
575 | " done = False\n",
576 | " state = env.reset()\n",
577 | " env.render()\n",
578 | " while not done:\n",
579 | " action = action = choose_action(state, 0.0)\n",
580 | " state, reward, done, info = env.step(action)\n",
581 | " env.render()\n",
582 | " env.close()\n",
583 | " show_video()\n",
584 | "\n",
585 | "render_env(env)"
586 | ],
587 | "execution_count": 29,
588 | "outputs": [
589 | {
590 | "output_type": "stream",
591 | "name": "stdout",
592 | "text": [
593 | "INFO: Clearing 4 monitor files from previous run (because force=True was provided)\n",
594 | "INFO: Starting new video recorder writing to /content/videos/openaigym.video.1.705.video000000.mp4\n",
595 | "INFO: Finished writing results. You can upload them to the scoreboard via gym.upload('/content/videos')\n"
596 | ]
597 | },
598 | {
599 | "output_type": "display_data",
600 | "data": {
601 | "text/html": [
602 | ""
606 | ],
607 | "text/plain": [
608 | ""
609 | ]
610 | },
611 | "metadata": {}
612 | }
613 | ]
614 | }
615 | ]
616 | }
--------------------------------------------------------------------------------
/Exploration/README.md:
--------------------------------------------------------------------------------
1 | # Instructions
2 |
3 | **To run the notebook in [Google Colab](https://colab.research.google.com/)**, click on the link
4 | `Open in Colab` at the top of the `.ipynb` file.
5 |
6 |
7 | **To run the notebook locally**, download the `.ipynb` file and install the required libraries,
8 | as explained below.
9 |
10 | * Setup virtual environment (optional but recommended):
11 |
12 | ```
13 | conda create -n rltutorials python=3.8
14 | conda activate rltutorials
15 | ```
16 |
17 | * Install required libraries:
18 |
19 | ```
20 | conda install -c conda-forge jupyterlab
21 | pip install git+https://github.com/rlberry-py/rlberry.git#egg=rlberry[default]
22 | ```
23 |
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 rlberry-py
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 |
7 |
10 |
11 | # Reinforcement Learning Tutorials
12 |
13 | * [Value Iteration and Q-Learning](https://github.com/rlberry-py/tutorials/blob/main/Value%20Iteration%20and%20Q-Learning/Value_Iteration_and_Q_Learning.ipynb)
14 |
15 | * [Deep Q Learning](https://github.com/rlberry-py/tutorials/blob/main/Deep_Q_Learning/Tutorial_Deep_Q_Learning.ipynb)
16 |
17 | * [Advantage Actor-Critic (A2C)](https://github.com/rlberry-py/tutorials/blob/main/A2C/Tutorial_Advantage_Actor_Critic_(A2C).ipynb)
18 |
19 | See also the [`rlberry`](https://github.com/rlberry-py/rlberry) library!
20 |
--------------------------------------------------------------------------------
/Value Iteration and Q-Learning/README.md:
--------------------------------------------------------------------------------
1 | # Instructions
2 |
3 | **To run the notebook in [Google Colab](https://colab.research.google.com/)**, click on the link
4 | `Open in Colab` at the top of the `.ipynb` file.
5 |
6 |
7 | **To run the notebook locally**, download the `.ipynb` file and install the required libraries,
8 | as explained below.
9 |
10 | * Setup virtual environment (optional but recommended):
11 |
12 | ```
13 | conda create -n rltutorials python=3.8
14 | conda activate rltutorials
15 | ```
16 |
17 | * Install required libraries:
18 |
19 | ```
20 | conda install -c conda-forge jupyterlab
21 | pip install git+https://github.com/rlberry-py/rlberry.git#egg=rlberry[default]
22 | ```
23 |
24 |
--------------------------------------------------------------------------------
/Value Iteration and Q-Learning/Value_Iteration_and_Q_Learning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Tutorial - Value Iteration and Q-Learning.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "toc_visible": true,
10 | "authorship_tag": "ABX9TyM+8H1rbTADo1Hh3m1E+mXQ",
11 | "include_colab_link": true
12 | },
13 | "kernelspec": {
14 | "name": "python3",
15 | "display_name": "Python 3"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "Io_4iovMTlzT"
33 | },
34 | "source": [
35 | "# Tutorial - Value Iteration and Q-Learning\n",
36 | "---------------------------------\n",
37 | "\n",
38 | "In this tutorial, you will:\n",
39 | "\n",
40 | "* Implement the value iteration algorithm to approximate the value function when *a model of the environment is available*.\n",
41 | "* Implement the Q-Learning algorithm to approximate the value function when *the model is unknown*, that is, the agent must learn through interactions.\n",
42 | "\n",
43 | "We start with a short review of these algorithms.\n",
44 | "\n",
45 | "\n",
46 | "## Markov decision processes and value functions\n",
47 | "\n",
48 | "In reinforcement learning, an agent interacts with an enviroment by taking actions and observing rewards. Its goal is to learn a *policy*, that is, a mapping from states to actions, that maximizes the amount of reward it gathers.\n",
49 | "\n",
50 | "The enviroment is modeled as a __Markov decision process (MDP)__, defined by a set of states $\\mathcal{S}$, a set of actions $\\mathcal{A}$, a reward function $r(s, a)$ and transition probabilities $P(s'|s,a)$. When an agent takes action $a$ in state $s$, it receives a random reward with mean $r(s,a)$ and makes a transion to a state $s'$ distributed according to $P(s'|s,a)$.\n",
51 | "\n",
52 | "A __policy__ $\\pi$ is such that $\\pi(a|s)$ gives the probability of choosing an action $a$ in state $s$. __If the policy is deterministic__, we denote by $\\pi(s)$ the action that it chooses in state $s$. We are interested in finding a policy that maximizes the value function $V^\\pi$, defined as \n",
53 | "\n",
54 | "$$\n",
55 | "V^\\pi(s) = \\sum_{a\\in \\mathcal{A}} \\pi(a|s) Q^\\pi(s, a), \n",
56 | "\\quad \\text{where} \\quad \n",
57 | "Q^\\pi(s, a) = \\mathbf{E}\\left[ \\sum_{t=0}^\\infty \\gamma^t r(S_t, A_t) \\Big| S_0 = s, A_0 = a\\right].\n",
58 | "$$\n",
59 | "and represents the mean of the sum of discounted rewards gathered by the policy $\\pi$ in the MDP, where $\\gamma \\in [0, 1[$ is a discount factor ensuring the convergence of the sum. \n",
60 | "\n",
61 | "The __action-value function__ $Q^\\pi$ is the __fixed point of the Bellman operator $T^\\pi$__:\n",
62 | "\n",
63 | "$$ \n",
64 | "Q^\\pi(s, a) = T^\\pi Q^\\pi(s, a)\n",
65 | "$$\n",
66 | "where, for any function $f: \\mathcal{S}\\times\\mathcal{A} \\to \\mathbb{R}$\n",
67 | "$$\n",
68 | "T^\\pi f(s, a) = r(s, a) + \\gamma \\sum_{s'} P(s'|s,a) \\left(\\sum_{a'}\\pi(a'|s')f(s',a')\\right) \n",
69 | "$$\n",
70 | "\n",
71 | "\n",
72 | "The __optimal value function__, defined as $V^*(s) = \\max_\\pi V^\\pi(s)$ can be shown to satisfy $V^*(s) = \\max_a Q^*(s, a)$, where $Q^*$ is the __fixed point of the optimal Bellman operator $T^*$__: \n",
73 | "\n",
74 | "$$ \n",
75 | "Q^*(s, a) = T^* Q^*(s, a)\n",
76 | "$$\n",
77 | "where, for any function $f: \\mathcal{S}\\times\\mathcal{A} \\to \\mathbb{R}$\n",
78 | "$$\n",
79 | "T^* f(s, a) = r(s, a) + \\gamma \\sum_{s'} P(s'|s,a) \\max_{a'} f(s', a')\n",
80 | "$$\n",
81 | "and there exists an __optimal policy__ which is deterministic, given by $\\pi^*(s) \\in \\arg\\max_a Q^*(s, a)$.\n",
82 | "\n",
83 | "\n",
84 | "## Value iteration\n",
85 | "\n",
86 | "If both the reward function $r$ and the transition probablities $P$ are known, we can compute $Q^*$ using value iteration, which proceeds as follows:\n",
87 | "\n",
88 | "1. Start with arbitrary $Q_0$, set $t=0$.\n",
89 | "2. Compute $Q_{t+1}(s, a) = T^*Q_t(s,a)$ for every $(s, a)$.\n",
90 | "3. If $\\max_{s,a} | Q_{t+1}(s, a) - Q_t(s,a)| \\leq \\varepsilon$, return $Q_{t}$. Otherwise, set $t \\gets t+1$ and go back to 2. \n",
91 | "\n",
92 | "The convergence is guaranteed by the contraction property of the Bellman operator, and $Q_{t+1}$ can be shown to be a good approximation of $Q^*$ for small epsilon. \n",
93 | "\n",
94 | "__Question__: Can you bound the error $\\max_{s,a} | Q^*(s, a) - Q_t(s,a)|$ as a function of $\\gamma$ and $\\varepsilon$?\n",
95 | "\n",
96 | "## Q-Learning\n",
97 | "\n",
98 | "In value iteration, we need to know $r$ and $P$ to implement the Bellman operator. When these quantities are not available, we can approximate $Q^*$ using *samples* from the environment with the Q-Learning algorithm.\n",
99 | "\n",
100 | "Q-Learning with __$\\varepsilon$-greedy exploration__ proceeds as follows:\n",
101 | "\n",
102 | "1. Start with arbitrary $Q_0$, get starting state $s_0$, set $t=0$.\n",
103 | "2. Choosing action $a_t$: \n",
104 | " * With probability $\\varepsilon$ choose $a_t$ randomly (uniform distribution) \n",
105 | " * With probability $1-\\varepsilon$, choose $a_t \\in \\arg\\max_a Q_t(s_t, a)$.\n",
106 | "3. Take action $a_t$, observe next state $s_{t+1}$ and reward $r_t$.\n",
107 | "4. Compute error $\\delta_t = r_t + \\gamma \\max_a Q_t(s_{t+1}, a) - Q_t(s_t, a_t)$.\n",
108 | "5. Update \n",
109 | " * $Q_{t+1}(s, a) = Q_t(s, a) + \\alpha_t(s,a) \\delta_t$, __if $s=s_t$ and $a=a_t$__\n",
110 | " * $Q_{t+1}(s, a) = Q_{t}(s, a)$ otherwise.\n",
111 | "\n",
112 | "Here, $\\alpha_t(s,a)$ is a learning rate that can depend, for instance, on the number of times the algorithm has visited the state-action pair $(s, a)$. \n"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {
118 | "id": "KYq9-63OR8RW"
119 | },
120 | "source": [
121 | "# Colab setup"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "metadata": {
127 | "id": "AxepTGrNR3DX",
128 | "colab": {
129 | "base_uri": "https://localhost:8080/"
130 | },
131 | "outputId": "42376421-d387-42a8-a943-0d1c5b5b3db0"
132 | },
133 | "source": [
134 | "if 'google.colab' in str(get_ipython()):\n",
135 | " print(\"Installing packages, please wait a few moments. Restart the runtime after the installation.\")\n",
136 | "\n",
137 | " # install rlberry library\n",
138 | " !pip install git+https://github.com/rlberry-py/rlberry.git#egg=rlberry[default] > /dev/null 2>&1\n",
139 | "\n",
140 | " # packages required to show video\n",
141 | " !pip install pyvirtualdisplay > /dev/null 2>&1\n",
142 | " !apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1\n"
143 | ],
144 | "execution_count": 1,
145 | "outputs": [
146 | {
147 | "output_type": "stream",
148 | "name": "stdout",
149 | "text": [
150 | "Installing packages, please wait a few moments. Restart the runtime after the installation.\n"
151 | ]
152 | }
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "metadata": {
158 | "id": "3_bPhqKlSiF0",
159 | "colab": {
160 | "base_uri": "https://localhost:8080/"
161 | },
162 | "outputId": "959689cb-1e62-41f3-c1ac-71741bd5bb48"
163 | },
164 | "source": [
165 | "# Create directory for saving videos\n",
166 | "!mkdir videos > /dev/null 2>&1\n",
167 | "\n",
168 | "# The following code is will be used to visualize the environments.\n",
169 | "import base64\n",
170 | "from pyvirtualdisplay import Display\n",
171 | "from IPython import display as ipythondisplay\n",
172 | "from IPython.display import clear_output\n",
173 | "from pathlib import Path\n",
174 | "\n",
175 | "def show_video(filename=None, directory='./videos'):\n",
176 | " \"\"\"\n",
177 | " Either show all videos in a directory (if filename is None) or \n",
178 | " show video corresponding to filename.\n",
179 | " \"\"\"\n",
180 | " html = []\n",
181 | " if filename is not None:\n",
182 | " files = Path('./').glob(filename)\n",
183 | " else:\n",
184 | " files = Path(directory).glob(\"*.mp4\")\n",
185 | " for mp4 in files:\n",
186 | " print(mp4)\n",
187 | " video_b64 = base64.b64encode(mp4.read_bytes())\n",
188 | " html.append(''''''.format(mp4, video_b64.decode('ascii')))\n",
192 | " ipythondisplay.display(ipythondisplay.HTML(data=\"
\".join(html)))\n",
193 | " \n",
194 | "from pyvirtualdisplay import Display\n",
195 | "display = Display(visible=0, size=(800, 800))\n",
196 | "display.start()"
197 | ],
198 | "execution_count": 2,
199 | "outputs": [
200 | {
201 | "output_type": "execute_result",
202 | "data": {
203 | "text/plain": [
204 | ""
205 | ]
206 | },
207 | "metadata": {},
208 | "execution_count": 2
209 | }
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "metadata": {
215 | "id": "ZYZCXMpisE_O"
216 | },
217 | "source": [
218 | "# other required libraries\n",
219 | "import numpy as np\n",
220 | "import matplotlib.pyplot as plt\n",
221 | "\n"
222 | ],
223 | "execution_count": 3,
224 | "outputs": []
225 | },
226 | {
227 | "cell_type": "markdown",
228 | "metadata": {
229 | "id": "zOPiAupGmkxh"
230 | },
231 | "source": [
232 | "# Warm up: interacting with a reinforcement learning environment"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "metadata": {
238 | "id": "6IZ0bVAlTjpZ",
239 | "colab": {
240 | "base_uri": "https://localhost:8080/",
241 | "height": 578
242 | },
243 | "outputId": "60cf10f4-8f13-4264-c281-1194beff4c1d"
244 | },
245 | "source": [
246 | "from rlberry.envs import GridWorld\n",
247 | "\n",
248 | "# A GridWorld is an environment where an agent moves in a 2d grid and aims to reach the state which gives a reward.\n",
249 | "env = GridWorld(nrows=3, ncols=5, walls=((0,2),(1, 2)), success_probability=0.9)\n",
250 | "\n",
251 | "# Number of states and actions\n",
252 | "print(\"number of states = \", env.observation_space.n)\n",
253 | "print(\"number of actions = \", env.action_space.n)\n",
254 | "\n",
255 | "# Transitions probabilities, env.P[s, a, s'] = P(s'|s, a)\n",
256 | "print(\"transition probabilities from state 0 by taking action 1: \", env.P[0, 1, :])\n",
257 | "\n",
258 | "# Reward function: env.R[s, a] = r(s, a)\n",
259 | "print(\"mean reward in state 0 for action 1 = \", env.R[0, 1])\n",
260 | "\n",
261 | "# Following a random policy \n",
262 | "state = env.reset() # initial state \n",
263 | "env.enable_rendering() # save states for visualization\n",
264 | "for tt in range(100): # interact for 100 time steps\n",
265 | " action = env.action_space.sample() # random action, a good RL agent must have a better strategy!\n",
266 | " next_state, reward, is_terminal, info = env.step(action)\n",
267 | " if is_terminal:\n",
268 | " break\n",
269 | " state = next_state\n",
270 | "\n",
271 | "# save video \n",
272 | "env.save_video('./videos/random_policy.mp4', framerate=10)\n",
273 | "# clear rendering data\n",
274 | "env.clear_render_buffer()\n",
275 | "env.disable_rendering()\n",
276 | "# see video\n",
277 | "show_video(filename='./videos/random_policy.mp4')"
278 | ],
279 | "execution_count": 4,
280 | "outputs": [
281 | {
282 | "output_type": "stream",
283 | "name": "stderr",
284 | "text": [
285 | "[INFO] OpenGL_accelerate module loaded \n",
286 | "[INFO] Using accelerated ArrayDatatype \n",
287 | "[INFO] Generating grammar tables from /usr/lib/python3.7/lib2to3/Grammar.txt \n",
288 | "[INFO] Generating grammar tables from /usr/lib/python3.7/lib2to3/PatternGrammar.txt \n"
289 | ]
290 | },
291 | {
292 | "output_type": "stream",
293 | "name": "stdout",
294 | "text": [
295 | "number of states = 13\n",
296 | "number of actions = 4\n",
297 | "transition probabilities from state 0 by taking action 1: [0. 0.9 0. 0. 0.1 0. 0. 0. 0. 0. 0. 0. 0. ]\n",
298 | "mean reward in state 0 for action 1 = 0.0\n",
299 | "videos/random_policy.mp4\n"
300 | ]
301 | },
302 | {
303 | "output_type": "display_data",
304 | "data": {
305 | "text/html": [
306 | ""
310 | ],
311 | "text/plain": [
312 | ""
313 | ]
314 | },
315 | "metadata": {}
316 | }
317 | ]
318 | },
319 | {
320 | "cell_type": "markdown",
321 | "metadata": {
322 | "id": "snmFW5Bzqpwj"
323 | },
324 | "source": [
325 | "# Implementing Value Iteration\n",
326 | "\n",
327 | "1. Write a function ``bellman_operator`` that takes as input a function $Q$ and returns $T^* Q$.\n",
328 | "2. Write a function ``value_iteration`` that returns a function $Q$ such that $||Q-T^* Q||_\\infty \\leq \\varepsilon$\n",
329 | "3. Evaluate the performance of the policy $\\pi(s) = \\arg\\max_a Q(s, a)$, where Q is returned by ``value_iteration``."
330 | ]
331 | },
332 | {
333 | "cell_type": "code",
334 | "metadata": {
335 | "id": "RPIOmpjkq0YX"
336 | },
337 | "source": [
338 | "def bellman_operator(Q, env, gamma=0.99):\n",
339 | " S = env.observation_space.n\n",
340 | " A = env.action_space.n \n",
341 | " TQ = np.zeros((S, A))\n",
342 | "\n",
343 | " # to complete...\n",
344 | "\n",
345 | " return TQ"
346 | ],
347 | "execution_count": 5,
348 | "outputs": []
349 | },
350 | {
351 | "cell_type": "code",
352 | "metadata": {
353 | "id": "tEKAtA1LsYFx"
354 | },
355 | "source": [
356 | "def value_iteration(env, gamma=0.99, epsilon=1e-6):\n",
357 | " S = env.observation_space.n\n",
358 | " A = env.action_space.n \n",
359 | " Q = np.zeros((S, A))\n",
360 | "\n",
361 | " # to complete...\n",
362 | "\n",
363 | " return Q"
364 | ],
365 | "execution_count": 6,
366 | "outputs": []
367 | },
368 | {
369 | "cell_type": "code",
370 | "metadata": {
371 | "id": "rZ7k-rDLssSk",
372 | "colab": {
373 | "base_uri": "https://localhost:8080/",
374 | "height": 440
375 | },
376 | "outputId": "7731f953-093d-4c3b-e84f-1b356eb892c3"
377 | },
378 | "source": [
379 | "Q_vi = value_iteration(env)\n",
380 | "\n",
381 | "# Following value iteration policy \n",
382 | "state = env.reset() \n",
383 | "env.enable_rendering() \n",
384 | "for tt in range(100): \n",
385 | " action = Q_vi[state, :].argmax()\n",
386 | " next_state, reward, is_terminal, info = env.step(action)\n",
387 | " if is_terminal:\n",
388 | " break\n",
389 | " state = next_state\n",
390 | "\n",
391 | "# save video (run last cell to visualize it!)\n",
392 | "env.save_video('./videos/value_iteration_policy.mp4', framerate=10)\n",
393 | "# clear rendering data\n",
394 | "env.clear_render_buffer()\n",
395 | "env.disable_rendering()\n",
396 | "# see video\n",
397 | "show_video(filename='./videos/value_iteration_policy.mp4')"
398 | ],
399 | "execution_count": 7,
400 | "outputs": [
401 | {
402 | "output_type": "stream",
403 | "name": "stdout",
404 | "text": [
405 | "videos/value_iteration_policy.mp4\n"
406 | ]
407 | },
408 | {
409 | "output_type": "display_data",
410 | "data": {
411 | "text/html": [
412 | ""
416 | ],
417 | "text/plain": [
418 | ""
419 | ]
420 | },
421 | "metadata": {}
422 | }
423 | ]
424 | },
425 | {
426 | "cell_type": "markdown",
427 | "metadata": {
428 | "id": "1Uw6LVyVulOX"
429 | },
430 | "source": [
431 | "# Implementing Q-Learning\n",
432 | "\n",
433 | "Implement a function ``q_learning`` that takes as input an environment, runs Q learning for $T$ time steps and returns $Q_T$. \n",
434 | "\n",
435 | "Test different learning rates:\n",
436 | " * $\\alpha_t(s, a) = \\frac{1}{\\text{number of visits to} (s, a)}$\n",
437 | " * $\\alpha_t(s, a) =$ constant in $]0, 1[$\n",
438 | " * others?\n",
439 | "\n",
440 | "Test different initializations of the Q function and try different values of $\\varepsilon$ in the $\\varepsilon$-greedy exploration!\n",
441 | "\n",
442 | "It might be very useful to plot the difference between the Q-learning approximation and the output of value iteration above, as a function of time.\n"
443 | ]
444 | },
445 | {
446 | "cell_type": "code",
447 | "metadata": {
448 | "id": "OrhUOlrfv6xp"
449 | },
450 | "source": [
451 | "def q_learning(env, gamma=0.99, T=5000, Q_vi=None):\n",
452 | " \"\"\"\n",
453 | " Q_vi is the output of value iteration.\n",
454 | " \"\"\"\n",
455 | " S = env.observation_space.n\n",
456 | " A = env.action_space.n \n",
457 | " error = np.zeros(T)\n",
458 | " Q = np.zeros((S, A)) # can we improve this initialization? \n",
459 | "\n",
460 | " state = env.reset()\n",
461 | " # to complete...\n",
462 | " for tt in range(T):\n",
463 | " # choose action a_t\n",
464 | " # ...\n",
465 | " # take action, observe next state and reward \n",
466 | " # ...\n",
467 | " # compute delta_t\n",
468 | " # ...\n",
469 | " # update Q\n",
470 | " # ...\n",
471 | "\n",
472 | " error[tt] = np.abs(Q-Q_vi).max()\n",
473 | " \n",
474 | " plt.plot(error)\n",
475 | " plt.xlabel('iteration')\n",
476 | " plt.title('Q-Learning error')\n",
477 | " plt.show()\n",
478 | " \n",
479 | " return Q "
480 | ],
481 | "execution_count": 8,
482 | "outputs": []
483 | },
484 | {
485 | "cell_type": "code",
486 | "metadata": {
487 | "id": "fOetdWM4xhLt",
488 | "colab": {
489 | "base_uri": "https://localhost:8080/",
490 | "height": 718
491 | },
492 | "outputId": "f755ca3f-86f1-4c48-ffe7-fa88d1dc68b3"
493 | },
494 | "source": [
495 | "Q_ql = q_learning(env, Q_vi=Q_vi)\n",
496 | "\n",
497 | "# Following Q-Learning policy \n",
498 | "state = env.reset() \n",
499 | "env.enable_rendering() \n",
500 | "for tt in range(100): \n",
501 | " action = Q_ql[state, :].argmax()\n",
502 | " next_state, reward, is_terminal, info = env.step(action)\n",
503 | " if is_terminal:\n",
504 | " break\n",
505 | " state = next_state\n",
506 | "\n",
507 | "# save video (run last cell to visualize it!)\n",
508 | "env.save_video('./videos/q_learning_policy.mp4', framerate=10)\n",
509 | "# clear rendering data\n",
510 | "env.clear_render_buffer()\n",
511 | "env.disable_rendering()\n",
512 | "# see video\n",
513 | "show_video(filename='./videos/q_learning_policy.mp4')"
514 | ],
515 | "execution_count": 9,
516 | "outputs": [
517 | {
518 | "output_type": "display_data",
519 | "data": {
520 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAVq0lEQVR4nO3df5BlZX3n8fdnGX5lMcOvAYFhHFwmaw27WbQ6qFG3KEV+ZGPGctkNmionSpaYLO5G19VRawUxZcCYELPRWCzqsmoEQ0KcJBuRH7IxKkgPgjAiMAI64PBzAEEEBL77x30aL23Pz+7pO93P+1V1q895znPP/T5dt++nz3NOn05VIUnq1z8bdQGSpNEyCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSDMgyZIkjyTZZdS1SNvKINBOJclvJrk+yaNJ7krysSQLt/CcK5L81mzVOJWq+n5V7VVVT42yDml7GATaaST5b8BZwH8HFgIvAZYCX0qy6whLI8mCUb7+1pqqzm2tfa6MVTPHINBOIcnPA+8H3lpVX6yqn1TV7cB/BJ4PvGE79/vmJDcmeSDJxUmeN7TtI0nWJ/lhkjVJXjG07fQkFyb5TJIfAr/Zjjw+kOSrSR5O8qUk+7f+S5PUxIfo5vq27W9M8r0k9yf5H0luT3LMJsawe5IPJ/l+kruTfDzJnm3b0UnuSPKuJHcBn9pE7QcnWZ1kY5J1Sf7T5sa6Pd9rzV0GgXYWvwzsAfz1cGNVPQL8X+DYbd1hkhXAe4DXAYuArwCfG+pyNXAksC/wF8BfJtljaPsK4EJgb+Czre0NwJuAA4DdgHdspoQp+yZZDnwM+A3gIAZHP4dsZj9nAr/Qaj289X3f0PbntjE8DzhlE7WfD9wBHAycCHwwySu3MFZ1wiDQzmJ/4L6qenKKbRsYfJBvq7cAf1BVN7b9fhA4cuKooKo+U1X3V9WTVfVHwO7Avxx6/ter6m+q6umq+nFr+1RV3dzWP8/gw3lTNtX3ROBvq+qfquoJBh/qU970K0kYfLi/rao2VtXDbRwnDXV7Gjitqh4fqvOZ2hl8b18GvKuqHquqa4FzgTduYazqhEGgncV9wP6bmJ8+qG2nTYs80h7v2cI+nwd8JMmDSR4ENgKh/fad5B1t2uihtn0hgw/NCeun2OddQ8uPAntt5vU31ffg4X1X1aPA/ZvYxyLg54A1Q+P4Is8Oxnur6rFJzxuu/WBgIkQmfI9nH4VMNVZ1wiDQzuLrwOMMpnGekWQv4ATgCoCqeku7OmevqvrgFva5Hvjtqtp76LFnVX2tnQ94J4NzEPtU1d7AQwyCYsKOujXvBmDxxEqb799vE33vA34MHDE0hoVVNRxAU9U53PYDYN8kzxlqWwLcuYV9qBMGgXYKVfUQg5PF/zPJ8Ul2TbKUwZTKfWx53npBkj2GHrsCHwfeneQIgCQLk/yH1v85wJPAve257wN+fsYHNrULgdck+eUkuwGn8+wAekab2vlfwNlJDgBIckiS47b2xapqPfA14A/a9+YXgZOBz0xvGJovDALtNKrqQwxO7n4YeBi4jcG0yDFV9aMtPP3PGfzmPPH4VFVdxOBy1PPb1TA3MDi6ALiYwRTLzQymSR5jlqZHqmot8FYGJ3A3AI8A9zA4IprKu4B1wJVtHJfy7HMZW+P1DC7F/QFwEYNzCpduc/Gal+I/ptHOKsmbgDOAl1XV90ddz47Spr8eBJZV1W2jrkf98Q9HtNOqqk8leZLBpaXzKgiSvAa4jMGU0IeB64HbR1mT+uURgTQCSc5lcBlpgHHgd6vqptFWpV4ZBJLUOU8WS1Ln5uQ5gv3337+WLl066jIkaU5Zs2bNfVX1M3+lPyeDYOnSpYyPj4+6DEmaU5J8b6p2p4YkqXMGgSR1ziCQpM4ZBJLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1DmDQJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSFLnDAJJ6pxBIEmdMwgkqXMzEgRJjk9yU5J1SVZNsX33JBe07VclWTpp+5IkjyR5x0zUI0naetMOgiS7AB8FTgCWA69PsnxSt5OBB6rqcOBs4KxJ2/8Y+Ifp1iJJ2nYzcURwFLCuqm6tqieA84EVk/qsAM5ryxcCr0oSgCSvBW4D1s5ALZKkbTQTQXAIsH5o/Y7WNmWfqnoSeAjYL8lewLuA92/pRZKckmQ8yfi99947A2VLkmD0J4tPB86uqke21LGqzqmqsaoaW7Ro0Y6vTJI6sWAG9nEncOjQ+uLWNlWfO5IsABYC9wMvBk5M8iFgb+DpJI9V1Z/NQF2SpK0wE0FwNbAsyWEMPvBPAt4wqc9qYCXwdeBE4PKqKuAVEx2SnA48YghI0uyadhBU1ZNJTgUuBnYBPllVa5OcAYxX1WrgE8Cnk6wDNjIIC0nSTiCDX8znlrGxsRofHx91GZI0pyRZU1Vjk9tHfbJYkjRiBoEkdc4gkKTOGQSS1DmDQJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSFLnDAJJ6pxBIEmdMwgkqXMGgSR1ziCQpM4ZBJLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1DmDQJI6ZxBIUudmJAiSHJ/kpiTrkqyaYvvuSS5o269KsrS1vzrJmiTXt6+vnIl6JElbb9pBkGQX4KPACcBy4PVJlk/qdjLwQFUdDpwNnNXa7wNeU1X/GlgJfHq69UiSts1MHBEcBayrqlur6gngfGDFpD4rgPPa8oXAq5Kkqr5ZVT9o7WuBPZPsPgM1SZK20kwEwSHA+qH1O1rblH2q6kngIWC/SX3+PXBNVT0+AzVJkrbSglEXAJDkCAbTRcdups8pwCkAS5YsmaXKJGn+m4kjgjuBQ4fWF7e2KfskWQAsBO5v64uBi4A3VtV3N/UiVXVOVY1V1diiRYtmoGxJEsxMEFwNLEtyWJLdgJOA1ZP6rGZwMhjgRODyqqokewN/D6yqqq/OQC2SpG007SBoc/6nAhcDNwKfr6q1Sc5I8mut2yeA/ZKsA94OTFxieipwOPC+JNe2xwHTrUmStPVSVaOuYZuNjY3V+Pj4qMuQpDklyZqqGpvc7l8WS1LnDAJJ6pxBIEmdMwgkqXMGgSR1ziCQpM4ZBJLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1DmDQJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSFLnDAJJ6pxBIEmdMwgkqXMGgSR1ziCQpM4ZBJLUuRkJgiTHJ7kpybokq6bYvnuSC9r2q5IsHdr27tZ+U5LjZqIeSdLWm3YQJNkF+ChwArAceH2S5ZO6nQw8UFWHA2cDZ7XnLgdOAo4Ajgc+1vYnSZolC2ZgH0cB66rqVoAk5wMrgG8P9VkBnN6WLwT+LEla+/lV9ThwW5J1bX9fn4G6fsb7/3Ytdz302I7YtSTNio+c9EJ2WzCzs/ozEQSHAOuH1u8AXrypPlX1ZJKHgP1a+5WTnnvIVC+S5BTgFIAlS5ZsV6HrN/6Y72/80XY9V5J2BkXN+D5nIghmRVWdA5wDMDY2tl3fiXNXjs1oTZI0H8zE8cWdwKFD64tb25R9kiwAFgL3b+VzJUk70EwEwdXAsiSHJdmNwcnf1ZP6rAZWtuUTgcurqlr7Se2qosOAZcA3ZqAmSdJWmvbUUJvzPxW4GNgF+GRVrU1yBjBeVauBTwCfbieDNzIIC1q/zzM4sfwk8J+r6qnp1iRJ2noZ/GI+t4yNjdX4+Pioy5CkOSXJmqr6mZOl/mWxJHXOIJCkzhkEktQ5g0CSOmcQSFLnDAJJ6pxBIEmdMwgkqXMGgSR1ziCQpM4ZBJLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1DmDQJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSFLnDAJJ6ty0giDJvkkuSXJL+7rPJvqtbH1uSbKytf1ckr9P8p0ka5OcOZ1aJEnbZ7pHBKuAy6pqGXBZW3+WJPsCpwEvBo4CThsKjA9X1QuAFwIvS3LCNOuRJG2j6QbBCuC8tnwe8Nop+hwHXFJVG6vqAeAS4PiqerSqvgxQVU8A1wCLp1mPJGkbTTcIDqyqDW35LuDAKfocAqwfWr+jtT0jyd7AaxgcVUiSZtGCLXVIcinw3Ck2vXd4paoqSW1rAUkWAJ8D/rSqbt1Mv1OAUwCWLFmyrS8jSdqELQZBVR2zqW1J7k5yUFVtSHIQcM8U3e4Ejh5aXwxcMbR+DnBLVf3JFuo4p/VlbGxsmwNHkjS16U4NrQZWtuWVwBem6HMxcGySfdpJ4mNbG0l+H1gI/N4065AkbafpBsGZwKuT3AIc09ZJMpbkXICq2gh8ALi6Pc6oqo1JFjOYXloOXJPk2iS/Nc16JEnbKFVzb5ZlbGysxsfHR12GJM0pSdZU1djkdv+yWJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSFLnDAJJ6pxBIEmdMwgkqXMGgSR1ziCQpM4ZBJLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1DmDQJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHXOIJCkzk0rCJLsm+SSJLe0r/tsot/K1ueWJCun2L46yQ3TqUWStH2me0SwCrisqpYBl7X1Z0myL3Aa8GLgKOC04cBI8jrgkWnWIUnaTtMNghXAeW35POC1U/Q5DrikqjZW1QPAJcDxAEn2At4O/P4065AkbafpBsGBVbWhLd8FHDhFn0OA9UPrd7Q2gA8AfwQ8uqUXSnJKkvEk4/fee+80SpYkDVuwpQ5JLgWeO8Wm9w6vVFUlqa194SRHAv+iqt6WZOmW+lfVOcA5AGNjY1v9OpKkzdtiEFTVMZvaluTuJAdV1YYkBwH3TNHtTuDoofXFwBXAS4GxJLe3Og5IckVVHY0kadZMd2poNTBxFdBK4AtT9LkYODbJPu0k8bHAxVX151V1cFUtBV4O3GwISNLsm24QnAm8OsktwDFtnSRjSc4FqKqNDM4FXN0eZ7Q2SdJOIFVzb7p9bGysxsfHR12GJM0pSdZU1djkdv+yWJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSFLnDAJJ6pxBIEmdMwgkqXMGgSR1ziCQpM4ZBJLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1LlU1ahr2GZJ7gW+t51P3x+4bwbLmQsccx96G3Nv44Xpj/l5VbVocuOcDILpSDJeVWOjrmM2OeY+9Dbm3sYLO27MTg1JUucMAknqXI9BcM6oCxgBx9yH3sbc23hhB425u3MEkqRn6/GIQJI0xCCQpM51EwRJjk9yU5J1SVaNup7pSPLJJPckuWGobd8klyS5pX3dp7UnyZ+2cX8ryYuGnrOy9b8lycpRjGVrJTk0yZeTfDvJ2iT/tbXP23En2SPJN5Jc18b8/tZ+WJKr2tguSLJba9+9ra9r25cO7evdrf2mJMeNZkRbJ8kuSb6Z5O/a+rweL0CS25Ncn+TaJOOtbfbe21U17x/ALsB3gecDuwHXActHXdc0xvNvgRcBNwy1fQhY1ZZXAWe15V8B/gEI8BLgqta+L3Br+7pPW95n1GPbzJgPAl7Ulp8D3Awsn8/jbrXv1ZZ3Ba5qY/k8cFJr/zjwO235d4GPt+WTgAva8vL2nt8dOKz9LOwy6vFtZtxvB/4C+Lu2Pq/H22q+Hdh/Utusvbd7OSI4ClhXVbdW1RPA+cCKEde03arqH4GNk5pXAOe15fOA1w61/58auBLYO8lBwHHAJVW1saoeAC4Bjt/x1W+fqtpQVde05YeBG4FDmMfjbrU/0lZ3bY8CXglc2Nonj3nie3Eh8Kokae3nV9XjVXUbsI7Bz8ROJ8li4N8B57b1MI/HuwWz9t7uJQgOAdYPrd/R2uaTA6tqQ1u+CziwLW9q7HP2e9KmAF7I4DfkeT3uNk1yLXAPgx/s7wIPVtWTrctw/c+MrW1/CNiPuTXmPwHeCTzd1vdjfo93QgFfSrImySmtbdbe2wu2t2rtvKqqkszL64KT7AX8FfB7VfXDwS+AA/Nx3FX1FHBkkr2Bi4AXjLikHSbJrwL3VNWaJEePup5Z9vKqujPJAcAlSb4zvHFHv7d7OSK4Ezh0aH1xa5tP7m6Hh7Sv97T2TY19zn1PkuzKIAQ+W1V/3Zrn/bgBqupB4MvASxlMBUz8Ejdc/zNja9sXAvczd8b8MuDXktzOYPr2lcBHmL/jfUZV3dm+3sMg8I9iFt/bvQTB1cCydvXBbgxOLK0ecU0zbTUwcZXASuALQ+1vbFcavAR4qB1uXgwcm2SfdjXCsa1tp9Tmfj8B3FhVfzy0ad6OO8midiRAkj2BVzM4N/Jl4MTWbfKYJ74XJwKX1+As4mrgpHaVzWHAMuAbszOKrVdV766qxVW1lMHP6OVV9RvM0/FOSPLPkzxnYpnBe/IGZvO9Peqz5bP1YHCm/WYGc6zvHXU90xzL54ANwE8YzAOezGBu9DLgFuBSYN/WN8BH27ivB8aG9vNmBifS1gFvGvW4tjDmlzOYR/0WcG17/Mp8Hjfwi8A325hvAN7X2p/P4INtHfCXwO6tfY+2vq5tf/7Qvt7bvhc3ASeMemxbMfaj+elVQ/N6vG1817XH2onPp9l8b3uLCUnqXC9TQ5KkTTAIJKlzBoEkdc4gkKTOGQSS1DmDQF1L8rX2dWmSN8zwvt8z1WtJOxsvH5WAdkuDd1TVr27DcxbUT++BM9X2R6pqr5moT9qRPCJQ15JM3N3zTOAV7X7wb2s3e/vDJFe3e77/dut/dJKvJFkNfLu1/U27WdjaiRuGJTkT2LPt77PDr9X+IvQPk9zQ7kH/60P7viLJhUm+k+SzGb6ZkrSDeNM5aWAVQ0cE7QP9oar6pSS7A19N8qXW90XAv6rBLY4B3lxVG9ttIK5O8ldVtSrJqVV15BSv9TrgSODfAPu35/xj2/ZC4AjgB8BXGdx/559mfrjST3lEIE3tWAb3c7mWwe2u92NwzxqAbwyFAMB/SXIdcCWDm34tY/NeDnyuqp6qqruB/wf80tC+76iqpxncRmPpjIxG2gyPCKSpBXhrVT3rpl3tXMKPJq0fA7y0qh5NcgWDe+Bsr8eHlp/Cn1HNAo8IpIGHGfwLzAkXA7/Tbn1Nkl9od4acbCHwQAuBFzD414ETfjLx/Em+Avx6Ow+xiMG/Ht1p746p+c/fNqSBbwFPtSme/83gPvhLgWvaCdt7+em/Chz2ReAtSW5kcKfLK4e2nQN8K8k1Nbid8oSLGPxfgesY3FH1nVV1VwsSadZ5+agkdc6pIUnqnEEgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOvf/AQ/Xfo538TV8AAAAAElFTkSuQmCC\n",
521 | "text/plain": [
522 | ""
523 | ]
524 | },
525 | "metadata": {
526 | "needs_background": "light"
527 | }
528 | },
529 | {
530 | "output_type": "stream",
531 | "name": "stdout",
532 | "text": [
533 | "videos/q_learning_policy.mp4\n"
534 | ]
535 | },
536 | {
537 | "output_type": "display_data",
538 | "data": {
539 | "text/html": [
540 | ""
544 | ],
545 | "text/plain": [
546 | ""
547 | ]
548 | },
549 | "metadata": {}
550 | }
551 | ]
552 | }
553 | ]
554 | }
--------------------------------------------------------------------------------
/colab_test/test_rlberry_setup.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "test_rlberry_setup.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyO6kyz5+E9FocC44CxfHJ76",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | }
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "view-in-github",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "
"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {
31 | "id": "qL-gF6FESKFk"
32 | },
33 | "source": [
34 | "# Colab setup"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "metadata": {
40 | "id": "sK5bE1AsL2Z8"
41 | },
42 | "source": [
43 | "# After installing, restart the kernel\n",
44 | "\n",
45 | "# install rlberry library\n",
46 | "!git clone https://github.com/rlberry-py/rlberry.git\n",
47 | "!cd rlberry && git pull && pip install -e .[full]\n",
48 | "!pip install ffmpeg-python > /dev/null 2>&1\n",
49 | "\n",
50 | "# packages required to show video\n",
51 | "!pip install pyvirtualdisplay > /dev/null 2>&1\n",
52 | "!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1\n",
53 | "\n",
54 | "# restart runtime\n",
55 | "import os\n",
56 | "os.kill(os.getpid(), 9)"
57 | ],
58 | "execution_count": null,
59 | "outputs": []
60 | },
61 | {
62 | "cell_type": "code",
63 | "metadata": {
64 | "id": "jr1cmKKoSFpq"
65 | },
66 | "source": [
67 | "# Create directory for saving videos\n",
68 | "!mkdir videos > /dev/null 2>&1\n",
69 | "\n",
70 | "# Initialize virtual display and import show_video function\n",
71 | "import rlberry.colab_utils.display_setup\n",
72 | "from rlberry.colab_utils.display_setup import show_video"
73 | ],
74 | "execution_count": 4,
75 | "outputs": []
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "metadata": {
80 | "id": "PNZY8gcrSP--"
81 | },
82 | "source": [
83 | "# 1. Importing modules and running unit tests\n",
84 | "---"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "metadata": {
90 | "id": "0JdnSic9PCDm"
91 | },
92 | "source": [
93 | "import rlberry\n",
94 | "import rlberry.agents\n",
95 | "import rlberry.stats\n",
96 | "import rlberry.envs\n",
97 | "import rlberry.exploration_tools\n",
98 | "import rlberry.rendering\n",
99 | "import rlberry.seeding \n",
100 | "import rlberry.spaces \n",
101 | "import rlberry.utils\n",
102 | "import rlberry.wrappers"
103 | ],
104 | "execution_count": 5,
105 | "outputs": []
106 | },
107 | {
108 | "cell_type": "code",
109 | "metadata": {
110 | "id": "UeNblieLHklr"
111 | },
112 | "source": [
113 | "!python -m pytest rlberry/"
114 | ],
115 | "execution_count": null,
116 | "outputs": []
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "metadata": {
121 | "id": "wdaxg13aIa9X"
122 | },
123 | "source": [
124 | "# 2. Interacting with GridWorld and saving video"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "metadata": {
130 | "id": "ZwpyeJAsRKRR"
131 | },
132 | "source": [
133 | "from rlberry.envs import GridWorld\n",
134 | "\n",
135 | "env = GridWorld(nrows=12, ncols=15, walls=((5,5),(6, 6)))\n",
136 | "\n",
137 | "# call enable_rendering if you want to record a video from the interactions\n",
138 | "env.enable_rendering()\n",
139 | "# get initial state\n",
140 | "state = env.reset()\n",
141 | "# run a random policy for 100 time steps\n",
142 | "for tt in range(100):\n",
143 | " action = env.action_space.sample() # a good RL algorithm must learn a better way to choose actions!\n",
144 | " next_state, reward, is_terminal, info = env.step(action)\n",
145 | " if is_terminal:\n",
146 | " break\n",
147 | " state = next_state\n",
148 | "env.save_video(\"videos/env_example.mp4\", framerate=10)\n",
149 | "\n",
150 | "# show video\n",
151 | "show_video()"
152 | ],
153 | "execution_count": null,
154 | "outputs": []
155 | },
156 | {
157 | "cell_type": "code",
158 | "metadata": {
159 | "id": "YAsvlO52TMBX"
160 | },
161 | "source": [
162 | ""
163 | ],
164 | "execution_count": null,
165 | "outputs": []
166 | }
167 | ]
168 | }
--------------------------------------------------------------------------------
/logo/logo_wide.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/rlberry-py/rlberry.git
2 | jupyterlab
3 | ffmpeg-python
4 | ipywidgets
5 | pyglet==1.5.27
6 | numpy>=1.17
7 | scipy>=1.6
8 | pygame
9 | matplotlib
10 | seaborn
11 | pandas
12 | gym==0.21
13 | dill
14 | docopt
15 | pyyaml
16 | numba
17 | optuna
18 | PyOpenGL==3.1.5
19 | PyOpenGL_accelerate==3.1.5
20 | pyvirtualdisplay
21 | torch>=1.6.0
22 | stable-baselines3
23 | protobuf==3.20.1
24 | tensorboard
25 | ipywidgets
26 |
--------------------------------------------------------------------------------