├── .gitignore ├── Monte Carlo methods ├── .ipynb_checkpoints │ └── engine.py-checkpoint.ipynb ├── __pycache__ │ ├── car.cpython-36.pyc │ └── environment.cpython-36.pyc ├── car.py ├── donut.txt ├── engine.py.ipynb ├── environment.py ├── racetrack1_32x17.txt ├── racetrack2_30x32.txt ├── racetrack2_30x32_2.txt └── tests │ ├── __pycache__ │ └── test_car.cpython-36.pyc │ ├── fixtures │ └── scenario1.txt │ └── test_car.py ├── a2c ├── cartpole_a2c_episodic.ipynb ├── cartpole_a2c_online.ipynb ├── pendulum_a2c_episodic.ipynb ├── pendulum_a2c_online-Copy2.ipynb ├── pendulum_a2c_online.ipynb └── pendulum_a2c_online_agents.ipynb ├── dynamic_programming ├── .ipynb_checkpoints │ ├── gambler-checkpoint.ipynb │ ├── iterative policy evaluation-checkpoint.ipynb │ ├── jack bueno-checkpoint.ipynb │ ├── jack-Copy1-checkpoint.ipynb │ ├── jack-checkpoint.ipynb │ ├── policy iteration - Jack's car rental-checkpoint.ipynb │ └── policy iteration-checkpoint.ipynb ├── gambler.ipynb ├── iterative policy evaluation.ipynb ├── jack bueno.ipynb ├── jack-Copy1.ipynb ├── jack.ipynb ├── policy iteration - Jack's car rental.ipynb └── policy iteration.ipynb ├── open ai gym ├── .ipynb_checkpoints │ ├── cartpole - AC-checkpoint.ipynb │ ├── mountain car - n-step sarsa-checkpoint.ipynb │ ├── mountain car - sarsa-Copy1-checkpoint.ipynb │ └── mountain car - sarsa-checkpoint.ipynb ├── cartpole - AC.ipynb ├── mountain car - n-step sarsa.ipynb ├── mountain car - sarsa-Copy1.ipynb ├── mountain car - sarsa.ipynb ├── weights.npy └── weights │ ├── policy-cartpole-3000-iterations.h5 │ └── value-cartpole-3000-iterations.h5 ├── ppo └── cartpole_ppo_online.ipynb └── resources └── cartpole-2000-iterations.png /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | **/runs 3 | -------------------------------------------------------------------------------- /Monte Carlo methods/.ipynb_checkpoints/engine.py-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 444, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from importlib import reload\n", 10 | "import car; reload(car)\n", 11 | "import environment; reload(environment)\n", 12 | "Environment = environment.Environment\n", 13 | "Car = car.Car" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 671, 19 | "metadata": { 20 | "scrolled": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "import random\n", 25 | "def select_action_fn(car):\n", 26 | " action, greedy_action = None, None\n", 27 | " greedy_prob = 1 - car.greedy + car.greedy/len(car.actions)\n", 28 | " \n", 29 | " if car.P.get((car.position, car.speed), None):\n", 30 | " greedy_action = car.P[(car.position, car.speed)]\n", 31 | "\n", 32 | " possible_actions = list(filter(lambda action: car.speed[0] + action[0] <= 4, car.actions))\n", 33 | " possible_actions = list(filter(lambda action: car.speed[0] + action[0] >= -4, possible_actions))\n", 34 | " possible_actions = list(filter(lambda action: car.speed[1] + action[1] <= 4, possible_actions))\n", 35 | " possible_actions = list(filter(lambda action: car.speed[1] + action[1] >= -4, possible_actions))\n", 36 | " \n", 37 | " if greedy_action in possible_actions and random.random() <= greedy_prob:\n", 38 | " return greedy_action\n", 39 | " \n", 40 | " return random.choice(list(filter(lambda act: act != greedy_action, possible_actions)))\n", 41 | "\n", 42 | "env = Environment(\"racetrack2_30x32_2.txt\")\n", 43 | "car = Car(env, select_action_fn=select_action_fn, greedy=0.05)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "Started iteration 814 " 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "car.greedy = 0.1\n", 61 | "car.train(2000, update_policy_each=1)\n", 62 | "len(car.P)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "car._reset()\n", 72 | "path = []\n", 73 | "path.append(car.position)\n", 74 | "actions = []\n", 75 | "for i in range(5000):\n", 76 | " reward, old_state, action = car.step()\n", 77 | " actions.append(action)\n", 78 | " path.append(car.position)\n", 79 | " if env.is_finish(car.position): break\n", 80 | "\n", 81 | "print(len(path), path[:40],actions[:40])\n", 82 | "env.print(car, path)\n", 83 | " \n" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "# l = list(filter(lambda state: state[0][0] == 29 and state[1] == (0,0), car.P))\n", 93 | "# list(map(lambda k: (k, car.P[k], car.Q[k]), sorted(l, key=lambda state: state[0][1])))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "import numpy as np\n", 103 | "#car.P = {}\n", 104 | "#car.P = car.calculate_policy(car.Q)\n", 105 | "num_steps = []\n", 106 | "for i in range(100):\n", 107 | " print(\"\\riteration\", i, \" \", end=\"\")\n", 108 | " steps, _ = car.play()\n", 109 | " num_steps.append(len(steps))\n", 110 | "\n", 111 | "np.mean(num_steps), np.median(num_steps)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "import matplotlib.pyplot as plt\n", 121 | "%matplotlib inline\n", 122 | "plt.hist(num_steps, bins=100)\n", 123 | "None" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "# car.P = {}\n", 133 | "car._reset()\n", 134 | "steps, _ = car.play()\n", 135 | "waypoints = [state[0] for (state, action) in steps]\n", 136 | "print(len(waypoints))\n", 137 | "env.print(car, waypoints)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [] 146 | } 147 | ], 148 | "metadata": { 149 | "kernelspec": { 150 | "display_name": "Python 3", 151 | "language": "python", 152 | "name": "python3" 153 | }, 154 | "language_info": { 155 | "codemirror_mode": { 156 | "name": "ipython", 157 | "version": 3 158 | }, 159 | "file_extension": ".py", 160 | "mimetype": "text/x-python", 161 | "name": "python", 162 | "nbconvert_exporter": "python", 163 | "pygments_lexer": "ipython3", 164 | "version": "3.6.3" 165 | } 166 | }, 167 | "nbformat": 4, 168 | "nbformat_minor": 2 169 | } 170 | -------------------------------------------------------------------------------- /Monte Carlo methods/__pycache__/car.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hermesdt/reinforcement-learning/eb69484bb6d5a415633eb6e1fc07e12aa193cbb0/Monte Carlo methods/__pycache__/car.cpython-36.pyc -------------------------------------------------------------------------------- /Monte Carlo methods/__pycache__/environment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hermesdt/reinforcement-learning/eb69484bb6d5a415633eb6e1fc07e12aa193cbb0/Monte Carlo methods/__pycache__/environment.cpython-36.pyc -------------------------------------------------------------------------------- /Monte Carlo methods/car.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import random 3 | 4 | class Car(): 5 | def __init__(self, environment, select_action_fn, greedy=0.9): 6 | self.environment = environment 7 | self._reset() 8 | self.Q = defaultdict(lambda: defaultdict(lambda: float("-inf"))) 9 | self.returns_sum = defaultdict(float) 10 | self.returns_count = defaultdict(float) 11 | self.actions = [] 12 | self.greedy = greedy 13 | self.select_action_fn = select_action_fn 14 | 15 | for i in [-1, 0, 1]: 16 | for j in [-1, 0, 1]: 17 | if i == 0 and j == 0: continue 18 | self.actions.append((i, j)) 19 | 20 | self.P = {} 21 | 22 | def __str__(self): 23 | return "#" 24 | 25 | def _reset(self): 26 | self.speed = (0, 0) 27 | self.position = self.environment.select_start_position() 28 | 29 | def select_action(self): 30 | return self.select_action_fn(self) 31 | 32 | def train(self, episodes=10, update_policy_each=10): 33 | # self.Q = defaultdict(lambda: defaultdict(float)) 34 | # self.P = {} 35 | # self.returns_sum = defaultdict(float) 36 | # self.returns_count = defaultdict(float) 37 | 38 | for i in range(episodes): 39 | print("\rStarted iteration {} ".format(i), end="") 40 | 41 | steps, rewards = self.play() 42 | R = self._calculate_returns(steps, rewards) 43 | 44 | for (state, action) in R: 45 | reward = R[(state, action)] 46 | self.returns_sum[state] += reward 47 | self.returns_count[state] += 1 48 | q_state = self.Q[state] 49 | q_state[action] = self.returns_sum[state] / self.returns_count[state] 50 | 51 | print("\rFinished iteration {} ".format(i), end="") 52 | 53 | if i % update_policy_each == 0 and i > 0: 54 | self.P = self.calculate_policy(self.Q) 55 | 56 | print("") 57 | 58 | def _calculate_returns(self, steps, rewards): 59 | total_reward = 0 60 | R = defaultdict(float) 61 | for i in range(len(rewards)-1, -1, -1): 62 | state, action = steps[i] 63 | total_reward += rewards[i] 64 | R[(state, action)] = total_reward 65 | 66 | return R 67 | 68 | def calculate_policy(self, Q): 69 | P = {} 70 | for state in Q: 71 | max_value = float("-inf") 72 | max_action = None 73 | 74 | for action in Q[state]: 75 | if Q[state][action] > max_value: 76 | max_value = Q[state][action] 77 | max_action = action 78 | 79 | if max_action: 80 | P[state] = max_action 81 | 82 | return P 83 | 84 | def play(self): 85 | steps = [] 86 | rewards = [] 87 | count = 0 88 | self._reset() 89 | 90 | while True: 91 | reward, old_state, action = self.step() 92 | # if action == (0, 0): continue 93 | 94 | steps.append((old_state, action)) 95 | rewards.append(reward) 96 | count += 1 97 | 98 | if self.environment.is_finish(self.position): 99 | break 100 | 101 | return steps, rewards 102 | 103 | def step(self): 104 | old_state = (self.position, self.speed) 105 | 106 | action = self.select_action() 107 | self.speed = (self.speed[0] + action[0], self.speed[1] + action[1]) 108 | self.speed = (min(4, self.speed[0]), min(4, self.speed[1])) 109 | self.speed = (max(-4, self.speed[0]), max(-4, self.speed[1])) 110 | 111 | new_position = (self.position[0] + self.speed[0], self.position[1] + self.speed[1]) 112 | new_position, _path = self.environment.move_to(self, new_position) 113 | self.position = new_position 114 | 115 | if not self.environment.is_start(old_state[0]) and self.environment.is_start(new_position): 116 | self.speed = (0, 0) 117 | 118 | return self.reward(new_position), old_state, action 119 | 120 | def reward(self, new_position): 121 | if self.environment.is_finish(new_position): 122 | return 0 123 | 124 | return -1 125 | -------------------------------------------------------------------------------- /Monte Carlo methods/donut.txt: -------------------------------------------------------------------------------- 1 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ++++++++++++++++ ++++++++++++++++++ 3 | +++++++++++++ +++++++++++++++ 4 | ++++++++++ s ++++++++++++ 5 | ++++++++ +++++++++++++ ++++++++++ 6 | ++++++ +++++++++++++++ ++++++++ 7 | ++++++ +++++++++++++++ ++++++++ 8 | ++++++ +++++++++++++++ ++++++++ 9 | ++++++ +++++++++++++++ ++++++++ 10 | ++++++ +++++++++++++++ ++++++++ 11 | ++++++++ +++++++++++++ ++++++++++ 12 | ++++++++++ ffffffffff ++++++++++++ 13 | +++++++++++++ +++++++++++++++ 14 | ++++++++++++++++ ++++++++++++++++++ 15 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++ -------------------------------------------------------------------------------- /Monte Carlo methods/engine.py.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 444, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from importlib import reload\n", 10 | "import car; reload(car)\n", 11 | "import environment; reload(environment)\n", 12 | "Environment = environment.Environment\n", 13 | "Car = car.Car" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 671, 19 | "metadata": { 20 | "scrolled": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "import random\n", 25 | "def select_action_fn(car):\n", 26 | " action, greedy_action = None, None\n", 27 | " greedy_prob = 1 - car.greedy + car.greedy/len(car.actions)\n", 28 | " \n", 29 | " if car.P.get((car.position, car.speed), None):\n", 30 | " greedy_action = car.P[(car.position, car.speed)]\n", 31 | "\n", 32 | " possible_actions = list(filter(lambda action: car.speed[0] + action[0] <= 4, car.actions))\n", 33 | " possible_actions = list(filter(lambda action: car.speed[0] + action[0] >= -4, possible_actions))\n", 34 | " possible_actions = list(filter(lambda action: car.speed[1] + action[1] <= 4, possible_actions))\n", 35 | " possible_actions = list(filter(lambda action: car.speed[1] + action[1] >= -4, possible_actions))\n", 36 | " \n", 37 | " if greedy_action in possible_actions and random.random() <= greedy_prob:\n", 38 | " return greedy_action\n", 39 | " \n", 40 | " return random.choice(list(filter(lambda act: act != greedy_action, possible_actions)))\n", 41 | "\n", 42 | "env = Environment(\"racetrack2_30x32_2.txt\")\n", 43 | "car = Car(env, select_action_fn=select_action_fn, greedy=0.05)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "Started iteration 814 " 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "car.greedy = 0.1\n", 61 | "car.train(2000, update_policy_each=1)\n", 62 | "len(car.P)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "car._reset()\n", 72 | "path = []\n", 73 | "path.append(car.position)\n", 74 | "actions = []\n", 75 | "for i in range(5000):\n", 76 | " reward, old_state, action = car.step()\n", 77 | " actions.append(action)\n", 78 | " path.append(car.position)\n", 79 | " if env.is_finish(car.position): break\n", 80 | "\n", 81 | "print(len(path), path[:40],actions[:40])\n", 82 | "env.print(car, path)\n", 83 | " \n" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "# l = list(filter(lambda state: state[0][0] == 29 and state[1] == (0,0), car.P))\n", 93 | "# list(map(lambda k: (k, car.P[k], car.Q[k]), sorted(l, key=lambda state: state[0][1])))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "import numpy as np\n", 103 | "#car.P = {}\n", 104 | "#car.P = car.calculate_policy(car.Q)\n", 105 | "num_steps = []\n", 106 | "for i in range(100):\n", 107 | " print(\"\\riteration\", i, \" \", end=\"\")\n", 108 | " steps, _ = car.play()\n", 109 | " num_steps.append(len(steps))\n", 110 | "\n", 111 | "np.mean(num_steps), np.median(num_steps)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "import matplotlib.pyplot as plt\n", 121 | "%matplotlib inline\n", 122 | "plt.hist(num_steps, bins=100)\n", 123 | "None" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "# car.P = {}\n", 133 | "car._reset()\n", 134 | "steps, _ = car.play()\n", 135 | "waypoints = [state[0] for (state, action) in steps]\n", 136 | "print(len(waypoints))\n", 137 | "env.print(car, waypoints)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [] 146 | } 147 | ], 148 | "metadata": { 149 | "kernelspec": { 150 | "display_name": "Python 3", 151 | "language": "python", 152 | "name": "python3" 153 | }, 154 | "language_info": { 155 | "codemirror_mode": { 156 | "name": "ipython", 157 | "version": 3 158 | }, 159 | "file_extension": ".py", 160 | "mimetype": "text/x-python", 161 | "name": "python", 162 | "nbconvert_exporter": "python", 163 | "pygments_lexer": "ipython3", 164 | "version": "3.6.3" 165 | } 166 | }, 167 | "nbformat": 4, 168 | "nbformat_minor": 2 169 | } 170 | -------------------------------------------------------------------------------- /Monte Carlo methods/environment.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | class Environment(): 4 | def __init__(self, filename='racetrack1_32x17.txt'): 5 | self.scenario = self._load_scenario(filename) 6 | self.finish = "f" 7 | self.start = "s" 8 | self.track = " " 9 | self.wall = "+" 10 | self.waypoint = "·" 11 | 12 | def _load_scenario(self, filename): 13 | with open(filename) as f: 14 | scenario = list(map(lambda line: [c for c in line], f.read().split("\n"))) 15 | return scenario 16 | 17 | def print(self, car, path=[]): 18 | top_nums = " " + "".join(map(lambda n: str(n%10), range(0, len(self.scenario[0])))) 19 | buffer = top_nums + "\n " + "-"*len(self.scenario[0]) 20 | buffer += "\n" 21 | 22 | for row_index, row in enumerate(self.scenario): 23 | buffer += str(row_index % 10) + "|" 24 | for col_index, col in enumerate(self.scenario[row_index]): 25 | if car.position == (row_index, col_index): 26 | buffer += str(car) 27 | continue 28 | if any(filter(lambda waypoint: waypoint == (row_index, col_index), path)): 29 | buffer += self.waypoint 30 | continue 31 | 32 | buffer += self.scenario[row_index][col_index] 33 | 34 | buffer += "\n" 35 | 36 | print(buffer) 37 | 38 | def select_start_position(self): 39 | starts = [] 40 | for row_index, row in enumerate(self.scenario): 41 | for col_index, col in enumerate(self.scenario[row_index]): 42 | if self.scenario[row_index][col_index] == self.start: 43 | starts.append((row_index, col_index)) 44 | return random.choice(starts) 45 | 46 | def move_to(self, car, new_position): 47 | position = car.position 48 | increment_v = 1 if new_position[0] - position[0] >= 0 else -1 49 | increment_h = 1 if new_position[1] - position[1] >= 0 else -1 50 | 51 | path = [position] 52 | while position != new_position: 53 | if abs(new_position[0] - position[0]) >= abs(new_position[1] - position[1]): 54 | position = (position[0] + increment_v, position[1]) 55 | else: 56 | position = (position[0], position[1] + increment_h) 57 | 58 | path.append(position) 59 | if self.is_wall(position): 60 | return self.select_start_position(), path 61 | elif self.is_finish(position): 62 | return position, path 63 | 64 | return new_position, path 65 | 66 | def is_track(self, position): 67 | if self.scenario[position[0]][position[1]] == self.track: 68 | return True 69 | 70 | return False 71 | 72 | def is_wall(self, position): 73 | if position[0] < 0 or position[0] > len(self.scenario)-1: 74 | return True 75 | elif position[1] < 0 or position[1] > len(self.scenario[position[0]])-1: 76 | return True 77 | elif self.scenario[position[0]][position[1]] == self.wall: 78 | return True 79 | else: 80 | return False 81 | 82 | def is_start(self, position): 83 | if self.scenario[position[0]][position[1]] == self.start: 84 | return True 85 | return False 86 | 87 | def is_finish(self, position): 88 | if self.scenario[position[0]][position[1]] == self.finish: 89 | return True 90 | return False 91 | 92 | 93 | env = Environment("racetrack2_30x32.txt") -------------------------------------------------------------------------------- /Monte Carlo methods/racetrack1_32x17.txt: -------------------------------------------------------------------------------- 1 | +++ f 2 | ++ f 3 | ++ f 4 | + f 5 | f 6 | f 7 | +++++++ 8 | ++++++++ 9 | ++++++++ 10 | ++++++++ 11 | ++++++++ 12 | ++++++++ 13 | ++++++++ 14 | ++++++++ 15 | + ++++++++ 16 | + ++++++++ 17 | + ++++++++ 18 | + ++++++++ 19 | + ++++++++ 20 | + ++++++++ 21 | + ++++++++ 22 | + ++++++++ 23 | ++ ++++++++ 24 | ++ ++++++++ 25 | ++ ++++++++ 26 | ++ ++++++++ 27 | ++ ++++++++ 28 | ++ ++++++++ 29 | ++ ++++++++ 30 | +++ ++++++++ 31 | +++ ++++++++ 32 | +++ssssss++++++++ -------------------------------------------------------------------------------- /Monte Carlo methods/racetrack2_30x32.txt: -------------------------------------------------------------------------------- 1 | ++++++++++++++++ f 2 | +++++++++++++ f 3 | ++++++++++++ f 4 | +++++++++++ f 5 | +++++++++++ f 6 | +++++++++++ f 7 | +++++++++++ f 8 | ++++++++++++ f 9 | +++++++++++++ ++ 10 | ++++++++++++++ +++++ 11 | ++++++++++++++ ++++++ 12 | ++++++++++++++ ++++++++ 13 | ++++++++++++++ +++++++++ 14 | ++++++++++++++ +++++++++ 15 | +++++++++++++ +++++++++ 16 | ++++++++++++ +++++++++ 17 | +++++++++++ +++++++++ 18 | ++++++++++ +++++++++ 19 | +++++++++ +++++++++ 20 | ++++++++ +++++++++ 21 | +++++++ +++++++++ 22 | ++++++ +++++++++ 23 | +++++ +++++++++ 24 | ++++ +++++++++ 25 | +++ +++++++++ 26 | ++ +++++++++ 27 | + +++++++++ 28 | +++++++++ 29 | +++++++++ 30 | sssssssssssssssssssssss+++++++++ -------------------------------------------------------------------------------- /Monte Carlo methods/racetrack2_30x32_2.txt: -------------------------------------------------------------------------------- 1 | ++++++++++++++++ f 2 | +++++++++++++ f 3 | ++++++++++++ f 4 | +++++++++++ f 5 | +++++++++++ f 6 | +++++++++++ f 7 | +++++++++++ f 8 | ++++++ f 9 | ++++++ ++ 10 | ++++ +++++++++ 11 | +++ +++++++++ 12 | ++ +++++++++ 13 | + +++++++++ 14 | +++++++++ 15 | +++++++++ 16 | sssssssssssssssssssssss+++++++++ -------------------------------------------------------------------------------- /Monte Carlo methods/tests/__pycache__/test_car.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hermesdt/reinforcement-learning/eb69484bb6d5a415633eb6e1fc07e12aa193cbb0/Monte Carlo methods/tests/__pycache__/test_car.cpython-36.pyc -------------------------------------------------------------------------------- /Monte Carlo methods/tests/fixtures/scenario1.txt: -------------------------------------------------------------------------------- 1 | +++ f 2 | ++ f 3 | f 4 | +++++++ 5 | 6 | ++ +++++ 7 | ++ ++++++++ 8 | +++ ++++++++ 9 | +++ ++++++++ 10 | +++ssssss++++++++ -------------------------------------------------------------------------------- /Monte Carlo methods/tests/test_car.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import Mock 3 | from car import Car 4 | from environment import Environment 5 | 6 | class TestCar(unittest.TestCase): 7 | def setUp(self): 8 | self.environment = Environment(filename="tests/fixtures/scenario1.txt") 9 | self.car = Car(self.environment, select_action_fn=lambda car: (0, 0)) 10 | 11 | def test_default_actions(self): 12 | actions = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)] 13 | self.assertEqual(self.car.actions, actions) 14 | 15 | def test_str(self): 16 | self.assertEqual(str(self.car), "#") 17 | 18 | def test_reset(self): 19 | self.car.speed = (2, 2) 20 | self.position = (1, 5) 21 | self.environment.select_start_position = Mock(return_value=(0, 0)) 22 | self.car._reset() 23 | 24 | self.assertEqual(self.car.speed, (0, 0)) 25 | self.assertEqual(self.car.position, (0, 0)) 26 | self.environment.select_start_position.assert_called() 27 | 28 | def test_calculate_policy(self): 29 | Q = { 30 | "state1": { 31 | "action1": 1, 32 | "action2": 2, 33 | "action3": 0.5, 34 | }, 35 | "state2": { 36 | "action1": 0.5 37 | } 38 | } 39 | 40 | P = self.car.calculate_policy(Q) 41 | self.assertEqual(P, { 42 | 'state1': 'action2', 43 | 'state2': 'action1' 44 | }) 45 | 46 | def test_reward_for_finish_cell(self): 47 | reward = self.car.reward((1, 16)) 48 | self.assertEqual(reward, 0) 49 | 50 | def test_step_normal(self): 51 | self.car.position = (7, 5) 52 | self.car.speed = (0, 0) 53 | self.car.select_action = Mock(return_value=(1, 0)) 54 | reward, old_state, action = self.car.step() 55 | self.assertEqual(old_state, ((7, 5), (0, 0))) 56 | self.assertEqual((-1, (1, 0)), (reward, action)) 57 | self.assertEqual(self.car.position, (8, 5)) 58 | self.assertEqual(self.car.speed, (1, 0)) 59 | 60 | def test_step_speeds_over_maximum(self): 61 | self.car.position = (3, 5) 62 | self.car.speed = (4, 1) 63 | self.car.select_action = Mock(return_value=(1, 0)) 64 | reward, old_state, action = self.car.step() 65 | self.assertEqual((-1, (1, 0)), (reward, action)) 66 | self.assertEqual(self.car.position, (7, 6)) 67 | self.assertEqual(self.car.speed, (4, 1)) 68 | 69 | def test_step_speeds_under_minimum(self): 70 | self.car.position = (3, 10) 71 | self.car.speed = (0, -4) 72 | self.car.select_action = Mock(return_value=(0, -1)) 73 | reward, old_state, action = self.car.step() 74 | self.assertEqual((-1, (0, -1)), (reward, action)) 75 | self.assertEqual(self.car.position, (3, 6)) 76 | self.assertEqual(self.car.speed, (0, -4)) 77 | 78 | def test_step_out_of_board_above(self): 79 | self.car.position = (3, 10) 80 | self.car.speed = (-3, -4) 81 | self.car.select_action = Mock(return_value=(-1, 0)) 82 | self.environment.select_start_position = Mock(return_value=(9, 7)) 83 | reward, old_state, action = self.car.step() 84 | self.assertEqual((-1, (-1, 0)), (reward, action)) 85 | self.assertEqual(self.car.position, (9, 7)) 86 | self.assertEqual(self.car.speed, (0, 0)) 87 | 88 | def test_step_out_of_board_below(self): 89 | self.car.position = (7, 7) 90 | self.car.speed = (5, 0) 91 | self.car.select_action = Mock(return_value=(-1, 0)) 92 | self.environment.select_start_position = Mock(return_value=(9, 8)) 93 | reward, old_state, action = self.car.step() 94 | self.assertEqual((-1, (-1, 0)), (reward, action)) 95 | self.assertEqual(self.car.position, (9, 8)) 96 | self.assertEqual(self.car.speed, (0, 0)) 97 | 98 | def test_step_out_of_board_left(self): 99 | self.car.position = (4, 1) 100 | self.car.speed = (0, -1) 101 | self.car.select_action = Mock(return_value=(0, -1)) 102 | self.environment.select_start_position = Mock(return_value=(9, 8)) 103 | reward, old_state, action = self.car.step() 104 | self.assertEqual((-1, (0, -1)), (reward, action)) 105 | self.assertEqual(self.car.position, (9, 8)) 106 | self.assertEqual(self.car.speed, (0, 0)) 107 | 108 | def test_step_out_of_board_right(self): 109 | self.car.position = (5, 16) 110 | self.car.speed = (-1, 0) 111 | self.car.select_action = Mock(return_value=(0, 1)) 112 | self.environment.select_start_position = Mock(return_value=(9, 8)) 113 | reward, old_state, action = self.car.step() 114 | self.assertEqual((-1, (0, 1)), (reward, action)) 115 | self.assertEqual(self.car.position, (9, 8)) 116 | self.assertEqual(self.car.speed, (0, 0)) 117 | 118 | def test_step_finish(self): 119 | self.car.position = (1, 15) 120 | self.car.speed = (-1, 0) 121 | self.car.select_action = Mock(return_value=(0, 1)) 122 | reward, old_state, action = self.car.step() 123 | self.assertEqual((0, (0, 1)), (reward, action)) 124 | self.assertEqual(self.car.position, (0, 16)) 125 | self.assertEqual(self.car.speed, (-1, 1)) 126 | 127 | def test_step_cross_finish(self): 128 | self.car.position = (2, 13) 129 | self.car.speed = (0, 3) 130 | self.car.select_action = Mock(return_value=(0, 1)) 131 | reward, old_state, action = self.car.step() 132 | self.assertEqual((0, (0, 1)), (reward, action)) 133 | self.assertEqual(self.car.position, (2, 16)) 134 | self.assertEqual(self.car.speed, (0, 4)) 135 | 136 | def test_reward_for_non_finish_cell(self): 137 | self.assertEqual(self.car.reward((1, 15)), -1) 138 | self.assertEqual(self.car.reward((0, 0)), -1) 139 | self.assertEqual(self.car.reward((9, 5)), -1) 140 | 141 | def test_play(self): 142 | self.car.select_action = Mock(return_value=(0, 1)) 143 | self.environment.select_start_position = Mock(return_value=(1, 3)) 144 | steps, rewards = self.car.play() 145 | 146 | self.assertEqual(steps, [ 147 | (((1, 3), (0, 0)), (0, 1)), 148 | (((1, 4), (0, 1)), (0, 1)), 149 | (((1, 6), (0, 2)), (0, 1)), 150 | (((1, 9), (0, 3)), (0, 1)), 151 | (((1, 13), (0, 4)), (0, 1)) 152 | ]) 153 | 154 | self.assertEqual(rewards, [-1, -1, -1, -1, 0]) 155 | 156 | def test_train(self): 157 | self.car.select_action = Mock(return_value=(0, 1)) 158 | self.environment.select_start_position = Mock(return_value=(1, 3)) 159 | self.car.train(1) 160 | 161 | self.assertEqual(self.car.position, (1, 16)) 162 | self.assertEqual(self.car.speed, (0, 4)) 163 | 164 | self.assertEqual(dict(self.car.Q[((1, 4), (0, 1))]), {(0, 1): -3}) 165 | self.assertEqual(self.car.returns_sum[((1, 4), (0, 1))], -3) 166 | self.assertEqual(self.car.returns_count[((1, 4), (0, 1))], 1) 167 | 168 | def test_calculate_returns_only_first(self): 169 | steps = [ 170 | (((2, 3), (-3, 1)), (1, 0)), 171 | (((0, 3), (-3, 1)), (0, 1)), 172 | (((2, 3), (-3, 1)), (-1, 1)), 173 | (((1, 2), (2, 1)), (1, 0)), 174 | (((2, 3), (-3, 1)), (1, 0)), 175 | ] 176 | 177 | rewards = [ 178 | -1, -1, -1, -1, -1 179 | ] 180 | 181 | R = self.car._calculate_returns(steps, rewards) 182 | self.assertEqual(list(R.keys()), [(((2, 3), (-3, 1)), (1, 0)), 183 | (((1, 2), (2, 1)), (1, 0)), 184 | (((2, 3), (-3, 1)), (-1, 1)), 185 | (((0, 3), (-3, 1)), (0, 1)) 186 | ]) 187 | 188 | self.assertEqual(R[(((2, 3), (-3, 1)), (1, 0))], -5) 189 | 190 | def test_select_action(self): 191 | car = Car(self.environment, select_action_fn=lambda car: (-501, -660)) 192 | self.assertEqual(car.select_action(), (-501, -660)) 193 | 194 | -------------------------------------------------------------------------------- /a2c/pendulum_a2c_online.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import torch\n", 11 | "import gym\n", 12 | "from torch import nn\n", 13 | "from torch.nn import functional as F\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "from torch.utils.tensorboard import SummaryWriter\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 3, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "def mish(input):\n", 25 | " return input * torch.tanh(F.softplus(input))\n", 26 | "\n", 27 | "class Mish(nn.Module):\n", 28 | " def __init__(self): super().__init__()\n", 29 | " def forward(self, input): return mish(input)" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 4, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# helper function to convert numpy arrays to tensors\n", 39 | "def t(x):\n", 40 | " x = np.array(x) if not isinstance(x, np.ndarray) else x\n", 41 | " return torch.from_numpy(x).float()" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 5, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "class Actor(nn.Module):\n", 51 | " def __init__(self, state_dim, n_actions, activation=nn.Tanh):\n", 52 | " super().__init__()\n", 53 | " self.n_actions = n_actions\n", 54 | " self.model = nn.Sequential(\n", 55 | " nn.Linear(state_dim, 64),\n", 56 | " activation(),\n", 57 | " nn.Linear(64, 64),\n", 58 | " activation(),\n", 59 | " nn.Linear(64, n_actions)\n", 60 | " )\n", 61 | " \n", 62 | " logstds_param = nn.Parameter(torch.full((n_actions,), 0.1))\n", 63 | " self.register_parameter(\"logstds\", logstds_param)\n", 64 | " \n", 65 | " def forward(self, X):\n", 66 | " means = self.model(X)\n", 67 | " stds = torch.clamp(self.logstds.exp(), 1e-3, 50)\n", 68 | " \n", 69 | " return torch.distributions.Normal(means, stds)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 6, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "## Critic module\n", 79 | "class Critic(nn.Module):\n", 80 | " def __init__(self, state_dim, activation=nn.Tanh):\n", 81 | " super().__init__()\n", 82 | " self.model = nn.Sequential(\n", 83 | " nn.Linear(state_dim, 64),\n", 84 | " activation(),\n", 85 | " nn.Linear(64, 64),\n", 86 | " activation(),\n", 87 | " nn.Linear(64, 1),\n", 88 | " )\n", 89 | " \n", 90 | " def forward(self, X):\n", 91 | " return self.model(X)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 7, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "def discounted_rewards(rewards, dones, gamma):\n", 101 | " ret = 0\n", 102 | " discounted = []\n", 103 | " for reward, done in zip(rewards[::-1], dones[::-1]):\n", 104 | " ret = reward + ret * gamma * (1-done)\n", 105 | " discounted.append(ret)\n", 106 | " \n", 107 | " return discounted[::-1]" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 14, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "def process_memory(memory, gamma=0.99, discount_rewards=True):\n", 117 | " actions = []\n", 118 | " states = []\n", 119 | " next_states = []\n", 120 | " rewards = []\n", 121 | " dones = []\n", 122 | "\n", 123 | " for action, reward, state, next_state, done in memory:\n", 124 | " actions.append(action)\n", 125 | " rewards.append(reward)\n", 126 | " states.append(state)\n", 127 | " next_states.append(next_state)\n", 128 | " dones.append(done)\n", 129 | " \n", 130 | " if discount_rewards:\n", 131 | " if False and dones[-1] == 0:\n", 132 | " rewards = discounted_rewards(rewards + [last_value], dones + [0], gamma)[:-1]\n", 133 | " else:\n", 134 | " rewards = discounted_rewards(rewards, dones, gamma)\n", 135 | "\n", 136 | " actions = t(actions).view(-1, 1)\n", 137 | " states = t(states)\n", 138 | " next_states = t(next_states)\n", 139 | " rewards = t(rewards).view(-1, 1)\n", 140 | " dones = t(dones).view(-1, 1)\n", 141 | " return actions, rewards, states, next_states, dones\n", 142 | "\n", 143 | "def clip_grad_norm_(module, max_grad_norm):\n", 144 | " nn.utils.clip_grad_norm_([p for g in module.param_groups for p in g[\"params\"]], max_grad_norm)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 27, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "class A2CLearner():\n", 154 | " def __init__(self, actor, critic, gamma=0.9, entropy_beta=0,\n", 155 | " actor_lr=4e-4, critic_lr=4e-3, max_grad_norm=0.5):\n", 156 | " self.gamma = gamma\n", 157 | " self.max_grad_norm = max_grad_norm\n", 158 | " self.actor = actor\n", 159 | " self.critic = critic\n", 160 | " self.entropy_beta = entropy_beta\n", 161 | " self.actor_optim = torch.optim.Adam(actor.parameters(), lr=actor_lr)\n", 162 | " self.critic_optim = torch.optim.Adam(critic.parameters(), lr=critic_lr)\n", 163 | " \n", 164 | " def learn(self, memory, steps, discount_rewards=True):\n", 165 | " actions, rewards, states, next_states, dones = process_memory(memory, self.gamma, discount_rewards)\n", 166 | "\n", 167 | " if discount_rewards:\n", 168 | " td_target = rewards\n", 169 | " else:\n", 170 | " td_target = rewards + self.gamma*critic(next_states)*(1-dones)\n", 171 | " value = critic(states)\n", 172 | " advantage = td_target - value\n", 173 | "\n", 174 | " # actor\n", 175 | " norm_dists = self.actor(states)\n", 176 | " logs_probs = norm_dists.log_prob(actions)\n", 177 | " entropy = norm_dists.entropy().mean()\n", 178 | " \n", 179 | " actor_loss = (-logs_probs*advantage.detach()).mean() - entropy*self.entropy_beta\n", 180 | " self.actor_optim.zero_grad()\n", 181 | " actor_loss.backward()\n", 182 | " \n", 183 | " clip_grad_norm_(self.actor_optim, self.max_grad_norm)\n", 184 | " writer.add_histogram(\"gradients/actor\",\n", 185 | " torch.cat([p.grad.view(-1) for p in self.actor.parameters()]), global_step=steps)\n", 186 | " writer.add_histogram(\"parameters/actor\",\n", 187 | " torch.cat([p.data.view(-1) for p in self.actor.parameters()]), global_step=steps)\n", 188 | " self.actor_optim.step()\n", 189 | "\n", 190 | " # critic\n", 191 | " critic_loss = F.mse_loss(td_target, value)\n", 192 | " self.critic_optim.zero_grad()\n", 193 | " critic_loss.backward()\n", 194 | " clip_grad_norm_(self.critic_optim, self.max_grad_norm)\n", 195 | " writer.add_histogram(\"gradients/critic\",\n", 196 | " torch.cat([p.grad.view(-1) for p in self.critic.parameters()]), global_step=steps)\n", 197 | " writer.add_histogram(\"parameters/critic\",\n", 198 | " torch.cat([p.data.view(-1) for p in self.critic.parameters()]), global_step=steps)\n", 199 | " self.critic_optim.step()\n", 200 | " \n", 201 | " # reports\n", 202 | " writer.add_scalar(\"losses/log_probs\", -logs_probs.mean(), global_step=steps)\n", 203 | " writer.add_scalar(\"losses/entropy\", entropy, global_step=steps) \n", 204 | " writer.add_scalar(\"losses/entropy_beta\", self.entropy_beta, global_step=steps) \n", 205 | " writer.add_scalar(\"losses/actor\", actor_loss, global_step=steps)\n", 206 | " writer.add_scalar(\"losses/advantage\", advantage.mean(), global_step=steps)\n", 207 | " writer.add_scalar(\"losses/critic\", critic_loss, global_step=steps)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 28, 213 | "metadata": { 214 | "scrolled": false 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "class Runner():\n", 219 | " def __init__(self, env):\n", 220 | " self.env = env\n", 221 | " self.state = None\n", 222 | " self.done = True\n", 223 | " self.steps = 0\n", 224 | " self.episode_reward = 0\n", 225 | " self.episode_rewards = []\n", 226 | " \n", 227 | " def reset(self):\n", 228 | " self.episode_reward = 0\n", 229 | " self.done = False\n", 230 | " self.state = self.env.reset()\n", 231 | " \n", 232 | " def run(self, max_steps, memory=None):\n", 233 | " if not memory: memory = []\n", 234 | " \n", 235 | " for i in range(max_steps):\n", 236 | " if self.done: self.reset()\n", 237 | " \n", 238 | " dists = actor(t(self.state))\n", 239 | " actions = dists.sample().detach().data.numpy()\n", 240 | " actions_clipped = np.clip(actions, self.env.action_space.low.min(), env.action_space.high.max())\n", 241 | "\n", 242 | " next_state, reward, self.done, info = self.env.step(actions_clipped)\n", 243 | " memory.append((actions, reward, self.state, next_state, self.done))\n", 244 | "\n", 245 | " self.state = next_state\n", 246 | " self.steps += 1\n", 247 | " self.episode_reward += reward\n", 248 | " \n", 249 | " if self.done:\n", 250 | " self.episode_rewards.append(self.episode_reward)\n", 251 | " if len(self.episode_rewards) % 10 == 0:\n", 252 | " print(\"episode:\", len(self.episode_rewards), \", episode reward:\", self.episode_reward)\n", 253 | " writer.add_scalar(\"episode_reward\", self.episode_reward, global_step=self.steps)\n", 254 | " \n", 255 | " \n", 256 | " return memory" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 31, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "env = gym.make(\"Pendulum-v0\")\n", 266 | "writer = SummaryWriter(\"runs/mish_activation\")\n", 267 | "\n", 268 | "# config\n", 269 | "state_dim = env.observation_space.shape[0]\n", 270 | "n_actions = env.action_space.shape[0]\n", 271 | "actor = Actor(state_dim, n_actions, activation=Mish)\n", 272 | "critic = Critic(state_dim, activation=Mish)\n", 273 | "\n", 274 | "learner = A2CLearner(actor, critic)\n", 275 | "runner = Runner(env)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 32, 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "name": "stdout", 285 | "output_type": "stream", 286 | "text": [ 287 | "episode: 10 , episode reward: -1331.6603030893166\n", 288 | "episode: 20 , episode reward: -1308.5052810626362\n", 289 | "episode: 30 , episode reward: -1295.3439224521233\n", 290 | "episode: 40 , episode reward: -1264.9938388654693\n", 291 | "episode: 50 , episode reward: -1159.1376888101245\n", 292 | "episode: 60 , episode reward: -1096.290611383684\n", 293 | "episode: 70 , episode reward: -1332.3170406146635\n", 294 | "episode: 80 , episode reward: -529.9847037752162\n", 295 | "episode: 90 , episode reward: -823.6640709812223\n", 296 | "episode: 100 , episode reward: -716.3981973346959\n", 297 | "episode: 110 , episode reward: -267.55884160232574\n", 298 | "episode: 120 , episode reward: -1183.7605896140449\n", 299 | "episode: 130 , episode reward: -566.974616549244\n", 300 | "episode: 140 , episode reward: -137.6165269870926\n", 301 | "episode: 150 , episode reward: -691.5466228810094\n", 302 | "episode: 160 , episode reward: -662.1612165947158\n", 303 | "episode: 170 , episode reward: -1086.4806743221225\n", 304 | "episode: 180 , episode reward: -647.1489689713651\n", 305 | "episode: 190 , episode reward: -1238.4159859610431\n", 306 | "episode: 200 , episode reward: -263.25376664267634\n", 307 | "episode: 210 , episode reward: -271.50743676771833\n", 308 | "episode: 220 , episode reward: -130.07883036823304\n", 309 | "episode: 230 , episode reward: -982.4589363644424\n", 310 | "episode: 240 , episode reward: -1256.668380561799\n", 311 | "episode: 250 , episode reward: -1267.7317399215963\n", 312 | "episode: 260 , episode reward: -131.6000793348229\n", 313 | "episode: 270 , episode reward: -2.471643838360228\n", 314 | "episode: 280 , episode reward: -1374.7596483537313\n", 315 | "episode: 290 , episode reward: -132.09786114294778\n", 316 | "episode: 300 , episode reward: -283.41963077793525\n", 317 | "episode: 310 , episode reward: -287.6034814493751\n", 318 | "episode: 320 , episode reward: -126.69283889032661\n", 319 | "episode: 330 , episode reward: -1.133852311983069\n", 320 | "episode: 340 , episode reward: -266.44327146392425\n", 321 | "episode: 350 , episode reward: -273.25977412397543\n", 322 | "episode: 360 , episode reward: -3.144876597430591\n", 323 | "episode: 370 , episode reward: -1495.0052678631948\n", 324 | "episode: 380 , episode reward: -441.07285351652\n", 325 | "episode: 390 , episode reward: -132.61540054307525\n", 326 | "episode: 400 , episode reward: -137.02285450348222\n", 327 | "episode: 410 , episode reward: -133.8382833481163\n", 328 | "episode: 420 , episode reward: -1.4550722539388448\n", 329 | "episode: 430 , episode reward: -129.01782445352362\n", 330 | "episode: 440 , episode reward: -266.22297564145487\n", 331 | "episode: 450 , episode reward: -444.16119954649116\n", 332 | "episode: 460 , episode reward: -132.3216589765316\n", 333 | "episode: 470 , episode reward: -136.68120605004034\n", 334 | "episode: 480 , episode reward: -438.02075550748197\n", 335 | "episode: 490 , episode reward: -131.1251871496582\n", 336 | "episode: 500 , episode reward: -133.29030451790683\n" 337 | ] 338 | } 339 | ], 340 | "source": [ 341 | "steps_on_memory = 16\n", 342 | "episodes = 500\n", 343 | "episode_length = 200\n", 344 | "total_steps = (episode_length*episodes)//steps_on_memory\n", 345 | "\n", 346 | "for i in range(total_steps):\n", 347 | " memory = runner.run(steps_on_memory)\n", 348 | " learner.learn(memory, runner.steps, discount_rewards=False)" 349 | ] 350 | } 351 | ], 352 | "metadata": { 353 | "kernelspec": { 354 | "display_name": "Python 3", 355 | "language": "python", 356 | "name": "python3" 357 | }, 358 | "language_info": { 359 | "codemirror_mode": { 360 | "name": "ipython", 361 | "version": 3 362 | }, 363 | "file_extension": ".py", 364 | "mimetype": "text/x-python", 365 | "name": "python", 366 | "nbconvert_exporter": "python", 367 | "pygments_lexer": "ipython3", 368 | "version": "3.6.3" 369 | } 370 | }, 371 | "nbformat": 4, 372 | "nbformat_minor": 2 373 | } 374 | -------------------------------------------------------------------------------- /a2c/pendulum_a2c_online_agents.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import torch\n", 11 | "import gym\n", 12 | "from torch import nn\n", 13 | "from torch.nn import functional as F\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "from torch.utils.tensorboard import SummaryWriter\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 12, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "def mish(input):\n", 25 | " return input * torch.tanh(F.softplus(input))\n", 26 | "\n", 27 | "class Mish(nn.Module):\n", 28 | " def __init__(self): super().__init__()\n", 29 | " def forward(self, input): return mish(input)" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 13, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# helper function to convert numpy arrays to tensors\n", 39 | "def t(x):\n", 40 | " x = np.array(x) if not isinstance(x, np.ndarray) else x\n", 41 | " return torch.from_numpy(x).float()" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 14, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# Actor module, categorical actions only\n", 51 | "import math\n", 52 | "leaky = torch.nn.LeakyReLU()\n", 53 | "\n", 54 | "class Actor(nn.Module):\n", 55 | " def __init__(self, state_dim, n_actions):\n", 56 | " super().__init__()\n", 57 | " self.n_actions = n_actions\n", 58 | " self.model = nn.Sequential(\n", 59 | " nn.Linear(state_dim, 32),\n", 60 | " # nn.BatchNorm1d(32),\n", 61 | " Mish(),\n", 62 | " )\n", 63 | " self.means_head = nn.Sequential(\n", 64 | " nn.Linear(32, n_actions),\n", 65 | " # nn.BatchNorm1d(n_actions),\n", 66 | " nn.Tanh(),\n", 67 | " )\n", 68 | " self.stds_head = nn.Sequential(\n", 69 | " nn.Linear(32, n_actions),\n", 70 | " # nn.BatchNorm1d(n_actions),\n", 71 | " nn.Softplus(),\n", 72 | " )\n", 73 | " \n", 74 | " def forward(self, X):\n", 75 | " data = self.model(X)\n", 76 | " means = self.means_head(data)\n", 77 | " stds = (self.stds_head(data) + 1e-3)\n", 78 | " \n", 79 | " dists = torch.distributions.Normal(means*2, stds)\n", 80 | " \n", 81 | " return dists" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 15, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "## Critic module\n", 91 | "class Critic(nn.Module):\n", 92 | " def __init__(self, state_dim):\n", 93 | " super().__init__()\n", 94 | " self.model = nn.Sequential(\n", 95 | " nn.Linear(state_dim, 32),\n", 96 | " # nn.BatchNorm1d(n_actions),\n", 97 | " Mish(),\n", 98 | " nn.Linear(32, 32),\n", 99 | " # nn.BatchNorm1d(n_actions),\n", 100 | " Mish(),\n", 101 | " nn.Linear(32, 1),\n", 102 | " )\n", 103 | " \n", 104 | " def forward(self, X):\n", 105 | " return self.model(X)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 16, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "\n", 115 | "writer = SummaryWriter()\n", 116 | "env = gym.make(\"Pendulum-v0\")" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 17, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "# config\n", 126 | "state_dim = env.observation_space.shape[0]\n", 127 | "n_actions = env.action_space.shape[0]\n", 128 | "actor = Actor(state_dim, n_actions)\n", 129 | "critic = Critic(state_dim)\n", 130 | "adam_actor = torch.optim.Adam(actor.parameters(), lr=1e-4)#, weight_decay=0.001)\n", 131 | "adam_critic = torch.optim.Adam(critic.parameters(), lr=3e-4)#, weight_decay=0.001)\n", 132 | "gamma = 0.98\n", 133 | "entropy_beta = 1e-2\n", 134 | "memory = []" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 18, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "def train(memory):\n", 144 | " actions = []\n", 145 | " states = []\n", 146 | " next_states = []\n", 147 | " rewards = []\n", 148 | " dones = []\n", 149 | " \n", 150 | " for action, reward, state, next_state, done in memory:\n", 151 | " actions.append(action)\n", 152 | " rewards.append(reward)\n", 153 | " states.append(state)\n", 154 | " next_states.append(next_state)\n", 155 | " dones.append(done)\n", 156 | " \n", 157 | " actions = t(actions).view(-1, 1)\n", 158 | " states = t(states)\n", 159 | " next_states = t(next_states)\n", 160 | " rewards = t(rewards).view(-1, 1)\n", 161 | " dones = t(dones).view(-1, 1)\n", 162 | " \n", 163 | " with torch.no_grad():\n", 164 | " td_target = rewards + gamma*critic(next_states)*(1-dones)\n", 165 | " advantage = td_target - critic(states)\n", 166 | " \n", 167 | " norm_dists = actor(states)\n", 168 | " logs_probs = norm_dists.log_prob(actions)\n", 169 | " actor_loss = (-logs_probs*advantage).mean() - entropy_beta*norm_dists.entropy().detach().mean()\n", 170 | " writer.add_scalar(\"losses/actor\", actor_loss)\n", 171 | " adam_actor.zero_grad()\n", 172 | " actor_loss.backward()\n", 173 | " adam_actor.step()\n", 174 | " \n", 175 | " critic_loss = F.mse_loss(td_target, critic(t(states)))\n", 176 | " writer.add_scalar(\"losses/critic\", critic_loss)\n", 177 | " adam_critic.zero_grad()\n", 178 | " critic_loss.backward()\n", 179 | " adam_critic.step()" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 19, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "def build_runner(env, steps, memory):\n", 189 | " states = [env.reset()]\n", 190 | " total_reward = [0]\n", 191 | " episodes = [0]\n", 192 | " \n", 193 | " def runner():\n", 194 | " for _ in range(steps):\n", 195 | " state = states[0]\n", 196 | " dists = actor(t(state))\n", 197 | " actions = dists.sample()\n", 198 | " actions_clamped = torch.clamp(actions, env.action_space.low.min(), env.action_space.high.max())\n", 199 | "\n", 200 | " next_state, reward, done, info = env.step(actions_clamped.detach().data.numpy())\n", 201 | " memory.append((actions, reward, states[0], next_state, done))\n", 202 | " states[0] = next_state\n", 203 | " total_reward[0] += reward\n", 204 | " \n", 205 | " if done:\n", 206 | " episodes[0] += 1\n", 207 | " if episodes[0] % 20 == 0:\n", 208 | " print(f\"episode #{episodes[0]}, reward: {total_reward[0]}\")\n", 209 | " writer.add_scalar(\"rewards/episode\", total_reward[0])\n", 210 | " states[0] = env.reset()\n", 211 | " total_reward[0] = 0\n", 212 | " \n", 213 | " return runner\n", 214 | "\n", 215 | "runners = [build_runner(gym.make(\"Pendulum-v0\"), 4, memory) for _ in range(8)]" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 20, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "ename": "KeyboardInterrupt", 225 | "evalue": "", 226 | "output_type": "error", 227 | "traceback": [ 228 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 229 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 230 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mrunner\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrunners\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mrewards\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrunner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmemory\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mmemory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 231 | "\u001b[0;32m\u001b[0m in \u001b[0;36mrunner\u001b[0;34m()\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msteps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstates\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mdists\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mactor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0mactions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdists\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mactions_clamped\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclamp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mactions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhigh\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 232 | "\u001b[0;32m~/.asdf/installs/python/3.6.3/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 539\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 540\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 541\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 542\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 543\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 233 | "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0mmeans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmeans_head\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 28\u001b[0m \u001b[0mstds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstds_head\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1e-3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 234 | "\u001b[0;32m~/.asdf/installs/python/3.6.3/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__getattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 572\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'_parameters'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dict__\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 573\u001b[0;31m \u001b[0m_parameters\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dict__\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_parameters'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 574\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_parameters\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 575\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_parameters\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 235 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "for i in range(5000):\n", 241 | " for runner in runners:\n", 242 | " rewards = runner()\n", 243 | " train(memory)\n", 244 | " memory.clear()" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 11, 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "ename": "NameError", 254 | "evalue": "name 'episode_rewards' is not defined", 255 | "output_type": "error", 256 | "traceback": [ 257 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 258 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 259 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepisode_rewards\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepisode_rewards\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtitle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Total reward per episode (online)\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mylabel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"reward\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mxlabel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"episode\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 260 | "\u001b[0;31mNameError\u001b[0m: name 'episode_rewards' is not defined" 261 | ] 262 | } 263 | ], 264 | "source": [ 265 | "plt.scatter(np.arange(len(episode_rewards)), episode_rewards, s=2)\n", 266 | "plt.title(\"Total reward per episode (online)\")\n", 267 | "plt.ylabel(\"reward\")\n", 268 | "plt.xlabel(\"episode\")\n", 269 | "plt.show()" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 68, 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "data": { 279 | "text/plain": [ 280 | "tensor([-5.4888], grad_fn=)" 281 | ] 282 | }, 283 | "execution_count": 68, 284 | "metadata": {}, 285 | "output_type": "execute_result" 286 | } 287 | ], 288 | "source": [ 289 | "dists.entropy()" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 69, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "state = env.reset()" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 70, 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "data": { 308 | "text/plain": [ 309 | "tensor([-0.2877, 0.0264, -0.1438, 0.5027, 1.0066, -0.1544, 0.4802, -0.1561,\n", 310 | " 0.2354, -0.0054, 0.3285, 0.1031, 0.7267, -0.0575, -0.0306, 0.7434,\n", 311 | " 0.4232, -0.1967, 0.2230, -0.1707, 0.0087, -0.2611, 0.7855, 0.3474,\n", 312 | " 0.8359, 0.1104, -0.2812, -0.2202, 0.3276, 0.1480, -0.3086, -0.1037],\n", 313 | " grad_fn=)" 314 | ] 315 | }, 316 | "execution_count": 70, 317 | "metadata": {}, 318 | "output_type": "execute_result" 319 | } 320 | ], 321 | "source": [ 322 | "m = actor.model(t(state))\n", 323 | "m" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 71, 329 | "metadata": {}, 330 | "outputs": [ 331 | { 332 | "data": { 333 | "text/plain": [ 334 | "tensor([0.0014], grad_fn=)" 335 | ] 336 | }, 337 | "execution_count": 71, 338 | "metadata": {}, 339 | "output_type": "execute_result" 340 | } 341 | ], 342 | "source": [ 343 | "actor.stds_head(m)" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 72, 349 | "metadata": {}, 350 | "outputs": [], 351 | "source": [ 352 | "env.close()" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": null, 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [] 361 | } 362 | ], 363 | "metadata": { 364 | "kernelspec": { 365 | "display_name": "Python 3", 366 | "language": "python", 367 | "name": "python3" 368 | }, 369 | "language_info": { 370 | "codemirror_mode": { 371 | "name": "ipython", 372 | "version": 3 373 | }, 374 | "file_extension": ".py", 375 | "mimetype": "text/x-python", 376 | "name": "python", 377 | "nbconvert_exporter": "python", 378 | "pygments_lexer": "ipython3", 379 | "version": "3.6.3" 380 | } 381 | }, 382 | "nbformat": 4, 383 | "nbformat_minor": 2 384 | } 385 | -------------------------------------------------------------------------------- /dynamic_programming/.ipynb_checkpoints/gambler-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 292, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from collections import defaultdict\n", 10 | "v = defaultdict(lambda: (0, 0))\n", 11 | "v[100] = (1, 0)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 293, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "states = list(range(1, 100))\n", 21 | "\n", 22 | "def actions(state): return list(range(1, min(state, 100 - state) + 1))\n", 23 | "\n", 24 | "for i in range(1000):\n", 25 | " for state_idx, state in enumerate(states):\n", 26 | " max_value = -1\n", 27 | " max_action = -1\n", 28 | " old_value, old_action = v[state]\n", 29 | "\n", 30 | " for action_idx, action in enumerate(actions(state)):\n", 31 | " value = 0.6 * v[state - action][0] + 0.4 * v[state + action][0]\n", 32 | "\n", 33 | " if value > max_value+0.000001:\n", 34 | " max_value = value\n", 35 | " max_action = action\n", 36 | "\n", 37 | " v[state] = (max_value, max_action)\n", 38 | " " 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 294, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "import matplotlib.pyplot as plt" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 295, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "%matplotlib inline" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 296, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/plain": [ 67 | "" 68 | ] 69 | }, 70 | "execution_count": 296, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | }, 74 | { 75 | "data": { 76 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAE5JJREFUeJzt3WuMXHd5x/Hvw9rAAlU3ISsrXifYFZFR1IiYrkJQEKIJ\nxeEiYkVRBELUL1L5DaihRaZ2WwkhVUqQKy6VKqQoobgVhdBgnChIuKkThFqpgTVOcS64CSEp2Vxs\nIAu0XRXHPH0xZ8PamdmZszvX/3w/krUzZ856/vM7459nnz27E5mJJGn0vWzQC5AkdYeFLkmFsNAl\nqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSrEun7e2XnnnZebN2/u511K0sg7cuTITzJzut1+\nfS30zZs3Mzc318+7lKSRFxFPdrKfIxdJKoSFLkmFsNAlqRAWuiQVwkKXpEJ0dJZLRDwB/BI4DbyQ\nmbMRcS5wO7AZeAK4PjOf780ypd44eHSefYeO8/TCIhunJtm9fSs7ts0MelnSqtR5hf77mXlpZs5W\n1/cAhzPzIuBwdV0aGQePzrP3wDHmFxZJYH5hkb0HjnHw6PyglyatylpGLtcA+6vL+4Eda1+O1D/7\nDh1n8dTpM7YtnjrNvkPHB7QiaW06LfQE/jkijkTErmrbhsx8prr8LLCh2SdGxK6ImIuIuZMnT65x\nuVL3PL2wWGu7NOw6LfS3ZuabgHcBH46Ity2/MRvvNN303aYz85bMnM3M2enptj+5KvXNxqnJWtul\nYddRoWfmfPXxBPB14DLguYg4H6D6eKJXi5R6Yff2rUyunzhj2+T6CXZv3zqgFUlr07bQI+LVEfFb\nS5eBdwIPAncBO6vddgJ39mqRUi/s2DbDTddewszUJAHMTE1y07WXeJaLRlYnpy1uAL4eEUv7/2Nm\nfjMivgt8NSJuAJ4Eru/dMqXe2LFtxgJXMdoWemY+DryxyfafAlf1YlGSpPr8SVFJKoSFLkmFsNAl\nqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIK\nYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAW\nuiQVwkKXpEJY6JJUiI4LPSImIuJoRNxdXd8SEfdHxGMRcXtEvLx3y5QktVPnFfqNwCPLrn8K+Exm\nvh54HrihmwuTJNXTUaFHxCbgPcCt1fUArgTuqHbZD+zoxQIlSZ3p9BX6Z4GPA7+urr8WWMjMF6rr\nTwEzXV6bJKmGtoUeEe8FTmTmkdXcQUTsioi5iJg7efLkav4KSVIHOnmFfgXwvoh4AvgKjVHL54Cp\niFhX7bMJmG/2yZl5S2bOZubs9PR0F5YsSWqmbaFn5t7M3JSZm4H3A/dm5geB+4Drqt12Anf2bJWS\npLbWch76nwF/GhGP0Zip39adJUmSVmNd+11+IzO/BXyruvw4cFn3lyRJWg1/UlSSCmGhS1IhLHRJ\nKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RC\nWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSF\nLkmFsNAlqRAWuiQVwkKXpEK0LfSIeGVEfCci/iMiHoqIT1bbt0TE/RHxWETcHhEv7/1yJUmtdPIK\n/f+AKzPzjcClwNURcTnwKeAzmfl64Hnght4tU5LUTttCz4b/rq6ur/4kcCVwR7V9P7CjJyuUJHWk\noxl6RExExAPACeAe4IfAQma+UO3yFDDT4nN3RcRcRMydPHmyG2uWJDXRUaFn5unMvBTYBFwGvKHT\nO8jMWzJzNjNnp6enV7lMSVI7tc5yycwF4D7gLcBURKyrbtoEzHd5bZKkGjo5y2U6Iqaqy5PAHwCP\n0Cj266rddgJ39mqRkqT21rXfhfOB/RExQeM/gK9m5t0R8TDwlYj4K+AocFsP1ylJaqNtoWfm94Ft\nTbY/TmOeLkkaAv6kqCQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RC\nWOiSVIhOfjmXRsjBo/PsO3ScpxcW2Tg1ye7tW9mxrel7j0i1+fwabhZ6QQ4enWfvgWMsnjoNwPzC\nInsPHAPwH53WzOfX8HPkUpB9h46/+I9tyeKp0+w7dHxAK1JJfH4NPwu9IE8vLNbaLtXh82v4WegF\n2Tg1WWu7VIfPr+FnoRdk9/atTK6fOGPb5PoJdm/fOqAVqSQ+v4af3xQtyNI3pjwLQb3g82v4RWb2\n7c5mZ2dzbm6ub/cnSSWIiCOZOdtuP0cuklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKX\npEJY6JJUCAtdkgphoUtSIdoWekRcEBH3RcTDEfFQRNxYbT83Iu6JiEerj+f0frmSpFY6eYX+AvCx\nzLwYuBz4cERcDOwBDmfmRcDh6rokaUDaFnpmPpOZ36su/xJ4BJgBrgH2V7vtB3b0apGSpPZqzdAj\nYjOwDbgf2JCZz1Q3PQts6OrKJEm1dFzoEfEa4GvARzPzF8tvy8YvVW/6i9UjYldEzEXE3MmTJ9e0\nWElSax0VekSsp1HmX8rMA9Xm5yLi/Or284ETzT43M2/JzNnMnJ2enu7GmiVJTXRylksAtwGPZOan\nl910F7CzurwTuLP7y5MkdaqT9xS9AvgQcCwiHqi2/TlwM/DViLgBeBK4vjdLlCR1om2hZ+a/AtHi\n5qu6uxxJ0mr5k6KSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGh\nS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrok\nFWLdoBcwig4enWffoeM8vbDIb0+uJwIW/vcUG6cm2b19Kzu2zQx6iUPFvOoxr3rM6zcs9JoOHp1n\n74FjLJ46DcDC4qkXb5tfWGTvgWMAY/UkWol51WNe9ZjXmRy51LTv0PEXnzzNLJ46zb5Dx/u4ouFm\nXvWYVz3mdSYLvaanFxa7ss+4MK96zKse8zqThd6hg0fnueLme8kO9k3gipvv5eDR+V4va2iZVz3m\nVY95NecMvQNnz+k6MY7zuyXmVY951WNerfkKvQMrzemmJtdzzqvWN71t3OZ3S8yrHvOqx7xa8xV6\nB1rN4AJ44BPvBGDLnm80/fJvnOZ3S8yrHvOqx7xaa/sKPSK+EBEnIuLBZdvOjYh7IuLR6uM5vV3m\nYLSb022cmmx6eblxmt+ZVz3mVY95tdfJyOWLwNVnbdsDHM7Mi4DD1fWiLM3p5lv8jz65foLd27e+\neH339q1Mrp9ouu/S/K7UJxGYV13mVY95daZtoWfmt4GfnbX5GmB/dXk/sKPL6xq4leZ0M1OT3HTt\nJWd8c2XHthluuvYSZlq8Mih9fmde9ZhXPebVmdXO0Ddk5jPV5WeBDa12jIhdwC6ACy+8cJV3138r\nzen+bc+VTW/bsW2GHdtmxnJ+Z171mFc95tWZNZ/lkpkJrU8HzcxbMnM2M2enp6fXenc9tTSj27Ln\nG7wsouk+rWZznexT4vyuzlyzFfP6DfM6k3nVs9pCfy4izgeoPp7o3pIGY/mMLoHT+dKn0NlzulbG\nZX5Xd67Zink1mNeZzKu+1Rb6XcDO6vJO4M7uLGdwWs3oJiIIms/pWhmX+V3duWYr5mVezZhXfW1n\n6BHxZeDtwHkR8RTwCeBm4KsRcQPwJHB9LxfZD63mab/O5Ec3v6f23zcO87vVzDVbMS/zOpt51dfJ\nWS4fyMzzM3N9Zm7KzNsy86eZeVVmXpSZ78jMs8+CGRndmNGtpLT5Xbe+z9BKaXlBb59j5lVPiXkt\nN9Y/+t+tGd1KSprfdfP7DK2UlBf0/jlmXvWUltfZxrrQuzWjW0lJ87tufp+hlZLygt4/x8yrntLy\nOttY/y6Xbs7oVlLK/K7b32dopZS8oD/PMfOqp6S8zjaWr9B7PTdvZVTnd+ZV3yAyM696RjmvVsau\n0PsxN29lFOd35lXfoDIzr3pGNa+VjF2h92Nu3soozu/Mq75BZWZe9YxqXisZuxl6v+bmrYza/M68\n6htkZuZVzyjmtZKxeYU+qDlwK8M+vzOvenp9fn5dw54XDNdzbBTy6sRYFPog58CtDPP8zrzq6cf5\n+XUNc14wfM+xYc+rU2NR6IOcA7cyzPM786qnH+fn1zXMecHwPceGPa9OjcUMfdBz4FaGdX5nXvX0\n6/z8uoY1r5Xue5DPsWHOq1NFv0IfphndSoZlfmde9ZhXfaOQ2TDlVVexhT5sM7qVDMP8zrzqMa/6\nRiWzYclrNYot9GGb0a1kGOZ35lWPedU3KpkNS16rUewMfRhndCsZ9PzOvOoxr/pGKbNhyGs1inuF\nPgozupX0e35nXvWYVz3Ddn5+XaM2Ty+q0EdlRreSfs7vzKse86pnGM/Pr2vU5ulFFfqozOhW0s/5\nnXnVY171DOP5+XWN2jy9qBn6KM3oVtKv+Z151WNe9Qzr+fl1jdI8vYhX6KM+12ylV/M786rHvOox\nr8EZ+UIvYa7ZSi/md+ZlXkvMq55RmKePfKGXMNdspRfzO/MyryXmVc8ozNNHfoZeylyzlW7P78zL\nvJYzr3qGfZ4+sq/QS53TtbLW+Z15NZhXc+ZVz7DO00ey0Eue07WylvmdeZ3JvF7KvOoZ1nn6SBZ6\nyXO6VtYyvzOvlzKvM5lXPcM6Tx/JGXrpc7pWVju/My/z6oR51TOM8/SReoU+bnO6Vjqd35lXg3nV\nY171DNM8fWQKfRzndK10Mr/7y4PHzKtiXvWYVz3DNE+PbPILc3pldnY25+bmVvW5V9x8b8snz8zU\nJLu3by1uTreSg0fn2XfoeMtMJiKa/jIkMK9mzOtM5lVPu7xmpibXNH6KiCOZOdt2v7UUekRcDXwO\nmABuzcybV9p/LYXeak4VMFK/F6LbWuXSinmZVx3mVU+veqrTQl/1yCUiJoC/Bd4FXAx8ICIuXu3f\n14pzupW1evwTI/i7p/vBvOoxr3oGPU9fywz9MuCxzHw8M38FfAW4pjvLanBu3l6z+d3k+gk+8OYL\nmm43L/Oqw7zqGfQ8fS2nLc4AP152/SngzWtbzpnand86bnO6ZpYe/75Dx3l6YZGNy3KZfd25TbeP\nM/Oqx7zqWZ5XsxeiS+en9yqnVc/QI+I64OrM/KPq+oeAN2fmR87abxewC+DCCy/8vSeffLLj+3Bu\nLmlUdbO/ej5DB+aBC5Zd31RtO0Nm3pKZs5k5Oz09XesOWs2jxn1OJ2n4DaK/1lLo3wUuiogtEfFy\n4P3AXd1ZVkOr+d24z+kkDb9B9NeqZ+iZ+UJEfAQ4ROO0xS9k5kNdWxkrz+8kaZgNor9G5geLJGlc\n9WOGLkkaIha6JBXCQpekQljoklQIC12SCtHXs1wi4iTQ+Y+Knuk84CddXM4o8DGPBx9z+db6eF+X\nmW1/MrOvhb4WETHXyWk7JfExjwcfc/n69XgduUhSISx0SSrEKBX6LYNewAD4mMeDj7l8fXm8IzND\nlyStbJReoUuSVjAShR4RV0fE8Yh4LCL2DHo93RYRF0TEfRHxcEQ8FBE3VtvPjYh7IuLR6uM5g15r\nt0XEREQcjYi7q+tbIuL+6ljfXv1q5mJExFRE3BERP4iIRyLiLaUf54j4k+p5/WBEfDkiXlnacY6I\nL0TEiYh4cNm2psc1Gv6meuzfj4g3dWsdQ1/o/Xoz6gF7AfhYZl4MXA58uHqMe4DDmXkRcLi6Xpob\ngUeWXf8U8JnMfD3wPHDDQFbVO58DvpmZbwDeSOOxF3ucI2IG+GNgNjN/l8av2n4/5R3nLwJXn7Wt\n1XF9F3BR9WcX8PluLWLoC50+vBn1oGXmM5n5veryL2n8I5+h8Tj3V7vtB3YMZoW9ERGbgPcAt1bX\nA7gSuKPapajHHBG/DbwNuA0gM3+VmQsUfpxpvO/CZESsA14FPENhxzkzvw387KzNrY7rNcDfZ8O/\nA1MRcX431jEKhd7szaiLfYeLiNgMbAPuBzZk5jPVTc8CGwa0rF75LPBx4NfV9dcCC5n5QnW9tGO9\nBTgJ/F01Zro1Il5Nwcc5M+eBvwb+i0aR/xw4QtnHeUmr49qzThuFQh8bEfEa4GvARzPzF8tvy8bp\nSMWckhQR7wVOZOaRQa+lj9YBbwI+n5nbgP/hrPFKgcf5HBqvSLcAG4FX89LRRPH6dVxHodA7ejPq\nURcR62mU+Zcy80C1+bmlL8WqjycGtb4euAJ4X0Q8QWOMdiWN+fJU9aU5lHesnwKeysz7q+t30Cj4\nko/zO4AfZebJzDwFHKBx7Es+zktaHdeeddooFHrP34x60KrZ8W3AI5n56WU33QXsrC7vBO7s99p6\nJTP3ZuamzNxM45jem5kfBO4Drqt2K+0xPwv8OCKW3iX4KuBhCj7ONEYtl0fEq6rn+dJjLvY4L9Pq\nuN4F/GF1tsvlwM+XjWbWJjOH/g/wbuA/gR8CfzHo9fTg8b2Vxpdj3wceqP68m8ZM+TDwKPAvwLmD\nXmuPHv/bgbury78DfAd4DPgn4BWDXl+XH+ulwFx1rA8C55R+nIFPAj8AHgT+AXhFaccZ+DKN7xGc\novGV2A2tjisQNM7c+yFwjMYZQF1Zhz8pKkmFGIWRiySpAxa6JBXCQpekQljoklQIC12SCmGhS1Ih\nLHRJKoSFLkmF+H8NdIpPtXYXUgAAAABJRU5ErkJggg==\n", 77 | "text/plain": [ 78 | "" 79 | ] 80 | }, 81 | "metadata": {}, 82 | "output_type": "display_data" 83 | } 84 | ], 85 | "source": [ 86 | "x = [k for k in v.keys()]\n", 87 | "y = [a for (v, a) in v.values()]\n", 88 | "plt.scatter(x, y)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [] 104 | } 105 | ], 106 | "metadata": { 107 | "kernelspec": { 108 | "display_name": "Python 3", 109 | "language": "python", 110 | "name": "python3" 111 | }, 112 | "language_info": { 113 | "codemirror_mode": { 114 | "name": "ipython", 115 | "version": 3 116 | }, 117 | "file_extension": ".py", 118 | "mimetype": "text/x-python", 119 | "name": "python", 120 | "nbconvert_exporter": "python", 121 | "pygments_lexer": "ipython3", 122 | "version": "3.6.3" 123 | } 124 | }, 125 | "nbformat": 4, 126 | "nbformat_minor": 2 127 | } 128 | -------------------------------------------------------------------------------- /dynamic_programming/.ipynb_checkpoints/iterative policy evaluation-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 153, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "class Gridworld(object):\n", 19 | " def __init__(self, discount_factor = 1.0):\n", 20 | " self.states = range(16)\n", 21 | " self.num_states = len(self.states)\n", 22 | " self.actions = [\"up\", \"down\", \"right\", \"left\"]\n", 23 | " self.value_state = np.full(self.num_states, 0)\n", 24 | " self.discount_factor = discount_factor\n", 25 | " \n", 26 | " self.policy = np.full([self.num_states, len(self.actions), 1], 1/len(self.actions))\n", 27 | " self.reward = -1#np.full([self.num_states, len(self.actions), self.num_states], -1)\n", 28 | " self.probabilities = np.full([self.num_states, len(self.actions), self.num_states], 0)\n", 29 | " \n", 30 | " for from_state in self.states:\n", 31 | " for action in self.actions:\n", 32 | " action_index = self.actions.index(action)\n", 33 | " for to_state in self.states:\n", 34 | " p = self.probability(from_state, action, to_state)\n", 35 | " self.probabilities[from_state][action_index][to_state] = p\n", 36 | " \n", 37 | " \n", 38 | " def probability(self, from_position, action, to_position):\n", 39 | " num_rows, num_cols = self.num_states / 4, self.num_states / 4\n", 40 | " from_x, from_y = from_position//num_rows, from_position%num_cols\n", 41 | " to_x, to_y = to_position//num_rows, to_position%num_cols\n", 42 | " \n", 43 | " if abs(to_x - from_x) > 1: return 0\n", 44 | " if abs(to_y - from_y) > 1: return 0\n", 45 | " \n", 46 | " if from_x == 0 and from_x == to_x and from_y == to_y and action == \"up\":\n", 47 | " return 1\n", 48 | " if from_x == num_rows-1 and from_x == to_x and from_y == to_y and action == \"down\":\n", 49 | " return 1\n", 50 | " if from_y == 0 and from_x == to_x and from_y == to_y and action == \"left\":\n", 51 | " return 1\n", 52 | " if from_y == num_cols-1 and from_x == to_x and from_y == to_y and action == \"right\":\n", 53 | " return 1\n", 54 | " \n", 55 | " if from_x == to_x and from_y == to_y: return 0\n", 56 | " \n", 57 | " if from_x == to_x + 1 and from_y == to_y + 1: return 0\n", 58 | " if from_x == to_x - 1 and from_y == to_y - 1: return 0\n", 59 | " if from_x == to_x + 1 and from_y == to_y - 1: return 0\n", 60 | " if from_x == to_x - 1 and from_y == to_y + 1: return 0\n", 61 | " \n", 62 | " if from_x + 1 == to_x and action != \"down\": return 0\n", 63 | " if from_x - 1 == to_x and action != \"up\": return 0\n", 64 | " if from_y + 1 == to_y and action != \"right\": return 0\n", 65 | " if from_y - 1 == to_y and action != \"left\": return 0\n", 66 | " \n", 67 | " return 1\n", 68 | " \n", 69 | " def policy_eval(self):\n", 70 | " new_v_action_state = (self.value_state*self.discount_factor + self.reward)*self.probabilities*self.policy\n", 71 | " new_v = new_v_action_state.sum(axis=2).sum(axis=1)\n", 72 | " new_v.put([0, -1], [0, 0])\n", 73 | " return new_v\n", 74 | " " 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 154, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "array([[ 0. , -13.98945772, -19.98437823, -21.98251832],\n", 86 | " [-13.98945772, -17.98623815, -19.98448273, -19.98437823],\n", 87 | " [-19.98437823, -19.98448273, -17.98623815, -13.98945772],\n", 88 | " [-21.98251832, -19.98437823, -13.98945772, 0. ]])" 89 | ] 90 | }, 91 | "execution_count": 154, 92 | "metadata": {}, 93 | "output_type": "execute_result" 94 | } 95 | ], 96 | "source": [ 97 | "theta = 0.001\n", 98 | "grid = Gridworld()\n", 99 | "\n", 100 | "while True:\n", 101 | " delta = 0\n", 102 | " v = grid.value_state\n", 103 | " grid.value_state = grid.policy_eval()\n", 104 | " delta = max(delta, (v - grid.value_state).max())\n", 105 | " if delta < theta:\n", 106 | " break\n", 107 | " \n", 108 | "grid.value_state.reshape(4, 4)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "Python 3", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.6.3" 136 | } 137 | }, 138 | "nbformat": 4, 139 | "nbformat_minor": 2 140 | } 141 | -------------------------------------------------------------------------------- /dynamic_programming/.ipynb_checkpoints/policy iteration-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "\n", 17 | "1) Initialization\n", 18 | "V = np.random.uniform(A*B)\n", 19 | "policy(S) = random\n", 20 | "\n", 21 | "2) Policy evaluation\n", 22 | "theta = 0.001\n", 23 | "while True:\n", 24 | " delta = 0\n", 25 | " for state in states:\n", 26 | " v = V[state]\n", 27 | " acum = 0\n", 28 | " for new_state in states:\n", 29 | " acum += prob(state, policy(state), new_state) * (reward + gamma * V[new_state])\n", 30 | " \n", 31 | " delta = max(delta, abs(v - V[state]))\n", 32 | " \n", 33 | " if deleta < theta: break\n", 34 | "\n", 35 | "3) Policy improvement\n", 36 | "policy_stable = true\n", 37 | "for state in states:\n", 38 | " old_action = policy(state)\n", 39 | " \n", 40 | " for new_state in states:\n", 41 | " action_state_values = prob(state, policy(state), new_state) * (reward + gamma * V[new_state])\n", 42 | " max_action = argmax_Action(action_state_values)\n", 43 | " policy(state) = max_action\n", 44 | " \n", 45 | " if old_action != policy(state): policy_stable = false & goto 2)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 91, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "class Gridworld(object):\n", 55 | " def __init__(self):\n", 56 | " self.states = range(16)\n", 57 | " self.num_states = len(self.states)\n", 58 | " self.actions = [\"up\", \"down\", \"right\", \"left\"]\n", 59 | " \n", 60 | " #state_policy = [1, 0, 0, 0]\n", 61 | " #self.policy = np.zeros([self.num_states, len(self.actions)])\n", 62 | " #for p_index in range(self.policy.shape[0]):\n", 63 | " # self.policy[p_index] = np.random.permutation(state_policy)\n", 64 | " \n", 65 | " #self.policy = self.policy.reshape(self.num_states, len(self.actions), 1)\n", 66 | " \n", 67 | " self.reward = np.full([self.num_states, len(self.actions), self.num_states], -1)\n", 68 | " self.probabilities = np.full([self.num_states, len(self.actions), self.num_states], 0)\n", 69 | " \n", 70 | " for from_state in self.states:\n", 71 | " for action in self.actions:\n", 72 | " action_index = self.actions.index(action)\n", 73 | " for to_state in self.states:\n", 74 | " p = self.probability(from_state, action, to_state)\n", 75 | " self.probabilities[from_state][action_index][to_state] = p\n", 76 | " \n", 77 | " \n", 78 | " def probability(self, from_position, action, to_position):\n", 79 | " num_rows, num_cols = self.num_states / 4, self.num_states / 4\n", 80 | " from_x, from_y = from_position//num_rows, from_position%num_cols\n", 81 | " to_x, to_y = to_position//num_rows, to_position%num_cols\n", 82 | " \n", 83 | " if abs(to_x - from_x) > 1: return 0\n", 84 | " if abs(to_y - from_y) > 1: return 0\n", 85 | " \n", 86 | " if from_x == 0 and from_x == to_x and from_y == to_y and action == \"up\":\n", 87 | " return 1\n", 88 | " if from_x == num_rows-1 and from_x == to_x and from_y == to_y and action == \"down\":\n", 89 | " return 1\n", 90 | " if from_y == 0 and from_x == to_x and from_y == to_y and action == \"left\":\n", 91 | " return 1\n", 92 | " if from_y == num_cols-1 and from_x == to_x and from_y == to_y and action == \"right\":\n", 93 | " return 1\n", 94 | " \n", 95 | " if from_x == to_x and from_y == to_y: return 0\n", 96 | " \n", 97 | " if from_x == to_x + 1 and from_y == to_y + 1: return 0\n", 98 | " if from_x == to_x - 1 and from_y == to_y - 1: return 0\n", 99 | " if from_x == to_x + 1 and from_y == to_y - 1: return 0\n", 100 | " if from_x == to_x - 1 and from_y == to_y + 1: return 0\n", 101 | " \n", 102 | " if from_x + 1 == to_x and action != \"down\": return 0\n", 103 | " if from_x - 1 == to_x and action != \"up\": return 0\n", 104 | " if from_y + 1 == to_y and action != \"right\": return 0\n", 105 | " if from_y - 1 == to_y and action != \"left\": return 0\n", 106 | " \n", 107 | " return 1\n", 108 | " \n", 109 | " def policy_iteration(self):\n", 110 | " pass\n", 111 | "\n", 112 | "def policy_eval(policy, value_state, grid, discount_factor=1.0, theta=0.00001):\n", 113 | " while True:\n", 114 | " delta = 0\n", 115 | " old_value_state = value_state\n", 116 | " \n", 117 | " new_v_action_state = (value_state*discount_factor + grid.reward)*grid.probabilities*policy\n", 118 | " new_v = new_v_action_state.sum(axis=2).sum(axis=1)\n", 119 | " new_v.put([0, -1], [0, 0])\n", 120 | " value_state = new_v\n", 121 | " \n", 122 | " delta = max(delta, (old_value_state - value_state).max())\n", 123 | " if delta < theta:\n", 124 | " print(i)\n", 125 | " break\n", 126 | "\n", 127 | " return value_state " 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 179, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "0\n", 140 | "try again\n", 141 | "1\n", 142 | "try again\n", 143 | "2\n", 144 | "try again\n", 145 | "3\n" 146 | ] 147 | }, 148 | { 149 | "data": { 150 | "text/plain": [ 151 | "(array([[ 1., 0., 0., 0.],\n", 152 | " [ 0., 0., 0., 1.],\n", 153 | " [ 0., 0., 0., 1.],\n", 154 | " [ 0., 1., 0., 0.],\n", 155 | " [ 1., 0., 0., 0.],\n", 156 | " [ 1., 0., 0., 0.],\n", 157 | " [ 1., 0., 0., 0.],\n", 158 | " [ 0., 1., 0., 0.],\n", 159 | " [ 1., 0., 0., 0.],\n", 160 | " [ 1., 0., 0., 0.],\n", 161 | " [ 0., 1., 0., 0.],\n", 162 | " [ 0., 1., 0., 0.],\n", 163 | " [ 1., 0., 0., 0.],\n", 164 | " [ 0., 0., 1., 0.],\n", 165 | " [ 0., 0., 1., 0.],\n", 166 | " [ 0., 1., 0., 0.]]), array([[ 0., -1., -2., -3.],\n", 167 | " [-1., -2., -3., -2.],\n", 168 | " [-2., -3., -2., -1.],\n", 169 | " [-3., -2., -1., 0.]]))" 170 | ] 171 | }, 172 | "execution_count": 179, 173 | "metadata": {}, 174 | "output_type": "execute_result" 175 | } 176 | ], 177 | "source": [ 178 | "grid = Gridworld()\n", 179 | "\n", 180 | "discount_factor = 1\n", 181 | "theta = 0.0001\n", 182 | "value_state = np.full(grid.num_states, 0)\n", 183 | "policy = np.full([grid.num_states, len(grid.actions), 1], 1/len(grid.actions))\n", 184 | "\n", 185 | "def policy_iteration(policy, value_state, grid):\n", 186 | " value_state = policy_eval(policy, value_state, grid, discount_factor=discount_factor, theta=theta)\n", 187 | " old_policy = policy\n", 188 | " \n", 189 | " new_policy = (grid.probabilities * (grid.reward + discount_factor*value_state)).sum(axis=2).argmax(axis=1)\n", 190 | " \n", 191 | " new_policy = new_policy.reshape(new_policy.shape[0], 1)\n", 192 | " return new_policy, value_state\n", 193 | "\n", 194 | "for i in range(20):\n", 195 | " new_policy_argmax, value_state = policy_iteration(policy, value_state, grid)\n", 196 | " new_policy = np.zeros([grid.num_states, len(grid.actions)])\n", 197 | " new_policy[range(len(policy)), new_policy_argmax.reshape(new_policy_argmax.shape[0],)] = 1\n", 198 | " new_policy = new_policy.reshape([grid.num_states, len(grid.actions), 1])\n", 199 | "\n", 200 | " diff = np.abs(policy - new_policy).max(axis=1)\n", 201 | " policy = new_policy\n", 202 | " if diff[diff != 0].shape[0] != 0:\n", 203 | " print(\"try again\")\n", 204 | " else:\n", 205 | " break\n", 206 | "\n", 207 | "policy.reshape(16,4), value_state.reshape(4,4)\n", 208 | " \n", 209 | "\n", 210 | "#np.asmatrix(p.reshape(16,4))[:,[1,2]]" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [] 219 | } 220 | ], 221 | "metadata": { 222 | "kernelspec": { 223 | "display_name": "Python 3", 224 | "language": "python", 225 | "name": "python3" 226 | }, 227 | "language_info": { 228 | "codemirror_mode": { 229 | "name": "ipython", 230 | "version": 3 231 | }, 232 | "file_extension": ".py", 233 | "mimetype": "text/x-python", 234 | "name": "python", 235 | "nbconvert_exporter": "python", 236 | "pygments_lexer": "ipython3", 237 | "version": "3.6.3" 238 | } 239 | }, 240 | "nbformat": 4, 241 | "nbformat_minor": 2 242 | } 243 | -------------------------------------------------------------------------------- /dynamic_programming/gambler.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 292, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from collections import defaultdict\n", 10 | "v = defaultdict(lambda: (0, 0))\n", 11 | "v[100] = (1, 0)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 293, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "states = list(range(1, 100))\n", 21 | "\n", 22 | "def actions(state): return list(range(1, min(state, 100 - state) + 1))\n", 23 | "\n", 24 | "for i in range(1000):\n", 25 | " for state_idx, state in enumerate(states):\n", 26 | " max_value = -1\n", 27 | " max_action = -1\n", 28 | " old_value, old_action = v[state]\n", 29 | "\n", 30 | " for action_idx, action in enumerate(actions(state)):\n", 31 | " value = 0.6 * v[state - action][0] + 0.4 * v[state + action][0]\n", 32 | "\n", 33 | " if value > max_value+0.000001:\n", 34 | " max_value = value\n", 35 | " max_action = action\n", 36 | "\n", 37 | " v[state] = (max_value, max_action)\n", 38 | " " 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 294, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "import matplotlib.pyplot as plt" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 295, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "%matplotlib inline" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 296, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/plain": [ 67 | "" 68 | ] 69 | }, 70 | "execution_count": 296, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | }, 74 | { 75 | "data": { 76 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAE5JJREFUeJzt3WuMXHd5x/Hvw9rAAlU3ISsrXifYFZFR1IiYrkJQEKIJ\nxeEiYkVRBELUL1L5DaihRaZ2WwkhVUqQKy6VKqQoobgVhdBgnChIuKkThFqpgTVOcS64CSEp2Vxs\nIAu0XRXHPH0xZ8PamdmZszvX/3w/krUzZ856/vM7459nnz27E5mJJGn0vWzQC5AkdYeFLkmFsNAl\nqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSrEun7e2XnnnZebN2/u511K0sg7cuTITzJzut1+\nfS30zZs3Mzc318+7lKSRFxFPdrKfIxdJKoSFLkmFsNAlqRAWuiQVwkKXpEJ0dJZLRDwB/BI4DbyQ\nmbMRcS5wO7AZeAK4PjOf780ypd44eHSefYeO8/TCIhunJtm9fSs7ts0MelnSqtR5hf77mXlpZs5W\n1/cAhzPzIuBwdV0aGQePzrP3wDHmFxZJYH5hkb0HjnHw6PyglyatylpGLtcA+6vL+4Eda1+O1D/7\nDh1n8dTpM7YtnjrNvkPHB7QiaW06LfQE/jkijkTErmrbhsx8prr8LLCh2SdGxK6ImIuIuZMnT65x\nuVL3PL2wWGu7NOw6LfS3ZuabgHcBH46Ity2/MRvvNN303aYz85bMnM3M2enptj+5KvXNxqnJWtul\nYddRoWfmfPXxBPB14DLguYg4H6D6eKJXi5R6Yff2rUyunzhj2+T6CXZv3zqgFUlr07bQI+LVEfFb\nS5eBdwIPAncBO6vddgJ39mqRUi/s2DbDTddewszUJAHMTE1y07WXeJaLRlYnpy1uAL4eEUv7/2Nm\nfjMivgt8NSJuAJ4Eru/dMqXe2LFtxgJXMdoWemY+DryxyfafAlf1YlGSpPr8SVFJKoSFLkmFsNAl\nqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIK\nYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAW\nuiQVwkKXpEJY6JJUiI4LPSImIuJoRNxdXd8SEfdHxGMRcXtEvLx3y5QktVPnFfqNwCPLrn8K+Exm\nvh54HrihmwuTJNXTUaFHxCbgPcCt1fUArgTuqHbZD+zoxQIlSZ3p9BX6Z4GPA7+urr8WWMjMF6rr\nTwEzXV6bJKmGtoUeEe8FTmTmkdXcQUTsioi5iJg7efLkav4KSVIHOnmFfgXwvoh4AvgKjVHL54Cp\niFhX7bMJmG/2yZl5S2bOZubs9PR0F5YsSWqmbaFn5t7M3JSZm4H3A/dm5geB+4Drqt12Anf2bJWS\npLbWch76nwF/GhGP0Zip39adJUmSVmNd+11+IzO/BXyruvw4cFn3lyRJWg1/UlSSCmGhS1IhLHRJ\nKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RC\nWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSF\nLkmFsNAlqRAWuiQVwkKXpEK0LfSIeGVEfCci/iMiHoqIT1bbt0TE/RHxWETcHhEv7/1yJUmtdPIK\n/f+AKzPzjcClwNURcTnwKeAzmfl64Hnght4tU5LUTttCz4b/rq6ur/4kcCVwR7V9P7CjJyuUJHWk\noxl6RExExAPACeAe4IfAQma+UO3yFDDT4nN3RcRcRMydPHmyG2uWJDXRUaFn5unMvBTYBFwGvKHT\nO8jMWzJzNjNnp6enV7lMSVI7tc5yycwF4D7gLcBURKyrbtoEzHd5bZKkGjo5y2U6Iqaqy5PAHwCP\n0Cj266rddgJ39mqRkqT21rXfhfOB/RExQeM/gK9m5t0R8TDwlYj4K+AocFsP1ylJaqNtoWfm94Ft\nTbY/TmOeLkkaAv6kqCQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RC\nWOiSVIhOfjmXRsjBo/PsO3ScpxcW2Tg1ye7tW9mxrel7j0i1+fwabhZ6QQ4enWfvgWMsnjoNwPzC\nInsPHAPwH53WzOfX8HPkUpB9h46/+I9tyeKp0+w7dHxAK1JJfH4NPwu9IE8vLNbaLtXh82v4WegF\n2Tg1WWu7VIfPr+FnoRdk9/atTK6fOGPb5PoJdm/fOqAVqSQ+v4af3xQtyNI3pjwLQb3g82v4RWb2\n7c5mZ2dzbm6ub/cnSSWIiCOZOdtuP0cuklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKX\npEJY6JJUCAtdkgphoUtSIdoWekRcEBH3RcTDEfFQRNxYbT83Iu6JiEerj+f0frmSpFY6eYX+AvCx\nzLwYuBz4cERcDOwBDmfmRcDh6rokaUDaFnpmPpOZ36su/xJ4BJgBrgH2V7vtB3b0apGSpPZqzdAj\nYjOwDbgf2JCZz1Q3PQts6OrKJEm1dFzoEfEa4GvARzPzF8tvy8YvVW/6i9UjYldEzEXE3MmTJ9e0\nWElSax0VekSsp1HmX8rMA9Xm5yLi/Or284ETzT43M2/JzNnMnJ2enu7GmiVJTXRylksAtwGPZOan\nl910F7CzurwTuLP7y5MkdaqT9xS9AvgQcCwiHqi2/TlwM/DViLgBeBK4vjdLlCR1om2hZ+a/AtHi\n5qu6uxxJ0mr5k6KSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGh\nS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrok\nFWLdoBcwig4enWffoeM8vbDIb0+uJwIW/vcUG6cm2b19Kzu2zQx6iUPFvOoxr3rM6zcs9JoOHp1n\n74FjLJ46DcDC4qkXb5tfWGTvgWMAY/UkWol51WNe9ZjXmRy51LTv0PEXnzzNLJ46zb5Dx/u4ouFm\nXvWYVz3mdSYLvaanFxa7ss+4MK96zKse8zqThd6hg0fnueLme8kO9k3gipvv5eDR+V4va2iZVz3m\nVY95NecMvQNnz+k6MY7zuyXmVY951WNerfkKvQMrzemmJtdzzqvWN71t3OZ3S8yrHvOqx7xa8xV6\nB1rN4AJ44BPvBGDLnm80/fJvnOZ3S8yrHvOqx7xaa/sKPSK+EBEnIuLBZdvOjYh7IuLR6uM5vV3m\nYLSb022cmmx6eblxmt+ZVz3mVY95tdfJyOWLwNVnbdsDHM7Mi4DD1fWiLM3p5lv8jz65foLd27e+\neH339q1Mrp9ouu/S/K7UJxGYV13mVY95daZtoWfmt4GfnbX5GmB/dXk/sKPL6xq4leZ0M1OT3HTt\nJWd8c2XHthluuvYSZlq8Mih9fmde9ZhXPebVmdXO0Ddk5jPV5WeBDa12jIhdwC6ACy+8cJV3138r\nzen+bc+VTW/bsW2GHdtmxnJ+Z171mFc95tWZNZ/lkpkJrU8HzcxbMnM2M2enp6fXenc9tTSj27Ln\nG7wsouk+rWZznexT4vyuzlyzFfP6DfM6k3nVs9pCfy4izgeoPp7o3pIGY/mMLoHT+dKn0NlzulbG\nZX5Xd67Zink1mNeZzKu+1Rb6XcDO6vJO4M7uLGdwWs3oJiIIms/pWhmX+V3duWYr5mVezZhXfW1n\n6BHxZeDtwHkR8RTwCeBm4KsRcQPwJHB9LxfZD63mab/O5Ec3v6f23zcO87vVzDVbMS/zOpt51dfJ\nWS4fyMzzM3N9Zm7KzNsy86eZeVVmXpSZ78jMs8+CGRndmNGtpLT5Xbe+z9BKaXlBb59j5lVPiXkt\nN9Y/+t+tGd1KSprfdfP7DK2UlBf0/jlmXvWUltfZxrrQuzWjW0lJ87tufp+hlZLygt4/x8yrntLy\nOttY/y6Xbs7oVlLK/K7b32dopZS8oD/PMfOqp6S8zjaWr9B7PTdvZVTnd+ZV3yAyM696RjmvVsau\n0PsxN29lFOd35lXfoDIzr3pGNa+VjF2h92Nu3soozu/Mq75BZWZe9YxqXisZuxl6v+bmrYza/M68\n6htkZuZVzyjmtZKxeYU+qDlwK8M+vzOvenp9fn5dw54XDNdzbBTy6sRYFPog58CtDPP8zrzq6cf5\n+XUNc14wfM+xYc+rU2NR6IOcA7cyzPM786qnH+fn1zXMecHwPceGPa9OjcUMfdBz4FaGdX5nXvX0\n6/z8uoY1r5Xue5DPsWHOq1NFv0IfphndSoZlfmde9ZhXfaOQ2TDlVVexhT5sM7qVDMP8zrzqMa/6\nRiWzYclrNYot9GGb0a1kGOZ35lWPedU3KpkNS16rUewMfRhndCsZ9PzOvOoxr/pGKbNhyGs1inuF\nPgozupX0e35nXvWYVz3Ddn5+XaM2Ty+q0EdlRreSfs7vzKse86pnGM/Pr2vU5ulFFfqozOhW0s/5\nnXnVY171DOP5+XWN2jy9qBn6KM3oVtKv+Z151WNe9Qzr+fl1jdI8vYhX6KM+12ylV/M786rHvOox\nr8EZ+UIvYa7ZSi/md+ZlXkvMq55RmKePfKGXMNdspRfzO/MyryXmVc8ozNNHfoZeylyzlW7P78zL\nvJYzr3qGfZ4+sq/QS53TtbLW+Z15NZhXc+ZVz7DO00ey0Eue07WylvmdeZ3JvF7KvOoZ1nn6SBZ6\nyXO6VtYyvzOvlzKvM5lXPcM6Tx/JGXrpc7pWVju/My/z6oR51TOM8/SReoU+bnO6Vjqd35lXg3nV\nY171DNM8fWQKfRzndK10Mr/7y4PHzKtiXvWYVz3DNE+PbPILc3pldnY25+bmVvW5V9x8b8snz8zU\nJLu3by1uTreSg0fn2XfoeMtMJiKa/jIkMK9mzOtM5lVPu7xmpibXNH6KiCOZOdt2v7UUekRcDXwO\nmABuzcybV9p/LYXeak4VMFK/F6LbWuXSinmZVx3mVU+veqrTQl/1yCUiJoC/Bd4FXAx8ICIuXu3f\n14pzupW1evwTI/i7p/vBvOoxr3oGPU9fywz9MuCxzHw8M38FfAW4pjvLanBu3l6z+d3k+gk+8OYL\nmm43L/Oqw7zqGfQ8fS2nLc4AP152/SngzWtbzpnand86bnO6ZpYe/75Dx3l6YZGNy3KZfd25TbeP\nM/Oqx7zqWZ5XsxeiS+en9yqnVc/QI+I64OrM/KPq+oeAN2fmR87abxewC+DCCy/8vSeffLLj+3Bu\nLmlUdbO/ej5DB+aBC5Zd31RtO0Nm3pKZs5k5Oz09XesOWs2jxn1OJ2n4DaK/1lLo3wUuiogtEfFy\n4P3AXd1ZVkOr+d24z+kkDb9B9NeqZ+iZ+UJEfAQ4ROO0xS9k5kNdWxkrz+8kaZgNor9G5geLJGlc\n9WOGLkkaIha6JBXCQpekQljoklQIC12SCtHXs1wi4iTQ+Y+Knuk84CddXM4o8DGPBx9z+db6eF+X\nmW1/MrOvhb4WETHXyWk7JfExjwcfc/n69XgduUhSISx0SSrEKBX6LYNewAD4mMeDj7l8fXm8IzND\nlyStbJReoUuSVjAShR4RV0fE8Yh4LCL2DHo93RYRF0TEfRHxcEQ8FBE3VtvPjYh7IuLR6uM5g15r\nt0XEREQcjYi7q+tbIuL+6ljfXv1q5mJExFRE3BERP4iIRyLiLaUf54j4k+p5/WBEfDkiXlnacY6I\nL0TEiYh4cNm2psc1Gv6meuzfj4g3dWsdQ1/o/Xoz6gF7AfhYZl4MXA58uHqMe4DDmXkRcLi6Xpob\ngUeWXf8U8JnMfD3wPHDDQFbVO58DvpmZbwDeSOOxF3ucI2IG+GNgNjN/l8av2n4/5R3nLwJXn7Wt\n1XF9F3BR9WcX8PluLWLoC50+vBn1oGXmM5n5veryL2n8I5+h8Tj3V7vtB3YMZoW9ERGbgPcAt1bX\nA7gSuKPapajHHBG/DbwNuA0gM3+VmQsUfpxpvO/CZESsA14FPENhxzkzvw387KzNrY7rNcDfZ8O/\nA1MRcX431jEKhd7szaiLfYeLiNgMbAPuBzZk5jPVTc8CGwa0rF75LPBx4NfV9dcCC5n5QnW9tGO9\nBTgJ/F01Zro1Il5Nwcc5M+eBvwb+i0aR/xw4QtnHeUmr49qzThuFQh8bEfEa4GvARzPzF8tvy8bp\nSMWckhQR7wVOZOaRQa+lj9YBbwI+n5nbgP/hrPFKgcf5HBqvSLcAG4FX89LRRPH6dVxHodA7ejPq\nURcR62mU+Zcy80C1+bmlL8WqjycGtb4euAJ4X0Q8QWOMdiWN+fJU9aU5lHesnwKeysz7q+t30Cj4\nko/zO4AfZebJzDwFHKBx7Es+zktaHdeeddooFHrP34x60KrZ8W3AI5n56WU33QXsrC7vBO7s99p6\nJTP3ZuamzNxM45jem5kfBO4Drqt2K+0xPwv8OCKW3iX4KuBhCj7ONEYtl0fEq6rn+dJjLvY4L9Pq\nuN4F/GF1tsvlwM+XjWbWJjOH/g/wbuA/gR8CfzHo9fTg8b2Vxpdj3wceqP68m8ZM+TDwKPAvwLmD\nXmuPHv/bgbury78DfAd4DPgn4BWDXl+XH+ulwFx1rA8C55R+nIFPAj8AHgT+AXhFaccZ+DKN7xGc\novGV2A2tjisQNM7c+yFwjMYZQF1Zhz8pKkmFGIWRiySpAxa6JBXCQpekQljoklQIC12SCmGhS1Ih\nLHRJKoSFLkmF+H8NdIpPtXYXUgAAAABJRU5ErkJggg==\n", 77 | "text/plain": [ 78 | "" 79 | ] 80 | }, 81 | "metadata": {}, 82 | "output_type": "display_data" 83 | } 84 | ], 85 | "source": [ 86 | "x = [k for k in v.keys()]\n", 87 | "y = [a for (v, a) in v.values()]\n", 88 | "plt.scatter(x, y)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [] 104 | } 105 | ], 106 | "metadata": { 107 | "kernelspec": { 108 | "display_name": "Python 3", 109 | "language": "python", 110 | "name": "python3" 111 | }, 112 | "language_info": { 113 | "codemirror_mode": { 114 | "name": "ipython", 115 | "version": 3 116 | }, 117 | "file_extension": ".py", 118 | "mimetype": "text/x-python", 119 | "name": "python", 120 | "nbconvert_exporter": "python", 121 | "pygments_lexer": "ipython3", 122 | "version": "3.6.3" 123 | } 124 | }, 125 | "nbformat": 4, 126 | "nbformat_minor": 2 127 | } 128 | -------------------------------------------------------------------------------- /dynamic_programming/iterative policy evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 153, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "class Gridworld(object):\n", 19 | " def __init__(self, discount_factor = 1.0):\n", 20 | " self.states = range(16)\n", 21 | " self.num_states = len(self.states)\n", 22 | " self.actions = [\"up\", \"down\", \"right\", \"left\"]\n", 23 | " self.value_state = np.full(self.num_states, 0)\n", 24 | " self.discount_factor = discount_factor\n", 25 | " \n", 26 | " self.policy = np.full([self.num_states, len(self.actions), 1], 1/len(self.actions))\n", 27 | " self.reward = -1#np.full([self.num_states, len(self.actions), self.num_states], -1)\n", 28 | " self.probabilities = np.full([self.num_states, len(self.actions), self.num_states], 0)\n", 29 | " \n", 30 | " for from_state in self.states:\n", 31 | " for action in self.actions:\n", 32 | " action_index = self.actions.index(action)\n", 33 | " for to_state in self.states:\n", 34 | " p = self.probability(from_state, action, to_state)\n", 35 | " self.probabilities[from_state][action_index][to_state] = p\n", 36 | " \n", 37 | " \n", 38 | " def probability(self, from_position, action, to_position):\n", 39 | " num_rows, num_cols = self.num_states / 4, self.num_states / 4\n", 40 | " from_x, from_y = from_position//num_rows, from_position%num_cols\n", 41 | " to_x, to_y = to_position//num_rows, to_position%num_cols\n", 42 | " \n", 43 | " if abs(to_x - from_x) > 1: return 0\n", 44 | " if abs(to_y - from_y) > 1: return 0\n", 45 | " \n", 46 | " if from_x == 0 and from_x == to_x and from_y == to_y and action == \"up\":\n", 47 | " return 1\n", 48 | " if from_x == num_rows-1 and from_x == to_x and from_y == to_y and action == \"down\":\n", 49 | " return 1\n", 50 | " if from_y == 0 and from_x == to_x and from_y == to_y and action == \"left\":\n", 51 | " return 1\n", 52 | " if from_y == num_cols-1 and from_x == to_x and from_y == to_y and action == \"right\":\n", 53 | " return 1\n", 54 | " \n", 55 | " if from_x == to_x and from_y == to_y: return 0\n", 56 | " \n", 57 | " if from_x == to_x + 1 and from_y == to_y + 1: return 0\n", 58 | " if from_x == to_x - 1 and from_y == to_y - 1: return 0\n", 59 | " if from_x == to_x + 1 and from_y == to_y - 1: return 0\n", 60 | " if from_x == to_x - 1 and from_y == to_y + 1: return 0\n", 61 | " \n", 62 | " if from_x + 1 == to_x and action != \"down\": return 0\n", 63 | " if from_x - 1 == to_x and action != \"up\": return 0\n", 64 | " if from_y + 1 == to_y and action != \"right\": return 0\n", 65 | " if from_y - 1 == to_y and action != \"left\": return 0\n", 66 | " \n", 67 | " return 1\n", 68 | " \n", 69 | " def policy_eval(self):\n", 70 | " new_v_action_state = (self.value_state*self.discount_factor + self.reward)*self.probabilities*self.policy\n", 71 | " new_v = new_v_action_state.sum(axis=2).sum(axis=1)\n", 72 | " new_v.put([0, -1], [0, 0])\n", 73 | " return new_v\n", 74 | " " 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 154, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "array([[ 0. , -13.98945772, -19.98437823, -21.98251832],\n", 86 | " [-13.98945772, -17.98623815, -19.98448273, -19.98437823],\n", 87 | " [-19.98437823, -19.98448273, -17.98623815, -13.98945772],\n", 88 | " [-21.98251832, -19.98437823, -13.98945772, 0. ]])" 89 | ] 90 | }, 91 | "execution_count": 154, 92 | "metadata": {}, 93 | "output_type": "execute_result" 94 | } 95 | ], 96 | "source": [ 97 | "theta = 0.001\n", 98 | "grid = Gridworld()\n", 99 | "\n", 100 | "while True:\n", 101 | " delta = 0\n", 102 | " v = grid.value_state\n", 103 | " grid.value_state = grid.policy_eval()\n", 104 | " delta = max(delta, (v - grid.value_state).max())\n", 105 | " if delta < theta:\n", 106 | " break\n", 107 | " \n", 108 | "grid.value_state.reshape(4, 4)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "Python 3", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.6.3" 136 | } 137 | }, 138 | "nbformat": 4, 139 | "nbformat_minor": 2 140 | } 141 | -------------------------------------------------------------------------------- /dynamic_programming/policy iteration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "\n", 17 | "1) Initialization\n", 18 | "V = np.random.uniform(A*B)\n", 19 | "policy(S) = random\n", 20 | "\n", 21 | "2) Policy evaluation\n", 22 | "theta = 0.001\n", 23 | "while True:\n", 24 | " delta = 0\n", 25 | " for state in states:\n", 26 | " v = V[state]\n", 27 | " acum = 0\n", 28 | " for new_state in states:\n", 29 | " acum += prob(state, policy(state), new_state) * (reward + gamma * V[new_state])\n", 30 | " \n", 31 | " delta = max(delta, abs(v - V[state]))\n", 32 | " \n", 33 | " if deleta < theta: break\n", 34 | "\n", 35 | "3) Policy improvement\n", 36 | "policy_stable = true\n", 37 | "for state in states:\n", 38 | " old_action = policy(state)\n", 39 | " \n", 40 | " for new_state in states:\n", 41 | " action_state_values = prob(state, policy(state), new_state) * (reward + gamma * V[new_state])\n", 42 | " max_action = argmax_Action(action_state_values)\n", 43 | " policy(state) = max_action\n", 44 | " \n", 45 | " if old_action != policy(state): policy_stable = false & goto 2)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 181, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "class Gridworld(object):\n", 55 | " def __init__(self):\n", 56 | " self.states = range(16)\n", 57 | " self.num_states = len(self.states)\n", 58 | " self.actions = [\"up\", \"down\", \"right\", \"left\"]\n", 59 | " \n", 60 | " #state_policy = [1, 0, 0, 0]\n", 61 | " #self.policy = np.zeros([self.num_states, len(self.actions)])\n", 62 | " #for p_index in range(self.policy.shape[0]):\n", 63 | " # self.policy[p_index] = np.random.permutation(state_policy)\n", 64 | " \n", 65 | " #self.policy = self.policy.reshape(self.num_states, len(self.actions), 1)\n", 66 | " \n", 67 | " self.reward = np.full([self.num_states, len(self.actions), self.num_states], -1)\n", 68 | " self.probabilities = np.full([self.num_states, len(self.actions), self.num_states], 0)\n", 69 | " \n", 70 | " for from_state in self.states:\n", 71 | " for action in self.actions:\n", 72 | " action_index = self.actions.index(action)\n", 73 | " for to_state in self.states:\n", 74 | " p = self.probability(from_state, action, to_state)\n", 75 | " self.probabilities[from_state][action_index][to_state] = p\n", 76 | " \n", 77 | " \n", 78 | " def probability(self, from_position, action, to_position):\n", 79 | " num_rows, num_cols = self.num_states / 4, self.num_states / 4\n", 80 | " from_x, from_y = from_position//num_rows, from_position%num_cols\n", 81 | " to_x, to_y = to_position//num_rows, to_position%num_cols\n", 82 | " \n", 83 | " if abs(to_x - from_x) > 1: return 0\n", 84 | " if abs(to_y - from_y) > 1: return 0\n", 85 | " \n", 86 | " if from_x == 0 and from_x == to_x and from_y == to_y and action == \"up\":\n", 87 | " return 1\n", 88 | " if from_x == num_rows-1 and from_x == to_x and from_y == to_y and action == \"down\":\n", 89 | " return 1\n", 90 | " if from_y == 0 and from_x == to_x and from_y == to_y and action == \"left\":\n", 91 | " return 1\n", 92 | " if from_y == num_cols-1 and from_x == to_x and from_y == to_y and action == \"right\":\n", 93 | " return 1\n", 94 | " \n", 95 | " if from_x == to_x and from_y == to_y: return 0\n", 96 | " \n", 97 | " if from_x == to_x + 1 and from_y == to_y + 1: return 0\n", 98 | " if from_x == to_x - 1 and from_y == to_y - 1: return 0\n", 99 | " if from_x == to_x + 1 and from_y == to_y - 1: return 0\n", 100 | " if from_x == to_x - 1 and from_y == to_y + 1: return 0\n", 101 | " \n", 102 | " if from_x + 1 == to_x and action != \"down\": return 0\n", 103 | " if from_x - 1 == to_x and action != \"up\": return 0\n", 104 | " if from_y + 1 == to_y and action != \"right\": return 0\n", 105 | " if from_y - 1 == to_y and action != \"left\": return 0\n", 106 | " \n", 107 | " return 1\n", 108 | " \n", 109 | " def policy_iteration(self):\n", 110 | " pass\n", 111 | "\n", 112 | "def policy_eval(policy, value_state, grid, discount_factor=1.0, theta=0.00001):\n", 113 | " while True:\n", 114 | " delta = 0\n", 115 | " old_value_state = value_state\n", 116 | " \n", 117 | " new_v_action_state = (value_state*discount_factor + grid.reward)*grid.probabilities*policy\n", 118 | " new_v = new_v_action_state.sum(axis=2).sum(axis=1)\n", 119 | " new_v.put([0, -1], [0, 0])\n", 120 | " value_state = new_v\n", 121 | " \n", 122 | " delta = max(delta, (old_value_state - value_state).max())\n", 123 | " if delta < theta:\n", 124 | " print(i)\n", 125 | " break\n", 126 | "\n", 127 | " return value_state " 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 186, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "0\n", 140 | "try again\n", 141 | "1\n", 142 | "try again\n", 143 | "2\n", 144 | "try again\n", 145 | "3\n" 146 | ] 147 | }, 148 | { 149 | "data": { 150 | "text/plain": [ 151 | "(array([[0, 3, 3, 1],\n", 152 | " [0, 0, 0, 1],\n", 153 | " [0, 0, 1, 1],\n", 154 | " [0, 2, 2, 1]]), array([[ 0., -1., -2., -3.],\n", 155 | " [-1., -2., -3., -2.],\n", 156 | " [-2., -3., -2., -1.],\n", 157 | " [-3., -2., -1., 0.]]))" 158 | ] 159 | }, 160 | "execution_count": 186, 161 | "metadata": {}, 162 | "output_type": "execute_result" 163 | } 164 | ], 165 | "source": [ 166 | "grid = Gridworld()\n", 167 | "\n", 168 | "discount_factor = 1\n", 169 | "theta = 0.0001\n", 170 | "value_state = np.full(grid.num_states, 0)\n", 171 | "policy = np.full([grid.num_states, len(grid.actions), 1], 1/len(grid.actions))\n", 172 | "\n", 173 | "def policy_iteration(policy, value_state, grid):\n", 174 | " value_state = policy_eval(policy, value_state, grid, discount_factor=discount_factor, theta=theta)\n", 175 | " old_policy = policy\n", 176 | " \n", 177 | " new_policy = (grid.probabilities * (grid.reward + discount_factor*value_state)).sum(axis=2).argmax(axis=1)\n", 178 | " new_policy = new_policy.reshape(new_policy.shape[0], 1)\n", 179 | " return new_policy, value_state\n", 180 | "\n", 181 | "for i in range(20):\n", 182 | " new_policy_argmax, value_state = policy_iteration(policy, value_state, grid)\n", 183 | " new_policy = np.zeros([grid.num_states, len(grid.actions)])\n", 184 | " new_policy[range(len(policy)), new_policy_argmax.reshape(new_policy_argmax.shape[0],)] = 1\n", 185 | " new_policy = new_policy.reshape([grid.num_states, len(grid.actions), 1])\n", 186 | "\n", 187 | " diff = np.abs(policy - new_policy).max(axis=1)\n", 188 | " policy = new_policy\n", 189 | " if diff[diff != 0].shape[0] != 0:\n", 190 | " print(\"try again\")\n", 191 | " else:\n", 192 | " break\n", 193 | "\n", 194 | "policy.reshape(16,4).argmax(axis=1).reshape(4, 4), value_state.reshape(4,4)\n", 195 | " \n", 196 | "\n", 197 | "#np.asmatrix(p.reshape(16,4))[:,[1,2]]" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [] 206 | } 207 | ], 208 | "metadata": { 209 | "kernelspec": { 210 | "display_name": "Python 3", 211 | "language": "python", 212 | "name": "python3" 213 | }, 214 | "language_info": { 215 | "codemirror_mode": { 216 | "name": "ipython", 217 | "version": 3 218 | }, 219 | "file_extension": ".py", 220 | "mimetype": "text/x-python", 221 | "name": "python", 222 | "nbconvert_exporter": "python", 223 | "pygments_lexer": "ipython3", 224 | "version": "3.6.3" 225 | } 226 | }, 227 | "nbformat": 4, 228 | "nbformat_minor": 2 229 | } 230 | -------------------------------------------------------------------------------- /open ai gym/.ipynb_checkpoints/mountain car - n-step sarsa-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gym\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import numpy as np\n", 12 | "from collections import defaultdict\n", 13 | "import math, random" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "%matplotlib inline" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "env = gym.make(\"MountainCarContinuous-v0\")" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 4, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "data": { 41 | "text/plain": [ 42 | "((2,), array([-1.2 , -0.07]), array([0.6 , 0.07]))" 43 | ] 44 | }, 45 | "execution_count": 4, 46 | "metadata": {}, 47 | "output_type": "execute_result" 48 | } 49 | ], 50 | "source": [ 51 | "env.observation_space.shape, env.observation_space.low, env.observation_space.high" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 5, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "def featurize_state(state):\n", 61 | " return np.array([state[0], state[0]**2, state[0]**3, state[0]**4, state[0]**5,\n", 62 | " state[1], state[1]**2, state[1]**3, state[1]**4, state[1]**5])" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 6, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "def actions(env):\n", 72 | " return np.linspace(env.action_space.low, env.action_space.high, 10)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 7, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "state = env.observation_space.sample()\n", 82 | "state = featurize_state(state)\n", 83 | "\n", 84 | "def init_weights():\n", 85 | " state = env.observation_space.sample()\n", 86 | " state = featurize_state(state)\n", 87 | " return np.random.uniform(-1/np.sqrt(state.shape[0]),\n", 88 | " 1/np.sqrt(state.shape[0]),\n", 89 | " (actions(env).shape[0], state.shape[0]))\n", 90 | "\n", 91 | "def q(state, action_idx):\n", 92 | " state = featurize_state(state)\n", 93 | " return np.dot(weights[action_idx], state)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 8, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "num_actions = len(actions(env))\n", 103 | "def select_action(state, greedy=0.2):\n", 104 | " state = featurize_state(state)\n", 105 | " max_arg = np.dot(weights, state).argmax()\n", 106 | " \n", 107 | " policy = np.full(num_actions, greedy/num_actions)\n", 108 | " policy[max_arg] = 1 - greedy + greedy/num_actions\n", 109 | " \n", 110 | " return np.random.choice(np.arange(num_actions), p=policy)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 44, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "state = env.reset()\n", 120 | "weights = init_weights()\n", 121 | "action = None\n", 122 | "done = False" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 47, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "lr = learning_rate = 0.01\n", 132 | "df = discount_factor = 0.9\n", 133 | "n = 3\n", 134 | "\n", 135 | "episodes = []\n", 136 | "iterations = []\n", 137 | "\n", 138 | "for i in range(1):\n", 139 | " count = 0\n", 140 | " state = env.reset()\n", 141 | " done = False\n", 142 | " action = None\n", 143 | " T = float(\"inf\")\n", 144 | " tau = float(\"-inf\")\n", 145 | " t = 0\n", 146 | " rewards = []\n", 147 | " done_hitted = False\n", 148 | " \n", 149 | " while tau != T - 1:\n", 150 | " count += 1\n", 151 | " # if count >= 300: raise Exception(\"yay\")\n", 152 | " if action is None:\n", 153 | " action = select_action(state)\n", 154 | "\n", 155 | " next_state, reward, done, _ = env.step([actions(env)[action]])\n", 156 | " if done and not done_hitted:\n", 157 | " T = t + 1\n", 158 | " done_hitted = True\n", 159 | " \n", 160 | " next_action = select_action(next_state)\n", 161 | " \n", 162 | " tau = t - n + 1\n", 163 | " \n", 164 | " if tau >= 0:\n", 165 | " rs = rewards[tau+1:min(tau+n, T)]\n", 166 | " G = sum([r*df**idx for idx, r in enumerate(rs)])\n", 167 | " if tau + n < T: G = G + q(next_state, next_action)*df**n\n", 168 | " \n", 169 | " w = weights[action]\n", 170 | " w = w + lr*(G - q(state, action))*featurize_state(state)\n", 171 | "\n", 172 | " weights[action] = w\n", 173 | "\n", 174 | " action = next_action\n", 175 | " state = next_state\n", 176 | " env.render()\n", 177 | " t += 1\n", 178 | " \n", 179 | " episodes.append(i)\n", 180 | " iterations.append(count)\n", 181 | " env.render()" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 37, 187 | "metadata": {}, 188 | "outputs": [ 189 | { 190 | "data": { 191 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA5sAAAFTCAYAAAC3TxjgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAMTQAADE0B0s6tTgAAGlVJREFUeJzt3W+IZWl9J/DvT7qZXtK9s5vJzE7Hsu2JoxKc2L5QOixG\nA5L/2+gqGBwG3HkR8UU2gUGQHUK/kIXVF5ElBPyDAUHTkPVP0hRkZZOQ1awyvZMYa23CxpFkbGss\nTZwQrNlQWoPPvqhbbvXsVFdV3+fWuefezweK6j7n3NtPVZ97zvme5/c8p1prAQAAgJ5eMHQDAAAA\nWDzCJgAAAN0JmwAAAHQnbAIAANCdsAkAAEB3wiYAAADdCZsAAAB0J2wCAADQnbAJAABAd8ImAAAA\n3Z0YugFHdccdd7S777576GYAAAAsnaeeeup7rbU7DrPt6MLm3XffnfX19aGbAQAAsHSq6u8Pu60y\nWgAAALoTNgEAAOhO2AQAAKA7YRMAAIDuhE0AAAC6EzYBAADoTtgEAACgO2ETAACA7oRNAAAAujtU\n2Kyq36qqJ6uqVdWr9ix/aVV9oaq+UlWPV9Urpl0HAADA+B22Z/OTSV6b5GvPWf6hJB9urb0syfuS\nfLTDOgAAAEbuUGGztfa51tr63mVVdU+SVyf5+GTRp5K8qKruv9110/0owBhtbm3nyrUb2dzaHrop\nAAB0NM2YzRcl2WitPZskrbWW5EaSc1OsA5bM6tpGLl+9ntW1jaGbAgBARyeGbsBBquqRJI/s/v3O\nO+8csDVAb5cunL3pOwAAi2Gans2vJzlbVSeSpKoqO72TN6ZY9/9prb2/tbay+3X69OkpmgzMmzOn\nTubBi+dy5tTJoZsCAEBHtx02W2t/l+SLSR6aLHpLkvXW2ldvd93ttgUAAID5UjtDJg/YqOpDSX4p\nyb1Jnk6y2Vq7v6penp2ZZO9K8p0kD7fWvjx5zW2tO8jKykpbX18/eEMAAAC6qqqnWmsrh9r2MGFz\nngibAAAAwzhK2JxmzCYAAAA8L2ETAACA7oRNAAAAuhM2AQAA6E7YBAAAoDthEwAAgO6ETQAAALoT\nNgEAAOhO2AQAAKA7YRMAAIDuhE0AAAC6EzYBABbM5tZ2rly7kc2t7aGbAiwxYRMAYMGsrm3k8tXr\nWV3bGLopwBI7MXQDAADo69KFszd9BxiCsAkAsGDOnDqZBy+eG7oZwJJTRgsAAEB3wiYAAADdCZsA\nAAB0J2wCAADQnbAJAABAd8ImAAAA3QmbAAAAdCdsArC0Nre2c+XajWxubQ/dFABYOMImAEtrdW0j\nl69ez+raxtBNAYCFc2LoBgDAUC5dOHvTd+bH5tZ2Vtc2cunC2Zw5dXLo5gBwG/RsArC0zpw6mQcv\nnhNm5pBeZ4Dx07MJAMwdvc4A4ydsAgBzZ7fXGYDxUkYLAABAd8ImAAAA3QmbAAAAdCdsAgAA0J2w\nCQAAQHfCJgAAAN0JmwAAAHQnbAIAANCdsAkAAEB3wiYAAADdCZsAAAB0J2wCAADQnbAJAABAd8Im\nAAAA3QmbAAAAdCdsAgAA0J2wCQAAQHfCJgAAAN0JmwAAAHQnbAIAANCdsAkAAEB3wiYAAADdCZsA\nAAB0J2wCAADQnbAJAABAd8ImAAAA3QmbAAAAdCdsAgAA0F2XsFlVv1hVX6yqL1XV9ap6+2T5PVX1\nmap6YrL8dXtes+86AAAAxu3EtG9QVZXk40l+urX2v6rqfJL/XVWfTvLeJI+11n6+ql6T5Per6r7W\n2vYB6wAAABixXmW0Lcm/mPz5nyd5Osl3k7w1yQeTpLX2eJJvJHn9ZLtbrQMAAGDEpu7ZbK21qvrl\nJJ+uqv+T5F8meXOSM0lOtta+uWfzJ5Ocq6q79ls3bXsAAAAY3tQ9m1V1IslvJHlza+3FSd6Q5GPp\nEGQn7/9IVa3vfj3zzDM93hYAAIAZ6lFG+6okP9pa+1zyg5LY9SSvTPJsVd27Z9vzSW601p7eb91z\n37y19v7W2sru1+nTpzs0mUWzubWdK9duZHNrNkN+Z/3+AACwaHqEza8nOVtVP54kVXV/kpck+esk\nn0jyzsny1yR5YZLPTl53q3VwJKtrG7l89XpW1zZG+f4AALBoeozZ/FZVvSPJf6mq72cnwP5qa+1G\nVb07yceq6okk30vy0J7ZZm+1Do7k0oWzN30f2/sDAMCiqdba0G04kpWVlba+vj50MwAAAJZOVT3V\nWls5zLa9Hn0CAAAAPyBsAgAA0J2wCQAAQHfCJgAAAN0JmwAAAHQnbAIAANCdsAkAAEB3wiaHsrm1\nnSvXbmRza3vopgAAACMgbHIoq2sbuXz1elbXNoZuCgAAMAInhm4A43DpwtmbvgPHY3NrO6trG7l0\n4WzOnDo5dHMAAA5NzyaHcubUyTx48ZyLXThmqgoAgLHSswkwx1QVAABjJWwCzLHdqgIAgLFRRgsA\nAEB3wiYAAADdCZsAAAB0J2wCAADQnbAJAABAd8ImAAAA3QmbAABAN5tb27ly7UY2t7aHbgoDEzYB\nAIBuVtc2cvnq9ayubQzdFAZ2YugGAAAAi+PShbM3fWd56dkEAIAjUCZ6a2dOncyDF8/lzKmTQzeF\ngQmbAABwBMpE4XCU0QIAwBEoE4XDETYBAOAIdstEgVtTRgsAAMfImE+WhbAJAADHyJhPloUyWgAA\nOEbGfLIshE0AADhGxnyyLJTRAiwx44YAgFkRNgGWmHFDAMCsKKMFWGLGDQEAsyJsAiwx44YAgFlR\nRgsAAEB3wiYAAADdCZsAAAB0J2wCAADQnbAJAABAd8ImAAAA3QmbAAAAdCdsQgebW9u5cu1GNre2\nh24KAADMBWETOlhd28jlq9ezurYxdFMAAGAunBi6AbAILl04e9N3AABYdsImdHDm1Mk8ePHc0M0A\nAIC5oYwWAACA7oRNAAAAuhM2AYDRMQs4wPwTNgGA0TELOMD8M0EQADA6ZgEHmH/CJgAwOmYBB5h/\nymgBAADoTtgEAACgO2ETAACA7oRNAAAAuhM2AQAA6E7YBAAAoLsuYbOq7qiq366qJ6rqy1X18cny\nl1bVF6rqK1X1eFW9Ys9r9l0HAADAuPXq2XxvkpbkZa21n0jyrsnyDyX5cGvtZUnel+Sje15zq3UA\nAACMWLXWpnuDqh9KspFkpbX2nT3L70ny1SQ/3Fp7tqpqst1rk3xnv3Wtta/e6t9bWVlp6+vrU7UZ\nAACAo6uqp1prK4fZtkfP5kuS/EOSR6vqz6vqz6rqDUlelGSjtfZskrSdVHsjybkD1gEAADByPcLm\niSQvTvJXrbVXJ/m1JL83WT61qnqkqtZ3v5555pkebwsAC29zaztXrt3I5tb2IK8HYLn1CJs3knw/\nye8mSWvtL5P8bXYC6NmqOpEkk1LZc5Ptv36LdTdprb2/tbay+3X69OkOTQaAxbe6tpHLV69ndW1j\nkNcDsNym7n1srX27qv4kyc8l+cOqui/JfUk+n+SLSR7KzuQ/b0myvjsms6r2XQcATO/ShbM3fT/u\n1wOw3KaeIChJqurHkvxOkh/JTi/ne1prn6qql2cnTN6VnUmBHm6tfXnymn3X3YoJggAAAIZxlAmC\nuoTN4yRsAgAADOO4Z6MFAACAmwibJDHjIAAA0JewSRIzDgIMwY0+ABZZl2dhMn5mHAQ4frs3+pLk\nwYvnBm4NHJ/Nre2srm3k0oWzOXPq5NDNAWZE2CRJcubUSRc6AMfMjT6WlRstsByETQAYiBt9LCs3\nWmA5CJsAABwrN1pgOZggCABgZEwuBYyBsAkAMDJmkQfGQBktAMDIGPMIjIGwCQAwMsY8AmOgjBYA\nAIDuhE0AAAC6EzYBAADoTtgEAACgO2ETAACA7oRNAAAAuhM2AQAA6E7YBAAAfmBzaztXrt3I5tb2\n0E1h5IRNAADgB1bXNnL56vWsrm0M3ZSls2hB/8TQDQAAAObHpQtnb/rO8dkN+kny4MVzA7dmesIm\nAADwA2dOnVyIoDNGixb0hU0AAIA5sGhB35hNAAAAuhM2AQDoatEmOQFuj7AJAEBXZjMFEmM2AQDo\nbNEmOQFuj7AJAEBXizbJCXB7lNECAADQnbAJAABAd8ImAAAA3QmbAAAAdCdsQjwPDAAAehM26WLW\nYW3W7+95YMAYuVEGwDzz6BO62A1rSWYy1fms39/zwGA2Nre2s7q2kUsXzubMqZNDN2fhzPrYCADT\nEDbpYtZhbdbv73lgMBvC0Gy5UQbAPKvW2tBtOJKVlZW2vr4+dDMAOIRF79lc9J8PAJ6rqp5qra0c\nZltjNoEDjXlc2Jjbvgh2qwYWNYgZ7w0A+1NGCyMwdO/JmEshx9x25p8yVgDYn7AJIzB0YJrnC+qD\ngvg8t53xW/bx3kPfCANgvgmbMAJDB6Z5vqA+KIjPc9th7Ia+EQY8PzeCmBfC5pxwUOBWBKb9DR3E\nYZn5/MF8ciOIeSFszgkHBbg9gjgMx+cP5pMbQcwLYXNOOCgwS3rOAWB5uBHEvPDokzmx6I8HYFiz\nfjyDx4sAAPBcejZhCcy651wZOAAsDxVTHJawCUtg1uU0ysAZigsegOPnJjOHJWwCUzM2hKG44AE4\nfm4yc1jCJgCj5YIH4Pi5ycxhCZsAjJYLHgCYX2ajBQAAoDthEwAAgO6EzZHwHEMAenNuAWCWhM2R\n2J1xcXVtY+imALAgnFsAmCUTBI2EGReBIXiO5WJzbgFglvRsjsTujIsu9oDjdFDPlzLMcXNuGY7P\nDrAMuoXNqnq4qlpVvWny93uq6jNV9URVXa+q1+3Zdt91AMyPSxfO5j1vfGDfni9lmHB7fHaAZdCl\njLaqzif5lSSP7Vn83iSPtdZ+vqpek+T3q+q+1tr2AesAjs3Yy0Rn3f6DnmOpDBNuj88OsAym7tms\nqhck+UiSf5/ku3tWvTXJB5OktfZ4km8kef0h1gEcm7H3LgzdfmWYcHt8doBl0KNn85Ekn2+t/UVV\nJUmq6q4kJ1tr39yz3ZNJzt1qXYe2ABzJ2HsXxt5+AGBxTRU2q+qBJG9JMrMxl1X1SHYCbZLkzjvv\nnNU/BSyhg8pE593Y2w8ALK5py2h/Ksn5JE9U1ZNJfjLJh7NTJvtsVd27Z9vzSW601p7eb93z/QOt\ntfe31lZ2v06fPj1lkwFgHMxYCsCYTRU2W2sfaK2dba2db62dz84EQe9orX0gySeSvDNJJpMAvTDJ\nZycvvdU6ACDDj8mFReVGDhyPWT5n891J/nVVPZHko0ke2jPb7K3WsYQc9Bkr++58G/v/z0GPngFu\njxs5cDy6PPpkV2vtp/f8+VtJfnaf7fZdx3LaPegnMf6MUbHvzrex//8YkwuzYXI1OB5dwybcLgd9\nxsq+O9/8/8y3sT/nlvFyIweOR7XWhm7DkaysrLT19fWhmwEATOnKtRu5fPV63vPGB1z405UbGTA7\nVfVUa23lMNvOcswmAAcY+5hCuJWD9m9jUpkVYzJhPgibS8IFLcwnF0QssoP2791SxkXseXLeHZYb\nGTAfjNlcEmOfJAMWlTGFLLJl3r/Hft4dexmqMZkwH4TNJbHMJ3yYZy6IWGTLvH+P/bw79rAMzAcT\nBAHM0Nh7B4Dl5NgF7McEQQBzwphMYIwWeTwtcHyU0QLM0NhL6QBgTPTKzxc9mwAzpHcAAI6PiqL5\nomcTAIAjWfTeo0X/+RaZiqL5omcTAIAjWfTeo0X/+RaZiqL5omcT5oA7qACMyaL3Hi36zwfHRc8m\nzAF3UOH5bW5t58q1G9nc2h66KcAei957tOg/HxwXPZswB9xBhefnwfIAMF7CJsyB3TuowM3ciAGA\n8RI2AZhbbsQAwHgZswkAAEB3wiYAAADdCZsAAAB0J2wCAEvnoMfqeOwOwPSETQBg6Rz0fGPPPwaY\nntloAYClc9BjdTx2B2B61Vobug1HsrKy0tbX14duBgBLYHNrO6trG7l04WzOnDo5dHMAYHBV9VRr\nbeUw2yqjBYB9TFtKadwfAMtMGS0A7GPaUsrdsJokD148161dADAGwmYnSq3g+flsMGZnTp2cKiQa\n9wfAMlNG24lZ6+D5+WywzHbDqhstACwjPZudDH33Wu8R82rozwYAAMPQs9nJ0Hev9R4xr4b+bEzL\nBC8AALdHz+aC0HsEs2GCFwCA2yNsLohpJ7EAnp8bOQAAt0cZLSwApZ6zM/YyYACAoQibx0QYYJaM\n2QUAYN4Im8dEGGCWLl04m/e88QGlngAwB6btZNBJwaIwZvOYGPfFLBmzCwDzY9rJ5UxOx6IQNo+J\nMLDcPAcVgHnivDRb03YyHPR6/3+MhTJaOAbKqAEWy9jLHOf9vDT23++0k8sd9Pp5//+DXXo24Rgo\nowZYLGMvc5z389LYf7+zNu//f7CrWmtDt+FIVlZW2vr6+tDN4IiUewCwSJzXZsvvF+ZXVT3VWls5\nzLbKaDkWyj24XWMvpQIWk2fwzpbfLywGYZNjMfSjOQSW8XKjAgAWyzJfly3bz27MJsdi6Nl4jf0Y\nL+NSgNuhDBPm1zJfly3bzy5sshQElvEa+kYFsyUQMCvLdkEHY7LM12XL9rObIAhwwc9grly7kctX\nr+c9b3xAIKArxzXg+Tg2TO8oEwTp2QT0ADCYZbvDy/FRFQE8H9c8x0vYBFzwMxiBAOD4LXPv3rTX\nPMv8u7sdZqMFTDEPAEtkmWd6P+ia56DZYpf5d3c79GwCAMASUdG0v4PKbP3ujsYEQQAAAJm+THYZ\nymyPMkGQMloAAIBMP7RIme3NlNECAAB0oMz2ZsImAABAB2ZZv5kyWgAAALoTNgEAAOhO2AQAAKC7\nqcNmVZ2qqj+oqq9U1VpV/VFV3T9Zd09Vfaaqnqiq61X1uj2v23cdAAAA49arZ/PDSV7eWruQ5GqS\nj0yWvzfJY621lyZ5OMmVqjp5iHUAAACM2NRhs7W21Vr7w9Zamyx6LMn5yZ/fmuSDk+0eT/KNJK8/\nxDoAAABGbBZjNn89ydWquivJydbaN/esezLJuVutm0F7AAAAOGZdn7NZVY8muT/JG5L8s07v+UiS\nR3b/fuedd/Z4WwAAAGaoW89mVb0ryZuT/EJr7Z9aa08nebaq7t2z2fkkN2617rnv21p7f2ttZffr\n9OnTvZoMAADAjHQJm5Pex7cl+ZnW2j/uWfWJJO+cbPOaJC9M8tlDrAMAAGDEpi6jraqVJL+Z5G+S\n/GlVJcl3W2sXk7w7yceq6okk30vyUGtte/LSW60DAABgxKYOm6219SS1z7pvJfnZo64DAABg3GYx\nGy0Ac2JzaztXrt3I5pbCERgTn11gEQibAAtsdW0jl69ez+raxtBNAY7AZ5dl5mbL4uj66BMA5sul\nC2dv+g6Mg88uy2z3ZkuSPHjx3MCtYRrVWhu6DUeysrLS1tfXh24GAAAwA5tb21ld28ilC2dz5tTJ\noZvDc1TVU621lcNsq4wWAACYG2dOncyDF8/tGzSV2Y6HsAkAAIyGMc3jYcwmAACjosxyuRnTPB56\nNgEAGBU9W8vtoDJb5oeeTQAARkXPFoyDsAkAwKjs9mwB800ZLQAAAN0JmwAAAHQnbAIAANCdsAkA\nAEB3wiYAAADdCZsAAAB0J2wCAADQnbAJAABAd8ImAAAA3QmbAAAAdCdsAgAA0F211oZuw5FU1XeT\n/P3Q7djH6STPDN0Ilpb9j6HY9xiKfY8h2f8YytD73t2ttTsOs+HowuY8q6r11trK0O1gOdn/GIp9\nj6HY9xiS/Y+hjGnfU0YLAABAd8ImAAAA3Qmbfb1/6Aaw1Ox/DMW+x1DsewzJ/sdQRrPvGbMJAABA\nd3o2AQAA6E7YBAAAoDths4OqemlVfaGqvlJVj1fVK4ZuE4upqk5V1R9M9rW1qvqjqrp/su6eqvpM\nVT1RVder6nVDt5fFVVUPV1WrqjdN/m7/Y6aq6o6q+u3JPvblqvr4ZLlzMDNXVb9YVV+sqi9NjnFv\nnyx37KOrqvqtqnpyco591Z7l+x7r5vk4KGz28aEkH26tvSzJ+5J8dNjmsOA+nOTlrbULSa4m+chk\n+XuTPNZae2mSh5NcqaqTA7WRBVZV55P8SpLH9iy2/zFr703SkrystfYTSd41We4czExVVSX5eJJ/\n11p7VZJ/k+RDVXUmjn3098kkr03ytecsv9Wxbm6PgyYImlJV3ZPkq0l+uLX27OSAtJHkta21rw7b\nOhZdVb06ySdba+er6pkk97fWvjlZ9z+TPNpa++NBG8lCqaoXJPlvSd6d5DeT/OfW2h/Y/5ilqvqh\n7JxbV1pr39mz3DmYmZvsV99O8m9ba5+rqlcm+a9J7kvyD3HsYwaq6skkb2qtfelWx7ok39lv3Twc\nB/VsTu9FSTZaa88mSdtJ7zeSnBu0VSyLX09ytaruSnJy92Q38WTsh/T3SJLPt9b+YneB/Y9j8JLs\nXNQ/WlV/XlV/VlVviHMwx2CyX/1ykk9X1deS/I8kb09yJo59HI9bHevm+jgobMJIVdWjSe5P8h+G\nbgvLoaoeSPKWJP9x6LawdE4keXGSv2qtvTrJryX5vclymKmqOpHkN5K8ubX24iRvSPKx2P/gQMLm\n9L6e5OzkQLRbanEuO3cUYCaq6l1J3pzkF1pr/9RaezrJs1V1757Nzsd+SF8/lZ396olJec9PZmcM\n8Vtj/2O2biT5fpLfTZLW2l8m+dvsBFDnYGbtVUl+tLX2uSRprT2eZD3JK+PYx/G4Vd6Y6ywibE6p\ntfZ3Sb6Y5KHJorckWZ+HGmkWU1U9kuRtSX6mtfaPe1Z9Isk7J9u8JskLk3z2+FvIomqtfaC1dra1\ndr61dj47EwS9o7X2gdj/mKHW2reT/EmSn0uSqrovO+PlPh/nYGZv92L+x5NkMgv8S5L8dRz7OAa3\nyhvznkVMENRBVb08O7M+3ZWdQboPt9a+PGijWEhVtZKdk97fJNmcLP5ua+1iVf2r7JT13Jfke0l+\ntbX2p8O0lGVQVf89/2+CIPsfM1VVP5bkd5L8SHZ6Od/TWvuUczDHoareluTR7Ox7L0jyn1prVxz7\n6K2qPpTkl5Lcm+TpJJuttftvdayb5+OgsAkAAEB3ymgBAADoTtgEAACgO2ETAACA7oRNAAAAuhM2\nAQAA6E7YBAAAoDthEwAAgO6ETQAAALoTNgEAAOju/wKP43spQqZEUQAAAABJRU5ErkJggg==\n", 192 | "text/plain": [ 193 | "" 194 | ] 195 | }, 196 | "metadata": {}, 197 | "output_type": "display_data" 198 | } 199 | ], 200 | "source": [ 201 | "plt.figure(figsize=(14, 5), dpi=80)\n", 202 | "plt.scatter(episodes, iterations, s=0.5)\n", 203 | "#plt.hist(iterations, bins=100)\n", 204 | "None" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 64, 210 | "metadata": { 211 | "scrolled": true 212 | }, 213 | "outputs": [ 214 | { 215 | "data": { 216 | "text/plain": [ 217 | "(202, array([[-0.13383668, -0.26133253, -0.18449042, 0.08677108, 0.06254631,\n", 218 | " -0.15061691, 0.16906858, -0.06251222, 0.07983635, -0.14186024],\n", 219 | " [-0.05345007, -0.03247772, -0.14748635, -0.174455 , -0.08724804,\n", 220 | " -0.18091682, 0.28122277, 0.06860553, 0.22778229, 0.19586654],\n", 221 | " [-0.09268782, -0.07004116, -0.02485723, -0.05054954, -0.07500264,\n", 222 | " -0.24865851, 0.08908225, 0.1064135 , -0.13726708, -0.26347313],\n", 223 | " [-0.00722484, 0.08859729, 0.07546832, 0.05640408, 0.00233169,\n", 224 | " 0.06386022, -0.13256574, 0.30330713, 0.04238674, 0.13221169],\n", 225 | " [-0.01910994, 0.00648791, -0.1253033 , -0.081292 , -0.01100774,\n", 226 | " -0.15696131, 0.19084809, 0.11525576, -0.09798487, -0.05958179],\n", 227 | " [-0.1372828 , -0.20033466, -0.03238448, -0.05364915, -0.12932985,\n", 228 | " 0.01981108, 0.05396511, -0.20613241, -0.07531075, 0.1247759 ],\n", 229 | " [-0.10363805, -0.25081688, -0.16896908, 0.24315942, 0.1815737 ,\n", 230 | " -0.14699275, 0.19022494, 0.21851419, 0.26020678, 0.18720707],\n", 231 | " [-0.01265579, 0.24167331, 0.26767328, -0.12764771, -0.20161029,\n", 232 | " -0.1040776 , 0.03898223, -0.31178328, -0.29660674, -0.22976241],\n", 233 | " [ 0.08708334, 0.10324454, -0.17512389, 0.09694226, 0.21055723,\n", 234 | " -0.2119353 , -0.06426451, 0.06814469, 0.02813327, -0.03687525],\n", 235 | " [ 0.02974516, 0.21678022, 0.07533856, -0.12043295, -0.09031784,\n", 236 | " -0.16903346, 0.29631335, -0.30342659, 0.10582635, -0.21252685]]))" 237 | ] 238 | }, 239 | "execution_count": 64, 240 | "metadata": {}, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "count, weights" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 90, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "if False:\n", 255 | " np.save(\"weights\", weights)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 12, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "if False:\n", 265 | " weights = np.load(\"weights.npy\")" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [] 274 | } 275 | ], 276 | "metadata": { 277 | "kernelspec": { 278 | "display_name": "Python 3", 279 | "language": "python", 280 | "name": "python3" 281 | }, 282 | "language_info": { 283 | "codemirror_mode": { 284 | "name": "ipython", 285 | "version": 3 286 | }, 287 | "file_extension": ".py", 288 | "mimetype": "text/x-python", 289 | "name": "python", 290 | "nbconvert_exporter": "python", 291 | "pygments_lexer": "ipython3", 292 | "version": "3.6.3" 293 | } 294 | }, 295 | "nbformat": 4, 296 | "nbformat_minor": 2 297 | } 298 | -------------------------------------------------------------------------------- /open ai gym/.ipynb_checkpoints/mountain car - sarsa-Copy1-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gym\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import numpy as np\n", 12 | "from collections import defaultdict\n", 13 | "import math, random" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "%matplotlib inline" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "env = gym.make(\"MountainCarContinuous-v0\")" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 4, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "initial_state = env.reset()" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 5, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "def step(env, action):\n", 50 | " state, reward, final, info = env.step(action)\n", 51 | " return state, reward, final, info" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 6, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "data": { 61 | "text/plain": [ 62 | "((2,), array([-1.2 , -0.07]), array([0.6 , 0.07]))" 63 | ] 64 | }, 65 | "execution_count": 6, 66 | "metadata": {}, 67 | "output_type": "execute_result" 68 | } 69 | ], 70 | "source": [ 71 | "env.observation_space.shape, env.observation_space.low, env.observation_space.high" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 7, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "def featurize_state(state):\n", 81 | " return np.array([state[0], state[0]**2, state[0]**3, state[0]**4, state[0]**5,\n", 82 | " state[1], state[1]**2, state[1]**3, state[1]**4, state[1]**5])\n", 83 | "\n", 84 | "def action_space(env):\n", 85 | " env.action_space.low()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 8, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "def actions(env):\n", 95 | " return np.linspace(env.action_space.low, env.action_space.high, 10)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 9, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "state = env.observation_space.sample()\n", 105 | "state = featurize_state(state)\n", 106 | "\n", 107 | "def init_weights():\n", 108 | " state = env.observation_space.sample()\n", 109 | " state = featurize_state(state)\n", 110 | " return np.random.uniform(-1/np.sqrt(state.shape[0]),\n", 111 | " 1/np.sqrt(state.shape[0]),\n", 112 | " (actions(env).shape[0], state.shape[0]))\n", 113 | "weights = init_weights()\n", 114 | "\n", 115 | "def q(state, action_idx):\n", 116 | " state = featurize_state(state)\n", 117 | " return np.dot(weights[action_idx], state)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 10, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "num_actions = len(actions(env))\n", 127 | "def select_action(state, greedy=0.2):\n", 128 | " state = featurize_state(state)\n", 129 | " max_arg = np.dot(weights, state).argmax()\n", 130 | " \n", 131 | " if np.random.uniform() < 1 - greedy:\n", 132 | " return max_arg\n", 133 | " else:\n", 134 | " return random.randint(0, num_actions-1)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 19, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "state = env.reset()\n", 144 | "weights = init_weights()\n", 145 | "action = None\n", 146 | "done = False" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 22, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "lr = learning_rate = 0.01\n", 156 | "df = discount_factor = 1\n", 157 | "\n", 158 | "episodes = []\n", 159 | "iterations = []\n", 160 | "\n", 161 | "\n", 162 | "for i in range(1):\n", 163 | " count = 0\n", 164 | " state = env.reset()\n", 165 | " done = False\n", 166 | " action = None\n", 167 | " \n", 168 | " while not done:\n", 169 | " count += 1\n", 170 | " # if count >= 300: raise Exception(\"yay\")\n", 171 | " if action is None:\n", 172 | " action = select_action(state)\n", 173 | "\n", 174 | " next_state, reward, done, _ = step(env, [actions(env)[action]])\n", 175 | " \n", 176 | " if done:\n", 177 | " w = weights[action]\n", 178 | " w = w + lr*(reward - q(state, action))*featurize_state(state)\n", 179 | " weights[action] = w\n", 180 | " break\n", 181 | " \n", 182 | " next_action = select_action(next_state)\n", 183 | " w = weights[action]\n", 184 | "\n", 185 | " w = w + lr*(reward + df*q(next_state, next_action) - q(state, action))*featurize_state(state)\n", 186 | "\n", 187 | " weights[action] = w\n", 188 | "\n", 189 | " action = next_action\n", 190 | " state = next_state\n", 191 | " env.render()\n", 192 | " \n", 193 | " episodes.append(i)\n", 194 | " iterations.append(count)\n", 195 | " env.render()" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 23, 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "data": { 205 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA5sAAAFTCAYAAAC3TxjgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAMTQAADE0B0s6tTgAAG/FJREFUeJzt3WGMZedZH/D/Q3flrfA2EGOTxePNmthBKCFrUCJTFJJW\nUYECW1KCEmFZpFYE5AOlkpUqrUX9IapE8qFWhZCCU4JckqxEQwBrKzWlUAghqV1DyJItIrEFzjLO\nhBAj8KzKxrPK0w9zh86a7Oys5505c+/8ftLVnXvec888d+bcc+c/73veU90dAAAAGOmrpi4AAACA\nxSNsAgAAMJywCQAAwHDCJgAAAMMJmwAAAAwnbAIAADCcsAkAAMBwwiYAAADDCZsAAAAMJ2wCAAAw\n3KGpC7hW1113Xd94441TlwEAAHDgPPXUU89293XbWXfuwuaNN96Y5eXlqcsAAAA4cKrqL7a7rmG0\nAAAADCdsAgAAMJywCQAAwHDCJgAAAMMJmwAAAAwnbAIAADCcsAkAAMBwwiYAAADDbStsVtXPVNWT\nVdVVdcem5bdX1cer6jNV9VhVvWynbQAAAMy/7fZs/nKSVyf57HOWP5jkPd390iTvSvLQgDYAAADm\nXHX39leuejLJ67v7k1V1U5Inkrywuy9VVSVZyXoofeb5tHX3E1erYWlpqZeXl6/pRQIAALBzVfVU\ndy9tZ92dnLN5S5KV7r6UJL2eWs8nOb6DtoW1enEtpx89n9WLa/vy+++0vt1+fTutf+qf/9VM/fOb\ncvv7/XdzNfv5Z7sI3/9qdvvYNXX7Qa7/oL+3pn79i96+U/P+3t1tU76+ef/dL5p9P0FQVd1bVcsb\ntwsXLkxd0vNy5uxK7n/4XM6cXdmX33+n9e3269tp/VP//K9m6p/flNvf77+bq9nPP9tF+P5Xs9vH\nrqnbD3L9B/29NfXrX/T2nZr39+5um/L1zfvvfuF097ZvSZ5Mcsfs65uyPiT20OxxJfl8ktueb9t2\narj55pt7Hj3zN8/2Bx75bD/zN8/uy++/0/p2+/XttP6pf/5XM/XPb8rt7/ffzdXs55/tInz/q9nt\nY9fU7Qe5/oP+3pr69S96+07N+3t3t035+ub9dz8Pkiz3NvPj8z5nc/b4t5M81N0PVdUPJfk33f3K\nnbRdjXM2mcLqxbWcObuSUyeP5eiRw1OXAwAAk7iWczYPbXODDyb5viQvSvLfq2q1u29L8uNJHqqq\n+7LeW3nPpqc93zbYdzaGPCTJXXcu9OnFAAAwxDX1bO4HejaZgp5NAADYu9lo4cA4euRw7rrzuKAJ\ngx20WfkA4CARNgGYzIGblQ8ADpBtnbMJALvh1Mljl90DAItD2ARgMhtD1AGAxWMYLQAAcCCYK2Bv\nCZsAAMCBYK6AvWUYLQAAcCCYK2BvCZsAAMCBYK6AvWUYLQAAAMMJmwAAAAwnbAIAADCcsAkAAMBw\nwiYAwGCu5QcgbAIADOdafgAufQIAMJxr+QHo2QRiuBfAaBvX8jt65PDUpQBMRtgEDPcCAGA4w2gB\nw70AABhO2AT+drgXAACMYhgtAAAAwwmbAAAADCdsAgAAMJywCQAAwHDCJgAAAMMJmwAAAAwnbAIA\nADCcsAkAAMBwwiYAAADDCZsAAAAMNyRsVtX3VNXvVdUfVtUjVXVytvymqvpwVT1eVeeq6jWbnnPF\nNgBIktWLazn96PmsXlybuhQA4BrtOGxW1dcm+UCSN3f3K5L869njJHlnkke6+/Yk9yQ5XVWHt9EG\nADlzdiX3P3wuZ86uTF0KAHCNDg3YxkuSPN3d/ydJuvujVXW8qr4tyRuT3DZb/lhVfS7Ja5P8xlXa\nACCnTh677B4AmB8jhtE+nuSGqvqOJKmqf5bkaJJbkxzu7s9vWvfJJMer6oYrtT1341V1b1Utb9wu\nXLgwoGQA5sHRI4dz153Hc/SIgS8AMG92HDa7+6+T/FCSn66q30/yXUn+KMn1O932bPsPdPfSxu36\n64dsFgAAgF00Yhhtuvu3kvxWklTVdUk+n+RjSS5V1Ys29WCeSHK+u5+uqq/YNqIeAAAApjVqNtrN\nJ9P8uyT/s7ufSPLBJG+drfOqJDcn+chsva3auEZmbAQAAPaTIT2bSd5RVd85297/SvKW2fK3J3lf\nVT2e5Nkkd3f32jbauEYbMzYmyV13/p1TXwEAAPZUdffUNVyTpaWlXl5enrqMfWf14lrOnF3JqZPH\nTKQBAADsiqp6qruXtrPuqJ5NJrYxYyMAAMB+MOScTQAAANhM2AQAAGA4YRMAAIDhhE0AAACGEzZh\nANc5BdhfHJcBpidswgAb1zk9c3Zl6lIAiOMywH7g0icwwKmTxy67B2BajstMyfXPYZ2wCQO4zinA\n/uK4zJQ2etaT2A850IRNAAAYSM86rBM2AQBgID3rsM4EQQAAAAwnbAIAADCcsAkAAMBwwiYAAADD\nCZsAAAAMJ2wCAAAwnLAJAADAcMImAAAAwwmbAAAADCdsAgC7YvXiWk4/ej6rF9emLgWACQibAMCu\nOHN2Jfc/fC5nzq5MXQoAEzg0dQEAwGI6dfLYZfcAHCzCJgCwK44eOZy77jw+dRkATMQwWgAAYNuc\nj812CZsAAMC2OR+b7TKMFgAA2DbnY7NdejYBAIBt2zgf++iRw1OXsnAWbYjykLBZVd9bVZ+oqk9W\n1bmqevNs+U1V9eGqeny2/DWbnnPFNgAAgINm0YYo73gYbVVVkvcn+Ufd/YdVdSLJH1fVryR5Z5JH\nuvt7qupVSX61qm7t7rWrtAEAABwoizZEedQ5m53ka2Zf/4MkTyf5UpI3JrktSbr7sar6XJLXJvmN\nq7QBAAAcKIt2yagdD6Pt7k7ypiS/UlWfTfK7Sd6c5GiSw939+U2rP5nkeFXdcKW2ndbD/rRo488B\nAICt7ThsVtWhJD+V5Ae7+8VJXpfkfRnUa1pV91bV8sbtwoULIzbLHlu08ecAAMDWRgTCO5J8Q3f/\nTvK3Q2KXk7wiyaWqetGmHswTSc5399NV9RXbnrvx7n4gyQMbj5eWlnpAzeyxRRt/DgAAbG3EbLR/\nluRYVX1zklTVbUlekuTTST6Y5K2z5a9KcnOSj8yet1UbC8YU2QAAcLDsuGezu/+8qn4syX+pqi9n\nPcD+RHefr6q3J3lfVT2e5Nkkd2+abXarNgAAAOZYrc/vMz+WlpZ6eXl56jIAAAAOnKp6qruXtrPu\niGG0AAAAcBlhEwAAgOGETQAAAIYTNgEAABhO2AQAAGA4YRMAAIDhhE0AAACGEzYBAAAYTtgEAABg\nOGETAACA4YRNAAAAhhM2AQAAGE7YBAAAYDhhEwAAgOGETQAAAIYTNgEAABhO2IQ9sHpxLacfPZ/V\ni2tTlwIAAHtC2IQ9cObsSu5/+FzOnF2ZuhS4jH+EAAC75dDUBcBBcOrkscvuYb/Y+EdIktx15/GJ\nqwEAFomwCXvg6JHDB/YP+dWLazlzdiWnTh7L0SOH9/z5bM0/QgCA3WIYLbCrdjqE2BDk3bXxjxBB\nHgAYTc8mcFU76V3cac+ZnjcAgPmkZxNikpSr2Unv4k57zvS8AV+J4zY7Yf+BvaFnE2KSlKvRuwjs\nN47b7IT9B/aGsAkRpq7mIE9wxHwzwdTictxmJ+w/sDequ6eu4ZosLS318vLy1GUAMAdOP3o+9z98\nLu/4gZf7hwkADFBVT3X30nbW1bMJwMLSewHAtTAiZiwTBAGwsEwwBcC1cMm1sfRsAgAAxIiY0Xbc\ns1lVN1TVJzfdPlNVl6rqhVV1U1V9uKoer6pzVfWaTc+7YhsAAMBeMyJmrB33bHb300nu2HhcVW9L\n8tru/suq+oUkj3T391TVq5L8alXd2t1rSd65RRsAAABzbDfO2XxLkvfOvn5jkp9Lku5+LMnnkrx2\nG20AAADMsaFhs6q+I8nXJvmvVXVDksPd/flNqzyZ5PhWbSPrAQAAYBqjezbfkuQXu/vSqA1W1b1V\ntbxxu3DhwqhNAwAspNWLazn96PmsXnR2EjCdYWGzqq7P+tDYX0j+9lzOS1X1ok2rnUhyfqu25263\nux/o7qWN2/XXXz+qZACAheTyDcB+MLJn801Jznb3H29a9sEkb02S2SRANyf5yDbaAPaE//4Di+jU\nyWN5xw+83OUbgEmNvM7mW5L8p+cse3uS91XV40meTXL3ptlmt2oD2BMb//1PkrvudNo4sBg2Lt8A\nMKVhYbO7v+MrLPvzJN91hfWv2AawV1y8GQBgd4zs2QSYO/77DwCwO3bjOpsAAAAccMImALvGBEwA\ncHAJmzAH/MHOvHL5BQA4uJyzCXPAjKnMKxMwAcDBJWzCHPAHO/PKBEwAcHAJmzAH/MEOAMC8cc4m\n+4JzEgEAYLEIm3Ni0cOYSUQA2E8W/XMXYC8YRjsnFn2CGOckArCfLPrnLsBeEDbnxKKHMeckArCf\nLPrnLsBeqO6euoZrsrS01MvLy1OXAQAAcOBU1VPdvbSddZ2zCQAAwHDCJgAAAMMJmwAAAAwnbLIt\npoCH3eG9BQAsKmGTbXEdTNgd3lsAwKJy6RO2xRTwsDvm/b21enEtZ86u5NTJYzl65PDU5QAA+4iw\nyba4Dibsjnl/b7nwPQBwJcImAM/bvPfMAgC7R9gE4Hmb955ZAGD3mCAIAACA4YRNAAAAhhM2AQAA\nGE7YBAAAYDhhEwAAgOGETQAAAIYTNgEAABhO2AQAAGC4IWGzqq6rqp+tqser6lNV9f7Z8tur6uNV\n9ZmqeqyqXrbpOVdsAwAAYL6N6tl8Z5JO8tLu/pYkb5stfzDJe7r7pUneleShTc/Zqg0AAIA5Vt29\nsw1UfXWSlSRL3f3MpuU3JXkiyQu7+1JV1Wy9Vyd55kpt3f3EVt9vaWmpl5eXd1QzAAAA166qnuru\npe2sO6Jn8yVJ/jLJfVX1e1X10ap6XZJbkqx096Uk6fVUez7J8au0PffF3FtVyxu3CxcuDCgZAACA\n3TQibB5K8uIkf9Tdr0zyk0l+abZ8x7r7ge5e2rhdf/31IzYLAADALhoRNs8n+XKSDyRJd/9Bkj/N\negA9VlWHkmQ2VPb4bP0/26INAACAObfjsNndX0zym0m+O0mq6tYktyb5WJJPJLl7tuobkix39xPd\n/YUrte20HgAAAKa34wmCkqSqvjHJe5N8XdZ7Od/R3R+qqm/K+iyzN2R9UqB7uvtTs+dcsW0rJggC\nAACYxrVMEDQkbO4lYROAg2L14lrOnF3JqZPHcvTI4anLAYA9n40WANgFZ86u5P6Hz+XM2ZWpSwGA\nazZkxlgAYLxTJ49ddg8A80TYBIB96uiRw7nrzr9zCWoAmAuG0QIAADCcsAkAMGdWL67l9KPns3px\nbepSAK5I2AQAmDMmjwLmgXM2AQDmjMmjgHmgZxNggRlqB4tpY/Io118F9jNhE2CBGWoHAEzFMFqA\nBWaoHQAwFT2bAAvMUDsA2DtOX7mcsAkAADCA01cuZxgtAADAAE5fuZyeTQCAa2SoHPCVOH3lcsIm\nAMA1MlQO4OoMowUAuEaGygFcnbAJAHCNNobKAXBlhtECAAAwnLAJAADAcMImAAAAwwmbAAAADCds\nAgAAMJywCQAAwHDCJgAAAMMJmwAAsMnqxbWcfvR8Vi+uTV0KzDVhEwAANjlzdiX3P3wuZ86uTF0K\nzLVDUxcAAAD7yamTxy67B54fYRMAADY5euRw7rrz+NRlwNwbMoy2qp6sqk9X1SdntzfNlt9eVR+v\nqs9U1WNV9bJNz7liGwAAAPNt5Dmbb+ruO2a3X5otezDJe7r7pUneleShTetv1QYAAMAc27UJgqrq\npiSvTPL+2aIPJbmlqm7bqm236gEAAGDvjAybv1hVn6qq91bVjUluSbLS3ZeSpLs7yfkkx6/SBgAA\nwJwbFTZf092vSPJtSb6Y5D8P2m6q6t6qWt64XbhwYdSmAQAA2CW13qk4cINVx5J8JslLkjyR5IXd\nfamqKslKklcneeZKbd39xFbbX1pa6uXl5aE1AwAAcHVV9VR3L21n3R33bFbVV1fV12xa9MNJ/qC7\nv5DkE0nuni1/Q5Ll7n5iq7ad1gOwn6xeXMvpR89n9eLa1KUAAOypEdfZ/PokH6qqv5ekkvxJkh+Z\ntf14koeq6r6s92bes+l5W7UBLIQzZ1dy/8PnksQ12wCAA2XHYbO7/yTJt16h7dNJ/uG1tgEsilMn\nj112DwBwUIzo2QTgCo4eOaxHEwA4kHbtOpsAAAAcXMImwBwzAREAsF8Jm7AABI6Da2MCojNnV6Yu\nBQDgMs7ZhAVgxtODywREAMB+JWzCAhA4Di4TEAEA+5WwCQtA4AAAYL9xziYAAADDCZsAAAAMJ2wC\nAAAwnLAJAADAcMImAAAAwwmbAAAADCdsAgAAMJywCQAAwHDC5iCrF9dy+tHzWb24NnUpAAAAkxM2\nBzlzdiX3P3wuZ86uTF0KAADA5A5NXcCiOHXy2GX3AAAAB5mwOcjRI4dz153Hpy4DAABgXzCMFgAA\ngOGETQAAAIYTNgEAABhO2ATmmssOAcBYPlsZRdgE5prLDgHAWD5bGcVstMBcc9khABjLZyujVHdP\nXcM1WVpa6uXl5anLAA6I1YtrOXN2JadOHsvRI4enLgcAYFJV9VR3L21nXcNoAbZgKBEAwPNjGC3A\nFgwlAgB4foRNgC0cPXI4d915fOoyAADmzrBhtFV1T1V1Vb1+9vimqvpwVT1eVeeq6jWb1r1iGwAA\nAPNvSNisqhNJfjTJI5sWvzPJI919e5J7kpyuqsPbaAMAAGDO7ThsVtVXJfn5JP8yyZc2Nb0xyc8l\nSXc/luRzSV67jTYAAADm3IiezXuTfKy7f39jQVXdkORwd39+03pPJjm+VduAWgAAANgHdjRBUFW9\nPMkbkuzaOZdVdW/WA22S5AUveMFufSsAAAAG2WnP5ncmOZHk8ap6Msm3J3lP1ofJXqqqF21a90SS\n89399JXavtI36O4Huntp43b99dfvsGQAAAB2247CZne/u7uPdfeJ7j6R9QmCfqy7353kg0nemiRV\n9aokNyf5yOypW7UBAAAw53bzOptvT/K+qno8ybNJ7u7utW20AQAAMOequ6eu4ZosLS318vLy1GUA\nAAAcOFX1VHcvbWfdIdfZBAAAgM2ETQAAAIYTNgEAABhO2AQAAGA4YRMAAIDhhE0AADhAVi+u5fSj\n57N60ZUH2V3CJgAAHCBnzq7k/ofP5czZlalLYcEdmroAAABg75w6eeyye9gtwiYAABwgR48czl13\nHp+6DA4Aw2gBAAAYTtgEAABgOGETAACA4YRNAAAAhhM2AQAAGE7YBAAAYDhhc59YvbiW04+ez+rF\ntalLAQAA2DFhc584c3Yl9z98LmfOrkxdCgAAwI4dmroA1p06eeyyewAAgHkmbO4TR48czl13Hp+6\nDAAAgCEMowUAAGA4YRMAAIDhhE0AAACGEzYBAAAYTtgEAABgOGETAACA4YRNAAAAhhM2AQAAGE7Y\nBAAAYDhhEwAAgOGGhM2q+vWq+sOq+mRVfbSqvnW2/Paq+nhVfaaqHquql216zhXbAAAAmG+jejbf\n2N2v6O47kjyQ5KHZ8geTvKe7X5rkXZuWX60NAACAOTYkbHb3X216+IIkXVU3JXllkvfPln8oyS1V\nddtWbSPqAQAAYFqHRm2oqn4xyT+ePfzeJLckWenuS0nS3V1V55McT/LXW7Q9MaomAAAApjFsgqDu\n/pHuviXJT2V9WOwQVXVvVS1v3C5cuDBq0wAAAOyS6u7xG636myQnkjye5IXdfamqKslKklcneSbr\nPZh/p627t+zZXFpa6uXl5eE1AwAAsLWqeqq7l7az7o57Nqvqa6rqGzY9fn2Sp5N8Icknktw9a3pD\nkuXufqK7r9i203oAAACY3ohzNl+Q5INV9feTfDnJXyT5/tl5mD+e5KGqui/rvZn3bHreVm0AAADM\nsV0ZRrubDKMFAACYxp4OowUAAIDnEjYBAAAYTtgEAABgOGETAACA4YRNYKGtXlzL6UfPZ/Xi2tSl\nAAAcKMImsNDOnF3J/Q+fy5mzK1OXAgBwoIy4zibAvnXq5LHL7gEA2BvCJrDQjh45nLvuPD51GQAA\nB45htAAAAAwnbAIAADCcsAkAAMBwwiYAAADDCZsAAAAMJ2wCAAAwnLAJAADAcMImAAAAwwmbAAAA\nDCdsAgAAMJywCQAAwHDV3VPXcE2q6ktJ/mLqOq7g+iQXpi6CA8v+x1Tse0zFvseU7H9MZep978bu\nvm47K85d2NzPqmq5u5emroODyf7HVOx7TMW+x5Tsf0xlnvY9w2gBAAAYTtgEAABgOGFzrAemLoAD\nzf7HVOx7TMW+x5Tsf0xlbvY952wCAAAwnJ5NAAAAhhM2AQAAGE7YHKCqbq+qj1fVZ6rqsap62dQ1\nsZiq6khV/dpsXztbVf+jqm6btd1UVR+uqser6lxVvWbqellcVXVPVXVVvX722P7Hrqqq66rqZ2f7\n2Keq6v2z5T6D2XVV9b1V9Ymq+uTsGPfm2XLHPoaqqp+pqidnn7F3bFp+xWPdfj4OCptjPJjkPd39\n0iTvSvLQtOWw4N6T5Ju6+2SSh5P8/Gz5O5M80t23J7knyemqOjxRjSywqjqR5EeTPLJpsf2P3fbO\nJJ3kpd39LUneNlvuM5hdVVWV5P1J/kV335Hk+5M8WFVH49jHeL+c5NVJPvuc5Vsd6/btcdAEQTtU\nVTcleSLJC7v70uyAtJLk1d39xLTVseiq6pVJfrm7T1TVhSS3dffnZ23/O8l93f0bkxbJQqmqr0ry\n60nenuQ/JPmP3f1r9j92U1V9ddY/W5e6+5lNy30Gs+tm+9UXk/zz7v6dqnpFkv+W5NYkfxnHPnZB\nVT2Z5PXd/cmtjnVJnrlS2344DurZ3Llbkqx096Uk6fX0fj7J8Umr4qD4V0kerqobkhze+LCbeTL2\nQ8a7N8nHuvv3NxbY/9gDL8n6H/X3VdXvVdVHq+p18RnMHpjtV29K8itV9dkkv5vkzUmOxrGPvbHV\nsW5fHweFTZhTVXVfktuS/Nupa+FgqKqXJ3lDkn8/dS0cOIeSvDjJH3X3K5P8ZJJfmi2HXVVVh5L8\nVJIf7O4XJ3ldkvfF/gdXJWzu3J8lOTY7EG0MtTie9f8owK6oqrcl+cEk/7S7/293P53kUlW9aNNq\nJ2I/ZKzvzPp+9fhseM+3Z/0c4jfG/sfuOp/ky0k+kCTd/QdJ/jTrAdRnMLvtjiTf0N2/kyTd/ViS\n5SSviGMfe2OrvLGvs4iwuUPd/YUkn0hy92zRG5Is74cx0iymqro3yQ8n+Sfd/Vebmj6Y5K2zdV6V\n5OYkH9n7CllU3f3u7j7W3Se6+0TWJwj6se5+d+x/7KLu/mKS30zy3UlSVbdm/Xy5j8VnMLtv44/5\nb06S2SzwL0ny6Tj2sQe2yhv7PYuYIGiAqvqmrM/6dEPWT9K9p7s/NWlRLKSqWsr6h96fJFmdLf5S\nd99ZVV+f9WE9tyZ5NslPdPdvTVMpB0FV/Xb+/wRB9j92VVV9Y5L3Jvm6rPdyvqO7P+QzmL1QVT+c\n5L6s73tfleSnu/u0Yx+jVdWDSb4vyYuSPJ1ktbtv2+pYt5+Pg8ImAAAAwxlGCwAAwHDCJgAAAMMJ\nmwAAAAwnbAIAADCcsAkAAMBwwiYAAADDCZsAAAAMJ2wCAAAwnLAJAADAcP8PwRmnAh4y/+8AAAAA\nSUVORK5CYII=\n", 206 | "text/plain": [ 207 | "" 208 | ] 209 | }, 210 | "metadata": {}, 211 | "output_type": "display_data" 212 | } 213 | ], 214 | "source": [ 215 | "plt.figure(figsize=(14, 5), dpi=80)\n", 216 | "plt.scatter(episodes, iterations, s=0.5)\n", 217 | "#plt.hist(iterations, bins=100)\n", 218 | "None" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 120, 224 | "metadata": { 225 | "scrolled": true 226 | }, 227 | "outputs": [ 228 | { 229 | "data": { 230 | "text/plain": [ 231 | "(999, array([[-0.02085557, 0.0595762 , 0.11831713, -0.18363152, -0.19683867,\n", 232 | " -0.01796869, 0.28806498, -0.021476 , 0.24860996, 0.18516945],\n", 233 | " [ 0.01440559, -0.05804852, -0.22202222, -0.2264345 , -0.07399336,\n", 234 | " 0.0836368 , 0.03485213, 0.13657303, -0.18549418, 0.08907543],\n", 235 | " [-0.08892545, -0.23147879, -0.16591371, -0.0597599 , -0.0296756 ,\n", 236 | " 0.0452843 , -0.16806746, 0.11934212, 0.05252277, 0.08751585],\n", 237 | " [ 0.00821088, -0.05516365, -0.10181619, -0.01925168, 0.0219685 ,\n", 238 | " 0.09394311, 0.27535469, -0.22059754, 0.14107741, 0.07010356],\n", 239 | " [-0.036048 , -0.08865107, -0.11465601, -0.16514499, -0.0954818 ,\n", 240 | " 0.15807466, -0.01672393, 0.20743621, 0.02093894, 0.29951216],\n", 241 | " [-0.03579179, -0.01706876, 0.1083824 , -0.00311162, -0.07572744,\n", 242 | " 0.22687012, 0.1699998 , 0.25096862, 0.01865831, 0.17707882],\n", 243 | " [-0.10621423, -0.29940879, -0.25072268, -0.12465697, -0.05695309,\n", 244 | " 0.05062194, 0.12632425, 0.01576961, 0.02080412, -0.08271762],\n", 245 | " [-0.04480179, -0.13596811, -0.23614967, -0.34553927, -0.18651549,\n", 246 | " -0.0329766 , 0.08021832, 0.07155519, 0.26401518, -0.064796 ],\n", 247 | " [-0.06205904, -0.12404013, -0.02585845, -0.07235029, -0.08973349,\n", 248 | " -0.08904176, -0.06319158, 0.09375384, 0.21721181, -0.05041887],\n", 249 | " [-0.0676916 , -0.11858236, 0.12053182, 0.21836649, 0.05770418,\n", 250 | " 0.20416503, -0.09088759, -0.31010173, -0.16394684, 0.11900846]]))" 251 | ] 252 | }, 253 | "execution_count": 120, 254 | "metadata": {}, 255 | "output_type": "execute_result" 256 | } 257 | ], 258 | "source": [ 259 | "count, weights" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 90, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "if False:\n", 269 | " np.save(\"weights\", weights)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 21, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "if True:\n", 279 | " weights = np.load(\"weights.npy\")" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [] 288 | } 289 | ], 290 | "metadata": { 291 | "kernelspec": { 292 | "display_name": "Python 3", 293 | "language": "python", 294 | "name": "python3" 295 | }, 296 | "language_info": { 297 | "codemirror_mode": { 298 | "name": "ipython", 299 | "version": 3 300 | }, 301 | "file_extension": ".py", 302 | "mimetype": "text/x-python", 303 | "name": "python", 304 | "nbconvert_exporter": "python", 305 | "pygments_lexer": "ipython3", 306 | "version": "3.6.3" 307 | } 308 | }, 309 | "nbformat": 4, 310 | "nbformat_minor": 2 311 | } 312 | -------------------------------------------------------------------------------- /open ai gym/mountain car - n-step sarsa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gym\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import numpy as np\n", 12 | "from collections import defaultdict\n", 13 | "import math, random" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "%matplotlib inline" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "env = gym.make(\"MountainCarContinuous-v0\")" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 4, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "data": { 41 | "text/plain": [ 42 | "((2,), array([-1.2 , -0.07]), array([0.6 , 0.07]))" 43 | ] 44 | }, 45 | "execution_count": 4, 46 | "metadata": {}, 47 | "output_type": "execute_result" 48 | } 49 | ], 50 | "source": [ 51 | "env.observation_space.shape, env.observation_space.low, env.observation_space.high" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 5, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "def featurize_state(state):\n", 61 | " return np.array([state[0], state[0]**2, state[0]**3, state[0]**4, state[0]**5,\n", 62 | " state[1], state[1]**2, state[1]**3, state[1]**4, state[1]**5])" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 6, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "def actions(env):\n", 72 | " return np.linspace(env.action_space.low, env.action_space.high, 10)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 7, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "state = env.observation_space.sample()\n", 82 | "state = featurize_state(state)\n", 83 | "\n", 84 | "def init_weights():\n", 85 | " state = env.observation_space.sample()\n", 86 | " state = featurize_state(state)\n", 87 | " return np.random.uniform(-1/np.sqrt(state.shape[0]),\n", 88 | " 1/np.sqrt(state.shape[0]),\n", 89 | " (actions(env).shape[0], state.shape[0]))\n", 90 | "\n", 91 | "def q(state, action_idx):\n", 92 | " state = featurize_state(state)\n", 93 | " return np.dot(weights[action_idx], state)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 8, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "num_actions = len(actions(env))\n", 103 | "def select_action(state, greedy=0.2):\n", 104 | " state = featurize_state(state)\n", 105 | " max_arg = np.dot(weights, state).argmax()\n", 106 | " \n", 107 | " policy = np.full(num_actions, greedy/num_actions)\n", 108 | " policy[max_arg] = 1 - greedy + greedy/num_actions\n", 109 | " \n", 110 | " return np.random.choice(np.arange(num_actions), p=policy)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 44, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "state = env.reset()\n", 120 | "weights = init_weights()\n", 121 | "action = None\n", 122 | "done = False" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 47, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "lr = learning_rate = 0.01\n", 132 | "df = discount_factor = 0.9\n", 133 | "n = 3\n", 134 | "\n", 135 | "episodes = []\n", 136 | "iterations = []\n", 137 | "\n", 138 | "for i in range(1):\n", 139 | " count = 0\n", 140 | " state = env.reset()\n", 141 | " done = False\n", 142 | " action = None\n", 143 | " T = float(\"inf\")\n", 144 | " tau = float(\"-inf\")\n", 145 | " t = 0\n", 146 | " rewards = []\n", 147 | " done_hitted = False\n", 148 | " \n", 149 | " while tau != T - 1:\n", 150 | " count += 1\n", 151 | " # if count >= 300: raise Exception(\"yay\")\n", 152 | " if action is None:\n", 153 | " action = select_action(state)\n", 154 | "\n", 155 | " next_state, reward, done, _ = env.step([actions(env)[action]])\n", 156 | " if done and not done_hitted:\n", 157 | " T = t + 1\n", 158 | " done_hitted = True\n", 159 | " \n", 160 | " next_action = select_action(next_state)\n", 161 | " \n", 162 | " tau = t - n + 1\n", 163 | " \n", 164 | " if tau >= 0:\n", 165 | " rs = rewards[tau+1:min(tau+n, T)]\n", 166 | " G = sum([r*df**idx for idx, r in enumerate(rs)])\n", 167 | " if tau + n < T: G = G + q(next_state, next_action)*df**n\n", 168 | " \n", 169 | " w = weights[action]\n", 170 | " w = w + lr*(G - q(state, action))*featurize_state(state)\n", 171 | "\n", 172 | " weights[action] = w\n", 173 | "\n", 174 | " action = next_action\n", 175 | " state = next_state\n", 176 | " env.render()\n", 177 | " t += 1\n", 178 | " \n", 179 | " episodes.append(i)\n", 180 | " iterations.append(count)\n", 181 | " env.render()" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 37, 187 | "metadata": {}, 188 | "outputs": [ 189 | { 190 | "data": { 191 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA5sAAAFTCAYAAAC3TxjgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAMTQAADE0B0s6tTgAAGlVJREFUeJzt3W+IZWl9J/DvT7qZXtK9s5vJzE7Hsu2JoxKc2L5QOixG\nA5L/2+gqGBwG3HkR8UU2gUGQHUK/kIXVF5ElBPyDAUHTkPVP0hRkZZOQ1awyvZMYa23CxpFkbGss\nTZwQrNlQWoPPvqhbbvXsVFdV3+fWuefezweK6j7n3NtPVZ97zvme5/c8p1prAQAAgJ5eMHQDAAAA\nWDzCJgAAAN0JmwAAAHQnbAIAANCdsAkAAEB3wiYAAADdCZsAAAB0J2wCAADQnbAJAABAd8ImAAAA\n3Z0YugFHdccdd7S777576GYAAAAsnaeeeup7rbU7DrPt6MLm3XffnfX19aGbAQAAsHSq6u8Pu60y\nWgAAALoTNgEAAOhO2AQAAKA7YRMAAIDuhE0AAAC6EzYBAADoTtgEAACgO2ETAACA7oRNAAAAujtU\n2Kyq36qqJ6uqVdWr9ix/aVV9oaq+UlWPV9Urpl0HAADA+B22Z/OTSV6b5GvPWf6hJB9urb0syfuS\nfLTDOgAAAEbuUGGztfa51tr63mVVdU+SVyf5+GTRp5K8qKruv9110/0owBhtbm3nyrUb2dzaHrop\nAAB0NM2YzRcl2WitPZskrbWW5EaSc1OsA5bM6tpGLl+9ntW1jaGbAgBARyeGbsBBquqRJI/s/v3O\nO+8csDVAb5cunL3pOwAAi2Gans2vJzlbVSeSpKoqO72TN6ZY9/9prb2/tbay+3X69OkpmgzMmzOn\nTubBi+dy5tTJoZsCAEBHtx02W2t/l+SLSR6aLHpLkvXW2ldvd93ttgUAAID5UjtDJg/YqOpDSX4p\nyb1Jnk6y2Vq7v6penp2ZZO9K8p0kD7fWvjx5zW2tO8jKykpbX18/eEMAAAC6qqqnWmsrh9r2MGFz\nngibAAAAwzhK2JxmzCYAAAA8L2ETAACA7oRNAAAAuhM2AQAA6E7YBAAAoDthEwAAgO6ETQAAALoT\nNgEAAOhO2AQAAKA7YRMAAIDuhE0AAAC6EzYBABbM5tZ2rly7kc2t7aGbAiwxYRMAYMGsrm3k8tXr\nWV3bGLopwBI7MXQDAADo69KFszd9BxiCsAkAsGDOnDqZBy+eG7oZwJJTRgsAAEB3wiYAAADdCZsA\nAAB0J2wCAADQnbAJAABAd8ImAAAA3QmbAAAAdCdsArC0Nre2c+XajWxubQ/dFABYOMImAEtrdW0j\nl69ez+raxtBNAYCFc2LoBgDAUC5dOHvTd+bH5tZ2Vtc2cunC2Zw5dXLo5gBwG/RsArC0zpw6mQcv\nnhNm5pBeZ4Dx07MJAMwdvc4A4ydsAgBzZ7fXGYDxUkYLAABAd8ImAAAA3QmbAAAAdCdsAgAA0J2w\nCQAAQHfCJgAAAN0JmwAAAHQnbAIAANCdsAkAAEB3wiYAAADdCZsAAAB0J2wCAADQnbAJAABAd8Im\nAAAA3QmbAAAAdCdsAgAA0J2wCQAAQHfCJgAAAN0JmwAAAHQnbAIAANCdsAkAAEB3wiYAAADdCZsA\nAAB0J2wCAADQnbAJAABAd8ImAAAA3QmbAAAAdCdsAgAA0F2XsFlVv1hVX6yqL1XV9ap6+2T5PVX1\nmap6YrL8dXtes+86AAAAxu3EtG9QVZXk40l+urX2v6rqfJL/XVWfTvLeJI+11n6+ql6T5Per6r7W\n2vYB6wAAABixXmW0Lcm/mPz5nyd5Osl3k7w1yQeTpLX2eJJvJHn9ZLtbrQMAAGDEpu7ZbK21qvrl\nJJ+uqv+T5F8meXOSM0lOtta+uWfzJ5Ocq6q79ls3bXsAAAAY3tQ9m1V1IslvJHlza+3FSd6Q5GPp\nEGQn7/9IVa3vfj3zzDM93hYAAIAZ6lFG+6okP9pa+1zyg5LY9SSvTPJsVd27Z9vzSW601p7eb91z\n37y19v7W2sru1+nTpzs0mUWzubWdK9duZHNrNkN+Z/3+AACwaHqEza8nOVtVP54kVXV/kpck+esk\nn0jyzsny1yR5YZLPTl53q3VwJKtrG7l89XpW1zZG+f4AALBoeozZ/FZVvSPJf6mq72cnwP5qa+1G\nVb07yceq6okk30vy0J7ZZm+1Do7k0oWzN30f2/sDAMCiqdba0G04kpWVlba+vj50MwAAAJZOVT3V\nWls5zLa9Hn0CAAAAPyBsAgAA0J2wCQAAQHfCJgAAAN0JmwAAAHQnbAIAANCdsAkAAEB3wiaHsrm1\nnSvXbmRza3vopgAAACMgbHIoq2sbuXz1elbXNoZuCgAAMAInhm4A43DpwtmbvgPHY3NrO6trG7l0\n4WzOnDo5dHMAAA5NzyaHcubUyTx48ZyLXThmqgoAgLHSswkwx1QVAABjJWwCzLHdqgIAgLFRRgsA\nAEB3wiYAAADdCZsAAAB0J2wCAADQnbAJAABAd8ImAAAA3QmbAABAN5tb27ly7UY2t7aHbgoDEzYB\nAIBuVtc2cvnq9ayubQzdFAZ2YugGAAAAi+PShbM3fWd56dkEAIAjUCZ6a2dOncyDF8/lzKmTQzeF\ngQmbAABwBMpE4XCU0QIAwBEoE4XDETYBAOAIdstEgVtTRgsAAMfImE+WhbAJAADHyJhPloUyWgAA\nOEbGfLIshE0AADhGxnyyLJTRAiwx44YAgFkRNgGWmHFDAMCsKKMFWGLGDQEAsyJsAiwx44YAgFlR\nRgsAAEB3wiYAAADdCZsAAAB0J2wCAADQnbAJAABAd8ImAAAA3QmbAAAAdCdsQgebW9u5cu1GNre2\nh24KAADMBWETOlhd28jlq9ezurYxdFMAAGAunBi6AbAILl04e9N3AABYdsImdHDm1Mk8ePHc0M0A\nAIC5oYwWAACA7oRNAAAAuhM2AYDRMQs4wPwTNgGA0TELOMD8M0EQADA6ZgEHmH/CJgAwOmYBB5h/\nymgBAADoTtgEAACgO2ETAACA7oRNAAAAuhM2AQAA6E7YBAAAoLsuYbOq7qiq366qJ6rqy1X18cny\nl1bVF6rqK1X1eFW9Ys9r9l0HAADAuPXq2XxvkpbkZa21n0jyrsnyDyX5cGvtZUnel+Sje15zq3UA\nAACMWLXWpnuDqh9KspFkpbX2nT3L70ny1SQ/3Fp7tqpqst1rk3xnv3Wtta/e6t9bWVlp6+vrU7UZ\nAACAo6uqp1prK4fZtkfP5kuS/EOSR6vqz6vqz6rqDUlelGSjtfZskrSdVHsjybkD1gEAADByPcLm\niSQvTvJXrbVXJ/m1JL83WT61qnqkqtZ3v5555pkebwsAC29zaztXrt3I5tb2IK8HYLn1CJs3knw/\nye8mSWvtL5P8bXYC6NmqOpEkk1LZc5Ptv36LdTdprb2/tbay+3X69OkOTQaAxbe6tpHLV69ndW1j\nkNcDsNym7n1srX27qv4kyc8l+cOqui/JfUk+n+SLSR7KzuQ/b0myvjsms6r2XQcATO/ShbM3fT/u\n1wOw3KaeIChJqurHkvxOkh/JTi/ne1prn6qql2cnTN6VnUmBHm6tfXnymn3X3YoJggAAAIZxlAmC\nuoTN4yRsAgAADOO4Z6MFAACAmwibJDHjIAAA0JewSRIzDgIMwY0+ABZZl2dhMn5mHAQ4frs3+pLk\nwYvnBm4NHJ/Nre2srm3k0oWzOXPq5NDNAWZE2CRJcubUSRc6AMfMjT6WlRstsByETQAYiBt9LCs3\nWmA5CJsAABwrN1pgOZggCABgZEwuBYyBsAkAMDJmkQfGQBktAMDIGPMIjIGwCQAwMsY8AmOgjBYA\nAIDuhE0AAAC6EzYBAADoTtgEAACgO2ETAACA7oRNAAAAuhM2AQAA6E7YBAAAfmBzaztXrt3I5tb2\n0E1h5IRNAADgB1bXNnL56vWsrm0M3ZSls2hB/8TQDQAAAObHpQtnb/rO8dkN+kny4MVzA7dmesIm\nAADwA2dOnVyIoDNGixb0hU0AAIA5sGhB35hNAAAAuhM2AQDoatEmOQFuj7AJAEBXZjMFEmM2AQDo\nbNEmOQFuj7AJAEBXizbJCXB7lNECAADQnbAJAABAd8ImAAAA3QmbAAAAdCdsQjwPDAAAehM26WLW\nYW3W7+95YMAYuVEGwDzz6BO62A1rSWYy1fms39/zwGA2Nre2s7q2kUsXzubMqZNDN2fhzPrYCADT\nEDbpYtZhbdbv73lgMBvC0Gy5UQbAPKvW2tBtOJKVlZW2vr4+dDMAOIRF79lc9J8PAJ6rqp5qra0c\nZltjNoEDjXlc2Jjbvgh2qwYWNYgZ7w0A+1NGCyMwdO/JmEshx9x25p8yVgDYn7AJIzB0YJrnC+qD\ngvg8t53xW/bx3kPfCANgvgmbMAJDB6Z5vqA+KIjPc9th7Ia+EQY8PzeCmBfC5pxwUOBWBKb9DR3E\nYZn5/MF8ciOIeSFszgkHBbg9gjgMx+cP5pMbQcwLYXNOOCgwS3rOAWB5uBHEvPDokzmx6I8HYFiz\nfjyDx4sAAPBcejZhCcy651wZOAAsDxVTHJawCUtg1uU0ysAZigsegOPnJjOHJWwCUzM2hKG44AE4\nfm4yc1jCJgCj5YIH4Pi5ycxhCZsAjJYLHgCYX2ajBQAAoDthEwAAgO6EzZHwHEMAenNuAWCWhM2R\n2J1xcXVtY+imALAgnFsAmCUTBI2EGReBIXiO5WJzbgFglvRsjsTujIsu9oDjdFDPlzLMcXNuGY7P\nDrAMuoXNqnq4qlpVvWny93uq6jNV9URVXa+q1+3Zdt91AMyPSxfO5j1vfGDfni9lmHB7fHaAZdCl\njLaqzif5lSSP7Vn83iSPtdZ+vqpek+T3q+q+1tr2AesAjs3Yy0Rn3f6DnmOpDBNuj88OsAym7tms\nqhck+UiSf5/ku3tWvTXJB5OktfZ4km8kef0h1gEcm7H3LgzdfmWYcHt8doBl0KNn85Ekn2+t/UVV\nJUmq6q4kJ1tr39yz3ZNJzt1qXYe2ABzJ2HsXxt5+AGBxTRU2q+qBJG9JMrMxl1X1SHYCbZLkzjvv\nnNU/BSyhg8pE593Y2w8ALK5py2h/Ksn5JE9U1ZNJfjLJh7NTJvtsVd27Z9vzSW601p7eb93z/QOt\ntfe31lZ2v06fPj1lkwFgHMxYCsCYTRU2W2sfaK2dba2db62dz84EQe9orX0gySeSvDNJJpMAvTDJ\nZycvvdU6ACDDj8mFReVGDhyPWT5n891J/nVVPZHko0ke2jPb7K3WsYQc9Bkr++58G/v/z0GPngFu\njxs5cDy6PPpkV2vtp/f8+VtJfnaf7fZdx3LaPegnMf6MUbHvzrex//8YkwuzYXI1OB5dwybcLgd9\nxsq+O9/8/8y3sT/nlvFyIweOR7XWhm7DkaysrLT19fWhmwEATOnKtRu5fPV63vPGB1z405UbGTA7\nVfVUa23lMNvOcswmAAcY+5hCuJWD9m9jUpkVYzJhPgibS8IFLcwnF0QssoP2791SxkXseXLeHZYb\nGTAfjNlcEmOfJAMWlTGFLLJl3r/Hft4dexmqMZkwH4TNJbHMJ3yYZy6IWGTLvH+P/bw79rAMzAcT\nBAHM0Nh7B4Dl5NgF7McEQQBzwphMYIwWeTwtcHyU0QLM0NhL6QBgTPTKzxc9mwAzpHcAAI6PiqL5\nomcTAIAjWfTeo0X/+RaZiqL5omcTAIAjWfTeo0X/+RaZiqL5omcT5oA7qACMyaL3Hi36zwfHRc8m\nzAF3UOH5bW5t58q1G9nc2h66KcAei957tOg/HxwXPZswB9xBhefnwfIAMF7CJsyB3TuowM3ciAGA\n8RI2AZhbbsQAwHgZswkAAEB3wiYAAADdCZsAAAB0J2wCAEvnoMfqeOwOwPSETQBg6Rz0fGPPPwaY\nntloAYClc9BjdTx2B2B61Vobug1HsrKy0tbX14duBgBLYHNrO6trG7l04WzOnDo5dHMAYHBV9VRr\nbeUw2yqjBYB9TFtKadwfAMtMGS0A7GPaUsrdsJokD148161dADAGwmYnSq3g+flsMGZnTp2cKiQa\n9wfAMlNG24lZ6+D5+WywzHbDqhstACwjPZudDH33Wu8R82rozwYAAMPQs9nJ0Hev9R4xr4b+bEzL\nBC8AALdHz+aC0HsEs2GCFwCA2yNsLohpJ7EAnp8bOQAAt0cZLSwApZ6zM/YyYACAoQibx0QYYJaM\n2QUAYN4Im8dEGGCWLl04m/e88QGlngAwB6btZNBJwaIwZvOYGPfFLBmzCwDzY9rJ5UxOx6IQNo+J\nMLDcPAcVgHnivDRb03YyHPR6/3+MhTJaOAbKqAEWy9jLHOf9vDT23++0k8sd9Pp5//+DXXo24Rgo\nowZYLGMvc5z389LYf7+zNu//f7CrWmtDt+FIVlZW2vr6+tDN4IiUewCwSJzXZsvvF+ZXVT3VWls5\nzLbKaDkWyj24XWMvpQIWk2fwzpbfLywGYZNjMfSjOQSW8XKjAgAWyzJfly3bz27MJsdi6Nl4jf0Y\nL+NSgNuhDBPm1zJfly3bzy5sshQElvEa+kYFsyUQMCvLdkEHY7LM12XL9rObIAhwwc9grly7kctX\nr+c9b3xAIKArxzXg+Tg2TO8oEwTp2QT0ADCYZbvDy/FRFQE8H9c8x0vYBFzwMxiBAOD4LXPv3rTX\nPMv8u7sdZqMFTDEPAEtkmWd6P+ia56DZYpf5d3c79GwCAMASUdG0v4PKbP3ujsYEQQAAAJm+THYZ\nymyPMkGQMloAAIBMP7RIme3NlNECAAB0oMz2ZsImAABAB2ZZv5kyWgAAALoTNgEAAOhO2AQAAKC7\nqcNmVZ2qqj+oqq9U1VpV/VFV3T9Zd09Vfaaqnqiq61X1uj2v23cdAAAA49arZ/PDSV7eWruQ5GqS\nj0yWvzfJY621lyZ5OMmVqjp5iHUAAACM2NRhs7W21Vr7w9Zamyx6LMn5yZ/fmuSDk+0eT/KNJK8/\nxDoAAABGbBZjNn89ydWquivJydbaN/esezLJuVutm0F7AAAAOGZdn7NZVY8muT/JG5L8s07v+UiS\nR3b/fuedd/Z4WwAAAGaoW89mVb0ryZuT/EJr7Z9aa08nebaq7t2z2fkkN2617rnv21p7f2ttZffr\n9OnTvZoMAADAjHQJm5Pex7cl+ZnW2j/uWfWJJO+cbPOaJC9M8tlDrAMAAGDEpi6jraqVJL+Z5G+S\n/GlVJcl3W2sXk7w7yceq6okk30vyUGtte/LSW60DAABgxKYOm6219SS1z7pvJfnZo64DAABg3GYx\nGy0Ac2JzaztXrt3I5pbCERgTn11gEQibAAtsdW0jl69ez+raxtBNAY7AZ5dl5mbL4uj66BMA5sul\nC2dv+g6Mg88uy2z3ZkuSPHjx3MCtYRrVWhu6DUeysrLS1tfXh24GAAAwA5tb21ld28ilC2dz5tTJ\noZvDc1TVU621lcNsq4wWAACYG2dOncyDF8/tGzSV2Y6HsAkAAIyGMc3jYcwmAACjosxyuRnTPB56\nNgEAGBU9W8vtoDJb5oeeTQAARkXPFoyDsAkAwKjs9mwB800ZLQAAAN0JmwAAAHQnbAIAANCdsAkA\nAEB3wiYAAADdCZsAAAB0J2wCAADQnbAJAABAd8ImAAAA3QmbAAAAdCdsAgAA0F211oZuw5FU1XeT\n/P3Q7djH6STPDN0Ilpb9j6HY9xiKfY8h2f8YytD73t2ttTsOs+HowuY8q6r11trK0O1gOdn/GIp9\nj6HY9xiS/Y+hjGnfU0YLAABAd8ImAAAA3Qmbfb1/6Aaw1Ox/DMW+x1DsewzJ/sdQRrPvGbMJAABA\nd3o2AQAA6E7YBAAAoDths4OqemlVfaGqvlJVj1fVK4ZuE4upqk5V1R9M9rW1qvqjqrp/su6eqvpM\nVT1RVder6nVDt5fFVVUPV1WrqjdN/m7/Y6aq6o6q+u3JPvblqvr4ZLlzMDNXVb9YVV+sqi9NjnFv\nnyx37KOrqvqtqnpyco591Z7l+x7r5vk4KGz28aEkH26tvSzJ+5J8dNjmsOA+nOTlrbULSa4m+chk\n+XuTPNZae2mSh5NcqaqTA7WRBVZV55P8SpLH9iy2/zFr703SkrystfYTSd41We4czExVVSX5eJJ/\n11p7VZJ/k+RDVXUmjn3098kkr03ytecsv9Wxbm6PgyYImlJV3ZPkq0l+uLX27OSAtJHkta21rw7b\nOhZdVb06ySdba+er6pkk97fWvjlZ9z+TPNpa++NBG8lCqaoXJPlvSd6d5DeT/OfW2h/Y/5ilqvqh\n7JxbV1pr39mz3DmYmZvsV99O8m9ba5+rqlcm+a9J7kvyD3HsYwaq6skkb2qtfelWx7ok39lv3Twc\nB/VsTu9FSTZaa88mSdtJ7zeSnBu0VSyLX09ytaruSnJy92Q38WTsh/T3SJLPt9b+YneB/Y9j8JLs\nXNQ/WlV/XlV/VlVviHMwx2CyX/1ykk9X1deS/I8kb09yJo59HI9bHevm+jgobMJIVdWjSe5P8h+G\nbgvLoaoeSPKWJP9x6LawdE4keXGSv2qtvTrJryX5vclymKmqOpHkN5K8ubX24iRvSPKx2P/gQMLm\n9L6e5OzkQLRbanEuO3cUYCaq6l1J3pzkF1pr/9RaezrJs1V1757Nzsd+SF8/lZ396olJec9PZmcM\n8Vtj/2O2biT5fpLfTZLW2l8m+dvsBFDnYGbtVUl+tLX2uSRprT2eZD3JK+PYx/G4Vd6Y6ywibE6p\ntfZ3Sb6Y5KHJorckWZ+HGmkWU1U9kuRtSX6mtfaPe1Z9Isk7J9u8JskLk3z2+FvIomqtfaC1dra1\ndr61dj47EwS9o7X2gdj/mKHW2reT/EmSn0uSqrovO+PlPh/nYGZv92L+x5NkMgv8S5L8dRz7OAa3\nyhvznkVMENRBVb08O7M+3ZWdQboPt9a+PGijWEhVtZKdk97fJNmcLP5ua+1iVf2r7JT13Jfke0l+\ntbX2p8O0lGVQVf89/2+CIPsfM1VVP5bkd5L8SHZ6Od/TWvuUczDHoareluTR7Ox7L0jyn1prVxz7\n6K2qPpTkl5Lcm+TpJJuttftvdayb5+OgsAkAAEB3ymgBAADoTtgEAACgO2ETAACA7oRNAAAAuhM2\nAQAA6E7YBAAAoDthEwAAgO6ETQAAALoTNgEAAOju/wKP43spQqZEUQAAAABJRU5ErkJggg==\n", 192 | "text/plain": [ 193 | "" 194 | ] 195 | }, 196 | "metadata": {}, 197 | "output_type": "display_data" 198 | } 199 | ], 200 | "source": [ 201 | "plt.figure(figsize=(14, 5), dpi=80)\n", 202 | "plt.scatter(episodes, iterations, s=0.5)\n", 203 | "#plt.hist(iterations, bins=100)\n", 204 | "None" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 64, 210 | "metadata": { 211 | "scrolled": true 212 | }, 213 | "outputs": [ 214 | { 215 | "data": { 216 | "text/plain": [ 217 | "(202, array([[-0.13383668, -0.26133253, -0.18449042, 0.08677108, 0.06254631,\n", 218 | " -0.15061691, 0.16906858, -0.06251222, 0.07983635, -0.14186024],\n", 219 | " [-0.05345007, -0.03247772, -0.14748635, -0.174455 , -0.08724804,\n", 220 | " -0.18091682, 0.28122277, 0.06860553, 0.22778229, 0.19586654],\n", 221 | " [-0.09268782, -0.07004116, -0.02485723, -0.05054954, -0.07500264,\n", 222 | " -0.24865851, 0.08908225, 0.1064135 , -0.13726708, -0.26347313],\n", 223 | " [-0.00722484, 0.08859729, 0.07546832, 0.05640408, 0.00233169,\n", 224 | " 0.06386022, -0.13256574, 0.30330713, 0.04238674, 0.13221169],\n", 225 | " [-0.01910994, 0.00648791, -0.1253033 , -0.081292 , -0.01100774,\n", 226 | " -0.15696131, 0.19084809, 0.11525576, -0.09798487, -0.05958179],\n", 227 | " [-0.1372828 , -0.20033466, -0.03238448, -0.05364915, -0.12932985,\n", 228 | " 0.01981108, 0.05396511, -0.20613241, -0.07531075, 0.1247759 ],\n", 229 | " [-0.10363805, -0.25081688, -0.16896908, 0.24315942, 0.1815737 ,\n", 230 | " -0.14699275, 0.19022494, 0.21851419, 0.26020678, 0.18720707],\n", 231 | " [-0.01265579, 0.24167331, 0.26767328, -0.12764771, -0.20161029,\n", 232 | " -0.1040776 , 0.03898223, -0.31178328, -0.29660674, -0.22976241],\n", 233 | " [ 0.08708334, 0.10324454, -0.17512389, 0.09694226, 0.21055723,\n", 234 | " -0.2119353 , -0.06426451, 0.06814469, 0.02813327, -0.03687525],\n", 235 | " [ 0.02974516, 0.21678022, 0.07533856, -0.12043295, -0.09031784,\n", 236 | " -0.16903346, 0.29631335, -0.30342659, 0.10582635, -0.21252685]]))" 237 | ] 238 | }, 239 | "execution_count": 64, 240 | "metadata": {}, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "count, weights" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 90, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "if False:\n", 255 | " np.save(\"weights\", weights)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 12, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "if False:\n", 265 | " weights = np.load(\"weights.npy\")" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [] 274 | } 275 | ], 276 | "metadata": { 277 | "kernelspec": { 278 | "display_name": "Python 3", 279 | "language": "python", 280 | "name": "python3" 281 | }, 282 | "language_info": { 283 | "codemirror_mode": { 284 | "name": "ipython", 285 | "version": 3 286 | }, 287 | "file_extension": ".py", 288 | "mimetype": "text/x-python", 289 | "name": "python", 290 | "nbconvert_exporter": "python", 291 | "pygments_lexer": "ipython3", 292 | "version": "3.6.3" 293 | } 294 | }, 295 | "nbformat": 4, 296 | "nbformat_minor": 2 297 | } 298 | -------------------------------------------------------------------------------- /open ai gym/mountain car - sarsa-Copy1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gym\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import numpy as np\n", 12 | "from collections import defaultdict\n", 13 | "import math, random" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "%matplotlib inline" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "env = gym.make(\"CartPole-v1\")" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 4, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "data": { 41 | "text/plain": [ 42 | "(array([-0.0182929 , -0.00087211, -0.03741516, -0.0190165 ]),)" 43 | ] 44 | }, 45 | "execution_count": 4, 46 | "metadata": {}, 47 | "output_type": "execute_result" 48 | } 49 | ], 50 | "source": [ 51 | "initial_state = env.reset(),\n", 52 | "initial_state" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 53, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "def step(env, action):\n", 62 | " state, reward, final, info = env.step(action)\n", 63 | " return state, reward, final, info" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 120, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "data": { 73 | "text/plain": [ 74 | "((4,),\n", 75 | " array([-4.80000000e+00, -3.40282347e+38, -4.18879020e-01, -3.40282347e+38]),\n", 76 | " array([4.80000000e+00, 3.40282347e+38, 4.18879020e-01, 3.40282347e+38]))" 77 | ] 78 | }, 79 | "execution_count": 120, 80 | "metadata": {}, 81 | "output_type": "execute_result" 82 | } 83 | ], 84 | "source": [ 85 | "env.observation_space.shape, env.observation_space.low, env.observation_space.high" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 243, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "def featurize_state(state):\n", 95 | " return np.array([state[2], state[3], state[0]*state[2], state[0]*state[3],\n", 96 | " state[2]*state[3], state[0], state[1]])\n", 97 | "\n", 98 | "def action_space(env):\n", 99 | " env.action_space.low()" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 244, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "def actions(env):\n", 109 | " # return np.linspace(env.action_space.low, env.action_space.high, 10)\n", 110 | " return np.array([0, 1])" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 245, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "state = env.observation_space.sample()\n", 120 | "state = featurize_state(state)\n", 121 | "\n", 122 | "def init_weights():\n", 123 | " state = env.observation_space.sample()\n", 124 | " state = featurize_state(state)\n", 125 | " # return np.random.uniform(-1/np.sqrt(state.shape[0]),\n", 126 | " # 1/np.sqrt(state.shape[0]),\n", 127 | " # (actions(env).shape[0], state.shape[0]))\n", 128 | " return np.zeros((actions(env).shape[0], state.shape[0]), dtype=np.float32)\n", 129 | "weights = init_weights()\n", 130 | "\n", 131 | "def q(state, action_idx):\n", 132 | " state = featurize_state(state)\n", 133 | " return np.dot(weights[action_idx], state)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 246, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "num_actions = len(actions(env))\n", 143 | "def select_action(state, greedy=0.2):\n", 144 | " state = featurize_state(state)\n", 145 | " max_arg = np.dot(weights, state).argmax()\n", 146 | " \n", 147 | " if np.random.uniform() < 1 - greedy:\n", 148 | " return max_arg\n", 149 | " else:\n", 150 | " return random.randint(0, num_actions-1)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 285, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "state = env.reset()\n", 160 | "weights = init_weights()\n", 161 | "action = None\n", 162 | "done = False" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 286, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "lr = learning_rate = 0.05\n", 172 | "df = discount_factor = 0.9\n", 173 | "\n", 174 | "episodes = []\n", 175 | "iterations = []\n", 176 | "\n", 177 | "\n", 178 | "for i in range(100):\n", 179 | " count = 0\n", 180 | " state = env.reset()\n", 181 | " done = False\n", 182 | " action = None\n", 183 | " greedy=0.2\n", 184 | " \n", 185 | " while not done:\n", 186 | " count += 1\n", 187 | " # if count >= 300: raise Exception(\"yay\")\n", 188 | " if action is None:\n", 189 | " action = select_action(state, greedy=greedy)\n", 190 | "\n", 191 | " next_state, reward, done, _ = env.step(actions(env)[action])\n", 192 | " \n", 193 | " if done:\n", 194 | " w = weights[action]\n", 195 | " w = w + lr*(reward - q(state, action))*featurize_state(state)\n", 196 | " weights[action] = w\n", 197 | " break\n", 198 | " \n", 199 | " next_action = select_action(next_state, greedy=greedy)\n", 200 | " w = weights[action]\n", 201 | "\n", 202 | " w = w + lr*(reward + df*q(next_state, next_action) - q(state, action))*featurize_state(state)\n", 203 | "\n", 204 | " weights[action] = w\n", 205 | "\n", 206 | " action = next_action\n", 207 | " state = next_state\n", 208 | " # env.render()\n", 209 | " \n", 210 | " episodes.append(i)\n", 211 | " iterations.append(count)\n", 212 | " env.render()" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 287, 218 | "metadata": {}, 219 | "outputs": [ 220 | { 221 | "data": { 222 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA5QAAAFTCAYAAABGRENtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAMTQAADE0B0s6tTgAAHUlJREFUeJzt3X+sZGd5H/DvQ7zyqlnLbVwDW64dxzHQQordishJSkhU\nQipKTBKs4toiTRFKSZOolJXbBLelUhKp0AonQkipHUhJQqyiQIK7QEJpBCEJwqElbLFQMMiY9XU3\nEIwcdkOXrMvTP3auc7G9u3PPzOyZH5+PNLp33jMz552Z95yZ77zveU91dwAAAGCvnjB2BQAAAFhN\nAiUAAACDCJQAAAAMIlACAAAwiEAJAADAIAIlAAAAgwiUAAAADCJQAgAAMIhACQAAwCACJQAAAINc\nMM2Nqmp/kv+a5BlJ/m+Szyf559396ar6L0n+3qT8RJJ/2d0fmdzvLUmen+RPJw/1vu7+V+da34UX\nXtiXXnrpHp8KAAAAs3rggQf+orsvnOa2UwXKiduT/FZ3d1X9RJI3JfnuJL+Z5Ee6++Gq+r4kv57k\nil33+0/d/fN7WE8uvfTSbG9v7+UuAAAAzEFV/em5b3XaVENeu/tkd7+nu3tS9OFMQmN3/7fufnhX\n+VOqai9BFQAAgBU09BjKVya58wzl79kVMJPklVX1v6vqXVV1zcD1AQAAsGT23JNYVbckuSrJ8x5V\n/tIkL0ny3F3F/ybJse7+alX9YJLfqqqndveJR933UJJDO9cvvvjivVYLAACA82xPPZRVdXOSFyd5\nQXd/eVf5DUn+fZLnd/fndsq7+4Hu/urk/99M8qUkT3/043b3rd29tXM5cODAsGcDAADAeTN1oJz0\nIt6Y06HxoV3lL0nys0m+p7uPPuo+W7v+/7YklyT59KyVBgAAYHzTnjZkK8nrk9yb5P1VlSRf6e5r\nk/xakj9JcuekPEme190PJnlLVT0pyf/L6dOK/KPu/rP5PgUAAADGMFWg7O7tJHWGZfvOcr/vGVgv\nAAAAltzQWV4BAADYcAIlAAAAgwiUAAAADCJQAgAAMIhACSvi+MlTueOuozl+8tTYVQEAgCQCJayM\nw0eO5TV33p3DR46NXRUAAEgy5WlDgPFdd/XBr/kLAABjEyhhRVy0f19uuvbysasBAACPMOQVAACA\nQQRKAAAABhEoAQAAGESgBAAAYBCBEgAAgEEESgAAAAYRKAEAABhEoAQAAGAQgRIAAIBBBEoAAAAG\nESgBAAAYRKAEAABgEIESAACAQQRKAAAABhEoAQAAGESgBAAAYBCBEgAAgEEESgAAAAYRKAEAABhk\nqkBZVfur6p1VdU9VHamq91XVVZNlT6yq366qT1XV3VX13F33O+MyAAAAVtteeihvT/L07r46yZ1J\n3jQpf22SD3f3U5O8LMkdVbVvimUAAACssKkCZXef7O73dHdPij6c5IrJ/y9J8p8nt/tIkv+T5Lum\nWAYAAMAKG3oM5SuT3FlVlyTZ191/smvZfUkuP9uyRz9YVR2qqu2dy4kTJwZWCwAAgPNlz4Gyqm5J\nclWSV8+rEt19a3dv7VwOHDgwr4cGAABgQfYUKKvq5iQvTvKC7v5ydz+Y5OGqevKum12R5OjZls1W\nZQAAAJbB1IGyqg4luTHJ87v7oV2Lfj3Jj05u861JnpLkd6dYBgAAwAq7YJobVdVWktcnuTfJ+6sq\nSb7S3dcm+ckkv1pVn0ryF0le2t2nJnc92zIAAABW2FSBsru3k9QZln0uyffudRkAAACrbegsrwAA\nAGw4gRIAAIBBBEoAAAAGESgBAAAYRKAEAABgEIESAACAQQRKAAAABhEoAQAAGESgBAAAYBCBEgAA\ngEEESgAAAAYRKAEAABhEoAQAAGAQgRIAAIBBBEoAAAAGESgBAAAYRKAEAABgEIESAACAQQRKAAAA\nBhEoAQAAGESgBAAAYBCBEgAAgEEESgAAAAYRKAEAABhEoAQAAGAQgRIAAIBBpgqUVfWGqrqvqrqq\nrpmUXVJVH9t1uaeqHq6qb5gs/0BVfWbX8lct8okAAABwfl0w5e3enuQ/Jvn9nYLufjDJNTvXq+rm\nJN/V3V/cdb9Xdfc751FRAAAAlstUgbK7P5gkVXW2m708yavnUCcAAABWwFyOoayq70jy15K861GL\nXltVH6+qt1XVlWe5/6Gq2t65nDhxYh7VAgAAYIHmNSnPy5P8Snc/vKvsh7r7byZ5VpLfy2PD5iO6\n+9bu3tq5HDhwYE7VAgAAYFFmDpRVdSDJS5L80u7y7r5/8re7+41JrqyqS2ZdHwAAAMthHj2UNyQ5\n0t1/vFNQVRdU1ZN2Xb8+yecmE/kAAACwBqaalKeqbkvywiRPTvLeqjre3VdNFr88yS8+6i4XJnl3\nVV2Y5KtJvpDkRfOpMgAAAMugunvsOjzG1tZWb29vj10NAACAjVNVD3T31jS3ndekPAAAAGwYgRIA\nAIBBBEoAAAAGESgBAAAYRKAEAABgEIESAACAQQRKAAAABhEoAQAAGESgBAAAYBCBEgAAgEEESgAA\nAAYRKAEAABhEoAQAAGAQgRIAAIBBBEoAAAAGESgBAAAYRKAEAABgEIESAACAQQRKAAAABhEoAQAA\nGESgBAAAYBCBEgAAgEEESgAAAAYRKAEAABhEoAQAAGAQgRIAAIBBpgqUVfWGqrqvqrqqrtlVfl9V\nfbKqPja53LBr2VOr6kNVdU9VfaSqnrmIJwAAAMA4pu2hfHuS5yT57OMsu6G7r5lc3rar/LYkt3f3\n05K8LslbZqopAAAAS2WqQNndH+zu7WkftKqemOTZSd46KXpHksuq6qq9VxEAAIBlNI9jKH+lqj5e\nVW+uqksnZZclOdbdDydJd3eSo0kun8P6AAAAWAKzBsrndvezkvzdJF9I8stDHqSqDlXV9s7lxIkT\nM1YLAACARZspUHb30cnfU0l+Psl3Thbdn+RgVV2QJFVVOd07efQMj3Nrd2/tXA4cODBLtQAAADgP\nBgfKqvr6qvqru4puTPJHSdLdn0/y0SQvnSy7Psl2d3966PoAAABYLhdMc6Oqui3JC5M8Ocl7q+p4\nku9N8o6q+rokleTeJP9k191ekeQtVXVLki8ledk8Kw4AAMC46vR8Octla2urt7ennlQWAACAOamq\nB7p7a5rbzmOWVwAAADaQQAkAAMAgAiUAAACDCJQAAAAMIlACAAAwiEAJAADAIAIlAAAAgwiUAAAA\nDCJQAgAAMIhACQAAwCACJQAAAIMIlAAAAAwiUAIAADCIQAkAAMAgAiUAAACDCJQAAAAMIlACAAAw\niEAJAADAIAIlAAAAgwiUADCD4ydP5Y67jub4yVNjVwUAzjuBEgBmcPjIsbzmzrtz+MixsasCAOfd\nBWNXAABW2XVXH/yavwCwSQRKAJjBRfv35aZrLx+7GgAwCkNeAQAAGESgBAAAYBCBEgAAgEGmCpRV\n9Yaquq+quqqumZTtr6p3VtU9VXWkqt5XVVftus8HquozVfWxyeVVi3oSAAAAnH/T9lC+Pclzknz2\nUeW3J3l6d1+d5M4kb3rU8ld19zWTy8/NVlUAAJiOc8TC+TFVoOzuD3b39qPKTnb3e7q7J0UfTnLF\nnOsHAAB75hyxcH7M87Qhr8zpXsrdXltVP5PkE0le3d33znF9AADwuJwjFs6PuQTKqrolyVVJnrer\n+Ie6+/6qqiQ/nuRdSZ5xhvsfSnJo5/rFF188j2oBALChnCMWzo+ZZ3mtqpuTvDjJC7r7yzvl3X3/\n5G939xuTXFlVlzzeY3T3rd29tXM5cODArNUCAABgwWYKlJOexRuTPL+7H9pVfkFVPWnX9euTfK67\nH5xlfQAAACyPqYa8VtVtSV6Y5MlJ3ltVx5N8d5LXJ7k3yftPj2zNV7r72iQXJnl3VV2Y5KtJvpDk\nRXOvPQAAAKOZKlB29yvOsKjOcPs/T/LsoZUCAABg+c18DCUAAACbSaAEAABgEIESABjs+MlTueOu\nozl+8tTYVQFgBAIlADDY4SPH8po7787hI8fGrgoAI5hqUh4AOJPjJ0/l8JFjue7qg7lo/76xq8N5\ndt3VB7/mLwCbRQ8lADPRQ7XZLtq/Lzdde7kfEwA2lB5KAGaihwoANpceSjaGiSOWm/dndemhAoDN\nJVCyMQzLW27eHxjGjzEAjMmQVzaGYXnLzfsDw+z8GJMkN117+ci1AWDTVHePXYfH2Nra6u3t7bGr\nAQBLzyy7AMxbVT3Q3VvT3FYPJQCssJ1jWAFgDI6hBAAAYBCBEgAAgEEESgAAAAYRKAEAABhEoAQA\nAGAQgRIAAIBBBEoAAAAGESgBAAAYRKAEAABgEIESAACAQQRKAAAABhEo9+D4yVO5466jOX7y1NhV\nAQAAGJ1AuQeHjxzLa+68O4ePHBu7KgAAAKO7YOwKrJLrrj74NX8BAAA22VQ9lFX1hqq6r6q6qq7Z\nVf7UqvpQVd1TVR+pqmdOs2xVXbR/X2669vJctH/f2FUBAAAY3bRDXt+e5DlJPvuo8tuS3N7dT0vy\nuiRvmXIZAAAAK26qQNndH+zu7d1lVfXEJM9O8tZJ0TuSXFZVV51t2XyqDQAAwNhmmZTnsiTHuvvh\nJOnuTnI0yeXnWAYAAMAaWIpZXqvqUFVt71xOnDgxdpUAAAA4h1kC5f1JDlbVBUlSVZXTPZBHz7Hs\nMbr71u7e2rkcOHBghmoBAABwPgwOlN39+SQfTfLSSdH1Sba7+9NnWzZLZQEAmM7xk6dyx11Hc/zk\nqbGrAqyxaU8bcltVbSfZSvLeqtoJhq9I8oqquifJTyV52a67nW0ZAAALdPjIsbzmzrtz+MixsavC\nCvKDBNO6YJobdfcrzlD+ySTfvtdlAAAs1nVXH/yav7AXOz9IJMlN15pXkzObKlACALBaLtq/TxBg\nMD9IMC2BEgAA+Bp+kGBaS3HaEAAAAFaPQAkAAMAgAiUAAACDCJQAAAM5tQKw6QRKAICBnOsR2HRm\neQUAGMipFYBNJ1ACAAzk1ArApjPkFQAAgEEESlaGiQ8AAGC5CJQ8YtkDm4kPmMWyt28AgFXkGEoe\nsRPYkizkeJDjJ0/l8JFjue7qg7lo/74939/EB8xi0e0bAGATCZQ8YtGBbdYv9CY+YBZ+kAAAmL/q\n7rHr8BhbW1u9vb09djWYs1l7KJfduj8/AAA2Q1U90N1b09zWMZRrZNmPEdvpYVzXsOUYTwAANo0h\nr2vEMWLjMqQSAIBNI1CuEYFmXI7xPLtlHxK87PUDAFhGhryukXUfUrruln3I8qwWPSR41tfPkOUz\nW3TbnPXx133bAYBlpocSlsS6D1le9lmE9fCf2aLb5qyPv+7bDgAsM4ESlsS6B5pFDwme9fUzZPnM\nFt02Z338dd92AGCZOW0IAAAAj3DaEGDPHIfGstI2AWB5CZRAEpPSsLy0TQBYXo6hBJI4Do3lpW0C\nwPJyDCUAAACPcAwlAAAACzfzkNequiTJ7+wq+itJrkzyxCS/keQbk/zZZNkvd/fPzbpOAAAAxjdz\noOzuB5Ncs3O9qm5O8l3d/cWqSpJXdfc7Z10PAADr4/jJUzl85Fiuu/pgLtq/b+zqAAMtYsjry5O8\neQGPCwDAmjCDM6yHuc7yWlXfkeSvJXnXruLXVtXPJPlEkld3972Pc79DSQ7tXL/44ovnWS0AAJaM\nGZxhPcx1lteqenOSB7v7X0+uX9bd99fpsa8/nuTHuvsZ53ocs7wCANMwbBJg/kaZ5bWqDiR5SZJf\n2inr7vsnf7u735jkyskkPgAAMzNsEmBc8xzyekOSI939x0lSVRckuaS7Pze5fn2Sz00m8QEAmJlh\nkwDjmmegfHmSX9x1/cIk766qC5N8NckXkrxojusDADbcRfv35aZrLx+7GgAba26Bsru/41HX/zzJ\ns+f1+AAAACyXRZw2BAAAgA0gUAIAADCIQAkAAGyU4ydP5Y67jub4yVNjV2XlCZTMjQ0TAIBV4JRD\n8zPPWV7ZcDsbZhIz7gEAsLSccmh+BErmZuwN8/jJUzl85Fiuu/pgLtq/b5Q6AACw/JxyaH4MeV0i\nqz5kdGfDHCvMGboAj2/V9y0AwPLSQ7lEln3I6LL3AI7dQwrLatn3LQDA6hIol8iyB6Jl/1Jq6AI8\nvmXftwAAq6u6e+w6PMbW1lZvb2+PXY21M2sP47L3UC67dX/91v35AQBsiqp6oLu3prmtYyg3yKzH\nGI59jOSqW/djPNf9+QEA8FiGvG4Qw97Gte6v/7o/PwAAHsuQ1xViSCEAALBohryuKUMKYT05rQeL\npH3NxusHcHYC5Ryd60Nn1g+l664+mJ/+/m8xpHAkvlSwKH4sYpG0r9mM+fr53Nls3n9WhWMo5+hc\np9WY9bQbTosxrmU/bQqry/GnLJL2NZsxXz+fO5vN+8+qcAzlHJ3rGEfHQK427x/LStuE9WTb3mze\n/3Ft+uu/l2MoBUqAFXfHXUfzmjvvzk9//7f4FRtYGpv+hZyzW/b2semfrXsJlIa8Aqw4QxqBZWTI\nJmez7O3DZ+v09FACADB3y94Dxbi0j+VmyCsAAACDOA8lK8n02JyN9rE4Xttxef0ZStsBloFAydJw\nrjTORvtYHK/tuLz+DKXtAMvApDwsDQc/czbax+J4bcd1rtffcUaciW0XWAaOoQSAJbbpU9cDcP45\nbQgArAm9UAAss7kcQ1lV91XVJ6vqY5PLDZPyp1bVh6rqnqr6SFU9cx7rA1glmz5xhuc/2/O/aP++\n3HTt5Qsb7jr2+zP2+jkz7w0wjXlOynNDd18zubxtUnZbktu7+2lJXpfkLXNcH8BK2PSJM5b9+S/6\nS/OyP/+x6zf2+jmzTX9vBGqYzsKGvFbVE5M8O8n3TorekeSNVXVVd396UesFFmPTJwaZ5flv+pDF\nZX/+O1+akyzkGMVlf/5j12/s9XNmm/7eLHrfsOoW/b1g0793rJJ5BspfqapK8odJfirJZUmOdffD\nSdLdXVVHk1ye5GsCZVUdSnJo5/rFF188x2oB87DpH6yzPP+dIYubatmf/6K/NC/78x+7fmOvnzPb\n9Pdm0wP1uSz6e8Gmf+9YJfMKlM/t7qNVtS/Jzyb55ST/bto7d/etSW7dub61tbV8U8/Chtv0D9ZN\nf/7rbNO/NC/arL0MeikYi33D2S36c9Hn7uqY+2lDqupgknuSfHNO90R+Q3c/POm9PJbkOeca8uq0\nIQCwHmY97YnTpgCcf+f1tCFV9fVJ9nX3Q5OiG5P8UXd/vqo+muSlOT0Zz/VJth0/CQCbY9ZeBr0U\nAMtt5h7Kqroypyfc+bokleTeJK/s7vuq6uk5HSYvSfKlJC/r7o+f6zH1UAIAcDaGQ7PKlr39ntce\nyu6+N8nfOcOyTyb59lnXAZzbuXZMy77jAoC9MGkLq2yd2u/CThsCnF/n2jGt044LAAyHZpWtU/ud\n+6Q882DIK+ydHspxeX1hNY297Y69fobz3rHO9jLk9QmLrgysiuMnT+WOu47m+MlTY1dlkJ3pzc/0\noXau5cxmpwf48JFjY1eFR1n1bZvFGnvbHXv9y2zR2+6sj7/q792yv76sDkNeYcKQUGaxTkNX1o1t\nm7MZe9sdc/3L3sO26G131scfu+3MatlfX1aHQAkTq/7BwLicAHt52bY5m7G33THXv+xf+Be97c76\n+GO3nVkt++vL6nAMJQDABlr2Hsp1N+vr7/1jkRxDCQDMheOg1pdj68c16zGY57q/bZfzxZBXAOCM\nln1YJKyqWYeEnuv+tl3OF0NeAYAzWvVhdatefxhq3dv+uj+/sRnyCgDMxaoPi1z1UzvAUKu+7Z6L\nbXt5GPIKAKwtM03CerJtLw9DXgEAAHiEIa8AADAis6yyKQRKAACYM8f4sSkcQwkAAHPmGD82hR5K\nAIAFMexxc637LKuwQ6AEAFgQwx6BdWfIKwCchZNnMwvDHoF1p4cSAM5CDxOzMOwRWHd6KAHgLPQw\nwePTew8keigB4Kz0MMHj03sPZ7ZJE3LpoQQAYM/03sOZ7fzgkiQ3XXv5yLVZLIESAIA92+m9Bx5r\nk35wESgBAADmaJN+cHEMJQAAAIMIlAAAAAwyc6Csqv1V9c6quqeqjlTV+6rqqsmyD1TVZ6rqY5PL\nq2avMgAAwJlt0iyrY5tXD+XtSZ7e3VcnuTPJm3Yte1V3XzO5/Nyc1gewMXwoAsDeOK3N+TPzpDzd\nfTLJe3YVfTjJzbM+LgCnbdLU4wAwD5s0y+rYqrvn+4BVv5rki939yqr6QJInJzmV5BNJXt3d957r\nMba2tnp7e3uu9QJYVcdPnsrhI8dy3dUHc9H+fWNXBwBYc1X1QHdvTXXbeQbKqrolyXVJntfdX66q\ny7r7/qqqJD+e5Me6+xmPc79DSQ7tXL/44ouf8tBDD82tXgAAAExnlEBZVTcn+cdJvqe7HzcNVtXJ\nJE/p7gfP9lh6KAEAgGW17qOH9hIo5zIpz6SH8cYkz98Jk1V1QVU9addtrk/yuXOFSQAAgGVm0p+/\nNPOkPFW1leT1Se5N8v7To1vzlSR/P8m7q+rCJF9N8oUkL5p1fQAAAGMy6c9fmvukPPNgyCsAAMA4\nzvuQVwAAADaPQAkAAMAgAiUAAACDCJQAAAAMIlACAAAwiEAJAADAIAIlAAAAgwiUAAAADCJQAgAA\nMIhACQAAwCACJQAAAINUd49dh8eoqq8k+dOx63EGB5KcGLsSbCztj7Foe4xJ+2Ms2h5jGrP9Xdrd\nF05zw6UMlMusqra7e2vserCZtD/Gou0xJu2PsWh7jGlV2p8hrwAAAAwiUAIAADCIQLl3t45dATaa\n9sdYtD3GpP0xFm2PMa1E+3MMJQAAAIPooQQAAGAQgRIAAIBBBMo9qKqnVtWHquqeqvpIVT1z7Dqx\nnqpqf1W9c9LWjlTV+6rqqsmyJ1bVb1fVp6rq7qp67tj1ZT1V1cuqqqvqBybXtT0WqqourKo3TtrY\nx6vqrZNyn78sXFX9w6r6aFV9bLKP++FJuX0fc1VVb6iq+yafsdfsKj/jvm6Z94MC5d7cluT27n5a\nktclecu41WHN3Z7k6d19dZI7k7xpUv7aJB/u7qcmeVmSO6pq30h1ZE1V1RVJfiTJh3cVa3ss2muT\ndJKndfffTnLzpNznLwtVVZXkrUn+aXdfk+T7ktxWVRfFvo/5e3uS5yT57KPKz7avW9r9oEl5plRV\nT0zy6STf0N0PT3Y8x5I8p7s/PW7tWHdV9ewkb+/uK6rqRJKruvtPJsv+MMkt3f0/Rq0ka6OqnpDk\nvyf5ySSvT/Lz3f1ObY9Fqqqvz+nP1a3u/tKucp+/LNykXX0hyQ929wer6llJfivJNyX5Yuz7WICq\nui/JD3T3x862r0vypTMtW4b9oB7K6V2W5Fh3P5wkfTqJH01y+ai1YlO8MsmdVXVJkn07H2oT90U7\nZL4OJfmD7v5fOwXaHufBN+f0F/dbqup/VtXvVdXz4vOX82DSrm5I8htV9dkkv5/kh5NcFPs+zo+z\n7euWej8oUMKSq6pbklyV5NVj14X1V1XfkuT6JD87dl3YOBck+cYkn+juZyf5F0neNimHhaqqC5L8\n2yQv7u5vTPK8JL8a7Q/OSaCc3v1JDk52ODtDIy7P6V8HYCGq6uYkL07ygu7+cnc/mOThqnryrptd\nEe2Q+fnOnG5Tn5oMxfm2nD6e9yXR9liso0m+muTXkqS7/yjJZ3I6ZPr8ZdGuSfI3uvuDSdLdH0my\nneRZse/j/Dhb1ljqHCJQTqm7P5/ko0leOim6Psn2MoxbZj1V1aEkNyZ5fnc/tGvRryf50cltvjXJ\nU5L87vmvIeuou3+huw929xXdfUVOT8rzz7r7F6LtsUDd/YUkv5PkHyRJVX1TTh+/9gfx+cvi7Xxh\n/1tJMplZ/ZuTfDL2fZwHZ8say55DTMqzB1X19JyeUemSnD449mXd/fFRK8VaqqqtnP5wuzfJ8Unx\nV7r72qp6Uk4Pw/mmJH+R5Ce6+/3j1JR1V1UfyF9OyqPtsVBVdWWSNyf56zndW/nT3f0On7+cD1V1\nY5JbcrrtPSHJf+juO+z7mLequi3JC5M8OcmDSY5391Vn29ct835QoAQAAGAQQ14BAAAYRKAEAABg\nEIESAACAQQRKAAAABhEoAQAAGESgBAAAYBCBEgAAgEEESgAAAAYRKAEAABjk/wP7ZxdV4ARkkgAA\nAABJRU5ErkJggg==\n", 223 | "text/plain": [ 224 | "" 225 | ] 226 | }, 227 | "metadata": {}, 228 | "output_type": "display_data" 229 | } 230 | ], 231 | "source": [ 232 | "plt.figure(figsize=(14, 5), dpi=80)\n", 233 | "plt.scatter(episodes, iterations, s=0.5)\n", 234 | "#plt.hist(iterations, bins=100)\n", 235 | "None" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 194, 241 | "metadata": { 242 | "scrolled": true 243 | }, 244 | "outputs": [ 245 | { 246 | "data": { 247 | "text/plain": [ 248 | "(60, array([[ 2.0565536, -2.9715855],\n", 249 | " [-1.8772595, 2.7955716]], dtype=float32))" 250 | ] 251 | }, 252 | "execution_count": 194, 253 | "metadata": {}, 254 | "output_type": "execute_result" 255 | } 256 | ], 257 | "source": [ 258 | "count, weights" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 90, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "if False:\n", 268 | " np.save(\"weights\", weights)" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 21, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "if True:\n", 278 | " weights = np.load(\"weights.npy\")" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [] 287 | } 288 | ], 289 | "metadata": { 290 | "kernelspec": { 291 | "display_name": "Python 3", 292 | "language": "python", 293 | "name": "python3" 294 | }, 295 | "language_info": { 296 | "codemirror_mode": { 297 | "name": "ipython", 298 | "version": 3 299 | }, 300 | "file_extension": ".py", 301 | "mimetype": "text/x-python", 302 | "name": "python", 303 | "nbconvert_exporter": "python", 304 | "pygments_lexer": "ipython3", 305 | "version": "3.6.3" 306 | } 307 | }, 308 | "nbformat": 4, 309 | "nbformat_minor": 2 310 | } 311 | -------------------------------------------------------------------------------- /open ai gym/mountain car - sarsa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gym\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import numpy as np\n", 12 | "from collections import defaultdict\n", 13 | "import math, random" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "%matplotlib inline" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "env = gym.make(\"MountainCarContinuous-v0\")" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 4, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "initial_state = env.reset()" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 5, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "def step(env, action):\n", 50 | " state, reward, final, info = env.step(action)\n", 51 | " return state, reward, final, info" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 6, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "data": { 61 | "text/plain": [ 62 | "((2,), array([-1.2 , -0.07]), array([0.6 , 0.07]))" 63 | ] 64 | }, 65 | "execution_count": 6, 66 | "metadata": {}, 67 | "output_type": "execute_result" 68 | } 69 | ], 70 | "source": [ 71 | "env.observation_space.shape, env.observation_space.low, env.observation_space.high" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 7, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "def featurize_state(state):\n", 81 | " return np.array([state[0], state[0]**2, state[0]**3, state[0]**4, state[0]**5,\n", 82 | " state[1], state[1]**2, state[1]**3, state[1]**4, state[1]**5])\n", 83 | "\n", 84 | "def action_space(env):\n", 85 | " env.action_space.low()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 8, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "def actions(env):\n", 95 | " return np.linspace(env.action_space.low, env.action_space.high, 10)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 14, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "data": { 105 | "text/plain": [ 106 | "gym.spaces.box.Box" 107 | ] 108 | }, 109 | "execution_count": 14, 110 | "metadata": {}, 111 | "output_type": "execute_result" 112 | } 113 | ], 114 | "source": [ 115 | "type(env.action_space)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 17, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/plain": [ 126 | "Box(1,)" 127 | ] 128 | }, 129 | "execution_count": 17, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "env.action_space." 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 9, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "state = env.observation_space.sample()\n", 145 | "state = featurize_state(state)\n", 146 | "\n", 147 | "def init_weights():\n", 148 | " state = env.observation_space.sample()\n", 149 | " state = featurize_state(state)\n", 150 | " return np.random.uniform(-1/np.sqrt(state.shape[0]),\n", 151 | " 1/np.sqrt(state.shape[0]),\n", 152 | " (actions(env).shape[0], state.shape[0]))\n", 153 | "weights = init_weights()\n", 154 | "\n", 155 | "def q(state, action_idx):\n", 156 | " state = featurize_state(state)\n", 157 | " return np.dot(weights[action_idx], state)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 10, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "num_actions = len(actions(env))\n", 167 | "def select_action(state, greedy=0.2):\n", 168 | " state = featurize_state(state)\n", 169 | " max_arg = np.dot(weights, state).argmax()\n", 170 | " \n", 171 | " if np.random.uniform() < 1 - greedy:\n", 172 | " return max_arg\n", 173 | " else:\n", 174 | " return random.randint(0, num_actions-1)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 27, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "state = env.reset()\n", 184 | "weights = init_weights()\n", 185 | "action = None\n", 186 | "done = False" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 29, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "lr = learning_rate = 0.01\n", 196 | "df = discount_factor = 1\n", 197 | "\n", 198 | "episodes = []\n", 199 | "iterations = []\n", 200 | "\n", 201 | "\n", 202 | "for i in range(50):\n", 203 | " count = 0\n", 204 | " state = env.reset()\n", 205 | " done = False\n", 206 | " action = None\n", 207 | " \n", 208 | " while not done:\n", 209 | " count += 1\n", 210 | " # if count >= 300: raise Exception(\"yay\")\n", 211 | " if action is None:\n", 212 | " action = select_action(state)\n", 213 | "\n", 214 | " next_state, reward, done, _ = step(env, [actions(env)[action]])\n", 215 | " \n", 216 | " if done:\n", 217 | " w = weights[action]\n", 218 | " w = w + lr*(reward - q(state, action))*featurize_state(state)\n", 219 | " weights[action] = w\n", 220 | " break\n", 221 | " \n", 222 | " next_action = select_action(next_state)\n", 223 | " w = weights[action]\n", 224 | "\n", 225 | " w = w + lr*(reward + df*q(next_state, next_action) - q(state, action))*featurize_state(state)\n", 226 | "\n", 227 | " weights[action] = w\n", 228 | "\n", 229 | " action = next_action\n", 230 | " state = next_state\n", 231 | " # env.render()\n", 232 | " \n", 233 | " episodes.append(i)\n", 234 | " iterations.append(count)\n", 235 | " env.render()" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 17, 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "data": { 245 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA5sAAAFTCAYAAAC3TxjgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAMTQAADE0B0s6tTgAAHdhJREFUeJzt3X2MZeddH/DvL/XKS9mtARPjJTfLmtiJUELWoFhLUUgq\nUd7ZkiYoUSyX1IqAqKIgWakCFrWqFAnnD6wKISV2AZkkrAQhgLtVG/HSkNAEG0PIEheR2ApmM86Y\nEKOQ2dJNdpVf/5g7YdbxzM7sPTP37fORRrP3POeeeebO2XPv9zxv1d0BAACAIT1r2hUAAABg8Qib\nAAAADE7YBAAAYHDCJgAAAIMTNgEAABicsAkAAMDghE0AAAAGJ2wCAAAwOGETAACAwQmbAAAADO6q\naVdgt66++up+9rOfPe1qAAAALJ0nnnji89199U72nbuw+exnPzsrKyvTrgYAAMDSqaq/3em+utEC\nAAAwOGETAACAwQmbAAAADE7YBAAAYHDCJgAAAIMTNgEAABicsAkAAMDghE0AAAAGt6OwWVU/X1WP\nV1VX1c2btt9UVR+sqo9V1cNV9cJJywAAAJh/O23Z/I0kL03y10/bfm+S+7r7+UnekuT+AcoAAACY\nc9XdO9+56vEkr+juD1fVdUkeS/JV3X2xqirJatZD6WevpKy7H7tcHUajUa+srOzqlwQAAGByVfVE\nd492su8kYzafm2S1uy8mSa+n1rNJjk5QtrDWzl/IqYfOZu38hSsq32uT1m+v6z/t12dSk9Z/1l//\nSX7+tH+3y5n113ba5Xtt2q//rB9/2ub595v1v/08v7aJ+gPrZn6CoKq6o6pWNr7OnTs37SpdkdNn\nVnPXA4/k9JnVKyrfa5PWb6/rP+3XZ1KT1n/WX/9Jfv60f7fLmfXXdtrle23ar/+sH3/a5vn3m/W/\n/Ty/ton6A+uumuC5n0hypKqu2tQd9mjWWyk/e4VlX6K770lyz8bj0Wi0836/M+Tk8SOXfN9t+V6b\ntH57Xf9pvz6TmrT+s/76T/Lzp/27Xc6sv7bTLt9r0379Z/340zbPv9+s/+3n+bVN1B9Yd8VjNseP\n/yDJ/d19f1X9YJKf7O6XTFJ2OcZsAgAATMduxmzuKGxW1b1Jvi/J9UmeSrLW3TdW1QuyPpPstVlv\nsby9uz8yfs4VlV2OsAkAADAdg4fNWSJsAgDTtnb+Qk6fWc3J40dy+OCBaVcHYN/s12y0AABLyQQy\nAJc3yQRBAABLyQQyAJcnbAIA7NLhgwdy64mFXiIcYGK60QIAADA4YRMAAIDBCZsAAAAMTtgEAABg\ncMImAAAAgxM2AQAAGJywCQAAwOCETQAAAAYnbAIAADA4YRMAAIDBCZsAAAAMTtgEAABgcMImAAAA\ngxM2AQAAGJywCQAAwOCETQAAAAYnbAIAADA4YRMAAIDBCZsAAAAMTtgEAABgcMImAAAAgxM2AQAA\nGNwgYbOqvruq/qSq/ryqHqyq4+Pt11XVe6rq0ap6pKpetuk5W5YBAAAw3yYOm1X1lUl+NcnruvvF\nSf7D+HGS3J3kwe6+KcntSU5V1YEdlAEAc27t/IWceuhs1s5fmHZVAJiCIVo2n5fkqe7+P0nS3X+Y\n5GhVfXOSVyd523j7w0k+meTl4+dtVwbAHBAm2M7pM6u564FHcvrM6rSrAsAUXDXAMR5Ncm1VfWt3\nf7Cq/lWSw0luSHKgu5/ctO/jWQ+i125V9vSDV9UdSe7YeHzNNdcMUGUAhrARJpLk1hNfcglnyZ08\nfuSS7wAsl4nDZnf/fVX9YJKfrapDSf4oyV8kOTTpscfHvyfJPRuPR6NRD3FcACYnTLCdwwcPuAkB\nsMSGaNlMd783yXuTpKquTvJkkg8kuVhV129qwTyW5Gx3P1VVz1g2RH0A2B/CBACwlaFmo918S/s/\nJvlf3f1YknclecN4n1uSPCfJ+8b7bVcGAADAHBukZTPJm6vq28bH+6Mkrx9vf1OSd1TVo0k+n+S2\n7r6wgzIAAADmWHXP1xDI0WjUKysr064GAADA0qmqJ7p7tJN9B+lGCwDMH0vXALCXhE0AWFLWwQRg\nLw01ZhMAmDOWrgFgLwmbALCkLF0DwF7SjRYAAIDBCZsAAAAMTtgEAABgcMImAAAAgxM2AQAAGJyw\nCQAAwOCETQAAAAYnbAIAADA4YRMAAIDBCZsAAAAMTtgEAABgcMImAAAAgxM2AQAAGJywCQAAwOCE\nTQAAAAYnbAIAADA4YRMAAIDBCZsAAAAMTtgEAABgcMImAAAAgxM2AQAAGNwgYbOqvreqPlRVH66q\nR6rqdePt11XVe6rq0fH2l216zpZlAAAAzLerJj1AVVWSdyb5F93951V1LMlfVtVvJrk7yYPd/d1V\ndUuS36qqG7r7wmXKAAAAmGNDdaPtJF8x/vc/S/JUks8leXWStyVJdz+c5JNJXj7eb7syYEBr5y/k\n1ENns3bevRwAAPbHxC2b3d1V9Zokv1lV/zfJVyZ5ZZLDSQ5095Obdn88ydGqunarsqcfv6ruSHLH\nxuNrrrlm0irD0jl9ZjV3PfBIkuTWE1/y3wwAAAY3RDfaq5L8dJJXdvf7x11i/1uSmyc9dpJ09z1J\n7tl4PBqNeojjwjI5efzIJd8BAGCvDdGN9uYkX9vd70++2CV2JcmLk1ysqus37XssydnufmqrsgHq\nAzzN4YMHcuuJozl88MC0qwIAwJIYImx+IsmRqvqGJKmqG5M8L8lHk7wryRvG229J8pwk7xs/b7sy\nAAAA5tgQYzb/pqp+JMmvV9UXsh5gf6y7z1bVm5K8o6oeTfL5JLdtmm12uzIAAADmWHXP1xDI0WjU\nKysr064GAADA0qmqJ7p7tJN9h1r6BAAAAL5I2AQAAGBwwiYAAACDEzYBmFtr5y/k1ENns3be/HIA\nMGuETQDm1ukzq7nrgUdy+szqtKsCADzNxEufAMC0nDx+5JLvAMDs0LIJwNw6fPBAbj1xNIcPHpjK\nz9eNFwC2JmwCwBXSjRcAtqYbLQBcId14AWBrwiYAXKGNbrwAwJfSjRYAAIDBCZsAAAAMTtgEAABg\ncMImAAAAgxM2AWAL1tEEgCsnbALAFqyjCQBXztInALAF62gCwJUTNgFgC9bRBIArpxstAAAAgxM2\nAQBgjpi8jHkhbAIAPI0P88wyk5cxL4zZBAB4mo0P80mM22XmmLyMeSFsAsCCWjt/IafPrObk8SM5\nfPDAtKszV3yYZ5aZvIx5oRstACwoXe2u3MaHeSEd4MpN3LJZVdcm+f1Nm/5pkq9Pct34+G9P8rwk\nn0vy77r7/ePnXbdVGcwbrQfALNI6B8A0TRw2u/upJDdvPK6qNyZ5eXf/XVX9cpIHu/u7q+qWJL9V\nVTd094Ukd29TBnPF2B5gFulqB8A07cWYzdcn+anxv1+d5MYk6e6Hq+qTSV6e5PcuUwZzResBAABc\natCwWVXfmuQrk/z3cffaA9395KZdHk9ydLuyIesD+0XrAQAAXGroCYJen+Tt3X1xqANW1R1VtbLx\nde7cuaEODQAAwB4ZLGxW1aGsd4395eSLYzkvVtX1m3Y7luTsdmVPP25339Pdo42vQ4cODVVlAAAA\n9siQLZuvSXKmu/9y07Z3JXlDkownAXpOkvftoAwAAIA5NuSYzdcn+a9P2/amJO+oqkeTfD7JbZtm\nm92uDAAAgDlW3T3tOuzKaDTqlZWVaVcDAABg6VTVE9092sm+Q08QBADs0Nr5Czn10NmsndexB4DF\nI2wCwJScPrOaux54JKfPrE67KgAwuEHX2QQAdu7k8SOXfOcfrZ2/kNNnVnPy+JEcPnhg2tUB4Apo\n2QRgapa9G+nhgwdy64mjwtQz0OoLMP+0bELcQYdp2QgUSXLriaNTrg2zRKsvwPwTNiE+8MK0CBRs\nZaPVF4D5JWxCfOCFaREoAGBxCZsQH3gBAGBoJggCAABgcMImAMA+W/aZmIHlIGwCAOwzS7sAy8CY\nTWCpWfYGmAYT0wHLQMsmsNT2unVBVzngmWxMTOcmF7DItGwCS22vWxes4QoA7NSi9bgSNoGlttfL\n3ugqBwDs1KLdpBY2AfaQNVzZS4t2Bxxg2S3aTWpjNgFgTpnRdHkZDw6LadHGc2vZZF+4+w4wvEW7\nA87OLVpXO2AxCZvsC2+KAMOb927abkReOTcagHkgbLIvvCmyqHxYhivnRuSVm/cbDcByEDbZF94U\nWVQ+LMOVcyMSYLEJmwAT8GEZrpwbkQCLTdgEmIAPywAAz8zSJwAAAAxO2AQAAGBwg4TNqrq6qn6h\nqh6tqo9U1TvH22+qqg9W1ceq6uGqeuGm52xZBgAwibXzF3LqobNZO39h2lUBWFpDtWzenaSTPL+7\nvzHJG8fb701yX3c/P8lbkty/6TnblQEAXLGNmaJPn1mddlVYQG5mwM5MPEFQVX15ktcnGXV3J0l3\nP1lV1yV5SZLvHO/67iS/UFU3JvnsVmXd/dikdQIAlpuZotlLlr2CnRliNtrnJfm7JHdW1b9M8v+S\n/Kckn0my2t0Xk6S7u6rOJjma5O+3KRM2Gdza+Qs5fWY1J48fyeGDB6ZdHQD2mJmi2UtuZsDODNGN\n9qokX5fkL7r7JUl+PMmvZaBlVarqjqpa2fg6d+7cEIdlyehOBQAMZeNmhhvYsL0hAuHZJF9I8qtJ\n0t1/VlV/lfUAeqSqrurui1VVWW+5PJv1brRblV2iu+9Jcs/G49Fo1APUmSXjDiQAAOyviVs2u/vT\nSX4/yXclSVXdkOSGJB9I8qEkt413fVWSle5+rLs/tVXZpPWBZ+IOJAAA7K8az+kz2UGqvj7JLyX5\n6qy3cr65u99dVS/I+iyz12a9NfP27v7I+Dlblm1nNBr1ysrKxHUGAABgd6rqie4e7WjfIcLmfhI2\nAQAApmM3YXOodTaBOWa9MAAAhiZsAgs9W68gDQAMxeeK3RlkeRJgvi3ybL0W3gYAhuJzxe4Im8BC\nL36+yEEaANhfPlfsjm60S0KTP8vKsjcAPJ3PRVwpnyt2R9hcEos8Jg8AYDd8LoL9oRvtktDkDwCw\nzuci2B/W2QQAAGBHrLMJAADAVAmbAAAADE7YBABgV8zmCuyEsAkAwK6YzRXYCbPRAgCwK2ZzBXZC\n2AQAYFc2FrYH2I5utAAAAAxO2AQAAGBwwiYAAACDEzYBAAAYnLAJAADA4IRNAAAABidsAgAAg1k7\nfyGnHjqbtfMXpl0VpkzYBAAABnP6zGrueuCRnD6zOu2qMGVXTbsCAADA4jh5/Mgl31lewiYAADCY\nwwcP5NYTR6ddDWaAbrQAAOwrY/qYhPNnfgwSNqvq8ar6aFV9ePz1mvH2m6rqg1X1sap6uKpeuOk5\nW5YBALC4jOljEs6f+TFkN9rXdPeHn7bt3iT3dff9VfWDSe5PcssOymCmrJ2/kNNnVnPy+JEcPnhg\n2tUBgLlmTB+TcP7Mj+ruyQ9S9XiSV2wOm1V1XZLHknxVd1+sqkqymuSlST67VVl3P7bdzxqNRr2y\nsjJxnWE3Tj10Nnc98Eje/AMvMgYBAIClVVVPdPdoJ/sO2bL59nFo/OMkP5nkuUlWu/tiknR3V9XZ\nJEeT/P02ZduGTZgGd9AAYP/oUQSLYagJgl7W3S9O8s1JPp3kVwY6bqrqjqpa2fg6d+7cUIeGHduY\nVc0bHgDsPWPyYDEM0o32kgNWHUnysSTPi260AADskpZNmF276UY7cctmVX15VX3Fpk2vTfJn3f2p\nJB9Kctt4+6uSrHT3Y9uVTVofAADmmx5FsBiGGLP5NUneXVX/JEkl+XiSHxqX/WiS+6vqzqy3Zt6+\n6XnblQGwALROAMDymjhsdvfHk3zTFmUfTfLPd1sGwGLYGHeVxEzOALBkhpyNFgAuYSZnYBnp1QHr\nhpqNFgC+hHFXk1k7fyGnHjqbtfMXpl0VYBfMpgvrtGwCwIzSDRnmk14dsE7YBIAZ5QMrzKeNXh2w\n7IRNAJhRPrACMM+M2QQA5o7xrACzT9iEGeBDE8DumIAFYPbpRgszwCQgwCya5eUbjGcFmH1aNgei\nZYpJnDx+JG/+gRf50HQF/N+DvTPLrYeW1YHF5b19cWjZHIiWKSZhEpAr5/8e7B2th8A0eG9fHMLm\nQLwhw3T4vwd7x40wYBq8ty+O6u5p12FXRqNRr6ysTLsaAAAAS6eqnuju0U72NWYTAACAwQmbAAAA\nDE7YXBBm7QIAAGaJsLkgZnl6egAAYPmYjXZBmLULAACYJVo2F4TFrQEAmAeGfy0PYRMAgEsIA/Nt\n1v9+hn8tD91oAQC4xEYYSJJbTxydcm3YrVn/+xn+tTyETQAALiEMzLdZ//ttDP9i8VV3T7sOuzIa\njXplZWXa1QAAAFg6VfVEd492sq8xmwAAAAxO2AQAAGBwwiYAAACDGyxsVtXtVdVV9Yrx4+uq6j1V\n9WhVPVJVL9u075ZlAADA7Jr1pVWYHYOEzao6luSHkzy4afPdSR7s7puS3J7kVFUd2EEZAAAwo6yT\nyU5NvPRJVT0ryS8m+fdJfm5T0auT3Jgk3f1wVX0yycuT/N5lygAAgBk160urMDuGaNm8I8kHuvtP\nNzZU1bVJDnT3k5v2ezzJ0e3KnungVXVHVa1sfJ07d26AKgMAAFdiY53Mwwd1TGR7E4XNqnpRklcl\n+ZlhqvOluvue7h5tfB06dGivfhQAAAADmbRl89uSHEvyaFU9nuRbktyX9W6yF6vq+k37Hktytruf\n2qpswroAM8gkAgAAy2misNndb+3uI919rLuPZX2CoB/p7rcmeVeSNyRJVd2S5DlJ3jd+6nZlsGsC\nzewyiQDT5NrAsnLuA7Ng4gmCtvGmJO+oqkeTfD7Jbd19YQdlsGsbgSZJbj3xjMN/mRKTCDBNrg0s\nK+c+MAuqu6ddh10ZjUa9srIy7WowY9bOX8jpM6s5efyIwerAF7k2sKyc+8Beqaonunu0o32FTQAA\nAHZiN2FziKVPAABgZhizyqxatnNT2AQAYKGYnI5ZtWzn5l5OEMSAjL0AANgZk9Mxq5bt3NSyOSeW\n7S4IAMCVOnzwQG49cdQNevbd5brJLtu5qWVzTizbXRAAAJg3lh26lLA5JzbuggAAMN+mPTxq2j9/\nkWkgupRutAAAsI+mPTxq2j9/kS1bN9nL0bIJAAD7aNqtX9P++SyP6u5p12FXRqNRr6ysTLsaAAAA\nS6eqnuju0U721Y0WAGDJLNvC8vPG34dFIWwCACwZY/Zmm78Pi8KYTQCAJWPM3mzz92FRGLMJAADA\njhizCQAAwFQJmwAAAAxO2AQAAGBwwiYAAACDEzYBAAAYnLAJAAB80dr5Czn10Nmsnb8w7aow54RN\nAADgi06fWc1dDzyS02dWp10V5txV064AAAAwO04eP3LJ93mydv5CTp9ZzcnjR3L44IFpV2fpadkE\nAAC+6PDBA7n1xNG5DGtaZWeLlk0AAGAhzHOr7CISNgEAgIWw0SrLbBikG21V/U5V/XlVfbiq/rCq\nvmm8/aaq+mBVfayqHq6qF256zpZlAAAAzLehxmy+urtf3N03J7knyf3j7fcmua+7n5/kLZu2X64M\nAACAOTZI2Ozuz2x6eE2SrqrrkrwkyTvH29+d5LlVdeN2ZUPUBwAAgOkabDbaqnp7VX0iyX9O8m+S\nPDfJandfTJLu7iRnkxy9TNnTj3tHVa1sfJ07d26oKjNDLB4MAACLZbCw2d0/1N3PTfLTWe8WO9Rx\n7+nu0cbXoUOHhjo0M8Q01QAAsFgGn422u3+lqt6WZCXJkaq6qrsvVlVlveXybJLPblPGEjJNNQAA\ne23t/IWcPrOak8ePzOU6ovNm4pbNqvqKqvraTY9fkeSpJJ9K8qEkt42LXpVkpbsf6+4tyyatD/Np\nnhcPBgBgPlyuN52hXcMaomXzmiTvqqovS/KFJH+b5Pu7u6vqR5PcX1V3Zr018/ZNz9uuDAAAYFCX\n6023EUaTWK9zALU+N8/8GI1GvbKyMu1qAAAAC0Y328urqie6e7STfQcfswkAADCPNoZ2MYzBZqMF\nAACADcImAAAAgxM2AQAAGJywCQAAwOCETQAAAAYnbAIAADA4YRMAAIDBCZsAAAAMTtgEAABgcMIm\nAAAAgxM2AQAAGJywCQAAwOCETQAAAAYnbAIAADA4YRMAAIDBCZsAAAAMTtgEAABgcMImAAAAgxM2\nAQAAGJywCQAAwOCETQAAAAYnbAIAADA4YRMAAIDBTRw2q+pgVf12VX2sqs5U1e9W1Y3jsuuq6j1V\n9WhVPVJVL9v0vC3LAAAAmG9DtWzel+QF3X08yQNJfnG8/e4kD3b3TUluT3Kqqg7soAwAAIA5NnHY\n7O7z3f0/urvHmx5Mcmz871cnedt4v4eTfDLJy3dQBgAAwBzbizGbP5Hkgaq6NsmB7n5yU9njSY5u\nV7YH9QEAAGCfXTXkwarqziQ3Jvn2JF820DHvSHLHxuNrrrlmiMMCAACwhwZr2ayqNyZ5ZZLv6e5/\n6O6nklysqus37XYsydntyp5+3O6+p7tHG1+HDh0aqsoAAADskUHC5rj18bVJvqO7P7Op6F1J3jDe\n55Ykz0nyvh2UAQAAMMcm7kZbVaMkP5fk40neW1VJ8rnuPpHkTUneUVWPJvl8ktu6+8L4qduVAQDA\nM1o7fyGnz6zm5PEjOXzQYgYwqyYOm929kqS2KPubJN+52zIAANjK6TOrueuBR5Ikt54wvyTMqkEn\nCAIAgL128viRS74Ds0nYBABgrhw+eECLJsyBvVhnEwAAgCUnbAIAADA4YRMAAIDBCZsAAAAMTtgE\nAABgcMImAAAAgxM2AQAAGJywCQAAwOCETQAAAAYnbAIAADA4YRMAAIDBVXdPuw67UlWfS/K3067H\nFg4lOTftSrC0nH9Mi3OPaXHuMU3OP6Zl2ufes7v76p3sOHdhc5ZV1Up3j6ZdD5aT849pce4xLc49\npsn5x7TM07mnGy0AAACDEzYBAAAYnLA5rHumXQGWmvOPaXHuMS3OPabJ+ce0zM25Z8wmAAAAg9Oy\nCQAAwOCETQAAAAYnbA6gqm6qqg9W1ceq6uGqeuG068RiqqqDVfXb43PtTFX9blXdOC67rqreU1WP\nVtUjVfWyadeXxVVVt1dVV9Urxo+df+ypqrq6qn5hfI59pKreOd7uPZg9V1XfW1UfqqoPj69xrxtv\nd+1jUFX181X1+Pg99uZN27e81s3ydVDYHMa9Se7r7ucneUuS+6dbHRbcfUle0N3HkzyQ5BfH2+9O\n8mB335Tk9iSnqurAlOrIAquqY0l+OMmDmzY7/9hrdyfpJM/v7m9M8sbxdu/B7KmqqiTvTPJvu/vm\nJN+f5N6qOhzXPob3G0lemuSvn7Z9u2vdzF4HTRA0oaq6LsljSb6quy+OL0irSV7a3Y9Nt3Ysuqp6\nSZLf6O5jVXUuyY3d/eS47I+T3NndvzfVSrJQqupZSX4nyZuS/FyS/9Ldv+38Yy9V1Zdn/b111N2f\n3bTdezB7bnxefTrJv+7u91fVi5P8zyQ3JPm7uPaxB6rq8SSv6O4Pb3etS/LZrcpm4TqoZXNyz02y\n2t0Xk6TX0/vZJEenWiuWxU8keaCqrk1yYOPNbuzxOA8Z3h1JPtDdf7qxwfnHPnhe1j/U31lVf1JV\nf1hV3x7vweyD8Xn1miS/WVV/neR/J3ldksNx7WN/bHetm+nroLAJc6qq7kxyY5KfmnZdWA5V9aIk\nr0ryM9OuC0vnqiRfl+QvuvslSX48ya+Nt8Oeqqqrkvx0kld299cl+fYk74jzDy5L2JzcJ5IcGV+I\nNrpaHM36HQXYE1X1xiSvTPI93f0P3f1UkotVdf2m3Y7Feciwvi3r59Wj4+4935L1McSvjvOPvXU2\nyReS/GqSdPefJfmrrAdQ78HstZuTfG13vz9JuvvhJCtJXhzXPvbHdnljprOIsDmh7v5Ukg8luW28\n6VVJVmahjzSLqaruSPLaJN/R3Z/ZVPSuJG8Y73NLkucked/+15BF1d1v7e4j3X2su49lfYKgH+nu\nt8b5xx7q7k8n+f0k35UkVXVD1sfLfSDeg9l7Gx/mvyFJxrPAPy/JR+Paxz7YLm/MehYxQdAAquoF\nWZ/16dqsD9K9vbs/MtVKsZCqapT1N72PJ1kbb/5cd5+oqq/JereeG5J8PsmPdfd7p1NTlkFV/UH+\ncYIg5x97qqq+PskvJfnqrLdyvrm73+09mP1QVa9NcmfWz71nJfnZ7j7l2sfQqureJN+X5PokTyVZ\n6+4bt7vWzfJ1UNgEAABgcLrRAgAAMDhhEwAAgMEJmwAAAAxO2AQAAGBwwiYAAACDEzYBAAAYnLAJ\nAADA4IRNAAAABidsAgAAMLj/D3rt0sv37LpyAAAAAElFTkSuQmCC\n", 246 | "text/plain": [ 247 | "" 248 | ] 249 | }, 250 | "metadata": {}, 251 | "output_type": "display_data" 252 | } 253 | ], 254 | "source": [ 255 | "plt.figure(figsize=(14, 5), dpi=80)\n", 256 | "plt.scatter(episodes, iterations, s=0.5)\n", 257 | "#plt.hist(iterations, bins=100)\n", 258 | "None" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 120, 264 | "metadata": { 265 | "scrolled": true 266 | }, 267 | "outputs": [ 268 | { 269 | "data": { 270 | "text/plain": [ 271 | "(999, array([[-0.02085557, 0.0595762 , 0.11831713, -0.18363152, -0.19683867,\n", 272 | " -0.01796869, 0.28806498, -0.021476 , 0.24860996, 0.18516945],\n", 273 | " [ 0.01440559, -0.05804852, -0.22202222, -0.2264345 , -0.07399336,\n", 274 | " 0.0836368 , 0.03485213, 0.13657303, -0.18549418, 0.08907543],\n", 275 | " [-0.08892545, -0.23147879, -0.16591371, -0.0597599 , -0.0296756 ,\n", 276 | " 0.0452843 , -0.16806746, 0.11934212, 0.05252277, 0.08751585],\n", 277 | " [ 0.00821088, -0.05516365, -0.10181619, -0.01925168, 0.0219685 ,\n", 278 | " 0.09394311, 0.27535469, -0.22059754, 0.14107741, 0.07010356],\n", 279 | " [-0.036048 , -0.08865107, -0.11465601, -0.16514499, -0.0954818 ,\n", 280 | " 0.15807466, -0.01672393, 0.20743621, 0.02093894, 0.29951216],\n", 281 | " [-0.03579179, -0.01706876, 0.1083824 , -0.00311162, -0.07572744,\n", 282 | " 0.22687012, 0.1699998 , 0.25096862, 0.01865831, 0.17707882],\n", 283 | " [-0.10621423, -0.29940879, -0.25072268, -0.12465697, -0.05695309,\n", 284 | " 0.05062194, 0.12632425, 0.01576961, 0.02080412, -0.08271762],\n", 285 | " [-0.04480179, -0.13596811, -0.23614967, -0.34553927, -0.18651549,\n", 286 | " -0.0329766 , 0.08021832, 0.07155519, 0.26401518, -0.064796 ],\n", 287 | " [-0.06205904, -0.12404013, -0.02585845, -0.07235029, -0.08973349,\n", 288 | " -0.08904176, -0.06319158, 0.09375384, 0.21721181, -0.05041887],\n", 289 | " [-0.0676916 , -0.11858236, 0.12053182, 0.21836649, 0.05770418,\n", 290 | " 0.20416503, -0.09088759, -0.31010173, -0.16394684, 0.11900846]]))" 291 | ] 292 | }, 293 | "execution_count": 120, 294 | "metadata": {}, 295 | "output_type": "execute_result" 296 | } 297 | ], 298 | "source": [ 299 | "count, weights" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": 90, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "if False:\n", 309 | " np.save(\"weights\", weights)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 12, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "if True:\n", 319 | " weights = np.load(\"weights.npy\")" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [] 328 | } 329 | ], 330 | "metadata": { 331 | "kernelspec": { 332 | "display_name": "Python 3", 333 | "language": "python", 334 | "name": "python3" 335 | }, 336 | "language_info": { 337 | "codemirror_mode": { 338 | "name": "ipython", 339 | "version": 3 340 | }, 341 | "file_extension": ".py", 342 | "mimetype": "text/x-python", 343 | "name": "python", 344 | "nbconvert_exporter": "python", 345 | "pygments_lexer": "ipython3", 346 | "version": "3.6.3" 347 | } 348 | }, 349 | "nbformat": 4, 350 | "nbformat_minor": 2 351 | } 352 | -------------------------------------------------------------------------------- /open ai gym/weights.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hermesdt/reinforcement-learning/eb69484bb6d5a415633eb6e1fc07e12aa193cbb0/open ai gym/weights.npy -------------------------------------------------------------------------------- /open ai gym/weights/policy-cartpole-3000-iterations.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hermesdt/reinforcement-learning/eb69484bb6d5a415633eb6e1fc07e12aa193cbb0/open ai gym/weights/policy-cartpole-3000-iterations.h5 -------------------------------------------------------------------------------- /open ai gym/weights/value-cartpole-3000-iterations.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hermesdt/reinforcement-learning/eb69484bb6d5a415633eb6e1fc07e12aa193cbb0/open ai gym/weights/value-cartpole-3000-iterations.h5 -------------------------------------------------------------------------------- /ppo/cartpole_ppo_online.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import torch\n", 11 | "import gym\n", 12 | "from torch import nn\n", 13 | "from torch.nn import functional as F\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "from torch.utils import tensorboard" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "def mish(input):\n", 25 | " return input * torch.tanh(F.softplus(input))\n", 26 | "\n", 27 | "class Mish(nn.Module):\n", 28 | " def __init__(self): super().__init__()\n", 29 | " def forward(self, input): return mish(input)" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 3, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# helper function to convert numpy arrays to tensors\n", 39 | "def t(x): return torch.from_numpy(x).float()" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 50, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# Actor module, categorical actions only\n", 49 | "class Actor(nn.Module):\n", 50 | " def __init__(self, state_dim, n_actions, activation=nn.Tanh):\n", 51 | " super().__init__()\n", 52 | " self.model = nn.Sequential(\n", 53 | " nn.Linear(state_dim, 64),\n", 54 | " activation(),\n", 55 | " nn.Linear(64, 32),\n", 56 | " activation(),\n", 57 | " nn.Linear(32, n_actions),\n", 58 | " nn.Softmax()\n", 59 | " )\n", 60 | " \n", 61 | " def forward(self, X):\n", 62 | " return self.model(X)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 51, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "# Critic module\n", 72 | "class Critic(nn.Module):\n", 73 | " def __init__(self, state_dim, activation=nn.Tanh):\n", 74 | " super().__init__()\n", 75 | " self.model = nn.Sequential(\n", 76 | " nn.Linear(state_dim, 64),\n", 77 | " activation(),\n", 78 | " nn.Linear(64, 32),\n", 79 | " activation(),\n", 80 | " nn.Linear(32, 1)\n", 81 | " )\n", 82 | " \n", 83 | " def forward(self, X):\n", 84 | " return self.model(X)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 52, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "env = gym.make(\"CartPole-v1\")" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 56, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "data": { 103 | "text/plain": [ 104 | "" 105 | ] 106 | }, 107 | "execution_count": 56, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "# config\n", 114 | "state_dim = env.observation_space.shape[0]\n", 115 | "n_actions = env.action_space.n\n", 116 | "actor = Actor(state_dim, n_actions, activation=Mish)\n", 117 | "critic = Critic(state_dim, activation=Mish)\n", 118 | "adam_actor = torch.optim.Adam(actor.parameters(), lr=3e-4)\n", 119 | "adam_critic = torch.optim.Adam(critic.parameters(), lr=1e-3)\n", 120 | "\n", 121 | "torch.manual_seed(1)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 57, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "def clip_grad_norm_(module, max_grad_norm):\n", 131 | " nn.utils.clip_grad_norm_([p for g in module.param_groups for p in g[\"params\"]], max_grad_norm)\n", 132 | "\n", 133 | "def policy_loss(old_log_prob, log_prob, advantage, eps):\n", 134 | " ratio = (log_prob - old_log_prob).exp()\n", 135 | " clipped = torch.clamp(ratio, 1-eps, 1+eps)*advantage\n", 136 | " \n", 137 | " m = torch.min(ratio*advantage, clipped)\n", 138 | " return -m" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 58, 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "14.0\n", 151 | "10.0\n", 152 | "39.0\n", 153 | "44.0\n", 154 | "39.0\n", 155 | "12.0\n", 156 | "11.0\n", 157 | "52.0\n", 158 | "14.0\n", 159 | "10.0\n", 160 | "33.0\n", 161 | "20.0\n", 162 | "15.0\n", 163 | "25.0\n", 164 | "27.0\n", 165 | "95.0\n", 166 | "29.0\n", 167 | "31.0\n", 168 | "160.0\n", 169 | "192.0\n", 170 | "283.0\n", 171 | "10.0\n", 172 | "205.0\n", 173 | "207.0\n", 174 | "10.0\n", 175 | "261.0\n", 176 | "212.0\n", 177 | "128.0\n", 178 | "155.0\n", 179 | "163.0\n", 180 | "222.0\n", 181 | "237.0\n", 182 | "500.0\n", 183 | "500.0\n", 184 | "500.0\n", 185 | "500.0\n", 186 | "500.0\n", 187 | "500.0\n", 188 | "500.0\n", 189 | "500.0\n" 190 | ] 191 | } 192 | ], 193 | "source": [ 194 | "episode_rewards = []\n", 195 | "gamma = 0.98\n", 196 | "eps = 0.2\n", 197 | "w = tensorboard.SummaryWriter()\n", 198 | "s = 0\n", 199 | "max_grad_norm = 0.5\n", 200 | "\n", 201 | "for i in range(800):\n", 202 | " prev_prob_act = None\n", 203 | " done = False\n", 204 | " total_reward = 0\n", 205 | " state = env.reset()\n", 206 | "\n", 207 | "\n", 208 | " while not done:\n", 209 | " s += 1\n", 210 | " probs = actor(t(state))\n", 211 | " dist = torch.distributions.Categorical(probs=probs)\n", 212 | " action = dist.sample()\n", 213 | " prob_act = dist.log_prob(action)\n", 214 | " \n", 215 | " next_state, reward, done, info = env.step(action.detach().data.numpy())\n", 216 | " advantage = reward + (1-done)*gamma*critic(t(next_state)) - critic(t(state))\n", 217 | " \n", 218 | " w.add_scalar(\"loss/advantage\", advantage, global_step=s)\n", 219 | " w.add_scalar(\"actions/action_0_prob\", dist.probs[0], global_step=s)\n", 220 | " w.add_scalar(\"actions/action_1_prob\", dist.probs[1], global_step=s)\n", 221 | " \n", 222 | " total_reward += reward\n", 223 | " state = next_state\n", 224 | " \n", 225 | " if prev_prob_act:\n", 226 | " actor_loss = policy_loss(prev_prob_act.detach(), prob_act, advantage.detach(), eps)\n", 227 | " w.add_scalar(\"loss/actor_loss\", actor_loss, global_step=s)\n", 228 | " adam_actor.zero_grad()\n", 229 | " actor_loss.backward()\n", 230 | " # clip_grad_norm_(adam_actor, max_grad_norm)\n", 231 | " w.add_histogram(\"gradients/actor\",\n", 232 | " torch.cat([p.grad.view(-1) for p in actor.parameters()]), global_step=s)\n", 233 | " adam_actor.step()\n", 234 | "\n", 235 | " critic_loss = advantage.pow(2).mean()\n", 236 | " w.add_scalar(\"loss/critic_loss\", critic_loss, global_step=s)\n", 237 | " adam_critic.zero_grad()\n", 238 | " critic_loss.backward()\n", 239 | " # clip_grad_norm_(adam_critic, max_grad_norm)\n", 240 | " w.add_histogram(\"gradients/critic\",\n", 241 | " torch.cat([p.data.view(-1) for p in critic.parameters()]), global_step=s)\n", 242 | " adam_critic.step()\n", 243 | " \n", 244 | " prev_prob_act = prob_act\n", 245 | " \n", 246 | " w.add_scalar(\"reward/episode_reward\", total_reward, global_step=i)\n", 247 | " episode_rewards.append(total_reward)" 248 | ] 249 | } 250 | ], 251 | "metadata": { 252 | "kernelspec": { 253 | "display_name": "Python 3", 254 | "language": "python", 255 | "name": "python3" 256 | }, 257 | "language_info": { 258 | "codemirror_mode": { 259 | "name": "ipython", 260 | "version": 3 261 | }, 262 | "file_extension": ".py", 263 | "mimetype": "text/x-python", 264 | "name": "python", 265 | "nbconvert_exporter": "python", 266 | "pygments_lexer": "ipython3", 267 | "version": "3.6.3" 268 | } 269 | }, 270 | "nbformat": 4, 271 | "nbformat_minor": 2 272 | } 273 | -------------------------------------------------------------------------------- /resources/cartpole-2000-iterations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hermesdt/reinforcement-learning/eb69484bb6d5a415633eb6e1fc07e12aa193cbb0/resources/cartpole-2000-iterations.png --------------------------------------------------------------------------------