├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── README.MD ├── __init__.py ├── dataset.py └── make_training_data.py ├── domains ├── __init__.py └── gridworld.py ├── download_weights_and_datasets.sh ├── generators ├── __init__.py └── obstacle_gen.py ├── model.py ├── requirements.txt ├── results ├── 16x16_1.png ├── 16x16_2.png ├── 28x28_1.png ├── 28x28_2.png ├── 8x8_1.png └── 8x8_2.png ├── test.py ├── train.py ├── trained └── README.md └── utility ├── __init__.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # npz 107 | *.npz 108 | 109 | # pth 110 | *.pth 111 | 112 | # jetbrains project settings 113 | .idea 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Kent Sommer 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VIN: [Value Iteration Networks](https://arxiv.org/abs/1602.02867) 2 | 3 | ![Architecture of Value Iteration Network](https://ai2-s2-public.s3.amazonaws.com/figures/2016-11-08/024f01f390ba94cbee81e82e979a65151d20a6fd/3-Figure2-1.png) 4 | 5 | ## A quick thank you 6 | A few others have released amazing related work which helped inspire and improve my own implementation. It goes without saying that this release would not be nearly as good if it were not for all of the following: 7 | * [@avivt](https://github.com/avivt) ([Paper Author](https://arxiv.org/abs/1602.02867), [MATLAB implementation](https://github.com/avivt/VIN)) 8 | * [@zuoxingdong](https://github.com/zuoxingdong) ([Tensorflow implementation](https://github.com/zuoxingdong/VIN_TensorFlow), [Pytorch implementation](https://github.com/zuoxingdong/VIN_PyTorch_Visdom)) 9 | * [@TheAbhiKumar](https://github.com/TheAbhiKumar) ([Tensorflow implementation](https://github.com/TheAbhiKumar/tensorflow-value-iteration-networks)) 10 | * [@onlytailei](https://github.com/onlytailei) ([Pytorch implementation](https://github.com/onlytailei/Value-Iteration-Networks-PyTorch)) 11 | 12 | ## Why another VIN implementation? 13 | 1. The Pytorch VIN model in this repository is, in my opinion, more readable and closer to the original Theano implementation than others I have found (both Tensorflow and Pytorch). 14 | 2. This is not simply an implementation of the VIN model in Pytorch, it is also a full Python implementation of the gridworld environments as used in the [original MATLAB implementation](https://github.com/avivt/VIN). 15 | 3. Provide a more extensible research base for others to build off of without needing to jump through the possible MATLAB paywall. 16 | 17 | ## Installation 18 | This repository requires following packages: 19 | - [SciPy](https://www.scipy.org/install.html) >= 0.19.0 20 | - [Python](https://www.python.org/) >= 2.7 (if using Python 3.x: python3-tk should be installed) 21 | - [Numpy](https://pypi.python.org/pypi/numpy) >= 1.12.1 22 | - [Matplotlib](https://matplotlib.org/users/installing.html) >= 2.0.0 23 | - [PyTorch](http://pytorch.org/) >= 0.1.11 24 | 25 | Use `pip` to install the necessary dependencies: 26 | ``` 27 | pip install -U -r requirements.txt 28 | ``` 29 | Note that PyTorch cannot be installed directly from PyPI; refer to http://pytorch.org/ for custom installation instructions specific to your needs. 30 | ## How to train 31 | #### 8x8 gridworld 32 | ```bash 33 | python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128 34 | ``` 35 | #### 16x16 gridworld 36 | ```bash 37 | python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 30 --k 20 --batch_size 128 38 | ``` 39 | #### 28x28 gridworld 40 | ```bash 41 | python train.py --datafile dataset/gridworld_28x28.npz --imsize 28 --lr 0.002 --epochs 30 --k 36 --batch_size 128 42 | ``` 43 | **Flags**: 44 | - `datafile`: The path to the data files. 45 | - `imsize`: The size of input images. One of: [8, 16, 28] 46 | - `lr`: Learning rate with RMSProp optimizer. Recommended: [0.01, 0.005, 0.002, 0.001] 47 | - `epochs`: Number of epochs to train. Default: 30 48 | - `k`: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28] 49 | - `l_i`: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image. 50 | - `l_h`: Number of channels in first convolutional layer. Default: 150, described in paper. 51 | - `l_q`: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper. 52 | - `batch_size`: Batch size. Default: 128 53 | 54 | ## How to test / visualize paths (requires training first) 55 | #### 8x8 gridworld 56 | ```bash 57 | python test.py --weights trained/vin_8x8.pth --imsize 8 --k 10 58 | ``` 59 | #### 16x16 gridworld 60 | ```bash 61 | python test.py --weights trained/vin_16x16.pth --imsize 16 --k 20 62 | ``` 63 | #### 28x28 gridworld 64 | ```bash 65 | python test.py --weights trained/vin_28x28.pth --imsize 28 --k 36 66 | ``` 67 | To visualize the optimal and predicted paths simply pass: 68 | ```bash 69 | --plot 70 | ``` 71 | 72 | **Flags**: 73 | - `weights`: Path to trained weights. 74 | - `imsize`: The size of input images. One of: [8, 16, 28] 75 | - `plot`: If supplied, the optimal and predicted paths will be plotted 76 | - `k`: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28] 77 | - `l_i`: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image. 78 | - `l_h`: Number of channels in first convolutional layer. Default: 150, described in paper. 79 | - `l_q`: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper. 80 | 81 | ## Results 82 | Gridworld | Sample One | Sample Two 83 | -- | --- | --- 84 | 8x8 | | 85 | 16x16 | | 86 | 28x28 | | 87 | 88 | ## Datasets 89 | Each data sample consists of an obstacle image and a goal image followed by the (x, y) coordinates of current state in the gridworld. 90 | 91 | Dataset size | 8x8 | 16x16 | 28x28 92 | -- | -- | -- | -- 93 | Train set | 81337 | 456309 | 1529584 94 | Test set | 13846 | 77203 | 251755 95 | 96 | The datasets (8x8, 16x16, and 28x28) included in this repository can be reproduced using the ```dataset/make_training_data.py``` script. Note that this script is not optimized and runs rather slowly (also uses a lot of memory :D) 97 | 98 | ## Performance: Success Rate 99 | This is the success rate from rollouts of the learned policy in the environment (taken over 5000 randomly generated domains). 100 | 101 | Success Rate | 8x8 | 16x16 | 28x28 102 | -- | -- | -- | -- 103 | PyTorch | 99.69% | 96.99% | 91.07% 104 | 105 | ## Performance: Test Accuracy 106 | 107 | **NOTE**: This is the **accuracy on test set**. It is different from the table in the paper, which indicates the **success rate** from rollouts of the learned policy in the environment. 108 | 109 | Test Accuracy | 8x8 | 16x16 | 28x28 110 | -- | -- | -- | -- 111 | PyTorch | 99.83% | 94.84% | 88.54% 112 | -------------------------------------------------------------------------------- /dataset/README.MD: -------------------------------------------------------------------------------- 1 | # Gridworld datasets 2 | To use the gridworld datasets you have two choices: 3 | 1. Download and place the .npz dataset files here 4 | * gridworld_8x8.npz 5 | * gridworld_16x16.npz 6 | * gridworld_28x28.npz 7 | 2. Use the dataset generation script 8 | * ```make_training_data.py``` 9 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.utils.data as data 5 | 6 | 7 | class GridworldData(data.Dataset): 8 | def __init__(self, 9 | file, 10 | imsize, 11 | train=True, 12 | transform=None, 13 | target_transform=None): 14 | assert file.endswith('.npz') # Must be .npz format 15 | self.file = file 16 | self.imsize = imsize 17 | self.transform = transform 18 | self.target_transform = target_transform 19 | self.train = train # Training set or test set 20 | 21 | self.images, self.S1, self.S2, self.labels = \ 22 | self._process(file, self.train) 23 | 24 | def __getitem__(self, index): 25 | img = self.images[index] 26 | s1 = self.S1[index] 27 | s2 = self.S2[index] 28 | label = self.labels[index] 29 | # Apply transform if we have one 30 | if self.transform is not None: 31 | img = self.transform(img) 32 | else: # Internal default transform: Just to Tensor 33 | img = torch.from_numpy(img) 34 | # Apply target transform if we have one 35 | if self.target_transform is not None: 36 | label = self.target_transform(label) 37 | return img, int(s1), int(s2), int(label) 38 | 39 | def __len__(self): 40 | return self.images.shape[0] 41 | 42 | def _process(self, file, train): 43 | """Data format: A list, [train data, test data] 44 | Each data sample: label, S1, S2, Images, in this order. 45 | """ 46 | with np.load(file, mmap_mode='r') as f: 47 | if train: 48 | images = f['arr_0'] 49 | S1 = f['arr_1'] 50 | S2 = f['arr_2'] 51 | labels = f['arr_3'] 52 | else: 53 | images = f['arr_4'] 54 | S1 = f['arr_5'] 55 | S2 = f['arr_6'] 56 | labels = f['arr_7'] 57 | # Set proper datatypes 58 | images = images.astype(np.float32) 59 | S1 = S1.astype(int) # (S1, S2) location are integers 60 | S2 = S2.astype(int) 61 | labels = labels.astype(int) # Labels are integers 62 | # Print number of samples 63 | if train: 64 | print("Number of Train Samples: {0}".format(images.shape[0])) 65 | else: 66 | print("Number of Test Samples: {0}".format(images.shape[0])) 67 | return images, S1, S2, labels 68 | -------------------------------------------------------------------------------- /dataset/make_training_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | from dataset import * 5 | 6 | import argparse 7 | 8 | sys.path.append('.') 9 | from domains.gridworld import * 10 | from generators.obstacle_gen import * 11 | sys.path.remove('.') 12 | 13 | 14 | def extract_action(traj): 15 | # Given a trajectory, outputs a 1D vector of 16 | # actions corresponding to the trajectory. 17 | n_actions = 8 18 | action_vecs = np.asarray([[-1., 0.], [1., 0.], [0., 1.], [0., -1.], 19 | [-1., 1.], [-1., -1.], [1., 1.], [1., -1.]]) 20 | action_vecs[4:] = 1 / np.sqrt(2) * action_vecs[4:] 21 | action_vecs = action_vecs.T 22 | state_diff = np.diff(traj, axis=0) 23 | norm_state_diff = state_diff * np.tile( 24 | 1 / np.sqrt(np.sum(np.square(state_diff), axis=1)), (2, 1)).T 25 | prj_state_diff = np.dot(norm_state_diff, action_vecs) 26 | actions_one_hot = np.abs(prj_state_diff - 1) < 0.00001 27 | actions = np.dot(actions_one_hot, np.arange(n_actions).T) 28 | return actions 29 | 30 | 31 | def make_data(dom_size, n_domains, max_obs, max_obs_size, n_traj, 32 | state_batch_size): 33 | 34 | X_l = [] 35 | S1_l = [] 36 | S2_l = [] 37 | Labels_l = [] 38 | 39 | dom = 0.0 40 | while dom <= n_domains: 41 | goal = [np.random.randint(dom_size[0]), np.random.randint(dom_size[1])] 42 | # Generate obstacle map 43 | obs = obstacles([dom_size[0], dom_size[1]], goal, max_obs_size) 44 | # Add obstacles to map 45 | n_obs = obs.add_n_rand_obs(max_obs) 46 | # Add border to map 47 | border_res = obs.add_border() 48 | # Ensure we have valid map 49 | if n_obs == 0 or not border_res: 50 | continue 51 | # Get final map 52 | im = obs.get_final() 53 | # Generate gridworld from obstacle map 54 | G = GridWorld(im, goal[0], goal[1]) 55 | # Get value prior 56 | value_prior = G.t_get_reward_prior() 57 | # Sample random trajectories to our goal 58 | states_xy, states_one_hot = sample_trajectory(G, n_traj) 59 | for i in range(n_traj): 60 | if len(states_xy[i]) > 1: 61 | # Get optimal actions for each state 62 | actions = extract_action(states_xy[i]) 63 | ns = states_xy[i].shape[0] - 1 64 | # Invert domain image => 0 = free, 1 = obstacle 65 | image = 1 - im 66 | # Resize domain and goal images and concate 67 | image_data = np.resize(image, (1, 1, dom_size[0], dom_size[1])) 68 | value_data = np.resize(value_prior, 69 | (1, 1, dom_size[0], dom_size[1])) 70 | iv_mixed = np.concatenate((image_data, value_data), axis=1) 71 | X_current = np.tile(iv_mixed, (ns, 1, 1, 1)) 72 | # Resize states 73 | S1_current = np.expand_dims(states_xy[i][0:ns, 0], axis=1) 74 | S2_current = np.expand_dims(states_xy[i][0:ns, 1], axis=1) 75 | # Resize labels 76 | Labels_current = np.expand_dims(actions, axis=1) 77 | # Append to output list 78 | X_l.append(X_current) 79 | S1_l.append(S1_current) 80 | S2_l.append(S2_current) 81 | Labels_l.append(Labels_current) 82 | dom += 1 83 | sys.stdout.write("\r" + str(int((dom / n_domains) * 100)) + "%") 84 | sys.stdout.flush() 85 | sys.stdout.write("\n") 86 | # Concat all outputs 87 | X_f = np.concatenate(X_l) 88 | S1_f = np.concatenate(S1_l) 89 | S2_f = np.concatenate(S2_l) 90 | Labels_f = np.concatenate(Labels_l) 91 | return X_f, S1_f, S2_f, Labels_f 92 | 93 | 94 | def main(dom_size=(28, 28), 95 | n_domains=5000, 96 | max_obs=50, 97 | max_obs_size=2, 98 | n_traj=7, 99 | state_batch_size=1): 100 | # Get path to save dataset 101 | save_path = "dataset/gridworld_{0}x{1}".format(dom_size[0], dom_size[1]) 102 | # Get training data 103 | print("Now making training data...") 104 | X_out_tr, S1_out_tr, S2_out_tr, Labels_out_tr = make_data( 105 | dom_size, n_domains, max_obs, max_obs_size, n_traj, state_batch_size) 106 | # Get testing data 107 | print("\nNow making testing data...") 108 | X_out_ts, S1_out_ts, S2_out_ts, Labels_out_ts = make_data( 109 | dom_size, n_domains / 6, max_obs, max_obs_size, n_traj, 110 | state_batch_size) 111 | # Save dataset 112 | np.savez_compressed(save_path, X_out_tr, S1_out_tr, S2_out_tr, 113 | Labels_out_tr, X_out_ts, S1_out_ts, S2_out_ts, 114 | Labels_out_ts) 115 | 116 | 117 | if __name__ == '__main__': 118 | 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument("--size", "-s", type=int, help="size of the domain", default=28) 121 | parser.add_argument("--n_domains", "-nd", type=int, help="number of domains", default=5000) 122 | parser.add_argument("--max_obs", "-no", type=int, help="maximum number of obstacles", default=50) 123 | parser.add_argument("--max_obs_size", "-os", type=int, help="maximum obstacle size", default=2) 124 | parser.add_argument("--n_traj", "-nt", type=int, help="number of trajectories", default=7) 125 | parser.add_argument("--state_batch_size", "-bs", type=int, help="state batch size", default=1) 126 | 127 | args = parser.parse_args() 128 | size = args.size 129 | 130 | main(dom_size=(size, size), n_domains=args.n_domains, max_obs=args.max_obs, 131 | max_obs_size=args.max_obs_size, n_traj=args.n_traj, state_batch_size=args.state_batch_size) 132 | -------------------------------------------------------------------------------- /domains/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/domains/__init__.py -------------------------------------------------------------------------------- /domains/gridworld.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse import csr_matrix 3 | from scipy.sparse.csgraph import dijkstra 4 | from collections import OrderedDict 5 | 6 | 7 | class GridWorld: 8 | """A class for making gridworlds""" 9 | 10 | ACTION = OrderedDict(N=(-1, 0), S=(1, 0), E=(0, 1), W=(0, -1), NE=(-1, 1), NW=(-1, -1), SE=(1, 1), SW=(1, -1)) 11 | 12 | def __init__(self, image, target_x, target_y): 13 | self.image = image 14 | self.n_row = image.shape[0] 15 | self.n_col = image.shape[1] 16 | self.obstacles = np.where(self.image == 0) 17 | self.freespace = np.where(self.image != 0) 18 | self.target_x = target_x 19 | self.target_y = target_y 20 | self.n_states = self.n_row * self.n_col 21 | self.n_actions = len(self.ACTION) 22 | 23 | self.G, self.W, self.P, self.R, self.state_map_row, self.state_map_col = self.set_vals() 24 | 25 | def loc_to_state(self, row, col): 26 | return np.ravel_multi_index([row, col], (self.n_row, self.n_col), order='F') 27 | 28 | def state_to_loc(self, state): 29 | return np.unravel_index(state, (self.n_col, self.n_row), order='F') 30 | 31 | def set_vals(self): 32 | # Setup function to initialize all necessary 33 | 34 | # Cost of each action, equivalent to the length of each vector 35 | # i.e. [1., 1., 1., 1., 1.414, 1.414, 1.414, 1.414] 36 | action_cost = np.linalg.norm(list(self.ACTION.values()), axis=1) 37 | # Initializing reward function R: (curr_state, action) -> reward: float 38 | # Each transition has negative reward equivalent to the distance of transition 39 | R = - np.ones((self.n_states, self.n_actions)) * action_cost 40 | # Reward at target is zero 41 | target = self.loc_to_state(self.target_x, self.target_y) 42 | R[target, :] = 0 43 | 44 | # Transition function P: (curr_state, next_state, action) -> probability: float 45 | P = np.zeros((self.n_states, self.n_states, self.n_actions)) 46 | # Filling in P 47 | for row in range(self.n_row): 48 | for col in range(self.n_col): 49 | curr_state = self.loc_to_state(row, col) 50 | for i_action, action in enumerate(self.ACTION): 51 | neighbor_row, neighbor_col = self.move(row, col, action) 52 | neighbor_state = self.loc_to_state(neighbor_row, neighbor_col) 53 | P[curr_state, neighbor_state, i_action] = 1 54 | 55 | # Adjacency matrix of a graph connecting curr_state and next_state 56 | G = np.logical_or.reduce(P, axis=2) 57 | # Weight of transition edges, equivalent to the cost of transition 58 | W = np.maximum.reduce(P * action_cost, axis=2) 59 | 60 | non_obstacles = self.loc_to_state(self.freespace[0], self.freespace[1]) 61 | 62 | non_obstacles = np.sort(non_obstacles) 63 | 64 | G = G[non_obstacles, :][:, non_obstacles] 65 | W = W[non_obstacles, :][:, non_obstacles] 66 | P = P[non_obstacles, :, :][:, non_obstacles, :] 67 | R = R[non_obstacles, :] 68 | 69 | state_map_col, state_map_row = np.meshgrid( 70 | np.arange(0, self.n_col), np.arange(0, self.n_row)) 71 | state_map_row = state_map_row.flatten('F')[non_obstacles] 72 | state_map_col = state_map_col.flatten('F')[non_obstacles] 73 | 74 | return G, W, P, R, state_map_row, state_map_col 75 | 76 | def get_graph(self): 77 | # Returns graph 78 | G = self.G 79 | W = self.W[self.W != 0] 80 | return G, W 81 | 82 | def get_graph_inv(self): 83 | # Returns transpose of graph 84 | G = self.G.T 85 | W = self.W.T 86 | return G, W 87 | 88 | def val_2_image(self, val): 89 | # Zeros for obstacles, val for free space 90 | im = np.zeros((self.n_row, self.n_col)) 91 | im[self.freespace[0], self.freespace[1]] = val 92 | return im 93 | 94 | def get_value_prior(self): 95 | # Returns value prior for gridworld 96 | s_map_col, s_map_row = np.meshgrid( 97 | np.arange(0, self.n_col), np.arange(0, self.n_row)) 98 | im = np.sqrt( 99 | np.square(s_map_col - self.target_y) + 100 | np.square(s_map_row - self.target_x)) 101 | return im 102 | 103 | def get_reward_prior(self): 104 | # Returns reward prior for gridworld 105 | im = -1 * np.ones((self.n_row, self.n_col)) 106 | im[self.target_x, self.target_y] = 10 107 | return im 108 | 109 | def t_get_reward_prior(self): 110 | # Returns reward prior as needed for 111 | # dataset generation 112 | im = np.zeros((self.n_row, self.n_col)) 113 | im[self.target_x, self.target_y] = 10 114 | return im 115 | 116 | def get_state_image(self, row, col): 117 | # Zeros everywhere except [row,col] 118 | im = np.zeros((self.n_row, self.n_col)) 119 | im[row, col] = 1 120 | return im 121 | 122 | def map_ind_to_state(self, row, col): 123 | # Takes [row, col] and maps to a state 124 | rw = np.where(self.state_map_row == row) 125 | cl = np.where(self.state_map_col == col) 126 | return np.intersect1d(rw, cl)[0] 127 | 128 | def get_coords(self, states): 129 | # Given a state or states, returns 130 | # [row,col] pairs for the state(s) 131 | non_obstacles = self.loc_to_state(self.freespace[0], self.freespace[1]) 132 | non_obstacles = np.sort(non_obstacles) 133 | states = states.astype(int) 134 | r, c = self.state_to_loc(non_obstacles[states]) 135 | return r, c 136 | 137 | def rand_choose(self, in_vec): 138 | # Samples 139 | if len(in_vec.shape) > 1: 140 | if in_vec.shape[1] == 1: 141 | in_vec = in_vec.T 142 | temp = np.hstack((np.zeros((1)), np.cumsum(in_vec))).astype('int') 143 | q = np.random.rand() 144 | x = np.where(q > temp[0:-1]) 145 | y = np.where(q < temp[1:]) 146 | return np.intersect1d(x, y)[0] 147 | 148 | def next_state_prob(self, s, a): 149 | # Gets next state probability for 150 | # a given action (a) 151 | if hasattr(a, "__iter__"): 152 | p = np.squeeze(self.P[s, :, a]) 153 | else: 154 | p = np.squeeze(self.P[s, :, a]).T 155 | return p 156 | 157 | def sample_next_state(self, s, a): 158 | # Gets the next state given the 159 | # current state (s) and an 160 | # action (a) 161 | vec = self.next_state_prob(s, a) 162 | result = self.rand_choose(vec) 163 | return result 164 | 165 | def get_size(self): 166 | # Returns domain size 167 | return self.n_row, self.n_col 168 | 169 | def move(self, row, col, action): 170 | # Returns new [row,col] 171 | # if we take the action 172 | r_move, c_move = self.ACTION[action] 173 | new_row = max(0, min(row + r_move, self.n_row - 1)) 174 | new_col = max(0, min(col + c_move, self.n_col - 1)) 175 | if self.image[new_row, new_col] == 0: 176 | new_row = row 177 | new_col = col 178 | return new_row, new_col 179 | 180 | 181 | def trace_path(pred, source, target): 182 | # traces back shortest path from 183 | # source to target given pred 184 | # (a predicessor list) 185 | max_len = 1000 186 | path = np.zeros((max_len, 1)) 187 | i = max_len - 1 188 | path[i] = target 189 | while path[i] != source and i > 0: 190 | try: 191 | path[i - 1] = pred[int(path[i])] 192 | i -= 1 193 | except Exception as e: 194 | return [] 195 | if i >= 0: 196 | path = path[i:] 197 | else: 198 | path = None 199 | return path 200 | 201 | 202 | def sample_trajectory(M: GridWorld, n_states): 203 | # Samples trajectories from random nodes 204 | # in our domain (M) 205 | G, W = M.get_graph_inv() 206 | N = G.shape[0] 207 | if N >= n_states: 208 | rand_ind = np.random.permutation(N) 209 | else: 210 | rand_ind = np.tile(np.random.permutation(N), (1, 10)) 211 | init_states = rand_ind[0:n_states].flatten() 212 | goal_s = M.map_ind_to_state(M.target_x, M.target_y) 213 | states = [] 214 | states_xy = [] 215 | states_one_hot = [] 216 | # Get optimal path from graph 217 | g_dense = W 218 | g_masked = np.ma.masked_values(g_dense, 0) 219 | g_sparse = csr_matrix(g_dense) 220 | d, pred = dijkstra(g_sparse, indices=goal_s, return_predecessors=True) 221 | for i in range(n_states): 222 | path = trace_path(pred, goal_s, init_states[i]) 223 | path = np.flip(path, 0) 224 | states.append(path) 225 | for state in states: 226 | L = len(state) 227 | r, c = M.get_coords(state) 228 | row_m = np.zeros((L, M.n_row)) 229 | col_m = np.zeros((L, M.n_col)) 230 | for i in range(L): 231 | row_m[i, r[i]] = 1 232 | col_m[i, c[i]] = 1 233 | states_one_hot.append(np.hstack((row_m, col_m))) 234 | states_xy.append(np.hstack((r, c))) 235 | return states_xy, states_one_hot 236 | -------------------------------------------------------------------------------- /download_weights_and_datasets.sh: -------------------------------------------------------------------------------- 1 | cd trained 2 | wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/vin_8x8.pth' 3 | wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/vin_16x16.pth' 4 | wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/vin_28x28.pth' 5 | cd ../dataset 6 | wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/gridworld_8x8.npz' 7 | wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/gridworld_16x16.npz' 8 | wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/gridworld_28x28.npz' 9 | cd .. 10 | -------------------------------------------------------------------------------- /generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/generators/__init__.py -------------------------------------------------------------------------------- /generators/obstacle_gen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | class obstacles: 6 | """A class for generating obstacles in a domain""" 7 | 8 | def __init__(self, 9 | domsize=None, 10 | mask=None, 11 | size_max=None, 12 | dom=None, 13 | obs_types=None, 14 | num_types=None): 15 | self.domsize = domsize or [] 16 | self.mask = mask or [] 17 | self.dom = dom or np.zeros(self.domsize) 18 | self.obs_types = obs_types or ["circ", "rect"] 19 | self.num_types = num_types or len(self.obs_types) 20 | self.size_max = size_max or np.max(self.domsize) / 4 21 | 22 | def check_mask(self, dom=None): 23 | # Ensure goal is in free space 24 | if dom is not None: 25 | return np.any(dom[self.mask[0], self.mask[1]]) 26 | else: 27 | return np.any(self.dom[self.mask[0], self.mask[1]]) 28 | 29 | def insert_rect(self, x, y, height, width): 30 | # Insert a rectangular obstacle into map 31 | im_try = np.copy(self.dom) 32 | im_try[x:x + height, y:y + width] = 1 33 | return im_try 34 | 35 | def add_rand_obs(self, obj_type): 36 | # Add random (valid) obstacle to map 37 | if obj_type == "circ": 38 | print("circ is not yet implemented... sorry") 39 | elif obj_type == "rect": 40 | rand_height = int(np.ceil(np.random.rand() * self.size_max)) 41 | rand_width = int(np.ceil(np.random.rand() * self.size_max)) 42 | randx = int(np.ceil(np.random.rand() * (self.domsize[1] - 1))) 43 | randy = int(np.ceil(np.random.rand() * (self.domsize[1] - 1))) 44 | im_try = self.insert_rect(randx, randy, rand_height, rand_width) 45 | if self.check_mask(im_try): 46 | return False 47 | else: 48 | self.dom = im_try 49 | return True 50 | 51 | def add_n_rand_obs(self, n): 52 | # Add random (valid) obstacles to map 53 | count = 0 54 | for i in range(n): 55 | obj_type = "rect" 56 | if self.add_rand_obs(obj_type): 57 | count += 1 58 | return count 59 | 60 | def add_border(self): 61 | # Make full outer border an obstacle 62 | im_try = np.copy(self.dom) 63 | im_try[0:self.domsize[0], 0] = 1 64 | im_try[0, 0:self.domsize[1]] = 1 65 | im_try[0:self.domsize[0], self.domsize[1] - 1] = 1 66 | im_try[self.domsize[0] - 1, 0:self.domsize[1]] = 1 67 | if self.check_mask(im_try): 68 | return False 69 | else: 70 | self.dom = im_try 71 | return True 72 | 73 | def get_final(self): 74 | # Process obstacle map for domain 75 | im = np.copy(self.dom) 76 | im = np.max(im) - im 77 | im = im / np.max(im) 78 | return im 79 | 80 | def show(self): 81 | # Utility function to view obstacle map 82 | plt.imshow(self.get_final(), cmap='Greys') 83 | plt.show() 84 | 85 | def _print(self): 86 | # Utility function to view obstacle map 87 | # information 88 | print("domsize: ", self.domsize) 89 | print("mask: ", self.mask) 90 | print("dom: ", self.dom) 91 | print("obs_types: ", self.obs_types) 92 | print("num_types: ", self.num_types) 93 | print("size_max: ", self.size_max) 94 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.nn.parameter import Parameter 8 | 9 | 10 | class VIN(nn.Module): 11 | def __init__(self, config): 12 | super(VIN, self).__init__() 13 | self.config = config 14 | self.h = nn.Conv2d( 15 | in_channels=config.l_i, 16 | out_channels=config.l_h, 17 | kernel_size=(3, 3), 18 | stride=1, 19 | padding=1, 20 | bias=True) 21 | self.r = nn.Conv2d( 22 | in_channels=config.l_h, 23 | out_channels=1, 24 | kernel_size=(1, 1), 25 | stride=1, 26 | padding=0, 27 | bias=False) 28 | self.q = nn.Conv2d( 29 | in_channels=1, 30 | out_channels=config.l_q, 31 | kernel_size=(3, 3), 32 | stride=1, 33 | padding=1, 34 | bias=False) 35 | self.fc = nn.Linear(in_features=config.l_q, out_features=8, bias=False) 36 | self.w = Parameter( 37 | torch.zeros(config.l_q, 1, 3, 3), requires_grad=True) 38 | self.sm = nn.Softmax(dim=1) 39 | 40 | def forward(self, input_view, state_x, state_y, k): 41 | """ 42 | :param input_view: (batch_sz, imsize, imsize) 43 | :param state_x: (batch_sz,), 0 <= state_x < imsize 44 | :param state_y: (batch_sz,), 0 <= state_y < imsize 45 | :param k: number of iterations 46 | :return: logits and softmaxed logits 47 | """ 48 | h = self.h(input_view) # Intermediate output 49 | r = self.r(h) # Reward 50 | q = self.q(r) # Initial Q value from reward 51 | v, _ = torch.max(q, dim=1, keepdim=True) 52 | 53 | def eval_q(r, v): 54 | return F.conv2d( 55 | # Stack reward with most recent value 56 | torch.cat([r, v], 1), 57 | # Convolve r->q weights to r, and v->q weights for v. These represent transition probabilities 58 | torch.cat([self.q.weight, self.w], 1), 59 | stride=1, 60 | padding=1) 61 | 62 | # Update q and v values 63 | for i in range(k - 1): 64 | q = eval_q(r, v) 65 | v, _ = torch.max(q, dim=1, keepdim=True) 66 | 67 | q = eval_q(r, v) 68 | # q: (batch_sz, l_q, map_size, map_size) 69 | batch_sz, l_q, _, _ = q.size() 70 | q_out = q[torch.arange(batch_sz), :, state_x.long(), state_y.long()].view(batch_sz, l_q) 71 | 72 | logits = self.fc(q_out) # q_out to actions 73 | 74 | return logits, self.sm(logits) 75 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy>=0.19.0 2 | matplotlib>=2.0.0 3 | numpy>=1.12.1 4 | torchvision>=0.1.8 5 | -------------------------------------------------------------------------------- /results/16x16_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/results/16x16_1.png -------------------------------------------------------------------------------- /results/16x16_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/results/16x16_2.png -------------------------------------------------------------------------------- /results/28x28_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/results/28x28_1.png -------------------------------------------------------------------------------- /results/28x28_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/results/28x28_2.png -------------------------------------------------------------------------------- /results/8x8_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/results/8x8_1.png -------------------------------------------------------------------------------- /results/8x8_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/results/8x8_2.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | import matplotlib.pyplot as plt 5 | 6 | import numpy as np 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from dataset.dataset import * 12 | from utility.utils import * 13 | from model import * 14 | 15 | from domains.gridworld import * 16 | from generators.obstacle_gen import * 17 | 18 | 19 | def main(config, 20 | n_domains=100, 21 | max_obs=30, 22 | max_obs_size=None, 23 | n_traj=1, 24 | n_actions=8): 25 | # Correct vs total: 26 | correct, total = 0.0, 0.0 27 | # Instantiate a VIN model 28 | vin: VIN = VIN(config) 29 | # Load model parameters 30 | vin.load_state_dict(torch.load(config.weights)) 31 | # Automatically select device to make the code device agnostic 32 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 33 | vin = vin.to(device) 34 | 35 | for dom in range(n_domains): 36 | # Randomly select goal position 37 | goal = [ 38 | np.random.randint(config.imsize), 39 | np.random.randint(config.imsize) 40 | ] 41 | # Generate obstacle map 42 | obs = obstacles([config.imsize, config.imsize], goal, max_obs_size) 43 | # Add obstacles to map 44 | n_obs = obs.add_n_rand_obs(max_obs) 45 | # Add border to map 46 | border_res = obs.add_border() 47 | # Ensure we have valid map 48 | if n_obs == 0 or not border_res: 49 | continue 50 | # Get final map 51 | im = obs.get_final() 52 | 53 | # Generate gridworld from obstacle map 54 | G = GridWorld(im, goal[0], goal[1]) 55 | # Get value prior 56 | value_prior = G.get_reward_prior() 57 | # Sample random trajectories to our goal 58 | states_xy, states_one_hot = sample_trajectory(G, n_traj) 59 | 60 | for i in range(n_traj): 61 | if len(states_xy[i]) > 1: 62 | 63 | # Get number of steps to goal 64 | L = len(states_xy[i]) * 2 65 | # Allocate space for predicted steps 66 | pred_traj = np.zeros((L, 2)) 67 | # Set starting position 68 | pred_traj[0, :] = states_xy[i][0, :] 69 | 70 | for j in range(1, L): 71 | # Transform current state data 72 | state_data = pred_traj[j - 1, :] 73 | state_data = state_data.astype(np.int) 74 | # Transform domain to Networks expected input shape 75 | im_data = G.image.astype(np.int) 76 | im_data = 1 - im_data 77 | im_data = im_data.reshape(1, 1, config.imsize, 78 | config.imsize) 79 | # Transfrom value prior to Networks expected input shape 80 | value_data = value_prior.astype(np.int) 81 | value_data = value_data.reshape(1, 1, config.imsize, 82 | config.imsize) 83 | # Get inputs as expected by network 84 | X_in = torch.from_numpy(np.append(im_data, value_data, axis=1)) 85 | S1_in = torch.from_numpy(state_data[0].reshape([1, 1])) 86 | S2_in = torch.from_numpy(state_data[1].reshape([1, 1])) 87 | 88 | # Get input batch 89 | X_in, S1_in, S2_in = [d.float().to(device) for d in [X_in, S1_in, S2_in]] 90 | 91 | # Forward pass in our neural net 92 | _, predictions = vin(X_in, S1_in, S2_in, config.k) 93 | _, indices = torch.max(predictions.cpu(), 1, keepdim=True) 94 | a = indices.data.numpy()[0][0] 95 | # Transform prediction to indices 96 | s = G.map_ind_to_state(pred_traj[j - 1, 0], 97 | pred_traj[j - 1, 1]) 98 | ns = G.sample_next_state(s, a) 99 | nr, nc = G.get_coords(ns) 100 | pred_traj[j, 0] = nr 101 | pred_traj[j, 1] = nc 102 | if nr == goal[0] and nc == goal[1]: 103 | # We hit goal so fill remaining steps 104 | pred_traj[j + 1:, 0] = nr 105 | pred_traj[j + 1:, 1] = nc 106 | break 107 | # Plot optimal and predicted path (also start, end) 108 | if pred_traj[-1, 0] == goal[0] and pred_traj[-1, 1] == goal[1]: 109 | correct += 1 110 | total += 1 111 | if config.plot == True: 112 | visualize(G.image.T, states_xy[i], pred_traj) 113 | sys.stdout.write("\r" + str(int( 114 | (float(dom) / n_domains) * 100.0)) + "%") 115 | sys.stdout.flush() 116 | sys.stdout.write("\n") 117 | print('Rollout Accuracy: {:.2f}%'.format(100 * (correct / total))) 118 | 119 | 120 | def visualize(dom, states_xy, pred_traj): 121 | fig, ax = plt.subplots() 122 | implot = plt.imshow(dom, cmap="Greys_r") 123 | ax.plot(states_xy[:, 0], states_xy[:, 1], c='b', label='Optimal Path') 124 | ax.plot( 125 | pred_traj[:, 0], pred_traj[:, 1], '-X', c='r', label='Predicted Path') 126 | ax.plot(states_xy[0, 0], states_xy[0, 1], '-o', label='Start') 127 | ax.plot(states_xy[-1, 0], states_xy[-1, 1], '-s', label='Goal') 128 | legend = ax.legend(loc='upper right', shadow=False) 129 | for label in legend.get_texts(): 130 | label.set_fontsize('x-small') # The legend text size 131 | for label in legend.get_lines(): 132 | label.set_linewidth(0.5) # The legend line width 133 | plt.draw() 134 | plt.waitforbuttonpress(0) 135 | plt.close(fig) 136 | 137 | 138 | if __name__ == '__main__': 139 | # Parsing training parameters 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument( 142 | '--weights', 143 | type=str, 144 | default='trained/vin_8x8.pth', 145 | help='Path to trained weights') 146 | parser.add_argument('--plot', action='store_true', default=False) 147 | parser.add_argument('--imsize', type=int, default=8, help='Size of image') 148 | parser.add_argument( 149 | '--k', type=int, default=10, help='Number of Value Iterations') 150 | parser.add_argument( 151 | '--l_i', type=int, default=2, help='Number of channels in input layer') 152 | parser.add_argument( 153 | '--l_h', 154 | type=int, 155 | default=150, 156 | help='Number of channels in first hidden layer') 157 | parser.add_argument( 158 | '--l_q', 159 | type=int, 160 | default=10, 161 | help='Number of channels in q layer (~actions) in VI-module') 162 | config = parser.parse_args() 163 | # Compute Paths generated by network and plot 164 | main(config) 165 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | import torchvision.transforms as transforms 10 | 11 | import matplotlib.pyplot as plt 12 | from dataset.dataset import * 13 | from utility.utils import * 14 | from model import * 15 | 16 | 17 | def train(net: VIN, trainloader, config, criterion, optimizer): 18 | print_header() 19 | # Automatically select device to make the code device agnostic 20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | for epoch in range(config.epochs): # Loop over dataset multiple times 22 | avg_error, avg_loss, num_batches = 0.0, 0.0, 0.0 23 | start_time = time.time() 24 | for i, data in enumerate(trainloader): # Loop over batches of data 25 | # Get input batch 26 | X, S1, S2, labels = [d.to(device) for d in data] 27 | if X.size()[0] != config.batch_size: 28 | continue # Drop those data, if not enough for a batch 29 | net = net.to(device) 30 | # Zero the parameter gradients 31 | optimizer.zero_grad() 32 | # Forward pass 33 | outputs, predictions = net(X, S1, S2, config.k) 34 | # Loss 35 | loss = criterion(outputs, labels) 36 | # Backward pass 37 | loss.backward() 38 | # Update params 39 | optimizer.step() 40 | # Calculate Loss and Error 41 | loss_batch, error_batch = get_stats(loss, predictions, labels) 42 | avg_loss += loss_batch 43 | avg_error += error_batch 44 | num_batches += 1 45 | time_duration = time.time() - start_time 46 | # Print epoch logs 47 | print_stats(epoch, avg_loss, avg_error, num_batches, time_duration) 48 | print('\nFinished training. \n') 49 | 50 | 51 | def test(net: VIN, testloader, config): 52 | total, correct = 0.0, 0.0 53 | # Automatically select device, device agnostic 54 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 55 | for i, data in enumerate(testloader): 56 | # Get inputs 57 | X, S1, S2, labels = [d.to(device) for d in data] 58 | if X.size()[0] != config.batch_size: 59 | continue # Drop those data, if not enough for a batch 60 | net = net.to(device) 61 | # Forward pass 62 | outputs, predictions = net(X, S1, S2, config.k) 63 | # Select actions with max scores(logits) 64 | _, predicted = torch.max(outputs, dim=1, keepdim=True) 65 | # Unwrap autograd.Variable to Tensor 66 | predicted = predicted.data 67 | # Compute test accuracy 68 | correct += (torch.eq(torch.squeeze(predicted), labels)).sum() 69 | total += labels.size()[0] 70 | print('Test Accuracy: {:.2f}%'.format(100 * (correct / total))) 71 | 72 | 73 | if __name__ == '__main__': 74 | # Parsing training parameters 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument( 77 | '--datafile', 78 | type=str, 79 | default='dataset/gridworld_8x8.npz', 80 | help='Path to data file') 81 | parser.add_argument('--imsize', type=int, default=8, help='Size of image') 82 | parser.add_argument( 83 | '--lr', 84 | type=float, 85 | default=0.005, 86 | help='Learning rate, [0.01, 0.005, 0.002, 0.001]') 87 | parser.add_argument( 88 | '--epochs', type=int, default=30, help='Number of epochs to train') 89 | parser.add_argument( 90 | '--k', type=int, default=10, help='Number of Value Iterations') 91 | parser.add_argument( 92 | '--l_i', type=int, default=2, help='Number of channels in input layer') 93 | parser.add_argument( 94 | '--l_h', 95 | type=int, 96 | default=150, 97 | help='Number of channels in first hidden layer') 98 | parser.add_argument( 99 | '--l_q', 100 | type=int, 101 | default=10, 102 | help='Number of channels in q layer (~actions) in VI-module') 103 | parser.add_argument( 104 | '--batch_size', type=int, default=128, help='Batch size') 105 | config = parser.parse_args() 106 | # Get path to save trained model 107 | save_path = "trained/vin_{0}x{0}.pth".format(config.imsize) 108 | # Instantiate a VIN model 109 | net = VIN(config) 110 | # Loss 111 | criterion = nn.CrossEntropyLoss() 112 | # Optimizer 113 | optimizer = optim.RMSprop(net.parameters(), lr=config.lr, eps=1e-6) 114 | # Dataset transformer: torchvision.transforms 115 | transform = None 116 | # Define Dataset 117 | trainset = GridworldData( 118 | config.datafile, imsize=config.imsize, train=True, transform=transform) 119 | testset = GridworldData( 120 | config.datafile, 121 | imsize=config.imsize, 122 | train=False, 123 | transform=transform) 124 | # Create Dataloader 125 | trainloader = torch.utils.data.DataLoader( 126 | trainset, batch_size=config.batch_size, shuffle=True, num_workers=0) 127 | testloader = torch.utils.data.DataLoader( 128 | testset, batch_size=config.batch_size, shuffle=False, num_workers=0) 129 | # Train the model 130 | train(net, trainloader, config, criterion, optimizer) 131 | # Test accuracy 132 | test(net, testloader, config) 133 | # Save the trained model parameters 134 | torch.save(net.state_dict(), save_path) 135 | -------------------------------------------------------------------------------- /trained/README.md: -------------------------------------------------------------------------------- 1 | # Trained Models 2 | To use a pretrained model you have two choices: 3 | 1. Download and place the trained .pth model files here 4 | 2. Train the VIN on the datasets yourself (the models will save themselves here) -------------------------------------------------------------------------------- /utility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/utility/__init__.py -------------------------------------------------------------------------------- /utility/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def fmt_row(width, row): 6 | out = " | ".join(fmt_item(x, width) for x in row) 7 | return out 8 | 9 | 10 | def fmt_item(x, l): 11 | if isinstance(x, np.ndarray): 12 | assert x.ndim == 0 13 | x = x.item() 14 | if isinstance(x, float): rep = "%g" % x 15 | else: rep = str(x) 16 | return " " * (l - len(rep)) + rep 17 | 18 | 19 | def get_stats(loss, predictions, labels): 20 | cp = np.argmax(predictions.cpu().data.numpy(), 1) 21 | error = np.mean(cp != labels.cpu().data.numpy()) 22 | return loss.item(), error 23 | 24 | 25 | def print_stats(epoch, avg_loss, avg_error, num_batches, time_duration): 26 | print( 27 | fmt_row(10, [ 28 | epoch + 1, avg_loss / num_batches, avg_error / num_batches, 29 | time_duration 30 | ])) 31 | 32 | 33 | def print_header(): 34 | print(fmt_row(10, ["Epoch", "Train Loss", "Train Error", "Epoch Time"])) 35 | --------------------------------------------------------------------------------