├── .gitignore ├── A2C └── Tutorial_Advantage_Actor_Critic_(A2C).ipynb ├── Deep_Q_Learning ├── README.md └── Tutorial_Deep_Q_Learning.ipynb ├── Exploration ├── README.md └── Tutorial_UCBVI.ipynb ├── LICENSE ├── README.md ├── Value Iteration and Q-Learning ├── README.md └── Value_Iteration_and_Q_Learning.ipynb ├── colab_test └── test_rlberry_setup.ipynb ├── logo └── logo_wide.svg └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.mp4 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # pytype static type analyzer 137 | .pytype/ 138 | 139 | # Cython debug symbols 140 | cython_debug/ -------------------------------------------------------------------------------- /A2C/Tutorial_Advantage_Actor_Critic_(A2C).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Tutorial - Advantage Actor Critic (A2C).ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyOerJxVFIaozWjxy5taLfea", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "FRvfou6G9RGn" 32 | }, 33 | "source": [ 34 | "# Tutorial - Advantage Actor Critic (A2C)\n", 35 | "\n", 36 | "A2C keeps two neural networks:\n", 37 | "* One network with paramemeters $\\theta$ to represent the policy $\\pi_\\theta$.\n", 38 | "* One network with parameters $\\omega$ to represent a value function $V_\\omega$, that approximates $V^{\\pi_\\theta}$\n", 39 | "\n", 40 | "\n", 41 | "At each iteration, A2C collects $M$ transitions $(s_i, a_i, r_i, s_i')_{i=1}^M$ by following the policy $\\pi_\\theta$. If a terminal state is reached, we simply go back to the initial state and continue to play $\\pi_\\theta$ until we gather the $M$ transitions.\n", 42 | "\n", 43 | "Consider the following quantities, defined based on the collected transitions:\n", 44 | "\n", 45 | "$$\n", 46 | "\\widehat{V}(s_i) = \\widehat{Q}(s_i, a_i) = \\sum_{t=i}^{\\tau_i \\wedge M} \\gamma^{t-i} r_t + \\gamma^{M-i+1} V_\\omega(s_M')\\mathbb{I}\\{\\tau_i>M\\}\n", 47 | "$$\n", 48 | "\n", 49 | "where and $\\tau_i = \\min\\{t\\geq i: s_i' \\text{ is a terminal state}\\}$, and \n", 50 | "\n", 51 | "$$\n", 52 | "\\mathbf{A}_\\omega(s_i, a_i) = \\widehat{Q}(s_i, a_i) - V_\\omega(s_i) \n", 53 | "$$\n", 54 | "\n", 55 | "\n", 56 | "A2C then takes a gradient step to minimize the policy \"loss\" (keeping $\\omega$ fixed):\n", 57 | "\n", 58 | "$$\n", 59 | "L_\\pi(\\theta) =\n", 60 | "-\\frac{1}{M} \\sum_{i=1}^M \\mathbf{A}_\\omega(s_i, a_i) \\log \\pi_\\theta(a_i|s_i)\n", 61 | "- \\frac{\\alpha}{M}\\sum_{i=1}^M \\sum_a \\pi(a|s_i) \\log \\frac{1}{\\pi(a|s_i)}\n", 62 | "$$\n", 63 | "\n", 64 | "and a gradient step to minimize the value loss (keeping $\\theta$ fixed):\n", 65 | "\n", 66 | "$$\n", 67 | "L_v(\\omega) = \\frac{1}{M} \\sum_{i=1}^M \\left( \\widehat{V}(s_i) - V_\\omega(s_i) \\right)^2\n", 68 | "$$\n", 69 | " \n", 70 | "\n", 71 | "\n", 72 | "# Reminders\n", 73 | "\n", 74 | "\n", 75 | "Objective function:\n", 76 | "\n", 77 | "$$\n", 78 | "J(\\theta) = \\mathbb{E}_{\\pi_\\theta}\n", 79 | "\\left[ \n", 80 | " \\sum_{t=0}^\\infty \\gamma^t r(S_t, A_t)\n", 81 | "\\right]\n", 82 | "$$\n", 83 | "\n", 84 | "Policy gradient:\n", 85 | "\n", 86 | "$$\n", 87 | "\\nabla_\\theta J(\\theta)= \\mathbb{E}_{\\pi_\\theta}\n", 88 | "\\left[ \n", 89 | " \\sum_{t=0}^\\infty \\gamma^t A^{\\pi_\\theta}(S_t, A_t) \n", 90 | " \\nabla_\\theta \\log \\pi_\\theta(A_t|S_t)\n", 91 | "\\right]\n", 92 | "$$\n", 93 | "where $A^{\\pi_\\theta}(s, a) = Q^{\\pi_\\theta}(s, a) - V^{\\pi_\\theta}(s) $ is the advantage function." 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": { 99 | "id": "Er4wbIih9e24" 100 | }, 101 | "source": [ 102 | "# Colab setup" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "metadata": { 108 | "colab": { 109 | "base_uri": "https://localhost:8080/" 110 | }, 111 | "id": "O12jMLD29DAU", 112 | "outputId": "37a4b59a-2b5d-44f4-da53-51fd84d77c3f" 113 | }, 114 | "source": [ 115 | "# After installing, restart the kernel\n", 116 | "\n", 117 | "# install rlberry library\n", 118 | "!git clone https://github.com/rlberry-py/rlberry.git \n", 119 | "!cd rlberry && git pull && pip install -e .[full] > /dev/null 2>&1\n", 120 | "!pip install ffmpeg-python > /dev/null 2>&1\n", 121 | "\n", 122 | "# gym\n", 123 | "!pip install 'gym[all]' > /dev/null 2>&1\n", 124 | "\n", 125 | "# packages required to show video\n", 126 | "!pip install pyvirtualdisplay > /dev/null 2>&1\n", 127 | "!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1\n", 128 | "\n", 129 | "# ask to restart runtime\n", 130 | "print(\"\")\n", 131 | "print(\" ~~~ Libraries installed, please restart the runtime! ~~~ \")\n", 132 | "print(\"\")" 133 | ], 134 | "execution_count": 1, 135 | "outputs": [ 136 | { 137 | "output_type": "stream", 138 | "text": [ 139 | "Cloning into 'rlberry'...\n", 140 | "remote: Enumerating objects: 472, done.\u001b[K\n", 141 | "remote: Counting objects: 100% (472/472), done.\u001b[K\n", 142 | "remote: Compressing objects: 100% (292/292), done.\u001b[K\n", 143 | "remote: Total 3541 (delta 283), reused 326 (delta 177), pack-reused 3069\u001b[K\n", 144 | "Receiving objects: 100% (3541/3541), 886.51 KiB | 9.85 MiB/s, done.\n", 145 | "Resolving deltas: 100% (2277/2277), done.\n", 146 | "Already up to date.\n", 147 | "\n", 148 | " ~~~ Libraries installed, please restart the runtime! ~~~ \n", 149 | "\n" 150 | ], 151 | "name": "stdout" 152 | } 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "metadata": { 158 | "id": "gKOp4h0Oe9-X" 159 | }, 160 | "source": [ 161 | "import gym\r\n", 162 | "from gym import logger as gymlogger\r\n", 163 | "from gym.wrappers import Monitor\r\n", 164 | "gymlogger.set_level(40) # error only\r\n", 165 | "\r\n", 166 | "import torch\r\n", 167 | "import torch.nn as nn\r\n", 168 | "import torch.nn.functional as F \r\n", 169 | "from torch import optim\r\n", 170 | "\r\n", 171 | "import numpy as np\r\n", 172 | "\r\n", 173 | "\r\n", 174 | "# for videos\r\n", 175 | "import rlberry.colab_utils.display_setup\r\n", 176 | "from rlberry.colab_utils.display_setup import show_video" 177 | ], 178 | "execution_count": 7, 179 | "outputs": [] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "metadata": { 184 | "id": "MESFRbWdfA6P" 185 | }, 186 | "source": [ 187 | "class ActorNetwork(nn.Module):\r\n", 188 | " \"\"\"\r\n", 189 | " This network represents the policy\r\n", 190 | " \"\"\"\r\n", 191 | "\r\n", 192 | " def __init__(self, input_size, hidden_size, action_size):\r\n", 193 | " super(ActorNetwork, self).__init__()\r\n", 194 | " self.n_actions = action_size\r\n", 195 | " self.dim_observation = input_size\r\n", 196 | " \r\n", 197 | " self.net = nn.Sequential(\r\n", 198 | " nn.Linear(in_features=self.dim_observation, out_features=hidden_size),\r\n", 199 | " nn.ReLU(),\r\n", 200 | " nn.Linear(in_features=hidden_size, out_features=hidden_size),\r\n", 201 | " nn.ReLU(),\r\n", 202 | " nn.Linear(in_features=hidden_size, out_features=self.n_actions),\r\n", 203 | " nn.Softmax(dim=-1)\r\n", 204 | " )\r\n", 205 | " \r\n", 206 | " def policy(self, state):\r\n", 207 | " state = torch.tensor(state, dtype=torch.float)\r\n", 208 | " return self.net(state)\r\n", 209 | " \r\n", 210 | " def sample_action(self, state):\r\n", 211 | " state = torch.tensor(state, dtype=torch.float)\r\n", 212 | " action = torch.multinomial(self.policy(state), 1)\r\n", 213 | " return action.item()" 214 | ], 215 | "execution_count": 8, 216 | "outputs": [] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "metadata": { 221 | "id": "R_DHHAQNfD7Z" 222 | }, 223 | "source": [ 224 | "class ValueNetwork(nn.Module):\r\n", 225 | " \"\"\"\r\n", 226 | " This class represents the value function\r\n", 227 | " \"\"\"\r\n", 228 | "\r\n", 229 | " def __init__(self, input_size, hidden_size, output_size):\r\n", 230 | " super(ValueNetwork, self).__init__()\r\n", 231 | " self.fc1 = nn.Linear(input_size, hidden_size)\r\n", 232 | " self.fc2 = nn.Linear(hidden_size, hidden_size)\r\n", 233 | " self.fc3 = nn.Linear(hidden_size, output_size)\r\n", 234 | "\r\n", 235 | " def forward(self, x):\r\n", 236 | " out = F.relu(self.fc1(x))\r\n", 237 | " out = F.relu(self.fc2(out))\r\n", 238 | " out = self.fc3(out)\r\n", 239 | " return out\r\n", 240 | " \r\n", 241 | " def value(self, state):\r\n", 242 | " state = torch.tensor(state, dtype=torch.float)\r\n", 243 | " return self.forward(state)" 244 | ], 245 | "execution_count": 9, 246 | "outputs": [] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "metadata": { 251 | "id": "_Ry-b3HgfGx5" 252 | }, 253 | "source": [ 254 | "# You can select your environment here\r\n", 255 | "env_id = 'CartPole-v1' # @param [\"CartPole-v1\", \"LunarLander-v2\", \"MountainCar-v0\"]\r\n", 256 | "env = gym.make(env_id)\r\n", 257 | "eval_env = gym.make(env_id) # environment to evaluate the policy" 258 | ], 259 | "execution_count": 10, 260 | "outputs": [] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "metadata": { 265 | "id": "h65dXIY5fMZg" 266 | }, 267 | "source": [ 268 | "# Define you networks\r\n", 269 | "value_network = ValueNetwork(env.observation_space.shape[0], 16, 1)\r\n", 270 | "actor_network = ActorNetwork(env.observation_space.shape[0], 16, env.action_space.n)\r\n", 271 | "print(value_network)\r\n", 272 | "print(actor_network)\r\n", 273 | "\r\n", 274 | "# Define your optimizers\r\n", 275 | "value_network_optimizer = torch.optim.RMSprop(value_network.parameters(), lr=0.01)\r\n", 276 | "actor_network_optimizer = torch.optim.RMSprop(actor_network.parameters(), lr=0.01)\r\n", 277 | "\r\n", 278 | "# --------------------------------------------------------------\r\n", 279 | "# Parameters\r\n", 280 | "# --------------------------------------------------------------\r\n", 281 | "num_iterations = 300 # Number of iterations\r\n", 282 | "batch_size = 512 # How many samples to collect (value of M)\r\n", 283 | "gamma = 1 # Discount factor\r\n", 284 | "alpha = 0.001 # Entropy term coefficient\r\n", 285 | "reward_threshold = 495 # Stop training when the policy achieves this amound of rewards\r\n", 286 | "\r\n", 287 | "\r\n", 288 | "# --------------------------------------------------------------\r\n", 289 | "# Train\r\n", 290 | "# --------------------------------------------------------------\r\n", 291 | "for iteration in range(num_iterations):\r\n", 292 | " # Initialize batch storage\r\n", 293 | " states = np.empty((batch_size,) + env.observation_space.shape, dtype=np.float) # shape (batch_size, state_dim)\r\n", 294 | " rewards = np.empty((batch_size,), dtype=np.float) # shape (batch_size, ) \r\n", 295 | " next_states = np.empty((batch_size,) + env.observation_space.shape, dtype=np.float) # shape (batch_size, state_dim)\r\n", 296 | " dones = np.empty((batch_size,), dtype=np.bool) # shape (batch_size, ) \r\n", 297 | " proba = torch.empty((batch_size,), dtype=np.float) # shape (batch_size, ), store pi(a_t|s_t)\r\n", 298 | " next_value = 0 # \r\n", 299 | " \r\n", 300 | " # Intialize environment\r\n", 301 | " state = env.reset()\r\n", 302 | "\r\n", 303 | " # Generate batch\r\n", 304 | " for i in range(batch_size):\r\n", 305 | " action = actor_network.sample_action(state)\r\n", 306 | " next_state, reward, done, _ = env.step(action)\r\n", 307 | "\r\n", 308 | " states[i] = # ...\r\n", 309 | " rewards[i] = # ...\r\n", 310 | " next_states[i] = # ...\r\n", 311 | " dones[i] = # ...\r\n", 312 | " proba[i] = # ...\r\n", 313 | "\r\n", 314 | " state = next_state\r\n", 315 | " if done:\r\n", 316 | " state = env.reset()\r\n", 317 | "\r\n", 318 | " if not done:\r\n", 319 | " next_value = value_network.value(next_states[-1]).detach().numpy()[0]\r\n", 320 | "\r\n", 321 | " # compute returns (without bootstrapping)\r\n", 322 | " returns = np.zeros((batch_size,), dtype=np.float)\r\n", 323 | " T = batch_size\r\n", 324 | " for j in range(T):\r\n", 325 | " returns[T-j-1] = rewards[T-j-1]\r\n", 326 | " if j > 0:\r\n", 327 | " returns[T-j-1] += gamma * returns[T-j] * (1 - dones[T-j])\r\n", 328 | " else:\r\n", 329 | " returns[T-j-1] += gamma * next_value\r\n", 330 | "\r\n", 331 | " # compute advantage\r\n", 332 | " values = value_network.value(states)\r\n", 333 | " advantages = # ...\r\n", 334 | "\r\n", 335 | " # Compute MSE (Value loss)\r\n", 336 | " value_network_optimizer.zero_grad()\r\n", 337 | " loss_value = # ...\r\n", 338 | " loss_value.backward()\r\n", 339 | " value_network_optimizer.step()\r\n", 340 | "\r\n", 341 | " # Compute entropy term\r\n", 342 | " dist = actor_network.policy(states)\r\n", 343 | " entropy_term = -(dist*dist.log()).sum(-1).mean()\r\n", 344 | "\r\n", 345 | " # Compute policy loss\r\n", 346 | " actor_network_optimizer.zero_grad()\r\n", 347 | " loss_policy = # ...\r\n", 348 | " loss_policy += -alpha * entropy_term\r\n", 349 | " loss_policy.backward()\r\n", 350 | " actor_network_optimizer.step()\r\n", 351 | "\r\n", 352 | " if( (iteration+1)%10 == 0 ):\r\n", 353 | " eval_rewards = np.zeros(5)\r\n", 354 | " for sim in range(5):\r\n", 355 | " eval_done = False\r\n", 356 | " eval_state = eval_env.reset()\r\n", 357 | " while not eval_done:\r\n", 358 | " eval_action = actor_network.sample_action(eval_state)\r\n", 359 | " eval_next_state, eval_reward, eval_done, _ = eval_env.step(eval_action)\r\n", 360 | " eval_rewards[sim] += eval_reward\r\n", 361 | " eval_state = eval_next_state\r\n", 362 | " print(\"Iteration = {}, loss_value = {:0.3f}, loss_policy = {:0.3f}, rewards = {:0.2f}\"\r\n", 363 | " .format(iteration +1, loss_value.item(), loss_policy.item(), eval_rewards.mean()))\r\n", 364 | " if (eval_rewards.mean() > reward_threshold):\r\n", 365 | " break" 366 | ], 367 | "execution_count": null, 368 | "outputs": [] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "metadata": { 373 | "id": "kPzvAqDVhc_K" 374 | }, 375 | "source": [ 376 | "env = Monitor(env, \"./gym-results\", force=True, video_callable=lambda episode: True)\r\n", 377 | "for episode in range(1):\r\n", 378 | " done = False\r\n", 379 | " state = env.reset()\r\n", 380 | " while not done:\r\n", 381 | " action = actor_network.sample_action(state)\r\n", 382 | " state, reward, done, info = env.step(action)\r\n", 383 | "env.close()\r\n", 384 | "show_video(directory=\"./gym-results\")" 385 | ], 386 | "execution_count": null, 387 | "outputs": [] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "metadata": { 392 | "id": "vNqnseJtlU87" 393 | }, 394 | "source": [ 395 | "# Test other environments!\r\n", 396 | "\r\n", 397 | "Try some other environments available in OpenAI gym ([link](https://gym.openai.com/envs/#classic_control)). Suggestion: use `classic control` or `Box2D` environments." 398 | ] 399 | } 400 | ] 401 | } -------------------------------------------------------------------------------- /Deep_Q_Learning/README.md: -------------------------------------------------------------------------------- 1 | # Instructions 2 | 3 | **To run the notebook in [Google Colab](https://colab.research.google.com/)**, click on the link 4 | `Open in Colab` at the top of the `.ipynb` file. 5 | 6 | 7 | **To run the notebook locally**, download the `.ipynb` file and install the required libraries, 8 | as explained below. 9 | 10 | * Setup virtual environment (optional but recommended): 11 | 12 | ``` 13 | conda create -n rltutorials python=3.8 14 | conda activate rltutorials 15 | ``` 16 | 17 | * Install required libraries: 18 | 19 | ``` 20 | conda install -c conda-forge jupyterlab 21 | pip install git+https://github.com/rlberry-py/rlberry.git#egg=rlberry[torch_agents] 22 | ``` 23 | -------------------------------------------------------------------------------- /Deep_Q_Learning/Tutorial_Deep_Q_Learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Tutorial_Deep_Q_Learning.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyP9EbLl6g2dURBpFFjKPouU", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "2j_no2BuvPUE" 32 | }, 33 | "source": [ 34 | "# Tutorial - Deep Q-Learning \n", 35 | "\n", 36 | "Deep Q-Learning uses a neural network to approximate $Q$ functions. Hence, we usually refer to this algorithm as DQN (for *deep Q network*).\n", 37 | "\n", 38 | "The parameters of the neural network are denoted by $\\theta$. \n", 39 | "* As input, the network takes a state $s$,\n", 40 | "* As output, the network returns $Q(s, a, \\theta)$, the value of each action $a$ in state $s$, according to the parameters $\\theta$.\n", 41 | "\n", 42 | "\n", 43 | "The goal of Deep Q-Learning is to learn the parameters $\\theta$ so that $Q(s, a, \\theta)$ approximates well the optimal $Q$-function $Q^*(s, a)$. \n", 44 | "\n", 45 | "In addition to the network with parameters $\\theta$, the algorithm keeps another network with the same architecture and parameters $\\theta^-$, called **target network**.\n", 46 | "\n", 47 | "The algorithm works as follows:\n", 48 | "\n", 49 | "1. At each time $t$, the agent is in state $s_t$ and has observed the transitions $(s_i, a_i, r_i, s_i')_{i=1}^{t-1}$, which are stored in a **replay buffer**.\n", 50 | "\n", 51 | "2. Choose action $a_t = \\arg\\max_a Q(s_t, a)$ with probability $1-\\varepsilon_t$, and $a_t$=random action with probability $\\varepsilon_t$. \n", 52 | "\n", 53 | "3. Take action $a_t$, observe reward $r_t$ and next state $s_t'$.\n", 54 | "\n", 55 | "4. Add transition $(s_t, a_t, r_t, s_t')$ to the **replay buffer**.\n", 56 | "\n", 57 | "4. Sample a minibatch $\\mathcal{B}$ containing $B$ transitions from the replay buffer. Using this minibatch, we define the loss:\n", 58 | "\n", 59 | "$$\n", 60 | "L(\\theta) = \\sum_{(s_i, a_i, r_i, s_i') \\in \\mathcal{B}}\n", 61 | "\\left[\n", 62 | "Q(s_i, a_i, \\theta) - y_i\n", 63 | "\\right]^2\n", 64 | "$$\n", 65 | "where the $y_i$ are the **targets** computed with the **target network** $\\theta^-$:\n", 66 | "\n", 67 | "$$\n", 68 | "y_i = r_i + \\gamma \\max_{a'} Q(s_i', a', \\theta^-).\n", 69 | "$$\n", 70 | "\n", 71 | "5. Update the parameters $\\theta$ to minimize the loss, e.g., with gradient descent (**keeping $\\theta^-$ fixed**): \n", 72 | "$$\n", 73 | "\\theta \\gets \\theta - \\eta \\nabla_\\theta L(\\theta)\n", 74 | "$$\n", 75 | "where $\\eta$ is the optimization learning rate. \n", 76 | "\n", 77 | "6. Every $N$ transitions ($t\\mod N$ = 0), update target parameters: $\\theta^- \\gets \\theta$.\n", 78 | "\n", 79 | "7. $t \\gets t+1$. Stop if $t = T$, otherwise go to step 2." 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": { 85 | "id": "HhKHif__t9OD" 86 | }, 87 | "source": [ 88 | "# Colab setup" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "metadata": { 94 | "colab": { 95 | "base_uri": "https://localhost:8080/" 96 | }, 97 | "id": "aylqy_sDqebM", 98 | "outputId": "e1a78b7f-f832-4119-e8c5-3e02264944d9" 99 | }, 100 | "source": [ 101 | "# After installing, restart the kernel\n", 102 | "\n", 103 | "if 'google.colab' in str(get_ipython()):\n", 104 | " print(\"Installing packages, please wait a few moments. You may need to restart the runtime after the installation.\")\n", 105 | "\n", 106 | " # install rlberry library\n", 107 | " !pip install git+https://github.com/rlberry-py/rlberry.git#egg=rlberry[default] > /dev/null 2>&1\n", 108 | "\n", 109 | " # install gym\n", 110 | " !pip install gym[all] > /dev/null 2>&1\n", 111 | "\n", 112 | " # packages required to show video\n", 113 | " !pip install pyvirtualdisplay > /dev/null 2>&1\n", 114 | " !apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1" 115 | ], 116 | "execution_count": 18, 117 | "outputs": [ 118 | { 119 | "output_type": "stream", 120 | "name": "stdout", 121 | "text": [ 122 | "Installing packages, please wait a few moments. You may need to restart the runtime after the installation.\n" 123 | ] 124 | } 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "metadata": { 130 | "id": "VWBRfwosfA9f" 131 | }, 132 | "source": [ 133 | "# Imports\n", 134 | "import torch\n", 135 | "import torch.nn as nn\n", 136 | "import torch.nn.functional as F\n", 137 | "import torch.optim as optim\n", 138 | "import numpy as np\n", 139 | "import random\n", 140 | "from copy import deepcopy\n", 141 | "from gym.wrappers import Monitor\n", 142 | "import gym" 143 | ], 144 | "execution_count": 19, 145 | "outputs": [] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "metadata": { 150 | "id": "35Zzr-xCya5y" 151 | }, 152 | "source": [ 153 | "# Create directory for saving videos\n", 154 | "!mkdir videos > /dev/null 2>&1\n", 155 | "\n", 156 | "# Initialize display and import function to show videos\n", 157 | "import rlberry.colab_utils.display_setup\n", 158 | "from rlberry.colab_utils.display_setup import show_video" 159 | ], 160 | "execution_count": 20, 161 | "outputs": [] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "metadata": { 166 | "id": "FLLwJLQlrTxo" 167 | }, 168 | "source": [ 169 | "# Random number generator\n", 170 | "import rlberry.seeding as seeding \n", 171 | "seeder = seeding.Seeder(456)\n", 172 | "rng = seeder.rng" 173 | ], 174 | "execution_count": 21, 175 | "outputs": [] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": { 180 | "id": "528oqsgefIFl" 181 | }, 182 | "source": [ 183 | "# 1. Define the parameters" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "metadata": { 189 | "id": "CtExtR4dfMbm", 190 | "colab": { 191 | "base_uri": "https://localhost:8080/" 192 | }, 193 | "outputId": "64f36e7b-b953-4442-bc88-9d9fe6b90ef7" 194 | }, 195 | "source": [ 196 | "# Environment\n", 197 | "env = gym.make(\"CartPole-v0\")\n", 198 | "\n", 199 | "# Discount factor\n", 200 | "GAMMA = 0.99\n", 201 | "\n", 202 | "# Batch size\n", 203 | "BATCH_SIZE = 256\n", 204 | "# Capacity of the replay buffer\n", 205 | "BUFFER_CAPACITY = 10000\n", 206 | "# Update target net every ... episodes\n", 207 | "UPDATE_TARGET_EVERY = 20\n", 208 | "\n", 209 | "# Initial value of epsilon\n", 210 | "EPSILON_START = 1.0\n", 211 | "# Parameter to decrease epsilon\n", 212 | "DECREASE_EPSILON = 200\n", 213 | "# Minimum value of epislon\n", 214 | "EPSILON_MIN = 0.05\n", 215 | "\n", 216 | "# Number of training episodes\n", 217 | "N_EPISODES = 200\n", 218 | "\n", 219 | "# Learning rate\n", 220 | "LEARNING_RATE = 0.1" 221 | ], 222 | "execution_count": 22, 223 | "outputs": [ 224 | { 225 | "output_type": "stream", 226 | "name": "stdout", 227 | "text": [ 228 | "INFO: Making new env: CartPole-v0\n" 229 | ] 230 | } 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": { 236 | "id": "6g16Je-dhM2Q" 237 | }, 238 | "source": [ 239 | "# 2. Define the replay buffer" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "metadata": { 245 | "id": "Jvh82br9hMNt" 246 | }, 247 | "source": [ 248 | "class ReplayBuffer:\n", 249 | " def __init__(self, capacity):\n", 250 | " self.capacity = capacity\n", 251 | " self.memory = []\n", 252 | " self.position = 0\n", 253 | "\n", 254 | " def push(self, state, action, reward, next_state, done):\n", 255 | " \"\"\"Saves a transition.\"\"\"\n", 256 | " if len(self.memory) < self.capacity:\n", 257 | " self.memory.append(None)\n", 258 | " self.memory[self.position] = (state, action, reward, next_state, done)\n", 259 | " self.position = (self.position + 1) % self.capacity\n", 260 | "\n", 261 | " def sample(self, batch_size):\n", 262 | " return rng.choice(self.memory, batch_size).tolist()\n", 263 | "\n", 264 | "\n", 265 | " def __len__(self):\n", 266 | " return len(self.memory)\n", 267 | "\n", 268 | "# create instance of replay buffer\n", 269 | "replay_buffer = ReplayBuffer(BUFFER_CAPACITY)" 270 | ], 271 | "execution_count": 23, 272 | "outputs": [] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": { 277 | "id": "UCc9WZppi92W" 278 | }, 279 | "source": [ 280 | "# 3. Define the neural network architecture, objective and optimizer" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "metadata": { 286 | "id": "sdNz3Jrwi9iS" 287 | }, 288 | "source": [ 289 | "class Net(nn.Module):\n", 290 | " \"\"\"\n", 291 | " Basic neural net.\n", 292 | " \"\"\"\n", 293 | " def __init__(self, obs_size, hidden_size, n_actions):\n", 294 | " super(Net, self).__init__()\n", 295 | " self.net = nn.Sequential(\n", 296 | " nn.Linear(obs_size, hidden_size),\n", 297 | " nn.ReLU(),\n", 298 | " nn.Linear(hidden_size, n_actions)\n", 299 | " )\n", 300 | "\n", 301 | " def forward(self, x):\n", 302 | " return self.net(x)" 303 | ], 304 | "execution_count": 24, 305 | "outputs": [] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "metadata": { 310 | "id": "NI9hFJ28jLZ_" 311 | }, 312 | "source": [ 313 | "# create network and target network\n", 314 | "hidden_size = 128\n", 315 | "obs_size = env.observation_space.shape[0]\n", 316 | "n_actions = env.action_space.n\n", 317 | "\n", 318 | "q_net = Net(obs_size, hidden_size, n_actions)\n", 319 | "target_net = Net(obs_size, hidden_size, n_actions)\n", 320 | "\n", 321 | "# objective and optimizer\n", 322 | "objective = nn.MSELoss()\n", 323 | "optimizer = optim.Adam(params=q_net.parameters(), lr=LEARNING_RATE)" 324 | ], 325 | "execution_count": 25, 326 | "outputs": [] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "metadata": { 331 | "id": "xnR8nfoSjZjL" 332 | }, 333 | "source": [ 334 | "# 4. Implement Deep Q-Learning" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "metadata": { 340 | "id": "z6fT8cKdjmTZ" 341 | }, 342 | "source": [ 343 | "#\n", 344 | "# Some useful functions\n", 345 | "#\n", 346 | "\n", 347 | "def get_q(states):\n", 348 | " \"\"\"\n", 349 | " Compute Q function for a list of states\n", 350 | " \"\"\"\n", 351 | " with torch.no_grad():\n", 352 | " states_v = torch.FloatTensor([states])\n", 353 | " output = q_net.forward(states_v).data.numpy() # shape (1, len(states), n_actions)\n", 354 | " return output[0, :, :] # shape (len(states), n_actions)\n", 355 | "\n", 356 | "def eval_dqn(n_sim=5):\n", 357 | " \"\"\" \n", 358 | " Monte Carlo evaluation of DQN agent.\n", 359 | "\n", 360 | " Repeat n_sim times:\n", 361 | " * Run the DQN policy until the environment reaches a terminal state (= one episode)\n", 362 | " * Compute the sum of rewards in this episode\n", 363 | " * Store the sum of rewards in the episode_rewards array.\n", 364 | " \"\"\"\n", 365 | " env_copy = deepcopy(env)\n", 366 | " episode_rewards = np.zeros(n_sim)\n", 367 | "\n", 368 | " for ii in range(n_sim):\n", 369 | " state = env_copy.reset()\n", 370 | " done = False \n", 371 | " while not done:\n", 372 | " action = choose_action(state, 0.0)\n", 373 | " next_state, reward, done, _ = env_copy.step(action)\n", 374 | " episode_rewards[ii] += reward\n", 375 | " state = next_state\n", 376 | " return episode_rewards" 377 | ], 378 | "execution_count": 26, 379 | "outputs": [] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "metadata": { 384 | "id": "OMspDNntkIoe" 385 | }, 386 | "source": [ 387 | "def choose_action(state, epsilon):\n", 388 | " \"\"\"\n", 389 | " ** TO BE IMPLEMENTED **\n", 390 | " \n", 391 | " Return action according to an epsilon-greedy exploration policy\n", 392 | " \"\"\"\n", 393 | " return 0\n", 394 | " \n", 395 | "\n", 396 | "def update(state, action, reward, next_state, done):\n", 397 | " \"\"\"\n", 398 | " ** TO BE COMPLETED **\n", 399 | " \"\"\"\n", 400 | " \n", 401 | " # add data to replay buffer\n", 402 | " replay_buffer.push(state, action, reward, next_state, done)\n", 403 | " \n", 404 | " if len(replay_buffer) < BATCH_SIZE:\n", 405 | " return np.inf\n", 406 | " \n", 407 | " # get batch\n", 408 | " transitions = replay_buffer.sample(BATCH_SIZE)\n", 409 | "\n", 410 | " # Compute loss - TO BE IMPLEMENTED!\n", 411 | " values = torch.zeros(BATCH_SIZE) # to be computed using batch\n", 412 | " targets = torch.zeros(BATCH_SIZE) # to be computed using batch\n", 413 | " loss = objective(values, targets)\n", 414 | " \n", 415 | " # Optimize the model - UNCOMMENT!\n", 416 | "# optimizer.zero_grad()\n", 417 | "# loss.backward()\n", 418 | "# optimizer.step()\n", 419 | " \n", 420 | " return loss.data.numpy()" 421 | ], 422 | "execution_count": 27, 423 | "outputs": [] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "metadata": { 428 | "id": "QIhpKPhkkU4W", 429 | "colab": { 430 | "base_uri": "https://localhost:8080/" 431 | }, 432 | "outputId": "93f23393-0bc4-48bf-d315-1fbc1d94f7c2" 433 | }, 434 | "source": [ 435 | "\n", 436 | "#\n", 437 | "# Train\n", 438 | "# \n", 439 | "\n", 440 | "EVAL_EVERY = 5\n", 441 | "REWARD_THRESHOLD = 199\n", 442 | "\n", 443 | "def train():\n", 444 | " state = env.reset()\n", 445 | " epsilon = EPSILON_START\n", 446 | " ep = 0\n", 447 | " total_time = 0\n", 448 | " while ep < N_EPISODES:\n", 449 | " action = choose_action(state, epsilon)\n", 450 | "\n", 451 | " # take action and update replay buffer and networks\n", 452 | " next_state, reward, done, _ = env.step(action)\n", 453 | " loss = update(state, action, reward, next_state, done)\n", 454 | "\n", 455 | " # update state\n", 456 | " state = next_state\n", 457 | "\n", 458 | " # end episode if done\n", 459 | " if done:\n", 460 | " state = env.reset()\n", 461 | " ep += 1\n", 462 | " if ( (ep+1)% EVAL_EVERY == 0):\n", 463 | " rewards = eval_dqn()\n", 464 | " print(\"episode =\", ep+1, \", reward = \", np.mean(rewards))\n", 465 | " if np.mean(rewards) >= REWARD_THRESHOLD:\n", 466 | " break\n", 467 | "\n", 468 | " # update target network\n", 469 | " if ep % UPDATE_TARGET_EVERY == 0:\n", 470 | " target_net.load_state_dict(q_net.state_dict())\n", 471 | " # decrease epsilon\n", 472 | " epsilon = EPSILON_MIN + (EPSILON_START - EPSILON_MIN) * \\\n", 473 | " np.exp(-1. * ep / DECREASE_EPSILON ) \n", 474 | "\n", 475 | " total_time += 1\n", 476 | "\n", 477 | "# Run the training loop\n", 478 | "train()\n", 479 | "\n", 480 | "# Evaluate the final policy\n", 481 | "rewards = eval_dqn(20)\n", 482 | "print(\"\")\n", 483 | "print(\"mean reward after training = \", np.mean(rewards))" 484 | ], 485 | "execution_count": 28, 486 | "outputs": [ 487 | { 488 | "output_type": "stream", 489 | "name": "stdout", 490 | "text": [ 491 | "episode = 5 , reward = 9.6\n", 492 | "episode = 10 , reward = 9.4\n", 493 | "episode = 15 , reward = 9.4\n", 494 | "episode = 20 , reward = 9.2\n", 495 | "episode = 25 , reward = 9.2\n", 496 | "episode = 30 , reward = 9.8\n", 497 | "episode = 35 , reward = 9.8\n", 498 | "episode = 40 , reward = 10.0\n", 499 | "episode = 45 , reward = 9.2\n", 500 | "episode = 50 , reward = 9.8\n" 501 | ] 502 | }, 503 | { 504 | "output_type": "stream", 505 | "name": "stderr", 506 | "text": [ 507 | "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:15: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", 508 | " from ipykernel import kernelapp as app\n" 509 | ] 510 | }, 511 | { 512 | "output_type": "stream", 513 | "name": "stdout", 514 | "text": [ 515 | "episode = 55 , reward = 9.8\n", 516 | "episode = 60 , reward = 9.4\n", 517 | "episode = 65 , reward = 9.6\n", 518 | "episode = 70 , reward = 9.6\n", 519 | "episode = 75 , reward = 8.8\n", 520 | "episode = 80 , reward = 10.0\n", 521 | "episode = 85 , reward = 9.2\n", 522 | "episode = 90 , reward = 9.4\n", 523 | "episode = 95 , reward = 9.2\n", 524 | "episode = 100 , reward = 9.2\n", 525 | "episode = 105 , reward = 9.2\n", 526 | "episode = 110 , reward = 9.6\n", 527 | "episode = 115 , reward = 9.2\n", 528 | "episode = 120 , reward = 9.2\n", 529 | "episode = 125 , reward = 9.4\n", 530 | "episode = 130 , reward = 9.8\n", 531 | "episode = 135 , reward = 9.2\n", 532 | "episode = 140 , reward = 9.2\n", 533 | "episode = 145 , reward = 10.2\n", 534 | "episode = 150 , reward = 9.2\n", 535 | "episode = 155 , reward = 9.4\n", 536 | "episode = 160 , reward = 9.6\n", 537 | "episode = 165 , reward = 9.6\n", 538 | "episode = 170 , reward = 9.4\n", 539 | "episode = 175 , reward = 9.0\n", 540 | "episode = 180 , reward = 9.0\n", 541 | "episode = 185 , reward = 9.6\n", 542 | "episode = 190 , reward = 9.2\n", 543 | "episode = 195 , reward = 9.4\n", 544 | "episode = 200 , reward = 9.4\n", 545 | "\n", 546 | "mean reward after training = 9.8\n" 547 | ] 548 | } 549 | ] 550 | }, 551 | { 552 | "cell_type": "markdown", 553 | "metadata": { 554 | "id": "c8QZwuvjgrMm" 555 | }, 556 | "source": [ 557 | "# Visualize the DQN policy" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "metadata": { 563 | "colab": { 564 | "base_uri": "https://localhost:8080/", 565 | "height": 474 566 | }, 567 | "id": "FGcGwOcEfzPz", 568 | "outputId": "3aa22829-9b5c-4308-cd1a-aadb1a629fb0" 569 | }, 570 | "source": [ 571 | "def render_env(env):\n", 572 | " env = deepcopy(env)\n", 573 | " env = Monitor(env, './videos', force=True, video_callable=lambda episode: True)\n", 574 | " for episode in range(1):\n", 575 | " done = False\n", 576 | " state = env.reset()\n", 577 | " env.render()\n", 578 | " while not done:\n", 579 | " action = action = choose_action(state, 0.0)\n", 580 | " state, reward, done, info = env.step(action)\n", 581 | " env.render()\n", 582 | " env.close()\n", 583 | " show_video()\n", 584 | "\n", 585 | "render_env(env)" 586 | ], 587 | "execution_count": 29, 588 | "outputs": [ 589 | { 590 | "output_type": "stream", 591 | "name": "stdout", 592 | "text": [ 593 | "INFO: Clearing 4 monitor files from previous run (because force=True was provided)\n", 594 | "INFO: Starting new video recorder writing to /content/videos/openaigym.video.1.705.video000000.mp4\n", 595 | "INFO: Finished writing results. You can upload them to the scoreboard via gym.upload('/content/videos')\n" 596 | ] 597 | }, 598 | { 599 | "output_type": "display_data", 600 | "data": { 601 | "text/html": [ 602 | "" 606 | ], 607 | "text/plain": [ 608 | "" 609 | ] 610 | }, 611 | "metadata": {} 612 | } 613 | ] 614 | } 615 | ] 616 | } -------------------------------------------------------------------------------- /Exploration/README.md: -------------------------------------------------------------------------------- 1 | # Instructions 2 | 3 | **To run the notebook in [Google Colab](https://colab.research.google.com/)**, click on the link 4 | `Open in Colab` at the top of the `.ipynb` file. 5 | 6 | 7 | **To run the notebook locally**, download the `.ipynb` file and install the required libraries, 8 | as explained below. 9 | 10 | * Setup virtual environment (optional but recommended): 11 | 12 | ``` 13 | conda create -n rltutorials python=3.8 14 | conda activate rltutorials 15 | ``` 16 | 17 | * Install required libraries: 18 | 19 | ``` 20 | conda install -c conda-forge jupyterlab 21 | pip install git+https://github.com/rlberry-py/rlberry.git#egg=rlberry[default] 22 | ``` 23 | 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 rlberry-py 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | 7 | 10 | 11 | # Reinforcement Learning Tutorials 12 | 13 | * [Value Iteration and Q-Learning](https://github.com/rlberry-py/tutorials/blob/main/Value%20Iteration%20and%20Q-Learning/Value_Iteration_and_Q_Learning.ipynb) 14 | 15 | * [Deep Q Learning](https://github.com/rlberry-py/tutorials/blob/main/Deep_Q_Learning/Tutorial_Deep_Q_Learning.ipynb) 16 | 17 | * [Advantage Actor-Critic (A2C)](https://github.com/rlberry-py/tutorials/blob/main/A2C/Tutorial_Advantage_Actor_Critic_(A2C).ipynb) 18 | 19 | See also the [`rlberry`](https://github.com/rlberry-py/rlberry) library! 20 | -------------------------------------------------------------------------------- /Value Iteration and Q-Learning/README.md: -------------------------------------------------------------------------------- 1 | # Instructions 2 | 3 | **To run the notebook in [Google Colab](https://colab.research.google.com/)**, click on the link 4 | `Open in Colab` at the top of the `.ipynb` file. 5 | 6 | 7 | **To run the notebook locally**, download the `.ipynb` file and install the required libraries, 8 | as explained below. 9 | 10 | * Setup virtual environment (optional but recommended): 11 | 12 | ``` 13 | conda create -n rltutorials python=3.8 14 | conda activate rltutorials 15 | ``` 16 | 17 | * Install required libraries: 18 | 19 | ``` 20 | conda install -c conda-forge jupyterlab 21 | pip install git+https://github.com/rlberry-py/rlberry.git#egg=rlberry[default] 22 | ``` 23 | 24 | -------------------------------------------------------------------------------- /Value Iteration and Q-Learning/Value_Iteration_and_Q_Learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Tutorial - Value Iteration and Q-Learning.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "toc_visible": true, 10 | "authorship_tag": "ABX9TyM+8H1rbTADo1Hh3m1E+mXQ", 11 | "include_colab_link": true 12 | }, 13 | "kernelspec": { 14 | "name": "python3", 15 | "display_name": "Python 3" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "Io_4iovMTlzT" 33 | }, 34 | "source": [ 35 | "# Tutorial - Value Iteration and Q-Learning\n", 36 | "---------------------------------\n", 37 | "\n", 38 | "In this tutorial, you will:\n", 39 | "\n", 40 | "* Implement the value iteration algorithm to approximate the value function when *a model of the environment is available*.\n", 41 | "* Implement the Q-Learning algorithm to approximate the value function when *the model is unknown*, that is, the agent must learn through interactions.\n", 42 | "\n", 43 | "We start with a short review of these algorithms.\n", 44 | "\n", 45 | "\n", 46 | "## Markov decision processes and value functions\n", 47 | "\n", 48 | "In reinforcement learning, an agent interacts with an enviroment by taking actions and observing rewards. Its goal is to learn a *policy*, that is, a mapping from states to actions, that maximizes the amount of reward it gathers.\n", 49 | "\n", 50 | "The enviroment is modeled as a __Markov decision process (MDP)__, defined by a set of states $\\mathcal{S}$, a set of actions $\\mathcal{A}$, a reward function $r(s, a)$ and transition probabilities $P(s'|s,a)$. When an agent takes action $a$ in state $s$, it receives a random reward with mean $r(s,a)$ and makes a transion to a state $s'$ distributed according to $P(s'|s,a)$.\n", 51 | "\n", 52 | "A __policy__ $\\pi$ is such that $\\pi(a|s)$ gives the probability of choosing an action $a$ in state $s$. __If the policy is deterministic__, we denote by $\\pi(s)$ the action that it chooses in state $s$. We are interested in finding a policy that maximizes the value function $V^\\pi$, defined as \n", 53 | "\n", 54 | "$$\n", 55 | "V^\\pi(s) = \\sum_{a\\in \\mathcal{A}} \\pi(a|s) Q^\\pi(s, a), \n", 56 | "\\quad \\text{where} \\quad \n", 57 | "Q^\\pi(s, a) = \\mathbf{E}\\left[ \\sum_{t=0}^\\infty \\gamma^t r(S_t, A_t) \\Big| S_0 = s, A_0 = a\\right].\n", 58 | "$$\n", 59 | "and represents the mean of the sum of discounted rewards gathered by the policy $\\pi$ in the MDP, where $\\gamma \\in [0, 1[$ is a discount factor ensuring the convergence of the sum. \n", 60 | "\n", 61 | "The __action-value function__ $Q^\\pi$ is the __fixed point of the Bellman operator $T^\\pi$__:\n", 62 | "\n", 63 | "$$ \n", 64 | "Q^\\pi(s, a) = T^\\pi Q^\\pi(s, a)\n", 65 | "$$\n", 66 | "where, for any function $f: \\mathcal{S}\\times\\mathcal{A} \\to \\mathbb{R}$\n", 67 | "$$\n", 68 | "T^\\pi f(s, a) = r(s, a) + \\gamma \\sum_{s'} P(s'|s,a) \\left(\\sum_{a'}\\pi(a'|s')f(s',a')\\right) \n", 69 | "$$\n", 70 | "\n", 71 | "\n", 72 | "The __optimal value function__, defined as $V^*(s) = \\max_\\pi V^\\pi(s)$ can be shown to satisfy $V^*(s) = \\max_a Q^*(s, a)$, where $Q^*$ is the __fixed point of the optimal Bellman operator $T^*$__: \n", 73 | "\n", 74 | "$$ \n", 75 | "Q^*(s, a) = T^* Q^*(s, a)\n", 76 | "$$\n", 77 | "where, for any function $f: \\mathcal{S}\\times\\mathcal{A} \\to \\mathbb{R}$\n", 78 | "$$\n", 79 | "T^* f(s, a) = r(s, a) + \\gamma \\sum_{s'} P(s'|s,a) \\max_{a'} f(s', a')\n", 80 | "$$\n", 81 | "and there exists an __optimal policy__ which is deterministic, given by $\\pi^*(s) \\in \\arg\\max_a Q^*(s, a)$.\n", 82 | "\n", 83 | "\n", 84 | "## Value iteration\n", 85 | "\n", 86 | "If both the reward function $r$ and the transition probablities $P$ are known, we can compute $Q^*$ using value iteration, which proceeds as follows:\n", 87 | "\n", 88 | "1. Start with arbitrary $Q_0$, set $t=0$.\n", 89 | "2. Compute $Q_{t+1}(s, a) = T^*Q_t(s,a)$ for every $(s, a)$.\n", 90 | "3. If $\\max_{s,a} | Q_{t+1}(s, a) - Q_t(s,a)| \\leq \\varepsilon$, return $Q_{t}$. Otherwise, set $t \\gets t+1$ and go back to 2. \n", 91 | "\n", 92 | "The convergence is guaranteed by the contraction property of the Bellman operator, and $Q_{t+1}$ can be shown to be a good approximation of $Q^*$ for small epsilon. \n", 93 | "\n", 94 | "__Question__: Can you bound the error $\\max_{s,a} | Q^*(s, a) - Q_t(s,a)|$ as a function of $\\gamma$ and $\\varepsilon$?\n", 95 | "\n", 96 | "## Q-Learning\n", 97 | "\n", 98 | "In value iteration, we need to know $r$ and $P$ to implement the Bellman operator. When these quantities are not available, we can approximate $Q^*$ using *samples* from the environment with the Q-Learning algorithm.\n", 99 | "\n", 100 | "Q-Learning with __$\\varepsilon$-greedy exploration__ proceeds as follows:\n", 101 | "\n", 102 | "1. Start with arbitrary $Q_0$, get starting state $s_0$, set $t=0$.\n", 103 | "2. Choosing action $a_t$: \n", 104 | " * With probability $\\varepsilon$ choose $a_t$ randomly (uniform distribution) \n", 105 | " * With probability $1-\\varepsilon$, choose $a_t \\in \\arg\\max_a Q_t(s_t, a)$.\n", 106 | "3. Take action $a_t$, observe next state $s_{t+1}$ and reward $r_t$.\n", 107 | "4. Compute error $\\delta_t = r_t + \\gamma \\max_a Q_t(s_{t+1}, a) - Q_t(s_t, a_t)$.\n", 108 | "5. Update \n", 109 | " * $Q_{t+1}(s, a) = Q_t(s, a) + \\alpha_t(s,a) \\delta_t$, __if $s=s_t$ and $a=a_t$__\n", 110 | " * $Q_{t+1}(s, a) = Q_{t}(s, a)$ otherwise.\n", 111 | "\n", 112 | "Here, $\\alpha_t(s,a)$ is a learning rate that can depend, for instance, on the number of times the algorithm has visited the state-action pair $(s, a)$. \n" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": { 118 | "id": "KYq9-63OR8RW" 119 | }, 120 | "source": [ 121 | "# Colab setup" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "metadata": { 127 | "id": "AxepTGrNR3DX", 128 | "colab": { 129 | "base_uri": "https://localhost:8080/" 130 | }, 131 | "outputId": "42376421-d387-42a8-a943-0d1c5b5b3db0" 132 | }, 133 | "source": [ 134 | "if 'google.colab' in str(get_ipython()):\n", 135 | " print(\"Installing packages, please wait a few moments. Restart the runtime after the installation.\")\n", 136 | "\n", 137 | " # install rlberry library\n", 138 | " !pip install git+https://github.com/rlberry-py/rlberry.git#egg=rlberry[default] > /dev/null 2>&1\n", 139 | "\n", 140 | " # packages required to show video\n", 141 | " !pip install pyvirtualdisplay > /dev/null 2>&1\n", 142 | " !apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1\n" 143 | ], 144 | "execution_count": 1, 145 | "outputs": [ 146 | { 147 | "output_type": "stream", 148 | "name": "stdout", 149 | "text": [ 150 | "Installing packages, please wait a few moments. Restart the runtime after the installation.\n" 151 | ] 152 | } 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "metadata": { 158 | "id": "3_bPhqKlSiF0", 159 | "colab": { 160 | "base_uri": "https://localhost:8080/" 161 | }, 162 | "outputId": "959689cb-1e62-41f3-c1ac-71741bd5bb48" 163 | }, 164 | "source": [ 165 | "# Create directory for saving videos\n", 166 | "!mkdir videos > /dev/null 2>&1\n", 167 | "\n", 168 | "# The following code is will be used to visualize the environments.\n", 169 | "import base64\n", 170 | "from pyvirtualdisplay import Display\n", 171 | "from IPython import display as ipythondisplay\n", 172 | "from IPython.display import clear_output\n", 173 | "from pathlib import Path\n", 174 | "\n", 175 | "def show_video(filename=None, directory='./videos'):\n", 176 | " \"\"\"\n", 177 | " Either show all videos in a directory (if filename is None) or \n", 178 | " show video corresponding to filename.\n", 179 | " \"\"\"\n", 180 | " html = []\n", 181 | " if filename is not None:\n", 182 | " files = Path('./').glob(filename)\n", 183 | " else:\n", 184 | " files = Path(directory).glob(\"*.mp4\")\n", 185 | " for mp4 in files:\n", 186 | " print(mp4)\n", 187 | " video_b64 = base64.b64encode(mp4.read_bytes())\n", 188 | " html.append(''''''.format(mp4, video_b64.decode('ascii')))\n", 192 | " ipythondisplay.display(ipythondisplay.HTML(data=\"
\".join(html)))\n", 193 | " \n", 194 | "from pyvirtualdisplay import Display\n", 195 | "display = Display(visible=0, size=(800, 800))\n", 196 | "display.start()" 197 | ], 198 | "execution_count": 2, 199 | "outputs": [ 200 | { 201 | "output_type": "execute_result", 202 | "data": { 203 | "text/plain": [ 204 | "" 205 | ] 206 | }, 207 | "metadata": {}, 208 | "execution_count": 2 209 | } 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "metadata": { 215 | "id": "ZYZCXMpisE_O" 216 | }, 217 | "source": [ 218 | "# other required libraries\n", 219 | "import numpy as np\n", 220 | "import matplotlib.pyplot as plt\n", 221 | "\n" 222 | ], 223 | "execution_count": 3, 224 | "outputs": [] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": { 229 | "id": "zOPiAupGmkxh" 230 | }, 231 | "source": [ 232 | "# Warm up: interacting with a reinforcement learning environment" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "metadata": { 238 | "id": "6IZ0bVAlTjpZ", 239 | "colab": { 240 | "base_uri": "https://localhost:8080/", 241 | "height": 578 242 | }, 243 | "outputId": "60cf10f4-8f13-4264-c281-1194beff4c1d" 244 | }, 245 | "source": [ 246 | "from rlberry.envs import GridWorld\n", 247 | "\n", 248 | "# A GridWorld is an environment where an agent moves in a 2d grid and aims to reach the state which gives a reward.\n", 249 | "env = GridWorld(nrows=3, ncols=5, walls=((0,2),(1, 2)), success_probability=0.9)\n", 250 | "\n", 251 | "# Number of states and actions\n", 252 | "print(\"number of states = \", env.observation_space.n)\n", 253 | "print(\"number of actions = \", env.action_space.n)\n", 254 | "\n", 255 | "# Transitions probabilities, env.P[s, a, s'] = P(s'|s, a)\n", 256 | "print(\"transition probabilities from state 0 by taking action 1: \", env.P[0, 1, :])\n", 257 | "\n", 258 | "# Reward function: env.R[s, a] = r(s, a)\n", 259 | "print(\"mean reward in state 0 for action 1 = \", env.R[0, 1])\n", 260 | "\n", 261 | "# Following a random policy \n", 262 | "state = env.reset() # initial state \n", 263 | "env.enable_rendering() # save states for visualization\n", 264 | "for tt in range(100): # interact for 100 time steps\n", 265 | " action = env.action_space.sample() # random action, a good RL agent must have a better strategy!\n", 266 | " next_state, reward, is_terminal, info = env.step(action)\n", 267 | " if is_terminal:\n", 268 | " break\n", 269 | " state = next_state\n", 270 | "\n", 271 | "# save video \n", 272 | "env.save_video('./videos/random_policy.mp4', framerate=10)\n", 273 | "# clear rendering data\n", 274 | "env.clear_render_buffer()\n", 275 | "env.disable_rendering()\n", 276 | "# see video\n", 277 | "show_video(filename='./videos/random_policy.mp4')" 278 | ], 279 | "execution_count": 4, 280 | "outputs": [ 281 | { 282 | "output_type": "stream", 283 | "name": "stderr", 284 | "text": [ 285 | "[INFO] OpenGL_accelerate module loaded \n", 286 | "[INFO] Using accelerated ArrayDatatype \n", 287 | "[INFO] Generating grammar tables from /usr/lib/python3.7/lib2to3/Grammar.txt \n", 288 | "[INFO] Generating grammar tables from /usr/lib/python3.7/lib2to3/PatternGrammar.txt \n" 289 | ] 290 | }, 291 | { 292 | "output_type": "stream", 293 | "name": "stdout", 294 | "text": [ 295 | "number of states = 13\n", 296 | "number of actions = 4\n", 297 | "transition probabilities from state 0 by taking action 1: [0. 0.9 0. 0. 0.1 0. 0. 0. 0. 0. 0. 0. 0. ]\n", 298 | "mean reward in state 0 for action 1 = 0.0\n", 299 | "videos/random_policy.mp4\n" 300 | ] 301 | }, 302 | { 303 | "output_type": "display_data", 304 | "data": { 305 | "text/html": [ 306 | "" 310 | ], 311 | "text/plain": [ 312 | "" 313 | ] 314 | }, 315 | "metadata": {} 316 | } 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": { 322 | "id": "snmFW5Bzqpwj" 323 | }, 324 | "source": [ 325 | "# Implementing Value Iteration\n", 326 | "\n", 327 | "1. Write a function ``bellman_operator`` that takes as input a function $Q$ and returns $T^* Q$.\n", 328 | "2. Write a function ``value_iteration`` that returns a function $Q$ such that $||Q-T^* Q||_\\infty \\leq \\varepsilon$\n", 329 | "3. Evaluate the performance of the policy $\\pi(s) = \\arg\\max_a Q(s, a)$, where Q is returned by ``value_iteration``." 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "metadata": { 335 | "id": "RPIOmpjkq0YX" 336 | }, 337 | "source": [ 338 | "def bellman_operator(Q, env, gamma=0.99):\n", 339 | " S = env.observation_space.n\n", 340 | " A = env.action_space.n \n", 341 | " TQ = np.zeros((S, A))\n", 342 | "\n", 343 | " # to complete...\n", 344 | "\n", 345 | " return TQ" 346 | ], 347 | "execution_count": 5, 348 | "outputs": [] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "metadata": { 353 | "id": "tEKAtA1LsYFx" 354 | }, 355 | "source": [ 356 | "def value_iteration(env, gamma=0.99, epsilon=1e-6):\n", 357 | " S = env.observation_space.n\n", 358 | " A = env.action_space.n \n", 359 | " Q = np.zeros((S, A))\n", 360 | "\n", 361 | " # to complete...\n", 362 | "\n", 363 | " return Q" 364 | ], 365 | "execution_count": 6, 366 | "outputs": [] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "metadata": { 371 | "id": "rZ7k-rDLssSk", 372 | "colab": { 373 | "base_uri": "https://localhost:8080/", 374 | "height": 440 375 | }, 376 | "outputId": "7731f953-093d-4c3b-e84f-1b356eb892c3" 377 | }, 378 | "source": [ 379 | "Q_vi = value_iteration(env)\n", 380 | "\n", 381 | "# Following value iteration policy \n", 382 | "state = env.reset() \n", 383 | "env.enable_rendering() \n", 384 | "for tt in range(100): \n", 385 | " action = Q_vi[state, :].argmax()\n", 386 | " next_state, reward, is_terminal, info = env.step(action)\n", 387 | " if is_terminal:\n", 388 | " break\n", 389 | " state = next_state\n", 390 | "\n", 391 | "# save video (run last cell to visualize it!)\n", 392 | "env.save_video('./videos/value_iteration_policy.mp4', framerate=10)\n", 393 | "# clear rendering data\n", 394 | "env.clear_render_buffer()\n", 395 | "env.disable_rendering()\n", 396 | "# see video\n", 397 | "show_video(filename='./videos/value_iteration_policy.mp4')" 398 | ], 399 | "execution_count": 7, 400 | "outputs": [ 401 | { 402 | "output_type": "stream", 403 | "name": "stdout", 404 | "text": [ 405 | "videos/value_iteration_policy.mp4\n" 406 | ] 407 | }, 408 | { 409 | "output_type": "display_data", 410 | "data": { 411 | "text/html": [ 412 | "" 416 | ], 417 | "text/plain": [ 418 | "" 419 | ] 420 | }, 421 | "metadata": {} 422 | } 423 | ] 424 | }, 425 | { 426 | "cell_type": "markdown", 427 | "metadata": { 428 | "id": "1Uw6LVyVulOX" 429 | }, 430 | "source": [ 431 | "# Implementing Q-Learning\n", 432 | "\n", 433 | "Implement a function ``q_learning`` that takes as input an environment, runs Q learning for $T$ time steps and returns $Q_T$. \n", 434 | "\n", 435 | "Test different learning rates:\n", 436 | " * $\\alpha_t(s, a) = \\frac{1}{\\text{number of visits to} (s, a)}$\n", 437 | " * $\\alpha_t(s, a) =$ constant in $]0, 1[$\n", 438 | " * others?\n", 439 | "\n", 440 | "Test different initializations of the Q function and try different values of $\\varepsilon$ in the $\\varepsilon$-greedy exploration!\n", 441 | "\n", 442 | "It might be very useful to plot the difference between the Q-learning approximation and the output of value iteration above, as a function of time.\n" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "metadata": { 448 | "id": "OrhUOlrfv6xp" 449 | }, 450 | "source": [ 451 | "def q_learning(env, gamma=0.99, T=5000, Q_vi=None):\n", 452 | " \"\"\"\n", 453 | " Q_vi is the output of value iteration.\n", 454 | " \"\"\"\n", 455 | " S = env.observation_space.n\n", 456 | " A = env.action_space.n \n", 457 | " error = np.zeros(T)\n", 458 | " Q = np.zeros((S, A)) # can we improve this initialization? \n", 459 | "\n", 460 | " state = env.reset()\n", 461 | " # to complete...\n", 462 | " for tt in range(T):\n", 463 | " # choose action a_t\n", 464 | " # ...\n", 465 | " # take action, observe next state and reward \n", 466 | " # ...\n", 467 | " # compute delta_t\n", 468 | " # ...\n", 469 | " # update Q\n", 470 | " # ...\n", 471 | "\n", 472 | " error[tt] = np.abs(Q-Q_vi).max()\n", 473 | " \n", 474 | " plt.plot(error)\n", 475 | " plt.xlabel('iteration')\n", 476 | " plt.title('Q-Learning error')\n", 477 | " plt.show()\n", 478 | " \n", 479 | " return Q " 480 | ], 481 | "execution_count": 8, 482 | "outputs": [] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "metadata": { 487 | "id": "fOetdWM4xhLt", 488 | "colab": { 489 | "base_uri": "https://localhost:8080/", 490 | "height": 718 491 | }, 492 | "outputId": "f755ca3f-86f1-4c48-ffe7-fa88d1dc68b3" 493 | }, 494 | "source": [ 495 | "Q_ql = q_learning(env, Q_vi=Q_vi)\n", 496 | "\n", 497 | "# Following Q-Learning policy \n", 498 | "state = env.reset() \n", 499 | "env.enable_rendering() \n", 500 | "for tt in range(100): \n", 501 | " action = Q_ql[state, :].argmax()\n", 502 | " next_state, reward, is_terminal, info = env.step(action)\n", 503 | " if is_terminal:\n", 504 | " break\n", 505 | " state = next_state\n", 506 | "\n", 507 | "# save video (run last cell to visualize it!)\n", 508 | "env.save_video('./videos/q_learning_policy.mp4', framerate=10)\n", 509 | "# clear rendering data\n", 510 | "env.clear_render_buffer()\n", 511 | "env.disable_rendering()\n", 512 | "# see video\n", 513 | "show_video(filename='./videos/q_learning_policy.mp4')" 514 | ], 515 | "execution_count": 9, 516 | "outputs": [ 517 | { 518 | "output_type": "display_data", 519 | "data": { 520 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAVq0lEQVR4nO3df5BlZX3n8fdnGX5lMcOvAYFhHFwmaw27WbQ6qFG3KEV+ZGPGctkNmionSpaYLO5G19VRawUxZcCYELPRWCzqsmoEQ0KcJBuRH7IxKkgPgjAiMAI64PBzAEEEBL77x30aL23Pz+7pO93P+1V1q895znPP/T5dt++nz3NOn05VIUnq1z8bdQGSpNEyCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSDMgyZIkjyTZZdS1SNvKINBOJclvJrk+yaNJ7krysSQLt/CcK5L81mzVOJWq+n5V7VVVT42yDml7GATaaST5b8BZwH8HFgIvAZYCX0qy6whLI8mCUb7+1pqqzm2tfa6MVTPHINBOIcnPA+8H3lpVX6yqn1TV7cB/BJ4PvGE79/vmJDcmeSDJxUmeN7TtI0nWJ/lhkjVJXjG07fQkFyb5TJIfAr/Zjjw+kOSrSR5O8qUk+7f+S5PUxIfo5vq27W9M8r0k9yf5H0luT3LMJsawe5IPJ/l+kruTfDzJnm3b0UnuSPKuJHcBn9pE7QcnWZ1kY5J1Sf7T5sa6Pd9rzV0GgXYWvwzsAfz1cGNVPQL8X+DYbd1hkhXAe4DXAYuArwCfG+pyNXAksC/wF8BfJtljaPsK4EJgb+Czre0NwJuAA4DdgHdspoQp+yZZDnwM+A3gIAZHP4dsZj9nAr/Qaj289X3f0PbntjE8DzhlE7WfD9wBHAycCHwwySu3MFZ1wiDQzmJ/4L6qenKKbRsYfJBvq7cAf1BVN7b9fhA4cuKooKo+U1X3V9WTVfVHwO7Avxx6/ter6m+q6umq+nFr+1RV3dzWP8/gw3lTNtX3ROBvq+qfquoJBh/qU970K0kYfLi/rao2VtXDbRwnDXV7Gjitqh4fqvOZ2hl8b18GvKuqHquqa4FzgTduYazqhEGgncV9wP6bmJ8+qG2nTYs80h7v2cI+nwd8JMmDSR4ENgKh/fad5B1t2uihtn0hgw/NCeun2OddQ8uPAntt5vU31ffg4X1X1aPA/ZvYxyLg54A1Q+P4Is8Oxnur6rFJzxuu/WBgIkQmfI9nH4VMNVZ1wiDQzuLrwOMMpnGekWQv4ATgCoCqeku7OmevqvrgFva5Hvjtqtp76LFnVX2tnQ94J4NzEPtU1d7AQwyCYsKOujXvBmDxxEqb799vE33vA34MHDE0hoVVNRxAU9U53PYDYN8kzxlqWwLcuYV9qBMGgXYKVfUQg5PF/zPJ8Ul2TbKUwZTKfWx53npBkj2GHrsCHwfeneQIgCQLk/yH1v85wJPAve257wN+fsYHNrULgdck+eUkuwGn8+wAekab2vlfwNlJDgBIckiS47b2xapqPfA14A/a9+YXgZOBz0xvGJovDALtNKrqQwxO7n4YeBi4jcG0yDFV9aMtPP3PGfzmPPH4VFVdxOBy1PPb1TA3MDi6ALiYwRTLzQymSR5jlqZHqmot8FYGJ3A3AI8A9zA4IprKu4B1wJVtHJfy7HMZW+P1DC7F/QFwEYNzCpduc/Gal+I/ptHOKsmbgDOAl1XV90ddz47Spr8eBJZV1W2jrkf98Q9HtNOqqk8leZLBpaXzKgiSvAa4jMGU0IeB64HbR1mT+uURgTQCSc5lcBlpgHHgd6vqptFWpV4ZBJLUOU8WS1Ln5uQ5gv3337+WLl066jIkaU5Zs2bNfVX1M3+lPyeDYOnSpYyPj4+6DEmaU5J8b6p2p4YkqXMGgSR1ziCQpM4ZBJLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1DmDQJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSFLnDAJJ6pxBIEmdMwgkqXMzEgRJjk9yU5J1SVZNsX33JBe07VclWTpp+5IkjyR5x0zUI0naetMOgiS7AB8FTgCWA69PsnxSt5OBB6rqcOBs4KxJ2/8Y+Ifp1iJJ2nYzcURwFLCuqm6tqieA84EVk/qsAM5ryxcCr0oSgCSvBW4D1s5ALZKkbTQTQXAIsH5o/Y7WNmWfqnoSeAjYL8lewLuA92/pRZKckmQ8yfi99947A2VLkmD0J4tPB86uqke21LGqzqmqsaoaW7Ro0Y6vTJI6sWAG9nEncOjQ+uLWNlWfO5IsABYC9wMvBk5M8iFgb+DpJI9V1Z/NQF2SpK0wE0FwNbAsyWEMPvBPAt4wqc9qYCXwdeBE4PKqKuAVEx2SnA48YghI0uyadhBU1ZNJTgUuBnYBPllVa5OcAYxX1WrgE8Cnk6wDNjIIC0nSTiCDX8znlrGxsRofHx91GZI0pyRZU1Vjk9tHfbJYkjRiBoEkdc4gkKTOGQSS1DmDQJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSFLnDAJJ6pxBIEmdMwgkqXMGgSR1ziCQpM4ZBJLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1DmDQJI6ZxBIUudmJAiSHJ/kpiTrkqyaYvvuSS5o269KsrS1vzrJmiTXt6+vnIl6JElbb9pBkGQX4KPACcBy4PVJlk/qdjLwQFUdDpwNnNXa7wNeU1X/GlgJfHq69UiSts1MHBEcBayrqlur6gngfGDFpD4rgPPa8oXAq5Kkqr5ZVT9o7WuBPZPsPgM1SZK20kwEwSHA+qH1O1rblH2q6kngIWC/SX3+PXBNVT0+AzVJkrbSglEXAJDkCAbTRcdups8pwCkAS5YsmaXKJGn+m4kjgjuBQ4fWF7e2KfskWQAsBO5v64uBi4A3VtV3N/UiVXVOVY1V1diiRYtmoGxJEsxMEFwNLEtyWJLdgJOA1ZP6rGZwMhjgRODyqqokewN/D6yqqq/OQC2SpG007SBoc/6nAhcDNwKfr6q1Sc5I8mut2yeA/ZKsA94OTFxieipwOPC+JNe2xwHTrUmStPVSVaOuYZuNjY3V+Pj4qMuQpDklyZqqGpvc7l8WS1LnDAJJ6pxBIEmdMwgkqXMGgSR1ziCQpM4ZBJLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1DmDQJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSFLnDAJJ6pxBIEmdMwgkqXMGgSR1ziCQpM4ZBJLUuRkJgiTHJ7kpybokq6bYvnuSC9r2q5IsHdr27tZ+U5LjZqIeSdLWm3YQJNkF+ChwArAceH2S5ZO6nQw8UFWHA2cDZ7XnLgdOAo4Ajgc+1vYnSZolC2ZgH0cB66rqVoAk5wMrgG8P9VkBnN6WLwT+LEla+/lV9ThwW5J1bX9fn4G6fsb7/3Ytdz302I7YtSTNio+c9EJ2WzCzs/ozEQSHAOuH1u8AXrypPlX1ZJKHgP1a+5WTnnvIVC+S5BTgFIAlS5ZsV6HrN/6Y72/80XY9V5J2BkXN+D5nIghmRVWdA5wDMDY2tl3fiXNXjs1oTZI0H8zE8cWdwKFD64tb25R9kiwAFgL3b+VzJUk70EwEwdXAsiSHJdmNwcnf1ZP6rAZWtuUTgcurqlr7Se2qosOAZcA3ZqAmSdJWmvbUUJvzPxW4GNgF+GRVrU1yBjBeVauBTwCfbieDNzIIC1q/zzM4sfwk8J+r6qnp1iRJ2noZ/GI+t4yNjdX4+Pioy5CkOSXJmqr6mZOl/mWxJHXOIJCkzhkEktQ5g0CSOmcQSFLnDAJJ6pxBIEmdMwgkqXMGgSR1ziCQpM4ZBJLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1DmDQJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSFLnDAJJ6ty0giDJvkkuSXJL+7rPJvqtbH1uSbKytf1ckr9P8p0ka5OcOZ1aJEnbZ7pHBKuAy6pqGXBZW3+WJPsCpwEvBo4CThsKjA9X1QuAFwIvS3LCNOuRJG2j6QbBCuC8tnwe8Nop+hwHXFJVG6vqAeAS4PiqerSqvgxQVU8A1wCLp1mPJGkbTTcIDqyqDW35LuDAKfocAqwfWr+jtT0jyd7AaxgcVUiSZtGCLXVIcinw3Ck2vXd4paoqSW1rAUkWAJ8D/rSqbt1Mv1OAUwCWLFmyrS8jSdqELQZBVR2zqW1J7k5yUFVtSHIQcM8U3e4Ejh5aXwxcMbR+DnBLVf3JFuo4p/VlbGxsmwNHkjS16U4NrQZWtuWVwBem6HMxcGySfdpJ4mNbG0l+H1gI/N4065AkbafpBsGZwKuT3AIc09ZJMpbkXICq2gh8ALi6Pc6oqo1JFjOYXloOXJPk2iS/Nc16JEnbKFVzb5ZlbGysxsfHR12GJM0pSdZU1djkdv+yWJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSFLnDAJJ6pxBIEmdMwgkqXMGgSR1ziCQpM4ZBJLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1DmDQJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHXOIJCkzk0rCJLsm+SSJLe0r/tsot/K1ueWJCun2L46yQ3TqUWStH2me0SwCrisqpYBl7X1Z0myL3Aa8GLgKOC04cBI8jrgkWnWIUnaTtMNghXAeW35POC1U/Q5DrikqjZW1QPAJcDxAEn2At4O/P4065AkbafpBsGBVbWhLd8FHDhFn0OA9UPrd7Q2gA8AfwQ8uqUXSnJKkvEk4/fee+80SpYkDVuwpQ5JLgWeO8Wm9w6vVFUlqa194SRHAv+iqt6WZOmW+lfVOcA5AGNjY1v9OpKkzdtiEFTVMZvaluTuJAdV1YYkBwH3TNHtTuDoofXFwBXAS4GxJLe3Og5IckVVHY0kadZMd2poNTBxFdBK4AtT9LkYODbJPu0k8bHAxVX151V1cFUtBV4O3GwISNLsm24QnAm8OsktwDFtnSRjSc4FqKqNDM4FXN0eZ7Q2SdJOIFVzb7p9bGysxsfHR12GJM0pSdZU1djkdv+yWJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSFLnDAJJ6pxBIEmdMwgkqXMGgSR1ziCQpM4ZBJLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1LlU1ahr2GZJ7gW+t51P3x+4bwbLmQsccx96G3Nv44Xpj/l5VbVocuOcDILpSDJeVWOjrmM2OeY+9Dbm3sYLO27MTg1JUucMAknqXI9BcM6oCxgBx9yH3sbc23hhB425u3MEkqRn6/GIQJI0xCCQpM51EwRJjk9yU5J1SVaNup7pSPLJJPckuWGobd8klyS5pX3dp7UnyZ+2cX8ryYuGnrOy9b8lycpRjGVrJTk0yZeTfDvJ2iT/tbXP23En2SPJN5Jc18b8/tZ+WJKr2tguSLJba9+9ra9r25cO7evdrf2mJMeNZkRbJ8kuSb6Z5O/a+rweL0CS25Ncn+TaJOOtbfbe21U17x/ALsB3gecDuwHXActHXdc0xvNvgRcBNwy1fQhY1ZZXAWe15V8B/gEI8BLgqta+L3Br+7pPW95n1GPbzJgPAl7Ulp8D3Awsn8/jbrXv1ZZ3Ba5qY/k8cFJr/zjwO235d4GPt+WTgAva8vL2nt8dOKz9LOwy6vFtZtxvB/4C+Lu2Pq/H22q+Hdh/Utusvbd7OSI4ClhXVbdW1RPA+cCKEde03arqH4GNk5pXAOe15fOA1w61/58auBLYO8lBwHHAJVW1saoeAC4Bjt/x1W+fqtpQVde05YeBG4FDmMfjbrU/0lZ3bY8CXglc2Nonj3nie3Eh8Kokae3nV9XjVXUbsI7Bz8ROJ8li4N8B57b1MI/HuwWz9t7uJQgOAdYPrd/R2uaTA6tqQ1u+CziwLW9q7HP2e9KmAF7I4DfkeT3uNk1yLXAPgx/s7wIPVtWTrctw/c+MrW1/CNiPuTXmPwHeCTzd1vdjfo93QgFfSrImySmtbdbe2wu2t2rtvKqqkszL64KT7AX8FfB7VfXDwS+AA/Nx3FX1FHBkkr2Bi4AXjLikHSbJrwL3VNWaJEePup5Z9vKqujPJAcAlSb4zvHFHv7d7OSK4Ezh0aH1xa5tP7m6Hh7Sv97T2TY19zn1PkuzKIAQ+W1V/3Zrn/bgBqupB4MvASxlMBUz8Ejdc/zNja9sXAvczd8b8MuDXktzOYPr2lcBHmL/jfUZV3dm+3sMg8I9iFt/bvQTB1cCydvXBbgxOLK0ecU0zbTUwcZXASuALQ+1vbFcavAR4qB1uXgwcm2SfdjXCsa1tp9Tmfj8B3FhVfzy0ad6OO8midiRAkj2BVzM4N/Jl4MTWbfKYJ74XJwKX1+As4mrgpHaVzWHAMuAbszOKrVdV766qxVW1lMHP6OVV9RvM0/FOSPLPkzxnYpnBe/IGZvO9Peqz5bP1YHCm/WYGc6zvHXU90xzL54ANwE8YzAOezGBu9DLgFuBSYN/WN8BH27ivB8aG9vNmBifS1gFvGvW4tjDmlzOYR/0WcG17/Mp8Hjfwi8A325hvAN7X2p/P4INtHfCXwO6tfY+2vq5tf/7Qvt7bvhc3ASeMemxbMfaj+elVQ/N6vG1817XH2onPp9l8b3uLCUnqXC9TQ5KkTTAIJKlzBoEkdc4gkKTOGQSS1DmDQF1L8rX2dWmSN8zwvt8z1WtJOxsvH5WAdkuDd1TVr27DcxbUT++BM9X2R6pqr5moT9qRPCJQ15JM3N3zTOAV7X7wb2s3e/vDJFe3e77/dut/dJKvJFkNfLu1/U27WdjaiRuGJTkT2LPt77PDr9X+IvQPk9zQ7kH/60P7viLJhUm+k+SzGb6ZkrSDeNM5aWAVQ0cE7QP9oar6pSS7A19N8qXW90XAv6rBLY4B3lxVG9ttIK5O8ldVtSrJqVV15BSv9TrgSODfAPu35/xj2/ZC4AjgB8BXGdx/559mfrjST3lEIE3tWAb3c7mWwe2u92NwzxqAbwyFAMB/SXIdcCWDm34tY/NeDnyuqp6qqruB/wf80tC+76iqpxncRmPpjIxG2gyPCKSpBXhrVT3rpl3tXMKPJq0fA7y0qh5NcgWDe+Bsr8eHlp/Cn1HNAo8IpIGHGfwLzAkXA7/Tbn1Nkl9od4acbCHwQAuBFzD414ETfjLx/Em+Avx6Ow+xiMG/Ht1p746p+c/fNqSBbwFPtSme/83gPvhLgWvaCdt7+em/Chz2ReAtSW5kcKfLK4e2nQN8K8k1Nbid8oSLGPxfgesY3FH1nVV1VwsSadZ5+agkdc6pIUnqnEEgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOvf/AQ/Xfo538TV8AAAAAElFTkSuQmCC\n", 521 | "text/plain": [ 522 | "
" 523 | ] 524 | }, 525 | "metadata": { 526 | "needs_background": "light" 527 | } 528 | }, 529 | { 530 | "output_type": "stream", 531 | "name": "stdout", 532 | "text": [ 533 | "videos/q_learning_policy.mp4\n" 534 | ] 535 | }, 536 | { 537 | "output_type": "display_data", 538 | "data": { 539 | "text/html": [ 540 | "" 544 | ], 545 | "text/plain": [ 546 | "" 547 | ] 548 | }, 549 | "metadata": {} 550 | } 551 | ] 552 | } 553 | ] 554 | } -------------------------------------------------------------------------------- /colab_test/test_rlberry_setup.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "test_rlberry_setup.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyO6kyz5+E9FocC44CxfHJ76", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "qL-gF6FESKFk" 32 | }, 33 | "source": [ 34 | "# Colab setup" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "metadata": { 40 | "id": "sK5bE1AsL2Z8" 41 | }, 42 | "source": [ 43 | "# After installing, restart the kernel\n", 44 | "\n", 45 | "# install rlberry library\n", 46 | "!git clone https://github.com/rlberry-py/rlberry.git\n", 47 | "!cd rlberry && git pull && pip install -e .[full]\n", 48 | "!pip install ffmpeg-python > /dev/null 2>&1\n", 49 | "\n", 50 | "# packages required to show video\n", 51 | "!pip install pyvirtualdisplay > /dev/null 2>&1\n", 52 | "!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1\n", 53 | "\n", 54 | "# restart runtime\n", 55 | "import os\n", 56 | "os.kill(os.getpid(), 9)" 57 | ], 58 | "execution_count": null, 59 | "outputs": [] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "metadata": { 64 | "id": "jr1cmKKoSFpq" 65 | }, 66 | "source": [ 67 | "# Create directory for saving videos\n", 68 | "!mkdir videos > /dev/null 2>&1\n", 69 | "\n", 70 | "# Initialize virtual display and import show_video function\n", 71 | "import rlberry.colab_utils.display_setup\n", 72 | "from rlberry.colab_utils.display_setup import show_video" 73 | ], 74 | "execution_count": 4, 75 | "outputs": [] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": { 80 | "id": "PNZY8gcrSP--" 81 | }, 82 | "source": [ 83 | "# 1. Importing modules and running unit tests\n", 84 | "---" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "metadata": { 90 | "id": "0JdnSic9PCDm" 91 | }, 92 | "source": [ 93 | "import rlberry\n", 94 | "import rlberry.agents\n", 95 | "import rlberry.stats\n", 96 | "import rlberry.envs\n", 97 | "import rlberry.exploration_tools\n", 98 | "import rlberry.rendering\n", 99 | "import rlberry.seeding \n", 100 | "import rlberry.spaces \n", 101 | "import rlberry.utils\n", 102 | "import rlberry.wrappers" 103 | ], 104 | "execution_count": 5, 105 | "outputs": [] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "metadata": { 110 | "id": "UeNblieLHklr" 111 | }, 112 | "source": [ 113 | "!python -m pytest rlberry/" 114 | ], 115 | "execution_count": null, 116 | "outputs": [] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": { 121 | "id": "wdaxg13aIa9X" 122 | }, 123 | "source": [ 124 | "# 2. Interacting with GridWorld and saving video" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "metadata": { 130 | "id": "ZwpyeJAsRKRR" 131 | }, 132 | "source": [ 133 | "from rlberry.envs import GridWorld\n", 134 | "\n", 135 | "env = GridWorld(nrows=12, ncols=15, walls=((5,5),(6, 6)))\n", 136 | "\n", 137 | "# call enable_rendering if you want to record a video from the interactions\n", 138 | "env.enable_rendering()\n", 139 | "# get initial state\n", 140 | "state = env.reset()\n", 141 | "# run a random policy for 100 time steps\n", 142 | "for tt in range(100):\n", 143 | " action = env.action_space.sample() # a good RL algorithm must learn a better way to choose actions!\n", 144 | " next_state, reward, is_terminal, info = env.step(action)\n", 145 | " if is_terminal:\n", 146 | " break\n", 147 | " state = next_state\n", 148 | "env.save_video(\"videos/env_example.mp4\", framerate=10)\n", 149 | "\n", 150 | "# show video\n", 151 | "show_video()" 152 | ], 153 | "execution_count": null, 154 | "outputs": [] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "metadata": { 159 | "id": "YAsvlO52TMBX" 160 | }, 161 | "source": [ 162 | "" 163 | ], 164 | "execution_count": null, 165 | "outputs": [] 166 | } 167 | ] 168 | } -------------------------------------------------------------------------------- /logo/logo_wide.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | 7 | Fichier 16 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 17 | 19 | 21 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 37 | 39 | 41 | 43 | 45 | 47 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/rlberry-py/rlberry.git 2 | jupyterlab 3 | ffmpeg-python 4 | ipywidgets 5 | pyglet==1.5.27 6 | numpy>=1.17 7 | scipy>=1.6 8 | pygame 9 | matplotlib 10 | seaborn 11 | pandas 12 | gym==0.21 13 | dill 14 | docopt 15 | pyyaml 16 | numba 17 | optuna 18 | PyOpenGL==3.1.5 19 | PyOpenGL_accelerate==3.1.5 20 | pyvirtualdisplay 21 | torch>=1.6.0 22 | stable-baselines3 23 | protobuf==3.20.1 24 | tensorboard 25 | ipywidgets 26 | --------------------------------------------------------------------------------