├── .gitignore
├── LICENSE
├── README.md
├── dataset
├── README.MD
├── __init__.py
├── dataset.py
└── make_training_data.py
├── domains
├── __init__.py
└── gridworld.py
├── download_weights_and_datasets.sh
├── generators
├── __init__.py
└── obstacle_gen.py
├── model.py
├── requirements.txt
├── results
├── 16x16_1.png
├── 16x16_2.png
├── 28x28_1.png
├── 28x28_2.png
├── 8x8_1.png
└── 8x8_2.png
├── test.py
├── train.py
├── trained
└── README.md
└── utility
├── __init__.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # npz
107 | *.npz
108 |
109 | # pth
110 | *.pth
111 |
112 | # jetbrains project settings
113 | .idea
114 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2018, Kent Sommer
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VIN: [Value Iteration Networks](https://arxiv.org/abs/1602.02867)
2 |
3 | 
4 |
5 | ## A quick thank you
6 | A few others have released amazing related work which helped inspire and improve my own implementation. It goes without saying that this release would not be nearly as good if it were not for all of the following:
7 | * [@avivt](https://github.com/avivt) ([Paper Author](https://arxiv.org/abs/1602.02867), [MATLAB implementation](https://github.com/avivt/VIN))
8 | * [@zuoxingdong](https://github.com/zuoxingdong) ([Tensorflow implementation](https://github.com/zuoxingdong/VIN_TensorFlow), [Pytorch implementation](https://github.com/zuoxingdong/VIN_PyTorch_Visdom))
9 | * [@TheAbhiKumar](https://github.com/TheAbhiKumar) ([Tensorflow implementation](https://github.com/TheAbhiKumar/tensorflow-value-iteration-networks))
10 | * [@onlytailei](https://github.com/onlytailei) ([Pytorch implementation](https://github.com/onlytailei/Value-Iteration-Networks-PyTorch))
11 |
12 | ## Why another VIN implementation?
13 | 1. The Pytorch VIN model in this repository is, in my opinion, more readable and closer to the original Theano implementation than others I have found (both Tensorflow and Pytorch).
14 | 2. This is not simply an implementation of the VIN model in Pytorch, it is also a full Python implementation of the gridworld environments as used in the [original MATLAB implementation](https://github.com/avivt/VIN).
15 | 3. Provide a more extensible research base for others to build off of without needing to jump through the possible MATLAB paywall.
16 |
17 | ## Installation
18 | This repository requires following packages:
19 | - [SciPy](https://www.scipy.org/install.html) >= 0.19.0
20 | - [Python](https://www.python.org/) >= 2.7 (if using Python 3.x: python3-tk should be installed)
21 | - [Numpy](https://pypi.python.org/pypi/numpy) >= 1.12.1
22 | - [Matplotlib](https://matplotlib.org/users/installing.html) >= 2.0.0
23 | - [PyTorch](http://pytorch.org/) >= 0.1.11
24 |
25 | Use `pip` to install the necessary dependencies:
26 | ```
27 | pip install -U -r requirements.txt
28 | ```
29 | Note that PyTorch cannot be installed directly from PyPI; refer to http://pytorch.org/ for custom installation instructions specific to your needs.
30 | ## How to train
31 | #### 8x8 gridworld
32 | ```bash
33 | python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128
34 | ```
35 | #### 16x16 gridworld
36 | ```bash
37 | python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 30 --k 20 --batch_size 128
38 | ```
39 | #### 28x28 gridworld
40 | ```bash
41 | python train.py --datafile dataset/gridworld_28x28.npz --imsize 28 --lr 0.002 --epochs 30 --k 36 --batch_size 128
42 | ```
43 | **Flags**:
44 | - `datafile`: The path to the data files.
45 | - `imsize`: The size of input images. One of: [8, 16, 28]
46 | - `lr`: Learning rate with RMSProp optimizer. Recommended: [0.01, 0.005, 0.002, 0.001]
47 | - `epochs`: Number of epochs to train. Default: 30
48 | - `k`: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
49 | - `l_i`: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
50 | - `l_h`: Number of channels in first convolutional layer. Default: 150, described in paper.
51 | - `l_q`: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.
52 | - `batch_size`: Batch size. Default: 128
53 |
54 | ## How to test / visualize paths (requires training first)
55 | #### 8x8 gridworld
56 | ```bash
57 | python test.py --weights trained/vin_8x8.pth --imsize 8 --k 10
58 | ```
59 | #### 16x16 gridworld
60 | ```bash
61 | python test.py --weights trained/vin_16x16.pth --imsize 16 --k 20
62 | ```
63 | #### 28x28 gridworld
64 | ```bash
65 | python test.py --weights trained/vin_28x28.pth --imsize 28 --k 36
66 | ```
67 | To visualize the optimal and predicted paths simply pass:
68 | ```bash
69 | --plot
70 | ```
71 |
72 | **Flags**:
73 | - `weights`: Path to trained weights.
74 | - `imsize`: The size of input images. One of: [8, 16, 28]
75 | - `plot`: If supplied, the optimal and predicted paths will be plotted
76 | - `k`: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
77 | - `l_i`: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
78 | - `l_h`: Number of channels in first convolutional layer. Default: 150, described in paper.
79 | - `l_q`: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.
80 |
81 | ## Results
82 | Gridworld | Sample One | Sample Two
83 | -- | --- | ---
84 | 8x8 |
|
85 | 16x16 |
|
86 | 28x28 |
|
87 |
88 | ## Datasets
89 | Each data sample consists of an obstacle image and a goal image followed by the (x, y) coordinates of current state in the gridworld.
90 |
91 | Dataset size | 8x8 | 16x16 | 28x28
92 | -- | -- | -- | --
93 | Train set | 81337 | 456309 | 1529584
94 | Test set | 13846 | 77203 | 251755
95 |
96 | The datasets (8x8, 16x16, and 28x28) included in this repository can be reproduced using the ```dataset/make_training_data.py``` script. Note that this script is not optimized and runs rather slowly (also uses a lot of memory :D)
97 |
98 | ## Performance: Success Rate
99 | This is the success rate from rollouts of the learned policy in the environment (taken over 5000 randomly generated domains).
100 |
101 | Success Rate | 8x8 | 16x16 | 28x28
102 | -- | -- | -- | --
103 | PyTorch | 99.69% | 96.99% | 91.07%
104 |
105 | ## Performance: Test Accuracy
106 |
107 | **NOTE**: This is the **accuracy on test set**. It is different from the table in the paper, which indicates the **success rate** from rollouts of the learned policy in the environment.
108 |
109 | Test Accuracy | 8x8 | 16x16 | 28x28
110 | -- | -- | -- | --
111 | PyTorch | 99.83% | 94.84% | 88.54%
112 |
--------------------------------------------------------------------------------
/dataset/README.MD:
--------------------------------------------------------------------------------
1 | # Gridworld datasets
2 | To use the gridworld datasets you have two choices:
3 | 1. Download and place the .npz dataset files here
4 | * gridworld_8x8.npz
5 | * gridworld_16x16.npz
6 | * gridworld_28x28.npz
7 | 2. Use the dataset generation script
8 | * ```make_training_data.py```
9 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.utils.data as data
5 |
6 |
7 | class GridworldData(data.Dataset):
8 | def __init__(self,
9 | file,
10 | imsize,
11 | train=True,
12 | transform=None,
13 | target_transform=None):
14 | assert file.endswith('.npz') # Must be .npz format
15 | self.file = file
16 | self.imsize = imsize
17 | self.transform = transform
18 | self.target_transform = target_transform
19 | self.train = train # Training set or test set
20 |
21 | self.images, self.S1, self.S2, self.labels = \
22 | self._process(file, self.train)
23 |
24 | def __getitem__(self, index):
25 | img = self.images[index]
26 | s1 = self.S1[index]
27 | s2 = self.S2[index]
28 | label = self.labels[index]
29 | # Apply transform if we have one
30 | if self.transform is not None:
31 | img = self.transform(img)
32 | else: # Internal default transform: Just to Tensor
33 | img = torch.from_numpy(img)
34 | # Apply target transform if we have one
35 | if self.target_transform is not None:
36 | label = self.target_transform(label)
37 | return img, int(s1), int(s2), int(label)
38 |
39 | def __len__(self):
40 | return self.images.shape[0]
41 |
42 | def _process(self, file, train):
43 | """Data format: A list, [train data, test data]
44 | Each data sample: label, S1, S2, Images, in this order.
45 | """
46 | with np.load(file, mmap_mode='r') as f:
47 | if train:
48 | images = f['arr_0']
49 | S1 = f['arr_1']
50 | S2 = f['arr_2']
51 | labels = f['arr_3']
52 | else:
53 | images = f['arr_4']
54 | S1 = f['arr_5']
55 | S2 = f['arr_6']
56 | labels = f['arr_7']
57 | # Set proper datatypes
58 | images = images.astype(np.float32)
59 | S1 = S1.astype(int) # (S1, S2) location are integers
60 | S2 = S2.astype(int)
61 | labels = labels.astype(int) # Labels are integers
62 | # Print number of samples
63 | if train:
64 | print("Number of Train Samples: {0}".format(images.shape[0]))
65 | else:
66 | print("Number of Test Samples: {0}".format(images.shape[0]))
67 | return images, S1, S2, labels
68 |
--------------------------------------------------------------------------------
/dataset/make_training_data.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import numpy as np
4 | from dataset import *
5 |
6 | import argparse
7 |
8 | sys.path.append('.')
9 | from domains.gridworld import *
10 | from generators.obstacle_gen import *
11 | sys.path.remove('.')
12 |
13 |
14 | def extract_action(traj):
15 | # Given a trajectory, outputs a 1D vector of
16 | # actions corresponding to the trajectory.
17 | n_actions = 8
18 | action_vecs = np.asarray([[-1., 0.], [1., 0.], [0., 1.], [0., -1.],
19 | [-1., 1.], [-1., -1.], [1., 1.], [1., -1.]])
20 | action_vecs[4:] = 1 / np.sqrt(2) * action_vecs[4:]
21 | action_vecs = action_vecs.T
22 | state_diff = np.diff(traj, axis=0)
23 | norm_state_diff = state_diff * np.tile(
24 | 1 / np.sqrt(np.sum(np.square(state_diff), axis=1)), (2, 1)).T
25 | prj_state_diff = np.dot(norm_state_diff, action_vecs)
26 | actions_one_hot = np.abs(prj_state_diff - 1) < 0.00001
27 | actions = np.dot(actions_one_hot, np.arange(n_actions).T)
28 | return actions
29 |
30 |
31 | def make_data(dom_size, n_domains, max_obs, max_obs_size, n_traj,
32 | state_batch_size):
33 |
34 | X_l = []
35 | S1_l = []
36 | S2_l = []
37 | Labels_l = []
38 |
39 | dom = 0.0
40 | while dom <= n_domains:
41 | goal = [np.random.randint(dom_size[0]), np.random.randint(dom_size[1])]
42 | # Generate obstacle map
43 | obs = obstacles([dom_size[0], dom_size[1]], goal, max_obs_size)
44 | # Add obstacles to map
45 | n_obs = obs.add_n_rand_obs(max_obs)
46 | # Add border to map
47 | border_res = obs.add_border()
48 | # Ensure we have valid map
49 | if n_obs == 0 or not border_res:
50 | continue
51 | # Get final map
52 | im = obs.get_final()
53 | # Generate gridworld from obstacle map
54 | G = GridWorld(im, goal[0], goal[1])
55 | # Get value prior
56 | value_prior = G.t_get_reward_prior()
57 | # Sample random trajectories to our goal
58 | states_xy, states_one_hot = sample_trajectory(G, n_traj)
59 | for i in range(n_traj):
60 | if len(states_xy[i]) > 1:
61 | # Get optimal actions for each state
62 | actions = extract_action(states_xy[i])
63 | ns = states_xy[i].shape[0] - 1
64 | # Invert domain image => 0 = free, 1 = obstacle
65 | image = 1 - im
66 | # Resize domain and goal images and concate
67 | image_data = np.resize(image, (1, 1, dom_size[0], dom_size[1]))
68 | value_data = np.resize(value_prior,
69 | (1, 1, dom_size[0], dom_size[1]))
70 | iv_mixed = np.concatenate((image_data, value_data), axis=1)
71 | X_current = np.tile(iv_mixed, (ns, 1, 1, 1))
72 | # Resize states
73 | S1_current = np.expand_dims(states_xy[i][0:ns, 0], axis=1)
74 | S2_current = np.expand_dims(states_xy[i][0:ns, 1], axis=1)
75 | # Resize labels
76 | Labels_current = np.expand_dims(actions, axis=1)
77 | # Append to output list
78 | X_l.append(X_current)
79 | S1_l.append(S1_current)
80 | S2_l.append(S2_current)
81 | Labels_l.append(Labels_current)
82 | dom += 1
83 | sys.stdout.write("\r" + str(int((dom / n_domains) * 100)) + "%")
84 | sys.stdout.flush()
85 | sys.stdout.write("\n")
86 | # Concat all outputs
87 | X_f = np.concatenate(X_l)
88 | S1_f = np.concatenate(S1_l)
89 | S2_f = np.concatenate(S2_l)
90 | Labels_f = np.concatenate(Labels_l)
91 | return X_f, S1_f, S2_f, Labels_f
92 |
93 |
94 | def main(dom_size=(28, 28),
95 | n_domains=5000,
96 | max_obs=50,
97 | max_obs_size=2,
98 | n_traj=7,
99 | state_batch_size=1):
100 | # Get path to save dataset
101 | save_path = "dataset/gridworld_{0}x{1}".format(dom_size[0], dom_size[1])
102 | # Get training data
103 | print("Now making training data...")
104 | X_out_tr, S1_out_tr, S2_out_tr, Labels_out_tr = make_data(
105 | dom_size, n_domains, max_obs, max_obs_size, n_traj, state_batch_size)
106 | # Get testing data
107 | print("\nNow making testing data...")
108 | X_out_ts, S1_out_ts, S2_out_ts, Labels_out_ts = make_data(
109 | dom_size, n_domains / 6, max_obs, max_obs_size, n_traj,
110 | state_batch_size)
111 | # Save dataset
112 | np.savez_compressed(save_path, X_out_tr, S1_out_tr, S2_out_tr,
113 | Labels_out_tr, X_out_ts, S1_out_ts, S2_out_ts,
114 | Labels_out_ts)
115 |
116 |
117 | if __name__ == '__main__':
118 |
119 | parser = argparse.ArgumentParser()
120 | parser.add_argument("--size", "-s", type=int, help="size of the domain", default=28)
121 | parser.add_argument("--n_domains", "-nd", type=int, help="number of domains", default=5000)
122 | parser.add_argument("--max_obs", "-no", type=int, help="maximum number of obstacles", default=50)
123 | parser.add_argument("--max_obs_size", "-os", type=int, help="maximum obstacle size", default=2)
124 | parser.add_argument("--n_traj", "-nt", type=int, help="number of trajectories", default=7)
125 | parser.add_argument("--state_batch_size", "-bs", type=int, help="state batch size", default=1)
126 |
127 | args = parser.parse_args()
128 | size = args.size
129 |
130 | main(dom_size=(size, size), n_domains=args.n_domains, max_obs=args.max_obs,
131 | max_obs_size=args.max_obs_size, n_traj=args.n_traj, state_batch_size=args.state_batch_size)
132 |
--------------------------------------------------------------------------------
/domains/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/domains/__init__.py
--------------------------------------------------------------------------------
/domains/gridworld.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.sparse import csr_matrix
3 | from scipy.sparse.csgraph import dijkstra
4 | from collections import OrderedDict
5 |
6 |
7 | class GridWorld:
8 | """A class for making gridworlds"""
9 |
10 | ACTION = OrderedDict(N=(-1, 0), S=(1, 0), E=(0, 1), W=(0, -1), NE=(-1, 1), NW=(-1, -1), SE=(1, 1), SW=(1, -1))
11 |
12 | def __init__(self, image, target_x, target_y):
13 | self.image = image
14 | self.n_row = image.shape[0]
15 | self.n_col = image.shape[1]
16 | self.obstacles = np.where(self.image == 0)
17 | self.freespace = np.where(self.image != 0)
18 | self.target_x = target_x
19 | self.target_y = target_y
20 | self.n_states = self.n_row * self.n_col
21 | self.n_actions = len(self.ACTION)
22 |
23 | self.G, self.W, self.P, self.R, self.state_map_row, self.state_map_col = self.set_vals()
24 |
25 | def loc_to_state(self, row, col):
26 | return np.ravel_multi_index([row, col], (self.n_row, self.n_col), order='F')
27 |
28 | def state_to_loc(self, state):
29 | return np.unravel_index(state, (self.n_col, self.n_row), order='F')
30 |
31 | def set_vals(self):
32 | # Setup function to initialize all necessary
33 |
34 | # Cost of each action, equivalent to the length of each vector
35 | # i.e. [1., 1., 1., 1., 1.414, 1.414, 1.414, 1.414]
36 | action_cost = np.linalg.norm(list(self.ACTION.values()), axis=1)
37 | # Initializing reward function R: (curr_state, action) -> reward: float
38 | # Each transition has negative reward equivalent to the distance of transition
39 | R = - np.ones((self.n_states, self.n_actions)) * action_cost
40 | # Reward at target is zero
41 | target = self.loc_to_state(self.target_x, self.target_y)
42 | R[target, :] = 0
43 |
44 | # Transition function P: (curr_state, next_state, action) -> probability: float
45 | P = np.zeros((self.n_states, self.n_states, self.n_actions))
46 | # Filling in P
47 | for row in range(self.n_row):
48 | for col in range(self.n_col):
49 | curr_state = self.loc_to_state(row, col)
50 | for i_action, action in enumerate(self.ACTION):
51 | neighbor_row, neighbor_col = self.move(row, col, action)
52 | neighbor_state = self.loc_to_state(neighbor_row, neighbor_col)
53 | P[curr_state, neighbor_state, i_action] = 1
54 |
55 | # Adjacency matrix of a graph connecting curr_state and next_state
56 | G = np.logical_or.reduce(P, axis=2)
57 | # Weight of transition edges, equivalent to the cost of transition
58 | W = np.maximum.reduce(P * action_cost, axis=2)
59 |
60 | non_obstacles = self.loc_to_state(self.freespace[0], self.freespace[1])
61 |
62 | non_obstacles = np.sort(non_obstacles)
63 |
64 | G = G[non_obstacles, :][:, non_obstacles]
65 | W = W[non_obstacles, :][:, non_obstacles]
66 | P = P[non_obstacles, :, :][:, non_obstacles, :]
67 | R = R[non_obstacles, :]
68 |
69 | state_map_col, state_map_row = np.meshgrid(
70 | np.arange(0, self.n_col), np.arange(0, self.n_row))
71 | state_map_row = state_map_row.flatten('F')[non_obstacles]
72 | state_map_col = state_map_col.flatten('F')[non_obstacles]
73 |
74 | return G, W, P, R, state_map_row, state_map_col
75 |
76 | def get_graph(self):
77 | # Returns graph
78 | G = self.G
79 | W = self.W[self.W != 0]
80 | return G, W
81 |
82 | def get_graph_inv(self):
83 | # Returns transpose of graph
84 | G = self.G.T
85 | W = self.W.T
86 | return G, W
87 |
88 | def val_2_image(self, val):
89 | # Zeros for obstacles, val for free space
90 | im = np.zeros((self.n_row, self.n_col))
91 | im[self.freespace[0], self.freespace[1]] = val
92 | return im
93 |
94 | def get_value_prior(self):
95 | # Returns value prior for gridworld
96 | s_map_col, s_map_row = np.meshgrid(
97 | np.arange(0, self.n_col), np.arange(0, self.n_row))
98 | im = np.sqrt(
99 | np.square(s_map_col - self.target_y) +
100 | np.square(s_map_row - self.target_x))
101 | return im
102 |
103 | def get_reward_prior(self):
104 | # Returns reward prior for gridworld
105 | im = -1 * np.ones((self.n_row, self.n_col))
106 | im[self.target_x, self.target_y] = 10
107 | return im
108 |
109 | def t_get_reward_prior(self):
110 | # Returns reward prior as needed for
111 | # dataset generation
112 | im = np.zeros((self.n_row, self.n_col))
113 | im[self.target_x, self.target_y] = 10
114 | return im
115 |
116 | def get_state_image(self, row, col):
117 | # Zeros everywhere except [row,col]
118 | im = np.zeros((self.n_row, self.n_col))
119 | im[row, col] = 1
120 | return im
121 |
122 | def map_ind_to_state(self, row, col):
123 | # Takes [row, col] and maps to a state
124 | rw = np.where(self.state_map_row == row)
125 | cl = np.where(self.state_map_col == col)
126 | return np.intersect1d(rw, cl)[0]
127 |
128 | def get_coords(self, states):
129 | # Given a state or states, returns
130 | # [row,col] pairs for the state(s)
131 | non_obstacles = self.loc_to_state(self.freespace[0], self.freespace[1])
132 | non_obstacles = np.sort(non_obstacles)
133 | states = states.astype(int)
134 | r, c = self.state_to_loc(non_obstacles[states])
135 | return r, c
136 |
137 | def rand_choose(self, in_vec):
138 | # Samples
139 | if len(in_vec.shape) > 1:
140 | if in_vec.shape[1] == 1:
141 | in_vec = in_vec.T
142 | temp = np.hstack((np.zeros((1)), np.cumsum(in_vec))).astype('int')
143 | q = np.random.rand()
144 | x = np.where(q > temp[0:-1])
145 | y = np.where(q < temp[1:])
146 | return np.intersect1d(x, y)[0]
147 |
148 | def next_state_prob(self, s, a):
149 | # Gets next state probability for
150 | # a given action (a)
151 | if hasattr(a, "__iter__"):
152 | p = np.squeeze(self.P[s, :, a])
153 | else:
154 | p = np.squeeze(self.P[s, :, a]).T
155 | return p
156 |
157 | def sample_next_state(self, s, a):
158 | # Gets the next state given the
159 | # current state (s) and an
160 | # action (a)
161 | vec = self.next_state_prob(s, a)
162 | result = self.rand_choose(vec)
163 | return result
164 |
165 | def get_size(self):
166 | # Returns domain size
167 | return self.n_row, self.n_col
168 |
169 | def move(self, row, col, action):
170 | # Returns new [row,col]
171 | # if we take the action
172 | r_move, c_move = self.ACTION[action]
173 | new_row = max(0, min(row + r_move, self.n_row - 1))
174 | new_col = max(0, min(col + c_move, self.n_col - 1))
175 | if self.image[new_row, new_col] == 0:
176 | new_row = row
177 | new_col = col
178 | return new_row, new_col
179 |
180 |
181 | def trace_path(pred, source, target):
182 | # traces back shortest path from
183 | # source to target given pred
184 | # (a predicessor list)
185 | max_len = 1000
186 | path = np.zeros((max_len, 1))
187 | i = max_len - 1
188 | path[i] = target
189 | while path[i] != source and i > 0:
190 | try:
191 | path[i - 1] = pred[int(path[i])]
192 | i -= 1
193 | except Exception as e:
194 | return []
195 | if i >= 0:
196 | path = path[i:]
197 | else:
198 | path = None
199 | return path
200 |
201 |
202 | def sample_trajectory(M: GridWorld, n_states):
203 | # Samples trajectories from random nodes
204 | # in our domain (M)
205 | G, W = M.get_graph_inv()
206 | N = G.shape[0]
207 | if N >= n_states:
208 | rand_ind = np.random.permutation(N)
209 | else:
210 | rand_ind = np.tile(np.random.permutation(N), (1, 10))
211 | init_states = rand_ind[0:n_states].flatten()
212 | goal_s = M.map_ind_to_state(M.target_x, M.target_y)
213 | states = []
214 | states_xy = []
215 | states_one_hot = []
216 | # Get optimal path from graph
217 | g_dense = W
218 | g_masked = np.ma.masked_values(g_dense, 0)
219 | g_sparse = csr_matrix(g_dense)
220 | d, pred = dijkstra(g_sparse, indices=goal_s, return_predecessors=True)
221 | for i in range(n_states):
222 | path = trace_path(pred, goal_s, init_states[i])
223 | path = np.flip(path, 0)
224 | states.append(path)
225 | for state in states:
226 | L = len(state)
227 | r, c = M.get_coords(state)
228 | row_m = np.zeros((L, M.n_row))
229 | col_m = np.zeros((L, M.n_col))
230 | for i in range(L):
231 | row_m[i, r[i]] = 1
232 | col_m[i, c[i]] = 1
233 | states_one_hot.append(np.hstack((row_m, col_m)))
234 | states_xy.append(np.hstack((r, c)))
235 | return states_xy, states_one_hot
236 |
--------------------------------------------------------------------------------
/download_weights_and_datasets.sh:
--------------------------------------------------------------------------------
1 | cd trained
2 | wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/vin_8x8.pth'
3 | wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/vin_16x16.pth'
4 | wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/vin_28x28.pth'
5 | cd ../dataset
6 | wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/gridworld_8x8.npz'
7 | wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/gridworld_16x16.npz'
8 | wget 'https://github.com/kentsommer/pytorch-value-iteration-networks/releases/download/v1.1/gridworld_28x28.npz'
9 | cd ..
10 |
--------------------------------------------------------------------------------
/generators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/generators/__init__.py
--------------------------------------------------------------------------------
/generators/obstacle_gen.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 |
5 | class obstacles:
6 | """A class for generating obstacles in a domain"""
7 |
8 | def __init__(self,
9 | domsize=None,
10 | mask=None,
11 | size_max=None,
12 | dom=None,
13 | obs_types=None,
14 | num_types=None):
15 | self.domsize = domsize or []
16 | self.mask = mask or []
17 | self.dom = dom or np.zeros(self.domsize)
18 | self.obs_types = obs_types or ["circ", "rect"]
19 | self.num_types = num_types or len(self.obs_types)
20 | self.size_max = size_max or np.max(self.domsize) / 4
21 |
22 | def check_mask(self, dom=None):
23 | # Ensure goal is in free space
24 | if dom is not None:
25 | return np.any(dom[self.mask[0], self.mask[1]])
26 | else:
27 | return np.any(self.dom[self.mask[0], self.mask[1]])
28 |
29 | def insert_rect(self, x, y, height, width):
30 | # Insert a rectangular obstacle into map
31 | im_try = np.copy(self.dom)
32 | im_try[x:x + height, y:y + width] = 1
33 | return im_try
34 |
35 | def add_rand_obs(self, obj_type):
36 | # Add random (valid) obstacle to map
37 | if obj_type == "circ":
38 | print("circ is not yet implemented... sorry")
39 | elif obj_type == "rect":
40 | rand_height = int(np.ceil(np.random.rand() * self.size_max))
41 | rand_width = int(np.ceil(np.random.rand() * self.size_max))
42 | randx = int(np.ceil(np.random.rand() * (self.domsize[1] - 1)))
43 | randy = int(np.ceil(np.random.rand() * (self.domsize[1] - 1)))
44 | im_try = self.insert_rect(randx, randy, rand_height, rand_width)
45 | if self.check_mask(im_try):
46 | return False
47 | else:
48 | self.dom = im_try
49 | return True
50 |
51 | def add_n_rand_obs(self, n):
52 | # Add random (valid) obstacles to map
53 | count = 0
54 | for i in range(n):
55 | obj_type = "rect"
56 | if self.add_rand_obs(obj_type):
57 | count += 1
58 | return count
59 |
60 | def add_border(self):
61 | # Make full outer border an obstacle
62 | im_try = np.copy(self.dom)
63 | im_try[0:self.domsize[0], 0] = 1
64 | im_try[0, 0:self.domsize[1]] = 1
65 | im_try[0:self.domsize[0], self.domsize[1] - 1] = 1
66 | im_try[self.domsize[0] - 1, 0:self.domsize[1]] = 1
67 | if self.check_mask(im_try):
68 | return False
69 | else:
70 | self.dom = im_try
71 | return True
72 |
73 | def get_final(self):
74 | # Process obstacle map for domain
75 | im = np.copy(self.dom)
76 | im = np.max(im) - im
77 | im = im / np.max(im)
78 | return im
79 |
80 | def show(self):
81 | # Utility function to view obstacle map
82 | plt.imshow(self.get_final(), cmap='Greys')
83 | plt.show()
84 |
85 | def _print(self):
86 | # Utility function to view obstacle map
87 | # information
88 | print("domsize: ", self.domsize)
89 | print("mask: ", self.mask)
90 | print("dom: ", self.dom)
91 | print("obs_types: ", self.obs_types)
92 | print("num_types: ", self.num_types)
93 | print("size_max: ", self.size_max)
94 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | from torch.nn.parameter import Parameter
8 |
9 |
10 | class VIN(nn.Module):
11 | def __init__(self, config):
12 | super(VIN, self).__init__()
13 | self.config = config
14 | self.h = nn.Conv2d(
15 | in_channels=config.l_i,
16 | out_channels=config.l_h,
17 | kernel_size=(3, 3),
18 | stride=1,
19 | padding=1,
20 | bias=True)
21 | self.r = nn.Conv2d(
22 | in_channels=config.l_h,
23 | out_channels=1,
24 | kernel_size=(1, 1),
25 | stride=1,
26 | padding=0,
27 | bias=False)
28 | self.q = nn.Conv2d(
29 | in_channels=1,
30 | out_channels=config.l_q,
31 | kernel_size=(3, 3),
32 | stride=1,
33 | padding=1,
34 | bias=False)
35 | self.fc = nn.Linear(in_features=config.l_q, out_features=8, bias=False)
36 | self.w = Parameter(
37 | torch.zeros(config.l_q, 1, 3, 3), requires_grad=True)
38 | self.sm = nn.Softmax(dim=1)
39 |
40 | def forward(self, input_view, state_x, state_y, k):
41 | """
42 | :param input_view: (batch_sz, imsize, imsize)
43 | :param state_x: (batch_sz,), 0 <= state_x < imsize
44 | :param state_y: (batch_sz,), 0 <= state_y < imsize
45 | :param k: number of iterations
46 | :return: logits and softmaxed logits
47 | """
48 | h = self.h(input_view) # Intermediate output
49 | r = self.r(h) # Reward
50 | q = self.q(r) # Initial Q value from reward
51 | v, _ = torch.max(q, dim=1, keepdim=True)
52 |
53 | def eval_q(r, v):
54 | return F.conv2d(
55 | # Stack reward with most recent value
56 | torch.cat([r, v], 1),
57 | # Convolve r->q weights to r, and v->q weights for v. These represent transition probabilities
58 | torch.cat([self.q.weight, self.w], 1),
59 | stride=1,
60 | padding=1)
61 |
62 | # Update q and v values
63 | for i in range(k - 1):
64 | q = eval_q(r, v)
65 | v, _ = torch.max(q, dim=1, keepdim=True)
66 |
67 | q = eval_q(r, v)
68 | # q: (batch_sz, l_q, map_size, map_size)
69 | batch_sz, l_q, _, _ = q.size()
70 | q_out = q[torch.arange(batch_sz), :, state_x.long(), state_y.long()].view(batch_sz, l_q)
71 |
72 | logits = self.fc(q_out) # q_out to actions
73 |
74 | return logits, self.sm(logits)
75 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy>=0.19.0
2 | matplotlib>=2.0.0
3 | numpy>=1.12.1
4 | torchvision>=0.1.8
5 |
--------------------------------------------------------------------------------
/results/16x16_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/results/16x16_1.png
--------------------------------------------------------------------------------
/results/16x16_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/results/16x16_2.png
--------------------------------------------------------------------------------
/results/28x28_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/results/28x28_1.png
--------------------------------------------------------------------------------
/results/28x28_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/results/28x28_2.png
--------------------------------------------------------------------------------
/results/8x8_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/results/8x8_1.png
--------------------------------------------------------------------------------
/results/8x8_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/results/8x8_2.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 |
4 | import matplotlib.pyplot as plt
5 |
6 | import numpy as np
7 |
8 | import torch
9 | from torch.autograd import Variable
10 |
11 | from dataset.dataset import *
12 | from utility.utils import *
13 | from model import *
14 |
15 | from domains.gridworld import *
16 | from generators.obstacle_gen import *
17 |
18 |
19 | def main(config,
20 | n_domains=100,
21 | max_obs=30,
22 | max_obs_size=None,
23 | n_traj=1,
24 | n_actions=8):
25 | # Correct vs total:
26 | correct, total = 0.0, 0.0
27 | # Instantiate a VIN model
28 | vin: VIN = VIN(config)
29 | # Load model parameters
30 | vin.load_state_dict(torch.load(config.weights))
31 | # Automatically select device to make the code device agnostic
32 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33 | vin = vin.to(device)
34 |
35 | for dom in range(n_domains):
36 | # Randomly select goal position
37 | goal = [
38 | np.random.randint(config.imsize),
39 | np.random.randint(config.imsize)
40 | ]
41 | # Generate obstacle map
42 | obs = obstacles([config.imsize, config.imsize], goal, max_obs_size)
43 | # Add obstacles to map
44 | n_obs = obs.add_n_rand_obs(max_obs)
45 | # Add border to map
46 | border_res = obs.add_border()
47 | # Ensure we have valid map
48 | if n_obs == 0 or not border_res:
49 | continue
50 | # Get final map
51 | im = obs.get_final()
52 |
53 | # Generate gridworld from obstacle map
54 | G = GridWorld(im, goal[0], goal[1])
55 | # Get value prior
56 | value_prior = G.get_reward_prior()
57 | # Sample random trajectories to our goal
58 | states_xy, states_one_hot = sample_trajectory(G, n_traj)
59 |
60 | for i in range(n_traj):
61 | if len(states_xy[i]) > 1:
62 |
63 | # Get number of steps to goal
64 | L = len(states_xy[i]) * 2
65 | # Allocate space for predicted steps
66 | pred_traj = np.zeros((L, 2))
67 | # Set starting position
68 | pred_traj[0, :] = states_xy[i][0, :]
69 |
70 | for j in range(1, L):
71 | # Transform current state data
72 | state_data = pred_traj[j - 1, :]
73 | state_data = state_data.astype(np.int)
74 | # Transform domain to Networks expected input shape
75 | im_data = G.image.astype(np.int)
76 | im_data = 1 - im_data
77 | im_data = im_data.reshape(1, 1, config.imsize,
78 | config.imsize)
79 | # Transfrom value prior to Networks expected input shape
80 | value_data = value_prior.astype(np.int)
81 | value_data = value_data.reshape(1, 1, config.imsize,
82 | config.imsize)
83 | # Get inputs as expected by network
84 | X_in = torch.from_numpy(np.append(im_data, value_data, axis=1))
85 | S1_in = torch.from_numpy(state_data[0].reshape([1, 1]))
86 | S2_in = torch.from_numpy(state_data[1].reshape([1, 1]))
87 |
88 | # Get input batch
89 | X_in, S1_in, S2_in = [d.float().to(device) for d in [X_in, S1_in, S2_in]]
90 |
91 | # Forward pass in our neural net
92 | _, predictions = vin(X_in, S1_in, S2_in, config.k)
93 | _, indices = torch.max(predictions.cpu(), 1, keepdim=True)
94 | a = indices.data.numpy()[0][0]
95 | # Transform prediction to indices
96 | s = G.map_ind_to_state(pred_traj[j - 1, 0],
97 | pred_traj[j - 1, 1])
98 | ns = G.sample_next_state(s, a)
99 | nr, nc = G.get_coords(ns)
100 | pred_traj[j, 0] = nr
101 | pred_traj[j, 1] = nc
102 | if nr == goal[0] and nc == goal[1]:
103 | # We hit goal so fill remaining steps
104 | pred_traj[j + 1:, 0] = nr
105 | pred_traj[j + 1:, 1] = nc
106 | break
107 | # Plot optimal and predicted path (also start, end)
108 | if pred_traj[-1, 0] == goal[0] and pred_traj[-1, 1] == goal[1]:
109 | correct += 1
110 | total += 1
111 | if config.plot == True:
112 | visualize(G.image.T, states_xy[i], pred_traj)
113 | sys.stdout.write("\r" + str(int(
114 | (float(dom) / n_domains) * 100.0)) + "%")
115 | sys.stdout.flush()
116 | sys.stdout.write("\n")
117 | print('Rollout Accuracy: {:.2f}%'.format(100 * (correct / total)))
118 |
119 |
120 | def visualize(dom, states_xy, pred_traj):
121 | fig, ax = plt.subplots()
122 | implot = plt.imshow(dom, cmap="Greys_r")
123 | ax.plot(states_xy[:, 0], states_xy[:, 1], c='b', label='Optimal Path')
124 | ax.plot(
125 | pred_traj[:, 0], pred_traj[:, 1], '-X', c='r', label='Predicted Path')
126 | ax.plot(states_xy[0, 0], states_xy[0, 1], '-o', label='Start')
127 | ax.plot(states_xy[-1, 0], states_xy[-1, 1], '-s', label='Goal')
128 | legend = ax.legend(loc='upper right', shadow=False)
129 | for label in legend.get_texts():
130 | label.set_fontsize('x-small') # The legend text size
131 | for label in legend.get_lines():
132 | label.set_linewidth(0.5) # The legend line width
133 | plt.draw()
134 | plt.waitforbuttonpress(0)
135 | plt.close(fig)
136 |
137 |
138 | if __name__ == '__main__':
139 | # Parsing training parameters
140 | parser = argparse.ArgumentParser()
141 | parser.add_argument(
142 | '--weights',
143 | type=str,
144 | default='trained/vin_8x8.pth',
145 | help='Path to trained weights')
146 | parser.add_argument('--plot', action='store_true', default=False)
147 | parser.add_argument('--imsize', type=int, default=8, help='Size of image')
148 | parser.add_argument(
149 | '--k', type=int, default=10, help='Number of Value Iterations')
150 | parser.add_argument(
151 | '--l_i', type=int, default=2, help='Number of channels in input layer')
152 | parser.add_argument(
153 | '--l_h',
154 | type=int,
155 | default=150,
156 | help='Number of channels in first hidden layer')
157 | parser.add_argument(
158 | '--l_q',
159 | type=int,
160 | default=10,
161 | help='Number of channels in q layer (~actions) in VI-module')
162 | config = parser.parse_args()
163 | # Compute Paths generated by network and plot
164 | main(config)
165 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | import argparse
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 |
9 | import torchvision.transforms as transforms
10 |
11 | import matplotlib.pyplot as plt
12 | from dataset.dataset import *
13 | from utility.utils import *
14 | from model import *
15 |
16 |
17 | def train(net: VIN, trainloader, config, criterion, optimizer):
18 | print_header()
19 | # Automatically select device to make the code device agnostic
20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21 | for epoch in range(config.epochs): # Loop over dataset multiple times
22 | avg_error, avg_loss, num_batches = 0.0, 0.0, 0.0
23 | start_time = time.time()
24 | for i, data in enumerate(trainloader): # Loop over batches of data
25 | # Get input batch
26 | X, S1, S2, labels = [d.to(device) for d in data]
27 | if X.size()[0] != config.batch_size:
28 | continue # Drop those data, if not enough for a batch
29 | net = net.to(device)
30 | # Zero the parameter gradients
31 | optimizer.zero_grad()
32 | # Forward pass
33 | outputs, predictions = net(X, S1, S2, config.k)
34 | # Loss
35 | loss = criterion(outputs, labels)
36 | # Backward pass
37 | loss.backward()
38 | # Update params
39 | optimizer.step()
40 | # Calculate Loss and Error
41 | loss_batch, error_batch = get_stats(loss, predictions, labels)
42 | avg_loss += loss_batch
43 | avg_error += error_batch
44 | num_batches += 1
45 | time_duration = time.time() - start_time
46 | # Print epoch logs
47 | print_stats(epoch, avg_loss, avg_error, num_batches, time_duration)
48 | print('\nFinished training. \n')
49 |
50 |
51 | def test(net: VIN, testloader, config):
52 | total, correct = 0.0, 0.0
53 | # Automatically select device, device agnostic
54 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
55 | for i, data in enumerate(testloader):
56 | # Get inputs
57 | X, S1, S2, labels = [d.to(device) for d in data]
58 | if X.size()[0] != config.batch_size:
59 | continue # Drop those data, if not enough for a batch
60 | net = net.to(device)
61 | # Forward pass
62 | outputs, predictions = net(X, S1, S2, config.k)
63 | # Select actions with max scores(logits)
64 | _, predicted = torch.max(outputs, dim=1, keepdim=True)
65 | # Unwrap autograd.Variable to Tensor
66 | predicted = predicted.data
67 | # Compute test accuracy
68 | correct += (torch.eq(torch.squeeze(predicted), labels)).sum()
69 | total += labels.size()[0]
70 | print('Test Accuracy: {:.2f}%'.format(100 * (correct / total)))
71 |
72 |
73 | if __name__ == '__main__':
74 | # Parsing training parameters
75 | parser = argparse.ArgumentParser()
76 | parser.add_argument(
77 | '--datafile',
78 | type=str,
79 | default='dataset/gridworld_8x8.npz',
80 | help='Path to data file')
81 | parser.add_argument('--imsize', type=int, default=8, help='Size of image')
82 | parser.add_argument(
83 | '--lr',
84 | type=float,
85 | default=0.005,
86 | help='Learning rate, [0.01, 0.005, 0.002, 0.001]')
87 | parser.add_argument(
88 | '--epochs', type=int, default=30, help='Number of epochs to train')
89 | parser.add_argument(
90 | '--k', type=int, default=10, help='Number of Value Iterations')
91 | parser.add_argument(
92 | '--l_i', type=int, default=2, help='Number of channels in input layer')
93 | parser.add_argument(
94 | '--l_h',
95 | type=int,
96 | default=150,
97 | help='Number of channels in first hidden layer')
98 | parser.add_argument(
99 | '--l_q',
100 | type=int,
101 | default=10,
102 | help='Number of channels in q layer (~actions) in VI-module')
103 | parser.add_argument(
104 | '--batch_size', type=int, default=128, help='Batch size')
105 | config = parser.parse_args()
106 | # Get path to save trained model
107 | save_path = "trained/vin_{0}x{0}.pth".format(config.imsize)
108 | # Instantiate a VIN model
109 | net = VIN(config)
110 | # Loss
111 | criterion = nn.CrossEntropyLoss()
112 | # Optimizer
113 | optimizer = optim.RMSprop(net.parameters(), lr=config.lr, eps=1e-6)
114 | # Dataset transformer: torchvision.transforms
115 | transform = None
116 | # Define Dataset
117 | trainset = GridworldData(
118 | config.datafile, imsize=config.imsize, train=True, transform=transform)
119 | testset = GridworldData(
120 | config.datafile,
121 | imsize=config.imsize,
122 | train=False,
123 | transform=transform)
124 | # Create Dataloader
125 | trainloader = torch.utils.data.DataLoader(
126 | trainset, batch_size=config.batch_size, shuffle=True, num_workers=0)
127 | testloader = torch.utils.data.DataLoader(
128 | testset, batch_size=config.batch_size, shuffle=False, num_workers=0)
129 | # Train the model
130 | train(net, trainloader, config, criterion, optimizer)
131 | # Test accuracy
132 | test(net, testloader, config)
133 | # Save the trained model parameters
134 | torch.save(net.state_dict(), save_path)
135 |
--------------------------------------------------------------------------------
/trained/README.md:
--------------------------------------------------------------------------------
1 | # Trained Models
2 | To use a pretrained model you have two choices:
3 | 1. Download and place the trained .pth model files here
4 | 2. Train the VIN on the datasets yourself (the models will save themselves here)
--------------------------------------------------------------------------------
/utility/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kentsommer/pytorch-value-iteration-networks/2205fce8ac9f1d9f01f81996f7deef9a7b197a8d/utility/__init__.py
--------------------------------------------------------------------------------
/utility/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def fmt_row(width, row):
6 | out = " | ".join(fmt_item(x, width) for x in row)
7 | return out
8 |
9 |
10 | def fmt_item(x, l):
11 | if isinstance(x, np.ndarray):
12 | assert x.ndim == 0
13 | x = x.item()
14 | if isinstance(x, float): rep = "%g" % x
15 | else: rep = str(x)
16 | return " " * (l - len(rep)) + rep
17 |
18 |
19 | def get_stats(loss, predictions, labels):
20 | cp = np.argmax(predictions.cpu().data.numpy(), 1)
21 | error = np.mean(cp != labels.cpu().data.numpy())
22 | return loss.item(), error
23 |
24 |
25 | def print_stats(epoch, avg_loss, avg_error, num_batches, time_duration):
26 | print(
27 | fmt_row(10, [
28 | epoch + 1, avg_loss / num_batches, avg_error / num_batches,
29 | time_duration
30 | ]))
31 |
32 |
33 | def print_header():
34 | print(fmt_row(10, ["Epoch", "Train Loss", "Train Error", "Epoch Time"]))
35 |
--------------------------------------------------------------------------------