├── README.md ├── movies ├── 2_2_maze_random.gif ├── 2_3_maze_reinforce.gif ├── 2_5_maze_state_value.gif ├── 3_2_cartpole_random.gif ├── 3_4_cartpole_q_learning.gif ├── 5_4_cartpole_dqn.gif ├── 7_breakout.gif ├── 7_breakout_random.gif └── book.jpg └── program ├── 2_1_TryJupyter.ipynb ├── 2_2_maze_random.ipynb ├── 2_3_Policygradient.ipynb ├── 2_5_Sarsa.ipynb ├── 2_6_Qlearning.ipynb ├── 3_2_try_CartPole.ipynb ├── 3_3_digitized_CartPole.ipynb ├── 3_4_Qlearning_CartPole.ipynb ├── 4_3_PyTorch_MNIST.ipynb ├── 5_3and5_4_DQN.ipynb ├── 6_2_DDQN.ipynb ├── 6_3_DuelingNetwork.ipynb ├── 6_4_PrioritizedExperienceReplay.ipynb ├── 6_5_A2C_Advanced_ActorCritic.ipynb ├── 7_1_Breakout_try.ipynb ├── 7_breakout_learning.ipynb ├── 7_breakout_play.ipynb └── weight_end.pth /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Reinforcement-Learning-Book 2 | [書籍「つくりながら学ぶ!深層強化学習」、著者:株式会社電通国際情報サービス 小川雄太郎、出版社: マイナビ出版 (2018/6/28) ](https://www.amazon.co.jp/%E3%81%A4%E3%81%8F%E3%82%8A%E3%81%AA%E3%81%8C%E3%82%89%E5%AD%A6%E3%81%B6-%E6%B7%B1%E5%B1%A4%E5%BC%B7%E5%8C%96%E5%AD%A6%E7%BF%92-PyTorch%E3%81%AB%E3%82%88%E3%82%8B%E5%AE%9F%E8%B7%B5%E3%83%97%E3%83%AD%E3%82%B0%E3%83%A9%E3%83%9F%E3%83%B3%E3%82%B0-%E6%A0%AA%E5%BC%8F%E4%BC%9A%E7%A4%BE%E9%9B%BB%E9%80%9A%E5%9B%BD%E9%9A%9B%E6%83%85%E5%A0%B1%E3%82%B5%E3%83%BC%E3%83%93%E3%82%B9-%E5%B0%8F%E5%B7%9D%E9%9B%84%E5%A4%AA%E9%83%8E/dp/4839965625)のサポートリポジトリです。 3 | [이 페이지는 "서지사항"](아마존 페이지 링크)의 고객 지원 페이지입니다. 4 | 5 | 6 | 맨 아랫 부분에 정오표를 실었습니다. 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 그림. 벽돌 깨기 공략(A2C 알고리즘 사용, GPU 1장으로 3시간 동안 학습한 결과) 15 | 16 | 17 | 그림. 미로 안에서 무작위 이동 18 | 19 | 20 | 그림. 미로를 대상으로 강화학습을 수행한 결과 21 | 22 | 23 | 그림. 미로 안의 각 위치에 대해 학습된 상태 가치 24 | 25 | 26 | 그림. 역진자 제어하기 27 | 28 | 29 | ### 정오표 30 | -------------------------------------------------------------------------------- /movies/2_2_maze_random.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/2_2_maze_random.gif -------------------------------------------------------------------------------- /movies/2_3_maze_reinforce.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/2_3_maze_reinforce.gif -------------------------------------------------------------------------------- /movies/2_5_maze_state_value.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/2_5_maze_state_value.gif -------------------------------------------------------------------------------- /movies/3_2_cartpole_random.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/3_2_cartpole_random.gif -------------------------------------------------------------------------------- /movies/3_4_cartpole_q_learning.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/3_4_cartpole_q_learning.gif -------------------------------------------------------------------------------- /movies/5_4_cartpole_dqn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/5_4_cartpole_dqn.gif -------------------------------------------------------------------------------- /movies/7_breakout.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/7_breakout.gif -------------------------------------------------------------------------------- /movies/7_breakout_random.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/7_breakout_random.gif -------------------------------------------------------------------------------- /movies/book.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/book.jpg -------------------------------------------------------------------------------- /program/2_1_TryJupyter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 2.1 주피터 노트북 사용법" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "x = 2 + 3\n", 17 | "y = x * 4\n" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 3, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "x의 값은 5입니다\n", 30 | "y의 값은 20입니다\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "print(\"x의 값은 {}입니다\".format(x))\n", 36 | "print(\"y의 값은 {}입니다\".format(y))\n" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 4, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# 본문의 설명을 따라 파일을 저장하고 로드해 봅시다." 46 | ] 47 | } 48 | ], 49 | "metadata": { 50 | "kernelspec": { 51 | "display_name": "Python 3", 52 | "language": "python", 53 | "name": "python3" 54 | }, 55 | "language_info": { 56 | "codemirror_mode": { 57 | "name": "ipython", 58 | "version": 3 59 | }, 60 | "file_extension": ".py", 61 | "mimetype": "text/x-python", 62 | "name": "python", 63 | "nbconvert_exporter": "python", 64 | "pygments_lexer": "ipython3", 65 | "version": "3.6.6" 66 | } 67 | }, 68 | "nbformat": 4, 69 | "nbformat_minor": 2 70 | } 71 | -------------------------------------------------------------------------------- /program/2_5_Sarsa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 2.5 Sarsa 알고리즘 구현하기" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# 구현에 사용할 패키지 임포트하기\n", 17 | "import numpy as np\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "%matplotlib inline\n" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "data": { 29 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAElCAYAAABect+9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAGu1JREFUeJzt3XtUlHX+B/D3MzcEBoJfksJQoLlK8MtM0APaLzPYpFq7mXZgK4Ei/WmXk3TsuOtu222PmejR1V9HziqVrW5eUqFTrWwqrrcUvOCKlK55Q1uQIOQyMON8f3+MsILKDMQ8z3yH9+ucOR7m+c48n/kG777fZ57n+yhCCBARyUCndQFERO5iYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0DN1p3L9/fxEdHe2hUoioryotLb0ohAhz1a5bgRUdHY2SkpKeV0VEdB2Kopx2px2nhEQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkjW6t1uCthBCovFSJ0vOl2Fe5D8Wni1FeXY5mezPsDjsuOy5Dr9PDoDPA3+CP2LBYjIsah9GW0YiPiIclyAJFUbT+GETkgrSB5RAOfH3yayzcuxC7zuyC3WGHUW9EQ2sDHMJxTXu7ww67ww6r3YpdZ3dhz7k9MJvMaL3cCqPOiLG3jcWsxFlIHpwMncKBJ5E3ki6waptrsfLgSuTuycWl1ktoaG1o39Zsb3b7fRzCgfqWegCAFVZ8deIr7DyzE0GmIOQk5SDr7iyE+of2ev1E1HOKEMLtxgkJCUKrBfzO1Z/D7KLZ2FixETpFhyZbk8f2FWAMgEM48ETME3jvl+8hMjjSY/siIkBRlFIhRIKrdl4/9xFCYMXBFYhZGoN1R9fBard6NKwAoMnWBKvdirVH1yJmaQxWHFyB7gQ7EXmGVwdWZX0lxn80Hq98+QoabY2wC7uq+7cLOxptjXjly1cw/qPxqKyvVHX/RNSR1wZW/qF8xCyNwa6zu9Boa9S0lkZbI3ad3YWYZTHIP5SvaS1EfZnXBZYQAq9+9Spe/OJFNNgaYHeoO6q6EbvDjobWBrz4xYuY9bdZnCISacCrAuuy4zIyNmUg70Cex49T9VSTrQnLS5cjc3MmLjsua10OUZ/iNac1CCGQtTkL64+t99qwatNka8K68nUAgPxH83nSKZFKvGaENetvs7Dh2AavD6s2baGVsyVH61KI+gyvCKz8Q/nIO5Cn+cH17mqbHvJAPJE6NA+syvpKvPzFy9KMrDprsjXh5S9f5ikPRCrQNLCEEEj/LB3Wy1Yty/jZWuwt+PVnv+Y3h0QepmlgrTy0EqXnS73m1IWesjlsKDlfwqkhkYdpFljn6s+1n8HuCxptjXjlq1c4NSTyIM0Ca3bRbLTYW7TavUdY7VbMLpqtdRlEPkuTwKptrsXGio2qXxvoaXaHHZ9VfIba5lqtSyHySZoE1sqDK312kTydouOxLCIPUT01HMKB3D250p7G4EqTrQm5u3Ovu+opEf08qgfW1ye/xqXWS73/xo0APgewCMDbAN4H8BGAf13ZLgBsA7AAwDsA8gFU9X4ZAFDfWo+t32/1zJt7kerqasyYMQPR0dHw8/PDgAEDkJycjKKiIgDAZ599hgkTJiAsLAyKomD79u3aFuwDuupzm82G119/HcOHD0dgYCDCw8ORnp6OM2fOaF12r1H9WsKFexd2WNa413wKwAbgUQD/BWeAnQLQNpDbBWAPgMcA3AygGMDHAF4C4Ne7pTS0NiB3Ty5SBqf07ht7mUmTJqGpqQkrVqzAkCFDUFVVheLiYtTU1AAAGhsbMWbMGDz99NN49tlnNa7WN3TV501NTThw4AB++9vfYsSIEfjpp5+Qk5OD1NRUlJWVwWDwmkuHe0zVJZKFELhp3k29P8JqBvAegGcA3H69HQPIBTAawL1XnrPBOQp7AIDLhVm7L9gvGHWv1/nshdF1dXUIDQ1FUVERUlK6DuaLFy8iLCwM27Ztw3333adOgT6oO33epry8HHFxcSgrK8Odd97p4Qp7ziuXSK68VAmbw9b7b2y68vgWziDqrBZAAzqGmRFAFICzvV8OALRebsX5S+c98+ZewGw2w2w2o6CgAFar3FcqyKInfV5f77zRSmiob9xQRdXAKj1fCpPe1PtvrIdzqlcGYB6APwP4G4BzV7a3zUADO70u8KptvcykN6H0Qqln3twLGAwGfPjhh/jkk08QEhKCpKQkvPbaa/jmm2+0Ls1ndbfPW1tbkZOTg4kTJyIy0jdupKJqYO2r3OeZ41cAEAsgB0A6gCFwjpz+DGDHVW1UnJ01tjZiX+U+9XaogUmTJuH8+fMoLCzEgw8+iN27dyMxMRF//OMftS7NZ7nb53a7HU8//TTq6uqQn+87p9moegzrnpX3YNfZXT1+fbdtBnAYwAwASwFkA7Bctf0vAAIAPO6Z3d9z2z34R+Y/PPPmXur555/Hxx9/jIaGBphMztE0j2F5Vuc+t9vtSEtLw5EjR7B9+3YMHDhQ6xJd8spjWOXV5WruDggD4ABgvvL411XbbABOA7jVc7tX/fN6gdjYWNjtdh7XUtHVfW6z2fDUU0+hrKwM27ZtkyKsukPV7zm7c2fmbmkCsBbA3QAGwHmawnk4T2UYDKAfgEQ4p4f94TytYQecB+o9+MVJs81Dn9cL1NTUYPLkycjKysLw4cMRFBSEkpISzJ8/H8nJyQgODsaPP/6IM2fOoK6uDgBw4sQJhISEYODAgT73h6QGV30eEBCAJ598Evv370dhYSEURcEPP/wAALjpppvg7++v8Sf4+VQNLI8tI2MCEAngGwA/ArADCIYzjNpOYxgL56jqCzhPg4iE8zSIXj4H62oe+UbUS5jNZiQmJmLx4sU4ceIEWlpaYLFYkJ6ejrlz5wIACgoKkJmZ2f6a7OxsAMAbb7yBP/zhD1qULTVXfX7u3Dls3rwZABAfH9/htfn5+cjIyNCg6t6l6jEs3Zs6CPSdRe4UKHC8wUt0iFzxymNYep1ezd1prq99XiJPUzWwDDr5Lw3oDqPOqHUJRD5F1cDyN8h/0K87/I196/MSeZqqgRUbFqvm7jTX1z4vkaepGljjosb57MJ9nekVPcZFjdO6DCKfomp6jLaMhtlkVnOXmgk0BWK0ZbTWZRD5FFUDKz4iHq2XW9XcpWZaL7ciPjzedUMicpuqgWUJsvSZb85MehMigiK0LoPIp6gaWIqiYOxtY9XcpWbG3DrGZxfvI9KK6idGzUqchZ1ndvZsmZkdAI7AuUyMAsAfzstsWuG8njDkSruHAdwG5zLJuQAeQsdVRRfhP5fk+MO5WoMJzjXgAecaWTo4V3IAnKs8dKOnzCYzcpJy3H8BEblF9cBKHpyMIFNQ9wPrLIDvAEyDs+pGAJfhvGbwewC7Afy602uOwnnN4BFcuwzyVDgX8NsGZxA+AuB/r2zbBmeA9XAwGOwXjPsH3d+zFxPRDal+joFO0SEnKQcBxgDXja92Cc4RT1vEBsIZVl35J5xrttdfeVxPZBfbeiDAGICcpJw+c/oGkZo0+avKujur+/ftux3ATwCWwHk7r1Mu2v8E59QuEkAcnOF1PScAxHSvlK44hAOZIzJdNySibtMksEL9Q/F4zOMwKN2YkfrBOR2cCOfoah2Ag120/yecQQUA/41rA+sjAPMBnESvrYll0BnwRMwTCPX3jQX/ibyNZlcjz//lfBR8WwC7rRtrZOkADLryuAXO5Y/vvkHbI3Ae5yq78vMlADVwLt4HOI9hmQBsgvOYVWr36r+efoZ+mP/L+T//jYjoujQ70BIZHInFDy5GoLHzrWxu4CKcgdPmBwA3ddHWBudNKV698vgfXDvKMsIZVIfxnxuu9lCgMRCLUxfDEmxx3ZiIekTTI8NZI7KQEJHg3rIzrQA2wnkzif8DUA3gvhu0PYJrj0vdceX5zoLgnBLud6vk6zLqjBhlGcVjV0QepuqKo9dTWV+JmKUxaLB56PZfKjCbzKiYWcHRFVEPeeWKo9djCbZgyUNLun+ag5cIMAZgyYNLGFZEKtA8sAAgc0QmXhj5gnShFWgMxLT4aZwKEqnEKwILABZOWIgn73hSmtAKMAbgydgnkftArtalEPUZXhNYiqJg5aMrMTl2steHVoAxAJNjJ2PFIyt4gTORirwmsADnXWbyH83HtPhpXhtaAcYATI+fjvxH83lXHCKVeVVgAc6R1sIJC7H0oaUwm8xec6cdo84Is8mMpQ8tRe6EXI6siDTgdYHVJnNEJipmVmDsrWPdP7nUQwKNgRhz6xhUzKzgAXYiDXltYAHOUx62Td2GJQ8ucY62unPtYS8w6Awwm8xY8uASbJu6jacuEGnMqwMLcE4Rs+7OwrGZxzAlbgr6GfohwODZ41sBhgD0M/TDlNgpqJhZgay7szgFJPIC3nGAyA2RwZH4y6S/oLa5FvmH8rFg9wJcar3Us5VLb8BsMiPYFIycMTnIHJHJVReIvIzml+b0lEM4sPX7rcjdk4vdZ3ej9XIrTHoTGlob3FprS6foYDaZ21835tYxyEnKwf2D7ufie0Qqc/fSHGlGWJ3pFB1SBqcgZXAKhBA4f+k8Si+UYl/lPhSfLkZ5dTmabc2wOWy47LgMvU4Po84If6M/YsNiMS5qHEZbRiM+PB4RQRGc8hFJQNrAupqiKLAEW2AJtuCRYY9oXQ4ReQjnPkQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTR84uJnn8UVJLTTjWWXSD0cYRGRNDjC8mb8v7z6OKr1ahxhEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkDZ8JrOrqasyYMQPR0dHw8/PDgAEDkJycjKKiIgDA7373O8TExCAwMBChoaFITk7G7t27Na5abq76/GovvPACFEXBggULNKjUd7jq84yMDCiK0uGRmJiocdW9x6B1Ab1l0qRJaGpqwooVKzBkyBBUVVWhuLgYNTU1AIBhw4Zh2bJlGDRoEJqbm7Fo0SKkpqbi+PHjGDBggMbVy8lVn7dZv3499u/fj4iICI0q9R3u9HlKSgpWrVrV/rPJZNKiVM8QQrj9iI+PF96otrZWABBFRUVuv+ann34SAMRXX33lwcp8l7t9furUKRERESHKy8tFVFSUeP/991WqsIcA58MLudPnU6dOFQ8//LCKVfUOACXCjQzyiSmh2WyG2WxGQUEBrFary/atra3Iy8tDcHAwRowYoUKFvsedPrfb7UhLS8PcuXNxxx13qFyh73H393znzp245ZZbMHToUGRnZ6OqqkrFKj3MnVQTXj7CEkKI9evXi9DQUOHn5ycSExNFTk6O2Lt3b4c2hYWFIjAwUCiKIiIiIsQ333yjUbW+wVWf/+Y3vxG/+tWv2n/mCOvnc9Xna9asEZs3bxZlZWWioKBADB8+XMTFxQmr1aph1a7BzRGWzwSWEEI0NzeLLVu2iDfffFMkJSUJAOLdd99t397Q0CCOHz8u9uzZI7KyskRUVJQ4f/68hhXL70Z9vn37dhERESGqqqra2zKweoer3/OrVVZWCoPBIDZs2KByld3TJwOrs+eee04YjUbR0tJy3e1DhgwRb731lspV+ba2Pp8zZ45QFEXo9fr2BwCh0+mExWLRuswbkyCwOnP1ex4dHS3mzZunclXd425g+cy3hNcTGxsLu90Oq9V63W9KHA4HWlpaNKjMd7X1+fTp05Gent5h24QJE5CWlobs7GyNqvNNXf2eX7x4EZWVlQgPD9eout7lE4FVU1ODyZMnIysrC8OHD0dQUBBKSkowf/58JCcnAwDmzp2LiRMnIjw8HNXV1Vi2bBnOnTuHKVOmaFy9nFz1+W233XbNa4xGIwYOHIhhw4ZpULH8XPW5TqfDa6+9hkmTJiE8PBynTp3CnDlzcMstt+Dxxx/Xuvxe4ROBZTabkZiYiMWLF+PEiRNoaWmBxWJBeno65s6dC4PBgKNHj2LlypWoqanBzTffjFGjRmHHjh0YPny41uVLyVWfU+9z1ed6vR5HjhzBxx9/jLq6OoSHh2P8+PFYu3YtgoKCtC6/VyjO6aN7EhISRElJiQfLIdKYojj/7cbfBf18iqKUCiESXLXzifOwiKhvYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0GFhEJA2D1gVQFxTF+a8Q2tbRF7X1PXkVjrCISBocYRFdjaNZbbg5ouUIi4ikwcAiImkwsIhIGgwsIpIGA4uIpMHAIiJpMLCISBoMLCKSBgOLiKTBwCIiaTCwiEgaDCwikgYDi4ikwcAiImkwsIhIGgwsIpIGA4uIpMHAIiJpMLCISBoMLCKSBgOLiKTBwCIiaTCwiEgaDCwikgYDi4ikwcAiImkwsIhIGgwsIpIGA4uIpMHAIiJpMLCISBoMLCKSBgOLiKThM4FVXV2NGTNmIDo6Gn5+fhgwYACSk5NRVFTU3ua7777DE088gZCQEAQEBGDkyJE4duyYhlXLzVWfK4py3cfMmTM1rlxervq8oaEBL730EiIjI+Hv749hw4Zh0aJFGlfdewxaF9BbJk2ahKamJqxYsQJDhgxBVVUViouLUVNTAwD4/vvvMXbsWDz77LPYunUrQkJCUFFRAbPZrHHl8nLV5xcuXOjQvqSkBBMnTsSUKVO0KNcnuOrzWbNm4e9//ztWrVqFQYMGYceOHcjOzkb//v3xzDPPaFx9LxBCuP2Ij48X3qi2tlYAEEVFRTdsk5aWJtLT01WsqhcAzocXcqfPO3v++efF0KFDPViVb3Onz+Pi4sTvf//7Ds/de++9YubMmZ4u72cBUCLcyCCfmBKazWaYzWYUFBTAarVes93hcKCwsBCxsbFITU1FWFgYRo0ahU8//VSDan2Dqz7vrKGhAX/961+RnZ2tQnW+yZ0+v+eee1BYWIizZ88CAHbv3o1Dhw4hNTVVzVI9x51UE14+whJCiPXr14vQ0FDh5+cnEhMTRU5Ojti7d68QQogLFy4IACIgIEDk5uaKgwcPitzcXKHX60VhYaHGlXfBi0dYQnTd550tX75cGI1GUVVVpXKVvsVVn7e0tIjMzEwBQBgMBmEwGMQHH3ygYcXugZsjLJ8JLCGEaG5uFlu2bBFvvvmmSEpKEgDEu+++KyorKwUAkZaW1qF9WlqaSE1N1ahaN3h5YAlx4z7vLCEhQUyePFmDCn1PV32+YMECMXToUFFQUCAOHz4s/vSnP4nAwEDx5Zdfalx11/pkYHX23HPPCaPRKFpaWoTBYBBvv/12h+1vvfWWiI2N1ag6N0gQWJ1d3edtDh48KACILVu2aFiZ72rr87q6OmE0GsWmTZuu2Z6cnKxRde5xN7B84hjWjcTGxsJut8NqtWLUqFH49ttvO2z/7rvvEBUVpVF1vunqPm+Tl5eH6OhopKSkaFiZ72rrc0VRYLPZoNfrO2zX6/VwOBwaVdfL3Ek14eUjrIsXL4rx48eLVatWicOHD4uTJ0+KtWvXigEDBoiUlBQhhBAbN24URqNRLF++XBw/flzk5eUJg8EgPv/8c42r74IXj7Dc6XMhhGhsbBTBwcHinXfe0bBa3+BOn48bN07ExcWJbdu2iZMnT4r8/HzRr18/sWTJEo2r7xr60pTQarWKOXPmiISEBBESEiL8/f3FkCFDxKuvvipqamra2+Xn54tf/OIXol+/fuLOO+8Uq1ev1rBqN3hxYLnb5ytXrhR6vV5UVlZqWK1vcKfPL1y4IDIyMkRERITo16+fGDZsmHj//feFw+HQuPquuRtYirOtexISEkRJSYnHRnvUiaI4/+3GfyMiGSmKUiqESHDVzqePYRGRb2FgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGAReal///vfSE9Px+DBgxEfH4+kpCRs3LgRALBz506MHj0aMTExiImJQV5e3jWvv+uuu5CWltbhuYyMDKxfv16V+j3BZ9Z0J/IlQgg89thjmDp1KlavXg0AOH36NAoKCvDDDz8gPT0dmzZtwsiRI3Hx4kVMmDABFosFDz/8MADg2LFjcDgc2LFjBxobGxEYGKjlx+k1HGEReaGtW7fCZDJh+vTp7c9FRUXhpZdewrJly5CRkYGRI0cCAPr374/58+dj3rx57W1Xr16NZ555Bg888AAKCgpUr99TGFhEXujo0aPtgXS9bfHx8R2eS0hIwNGjR9t//vTTT/HUU08hLS0Na9as8WitamJgEUlg5syZuOuuuzBq1CjnMittK3lcpe25/fv3IywsDFFRUUhOTsaBAwdQW1urdskewcAi8kJxcXE4cOBA+8/Lli3D119/jerqasTFxaHzMk+lpaWIjY0FAKxZswYVFRWIjo7G7bffjvr6emzYsEHV+j2FgUXkhe6//35YrVZ88MEH7c81NTUBcI62PvzwQxw6dAgAUFNTg9dffx2zZ8+Gw+HAunXrUFZWhlOnTuHUqVPYvHmzz0wLGVhEXkhRFGzatAnFxcUYNGgQRo8ejalTp+K9995DeHg4PvnkE2RnZyMmJgZjxoxBVlYWJk6ciB07dsBiscBisbS/17333ovy8vL2O3FPmzYNkZGRiIyMRFJSklYfsUe44qg344qj1EdwxVEi8jkMLCKSBgOLiKTBwCIiaTCwiEgaDCwikgYDi4ikwcAiImkwsIhIGgwsIpIGA4uIpMHAIiJpMLCISBoMLCKSBgOLiKTBwCIiaTCwiEgaDCwikgYDi4ikwcAiImkwsIhIGgwsIpIGA4uIpMHAIiJpMLCISBoMLCKSRrduVa8oSjWA054rh4j6qCghRJirRt0KLCIiLXFKSETSYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNL4f/+izvyQK8sjAAAAAElFTkSuQmCC\n", 30 | "text/plain": [ 31 | "
" 32 | ] 33 | }, 34 | "metadata": {}, 35 | "output_type": "display_data" 36 | } 37 | ], 38 | "source": [ 39 | "# 초기 상태의 미로 모습\n", 40 | "\n", 41 | "# 전체 그림의 크기 및 그림을 나타내는 변수 선언\n", 42 | "fig = plt.figure(figsize=(5, 5))\n", 43 | "ax = plt.gca()\n", 44 | "\n", 45 | "# 붉은 벽 그리기\n", 46 | "plt.plot([1, 1], [0, 1], color='red', linewidth=2)\n", 47 | "plt.plot([1, 2], [2, 2], color='red', linewidth=2)\n", 48 | "plt.plot([2, 2], [2, 1], color='red', linewidth=2)\n", 49 | "plt.plot([2, 3], [1, 1], color='red', linewidth=2)\n", 50 | "\n", 51 | "# 상태를 의미하는 문자열(S0~S8) 표시\n", 52 | "plt.text(0.5, 2.5, 'S0', size=14, ha='center')\n", 53 | "plt.text(1.5, 2.5, 'S1', size=14, ha='center')\n", 54 | "plt.text(2.5, 2.5, 'S2', size=14, ha='center')\n", 55 | "plt.text(0.5, 1.5, 'S3', size=14, ha='center')\n", 56 | "plt.text(1.5, 1.5, 'S4', size=14, ha='center')\n", 57 | "plt.text(2.5, 1.5, 'S5', size=14, ha='center')\n", 58 | "plt.text(0.5, 0.5, 'S6', size=14, ha='center')\n", 59 | "plt.text(1.5, 0.5, 'S7', size=14, ha='center')\n", 60 | "plt.text(2.5, 0.5, 'S8', size=14, ha='center')\n", 61 | "plt.text(0.5, 2.3, 'START', ha='center')\n", 62 | "plt.text(2.5, 0.3, 'GOAL', ha='center')\n", 63 | "\n", 64 | "# 그림을 그릴 범위 및 눈금 제거 설정\n", 65 | "ax.set_xlim(0, 3)\n", 66 | "ax.set_ylim(0, 3)\n", 67 | "plt.tick_params(axis='both', which='both', bottom=False, top=False,\n", 68 | " labelbottom=False, right=False, left=False, labelleft=False)\n", 69 | "\n", 70 | "# S0에 녹색 원으로 현재 위치를 표시\n", 71 | "line, = ax.plot([0.5], [2.5], marker=\"o\", color='g', markersize=60)\n" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "# 정책을 결정하는 파라미터의 초깃값 theta_0를 설정\n", 81 | "\n", 82 | "# 줄은 상태 0~7, 열은 행동방향(상,우,하,좌 순)를 나타낸다\n", 83 | "theta_0 = np.array([[np.nan, 1, 1, np.nan], # s0\n", 84 | " [np.nan, 1, np.nan, 1], # s1\n", 85 | " [np.nan, np.nan, 1, 1], # s2\n", 86 | " [1, 1, 1, np.nan], # s3\n", 87 | " [np.nan, np.nan, 1, 1], # s4\n", 88 | " [1, np.nan, np.nan, np.nan], # s5\n", 89 | " [1, np.nan, np.nan, np.nan], # s6\n", 90 | " [1, 1, np.nan, np.nan], # s7、※s8은 목표지점이므로 정책이 없다\n", 91 | " ])" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 5, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "# 정책 파라미터 theta_0을 무작위 행동 정책 pi로 변환하는 함수\n", 101 | "\n", 102 | "def simple_convert_into_pi_from_theta(theta):\n", 103 | " '''단순 비율 계산'''\n", 104 | "\n", 105 | " [m, n] = theta.shape # theta의 행렬 크기를 구함\n", 106 | " pi = np.zeros((m, n))\n", 107 | " for i in range(0, m):\n", 108 | " pi[i, :] = theta[i, :] / np.nansum(theta[i, :]) # 비율 계산\n", 109 | "\n", 110 | " pi = np.nan_to_num(pi) # nan을 0으로 변환\n", 111 | "\n", 112 | " return pi\n", 113 | "\n", 114 | "# 무작위 행동정책 pi_0을 계산\n", 115 | "pi_0 = simple_convert_into_pi_from_theta(theta_0)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 6, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# 행동가치 함수 Q의 초기 상태\n", 125 | "\n", 126 | "[a, b] = theta_0.shape # 열과 행의 갯수를 변수 a, b에 저장\n", 127 | "Q = np.random.rand(a, b) * theta_0\n", 128 | "# * theta0 로 요소 단위 곱셈을 수행, Q에서 벽 방향으로 이동하는 행동에는 nan을 부여\n" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 7, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "# ε-greedy 알고리즘 구현\n", 138 | "\n", 139 | "\n", 140 | "def get_action(s, Q, epsilon, pi_0):\n", 141 | " direction = [\"up\", \"right\", \"down\", \"left\"]\n", 142 | "\n", 143 | " # 행동을 결정\n", 144 | " if np.random.rand() < epsilon:\n", 145 | " # 확률 ε로 무작위 행동을 선택함\n", 146 | " next_direction = np.random.choice(direction, p=pi_0[s, :])\n", 147 | " else:\n", 148 | " # Q값이 최대가 되는 행동을 선택함\n", 149 | " next_direction = direction[np.nanargmax(Q[s, :])]\n", 150 | "\n", 151 | " # 행동을 인덱스로 변환\n", 152 | " if next_direction == \"up\":\n", 153 | " action = 0\n", 154 | " elif next_direction == \"right\":\n", 155 | " action = 1\n", 156 | " elif next_direction == \"down\":\n", 157 | " action = 2\n", 158 | " elif next_direction == \"left\":\n", 159 | " action = 3\n", 160 | "\n", 161 | " return action\n", 162 | "\n", 163 | "\n", 164 | "def get_s_next(s, a, Q, epsilon, pi_0):\n", 165 | " direction = [\"up\", \"right\", \"down\", \"left\"]\n", 166 | " next_direction = direction[a] # 행동 a의 방향\n", 167 | "\n", 168 | " # 행동으로 다음 상태를 결정\n", 169 | " if next_direction == \"up\":\n", 170 | " s_next = s - 3 # 위로 이동하면 상태값이 3 줄어든다\n", 171 | " elif next_direction == \"right\":\n", 172 | " s_next = s + 1 # 오른쪽으로 이동하면 상태값이 1 늘어난다\n", 173 | " elif next_direction == \"down\":\n", 174 | " s_next = s + 3 # 아래로 이동하면 상태값이 3 늘어난다\n", 175 | " elif next_direction == \"left\":\n", 176 | " s_next = s - 1 # 왼쪽으로 이동하면 상태값이 1 줄어든다\n", 177 | "\n", 178 | " return s_next" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 8, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "# Sarsa 알고리즘으로 행동가치 함수 Q를 수정\n", 188 | "\n", 189 | "def Sarsa(s, a, r, s_next, a_next, Q, eta, gamma):\n", 190 | "\n", 191 | " if s_next == 8: # 목표 지점에 도달한 경우\n", 192 | " Q[s, a] = Q[s, a] + eta * (r - Q[s, a])\n", 193 | "\n", 194 | " else:\n", 195 | " Q[s, a] = Q[s, a] + eta * (r + gamma * Q[s_next, a_next] - Q[s, a])\n", 196 | "\n", 197 | " return Q\n" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 9, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "# Sarsa 알고리즘으로 미로를 빠져나오는 함수, 상태 및 행동 그리고 Q값의 히스토리를 출력한다\n", 207 | "\n", 208 | "def goal_maze_ret_s_a_Q(Q, epsilon, eta, gamma, pi):\n", 209 | " s = 0 # 시작 지점\n", 210 | " a = a_next = get_action(s, Q, epsilon, pi) # 첫 번째 행동\n", 211 | " s_a_history = [[0, np.nan]] # 에이전트의 행동 및 상태의 히스토리를 기록하는 리스트\n", 212 | "\n", 213 | " while (1): # 목표 지점에 이를 때까지 반복\n", 214 | " a = a_next # 행동 결정\n", 215 | "\n", 216 | " s_a_history[-1][1] = a\n", 217 | " # 현재 상태(마지막이므로 인덱스가 -1)을 히스토리에 추가\n", 218 | "\n", 219 | " s_next = get_s_next(s, a, Q, epsilon, pi)\n", 220 | " # 다음 단계의 상태를 구함\n", 221 | "\n", 222 | " s_a_history.append([s_next, np.nan])\n", 223 | " # 다음 상태를 히스토리에 추가, 행동은 아직 알 수 없으므로 nan으로 둔다\n", 224 | "\n", 225 | " # 보상을 부여하고 다음 행동을 계산함\n", 226 | " if s_next == 8:\n", 227 | " r = 1 # 목표 지점에 도달했다면 보상을 부여\n", 228 | " a_next = np.nan\n", 229 | " else:\n", 230 | " r = 0\n", 231 | " a_next = get_action(s_next, Q, epsilon, pi)\n", 232 | " # 다음 행동 a_next를 계산\n", 233 | "\n", 234 | " # 가치함수를 수정\n", 235 | " Q = Sarsa(s, a, r, s_next, a_next, Q, eta, gamma)\n", 236 | "\n", 237 | " # 종료 여부 판정\n", 238 | " if s_next == 8: # 목표 지점에 도달하면 종료\n", 239 | " break\n", 240 | " else:\n", 241 | " s = s_next\n", 242 | "\n", 243 | " return [s_a_history, Q]" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 10, 249 | "metadata": {}, 250 | "outputs": [ 251 | { 252 | "name": "stdout", 253 | "output_type": "stream", 254 | "text": [ 255 | "에피소드: 1\n", 256 | "2.9123690381123595\n", 257 | "목표 지점에 이르기까지 걸린 단계 수는 756단계입니다\n", 258 | "에피소드: 2\n", 259 | "0.06470294176990493\n", 260 | "목표 지점에 이르기까지 걸린 단계 수는 6단계입니다\n", 261 | "에피소드: 3\n", 262 | "0.08962131184352168\n", 263 | "목표 지점에 이르기까지 걸린 단계 수는 16단계입니다\n", 264 | "에피소드: 4\n", 265 | "0.07088647193553937\n", 266 | "목표 지점에 이르기까지 걸린 단계 수는 10단계입니다\n", 267 | "에피소드: 5\n", 268 | "0.057825165390497923\n", 269 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 270 | "에피소드: 6\n", 271 | "0.069115365071926\n", 272 | "목표 지점에 이르기까지 걸린 단계 수는 10단계입니다\n", 273 | "에피소드: 7\n", 274 | "0.055847280693343826\n", 275 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 276 | "에피소드: 8\n", 277 | "0.0646771466375029\n", 278 | "목표 지점에 이르기까지 걸린 단계 수는 10단계입니다\n", 279 | "에피소드: 9\n", 280 | "0.054024188311407206\n", 281 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 282 | "에피소드: 10\n", 283 | "0.05343429531742172\n", 284 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 285 | "에피소드: 11\n", 286 | "0.05278635429752665\n", 287 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 288 | "에피소드: 12\n", 289 | "0.05207874386146988\n", 290 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 291 | "에피소드: 13\n", 292 | "0.05131093120979935\n", 293 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 294 | "에피소드: 14\n", 295 | "0.05048342516236182\n", 296 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 297 | "에피소드: 15\n", 298 | "0.04959769495031635\n", 299 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 300 | "에피소드: 16\n", 301 | "0.04865606709232484\n", 302 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 303 | "에피소드: 17\n", 304 | "0.04766160988053009\n", 305 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 306 | "에피소드: 18\n", 307 | "0.046618012727305924\n", 308 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 309 | "에피소드: 19\n", 310 | "0.045529465783242185\n", 311 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 312 | "에피소드: 20\n", 313 | "0.04440054375731639\n", 314 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 315 | "에피소드: 21\n", 316 | "0.04323609668999023\n", 317 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 318 | "에피소드: 22\n", 319 | "0.04204114949695603\n", 320 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 321 | "에피소드: 23\n", 322 | "0.0408208113716384\n", 323 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 324 | "에피소드: 24\n", 325 | "0.03958019557152842\n", 326 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 327 | "에피소드: 25\n", 328 | "0.03832434968616055\n", 329 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 330 | "에피소드: 26\n", 331 | "0.03705819616729522\n", 332 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 333 | "에피소드: 27\n", 334 | "0.035786482673166475\n", 335 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 336 | "에피소드: 28\n", 337 | "0.03451374162070997\n", 338 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 339 | "에피소드: 29\n", 340 | "0.03324425823770877\n", 341 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 342 | "에피소드: 30\n", 343 | "0.03198204634869278\n", 344 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 345 | "에피소드: 31\n", 346 | "0.030730831104155143\n", 347 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 348 | "에피소드: 32\n", 349 | "0.029494037864122025\n", 350 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 351 | "에피소드: 33\n", 352 | "0.028274786467673507\n", 353 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 354 | "에피소드: 34\n", 355 | "0.027075890154329874\n", 356 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 357 | "에피소드: 35\n", 358 | "0.02589985844703846\n", 359 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 360 | "에피소드: 36\n", 361 | "0.02474890335638047\n", 362 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 363 | "에피소드: 37\n", 364 | "0.023624948318948458\n", 365 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 366 | "에피소드: 38\n", 367 | "0.022529639337513174\n", 368 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 369 | "에피소드: 39\n", 370 | "0.021464357845061732\n", 371 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 372 | "에피소드: 40\n", 373 | "0.02043023486784079\n", 374 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 375 | "에피소드: 41\n", 376 | "0.019428166113352963\n", 377 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 378 | "에피소드: 42\n", 379 | "0.018458827657222066\n", 380 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 381 | "에피소드: 43\n", 382 | "0.017522691947583047\n", 383 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 384 | "에피소드: 44\n", 385 | "0.016620043886945823\n", 386 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 387 | "에피소드: 45\n", 388 | "0.01575099678924119\n", 389 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 390 | "에피소드: 46\n", 391 | "0.01491550804395958\n", 392 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 393 | "에피소드: 47\n", 394 | "0.014113394350067199\n", 395 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 396 | "에피소드: 48\n", 397 | "0.013344346409800645\n", 398 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 399 | "에피소드: 49\n", 400 | "0.012607942996720523\n", 401 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 402 | "에피소드: 50\n", 403 | "0.011903664333717257\n", 404 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 405 | "에피소드: 51\n", 406 | "0.011230904735219371\n", 407 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 408 | "에피소드: 52\n", 409 | "0.010588984483891117\n", 410 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 411 | "에피소드: 53\n", 412 | "0.009977160925812911\n", 413 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 414 | "에피소드: 54\n", 415 | "0.009394638779763542\n", 416 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 417 | "에피소드: 55\n", 418 | "0.00884057966594709\n", 419 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 420 | "에피소드: 56\n", 421 | "0.008314110867547964\n", 422 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 423 | "에피소드: 57\n", 424 | "0.007814333345034785\n", 425 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 426 | "에피소드: 58\n", 427 | "0.0073403290283408085\n", 428 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 429 | "에피소드: 59\n", 430 | "0.006891167416099742\n", 431 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 432 | "에피소드: 60\n", 433 | "0.006465911514145994\n", 434 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 435 | "에피소드: 61\n", 436 | "0.006063623147645414\n", 437 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 438 | "에피소드: 62\n", 439 | "0.00568336768262645\n", 440 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 441 | "에피소드: 63\n", 442 | "0.005324218193444752\n", 443 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 444 | "에피소드: 64\n", 445 | "0.004985259112940121\n", 446 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 447 | "에피소드: 65\n", 448 | "0.004665589401810943\n", 449 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 450 | "에피소드: 66\n", 451 | "0.004364325273144343\n", 452 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 453 | "에피소드: 67\n", 454 | "0.004080602507132602\n", 455 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 456 | "에피소드: 68\n", 457 | "0.003813578389876615\n", 458 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 459 | "에피소드: 69\n", 460 | "0.0035624333088519755\n", 461 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 462 | "에피소드: 70\n", 463 | "0.0033263720361560445\n", 464 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 465 | "에피소드: 71\n", 466 | "0.003104624729090788\n", 467 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 468 | "에피소드: 72\n", 469 | "0.0028964476760231506\n", 470 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 471 | "에피소드: 73\n", 472 | "0.002701123813795725\n", 473 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 474 | "에피소드: 74\n", 475 | "0.0025179630412982545\n", 476 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 477 | "에피소드: 75\n", 478 | "0.0023463023521538284\n", 479 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 480 | "에피소드: 76\n", 481 | "0.002185505807833943\n", 482 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 483 | "에피소드: 77\n", 484 | "0.002034964370928427\n", 485 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 486 | "에피소드: 78\n", 487 | "0.0018940956167549095\n", 488 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 489 | "에피소드: 79\n", 490 | "0.0017623433400003607\n", 491 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 492 | "에피소드: 80\n", 493 | "0.0016391770716805976\n", 494 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 495 | "에피소드: 81\n", 496 | "0.0015240915203390548\n", 497 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 498 | "에피소드: 82\n", 499 | "0.0014166059501411477\n", 500 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 501 | "에피소드: 83\n", 502 | "0.0013162635073069584\n", 503 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 504 | "에피소드: 84\n", 505 | "0.001222630505198108\n", 506 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 507 | "에피소드: 85\n", 508 | "0.0011352956773180711\n", 509 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 510 | "에피소드: 86\n", 511 | "0.0010538694065022058\n", 512 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 513 | "에피소드: 87\n", 514 | "0.0009779829376552751\n", 515 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 516 | "에피소드: 88\n", 517 | "0.0009072875805578029\n", 518 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 519 | "에피소드: 89\n", 520 | "0.0008414539084752315\n", 521 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 522 | "에피소드: 90\n", 523 | "0.0007801709575961935\n", 524 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 525 | "에피소드: 91\n", 526 | "0.0007231454316684038\n", 527 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 528 | "에피소드: 92\n", 529 | "0.0006701009155980486\n", 530 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 531 | "에피소드: 93\n", 532 | "0.0006207771012446406\n", 533 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 534 | "에피소드: 94\n", 535 | "0.0005749290281399366\n", 536 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 537 | "에피소드: 95\n", 538 | "0.0005323263414217516\n", 539 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 540 | "에피소드: 96\n", 541 | "0.000492752568863164\n", 542 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 543 | "에피소드: 97\n", 544 | "0.0004560044185274448\n", 545 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 546 | "에피소드: 98\n", 547 | "0.0004218910982470847\n", 548 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 549 | "에피소드: 99\n", 550 | "0.00039023365784385255\n", 551 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n", 552 | "에피소드: 100\n", 553 | "0.0003608643547524659\n", 554 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n" 555 | ] 556 | } 557 | ], 558 | "source": [ 559 | "# Sarsa 알고리즘으로 미로 빠져나오기\n", 560 | "\n", 561 | "eta = 0.1 # 학습률\n", 562 | "gamma = 0.9 # 시간할인율\n", 563 | "epsilon = 0.5 # ε-greedy 알고리즘 epsilon 초깃값\n", 564 | "v = np.nanmax(Q, axis=1) # 각 상태마다 가치의 최댓값을 계산\n", 565 | "is_continue = True\n", 566 | "episode = 1\n", 567 | "\n", 568 | "while is_continue: # is_continue의 값이 False가 될 때까지 반복\n", 569 | " print(\"에피소드: \" + str(episode))\n", 570 | "\n", 571 | " # ε 값을 조금씩 감소시킴\n", 572 | " epsilon = epsilon / 2\n", 573 | "\n", 574 | " # Sarsa 알고리즘으로 미로를 빠져나온 후, 결과로 나온 행동 히스토리와 Q값을 변수에 저장\n", 575 | " [s_a_history, Q] = goal_maze_ret_s_a_Q(Q, epsilon, eta, gamma, pi_0)\n", 576 | "\n", 577 | " # 상태가치의 변화\n", 578 | " new_v = np.nanmax(Q, axis=1) # 각 상태마다 행동가치의 최댓값을 계산\n", 579 | " print(np.sum(np.abs(new_v - v))) # 상태가치 함수의 변화를 출력\n", 580 | " v = new_v\n", 581 | "\n", 582 | " print(\"목표 지점에 이르기까지 걸린 단계 수는 \" + str(len(s_a_history) - 1) + \"단계입니다\")\n", 583 | "\n", 584 | " # 100 에피소드 반복\n", 585 | " episode = episode + 1\n", 586 | " if episode > 100:\n", 587 | " break\n" 588 | ] 589 | } 590 | ], 591 | "metadata": { 592 | "kernelspec": { 593 | "display_name": "Python 3", 594 | "language": "python", 595 | "name": "python3" 596 | }, 597 | "language_info": { 598 | "codemirror_mode": { 599 | "name": "ipython", 600 | "version": 3 601 | }, 602 | "file_extension": ".py", 603 | "mimetype": "text/x-python", 604 | "name": "python", 605 | "nbconvert_exporter": "python", 606 | "pygments_lexer": "ipython3", 607 | "version": "3.6.6" 608 | } 609 | }, 610 | "nbformat": 4, 611 | "nbformat_minor": 2 612 | } 613 | -------------------------------------------------------------------------------- /program/3_3_digitized_CartPole.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 3.3 다변수 연속값 상태를 이산변수로 변환하기" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# 구현에 사용할 패키지 임포트\n", 17 | "import numpy as np\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "%matplotlib inline\n", 20 | "import gym\n" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "# 상수 정의\n", 30 | "ENV = 'CartPole-v0' # 태스크 이름\n", 31 | "NUM_DIZITIZED = 6 # 각 상태를 이산변수로 변환할 구간 수\n" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "# CartPole 실행\n", 49 | "env = gym.make(ENV) # 태스크 실행 환경 생성\n", 50 | "observation = env.reset() # 환경 초기화\n" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# 이산변수 변환에 사용할 구간을 계산\n", 60 | "\n", 61 | "\n", 62 | "def bins(clip_min, clip_max, num):\n", 63 | " '''관측된 상태(연속값)을 이산값으로 변환하는 구간을 계산'''\n", 64 | " return np.linspace(clip_min, clip_max, num + 1)[1:-1]" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 5, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "data": { 74 | "text/plain": [ 75 | "array([-2.4, -1.6, -0.8, 0. , 0.8, 1.6, 2.4])" 76 | ] 77 | }, 78 | "execution_count": 5, 79 | "metadata": {}, 80 | "output_type": "execute_result" 81 | } 82 | ], 83 | "source": [ 84 | "np.linspace(-2.4, 2.4, 6 + 1)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 6, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "data": { 94 | "text/plain": [ 95 | "array([-1.6, -0.8, 0. , 0.8, 1.6])" 96 | ] 97 | }, 98 | "execution_count": 6, 99 | "metadata": {}, 100 | "output_type": "execute_result" 101 | } 102 | ], 103 | "source": [ 104 | "np.linspace(-2.4, 2.4, 6 + 1)[1:-1]" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 10, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "def digitize_state(observation):\n", 114 | " '''관측된 상태(observation)을 이산값으로 변환'''\n", 115 | " cart_pos, cart_v, pole_angle, pole_v = observation\n", 116 | " digitized = [\n", 117 | " np.digitize(cart_pos, bins=bins(-2.4, 2.4, NUM_DIZITIZED)),\n", 118 | " np.digitize(cart_v, bins=bins(-3.0, 3.0, NUM_DIZITIZED)),\n", 119 | " np.digitize(pole_angle, bins=bins(-0.5, 0.5, NUM_DIZITIZED)),\n", 120 | " np.digitize(pole_v, bins=bins(-2.0, 2.0, NUM_DIZITIZED))]\n", 121 | " return sum([x * (NUM_DIZITIZED**i) for i, x in enumerate(digitized)])\n" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 11, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/plain": [ 132 | "525" 133 | ] 134 | }, 135 | "execution_count": 11, 136 | "metadata": {}, 137 | "output_type": "execute_result" 138 | } 139 | ], 140 | "source": [ 141 | "digitize_state(observation)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": { 148 | "collapsed": true 149 | }, 150 | "outputs": [], 151 | "source": [] 152 | } 153 | ], 154 | "metadata": { 155 | "kernelspec": { 156 | "display_name": "Python 3", 157 | "language": "python", 158 | "name": "python3" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": { 162 | "name": "ipython", 163 | "version": 3 164 | }, 165 | "file_extension": ".py", 166 | "mimetype": "text/x-python", 167 | "name": "python", 168 | "nbconvert_exporter": "python", 169 | "pygments_lexer": "ipython3", 170 | "version": "3.6.6" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 2 175 | } 176 | -------------------------------------------------------------------------------- /program/3_4_Qlearning_CartPole.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 3.4 Q러닝 구현" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# 구현에 사용할 패키지 임포트\n", 17 | "import numpy as np\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "%matplotlib inline\n", 20 | "import gym\n" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "# 애니메이션을 만드는 함수\n", 30 | "# 참고 URL http://nbviewer.jupyter.org/github/patrickmineault\n", 31 | "# /xcorr-notebooks/blob/master/Render%20OpenAI%20gym%20as%20GIF.ipynb\n", 32 | "from JSAnimation.IPython_display import display_animation\n", 33 | "from matplotlib import animation\n", 34 | "from IPython.display import display\n", 35 | "\n", 36 | "\n", 37 | "def display_frames_as_gif(frames):\n", 38 | " \"\"\"\n", 39 | " Displays a list of frames as a gif, with controls\n", 40 | " \"\"\"\n", 41 | " plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0),\n", 42 | " dpi=72)\n", 43 | " patch = plt.imshow(frames[0])\n", 44 | " plt.axis('off')\n", 45 | "\n", 46 | " def animate(i):\n", 47 | " patch.set_data(frames[i])\n", 48 | "\n", 49 | " anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames),\n", 50 | " interval=50)\n", 51 | "\n", 52 | " anim.save('movie_cartpole.mp4') # 애니메이션을 저장하는 부분\n", 53 | " display(display_animation(anim, default_mode='loop'))\n", 54 | " " 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "# 상수 정의\n", 64 | "ENV = 'CartPole-v0' # 태스크 이름\n", 65 | "NUM_DIZITIZED = 6 # 각 상태를 이산변수로 변환할 구간 수\n", 66 | "GAMMA = 0.99 # 시간할인율\n", 67 | "ETA = 0.5 # 학습률\n", 68 | "MAX_STEPS = 200 # 1에피소드 당 최대 단계 수\n", 69 | "NUM_EPISODES = 1000 # 최대 에피소드 수\n" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 4, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "class Agent:\n", 79 | " '''CartPole 에이전트 역할을 할 클래스, 봉 달린 수레이다'''\n", 80 | "\n", 81 | " def __init__(self, num_states, num_actions):\n", 82 | " self.brain = Brain(num_states, num_actions) # 에이전트가 행동을 결정하는 두뇌 역할\n", 83 | "\n", 84 | " def update_Q_function(self, observation, action, reward, observation_next):\n", 85 | " '''Q함수 수정'''\n", 86 | " self.brain.update_Q_table(\n", 87 | " observation, action, reward, observation_next)\n", 88 | "\n", 89 | " def get_action(self, observation, step):\n", 90 | " '''행동 결정'''\n", 91 | " action = self.brain.decide_action(observation, step)\n", 92 | " return action\n", 93 | " " 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 5, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "class Brain:\n", 103 | " '''에이전트의 두뇌 역할을 하는 클래스, Q러닝을 실제 수행'''\n", 104 | "\n", 105 | " def __init__(self, num_states, num_actions):\n", 106 | " self.num_actions = num_actions # 행동의 가짓수(왼쪽, 오른쪽)를 구함\n", 107 | "\n", 108 | " # Q테이블을 생성. 줄 수는 상태를 구간수^4(4는 변수의 수)가지 값 중 하나로 변환한 값, 열 수는 행동의 가짓수\n", 109 | " self.q_table = np.random.uniform(low=0, high=1, size=(\n", 110 | " NUM_DIZITIZED**num_states, num_actions))\n", 111 | "\n", 112 | "\n", 113 | " def bins(self, clip_min, clip_max, num):\n", 114 | " '''관측된 상태(연속값)를 이산변수로 변환하는 구간을 계산'''\n", 115 | " return np.linspace(clip_min, clip_max, num + 1)[1:-1]\n", 116 | "\n", 117 | " def digitize_state(self, observation):\n", 118 | " '''관측된 상태 observation을 이산변수로 변환'''\n", 119 | " cart_pos, cart_v, pole_angle, pole_v = observation\n", 120 | " digitized = [\n", 121 | " np.digitize(cart_pos, bins=self.bins(-2.4, 2.4, NUM_DIZITIZED)),\n", 122 | " np.digitize(cart_v, bins=self.bins(-3.0, 3.0, NUM_DIZITIZED)),\n", 123 | " np.digitize(pole_angle, bins=self.bins(-0.5, 0.5, NUM_DIZITIZED)),\n", 124 | " np.digitize(pole_v, bins=self.bins(-2.0, 2.0, NUM_DIZITIZED))\n", 125 | " ]\n", 126 | " return sum([x * (NUM_DIZITIZED**i) for i, x in enumerate(digitized)])\n", 127 | "\n", 128 | " def update_Q_table(self, observation, action, reward, observation_next):\n", 129 | " '''Q러닝으로 Q테이블을 수정'''\n", 130 | " state = self.digitize_state(observation) # 상태를 이산변수로 변환\n", 131 | " state_next = self.digitize_state(observation_next) # 다음 상태를 이산변수로 변환\n", 132 | " Max_Q_next = max(self.q_table[state_next][:])\n", 133 | " self.q_table[state, action] = self.q_table[state, action] + \\\n", 134 | " ETA * (reward + GAMMA * Max_Q_next - self.q_table[state, action])\n", 135 | "\n", 136 | " def decide_action(self, observation, episode):\n", 137 | " '''ε-greedy 알고리즘을 적용하여 서서히 최적행동의 비중을 늘림'''\n", 138 | " state = self.digitize_state(observation)\n", 139 | " epsilon = 0.5 * (1 / (episode + 1))\n", 140 | "\n", 141 | " if epsilon <= np.random.uniform(0, 1):\n", 142 | " action = np.argmax(self.q_table[state][:])\n", 143 | " else:\n", 144 | " action = np.random.choice(self.num_actions) # 0,1 두 가지 행동 중 하나를 무작위로 선택\n", 145 | " return action\n", 146 | " " 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 8, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "class Environment:\n", 156 | " '''CartPole을 실행하는 환경 역할을 하는 클래스'''\n", 157 | "\n", 158 | " def __init__(self):\n", 159 | " self.env = gym.make(ENV) # 실행할 태스크를 설정\n", 160 | " num_states = self.env.observation_space.shape[0] # 태스크의 상태 변수 수를 구함\n", 161 | " num_actions = self.env.action_space.n # 가능한 행동 수를 구함\n", 162 | " self.agent = Agent(num_states, num_actions) # 에이전트 객체를 생성\n", 163 | "\n", 164 | " def run(self):\n", 165 | " '''실행'''\n", 166 | " complete_episodes = 0 # 성공한(195단계 이상 버틴) 에피소드 수\n", 167 | " is_episode_final = False # 마지막 에피소드 여부\n", 168 | " frames = [] # 애니메이션을 만드는데 사용할 이미지를 저장하는 변수\n", 169 | "\n", 170 | " for episode in range(NUM_EPISODES): # 에피소드 수 만큼 반복\n", 171 | " observation = self.env.reset() # 환경 초기화\n", 172 | "\n", 173 | " for step in range(MAX_STEPS): # 1 에피소드에 해당하는 반복\n", 174 | "\n", 175 | " if is_episode_final is True: # 마지막 에피소드이면 frames에 각 단계의 이미지를 저장\n", 176 | " frames.append(self.env.render(mode='rgb_array'))\n", 177 | "\n", 178 | " # 행동을 선택\n", 179 | " action = self.agent.get_action(observation, episode)\n", 180 | "\n", 181 | " # 행동 a_t를 실행하여 s_{t+1}, r_{t+1}을 계산\n", 182 | " observation_next, _, done, _ = self.env.step(\n", 183 | " action) # reward, info는 사용하지 않으므로 _로 처리함\n", 184 | "\n", 185 | " # 보상을 부여\n", 186 | " if done: # 200단계를 넘어서거나 일정 각도 이상 기울면 done의 값이 True가 됨\n", 187 | " if step < 195:\n", 188 | " reward = -1 # 봉이 쓰러지면 페널티로 보상 -1 부여\n", 189 | " complete_episodes = 0 # 195단계 이상 버티면 해당 에피소드를 성공 처리\n", 190 | " else:\n", 191 | " reward = 1 # 쓰러지지 않고 에피소드를 끝내면 보상 1 부여\n", 192 | " complete_episodes += 1 # 에피소드 연속 성공 기록을 업데이트\n", 193 | " else:\n", 194 | " reward = 0 # 에피소드 중에는 보상이 0\n", 195 | "\n", 196 | " # 다음 단계의 상태 observation_next로 Q함수를 수정\n", 197 | " self.agent.update_Q_function(\n", 198 | " observation, action, reward, observation_next)\n", 199 | "\n", 200 | " # 다음 단계 상태 관측\n", 201 | " observation = observation_next\n", 202 | "\n", 203 | " # 에피소드 마무리\n", 204 | " if done:\n", 205 | " print('{0} Episode: Finished after {1} time steps'.format(\n", 206 | " episode, step + 1))\n", 207 | " break\n", 208 | "\n", 209 | " if is_episode_final is True: # 마지막 에피소드에서는 애니메이션을 만들고 저장\n", 210 | " display_frames_as_gif(frames)\n", 211 | " break\n", 212 | "\n", 213 | " if complete_episodes >= 10: # 10 에피소드 연속으로 성공한 경우\n", 214 | " print('10 에피소드 연속 성공')\n", 215 | " is_episode_final = True # 다음 에피소드가 마지막 에피소드가 됨 " 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 9, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | "0 Episode: Finished after 12 time steps\n", 228 | "1 Episode: Finished after 21 time steps\n", 229 | "2 Episode: Finished after 28 time steps\n", 230 | "3 Episode: Finished after 23 time steps\n", 231 | "4 Episode: Finished after 27 time steps\n", 232 | "5 Episode: Finished after 25 time steps\n", 233 | "6 Episode: Finished after 52 time steps\n", 234 | "7 Episode: Finished after 17 time steps\n", 235 | "8 Episode: Finished after 19 time steps\n", 236 | "9 Episode: Finished after 52 time steps\n", 237 | "10 Episode: Finished after 143 time steps\n", 238 | "11 Episode: Finished after 32 time steps\n", 239 | "12 Episode: Finished after 16 time steps\n", 240 | "13 Episode: Finished after 59 time steps\n", 241 | "14 Episode: Finished after 102 time steps\n", 242 | "15 Episode: Finished after 183 time steps\n", 243 | "16 Episode: Finished after 148 time steps\n", 244 | "17 Episode: Finished after 80 time steps\n", 245 | "18 Episode: Finished after 14 time steps\n", 246 | "19 Episode: Finished after 200 time steps\n", 247 | "20 Episode: Finished after 80 time steps\n", 248 | "21 Episode: Finished after 38 time steps\n", 249 | "22 Episode: Finished after 26 time steps\n", 250 | "23 Episode: Finished after 75 time steps\n", 251 | "24 Episode: Finished after 81 time steps\n", 252 | "25 Episode: Finished after 27 time steps\n", 253 | "26 Episode: Finished after 32 time steps\n", 254 | "27 Episode: Finished after 41 time steps\n", 255 | "28 Episode: Finished after 11 time steps\n", 256 | "29 Episode: Finished after 61 time steps\n", 257 | "30 Episode: Finished after 36 time steps\n", 258 | "31 Episode: Finished after 189 time steps\n", 259 | "32 Episode: Finished after 200 time steps\n", 260 | "33 Episode: Finished after 37 time steps\n", 261 | "34 Episode: Finished after 10 time steps\n", 262 | "35 Episode: Finished after 13 time steps\n", 263 | "36 Episode: Finished after 30 time steps\n", 264 | "37 Episode: Finished after 12 time steps\n", 265 | "38 Episode: Finished after 85 time steps\n", 266 | "39 Episode: Finished after 192 time steps\n", 267 | "40 Episode: Finished after 48 time steps\n", 268 | "41 Episode: Finished after 200 time steps\n", 269 | "42 Episode: Finished after 196 time steps\n", 270 | "43 Episode: Finished after 103 time steps\n", 271 | "44 Episode: Finished after 162 time steps\n", 272 | "45 Episode: Finished after 72 time steps\n", 273 | "46 Episode: Finished after 200 time steps\n", 274 | "47 Episode: Finished after 16 time steps\n", 275 | "48 Episode: Finished after 16 time steps\n", 276 | "49 Episode: Finished after 20 time steps\n", 277 | "50 Episode: Finished after 20 time steps\n", 278 | "51 Episode: Finished after 20 time steps\n", 279 | "52 Episode: Finished after 10 time steps\n", 280 | "53 Episode: Finished after 200 time steps\n", 281 | "54 Episode: Finished after 49 time steps\n", 282 | "55 Episode: Finished after 55 time steps\n", 283 | "56 Episode: Finished after 47 time steps\n", 284 | "57 Episode: Finished after 15 time steps\n", 285 | "58 Episode: Finished after 40 time steps\n", 286 | "59 Episode: Finished after 49 time steps\n", 287 | "60 Episode: Finished after 40 time steps\n", 288 | "61 Episode: Finished after 42 time steps\n", 289 | "62 Episode: Finished after 41 time steps\n", 290 | "63 Episode: Finished after 54 time steps\n", 291 | "64 Episode: Finished after 36 time steps\n", 292 | "65 Episode: Finished after 73 time steps\n", 293 | "66 Episode: Finished after 69 time steps\n", 294 | "67 Episode: Finished after 51 time steps\n", 295 | "68 Episode: Finished after 32 time steps\n", 296 | "69 Episode: Finished after 51 time steps\n", 297 | "70 Episode: Finished after 75 time steps\n", 298 | "71 Episode: Finished after 9 time steps\n", 299 | "72 Episode: Finished after 17 time steps\n", 300 | "73 Episode: Finished after 10 time steps\n", 301 | "74 Episode: Finished after 68 time steps\n", 302 | "75 Episode: Finished after 22 time steps\n", 303 | "76 Episode: Finished after 12 time steps\n", 304 | "77 Episode: Finished after 24 time steps\n", 305 | "78 Episode: Finished after 18 time steps\n", 306 | "79 Episode: Finished after 19 time steps\n", 307 | "80 Episode: Finished after 60 time steps\n", 308 | "81 Episode: Finished after 44 time steps\n", 309 | "82 Episode: Finished after 78 time steps\n", 310 | "83 Episode: Finished after 97 time steps\n", 311 | "84 Episode: Finished after 119 time steps\n", 312 | "85 Episode: Finished after 55 time steps\n", 313 | "86 Episode: Finished after 51 time steps\n", 314 | "87 Episode: Finished after 118 time steps\n", 315 | "88 Episode: Finished after 168 time steps\n", 316 | "89 Episode: Finished after 62 time steps\n", 317 | "90 Episode: Finished after 200 time steps\n", 318 | "91 Episode: Finished after 25 time steps\n", 319 | "92 Episode: Finished after 162 time steps\n", 320 | "93 Episode: Finished after 200 time steps\n", 321 | "94 Episode: Finished after 77 time steps\n", 322 | "95 Episode: Finished after 200 time steps\n", 323 | "96 Episode: Finished after 107 time steps\n", 324 | "97 Episode: Finished after 91 time steps\n", 325 | "98 Episode: Finished after 159 time steps\n", 326 | "99 Episode: Finished after 73 time steps\n", 327 | "100 Episode: Finished after 94 time steps\n", 328 | "101 Episode: Finished after 40 time steps\n", 329 | "102 Episode: Finished after 145 time steps\n", 330 | "103 Episode: Finished after 87 time steps\n", 331 | "104 Episode: Finished after 200 time steps\n", 332 | "105 Episode: Finished after 198 time steps\n", 333 | "106 Episode: Finished after 87 time steps\n", 334 | "107 Episode: Finished after 107 time steps\n", 335 | "108 Episode: Finished after 108 time steps\n", 336 | "109 Episode: Finished after 110 time steps\n", 337 | "110 Episode: Finished after 26 time steps\n", 338 | "111 Episode: Finished after 196 time steps\n", 339 | "112 Episode: Finished after 190 time steps\n", 340 | "113 Episode: Finished after 200 time steps\n", 341 | "114 Episode: Finished after 187 time steps\n", 342 | "115 Episode: Finished after 51 time steps\n", 343 | "116 Episode: Finished after 84 time steps\n", 344 | "117 Episode: Finished after 184 time steps\n", 345 | "118 Episode: Finished after 10 time steps\n", 346 | "119 Episode: Finished after 96 time steps\n", 347 | "120 Episode: Finished after 93 time steps\n", 348 | "121 Episode: Finished after 187 time steps\n", 349 | "122 Episode: Finished after 171 time steps\n", 350 | "123 Episode: Finished after 200 time steps\n", 351 | "124 Episode: Finished after 145 time steps\n", 352 | "125 Episode: Finished after 161 time steps\n", 353 | "126 Episode: Finished after 96 time steps\n", 354 | "127 Episode: Finished after 107 time steps\n", 355 | "128 Episode: Finished after 146 time steps\n", 356 | "129 Episode: Finished after 192 time steps\n", 357 | "130 Episode: Finished after 129 time steps\n", 358 | "131 Episode: Finished after 89 time steps\n", 359 | "132 Episode: Finished after 126 time steps\n", 360 | "133 Episode: Finished after 95 time steps\n", 361 | "134 Episode: Finished after 52 time steps\n", 362 | "135 Episode: Finished after 37 time steps\n", 363 | "136 Episode: Finished after 200 time steps\n", 364 | "137 Episode: Finished after 73 time steps\n", 365 | "138 Episode: Finished after 167 time steps\n", 366 | "139 Episode: Finished after 152 time steps\n", 367 | "140 Episode: Finished after 78 time steps\n", 368 | "141 Episode: Finished after 200 time steps\n", 369 | "142 Episode: Finished after 200 time steps\n", 370 | "143 Episode: Finished after 100 time steps\n", 371 | "144 Episode: Finished after 167 time steps\n", 372 | "145 Episode: Finished after 200 time steps\n", 373 | "146 Episode: Finished after 200 time steps\n", 374 | "147 Episode: Finished after 200 time steps\n", 375 | "148 Episode: Finished after 183 time steps\n", 376 | "149 Episode: Finished after 23 time steps\n", 377 | "150 Episode: Finished after 200 time steps\n", 378 | "151 Episode: Finished after 200 time steps\n", 379 | "152 Episode: Finished after 200 time steps\n", 380 | "153 Episode: Finished after 200 time steps\n", 381 | "154 Episode: Finished after 171 time steps\n", 382 | "155 Episode: Finished after 200 time steps\n", 383 | "156 Episode: Finished after 200 time steps\n", 384 | "157 Episode: Finished after 131 time steps\n", 385 | "158 Episode: Finished after 200 time steps\n", 386 | "159 Episode: Finished after 200 time steps\n", 387 | "160 Episode: Finished after 200 time steps\n", 388 | "161 Episode: Finished after 200 time steps\n", 389 | "162 Episode: Finished after 200 time steps\n", 390 | "163 Episode: Finished after 200 time steps\n", 391 | "164 Episode: Finished after 200 time steps\n", 392 | "165 Episode: Finished after 200 time steps\n", 393 | "166 Episode: Finished after 200 time steps\n", 394 | "167 Episode: Finished after 200 time steps\n", 395 | "10 에피소드 연속 성공\n" 396 | ] 397 | }, 398 | { 399 | "ename": "NotImplementedError", 400 | "evalue": "abstract", 401 | "output_type": "error", 402 | "traceback": [ 403 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 404 | "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)", 405 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# main\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mcartpole_env\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mEnvironment\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mcartpole_env\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 406 | "\u001b[0;32m\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_episode_final\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# 마지막 에피소드이면 frames에 각 단계의 이미지를 저장\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0mframes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'rgb_array'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# 행동을 선택\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 407 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/gym/core.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode, **kwargs)\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 274\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'human'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 275\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 276\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 277\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 408 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/gym/envs/classic_control/cartpole.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode)\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclassic_control\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mrendering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 151\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrendering\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mViewer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscreen_width\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscreen_height\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 152\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcartwidth\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcartwidth\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcartheight\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcartheight\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0maxleoffset\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0mcartheight\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m4.0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 409 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/gym/envs/classic_control/rendering.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, width, height, display)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwidth\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwidth\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheight\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mheight\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpyglet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mWindow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwidth\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwidth\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mheight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mheight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdisplay\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_close\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow_closed_by_user\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misopen\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 410 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/pyglet/window/__init__.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, width, height, caption, resizable, style, fullscreen, visible, vsync, display, screen, config, context, mode)\u001b[0m\n\u001b[1;32m 502\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 503\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mscreen\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 504\u001b[0;31m \u001b[0mscreen\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_default_screen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 505\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 506\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 411 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/pyglet/canvas/base.py\u001b[0m in \u001b[0;36mget_default_screen\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mrtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;32mclass\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mScreen\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m '''\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_screens\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_windows\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 412 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/pyglet/canvas/base.py\u001b[0m in \u001b[0;36mget_screens\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mrtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlist\u001b[0m \u001b[0mof\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;32mclass\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mScreen\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m '''\n\u001b[0;32m---> 65\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'abstract'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_default_screen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 413 | "\u001b[0;31mNotImplementedError\u001b[0m: abstract" 414 | ] 415 | } 416 | ], 417 | "source": [ 418 | "# main\n", 419 | "cartpole_env = Environment()\n", 420 | "cartpole_env.run()\n" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [] 429 | } 430 | ], 431 | "metadata": { 432 | "kernelspec": { 433 | "display_name": "Python 3", 434 | "language": "python", 435 | "name": "python3" 436 | }, 437 | "language_info": { 438 | "codemirror_mode": { 439 | "name": "ipython", 440 | "version": 3 441 | }, 442 | "file_extension": ".py", 443 | "mimetype": "text/x-python", 444 | "name": "python", 445 | "nbconvert_exporter": "python", 446 | "pygments_lexer": "ipython3", 447 | "version": "3.6.4" 448 | } 449 | }, 450 | "nbformat": 4, 451 | "nbformat_minor": 2 452 | } 453 | -------------------------------------------------------------------------------- /program/4_3_PyTorch_MNIST.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 4.3 파이토치로 MNIST 이미지 분류 구현하기" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# 숫자 손글씨 이미지 데이터 집합 MNIST 다운로드\n", 17 | "\n", 18 | "from sklearn.datasets import fetch_mldata\n", 19 | "\n", 20 | "mnist = fetch_mldata('MNIST original', data_home=\".\") \n", 21 | "# data_home에 다운로드 받은 경로를 지정" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 5, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# 1. 데이터 전처리 (이미지 데이터와 레이블 데이터 분리, 정규화)\n", 31 | "\n", 32 | "X = mnist.data / 255 # 0-255값을 [0,1] 구간으로 정규화\n", 33 | "y = mnist.target\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 6, 39 | "metadata": { 40 | "scrolled": true 41 | }, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "이 이미지 데이터의 레이블은 0이다\n" 48 | ] 49 | }, 50 | { 51 | "data": { 52 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADi5JREFUeJzt3X+IXfWZx/HPo22CmkbUYhyN2bQlLi2iEzMGoWHNulhcDSRFognipOzSyR8NWFlkVUYTWItFNLsqGEx1aIJpkmp0E8u6aXFEWxBxjFJt0x+hZNPZDBljxEwQDCbP/jEnyyTO/Z479557z5l53i8Ic+957rnn8TqfOefe77nna+4uAPGcVXYDAMpB+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBPWldm7MzDidEGgxd7d6HtfUnt/MbjKzP5rZPjO7t5nnAtBe1ui5/WZ2tqQ/SbpR0qCktyWtdPffJ9Zhzw+0WDv2/Asl7XP3v7j7cUnbJC1t4vkAtFEz4b9M0l/H3B/Mlp3GzHrMbMDMBprYFoCCNfOB33iHFl84rHf3jZI2Shz2A1XSzJ5/UNLlY+7PlnSwuXYAtEsz4X9b0jwz+5qZTZO0QtKuYtoC0GoNH/a7++dmtkbSbklnS+pz998V1hmAlmp4qK+hjfGeH2i5tpzkA2DyIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqLZO0Y2pZ8GCBcn6mjVrata6u7uT627evDlZf/LJJ5P1PXv2JOvRsecHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaCamqXXzPZLGpF0QtLn7t6V83hm6Z1kOjs7k/X+/v5kfebMmUW2c5pPPvkkWb/oootatu0qq3eW3iJO8vl7dz9cwPMAaCMO+4Ggmg2/S/qlmb1jZj1FNASgPZo97P+2ux80s4sl/crM/uDub4x9QPZHgT8MQMU0ted394PZz2FJL0laOM5jNrp7V96HgQDaq+Hwm9l5ZvaVU7clfUfSB0U1BqC1mjnsnyXpJTM79Tw/c/f/LqQrAC3X1Dj/hDfGOH/lLFz4hXdqp9mxY0eyfumllybrqd+vkZGR5LrHjx9P1vPG8RctWlSzlvdd/7xtV1m94/wM9QFBEX4gKMIPBEX4gaAIPxAU4QeCYqhvCjj33HNr1q655prkus8991yyPnv27GQ9O8+jptTvV95w2yOPPJKsb9u2LVlP9dbb25tc9+GHH07Wq4yhPgBJhB8IivADQRF+ICjCDwRF+IGgCD8QFFN0TwFPP/10zdrKlSvb2MnE5J2DMGPGjGT99ddfT9YXL15cs3bVVVcl142APT8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBMU4/ySwYMGCZP2WW26pWcv7vn2evLH0l19+OVl/9NFHa9YOHjyYXPfdd99N1j/++ONk/YYbbqhZa/Z1mQrY8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAULnX7TezPklLJA27+5XZsgslbZc0V9J+Sbe5e3rQVVy3v5bOzs5kvb+/P1mfOXNmw9t+5ZVXkvW86wFcf/31yXrqe/PPPPNMct0PP/wwWc9z4sSJmrVPP/00uW7ef1fenANlKvK6/T+VdNMZy+6V9Kq7z5P0anYfwCSSG353f0PSkTMWL5W0Kbu9SdKygvsC0GKNvuef5e5DkpT9vLi4lgC0Q8vP7TezHkk9rd4OgIlpdM9/yMw6JCn7OVzrge6+0d273L2rwW0BaIFGw79L0qrs9ipJO4tpB0C75IbfzLZKelPS35rZoJn9s6QfS7rRzP4s6cbsPoBJJHecv9CNBR3nv+KKK5L1tWvXJusrVqxI1g8fPlyzNjQ0lFz3oYceStZfeOGFZL3KUuP8eb/327dvT9bvuOOOhnpqhyLH+QFMQYQfCIrwA0ERfiAowg8ERfiBoLh0dwGmT5+erKcuXy1JN998c7I+MjKSrHd3d9esDQwMJNc955xzkvWo5syZU3YLLceeHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCYpy/APPnz0/W88bx8yxdujRZz5tGGxgPe34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIpx/gKsX78+WTdLX0k5b5yecfzGnHVW7X3byZMn29hJNbHnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgcsf5zaxP0hJJw+5+ZbZsnaTvS/owe9j97v5frWqyCpYsWVKz1tnZmVw3bzroXbt2NdQT0lJj+Xn/T957772i26mcevb8P5V00zjL/93dO7N/Uzr4wFSUG353f0PSkTb0AqCNmnnPv8bMfmtmfWZ2QWEdAWiLRsO/QdI3JHVKGpL0WK0HmlmPmQ2YWXrSOABt1VD43f2Qu59w95OSfiJpYeKxG929y927Gm0SQPEaCr+ZdYy5+11JHxTTDoB2qWeob6ukxZK+amaDktZKWmxmnZJc0n5Jq1vYI4AWyA2/u68cZ/GzLeil0lLz2E+bNi257vDwcLK+ffv2hnqa6qZPn56sr1u3ruHn7u/vT9bvu+++hp97suAMPyAowg8ERfiBoAg/EBThB4Ii/EBQXLq7DT777LNkfWhoqE2dVEveUF5vb2+yfs899yTrg4ODNWuPPVbzjHRJ0rFjx5L1qYA9PxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ExTh/G0S+NHfqsuZ54/S33357sr5z585k/dZbb03Wo2PPDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBMc5fJzNrqCZJy5YtS9bvuuuuhnqqgrvvvjtZf+CBB2rWzj///OS6W7ZsSda7u7uTdaSx5weCIvxAUIQfCIrwA0ERfiAowg8ERfiBoHLH+c3sckmbJV0i6aSkje7+uJldKGm7pLmS9ku6zd0/bl2r5XL3hmqSdMkllyTrTzzxRLLe19eXrH/00Uc1a9ddd11y3TvvvDNZv/rqq5P12bNnJ+sHDhyoWdu9e3dy3aeeeipZR3Pq2fN/Lulf3P2bkq6T9AMz+5akeyW96u7zJL2a3QcwSeSG392H3H1PdntE0l5Jl0laKmlT9rBNktKnsQGolAm95zezuZLmS3pL0ix3H5JG/0BIurjo5gC0Tt3n9pvZDEk7JP3Q3Y/mnc8+Zr0eST2NtQegVera85vZlzUa/C3u/mK2+JCZdWT1DknD463r7hvdvcvdu4poGEAxcsNvo7v4ZyXtdff1Y0q7JK3Kbq+SlL6UKoBKsbxhKjNbJOnXkt7X6FCfJN2v0ff9P5c0R9IBScvd/UjOc6U3VmHLly+vWdu6dWtLt33o0KFk/ejRozVr8+bNK7qd07z55pvJ+muvvVaz9uCDDxbdDiS5e13vyXPf87v7byTVerJ/mEhTAKqDM/yAoAg/EBThB4Ii/EBQhB8IivADQeWO8xe6sUk8zp/66urzzz+fXPfaa69tatt5p1I38/8w9XVgSdq2bVuyPpkvOz5V1TvOz54fCIrwA0ERfiAowg8ERfiBoAg/EBThB4JinL8AHR0dyfrq1auT9d7e3mS9mXH+xx9/PLnuhg0bkvV9+/Yl66gexvkBJBF+ICjCDwRF+IGgCD8QFOEHgiL8QFCM8wNTDOP8AJIIPxAU4QeCIvxAUIQfCIrwA0ERfiCo3PCb2eVm9pqZ7TWz35nZXdnydWb2v2b2Xvbv5ta3C6AouSf5mFmHpA5332NmX5H0jqRlkm6TdMzdH617Y5zkA7RcvSf5fKmOJxqSNJTdHjGzvZIua649AGWb0Ht+M5srab6kt7JFa8zst2bWZ2YX1Finx8wGzGygqU4BFKruc/vNbIak1yX9yN1fNLNZkg5Lckn/ptG3Bv+U8xwc9gMtVu9hf13hN7MvS/qFpN3uvn6c+lxJv3D3K3Oeh/ADLVbYF3ts9NKxz0raOzb42QeBp3xX0gcTbRJAeer5tH+RpF9Lel/SyWzx/ZJWSurU6GH/fkmrsw8HU8/Fnh9osUIP+4tC+IHW4/v8AJIIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQeVewLNghyX9z5j7X82WVVFVe6tqXxK9NarI3v6m3ge29fv8X9i42YC7d5XWQEJVe6tqXxK9Naqs3jjsB4Ii/EBQZYd/Y8nbT6lqb1XtS6K3RpXSW6nv+QGUp+w9P4CSlBJ+M7vJzP5oZvvM7N4yeqjFzPab2fvZzMOlTjGWTYM2bGYfjFl2oZn9ysz+nP0cd5q0knqrxMzNiZmlS33tqjbjddsP+83sbEl/knSjpEFJb0ta6e6/b2sjNZjZfkld7l76mLCZ/Z2kY5I2n5oNycwekXTE3X+c/eG8wN3/tSK9rdMEZ25uUW+1Zpb+nkp87Yqc8boIZez5F0ra5+5/cffjkrZJWlpCH5Xn7m9IOnLG4qWSNmW3N2n0l6ftavRWCe4+5O57stsjkk7NLF3qa5foqxRlhP8ySX8dc39Q1Zry2yX90szeMbOespsZx6xTMyNlPy8uuZ8z5c7c3E5nzCxdmdeukRmvi1ZG+MebTaRKQw7fdvdrJP2jpB9kh7eozwZJ39DoNG5Dkh4rs5lsZukdkn7o7kfL7GWscfoq5XUrI/yDki4fc3+2pIMl9DEudz+Y/RyW9JJG36ZUyaFTk6RmP4dL7uf/ufshdz/h7icl/UQlvnbZzNI7JG1x9xezxaW/duP1VdbrVkb435Y0z8y+ZmbTJK2QtKuEPr7AzM7LPoiRmZ0n6Tuq3uzDuyStym6vkrSzxF5OU5WZm2vNLK2SX7uqzXhdykk+2VDGf0g6W1Kfu/+o7U2Mw8y+rtG9vTT6jcefldmbmW2VtFij3/o6JGmtpP+U9HNJcyQdkLTc3dv+wVuN3hZrgjM3t6i3WjNLv6USX7siZ7wupB/O8ANi4gw/ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANB/R/7QknxGq+fLwAAAABJRU5ErkJggg==\n", 53 | "text/plain": [ 54 | "
" 55 | ] 56 | }, 57 | "metadata": { 58 | "needs_background": "light" 59 | }, 60 | "output_type": "display_data" 61 | } 62 | ], 63 | "source": [ 64 | "# 첫 번째 데이터를 시각화\n", 65 | "\n", 66 | "import matplotlib.pyplot as plt\n", 67 | "% matplotlib inline\n", 68 | "\n", 69 | "plt.imshow(X[0].reshape(28, 28), cmap='gray')\n", 70 | "print(\"이 이미지 데이터의 레이블은 {:.0f}이다\".format(y[0]))" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 7, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# 2. DataLoader로 변환\n", 80 | "\n", 81 | "import torch\n", 82 | "from torch.utils.data import TensorDataset, DataLoader\n", 83 | "from sklearn.model_selection import train_test_split\n", 84 | "\n", 85 | "# 2.1 데이터를 훈련 데이터와 테스트 데이터로 분할(6:1 비율)\n", 86 | "X_train, X_test, y_train, y_test = train_test_split(\n", 87 | " X, y, test_size=1/7, random_state=0)\n", 88 | "\n", 89 | "# 2.2 데이터를 파이토치 텐서로 변환\n", 90 | "X_train = torch.Tensor(X_train)\n", 91 | "X_test = torch.Tensor(X_test)\n", 92 | "y_train = torch.LongTensor(y_train)\n", 93 | "y_test = torch.LongTensor(y_test)\n", 94 | "\n", 95 | "# 2.3 데이터와 정답 레이블을 하나로 묶어 Dataset으로 만듬\n", 96 | "ds_train = TensorDataset(X_train, y_train)\n", 97 | "ds_test = TensorDataset(X_test, y_test)\n", 98 | "\n", 99 | "# 2.4 미니배치 크기를 지정하여 DataLoader 객체로 변환\n", 100 | "# Chainer의 iterators.SerialIterator와 비슷함\n", 101 | "loader_train = DataLoader(ds_train, batch_size=64, shuffle=True)\n", 102 | "loader_test = DataLoader(ds_test, batch_size=64, shuffle=False)\n" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 8, 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "Sequential(\n", 115 | " (fc1): Linear(in_features=784, out_features=100, bias=True)\n", 116 | " (relu1): ReLU()\n", 117 | " (fc2): Linear(in_features=100, out_features=100, bias=True)\n", 118 | " (relu2): ReLU()\n", 119 | " (fc3): Linear(in_features=100, out_features=10, bias=True)\n", 120 | ")\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "# 3. 신경망 구성\n", 126 | "# Keras 스타일\n", 127 | "\n", 128 | "from torch import nn\n", 129 | "\n", 130 | "model = nn.Sequential()\n", 131 | "model.add_module('fc1', nn.Linear(28*28*1, 100))\n", 132 | "model.add_module('relu1', nn.ReLU())\n", 133 | "model.add_module('fc2', nn.Linear(100, 100))\n", 134 | "model.add_module('relu2', nn.ReLU())\n", 135 | "model.add_module('fc3', nn.Linear(100, 10))\n", 136 | "\n", 137 | "print(model)\n" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 9, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "# 4. 오차함수 및 최적화 기법 설정\n", 147 | "\n", 148 | "from torch import optim\n", 149 | "\n", 150 | "# 오차함수 선택\n", 151 | "loss_fn = nn.CrossEntropyLoss() # criterion을 변수명으로 사용하는 경우가 많다\n", 152 | "\n", 153 | "\n", 154 | "# 가중치를 학습하기 위한 최적화 기법 선택\n", 155 | "optimizer = optim.Adam(model.parameters(), lr=0.01)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 10, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "# 5. 학습 및 추론 설정\n", 165 | "# 5-1. 학습 중 1에포크에서 수행할 일을 함수로 정의\n", 166 | "# 파이토치에는 Chainer의 training.Trainer()에 해당하는 것이 없음\n", 167 | "\n", 168 | "\n", 169 | "def train(epoch):\n", 170 | " model.train() # 신경망을 학습 모드로 전환\n", 171 | "\n", 172 | " # 데이터로더에서 미니배치를 하나씩 꺼내 학습을 수행\n", 173 | " for data, targets in loader_train:\n", 174 | " \n", 175 | " optimizer.zero_grad() # 경사를 0으로 초기화\n", 176 | " outputs = model(data) # 데이터를 입력하고 출력을 계산\n", 177 | " loss = loss_fn(outputs, targets) # 출력과 훈련 데이터 정답 간의 오차를 계산\n", 178 | " loss.backward() # 오차를 역전파 계산\n", 179 | " optimizer.step() # 역전파 계산한 값으로 가중치를 수정\n", 180 | "\n", 181 | " print(\"epoch{}:완료\\n\".format(epoch))\n" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 11, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "# 5. 학습 및 추론 설정\n", 191 | "# 5-2. 추론 1에포크에서 할 일을 함수로 정의\n", 192 | "# 파이토치에는 Chainer의 trainer.extend(extensions.Evaluator())에 해당하는 것이 없음\n", 193 | "\n", 194 | "\n", 195 | "def test():\n", 196 | " model.eval() # 신경망을 추론 모드로 전환\n", 197 | " correct = 0\n", 198 | "\n", 199 | " # 데이터로더에서 미니배치를 하나씩 꺼내 추론을 수행\n", 200 | " with torch.no_grad(): # 추론 과정에는 미분이 필요없음\n", 201 | " for data, targets in loader_test:\n", 202 | "\n", 203 | " outputs = model(data) # 데이터를 입력하고 출력을 계산\n", 204 | "\n", 205 | " # 추론 계산\n", 206 | " _, predicted = torch.max(outputs.data, 1) # 확률이 가장 높은 레이블이 무엇인지 계산\n", 207 | " correct += predicted.eq(targets.data.view_as(predicted)).sum() # 정답과 일치한 경우 정답 카운트를 증가\n", 208 | "\n", 209 | " # 정확도 출력\n", 210 | " data_num = len(loader_test.dataset) # 데이터 총 건수\n", 211 | " print('\\n테스트 데이터에서 예측 정확도: {}/{} ({:.0f}%)\\n'.format(correct,\n", 212 | " data_num, 100. * correct / data_num))\n" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 12, 218 | "metadata": {}, 219 | "outputs": [ 220 | { 221 | "name": "stdout", 222 | "output_type": "stream", 223 | "text": [ 224 | "\n", 225 | "테스트 데이터에서 예측 정확도: 982/10000 (9%)\n", 226 | "\n" 227 | ] 228 | } 229 | ], 230 | "source": [ 231 | "# 학습 전 상태에서 테스트 데이터로 정확도 측정\n", 232 | "test()\n" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 15, 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "name": "stdout", 242 | "output_type": "stream", 243 | "text": [ 244 | "epoch0:완료\n", 245 | "\n", 246 | "epoch1:완료\n", 247 | "\n", 248 | "epoch2:완료\n", 249 | "\n", 250 | "\n", 251 | "테스트 데이터에서 예측 정확도: 9618/10000 (96%)\n", 252 | "\n" 253 | ] 254 | } 255 | ], 256 | "source": [ 257 | "# 6. 학습 및 추론 수행\n", 258 | "for epoch in range(3):\n", 259 | " train(epoch)\n", 260 | "\n", 261 | "test()\n" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 16, 267 | "metadata": {}, 268 | "outputs": [ 269 | { 270 | "name": "stdout", 271 | "output_type": "stream", 272 | "text": [ 273 | "예측 결과 : 7\n", 274 | "이 이미지 데이터의 정답 레이블은 7입니다\n" 275 | ] 276 | }, 277 | { 278 | "data": { 279 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADIZJREFUeJzt3VHIXPWZx/Hvs257Y3uhFG1I36zdIssuXtglSGLL4rJY3KUQe5GkXmVhaXpRYRMVKt40N4WyGE2vCikNjdDaaNquXpTdiizYxUSMUqpttq2U7Pu+GpKWFKpXRX324j0ub+M750xm5syZN8/3A2Fmzn/mnIdDfu85M/9z/v/ITCTV82dDFyBpGIZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJRfz7PjUWElxNKPcvMGOd9Ux35I+KuiPhlRLwWEQ9Osy5J8xWTXtsfEdcAvwLuBFaBF4F7MvMXLZ/xyC/1bB5H/tuA1zLzN5n5R+B7wK4p1idpjqYJ/1ZgZd3r1WbZn4iI/RFxJiLOTLEtSTM2zQ9+G51avO+0PjOPAkfB035pkUxz5F8Flta9/hjwxnTlSJqXacL/InBzRHw8Ij4IfB54ejZlSerbxKf9mfl2RNwL/CdwDXAsM38+s8ok9Wrirr6JNuZ3fql3c7nIR9LmZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJRE0/RDRAR54A3gXeAtzNz+yyKktS/qcLf+PvM/N0M1iNpjjztl4qaNvwJ/DgiXoqI/bMoSNJ8THva/6nMfCMibgCeiYj/yczn1r+h+aPgHwZpwURmzmZFEYeAtzLz4Zb3zGZjkkbKzBjnfROf9kfEtRHx4feeA58BXp10fZLma5rT/huBH0bEe+v5bmb+x0yqktS7mZ32j7WxHk/7d+zY0dr+xBNPTLX+06dPj2w7depU62dff/311vbl5eWJty1drvfTfkmbm+GXijL8UlGGXyrK8EtFGX6pqFnc1bcQtm3b1tq+tLQ01frbPr979+6p1j2tJ598cmRbVzdkV7vdjFcvj/xSUYZfKsrwS0UZfqkowy8VZfilogy/VNRVc0tvVz/+zp07W9sffnjkAESd65+2L72r9iGvI1hZWWltb7vGAODIkSMTr1uT8ZZeSa0Mv1SU4ZeKMvxSUYZfKsrwS0UZfqmoq6aff1pdQ3+39dV39XXv2bNnoprG1bb+rVu3tn626xqCrusjptG1306ePNnaPu1w7Fcr+/kltTL8UlGGXyrK8EtFGX6pKMMvFWX4paI6+/kj4hjwWeBiZt7SLLseOAHcBJwD9mTm7zs3tsD9/F3a9tPQ/fx96hpr4MCBA63tbdcRTDuXQtd4AG37/Wqej2CW/fzfBu66bNmDwLOZeTPwbPNa0ibSGf7MfA64dNniXcDx5vlx4O4Z1yWpZ5N+578xM88DNI83zK4kSfPQ+1x9EbEf2N/3diRdmUmP/BciYgtA83hx1Bsz82hmbs/M7RNuS1IPJg3/08C+5vk+4KnZlCNpXjrDHxGPA6eAv4qI1Yj4F+BrwJ0R8Wvgzua1pE3E+/nH9Pzzz49sW11dbf3sZu7nn1ZbX/401wh0rRvarwN44IEHWj+7mccK8H5+Sa0Mv1SU4ZeKMvxSUYZfKsrwS0XZ1TemabrrNnO30ZC6uvIOHz7c2t7WVdh1O/C2bdta2xeZXX2SWhl+qSjDLxVl+KWiDL9UlOGXijL8UlH282vT6rr24sSJExOve+/eva3ti3zthv38kloZfqkowy8VZfilogy/VJThl4oy/FJRvU/XJfWla+jvaSwvL/e27kXhkV8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXiurs54+IY8BngYuZeUuz7BDwBeC3zdseyswf9VWkajp48GBr+86dOyde9yOPPNLafvr06YnXvVmMc+T/NnDXBssfzcxbm38GX9pkOsOfmc8Bl+ZQi6Q5muY7/70R8bOIOBYR182sIklzMWn4vwF8ArgVOA+MnDQtIvZHxJmIODPhtiT1YKLwZ+aFzHwnM98Fvgnc1vLeo5m5PTO3T1qkpNmbKPwRsWXdy88Br86mHEnzMk5X3+PAHcBHImIV+ApwR0TcCiRwDvhijzVK6oHj9mswO3bsaG3vGht/aWmptf3UqVMj226//fbWz25mjtsvqZXhl4oy/FJRhl8qyvBLRRl+qSiH7lav2rrjum6r7erKW1lZaW0/cuRIa3t1Hvmlogy/VJThl4oy/FJRhl8qyvBLRRl+qShv6VWv2m7L3b1791Tr3rt378Tbvpp5S6+kVoZfKsrwS0UZfqkowy8VZfilogy/VJT382sqXX3p0/Tld93vX7Uff1Y88ktFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUZ39/BGxBDwGfBR4FziamV+PiOuBE8BNwDlgT2b+vr9S1YeusfEPHz7c2t5nP/79998/8brVbZwj/9vA/Zn518AO4EsR8TfAg8CzmXkz8GzzWtIm0Rn+zDyfmS83z98EzgJbgV3A8eZtx4G7+ypS0uxd0Xf+iLgJ+CTwAnBjZp6HtT8QwA2zLk5Sf8a+tj8iPgR8HziQmX+IGGuYMCJiP7B/svIk9WWsI39EfIC14H8nM3/QLL4QEVua9i3AxY0+m5lHM3N7Zm6fRcGSZqMz/LF2iP8WcDYz1/88+zSwr3m+D3hq9uVJ6kvn0N0R8WngJ8ArrHX1ATzE2vf+J4BtwDKwOzMvdazLobsXTJ+35ALcd999I9seffTRqdatjY07dHfnd/7M/G9g1Mr+4UqKkrQ4vMJPKsrwS0UZfqkowy8VZfilogy/VJRDd1/lDh482NreZz8+2Je/yDzyS0UZfqkowy8VZfilogy/VJThl4oy/FJRnffzz3Rj3s/fix07doxsO3Xq1FTrdnjtzWfc+/k98ktFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUd7PfxXouqd+Gl1TeC8vL7e279mzZ2Tb6dOnJ6pJs+GRXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeK6uznj4gl4DHgo8C7wNHM/HpEHAK+APy2eetDmfmjvgrVaCdPnhzZ1jUu/8rKylTbbuvHB/vyF9k4F/m8DdyfmS9HxIeBlyLimabt0cx8uL/yJPWlM/yZeR443zx/MyLOAlv7LkxSv67oO39E3AR8EnihWXRvRPwsIo5FxHUjPrM/Is5ExJmpKpU0U2OHPyI+BHwfOJCZfwC+AXwCuJW1M4PDG30uM49m5vbM3D6DeiXNyFjhj4gPsBb872TmDwAy80JmvpOZ7wLfBG7rr0xJs9YZ/ogI4FvA2cx8ZN3yLeve9jng1dmXJ6kvnUN3R8SngZ8Ar7DW1QfwEHAPa6f8CZwDvtj8ONi2Lofulno27tDdjtsvXWUct19SK8MvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJR856i+3fA/657/ZFm2SJa1NoWtS6wtknNsra/GPeNc72f/30bjzizqGP7LWpti1oXWNukhqrN036pKMMvFTV0+I8OvP02i1rbotYF1japQWob9Du/pOEMfeSXNJBBwh8Rd0XELyPitYh4cIgaRomIcxHxSkT8dOgpxppp0C5GxKvrll0fEc9ExK+bxw2nSRuotkMR8Xqz734aEf80UG1LEfFfEXE2In4eEf/aLB9037XUNch+m/tpf0RcA/wKuBNYBV4E7snMX8y1kBEi4hywPTMH7xOOiL8D3gIey8xbmmX/BlzKzK81fzivy8wvL0hth4C3hp65uZlQZsv6maWBu4F/ZsB911LXHgbYb0Mc+W8DXsvM32TmH4HvAbsGqGPhZeZzwKXLFu8CjjfPj7P2n2fuRtS2EDLzfGa+3Dx/E3hvZulB911LXYMYIvxbgZV1r1dZrCm/E/hxRLwUEfuHLmYDN743M1LzeMPA9Vyuc+bmebpsZumF2XeTzHg9a0OEf6PZRBapy+FTmfm3wD8CX2pObzWesWZunpcNZpZeCJPOeD1rQ4R/FVha9/pjwBsD1LGhzHyjebwI/JDFm334wnuTpDaPFweu5/8t0szNG80szQLsu0Wa8XqI8L8I3BwRH4+IDwKfB54eoI73iYhrmx9iiIhrgc+weLMPPw3sa57vA54asJY/sSgzN4+aWZqB992izXg9yEU+TVfGEeAa4FhmfnXuRWwgIv6StaM9rN3x+N0ha4uIx4E7WLvr6wLwFeDfgSeAbcAysDsz5/7D24ja7uAKZ27uqbZRM0u/wID7bpYzXs+kHq/wk2ryCj+pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0X9H2LFAu6LVvxRAAAAAElFTkSuQmCC\n", 280 | "text/plain": [ 281 | "
" 282 | ] 283 | }, 284 | "metadata": { 285 | "needs_background": "light" 286 | }, 287 | "output_type": "display_data" 288 | } 289 | ], 290 | "source": [ 291 | "# 2018번째 데이터를 예로 추론 수행\n", 292 | "\n", 293 | "index = 2018\n", 294 | "\n", 295 | "model.eval() # 신경망을 추론 모드로 전환\n", 296 | "data = X_test[index]\n", 297 | "output = model(data) # 데이터를 입력하고 출력을 계산\n", 298 | "_, predicted = torch.max(output.data, 0) # 확률이 가장 높은 레이블이 무엇인지 계산\n", 299 | "\n", 300 | "print(\"예측 결과 : {}\".format(predicted))\n", 301 | "\n", 302 | "X_test_show = (X_test[index]).numpy()\n", 303 | "plt.imshow(X_test_show.reshape(28, 28), cmap='gray')\n", 304 | "print(\"이 이미지 데이터의 정답 레이블은 {:.0f}입니다\".format(y_test[index]))\n" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 17, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "#-----------------------------------------------" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 18, 319 | "metadata": {}, 320 | "outputs": [ 321 | { 322 | "name": "stdout", 323 | "output_type": "stream", 324 | "text": [ 325 | "Net(\n", 326 | " (fc1): Linear(in_features=784, out_features=100, bias=True)\n", 327 | " (fc2): Linear(in_features=100, out_features=100, bias=True)\n", 328 | " (fc3): Linear(in_features=100, out_features=10, bias=True)\n", 329 | ")\n" 330 | ] 331 | } 332 | ], 333 | "source": [ 334 | "# 3. 신경망 구성\n", 335 | "# 신경망 구성 (Chainer 스타일)\n", 336 | "import torch.nn as nn\n", 337 | "import torch.nn.functional as F\n", 338 | "\n", 339 | "\n", 340 | "class Net(nn.Module):\n", 341 | "\n", 342 | " def __init__(self, n_in, n_mid, n_out):\n", 343 | " super(Net, self).__init__()\n", 344 | " self.fc1 = nn.Linear(n_in, n_mid) # Chainer와 달리、None을 받을 수는 없다\n", 345 | " self.fc2 = nn.Linear(n_mid, n_mid)\n", 346 | " self.fc3 = nn.Linear(n_mid, n_out)\n", 347 | "\n", 348 | " def forward(self, x):\n", 349 | " # 입력 x에 따라 forward 계산 과정이 변화함\n", 350 | " h1 = F.relu(self.fc1(x))\n", 351 | " h2 = F.relu(self.fc2(h1))\n", 352 | " output = self.fc3(h2)\n", 353 | " return output\n", 354 | "\n", 355 | "\n", 356 | "model = Net(n_in=28*28*1, n_mid=100, n_out=10) # 신경망 객체를 생성\n", 357 | "print(model)\n" 358 | ] 359 | } 360 | ], 361 | "metadata": { 362 | "kernelspec": { 363 | "display_name": "Python 3", 364 | "language": "python", 365 | "name": "python3" 366 | }, 367 | "language_info": { 368 | "codemirror_mode": { 369 | "name": "ipython", 370 | "version": 3 371 | }, 372 | "file_extension": ".py", 373 | "mimetype": "text/x-python", 374 | "name": "python", 375 | "nbconvert_exporter": "python", 376 | "pygments_lexer": "ipython3", 377 | "version": "3.6.6" 378 | } 379 | }, 380 | "nbformat": 4, 381 | "nbformat_minor": 2 382 | } 383 | -------------------------------------------------------------------------------- /program/6_2_DDQN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 6.2 DDQN 구현" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# 구현에 사용할 패키지 임포트\n", 17 | "import numpy as np\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "%matplotlib inline\n", 20 | "import gym\n" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "# 애니메이션을 만드는 함수\n", 30 | "# 참고 URL http://nbviewer.jupyter.org/github/patrickmineault\n", 31 | "# /xcorr-notebooks/blob/master/Render%20OpenAI%20gym%20as%20GIF.ipynb\n", 32 | "from JSAnimation.IPython_display import display_animation\n", 33 | "from matplotlib import animation\n", 34 | "from IPython.display import display\n", 35 | "\n", 36 | "\n", 37 | "def display_frames_as_gif(frames):\n", 38 | " \"\"\"\n", 39 | " Displays a list of frames as a gif, with controls\n", 40 | " \"\"\"\n", 41 | " plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0),\n", 42 | " dpi=72)\n", 43 | " patch = plt.imshow(frames[0])\n", 44 | " plt.axis('off')\n", 45 | "\n", 46 | " def animate(i):\n", 47 | " patch.set_data(frames[i])\n", 48 | "\n", 49 | " anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames),\n", 50 | " interval=50)\n", 51 | "\n", 52 | " anim.save('movie_cartpole_DDQN.mp4') # 애니메이션을 저장하는 부분\n", 53 | " display(display_animation(anim, default_mode='loop'))\n" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "# namedtuple 생성\n", 63 | "from collections import namedtuple\n", 64 | "\n", 65 | "Transition = namedtuple(\n", 66 | " 'Transition', ('state', 'action', 'next_state', 'reward'))\n" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "# 상수 정의\n", 76 | "ENV = 'CartPole-v0' # 태스크 이름\n", 77 | "GAMMA = 0.99 # 시간할인율\n", 78 | "MAX_STEPS = 200 # 1에피소드 당 최대 단계 수\n", 79 | "NUM_EPISODES = 500 # 최대 에피소드 수\n" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 5, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "# transition을 저장하기 위한 메모리 클래스\n", 89 | "\n", 90 | "\n", 91 | "class ReplayMemory:\n", 92 | "\n", 93 | " def __init__(self, CAPACITY):\n", 94 | " self.capacity = CAPACITY # 메모리의 최대 저장 건수\n", 95 | " self.memory = [] # 실제 transition을 저장할 변수\n", 96 | " self.index = 0 # 저장 위치를 가리킬 인덱스 변수\n", 97 | "\n", 98 | " def push(self, state, action, state_next, reward):\n", 99 | " '''transition = (state, action, state_next, reward)을 메모리에 저장'''\n", 100 | "\n", 101 | " if len(self.memory) < self.capacity:\n", 102 | " self.memory.append(None) # 메모리가 가득차지 않은 경우\n", 103 | "\n", 104 | " # Transition이라는 namedtuple을 사용하여 키-값 쌍의 형태로 값을 저장\n", 105 | " self.memory[self.index] = Transition(state, action, state_next, reward)\n", 106 | "\n", 107 | " self.index = (self.index + 1) % self.capacity # 다음 저장할 위치를 한 자리 뒤로 수정\n", 108 | "\n", 109 | " def sample(self, batch_size):\n", 110 | " '''batch_size 갯수 만큼 무작위로 저장된 transition을 추출'''\n", 111 | " return random.sample(self.memory, batch_size)\n", 112 | "\n", 113 | " def __len__(self):\n", 114 | " '''len 함수로 현재 저장된 transition 갯수를 반환'''\n", 115 | " return len(self.memory)\n" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 6, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# 신경망 구성\n", 125 | "import torch.nn as nn\n", 126 | "import torch.nn.functional as F\n", 127 | "\n", 128 | "\n", 129 | "class Net(nn.Module):\n", 130 | "\n", 131 | " def __init__(self, n_in, n_mid, n_out):\n", 132 | " super(Net, self).__init__()\n", 133 | " self.fc1 = nn.Linear(n_in, n_mid)\n", 134 | " self.fc2 = nn.Linear(n_mid, n_mid)\n", 135 | " self.fc3 = nn.Linear(n_mid, n_out)\n", 136 | "\n", 137 | " def forward(self, x):\n", 138 | " h1 = F.relu(self.fc1(x))\n", 139 | " h2 = F.relu(self.fc2(h1))\n", 140 | " output = self.fc3(h2)\n", 141 | " return output\n" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 7, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "# 에이전트의 두뇌 역할을 하는 클래스, DDQN을 실제 수행한다 \n", 151 | "\n", 152 | "import random\n", 153 | "import torch\n", 154 | "from torch import nn\n", 155 | "from torch import optim\n", 156 | "import torch.nn.functional as F\n", 157 | "\n", 158 | "BATCH_SIZE = 32\n", 159 | "CAPACITY = 10000\n", 160 | "\n", 161 | "\n", 162 | "class Brain:\n", 163 | " def __init__(self, num_states, num_actions):\n", 164 | " self.num_actions = num_actions # CartPoleの行動(右に左に押す)の2を取得\n", 165 | "\n", 166 | " # transition을 기억하기 위한 메모리 객체 생성\n", 167 | " self.memory = ReplayMemory(CAPACITY)\n", 168 | "\n", 169 | " # 신경망 구성\n", 170 | " n_in, n_mid, n_out = num_states, 32, num_actions\n", 171 | " self.main_q_network = Net(n_in, n_mid, n_out) # Net 클래스를 사용\n", 172 | " self.target_q_network = Net(n_in, n_mid, n_out) # Net 클래스를 사용\n", 173 | " print(self.main_q_network) # 신경망의 구조를 출력\n", 174 | "\n", 175 | " # 최적화 기법 선택\n", 176 | " self.optimizer = optim.Adam(\n", 177 | " self.main_q_network.parameters(), lr=0.0001)\n", 178 | "\n", 179 | " def replay(self):\n", 180 | " '''Experience Replay로 신경망의 결합 가중치 학습'''\n", 181 | "\n", 182 | " # 1. 저장된 transition의 수를 확인\n", 183 | " if len(self.memory) < BATCH_SIZE:\n", 184 | " return\n", 185 | "\n", 186 | " # 2. 미니배치 생성\n", 187 | " self.batch, self.state_batch, self.action_batch, self.reward_batch, self.non_final_next_states = self.make_minibatch()\n", 188 | "\n", 189 | " # 3. 정답신호로 사용할 Q(s_t, a_t)를 계산\n", 190 | " self.expected_state_action_values = self.get_expected_state_action_values()\n", 191 | "\n", 192 | " # 4. 결합 가중치 수정\n", 193 | " self.update_main_q_network()\n", 194 | "\n", 195 | " def decide_action(self, state, episode):\n", 196 | " '''현재 상태로부터 행동을 결정함'''\n", 197 | " # ε-greedy 알고리즘에서 서서히 최적행동의 비중을 늘린다\n", 198 | " epsilon = 0.5 * (1 / (episode + 1))\n", 199 | "\n", 200 | " if epsilon <= np.random.uniform(0, 1):\n", 201 | " self.main_q_network.eval() # 신경망을 추론 모드로 전환\n", 202 | " with torch.no_grad():\n", 203 | " action = self.main_q_network(state).max(1)[1].view(1, 1)\n", 204 | " # 신경망 출력의 최댓값에 대한 인덱스 = max(1)[1]\n", 205 | " # .view(1,1)은 [torch.LongTensor of size 1] 을 size 1*1로 변환하는 역할을 한다\n", 206 | "\n", 207 | " else:\n", 208 | " # 행동을 무작위로 반환(0 혹은 1)\n", 209 | " action = torch.LongTensor(\n", 210 | " [[random.randrange(self.num_actions)]]) # 행동을 무작위로 반환(0 혹은 1)\n", 211 | " # action은 [torch.LongTensor of size 1*1] 형태가 된다\n", 212 | "\n", 213 | " return action\n", 214 | "\n", 215 | " def make_minibatch(self):\n", 216 | " '''2. 미니배치 생성'''\n", 217 | "\n", 218 | " # 2.1 메모리 객체에서 미니배치를 추출\n", 219 | " transitions = self.memory.sample(BATCH_SIZE)\n", 220 | "\n", 221 | " # 2.2 각 변수를 미니배치에 맞는 형태로 변형\n", 222 | " # transitions는 각 단계 별로 (state, action, state_next, reward) 형태로 BATCH_SIZE 갯수만큼 저장됨\n", 223 | " # 다시 말해, (state, action, state_next, reward) * BATCH_SIZE 형태가 된다\n", 224 | " # 이것을 미니배치로 만들기 위해\n", 225 | " # (state*BATCH_SIZE, action*BATCH_SIZE, state_next*BATCH_SIZE, reward*BATCH_SIZE) 형태로 변환한다\n", 226 | " batch = Transition(*zip(*transitions))\n", 227 | "\n", 228 | " # 2.3 각 변수의 요소를 미니배치에 맞게 변형하고, 신경망으로 다룰 수 있도록 Variable로 만든다\n", 229 | " # state를 예로 들면, [torch.FloatTensor of size 1*4] 형태의 요소가 BATCH_SIZE 갯수만큼 있는 형태이다\n", 230 | " # 이를 torch.FloatTensor of size BATCH_SIZE*4 형태로 변형한다\n", 231 | " # 상태, 행동, 보상, non_final 상태로 된 미니배치를 나타내는 Variable을 생성\n", 232 | " # cat은 Concatenates(연접)을 의미한다\n", 233 | " state_batch = torch.cat(batch.state)\n", 234 | " action_batch = torch.cat(batch.action)\n", 235 | " reward_batch = torch.cat(batch.reward)\n", 236 | " non_final_next_states = torch.cat([s for s in batch.next_state\n", 237 | " if s is not None])\n", 238 | "\n", 239 | " return batch, state_batch, action_batch, reward_batch, non_final_next_states\n", 240 | "\n", 241 | " def get_expected_state_action_values(self):\n", 242 | " '''정답신호로 사용할 Q(s_t, a_t)를 계산'''\n", 243 | "\n", 244 | " # 3.1 신경망을 추론 모드로 전환\n", 245 | " self.main_q_network.eval()\n", 246 | " self.target_q_network.eval()\n", 247 | "\n", 248 | " # 3.2 신경망으로 Q(s_t, a_t)를 계산\n", 249 | " # self.model(state_batch)은 왼쪽, 오른쪽에 대한 Q값을 출력하며\n", 250 | " # [torch.FloatTensor of size BATCH_SIZEx2] 형태이다\n", 251 | " # 여기서부터는 실행한 행동 a_t에 대한 Q값을 계산하므로 action_batch에서 취한 행동 a_t가 \n", 252 | " # 왼쪽이냐 오른쪽이냐에 대한 인덱스를 구하고, 이에 대한 Q값을 gather 메서드로 모아온다\n", 253 | " self.state_action_values = self.main_q_network(\n", 254 | " self.state_batch).gather(1, self.action_batch)\n", 255 | "\n", 256 | " # 3.3 max{Q(s_t+1, a)}값을 계산한다 이때 다음 상태가 존재하는지에 주의해야 한다\n", 257 | "\n", 258 | " # cartpole이 done 상태가 아니고, next_state가 존재하는지 확인하는 인덱스 마스크를 만듬\n", 259 | " non_final_mask = torch.ByteTensor(tuple(map(lambda s: s is not None,\n", 260 | " self.batch.next_state)))\n", 261 | " # 먼저 전체를 0으로 초기화\n", 262 | " next_state_values = torch.zeros(BATCH_SIZE)\n", 263 | "\n", 264 | " a_m = torch.zeros(BATCH_SIZE).type(torch.LongTensor)\n", 265 | "\n", 266 | " # 다음 상태에서 Q값이 최대가 되는 행동 a_m을 Main Q-Network로 계산\n", 267 | " # 마지막에 붙은 [1]로 행동에 해당하는 인덱스를 구함\n", 268 | " a_m[non_final_mask] = self.main_q_network(\n", 269 | " self.non_final_next_states).detach().max(1)[1]\n", 270 | "\n", 271 | " # 다음 상태가 있는 것만을 걸러내고, size 32를 32*1로 변환\n", 272 | " a_m_non_final_next_states = a_m[non_final_mask].view(-1, 1)\n", 273 | "\n", 274 | " # 다음 상태가 있는 인덱스에 대해 행동 a_m의 Q값을 target Q-Network로 계산\n", 275 | " # detach() 메서드로 값을 꺼내옴\n", 276 | " # squeeze() 메서드로 size[minibatch*1]을 [minibatch]로 변환\n", 277 | " next_state_values[non_final_mask] = self.target_q_network(\n", 278 | " self.non_final_next_states).gather(1, a_m_non_final_next_states).detach().squeeze()\n", 279 | "\n", 280 | " # 3.4 정답신호로 사용할 Q(s_t, a_t)값을 Q러닝 식으로 계산한다\n", 281 | " expected_state_action_values = self.reward_batch + GAMMA * next_state_values\n", 282 | "\n", 283 | " return expected_state_action_values\n", 284 | "\n", 285 | " def update_main_q_network(self):\n", 286 | " '''4. 결합 가중치 수정'''\n", 287 | "\n", 288 | " # 4.1 신경망을 학습 모드로 전환\n", 289 | " self.main_q_network.train()\n", 290 | "\n", 291 | " # 4.2 손실함수를 계산 (smooth_l1_loss는 Huber 함수)\n", 292 | " # expected_state_action_values은\n", 293 | " # size가 [minibatch]이므로 unsqueeze하여 [minibatch*1]로 만든다\n", 294 | " loss = F.smooth_l1_loss(self.state_action_values,\n", 295 | " self.expected_state_action_values.unsqueeze(1))\n", 296 | "\n", 297 | " # 4.3 결합 가중치를 수정한다\n", 298 | " self.optimizer.zero_grad() # 경사를 초기화\n", 299 | " loss.backward() # 역전파 계산\n", 300 | " self.optimizer.step() # 결합 가중치 수정\n", 301 | "\n", 302 | " def update_target_q_network(self): # DDQN에서 추가됨\n", 303 | " '''Target Q-Network을 Main Q-Network와 맞춤'''\n", 304 | " self.target_q_network.load_state_dict(self.main_q_network.state_dict())\n" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 8, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "# CartPole 태스크의 에이전트 클래스. 봉 달린 수레 자체라고 보면 된다\n", 314 | "\n", 315 | "\n", 316 | "class Agent:\n", 317 | " def __init__(self, num_states, num_actions):\n", 318 | " '''태스크의 상태 및 행동의 가짓수를 설정'''\n", 319 | " self.brain = Brain(num_states, num_actions) # 에이전트의 행동을 결정할 두뇌 역할 객체를 생성\n", 320 | "\n", 321 | " def update_q_function(self):\n", 322 | " '''Q함수를 수정'''\n", 323 | " self.brain.replay()\n", 324 | "\n", 325 | " def get_action(self, state, episode):\n", 326 | " '''행동을 결정'''\n", 327 | " action = self.brain.decide_action(state, episode)\n", 328 | " return action\n", 329 | "\n", 330 | " def memorize(self, state, action, state_next, reward):\n", 331 | " '''memory 객체에 state, action, state_next, reward 내용을 저장'''\n", 332 | " self.brain.memory.push(state, action, state_next, reward)\n", 333 | "\n", 334 | " def update_target_q_function(self):\n", 335 | " '''Target Q-Network을 Main Q-Network와 맞춤'''\n", 336 | " self.brain.update_target_q_network()\n", 337 | " " 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 9, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "# CartPole을 실행하는 환경 클래스\n", 347 | "\n", 348 | "\n", 349 | "class Environment:\n", 350 | "\n", 351 | " def __init__(self):\n", 352 | " self.env = gym.make(ENV) # 태스크를 설정\n", 353 | " num_states = self.env.observation_space.shape[0] # 태스크의 상태 변수 수(4)를 받아옴\n", 354 | " num_actions = self.env.action_space.n # 태스크의 행동 가짓수(2)를 받아옴\n", 355 | " self.agent = Agent(num_states, num_actions) # 에이전트 역할을 할 객체를 생성\n", 356 | "\n", 357 | " def run(self):\n", 358 | " '''실행'''\n", 359 | " episode_10_list = np.zeros(10) # 최근 10에피소드 동안 버틴 단계 수를 저장함(평균 단계 수를 출력할 때 사용)\n", 360 | " complete_episodes = 0 # 현재까지 195단계를 버틴 에피소드 수\n", 361 | " episode_final = False # 마지막 에피소드 여부\n", 362 | " frames = [] # 애니메이션을 만들기 위해 마지막 에피소드의 프레임을 저장할 배열\n", 363 | "\n", 364 | " for episode in range(NUM_EPISODES): # 최대 에피소드 수만큼 반복\n", 365 | " observation = self.env.reset() # 환경 초기화\n", 366 | "\n", 367 | " state = observation # 관측을 변환없이 그대로 상태 s로 사용\n", 368 | " state = torch.from_numpy(state).type(\n", 369 | " torch.FloatTensor) # NumPy 변수를 파이토치 텐서로 변환\n", 370 | " state = torch.unsqueeze(state, 0) # size 4를 size 1*4로 변환\n", 371 | "\n", 372 | " for step in range(MAX_STEPS): # 1 에피소드에 해당하는 반복문\n", 373 | " \n", 374 | " # 애니메이션 만드는 부분을 주석처리\n", 375 | " #if episode_final is True: # 마지막 에피소드에서는 각 시각의 이미지를 frames에 저장한다\n", 376 | " # frames.append(self.env.render(mode='rgb_array'))\n", 377 | " \n", 378 | " action = self.agent.get_action(state, episode) # 다음 행동을 결정\n", 379 | "\n", 380 | " # 행동 a_t를 실행하여 다음 상태 s_{t+1}과 done 플래그 값을 결정\n", 381 | " # action에 .item()을 호출하여 행동 내용을 구함\n", 382 | " observation_next, _, done, _ = self.env.step(\n", 383 | " action.item()) # reward와 info는 사용하지 않으므로 _로 처리\n", 384 | "\n", 385 | " # 보상을 부여하고 episode의 종료 판정 및 state_next를 설정한다\n", 386 | " if done: # 단계 수가 200을 넘었거나 봉이 일정 각도 이상 기울면 done이 True가 됨\n", 387 | " state_next = None # 다음 상태가 없으므로 None으로\n", 388 | "\n", 389 | " # 최근 10 에피소드에서 버틴 단계 수를 리스트에 저장\n", 390 | " episode_10_list = np.hstack(\n", 391 | " (episode_10_list[1:], step + 1))\n", 392 | "\n", 393 | " if step < 195:\n", 394 | " reward = torch.FloatTensor(\n", 395 | " [-1.0]) # 도중에 봉이 쓰러졌다면 페널티로 보상 -1을 부여\n", 396 | " complete_episodes = 0 # 연속 성공 에피소드 기록을 초기화\n", 397 | " else:\n", 398 | " reward = torch.FloatTensor([1.0]) # 봉이 서 있는 채로 에피소드를 마쳤다면 보상 1 부여\n", 399 | " complete_episodes = complete_episodes + 1 # 연속 성공 에피소드 기록을 갱신\n", 400 | " else:\n", 401 | " reward = torch.FloatTensor([0.0]) # 그 외의 경우는 보상 0을 부여\n", 402 | " state_next = observation_next # 관측 결과를 그대로 상태로 사용\n", 403 | " state_next = torch.from_numpy(state_next).type(\n", 404 | " torch.FloatTensor) # numpy 변수를 파이토치 텐서로 변환\n", 405 | " state_next = torch.unsqueeze(state_next, 0) # size 4를 size 1*4로 변환\n", 406 | "\n", 407 | " # 메모리에 경험을 저장\n", 408 | " self.agent.memorize(state, action, state_next, reward)\n", 409 | "\n", 410 | " # Experience Replay로 Q함수를 수정\n", 411 | " self.agent.update_q_function()\n", 412 | "\n", 413 | " # 관측 결과를 업데이트\n", 414 | " state = state_next\n", 415 | "\n", 416 | " # 에피소드 종료 처리\n", 417 | " if done:\n", 418 | " print('%d Episode: Finished after %d steps:최근 10 에피소드의 평균 단계 수 = %.1lf' % (\n", 419 | " episode, step + 1, episode_10_list.mean()))\n", 420 | " \n", 421 | " # DDQN으로 추가된 부분 2에피소드마다 한번씩 Target Q-Network을 Main Q-Network와 맞춰줌\n", 422 | " if(episode % 2 == 0):\n", 423 | " self.agent.update_target_q_function()\n", 424 | " break\n", 425 | " \n", 426 | " \n", 427 | " if episode_final is True:\n", 428 | " # 애니메이션 생성 부분을 주석처리함\n", 429 | " # 애니메이션 생성 및 저장\n", 430 | " #display_frames_as_gif(frames)\n", 431 | " break\n", 432 | "\n", 433 | " # 10 에피소드 연속으로 195단계를 버티면 태스크 성공\n", 434 | " if complete_episodes >= 10:\n", 435 | " print('10 에피소드 연속 성공')\n", 436 | " episode_final = True # 다음 에피소드에서 애니메이션을 생성\n" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 10, 442 | "metadata": {}, 443 | "outputs": [ 444 | { 445 | "name": "stdout", 446 | "output_type": "stream", 447 | "text": [ 448 | "Net(\n", 449 | " (fc1): Linear(in_features=4, out_features=32, bias=True)\n", 450 | " (fc2): Linear(in_features=32, out_features=32, bias=True)\n", 451 | " (fc3): Linear(in_features=32, out_features=2, bias=True)\n", 452 | ")\n", 453 | "0 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 1.1\n", 454 | "1 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 2.1\n", 455 | "2 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 3.2\n", 456 | "3 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 4.2\n", 457 | "4 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 5.1\n", 458 | "5 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 6.0\n", 459 | "6 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 7.0\n", 460 | "7 Episode: Finished after 8 steps:최근 10 에피소드의 평균 단계 수 = 7.8\n", 461 | "8 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 8.7\n", 462 | "9 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.6\n", 463 | "10 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n", 464 | "11 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 9.6\n", 465 | "12 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n", 466 | "13 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n", 467 | "14 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n", 468 | "15 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 9.7\n", 469 | "16 Episode: Finished after 14 steps:최근 10 에피소드의 평균 단계 수 = 10.1\n", 470 | "17 Episode: Finished after 15 steps:최근 10 에피소드의 평균 단계 수 = 10.8\n", 471 | "18 Episode: Finished after 79 steps:최근 10 에피소드의 평균 단계 수 = 17.8\n", 472 | "19 Episode: Finished after 37 steps:최근 10 에피소드의 평균 단계 수 = 20.6\n", 473 | "20 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 20.8\n", 474 | "21 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 20.7\n", 475 | "22 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 20.7\n", 476 | "23 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 20.7\n", 477 | "24 Episode: Finished after 8 steps:최근 10 에피소드의 평균 단계 수 = 20.5\n", 478 | "25 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 20.4\n", 479 | "26 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 19.9\n", 480 | "27 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 19.4\n", 481 | "28 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 12.4\n", 482 | "29 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.6\n", 483 | "30 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n", 484 | "31 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n", 485 | "32 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n", 486 | "33 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n", 487 | "34 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 9.7\n", 488 | "35 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.6\n", 489 | "36 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.7\n", 490 | "37 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.6\n", 491 | "38 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.7\n", 492 | "39 Episode: Finished after 8 steps:최근 10 에피소드의 평균 단계 수 = 9.6\n", 493 | "40 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n", 494 | "41 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n", 495 | "42 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n", 496 | "43 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n", 497 | "44 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n", 498 | "45 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n", 499 | "46 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 9.6\n", 500 | "47 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 9.8\n", 501 | "48 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.8\n", 502 | "49 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 10.2\n", 503 | "50 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 10.3\n", 504 | "51 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 10.5\n", 505 | "52 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 10.6\n", 506 | "53 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 10.7\n", 507 | "54 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 10.6\n", 508 | "55 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 10.7\n", 509 | "56 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 10.6\n", 510 | "57 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 10.7\n", 511 | "58 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 10.7\n", 512 | "59 Episode: Finished after 13 steps:최근 10 에피소드의 평균 단계 수 = 10.8\n", 513 | "60 Episode: Finished after 14 steps:최근 10 에피소드의 평균 단계 수 = 11.2\n", 514 | "61 Episode: Finished after 14 steps:최근 10 에피소드의 평균 단계 수 = 11.4\n", 515 | "62 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 11.6\n", 516 | "63 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 11.7\n", 517 | "64 Episode: Finished after 16 steps:최근 10 에피소드의 평균 단계 수 = 12.4\n", 518 | "65 Episode: Finished after 14 steps:최근 10 에피소드의 평균 단계 수 = 12.7\n", 519 | "66 Episode: Finished after 15 steps:최근 10 에피소드의 평균 단계 수 = 13.2\n", 520 | "67 Episode: Finished after 15 steps:최근 10 에피소드의 평균 단계 수 = 13.5\n", 521 | "68 Episode: Finished after 17 steps:최근 10 에피소드의 평균 단계 수 = 14.2\n", 522 | "69 Episode: Finished after 17 steps:최근 10 에피소드의 평균 단계 수 = 14.6\n", 523 | "70 Episode: Finished after 38 steps:최근 10 에피소드의 평균 단계 수 = 17.0\n", 524 | "71 Episode: Finished after 25 steps:최근 10 에피소드의 평균 단계 수 = 18.1\n", 525 | "72 Episode: Finished after 30 steps:최근 10 에피소드의 평균 단계 수 = 19.9\n", 526 | "73 Episode: Finished after 22 steps:최근 10 에피소드의 평균 단계 수 = 20.9\n", 527 | "74 Episode: Finished after 32 steps:최근 10 에피소드의 평균 단계 수 = 22.5\n", 528 | "75 Episode: Finished after 26 steps:최근 10 에피소드의 평균 단계 수 = 23.7\n", 529 | "76 Episode: Finished after 43 steps:최근 10 에피소드의 평균 단계 수 = 26.5\n", 530 | "77 Episode: Finished after 40 steps:최근 10 에피소드의 평균 단계 수 = 29.0\n", 531 | "78 Episode: Finished after 56 steps:최근 10 에피소드의 평균 단계 수 = 32.9\n", 532 | "79 Episode: Finished after 63 steps:최근 10 에피소드의 평균 단계 수 = 37.5\n", 533 | "80 Episode: Finished after 66 steps:최근 10 에피소드의 평균 단계 수 = 40.3\n", 534 | "81 Episode: Finished after 85 steps:최근 10 에피소드의 평균 단계 수 = 46.3\n", 535 | "82 Episode: Finished after 65 steps:최근 10 에피소드의 평균 단계 수 = 49.8\n", 536 | "83 Episode: Finished after 42 steps:최근 10 에피소드의 평균 단계 수 = 51.8\n", 537 | "84 Episode: Finished after 50 steps:최근 10 에피소드의 평균 단계 수 = 53.6\n", 538 | "85 Episode: Finished after 49 steps:최근 10 에피소드의 평균 단계 수 = 55.9\n", 539 | "86 Episode: Finished after 48 steps:최근 10 에피소드의 평균 단계 수 = 56.4\n", 540 | "87 Episode: Finished after 47 steps:최근 10 에피소드의 평균 단계 수 = 57.1\n", 541 | "88 Episode: Finished after 74 steps:최근 10 에피소드의 평균 단계 수 = 58.9\n", 542 | "89 Episode: Finished after 142 steps:최근 10 에피소드의 평균 단계 수 = 66.8\n", 543 | "90 Episode: Finished after 69 steps:최근 10 에피소드의 평균 단계 수 = 67.1\n", 544 | "91 Episode: Finished after 76 steps:최근 10 에피소드의 평균 단계 수 = 66.2\n", 545 | "92 Episode: Finished after 90 steps:최근 10 에피소드의 평균 단계 수 = 68.7\n", 546 | "93 Episode: Finished after 79 steps:최근 10 에피소드의 평균 단계 수 = 72.4\n", 547 | "94 Episode: Finished after 77 steps:최근 10 에피소드의 평균 단계 수 = 75.1\n", 548 | "95 Episode: Finished after 66 steps:최근 10 에피소드의 평균 단계 수 = 76.8\n", 549 | "96 Episode: Finished after 56 steps:최근 10 에피소드의 평균 단계 수 = 77.6\n", 550 | "97 Episode: Finished after 68 steps:최근 10 에피소드의 평균 단계 수 = 79.7\n", 551 | "98 Episode: Finished after 88 steps:최근 10 에피소드의 평균 단계 수 = 81.1\n", 552 | "99 Episode: Finished after 52 steps:최근 10 에피소드의 평균 단계 수 = 72.1\n", 553 | "100 Episode: Finished after 33 steps:최근 10 에피소드의 평균 단계 수 = 68.5\n", 554 | "101 Episode: Finished after 77 steps:최근 10 에피소드의 평균 단계 수 = 68.6\n", 555 | "102 Episode: Finished after 33 steps:최근 10 에피소드의 평균 단계 수 = 62.9\n", 556 | "103 Episode: Finished after 64 steps:최근 10 에피소드의 평균 단계 수 = 61.4\n", 557 | "104 Episode: Finished after 65 steps:최근 10 에피소드의 평균 단계 수 = 60.2\n", 558 | "105 Episode: Finished after 70 steps:최근 10 에피소드의 평균 단계 수 = 60.6\n", 559 | "106 Episode: Finished after 62 steps:최근 10 에피소드의 평균 단계 수 = 61.2\n", 560 | "107 Episode: Finished after 59 steps:최근 10 에피소드의 평균 단계 수 = 60.3\n", 561 | "108 Episode: Finished after 45 steps:최근 10 에피소드의 평균 단계 수 = 56.0\n", 562 | "109 Episode: Finished after 49 steps:최근 10 에피소드의 평균 단계 수 = 55.7\n", 563 | "110 Episode: Finished after 115 steps:최근 10 에피소드의 평균 단계 수 = 63.9\n", 564 | "111 Episode: Finished after 83 steps:최근 10 에피소드의 평균 단계 수 = 64.5\n", 565 | "112 Episode: Finished after 126 steps:최근 10 에피소드의 평균 단계 수 = 73.8\n", 566 | "113 Episode: Finished after 76 steps:최근 10 에피소드의 평균 단계 수 = 75.0\n", 567 | "114 Episode: Finished after 86 steps:최근 10 에피소드의 평균 단계 수 = 77.1\n", 568 | "115 Episode: Finished after 52 steps:최근 10 에피소드의 평균 단계 수 = 75.3\n", 569 | "116 Episode: Finished after 55 steps:최근 10 에피소드의 평균 단계 수 = 74.6\n", 570 | "117 Episode: Finished after 73 steps:최근 10 에피소드의 평균 단계 수 = 76.0\n", 571 | "118 Episode: Finished after 147 steps:최근 10 에피소드의 평균 단계 수 = 86.2\n", 572 | "119 Episode: Finished after 85 steps:최근 10 에피소드의 평균 단계 수 = 89.8\n", 573 | "120 Episode: Finished after 108 steps:최근 10 에피소드의 평균 단계 수 = 89.1\n", 574 | "121 Episode: Finished after 107 steps:최근 10 에피소드의 평균 단계 수 = 91.5\n", 575 | "122 Episode: Finished after 106 steps:최근 10 에피소드의 평균 단계 수 = 89.5\n", 576 | "123 Episode: Finished after 89 steps:최근 10 에피소드의 평균 단계 수 = 90.8\n", 577 | "124 Episode: Finished after 106 steps:최근 10 에피소드의 평균 단계 수 = 92.8\n", 578 | "125 Episode: Finished after 76 steps:최근 10 에피소드의 평균 단계 수 = 95.2\n", 579 | "126 Episode: Finished after 83 steps:최근 10 에피소드의 평균 단계 수 = 98.0\n", 580 | "127 Episode: Finished after 144 steps:최근 10 에피소드의 평균 단계 수 = 105.1\n", 581 | "128 Episode: Finished after 74 steps:최근 10 에피소드의 평균 단계 수 = 97.8\n", 582 | "129 Episode: Finished after 87 steps:최근 10 에피소드의 평균 단계 수 = 98.0\n" 583 | ] 584 | }, 585 | { 586 | "name": "stdout", 587 | "output_type": "stream", 588 | "text": [ 589 | "130 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 107.2\n", 590 | "131 Episode: Finished after 72 steps:최근 10 에피소드의 평균 단계 수 = 103.7\n", 591 | "132 Episode: Finished after 76 steps:최근 10 에피소드의 평균 단계 수 = 100.7\n", 592 | "133 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 111.8\n", 593 | "134 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 121.2\n", 594 | "135 Episode: Finished after 184 steps:최근 10 에피소드의 평균 단계 수 = 132.0\n", 595 | "136 Episode: Finished after 123 steps:최근 10 에피소드의 평균 단계 수 = 136.0\n", 596 | "137 Episode: Finished after 142 steps:최근 10 에피소드의 평균 단계 수 = 135.8\n", 597 | "138 Episode: Finished after 184 steps:최근 10 에피소드의 평균 단계 수 = 146.8\n", 598 | "139 Episode: Finished after 198 steps:최근 10 에피소드의 평균 단계 수 = 157.9\n", 599 | "140 Episode: Finished after 180 steps:최근 10 에피소드의 평균 단계 수 = 155.9\n", 600 | "141 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 168.7\n", 601 | "142 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 181.1\n", 602 | "143 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 181.1\n", 603 | "144 Episode: Finished after 196 steps:최근 10 에피소드의 평균 단계 수 = 180.7\n", 604 | "145 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 182.3\n", 605 | "146 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 190.0\n", 606 | "147 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 195.8\n", 607 | "148 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 197.4\n", 608 | "149 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 197.6\n", 609 | "150 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 199.6\n", 610 | "10 에피소드 연속 성공\n", 611 | "151 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 199.6\n" 612 | ] 613 | } 614 | ], 615 | "source": [ 616 | "# 실행 엔트리 포인트\n", 617 | "cartpole_env = Environment()\n", 618 | "cartpole_env.run()\n" 619 | ] 620 | } 621 | ], 622 | "metadata": { 623 | "kernelspec": { 624 | "display_name": "Python 3", 625 | "language": "python", 626 | "name": "python3" 627 | }, 628 | "language_info": { 629 | "codemirror_mode": { 630 | "name": "ipython", 631 | "version": 3 632 | }, 633 | "file_extension": ".py", 634 | "mimetype": "text/x-python", 635 | "name": "python", 636 | "nbconvert_exporter": "python", 637 | "pygments_lexer": "ipython3", 638 | "version": "3.6.6" 639 | } 640 | }, 641 | "nbformat": 4, 642 | "nbformat_minor": 2 643 | } 644 | -------------------------------------------------------------------------------- /program/6_3_DuelingNetwork.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 6.3 Dueling Network 구현" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# 구현에 사용할 패키지 임포트\n", 17 | "import numpy as np\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "%matplotlib inline\n", 20 | "import gym\n" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "# 애니메이션을 만드는 함수\n", 30 | "# 참고 URL http://nbviewer.jupyter.org/github/patrickmineault\n", 31 | "# /xcorr-notebooks/blob/master/Render%20OpenAI%20gym%20as%20GIF.ipynb\n", 32 | "from JSAnimation.IPython_display import display_animation\n", 33 | "from matplotlib import animation\n", 34 | "from IPython.display import display\n", 35 | "\n", 36 | "\n", 37 | "def display_frames_as_gif(frames):\n", 38 | " \"\"\"\n", 39 | " Displays a list of frames as a gif, with controls\n", 40 | " \"\"\"\n", 41 | " plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0),\n", 42 | " dpi=72)\n", 43 | " patch = plt.imshow(frames[0])\n", 44 | " plt.axis('off')\n", 45 | "\n", 46 | " def animate(i):\n", 47 | " patch.set_data(frames[i])\n", 48 | "\n", 49 | " anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames),\n", 50 | " interval=50)\n", 51 | "\n", 52 | " anim.save('movie_cartpole_dueling_network.mp4') # 애니메이션을 저장하는 부분\n", 53 | " display(display_animation(anim, default_mode='loop'))\n" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "# namedtuple 생성\n", 63 | "from collections import namedtuple\n", 64 | "\n", 65 | "Transition = namedtuple(\n", 66 | " 'Transition', ('state', 'action', 'next_state', 'reward'))\n" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "# 상수 정의\n", 76 | "ENV = 'CartPole-v0' # 태스크 이름\n", 77 | "GAMMA = 0.99 # 시간할인율\n", 78 | "MAX_STEPS = 200 # 1에피소드 당 최대 단계 수\n", 79 | "NUM_EPISODES = 500 # 최대 에피소드 수\n" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 5, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "# transition을 저장하기 위한 메모리 클래스\n", 89 | "\n", 90 | "\n", 91 | "class ReplayMemory:\n", 92 | "\n", 93 | " def __init__(self, CAPACITY):\n", 94 | " self.capacity = CAPACITY # 메모리의 최대 저장 건수\n", 95 | " self.memory = [] # 실제 transition을 저장할 변수\n", 96 | " self.index = 0 # 저장 위치를 가리킬 인덱스 변수\n", 97 | "\n", 98 | " def push(self, state, action, state_next, reward):\n", 99 | " '''transition = (state, action, state_next, reward)을 메모리에 저장'''\n", 100 | "\n", 101 | " if len(self.memory) < self.capacity:\n", 102 | " self.memory.append(None) # 메모리가 가득차지 않은 경우\n", 103 | "\n", 104 | " # Transition이라는 namedtuple을 사용하여 키-값 쌍의 형태로 값을 저장\n", 105 | " self.memory[self.index] = Transition(state, action, state_next, reward)\n", 106 | "\n", 107 | " self.index = (self.index + 1) % self.capacity # 다음 저장할 위치를 한 자리 뒤로 수정\n", 108 | "\n", 109 | " def sample(self, batch_size):\n", 110 | " '''batch_size 갯수 만큼 무작위로 저장된 transition을 추출'''\n", 111 | " return random.sample(self.memory, batch_size)\n", 112 | "\n", 113 | " def __len__(self):\n", 114 | " '''len 함수로 현재 저장된 transition 갯수를 반환'''\n", 115 | " return len(self.memory)\n" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 6, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# Dueling Network 신경망 구성\n", 125 | "import torch.nn as nn\n", 126 | "import torch.nn.functional as F\n", 127 | "\n", 128 | "\n", 129 | "class Net(nn.Module):\n", 130 | "\n", 131 | " def __init__(self, n_in, n_mid, n_out):\n", 132 | " super(Net, self).__init__()\n", 133 | " self.fc1 = nn.Linear(n_in, n_mid)\n", 134 | " self.fc2 = nn.Linear(n_mid, n_mid)\n", 135 | " # Dueling Network\n", 136 | " self.fc3_adv = nn.Linear(n_mid, n_out) # Advantage함수쪽 신경망\n", 137 | " self.fc3_v = nn.Linear(n_mid, 1) # 가치 V쪽 신경망\n", 138 | "\n", 139 | " def forward(self, x):\n", 140 | " h1 = F.relu(self.fc1(x))\n", 141 | " h2 = F.relu(self.fc2(h1))\n", 142 | "\n", 143 | " adv = self.fc3_adv(h2) # 이 출력은 ReLU를 거치지 않음\n", 144 | " val = self.fc3_v(h2).expand(-1, adv.size(1)) # 이 출력은 ReLU를 거치지 않음\n", 145 | " # val은 adv와 덧셈을 하기 위해 expand 메서드로 크기를 [minibatch*1]에서 [minibatch*2]로 변환\n", 146 | " # adv.size(1)은 2(출력할 행동의 가짓수)\n", 147 | "\n", 148 | " output = val + adv - adv.mean(1, keepdim=True).expand(-1, adv.size(1))\n", 149 | " # val+adv에서 adv의 평균을 뺀다\n", 150 | " # adv.mean(1, keepdim=True) 으로 열방향(행동의 종류 방향) 평균을 구함 크기는 [minibatch*1]이 됨\n", 151 | " # expand 메서드로 크기를 [minibatch*2]로 늘림\n", 152 | "\n", 153 | " return output\n" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 7, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "# 에이전트의 두뇌 역할을 하는 클래스, DDQN을 실제 수행한다 \n", 163 | "\n", 164 | "import random\n", 165 | "import torch\n", 166 | "from torch import nn\n", 167 | "from torch import optim\n", 168 | "import torch.nn.functional as F\n", 169 | "\n", 170 | "BATCH_SIZE = 32\n", 171 | "CAPACITY = 10000\n", 172 | "\n", 173 | "\n", 174 | "class Brain:\n", 175 | " def __init__(self, num_states, num_actions):\n", 176 | " self.num_actions = num_actions # CartPoleの行動(右に左に押す)の2を取得\n", 177 | "\n", 178 | " # transition을 기억하기 위한 메모리 객체 생성\n", 179 | " self.memory = ReplayMemory(CAPACITY)\n", 180 | "\n", 181 | " # 신경망 구성\n", 182 | " n_in, n_mid, n_out = num_states, 32, num_actions\n", 183 | " self.main_q_network = Net(n_in, n_mid, n_out) # Net 클래스를 사용\n", 184 | " self.target_q_network = Net(n_in, n_mid, n_out) # Net 클래스를 사용\n", 185 | " print(self.main_q_network) # 신경망의 구조를 출력\n", 186 | "\n", 187 | " # 최적화 기법 선택\n", 188 | " self.optimizer = optim.Adam(\n", 189 | " self.main_q_network.parameters(), lr=0.0001)\n", 190 | "\n", 191 | " def replay(self):\n", 192 | " '''Experience Replay로 신경망의 결합 가중치 학습'''\n", 193 | "\n", 194 | " # 1. 저장된 transition의 수를 확인\n", 195 | " if len(self.memory) < BATCH_SIZE:\n", 196 | " return\n", 197 | "\n", 198 | " # 2. 미니배치 생성\n", 199 | " self.batch, self.state_batch, self.action_batch, self.reward_batch, self.non_final_next_states = self.make_minibatch()\n", 200 | "\n", 201 | " # 3. 정답신호로 사용할 Q(s_t, a_t)를 계산\n", 202 | " self.expected_state_action_values = self.get_expected_state_action_values()\n", 203 | "\n", 204 | " # 4. 결합 가중치 수정\n", 205 | " self.update_main_q_network()\n", 206 | "\n", 207 | " def decide_action(self, state, episode):\n", 208 | " '''현재 상태로부터 행동을 결정함'''\n", 209 | " # ε-greedy 알고리즘에서 서서히 최적행동의 비중을 늘린다\n", 210 | " epsilon = 0.5 * (1 / (episode + 1))\n", 211 | "\n", 212 | " if epsilon <= np.random.uniform(0, 1):\n", 213 | " self.main_q_network.eval() # 신경망을 추론 모드로 전환\n", 214 | " with torch.no_grad():\n", 215 | " action = self.main_q_network(state).max(1)[1].view(1, 1)\n", 216 | " # 신경망 출력의 최댓값에 대한 인덱스 = max(1)[1]\n", 217 | " # .view(1,1)은 [torch.LongTensor of size 1] 을 size 1*1로 변환하는 역할을 한다\n", 218 | "\n", 219 | " else:\n", 220 | " # 행동을 무작위로 반환(0 혹은 1)\n", 221 | " action = torch.LongTensor(\n", 222 | " [[random.randrange(self.num_actions)]]) # 행동을 무작위로 반환(0 혹은 1)\n", 223 | " # action은 [torch.LongTensor of size 1*1] 형태가 된다\n", 224 | "\n", 225 | " return action\n", 226 | "\n", 227 | " def make_minibatch(self):\n", 228 | " '''2. 미니배치 생성'''\n", 229 | "\n", 230 | " # 2.1 메모리 객체에서 미니배치를 추출\n", 231 | " transitions = self.memory.sample(BATCH_SIZE)\n", 232 | "\n", 233 | " # 2.2 각 변수를 미니배치에 맞는 형태로 변형\n", 234 | " # transitions는 각 단계 별로 (state, action, state_next, reward) 형태로 BATCH_SIZE 갯수만큼 저장됨\n", 235 | " # 다시 말해, (state, action, state_next, reward) * BATCH_SIZE 형태가 된다\n", 236 | " # 이것을 미니배치로 만들기 위해\n", 237 | " # (state*BATCH_SIZE, action*BATCH_SIZE, state_next*BATCH_SIZE, reward*BATCH_SIZE) 형태로 변환한다\n", 238 | " batch = Transition(*zip(*transitions))\n", 239 | "\n", 240 | " # 2.3 각 변수의 요소를 미니배치에 맞게 변형하고, 신경망으로 다룰 수 있도록 Variable로 만든다\n", 241 | " # state를 예로 들면, [torch.FloatTensor of size 1*4] 형태의 요소가 BATCH_SIZE 갯수만큼 있는 형태이다\n", 242 | " # 이를 torch.FloatTensor of size BATCH_SIZE*4 형태로 변형한다\n", 243 | " # 상태, 행동, 보상, non_final 상태로 된 미니배치를 나타내는 Variable을 생성\n", 244 | " # cat은 Concatenates(연접)을 의미한다\n", 245 | " state_batch = torch.cat(batch.state)\n", 246 | " action_batch = torch.cat(batch.action)\n", 247 | " reward_batch = torch.cat(batch.reward)\n", 248 | " non_final_next_states = torch.cat([s for s in batch.next_state\n", 249 | " if s is not None])\n", 250 | "\n", 251 | " return batch, state_batch, action_batch, reward_batch, non_final_next_states\n", 252 | "\n", 253 | " def get_expected_state_action_values(self):\n", 254 | " '''정답신호로 사용할 Q(s_t, a_t)를 계산'''\n", 255 | "\n", 256 | " # 3.1 신경망을 추론 모드로 전환\n", 257 | " self.main_q_network.eval()\n", 258 | " self.target_q_network.eval()\n", 259 | "\n", 260 | " # 3.2 신경망으로 Q(s_t, a_t)를 계산\n", 261 | " # self.model(state_batch)은 왼쪽, 오른쪽에 대한 Q값을 출력하며\n", 262 | " # [torch.FloatTensor of size BATCH_SIZEx2] 형태이다\n", 263 | " # 여기서부터는 실행한 행동 a_t에 대한 Q값을 계산하므로 action_batch에서 취한 행동 a_t가 \n", 264 | " # 왼쪽이냐 오른쪽이냐에 대한 인덱스를 구하고, 이에 대한 Q값을 gather 메서드로 모아온다\n", 265 | " self.state_action_values = self.main_q_network(\n", 266 | " self.state_batch).gather(1, self.action_batch)\n", 267 | "\n", 268 | " # 3.3 max{Q(s_t+1, a)}값을 계산한다 이때 다음 상태가 존재하는지에 주의해야 한다\n", 269 | "\n", 270 | " # cartpole이 done 상태가 아니고, next_state가 존재하는지 확인하는 인덱스 마스크를 만듬\n", 271 | " non_final_mask = torch.ByteTensor(tuple(map(lambda s: s is not None,\n", 272 | " self.batch.next_state)))\n", 273 | " # 먼저 전체를 0으로 초기화\n", 274 | " next_state_values = torch.zeros(BATCH_SIZE)\n", 275 | " \n", 276 | " a_m = torch.zeros(BATCH_SIZE).type(torch.LongTensor)\n", 277 | "\n", 278 | " # 다음 상태에서 Q값이 최대가 되는 행동 a_m을 Main Q-Network로 계산\n", 279 | " # 마지막에 붙은 [1]로 행동에 해당하는 인덱스를 구함\n", 280 | " a_m[non_final_mask] = self.main_q_network(\n", 281 | " self.non_final_next_states).detach().max(1)[1]\n", 282 | "\n", 283 | " # 다음 상태가 있는 것만을 걸러내고, size 32를 32*1로 변환\n", 284 | " a_m_non_final_next_states = a_m[non_final_mask].view(-1, 1)\n", 285 | "\n", 286 | " # 다음 상태가 있는 인덱스에 대해 행동 a_m의 Q값을 target Q-Network로 계산\n", 287 | " # detach() 메서드로 값을 꺼내옴\n", 288 | " # squeeze() 메서드로 size[minibatch*1]을 [minibatch]로 변환\n", 289 | " next_state_values[non_final_mask] = self.target_q_network(\n", 290 | " self.non_final_next_states).gather(1, a_m_non_final_next_states).detach().squeeze()\n", 291 | "\n", 292 | " # 3.4 정답신호로 사용할 Q(s_t, a_t)값을 Q러닝 식으로 계산한다\n", 293 | " expected_state_action_values = self.reward_batch + GAMMA * next_state_values\n", 294 | "\n", 295 | " return expected_state_action_values\n", 296 | "\n", 297 | " def update_main_q_network(self):\n", 298 | " '''4. 결합 가중치 수정'''\n", 299 | "\n", 300 | " # 4.1 신경망을 학습 모드로 전환\n", 301 | " self.main_q_network.train()\n", 302 | "\n", 303 | " # 4.2 손실함수를 계산 (smooth_l1_loss는 Huber 함수)\n", 304 | " # expected_state_action_values은\n", 305 | " # size가 [minibatch]이므로 unsqueeze하여 [minibatch*1]로 만든다\n", 306 | " loss = F.smooth_l1_loss(self.state_action_values,\n", 307 | " self.expected_state_action_values.unsqueeze(1))\n", 308 | "\n", 309 | " # 4.3 결합 가중치를 수정한다\n", 310 | " self.optimizer.zero_grad() # 경사를 초기화\n", 311 | " loss.backward() # 역전파 계산\n", 312 | " self.optimizer.step() # 결합 가중치 수정\n", 313 | "\n", 314 | " def update_target_q_network(self): # DDQN에서 추가됨\n", 315 | " '''Target Q-Network을 Main Q-Network와 맞춤'''\n", 316 | " self.target_q_network.load_state_dict(self.main_q_network.state_dict())\n" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 8, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "# CartPole 태스크의 에이전트 클래스. 봉 달린 수레 자체라고 보면 된다\n", 326 | "\n", 327 | "\n", 328 | "class Agent:\n", 329 | " def __init__(self, num_states, num_actions):\n", 330 | " '''태스크의 상태 및 행동의 가짓수를 설정'''\n", 331 | " self.brain = Brain(num_states, num_actions) # 에이전트의 행동을 결정할 두뇌 역할 객체를 생성\n", 332 | "\n", 333 | " def update_q_function(self):\n", 334 | " '''Q함수를 수정'''\n", 335 | " self.brain.replay()\n", 336 | "\n", 337 | " def get_action(self, state, episode):\n", 338 | " '''행동을 결정'''\n", 339 | " action = self.brain.decide_action(state, episode)\n", 340 | " return action\n", 341 | "\n", 342 | " def memorize(self, state, action, state_next, reward):\n", 343 | " '''memory 객체에 state, action, state_next, reward 내용을 저장'''\n", 344 | " self.brain.memory.push(state, action, state_next, reward)\n", 345 | "\n", 346 | " def update_target_q_function(self):\n", 347 | " '''Target Q-Network을 Main Q-Network와 맞춤'''\n", 348 | " self.brain.update_target_q_network()\n", 349 | " " 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 9, 355 | "metadata": {}, 356 | "outputs": [], 357 | "source": [ 358 | "# CartPole을 실행하는 환경 클래스\n", 359 | "\n", 360 | "\n", 361 | "class Environment:\n", 362 | "\n", 363 | " def __init__(self):\n", 364 | " self.env = gym.make(ENV) # 태스크를 설정\n", 365 | " num_states = self.env.observation_space.shape[0] # 태스크의 상태 변수 수(4)를 받아옴\n", 366 | " num_actions = self.env.action_space.n # 태스크의 행동 가짓수(2)를 받아옴\n", 367 | " self.agent = Agent(num_states, num_actions) # 에이전트 역할을 할 객체를 생성\n", 368 | "\n", 369 | " def run(self):\n", 370 | " '''실행'''\n", 371 | " episode_10_list = np.zeros(10) # 최근 10에피소드 동안 버틴 단계 수를 저장함(평균 단계 수를 출력할 때 사용)\n", 372 | " complete_episodes = 0 # 현재까지 195단계를 버틴 에피소드 수\n", 373 | " episode_final = False # 마지막 에피소드 여부\n", 374 | " frames = [] # 애니메이션을 만들기 위해 마지막 에피소드의 프레임을 저장할 배열\n", 375 | "\n", 376 | " for episode in range(NUM_EPISODES): # 최대 에피소드 수만큼 반복\n", 377 | " observation = self.env.reset() # 환경 초기화\n", 378 | "\n", 379 | " state = observation # 관측을 변환없이 그대로 상태 s로 사용\n", 380 | " state = torch.from_numpy(state).type(\n", 381 | " torch.FloatTensor) # NumPy 변수를 파이토치 텐서로 변환\n", 382 | " state = torch.unsqueeze(state, 0) # size 4를 size 1*4로 변환\n", 383 | "\n", 384 | " for step in range(MAX_STEPS): # 1 에피소드에 해당하는 반복문\n", 385 | " \n", 386 | " # 애니메이션 만드는 부분을 주석처리\n", 387 | " #if episode_final is True: # 마지막 에피소드에서는 각 시각의 이미지를 frames에 저장한다\n", 388 | " # frames.append(self.env.render(mode='rgb_array'))\n", 389 | " \n", 390 | " action = self.agent.get_action(state, episode) # 다음 행동을 결정\n", 391 | "\n", 392 | " # 행동 a_t를 실행하여 다음 상태 s_{t+1}과 done 플래그 값을 결정\n", 393 | " # action에 .item()을 호출하여 행동 내용을 구함\n", 394 | " observation_next, _, done, _ = self.env.step(\n", 395 | " action.item()) # reward와 info는 사용하지 않으므로 _로 처리\n", 396 | "\n", 397 | " # 보상을 부여하고 episode의 종료 판정 및 state_next를 설정한다\n", 398 | " if done: # 단계 수가 200을 넘었거나 봉이 일정 각도 이상 기울면 done이 True가 됨\n", 399 | " state_next = None # 다음 상태가 없으므로 None으로\n", 400 | "\n", 401 | " # 최근 10 에피소드에서 버틴 단계 수를 리스트에 저장\n", 402 | " episode_10_list = np.hstack(\n", 403 | " (episode_10_list[1:], step + 1))\n", 404 | "\n", 405 | " if step < 195:\n", 406 | " reward = torch.FloatTensor(\n", 407 | " [-1.0]) # 도중에 봉이 쓰러졌다면 페널티로 보상 -1을 부여\n", 408 | " complete_episodes = 0 # 연속 성공 에피소드 기록을 초기화\n", 409 | " else:\n", 410 | " reward = torch.FloatTensor([1.0]) # 봉이 서 있는 채로 에피소드를 마쳤다면 보상 1 부여\n", 411 | " complete_episodes = complete_episodes + 1 # 연속 성공 에피소드 기록을 갱신\n", 412 | " else:\n", 413 | " reward = torch.FloatTensor([0.0]) # 그 외의 경우는 보상 0을 부여\n", 414 | " state_next = observation_next # 관측 결과를 그대로 상태로 사용\n", 415 | " state_next = torch.from_numpy(state_next).type(\n", 416 | " torch.FloatTensor) # numpy 변수를 파이토치 텐서로 변환\n", 417 | " state_next = torch.unsqueeze(state_next, 0) # size 4를 size 1*4로 변환\n", 418 | "\n", 419 | " # 메모리에 경험을 저장\n", 420 | " self.agent.memorize(state, action, state_next, reward)\n", 421 | "\n", 422 | " # Experience Replay로 Q함수를 수정\n", 423 | " self.agent.update_q_function()\n", 424 | "\n", 425 | " # 관측 결과를 업데이트\n", 426 | " state = state_next\n", 427 | "\n", 428 | " # 에피소드 종료 처리\n", 429 | " if done:\n", 430 | " print('%d Episode: Finished after %d steps:최근 10 에피소드의 평균 단계 수 = %.1lf' % (\n", 431 | " episode, step + 1, episode_10_list.mean()))\n", 432 | " \n", 433 | " # DDQN으로 추가된 부분 2에피소드마다 한번씩 Target Q-Network을 Main Q-Network와 맞춰줌\n", 434 | " if(episode % 2 == 0):\n", 435 | " self.agent.update_target_q_function()\n", 436 | " break\n", 437 | " \n", 438 | " \n", 439 | " if episode_final is True:\n", 440 | " # 애니메이션 생성 부분을 주석처리함\n", 441 | " # 애니메이션 생성 및 저장\n", 442 | " #display_frames_as_gif(frames)\n", 443 | " break\n", 444 | "\n", 445 | " # 10 에피소드 연속으로 195단계를 버티면 태스크 성공\n", 446 | " if complete_episodes >= 10:\n", 447 | " print('10 에피소드 연속 성공')\n", 448 | " episode_final = True # 다음 에피소드에서 애니메이션을 생성\n" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 10, 454 | "metadata": {}, 455 | "outputs": [ 456 | { 457 | "name": "stdout", 458 | "output_type": "stream", 459 | "text": [ 460 | "Net(\n", 461 | " (fc1): Linear(in_features=4, out_features=32, bias=True)\n", 462 | " (fc2): Linear(in_features=32, out_features=32, bias=True)\n", 463 | " (fc3_adv): Linear(in_features=32, out_features=2, bias=True)\n", 464 | " (fc3_v): Linear(in_features=32, out_features=1, bias=True)\n", 465 | ")\n", 466 | "0 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 1.1\n", 467 | "1 Episode: Finished after 13 steps:최근 10 에피소드의 평균 단계 수 = 2.4\n", 468 | "2 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 3.6\n", 469 | "3 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 4.7\n", 470 | "4 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 5.6\n", 471 | "5 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 6.6\n", 472 | "6 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 7.6\n", 473 | "7 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 8.8\n", 474 | "8 Episode: Finished after 16 steps:최근 10 에피소드의 평균 단계 수 = 10.4\n", 475 | "9 Episode: Finished after 18 steps:최근 10 에피소드의 평균 단계 수 = 12.2\n", 476 | "10 Episode: Finished after 17 steps:최근 10 에피소드의 평균 단계 수 = 12.8\n", 477 | "11 Episode: Finished after 18 steps:최근 10 에피소드의 평균 단계 수 = 13.3\n", 478 | "12 Episode: Finished after 20 steps:최근 10 에피소드의 평균 단계 수 = 14.1\n", 479 | "13 Episode: Finished after 27 steps:최근 10 에피소드의 평균 단계 수 = 15.7\n", 480 | "14 Episode: Finished after 21 steps:최근 10 에피소드의 평균 단계 수 = 16.9\n", 481 | "15 Episode: Finished after 84 steps:최근 10 에피소드의 평균 단계 수 = 24.3\n", 482 | "16 Episode: Finished after 24 steps:최근 10 에피소드의 평균 단계 수 = 25.7\n", 483 | "17 Episode: Finished after 16 steps:최근 10 에피소드의 평균 단계 수 = 26.1\n", 484 | "18 Episode: Finished after 16 steps:최근 10 에피소드의 평균 단계 수 = 26.1\n", 485 | "19 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 25.3\n", 486 | "20 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 24.6\n", 487 | "21 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 23.8\n", 488 | "22 Episode: Finished after 8 steps:최근 10 에피소드의 평균 단계 수 = 22.6\n", 489 | "23 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 20.8\n", 490 | "24 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 19.7\n", 491 | "25 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 12.2\n", 492 | "26 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 10.7\n", 493 | "27 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 10.0\n", 494 | "28 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.3\n", 495 | "29 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.2\n", 496 | "30 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n", 497 | "31 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n", 498 | "32 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n", 499 | "33 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 9.7\n", 500 | "34 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 9.8\n", 501 | "35 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 10.0\n", 502 | "36 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 10.0\n", 503 | "37 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 10.1\n", 504 | "38 Episode: Finished after 13 steps:최근 10 에피소드의 평균 단계 수 = 10.5\n", 505 | "39 Episode: Finished after 18 steps:최근 10 에피소드의 평균 단계 수 = 11.4\n", 506 | "40 Episode: Finished after 21 steps:최근 10 에피소드의 평균 단계 수 = 12.3\n", 507 | "41 Episode: Finished after 34 steps:최근 10 에피소드의 평균 단계 수 = 14.7\n", 508 | "42 Episode: Finished after 45 steps:최근 10 에피소드의 평균 단계 수 = 18.3\n", 509 | "43 Episode: Finished after 45 steps:최근 10 에피소드의 평균 단계 수 = 21.7\n", 510 | "44 Episode: Finished after 49 steps:최근 10 에피소드의 평균 단계 수 = 25.5\n", 511 | "45 Episode: Finished after 56 steps:최근 10 에피소드의 평균 단계 수 = 30.0\n", 512 | "46 Episode: Finished after 52 steps:최근 10 에피소드의 평균 단계 수 = 34.3\n", 513 | "47 Episode: Finished after 51 steps:최근 10 에피소드의 평균 단계 수 = 38.4\n", 514 | "48 Episode: Finished after 84 steps:최근 10 에피소드의 평균 단계 수 = 45.5\n", 515 | "49 Episode: Finished after 80 steps:최근 10 에피소드의 평균 단계 수 = 51.7\n", 516 | "50 Episode: Finished after 111 steps:최근 10 에피소드의 평균 단계 수 = 60.7\n", 517 | "51 Episode: Finished after 126 steps:최근 10 에피소드의 평균 단계 수 = 69.9\n", 518 | "52 Episode: Finished after 79 steps:최근 10 에피소드의 평균 단계 수 = 73.3\n", 519 | "53 Episode: Finished after 77 steps:최근 10 에피소드의 평균 단계 수 = 76.5\n", 520 | "54 Episode: Finished after 50 steps:최근 10 에피소드의 평균 단계 수 = 76.6\n", 521 | "55 Episode: Finished after 70 steps:최근 10 에피소드의 평균 단계 수 = 78.0\n", 522 | "56 Episode: Finished after 97 steps:최근 10 에피소드의 평균 단계 수 = 82.5\n", 523 | "57 Episode: Finished after 80 steps:최근 10 에피소드의 평균 단계 수 = 85.4\n", 524 | "58 Episode: Finished after 100 steps:최근 10 에피소드의 평균 단계 수 = 87.0\n", 525 | "59 Episode: Finished after 190 steps:최근 10 에피소드의 평균 단계 수 = 98.0\n", 526 | "60 Episode: Finished after 65 steps:최근 10 에피소드의 평균 단계 수 = 93.4\n", 527 | "61 Episode: Finished after 102 steps:최근 10 에피소드의 평균 단계 수 = 91.0\n", 528 | "62 Episode: Finished after 107 steps:최근 10 에피소드의 평균 단계 수 = 93.8\n", 529 | "63 Episode: Finished after 157 steps:최근 10 에피소드의 평균 단계 수 = 101.8\n", 530 | "64 Episode: Finished after 188 steps:최근 10 에피소드의 평균 단계 수 = 115.6\n", 531 | "65 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 128.6\n", 532 | "66 Episode: Finished after 187 steps:최근 10 에피소드의 평균 단계 수 = 137.6\n", 533 | "67 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 149.6\n", 534 | "68 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 159.6\n", 535 | "69 Episode: Finished after 105 steps:최근 10 에피소드의 평균 단계 수 = 151.1\n", 536 | "70 Episode: Finished after 179 steps:최근 10 에피소드의 평균 단계 수 = 162.5\n", 537 | "71 Episode: Finished after 194 steps:최근 10 에피소드의 평균 단계 수 = 171.7\n", 538 | "72 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 181.0\n", 539 | "73 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 185.3\n", 540 | "74 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 186.5\n", 541 | "75 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 186.5\n", 542 | "76 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 187.8\n", 543 | "77 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 187.8\n", 544 | "78 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 187.8\n", 545 | "79 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 197.3\n", 546 | "80 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 199.4\n", 547 | "81 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 200.0\n", 548 | "10 에피소드 연속 성공\n", 549 | "82 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 200.0\n" 550 | ] 551 | } 552 | ], 553 | "source": [ 554 | "# 실행 엔트리 포인트\n", 555 | "cartpole_env = Environment()\n", 556 | "cartpole_env.run()\n" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": null, 562 | "metadata": {}, 563 | "outputs": [], 564 | "source": [] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": null, 569 | "metadata": {}, 570 | "outputs": [], 571 | "source": [] 572 | } 573 | ], 574 | "metadata": { 575 | "kernelspec": { 576 | "display_name": "Python 3", 577 | "language": "python", 578 | "name": "python3" 579 | }, 580 | "language_info": { 581 | "codemirror_mode": { 582 | "name": "ipython", 583 | "version": 3 584 | }, 585 | "file_extension": ".py", 586 | "mimetype": "text/x-python", 587 | "name": "python", 588 | "nbconvert_exporter": "python", 589 | "pygments_lexer": "ipython3", 590 | "version": "3.6.6" 591 | } 592 | }, 593 | "nbformat": 4, 594 | "nbformat_minor": 2 595 | } 596 | -------------------------------------------------------------------------------- /program/6_5_A2C_Advanced_ActorCritic.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 6.5 A2C(Advanced Actor-Critic) 구현" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# 구현에 사용할 패키지 임포트\n", 17 | "import numpy as np\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "%matplotlib inline\n", 20 | "import gym\n" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "# 상수 정의\n", 30 | "ENV = 'CartPole-v0' # 태스크 이름\n", 31 | "GAMMA = 0.99 # 시간할인율\n", 32 | "MAX_STEPS = 200 # 1에피소드 당 최대 단계 수\n", 33 | "NUM_EPISODES = 1000 # 최대 에피소드 수\n", 34 | "\n", 35 | "NUM_PROCESSES = 32 # 동시 실행 환경 수\n", 36 | "NUM_ADVANCED_STEP = 5 # 총 보상을 계산할 때 Advantage 학습을 할 단계 수\n" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# A2C 손실함수 계산에 사용되는 상수\n", 46 | "value_loss_coef = 0.5\n", 47 | "entropy_coef = 0.01\n", 48 | "max_grad_norm = 0.5\n" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 4, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# 메모리 클래스 정의\n", 58 | "\n", 59 | "\n", 60 | "class RolloutStorage(object):\n", 61 | " '''Advantage 학습에 사용할 메모리 클래스'''\n", 62 | "\n", 63 | " def __init__(self, num_steps, num_processes, obs_shape):\n", 64 | "\n", 65 | " self.observations = torch.zeros(num_steps + 1, num_processes, 4)\n", 66 | " self.masks = torch.ones(num_steps + 1, num_processes, 1)\n", 67 | " self.rewards = torch.zeros(num_steps, num_processes, 1)\n", 68 | " self.actions = torch.zeros(num_steps, num_processes, 1).long()\n", 69 | "\n", 70 | " # 할인 총보상 저장\n", 71 | " self.returns = torch.zeros(num_steps + 1, num_processes, 1)\n", 72 | " self.index = 0 # insert할 인덱스\n", 73 | "\n", 74 | " def insert(self, current_obs, action, reward, mask):\n", 75 | " '''현재 인덱스 위치에 transition을 저장'''\n", 76 | " self.observations[self.index + 1].copy_(current_obs)\n", 77 | " self.masks[self.index + 1].copy_(mask)\n", 78 | " self.rewards[self.index].copy_(reward)\n", 79 | " self.actions[self.index].copy_(action)\n", 80 | "\n", 81 | " self.index = (self.index + 1) % NUM_ADVANCED_STEP # 인덱스 값 업데이트\n", 82 | "\n", 83 | " def after_update(self):\n", 84 | " '''Advantage학습 단계만큼 단계가 진행되면 가장 새로운 transition을 index0에 저장'''\n", 85 | " self.observations[0].copy_(self.observations[-1])\n", 86 | " self.masks[0].copy_(self.masks[-1])\n", 87 | "\n", 88 | " def compute_returns(self, next_value):\n", 89 | " '''Advantage학습 범위 안의 각 단계에 대해 할인 총보상을 계산'''\n", 90 | "\n", 91 | " # 주의 : 5번째 단계부터 거슬러 올라오며 계산\n", 92 | " # 주의 : 5번째 단계가 Advantage1, 4번째 단계는 Advantage2가 됨\n", 93 | " self.returns[-1] = next_value\n", 94 | " for ad_step in reversed(range(self.rewards.size(0))):\n", 95 | " self.returns[ad_step] = self.returns[ad_step + 1] * \\\n", 96 | " GAMMA * self.masks[ad_step + 1] + self.rewards[ad_step]\n" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 5, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "# A2C에 사용되는 신경망 구성\n", 106 | "import torch.nn as nn\n", 107 | "import torch.nn.functional as F\n", 108 | "\n", 109 | "\n", 110 | "class Net(nn.Module):\n", 111 | "\n", 112 | " def __init__(self, n_in, n_mid, n_out):\n", 113 | " super(Net, self).__init__()\n", 114 | " self.fc1 = nn.Linear(n_in, n_mid)\n", 115 | " self.fc2 = nn.Linear(n_mid, n_mid)\n", 116 | " self.actor = nn.Linear(n_mid, n_out) # 행동을 결정하는 부분이므로 출력 갯수는 행동의 가짓수\n", 117 | " self.critic = nn.Linear(n_mid, 1) # 상태가치를 출력하는 부분이므로 출력 갯수는 1개\n", 118 | "\n", 119 | " def forward(self, x):\n", 120 | " '''신경망 순전파 계산을 정의'''\n", 121 | " h1 = F.relu(self.fc1(x))\n", 122 | " h2 = F.relu(self.fc2(h1))\n", 123 | " critic_output = self.critic(h2) # 상태가치 계산\n", 124 | " actor_output = self.actor(h2) # 행동 계산\n", 125 | "\n", 126 | " return critic_output, actor_output\n", 127 | "\n", 128 | " def act(self, x):\n", 129 | " '''상태 x로부터 행동을 확률적으로 결정'''\n", 130 | " value, actor_output = self(x)\n", 131 | " # dim=1이므로 행동의 종류에 대해 softmax를 적용\n", 132 | " action_probs = F.softmax(actor_output, dim=1)\n", 133 | " action = action_probs.multinomial(num_samples=1) # dim=1이므로 행동의 종류에 대해 확률을 계산\n", 134 | " return action\n", 135 | "\n", 136 | " def get_value(self, x):\n", 137 | " '''상태 x로부터 상태가치를 계산'''\n", 138 | " value, actor_output = self(x)\n", 139 | "\n", 140 | " return value\n", 141 | "\n", 142 | " def evaluate_actions(self, x, actions):\n", 143 | " '''상태 x로부터 상태가치, 실제 행동 actions의 로그 확률, 엔트로피를 계산'''\n", 144 | " value, actor_output = self(x)\n", 145 | "\n", 146 | " log_probs = F.log_softmax(actor_output, dim=1) # dim=1이므로 행동의 종류에 대해 확률을 계산\n", 147 | " action_log_probs = log_probs.gather(1, actions) # 실제 행동의 로그 확률(log_probs)을 구함\n", 148 | "\n", 149 | " probs = F.softmax(actor_output, dim=1) # dim=1이므로 행동의 종류에 대한 계산\n", 150 | " entropy = -(log_probs * probs).sum(-1).mean()\n", 151 | "\n", 152 | " return value, action_log_probs, entropy\n" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 6, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "# 에이전트의 두뇌 역할을 하는 클래스. 모든 에이전트가 공유한다\n", 162 | "\n", 163 | "import torch\n", 164 | "from torch import optim\n", 165 | "\n", 166 | "\n", 167 | "class Brain(object):\n", 168 | " def __init__(self, actor_critic):\n", 169 | " self.actor_critic = actor_critic # actor_critic은 Net 클래스로 구현한 신경망\n", 170 | " self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=0.01)\n", 171 | "\n", 172 | " def update(self, rollouts):\n", 173 | " '''Advantage학습의 대상이 되는 5단계 모두를 사용하여 수정'''\n", 174 | " obs_shape = rollouts.observations.size()[2:] # torch.Size([4, 84, 84])\n", 175 | " num_steps = NUM_ADVANCED_STEP\n", 176 | " num_processes = NUM_PROCESSES\n", 177 | "\n", 178 | " values, action_log_probs, entropy = self.actor_critic.evaluate_actions(\n", 179 | " rollouts.observations[:-1].view(-1, 4),\n", 180 | " rollouts.actions.view(-1, 1))\n", 181 | "\n", 182 | " # 주의 : 각 변수의 크기\n", 183 | " # rollouts.observations[:-1].view(-1, 4) torch.Size([80, 4])\n", 184 | " # rollouts.actions.view(-1, 1) torch.Size([80, 1])\n", 185 | " # values torch.Size([80, 1])\n", 186 | " # action_log_probs torch.Size([80, 1])\n", 187 | " # entropy torch.Size([])\n", 188 | "\n", 189 | " values = values.view(num_steps, num_processes,\n", 190 | " 1) # torch.Size([5, 16, 1])\n", 191 | " action_log_probs = action_log_probs.view(num_steps, num_processes, 1)\n", 192 | "\n", 193 | " # advantage(행동가치-상태가치) 계산\n", 194 | " advantages = rollouts.returns[:-1] - values # torch.Size([5, 16, 1])\n", 195 | "\n", 196 | " # Critic의 loss 계산\n", 197 | " value_loss = advantages.pow(2).mean()\n", 198 | "\n", 199 | " # Actor의 gain 계산, 나중에 -1을 곱하면 loss가 된다\n", 200 | " action_gain = (action_log_probs*advantages.detach()).mean()\n", 201 | " # detach 메서드를 호출하여 advantages를 상수로 취급\n", 202 | "\n", 203 | " # 오차함수의 총합\n", 204 | " total_loss = (value_loss * value_loss_coef -\n", 205 | " action_gain - entropy * entropy_coef)\n", 206 | "\n", 207 | " # 결합 가중치 수정\n", 208 | " self.actor_critic.train() # 신경망을 학습 모드로 전환\n", 209 | " self.optimizer.zero_grad() # 경사를 초기화\n", 210 | " total_loss.backward() # 역전파 계산\n", 211 | " nn.utils.clip_grad_norm_(self.actor_critic.parameters(), max_grad_norm)\n", 212 | " # 결합 가중치가 한번에 너무 크게 변화하지 않도록, 경사를 0.5 이하로 제한함(클리핑)\n", 213 | "\n", 214 | " self.optimizer.step() # 결합 가중치 수정\n" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 7, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "# 이번에는 에이전트 클래스가 없음" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 8, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "# 실행 환경 클래스\n", 233 | "import copy\n", 234 | "\n", 235 | "\n", 236 | "class Environment:\n", 237 | " def run(self):\n", 238 | " '''실행 엔트리 포인트'''\n", 239 | "\n", 240 | " # 동시 실행할 환경 수 만큼 env를 생성\n", 241 | " envs = [gym.make(ENV) for i in range(NUM_PROCESSES)]\n", 242 | "\n", 243 | " # 모든 에이전트가 공유하는 Brain 객체를 생성\n", 244 | " n_in = envs[0].observation_space.shape[0] # 상태 변수 수는 4\n", 245 | " n_out = envs[0].action_space.n # 행동 가짓수는 2\n", 246 | " n_mid = 32\n", 247 | " actor_critic = Net(n_in, n_mid, n_out) # 신경망 객체 생성\n", 248 | " global_brain = Brain(actor_critic)\n", 249 | "\n", 250 | " # 각종 정보를 저장하는 변수\n", 251 | " obs_shape = n_in\n", 252 | " current_obs = torch.zeros(\n", 253 | " NUM_PROCESSES, obs_shape) # torch.Size([16, 4])\n", 254 | " rollouts = RolloutStorage(\n", 255 | " NUM_ADVANCED_STEP, NUM_PROCESSES, obs_shape) # rollouts 객체\n", 256 | " episode_rewards = torch.zeros([NUM_PROCESSES, 1]) # 현재 에피소드의 보상\n", 257 | " final_rewards = torch.zeros([NUM_PROCESSES, 1]) # 마지막 에피소드의 보상\n", 258 | " obs_np = np.zeros([NUM_PROCESSES, obs_shape]) # Numpy 배열\n", 259 | " reward_np = np.zeros([NUM_PROCESSES, 1]) # Numpy 배열\n", 260 | " done_np = np.zeros([NUM_PROCESSES, 1]) # Numpy 배열\n", 261 | " each_step = np.zeros(NUM_PROCESSES) # 각 환경의 단계 수를 기록\n", 262 | " episode = 0 # 환경 0의 에피소드 수\n", 263 | "\n", 264 | " # 초기 상태로부터 시작\n", 265 | " obs = [envs[i].reset() for i in range(NUM_PROCESSES)]\n", 266 | " obs = np.array(obs)\n", 267 | " obs = torch.from_numpy(obs).float() # torch.Size([16, 4])\n", 268 | " current_obs = obs # 가장 최근의 obs를 저장\n", 269 | " \n", 270 | " # advanced 학습에 사용되는 객체 rollouts 첫번째 상태에 현재 상태를 저장\n", 271 | " rollouts.observations[0].copy_(current_obs)\n", 272 | "\n", 273 | " # 1 에피소드에 해당하는 반복문\n", 274 | " for j in range(NUM_EPISODES*NUM_PROCESSES): # 전체 for문\n", 275 | " # advanced 학습 대상이 되는 각 단계에 대해 계산\n", 276 | " for step in range(NUM_ADVANCED_STEP):\n", 277 | "\n", 278 | " # 행동을 선택\n", 279 | " with torch.no_grad():\n", 280 | " action = actor_critic.act(rollouts.observations[step])\n", 281 | "\n", 282 | " # (16,1)→(16,) -> tensor를 NumPy변수로\n", 283 | " actions = action.squeeze(1).numpy()\n", 284 | "\n", 285 | " # 한 단계를 실행\n", 286 | " for i in range(NUM_PROCESSES):\n", 287 | " obs_np[i], reward_np[i], done_np[i], _ = envs[i].step(\n", 288 | " actions[i])\n", 289 | "\n", 290 | " # episode의 종료가치, state_next를 설정\n", 291 | " if done_np[i]: # 단계 수가 200을 넘거나, 봉이 일정 각도 이상 기울면 done이 True가 됨\n", 292 | "\n", 293 | " # 환경 0일 경우에만 출력\n", 294 | " if i == 0:\n", 295 | " print('%d Episode: Finished after %d steps' % (\n", 296 | " episode, each_step[i]+1))\n", 297 | " episode += 1\n", 298 | "\n", 299 | " # 보상 부여\n", 300 | " if each_step[i] < 195:\n", 301 | " reward_np[i] = -1.0 # 도중에 봉이 넘어지면 페널티로 보상 -1 부여\n", 302 | " else:\n", 303 | " reward_np[i] = 1.0 # 봉이 쓰러지지 않고 끝나면 보상 1 부여\n", 304 | "\n", 305 | " each_step[i] = 0 # 단계 수 초기화\n", 306 | " obs_np[i] = envs[i].reset() # 실행 환경 초기화\n", 307 | "\n", 308 | " else:\n", 309 | " reward_np[i] = 0.0 # 그 외의 경우는 보상 0 부여\n", 310 | " each_step[i] += 1\n", 311 | "\n", 312 | " # 보상을 tensor로 변환하고, 에피소드의 총보상에 더해줌\n", 313 | " reward = torch.from_numpy(reward_np).float()\n", 314 | " episode_rewards += reward\n", 315 | "\n", 316 | " # 각 실행 환경을 확인하여 done이 true이면 mask를 0으로, false이면 mask를 1로\n", 317 | " masks = torch.FloatTensor(\n", 318 | " [[0.0] if done_ else [1.0] for done_ in done_np])\n", 319 | "\n", 320 | " # 마지막 에피소드의 총 보상을 업데이트\n", 321 | " final_rewards *= masks # done이 false이면 1을 곱하고, true이면 0을 곱해 초기화\n", 322 | " # done이 false이면 0을 더하고, true이면 episode_rewards를 더해줌\n", 323 | " final_rewards += (1 - masks) * episode_rewards\n", 324 | "\n", 325 | " # 에피소드의 총보상을 업데이트\n", 326 | " episode_rewards *= masks # done이 false인 에피소드의 mask는 1이므로 그대로, true이면 0이 됨\n", 327 | "\n", 328 | " # 현재 done이 true이면 모두 0으로 \n", 329 | " current_obs *= masks\n", 330 | "\n", 331 | " # current_obs를 업데이트\n", 332 | " obs = torch.from_numpy(obs_np).float() # torch.Size([16, 4])\n", 333 | " current_obs = obs # 최신 상태의 obs를 저장\n", 334 | "\n", 335 | " # 메모리 객체에 현 단계의 transition을 저장\n", 336 | " rollouts.insert(current_obs, action.data, reward, masks)\n", 337 | "\n", 338 | " # advanced 학습 for문 끝\n", 339 | "\n", 340 | " # advanced 학습 대상 중 마지막 단계의 상태로 예측하는 상태가치를 계산\n", 341 | "\n", 342 | " with torch.no_grad():\n", 343 | " next_value = actor_critic.get_value(\n", 344 | " rollouts.observations[-1]).detach()\n", 345 | " # rollouts.observations의 크기는 torch.Size([6, 16, 4])\n", 346 | "\n", 347 | " # 모든 단계의 할인총보상을 계산하고, rollouts의 변수 returns를 업데이트\n", 348 | " rollouts.compute_returns(next_value)\n", 349 | "\n", 350 | " # 신경망 및 rollout 업데이트\n", 351 | " global_brain.update(rollouts)\n", 352 | " rollouts.after_update()\n", 353 | "\n", 354 | " # 환경 갯수를 넘어서는 횟수로 200단계를 버텨내면 성공\n", 355 | " if final_rewards.sum().numpy() >= NUM_PROCESSES:\n", 356 | " print('연속성공')\n", 357 | " break\n" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 9, 363 | "metadata": {}, 364 | "outputs": [ 365 | { 366 | "name": "stdout", 367 | "output_type": "stream", 368 | "text": [ 369 | "0 Episode: Finished after 20 steps\n", 370 | "1 Episode: Finished after 13 steps\n", 371 | "2 Episode: Finished after 10 steps\n", 372 | "3 Episode: Finished after 41 steps\n", 373 | "4 Episode: Finished after 20 steps\n", 374 | "5 Episode: Finished after 15 steps\n", 375 | "6 Episode: Finished after 19 steps\n", 376 | "7 Episode: Finished after 19 steps\n", 377 | "8 Episode: Finished after 29 steps\n", 378 | "9 Episode: Finished after 23 steps\n", 379 | "10 Episode: Finished after 73 steps\n", 380 | "11 Episode: Finished after 49 steps\n", 381 | "12 Episode: Finished after 66 steps\n", 382 | "13 Episode: Finished after 14 steps\n", 383 | "14 Episode: Finished after 16 steps\n", 384 | "15 Episode: Finished after 74 steps\n", 385 | "16 Episode: Finished after 200 steps\n", 386 | "17 Episode: Finished after 200 steps\n", 387 | "18 Episode: Finished after 182 steps\n", 388 | "19 Episode: Finished after 200 steps\n", 389 | "연속성공\n" 390 | ] 391 | } 392 | ], 393 | "source": [ 394 | "# main 실행\n", 395 | "cartpole_env = Environment()\n", 396 | "cartpole_env.run()\n" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [] 405 | } 406 | ], 407 | "metadata": { 408 | "kernelspec": { 409 | "display_name": "Python 3", 410 | "language": "python", 411 | "name": "python3" 412 | }, 413 | "language_info": { 414 | "codemirror_mode": { 415 | "name": "ipython", 416 | "version": 3 417 | }, 418 | "file_extension": ".py", 419 | "mimetype": "text/x-python", 420 | "name": "python", 421 | "nbconvert_exporter": "python", 422 | "pygments_lexer": "ipython3", 423 | "version": "3.6.6" 424 | } 425 | }, 426 | "nbformat": 4, 427 | "nbformat_minor": 2 428 | } 429 | -------------------------------------------------------------------------------- /program/7_breakout_play.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 7장 벽돌깨기 게임 학습 결과 실행용 프로그램 " 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# 구현에 사용할 패키지 임포트\n", 17 | "import numpy as np\n", 18 | "from collections import deque\n", 19 | "from tqdm import tqdm\n", 20 | "\n", 21 | "import torch\n", 22 | "import torch.nn as nn\n", 23 | "import torch.nn.functional as F\n", 24 | "import torch.optim as optim\n", 25 | "\n", 26 | "import gym\n", 27 | "from gym import spaces\n", 28 | "from gym.spaces.box import Box\n" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# 구현에 사용할 패키지 임포트\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "%matplotlib inline\n", 40 | "\n", 41 | "\n", 42 | "# 애니메이션을 만드는 함수\n", 43 | "# 참고 URL http://nbviewer.jupyter.org/github/patrickmineault\n", 44 | "# /xcorr-notebooks/blob/master/Render%20OpenAI%20gym%20as%20GIF.ipynb\n", 45 | "from JSAnimation.IPython_display import display_animation\n", 46 | "from matplotlib import animation\n", 47 | "from IPython.display import display\n", 48 | "\n", 49 | "def display_frames_as_gif(frames):\n", 50 | " \"\"\"\n", 51 | " Displays a list of frames as a gif, with controls\n", 52 | " \"\"\"\n", 53 | " plt.figure(figsize=(frames[0].shape[1]/72.0*1, frames[0].shape[0]/72.0*1),\n", 54 | " dpi=72)\n", 55 | " patch = plt.imshow(frames[0])\n", 56 | " plt.axis('off')\n", 57 | " \n", 58 | " def animate(i):\n", 59 | " patch.set_data(frames[i])\n", 60 | " \n", 61 | " anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames),\n", 62 | " interval=20)\n", 63 | " \n", 64 | " anim.save('breakout.mp4') # 애니메이션을 저장하는 부분\n", 65 | " display(display_animation(anim, default_mode='loop'))\n", 66 | " " 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "# 실행환경 설정\n", 76 | "# 参考:https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py\n", 77 | "\n", 78 | "import cv2\n", 79 | "cv2.ocl.setUseOpenCL(False)\n", 80 | "\n", 81 | "\n", 82 | "class NoopResetEnv(gym.Wrapper):\n", 83 | " def __init__(self, env, noop_max=30):\n", 84 | " '''첫 번째 트릭 No-Operation. 초기화 후 일정 단계에 이를때까지 아무 행동도 하지않고\n", 85 | " 게임 초기 상태를 다양하게 하여 특정 시작 상태만 학습하는 것을 방지한다'''\n", 86 | "\n", 87 | " gym.Wrapper.__init__(self, env)\n", 88 | " self.noop_max = noop_max\n", 89 | " self.override_num_noops = None\n", 90 | " self.noop_action = 0\n", 91 | " assert env.unwrapped.get_action_meanings()[0] == 'NOOP'\n", 92 | "\n", 93 | " def reset(self, **kwargs):\n", 94 | " \"\"\" Do no-op action for a number of steps in [1, noop_max].\"\"\"\n", 95 | " self.env.reset(**kwargs)\n", 96 | " if self.override_num_noops is not None:\n", 97 | " noops = self.override_num_noops\n", 98 | " else:\n", 99 | " noops = self.unwrapped.np_random.randint(\n", 100 | " 1, self.noop_max + 1) # pylint: disable=E1101\n", 101 | " assert noops > 0\n", 102 | " obs = None\n", 103 | " for _ in range(noops):\n", 104 | " obs, _, done, _ = self.env.step(self.noop_action)\n", 105 | " if done:\n", 106 | " obs = self.env.reset(**kwargs)\n", 107 | " return obs\n", 108 | "\n", 109 | " def step(self, ac):\n", 110 | " return self.env.step(ac)\n", 111 | "\n", 112 | "\n", 113 | "class EpisodicLifeEnv(gym.Wrapper):\n", 114 | " def __init__(self, env):\n", 115 | " '''두 번째 트릭 Episodic Life. 한번 실패를 게임 종료로 간주하나, 다음 게임을 같은 블록 상태로 시작'''\n", 116 | " gym.Wrapper.__init__(self, env)\n", 117 | " self.lives = 0\n", 118 | " self.was_real_done = True\n", 119 | "\n", 120 | " def step(self, action):\n", 121 | " obs, reward, done, info = self.env.step(action)\n", 122 | " self.was_real_done = done\n", 123 | " # check current lives, make loss of life terminal,\n", 124 | " # then update lives to handle bonus lives\n", 125 | " lives = self.env.unwrapped.ale.lives()\n", 126 | " if lives < self.lives and lives > 0:\n", 127 | " # for Qbert sometimes we stay in lives == 0 condtion for a few frames\n", 128 | " # so its important to keep lives > 0, so that we only reset once\n", 129 | " # the environment advertises done.\n", 130 | " done = True\n", 131 | " self.lives = lives\n", 132 | " return obs, reward, done, info\n", 133 | "\n", 134 | " def reset(self, **kwargs):\n", 135 | " '''5번 실패하면 게임을 완전히 다시 시작'''\n", 136 | " if self.was_real_done:\n", 137 | " obs = self.env.reset(**kwargs)\n", 138 | " else:\n", 139 | " # no-op step to advance from terminal/lost life state\n", 140 | " obs, _, _, _ = self.env.step(0)\n", 141 | " self.lives = self.env.unwrapped.ale.lives()\n", 142 | " return obs\n", 143 | "\n", 144 | "\n", 145 | "class MaxAndSkipEnv(gym.Wrapper):\n", 146 | " def __init__(self, env, skip=4):\n", 147 | " '''세 번째 트릭 Max and Skip. 4프레임 동안 같은 행동을 지속하되, 3번째와 4번째 프레임의 최댓값 이미지를 관측 obs로 삼는다'''\n", 148 | " gym.Wrapper.__init__(self, env)\n", 149 | " # most recent raw observations (for max pooling across time steps)\n", 150 | " self._obs_buffer = np.zeros(\n", 151 | " (2,)+env.observation_space.shape, dtype=np.uint8)\n", 152 | " self._skip = skip\n", 153 | "\n", 154 | " def step(self, action):\n", 155 | " \"\"\"Repeat action, sum reward, and max over last observations.\"\"\"\n", 156 | " total_reward = 0.0\n", 157 | " done = None\n", 158 | " for i in range(self._skip):\n", 159 | " obs, reward, done, info = self.env.step(action)\n", 160 | " if i == self._skip - 2:\n", 161 | " self._obs_buffer[0] = obs\n", 162 | " if i == self._skip - 1:\n", 163 | " self._obs_buffer[1] = obs\n", 164 | " total_reward += reward\n", 165 | " if done:\n", 166 | " break\n", 167 | " # Note that the observation on the done=True frame\n", 168 | " # doesn't matter\n", 169 | " max_frame = self._obs_buffer.max(axis=0)\n", 170 | "\n", 171 | " return max_frame, total_reward, done, info\n", 172 | "\n", 173 | " def reset(self, **kwargs):\n", 174 | " return self.env.reset(**kwargs)\n", 175 | "\n", 176 | "\n", 177 | "class WarpFrame(gym.ObservationWrapper):\n", 178 | " def __init__(self, env):\n", 179 | " '''네 번째 트릭 Warp frame. DQN 네이처 논문 구현과 같이 84*84 흑백 이미지를 사용'''\n", 180 | " gym.ObservationWrapper.__init__(self, env)\n", 181 | " self.width = 84\n", 182 | " self.height = 84\n", 183 | " self.observation_space = spaces.Box(low=0, high=255,\n", 184 | " shape=(self.height, self.width, 1), dtype=np.uint8)\n", 185 | "\n", 186 | " def observation(self, frame):\n", 187 | " frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)\n", 188 | " frame = cv2.resize(frame, (self.width, self.height),\n", 189 | " interpolation=cv2.INTER_AREA)\n", 190 | " return frame[:, :, None]\n", 191 | "\n", 192 | "\n", 193 | "class WrapPyTorch(gym.ObservationWrapper):\n", 194 | " def __init__(self, env=None):\n", 195 | " '''인덱스 순서를 파이토치 미니배치와 같이 조정하는 래퍼'''\n", 196 | " super(WrapPyTorch, self).__init__(env)\n", 197 | " obs_shape = self.observation_space.shape\n", 198 | " self.observation_space = Box(\n", 199 | " self.observation_space.low[0, 0, 0],\n", 200 | " self.observation_space.high[0, 0, 0],\n", 201 | " [obs_shape[2], obs_shape[1], obs_shape[0]],\n", 202 | " dtype=self.observation_space.dtype)\n", 203 | "\n", 204 | " def observation(self, observation):\n", 205 | " return observation.transpose(2, 0, 1)\n" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 4, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "# 再生用の実行環境\n", 215 | "\n", 216 | "\n", 217 | "class EpisodicLifeEnvPlay(gym.Wrapper):\n", 218 | " def __init__(self, env):\n", 219 | " '''두 번째 트릭 Episodic Life. 한번 실패를 게임 종료로 간주하나, 다음 게임을 같은 블록 상태로 시작\n", 220 | " 그러나 여기서는 학습 결과 플레이를 위해 한번 실패에도 블록 상태까지 초기화한다'''\n", 221 | "\n", 222 | " gym.Wrapper.__init__(self, env)\n", 223 | "\n", 224 | " def step(self, action):\n", 225 | " obs, reward, done, info = self.env.step(action)\n", 226 | " # 처음 5개 라이프를 갖고 시작하지만, 하나만 잃어도 종료한다\n", 227 | " if self.env.unwrapped.ale.lives() < 5:\n", 228 | " done = True\n", 229 | "\n", 230 | " return obs, reward, done, info\n", 231 | "\n", 232 | " def reset(self, **kwargs):\n", 233 | " '''한번이라도 실패하면 완전히 초기화'''\n", 234 | "\n", 235 | " obs = self.env.reset(**kwargs)\n", 236 | "\n", 237 | " return obs\n", 238 | "\n", 239 | "\n", 240 | "class MaxAndSkipEnvPlay(gym.Wrapper):\n", 241 | " def __init__(self, env, skip=4):\n", 242 | " '''세 번째 트릭 Max and Skip. 4프레임 동안 같은 행동을 지속하되, 4번째 프레임 이미지를 관측 obs로 삼는다'''\n", 243 | " gym.Wrapper.__init__(self, env)\n", 244 | " # most recent raw observations (for max pooling across time steps)\n", 245 | " self._obs_buffer = np.zeros(\n", 246 | " (2,)+env.observation_space.shape, dtype=np.uint8)\n", 247 | " self._skip = skip\n", 248 | "\n", 249 | " def step(self, action):\n", 250 | " \"\"\"Repeat action, sum reward, and max over last observations.\"\"\"\n", 251 | " total_reward = 0.0\n", 252 | " done = None\n", 253 | " for i in range(self._skip):\n", 254 | " obs, reward, done, info = self.env.step(action)\n", 255 | " if i == self._skip - 2:\n", 256 | " self._obs_buffer[0] = obs\n", 257 | " if i == self._skip - 1:\n", 258 | " self._obs_buffer[1] = obs\n", 259 | " total_reward += reward\n", 260 | " if done:\n", 261 | " break\n", 262 | "\n", 263 | " return obs, total_reward, done, info\n", 264 | "\n", 265 | " def reset(self, **kwargs):\n", 266 | " return self.env.reset(**kwargs)\n" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 5, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "# 실행환경 생성 함수\n", 276 | "\n", 277 | "# 병렬 실행환경\n", 278 | "from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv\n", 279 | "\n", 280 | "\n", 281 | "def make_env(env_id, seed, rank):\n", 282 | " def _thunk():\n", 283 | " '''멀티 프로세스로 동작하는 환경 SubprocVecEnv를 실행하기 위해 필요하다'''\n", 284 | "\n", 285 | " env = gym.make(env_id)\n", 286 | " #env = NoopResetEnv(env, noop_max=30)\n", 287 | " env = MaxAndSkipEnv(env, skip=4)\n", 288 | " env.seed(seed + rank) # 난수 시드 설정\n", 289 | " #env = EpisodicLifeEnv(env)\n", 290 | " env = EpisodicLifeEnvPlay(env)\n", 291 | " env = WarpFrame(env)\n", 292 | " env = WrapPyTorch(env)\n", 293 | "\n", 294 | " return env\n", 295 | "\n", 296 | " return _thunk\n", 297 | "\n", 298 | "\n", 299 | "def make_env_play(env_id, seed, rank):\n", 300 | " '''학습 결과 시연용 실행환경'''\n", 301 | " env = gym.make(env_id)\n", 302 | " #env = NoopResetEnv(env, noop_max=30)\n", 303 | " #env = MaxAndSkipEnv(env, skip=4)\n", 304 | " env = MaxAndSkipEnvPlay(env, skip=4)\n", 305 | " env.seed(seed + rank) # 난수 시드 설정\n", 306 | " env = EpisodicLifeEnvPlay(env)\n", 307 | " #env = WarpFrame(env)\n", 308 | " #env = WrapPyTorch(env)\n", 309 | "\n", 310 | " return env\n" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 6, 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "# 상수 정의\n", 320 | "\n", 321 | "ENV_NAME = 'BreakoutNoFrameskip-v4' \n", 322 | "# Breakout-v0 대신 BreakoutNoFrameskip-v4을 사용\n", 323 | "# v0은 2~4개 프레임을 자동으로 생략하므로 이 기능이 없는 버전을 사용한다\n", 324 | "# 참고 URL https://becominghuman.ai/lets-build-an-atari-ai-part-1-dqn-df57e8ff3b26\n", 325 | "# https://github.com/openai/gym/blob/5cb12296274020db9bb6378ce54276b31e7002da/gym/envs/__init__.py#L371\n", 326 | " \n", 327 | "NUM_SKIP_FRAME = 4 # 생략할 프레임 수\n", 328 | "NUM_STACK_FRAME = 4 # 하나의 상태로 사용할 프레임의 수\n", 329 | "NOOP_MAX = 30 # 초기화 후 No-operation을 적용할 최초 프레임 수의 최댓값\n", 330 | "NUM_PROCESSES = 16 # 병렬로 실행할 프로세스 수\n", 331 | "NUM_ADVANCED_STEP = 5 # Advanced 학습할 단계 수\n", 332 | "GAMMA = 0.99 # 시간할인율\n", 333 | "\n", 334 | "TOTAL_FRAMES=10e6 # 학습에 사용하는 총 프레임 수\n", 335 | "NUM_UPDATES = int(TOTAL_FRAMES / NUM_ADVANCED_STEP / NUM_PROCESSES) # 신경망 수정 총 횟수\n", 336 | "# NUM_UPDATES는 약 125,000이 됨\n" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 7, 342 | "metadata": {}, 343 | "outputs": [], 344 | "source": [ 345 | "# A2C 손실함수를 계산하기 위한 상수\n", 346 | "value_loss_coef = 0.5\n", 347 | "entropy_coef = 0.01\n", 348 | "max_grad_norm = 0.5\n", 349 | "\n", 350 | "# 최적회 기법 RMSprop에 대한 설정\n", 351 | "lr = 7e-4\n", 352 | "eps = 1e-5\n", 353 | "alpha = 0.99\n" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 3, 359 | "metadata": {}, 360 | "outputs": [ 361 | { 362 | "name": "stdout", 363 | "output_type": "stream", 364 | "text": [ 365 | "cuda\n" 366 | ] 367 | } 368 | ], 369 | "source": [ 370 | "# GPU 사용 설정\n", 371 | "use_cuda = torch.cuda.is_available()\n", 372 | "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", 373 | "print(device)\n" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 9, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "# 메모리 클래스 정의\n", 383 | "\n", 384 | "\n", 385 | "class RolloutStorage(object):\n", 386 | " '''Advantage 학습에 사용하는 메모리 클래스'''\n", 387 | "\n", 388 | " def __init__(self, num_steps, num_processes, obs_shape):\n", 389 | "\n", 390 | " self.observations = torch.zeros(\n", 391 | " num_steps + 1, num_processes, *obs_shape).to(device)\n", 392 | " # *로 리스트의 요소를 풀어낸다(unpack)\n", 393 | " # obs_shape→(4,84,84)\n", 394 | " # *obs_shape→ 4 84 84\n", 395 | "\n", 396 | " self.masks = torch.ones(num_steps + 1, num_processes, 1).to(device)\n", 397 | " self.rewards = torch.zeros(num_steps, num_processes, 1).to(device)\n", 398 | " self.actions = torch.zeros(\n", 399 | " num_steps, num_processes, 1).long().to(device)\n", 400 | "\n", 401 | " # 할인 총보상을 저장\n", 402 | " self.returns = torch.zeros(num_steps + 1, num_processes, 1).to(device)\n", 403 | " self.index = 0 # 저장할 인덱스\n", 404 | "\n", 405 | " def insert(self, current_obs, action, reward, mask):\n", 406 | " '''인덱스가 가리키는 다음 자리에 transition을 저장'''\n", 407 | " self.observations[self.index + 1].copy_(current_obs)\n", 408 | " self.masks[self.index + 1].copy_(mask)\n", 409 | " self.rewards[self.index].copy_(reward)\n", 410 | " self.actions[self.index].copy_(action)\n", 411 | "\n", 412 | " self.index = (self.index + 1) % NUM_ADVANCED_STEP # 인덱스 업데이트\n", 413 | "\n", 414 | " def after_update(self):\n", 415 | " '''Advantage 학습 단계 수만큼 단계가 진행되면 가장 최근 단계를 index0에 저장'''\n", 416 | " self.observations[0].copy_(self.observations[-1])\n", 417 | " self.masks[0].copy_(self.masks[-1])\n", 418 | "\n", 419 | " def compute_returns(self, next_value):\n", 420 | " '''Advantage 학습 단계에 들어가는 각 단계에 대해 할인 총보상을 계산'''\n", 421 | "\n", 422 | " # 주의 : 5번째 단계부터 거슬러 올라가며 계산\n", 423 | " # 주의 : 5번째 단계가 Advantage1, 4번째 단계가 Advantage2가 되는 식임\n", 424 | " self.returns[-1] = next_value\n", 425 | " for ad_step in reversed(range(self.rewards.size(0))):\n", 426 | " self.returns[ad_step] = self.returns[ad_step + 1] * \\\n", 427 | " GAMMA * self.masks[ad_step + 1] + self.rewards[ad_step]\n" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": 10, 433 | "metadata": {}, 434 | "outputs": [], 435 | "source": [ 436 | "# A2C 신경망 구성\n", 437 | "\n", 438 | "\n", 439 | "def init(module, gain):\n", 440 | " '''결합 가중치를 초기화하는 함수'''\n", 441 | " nn.init.orthogonal_(module.weight.data, gain=gain)\n", 442 | " nn.init.constant_(module.bias.data, 0)\n", 443 | " return module\n", 444 | "\n", 445 | "\n", 446 | "class Flatten(nn.Module):\n", 447 | " '''합성곱층의 출력 이미지를 1차원으로 변환하는 층'''\n", 448 | "\n", 449 | " def forward(self, x):\n", 450 | " return x.view(x.size(0), -1)\n", 451 | "\n", 452 | "\n", 453 | "class Net(nn.Module):\n", 454 | " def __init__(self, n_out):\n", 455 | " super(Net, self).__init__()\n", 456 | "\n", 457 | " # 결합 가중치 초기화 함수\n", 458 | " def init_(module): return init(\n", 459 | " module, gain=nn.init.calculate_gain('relu'))\n", 460 | "\n", 461 | " # 합성곱층을 정의\n", 462 | " self.conv = nn.Sequential(\n", 463 | " # 이미지 크기의 변화 (84*84 -> 20*20)\n", 464 | " init_(nn.Conv2d(NUM_STACK_FRAME, 32, kernel_size=8, stride=4)),\n", 465 | " # 프레임 4개를 합치므로 input=NUM_STACK_FRAME=4가 된다. 출력은 32이다.\n", 466 | " # size 계산 size = (Input_size - Kernel_size + 2*Padding_size)/ Stride_size + 1\n", 467 | "\n", 468 | " nn.ReLU(),\n", 469 | " # 이미지 크기의 변화 (20*20 -> 9*9)\n", 470 | " init_(nn.Conv2d(32, 64, kernel_size=4, stride=2)),\n", 471 | " nn.ReLU(),\n", 472 | " init_(nn.Conv2d(64, 64, kernel_size=3, stride=1)), # 이미지 크기의 변화(9*9 -> 7*7)\n", 473 | " nn.ReLU(),\n", 474 | " Flatten(), # 이미지를 1차원으로 변환\n", 475 | " init_(nn.Linear(64 * 7 * 7, 512)), # 7*7 이미지 64개를 512차원으로 변환\n", 476 | " nn.ReLU()\n", 477 | " )\n", 478 | "\n", 479 | " # 결합 가중치 초기화 함수\n", 480 | " def init_(module): return init(module, gain=1.0)\n", 481 | "\n", 482 | " # Critic을 정의\n", 483 | " self.critic = init_(nn.Linear(512, 1)) # 출력은 상태가치이므로 1개\n", 484 | "\n", 485 | " # 결합 가중치 초기화 함수\n", 486 | " def init_(module): return init(module, gain=0.01)\n", 487 | "\n", 488 | " # Actor를 정의\n", 489 | " self.actor = init_(nn.Linear(512, n_out)) # 출력이 행동이므로 출력 수는 행동의 가짓수\n", 490 | " \n", 491 | " # 신경망을 학습 모드로 전환\n", 492 | " self.train()\n", 493 | "\n", 494 | " def forward(self, x):\n", 495 | " '''신경망의 순전파 계산 정의'''\n", 496 | " input = x / 255.0 # 이미지의 픽셀값을 [0,255]에서 [0,1] 구간으로 정규화\n", 497 | " conv_output = self.conv(input) # 합성곱층 계산\n", 498 | " critic_output = self.critic(conv_output) # 상태가치 출력 계산\n", 499 | " actor_output = self.actor(conv_output) # 행동 출력 계산\n", 500 | "\n", 501 | " return critic_output, actor_output\n", 502 | "\n", 503 | " def act(self, x):\n", 504 | " '''상태 x일때 취할 확률을 확률적으로 구함'''\n", 505 | " value, actor_output = self(x)\n", 506 | " probs = F.softmax(actor_output, dim=1) # dim=1で行動の種類方向に計算\n", 507 | " action = probs.multinomial(num_samples=1)\n", 508 | "\n", 509 | " return action\n", 510 | "\n", 511 | " def get_value(self, x):\n", 512 | " '''상태 x의 상태가치를 구함'''\n", 513 | " value, actor_output = self(x)\n", 514 | "\n", 515 | " return value\n", 516 | "\n", 517 | " def evaluate_actions(self, x, actions):\n", 518 | " '''상태 x의 상태가치, 실제 행동 actions의 로그 확률, 엔트로피를 구함'''\n", 519 | " value, actor_output = self(x)\n", 520 | "\n", 521 | " log_probs = F.log_softmax(actor_output, dim=1) # dim=1이므로 행동의 종류 방향으로 계산\n", 522 | " action_log_probs = log_probs.gather(1, actions) # 실제 행동에 대한 log_probs 계산\n", 523 | "\n", 524 | " probs = F.softmax(actor_output, dim=1) # dim=1이므로 행동의 종류 방향으로 계산\n", 525 | " dist_entropy = -(log_probs * probs).sum(-1).mean()\n", 526 | "\n", 527 | " return value, action_log_probs, dist_entropy\n" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": 11, 533 | "metadata": {}, 534 | "outputs": [], 535 | "source": [ 536 | "# 에이전트의 두뇌 역할을 하는 클래스로, 모든 에이전트가 공유한다\n", 537 | "\n", 538 | "\n", 539 | "class Brain(object):\n", 540 | " def __init__(self, actor_critic):\n", 541 | "\n", 542 | " self.actor_critic = actor_critic # actor_critic은 Net클래스로 구현한 신경망이다\n", 543 | "\n", 544 | " # 이미 학습된 결합 가중치를 로드하려면\n", 545 | " filename = 'weight_end.pth'\n", 546 | " #filename = 'weight_112500.pth'\n", 547 | " param = torch.load(filename, map_location='cpu')\n", 548 | " self.actor_critic.load_state_dict(param)\n", 549 | "\n", 550 | " # 가중치를 학습하는 최적화 알고리즘 설정\n", 551 | " self.optimizer = optim.RMSprop(\n", 552 | " actor_critic.parameters(), lr=lr, eps=eps, alpha=alpha)\n", 553 | "\n", 554 | " def update(self, rollouts):\n", 555 | " '''advanced 학습 대상 5단계를 모두 사용하여 수정한다'''\n", 556 | " obs_shape = rollouts.observations.size()[2:] # torch.Size([4, 84, 84])\n", 557 | " num_steps = NUM_ADVANCED_STEP\n", 558 | " num_processes = NUM_PROCESSES\n", 559 | "\n", 560 | " values, action_log_probs, dist_entropy = self.actor_critic.evaluate_actions(\n", 561 | " rollouts.observations[:-1].view(-1, *obs_shape),\n", 562 | " rollouts.actions.view(-1, 1))\n", 563 | "\n", 564 | " # 각 변수의 크기에 주의할 것\n", 565 | " # rollouts.observations[:-1].view(-1, *obs_shape) torch.Size([80, 4, 84, 84])\n", 566 | " # rollouts.actions.view(-1, 1) torch.Size([80, 1])\n", 567 | " # values torch.Size([80, 1])\n", 568 | " # action_log_probs torch.Size([80, 1])\n", 569 | " # dist_entropy torch.Size([])\n", 570 | "\n", 571 | " values = values.view(num_steps, num_processes,\n", 572 | " 1) # torch.Size([5, 16, 1])\n", 573 | " action_log_probs = action_log_probs.view(num_steps, num_processes, 1)\n", 574 | "\n", 575 | " advantages = rollouts.returns[:-1] - values # torch.Size([5, 16, 1])\n", 576 | " value_loss = advantages.pow(2).mean()\n", 577 | "\n", 578 | " action_gain = (advantages.detach() * action_log_probs).mean()\n", 579 | " # advantages는 detach 하여 정수로 취급한다\n", 580 | "\n", 581 | " total_loss = (value_loss * value_loss_coef -\n", 582 | " action_gain - dist_entropy * entropy_coef)\n", 583 | "\n", 584 | " self.optimizer.zero_grad() # 경사 초기화\n", 585 | " total_loss.backward() # 역전파 계산\n", 586 | " nn.utils.clip_grad_norm_(self.actor_critic.parameters(), max_grad_norm)\n", 587 | " # 한번에 결합 가중치가 너무 크게 변화하지 않도록, 경사의 최댓값을 0.5로 제한한다\n", 588 | "\n", 589 | " self.optimizer.step() # 결합 가중치 수정\n" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": 12, 595 | "metadata": {}, 596 | "outputs": [], 597 | "source": [ 598 | "# Breakout을 실행하는 환경 클래스\n", 599 | "\n", 600 | "NUM_PROCESSES = 1\n", 601 | "\n", 602 | "\n", 603 | "class Environment:\n", 604 | " def run(self):\n", 605 | "\n", 606 | " # 난수 시드 설정\n", 607 | " seed_num = 1\n", 608 | " torch.manual_seed(seed_num)\n", 609 | " if use_cuda:\n", 610 | " torch.cuda.manual_seed(seed_num)\n", 611 | "\n", 612 | " # 실행환경 구축\n", 613 | " torch.set_num_threads(seed_num)\n", 614 | " envs = [make_env(ENV_NAME, seed_num, i) for i in range(NUM_PROCESSES)]\n", 615 | " envs = SubprocVecEnv(envs) # 멀티프로세스 실행환경\n", 616 | "\n", 617 | " # 모든 에이전트가 공유하는 두뇌 역할 클래스 Brain 객체 생성\n", 618 | " n_out = envs.action_space.n # 행동의 가짓수는 4\n", 619 | " actor_critic = Net(n_out).to(device) # GPU 사용\n", 620 | " global_brain = Brain(actor_critic)\n", 621 | "\n", 622 | " # 정보 저장용 변수 생성\n", 623 | " obs_shape = envs.observation_space.shape # (1, 84, 84)\n", 624 | " obs_shape = (obs_shape[0] * NUM_STACK_FRAME,\n", 625 | " *obs_shape[1:]) # (4, 84, 84)\n", 626 | " # torch.Size([16, 4, 84, 84])\n", 627 | " current_obs = torch.zeros(NUM_PROCESSES, *obs_shape).to(device)\n", 628 | " rollouts = RolloutStorage(\n", 629 | " NUM_ADVANCED_STEP, NUM_PROCESSES, obs_shape) # rollouts 객체\n", 630 | " episode_rewards = torch.zeros([NUM_PROCESSES, 1]) # 현재 에피소드에서 받을 보상 저장\n", 631 | " final_rewards = torch.zeros([NUM_PROCESSES, 1]) # 마지막 에피소드의 총 보상 저장\n", 632 | "\n", 633 | " # 초기 상태로 시작\n", 634 | " obs = envs.reset()\n", 635 | " obs = torch.from_numpy(obs).float() # torch.Size([16, 1, 84, 84])\n", 636 | " current_obs[:, -1:] = obs # 4번째 프레임에 가장 최근 관측결과를 저장\n", 637 | "\n", 638 | " # advanced 학습에 사용할 객체 rollouts에 첫번째 상태로 현재 상태를 저장\n", 639 | " rollouts.observations[0].copy_(current_obs)\n", 640 | "\n", 641 | " # 애니메이션 생성용 환경(시연용 추가)\n", 642 | " env_play = make_env_play(ENV_NAME, seed_num, 0)\n", 643 | " obs_play = env_play.reset()\n", 644 | "\n", 645 | " # 애니메이션 생성을 위해 이미지를 저장할 변수(시연용 추가)\n", 646 | " frames = []\n", 647 | " main_end = False\n", 648 | "\n", 649 | " # 주 반복문\n", 650 | " for j in tqdm(range(NUM_UPDATES)):\n", 651 | "\n", 652 | " # 보상이 기준을 넘어서면 종료 (시연용 추가)\n", 653 | " if main_end:\n", 654 | " break\n", 655 | "\n", 656 | " # advanced 학습 범위에 들어가는 단계마다 반복\n", 657 | " for step in range(NUM_ADVANCED_STEP):\n", 658 | "\n", 659 | " # 행동을 결정\n", 660 | " with torch.no_grad():\n", 661 | " action = actor_critic.act(rollouts.observations[step])\n", 662 | "\n", 663 | " cpu_actions = action.squeeze(1).cpu().numpy() # tensor를 NumPy 변수로\n", 664 | "\n", 665 | " # 1단계를 병렬로 실행, 반환값 obs의 크기는 (16, 1, 84, 84)\n", 666 | " obs, reward, done, info = envs.step(cpu_actions)\n", 667 | "\n", 668 | " # 보상을 텐서로 변환한 다음 에피소드 총 보상에 더함\n", 669 | " # 크기가 (16,)인 것을 (16, 1)로 변환\n", 670 | " reward = np.expand_dims(np.stack(reward), 1)\n", 671 | " reward = torch.from_numpy(reward).float()\n", 672 | " episode_rewards += reward\n", 673 | "\n", 674 | " # 각 프로세스마다 done이 True이면 0, False이면 1\n", 675 | " masks = torch.FloatTensor(\n", 676 | " [[0.0] if done_ else [1.0] for done_ in done])\n", 677 | "\n", 678 | " # 마지막 에피소드의 총 보상을 업데이트\n", 679 | " final_rewards *= masks # done이 True이면 0을 곱하고, False이면 1을 곱하여 리셋\n", 680 | " # done이 False이면 0을 더하고, True이면 epicodic_rewards를 더함\n", 681 | " final_rewards += (1 - masks) * episode_rewards\n", 682 | "\n", 683 | " # 이미지를 구함(시연용 추가)\n", 684 | " obs_play, reward_play, _, _ = env_play.step(cpu_actions[0])\n", 685 | " frames.append(obs_play) # 변환한 이미지를 저장\n", 686 | " if done[0]: # 첫번째 프로세스가 종료된 경우\n", 687 | " print(episode_rewards[0][0].numpy()) # 보상\n", 688 | "\n", 689 | " # 보상이 300을 초과하면 종료\n", 690 | " if (episode_rewards[0][0].numpy()) > 300:\n", 691 | " main_end = True\n", 692 | " break\n", 693 | " else:\n", 694 | " obs_view = env_play.reset()\n", 695 | " frames = [] # 저장한 이미지를 리셋\n", 696 | "\n", 697 | " # 에피소드의 총 보상을 업데이트\n", 698 | " episode_rewards *= masks # 각 프로세스마다 done이 True이면 0, False이면 1을 곱함\n", 699 | "\n", 700 | " # masks 변수를 GPU로 전달\n", 701 | " masks = masks.to(device)\n", 702 | "\n", 703 | " # done이 True이면 모두 0으로\n", 704 | " # mask의 크기를 torch.Size([16, 1]) --> torch.Size([16, 1, 1 ,1])로 변환하고 곱함\n", 705 | " current_obs *= masks.unsqueeze(2).unsqueeze(2)\n", 706 | "\n", 707 | " # 프레임을 모음\n", 708 | " # torch.Size([16, 1, 84, 84])\n", 709 | " obs = torch.from_numpy(obs).float()\n", 710 | " current_obs[:, :-1] = current_obs[:, 1:] # 0~2번째 프레임을 1~3번째 프레임으로 덮어씀\n", 711 | " current_obs[:, -1:] = obs # 4번째 프레임에 가장 최근 obs를 저장\n", 712 | "\n", 713 | " # 메모리 객체에 현 단계의 transition을 저장\n", 714 | " rollouts.insert(current_obs, action.data, reward, masks)\n", 715 | "\n", 716 | " # advanced 학습의 for문 끝\n", 717 | "\n", 718 | " # advanced 학습 대상 단계 중 마지막 단계의 상태에서 예상되는 상태가치를 계산\n", 719 | " with torch.no_grad():\n", 720 | " next_value = actor_critic.get_value(\n", 721 | " rollouts.observations[-1]).detach()\n", 722 | "\n", 723 | " # 모든 단계의 할인 총보상을 계산하고, rollouts의 변수 returns를 업데이트\n", 724 | " rollouts.compute_returns(next_value)\n", 725 | "\n", 726 | " # 신경망 수정 및 rollout 업데이트\n", 727 | " # global_brain.update(rollouts)\n", 728 | " rollouts.after_update()\n", 729 | "\n", 730 | " # 주 반복문 끝\n", 731 | " display_frames_as_gif(frames) # 애니메이션 저장 및 재생\n" 732 | ] 733 | }, 734 | { 735 | "cell_type": "code", 736 | "execution_count": null, 737 | "metadata": {}, 738 | "outputs": [], 739 | "source": [ 740 | "# 실행\n", 741 | "breakout_env = Environment()\n", 742 | "frames = breakout_env.run()" 743 | ] 744 | } 745 | ], 746 | "metadata": { 747 | "kernelspec": { 748 | "display_name": "Python 3", 749 | "language": "python", 750 | "name": "python3" 751 | }, 752 | "language_info": { 753 | "codemirror_mode": { 754 | "name": "ipython", 755 | "version": 3 756 | }, 757 | "file_extension": ".py", 758 | "mimetype": "text/x-python", 759 | "name": "python", 760 | "nbconvert_exporter": "python", 761 | "pygments_lexer": "ipython3", 762 | "version": "3.6.6" 763 | } 764 | }, 765 | "nbformat": 4, 766 | "nbformat_minor": 2 767 | } 768 | -------------------------------------------------------------------------------- /program/weight_end.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/program/weight_end.pth --------------------------------------------------------------------------------