├── README.md ├── a3c_main.py ├── a3c_test.py ├── a3c_train.py ├── generate_test_data.py ├── maze2d.py ├── model.py ├── pretrained_models ├── m15_l20 ├── m15_l40 ├── m21_l30 ├── m21_l60 ├── m7_l15 └── m7_l30 ├── test_data ├── m15_n1000.npy ├── m21_n1000.npy └── m7_n1000.npy └── utils ├── __init__.py ├── localization.py └── maze.py /README.md: -------------------------------------------------------------------------------- 1 | # Active Neural Localization 2 | This is a PyTorch implementation of the ICLR-18 paper: 3 | 4 | [Active Neural Localization](https://arxiv.org/abs/1801.08214)
5 | Devendra Singh Chaplot, Emilio Parisotto, Ruslan Salakhutdinov
6 | Carnegie Mellon University 7 | 8 | Project Website: https://devendrachaplot.github.io/projects/Neural-Localization 9 | 10 | ### This repository contains: 11 | - Code for the Maze2D Environment which generates random 2D mazes for active localization. 12 | - Code for training an Active Neural Localization agent in Maze2D Environment using A3C. 13 | 14 | ## Dependencies 15 | - [PyTorch](http://pytorch.org) (v0.3) 16 | 17 | ## Usage 18 | 19 | ### Training 20 | For training an Active Neural Localization A3C agent with 16 threads on 7x7 mazes with maximum episode length 30: 21 | ``` 22 | python a3c_main.py --num-processes 16 --map-size 7 --max-episode-length 30 --dump-location ./saved/ --test-data ./test_data/m7_n1000.npy 23 | ``` 24 | The code will save the best model at `./saved/model_best` and the training log at `./saved/train.log`. The code uses `./test_data/m7_n1000.npy` as the test data and makes sure that any maze in the test data is not used while training. 25 | 26 | ### Evaluation 27 | After training, the model can be evaluated using: 28 | ``` 29 | python a3c_main.py --num-processes 0 --evaluate 1 --map-size 7 --max-episode-length 30 --load ./saved/model_best --test-data ./test_data/m7_n1000.npy 30 | ``` 31 | 32 | ### Pre-trained models 33 | The `pretrained_models` directory contains pre-trained models for map-size 7 (max-episode-length 15 and 30), map-size 15 (max-episode-length 20 and 40) and map-size 21 (max-episode-length 30 and 60). The test data used for training these models is provided in the `test_data` directory. 34 | 35 | For evaluating a pre-trained model on maze size 15x15 with maximum episode length 40: 36 | ``` 37 | python a3c_main.py --num-processes 0 --evaluate 1 --map-size 15 --max-episode-length 40 --load ./pretrained_models/m15_l40 --test-data ./test_data/m15_n1000.npy 38 | ``` 39 | 40 | ### Generating test data 41 | The repository contains test data of map-sizes 7, 15 and 21 with 1000 mazes each in the `test_data` directory. 42 | 43 | For generating more test data: 44 | ``` 45 | python generate_test_data.py --map-size 7 --num-mazes 100 --test-data-location ./test_data/ --test-data-filename my_new_test_data.npy 46 | ``` 47 | This will generate a test data file at `test_data/my_new_test_data.npy` containing 100 7x7 mazes. 48 | 49 | ### All arguments 50 | All arguments for a3c_main.py: 51 | ``` 52 | -h, --help show this help message and exit 53 | -l L, --max-episode-length L 54 | maximum length of an episode (default: 30) 55 | -m MAP_SIZE, --map-size MAP_SIZE 56 | m: Size of the maze m x m (default: 7), must be an odd 57 | natural number 58 | --lr LR learning rate (default: 0.001) 59 | --num-iters NS number of training iterations per training thread 60 | (default: 10000000) 61 | --gamma G discount factor for rewards (default: 0.99) 62 | --tau T parameter for GAE (default: 1.00) 63 | --seed S random seed (default: 1) 64 | -n N, --num-processes N 65 | how many training processes to use (default: 8) 66 | --num-steps NS number of forward steps in A3C (default: 20) 67 | --hist-size HIST_SIZE 68 | action history size (default: 5) 69 | --load LOAD model path to load, 0 to not reload (default: 0) 70 | -e EVALUATE, --evaluate EVALUATE 71 | 0:Train, 1:Evaluate on test data (default: 0) 72 | -d DUMP_LOCATION, --dump-location DUMP_LOCATION 73 | path to dump models and log (default: ./saved/) 74 | -td TEST_DATA, --test-data TEST_DATA 75 | Test data filepath (default: ./test_data/m7_n1000.npy) 76 | ``` 77 | 78 | 79 | 80 | ## Cite as 81 | >Chaplot, Devendra Singh, Parisotto, Emilio and Salakhutdinov, Ruslan. 82 | Active Neural Localization. 83 | In *International Conference on Learning Representations*, 2018. 84 | ([PDF](http://arxiv.org/abs/1801.08214)) 85 | 86 | ### Bibtex: 87 | ``` 88 | @inproceedings{chaplot2018active, 89 | title={Active Neural Localization}, 90 | author={Chaplot, Devendra Singh and Parisotto, Emilio and Salakhutdinov, Ruslan}, 91 | booktitle={International Conference on Learning Representations}, 92 | year={2018} 93 | } 94 | ``` 95 | 96 | ## Acknowledgements 97 | The implementation of A3C is borrowed from https://github.com/ikostrikov/pytorch-a3c. 98 | -------------------------------------------------------------------------------- /a3c_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | os.environ["OMP_NUM_THREADS"] = "1" 4 | import sys 5 | import signal 6 | import torch 7 | import torch.multiprocessing as mp 8 | 9 | from maze2d import * 10 | from model import * 11 | from a3c_train import train 12 | from a3c_test import test 13 | 14 | import logging 15 | 16 | 17 | parser = argparse.ArgumentParser(description='Active Neural Localization') 18 | 19 | # Environment arguments 20 | parser.add_argument('-l', '--max-episode-length', type=int, 21 | default=30, metavar='L', 22 | help='maximum length of an episode (default: 30)') 23 | parser.add_argument('-m', '--map-size', type=int, default=7, 24 | help='''m: Size of the maze m x m (default: 7), 25 | must be an odd natural number''') 26 | 27 | # A3C and model arguments 28 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 29 | help='learning rate (default: 0.001)') 30 | parser.add_argument('--num-iters', type=int, default=1000000, metavar='NS', 31 | help='''number of training iterations per training thread 32 | (default: 10000000)''') 33 | parser.add_argument('--gamma', type=float, default=0.99, metavar='G', 34 | help='discount factor for rewards (default: 0.99)') 35 | parser.add_argument('--tau', type=float, default=1.00, metavar='T', 36 | help='parameter for GAE (default: 1.00)') 37 | parser.add_argument('--seed', type=int, default=1, metavar='S', 38 | help='random seed (default: 1)') 39 | parser.add_argument('-n', '--num-processes', type=int, default=8, metavar='N', 40 | help='how many training processes to use (default: 8)') 41 | parser.add_argument('--num-steps', type=int, default=20, metavar='NS', 42 | help='number of forward steps in A3C (default: 20)') 43 | parser.add_argument('--hist-size', type=int, default=5, 44 | help='action history size (default: 5)') 45 | parser.add_argument('--load', type=str, default="0", 46 | help='model path to load, 0 to not reload (default: 0)') 47 | parser.add_argument('-e', '--evaluate', type=int, default=0, 48 | help='0:Train, 1:Evaluate on test data (default: 0)') 49 | parser.add_argument('-d', '--dump-location', type=str, default="./saved/", 50 | help='path to dump models and log (default: ./saved/)') 51 | parser.add_argument('-td', '--test-data', type=str, 52 | default="./test_data/m7_n1000.npy", 53 | help='''Test data filepath 54 | (default: ./test_data/m7_n1000.npy)''') 55 | 56 | if __name__ == '__main__': 57 | args = parser.parse_args() 58 | torch.manual_seed(args.seed) 59 | 60 | if not os.path.exists(args.dump_location): 61 | os.makedirs(args.dump_location) 62 | 63 | logging.basicConfig( 64 | filename=args.dump_location + 65 | 'train.log', 66 | level=logging.INFO) 67 | 68 | assert args.evaluate == 0 or args.num_processes == 0, \ 69 | "Can't train while evaluating, either n=0 or e=0" 70 | 71 | shared_model = Localization_2D_A3C(args) 72 | 73 | if args.load != "0": 74 | shared_model.load_state_dict(torch.load(args.load)) 75 | shared_model.share_memory() 76 | 77 | signal.signal(signal.SIGINT, signal.signal(signal.SIGINT, signal.SIG_IGN)) 78 | processes = [] 79 | 80 | p = mp.Process(target=test, args=(args.num_processes, args, shared_model)) 81 | p.start() 82 | processes.append(p) 83 | 84 | for rank in range(0, args.num_processes): 85 | p = mp.Process(target=train, args=(rank, args, shared_model)) 86 | p.start() 87 | processes.append(p) 88 | 89 | try: 90 | for p in processes: 91 | p.join() 92 | except KeyboardInterrupt: 93 | print("Stopping training. " + 94 | "Best model stored at {}model_best".format(args.dump_location)) 95 | for p in processes: 96 | p.terminate() 97 | -------------------------------------------------------------------------------- /a3c_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | 4 | from maze2d import * 5 | from model import * 6 | from collections import deque 7 | 8 | 9 | def test(rank, args, shared_model): 10 | args.seed = args.seed + rank 11 | torch.manual_seed(args.seed) 12 | np.random.seed(args.seed) 13 | 14 | env = Maze2D(args) 15 | action_hist_size = args.hist_size 16 | 17 | model = Localization_2D_A3C(args) 18 | if (args.load != "0"): 19 | print("Loading model {}".format(args.load)) 20 | model.load_state_dict(torch.load(args.load)) 21 | model.eval() 22 | 23 | reward_sum = 0 24 | episode_length = 0 25 | rewards_list = [] 26 | accuracy_list = [] 27 | best_reward = 0.0 28 | done = True 29 | 30 | if args.evaluate != 0: 31 | test_freq = env.test_mazes.shape[0] 32 | else: 33 | test_freq = 1000 34 | 35 | start_time = time.time() 36 | 37 | state, depth = env.reset() 38 | state = torch.from_numpy(state).float() 39 | 40 | while True: 41 | episode_length += 1 42 | if done: 43 | if (args.evaluate == 0): 44 | # Sync with the shared model 45 | model.load_state_dict(shared_model.state_dict()) 46 | 47 | # filling action history with action 3 at the start of the episode 48 | action_hist = deque( 49 | [3] * action_hist_size, 50 | maxlen=action_hist_size) 51 | action_seq = [] 52 | else: 53 | action_hist.append(action) 54 | 55 | ax = Variable(torch.from_numpy(np.array(action_hist)), 56 | volatile=True) 57 | dx = Variable(torch.from_numpy(np.array([depth])).long(), 58 | volatile=True) 59 | tx = Variable(torch.from_numpy(np.array([episode_length])).long(), 60 | volatile=True) 61 | 62 | value, logit = model( 63 | (Variable(state.unsqueeze(0), volatile=True), (ax, dx, tx))) 64 | prob = F.softmax(logit, dim=1) 65 | action = prob.max(1)[1].data.numpy()[0] 66 | 67 | state, reward, done, depth = env.step(action) 68 | 69 | done = done or episode_length >= args.max_episode_length 70 | reward_sum += reward 71 | 72 | if done: 73 | rewards_list.append(reward_sum) 74 | if reward >= 1: 75 | accuracy = 1 76 | else: 77 | accuracy = 0 78 | accuracy_list.append(accuracy) 79 | 80 | if(len(rewards_list) >= test_freq): 81 | time_elapsed = time.gmtime(time.time() - start_time) 82 | print(" ".join([ 83 | "Time: {0:0=2d}d".format(time_elapsed.tm_mday-1), 84 | "{},".format(time.strftime("%Hh %Mm %Ss", time_elapsed)), 85 | "Avg Reward: {0:.3f},".format(np.mean(rewards_list)), 86 | "Avg Accuracy: {0:.3f},".format(np.mean(accuracy_list)), 87 | "Best Reward: {0:.3f}".format(best_reward)])) 88 | logging.info(" ".join([ 89 | "Time: {0:0=2d}d".format(time_elapsed.tm_mday-1), 90 | "{},".format(time.strftime("%Hh %Mm %Ss", time_elapsed)), 91 | "Avg Reward: {0:.3f},".format(np.mean(rewards_list)), 92 | "Avg Accuracy: {0:.3f},".format(np.mean(accuracy_list)), 93 | "Best Reward: {0:.3f}".format(best_reward)])) 94 | if args.evaluate != 0: 95 | return 96 | elif (np.mean(rewards_list) >= best_reward): 97 | torch.save(model.state_dict(), 98 | args.dump_location + "model_best") 99 | best_reward = np.mean(rewards_list) 100 | rewards_list = [] 101 | accuracy_list = [] 102 | 103 | reward_sum = 0 104 | episode_length = 0 105 | state, depth = env.reset() 106 | 107 | state = torch.from_numpy(state).float() 108 | -------------------------------------------------------------------------------- /a3c_train.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | import logging 3 | 4 | from maze2d import * 5 | from model import * 6 | from collections import deque 7 | 8 | 9 | def ensure_shared_grads(model, shared_model): 10 | for param, shared_param in zip( 11 | model.parameters(), shared_model.parameters()): 12 | if shared_param.grad is not None: 13 | return 14 | shared_param._grad = param.grad 15 | 16 | 17 | def train(rank, args, shared_model): 18 | args.seed = args.seed + rank 19 | torch.manual_seed(args.seed) 20 | np.random.seed(args.seed) 21 | 22 | env = Maze2D(args) 23 | action_hist_size = args.hist_size 24 | 25 | model = Localization_2D_A3C(args) 26 | if (args.load != "0"): 27 | print("Training thread: {}, Loading model {}".format(rank, args.load)) 28 | model.load_state_dict(torch.load(args.load)) 29 | optimizer = optim.SGD(shared_model.parameters(), lr=args.lr) 30 | model.train() 31 | 32 | values = [] 33 | log_probs = [] 34 | p_losses = [] 35 | v_losses = [] 36 | 37 | episode_length = 0 38 | num_iters = 0 39 | done = True 40 | 41 | state, depth = env.reset() 42 | state = torch.from_numpy(state) 43 | while num_iters < args.num_iters/1000: 44 | # Sync with the shared model 45 | model.load_state_dict(shared_model.state_dict()) 46 | if done: 47 | # filling action history with action 3 at the start of the episode 48 | action_hist = deque( 49 | [3] * action_hist_size, 50 | maxlen=action_hist_size) 51 | episode_length = 0 52 | else: 53 | action_hist.append(action) 54 | 55 | values = [] 56 | log_probs = [] 57 | rewards = [] 58 | entropies = [] 59 | 60 | for step in range(args.num_steps): 61 | episode_length += 1 62 | state = state.float() 63 | ax = Variable(torch.from_numpy(np.array(action_hist))) 64 | dx = Variable(torch.from_numpy(np.array([depth])).long()) 65 | tx = Variable(torch.from_numpy(np.array([episode_length])).long()) 66 | 67 | value, logit = model( 68 | (Variable(state.unsqueeze(0)), (ax, dx, tx))) 69 | prob = F.softmax(logit, dim=1) 70 | log_prob = F.log_softmax(logit, dim=1) 71 | entropy = -(log_prob * prob).sum(1) 72 | entropies.append(entropy) 73 | 74 | action = prob.multinomial().data 75 | 76 | log_prob = log_prob.gather(1, Variable(action)) 77 | 78 | action = action.numpy()[0, 0] 79 | 80 | state, reward, done, depth = env.step(action) 81 | done = done or episode_length >= args.max_episode_length 82 | 83 | if done: 84 | episode_length = 0 85 | state, depth = env.reset() 86 | 87 | state = torch.from_numpy(state) 88 | values.append(value) 89 | log_probs.append(log_prob) 90 | rewards.append(reward) 91 | 92 | if done: 93 | break 94 | 95 | R = torch.zeros(1, 1) 96 | state = state.float() 97 | if not done: 98 | action_hist.append(action) 99 | ax = Variable(torch.from_numpy(np.array(action_hist))) 100 | dx = Variable(torch.from_numpy(np.array([depth])).long()) 101 | tx = Variable(torch.from_numpy(np.array([episode_length])).long()) 102 | value, _ = model((Variable(state.unsqueeze(0)), (ax, dx, tx))) 103 | R = value.data 104 | 105 | values.append(Variable(R)) 106 | policy_loss = 0 107 | value_loss = 0 108 | R = Variable(R) 109 | gae = torch.zeros(1, 1) 110 | for i in reversed(range(len(rewards))): 111 | R = args.gamma * R + rewards[i] 112 | advantage = R - values[i] 113 | value_loss = value_loss + 0.5 * advantage.pow(2) 114 | 115 | # Generalized Advantage Estimataion 116 | delta_t = rewards[i] + args.gamma * \ 117 | values[i + 1].data - values[i].data 118 | gae = gae * args.gamma * args.tau + delta_t 119 | 120 | policy_loss = policy_loss - \ 121 | log_probs[i] * Variable(gae) - 0.01 * entropies[i] 122 | 123 | optimizer.zero_grad() 124 | 125 | p_losses.append(policy_loss.data[0, 0]) 126 | v_losses.append(value_loss.data[0, 0]) 127 | 128 | if(len(p_losses) > 1000): 129 | num_iters += 1 130 | print(" ".join([ 131 | "Training thread: {:2d},".format(rank), 132 | "Num iters: {:4d}K,".format(num_iters), 133 | "Avg policy loss: {0:+.3f},".format(np.mean(p_losses)), 134 | "Avg value loss: {0:+.3f}".format(np.mean(v_losses))])) 135 | logging.info(" ".join([ 136 | "Training thread: {:2d},".format(rank), 137 | "Num iters: {:4d}K,".format(num_iters), 138 | "Avg policy loss: {0:+.3f},".format(np.mean(p_losses)), 139 | "Avg value loss: {0:+.3f}".format(np.mean(v_losses))])) 140 | p_losses = [] 141 | v_losses = [] 142 | 143 | (policy_loss + 0.5 * value_loss).backward() 144 | torch.nn.utils.clip_grad_norm(model.parameters(), 40) 145 | 146 | ensure_shared_grads(model, shared_model) 147 | optimizer.step() 148 | 149 | print("Training thread {} completed".format(rank)) 150 | logging.info("Training thread {} completed".format(rank)) 151 | -------------------------------------------------------------------------------- /generate_test_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from maze2d import * 4 | 5 | import logging 6 | 7 | 8 | parser = argparse.ArgumentParser(description='Generate test mazes') 9 | parser.add_argument('-n', '--num-mazes', type=int, default=1000, 10 | help='Number of mazes to generate (default: 1000)') 11 | parser.add_argument('-m', '--map-size', type=int, default=7, 12 | help='''m: Size of the maze m x m (default: 7), 13 | must be an odd natural number''') 14 | parser.add_argument('-tdl', '--test-data-location', type=str, 15 | default="./test_data/", 16 | help='Data location (default: ./test_data/)') 17 | parser.add_argument('-tdf', '--test-data-filename', type=str, 18 | default="m7_n1000.npy", 19 | help='Data location (default: m7_n1000.npy)') 20 | 21 | 22 | if __name__ == '__main__': 23 | args = parser.parse_args() 24 | test_mazes = [] 25 | 26 | if not os.path.exists(args.test_data_location): 27 | os.makedirs(args.test_data_location) 28 | 29 | while len(test_mazes) < args.num_mazes: 30 | map_design = generate_map(args.map_size) 31 | position = np.array(get_random_position(map_design)) 32 | orientation = np.array([np.random.randint(4)]) 33 | 34 | maze = np.concatenate((map_design.flatten(), position, orientation)) 35 | 36 | # Make sure the maze doesn't exist in the test mazes already 37 | if not any((maze == x).all() for x in test_mazes): 38 | # Make sure map is not symmetric 39 | if not (map_design == np.rot90(map_design)).all() and \ 40 | not (map_design == np.rot90(np.rot90(map_design))).all(): 41 | test_mazes.append(maze) 42 | 43 | filepath = os.path.join(args.test_data_location, args.test_data_filename) 44 | np.save(filepath, test_mazes) 45 | -------------------------------------------------------------------------------- /maze2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from utils.maze import * 4 | from utils.localization import * 5 | 6 | 7 | class Maze2D(object): 8 | 9 | def __init__(self, args): 10 | self.args = args 11 | self.test_mazes = np.load(args.test_data) 12 | self.test_maze_idx = 0 13 | return 14 | 15 | def reset(self): 16 | 17 | # Load a test maze during evaluation 18 | if self.args.evaluate != 0: 19 | maze_in_test_data = False 20 | maze = self.test_mazes[self.test_maze_idx] 21 | self.orientation = int(maze[-1]) 22 | self.position = (int(maze[-3]), int(maze[-2])) 23 | self.map_design = maze[:-3].reshape(self.args.map_size, 24 | self.args.map_size) 25 | self.test_maze_idx += 1 26 | else: 27 | maze_in_test_data = True 28 | 29 | # Generate a maze 30 | while maze_in_test_data: 31 | # Generate a map design 32 | self.map_design = generate_map(self.args.map_size) 33 | 34 | # Get random initial position and orientation of the agent 35 | self.position = get_random_position(self.map_design) 36 | self.orientation = np.random.randint(4) 37 | 38 | maze = np.concatenate((self.map_design.flatten(), 39 | np.array(self.position), 40 | np.array([self.orientation]))) 41 | 42 | # Make sure the maze doesn't exist in the test mazes 43 | if not any((maze == x).all() for x in self.test_mazes): 44 | # Make sure map is not symmetric 45 | if not (self.map_design == 46 | np.rot90(self.map_design)).all() \ 47 | and not (self.map_design == 48 | np.rot90(np.rot90(self.map_design))).all(): 49 | maze_in_test_data = False 50 | 51 | # Pre-compute likelihoods of all observations on the map for efficiency 52 | self.likelihoods = get_all_likelihoods(self.map_design) 53 | 54 | # Get current observation and likelihood 55 | curr_depth = get_depth(self.map_design, self.position, 56 | self.orientation) 57 | curr_likelihood = self.likelihoods[int(curr_depth) - 1] 58 | 59 | # Posterior is just the likelihood as prior is uniform 60 | self.posterior = curr_likelihood 61 | 62 | # Renormalization of the posterior 63 | self.posterior /= np.sum(self.posterior) 64 | self.t = 0 65 | 66 | # next state for the policy model 67 | self.state = np.concatenate((self.posterior, np.expand_dims( 68 | self.map_design, axis=0)), axis=0) 69 | return self.state, int(curr_depth) 70 | 71 | def step(self, action_id): 72 | # Get the observation before taking the action 73 | curr_depth = get_depth(self.map_design, self.position, 74 | self.orientation) 75 | 76 | # Posterior from last step is the prior for this step 77 | prior = self.posterior 78 | 79 | # Transform the prior according to the action taken 80 | prior = transition_function(prior, curr_depth, action_id) 81 | 82 | # Calculate position and orientation after taking the action 83 | self.position, self.orientation = get_next_state( 84 | self.map_design, self.position, self.orientation, action_id) 85 | 86 | # Get the observation and likelihood after taking the action 87 | curr_depth = get_depth( 88 | self.map_design, self.position, self.orientation) 89 | curr_likelihood = self.likelihoods[int(curr_depth) - 1] 90 | 91 | # Posterior = Prior * Likelihood 92 | self.posterior = np.multiply(curr_likelihood, prior) 93 | 94 | # Renormalization of the posterior 95 | self.posterior /= np.sum(self.posterior) 96 | 97 | # Calculate the reward 98 | reward = self.posterior.max() 99 | 100 | self.t += 1 101 | if self.t == self.args.max_episode_length: 102 | is_final = True 103 | else: 104 | is_final = False 105 | 106 | # next state for the policy model 107 | self.state = np.concatenate( 108 | (self.posterior, np.expand_dims( 109 | self.map_design, axis=0)), axis=0) 110 | 111 | return self.state, reward, is_final, int(curr_depth) 112 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | def normalized_columns_initializer(weights, std=1.0): 10 | out = torch.randn(weights.size()) 11 | out *= std / torch.sqrt(out.pow(2).sum(1, keepdim=True).expand_as(out)) 12 | return out 13 | 14 | 15 | def weights_init(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('Conv') != -1: 18 | weight_shape = list(m.weight.data.size()) 19 | fan_in = np.prod(weight_shape[1:4]) 20 | fan_out = np.prod(weight_shape[2:4]) * weight_shape[0] 21 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 22 | m.weight.data.uniform_(-w_bound, w_bound) 23 | m.bias.data.fill_(0) 24 | elif classname.find('Linear') != -1: 25 | weight_shape = list(m.weight.data.size()) 26 | fan_in = weight_shape[1] 27 | fan_out = weight_shape[0] 28 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 29 | m.weight.data.uniform_(-w_bound, w_bound) 30 | m.bias.data.fill_(0) 31 | 32 | 33 | class Localization_2D_A3C(torch.nn.Module): 34 | 35 | def __init__(self, args): 36 | super(Localization_2D_A3C, self).__init__() 37 | 38 | self.map_size = args.map_size 39 | 40 | num_orientations = 4 41 | num_actions = 3 42 | n_policy_conv1_filters = 16 43 | n_policy_conv2_filters = 16 44 | size_policy_conv1_filters = 3 45 | size_policy_conv2_filters = 3 46 | self.action_emb_dim = 8 47 | self.depth_emb_dim = 8 48 | self.time_emb_dim = 8 49 | self.action_hist_size = args.hist_size 50 | 51 | conv_out_height = (((self.map_size - size_policy_conv1_filters) + 1) - 52 | size_policy_conv2_filters) + 1 53 | conv_out_width = (((self.map_size - size_policy_conv1_filters) + 1) - 54 | size_policy_conv2_filters) + 1 55 | 56 | self.policy_conv1 = nn.Conv2d(num_orientations + 1, 57 | n_policy_conv1_filters, 58 | size_policy_conv1_filters, 59 | stride=1) 60 | self.policy_conv2 = nn.Conv2d(n_policy_conv1_filters, 61 | n_policy_conv2_filters, 62 | size_policy_conv2_filters, 63 | stride=1) 64 | 65 | self.action_emb_layer = nn.Embedding(num_actions + 1, 66 | self.action_emb_dim) 67 | self.depth_emb_layer = nn.Embedding(args.map_size, 68 | self.depth_emb_dim) 69 | self.time_emb_layer = nn.Embedding(args.max_episode_length + 1, 70 | self.time_emb_dim) 71 | 72 | self.proj_layer = nn.Linear( 73 | n_policy_conv2_filters * conv_out_height * conv_out_width, 256) 74 | self.critic_linear = nn.Linear( 75 | 256 + self.action_emb_dim * self.action_hist_size + 76 | self.depth_emb_dim + self.time_emb_dim, 1) 77 | self.actor_linear = nn.Linear( 78 | 256 + self.action_emb_dim * self.action_hist_size + 79 | self.depth_emb_dim + self.time_emb_dim, num_actions) 80 | 81 | self.apply(weights_init) 82 | self.actor_linear.weight.data = normalized_columns_initializer( 83 | self.actor_linear.weight.data, 0.01) 84 | self.actor_linear.bias.data.fill_(0) 85 | self.critic_linear.weight.data = normalized_columns_initializer( 86 | self.critic_linear.weight.data, 1.0) 87 | self.critic_linear.bias.data.fill_(0) 88 | 89 | self.train() 90 | 91 | def forward(self, inputs): 92 | inputs, (ax, dx, tx) = inputs 93 | conv_out = F.elu(self.policy_conv1(inputs)) 94 | conv_out = F.elu(self.policy_conv2(conv_out)) 95 | conv_out = conv_out.view(conv_out.size(0), -1) 96 | proj = self.proj_layer(conv_out) 97 | action_emb = self.action_emb_layer(ax) 98 | depth_emb = self.depth_emb_layer(dx) 99 | time_emb = self.time_emb_layer(tx) 100 | x = torch.cat(( 101 | proj, 102 | action_emb.view(-1, self.action_emb_dim * self.action_hist_size), 103 | depth_emb.view(-1, self.depth_emb_dim), 104 | time_emb.view(-1, self.time_emb_dim)), 1) 105 | 106 | return self.critic_linear(x), self.actor_linear(x) 107 | -------------------------------------------------------------------------------- /pretrained_models/m15_l20: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/pretrained_models/m15_l20 -------------------------------------------------------------------------------- /pretrained_models/m15_l40: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/pretrained_models/m15_l40 -------------------------------------------------------------------------------- /pretrained_models/m21_l30: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/pretrained_models/m21_l30 -------------------------------------------------------------------------------- /pretrained_models/m21_l60: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/pretrained_models/m21_l60 -------------------------------------------------------------------------------- /pretrained_models/m7_l15: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/pretrained_models/m7_l15 -------------------------------------------------------------------------------- /pretrained_models/m7_l30: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/pretrained_models/m7_l30 -------------------------------------------------------------------------------- /test_data/m15_n1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/test_data/m15_n1000.npy -------------------------------------------------------------------------------- /test_data/m21_n1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/test_data/m21_n1000.npy -------------------------------------------------------------------------------- /test_data/m7_n1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/test_data/m7_n1000.npy -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/utils/__init__.py -------------------------------------------------------------------------------- /utils/localization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from .maze import * 5 | 6 | 7 | def transition_function(belief_map, depth, action): 8 | (o, m, n) = belief_map.shape 9 | if action == 'TURN_RIGHT' or action == 1: 10 | belief_map = np.append(belief_map, 11 | belief_map[0, :, :] 12 | ).reshape(o + 1, m, n) 13 | belief_map = np.delete(belief_map, 0, axis=0) 14 | elif action == "TURN_LEFT" or action == 0: 15 | belief_map = np.insert(belief_map, 0, 16 | belief_map[-1, :, :], 17 | axis=0).reshape(o + 1, m, n) 18 | belief_map = np.delete(belief_map, -1, axis=0) 19 | elif action == "MOVE_FORWARD" or action == 2: 20 | if depth != 1: 21 | new_belief = np.zeros(belief_map.shape) 22 | for orientation in range(belief_map.shape[0]): 23 | B = belief_map[orientation] 24 | Bcap = shift_belief(B, orientation) 25 | new_belief[orientation, :, :] = Bcap 26 | belief_map = new_belief 27 | return belief_map 28 | 29 | 30 | def get_all_likelihoods(map_design): 31 | num_orientations = 4 32 | all_likelihoods = np.zeros( 33 | [map_design.shape[0] - 2, num_orientations, 34 | map_design.shape[0], map_design.shape[1]]) 35 | for orientation in range(num_orientations): 36 | for i, element in np.ndenumerate(all_likelihoods[0, orientation]): 37 | depth = get_depth(map_design, i, orientation) 38 | if depth > 0: 39 | all_likelihoods[int(depth) - 1][orientation][i] += 1 40 | return all_likelihoods 41 | 42 | 43 | def shift_belief(B, orientation): 44 | if orientation == 0 or orientation == "east": 45 | Bcap = np.insert( 46 | B, 0, np.zeros( 47 | B.shape[1]), axis=1).reshape( 48 | B.shape[0], B.shape[1] + 1) 49 | Bcap = np.delete(Bcap, -1, axis=1) 50 | elif orientation == 2 or orientation == "west": 51 | Bcap = np.append(B, np.zeros([B.shape[1], 1]), axis=1).reshape( 52 | B.shape[0], B.shape[1] + 1) 53 | Bcap = np.delete(Bcap, 0, axis=1) 54 | elif orientation == 1 or orientation == "north": 55 | Bcap = np.append(B, np.zeros([1, B.shape[1]]), axis=0).reshape( 56 | B.shape[0] + 1, B.shape[1]) 57 | Bcap = np.delete(Bcap, 0, axis=0) 58 | elif orientation == 3 or orientation == "south": 59 | Bcap = np.insert( 60 | B, 0, np.zeros( 61 | B.shape[1]), axis=0).reshape( 62 | B.shape[0] + 1, B.shape[1]) 63 | Bcap = np.delete(Bcap, -1, axis=0) 64 | else: 65 | assert False, "Invalid orientation" 66 | return Bcap 67 | -------------------------------------------------------------------------------- /utils/maze.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.random as npr 3 | import itertools 4 | 5 | 6 | def generate_map(maze_size, decimation=0.): 7 | """ 8 | Generates a maze using Kruskal's algorithm 9 | https://en.wikipedia.org/wiki/Maze_generation_algorithm 10 | """ 11 | m = (maze_size - 1)//2 12 | n = (maze_size - 1)//2 13 | 14 | maze = np.ones((maze_size, maze_size)) 15 | for i, j in list(itertools.product(range(m), range(n))): 16 | maze[2*i+1, 2*j+1] = 0 17 | m = m - 1 18 | L = np.arange(n+1) 19 | R = np.arange(n) 20 | L[n] = n-1 21 | 22 | while m > 0: 23 | for i in range(n): 24 | j = L[i+1] 25 | if (i != j and npr.randint(3) != 0): 26 | R[j] = R[i] 27 | L[R[j]] = j 28 | R[i] = i + 1 29 | L[R[i]] = i 30 | maze[2*(n-m)-1, 2*i+2] = 0 31 | if (i != L[i] and npr.randint(3) != 0): 32 | L[R[i]] = L[i] 33 | R[L[i]] = R[i] 34 | L[i] = i 35 | R[i] = i 36 | else: 37 | maze[2*(n-m), 2*i+1] = 0 38 | m -= 1 39 | 40 | for i in range(n): 41 | j = L[i+1] 42 | if (i != j and (i == L[i] or npr.randint(3) != 0)): 43 | R[j] = R[i] 44 | L[R[j]] = j 45 | R[i] = i+1 46 | L[R[i]] = i 47 | maze[2*(n-m)-1, 2*i+2] = 0 48 | L[R[i]] = L[i] 49 | R[L[i]] = R[i] 50 | L[i] = i 51 | R[i] = i 52 | return maze 53 | 54 | 55 | def get_depth(map_design, position, orientation): 56 | m, n = map_design.shape 57 | depth = 0 58 | new_tuple = position 59 | while(compare_tuples(new_tuple, tuple([m - 1, n - 1])) and 60 | compare_tuples(tuple([0, 0]), new_tuple)): 61 | if map_design[new_tuple] != 0: 62 | break 63 | else: 64 | new_tuple = get_tuple(new_tuple, orientation) 65 | depth += 1 66 | return depth 67 | 68 | 69 | def get_next_state(map_design, position, orientation, action): 70 | m, n = map_design.shape 71 | if action == 'TURN_LEFT' or action == 0: 72 | orientation = (orientation + 1) % 4 73 | elif action == "TURN_RIGHT" or action == 1: 74 | orientation = (orientation - 1) % 4 75 | elif action == "MOVE_FORWARD" or action == 2: 76 | new_tuple = get_tuple(position, orientation) 77 | if compare_tuples(new_tuple, tuple([m - 1, n - 1])) and \ 78 | compare_tuples(tuple([0, 0]), new_tuple) and \ 79 | map_design[new_tuple] == 0: 80 | position = new_tuple 81 | return position, orientation 82 | 83 | 84 | def get_random_position(map_design): 85 | m, n = map_design.shape 86 | while True: 87 | index = tuple([np.random.randint(m), np.random.randint(n)]) 88 | if map_design[index] == 0: 89 | return index 90 | 91 | 92 | def get_tuple(i, orientation): 93 | if orientation == 0 or orientation == "east": 94 | new_tuple = tuple([i[0], i[1] + 1]) 95 | elif orientation == 2 or orientation == "west": 96 | new_tuple = tuple([i[0], i[1] - 1]) 97 | elif orientation == 1 or orientation == "north": 98 | new_tuple = tuple([i[0] - 1, i[1]]) 99 | elif orientation == 3 or orientation == "south": 100 | new_tuple = tuple([i[0] + 1, i[1]]) 101 | else: 102 | assert False, "Invalid orientation" 103 | return new_tuple 104 | 105 | 106 | def compare_tuples(a, b): 107 | """ 108 | Returns true if all elements of a are less than 109 | or equal to b 110 | """ 111 | assert len(a) == len(b), "Unequal lengths of tuples for comparison" 112 | for i in range(len(a)): 113 | if a[i] > b[i]: 114 | return False 115 | return True 116 | 117 | 118 | if __name__ == '__main__': 119 | print(generate_map(7)) 120 | --------------------------------------------------------------------------------