├── README.md ├── model ├── __init__.py └── model.py ├── utils ├── __init__.py ├── utils.py └── plot.py ├── algorithm ├── alg_config.py ├── __init__.py ├── search_tree.py └── tsa.py ├── environment ├── __init__.py ├── env_config.py └── maze_env.py ├── trained_models ├── NEXT_2d.pt └── NEXT_3d.pt ├── maze_files ├── mazes_15_2_3000.npz └── mazes_15_3_3000.npz ├── LICENSE ├── .gitignore └── main.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # NEXT-learning-to-plan -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .plot import * -------------------------------------------------------------------------------- /algorithm/alg_config.py: -------------------------------------------------------------------------------- 1 | # RRT resolution 2 | RRT_EPS = 5e-2 3 | 4 | -------------------------------------------------------------------------------- /environment/__init__.py: -------------------------------------------------------------------------------- 1 | from .env_config import * 2 | from .maze_env import * -------------------------------------------------------------------------------- /algorithm/__init__.py: -------------------------------------------------------------------------------- 1 | from .search_tree import * 2 | from .tsa import * 3 | from .alg_config import * -------------------------------------------------------------------------------- /trained_models/NEXT_2d.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeurEXT/NEXT-learning-to-plan/HEAD/trained_models/NEXT_2d.pt -------------------------------------------------------------------------------- /trained_models/NEXT_3d.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeurEXT/NEXT-learning-to-plan/HEAD/trained_models/NEXT_3d.pt -------------------------------------------------------------------------------- /maze_files/mazes_15_2_3000.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeurEXT/NEXT-learning-to-plan/HEAD/maze_files/mazes_15_2_3000.npz -------------------------------------------------------------------------------- /maze_files/mazes_15_3_3000.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeurEXT/NEXT-learning-to-plan/HEAD/maze_files/mazes_15_3_3000.npz -------------------------------------------------------------------------------- /environment/env_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from algorithm import RRT_EPS 3 | 4 | STICK_LENGTH = 1.5 * 2 / 15 5 | LIMITS = np.array([1., 1., 8.*RRT_EPS]) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import random 5 | 6 | def mkdir_if_not_exist(dir_name): 7 | if not os.path.exists(dir_name): 8 | os.makedirs(dir_name) 9 | 10 | def set_random_seed(seed): 11 | np.random.seed(seed) 12 | torch.manual_seed(seed) 13 | random.seed(seed) 14 | 15 | def load_model(model, file, use_cuda=True): 16 | if use_cuda: 17 | model.load_state_dict(torch.load(file)) 18 | else: 19 | model.load_state_dict(torch.load(file, map_location=lambda storage, loc: storage)) 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 NeurEXT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | from matplotlib import cm 3 | import matplotlib.pyplot as plt 4 | import matplotlib.patches as patches 5 | import numpy as np 6 | 7 | from environment import MazeEnv 8 | 9 | def draw_node(state, color, radius_scale=1., dim=2, face=False): 10 | if dim == 2: 11 | facecolor = 'none' 12 | if face: 13 | facecolor = color 14 | circle = patches.Circle(tuple(state+1.0), radius=0.02 * radius_scale, edgecolor=color, facecolor=facecolor) 15 | plt.gca().add_patch(circle) 16 | 17 | elif dim == 3: 18 | a, b = MazeEnv._end_points(state) 19 | plt.gca().add_patch(patches.ConnectionPatch(a+1.0, b+1.0, 'data', arrowstyle="-", linewidth=2, color=color)) 20 | plt.gca().add_patch(patches.Circle(a+1.0, radius=0.02 * radius_scale, edgecolor=color, facecolor=color)) 21 | 22 | def draw_edge(state0, state1, color, dim=2, style='-'): 23 | path = patches.ConnectionPatch(tuple(state0[:2]+1.0), tuple(state1[:2]+1.0), 'data', arrowstyle=style, color=color) 24 | plt.gca().add_patch(path) 25 | 26 | def plot_tree(states, parents, problem, index=0, edge_classes=None): 27 | states = states 28 | environment_map = problem["map"] 29 | init_state = problem["init_state"] 30 | goal_state = problem["goal_state"] 31 | dim = init_state.size 32 | 33 | fig = plt.figure(figsize=(4,4)) 34 | 35 | rect = patches.Rectangle((0.0, 0.0), 2.0, 2.0, linewidth=1, edgecolor='black', facecolor='none') 36 | plt.gca().add_patch(rect) 37 | 38 | map_width = environment_map.shape 39 | d_x = 2.0 / map_width[0] 40 | d_y = 2.0 / map_width[1] 41 | for i in range(map_width[0]): 42 | for j in range(map_width[1]): 43 | if environment_map[i,j] > 0: 44 | rect = patches.Rectangle((d_x*i, d_y*j), d_x, d_y, linewidth=1, edgecolor='#253494', facecolor='#253494') 45 | plt.gca().add_patch(rect) 46 | 47 | for i in range(len(states)-1): 48 | draw_node(states[i+1], '#fdbe85', dim=dim) 49 | 50 | if edge_classes is None: 51 | draw_edge(states[i+1], states[parents[i+1]], 'green', dim=dim) 52 | else: 53 | if edge_classes[i+1] == True: 54 | draw_edge(states[i+1], states[parents[i+1]], 'blue', dim=dim) 55 | else: 56 | draw_edge(states[i+1], states[parents[i+1]], 'green', dim=dim) 57 | 58 | 59 | draw_node(init_state, '#e6550d', dim=dim, face=True) 60 | draw_node(goal_state, '#a63603', dim=dim, face=True) 61 | 62 | plt.axis([0.0, 2.0, 0.0, 2.0]) 63 | plt.axis('off') 64 | plt.axis('square') 65 | 66 | plt.subplots_adjust(left=-0., right=1.0, top=1.0, bottom=-0.) 67 | 68 | plt.show() 69 | -------------------------------------------------------------------------------- /algorithm/search_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .alg_config import RRT_EPS 4 | import torch 5 | 6 | class SearchTree: 7 | def __init__(self, env, root, model=None, dim=2): 8 | self.states = np.array([root]) 9 | self.parents = [None] 10 | self.rewired_parents = [None] 11 | self.expanded_by_rrt = [None] 12 | self.freesp = [True] 13 | self.costs = [0.] 14 | self.path_lengths = [-1] 15 | self.cumulated_collision_checks = [0] 16 | self.in_goal_region = [False] 17 | 18 | # for global exploration 19 | self.non_terminal_states = np.array([root]) 20 | self.non_terminal_idxes = [0] 21 | 22 | if model is not None: 23 | # Multi-armed bandit stats 24 | self.visits = [1] 25 | self.state_values = [model.pred_value(root)] 26 | 27 | # kernel regression stats 28 | self.w = [compute_w(env, self, idx=0)] 29 | self.w_sum = self.w[0] 30 | 31 | def update_collision_checks(search_tree, collision_checks): 32 | search_tree.cumulated_collision_checks.append(collision_checks) 33 | 34 | def rewire_to(search_tree, child_idx, new_parent_idx): 35 | search_tree.rewired_parents[child_idx] = new_parent_idx 36 | 37 | def set_cost(search_tree, idx, new_cost): 38 | search_tree.costs[idx] = new_cost 39 | 40 | # Update path length if a path is found. 41 | if idx == -1 and search_tree.in_goal_region[-1]: 42 | if search_tree.path_lengths[-1] < 0 or \ 43 | search_tree.path_lengths[-1] > new_cost: 44 | search_tree.path_lengths[-1] = new_cost 45 | 46 | def insert_new_state(env, search_tree, state, model, parent_idx, no_collision, \ 47 | done, expanded_by_rrt=False, use_GP=False): 48 | search_tree.states = np.append(search_tree.states, [state], axis=0) 49 | search_tree.parents.append(parent_idx) 50 | search_tree.rewired_parents.append(parent_idx) 51 | search_tree.expanded_by_rrt.append(expanded_by_rrt) 52 | search_tree.freesp.append(no_collision) 53 | search_tree.in_goal_region.append(done) 54 | 55 | # Will be updated in post-processing. 56 | search_tree.path_lengths.append(search_tree.path_lengths[-1]) 57 | search_tree.costs.append(-1) 58 | 59 | if no_collision and (not done): 60 | search_tree.non_terminal_states = np.append( \ 61 | search_tree.non_terminal_states, [state], axis=0) 62 | search_tree.non_terminal_idxes.append(search_tree.states.shape[0]-1) 63 | 64 | if model is not None: 65 | state_value = model.pred_value(state) 66 | search_tree.visits[parent_idx] += 1 67 | search_tree.visits.append(0) 68 | search_tree.state_values.append(state_value) 69 | 70 | search_tree.w_sum -= search_tree.w[parent_idx] 71 | parent_w = compute_w(env, search_tree, idx=parent_idx) 72 | search_tree.w[parent_idx] = parent_w 73 | search_tree.w_sum += parent_w 74 | 75 | w = compute_w(env, search_tree, state=state) 76 | search_tree.w.append(w) 77 | search_tree.w_sum += w 78 | 79 | return search_tree.states.shape[0]-1 80 | 81 | def state_kernel(env, state_A, state_B): 82 | diff = env.distance(state_A, state_B) / RRT_EPS 83 | kernel = np.exp( - (diff**2) * 1.) 84 | 85 | return kernel 86 | 87 | def compute_w(env, search_tree, idx=None, state=None): 88 | if state is None: 89 | state = search_tree.states[idx] 90 | 91 | kernel = state_kernel(env, search_tree.states, state) 92 | w_ = np.sum(kernel) 93 | 94 | return w_ 95 | -------------------------------------------------------------------------------- /environment/maze_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from algorithm import RRT_EPS 3 | from .env_config import LIMITS, STICK_LENGTH 4 | 5 | class MazeEnv: 6 | ''' 7 | Interface class for maze environment 8 | ''' 9 | def __init__(self, dim): 10 | print("Initializing environment...") 11 | self.dim = dim 12 | self.collision_check_count = 0 13 | 14 | # load map from file 15 | map_file = 'maze_files/mazes_15_%d_3000.npz' % dim 16 | print("loading mazes from %s" % map_file) 17 | with np.load(map_file) as f: 18 | self.maps = f['maps'] 19 | self.init_states = f['init_states'] 20 | self.goal_states = f['goal_states'] 21 | 22 | self.size = self.maps.shape[0] 23 | self.width = self.maps.shape[1] 24 | self.order = list(range(self.size)) 25 | self.episode_i = 0 26 | 27 | def init_new_problem(self, index=None): 28 | ''' 29 | Initialize a new planning problem 30 | ''' 31 | if index is None: 32 | index = self.episode_i 33 | 34 | self.map = self.maps[self.order[index]] 35 | self.init_state = self.init_states[self.order[index]] 36 | self.goal_state = self.goal_states[self.order[index]] 37 | self.episode_i += 1 38 | self.collision_check_count = 0 39 | 40 | return self.get_problem() 41 | 42 | def get_problem(self): 43 | problem = { 44 | "map": self.map, 45 | "init_state": self.init_state, 46 | "goal_state": self.goal_state 47 | } 48 | return problem 49 | 50 | def uniform_sample(self): 51 | ''' 52 | Uniformlly sample in the configuration space 53 | ''' 54 | sample = np.random.uniform(-LIMITS[:self.dim], LIMITS[:self.dim]) 55 | return sample 56 | 57 | def distance(self, from_state, to_state): 58 | ''' 59 | Distance metric 60 | ''' 61 | diff = np.abs(to_state - from_state) 62 | if diff.ndim == 1: 63 | diff = diff.reshape(1, -1) 64 | 65 | if self.dim >= 3: 66 | diff[:,2] = np.min((diff[:,2], np.abs(diff[:,2] - 2*LIMITS[2])), axis=0) 67 | assert (np.abs(diff[:,2]) <= LIMITS[2]).all() 68 | 69 | return np.sqrt(np.sum(diff**2, axis=-1)) 70 | 71 | def interpolate(self, from_state, to_state, ratio): 72 | diff = to_state - from_state 73 | 74 | if self.dim >= 3: 75 | if np.abs(diff[2]) > LIMITS[2]: 76 | if diff[2] > 0: 77 | diff[2] -= 2*LIMITS[2] 78 | else: 79 | diff[2] += 2*LIMITS[2] 80 | assert np.abs(diff[2]) <= LIMITS[2] 81 | 82 | new_state = from_state + diff * ratio 83 | 84 | if self.dim >= 3: 85 | if np.abs(new_state[2]) > LIMITS[2]: 86 | if new_state[2] > 0: 87 | new_state[2] -= 2*LIMITS[2] 88 | else: 89 | new_state[2] += 2*LIMITS[2] 90 | assert np.abs(new_state[2]) <= LIMITS[2] 91 | 92 | return new_state 93 | 94 | def in_goal_region(self, state): 95 | ''' 96 | Return whether a state(configuration) is in the goal region 97 | ''' 98 | return self.distance(state, self.goal_state) < RRT_EPS and \ 99 | self._state_fp(state) 100 | 101 | def step(self, state, action=None, new_state=None, check_collision=True): 102 | ''' 103 | Collision detection module 104 | ''' 105 | # must specify either action or new_state 106 | if action is not None: 107 | new_state = state + action 108 | 109 | new_state[:2] = new_state[:2].clip(-LIMITS[:-1], LIMITS[:-1]) 110 | if self.dim >= 3: 111 | if np.abs(new_state[2]) > LIMITS[2]: 112 | if new_state[2] > 0: 113 | new_state[2] -= 2*LIMITS[2] 114 | else: 115 | new_state[2] += 2*LIMITS[2] 116 | assert np.abs(new_state[2]) <= LIMITS[2] 117 | 118 | action = new_state - state 119 | 120 | if not check_collision: 121 | return new_state, action 122 | 123 | done = False 124 | no_collision = self._edge_fp(state, new_state) 125 | if no_collision and self.in_goal_region(new_state): 126 | done = True 127 | 128 | return new_state, action, no_collision, done 129 | 130 | #=====================internal collision check module======================= 131 | 132 | # transform a state into a discretized grid coordinate 133 | def _transform(self, state, w=15): 134 | coord = ((np.array(state)[:2].flatten() + 1.0) * w / 2.0).astype(int) 135 | coord[coord > w-1] = w-1 136 | return coord 137 | 138 | def _end_points(coord=None, l=None, center=None, theta=None, a=None,\ 139 | b=None): 140 | if theta is None: 141 | theta = coord[2] / LIMITS[2] * np.pi 142 | orient = np.array([np.cos(theta), np.sin(theta)]) 143 | if l is None: 144 | l = STICK_LENGTH 145 | 146 | if a is None and b is None: 147 | if center is None: 148 | center = np.array(coord[:2]) 149 | a = center - l / 2. * orient 150 | b = center + l / 2. * orient 151 | else: 152 | if a is not None: 153 | b = a + l * orient 154 | if b is not None: 155 | a = b - l * orient 156 | 157 | return a, b 158 | 159 | def _valid_state(self, state): 160 | return (state >= -LIMITS[:state.size]).all() and \ 161 | (state <= LIMITS[:state.size]).all() 162 | 163 | def _point_in_free_space(self, state): 164 | assert state.size == 2 165 | if not self._valid_state(state): 166 | return False 167 | 168 | return self.map[tuple(self._transform(state, self.width))] == 0 169 | 170 | def _stick_in_free_space(self, state): 171 | assert state.size == 3 172 | 173 | if not self._valid_state(state): 174 | return False 175 | 176 | a, b = MazeEnv._end_points(state) 177 | if not self._point_in_free_space(a) or not self._point_in_free_space(b): 178 | return False 179 | 180 | return self._iterative_check_segment(a, b) 181 | 182 | def _state_fp(self, state): 183 | assert state.size == 2 or state.size == 3 or state.size == 5 184 | self.collision_check_count += 1 185 | 186 | if state.size == 2: 187 | return self._point_in_free_space(state) 188 | elif state.size == 3: 189 | return self._stick_in_free_space(state) 190 | 191 | def _iterative_check_segment(self, left, right): 192 | assert left.size == 2 and right.size == 2 193 | 194 | left_coord = np.array(self._transform(left, self.width), dtype=int) 195 | right_coord = np.array(self._transform(right, self.width), dtype=int) 196 | if np.sum(np.abs(left_coord - right_coord)) > 1: 197 | mid = (left + right) / 2.0 198 | if not self._state_fp(mid): 199 | return False 200 | return self._iterative_check_segment(left, mid) and self._iterative_check_segment(mid, right) 201 | 202 | return True 203 | 204 | def _edge_fp(self, state, new_state): 205 | assert state.size == new_state.size 206 | 207 | if not self._valid_state(state) or not self._valid_state(new_state): 208 | return False 209 | if not self._state_fp(state) or not self._state_fp(new_state): 210 | return False 211 | 212 | if state.size == 2: 213 | return self._iterative_check_segment(state, new_state) 214 | else: 215 | 216 | disp = new_state - state 217 | if np.abs(disp[2]) > LIMITS[2]: 218 | if disp[2] > 0: 219 | disp[2] -= 2*LIMITS[2] 220 | else: 221 | disp[2] += 2*LIMITS[2] 222 | assert np.abs(disp[2]) <= LIMITS[2] 223 | 224 | d = self.distance(state, new_state) 225 | K = int(d / 0.015) 226 | for k in range(1, K): 227 | c = state + k*1./K * disp 228 | 229 | if state.size == 3: 230 | ca, cb = MazeEnv._end_points(c) 231 | if not self._edge_fp(ca, cb): 232 | return False 233 | 234 | return True -------------------------------------------------------------------------------- /algorithm/tsa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from queue import Queue 4 | from .search_tree import SearchTree, insert_new_state, compute_w, rewire_to, \ 5 | set_cost, update_collision_checks 6 | from .alg_config import RRT_EPS 7 | 8 | def RRTS_plan(env, T=100, stop_when_success=False): 9 | return NEXT_plan(env=env, T=T, g_explore_eps=1., \ 10 | stop_when_success=stop_when_success) 11 | 12 | def NEXT_plan(env, model=None, T=100, g_explore_eps=1., \ 13 | stop_when_success=False, UCB_type='kde'): 14 | """Robot motion planning with NEXT. 15 | 16 | Args: 17 | env: The environment which stores the problem relevant information (map, 18 | initial state, goal state), and performs collision check, goal 19 | region check, uniform sampling. 20 | model: Machine learning model used to guide vertex selection and 21 | tree expansion. 22 | T (int): Maximum number of samples allowed. 23 | g_explore_eps (float): Probability for RRT-like global exploration. 24 | stop_when_success (bool): Whether to terminate the algorithm if one path 25 | is found. 26 | UCB_type (string): Type of UCB used (one of {'kde', 'GP'}). 27 | 28 | Returns: 29 | search_tree: Search tree generated by the algorithm. 30 | success (bool): Whether a path is found. 31 | """ 32 | search_tree = SearchTree( 33 | env = env, 34 | root = env.init_state, 35 | model = model, 36 | dim = env.dim 37 | ) 38 | 39 | success = False 40 | for i in range(T): 41 | leaf_id = None 42 | 43 | # Goal-biased heuristic. 44 | if np.random.rand() < 0.05: 45 | leaf_state, parent_idx, _, no_collision, done = \ 46 | global_explore(search_tree, env, sample_state=env.goal_state) 47 | success = success or done 48 | expanded_by_rrt = True 49 | 50 | # RRT-like global exploration. 51 | elif np.random.rand() < g_explore_eps: 52 | leaf_state, parent_idx, _, no_collision, done = \ 53 | global_explore(search_tree, env) 54 | success = success or done 55 | expanded_by_rrt = True 56 | 57 | # Guided selection and expansion. 58 | else: 59 | idx = select(search_tree, env) 60 | assert search_tree.freesp[idx] 61 | # assert not search_tree.in_goal_region[idx] 62 | 63 | parent_idx = idx 64 | leaf_state, _, no_collision, done = \ 65 | expand(search_tree, parent_idx, model, env) 66 | success = success or done 67 | expanded_by_rrt = False 68 | 69 | leaf_id = insert_new_state(env, search_tree, leaf_state, model, \ 70 | parent_idx, no_collision, done, expanded_by_rrt=expanded_by_rrt) 71 | RRTS_rewire_last(env, search_tree) 72 | 73 | if success and stop_when_success: 74 | break 75 | 76 | print('success =', success, ' number of samples =', i) 77 | 78 | return search_tree, success 79 | 80 | def RRT_steer(env, sample_state, nearest, dist): 81 | """Steer the sampled state to a new state close to the search tree. 82 | 83 | Args: 84 | env: The environment which stores the problem relevant information (map, 85 | initial state, goal state), and performs collision check, goal 86 | region check, uniform sampling. 87 | sample_state: State sampled from some distribution. 88 | nearest: Nearest point in the search tree to the sampled state. 89 | dist: Distance between sample_state and nearest. 90 | 91 | Returns: 92 | new_state: Steered state. 93 | """ 94 | if dist < RRT_EPS: 95 | return sample_state 96 | 97 | ratio = RRT_EPS / dist 98 | return env.interpolate(nearest, sample_state, ratio) 99 | 100 | def global_explore(search_tree, env, sample_state=None): 101 | """One step of RRT-like expansion. 102 | 103 | Args: 104 | search_tree: Current search tree generated by the algorithm. 105 | env: The environment which stores the problem relevant information (map, 106 | initial state, goal state), and performs collision check, goal 107 | region check, uniform sampling. 108 | sample_state: A randomly sampled state (if provided). 109 | 110 | Returns: 111 | new_state: New state being added to the search tree. 112 | parent_idx: Index of the parent of the new state. 113 | action: Path segment connecting parent and new state. 114 | no_collision (bool): True <==> the path segment is collision-free. 115 | done (bool): True <==> the path segment is collision-free and the new 116 | state is inside the goal region. 117 | """ 118 | non_terminal_states = search_tree.non_terminal_states 119 | 120 | # Sample uniformly in the maze 121 | if sample_state is None: 122 | sample_state = env.uniform_sample() 123 | 124 | # Steer sample to nearby location 125 | dists = env.distance(non_terminal_states, sample_state) 126 | nearest_idx, min_dist = np.argmin(dists), np.min(dists) 127 | new_state = RRT_steer(env, sample_state, non_terminal_states[nearest_idx], \ 128 | min_dist) 129 | 130 | new_state, action, no_collision, done = env.step( 131 | state = non_terminal_states[nearest_idx], 132 | new_state = new_state 133 | ) 134 | 135 | return new_state, search_tree.non_terminal_idxes[nearest_idx], action, \ 136 | no_collision, done 137 | 138 | def select(search_tree, env, c=1., use_GP=False): 139 | """Select a point in the search tree for expansion. 140 | 141 | Args: 142 | search_tree: Current search tree generated by the algorithm. 143 | env: The environment which stores the problem relevant information (map, 144 | initial state, goal state), and performs collision check, goal 145 | region check, uniform sampling. 146 | c: Hyperparameter controlling the weight for exploration. 147 | use_GP: True <==> using Gaussian Process. 148 | 149 | Returns: 150 | idx (int): Index of the point in the tree being selected. 151 | """ 152 | scores = [] 153 | for i in range(search_tree.non_terminal_states.shape[0]): 154 | idx = search_tree.non_terminal_idxes[i] 155 | Q = search_tree.state_values[idx] 156 | U = np.sqrt(np.log(search_tree.w_sum) / search_tree.w[idx]) 157 | 158 | scores.append(Q + c*U) 159 | 160 | return search_tree.non_terminal_idxes[np.argmax(scores)] 161 | 162 | def expand(search_tree, idx, model, env, k=10, c=1., use_GP=False): 163 | """Expand a search tree from a given point. 164 | 165 | Args: 166 | search_tree: Current search tree generated by the algorithm. 167 | idx (int): Index of the selected point. 168 | model: Machine learning model used to guide the expansion. 169 | env: The environment which stores the problem relevant information (map, 170 | initial state, goal state), and performs collision check, goal 171 | region check, uniform sampling. 172 | k (int): Number of candidate actions. 173 | c: Hyperparameter controlling the weight for exploration. 174 | use_GP: True <==> using Gaussian Process. 175 | 176 | Returns: 177 | new_state: New state being added to the tree. 178 | action: Path segment connecting parent and new state. 179 | no_collision (bool): True <==> the path segment is collision-free. 180 | done (bool): True <==> the path segment is collision-free and the new 181 | state is inside the goal region. 182 | """ 183 | state = np.array(search_tree.states[idx]) 184 | candidate_actions = model.policy(state=state, k=k)[0] 185 | candidates = [] 186 | for i in range(k): 187 | action = candidate_actions[i] 188 | new_state, _ = env.step(state=state, action=action, \ 189 | check_collision=False) 190 | candidates.append(new_state) 191 | 192 | if k > 1: 193 | scores = [] 194 | Qs = model.pred_value(np.array(candidates)) 195 | for i in range(k): 196 | Q = Qs[i] 197 | w = compute_w(env, search_tree, state=candidates[i]) 198 | U = np.sqrt(np.log(search_tree.w_sum) / w) 199 | scores.append(Q + c*U) 200 | new_state = candidates[np.argmax(scores)] 201 | 202 | else: 203 | new_state = candidates[0] 204 | 205 | new_state, action, no_collision, done = env.step( 206 | state = state, 207 | new_state = new_state 208 | ) 209 | 210 | return new_state, action, no_collision, done 211 | 212 | def RRTS_rewire_last(env, search_tree, neighbor_r=None, obs_cost=2): 213 | """Locally optimize the search tree by rewiring the latest added point. 214 | 215 | Args: 216 | env: The environment which stores the problem relevant information (map, 217 | initial state, goal state), and performs collision check, goal 218 | region check, uniform sampling. 219 | search_tree: Current search tree generated by the algorithm. 220 | neighbor_r (float): Radius for rewiring. 221 | obs_cost (float): Cost for obstacle (hyperparameter). 222 | """ 223 | if neighbor_r is None: 224 | neighbor_r = RRT_EPS*3 225 | cur_tree = search_tree.states[:-1] 226 | new_state = search_tree.states[-1] 227 | nearest = search_tree.parents[-1] 228 | freesp = search_tree.freesp 229 | 230 | # Return if the latest point is inside of an obstacle. 231 | if not search_tree.freesp[-1]: 232 | set_cost(search_tree, -1, obs_cost) 233 | update_collision_checks(search_tree, env.collision_check_count) 234 | return 235 | 236 | # Find the locally optimal path to the root for the latest point. 237 | dists = env.distance(cur_tree, new_state) 238 | near = np.where(dists < neighbor_r)[0] 239 | 240 | min_cost = dists[nearest] + search_tree.costs[nearest] 241 | min_j = nearest 242 | for j in near: 243 | if not freesp[j]: 244 | continue 245 | cost_new = dists[j] + search_tree.costs[j] 246 | if cost_new < min_cost: 247 | _, _, no_collision, done = env.step( 248 | state = cur_tree[j], 249 | new_state = new_state 250 | ) 251 | if no_collision: 252 | min_cost, min_j = cost_new, j 253 | 254 | # Rewire (change parent) to the locally optimal path. 255 | rewire_to(search_tree, -1, min_j) 256 | set_cost(search_tree, -1, min_cost) 257 | 258 | # If the latest point can improve the cost for the neighbors, rewire them. 259 | for j in near: 260 | cost_new = min_cost + dists[j] 261 | if cost_new < search_tree.costs[j]: 262 | _, _, no_collision, done = env.step( 263 | state = cur_tree[j], 264 | new_state = new_state 265 | ) 266 | 267 | if no_collision: 268 | set_cost(search_tree, j, cost_new) 269 | rewire_to(search_tree, j, len(search_tree.states)-1) 270 | 271 | update_collision_checks(search_tree, env.collision_check_count) -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.distributions.multivariate_normal import MultivariateNormal 8 | 9 | from algorithm import RRT_EPS 10 | from environment import LIMITS 11 | from utils import load_model 12 | 13 | class Attention(nn.Module): 14 | def __init__(self, cuda=True, env_width=15, cap=8, dim=2): 15 | super(Attention, self).__init__() 16 | self.w = env_width 17 | self.cap = cap 18 | self.dim = dim 19 | self.fix_attention = False 20 | 21 | # coords[0:2, i, j] = [i, j] 22 | # for i, j in {0, 1, ..., w-1} 23 | idx = np.arange(self.w) 24 | col_coord = np.tile(idx, (self.w, 1)) 25 | row_coord = np.tile(idx.reshape(self.w, 1), (1, self.w)) 26 | self.coords = torch.FloatTensor(np.array([col_coord, row_coord])) 27 | self.coords = self.coords.view(1, 2, self.w, self.w) 28 | 29 | # 1x1 conv ~= mlp with shared parameters 30 | self.mlp_share = nn.Sequential( 31 | nn.Conv2d(in_channels=4, out_channels=16, kernel_size=1), 32 | nn.ReLU(), 33 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=1), 34 | nn.ReLU(), 35 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=1), 36 | nn.ReLU(), 37 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1), 38 | nn.ReLU(), 39 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1), 40 | nn.ReLU(), 41 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1), 42 | ) 43 | 44 | # 3rd-d attention 45 | self.mlp = nn.Sequential( 46 | nn.Linear(in_features=self.dim, out_features=64), 47 | nn.ReLU(), 48 | nn.Linear(in_features=64, out_features=self.cap), 49 | ) 50 | 51 | if self.fix_attention: 52 | for param in self.parameters(): 53 | param.requires_grad = False 54 | 55 | if cuda: 56 | self.coords = self.coords.cuda() 57 | 58 | def forward(self, inp): 59 | # x[0:b, 0:4, i, j] = [input[0:b, 0], input[0:b, 1], i, j] 60 | # for i, j in {0, 1, ..., w-1} 61 | x = inp[:, 0:2].contiguous().view(inp.shape[0], 2, 1, 1) 62 | x = x.expand(-1, -1, self.w, self.w) 63 | coords = self.coords.expand(x.shape[0], -1, -1, -1) 64 | x = torch.cat((x, coords), dim=1) 65 | 66 | # attention over 2D grid 67 | x = self.mlp_share(x) 68 | x = x.view(x.shape[0], -1) 69 | x = F.softmax(x, dim=-1) 70 | atten_12d = x.view(x.shape[0], 1, -1) 71 | 72 | # attention over the 3rd dimension 73 | # x = inp[:, 2:3] 74 | x = inp 75 | x = self.mlp(x) 76 | x = F.softmax(x, dim=-1) 77 | atten_3d = x.view(x.shape[0], self.cap, 1) 78 | 79 | # combine 2d and 3rd-d attention 80 | x = atten_12d.expand(-1, self.cap, -1) * atten_3d 81 | x = x.view(-1, self.cap, self.w, self.w) 82 | 83 | return x 84 | 85 | class PPN(nn.Module): 86 | def __init__(self, cuda, env_width=15, cap=8, dim=2): 87 | super(PPN, self).__init__() 88 | self.w = env_width 89 | self.cap = cap 90 | self.dim = dim 91 | 92 | self.g = 8 93 | self.latent_dim = self.cap * self.g 94 | self.iters = 20 95 | self.conv_kern = 3 96 | self.conv_pad = int((self.conv_kern - 1.0) / 2) 97 | self.conv_cap = self.cap * 8 98 | 99 | self.hidden = nn.Conv2d(in_channels=self.cap + 1, out_channels=self.latent_dim, kernel_size=3, padding=1) 100 | self.h0 = nn.Conv2d(in_channels=self.latent_dim, out_channels=self.latent_dim, kernel_size=3, padding=1) 101 | self.c0 = nn.Conv2d(in_channels=self.latent_dim, out_channels=self.latent_dim, kernel_size=3, padding=1) 102 | 103 | self.conv = nn.Conv2d(in_channels=self.latent_dim, out_channels=self.conv_cap, kernel_size=self.conv_kern, padding=self.conv_pad) 104 | self.lstm = nn.LSTMCell(self.conv_cap, self.latent_dim) 105 | 106 | self.attention_g = Attention(cuda, env_width=env_width, cap=cap, dim=dim) 107 | self.attention_s = self.attention_g 108 | 109 | self.policy = nn.Sequential( 110 | nn.Linear(in_features=self.g, out_features=128), 111 | nn.ReLU(), 112 | nn.Linear(in_features=128, out_features=64), # 128 / 64 32/32 113 | nn.ReLU(), 114 | nn.Linear(in_features=64, out_features=self.dim+1), 115 | ) 116 | 117 | def forward(self, cur_state, goal_state, maze_map): 118 | cur_state = cur_state.clone().detach() 119 | goal_state = goal_state.clone().detach() 120 | cur_state[:,-1] /= LIMITS[2] 121 | goal_state[:,-1] /= LIMITS[2] 122 | 123 | b_size = maze_map.shape[0] 124 | 125 | goal_atten = self.attention_g(goal_state) # has size [b_size, capacity, map_w, map_w] 126 | maze_map = maze_map.view(b_size, 1, self.w, self.w) 127 | x = torch.cat((maze_map, goal_atten), dim=1) 128 | 129 | h_layer = self.hidden(x) 130 | h0 = self.h0(h_layer).transpose(1, 3).contiguous().view(b_size * self.w**2, self.latent_dim) 131 | c0 = self.c0(h_layer).transpose(1, 3).contiguous().view(b_size * self.w**2, self.latent_dim) 132 | 133 | 134 | last_h, last_c = h0, c0 135 | for _ in range(0, self.iters): 136 | h_map = last_h.view(-1, self.w, self.w, self.latent_dim) 137 | h_map = h_map.transpose(3, 1) 138 | lstm_inp = self.conv(h_map).transpose(1, 3).contiguous().view(-1, self.conv_cap) 139 | last_h, last_c = self.lstm(lstm_inp, (last_h, last_c)) 140 | 141 | 142 | x = last_h.view(b_size, self.w, self.w, self.latent_dim).transpose(3, 1) 143 | x = x.view(b_size, self.g, self.cap, self.w, self.w) 144 | state_atten = self.attention_s(cur_state).view(b_size, 1, self.cap, self.w, self.w) 145 | x = x * state_atten 146 | 147 | x = x.sum(dim=-1).sum(dim=-1).sum(dim=-1) 148 | x = self.policy(x) 149 | 150 | return x 151 | 152 | def pb_forward(self, goal_state, maze_map): 153 | """Compute the problem representation. 154 | 155 | Args: 156 | goal_state: [1, self.dim] 157 | maze_map: [1, self.w, self.w, self.w] 158 | 159 | Returns: 160 | pb_rep: [1, self.g, self.cap, self.w, self.w, self.w] 161 | """ 162 | goal_state = goal_state.clone().detach() 163 | goal_state[:,-1] /= LIMITS[2] 164 | 165 | b_size = maze_map.shape[0] 166 | assert b_size == 1 167 | 168 | goal_atten = self.attention_g(goal_state) # has size [b_size, capacity, map_w, map_w] 169 | maze_map = maze_map.view(b_size, 1, self.w, self.w) 170 | x = torch.cat((maze_map, goal_atten), dim=1) 171 | 172 | h_layer = self.hidden(x) 173 | h0 = self.h0(h_layer).transpose(1, 3).contiguous().view(b_size * self.w**2, self.latent_dim) 174 | c0 = self.c0(h_layer).transpose(1, 3).contiguous().view(b_size * self.w**2, self.latent_dim) 175 | 176 | last_h, last_c = h0, c0 177 | for _ in range(0, self.iters): 178 | h_map = last_h.view(-1, self.w, self.w, self.latent_dim) 179 | h_map = h_map.transpose(3, 1) 180 | lstm_inp = self.conv(h_map).transpose(1, 3).contiguous().view(-1, self.conv_cap) 181 | last_h, last_c = self.lstm(lstm_inp, (last_h, last_c)) 182 | 183 | x = last_h.view(b_size, self.w, self.w, self.latent_dim).transpose(3, 1) 184 | x = x.view(b_size, self.g, self.cap, self.w, self.w) 185 | 186 | return x 187 | 188 | def state_forward(self, cur_states, pb_rep): 189 | """Forward using problem representation. 190 | 191 | Args: 192 | cur_states: [batch_size, self.dim] 193 | pb_rep: [1, self.g, self.cap, self.w, self.w] 194 | 195 | Returns: 196 | [actions, values]: [batch_size, self.dim + 1] 197 | """ 198 | # if self.dim >= 3: 199 | cur_states = cur_states.clone().detach() 200 | cur_states[:,-1] /= LIMITS[2] 201 | 202 | b_size = cur_states.shape[0] 203 | x = pb_rep.expand(b_size, self.g, self.cap, self.w, self.w) 204 | 205 | state_atten = self.attention_s(cur_states).view(b_size, 1, self.cap, self.w, self.w) 206 | x = x * state_atten 207 | 208 | x = x.sum(dim=-1).sum(dim=-1).sum(dim=-1) 209 | x = self.policy(x) 210 | 211 | return x 212 | 213 | 214 | class Model: 215 | def __init__(self, cuda, env_width=15, model_cap=8, dim=2, std=None, UCB_type='kde'): 216 | if std is None: 217 | std = RRT_EPS*0.3 218 | 219 | print("initializing model ...") 220 | self.net = PPN(cuda, env_width=env_width, cap=model_cap, dim=dim) 221 | self.cuda = cuda 222 | if cuda: 223 | self.net = self.net.cuda() 224 | self.std = std 225 | self.dim = dim 226 | self.var = torch.eye(self.dim)*self.std**2 227 | print('dim == ', dim) 228 | 229 | self.env_width=env_width 230 | self.UCB_type = UCB_type 231 | 232 | def set_problem(self, problem): 233 | self.problem = problem 234 | 235 | # compute problem representation 236 | assert self.net 237 | self.maze_map = problem["map"].reshape(1, self.env_width, self.env_width) 238 | self.goal_state = problem["goal_state"].reshape(1, self.dim) 239 | self.maze_map = torch.FloatTensor(self.maze_map) 240 | self.goal_state = torch.FloatTensor(self.goal_state) 241 | if self.cuda: 242 | self.maze_map = self.maze_map.cuda() 243 | self.goal_state = self.goal_state.cuda() 244 | 245 | self.pb_rep = self.net.pb_forward(self.goal_state, self.maze_map) 246 | 247 | def net_forward(self, states): 248 | if states.ndim == 1: 249 | states = states.reshape(1,-1) 250 | 251 | cur_states = torch.FloatTensor(states) 252 | if self.cuda: 253 | cur_states = cur_states.cuda() 254 | 255 | y = self.net.state_forward(cur_states, self.pb_rep) 256 | y = y.data.cpu().numpy() 257 | 258 | pred_actions = y[:, :self.dim] 259 | pred_values = y[:, -1] 260 | 261 | if pred_actions.shape[0] == 1: 262 | pred_actions = pred_actions[0] 263 | pred_values = pred_values[0] 264 | 265 | return pred_actions, pred_values 266 | 267 | def pred_value(self, states): 268 | _, state_values = self.net_forward(states) 269 | 270 | return state_values 271 | 272 | def policy(self, state, k=1): 273 | action_mean, _ = self.net_forward(state) 274 | m = MultivariateNormal(torch.FloatTensor(action_mean), self.var) 275 | 276 | actions = [] 277 | prior_values = [] 278 | 279 | for i in range(k): 280 | action = m.sample() 281 | prior_value = torch.exp(m.log_prob(action)).item() 282 | 283 | actions.append(action.cpu().numpy()) 284 | prior_values.append(prior_value) 285 | 286 | return actions, prior_values 287 | 288 | def get_net(self): 289 | return self.net 290 | 291 | def set_net(self, net): 292 | self.net = net 293 | 294 | 295 | -------------------------------------------------------------------------------- /main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "from environment import MazeEnv\n", 11 | "from model import Model\n", 12 | "from algorithm import NEXT_plan, RRTS_plan\n", 13 | "from utils import set_random_seed, load_model, plot_tree" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "set_random_seed(1234)\n", 23 | "cuda = True" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "# 2D Planning" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "Initializing environment...\n", 43 | "loading mazes from maze_files/mazes_15_2_3000.npz\n", 44 | "initializing model ...\n", 45 | "dim == 2\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "dim = 2\n", 51 | "UCB_type = 'kde'\n", 52 | "environment = MazeEnv(dim = dim)\n", 53 | "model = Model(cuda = cuda, dim = dim)\n", 54 | "model_file = 'trained_models/NEXT_%dd.pt' % dim\n", 55 | "load_model(model.net, model_file, cuda)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# Sample a problem from the environment\n", 65 | "pb_idx = 2101 # 0 - 2999\n", 66 | "pb = environment.init_new_problem(pb_idx)\n", 67 | "model.set_problem(pb)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "## NEXT Algorithm" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": [ 86 | "success = True number of samples = 86\n" 87 | ] 88 | }, 89 | { 90 | "data": { 91 | "image/png": "\n", 92 | "text/plain": [ 93 | "
" 94 | ] 95 | }, 96 | "metadata": { 97 | "needs_background": "light" 98 | }, 99 | "output_type": "display_data" 100 | } 101 | ], 102 | "source": [ 103 | "search_tree, done = NEXT_plan(\n", 104 | " env = environment,\n", 105 | " model = model,\n", 106 | " T = 500,\n", 107 | " g_explore_eps = 0.1,\n", 108 | " stop_when_success = True,\n", 109 | " UCB_type = UCB_type\n", 110 | ")\n", 111 | "plot_tree(\n", 112 | " states = search_tree.states,\n", 113 | " parents = search_tree.parents,\n", 114 | " problem = environment.get_problem()\n", 115 | ")" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "## RRT* Algorithm" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 6, 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "name": "stdout", 132 | "output_type": "stream", 133 | "text": [ 134 | "success = False number of samples = 999\n" 135 | ] 136 | }, 137 | { 138 | "data": { 139 | "image/png": "\n", 140 | "text/plain": [ 141 | "
" 142 | ] 143 | }, 144 | "metadata": { 145 | "needs_background": "light" 146 | }, 147 | "output_type": "display_data" 148 | } 149 | ], 150 | "source": [ 151 | "search_tree, done = RRTS_plan(\n", 152 | " env = environment,\n", 153 | " T = 1000,\n", 154 | " stop_when_success = True\n", 155 | ")\n", 156 | "plot_tree(\n", 157 | " states = search_tree.states,\n", 158 | " parents = search_tree.parents,\n", 159 | " problem = environment.get_problem()\n", 160 | ")" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "# 3D Planning" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 7, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "name": "stdout", 177 | "output_type": "stream", 178 | "text": [ 179 | "Initializing environment...\n", 180 | "loading mazes from maze_files/mazes_15_3_3000.npz\n", 181 | "initializing model ...\n", 182 | "dim == 3\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "dim = 3\n", 188 | "UCB_type = 'kde'\n", 189 | "environment = MazeEnv(dim = dim)\n", 190 | "model = Model(cuda = cuda, dim = dim)\n", 191 | "model_file = 'trained_models/NEXT_%dd.pt' % dim\n", 192 | "load_model(model.net, model_file, cuda)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 8, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "# Sample a problem from the environment\n", 202 | "pb_idx = 2101 # 0 - 2999\n", 203 | "pb = environment.init_new_problem(pb_idx)\n", 204 | "model.set_problem(pb)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "## NEXT Algorithm" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 9, 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "success = True number of samples = 126\n" 224 | ] 225 | }, 226 | { 227 | "data": { 228 | "image/png": "\n", 229 | "text/plain": [ 230 | "
" 231 | ] 232 | }, 233 | "metadata": { 234 | "needs_background": "light" 235 | }, 236 | "output_type": "display_data" 237 | } 238 | ], 239 | "source": [ 240 | "search_tree, done = NEXT_plan(\n", 241 | " env = environment,\n", 242 | " model = model,\n", 243 | " T = 500,\n", 244 | " g_explore_eps = 0.1,\n", 245 | " stop_when_success = True,\n", 246 | " UCB_type = UCB_type\n", 247 | ")\n", 248 | "plot_tree(\n", 249 | " states = search_tree.states,\n", 250 | " parents = search_tree.parents,\n", 251 | " problem = environment.get_problem()\n", 252 | ")" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "## RRT* Algorithm" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 10, 265 | "metadata": {}, 266 | "outputs": [ 267 | { 268 | "name": "stdout", 269 | "output_type": "stream", 270 | "text": [ 271 | "success = False number of samples = 1999\n" 272 | ] 273 | }, 274 | { 275 | "data": { 276 | "image/png": "\n", 277 | "text/plain": [ 278 | "
" 279 | ] 280 | }, 281 | "metadata": { 282 | "needs_background": "light" 283 | }, 284 | "output_type": "display_data" 285 | } 286 | ], 287 | "source": [ 288 | "search_tree, done = RRTS_plan(\n", 289 | " env = environment,\n", 290 | " T = 2000,\n", 291 | " stop_when_success = True\n", 292 | ")\n", 293 | "plot_tree(\n", 294 | " states = search_tree.states,\n", 295 | " parents = search_tree.parents,\n", 296 | " problem = environment.get_problem()\n", 297 | ")" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [] 306 | } 307 | ], 308 | "metadata": { 309 | "kernelspec": { 310 | "display_name": "Python 3", 311 | "language": "python", 312 | "name": "python3" 313 | }, 314 | "language_info": { 315 | "codemirror_mode": { 316 | "name": "ipython", 317 | "version": 3 318 | }, 319 | "file_extension": ".py", 320 | "mimetype": "text/x-python", 321 | "name": "python", 322 | "nbconvert_exporter": "python", 323 | "pygments_lexer": "ipython3", 324 | "version": "3.7.2" 325 | } 326 | }, 327 | "nbformat": 4, 328 | "nbformat_minor": 2 329 | } 330 | --------------------------------------------------------------------------------