└── Policy Gradient.ipynb /Policy Gradient.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "Episode training complete.\n", 13 | "Episode: 1 Total Reward: 1\n", 14 | "Episode training complete.\n", 15 | "Episode: 2 Total Reward: 1\n", 16 | "Episode training complete.\n", 17 | "Episode: 3 Total Reward: 1\n", 18 | "Episode training complete.\n", 19 | "Episode: 4 Total Reward: 1\n", 20 | "Episode training complete.\n", 21 | "Episode: 5 Total Reward: 1\n", 22 | "Episode training complete.\n", 23 | "Episode: 6 Total Reward: 1\n", 24 | "Episode training complete.\n", 25 | "Episode: 7 Total Reward: 1\n", 26 | "Episode training complete.\n", 27 | "Episode: 8 Total Reward: 1\n", 28 | "Episode training complete.\n", 29 | "Episode: 9 Total Reward: 1\n", 30 | "Episode training complete.\n", 31 | "Episode: 10 Total Reward: 1\n", 32 | "Episode training complete.\n", 33 | "Episode: 11 Total Reward: 1\n", 34 | "Episode training complete.\n", 35 | "Episode: 12 Total Reward: 1\n", 36 | "Episode training complete.\n", 37 | "Episode: 13 Total Reward: 1\n", 38 | "Episode training complete.\n", 39 | "Episode: 14 Total Reward: 1\n", 40 | "Episode training complete.\n", 41 | "Episode: 15 Total Reward: 1\n", 42 | "Episode training complete.\n", 43 | "Episode: 16 Total Reward: 1\n", 44 | "Episode training complete.\n", 45 | "Episode: 17 Total Reward: 1\n", 46 | "Episode training complete.\n", 47 | "Episode: 18 Total Reward: 1\n", 48 | "Episode training complete.\n", 49 | "Episode: 19 Total Reward: 1\n", 50 | "Episode training complete.\n", 51 | "Episode: 20 Total Reward: 1\n", 52 | "Episode training complete.\n", 53 | "Episode: 21 Total Reward: 1\n", 54 | "Episode training complete.\n", 55 | "Episode: 22 Total Reward: 1\n", 56 | "Episode training complete.\n", 57 | "Episode: 23 Total Reward: 1\n", 58 | "Episode training complete.\n", 59 | "Episode: 24 Total Reward: 1\n", 60 | "Episode training complete.\n", 61 | "Episode: 25 Total Reward: 1\n", 62 | "Episode training complete.\n", 63 | "Episode: 26 Total Reward: 1\n", 64 | "Episode training complete.\n", 65 | "Episode: 27 Total Reward: 1\n", 66 | "Episode training complete.\n", 67 | "Episode: 28 Total Reward: 1\n", 68 | "Episode training complete.\n", 69 | "Episode: 29 Total Reward: 1\n", 70 | "Episode training complete.\n", 71 | "Episode: 30 Total Reward: 1\n", 72 | "Episode training complete.\n", 73 | "Episode: 31 Total Reward: 1\n", 74 | "Episode training complete.\n", 75 | "Episode: 32 Total Reward: 1\n", 76 | "Episode training complete.\n", 77 | "Episode: 33 Total Reward: 1\n", 78 | "Episode training complete.\n", 79 | "Episode: 34 Total Reward: 1\n", 80 | "Episode training complete.\n", 81 | "Episode: 35 Total Reward: 1\n", 82 | "Episode training complete.\n", 83 | "Episode: 36 Total Reward: 1\n", 84 | "Episode training complete.\n", 85 | "Episode: 37 Total Reward: 1\n", 86 | "Episode training complete.\n", 87 | "Episode: 38 Total Reward: 1\n", 88 | "Episode training complete.\n", 89 | "Episode: 39 Total Reward: 1\n", 90 | "Episode training complete.\n", 91 | "Episode: 40 Total Reward: 1\n", 92 | "Episode training complete.\n", 93 | "Episode: 41 Total Reward: 1\n", 94 | "Episode training complete.\n", 95 | "Episode: 42 Total Reward: 1\n", 96 | "Episode training complete.\n", 97 | "Episode: 43 Total Reward: 1\n", 98 | "Episode training complete.\n", 99 | "Episode: 44 Total Reward: 1\n", 100 | "Episode training complete.\n", 101 | "Episode: 45 Total Reward: 1\n", 102 | "Episode training complete.\n", 103 | "Episode: 46 Total Reward: 1\n", 104 | "Episode training complete.\n", 105 | "Episode: 47 Total Reward: 1\n", 106 | "Episode training complete.\n", 107 | "Episode: 48 Total Reward: 1\n", 108 | "Episode training complete.\n", 109 | "Episode: 49 Total Reward: 1\n", 110 | "Episode training complete.\n", 111 | "Episode: 50 Total Reward: 1\n", 112 | "Episode training complete.\n", 113 | "Episode: 51 Total Reward: 1\n", 114 | "Episode training complete.\n", 115 | "Episode: 52 Total Reward: 1\n", 116 | "Episode training complete.\n", 117 | "Episode: 53 Total Reward: 1\n", 118 | "Episode training complete.\n", 119 | "Episode: 54 Total Reward: 1\n", 120 | "Episode training complete.\n", 121 | "Episode: 55 Total Reward: 1\n", 122 | "Episode training complete.\n", 123 | "Episode: 56 Total Reward: 1\n", 124 | "Episode training complete.\n", 125 | "Episode: 57 Total Reward: 1\n", 126 | "Episode training complete.\n", 127 | "Episode: 58 Total Reward: 1\n", 128 | "Episode training complete.\n", 129 | "Episode: 59 Total Reward: 1\n", 130 | "Episode training complete.\n", 131 | "Episode: 60 Total Reward: 1\n", 132 | "Episode training complete.\n", 133 | "Episode: 61 Total Reward: 1\n", 134 | "Episode training complete.\n", 135 | "Episode: 62 Total Reward: 1\n", 136 | "Episode training complete.\n", 137 | "Episode: 63 Total Reward: 1\n", 138 | "Episode training complete.\n", 139 | "Episode: 64 Total Reward: 1\n", 140 | "Episode training complete.\n", 141 | "Episode: 65 Total Reward: 1\n", 142 | "Episode training complete.\n", 143 | "Episode: 66 Total Reward: 1\n", 144 | "Episode training complete.\n", 145 | "Episode: 67 Total Reward: 1\n", 146 | "Episode training complete.\n", 147 | "Episode: 68 Total Reward: 1\n", 148 | "Episode training complete.\n", 149 | "Episode: 69 Total Reward: 1\n", 150 | "Episode training complete.\n", 151 | "Episode: 70 Total Reward: 1\n", 152 | "Episode training complete.\n", 153 | "Episode: 71 Total Reward: 1\n", 154 | "Episode training complete.\n", 155 | "Episode: 72 Total Reward: 1\n", 156 | "Episode training complete.\n", 157 | "Episode: 73 Total Reward: 1\n", 158 | "Episode training complete.\n", 159 | "Episode: 74 Total Reward: 1\n", 160 | "Episode training complete.\n", 161 | "Episode: 75 Total Reward: 1\n", 162 | "Episode training complete.\n", 163 | "Episode: 76 Total Reward: 1\n", 164 | "Episode training complete.\n", 165 | "Episode: 77 Total Reward: 1\n", 166 | "Episode training complete.\n", 167 | "Episode: 78 Total Reward: 1\n", 168 | "Episode training complete.\n", 169 | "Episode: 79 Total Reward: 1\n", 170 | "Episode training complete.\n", 171 | "Episode: 80 Total Reward: 1\n", 172 | "Episode training complete.\n", 173 | "Episode: 81 Total Reward: 1\n", 174 | "Episode training complete.\n", 175 | "Episode: 82 Total Reward: 1\n", 176 | "Episode training complete.\n", 177 | "Episode: 83 Total Reward: 1\n", 178 | "Episode training complete.\n", 179 | "Episode: 84 Total Reward: 1\n", 180 | "Episode training complete.\n", 181 | "Episode: 85 Total Reward: 1\n", 182 | "Episode training complete.\n", 183 | "Episode: 86 Total Reward: 1\n", 184 | "Episode training complete.\n", 185 | "Episode: 87 Total Reward: 1\n", 186 | "Episode training complete.\n", 187 | "Episode: 88 Total Reward: 1\n", 188 | "Episode training complete.\n", 189 | "Episode: 89 Total Reward: 1\n", 190 | "Episode training complete.\n", 191 | "Episode: 90 Total Reward: 1\n", 192 | "Episode training complete.\n", 193 | "Episode: 91 Total Reward: 1\n", 194 | "Episode training complete.\n", 195 | "Episode: 92 Total Reward: 1\n", 196 | "Episode training complete.\n", 197 | "Episode: 93 Total Reward: 1\n", 198 | "Episode training complete.\n", 199 | "Episode: 94 Total Reward: 1\n", 200 | "Episode training complete.\n", 201 | "Episode: 95 Total Reward: 1\n", 202 | "Episode training complete.\n", 203 | "Episode: 96 Total Reward: 1\n", 204 | "Episode training complete.\n", 205 | "Episode: 97 Total Reward: 1\n", 206 | "Episode training complete.\n", 207 | "Episode: 98 Total Reward: 1\n", 208 | "Episode training complete.\n", 209 | "Episode: 99 Total Reward: 1\n", 210 | "Episode training complete.\n", 211 | "Episode: 100 Total Reward: 1\n" 212 | ] 213 | } 214 | ], 215 | "source": [ 216 | "import numpy as np\n", 217 | "\n", 218 | "class REINFORCEAgent:\n", 219 | " def __init__(self, num_actions, num_states, gamma=0.99, learning_rate=0.01):\n", 220 | " # gamma is discount factor for finding future reward\n", 221 | " self.num_actions = num_actions\n", 222 | " self.num_states = num_states\n", 223 | " self.gamma = gamma\n", 224 | " self.learning_rate = learning_rate\n", 225 | " self.policy = np.zeros((num_states, num_actions))\n", 226 | "\n", 227 | " def get_action(self, state):\n", 228 | " # return action for current policy\n", 229 | " action_probs = self._softmax(self.policy[state]) # Accesses the policy for the current state\n", 230 | " return np.random.choice(self.num_actions, p=action_probs)\n", 231 | "\n", 232 | " def train(self, episode):\n", 233 | " states, actions, rewards = zip(*episode)\n", 234 | " returns = self._calculate_returns(rewards)\n", 235 | " for t, (state, action) in enumerate(zip(states, actions)):\n", 236 | " delta = returns[t] - self.policy[state, action]\n", 237 | " self.policy[state, action] += self.learning_rate * delta\n", 238 | "\n", 239 | " print(\"Episode training complete.\")\n", 240 | "\n", 241 | " def _calculate_returns(self, rewards):\n", 242 | " G = 0\n", 243 | " returns = []\n", 244 | " for r in reversed(rewards):\n", 245 | " G = r + self.gamma * G\n", 246 | " returns.insert(0, G)\n", 247 | " return returns\n", 248 | "\n", 249 | " def _softmax(self, x):\n", 250 | " exp_values = np.exp(x - np.max(x)) # exponential\n", 251 | " return exp_values / np.sum(exp_values)\n", 252 | "\n", 253 | "# Simple environment\n", 254 | "class SimpleEnvironment:\n", 255 | " def __init__(self, num_states, num_actions):\n", 256 | " self.num_states = num_states\n", 257 | " self.num_actions = num_actions\n", 258 | "\n", 259 | " def reset(self):\n", 260 | " return 0\n", 261 | "\n", 262 | " def step(self, state, action):\n", 263 | " new_state = max(0, min(self.num_states - 1, state + (action * 2 - 1)))\n", 264 | " reward = 1 if new_state == self.num_states - 1 else 0\n", 265 | " return new_state, reward\n", 266 | "\n", 267 | "# Training loop\n", 268 | "num_states = 10\n", 269 | "num_actions = 4\n", 270 | "num_episodes = 100\n", 271 | "env = SimpleEnvironment(num_states, num_actions)\n", 272 | "agent = REINFORCEAgent(num_actions, num_states)\n", 273 | "\n", 274 | "for episode_num in range(num_episodes):\n", 275 | " state = env.reset()\n", 276 | " episode = []\n", 277 | " total_reward = 0 # Track total reward for the episode\n", 278 | " done = False\n", 279 | " while not done:\n", 280 | " action = agent.get_action(state)\n", 281 | " next_state, reward = env.step(state, action)\n", 282 | " total_reward += reward # Accumulate reward for the episode\n", 283 | " episode.append((state, action, reward))\n", 284 | " state = next_state\n", 285 | " done = next_state == num_states - 1\n", 286 | " agent.train(episode)\n", 287 | " print(\"Episode:\", episode_num + 1, \"Total Reward:\", total_reward) # Print total reward for the episode\n" 288 | ] 289 | } 290 | ], 291 | "metadata": { 292 | "kernelspec": { 293 | "display_name": "Python 3", 294 | "language": "python", 295 | "name": "python3" 296 | }, 297 | "language_info": { 298 | "codemirror_mode": { 299 | "name": "ipython", 300 | "version": 3 301 | }, 302 | "file_extension": ".py", 303 | "mimetype": "text/x-python", 304 | "name": "python", 305 | "nbconvert_exporter": "python", 306 | "pygments_lexer": "ipython3", 307 | "version": "3.12.0" 308 | } 309 | }, 310 | "nbformat": 4, 311 | "nbformat_minor": 2 312 | } 313 | --------------------------------------------------------------------------------