├── README.md
├── a3c_main.py
├── a3c_test.py
├── a3c_train.py
├── generate_test_data.py
├── maze2d.py
├── model.py
├── pretrained_models
├── m15_l20
├── m15_l40
├── m21_l30
├── m21_l60
├── m7_l15
└── m7_l30
├── test_data
├── m15_n1000.npy
├── m21_n1000.npy
└── m7_n1000.npy
└── utils
├── __init__.py
├── localization.py
└── maze.py
/README.md:
--------------------------------------------------------------------------------
1 | # Active Neural Localization
2 | This is a PyTorch implementation of the ICLR-18 paper:
3 |
4 | [Active Neural Localization](https://arxiv.org/abs/1801.08214)
5 | Devendra Singh Chaplot, Emilio Parisotto, Ruslan Salakhutdinov
6 | Carnegie Mellon University
7 |
8 | Project Website: https://devendrachaplot.github.io/projects/Neural-Localization
9 |
10 | ### This repository contains:
11 | - Code for the Maze2D Environment which generates random 2D mazes for active localization.
12 | - Code for training an Active Neural Localization agent in Maze2D Environment using A3C.
13 |
14 | ## Dependencies
15 | - [PyTorch](http://pytorch.org) (v0.3)
16 |
17 | ## Usage
18 |
19 | ### Training
20 | For training an Active Neural Localization A3C agent with 16 threads on 7x7 mazes with maximum episode length 30:
21 | ```
22 | python a3c_main.py --num-processes 16 --map-size 7 --max-episode-length 30 --dump-location ./saved/ --test-data ./test_data/m7_n1000.npy
23 | ```
24 | The code will save the best model at `./saved/model_best` and the training log at `./saved/train.log`. The code uses `./test_data/m7_n1000.npy` as the test data and makes sure that any maze in the test data is not used while training.
25 |
26 | ### Evaluation
27 | After training, the model can be evaluated using:
28 | ```
29 | python a3c_main.py --num-processes 0 --evaluate 1 --map-size 7 --max-episode-length 30 --load ./saved/model_best --test-data ./test_data/m7_n1000.npy
30 | ```
31 |
32 | ### Pre-trained models
33 | The `pretrained_models` directory contains pre-trained models for map-size 7 (max-episode-length 15 and 30), map-size 15 (max-episode-length 20 and 40) and map-size 21 (max-episode-length 30 and 60). The test data used for training these models is provided in the `test_data` directory.
34 |
35 | For evaluating a pre-trained model on maze size 15x15 with maximum episode length 40:
36 | ```
37 | python a3c_main.py --num-processes 0 --evaluate 1 --map-size 15 --max-episode-length 40 --load ./pretrained_models/m15_l40 --test-data ./test_data/m15_n1000.npy
38 | ```
39 |
40 | ### Generating test data
41 | The repository contains test data of map-sizes 7, 15 and 21 with 1000 mazes each in the `test_data` directory.
42 |
43 | For generating more test data:
44 | ```
45 | python generate_test_data.py --map-size 7 --num-mazes 100 --test-data-location ./test_data/ --test-data-filename my_new_test_data.npy
46 | ```
47 | This will generate a test data file at `test_data/my_new_test_data.npy` containing 100 7x7 mazes.
48 |
49 | ### All arguments
50 | All arguments for a3c_main.py:
51 | ```
52 | -h, --help show this help message and exit
53 | -l L, --max-episode-length L
54 | maximum length of an episode (default: 30)
55 | -m MAP_SIZE, --map-size MAP_SIZE
56 | m: Size of the maze m x m (default: 7), must be an odd
57 | natural number
58 | --lr LR learning rate (default: 0.001)
59 | --num-iters NS number of training iterations per training thread
60 | (default: 10000000)
61 | --gamma G discount factor for rewards (default: 0.99)
62 | --tau T parameter for GAE (default: 1.00)
63 | --seed S random seed (default: 1)
64 | -n N, --num-processes N
65 | how many training processes to use (default: 8)
66 | --num-steps NS number of forward steps in A3C (default: 20)
67 | --hist-size HIST_SIZE
68 | action history size (default: 5)
69 | --load LOAD model path to load, 0 to not reload (default: 0)
70 | -e EVALUATE, --evaluate EVALUATE
71 | 0:Train, 1:Evaluate on test data (default: 0)
72 | -d DUMP_LOCATION, --dump-location DUMP_LOCATION
73 | path to dump models and log (default: ./saved/)
74 | -td TEST_DATA, --test-data TEST_DATA
75 | Test data filepath (default: ./test_data/m7_n1000.npy)
76 | ```
77 |
78 |
79 |
80 | ## Cite as
81 | >Chaplot, Devendra Singh, Parisotto, Emilio and Salakhutdinov, Ruslan.
82 | Active Neural Localization.
83 | In *International Conference on Learning Representations*, 2018.
84 | ([PDF](http://arxiv.org/abs/1801.08214))
85 |
86 | ### Bibtex:
87 | ```
88 | @inproceedings{chaplot2018active,
89 | title={Active Neural Localization},
90 | author={Chaplot, Devendra Singh and Parisotto, Emilio and Salakhutdinov, Ruslan},
91 | booktitle={International Conference on Learning Representations},
92 | year={2018}
93 | }
94 | ```
95 |
96 | ## Acknowledgements
97 | The implementation of A3C is borrowed from https://github.com/ikostrikov/pytorch-a3c.
98 |
--------------------------------------------------------------------------------
/a3c_main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | os.environ["OMP_NUM_THREADS"] = "1"
4 | import sys
5 | import signal
6 | import torch
7 | import torch.multiprocessing as mp
8 |
9 | from maze2d import *
10 | from model import *
11 | from a3c_train import train
12 | from a3c_test import test
13 |
14 | import logging
15 |
16 |
17 | parser = argparse.ArgumentParser(description='Active Neural Localization')
18 |
19 | # Environment arguments
20 | parser.add_argument('-l', '--max-episode-length', type=int,
21 | default=30, metavar='L',
22 | help='maximum length of an episode (default: 30)')
23 | parser.add_argument('-m', '--map-size', type=int, default=7,
24 | help='''m: Size of the maze m x m (default: 7),
25 | must be an odd natural number''')
26 |
27 | # A3C and model arguments
28 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
29 | help='learning rate (default: 0.001)')
30 | parser.add_argument('--num-iters', type=int, default=1000000, metavar='NS',
31 | help='''number of training iterations per training thread
32 | (default: 10000000)''')
33 | parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
34 | help='discount factor for rewards (default: 0.99)')
35 | parser.add_argument('--tau', type=float, default=1.00, metavar='T',
36 | help='parameter for GAE (default: 1.00)')
37 | parser.add_argument('--seed', type=int, default=1, metavar='S',
38 | help='random seed (default: 1)')
39 | parser.add_argument('-n', '--num-processes', type=int, default=8, metavar='N',
40 | help='how many training processes to use (default: 8)')
41 | parser.add_argument('--num-steps', type=int, default=20, metavar='NS',
42 | help='number of forward steps in A3C (default: 20)')
43 | parser.add_argument('--hist-size', type=int, default=5,
44 | help='action history size (default: 5)')
45 | parser.add_argument('--load', type=str, default="0",
46 | help='model path to load, 0 to not reload (default: 0)')
47 | parser.add_argument('-e', '--evaluate', type=int, default=0,
48 | help='0:Train, 1:Evaluate on test data (default: 0)')
49 | parser.add_argument('-d', '--dump-location', type=str, default="./saved/",
50 | help='path to dump models and log (default: ./saved/)')
51 | parser.add_argument('-td', '--test-data', type=str,
52 | default="./test_data/m7_n1000.npy",
53 | help='''Test data filepath
54 | (default: ./test_data/m7_n1000.npy)''')
55 |
56 | if __name__ == '__main__':
57 | args = parser.parse_args()
58 | torch.manual_seed(args.seed)
59 |
60 | if not os.path.exists(args.dump_location):
61 | os.makedirs(args.dump_location)
62 |
63 | logging.basicConfig(
64 | filename=args.dump_location +
65 | 'train.log',
66 | level=logging.INFO)
67 |
68 | assert args.evaluate == 0 or args.num_processes == 0, \
69 | "Can't train while evaluating, either n=0 or e=0"
70 |
71 | shared_model = Localization_2D_A3C(args)
72 |
73 | if args.load != "0":
74 | shared_model.load_state_dict(torch.load(args.load))
75 | shared_model.share_memory()
76 |
77 | signal.signal(signal.SIGINT, signal.signal(signal.SIGINT, signal.SIG_IGN))
78 | processes = []
79 |
80 | p = mp.Process(target=test, args=(args.num_processes, args, shared_model))
81 | p.start()
82 | processes.append(p)
83 |
84 | for rank in range(0, args.num_processes):
85 | p = mp.Process(target=train, args=(rank, args, shared_model))
86 | p.start()
87 | processes.append(p)
88 |
89 | try:
90 | for p in processes:
91 | p.join()
92 | except KeyboardInterrupt:
93 | print("Stopping training. " +
94 | "Best model stored at {}model_best".format(args.dump_location))
95 | for p in processes:
96 | p.terminate()
97 |
--------------------------------------------------------------------------------
/a3c_test.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 |
4 | from maze2d import *
5 | from model import *
6 | from collections import deque
7 |
8 |
9 | def test(rank, args, shared_model):
10 | args.seed = args.seed + rank
11 | torch.manual_seed(args.seed)
12 | np.random.seed(args.seed)
13 |
14 | env = Maze2D(args)
15 | action_hist_size = args.hist_size
16 |
17 | model = Localization_2D_A3C(args)
18 | if (args.load != "0"):
19 | print("Loading model {}".format(args.load))
20 | model.load_state_dict(torch.load(args.load))
21 | model.eval()
22 |
23 | reward_sum = 0
24 | episode_length = 0
25 | rewards_list = []
26 | accuracy_list = []
27 | best_reward = 0.0
28 | done = True
29 |
30 | if args.evaluate != 0:
31 | test_freq = env.test_mazes.shape[0]
32 | else:
33 | test_freq = 1000
34 |
35 | start_time = time.time()
36 |
37 | state, depth = env.reset()
38 | state = torch.from_numpy(state).float()
39 |
40 | while True:
41 | episode_length += 1
42 | if done:
43 | if (args.evaluate == 0):
44 | # Sync with the shared model
45 | model.load_state_dict(shared_model.state_dict())
46 |
47 | # filling action history with action 3 at the start of the episode
48 | action_hist = deque(
49 | [3] * action_hist_size,
50 | maxlen=action_hist_size)
51 | action_seq = []
52 | else:
53 | action_hist.append(action)
54 |
55 | ax = Variable(torch.from_numpy(np.array(action_hist)),
56 | volatile=True)
57 | dx = Variable(torch.from_numpy(np.array([depth])).long(),
58 | volatile=True)
59 | tx = Variable(torch.from_numpy(np.array([episode_length])).long(),
60 | volatile=True)
61 |
62 | value, logit = model(
63 | (Variable(state.unsqueeze(0), volatile=True), (ax, dx, tx)))
64 | prob = F.softmax(logit, dim=1)
65 | action = prob.max(1)[1].data.numpy()[0]
66 |
67 | state, reward, done, depth = env.step(action)
68 |
69 | done = done or episode_length >= args.max_episode_length
70 | reward_sum += reward
71 |
72 | if done:
73 | rewards_list.append(reward_sum)
74 | if reward >= 1:
75 | accuracy = 1
76 | else:
77 | accuracy = 0
78 | accuracy_list.append(accuracy)
79 |
80 | if(len(rewards_list) >= test_freq):
81 | time_elapsed = time.gmtime(time.time() - start_time)
82 | print(" ".join([
83 | "Time: {0:0=2d}d".format(time_elapsed.tm_mday-1),
84 | "{},".format(time.strftime("%Hh %Mm %Ss", time_elapsed)),
85 | "Avg Reward: {0:.3f},".format(np.mean(rewards_list)),
86 | "Avg Accuracy: {0:.3f},".format(np.mean(accuracy_list)),
87 | "Best Reward: {0:.3f}".format(best_reward)]))
88 | logging.info(" ".join([
89 | "Time: {0:0=2d}d".format(time_elapsed.tm_mday-1),
90 | "{},".format(time.strftime("%Hh %Mm %Ss", time_elapsed)),
91 | "Avg Reward: {0:.3f},".format(np.mean(rewards_list)),
92 | "Avg Accuracy: {0:.3f},".format(np.mean(accuracy_list)),
93 | "Best Reward: {0:.3f}".format(best_reward)]))
94 | if args.evaluate != 0:
95 | return
96 | elif (np.mean(rewards_list) >= best_reward):
97 | torch.save(model.state_dict(),
98 | args.dump_location + "model_best")
99 | best_reward = np.mean(rewards_list)
100 | rewards_list = []
101 | accuracy_list = []
102 |
103 | reward_sum = 0
104 | episode_length = 0
105 | state, depth = env.reset()
106 |
107 | state = torch.from_numpy(state).float()
108 |
--------------------------------------------------------------------------------
/a3c_train.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 | import logging
3 |
4 | from maze2d import *
5 | from model import *
6 | from collections import deque
7 |
8 |
9 | def ensure_shared_grads(model, shared_model):
10 | for param, shared_param in zip(
11 | model.parameters(), shared_model.parameters()):
12 | if shared_param.grad is not None:
13 | return
14 | shared_param._grad = param.grad
15 |
16 |
17 | def train(rank, args, shared_model):
18 | args.seed = args.seed + rank
19 | torch.manual_seed(args.seed)
20 | np.random.seed(args.seed)
21 |
22 | env = Maze2D(args)
23 | action_hist_size = args.hist_size
24 |
25 | model = Localization_2D_A3C(args)
26 | if (args.load != "0"):
27 | print("Training thread: {}, Loading model {}".format(rank, args.load))
28 | model.load_state_dict(torch.load(args.load))
29 | optimizer = optim.SGD(shared_model.parameters(), lr=args.lr)
30 | model.train()
31 |
32 | values = []
33 | log_probs = []
34 | p_losses = []
35 | v_losses = []
36 |
37 | episode_length = 0
38 | num_iters = 0
39 | done = True
40 |
41 | state, depth = env.reset()
42 | state = torch.from_numpy(state)
43 | while num_iters < args.num_iters/1000:
44 | # Sync with the shared model
45 | model.load_state_dict(shared_model.state_dict())
46 | if done:
47 | # filling action history with action 3 at the start of the episode
48 | action_hist = deque(
49 | [3] * action_hist_size,
50 | maxlen=action_hist_size)
51 | episode_length = 0
52 | else:
53 | action_hist.append(action)
54 |
55 | values = []
56 | log_probs = []
57 | rewards = []
58 | entropies = []
59 |
60 | for step in range(args.num_steps):
61 | episode_length += 1
62 | state = state.float()
63 | ax = Variable(torch.from_numpy(np.array(action_hist)))
64 | dx = Variable(torch.from_numpy(np.array([depth])).long())
65 | tx = Variable(torch.from_numpy(np.array([episode_length])).long())
66 |
67 | value, logit = model(
68 | (Variable(state.unsqueeze(0)), (ax, dx, tx)))
69 | prob = F.softmax(logit, dim=1)
70 | log_prob = F.log_softmax(logit, dim=1)
71 | entropy = -(log_prob * prob).sum(1)
72 | entropies.append(entropy)
73 |
74 | action = prob.multinomial().data
75 |
76 | log_prob = log_prob.gather(1, Variable(action))
77 |
78 | action = action.numpy()[0, 0]
79 |
80 | state, reward, done, depth = env.step(action)
81 | done = done or episode_length >= args.max_episode_length
82 |
83 | if done:
84 | episode_length = 0
85 | state, depth = env.reset()
86 |
87 | state = torch.from_numpy(state)
88 | values.append(value)
89 | log_probs.append(log_prob)
90 | rewards.append(reward)
91 |
92 | if done:
93 | break
94 |
95 | R = torch.zeros(1, 1)
96 | state = state.float()
97 | if not done:
98 | action_hist.append(action)
99 | ax = Variable(torch.from_numpy(np.array(action_hist)))
100 | dx = Variable(torch.from_numpy(np.array([depth])).long())
101 | tx = Variable(torch.from_numpy(np.array([episode_length])).long())
102 | value, _ = model((Variable(state.unsqueeze(0)), (ax, dx, tx)))
103 | R = value.data
104 |
105 | values.append(Variable(R))
106 | policy_loss = 0
107 | value_loss = 0
108 | R = Variable(R)
109 | gae = torch.zeros(1, 1)
110 | for i in reversed(range(len(rewards))):
111 | R = args.gamma * R + rewards[i]
112 | advantage = R - values[i]
113 | value_loss = value_loss + 0.5 * advantage.pow(2)
114 |
115 | # Generalized Advantage Estimataion
116 | delta_t = rewards[i] + args.gamma * \
117 | values[i + 1].data - values[i].data
118 | gae = gae * args.gamma * args.tau + delta_t
119 |
120 | policy_loss = policy_loss - \
121 | log_probs[i] * Variable(gae) - 0.01 * entropies[i]
122 |
123 | optimizer.zero_grad()
124 |
125 | p_losses.append(policy_loss.data[0, 0])
126 | v_losses.append(value_loss.data[0, 0])
127 |
128 | if(len(p_losses) > 1000):
129 | num_iters += 1
130 | print(" ".join([
131 | "Training thread: {:2d},".format(rank),
132 | "Num iters: {:4d}K,".format(num_iters),
133 | "Avg policy loss: {0:+.3f},".format(np.mean(p_losses)),
134 | "Avg value loss: {0:+.3f}".format(np.mean(v_losses))]))
135 | logging.info(" ".join([
136 | "Training thread: {:2d},".format(rank),
137 | "Num iters: {:4d}K,".format(num_iters),
138 | "Avg policy loss: {0:+.3f},".format(np.mean(p_losses)),
139 | "Avg value loss: {0:+.3f}".format(np.mean(v_losses))]))
140 | p_losses = []
141 | v_losses = []
142 |
143 | (policy_loss + 0.5 * value_loss).backward()
144 | torch.nn.utils.clip_grad_norm(model.parameters(), 40)
145 |
146 | ensure_shared_grads(model, shared_model)
147 | optimizer.step()
148 |
149 | print("Training thread {} completed".format(rank))
150 | logging.info("Training thread {} completed".format(rank))
151 |
--------------------------------------------------------------------------------
/generate_test_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from maze2d import *
4 |
5 | import logging
6 |
7 |
8 | parser = argparse.ArgumentParser(description='Generate test mazes')
9 | parser.add_argument('-n', '--num-mazes', type=int, default=1000,
10 | help='Number of mazes to generate (default: 1000)')
11 | parser.add_argument('-m', '--map-size', type=int, default=7,
12 | help='''m: Size of the maze m x m (default: 7),
13 | must be an odd natural number''')
14 | parser.add_argument('-tdl', '--test-data-location', type=str,
15 | default="./test_data/",
16 | help='Data location (default: ./test_data/)')
17 | parser.add_argument('-tdf', '--test-data-filename', type=str,
18 | default="m7_n1000.npy",
19 | help='Data location (default: m7_n1000.npy)')
20 |
21 |
22 | if __name__ == '__main__':
23 | args = parser.parse_args()
24 | test_mazes = []
25 |
26 | if not os.path.exists(args.test_data_location):
27 | os.makedirs(args.test_data_location)
28 |
29 | while len(test_mazes) < args.num_mazes:
30 | map_design = generate_map(args.map_size)
31 | position = np.array(get_random_position(map_design))
32 | orientation = np.array([np.random.randint(4)])
33 |
34 | maze = np.concatenate((map_design.flatten(), position, orientation))
35 |
36 | # Make sure the maze doesn't exist in the test mazes already
37 | if not any((maze == x).all() for x in test_mazes):
38 | # Make sure map is not symmetric
39 | if not (map_design == np.rot90(map_design)).all() and \
40 | not (map_design == np.rot90(np.rot90(map_design))).all():
41 | test_mazes.append(maze)
42 |
43 | filepath = os.path.join(args.test_data_location, args.test_data_filename)
44 | np.save(filepath, test_mazes)
45 |
--------------------------------------------------------------------------------
/maze2d.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from utils.maze import *
4 | from utils.localization import *
5 |
6 |
7 | class Maze2D(object):
8 |
9 | def __init__(self, args):
10 | self.args = args
11 | self.test_mazes = np.load(args.test_data)
12 | self.test_maze_idx = 0
13 | return
14 |
15 | def reset(self):
16 |
17 | # Load a test maze during evaluation
18 | if self.args.evaluate != 0:
19 | maze_in_test_data = False
20 | maze = self.test_mazes[self.test_maze_idx]
21 | self.orientation = int(maze[-1])
22 | self.position = (int(maze[-3]), int(maze[-2]))
23 | self.map_design = maze[:-3].reshape(self.args.map_size,
24 | self.args.map_size)
25 | self.test_maze_idx += 1
26 | else:
27 | maze_in_test_data = True
28 |
29 | # Generate a maze
30 | while maze_in_test_data:
31 | # Generate a map design
32 | self.map_design = generate_map(self.args.map_size)
33 |
34 | # Get random initial position and orientation of the agent
35 | self.position = get_random_position(self.map_design)
36 | self.orientation = np.random.randint(4)
37 |
38 | maze = np.concatenate((self.map_design.flatten(),
39 | np.array(self.position),
40 | np.array([self.orientation])))
41 |
42 | # Make sure the maze doesn't exist in the test mazes
43 | if not any((maze == x).all() for x in self.test_mazes):
44 | # Make sure map is not symmetric
45 | if not (self.map_design ==
46 | np.rot90(self.map_design)).all() \
47 | and not (self.map_design ==
48 | np.rot90(np.rot90(self.map_design))).all():
49 | maze_in_test_data = False
50 |
51 | # Pre-compute likelihoods of all observations on the map for efficiency
52 | self.likelihoods = get_all_likelihoods(self.map_design)
53 |
54 | # Get current observation and likelihood
55 | curr_depth = get_depth(self.map_design, self.position,
56 | self.orientation)
57 | curr_likelihood = self.likelihoods[int(curr_depth) - 1]
58 |
59 | # Posterior is just the likelihood as prior is uniform
60 | self.posterior = curr_likelihood
61 |
62 | # Renormalization of the posterior
63 | self.posterior /= np.sum(self.posterior)
64 | self.t = 0
65 |
66 | # next state for the policy model
67 | self.state = np.concatenate((self.posterior, np.expand_dims(
68 | self.map_design, axis=0)), axis=0)
69 | return self.state, int(curr_depth)
70 |
71 | def step(self, action_id):
72 | # Get the observation before taking the action
73 | curr_depth = get_depth(self.map_design, self.position,
74 | self.orientation)
75 |
76 | # Posterior from last step is the prior for this step
77 | prior = self.posterior
78 |
79 | # Transform the prior according to the action taken
80 | prior = transition_function(prior, curr_depth, action_id)
81 |
82 | # Calculate position and orientation after taking the action
83 | self.position, self.orientation = get_next_state(
84 | self.map_design, self.position, self.orientation, action_id)
85 |
86 | # Get the observation and likelihood after taking the action
87 | curr_depth = get_depth(
88 | self.map_design, self.position, self.orientation)
89 | curr_likelihood = self.likelihoods[int(curr_depth) - 1]
90 |
91 | # Posterior = Prior * Likelihood
92 | self.posterior = np.multiply(curr_likelihood, prior)
93 |
94 | # Renormalization of the posterior
95 | self.posterior /= np.sum(self.posterior)
96 |
97 | # Calculate the reward
98 | reward = self.posterior.max()
99 |
100 | self.t += 1
101 | if self.t == self.args.max_episode_length:
102 | is_final = True
103 | else:
104 | is_final = False
105 |
106 | # next state for the policy model
107 | self.state = np.concatenate(
108 | (self.posterior, np.expand_dims(
109 | self.map_design, axis=0)), axis=0)
110 |
111 | return self.state, reward, is_final, int(curr_depth)
112 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 |
8 |
9 | def normalized_columns_initializer(weights, std=1.0):
10 | out = torch.randn(weights.size())
11 | out *= std / torch.sqrt(out.pow(2).sum(1, keepdim=True).expand_as(out))
12 | return out
13 |
14 |
15 | def weights_init(m):
16 | classname = m.__class__.__name__
17 | if classname.find('Conv') != -1:
18 | weight_shape = list(m.weight.data.size())
19 | fan_in = np.prod(weight_shape[1:4])
20 | fan_out = np.prod(weight_shape[2:4]) * weight_shape[0]
21 | w_bound = np.sqrt(6. / (fan_in + fan_out))
22 | m.weight.data.uniform_(-w_bound, w_bound)
23 | m.bias.data.fill_(0)
24 | elif classname.find('Linear') != -1:
25 | weight_shape = list(m.weight.data.size())
26 | fan_in = weight_shape[1]
27 | fan_out = weight_shape[0]
28 | w_bound = np.sqrt(6. / (fan_in + fan_out))
29 | m.weight.data.uniform_(-w_bound, w_bound)
30 | m.bias.data.fill_(0)
31 |
32 |
33 | class Localization_2D_A3C(torch.nn.Module):
34 |
35 | def __init__(self, args):
36 | super(Localization_2D_A3C, self).__init__()
37 |
38 | self.map_size = args.map_size
39 |
40 | num_orientations = 4
41 | num_actions = 3
42 | n_policy_conv1_filters = 16
43 | n_policy_conv2_filters = 16
44 | size_policy_conv1_filters = 3
45 | size_policy_conv2_filters = 3
46 | self.action_emb_dim = 8
47 | self.depth_emb_dim = 8
48 | self.time_emb_dim = 8
49 | self.action_hist_size = args.hist_size
50 |
51 | conv_out_height = (((self.map_size - size_policy_conv1_filters) + 1) -
52 | size_policy_conv2_filters) + 1
53 | conv_out_width = (((self.map_size - size_policy_conv1_filters) + 1) -
54 | size_policy_conv2_filters) + 1
55 |
56 | self.policy_conv1 = nn.Conv2d(num_orientations + 1,
57 | n_policy_conv1_filters,
58 | size_policy_conv1_filters,
59 | stride=1)
60 | self.policy_conv2 = nn.Conv2d(n_policy_conv1_filters,
61 | n_policy_conv2_filters,
62 | size_policy_conv2_filters,
63 | stride=1)
64 |
65 | self.action_emb_layer = nn.Embedding(num_actions + 1,
66 | self.action_emb_dim)
67 | self.depth_emb_layer = nn.Embedding(args.map_size,
68 | self.depth_emb_dim)
69 | self.time_emb_layer = nn.Embedding(args.max_episode_length + 1,
70 | self.time_emb_dim)
71 |
72 | self.proj_layer = nn.Linear(
73 | n_policy_conv2_filters * conv_out_height * conv_out_width, 256)
74 | self.critic_linear = nn.Linear(
75 | 256 + self.action_emb_dim * self.action_hist_size +
76 | self.depth_emb_dim + self.time_emb_dim, 1)
77 | self.actor_linear = nn.Linear(
78 | 256 + self.action_emb_dim * self.action_hist_size +
79 | self.depth_emb_dim + self.time_emb_dim, num_actions)
80 |
81 | self.apply(weights_init)
82 | self.actor_linear.weight.data = normalized_columns_initializer(
83 | self.actor_linear.weight.data, 0.01)
84 | self.actor_linear.bias.data.fill_(0)
85 | self.critic_linear.weight.data = normalized_columns_initializer(
86 | self.critic_linear.weight.data, 1.0)
87 | self.critic_linear.bias.data.fill_(0)
88 |
89 | self.train()
90 |
91 | def forward(self, inputs):
92 | inputs, (ax, dx, tx) = inputs
93 | conv_out = F.elu(self.policy_conv1(inputs))
94 | conv_out = F.elu(self.policy_conv2(conv_out))
95 | conv_out = conv_out.view(conv_out.size(0), -1)
96 | proj = self.proj_layer(conv_out)
97 | action_emb = self.action_emb_layer(ax)
98 | depth_emb = self.depth_emb_layer(dx)
99 | time_emb = self.time_emb_layer(tx)
100 | x = torch.cat((
101 | proj,
102 | action_emb.view(-1, self.action_emb_dim * self.action_hist_size),
103 | depth_emb.view(-1, self.depth_emb_dim),
104 | time_emb.view(-1, self.time_emb_dim)), 1)
105 |
106 | return self.critic_linear(x), self.actor_linear(x)
107 |
--------------------------------------------------------------------------------
/pretrained_models/m15_l20:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/pretrained_models/m15_l20
--------------------------------------------------------------------------------
/pretrained_models/m15_l40:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/pretrained_models/m15_l40
--------------------------------------------------------------------------------
/pretrained_models/m21_l30:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/pretrained_models/m21_l30
--------------------------------------------------------------------------------
/pretrained_models/m21_l60:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/pretrained_models/m21_l60
--------------------------------------------------------------------------------
/pretrained_models/m7_l15:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/pretrained_models/m7_l15
--------------------------------------------------------------------------------
/pretrained_models/m7_l30:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/pretrained_models/m7_l30
--------------------------------------------------------------------------------
/test_data/m15_n1000.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/test_data/m15_n1000.npy
--------------------------------------------------------------------------------
/test_data/m21_n1000.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/test_data/m21_n1000.npy
--------------------------------------------------------------------------------
/test_data/m7_n1000.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/test_data/m7_n1000.npy
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devendrachaplot/Neural-Localization/74a184f7384599b16376d94ec236d965067bc518/utils/__init__.py
--------------------------------------------------------------------------------
/utils/localization.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | from .maze import *
5 |
6 |
7 | def transition_function(belief_map, depth, action):
8 | (o, m, n) = belief_map.shape
9 | if action == 'TURN_RIGHT' or action == 1:
10 | belief_map = np.append(belief_map,
11 | belief_map[0, :, :]
12 | ).reshape(o + 1, m, n)
13 | belief_map = np.delete(belief_map, 0, axis=0)
14 | elif action == "TURN_LEFT" or action == 0:
15 | belief_map = np.insert(belief_map, 0,
16 | belief_map[-1, :, :],
17 | axis=0).reshape(o + 1, m, n)
18 | belief_map = np.delete(belief_map, -1, axis=0)
19 | elif action == "MOVE_FORWARD" or action == 2:
20 | if depth != 1:
21 | new_belief = np.zeros(belief_map.shape)
22 | for orientation in range(belief_map.shape[0]):
23 | B = belief_map[orientation]
24 | Bcap = shift_belief(B, orientation)
25 | new_belief[orientation, :, :] = Bcap
26 | belief_map = new_belief
27 | return belief_map
28 |
29 |
30 | def get_all_likelihoods(map_design):
31 | num_orientations = 4
32 | all_likelihoods = np.zeros(
33 | [map_design.shape[0] - 2, num_orientations,
34 | map_design.shape[0], map_design.shape[1]])
35 | for orientation in range(num_orientations):
36 | for i, element in np.ndenumerate(all_likelihoods[0, orientation]):
37 | depth = get_depth(map_design, i, orientation)
38 | if depth > 0:
39 | all_likelihoods[int(depth) - 1][orientation][i] += 1
40 | return all_likelihoods
41 |
42 |
43 | def shift_belief(B, orientation):
44 | if orientation == 0 or orientation == "east":
45 | Bcap = np.insert(
46 | B, 0, np.zeros(
47 | B.shape[1]), axis=1).reshape(
48 | B.shape[0], B.shape[1] + 1)
49 | Bcap = np.delete(Bcap, -1, axis=1)
50 | elif orientation == 2 or orientation == "west":
51 | Bcap = np.append(B, np.zeros([B.shape[1], 1]), axis=1).reshape(
52 | B.shape[0], B.shape[1] + 1)
53 | Bcap = np.delete(Bcap, 0, axis=1)
54 | elif orientation == 1 or orientation == "north":
55 | Bcap = np.append(B, np.zeros([1, B.shape[1]]), axis=0).reshape(
56 | B.shape[0] + 1, B.shape[1])
57 | Bcap = np.delete(Bcap, 0, axis=0)
58 | elif orientation == 3 or orientation == "south":
59 | Bcap = np.insert(
60 | B, 0, np.zeros(
61 | B.shape[1]), axis=0).reshape(
62 | B.shape[0] + 1, B.shape[1])
63 | Bcap = np.delete(Bcap, -1, axis=0)
64 | else:
65 | assert False, "Invalid orientation"
66 | return Bcap
67 |
--------------------------------------------------------------------------------
/utils/maze.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy.random as npr
3 | import itertools
4 |
5 |
6 | def generate_map(maze_size, decimation=0.):
7 | """
8 | Generates a maze using Kruskal's algorithm
9 | https://en.wikipedia.org/wiki/Maze_generation_algorithm
10 | """
11 | m = (maze_size - 1)//2
12 | n = (maze_size - 1)//2
13 |
14 | maze = np.ones((maze_size, maze_size))
15 | for i, j in list(itertools.product(range(m), range(n))):
16 | maze[2*i+1, 2*j+1] = 0
17 | m = m - 1
18 | L = np.arange(n+1)
19 | R = np.arange(n)
20 | L[n] = n-1
21 |
22 | while m > 0:
23 | for i in range(n):
24 | j = L[i+1]
25 | if (i != j and npr.randint(3) != 0):
26 | R[j] = R[i]
27 | L[R[j]] = j
28 | R[i] = i + 1
29 | L[R[i]] = i
30 | maze[2*(n-m)-1, 2*i+2] = 0
31 | if (i != L[i] and npr.randint(3) != 0):
32 | L[R[i]] = L[i]
33 | R[L[i]] = R[i]
34 | L[i] = i
35 | R[i] = i
36 | else:
37 | maze[2*(n-m), 2*i+1] = 0
38 | m -= 1
39 |
40 | for i in range(n):
41 | j = L[i+1]
42 | if (i != j and (i == L[i] or npr.randint(3) != 0)):
43 | R[j] = R[i]
44 | L[R[j]] = j
45 | R[i] = i+1
46 | L[R[i]] = i
47 | maze[2*(n-m)-1, 2*i+2] = 0
48 | L[R[i]] = L[i]
49 | R[L[i]] = R[i]
50 | L[i] = i
51 | R[i] = i
52 | return maze
53 |
54 |
55 | def get_depth(map_design, position, orientation):
56 | m, n = map_design.shape
57 | depth = 0
58 | new_tuple = position
59 | while(compare_tuples(new_tuple, tuple([m - 1, n - 1])) and
60 | compare_tuples(tuple([0, 0]), new_tuple)):
61 | if map_design[new_tuple] != 0:
62 | break
63 | else:
64 | new_tuple = get_tuple(new_tuple, orientation)
65 | depth += 1
66 | return depth
67 |
68 |
69 | def get_next_state(map_design, position, orientation, action):
70 | m, n = map_design.shape
71 | if action == 'TURN_LEFT' or action == 0:
72 | orientation = (orientation + 1) % 4
73 | elif action == "TURN_RIGHT" or action == 1:
74 | orientation = (orientation - 1) % 4
75 | elif action == "MOVE_FORWARD" or action == 2:
76 | new_tuple = get_tuple(position, orientation)
77 | if compare_tuples(new_tuple, tuple([m - 1, n - 1])) and \
78 | compare_tuples(tuple([0, 0]), new_tuple) and \
79 | map_design[new_tuple] == 0:
80 | position = new_tuple
81 | return position, orientation
82 |
83 |
84 | def get_random_position(map_design):
85 | m, n = map_design.shape
86 | while True:
87 | index = tuple([np.random.randint(m), np.random.randint(n)])
88 | if map_design[index] == 0:
89 | return index
90 |
91 |
92 | def get_tuple(i, orientation):
93 | if orientation == 0 or orientation == "east":
94 | new_tuple = tuple([i[0], i[1] + 1])
95 | elif orientation == 2 or orientation == "west":
96 | new_tuple = tuple([i[0], i[1] - 1])
97 | elif orientation == 1 or orientation == "north":
98 | new_tuple = tuple([i[0] - 1, i[1]])
99 | elif orientation == 3 or orientation == "south":
100 | new_tuple = tuple([i[0] + 1, i[1]])
101 | else:
102 | assert False, "Invalid orientation"
103 | return new_tuple
104 |
105 |
106 | def compare_tuples(a, b):
107 | """
108 | Returns true if all elements of a are less than
109 | or equal to b
110 | """
111 | assert len(a) == len(b), "Unequal lengths of tuples for comparison"
112 | for i in range(len(a)):
113 | if a[i] > b[i]:
114 | return False
115 | return True
116 |
117 |
118 | if __name__ == '__main__':
119 | print(generate_map(7))
120 |
--------------------------------------------------------------------------------