├── a_star.py ├── main.py ├── point.py └── random_map.py /a_star.py: -------------------------------------------------------------------------------- 1 | # a_star.py 2 | 3 | import sys 4 | import time 5 | 6 | import numpy as np 7 | 8 | from matplotlib.patches import Rectangle 9 | 10 | import point 11 | import random_map 12 | 13 | class AStar: 14 | def __init__(self, map): 15 | self.map=map 16 | self.open_set = [] 17 | self.close_set = [] 18 | 19 | def BaseCost(self, p): 20 | x_dis = p.x 21 | y_dis = p.y 22 | # Distance to start point 23 | return x_dis + y_dis + (np.sqrt(2) - 2) * min(x_dis, y_dis) 24 | 25 | def HeuristicCost(self, p): 26 | x_dis = self.map.size - 1 - p.x 27 | y_dis = self.map.size - 1 - p.y 28 | # Distance to end point 29 | return x_dis + y_dis + (np.sqrt(2) - 2) * min(x_dis, y_dis) 30 | 31 | def TotalCost(self, p): 32 | return self.BaseCost(p) + self.HeuristicCost(p) 33 | 34 | def IsValidPoint(self, x, y): 35 | if x < 0 or y < 0: 36 | return False 37 | if x >= self.map.size or y >= self.map.size: 38 | return False 39 | return not self.map.IsObstacle(x, y) 40 | 41 | def IsInPointList(self, p, point_list): 42 | for point in point_list: 43 | if point.x == p.x and point.y == p.y: 44 | return True 45 | return False 46 | 47 | def IsInOpenList(self, p): 48 | return self.IsInPointList(p, self.open_set) 49 | 50 | def IsInCloseList(self, p): 51 | return self.IsInPointList(p, self.close_set) 52 | 53 | def IsStartPoint(self, p): 54 | return p.x == 0 and p.y ==0 55 | 56 | def IsEndPoint(self, p): 57 | return p.x == self.map.size-1 and p.y == self.map.size-1 58 | 59 | def SaveImage(self, plt): 60 | millis = int(round(time.time() * 1000)) 61 | filename = './' + str(millis) + '.png' 62 | plt.savefig(filename) 63 | 64 | def ProcessPoint(self, x, y, parent): 65 | if not self.IsValidPoint(x, y): 66 | return # Do nothing for invalid point 67 | p = point.Point(x, y) 68 | if self.IsInCloseList(p): 69 | return # Do nothing for visited point 70 | print('Process Point [', p.x, ',', p.y, ']', ', cost: ', p.cost) 71 | if not self.IsInOpenList(p): 72 | p.parent = parent 73 | p.cost = self.TotalCost(p) 74 | self.open_set.append(p) 75 | 76 | def SelectPointInOpenList(self): 77 | index = 0 78 | selected_index = -1 79 | min_cost = sys.maxsize 80 | for p in self.open_set: 81 | cost = self.TotalCost(p) 82 | if cost < min_cost: 83 | min_cost = cost 84 | selected_index = index 85 | index += 1 86 | return selected_index 87 | 88 | def BuildPath(self, p, ax, plt, start_time): 89 | path = [] 90 | while True: 91 | path.insert(0, p) # Insert first 92 | if self.IsStartPoint(p): 93 | break 94 | else: 95 | p = p.parent 96 | for p in path: 97 | rec = Rectangle((p.x, p.y), 1, 1, color='g') 98 | ax.add_patch(rec) 99 | plt.draw() 100 | self.SaveImage(plt) 101 | end_time = time.time() 102 | print('===== Algorithm finish in', int(end_time-start_time), ' seconds') 103 | 104 | def RunAndSaveImage(self, ax, plt): 105 | start_time = time.time() 106 | 107 | start_point = point.Point(0, 0) 108 | start_point.cost = 0 109 | self.open_set.append(start_point) 110 | 111 | while True: 112 | index = self.SelectPointInOpenList() 113 | if index < 0: 114 | print('No path found, algorithm failed!!!') 115 | return 116 | p = self.open_set[index] 117 | rec = Rectangle((p.x, p.y), 1, 1, color='c') 118 | ax.add_patch(rec) 119 | self.SaveImage(plt) 120 | 121 | if self.IsEndPoint(p): 122 | return self.BuildPath(p, ax, plt, start_time) 123 | 124 | del self.open_set[index] 125 | self.close_set.append(p) 126 | 127 | # Process all neighbors 128 | x = p.x 129 | y = p.y 130 | self.ProcessPoint(x-1, y+1, p) 131 | self.ProcessPoint(x-1, y, p) 132 | self.ProcessPoint(x-1, y-1, p) 133 | self.ProcessPoint(x, y-1, p) 134 | self.ProcessPoint(x+1, y-1, p) 135 | self.ProcessPoint(x+1, y, p) 136 | self.ProcessPoint(x+1, y+1, p) 137 | self.ProcessPoint(x, y+1, p) 138 | 139 | 140 | 141 | 142 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # main.py 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from matplotlib.patches import Rectangle 7 | 8 | import random_map 9 | import a_star 10 | 11 | plt.figure(figsize=(5, 5)) 12 | 13 | map = random_map.RandomMap() 14 | 15 | ax = plt.gca() 16 | ax.set_xlim([0, map.size]) 17 | ax.set_ylim([0, map.size]) 18 | 19 | for i in range(map.size): 20 | for j in range(map.size): 21 | if map.IsObstacle(i,j): 22 | rec = Rectangle((i, j), width=1, height=1, color='gray') 23 | ax.add_patch(rec) 24 | else: 25 | rec = Rectangle((i, j), width=1, height=1, edgecolor='gray', facecolor='w') 26 | ax.add_patch(rec) 27 | 28 | rec = Rectangle((0, 0), width = 1, height = 1, facecolor='b') 29 | ax.add_patch(rec) 30 | 31 | rec = Rectangle((map.size-1, map.size-1), width = 1, height = 1, facecolor='r') 32 | ax.add_patch(rec) 33 | 34 | plt.axis('equal') 35 | plt.axis('off') 36 | plt.tight_layout() 37 | #plt.show() 38 | 39 | a_star = a_star.AStar(map) 40 | a_star.RunAndSaveImage(ax, plt) 41 | 42 | -------------------------------------------------------------------------------- /point.py: -------------------------------------------------------------------------------- 1 | # point.py 2 | 3 | import sys 4 | 5 | class Point: 6 | def __init__(self, x, y): 7 | self.x = x 8 | self.y = y 9 | self.cost = sys.maxsize -------------------------------------------------------------------------------- /random_map.py: -------------------------------------------------------------------------------- 1 | # random_map.py 2 | 3 | import numpy as np 4 | 5 | import point 6 | 7 | class RandomMap: 8 | def __init__(self, size=50): 9 | self.size = size 10 | self.obstacle = size//8 11 | self.GenerateObstacle() 12 | 13 | def GenerateObstacle(self): 14 | self.obstacle_point = [] 15 | self.obstacle_point.append(point.Point(self.size//2, self.size//2)) 16 | self.obstacle_point.append(point.Point(self.size//2, self.size//2-1)) 17 | 18 | 19 | # Generate an obstacle in the middle 20 | for i in range(self.size//2-4, self.size//2): 21 | self.obstacle_point.append(point.Point(i, self.size-i)) 22 | self.obstacle_point.append(point.Point(i, self.size-i-1)) 23 | self.obstacle_point.append(point.Point(self.size-i, i)) 24 | self.obstacle_point.append(point.Point(self.size-i, i-1)) 25 | 26 | for i in range(self.obstacle-1): 27 | x = np.random.randint(0, self.size) 28 | y = np.random.randint(0, self.size) 29 | self.obstacle_point.append(point.Point(x, y)) 30 | 31 | if (np.random.rand() > 0.5): # Random boolean 32 | for l in range(self.size//4): 33 | self.obstacle_point.append(point.Point(x, y+l)) 34 | pass 35 | else: 36 | for l in range(self.size//4): 37 | self.obstacle_point.append(point.Point(x+l, y)) 38 | pass 39 | 40 | def IsObstacle(self, i ,j): 41 | for p in self.obstacle_point: 42 | if i==p.x and j==p.y: 43 | return True 44 | return False --------------------------------------------------------------------------------