├── LICENSE ├── README.md ├── __pycache__ ├── grid_agent.cpython-35.pyc └── world.cpython-35.pyc ├── _remnants_ ├── belief.py ├── experiment.py ├── grid_agent.py ├── grid_agent.pyc ├── macros.py ├── macros.pyc ├── visualize.py ├── visualize.pyc ├── world.py └── world.pyc ├── astar.py ├── cbsearch.py ├── gworld.py ├── m_astar.py ├── macros.py ├── pqueue.py ├── pqueue0.py ├── ta_test.py ├── tb_test.py ├── test.py ├── ts_astar_test.py └── visualize.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Barnabas Gavin Cangan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ECE6504 Autonomous Coordination - Presentation One 2 | # Centralized methods for Multi-Agent Path Finding 3 | 4 | Implementation of Multi-Agent Path Finding in a gridworld simulation. 5 | 6 | - Find independent paths for each agent without considering other agents 7 | - Uses space-time A* for low-level search. 8 | - Make reservations in a space-time reservation table 9 | - Check all paths against the table for conflicts with other agents. 10 | - When a conflict is found, add a constraint to the agent's low-level path planning and re-plan. 11 | 12 | - To avoid agents passing right through each other, currently each agent makes multiple reservations in the reservation table. 13 | - This sometimes causes a noticeable delay, where one or more agents may end up waiting in their location although a path is clearly available. 14 | - To fix this, I need to be able to add two different sorts of constraints as in the original Conflict Based Search paper - for vertext collisions and edge collisions. 15 | 16 | - Link to YouTube video: https://youtu.be/b5KMm729b_4 17 | 18 | -- 19 | The A* priority queue is from this gridworld simluator on GitHub: https://github.com/TheLastBanana 20 | -------------------------------------------------------------------------------- /__pycache__/grid_agent.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gavincangan/multiagent-pathfinding/a151f5fb81e925988fc2f228ee1916cc6e254629/__pycache__/grid_agent.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/world.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gavincangan/multiagent-pathfinding/a151f5fb81e925988fc2f228ee1916cc6e254629/__pycache__/world.cpython-35.pyc -------------------------------------------------------------------------------- /_remnants_/belief.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from macros import * 3 | import numpy as np 4 | import random 5 | import world 6 | import grid_agent as ga 7 | 8 | class Belief: 9 | def __init__(self, nrows, ncols): 10 | 11 | 12 | @staticmethod 13 | def gauss_update(mean1, var1, mean2, var2): 14 | new_mean = (mean1 * var2 + mean2 * var1)/(var1 + var2) 15 | new_var = 1/((1/var1) + (1/var2)) 16 | return [new_mean, new_var] 17 | -------------------------------------------------------------------------------- /_remnants_/experiment.py: -------------------------------------------------------------------------------- 1 | from macros import * 2 | import numpy as np 3 | import random 4 | import grid_agent as ga 5 | from world import world 6 | from visualize import Visualize 7 | 8 | class Experiment: 9 | def __init__(self, world_dim, nagents, agent_xy, goal_xy): 10 | self.world = world(world_dim[1], world_dim[0]) 11 | self.vis = Visualize(self.world) 12 | self.list_agents = [] 13 | for index in range(nagents): 14 | self.list_agents.append(self.world.new_agent(agent_xy[index][1], agent_xy[index][0], goal_xy[index][1], goal_xy[index][0])) 15 | self.init_vis() 16 | 17 | def init_vis(self): 18 | self.vis.draw_world() 19 | self.vis.draw_agents() 20 | self.vis.canvas.pack() 21 | self.vis.canvas.update() 22 | 23 | def run_random(self, ts, T): 24 | nsteps = int(T/ts) 25 | nagents = len(self.list_agents) 26 | for step in range(nsteps): 27 | random.shuffle(self.list_agents) 28 | for agent in self.list_agents: 29 | agent.move(random.choice(agent.get_move_actions())) 30 | self.vis.canvas.update() 31 | self.vis.canvas.after(int((ts * 900) /nagents)) 32 | print '\n' 33 | for agent in self.list_agents: 34 | print agent 35 | print '\n\n' 36 | self.vis.canvas.after(int(ts * 100)) 37 | 38 | if __name__ == "__main__": 39 | # my_exp = experiment( (10,10), 7, [(3,2),(1,6),(7,8),(2,6),(0,9),(5,6),(4,7)] ) 40 | # my_exp.run_random(1, 10) 41 | # my_exp = Experiment( (5,5), 6, [(3,2),(1,2),(2,4),(0,4),(3,0),(4,4)], [(1,2),(2,4),(0,4),(3,0),(4,4),(3,2)] ) 42 | # my_exp.run_random(2, 5) 43 | 44 | my_exp = Experiment( (5,5), 1, [(3,2)], [(4,4)] ) 45 | my_exp.run_random(0.5, 5) 46 | -------------------------------------------------------------------------------- /_remnants_/grid_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from macros import * 3 | from builtins import object 4 | import numpy as np 5 | from collections import deque 6 | import world 7 | 8 | class CentralizedAgent: 9 | def_move_actions = (AgentActions.WAIT, AgentActions.UP, AgentActions.DOWN, AgentActions.LEFT, AgentActions.RIGHT) 10 | def_obs_actions = (Quadrants.QUAD1, Quadrants.QUAD2, Quadrants.QUAD3, Quadrants.QUAD4) 11 | comm_actions = range(MSG_LIMITLOWER, MSG_LIMITUPPER + 1) 12 | agent_count = 0 13 | agent_by_index = dict() 14 | verbose = True 15 | def __init__(self, world_obj, y, x, gy, gx): 16 | self.world_act = world_obj 17 | self.obs_map = np.ones_like(world_obj.occ_map) 18 | self.y = INVALID 19 | self.x = INVALID 20 | self.gy = gy 21 | self.gx = gx 22 | self.reachedGoal = False 23 | self.aindex = CentralizedAgent.agent_count 24 | self.states = (self.x, self.y, self.obs_map) 25 | self.vis_obj = 0 26 | self.world_act.add_agent(self, y, x, gy, gx) 27 | CentralizedAgent.agent_by_index[CentralizedAgent.agent_count] = self 28 | CentralizedAgent.agent_count +=1 29 | 30 | @staticmethod 31 | def __move_cmd_to_vector__(move_cmd): 32 | dy = 0 33 | dx = 0 34 | if(move_cmd == AgentActions.UP): 35 | dy = -MOVE_SPEED 36 | elif(move_cmd == AgentActions.DOWN): 37 | dy = MOVE_SPEED 38 | elif(move_cmd == AgentActions.LEFT): 39 | dx = -MOVE_SPEED 40 | elif(move_cmd == AgentActions.RIGHT): 41 | dx = MOVE_SPEED 42 | else: 43 | pass 44 | return (dy, dx) 45 | 46 | # Represent the view of the agewnt in matrix indices 47 | # => Q1 has a negative dy, for example 48 | @staticmethod 49 | def __quadrant_to_dxdy__(quadrant): 50 | if(quadrant == 1): 51 | dx = SENSE_RANGE 52 | dy = -SENSE_RANGE 53 | elif(quadrant == 2): 54 | dx = -SENSE_RANGE 55 | dy = -SENSE_RANGE 56 | elif(quadrant == 3): 57 | dx = -SENSE_RANGE 58 | dy = SENSE_RANGE 59 | else: #(quadrant == 4) 60 | dx = SENSE_RANGE 61 | dy = SENSE_RANGE 62 | return (dy, dx) 63 | 64 | def move(self, move_cmd): 65 | # 0 - wait, 1 - up, 2 - down 66 | # 3 - left, 4 - right 67 | if(move_cmd in self.get_move_actions()): 68 | wnrows, wncols = self.world_act.get_size() 69 | (dy, dx) = self.__move_cmd_to_vector__(move_cmd) 70 | # print 'dy:', dy, ' dx:', dx 71 | new_y = (self.y + dy) % wnrows 72 | new_x = (self.x + dx) % wncols 73 | # print 'New position: ', new_y, new_x 74 | self.update_position(new_y, new_x) 75 | if(self.x == self.gx and self.y == self.gy): 76 | self.reachedGoal = True 77 | else: 78 | print 'Error! Cmd:', move_cmd, 'PosYX:', self.y, self,x 79 | raise EnvironmentError 80 | 81 | def __str__(self): 82 | return('#' + str(self.aindex) + ' @ (' + str(self.y) + ', ' + str(self.x) + ') -> (' + str(self.gy) + ', ' + str(self.gx)+ ')') 83 | 84 | def update_position(self, pos_y, pos_x): 85 | old_x = self.x 86 | old_y = self.y 87 | self.x = pos_x 88 | self.y = pos_y 89 | if not(old_x == self.x and old_y == self.y): 90 | if (old_x >= 0 and old_x < self.world_act.ncols and old_y >= 0 and old_y < self.world_act.nrows): 91 | self.world_act.occ_map[old_y][old_x] -= 1 92 | self.world_act.ptr_map[old_y][old_x].remove(self) 93 | self.world_act.occ_map[self.y][self.x] += 1 94 | self.world_act.ptr_map[self.y][self.x].append(self) 95 | if(self.vis_obj): 96 | self.world_act.visualize.move_agent_vis(self, self.vis_obj, old_y, old_x, pos_y, pos_x) 97 | 98 | def get_move_actions(self): 99 | ret_moveactions = [] 100 | ymin, xmin = 0, 0 101 | ysize, xsize = self.world_act.get_size() 102 | xmax = xsize - 1 103 | ymax = ysize - 1 104 | if(self.x < xmax - MOVE_SPEED + 1): 105 | next_cell = self.world_act.ptr_map[self.y][self.x + MOVE_SPEED] 106 | if(not next_cell): 107 | ret_moveactions.append(AgentActions.RIGHT) 108 | if(self.x > xmin + MOVE_SPEED - 1): 109 | next_cell = self.world_act.ptr_map[self.y][self.x - MOVE_SPEED] 110 | if(not next_cell): 111 | ret_moveactions.append(AgentActions.LEFT) 112 | if(self.y < ymax - MOVE_SPEED + 1): 113 | next_cell = self.world_act.ptr_map[self.y + MOVE_SPEED][self.x] 114 | if(not next_cell): 115 | ret_moveactions.append(AgentActions.DOWN) 116 | if(self.y > ymin + MOVE_SPEED - 1): 117 | next_cell = self.world_act.ptr_map[self.y - MOVE_SPEED][self.x] 118 | if(not next_cell): 119 | ret_moveactions.append(AgentActions.UP) 120 | # print '##RM:', ret_moveactions, self.y, self.x 121 | return ret_moveactions 122 | -------------------------------------------------------------------------------- /_remnants_/grid_agent.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gavincangan/multiagent-pathfinding/a151f5fb81e925988fc2f228ee1916cc6e254629/_remnants_/grid_agent.pyc -------------------------------------------------------------------------------- /_remnants_/macros.py: -------------------------------------------------------------------------------- 1 | 2 | INVALID = -999 3 | MSG_LIMITLOWER = 0x0 4 | MSG_LIMITUPPER = 0xF 5 | 6 | SENSE_RANGE = 2 7 | COMM_RANGE = 2 8 | MOVE_SPEED = 1 9 | MSG_BUFFER_SIZE = 3 10 | 11 | FRAME_HEIGHT = 600 12 | FRAME_WIDTH = 600 13 | 14 | FRAME_MARGIN = 10 15 | CELL_MARGIN = 5 16 | 17 | MAX_AGENTS_IN_CELL = 1 18 | 19 | class AgentActions(object): 20 | WAIT = 0 21 | UP = 1 22 | DOWN = 2 23 | LEFT = 3 24 | RIGHT = 4 25 | 26 | class Quadrants(object): 27 | QUAD1 = 1 28 | QUAD2 = 2 29 | QUAD3 = 3 30 | QUAD4 = 4 31 | 32 | COLORS = ['red', 'green', 'blue', 'black', 'white', 'magenta', 'cyan', 'yellow'] 33 | -------------------------------------------------------------------------------- /_remnants_/macros.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gavincangan/multiagent-pathfinding/a151f5fb81e925988fc2f228ee1916cc6e254629/_remnants_/macros.pyc -------------------------------------------------------------------------------- /_remnants_/visualize.py: -------------------------------------------------------------------------------- 1 | from macros import * 2 | import numpy as np 3 | from world import world 4 | import grid_agent as ga 5 | from Tkinter import * 6 | 7 | class Visualize: 8 | def __init__(self, world_data): 9 | self.frame = Tk() 10 | self.canvas = Canvas(self.frame, width=FRAME_WIDTH, height=FRAME_HEIGHT) 11 | self.canvas.grid() 12 | self.world = world_data 13 | world_data.visualize = self 14 | self.cell_h, self.cell_w = self.get_cell_size() 15 | self.agent_h, self.agent_w = self.get_agent_size(1) 16 | self.vis_world_ptr = [] 17 | 18 | def draw_world(self): 19 | nrows, ncols = self.world.get_size() 20 | for row in range(nrows): 21 | curr_row = [] 22 | for col in range(ncols): 23 | cell = self.canvas.create_rectangle(FRAME_MARGIN + self.cell_w * col, FRAME_MARGIN + self.cell_h * row, FRAME_MARGIN + self.cell_w * (col+1), FRAME_MARGIN + self.cell_h * (row+1) ) 24 | curr_row.append(cell) 25 | self.vis_world_ptr.append(curr_row) 26 | 27 | def get_pos_in_cell(self, crow, ccol, index, nagents): 28 | if(MAX_AGENTS_IN_CELL == 1): 29 | agent_h = self.agent_h 30 | agent_w = self.agent_w 31 | agent_y1 = FRAME_MARGIN + (crow * self.cell_h) + CELL_MARGIN 32 | agent_y2 = agent_y1 + agent_h 33 | agent_x1 = FRAME_MARGIN + (ccol * self.cell_w) + CELL_MARGIN 34 | agent_x2 = agent_x1 + agent_w 35 | elif(MAX_AGENTS_IN_CELL < 5): 36 | agent_h, agent_w = self.get_agent_size(MAX_AGENTS_IN_CELL) 37 | agent_y1 = FRAME_MARGIN + (crow * self.cell_h) + CELL_MARGIN + ((index/2) * (CELL_MARGIN + agent_h)) 38 | agent_y2 = agent_y1 + agent_h 39 | agent_x1 = FRAME_MARGIN + (ccol * self.cell_w) + CELL_MARGIN + ((index%2) * (CELL_MARGIN + agent_w)) 40 | agent_x2 = agent_x1 + agent_w 41 | else: 42 | raise NotImplementedError 43 | return (agent_y1, agent_x1, agent_y2, agent_x2) 44 | 45 | def draw_agents(self): 46 | for crow in range(self.world.nrows): 47 | for ccol in range(self.world.ncols): 48 | cell = self.world.ptr_map[crow][ccol] 49 | if(cell): 50 | nagents = len(cell) 51 | for agent in range(nagents): 52 | y1, x1, y2, x2 = self.get_pos_in_cell(crow, ccol, agent, nagents) 53 | cell[agent].vis_obj = self.canvas.create_oval(x1, y1, x2, y2, fill=COLORS[cell[agent].aindex]) 54 | # print gy, gx, self.vis_world_ptr 55 | goal_cell = self.vis_world_ptr[cell[agent].gy][cell[agent].gx] 56 | self.canvas.itemconfig(goal_cell, outline=COLORS[cell[agent].aindex], width=2) 57 | 58 | def move_agent_vis(self, agent_obj, vis_obj, orow, ocol, crow, ccol): 59 | ocell = self.world.ptr_map[orow][ocol] 60 | ncell = self.world.ptr_map[crow][ccol] 61 | if(ocell): 62 | nagents = len(ocell) 63 | for agent in range(nagents): 64 | y1, x1, y2, x2 = self.get_pos_in_cell(orow, ocol, agent, nagents) 65 | self.canvas.coords(ocell[agent].vis_obj, x1, y1, x2, y2) 66 | if(ncell): 67 | nagents = len(ncell) 68 | for agent in range(nagents): 69 | y1, x1, y2, x2 = self.get_pos_in_cell(crow, ccol, agent, nagents) 70 | self.canvas.coords(ncell[agent].vis_obj, x1, y1, x2, y2) 71 | 72 | def get_cell_size(self): 73 | avail_h = FRAME_HEIGHT - 2 * FRAME_MARGIN 74 | avail_w = FRAME_WIDTH - 2 * FRAME_MARGIN 75 | nrows, ncols = self.world.get_size() 76 | cell_h = avail_h / nrows 77 | cell_w = avail_w / ncols 78 | return (cell_h, cell_w) 79 | 80 | def get_agent_size(self, nagents): 81 | if(MAX_AGENTS_IN_CELL == 1): 82 | agent_h = self.cell_h - 2 * CELL_MARGIN 83 | agent_w = self.cell_w - 2 * CELL_MARGIN 84 | elif(MAX_AGENTS_IN_CELL < 5): 85 | agent_h = (self.cell_h - 3 * CELL_MARGIN) / 2 86 | agent_w = (self.cell_w - 3 * CELL_MARGIN) / 2 87 | else: 88 | raise NotImplementedError 89 | return (agent_h, agent_w) 90 | 91 | def do_loop(self): 92 | self.frame.mainloop() 93 | 94 | def do_pack(self): 95 | self.canvas.pack() 96 | -------------------------------------------------------------------------------- /_remnants_/visualize.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gavincangan/multiagent-pathfinding/a151f5fb81e925988fc2f228ee1916cc6e254629/_remnants_/visualize.pyc -------------------------------------------------------------------------------- /_remnants_/world.py: -------------------------------------------------------------------------------- 1 | from macros import * 2 | import numpy as np 3 | import grid_agent as ga 4 | 5 | class world: 6 | def __init__(self, nrows, ncols): 7 | self.nrows = nrows 8 | self.ncols = ncols 9 | self.occ_map = np.zeros((nrows, ncols)) #occupancy map 10 | self.ptr_map = [[[] for row in range(ncols)] for row in range(nrows)] 11 | self.agents = [] 12 | self.visualize = None 13 | 14 | def add_agent(self, agent_obj, pos_y, pos_x, goal_y, goal_x): 15 | pos_x, pos_y = self.xy_saturate(pos_x, pos_y) 16 | # goal_x, goal_y = self.xy_saturate(goal_x, goal_y) 17 | agent_obj.update_position(pos_y, pos_x) 18 | print 'Adding: ', str(agent_obj) 19 | self.agents.append(agent_obj) 20 | 21 | def new_agent(self, pos_row, pos_col, goal_row, goal_col): 22 | agent_obj = ga.CentralizedAgent(self, pos_row, pos_col, goal_row, goal_col) 23 | #agent_obj.update_position(self, pos_row, pos_col) 24 | return agent_obj 25 | 26 | def move_agent(self, agent_obj, move_cmd): 27 | agent_obj.move(move_cmd) 28 | raise NotImplementedError 29 | 30 | def rm_agent(self, agent_obj): 31 | self.agents.remove(agent_obj) 32 | 33 | def get_size(self): 34 | return (self.nrows, self.ncols) 35 | 36 | def xy_saturate(self, x,y): 37 | if(x<0): x=0 38 | if(x>self.ncols-1): x=self.ncols-1 39 | if(y<0): y=0 40 | if(y>self.nrows-1): y=self.nrows-1 41 | return(x, y) 42 | 43 | def occ_map_view(self, y, x, dy, dx): 44 | # print '$$', x, dx, y, dy 45 | if(dx < 0): 46 | x = x + dx 47 | dx = dx * (-1) 48 | if(dy < 0): 49 | y = y + dy 50 | dy = dy * (-1) 51 | x1 = x + dx 52 | y1 = y + dy 53 | x, y = self.xy_saturate(x, y) 54 | x1, y1 = self.xy_saturate(x1, y1) 55 | # print '##', x,x1,y,y1,'\n',self.occ_map, '\n', self.occ_map[y: y1, x :x1], '\n\t#$#' 56 | return (y, x, y1, x1, self.occ_map[y: y1, x :x1]) 57 | 58 | def agents_in_range(self, y1, x1, y2, x2): 59 | # print(y1, x1, y2, x2) 60 | x1,y1 = self.xy_saturate(x1, y1) 61 | x2,y2 = self.xy_saturate(x2, y2) 62 | if(x2 > x1): 63 | sx = x1 64 | bx = x2 65 | else: 66 | sx = x2 67 | bx = x1 68 | if(y2 > y1): 69 | sy = y1 70 | by = y2 71 | else: 72 | sy = y2 73 | by = y1 74 | # print(sy, sx, by, bx) 75 | # print(sy, sx, by-sy+1, bx-sx+1) 76 | ptr_map_range = self.ptr_map[sy: by-sy+1] 77 | ptr_map_range = [ cells_row[sx:bx-sx+1] for cells_row in ptr_map_range ] 78 | list_agents = [] 79 | for row in ptr_map_range: 80 | for cell in row: 81 | if cell: #is not empty 82 | for agent in cell: 83 | list_agents.append(agent) 84 | return list_agents 85 | 86 | def list_all_agents(self): 87 | return self.agents 88 | 89 | def print_all_agents(self): 90 | for agent in self.agents: 91 | print(str(agent)) 92 | -------------------------------------------------------------------------------- /_remnants_/world.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gavincangan/multiagent-pathfinding/a151f5fb81e925988fc2f228ee1916cc6e254629/_remnants_/world.pyc -------------------------------------------------------------------------------- /astar.py: -------------------------------------------------------------------------------- 1 | import pqueue 2 | 3 | def manhattan_dist(a, b): 4 | """ 5 | Returns the Manhattan distance between two points. 6 | 7 | >>> manhattan_dist((0, 0), (5, 5)) 8 | 10 9 | >>> manhattan_dist((0, 5), (10, 7)) 10 | 12 11 | >>> manhattan_dist((12, 9), (2, 3)) 12 | 16 13 | >>> manhattan_dist((0, 5), (5, 0)) 14 | 10 15 | """ 16 | return abs(a[0] - b[0]) + abs(a[1] - b[1]) 17 | 18 | def find_path(neighbour_fn, 19 | start, 20 | end, 21 | cost = lambda pos: 1, 22 | passable = lambda pos: True, 23 | heuristic = manhattan_dist, 24 | stopCondOr = lambda x=0: False, 25 | stopCondAnd = lambda x=0: True, 26 | costs = 0): 27 | """ 28 | Returns the path between two nodes as a list of nodes using the A* 29 | algorithm. 30 | If no path could be found, an empty list is returned. 31 | 32 | The cost function is how much it costs to leave the given node. This should 33 | always be greater than or equal to 1, or shortest path is not guaranteed. 34 | 35 | The passable function returns whether the given node is passable. 36 | 37 | The heuristic function takes two nodes and computes the distance between the 38 | two. Underestimates are guaranteed to provide an optimal path, but it may 39 | take longer to compute the path. Overestimates lead to faster path 40 | computations, but may not give an optimal path. 41 | """ 42 | # tiles to check (tuples of (x, y), cost) 43 | todo = pqueue.PQueue() 44 | todo.update(start, 0) 45 | 46 | # tiles we've been to 47 | visited = set() 48 | 49 | # associated G and H costs for each tile (tuples of G, H) 50 | if(not costs): 51 | costs = { start: (0, heuristic(start, end)) } 52 | 53 | # parents for each tile 54 | parents = {} 55 | 56 | while ( ( (todo and (end not in visited) ) and stopCondAnd()) or stopCondOr()): 57 | cur, c = todo.pop_smallest() 58 | 59 | visited.add(cur) 60 | 61 | # check neighbours 62 | for n in neighbour_fn(cur): 63 | # skip it if we've already checked it, or if it isn't passable 64 | if ((n in visited) or 65 | (not passable(n))): 66 | continue 67 | 68 | if not (n in todo): 69 | # we haven't looked at this tile yet, so calculate its costs 70 | g = costs[cur][0] + cost(cur) 71 | h = heuristic(n, end) 72 | costs[n] = (g, h) 73 | parents[n] = cur 74 | todo.update(n, g + h) 75 | else: 76 | # if we've found a better path, update it 77 | g, h = costs[n] 78 | new_g = costs[cur][0] + cost(cur) 79 | if new_g < g: 80 | g = new_g 81 | todo.update(n, g + h) 82 | costs[n] = (g, h) 83 | parents[n] = cur 84 | 85 | # we didn't find a path 86 | if end not in visited: 87 | return [] 88 | 89 | # build the path backward 90 | path = [] 91 | while end != start: 92 | path.append(end) 93 | end = parents[end] 94 | path.append(start) 95 | path.reverse() 96 | 97 | return path, (len(path) - 1) 98 | -------------------------------------------------------------------------------- /cbsearch.py: -------------------------------------------------------------------------------- 1 | from macros import * 2 | from gworld import * 3 | from visualize import * 4 | import astar 5 | import m_astar 6 | import random 7 | 8 | def get_m_astar_path(world, start, goal, constraints = None): 9 | ret_path = m_astar.find_path(world.get_nbor_cells, 10 | start, 11 | goal, 12 | lambda cell: 1, 13 | lambda cell, constraints = None: world.passable( cell, constraints ), 14 | world.tyx_dist_heuristic, 15 | constraints) 16 | return ret_path 17 | 18 | def get_astar_path(world, start, goal): 19 | ret_path, pathcost = astar.find_path(world.get_nbor_cells, 20 | start, 21 | goal, 22 | lambda cell: 1, 23 | lambda cell: world.passable( cell ) ) 24 | return ret_path, pathcost 25 | 26 | def path_spacetime_conv(path_yx, tstart = 0): 27 | path_tyx = [] 28 | tcurr = tstart 29 | for step_yx in path_yx: 30 | step_tyx = ( tcurr, step_yx[0], step_yx[1] ) 31 | path_tyx.append(step_tyx) 32 | tcurr = tcurr + 1 33 | return (tcurr - tstart), path_tyx 34 | 35 | def cell_spacetime_conv(cell, t): 36 | return ( (t, cell[0], cell[1]) ) 37 | 38 | def get_max_pathlen(agents, path_seq): 39 | max_pathlen = 0 40 | for agent in agents: 41 | pathlen = len(path_seq[agent]) 42 | max_pathlen = pathlen if pathlen > max_pathlen else max_pathlen 43 | return max_pathlen 44 | 45 | def path_equalize(agents, path_seq, max_pathlen = -1): 46 | if(max_pathlen < 0): 47 | max_pathlen = get_maxpathlen(agents, path_seq) 48 | for agent in agents: 49 | path = path_seq[agent] 50 | lstep = path[-1] 51 | for step in range(len(path), max_pathlen + TWAIT): 52 | path.append( ( step, lstep[1], lstep[2] ) ) 53 | path_seq[agent] = path 54 | return path_seq 55 | 56 | def steptime_agtb(a, b): 57 | if(a[0] > b[0]): return True 58 | return False 59 | 60 | def tplusone(step): 61 | return ( (step[0]+1, step[1], step[2]) ) 62 | 63 | def get_conflicts(agents, path_seq, conflicts_db = None): 64 | tyx_map = dict() 65 | if(not bool(conflicts_db)): 66 | conflicts_db = dict() 67 | random.shuffle(agents) 68 | for agent in agents: 69 | if(agent not in conflicts_db): 70 | conflicts_db[agent] = set() 71 | if(path_seq[agent]): 72 | pathlen = len(path_seq[agent]) 73 | for t, tstep in enumerate(path_seq[agent]): 74 | twosteps = [tstep] #, tplusone(tstep)] 75 | if(t > 0 ): twosteps.append( tplusone(path_seq[agent][t-1]) ) 76 | for step in twosteps: 77 | # print 'bTYXMap: ', tyx_map 78 | if(step not in tyx_map): 79 | tyx_map[step] = agent 80 | else: 81 | otheragent = tyx_map[step] 82 | if(step not in conflicts_db[agent] and agent!=otheragent): 83 | conflicts_db[agent].update( {step} ) 84 | # if(t > 0): conflicts_db[agent].update( { tplusone( path_seq[agent][t-1] ) } ) 85 | # if(bool(conflicts_db[otheragent])): 86 | # otherconflict = conflicts_db[otheragent] 87 | # # if( steptime_agtb(otherconflict, step) ): 88 | # if(step not in conflicts_db[otheragent]): 89 | # conflicts_db[otheragent].update( {step} ) 90 | # else: 91 | # conflicts_db[otheragent].update( {step} ) 92 | # print 'bTYXMap: ', tyx_map 93 | return conflicts_db 94 | 95 | def evaluate_path(path_seq, agent, conflicts_db): 96 | all_okay = True 97 | tpath = path_seq[agent] 98 | tconstraints = conflicts_db[agent] 99 | for constraint in tconstraints: 100 | if(constraint in tpath): 101 | all_okay = False 102 | 103 | def search(agents, world): 104 | path_seq = dict() 105 | pathcost = dict() 106 | agent_goal = dict() 107 | max_pathlen = 0 108 | restart_loop = False 109 | 110 | for agent in agents: 111 | start = world.aindx_cpos[agent] 112 | goal = world.aindx_goal[agent] 113 | pathseq_yx, pathcost[agent] = get_astar_path(world, start, goal) 114 | pathlen, path_seq[agent] = path_spacetime_conv( pathseq_yx ) 115 | max_pathlen = pathlen if pathlen > max_pathlen else max_pathlen 116 | 117 | conflicts_db = get_conflicts(agents, path_seq) 118 | 119 | iter_count = 1 120 | pickd_agents = [] 121 | while(True): # iter_count < 5): 122 | max_pathlen = get_max_pathlen(agents, path_seq) 123 | path_seq = path_equalize(agents, path_seq, max_pathlen) 124 | 125 | if(iter_count % 2 == 1): 126 | pickd_agents = [] 127 | nagents = len(agents) 128 | random.shuffle(agents) 129 | pickd_agents = agents[(nagents/2):] 130 | else: 131 | temp_pickd_agents = [] 132 | for agent in agents: 133 | if agent not in pickd_agents: 134 | temp_pickd_agents.append(agent) 135 | pickd_agents = temp_pickd_agents 136 | 137 | if(restart_loop): 138 | restart_loop = False 139 | print '\n\nStuck between a rock and a hard place?\nRapid Random Restart to the rescue!\n\n' 140 | # something = input('Press 1 + to continue...') 141 | for agent in agents: 142 | conflicts_db[agent] = set() 143 | start = world.aindx_cpos[agent] 144 | goal = world.aindx_goal[agent] 145 | pathseq_yx, pathcost[agent] = get_astar_path(world, start, goal) 146 | pathlen, path_seq[agent] = path_spacetime_conv( pathseq_yx ) 147 | max_pathlen = pathlen if pathlen > max_pathlen else max_pathlen 148 | 149 | conflicts_db = get_conflicts(agents, path_seq, conflicts_db) 150 | 151 | for agent in pickd_agents: 152 | if (agent in conflicts_db): 153 | constraints = conflicts_db[agent] 154 | constraints.update({}) 155 | if(bool(constraints)): 156 | start = cell_spacetime_conv(world.aindx_cpos[agent], 0) 157 | goal = cell_spacetime_conv(world.aindx_goal[agent], SOMETIME) 158 | print 'Agent',agent,': S',start, ' G', goal, '\n\t C', constraints, '\n\t OP', path_seq[agent] 159 | nw_path, nw_pathlen = get_m_astar_path(world, start, goal, constraints) 160 | if(nw_path): 161 | path_seq[agent] = nw_path 162 | evaluate_path(path_seq, agent, conflicts_db) 163 | else: 164 | path_seq[agent] = [start] 165 | restart_loop = True 166 | print 'Agent',agent,': S',start, ' G', goal, '\n\t C', constraints, '\n\t NP', nw_path, 'Len: ', nw_pathlen 167 | 168 | if not restart_loop: 169 | path_seq = path_equalize(agents, path_seq, SOMETIME) 170 | conflicts_db = get_conflicts(agents, path_seq, conflicts_db) 171 | 172 | break_loop = True 173 | for agent in agents: 174 | ubrokn_conflicts = [] 175 | constraints = conflicts_db[agent] 176 | for step in path_seq[agent]: 177 | if(step in constraints): 178 | ubrokn_conflicts.append(step) 179 | if(ubrokn_conflicts): 180 | print '## A', agent, 'UC:', ubrokn_conflicts 181 | print 'Yes, there are conflicts!' 182 | break_loop = False 183 | goal = cell_spacetime_conv(world.aindx_goal[agent], SOMETIME) 184 | if(path_seq[agent][-1] != goal): 185 | break_loop = False 186 | iter_count = iter_count + 1 187 | 188 | if(break_loop and not restart_loop): 189 | print 'Loop break!' 190 | break 191 | 192 | # something = input('Press any key to continue...') 193 | 194 | for agent in agents: 195 | print '\nAgent ', agent, ' cost:',pathcost[agent], ' Path -- ', path_seq[agent] 196 | 197 | for agent in agents: 198 | if agent in conflicts_db: 199 | print '\nAgent ', agent, ' Conflicts -- ', conflicts_db[agent] 200 | 201 | return path_seq 202 | -------------------------------------------------------------------------------- /gworld.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from macros import * 4 | from visualize import * 5 | 6 | class GridWorld: 7 | def __init__(self, h, w, rocks = None, agent_sng = None): 8 | self.h = h 9 | self.w = w 10 | self.cells = np.zeros((h, w), dtype=int) 11 | self.visualize = None 12 | self.add_rocks(rocks) 13 | 14 | self.aindx_cpos = dict() 15 | self.aindx_goal = dict() 16 | self.tyx_res = dict() 17 | 18 | def xy_saturate(self, x,y): 19 | if(x<0): x=0 20 | if(x>self.w-1): x=self.w-1 21 | if(y<0): y=0 22 | if(y>self.h-1): y=self.h-1 23 | return(x, y) 24 | 25 | def add_rocks(self, rocks): 26 | if rocks: 27 | for rock in rocks: 28 | rockx, rocky = self.xy_saturate(rock[1], rock[0]) 29 | if( not self.is_blocked(rocky, rockx) ): 30 | self.cells[rocky][rockx] = IS_ROCK 31 | 32 | ''' 33 | agent_sng - (sy, sx, gy, gx) 34 | -- start and goal positions for each agent 35 | ''' 36 | def add_agents(self, agents_sng): 37 | if agents_sng: 38 | print agents_sng 39 | # Replace list of tuples with a dict lookup for better performance 40 | for (sy, sx, gy, gx) in agents_sng: 41 | nagents = len( self.aindx_cpos.keys() ) 42 | if(not self.is_blocked(sy, sx) and not self.is_blocked(gy, gx)): 43 | if(self.cells[sy][sx] == UNOCCUPIED): 44 | self.aindx_cpos[nagents + 1] = (sy, sx) 45 | self.cells[sy][sx] = nagents + 1 46 | self.aindx_goal[nagents + 1] = (gy, gx) 47 | else: 48 | raise Exception('Cell has already been occupied!') 49 | else: 50 | raise Exception( 'Failure! agent index: ' + str(nagents + 1) ) 51 | return False 52 | return True 53 | return False 54 | 55 | def path_to_action(self, aindx, path): 56 | actions = [] 57 | cy, cx = self.aindx_cpos[aindx] 58 | for step in path: 59 | ty, tx = step[1], step[2] 60 | if(tx - cx == 1): action = Actions.RIGHT 61 | elif(tx - cx == -1): action = Actions.LEFT 62 | elif(ty - cy == 1): action = Actions.DOWN 63 | elif(ty - cy == -1): action = Actions.UP 64 | else: action = Actions.WAIT 65 | # print 'ToAction: ', cy, cx, ty, tx, tt, action 66 | actions.append(action) 67 | cy, cx = ty, tx 68 | return actions 69 | 70 | def is_validpos(self, y, x): 71 | if x < 0 or x > self.w - 1 or y < 0 or y > self.h - 1: 72 | return False 73 | else: 74 | return True 75 | 76 | # def get_nbor_cells(self, cell_pos): 77 | # y, x = cell_pos[0], cell_pos[1] 78 | # nbor_cells = [] 79 | # if(x > 0): 80 | # nbor_cells.append((y, x-1)) 81 | # if(x < self.w - 1): 82 | # nbor_cells.append((y, x+1)) 83 | # if(y > 0): 84 | # nbor_cells.append((y-1, x)) 85 | # if(y < self.h - 1): 86 | # nbor_cells.append((y+1, x)) 87 | # return nbor_cells 88 | 89 | def get_nbor_cells(self, cell_pos): 90 | nbor_cells = [] 91 | if(len(cell_pos) == 3): 92 | t, y, x= cell_pos[0], cell_pos[1], cell_pos[2] 93 | if(t > MAX_STEPS): 94 | print 'cell = ', cell_pos 95 | raise EnvironmentError 96 | if(x > 0): 97 | nbor_cells.append((t+1, y, x-1)) 98 | if(x < self.w - 1): 99 | nbor_cells.append((t+1, y, x+1)) 100 | if(y > 0): 101 | nbor_cells.append((t+1, y-1, x)) 102 | if(y < self.h - 1): 103 | nbor_cells.append((t+1, y+1, x)) 104 | nbor_cells.append((t+1, y, x)) 105 | elif(len(cell_pos) == 2): 106 | y, x = cell_pos[0], cell_pos[1] 107 | if(x > 0): 108 | nbor_cells.append((y, x-1)) 109 | if(x < self.w - 1): 110 | nbor_cells.append((y, x+1)) 111 | if(y > 0): 112 | nbor_cells.append((y-1, x)) 113 | if(y < self.h - 1): 114 | nbor_cells.append((y+1, x)) 115 | nbor_cells.append((y, x)) 116 | return nbor_cells 117 | 118 | def check_nbors(self, y, x): 119 | ''' 120 | Return contents of neighbors of given cell 121 | return: array [ RIGHT, UP, LEFT, DOWN, WAIT ] 122 | ''' 123 | nbors = np.ones(5, dtype = int ) * INVALID 124 | # x, y = self.xy_saturate(x, y) 125 | if(x > 0): 126 | nbors[Actions.LEFT] = self.cells[y][x-1] 127 | if(x < self.w - 1): 128 | nbors[Actions.RIGHT] = self.cells[y][x+1] 129 | if(y > 0): 130 | nbors[Actions.UP] = self.cells[y-1][x] 131 | if(y < self.h - 1): 132 | nbors[Actions.DOWN] = self.cells[y+1][x] 133 | nbors[Actions.WAIT] = self.cells[y][x] 134 | return nbors 135 | 136 | def is_blocked(self, y, x): 137 | # print 'Cell :', y, x 138 | if not self.is_validpos(y, x): return True 139 | if(self.cells[y][x] == IS_ROCK): return True 140 | return False 141 | 142 | def agent_action(self, aindx, action): 143 | if(aindx in self.aindx_cpos): 144 | y, x = self.aindx_cpos[aindx] 145 | else: 146 | raise Exception('Agent ' + str(aindx) + ' does not exist!') 147 | oy, ox = y, x 148 | nbors = self.check_nbors(y, x) 149 | # print 'DoAction: ', aindx, y, x, nbors, action, 150 | if(nbors[action] == UNOCCUPIED): 151 | # if(nbors[action] != IS_ROCK and nbors[action] != INVALID): 152 | y += int(action == Actions.DOWN) - int(action == Actions.UP) 153 | x += int(action == Actions.RIGHT) - int(action == Actions.LEFT) 154 | self.aindx_cpos[aindx] = (y, x) 155 | self.cells[oy][ox] = 0 156 | self.cells[y][x] = aindx 157 | if(self.visualize): self.visualize.update_agent_vis(aindx) 158 | elif(action == Actions.WAIT): 159 | return (-1) 160 | else: 161 | # print 'DoAction: ', aindx, y, x, nbors, action 162 | raise Exception('Cell is not unoccupied! : (' + str(y) + ',' + str(x) + ') --> ' + str(action) ) 163 | return (0) if self.aindx_cpos[aindx] == self.aindx_goal[aindx] else (-1) 164 | 165 | def passable(self, cell, constraints = None): 166 | retValue = False 167 | if(len(cell) == 3): 168 | t, y, x = cell[0], cell[1], cell[2] 169 | if(self.is_blocked(y,x)): 170 | retValue = False 171 | elif(t > tLIMIT): 172 | retValue = False 173 | elif(bool(constraints)): 174 | if(cell in constraints): 175 | retValue = False 176 | else: 177 | retValue = True 178 | else: 179 | retValue = True 180 | # if(bool(constraints)): 181 | # print '\n ##', cell, '::', constraints,' RETURN:' ,retValue 182 | return retValue 183 | elif(len(cell) == 2): 184 | y, x = cell[0], cell[1] 185 | if(self.is_blocked(y,x)): 186 | retValue = False 187 | elif(bool(constraints)): 188 | if(cell in constraints): 189 | retValue = False 190 | else: 191 | retValue = True 192 | else: 193 | retValue = True 194 | return retValue 195 | 196 | # @staticmethod 197 | def tyx_dist_heuristic(self, a, b): 198 | yx_dist = abs(a[1] - b[1]) + abs(a[2] - b[2]) 199 | if(a[0] == ANY_TIME or b[0] == ANY_TIME): t_dist = yx_dist/WAIT_FACTOR 200 | else: t_dist = ( abs(a[2] - b[2]) ) * int(yx_dist>0) 201 | return yx_dist + t_dist/WAIT_FACTOR 202 | 203 | def get_size(self): 204 | return (self.h, self.w) 205 | 206 | def get_agents(self): 207 | return self.aindx_cpos.keys() 208 | -------------------------------------------------------------------------------- /m_astar.py: -------------------------------------------------------------------------------- 1 | import pqueue 2 | 3 | def manhattan_dist(a, b): 4 | """ 5 | Returns the Manhattan distance between two points. 6 | 7 | >>> manhattan_dist((0, 0), (5, 5)) 8 | 10 9 | >>> manhattan_dist((0, 5), (10, 7)) 10 | 12 11 | >>> manhattan_dist((12, 9), (2, 3)) 12 | 16 13 | >>> manhattan_dist((0, 5), (5, 0)) 14 | 10 15 | """ 16 | return abs(a[0] - b[0]) + abs(a[1] - b[1]) 17 | 18 | def extract_fn(a): 19 | # print 'a :', a, 'Extract :', a[:-1] 20 | return a 21 | # return a[1:] 22 | 23 | def find_path(neighbour_fn, 24 | start, 25 | end, 26 | cost = lambda pos: 1, 27 | passable = lambda pos, constraints = None : True, 28 | heuristic = manhattan_dist, 29 | constraints = None, 30 | extract = extract_fn): 31 | """ 32 | Returns the path between two nodes as a list of nodes using the A* 33 | algorithm. 34 | If no path could be found, an empty list is returned. 35 | 36 | The cost function is how much it costs to leave the given node. This should 37 | always be greater than or equal to 1, or shortest path is not guaranteed. 38 | 39 | The passable function returns whether the given node is passable. 40 | 41 | The heuristic function takes two nodes and computes the distance between the 42 | two. Underestimates are guaranteed to provide an optimal path, but it may 43 | take longer to compute the path. Overestimates lead to faster path 44 | computations, but may not give an optimal path. 45 | """ 46 | # tiles to check (tuples of (x, y), cost) 47 | todo = pqueue.PQueue() 48 | todo.update(start, 0) 49 | 50 | # tiles we've been to 51 | visited = set() 52 | 53 | # associated G and H costs for each tile (tuples of G, H) 54 | costs = { start: (0, heuristic(start, end)) } 55 | 56 | # parents for each tile 57 | parents = {} 58 | 59 | if( heuristic(start, end) == 0 ): 60 | return [start] 61 | 62 | while todo and (extract(end) not in visited): 63 | cur, c = todo.pop_smallest() 64 | 65 | # print 'Current: ', cur, 'cost: ', sum(costs[cur]) 66 | # something = input('Press some key to continue...') 67 | 68 | visited.add(extract(cur)) 69 | 70 | # check neighbours 71 | for n in neighbour_fn(cur): 72 | # skip it if we've already checked it, or if it isn't passable 73 | if ((extract(n) in visited) or 74 | (not passable(n, constraints))): 75 | # print 'Nbor: ', n, (not passable(n, constraints)), (extract(n) in visited) 76 | continue 77 | 78 | if not (n in todo): 79 | # we haven't looked at this tile yet, so calculate its costs 80 | g = costs[cur][0] + cost(cur) 81 | h = heuristic(n, end) 82 | costs[n] = (g, h) 83 | parents[n] = cur 84 | todo.update(n, g + h) 85 | else: 86 | # if we've found a better path, update it 87 | g, h = costs[n] 88 | new_g = costs[cur][0] + cost(cur) 89 | if new_g < g: 90 | g = new_g 91 | todo.update(n, g + h) 92 | costs[n] = (g, h) 93 | parents[n] = cur 94 | # print '\nVisited: ', visited 95 | # print '\nParents: ', parents 96 | 97 | # we didn't find a path 98 | if extract(end) not in visited: 99 | return [], 32767 100 | 101 | # build the path backward 102 | path = [] 103 | while extract(end) != extract(start): 104 | path.append(end) 105 | end = parents[end] 106 | path.append(start) 107 | path.reverse() 108 | 109 | return path, sum(costs[start]) 110 | -------------------------------------------------------------------------------- /macros.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | INVALID = -999 4 | HARD_PLACE = -999 5 | 6 | ANY_TIME = -999 7 | SOMETIME = 25 8 | tLIMIT = 25 9 | 10 | TWAIT = 0 11 | WAIT_FACTOR = 0.51 12 | 13 | MAX_STEPS = 45 14 | 15 | UNOCCUPIED = 0 16 | IS_ROCK = -99 17 | 18 | MOVE_SPEED = 1 19 | MSG_BUFFER_SIZE = 3 20 | 21 | FRAME_HEIGHT = 350 22 | FRAME_WIDTH = 600 23 | 24 | FRAME_MARGIN = 10 25 | CELL_MARGIN = 5 26 | 27 | MAX_AGENTS_IN_CELL = 1 28 | 29 | class Actions(object): 30 | RIGHT = 0 31 | UP = 1 32 | LEFT = 2 33 | DOWN = 3 34 | WAIT = 4 35 | 36 | COLORS = ['red', 'green', 'blue', 'black', 'white', 'magenta', 'cyan', 'yellow'] 37 | -------------------------------------------------------------------------------- /pqueue.py: -------------------------------------------------------------------------------- 1 | def heapsort(l): 2 | """ 3 | Sort a list using the heap (assuming there are no repeated values). 4 | 5 | >>> heapsort([1, 6, 2, 8, 9, 14, 4, 7]) 6 | [1, 2, 4, 6, 7, 8, 9, 14] 7 | """ 8 | q = PQueue() 9 | for (i, x) in enumerate(l): q.update(i, x) 10 | return [ q.pop_smallest()[1] for x in l ] 11 | 12 | def _parent(i): 13 | """ 14 | Returns the parent node of the given node. 15 | """ 16 | return (i - 1) // 2 17 | 18 | def _lchild(i): 19 | """ 20 | Returns the left child node of the given node. 21 | """ 22 | return 2 * i + 1 23 | 24 | def _rchild(i): 25 | """ 26 | Returns the right child node of the given node. 27 | """ 28 | return 2 * i + 2 29 | 30 | def _children(i): 31 | """ 32 | Returns the children of the given node as a tuple (left then right). 33 | """ 34 | return (_lchild(i), _rchild(i)) 35 | 36 | class PQueue: 37 | """ 38 | Priority queue implemented with dictionaries. Stores a set of keys and associated priorities. 39 | 40 | >>> q = PQueue() 41 | >>> q.is_empty() 42 | True 43 | >>> q.update("thing", 5) 44 | True 45 | >>> q.is_empty() 46 | False 47 | >>> q.update("another thing", 2) 48 | True 49 | >>> q.pop_smallest() 50 | ('another thing', 2) 51 | >>> q.update("thing", 100) 52 | False 53 | >>> q.update("something else", 110) 54 | True 55 | >>> q.update("something else", 8) 56 | True 57 | >>> "thing" in q 58 | True 59 | >>> "nothing" in q 60 | False 61 | >>> len(q) 62 | 2 63 | >>> q.peek_smallest() 64 | ('thing', 5) 65 | >>> q.pop_smallest() 66 | ('thing', 5) 67 | >>> q.pop_smallest() 68 | ('something else', 8) 69 | >>> True if q else False 70 | False 71 | >>> q.is_empty() 72 | True 73 | >>> q.tie_breaker = lambda x,y: x[1] < y[1] 74 | >>> q.update(("A", 6), 5) 75 | True 76 | >>> q.update(("B", 1), 5) 77 | True 78 | >>> q.update(("C", 10), 1) 79 | True 80 | >>> q.update(("D", 4), 5) 81 | True 82 | >>> q.pop_smallest()[0][0] 83 | 'C' 84 | >>> q.pop_smallest()[0][0] 85 | 'B' 86 | >>> q.pop_smallest()[0][0] 87 | 'D' 88 | >>> q.pop_smallest()[0][0] 89 | 'A' 90 | 91 | """ 92 | def __init__(self): 93 | self._heap = [] 94 | self._keyindex = {} 95 | self.tie_breaker = None 96 | 97 | def __len__(self): 98 | return len(self._heap) 99 | 100 | def __contains__(self, key): 101 | return key in self._keyindex 102 | 103 | def _key(self, i): 104 | """ 105 | Returns the key value of the given node. 106 | """ 107 | return self._heap[i][0] 108 | 109 | def _priority(self, i): 110 | """ 111 | Returns the priority of the given node. 112 | """ 113 | return self._heap[i][1] 114 | 115 | def _swap(self, i, j): 116 | """ 117 | Swap the positions of two nodes and update the key index. 118 | """ 119 | (self._heap[i], self._heap[j]) = (self._heap[j], self._heap[i]) 120 | (self._keyindex[self._key(i)], self._keyindex[self._key(j)]) = (self._keyindex[self._key(j)], self._keyindex[self._key(i)]) 121 | 122 | def _heapify_down(self, i): 123 | """ 124 | Solves heap violations starting at the given node, moving down the heap. 125 | """ 126 | 127 | children = [ c for c in _children(i) if c < len(self._heap) ] 128 | 129 | # This is a leaf, so stop 130 | if not children: return 131 | 132 | # Get the minimum child 133 | min_child = min(children, key=self._priority) 134 | 135 | # If there are two children with the same priority, we need to break the tie 136 | if self.tie_breaker and len(children) == 2: 137 | c0 = children[0] 138 | c1 = children[1] 139 | if self._priority(c0) == self._priority(c1): 140 | min_child = c0 if self.tie_breaker(self._key(c0), self._key(c1)) else c1 141 | 142 | # Sort, if necessary 143 | a = self._priority(i) 144 | b = self._priority(min_child) 145 | if a > b or (self.tie_breaker and a == b and not self.tie_breaker(self._key(i), self._key(min_child))): 146 | # Swap with the minimum child and continue heapifying 147 | self._swap(i, min_child) 148 | self._heapify_down(min_child) 149 | 150 | def _heapify_up(self, i): 151 | """ 152 | Solves heap violations starting at the given node, moving up the heap. 153 | """ 154 | # This is the top of the heap, so stop. 155 | if i == 0: return 156 | 157 | parent = _parent(i) 158 | a = self._priority(i) 159 | b = self._priority(parent) 160 | if a < b or (self.tie_breaker and a == b and self.tie_breaker(self._key(i), self._key(parent))): 161 | self._swap(i, parent) 162 | self._heapify_up(parent) 163 | 164 | def peek_smallest(self): 165 | """ 166 | Returns a tuple containing the key with the smallest priority and its associated priority. 167 | """ 168 | return self._heap[0] 169 | 170 | def pop_smallest(self): 171 | """ 172 | Removes the key with the smallest priority and returns a tuple containing the key and its associated priority 173 | """ 174 | 175 | # Swap the last node to the front 176 | self._swap(0, len(self._heap) - 1) 177 | 178 | # Remove the smallest from the list 179 | (key, priority) = self._heap.pop() 180 | del self._keyindex[key] 181 | 182 | # Fix the heap 183 | self._heapify_down(0) 184 | 185 | return (key, priority) 186 | 187 | def update(self, key, priority): 188 | """ 189 | update(key, priority) 190 | If priority is lower than the associated priority of key, then change it to the new priority. If not, does nothing. 191 | 192 | If key is not in the priority queue, add it. 193 | 194 | Return True if a change was made, else False. 195 | """ 196 | 197 | if key in self._keyindex: 198 | # Find key index in heap 199 | i = self._keyindex[key] 200 | 201 | # Make sure this lowers its priority 202 | if priority > self._priority(i): 203 | return False 204 | 205 | # Fix the heap 206 | self._heap[i] = (key, priority) 207 | self._heapify_up(i) 208 | return True 209 | else: 210 | self._heap.append((key, priority)) 211 | self._keyindex[key] = len(self._heap) - 1 212 | self._heapify_up(len(self._heap) - 1) 213 | return True 214 | 215 | def is_empty(self): 216 | """ 217 | Returns True if the queue is empty empty, else False. 218 | """ 219 | return len(self) == 0 220 | 221 | if __name__ == "__main__": 222 | import doctest 223 | doctest.testmod() 224 | -------------------------------------------------------------------------------- /pqueue0.py: -------------------------------------------------------------------------------- 1 | def heapsort(l): 2 | """ 3 | Sort a list using the heap (assuming there are no repeated values). 4 | 5 | >>> heapsort([1, 6, 2, 8, 9, 14, 4, 7]) 6 | [1, 2, 4, 6, 7, 8, 9, 14] 7 | """ 8 | q = PQueue() 9 | for (i, x) in enumerate(l): q.update(i, x) 10 | return [ q.pop_smallest()[1] for x in l ] 11 | 12 | def _parent(i): 13 | """ 14 | Returns the parent node of the given node. 15 | """ 16 | return (i - 1) // 2 17 | 18 | def _lchild(i): 19 | """ 20 | Returns the left child node of the given node. 21 | """ 22 | return 2 * i + 1 23 | 24 | def _rchild(i): 25 | """ 26 | Returns the right child node of the given node. 27 | """ 28 | return 2 * i + 2 29 | 30 | def _children(i): 31 | """ 32 | Returns the children of the given node as a tuple (left then right). 33 | """ 34 | return (_lchild(i), _rchild(i)) 35 | 36 | class PQueue: 37 | """ 38 | Priority queue implemented with dictionaries. Stores a set of keys and associated priorities. 39 | 40 | >>> q = PQueue() 41 | >>> q.is_empty() 42 | True 43 | >>> q.update("thing", 5) 44 | True 45 | >>> q.is_empty() 46 | False 47 | >>> q.update("another thing", 2) 48 | True 49 | >>> q.pop_smallest() 50 | ('another thing', 2) 51 | >>> q.update("thing", 100) 52 | False 53 | >>> q.update("something else", 110) 54 | True 55 | >>> q.update("something else", 8) 56 | True 57 | >>> "thing" in q 58 | True 59 | >>> "nothing" in q 60 | False 61 | >>> len(q) 62 | 2 63 | >>> q.peek_smallest() 64 | ('thing', 5) 65 | >>> q.pop_smallest() 66 | ('thing', 5) 67 | >>> q.pop_smallest() 68 | ('something else', 8) 69 | >>> True if q else False 70 | False 71 | >>> q.is_empty() 72 | True 73 | >>> q.tie_breaker = lambda x,y: x[1] < y[1] 74 | >>> q.update(("A", 6), 5) 75 | True 76 | >>> q.update(("B", 1), 5) 77 | True 78 | >>> q.update(("C", 10), 1) 79 | True 80 | >>> q.update(("D", 4), 5) 81 | True 82 | >>> q.pop_smallest()[0][0] 83 | 'C' 84 | >>> q.pop_smallest()[0][0] 85 | 'B' 86 | >>> q.pop_smallest()[0][0] 87 | 'D' 88 | >>> q.pop_smallest()[0][0] 89 | 'A' 90 | 91 | """ 92 | def __init__(self): 93 | self._heap = [] 94 | self._keyindex = {} 95 | self.tie_breaker = None 96 | 97 | def __len__(self): 98 | return len(self._heap) 99 | 100 | def __contains__(self, key): 101 | return key in self._keyindex 102 | 103 | def _key(self, i): 104 | """ 105 | Returns the key value of the given node. 106 | """ 107 | return self._heap[i][0] 108 | 109 | def _priority(self, i): 110 | """ 111 | Returns the priority of the given node. 112 | """ 113 | return self._heap[i][1] 114 | 115 | def _swap(self, i, j): 116 | """ 117 | Swap the positions of two nodes and update the key index. 118 | """ 119 | (self._heap[i], self._heap[j]) = (self._heap[j], self._heap[i]) 120 | (self._keyindex[self._key(i)], self._keyindex[self._key(j)]) = (self._keyindex[self._key(j)], self._keyindex[self._key(i)]) 121 | 122 | def _heapify_down(self, i): 123 | """ 124 | Solves heap violations starting at the given node, moving down the heap. 125 | """ 126 | 127 | children = [ c for c in _children(i) if c < len(self._heap) ] 128 | 129 | # This is a leaf, so stop 130 | if not children: return 131 | 132 | # Get the minimum child 133 | min_child = min(children, key=self._priority) 134 | 135 | # If there are two children with the same priority, we need to break the tie 136 | if self.tie_breaker and len(children) == 2: 137 | c0 = children[0] 138 | c1 = children[1] 139 | if self._priority(c0) == self._priority(c1): 140 | min_child = c0 if self.tie_breaker(self._key(c0), self._key(c1)) else c1 141 | 142 | # Sort, if necessary 143 | a = self._priority(i) 144 | b = self._priority(min_child) 145 | if a > b or (self.tie_breaker and a == b and not self.tie_breaker(self._key(i), self._key(min_child))): 146 | # Swap with the minimum child and continue heapifying 147 | self._swap(i, min_child) 148 | self._heapify_down(min_child) 149 | 150 | def _heapify_up(self, i): 151 | """ 152 | Solves heap violations starting at the given node, moving up the heap. 153 | """ 154 | # This is the top of the heap, so stop. 155 | if i == 0: return 156 | 157 | parent = _parent(i) 158 | a = self._priority(i) 159 | b = self._priority(parent) 160 | if a < b or (self.tie_breaker and a == b and self.tie_breaker(self._key(i), self._key(parent))): 161 | self._swap(i, parent) 162 | self._heapify_up(parent) 163 | 164 | def peek_smallest(self): 165 | """ 166 | Returns a tuple containing the key with the smallest priority and its associated priority. 167 | """ 168 | return self._heap[0] 169 | 170 | def pop_smallest(self): 171 | """ 172 | Removes the key with the smallest priority and returns a tuple containing the key and its associated priority 173 | """ 174 | 175 | # Swap the last node to the front 176 | self._swap(0, len(self._heap) - 1) 177 | 178 | # Remove the smallest from the list 179 | (key, priority) = self._heap.pop() 180 | del self._keyindex[key] 181 | 182 | # Fix the heap 183 | self._heapify_down(0) 184 | 185 | return (key, priority) 186 | 187 | def update(self, key, priority): 188 | """ 189 | update(key, priority) 190 | If priority is lower than the associated priority of key, then change it to the new priority. If not, does nothing. 191 | 192 | If key is not in the priority queue, add it. 193 | 194 | Return True if a change was made, else False. 195 | """ 196 | 197 | if key in self._keyindex: 198 | # Find key index in heap 199 | i = self._keyindex[key] 200 | 201 | # Make sure this lowers its priority 202 | if priority > self._priority(i): 203 | return False 204 | 205 | # Fix the heap 206 | self._heap[i] = (key, priority) 207 | self._heapify_up(i) 208 | return True 209 | else: 210 | self._heap.append((key, priority)) 211 | self._keyindex[key] = len(self._heap) - 1 212 | self._heapify_up(len(self._heap) - 1) 213 | return True 214 | 215 | def is_empty(self): 216 | """ 217 | Returns True if the queue is empty empty, else False. 218 | """ 219 | return len(self) == 0 220 | 221 | if __name__ == "__main__": 222 | import doctest 223 | doctest.testmod() 224 | -------------------------------------------------------------------------------- /ta_test.py: -------------------------------------------------------------------------------- 1 | from gworld import * 2 | from visualize import * 3 | import astar 4 | 5 | # a = GridWorld(6,10, [(2,1),(1,2)] ) 6 | a = GridWorld(5,10) 7 | 8 | vis = Visualize(a) 9 | 10 | a.add_agents( [ (1,1,2,2) ] ) #, (1,0,2,3) 11 | a.add_rocks( [ (2,1),(1,2),(1,3),(1,4),(3,1),(4,1),(2,3),(3,3),(3,4) ] ) 12 | 13 | vis.draw_world() 14 | vis.draw_agents() 15 | 16 | vis.canvas.pack() 17 | vis.canvas.update() 18 | vis.canvas.after(1000) 19 | 20 | path = astar.find_path(a.get_nbor_cells, 21 | a.aindx_cpos[1], 22 | a.aindx_goal[1], 23 | lambda cell: 1, 24 | lambda cell: not a.is_blocked( cell[0], cell[1] ) ) 25 | 26 | print path 27 | actions = a.path_to_action(1, path[1:]) 28 | 29 | print actions 30 | 31 | for action in actions: 32 | a.agent_action(1, action) 33 | vis.canvas.update() 34 | vis.canvas.after(1000) 35 | 36 | print a.cells 37 | vis.canvas.after(3000) 38 | 39 | 40 | # a.agent_action(1, Actions.UP) 41 | # 42 | # print a.check_nbors(1,1) 43 | # print a.is_blocked(2,1) 44 | # print a.is_blocked(1,1) 45 | # print a.cells 46 | # 47 | # vis.canvas.pack() 48 | # vis.canvas.update() 49 | # vis.canvas.after(1000) 50 | # 51 | # a.agent_action(2, Actions.RIGHT) 52 | # 53 | # print a.check_nbors(1,1) 54 | # print a.is_blocked(2,1) 55 | # print a.is_blocked(1,1) 56 | # print a.cells 57 | # 58 | # vis.canvas.pack() 59 | # vis.canvas.update() 60 | # vis.canvas.after(1000) 61 | # 62 | # a.agent_action(2, Actions.LEFT) 63 | # 64 | # print a.check_nbors(1,1) 65 | # print a.is_blocked(2,1) 66 | # print a.is_blocked(1,1) 67 | # print a.cells 68 | # 69 | # vis.canvas.pack() 70 | # vis.canvas.update() 71 | # vis.canvas.after(1000) 72 | -------------------------------------------------------------------------------- /tb_test.py: -------------------------------------------------------------------------------- 1 | from gworld import * 2 | from visualize import * 3 | import m_astar 4 | import cbsearch as cbs 5 | 6 | def get_m_astar_path(world, start, goal, constraints = None): 7 | ret_path = m_astar.find_path(world.get_nbor_cells, 8 | start, 9 | goal, 10 | lambda cell: 1, 11 | lambda cell, constraints : world.passable( cell, constraints ), 12 | world.yxt_dist_heuristic, 13 | constraints) 14 | return ret_path 15 | 16 | ## Go around block. Wait aside for agent1 to pass 17 | ## Takes too long. Need better conflict handling 18 | # a = GridWorld(6,10) 19 | # a.add_rocks( [ (2,1),(1,2),(1,3),(1,4),(3,1),(2,3),(3,3),(3,4) ] ) 20 | # a.add_agents( [ (1,0,3,2), (1,1,2,2) ] ) 21 | 22 | ## 2 agents. Narrow path with a open slot on the wall 23 | ## Waits too long. Need better conflict handling 24 | # a = GridWorld(6,10) 25 | # a.add_rocks( [ (1,0),(1,1),(1,2),(1,3),(1,4),(2,5),(1,6),(1,7),(1,8),(1,9),(0,9) ] ) 26 | # a.add_agents( [ (0,0,0,8), (0,1,0,7) ] ) 27 | 28 | ## 3 agents. Few rocks. More space to swerve around 29 | a = GridWorld(6,10) 30 | # a.add_rocks( [ (4,0),(4,1),(4,2),(1,7),(1,8),(1,9) ] ) 31 | a.add_rocks( [ (4,0),(4,1),(4,2),(4,3),(1,6),(1,7),(1,8),(1,9) ] ) 32 | a.add_agents( [ (0,7,5,1), (5,3,0,9), (0,3,5,9) ] ) 33 | 34 | ## 3 agents. Single passable block 35 | # a = GridWorld(6,10) 36 | # a.add_rocks( [ (4,0),(4,1),(4,2),(4,3),(4,4),(3,4),(1,6),(1,7),(1,8),(1,9) ] ) 37 | # a.add_agents( [ (0,7,5,1), (5,3,0,9), (0,3,5,9) ] ) 38 | 39 | ## 4 agents. Few rocks. More space to swerve around 40 | ## Need better conflict handling for an optimal path 41 | # a = GridWorld(6,10) 42 | # a.add_rocks( [ (4,0),(4,1),(4,2),(1,7),(1,8),(1,9) ] ) 43 | # a.add_agents( [ (0,7,5,1), (5,3,0,9), (0,3,5,9), (3,0,3,9) ] ) 44 | 45 | 46 | vis = Visualize(a) 47 | 48 | vis.draw_world() 49 | vis.draw_agents() 50 | 51 | vis.canvas.pack() 52 | vis.canvas.update() 53 | vis.canvas.after(500) 54 | 55 | agents = a.get_agents() 56 | 57 | conflict = False 58 | 59 | path_maxlen = 0 60 | 61 | constraints = [] 62 | 63 | 64 | # cpos = a.aindx_cpos[agent] 65 | # goal = a.aindx_goal[agent] 66 | # start_cell = (cpos[0], cpos[1], 0) 67 | # goal_cell = (goal[0], goal[1], ANY_TIME) 68 | 69 | path_seq = dict() 70 | 71 | path_seq = cbs.search(agents, a) 72 | 73 | # path_seq[1] = [(0,7),(0,6] 74 | 75 | action_seq = dict() 76 | 77 | for agent in agents: 78 | path_len = len(path_seq[agent]) 79 | path_maxlen = path_len if (path_len > path_maxlen) else path_maxlen 80 | action_seq[agent] = a.path_to_action(agent, path_seq[agent]) 81 | 82 | something = input('Press some key to continue...') 83 | 84 | for step in range(path_maxlen): 85 | for agent in agents: 86 | # print 'ActSeq: ', agent, action_seq[agent] 87 | if( action_seq[agent] ): 88 | action = action_seq[agent].pop(0) 89 | a.agent_action(agent, action) 90 | vis.canvas.update() 91 | vis.canvas.after(150) 92 | vis.canvas.update() 93 | vis.canvas.after(500) 94 | 95 | vis.canvas.update() 96 | vis.canvas.after(5000) 97 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from gworld import * 2 | from visualize import * 3 | import m_astar 4 | import cbsearch as cbs 5 | 6 | def get_m_astar_path(world, start, goal, constraints = None): 7 | ret_path = m_astar.find_path(world.get_nbor_cells, 8 | start, 9 | goal, 10 | lambda cell: 1, 11 | lambda cell, constraints : world.passable( cell, constraints ), 12 | world.yxt_dist_heuristic, 13 | constraints) 14 | return ret_path 15 | 16 | ## Go around block. Wait aside for agent1 to pass 17 | ## Takes too long. Need better conflict handling 18 | # a = GridWorld(6,10) 19 | # a.add_rocks( [ (2,1),(1,2),(1,3),(1,4),(3,1),(2,3),(3,3),(3,4) ] ) 20 | # a.add_agents( [ (1,0,3,2), (1,1,2,2) ] ) 21 | 22 | ## 2 agents. Narrow path with a open slot on the wall 23 | ## Waits too long. Need better conflict handling 24 | a = GridWorld(6,10) 25 | a.add_rocks( [ (1,0),(1,1),(1,2),(1,3),(1,4),(2,5),(1,6),(1,7),(1,8),(1,9),(0,9) ] ) 26 | a.add_agents( [ (0,0,0,8), (0,1,0,7) ] ) 27 | 28 | ## 3 agents. Few rocks. More space to swerve around 29 | # a = GridWorld(6,10) 30 | # a.add_rocks( [ (4,0),(4,1),(4,2),(1,7),(1,8),(1,9) ] ) 31 | # a.add_agents( [ (0,7,5,1), (5,3,0,9), (0,3,5,9) ] ) 32 | 33 | ## 3 agents. Single passable block 34 | # a = GridWorld(6,10) 35 | # a.add_rocks( [ (4,0),(4,1),(4,2),(4,3),(4,4),(3,4),(1,6),(1,7),(1,8),(1,9) ] ) 36 | # a.add_agents( [ (0,7,5,1), (5,3,0,9), (0,3,5,9) ] ) 37 | 38 | ## 4 agents. Few rocks. More space to swerve around 39 | ## Need better conflict handling for an optimal path 40 | # a = GridWorld(6,10) 41 | # a.add_rocks( [ (4,0),(4,1),(4,2),(1,7),(1,8),(1,9) ] ) 42 | # a.add_agents( [ (0,7,5,1), (5,3,0,9), (0,3,5,9), (3,0,3,9) ] ) 43 | 44 | 45 | vis = Visualize(a) 46 | 47 | vis.draw_world() 48 | vis.draw_agents() 49 | 50 | vis.canvas.pack() 51 | vis.canvas.update() 52 | vis.canvas.after(500) 53 | 54 | agents = a.get_agents() 55 | 56 | conflict = False 57 | 58 | path_maxlen = 0 59 | 60 | constraints = [] 61 | 62 | 63 | # cpos = a.aindx_cpos[agent] 64 | # goal = a.aindx_goal[agent] 65 | # start_cell = (cpos[0], cpos[1], 0) 66 | # goal_cell = (goal[0], goal[1], ANY_TIME) 67 | 68 | path_seq = dict() 69 | 70 | path_seq = cbs.search(agents, a) 71 | 72 | # path_seq[1] = [(0,7),(0,6] 73 | 74 | action_seq = dict() 75 | 76 | for agent in agents: 77 | path_len = len(path_seq[agent]) 78 | path_maxlen = path_len if (path_len > path_maxlen) else path_maxlen 79 | action_seq[agent] = a.path_to_action(agent, path_seq[agent]) 80 | 81 | something = input('Press 2 + to continue...') 82 | 83 | for step in range(path_maxlen): 84 | for agent in agents: 85 | # print 'ActSeq: ', agent, action_seq[agent] 86 | if( action_seq[agent] ): 87 | action = action_seq[agent].pop(0) 88 | a.agent_action(agent, action) 89 | vis.canvas.update() 90 | vis.canvas.after(150) 91 | vis.canvas.update() 92 | vis.canvas.after(500) 93 | 94 | vis.canvas.update() 95 | vis.canvas.after(5000) 96 | -------------------------------------------------------------------------------- /ts_astar_test.py: -------------------------------------------------------------------------------- 1 | from gworld import * 2 | from visualize import * 3 | import astar 4 | 5 | # a = GridWorld(6,10, [(2,1),(1,2)] ) 6 | a = GridWorld(5,10) 7 | 8 | vis = Visualize(a) 9 | 10 | a.add_agents( [ (1,1,2,2) ] ) #, (1,0,2,3) 11 | a.add_rocks( [ (2,1),(1,2),(1,3),(1,4),(3,1),(4,1),(2,3),(3,3),(3,4) ] ) 12 | 13 | vis.draw_world() 14 | vis.draw_agents() 15 | 16 | vis.canvas.pack() 17 | vis.canvas.update() 18 | vis.canvas.after(1000) 19 | 20 | path = astar.find_path(a.get_nbor_cells, 21 | a.aindx_cpos[1], 22 | a.aindx_goal[1], 23 | lambda cell: 1, 24 | lambda cell: not a.is_blocked( cell[0], cell[1] ) ) 25 | 26 | print path 27 | actions = a.path_to_action(1, path[1:]) 28 | 29 | print actions 30 | 31 | for action in actions: 32 | a.agent_action(1, action) 33 | vis.canvas.update() 34 | vis.canvas.after(1000) 35 | 36 | print a.cells 37 | vis.canvas.after(3000) 38 | 39 | 40 | # a.agent_action(1, Actions.UP) 41 | # 42 | # print a.check_nbors(1,1) 43 | # print a.is_blocked(2,1) 44 | # print a.is_blocked(1,1) 45 | # print a.cells 46 | # 47 | # vis.canvas.pack() 48 | # vis.canvas.update() 49 | # vis.canvas.after(1000) 50 | # 51 | # a.agent_action(2, Actions.RIGHT) 52 | # 53 | # print a.check_nbors(1,1) 54 | # print a.is_blocked(2,1) 55 | # print a.is_blocked(1,1) 56 | # print a.cells 57 | # 58 | # vis.canvas.pack() 59 | # vis.canvas.update() 60 | # vis.canvas.after(1000) 61 | # 62 | # a.agent_action(2, Actions.LEFT) 63 | # 64 | # print a.check_nbors(1,1) 65 | # print a.is_blocked(2,1) 66 | # print a.is_blocked(1,1) 67 | # print a.cells 68 | # 69 | # vis.canvas.pack() 70 | # vis.canvas.update() 71 | # vis.canvas.after(1000) 72 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | from macros import * 2 | import numpy as np 3 | from gworld import * 4 | from Tkinter import * 5 | 6 | class Visualize: 7 | def __init__(self, world_data): 8 | self.frame = Tk() 9 | self.canvas = Canvas(self.frame, width=FRAME_WIDTH, height=FRAME_HEIGHT) 10 | self.canvas.grid() 11 | self.world = world_data 12 | world_data.visualize = self 13 | self.cell_h, self.cell_w = self.get_cell_size() 14 | self.agent_h, self.agent_w = self.get_agent_size(1) 15 | self.vis_cells = np.zeros_like(self.world.cells, dtype = int) 16 | self.aindx_obj = dict() 17 | 18 | def draw_world(self): 19 | nrows, ncols = self.world.get_size() 20 | for row in range(nrows): 21 | for col in range(ncols): 22 | self.vis_cells[row][col] = self.canvas.create_rectangle(FRAME_MARGIN + self.cell_w * col, FRAME_MARGIN + self.cell_h * row, FRAME_MARGIN + self.cell_w * (col+1), FRAME_MARGIN + self.cell_h * (row+1) ) 23 | if self.world.cells[row][col] == IS_ROCK: 24 | self.canvas.itemconfig(self.vis_cells[row][col], fill='gray', width=2) 25 | 26 | def get_pos_in_cell(self, crow, ccol): 27 | agent_h = self.agent_h 28 | agent_w = self.agent_w 29 | agent_y1 = FRAME_MARGIN + (crow * self.cell_h) + CELL_MARGIN 30 | agent_y2 = agent_y1 + agent_h 31 | agent_x1 = FRAME_MARGIN + (ccol * self.cell_w) + CELL_MARGIN 32 | agent_x2 = agent_x1 + agent_w 33 | return (agent_y1, agent_x1, agent_y2, agent_x2) 34 | 35 | def draw_agents(self): 36 | for crow in range(self.world.h): 37 | for ccol in range(self.world.w): 38 | cell = self.world.cells[crow][ccol] 39 | if( cell != UNOCCUPIED and not self.world.is_blocked(crow, ccol) ): 40 | y1, x1, y2, x2 = self.get_pos_in_cell(crow, ccol) 41 | self.aindx_obj[cell] = self.canvas.create_oval(x1, y1, x2, y2, fill=COLORS[cell]) 42 | gy, gx = self.world.aindx_goal[cell] 43 | goal_cell = self.vis_cells[gy][gx] 44 | self.canvas.itemconfig(goal_cell, outline=COLORS[cell], width=4) 45 | 46 | def update_agent_vis(self, aindx): 47 | cy, cx = self.world.aindx_cpos[aindx] 48 | y1, x1, y2, x2 = self.get_pos_in_cell(cy, cx) 49 | self.canvas.coords(self.aindx_obj[aindx], x1, y1, x2, y2) 50 | 51 | def get_cell_size(self): 52 | avail_h = FRAME_HEIGHT - 2 * FRAME_MARGIN 53 | avail_w = FRAME_WIDTH - 2 * FRAME_MARGIN 54 | nrows, ncols = self.world.get_size() 55 | cell_h = avail_h / nrows 56 | cell_w = avail_w / ncols 57 | return (cell_h, cell_w) 58 | 59 | def get_agent_size(self, nagents): 60 | agent_h = self.cell_h - 2 * CELL_MARGIN 61 | agent_w = self.cell_w - 2 * CELL_MARGIN 62 | return (agent_h, agent_w) 63 | 64 | def do_loop(self): 65 | self.frame.mainloop() 66 | 67 | def do_pack(self): 68 | self.canvas.pack() 69 | --------------------------------------------------------------------------------