├── .gitignore ├── Chapter1-初识强化学习 └── 1.6-案例:基于Gym库的智能体-环境交互.ipynb ├── Chapter2-Markov决策过程 ├── 2.2-Bellman期望方程.ipynb ├── 2.3-最优策略及其性质.ipynb └── 2.4-案例:悬崖寻路.ipynb ├── Chapter3-有模型数值迭代 └── 3.5-案例:冰面滑行.ipynb ├── Chapter4-回合更新价值迭代 └── 4.3-案例:21点游戏.ipynb ├── Chapter5-时序差分价值迭代 └── 5.4-案例:出租车调度.ipynb ├── Chapter6-函数近似方法 └── 6.5-案例:小车上山.ipynb ├── Chapter7-回合更新策略梯度方法 └── 7.5-案例:车杆平衡.ipynb └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /Chapter1-初识强化学习/1.6-案例:基于Gym库的智能体-环境交互.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gym\n", 10 | "import warnings\n", 11 | "import numpy as np\n", 12 | "\n", 13 | "warnings.filterwarnings('ignore')" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "观测空间 = Box(2,)\n", 26 | "动作空间 = Discrete(3)\n", 27 | "观测范围 = [-1.2 -0.07] ~ [0.6 0.07]\n", 28 | "动作数 = 3\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "# 导入环境并查看观测空间和动作空间\n", 34 | "\n", 35 | "env = gym.make('MountainCar-v0')\n", 36 | "print('观测空间 = {}'.format(env.observation_space))\n", 37 | "print('动作空间 = {}'.format(env.action_space))\n", 38 | "print('观测范围 = {} ~ {}'.format(env.observation_space.low, env.observation_space.high))\n", 39 | "print('动作数 = {}'.format(env.action_space.n))" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# 根据指定确定性策略决定动作的智能体\n", 49 | "\n", 50 | "class BespokeAgent(object):\n", 51 | " def __init__(self, env):\n", 52 | " pass\n", 53 | " \n", 54 | " def decide(self, observation):\n", 55 | " position, velocity = observation\n", 56 | " lb = min(-0.09 * (position + 0.25) ** 2 + 0.03, 0.3 * (position + 0.9) ** 4 - 0.008)\n", 57 | " ub = -0.07 * (position + 0.38) ** 2 + 0.06\n", 58 | " \n", 59 | " if lb < velocity < ub:\n", 60 | " action = 2\n", 61 | " else:\n", 62 | " action = 0\n", 63 | " return action\n", 64 | "\n", 65 | " def learn(self, *args):\n", 66 | " pass" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "# 智能体和环境交互一个回合的代码\n", 76 | "\n", 77 | "def play_montecarlo(env, agent, render=False, train=False):\n", 78 | " episode_reward = 0.0\n", 79 | " observation = env.reset()\n", 80 | " \n", 81 | " while True:\n", 82 | " if render:\n", 83 | " env.render()\n", 84 | " \n", 85 | " action = agent.decide(observation)\n", 86 | " next_observation, reward, done, _ = env.step(action)\n", 87 | " episode_reward += reward\n", 88 | " \n", 89 | " if train:\n", 90 | " agent.learn(observation, action, reward, done)\n", 91 | " if done:\n", 92 | " break\n", 93 | " \n", 94 | " observation = next_observation\n", 95 | " \n", 96 | " return episode_reward" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 5, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "agent = BespokeAgent(env)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "回合奖励 = -113.0\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "env.seed(0)\n", 123 | "episode_reward = play_montecarlo(env, agent, render=True)\n", 124 | "print('回合奖励 = {}'.format(episode_reward))\n", 125 | "env.close()" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 7, 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "平均回合奖励 = -108.26\n" 138 | ] 139 | } 140 | ], 141 | "source": [ 142 | "# 运行100回合求平均以测试性能\n", 143 | "\n", 144 | "episode_rewards = [play_montecarlo(env, agent) for _ in range(100)]\n", 145 | "print('平均回合奖励 = {}'.format(np.mean(episode_rewards)))" 146 | ] 147 | } 148 | ], 149 | "metadata": { 150 | "kernelspec": { 151 | "display_name": "Python 3", 152 | "language": "python", 153 | "name": "python3" 154 | }, 155 | "language_info": { 156 | "codemirror_mode": { 157 | "name": "ipython", 158 | "version": 3 159 | }, 160 | "file_extension": ".py", 161 | "mimetype": "text/x-python", 162 | "name": "python", 163 | "nbconvert_exporter": "python", 164 | "pygments_lexer": "ipython3", 165 | "version": "3.7.6" 166 | } 167 | }, 168 | "nbformat": 4, 169 | "nbformat_minor": 4 170 | } 171 | -------------------------------------------------------------------------------- /Chapter2-Markov决策过程/2.2-Bellman期望方程.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sympy\n", 10 | "\n", 11 | "from sympy import symbols\n", 12 | "\n", 13 | "sympy.init_printing()" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "v_hungry, v_full = symbols('v_hungry v_full')\n", 23 | "q_hungry_eat, q_hungry_none, q_full_eat, q_full_none = symbols('q_hungry_eat q_hungry_none q_full_eat q_full_none')\n", 24 | "alpha, beta, x, y, gamma = symbols('alpha, beta, x, y, gamma')" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "# 求解示例Bellman期望方程\n", 34 | "\n", 35 | "system = sympy.Matrix((\n", 36 | " (1, 0, x - 1, -x, 0, 0, 0),\n", 37 | " (0, 1, 0, 0, -y, y - 1, 0),\n", 38 | " (-gamma, 0, 1, 0, 0, 0, -2),\n", 39 | " ((alpha - 1) * gamma, -alpha * gamma, 0, 1, 0, 0, 4 * alpha - 3),\n", 40 | " (-beta * gamma, (beta - 1) * gamma, 0, 0, 1, 0, -4 * beta + 2),\n", 41 | " (0, -gamma, 0, 0, 0, 1, 1)\n", 42 | "))" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 4, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "data": { 52 | "text/latex": [ 53 | "$\\displaystyle \\left\\{ q_{full eat} : \\frac{- \\alpha \\gamma^{2} x y + \\alpha \\gamma^{2} x - 2 \\alpha \\gamma x + \\beta \\gamma^{2} x y - \\beta \\gamma^{2} x - \\beta \\gamma^{2} y + \\beta \\gamma^{2} + \\beta \\gamma x + 3 \\beta \\gamma y - 5 \\beta \\gamma + 4 \\beta + \\gamma^{2} y - \\gamma^{2} - \\gamma y + 3 \\gamma - 2}{\\alpha \\gamma^{2} x - \\alpha \\gamma x + \\beta \\gamma^{2} y - \\beta \\gamma y - \\gamma^{2} + 2 \\gamma - 1}, \\ q_{full none} : \\frac{- \\alpha \\gamma^{2} x y - \\alpha \\gamma x + \\beta \\gamma^{2} x y - \\beta \\gamma^{2} y + 3 \\beta \\gamma y + \\gamma^{2} y - \\gamma y + \\gamma - 1}{\\alpha \\gamma^{2} x - \\alpha \\gamma x + \\beta \\gamma^{2} y - \\beta \\gamma y - \\gamma^{2} + 2 \\gamma - 1}, \\ q_{hungry eat} : \\frac{- \\alpha \\gamma^{2} x y + \\alpha \\gamma^{2} x - 2 \\alpha \\gamma x + \\beta \\gamma^{2} x y + 2 \\beta \\gamma y - \\gamma^{2} x + \\gamma x - 2 \\gamma + 2}{\\alpha \\gamma^{2} x - \\alpha \\gamma x + \\beta \\gamma^{2} y - \\beta \\gamma y - \\gamma^{2} + 2 \\gamma - 1}, \\ q_{hungry none} : \\frac{- \\alpha \\gamma^{2} x y + \\alpha \\gamma^{2} x + \\alpha \\gamma^{2} y - \\alpha \\gamma^{2} - 2 \\alpha \\gamma x - \\alpha \\gamma y + 5 \\alpha \\gamma - 4 \\alpha + \\beta \\gamma^{2} x y - \\beta \\gamma^{2} y + 3 \\beta \\gamma y - \\gamma^{2} x + \\gamma^{2} + \\gamma x - 4 \\gamma + 3}{\\alpha \\gamma^{2} x - \\alpha \\gamma x + \\beta \\gamma^{2} y - \\beta \\gamma y - \\gamma^{2} + 2 \\gamma - 1}, \\ v_{full} : \\frac{- \\alpha \\gamma x y - \\alpha \\gamma x + \\beta \\gamma x y - 2 \\beta \\gamma y + 4 \\beta y + \\gamma y + \\gamma - y - 1}{\\alpha \\gamma^{2} x - \\alpha \\gamma x + \\beta \\gamma^{2} y - \\beta \\gamma y - \\gamma^{2} + 2 \\gamma - 1}, \\ v_{hungry} : \\frac{- \\alpha \\gamma x y + 3 \\alpha \\gamma x - 4 \\alpha x + \\beta \\gamma x y + 2 \\beta \\gamma y - \\gamma x - 2 \\gamma + x + 2}{\\alpha \\gamma^{2} x - \\alpha \\gamma x + \\beta \\gamma^{2} y - \\beta \\gamma y - \\gamma^{2} + 2 \\gamma - 1}\\right\\}$" 54 | ], 55 | "text/plain": [ 56 | "⎧ 2 2 2 2 2 2\n", 57 | "⎪ - α⋅γ ⋅x⋅y + α⋅γ ⋅x - 2⋅α⋅γ⋅x + β⋅γ ⋅x⋅y - β⋅γ ⋅x - β⋅γ ⋅y + β⋅γ \n", 58 | "⎨q_full_eat: ─────────────────────────────────────────────────────────────────\n", 59 | "⎪ 2 2 \n", 60 | "⎩ α⋅γ ⋅x - α⋅γ⋅x + β⋅γ ⋅y -\n", 61 | "\n", 62 | " 2 2 \n", 63 | " + β⋅γ⋅x + 3⋅β⋅γ⋅y - 5⋅β⋅γ + 4⋅β + γ ⋅y - γ - γ⋅y + 3⋅γ - 2 - α\n", 64 | "────────────────────────────────────────────────────────────, q_full_none: ───\n", 65 | " 2 \n", 66 | " β⋅γ⋅y - γ + 2⋅γ - 1 \n", 67 | "\n", 68 | " 2 2 2 2 \n", 69 | "⋅γ ⋅x⋅y - α⋅γ⋅x + β⋅γ ⋅x⋅y - β⋅γ ⋅y + 3⋅β⋅γ⋅y + γ ⋅y - γ⋅y + γ - 1 \n", 70 | "──────────────────────────────────────────────────────────────────, q_hungry_e\n", 71 | " 2 2 2 \n", 72 | " α⋅γ ⋅x - α⋅γ⋅x + β⋅γ ⋅y - β⋅γ⋅y - γ + 2⋅γ - 1 \n", 73 | "\n", 74 | " 2 2 2 2 \n", 75 | " - α⋅γ ⋅x⋅y + α⋅γ ⋅x - 2⋅α⋅γ⋅x + β⋅γ ⋅x⋅y + 2⋅β⋅γ⋅y - γ ⋅x + γ⋅x - 2⋅γ + 2 \n", 76 | "at: ─────────────────────────────────────────────────────────────────────────,\n", 77 | " 2 2 2 \n", 78 | " α⋅γ ⋅x - α⋅γ⋅x + β⋅γ ⋅y - β⋅γ⋅y - γ + 2⋅γ - 1 \n", 79 | "\n", 80 | " 2 2 2 2 \n", 81 | " - α⋅γ ⋅x⋅y + α⋅γ ⋅x + α⋅γ ⋅y - α⋅γ - 2⋅α⋅γ⋅x - α⋅γ⋅y + 5⋅α⋅γ \n", 82 | " q_hungry_none: ──────────────────────────────────────────────────────────────\n", 83 | " 2 2 \n", 84 | " α⋅γ ⋅x - α⋅γ⋅x + β⋅γ ⋅\n", 85 | "\n", 86 | " 2 2 2 2 \n", 87 | "- 4⋅α + β⋅γ ⋅x⋅y - β⋅γ ⋅y + 3⋅β⋅γ⋅y - γ ⋅x + γ + γ⋅x - 4⋅γ + 3 -α⋅γ⋅\n", 88 | "───────────────────────────────────────────────────────────────, v_full: ─────\n", 89 | " 2 \n", 90 | "y - β⋅γ⋅y - γ + 2⋅γ - 1 \n", 91 | "\n", 92 | " \n", 93 | "x⋅y - α⋅γ⋅x + β⋅γ⋅x⋅y - 2⋅β⋅γ⋅y + 4⋅β⋅y + γ⋅y + γ - y - 1 -α⋅γ⋅x⋅y \n", 94 | "─────────────────────────────────────────────────────────, v_hungry: ─────────\n", 95 | " 2 2 2 \n", 96 | " α⋅γ ⋅x - α⋅γ⋅x + β⋅γ ⋅y - β⋅γ⋅y - γ + 2⋅γ - 1 \n", 97 | "\n", 98 | " ⎫\n", 99 | "+ 3⋅α⋅γ⋅x - 4⋅α⋅x + β⋅γ⋅x⋅y + 2⋅β⋅γ⋅y - γ⋅x - 2⋅γ + x + 2⎪\n", 100 | "─────────────────────────────────────────────────────────⎬\n", 101 | " 2 2 2 ⎪\n", 102 | " α⋅γ ⋅x - α⋅γ⋅x + β⋅γ ⋅y - β⋅γ⋅y - γ + 2⋅γ - 1 ⎭" 103 | ] 104 | }, 105 | "execution_count": 4, 106 | "metadata": {}, 107 | "output_type": "execute_result" 108 | } 109 | ], 110 | "source": [ 111 | "sympy.solve_linear_system(system, v_hungry, v_full, q_hungry_eat, q_hungry_none, q_full_eat, q_full_none)" 112 | ] 113 | } 114 | ], 115 | "metadata": { 116 | "kernelspec": { 117 | "display_name": "Python 3", 118 | "language": "python", 119 | "name": "python3" 120 | }, 121 | "language_info": { 122 | "codemirror_mode": { 123 | "name": "ipython", 124 | "version": 3 125 | }, 126 | "file_extension": ".py", 127 | "mimetype": "text/x-python", 128 | "name": "python", 129 | "nbconvert_exporter": "python", 130 | "pygments_lexer": "ipython3", 131 | "version": "3.7.6" 132 | } 133 | }, 134 | "nbformat": 4, 135 | "nbformat_minor": 4 136 | } 137 | -------------------------------------------------------------------------------- /Chapter2-Markov决策过程/2.3-最优策略及其性质.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sympy\n", 10 | "\n", 11 | "from sympy import symbols\n", 12 | "\n", 13 | "sympy.init_printing()" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "alpha, beta, gamma = symbols('alpha beta gamma')\n", 23 | "v_hungry, v_full = symbols('v_hungry v_full')\n", 24 | "q_hungry_eat, q_hungry_none, q_full_eat, q_full_none = symbols('q_hungry_eat q_hungry_none q_full_eat q_full_none')" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "==== v(饿) = q(饿, 不吃), v(饱) = q(饱, 吃) ==== x = 0, y = 0 ====\n" 37 | ] 38 | }, 39 | { 40 | "data": { 41 | "text/latex": [ 42 | "$\\displaystyle \\left\\{ q_{full eat} : \\frac{- \\beta \\gamma + 4 \\beta + \\gamma - 2}{\\gamma - 1}, \\ q_{full none} : - \\frac{1}{\\gamma - 1}, \\ q_{hungry eat} : \\frac{2}{\\gamma - 1}, \\ q_{hungry none} : \\frac{\\alpha \\gamma - 4 \\alpha - \\gamma + 3}{\\gamma - 1}, \\ v_{full} : - \\frac{1}{\\gamma - 1}, \\ v_{hungry} : \\frac{2}{\\gamma - 1}\\right\\}$" 43 | ], 44 | "text/plain": [ 45 | "⎧ -β⋅γ + 4⋅β + γ - 2 -1 2 \n", 46 | "⎨q_full_eat: ──────────────────, q_full_none: ─────, q_hungry_eat: ─────, q_hu\n", 47 | "⎩ γ - 1 γ - 1 γ - 1 \n", 48 | "\n", 49 | " α⋅γ - 4⋅α - γ + 3 -1 2 ⎫\n", 50 | "ngry_none: ─────────────────, v_full: ─────, v_hungry: ─────⎬\n", 51 | " γ - 1 γ - 1 γ - 1⎭" 52 | ] 53 | }, 54 | "metadata": {}, 55 | "output_type": "display_data" 56 | }, 57 | { 58 | "name": "stdout", 59 | "output_type": "stream", 60 | "text": [ 61 | "==== v(饿) = q(饿, 吃), v(饱) = q(饱, 吃) ==== x = 1, y = 0 ====\n" 62 | ] 63 | }, 64 | { 65 | "data": { 66 | "text/latex": [ 67 | "$\\displaystyle \\left\\{ q_{full eat} : \\frac{\\alpha \\gamma^{2} - 2 \\alpha \\gamma - 4 \\beta \\gamma + 4 \\beta - \\gamma^{2} + 3 \\gamma - 2}{\\alpha \\gamma^{2} - \\alpha \\gamma - \\gamma^{2} + 2 \\gamma - 1}, \\ q_{full none} : - \\frac{1}{\\gamma - 1}, \\ q_{hungry eat} : \\frac{\\alpha \\gamma^{2} - 2 \\alpha \\gamma - \\gamma^{2} - \\gamma + 2}{\\alpha \\gamma^{2} - \\alpha \\gamma - \\gamma^{2} + 2 \\gamma - 1}, \\ q_{hungry none} : \\frac{- \\alpha \\gamma + \\left(4 \\alpha - 3\\right) \\left(\\gamma - 1\\right)}{\\gamma^{2} \\left(\\alpha - 1\\right) - \\gamma \\left(\\alpha - 1\\right) + \\gamma - 1}, \\ v_{full} : - \\frac{1}{\\gamma - 1}, \\ v_{hungry} : \\frac{- \\alpha \\gamma + \\left(4 \\alpha - 3\\right) \\left(\\gamma - 1\\right)}{\\gamma^{2} \\left(\\alpha - 1\\right) - \\gamma \\left(\\alpha - 1\\right) + \\gamma - 1}\\right\\}$" 68 | ], 69 | "text/plain": [ 70 | "⎧ 2 2 \n", 71 | "⎪ α⋅γ - 2⋅α⋅γ - 4⋅β⋅γ + 4⋅β - γ + 3⋅γ - 2 -1 \n", 72 | "⎨q_full_eat: ─────────────────────────────────────────, q_full_none: ─────, q_\n", 73 | "⎪ 2 2 γ - 1 \n", 74 | "⎩ α⋅γ - α⋅γ - γ + 2⋅γ - 1 \n", 75 | "\n", 76 | " 2 2 \n", 77 | " α⋅γ - 2⋅α⋅γ - γ - γ + 2 -α⋅γ + (4⋅α - 3)⋅(γ -\n", 78 | "hungry_eat: ─────────────────────────, q_hungry_none: ────────────────────────\n", 79 | " 2 2 2 \n", 80 | " α⋅γ - α⋅γ - γ + 2⋅γ - 1 γ ⋅(α - 1) - γ⋅(α - 1) +\n", 81 | "\n", 82 | " ⎫\n", 83 | " 1) -1 -α⋅γ + (4⋅α - 3)⋅(γ - 1) ⎪\n", 84 | "──────, v_full: ─────, v_hungry: ──────────────────────────────⎬\n", 85 | " γ - 1 2 ⎪\n", 86 | " γ - 1 γ ⋅(α - 1) - γ⋅(α - 1) + γ - 1⎭" 87 | ] 88 | }, 89 | "metadata": {}, 90 | "output_type": "display_data" 91 | }, 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "==== v(饿) = q(饿, 不吃), v(饱) = q(饱, 不吃) ==== x = 0, y = 1 ====\n" 97 | ] 98 | }, 99 | { 100 | "data": { 101 | "text/latex": [ 102 | "$\\displaystyle \\left\\{ q_{full eat} : \\frac{2 \\left(\\beta \\gamma - \\left(2 \\beta - 1\\right) \\left(\\gamma - 1\\right)\\right)}{\\gamma^{2} \\left(\\beta - 1\\right) - \\gamma \\left(\\beta - 1\\right) + \\gamma - 1}, \\ q_{full none} : \\frac{- \\beta \\gamma^{2} + 3 \\beta \\gamma + \\gamma^{2} - 1}{\\beta \\gamma^{2} - \\beta \\gamma - \\gamma^{2} + 2 \\gamma - 1}, \\ q_{hungry eat} : \\frac{2}{\\gamma - 1}, \\ q_{hungry none} : \\frac{4 \\alpha \\gamma - 4 \\alpha - \\beta \\gamma^{2} + 3 \\beta \\gamma + \\gamma^{2} - 4 \\gamma + 3}{\\beta \\gamma^{2} - \\beta \\gamma - \\gamma^{2} + 2 \\gamma - 1}, \\ v_{full} : \\frac{2 \\left(\\beta \\gamma - \\left(2 \\beta - 1\\right) \\left(\\gamma - 1\\right)\\right)}{\\gamma^{2} \\left(\\beta - 1\\right) - \\gamma \\left(\\beta - 1\\right) + \\gamma - 1}, \\ v_{hungry} : \\frac{2}{\\gamma - 1}\\right\\}$" 103 | ], 104 | "text/plain": [ 105 | "⎧ 2 2\n", 106 | "⎪ 2⋅(β⋅γ - (2⋅β - 1)⋅(γ - 1)) - β⋅γ + 3⋅β⋅γ + γ \n", 107 | "⎨q_full_eat: ──────────────────────────────, q_full_none: ────────────────────\n", 108 | "⎪ 2 2 2 \n", 109 | "⎩ γ ⋅(β - 1) - γ⋅(β - 1) + γ - 1 β⋅γ - β⋅γ - γ + 2⋅\n", 110 | "\n", 111 | " 2 2 \n", 112 | " - 1 2 4⋅α⋅γ - 4⋅α - β⋅γ + 3⋅β⋅γ + γ - 4\n", 113 | "─────, q_hungry_eat: ─────, q_hungry_none: ───────────────────────────────────\n", 114 | " γ - 1 2 2 \n", 115 | "γ - 1 β⋅γ - β⋅γ - γ + 2⋅γ - 1 \n", 116 | "\n", 117 | " ⎫\n", 118 | "⋅γ + 3 2⋅(β⋅γ - (2⋅β - 1)⋅(γ - 1)) 2 ⎪\n", 119 | "──────, v_full: ──────────────────────────────, v_hungry: ─────⎬\n", 120 | " 2 γ - 1⎪\n", 121 | " γ ⋅(β - 1) - γ⋅(β - 1) + γ - 1 ⎭" 122 | ] 123 | }, 124 | "metadata": {}, 125 | "output_type": "display_data" 126 | }, 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "==== v(饿) = q(饿, 吃), v(饱) = q(饱, 不吃) ==== x = 1, y = 1 ====\n" 132 | ] 133 | }, 134 | { 135 | "data": { 136 | "text/latex": [ 137 | "$\\displaystyle \\left\\{ q_{full eat} : \\frac{- 2 \\alpha \\gamma - \\beta \\gamma + 4 \\beta + 2 \\gamma - 2}{\\alpha \\gamma^{2} - \\alpha \\gamma + \\beta \\gamma^{2} - \\beta \\gamma - \\gamma^{2} + 2 \\gamma - 1}, \\ q_{full none} : \\frac{- \\alpha \\gamma^{2} - \\alpha \\gamma + 3 \\beta \\gamma + \\gamma^{2} - 1}{\\alpha \\gamma^{2} - \\alpha \\gamma + \\beta \\gamma^{2} - \\beta \\gamma - \\gamma^{2} + 2 \\gamma - 1}, \\ q_{hungry eat} : \\frac{- 2 \\alpha \\gamma + \\beta \\gamma^{2} + 2 \\beta \\gamma - \\gamma^{2} - \\gamma + 2}{\\alpha \\gamma^{2} - \\alpha \\gamma + \\beta \\gamma^{2} - \\beta \\gamma - \\gamma^{2} + 2 \\gamma - 1}, \\ q_{hungry none} : \\frac{2 \\alpha \\gamma - 4 \\alpha + 3 \\beta \\gamma - 3 \\gamma + 3}{\\alpha \\gamma^{2} - \\alpha \\gamma + \\beta \\gamma^{2} - \\beta \\gamma - \\gamma^{2} + 2 \\gamma - 1}, \\ v_{full} : \\frac{- 2 \\alpha \\gamma - \\beta \\gamma + 4 \\beta + 2 \\gamma - 2}{\\alpha \\gamma^{2} - \\alpha \\gamma + \\beta \\gamma^{2} - \\beta \\gamma - \\gamma^{2} + 2 \\gamma - 1}, \\ v_{hungry} : \\frac{2 \\alpha \\gamma - 4 \\alpha + 3 \\beta \\gamma - 3 \\gamma + 3}{\\alpha \\gamma^{2} - \\alpha \\gamma + \\beta \\gamma^{2} - \\beta \\gamma - \\gamma^{2} + 2 \\gamma - 1}\\right\\}$" 138 | ], 139 | "text/plain": [ 140 | "⎧ 2 \n", 141 | "⎪ -2⋅α⋅γ - β⋅γ + 4⋅β + 2⋅γ - 2 - α⋅γ -\n", 142 | "⎨q_full_eat: ──────────────────────────────────────, q_full_none: ────────────\n", 143 | "⎪ 2 2 2 2 \n", 144 | "⎩ α⋅γ - α⋅γ + β⋅γ - β⋅γ - γ + 2⋅γ - 1 α⋅γ - α⋅γ +\n", 145 | "\n", 146 | " 2 2 2 \n", 147 | " α⋅γ + 3⋅β⋅γ + γ - 1 -2⋅α⋅γ + β⋅γ + 2⋅β⋅γ - γ - γ + 2\n", 148 | "──────────────────────────, q_hungry_eat: ────────────────────────────────────\n", 149 | " 2 2 2 2 2 \n", 150 | " β⋅γ - β⋅γ - γ + 2⋅γ - 1 α⋅γ - α⋅γ + β⋅γ - β⋅γ - γ + 2⋅γ -\n", 151 | "\n", 152 | " \n", 153 | " 2⋅α⋅γ - 4⋅α + 3⋅β⋅γ - 3⋅γ + 3 -2⋅α⋅γ\n", 154 | "──, q_hungry_none: ──────────────────────────────────────, v_full: ───────────\n", 155 | " 2 2 2 2 \n", 156 | " 1 α⋅γ - α⋅γ + β⋅γ - β⋅γ - γ + 2⋅γ - 1 α⋅γ - α⋅γ \n", 157 | "\n", 158 | " ⎫\n", 159 | " - β⋅γ + 4⋅β + 2⋅γ - 2 2⋅α⋅γ - 4⋅α + 3⋅β⋅γ - 3⋅γ + 3 ⎪\n", 160 | "───────────────────────────, v_hungry: ──────────────────────────────────────⎬\n", 161 | " 2 2 2 2 2 ⎪\n", 162 | "+ β⋅γ - β⋅γ - γ + 2⋅γ - 1 α⋅γ - α⋅γ + β⋅γ - β⋅γ - γ + 2⋅γ - 1⎭" 163 | ] 164 | }, 165 | "metadata": {}, 166 | "output_type": "display_data" 167 | } 168 | ], 169 | "source": [ 170 | "# 求解示例Bellman最优方程\n", 171 | "\n", 172 | "xy_tuples = ((0, 0), (1, 0), (0, 1), (1, 1))\n", 173 | "for x, y in xy_tuples:\n", 174 | " system = sympy.Matrix((\n", 175 | " (1, 0, x - 1, -x, 0, 0, 0),\n", 176 | " (0, 1, 0, 0, -y, y - 1, 0),\n", 177 | " (-gamma, 0, 1, 0, 0, 0, -2),\n", 178 | " ((alpha - 1) * gamma, -alpha * gamma, 0, 1, 0, 0, 4 * alpha - 3),\n", 179 | " (-beta * gamma, (beta - 1) * gamma, 0, 0, 1, 0, -4 * beta + 2),\n", 180 | " (0, -gamma, 0, 0, 0, 1, 1)\n", 181 | " ))\n", 182 | " \n", 183 | " result = sympy.solve_linear_system(system, v_hungry, v_full, q_hungry_eat, q_hungry_none, q_full_eat, q_full_none)\n", 184 | " msgx = 'v(饿) = q(饿, {}吃)'.format('' if x else '不')\n", 185 | " msgy = 'v(饱) = q(饱, {}吃)'.format('不' if y else '')\n", 186 | " print('==== {}, {} ==== x = {}, y = {} ===='.format(msgx, msgy, x, y))\n", 187 | " display(result)" 188 | ] 189 | } 190 | ], 191 | "metadata": { 192 | "kernelspec": { 193 | "display_name": "Python 3", 194 | "language": "python", 195 | "name": "python3" 196 | }, 197 | "language_info": { 198 | "codemirror_mode": { 199 | "name": "ipython", 200 | "version": 3 201 | }, 202 | "file_extension": ".py", 203 | "mimetype": "text/x-python", 204 | "name": "python", 205 | "nbconvert_exporter": "python", 206 | "pygments_lexer": "ipython3", 207 | "version": "3.7.6" 208 | } 209 | }, 210 | "nbformat": 4, 211 | "nbformat_minor": 4 212 | } 213 | -------------------------------------------------------------------------------- /Chapter2-Markov决策过程/2.4-案例:悬崖寻路.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gym\n", 10 | "import scipy\n", 11 | "import numpy as np" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "观测空间 = Discrete(48)\n", 24 | "动作空间 = Discrete(4)\n", 25 | "观测数量 = 48, 动作数量 = 4\n", 26 | "地图大小 = (4, 12)\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "# 导入‘CliffWalking-v0’环境\n", 32 | "\n", 33 | "env = gym.make('CliffWalking-v0')\n", 34 | "print('观测空间 = {}'.format(env.observation_space))\n", 35 | "print('动作空间 = {}'.format(env.action_space))\n", 36 | "print('观测数量 = {}, 动作数量 = {}'.format(env.nS, env.nA))\n", 37 | "print('地图大小 = {}'.format(env.shape))" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "# 运行一个回合\n", 47 | "\n", 48 | "def play_once(env, policy):\n", 49 | " total_reward = 0\n", 50 | " state = env.reset()\n", 51 | " \n", 52 | " while True:\n", 53 | " loc = np.unravel_index(state, env.shape)\n", 54 | " print('状态 = {}, 位置 = {}'.format(state, loc), end = ' | ')\n", 55 | " action = np.random.choice(env.nA, p=policy[state])\n", 56 | " state, reward, done, _ = env.step(action)\n", 57 | " print('动作 = {}, 奖励 = {}'.format(action, reward))\n", 58 | " total_reward += reward\n", 59 | " if done:\n", 60 | " break\n", 61 | " \n", 62 | " return total_reward" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 4, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "最优策略 = \n", 75 | "[[1 1 1 1 1 1 1 1 1 1 1 2]\n", 76 | " [1 1 1 1 1 1 1 1 1 1 1 2]\n", 77 | " [1 1 1 1 1 1 1 1 1 1 1 2]\n", 78 | " [0 0 0 0 0 0 0 0 0 0 0 2]]\n", 79 | "\n", 80 | "状态 = 36, 位置 = (3, 0) | 动作 = 0, 奖励 = -1\n", 81 | "状态 = 24, 位置 = (2, 0) | 动作 = 1, 奖励 = -1\n", 82 | "状态 = 25, 位置 = (2, 1) | 动作 = 1, 奖励 = -1\n", 83 | "状态 = 26, 位置 = (2, 2) | 动作 = 1, 奖励 = -1\n", 84 | "状态 = 27, 位置 = (2, 3) | 动作 = 1, 奖励 = -1\n", 85 | "状态 = 28, 位置 = (2, 4) | 动作 = 1, 奖励 = -1\n", 86 | "状态 = 29, 位置 = (2, 5) | 动作 = 1, 奖励 = -1\n", 87 | "状态 = 30, 位置 = (2, 6) | 动作 = 1, 奖励 = -1\n", 88 | "状态 = 31, 位置 = (2, 7) | 动作 = 1, 奖励 = -1\n", 89 | "状态 = 32, 位置 = (2, 8) | 动作 = 1, 奖励 = -1\n", 90 | "状态 = 33, 位置 = (2, 9) | 动作 = 1, 奖励 = -1\n", 91 | "状态 = 34, 位置 = (2, 10) | 动作 = 1, 奖励 = -1\n", 92 | "状态 = 35, 位置 = (2, 11) | 动作 = 2, 奖励 = -1\n", 93 | "总奖励 = -13\n" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "# 最优策略\n", 99 | "\n", 100 | "actions = np.ones(env.shape, dtype=int)\n", 101 | "actions[-1, :] = 0\n", 102 | "actions[:, -1] = 2\n", 103 | "optimal_policy = np.eye(4)[actions.reshape(-1)]\n", 104 | "print('最优策略 = \\n{}\\n'.format(actions))\n", 105 | "\n", 106 | "total_reward = play_once(env, optimal_policy)\n", 107 | "print('总奖励 = {}'.format(total_reward))" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 5, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# 用Bellman方程求解状态价值和动作价值\n", 117 | "\n", 118 | "def evaluate_bellman(env, policy, gamma=1.0):\n", 119 | " a, b = np.eye(env.nS), np.zeros((env.nS))\n", 120 | " for state in range(env.nS - 1):\n", 121 | " for action in range(env.nA):\n", 122 | " pi = policy[state][action]\n", 123 | " for p, next_state, reward, done in env.P[state][action]:\n", 124 | " a[state, next_state] -= pi * gamma * p\n", 125 | " b[state] += pi * reward * p\n", 126 | " v = np.linalg.solve(a, b)\n", 127 | " \n", 128 | " q = np.zeros((env.nS, env.nA))\n", 129 | " for state in range(env.nS - 1):\n", 130 | " for action in range(env.nA):\n", 131 | " for p, next_state, reward, done in env.P[state][action]:\n", 132 | " q[state][action] += (reward + gamma * v[next_state]) * p\n", 133 | " return v, q" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 6, 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "name": "stdout", 143 | "output_type": "stream", 144 | "text": [ 145 | "状态价值 = [-84537767.17599674 -84537734.14924777 -84537128.96237774\n", 146 | " -84536314.46853566 -84534411.5386176 -84523673.65553102\n", 147 | " -84484789.05526869 -84373265.75787517 -84329922.07866065\n", 148 | " -84275296.29400891 -84157473.9489998 -84126181.5476279\n", 149 | " -84537790.11340794 -84537698.61188367 -84537527.58746669\n", 150 | " -84537068.71889049 -84535609.97717507 -84519784.83687702\n", 151 | " -84501674.3798621 -84427458.1592079 -84303436.46428464\n", 152 | " -84229578.6346756 -84101343.11962558 -84084223.0419421\n", 153 | " -84538076.4686468 -84538041.2098556 -84537974.47486386\n", 154 | " -84537748.23394884 -84536900.56879608 -84526997.16611306\n", 155 | " -84524811.57365067 -84514151.94362962 -84382801.25759234\n", 156 | " -84220081.03580694 -83779780.71472944 -82915912.33177578\n", 157 | " -84538305.52849488 -84538312.12583822 -84538153.59525718\n", 158 | " -84538145.98539561 -84538403.63886732 -84535498.19949386\n", 159 | " -84533771.18294747 -84535051.68012775 -84496447.7055068\n", 160 | " -84499725.97483775 -63199813.39233176 0. ]\n", 161 | "动作价值 = [[-8.45377682e+07 -8.45377351e+07 -8.45377911e+07 -8.45377682e+07]\n", 162 | " [-8.45377351e+07 -8.45371300e+07 -8.45376996e+07 -8.45377682e+07]\n", 163 | " [-8.45371300e+07 -8.45363155e+07 -8.45375286e+07 -8.45377351e+07]\n", 164 | " [-8.45363155e+07 -8.45344125e+07 -8.45370697e+07 -8.45371300e+07]\n", 165 | " [-8.45344125e+07 -8.45236747e+07 -8.45356110e+07 -8.45363155e+07]\n", 166 | " [-8.45236747e+07 -8.44847901e+07 -8.45197858e+07 -8.45344125e+07]\n", 167 | " [-8.44847901e+07 -8.43732668e+07 -8.45016754e+07 -8.45236747e+07]\n", 168 | " [-8.43732668e+07 -8.43299231e+07 -8.44274592e+07 -8.44847901e+07]\n", 169 | " [-8.43299231e+07 -8.42752973e+07 -8.43034375e+07 -8.43732668e+07]\n", 170 | " [-8.42752973e+07 -8.41574749e+07 -8.42295796e+07 -8.43299231e+07]\n", 171 | " [-8.41574749e+07 -8.41261825e+07 -8.41013441e+07 -8.42752973e+07]\n", 172 | " [-8.41261825e+07 -8.41261825e+07 -8.40842240e+07 -8.41574749e+07]\n", 173 | " [-8.45377682e+07 -8.45376996e+07 -8.45380775e+07 -8.45377911e+07]\n", 174 | " [-8.45377351e+07 -8.45375286e+07 -8.45380422e+07 -8.45377911e+07]\n", 175 | " [-8.45371300e+07 -8.45370697e+07 -8.45379755e+07 -8.45376996e+07]\n", 176 | " [-8.45363155e+07 -8.45356110e+07 -8.45377492e+07 -8.45375286e+07]\n", 177 | " [-8.45344125e+07 -8.45197858e+07 -8.45369016e+07 -8.45370697e+07]\n", 178 | " [-8.45236747e+07 -8.45016754e+07 -8.45269982e+07 -8.45356110e+07]\n", 179 | " [-8.44847901e+07 -8.44274592e+07 -8.45248126e+07 -8.45197858e+07]\n", 180 | " [-8.43732668e+07 -8.43034375e+07 -8.45141529e+07 -8.45016754e+07]\n", 181 | " [-8.43299231e+07 -8.42295796e+07 -8.43828023e+07 -8.44274592e+07]\n", 182 | " [-8.42752973e+07 -8.41013441e+07 -8.42200820e+07 -8.43034375e+07]\n", 183 | " [-8.41574749e+07 -8.40842240e+07 -8.37797817e+07 -8.42295796e+07]\n", 184 | " [-8.41261825e+07 -8.40842240e+07 -8.29159133e+07 -8.41013441e+07]\n", 185 | " [-8.45377911e+07 -8.45380422e+07 -8.45383065e+07 -8.45380775e+07]\n", 186 | " [-8.45376996e+07 -8.45379755e+07 -8.45384055e+07 -8.45380775e+07]\n", 187 | " [-8.45375286e+07 -8.45377492e+07 -8.45384055e+07 -8.45380422e+07]\n", 188 | " [-8.45370697e+07 -8.45369016e+07 -8.45384055e+07 -8.45379755e+07]\n", 189 | " [-8.45356110e+07 -8.45269982e+07 -8.45384055e+07 -8.45377492e+07]\n", 190 | " [-8.45197858e+07 -8.45248126e+07 -8.45384055e+07 -8.45369016e+07]\n", 191 | " [-8.45016754e+07 -8.45141529e+07 -8.45384055e+07 -8.45269982e+07]\n", 192 | " [-8.44274592e+07 -8.43828023e+07 -8.45384055e+07 -8.45248126e+07]\n", 193 | " [-8.43034375e+07 -8.42200820e+07 -8.45384055e+07 -8.45141529e+07]\n", 194 | " [-8.42295796e+07 -8.37797817e+07 -8.45384055e+07 -8.43828023e+07]\n", 195 | " [-8.41013441e+07 -8.29159133e+07 -8.45384055e+07 -8.42200820e+07]\n", 196 | " [-8.40842240e+07 -8.29159133e+07 -1.00000000e+00 -8.37797817e+07]\n", 197 | " [-8.45380775e+07 -8.45384055e+07 -8.45383065e+07 -8.45383065e+07]\n", 198 | " [-8.45380422e+07 -8.45384055e+07 -8.45384055e+07 -8.45383065e+07]\n", 199 | " [-8.45379755e+07 -8.45384055e+07 -8.45384055e+07 -8.45384055e+07]\n", 200 | " [-8.45377492e+07 -8.45384055e+07 -8.45384055e+07 -8.45384055e+07]\n", 201 | " [-8.45369016e+07 -8.45384055e+07 -8.45384055e+07 -8.45384055e+07]\n", 202 | " [-8.45269982e+07 -8.45384055e+07 -8.45384055e+07 -8.45384055e+07]\n", 203 | " [-8.45248126e+07 -8.45384055e+07 -8.45384055e+07 -8.45384055e+07]\n", 204 | " [-8.45141529e+07 -8.45384055e+07 -8.45384055e+07 -8.45384055e+07]\n", 205 | " [-8.43828023e+07 -8.45384055e+07 -8.45384055e+07 -8.45384055e+07]\n", 206 | " [-8.42200820e+07 -8.45384055e+07 -8.45384055e+07 -8.45384055e+07]\n", 207 | " [-8.37797817e+07 -1.00000000e+00 -8.45384055e+07 -8.45384055e+07]\n", 208 | " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n" 209 | ] 210 | } 211 | ], 212 | "source": [ 213 | "# 评估随机策略\n", 214 | "\n", 215 | "policy = np.random.uniform(size=(env.nS, env.nA))\n", 216 | "policy = policy / np.sum(policy, axis=1)[:, np.newaxis]\n", 217 | "state_values, action_values = evaluate_bellman(env, policy)\n", 218 | "print('状态价值 = {}'.format(state_values))\n", 219 | "print('动作价值 = {}'.format(action_values))" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 7, 225 | "metadata": {}, 226 | "outputs": [ 227 | { 228 | "name": "stdout", 229 | "output_type": "stream", 230 | "text": [ 231 | "最优状态价值 = [-14. -13. -12. -11. -10. -9. -8. -7. -6. -5. -4. -3. -13. -12.\n", 232 | " -11. -10. -9. -8. -7. -6. -5. -4. -3. -2. -12. -11. -10. -9.\n", 233 | " -8. -7. -6. -5. -4. -3. -2. -1. -13. -12. -11. -10. -9. -8.\n", 234 | " -7. -6. -5. -4. -3. 0.]\n", 235 | "最优动作价值 = [[ -15. -14. -14. -15.]\n", 236 | " [ -14. -13. -13. -15.]\n", 237 | " [ -13. -12. -12. -14.]\n", 238 | " [ -12. -11. -11. -13.]\n", 239 | " [ -11. -10. -10. -12.]\n", 240 | " [ -10. -9. -9. -11.]\n", 241 | " [ -9. -8. -8. -10.]\n", 242 | " [ -8. -7. -7. -9.]\n", 243 | " [ -7. -6. -6. -8.]\n", 244 | " [ -6. -5. -5. -7.]\n", 245 | " [ -5. -4. -4. -6.]\n", 246 | " [ -4. -4. -3. -5.]\n", 247 | " [ -15. -13. -13. -14.]\n", 248 | " [ -14. -12. -12. -14.]\n", 249 | " [ -13. -11. -11. -13.]\n", 250 | " [ -12. -10. -10. -12.]\n", 251 | " [ -11. -9. -9. -11.]\n", 252 | " [ -10. -8. -8. -10.]\n", 253 | " [ -9. -7. -7. -9.]\n", 254 | " [ -8. -6. -6. -8.]\n", 255 | " [ -7. -5. -5. -7.]\n", 256 | " [ -6. -4. -4. -6.]\n", 257 | " [ -5. -3. -3. -5.]\n", 258 | " [ -4. -3. -2. -4.]\n", 259 | " [ -14. -12. -14. -13.]\n", 260 | " [ -13. -11. -113. -13.]\n", 261 | " [ -12. -10. -113. -12.]\n", 262 | " [ -11. -9. -113. -11.]\n", 263 | " [ -10. -8. -113. -10.]\n", 264 | " [ -9. -7. -113. -9.]\n", 265 | " [ -8. -6. -113. -8.]\n", 266 | " [ -7. -5. -113. -7.]\n", 267 | " [ -6. -4. -113. -6.]\n", 268 | " [ -5. -3. -113. -5.]\n", 269 | " [ -4. -2. -113. -4.]\n", 270 | " [ -3. -2. -1. -3.]\n", 271 | " [ -13. -113. -14. -14.]\n", 272 | " [ -12. -113. -113. -14.]\n", 273 | " [ -11. -113. -113. -113.]\n", 274 | " [ -10. -113. -113. -113.]\n", 275 | " [ -9. -113. -113. -113.]\n", 276 | " [ -8. -113. -113. -113.]\n", 277 | " [ -7. -113. -113. -113.]\n", 278 | " [ -6. -113. -113. -113.]\n", 279 | " [ -5. -113. -113. -113.]\n", 280 | " [ -4. -113. -113. -113.]\n", 281 | " [ -3. -1. -113. -113.]\n", 282 | " [ 0. 0. 0. 0.]]\n" 283 | ] 284 | } 285 | ], 286 | "source": [ 287 | "# 评估最优策略\n", 288 | "\n", 289 | "optimal_state_values, optimal_action_values = evaluate_bellman(env, optimal_policy)\n", 290 | "print('最优状态价值 = {}'.format(optimal_state_values))\n", 291 | "print('最优动作价值 = {}'.format(optimal_action_values))" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 8, 297 | "metadata": {}, 298 | "outputs": [ 299 | { 300 | "name": "stdout", 301 | "output_type": "stream", 302 | "text": [ 303 | "最优状态价值 = [-1.40000000e+01 -1.30000000e+01 -1.20000000e+01 -1.10000000e+01\n", 304 | " -1.00000000e+01 -9.00000000e+00 -8.00000000e+00 -7.00000000e+00\n", 305 | " -6.00000000e+00 -5.00000000e+00 -4.00000000e+00 -3.00000000e+00\n", 306 | " -1.30000000e+01 -1.20000000e+01 -1.10000000e+01 -1.00000000e+01\n", 307 | " -9.00000000e+00 -8.00000000e+00 -7.00000000e+00 -6.00000000e+00\n", 308 | " -5.00000000e+00 -4.00000000e+00 -3.00000000e+00 -2.00000000e+00\n", 309 | " -1.20000000e+01 -1.10000000e+01 -1.00000000e+01 -9.00000000e+00\n", 310 | " -8.00000000e+00 -7.00000000e+00 -6.00000000e+00 -5.00000000e+00\n", 311 | " -4.00000000e+00 -3.00000000e+00 -2.00000000e+00 -1.00000000e+00\n", 312 | " -1.30000000e+01 -1.20000000e+01 -1.10000000e+01 -1.00000000e+01\n", 313 | " -9.00000000e+00 -8.00000000e+00 -7.00000000e+00 -6.00000000e+00\n", 314 | " -5.00000000e+00 -4.00000000e+00 -9.99999999e-01 1.82270928e-11]\n", 315 | "最优动作价值 = [[ -14.99999999 -13.99999999 -13.99999999 -14.99999999]\n", 316 | " [ -13.99999999 -13. -13. -14.99999999]\n", 317 | " [ -13. -12. -12. -13.99999999]\n", 318 | " [ -12. -11. -11. -13. ]\n", 319 | " [ -11. -10. -10. -12. ]\n", 320 | " [ -10. -9. -9. -11. ]\n", 321 | " [ -9. -8. -8. -10. ]\n", 322 | " [ -8. -7. -7. -9. ]\n", 323 | " [ -7. -6. -6. -8. ]\n", 324 | " [ -6. -5. -5. -7. ]\n", 325 | " [ -5. -4. -4. -6. ]\n", 326 | " [ -4. -4. -3. -5. ]\n", 327 | " [ -14.99999999 -13. -13. -13.99999999]\n", 328 | " [ -13.99999999 -12. -12. -13.99999999]\n", 329 | " [ -13. -11. -11. -13. ]\n", 330 | " [ -12. -10. -10. -12. ]\n", 331 | " [ -11. -9. -9. -11. ]\n", 332 | " [ -10. -8. -8. -10. ]\n", 333 | " [ -9. -7. -7. -9. ]\n", 334 | " [ -8. -6. -6. -8. ]\n", 335 | " [ -7. -5. -5. -7. ]\n", 336 | " [ -6. -4. -4. -6. ]\n", 337 | " [ -5. -3. -3. -5. ]\n", 338 | " [ -4. -3. -2. -4. ]\n", 339 | " [ -13.99999999 -12. -14. -13. ]\n", 340 | " [ -13. -11. -113. -13. ]\n", 341 | " [ -12. -10. -113. -12. ]\n", 342 | " [ -11. -9. -113. -11. ]\n", 343 | " [ -10. -8. -113. -10. ]\n", 344 | " [ -9. -7. -113. -9. ]\n", 345 | " [ -8. -6. -113. -8. ]\n", 346 | " [ -7. -5. -113. -7. ]\n", 347 | " [ -6. -4. -113. -6. ]\n", 348 | " [ -5. -3. -113. -5. ]\n", 349 | " [ -4. -2. -113. -4. ]\n", 350 | " [ -3. -2. -1. -3. ]\n", 351 | " [ -13. -113. -14. -14. ]\n", 352 | " [ -12. -113. -113. -14. ]\n", 353 | " [ -11. -113. -113. -113. ]\n", 354 | " [ -10. -113. -113. -113. ]\n", 355 | " [ -9. -113. -113. -113. ]\n", 356 | " [ -8. -113. -113. -113. ]\n", 357 | " [ -7. -113. -113. -113. ]\n", 358 | " [ -6. -113. -113. -113. ]\n", 359 | " [ -5. -113. -113. -113. ]\n", 360 | " [ -4. -113. -113. -113. ]\n", 361 | " [ -3. -1. -113. -113. ]\n", 362 | " [ 0. 0. 0. 0. ]]\n" 363 | ] 364 | } 365 | ], 366 | "source": [ 367 | "# 用线性规划求解Bellman最优方程\n", 368 | "\n", 369 | "def optimal_bellman(env, gamma=1.0):\n", 370 | " p = np.zeros((env.nS, env.nA, env.nS))\n", 371 | " r = np.zeros((env.nS, env.nA))\n", 372 | " for state in range(env.nS - 1):\n", 373 | " for action in range(env.nA):\n", 374 | " for prob, next_state, reward, done in env.P[state][action]:\n", 375 | " p[state, action, next_state] += prob\n", 376 | " r[state, action] += reward * prob\n", 377 | " \n", 378 | " c = np.ones((env.nS))\n", 379 | " a_ub = gamma * p.reshape(-1, env.nS) - np.repeat(np.eye(env.nS), env.nA, axis=0)\n", 380 | " b_ub = -r.reshape(-1)\n", 381 | " \n", 382 | " bounds = [(None, None),] * env.nS\n", 383 | " res = scipy.optimize.linprog(c, a_ub, b_ub, bounds=bounds, method='interior-point')\n", 384 | " v = res.x\n", 385 | " q = r + gamma * np.dot(p, v)\n", 386 | " return v, q\n", 387 | "\n", 388 | "optimal_state_values, optimal_action_values = optimal_bellman(env)\n", 389 | "print('最优状态价值 = {}'.format(optimal_state_values))\n", 390 | "print('最优动作价值 = {}'.format(optimal_action_values))" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 9, 396 | "metadata": {}, 397 | "outputs": [ 398 | { 399 | "name": "stdout", 400 | "output_type": "stream", 401 | "text": [ 402 | "最优策略 = \n", 403 | "[[2 1 1 1 1 1 1 1 1 1 1 2]\n", 404 | " [1 1 1 1 1 1 1 1 1 1 1 2]\n", 405 | " [1 1 1 1 1 1 1 1 1 1 1 2]\n", 406 | " [0 0 0 0 0 0 0 0 0 0 1 0]]\n" 407 | ] 408 | } 409 | ], 410 | "source": [ 411 | "# 用最优动作价值确定最优确定性策略\n", 412 | "\n", 413 | "optimal_actions = optimal_action_values.argmax(axis=1)\n", 414 | "print('最优策略 = \\n{}'.format(optimal_actions.reshape(env.shape)))" 415 | ] 416 | } 417 | ], 418 | "metadata": { 419 | "kernelspec": { 420 | "display_name": "Python 3", 421 | "language": "python", 422 | "name": "python3" 423 | }, 424 | "language_info": { 425 | "codemirror_mode": { 426 | "name": "ipython", 427 | "version": 3 428 | }, 429 | "file_extension": ".py", 430 | "mimetype": "text/x-python", 431 | "name": "python", 432 | "nbconvert_exporter": "python", 433 | "pygments_lexer": "ipython3", 434 | "version": "3.7.6" 435 | } 436 | }, 437 | "nbformat": 4, 438 | "nbformat_minor": 4 439 | } 440 | -------------------------------------------------------------------------------- /Chapter3-有模型数值迭代/3.5-案例:冰面滑行.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gym\n", 10 | "import numpy as np" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [ 18 | { 19 | "name": "stdout", 20 | "output_type": "stream", 21 | "text": [ 22 | "观测空间:Discrete(16)\n", 23 | "动作空间:Discrete(4)\n" 24 | ] 25 | }, 26 | { 27 | "data": { 28 | "text/plain": [ 29 | "{0: {0: [(0.3333333333333333, 0, 0.0, False),\n", 30 | " (0.3333333333333333, 0, 0.0, False),\n", 31 | " (0.3333333333333333, 4, 0.0, False)],\n", 32 | " 1: [(0.3333333333333333, 0, 0.0, False),\n", 33 | " (0.3333333333333333, 4, 0.0, False),\n", 34 | " (0.3333333333333333, 1, 0.0, False)],\n", 35 | " 2: [(0.3333333333333333, 4, 0.0, False),\n", 36 | " (0.3333333333333333, 1, 0.0, False),\n", 37 | " (0.3333333333333333, 0, 0.0, False)],\n", 38 | " 3: [(0.3333333333333333, 1, 0.0, False),\n", 39 | " (0.3333333333333333, 0, 0.0, False),\n", 40 | " (0.3333333333333333, 0, 0.0, False)]},\n", 41 | " 1: {0: [(0.3333333333333333, 1, 0.0, False),\n", 42 | " (0.3333333333333333, 0, 0.0, False),\n", 43 | " (0.3333333333333333, 5, 0.0, True)],\n", 44 | " 1: [(0.3333333333333333, 0, 0.0, False),\n", 45 | " (0.3333333333333333, 5, 0.0, True),\n", 46 | " (0.3333333333333333, 2, 0.0, False)],\n", 47 | " 2: [(0.3333333333333333, 5, 0.0, True),\n", 48 | " (0.3333333333333333, 2, 0.0, False),\n", 49 | " (0.3333333333333333, 1, 0.0, False)],\n", 50 | " 3: [(0.3333333333333333, 2, 0.0, False),\n", 51 | " (0.3333333333333333, 1, 0.0, False),\n", 52 | " (0.3333333333333333, 0, 0.0, False)]},\n", 53 | " 2: {0: [(0.3333333333333333, 2, 0.0, False),\n", 54 | " (0.3333333333333333, 1, 0.0, False),\n", 55 | " (0.3333333333333333, 6, 0.0, False)],\n", 56 | " 1: [(0.3333333333333333, 1, 0.0, False),\n", 57 | " (0.3333333333333333, 6, 0.0, False),\n", 58 | " (0.3333333333333333, 3, 0.0, False)],\n", 59 | " 2: [(0.3333333333333333, 6, 0.0, False),\n", 60 | " (0.3333333333333333, 3, 0.0, False),\n", 61 | " (0.3333333333333333, 2, 0.0, False)],\n", 62 | " 3: [(0.3333333333333333, 3, 0.0, False),\n", 63 | " (0.3333333333333333, 2, 0.0, False),\n", 64 | " (0.3333333333333333, 1, 0.0, False)]},\n", 65 | " 3: {0: [(0.3333333333333333, 3, 0.0, False),\n", 66 | " (0.3333333333333333, 2, 0.0, False),\n", 67 | " (0.3333333333333333, 7, 0.0, True)],\n", 68 | " 1: [(0.3333333333333333, 2, 0.0, False),\n", 69 | " (0.3333333333333333, 7, 0.0, True),\n", 70 | " (0.3333333333333333, 3, 0.0, False)],\n", 71 | " 2: [(0.3333333333333333, 7, 0.0, True),\n", 72 | " (0.3333333333333333, 3, 0.0, False),\n", 73 | " (0.3333333333333333, 3, 0.0, False)],\n", 74 | " 3: [(0.3333333333333333, 3, 0.0, False),\n", 75 | " (0.3333333333333333, 3, 0.0, False),\n", 76 | " (0.3333333333333333, 2, 0.0, False)]},\n", 77 | " 4: {0: [(0.3333333333333333, 0, 0.0, False),\n", 78 | " (0.3333333333333333, 4, 0.0, False),\n", 79 | " (0.3333333333333333, 8, 0.0, False)],\n", 80 | " 1: [(0.3333333333333333, 4, 0.0, False),\n", 81 | " (0.3333333333333333, 8, 0.0, False),\n", 82 | " (0.3333333333333333, 5, 0.0, True)],\n", 83 | " 2: [(0.3333333333333333, 8, 0.0, False),\n", 84 | " (0.3333333333333333, 5, 0.0, True),\n", 85 | " (0.3333333333333333, 0, 0.0, False)],\n", 86 | " 3: [(0.3333333333333333, 5, 0.0, True),\n", 87 | " (0.3333333333333333, 0, 0.0, False),\n", 88 | " (0.3333333333333333, 4, 0.0, False)]},\n", 89 | " 5: {0: [(1.0, 5, 0, True)],\n", 90 | " 1: [(1.0, 5, 0, True)],\n", 91 | " 2: [(1.0, 5, 0, True)],\n", 92 | " 3: [(1.0, 5, 0, True)]},\n", 93 | " 6: {0: [(0.3333333333333333, 2, 0.0, False),\n", 94 | " (0.3333333333333333, 5, 0.0, True),\n", 95 | " (0.3333333333333333, 10, 0.0, False)],\n", 96 | " 1: [(0.3333333333333333, 5, 0.0, True),\n", 97 | " (0.3333333333333333, 10, 0.0, False),\n", 98 | " (0.3333333333333333, 7, 0.0, True)],\n", 99 | " 2: [(0.3333333333333333, 10, 0.0, False),\n", 100 | " (0.3333333333333333, 7, 0.0, True),\n", 101 | " (0.3333333333333333, 2, 0.0, False)],\n", 102 | " 3: [(0.3333333333333333, 7, 0.0, True),\n", 103 | " (0.3333333333333333, 2, 0.0, False),\n", 104 | " (0.3333333333333333, 5, 0.0, True)]},\n", 105 | " 7: {0: [(1.0, 7, 0, True)],\n", 106 | " 1: [(1.0, 7, 0, True)],\n", 107 | " 2: [(1.0, 7, 0, True)],\n", 108 | " 3: [(1.0, 7, 0, True)]},\n", 109 | " 8: {0: [(0.3333333333333333, 4, 0.0, False),\n", 110 | " (0.3333333333333333, 8, 0.0, False),\n", 111 | " (0.3333333333333333, 12, 0.0, True)],\n", 112 | " 1: [(0.3333333333333333, 8, 0.0, False),\n", 113 | " (0.3333333333333333, 12, 0.0, True),\n", 114 | " (0.3333333333333333, 9, 0.0, False)],\n", 115 | " 2: [(0.3333333333333333, 12, 0.0, True),\n", 116 | " (0.3333333333333333, 9, 0.0, False),\n", 117 | " (0.3333333333333333, 4, 0.0, False)],\n", 118 | " 3: [(0.3333333333333333, 9, 0.0, False),\n", 119 | " (0.3333333333333333, 4, 0.0, False),\n", 120 | " (0.3333333333333333, 8, 0.0, False)]},\n", 121 | " 9: {0: [(0.3333333333333333, 5, 0.0, True),\n", 122 | " (0.3333333333333333, 8, 0.0, False),\n", 123 | " (0.3333333333333333, 13, 0.0, False)],\n", 124 | " 1: [(0.3333333333333333, 8, 0.0, False),\n", 125 | " (0.3333333333333333, 13, 0.0, False),\n", 126 | " (0.3333333333333333, 10, 0.0, False)],\n", 127 | " 2: [(0.3333333333333333, 13, 0.0, False),\n", 128 | " (0.3333333333333333, 10, 0.0, False),\n", 129 | " (0.3333333333333333, 5, 0.0, True)],\n", 130 | " 3: [(0.3333333333333333, 10, 0.0, False),\n", 131 | " (0.3333333333333333, 5, 0.0, True),\n", 132 | " (0.3333333333333333, 8, 0.0, False)]},\n", 133 | " 10: {0: [(0.3333333333333333, 6, 0.0, False),\n", 134 | " (0.3333333333333333, 9, 0.0, False),\n", 135 | " (0.3333333333333333, 14, 0.0, False)],\n", 136 | " 1: [(0.3333333333333333, 9, 0.0, False),\n", 137 | " (0.3333333333333333, 14, 0.0, False),\n", 138 | " (0.3333333333333333, 11, 0.0, True)],\n", 139 | " 2: [(0.3333333333333333, 14, 0.0, False),\n", 140 | " (0.3333333333333333, 11, 0.0, True),\n", 141 | " (0.3333333333333333, 6, 0.0, False)],\n", 142 | " 3: [(0.3333333333333333, 11, 0.0, True),\n", 143 | " (0.3333333333333333, 6, 0.0, False),\n", 144 | " (0.3333333333333333, 9, 0.0, False)]},\n", 145 | " 11: {0: [(1.0, 11, 0, True)],\n", 146 | " 1: [(1.0, 11, 0, True)],\n", 147 | " 2: [(1.0, 11, 0, True)],\n", 148 | " 3: [(1.0, 11, 0, True)]},\n", 149 | " 12: {0: [(1.0, 12, 0, True)],\n", 150 | " 1: [(1.0, 12, 0, True)],\n", 151 | " 2: [(1.0, 12, 0, True)],\n", 152 | " 3: [(1.0, 12, 0, True)]},\n", 153 | " 13: {0: [(0.3333333333333333, 9, 0.0, False),\n", 154 | " (0.3333333333333333, 12, 0.0, True),\n", 155 | " (0.3333333333333333, 13, 0.0, False)],\n", 156 | " 1: [(0.3333333333333333, 12, 0.0, True),\n", 157 | " (0.3333333333333333, 13, 0.0, False),\n", 158 | " (0.3333333333333333, 14, 0.0, False)],\n", 159 | " 2: [(0.3333333333333333, 13, 0.0, False),\n", 160 | " (0.3333333333333333, 14, 0.0, False),\n", 161 | " (0.3333333333333333, 9, 0.0, False)],\n", 162 | " 3: [(0.3333333333333333, 14, 0.0, False),\n", 163 | " (0.3333333333333333, 9, 0.0, False),\n", 164 | " (0.3333333333333333, 12, 0.0, True)]},\n", 165 | " 14: {0: [(0.3333333333333333, 10, 0.0, False),\n", 166 | " (0.3333333333333333, 13, 0.0, False),\n", 167 | " (0.3333333333333333, 14, 0.0, False)],\n", 168 | " 1: [(0.3333333333333333, 13, 0.0, False),\n", 169 | " (0.3333333333333333, 14, 0.0, False),\n", 170 | " (0.3333333333333333, 15, 1.0, True)],\n", 171 | " 2: [(0.3333333333333333, 14, 0.0, False),\n", 172 | " (0.3333333333333333, 15, 1.0, True),\n", 173 | " (0.3333333333333333, 10, 0.0, False)],\n", 174 | " 3: [(0.3333333333333333, 15, 1.0, True),\n", 175 | " (0.3333333333333333, 10, 0.0, False),\n", 176 | " (0.3333333333333333, 13, 0.0, False)]},\n", 177 | " 15: {0: [(1.0, 15, 0, True)],\n", 178 | " 1: [(1.0, 15, 0, True)],\n", 179 | " 2: [(1.0, 15, 0, True)],\n", 180 | " 3: [(1.0, 15, 0, True)]}}" 181 | ] 182 | }, 183 | "execution_count": 2, 184 | "metadata": {}, 185 | "output_type": "execute_result" 186 | } 187 | ], 188 | "source": [ 189 | "env = gym.make('FrozenLake-v0')\n", 190 | "env = env.unwrapped\n", 191 | "\n", 192 | "print('观测空间:{}'.format(env.observation_space))\n", 193 | "print('动作空间:{}'.format(env.action_space))\n", 194 | "env.P" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 3, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "# 用策略执行一个回合\n", 204 | "\n", 205 | "def play_policy(env, policy, render=False):\n", 206 | " total_reward = 0\n", 207 | " observation = env.reset()\n", 208 | " while True:\n", 209 | " if render:\n", 210 | " env.render()\n", 211 | " \n", 212 | " action = np.random.choice(env.action_space.n, p=policy[observation])\n", 213 | " observation, reward, done, _ = env.step(action)\n", 214 | " total_reward += reward\n", 215 | " if done:\n", 216 | " break\n", 217 | " \n", 218 | " return total_reward" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 4, 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "name": "stdout", 228 | "output_type": "stream", 229 | "text": [ 230 | "随机策略 平均奖励 = 0.0\n" 231 | ] 232 | } 233 | ], 234 | "source": [ 235 | "# 求随机策略的期望奖励\n", 236 | "\n", 237 | "random_policy = np.ones((env.nS, env.nA)) / env.nA\n", 238 | "episode_rewards = [play_policy(env, random_policy) for _ in range(100)]\n", 239 | "print('随机策略 平均奖励 = {}'.format(np.mean(episode_rewards)))" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 5, 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "# 策略评估的实现\n", 249 | "\n", 250 | "# 根据状态价值函数计算动作价值函数\n", 251 | "def v2q(env, v, s=None, gamma=1.0):\n", 252 | " # 针对单个状态求解\n", 253 | " if s is not None:\n", 254 | " q = np.zeros(env.nA)\n", 255 | " for a in range(env.nA):\n", 256 | " for prob, next_state, reward, done in env.P[s][a]:\n", 257 | " q[a] += prob * (reward + gamma * v[next_state] * (1.0 - done))\n", 258 | " else: # 针对所有状态求解\n", 259 | " q = np.zeros((env.nS, env.nA))\n", 260 | " for s in range(env.nS):\n", 261 | " q[s] = v2q(env, v, s, gamma)\n", 262 | " return q\n", 263 | "\n", 264 | "def evaluate_policy(env, policy, gamma=1.0, tolerant=1e-6):\n", 265 | " v = np.zeros(env.nS)\n", 266 | " while True:\n", 267 | " delta = 0\n", 268 | " for s in range(env.nS):\n", 269 | " vs = sum(policy[s] * v2q(env, v, s, gamma))\n", 270 | " delta = max(delta, abs(v[s] - vs))\n", 271 | " v[s] = vs\n", 272 | " if delta < tolerant:\n", 273 | " break\n", 274 | " return v" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 6, 280 | "metadata": {}, 281 | "outputs": [ 282 | { 283 | "name": "stdout", 284 | "output_type": "stream", 285 | "text": [ 286 | "状态价值函数:\n", 287 | "[[0.0139372 0.01162942 0.02095187 0.01047569]\n", 288 | " [0.01624741 0. 0.04075119 0. ]\n", 289 | " [0.03480561 0.08816967 0.14205297 0. ]\n", 290 | " [0. 0.17582021 0.43929104 0. ]]\n", 291 | "动作状态价值:\n", 292 | "[[0.01470727 0.01393801 0.01393801 0.01316794]\n", 293 | " [0.00852221 0.01162969 0.01086043 0.01550616]\n", 294 | " [0.02444416 0.0209521 0.02405958 0.01435233]\n", 295 | " [0.01047585 0.01047585 0.00698379 0.01396775]\n", 296 | " [0.02166341 0.01701767 0.0162476 0.01006154]\n", 297 | " [0. 0. 0. 0. ]\n", 298 | " [0.05433495 0.04735099 0.05433495 0.00698396]\n", 299 | " [0. 0. 0. 0. ]\n", 300 | " [0.01701767 0.04099176 0.03480569 0.04640756]\n", 301 | " [0.0702086 0.11755959 0.10595772 0.05895286]\n", 302 | " [0.18940397 0.17582024 0.16001408 0.04297362]\n", 303 | " [0. 0. 0. 0. ]\n", 304 | " [0. 0. 0. 0. ]\n", 305 | " [0.08799662 0.20503708 0.23442697 0.17582024]\n", 306 | " [0.25238807 0.53837042 0.52711467 0.43929106]\n", 307 | " [0. 0. 0. 0. ]]\n" 308 | ] 309 | } 310 | ], 311 | "source": [ 312 | "# 对随机策略进行策略评估\n", 313 | "\n", 314 | "print('状态价值函数:')\n", 315 | "v_random = evaluate_policy(env, random_policy)\n", 316 | "print(v_random.reshape(4, 4))\n", 317 | "\n", 318 | "print('动作状态价值:')\n", 319 | "q_random = v2q(env, v_random)\n", 320 | "print(q_random)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 7, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "# 策略改进的实现\n", 330 | "\n", 331 | "def improve_policy(env, v, policy, gamma=1.0):\n", 332 | " optimal = True\n", 333 | " for s in range(env.nS):\n", 334 | " q = v2q(env, v, s, gamma)\n", 335 | " a = np.argmax(q)\n", 336 | " if policy[s][a] != 1.0:\n", 337 | " optimal = False\n", 338 | " policy[s] = 0.0\n", 339 | " policy[s][a] = 1.0\n", 340 | " return optimal" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 8, 346 | "metadata": {}, 347 | "outputs": [ 348 | { 349 | "name": "stdout", 350 | "output_type": "stream", 351 | "text": [ 352 | "有更新,更新后的策略为:\n", 353 | "[[1. 0. 0. 0.]\n", 354 | " [0. 0. 0. 1.]\n", 355 | " [1. 0. 0. 0.]\n", 356 | " [0. 0. 0. 1.]\n", 357 | " [1. 0. 0. 0.]\n", 358 | " [1. 0. 0. 0.]\n", 359 | " [1. 0. 0. 0.]\n", 360 | " [1. 0. 0. 0.]\n", 361 | " [0. 0. 0. 1.]\n", 362 | " [0. 1. 0. 0.]\n", 363 | " [1. 0. 0. 0.]\n", 364 | " [1. 0. 0. 0.]\n", 365 | " [1. 0. 0. 0.]\n", 366 | " [0. 0. 1. 0.]\n", 367 | " [0. 1. 0. 0.]\n", 368 | " [1. 0. 0. 0.]]\n" 369 | ] 370 | } 371 | ], 372 | "source": [ 373 | "# 对随机策略进行策略改进\n", 374 | "\n", 375 | "policy = random_policy.copy()\n", 376 | "optimal = improve_policy(env, v_random, policy)\n", 377 | "if optimal:\n", 378 | " print('无更新,最优策略为:')\n", 379 | "else:\n", 380 | " print('有更新,更新后的策略为:')\n", 381 | "print(policy)" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": 9, 387 | "metadata": {}, 388 | "outputs": [], 389 | "source": [ 390 | "# 策略迭代的实现\n", 391 | "\n", 392 | "def iterate_policy(env, gamma=1.0, tolerant=1e-6):\n", 393 | " policy = np.ones((env.nS, env.nA)) / env.nA\n", 394 | " while True:\n", 395 | " v = evaluate_policy(env, policy, gamma, tolerant)\n", 396 | " if improve_policy(env, v, policy):\n", 397 | " break\n", 398 | " return policy, v" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": 10, 404 | "metadata": {}, 405 | "outputs": [ 406 | { 407 | "name": "stdout", 408 | "output_type": "stream", 409 | "text": [ 410 | "状态价值函数 =\n", 411 | "[[0.82351246 0.82350689 0.82350303 0.82350106]\n", 412 | " [0.82351416 0. 0.5294002 0. ]\n", 413 | " [0.82351683 0.82352026 0.76469786 0. ]\n", 414 | " [0. 0.88234658 0.94117323 0. ]]\n", 415 | "最优策略 =\n", 416 | "[[0 3 3 3]\n", 417 | " [0 0 0 0]\n", 418 | " [3 1 0 0]\n", 419 | " [0 2 1 0]]\n" 420 | ] 421 | } 422 | ], 423 | "source": [ 424 | "# 利用策略迭代求解最优策略\n", 425 | "\n", 426 | "policy_pi, v_pi = iterate_policy(env)\n", 427 | "print('状态价值函数 =')\n", 428 | "print(v_pi.reshape(4, 4))\n", 429 | "print('最优策略 =')\n", 430 | "print(np.argmax(policy_pi, axis=1).reshape(4, 4))" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 11, 436 | "metadata": {}, 437 | "outputs": [], 438 | "source": [ 439 | "# 价值迭代的实现\n", 440 | "\n", 441 | "def iterate_value(env, gamma=1.0, tolerant=1e-6):\n", 442 | " v = np.zeros(env.nS)\n", 443 | " while True:\n", 444 | " delta = 0\n", 445 | " for s in range(env.nS):\n", 446 | " vmax = max(v2q(env, v, s, gamma))\n", 447 | " delta = max(delta, abs(v[s] - vmax))\n", 448 | " v[s] = vmax\n", 449 | " if delta < tolerant:\n", 450 | " break\n", 451 | " \n", 452 | " policy = np.zeros((env.nS, env.nA))\n", 453 | " for s in range(env.nS):\n", 454 | " a = np.argmax(v2q(env, v, s, gamma))\n", 455 | " policy[s][a] = 1.0\n", 456 | " return policy, v" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": 12, 462 | "metadata": {}, 463 | "outputs": [ 464 | { 465 | "name": "stdout", 466 | "output_type": "stream", 467 | "text": [ 468 | "状态价值函数 =\n", 469 | "[[0.82351232 0.82350671 0.82350281 0.82350083]\n", 470 | " [0.82351404 0. 0.52940011 0. ]\n", 471 | " [0.82351673 0.82352018 0.76469779 0. ]\n", 472 | " [0. 0.88234653 0.94117321 0. ]]\n", 473 | "最优策略 =\n", 474 | "[[0 3 3 3]\n", 475 | " [0 0 0 0]\n", 476 | " [3 1 0 0]\n", 477 | " [0 2 1 0]]\n", 478 | "价值迭代 平均奖励:0.0\n" 479 | ] 480 | } 481 | ], 482 | "source": [ 483 | "# 利用价值迭代算法求解最优策略\n", 484 | "\n", 485 | "policy_vi, v_vi = iterate_value(env)\n", 486 | "print('状态价值函数 =')\n", 487 | "print(v_vi.reshape((4, 4)))\n", 488 | "print('最优策略 =')\n", 489 | "print(np.argmax(policy_vi, axis=1).reshape(4, 4))\n", 490 | "episode_rewards = [play_policy(env, policy_vi) for _ in range(100)]\n", 491 | "print('价值迭代 平均奖励:{}'.format(np.mean(episode_rewards)))" 492 | ] 493 | } 494 | ], 495 | "metadata": { 496 | "kernelspec": { 497 | "display_name": "Python 3", 498 | "language": "python", 499 | "name": "python3" 500 | }, 501 | "language_info": { 502 | "codemirror_mode": { 503 | "name": "ipython", 504 | "version": 3 505 | }, 506 | "file_extension": ".py", 507 | "mimetype": "text/x-python", 508 | "name": "python", 509 | "nbconvert_exporter": "python", 510 | "pygments_lexer": "ipython3", 511 | "version": "3.7.6" 512 | } 513 | }, 514 | "nbformat": 4, 515 | "nbformat_minor": 4 516 | } 517 | -------------------------------------------------------------------------------- /Chapter4-回合更新价值迭代/4.3-案例:21点游戏.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gym\n", 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "\n", 13 | "from tqdm.notebook import tqdm\n", 14 | "\n", 15 | "np.random.seed(0)" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "观测空间: Tuple(Discrete(32), Discrete(11), Discrete(2))\n", 28 | "动作空间: Discrete(2)\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "env = gym.make('Blackjack-v0')\n", 34 | "env.seed(0)\n", 35 | "\n", 36 | "print('观测空间: {}'.format(env.observation_space))\n", 37 | "print('动作空间: {}'.format(env.action_space))" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "name": "stdout", 47 | "output_type": "stream", 48 | "text": [ 49 | "观测 = (18, 1, False)\n", 50 | "玩家 = [10, 8], 庄家 = [1, 7]\n", 51 | "动作 = 0\n", 52 | "观测 = (18, 1, False), 奖励 = 0.0, 结束指示 = True\n" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "# 用随机策略玩一个回合\n", 58 | "\n", 59 | "observation = env.reset()\n", 60 | "print('观测 = {}'.format(observation))\n", 61 | "\n", 62 | "while True:\n", 63 | " print('玩家 = {}, 庄家 = {}'.format(env.player, env.dealer))\n", 64 | " action = np.random.choice(env.action_space.n)\n", 65 | " print('动作 = {}'.format(action))\n", 66 | " observation, reward, done, _ = env.step(action)\n", 67 | " print('观测 = {}, 奖励 = {}, 结束指示 = {}'.format(observation, reward, done))\n", 68 | " \n", 69 | " if done:\n", 70 | " break" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 4, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# 从观测到状态\n", 80 | "\n", 81 | "def ob2state(observation):\n", 82 | " return (observation[0], observation[1], int(observation[2]))" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 5, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "# 同策回合更新策略评估\n", 92 | "\n", 93 | "def evaluate_action_monte_carlo(env, policy, episode_num=500000):\n", 94 | " q = np.zeros_like(policy)\n", 95 | " c = np.zeros_like(policy)\n", 96 | " \n", 97 | " for _ in tqdm(range(episode_num)):\n", 98 | " state_actions = []\n", 99 | " observation = env.reset()\n", 100 | " while True:\n", 101 | " state = ob2state(observation)\n", 102 | " action = np.random.choice(env.action_space.n, p=policy[state])\n", 103 | " state_actions.append((state, action))\n", 104 | " observation, reward, done, _ = env.step(action)\n", 105 | " \n", 106 | " if done:\n", 107 | " break\n", 108 | " \n", 109 | " g = reward\n", 110 | " for state, action in state_actions:\n", 111 | " c[state][action] += 1\n", 112 | " q[state][action] += (g - q[state][action]) / c[state][action]\n", 113 | " return q" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 6, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "application/vnd.jupyter.widget-view+json": { 124 | "model_id": "55843890a9a04fe6b676c26cf69065f1", 125 | "version_major": 2, 126 | "version_minor": 0 127 | }, 128 | "text/plain": [ 129 | "HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))" 130 | ] 131 | }, 132 | "metadata": {}, 133 | "output_type": "display_data" 134 | }, 135 | { 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "\n" 140 | ] 141 | } 142 | ], 143 | "source": [ 144 | "policy = np.zeros((22, 11, 2, 2))\n", 145 | "policy[20:, :, :, 0] = 1 # >=20时不再要牌\n", 146 | "policy[:20, :, :, 1] = 1 # <20时再要牌\n", 147 | "q = evaluate_action_monte_carlo(env, policy)\n", 148 | "v = (q * policy).sum(axis=-1)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 7, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "# 绘制最后一维的指标为0或1的3维数组\n", 158 | "def plot(data):\n", 159 | " fig, axes = plt.subplots(1, 2, figsize=(9, 4))\n", 160 | " titles = ['without ace', 'with ace']\n", 161 | " have_aces = [0, 1]\n", 162 | " extent = [12, 22, 1, 11]\n", 163 | " \n", 164 | " for title, have_ace, axis in zip(titles, have_aces, axes):\n", 165 | " dat = data[extent[0]:extent[1], extent[2]:extent[3], have_ace].T\n", 166 | " axis.imshow(dat, extent=extent, origin='lower')\n", 167 | " axis.set_xlabel('player sum')\n", 168 | " axis.set_ylabel('dealer showing')\n", 169 | " axis.set_title(title)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 8, 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "data": { 179 | "image/png": "\n", 180 | "text/plain": [ 181 | "
" 182 | ] 183 | }, 184 | "metadata": { 185 | "needs_background": "light" 186 | }, 187 | "output_type": "display_data" 188 | } 189 | ], 190 | "source": [ 191 | "plot(v)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 9, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "# 带起始探索的同策回合更新\n", 201 | "\n", 202 | "def monte_carlo_with_exploring_start(env, episode_num=500000):\n", 203 | " policy = np.zeros((22, 11, 2, 2))\n", 204 | " policy[:, :, :, 1] = 1.\n", 205 | " q = np.zeros_like(policy)\n", 206 | " c = np.zeros_like(policy)\n", 207 | " \n", 208 | " for _ in tqdm(range(episode_num)):\n", 209 | " # 随机选择起始状态和起始动作\n", 210 | " state = (np.random.randint(12, 22),\n", 211 | " np.random.randint(1, 11),\n", 212 | " np.random.randint(2))\n", 213 | " action = np.random.randint(2)\n", 214 | " \n", 215 | " env.reset()\n", 216 | " if state[2]: # 有A\n", 217 | " env.player = [1, state[0] - 11]\n", 218 | " else: # 没有A\n", 219 | " if state[0] == 21:\n", 220 | " env.player = [10, 9, 2]\n", 221 | " else:\n", 222 | " env.player = [10, state[0] - 10]\n", 223 | " env.dealer[0] = state[1]\n", 224 | " \n", 225 | " state_actions = []\n", 226 | " while True:\n", 227 | " state_actions.append((state, action))\n", 228 | " observation, reward, done, _ = env.step(action)\n", 229 | " \n", 230 | " if done:\n", 231 | " break\n", 232 | " \n", 233 | " state = ob2state(observation)\n", 234 | " action = np.random.choice(env.action_space.n, p=policy[state])\n", 235 | " \n", 236 | " g = reward\n", 237 | " for state, action in state_actions:\n", 238 | " c[state][action] += 1.\n", 239 | " q[state][action] += (g - q[state][action]) / c[state][action]\n", 240 | " a = q[state].argmax()\n", 241 | " policy[state] = 0.\n", 242 | " policy[state][a] = 1.\n", 243 | " \n", 244 | " return policy, q" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 10, 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "data": { 254 | "application/vnd.jupyter.widget-view+json": { 255 | "model_id": "0c3eb0a7a75c4a398889b3fd014b1b2c", 256 | "version_major": 2, 257 | "version_minor": 0 258 | }, 259 | "text/plain": [ 260 | "HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))" 261 | ] 262 | }, 263 | "metadata": {}, 264 | "output_type": "display_data" 265 | }, 266 | { 267 | "name": "stdout", 268 | "output_type": "stream", 269 | "text": [ 270 | "\n" 271 | ] 272 | } 273 | ], 274 | "source": [ 275 | "policy, q = monte_carlo_with_exploring_start(env)\n", 276 | "v = q.max(axis=-1)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 11, 282 | "metadata": {}, 283 | "outputs": [ 284 | { 285 | "data": { 286 | "image/png": "\n", 287 | "text/plain": [ 288 | "
" 289 | ] 290 | }, 291 | "metadata": { 292 | "needs_background": "light" 293 | }, 294 | "output_type": "display_data" 295 | }, 296 | { 297 | "data": { 298 | "image/png": "\n", 299 | "text/plain": [ 300 | "
" 301 | ] 302 | }, 303 | "metadata": { 304 | "needs_background": "light" 305 | }, 306 | "output_type": "display_data" 307 | } 308 | ], 309 | "source": [ 310 | "plot(policy.argmax(-1))\n", 311 | "plot(v)" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 12, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "# 基于柔性策略的同策回合更新\n", 321 | "\n", 322 | "def monte_carlo_with_soft(env, episode_num=500000, epsilon=0.1):\n", 323 | " policy = np.ones((22, 11, 2, 2)) * 0.5\n", 324 | " q = np.zeros_like(policy)\n", 325 | " c = np.zeros_like(policy)\n", 326 | " \n", 327 | " for _ in tqdm(range(episode_num)):\n", 328 | " state_actions = []\n", 329 | " observation = env.reset()\n", 330 | " \n", 331 | " while True:\n", 332 | " state = ob2state(observation)\n", 333 | " action = np.random.choice(env.action_space.n, p=policy[state])\n", 334 | " state_actions.append([state, action])\n", 335 | " \n", 336 | " observation, reward, done, _ = env.step(action)\n", 337 | " if done:\n", 338 | " break\n", 339 | " \n", 340 | " g = reward\n", 341 | " for state, action in state_actions:\n", 342 | " c[state][action] += 1\n", 343 | " q[state][action] += (g - q[state][action]) / c[state][action]\n", 344 | " a = q[state].argmax()\n", 345 | " policy[state] = epsilon / 2.0\n", 346 | " policy[state][a] += 1.0 - epsilon\n", 347 | " return policy, q" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": 13, 353 | "metadata": {}, 354 | "outputs": [ 355 | { 356 | "data": { 357 | "application/vnd.jupyter.widget-view+json": { 358 | "model_id": "2d2dd23f407549b6954d33d9344585ee", 359 | "version_major": 2, 360 | "version_minor": 0 361 | }, 362 | "text/plain": [ 363 | "HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))" 364 | ] 365 | }, 366 | "metadata": {}, 367 | "output_type": "display_data" 368 | }, 369 | { 370 | "name": "stdout", 371 | "output_type": "stream", 372 | "text": [ 373 | "\n" 374 | ] 375 | }, 376 | { 377 | "data": { 378 | "image/png": "\n", 379 | "text/plain": [ 380 | "
" 381 | ] 382 | }, 383 | "metadata": { 384 | "needs_background": "light" 385 | }, 386 | "output_type": "display_data" 387 | }, 388 | { 389 | "data": { 390 | "image/png": "\n", 391 | "text/plain": [ 392 | "
" 393 | ] 394 | }, 395 | "metadata": { 396 | "needs_background": "light" 397 | }, 398 | "output_type": "display_data" 399 | } 400 | ], 401 | "source": [ 402 | "policy, q = monte_carlo_with_soft(env)\n", 403 | "v = q.max(axis=-1)\n", 404 | "plot(policy.argmax(-1))\n", 405 | "plot(v)" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": 14, 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "# 重要性采样策略评估\n", 415 | "\n", 416 | "def evaluate_monte_carlo_with_importance_resample(env, policy, behavior_policy, episode_num=500000):\n", 417 | " q = np.zeros_like(policy)\n", 418 | " c = np.zeros_like(policy)\n", 419 | " \n", 420 | " for _ in tqdm(range(episode_num)):\n", 421 | " state_actions = []\n", 422 | " observation = env.reset()\n", 423 | " \n", 424 | " while True:\n", 425 | " state = ob2state(observation)\n", 426 | " action = np.random.choice(env.action_space.n, p=behavior_policy[state])\n", 427 | " state_actions.append([state, action])\n", 428 | " \n", 429 | " observation, reward, done, _ = env.step(action)\n", 430 | " if done:\n", 431 | " break\n", 432 | " \n", 433 | " g = reward\n", 434 | " rho = 1.0 # 重要性采样比率\n", 435 | " for state, action in state_actions:\n", 436 | " c[state][action] += rho\n", 437 | " q[state][action] += (rho / c[state][action] * (g - q[state][action]))\n", 438 | " rho *= (policy[state][action] / behavior_policy[state][action])\n", 439 | " if rho == 0:\n", 440 | " break\n", 441 | " return q" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": 15, 447 | "metadata": {}, 448 | "outputs": [ 449 | { 450 | "data": { 451 | "application/vnd.jupyter.widget-view+json": { 452 | "model_id": "4e4bbd925eef4420b90a066c60bb73f5", 453 | "version_major": 2, 454 | "version_minor": 0 455 | }, 456 | "text/plain": [ 457 | "HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))" 458 | ] 459 | }, 460 | "metadata": {}, 461 | "output_type": "display_data" 462 | }, 463 | { 464 | "name": "stdout", 465 | "output_type": "stream", 466 | "text": [ 467 | "\n" 468 | ] 469 | }, 470 | { 471 | "data": { 472 | "image/png": "\n", 473 | "text/plain": [ 474 | "
" 475 | ] 476 | }, 477 | "metadata": { 478 | "needs_background": "light" 479 | }, 480 | "output_type": "display_data" 481 | } 482 | ], 483 | "source": [ 484 | "policy = np.zeros((22, 11, 2, 2))\n", 485 | "policy[20:, :, :, 0] = 1 # >= 20 时收手\n", 486 | "policy[:20, :, :, 1] = 1 # < 20 时继续\n", 487 | "behavior_policy = np.ones_like(policy) * 0.5\n", 488 | "q = evaluate_monte_carlo_with_importance_resample(env, policy, behavior_policy)\n", 489 | "v = (q * policy).sum(axis=-1)\n", 490 | "plot(v)" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": 16, 496 | "metadata": {}, 497 | "outputs": [], 498 | "source": [ 499 | "def monte_carlo_importance_resample(env, episode_num=500000):\n", 500 | " policy = np.zeros((22, 11, 2, 2))\n", 501 | " policy[:, :, :, 0] = 1.0\n", 502 | " behavior_policy = np.ones_like(policy) * 0.5\n", 503 | " q = np.zeros_like(policy)\n", 504 | " c = np.zeros_like(policy)\n", 505 | " \n", 506 | " for _ in tqdm(range(episode_num)):\n", 507 | " state_actions = []\n", 508 | " observation = env.reset()\n", 509 | " \n", 510 | " while True:\n", 511 | " state = ob2state(observation)\n", 512 | " action = np.random.choice(env.action_space.n, p=behavior_policy[state])\n", 513 | " state_actions.append([state, action])\n", 514 | " \n", 515 | " observation, reward, done, _ = env.step(action)\n", 516 | " if done:\n", 517 | " break\n", 518 | " \n", 519 | " g = reward\n", 520 | " rho = 1.0\n", 521 | " for state, action in state_actions:\n", 522 | " c[state][action] += rho\n", 523 | " q[state][action] += (rho / c[state][action] * (g - q[state][action]))\n", 524 | " a = q[state].argmax()\n", 525 | " policy[state] = 0.0\n", 526 | " policy[state][a] = 1.0\n", 527 | " if a != action:\n", 528 | " break\n", 529 | " rho /= behavior_policy[state][action]\n", 530 | "\n", 531 | " return policy, q" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": 17, 537 | "metadata": {}, 538 | "outputs": [ 539 | { 540 | "data": { 541 | "application/vnd.jupyter.widget-view+json": { 542 | "model_id": "b18769817b104b9c81e177a6259ae9be", 543 | "version_major": 2, 544 | "version_minor": 0 545 | }, 546 | "text/plain": [ 547 | "HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))" 548 | ] 549 | }, 550 | "metadata": {}, 551 | "output_type": "display_data" 552 | }, 553 | { 554 | "name": "stdout", 555 | "output_type": "stream", 556 | "text": [ 557 | "\n" 558 | ] 559 | }, 560 | { 561 | "data": { 562 | "image/png": "\n", 563 | "text/plain": [ 564 | "
" 565 | ] 566 | }, 567 | "metadata": { 568 | "needs_background": "light" 569 | }, 570 | "output_type": "display_data" 571 | }, 572 | { 573 | "data": { 574 | "image/png": "\n", 575 | "text/plain": [ 576 | "
" 577 | ] 578 | }, 579 | "metadata": { 580 | "needs_background": "light" 581 | }, 582 | "output_type": "display_data" 583 | } 584 | ], 585 | "source": [ 586 | "policy, q = monte_carlo_importance_resample(env)\n", 587 | "v = q.max(axis=-1)\n", 588 | "plot(policy.argmax(-1))\n", 589 | "plot(v)" 590 | ] 591 | } 592 | ], 593 | "metadata": { 594 | "kernelspec": { 595 | "display_name": "Python 3", 596 | "language": "python", 597 | "name": "python3" 598 | }, 599 | "language_info": { 600 | "codemirror_mode": { 601 | "name": "ipython", 602 | "version": 3 603 | }, 604 | "file_extension": ".py", 605 | "mimetype": "text/x-python", 606 | "name": "python", 607 | "nbconvert_exporter": "python", 608 | "pygments_lexer": "ipython3", 609 | "version": "3.7.6" 610 | } 611 | }, 612 | "nbformat": 4, 613 | "nbformat_minor": 4 614 | } 615 | -------------------------------------------------------------------------------- /Chapter7-回合更新策略梯度方法/7.5-案例:车杆平衡.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 33, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "" 12 | ] 13 | }, 14 | "execution_count": 33, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "import gym\n", 21 | "import torch\n", 22 | "import numpy as np\n", 23 | "import pandas as pd\n", 24 | "import torch.nn as nn\n", 25 | "import torch.optim as optim\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "import torch.nn.functional as F\n", 28 | "\n", 29 | "from tqdm.notebook import tqdm\n", 30 | "\n", 31 | "np.random.seed(0)\n", 32 | "torch.manual_seed(0)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 4, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "[0]" 44 | ] 45 | }, 46 | "execution_count": 4, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "env = gym.make('CartPole-v0')\n", 53 | "env.seed(0)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 7, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "class DQN(nn.Module):\n", 63 | " \n", 64 | " def __init__(self, input_size, hidden_sizes, output_size):\n", 65 | " super(DQN, self).__init__()\n", 66 | " \n", 67 | " neurons = [input_size] + hidden_sizes\n", 68 | " layers = []\n", 69 | " for i in range(len(neurons) - 1):\n", 70 | " layers.append(nn.Linear(neurons[i], neurons[i + 1]))\n", 71 | " layers.append(nn.ReLU(inplace=True))\n", 72 | " layers.append(nn.Linear(neurons[-1], output_size))\n", 73 | " layers.append(nn.Softmax())\n", 74 | " self.net = nn.Sequential(*layers)\n", 75 | " return\n", 76 | "\n", 77 | " def forward(self, x):\n", 78 | " return self.net(x)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 50, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "# 同策策略梯度算法智能体类\n", 88 | "\n", 89 | "class VPGAgent(object):\n", 90 | " \n", 91 | " def __init__(self, env, policy_kwargs, baseline_kwargs=None, gamma=0.99):\n", 92 | " observation_dim = env.observation_space.shape[0]\n", 93 | " self.action_n = env.action_space.n\n", 94 | " self.gamma = gamma\n", 95 | " \n", 96 | " self.trajectory = []\n", 97 | "\n", 98 | " self.policy_net = DQN(observation_dim, policy_kwargs['hidden_sizes'], self.action_n)\n", 99 | " self.policy_opt = optim.Adam(self.policy_net.parameters(), lr=policy_kwargs['learning_rate'])\n", 100 | " \n", 101 | " if baseline_kwargs:\n", 102 | " self.baseline_net = DQN(observation_dim, baseline_kwargs['hidden_sizes'], 1)\n", 103 | " self.baseline_opt = optim.Adam(self.baseline_net.parameters(), lr=baseline_kwargs['learning_rate'])\n", 104 | " return\n", 105 | "\n", 106 | " @staticmethod\n", 107 | " def __tensor2numpy(tensor):\n", 108 | " return tensor.cpu().detach().numpy()\n", 109 | " \n", 110 | " def decide(self, observation):\n", 111 | " self.policy_net.eval()\n", 112 | " probs = self.policy_net(torch.tensor(observation[np.newaxis]).float())\n", 113 | " probs = self.__tensor2numpy(probs)[0]\n", 114 | " action = np.random.choice(self.action_n, p=probs)\n", 115 | " return action\n", 116 | " \n", 117 | " def learn(self, observation, action, reward, done):\n", 118 | " self.trajectory.append((observation, action, reward))\n", 119 | " \n", 120 | " if done:\n", 121 | " df = pd.DataFrame(data=self.trajectory, columns=['observation', 'action', 'reward'])\n", 122 | " df['discount'] = self.gamma ** df.index.to_series()\n", 123 | " df['discounted_reward'] = df['discount'] * df['reward']\n", 124 | " df['discounted_return'] = df['discounted_reward'][::-1].cumsum()\n", 125 | " df['psi'] = df['discounted_return']\n", 126 | " \n", 127 | " x = torch.tensor(np.stack(df['observation'])).float()\n", 128 | " if hasattr(self, 'baseline_net'):\n", 129 | " self.baseline_net.eval()\n", 130 | " df['baseline'] = self.__tensor2numpy(self.baseline_net(x))\n", 131 | " df['psi'] -= df['baseline'] * df['discount']\n", 132 | " df['return'] = df['discount_return'] / df['discount']\n", 133 | " y = torch.tensor(df['return'].values[:, np.newaxis]).float()\n", 134 | " self.baseline_net.train()\n", 135 | " y_hat = self.baseline_net(x)\n", 136 | " loss = F.binary_cross_entropy_with_logits(y_hat, y)\n", 137 | " self.baseline_opt.zero_grad()\n", 138 | " loss.backward()\n", 139 | " self.baseline_opt.step()\n", 140 | " \n", 141 | " y = torch.tensor(np.eye(self.action_n)[df['action']] * df['psi'].values[:, np.newaxis]).float()\n", 142 | " self.policy_net.train()\n", 143 | " y_hat = self.policy_net(x)\n", 144 | " loss = F.binary_cross_entropy_with_logits(y_hat, y)\n", 145 | " self.policy_opt.zero_grad()\n", 146 | " loss.backward()\n", 147 | " self.policy_opt.step()\n", 148 | "\n", 149 | " self.trajectory = []\n", 150 | " return" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 16, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "# 智能体和环境交互一个回合的代码\n", 160 | "\n", 161 | "def play_montecarlo(env, agent, render=False, train=False):\n", 162 | " episode_reward = 0.0\n", 163 | " observation = env.reset()\n", 164 | " \n", 165 | " while True:\n", 166 | " if render:\n", 167 | " env.render()\n", 168 | " \n", 169 | " action = agent.decide(observation)\n", 170 | " next_observation, reward, done, _ = env.step(action)\n", 171 | " episode_reward += reward\n", 172 | " \n", 173 | " if train:\n", 174 | " agent.learn(observation, action, reward, done)\n", 175 | " if done:\n", 176 | " break\n", 177 | " \n", 178 | " observation = next_observation\n", 179 | " \n", 180 | " return episode_reward" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 51, 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "data": { 190 | "application/vnd.jupyter.widget-view+json": { 191 | "model_id": "7bd25b5f62084b0fbcef1f099464b9f4", 192 | "version_major": 2, 193 | "version_minor": 0 194 | }, 195 | "text/plain": [ 196 | "HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))" 197 | ] 198 | }, 199 | "metadata": {}, 200 | "output_type": "display_data" 201 | }, 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "\n" 207 | ] 208 | }, 209 | { 210 | "name": "stderr", 211 | "output_type": "stream", 212 | "text": [ 213 | "d:\\programdata\\miniconda3\\envs\\rl\\lib\\site-packages\\torch\\nn\\modules\\container.py:100: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", 214 | " input = module(input)\n" 215 | ] 216 | }, 217 | { 218 | "name": "stdout", 219 | "output_type": "stream", 220 | "text": [ 221 | "平均回合奖励 = 950.0 / 100 = 9.5\n" 222 | ] 223 | }, 224 | { 225 | "data": { 226 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAftUlEQVR4nO3deXhU9d338fc3K2GJbIFSQAF3am8DRmqt2s2Fiq1LXUp9Kk+1tb3v9lJ7L62199O7u7TVamvdqxWfy7rUpfq4tUhFRCgYIGyyJEAQJEAgskNCkt/zx5wJk1kyM8ks+dnP67pyZebMWb7nd8585jdnzpwx5xwiIuKfgnwXICIi3aMAFxHxlAJcRMRTCnAREU8pwEVEPFWUy4UNHTrUjRkzJpeLFBHx3qJFi3Y45yqih+c0wMeMGUN1dXUuFyki4j0z2xhvuA6hiIh4SgEuIuIpBbiIiKcU4CIinlKAi4h4SgEuIuIpBbiIiKe8CfDq+ibWbN2b7zJERHqNnH6Rpycuv28+APXTp+S5EhGR3sGbHriIiHSmABcR8ZQCXETEUwpwERFPKcBFRDylABcR8ZQCXETEUwpwERFPKcBFRDylABcR8ZQCXETEUwpwERFPKcBFRDylABcR8ZQCXETEUwpwERFPJQ1wM+tjZgvNbKmZrTSzHwfDHzGzDWZWE/xVZr9cEREJS+UXeZqBzzjn9plZMTDXzF4JHvsv59zT2StPREQSSRrgzjkH7AvuFgd/LptFiYhIcikdAzezQjOrAbYDM51zC4KHfm5my8zsDjMrTTDt9WZWbWbVjY2NGSpbRERSCnDnXJtzrhIYBUwys1OA7wMnAacDg4HvJZj2AedclXOuqqKiIkNli4hIWmehOOd2AbOByc65BhfSDPwRmJSF+kREJIFUzkKpMLOBwe0y4FxgtZmNCIYZcAmwIpuFiohIZ6mchTICmGFmhYQC/ynn3Itm9nczqwAMqAG+mcU6RUQkSipnoSwDJsQZ/pmsVCQiIinRNzFFRDylABcR8ZQCXETEUwpwERFPKcBFRDylABcR8ZQCXETEUwpwERFPKcBFRDylABcR8ZQCXETEUwpwERFPKcBFRDylABcR8ZQCXETEUwpwERFPKcBFRDylABcR8ZQCXETEUwpwERFPKcBFRDylABcR8ZQCXETEUwpwERFPJQ1wM+tjZgvNbKmZrTSzHwfDx5rZAjOrNbMnzawk++WKiEhYKj3wZuAzzrlTgUpgspmdAfwSuMM5dzzwPnBd9soUEZFoSQPchewL7hYHfw74DPB0MHwGcElWKhQRkbhSOgZuZoVmVgNsB2YC64BdzrnWYJTNwMgE015vZtVmVt3Y2JiJmkVEhBQD3DnX5pyrBEYBk4CT442WYNoHnHNVzrmqioqK7lcqIiKdpHUWinNuFzAbOAMYaGZFwUOjgC2ZLU1ERLqSylkoFWY2MLhdBpwLrAJeBy4PRpsGPJ+tIkVEJFZR8lEYAcwws0JCgf+Uc+5FM3sHeMLMfgYsAR7KYp0iIhIlaYA755YBE+IMX0/oeLiIiOSBvokpIuIpBbiIiKcU4CIinlKAi4h4SgEuIuIpBbiIiKcU4CIinlKAi4h4SgEuIuIpBbiIiKcU4CIinlKAi4h4SgEuIuIpBbiIiKcU4CIinlKAi4h4SgEuIuIpBbiIiKcU4CIinvIuwPceOpzvEkREegXvAvzbf1qS7xJERHoF7wK8bvu+fJcgItIreBfgIiISogAXEfFU0gA3s9Fm9rqZrTKzlWZ2YzD8R2b2npnVBH8XZr9cEREJK0phnFbgP5xzi81sALDIzGYGj93hnLste+WJiEgiSQPcOdcANAS395rZKmBktguLdLClLZeLExHxQlrHwM1sDDABWBAM+raZLTOzh81sUIJprjezajOrbmxs7FaRNZt2dWs6EZEPspQD3Mz6A88ANznn9gD3AscClYR66LfHm84594Bzrso5V1VRUdG9Ii2yjm7NQkTkAyelADezYkLh/Zhz7lkA59w251ybc64deBCYlLUiC5TaIiLRUjkLxYCHgFXOud9EDB8RMdqlwIrMlxei/BYRiZXKWSifAL4CLDezmmDYLcBUM6sEHFAPfCMrFQKm4yYiIjFSOQtlLhAvQV/OfDnxFSjARURiePFNTMW3iEgsLwJcPXARkVh+BHhElcpyEZEQPwJcqS0iEkMBLiLiKS8CXEREYnkR4A7Xcdt0ToqICOBLgLvk44iI/LNRgIuIeMqPAEcJLiISzY8AV36LiMTwIsAj6YxCEZEQLwJcPXARkVh+BHjEMfCNOw/ksRIRkd7DjwBXD1xEJIYfAZ7vAkREeiE/AlxdcBGRGH4EeL4LEBHphfwIcCW4iEgMLwJcfXARkVheBLh64CIisfwI8HwXICLSC/kR4EpwEZEYngS4ElxEJFrSADez0Wb2upmtMrOVZnZjMHywmc00s9rg/6BsFan4FhGJlUoPvBX4D+fcycAZwLfMbDxwMzDLOXc8MCu4nxXqgIuIxEoa4M65Bufc4uD2XmAVMBK4GJgRjDYDuCRbRc5esz1bsxYR8VZax8DNbAwwAVgADHfONUAo5IFhCaa53syqzay6sbGxW0XeP2d9t6YTEfkgSznAzaw/8Axwk3NuT6rTOececM5VOeeqKioqulOjiIjEkVKAm1kxofB+zDn3bDB4m5mNCB4fAeg4h4hIDqVyFooBDwGrnHO/iXjoBWBacHsa8HzmyxMRkUSKUhjnE8BXgOVmVhMMuwWYDjxlZtcB7wJXZKdEERGJJ2mAO+fmAol+SvizmS1HRERS5cU3MauOydp3hEREvOVFgJ//keH5LkFEpNfxIsD1TUwRkVh+BHi+CxAR6YX8CHAluIhIDC8CXEREYnkR4E4HUUREYvgR4MpvEZEYXgS4iIjEUoCLiHjKiwDXb2KKiMTyJMDzXYGISO/jRYCLiEgsLwJcHXARkVh+BLgSXEQkhh8Brj64iEgMLwJcRERieRHgOoQiIhLLjwDPdwEiIr2QFwGuLriISCw/AlxERGJ4EeDqf4uIxPIjwJXgIiIxvAhwERGJlTTAzexhM9tuZisihv3IzN4zs5rg78JsFqkv8oiIxEqlB/4IMDnO8Ducc5XB38uZLaszHUIREYmVNMCdc3OAphzUkriGfC5cRKSX6skx8G+b2bLgEMugRCOZ2fVmVm1m1Y2NjT1YnIiIROpugN8LHAtUAg3A7YlGdM494Jyrcs5VVVRUdGthOoQiIhKrWwHunNvmnGtzzrUDDwKTMltW1PJ0EEVEJEa3AtzMRkTcvRRYkWjcjFB+i4jEKEo2gpk9DnwKGGpmm4H/AT5lZpWEorUe+EYWaxQRkTiSBrhzbmqcwQ9loZbENeRyYSIinvDim5hOn2KKiMTwJMDzXYGISO/jRYCLiEgsLwJcHXARkVh+BLgSXEQkhhcBLiIisbwI8OhvYlbX5/XaWiIivYIfAR51COXy++bnpxARkV7EiwAHOKqsmOvOGpvvMkREeg1vAtwMWtva812GiEiv4UWAh7+J2dqu01FERML8CHDAgNY2BbiISJgfAe7AzDjcrkMoIiJhXgQ4qAcuIhLNiwD/8seO5rYrT6VVPXARkQ5eBPjJI8r59InDOKweuIhIBy8CPEynEYqIHOFXgOs0QhGRDl4F+GH1wEVEOngV4JFnoexrbgVg9dY9PF/zHgBPvb2JDTv2pzXPt+p2MLd2BwA79jXz4Jz13DO7rmP+T1VvYn3jPt6ub+Kahxeyfe8hIPTlogfnrOf9/S0d82rc28xDczd0+gm4J99+l/ouatrf3Mrdr9fR1sW7i01NB/j8XXPZ1HQg5rF9za3cO3sds1ZtY+GG2It8bd97iIeDmhZtfJ+Z72xL0iKJPfV2qC3umlXLra+sorm1rcvx/756G799rZY3axuBUJvdNauWFe/tTjjNnxa822k9D7a0cffrdXEPnzXsPshvX6vlwTnrY352r6vpwl5YuoVVDXsAWLB+J2+sbexyfeat28GctY3Mrd3BW3U7Yh5f/G5s+7a1O+6ZXcej8+up274vZppXljewbPMuAF57Zxs/emFlR03NraF1aGntvA5vrG1k/rqdHfdnrdpGdX0T97+xjt0HDne5Djv2NXPjE0t4dcXWhOMcOhx/uQBrtu7lL0veiztd5L4Goe19/xvr2H2w65oS2d/cyj2z69h94DD3zl5Hew/fgbe2tfOtxxbz9KLNnYaH1zedDuJzSzbzVt0OLv79XL7+aHXCn310zvHAnHXsOtAS9/GeSvqjxr3J4YgNOPOdrVw6YRST73wTgIsrR/LdZ5YxoLSI5T++IOV5Xv2HBQDUT5/CTU/UMDd4Ym5qOsitl32U7z69jNKiApqDnflbjy3mz988k0Ub3+fnL6/i7fomHrimCoAbn1jCvHU7Oeu4oZz4oQG0tzu+98xyBvYtpuaH58dd/q9eXc2M+Rs5enBfPn/qh+OOc9Fdc9l98DBfvHceC39wbqfHbn15FY8teLfjfv30KZ0ev+mJGuat28mZxw3hi/fOiztOqr77zLJO9weWlfCvnzo24fjXPlLdqa6d+1u4feZanl+6hdf+/ZMx4ze3tnHLc8sZXl7KgltC63nna2u5f856hg0o5Yqq0Z3G/9qMalZuCYXdxGMGctoxgzse+/3rtdz9+joG9yth6qSj49Z3w+NLOmq76oF/dNxO5MsPLuh0P3rcy+6Jbd+/rdzKr15dA0CBwfpbO0/zr48t7pjma4+G2uuRefXUT5/CH97cwK//uobSogK+dva4jmmmPbyw03Kum3Gknd9p2MNvvzQh4Tp858ka3qzdwfM1WxKu6x/eXM9tf1tLWXEh10Zdf+iCO+cAcMmEkTHT3fD4Ev6xvomzjh/KCcMH8FbdTm59ZTUrt+zhd1MT15TIr/+6hkfm1XP33+vY39LGccP6c9744WnPJ+yZxZt5aXkDLy1v4PLTRnUMv2f2On43q5bysmK+csYxKc3rO08ujbi3m/nrdnLmcUNjxlu4oYlfvLyamk27uOfq07pdeyKe9cCPvEIebu38ihd+dd4b9Jy7Y2dEb3pPRK+hOaInEh4nPGzvoSPLC/c0wj3T8BePdnXRKwpPf+hw4t5seL7vx3kV359kfcPzbz7cs8NP8XoYXdUcT/hdRryeaGgZof879h1Zz/A7oUNxeoOR7doc9fiBllBtydon21oi9tl0O5AHg3UIr0sqkq3vnhR6w91tu+h9raWtLRjevR54eNvvD+pJ9o4vmUTteCC8j6XRztESfT4X3i+7+y4kGa8CvKvDDG0Z+NmeyJBqdy7uW7bwKOH/ZkceKyywYNrQ/Uyf9mhYzLCCgthhkcI19fQD4HhtX5hk2anMI97j7Sluy07Lj5qk0MLbIr8ffBdYem0UKTxpOquQbHnJ9heI3Y9TVdSxr2Xms6o0d6+cideZ6cl27gmvAryrY1TJwiEVkfNoa3dxXxTC44R/ZCJyu4U3YniclE57DKa3VHaAOKMUJpmuqCAzQRavLTIe4B3HTlObX+TyoycJP5bvz73TbaNI4Smjf9CkJ8tLtr9AxH7cxYbo6gU9Uy+a0aEYrwOTjkxFbLzOUKJmD69CT2tPJGmAm9nDZrbdzFZEDBtsZjPNrDb4Pygr1UXpqheZiZ2mLaoHHm8n7QjwcA88YsNE78AtOUiPpE/YcK+oh+8G4nWq0u11JNtG6X5I1akDHjVpQYbDpLt61DMLpk2rB55kf0irB97F9ojXmSoqCMVJpi55kVKnJg/irXu+ak2lB/4IMDlq2M3ALOfc8cCs4H7WRe4Y7c51eisTGbYu4rHwTtje7mKGR0/THt0Dj7MDhwMh8pH29tC04edGuM7oT/Ej63VR9YdfMKLr7vR2zYXut0U8Fv2EDK9HeLqiwq7f1obHDa9D9HLDj8WbvsA6t2v0+kWKbs/oaULLSDw9EesVrrNzD7xz3eGH2qLWK+H847RHuL5EQRbd1pHzCIt+ge20zdvj3+4YL2IbJGrj6GUXmnV5RkSi/I5c3462i5h/9DwPt7V32meccx29zdaOfSh2+fFqiv4LPw+ia418JxLvee2i6o3XZvFqcR3/O69r9LzD6xrv7JzwurfH7OPB8AR19JSlMlMzGwO86Jw7Jbi/BviUc67BzEYAs51zJyabT1VVlauurk42WkJX3j8/7qlykp766VO4+ZllPPH2pozP+5SR5ax4b0/G55uqPsUFHOrhB7aZ9PFxQ5i/fmfyESUv6qdPYczNL+VkWQ9eU9Xts2jMbJFzrip6eHePgQ93zjUABP+HdbHg682s2syqGxu7Ps82mfv/V+ZPw/lnlY3wBvIa3kCvCm9A4S0d+pdm/qztrH+I6Zx7wDlX5Zyrqqio6NG8BvUryVBVIiK5dfzw/hmfZ3cDfFtw6ITg//bMlSQi8sFTXJj5/nJ35/gCMC24PQ14PjPliIh8MJXkI8DN7HFgPnCimW02s+uA6cB5ZlYLnBfcFxGRBIoLM3+qYdKj6s65qQke+myGa5EceXjuhnyXINIrfPq22TlbVk++1JWIV9/EBLgszkV0JD0/efGdfJcg0iuke/XS7hrSryQrX/bxLsB/c1Ul9dOnMGnMkSvPnTd+OCvTuALhtI8fQ/30Kdx5VWXMYxOPHtjltN/45LhO91++4Wxqf/65jvvjR5R33P6vC5KeGt+hqMConz6FFT++gLsSXLkt/Ar+p699LOF8XrrhrISP3XLhSV3WUBexHl1Z87Po73Ul9o1zxnHmsUMSPv7Hr56e1rYDGFfRj9U/ncyqn6ReRyrOOSH+WVIfPqoPd1x1akaXFfbcv50ZM+yVG89OadoPlffh7YirUw7tX8JtV/S8zsrRXT8HIrf/kODMsIemxZyi3G0r0twfMiHePnjphJF8dORRHffrfv453vnJBSxNcGXReO646lTW/uxzzPv+ZzJSZzSvLicbKfIr0uV9iumXxjmW4XHjvSD2KS7sdN+s8zfKjior7vR4SVFBp0+X+xQfud2dDy36lxZRHrWMsNKiAg60tFFclHi+5X3iTwvJPwUvSrHe0qLC5CMFkn11u7SwIO6261tSmPDqcWXFhTHbKRMSHaLsW1qU1jqnI95X7ctSXLfiIqO87EjblRQWpDxtV4qSbbOItgiXP7Bv5k7xzcax4mTi7U9FBcbgiFOXiwoLQs+RNFZ1YFkJJV08X3vKux54WOR1S7r7ziSVK+xFf1E1+ivP0cvOxnGusHC93T0dKR9Xa0h28aREF0zqKqCzddmJRNvOyMzF0npaR8x4Zp3aN1Nv0btzZc/SDIZUcUHuYylem7c51+PncyrXnunR/LM69yzq6a9zQPwnZbKLD0VfxyZ67GxeVjL8riMfPZTuSrYDJwrGTAZCqrradrm8KFaqT/qCAstKh6E7T61M9jKzHXqpiry+UXelcvXHnvA2wDNx/e94T8pkT4jo5Ub3erIZ4OGwy8b5pNl655BsB04UjPm4vnJXbZDTHniK615olpUPxrrTOcrGl1TyLRMXVcz2mwlvWz3yOFy6r/7haeMefyzp/NY9+nhg9P3oJ1tpxDHwojR6ypHLTXQMsm9J6HhnV8dju3o+d3XsfFAGj2FG1tCnuKDrwyEJDuwM6JP4M41MHOeNO9+S+PMd0Kcoay9w0ftggaX+pB9WXtrpfllJYUbqTOf5FG6zZMfNfVRcYAk/U0q1E5XtHri3H2LeNXUCNzy+hPqdB/je5NDZFY9eO4mNTQd4cekW3tmyh9PGDOJgSxtnHjsUs9BPMr1/4DBTJ4V+W/HiypH8edFmzj15GAda2ljXuJ8ff+EjnDBsAG+t28GE0QP5+jnjuOXZ5cxe08iLN5zF6EF92dfcinOOtdv2MXpwGQD3Xj2R0uICKkcPYuJPZ3LuycOYOuloJo0dzL89tpiNOw8waexglm/eTdWYQXx05FGUFRdy+8y1nH38UP7n8+M71u3j44ZwZdUonqrezJvf/TRbdh3k3aYDnDp6IG+saWT04DK+OHEU7x9o4ZQPl1NWUsTW3Qe5omo0IweWcc4JFcxZ20i/ksKOn6O69hNjufy0UWzceYCrTh/Nr15dzdbdhzh5RDmvrtzKn7/58Y42/O7Ty7jryxO45dnlXHvWWA4dbmPS2ME07W+hKfhJuSevP4MZ8+t5eflWTh8ziNZ2x+G2dj570nAu+pcR/L9lDbyzZQ/TzhzDZRNHMWNePWOH9mPEwD4s3bSbX/91Ne3uyJkft11xKne/Xsdvv1RJzaZdfPrEYTxVvYm12/by15XbKCowrjp9NEP6l3ZsPwidBfT955YzsKyYNVv3MqR/CSu37OGlG86ibvs+3qzdwcvLG3ho2uk8vWgzzywO/aDtnVdVsrC+iQmjB/LLV9dw4of688OLxnPhKSNobXfUbtvLueOH88hb9Vx5+ij+ZdRA3mnYQ9/iInbub2ZVwx4G9S3h6+eM44r75nP++OFUjRnEm7U7uPpjR9PS5hg3tB/3zK7jQEsb6xv3M+3MMfy5ehOrt+7lpA8NYFxFP04ZWc5dUydQXlZM3fZ9fOK4IVT0L+Xy00axcEMTRQXGf15wIks37aKkqIAPHdWHZxe/R2u74/9cFNpnfnjReBZtfJ/vTT6JDw/sQ3mfIvZE/NTfFaeN4oxxQ3ho7gbGVfSjvKyYPy14l29/+jheWdHAkH6lHDOkLyu37GH73kPcNXUCTftbuOiuuVw2cSQvLWvg9itP5f39LVSODl36/w/XVNHmHCcOH8BLyxsYPbgvUz46gkH9QicUlBUXcudrtfzi0o9y3LD+fPWPC/nUScO4pHIkX3+0mlsuPImKAaXc+vJqzhs/nLrt+ygvK+aTwf5w51WVbNx5gHnrdvCd806grLiQrz7yNk37W7hs4khGD+rLb2fVcu/VE9nb3Ep5nyL++y8r+OnFp7Biy24+PLCMJ9/exKHDbazdFvoJv/+ecjI79rXw2IKN7D3UyuWnjeLpRZs7fq/zlgtP4hcvr+aP//t0Fmxo4hvnjGPXwcM8u+Q9ng6eHx373Y1nc8V987jxs8dz7xvr+PKkY9i5v5lLJ4ykpbWdv9Rsob3d8ZGIs1iyIaXLyWZKTy8nKyLyzyjTl5MVEZE8U4CLiHhKAS4i4ikFuIiIpxTgIiKeUoCLiHhKAS4i4ikFuIiIp3L6RR4zawQ2dnPyocCODJaTKaorPaorPb21Lui9tX0Q6zrGORdzwfqcBnhPmFl1vG8i5ZvqSo/qSk9vrQt6b23/THXpEIqIiKcU4CIinvIpwB/IdwEJqK70qK709Na6oPfW9k9TlzfHwEVEpDOfeuAiIhJBAS4i4ikvAtzMJpvZGjOrM7Obc7jc0Wb2upmtMrOVZnZjMPxHZvaemdUEfxdGTPP9oM41ZnZBluurN7PlQQ3VwbDBZjbTzGqD/4OC4WZmvwtqW2ZmE7NU04kR7VJjZnvM7KZ8tJmZPWxm281sRcSwtNvHzKYF49ea2bQs1fVrM1sdLPs5MxsYDB9jZgcj2u2+iGlOC7Z/XVB7j36/K0FdaW+3TD9fE9T1ZERN9WZWEwzPZXslyofc7WPOuV79BxQC64BxQAmwFBifo2WPACYGtwcAa4HxwI+A/4wz/vigvlJgbFB3YRbrqweGRg37FXBzcPtm4JfB7QuBVwADzgAW5GjbbQWOyUebAecAE4EV3W0fYDCwPvg/KLg9KAt1nQ8UBbd/GVHXmMjxouazEPh4UPMrwOeyUFda2y0bz9d4dUU9fjvwwzy0V6J8yNk+5kMPfBJQ55xb75xrAZ4ALs7Fgp1zDc65xcHtvcAqYGQXk1wMPOGca3bObQDqCNWfSxcDM4LbM4BLIoY/6kL+AQw0sxFZruWzwDrnXFffvs1amznn5gBNcZaXTvtcAMx0zjU5594HZgKTM12Xc+5vzrnwD1n+AxjV1TyC2sqdc/NdKAUejViXjNXVhUTbLePP167qCnrRVwKPdzWPLLVXonzI2T7mQ4CPBDZF3N9M1yGaFWY2BpgALAgGfTt4G/Rw+C0Sua/VAX8zs0Vmdn0wbLhzrgFCOxgwLE+1AXyJzk+s3tBm6bZPPtrtWkI9tbCxZrbEzN4ws7ODYSODWnJRVzrbLdftdTawzTlXGzEs5+0VlQ8528d8CPB4x6lyeu6jmfUHngFucs7tAe4FjgUqgQZCb+Eg97V+wjk3Efgc8C0zO6eLcXNam5mVAF8A/hwM6i1tlkiiOnLdbj8AWoHHgkENwNHOuQnAvwN/MrPyHNaV7nbL9facSudOQs7bK04+JBw1QQ3drs2HAN8MjI64PwrYkquFm1kxoY3zmHPuWQDn3DbnXJtzrh14kCNv+XNaq3NuS/B/O/BcUMe28KGR4P/2fNRG6EVlsXNuW1Bjr2gz0m+fnNUXfHh1EXB18Daf4BDFzuD2IkLHl08I6oo8zJKVurqx3XLZXkXAZcCTEfXmtL3i5QM53Md8CPC3gePNbGzQq/sS8EIuFhwcX3sIWOWc+03E8Mhjx5cC4U/HXwC+ZGalZjYWOJ7QByfZqK2fmQ0I3yb0IdiKoIbwp9jTgOcjarsm+CT8DGB3+G1elnTqGfWGNotYXjrt81fgfDMbFBw+OD8YllFmNhn4HvAF59yBiOEVZlYY3B5HqH3WB7XtNbMzgv30moh1yWRd6W63XD5fzwVWO+c6Do3ksr0S5QO53Md68ilsrv4IfXq7ltCr6Q9yuNyzCL2VWQbUBH8XAv8XWB4MfwEYETHND4I619DDT7mT1DaO0Cf8S4GV4XYBhgCzgNrg/+BguAF3B7UtB6qyWFtfYCdwVMSwnLcZoReQBuAwoV7Odd1pH0LHpOuCv69mqa46QsdBw/vZfcG4Xwy271JgMfD5iPlUEQrUdcDvCb5ZneG60t5umX6+xqsrGP4I8M2ocXPZXonyIWf7mL5KLyLiKR8OoYiISBwKcBERTynARUQ8pQAXEfGUAlxExFMKcBERTynARUQ89f8Bt5mGubOT7Q4AAAAASUVORK5CYII=\n", 227 | "text/plain": [ 228 | "
" 229 | ] 230 | }, 231 | "metadata": { 232 | "needs_background": "light" 233 | }, 234 | "output_type": "display_data" 235 | } 236 | ], 237 | "source": [ 238 | "# 不带基线的简单策略梯度算法\n", 239 | "\n", 240 | "policy_kwargs = dict(hidden_sizes=[128,], learning_rate=0.01)\n", 241 | "agent = VPGAgent(env, policy_kwargs=policy_kwargs)\n", 242 | "\n", 243 | "# 训练\n", 244 | "episodes = 2000\n", 245 | "episode_rewards = []\n", 246 | "for episode in tqdm(range(episodes)):\n", 247 | " episode_reward = play_montecarlo(env, agent, train=True)\n", 248 | " episode_rewards.append(episode_reward)\n", 249 | "plt.plot(episode_rewards)\n", 250 | "\n", 251 | "# 测试\n", 252 | "episode_rewards = [play_montecarlo(env, agent, train=False) for _ in range(100)]\n", 253 | "print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards), len(episode_rewards), np.mean(episode_rewards)))" 254 | ] 255 | } 256 | ], 257 | "metadata": { 258 | "kernelspec": { 259 | "display_name": "Python 3", 260 | "language": "python", 261 | "name": "python3" 262 | }, 263 | "language_info": { 264 | "codemirror_mode": { 265 | "name": "ipython", 266 | "version": 3 267 | }, 268 | "file_extension": ".py", 269 | "mimetype": "text/x-python", 270 | "name": "python", 271 | "nbconvert_exporter": "python", 272 | "pygments_lexer": "ipython3", 273 | "version": "3.7.6" 274 | } 275 | }, 276 | "nbformat": 4, 277 | "nbformat_minor": 4 278 | } 279 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RL-Python-Pytorch 2 | 3 |
4 | 5 | 《强化学习-原理与Python实现》原书使用```Numpy```、```Keras```和```Tensorflow```实现强化学习方法,见[rl-book](https://github.com/ZhiqingXiao/rl-book)。 6 | 7 | 这里使用```Pytorch```将原书中深度强化学习方法实现一遍。 8 | 9 | - [x] 1 - 初识强化学习 10 | - [x] 2 - Markov决策过程 11 | - [x] 3 - 有模型数值迭代 12 | - [x] 4 - 回合更新价值迭代 13 | - [x] 5 - 时序差分价值迭代 14 | - [x] 6 - 函数近似方法 15 | - [ ] 7 - 回合更新策略梯度方法 16 | - [ ] 8 - 执行者/评论者方法 17 | - [ ] 9 - 连续动作空间的确定性策略 18 | - [ ] 10 - 综合案例:电动游戏 19 | - [ ] 11 - 综合案例:棋盘游戏 20 | - [ ] 12 - 综合案例:自动驾驶 21 | 22 |
23 | 24 | --- 25 | 26 |
27 | 28 | ## 环境配置 29 | 30 | | Package | Version | Installation | 31 | | ---------- | ------- | ------------------------------------------------------------ | 32 | | python | 3.8.6 | conda create --name rl python=3.8.6 | 33 | | numpy | 1.19.4 | pip install numpy==1.19.4 | 34 | | scipy | 1.5.4 | pip install scipy==1.5.4 | 35 | | pandas | 1.1.4 | pip install pandas==1.1.4 | 36 | | sympy | 1.7 | pip install sympy==1.7 | 37 | | gym | 0.17.3 | pip install gym==0.17.3 | 38 | | tqdm | 4.54.0 | pip install tqdm==4.54.0 | 39 | | matplotlib | 3.3.3 | pip install matplotlib==3.3.3 | 40 | | notebook | 6.1.5 | pip install notebook==6.1.5 | 41 | | pytorch | 1.7.0 | cpu:conda install pytorch\==1.7.0 cpuonly -c pytorch
gpu:conda install pytorch\==1.7.0 cudatoolkit=10.2 -c pytorch | 42 | 43 | --------------------------------------------------------------------------------