├── README.md
├── movies
├── 2_2_maze_random.gif
├── 2_3_maze_reinforce.gif
├── 2_5_maze_state_value.gif
├── 3_2_cartpole_random.gif
├── 3_4_cartpole_q_learning.gif
├── 5_4_cartpole_dqn.gif
├── 7_breakout.gif
├── 7_breakout_random.gif
└── book.jpg
└── program
├── 2_1_TryJupyter.ipynb
├── 2_2_maze_random.ipynb
├── 2_3_Policygradient.ipynb
├── 2_5_Sarsa.ipynb
├── 2_6_Qlearning.ipynb
├── 3_2_try_CartPole.ipynb
├── 3_3_digitized_CartPole.ipynb
├── 3_4_Qlearning_CartPole.ipynb
├── 4_3_PyTorch_MNIST.ipynb
├── 5_3and5_4_DQN.ipynb
├── 6_2_DDQN.ipynb
├── 6_3_DuelingNetwork.ipynb
├── 6_4_PrioritizedExperienceReplay.ipynb
├── 6_5_A2C_Advanced_ActorCritic.ipynb
├── 7_1_Breakout_try.ipynb
├── 7_breakout_learning.ipynb
├── 7_breakout_play.ipynb
└── weight_end.pth
/README.md:
--------------------------------------------------------------------------------
1 | # Deep-Reinforcement-Learning-Book
2 | [書籍「つくりながら学ぶ!深層強化学習」、著者:株式会社電通国際情報サービス 小川雄太郎、出版社: マイナビ出版 (2018/6/28) ](https://www.amazon.co.jp/%E3%81%A4%E3%81%8F%E3%82%8A%E3%81%AA%E3%81%8C%E3%82%89%E5%AD%A6%E3%81%B6-%E6%B7%B1%E5%B1%A4%E5%BC%B7%E5%8C%96%E5%AD%A6%E7%BF%92-PyTorch%E3%81%AB%E3%82%88%E3%82%8B%E5%AE%9F%E8%B7%B5%E3%83%97%E3%83%AD%E3%82%B0%E3%83%A9%E3%83%9F%E3%83%B3%E3%82%B0-%E6%A0%AA%E5%BC%8F%E4%BC%9A%E7%A4%BE%E9%9B%BB%E9%80%9A%E5%9B%BD%E9%9A%9B%E6%83%85%E5%A0%B1%E3%82%B5%E3%83%BC%E3%83%93%E3%82%B9-%E5%B0%8F%E5%B7%9D%E9%9B%84%E5%A4%AA%E9%83%8E/dp/4839965625)のサポートリポジトリです。
3 | [이 페이지는 "서지사항"](아마존 페이지 링크)의 고객 지원 페이지입니다.
4 |
5 |
6 | 맨 아랫 부분에 정오표를 실었습니다.
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | 그림. 벽돌 깨기 공략(A2C 알고리즘 사용, GPU 1장으로 3시간 동안 학습한 결과)
15 |
16 |
17 | 그림. 미로 안에서 무작위 이동
18 |
19 |
20 | 그림. 미로를 대상으로 강화학습을 수행한 결과
21 |
22 |
23 | 그림. 미로 안의 각 위치에 대해 학습된 상태 가치
24 |
25 |
26 | 그림. 역진자 제어하기
27 |
28 |
29 | ### 정오표
30 |
--------------------------------------------------------------------------------
/movies/2_2_maze_random.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/2_2_maze_random.gif
--------------------------------------------------------------------------------
/movies/2_3_maze_reinforce.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/2_3_maze_reinforce.gif
--------------------------------------------------------------------------------
/movies/2_5_maze_state_value.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/2_5_maze_state_value.gif
--------------------------------------------------------------------------------
/movies/3_2_cartpole_random.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/3_2_cartpole_random.gif
--------------------------------------------------------------------------------
/movies/3_4_cartpole_q_learning.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/3_4_cartpole_q_learning.gif
--------------------------------------------------------------------------------
/movies/5_4_cartpole_dqn.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/5_4_cartpole_dqn.gif
--------------------------------------------------------------------------------
/movies/7_breakout.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/7_breakout.gif
--------------------------------------------------------------------------------
/movies/7_breakout_random.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/7_breakout_random.gif
--------------------------------------------------------------------------------
/movies/book.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/movies/book.jpg
--------------------------------------------------------------------------------
/program/2_1_TryJupyter.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## 2.1 주피터 노트북 사용법"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 2,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "x = 2 + 3\n",
17 | "y = x * 4\n"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 3,
23 | "metadata": {},
24 | "outputs": [
25 | {
26 | "name": "stdout",
27 | "output_type": "stream",
28 | "text": [
29 | "x의 값은 5입니다\n",
30 | "y의 값은 20입니다\n"
31 | ]
32 | }
33 | ],
34 | "source": [
35 | "print(\"x의 값은 {}입니다\".format(x))\n",
36 | "print(\"y의 값은 {}입니다\".format(y))\n"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 4,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "# 본문의 설명을 따라 파일을 저장하고 로드해 봅시다."
46 | ]
47 | }
48 | ],
49 | "metadata": {
50 | "kernelspec": {
51 | "display_name": "Python 3",
52 | "language": "python",
53 | "name": "python3"
54 | },
55 | "language_info": {
56 | "codemirror_mode": {
57 | "name": "ipython",
58 | "version": 3
59 | },
60 | "file_extension": ".py",
61 | "mimetype": "text/x-python",
62 | "name": "python",
63 | "nbconvert_exporter": "python",
64 | "pygments_lexer": "ipython3",
65 | "version": "3.6.6"
66 | }
67 | },
68 | "nbformat": 4,
69 | "nbformat_minor": 2
70 | }
71 |
--------------------------------------------------------------------------------
/program/2_5_Sarsa.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## 2.5 Sarsa 알고리즘 구현하기"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "# 구현에 사용할 패키지 임포트하기\n",
17 | "import numpy as np\n",
18 | "import matplotlib.pyplot as plt\n",
19 | "%matplotlib inline\n"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {},
26 | "outputs": [
27 | {
28 | "data": {
29 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAElCAYAAABect+9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAGu1JREFUeJzt3XtUlHX+B/D3MzcEBoJfksJQoLlK8MtM0APaLzPYpFq7mXZgK4Ei/WmXk3TsuOtu222PmejR1V9HziqVrW5eUqFTrWwqrrcUvOCKlK55Q1uQIOQyMON8f3+MsILKDMQ8z3yH9+ucOR7m+c48n/kG777fZ57n+yhCCBARyUCndQFERO5iYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0DN1p3L9/fxEdHe2hUoioryotLb0ohAhz1a5bgRUdHY2SkpKeV0VEdB2Kopx2px2nhEQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkjW6t1uCthBCovFSJ0vOl2Fe5D8Wni1FeXY5mezPsDjsuOy5Dr9PDoDPA3+CP2LBYjIsah9GW0YiPiIclyAJFUbT+GETkgrSB5RAOfH3yayzcuxC7zuyC3WGHUW9EQ2sDHMJxTXu7ww67ww6r3YpdZ3dhz7k9MJvMaL3cCqPOiLG3jcWsxFlIHpwMncKBJ5E3ki6waptrsfLgSuTuycWl1ktoaG1o39Zsb3b7fRzCgfqWegCAFVZ8deIr7DyzE0GmIOQk5SDr7iyE+of2ev1E1HOKEMLtxgkJCUKrBfzO1Z/D7KLZ2FixETpFhyZbk8f2FWAMgEM48ETME3jvl+8hMjjSY/siIkBRlFIhRIKrdl4/9xFCYMXBFYhZGoN1R9fBard6NKwAoMnWBKvdirVH1yJmaQxWHFyB7gQ7EXmGVwdWZX0lxn80Hq98+QoabY2wC7uq+7cLOxptjXjly1cw/qPxqKyvVHX/RNSR1wZW/qF8xCyNwa6zu9Boa9S0lkZbI3ad3YWYZTHIP5SvaS1EfZnXBZYQAq9+9Spe/OJFNNgaYHeoO6q6EbvDjobWBrz4xYuY9bdZnCISacCrAuuy4zIyNmUg70Cex49T9VSTrQnLS5cjc3MmLjsua10OUZ/iNac1CCGQtTkL64+t99qwatNka8K68nUAgPxH83nSKZFKvGaENetvs7Dh2AavD6s2baGVsyVH61KI+gyvCKz8Q/nIO5Cn+cH17mqbHvJAPJE6NA+syvpKvPzFy9KMrDprsjXh5S9f5ikPRCrQNLCEEEj/LB3Wy1Yty/jZWuwt+PVnv+Y3h0QepmlgrTy0EqXnS73m1IWesjlsKDlfwqkhkYdpFljn6s+1n8HuCxptjXjlq1c4NSTyIM0Ca3bRbLTYW7TavUdY7VbMLpqtdRlEPkuTwKptrsXGio2qXxvoaXaHHZ9VfIba5lqtSyHySZoE1sqDK312kTydouOxLCIPUT01HMKB3D250p7G4EqTrQm5u3Ovu+opEf08qgfW1ye/xqXWS73/xo0APgewCMDbAN4H8BGAf13ZLgBsA7AAwDsA8gFU9X4ZAFDfWo+t32/1zJt7kerqasyYMQPR0dHw8/PDgAEDkJycjKKiIgDAZ599hgkTJiAsLAyKomD79u3aFuwDuupzm82G119/HcOHD0dgYCDCw8ORnp6OM2fOaF12r1H9WsKFexd2WNa413wKwAbgUQD/BWeAnQLQNpDbBWAPgMcA3AygGMDHAF4C4Ne7pTS0NiB3Ty5SBqf07ht7mUmTJqGpqQkrVqzAkCFDUFVVheLiYtTU1AAAGhsbMWbMGDz99NN49tlnNa7WN3TV501NTThw4AB++9vfYsSIEfjpp5+Qk5OD1NRUlJWVwWDwmkuHe0zVJZKFELhp3k29P8JqBvAegGcA3H69HQPIBTAawL1XnrPBOQp7AIDLhVm7L9gvGHWv1/nshdF1dXUIDQ1FUVERUlK6DuaLFy8iLCwM27Ztw3333adOgT6oO33epry8HHFxcSgrK8Odd97p4Qp7ziuXSK68VAmbw9b7b2y68vgWziDqrBZAAzqGmRFAFICzvV8OALRebsX5S+c98+ZewGw2w2w2o6CgAFar3FcqyKInfV5f77zRSmiob9xQRdXAKj1fCpPe1PtvrIdzqlcGYB6APwP4G4BzV7a3zUADO70u8KptvcykN6H0Qqln3twLGAwGfPjhh/jkk08QEhKCpKQkvPbaa/jmm2+0Ls1ndbfPW1tbkZOTg4kTJyIy0jdupKJqYO2r3OeZ41cAEAsgB0A6gCFwjpz+DGDHVW1UnJ01tjZiX+U+9XaogUmTJuH8+fMoLCzEgw8+iN27dyMxMRF//OMftS7NZ7nb53a7HU8//TTq6uqQn+87p9moegzrnpX3YNfZXT1+fbdtBnAYwAwASwFkA7Bctf0vAAIAPO6Z3d9z2z34R+Y/PPPmXur555/Hxx9/jIaGBphMztE0j2F5Vuc+t9vtSEtLw5EjR7B9+3YMHDhQ6xJd8spjWOXV5WruDggD4ABgvvL411XbbABOA7jVc7tX/fN6gdjYWNjtdh7XUtHVfW6z2fDUU0+hrKwM27ZtkyKsukPV7zm7c2fmbmkCsBbA3QAGwHmawnk4T2UYDKAfgEQ4p4f94TytYQecB+o9+MVJs81Dn9cL1NTUYPLkycjKysLw4cMRFBSEkpISzJ8/H8nJyQgODsaPP/6IM2fOoK6uDgBw4sQJhISEYODAgT73h6QGV30eEBCAJ598Evv370dhYSEURcEPP/wAALjpppvg7++v8Sf4+VQNLI8tI2MCEAngGwA/ArADCIYzjNpOYxgL56jqCzhPg4iE8zSIXj4H62oe+UbUS5jNZiQmJmLx4sU4ceIEWlpaYLFYkJ6ejrlz5wIACgoKkJmZ2f6a7OxsAMAbb7yBP/zhD1qULTVXfX7u3Dls3rwZABAfH9/htfn5+cjIyNCg6t6l6jEs3Zs6CPSdRe4UKHC8wUt0iFzxymNYep1ezd1prq99XiJPUzWwDDr5Lw3oDqPOqHUJRD5F1cDyN8h/0K87/I196/MSeZqqgRUbFqvm7jTX1z4vkaepGljjosb57MJ9nekVPcZFjdO6DCKfomp6jLaMhtlkVnOXmgk0BWK0ZbTWZRD5FFUDKz4iHq2XW9XcpWZaL7ciPjzedUMicpuqgWUJsvSZb85MehMigiK0LoPIp6gaWIqiYOxtY9XcpWbG3DrGZxfvI9KK6idGzUqchZ1ndvZsmZkdAI7AuUyMAsAfzstsWuG8njDkSruHAdwG5zLJuQAeQsdVRRfhP5fk+MO5WoMJzjXgAecaWTo4V3IAnKs8dKOnzCYzcpJy3H8BEblF9cBKHpyMIFNQ9wPrLIDvAEyDs+pGAJfhvGbwewC7Afy602uOwnnN4BFcuwzyVDgX8NsGZxA+AuB/r2zbBmeA9XAwGOwXjPsH3d+zFxPRDal+joFO0SEnKQcBxgDXja92Cc4RT1vEBsIZVl35J5xrttdfeVxPZBfbeiDAGICcpJw+c/oGkZo0+avKujur+/ftux3ATwCWwHk7r1Mu2v8E59QuEkAcnOF1PScAxHSvlK44hAOZIzJdNySibtMksEL9Q/F4zOMwKN2YkfrBOR2cCOfoah2Ag120/yecQQUA/41rA+sjAPMBnESvrYll0BnwRMwTCPX3jQX/ibyNZlcjz//lfBR8WwC7rRtrZOkADLryuAXO5Y/vvkHbI3Ae5yq78vMlADVwLt4HOI9hmQBsgvOYVWr36r+efoZ+mP/L+T//jYjoujQ70BIZHInFDy5GoLHzrWxu4CKcgdPmBwA3ddHWBudNKV698vgfXDvKMsIZVIfxnxuu9lCgMRCLUxfDEmxx3ZiIekTTI8NZI7KQEJHg3rIzrQA2wnkzif8DUA3gvhu0PYJrj0vdceX5zoLgnBLud6vk6zLqjBhlGcVjV0QepuqKo9dTWV+JmKUxaLB56PZfKjCbzKiYWcHRFVEPeeWKo9djCbZgyUNLun+ag5cIMAZgyYNLGFZEKtA8sAAgc0QmXhj5gnShFWgMxLT4aZwKEqnEKwILABZOWIgn73hSmtAKMAbgydgnkftArtalEPUZXhNYiqJg5aMrMTl2steHVoAxAJNjJ2PFIyt4gTORirwmsADnXWbyH83HtPhpXhtaAcYATI+fjvxH83lXHCKVeVVgAc6R1sIJC7H0oaUwm8xec6cdo84Is8mMpQ8tRe6EXI6siDTgdYHVJnNEJipmVmDsrWPdP7nUQwKNgRhz6xhUzKzgAXYiDXltYAHOUx62Td2GJQ8ucY62unPtYS8w6Awwm8xY8uASbJu6jacuEGnMqwMLcE4Rs+7OwrGZxzAlbgr6GfohwODZ41sBhgD0M/TDlNgpqJhZgay7szgFJPIC3nGAyA2RwZH4y6S/oLa5FvmH8rFg9wJcar3Us5VLb8BsMiPYFIycMTnIHJHJVReIvIzml+b0lEM4sPX7rcjdk4vdZ3ej9XIrTHoTGlob3FprS6foYDaZ21835tYxyEnKwf2D7ufie0Qqc/fSHGlGWJ3pFB1SBqcgZXAKhBA4f+k8Si+UYl/lPhSfLkZ5dTmabc2wOWy47LgMvU4Po84If6M/YsNiMS5qHEZbRiM+PB4RQRGc8hFJQNrAupqiKLAEW2AJtuCRYY9oXQ4ReQjnPkQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTR84uJnn8UVJLTTjWWXSD0cYRGRNDjC8mb8v7z6OKr1ahxhEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkDZ8JrOrqasyYMQPR0dHw8/PDgAEDkJycjKKiIgDA7373O8TExCAwMBChoaFITk7G7t27Na5abq76/GovvPACFEXBggULNKjUd7jq84yMDCiK0uGRmJiocdW9x6B1Ab1l0qRJaGpqwooVKzBkyBBUVVWhuLgYNTU1AIBhw4Zh2bJlGDRoEJqbm7Fo0SKkpqbi+PHjGDBggMbVy8lVn7dZv3499u/fj4iICI0q9R3u9HlKSgpWrVrV/rPJZNKiVM8QQrj9iI+PF96otrZWABBFRUVuv+ann34SAMRXX33lwcp8l7t9furUKRERESHKy8tFVFSUeP/991WqsIcA58MLudPnU6dOFQ8//LCKVfUOACXCjQzyiSmh2WyG2WxGQUEBrFary/atra3Iy8tDcHAwRowYoUKFvsedPrfb7UhLS8PcuXNxxx13qFyh73H393znzp245ZZbMHToUGRnZ6OqqkrFKj3MnVQTXj7CEkKI9evXi9DQUOHn5ycSExNFTk6O2Lt3b4c2hYWFIjAwUCiKIiIiIsQ333yjUbW+wVWf/+Y3vxG/+tWv2n/mCOvnc9Xna9asEZs3bxZlZWWioKBADB8+XMTFxQmr1aph1a7BzRGWzwSWEEI0NzeLLVu2iDfffFMkJSUJAOLdd99t397Q0CCOHz8u9uzZI7KyskRUVJQ4f/68hhXL70Z9vn37dhERESGqqqra2zKweoer3/OrVVZWCoPBIDZs2KByld3TJwOrs+eee04YjUbR0tJy3e1DhgwRb731lspV+ba2Pp8zZ45QFEXo9fr2BwCh0+mExWLRuswbkyCwOnP1ex4dHS3mzZunclXd425g+cy3hNcTGxsLu90Oq9V63W9KHA4HWlpaNKjMd7X1+fTp05Gent5h24QJE5CWlobs7GyNqvNNXf2eX7x4EZWVlQgPD9eout7lE4FVU1ODyZMnIysrC8OHD0dQUBBKSkowf/58JCcnAwDmzp2LiRMnIjw8HNXV1Vi2bBnOnTuHKVOmaFy9nFz1+W233XbNa4xGIwYOHIhhw4ZpULH8XPW5TqfDa6+9hkmTJiE8PBynTp3CnDlzcMstt+Dxxx/Xuvxe4ROBZTabkZiYiMWLF+PEiRNoaWmBxWJBeno65s6dC4PBgKNHj2LlypWoqanBzTffjFGjRmHHjh0YPny41uVLyVWfU+9z1ed6vR5HjhzBxx9/jLq6OoSHh2P8+PFYu3YtgoKCtC6/VyjO6aN7EhISRElJiQfLIdKYojj/7cbfBf18iqKUCiESXLXzifOwiKhvYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0GFhEJA2D1gVQFxTF+a8Q2tbRF7X1PXkVjrCISBocYRFdjaNZbbg5ouUIi4ikwcAiImkwsIhIGgwsIpIGA4uIpMHAIiJpMLCISBoMLCKSBgOLiKTBwCIiaTCwiEgaDCwikgYDi4ikwcAiImkwsIhIGgwsIpIGA4uIpMHAIiJpMLCISBoMLCKSBgOLiKTBwCIiaTCwiEgaDCwikgYDi4ikwcAiImkwsIhIGgwsIpIGA4uIpMHAIiJpMLCISBoMLCKSBgOLiKThM4FVXV2NGTNmIDo6Gn5+fhgwYACSk5NRVFTU3ua7777DE088gZCQEAQEBGDkyJE4duyYhlXLzVWfK4py3cfMmTM1rlxervq8oaEBL730EiIjI+Hv749hw4Zh0aJFGlfdewxaF9BbJk2ahKamJqxYsQJDhgxBVVUViouLUVNTAwD4/vvvMXbsWDz77LPYunUrQkJCUFFRAbPZrHHl8nLV5xcuXOjQvqSkBBMnTsSUKVO0KNcnuOrzWbNm4e9//ztWrVqFQYMGYceOHcjOzkb//v3xzDPPaFx9LxBCuP2Ij48X3qi2tlYAEEVFRTdsk5aWJtLT01WsqhcAzocXcqfPO3v++efF0KFDPViVb3Onz+Pi4sTvf//7Ds/de++9YubMmZ4u72cBUCLcyCCfmBKazWaYzWYUFBTAarVes93hcKCwsBCxsbFITU1FWFgYRo0ahU8//VSDan2Dqz7vrKGhAX/961+RnZ2tQnW+yZ0+v+eee1BYWIizZ88CAHbv3o1Dhw4hNTVVzVI9x51UE14+whJCiPXr14vQ0FDh5+cnEhMTRU5Ojti7d68QQogLFy4IACIgIEDk5uaKgwcPitzcXKHX60VhYaHGlXfBi0dYQnTd550tX75cGI1GUVVVpXKVvsVVn7e0tIjMzEwBQBgMBmEwGMQHH3ygYcXugZsjLJ8JLCGEaG5uFlu2bBFvvvmmSEpKEgDEu+++KyorKwUAkZaW1qF9WlqaSE1N1ahaN3h5YAlx4z7vLCEhQUyePFmDCn1PV32+YMECMXToUFFQUCAOHz4s/vSnP4nAwEDx5Zdfalx11/pkYHX23HPPCaPRKFpaWoTBYBBvv/12h+1vvfWWiI2N1ag6N0gQWJ1d3edtDh48KACILVu2aFiZ72rr87q6OmE0GsWmTZuu2Z6cnKxRde5xN7B84hjWjcTGxsJut8NqtWLUqFH49ttvO2z/7rvvEBUVpVF1vunqPm+Tl5eH6OhopKSkaFiZ72rrc0VRYLPZoNfrO2zX6/VwOBwaVdfL3Ek14eUjrIsXL4rx48eLVatWicOHD4uTJ0+KtWvXigEDBoiUlBQhhBAbN24URqNRLF++XBw/flzk5eUJg8EgPv/8c42r74IXj7Dc6XMhhGhsbBTBwcHinXfe0bBa3+BOn48bN07ExcWJbdu2iZMnT4r8/HzRr18/sWTJEo2r7xr60pTQarWKOXPmiISEBBESEiL8/f3FkCFDxKuvvipqamra2+Xn54tf/OIXol+/fuLOO+8Uq1ev1rBqN3hxYLnb5ytXrhR6vV5UVlZqWK1vcKfPL1y4IDIyMkRERITo16+fGDZsmHj//feFw+HQuPquuRtYirOtexISEkRJSYnHRnvUiaI4/+3GfyMiGSmKUiqESHDVzqePYRGRb2FgEZE0GFhEJA0GFhFJg4FFRNJgYBGRNBhYRCQNBhYRSYOBRUTSYGAReal///vfSE9Px+DBgxEfH4+kpCRs3LgRALBz506MHj0aMTExiImJQV5e3jWvv+uuu5CWltbhuYyMDKxfv16V+j3BZ9Z0J/IlQgg89thjmDp1KlavXg0AOH36NAoKCvDDDz8gPT0dmzZtwsiRI3Hx4kVMmDABFosFDz/8MADg2LFjcDgc2LFjBxobGxEYGKjlx+k1HGEReaGtW7fCZDJh+vTp7c9FRUXhpZdewrJly5CRkYGRI0cCAPr374/58+dj3rx57W1Xr16NZ555Bg888AAKCgpUr99TGFhEXujo0aPtgXS9bfHx8R2eS0hIwNGjR9t//vTTT/HUU08hLS0Na9as8WitamJgEUlg5syZuOuuuzBq1CjnMittK3lcpe25/fv3IywsDFFRUUhOTsaBAwdQW1urdskewcAi8kJxcXE4cOBA+8/Lli3D119/jerqasTFxaHzMk+lpaWIjY0FAKxZswYVFRWIjo7G7bffjvr6emzYsEHV+j2FgUXkhe6//35YrVZ88MEH7c81NTUBcI62PvzwQxw6dAgAUFNTg9dffx2zZ8+Gw+HAunXrUFZWhlOnTuHUqVPYvHmzz0wLGVhEXkhRFGzatAnFxcUYNGgQRo8ejalTp+K9995DeHg4PvnkE2RnZyMmJgZjxoxBVlYWJk6ciB07dsBiscBisbS/17333ovy8vL2O3FPmzYNkZGRiIyMRFJSklYfsUe44qg344qj1EdwxVEi8jkMLCKSBgOLiKTBwCIiaTCwiEgaDCwikgYDi4ikwcAiImkwsIhIGgwsIpIGA4uIpMHAIiJpMLCISBoMLCKSBgOLiKTBwCIiaTCwiEgaDCwikgYDi4ikwcAiImkwsIhIGgwsIpIGA4uIpMHAIiJpMLCISBoMLCKSRrduVa8oSjWA054rh4j6qCghRJirRt0KLCIiLXFKSETSYGARkTQYWEQkDQYWEUmDgUVE0mBgEZE0GFhEJA0GFhFJg4FFRNL4f/+izvyQK8sjAAAAAElFTkSuQmCC\n",
30 | "text/plain": [
31 | ""
32 | ]
33 | },
34 | "metadata": {},
35 | "output_type": "display_data"
36 | }
37 | ],
38 | "source": [
39 | "# 초기 상태의 미로 모습\n",
40 | "\n",
41 | "# 전체 그림의 크기 및 그림을 나타내는 변수 선언\n",
42 | "fig = plt.figure(figsize=(5, 5))\n",
43 | "ax = plt.gca()\n",
44 | "\n",
45 | "# 붉은 벽 그리기\n",
46 | "plt.plot([1, 1], [0, 1], color='red', linewidth=2)\n",
47 | "plt.plot([1, 2], [2, 2], color='red', linewidth=2)\n",
48 | "plt.plot([2, 2], [2, 1], color='red', linewidth=2)\n",
49 | "plt.plot([2, 3], [1, 1], color='red', linewidth=2)\n",
50 | "\n",
51 | "# 상태를 의미하는 문자열(S0~S8) 표시\n",
52 | "plt.text(0.5, 2.5, 'S0', size=14, ha='center')\n",
53 | "plt.text(1.5, 2.5, 'S1', size=14, ha='center')\n",
54 | "plt.text(2.5, 2.5, 'S2', size=14, ha='center')\n",
55 | "plt.text(0.5, 1.5, 'S3', size=14, ha='center')\n",
56 | "plt.text(1.5, 1.5, 'S4', size=14, ha='center')\n",
57 | "plt.text(2.5, 1.5, 'S5', size=14, ha='center')\n",
58 | "plt.text(0.5, 0.5, 'S6', size=14, ha='center')\n",
59 | "plt.text(1.5, 0.5, 'S7', size=14, ha='center')\n",
60 | "plt.text(2.5, 0.5, 'S8', size=14, ha='center')\n",
61 | "plt.text(0.5, 2.3, 'START', ha='center')\n",
62 | "plt.text(2.5, 0.3, 'GOAL', ha='center')\n",
63 | "\n",
64 | "# 그림을 그릴 범위 및 눈금 제거 설정\n",
65 | "ax.set_xlim(0, 3)\n",
66 | "ax.set_ylim(0, 3)\n",
67 | "plt.tick_params(axis='both', which='both', bottom=False, top=False,\n",
68 | " labelbottom=False, right=False, left=False, labelleft=False)\n",
69 | "\n",
70 | "# S0에 녹색 원으로 현재 위치를 표시\n",
71 | "line, = ax.plot([0.5], [2.5], marker=\"o\", color='g', markersize=60)\n"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 3,
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "# 정책을 결정하는 파라미터의 초깃값 theta_0를 설정\n",
81 | "\n",
82 | "# 줄은 상태 0~7, 열은 행동방향(상,우,하,좌 순)를 나타낸다\n",
83 | "theta_0 = np.array([[np.nan, 1, 1, np.nan], # s0\n",
84 | " [np.nan, 1, np.nan, 1], # s1\n",
85 | " [np.nan, np.nan, 1, 1], # s2\n",
86 | " [1, 1, 1, np.nan], # s3\n",
87 | " [np.nan, np.nan, 1, 1], # s4\n",
88 | " [1, np.nan, np.nan, np.nan], # s5\n",
89 | " [1, np.nan, np.nan, np.nan], # s6\n",
90 | " [1, 1, np.nan, np.nan], # s7、※s8은 목표지점이므로 정책이 없다\n",
91 | " ])"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 5,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "# 정책 파라미터 theta_0을 무작위 행동 정책 pi로 변환하는 함수\n",
101 | "\n",
102 | "def simple_convert_into_pi_from_theta(theta):\n",
103 | " '''단순 비율 계산'''\n",
104 | "\n",
105 | " [m, n] = theta.shape # theta의 행렬 크기를 구함\n",
106 | " pi = np.zeros((m, n))\n",
107 | " for i in range(0, m):\n",
108 | " pi[i, :] = theta[i, :] / np.nansum(theta[i, :]) # 비율 계산\n",
109 | "\n",
110 | " pi = np.nan_to_num(pi) # nan을 0으로 변환\n",
111 | "\n",
112 | " return pi\n",
113 | "\n",
114 | "# 무작위 행동정책 pi_0을 계산\n",
115 | "pi_0 = simple_convert_into_pi_from_theta(theta_0)"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 6,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "# 행동가치 함수 Q의 초기 상태\n",
125 | "\n",
126 | "[a, b] = theta_0.shape # 열과 행의 갯수를 변수 a, b에 저장\n",
127 | "Q = np.random.rand(a, b) * theta_0\n",
128 | "# * theta0 로 요소 단위 곱셈을 수행, Q에서 벽 방향으로 이동하는 행동에는 nan을 부여\n"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 7,
134 | "metadata": {},
135 | "outputs": [],
136 | "source": [
137 | "# ε-greedy 알고리즘 구현\n",
138 | "\n",
139 | "\n",
140 | "def get_action(s, Q, epsilon, pi_0):\n",
141 | " direction = [\"up\", \"right\", \"down\", \"left\"]\n",
142 | "\n",
143 | " # 행동을 결정\n",
144 | " if np.random.rand() < epsilon:\n",
145 | " # 확률 ε로 무작위 행동을 선택함\n",
146 | " next_direction = np.random.choice(direction, p=pi_0[s, :])\n",
147 | " else:\n",
148 | " # Q값이 최대가 되는 행동을 선택함\n",
149 | " next_direction = direction[np.nanargmax(Q[s, :])]\n",
150 | "\n",
151 | " # 행동을 인덱스로 변환\n",
152 | " if next_direction == \"up\":\n",
153 | " action = 0\n",
154 | " elif next_direction == \"right\":\n",
155 | " action = 1\n",
156 | " elif next_direction == \"down\":\n",
157 | " action = 2\n",
158 | " elif next_direction == \"left\":\n",
159 | " action = 3\n",
160 | "\n",
161 | " return action\n",
162 | "\n",
163 | "\n",
164 | "def get_s_next(s, a, Q, epsilon, pi_0):\n",
165 | " direction = [\"up\", \"right\", \"down\", \"left\"]\n",
166 | " next_direction = direction[a] # 행동 a의 방향\n",
167 | "\n",
168 | " # 행동으로 다음 상태를 결정\n",
169 | " if next_direction == \"up\":\n",
170 | " s_next = s - 3 # 위로 이동하면 상태값이 3 줄어든다\n",
171 | " elif next_direction == \"right\":\n",
172 | " s_next = s + 1 # 오른쪽으로 이동하면 상태값이 1 늘어난다\n",
173 | " elif next_direction == \"down\":\n",
174 | " s_next = s + 3 # 아래로 이동하면 상태값이 3 늘어난다\n",
175 | " elif next_direction == \"left\":\n",
176 | " s_next = s - 1 # 왼쪽으로 이동하면 상태값이 1 줄어든다\n",
177 | "\n",
178 | " return s_next"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 8,
184 | "metadata": {},
185 | "outputs": [],
186 | "source": [
187 | "# Sarsa 알고리즘으로 행동가치 함수 Q를 수정\n",
188 | "\n",
189 | "def Sarsa(s, a, r, s_next, a_next, Q, eta, gamma):\n",
190 | "\n",
191 | " if s_next == 8: # 목표 지점에 도달한 경우\n",
192 | " Q[s, a] = Q[s, a] + eta * (r - Q[s, a])\n",
193 | "\n",
194 | " else:\n",
195 | " Q[s, a] = Q[s, a] + eta * (r + gamma * Q[s_next, a_next] - Q[s, a])\n",
196 | "\n",
197 | " return Q\n"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": 9,
203 | "metadata": {},
204 | "outputs": [],
205 | "source": [
206 | "# Sarsa 알고리즘으로 미로를 빠져나오는 함수, 상태 및 행동 그리고 Q값의 히스토리를 출력한다\n",
207 | "\n",
208 | "def goal_maze_ret_s_a_Q(Q, epsilon, eta, gamma, pi):\n",
209 | " s = 0 # 시작 지점\n",
210 | " a = a_next = get_action(s, Q, epsilon, pi) # 첫 번째 행동\n",
211 | " s_a_history = [[0, np.nan]] # 에이전트의 행동 및 상태의 히스토리를 기록하는 리스트\n",
212 | "\n",
213 | " while (1): # 목표 지점에 이를 때까지 반복\n",
214 | " a = a_next # 행동 결정\n",
215 | "\n",
216 | " s_a_history[-1][1] = a\n",
217 | " # 현재 상태(마지막이므로 인덱스가 -1)을 히스토리에 추가\n",
218 | "\n",
219 | " s_next = get_s_next(s, a, Q, epsilon, pi)\n",
220 | " # 다음 단계의 상태를 구함\n",
221 | "\n",
222 | " s_a_history.append([s_next, np.nan])\n",
223 | " # 다음 상태를 히스토리에 추가, 행동은 아직 알 수 없으므로 nan으로 둔다\n",
224 | "\n",
225 | " # 보상을 부여하고 다음 행동을 계산함\n",
226 | " if s_next == 8:\n",
227 | " r = 1 # 목표 지점에 도달했다면 보상을 부여\n",
228 | " a_next = np.nan\n",
229 | " else:\n",
230 | " r = 0\n",
231 | " a_next = get_action(s_next, Q, epsilon, pi)\n",
232 | " # 다음 행동 a_next를 계산\n",
233 | "\n",
234 | " # 가치함수를 수정\n",
235 | " Q = Sarsa(s, a, r, s_next, a_next, Q, eta, gamma)\n",
236 | "\n",
237 | " # 종료 여부 판정\n",
238 | " if s_next == 8: # 목표 지점에 도달하면 종료\n",
239 | " break\n",
240 | " else:\n",
241 | " s = s_next\n",
242 | "\n",
243 | " return [s_a_history, Q]"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": 10,
249 | "metadata": {},
250 | "outputs": [
251 | {
252 | "name": "stdout",
253 | "output_type": "stream",
254 | "text": [
255 | "에피소드: 1\n",
256 | "2.9123690381123595\n",
257 | "목표 지점에 이르기까지 걸린 단계 수는 756단계입니다\n",
258 | "에피소드: 2\n",
259 | "0.06470294176990493\n",
260 | "목표 지점에 이르기까지 걸린 단계 수는 6단계입니다\n",
261 | "에피소드: 3\n",
262 | "0.08962131184352168\n",
263 | "목표 지점에 이르기까지 걸린 단계 수는 16단계입니다\n",
264 | "에피소드: 4\n",
265 | "0.07088647193553937\n",
266 | "목표 지점에 이르기까지 걸린 단계 수는 10단계입니다\n",
267 | "에피소드: 5\n",
268 | "0.057825165390497923\n",
269 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
270 | "에피소드: 6\n",
271 | "0.069115365071926\n",
272 | "목표 지점에 이르기까지 걸린 단계 수는 10단계입니다\n",
273 | "에피소드: 7\n",
274 | "0.055847280693343826\n",
275 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
276 | "에피소드: 8\n",
277 | "0.0646771466375029\n",
278 | "목표 지점에 이르기까지 걸린 단계 수는 10단계입니다\n",
279 | "에피소드: 9\n",
280 | "0.054024188311407206\n",
281 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
282 | "에피소드: 10\n",
283 | "0.05343429531742172\n",
284 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
285 | "에피소드: 11\n",
286 | "0.05278635429752665\n",
287 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
288 | "에피소드: 12\n",
289 | "0.05207874386146988\n",
290 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
291 | "에피소드: 13\n",
292 | "0.05131093120979935\n",
293 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
294 | "에피소드: 14\n",
295 | "0.05048342516236182\n",
296 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
297 | "에피소드: 15\n",
298 | "0.04959769495031635\n",
299 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
300 | "에피소드: 16\n",
301 | "0.04865606709232484\n",
302 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
303 | "에피소드: 17\n",
304 | "0.04766160988053009\n",
305 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
306 | "에피소드: 18\n",
307 | "0.046618012727305924\n",
308 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
309 | "에피소드: 19\n",
310 | "0.045529465783242185\n",
311 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
312 | "에피소드: 20\n",
313 | "0.04440054375731639\n",
314 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
315 | "에피소드: 21\n",
316 | "0.04323609668999023\n",
317 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
318 | "에피소드: 22\n",
319 | "0.04204114949695603\n",
320 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
321 | "에피소드: 23\n",
322 | "0.0408208113716384\n",
323 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
324 | "에피소드: 24\n",
325 | "0.03958019557152842\n",
326 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
327 | "에피소드: 25\n",
328 | "0.03832434968616055\n",
329 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
330 | "에피소드: 26\n",
331 | "0.03705819616729522\n",
332 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
333 | "에피소드: 27\n",
334 | "0.035786482673166475\n",
335 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
336 | "에피소드: 28\n",
337 | "0.03451374162070997\n",
338 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
339 | "에피소드: 29\n",
340 | "0.03324425823770877\n",
341 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
342 | "에피소드: 30\n",
343 | "0.03198204634869278\n",
344 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
345 | "에피소드: 31\n",
346 | "0.030730831104155143\n",
347 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
348 | "에피소드: 32\n",
349 | "0.029494037864122025\n",
350 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
351 | "에피소드: 33\n",
352 | "0.028274786467673507\n",
353 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
354 | "에피소드: 34\n",
355 | "0.027075890154329874\n",
356 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
357 | "에피소드: 35\n",
358 | "0.02589985844703846\n",
359 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
360 | "에피소드: 36\n",
361 | "0.02474890335638047\n",
362 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
363 | "에피소드: 37\n",
364 | "0.023624948318948458\n",
365 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
366 | "에피소드: 38\n",
367 | "0.022529639337513174\n",
368 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
369 | "에피소드: 39\n",
370 | "0.021464357845061732\n",
371 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
372 | "에피소드: 40\n",
373 | "0.02043023486784079\n",
374 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
375 | "에피소드: 41\n",
376 | "0.019428166113352963\n",
377 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
378 | "에피소드: 42\n",
379 | "0.018458827657222066\n",
380 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
381 | "에피소드: 43\n",
382 | "0.017522691947583047\n",
383 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
384 | "에피소드: 44\n",
385 | "0.016620043886945823\n",
386 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
387 | "에피소드: 45\n",
388 | "0.01575099678924119\n",
389 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
390 | "에피소드: 46\n",
391 | "0.01491550804395958\n",
392 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
393 | "에피소드: 47\n",
394 | "0.014113394350067199\n",
395 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
396 | "에피소드: 48\n",
397 | "0.013344346409800645\n",
398 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
399 | "에피소드: 49\n",
400 | "0.012607942996720523\n",
401 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
402 | "에피소드: 50\n",
403 | "0.011903664333717257\n",
404 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
405 | "에피소드: 51\n",
406 | "0.011230904735219371\n",
407 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
408 | "에피소드: 52\n",
409 | "0.010588984483891117\n",
410 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
411 | "에피소드: 53\n",
412 | "0.009977160925812911\n",
413 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
414 | "에피소드: 54\n",
415 | "0.009394638779763542\n",
416 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
417 | "에피소드: 55\n",
418 | "0.00884057966594709\n",
419 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
420 | "에피소드: 56\n",
421 | "0.008314110867547964\n",
422 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
423 | "에피소드: 57\n",
424 | "0.007814333345034785\n",
425 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
426 | "에피소드: 58\n",
427 | "0.0073403290283408085\n",
428 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
429 | "에피소드: 59\n",
430 | "0.006891167416099742\n",
431 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
432 | "에피소드: 60\n",
433 | "0.006465911514145994\n",
434 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
435 | "에피소드: 61\n",
436 | "0.006063623147645414\n",
437 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
438 | "에피소드: 62\n",
439 | "0.00568336768262645\n",
440 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
441 | "에피소드: 63\n",
442 | "0.005324218193444752\n",
443 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
444 | "에피소드: 64\n",
445 | "0.004985259112940121\n",
446 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
447 | "에피소드: 65\n",
448 | "0.004665589401810943\n",
449 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
450 | "에피소드: 66\n",
451 | "0.004364325273144343\n",
452 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
453 | "에피소드: 67\n",
454 | "0.004080602507132602\n",
455 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
456 | "에피소드: 68\n",
457 | "0.003813578389876615\n",
458 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
459 | "에피소드: 69\n",
460 | "0.0035624333088519755\n",
461 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
462 | "에피소드: 70\n",
463 | "0.0033263720361560445\n",
464 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
465 | "에피소드: 71\n",
466 | "0.003104624729090788\n",
467 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
468 | "에피소드: 72\n",
469 | "0.0028964476760231506\n",
470 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
471 | "에피소드: 73\n",
472 | "0.002701123813795725\n",
473 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
474 | "에피소드: 74\n",
475 | "0.0025179630412982545\n",
476 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
477 | "에피소드: 75\n",
478 | "0.0023463023521538284\n",
479 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
480 | "에피소드: 76\n",
481 | "0.002185505807833943\n",
482 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
483 | "에피소드: 77\n",
484 | "0.002034964370928427\n",
485 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
486 | "에피소드: 78\n",
487 | "0.0018940956167549095\n",
488 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
489 | "에피소드: 79\n",
490 | "0.0017623433400003607\n",
491 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
492 | "에피소드: 80\n",
493 | "0.0016391770716805976\n",
494 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
495 | "에피소드: 81\n",
496 | "0.0015240915203390548\n",
497 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
498 | "에피소드: 82\n",
499 | "0.0014166059501411477\n",
500 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
501 | "에피소드: 83\n",
502 | "0.0013162635073069584\n",
503 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
504 | "에피소드: 84\n",
505 | "0.001222630505198108\n",
506 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
507 | "에피소드: 85\n",
508 | "0.0011352956773180711\n",
509 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
510 | "에피소드: 86\n",
511 | "0.0010538694065022058\n",
512 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
513 | "에피소드: 87\n",
514 | "0.0009779829376552751\n",
515 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
516 | "에피소드: 88\n",
517 | "0.0009072875805578029\n",
518 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
519 | "에피소드: 89\n",
520 | "0.0008414539084752315\n",
521 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
522 | "에피소드: 90\n",
523 | "0.0007801709575961935\n",
524 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
525 | "에피소드: 91\n",
526 | "0.0007231454316684038\n",
527 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
528 | "에피소드: 92\n",
529 | "0.0006701009155980486\n",
530 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
531 | "에피소드: 93\n",
532 | "0.0006207771012446406\n",
533 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
534 | "에피소드: 94\n",
535 | "0.0005749290281399366\n",
536 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
537 | "에피소드: 95\n",
538 | "0.0005323263414217516\n",
539 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
540 | "에피소드: 96\n",
541 | "0.000492752568863164\n",
542 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
543 | "에피소드: 97\n",
544 | "0.0004560044185274448\n",
545 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
546 | "에피소드: 98\n",
547 | "0.0004218910982470847\n",
548 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
549 | "에피소드: 99\n",
550 | "0.00039023365784385255\n",
551 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n",
552 | "에피소드: 100\n",
553 | "0.0003608643547524659\n",
554 | "목표 지점에 이르기까지 걸린 단계 수는 4단계입니다\n"
555 | ]
556 | }
557 | ],
558 | "source": [
559 | "# Sarsa 알고리즘으로 미로 빠져나오기\n",
560 | "\n",
561 | "eta = 0.1 # 학습률\n",
562 | "gamma = 0.9 # 시간할인율\n",
563 | "epsilon = 0.5 # ε-greedy 알고리즘 epsilon 초깃값\n",
564 | "v = np.nanmax(Q, axis=1) # 각 상태마다 가치의 최댓값을 계산\n",
565 | "is_continue = True\n",
566 | "episode = 1\n",
567 | "\n",
568 | "while is_continue: # is_continue의 값이 False가 될 때까지 반복\n",
569 | " print(\"에피소드: \" + str(episode))\n",
570 | "\n",
571 | " # ε 값을 조금씩 감소시킴\n",
572 | " epsilon = epsilon / 2\n",
573 | "\n",
574 | " # Sarsa 알고리즘으로 미로를 빠져나온 후, 결과로 나온 행동 히스토리와 Q값을 변수에 저장\n",
575 | " [s_a_history, Q] = goal_maze_ret_s_a_Q(Q, epsilon, eta, gamma, pi_0)\n",
576 | "\n",
577 | " # 상태가치의 변화\n",
578 | " new_v = np.nanmax(Q, axis=1) # 각 상태마다 행동가치의 최댓값을 계산\n",
579 | " print(np.sum(np.abs(new_v - v))) # 상태가치 함수의 변화를 출력\n",
580 | " v = new_v\n",
581 | "\n",
582 | " print(\"목표 지점에 이르기까지 걸린 단계 수는 \" + str(len(s_a_history) - 1) + \"단계입니다\")\n",
583 | "\n",
584 | " # 100 에피소드 반복\n",
585 | " episode = episode + 1\n",
586 | " if episode > 100:\n",
587 | " break\n"
588 | ]
589 | }
590 | ],
591 | "metadata": {
592 | "kernelspec": {
593 | "display_name": "Python 3",
594 | "language": "python",
595 | "name": "python3"
596 | },
597 | "language_info": {
598 | "codemirror_mode": {
599 | "name": "ipython",
600 | "version": 3
601 | },
602 | "file_extension": ".py",
603 | "mimetype": "text/x-python",
604 | "name": "python",
605 | "nbconvert_exporter": "python",
606 | "pygments_lexer": "ipython3",
607 | "version": "3.6.6"
608 | }
609 | },
610 | "nbformat": 4,
611 | "nbformat_minor": 2
612 | }
613 |
--------------------------------------------------------------------------------
/program/3_3_digitized_CartPole.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## 3.3 다변수 연속값 상태를 이산변수로 변환하기"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "# 구현에 사용할 패키지 임포트\n",
17 | "import numpy as np\n",
18 | "import matplotlib.pyplot as plt\n",
19 | "%matplotlib inline\n",
20 | "import gym\n"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "# 상수 정의\n",
30 | "ENV = 'CartPole-v0' # 태스크 이름\n",
31 | "NUM_DIZITIZED = 6 # 각 상태를 이산변수로 변환할 구간 수\n"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 3,
37 | "metadata": {},
38 | "outputs": [
39 | {
40 | "name": "stdout",
41 | "output_type": "stream",
42 | "text": [
43 | "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n"
44 | ]
45 | }
46 | ],
47 | "source": [
48 | "# CartPole 실행\n",
49 | "env = gym.make(ENV) # 태스크 실행 환경 생성\n",
50 | "observation = env.reset() # 환경 초기화\n"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 4,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "# 이산변수 변환에 사용할 구간을 계산\n",
60 | "\n",
61 | "\n",
62 | "def bins(clip_min, clip_max, num):\n",
63 | " '''관측된 상태(연속값)을 이산값으로 변환하는 구간을 계산'''\n",
64 | " return np.linspace(clip_min, clip_max, num + 1)[1:-1]"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 5,
70 | "metadata": {},
71 | "outputs": [
72 | {
73 | "data": {
74 | "text/plain": [
75 | "array([-2.4, -1.6, -0.8, 0. , 0.8, 1.6, 2.4])"
76 | ]
77 | },
78 | "execution_count": 5,
79 | "metadata": {},
80 | "output_type": "execute_result"
81 | }
82 | ],
83 | "source": [
84 | "np.linspace(-2.4, 2.4, 6 + 1)"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 6,
90 | "metadata": {},
91 | "outputs": [
92 | {
93 | "data": {
94 | "text/plain": [
95 | "array([-1.6, -0.8, 0. , 0.8, 1.6])"
96 | ]
97 | },
98 | "execution_count": 6,
99 | "metadata": {},
100 | "output_type": "execute_result"
101 | }
102 | ],
103 | "source": [
104 | "np.linspace(-2.4, 2.4, 6 + 1)[1:-1]"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 10,
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "def digitize_state(observation):\n",
114 | " '''관측된 상태(observation)을 이산값으로 변환'''\n",
115 | " cart_pos, cart_v, pole_angle, pole_v = observation\n",
116 | " digitized = [\n",
117 | " np.digitize(cart_pos, bins=bins(-2.4, 2.4, NUM_DIZITIZED)),\n",
118 | " np.digitize(cart_v, bins=bins(-3.0, 3.0, NUM_DIZITIZED)),\n",
119 | " np.digitize(pole_angle, bins=bins(-0.5, 0.5, NUM_DIZITIZED)),\n",
120 | " np.digitize(pole_v, bins=bins(-2.0, 2.0, NUM_DIZITIZED))]\n",
121 | " return sum([x * (NUM_DIZITIZED**i) for i, x in enumerate(digitized)])\n"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 11,
127 | "metadata": {},
128 | "outputs": [
129 | {
130 | "data": {
131 | "text/plain": [
132 | "525"
133 | ]
134 | },
135 | "execution_count": 11,
136 | "metadata": {},
137 | "output_type": "execute_result"
138 | }
139 | ],
140 | "source": [
141 | "digitize_state(observation)"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {
148 | "collapsed": true
149 | },
150 | "outputs": [],
151 | "source": []
152 | }
153 | ],
154 | "metadata": {
155 | "kernelspec": {
156 | "display_name": "Python 3",
157 | "language": "python",
158 | "name": "python3"
159 | },
160 | "language_info": {
161 | "codemirror_mode": {
162 | "name": "ipython",
163 | "version": 3
164 | },
165 | "file_extension": ".py",
166 | "mimetype": "text/x-python",
167 | "name": "python",
168 | "nbconvert_exporter": "python",
169 | "pygments_lexer": "ipython3",
170 | "version": "3.6.6"
171 | }
172 | },
173 | "nbformat": 4,
174 | "nbformat_minor": 2
175 | }
176 |
--------------------------------------------------------------------------------
/program/3_4_Qlearning_CartPole.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## 3.4 Q러닝 구현"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "# 구현에 사용할 패키지 임포트\n",
17 | "import numpy as np\n",
18 | "import matplotlib.pyplot as plt\n",
19 | "%matplotlib inline\n",
20 | "import gym\n"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "# 애니메이션을 만드는 함수\n",
30 | "# 참고 URL http://nbviewer.jupyter.org/github/patrickmineault\n",
31 | "# /xcorr-notebooks/blob/master/Render%20OpenAI%20gym%20as%20GIF.ipynb\n",
32 | "from JSAnimation.IPython_display import display_animation\n",
33 | "from matplotlib import animation\n",
34 | "from IPython.display import display\n",
35 | "\n",
36 | "\n",
37 | "def display_frames_as_gif(frames):\n",
38 | " \"\"\"\n",
39 | " Displays a list of frames as a gif, with controls\n",
40 | " \"\"\"\n",
41 | " plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0),\n",
42 | " dpi=72)\n",
43 | " patch = plt.imshow(frames[0])\n",
44 | " plt.axis('off')\n",
45 | "\n",
46 | " def animate(i):\n",
47 | " patch.set_data(frames[i])\n",
48 | "\n",
49 | " anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames),\n",
50 | " interval=50)\n",
51 | "\n",
52 | " anim.save('movie_cartpole.mp4') # 애니메이션을 저장하는 부분\n",
53 | " display(display_animation(anim, default_mode='loop'))\n",
54 | " "
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 3,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "# 상수 정의\n",
64 | "ENV = 'CartPole-v0' # 태스크 이름\n",
65 | "NUM_DIZITIZED = 6 # 각 상태를 이산변수로 변환할 구간 수\n",
66 | "GAMMA = 0.99 # 시간할인율\n",
67 | "ETA = 0.5 # 학습률\n",
68 | "MAX_STEPS = 200 # 1에피소드 당 최대 단계 수\n",
69 | "NUM_EPISODES = 1000 # 최대 에피소드 수\n"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": 4,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "class Agent:\n",
79 | " '''CartPole 에이전트 역할을 할 클래스, 봉 달린 수레이다'''\n",
80 | "\n",
81 | " def __init__(self, num_states, num_actions):\n",
82 | " self.brain = Brain(num_states, num_actions) # 에이전트가 행동을 결정하는 두뇌 역할\n",
83 | "\n",
84 | " def update_Q_function(self, observation, action, reward, observation_next):\n",
85 | " '''Q함수 수정'''\n",
86 | " self.brain.update_Q_table(\n",
87 | " observation, action, reward, observation_next)\n",
88 | "\n",
89 | " def get_action(self, observation, step):\n",
90 | " '''행동 결정'''\n",
91 | " action = self.brain.decide_action(observation, step)\n",
92 | " return action\n",
93 | " "
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 5,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "class Brain:\n",
103 | " '''에이전트의 두뇌 역할을 하는 클래스, Q러닝을 실제 수행'''\n",
104 | "\n",
105 | " def __init__(self, num_states, num_actions):\n",
106 | " self.num_actions = num_actions # 행동의 가짓수(왼쪽, 오른쪽)를 구함\n",
107 | "\n",
108 | " # Q테이블을 생성. 줄 수는 상태를 구간수^4(4는 변수의 수)가지 값 중 하나로 변환한 값, 열 수는 행동의 가짓수\n",
109 | " self.q_table = np.random.uniform(low=0, high=1, size=(\n",
110 | " NUM_DIZITIZED**num_states, num_actions))\n",
111 | "\n",
112 | "\n",
113 | " def bins(self, clip_min, clip_max, num):\n",
114 | " '''관측된 상태(연속값)를 이산변수로 변환하는 구간을 계산'''\n",
115 | " return np.linspace(clip_min, clip_max, num + 1)[1:-1]\n",
116 | "\n",
117 | " def digitize_state(self, observation):\n",
118 | " '''관측된 상태 observation을 이산변수로 변환'''\n",
119 | " cart_pos, cart_v, pole_angle, pole_v = observation\n",
120 | " digitized = [\n",
121 | " np.digitize(cart_pos, bins=self.bins(-2.4, 2.4, NUM_DIZITIZED)),\n",
122 | " np.digitize(cart_v, bins=self.bins(-3.0, 3.0, NUM_DIZITIZED)),\n",
123 | " np.digitize(pole_angle, bins=self.bins(-0.5, 0.5, NUM_DIZITIZED)),\n",
124 | " np.digitize(pole_v, bins=self.bins(-2.0, 2.0, NUM_DIZITIZED))\n",
125 | " ]\n",
126 | " return sum([x * (NUM_DIZITIZED**i) for i, x in enumerate(digitized)])\n",
127 | "\n",
128 | " def update_Q_table(self, observation, action, reward, observation_next):\n",
129 | " '''Q러닝으로 Q테이블을 수정'''\n",
130 | " state = self.digitize_state(observation) # 상태를 이산변수로 변환\n",
131 | " state_next = self.digitize_state(observation_next) # 다음 상태를 이산변수로 변환\n",
132 | " Max_Q_next = max(self.q_table[state_next][:])\n",
133 | " self.q_table[state, action] = self.q_table[state, action] + \\\n",
134 | " ETA * (reward + GAMMA * Max_Q_next - self.q_table[state, action])\n",
135 | "\n",
136 | " def decide_action(self, observation, episode):\n",
137 | " '''ε-greedy 알고리즘을 적용하여 서서히 최적행동의 비중을 늘림'''\n",
138 | " state = self.digitize_state(observation)\n",
139 | " epsilon = 0.5 * (1 / (episode + 1))\n",
140 | "\n",
141 | " if epsilon <= np.random.uniform(0, 1):\n",
142 | " action = np.argmax(self.q_table[state][:])\n",
143 | " else:\n",
144 | " action = np.random.choice(self.num_actions) # 0,1 두 가지 행동 중 하나를 무작위로 선택\n",
145 | " return action\n",
146 | " "
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 8,
152 | "metadata": {},
153 | "outputs": [],
154 | "source": [
155 | "class Environment:\n",
156 | " '''CartPole을 실행하는 환경 역할을 하는 클래스'''\n",
157 | "\n",
158 | " def __init__(self):\n",
159 | " self.env = gym.make(ENV) # 실행할 태스크를 설정\n",
160 | " num_states = self.env.observation_space.shape[0] # 태스크의 상태 변수 수를 구함\n",
161 | " num_actions = self.env.action_space.n # 가능한 행동 수를 구함\n",
162 | " self.agent = Agent(num_states, num_actions) # 에이전트 객체를 생성\n",
163 | "\n",
164 | " def run(self):\n",
165 | " '''실행'''\n",
166 | " complete_episodes = 0 # 성공한(195단계 이상 버틴) 에피소드 수\n",
167 | " is_episode_final = False # 마지막 에피소드 여부\n",
168 | " frames = [] # 애니메이션을 만드는데 사용할 이미지를 저장하는 변수\n",
169 | "\n",
170 | " for episode in range(NUM_EPISODES): # 에피소드 수 만큼 반복\n",
171 | " observation = self.env.reset() # 환경 초기화\n",
172 | "\n",
173 | " for step in range(MAX_STEPS): # 1 에피소드에 해당하는 반복\n",
174 | "\n",
175 | " if is_episode_final is True: # 마지막 에피소드이면 frames에 각 단계의 이미지를 저장\n",
176 | " frames.append(self.env.render(mode='rgb_array'))\n",
177 | "\n",
178 | " # 행동을 선택\n",
179 | " action = self.agent.get_action(observation, episode)\n",
180 | "\n",
181 | " # 행동 a_t를 실행하여 s_{t+1}, r_{t+1}을 계산\n",
182 | " observation_next, _, done, _ = self.env.step(\n",
183 | " action) # reward, info는 사용하지 않으므로 _로 처리함\n",
184 | "\n",
185 | " # 보상을 부여\n",
186 | " if done: # 200단계를 넘어서거나 일정 각도 이상 기울면 done의 값이 True가 됨\n",
187 | " if step < 195:\n",
188 | " reward = -1 # 봉이 쓰러지면 페널티로 보상 -1 부여\n",
189 | " complete_episodes = 0 # 195단계 이상 버티면 해당 에피소드를 성공 처리\n",
190 | " else:\n",
191 | " reward = 1 # 쓰러지지 않고 에피소드를 끝내면 보상 1 부여\n",
192 | " complete_episodes += 1 # 에피소드 연속 성공 기록을 업데이트\n",
193 | " else:\n",
194 | " reward = 0 # 에피소드 중에는 보상이 0\n",
195 | "\n",
196 | " # 다음 단계의 상태 observation_next로 Q함수를 수정\n",
197 | " self.agent.update_Q_function(\n",
198 | " observation, action, reward, observation_next)\n",
199 | "\n",
200 | " # 다음 단계 상태 관측\n",
201 | " observation = observation_next\n",
202 | "\n",
203 | " # 에피소드 마무리\n",
204 | " if done:\n",
205 | " print('{0} Episode: Finished after {1} time steps'.format(\n",
206 | " episode, step + 1))\n",
207 | " break\n",
208 | "\n",
209 | " if is_episode_final is True: # 마지막 에피소드에서는 애니메이션을 만들고 저장\n",
210 | " display_frames_as_gif(frames)\n",
211 | " break\n",
212 | "\n",
213 | " if complete_episodes >= 10: # 10 에피소드 연속으로 성공한 경우\n",
214 | " print('10 에피소드 연속 성공')\n",
215 | " is_episode_final = True # 다음 에피소드가 마지막 에피소드가 됨 "
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 9,
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "name": "stdout",
225 | "output_type": "stream",
226 | "text": [
227 | "0 Episode: Finished after 12 time steps\n",
228 | "1 Episode: Finished after 21 time steps\n",
229 | "2 Episode: Finished after 28 time steps\n",
230 | "3 Episode: Finished after 23 time steps\n",
231 | "4 Episode: Finished after 27 time steps\n",
232 | "5 Episode: Finished after 25 time steps\n",
233 | "6 Episode: Finished after 52 time steps\n",
234 | "7 Episode: Finished after 17 time steps\n",
235 | "8 Episode: Finished after 19 time steps\n",
236 | "9 Episode: Finished after 52 time steps\n",
237 | "10 Episode: Finished after 143 time steps\n",
238 | "11 Episode: Finished after 32 time steps\n",
239 | "12 Episode: Finished after 16 time steps\n",
240 | "13 Episode: Finished after 59 time steps\n",
241 | "14 Episode: Finished after 102 time steps\n",
242 | "15 Episode: Finished after 183 time steps\n",
243 | "16 Episode: Finished after 148 time steps\n",
244 | "17 Episode: Finished after 80 time steps\n",
245 | "18 Episode: Finished after 14 time steps\n",
246 | "19 Episode: Finished after 200 time steps\n",
247 | "20 Episode: Finished after 80 time steps\n",
248 | "21 Episode: Finished after 38 time steps\n",
249 | "22 Episode: Finished after 26 time steps\n",
250 | "23 Episode: Finished after 75 time steps\n",
251 | "24 Episode: Finished after 81 time steps\n",
252 | "25 Episode: Finished after 27 time steps\n",
253 | "26 Episode: Finished after 32 time steps\n",
254 | "27 Episode: Finished after 41 time steps\n",
255 | "28 Episode: Finished after 11 time steps\n",
256 | "29 Episode: Finished after 61 time steps\n",
257 | "30 Episode: Finished after 36 time steps\n",
258 | "31 Episode: Finished after 189 time steps\n",
259 | "32 Episode: Finished after 200 time steps\n",
260 | "33 Episode: Finished after 37 time steps\n",
261 | "34 Episode: Finished after 10 time steps\n",
262 | "35 Episode: Finished after 13 time steps\n",
263 | "36 Episode: Finished after 30 time steps\n",
264 | "37 Episode: Finished after 12 time steps\n",
265 | "38 Episode: Finished after 85 time steps\n",
266 | "39 Episode: Finished after 192 time steps\n",
267 | "40 Episode: Finished after 48 time steps\n",
268 | "41 Episode: Finished after 200 time steps\n",
269 | "42 Episode: Finished after 196 time steps\n",
270 | "43 Episode: Finished after 103 time steps\n",
271 | "44 Episode: Finished after 162 time steps\n",
272 | "45 Episode: Finished after 72 time steps\n",
273 | "46 Episode: Finished after 200 time steps\n",
274 | "47 Episode: Finished after 16 time steps\n",
275 | "48 Episode: Finished after 16 time steps\n",
276 | "49 Episode: Finished after 20 time steps\n",
277 | "50 Episode: Finished after 20 time steps\n",
278 | "51 Episode: Finished after 20 time steps\n",
279 | "52 Episode: Finished after 10 time steps\n",
280 | "53 Episode: Finished after 200 time steps\n",
281 | "54 Episode: Finished after 49 time steps\n",
282 | "55 Episode: Finished after 55 time steps\n",
283 | "56 Episode: Finished after 47 time steps\n",
284 | "57 Episode: Finished after 15 time steps\n",
285 | "58 Episode: Finished after 40 time steps\n",
286 | "59 Episode: Finished after 49 time steps\n",
287 | "60 Episode: Finished after 40 time steps\n",
288 | "61 Episode: Finished after 42 time steps\n",
289 | "62 Episode: Finished after 41 time steps\n",
290 | "63 Episode: Finished after 54 time steps\n",
291 | "64 Episode: Finished after 36 time steps\n",
292 | "65 Episode: Finished after 73 time steps\n",
293 | "66 Episode: Finished after 69 time steps\n",
294 | "67 Episode: Finished after 51 time steps\n",
295 | "68 Episode: Finished after 32 time steps\n",
296 | "69 Episode: Finished after 51 time steps\n",
297 | "70 Episode: Finished after 75 time steps\n",
298 | "71 Episode: Finished after 9 time steps\n",
299 | "72 Episode: Finished after 17 time steps\n",
300 | "73 Episode: Finished after 10 time steps\n",
301 | "74 Episode: Finished after 68 time steps\n",
302 | "75 Episode: Finished after 22 time steps\n",
303 | "76 Episode: Finished after 12 time steps\n",
304 | "77 Episode: Finished after 24 time steps\n",
305 | "78 Episode: Finished after 18 time steps\n",
306 | "79 Episode: Finished after 19 time steps\n",
307 | "80 Episode: Finished after 60 time steps\n",
308 | "81 Episode: Finished after 44 time steps\n",
309 | "82 Episode: Finished after 78 time steps\n",
310 | "83 Episode: Finished after 97 time steps\n",
311 | "84 Episode: Finished after 119 time steps\n",
312 | "85 Episode: Finished after 55 time steps\n",
313 | "86 Episode: Finished after 51 time steps\n",
314 | "87 Episode: Finished after 118 time steps\n",
315 | "88 Episode: Finished after 168 time steps\n",
316 | "89 Episode: Finished after 62 time steps\n",
317 | "90 Episode: Finished after 200 time steps\n",
318 | "91 Episode: Finished after 25 time steps\n",
319 | "92 Episode: Finished after 162 time steps\n",
320 | "93 Episode: Finished after 200 time steps\n",
321 | "94 Episode: Finished after 77 time steps\n",
322 | "95 Episode: Finished after 200 time steps\n",
323 | "96 Episode: Finished after 107 time steps\n",
324 | "97 Episode: Finished after 91 time steps\n",
325 | "98 Episode: Finished after 159 time steps\n",
326 | "99 Episode: Finished after 73 time steps\n",
327 | "100 Episode: Finished after 94 time steps\n",
328 | "101 Episode: Finished after 40 time steps\n",
329 | "102 Episode: Finished after 145 time steps\n",
330 | "103 Episode: Finished after 87 time steps\n",
331 | "104 Episode: Finished after 200 time steps\n",
332 | "105 Episode: Finished after 198 time steps\n",
333 | "106 Episode: Finished after 87 time steps\n",
334 | "107 Episode: Finished after 107 time steps\n",
335 | "108 Episode: Finished after 108 time steps\n",
336 | "109 Episode: Finished after 110 time steps\n",
337 | "110 Episode: Finished after 26 time steps\n",
338 | "111 Episode: Finished after 196 time steps\n",
339 | "112 Episode: Finished after 190 time steps\n",
340 | "113 Episode: Finished after 200 time steps\n",
341 | "114 Episode: Finished after 187 time steps\n",
342 | "115 Episode: Finished after 51 time steps\n",
343 | "116 Episode: Finished after 84 time steps\n",
344 | "117 Episode: Finished after 184 time steps\n",
345 | "118 Episode: Finished after 10 time steps\n",
346 | "119 Episode: Finished after 96 time steps\n",
347 | "120 Episode: Finished after 93 time steps\n",
348 | "121 Episode: Finished after 187 time steps\n",
349 | "122 Episode: Finished after 171 time steps\n",
350 | "123 Episode: Finished after 200 time steps\n",
351 | "124 Episode: Finished after 145 time steps\n",
352 | "125 Episode: Finished after 161 time steps\n",
353 | "126 Episode: Finished after 96 time steps\n",
354 | "127 Episode: Finished after 107 time steps\n",
355 | "128 Episode: Finished after 146 time steps\n",
356 | "129 Episode: Finished after 192 time steps\n",
357 | "130 Episode: Finished after 129 time steps\n",
358 | "131 Episode: Finished after 89 time steps\n",
359 | "132 Episode: Finished after 126 time steps\n",
360 | "133 Episode: Finished after 95 time steps\n",
361 | "134 Episode: Finished after 52 time steps\n",
362 | "135 Episode: Finished after 37 time steps\n",
363 | "136 Episode: Finished after 200 time steps\n",
364 | "137 Episode: Finished after 73 time steps\n",
365 | "138 Episode: Finished after 167 time steps\n",
366 | "139 Episode: Finished after 152 time steps\n",
367 | "140 Episode: Finished after 78 time steps\n",
368 | "141 Episode: Finished after 200 time steps\n",
369 | "142 Episode: Finished after 200 time steps\n",
370 | "143 Episode: Finished after 100 time steps\n",
371 | "144 Episode: Finished after 167 time steps\n",
372 | "145 Episode: Finished after 200 time steps\n",
373 | "146 Episode: Finished after 200 time steps\n",
374 | "147 Episode: Finished after 200 time steps\n",
375 | "148 Episode: Finished after 183 time steps\n",
376 | "149 Episode: Finished after 23 time steps\n",
377 | "150 Episode: Finished after 200 time steps\n",
378 | "151 Episode: Finished after 200 time steps\n",
379 | "152 Episode: Finished after 200 time steps\n",
380 | "153 Episode: Finished after 200 time steps\n",
381 | "154 Episode: Finished after 171 time steps\n",
382 | "155 Episode: Finished after 200 time steps\n",
383 | "156 Episode: Finished after 200 time steps\n",
384 | "157 Episode: Finished after 131 time steps\n",
385 | "158 Episode: Finished after 200 time steps\n",
386 | "159 Episode: Finished after 200 time steps\n",
387 | "160 Episode: Finished after 200 time steps\n",
388 | "161 Episode: Finished after 200 time steps\n",
389 | "162 Episode: Finished after 200 time steps\n",
390 | "163 Episode: Finished after 200 time steps\n",
391 | "164 Episode: Finished after 200 time steps\n",
392 | "165 Episode: Finished after 200 time steps\n",
393 | "166 Episode: Finished after 200 time steps\n",
394 | "167 Episode: Finished after 200 time steps\n",
395 | "10 에피소드 연속 성공\n"
396 | ]
397 | },
398 | {
399 | "ename": "NotImplementedError",
400 | "evalue": "abstract",
401 | "output_type": "error",
402 | "traceback": [
403 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
404 | "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)",
405 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# main\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mcartpole_env\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mEnvironment\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mcartpole_env\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
406 | "\u001b[0;32m\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_episode_final\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# 마지막 에피소드이면 frames에 각 단계의 이미지를 저장\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0mframes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'rgb_array'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# 행동을 선택\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
407 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/gym/core.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode, **kwargs)\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 274\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'human'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 275\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 276\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 277\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
408 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/gym/envs/classic_control/cartpole.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode)\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclassic_control\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mrendering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 151\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrendering\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mViewer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscreen_width\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscreen_height\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 152\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcartwidth\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcartwidth\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcartheight\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcartheight\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0maxleoffset\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0mcartheight\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m4.0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
409 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/gym/envs/classic_control/rendering.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, width, height, display)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwidth\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwidth\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheight\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mheight\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpyglet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mWindow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwidth\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwidth\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mheight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mheight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdisplay\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_close\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow_closed_by_user\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misopen\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
410 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/pyglet/window/__init__.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, width, height, caption, resizable, style, fullscreen, visible, vsync, display, screen, config, context, mode)\u001b[0m\n\u001b[1;32m 502\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 503\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mscreen\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 504\u001b[0;31m \u001b[0mscreen\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_default_screen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 505\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 506\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
411 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/pyglet/canvas/base.py\u001b[0m in \u001b[0;36mget_default_screen\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mrtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;32mclass\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mScreen\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m '''\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_screens\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_windows\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
412 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/pyglet/canvas/base.py\u001b[0m in \u001b[0;36mget_screens\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mrtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlist\u001b[0m \u001b[0mof\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;32mclass\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mScreen\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m '''\n\u001b[0;32m---> 65\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'abstract'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_default_screen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
413 | "\u001b[0;31mNotImplementedError\u001b[0m: abstract"
414 | ]
415 | }
416 | ],
417 | "source": [
418 | "# main\n",
419 | "cartpole_env = Environment()\n",
420 | "cartpole_env.run()\n"
421 | ]
422 | },
423 | {
424 | "cell_type": "code",
425 | "execution_count": null,
426 | "metadata": {},
427 | "outputs": [],
428 | "source": []
429 | }
430 | ],
431 | "metadata": {
432 | "kernelspec": {
433 | "display_name": "Python 3",
434 | "language": "python",
435 | "name": "python3"
436 | },
437 | "language_info": {
438 | "codemirror_mode": {
439 | "name": "ipython",
440 | "version": 3
441 | },
442 | "file_extension": ".py",
443 | "mimetype": "text/x-python",
444 | "name": "python",
445 | "nbconvert_exporter": "python",
446 | "pygments_lexer": "ipython3",
447 | "version": "3.6.4"
448 | }
449 | },
450 | "nbformat": 4,
451 | "nbformat_minor": 2
452 | }
453 |
--------------------------------------------------------------------------------
/program/4_3_PyTorch_MNIST.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## 4.3 파이토치로 MNIST 이미지 분류 구현하기"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 2,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "# 숫자 손글씨 이미지 데이터 집합 MNIST 다운로드\n",
17 | "\n",
18 | "from sklearn.datasets import fetch_mldata\n",
19 | "\n",
20 | "mnist = fetch_mldata('MNIST original', data_home=\".\") \n",
21 | "# data_home에 다운로드 받은 경로를 지정"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 5,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "# 1. 데이터 전처리 (이미지 데이터와 레이블 데이터 분리, 정규화)\n",
31 | "\n",
32 | "X = mnist.data / 255 # 0-255값을 [0,1] 구간으로 정규화\n",
33 | "y = mnist.target\n"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 6,
39 | "metadata": {
40 | "scrolled": true
41 | },
42 | "outputs": [
43 | {
44 | "name": "stdout",
45 | "output_type": "stream",
46 | "text": [
47 | "이 이미지 데이터의 레이블은 0이다\n"
48 | ]
49 | },
50 | {
51 | "data": {
52 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADi5JREFUeJzt3X+IXfWZx/HPo22CmkbUYhyN2bQlLi2iEzMGoWHNulhcDSRFognipOzSyR8NWFlkVUYTWItFNLsqGEx1aIJpkmp0E8u6aXFEWxBxjFJt0x+hZNPZDBljxEwQDCbP/jEnyyTO/Z479557z5l53i8Ic+957rnn8TqfOefe77nna+4uAPGcVXYDAMpB+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBPWldm7MzDidEGgxd7d6HtfUnt/MbjKzP5rZPjO7t5nnAtBe1ui5/WZ2tqQ/SbpR0qCktyWtdPffJ9Zhzw+0WDv2/Asl7XP3v7j7cUnbJC1t4vkAtFEz4b9M0l/H3B/Mlp3GzHrMbMDMBprYFoCCNfOB33iHFl84rHf3jZI2Shz2A1XSzJ5/UNLlY+7PlnSwuXYAtEsz4X9b0jwz+5qZTZO0QtKuYtoC0GoNH/a7++dmtkbSbklnS+pz998V1hmAlmp4qK+hjfGeH2i5tpzkA2DyIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqLZO0Y2pZ8GCBcn6mjVrata6u7uT627evDlZf/LJJ5P1PXv2JOvRsecHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaCamqXXzPZLGpF0QtLn7t6V83hm6Z1kOjs7k/X+/v5kfebMmUW2c5pPPvkkWb/oootatu0qq3eW3iJO8vl7dz9cwPMAaCMO+4Ggmg2/S/qlmb1jZj1FNASgPZo97P+2ux80s4sl/crM/uDub4x9QPZHgT8MQMU0ted394PZz2FJL0laOM5jNrp7V96HgQDaq+Hwm9l5ZvaVU7clfUfSB0U1BqC1mjnsnyXpJTM79Tw/c/f/LqQrAC3X1Dj/hDfGOH/lLFz4hXdqp9mxY0eyfumllybrqd+vkZGR5LrHjx9P1vPG8RctWlSzlvdd/7xtV1m94/wM9QFBEX4gKMIPBEX4gaAIPxAU4QeCYqhvCjj33HNr1q655prkus8991yyPnv27GQ9O8+jptTvV95w2yOPPJKsb9u2LVlP9dbb25tc9+GHH07Wq4yhPgBJhB8IivADQRF+ICjCDwRF+IGgCD8QFFN0TwFPP/10zdrKlSvb2MnE5J2DMGPGjGT99ddfT9YXL15cs3bVVVcl142APT8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBMU4/ySwYMGCZP2WW26pWcv7vn2evLH0l19+OVl/9NFHa9YOHjyYXPfdd99N1j/++ONk/YYbbqhZa/Z1mQrY8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAULnX7TezPklLJA27+5XZsgslbZc0V9J+Sbe5e3rQVVy3v5bOzs5kvb+/P1mfOXNmw9t+5ZVXkvW86wFcf/31yXrqe/PPPPNMct0PP/wwWc9z4sSJmrVPP/00uW7ef1fenANlKvK6/T+VdNMZy+6V9Kq7z5P0anYfwCSSG353f0PSkTMWL5W0Kbu9SdKygvsC0GKNvuef5e5DkpT9vLi4lgC0Q8vP7TezHkk9rd4OgIlpdM9/yMw6JCn7OVzrge6+0d273L2rwW0BaIFGw79L0qrs9ipJO4tpB0C75IbfzLZKelPS35rZoJn9s6QfS7rRzP4s6cbsPoBJJHecv9CNBR3nv+KKK5L1tWvXJusrVqxI1g8fPlyzNjQ0lFz3oYceStZfeOGFZL3KUuP8eb/327dvT9bvuOOOhnpqhyLH+QFMQYQfCIrwA0ERfiAowg8ERfiBoLh0dwGmT5+erKcuXy1JN998c7I+MjKSrHd3d9esDQwMJNc955xzkvWo5syZU3YLLceeHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCYpy/APPnz0/W88bx8yxdujRZz5tGGxgPe34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIpx/gKsX78+WTdLX0k5b5yecfzGnHVW7X3byZMn29hJNbHnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgcsf5zaxP0hJJw+5+ZbZsnaTvS/owe9j97v5frWqyCpYsWVKz1tnZmVw3bzroXbt2NdQT0lJj+Xn/T957772i26mcevb8P5V00zjL/93dO7N/Uzr4wFSUG353f0PSkTb0AqCNmnnPv8bMfmtmfWZ2QWEdAWiLRsO/QdI3JHVKGpL0WK0HmlmPmQ2YWXrSOABt1VD43f2Qu59w95OSfiJpYeKxG929y927Gm0SQPEaCr+ZdYy5+11JHxTTDoB2qWeob6ukxZK+amaDktZKWmxmnZJc0n5Jq1vYI4AWyA2/u68cZ/GzLeil0lLz2E+bNi257vDwcLK+ffv2hnqa6qZPn56sr1u3ruHn7u/vT9bvu+++hp97suAMPyAowg8ERfiBoAg/EBThB4Ii/EBQXLq7DT777LNkfWhoqE2dVEveUF5vb2+yfs899yTrg4ODNWuPPVbzjHRJ0rFjx5L1qYA9PxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ExTh/G0S+NHfqsuZ54/S33357sr5z585k/dZbb03Wo2PPDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBMc5fJzNrqCZJy5YtS9bvuuuuhnqqgrvvvjtZf+CBB2rWzj///OS6W7ZsSda7u7uTdaSx5weCIvxAUIQfCIrwA0ERfiAowg8ERfiBoHLH+c3sckmbJV0i6aSkje7+uJldKGm7pLmS9ku6zd0/bl2r5XL3hmqSdMkllyTrTzzxRLLe19eXrH/00Uc1a9ddd11y3TvvvDNZv/rqq5P12bNnJ+sHDhyoWdu9e3dy3aeeeipZR3Pq2fN/Lulf3P2bkq6T9AMz+5akeyW96u7zJL2a3QcwSeSG392H3H1PdntE0l5Jl0laKmlT9rBNktKnsQGolAm95zezuZLmS3pL0ix3H5JG/0BIurjo5gC0Tt3n9pvZDEk7JP3Q3Y/mnc8+Zr0eST2NtQegVera85vZlzUa/C3u/mK2+JCZdWT1DknD463r7hvdvcvdu4poGEAxcsNvo7v4ZyXtdff1Y0q7JK3Kbq+SlL6UKoBKsbxhKjNbJOnXkt7X6FCfJN2v0ff9P5c0R9IBScvd/UjOc6U3VmHLly+vWdu6dWtLt33o0KFk/ejRozVr8+bNK7qd07z55pvJ+muvvVaz9uCDDxbdDiS5e13vyXPf87v7byTVerJ/mEhTAKqDM/yAoAg/EBThB4Ii/EBQhB8IivADQeWO8xe6sUk8zp/66urzzz+fXPfaa69tatt5p1I38/8w9XVgSdq2bVuyPpkvOz5V1TvOz54fCIrwA0ERfiAowg8ERfiBoAg/EBThB4JinL8AHR0dyfrq1auT9d7e3mS9mXH+xx9/PLnuhg0bkvV9+/Yl66gexvkBJBF+ICjCDwRF+IGgCD8QFOEHgiL8QFCM8wNTDOP8AJIIPxAU4QeCIvxAUIQfCIrwA0ERfiCo3PCb2eVm9pqZ7TWz35nZXdnydWb2v2b2Xvbv5ta3C6AouSf5mFmHpA5332NmX5H0jqRlkm6TdMzdH617Y5zkA7RcvSf5fKmOJxqSNJTdHjGzvZIua649AGWb0Ht+M5srab6kt7JFa8zst2bWZ2YX1Finx8wGzGygqU4BFKruc/vNbIak1yX9yN1fNLNZkg5Lckn/ptG3Bv+U8xwc9gMtVu9hf13hN7MvS/qFpN3uvn6c+lxJv3D3K3Oeh/ADLVbYF3ts9NKxz0raOzb42QeBp3xX0gcTbRJAeer5tH+RpF9Lel/SyWzx/ZJWSurU6GH/fkmrsw8HU8/Fnh9osUIP+4tC+IHW4/v8AJIIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQeVewLNghyX9z5j7X82WVVFVe6tqXxK9NarI3v6m3ge29fv8X9i42YC7d5XWQEJVe6tqXxK9Naqs3jjsB4Ii/EBQZYd/Y8nbT6lqb1XtS6K3RpXSW6nv+QGUp+w9P4CSlBJ+M7vJzP5oZvvM7N4yeqjFzPab2fvZzMOlTjGWTYM2bGYfjFl2oZn9ysz+nP0cd5q0knqrxMzNiZmlS33tqjbjddsP+83sbEl/knSjpEFJb0ta6e6/b2sjNZjZfkld7l76mLCZ/Z2kY5I2n5oNycwekXTE3X+c/eG8wN3/tSK9rdMEZ25uUW+1Zpb+nkp87Yqc8boIZez5F0ra5+5/cffjkrZJWlpCH5Xn7m9IOnLG4qWSNmW3N2n0l6ftavRWCe4+5O57stsjkk7NLF3qa5foqxRlhP8ySX8dc39Q1Zry2yX90szeMbOespsZx6xTMyNlPy8uuZ8z5c7c3E5nzCxdmdeukRmvi1ZG+MebTaRKQw7fdvdrJP2jpB9kh7eozwZJ39DoNG5Dkh4rs5lsZukdkn7o7kfL7GWscfoq5XUrI/yDki4fc3+2pIMl9DEudz+Y/RyW9JJG36ZUyaFTk6RmP4dL7uf/ufshdz/h7icl/UQlvnbZzNI7JG1x9xezxaW/duP1VdbrVkb435Y0z8y+ZmbTJK2QtKuEPr7AzM7LPoiRmZ0n6Tuq3uzDuyStym6vkrSzxF5OU5WZm2vNLK2SX7uqzXhdykk+2VDGf0g6W1Kfu/+o7U2Mw8y+rtG9vTT6jcefldmbmW2VtFij3/o6JGmtpP+U9HNJcyQdkLTc3dv+wVuN3hZrgjM3t6i3WjNLv6USX7siZ7wupB/O8ANi4gw/ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANB/R/7QknxGq+fLwAAAABJRU5ErkJggg==\n",
53 | "text/plain": [
54 | ""
55 | ]
56 | },
57 | "metadata": {
58 | "needs_background": "light"
59 | },
60 | "output_type": "display_data"
61 | }
62 | ],
63 | "source": [
64 | "# 첫 번째 데이터를 시각화\n",
65 | "\n",
66 | "import matplotlib.pyplot as plt\n",
67 | "% matplotlib inline\n",
68 | "\n",
69 | "plt.imshow(X[0].reshape(28, 28), cmap='gray')\n",
70 | "print(\"이 이미지 데이터의 레이블은 {:.0f}이다\".format(y[0]))"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 7,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "# 2. DataLoader로 변환\n",
80 | "\n",
81 | "import torch\n",
82 | "from torch.utils.data import TensorDataset, DataLoader\n",
83 | "from sklearn.model_selection import train_test_split\n",
84 | "\n",
85 | "# 2.1 데이터를 훈련 데이터와 테스트 데이터로 분할(6:1 비율)\n",
86 | "X_train, X_test, y_train, y_test = train_test_split(\n",
87 | " X, y, test_size=1/7, random_state=0)\n",
88 | "\n",
89 | "# 2.2 데이터를 파이토치 텐서로 변환\n",
90 | "X_train = torch.Tensor(X_train)\n",
91 | "X_test = torch.Tensor(X_test)\n",
92 | "y_train = torch.LongTensor(y_train)\n",
93 | "y_test = torch.LongTensor(y_test)\n",
94 | "\n",
95 | "# 2.3 데이터와 정답 레이블을 하나로 묶어 Dataset으로 만듬\n",
96 | "ds_train = TensorDataset(X_train, y_train)\n",
97 | "ds_test = TensorDataset(X_test, y_test)\n",
98 | "\n",
99 | "# 2.4 미니배치 크기를 지정하여 DataLoader 객체로 변환\n",
100 | "# Chainer의 iterators.SerialIterator와 비슷함\n",
101 | "loader_train = DataLoader(ds_train, batch_size=64, shuffle=True)\n",
102 | "loader_test = DataLoader(ds_test, batch_size=64, shuffle=False)\n"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 8,
108 | "metadata": {},
109 | "outputs": [
110 | {
111 | "name": "stdout",
112 | "output_type": "stream",
113 | "text": [
114 | "Sequential(\n",
115 | " (fc1): Linear(in_features=784, out_features=100, bias=True)\n",
116 | " (relu1): ReLU()\n",
117 | " (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
118 | " (relu2): ReLU()\n",
119 | " (fc3): Linear(in_features=100, out_features=10, bias=True)\n",
120 | ")\n"
121 | ]
122 | }
123 | ],
124 | "source": [
125 | "# 3. 신경망 구성\n",
126 | "# Keras 스타일\n",
127 | "\n",
128 | "from torch import nn\n",
129 | "\n",
130 | "model = nn.Sequential()\n",
131 | "model.add_module('fc1', nn.Linear(28*28*1, 100))\n",
132 | "model.add_module('relu1', nn.ReLU())\n",
133 | "model.add_module('fc2', nn.Linear(100, 100))\n",
134 | "model.add_module('relu2', nn.ReLU())\n",
135 | "model.add_module('fc3', nn.Linear(100, 10))\n",
136 | "\n",
137 | "print(model)\n"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 9,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "# 4. 오차함수 및 최적화 기법 설정\n",
147 | "\n",
148 | "from torch import optim\n",
149 | "\n",
150 | "# 오차함수 선택\n",
151 | "loss_fn = nn.CrossEntropyLoss() # criterion을 변수명으로 사용하는 경우가 많다\n",
152 | "\n",
153 | "\n",
154 | "# 가중치를 학습하기 위한 최적화 기법 선택\n",
155 | "optimizer = optim.Adam(model.parameters(), lr=0.01)"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": 10,
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "# 5. 학습 및 추론 설정\n",
165 | "# 5-1. 학습 중 1에포크에서 수행할 일을 함수로 정의\n",
166 | "# 파이토치에는 Chainer의 training.Trainer()에 해당하는 것이 없음\n",
167 | "\n",
168 | "\n",
169 | "def train(epoch):\n",
170 | " model.train() # 신경망을 학습 모드로 전환\n",
171 | "\n",
172 | " # 데이터로더에서 미니배치를 하나씩 꺼내 학습을 수행\n",
173 | " for data, targets in loader_train:\n",
174 | " \n",
175 | " optimizer.zero_grad() # 경사를 0으로 초기화\n",
176 | " outputs = model(data) # 데이터를 입력하고 출력을 계산\n",
177 | " loss = loss_fn(outputs, targets) # 출력과 훈련 데이터 정답 간의 오차를 계산\n",
178 | " loss.backward() # 오차를 역전파 계산\n",
179 | " optimizer.step() # 역전파 계산한 값으로 가중치를 수정\n",
180 | "\n",
181 | " print(\"epoch{}:완료\\n\".format(epoch))\n"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": 11,
187 | "metadata": {},
188 | "outputs": [],
189 | "source": [
190 | "# 5. 학습 및 추론 설정\n",
191 | "# 5-2. 추론 1에포크에서 할 일을 함수로 정의\n",
192 | "# 파이토치에는 Chainer의 trainer.extend(extensions.Evaluator())에 해당하는 것이 없음\n",
193 | "\n",
194 | "\n",
195 | "def test():\n",
196 | " model.eval() # 신경망을 추론 모드로 전환\n",
197 | " correct = 0\n",
198 | "\n",
199 | " # 데이터로더에서 미니배치를 하나씩 꺼내 추론을 수행\n",
200 | " with torch.no_grad(): # 추론 과정에는 미분이 필요없음\n",
201 | " for data, targets in loader_test:\n",
202 | "\n",
203 | " outputs = model(data) # 데이터를 입력하고 출력을 계산\n",
204 | "\n",
205 | " # 추론 계산\n",
206 | " _, predicted = torch.max(outputs.data, 1) # 확률이 가장 높은 레이블이 무엇인지 계산\n",
207 | " correct += predicted.eq(targets.data.view_as(predicted)).sum() # 정답과 일치한 경우 정답 카운트를 증가\n",
208 | "\n",
209 | " # 정확도 출력\n",
210 | " data_num = len(loader_test.dataset) # 데이터 총 건수\n",
211 | " print('\\n테스트 데이터에서 예측 정확도: {}/{} ({:.0f}%)\\n'.format(correct,\n",
212 | " data_num, 100. * correct / data_num))\n"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "execution_count": 12,
218 | "metadata": {},
219 | "outputs": [
220 | {
221 | "name": "stdout",
222 | "output_type": "stream",
223 | "text": [
224 | "\n",
225 | "테스트 데이터에서 예측 정확도: 982/10000 (9%)\n",
226 | "\n"
227 | ]
228 | }
229 | ],
230 | "source": [
231 | "# 학습 전 상태에서 테스트 데이터로 정확도 측정\n",
232 | "test()\n"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": 15,
238 | "metadata": {},
239 | "outputs": [
240 | {
241 | "name": "stdout",
242 | "output_type": "stream",
243 | "text": [
244 | "epoch0:완료\n",
245 | "\n",
246 | "epoch1:완료\n",
247 | "\n",
248 | "epoch2:완료\n",
249 | "\n",
250 | "\n",
251 | "테스트 데이터에서 예측 정확도: 9618/10000 (96%)\n",
252 | "\n"
253 | ]
254 | }
255 | ],
256 | "source": [
257 | "# 6. 학습 및 추론 수행\n",
258 | "for epoch in range(3):\n",
259 | " train(epoch)\n",
260 | "\n",
261 | "test()\n"
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "execution_count": 16,
267 | "metadata": {},
268 | "outputs": [
269 | {
270 | "name": "stdout",
271 | "output_type": "stream",
272 | "text": [
273 | "예측 결과 : 7\n",
274 | "이 이미지 데이터의 정답 레이블은 7입니다\n"
275 | ]
276 | },
277 | {
278 | "data": {
279 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADIZJREFUeJzt3VHIXPWZx/Hvs257Y3uhFG1I36zdIssuXtglSGLL4rJY3KUQe5GkXmVhaXpRYRMVKt40N4WyGE2vCikNjdDaaNquXpTdiizYxUSMUqpttq2U7Pu+GpKWFKpXRX324j0ub+M750xm5syZN8/3A2Fmzn/mnIdDfu85M/9z/v/ITCTV82dDFyBpGIZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJRfz7PjUWElxNKPcvMGOd9Ux35I+KuiPhlRLwWEQ9Osy5J8xWTXtsfEdcAvwLuBFaBF4F7MvMXLZ/xyC/1bB5H/tuA1zLzN5n5R+B7wK4p1idpjqYJ/1ZgZd3r1WbZn4iI/RFxJiLOTLEtSTM2zQ9+G51avO+0PjOPAkfB035pkUxz5F8Flta9/hjwxnTlSJqXacL/InBzRHw8Ij4IfB54ejZlSerbxKf9mfl2RNwL/CdwDXAsM38+s8ok9Wrirr6JNuZ3fql3c7nIR9LmZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJRE0/RDRAR54A3gXeAtzNz+yyKktS/qcLf+PvM/N0M1iNpjjztl4qaNvwJ/DgiXoqI/bMoSNJ8THva/6nMfCMibgCeiYj/yczn1r+h+aPgHwZpwURmzmZFEYeAtzLz4Zb3zGZjkkbKzBjnfROf9kfEtRHx4feeA58BXp10fZLma5rT/huBH0bEe+v5bmb+x0yqktS7mZ32j7WxHk/7d+zY0dr+xBNPTLX+06dPj2w7depU62dff/311vbl5eWJty1drvfTfkmbm+GXijL8UlGGXyrK8EtFGX6pqFnc1bcQtm3b1tq+tLQ01frbPr979+6p1j2tJ598cmRbVzdkV7vdjFcvj/xSUYZfKsrwS0UZfqkowy8VZfilogy/VNRVc0tvVz/+zp07W9sffnjkAESd65+2L72r9iGvI1hZWWltb7vGAODIkSMTr1uT8ZZeSa0Mv1SU4ZeKMvxSUYZfKsrwS0UZfqmoq6aff1pdQ3+39dV39XXv2bNnoprG1bb+rVu3tn626xqCrusjptG1306ePNnaPu1w7Fcr+/kltTL8UlGGXyrK8EtFGX6pKMMvFWX4paI6+/kj4hjwWeBiZt7SLLseOAHcBJwD9mTm7zs3tsD9/F3a9tPQ/fx96hpr4MCBA63tbdcRTDuXQtd4AG37/Wqej2CW/fzfBu66bNmDwLOZeTPwbPNa0ibSGf7MfA64dNniXcDx5vlx4O4Z1yWpZ5N+578xM88DNI83zK4kSfPQ+1x9EbEf2N/3diRdmUmP/BciYgtA83hx1Bsz82hmbs/M7RNuS1IPJg3/08C+5vk+4KnZlCNpXjrDHxGPA6eAv4qI1Yj4F+BrwJ0R8Wvgzua1pE3E+/nH9Pzzz49sW11dbf3sZu7nn1ZbX/401wh0rRvarwN44IEHWj+7mccK8H5+Sa0Mv1SU4ZeKMvxSUYZfKsrwS0XZ1TemabrrNnO30ZC6uvIOHz7c2t7WVdh1O/C2bdta2xeZXX2SWhl+qSjDLxVl+KWiDL9UlOGXijL8UlH282vT6rr24sSJExOve+/eva3ti3zthv38kloZfqkowy8VZfilogy/VJThl4oy/FJRvU/XJfWla+jvaSwvL/e27kXhkV8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXiurs54+IY8BngYuZeUuz7BDwBeC3zdseyswf9VWkajp48GBr+86dOyde9yOPPNLafvr06YnXvVmMc+T/NnDXBssfzcxbm38GX9pkOsOfmc8Bl+ZQi6Q5muY7/70R8bOIOBYR182sIklzMWn4vwF8ArgVOA+MnDQtIvZHxJmIODPhtiT1YKLwZ+aFzHwnM98Fvgnc1vLeo5m5PTO3T1qkpNmbKPwRsWXdy88Br86mHEnzMk5X3+PAHcBHImIV+ApwR0TcCiRwDvhijzVK6oHj9mswO3bsaG3vGht/aWmptf3UqVMj226//fbWz25mjtsvqZXhl4oy/FJRhl8qyvBLRRl+qSiH7lav2rrjum6r7erKW1lZaW0/cuRIa3t1Hvmlogy/VJThl4oy/FJRhl8qyvBLRRl+qShv6VWv2m7L3b1791Tr3rt378Tbvpp5S6+kVoZfKsrwS0UZfqkowy8VZfilogy/VJT382sqXX3p0/Tld93vX7Uff1Y88ktFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUZ39/BGxBDwGfBR4FziamV+PiOuBE8BNwDlgT2b+vr9S1YeusfEPHz7c2t5nP/79998/8brVbZwj/9vA/Zn518AO4EsR8TfAg8CzmXkz8GzzWtIm0Rn+zDyfmS83z98EzgJbgV3A8eZtx4G7+ypS0uxd0Xf+iLgJ+CTwAnBjZp6HtT8QwA2zLk5Sf8a+tj8iPgR8HziQmX+IGGuYMCJiP7B/svIk9WWsI39EfIC14H8nM3/QLL4QEVua9i3AxY0+m5lHM3N7Zm6fRcGSZqMz/LF2iP8WcDYz1/88+zSwr3m+D3hq9uVJ6kvn0N0R8WngJ8ArrHX1ATzE2vf+J4BtwDKwOzMvdazLobsXTJ+35ALcd999I9seffTRqdatjY07dHfnd/7M/G9g1Mr+4UqKkrQ4vMJPKsrwS0UZfqkowy8VZfilogy/VJRDd1/lDh482NreZz8+2Je/yDzyS0UZfqkowy8VZfilogy/VJThl4oy/FJRnffzz3Rj3s/fix07doxsO3Xq1FTrdnjtzWfc+/k98ktFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUd7PfxXouqd+Gl1TeC8vL7e279mzZ2Tb6dOnJ6pJs+GRXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeK6uznj4gl4DHgo8C7wNHM/HpEHAK+APy2eetDmfmjvgrVaCdPnhzZ1jUu/8rKylTbbuvHB/vyF9k4F/m8DdyfmS9HxIeBlyLimabt0cx8uL/yJPWlM/yZeR443zx/MyLOAlv7LkxSv67oO39E3AR8EnihWXRvRPwsIo5FxHUjPrM/Is5ExJmpKpU0U2OHPyI+BHwfOJCZfwC+AXwCuJW1M4PDG30uM49m5vbM3D6DeiXNyFjhj4gPsBb872TmDwAy80JmvpOZ7wLfBG7rr0xJs9YZ/ogI4FvA2cx8ZN3yLeve9jng1dmXJ6kvnUN3R8SngZ8Ar7DW1QfwEHAPa6f8CZwDvtj8ONi2Lofulno27tDdjtsvXWUct19SK8MvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJR856i+3fA/657/ZFm2SJa1NoWtS6wtknNsra/GPeNc72f/30bjzizqGP7LWpti1oXWNukhqrN036pKMMvFTV0+I8OvP02i1rbotYF1japQWob9Du/pOEMfeSXNJBBwh8Rd0XELyPitYh4cIgaRomIcxHxSkT8dOgpxppp0C5GxKvrll0fEc9ExK+bxw2nSRuotkMR8Xqz734aEf80UG1LEfFfEXE2In4eEf/aLB9037XUNch+m/tpf0RcA/wKuBNYBV4E7snMX8y1kBEi4hywPTMH7xOOiL8D3gIey8xbmmX/BlzKzK81fzivy8wvL0hth4C3hp65uZlQZsv6maWBu4F/ZsB911LXHgbYb0Mc+W8DXsvM32TmH4HvAbsGqGPhZeZzwKXLFu8CjjfPj7P2n2fuRtS2EDLzfGa+3Dx/E3hvZulB911LXYMYIvxbgZV1r1dZrCm/E/hxRLwUEfuHLmYDN743M1LzeMPA9Vyuc+bmebpsZumF2XeTzHg9a0OEf6PZRBapy+FTmfm3wD8CX2pObzWesWZunpcNZpZeCJPOeD1rQ4R/FVha9/pjwBsD1LGhzHyjebwI/JDFm334wnuTpDaPFweu5/8t0szNG80szQLsu0Wa8XqI8L8I3BwRH4+IDwKfB54eoI73iYhrmx9iiIhrgc+weLMPPw3sa57vA54asJY/sSgzN4+aWZqB992izXg9yEU+TVfGEeAa4FhmfnXuRWwgIv6StaM9rN3x+N0ha4uIx4E7WLvr6wLwFeDfgSeAbcAysDsz5/7D24ja7uAKZ27uqbZRM0u/wID7bpYzXs+kHq/wk2ryCj+pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0X9H2LFAu6LVvxRAAAAAElFTkSuQmCC\n",
280 | "text/plain": [
281 | ""
282 | ]
283 | },
284 | "metadata": {
285 | "needs_background": "light"
286 | },
287 | "output_type": "display_data"
288 | }
289 | ],
290 | "source": [
291 | "# 2018번째 데이터를 예로 추론 수행\n",
292 | "\n",
293 | "index = 2018\n",
294 | "\n",
295 | "model.eval() # 신경망을 추론 모드로 전환\n",
296 | "data = X_test[index]\n",
297 | "output = model(data) # 데이터를 입력하고 출력을 계산\n",
298 | "_, predicted = torch.max(output.data, 0) # 확률이 가장 높은 레이블이 무엇인지 계산\n",
299 | "\n",
300 | "print(\"예측 결과 : {}\".format(predicted))\n",
301 | "\n",
302 | "X_test_show = (X_test[index]).numpy()\n",
303 | "plt.imshow(X_test_show.reshape(28, 28), cmap='gray')\n",
304 | "print(\"이 이미지 데이터의 정답 레이블은 {:.0f}입니다\".format(y_test[index]))\n"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": 17,
310 | "metadata": {},
311 | "outputs": [],
312 | "source": [
313 | "#-----------------------------------------------"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": 18,
319 | "metadata": {},
320 | "outputs": [
321 | {
322 | "name": "stdout",
323 | "output_type": "stream",
324 | "text": [
325 | "Net(\n",
326 | " (fc1): Linear(in_features=784, out_features=100, bias=True)\n",
327 | " (fc2): Linear(in_features=100, out_features=100, bias=True)\n",
328 | " (fc3): Linear(in_features=100, out_features=10, bias=True)\n",
329 | ")\n"
330 | ]
331 | }
332 | ],
333 | "source": [
334 | "# 3. 신경망 구성\n",
335 | "# 신경망 구성 (Chainer 스타일)\n",
336 | "import torch.nn as nn\n",
337 | "import torch.nn.functional as F\n",
338 | "\n",
339 | "\n",
340 | "class Net(nn.Module):\n",
341 | "\n",
342 | " def __init__(self, n_in, n_mid, n_out):\n",
343 | " super(Net, self).__init__()\n",
344 | " self.fc1 = nn.Linear(n_in, n_mid) # Chainer와 달리、None을 받을 수는 없다\n",
345 | " self.fc2 = nn.Linear(n_mid, n_mid)\n",
346 | " self.fc3 = nn.Linear(n_mid, n_out)\n",
347 | "\n",
348 | " def forward(self, x):\n",
349 | " # 입력 x에 따라 forward 계산 과정이 변화함\n",
350 | " h1 = F.relu(self.fc1(x))\n",
351 | " h2 = F.relu(self.fc2(h1))\n",
352 | " output = self.fc3(h2)\n",
353 | " return output\n",
354 | "\n",
355 | "\n",
356 | "model = Net(n_in=28*28*1, n_mid=100, n_out=10) # 신경망 객체를 생성\n",
357 | "print(model)\n"
358 | ]
359 | }
360 | ],
361 | "metadata": {
362 | "kernelspec": {
363 | "display_name": "Python 3",
364 | "language": "python",
365 | "name": "python3"
366 | },
367 | "language_info": {
368 | "codemirror_mode": {
369 | "name": "ipython",
370 | "version": 3
371 | },
372 | "file_extension": ".py",
373 | "mimetype": "text/x-python",
374 | "name": "python",
375 | "nbconvert_exporter": "python",
376 | "pygments_lexer": "ipython3",
377 | "version": "3.6.6"
378 | }
379 | },
380 | "nbformat": 4,
381 | "nbformat_minor": 2
382 | }
383 |
--------------------------------------------------------------------------------
/program/6_2_DDQN.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## 6.2 DDQN 구현"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "# 구현에 사용할 패키지 임포트\n",
17 | "import numpy as np\n",
18 | "import matplotlib.pyplot as plt\n",
19 | "%matplotlib inline\n",
20 | "import gym\n"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "# 애니메이션을 만드는 함수\n",
30 | "# 참고 URL http://nbviewer.jupyter.org/github/patrickmineault\n",
31 | "# /xcorr-notebooks/blob/master/Render%20OpenAI%20gym%20as%20GIF.ipynb\n",
32 | "from JSAnimation.IPython_display import display_animation\n",
33 | "from matplotlib import animation\n",
34 | "from IPython.display import display\n",
35 | "\n",
36 | "\n",
37 | "def display_frames_as_gif(frames):\n",
38 | " \"\"\"\n",
39 | " Displays a list of frames as a gif, with controls\n",
40 | " \"\"\"\n",
41 | " plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0),\n",
42 | " dpi=72)\n",
43 | " patch = plt.imshow(frames[0])\n",
44 | " plt.axis('off')\n",
45 | "\n",
46 | " def animate(i):\n",
47 | " patch.set_data(frames[i])\n",
48 | "\n",
49 | " anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames),\n",
50 | " interval=50)\n",
51 | "\n",
52 | " anim.save('movie_cartpole_DDQN.mp4') # 애니메이션을 저장하는 부분\n",
53 | " display(display_animation(anim, default_mode='loop'))\n"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 3,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "# namedtuple 생성\n",
63 | "from collections import namedtuple\n",
64 | "\n",
65 | "Transition = namedtuple(\n",
66 | " 'Transition', ('state', 'action', 'next_state', 'reward'))\n"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 4,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "# 상수 정의\n",
76 | "ENV = 'CartPole-v0' # 태스크 이름\n",
77 | "GAMMA = 0.99 # 시간할인율\n",
78 | "MAX_STEPS = 200 # 1에피소드 당 최대 단계 수\n",
79 | "NUM_EPISODES = 500 # 최대 에피소드 수\n"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 5,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "# transition을 저장하기 위한 메모리 클래스\n",
89 | "\n",
90 | "\n",
91 | "class ReplayMemory:\n",
92 | "\n",
93 | " def __init__(self, CAPACITY):\n",
94 | " self.capacity = CAPACITY # 메모리의 최대 저장 건수\n",
95 | " self.memory = [] # 실제 transition을 저장할 변수\n",
96 | " self.index = 0 # 저장 위치를 가리킬 인덱스 변수\n",
97 | "\n",
98 | " def push(self, state, action, state_next, reward):\n",
99 | " '''transition = (state, action, state_next, reward)을 메모리에 저장'''\n",
100 | "\n",
101 | " if len(self.memory) < self.capacity:\n",
102 | " self.memory.append(None) # 메모리가 가득차지 않은 경우\n",
103 | "\n",
104 | " # Transition이라는 namedtuple을 사용하여 키-값 쌍의 형태로 값을 저장\n",
105 | " self.memory[self.index] = Transition(state, action, state_next, reward)\n",
106 | "\n",
107 | " self.index = (self.index + 1) % self.capacity # 다음 저장할 위치를 한 자리 뒤로 수정\n",
108 | "\n",
109 | " def sample(self, batch_size):\n",
110 | " '''batch_size 갯수 만큼 무작위로 저장된 transition을 추출'''\n",
111 | " return random.sample(self.memory, batch_size)\n",
112 | "\n",
113 | " def __len__(self):\n",
114 | " '''len 함수로 현재 저장된 transition 갯수를 반환'''\n",
115 | " return len(self.memory)\n"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 6,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "# 신경망 구성\n",
125 | "import torch.nn as nn\n",
126 | "import torch.nn.functional as F\n",
127 | "\n",
128 | "\n",
129 | "class Net(nn.Module):\n",
130 | "\n",
131 | " def __init__(self, n_in, n_mid, n_out):\n",
132 | " super(Net, self).__init__()\n",
133 | " self.fc1 = nn.Linear(n_in, n_mid)\n",
134 | " self.fc2 = nn.Linear(n_mid, n_mid)\n",
135 | " self.fc3 = nn.Linear(n_mid, n_out)\n",
136 | "\n",
137 | " def forward(self, x):\n",
138 | " h1 = F.relu(self.fc1(x))\n",
139 | " h2 = F.relu(self.fc2(h1))\n",
140 | " output = self.fc3(h2)\n",
141 | " return output\n"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": 7,
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "# 에이전트의 두뇌 역할을 하는 클래스, DDQN을 실제 수행한다 \n",
151 | "\n",
152 | "import random\n",
153 | "import torch\n",
154 | "from torch import nn\n",
155 | "from torch import optim\n",
156 | "import torch.nn.functional as F\n",
157 | "\n",
158 | "BATCH_SIZE = 32\n",
159 | "CAPACITY = 10000\n",
160 | "\n",
161 | "\n",
162 | "class Brain:\n",
163 | " def __init__(self, num_states, num_actions):\n",
164 | " self.num_actions = num_actions # CartPoleの行動(右に左に押す)の2を取得\n",
165 | "\n",
166 | " # transition을 기억하기 위한 메모리 객체 생성\n",
167 | " self.memory = ReplayMemory(CAPACITY)\n",
168 | "\n",
169 | " # 신경망 구성\n",
170 | " n_in, n_mid, n_out = num_states, 32, num_actions\n",
171 | " self.main_q_network = Net(n_in, n_mid, n_out) # Net 클래스를 사용\n",
172 | " self.target_q_network = Net(n_in, n_mid, n_out) # Net 클래스를 사용\n",
173 | " print(self.main_q_network) # 신경망의 구조를 출력\n",
174 | "\n",
175 | " # 최적화 기법 선택\n",
176 | " self.optimizer = optim.Adam(\n",
177 | " self.main_q_network.parameters(), lr=0.0001)\n",
178 | "\n",
179 | " def replay(self):\n",
180 | " '''Experience Replay로 신경망의 결합 가중치 학습'''\n",
181 | "\n",
182 | " # 1. 저장된 transition의 수를 확인\n",
183 | " if len(self.memory) < BATCH_SIZE:\n",
184 | " return\n",
185 | "\n",
186 | " # 2. 미니배치 생성\n",
187 | " self.batch, self.state_batch, self.action_batch, self.reward_batch, self.non_final_next_states = self.make_minibatch()\n",
188 | "\n",
189 | " # 3. 정답신호로 사용할 Q(s_t, a_t)를 계산\n",
190 | " self.expected_state_action_values = self.get_expected_state_action_values()\n",
191 | "\n",
192 | " # 4. 결합 가중치 수정\n",
193 | " self.update_main_q_network()\n",
194 | "\n",
195 | " def decide_action(self, state, episode):\n",
196 | " '''현재 상태로부터 행동을 결정함'''\n",
197 | " # ε-greedy 알고리즘에서 서서히 최적행동의 비중을 늘린다\n",
198 | " epsilon = 0.5 * (1 / (episode + 1))\n",
199 | "\n",
200 | " if epsilon <= np.random.uniform(0, 1):\n",
201 | " self.main_q_network.eval() # 신경망을 추론 모드로 전환\n",
202 | " with torch.no_grad():\n",
203 | " action = self.main_q_network(state).max(1)[1].view(1, 1)\n",
204 | " # 신경망 출력의 최댓값에 대한 인덱스 = max(1)[1]\n",
205 | " # .view(1,1)은 [torch.LongTensor of size 1] 을 size 1*1로 변환하는 역할을 한다\n",
206 | "\n",
207 | " else:\n",
208 | " # 행동을 무작위로 반환(0 혹은 1)\n",
209 | " action = torch.LongTensor(\n",
210 | " [[random.randrange(self.num_actions)]]) # 행동을 무작위로 반환(0 혹은 1)\n",
211 | " # action은 [torch.LongTensor of size 1*1] 형태가 된다\n",
212 | "\n",
213 | " return action\n",
214 | "\n",
215 | " def make_minibatch(self):\n",
216 | " '''2. 미니배치 생성'''\n",
217 | "\n",
218 | " # 2.1 메모리 객체에서 미니배치를 추출\n",
219 | " transitions = self.memory.sample(BATCH_SIZE)\n",
220 | "\n",
221 | " # 2.2 각 변수를 미니배치에 맞는 형태로 변형\n",
222 | " # transitions는 각 단계 별로 (state, action, state_next, reward) 형태로 BATCH_SIZE 갯수만큼 저장됨\n",
223 | " # 다시 말해, (state, action, state_next, reward) * BATCH_SIZE 형태가 된다\n",
224 | " # 이것을 미니배치로 만들기 위해\n",
225 | " # (state*BATCH_SIZE, action*BATCH_SIZE, state_next*BATCH_SIZE, reward*BATCH_SIZE) 형태로 변환한다\n",
226 | " batch = Transition(*zip(*transitions))\n",
227 | "\n",
228 | " # 2.3 각 변수의 요소를 미니배치에 맞게 변형하고, 신경망으로 다룰 수 있도록 Variable로 만든다\n",
229 | " # state를 예로 들면, [torch.FloatTensor of size 1*4] 형태의 요소가 BATCH_SIZE 갯수만큼 있는 형태이다\n",
230 | " # 이를 torch.FloatTensor of size BATCH_SIZE*4 형태로 변형한다\n",
231 | " # 상태, 행동, 보상, non_final 상태로 된 미니배치를 나타내는 Variable을 생성\n",
232 | " # cat은 Concatenates(연접)을 의미한다\n",
233 | " state_batch = torch.cat(batch.state)\n",
234 | " action_batch = torch.cat(batch.action)\n",
235 | " reward_batch = torch.cat(batch.reward)\n",
236 | " non_final_next_states = torch.cat([s for s in batch.next_state\n",
237 | " if s is not None])\n",
238 | "\n",
239 | " return batch, state_batch, action_batch, reward_batch, non_final_next_states\n",
240 | "\n",
241 | " def get_expected_state_action_values(self):\n",
242 | " '''정답신호로 사용할 Q(s_t, a_t)를 계산'''\n",
243 | "\n",
244 | " # 3.1 신경망을 추론 모드로 전환\n",
245 | " self.main_q_network.eval()\n",
246 | " self.target_q_network.eval()\n",
247 | "\n",
248 | " # 3.2 신경망으로 Q(s_t, a_t)를 계산\n",
249 | " # self.model(state_batch)은 왼쪽, 오른쪽에 대한 Q값을 출력하며\n",
250 | " # [torch.FloatTensor of size BATCH_SIZEx2] 형태이다\n",
251 | " # 여기서부터는 실행한 행동 a_t에 대한 Q값을 계산하므로 action_batch에서 취한 행동 a_t가 \n",
252 | " # 왼쪽이냐 오른쪽이냐에 대한 인덱스를 구하고, 이에 대한 Q값을 gather 메서드로 모아온다\n",
253 | " self.state_action_values = self.main_q_network(\n",
254 | " self.state_batch).gather(1, self.action_batch)\n",
255 | "\n",
256 | " # 3.3 max{Q(s_t+1, a)}값을 계산한다 이때 다음 상태가 존재하는지에 주의해야 한다\n",
257 | "\n",
258 | " # cartpole이 done 상태가 아니고, next_state가 존재하는지 확인하는 인덱스 마스크를 만듬\n",
259 | " non_final_mask = torch.ByteTensor(tuple(map(lambda s: s is not None,\n",
260 | " self.batch.next_state)))\n",
261 | " # 먼저 전체를 0으로 초기화\n",
262 | " next_state_values = torch.zeros(BATCH_SIZE)\n",
263 | "\n",
264 | " a_m = torch.zeros(BATCH_SIZE).type(torch.LongTensor)\n",
265 | "\n",
266 | " # 다음 상태에서 Q값이 최대가 되는 행동 a_m을 Main Q-Network로 계산\n",
267 | " # 마지막에 붙은 [1]로 행동에 해당하는 인덱스를 구함\n",
268 | " a_m[non_final_mask] = self.main_q_network(\n",
269 | " self.non_final_next_states).detach().max(1)[1]\n",
270 | "\n",
271 | " # 다음 상태가 있는 것만을 걸러내고, size 32를 32*1로 변환\n",
272 | " a_m_non_final_next_states = a_m[non_final_mask].view(-1, 1)\n",
273 | "\n",
274 | " # 다음 상태가 있는 인덱스에 대해 행동 a_m의 Q값을 target Q-Network로 계산\n",
275 | " # detach() 메서드로 값을 꺼내옴\n",
276 | " # squeeze() 메서드로 size[minibatch*1]을 [minibatch]로 변환\n",
277 | " next_state_values[non_final_mask] = self.target_q_network(\n",
278 | " self.non_final_next_states).gather(1, a_m_non_final_next_states).detach().squeeze()\n",
279 | "\n",
280 | " # 3.4 정답신호로 사용할 Q(s_t, a_t)값을 Q러닝 식으로 계산한다\n",
281 | " expected_state_action_values = self.reward_batch + GAMMA * next_state_values\n",
282 | "\n",
283 | " return expected_state_action_values\n",
284 | "\n",
285 | " def update_main_q_network(self):\n",
286 | " '''4. 결합 가중치 수정'''\n",
287 | "\n",
288 | " # 4.1 신경망을 학습 모드로 전환\n",
289 | " self.main_q_network.train()\n",
290 | "\n",
291 | " # 4.2 손실함수를 계산 (smooth_l1_loss는 Huber 함수)\n",
292 | " # expected_state_action_values은\n",
293 | " # size가 [minibatch]이므로 unsqueeze하여 [minibatch*1]로 만든다\n",
294 | " loss = F.smooth_l1_loss(self.state_action_values,\n",
295 | " self.expected_state_action_values.unsqueeze(1))\n",
296 | "\n",
297 | " # 4.3 결합 가중치를 수정한다\n",
298 | " self.optimizer.zero_grad() # 경사를 초기화\n",
299 | " loss.backward() # 역전파 계산\n",
300 | " self.optimizer.step() # 결합 가중치 수정\n",
301 | "\n",
302 | " def update_target_q_network(self): # DDQN에서 추가됨\n",
303 | " '''Target Q-Network을 Main Q-Network와 맞춤'''\n",
304 | " self.target_q_network.load_state_dict(self.main_q_network.state_dict())\n"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": 8,
310 | "metadata": {},
311 | "outputs": [],
312 | "source": [
313 | "# CartPole 태스크의 에이전트 클래스. 봉 달린 수레 자체라고 보면 된다\n",
314 | "\n",
315 | "\n",
316 | "class Agent:\n",
317 | " def __init__(self, num_states, num_actions):\n",
318 | " '''태스크의 상태 및 행동의 가짓수를 설정'''\n",
319 | " self.brain = Brain(num_states, num_actions) # 에이전트의 행동을 결정할 두뇌 역할 객체를 생성\n",
320 | "\n",
321 | " def update_q_function(self):\n",
322 | " '''Q함수를 수정'''\n",
323 | " self.brain.replay()\n",
324 | "\n",
325 | " def get_action(self, state, episode):\n",
326 | " '''행동을 결정'''\n",
327 | " action = self.brain.decide_action(state, episode)\n",
328 | " return action\n",
329 | "\n",
330 | " def memorize(self, state, action, state_next, reward):\n",
331 | " '''memory 객체에 state, action, state_next, reward 내용을 저장'''\n",
332 | " self.brain.memory.push(state, action, state_next, reward)\n",
333 | "\n",
334 | " def update_target_q_function(self):\n",
335 | " '''Target Q-Network을 Main Q-Network와 맞춤'''\n",
336 | " self.brain.update_target_q_network()\n",
337 | " "
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": 9,
343 | "metadata": {},
344 | "outputs": [],
345 | "source": [
346 | "# CartPole을 실행하는 환경 클래스\n",
347 | "\n",
348 | "\n",
349 | "class Environment:\n",
350 | "\n",
351 | " def __init__(self):\n",
352 | " self.env = gym.make(ENV) # 태스크를 설정\n",
353 | " num_states = self.env.observation_space.shape[0] # 태스크의 상태 변수 수(4)를 받아옴\n",
354 | " num_actions = self.env.action_space.n # 태스크의 행동 가짓수(2)를 받아옴\n",
355 | " self.agent = Agent(num_states, num_actions) # 에이전트 역할을 할 객체를 생성\n",
356 | "\n",
357 | " def run(self):\n",
358 | " '''실행'''\n",
359 | " episode_10_list = np.zeros(10) # 최근 10에피소드 동안 버틴 단계 수를 저장함(평균 단계 수를 출력할 때 사용)\n",
360 | " complete_episodes = 0 # 현재까지 195단계를 버틴 에피소드 수\n",
361 | " episode_final = False # 마지막 에피소드 여부\n",
362 | " frames = [] # 애니메이션을 만들기 위해 마지막 에피소드의 프레임을 저장할 배열\n",
363 | "\n",
364 | " for episode in range(NUM_EPISODES): # 최대 에피소드 수만큼 반복\n",
365 | " observation = self.env.reset() # 환경 초기화\n",
366 | "\n",
367 | " state = observation # 관측을 변환없이 그대로 상태 s로 사용\n",
368 | " state = torch.from_numpy(state).type(\n",
369 | " torch.FloatTensor) # NumPy 변수를 파이토치 텐서로 변환\n",
370 | " state = torch.unsqueeze(state, 0) # size 4를 size 1*4로 변환\n",
371 | "\n",
372 | " for step in range(MAX_STEPS): # 1 에피소드에 해당하는 반복문\n",
373 | " \n",
374 | " # 애니메이션 만드는 부분을 주석처리\n",
375 | " #if episode_final is True: # 마지막 에피소드에서는 각 시각의 이미지를 frames에 저장한다\n",
376 | " # frames.append(self.env.render(mode='rgb_array'))\n",
377 | " \n",
378 | " action = self.agent.get_action(state, episode) # 다음 행동을 결정\n",
379 | "\n",
380 | " # 행동 a_t를 실행하여 다음 상태 s_{t+1}과 done 플래그 값을 결정\n",
381 | " # action에 .item()을 호출하여 행동 내용을 구함\n",
382 | " observation_next, _, done, _ = self.env.step(\n",
383 | " action.item()) # reward와 info는 사용하지 않으므로 _로 처리\n",
384 | "\n",
385 | " # 보상을 부여하고 episode의 종료 판정 및 state_next를 설정한다\n",
386 | " if done: # 단계 수가 200을 넘었거나 봉이 일정 각도 이상 기울면 done이 True가 됨\n",
387 | " state_next = None # 다음 상태가 없으므로 None으로\n",
388 | "\n",
389 | " # 최근 10 에피소드에서 버틴 단계 수를 리스트에 저장\n",
390 | " episode_10_list = np.hstack(\n",
391 | " (episode_10_list[1:], step + 1))\n",
392 | "\n",
393 | " if step < 195:\n",
394 | " reward = torch.FloatTensor(\n",
395 | " [-1.0]) # 도중에 봉이 쓰러졌다면 페널티로 보상 -1을 부여\n",
396 | " complete_episodes = 0 # 연속 성공 에피소드 기록을 초기화\n",
397 | " else:\n",
398 | " reward = torch.FloatTensor([1.0]) # 봉이 서 있는 채로 에피소드를 마쳤다면 보상 1 부여\n",
399 | " complete_episodes = complete_episodes + 1 # 연속 성공 에피소드 기록을 갱신\n",
400 | " else:\n",
401 | " reward = torch.FloatTensor([0.0]) # 그 외의 경우는 보상 0을 부여\n",
402 | " state_next = observation_next # 관측 결과를 그대로 상태로 사용\n",
403 | " state_next = torch.from_numpy(state_next).type(\n",
404 | " torch.FloatTensor) # numpy 변수를 파이토치 텐서로 변환\n",
405 | " state_next = torch.unsqueeze(state_next, 0) # size 4를 size 1*4로 변환\n",
406 | "\n",
407 | " # 메모리에 경험을 저장\n",
408 | " self.agent.memorize(state, action, state_next, reward)\n",
409 | "\n",
410 | " # Experience Replay로 Q함수를 수정\n",
411 | " self.agent.update_q_function()\n",
412 | "\n",
413 | " # 관측 결과를 업데이트\n",
414 | " state = state_next\n",
415 | "\n",
416 | " # 에피소드 종료 처리\n",
417 | " if done:\n",
418 | " print('%d Episode: Finished after %d steps:최근 10 에피소드의 평균 단계 수 = %.1lf' % (\n",
419 | " episode, step + 1, episode_10_list.mean()))\n",
420 | " \n",
421 | " # DDQN으로 추가된 부분 2에피소드마다 한번씩 Target Q-Network을 Main Q-Network와 맞춰줌\n",
422 | " if(episode % 2 == 0):\n",
423 | " self.agent.update_target_q_function()\n",
424 | " break\n",
425 | " \n",
426 | " \n",
427 | " if episode_final is True:\n",
428 | " # 애니메이션 생성 부분을 주석처리함\n",
429 | " # 애니메이션 생성 및 저장\n",
430 | " #display_frames_as_gif(frames)\n",
431 | " break\n",
432 | "\n",
433 | " # 10 에피소드 연속으로 195단계를 버티면 태스크 성공\n",
434 | " if complete_episodes >= 10:\n",
435 | " print('10 에피소드 연속 성공')\n",
436 | " episode_final = True # 다음 에피소드에서 애니메이션을 생성\n"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": 10,
442 | "metadata": {},
443 | "outputs": [
444 | {
445 | "name": "stdout",
446 | "output_type": "stream",
447 | "text": [
448 | "Net(\n",
449 | " (fc1): Linear(in_features=4, out_features=32, bias=True)\n",
450 | " (fc2): Linear(in_features=32, out_features=32, bias=True)\n",
451 | " (fc3): Linear(in_features=32, out_features=2, bias=True)\n",
452 | ")\n",
453 | "0 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 1.1\n",
454 | "1 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 2.1\n",
455 | "2 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 3.2\n",
456 | "3 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 4.2\n",
457 | "4 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 5.1\n",
458 | "5 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 6.0\n",
459 | "6 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 7.0\n",
460 | "7 Episode: Finished after 8 steps:최근 10 에피소드의 평균 단계 수 = 7.8\n",
461 | "8 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 8.7\n",
462 | "9 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.6\n",
463 | "10 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n",
464 | "11 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 9.6\n",
465 | "12 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n",
466 | "13 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n",
467 | "14 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n",
468 | "15 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 9.7\n",
469 | "16 Episode: Finished after 14 steps:최근 10 에피소드의 평균 단계 수 = 10.1\n",
470 | "17 Episode: Finished after 15 steps:최근 10 에피소드의 평균 단계 수 = 10.8\n",
471 | "18 Episode: Finished after 79 steps:최근 10 에피소드의 평균 단계 수 = 17.8\n",
472 | "19 Episode: Finished after 37 steps:최근 10 에피소드의 평균 단계 수 = 20.6\n",
473 | "20 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 20.8\n",
474 | "21 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 20.7\n",
475 | "22 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 20.7\n",
476 | "23 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 20.7\n",
477 | "24 Episode: Finished after 8 steps:최근 10 에피소드의 평균 단계 수 = 20.5\n",
478 | "25 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 20.4\n",
479 | "26 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 19.9\n",
480 | "27 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 19.4\n",
481 | "28 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 12.4\n",
482 | "29 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.6\n",
483 | "30 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n",
484 | "31 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n",
485 | "32 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n",
486 | "33 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n",
487 | "34 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 9.7\n",
488 | "35 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.6\n",
489 | "36 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.7\n",
490 | "37 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.6\n",
491 | "38 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.7\n",
492 | "39 Episode: Finished after 8 steps:최근 10 에피소드의 평균 단계 수 = 9.6\n",
493 | "40 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n",
494 | "41 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n",
495 | "42 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n",
496 | "43 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n",
497 | "44 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n",
498 | "45 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n",
499 | "46 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 9.6\n",
500 | "47 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 9.8\n",
501 | "48 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.8\n",
502 | "49 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 10.2\n",
503 | "50 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 10.3\n",
504 | "51 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 10.5\n",
505 | "52 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 10.6\n",
506 | "53 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 10.7\n",
507 | "54 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 10.6\n",
508 | "55 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 10.7\n",
509 | "56 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 10.6\n",
510 | "57 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 10.7\n",
511 | "58 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 10.7\n",
512 | "59 Episode: Finished after 13 steps:최근 10 에피소드의 평균 단계 수 = 10.8\n",
513 | "60 Episode: Finished after 14 steps:최근 10 에피소드의 평균 단계 수 = 11.2\n",
514 | "61 Episode: Finished after 14 steps:최근 10 에피소드의 평균 단계 수 = 11.4\n",
515 | "62 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 11.6\n",
516 | "63 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 11.7\n",
517 | "64 Episode: Finished after 16 steps:최근 10 에피소드의 평균 단계 수 = 12.4\n",
518 | "65 Episode: Finished after 14 steps:최근 10 에피소드의 평균 단계 수 = 12.7\n",
519 | "66 Episode: Finished after 15 steps:최근 10 에피소드의 평균 단계 수 = 13.2\n",
520 | "67 Episode: Finished after 15 steps:최근 10 에피소드의 평균 단계 수 = 13.5\n",
521 | "68 Episode: Finished after 17 steps:최근 10 에피소드의 평균 단계 수 = 14.2\n",
522 | "69 Episode: Finished after 17 steps:최근 10 에피소드의 평균 단계 수 = 14.6\n",
523 | "70 Episode: Finished after 38 steps:최근 10 에피소드의 평균 단계 수 = 17.0\n",
524 | "71 Episode: Finished after 25 steps:최근 10 에피소드의 평균 단계 수 = 18.1\n",
525 | "72 Episode: Finished after 30 steps:최근 10 에피소드의 평균 단계 수 = 19.9\n",
526 | "73 Episode: Finished after 22 steps:최근 10 에피소드의 평균 단계 수 = 20.9\n",
527 | "74 Episode: Finished after 32 steps:최근 10 에피소드의 평균 단계 수 = 22.5\n",
528 | "75 Episode: Finished after 26 steps:최근 10 에피소드의 평균 단계 수 = 23.7\n",
529 | "76 Episode: Finished after 43 steps:최근 10 에피소드의 평균 단계 수 = 26.5\n",
530 | "77 Episode: Finished after 40 steps:최근 10 에피소드의 평균 단계 수 = 29.0\n",
531 | "78 Episode: Finished after 56 steps:최근 10 에피소드의 평균 단계 수 = 32.9\n",
532 | "79 Episode: Finished after 63 steps:최근 10 에피소드의 평균 단계 수 = 37.5\n",
533 | "80 Episode: Finished after 66 steps:최근 10 에피소드의 평균 단계 수 = 40.3\n",
534 | "81 Episode: Finished after 85 steps:최근 10 에피소드의 평균 단계 수 = 46.3\n",
535 | "82 Episode: Finished after 65 steps:최근 10 에피소드의 평균 단계 수 = 49.8\n",
536 | "83 Episode: Finished after 42 steps:최근 10 에피소드의 평균 단계 수 = 51.8\n",
537 | "84 Episode: Finished after 50 steps:최근 10 에피소드의 평균 단계 수 = 53.6\n",
538 | "85 Episode: Finished after 49 steps:최근 10 에피소드의 평균 단계 수 = 55.9\n",
539 | "86 Episode: Finished after 48 steps:최근 10 에피소드의 평균 단계 수 = 56.4\n",
540 | "87 Episode: Finished after 47 steps:최근 10 에피소드의 평균 단계 수 = 57.1\n",
541 | "88 Episode: Finished after 74 steps:최근 10 에피소드의 평균 단계 수 = 58.9\n",
542 | "89 Episode: Finished after 142 steps:최근 10 에피소드의 평균 단계 수 = 66.8\n",
543 | "90 Episode: Finished after 69 steps:최근 10 에피소드의 평균 단계 수 = 67.1\n",
544 | "91 Episode: Finished after 76 steps:최근 10 에피소드의 평균 단계 수 = 66.2\n",
545 | "92 Episode: Finished after 90 steps:최근 10 에피소드의 평균 단계 수 = 68.7\n",
546 | "93 Episode: Finished after 79 steps:최근 10 에피소드의 평균 단계 수 = 72.4\n",
547 | "94 Episode: Finished after 77 steps:최근 10 에피소드의 평균 단계 수 = 75.1\n",
548 | "95 Episode: Finished after 66 steps:최근 10 에피소드의 평균 단계 수 = 76.8\n",
549 | "96 Episode: Finished after 56 steps:최근 10 에피소드의 평균 단계 수 = 77.6\n",
550 | "97 Episode: Finished after 68 steps:최근 10 에피소드의 평균 단계 수 = 79.7\n",
551 | "98 Episode: Finished after 88 steps:최근 10 에피소드의 평균 단계 수 = 81.1\n",
552 | "99 Episode: Finished after 52 steps:최근 10 에피소드의 평균 단계 수 = 72.1\n",
553 | "100 Episode: Finished after 33 steps:최근 10 에피소드의 평균 단계 수 = 68.5\n",
554 | "101 Episode: Finished after 77 steps:최근 10 에피소드의 평균 단계 수 = 68.6\n",
555 | "102 Episode: Finished after 33 steps:최근 10 에피소드의 평균 단계 수 = 62.9\n",
556 | "103 Episode: Finished after 64 steps:최근 10 에피소드의 평균 단계 수 = 61.4\n",
557 | "104 Episode: Finished after 65 steps:최근 10 에피소드의 평균 단계 수 = 60.2\n",
558 | "105 Episode: Finished after 70 steps:최근 10 에피소드의 평균 단계 수 = 60.6\n",
559 | "106 Episode: Finished after 62 steps:최근 10 에피소드의 평균 단계 수 = 61.2\n",
560 | "107 Episode: Finished after 59 steps:최근 10 에피소드의 평균 단계 수 = 60.3\n",
561 | "108 Episode: Finished after 45 steps:최근 10 에피소드의 평균 단계 수 = 56.0\n",
562 | "109 Episode: Finished after 49 steps:최근 10 에피소드의 평균 단계 수 = 55.7\n",
563 | "110 Episode: Finished after 115 steps:최근 10 에피소드의 평균 단계 수 = 63.9\n",
564 | "111 Episode: Finished after 83 steps:최근 10 에피소드의 평균 단계 수 = 64.5\n",
565 | "112 Episode: Finished after 126 steps:최근 10 에피소드의 평균 단계 수 = 73.8\n",
566 | "113 Episode: Finished after 76 steps:최근 10 에피소드의 평균 단계 수 = 75.0\n",
567 | "114 Episode: Finished after 86 steps:최근 10 에피소드의 평균 단계 수 = 77.1\n",
568 | "115 Episode: Finished after 52 steps:최근 10 에피소드의 평균 단계 수 = 75.3\n",
569 | "116 Episode: Finished after 55 steps:최근 10 에피소드의 평균 단계 수 = 74.6\n",
570 | "117 Episode: Finished after 73 steps:최근 10 에피소드의 평균 단계 수 = 76.0\n",
571 | "118 Episode: Finished after 147 steps:최근 10 에피소드의 평균 단계 수 = 86.2\n",
572 | "119 Episode: Finished after 85 steps:최근 10 에피소드의 평균 단계 수 = 89.8\n",
573 | "120 Episode: Finished after 108 steps:최근 10 에피소드의 평균 단계 수 = 89.1\n",
574 | "121 Episode: Finished after 107 steps:최근 10 에피소드의 평균 단계 수 = 91.5\n",
575 | "122 Episode: Finished after 106 steps:최근 10 에피소드의 평균 단계 수 = 89.5\n",
576 | "123 Episode: Finished after 89 steps:최근 10 에피소드의 평균 단계 수 = 90.8\n",
577 | "124 Episode: Finished after 106 steps:최근 10 에피소드의 평균 단계 수 = 92.8\n",
578 | "125 Episode: Finished after 76 steps:최근 10 에피소드의 평균 단계 수 = 95.2\n",
579 | "126 Episode: Finished after 83 steps:최근 10 에피소드의 평균 단계 수 = 98.0\n",
580 | "127 Episode: Finished after 144 steps:최근 10 에피소드의 평균 단계 수 = 105.1\n",
581 | "128 Episode: Finished after 74 steps:최근 10 에피소드의 평균 단계 수 = 97.8\n",
582 | "129 Episode: Finished after 87 steps:최근 10 에피소드의 평균 단계 수 = 98.0\n"
583 | ]
584 | },
585 | {
586 | "name": "stdout",
587 | "output_type": "stream",
588 | "text": [
589 | "130 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 107.2\n",
590 | "131 Episode: Finished after 72 steps:최근 10 에피소드의 평균 단계 수 = 103.7\n",
591 | "132 Episode: Finished after 76 steps:최근 10 에피소드의 평균 단계 수 = 100.7\n",
592 | "133 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 111.8\n",
593 | "134 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 121.2\n",
594 | "135 Episode: Finished after 184 steps:최근 10 에피소드의 평균 단계 수 = 132.0\n",
595 | "136 Episode: Finished after 123 steps:최근 10 에피소드의 평균 단계 수 = 136.0\n",
596 | "137 Episode: Finished after 142 steps:최근 10 에피소드의 평균 단계 수 = 135.8\n",
597 | "138 Episode: Finished after 184 steps:최근 10 에피소드의 평균 단계 수 = 146.8\n",
598 | "139 Episode: Finished after 198 steps:최근 10 에피소드의 평균 단계 수 = 157.9\n",
599 | "140 Episode: Finished after 180 steps:최근 10 에피소드의 평균 단계 수 = 155.9\n",
600 | "141 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 168.7\n",
601 | "142 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 181.1\n",
602 | "143 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 181.1\n",
603 | "144 Episode: Finished after 196 steps:최근 10 에피소드의 평균 단계 수 = 180.7\n",
604 | "145 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 182.3\n",
605 | "146 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 190.0\n",
606 | "147 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 195.8\n",
607 | "148 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 197.4\n",
608 | "149 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 197.6\n",
609 | "150 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 199.6\n",
610 | "10 에피소드 연속 성공\n",
611 | "151 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 199.6\n"
612 | ]
613 | }
614 | ],
615 | "source": [
616 | "# 실행 엔트리 포인트\n",
617 | "cartpole_env = Environment()\n",
618 | "cartpole_env.run()\n"
619 | ]
620 | }
621 | ],
622 | "metadata": {
623 | "kernelspec": {
624 | "display_name": "Python 3",
625 | "language": "python",
626 | "name": "python3"
627 | },
628 | "language_info": {
629 | "codemirror_mode": {
630 | "name": "ipython",
631 | "version": 3
632 | },
633 | "file_extension": ".py",
634 | "mimetype": "text/x-python",
635 | "name": "python",
636 | "nbconvert_exporter": "python",
637 | "pygments_lexer": "ipython3",
638 | "version": "3.6.6"
639 | }
640 | },
641 | "nbformat": 4,
642 | "nbformat_minor": 2
643 | }
644 |
--------------------------------------------------------------------------------
/program/6_3_DuelingNetwork.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## 6.3 Dueling Network 구현"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "# 구현에 사용할 패키지 임포트\n",
17 | "import numpy as np\n",
18 | "import matplotlib.pyplot as plt\n",
19 | "%matplotlib inline\n",
20 | "import gym\n"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "# 애니메이션을 만드는 함수\n",
30 | "# 참고 URL http://nbviewer.jupyter.org/github/patrickmineault\n",
31 | "# /xcorr-notebooks/blob/master/Render%20OpenAI%20gym%20as%20GIF.ipynb\n",
32 | "from JSAnimation.IPython_display import display_animation\n",
33 | "from matplotlib import animation\n",
34 | "from IPython.display import display\n",
35 | "\n",
36 | "\n",
37 | "def display_frames_as_gif(frames):\n",
38 | " \"\"\"\n",
39 | " Displays a list of frames as a gif, with controls\n",
40 | " \"\"\"\n",
41 | " plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0),\n",
42 | " dpi=72)\n",
43 | " patch = plt.imshow(frames[0])\n",
44 | " plt.axis('off')\n",
45 | "\n",
46 | " def animate(i):\n",
47 | " patch.set_data(frames[i])\n",
48 | "\n",
49 | " anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames),\n",
50 | " interval=50)\n",
51 | "\n",
52 | " anim.save('movie_cartpole_dueling_network.mp4') # 애니메이션을 저장하는 부분\n",
53 | " display(display_animation(anim, default_mode='loop'))\n"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 3,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "# namedtuple 생성\n",
63 | "from collections import namedtuple\n",
64 | "\n",
65 | "Transition = namedtuple(\n",
66 | " 'Transition', ('state', 'action', 'next_state', 'reward'))\n"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 4,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "# 상수 정의\n",
76 | "ENV = 'CartPole-v0' # 태스크 이름\n",
77 | "GAMMA = 0.99 # 시간할인율\n",
78 | "MAX_STEPS = 200 # 1에피소드 당 최대 단계 수\n",
79 | "NUM_EPISODES = 500 # 최대 에피소드 수\n"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 5,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "# transition을 저장하기 위한 메모리 클래스\n",
89 | "\n",
90 | "\n",
91 | "class ReplayMemory:\n",
92 | "\n",
93 | " def __init__(self, CAPACITY):\n",
94 | " self.capacity = CAPACITY # 메모리의 최대 저장 건수\n",
95 | " self.memory = [] # 실제 transition을 저장할 변수\n",
96 | " self.index = 0 # 저장 위치를 가리킬 인덱스 변수\n",
97 | "\n",
98 | " def push(self, state, action, state_next, reward):\n",
99 | " '''transition = (state, action, state_next, reward)을 메모리에 저장'''\n",
100 | "\n",
101 | " if len(self.memory) < self.capacity:\n",
102 | " self.memory.append(None) # 메모리가 가득차지 않은 경우\n",
103 | "\n",
104 | " # Transition이라는 namedtuple을 사용하여 키-값 쌍의 형태로 값을 저장\n",
105 | " self.memory[self.index] = Transition(state, action, state_next, reward)\n",
106 | "\n",
107 | " self.index = (self.index + 1) % self.capacity # 다음 저장할 위치를 한 자리 뒤로 수정\n",
108 | "\n",
109 | " def sample(self, batch_size):\n",
110 | " '''batch_size 갯수 만큼 무작위로 저장된 transition을 추출'''\n",
111 | " return random.sample(self.memory, batch_size)\n",
112 | "\n",
113 | " def __len__(self):\n",
114 | " '''len 함수로 현재 저장된 transition 갯수를 반환'''\n",
115 | " return len(self.memory)\n"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 6,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "# Dueling Network 신경망 구성\n",
125 | "import torch.nn as nn\n",
126 | "import torch.nn.functional as F\n",
127 | "\n",
128 | "\n",
129 | "class Net(nn.Module):\n",
130 | "\n",
131 | " def __init__(self, n_in, n_mid, n_out):\n",
132 | " super(Net, self).__init__()\n",
133 | " self.fc1 = nn.Linear(n_in, n_mid)\n",
134 | " self.fc2 = nn.Linear(n_mid, n_mid)\n",
135 | " # Dueling Network\n",
136 | " self.fc3_adv = nn.Linear(n_mid, n_out) # Advantage함수쪽 신경망\n",
137 | " self.fc3_v = nn.Linear(n_mid, 1) # 가치 V쪽 신경망\n",
138 | "\n",
139 | " def forward(self, x):\n",
140 | " h1 = F.relu(self.fc1(x))\n",
141 | " h2 = F.relu(self.fc2(h1))\n",
142 | "\n",
143 | " adv = self.fc3_adv(h2) # 이 출력은 ReLU를 거치지 않음\n",
144 | " val = self.fc3_v(h2).expand(-1, adv.size(1)) # 이 출력은 ReLU를 거치지 않음\n",
145 | " # val은 adv와 덧셈을 하기 위해 expand 메서드로 크기를 [minibatch*1]에서 [minibatch*2]로 변환\n",
146 | " # adv.size(1)은 2(출력할 행동의 가짓수)\n",
147 | "\n",
148 | " output = val + adv - adv.mean(1, keepdim=True).expand(-1, adv.size(1))\n",
149 | " # val+adv에서 adv의 평균을 뺀다\n",
150 | " # adv.mean(1, keepdim=True) 으로 열방향(행동의 종류 방향) 평균을 구함 크기는 [minibatch*1]이 됨\n",
151 | " # expand 메서드로 크기를 [minibatch*2]로 늘림\n",
152 | "\n",
153 | " return output\n"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": 7,
159 | "metadata": {},
160 | "outputs": [],
161 | "source": [
162 | "# 에이전트의 두뇌 역할을 하는 클래스, DDQN을 실제 수행한다 \n",
163 | "\n",
164 | "import random\n",
165 | "import torch\n",
166 | "from torch import nn\n",
167 | "from torch import optim\n",
168 | "import torch.nn.functional as F\n",
169 | "\n",
170 | "BATCH_SIZE = 32\n",
171 | "CAPACITY = 10000\n",
172 | "\n",
173 | "\n",
174 | "class Brain:\n",
175 | " def __init__(self, num_states, num_actions):\n",
176 | " self.num_actions = num_actions # CartPoleの行動(右に左に押す)の2を取得\n",
177 | "\n",
178 | " # transition을 기억하기 위한 메모리 객체 생성\n",
179 | " self.memory = ReplayMemory(CAPACITY)\n",
180 | "\n",
181 | " # 신경망 구성\n",
182 | " n_in, n_mid, n_out = num_states, 32, num_actions\n",
183 | " self.main_q_network = Net(n_in, n_mid, n_out) # Net 클래스를 사용\n",
184 | " self.target_q_network = Net(n_in, n_mid, n_out) # Net 클래스를 사용\n",
185 | " print(self.main_q_network) # 신경망의 구조를 출력\n",
186 | "\n",
187 | " # 최적화 기법 선택\n",
188 | " self.optimizer = optim.Adam(\n",
189 | " self.main_q_network.parameters(), lr=0.0001)\n",
190 | "\n",
191 | " def replay(self):\n",
192 | " '''Experience Replay로 신경망의 결합 가중치 학습'''\n",
193 | "\n",
194 | " # 1. 저장된 transition의 수를 확인\n",
195 | " if len(self.memory) < BATCH_SIZE:\n",
196 | " return\n",
197 | "\n",
198 | " # 2. 미니배치 생성\n",
199 | " self.batch, self.state_batch, self.action_batch, self.reward_batch, self.non_final_next_states = self.make_minibatch()\n",
200 | "\n",
201 | " # 3. 정답신호로 사용할 Q(s_t, a_t)를 계산\n",
202 | " self.expected_state_action_values = self.get_expected_state_action_values()\n",
203 | "\n",
204 | " # 4. 결합 가중치 수정\n",
205 | " self.update_main_q_network()\n",
206 | "\n",
207 | " def decide_action(self, state, episode):\n",
208 | " '''현재 상태로부터 행동을 결정함'''\n",
209 | " # ε-greedy 알고리즘에서 서서히 최적행동의 비중을 늘린다\n",
210 | " epsilon = 0.5 * (1 / (episode + 1))\n",
211 | "\n",
212 | " if epsilon <= np.random.uniform(0, 1):\n",
213 | " self.main_q_network.eval() # 신경망을 추론 모드로 전환\n",
214 | " with torch.no_grad():\n",
215 | " action = self.main_q_network(state).max(1)[1].view(1, 1)\n",
216 | " # 신경망 출력의 최댓값에 대한 인덱스 = max(1)[1]\n",
217 | " # .view(1,1)은 [torch.LongTensor of size 1] 을 size 1*1로 변환하는 역할을 한다\n",
218 | "\n",
219 | " else:\n",
220 | " # 행동을 무작위로 반환(0 혹은 1)\n",
221 | " action = torch.LongTensor(\n",
222 | " [[random.randrange(self.num_actions)]]) # 행동을 무작위로 반환(0 혹은 1)\n",
223 | " # action은 [torch.LongTensor of size 1*1] 형태가 된다\n",
224 | "\n",
225 | " return action\n",
226 | "\n",
227 | " def make_minibatch(self):\n",
228 | " '''2. 미니배치 생성'''\n",
229 | "\n",
230 | " # 2.1 메모리 객체에서 미니배치를 추출\n",
231 | " transitions = self.memory.sample(BATCH_SIZE)\n",
232 | "\n",
233 | " # 2.2 각 변수를 미니배치에 맞는 형태로 변형\n",
234 | " # transitions는 각 단계 별로 (state, action, state_next, reward) 형태로 BATCH_SIZE 갯수만큼 저장됨\n",
235 | " # 다시 말해, (state, action, state_next, reward) * BATCH_SIZE 형태가 된다\n",
236 | " # 이것을 미니배치로 만들기 위해\n",
237 | " # (state*BATCH_SIZE, action*BATCH_SIZE, state_next*BATCH_SIZE, reward*BATCH_SIZE) 형태로 변환한다\n",
238 | " batch = Transition(*zip(*transitions))\n",
239 | "\n",
240 | " # 2.3 각 변수의 요소를 미니배치에 맞게 변형하고, 신경망으로 다룰 수 있도록 Variable로 만든다\n",
241 | " # state를 예로 들면, [torch.FloatTensor of size 1*4] 형태의 요소가 BATCH_SIZE 갯수만큼 있는 형태이다\n",
242 | " # 이를 torch.FloatTensor of size BATCH_SIZE*4 형태로 변형한다\n",
243 | " # 상태, 행동, 보상, non_final 상태로 된 미니배치를 나타내는 Variable을 생성\n",
244 | " # cat은 Concatenates(연접)을 의미한다\n",
245 | " state_batch = torch.cat(batch.state)\n",
246 | " action_batch = torch.cat(batch.action)\n",
247 | " reward_batch = torch.cat(batch.reward)\n",
248 | " non_final_next_states = torch.cat([s for s in batch.next_state\n",
249 | " if s is not None])\n",
250 | "\n",
251 | " return batch, state_batch, action_batch, reward_batch, non_final_next_states\n",
252 | "\n",
253 | " def get_expected_state_action_values(self):\n",
254 | " '''정답신호로 사용할 Q(s_t, a_t)를 계산'''\n",
255 | "\n",
256 | " # 3.1 신경망을 추론 모드로 전환\n",
257 | " self.main_q_network.eval()\n",
258 | " self.target_q_network.eval()\n",
259 | "\n",
260 | " # 3.2 신경망으로 Q(s_t, a_t)를 계산\n",
261 | " # self.model(state_batch)은 왼쪽, 오른쪽에 대한 Q값을 출력하며\n",
262 | " # [torch.FloatTensor of size BATCH_SIZEx2] 형태이다\n",
263 | " # 여기서부터는 실행한 행동 a_t에 대한 Q값을 계산하므로 action_batch에서 취한 행동 a_t가 \n",
264 | " # 왼쪽이냐 오른쪽이냐에 대한 인덱스를 구하고, 이에 대한 Q값을 gather 메서드로 모아온다\n",
265 | " self.state_action_values = self.main_q_network(\n",
266 | " self.state_batch).gather(1, self.action_batch)\n",
267 | "\n",
268 | " # 3.3 max{Q(s_t+1, a)}값을 계산한다 이때 다음 상태가 존재하는지에 주의해야 한다\n",
269 | "\n",
270 | " # cartpole이 done 상태가 아니고, next_state가 존재하는지 확인하는 인덱스 마스크를 만듬\n",
271 | " non_final_mask = torch.ByteTensor(tuple(map(lambda s: s is not None,\n",
272 | " self.batch.next_state)))\n",
273 | " # 먼저 전체를 0으로 초기화\n",
274 | " next_state_values = torch.zeros(BATCH_SIZE)\n",
275 | " \n",
276 | " a_m = torch.zeros(BATCH_SIZE).type(torch.LongTensor)\n",
277 | "\n",
278 | " # 다음 상태에서 Q값이 최대가 되는 행동 a_m을 Main Q-Network로 계산\n",
279 | " # 마지막에 붙은 [1]로 행동에 해당하는 인덱스를 구함\n",
280 | " a_m[non_final_mask] = self.main_q_network(\n",
281 | " self.non_final_next_states).detach().max(1)[1]\n",
282 | "\n",
283 | " # 다음 상태가 있는 것만을 걸러내고, size 32를 32*1로 변환\n",
284 | " a_m_non_final_next_states = a_m[non_final_mask].view(-1, 1)\n",
285 | "\n",
286 | " # 다음 상태가 있는 인덱스에 대해 행동 a_m의 Q값을 target Q-Network로 계산\n",
287 | " # detach() 메서드로 값을 꺼내옴\n",
288 | " # squeeze() 메서드로 size[minibatch*1]을 [minibatch]로 변환\n",
289 | " next_state_values[non_final_mask] = self.target_q_network(\n",
290 | " self.non_final_next_states).gather(1, a_m_non_final_next_states).detach().squeeze()\n",
291 | "\n",
292 | " # 3.4 정답신호로 사용할 Q(s_t, a_t)값을 Q러닝 식으로 계산한다\n",
293 | " expected_state_action_values = self.reward_batch + GAMMA * next_state_values\n",
294 | "\n",
295 | " return expected_state_action_values\n",
296 | "\n",
297 | " def update_main_q_network(self):\n",
298 | " '''4. 결합 가중치 수정'''\n",
299 | "\n",
300 | " # 4.1 신경망을 학습 모드로 전환\n",
301 | " self.main_q_network.train()\n",
302 | "\n",
303 | " # 4.2 손실함수를 계산 (smooth_l1_loss는 Huber 함수)\n",
304 | " # expected_state_action_values은\n",
305 | " # size가 [minibatch]이므로 unsqueeze하여 [minibatch*1]로 만든다\n",
306 | " loss = F.smooth_l1_loss(self.state_action_values,\n",
307 | " self.expected_state_action_values.unsqueeze(1))\n",
308 | "\n",
309 | " # 4.3 결합 가중치를 수정한다\n",
310 | " self.optimizer.zero_grad() # 경사를 초기화\n",
311 | " loss.backward() # 역전파 계산\n",
312 | " self.optimizer.step() # 결합 가중치 수정\n",
313 | "\n",
314 | " def update_target_q_network(self): # DDQN에서 추가됨\n",
315 | " '''Target Q-Network을 Main Q-Network와 맞춤'''\n",
316 | " self.target_q_network.load_state_dict(self.main_q_network.state_dict())\n"
317 | ]
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": 8,
322 | "metadata": {},
323 | "outputs": [],
324 | "source": [
325 | "# CartPole 태스크의 에이전트 클래스. 봉 달린 수레 자체라고 보면 된다\n",
326 | "\n",
327 | "\n",
328 | "class Agent:\n",
329 | " def __init__(self, num_states, num_actions):\n",
330 | " '''태스크의 상태 및 행동의 가짓수를 설정'''\n",
331 | " self.brain = Brain(num_states, num_actions) # 에이전트의 행동을 결정할 두뇌 역할 객체를 생성\n",
332 | "\n",
333 | " def update_q_function(self):\n",
334 | " '''Q함수를 수정'''\n",
335 | " self.brain.replay()\n",
336 | "\n",
337 | " def get_action(self, state, episode):\n",
338 | " '''행동을 결정'''\n",
339 | " action = self.brain.decide_action(state, episode)\n",
340 | " return action\n",
341 | "\n",
342 | " def memorize(self, state, action, state_next, reward):\n",
343 | " '''memory 객체에 state, action, state_next, reward 내용을 저장'''\n",
344 | " self.brain.memory.push(state, action, state_next, reward)\n",
345 | "\n",
346 | " def update_target_q_function(self):\n",
347 | " '''Target Q-Network을 Main Q-Network와 맞춤'''\n",
348 | " self.brain.update_target_q_network()\n",
349 | " "
350 | ]
351 | },
352 | {
353 | "cell_type": "code",
354 | "execution_count": 9,
355 | "metadata": {},
356 | "outputs": [],
357 | "source": [
358 | "# CartPole을 실행하는 환경 클래스\n",
359 | "\n",
360 | "\n",
361 | "class Environment:\n",
362 | "\n",
363 | " def __init__(self):\n",
364 | " self.env = gym.make(ENV) # 태스크를 설정\n",
365 | " num_states = self.env.observation_space.shape[0] # 태스크의 상태 변수 수(4)를 받아옴\n",
366 | " num_actions = self.env.action_space.n # 태스크의 행동 가짓수(2)를 받아옴\n",
367 | " self.agent = Agent(num_states, num_actions) # 에이전트 역할을 할 객체를 생성\n",
368 | "\n",
369 | " def run(self):\n",
370 | " '''실행'''\n",
371 | " episode_10_list = np.zeros(10) # 최근 10에피소드 동안 버틴 단계 수를 저장함(평균 단계 수를 출력할 때 사용)\n",
372 | " complete_episodes = 0 # 현재까지 195단계를 버틴 에피소드 수\n",
373 | " episode_final = False # 마지막 에피소드 여부\n",
374 | " frames = [] # 애니메이션을 만들기 위해 마지막 에피소드의 프레임을 저장할 배열\n",
375 | "\n",
376 | " for episode in range(NUM_EPISODES): # 최대 에피소드 수만큼 반복\n",
377 | " observation = self.env.reset() # 환경 초기화\n",
378 | "\n",
379 | " state = observation # 관측을 변환없이 그대로 상태 s로 사용\n",
380 | " state = torch.from_numpy(state).type(\n",
381 | " torch.FloatTensor) # NumPy 변수를 파이토치 텐서로 변환\n",
382 | " state = torch.unsqueeze(state, 0) # size 4를 size 1*4로 변환\n",
383 | "\n",
384 | " for step in range(MAX_STEPS): # 1 에피소드에 해당하는 반복문\n",
385 | " \n",
386 | " # 애니메이션 만드는 부분을 주석처리\n",
387 | " #if episode_final is True: # 마지막 에피소드에서는 각 시각의 이미지를 frames에 저장한다\n",
388 | " # frames.append(self.env.render(mode='rgb_array'))\n",
389 | " \n",
390 | " action = self.agent.get_action(state, episode) # 다음 행동을 결정\n",
391 | "\n",
392 | " # 행동 a_t를 실행하여 다음 상태 s_{t+1}과 done 플래그 값을 결정\n",
393 | " # action에 .item()을 호출하여 행동 내용을 구함\n",
394 | " observation_next, _, done, _ = self.env.step(\n",
395 | " action.item()) # reward와 info는 사용하지 않으므로 _로 처리\n",
396 | "\n",
397 | " # 보상을 부여하고 episode의 종료 판정 및 state_next를 설정한다\n",
398 | " if done: # 단계 수가 200을 넘었거나 봉이 일정 각도 이상 기울면 done이 True가 됨\n",
399 | " state_next = None # 다음 상태가 없으므로 None으로\n",
400 | "\n",
401 | " # 최근 10 에피소드에서 버틴 단계 수를 리스트에 저장\n",
402 | " episode_10_list = np.hstack(\n",
403 | " (episode_10_list[1:], step + 1))\n",
404 | "\n",
405 | " if step < 195:\n",
406 | " reward = torch.FloatTensor(\n",
407 | " [-1.0]) # 도중에 봉이 쓰러졌다면 페널티로 보상 -1을 부여\n",
408 | " complete_episodes = 0 # 연속 성공 에피소드 기록을 초기화\n",
409 | " else:\n",
410 | " reward = torch.FloatTensor([1.0]) # 봉이 서 있는 채로 에피소드를 마쳤다면 보상 1 부여\n",
411 | " complete_episodes = complete_episodes + 1 # 연속 성공 에피소드 기록을 갱신\n",
412 | " else:\n",
413 | " reward = torch.FloatTensor([0.0]) # 그 외의 경우는 보상 0을 부여\n",
414 | " state_next = observation_next # 관측 결과를 그대로 상태로 사용\n",
415 | " state_next = torch.from_numpy(state_next).type(\n",
416 | " torch.FloatTensor) # numpy 변수를 파이토치 텐서로 변환\n",
417 | " state_next = torch.unsqueeze(state_next, 0) # size 4를 size 1*4로 변환\n",
418 | "\n",
419 | " # 메모리에 경험을 저장\n",
420 | " self.agent.memorize(state, action, state_next, reward)\n",
421 | "\n",
422 | " # Experience Replay로 Q함수를 수정\n",
423 | " self.agent.update_q_function()\n",
424 | "\n",
425 | " # 관측 결과를 업데이트\n",
426 | " state = state_next\n",
427 | "\n",
428 | " # 에피소드 종료 처리\n",
429 | " if done:\n",
430 | " print('%d Episode: Finished after %d steps:최근 10 에피소드의 평균 단계 수 = %.1lf' % (\n",
431 | " episode, step + 1, episode_10_list.mean()))\n",
432 | " \n",
433 | " # DDQN으로 추가된 부분 2에피소드마다 한번씩 Target Q-Network을 Main Q-Network와 맞춰줌\n",
434 | " if(episode % 2 == 0):\n",
435 | " self.agent.update_target_q_function()\n",
436 | " break\n",
437 | " \n",
438 | " \n",
439 | " if episode_final is True:\n",
440 | " # 애니메이션 생성 부분을 주석처리함\n",
441 | " # 애니메이션 생성 및 저장\n",
442 | " #display_frames_as_gif(frames)\n",
443 | " break\n",
444 | "\n",
445 | " # 10 에피소드 연속으로 195단계를 버티면 태스크 성공\n",
446 | " if complete_episodes >= 10:\n",
447 | " print('10 에피소드 연속 성공')\n",
448 | " episode_final = True # 다음 에피소드에서 애니메이션을 생성\n"
449 | ]
450 | },
451 | {
452 | "cell_type": "code",
453 | "execution_count": 10,
454 | "metadata": {},
455 | "outputs": [
456 | {
457 | "name": "stdout",
458 | "output_type": "stream",
459 | "text": [
460 | "Net(\n",
461 | " (fc1): Linear(in_features=4, out_features=32, bias=True)\n",
462 | " (fc2): Linear(in_features=32, out_features=32, bias=True)\n",
463 | " (fc3_adv): Linear(in_features=32, out_features=2, bias=True)\n",
464 | " (fc3_v): Linear(in_features=32, out_features=1, bias=True)\n",
465 | ")\n",
466 | "0 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 1.1\n",
467 | "1 Episode: Finished after 13 steps:최근 10 에피소드의 평균 단계 수 = 2.4\n",
468 | "2 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 3.6\n",
469 | "3 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 4.7\n",
470 | "4 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 5.6\n",
471 | "5 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 6.6\n",
472 | "6 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 7.6\n",
473 | "7 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 8.8\n",
474 | "8 Episode: Finished after 16 steps:최근 10 에피소드의 평균 단계 수 = 10.4\n",
475 | "9 Episode: Finished after 18 steps:최근 10 에피소드의 평균 단계 수 = 12.2\n",
476 | "10 Episode: Finished after 17 steps:최근 10 에피소드의 평균 단계 수 = 12.8\n",
477 | "11 Episode: Finished after 18 steps:최근 10 에피소드의 평균 단계 수 = 13.3\n",
478 | "12 Episode: Finished after 20 steps:최근 10 에피소드의 평균 단계 수 = 14.1\n",
479 | "13 Episode: Finished after 27 steps:최근 10 에피소드의 평균 단계 수 = 15.7\n",
480 | "14 Episode: Finished after 21 steps:최근 10 에피소드의 평균 단계 수 = 16.9\n",
481 | "15 Episode: Finished after 84 steps:최근 10 에피소드의 평균 단계 수 = 24.3\n",
482 | "16 Episode: Finished after 24 steps:최근 10 에피소드의 평균 단계 수 = 25.7\n",
483 | "17 Episode: Finished after 16 steps:최근 10 에피소드의 평균 단계 수 = 26.1\n",
484 | "18 Episode: Finished after 16 steps:최근 10 에피소드의 평균 단계 수 = 26.1\n",
485 | "19 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 25.3\n",
486 | "20 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 24.6\n",
487 | "21 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 23.8\n",
488 | "22 Episode: Finished after 8 steps:최근 10 에피소드의 평균 단계 수 = 22.6\n",
489 | "23 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 20.8\n",
490 | "24 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 19.7\n",
491 | "25 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 12.2\n",
492 | "26 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 10.7\n",
493 | "27 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 10.0\n",
494 | "28 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.3\n",
495 | "29 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.2\n",
496 | "30 Episode: Finished after 12 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n",
497 | "31 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 9.4\n",
498 | "32 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 9.5\n",
499 | "33 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 9.7\n",
500 | "34 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 9.8\n",
501 | "35 Episode: Finished after 11 steps:최근 10 에피소드의 평균 단계 수 = 10.0\n",
502 | "36 Episode: Finished after 9 steps:최근 10 에피소드의 평균 단계 수 = 10.0\n",
503 | "37 Episode: Finished after 10 steps:최근 10 에피소드의 평균 단계 수 = 10.1\n",
504 | "38 Episode: Finished after 13 steps:최근 10 에피소드의 평균 단계 수 = 10.5\n",
505 | "39 Episode: Finished after 18 steps:최근 10 에피소드의 평균 단계 수 = 11.4\n",
506 | "40 Episode: Finished after 21 steps:최근 10 에피소드의 평균 단계 수 = 12.3\n",
507 | "41 Episode: Finished after 34 steps:최근 10 에피소드의 평균 단계 수 = 14.7\n",
508 | "42 Episode: Finished after 45 steps:최근 10 에피소드의 평균 단계 수 = 18.3\n",
509 | "43 Episode: Finished after 45 steps:최근 10 에피소드의 평균 단계 수 = 21.7\n",
510 | "44 Episode: Finished after 49 steps:최근 10 에피소드의 평균 단계 수 = 25.5\n",
511 | "45 Episode: Finished after 56 steps:최근 10 에피소드의 평균 단계 수 = 30.0\n",
512 | "46 Episode: Finished after 52 steps:최근 10 에피소드의 평균 단계 수 = 34.3\n",
513 | "47 Episode: Finished after 51 steps:최근 10 에피소드의 평균 단계 수 = 38.4\n",
514 | "48 Episode: Finished after 84 steps:최근 10 에피소드의 평균 단계 수 = 45.5\n",
515 | "49 Episode: Finished after 80 steps:최근 10 에피소드의 평균 단계 수 = 51.7\n",
516 | "50 Episode: Finished after 111 steps:최근 10 에피소드의 평균 단계 수 = 60.7\n",
517 | "51 Episode: Finished after 126 steps:최근 10 에피소드의 평균 단계 수 = 69.9\n",
518 | "52 Episode: Finished after 79 steps:최근 10 에피소드의 평균 단계 수 = 73.3\n",
519 | "53 Episode: Finished after 77 steps:최근 10 에피소드의 평균 단계 수 = 76.5\n",
520 | "54 Episode: Finished after 50 steps:최근 10 에피소드의 평균 단계 수 = 76.6\n",
521 | "55 Episode: Finished after 70 steps:최근 10 에피소드의 평균 단계 수 = 78.0\n",
522 | "56 Episode: Finished after 97 steps:최근 10 에피소드의 평균 단계 수 = 82.5\n",
523 | "57 Episode: Finished after 80 steps:최근 10 에피소드의 평균 단계 수 = 85.4\n",
524 | "58 Episode: Finished after 100 steps:최근 10 에피소드의 평균 단계 수 = 87.0\n",
525 | "59 Episode: Finished after 190 steps:최근 10 에피소드의 평균 단계 수 = 98.0\n",
526 | "60 Episode: Finished after 65 steps:최근 10 에피소드의 평균 단계 수 = 93.4\n",
527 | "61 Episode: Finished after 102 steps:최근 10 에피소드의 평균 단계 수 = 91.0\n",
528 | "62 Episode: Finished after 107 steps:최근 10 에피소드의 평균 단계 수 = 93.8\n",
529 | "63 Episode: Finished after 157 steps:최근 10 에피소드의 평균 단계 수 = 101.8\n",
530 | "64 Episode: Finished after 188 steps:최근 10 에피소드의 평균 단계 수 = 115.6\n",
531 | "65 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 128.6\n",
532 | "66 Episode: Finished after 187 steps:최근 10 에피소드의 평균 단계 수 = 137.6\n",
533 | "67 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 149.6\n",
534 | "68 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 159.6\n",
535 | "69 Episode: Finished after 105 steps:최근 10 에피소드의 평균 단계 수 = 151.1\n",
536 | "70 Episode: Finished after 179 steps:최근 10 에피소드의 평균 단계 수 = 162.5\n",
537 | "71 Episode: Finished after 194 steps:최근 10 에피소드의 평균 단계 수 = 171.7\n",
538 | "72 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 181.0\n",
539 | "73 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 185.3\n",
540 | "74 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 186.5\n",
541 | "75 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 186.5\n",
542 | "76 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 187.8\n",
543 | "77 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 187.8\n",
544 | "78 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 187.8\n",
545 | "79 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 197.3\n",
546 | "80 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 199.4\n",
547 | "81 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 200.0\n",
548 | "10 에피소드 연속 성공\n",
549 | "82 Episode: Finished after 200 steps:최근 10 에피소드의 평균 단계 수 = 200.0\n"
550 | ]
551 | }
552 | ],
553 | "source": [
554 | "# 실행 엔트리 포인트\n",
555 | "cartpole_env = Environment()\n",
556 | "cartpole_env.run()\n"
557 | ]
558 | },
559 | {
560 | "cell_type": "code",
561 | "execution_count": null,
562 | "metadata": {},
563 | "outputs": [],
564 | "source": []
565 | },
566 | {
567 | "cell_type": "code",
568 | "execution_count": null,
569 | "metadata": {},
570 | "outputs": [],
571 | "source": []
572 | }
573 | ],
574 | "metadata": {
575 | "kernelspec": {
576 | "display_name": "Python 3",
577 | "language": "python",
578 | "name": "python3"
579 | },
580 | "language_info": {
581 | "codemirror_mode": {
582 | "name": "ipython",
583 | "version": 3
584 | },
585 | "file_extension": ".py",
586 | "mimetype": "text/x-python",
587 | "name": "python",
588 | "nbconvert_exporter": "python",
589 | "pygments_lexer": "ipython3",
590 | "version": "3.6.6"
591 | }
592 | },
593 | "nbformat": 4,
594 | "nbformat_minor": 2
595 | }
596 |
--------------------------------------------------------------------------------
/program/6_5_A2C_Advanced_ActorCritic.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## 6.5 A2C(Advanced Actor-Critic) 구현"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "# 구현에 사용할 패키지 임포트\n",
17 | "import numpy as np\n",
18 | "import matplotlib.pyplot as plt\n",
19 | "%matplotlib inline\n",
20 | "import gym\n"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "# 상수 정의\n",
30 | "ENV = 'CartPole-v0' # 태스크 이름\n",
31 | "GAMMA = 0.99 # 시간할인율\n",
32 | "MAX_STEPS = 200 # 1에피소드 당 최대 단계 수\n",
33 | "NUM_EPISODES = 1000 # 최대 에피소드 수\n",
34 | "\n",
35 | "NUM_PROCESSES = 32 # 동시 실행 환경 수\n",
36 | "NUM_ADVANCED_STEP = 5 # 총 보상을 계산할 때 Advantage 학습을 할 단계 수\n"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 3,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "# A2C 손실함수 계산에 사용되는 상수\n",
46 | "value_loss_coef = 0.5\n",
47 | "entropy_coef = 0.01\n",
48 | "max_grad_norm = 0.5\n"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 4,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "# 메모리 클래스 정의\n",
58 | "\n",
59 | "\n",
60 | "class RolloutStorage(object):\n",
61 | " '''Advantage 학습에 사용할 메모리 클래스'''\n",
62 | "\n",
63 | " def __init__(self, num_steps, num_processes, obs_shape):\n",
64 | "\n",
65 | " self.observations = torch.zeros(num_steps + 1, num_processes, 4)\n",
66 | " self.masks = torch.ones(num_steps + 1, num_processes, 1)\n",
67 | " self.rewards = torch.zeros(num_steps, num_processes, 1)\n",
68 | " self.actions = torch.zeros(num_steps, num_processes, 1).long()\n",
69 | "\n",
70 | " # 할인 총보상 저장\n",
71 | " self.returns = torch.zeros(num_steps + 1, num_processes, 1)\n",
72 | " self.index = 0 # insert할 인덱스\n",
73 | "\n",
74 | " def insert(self, current_obs, action, reward, mask):\n",
75 | " '''현재 인덱스 위치에 transition을 저장'''\n",
76 | " self.observations[self.index + 1].copy_(current_obs)\n",
77 | " self.masks[self.index + 1].copy_(mask)\n",
78 | " self.rewards[self.index].copy_(reward)\n",
79 | " self.actions[self.index].copy_(action)\n",
80 | "\n",
81 | " self.index = (self.index + 1) % NUM_ADVANCED_STEP # 인덱스 값 업데이트\n",
82 | "\n",
83 | " def after_update(self):\n",
84 | " '''Advantage학습 단계만큼 단계가 진행되면 가장 새로운 transition을 index0에 저장'''\n",
85 | " self.observations[0].copy_(self.observations[-1])\n",
86 | " self.masks[0].copy_(self.masks[-1])\n",
87 | "\n",
88 | " def compute_returns(self, next_value):\n",
89 | " '''Advantage학습 범위 안의 각 단계에 대해 할인 총보상을 계산'''\n",
90 | "\n",
91 | " # 주의 : 5번째 단계부터 거슬러 올라오며 계산\n",
92 | " # 주의 : 5번째 단계가 Advantage1, 4번째 단계는 Advantage2가 됨\n",
93 | " self.returns[-1] = next_value\n",
94 | " for ad_step in reversed(range(self.rewards.size(0))):\n",
95 | " self.returns[ad_step] = self.returns[ad_step + 1] * \\\n",
96 | " GAMMA * self.masks[ad_step + 1] + self.rewards[ad_step]\n"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 5,
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "# A2C에 사용되는 신경망 구성\n",
106 | "import torch.nn as nn\n",
107 | "import torch.nn.functional as F\n",
108 | "\n",
109 | "\n",
110 | "class Net(nn.Module):\n",
111 | "\n",
112 | " def __init__(self, n_in, n_mid, n_out):\n",
113 | " super(Net, self).__init__()\n",
114 | " self.fc1 = nn.Linear(n_in, n_mid)\n",
115 | " self.fc2 = nn.Linear(n_mid, n_mid)\n",
116 | " self.actor = nn.Linear(n_mid, n_out) # 행동을 결정하는 부분이므로 출력 갯수는 행동의 가짓수\n",
117 | " self.critic = nn.Linear(n_mid, 1) # 상태가치를 출력하는 부분이므로 출력 갯수는 1개\n",
118 | "\n",
119 | " def forward(self, x):\n",
120 | " '''신경망 순전파 계산을 정의'''\n",
121 | " h1 = F.relu(self.fc1(x))\n",
122 | " h2 = F.relu(self.fc2(h1))\n",
123 | " critic_output = self.critic(h2) # 상태가치 계산\n",
124 | " actor_output = self.actor(h2) # 행동 계산\n",
125 | "\n",
126 | " return critic_output, actor_output\n",
127 | "\n",
128 | " def act(self, x):\n",
129 | " '''상태 x로부터 행동을 확률적으로 결정'''\n",
130 | " value, actor_output = self(x)\n",
131 | " # dim=1이므로 행동의 종류에 대해 softmax를 적용\n",
132 | " action_probs = F.softmax(actor_output, dim=1)\n",
133 | " action = action_probs.multinomial(num_samples=1) # dim=1이므로 행동의 종류에 대해 확률을 계산\n",
134 | " return action\n",
135 | "\n",
136 | " def get_value(self, x):\n",
137 | " '''상태 x로부터 상태가치를 계산'''\n",
138 | " value, actor_output = self(x)\n",
139 | "\n",
140 | " return value\n",
141 | "\n",
142 | " def evaluate_actions(self, x, actions):\n",
143 | " '''상태 x로부터 상태가치, 실제 행동 actions의 로그 확률, 엔트로피를 계산'''\n",
144 | " value, actor_output = self(x)\n",
145 | "\n",
146 | " log_probs = F.log_softmax(actor_output, dim=1) # dim=1이므로 행동의 종류에 대해 확률을 계산\n",
147 | " action_log_probs = log_probs.gather(1, actions) # 실제 행동의 로그 확률(log_probs)을 구함\n",
148 | "\n",
149 | " probs = F.softmax(actor_output, dim=1) # dim=1이므로 행동의 종류에 대한 계산\n",
150 | " entropy = -(log_probs * probs).sum(-1).mean()\n",
151 | "\n",
152 | " return value, action_log_probs, entropy\n"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": 6,
158 | "metadata": {},
159 | "outputs": [],
160 | "source": [
161 | "# 에이전트의 두뇌 역할을 하는 클래스. 모든 에이전트가 공유한다\n",
162 | "\n",
163 | "import torch\n",
164 | "from torch import optim\n",
165 | "\n",
166 | "\n",
167 | "class Brain(object):\n",
168 | " def __init__(self, actor_critic):\n",
169 | " self.actor_critic = actor_critic # actor_critic은 Net 클래스로 구현한 신경망\n",
170 | " self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=0.01)\n",
171 | "\n",
172 | " def update(self, rollouts):\n",
173 | " '''Advantage학습의 대상이 되는 5단계 모두를 사용하여 수정'''\n",
174 | " obs_shape = rollouts.observations.size()[2:] # torch.Size([4, 84, 84])\n",
175 | " num_steps = NUM_ADVANCED_STEP\n",
176 | " num_processes = NUM_PROCESSES\n",
177 | "\n",
178 | " values, action_log_probs, entropy = self.actor_critic.evaluate_actions(\n",
179 | " rollouts.observations[:-1].view(-1, 4),\n",
180 | " rollouts.actions.view(-1, 1))\n",
181 | "\n",
182 | " # 주의 : 각 변수의 크기\n",
183 | " # rollouts.observations[:-1].view(-1, 4) torch.Size([80, 4])\n",
184 | " # rollouts.actions.view(-1, 1) torch.Size([80, 1])\n",
185 | " # values torch.Size([80, 1])\n",
186 | " # action_log_probs torch.Size([80, 1])\n",
187 | " # entropy torch.Size([])\n",
188 | "\n",
189 | " values = values.view(num_steps, num_processes,\n",
190 | " 1) # torch.Size([5, 16, 1])\n",
191 | " action_log_probs = action_log_probs.view(num_steps, num_processes, 1)\n",
192 | "\n",
193 | " # advantage(행동가치-상태가치) 계산\n",
194 | " advantages = rollouts.returns[:-1] - values # torch.Size([5, 16, 1])\n",
195 | "\n",
196 | " # Critic의 loss 계산\n",
197 | " value_loss = advantages.pow(2).mean()\n",
198 | "\n",
199 | " # Actor의 gain 계산, 나중에 -1을 곱하면 loss가 된다\n",
200 | " action_gain = (action_log_probs*advantages.detach()).mean()\n",
201 | " # detach 메서드를 호출하여 advantages를 상수로 취급\n",
202 | "\n",
203 | " # 오차함수의 총합\n",
204 | " total_loss = (value_loss * value_loss_coef -\n",
205 | " action_gain - entropy * entropy_coef)\n",
206 | "\n",
207 | " # 결합 가중치 수정\n",
208 | " self.actor_critic.train() # 신경망을 학습 모드로 전환\n",
209 | " self.optimizer.zero_grad() # 경사를 초기화\n",
210 | " total_loss.backward() # 역전파 계산\n",
211 | " nn.utils.clip_grad_norm_(self.actor_critic.parameters(), max_grad_norm)\n",
212 | " # 결합 가중치가 한번에 너무 크게 변화하지 않도록, 경사를 0.5 이하로 제한함(클리핑)\n",
213 | "\n",
214 | " self.optimizer.step() # 결합 가중치 수정\n"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 7,
220 | "metadata": {},
221 | "outputs": [],
222 | "source": [
223 | "# 이번에는 에이전트 클래스가 없음"
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": 8,
229 | "metadata": {},
230 | "outputs": [],
231 | "source": [
232 | "# 실행 환경 클래스\n",
233 | "import copy\n",
234 | "\n",
235 | "\n",
236 | "class Environment:\n",
237 | " def run(self):\n",
238 | " '''실행 엔트리 포인트'''\n",
239 | "\n",
240 | " # 동시 실행할 환경 수 만큼 env를 생성\n",
241 | " envs = [gym.make(ENV) for i in range(NUM_PROCESSES)]\n",
242 | "\n",
243 | " # 모든 에이전트가 공유하는 Brain 객체를 생성\n",
244 | " n_in = envs[0].observation_space.shape[0] # 상태 변수 수는 4\n",
245 | " n_out = envs[0].action_space.n # 행동 가짓수는 2\n",
246 | " n_mid = 32\n",
247 | " actor_critic = Net(n_in, n_mid, n_out) # 신경망 객체 생성\n",
248 | " global_brain = Brain(actor_critic)\n",
249 | "\n",
250 | " # 각종 정보를 저장하는 변수\n",
251 | " obs_shape = n_in\n",
252 | " current_obs = torch.zeros(\n",
253 | " NUM_PROCESSES, obs_shape) # torch.Size([16, 4])\n",
254 | " rollouts = RolloutStorage(\n",
255 | " NUM_ADVANCED_STEP, NUM_PROCESSES, obs_shape) # rollouts 객체\n",
256 | " episode_rewards = torch.zeros([NUM_PROCESSES, 1]) # 현재 에피소드의 보상\n",
257 | " final_rewards = torch.zeros([NUM_PROCESSES, 1]) # 마지막 에피소드의 보상\n",
258 | " obs_np = np.zeros([NUM_PROCESSES, obs_shape]) # Numpy 배열\n",
259 | " reward_np = np.zeros([NUM_PROCESSES, 1]) # Numpy 배열\n",
260 | " done_np = np.zeros([NUM_PROCESSES, 1]) # Numpy 배열\n",
261 | " each_step = np.zeros(NUM_PROCESSES) # 각 환경의 단계 수를 기록\n",
262 | " episode = 0 # 환경 0의 에피소드 수\n",
263 | "\n",
264 | " # 초기 상태로부터 시작\n",
265 | " obs = [envs[i].reset() for i in range(NUM_PROCESSES)]\n",
266 | " obs = np.array(obs)\n",
267 | " obs = torch.from_numpy(obs).float() # torch.Size([16, 4])\n",
268 | " current_obs = obs # 가장 최근의 obs를 저장\n",
269 | " \n",
270 | " # advanced 학습에 사용되는 객체 rollouts 첫번째 상태에 현재 상태를 저장\n",
271 | " rollouts.observations[0].copy_(current_obs)\n",
272 | "\n",
273 | " # 1 에피소드에 해당하는 반복문\n",
274 | " for j in range(NUM_EPISODES*NUM_PROCESSES): # 전체 for문\n",
275 | " # advanced 학습 대상이 되는 각 단계에 대해 계산\n",
276 | " for step in range(NUM_ADVANCED_STEP):\n",
277 | "\n",
278 | " # 행동을 선택\n",
279 | " with torch.no_grad():\n",
280 | " action = actor_critic.act(rollouts.observations[step])\n",
281 | "\n",
282 | " # (16,1)→(16,) -> tensor를 NumPy변수로\n",
283 | " actions = action.squeeze(1).numpy()\n",
284 | "\n",
285 | " # 한 단계를 실행\n",
286 | " for i in range(NUM_PROCESSES):\n",
287 | " obs_np[i], reward_np[i], done_np[i], _ = envs[i].step(\n",
288 | " actions[i])\n",
289 | "\n",
290 | " # episode의 종료가치, state_next를 설정\n",
291 | " if done_np[i]: # 단계 수가 200을 넘거나, 봉이 일정 각도 이상 기울면 done이 True가 됨\n",
292 | "\n",
293 | " # 환경 0일 경우에만 출력\n",
294 | " if i == 0:\n",
295 | " print('%d Episode: Finished after %d steps' % (\n",
296 | " episode, each_step[i]+1))\n",
297 | " episode += 1\n",
298 | "\n",
299 | " # 보상 부여\n",
300 | " if each_step[i] < 195:\n",
301 | " reward_np[i] = -1.0 # 도중에 봉이 넘어지면 페널티로 보상 -1 부여\n",
302 | " else:\n",
303 | " reward_np[i] = 1.0 # 봉이 쓰러지지 않고 끝나면 보상 1 부여\n",
304 | "\n",
305 | " each_step[i] = 0 # 단계 수 초기화\n",
306 | " obs_np[i] = envs[i].reset() # 실행 환경 초기화\n",
307 | "\n",
308 | " else:\n",
309 | " reward_np[i] = 0.0 # 그 외의 경우는 보상 0 부여\n",
310 | " each_step[i] += 1\n",
311 | "\n",
312 | " # 보상을 tensor로 변환하고, 에피소드의 총보상에 더해줌\n",
313 | " reward = torch.from_numpy(reward_np).float()\n",
314 | " episode_rewards += reward\n",
315 | "\n",
316 | " # 각 실행 환경을 확인하여 done이 true이면 mask를 0으로, false이면 mask를 1로\n",
317 | " masks = torch.FloatTensor(\n",
318 | " [[0.0] if done_ else [1.0] for done_ in done_np])\n",
319 | "\n",
320 | " # 마지막 에피소드의 총 보상을 업데이트\n",
321 | " final_rewards *= masks # done이 false이면 1을 곱하고, true이면 0을 곱해 초기화\n",
322 | " # done이 false이면 0을 더하고, true이면 episode_rewards를 더해줌\n",
323 | " final_rewards += (1 - masks) * episode_rewards\n",
324 | "\n",
325 | " # 에피소드의 총보상을 업데이트\n",
326 | " episode_rewards *= masks # done이 false인 에피소드의 mask는 1이므로 그대로, true이면 0이 됨\n",
327 | "\n",
328 | " # 현재 done이 true이면 모두 0으로 \n",
329 | " current_obs *= masks\n",
330 | "\n",
331 | " # current_obs를 업데이트\n",
332 | " obs = torch.from_numpy(obs_np).float() # torch.Size([16, 4])\n",
333 | " current_obs = obs # 최신 상태의 obs를 저장\n",
334 | "\n",
335 | " # 메모리 객체에 현 단계의 transition을 저장\n",
336 | " rollouts.insert(current_obs, action.data, reward, masks)\n",
337 | "\n",
338 | " # advanced 학습 for문 끝\n",
339 | "\n",
340 | " # advanced 학습 대상 중 마지막 단계의 상태로 예측하는 상태가치를 계산\n",
341 | "\n",
342 | " with torch.no_grad():\n",
343 | " next_value = actor_critic.get_value(\n",
344 | " rollouts.observations[-1]).detach()\n",
345 | " # rollouts.observations의 크기는 torch.Size([6, 16, 4])\n",
346 | "\n",
347 | " # 모든 단계의 할인총보상을 계산하고, rollouts의 변수 returns를 업데이트\n",
348 | " rollouts.compute_returns(next_value)\n",
349 | "\n",
350 | " # 신경망 및 rollout 업데이트\n",
351 | " global_brain.update(rollouts)\n",
352 | " rollouts.after_update()\n",
353 | "\n",
354 | " # 환경 갯수를 넘어서는 횟수로 200단계를 버텨내면 성공\n",
355 | " if final_rewards.sum().numpy() >= NUM_PROCESSES:\n",
356 | " print('연속성공')\n",
357 | " break\n"
358 | ]
359 | },
360 | {
361 | "cell_type": "code",
362 | "execution_count": 9,
363 | "metadata": {},
364 | "outputs": [
365 | {
366 | "name": "stdout",
367 | "output_type": "stream",
368 | "text": [
369 | "0 Episode: Finished after 20 steps\n",
370 | "1 Episode: Finished after 13 steps\n",
371 | "2 Episode: Finished after 10 steps\n",
372 | "3 Episode: Finished after 41 steps\n",
373 | "4 Episode: Finished after 20 steps\n",
374 | "5 Episode: Finished after 15 steps\n",
375 | "6 Episode: Finished after 19 steps\n",
376 | "7 Episode: Finished after 19 steps\n",
377 | "8 Episode: Finished after 29 steps\n",
378 | "9 Episode: Finished after 23 steps\n",
379 | "10 Episode: Finished after 73 steps\n",
380 | "11 Episode: Finished after 49 steps\n",
381 | "12 Episode: Finished after 66 steps\n",
382 | "13 Episode: Finished after 14 steps\n",
383 | "14 Episode: Finished after 16 steps\n",
384 | "15 Episode: Finished after 74 steps\n",
385 | "16 Episode: Finished after 200 steps\n",
386 | "17 Episode: Finished after 200 steps\n",
387 | "18 Episode: Finished after 182 steps\n",
388 | "19 Episode: Finished after 200 steps\n",
389 | "연속성공\n"
390 | ]
391 | }
392 | ],
393 | "source": [
394 | "# main 실행\n",
395 | "cartpole_env = Environment()\n",
396 | "cartpole_env.run()\n"
397 | ]
398 | },
399 | {
400 | "cell_type": "code",
401 | "execution_count": null,
402 | "metadata": {},
403 | "outputs": [],
404 | "source": []
405 | }
406 | ],
407 | "metadata": {
408 | "kernelspec": {
409 | "display_name": "Python 3",
410 | "language": "python",
411 | "name": "python3"
412 | },
413 | "language_info": {
414 | "codemirror_mode": {
415 | "name": "ipython",
416 | "version": 3
417 | },
418 | "file_extension": ".py",
419 | "mimetype": "text/x-python",
420 | "name": "python",
421 | "nbconvert_exporter": "python",
422 | "pygments_lexer": "ipython3",
423 | "version": "3.6.6"
424 | }
425 | },
426 | "nbformat": 4,
427 | "nbformat_minor": 2
428 | }
429 |
--------------------------------------------------------------------------------
/program/7_breakout_play.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## 7장 벽돌깨기 게임 학습 결과 실행용 프로그램 "
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 2,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "# 구현에 사용할 패키지 임포트\n",
17 | "import numpy as np\n",
18 | "from collections import deque\n",
19 | "from tqdm import tqdm\n",
20 | "\n",
21 | "import torch\n",
22 | "import torch.nn as nn\n",
23 | "import torch.nn.functional as F\n",
24 | "import torch.optim as optim\n",
25 | "\n",
26 | "import gym\n",
27 | "from gym import spaces\n",
28 | "from gym.spaces.box import Box\n"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 2,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "# 구현에 사용할 패키지 임포트\n",
38 | "import matplotlib.pyplot as plt\n",
39 | "%matplotlib inline\n",
40 | "\n",
41 | "\n",
42 | "# 애니메이션을 만드는 함수\n",
43 | "# 참고 URL http://nbviewer.jupyter.org/github/patrickmineault\n",
44 | "# /xcorr-notebooks/blob/master/Render%20OpenAI%20gym%20as%20GIF.ipynb\n",
45 | "from JSAnimation.IPython_display import display_animation\n",
46 | "from matplotlib import animation\n",
47 | "from IPython.display import display\n",
48 | "\n",
49 | "def display_frames_as_gif(frames):\n",
50 | " \"\"\"\n",
51 | " Displays a list of frames as a gif, with controls\n",
52 | " \"\"\"\n",
53 | " plt.figure(figsize=(frames[0].shape[1]/72.0*1, frames[0].shape[0]/72.0*1),\n",
54 | " dpi=72)\n",
55 | " patch = plt.imshow(frames[0])\n",
56 | " plt.axis('off')\n",
57 | " \n",
58 | " def animate(i):\n",
59 | " patch.set_data(frames[i])\n",
60 | " \n",
61 | " anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames),\n",
62 | " interval=20)\n",
63 | " \n",
64 | " anim.save('breakout.mp4') # 애니메이션을 저장하는 부분\n",
65 | " display(display_animation(anim, default_mode='loop'))\n",
66 | " "
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 3,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "# 실행환경 설정\n",
76 | "# 参考:https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py\n",
77 | "\n",
78 | "import cv2\n",
79 | "cv2.ocl.setUseOpenCL(False)\n",
80 | "\n",
81 | "\n",
82 | "class NoopResetEnv(gym.Wrapper):\n",
83 | " def __init__(self, env, noop_max=30):\n",
84 | " '''첫 번째 트릭 No-Operation. 초기화 후 일정 단계에 이를때까지 아무 행동도 하지않고\n",
85 | " 게임 초기 상태를 다양하게 하여 특정 시작 상태만 학습하는 것을 방지한다'''\n",
86 | "\n",
87 | " gym.Wrapper.__init__(self, env)\n",
88 | " self.noop_max = noop_max\n",
89 | " self.override_num_noops = None\n",
90 | " self.noop_action = 0\n",
91 | " assert env.unwrapped.get_action_meanings()[0] == 'NOOP'\n",
92 | "\n",
93 | " def reset(self, **kwargs):\n",
94 | " \"\"\" Do no-op action for a number of steps in [1, noop_max].\"\"\"\n",
95 | " self.env.reset(**kwargs)\n",
96 | " if self.override_num_noops is not None:\n",
97 | " noops = self.override_num_noops\n",
98 | " else:\n",
99 | " noops = self.unwrapped.np_random.randint(\n",
100 | " 1, self.noop_max + 1) # pylint: disable=E1101\n",
101 | " assert noops > 0\n",
102 | " obs = None\n",
103 | " for _ in range(noops):\n",
104 | " obs, _, done, _ = self.env.step(self.noop_action)\n",
105 | " if done:\n",
106 | " obs = self.env.reset(**kwargs)\n",
107 | " return obs\n",
108 | "\n",
109 | " def step(self, ac):\n",
110 | " return self.env.step(ac)\n",
111 | "\n",
112 | "\n",
113 | "class EpisodicLifeEnv(gym.Wrapper):\n",
114 | " def __init__(self, env):\n",
115 | " '''두 번째 트릭 Episodic Life. 한번 실패를 게임 종료로 간주하나, 다음 게임을 같은 블록 상태로 시작'''\n",
116 | " gym.Wrapper.__init__(self, env)\n",
117 | " self.lives = 0\n",
118 | " self.was_real_done = True\n",
119 | "\n",
120 | " def step(self, action):\n",
121 | " obs, reward, done, info = self.env.step(action)\n",
122 | " self.was_real_done = done\n",
123 | " # check current lives, make loss of life terminal,\n",
124 | " # then update lives to handle bonus lives\n",
125 | " lives = self.env.unwrapped.ale.lives()\n",
126 | " if lives < self.lives and lives > 0:\n",
127 | " # for Qbert sometimes we stay in lives == 0 condtion for a few frames\n",
128 | " # so its important to keep lives > 0, so that we only reset once\n",
129 | " # the environment advertises done.\n",
130 | " done = True\n",
131 | " self.lives = lives\n",
132 | " return obs, reward, done, info\n",
133 | "\n",
134 | " def reset(self, **kwargs):\n",
135 | " '''5번 실패하면 게임을 완전히 다시 시작'''\n",
136 | " if self.was_real_done:\n",
137 | " obs = self.env.reset(**kwargs)\n",
138 | " else:\n",
139 | " # no-op step to advance from terminal/lost life state\n",
140 | " obs, _, _, _ = self.env.step(0)\n",
141 | " self.lives = self.env.unwrapped.ale.lives()\n",
142 | " return obs\n",
143 | "\n",
144 | "\n",
145 | "class MaxAndSkipEnv(gym.Wrapper):\n",
146 | " def __init__(self, env, skip=4):\n",
147 | " '''세 번째 트릭 Max and Skip. 4프레임 동안 같은 행동을 지속하되, 3번째와 4번째 프레임의 최댓값 이미지를 관측 obs로 삼는다'''\n",
148 | " gym.Wrapper.__init__(self, env)\n",
149 | " # most recent raw observations (for max pooling across time steps)\n",
150 | " self._obs_buffer = np.zeros(\n",
151 | " (2,)+env.observation_space.shape, dtype=np.uint8)\n",
152 | " self._skip = skip\n",
153 | "\n",
154 | " def step(self, action):\n",
155 | " \"\"\"Repeat action, sum reward, and max over last observations.\"\"\"\n",
156 | " total_reward = 0.0\n",
157 | " done = None\n",
158 | " for i in range(self._skip):\n",
159 | " obs, reward, done, info = self.env.step(action)\n",
160 | " if i == self._skip - 2:\n",
161 | " self._obs_buffer[0] = obs\n",
162 | " if i == self._skip - 1:\n",
163 | " self._obs_buffer[1] = obs\n",
164 | " total_reward += reward\n",
165 | " if done:\n",
166 | " break\n",
167 | " # Note that the observation on the done=True frame\n",
168 | " # doesn't matter\n",
169 | " max_frame = self._obs_buffer.max(axis=0)\n",
170 | "\n",
171 | " return max_frame, total_reward, done, info\n",
172 | "\n",
173 | " def reset(self, **kwargs):\n",
174 | " return self.env.reset(**kwargs)\n",
175 | "\n",
176 | "\n",
177 | "class WarpFrame(gym.ObservationWrapper):\n",
178 | " def __init__(self, env):\n",
179 | " '''네 번째 트릭 Warp frame. DQN 네이처 논문 구현과 같이 84*84 흑백 이미지를 사용'''\n",
180 | " gym.ObservationWrapper.__init__(self, env)\n",
181 | " self.width = 84\n",
182 | " self.height = 84\n",
183 | " self.observation_space = spaces.Box(low=0, high=255,\n",
184 | " shape=(self.height, self.width, 1), dtype=np.uint8)\n",
185 | "\n",
186 | " def observation(self, frame):\n",
187 | " frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)\n",
188 | " frame = cv2.resize(frame, (self.width, self.height),\n",
189 | " interpolation=cv2.INTER_AREA)\n",
190 | " return frame[:, :, None]\n",
191 | "\n",
192 | "\n",
193 | "class WrapPyTorch(gym.ObservationWrapper):\n",
194 | " def __init__(self, env=None):\n",
195 | " '''인덱스 순서를 파이토치 미니배치와 같이 조정하는 래퍼'''\n",
196 | " super(WrapPyTorch, self).__init__(env)\n",
197 | " obs_shape = self.observation_space.shape\n",
198 | " self.observation_space = Box(\n",
199 | " self.observation_space.low[0, 0, 0],\n",
200 | " self.observation_space.high[0, 0, 0],\n",
201 | " [obs_shape[2], obs_shape[1], obs_shape[0]],\n",
202 | " dtype=self.observation_space.dtype)\n",
203 | "\n",
204 | " def observation(self, observation):\n",
205 | " return observation.transpose(2, 0, 1)\n"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": 4,
211 | "metadata": {},
212 | "outputs": [],
213 | "source": [
214 | "# 再生用の実行環境\n",
215 | "\n",
216 | "\n",
217 | "class EpisodicLifeEnvPlay(gym.Wrapper):\n",
218 | " def __init__(self, env):\n",
219 | " '''두 번째 트릭 Episodic Life. 한번 실패를 게임 종료로 간주하나, 다음 게임을 같은 블록 상태로 시작\n",
220 | " 그러나 여기서는 학습 결과 플레이를 위해 한번 실패에도 블록 상태까지 초기화한다'''\n",
221 | "\n",
222 | " gym.Wrapper.__init__(self, env)\n",
223 | "\n",
224 | " def step(self, action):\n",
225 | " obs, reward, done, info = self.env.step(action)\n",
226 | " # 처음 5개 라이프를 갖고 시작하지만, 하나만 잃어도 종료한다\n",
227 | " if self.env.unwrapped.ale.lives() < 5:\n",
228 | " done = True\n",
229 | "\n",
230 | " return obs, reward, done, info\n",
231 | "\n",
232 | " def reset(self, **kwargs):\n",
233 | " '''한번이라도 실패하면 완전히 초기화'''\n",
234 | "\n",
235 | " obs = self.env.reset(**kwargs)\n",
236 | "\n",
237 | " return obs\n",
238 | "\n",
239 | "\n",
240 | "class MaxAndSkipEnvPlay(gym.Wrapper):\n",
241 | " def __init__(self, env, skip=4):\n",
242 | " '''세 번째 트릭 Max and Skip. 4프레임 동안 같은 행동을 지속하되, 4번째 프레임 이미지를 관측 obs로 삼는다'''\n",
243 | " gym.Wrapper.__init__(self, env)\n",
244 | " # most recent raw observations (for max pooling across time steps)\n",
245 | " self._obs_buffer = np.zeros(\n",
246 | " (2,)+env.observation_space.shape, dtype=np.uint8)\n",
247 | " self._skip = skip\n",
248 | "\n",
249 | " def step(self, action):\n",
250 | " \"\"\"Repeat action, sum reward, and max over last observations.\"\"\"\n",
251 | " total_reward = 0.0\n",
252 | " done = None\n",
253 | " for i in range(self._skip):\n",
254 | " obs, reward, done, info = self.env.step(action)\n",
255 | " if i == self._skip - 2:\n",
256 | " self._obs_buffer[0] = obs\n",
257 | " if i == self._skip - 1:\n",
258 | " self._obs_buffer[1] = obs\n",
259 | " total_reward += reward\n",
260 | " if done:\n",
261 | " break\n",
262 | "\n",
263 | " return obs, total_reward, done, info\n",
264 | "\n",
265 | " def reset(self, **kwargs):\n",
266 | " return self.env.reset(**kwargs)\n"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "execution_count": 5,
272 | "metadata": {},
273 | "outputs": [],
274 | "source": [
275 | "# 실행환경 생성 함수\n",
276 | "\n",
277 | "# 병렬 실행환경\n",
278 | "from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv\n",
279 | "\n",
280 | "\n",
281 | "def make_env(env_id, seed, rank):\n",
282 | " def _thunk():\n",
283 | " '''멀티 프로세스로 동작하는 환경 SubprocVecEnv를 실행하기 위해 필요하다'''\n",
284 | "\n",
285 | " env = gym.make(env_id)\n",
286 | " #env = NoopResetEnv(env, noop_max=30)\n",
287 | " env = MaxAndSkipEnv(env, skip=4)\n",
288 | " env.seed(seed + rank) # 난수 시드 설정\n",
289 | " #env = EpisodicLifeEnv(env)\n",
290 | " env = EpisodicLifeEnvPlay(env)\n",
291 | " env = WarpFrame(env)\n",
292 | " env = WrapPyTorch(env)\n",
293 | "\n",
294 | " return env\n",
295 | "\n",
296 | " return _thunk\n",
297 | "\n",
298 | "\n",
299 | "def make_env_play(env_id, seed, rank):\n",
300 | " '''학습 결과 시연용 실행환경'''\n",
301 | " env = gym.make(env_id)\n",
302 | " #env = NoopResetEnv(env, noop_max=30)\n",
303 | " #env = MaxAndSkipEnv(env, skip=4)\n",
304 | " env = MaxAndSkipEnvPlay(env, skip=4)\n",
305 | " env.seed(seed + rank) # 난수 시드 설정\n",
306 | " env = EpisodicLifeEnvPlay(env)\n",
307 | " #env = WarpFrame(env)\n",
308 | " #env = WrapPyTorch(env)\n",
309 | "\n",
310 | " return env\n"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": 6,
316 | "metadata": {},
317 | "outputs": [],
318 | "source": [
319 | "# 상수 정의\n",
320 | "\n",
321 | "ENV_NAME = 'BreakoutNoFrameskip-v4' \n",
322 | "# Breakout-v0 대신 BreakoutNoFrameskip-v4을 사용\n",
323 | "# v0은 2~4개 프레임을 자동으로 생략하므로 이 기능이 없는 버전을 사용한다\n",
324 | "# 참고 URL https://becominghuman.ai/lets-build-an-atari-ai-part-1-dqn-df57e8ff3b26\n",
325 | "# https://github.com/openai/gym/blob/5cb12296274020db9bb6378ce54276b31e7002da/gym/envs/__init__.py#L371\n",
326 | " \n",
327 | "NUM_SKIP_FRAME = 4 # 생략할 프레임 수\n",
328 | "NUM_STACK_FRAME = 4 # 하나의 상태로 사용할 프레임의 수\n",
329 | "NOOP_MAX = 30 # 초기화 후 No-operation을 적용할 최초 프레임 수의 최댓값\n",
330 | "NUM_PROCESSES = 16 # 병렬로 실행할 프로세스 수\n",
331 | "NUM_ADVANCED_STEP = 5 # Advanced 학습할 단계 수\n",
332 | "GAMMA = 0.99 # 시간할인율\n",
333 | "\n",
334 | "TOTAL_FRAMES=10e6 # 학습에 사용하는 총 프레임 수\n",
335 | "NUM_UPDATES = int(TOTAL_FRAMES / NUM_ADVANCED_STEP / NUM_PROCESSES) # 신경망 수정 총 횟수\n",
336 | "# NUM_UPDATES는 약 125,000이 됨\n"
337 | ]
338 | },
339 | {
340 | "cell_type": "code",
341 | "execution_count": 7,
342 | "metadata": {},
343 | "outputs": [],
344 | "source": [
345 | "# A2C 손실함수를 계산하기 위한 상수\n",
346 | "value_loss_coef = 0.5\n",
347 | "entropy_coef = 0.01\n",
348 | "max_grad_norm = 0.5\n",
349 | "\n",
350 | "# 최적회 기법 RMSprop에 대한 설정\n",
351 | "lr = 7e-4\n",
352 | "eps = 1e-5\n",
353 | "alpha = 0.99\n"
354 | ]
355 | },
356 | {
357 | "cell_type": "code",
358 | "execution_count": 3,
359 | "metadata": {},
360 | "outputs": [
361 | {
362 | "name": "stdout",
363 | "output_type": "stream",
364 | "text": [
365 | "cuda\n"
366 | ]
367 | }
368 | ],
369 | "source": [
370 | "# GPU 사용 설정\n",
371 | "use_cuda = torch.cuda.is_available()\n",
372 | "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
373 | "print(device)\n"
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "execution_count": 9,
379 | "metadata": {},
380 | "outputs": [],
381 | "source": [
382 | "# 메모리 클래스 정의\n",
383 | "\n",
384 | "\n",
385 | "class RolloutStorage(object):\n",
386 | " '''Advantage 학습에 사용하는 메모리 클래스'''\n",
387 | "\n",
388 | " def __init__(self, num_steps, num_processes, obs_shape):\n",
389 | "\n",
390 | " self.observations = torch.zeros(\n",
391 | " num_steps + 1, num_processes, *obs_shape).to(device)\n",
392 | " # *로 리스트의 요소를 풀어낸다(unpack)\n",
393 | " # obs_shape→(4,84,84)\n",
394 | " # *obs_shape→ 4 84 84\n",
395 | "\n",
396 | " self.masks = torch.ones(num_steps + 1, num_processes, 1).to(device)\n",
397 | " self.rewards = torch.zeros(num_steps, num_processes, 1).to(device)\n",
398 | " self.actions = torch.zeros(\n",
399 | " num_steps, num_processes, 1).long().to(device)\n",
400 | "\n",
401 | " # 할인 총보상을 저장\n",
402 | " self.returns = torch.zeros(num_steps + 1, num_processes, 1).to(device)\n",
403 | " self.index = 0 # 저장할 인덱스\n",
404 | "\n",
405 | " def insert(self, current_obs, action, reward, mask):\n",
406 | " '''인덱스가 가리키는 다음 자리에 transition을 저장'''\n",
407 | " self.observations[self.index + 1].copy_(current_obs)\n",
408 | " self.masks[self.index + 1].copy_(mask)\n",
409 | " self.rewards[self.index].copy_(reward)\n",
410 | " self.actions[self.index].copy_(action)\n",
411 | "\n",
412 | " self.index = (self.index + 1) % NUM_ADVANCED_STEP # 인덱스 업데이트\n",
413 | "\n",
414 | " def after_update(self):\n",
415 | " '''Advantage 학습 단계 수만큼 단계가 진행되면 가장 최근 단계를 index0에 저장'''\n",
416 | " self.observations[0].copy_(self.observations[-1])\n",
417 | " self.masks[0].copy_(self.masks[-1])\n",
418 | "\n",
419 | " def compute_returns(self, next_value):\n",
420 | " '''Advantage 학습 단계에 들어가는 각 단계에 대해 할인 총보상을 계산'''\n",
421 | "\n",
422 | " # 주의 : 5번째 단계부터 거슬러 올라가며 계산\n",
423 | " # 주의 : 5번째 단계가 Advantage1, 4번째 단계가 Advantage2가 되는 식임\n",
424 | " self.returns[-1] = next_value\n",
425 | " for ad_step in reversed(range(self.rewards.size(0))):\n",
426 | " self.returns[ad_step] = self.returns[ad_step + 1] * \\\n",
427 | " GAMMA * self.masks[ad_step + 1] + self.rewards[ad_step]\n"
428 | ]
429 | },
430 | {
431 | "cell_type": "code",
432 | "execution_count": 10,
433 | "metadata": {},
434 | "outputs": [],
435 | "source": [
436 | "# A2C 신경망 구성\n",
437 | "\n",
438 | "\n",
439 | "def init(module, gain):\n",
440 | " '''결합 가중치를 초기화하는 함수'''\n",
441 | " nn.init.orthogonal_(module.weight.data, gain=gain)\n",
442 | " nn.init.constant_(module.bias.data, 0)\n",
443 | " return module\n",
444 | "\n",
445 | "\n",
446 | "class Flatten(nn.Module):\n",
447 | " '''합성곱층의 출력 이미지를 1차원으로 변환하는 층'''\n",
448 | "\n",
449 | " def forward(self, x):\n",
450 | " return x.view(x.size(0), -1)\n",
451 | "\n",
452 | "\n",
453 | "class Net(nn.Module):\n",
454 | " def __init__(self, n_out):\n",
455 | " super(Net, self).__init__()\n",
456 | "\n",
457 | " # 결합 가중치 초기화 함수\n",
458 | " def init_(module): return init(\n",
459 | " module, gain=nn.init.calculate_gain('relu'))\n",
460 | "\n",
461 | " # 합성곱층을 정의\n",
462 | " self.conv = nn.Sequential(\n",
463 | " # 이미지 크기의 변화 (84*84 -> 20*20)\n",
464 | " init_(nn.Conv2d(NUM_STACK_FRAME, 32, kernel_size=8, stride=4)),\n",
465 | " # 프레임 4개를 합치므로 input=NUM_STACK_FRAME=4가 된다. 출력은 32이다.\n",
466 | " # size 계산 size = (Input_size - Kernel_size + 2*Padding_size)/ Stride_size + 1\n",
467 | "\n",
468 | " nn.ReLU(),\n",
469 | " # 이미지 크기의 변화 (20*20 -> 9*9)\n",
470 | " init_(nn.Conv2d(32, 64, kernel_size=4, stride=2)),\n",
471 | " nn.ReLU(),\n",
472 | " init_(nn.Conv2d(64, 64, kernel_size=3, stride=1)), # 이미지 크기의 변화(9*9 -> 7*7)\n",
473 | " nn.ReLU(),\n",
474 | " Flatten(), # 이미지를 1차원으로 변환\n",
475 | " init_(nn.Linear(64 * 7 * 7, 512)), # 7*7 이미지 64개를 512차원으로 변환\n",
476 | " nn.ReLU()\n",
477 | " )\n",
478 | "\n",
479 | " # 결합 가중치 초기화 함수\n",
480 | " def init_(module): return init(module, gain=1.0)\n",
481 | "\n",
482 | " # Critic을 정의\n",
483 | " self.critic = init_(nn.Linear(512, 1)) # 출력은 상태가치이므로 1개\n",
484 | "\n",
485 | " # 결합 가중치 초기화 함수\n",
486 | " def init_(module): return init(module, gain=0.01)\n",
487 | "\n",
488 | " # Actor를 정의\n",
489 | " self.actor = init_(nn.Linear(512, n_out)) # 출력이 행동이므로 출력 수는 행동의 가짓수\n",
490 | " \n",
491 | " # 신경망을 학습 모드로 전환\n",
492 | " self.train()\n",
493 | "\n",
494 | " def forward(self, x):\n",
495 | " '''신경망의 순전파 계산 정의'''\n",
496 | " input = x / 255.0 # 이미지의 픽셀값을 [0,255]에서 [0,1] 구간으로 정규화\n",
497 | " conv_output = self.conv(input) # 합성곱층 계산\n",
498 | " critic_output = self.critic(conv_output) # 상태가치 출력 계산\n",
499 | " actor_output = self.actor(conv_output) # 행동 출력 계산\n",
500 | "\n",
501 | " return critic_output, actor_output\n",
502 | "\n",
503 | " def act(self, x):\n",
504 | " '''상태 x일때 취할 확률을 확률적으로 구함'''\n",
505 | " value, actor_output = self(x)\n",
506 | " probs = F.softmax(actor_output, dim=1) # dim=1で行動の種類方向に計算\n",
507 | " action = probs.multinomial(num_samples=1)\n",
508 | "\n",
509 | " return action\n",
510 | "\n",
511 | " def get_value(self, x):\n",
512 | " '''상태 x의 상태가치를 구함'''\n",
513 | " value, actor_output = self(x)\n",
514 | "\n",
515 | " return value\n",
516 | "\n",
517 | " def evaluate_actions(self, x, actions):\n",
518 | " '''상태 x의 상태가치, 실제 행동 actions의 로그 확률, 엔트로피를 구함'''\n",
519 | " value, actor_output = self(x)\n",
520 | "\n",
521 | " log_probs = F.log_softmax(actor_output, dim=1) # dim=1이므로 행동의 종류 방향으로 계산\n",
522 | " action_log_probs = log_probs.gather(1, actions) # 실제 행동에 대한 log_probs 계산\n",
523 | "\n",
524 | " probs = F.softmax(actor_output, dim=1) # dim=1이므로 행동의 종류 방향으로 계산\n",
525 | " dist_entropy = -(log_probs * probs).sum(-1).mean()\n",
526 | "\n",
527 | " return value, action_log_probs, dist_entropy\n"
528 | ]
529 | },
530 | {
531 | "cell_type": "code",
532 | "execution_count": 11,
533 | "metadata": {},
534 | "outputs": [],
535 | "source": [
536 | "# 에이전트의 두뇌 역할을 하는 클래스로, 모든 에이전트가 공유한다\n",
537 | "\n",
538 | "\n",
539 | "class Brain(object):\n",
540 | " def __init__(self, actor_critic):\n",
541 | "\n",
542 | " self.actor_critic = actor_critic # actor_critic은 Net클래스로 구현한 신경망이다\n",
543 | "\n",
544 | " # 이미 학습된 결합 가중치를 로드하려면\n",
545 | " filename = 'weight_end.pth'\n",
546 | " #filename = 'weight_112500.pth'\n",
547 | " param = torch.load(filename, map_location='cpu')\n",
548 | " self.actor_critic.load_state_dict(param)\n",
549 | "\n",
550 | " # 가중치를 학습하는 최적화 알고리즘 설정\n",
551 | " self.optimizer = optim.RMSprop(\n",
552 | " actor_critic.parameters(), lr=lr, eps=eps, alpha=alpha)\n",
553 | "\n",
554 | " def update(self, rollouts):\n",
555 | " '''advanced 학습 대상 5단계를 모두 사용하여 수정한다'''\n",
556 | " obs_shape = rollouts.observations.size()[2:] # torch.Size([4, 84, 84])\n",
557 | " num_steps = NUM_ADVANCED_STEP\n",
558 | " num_processes = NUM_PROCESSES\n",
559 | "\n",
560 | " values, action_log_probs, dist_entropy = self.actor_critic.evaluate_actions(\n",
561 | " rollouts.observations[:-1].view(-1, *obs_shape),\n",
562 | " rollouts.actions.view(-1, 1))\n",
563 | "\n",
564 | " # 각 변수의 크기에 주의할 것\n",
565 | " # rollouts.observations[:-1].view(-1, *obs_shape) torch.Size([80, 4, 84, 84])\n",
566 | " # rollouts.actions.view(-1, 1) torch.Size([80, 1])\n",
567 | " # values torch.Size([80, 1])\n",
568 | " # action_log_probs torch.Size([80, 1])\n",
569 | " # dist_entropy torch.Size([])\n",
570 | "\n",
571 | " values = values.view(num_steps, num_processes,\n",
572 | " 1) # torch.Size([5, 16, 1])\n",
573 | " action_log_probs = action_log_probs.view(num_steps, num_processes, 1)\n",
574 | "\n",
575 | " advantages = rollouts.returns[:-1] - values # torch.Size([5, 16, 1])\n",
576 | " value_loss = advantages.pow(2).mean()\n",
577 | "\n",
578 | " action_gain = (advantages.detach() * action_log_probs).mean()\n",
579 | " # advantages는 detach 하여 정수로 취급한다\n",
580 | "\n",
581 | " total_loss = (value_loss * value_loss_coef -\n",
582 | " action_gain - dist_entropy * entropy_coef)\n",
583 | "\n",
584 | " self.optimizer.zero_grad() # 경사 초기화\n",
585 | " total_loss.backward() # 역전파 계산\n",
586 | " nn.utils.clip_grad_norm_(self.actor_critic.parameters(), max_grad_norm)\n",
587 | " # 한번에 결합 가중치가 너무 크게 변화하지 않도록, 경사의 최댓값을 0.5로 제한한다\n",
588 | "\n",
589 | " self.optimizer.step() # 결합 가중치 수정\n"
590 | ]
591 | },
592 | {
593 | "cell_type": "code",
594 | "execution_count": 12,
595 | "metadata": {},
596 | "outputs": [],
597 | "source": [
598 | "# Breakout을 실행하는 환경 클래스\n",
599 | "\n",
600 | "NUM_PROCESSES = 1\n",
601 | "\n",
602 | "\n",
603 | "class Environment:\n",
604 | " def run(self):\n",
605 | "\n",
606 | " # 난수 시드 설정\n",
607 | " seed_num = 1\n",
608 | " torch.manual_seed(seed_num)\n",
609 | " if use_cuda:\n",
610 | " torch.cuda.manual_seed(seed_num)\n",
611 | "\n",
612 | " # 실행환경 구축\n",
613 | " torch.set_num_threads(seed_num)\n",
614 | " envs = [make_env(ENV_NAME, seed_num, i) for i in range(NUM_PROCESSES)]\n",
615 | " envs = SubprocVecEnv(envs) # 멀티프로세스 실행환경\n",
616 | "\n",
617 | " # 모든 에이전트가 공유하는 두뇌 역할 클래스 Brain 객체 생성\n",
618 | " n_out = envs.action_space.n # 행동의 가짓수는 4\n",
619 | " actor_critic = Net(n_out).to(device) # GPU 사용\n",
620 | " global_brain = Brain(actor_critic)\n",
621 | "\n",
622 | " # 정보 저장용 변수 생성\n",
623 | " obs_shape = envs.observation_space.shape # (1, 84, 84)\n",
624 | " obs_shape = (obs_shape[0] * NUM_STACK_FRAME,\n",
625 | " *obs_shape[1:]) # (4, 84, 84)\n",
626 | " # torch.Size([16, 4, 84, 84])\n",
627 | " current_obs = torch.zeros(NUM_PROCESSES, *obs_shape).to(device)\n",
628 | " rollouts = RolloutStorage(\n",
629 | " NUM_ADVANCED_STEP, NUM_PROCESSES, obs_shape) # rollouts 객체\n",
630 | " episode_rewards = torch.zeros([NUM_PROCESSES, 1]) # 현재 에피소드에서 받을 보상 저장\n",
631 | " final_rewards = torch.zeros([NUM_PROCESSES, 1]) # 마지막 에피소드의 총 보상 저장\n",
632 | "\n",
633 | " # 초기 상태로 시작\n",
634 | " obs = envs.reset()\n",
635 | " obs = torch.from_numpy(obs).float() # torch.Size([16, 1, 84, 84])\n",
636 | " current_obs[:, -1:] = obs # 4번째 프레임에 가장 최근 관측결과를 저장\n",
637 | "\n",
638 | " # advanced 학습에 사용할 객체 rollouts에 첫번째 상태로 현재 상태를 저장\n",
639 | " rollouts.observations[0].copy_(current_obs)\n",
640 | "\n",
641 | " # 애니메이션 생성용 환경(시연용 추가)\n",
642 | " env_play = make_env_play(ENV_NAME, seed_num, 0)\n",
643 | " obs_play = env_play.reset()\n",
644 | "\n",
645 | " # 애니메이션 생성을 위해 이미지를 저장할 변수(시연용 추가)\n",
646 | " frames = []\n",
647 | " main_end = False\n",
648 | "\n",
649 | " # 주 반복문\n",
650 | " for j in tqdm(range(NUM_UPDATES)):\n",
651 | "\n",
652 | " # 보상이 기준을 넘어서면 종료 (시연용 추가)\n",
653 | " if main_end:\n",
654 | " break\n",
655 | "\n",
656 | " # advanced 학습 범위에 들어가는 단계마다 반복\n",
657 | " for step in range(NUM_ADVANCED_STEP):\n",
658 | "\n",
659 | " # 행동을 결정\n",
660 | " with torch.no_grad():\n",
661 | " action = actor_critic.act(rollouts.observations[step])\n",
662 | "\n",
663 | " cpu_actions = action.squeeze(1).cpu().numpy() # tensor를 NumPy 변수로\n",
664 | "\n",
665 | " # 1단계를 병렬로 실행, 반환값 obs의 크기는 (16, 1, 84, 84)\n",
666 | " obs, reward, done, info = envs.step(cpu_actions)\n",
667 | "\n",
668 | " # 보상을 텐서로 변환한 다음 에피소드 총 보상에 더함\n",
669 | " # 크기가 (16,)인 것을 (16, 1)로 변환\n",
670 | " reward = np.expand_dims(np.stack(reward), 1)\n",
671 | " reward = torch.from_numpy(reward).float()\n",
672 | " episode_rewards += reward\n",
673 | "\n",
674 | " # 각 프로세스마다 done이 True이면 0, False이면 1\n",
675 | " masks = torch.FloatTensor(\n",
676 | " [[0.0] if done_ else [1.0] for done_ in done])\n",
677 | "\n",
678 | " # 마지막 에피소드의 총 보상을 업데이트\n",
679 | " final_rewards *= masks # done이 True이면 0을 곱하고, False이면 1을 곱하여 리셋\n",
680 | " # done이 False이면 0을 더하고, True이면 epicodic_rewards를 더함\n",
681 | " final_rewards += (1 - masks) * episode_rewards\n",
682 | "\n",
683 | " # 이미지를 구함(시연용 추가)\n",
684 | " obs_play, reward_play, _, _ = env_play.step(cpu_actions[0])\n",
685 | " frames.append(obs_play) # 변환한 이미지를 저장\n",
686 | " if done[0]: # 첫번째 프로세스가 종료된 경우\n",
687 | " print(episode_rewards[0][0].numpy()) # 보상\n",
688 | "\n",
689 | " # 보상이 300을 초과하면 종료\n",
690 | " if (episode_rewards[0][0].numpy()) > 300:\n",
691 | " main_end = True\n",
692 | " break\n",
693 | " else:\n",
694 | " obs_view = env_play.reset()\n",
695 | " frames = [] # 저장한 이미지를 리셋\n",
696 | "\n",
697 | " # 에피소드의 총 보상을 업데이트\n",
698 | " episode_rewards *= masks # 각 프로세스마다 done이 True이면 0, False이면 1을 곱함\n",
699 | "\n",
700 | " # masks 변수를 GPU로 전달\n",
701 | " masks = masks.to(device)\n",
702 | "\n",
703 | " # done이 True이면 모두 0으로\n",
704 | " # mask의 크기를 torch.Size([16, 1]) --> torch.Size([16, 1, 1 ,1])로 변환하고 곱함\n",
705 | " current_obs *= masks.unsqueeze(2).unsqueeze(2)\n",
706 | "\n",
707 | " # 프레임을 모음\n",
708 | " # torch.Size([16, 1, 84, 84])\n",
709 | " obs = torch.from_numpy(obs).float()\n",
710 | " current_obs[:, :-1] = current_obs[:, 1:] # 0~2번째 프레임을 1~3번째 프레임으로 덮어씀\n",
711 | " current_obs[:, -1:] = obs # 4번째 프레임에 가장 최근 obs를 저장\n",
712 | "\n",
713 | " # 메모리 객체에 현 단계의 transition을 저장\n",
714 | " rollouts.insert(current_obs, action.data, reward, masks)\n",
715 | "\n",
716 | " # advanced 학습의 for문 끝\n",
717 | "\n",
718 | " # advanced 학습 대상 단계 중 마지막 단계의 상태에서 예상되는 상태가치를 계산\n",
719 | " with torch.no_grad():\n",
720 | " next_value = actor_critic.get_value(\n",
721 | " rollouts.observations[-1]).detach()\n",
722 | "\n",
723 | " # 모든 단계의 할인 총보상을 계산하고, rollouts의 변수 returns를 업데이트\n",
724 | " rollouts.compute_returns(next_value)\n",
725 | "\n",
726 | " # 신경망 수정 및 rollout 업데이트\n",
727 | " # global_brain.update(rollouts)\n",
728 | " rollouts.after_update()\n",
729 | "\n",
730 | " # 주 반복문 끝\n",
731 | " display_frames_as_gif(frames) # 애니메이션 저장 및 재생\n"
732 | ]
733 | },
734 | {
735 | "cell_type": "code",
736 | "execution_count": null,
737 | "metadata": {},
738 | "outputs": [],
739 | "source": [
740 | "# 실행\n",
741 | "breakout_env = Environment()\n",
742 | "frames = breakout_env.run()"
743 | ]
744 | }
745 | ],
746 | "metadata": {
747 | "kernelspec": {
748 | "display_name": "Python 3",
749 | "language": "python",
750 | "name": "python3"
751 | },
752 | "language_info": {
753 | "codemirror_mode": {
754 | "name": "ipython",
755 | "version": 3
756 | },
757 | "file_extension": ".py",
758 | "mimetype": "text/x-python",
759 | "name": "python",
760 | "nbconvert_exporter": "python",
761 | "pygments_lexer": "ipython3",
762 | "version": "3.6.6"
763 | }
764 | },
765 | "nbformat": 4,
766 | "nbformat_minor": 2
767 | }
768 |
--------------------------------------------------------------------------------
/program/weight_end.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wikibook/pytorch-drl/12643396c120f619fe9602a8aa66b4421c794b66/program/weight_end.pth
--------------------------------------------------------------------------------