├── .gitignore
├── README.md
├── download_datasets.sh
├── eval.py
├── generate_dataset.py
├── models
├── GPPN.py
├── VIN.py
└── __init__.py
├── requirements.txt
├── train.py
└── utils
├── __init__.py
├── dijkstra.py
├── experiment.py
├── maze.py
├── mechanism.py
└── runner.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Data/log directories
2 | ./mazes
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # pyenv
79 | .python-version
80 |
81 | # celery beat schedule file
82 | celerybeat-schedule
83 |
84 | # SageMath parsed files
85 | *.sage.py
86 |
87 | # Environments
88 | .env
89 | .venv
90 | env/
91 | venv/
92 | ENV/
93 | env.bak/
94 | venv.bak/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 | .spyproject
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # mkdocs documentation
104 | /site
105 |
106 | # mypy
107 | .mypy_cache/
108 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Gated Path Planning Networks (ICML 2018)
2 |
3 | This is the official codebase for the following paper:
4 |
5 | Lisa Lee\*, Emilio Parisotto\*, Devendra Singh Chaplot, Eric Xing, Ruslan Salakhutdinov. **Gated Path Planning Networks**. ICML 2018. https://arxiv.org/abs/1806.06408
6 |
7 | ## Getting Started
8 |
9 | You can clone this repo by running:
10 | ```
11 | git clone https://github.com/lileee/gated-path-planning-networks.git
12 | cd gated-path-planning-networks/
13 | ```
14 |
15 | All subsequent commands in this README should be run from the top-level directory of this repository (i.e., `/path/to/gated-path-planning-networks/`).
16 |
17 | ### I. Docker container
18 |
19 | We provide two Docker containers, with and without GPU support. These containers have Python 3.6.5, PyTorch 0.4.0, and other dependencies installed. They do not contain this codebase or the maze datasets used in our experiments.
20 |
21 | To load the container with GPU support:
22 | ```
23 | # PyTorch with GPU support
24 | nvidia-docker pull lileee/ubuntu-16.04-pytorch-0.4.0-gpu:v1
25 | nvidia-docker run -v $(pwd):/home --rm -ti lileee/ubuntu-16.04-pytorch-0.4.0-gpu:v1
26 | ```
27 |
28 | To load the container without GPU support:
29 | ```
30 | # PyTorch (CPU-only)
31 | nvidia-docker pull lileee/ubuntu-16.04-pytorch-0.4.0-cpu:v1
32 | nvidia-docker run -v $(pwd):/home --rm -ti lileee/ubuntu-16.04-pytorch-0.4.0-cpu:v1
33 | ```
34 |
35 | Here is a speed comparison between the Docker containers for training VIN on a `9x9` maze with 5k/1k/1k train-val-test split:
36 |
37 | | PyTorch 0.4.0 | time per epoch
38 | |:-----------------:|:---------:|
39 | |with GPU support | 8.5 sec
40 | |without GPU support| 32.3 sec
41 |
42 | ### II. Generate a 2D maze dataset
43 |
44 | Generate a dataset by running:
45 |
46 | ```
47 | python generate_dataset.py --output-path mazes.npz --mechanism news --maze-size 9 --train-size 5000 --valid-size 1000 --test-size 1000
48 | ```
49 | This will create a datafile `mazes.npz` containing a dataset of `9x9` mazes using the NEWS maze transition mechanism with 5k/1k/1k train-val-test split.
50 |
51 | Note:
52 | * The same maze transition mechanism that was used to generate the dataset must be used for `train.py` and `eval.py`. Here, we used `--mechanism news` to generate the dataset. Other options are `--mechanism moore` and `--mechanism diffdrive`.
53 |
54 | ### III. Train a model
55 |
56 | You can train a VIN with iteration count K=15 and kernel size F=5 on the datafile `mazes.npz` by running:
57 | ```
58 | python train.py --datafile mazes.npz --mechanism news --model models.VIN --k 15 --f 5 --save-directory log/vin-k15-f5
59 | ```
60 | This will save outputs to the subdirectory `vin-k15-f5/`, including the trained models and learning plots.
61 |
62 | Similarly, you can train a GPPN by running:
63 | ```
64 | python train.py --datafile mazes.npz --mechanism news --model models.GPPN --k 15 --f 5 --save-directory log/gppn-k15-f5
65 | ```
66 |
67 | Notes:
68 | * `--mechanism` must be the same as the one used to generate `mazes.npz` (which is `news` in this example).
69 | * `--f` must be an odd integer.
70 |
71 | ### IV. Evaluate a trained model
72 |
73 | Once you have a trained VIN model, you can evaluate it on a dataset by running:
74 |
75 | ```
76 | python eval.py --datafile mazes.npz --mechanism news --model models.VIN --k 15 --f 5 --load-file log/vin-k15-f5/planner.final.pth
77 | ```
78 | Similarly for GPPN:
79 |
80 | ```
81 | python eval.py --datafile mazes.npz --mechanism news --model models.GPPN --k 15 --f 5 --load-file log/gppn-k15-f5/planner.final.pth
82 | ```
83 |
84 | Notes:
85 | * `--mechanism` must be the same as the one used to generate `mazes.npz` (which is `news` in this example).
86 | * `--f` must be the same as the one used to train the model.
87 |
88 |
89 | ## Replicating experiments from our paper
90 |
91 | ### I. Download 2D maze datasets
92 | To replicate experiments from our ICML 2018 paper, first download the datasets by running:
93 | ```
94 | ./download_datasets.sh
95 | ```
96 | This will create a subdirectory `mazes/` containing the following 2D maze datasets used in our experiments:
97 |
98 | | datafile | maze size | mechanism | train size | val size | test size |
99 | | :---------------------: |:---------:| :----------:|:----------:|:--------:|:---------:|
100 | | `m15_news_10k.npz` | `15x15` | `news` | 10000 | 2000 | 2000
101 | | `m15_news_25k.npz` | `15x15` | `news` | 25000 | 5000 | 5000
102 | | `m15_news_100k.npz` | `15x15` | `news` | 100000 | 10000 | 10000
103 | | `m15_moore_10k.npz` | `15x15` | `moore` | 10000 | 2000 | 2000
104 | | `m15_moore_25k.npz` | `15x15` | `moore` | 25000 | 5000 | 5000
105 | | `m15_moore_100k.npz` | `15x15` | `moore` | 100000 | 10000 | 10000
106 | | `m15_diffdrive_10k.npz` | `15x15` | `diffdrive` | 10000 | 2000 | 2000
107 | | `m15_diffdrive_25k.npz` | `15x15` | `diffdrive` | 25000 | 5000 | 5000
108 | | `m15_diffdrive_100k.npz`| `15x15` | `diffdrive` | 100000 | 10000 | 10000
109 | | `m28_news_25k.npz` | `28x28` | `news` | 25000 | 5000 | 5000
110 | | `m28_moore_25k.npz` | `28x28` | `moore` | 25000 | 5000 | 5000
111 | | `m28_diffdrive_25k.npz` | `28x28` | `diffdrive` | 25000 | 5000 | 5000
112 |
113 |
114 | ### II. Train VIN and GPPN
115 | Then you can train VIN with the best (K, F) settings for each dataset from our paper by running:
116 | ```
117 | python train.py --datafile mazes/m15_news_10k.npz --mechanism news --model models.VIN --k 30 --f 5
118 | python train.py --datafile mazes/m15_news_25k.npz --mechanism news --model models.VIN --k 20 --f 5
119 | python train.py --datafile mazes/m15_news_100k.npz --mechanism news --model models.VIN --k 30 --f 3
120 |
121 | python train.py --datafile mazes/m15_moore_10k.npz --mechanism moore --model models.VIN --k 30 --f 11
122 | python train.py --datafile mazes/m15_moore_25k.npz --mechanism moore --model models.VIN --k 30 --f 5
123 | python train.py --datafile mazes/m15_moore_100k.npz --mechanism moore --model models.VIN --k 30 --f 5
124 |
125 | python train.py --datafile mazes/m15_diffdrive_10k.npz --mechanism diffdrive --model models.VIN --k 30 --f 3
126 | python train.py --datafile mazes/m15_diffdrive_25k.npz --mechanism diffdrive --model models.VIN --k 30 --f 3
127 | python train.py --datafile mazes/m15_diffdrive_100k.npz --mechanism diffdrive --model models.VIN --k 30 --f 3
128 |
129 | python train.py --datafile mazes/m28_news_25k.npz --mechanism news --model models.VIN --k 56 --f 3
130 | python train.py --datafile mazes/m28_moore_25k.npz --mechanism moore --model models.VIN --k 56 --f 5
131 | python train.py --datafile mazes/m28_diffdrive_25k.npz --mechanism diffdrive --model models.VIN --k 56 --f 3
132 | ```
133 |
134 | Similarly, you can train GPPN with the best (K, F) settings for each dataset from our paper by running:
135 | ```
136 | python train.py --datafile mazes/m15_news_10k.npz --mechanism news --model models.GPPN --k 20 --f 9
137 | python train.py --datafile mazes/m15_news_25k.npz --mechanism news --model models.GPPN --k 20 --f 11
138 | python train.py --datafile mazes/m15_news_100k.npz --mechanism news --model models.GPPN --k 30 --f 11
139 |
140 | python train.py --datafile mazes/m15_moore_10k.npz --mechanism moore --model models.GPPN --k 30 --f 7
141 | python train.py --datafile mazes/m15_moore_25k.npz --mechanism moore --model models.GPPN --k 30 --f 9
142 | python train.py --datafile mazes/m15_moore_100k.npz --mechanism moore --model models.GPPN --k 30 --f 7
143 |
144 | python train.py --datafile mazes/m15_diffdrive_10k.npz --mechanism diffdrive --model models.GPPN --k 30 --f 11
145 | python train.py --datafile mazes/m15_diffdrive_25k.npz --mechanism diffdrive --model models.GPPN --k 30 --f 9
146 | python train.py --datafile mazes/m15_diffdrive_100k.npz --mechanism diffdrive --model models.GPPN --k 30 --f 9
147 |
148 | python train.py --datafile mazes/m28_news_25k.npz --mechanism news --model models.GPPN --k 56 --f 11
149 | python train.py --datafile mazes/m28_moore_25k.npz --mechanism moore --model models.GPPN --k 56 --f 9
150 | python train.py --datafile mazes/m28_diffdrive_25k.npz --mechanism diffdrive --model models.GPPN --k 56 --f 11
151 | ```
152 |
153 | ### III. Test Performance Results
154 |
155 | Here are the test performance results from running the above commands inside the Docker container `lileee/ubuntu-16.04-pytorch-0.4.0-gpu:v1`:
156 |
157 |
158 |
159 | |
160 | VIN |
161 | GPPN |
162 |
163 |
164 | datafile |
165 | K |
166 | F |
167 | %Opt |
168 | %Suc |
169 | K |
170 | F |
171 | %Opt |
172 | %Suc |
173 |
174 |
175 | m15_news_10k.npz |
176 | 30 | 5 | 77.4 | 79.0 |
177 | 20 | 9 | 96.8 | 97.8 |
178 |
179 |
180 | m15_news_25k.npz |
181 | 20 | 5 | 83.6 | 84.2 |
182 | 20 | 11 | 99.0 | 99.3 |
183 |
184 |
185 | m15_news_100k.npz |
186 | 30 | 3 | 92.6 | 92.8 |
187 | 30 | 11 | 99.7 | 99.8 |
188 |
189 |
190 | m15_moore_10k.npz |
191 | 30 | 11 | 86.0 | 89.3 |
192 | 30 | 7 | 97.0 | 98.0 |
193 |
194 |
195 | m15_moore_25k.npz |
196 | 30 | 5 | 85.4 | 88.1 |
197 | 30 | 9 | 98.9 | 99.5 |
198 |
199 |
200 | m15_moore_100k.npz |
201 | 30 | 5 | 96.9 | 97.5 |
202 | 30 | 7 | 99.6 | 99.8 |
203 |
204 |
205 | m15_diffdrive_10k.npz |
206 | 30 | 3 | 98.4 | 99.0 |
207 | 30 | 11 | 99.1 | 99.7 |
208 |
209 |
210 | m15_diffdrive_25k.npz |
211 | 30 | 3 | 96.1 | 98.5 |
212 | 30 | 9 | 98.9 | 99.5 |
213 |
214 |
215 | m15_diffdrive_100k.npz |
216 | 30 | 3 | 99.0 | 99.4 |
217 | 30 | 9 | 99.8 | 99.9 |
218 |
219 |
220 | m28_news_25k.npz |
221 | 56 | 3 | 83.4 | 84.2 |
222 | 56 | 11 | 96.5 | 97.8 |
223 |
224 |
225 | m28_moore_25k.npz |
226 | 56 | 5 | 73.3 | 81.0 |
227 | 56 | 9 | 96.5 | 97.9 |
228 |
229 |
230 | m28_diffdrive_25k.npz |
231 | 56 | 3 | 82.0 | 93.6 |
232 | 56 | 11 | 95.3 | 98.0 |
233 |
234 |
235 |
236 | Feel free to play around with different iteration counts `--k` and kernel sizes `--f`.
237 |
238 | ### IV. Version differences (Python, PyTorch)
239 |
240 | The test performance results above are slightly different from what is reported in our ICML 2018 paper due to version differences in Python (3.6.5 vs. 2.7.12) and PyTorch (0.4.0 vs. 0.3.1).
241 |
242 | Below, we provide instructions to exactly replicate the numbers reported in our ICML 2018 paper.
243 |
244 | 1. Checkout the Git branch `icml2018`:
245 | ```
246 | git checkout icml2018
247 | ```
248 |
249 | 2. Load the Docker container used in our experiments by running:
250 | ```
251 | # PyTorch with GPU support
252 | nvidia-docker pull lileee/python-2.7-pytorch-0.3.1-custom:latest
253 | nvidia-docker run -v $(pwd):/home --rm -ti lileee/python-2.7-pytorch-0.3.1-custom:latest
254 | ```
255 | This Docker container uses Python 2.7.12 and a custom version of PyTorch 0.3.1 compiled from source at https://github.com/eparisotto/pytorch.
256 |
257 | 3. Train a model:
258 | ```
259 | python train.py --datafile mazes/m15_news_25k.npz --mechanism news --model models.VIN --k 20 --f 5
260 | ```
261 |
262 | ## Citation
263 |
264 | If you found this code useful in your research, please cite:
265 |
266 | ```
267 | @inproceedings{gppn2018,
268 | author = {Lisa Lee and Emilio Parisotto and Devendra Singh Chaplot and Eric Xing and Ruslan Salakhutdinov},
269 | title = {Gated Path Planning Networks},
270 | booktitle = {Proceedings of the 35th International Conference on Machine Learning (ICML 2018)},
271 | year = {2018}
272 | }
273 | ```
274 |
275 | ## Acknowledgments
276 |
277 | Thanks to [@kentsommer](https://github.com/kentsommer) for releasing a [PyTorch implementation](https://github.com/kentsommer/pytorch-value-iteration-networks) of the original VIN results, which served as a starting point for this codebase.
278 |
--------------------------------------------------------------------------------
/download_datasets.sh:
--------------------------------------------------------------------------------
1 | # Downloads 2D maze datasets used in our ICML 2018 paper.
2 |
3 | OUTPUT_DIR=./mazes
4 |
5 | mkdir ${OUTPUT_DIR}
6 |
7 | # 15x15 mazes (NEWS)
8 | wget https://cmu.box.com/shared/static/l46uqg78f2iik8avr9rdfyfi28vyek7o.npz -O ${OUTPUT_DIR}/m15_news_10k.npz
9 | wget https://cmu.box.com/shared/static/voqgj886o1gfx7ievbytzybonlu8s56n.npz -O ${OUTPUT_DIR}/m15_news_25k.npz
10 | wget https://cmu.box.com/shared/static/4bp5xxs7ilfohyosy941sc8bi9fbed8w.npz -O ${OUTPUT_DIR}/m15_news_100k.npz
11 |
12 | # 15x15 mazes (Moore)
13 | wget https://cmu.box.com/shared/static/p8aanmo5kj1bm9949njmigntmbol0o7s.npz -O ${OUTPUT_DIR}/m15_moore_10k.npz
14 | wget https://cmu.box.com/shared/static/1nmxgma8uvnuezfifqxieliiy9qrx4pe.npz -O ${OUTPUT_DIR}/m15_moore_25k.npz
15 | wget https://cmu.box.com/shared/static/gg78z5ka2sjx9jcbj6m31du1v9vmfmhm.npz -O ${OUTPUT_DIR}/m15_moore_100k.npz
16 |
17 | # 15x15 mazes (DiffDrive)
18 | wget https://cmu.box.com/shared/static/3rv3aghi8df17vwnidcj1kd0qci2qa94.npz -O ${OUTPUT_DIR}/m15_diffdrive_10k.npz
19 | wget https://cmu.box.com/shared/static/3rv3aghi8df17vwnidcj1kd0qci2qa94.npz -O ${OUTPUT_DIR}/m15_diffdrive_25k.npz
20 | wget https://cmu.box.com/shared/static/pjfw2rwj88ibx4ako6balz8d810qujw1.npz -O ${OUTPUT_DIR}/m15_diffdrive_100k.npz
21 |
22 | # 28x28 mazes
23 | wget https://cmu.box.com/shared/static/2bat54yisnnx5yybzl5uhtzkex0g65sx.npz -O ${OUTPUT_DIR}/m28_news_25k.npz
24 | wget https://cmu.box.com/shared/static/c6st11gqtx1n1xs1lu86sys00yf7yu84.npz -O ${OUTPUT_DIR}/m28_moore_25k.npz
25 | wget https://cmu.box.com/shared/static/r3bgykf8zss8ro0pfqw8sc3pgjfi1urg.npz -O ${OUTPUT_DIR}/m28_diffdrive_25k.npz
26 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluates a trained model on a dataset.
3 |
4 | Example usage:
5 | python eval.py --datafile mazes.npz --mechanism news --model models.GPPN \
6 | --k 15 --f 5 --load-file log/gppn-k15-f5/planner.final.pth
7 | """
8 | from __future__ import print_function
9 |
10 | from utils.experiment import (parse_args, create_save_dir, get_mechanism,
11 | create_dataloader, print_stats)
12 | from utils.runner import Runner
13 |
14 |
15 | def main():
16 | args = parse_args()
17 |
18 | create_save_dir(args.save_directory)
19 | mechanism = get_mechanism(args.mechanism)
20 |
21 | # Create DataLoaders.
22 | trainloader = create_dataloader(
23 | args.datafile, "train", args.batch_size, mechanism, shuffle=True)
24 | validloader = create_dataloader(
25 | args.datafile, "valid", args.batch_size, mechanism, shuffle=False)
26 | testloader = create_dataloader(
27 | args.datafile, "test", args.batch_size, mechanism, shuffle=False)
28 |
29 | runner = Runner(args, mechanism)
30 |
31 | print("\n------------- Evaluating final model -------------")
32 | print("\nTrain performance:")
33 | print_stats(runner.test(trainloader))
34 |
35 | print("\nValidation performance:")
36 | print_stats(runner.test(testloader))
37 |
38 | print("\nTest performance:")
39 | print_stats(runner.test(validloader))
40 |
41 | print("\n------------- Evaluating best model -------------")
42 | print("\nTrain performance:")
43 | print_stats(runner.test(trainloader, use_best=True))
44 |
45 | print("\nValidation performance:")
46 | print_stats(runner.test(testloader, use_best=True))
47 |
48 | print("\nTest performance:")
49 | print_stats(runner.test(validloader, use_best=True))
50 |
51 |
52 | if __name__ == "__main__":
53 | main()
54 |
--------------------------------------------------------------------------------
/generate_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Generates a 2D maze dataset.
3 |
4 | Example usage:
5 | python generate_dataset.py --output-path mazes.npz --mechanism news \
6 | --maze-size 9 --train-size 5000 --valid-size 1000 --test-size 1000
7 | """
8 | from __future__ import print_function
9 | import sys
10 | import argparse
11 | import numpy as np
12 |
13 | from utils.dijkstra import dijkstra_dist
14 | from utils.experiment import get_mechanism
15 | from utils.maze import RandomMaze, extract_policy
16 |
17 |
18 | def generate_data(filename,
19 | train_size,
20 | valid_size,
21 | test_size,
22 | mechanism,
23 | maze_size,
24 | min_decimation,
25 | max_decimation,
26 | start_pos=(1, 1)):
27 | maze_class = RandomMaze(
28 | mechanism,
29 | maze_size,
30 | maze_size,
31 | min_decimation,
32 | max_decimation,
33 | start_pos=start_pos)
34 |
35 | def hash_maze_to_string(maze):
36 | maze = np.array(maze, dtype=np.uint8).reshape((-1))
37 | mazekey = ""
38 | for i in range(maze.shape[0]):
39 | mazekey += str(maze[i])
40 | return mazekey
41 |
42 | def hashed_check_maze_exists(mazekey, mazehash):
43 | if mazehash is None:
44 | return False
45 | if mazekey in mazehash:
46 | return True
47 | return False
48 |
49 | def check_maze_exists(maze, compare_mazes):
50 | if compare_mazes is None:
51 | return False
52 | diff = np.sum(
53 | np.abs(compare_mazes - maze).reshape((len(compare_mazes), -1)),
54 | axis=1)
55 | if np.sum(diff == 0):
56 | return True
57 | return False
58 |
59 | def extract_goal(goal_map):
60 | for o in range(mechanism.num_orient):
61 | for y in range(maze_size):
62 | for x in range(maze_size):
63 | if goal_map[o][y][x] == 1.:
64 | return (o, y, x)
65 |
66 | def create_dataset(data_size, compare_mazes=None):
67 | mazes = np.zeros((data_size, maze_size, maze_size))
68 | goal_maps = np.zeros((data_size, mechanism.num_orient, maze_size,
69 | maze_size))
70 | opt_policies = np.zeros((data_size, mechanism.num_actions,
71 | mechanism.num_orient, maze_size, maze_size))
72 |
73 | mazehash = {}
74 | if compare_mazes is not None:
75 | for i in range(compare_mazes.shape[0]):
76 | maze = compare_mazes[i]
77 | mazekey = hash_maze_to_string(maze)
78 | mazehash[mazekey] = 1
79 | for i in range(data_size):
80 | maze, goal_map = None, None
81 | while True:
82 | maze, _, goal_map = maze_class.reset()
83 | mazekey = hash_maze_to_string(maze)
84 |
85 | # Make sure we sampled a unique maze from the compare set
86 | if hashed_check_maze_exists(mazekey, mazehash):
87 | continue
88 | mazehash[mazekey] = 1
89 | break
90 |
91 | # Use Dijkstra's to construct the optimal policy
92 | opt_value = dijkstra_dist(maze, mechanism, extract_goal(goal_map))
93 | opt_policy = extract_policy(maze, mechanism, opt_value)
94 |
95 | mazes[i, :, :] = maze
96 | goal_maps[i, :, :, :] = goal_map
97 | opt_policies[i, :, :, :, :] = opt_policy
98 |
99 | sys.stdout.write("\r%0.4f" % (float(i) / data_size * 100) + "%")
100 | sys.stdout.flush()
101 | sys.stdout.write("\r100%\n")
102 |
103 | return mazes, goal_maps, opt_policies
104 |
105 | # Generate test set first
106 | print("Creating valid+test dataset...")
107 | validtest_mazes, validtest_goal_maps, validtest_opt_policies = create_dataset(
108 | test_size + valid_size)
109 |
110 | # Split valid and test
111 | valid_mazes = validtest_mazes[0:valid_size]
112 | test_mazes = validtest_mazes[valid_size:]
113 | valid_goal_maps = validtest_goal_maps[0:valid_size]
114 | test_goal_maps = validtest_goal_maps[valid_size:]
115 | valid_opt_policies = validtest_opt_policies[0:valid_size]
116 | test_opt_policies = validtest_opt_policies[valid_size:]
117 |
118 | # Generate train set while avoiding test geometries
119 | print("Creating training dataset...")
120 | train_mazes, train_goal_maps, train_opt_policies = create_dataset(
121 | train_size, compare_mazes=validtest_mazes)
122 |
123 | # Re-shuffle
124 | mazes = np.concatenate((train_mazes, valid_mazes, test_mazes), 0)
125 | goal_maps = np.concatenate(
126 | (train_goal_maps, valid_goal_maps, test_goal_maps), 0)
127 | opt_policies = np.concatenate(
128 | (train_opt_policies, valid_opt_policies, test_opt_policies), 0)
129 |
130 | shuffidx = np.random.permutation(mazes.shape[0])
131 | mazes = mazes[shuffidx]
132 | goal_maps = goal_maps[shuffidx]
133 | opt_policies = opt_policies[shuffidx]
134 |
135 | train_mazes = mazes[:train_size]
136 | train_goal_maps = goal_maps[:train_size]
137 | train_opt_policies = opt_policies[:train_size]
138 |
139 | valid_mazes = mazes[train_size:train_size + valid_size]
140 | valid_goal_maps = goal_maps[train_size:train_size + valid_size]
141 | valid_opt_policies = opt_policies[train_size:train_size + valid_size]
142 |
143 | test_mazes = mazes[train_size + valid_size:]
144 | test_goal_maps = goal_maps[train_size + valid_size:]
145 | test_opt_policies = opt_policies[train_size + valid_size:]
146 |
147 | # Save to numpy
148 | np.savez_compressed(filename, train_mazes, train_goal_maps,
149 | train_opt_policies, valid_mazes, valid_goal_maps,
150 | valid_opt_policies, test_mazes, test_goal_maps,
151 | test_opt_policies)
152 |
153 |
154 | def main():
155 | parser = argparse.ArgumentParser()
156 | parser.add_argument(
157 | "--output-path", type=str, default="mazes.npz",
158 | help="Filename to save the dataset to.")
159 | parser.add_argument(
160 | "--train-size", type=int, default=10000,
161 | help="Number of training mazes.")
162 | parser.add_argument(
163 | "--valid-size", type=int, default=1000,
164 | help="Number of validation mazes.")
165 | parser.add_argument(
166 | "--test-size", type=int, default=1000,
167 | help="Number of test mazes.")
168 | parser.add_argument(
169 | "--maze-size", type=int, default=9,
170 | help="Size of mazes.")
171 | parser.add_argument(
172 | "--min-decimation", type=float, default=0.0,
173 | help="How likely a wall is to be destroyed (minimum).")
174 | parser.add_argument("--max-decimation", type=float, default=1.0,
175 | help="How likely a wall is to be destroyed (maximum).")
176 | parser.add_argument(
177 | "--start-pos-x", type=int, default=1,
178 | help="Maze start X-axis position.")
179 | parser.add_argument(
180 | "--start-pos-y", type=int, default=1,
181 | help="Maze start Y-axis position.")
182 | parser.add_argument(
183 | "--mechanism", type=str, default="news",
184 | help="Maze transition mechanism. (news|diffdrive|moore)")
185 | args = parser.parse_args()
186 |
187 | mechanism = get_mechanism(args.mechanism)
188 | generate_data(
189 | args.output_path,
190 | args.train_size,
191 | args.valid_size,
192 | args.test_size,
193 | mechanism,
194 | args.maze_size,
195 | args.min_decimation,
196 | args.max_decimation,
197 | start_pos=(args.start_pos_y, args.start_pos_x))
198 |
199 |
200 | if __name__ == "__main__":
201 | main()
--------------------------------------------------------------------------------
/models/GPPN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | # Gated Path Planning Network module
6 | class Planner(nn.Module):
7 | """
8 | Implementation of the Gated Path Planning Network.
9 | """
10 | def __init__(self, num_orient, num_actions, args):
11 | super(Planner, self).__init__()
12 |
13 | self.num_orient = num_orient
14 | self.num_actions = num_actions
15 |
16 | self.l_h = args.l_h
17 | self.k = args.k
18 | self.f = args.f
19 |
20 | self.hid = nn.Conv2d(
21 | in_channels=(num_orient + 1), # maze map + goal location
22 | out_channels=self.l_h,
23 | kernel_size=(3, 3),
24 | stride=1,
25 | padding=1,
26 | bias=True)
27 |
28 | self.h0 = nn.Conv2d(
29 | in_channels=self.l_h,
30 | out_channels=self.l_h,
31 | kernel_size=(3, 3),
32 | stride=1,
33 | padding=1,
34 | bias=True)
35 |
36 | self.c0 = nn.Conv2d(
37 | in_channels=self.l_h,
38 | out_channels=self.l_h,
39 | kernel_size=(3, 3),
40 | stride=1,
41 | padding=1,
42 | bias=True)
43 |
44 | self.conv = nn.Conv2d(
45 | in_channels=self.l_h,
46 | out_channels=1,
47 | kernel_size=(self.f, self.f),
48 | stride=1,
49 | padding=int((self.f - 1.0) / 2),
50 | bias=True)
51 |
52 | self.lstm = nn.LSTMCell(1, self.l_h)
53 |
54 | self.policy = nn.Conv2d(
55 | in_channels=self.l_h,
56 | out_channels=num_actions * num_orient,
57 | kernel_size=(1, 1),
58 | stride=1,
59 | padding=0,
60 | bias=False)
61 |
62 | self.sm = nn.Softmax2d()
63 |
64 | def forward(self, map_design, goal_map):
65 | maze_size = map_design.size()[-1]
66 | X = torch.cat([map_design, goal_map], 1)
67 |
68 | hid = self.hid(X)
69 | h0 = self.h0(hid).transpose(1, 3).contiguous().view(-1, self.l_h)
70 | c0 = self.c0(hid).transpose(1, 3).contiguous().view(-1, self.l_h)
71 |
72 | last_h, last_c = h0, c0
73 | for _ in range(0, self.k - 1):
74 | h_map = last_h.view(-1, maze_size, maze_size, self.l_h)
75 | h_map = h_map.transpose(3, 1)
76 | inp = self.conv(h_map).transpose(1, 3).contiguous().view(-1, 1)
77 |
78 | last_h, last_c = self.lstm(inp, (last_h, last_c))
79 |
80 | hk = last_h.view(-1, maze_size, maze_size, self.l_h).transpose(3, 1)
81 | logits = self.policy(hk)
82 |
83 | # Normalize over actions
84 | logits = logits.view(-1, self.num_actions, maze_size, maze_size)
85 | probs = self.sm(logits)
86 |
87 | # Reshape to output dimensions
88 | logits = logits.view(-1, self.num_orient, self.num_actions, maze_size,
89 | maze_size)
90 | probs = probs.view(-1, self.num_orient, self.num_actions, maze_size,
91 | maze_size)
92 | logits = torch.transpose(logits, 1, 2).contiguous()
93 | probs = torch.transpose(probs, 1, 2).contiguous()
94 |
95 | return logits, probs, h0, hk
96 |
--------------------------------------------------------------------------------
/models/VIN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.parameter import Parameter
5 |
6 |
7 | # VIN planner module
8 | class Planner(nn.Module):
9 | """
10 | Implementation of the Value Iteration Network.
11 | """
12 | def __init__(self, num_orient, num_actions, args):
13 | super(Planner, self).__init__()
14 |
15 | self.num_orient = num_orient
16 | self.num_actions = num_actions
17 |
18 | self.l_q = args.l_q
19 | self.l_h = args.l_h
20 | self.k = args.k
21 | self.f = args.f
22 |
23 | self.h = nn.Conv2d(
24 | in_channels=(num_orient + 1), # maze map + goal location
25 | out_channels=self.l_h,
26 | kernel_size=(3, 3),
27 | stride=1,
28 | padding=1,
29 | bias=True)
30 |
31 | self.r = nn.Conv2d(
32 | in_channels=self.l_h,
33 | out_channels=num_orient, # reward per orientation
34 | kernel_size=(1, 1),
35 | stride=1,
36 | padding=0,
37 | bias=False)
38 |
39 | self.q = nn.Conv2d(
40 | in_channels=num_orient,
41 | out_channels=self.l_q * num_orient,
42 | kernel_size=(self.f, self.f),
43 | stride=1,
44 | padding=int((self.f - 1.0) / 2),
45 | bias=False)
46 |
47 | self.policy = nn.Conv2d(
48 | in_channels=self.l_q * num_orient,
49 | out_channels=num_actions * num_orient,
50 | kernel_size=(1, 1),
51 | stride=1,
52 | padding=0,
53 | bias=False)
54 |
55 | self.w = Parameter(
56 | torch.zeros(self.l_q * num_orient, num_orient, self.f,
57 | self.f),
58 | requires_grad=True)
59 |
60 | self.sm = nn.Softmax2d()
61 |
62 | def forward(self, map_design, goal_map):
63 | maze_size = map_design.size()[-1]
64 | X = torch.cat([map_design, goal_map], 1)
65 |
66 | h = self.h(X)
67 | r = self.r(h)
68 | q = self.q(r)
69 | q = q.view(-1, self.num_orient, self.l_q, maze_size, maze_size)
70 | v, _ = torch.max(q, dim=2, keepdim=True)
71 | v = v.view(-1, self.num_orient, maze_size, maze_size)
72 | for _ in range(0, self.k - 1):
73 | q = F.conv2d(
74 | torch.cat([r, v], 1),
75 | torch.cat([self.q.weight, self.w], 1),
76 | stride=1,
77 | padding=int((self.f - 1.0) / 2))
78 | q = q.view(-1, self.num_orient, self.l_q, maze_size, maze_size)
79 | v, _ = torch.max(q, dim=2)
80 | v = v.view(-1, self.num_orient, maze_size, maze_size)
81 |
82 | q = F.conv2d(
83 | torch.cat([r, v], 1),
84 | torch.cat([self.q.weight, self.w], 1),
85 | stride=1,
86 | padding=int((self.f - 1.0) / 2))
87 |
88 | logits = self.policy(q)
89 |
90 | # Normalize over actions
91 | logits = logits.view(-1, self.num_actions, maze_size, maze_size)
92 | probs = self.sm(logits)
93 |
94 | # Reshape to output dimensions
95 | logits = logits.view(-1, self.num_orient, self.num_actions, maze_size,
96 | maze_size)
97 | probs = probs.view(-1, self.num_orient, self.num_actions, maze_size,
98 | maze_size)
99 | logits = torch.transpose(logits, 1, 2).contiguous()
100 | probs = torch.transpose(probs, 1, 2).contiguous()
101 |
102 | return logits, probs, v, r
103 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RLAgent/gated-path-planning-networks/df010fc26667f2b4f0dbbcc0d3d528d4ae6efeb1/models/__init__.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | torch >= 0.4.0
3 | matplotlib
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Trains a planner model.
3 |
4 | Example usage:
5 | python train.py --datafile mazes.npz --mechanism news --model models.GPPN \
6 | --k 15 --f 5 --save-directory log/gppn-k15-f5
7 | """
8 | from __future__ import print_function
9 |
10 | import argparse
11 | import time
12 | import numpy as np
13 |
14 | import matplotlib as mpl
15 | mpl.use("Agg")
16 | import matplotlib.pyplot as plt
17 |
18 | import torch
19 |
20 | from utils.experiment import (parse_args, create_save_dir, get_mechanism,
21 | create_dataloader, print_row, print_stats)
22 | from utils.runner import Runner
23 |
24 |
25 | def main():
26 | args = parse_args()
27 |
28 | save_path = create_save_dir(args.save_directory)
29 | mechanism = get_mechanism(args.mechanism)
30 |
31 | # Create DataLoaders
32 | trainloader = create_dataloader(
33 | args.datafile, "train", args.batch_size, mechanism, shuffle=True)
34 | validloader = create_dataloader(
35 | args.datafile, "valid", args.batch_size, mechanism, shuffle=False)
36 | testloader = create_dataloader(
37 | args.datafile, "test", args.batch_size, mechanism, shuffle=False)
38 |
39 | runner = Runner(args, mechanism)
40 |
41 | # Print header
42 | col_width = 5
43 | print("\n | Train | Valid |") # pylint: disable=line-too-long
44 | print_row(col_width, [
45 | "Epoch", "CE", "Err", "%Opt", "%Suc", "CE", "Err", "%Opt", "%Suc", "W",
46 | "dW", "Time", "Best"
47 | ])
48 |
49 | tr_total_loss, tr_total_error, tr_total_optimal, tr_total_success = [], [], [], []
50 | v_total_loss, v_total_error, v_total_optimal, v_total_success = [], [], [], []
51 | for epoch in range(args.epochs):
52 | start_time = time.time()
53 |
54 | # Train the model
55 | tr_info = runner.train(trainloader, args.batch_size)
56 |
57 | # Compute validation stats and save the best model
58 | v_info = runner.validate(validloader)
59 | time_duration = time.time() - start_time
60 |
61 | # Print epoch logs
62 | print_row(col_width, [
63 | epoch + 1, tr_info["avg_loss"], tr_info["avg_error"],
64 | tr_info["avg_optimal"], tr_info["avg_success"], v_info["avg_loss"],
65 | v_info["avg_error"], v_info["avg_optimal"], v_info["avg_success"],
66 | tr_info["weight_norm"], tr_info["grad_norm"],
67 | time_duration,
68 | "!" if v_info["is_best"] else " "
69 | ])
70 |
71 | # Keep track of metrics:
72 | tr_total_loss.append(tr_info["avg_loss"])
73 | tr_total_error.append(tr_info["avg_error"])
74 | tr_total_optimal.append(tr_info["avg_optimal"])
75 | tr_total_success.append(tr_info["avg_success"])
76 | v_total_loss.append(v_info["avg_loss"])
77 | v_total_error.append(v_info["avg_error"])
78 | v_total_optimal.append(v_info["avg_optimal"])
79 | v_total_success.append(v_info["avg_success"])
80 |
81 | # Plot learning curves.
82 | def _plot(train, valid, name):
83 | plt.clf()
84 | x = np.array(range(len(train)))
85 | y = np.array(valid)
86 | plt.plot(x, np.array(train), label="train")
87 | plt.plot(x, np.array(valid), label="valid")
88 | plt.legend()
89 | plt.savefig(name)
90 | _plot(tr_total_loss, v_total_loss, save_path + "_total_loss.pdf")
91 | _plot(tr_total_error, v_total_error, save_path + "_total_error.pdf")
92 | _plot(tr_total_optimal, v_total_optimal,
93 | save_path + "_total_optimal.pdf")
94 | _plot(tr_total_success, v_total_success,
95 | save_path + "_total_success.pdf")
96 |
97 | # Save intermediate model.
98 | if args.save_intermediate:
99 | torch.save({
100 | "model": runner.model.state_dict(),
101 | "best_model": runner.best_model.state_dict(),
102 | "tr_total_loss": tr_total_loss,
103 | "tr_total_error": tr_total_error,
104 | "tr_total_optimal": tr_total_optimal,
105 | "tr_total_success": tr_total_success,
106 | "v_total_loss": v_total_loss,
107 | "v_total_error": v_total_error,
108 | "v_total_optimal": v_total_optimal,
109 | "v_total_success": v_total_success,
110 | }, save_path + ".e" + str(epoch) + ".pth")
111 |
112 | # Test accuracy
113 | print("\nFinal test performance:")
114 | t_final_info = runner.test(testloader)
115 | print_stats(t_final_info)
116 |
117 | print("\nBest test performance:")
118 | t_best_info = runner.test(testloader, use_best=True)
119 | print_stats(t_best_info)
120 |
121 | # Save the final trained model
122 | torch.save({
123 | "model": runner.model.state_dict(),
124 | "best_model": runner.best_model.state_dict(),
125 | "tr_total_loss": tr_total_loss,
126 | "tr_total_error": tr_total_error,
127 | "tr_total_optimal": tr_total_optimal,
128 | "tr_total_success": tr_total_success,
129 | "v_total_loss": v_total_loss,
130 | "v_total_error": v_total_error,
131 | "v_total_optimal": v_total_optimal,
132 | "v_total_success": v_total_success,
133 | "t_final_loss": t_final_info["avg_loss"],
134 | "t_final_error": t_final_info["avg_error"],
135 | "t_final_optimal": t_final_info["avg_optimal"],
136 | "t_final_success": t_final_info["avg_success"],
137 | "t_best_loss": t_best_info["avg_loss"],
138 | "t_best_error": t_best_info["avg_error"],
139 | "t_best_optimal": t_best_info["avg_optimal"],
140 | "t_best_success": t_best_info["avg_success"],
141 | }, save_path + ".final.pth")
142 |
143 |
144 | if __name__ == "__main__":
145 | main()
146 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RLAgent/gated-path-planning-networks/df010fc26667f2b4f0dbbcc0d3d528d4ae6efeb1/utils/__init__.py
--------------------------------------------------------------------------------
/utils/dijkstra.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 |
5 | class MinHeap:
6 |
7 | def __init__(self):
8 | self.heap = [] # binary min-heap
9 | self.heapdict = {} # [key]->index dict
10 | self.invheapdict = {} # [index]->key dict
11 | self.heap_length = 0 # number of elements
12 |
13 | def empty(self):
14 | return self.heap_length == 0
15 |
16 | def insert(self, key, val):
17 | # Insert the value at the bottom of heap
18 | if self.heap_length == len(self.heap):
19 | self.heap.append(val)
20 | else:
21 | self.heap[self.heap_length] = val
22 | add_idx = self.heap_length
23 | self.heap_length += 1
24 |
25 | # Update the dictionaries
26 | self.heapdict[key] = add_idx
27 | self.invheapdict[add_idx] = key
28 |
29 | # percolate upwards
30 | self._percolate_up(add_idx)
31 |
32 | def decrease(self, key, new_val):
33 | # Find the index and value of this key
34 | curr_idx = self.heapdict[key]
35 | curr_val = self.heap[curr_idx]
36 | assert new_val <= curr_val
37 |
38 | # Update with new lower value
39 | self.heap[curr_idx] = new_val
40 |
41 | # Percolate upwards
42 | self._percolate_up(curr_idx)
43 |
44 | def extract(self):
45 | assert self.heap_length > 0
46 |
47 | retval = self.heap[0]
48 | retkey = self.invheapdict[0]
49 |
50 | # Swap the root with a leaf
51 | self._swap_index(0, self.heap_length - 1)
52 |
53 | # Delete the last element from the dictionaries
54 | del self.heapdict[retkey]
55 | del self.invheapdict[self.heap_length - 1]
56 | self.heap_length -= 1
57 |
58 | # Percolate downwards
59 | self._percolate_down(0)
60 | return retkey, retval
61 |
62 | def _swap_index(self, idx1, idx2):
63 | # get keys
64 | key1 = self.invheapdict[idx1]
65 | key2 = self.invheapdict[idx2]
66 |
67 | # Swap values in the heap
68 | tmp1_ = self.heap[idx1]
69 | self.heap[idx1] = self.heap[idx2]
70 | self.heap[idx2] = tmp1_
71 |
72 | # Swap indices in the [key]->index dict
73 | tmp2_ = self.heapdict[key1]
74 | self.heapdict[key1] = self.heapdict[key2]
75 | self.heapdict[key2] = tmp2_
76 |
77 | # Swap keys in the [index]->key dict
78 | tmp3_ = self.invheapdict[idx1]
79 | self.invheapdict[idx1] = self.invheapdict[idx2]
80 | self.invheapdict[idx2] = tmp3_
81 |
82 | def _percolate_up(self, curr_idx):
83 | while curr_idx != 0:
84 | parent_idx = int(math.floor((curr_idx - 1) / 2))
85 | if self.heap[parent_idx] > self.heap[curr_idx]:
86 | self._swap_index(curr_idx, parent_idx)
87 | curr_idx = parent_idx
88 | else:
89 | break
90 |
91 | def _percolate_down(self, curr_idx):
92 | while curr_idx < self.heap_length:
93 | child1 = 2 * curr_idx + 1
94 | child2 = child1 + 1
95 | if child1 >= self.heap_length:
96 | break
97 | minchild = child1
98 | maxchild = child2
99 | if child2 >= self.heap_length:
100 | maxchild = None
101 | if (maxchild is not None) and (self.heap[child1] >
102 | self.heap[child2]):
103 | minchild = child2
104 | maxchild = child1
105 |
106 | if self.heap[minchild] < self.heap[curr_idx]:
107 | self._swap_index(curr_idx, minchild)
108 | curr_idx = minchild
109 | continue
110 |
111 | if (maxchild is not None) and (self.heap[maxchild] <
112 | self.heap[curr_idx]):
113 | self._swap_index(curr_idx, maxchild)
114 | curr_idx = maxchild
115 | continue
116 | break
117 |
118 |
119 | def dijkstra_dist(maze, mechanism, goal):
120 | # Initialize distance to largest possible distance
121 | dist = (np.zeros((mechanism.num_orient, maze.shape[0], maze.shape[1])) +
122 | mechanism.num_orient * maze.shape[0] * maze.shape[1])
123 |
124 | pq = MinHeap()
125 | pq.insert(goal, 0)
126 | for orient in range(mechanism.num_orient):
127 | for y in range(maze.shape[0]):
128 | for x in range(maze.shape[1]):
129 | if (orient == goal[0]) and (y == goal[1]) and (x == goal[2]):
130 | continue
131 | pq.insert((orient, y, x),
132 | mechanism.num_orient * maze.shape[0] * maze.shape[1])
133 |
134 | while not pq.empty():
135 | # extract minimum distance position
136 | ((p_orient, p_y, p_x), val) = pq.extract()
137 | dist[p_orient][p_y][p_x] = val
138 |
139 | # Update neighbors
140 | for n in mechanism.invneighbors_func(maze, p_orient, p_y, p_x):
141 | if (n[1] < 0) or (n[1] >= maze.shape[0]):
142 | continue
143 | if (n[2] < 0) or (n[2] >= maze.shape[1]):
144 | continue
145 |
146 | if maze[n[1]][n[2]] == 0.:
147 | continue
148 |
149 | curr_to_n = val + 1
150 | if curr_to_n < dist[n[0]][n[1]][n[2]]:
151 | dist[n[0]][n[1]][n[2]] = curr_to_n
152 | pq.decrease(n, curr_to_n)
153 | return -dist # negative distance ~= value
154 |
155 |
156 | def dijkstra_policy(maze, mechanism, goal, policy):
157 | # Initialize distance to largest possible distance
158 | dist = np.zeros(
159 | (mechanism.num_orient, maze.shape[0],
160 | maze.shape[1])) + mechanism.num_orient * maze.shape[0] * maze.shape[1]
161 |
162 | pq = MinHeap()
163 | pq.insert(goal, 0)
164 | for orient in range(mechanism.num_orient):
165 | for y in range(maze.shape[0]):
166 | for x in range(maze.shape[1]):
167 | if (orient == goal[0]) and (y == goal[1]) and (x == goal[2]):
168 | continue
169 | pq.insert((orient, y, x),
170 | mechanism.num_orient * maze.shape[0] * maze.shape[1])
171 |
172 | while not pq.empty():
173 | # extract minimum distance position
174 | ((p_orient, p_y, p_x), val) = pq.extract()
175 | dist[p_orient][p_y][p_x] = val
176 |
177 | # Update neighboring predecessors
178 | predecessors = mechanism.invneighbors_func(maze, p_orient, p_y, p_x)
179 | for i in range(len(predecessors)):
180 | n = predecessors[i]
181 | if (n[1] < 0) or (n[1] >= maze.shape[0]):
182 | continue
183 | if (n[2] < 0) or (n[2] >= maze.shape[1]):
184 | continue
185 |
186 | if maze[n[1]][n[2]] == 0.:
187 | continue
188 |
189 | # What are the successor from this predecessor state?
190 | succ_pred = mechanism.neighbors_func(maze, n[0], n[1], n[2])
191 |
192 | # Does following the policy on the predecessor state transition to
193 | # the current state?
194 | succ_pred_pol = succ_pred[policy[n[0]][n[1]][n[2]]]
195 | if (succ_pred_pol[0] == p_orient) and (
196 | succ_pred_pol[1] == p_y) and (succ_pred_pol[2] == p_x):
197 | # Update value
198 | curr_to_n = val + 1
199 | if curr_to_n < dist[n[0]][n[1]][n[2]]:
200 | dist[n[0]][n[1]][n[2]] = curr_to_n
201 | pq.decrease(n, curr_to_n)
202 | return -dist # negative distance ~= value
203 |
--------------------------------------------------------------------------------
/utils/experiment.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 |
5 | import torch
6 |
7 | from utils.maze import MazeDataset
8 | from utils.mechanism import DifferentialDrive, NorthEastWestSouth, Moore
9 |
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser(
13 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
14 |
15 | # Environment parameters
16 | parser.add_argument(
17 | "--datafile", type=str, default="mazes.npz", help="Path to data file.")
18 | parser.add_argument(
19 | "--gamma", type=float, default=0.99,
20 | help="Discount factor. (keeps value within reasonable range).")
21 | parser.add_argument(
22 | "--mechanism", type=str, default="news",
23 | help="Maze transition mechanism. (news|diffdrive|moore)")
24 |
25 | # Log parameters
26 | parser.add_argument(
27 | "--save-directory", type=str, default="log/",
28 | help="Directory to save the graphs and models.")
29 | parser.add_argument(
30 | "--save-intermediate", default=False,
31 | help="Whether to save every epoch.")
32 | parser.add_argument(
33 | "--use-percent-successful", default=False,
34 | help="Use % successful instead of % optimal to decide best models.")
35 |
36 | # Optimization parameters
37 | parser.add_argument(
38 | "--optimizer", type=str, default="RMSprop",
39 | help="Which optimizer to use.")
40 | parser.add_argument(
41 | "--epochs", type=int, default=30, help="Number of epochs to train.")
42 | parser.add_argument(
43 | "--lr", type=float, default=0.001, help="Learning rate.")
44 | parser.add_argument(
45 | "--lr-decay", type=float, default=1.0,
46 | help="Learning rate decay when CE goes up.")
47 | parser.add_argument(
48 | "--eps", type=float, default=1e-6, help="Epsilon for denominator.")
49 | parser.add_argument(
50 | "--batch-size", type=int, default=32, help="Batch size.")
51 | parser.add_argument(
52 | "--clip-grad", type=float, default=40,
53 | help="Whether to clip the gradient norms. (0 for none)")
54 |
55 | # Model parameters
56 | parser.add_argument(
57 | "--model", type=str, default="models.VIN",
58 | help="Which model architecture to train.")
59 | parser.add_argument(
60 | "--load-file", type=str, default="",
61 | help="Model weights to load. (leave empty for none)")
62 | parser.add_argument(
63 | "--load-best", default=False,
64 | help="Whether to load the best weights from the load-file.")
65 | parser.add_argument(
66 | "--k", type=int, default=10, help="Number of Value Iterations.")
67 | parser.add_argument(
68 | "--l-i", type=int, default=5,
69 | help="Number of channels in input layer.")
70 | parser.add_argument(
71 | "--l-h", type=int, default=150,
72 | help="Number of channels in first hidden layer.")
73 | parser.add_argument(
74 | "--l-q", type=int, default=600,
75 | help="Number of channels in q layer (~actions) in VI-module.")
76 | parser.add_argument(
77 | "--f", type=int, default=3, help="Kernel size")
78 |
79 | args = parser.parse_args()
80 |
81 | # Automatic switch of GPU mode if available
82 | args.use_gpu = torch.cuda.is_available()
83 |
84 | return args
85 |
86 |
87 | def get_mechanism(mechanism):
88 | if mechanism == "news":
89 | print("Using NEWS Drive")
90 | return NorthEastWestSouth()
91 | elif mechanism == "diffdrive":
92 | print("Using Differential Drive")
93 | return DifferentialDrive()
94 | elif mechanism == "moore":
95 | print("Using Moore Drive")
96 | return Moore()
97 | else:
98 | raise ValueError("Unsupported mechanism: %s" % mechanism)
99 |
100 |
101 | def create_dataloader(datafile, dataset_type, batch_size, mechanism, shuffle=False):
102 | """
103 | Creates a maze DataLoader.
104 | Args:
105 | datafile (string): Path to the dataset
106 | dataset_type (string): One of "train", "valid", or "test"
107 | batch_size (int): The batch size
108 | shuffle (bool): Whether to shuffle the data
109 | """
110 | dataset = MazeDataset(datafile, dataset_type)
111 | assert dataset.num_actions == mechanism.num_actions
112 | return torch.utils.data.DataLoader(
113 | dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0)
114 |
115 |
116 | def create_save_dir(save_directory):
117 | """
118 | Creates and returns path to the save directory.
119 | """
120 | try:
121 | os.makedirs(save_directory)
122 | except OSError:
123 | if not os.path.isdir(save_directory):
124 | raise
125 | return save_directory + "/planner"
126 |
127 |
128 | def print_row(width, items):
129 | """
130 | Prints the given items.
131 | Args:
132 | width (int): Character length for each column.
133 | items (list): List of items to print.
134 | """
135 | def fmt_item(x):
136 | if isinstance(x, np.ndarray):
137 | assert x.ndim == 0
138 | x = x.item()
139 | if isinstance(x, float):
140 | rep = "%.3f" % x
141 | else:
142 | rep = str(x)
143 | return rep.ljust(width)
144 |
145 | print(" | ".join(fmt_item(item) for item in items))
146 |
147 |
148 | def print_stats(info):
149 | """Prints performance statistics output from Runner."""
150 | print_row(10, ["Loss", "Err", "% Optimal", "% Success"])
151 | print_row(10, [
152 | info["avg_loss"], info["avg_error"],
153 | info["avg_optimal"], info["avg_success"]])
154 | return info
155 |
--------------------------------------------------------------------------------
/utils/maze.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function
3 | import sys
4 | import math
5 | import numpy as np
6 | import numpy.random as npr
7 |
8 | import torch
9 | import torch.utils.data as data
10 |
11 |
12 | class MazeDataset(data.Dataset):
13 |
14 | def __init__(self, filename, dataset_type):
15 | """
16 | Args:
17 | filename (str): Dataset filename (must be .npz format).
18 | dataset_type (str): One of "train", "valid", or "test".
19 | """
20 | assert filename.endswith("npz") # Must be .npz format
21 | self.filename = filename
22 | self.dataset_type = dataset_type # train, valid, test
23 |
24 | self.mazes, self.goal_maps, self.opt_policies = self._process(filename)
25 |
26 | self.num_actions = self.opt_policies.shape[1]
27 | self.num_orient = self.opt_policies.shape[2]
28 |
29 | def _process(self, filename):
30 | """
31 | Data format: list, [train data, test data]
32 | """
33 | with np.load(filename) as f:
34 | dataset2idx = {"train": 0, "valid": 3, "test": 6}
35 | idx = dataset2idx[self.dataset_type]
36 | mazes = f["arr_" + str(idx)]
37 | goal_maps = f["arr_" + str(idx + 1)]
38 | opt_policies = f["arr_" + str(idx + 2)]
39 |
40 | # Set proper datatypes
41 | mazes = mazes.astype(np.float32)
42 | goal_maps = goal_maps.astype(np.float32)
43 | opt_policies = opt_policies.astype(np.float32)
44 |
45 | # Print number of samples
46 | if self.dataset_type == "train":
47 | print("Number of Train Samples: {0}".format(mazes.shape[0]))
48 | elif self.dataset_type == "valid":
49 | print("Number of Validation Samples: {0}".format(mazes.shape[0]))
50 | else:
51 | print("Number of Test Samples: {0}".format(mazes.shape[0]))
52 | print("\tSize: {}x{}".format(mazes.shape[1], mazes.shape[2]))
53 | return mazes, goal_maps, opt_policies
54 |
55 | def __getitem__(self, index):
56 | maze = self.mazes[index]
57 | goal_map = self.goal_maps[index]
58 | opt_policy = self.opt_policies[index]
59 |
60 | maze = torch.from_numpy(maze)
61 | goal_map = torch.from_numpy(goal_map)
62 | opt_policy = torch.from_numpy(opt_policy)
63 |
64 | return maze, goal_map, opt_policy
65 |
66 | def __len__(self):
67 | return self.mazes.shape[0]
68 |
69 |
70 | def generate_maze(maze_size, decimation, start_pos=(1, 1)):
71 | maze = np.zeros((maze_size, maze_size))
72 |
73 | stack = [((start_pos[0], start_pos[1]), (0, 0))]
74 |
75 | def add_stack(next_pos, next_dir):
76 | if (next_pos[0] < 0) or (next_pos[0] >= maze_size):
77 | return
78 | if (next_pos[1] < 0) or (next_pos[1] >= maze_size):
79 | return
80 | if maze[next_pos[0]][next_pos[1]] == 0.:
81 | stack.append((next_pos, next_dir))
82 |
83 | while stack:
84 | pos, prev_dir = stack.pop()
85 | # Has this not been filled since being added?
86 | if maze[pos[0]][pos[1]] == 1.:
87 | continue
88 |
89 | # Fill in this point + break down wall from previous position
90 | maze[pos[0]][pos[1]] = 1.
91 | maze[pos[0] - prev_dir[0]][pos[1] - prev_dir[1]] = 1.
92 |
93 | choices = []
94 | choices.append(((pos[0] - 2, pos[1]), (-1, 0)))
95 | choices.append(((pos[0], pos[1] + 2), (0, 1)))
96 | choices.append(((pos[0], pos[1] - 2), (0, -1)))
97 | choices.append(((pos[0] + 2, pos[1]), (1, 0)))
98 |
99 | perm = np.random.permutation(np.array(range(4)))
100 | for i in range(4):
101 | choice = choices[perm[i]]
102 | add_stack(choice[0], choice[1])
103 |
104 | for y in range(1, maze_size - 1):
105 | for x in range(1, maze_size - 1):
106 | if np.random.uniform() < decimation:
107 | maze[y][x] = 1.
108 |
109 | return maze
110 |
111 |
112 | class RandomMaze:
113 |
114 | def __init__(self,
115 | mechanism,
116 | min_maze_size,
117 | max_maze_size,
118 | min_decimation,
119 | max_decimation,
120 | start_pos=(1, 1)):
121 | self.mechanism = mechanism
122 | self.min_maze_size = min_maze_size
123 | self.max_maze_size = max_maze_size
124 | self.min_decimation = min_decimation
125 | self.max_decimation = max_decimation
126 | self.start_pos = start_pos
127 |
128 | def _isGoalPos(self, pos):
129 | """Returns true if pos is equal to the goal position."""
130 | return pos[0] == self.goal_pos[0] and pos[1] == self.goal_pos[1]
131 |
132 | def _getState(self):
133 | """Returns the current state."""
134 | goal_map = np.zeros((self.mechanism.num_orient, self.maze_size,
135 | self.maze_size))
136 | goal_map[self.goal_orient, self.goal_pos[0], self.goal_pos[1]] = 1.
137 |
138 | player_map = np.zeros((self.mechanism.num_orient, self.maze_size,
139 | self.maze_size))
140 | player_map[self.player_orient, self.player_pos[0],
141 | self.player_pos[1]] = 1.
142 |
143 | # Check if agent has reached the goal state
144 | reward = 0
145 | terminal = False
146 | if (self.player_orient == self.goal_orient) and self._isGoalPos(
147 | self.player_pos):
148 | reward = 1
149 | terminal = True
150 |
151 | return np.copy(self.maze), player_map, goal_map, reward, terminal
152 |
153 | def reset(self):
154 | """Resets the maze."""
155 | if self.min_maze_size == self.max_maze_size:
156 | self.maze_size = self.min_maze_size
157 | else:
158 | self.maze_size = self.min_maze_size + 2 * npr.randint(
159 | math.floor((self.max_maze_size - self.min_maze_size) / 2))
160 | if self.min_decimation == self.max_decimation:
161 | self.decimation = self.min_decimation
162 | else:
163 | self.decimation = npr.uniform(self.min_decimation,
164 | self.max_decimation)
165 | self.maze = generate_maze(
166 | self.maze_size, self.decimation, start_pos=self.start_pos)
167 |
168 | # Randomly sample a goal location
169 | self.goal_pos = (npr.randint(1, self.maze_size - 1),
170 | npr.randint(1, self.maze_size - 1))
171 | while self._isGoalPos(self.start_pos):
172 | self.goal_pos = (npr.randint(1, self.maze_size - 1),
173 | npr.randint(1, self.maze_size - 1))
174 | self.goal_orient = npr.randint(self.mechanism.num_orient)
175 |
176 | # Free the maze at the goal location
177 | self.maze[self.goal_pos[0]][self.goal_pos[1]] = 1.
178 |
179 | # Player start position
180 | self.player_pos = (self.start_pos[0], self.start_pos[1])
181 |
182 | # Sample player orientation
183 | self.player_orient = npr.randint(self.mechanism.num_orient)
184 |
185 | screen, player_map, goal_map, _, _ = self._getState()
186 | return screen, player_map, goal_map
187 |
188 | def step(self, action):
189 | # Compute neighbors for the current state.
190 | neighbors = self.neighbors_func(self.maze, self.player_orient,
191 | self.player_pos[0], self.player_pos[1])
192 | assert (action > 0) and (action < len(neighbors))
193 | self.player_orient, self.player_pos[0], self.player_pos[1] = neighbors[
194 | action]
195 | return self._getState()
196 |
197 |
198 | def extract_policy(maze, mechanism, value):
199 | """Extracts the policy from the given values."""
200 | policy = np.zeros((mechanism.num_actions, value.shape[0], value.shape[1],
201 | value.shape[2]))
202 | for p_orient in range(value.shape[0]):
203 | for p_y in range(value.shape[1]):
204 | for p_x in range(value.shape[2]):
205 | # Find the neighbor w/ max value (assuming deterministic
206 | # transitions)
207 | max_val = -sys.maxsize
208 | max_acts = [0]
209 | neighbors = mechanism.neighbors_func(maze, p_orient, p_y, p_x)
210 | for i in range(len(neighbors)):
211 | n = neighbors[i]
212 | nval = value[n[0]][n[1]][n[2]]
213 | if nval > max_val:
214 | max_val = nval
215 | max_acts = [i]
216 | elif nval == max_val:
217 | max_acts.append(i)
218 |
219 | # Choose max actions if several w/ same value
220 | max_act = max_acts[np.random.randint(len(max_acts))]
221 | policy[max_act][p_orient][p_y][p_x] = 1.
222 | return policy
223 |
224 |
225 | def extract_goal(goal_map):
226 | """Returns the goal location."""
227 | for o in range(goal_map.shape[0]):
228 | for y in range(goal_map.shape[1]):
229 | for x in range(goal_map.shape[2]):
230 | if goal_map[o][y][x] == 1.:
231 | return (o, y, x)
232 | assert False
233 |
--------------------------------------------------------------------------------
/utils/mechanism.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function
3 | import abc
4 |
5 |
6 | class Mechanism(abc.ABC):
7 | """Base class for maze transition mechanisms."""
8 |
9 | def __init__(self, num_actions, num_orient):
10 | self.num_actions = num_actions
11 | self.num_orient = num_orient
12 |
13 | @abc.abstractmethod
14 | def neighbors_func(self, maze, p_orient, p_y, p_x):
15 | """Computes next states for each action."""
16 |
17 | @abc.abstractmethod
18 | def invneighbors_func(self, maze, p_orient, p_y, p_x):
19 | """Computes previous states for each action."""
20 |
21 | @abc.abstractmethod
22 | def print_policy(self, maze, goal, policy):
23 | """Prints the given policy."""
24 |
25 |
26 | class DifferentialDrive(Mechanism):
27 | """
28 | In Differential Drive, the agent can move forward along its current
29 | orientation, or turn left/right by 90 degrees.
30 | """
31 |
32 | def __init__(self):
33 | super(DifferentialDrive, self).__init__(num_actions=3, num_orient=4)
34 | self.clockwise = [1, 3, 0, 2] # E S N W
35 | self.cclockwise = [2, 0, 3, 1] # W N S E
36 |
37 | def _is_out_of_bounds(self, maze, p_y, p_x):
38 | return (p_x < 0 or p_x >= maze.shape[1] or p_y < 0 or
39 | p_y >= maze.shape[0])
40 |
41 | def _forward(self, maze, p_orient, p_y, p_x):
42 | assert p_orient < self.num_orient, p_orient
43 |
44 | next_p_y, next_p_x = p_y, p_x
45 | if p_orient == 0: # North
46 | next_p_y -= 1
47 | elif p_orient == 1: # East
48 | next_p_x += 1
49 | elif p_orient == 2: # West
50 | next_p_x -= 1
51 | else: # South
52 | next_p_y += 1
53 |
54 | # If position is out of bounds, simply return the current state.
55 | if (self._is_out_of_bounds(maze, next_p_y, next_p_x) or
56 | maze[p_y][p_x] == 0.):
57 | return p_orient, p_y, p_x
58 |
59 | return p_orient, next_p_y, next_p_x
60 |
61 | def _turnright(self, p_orient, p_y, p_x):
62 | assert p_orient < self.num_orient, p_orient
63 | return self.clockwise[p_orient], p_y, p_x
64 |
65 | def _turnleft(self, p_orient, p_y, p_x):
66 | assert p_orient < self.num_orient, p_orient
67 | return self.cclockwise[p_orient], p_y, p_x
68 |
69 | def _backward(self, maze, p_orient, p_y, p_x):
70 | assert p_orient < self.num_orient, p_orient
71 |
72 | next_p_y, next_p_x = p_y, p_x
73 | if p_orient == 0: # North
74 | next_p_y += 1
75 | elif p_orient == 1: # East
76 | next_p_x -= 1
77 | elif p_orient == 2: # West
78 | next_p_x += 1
79 | else: # South
80 | next_p_y -= 1
81 |
82 | # If position is out of bounds, simply return the current state.
83 | if (self._is_out_of_bounds(maze, next_p_y, next_p_x) or
84 | maze[p_y][p_x] == 0.):
85 | return p_orient, p_y, p_x
86 |
87 | return p_orient, next_p_y, next_p_x
88 |
89 | def neighbors_func(self, maze, p_orient, p_y, p_x):
90 | return [
91 | self._forward(maze, p_orient, p_y, p_x),
92 | self._turnright(p_orient, p_y, p_x),
93 | self._turnleft(p_orient, p_y, p_x),
94 | ]
95 |
96 | def invneighbors_func(self, maze, p_orient, p_y, p_x):
97 | return [
98 | self._backward(maze, p_orient, p_y, p_x),
99 | self._turnleft(p_orient, p_y, p_x),
100 | self._turnright(p_orient, p_y, p_x),
101 | ]
102 |
103 | def print_policy(self, maze, goal, policy):
104 | orient2str = ["↑", "→", "←", "↓"]
105 | action2str = ["F", "R", "L"]
106 | for o in range(self.num_orient):
107 | print(orient2str[o])
108 | for y in range(policy.shape[1]):
109 | for x in range(policy.shape[2]):
110 | if (o, y, x) == goal:
111 | print("!", end="")
112 | elif maze[y][x] == 0.:
113 | print(u"\u2588", end="")
114 | else:
115 | a = policy[o][y][x]
116 | print(action2str[a], end="")
117 | print("")
118 |
119 |
120 | class NorthEastWestSouth(Mechanism):
121 | """
122 | In NEWS, the agent can move North, East, West, or South.
123 | """
124 |
125 | def __init__(self):
126 | super(NorthEastWestSouth, self).__init__(num_actions=4, num_orient=1)
127 |
128 | def _north(self, maze, p_orient, p_y, p_x):
129 | if (p_y > 0) and (maze[p_y - 1][p_x] != 0.):
130 | return p_orient, p_y - 1, p_x
131 | return p_orient, p_y, p_x
132 |
133 | def _east(self, maze, p_orient, p_y, p_x):
134 | if (p_x < (maze.shape[1] - 1)) and (maze[p_y][p_x + 1] != 0.):
135 | return p_orient, p_y, p_x + 1
136 | return p_orient, p_y, p_x
137 |
138 | def _west(self, maze, p_orient, p_y, p_x):
139 | if (p_x > 0) and (maze[p_y][p_x - 1] != 0.):
140 | return p_orient, p_y, p_x - 1
141 | return p_orient, p_y, p_x
142 |
143 | def _south(self, maze, p_orient, p_y, p_x):
144 | if (p_y < (maze.shape[0] - 1)) and (maze[p_y + 1][p_x] != 0.):
145 | return p_orient, p_y + 1, p_x
146 | return p_orient, p_y, p_x
147 |
148 | def neighbors_func(self, maze, p_orient, p_y, p_x):
149 | return [
150 | self._north(maze, p_orient, p_y, p_x),
151 | self._east(maze, p_orient, p_y, p_x),
152 | self._west(maze, p_orient, p_y, p_x),
153 | self._south(maze, p_orient, p_y, p_x),
154 | ]
155 |
156 | def invneighbors_func(self, maze, p_orient, p_y, p_x):
157 | return [
158 | self._south(maze, p_orient, p_y, p_x),
159 | self._west(maze, p_orient, p_y, p_x),
160 | self._east(maze, p_orient, p_y, p_x),
161 | self._north(maze, p_orient, p_y, p_x),
162 | ]
163 |
164 | def print_policy(self, maze, goal, policy):
165 | action2str = ["↑", "→", "←", "↓"]
166 | for o in range(self.num_orient):
167 | for y in range(policy.shape[1]):
168 | for x in range(policy.shape[2]):
169 | if (o, y, x) == goal:
170 | print("!", end="")
171 | elif maze[y][x] == 0.:
172 | print(u"\u2588", end="")
173 | else:
174 | a = policy[o][y][x]
175 | print(action2str[a], end="")
176 | print("")
177 | print("")
178 |
179 |
180 | class Moore(Mechanism):
181 | """
182 | In Moore, the agent can move to any of the eight cells in its Moore
183 | neighborhood.
184 | """
185 |
186 | def __init__(self):
187 | super(Moore, self).__init__(num_actions=8, num_orient=1)
188 |
189 | def _north(self, maze, p_orient, p_y, p_x):
190 | if (p_y > 0) and (maze[p_y - 1][p_x] != 0.):
191 | return p_orient, p_y - 1, p_x
192 | return p_orient, p_y, p_x
193 |
194 | def _east(self, maze, p_orient, p_y, p_x):
195 | if (p_x < (maze.shape[1] - 1)) and (maze[p_y][p_x + 1] != 0.):
196 | return p_orient, p_y, p_x + 1
197 | return p_orient, p_y, p_x
198 |
199 | def _west(self, maze, p_orient, p_y, p_x):
200 | if (p_x > 0) and (maze[p_y][p_x - 1] != 0.):
201 | return p_orient, p_y, p_x - 1
202 | return p_orient, p_y, p_x
203 |
204 | def _south(self, maze, p_orient, p_y, p_x):
205 | if (p_y < (maze.shape[0] - 1)) and (maze[p_y + 1][p_x] != 0.):
206 | return p_orient, p_y + 1, p_x
207 | return p_orient, p_y, p_x
208 |
209 | def _northeast(self, maze, p_orient, p_y, p_x):
210 | if (p_y > 0) and (p_x < (maze.shape[1] - 1)) and (maze[p_y - 1][p_x + 1]
211 | != 0.):
212 | return p_orient, p_y - 1, p_x + 1
213 | return p_orient, p_y, p_x
214 |
215 | def _northwest(self, maze, p_orient, p_y, p_x):
216 | if (p_y > 0) and (p_x > 0) and (maze[p_y - 1][p_x - 1] != 0.):
217 | return p_orient, p_y - 1, p_x - 1
218 | return p_orient, p_y, p_x
219 |
220 | def _southeast(self, maze, p_orient, p_y, p_x):
221 | if (p_y < (maze.shape[0] - 1)) and (p_x < (maze.shape[1] - 1)) and (
222 | maze[p_y + 1][p_x + 1] != 0.):
223 | return p_orient, p_y + 1, p_x + 1
224 | return p_orient, p_y, p_x
225 |
226 | def _southwest(self, maze, p_orient, p_y, p_x):
227 | if (p_y < (maze.shape[0] - 1)) and (p_x > 0) and (maze[p_y + 1][p_x - 1]
228 | != 0.):
229 | return p_orient, p_y + 1, p_x - 1
230 | return p_orient, p_y, p_x
231 |
232 | def neighbors_func(self, maze, p_orient, p_y, p_x):
233 | return [
234 | self._north(maze, p_orient, p_y, p_x),
235 | self._east(maze, p_orient, p_y, p_x),
236 | self._west(maze, p_orient, p_y, p_x),
237 | self._south(maze, p_orient, p_y, p_x),
238 | self._northeast(maze, p_orient, p_y, p_x),
239 | self._northwest(maze, p_orient, p_y, p_x),
240 | self._southeast(maze, p_orient, p_y, p_x),
241 | self._southwest(maze, p_orient, p_y, p_x),
242 | ]
243 |
244 | def invneighbors_func(self, maze, p_orient, p_y, p_x):
245 | return [
246 | self._south(maze, p_orient, p_y, p_x),
247 | self._west(maze, p_orient, p_y, p_x),
248 | self._east(maze, p_orient, p_y, p_x),
249 | self._north(maze, p_orient, p_y, p_x),
250 | self._southwest(maze, p_orient, p_y, p_x),
251 | self._southeast(maze, p_orient, p_y, p_x),
252 | self._northwest(maze, p_orient, p_y, p_x),
253 | self._northeast(maze, p_orient, p_y, p_x),
254 | ]
255 |
256 | def print_policy(self, maze, goal, policy):
257 | action2str = ["↑", "→", "←", "↓", "↗", "↖", "↘", "↙"]
258 | for o in range(self.num_orient):
259 | for y in range(policy.shape[1]):
260 | for x in range(policy.shape[2]):
261 | if (o, y, x) == goal:
262 | print("!", end="")
263 | elif maze[y][x] == 0.:
264 | print(u"\u2588", end="")
265 | else:
266 | a = policy[o][y][x]
267 | print(action2str[a], end="")
268 | print("")
269 | print("")
270 |
--------------------------------------------------------------------------------
/utils/runner.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import numpy as np
3 |
4 | import torch
5 | from torch.autograd import Variable
6 | import torch.nn as nn
7 | import torch.optim as optim
8 |
9 | from utils.dijkstra import dijkstra_dist, dijkstra_policy
10 | from utils.maze import extract_goal
11 |
12 |
13 | def get_optimizer(args, parameters):
14 | if args.optimizer == "RMSprop":
15 | return optim.RMSprop(parameters, lr=args.lr, eps=args.eps)
16 | elif args.optimizer == "Adam":
17 | return optim.Adam(parameters, lr=args.lr, eps=args.eps)
18 | elif args.optimizer == "SGD":
19 | return optim.SGD(parameters, lr=args.lr, momentum=0.95)
20 | else:
21 | raise ValueError("Unsupported optimizer: %s" % args.optimizer)
22 |
23 |
24 | class Runner():
25 | """
26 | The Runner class runs a planner model on a given dataset and records
27 | statistics such as loss, prediction error, % Optimal, and % Success.
28 | """
29 |
30 | def __init__(self, args, mechanism):
31 | """
32 | Args:
33 | model (torch.nn.Module): The Planner model
34 | mechanism (utils.mechanism.Mechanism): Environment transition kernel
35 | args (Namespace): Arguments
36 | """
37 | self.use_gpu = args.use_gpu
38 | self.clip_grad = args.clip_grad
39 | self.lr_decay = args.lr_decay
40 | self.use_percent_successful = args.use_percent_successful
41 |
42 | self.mechanism = mechanism
43 | self.criterion = nn.CrossEntropyLoss()
44 |
45 | # Instantiate the model
46 | model_module = importlib.import_module(args.model)
47 | self.model = model_module.Planner(
48 | mechanism.num_orient, mechanism.num_actions, args)
49 | self.best_model = model_module.Planner(
50 | mechanism.num_orient, mechanism.num_actions, args)
51 |
52 | # Load model from file if provided
53 | if args.load_file != "":
54 | saved_model = torch.load(args.load_file)
55 | if args.load_best:
56 | self.model.load_state_dict(saved_model["best_model"])
57 | else:
58 | self.model.load_state_dict(saved_model["model"])
59 | self.best_model.load_state_dict(saved_model["best_model"])
60 | else:
61 | self.best_model.load_state_dict(self.model.state_dict())
62 |
63 | # Track the best performing model so far
64 | self.best_metric = 0.
65 |
66 | # Use GPU if available
67 | if self.use_gpu:
68 | self.model = self.model.cuda()
69 | self.best_model = self.best_model.cuda()
70 |
71 | self.optimizer = get_optimizer(args, self.model.parameters())
72 |
73 | def _compute_stats(self, batch_size, map_design, goal_map,
74 | outputs, predictions, labels,
75 | loss, opt_policy, sample=False):
76 | # Select argmax policy
77 | _, pred_pol = torch.max(outputs, dim=1, keepdim=True)
78 |
79 | # Convert to numpy arrays
80 | map_design = map_design.cpu().data.numpy()
81 | goal_map = goal_map.cpu().data.numpy()
82 | outputs = outputs.cpu().data.numpy()
83 | predictions = predictions.cpu().data.numpy()
84 | labels = labels.cpu().data.numpy()
85 | opt_policy = opt_policy.cpu().data.numpy()
86 | pred_pol = pred_pol.cpu().data.numpy()
87 |
88 | max_pred = (predictions == predictions.max(axis=1)[:, None]).astype(
89 | np.float32)
90 | match_action = np.sum((max_pred != opt_policy).astype(np.float32), axis=1)
91 | match_action = (match_action == 0).astype(np.float32)
92 | match_action = np.reshape(match_action, (batch_size, -1))
93 | batch_error = 1 - np.mean(match_action)
94 |
95 | def calc_optimal_and_success(i):
96 | # Get current sample
97 | md = map_design[i][0]
98 | gm = goal_map[i]
99 | op = opt_policy[i]
100 | pp = pred_pol[i][0]
101 | ll = labels[i][0]
102 |
103 | # Extract the goal in 2D coordinates
104 | goal = extract_goal(gm)
105 |
106 | # Check how different the predicted policy is from the optimal one
107 | # in terms of path lengths
108 | pred_dist = dijkstra_policy(md, self.mechanism, goal, pp)
109 | opt_dist = dijkstra_dist(md, self.mechanism, goal)
110 | diff_dist = pred_dist - opt_dist
111 |
112 | wall_dist = np.min(pred_dist) # impossible distance
113 |
114 | for o in range(self.mechanism.num_orient):
115 | # Refill the walls in the difference with the impossible distance
116 | diff_dist[o] += (1 - md) * wall_dist
117 |
118 | # Mask out the walls in the prediction distances
119 | pred_dist[o] = pred_dist[o] - np.multiply(1 - md, pred_dist[o])
120 |
121 | num_open = md.sum() * self.mechanism.num_orient # number of reachable locations
122 | return (diff_dist == 0).sum() / num_open, 1. - (
123 | pred_dist == wall_dist).sum() / num_open
124 |
125 | if sample:
126 | percent_optimal, percent_successful = calc_optimal_and_success(
127 | np.random.randint(batch_size))
128 | else:
129 | percent_optimal, percent_successful = 0, 0
130 | for i in range(batch_size):
131 | po, ps = calc_optimal_and_success(i)
132 | percent_optimal += po
133 | percent_successful += ps
134 | percent_optimal = percent_optimal / float(batch_size)
135 | percent_successful = percent_successful / float(batch_size)
136 |
137 | return loss.data.item(), batch_error, percent_optimal, percent_successful
138 |
139 | def _run(self, model, dataloader, train=False, batch_size=-1,
140 | store_best=False):
141 | """
142 | Runs the model on the given data.
143 | Args:
144 | model (torch.nn.Module): The Planner model
145 | dataloader (torch.utils.data.Dataset): Dataset loader
146 | train (bool): Whether to train the model
147 | batch_size (int): Only used if train=True
148 | store_best (bool): Whether to store the best model
149 | Returns:
150 | info (dict): Performance statistics, including
151 | info["avg_loss"] (float): Average loss
152 | info["avg_error"] (float): Average error
153 | info["avg_optimal"] (float): Average % Optimal
154 | info["avg_success"] (float): Average % Success
155 | info["weight_norm"] (float): Model weight norm, stored if train=True
156 | info["grad_norm"]: Gradient norm, stored if train=True
157 | info["is_best"] (bool): Whether the model is best, stored if store_best=True
158 | """
159 | info = {}
160 | for key in ["avg_loss", "avg_error", "avg_optimal", "avg_success"]:
161 | info[key] = 0.0
162 | num_batches = 0
163 |
164 | for i, data in enumerate(dataloader):
165 | # Get input batch.
166 | map_design, goal_map, opt_policy = data
167 |
168 | if train:
169 | if map_design.size()[0] != batch_size:
170 | continue # Drop those data, if not enough for a batch
171 | self.optimizer.zero_grad() # Zero the parameter gradients
172 | else:
173 | batch_size = map_design.size()[0]
174 |
175 | # Send tensor to GPU if available
176 | if self.use_gpu:
177 | map_design = map_design.cuda()
178 | goal_map = goal_map.cuda()
179 | opt_policy = opt_policy.cuda()
180 | map_design = Variable(map_design)
181 | goal_map = Variable(goal_map)
182 | opt_policy = Variable(opt_policy)
183 |
184 | # Reshape batch-wise if necessary
185 | if map_design.dim() == 3:
186 | map_design = map_design.unsqueeze(1)
187 |
188 | # Forward pass
189 | outputs, predictions, _, _ = model(map_design, goal_map)
190 |
191 | # Loss
192 | flat_outputs = outputs.transpose(1, 4).contiguous()
193 | flat_outputs = flat_outputs.view(
194 | -1, flat_outputs.size()[-1]).contiguous()
195 | _, labels = opt_policy.max(1, keepdim=True)
196 | flat_labels = labels.transpose(1, 4).contiguous()
197 | flat_labels = flat_labels.view(-1).contiguous()
198 | loss = self.criterion(flat_outputs, flat_labels)
199 |
200 | # Select actions with max scores (logits)
201 | _, predicted = torch.max(outputs, dim=1, keepdim=True)
202 |
203 | # Update parameters
204 | if train:
205 | # Backward pass
206 | loss.backward()
207 |
208 | # Clip the gradient norm
209 | if self.clip_grad:
210 | torch.nn.utils.clip_grad_norm_(model.parameters(),
211 | self.clip_grad)
212 |
213 | # Update parameters
214 | self.optimizer.step()
215 |
216 | # Compute loss and error
217 | loss_batch, batch_error, p_opt, p_suc = self._compute_stats(
218 | batch_size, map_design, goal_map,
219 | outputs, predictions, labels,
220 | loss, opt_policy, sample=train)
221 | info["avg_loss"] += loss_batch
222 | info["avg_error"] += batch_error
223 | info["avg_optimal"] += p_opt
224 | info["avg_success"] += p_suc
225 | num_batches += 1
226 |
227 | info["avg_loss"] = info["avg_loss"] / num_batches
228 | info["avg_error"] = info["avg_error"] / num_batches
229 | info["avg_optimal"] = info["avg_optimal"] / num_batches
230 | info["avg_success"] = info["avg_success"] / num_batches
231 |
232 | if train:
233 | # Calculate weight norm
234 | weight_norm = 0
235 | grad_norm = 0
236 | for p in model.parameters():
237 | weight_norm += torch.norm(p)**2
238 | if p.grad is not None:
239 | grad_norm += torch.norm(p.grad)**2
240 | info["weight_norm"] = float(np.sqrt(weight_norm.cpu().data.numpy().item()))
241 | info["grad_norm"] = float(np.sqrt(grad_norm.cpu().data.numpy().item()))
242 |
243 | if store_best:
244 | # Was the validation accuracy greater than the best one?
245 | metric = (info["avg_success"] if self.use_percent_successful else
246 | info["avg_optimal"])
247 | if metric > self.best_metric:
248 | self.best_metric = metric
249 | self.best_model.load_state_dict(model.state_dict())
250 | info["is_best"] = True
251 | else:
252 | for param_group in self.optimizer.param_groups:
253 | param_group["lr"] = param_group["lr"] * self.lr_decay
254 | info["is_best"] = False
255 | return info
256 |
257 | def train(self, dataloader, batch_size):
258 | """
259 | Trains the model on the given training dataset.
260 | """
261 | return self._run(self.model, dataloader, train=True,
262 | batch_size=batch_size)
263 |
264 | def validate(self, dataloader):
265 | """
266 | Evaluates the model on the given validation dataset. Stores the
267 | current model if it achieves the best validation performance.
268 | """
269 | return self._run(self.model, dataloader, store_best=True)
270 |
271 | def test(self, dataloader, use_best=False):
272 | """
273 | Tests the model on the given dataset.
274 | """
275 | if use_best:
276 | model = self.best_model
277 | else:
278 | model = self.model
279 | model.eval()
280 | return self._run(model, dataloader, store_best=True)
281 |
--------------------------------------------------------------------------------