├── .gitignore ├── README.md ├── download_datasets.sh ├── eval.py ├── generate_dataset.py ├── models ├── GPPN.py ├── VIN.py └── __init__.py ├── requirements.txt ├── train.py └── utils ├── __init__.py ├── dijkstra.py ├── experiment.py ├── maze.py ├── mechanism.py └── runner.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Data/log directories 2 | ./mazes 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gated Path Planning Networks (ICML 2018) 2 | 3 | This is the official codebase for the following paper: 4 | 5 | Lisa Lee\*, Emilio Parisotto\*, Devendra Singh Chaplot, Eric Xing, Ruslan Salakhutdinov. **Gated Path Planning Networks**. ICML 2018. https://arxiv.org/abs/1806.06408 6 | 7 | ## Getting Started 8 | 9 | You can clone this repo by running: 10 | ``` 11 | git clone https://github.com/lileee/gated-path-planning-networks.git 12 | cd gated-path-planning-networks/ 13 | ``` 14 | 15 | All subsequent commands in this README should be run from the top-level directory of this repository (i.e., `/path/to/gated-path-planning-networks/`). 16 | 17 | ### I. Docker container 18 | 19 | We provide two Docker containers, with and without GPU support. These containers have Python 3.6.5, PyTorch 0.4.0, and other dependencies installed. They do not contain this codebase or the maze datasets used in our experiments. 20 | 21 | To load the container with GPU support: 22 | ``` 23 | # PyTorch with GPU support 24 | nvidia-docker pull lileee/ubuntu-16.04-pytorch-0.4.0-gpu:v1 25 | nvidia-docker run -v $(pwd):/home --rm -ti lileee/ubuntu-16.04-pytorch-0.4.0-gpu:v1 26 | ``` 27 | 28 | To load the container without GPU support: 29 | ``` 30 | # PyTorch (CPU-only) 31 | nvidia-docker pull lileee/ubuntu-16.04-pytorch-0.4.0-cpu:v1 32 | nvidia-docker run -v $(pwd):/home --rm -ti lileee/ubuntu-16.04-pytorch-0.4.0-cpu:v1 33 | ``` 34 | 35 | Here is a speed comparison between the Docker containers for training VIN on a `9x9` maze with 5k/1k/1k train-val-test split: 36 | 37 | | PyTorch 0.4.0 | time per epoch 38 | |:-----------------:|:---------:| 39 | |with GPU support | 8.5 sec 40 | |without GPU support| 32.3 sec 41 | 42 | ### II. Generate a 2D maze dataset 43 | 44 | Generate a dataset by running: 45 | 46 | ``` 47 | python generate_dataset.py --output-path mazes.npz --mechanism news --maze-size 9 --train-size 5000 --valid-size 1000 --test-size 1000 48 | ``` 49 | This will create a datafile `mazes.npz` containing a dataset of `9x9` mazes using the NEWS maze transition mechanism with 5k/1k/1k train-val-test split. 50 | 51 | Note: 52 | * The same maze transition mechanism that was used to generate the dataset must be used for `train.py` and `eval.py`. Here, we used `--mechanism news` to generate the dataset. Other options are `--mechanism moore` and `--mechanism diffdrive`. 53 | 54 | ### III. Train a model 55 | 56 | You can train a VIN with iteration count K=15 and kernel size F=5 on the datafile `mazes.npz` by running: 57 | ``` 58 | python train.py --datafile mazes.npz --mechanism news --model models.VIN --k 15 --f 5 --save-directory log/vin-k15-f5 59 | ``` 60 | This will save outputs to the subdirectory `vin-k15-f5/`, including the trained models and learning plots. 61 | 62 | Similarly, you can train a GPPN by running: 63 | ``` 64 | python train.py --datafile mazes.npz --mechanism news --model models.GPPN --k 15 --f 5 --save-directory log/gppn-k15-f5 65 | ``` 66 | 67 | Notes: 68 | * `--mechanism` must be the same as the one used to generate `mazes.npz` (which is `news` in this example). 69 | * `--f` must be an odd integer. 70 | 71 | ### IV. Evaluate a trained model 72 | 73 | Once you have a trained VIN model, you can evaluate it on a dataset by running: 74 | 75 | ``` 76 | python eval.py --datafile mazes.npz --mechanism news --model models.VIN --k 15 --f 5 --load-file log/vin-k15-f5/planner.final.pth 77 | ``` 78 | Similarly for GPPN: 79 | 80 | ``` 81 | python eval.py --datafile mazes.npz --mechanism news --model models.GPPN --k 15 --f 5 --load-file log/gppn-k15-f5/planner.final.pth 82 | ``` 83 | 84 | Notes: 85 | * `--mechanism` must be the same as the one used to generate `mazes.npz` (which is `news` in this example). 86 | * `--f` must be the same as the one used to train the model. 87 | 88 | 89 | ## Replicating experiments from our paper 90 | 91 | ### I. Download 2D maze datasets 92 | To replicate experiments from our ICML 2018 paper, first download the datasets by running: 93 | ``` 94 | ./download_datasets.sh 95 | ``` 96 | This will create a subdirectory `mazes/` containing the following 2D maze datasets used in our experiments: 97 | 98 | | datafile | maze size | mechanism | train size | val size | test size | 99 | | :---------------------: |:---------:| :----------:|:----------:|:--------:|:---------:| 100 | | `m15_news_10k.npz` | `15x15` | `news` | 10000 | 2000 | 2000 101 | | `m15_news_25k.npz` | `15x15` | `news` | 25000 | 5000 | 5000 102 | | `m15_news_100k.npz` | `15x15` | `news` | 100000 | 10000 | 10000 103 | | `m15_moore_10k.npz` | `15x15` | `moore` | 10000 | 2000 | 2000 104 | | `m15_moore_25k.npz` | `15x15` | `moore` | 25000 | 5000 | 5000 105 | | `m15_moore_100k.npz` | `15x15` | `moore` | 100000 | 10000 | 10000 106 | | `m15_diffdrive_10k.npz` | `15x15` | `diffdrive` | 10000 | 2000 | 2000 107 | | `m15_diffdrive_25k.npz` | `15x15` | `diffdrive` | 25000 | 5000 | 5000 108 | | `m15_diffdrive_100k.npz`| `15x15` | `diffdrive` | 100000 | 10000 | 10000 109 | | `m28_news_25k.npz` | `28x28` | `news` | 25000 | 5000 | 5000 110 | | `m28_moore_25k.npz` | `28x28` | `moore` | 25000 | 5000 | 5000 111 | | `m28_diffdrive_25k.npz` | `28x28` | `diffdrive` | 25000 | 5000 | 5000 112 | 113 | 114 | ### II. Train VIN and GPPN 115 | Then you can train VIN with the best (K, F) settings for each dataset from our paper by running: 116 | ``` 117 | python train.py --datafile mazes/m15_news_10k.npz --mechanism news --model models.VIN --k 30 --f 5 118 | python train.py --datafile mazes/m15_news_25k.npz --mechanism news --model models.VIN --k 20 --f 5 119 | python train.py --datafile mazes/m15_news_100k.npz --mechanism news --model models.VIN --k 30 --f 3 120 | 121 | python train.py --datafile mazes/m15_moore_10k.npz --mechanism moore --model models.VIN --k 30 --f 11 122 | python train.py --datafile mazes/m15_moore_25k.npz --mechanism moore --model models.VIN --k 30 --f 5 123 | python train.py --datafile mazes/m15_moore_100k.npz --mechanism moore --model models.VIN --k 30 --f 5 124 | 125 | python train.py --datafile mazes/m15_diffdrive_10k.npz --mechanism diffdrive --model models.VIN --k 30 --f 3 126 | python train.py --datafile mazes/m15_diffdrive_25k.npz --mechanism diffdrive --model models.VIN --k 30 --f 3 127 | python train.py --datafile mazes/m15_diffdrive_100k.npz --mechanism diffdrive --model models.VIN --k 30 --f 3 128 | 129 | python train.py --datafile mazes/m28_news_25k.npz --mechanism news --model models.VIN --k 56 --f 3 130 | python train.py --datafile mazes/m28_moore_25k.npz --mechanism moore --model models.VIN --k 56 --f 5 131 | python train.py --datafile mazes/m28_diffdrive_25k.npz --mechanism diffdrive --model models.VIN --k 56 --f 3 132 | ``` 133 | 134 | Similarly, you can train GPPN with the best (K, F) settings for each dataset from our paper by running: 135 | ``` 136 | python train.py --datafile mazes/m15_news_10k.npz --mechanism news --model models.GPPN --k 20 --f 9 137 | python train.py --datafile mazes/m15_news_25k.npz --mechanism news --model models.GPPN --k 20 --f 11 138 | python train.py --datafile mazes/m15_news_100k.npz --mechanism news --model models.GPPN --k 30 --f 11 139 | 140 | python train.py --datafile mazes/m15_moore_10k.npz --mechanism moore --model models.GPPN --k 30 --f 7 141 | python train.py --datafile mazes/m15_moore_25k.npz --mechanism moore --model models.GPPN --k 30 --f 9 142 | python train.py --datafile mazes/m15_moore_100k.npz --mechanism moore --model models.GPPN --k 30 --f 7 143 | 144 | python train.py --datafile mazes/m15_diffdrive_10k.npz --mechanism diffdrive --model models.GPPN --k 30 --f 11 145 | python train.py --datafile mazes/m15_diffdrive_25k.npz --mechanism diffdrive --model models.GPPN --k 30 --f 9 146 | python train.py --datafile mazes/m15_diffdrive_100k.npz --mechanism diffdrive --model models.GPPN --k 30 --f 9 147 | 148 | python train.py --datafile mazes/m28_news_25k.npz --mechanism news --model models.GPPN --k 56 --f 11 149 | python train.py --datafile mazes/m28_moore_25k.npz --mechanism moore --model models.GPPN --k 56 --f 9 150 | python train.py --datafile mazes/m28_diffdrive_25k.npz --mechanism diffdrive --model models.GPPN --k 56 --f 11 151 | ``` 152 | 153 | ### III. Test Performance Results 154 | 155 | Here are the test performance results from running the above commands inside the Docker container `lileee/ubuntu-16.04-pytorch-0.4.0-gpu:v1`: 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 |
VINGPPN
datafileKF%Opt%SucKF%Opt%Suc
m15_news_10k.npz30577.479.020996.897.8
m15_news_25k.npz20583.684.2201199.099.3
m15_news_100k.npz30392.692.8301199.799.8
m15_moore_10k.npz301186.089.330797.098.0
m15_moore_25k.npz30585.488.130998.999.5
m15_moore_100k.npz30596.997.530799.699.8
m15_diffdrive_10k.npz30398.499.0301199.199.7
m15_diffdrive_25k.npz30396.198.530998.999.5
m15_diffdrive_100k.npz30399.099.430999.899.9
m28_news_25k.npz56383.484.2561196.597.8
m28_moore_25k.npz56573.381.056996.597.9
m28_diffdrive_25k.npz56382.093.6561195.398.0
235 | 236 | Feel free to play around with different iteration counts `--k` and kernel sizes `--f`. 237 | 238 | ### IV. Version differences (Python, PyTorch) 239 | 240 | The test performance results above are slightly different from what is reported in our ICML 2018 paper due to version differences in Python (3.6.5 vs. 2.7.12) and PyTorch (0.4.0 vs. 0.3.1). 241 | 242 | Below, we provide instructions to exactly replicate the numbers reported in our ICML 2018 paper. 243 | 244 | 1. Checkout the Git branch `icml2018`: 245 | ``` 246 | git checkout icml2018 247 | ``` 248 | 249 | 2. Load the Docker container used in our experiments by running: 250 | ``` 251 | # PyTorch with GPU support 252 | nvidia-docker pull lileee/python-2.7-pytorch-0.3.1-custom:latest 253 | nvidia-docker run -v $(pwd):/home --rm -ti lileee/python-2.7-pytorch-0.3.1-custom:latest 254 | ``` 255 | This Docker container uses Python 2.7.12 and a custom version of PyTorch 0.3.1 compiled from source at https://github.com/eparisotto/pytorch. 256 | 257 | 3. Train a model: 258 | ``` 259 | python train.py --datafile mazes/m15_news_25k.npz --mechanism news --model models.VIN --k 20 --f 5 260 | ``` 261 | 262 | ## Citation 263 | 264 | If you found this code useful in your research, please cite: 265 | 266 | ``` 267 | @inproceedings{gppn2018, 268 | author = {Lisa Lee and Emilio Parisotto and Devendra Singh Chaplot and Eric Xing and Ruslan Salakhutdinov}, 269 | title = {Gated Path Planning Networks}, 270 | booktitle = {Proceedings of the 35th International Conference on Machine Learning (ICML 2018)}, 271 | year = {2018} 272 | } 273 | ``` 274 | 275 | ## Acknowledgments 276 | 277 | Thanks to [@kentsommer](https://github.com/kentsommer) for releasing a [PyTorch implementation](https://github.com/kentsommer/pytorch-value-iteration-networks) of the original VIN results, which served as a starting point for this codebase. 278 | -------------------------------------------------------------------------------- /download_datasets.sh: -------------------------------------------------------------------------------- 1 | # Downloads 2D maze datasets used in our ICML 2018 paper. 2 | 3 | OUTPUT_DIR=./mazes 4 | 5 | mkdir ${OUTPUT_DIR} 6 | 7 | # 15x15 mazes (NEWS) 8 | wget https://cmu.box.com/shared/static/l46uqg78f2iik8avr9rdfyfi28vyek7o.npz -O ${OUTPUT_DIR}/m15_news_10k.npz 9 | wget https://cmu.box.com/shared/static/voqgj886o1gfx7ievbytzybonlu8s56n.npz -O ${OUTPUT_DIR}/m15_news_25k.npz 10 | wget https://cmu.box.com/shared/static/4bp5xxs7ilfohyosy941sc8bi9fbed8w.npz -O ${OUTPUT_DIR}/m15_news_100k.npz 11 | 12 | # 15x15 mazes (Moore) 13 | wget https://cmu.box.com/shared/static/p8aanmo5kj1bm9949njmigntmbol0o7s.npz -O ${OUTPUT_DIR}/m15_moore_10k.npz 14 | wget https://cmu.box.com/shared/static/1nmxgma8uvnuezfifqxieliiy9qrx4pe.npz -O ${OUTPUT_DIR}/m15_moore_25k.npz 15 | wget https://cmu.box.com/shared/static/gg78z5ka2sjx9jcbj6m31du1v9vmfmhm.npz -O ${OUTPUT_DIR}/m15_moore_100k.npz 16 | 17 | # 15x15 mazes (DiffDrive) 18 | wget https://cmu.box.com/shared/static/3rv3aghi8df17vwnidcj1kd0qci2qa94.npz -O ${OUTPUT_DIR}/m15_diffdrive_10k.npz 19 | wget https://cmu.box.com/shared/static/3rv3aghi8df17vwnidcj1kd0qci2qa94.npz -O ${OUTPUT_DIR}/m15_diffdrive_25k.npz 20 | wget https://cmu.box.com/shared/static/pjfw2rwj88ibx4ako6balz8d810qujw1.npz -O ${OUTPUT_DIR}/m15_diffdrive_100k.npz 21 | 22 | # 28x28 mazes 23 | wget https://cmu.box.com/shared/static/2bat54yisnnx5yybzl5uhtzkex0g65sx.npz -O ${OUTPUT_DIR}/m28_news_25k.npz 24 | wget https://cmu.box.com/shared/static/c6st11gqtx1n1xs1lu86sys00yf7yu84.npz -O ${OUTPUT_DIR}/m28_moore_25k.npz 25 | wget https://cmu.box.com/shared/static/r3bgykf8zss8ro0pfqw8sc3pgjfi1urg.npz -O ${OUTPUT_DIR}/m28_diffdrive_25k.npz 26 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluates a trained model on a dataset. 3 | 4 | Example usage: 5 | python eval.py --datafile mazes.npz --mechanism news --model models.GPPN \ 6 | --k 15 --f 5 --load-file log/gppn-k15-f5/planner.final.pth 7 | """ 8 | from __future__ import print_function 9 | 10 | from utils.experiment import (parse_args, create_save_dir, get_mechanism, 11 | create_dataloader, print_stats) 12 | from utils.runner import Runner 13 | 14 | 15 | def main(): 16 | args = parse_args() 17 | 18 | create_save_dir(args.save_directory) 19 | mechanism = get_mechanism(args.mechanism) 20 | 21 | # Create DataLoaders. 22 | trainloader = create_dataloader( 23 | args.datafile, "train", args.batch_size, mechanism, shuffle=True) 24 | validloader = create_dataloader( 25 | args.datafile, "valid", args.batch_size, mechanism, shuffle=False) 26 | testloader = create_dataloader( 27 | args.datafile, "test", args.batch_size, mechanism, shuffle=False) 28 | 29 | runner = Runner(args, mechanism) 30 | 31 | print("\n------------- Evaluating final model -------------") 32 | print("\nTrain performance:") 33 | print_stats(runner.test(trainloader)) 34 | 35 | print("\nValidation performance:") 36 | print_stats(runner.test(testloader)) 37 | 38 | print("\nTest performance:") 39 | print_stats(runner.test(validloader)) 40 | 41 | print("\n------------- Evaluating best model -------------") 42 | print("\nTrain performance:") 43 | print_stats(runner.test(trainloader, use_best=True)) 44 | 45 | print("\nValidation performance:") 46 | print_stats(runner.test(testloader, use_best=True)) 47 | 48 | print("\nTest performance:") 49 | print_stats(runner.test(validloader, use_best=True)) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /generate_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generates a 2D maze dataset. 3 | 4 | Example usage: 5 | python generate_dataset.py --output-path mazes.npz --mechanism news \ 6 | --maze-size 9 --train-size 5000 --valid-size 1000 --test-size 1000 7 | """ 8 | from __future__ import print_function 9 | import sys 10 | import argparse 11 | import numpy as np 12 | 13 | from utils.dijkstra import dijkstra_dist 14 | from utils.experiment import get_mechanism 15 | from utils.maze import RandomMaze, extract_policy 16 | 17 | 18 | def generate_data(filename, 19 | train_size, 20 | valid_size, 21 | test_size, 22 | mechanism, 23 | maze_size, 24 | min_decimation, 25 | max_decimation, 26 | start_pos=(1, 1)): 27 | maze_class = RandomMaze( 28 | mechanism, 29 | maze_size, 30 | maze_size, 31 | min_decimation, 32 | max_decimation, 33 | start_pos=start_pos) 34 | 35 | def hash_maze_to_string(maze): 36 | maze = np.array(maze, dtype=np.uint8).reshape((-1)) 37 | mazekey = "" 38 | for i in range(maze.shape[0]): 39 | mazekey += str(maze[i]) 40 | return mazekey 41 | 42 | def hashed_check_maze_exists(mazekey, mazehash): 43 | if mazehash is None: 44 | return False 45 | if mazekey in mazehash: 46 | return True 47 | return False 48 | 49 | def check_maze_exists(maze, compare_mazes): 50 | if compare_mazes is None: 51 | return False 52 | diff = np.sum( 53 | np.abs(compare_mazes - maze).reshape((len(compare_mazes), -1)), 54 | axis=1) 55 | if np.sum(diff == 0): 56 | return True 57 | return False 58 | 59 | def extract_goal(goal_map): 60 | for o in range(mechanism.num_orient): 61 | for y in range(maze_size): 62 | for x in range(maze_size): 63 | if goal_map[o][y][x] == 1.: 64 | return (o, y, x) 65 | 66 | def create_dataset(data_size, compare_mazes=None): 67 | mazes = np.zeros((data_size, maze_size, maze_size)) 68 | goal_maps = np.zeros((data_size, mechanism.num_orient, maze_size, 69 | maze_size)) 70 | opt_policies = np.zeros((data_size, mechanism.num_actions, 71 | mechanism.num_orient, maze_size, maze_size)) 72 | 73 | mazehash = {} 74 | if compare_mazes is not None: 75 | for i in range(compare_mazes.shape[0]): 76 | maze = compare_mazes[i] 77 | mazekey = hash_maze_to_string(maze) 78 | mazehash[mazekey] = 1 79 | for i in range(data_size): 80 | maze, goal_map = None, None 81 | while True: 82 | maze, _, goal_map = maze_class.reset() 83 | mazekey = hash_maze_to_string(maze) 84 | 85 | # Make sure we sampled a unique maze from the compare set 86 | if hashed_check_maze_exists(mazekey, mazehash): 87 | continue 88 | mazehash[mazekey] = 1 89 | break 90 | 91 | # Use Dijkstra's to construct the optimal policy 92 | opt_value = dijkstra_dist(maze, mechanism, extract_goal(goal_map)) 93 | opt_policy = extract_policy(maze, mechanism, opt_value) 94 | 95 | mazes[i, :, :] = maze 96 | goal_maps[i, :, :, :] = goal_map 97 | opt_policies[i, :, :, :, :] = opt_policy 98 | 99 | sys.stdout.write("\r%0.4f" % (float(i) / data_size * 100) + "%") 100 | sys.stdout.flush() 101 | sys.stdout.write("\r100%\n") 102 | 103 | return mazes, goal_maps, opt_policies 104 | 105 | # Generate test set first 106 | print("Creating valid+test dataset...") 107 | validtest_mazes, validtest_goal_maps, validtest_opt_policies = create_dataset( 108 | test_size + valid_size) 109 | 110 | # Split valid and test 111 | valid_mazes = validtest_mazes[0:valid_size] 112 | test_mazes = validtest_mazes[valid_size:] 113 | valid_goal_maps = validtest_goal_maps[0:valid_size] 114 | test_goal_maps = validtest_goal_maps[valid_size:] 115 | valid_opt_policies = validtest_opt_policies[0:valid_size] 116 | test_opt_policies = validtest_opt_policies[valid_size:] 117 | 118 | # Generate train set while avoiding test geometries 119 | print("Creating training dataset...") 120 | train_mazes, train_goal_maps, train_opt_policies = create_dataset( 121 | train_size, compare_mazes=validtest_mazes) 122 | 123 | # Re-shuffle 124 | mazes = np.concatenate((train_mazes, valid_mazes, test_mazes), 0) 125 | goal_maps = np.concatenate( 126 | (train_goal_maps, valid_goal_maps, test_goal_maps), 0) 127 | opt_policies = np.concatenate( 128 | (train_opt_policies, valid_opt_policies, test_opt_policies), 0) 129 | 130 | shuffidx = np.random.permutation(mazes.shape[0]) 131 | mazes = mazes[shuffidx] 132 | goal_maps = goal_maps[shuffidx] 133 | opt_policies = opt_policies[shuffidx] 134 | 135 | train_mazes = mazes[:train_size] 136 | train_goal_maps = goal_maps[:train_size] 137 | train_opt_policies = opt_policies[:train_size] 138 | 139 | valid_mazes = mazes[train_size:train_size + valid_size] 140 | valid_goal_maps = goal_maps[train_size:train_size + valid_size] 141 | valid_opt_policies = opt_policies[train_size:train_size + valid_size] 142 | 143 | test_mazes = mazes[train_size + valid_size:] 144 | test_goal_maps = goal_maps[train_size + valid_size:] 145 | test_opt_policies = opt_policies[train_size + valid_size:] 146 | 147 | # Save to numpy 148 | np.savez_compressed(filename, train_mazes, train_goal_maps, 149 | train_opt_policies, valid_mazes, valid_goal_maps, 150 | valid_opt_policies, test_mazes, test_goal_maps, 151 | test_opt_policies) 152 | 153 | 154 | def main(): 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument( 157 | "--output-path", type=str, default="mazes.npz", 158 | help="Filename to save the dataset to.") 159 | parser.add_argument( 160 | "--train-size", type=int, default=10000, 161 | help="Number of training mazes.") 162 | parser.add_argument( 163 | "--valid-size", type=int, default=1000, 164 | help="Number of validation mazes.") 165 | parser.add_argument( 166 | "--test-size", type=int, default=1000, 167 | help="Number of test mazes.") 168 | parser.add_argument( 169 | "--maze-size", type=int, default=9, 170 | help="Size of mazes.") 171 | parser.add_argument( 172 | "--min-decimation", type=float, default=0.0, 173 | help="How likely a wall is to be destroyed (minimum).") 174 | parser.add_argument("--max-decimation", type=float, default=1.0, 175 | help="How likely a wall is to be destroyed (maximum).") 176 | parser.add_argument( 177 | "--start-pos-x", type=int, default=1, 178 | help="Maze start X-axis position.") 179 | parser.add_argument( 180 | "--start-pos-y", type=int, default=1, 181 | help="Maze start Y-axis position.") 182 | parser.add_argument( 183 | "--mechanism", type=str, default="news", 184 | help="Maze transition mechanism. (news|diffdrive|moore)") 185 | args = parser.parse_args() 186 | 187 | mechanism = get_mechanism(args.mechanism) 188 | generate_data( 189 | args.output_path, 190 | args.train_size, 191 | args.valid_size, 192 | args.test_size, 193 | mechanism, 194 | args.maze_size, 195 | args.min_decimation, 196 | args.max_decimation, 197 | start_pos=(args.start_pos_y, args.start_pos_x)) 198 | 199 | 200 | if __name__ == "__main__": 201 | main() -------------------------------------------------------------------------------- /models/GPPN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Gated Path Planning Network module 6 | class Planner(nn.Module): 7 | """ 8 | Implementation of the Gated Path Planning Network. 9 | """ 10 | def __init__(self, num_orient, num_actions, args): 11 | super(Planner, self).__init__() 12 | 13 | self.num_orient = num_orient 14 | self.num_actions = num_actions 15 | 16 | self.l_h = args.l_h 17 | self.k = args.k 18 | self.f = args.f 19 | 20 | self.hid = nn.Conv2d( 21 | in_channels=(num_orient + 1), # maze map + goal location 22 | out_channels=self.l_h, 23 | kernel_size=(3, 3), 24 | stride=1, 25 | padding=1, 26 | bias=True) 27 | 28 | self.h0 = nn.Conv2d( 29 | in_channels=self.l_h, 30 | out_channels=self.l_h, 31 | kernel_size=(3, 3), 32 | stride=1, 33 | padding=1, 34 | bias=True) 35 | 36 | self.c0 = nn.Conv2d( 37 | in_channels=self.l_h, 38 | out_channels=self.l_h, 39 | kernel_size=(3, 3), 40 | stride=1, 41 | padding=1, 42 | bias=True) 43 | 44 | self.conv = nn.Conv2d( 45 | in_channels=self.l_h, 46 | out_channels=1, 47 | kernel_size=(self.f, self.f), 48 | stride=1, 49 | padding=int((self.f - 1.0) / 2), 50 | bias=True) 51 | 52 | self.lstm = nn.LSTMCell(1, self.l_h) 53 | 54 | self.policy = nn.Conv2d( 55 | in_channels=self.l_h, 56 | out_channels=num_actions * num_orient, 57 | kernel_size=(1, 1), 58 | stride=1, 59 | padding=0, 60 | bias=False) 61 | 62 | self.sm = nn.Softmax2d() 63 | 64 | def forward(self, map_design, goal_map): 65 | maze_size = map_design.size()[-1] 66 | X = torch.cat([map_design, goal_map], 1) 67 | 68 | hid = self.hid(X) 69 | h0 = self.h0(hid).transpose(1, 3).contiguous().view(-1, self.l_h) 70 | c0 = self.c0(hid).transpose(1, 3).contiguous().view(-1, self.l_h) 71 | 72 | last_h, last_c = h0, c0 73 | for _ in range(0, self.k - 1): 74 | h_map = last_h.view(-1, maze_size, maze_size, self.l_h) 75 | h_map = h_map.transpose(3, 1) 76 | inp = self.conv(h_map).transpose(1, 3).contiguous().view(-1, 1) 77 | 78 | last_h, last_c = self.lstm(inp, (last_h, last_c)) 79 | 80 | hk = last_h.view(-1, maze_size, maze_size, self.l_h).transpose(3, 1) 81 | logits = self.policy(hk) 82 | 83 | # Normalize over actions 84 | logits = logits.view(-1, self.num_actions, maze_size, maze_size) 85 | probs = self.sm(logits) 86 | 87 | # Reshape to output dimensions 88 | logits = logits.view(-1, self.num_orient, self.num_actions, maze_size, 89 | maze_size) 90 | probs = probs.view(-1, self.num_orient, self.num_actions, maze_size, 91 | maze_size) 92 | logits = torch.transpose(logits, 1, 2).contiguous() 93 | probs = torch.transpose(probs, 1, 2).contiguous() 94 | 95 | return logits, probs, h0, hk 96 | -------------------------------------------------------------------------------- /models/VIN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | 6 | 7 | # VIN planner module 8 | class Planner(nn.Module): 9 | """ 10 | Implementation of the Value Iteration Network. 11 | """ 12 | def __init__(self, num_orient, num_actions, args): 13 | super(Planner, self).__init__() 14 | 15 | self.num_orient = num_orient 16 | self.num_actions = num_actions 17 | 18 | self.l_q = args.l_q 19 | self.l_h = args.l_h 20 | self.k = args.k 21 | self.f = args.f 22 | 23 | self.h = nn.Conv2d( 24 | in_channels=(num_orient + 1), # maze map + goal location 25 | out_channels=self.l_h, 26 | kernel_size=(3, 3), 27 | stride=1, 28 | padding=1, 29 | bias=True) 30 | 31 | self.r = nn.Conv2d( 32 | in_channels=self.l_h, 33 | out_channels=num_orient, # reward per orientation 34 | kernel_size=(1, 1), 35 | stride=1, 36 | padding=0, 37 | bias=False) 38 | 39 | self.q = nn.Conv2d( 40 | in_channels=num_orient, 41 | out_channels=self.l_q * num_orient, 42 | kernel_size=(self.f, self.f), 43 | stride=1, 44 | padding=int((self.f - 1.0) / 2), 45 | bias=False) 46 | 47 | self.policy = nn.Conv2d( 48 | in_channels=self.l_q * num_orient, 49 | out_channels=num_actions * num_orient, 50 | kernel_size=(1, 1), 51 | stride=1, 52 | padding=0, 53 | bias=False) 54 | 55 | self.w = Parameter( 56 | torch.zeros(self.l_q * num_orient, num_orient, self.f, 57 | self.f), 58 | requires_grad=True) 59 | 60 | self.sm = nn.Softmax2d() 61 | 62 | def forward(self, map_design, goal_map): 63 | maze_size = map_design.size()[-1] 64 | X = torch.cat([map_design, goal_map], 1) 65 | 66 | h = self.h(X) 67 | r = self.r(h) 68 | q = self.q(r) 69 | q = q.view(-1, self.num_orient, self.l_q, maze_size, maze_size) 70 | v, _ = torch.max(q, dim=2, keepdim=True) 71 | v = v.view(-1, self.num_orient, maze_size, maze_size) 72 | for _ in range(0, self.k - 1): 73 | q = F.conv2d( 74 | torch.cat([r, v], 1), 75 | torch.cat([self.q.weight, self.w], 1), 76 | stride=1, 77 | padding=int((self.f - 1.0) / 2)) 78 | q = q.view(-1, self.num_orient, self.l_q, maze_size, maze_size) 79 | v, _ = torch.max(q, dim=2) 80 | v = v.view(-1, self.num_orient, maze_size, maze_size) 81 | 82 | q = F.conv2d( 83 | torch.cat([r, v], 1), 84 | torch.cat([self.q.weight, self.w], 1), 85 | stride=1, 86 | padding=int((self.f - 1.0) / 2)) 87 | 88 | logits = self.policy(q) 89 | 90 | # Normalize over actions 91 | logits = logits.view(-1, self.num_actions, maze_size, maze_size) 92 | probs = self.sm(logits) 93 | 94 | # Reshape to output dimensions 95 | logits = logits.view(-1, self.num_orient, self.num_actions, maze_size, 96 | maze_size) 97 | probs = probs.view(-1, self.num_orient, self.num_actions, maze_size, 98 | maze_size) 99 | logits = torch.transpose(logits, 1, 2).contiguous() 100 | probs = torch.transpose(probs, 1, 2).contiguous() 101 | 102 | return logits, probs, v, r 103 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RLAgent/gated-path-planning-networks/df010fc26667f2b4f0dbbcc0d3d528d4ae6efeb1/models/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch >= 0.4.0 3 | matplotlib -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trains a planner model. 3 | 4 | Example usage: 5 | python train.py --datafile mazes.npz --mechanism news --model models.GPPN \ 6 | --k 15 --f 5 --save-directory log/gppn-k15-f5 7 | """ 8 | from __future__ import print_function 9 | 10 | import argparse 11 | import time 12 | import numpy as np 13 | 14 | import matplotlib as mpl 15 | mpl.use("Agg") 16 | import matplotlib.pyplot as plt 17 | 18 | import torch 19 | 20 | from utils.experiment import (parse_args, create_save_dir, get_mechanism, 21 | create_dataloader, print_row, print_stats) 22 | from utils.runner import Runner 23 | 24 | 25 | def main(): 26 | args = parse_args() 27 | 28 | save_path = create_save_dir(args.save_directory) 29 | mechanism = get_mechanism(args.mechanism) 30 | 31 | # Create DataLoaders 32 | trainloader = create_dataloader( 33 | args.datafile, "train", args.batch_size, mechanism, shuffle=True) 34 | validloader = create_dataloader( 35 | args.datafile, "valid", args.batch_size, mechanism, shuffle=False) 36 | testloader = create_dataloader( 37 | args.datafile, "test", args.batch_size, mechanism, shuffle=False) 38 | 39 | runner = Runner(args, mechanism) 40 | 41 | # Print header 42 | col_width = 5 43 | print("\n | Train | Valid |") # pylint: disable=line-too-long 44 | print_row(col_width, [ 45 | "Epoch", "CE", "Err", "%Opt", "%Suc", "CE", "Err", "%Opt", "%Suc", "W", 46 | "dW", "Time", "Best" 47 | ]) 48 | 49 | tr_total_loss, tr_total_error, tr_total_optimal, tr_total_success = [], [], [], [] 50 | v_total_loss, v_total_error, v_total_optimal, v_total_success = [], [], [], [] 51 | for epoch in range(args.epochs): 52 | start_time = time.time() 53 | 54 | # Train the model 55 | tr_info = runner.train(trainloader, args.batch_size) 56 | 57 | # Compute validation stats and save the best model 58 | v_info = runner.validate(validloader) 59 | time_duration = time.time() - start_time 60 | 61 | # Print epoch logs 62 | print_row(col_width, [ 63 | epoch + 1, tr_info["avg_loss"], tr_info["avg_error"], 64 | tr_info["avg_optimal"], tr_info["avg_success"], v_info["avg_loss"], 65 | v_info["avg_error"], v_info["avg_optimal"], v_info["avg_success"], 66 | tr_info["weight_norm"], tr_info["grad_norm"], 67 | time_duration, 68 | "!" if v_info["is_best"] else " " 69 | ]) 70 | 71 | # Keep track of metrics: 72 | tr_total_loss.append(tr_info["avg_loss"]) 73 | tr_total_error.append(tr_info["avg_error"]) 74 | tr_total_optimal.append(tr_info["avg_optimal"]) 75 | tr_total_success.append(tr_info["avg_success"]) 76 | v_total_loss.append(v_info["avg_loss"]) 77 | v_total_error.append(v_info["avg_error"]) 78 | v_total_optimal.append(v_info["avg_optimal"]) 79 | v_total_success.append(v_info["avg_success"]) 80 | 81 | # Plot learning curves. 82 | def _plot(train, valid, name): 83 | plt.clf() 84 | x = np.array(range(len(train))) 85 | y = np.array(valid) 86 | plt.plot(x, np.array(train), label="train") 87 | plt.plot(x, np.array(valid), label="valid") 88 | plt.legend() 89 | plt.savefig(name) 90 | _plot(tr_total_loss, v_total_loss, save_path + "_total_loss.pdf") 91 | _plot(tr_total_error, v_total_error, save_path + "_total_error.pdf") 92 | _plot(tr_total_optimal, v_total_optimal, 93 | save_path + "_total_optimal.pdf") 94 | _plot(tr_total_success, v_total_success, 95 | save_path + "_total_success.pdf") 96 | 97 | # Save intermediate model. 98 | if args.save_intermediate: 99 | torch.save({ 100 | "model": runner.model.state_dict(), 101 | "best_model": runner.best_model.state_dict(), 102 | "tr_total_loss": tr_total_loss, 103 | "tr_total_error": tr_total_error, 104 | "tr_total_optimal": tr_total_optimal, 105 | "tr_total_success": tr_total_success, 106 | "v_total_loss": v_total_loss, 107 | "v_total_error": v_total_error, 108 | "v_total_optimal": v_total_optimal, 109 | "v_total_success": v_total_success, 110 | }, save_path + ".e" + str(epoch) + ".pth") 111 | 112 | # Test accuracy 113 | print("\nFinal test performance:") 114 | t_final_info = runner.test(testloader) 115 | print_stats(t_final_info) 116 | 117 | print("\nBest test performance:") 118 | t_best_info = runner.test(testloader, use_best=True) 119 | print_stats(t_best_info) 120 | 121 | # Save the final trained model 122 | torch.save({ 123 | "model": runner.model.state_dict(), 124 | "best_model": runner.best_model.state_dict(), 125 | "tr_total_loss": tr_total_loss, 126 | "tr_total_error": tr_total_error, 127 | "tr_total_optimal": tr_total_optimal, 128 | "tr_total_success": tr_total_success, 129 | "v_total_loss": v_total_loss, 130 | "v_total_error": v_total_error, 131 | "v_total_optimal": v_total_optimal, 132 | "v_total_success": v_total_success, 133 | "t_final_loss": t_final_info["avg_loss"], 134 | "t_final_error": t_final_info["avg_error"], 135 | "t_final_optimal": t_final_info["avg_optimal"], 136 | "t_final_success": t_final_info["avg_success"], 137 | "t_best_loss": t_best_info["avg_loss"], 138 | "t_best_error": t_best_info["avg_error"], 139 | "t_best_optimal": t_best_info["avg_optimal"], 140 | "t_best_success": t_best_info["avg_success"], 141 | }, save_path + ".final.pth") 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RLAgent/gated-path-planning-networks/df010fc26667f2b4f0dbbcc0d3d528d4ae6efeb1/utils/__init__.py -------------------------------------------------------------------------------- /utils/dijkstra.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | 5 | class MinHeap: 6 | 7 | def __init__(self): 8 | self.heap = [] # binary min-heap 9 | self.heapdict = {} # [key]->index dict 10 | self.invheapdict = {} # [index]->key dict 11 | self.heap_length = 0 # number of elements 12 | 13 | def empty(self): 14 | return self.heap_length == 0 15 | 16 | def insert(self, key, val): 17 | # Insert the value at the bottom of heap 18 | if self.heap_length == len(self.heap): 19 | self.heap.append(val) 20 | else: 21 | self.heap[self.heap_length] = val 22 | add_idx = self.heap_length 23 | self.heap_length += 1 24 | 25 | # Update the dictionaries 26 | self.heapdict[key] = add_idx 27 | self.invheapdict[add_idx] = key 28 | 29 | # percolate upwards 30 | self._percolate_up(add_idx) 31 | 32 | def decrease(self, key, new_val): 33 | # Find the index and value of this key 34 | curr_idx = self.heapdict[key] 35 | curr_val = self.heap[curr_idx] 36 | assert new_val <= curr_val 37 | 38 | # Update with new lower value 39 | self.heap[curr_idx] = new_val 40 | 41 | # Percolate upwards 42 | self._percolate_up(curr_idx) 43 | 44 | def extract(self): 45 | assert self.heap_length > 0 46 | 47 | retval = self.heap[0] 48 | retkey = self.invheapdict[0] 49 | 50 | # Swap the root with a leaf 51 | self._swap_index(0, self.heap_length - 1) 52 | 53 | # Delete the last element from the dictionaries 54 | del self.heapdict[retkey] 55 | del self.invheapdict[self.heap_length - 1] 56 | self.heap_length -= 1 57 | 58 | # Percolate downwards 59 | self._percolate_down(0) 60 | return retkey, retval 61 | 62 | def _swap_index(self, idx1, idx2): 63 | # get keys 64 | key1 = self.invheapdict[idx1] 65 | key2 = self.invheapdict[idx2] 66 | 67 | # Swap values in the heap 68 | tmp1_ = self.heap[idx1] 69 | self.heap[idx1] = self.heap[idx2] 70 | self.heap[idx2] = tmp1_ 71 | 72 | # Swap indices in the [key]->index dict 73 | tmp2_ = self.heapdict[key1] 74 | self.heapdict[key1] = self.heapdict[key2] 75 | self.heapdict[key2] = tmp2_ 76 | 77 | # Swap keys in the [index]->key dict 78 | tmp3_ = self.invheapdict[idx1] 79 | self.invheapdict[idx1] = self.invheapdict[idx2] 80 | self.invheapdict[idx2] = tmp3_ 81 | 82 | def _percolate_up(self, curr_idx): 83 | while curr_idx != 0: 84 | parent_idx = int(math.floor((curr_idx - 1) / 2)) 85 | if self.heap[parent_idx] > self.heap[curr_idx]: 86 | self._swap_index(curr_idx, parent_idx) 87 | curr_idx = parent_idx 88 | else: 89 | break 90 | 91 | def _percolate_down(self, curr_idx): 92 | while curr_idx < self.heap_length: 93 | child1 = 2 * curr_idx + 1 94 | child2 = child1 + 1 95 | if child1 >= self.heap_length: 96 | break 97 | minchild = child1 98 | maxchild = child2 99 | if child2 >= self.heap_length: 100 | maxchild = None 101 | if (maxchild is not None) and (self.heap[child1] > 102 | self.heap[child2]): 103 | minchild = child2 104 | maxchild = child1 105 | 106 | if self.heap[minchild] < self.heap[curr_idx]: 107 | self._swap_index(curr_idx, minchild) 108 | curr_idx = minchild 109 | continue 110 | 111 | if (maxchild is not None) and (self.heap[maxchild] < 112 | self.heap[curr_idx]): 113 | self._swap_index(curr_idx, maxchild) 114 | curr_idx = maxchild 115 | continue 116 | break 117 | 118 | 119 | def dijkstra_dist(maze, mechanism, goal): 120 | # Initialize distance to largest possible distance 121 | dist = (np.zeros((mechanism.num_orient, maze.shape[0], maze.shape[1])) + 122 | mechanism.num_orient * maze.shape[0] * maze.shape[1]) 123 | 124 | pq = MinHeap() 125 | pq.insert(goal, 0) 126 | for orient in range(mechanism.num_orient): 127 | for y in range(maze.shape[0]): 128 | for x in range(maze.shape[1]): 129 | if (orient == goal[0]) and (y == goal[1]) and (x == goal[2]): 130 | continue 131 | pq.insert((orient, y, x), 132 | mechanism.num_orient * maze.shape[0] * maze.shape[1]) 133 | 134 | while not pq.empty(): 135 | # extract minimum distance position 136 | ((p_orient, p_y, p_x), val) = pq.extract() 137 | dist[p_orient][p_y][p_x] = val 138 | 139 | # Update neighbors 140 | for n in mechanism.invneighbors_func(maze, p_orient, p_y, p_x): 141 | if (n[1] < 0) or (n[1] >= maze.shape[0]): 142 | continue 143 | if (n[2] < 0) or (n[2] >= maze.shape[1]): 144 | continue 145 | 146 | if maze[n[1]][n[2]] == 0.: 147 | continue 148 | 149 | curr_to_n = val + 1 150 | if curr_to_n < dist[n[0]][n[1]][n[2]]: 151 | dist[n[0]][n[1]][n[2]] = curr_to_n 152 | pq.decrease(n, curr_to_n) 153 | return -dist # negative distance ~= value 154 | 155 | 156 | def dijkstra_policy(maze, mechanism, goal, policy): 157 | # Initialize distance to largest possible distance 158 | dist = np.zeros( 159 | (mechanism.num_orient, maze.shape[0], 160 | maze.shape[1])) + mechanism.num_orient * maze.shape[0] * maze.shape[1] 161 | 162 | pq = MinHeap() 163 | pq.insert(goal, 0) 164 | for orient in range(mechanism.num_orient): 165 | for y in range(maze.shape[0]): 166 | for x in range(maze.shape[1]): 167 | if (orient == goal[0]) and (y == goal[1]) and (x == goal[2]): 168 | continue 169 | pq.insert((orient, y, x), 170 | mechanism.num_orient * maze.shape[0] * maze.shape[1]) 171 | 172 | while not pq.empty(): 173 | # extract minimum distance position 174 | ((p_orient, p_y, p_x), val) = pq.extract() 175 | dist[p_orient][p_y][p_x] = val 176 | 177 | # Update neighboring predecessors 178 | predecessors = mechanism.invneighbors_func(maze, p_orient, p_y, p_x) 179 | for i in range(len(predecessors)): 180 | n = predecessors[i] 181 | if (n[1] < 0) or (n[1] >= maze.shape[0]): 182 | continue 183 | if (n[2] < 0) or (n[2] >= maze.shape[1]): 184 | continue 185 | 186 | if maze[n[1]][n[2]] == 0.: 187 | continue 188 | 189 | # What are the successor from this predecessor state? 190 | succ_pred = mechanism.neighbors_func(maze, n[0], n[1], n[2]) 191 | 192 | # Does following the policy on the predecessor state transition to 193 | # the current state? 194 | succ_pred_pol = succ_pred[policy[n[0]][n[1]][n[2]]] 195 | if (succ_pred_pol[0] == p_orient) and ( 196 | succ_pred_pol[1] == p_y) and (succ_pred_pol[2] == p_x): 197 | # Update value 198 | curr_to_n = val + 1 199 | if curr_to_n < dist[n[0]][n[1]][n[2]]: 200 | dist[n[0]][n[1]][n[2]] = curr_to_n 201 | pq.decrease(n, curr_to_n) 202 | return -dist # negative distance ~= value 203 | -------------------------------------------------------------------------------- /utils/experiment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | 5 | import torch 6 | 7 | from utils.maze import MazeDataset 8 | from utils.mechanism import DifferentialDrive, NorthEastWestSouth, Moore 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser( 13 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 14 | 15 | # Environment parameters 16 | parser.add_argument( 17 | "--datafile", type=str, default="mazes.npz", help="Path to data file.") 18 | parser.add_argument( 19 | "--gamma", type=float, default=0.99, 20 | help="Discount factor. (keeps value within reasonable range).") 21 | parser.add_argument( 22 | "--mechanism", type=str, default="news", 23 | help="Maze transition mechanism. (news|diffdrive|moore)") 24 | 25 | # Log parameters 26 | parser.add_argument( 27 | "--save-directory", type=str, default="log/", 28 | help="Directory to save the graphs and models.") 29 | parser.add_argument( 30 | "--save-intermediate", default=False, 31 | help="Whether to save every epoch.") 32 | parser.add_argument( 33 | "--use-percent-successful", default=False, 34 | help="Use % successful instead of % optimal to decide best models.") 35 | 36 | # Optimization parameters 37 | parser.add_argument( 38 | "--optimizer", type=str, default="RMSprop", 39 | help="Which optimizer to use.") 40 | parser.add_argument( 41 | "--epochs", type=int, default=30, help="Number of epochs to train.") 42 | parser.add_argument( 43 | "--lr", type=float, default=0.001, help="Learning rate.") 44 | parser.add_argument( 45 | "--lr-decay", type=float, default=1.0, 46 | help="Learning rate decay when CE goes up.") 47 | parser.add_argument( 48 | "--eps", type=float, default=1e-6, help="Epsilon for denominator.") 49 | parser.add_argument( 50 | "--batch-size", type=int, default=32, help="Batch size.") 51 | parser.add_argument( 52 | "--clip-grad", type=float, default=40, 53 | help="Whether to clip the gradient norms. (0 for none)") 54 | 55 | # Model parameters 56 | parser.add_argument( 57 | "--model", type=str, default="models.VIN", 58 | help="Which model architecture to train.") 59 | parser.add_argument( 60 | "--load-file", type=str, default="", 61 | help="Model weights to load. (leave empty for none)") 62 | parser.add_argument( 63 | "--load-best", default=False, 64 | help="Whether to load the best weights from the load-file.") 65 | parser.add_argument( 66 | "--k", type=int, default=10, help="Number of Value Iterations.") 67 | parser.add_argument( 68 | "--l-i", type=int, default=5, 69 | help="Number of channels in input layer.") 70 | parser.add_argument( 71 | "--l-h", type=int, default=150, 72 | help="Number of channels in first hidden layer.") 73 | parser.add_argument( 74 | "--l-q", type=int, default=600, 75 | help="Number of channels in q layer (~actions) in VI-module.") 76 | parser.add_argument( 77 | "--f", type=int, default=3, help="Kernel size") 78 | 79 | args = parser.parse_args() 80 | 81 | # Automatic switch of GPU mode if available 82 | args.use_gpu = torch.cuda.is_available() 83 | 84 | return args 85 | 86 | 87 | def get_mechanism(mechanism): 88 | if mechanism == "news": 89 | print("Using NEWS Drive") 90 | return NorthEastWestSouth() 91 | elif mechanism == "diffdrive": 92 | print("Using Differential Drive") 93 | return DifferentialDrive() 94 | elif mechanism == "moore": 95 | print("Using Moore Drive") 96 | return Moore() 97 | else: 98 | raise ValueError("Unsupported mechanism: %s" % mechanism) 99 | 100 | 101 | def create_dataloader(datafile, dataset_type, batch_size, mechanism, shuffle=False): 102 | """ 103 | Creates a maze DataLoader. 104 | Args: 105 | datafile (string): Path to the dataset 106 | dataset_type (string): One of "train", "valid", or "test" 107 | batch_size (int): The batch size 108 | shuffle (bool): Whether to shuffle the data 109 | """ 110 | dataset = MazeDataset(datafile, dataset_type) 111 | assert dataset.num_actions == mechanism.num_actions 112 | return torch.utils.data.DataLoader( 113 | dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0) 114 | 115 | 116 | def create_save_dir(save_directory): 117 | """ 118 | Creates and returns path to the save directory. 119 | """ 120 | try: 121 | os.makedirs(save_directory) 122 | except OSError: 123 | if not os.path.isdir(save_directory): 124 | raise 125 | return save_directory + "/planner" 126 | 127 | 128 | def print_row(width, items): 129 | """ 130 | Prints the given items. 131 | Args: 132 | width (int): Character length for each column. 133 | items (list): List of items to print. 134 | """ 135 | def fmt_item(x): 136 | if isinstance(x, np.ndarray): 137 | assert x.ndim == 0 138 | x = x.item() 139 | if isinstance(x, float): 140 | rep = "%.3f" % x 141 | else: 142 | rep = str(x) 143 | return rep.ljust(width) 144 | 145 | print(" | ".join(fmt_item(item) for item in items)) 146 | 147 | 148 | def print_stats(info): 149 | """Prints performance statistics output from Runner.""" 150 | print_row(10, ["Loss", "Err", "% Optimal", "% Success"]) 151 | print_row(10, [ 152 | info["avg_loss"], info["avg_error"], 153 | info["avg_optimal"], info["avg_success"]]) 154 | return info 155 | -------------------------------------------------------------------------------- /utils/maze.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import sys 4 | import math 5 | import numpy as np 6 | import numpy.random as npr 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | 12 | class MazeDataset(data.Dataset): 13 | 14 | def __init__(self, filename, dataset_type): 15 | """ 16 | Args: 17 | filename (str): Dataset filename (must be .npz format). 18 | dataset_type (str): One of "train", "valid", or "test". 19 | """ 20 | assert filename.endswith("npz") # Must be .npz format 21 | self.filename = filename 22 | self.dataset_type = dataset_type # train, valid, test 23 | 24 | self.mazes, self.goal_maps, self.opt_policies = self._process(filename) 25 | 26 | self.num_actions = self.opt_policies.shape[1] 27 | self.num_orient = self.opt_policies.shape[2] 28 | 29 | def _process(self, filename): 30 | """ 31 | Data format: list, [train data, test data] 32 | """ 33 | with np.load(filename) as f: 34 | dataset2idx = {"train": 0, "valid": 3, "test": 6} 35 | idx = dataset2idx[self.dataset_type] 36 | mazes = f["arr_" + str(idx)] 37 | goal_maps = f["arr_" + str(idx + 1)] 38 | opt_policies = f["arr_" + str(idx + 2)] 39 | 40 | # Set proper datatypes 41 | mazes = mazes.astype(np.float32) 42 | goal_maps = goal_maps.astype(np.float32) 43 | opt_policies = opt_policies.astype(np.float32) 44 | 45 | # Print number of samples 46 | if self.dataset_type == "train": 47 | print("Number of Train Samples: {0}".format(mazes.shape[0])) 48 | elif self.dataset_type == "valid": 49 | print("Number of Validation Samples: {0}".format(mazes.shape[0])) 50 | else: 51 | print("Number of Test Samples: {0}".format(mazes.shape[0])) 52 | print("\tSize: {}x{}".format(mazes.shape[1], mazes.shape[2])) 53 | return mazes, goal_maps, opt_policies 54 | 55 | def __getitem__(self, index): 56 | maze = self.mazes[index] 57 | goal_map = self.goal_maps[index] 58 | opt_policy = self.opt_policies[index] 59 | 60 | maze = torch.from_numpy(maze) 61 | goal_map = torch.from_numpy(goal_map) 62 | opt_policy = torch.from_numpy(opt_policy) 63 | 64 | return maze, goal_map, opt_policy 65 | 66 | def __len__(self): 67 | return self.mazes.shape[0] 68 | 69 | 70 | def generate_maze(maze_size, decimation, start_pos=(1, 1)): 71 | maze = np.zeros((maze_size, maze_size)) 72 | 73 | stack = [((start_pos[0], start_pos[1]), (0, 0))] 74 | 75 | def add_stack(next_pos, next_dir): 76 | if (next_pos[0] < 0) or (next_pos[0] >= maze_size): 77 | return 78 | if (next_pos[1] < 0) or (next_pos[1] >= maze_size): 79 | return 80 | if maze[next_pos[0]][next_pos[1]] == 0.: 81 | stack.append((next_pos, next_dir)) 82 | 83 | while stack: 84 | pos, prev_dir = stack.pop() 85 | # Has this not been filled since being added? 86 | if maze[pos[0]][pos[1]] == 1.: 87 | continue 88 | 89 | # Fill in this point + break down wall from previous position 90 | maze[pos[0]][pos[1]] = 1. 91 | maze[pos[0] - prev_dir[0]][pos[1] - prev_dir[1]] = 1. 92 | 93 | choices = [] 94 | choices.append(((pos[0] - 2, pos[1]), (-1, 0))) 95 | choices.append(((pos[0], pos[1] + 2), (0, 1))) 96 | choices.append(((pos[0], pos[1] - 2), (0, -1))) 97 | choices.append(((pos[0] + 2, pos[1]), (1, 0))) 98 | 99 | perm = np.random.permutation(np.array(range(4))) 100 | for i in range(4): 101 | choice = choices[perm[i]] 102 | add_stack(choice[0], choice[1]) 103 | 104 | for y in range(1, maze_size - 1): 105 | for x in range(1, maze_size - 1): 106 | if np.random.uniform() < decimation: 107 | maze[y][x] = 1. 108 | 109 | return maze 110 | 111 | 112 | class RandomMaze: 113 | 114 | def __init__(self, 115 | mechanism, 116 | min_maze_size, 117 | max_maze_size, 118 | min_decimation, 119 | max_decimation, 120 | start_pos=(1, 1)): 121 | self.mechanism = mechanism 122 | self.min_maze_size = min_maze_size 123 | self.max_maze_size = max_maze_size 124 | self.min_decimation = min_decimation 125 | self.max_decimation = max_decimation 126 | self.start_pos = start_pos 127 | 128 | def _isGoalPos(self, pos): 129 | """Returns true if pos is equal to the goal position.""" 130 | return pos[0] == self.goal_pos[0] and pos[1] == self.goal_pos[1] 131 | 132 | def _getState(self): 133 | """Returns the current state.""" 134 | goal_map = np.zeros((self.mechanism.num_orient, self.maze_size, 135 | self.maze_size)) 136 | goal_map[self.goal_orient, self.goal_pos[0], self.goal_pos[1]] = 1. 137 | 138 | player_map = np.zeros((self.mechanism.num_orient, self.maze_size, 139 | self.maze_size)) 140 | player_map[self.player_orient, self.player_pos[0], 141 | self.player_pos[1]] = 1. 142 | 143 | # Check if agent has reached the goal state 144 | reward = 0 145 | terminal = False 146 | if (self.player_orient == self.goal_orient) and self._isGoalPos( 147 | self.player_pos): 148 | reward = 1 149 | terminal = True 150 | 151 | return np.copy(self.maze), player_map, goal_map, reward, terminal 152 | 153 | def reset(self): 154 | """Resets the maze.""" 155 | if self.min_maze_size == self.max_maze_size: 156 | self.maze_size = self.min_maze_size 157 | else: 158 | self.maze_size = self.min_maze_size + 2 * npr.randint( 159 | math.floor((self.max_maze_size - self.min_maze_size) / 2)) 160 | if self.min_decimation == self.max_decimation: 161 | self.decimation = self.min_decimation 162 | else: 163 | self.decimation = npr.uniform(self.min_decimation, 164 | self.max_decimation) 165 | self.maze = generate_maze( 166 | self.maze_size, self.decimation, start_pos=self.start_pos) 167 | 168 | # Randomly sample a goal location 169 | self.goal_pos = (npr.randint(1, self.maze_size - 1), 170 | npr.randint(1, self.maze_size - 1)) 171 | while self._isGoalPos(self.start_pos): 172 | self.goal_pos = (npr.randint(1, self.maze_size - 1), 173 | npr.randint(1, self.maze_size - 1)) 174 | self.goal_orient = npr.randint(self.mechanism.num_orient) 175 | 176 | # Free the maze at the goal location 177 | self.maze[self.goal_pos[0]][self.goal_pos[1]] = 1. 178 | 179 | # Player start position 180 | self.player_pos = (self.start_pos[0], self.start_pos[1]) 181 | 182 | # Sample player orientation 183 | self.player_orient = npr.randint(self.mechanism.num_orient) 184 | 185 | screen, player_map, goal_map, _, _ = self._getState() 186 | return screen, player_map, goal_map 187 | 188 | def step(self, action): 189 | # Compute neighbors for the current state. 190 | neighbors = self.neighbors_func(self.maze, self.player_orient, 191 | self.player_pos[0], self.player_pos[1]) 192 | assert (action > 0) and (action < len(neighbors)) 193 | self.player_orient, self.player_pos[0], self.player_pos[1] = neighbors[ 194 | action] 195 | return self._getState() 196 | 197 | 198 | def extract_policy(maze, mechanism, value): 199 | """Extracts the policy from the given values.""" 200 | policy = np.zeros((mechanism.num_actions, value.shape[0], value.shape[1], 201 | value.shape[2])) 202 | for p_orient in range(value.shape[0]): 203 | for p_y in range(value.shape[1]): 204 | for p_x in range(value.shape[2]): 205 | # Find the neighbor w/ max value (assuming deterministic 206 | # transitions) 207 | max_val = -sys.maxsize 208 | max_acts = [0] 209 | neighbors = mechanism.neighbors_func(maze, p_orient, p_y, p_x) 210 | for i in range(len(neighbors)): 211 | n = neighbors[i] 212 | nval = value[n[0]][n[1]][n[2]] 213 | if nval > max_val: 214 | max_val = nval 215 | max_acts = [i] 216 | elif nval == max_val: 217 | max_acts.append(i) 218 | 219 | # Choose max actions if several w/ same value 220 | max_act = max_acts[np.random.randint(len(max_acts))] 221 | policy[max_act][p_orient][p_y][p_x] = 1. 222 | return policy 223 | 224 | 225 | def extract_goal(goal_map): 226 | """Returns the goal location.""" 227 | for o in range(goal_map.shape[0]): 228 | for y in range(goal_map.shape[1]): 229 | for x in range(goal_map.shape[2]): 230 | if goal_map[o][y][x] == 1.: 231 | return (o, y, x) 232 | assert False 233 | -------------------------------------------------------------------------------- /utils/mechanism.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import abc 4 | 5 | 6 | class Mechanism(abc.ABC): 7 | """Base class for maze transition mechanisms.""" 8 | 9 | def __init__(self, num_actions, num_orient): 10 | self.num_actions = num_actions 11 | self.num_orient = num_orient 12 | 13 | @abc.abstractmethod 14 | def neighbors_func(self, maze, p_orient, p_y, p_x): 15 | """Computes next states for each action.""" 16 | 17 | @abc.abstractmethod 18 | def invneighbors_func(self, maze, p_orient, p_y, p_x): 19 | """Computes previous states for each action.""" 20 | 21 | @abc.abstractmethod 22 | def print_policy(self, maze, goal, policy): 23 | """Prints the given policy.""" 24 | 25 | 26 | class DifferentialDrive(Mechanism): 27 | """ 28 | In Differential Drive, the agent can move forward along its current 29 | orientation, or turn left/right by 90 degrees. 30 | """ 31 | 32 | def __init__(self): 33 | super(DifferentialDrive, self).__init__(num_actions=3, num_orient=4) 34 | self.clockwise = [1, 3, 0, 2] # E S N W 35 | self.cclockwise = [2, 0, 3, 1] # W N S E 36 | 37 | def _is_out_of_bounds(self, maze, p_y, p_x): 38 | return (p_x < 0 or p_x >= maze.shape[1] or p_y < 0 or 39 | p_y >= maze.shape[0]) 40 | 41 | def _forward(self, maze, p_orient, p_y, p_x): 42 | assert p_orient < self.num_orient, p_orient 43 | 44 | next_p_y, next_p_x = p_y, p_x 45 | if p_orient == 0: # North 46 | next_p_y -= 1 47 | elif p_orient == 1: # East 48 | next_p_x += 1 49 | elif p_orient == 2: # West 50 | next_p_x -= 1 51 | else: # South 52 | next_p_y += 1 53 | 54 | # If position is out of bounds, simply return the current state. 55 | if (self._is_out_of_bounds(maze, next_p_y, next_p_x) or 56 | maze[p_y][p_x] == 0.): 57 | return p_orient, p_y, p_x 58 | 59 | return p_orient, next_p_y, next_p_x 60 | 61 | def _turnright(self, p_orient, p_y, p_x): 62 | assert p_orient < self.num_orient, p_orient 63 | return self.clockwise[p_orient], p_y, p_x 64 | 65 | def _turnleft(self, p_orient, p_y, p_x): 66 | assert p_orient < self.num_orient, p_orient 67 | return self.cclockwise[p_orient], p_y, p_x 68 | 69 | def _backward(self, maze, p_orient, p_y, p_x): 70 | assert p_orient < self.num_orient, p_orient 71 | 72 | next_p_y, next_p_x = p_y, p_x 73 | if p_orient == 0: # North 74 | next_p_y += 1 75 | elif p_orient == 1: # East 76 | next_p_x -= 1 77 | elif p_orient == 2: # West 78 | next_p_x += 1 79 | else: # South 80 | next_p_y -= 1 81 | 82 | # If position is out of bounds, simply return the current state. 83 | if (self._is_out_of_bounds(maze, next_p_y, next_p_x) or 84 | maze[p_y][p_x] == 0.): 85 | return p_orient, p_y, p_x 86 | 87 | return p_orient, next_p_y, next_p_x 88 | 89 | def neighbors_func(self, maze, p_orient, p_y, p_x): 90 | return [ 91 | self._forward(maze, p_orient, p_y, p_x), 92 | self._turnright(p_orient, p_y, p_x), 93 | self._turnleft(p_orient, p_y, p_x), 94 | ] 95 | 96 | def invneighbors_func(self, maze, p_orient, p_y, p_x): 97 | return [ 98 | self._backward(maze, p_orient, p_y, p_x), 99 | self._turnleft(p_orient, p_y, p_x), 100 | self._turnright(p_orient, p_y, p_x), 101 | ] 102 | 103 | def print_policy(self, maze, goal, policy): 104 | orient2str = ["↑", "→", "←", "↓"] 105 | action2str = ["F", "R", "L"] 106 | for o in range(self.num_orient): 107 | print(orient2str[o]) 108 | for y in range(policy.shape[1]): 109 | for x in range(policy.shape[2]): 110 | if (o, y, x) == goal: 111 | print("!", end="") 112 | elif maze[y][x] == 0.: 113 | print(u"\u2588", end="") 114 | else: 115 | a = policy[o][y][x] 116 | print(action2str[a], end="") 117 | print("") 118 | 119 | 120 | class NorthEastWestSouth(Mechanism): 121 | """ 122 | In NEWS, the agent can move North, East, West, or South. 123 | """ 124 | 125 | def __init__(self): 126 | super(NorthEastWestSouth, self).__init__(num_actions=4, num_orient=1) 127 | 128 | def _north(self, maze, p_orient, p_y, p_x): 129 | if (p_y > 0) and (maze[p_y - 1][p_x] != 0.): 130 | return p_orient, p_y - 1, p_x 131 | return p_orient, p_y, p_x 132 | 133 | def _east(self, maze, p_orient, p_y, p_x): 134 | if (p_x < (maze.shape[1] - 1)) and (maze[p_y][p_x + 1] != 0.): 135 | return p_orient, p_y, p_x + 1 136 | return p_orient, p_y, p_x 137 | 138 | def _west(self, maze, p_orient, p_y, p_x): 139 | if (p_x > 0) and (maze[p_y][p_x - 1] != 0.): 140 | return p_orient, p_y, p_x - 1 141 | return p_orient, p_y, p_x 142 | 143 | def _south(self, maze, p_orient, p_y, p_x): 144 | if (p_y < (maze.shape[0] - 1)) and (maze[p_y + 1][p_x] != 0.): 145 | return p_orient, p_y + 1, p_x 146 | return p_orient, p_y, p_x 147 | 148 | def neighbors_func(self, maze, p_orient, p_y, p_x): 149 | return [ 150 | self._north(maze, p_orient, p_y, p_x), 151 | self._east(maze, p_orient, p_y, p_x), 152 | self._west(maze, p_orient, p_y, p_x), 153 | self._south(maze, p_orient, p_y, p_x), 154 | ] 155 | 156 | def invneighbors_func(self, maze, p_orient, p_y, p_x): 157 | return [ 158 | self._south(maze, p_orient, p_y, p_x), 159 | self._west(maze, p_orient, p_y, p_x), 160 | self._east(maze, p_orient, p_y, p_x), 161 | self._north(maze, p_orient, p_y, p_x), 162 | ] 163 | 164 | def print_policy(self, maze, goal, policy): 165 | action2str = ["↑", "→", "←", "↓"] 166 | for o in range(self.num_orient): 167 | for y in range(policy.shape[1]): 168 | for x in range(policy.shape[2]): 169 | if (o, y, x) == goal: 170 | print("!", end="") 171 | elif maze[y][x] == 0.: 172 | print(u"\u2588", end="") 173 | else: 174 | a = policy[o][y][x] 175 | print(action2str[a], end="") 176 | print("") 177 | print("") 178 | 179 | 180 | class Moore(Mechanism): 181 | """ 182 | In Moore, the agent can move to any of the eight cells in its Moore 183 | neighborhood. 184 | """ 185 | 186 | def __init__(self): 187 | super(Moore, self).__init__(num_actions=8, num_orient=1) 188 | 189 | def _north(self, maze, p_orient, p_y, p_x): 190 | if (p_y > 0) and (maze[p_y - 1][p_x] != 0.): 191 | return p_orient, p_y - 1, p_x 192 | return p_orient, p_y, p_x 193 | 194 | def _east(self, maze, p_orient, p_y, p_x): 195 | if (p_x < (maze.shape[1] - 1)) and (maze[p_y][p_x + 1] != 0.): 196 | return p_orient, p_y, p_x + 1 197 | return p_orient, p_y, p_x 198 | 199 | def _west(self, maze, p_orient, p_y, p_x): 200 | if (p_x > 0) and (maze[p_y][p_x - 1] != 0.): 201 | return p_orient, p_y, p_x - 1 202 | return p_orient, p_y, p_x 203 | 204 | def _south(self, maze, p_orient, p_y, p_x): 205 | if (p_y < (maze.shape[0] - 1)) and (maze[p_y + 1][p_x] != 0.): 206 | return p_orient, p_y + 1, p_x 207 | return p_orient, p_y, p_x 208 | 209 | def _northeast(self, maze, p_orient, p_y, p_x): 210 | if (p_y > 0) and (p_x < (maze.shape[1] - 1)) and (maze[p_y - 1][p_x + 1] 211 | != 0.): 212 | return p_orient, p_y - 1, p_x + 1 213 | return p_orient, p_y, p_x 214 | 215 | def _northwest(self, maze, p_orient, p_y, p_x): 216 | if (p_y > 0) and (p_x > 0) and (maze[p_y - 1][p_x - 1] != 0.): 217 | return p_orient, p_y - 1, p_x - 1 218 | return p_orient, p_y, p_x 219 | 220 | def _southeast(self, maze, p_orient, p_y, p_x): 221 | if (p_y < (maze.shape[0] - 1)) and (p_x < (maze.shape[1] - 1)) and ( 222 | maze[p_y + 1][p_x + 1] != 0.): 223 | return p_orient, p_y + 1, p_x + 1 224 | return p_orient, p_y, p_x 225 | 226 | def _southwest(self, maze, p_orient, p_y, p_x): 227 | if (p_y < (maze.shape[0] - 1)) and (p_x > 0) and (maze[p_y + 1][p_x - 1] 228 | != 0.): 229 | return p_orient, p_y + 1, p_x - 1 230 | return p_orient, p_y, p_x 231 | 232 | def neighbors_func(self, maze, p_orient, p_y, p_x): 233 | return [ 234 | self._north(maze, p_orient, p_y, p_x), 235 | self._east(maze, p_orient, p_y, p_x), 236 | self._west(maze, p_orient, p_y, p_x), 237 | self._south(maze, p_orient, p_y, p_x), 238 | self._northeast(maze, p_orient, p_y, p_x), 239 | self._northwest(maze, p_orient, p_y, p_x), 240 | self._southeast(maze, p_orient, p_y, p_x), 241 | self._southwest(maze, p_orient, p_y, p_x), 242 | ] 243 | 244 | def invneighbors_func(self, maze, p_orient, p_y, p_x): 245 | return [ 246 | self._south(maze, p_orient, p_y, p_x), 247 | self._west(maze, p_orient, p_y, p_x), 248 | self._east(maze, p_orient, p_y, p_x), 249 | self._north(maze, p_orient, p_y, p_x), 250 | self._southwest(maze, p_orient, p_y, p_x), 251 | self._southeast(maze, p_orient, p_y, p_x), 252 | self._northwest(maze, p_orient, p_y, p_x), 253 | self._northeast(maze, p_orient, p_y, p_x), 254 | ] 255 | 256 | def print_policy(self, maze, goal, policy): 257 | action2str = ["↑", "→", "←", "↓", "↗", "↖", "↘", "↙"] 258 | for o in range(self.num_orient): 259 | for y in range(policy.shape[1]): 260 | for x in range(policy.shape[2]): 261 | if (o, y, x) == goal: 262 | print("!", end="") 263 | elif maze[y][x] == 0.: 264 | print(u"\u2588", end="") 265 | else: 266 | a = policy[o][y][x] 267 | print(action2str[a], end="") 268 | print("") 269 | print("") 270 | -------------------------------------------------------------------------------- /utils/runner.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | from utils.dijkstra import dijkstra_dist, dijkstra_policy 10 | from utils.maze import extract_goal 11 | 12 | 13 | def get_optimizer(args, parameters): 14 | if args.optimizer == "RMSprop": 15 | return optim.RMSprop(parameters, lr=args.lr, eps=args.eps) 16 | elif args.optimizer == "Adam": 17 | return optim.Adam(parameters, lr=args.lr, eps=args.eps) 18 | elif args.optimizer == "SGD": 19 | return optim.SGD(parameters, lr=args.lr, momentum=0.95) 20 | else: 21 | raise ValueError("Unsupported optimizer: %s" % args.optimizer) 22 | 23 | 24 | class Runner(): 25 | """ 26 | The Runner class runs a planner model on a given dataset and records 27 | statistics such as loss, prediction error, % Optimal, and % Success. 28 | """ 29 | 30 | def __init__(self, args, mechanism): 31 | """ 32 | Args: 33 | model (torch.nn.Module): The Planner model 34 | mechanism (utils.mechanism.Mechanism): Environment transition kernel 35 | args (Namespace): Arguments 36 | """ 37 | self.use_gpu = args.use_gpu 38 | self.clip_grad = args.clip_grad 39 | self.lr_decay = args.lr_decay 40 | self.use_percent_successful = args.use_percent_successful 41 | 42 | self.mechanism = mechanism 43 | self.criterion = nn.CrossEntropyLoss() 44 | 45 | # Instantiate the model 46 | model_module = importlib.import_module(args.model) 47 | self.model = model_module.Planner( 48 | mechanism.num_orient, mechanism.num_actions, args) 49 | self.best_model = model_module.Planner( 50 | mechanism.num_orient, mechanism.num_actions, args) 51 | 52 | # Load model from file if provided 53 | if args.load_file != "": 54 | saved_model = torch.load(args.load_file) 55 | if args.load_best: 56 | self.model.load_state_dict(saved_model["best_model"]) 57 | else: 58 | self.model.load_state_dict(saved_model["model"]) 59 | self.best_model.load_state_dict(saved_model["best_model"]) 60 | else: 61 | self.best_model.load_state_dict(self.model.state_dict()) 62 | 63 | # Track the best performing model so far 64 | self.best_metric = 0. 65 | 66 | # Use GPU if available 67 | if self.use_gpu: 68 | self.model = self.model.cuda() 69 | self.best_model = self.best_model.cuda() 70 | 71 | self.optimizer = get_optimizer(args, self.model.parameters()) 72 | 73 | def _compute_stats(self, batch_size, map_design, goal_map, 74 | outputs, predictions, labels, 75 | loss, opt_policy, sample=False): 76 | # Select argmax policy 77 | _, pred_pol = torch.max(outputs, dim=1, keepdim=True) 78 | 79 | # Convert to numpy arrays 80 | map_design = map_design.cpu().data.numpy() 81 | goal_map = goal_map.cpu().data.numpy() 82 | outputs = outputs.cpu().data.numpy() 83 | predictions = predictions.cpu().data.numpy() 84 | labels = labels.cpu().data.numpy() 85 | opt_policy = opt_policy.cpu().data.numpy() 86 | pred_pol = pred_pol.cpu().data.numpy() 87 | 88 | max_pred = (predictions == predictions.max(axis=1)[:, None]).astype( 89 | np.float32) 90 | match_action = np.sum((max_pred != opt_policy).astype(np.float32), axis=1) 91 | match_action = (match_action == 0).astype(np.float32) 92 | match_action = np.reshape(match_action, (batch_size, -1)) 93 | batch_error = 1 - np.mean(match_action) 94 | 95 | def calc_optimal_and_success(i): 96 | # Get current sample 97 | md = map_design[i][0] 98 | gm = goal_map[i] 99 | op = opt_policy[i] 100 | pp = pred_pol[i][0] 101 | ll = labels[i][0] 102 | 103 | # Extract the goal in 2D coordinates 104 | goal = extract_goal(gm) 105 | 106 | # Check how different the predicted policy is from the optimal one 107 | # in terms of path lengths 108 | pred_dist = dijkstra_policy(md, self.mechanism, goal, pp) 109 | opt_dist = dijkstra_dist(md, self.mechanism, goal) 110 | diff_dist = pred_dist - opt_dist 111 | 112 | wall_dist = np.min(pred_dist) # impossible distance 113 | 114 | for o in range(self.mechanism.num_orient): 115 | # Refill the walls in the difference with the impossible distance 116 | diff_dist[o] += (1 - md) * wall_dist 117 | 118 | # Mask out the walls in the prediction distances 119 | pred_dist[o] = pred_dist[o] - np.multiply(1 - md, pred_dist[o]) 120 | 121 | num_open = md.sum() * self.mechanism.num_orient # number of reachable locations 122 | return (diff_dist == 0).sum() / num_open, 1. - ( 123 | pred_dist == wall_dist).sum() / num_open 124 | 125 | if sample: 126 | percent_optimal, percent_successful = calc_optimal_and_success( 127 | np.random.randint(batch_size)) 128 | else: 129 | percent_optimal, percent_successful = 0, 0 130 | for i in range(batch_size): 131 | po, ps = calc_optimal_and_success(i) 132 | percent_optimal += po 133 | percent_successful += ps 134 | percent_optimal = percent_optimal / float(batch_size) 135 | percent_successful = percent_successful / float(batch_size) 136 | 137 | return loss.data.item(), batch_error, percent_optimal, percent_successful 138 | 139 | def _run(self, model, dataloader, train=False, batch_size=-1, 140 | store_best=False): 141 | """ 142 | Runs the model on the given data. 143 | Args: 144 | model (torch.nn.Module): The Planner model 145 | dataloader (torch.utils.data.Dataset): Dataset loader 146 | train (bool): Whether to train the model 147 | batch_size (int): Only used if train=True 148 | store_best (bool): Whether to store the best model 149 | Returns: 150 | info (dict): Performance statistics, including 151 | info["avg_loss"] (float): Average loss 152 | info["avg_error"] (float): Average error 153 | info["avg_optimal"] (float): Average % Optimal 154 | info["avg_success"] (float): Average % Success 155 | info["weight_norm"] (float): Model weight norm, stored if train=True 156 | info["grad_norm"]: Gradient norm, stored if train=True 157 | info["is_best"] (bool): Whether the model is best, stored if store_best=True 158 | """ 159 | info = {} 160 | for key in ["avg_loss", "avg_error", "avg_optimal", "avg_success"]: 161 | info[key] = 0.0 162 | num_batches = 0 163 | 164 | for i, data in enumerate(dataloader): 165 | # Get input batch. 166 | map_design, goal_map, opt_policy = data 167 | 168 | if train: 169 | if map_design.size()[0] != batch_size: 170 | continue # Drop those data, if not enough for a batch 171 | self.optimizer.zero_grad() # Zero the parameter gradients 172 | else: 173 | batch_size = map_design.size()[0] 174 | 175 | # Send tensor to GPU if available 176 | if self.use_gpu: 177 | map_design = map_design.cuda() 178 | goal_map = goal_map.cuda() 179 | opt_policy = opt_policy.cuda() 180 | map_design = Variable(map_design) 181 | goal_map = Variable(goal_map) 182 | opt_policy = Variable(opt_policy) 183 | 184 | # Reshape batch-wise if necessary 185 | if map_design.dim() == 3: 186 | map_design = map_design.unsqueeze(1) 187 | 188 | # Forward pass 189 | outputs, predictions, _, _ = model(map_design, goal_map) 190 | 191 | # Loss 192 | flat_outputs = outputs.transpose(1, 4).contiguous() 193 | flat_outputs = flat_outputs.view( 194 | -1, flat_outputs.size()[-1]).contiguous() 195 | _, labels = opt_policy.max(1, keepdim=True) 196 | flat_labels = labels.transpose(1, 4).contiguous() 197 | flat_labels = flat_labels.view(-1).contiguous() 198 | loss = self.criterion(flat_outputs, flat_labels) 199 | 200 | # Select actions with max scores (logits) 201 | _, predicted = torch.max(outputs, dim=1, keepdim=True) 202 | 203 | # Update parameters 204 | if train: 205 | # Backward pass 206 | loss.backward() 207 | 208 | # Clip the gradient norm 209 | if self.clip_grad: 210 | torch.nn.utils.clip_grad_norm_(model.parameters(), 211 | self.clip_grad) 212 | 213 | # Update parameters 214 | self.optimizer.step() 215 | 216 | # Compute loss and error 217 | loss_batch, batch_error, p_opt, p_suc = self._compute_stats( 218 | batch_size, map_design, goal_map, 219 | outputs, predictions, labels, 220 | loss, opt_policy, sample=train) 221 | info["avg_loss"] += loss_batch 222 | info["avg_error"] += batch_error 223 | info["avg_optimal"] += p_opt 224 | info["avg_success"] += p_suc 225 | num_batches += 1 226 | 227 | info["avg_loss"] = info["avg_loss"] / num_batches 228 | info["avg_error"] = info["avg_error"] / num_batches 229 | info["avg_optimal"] = info["avg_optimal"] / num_batches 230 | info["avg_success"] = info["avg_success"] / num_batches 231 | 232 | if train: 233 | # Calculate weight norm 234 | weight_norm = 0 235 | grad_norm = 0 236 | for p in model.parameters(): 237 | weight_norm += torch.norm(p)**2 238 | if p.grad is not None: 239 | grad_norm += torch.norm(p.grad)**2 240 | info["weight_norm"] = float(np.sqrt(weight_norm.cpu().data.numpy().item())) 241 | info["grad_norm"] = float(np.sqrt(grad_norm.cpu().data.numpy().item())) 242 | 243 | if store_best: 244 | # Was the validation accuracy greater than the best one? 245 | metric = (info["avg_success"] if self.use_percent_successful else 246 | info["avg_optimal"]) 247 | if metric > self.best_metric: 248 | self.best_metric = metric 249 | self.best_model.load_state_dict(model.state_dict()) 250 | info["is_best"] = True 251 | else: 252 | for param_group in self.optimizer.param_groups: 253 | param_group["lr"] = param_group["lr"] * self.lr_decay 254 | info["is_best"] = False 255 | return info 256 | 257 | def train(self, dataloader, batch_size): 258 | """ 259 | Trains the model on the given training dataset. 260 | """ 261 | return self._run(self.model, dataloader, train=True, 262 | batch_size=batch_size) 263 | 264 | def validate(self, dataloader): 265 | """ 266 | Evaluates the model on the given validation dataset. Stores the 267 | current model if it achieves the best validation performance. 268 | """ 269 | return self._run(self.model, dataloader, store_best=True) 270 | 271 | def test(self, dataloader, use_best=False): 272 | """ 273 | Tests the model on the given dataset. 274 | """ 275 | if use_best: 276 | model = self.best_model 277 | else: 278 | model = self.model 279 | model.eval() 280 | return self._run(model, dataloader, store_best=True) 281 | --------------------------------------------------------------------------------