├── README.md ├── agent.py ├── deep_qlearning.ipynb ├── environment.py └── maze_generator ├── maze.npy └── maze_generator.py /README.md: -------------------------------------------------------------------------------- 1 | # Deep Q-learning for maze solving 2 | 3 | A simple implementation of DQN that uses PyTorch and a fully connected neural network to estimate the q-values of each state-action pair. 4 | 5 | The environment is a maze that is randomly generated using a deep-first search algorithm to estimate the Q-values. Four moves are possible for the agent (up, down, left and right), whose objective is to reach a predetermined cell. The agent implements either an epsilon-greedy policy or a softmax behaviour policy with temperature equal to epsilon. After each episode, the starting position is sampled in such a way that at the beginning of the training the agent explores the area surrounding the goal, and as the training goes on it will explore further and further areas of the maze. 6 | 7 | A convolutional neural network is also implemented for completeness. 8 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.special as sp 3 | import matplotlib.pyplot as plt 4 | import copy 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import collections 9 | 10 | Transition = collections.namedtuple('Experience', 11 | field_names=['state', 'action', 12 | 'next_state', 'reward', 13 | 'is_game_on']) 14 | 15 | 16 | class Agent: 17 | def __init__(self, maze, memory_buffer, use_softmax = True): 18 | self.env = maze 19 | self.buffer = memory_buffer # this is actually a reference 20 | self.num_act = 4 21 | self.use_softmax = use_softmax 22 | self.total_reward = 0 23 | self.min_reward = -self.env.maze.size 24 | self.isgameon = True 25 | 26 | 27 | def make_a_move(self, net, epsilon, device = 'cuda'): 28 | action = self.select_action(net, epsilon, device) 29 | current_state = self.env.state() 30 | next_state, reward, self.isgameon = self.env.state_update(action) 31 | self.total_reward += reward 32 | 33 | if self.total_reward < self.min_reward: 34 | self.isgameon = False 35 | if not self.isgameon: 36 | self.total_reward = 0 37 | 38 | transition = Transition(current_state, action, 39 | next_state, reward, 40 | self.isgameon) 41 | 42 | self.buffer.push(transition) 43 | 44 | 45 | def select_action(self, net, epsilon, device = 'cuda'): 46 | state = torch.Tensor(self.env.state()).to(device).view(1,-1) 47 | qvalues = net(state).cpu().detach().numpy().squeeze() 48 | 49 | # softmax sampling of the qvalues 50 | if self.use_softmax: 51 | p = sp.softmax(qvalues/epsilon).squeeze() 52 | p /= np.sum(p) 53 | action = np.random.choice(self.num_act, p = p) 54 | 55 | # else choose the best action with probability 1-epsilon 56 | # and with probability epsilon choose at random 57 | else: 58 | if np.random.random() < epsilon: 59 | action = np.random.randint(self.num_act, size=1)[0] 60 | else: 61 | action = np.argmax(qvalues, axis=0) 62 | action = int(action) 63 | 64 | return action 65 | 66 | 67 | def plot_policy_map(self, net, filename, offset): 68 | net.eval() 69 | with torch.no_grad(): 70 | fig, ax = plt.subplots() 71 | ax.imshow(self.env.maze, 'Greys') 72 | 73 | for free_cell in self.env.allowed_states: 74 | self.env.current_position = np.asarray(free_cell) 75 | qvalues = net(torch.Tensor(self.env.state()).view(1,-1).to('cuda')) 76 | action = int(torch.argmax(qvalues).detach().cpu().numpy()) 77 | policy = self.env.directions[action] 78 | 79 | ax.text(free_cell[1]-offset[0], free_cell[0]-offset[1], policy) 80 | ax = plt.gca(); 81 | 82 | plt.xticks([], []) 83 | plt.yticks([], []) 84 | 85 | ax.plot(self.env.goal[1], self.env.goal[0], 86 | 'bs', markersize = 4) 87 | plt.savefig(filename, dpi = 300, bbox_inches = 'tight') 88 | plt.show() 89 | -------------------------------------------------------------------------------- /deep_qlearning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import scipy.special as sp\n", 11 | "\n", 12 | "from IPython.display import display, clear_output\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import copy\n", 15 | "import time\n", 16 | "import random\n", 17 | "\n", 18 | "import torch\n", 19 | "import torch.nn as nn\n", 20 | "import torch.optim as optim\n", 21 | "import collections" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "**Introduce experience replay.**" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "Transition = collections.namedtuple('Experience',\n", 38 | " field_names=['state', 'action',\n", 39 | " 'next_state', 'reward',\n", 40 | " 'is_game_on'])\n", 41 | "\n", 42 | "class ExperienceReplay:\n", 43 | " def __init__(self, capacity):\n", 44 | " self.capacity = capacity\n", 45 | " self.memory = collections.deque(maxlen=capacity)\n", 46 | "\n", 47 | " def __len__(self):\n", 48 | " return len(self.memory)\n", 49 | "\n", 50 | " def push(self, transition):\n", 51 | " self.memory.append(transition)\n", 52 | "\n", 53 | " def sample(self, batch_size, device = 'cuda'):\n", 54 | " indices = np.random.choice(len(self.memory), batch_size, replace = False)\n", 55 | " \n", 56 | " states, actions, next_states, rewards, isgameon = zip(*[self.memory[idx] \n", 57 | " for idx in indices])\n", 58 | " \n", 59 | " return torch.Tensor(states).type(torch.float).to(device), \\\n", 60 | " torch.Tensor(actions).type(torch.long).to(device), \\\n", 61 | " torch.Tensor(next_states).to(device), \\\n", 62 | " torch.Tensor(rewards).to(device), torch.tensor(isgameon).to(device)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "**Networks definition.**" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 3, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "class fc_nn(nn.Module):\n", 79 | " def __init__(self, Ni, Nh1, Nh2, No = 4):\n", 80 | " super().__init__()\n", 81 | " \n", 82 | " self.fc1 = nn.Linear(Ni, Nh1)\n", 83 | " self.fc2 = nn.Linear(Nh1, Nh2)\n", 84 | " self.fc3 = nn.Linear(Nh2, No)\n", 85 | " \n", 86 | " self.act = nn.ReLU()\n", 87 | " \n", 88 | " def forward(self, x, classification = False, additional_out=False):\n", 89 | " x = self.act(self.fc1(x))\n", 90 | " x = self.act(self.fc2(x))\n", 91 | " out = self.fc3(x)\n", 92 | " \n", 93 | " return out" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 4, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "class conv_nn(nn.Module):\n", 103 | " \n", 104 | " channels = [16, 32, 64]\n", 105 | " kernels = [3, 3, 3]\n", 106 | " strides = [1, 1, 1]\n", 107 | " in_channels = 1\n", 108 | " \n", 109 | " def __init__(self, rows, cols, n_act):\n", 110 | " super().__init__()\n", 111 | " self.rows = rows\n", 112 | " self.cols = cols\n", 113 | "\n", 114 | " self.conv = nn.Sequential(nn.Conv2d(in_channels = self.in_channels,\n", 115 | " out_channels = self.channels[0],\n", 116 | " kernel_size = self.kernels[0],\n", 117 | " stride = self.strides[0]),\n", 118 | " nn.ReLU(),\n", 119 | " nn.Conv2d(in_channels = self.channels[0],\n", 120 | " out_channels = self.channels[1],\n", 121 | " kernel_size = self.kernels[1],\n", 122 | " stride = self.strides[1]),\n", 123 | " nn.ReLU()\n", 124 | " )\n", 125 | " \n", 126 | " size_out_conv = self.get_conv_size(rows, cols)\n", 127 | " \n", 128 | " self.linear = nn.Sequential(nn.Linear(size_out_conv, rows*cols*2),\n", 129 | " nn.ReLU(),\n", 130 | " nn.Linear(rows*cols*2, int(rows*cols/2)),\n", 131 | " nn.ReLU(),\n", 132 | " nn.Linear(int(rows*cols/2), n_act),\n", 133 | " )\n", 134 | "\n", 135 | " def forward(self, x):\n", 136 | " x = x.view(len(x), self.in_channels, self.rows, self.cols)\n", 137 | " out_conv = self.conv(x).view(len(x),-1)\n", 138 | " out_lin = self.linear(out_conv)\n", 139 | " return out_lin\n", 140 | " \n", 141 | " def get_conv_size(self, x, y):\n", 142 | " out_conv = self.conv(torch.zeros(1,self.in_channels, x, y))\n", 143 | " return int(np.prod(out_conv.size()))" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 5, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "def Qloss(batch, net, gamma=0.99, device=\"cuda\"):\n", 153 | " states, actions, next_states, rewards, _ = batch\n", 154 | " lbatch = len(states)\n", 155 | " state_action_values = net(states.view(lbatch,-1))\n", 156 | " state_action_values = state_action_values.gather(1, actions.unsqueeze(-1))\n", 157 | " state_action_values = state_action_values.squeeze(-1)\n", 158 | " \n", 159 | " next_state_values = net(next_states.view(lbatch, -1))\n", 160 | " next_state_values = next_state_values.max(1)[0]\n", 161 | " \n", 162 | " next_state_values = next_state_values.detach()\n", 163 | " expected_state_action_values = next_state_values * gamma + rewards\n", 164 | " \n", 165 | " return nn.MSELoss()(state_action_values, expected_state_action_values)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "**Import the maze and define the environment.**" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 6, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "from environment import MazeEnvironment\n", 182 | "\n", 183 | "maze = np.load('maze_generator/maze.npy')\n", 184 | "\n", 185 | "initial_position = [0,0]\n", 186 | "goal = [len(maze)-1, len(maze)-1]\n", 187 | "\n", 188 | "maze_env = MazeEnvironment(maze, initial_position, goal)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 7, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "data": { 198 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOsAAADrCAYAAACICmHVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAFBUlEQVR4nO3dQU4jRxiA0XaUIxDhXeYO+P4nMHcgO1DmDs5ylJHNuBmXuz54T2KBMO1yw6dC4ld5dzqdFmB+f2y9AOA6YoUIsUKEWCFCrBAhVoj4c82DHx4eTt++fRu0lK/t+fn56sc+PT0Nue4aW69h1PNvfd2Xl5fl+/fvu3Nf2635P+vhcDgdj8erH8/1druzP5+z1vzM1lx3ja3XMOr5t77u4XBYjsfj2Qv7MxgiPh7rfr8su92Pj/3+hssCfvbxWN/e3v8cuCl/BkOEWCHi47E+Pr7/OXBTq/7P+j+vrzdcBvAr/gyGiI/vrL9Q+mf81v/gX7uGUUatYevX9llel50VIsQKEWKFCLFChFghQqwQIVaIECtEiBUixAoRw8YNZxjhG2HUiNmoc4JmWMOI686w1jVu8TOzs0KEWCFCrBAhVogQK0SIFSLEChFihQixQoRYISJ3uuHWaq+rNpo44ppbr/VW7KwQIVaIECtEiBUixAoRYoUIsUKEWCFCrBAhVohYNW74/Pw85A2KR9n6pLxRZj6B75ytTzcc5d6/N3ZWiBArRIgVIsQKEWKFCLFChFghQqwQIVaIGHZg2hozTKOMUJuMGmXE4Waf+X5dYmeFCLFChFghQqwQIVaIECtEiBUixAoRYoUIsULEqnHDp6en5Xg83nwRW4+O1cYdt75fNbUD5i6xs0KEWCFCrBAhVogQK0SIFSLEChFihQixQoRYIWKK92cddQrg1iflle7BWjPcsy2vuSzenxW4QKwQIVaIECtEiBUixAoRYoUIsUKEWCFCrBAxxemGo1w74jXDqF/tujOsYWtONwTOEitEiBUixAoRYoUIsUKEWCFCrBAhVogQK0RMcbrhKFuPrs1wD9aYYb0zrOFaTjcEzhIrRIgVIsQKEWKFCLFChFghQqwQIVaIECtEON1wpdqbE289crksrdMNZx53tLNChFghQqwQIVaIECtEiBUixAoRYoUIsUKEWCHC6YbLurXOMDpXurejzHAP7s3OChFihQixQoRYIUKsECFWiBArRIgVIsQKEWKFiFXjhqPMMMI34vlLr2tZPu8Ji59l7NPOChFihQixQoRYIUKsECFWiBArRIgVIsQKEWKFiCneTHnUCN8Io56/Nuq39c9hlFEjore4X3ZWiBArRIgVIsQKEWKFCLFChFghQqwQIVaIECtEDHsz5RmMWOsMJxbOMOpX+j2ojYheYmeFCLFChFghQqwQIVaIECtEiBUixAoRYoWIKQ5MG2XEhElpcmdZ5j4A7F4+y3vq2lkhQqwQIVaIECtEiBUixAoRYoUIsUKEWCFCrBAx7MC0Gca2rl1DbcyuNvI44v7W7sG1r+twOFz8mp0VIsQKEWKFCLFChFghQqwQIVaIECtEiBUixAoRU5xuOGqEb+uRtBnGGGcY+xxhhrHPe7OzQoRYIUKsECFWiBArRIgVIsQKEWKFCLFChFghIne64VccM/vZDKOJo5RGHu/NzgoRYoUIsUKEWCFCrBAhVogQK0SIFSLEChFihYgpTjccZevxuRlGLrd+o+q1j+UyOytEiBUixAoRYoU72++XZbf78bHfX/d9YoU7e3t7//NLxAoRYoUIscKdPT6+//klq4YigN/3+vqx77OzQsRu5ejav8uy/DNuOfDl/X06nf4694VVsQLb8WcwRIgVIsQKEWKFCLFChFghQqwQIVaIECtE/AfYYY6NVPlPoAAAAABJRU5ErkJggg==\n", 199 | "text/plain": [ 200 | "
" 201 | ] 202 | }, 203 | "metadata": {}, 204 | "output_type": "display_data" 205 | } 206 | ], 207 | "source": [ 208 | "maze_env.draw('maze_20.pdf')" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "**Define the agent and the buffer for experience replay.**" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 9, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "buffer_capacity = 10000\n", 225 | "buffer_start_size = 1000\n", 226 | "memory_buffer = ExperienceReplay(buffer_capacity)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 10, 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "from agent import Agent\n", 236 | "agent = Agent(maze = maze_env,\n", 237 | " memory_buffer = memory_buffer,\n", 238 | " use_softmax = True\n", 239 | " )" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "** Define the network.**" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 82, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "net = fc_nn(maze.size, maze.size, maze.size, 4)\n", 256 | "optimizer = optim.Adam(net.parameters(), lr=1e-4)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 83, 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "data": { 266 | "text/plain": [ 267 | "fc_nn(\n", 268 | " (fc1): Linear(in_features=400, out_features=400, bias=True)\n", 269 | " (fc2): Linear(in_features=400, out_features=400, bias=True)\n", 270 | " (fc3): Linear(in_features=400, out_features=4, bias=True)\n", 271 | " (act): ReLU()\n", 272 | ")" 273 | ] 274 | }, 275 | "execution_count": 83, 276 | "metadata": {}, 277 | "output_type": "execute_result" 278 | } 279 | ], 280 | "source": [ 281 | "device = 'cuda'\n", 282 | "batch_size = 24\n", 283 | "gamma = 0.9\n", 284 | "\n", 285 | "net.to(device)" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "**Define the epsilon profile and plot the resetting probability.**" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 84, 298 | "metadata": { 299 | "scrolled": false 300 | }, 301 | "outputs": [ 302 | { 303 | "data": { 304 | "image/png": "\n", 305 | "text/plain": [ 306 | "
" 307 | ] 308 | }, 309 | "metadata": {}, 310 | "output_type": "display_data" 311 | }, 312 | { 313 | "data": { 314 | "image/png": "\n", 315 | "text/plain": [ 316 | "
" 317 | ] 318 | }, 319 | "metadata": {}, 320 | "output_type": "display_data" 321 | } 322 | ], 323 | "source": [ 324 | "num_epochs = 20000\n", 325 | "\n", 326 | "cutoff = 3000\n", 327 | "epsilon = np.exp(-np.arange(num_epochs)/(cutoff))\n", 328 | "epsilon[epsilon > epsilon[100*int(num_epochs/cutoff)]] = epsilon[100*int(num_epochs/cutoff)]\n", 329 | "plt.plot(epsilon, color = 'orangered', ls = '--')\n", 330 | "plt.xlabel('Epochs')\n", 331 | "plt.ylabel('Epsilon')\n", 332 | "plt.savefig('epsilon_profile.pdf', dpi = 300, bbox_inches = 'tight')\n", 333 | "plt.show()\n", 334 | "\n", 335 | "mp = []\n", 336 | "mpm = []\n", 337 | "reg = 200\n", 338 | "for e in epsilon:\n", 339 | " a = agent.env.reset_policy(e)\n", 340 | " mp.append(np.min(a))\n", 341 | " mpm.append(np.max(a))\n", 342 | "\n", 343 | "plt.plot(epsilon/1.3, color = 'orangered', ls = '--', alpha = 0.5,\n", 344 | " label= 'Epsilon profile (arbitrary units)')\n", 345 | "\n", 346 | "plt.plot(np.array(mpm)-np.array(mp), label = 'Probability difference', color = 'cornflowerblue')\n", 347 | "plt.xlabel('Epochs')\n", 348 | "plt.ylabel(r'max $p^r$ - min $p^r$')\n", 349 | "plt.legend()\n", 350 | "plt.savefig('reset_policy.pdf', dpi = 300, bbox_inches = 'tight')\n", 351 | "plt.show()" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "**Training the network.**" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 85, 364 | "metadata": {}, 365 | "outputs": [ 366 | { 367 | "name": "stdout", 368 | "output_type": "stream", 369 | "text": [ 370 | "Epoch 19999 (number of moves 43)\n", 371 | "Game won\n", 372 | "[####################################################################################################]\n", 373 | "\t Average loss: 0.00128\n", 374 | "\t Best average loss of the last 50 epochs: 0.00133, achieved at epoch 17204\n" 375 | ] 376 | } 377 | ], 378 | "source": [ 379 | "loss_log = []\n", 380 | "best_loss = 1e5\n", 381 | "\n", 382 | "running_loss = 0\n", 383 | "\n", 384 | "for epoch in range(num_epochs):\n", 385 | " loss = 0\n", 386 | " counter = 0\n", 387 | " eps = epsilon[epoch]\n", 388 | " \n", 389 | " agent.isgameon = True\n", 390 | " _ = agent.env.reset(eps)\n", 391 | " \n", 392 | " while agent.isgameon:\n", 393 | " agent.make_a_move(net, eps)\n", 394 | " counter += 1\n", 395 | " \n", 396 | " if len(agent.buffer) < buffer_start_size:\n", 397 | " continue\n", 398 | " \n", 399 | " optimizer.zero_grad()\n", 400 | " batch = agent.buffer.sample(batch_size, device = device)\n", 401 | " loss_t = Qloss(batch, net, gamma = gamma, device = device)\n", 402 | " loss_t.backward()\n", 403 | " optimizer.step()\n", 404 | " \n", 405 | " loss += loss_t.item()\n", 406 | " \n", 407 | " if (agent.env.current_position == agent.env.goal).all():\n", 408 | " result = 'won'\n", 409 | " else:\n", 410 | " result = 'lost'\n", 411 | " \n", 412 | " if epoch%1000 == 0:\n", 413 | " agent.plot_policy_map(net, 'sol_epoch_'+str(epoch)+'.pdf', [0.35,-0.3])\n", 414 | " \n", 415 | " loss_log.append(loss)\n", 416 | " \n", 417 | " if (epoch > 2000):\n", 418 | " running_loss = np.mean(loss_log[-50:])\n", 419 | " if running_loss < best_loss:\n", 420 | " best_loss = running_loss\n", 421 | " torch.save(net.state_dict(), \"best.torch\")\n", 422 | " estop = epoch\n", 423 | " \n", 424 | " print('Epoch', epoch, '(number of moves ' + str(counter) + ')')\n", 425 | " print('Game', result)\n", 426 | " print('[' + '#'*(100-int(100*(1 - epoch/num_epochs))) +\n", 427 | " ' '*int(100*(1 - epoch/num_epochs)) + ']')\n", 428 | " print('\\t Average loss: ' + f'{loss:.5f}')\n", 429 | " if (epoch > 2000):\n", 430 | " print('\\t Best average loss of the last 50 epochs: ' + f'{best_loss:.5f}' + ', achieved at epoch', estop)\n", 431 | " clear_output(wait = True)" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": 86, 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "torch.save(net.state_dict(), \"net.torch\")" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 87, 446 | "metadata": {}, 447 | "outputs": [ 448 | { 449 | "data": { 450 | "image/png": "\n", 451 | "text/plain": [ 452 | "
" 453 | ] 454 | }, 455 | "metadata": {}, 456 | "output_type": "display_data" 457 | } 458 | ], 459 | "source": [ 460 | "plt.plot(epsilon*90, alpha = 0.6, ls = '--', label = 'Epsilon profile (arbitrary unit)', color = 'orangered')\n", 461 | "plt.plot((np.array(mpm)-np.array(mp))*120, alpha = 0.6, ls = '--',\n", 462 | " label = 'Probability difference (arbitrary unit)', color = 'dimgray')\n", 463 | "plt.plot(loss_log, label = 'Loss', color = 'cornflowerblue')\n", 464 | "plt.xlabel('Epoch')\n", 465 | "plt.ylabel('MSE')\n", 466 | "plt.legend()\n", 467 | "plt.savefig('loss.pdf', dpi = 300, bbox_inches='tight')\n", 468 | "plt.show()" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "metadata": {}, 474 | "source": [ 475 | "**Show the maze solution and the policy learnt.**" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 88, 481 | "metadata": { 482 | "scrolled": true 483 | }, 484 | "outputs": [ 485 | { 486 | "data": { 487 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOsAAADrCAYAAACICmHVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAE6UlEQVR4nO3dQW4bNxiA0VHSI6RpdvUdrPufQL6Du2tdX6FVdwVaSIEGMUV+9ntAFoaVETXyBxrwD+pwPp83YH2fZi8AuI1YIUKsECFWiBArRIgVIn7a8+AvX76cHx4eBi3lY3t6err5sY+Pj0Ouu8fsNYx6/tnXfX5+3l5fXw+XvnfY83fW4/F4Pp1ONz+e2x0OF9+fi/a8Z3uuu8fsNYx6/tnXPR6P2+l0unhhvwZDhFghQqwQIVaIECtEiBUixAoRYoWIXRNMe5T+GD/7D/x71zDKqDXMfm3v5XXZWSFCrBAhVogQK0SIFSLEChFihQixQoRYIUKsEDFs3HCFEb4RRo2YjTonaIU1jLjuCmvd4y3eMzsrRIgVIsQKEWKFCLFChFghQqwQIVaIECtEiBUicqcbzlZ7XbXRxBHXnL3Wt2JnhQixQoRYIUKsECFWiBArRIgVIsQKEWKFCLFCxK5xw6enpyEfUDzK7JPyRln5BL5LZp9uOMq9f27srBAhVogQK0SIFSLEChFihQixQoRYIUKsEDHswLQ9VphGGaE2GTXKiMPN3vP9usbOChFihQixQoRYIUKsECFWiBArRIgVIsQKEWKFiF3jho+Pj9vpdHrzRcweHauNO86+XzW1A+ausbNChFghQqwQIVaIECtEiBUixAoRYoUIsUKEWCFiic9nHXUK4OyT8kr3YK8V7tnMa26bz2cFrhArRIgVIsQKEWKFCLFChFghQqwQIVaIECtELHG64Si3jnitMOpXu+4Ka5jN6YbARWKFCLFChFghQqwQIVaIECtEiBUixAoRYoWIJU43HGX26NoK92CPFda7whpu5XRD4CKxQoRYIUKsECFWiBArRIgVIsQKEWKFCLFChNMNd6p9OPHskctta51uuPK4o50VIsQKEWKFCLFChFghQqwQIVaIECtEiBUixAoRTjfc9q11hdG50r0dZYV7cG92VogQK0SIFSLEChFihQixQoRYIUKsECFWiBArROwaNxxlhRG+Ec9fel3b9n5PWHwvY592VogQK0SIFSLEChFihQixQoRYIUKsECFWiBArRCzxYcqjRvhGGPX8tVG/2e/DKKNGRN/iftlZIUKsECFWiBArRIgVIsQKEWKFCLFChFghQqwQMezDlFcwYq0rnFi4wqhf6eegNiJ6jZ0VIsQKEWKFCLFChFghQqwQIVaIECtEiBUiljgwbZQREyalyZ1tW/sAsHt5L5+pa2eFCLFChFghQqwQIVaIECtEiBUixAoRYoUIsULEsAPTVhjbunUNtTG72sjjiPtbuwe3vq7j8Xj1e3ZWiBArRIgVIsQKEWKFCLFChFghQqwQIVaIECtELHG64agRvtkjaSuMMa4w9jnCCmOf92ZnhQixQoRYIUKsECFWiBArRIgVIsQKEWKFCLFCRO50w484ZvZ/K4wmjlIaebw3OytEiBUixAoRYoUIsUKEWCFCrBAhVogQK0SIFSKWON1wlNnjcyuMXM7+oOq9j+U6OytEiBUixAoRYoU7e/n8bdsOh3//vXz+dtP/Eyvc2de///ju19eIFSLEChFihTt7+fTLd7++ZtdQBPDjvv71+3+/vvH/2Vkh4rBzdO3Pbdt+G7cc+PB+PZ/PP1/6xq5YgXn8GgwRYoUIsUKEWCFCrBAhVogQK0SIFSLEChH/AFuQjEBDA4n2AAAAAElFTkSuQmCC\n", 488 | "text/plain": [ 489 | "
" 490 | ] 491 | }, 492 | "metadata": {}, 493 | "output_type": "display_data" 494 | } 495 | ], 496 | "source": [ 497 | "net.eval()\n", 498 | "agent.isgameon = True\n", 499 | "agent.use_softmax = False\n", 500 | "_ = agent.env.reset(0)\n", 501 | "while agent.isgameon:\n", 502 | " agent.make_a_move(net, 0)\n", 503 | " agent.env.draw('')\n", 504 | " clear_output(wait = True)" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": 89, 510 | "metadata": {}, 511 | "outputs": [ 512 | { 513 | "data": { 514 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAO0AAADrCAYAAACFFBGSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAWbklEQVR4nO2dP2hdx/LHR7/8QAlOZfKiNMkvhd3YAhWyESgYQ0CkCypDEL4IkdaVQFV6gYs0dnBpoVQubJQgBblSIUso0Q12iIsQEmKDiR5JIWPLlgprX+Gf/K6k82fn7M7OzJ75wIH33t03Z3Z29175nu/9Tp9zDgzD0MP/cCdgGAYOO7SGoQw7tIahDDu0hqEMO7SGoQw7tIahjP/FDH7nnXfchx9+SJRKu+l2u95jh4eHSeJi4M6B6v7ccf/880/4559/+ioHOee8r+HhYSeN7e1td+/ePe40ggEA7+soVTXAxNWUA9X9ueP+/xmrPIfq/zz+7bff4KuvvuJOgxUJNeDOger+EuNGO7Q///wz/PHHH6xjy9jf34fvvvtOzVjDqCLaod3d3YXx8XGvA0Y1toj9/X2YnJyE1dVVFWMNow7UF1EHfPPNNzA7O3vsf//rr7/gs88+gx9++IF8rC/Xr1+H+fl5OHPmDCwuLh567fTp03D79m1RYw2jlrp/9DrPL6IePnzohoaG3OrqaumY2GPX1tbcxsaG63Q6bnNz0+3u7h4b8/TpU3fx4kV348aN2ntxjoWGX2rU1QATV1MOVPfnjuvzRVS0Q3vnzh139+7d0tcpxs7MzLiRkRE3ODjohoaG3NbWVuG4Z8+euStXrnjdj2ts0w1QV4MUh5YjB6r7c8dNemi5mJ6edqdOnXKPHz/mTiWIphvAueoapDi0HDlQ3Z87rs+h7XOI39OeO3fObW5ueo9PxcuXL+GNN97gTiOIvr7q5+m9FK1ZWQ0wcTFw50B1f+64586dg83NzcrAjb6I8iHlZik7sL45YN64Us4LQ4w3LaocQuNy3z913DrUiysMo20EHdonT57ATz/9FCuXxjncv3+fJC733HyhqoFBR8iaNT60T548gU8++QRGR0fh+++/bxomGAqZmZS5HVCnCqOWEEpQu+U2lkXG+MUXX8Do6Ch8/PHH8OWXX8KjR4+ahhKHtLmFqsJS3t/G4sdiafxF1NzcHDx48ACuXr0Kt27dgjfffDNmXqxwzo1CFUZ1fxuLHxuFumdCvdfR57Q//vij63Q6x55BYZ9hYa6jxMihiLK4qebVS52CLNY6NL2/jcWPLVuz7H+at76+Dvv7+wDw6gfJe3t7zBnR8Ouvv8LXX38NH3300bHXUtSg6v42Fj82eM3qTrUT/EkbS2aGeSdMMS8MVFI7g45QGaPqT9rZ2Vm4cOEC7O7uwtLSEgwMDHCnlByrgT5C14xMxqhJPoepgVRFFJXUzqBDnIwx5UFIqTum2tihhytGDagOOEVcCbliiLlmqv88Now2ov7QmoxRRq5UOVDElVKvxvu27puq3qvq97T37993v//+e+FrL1++dN9++23p/xcCvuGkeE67vb3tRkZGXH9/v1taWirNu25emLFNa1CXKyauhBx840rIlWLfJv32uEy2pdHUzEfGKMXYTYLkkioHirgS6hVM3al2BZ+08/Pz7uzZs8eukydPuvPnzx9657h27Zrr6+srHD8+Ph70rln1joWJe5Tnz5+/jvvixYvC2D7zSlGDulwxcSXk4BtXQq4U+9bnk7bRt8cTExMwMTFx6H979OgRfPrpp8d+uXDp0iW4efMmTE5OQqfTaXK75Lz11luv/3OZ7hgzL8oa+ORKDVUOFHEl1CuUaH8el8m23n77bVhcXIS///471q1ewyljxMyLsgaGPkL3bbTntGNjY6WvnThxAqanp2Pd6jULCwuwsrICOzs7MDU1BcvLy0kVQZh5UdXA0EfovlX9yMckfIZGxMoYUUkwu+VhakBFDgqfGDn4xs2hXuJkjFSTLyKljDHlvGJA9YbErWnmftOIgckYDaMlmBtjRVzuuflCKSHkdnnUtA4YzI0xczfGKkJyle7y6Du3VL1/zY1RMJrmFpKrdJfH2HLSUKS4MTb+wYB0ORgmh6PEks+lqIGP5NI5nPQ0Zm0P8m1y/1hy0ia1xeQbs7bkXfO4PaJi5YApKveBxdagCg0uj3Vzo+z92yRf37EifuXDQVvcGKnQ5PJYRkqJqBQ3RrLntCngljFqp0p6mqK2VffHkEoiismXsraqP2lNxkiH1ZYOsTJGbjkYJgdMDSQoojD5hpKDy6PUXJPIGLvdrkg9b2jTX6mLWoYGx8BeuLXHVHA5aCb781iCuoYKirlJqJcENRJ3DhLVZskOLbW6hhOKuXHXS4IqjDsHyvuzKKKo8JWkYaRrqWRuGDCSOA4kqMK4c+C+fxmiDq2vJE2KE2II3BLCOubm5uDzzz+Hd999F1ZXV+GDDz5oXQ7c9y9D1HPa69evw/z8PJw5cwYWFxcPvXb69Gm4ffs2ahx2LBXcjaKbIMEAjTsH7vuXkeTQrq+vv/6mrNvtwuDgIPT39x8b5+taKMUJEcBvbhj3St+Yhl5C1zfJn8cLCwtw+fJl6Ha7MDU1Bdvb24XjfCVpkpwQfed2lCqZW9OYhg6C17dOnNx7QYDwenp62p06dco9fvy4VmydmpB5OUczt6qYmHxD5ib5ByHc98fmcJSy9fX5wQBKEdXX1+c9uChumbqGmxgP/ynmFkONhAG5F1hz4L4/Ngff88Bq7FZE2aaWoG4JhULqJ/ENrgkUBwETUyriFVGGYcTBZIwRyLmHqjQJX2pE1qDuH729Fwj4Bz33dZRce6hS9ZyNtRdSwFGD7J0rJJBrD1UJOTQllmui1BrYoQ2EQuomQT4nIYemxHJNlFoDUTJGjeTaQ1VCDj5gJKJYOanUGoiSMRr5E3svYCSiWDkpFaE1SHJozYDNOCDFXjiQiI6OjkYdGwsVxm5mEmYckGIvjI2NeR9CzNhYJDV2o5IxalJEIevFHhdDyrlZT+FEMsbh4WEIaSodasBGhaY3DQD+esUgh57CoetgMkbDaAnqD60ESRy3jFFbj14q90pu2ScGFW6MVOToWoiNqa1Hb+x8uV0bmyDCjTFmw92mY8ugdGPULGPUKOErQmKulG6b0Q4tVcPdUNdCajdGzTJGjRK+IiTmSum22UhcQSUdo3AtpHZj1CJjzEXCVwR3rsndNut+BtR7HW0q3UvMhru+Y9fW1tzGxobrdDpuc3PT7e7uHhuToulwDB8j35hNalAGR1Ppunwp6oXJNXQv+NS2qgZJf5oXq+EuZqyPq13KpsMchDj7xWjqjKUtTpOkbpt1p9p5ftJyEcMJERR/0jpH53SZ2okQE5cq19C94EsyN0ZMf9qUhDohSpAbYuIWkdrpUpOMUaoiSnV/WirXQm4dq6YaYAmtWUpJK9X6Zt+f1jCMOCR1Y9QiiWuSQ46NjzWhrQbi3Rhzdvarm5svEmobsg6cxFoDLBT7Vowbo0SZmS91cjTfuVFJKaXVlkPOiqmB9mbkAIn+PJYoM/OlTo7mMzdKKaW02nLIWX1rkEMzcoBEHlHcMjNfmsjRfOZGKaXkrK0UOatvDbQ1Iy+l7u9nF+HftFV/w4f824BKEteLj+SyTghBLaWUJCzgkLPW1eAAn3XAjKXYt2L+TUtFCklcDKlf7lLKXjjkrL5IaUaupqm0JkkcFozksAwJtQ3JgZsYa4CFYt+KaypdEdd7bFFcc/aLExeDhJpx0wo3RipCJXESDoy2uBJy4Cb0jctkjIbREkzGGCkHCjdGKofFHOulDfFujJqc/bBQzY1qXrnWSxvi3Ri5pXba3Bhzpi31Uu/GyCm10+jGmDNtqZdqN0YAXqmdRjfGnMmxXqndGBsdWk3NeS9dugQ3b96EyclJ6HQ6yXILgaoJtzX3pgF7HkLXgcWNMRa5ujFSyTPb4oQogSzcGDlkjL6E3L9qblTzkhA3JAcO2WWKeWFQ78ZIJWNMdX+quDGacFPFxUBVM24w8ypCtRtjKKmdCH0JrQGVJarU5t5ScvDF3BgNw/DCZIyRcuBsKq2NnOfmi8kYFcryqhQzqaR+vqqdmGZtvnOjMmCTYuwmQsZYRc7StaZzq1LMpKqXr2onplmbz9yoDNhyMXZL4lzx/Pnz11/zv3jxIuhr/qOkfnxwlLq5Oefc/Py8O3v27LHr5MmT7vz58+iYTfL1zQGTK3ZevvW6du2a6+vrK4w9Pj4uZizFvvV55KPebob70NblUAZHb1hsDthxvmNjGOFJGMt1aFV/e7y+vg77+/sA8Opx1N7eHnNG/nAoyJrmkIsBmxRjt+B9W3eqneBP2pmZGTcyMuIGBwfd0NCQ29raahy36bzq5taUFJ+0XHCYsFFAsW+z/6SdnZ2FCxcuwO7uLiwtLcHAwAB3SoZRS+i+NTdGJJh5UZGyWbVRjrkxBpBSlidBn4uBW8oJQPcGTkHKNzmTMRpGS1AvY6SCWx6prV4YNM1Novw2ibGbRlkepzwyZ8dCTXOTKr9NYuymWZbHgck+ZSA217pnQr3XgXNFrrK8XmK4UTSZV+x6Sbg0zY0qV9/95fOcNomxG4UDn28OUkzoMOToWHiAprlJzVW1sVvTHOrGaZZHGvIJ3V+NPmmLGBsbixWKPIe6cQsLC7CysgI7OzswNTUFy8vLprYyohG6v+w5bQEmjzQoSSpjlOrGSEWRzEyCIopbNYRF09yoci2K22o3xjJCcy2SmUnYgJpqS4WEGoQiXsZIqSzRoq7BkOu8AHTNTWSudc+Eei9o+Fxqe3vbjYyMuP7+fre0tBTteVesuEVsb2+7e/fuFb7mC9W8UlFVA4lrFhuqfXsQu6i2Yn5PS6UsoVSs+MjMqNz6QuYV0zmRQsqZQmUUqwbc+6uMJIeWqidprn1vQ+YV0zmRghRrFqsGUnvpRntOWwWVsiTXvre+88L0RU3dQ7WM2GtGWQOpiqgkhzZHJPS9xUg0Nco5fWhjDUxcUYCPzExq39tYzomapZyp3SOxiJEx5oSvzOzEiRMwPT3NkGE5GDlp1VjNUs5YNaAiuLZ1Xy+7CI986uCOW4SEZtXcUDW2rkJTvUJqkKyptFQ3xtC4ZTGlNqtOCVVj6zJyVNwBFOebRMZIBdUBD0XiZqWkKN8Yja255yZVTtpaGSMVFMZu3GZxWDStmaZcAYQYu1WhyczrAAo1EHcvXQya1kxTrgeIV0Rh5GCUjXwNf8SamhUgMVeMlBKLKBmj6Ea+LUOqhK8IiblSSkRFyRgppYEGDqkSviK4c00tERXx7fEBEqSBAK8UKwff7HW7XRgcHIT+/n5xMQ0ZYOWRoXtBlIxRijRwYWEBLl++DN1uF6ampmB7e1tkTEMuVfLI0L0g6tACyJAGUhi7mVlcuxgbG4PR0dHC17LoT4shZc/ZUGM3XxUMNi4VmsQgEnIN3eNNFVHiPmklEUMNlCKmoZOme0FEU2lNGl2q+2v6ywSAfx2okKCZr0O9jFGbNNCgQdv+arWMUZM00KBB4/4SIWPkcrXzASONzHlsU2I6PFKM5d5fqUnSVFqLa2LOY0OgcnjM3TWRikaKKE2udhhpZM5jfaFyN2yjayIZddYWvddBJ/giHj586IaGhtzq6mrh62Wdr51rbtuxtrbmNjY2XKfTcZubm253d/dY7KdPn7qLFy+6GzdulOae89imte2lbm0ljOXYXxRxk3YY4HC185GDYaSROY8NgcrdULprIpX0NDhu3al2np+0dVC8EzoXx4AtZ0JqqwmO/UURN7qxm9T+tKEGbDmjSbhChQTpqW/c1sgY7cAalFDtryQyRkxTaQmk/HFBqvtjc6BC0z6gqhfXOqj/pJUgY+R2ArSG3XRI2F9HUX9ouWWM3E6AVPfnnpcUuPdXEeoPLTfcEjqNDbslEcs10VdOGkN2aoc2EG4JXY4Nu1MSwzXRV04aS3YqythNI9wSuhwbdlNB5ZroKyeNJTtVfWjN4dDAQOWa6OsiGsttVPWfx+ZwaIQSwzXRV04aS3aq+tCaw6ERSizXRF8X0Rhuo2RujFRQycxC7l+GBHFFzq05fQkVQaSUyfrIGEVojzUhYbNSKXFy1Slre+NyzuWvPTaMNpHUjVGaHCwXJMgNc11fifNKdmglysFyQIrcMNf1lTivJG6MGCjkYNLcDWOiqWG3rVkckrgx+kIhB5PobhgTLQ27bc3ikcSN0RcKORi3uyE1Whp225pFpM6Ppvdq6sbo45ronL8ToRZ3Q6rrKFX+SFRz60W6KybVvAjXl9+NkVMOJtHdkAPKuWlxxcQiViZbd6qd5ydtHbm4JoLST9rYczuK5PWlmhfh+spxY8zBNdEUUeU5SF1fCW6MGFyNIkpEJ3huja42+Z62Nw6K+nLvGWwOvnFbY6FqGG2iVTJGCXI/XyTUCwNl82Vup8tWNpUG4JeDSZH7+cJdLywU+UpYs6ybSksnhbsgVUPlttIWR0gsrTm0KdwFqRoqt5W2OEJiUW3shiG2uyBlk2TjFTk6QsYgySft+vo67O/vA8ArV7u9vb0UtyVlYmICfvnll0PX0tISvP/++8f+rYIZC6CvXtry9YVqXqFxk3zSLiwswMrKCuzs7MDU1BQsLy9nacJ2IOUsMwrzHautXtry9YVqXsFx6yRTvRfok4MdokzuFxKTCo56SVjfo8RYM4nzStZUOlQRxSAH8xonVRGVul4YqNaXYs0wpJxX06bSqD+Ph4eHIUR7XKZLpToI3Js7dFFj6HiRb8pB90qpO0755imtqXRrHvkYRi6oP7QS5H4UUjuJ8jmuuFqkpxhEyBip1EB1Y7nlflRSO4nyudRxfWur0TBOhIyRSg0kXTlkUjs6fGrbRsM4cmO33JVDc3Nz8ODBA7h69SrcunXLlDsR8altKw3j6p4J9V5Njd0ox1ZZrUCC53hVOYTEjTWvnOM6p9MwrmpePs9pkxi7UY01+VzecX3QaBgXXK+6U+08P2k5mJmZcSMjI25wcNANDQ25ra2tQ6+D0k/amPPSHLeqthKgqFfST1oOcm0qTTUvbXFzJbReWfSnzaGpdFHclPI5yXGlokLG2O12vRNNubkl2nbGgEr2mTouBm7paUpMxmgYLSGpGyOFHC1XGaM2JKyDJkTIGKugdNXLVcbYC5dEFDOOex20IULGWEXOUr8Uc+OWiEqXkraNJHYzOUv9Ys+NWyKqUUraOuoe5PZeoEzmhonb9P5VOYTGPYBLIooZl3odtF2+9cpeXJGrjPEoHBJRzLi2rEMsVLgxUpGrC+BRxsbGWMfWjWvLOsQitF6qP2lNPicDWwccSWWMEvrTFsXNVcaoDcnukdz47ltxboxUpHZ59EXCAU9pNSp1HaQiXsaYs2qIe27c9zfSol4RxQ333Ljvb6QniRtjKkVUTFme71jfuVE5+0lTm3FLLjWOxZLEjTFVn1EKWV7dWJ+5UTr7Sevhyi251DD2vfcA+vr+e733Xm24Q5C7MQLQ9BmlkOU1kfD5zI3S2Y+zhyu35FLjWACAf/8bKv97LXWSqd4rxI2RSsaIyQE7zndsDMfAkBpwSAhD6tX2sQDHrwN8ZIzRDu2dO3fc3bt3S19PcWjrcsCO8x3rYz727Nkzd+XKldLXczm0sWub41gxh7aOFIeWixiOgbkcWqOegYHDB3Zg4L+vRe9PS2Xs1jblUBE5iCuMcHwUUaq1x4bRRrDa478B4CFdOobRev7POfevqgGoQ2sYBj/257FhKMMOrWEoww6tYSjDDq1hKMMOrWEoww6tYSjDDq1hKMMOrWEoww6tYSjjP/B1t4mP6K/+AAAAAElFTkSuQmCC\n", 515 | "text/plain": [ 516 | "
" 517 | ] 518 | }, 519 | "metadata": {}, 520 | "output_type": "display_data" 521 | } 522 | ], 523 | "source": [ 524 | "agent.plot_policy_map(net, 'solution.pdf', [0.35,-0.3])" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": 93, 530 | "metadata": {}, 531 | "outputs": [ 532 | { 533 | "data": { 534 | "text/plain": [ 535 | "" 536 | ] 537 | }, 538 | "execution_count": 93, 539 | "metadata": {}, 540 | "output_type": "execute_result" 541 | } 542 | ], 543 | "source": [ 544 | "best_net = copy.deepcopy(net)\n", 545 | "best_net.load_state_dict(torch.load('best.torch'))" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": 147, 551 | "metadata": {}, 552 | "outputs": [ 553 | { 554 | "data": { 555 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAO0AAADrCAYAAACFFBGSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAWF0lEQVR4nO2dT2hdxR7Hf3k+qFJX4jNu6nPRbmogi1QCKUEQgjvpskjIJQS3XQWych/Iwk2VLg3tqouUKIkkqyzyh9RcUWkWIogWSvPQRYqJ3i6aeYu8+G6Se86ZOTO/f3N+Hzjwnhnm/GbmzL1tzqff6XPOgWEYevgHdwGGYYRhm9YwlGGb1jCUYZvWMJRhm9YwlGGb1jCU8c+Qxq+//rp7++23kUppNu1227vt0NAQSr8hcNeAdX/ufn/55Rf4/fff+0obOee8r6GhISeN/f19991333GXEQ0AeF9nKZuDkH411YB1f+5+/7fHSveh+j8e//TTT/Dpp59yl8GKhDngrgHr/hL7TbZpf/jhB/j5559Z2xZxdHQEX331lZq2hlFGsk3b6XTgxo0bXhsMq20vjo6OYHJyEtbX11W0NYwqgn4RdcK9e/dgdnb23H9/+vQp3Lx5Ex4+fIje1pc7d+7A3bt34erVq7C0tHTqZ1euXIEHDx6IamsYlVT9pdd5/iLq119/dYODg259fb2wTeq2m5ubbnt727VaLbezs+M6nc65Nn/88Yd777333BdffFF5L862UPOXGlVzENKvphqw7s/dr88vopJt2tXVVbexsVH4c4y2MzMzbnh42A0MDLjBwUG3t7fXs93BwYGbm5vzuh9X27oPQNUcUGxajhqw7s/dL+mm5WJ6etpdvnzZPXnyhLuUKOo+AM6VzwHFpuWoAev+3P36bNo+F/Dvaa9du+Z2dna821Px4sULeOmll7jLiKKvr/x9eje91qxoDkL6DYG7Bqz7c/d77do12NnZKe241i+ifKB8WIo2rG8NIR9clOMKIcWHFlYNsf1y35+63yrUyxWG0TSiNu2zZ8/g22+/TVVL7Rq+//57lH65x+aLplqNY2Ke29qb9tmzZ/DBBx/AyMgIfP3113W7iQZDM5MythPKrDCKWiXYbrm1ZdEYP/74YxgZGYH3338fPvnkE3j8+HHdrsQhbWxlVhhFrRJst5zbhlL7F1Hz8/Owu7sLt2/fhoWFBXj55ZdT1sUK59hCrbDUtUqw3XJum4Sqd0Ld19n3tN98841rtVrn3kGFvsMKuc6SooZeFPVLNa5uqgyyVOtQ9/7WNrxt0Zpl/0/ztra24OjoCACO/0Hy8+fPmSvC4ccff4TPP/8crl+/Lv7+1ra6bexzi/aeloLFxUVYW1uDw8NDmJqagpWVFejv7+cuKzljY2Nq7m9tq9vGPreqv2lnZ2dhdHQUOp0OLC8vZ7lhjfyIfW7RNEZN+lzIHEg1ooqIVe0MPMRpjJQbgdI7znlzYdWA0a+EWkPoVYNpjIbREExjLOmXe2y+SKgVqwaMfqXMV+3ntuqdUPfV/Z52f3/fDQ8PuwsXLrjl5eXSd1YvXrxwX375ZeHPQdh7Wt+xVY0rpG3dOaiqNaRfCTX49iuhVoznFvU9ra8+pzHUzGdsUoLdJCiXWDVg9CthvqKp2tWu4Jv2zz///PvT4q+//jr3iXHCZ5995vr6+tw777xz7rpx40bUp2bZJ1ZIv2fxGZvPuCjmoKrWkH4l1ODbr4RaMZ5bn2/a2r89fuWVV/7+32W+68TEBNy/fx8mJyeh1WrVvR0pPmMLGRfmHPiuAyZYNWD0K2G+YkH/7fGrr74KS0tL8NtvvyXvm1NjDBkX5hwY+lChMV68eBGmp6eT98utMYaMC2sODH2Yxmgao6EMsRpjUBHMaXkhc4BFDoZPihp8+81hvsRpjFiD7wWlxkg5rhRgfSBxa5fcHxopMI3RMBqCaYwl/XKPzch3HSyNMfM0xhBSJAZSnaVbVavvOkipN6StpTEmRvPYYhMDKbXTqlpT66SxSEljRNMYgVkHC6nhLKn0Oew5uHv3bk818rXXXnPvvvvuqbp92/rqmaH1htTqsw4h9UqZ225iNEZLY+zRZ1m/3Bu2qN4TYhMDfc/dTVGvT61l6+tbr5S59RmXpTE2kNjEQErlMkXKpNR6LY2xAG6NUSIpEgOplMtUKZMS67U0xgJMYzQ0IlZj5NbBQmoImQMJRlRIvVhItYx6IbVWEo2x3W6L9HljD/2VuqhFSPCJQ+B2j7GIHZd4jRHLXJJArmOTMC5uIwozsI7ciAoFw1ySQq5j4x4Xt5mGeX8WIwoLXyUtRF2j0txS1SChXglwm2nc9y9C1Kb1VdKkJCH6oq1eKczPz8NHH30Eb7zxBqyvr8Nbb73VqPsXIeo97Z07d+Du3btw9epVWFpaOvWzK1euwIMHD4LahbbFQlu9UuAOYeO+fyFVylT3BTV1sM3NTbe9ve1arZbb2dlxnU6np9rlq9D5tvNtW3dcvmPjrBdrzShqcI7/cG+MA7vL5laMxri4uAi3bt2CdrsNU1NTsL+/37Odr5ImKQnRZ2yS6vXFd82McKLntmpXuwTftM45Nz097S5fvuyePHnS81OLk5hxOUc/tpB6sdaMqoYcv2nL5tbnmzbIiOrr6/Nu3KvfInOJmxQv/ynHpsk2S1FDL7jvH1qD79yyBrv1ouihlmC3xBL7YIc8LJSk+CDC2AhS5ysE8UaUYRhpMI0xARhj49b3TmrACs3T8iyYxpih6geQfmzc+t4JWGum5VkwjdHwRqo+p4VUqYlS18E2rUCk6nNaSJWaKHUdRGmMxjFi9TmB3Lt3D2ZnZ8/996dPn8LNmzfh4cOHtdoCyF0Hkm/anAPYch0b1rhS9zs+Pg6PHj06dS0vL8OlS5fO/Z0xpC0mKoLdcg5gy3VsWOOimK+TJMSRkZGkbVMRPQdVylT3BcqUOIyrF0Vji+lXQqZ0qjXD6JcK0xiJlTgMiuar19hyPUM1tAaMfkPmCwsVGuPQ0BDEHCodG8CGRYpNQOlUU84X1rhyOFM4dh1MYzSMhqB+00pQ4iTUoIlctc8QVGiMWEhQ4iTUoIlctc8QRGiMKQ/crdu2CG3phljJjRhJlxKQqBumeG6LSLZpsQ7cjT2cV1u6IVZyI0bSpRQk6oaYh0rXkiuw1LFQzcwHbemGWMmNGEmXUuDWDTGe21KqXuR2X2cPle4m5YG7vm1TJSFCxEvyshrq9Js6uTG0Lfah0qlSHs8iQUbppu5zi34SfDerq6tuY2OjdCCp287MzLjh4WE3MDDgBgcH3d7eXs92BwcHbm5urvAeMQtVVkPdfqvqxW7r067u2KrWLJdNW/e5Jd20XKRIQoxdKAyNUToxY8PQGKVt2irINMaQ82kpiU1CxEpjzCHYrQhNGqNUI0r1+bRY53xSeqy9aqCcgxT9hhA7Z5RKK9YHYvbn0xqGkQbSNEZpqXa51KBN4cNA2xyI1xilptrlUINGhS81GudAhMZYhkTNzJdUOhqWGihtbjl01pA5yOFwb5JNK1Ez8yWFjoapBkqbWw6d1XcOcjncmyQjilsz8wVLR8NUAznnVorO6jsH2RzuXfUit/uCiBfP1IfzhvRbRpVyKUGllCQWcOisVXNwgpTDyFUcKo0FxcHHJ2l9169fr12DlIOiKaiaL4q2RUg53FvNodLUqXYh/cbCrVJK+qblwuebNjUYz624NMaSfr3b9upXQrKfBJUytt8QUsyZdhqRxohFrBKXYsNQHiotoV8JNXAT+8FlGqNhNATTGIXWkPOBztqUQwxMY8xQY8z1QGeNyiEGIjRGyYfzStXRmgj3s0CF+jRGTtVOso7WRKRpl1ioTmME4FXtROtoDUSL0hoCdRpjrU07Pj4O4+Pjp/7b48eP4cMPPyQ/nPfk1+btdhsGBgbgwoULp9pMTEzA/fv3YXJyElqtFksNEvrE7LfphO6H2HVI9sfjFJpZKBIUQgyVEkvPpNA+jWPK9gOpxhiTxsihMfoSc/+yGrDGJaHfmBo4tEuKcYWgPo0RS2Okun9RDdx6Jma/IYQ8Y00yolSnMcYSm8aIVUPsHFAf6Ey5ZkVIqMEXS2M0DMML0xiFkuu4AExjBDCNkV1jPCFlqBnFuHzrTTku32eB+4ze0LahiNAYy2iKukZ57m4KfGtIOS6fZ4H7jN7QttSQBLvNz8/D7u4u3L59GxYWFrIwYbjP3Q3Ftwbscfk8C9xn9Ia2JafqnVD3BQLezWH1W/f+3VAFlVHXmzqsrSoahvuMXt+2GM9t9sFuW1tbcHR0BADHr6OeP3/OWk+qoDKqcfnWm0sAm5Rgt+j1rdrVTvA3baoDimPGhUHKcUmEI4QNA4znNvtv2tnZWRgdHYVOpwPLy8vQ39/PXVISch2XcUzs+loaYyAh44oFS2M00mBpjBFQanmUfm4KjVGCy4v1AY4B5YecaYyG0RDUa4xYcGuE3PfHRNOzIFG/JQl2o0rgS6nlceuR3PfHQlMao1T9liTYjUpjxNDyjLRoUlql1koS7IahMWJoeQY+mpRWqbWSBLthJPD51iAlhM44RlMao9RaVQe71a2hqh23Hsl9fwOX2PVN9q98xsbGUnWFXkNVu8XFRVhbW4PDw0OYmpqClZUVUiuJ+/4GLrHra+9pe8CtEXLf38CFVGOUmsaIRWzCYgjceiYmuRpRKjRGqWmMRcTWGpuwGPsASk5N1LS5pCJeY8Q0S7TYNcYxmtZMYq2qg90wjZWcNUJfsA7LNiNKiMZYBpZZgmms+GhmVOfepkxDDGmLoVJSWEap5oD7+SqCZNNinUnalHNvsVIeOXROijVLNQdSz9IlSWPEMktyPPcWKw1Ris6Zes0w50CqEaU6IypVv71IkRgYM65uUqch+rbVtma+4wppSz0H2WdEYeGjmWGfe9sNVhqihERIDKjTI0MRozHmhK9mdvHiRZienkavJ0QRTdVWs0qJNV+piJ7bqq9il+iPx2Vw99sLCYdVc4N1sHUZmuYrZg7IDpWWmsYY229RnxIOq+aGOhEyR+MOoHe9JBojFlgbPBaJDysmverNIRFSwgn3vWisxogFlg2kybLSVG+Tni/VGiMmGDaQtrA2LfU27fkSpzFSqYFGPkgMYAtRKUMRpTFKPsjXkItE3RBTERWlMYo+yNcQC7duSK2Iivjt8QkTExNw//59mJychFarxVbH1tbW37/Za7fbMDAwABcuXBDXJyba6uUkNPEzdm5FaYyUamAZi4uLcOvWLWi32zA1NQX7+/si+8REW73SKNMjo+e2yr7ovgBJEg8hpIa6tTpXbKzE9JvKMMK6NNV7llRif+xz44t6IyoEyjNnY4PdevUrOaxNU72Bzy17Db2oa0SJ+uOxNFLYQBR9YqKtXk3UnVsRh0prcnSx7q/pTyYA/OuABdY3eMr5Uq8xalLtDDy0PV+N1hi1qHYGHhqfLxEaI1eqnQ8hamTObevClQbp25b7+aKG5FBpLamJObeNQUIapNTniwOSQ6W1pCbm3NYXCWmQmp4vFqpe5HZf/3vx2xOOVLvNzU23vb3tWq2W29nZcZ1O51zfPqmJObetO7fdcKVBhrTleL4w+iVNY+RItfPRwULUyJzbxsCRBhnTNhVYKiepxlj2TVsFVn5sigC2nImZW01wPF8Y/SbXGKWeTxsbwJYzmsQVLCSop779NkZjtA1rYIL1fJFojCGHSkuA8h8XUN0/tAYsND0HWPPFtQ7qv2klaIzcSYB2YDceEp6vs6jftNwaI3cSoMYDuzXB/Xz1Qv2m5YZbodN4YLckUqUm+uqkKbRT27SRcCt0OR7YTUmK1ERfnTSVdioq2E0j3Apdjgd2Y4GVmuirk6bSTlVvWksMNELASk30TRFNlTaq+o/HlhhoxJIiNdFXJ02lnaretLOzszA6OgqdTgeWl5fVHHpsyGFsbAxGRkZ6/izk+fI9YDzFQeRoaYxYYGlmMfcvQoJckfPRnL7EShCUmqyPxijCPdaEhIcVy8TJ1VPW9sHlnMvfPTaMJkGaxihNB8sFCbphrusrcVxkm1aiDpYDUnTDXNdX4rhI0hhDwNDBpKUbpkTTgd22ZmkgSWP0BUMHk5humBItB3bbmqWDJI3RFwwdjDvdEBstB3bbmiWkKo+m+6qbxuiTmuicfxKhlnRDrOssPseIYiY3Sk/FxBoX4vrypzFy6mAS0w05wBybllTMUMRqslW72nl+01aRS2oiKP2mTT22s0heX6xxIa6vnDTGHFITzYgqrkHq+kpIYwzBVRhRIk6C53Z0tel72j44MOaX+5kJrcG338ZEqBpGk2iUxihB9/NFwnyFgHn4MnfSZSMPlQbg18Gk6H6+cM9XKBj1SlizrA+Vlg5FuiDWgcpNpSmJkKE0ZtNSpAtiHajcVJqSCBmK6mC3EFKnC2Iekmwck2MiZApIvmm3trbg6OgIAI5T7Z4/f05xW1TGx8fh0aNHp67l5WW4dOnSub+rhLQF0Ddf2ur1BWtcsf2SfNMuLi7C2toaHB4ewtTUFKysrGQZwnaichYFhfm21TZf2ur1BWtc0f1WKVPdF+jTwU5RpPvF9IkFx3xJWN+zpFgzieMiO1Q61ohi0MG82kk1oqjnKwSs9cVYsxAox1X3UOmgPx4PDQ1BjHtc5KVibQTuhzt2UVN4vIEfylH3ovSOKT88pR0q3ZhXPoaRC+o3rQTdD0O1k6jPcfWrRT0NQYTGiGUDVbXl1v2wVDuJ+hx1v75zqzEwToTGiGUDSTeHTLXDw2dumxgYhx7slrs5ND8/D7u7u3D79m1YWFgwcychPnPbyMC4qndC3VfdYDfMtmVRK0DwHq+shph+U40r536d0xkYVzYun/e0JMFuWG1Nn8u7Xx80BsZFz1fVrnae37QczMzMuOHhYTcwMOAGBwfd3t7eqZ+D0m/alOPS3G/Z3EoAY75Iv2k5yPVQaaxxaes3V2LnK4vzaXM4VLpXv5T6nOR+paJCY2y3296FUj7cEmM7U4ClfVL3GwK3ekqJaYyG0RBI0xgxdLRcNUZtSFgHTYjQGMvATNXLVWPshksRDWnHvQ7aEKExlpGz6kcxNm5FVLpK2jRI4mZyVv1Sj41bEdWokjaOqhe53Rco09xC+q17/7IaYvs9gUsRDWlHvQ7aLt/5yl6uyFVjPAuHIhrSrinrkAoVaYxY5JoCeJaxsTHWtlXtmrIOqYidL9XftKbPycDWIQxSjVHC+bS9+s1VY9SG5PRIbnyfW3FpjFhQpzz6ImGDU0aNSl0HqYjXGHO2hrjHxn1/gxb1RhQ33GPjvr9BD0kaI5URlVLL823rOzasZD9pthm3cqmxbSgkaYxU54xiaHlVbX3GhpnsJ+0MV27lUkPbN98E6Ov7//Xmm5XdnQI9jREA55xRDC2vjsLnMzbMZD/OM1y5lUuNbQEA/vMfKP3/lVQpU91XTBojlsYYUkNoO9+2KRIDY+aAQyGMma+mtwU4f53gozEm27Srq6tuY2Oj8OcUm7aqhtB2vm19wscODg7c3Nxc4c9z2bSp5zbHtmI2bRUUm5aLFImBuWxao5r+/tMbtr///z9Lfj4tVrBb08yhXuQgVxjx+BhRqt1jw2gioe7xbwDwK145htF4/u2c+1dZg6BNaxgGP/bHY8NQhm1aw1CGbVrDUIZtWsNQhm1aw1CGbVrDUIZtWsNQhm1aw1CGbVrDUMZ/AV8KrKQ+sGjtAAAAAElFTkSuQmCC\n", 556 | "text/plain": [ 557 | "
" 558 | ] 559 | }, 560 | "metadata": {}, 561 | "output_type": "display_data" 562 | } 563 | ], 564 | "source": [ 565 | "agent.plot_policy_map(best_net, 'solution_best.pdf', [0.35,-0.3])" 566 | ] 567 | } 568 | ], 569 | "metadata": { 570 | "kernelspec": { 571 | "display_name": "Python 3", 572 | "language": "python", 573 | "name": "python3" 574 | }, 575 | "language_info": { 576 | "codemirror_mode": { 577 | "name": "ipython", 578 | "version": 3 579 | }, 580 | "file_extension": ".py", 581 | "mimetype": "text/x-python", 582 | "name": "python", 583 | "nbconvert_exporter": "python", 584 | "pygments_lexer": "ipython3", 585 | "version": "3.6.9" 586 | } 587 | }, 588 | "nbformat": 4, 589 | "nbformat_minor": 2 590 | } 591 | -------------------------------------------------------------------------------- /environment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.special as sp 3 | import matplotlib.pyplot as plt 4 | import copy 5 | 6 | class MazeEnvironment: 7 | def __init__(self, maze, init_position, goal): 8 | x = len(maze) 9 | y = len(maze) 10 | 11 | self.boundary = np.asarray([x, y]) 12 | self.init_position = init_position 13 | self.current_position = np.asarray(init_position) 14 | self.goal = goal 15 | self.maze = maze 16 | 17 | self.visited = set() 18 | self.visited.add(tuple(self.current_position)) 19 | 20 | # initialize the empty cells and the euclidean distance from 21 | # the goal (removing the goal cell itself) 22 | self.allowed_states = np.asarray(np.where(self.maze == 0)).T.tolist() 23 | self.distances = np.sqrt(np.sum((np.array(self.allowed_states) - 24 | np.asarray(self.goal))**2, 25 | axis = 1)) 26 | 27 | del(self.allowed_states[np.where(self.distances == 0)[0][0]]) 28 | self.distances = np.delete(self.distances, np.where(self.distances == 0)[0][0]) 29 | 30 | self.action_map = {0: [0, 1], 31 | 1: [0, -1], 32 | 2: [1, 0], 33 | 3: [-1, 0]} 34 | 35 | self.directions = {0: '→', 36 | 1: '←', 37 | 2: '↓ ', 38 | 3: '↑'} 39 | 40 | # the agent makes an action from the following: 41 | # 1 -> right, 2 -> left 42 | # 3 -> down, 4 -> up 43 | 44 | # introduce a reset policy, so that for high epsilon the initial 45 | # position is nearer to the goal (useful for large mazes) 46 | def reset_policy(self, eps, reg = 7): 47 | return sp.softmax(-self.distances/(reg*(1-eps**(2/reg)))**(reg/2)).squeeze() 48 | 49 | # reset the environment when the game is completed 50 | # with probability prand the reset is random, otherwise 51 | # the reset policy at the given epsilon is used 52 | def reset(self, epsilon, prand = 0): 53 | if np.random.rand() < prand: 54 | idx = np.random.choice(len(self.allowed_states)) 55 | else: 56 | p = self.reset_policy(epsilon) 57 | idx = np.random.choice(len(self.allowed_states), p = p) 58 | 59 | self.current_position = np.asarray(self.allowed_states[idx]) 60 | 61 | self.visited = set() 62 | self.visited.add(tuple(self.current_position)) 63 | 64 | return self.state() 65 | 66 | 67 | def state_update(self, action): 68 | isgameon = True 69 | 70 | # each move costs -0.05 71 | reward = -0.05 72 | 73 | move = self.action_map[action] 74 | next_position = self.current_position + np.asarray(move) 75 | 76 | # if the goals has been reached, the reward is 1 77 | if (self.current_position == self.goal).all(): 78 | reward = 1 79 | isgameon = False 80 | return [self.state(), reward, isgameon] 81 | 82 | # if the cell has been visited before, the reward is -0.2 83 | else: 84 | if tuple(self.current_position) in self.visited: 85 | reward = -0.2 86 | 87 | # if the moves goes out of the maze or to a wall, the 88 | # reward is -1 89 | if self.is_state_valid(next_position): 90 | self.current_position = next_position 91 | else: 92 | reward = -1 93 | 94 | self.visited.add(tuple(self.current_position)) 95 | return [self.state(), reward, isgameon] 96 | 97 | # return the state to be feeded to the network 98 | def state(self): 99 | state = copy.deepcopy(self.maze) 100 | state[tuple(self.current_position)] = 2 101 | return state 102 | 103 | 104 | def check_boundaries(self, position): 105 | out = len([num for num in position if num < 0]) 106 | out += len([num for num in (self.boundary - np.asarray(position)) if num <= 0]) 107 | return out > 0 108 | 109 | 110 | def check_walls(self, position): 111 | return self.maze[tuple(position)] == 1 112 | 113 | 114 | def is_state_valid(self, next_position): 115 | if self.check_boundaries(next_position): 116 | return False 117 | elif self.check_walls(next_position): 118 | return False 119 | return True 120 | 121 | 122 | def draw(self, filename): 123 | plt.figure() 124 | im = plt.imshow(self.maze, interpolation='none', aspect='equal', cmap='Greys'); 125 | ax = plt.gca(); 126 | 127 | plt.xticks([], []) 128 | plt.yticks([], []) 129 | 130 | ax.plot(self.goal[1], self.goal[0], 131 | 'bs', markersize = 4) 132 | ax.plot(self.current_position[1], self.current_position[0], 133 | 'rs', markersize = 4) 134 | plt.savefig(filename, dpi = 300, bbox_inches = 'tight') 135 | plt.show() 136 | -------------------------------------------------------------------------------- /maze_generator/maze.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giorgionicoletti/deep_Q_learning_maze/5806405207d9f49d043047eae328c885e08dbfc6/maze_generator/maze.npy -------------------------------------------------------------------------------- /maze_generator/maze_generator.py: -------------------------------------------------------------------------------- 1 | # Adapted from http://code.activestate.com/recipes/578356-random-maze-generator/ 2 | # Random Maze Generator using Depth-first Search 3 | # http://en.wikipedia.org/wiki/Maze_generation_algorithm 4 | # FB36 - 20130106 5 | 6 | import random 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | 10 | 11 | mx = 20; my = 20 # width and height of the maze 12 | 13 | maze = [[0 for x in range(mx)] for y in range(my)] 14 | dx = [0, 1, 0, -1]; dy = [-1, 0, 1, 0] # 4 directions to move in the maze 15 | color = [(0, 0, 0), (255, 255, 255)] # RGB colors of the maze 16 | 17 | # start the maze from a random cell 18 | cx = random.randint(0, mx - 1) 19 | cy = random.randint(0, my - 1) 20 | maze[cy][cx] = 1 21 | stack = [(cx, cy, 0)] # stack element: (x, y, direction) 22 | 23 | while len(stack) > 0: 24 | (cx, cy, cd) = stack[-1] 25 | # to prevent zigzags: 26 | # if changed direction in the last move then cannot change again 27 | if len(stack) > 2: 28 | if cd != stack[-2][2]: dirRange = [cd] 29 | else: dirRange = range(4) 30 | else: dirRange = range(4) 31 | 32 | # find a new cell to add 33 | nlst = [] # list of available neighbors 34 | for i in dirRange: 35 | nx = cx + dx[i] 36 | ny = cy + dy[i] 37 | if nx >= 0 and nx < mx and ny >= 0 and ny < my: 38 | if maze[ny][nx] == 0: 39 | ctr = 0 # of occupied neighbors must be 1 40 | for j in range(4): 41 | ex = nx + dx[j]; ey = ny + dy[j] 42 | if ex >= 0 and ex < mx and ey >= 0 and ey < my: 43 | if maze[ey][ex] == 1: ctr += 1 44 | if ctr == 1: nlst.append(i) 45 | 46 | # if 1 or more neighbors available then randomly select one and move 47 | if len(nlst) > 0: 48 | ir = nlst[random.randint(0, len(nlst) - 1)] 49 | cx += dx[ir]; cy += dy[ir]; maze[cy][cx] = 1 50 | stack.append((cx, cy, ir)) 51 | else: stack.pop() 52 | 53 | maze = np.array(maze) 54 | maze -= 1 55 | maze = abs(maze) 56 | 57 | maze[0][0] = 0 58 | maze[mx-1][my-1] = 0 59 | 60 | np.save('maze', np.array(maze)) 61 | --------------------------------------------------------------------------------