├── .gitmodules
├── .idea
├── E2C.iml
├── encodings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── .ipynb_checkpoints
└── test_plot-checkpoint.ipynb
├── README.md
├── __pycache__
├── datasets.cpython-36.pyc
├── e2c_model.cpython-36.pyc
├── networks.cpython-36.pyc
├── normal.cpython-36.pyc
└── train_e2c.cpython-36.pyc
├── data
├── __pycache__
│ ├── sample_cartpole_data.cpython-36.pyc
│ ├── sample_pendulum_data.cpython-36.pyc
│ ├── sample_planar_data.cpython-36.pyc
│ ├── sample_planar_data_2.cpython-36.pyc
│ └── sample_planar_partial.cpython-36.pyc
├── draw_planar_env.py
├── env.npy
├── env.png
├── sample_cartpole_data.py
├── sample_pendulum_data.py
└── sample_planar.py
├── datasets.py
├── draw_latent_map.ipynb
├── e2c.yml
├── e2c_model.py
├── evaluate_saved_model.ipynb
├── map.png
├── networks.py
├── normal.py
├── test_plot.ipynb
└── train_e2c.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "gym"]
2 | path = gym
3 | url = https://github.com/ethanluoyc/gym/tree/pendulum_internal
4 |
--------------------------------------------------------------------------------
/.idea/E2C.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
120 |
121 |
122 |
123 | im
124 | exp
125 | sigma
126 | angle_normalize
127 | width
128 | plt
129 | render
130 | print
131 | env
132 | load
133 | save
134 | ten
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 | 1568356268398
350 |
351 |
352 | 1568356268398
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
--------------------------------------------------------------------------------
/.ipynb_checkpoints/test_plot-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from tensorboardX import SummaryWriter\n",
10 | "import matplotlib.pyplot as plt\n",
11 | "import numpy as np"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 28,
17 | "metadata": {},
18 | "outputs": [
19 | {
20 | "data": {
21 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAH6UlEQVR4nO3dP2iV9x7H8e8xKZIhpAm6CNLQTqV26VJEl6CTLoIFoZOUbiXgFAdt6FTopINgugmC4K6LiAQ0g5DSRqhDh7ZKCRSV1j+hU3K63tt0OdyvN597+3qNh8fP8yMS3nnO8gyGw2EBQJpdO30AAPg7AgVAJIECIJJAARBJoACINL7TB4B/mj179gxnZ2d3+hgQ45tvvnk6HA73/vVzgYL/stnZ2VpdXd3pY0CMwWDw6O8+9xUfAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACKNj3LxxMTEcHJysu3mMzMzbVuPHz9u26qqevfdd1v3/vjjj7at3bt3t21VVf3000+te/v372/dW19fb9uamJho2/rtt99qY2Nj0DYI/JuRAjU5OVknT55su/nHH3/ctvXZZ5+1bVVVraystO49fPiwbevtt99u26qqOn36dOvehQsXWvcWFxfbtt577722rUuXLrVtAdv5ig+ASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBII71Rd2xsrPU17Z1bR48ebduqqvr9999b93788ce2rRs3brRtVVX9/PPPrXvff/99695XX33VtrW2tta21fn6eGA7T1AARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIo38yvc333yz7eZffPFF29ZgMGjbqqr69ttvW/eOHTvWtvXOO++0bVVVffDBB617y8vLrXtnzpxp27p9+3bblle+w+vlCQqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECINL4qP9gc3Oz7eZzc3NtW3fv3m3bqqpaX19v3bt161bb1vLycttWVdX58+db95aWllr3On92V69ebdt69uxZ2xawnScoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBI46NcPDY2VlNTU6/rLP+Rjz76qHXvhx9+aN2bnp5u27p48WLbVlXViRMnWvdu3rzZujc3N9e2dfTo0bata9eutW0B23mCAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBppFe+v3r1qu7du9d280OHDrVtHTlypG2rqur+/fute8eOHWvbWlxcbNuqqlpaWmrdO3v2bOve+++/37Y1Pz/ftvXkyZO2LWA7T1AARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJHGR7l4amqqjh8/3nbzp0+ftm1duXKlbauqamJionXv3LlzkVtVVWtra617MzMzrXud/xdzc3NtW8vLy21bwHaeoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIo2PcvHW1lZtbGy03fzgwYNtW9PT021bVVWXL19u3du3b1/b1pdfftm2VVW1sLDQuvf111+37q2urrZt3blzp23rxYsXbVvAdp6gAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAijY9y8ebmZj1//rzt5p9//nnb1uHDh9u2qqoePHjQurewsNC2df369batqqq1tbXWvZcvX7buHThwoG1ra2urbWtlZaVtC9jOExQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASCO98n3Pnj316aeftt38k08+adsaDAZtW1VVp06dat1744032rY+/PDDtq2qql9++aV1b35+vnXv119/bduanJxs29q1y9938Dr5DQMgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiDQ+ysXffffd0+np6Uev6zDwP+atnT4A/D8bKVDD4XDv6zoIAPwrX/EBEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVApMFwONzpM8A/ymAweFJVj3b6HBDkreFwuPevHwoUAJF8xQdAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQKQ/AdSh4aTL/fw9AAAAAElFTkSuQmCC\n",
22 | "text/plain": [
23 | ""
24 | ]
25 | },
26 | "metadata": {},
27 | "output_type": "display_data"
28 | }
29 | ],
30 | "source": [
31 | "writer = SummaryWriter('test')\n",
32 | "fig, ax = plt.subplots(nrows=1, ncols=2)\n",
33 | "plt.setp(ax, xticks=[], yticks=[])\n",
34 | "# ax[0].set_ylabel('test', rotation=0,size='large')\n",
35 | "ax[0].annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),\n",
36 | " xycoords=ax.yaxis.label, textcoords='offset points',\n",
37 | " size='large', ha='right', va='center')\n",
38 | "ax[0].imshow(np.random.randn(10,10), cmap='Greys')\n",
39 | "fig.tight_layout()"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 26,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "writer.add_figure('test', fig, 0)"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": []
57 | }
58 | ],
59 | "metadata": {
60 | "kernelspec": {
61 | "display_name": "Python 3",
62 | "language": "python",
63 | "name": "python3"
64 | },
65 | "language_info": {
66 | "codemirror_mode": {
67 | "name": "ipython",
68 | "version": 3
69 | },
70 | "file_extension": ".py",
71 | "mimetype": "text/x-python",
72 | "name": "python",
73 | "nbconvert_exporter": "python",
74 | "pygments_lexer": "ipython3",
75 | "version": "3.6.8"
76 | }
77 | },
78 | "nbformat": 4,
79 | "nbformat_minor": 2
80 | }
81 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Embed to Control
2 |
3 | This is a pytorch implementation of the paper "[Embed to Control: A Locally Linear Latent Dynamics Model for Control from Raw Images](https://arxiv.org/abs/1506.07365)", NIPS, 2015.
4 |
5 | **Note: This is not and official implementation.**
6 |
7 | ### Installing
8 |
9 | First, clone the repository:
10 |
11 | ```
12 | git clone https://github.com/tungnd1705/E2C-pytorch.git
13 | ```
14 |
15 | Install the dependencies as listed in `env.yml` and activate the environment
16 |
17 | ```
18 | conda env create -f env.yml
19 |
20 | conda activate e2c
21 | ```
22 |
23 | Then install the patch version of gym in order to sample the pendulum data
24 |
25 | ```
26 | cd gym
27 |
28 | python setup.py install
29 | ```
30 |
31 | ### Simulate training data
32 |
33 | Currently the code supports simulating 3 environments: `planar`, `pendulum` and `cartpole`.
34 |
35 | In order to generate data, simply run `python sample_{env_name}_data.py --sample_size={sample_size}`.
36 |
37 | **Note: the sample size is equal to the total number of training and test data**
38 |
39 | For the planar task, we base on [this](https://github.com/ethanluoyc/e2c-pytorch) implementation and modify for our needs.
40 |
41 | ### Training
42 |
43 | Run the ``train_e2c.py`` with your own settings. Example:
44 |
45 | ```
46 | python train_e2c.py \
47 | --env=planar \
48 | --propor=3/4 \
49 | --batch_size=128 \
50 | --lr=0.0001 \
51 | --lam=0.25 \
52 | --num_iter=5000 \
53 | --iter_save=1000
54 | ```
55 |
56 | You can visualize the training process by running ``tensorboard --logdir=logs``.
57 |
58 | ### Citation
59 |
60 | If you find E2C useful in your research, please consider citing:
61 |
62 | ```
63 | @inproceedings{watter2015embed,
64 | title={Embed to control: A locally linear latent dynamics model for control from raw images},
65 | author={Watter, Manuel and Springenberg, Jost and Boedecker, Joschka and Riedmiller, Martin},
66 | booktitle={Advances in neural information processing systems},
67 | pages={2746--2754},
68 | year={2015}
69 | }
70 | ```
71 |
--------------------------------------------------------------------------------
/__pycache__/datasets.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tung-nd/E2C-pytorch/8af375b58e1ddef3cd041d518b7061faece2f102/__pycache__/datasets.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/e2c_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tung-nd/E2C-pytorch/8af375b58e1ddef3cd041d518b7061faece2f102/__pycache__/e2c_model.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/networks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tung-nd/E2C-pytorch/8af375b58e1ddef3cd041d518b7061faece2f102/__pycache__/networks.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/normal.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tung-nd/E2C-pytorch/8af375b58e1ddef3cd041d518b7061faece2f102/__pycache__/normal.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/train_e2c.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tung-nd/E2C-pytorch/8af375b58e1ddef3cd041d518b7061faece2f102/__pycache__/train_e2c.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/sample_cartpole_data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tung-nd/E2C-pytorch/8af375b58e1ddef3cd041d518b7061faece2f102/data/__pycache__/sample_cartpole_data.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/sample_pendulum_data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tung-nd/E2C-pytorch/8af375b58e1ddef3cd041d518b7061faece2f102/data/__pycache__/sample_pendulum_data.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/sample_planar_data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tung-nd/E2C-pytorch/8af375b58e1ddef3cd041d518b7061faece2f102/data/__pycache__/sample_planar_data.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/sample_planar_data_2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tung-nd/E2C-pytorch/8af375b58e1ddef3cd041d518b7061faece2f102/data/__pycache__/sample_planar_data_2.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/sample_planar_partial.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tung-nd/E2C-pytorch/8af375b58e1ddef3cd041d518b7061faece2f102/data/__pycache__/sample_planar_partial.cpython-36.pyc
--------------------------------------------------------------------------------
/data/draw_planar_env.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageDraw
2 | import numpy as np
3 | width, height = 40, 40
4 | r = 2.5
5 | obstacles_center = np.array([[20.5, 5.5], [20.5, 12.5], [20.5, 27.5], [20.5, 35.5], [10.5, 20.5], [30.5, 20.5]])
6 |
7 | def generate_env():
8 | print ('Making the environment...')
9 | img_arr = np.zeros(shape=(width,height))
10 |
11 | img_env = Image.fromarray(img_arr)
12 | draw = ImageDraw.Draw(img_env)
13 | for y, x in obstacles_center:
14 | draw.ellipse((int(x)-int(r), int(y)-int(r), int(x)+int(r), int(y)+int(r)), fill=255)
15 | img_env = img_env.convert('L')
16 | img_env.save('env.png')
17 |
18 | img_arr = np.array(img_env) / 255.
19 | np.save('./env.npy', img_arr)
20 | return img_arr
21 |
22 | env = generate_env()
--------------------------------------------------------------------------------
/data/env.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tung-nd/E2C-pytorch/8af375b58e1ddef3cd041d518b7061faece2f102/data/env.npy
--------------------------------------------------------------------------------
/data/env.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tung-nd/E2C-pytorch/8af375b58e1ddef3cd041d518b7061faece2f102/data/env.png
--------------------------------------------------------------------------------
/data/sample_cartpole_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 | from tqdm import trange
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import gym
7 | import json
8 | from datetime import datetime
9 | import argparse
10 |
11 | def sample_cartpole(sample_size, output_dir='data/cartpole', step_size=5, apply_control=True, num_shards=10):
12 | env = gym.make('CartPole-v0').env
13 | assert sample_size % num_shards == 0
14 |
15 | samples = []
16 |
17 | if not path.exists(output_dir):
18 | os.makedirs(output_dir)
19 |
20 | state = env.reset()
21 | for i in trange(sample_size):
22 | """
23 | for each sample:
24 | - draw a random state
25 | - render x_t
26 | - draw a random action u_t and apply
27 | - render x_t+1 after applying u_t
28 | """
29 | # state = env.reset()
30 |
31 | initial_state = state
32 | before = env.render(mode = 'rgb_array')
33 |
34 | # apply the same control over a few timesteps
35 | if apply_control:
36 | u = env.action_space.sample()
37 | else:
38 | u = np.zeros((1,))
39 |
40 | for _ in range(step_size):
41 | state, reward, done, info = env.step(u)
42 |
43 | after_state = state
44 | after = env.render(mode = 'rgb_array')
45 |
46 | shard_no = i // (sample_size // num_shards)
47 |
48 | shard_path = path.join('{:03d}-of-{:03d}'.format(shard_no, num_shards))
49 |
50 | if not path.exists(path.join(output_dir, shard_path)):
51 | os.makedirs(path.join(output_dir, shard_path))
52 |
53 | before_file = path.join(shard_path, 'before-{:05d}.jpg'.format(i))
54 | plt.imsave(path.join(output_dir, before_file), before)
55 |
56 | after_file = path.join(shard_path, 'after-{:05d}.jpg'.format(i))
57 | plt.imsave(path.join(output_dir, after_file), after)
58 |
59 | samples.append({
60 | 'before_state': initial_state.tolist(),
61 | 'after_state': after_state.tolist(),
62 | 'before': before_file,
63 | 'after': after_file,
64 | 'control': [u],
65 | })
66 |
67 | with open(path.join(output_dir, 'data.json'), 'wt') as outfile:
68 | json.dump(
69 | {
70 | 'metadata': {
71 | 'num_samples': sample_size,
72 | 'step_size': step_size,
73 | 'apply_control': apply_control,
74 | 'time_created': str(datetime.now()),
75 | 'version': 1
76 | },
77 | 'samples': samples
78 | }, outfile, indent=2)
79 |
80 | env.viewer.close()
81 |
82 | def main(args):
83 | sample_size = args.sample_size
84 |
85 | sample_cartpole(sample_size=sample_size)
86 |
87 | if __name__ == "__main__":
88 | parser = argparse.ArgumentParser(description='sample data')
89 |
90 | parser.add_argument('--sample_size', required=True, type=int, help='the number of samples')
91 |
92 | args = parser.parse_args()
93 |
94 | main(args)
--------------------------------------------------------------------------------
/data/sample_pendulum_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 | from tqdm import trange
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import gym
7 | import json
8 | from datetime import datetime
9 | import argparse
10 |
11 | env = gym.make('Pendulum-v0').env
12 | width, height = 48 * 2, 48
13 |
14 | def render(state):
15 | # need two observations to restore the Markov property
16 | before1 = state
17 | before2 = env.step_from_state(state, np.array([0]))
18 | return map(env.render_state, (before1[0], before2[0]))
19 |
20 | def sample_pendulum(sample_size, output_dir='data/pendulum', step_size=1, apply_control=True, num_shards=10):
21 | assert sample_size % num_shards == 0
22 |
23 | samples = []
24 |
25 | if not path.exists(output_dir):
26 | os.makedirs(output_dir)
27 |
28 | for i in trange(sample_size):
29 | """
30 | for each sample:
31 | - draw a random state (theta, theta dot)
32 | - render x_t (including 2 images)
33 | - draw a random action u_t and apply
34 | - render x_t+1 after applying u_t
35 | """
36 | # th (theta) and thdot (theta dot) represent a state in Pendulum env
37 | th = np.random.uniform(0, np.pi * 2)
38 | thdot = np.random.uniform(-8, 8)
39 |
40 | state = np.array([th, thdot])
41 |
42 | initial_state = np.copy(state)
43 | before1, before2 = render(state)
44 |
45 | # apply the same control over a few timesteps
46 | if apply_control:
47 | u = np.random.uniform(-2, 2, size=(1,))
48 | else:
49 | u = np.zeros((1,))
50 |
51 | for _ in range(step_size):
52 | state = env.step_from_state(state, u)
53 |
54 | after_state = np.copy(state)
55 | after1, after2 = render(state)
56 |
57 | before = np.hstack((before1, before2))
58 | after = np.hstack((after1, after2))
59 |
60 | shard_no = i // (sample_size // num_shards)
61 |
62 | shard_path = path.join('{:03d}-of-{:03d}'.format(shard_no, num_shards))
63 |
64 | if not path.exists(path.join(output_dir, shard_path)):
65 | os.makedirs(path.join(output_dir, shard_path))
66 |
67 | before_file = path.join(shard_path, 'before-{:05d}.jpg'.format(i))
68 | plt.imsave(path.join(output_dir, before_file), before)
69 |
70 | after_file = path.join(shard_path, 'after-{:05d}.jpg'.format(i))
71 | plt.imsave(path.join(output_dir, after_file), after)
72 |
73 | samples.append({
74 | 'before_state': initial_state.tolist(),
75 | 'after_state': after_state.tolist(),
76 | 'before': before_file,
77 | 'after': after_file,
78 | 'control': u.tolist(),
79 | })
80 |
81 | with open(path.join(output_dir, 'data.json'), 'wt') as outfile:
82 | json.dump(
83 | {
84 | 'metadata': {
85 | 'num_samples': sample_size,
86 | 'step_size': step_size,
87 | 'apply_control': apply_control,
88 | 'time_created': str(datetime.now()),
89 | 'version': 1
90 | },
91 | 'samples': samples
92 | }, outfile, indent=2)
93 |
94 | env.viewer.close()
95 |
96 | def main(args):
97 | sample_size = args.sample_size
98 |
99 | sample_pendulum(sample_size=sample_size)
100 |
101 | if __name__ == "__main__":
102 | parser = argparse.ArgumentParser(description='sample data')
103 |
104 | parser.add_argument('--sample_size', required=True, type=int, help='the number of samples')
105 |
106 | args = parser.parse_args()
107 |
108 | main(args)
--------------------------------------------------------------------------------
/data/sample_planar.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from os import path
4 | from tqdm import trange
5 | import json
6 | from datetime import datetime
7 | import argparse
8 | from PIL import Image, ImageDraw
9 |
10 | np.random.seed(1)
11 |
12 | width, height = 40, 40
13 | obstacles_center = np.array([[20.5, 5.5], [20.5, 12.5], [20.5, 27.5], [20.5, 35.5], [10.5, 20.5], [30.5, 20.5]])
14 |
15 | r_overlap = 0.5 # agent cannot be in any rectangular area with obstacles as centers and half-width = 0.5
16 | r = 2.5 # radius of the obstacles when rendered
17 | rw = 3 # robot half-width
18 | rw_rendered = 2 # robot half-width when rendered
19 | max_step_len = 3
20 | env_path = os.path.dirname(os.path.abspath(__file__))
21 | env = np.load(env_path + '/env.npy')
22 |
23 | def get_pixel_location(s):
24 | # return the location of agent when rendered
25 | center_x, center_y = int(round(s[0])), int(round(s[1]))
26 | top = center_x - rw_rendered
27 | bottom = center_x + rw_rendered
28 | left = center_y - rw_rendered
29 | right = center_y + rw_rendered
30 | return top, bottom, left, right
31 |
32 | def render(s):
33 | top, bottom, left, right = get_pixel_location(s)
34 | x = np.copy(env)
35 | x[top:bottom, left:right] = 1. # robot is white on black background
36 | return x
37 |
38 | def is_valid(s, u, s_next, epsilon = 0.1):
39 | # if the difference between the action and the actual distance between x and x_next are in range(0,epsilon)
40 | top, bottom, left, right = get_pixel_location(s)
41 | top_next, bottom_next, left_next, right_next = get_pixel_location(s_next)
42 | x_diff = np.array([top_next - top, left_next - left], dtype=np.float)
43 | return (not np.sqrt(np.sum((x_diff - u)**2)) > epsilon)
44 |
45 | def is_colliding(s):
46 | """
47 | :param s: the continuous coordinate (x, y) of the agent center
48 | :return: if agent body overlaps with obstacles
49 | """
50 | if np.any([s - rw < 0, s + rw > height]):
51 | return True
52 | x, y = s[0], s[1]
53 | for obs in obstacles_center:
54 | if np.abs(obs[0] - x) <= r_overlap and np.abs(obs[1] - y) <= r_overlap:
55 | return True
56 | return False
57 |
58 | def random_step(s):
59 | # draw a random step until it doesn't collidie with the obstacles
60 | while True:
61 | u = np.random.uniform(low = -max_step_len, high = max_step_len, size = 2)
62 | s_next = s + u
63 | if (not is_colliding(s_next) and is_valid(s, u, s_next)):
64 | return u, s_next
65 |
66 | def sample(sample_size):
67 | """
68 | return [(s, u, s_next)]
69 | """
70 | state_samples = []
71 | for i in trange(sample_size, desc = 'Sampling data'):
72 | while True:
73 | s_x = np.random.uniform(low = rw, high = height - rw)
74 | s_y = np.random.uniform(low = rw, high = width - rw)
75 | s = np.array([s_x, s_y])
76 | if not is_colliding(s):
77 | break
78 | u, s_next = random_step(s)
79 | state_samples.append((s, u, s_next))
80 | obs_samples = [(render(s), u, render(s_next)) for s, u, s_next in state_samples]
81 | return state_samples, obs_samples
82 |
83 | def write_to_file(sample_size, output_dir = './data/planar'):
84 | """
85 | write [(x, u, x_next)] to output dir
86 | """
87 | if not path.exists(output_dir):
88 | os.makedirs(output_dir)
89 |
90 | state_samples, obs_samples = sample(sample_size)
91 |
92 | samples = []
93 |
94 | for i, (before, u, after) in enumerate(obs_samples):
95 | before_file = 'before-{:05d}.png'.format(i)
96 | Image.fromarray(before * 255.).convert('L').save(path.join(output_dir, before_file))
97 |
98 | after_file = 'after-{:05d}.png'.format(i)
99 | Image.fromarray(after * 255.).convert('L').save(path.join(output_dir, after_file))
100 |
101 | initial_state = state_samples[i][0]
102 | after_state = state_samples[i][2]
103 |
104 | samples.append({
105 | 'before_state': initial_state.tolist(),
106 | 'after_state': after_state.tolist(),
107 | 'before': before_file,
108 | 'after': after_file,
109 | 'control': u.tolist(),
110 | })
111 |
112 | with open(path.join(output_dir, 'data.json'), 'wt') as outfile:
113 | json.dump(
114 | {
115 | 'metadata': {
116 | 'num_samples': sample_size,
117 | 'max_distance': max_step_len,
118 | 'time_created': str(datetime.now()),
119 | 'version': 1
120 | },
121 | 'samples': samples
122 | }, outfile, indent=2)
123 |
124 | def main(args):
125 | sample_size = args.sample_size
126 |
127 | write_to_file(sample_size=sample_size)
128 |
129 | if __name__ == "__main__":
130 | parser = argparse.ArgumentParser(description='sample data')
131 |
132 | parser.add_argument('--sample_size', required=True, type=int, help='the number of samples')
133 |
134 | args = parser.parse_args()
135 |
136 | main(args)
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 | from PIL import Image
4 | import numpy as np
5 | import json
6 | from torchvision.transforms import ToTensor
7 | from torch.utils.data import Dataset
8 | from tqdm import tqdm
9 | import pickle
10 | import torch
11 |
12 | torch.set_default_dtype(torch.float64)
13 |
14 | class PlanarDataset(Dataset):
15 | width = 40
16 | height = 40
17 | action_dim = 2
18 |
19 | def __init__(self, dir):
20 | self.dir = dir
21 | with open(path.join(dir, 'data.json')) as f:
22 | self._data = json.load(f)
23 | self._process()
24 |
25 | def __len__(self):
26 | return len(self._data['samples'])
27 |
28 | def __getitem__(self, index):
29 | return self._processed[index]
30 |
31 | @staticmethod
32 | def _process_image(img):
33 | return ToTensor()((img.convert('L').
34 | resize((PlanarDataset.width,
35 | PlanarDataset.height))))
36 |
37 | def _process(self):
38 | preprocessed_file = os.path.join(self.dir, 'processed.pkl')
39 | if not os.path.exists(preprocessed_file):
40 | processed = []
41 | for sample in tqdm(self._data['samples'], desc='processing data'):
42 | before = Image.open(os.path.join(self.dir, sample['before']))
43 | after = Image.open(os.path.join(self.dir, sample['after']))
44 |
45 | processed.append((self._process_image(before),
46 | np.array(sample['control']),
47 | self._process_image(after)))
48 |
49 | with open(preprocessed_file, 'wb') as f:
50 | pickle.dump(processed, f)
51 | self._processed = processed
52 | else:
53 | with open(preprocessed_file, 'rb') as f:
54 | self._processed = pickle.load(f)
55 |
56 | class GymPendulumDatasetV2(Dataset):
57 | width = 48 * 2
58 | height = 48
59 | action_dim = 1
60 |
61 | def __init__(self, dir):
62 | self.dir = dir
63 | with open(path.join(dir, 'data.json')) as f:
64 | self._data = json.load(f)
65 | self._process()
66 |
67 | def __len__(self):
68 | return len(self._data['samples'])
69 |
70 | def __getitem__(self, index):
71 | return self._processed[index]
72 |
73 | @staticmethod
74 | def _process_image(img):
75 | return ToTensor()((img.convert('L').
76 | resize((GymPendulumDatasetV2.width,
77 | GymPendulumDatasetV2.height))))
78 |
79 | def _process(self):
80 | preprocessed_file = os.path.join(self.dir, 'processed.pkl')
81 | if not os.path.exists(preprocessed_file):
82 | processed = []
83 | for sample in tqdm(self._data['samples'], desc='processing data'):
84 | before = Image.open(os.path.join(self.dir, sample['before']))
85 | after = Image.open(os.path.join(self.dir, sample['after']))
86 |
87 | processed.append((self._process_image(before),
88 | np.array(sample['control']),
89 | self._process_image(after)))
90 |
91 | with open(preprocessed_file, 'wb') as f:
92 | pickle.dump(processed, f)
93 | self._processed = processed
94 | else:
95 | with open(preprocessed_file, 'rb') as f:
96 | self._processed = pickle.load(f)
--------------------------------------------------------------------------------
/draw_latent_map.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import matplotlib.pyplot as plt\n",
10 | "from PIL import Image, ImageDraw\n",
11 | "import colour\n",
12 | "from random import randint as rint\n",
13 | "import numpy as np\n",
14 | "from colour import Color\n",
15 | "from data.sample_planar_partial import *"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 2,
21 | "metadata": {},
22 | "outputs": [
23 | {
24 | "data": {
25 | "text/plain": [
26 | "[[10, 20],\n",
27 | " [10, 21],\n",
28 | " [11, 20],\n",
29 | " [11, 21],\n",
30 | " [20, 5],\n",
31 | " [20, 6],\n",
32 | " [20, 12],\n",
33 | " [20, 13],\n",
34 | " [20, 27],\n",
35 | " [20, 28],\n",
36 | " [20, 35],\n",
37 | " [20, 36],\n",
38 | " [21, 5],\n",
39 | " [21, 6],\n",
40 | " [21, 12],\n",
41 | " [21, 13],\n",
42 | " [21, 27],\n",
43 | " [21, 28],\n",
44 | " [21, 35],\n",
45 | " [21, 36],\n",
46 | " [30, 20],\n",
47 | " [30, 21],\n",
48 | " [31, 20],\n",
49 | " [31, 21]]"
50 | ]
51 | },
52 | "execution_count": 2,
53 | "metadata": {},
54 | "output_type": "execute_result"
55 | }
56 | ],
57 | "source": [
58 | "start, end = 3, 37\n",
59 | "unvalid_pos = []\n",
60 | "for x in range(start, end + 1):\n",
61 | " for y in range(start, end + 1):\n",
62 | " s = [x,y]\n",
63 | " if is_colliding(np.array(s)):\n",
64 | " unvalid_pos.append(s)\n",
65 | "unvalid_pos"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 3,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "red = Color('blue')\n",
75 | "# red_rgb = np.array(Color('red').rgb)\n",
76 | "# blue_rgb = np.array(Color('blue').rgb)\n",
77 | "colors = list(red.range_to(Color(\"red\"),36))\n",
78 | "colors_rgb = [color.rgb for color in colors]\n",
79 | "# print (colors_rgb)"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 4,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "def random_gradient():\n",
89 | " img = Image.new(\"RGB\", (width,height), \"#FFFFFF\")\n",
90 | " draw = ImageDraw.Draw(img)\n",
91 | " \n",
92 | " for i, color in zip(range(start, end+1), colors_rgb):\n",
93 | " r1, g1, b1 = color[0] * 255., color[1] * 255., color[2] * 255.\n",
94 | " draw.line((i,start,i,end), fill=(int(r1), int(g1), int(b1)))\n",
95 | " \n",
96 | " # obstacles as white\n",
97 | "# for y, x in obstacles_center:\n",
98 | "# draw.rectangle(((x - r, y - r), (x + r, y + r)), fill=\"white\")\n",
99 | "\n",
100 | " img_arr = np.array(img)\n",
101 | " for x, y in unvalid_pos:\n",
102 | " img_arr[x, y] = 255.\n",
103 | " return img_arr / 255., img"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 5,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "img_arr, img = random_gradient()"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": 6,
118 | "metadata": {
119 | "scrolled": false
120 | },
121 | "outputs": [
122 | {
123 | "data": {
124 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD6CAYAAABnLjEDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAMO0lEQVR4nO3df6jdd33H8edrsbUyrbUkKaVNl07CVhlrBlkpdH900Y6s/6SCQgsbGRSqsIKCiJn/qGNChWn3zxAUu+YPZy1V1zC6H6FWnGPUxhpranSpXdbGhCRFa9o/2pH2vT/ON1tM7mlOzvl+zz33fp4POJxzPufzPef9vfe+7vec7/3e7ztVhaTV79eWuwBJ82HYpUYYdqkRhl1qhGGXGmHYpUbMFPYk25L8JMkzSXb2VZSk/mXav7MnWQP8J3ALcBh4Arijqn40bpm1a9fWxo0bp3o9Sed36NAhXnjhhSz12JtmeN4bgGeq6lmAJA8A24GxYd+4cSN79+6d4SUlvZEtW7aMfWyWt/FXAc+fcf9wNyZpAc0S9qXeKpzzmSDJXUn2Jtl74sSJGV5O0ixmCfthYMMZ968Gjpw9qaq+UFVbqmrLunXrZng5SbOYJexPAJuSXJvkYuB2YHc/ZUnq29Q76KrqVJK7gX8B1gD3VdXTvVUmqVez7I2nqh4BHumpFkkD8gg6qRGGXWqEYZcaYdilRhh2qRGGXWqEYZcaYdilRsx0UM28Zcn/0uVXj9B/o7Gh5o5b/polxta/Mmby8xOOXejc53p/3ks5ueTSs35pl/py9fG8S41dcnzME/T/5epn7nOz93dwyy41wrBLjTDsUiMMu9QIwy41wrBLjTDsUiMMu9QIwy41YqYj6JIcAl4CXgNOVdX4M9RLWlZ9HC77h1X1Qg/PI2lAvo2XGjFr2Av41yTfS3JXHwVJGsasb+NvqqojSdYDe5L8uKq+feaE7pfAXQDXXDPu/5okDW2mLXtVHemujwPfYNTZ9ew5tn+SFsDUYU/y60nedvo28EfA/r4Kk9SvWd7GXwF8I6MzSrwJ+Puq+udeqpLUu1l6vT0LXN9jLZIG5J/epEYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qxHnDnuS+JMeT7D9j7PIke5Ic7K7fMWyZWi7FL8+5nDsyumixTbJlvx/YdtbYTuDRqtoEPNrdl7TAzhv2rsPLz88a3g7s6m7vAm7ruS5JPZv2M/sVVXUUoLteP25ikruS7E2y98SJE1O+nKRZDb6DzvZP0mKYNuzHklwJ0F0f768kSUOYNuy7gR3d7R3Aw/2UI2kok/zp7SvAfwC/leRwkjuBe4BbkhwEbunuS1pg5+31VlV3jHno3T3XImlAHkEnNcKwS42YpT+7GhDefs7YpZxccu6GoYvRTNyyS40w7FIjDLvUCMMuNcKwS40w7FIjDLvUCMMuNcKwS40w7FIjDLvUCMMuNcKwS40w7FIjDLvUiGnbP30yyc+S7Osutw5bpqRZTdv+CeDeqtrcXR7ptyxJfZu2/ZOkFWaWz+x3J3mqe5tvF1dpwU0b9s8D7wQ2A0eBz46baK83aTFMFfaqOlZVr1XV68AXgRveYK693qQFMFXYT/d567wX2D9urqTFcN5TSXftn24G1iY5DHwCuDnJZqCAQ8AHBqxRUg+mbf/0pQFqkTQgj6CTGmHYpUYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qhGGXGmHYpUZM0v5pQ5LHkhxI8nSSD3XjlyfZk+Rgd+2546UFdt5z0AGngI9U1ZNJ3gZ8L8ke4M+AR6vqniQ7gZ3Ax4Yr9cLUv08+Nx8dro6zFZeMeWTTOSPh+WGLWcG2UhPPfYUMWMmEHpi8Xq4Zpt5J2j8draonu9svAQeAq4DtwK5u2i7gtkEqlNSLC/rMnmQj8HvA48AVVXUURr8QgPV9FyepPxOHPclbga8BH66qkxewnO2fpAUwUdiTXMQo6F+uqq93w8dOd4bpro8vtaztn6TFMMne+DBqCnGgqj53xkO7gR3d7R3Aw/2XJ6kvk+yNvwn4U+CHSfZ1Yx8H7gEeTHIn8Bzw/mFKlNSHSdo/fQfG/u3i3f2WI2koHkEnNcKwS40w7FIjJtlBtyLlpjEPbJhwbCDhlTGPeGjshfjmmN1Iy/ztHe/2Jeqd87fcLbvUCMMuNcKwS40w7FIjDLvUCMMuNcKwS40w7FIjDLvUCMMuNcKwS40w7FIjDLvUCMMuNWKW9k+fTPKzJPu6y63DlytpWrO0fwK4t6r+erjyJPVlkhNOHgVOd355Kcnp9k+SVpBZ2j8B3J3kqST32cVVWmyztH/6PPBOYDOjLf9nxyxn+ydpAUzd/qmqjlXVa1X1OvBF4IallrX9k7QYpm7/dLrPW+e9wP7+y5PUl1naP92RZDNQwCHgA4NUKKkXs7R/eqT/ciQNxSPopEYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qhGGXGmHYpUYYdqkRk5xw8pIk303yg67906e68WuTPJ7kYJKvJrl4+HIlTWuSLfurwNaqup7ROeK3JbkR+Ayj9k+bgF8Adw5XppZL8ctzLueOjC5abOcNe4283N29qLsUsBV4qBvfBdw2SIWSejFpk4g13WmkjwN7gJ8CL1bVqW7KYez/Ji20icLedX7ZDFzNqPPLdUtNW2pZ2z9Ji+GC9sZX1YvAt4AbgcuSnD7v/NXAkTHL2P5JWgCT7I1fl+Sy7vZbgPcAB4DHgPd103YADw9VpKTZTdL+6UpgV5I1jH45PFhV/5jkR8ADSf4K+D6jfnCSFtQk7Z+eYtST/ezxZxnTuVXS4vEIOqkRhl1qhGGXGjHJDjo1LLz9nLFLObnk3A1DF6OZuGWXGmHYpUYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qhGGXGmHYpUbM0v7p/iT/lWRfd9k8fLmSpjXJ/7Ofbv/0cpKLgO8k+afusY9W1UNvsKykBTHJCScLWKr9k6QVZKr2T1X1ePfQp5M8leTeJG8erEpJM5uq/VOS3wH+Avht4PeBy4GPLbWs7Z+kxTBt+6dtVXW06/D6KvB3jDmHvO2fpMUwbfunHye5shsLo3bN+4csVNJsZmn/9M0k64AA+4APDlinpBnN0v5p6yAVSRqER9BJjTDsUiMMu9QIwy41wrBLjTDsUiMMu9QIwy41wrBLjZjkcNmFUaviv+gvGTO+acIx9WL9BYxvGbKQ+XHLLjXCsEuNMOxSIwy71AjDLjXCsEuNMOxSIwy71AjDLjXCsEuNSM3xGNQkJ4D/7u6uBV6Y24vPj+u18qymdfuNqlqyQcNcw/4rL5zsrapVctTx/3O9Vp7VvG5n8m281AjDLjViOcP+hWV87SG5XivPal63/7Nsn9klzZdv46VGzD3sSbYl+UmSZ5LsnPfr9ynJfUmOJ9l/xtjlSfYkOdhdv2M5a5xGkg1JHktyIMnTST7Uja/odUtySZLvJvlBt16f6savTfJ4t15fTXLxctc6hLmGvesE+7fAHwPvAu5I8q551tCz+4FtZ43tBB6tqk3Ao939leYU8JGqug64Efjz7vu00tftVWBrVV0PbAa2JbkR+Axwb7devwDuXMYaBzPvLfsNwDNV9WxV/Q/wALB9zjX0pqq+Dfz8rOHtwK7u9i5GvetXlKo6WlVPdrdfAg4AV7HC161GXu7uXtRdCtgKPNSNr7j1mtS8w34V8PwZ9w93Y6vJFVV1FEahYfypDVeEJBsZtex+nFWwbknWJNkHHAf2AD8FXqyqU92U1fgzCcw/7FlizD8HLKgkbwW+Bny4qk4udz19qKrXqmozcDWjd5rXLTVtvlXNx7zDfhjYcMb9q4Ejc65haMeSXAnQXR9f5nqmkuQiRkH/clV9vRteFesGUFUvAt9itE/isiSnT6u+Gn8mgfmH/QlgU7f382LgdmD3nGsY2m5gR3d7B/DwMtYylSQBvgQcqKrPnfHQil63JOuSXNbdfgvwHkb7Ix4D3tdNW3HrNam5H1ST5Fbgb4A1wH1V9em5FtCjJF8Bbmb0X1PHgE8A/wA8CFwDPAe8v6rO3om30JL8AfBvwA+B17vhjzP63L5i1y3J7zLaAbeG0Ybuwar6yyS/yWhn8eXA94E/qapXl6/SYXgEndQIj6CTGmHYpUYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qxP8C+66SiDeKzo0AAAAASUVORK5CYII=\n",
125 | "text/plain": [
126 | ""
127 | ]
128 | },
129 | "metadata": {
130 | "needs_background": "light"
131 | },
132 | "output_type": "display_data"
133 | }
134 | ],
135 | "source": [
136 | "plt.imshow(img_arr)\n",
137 | "# plt.axis('off')\n",
138 | "plt.show()"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 7,
144 | "metadata": {},
145 | "outputs": [
146 | {
147 | "data": {
148 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAD8CAYAAAB3lxGOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAUV0lEQVR4nO3dfawc1XnH8e+v5i0tCENwqMtLTZJbCnGIoS4xokqJCQ1GVR2kpDVSA0lR3SaOklQtxbQo0BaU96AiUQIVKKYvODQpwqImlBpQRMSbTRxjoI4dQsGJGxJeHNIIUujTP+bM3bl7Z8e7d3fu3t35faTVzJ49u3vGL8/dfc5zz1FEYGbN9XPDHoCZDZeDgFnDOQiYNZyDgFnDOQiYNZyDgFnD1RYEJJ0taYekXZLW1vU+ZtYf1VEnIGke8G3gLGA38DBwXkQ8PvA3M7O+1PVJ4FRgV0Q8GRE/A9YDK2t6LzPrw341ve5RwDOF+7uBt3fqfMQRR8SiRYtqGoqZAWzZsuVHEbGgvb2uIKCStinfOyStBlYDHHvssWzevLmmoZgZgKT/Kmuv6+vAbuCYwv2jge8XO0TE9RGxNCKWLlgwLTiZ2SypKwg8DExIOk7SAcAqYENN72Vmfajl60BEvCrpI8CdwDzgxoh4rI73MrP+1JUTICI2Ahvren0zGwxXDJo1nIOAWcM5CJg1nIOAWcM5CJg1nIOAWcM5CJg1nIOAWcM5CJg13FgEAZX8zqKOzW7T2k8vaVuV3aa1X1TSdnV2m9Z+a0nb5uw2pe3Z7DatLy+XtO1E7Cxpv7ukbR1iXUn7FSVtq1H2C5xt7StK2hYjFre1HYo4dFrfQ9MjRYvTK7RbUdK2Oo2s3RUlbevSFbe7u6RtZ/qTbPdy2S+7PqvsVrRZ2a3drSVtVyu7tbuopG2Vslu700vajlV2q8HIB4E8AJQFgml9T596rOx70dRjZd9bpx5h+n/+ac8pBII8AJQFgmnPSwGgLBBM73vFlGN13xVTjtn54k7d0+OtQND+n79dMRDkAaAsELTLA0BZIGiXB4CyQNAuDwBTAkH7f/52xUCQB4CyQNAuDwBlgaBdHgDKAkFNRj4I5KujdbNKWnxj6rGy72enHiv7njv1CBBL9/GcNxTOOWjKsfJ5LJ9yrO576ZRjdd87phyz8+37eM7eyfO9VP8FbC88fkc6v2MfzwG4NPW5tIu+y1Of5V30PSj1OajY9w37eN7SwuPnxtRjlc/G1GOVb8TU4yyoZY3BXi1dujS8qIhZvSRtiZj+42nkPwmYWX8cBMwarrb1BGZLMSFY/GaTzwzE04W2QkKwmBfIZwZifaGtkBAs5gXymYH4aKGtkBDM8wLFxGDxA1gxIZjnBYoJwWJeIJ8ZCCYKbXcX+i4vtK9LbRcU2q4o9L200L46tV1faFtR6HtHaltcaNte6HtooT3LCxQTg8X8QDEhmOcFignBYl4gnxm4vtBWTAgW8wL5zMAFhbZiQrCYF8hnBiYKbcWE4GReoJgYLOYHignBPC9QTAgW8wL5zMBHC23FhGAxL5DPDKwvtBUTgsW8QD4z8PTgv777k4BZwzkxaNYQnRKDfX0dkPQU8BLwGvBqRCyVdDjwZWAR8BTwuxHxQj/vY2b1GcTXgXdGxJJChFkLbIqICWBTum9mc1QdOYGVMFm/ug54Tw3vMYXLhl02nHPZcO/6DQIB/LukLWlHIYAjI2IPQDq+oeOzB8Blw1V9XTbcicuGW/oNAqdHxCnACmCNpHd0+0RJqyVtlrT5hz/84YwH4LLhqr4uG+7EZcMtA5sdkHQ58BPgD4EzImKPpIXAvRFxfNVzPTtgVr+Blw1L+gVJh+TnwG8B28m2G8srVi4Abpvpe5hZ/fqZIjwSuFXZl/H9gH+OiK9Jehi4RdKFwNPA+/ofppnVZcZBICKeBN5W0v4ccGY/g+qFy4bzdpcNu2x4Zlw2bNZwLhs2awivJ2BmpRwEzBpuLIKAy4ZdNpxz2XDvRj4IuGy4qq/Lhjtx2XDLyAcBlw1X9XXZcCcuG27x7IBZQ3h2wMxKOQiYNZxXG8Zlw1mby4ZdNmxmjeTEoFlDODFoZqUcBMwabiyCgMuGXTacc9lw70Y+CLhsuKqvy4Y7cdlwy8gHAZcNV/V12XAnLhtu8eyAWUPMeHZA0o2SnpW0vdB2uKS7JO1Mx8NSuyRdLWmXpG2SThnsZZjZoHXzdeBLwNltbZ32G1wBTKTbauDawQyzmhODs5sY7IUTg8koJwYj4uvA823NnfYbXAncFJkHgPlpA5LaODFY1bfexGAvnBhk7BKDnfYbPAp4ptBvd2qbZlDbkJlZf7pKDEpaBNweEYvT/RcjYn7h8Rci4jBJ/wZ8MiLuS+2bgD+PiC1Vr+/EoFn9Bl02/IP8Y3465h9udwPHFPodDXx/hu9hZrNgpkGg036DG4Dz0yzBMmBv/rXBzOamfa4nIOlm4AzgCEm7gcuAT1G+3+BG4BxgF/BT4IM1jLltfK1zrycwO+sJ9MLrCTDn1xPYZxCIiPM6PDRtv8HIEgxr+h2Umc0elw136uuy4Rn/9C/jsmFcNlzFswNm9fOiImZWaiyCgMuGvZ5AzmXDvRv5IOCy4aq+Xk+gE5cNt4x8EHBisKqv1xPoxInBFicGzRrCiUEzK+UgYNZw3oYMlw1nbd6GrKllw/4kYNZwTgyaNYQTg2ZWykHArOHGIgi4bNhlwzmXDfdu5IOAy4ar+rpsuBOXDbeMfBBw2XBVX5cNd+Ky4RbPDpg1xKC3Ibtc0vckbU23cwqPXZK2Idsh6d2DuwQzq8NMtyEDuCoilqTbRgBJJwKrgLek5/ydpHmDGqyZDV43C41+PW0+0o2VwPqIeAX4rqRdwKnA/TMe4T64bDhvd9mwy4Znpp/E4EfSzsM35rsS423IzEbOTLchOxL4ERDA3wALI+IPJF0D3B8R/5j63QBsjIivVr2+E4Nm9Rto2XBE/CAiXouI/wP+nuwjP3gbMrORM6Mg0Lbd+Lkw+YVxA7BK0oGSjgMmgIf6G6KZ1ambKcKbyRJ7x0vanbYe+4ykRyVtA94J/AlARDwG3AI8DnwNWBMRr9U2+skxlrS5bDi1u2zYZcPV9hkEIuK8iFgYEftHxNERcUNEvD8i3hoRJ0XE7xQ3HY2IKyPiTRFxfETcUfXag+Cy4aq+LhvuxGXDLS4b7tTXZcMuG27nsuH6eHbArH5eVMTMSjkImDWcVxvGZcNZm8uGXTZsZo3kxKBZQzgxaGalHATMGm4sgoDLhl02nHPZcO9GPgi4bLiqr8uGO3HZcMvIBwGXDVf1ddlwJy4bbvHsgFlDeHbAzEqNRRBwYtCJwZwTg70b+SDgxGBVXycGO3FisGXkg4CZ9ceJQbOG6GcbsmMk3SPpCUmPSfpYaj9c0l2SdqbjYaldkq5OW5Ftk3TK4C/HzAalm68DrwJ/GhEnAMuANWm7sbXApoiYADal+wAryFYZngBWA9cOfNRmNjDdbEO2B9iTzl+S9ATZrkIrgTNSt3XAvcDFqf2myL5nPCBpvqSFxcVIB6mX9QR6et1ZXE+gp3F5PQGgfD2BXsyJ9QR6MVfWE0g7EZ0MPAgcmf/HTsf8n3RXW5F5GzKzuaHrICDpYOCrwMcj4sdVXUvapoWviLg+IpZGxNIFCxZ0O4zpL9xD2XBPrzuLZcM9jctlwwMxZ8uGh6DbvQj3B24H7oyIL6S2HcAZEbEn7Uh0b0QcL+m6dH5ze79Or+/ZAbP69TM7IOAG4Ik8ACQbYPIL6AXAbYX289MswTJgb135ADPrXzdfB04H3g8sl7Q13c4BPgWcJWkncFa6D7AReBLYRbZZ6YcHP+ypeikb7ul1Z7FsuKdxuWy4Y9lwL4ZeNtyLIW9Ddl9EKG05tiTdNkbEcxFxZkRMpOPzqX9ExJq0FdlbI6LWz/m9lA339LqzWDbc07hcNjwQc7ZseAhcNmzWcC4bNmsIrydgZqUcBMwaztuQgbchYzzKhr0N2cz4k4BZw418EPBqw1V9m1U27NWGZ8azA2YN4dkBMys1FkHAqw27bDjn1YZ7N/JBwKsNV/VtVtmwVxuemZEPAk4MVvV1YrATJwZbnBg0awgnBs2slIOAWcO5bBiXDWdtLht22bCZNZITg2YNUcc2ZJdL+l7buoP5cy5J25DtkPTuwV6KmQ1SNzmBfBuyRyQdAmyRdFd67KqI+Fyxc9qibBXwFuCXgP+Q9CsR8dogB25mg9HNQqN7IuKRdP4SkG9D1slKYH1EvBIR3yVbdfjUQQy2E5cNu2w457Lh3vWzDRnAR9LOwzfmuxLT5TZkg+Ky4aq+LhvuxGXDLf1sQ3Yt8CZgCdmGpZ/Pu5Y8fVr2cVB7EbpsuKqvy4Y7cdlwy4y3IWt7fBFwe0QslnQJQER8Mj12J3B5RNzf6fU9O2BWv4FvQ5b2H8ydC5M/NjYAqyQdKOk4YAJ4qJ/Bm1l9upkdyLche1TS1tT2F8B5kpaQfdR/CvgjgIh4TNItwONkMwtrPDNgNnftMwhExH2Uf8/fWPGcK4Er+xhX11w2nLe7bNhlwzPjsmGzhnPZsFlDeD0BMyvlIGDWcGMRBFw27LLhnMuGezfyQcBlw1V9XTbcicuGW0Y+CLhsuKqvy4Y7cdlwi2cHzBrCswNmVspBwKzhvNowLhvO2lw27LJhM2skJwbNGsKJQTMr5SBg1nBjEQRcNuyy4ZzLhns38kHAZcNVfV023InLhltGPgi4bLiqr8uGO3HZcMs+ZwckHQR8HTiQrK7gKxFxWVpEdD1wOPAI8P6I+JmkA4GbgF8DngN+LyKeqnoPzw6Y1a+f2YFXgOUR8TayPQbOlrQM+DTZNmQTwAvAhan/hcALEfFm4KrUz8zmqG62IYuI+Em6u3+6BbAc+EpqXwe8J52vTPdJj5+Zli2vjRODs5sY7IUTg8moJwYlzUvLjT8L3AV8B3gxIl5NXYpbjU1uQ5Ye3wu8fpCDnjq2qcfKvk4MVvTtPTHYCycGGe3EYES8FhFLgKPJNhc9oaxbOpaNvrZtyMysPz2XDUu6DPgpcDHwixHxqqTTyLYae3dx2zFJ+wH/DSyIijdyYtCsfv1sQ7ZA0vx0/jrgXWTbk98DvDd1uwC4LZ1vSPdJj99dFQDMbLi6+VXihcA6SfPIgsYtEXG7pMeB9ZKuAL5Jtl8h6fgPknYBzwMlKTczmyu62YZsG3BySfuTZPmB9vaXgfcNZHRd8HoCefvsrSfQC68ngNcTMLO5beSDgMuGq/rWWzbcC5cNM7plw7PBswNm9fOiImZWaiyCgMuGvZ5AzmXDvRv5IOCy4aq+Xk+gE5cNt4x8EHBisKqv1xPoxInBFicGzRrCiUEzK+UgYNZw3oYMlw1nbd6GzGXDZtZITgyaNYQTg2ZWykHArOHGIgi4bNhlwzmXDfdu5IOAy4ar+rpsuBOXDbeMfBBw2XBVX5cNd+Ky4ZZ+tiH7EvCbMPkv4QMRsTVtNPK3wDlkqxJ/ICIeqXoPzw6Y1a/T7EA3xUL5NmQ/kbQ/cJ+k/MfFRRHxlbb+K4CJdHs7cG06mtkc1M82ZJ2sBG5Kz3sAmC9pYf9DNbM6dFU2nJYb3wK8GbgmIh6U9CHgSkmfADYBayPiFQrbkCX5FmV7BjryybG1zl027LLhnMuGuzejbcgkLQYuAX4V+HWy7ckvTt3L0prTRu5tyMzmhpluQ/Y/EfG5QtsZwJ9FxG9Lug64NyJuTo/tAM6IiI6fBJwYNKvfoLch+8/8e36aDXgPTH5e3ACcr8wyYG9VADCz4epnG7K7JS0g+/i/Ffjj1H8j2fTgLrIpwg8OfthmNijdzA5si4iTI+KkiFgcEX+d2pdHxFtT2+/nMwhpVmBNRLwpPV7753yXDbtsOOey4d6NfMWgy4ar+rpsuBOXDbeMfBBw2XBVX5cNd+Ky4RYvKmLWEF5UxMxKOQiYNZxXG8Zlw1mby4ZdNmxmjeTEoFlDODFoZqUcBMwazkHArOEcBMwazkHArOEcBMwazkHArOEcBMwazkHArOEcBMwazkHArOEcBMwazkHArOEcBMwabk78KrGkl4Adwx5HTY4AfjTsQdRgXK8LxvfafjkiFrQ3zpWVhXaU/Z7zOJC0eRyvbVyvC8b72sr464BZwzkImDXcXAkC1++7y8ga12sb1+uC8b62aeZEYtDMhmeufBIwsyEZehCQdLakHZJ2SVo77PH0StKNkp6VtL3QdrikuyTtTMfDUrskXZ2udZukU4Y38mqSjpF0j6QnJD0m6WOpfaSvTdJBkh6S9K10XX+V2o+T9GC6ri9LOiC1H5ju70qPLxrm+GsREUO7AfOA7wBvBA4AvgWcOMwxzeAa3gGcAmwvtH0GWJvO1wKfTufnAHcAApYBDw57/BXXtRA4JZ0fAnwbOHHUry2N7+B0vj/wYBrvLcCq1P5F4EPp/MPAF9P5KuDLw76Ggf+ZDPkv5DTgzsL9S4BLhv2HMoPrWNQWBHYAC9P5QrI6CIDrgPPK+s31G3AbcNY4XRvw88AjwNvJioP2S+2T/y6BO4HT0vl+qZ+GPfZB3ob9deAo4JnC/d2pbdQdGRF7ANIx34h8JK83fQQ+meyn5shfm6R5krYCzwJ3kX0afTEiXk1dimOfvK70+F7g9bM74noNOwiopG2cpytG7nolHQx8Ffh4RPy4qmtJ25y8toh4LSKWAEcDpwInlHVLx5G5rpkadhDYDRxTuH808P0hjWWQfiBpIUA65tuQjtT1StqfLAD8U0T8a2oei2sDiIgXgXvJcgLzJeVl9MWxT15XevxQ4PnZHWm9hh0EHgYmUmb2ALLEy4Yhj2kQNsDk9sAXkH2fztvPT5n0ZcDe/KP1XCNJwA3AExHxhcJDI31tkhZImp/OXwe8C3gCuAd4b+rWfl359b4XuDtSgmBsDDspQZZV/jbZ97K/HPZ4ZjD+m4E9wP+S/dS4kOw74yZgZzoenvoKuCZd66PA0mGPv+K6foPsY+82YGu6nTPq1wacBHwzXdd24BOp/Y3AQ8Au4F+AA1P7Qen+rvT4G4d9DYO+uWLQrOGG/XXAzIbMQcCs4RwEzBrOQcCs4RwEzBrOQcCs4RwEzBrOQcCs4f4fE0aLNZZcRNUAAAAASUVORK5CYII=\n",
149 | "text/plain": [
150 | ""
151 | ]
152 | },
153 | "metadata": {
154 | "needs_background": "light"
155 | },
156 | "output_type": "display_data"
157 | }
158 | ],
159 | "source": [
160 | "img_scaled = Image.new(\"RGB\", (width * 10,height*10), \"#FFFFFF\")\n",
161 | "draw = ImageDraw.Draw(img_scaled)\n",
162 | "for y in range(start, end + 1):\n",
163 | " for x in range(start, end + 1):\n",
164 | " if [y, x] in unvalid_pos:\n",
165 | " continue\n",
166 | " else:\n",
167 | " x_scaled, y_scaled = x * 10, y * 10\n",
168 | " draw.ellipse((x_scaled-2, y_scaled-2, x_scaled+2, y_scaled+2), fill = img.getpixel((x,y)))\n",
169 | "img_scaled.save('map.png', 'PNG')\n",
170 | "img_arr_scaled = np.array(img_scaled) / 255.\n",
171 | "plt.imshow(img_arr_scaled)\n",
172 | "plt.show()"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 31,
178 | "metadata": {},
179 | "outputs": [],
180 | "source": [
181 | "import torch\n",
182 | "from e2c_model import E2C"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 32,
188 | "metadata": {},
189 | "outputs": [
190 | {
191 | "data": {
192 | "text/plain": [
193 | "E2C(\n",
194 | " (encoder): PlanarEncoder(\n",
195 | " (net): Sequential(\n",
196 | " (0): Linear(in_features=1600, out_features=150, bias=True)\n",
197 | " (1): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
198 | " (2): ReLU()\n",
199 | " (3): Linear(in_features=150, out_features=150, bias=True)\n",
200 | " (4): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
201 | " (5): ReLU()\n",
202 | " (6): Linear(in_features=150, out_features=150, bias=True)\n",
203 | " (7): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
204 | " (8): ReLU()\n",
205 | " (9): Linear(in_features=150, out_features=4, bias=True)\n",
206 | " )\n",
207 | " )\n",
208 | " (decoder): PlanarDecoder(\n",
209 | " (net): Sequential(\n",
210 | " (0): Linear(in_features=2, out_features=200, bias=True)\n",
211 | " (1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
212 | " (2): ReLU()\n",
213 | " (3): Linear(in_features=200, out_features=200, bias=True)\n",
214 | " (4): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
215 | " (5): ReLU()\n",
216 | " (6): Linear(in_features=200, out_features=1600, bias=True)\n",
217 | " (7): Sigmoid()\n",
218 | " )\n",
219 | " )\n",
220 | " (trans): PlanarTransition(\n",
221 | " (net): Sequential(\n",
222 | " (0): Linear(in_features=2, out_features=100, bias=True)\n",
223 | " (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
224 | " (2): ReLU()\n",
225 | " (3): Linear(in_features=100, out_features=100, bias=True)\n",
226 | " (4): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
227 | " (5): ReLU()\n",
228 | " )\n",
229 | " (fc_A): Sequential(\n",
230 | " (0): Linear(in_features=100, out_features=4, bias=True)\n",
231 | " (1): Sigmoid()\n",
232 | " )\n",
233 | " (fc_B): Linear(in_features=100, out_features=4, bias=True)\n",
234 | " (fc_o): Linear(in_features=100, out_features=2, bias=True)\n",
235 | " )\n",
236 | ")"
237 | ]
238 | },
239 | "execution_count": 32,
240 | "metadata": {},
241 | "output_type": "execute_result"
242 | }
243 | ],
244 | "source": [
245 | "model = E2C(1600,2,2,'planar_partial').cuda(0)\n",
246 | "model.load_state_dict(torch.load('result/planar_partial/log_check_valid/model_5000'))\n",
247 | "model.eval()"
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": 33,
253 | "metadata": {},
254 | "outputs": [],
255 | "source": [
256 | "all_z = []\n",
257 | "for x in range(start, end + 1):\n",
258 | " for y in range(start, end + 1):\n",
259 | " s = np.array([x,y])\n",
260 | " if [x,y] in unvalid_pos:\n",
261 | " all_z.append(np.zeros(2))\n",
262 | " else:\n",
263 | " obs = render(s)\n",
264 | " with torch.no_grad():\n",
265 | " obs = torch.Tensor(render(s)).unsqueeze(0).view(-1,1600).double().cuda(0)\n",
266 | " mu, sigma = model.encode(obs)\n",
267 | " z = mu.squeeze().cpu().numpy()\n",
268 | " all_z.append(np.copy(z))\n",
269 | "all_z = np.array(all_z)"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": 34,
275 | "metadata": {},
276 | "outputs": [
277 | {
278 | "name": "stdout",
279 | "output_type": "stream",
280 | "text": [
281 | "(3, 3): [23 3]\n",
282 | "(3, 4): [23 3]\n",
283 | "(3, 5): [22 3]\n",
284 | "(3, 6): [22 1]\n",
285 | "(3, 7): [21 2]\n",
286 | "(3, 8): [21 1]\n",
287 | "(3, 9): [20 3]\n",
288 | "(3, 10): [20 8]\n",
289 | "(3, 11): [20 10]\n",
290 | "(3, 12): [20 12]\n",
291 | "(3, 13): [20 13]\n",
292 | "(3, 14): [20 13]\n",
293 | "(3, 15): [20 14]\n",
294 | "(3, 16): [20 14]\n",
295 | "(3, 17): [20 15]\n",
296 | "(3, 18): [19 15]\n",
297 | "(3, 19): [19 15]\n",
298 | "(3, 20): [19 16]\n",
299 | "(3, 21): [19 24]\n",
300 | "(3, 22): [19 24]\n",
301 | "(3, 23): [20 24]\n",
302 | "(3, 24): [20 24]\n",
303 | "(3, 25): [20 24]\n",
304 | "(3, 26): [20 24]\n",
305 | "(3, 27): [21 24]\n",
306 | "(3, 28): [21 23]\n",
307 | "(3, 29): [21 23]\n",
308 | "(3, 30): [22 23]\n",
309 | "(3, 31): [22 23]\n",
310 | "(3, 32): [22 15]\n",
311 | "(3, 33): [22 14]\n",
312 | "(3, 34): [23 13]\n",
313 | "(3, 35): [23 12]\n",
314 | "(3, 36): [23 11]\n",
315 | "(3, 37): [23 10]\n",
316 | "(4, 3): [23 6]\n",
317 | "(4, 4): [22 5]\n",
318 | "(4, 5): [22 5]\n",
319 | "(4, 6): [22 3]\n",
320 | "(4, 7): [21 4]\n",
321 | "(4, 8): [21 2]\n",
322 | "(4, 9): [20 2]\n",
323 | "(4, 10): [20 5]\n",
324 | "(4, 11): [20 8]\n",
325 | "(4, 12): [20 11]\n",
326 | "(4, 13): [20 12]\n",
327 | "(4, 14): [19 13]\n",
328 | "(4, 15): [19 13]\n",
329 | "(4, 16): [19 14]\n",
330 | "(4, 17): [19 14]\n",
331 | "(4, 18): [19 15]\n",
332 | "(4, 19): [19 15]\n",
333 | "(4, 20): [19 16]\n",
334 | "(4, 21): [19 25]\n",
335 | "(4, 22): [19 25]\n",
336 | "(4, 23): [20 25]\n",
337 | "(4, 24): [20 24]\n",
338 | "(4, 25): [20 24]\n",
339 | "(4, 26): [21 24]\n",
340 | "(4, 27): [21 24]\n",
341 | "(4, 28): [21 24]\n",
342 | "(4, 29): [22 24]\n",
343 | "(4, 30): [22 24]\n",
344 | "(4, 31): [22 23]\n",
345 | "(4, 32): [23 15]\n",
346 | "(4, 33): [23 14]\n",
347 | "(4, 34): [23 13]\n",
348 | "(4, 35): [23 12]\n",
349 | "(4, 36): [23 11]\n",
350 | "(4, 37): [24 9]\n",
351 | "(5, 3): [22 10]\n",
352 | "(5, 4): [22 10]\n",
353 | "(5, 5): [22 8]\n",
354 | "(5, 6): [21 6]\n",
355 | "(5, 7): [21 7]\n",
356 | "(5, 8): [20 4]\n",
357 | "(5, 9): [20 4]\n",
358 | "(5, 10): [20 6]\n",
359 | "(5, 11): [19 7]\n",
360 | "(5, 12): [19 9]\n",
361 | "(5, 13): [19 12]\n",
362 | "(5, 14): [19 12]\n",
363 | "(5, 15): [19 13]\n",
364 | "(5, 16): [19 14]\n",
365 | "(5, 17): [19 14]\n",
366 | "(5, 18): [19 14]\n",
367 | "(5, 19): [19 15]\n",
368 | "(5, 20): [19 15]\n",
369 | "(5, 21): [19 25]\n",
370 | "(5, 22): [19 25]\n",
371 | "(5, 23): [20 25]\n",
372 | "(5, 24): [20 25]\n",
373 | "(5, 25): [20 25]\n",
374 | "(5, 26): [21 25]\n",
375 | "(5, 27): [21 24]\n",
376 | "(5, 28): [21 24]\n",
377 | "(5, 29): [22 24]\n",
378 | "(5, 30): [22 24]\n",
379 | "(5, 31): [22 24]\n",
380 | "(5, 32): [23 15]\n",
381 | "(5, 33): [23 14]\n",
382 | "(5, 34): [23 13]\n",
383 | "(5, 35): [23 12]\n",
384 | "(5, 36): [24 12]\n",
385 | "(5, 37): [24 10]\n",
386 | "(6, 3): [22 11]\n",
387 | "(6, 4): [22 10]\n",
388 | "(6, 5): [22 10]\n",
389 | "(6, 6): [21 8]\n",
390 | "(6, 7): [21 8]\n",
391 | "(6, 8): [20 8]\n",
392 | "(6, 9): [20 5]\n",
393 | "(6, 10): [19 5]\n",
394 | "(6, 11): [19 8]\n",
395 | "(6, 12): [19 10]\n",
396 | "(6, 13): [19 11]\n",
397 | "(6, 14): [19 12]\n",
398 | "(6, 15): [19 12]\n",
399 | "(6, 16): [19 13]\n",
400 | "(6, 17): [18 14]\n",
401 | "(6, 18): [18 14]\n",
402 | "(6, 19): [18 15]\n",
403 | "(6, 20): [18 15]\n",
404 | "(6, 21): [19 26]\n",
405 | "(6, 22): [19 26]\n",
406 | "(6, 23): [20 26]\n",
407 | "(6, 24): [20 26]\n",
408 | "(6, 25): [20 25]\n",
409 | "(6, 26): [21 25]\n",
410 | "(6, 27): [21 25]\n",
411 | "(6, 28): [21 25]\n",
412 | "(6, 29): [22 24]\n",
413 | "(6, 30): [22 24]\n",
414 | "(6, 31): [22 24]\n",
415 | "(6, 32): [23 15]\n",
416 | "(6, 33): [23 14]\n",
417 | "(6, 34): [24 13]\n",
418 | "(6, 35): [24 12]\n",
419 | "(6, 36): [24 11]\n",
420 | "(6, 37): [24 10]\n",
421 | "(7, 3): [22 12]\n",
422 | "(7, 4): [22 12]\n",
423 | "(7, 5): [21 11]\n",
424 | "(7, 6): [21 10]\n",
425 | "(7, 7): [21 9]\n",
426 | "(7, 8): [20 9]\n",
427 | "(7, 9): [20 7]\n",
428 | "(7, 10): [19 4]\n",
429 | "(7, 11): [18 6]\n",
430 | "(7, 12): [19 10]\n",
431 | "(7, 13): [18 11]\n",
432 | "(7, 14): [18 11]\n",
433 | "(7, 15): [18 12]\n",
434 | "(7, 16): [18 13]\n",
435 | "(7, 17): [18 13]\n",
436 | "(7, 18): [18 15]\n",
437 | "(7, 19): [18 15]\n",
438 | "(7, 20): [18 15]\n",
439 | "(7, 21): [19 26]\n",
440 | "(7, 22): [19 26]\n",
441 | "(7, 23): [20 26]\n",
442 | "(7, 24): [20 26]\n",
443 | "(7, 25): [21 26]\n",
444 | "(7, 26): [21 26]\n",
445 | "(7, 27): [21 25]\n",
446 | "(7, 28): [22 25]\n",
447 | "(7, 29): [22 25]\n",
448 | "(7, 30): [22 25]\n",
449 | "(7, 31): [23 24]\n",
450 | "(7, 32): [24 15]\n",
451 | "(7, 33): [24 15]\n",
452 | "(7, 34): [24 13]\n",
453 | "(7, 35): [24 12]\n",
454 | "(7, 36): [24 11]\n",
455 | "(7, 37): [25 11]\n",
456 | "(8, 3): [22 13]\n",
457 | "(8, 4): [22 13]\n",
458 | "(8, 5): [21 12]\n",
459 | "(8, 6): [21 11]\n",
460 | "(8, 7): [21 11]\n",
461 | "(8, 8): [21 10]\n",
462 | "(8, 9): [20 8]\n",
463 | "(8, 10): [ 9 11]\n",
464 | "(8, 11): [18 5]\n",
465 | "(8, 12): [10 13]\n",
466 | "(8, 13): [18 9]\n",
467 | "(8, 14): [18 9]\n",
468 | "(8, 15): [18 12]\n",
469 | "(8, 16): [18 12]\n",
470 | "(8, 17): [18 13]\n",
471 | "(8, 18): [18 14]\n",
472 | "(8, 19): [18 14]\n",
473 | "(8, 20): [18 27]\n",
474 | "(8, 21): [19 27]\n",
475 | "(8, 22): [19 27]\n",
476 | "(8, 23): [20 27]\n",
477 | "(8, 24): [20 27]\n",
478 | "(8, 25): [20 26]\n",
479 | "(8, 26): [21 26]\n",
480 | "(8, 27): [21 26]\n",
481 | "(8, 28): [22 26]\n",
482 | "(8, 29): [22 25]\n",
483 | "(8, 30): [22 25]\n",
484 | "(8, 31): [23 25]\n",
485 | "(8, 32): [24 15]\n",
486 | "(8, 33): [24 14]\n",
487 | "(8, 34): [24 13]\n",
488 | "(8, 35): [25 13]\n",
489 | "(8, 36): [25 11]\n",
490 | "(8, 37): [25 10]\n",
491 | "(9, 3): [22 14]\n",
492 | "(9, 4): [21 13]\n",
493 | "(9, 5): [21 13]\n",
494 | "(9, 6): [15 12]\n",
495 | "(9, 7): [12 11]\n",
496 | "(9, 8): [11 12]\n",
497 | "(9, 9): [11 12]\n",
498 | "(9, 10): [ 9 12]\n",
499 | "(9, 11): [10 14]\n",
500 | "(9, 12): [ 9 14]\n",
501 | "(9, 13): [ 9 15]\n",
502 | "(9, 14): [17 10]\n",
503 | "(9, 15): [17 11]\n",
504 | "(9, 16): [17 12]\n",
505 | "(9, 17): [17 13]\n",
506 | "(9, 18): [17 14]\n",
507 | "(9, 19): [17 14]\n",
508 | "(9, 20): [17 14]\n",
509 | "(9, 21): [17 15]\n",
510 | "(9, 22): [19 28]\n",
511 | "(9, 23): [20 27]\n",
512 | "(9, 24): [20 27]\n",
513 | "(9, 25): [21 27]\n",
514 | "(9, 26): [21 27]\n",
515 | "(9, 27): [21 27]\n",
516 | "(9, 28): [22 26]\n",
517 | "(9, 29): [22 26]\n",
518 | "(9, 30): [23 26]\n",
519 | "(9, 31): [23 25]\n",
520 | "(9, 32): [24 15]\n",
521 | "(9, 33): [25 14]\n",
522 | "(9, 34): [25 13]\n",
523 | "(9, 35): [25 12]\n",
524 | "(9, 36): [25 11]\n",
525 | "(9, 37): [26 10]\n",
526 | "(10, 3): [21 14]\n",
527 | "(10, 4): [21 14]\n",
528 | "(10, 5): [13 12]\n",
529 | "(10, 6): [13 12]\n",
530 | "(10, 7): [12 12]\n",
531 | "(10, 8): [12 13]\n",
532 | "(10, 9): [11 13]\n",
533 | "(10, 10): [11 14]\n",
534 | "(10, 11): [10 15]\n",
535 | "(10, 12): [10 15]\n",
536 | "(10, 13): [ 9 16]\n",
537 | "(10, 14): [ 9 16]\n",
538 | "(10, 15): [17 11]\n",
539 | "(10, 16): [17 12]\n",
540 | "(10, 17): [17 12]\n",
541 | "(10, 18): [17 13]\n",
542 | "(10, 19): [17 13]\n",
543 | "(10, 20): [20 20]\n",
544 | "(10, 21): [20 20]\n",
545 | "(10, 22): [19 28]\n",
546 | "(10, 23): [20 28]\n",
547 | "(10, 24): [20 28]\n",
548 | "(10, 25): [21 28]\n",
549 | "(10, 26): [21 27]\n",
550 | "(10, 27): [22 27]\n",
551 | "(10, 28): [22 27]\n",
552 | "(10, 29): [22 26]\n",
553 | "(10, 30): [23 26]\n",
554 | "(10, 31): [23 26]\n",
555 | "(10, 32): [24 25]\n",
556 | "(10, 33): [25 14]\n",
557 | "(10, 34): [25 13]\n",
558 | "(10, 35): [26 12]\n",
559 | "(10, 36): [26 11]\n",
560 | "(10, 37): [26 10]\n",
561 | "(11, 3): [15 11]\n",
562 | "(11, 4): [14 12]\n",
563 | "(11, 5): [14 12]\n",
564 | "(11, 6): [14 13]\n",
565 | "(11, 7): [13 13]\n",
566 | "(11, 8): [12 14]\n",
567 | "(11, 9): [12 14]\n",
568 | "(11, 10): [11 15]\n",
569 | "(11, 11): [11 15]\n",
570 | "(11, 12): [11 16]\n",
571 | "(11, 13): [10 16]\n",
572 | "(11, 14): [ 9 17]\n",
573 | "(11, 15): [17 10]\n",
574 | "(11, 16): [ 8 19]\n",
575 | "(11, 17): [ 7 20]\n",
576 | "(11, 18): [ 4 22]\n",
577 | "(11, 19): [ 5 22]\n",
578 | "(11, 20): [20 20]\n",
579 | "(11, 21): [20 20]\n",
580 | "(11, 22): [19 29]\n",
581 | "(11, 23): [20 29]\n",
582 | "(11, 24): [21 28]\n",
583 | "(11, 25): [21 28]\n",
584 | "(11, 26): [21 28]\n",
585 | "(11, 27): [22 28]\n",
586 | "(11, 28): [22 27]\n",
587 | "(11, 29): [23 27]\n",
588 | "(11, 30): [23 26]\n",
589 | "(11, 31): [23 26]\n",
590 | "(11, 32): [24 26]\n",
591 | "(11, 33): [25 14]\n",
592 | "(11, 34): [26 13]\n",
593 | "(11, 35): [26 12]\n",
594 | "(11, 36): [26 11]\n",
595 | "(11, 37): [27 10]\n",
596 | "(12, 3): [15 12]\n",
597 | "(12, 4): [15 13]\n",
598 | "(12, 5): [14 13]\n",
599 | "(12, 6): [14 14]\n",
600 | "(12, 7): [13 14]\n",
601 | "(12, 8): [13 14]\n",
602 | "(12, 9): [12 15]\n",
603 | "(12, 10): [12 15]\n",
604 | "(12, 11): [12 16]\n",
605 | "(12, 12): [11 16]\n",
606 | "(12, 13): [11 17]\n",
607 | "(12, 14): [10 17]\n",
608 | "(12, 15): [10 18]\n",
609 | "(12, 16): [ 9 19]\n",
610 | "(12, 17): [ 8 20]\n",
611 | "(12, 18): [ 5 22]\n",
612 | "(12, 19): [ 5 23]\n",
613 | "(12, 20): [ 4 25]\n",
614 | "(12, 21): [ 4 26]\n",
615 | "(12, 22): [ 4 27]\n",
616 | "(12, 23): [20 29]\n",
617 | "(12, 24): [20 29]\n",
618 | "(12, 25): [21 29]\n",
619 | "(12, 26): [21 29]\n",
620 | "(12, 27): [22 28]\n",
621 | "(12, 28): [23 28]\n",
622 | "(12, 29): [23 27]\n",
623 | "(12, 30): [23 27]\n",
624 | "(12, 31): [24 26]\n",
625 | "(12, 32): [24 26]\n",
626 | "(12, 33): [25 20]\n",
627 | "(12, 34): [26 13]\n",
628 | "(12, 35): [27 12]\n",
629 | "(12, 36): [27 11]\n",
630 | "(12, 37): [27 10]\n",
631 | "(13, 3): [15 13]\n",
632 | "(13, 4): [15 13]\n",
633 | "(13, 5): [15 14]\n",
634 | "(13, 6): [14 14]\n",
635 | "(13, 7): [14 15]\n",
636 | "(13, 8): [13 15]\n",
637 | "(13, 9): [13 15]\n",
638 | "(13, 10): [13 16]\n",
639 | "(13, 11): [12 16]\n",
640 | "(13, 12): [12 17]\n",
641 | "(13, 13): [12 17]\n",
642 | "(13, 14): [11 18]\n",
643 | "(13, 15): [11 18]\n",
644 | "(13, 16): [10 19]\n",
645 | "(13, 17): [10 20]\n",
646 | "(13, 18): [10 21]\n",
647 | "(13, 19): [ 7 23]\n",
648 | "(13, 20): [ 6 24]\n",
649 | "(13, 21): [ 6 25]\n",
650 | "(13, 22): [ 6 27]\n",
651 | "(13, 23): [20 30]\n",
652 | "(13, 24): [20 30]\n",
653 | "(13, 25): [21 30]\n",
654 | "(13, 26): [22 29]\n",
655 | "(13, 27): [22 29]\n",
656 | "(13, 28): [23 28]\n",
657 | "(13, 29): [23 28]\n",
658 | "(13, 30): [24 27]\n",
659 | "(13, 31): [24 27]\n",
660 | "(13, 32): [24 26]\n",
661 | "(13, 33): [25 26]\n",
662 | "(13, 34): [25 26]\n",
663 | "(13, 35): [25 23]\n",
664 | "(13, 36): [27 11]\n",
665 | "(13, 37): [28 10]\n",
666 | "(14, 3): [16 14]\n",
667 | "(14, 4): [15 14]\n",
668 | "(14, 5): [15 15]\n",
669 | "(14, 6): [15 15]\n",
670 | "(14, 7): [14 15]\n",
671 | "(14, 8): [14 16]\n",
672 | "(14, 9): [14 16]\n",
673 | "(14, 10): [13 16]\n",
674 | "(14, 11): [13 17]\n",
675 | "(14, 12): [12 17]\n",
676 | "(14, 13): [12 18]\n",
677 | "(14, 14): [12 18]\n",
678 | "(14, 15): [11 19]\n",
679 | "(14, 16): [11 20]\n",
680 | "(14, 17): [11 20]\n",
681 | "(14, 18): [ 9 21]\n",
682 | "(14, 19): [ 9 22]\n",
683 | "(14, 20): [ 6 25]\n",
684 | "(14, 21): [ 5 26]\n",
685 | "(14, 22): [ 5 27]\n",
686 | "(14, 23): [ 7 27]\n",
687 | "(14, 24): [ 6 28]\n",
688 | "(14, 25): [21 31]\n",
689 | "(14, 26): [22 30]\n",
690 | "(14, 27): [23 30]\n",
691 | "(14, 28): [23 29]\n",
692 | "(14, 29): [24 29]\n",
693 | "(14, 30): [24 28]\n",
694 | "(14, 31): [24 28]\n",
695 | "(14, 32): [25 27]\n",
696 | "(14, 33): [25 26]\n",
697 | "(14, 34): [25 26]\n",
698 | "(14, 35): [26 26]\n",
699 | "(14, 36): [26 25]\n",
700 | "(14, 37): [28 10]\n",
701 | "(15, 3): [16 15]\n",
702 | "(15, 4): [15 15]\n",
703 | "(15, 5): [15 15]\n",
704 | "(15, 6): [15 15]\n",
705 | "(15, 7): [15 16]\n",
706 | "(15, 8): [14 16]\n",
707 | "(15, 9): [14 16]\n",
708 | "(15, 10): [14 17]\n",
709 | "(15, 11): [13 17]\n",
710 | "(15, 12): [13 18]\n",
711 | "(15, 13): [13 18]\n",
712 | "(15, 14): [12 19]\n",
713 | "(15, 15): [12 19]\n",
714 | "(15, 16): [12 20]\n",
715 | "(15, 17): [11 21]\n",
716 | "(15, 18): [11 21]\n",
717 | "(15, 19): [10 22]\n",
718 | "(15, 20): [ 9 24]\n",
719 | "(15, 21): [10 24]\n",
720 | "(15, 22): [ 8 26]\n",
721 | "(15, 23): [10 26]\n",
722 | "(15, 24): [ 8 28]\n",
723 | "(15, 25): [22 31]\n",
724 | "(15, 26): [22 31]\n",
725 | "(15, 27): [23 30]\n",
726 | "(15, 28): [24 30]\n",
727 | "(15, 29): [24 29]\n",
728 | "(15, 30): [25 29]\n",
729 | "(15, 31): [25 28]\n",
730 | "(15, 32): [25 27]\n",
731 | "(15, 33): [26 27]\n",
732 | "(15, 34): [26 26]\n",
733 | "(15, 35): [26 26]\n",
734 | "(15, 36): [27 25]\n",
735 | "(15, 37): [28 18]\n",
736 | "(16, 3): [16 15]\n",
737 | "(16, 4): [16 15]\n",
738 | "(16, 5): [16 15]\n",
739 | "(16, 6): [15 16]\n",
740 | "(16, 7): [15 16]\n",
741 | "(16, 8): [15 16]\n",
742 | "(16, 9): [14 17]\n",
743 | "(16, 10): [14 17]\n",
744 | "(16, 11): [14 18]\n",
745 | "(16, 12): [13 18]\n",
746 | "(16, 13): [13 19]\n",
747 | "(16, 14): [13 19]\n",
748 | "(16, 15): [13 19]\n",
749 | "(16, 16): [12 20]\n",
750 | "(16, 17): [12 21]\n",
751 | "(16, 18): [12 21]\n",
752 | "(16, 19): [12 22]\n",
753 | "(16, 20): [12 23]\n",
754 | "(16, 21): [11 24]\n",
755 | "(16, 22): [12 24]\n",
756 | "(16, 23): [11 25]\n",
757 | "(16, 24): [ 9 28]\n",
758 | "(16, 25): [22 32]\n",
759 | "(16, 26): [23 32]\n",
760 | "(16, 27): [24 31]\n",
761 | "(16, 28): [24 30]\n",
762 | "(16, 29): [25 30]\n",
763 | "(16, 30): [25 29]\n",
764 | "(16, 31): [26 29]\n",
765 | "(16, 32): [26 28]\n",
766 | "(16, 33): [26 27]\n",
767 | "(16, 34): [27 27]\n",
768 | "(16, 35): [27 24]\n",
769 | "(16, 36): [28 26]\n",
770 | "(16, 37): [28 26]\n",
771 | "(17, 3): [16 15]\n",
772 | "(17, 4): [16 15]\n",
773 | "(17, 5): [16 16]\n",
774 | "(17, 6): [15 16]\n",
775 | "(17, 7): [15 16]\n",
776 | "(17, 8): [15 17]\n",
777 | "(17, 9): [15 17]\n",
778 | "(17, 10): [15 17]\n",
779 | "(17, 11): [14 18]\n",
780 | "(17, 12): [14 18]\n",
781 | "(17, 13): [14 19]\n",
782 | "(17, 14): [13 19]\n",
783 | "(17, 15): [13 20]\n",
784 | "(17, 16): [13 20]\n",
785 | "(17, 17): [13 21]\n",
786 | "(17, 18): [13 21]\n",
787 | "(17, 19): [13 22]\n",
788 | "(17, 20): [13 23]\n",
789 | "(17, 21): [12 23]\n",
790 | "(17, 22): [12 24]\n",
791 | "(17, 23): [11 26]\n",
792 | "(17, 24): [ 9 28]\n",
793 | "(17, 25): [12 27]\n",
794 | "(17, 26): [23 33]\n",
795 | "(17, 27): [24 32]\n",
796 | "(17, 28): [25 31]\n",
797 | "(17, 29): [26 31]\n",
798 | "(17, 30): [26 30]\n",
799 | "(17, 31): [27 29]\n",
800 | "(17, 32): [27 28]\n",
801 | "(17, 33): [27 28]\n",
802 | "(17, 34): [28 28]\n",
803 | "(17, 35): [27 27]\n",
804 | "(17, 36): [28 26]\n",
805 | "(17, 37): [29 26]\n",
806 | "(18, 3): [16 15]\n",
807 | "(18, 4): [16 16]\n",
808 | "(18, 5): [16 16]\n",
809 | "(18, 6): [16 16]\n",
810 | "(18, 7): [16 17]\n",
811 | "(18, 8): [15 17]\n",
812 | "(18, 9): [15 17]\n",
813 | "(18, 10): [15 18]\n",
814 | "(18, 11): [15 18]\n",
815 | "(18, 12): [14 19]\n",
816 | "(18, 13): [14 19]\n",
817 | "(18, 14): [14 19]\n",
818 | "(18, 15): [14 20]\n",
819 | "(18, 16): [14 20]\n",
820 | "(18, 17): [13 21]\n",
821 | "(18, 18): [13 21]\n",
822 | "(18, 19): [13 22]\n",
823 | "(18, 20): [13 23]\n",
824 | "(18, 21): [13 23]\n",
825 | "(18, 22): [13 24]\n",
826 | "(18, 23): [13 25]\n",
827 | "(18, 24): [13 26]\n",
828 | "(18, 25): [13 27]\n",
829 | "(18, 26): [24 34]\n",
830 | "(18, 27): [25 33]\n",
831 | "(18, 28): [26 32]\n",
832 | "(18, 29): [26 32]\n",
833 | "(18, 30): [27 30]\n",
834 | "(18, 31): [28 30]\n",
835 | "(18, 32): [29 30]\n",
836 | "(18, 33): [30 30]\n",
837 | "(18, 34): [32 30]\n",
838 | "(18, 35): [29 27]\n",
839 | "(18, 36): [31 28]\n",
840 | "(18, 37): [32 28]\n",
841 | "(19, 3): [17 16]\n",
842 | "(19, 4): [17 16]\n",
843 | "(19, 5): [16 16]\n",
844 | "(19, 6): [16 17]\n",
845 | "(19, 7): [16 17]\n",
846 | "(19, 8): [16 17]\n",
847 | "(19, 9): [16 18]\n",
848 | "(19, 10): [15 18]\n",
849 | "(19, 11): [15 18]\n",
850 | "(19, 12): [15 19]\n",
851 | "(19, 13): [15 19]\n",
852 | "(19, 14): [15 20]\n",
853 | "(19, 15): [14 20]\n",
854 | "(19, 16): [14 21]\n",
855 | "(19, 17): [14 21]\n",
856 | "(19, 18): [14 22]\n",
857 | "(19, 19): [14 22]\n",
858 | "(19, 20): [14 23]\n",
859 | "(19, 21): [14 23]\n",
860 | "(19, 22): [14 24]\n",
861 | "(19, 23): [14 25]\n",
862 | "(19, 24): [14 26]\n",
863 | "(19, 25): [14 26]\n",
864 | "(19, 26): [18 31]\n",
865 | "(19, 27): [27 34]\n",
866 | "(19, 28): [29 34]\n",
867 | "(19, 29): [30 34]\n",
868 | "(19, 30): [31 33]\n",
869 | "(19, 31): [33 25]\n",
870 | "(19, 32): [31 31]\n",
871 | "(19, 33): [32 31]\n",
872 | "(19, 34): [33 30]\n",
873 | "(19, 35): [33 30]\n",
874 | "(19, 36): [32 28]\n",
875 | "(19, 37): [32 28]\n",
876 | "(20, 3): [17 16]\n",
877 | "(20, 4): [17 16]\n",
878 | "(20, 5): [20 20]\n",
879 | "(20, 6): [20 20]\n",
880 | "(20, 7): [16 17]\n",
881 | "(20, 8): [16 18]\n",
882 | "(20, 9): [16 18]\n",
883 | "(20, 10): [16 18]\n",
884 | "(20, 11): [16 19]\n",
885 | "(20, 12): [20 20]\n",
886 | "(20, 13): [20 20]\n",
887 | "(20, 14): [15 20]\n",
888 | "(20, 15): [15 20]\n",
889 | "(20, 16): [15 21]\n",
890 | "(20, 17): [15 21]\n",
891 | "(20, 18): [14 22]\n",
892 | "(20, 19): [14 22]\n",
893 | "(20, 20): [14 23]\n",
894 | "(20, 21): [14 23]\n",
895 | "(20, 22): [14 24]\n",
896 | "(20, 23): [14 25]\n",
897 | "(20, 24): [14 26]\n",
898 | "(20, 25): [14 26]\n",
899 | "(20, 26): [14 27]\n",
900 | "(20, 27): [20 20]\n",
901 | "(20, 28): [20 20]\n",
902 | "(20, 29): [32 25]\n",
903 | "(20, 30): [32 25]\n",
904 | "(20, 31): [33 25]\n",
905 | "(20, 32): [36 28]\n",
906 | "(20, 33): [34 25]\n",
907 | "(20, 34): [33 30]\n",
908 | "(20, 35): [20 20]\n",
909 | "(20, 36): [20 20]\n",
910 | "(20, 37): [33 28]\n",
911 | "(21, 3): [17 16]\n",
912 | "(21, 4): [17 16]\n",
913 | "(21, 5): [20 20]\n",
914 | "(21, 6): [20 20]\n",
915 | "(21, 7): [17 17]\n",
916 | "(21, 8): [16 18]\n",
917 | "(21, 9): [16 18]\n",
918 | "(21, 10): [16 18]\n",
919 | "(21, 11): [16 19]\n",
920 | "(21, 12): [20 20]\n",
921 | "(21, 13): [20 20]\n",
922 | "(21, 14): [15 20]\n",
923 | "(21, 15): [15 20]\n",
924 | "(21, 16): [15 21]\n",
925 | "(21, 17): [15 21]\n",
926 | "(21, 18): [15 22]\n",
927 | "(21, 19): [15 22]\n",
928 | "(21, 20): [15 23]\n",
929 | "(21, 21): [15 24]\n",
930 | "(21, 22): [15 24]\n",
931 | "(21, 23): [15 25]\n",
932 | "(21, 24): [15 26]\n",
933 | "(21, 25): [15 26]\n",
934 | "(21, 26): [16 26]\n",
935 | "(21, 27): [20 20]\n",
936 | "(21, 28): [20 20]\n",
937 | "(21, 29): [28 23]\n",
938 | "(21, 30): [29 23]\n",
939 | "(21, 31): [30 23]\n",
940 | "(21, 32): [32 24]\n",
941 | "(21, 33): [32 23]\n",
942 | "(21, 34): [33 23]\n",
943 | "(21, 35): [20 20]\n",
944 | "(21, 36): [20 20]\n",
945 | "(21, 37): [34 22]\n",
946 | "(22, 3): [17 16]\n",
947 | "(22, 4): [17 16]\n",
948 | "(22, 5): [17 17]\n",
949 | "(22, 6): [17 17]\n",
950 | "(22, 7): [17 18]\n",
951 | "(22, 8): [17 18]\n",
952 | "(22, 9): [17 18]\n",
953 | "(22, 10): [16 19]\n",
954 | "(22, 11): [16 19]\n",
955 | "(22, 12): [16 19]\n",
956 | "(22, 13): [16 20]\n",
957 | "(22, 14): [16 20]\n",
958 | "(22, 15): [16 20]\n",
959 | "(22, 16): [16 21]\n",
960 | "(22, 17): [15 21]\n",
961 | "(22, 18): [15 22]\n",
962 | "(22, 19): [15 22]\n",
963 | "(22, 20): [15 23]\n",
964 | "(22, 21): [15 24]\n",
965 | "(22, 22): [15 24]\n",
966 | "(22, 23): [15 25]\n",
967 | "(22, 24): [16 25]\n",
968 | "(22, 25): [15 26]\n",
969 | "(22, 26): [16 27]\n",
970 | "(22, 27): [27 23]\n",
971 | "(22, 28): [27 23]\n",
972 | "(22, 29): [28 23]\n",
973 | "(22, 30): [28 22]\n",
974 | "(22, 31): [29 22]\n",
975 | "(22, 32): [31 22]\n",
976 | "(22, 33): [31 22]\n",
977 | "(22, 34): [33 22]\n",
978 | "(22, 35): [34 22]\n",
979 | "(22, 36): [35 22]\n",
980 | "(22, 37): [36 22]\n",
981 | "(23, 3): [18 16]\n",
982 | "(23, 4): [18 17]\n",
983 | "(23, 5): [17 17]\n",
984 | "(23, 6): [17 17]\n",
985 | "(23, 7): [17 18]\n",
986 | "(23, 8): [17 18]\n",
987 | "(23, 9): [17 18]\n",
988 | "(23, 10): [17 19]\n",
989 | "(23, 11): [17 19]\n",
990 | "(23, 12): [17 19]\n",
991 | "(23, 13): [16 20]\n",
992 | "(23, 14): [16 20]\n",
993 | "(23, 15): [16 21]\n",
994 | "(23, 16): [16 21]\n",
995 | "(23, 17): [16 21]\n",
996 | "(23, 18): [16 22]\n",
997 | "(23, 19): [16 22]\n",
998 | "(23, 20): [16 23]\n",
999 | "(23, 21): [16 24]\n",
1000 | "(23, 22): [16 24]\n",
1001 | "(23, 23): [16 25]\n",
1002 | "(23, 24): [18 24]\n",
1003 | "(23, 25): [16 26]\n",
1004 | "(23, 26): [26 23]\n",
1005 | "(23, 27): [27 23]\n",
1006 | "(23, 28): [27 22]\n",
1007 | "(23, 29): [27 22]\n",
1008 | "(23, 30): [28 22]\n",
1009 | "(23, 31): [28 22]\n",
1010 | "(23, 32): [29 21]\n",
1011 | "(23, 33): [30 21]\n",
1012 | "(23, 34): [30 21]\n",
1013 | "(23, 35): [31 21]\n",
1014 | "(23, 36): [31 20]\n",
1015 | "(23, 37): [33 20]\n",
1016 | "(24, 3): [18 16]\n",
1017 | "(24, 4): [18 17]\n",
1018 | "(24, 5): [18 17]\n",
1019 | "(24, 6): [18 17]\n",
1020 | "(24, 7): [18 18]\n",
1021 | "(24, 8): [17 18]\n",
1022 | "(24, 9): [17 19]\n",
1023 | "(24, 10): [17 19]\n",
1024 | "(24, 11): [17 19]\n",
1025 | "(24, 12): [17 20]\n",
1026 | "(24, 13): [17 20]\n",
1027 | "(24, 14): [17 20]\n",
1028 | "(24, 15): [16 21]\n",
1029 | "(24, 16): [16 21]\n",
1030 | "(24, 17): [16 22]\n",
1031 | "(24, 18): [16 22]\n",
1032 | "(24, 19): [16 23]\n",
1033 | "(24, 20): [16 23]\n",
1034 | "(24, 21): [16 24]\n",
1035 | "(24, 22): [16 24]\n",
1036 | "(24, 23): [16 25]\n",
1037 | "(24, 24): [25 23]\n",
1038 | "(24, 25): [25 23]\n",
1039 | "(24, 26): [26 23]\n",
1040 | "(24, 27): [26 22]\n",
1041 | "(24, 28): [27 22]\n",
1042 | "(24, 29): [27 22]\n",
1043 | "(24, 30): [27 22]\n",
1044 | "(24, 31): [28 21]\n",
1045 | "(24, 32): [29 21]\n",
1046 | "(24, 33): [29 21]\n",
1047 | "(24, 34): [30 20]\n",
1048 | "(24, 35): [30 20]\n",
1049 | "(24, 36): [31 19]\n",
1050 | "(24, 37): [34 20]\n",
1051 | "(25, 3): [18 17]\n",
1052 | "(25, 4): [18 17]\n",
1053 | "(25, 5): [18 17]\n",
1054 | "(25, 6): [18 18]\n",
1055 | "(25, 7): [18 18]\n",
1056 | "(25, 8): [18 18]\n",
1057 | "(25, 9): [18 19]\n",
1058 | "(25, 10): [17 19]\n",
1059 | "(25, 11): [17 19]\n",
1060 | "(25, 12): [17 20]\n",
1061 | "(25, 13): [17 20]\n",
1062 | "(25, 14): [17 20]\n",
1063 | "(25, 15): [17 21]\n",
1064 | "(25, 16): [17 21]\n",
1065 | "(25, 17): [17 22]\n",
1066 | "(25, 18): [17 22]\n",
1067 | "(25, 19): [17 23]\n",
1068 | "(25, 20): [17 23]\n",
1069 | "(25, 21): [17 24]\n",
1070 | "(25, 22): [24 23]\n",
1071 | "(25, 23): [24 23]\n",
1072 | "(25, 24): [25 23]\n",
1073 | "(25, 25): [25 22]\n",
1074 | "(25, 26): [25 22]\n",
1075 | "(25, 27): [26 22]\n",
1076 | "(25, 28): [26 22]\n",
1077 | "(25, 29): [27 21]\n",
1078 | "(25, 30): [27 21]\n",
1079 | "(25, 31): [28 21]\n",
1080 | "(25, 32): [28 20]\n",
1081 | "(25, 33): [29 20]\n",
1082 | "(25, 34): [29 20]\n",
1083 | "(25, 35): [30 19]\n",
1084 | "(25, 36): [30 19]\n",
1085 | "(25, 37): [31 19]\n",
1086 | "(26, 3): [19 17]\n",
1087 | "(26, 4): [19 17]\n",
1088 | "(26, 5): [18 17]\n",
1089 | "(26, 6): [18 18]\n",
1090 | "(26, 7): [18 18]\n",
1091 | "(26, 8): [18 18]\n",
1092 | "(26, 9): [18 19]\n",
1093 | "(26, 10): [18 19]\n",
1094 | "(26, 11): [18 19]\n",
1095 | "(26, 12): [17 20]\n",
1096 | "(26, 13): [17 20]\n",
1097 | "(26, 14): [17 21]\n",
1098 | "(26, 15): [17 21]\n",
1099 | "(26, 16): [17 21]\n",
1100 | "(26, 17): [17 22]\n",
1101 | "(26, 18): [17 22]\n",
1102 | "(26, 19): [17 23]\n",
1103 | "(26, 20): [17 23]\n",
1104 | "(26, 21): [17 24]\n",
1105 | "(26, 22): [24 23]\n",
1106 | "(26, 23): [24 23]\n",
1107 | "(26, 24): [24 22]\n",
1108 | "(26, 25): [25 22]\n",
1109 | "(26, 26): [25 22]\n",
1110 | "(26, 27): [25 21]\n",
1111 | "(26, 28): [26 21]\n",
1112 | "(26, 29): [26 21]\n",
1113 | "(26, 30): [27 21]\n",
1114 | "(26, 31): [27 20]\n",
1115 | "(26, 32): [28 20]\n",
1116 | "(26, 33): [28 20]\n",
1117 | "(26, 34): [29 19]\n",
1118 | "(26, 35): [29 19]\n",
1119 | "(26, 36): [30 19]\n",
1120 | "(26, 37): [31 18]\n",
1121 | "(27, 3): [19 17]\n",
1122 | "(27, 4): [19 17]\n",
1123 | "(27, 5): [19 17]\n",
1124 | "(27, 6): [19 18]\n",
1125 | "(27, 7): [18 18]\n",
1126 | "(27, 8): [18 19]\n",
1127 | "(27, 9): [18 19]\n",
1128 | "(27, 10): [18 19]\n",
1129 | "(27, 11): [18 20]\n",
1130 | "(27, 12): [18 20]\n",
1131 | "(27, 13): [18 20]\n",
1132 | "(27, 14): [18 21]\n",
1133 | "(27, 15): [18 21]\n",
1134 | "(27, 16): [17 21]\n",
1135 | "(27, 17): [17 22]\n",
1136 | "(27, 18): [17 22]\n",
1137 | "(27, 19): [17 23]\n",
1138 | "(27, 20): [17 23]\n",
1139 | "(27, 21): [17 24]\n",
1140 | "(27, 22): [17 24]\n",
1141 | "(27, 23): [24 22]\n",
1142 | "(27, 24): [24 22]\n",
1143 | "(27, 25): [24 22]\n",
1144 | "(27, 26): [25 21]\n",
1145 | "(27, 27): [25 21]\n",
1146 | "(27, 28): [25 21]\n",
1147 | "(27, 29): [26 21]\n",
1148 | "(27, 30): [26 20]\n",
1149 | "(27, 31): [27 20]\n",
1150 | "(27, 32): [27 20]\n",
1151 | "(27, 33): [28 19]\n",
1152 | "(27, 34): [28 19]\n",
1153 | "(27, 35): [29 18]\n",
1154 | "(27, 36): [30 18]\n",
1155 | "(27, 37): [31 18]\n",
1156 | "(28, 3): [19 17]\n",
1157 | "(28, 4): [19 17]\n",
1158 | "(28, 5): [19 18]\n",
1159 | "(28, 6): [19 18]\n",
1160 | "(28, 7): [19 18]\n",
1161 | "(28, 8): [19 19]\n",
1162 | "(28, 9): [18 19]\n",
1163 | "(28, 10): [18 19]\n",
1164 | "(28, 11): [18 20]\n",
1165 | "(28, 12): [18 20]\n",
1166 | "(28, 13): [18 20]\n",
1167 | "(28, 14): [18 21]\n",
1168 | "(28, 15): [18 21]\n",
1169 | "(28, 16): [18 21]\n",
1170 | "(28, 17): [18 22]\n",
1171 | "(28, 18): [18 22]\n",
1172 | "(28, 19): [18 23]\n",
1173 | "(28, 20): [18 23]\n",
1174 | "(28, 21): [18 24]\n",
1175 | "(28, 22): [23 22]\n",
1176 | "(28, 23): [23 22]\n",
1177 | "(28, 24): [24 22]\n",
1178 | "(28, 25): [24 21]\n",
1179 | "(28, 26): [24 21]\n",
1180 | "(28, 27): [25 21]\n",
1181 | "(28, 28): [25 21]\n",
1182 | "(28, 29): [25 20]\n",
1183 | "(28, 30): [26 20]\n",
1184 | "(28, 31): [26 20]\n",
1185 | "(28, 32): [27 19]\n",
1186 | "(28, 33): [27 19]\n",
1187 | "(28, 34): [28 18]\n",
1188 | "(28, 35): [28 18]\n",
1189 | "(28, 36): [29 18]\n",
1190 | "(28, 37): [30 17]\n",
1191 | "(29, 3): [20 17]\n",
1192 | "(29, 4): [19 17]\n",
1193 | "(29, 5): [19 18]\n",
1194 | "(29, 6): [19 18]\n",
1195 | "(29, 7): [19 18]\n",
1196 | "(29, 8): [19 19]\n",
1197 | "(29, 9): [19 19]\n",
1198 | "(29, 10): [19 19]\n",
1199 | "(29, 11): [19 20]\n",
1200 | "(29, 12): [18 20]\n",
1201 | "(29, 13): [18 20]\n",
1202 | "(29, 14): [18 21]\n",
1203 | "(29, 15): [18 21]\n",
1204 | "(29, 16): [18 22]\n",
1205 | "(29, 17): [18 22]\n",
1206 | "(29, 18): [18 22]\n",
1207 | "(29, 19): [18 23]\n",
1208 | "(29, 20): [18 23]\n",
1209 | "(29, 21): [18 24]\n",
1210 | "(29, 22): [23 22]\n",
1211 | "(29, 23): [23 22]\n",
1212 | "(29, 24): [23 21]\n",
1213 | "(29, 25): [24 21]\n",
1214 | "(29, 26): [24 21]\n",
1215 | "(29, 27): [24 21]\n",
1216 | "(29, 28): [25 20]\n",
1217 | "(29, 29): [25 20]\n",
1218 | "(29, 30): [26 20]\n",
1219 | "(29, 31): [26 19]\n",
1220 | "(29, 32): [26 19]\n",
1221 | "(29, 33): [27 19]\n",
1222 | "(29, 34): [28 18]\n",
1223 | "(29, 35): [28 18]\n",
1224 | "(29, 36): [29 17]\n",
1225 | "(29, 37): [30 17]\n",
1226 | "(30, 3): [20 17]\n",
1227 | "(30, 4): [20 17]\n",
1228 | "(30, 5): [20 18]\n",
1229 | "(30, 6): [20 18]\n",
1230 | "(30, 7): [19 18]\n",
1231 | "(30, 8): [19 19]\n",
1232 | "(30, 9): [19 19]\n",
1233 | "(30, 10): [19 20]\n",
1234 | "(30, 11): [19 20]\n",
1235 | "(30, 12): [19 20]\n",
1236 | "(30, 13): [19 21]\n",
1237 | "(30, 14): [19 21]\n",
1238 | "(30, 15): [18 21]\n",
1239 | "(30, 16): [18 22]\n",
1240 | "(30, 17): [18 22]\n",
1241 | "(30, 18): [18 22]\n",
1242 | "(30, 19): [18 23]\n",
1243 | "(30, 20): [20 20]\n",
1244 | "(30, 21): [20 20]\n",
1245 | "(30, 22): [23 22]\n",
1246 | "(30, 23): [23 22]\n",
1247 | "(30, 24): [23 21]\n",
1248 | "(30, 25): [23 21]\n",
1249 | "(30, 26): [24 21]\n",
1250 | "(30, 27): [24 20]\n",
1251 | "(30, 28): [24 20]\n",
1252 | "(30, 29): [25 20]\n",
1253 | "(30, 30): [25 19]\n",
1254 | "(30, 31): [26 19]\n",
1255 | "(30, 32): [26 19]\n",
1256 | "(30, 33): [27 18]\n",
1257 | "(30, 34): [27 18]\n",
1258 | "(30, 35): [28 17]\n",
1259 | "(30, 36): [28 17]\n",
1260 | "(30, 37): [29 16]\n",
1261 | "(31, 3): [20 17]\n",
1262 | "(31, 4): [20 17]\n",
1263 | "(31, 5): [20 18]\n",
1264 | "(31, 6): [20 18]\n",
1265 | "(31, 7): [20 18]\n",
1266 | "(31, 8): [20 19]\n",
1267 | "(31, 9): [19 19]\n",
1268 | "(31, 10): [19 20]\n",
1269 | "(31, 11): [19 20]\n",
1270 | "(31, 12): [19 20]\n",
1271 | "(31, 13): [19 21]\n",
1272 | "(31, 14): [19 21]\n",
1273 | "(31, 15): [19 21]\n",
1274 | "(31, 16): [19 22]\n",
1275 | "(31, 17): [19 22]\n",
1276 | "(31, 18): [19 22]\n",
1277 | "(31, 19): [18 23]\n",
1278 | "(31, 20): [20 20]\n",
1279 | "(31, 21): [20 20]\n",
1280 | "(31, 22): [23 22]\n",
1281 | "(31, 23): [23 21]\n",
1282 | "(31, 24): [23 21]\n",
1283 | "(31, 25): [23 21]\n",
1284 | "(31, 26): [23 20]\n",
1285 | "(31, 27): [24 20]\n",
1286 | "(31, 28): [24 20]\n",
1287 | "(31, 29): [24 19]\n",
1288 | "(31, 30): [25 19]\n",
1289 | "(31, 31): [25 19]\n",
1290 | "(31, 32): [26 18]\n",
1291 | "(31, 33): [26 18]\n",
1292 | "(31, 34): [27 18]\n",
1293 | "(31, 35): [27 17]\n",
1294 | "(31, 36): [28 17]\n",
1295 | "(31, 37): [29 16]\n",
1296 | "(32, 3): [21 17]\n",
1297 | "(32, 4): [21 17]\n",
1298 | "(32, 5): [20 18]\n",
1299 | "(32, 6): [20 18]\n",
1300 | "(32, 7): [20 19]\n",
1301 | "(32, 8): [20 19]\n",
1302 | "(32, 9): [20 19]\n",
1303 | "(32, 10): [20 20]\n",
1304 | "(32, 11): [19 20]\n",
1305 | "(32, 12): [19 20]\n",
1306 | "(32, 13): [19 21]\n",
1307 | "(32, 14): [19 21]\n",
1308 | "(32, 15): [19 21]\n",
1309 | "(32, 16): [19 22]\n",
1310 | "(32, 17): [19 22]\n",
1311 | "(32, 18): [19 23]\n",
1312 | "(32, 19): [19 23]\n",
1313 | "(32, 20): [19 23]\n",
1314 | "(32, 21): [22 22]\n",
1315 | "(32, 22): [22 21]\n",
1316 | "(32, 23): [22 21]\n",
1317 | "(32, 24): [23 21]\n",
1318 | "(32, 25): [23 21]\n",
1319 | "(32, 26): [23 20]\n",
1320 | "(32, 27): [23 20]\n",
1321 | "(32, 28): [24 19]\n",
1322 | "(32, 29): [24 19]\n",
1323 | "(32, 30): [25 19]\n",
1324 | "(32, 31): [25 18]\n",
1325 | "(32, 32): [25 18]\n",
1326 | "(32, 33): [26 18]\n",
1327 | "(32, 34): [26 17]\n",
1328 | "(32, 35): [27 17]\n",
1329 | "(32, 36): [28 16]\n",
1330 | "(32, 37): [28 16]\n",
1331 | "(33, 3): [21 17]\n",
1332 | "(33, 4): [21 17]\n",
1333 | "(33, 5): [21 18]\n",
1334 | "(33, 6): [20 18]\n",
1335 | "(33, 7): [20 19]\n",
1336 | "(33, 8): [20 19]\n",
1337 | "(33, 9): [20 19]\n",
1338 | "(33, 10): [20 20]\n",
1339 | "(33, 11): [20 20]\n",
1340 | "(33, 12): [20 20]\n",
1341 | "(33, 13): [19 21]\n",
1342 | "(33, 14): [19 21]\n",
1343 | "(33, 15): [19 21]\n",
1344 | "(33, 16): [19 22]\n",
1345 | "(33, 17): [19 22]\n",
1346 | "(33, 18): [19 23]\n",
1347 | "(33, 19): [21 22]\n",
1348 | "(33, 20): [22 22]\n",
1349 | "(33, 21): [22 22]\n",
1350 | "(33, 22): [22 21]\n",
1351 | "(33, 23): [22 21]\n",
1352 | "(33, 24): [22 21]\n",
1353 | "(33, 25): [23 20]\n",
1354 | "(33, 26): [23 20]\n",
1355 | "(33, 27): [23 20]\n",
1356 | "(33, 28): [24 19]\n",
1357 | "(33, 29): [24 19]\n",
1358 | "(33, 30): [24 19]\n",
1359 | "(33, 31): [25 18]\n",
1360 | "(33, 32): [25 18]\n",
1361 | "(33, 33): [26 17]\n",
1362 | "(33, 34): [26 17]\n",
1363 | "(33, 35): [27 17]\n",
1364 | "(33, 36): [27 16]\n",
1365 | "(33, 37): [28 16]\n",
1366 | "(34, 3): [21 17]\n",
1367 | "(34, 4): [21 18]\n",
1368 | "(34, 5): [21 18]\n",
1369 | "(34, 6): [21 18]\n",
1370 | "(34, 7): [21 19]\n",
1371 | "(34, 8): [20 19]\n",
1372 | "(34, 9): [20 20]\n",
1373 | "(34, 10): [20 20]\n",
1374 | "(34, 11): [20 20]\n",
1375 | "(34, 12): [20 21]\n",
1376 | "(34, 13): [20 21]\n",
1377 | "(34, 14): [20 21]\n",
1378 | "(34, 15): [19 22]\n",
1379 | "(34, 16): [19 22]\n",
1380 | "(34, 17): [19 22]\n",
1381 | "(34, 18): [21 22]\n",
1382 | "(34, 19): [21 22]\n",
1383 | "(34, 20): [21 22]\n",
1384 | "(34, 21): [22 21]\n",
1385 | "(34, 22): [22 21]\n",
1386 | "(34, 23): [22 21]\n",
1387 | "(34, 24): [22 20]\n",
1388 | "(34, 25): [23 20]\n",
1389 | "(34, 26): [23 20]\n",
1390 | "(34, 27): [23 19]\n",
1391 | "(34, 28): [23 19]\n",
1392 | "(34, 29): [24 19]\n",
1393 | "(34, 30): [24 18]\n",
1394 | "(34, 31): [24 18]\n",
1395 | "(34, 32): [25 18]\n",
1396 | "(34, 33): [25 17]\n",
1397 | "(34, 34): [26 17]\n",
1398 | "(34, 35): [26 16]\n",
1399 | "(34, 36): [27 16]\n",
1400 | "(34, 37): [27 15]\n",
1401 | "(35, 3): [22 17]\n",
1402 | "(35, 4): [21 18]\n",
1403 | "(35, 5): [21 18]\n",
1404 | "(35, 6): [21 18]\n",
1405 | "(35, 7): [21 19]\n",
1406 | "(35, 8): [21 19]\n",
1407 | "(35, 9): [21 20]\n",
1408 | "(35, 10): [20 20]\n",
1409 | "(35, 11): [20 20]\n",
1410 | "(35, 12): [20 21]\n",
1411 | "(35, 13): [20 21]\n",
1412 | "(35, 14): [20 21]\n",
1413 | "(35, 15): [20 22]\n",
1414 | "(35, 16): [20 22]\n",
1415 | "(35, 17): [20 22]\n",
1416 | "(35, 18): [20 23]\n",
1417 | "(35, 19): [19 23]\n",
1418 | "(35, 20): [21 22]\n",
1419 | "(35, 21): [21 21]\n",
1420 | "(35, 22): [22 21]\n",
1421 | "(35, 23): [22 21]\n",
1422 | "(35, 24): [22 20]\n",
1423 | "(35, 25): [22 20]\n",
1424 | "(35, 26): [22 19]\n",
1425 | "(35, 27): [23 19]\n",
1426 | "(35, 28): [23 19]\n",
1427 | "(35, 29): [23 18]\n",
1428 | "(35, 30): [24 18]\n",
1429 | "(35, 31): [24 18]\n",
1430 | "(35, 32): [24 17]\n",
1431 | "(35, 33): [25 17]\n",
1432 | "(35, 34): [25 17]\n",
1433 | "(35, 35): [26 16]\n",
1434 | "(35, 36): [26 16]\n",
1435 | "(35, 37): [27 15]\n",
1436 | "(36, 3): [22 17]\n",
1437 | "(36, 4): [22 18]\n",
1438 | "(36, 5): [22 18]\n",
1439 | "(36, 6): [21 19]\n",
1440 | "(36, 7): [21 19]\n",
1441 | "(36, 8): [21 19]\n",
1442 | "(36, 9): [21 20]\n",
1443 | "(36, 10): [21 20]\n",
1444 | "(36, 11): [20 20]\n",
1445 | "(36, 12): [20 21]\n",
1446 | "(36, 13): [20 21]\n",
1447 | "(36, 14): [20 21]\n",
1448 | "(36, 15): [20 22]\n",
1449 | "(36, 16): [20 22]\n",
1450 | "(36, 17): [20 22]\n",
1451 | "(36, 18): [20 23]\n",
1452 | "(36, 19): [21 22]\n",
1453 | "(36, 20): [21 22]\n",
1454 | "(36, 21): [21 21]\n",
1455 | "(36, 22): [21 21]\n",
1456 | "(36, 23): [22 20]\n",
1457 | "(36, 24): [22 20]\n",
1458 | "(36, 25): [22 20]\n",
1459 | "(36, 26): [22 19]\n",
1460 | "(36, 27): [22 19]\n",
1461 | "(36, 28): [23 19]\n",
1462 | "(36, 29): [23 18]\n",
1463 | "(36, 30): [23 18]\n",
1464 | "(36, 31): [24 17]\n",
1465 | "(36, 32): [24 17]\n",
1466 | "(36, 33): [24 17]\n",
1467 | "(36, 34): [25 16]\n",
1468 | "(36, 35): [26 16]\n",
1469 | "(36, 36): [26 15]\n",
1470 | "(36, 37): [27 15]\n",
1471 | "(37, 3): [22 17]\n",
1472 | "(37, 4): [22 18]\n",
1473 | "(37, 5): [22 18]\n",
1474 | "(37, 6): [22 19]\n",
1475 | "(37, 7): [21 19]\n",
1476 | "(37, 8): [21 19]\n",
1477 | "(37, 9): [21 20]\n",
1478 | "(37, 10): [21 20]\n",
1479 | "(37, 11): [21 20]\n",
1480 | "(37, 12): [21 21]\n",
1481 | "(37, 13): [21 21]\n",
1482 | "(37, 14): [20 21]\n",
1483 | "(37, 15): [20 22]\n",
1484 | "(37, 16): [20 22]\n",
1485 | "(37, 17): [20 22]\n",
1486 | "(37, 18): [20 22]\n",
1487 | "(37, 19): [21 22]\n",
1488 | "(37, 20): [21 21]\n",
1489 | "(37, 21): [21 21]\n",
1490 | "(37, 22): [21 21]\n",
1491 | "(37, 23): [21 20]\n",
1492 | "(37, 24): [22 20]\n",
1493 | "(37, 25): [22 20]\n",
1494 | "(37, 26): [22 19]\n",
1495 | "(37, 27): [22 19]\n",
1496 | "(37, 28): [23 18]\n",
1497 | "(37, 29): [23 18]\n",
1498 | "(37, 30): [23 18]\n",
1499 | "(37, 31): [24 17]\n",
1500 | "(37, 32): [24 17]\n",
1501 | "(37, 33): [24 17]\n",
1502 | "(37, 34): [25 16]\n",
1503 | "(37, 35): [25 16]\n",
1504 | "(37, 36): [26 15]\n",
1505 | "(37, 37): [26 15]\n"
1506 | ]
1507 | }
1508 | ],
1509 | "source": [
1510 | "z_min = np.min(all_z, axis = 0)\n",
1511 | "z_max = np.max(all_z, axis = 0)\n",
1512 | "all_z = np.round(5 * all_z + 20).astype(np.int)\n",
1513 | "# all_z = (all_z - z_min)\n",
1514 | "latent_map = {}\n",
1515 | "i = 0\n",
1516 | "for x in range(start, end + 1):\n",
1517 | " for y in range(start, end + 1):\n",
1518 | " latent_map[(x,y)] = all_z[i]\n",
1519 | " i += 1\n",
1520 | "for k in latent_map:\n",
1521 | " print (str(k) + ': ' + str(latent_map[k]))"
1522 | ]
1523 | },
1524 | {
1525 | "cell_type": "code",
1526 | "execution_count": 35,
1527 | "metadata": {},
1528 | "outputs": [
1529 | {
1530 | "data": {
1531 | "text/plain": [
1532 | "(1.0746268656716418, 1.0831858407079646)"
1533 | ]
1534 | },
1535 | "execution_count": 35,
1536 | "metadata": {},
1537 | "output_type": "execute_result"
1538 | }
1539 | ],
1540 | "source": [
1541 | "count = 0\n",
1542 | "sum_diff_x = 0.0\n",
1543 | "sum_diff_y = 0.0\n",
1544 | "for i in range(0, all_z.shape[0] - 1):\n",
1545 | " sum_diff_x += np.abs(all_z[i][0] - all_z[i+1][0])\n",
1546 | " sum_diff_y += np.abs(all_z[i][1] - all_z[i+1][1])\n",
1547 | " count += 1\n",
1548 | "avg_diff_x = sum_diff_x / count\n",
1549 | "avg_diff_y = sum_diff_y / count\n",
1550 | "1 / avg_diff_x, 1 / avg_diff_y"
1551 | ]
1552 | },
1553 | {
1554 | "cell_type": "code",
1555 | "execution_count": 43,
1556 | "metadata": {},
1557 | "outputs": [
1558 | {
1559 | "data": {
1560 | "text/plain": [
1561 | ""
1562 | ]
1563 | },
1564 | "execution_count": 43,
1565 | "metadata": {},
1566 | "output_type": "execute_result"
1567 | },
1568 | {
1569 | "data": {
1570 | "image/png": "\n",
1571 | "text/plain": [
1572 | ""
1573 | ]
1574 | },
1575 | "metadata": {
1576 | "needs_background": "light"
1577 | },
1578 | "output_type": "display_data"
1579 | }
1580 | ],
1581 | "source": [
1582 | "scale = 5\n",
1583 | "img_latent = Image.new(\"RGB\", (width * 10, height * 10), \"#FFFFFF\")\n",
1584 | "draw = ImageDraw.Draw(img_latent)\n",
1585 | "for k in latent_map:\n",
1586 | " x, y = k\n",
1587 | " if [x, y] in unvalid_pos:\n",
1588 | " continue\n",
1589 | " else:\n",
1590 | " x_scaled, y_scaled = latent_map[k][1] * 10, latent_map[k][0] * 10\n",
1591 | " draw.ellipse((x_scaled-2, y_scaled-2, x_scaled+2, y_scaled+2), fill = img.getpixel((y, x)))\n",
1592 | "img_latent.save('latent_map.png', 'PNG')\n",
1593 | "img_latent_scaled = np.array(img_latent) / 255.\n",
1594 | "# plt.imshow(img_arr_scaled)\n",
1595 | "# plt.show()\n",
1596 | "# plt.imshow(img_latent_scaled)\n",
1597 | "plt.show()\n",
1598 | "f, axarr = plt.subplots(1,2, figsize=(15,15))\n",
1599 | "# plt.setp(axarr, xticks=[], yticks=[])\n",
1600 | "axarr[0].imshow(img_arr_scaled)\n",
1601 | "axarr[1].imshow(img_latent_scaled)"
1602 | ]
1603 | },
1604 | {
1605 | "cell_type": "code",
1606 | "execution_count": null,
1607 | "metadata": {},
1608 | "outputs": [],
1609 | "source": []
1610 | }
1611 | ],
1612 | "metadata": {
1613 | "kernelspec": {
1614 | "display_name": "Python 3",
1615 | "language": "python",
1616 | "name": "python3"
1617 | },
1618 | "language_info": {
1619 | "codemirror_mode": {
1620 | "name": "ipython",
1621 | "version": 3
1622 | },
1623 | "file_extension": ".py",
1624 | "mimetype": "text/x-python",
1625 | "name": "python",
1626 | "nbconvert_exporter": "python",
1627 | "pygments_lexer": "ipython3",
1628 | "version": "3.6.8"
1629 | }
1630 | },
1631 | "nbformat": 4,
1632 | "nbformat_minor": 2
1633 | }
1634 |
--------------------------------------------------------------------------------
/e2c.yml:
--------------------------------------------------------------------------------
1 | name: e2c
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - anaconda
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _tflow_select=2.1.0=gpu
10 | - absl-py=0.7.1=py36_0
11 | - asn1crypto=0.24.0=py36_0
12 | - astor=0.8.0=py36_0
13 | - blas=1.0=mkl
14 | - c-ares=1.15.0=h7b6447c_1001
15 | - ca-certificates=2019.5.15=1
16 | - certifi=2019.9.11=py36_0
17 | - cffi=1.12.3=py36h2e261b9_0
18 | - chardet=3.0.4=py36_1003
19 | - cryptography=2.7=py36h1ba5d50_0
20 | - cudatoolkit=10.0.130=0
21 | - cudnn=7.6.0=cuda10.0_0
22 | - cupti=10.0.130=0
23 | - cycler=0.10.0=py_1
24 | - dbus=1.13.6=h746ee38_0
25 | - expat=2.2.6=he6710b0_0
26 | - fontconfig=2.13.1=he4413a7_1000
27 | - freetype=2.9.1=h8a8886c_1
28 | - gast=0.2.2=py36_0
29 | - gettext=0.19.8.1=hc5be6a0_1002
30 | - glib=2.56.2=had28632_1001
31 | - google-pasta=0.1.7=py_0
32 | - grpcio=1.16.1=py36hf8bcb03_1
33 | - gst-plugins-base=1.14.0=hbbd80ab_1
34 | - gstreamer=1.14.0=hb453b48_1
35 | - h5py=2.9.0=py36h7918eee_0
36 | - hdf5=1.10.4=hb1b8bf9_0
37 | - icu=58.2=hf484d3e_1000
38 | - idna=2.8=py36_0
39 | - intel-openmp=2019.4=243
40 | - jpeg=9b=h024ee3a_2
41 | - keras-applications=1.0.8=py_0
42 | - keras-preprocessing=1.1.0=py_1
43 | - kiwisolver=1.1.0=py36hc9558a2_0
44 | - libedit=3.1.20181209=hc058e9b_0
45 | - libffi=3.2.1=hd88cf55_4
46 | - libgcc-ng=9.1.0=hdf63c60_0
47 | - libgfortran-ng=7.3.0=hdf63c60_0
48 | - libiconv=1.15=h516909a_1005
49 | - libpng=1.6.37=hbc83047_0
50 | - libprotobuf=3.8.0=hd408876_0
51 | - libstdcxx-ng=9.1.0=hdf63c60_0
52 | - libtiff=4.0.10=h2733197_2
53 | - libuuid=2.32.1=h14c3975_1000
54 | - libxcb=1.13=h14c3975_1002
55 | - libxml2=2.9.9=h13577e0_2
56 | - markdown=3.1.1=py36_0
57 | - matplotlib=3.1.1=py36h5429711_0
58 | - mkl=2019.4=243
59 | - mkl-service=2.3.0=py36he904b0f_0
60 | - mkl_fft=1.0.14=py36ha843d7b_0
61 | - mkl_random=1.0.2=py36hd81dba3_0
62 | - ncurses=6.1=he6710b0_1
63 | - ninja=1.9.0=py36hfd86e86_0
64 | - numpy=1.16.5=py36h7e9f1db_0
65 | - numpy-base=1.16.5=py36hde5b4d6_0
66 | - olefile=0.46=py36_0
67 | - openssl=1.1.1=h7b6447c_0
68 | - pcre=8.43=he6710b0_0
69 | - pillow=6.1.0=py36h34e0f95_0
70 | - pip=19.2.2=py36_0
71 | - protobuf=3.8.0=py36he6710b0_0
72 | - pthread-stubs=0.4=h14c3975_1001
73 | - pycparser=2.19=py36_0
74 | - pyopenssl=19.0.0=py36_0
75 | - pyparsing=2.4.2=py_0
76 | - pyqt=5.9.2=py36hcca6a23_4
77 | - pysocks=1.7.1=py36_0
78 | - python=3.6.8=h0371630_0
79 | - python-dateutil=2.8.0=py_0
80 | - pytorch=1.2.0=py3.6_cuda10.0.130_cudnn7.6.2_0
81 | - pytz=2019.2=py_0
82 | - qt=5.9.7=h5867ecd_1
83 | - readline=7.0=h7b6447c_5
84 | - requests=2.22.0=py36_0
85 | - scipy=1.3.1=py36h7c811a0_0
86 | - setuptools=41.2.0=py36_0
87 | - sip=4.19.8=py36hf484d3e_1000
88 | - six=1.12.0=py36_0
89 | - sqlite=3.29.0=h7b6447c_0
90 | - tensorboard=1.14.0=py36hf484d3e_0
91 | - tensorboardx=1.8=py_0
92 | - tensorflow=1.14.0=gpu_py36h57aa796_0
93 | - tensorflow-base=1.14.0=gpu_py36h8d69cac_0
94 | - tensorflow-estimator=1.14.0=py_0
95 | - tensorflow-gpu=1.14.0=h0d30ee6_0
96 | - termcolor=1.1.0=py36_1
97 | - tk=8.6.8=hbc83047_0
98 | - torchvision=0.4.0=py36_cu100
99 | - tornado=6.0.3=py36h516909a_0
100 | - tqdm=4.36.1=py_0
101 | - urllib3=1.24.2=py36_0
102 | - werkzeug=0.15.5=py_0
103 | - wheel=0.33.6=py36_0
104 | - wrapt=1.11.2=py36h7b6447c_0
105 | - xorg-libxau=1.0.9=h14c3975_0
106 | - xorg-libxdmcp=1.1.3=h516909a_0
107 | - xz=5.2.4=h14c3975_4
108 | - zlib=1.2.11=h7b6447c_3
109 | - zstd=1.3.7=h0b5b093_0
110 | - pip:
111 | - future==0.17.1
112 | - gym==0.9.1
113 | - pyglet==1.4.4
114 | prefix: /home/tungnd13/miniconda3/envs/e2c
115 |
116 |
--------------------------------------------------------------------------------
/e2c_model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from normal import *
4 | from networks import *
5 |
6 | torch.set_default_dtype(torch.float64)
7 |
8 | class E2C(nn.Module):
9 | def __init__(self, obs_dim, z_dim, u_dim, env = 'planar'):
10 | super(E2C, self).__init__()
11 | enc, dec, trans = load_config(env)
12 |
13 | self.obs_dim = obs_dim
14 | self.z_dim = z_dim
15 | self.u_dim = u_dim
16 |
17 | self.encoder = enc(obs_dim=obs_dim, z_dim=z_dim)
18 | # self.encoder.apply(init_weights)
19 | self.decoder = dec(z_dim=z_dim, obs_dim=obs_dim)
20 | # self.decoder.apply(init_weights)
21 | self.trans = trans(z_dim=z_dim, u_dim=u_dim)
22 | # self.trans.apply(init_weights)
23 |
24 | def encode(self, x):
25 | """
26 | :param x:
27 | :return: mean and log variance of q(z | x)
28 | """
29 | return self.encoder(x)
30 |
31 | def decode(self, z):
32 | """
33 | :param z:
34 | :return: bernoulli distribution p(x | z)
35 | """
36 | return self.decoder(z)
37 |
38 | def transition(self, z_bar, q_z, u):
39 | """
40 | :param z_bar:
41 | :param q_z:
42 | :param u:
43 | :return: samples z_hat_next and Q(z_hat_next)
44 | """
45 | return self.trans(z_bar, q_z, u)
46 |
47 | def reparam(self, mean, logvar):
48 | sigma = (logvar / 2).exp()
49 | epsilon = torch.randn_like(sigma)
50 | return mean + torch.mul(epsilon, sigma)
51 |
52 | def forward(self, x, u, x_next):
53 | mu, logvar = self.encode(x)
54 | z = self.reparam(mu, logvar)
55 | q_z = NormalDistribution(mu, logvar)
56 |
57 | x_recon = self.decode(z)
58 |
59 | z_next, q_z_next_pred = self.transition(z, q_z, u)
60 |
61 | x_next_pred = self.decode(z_next)
62 |
63 | mu_next, logvar_next = self.encode(x_next)
64 | q_z_next = NormalDistribution(mean=mu_next, logvar=logvar_next)
65 |
66 | return x_recon, x_next_pred, q_z, q_z_next_pred, q_z_next
67 |
68 | def predict(self, x, u):
69 | mu, logvar = self.encoder(x)
70 | z = self.reparam(mu, logvar)
71 | q_z = NormalDistribution(mu, logvar)
72 |
73 | z_next, q_z_next_pred = self.transition(z, q_z, u)
74 |
75 | x_next_pred = self.decode(z_next)
76 | return x_next_pred
--------------------------------------------------------------------------------
/evaluate_saved_model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 34,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "data": {
10 | "text/plain": [
11 | "IncompatibleKeys(missing_keys=[], unexpected_keys=[])"
12 | ]
13 | },
14 | "execution_count": 34,
15 | "metadata": {},
16 | "output_type": "execute_result"
17 | }
18 | ],
19 | "source": [
20 | "import torch\n",
21 | "from torch.utils.data import random_split\n",
22 | "from torch.utils.data import DataLoader\n",
23 | "import matplotlib.pyplot as plt\n",
24 | "import random\n",
25 | "from datasets import *\n",
26 | "from e2c_model import E2C\n",
27 | "from train_e2c import evaluate\n",
28 | "model = E2C(1600,2,2,'planar_partial').cuda()\n",
29 | "model.load_state_dict(torch.load('result/planar_partial/log_check_valid/model_3000'))"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 35,
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "batch_size = 128\n",
39 | "propor = 3/4\n",
40 | "dataset = PlanarDataset('./data/data/planar_partial')\n",
41 | "train_set, test_set = dataset[:int(len(dataset) * propor)], dataset[int(len(dataset) * propor):]\n",
42 | "test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=8)"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 36,
48 | "metadata": {},
49 | "outputs": [
50 | {
51 | "name": "stdout",
52 | "output_type": "stream",
53 | "text": [
54 | "State loss: 12.65210761933395\n",
55 | "Next state loss: 12.661921219078062\n"
56 | ]
57 | }
58 | ],
59 | "source": [
60 | "state_loss, next_state_loss = evaluate(model, test_loader)\n",
61 | "print ('State loss: ' + str(state_loss))\n",
62 | "print ('Next state loss: ' + str(next_state_loss))"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 67,
68 | "metadata": {},
69 | "outputs": [
70 | {
71 | "data": {
72 | "text/plain": [
73 | ""
74 | ]
75 | },
76 | "execution_count": 67,
77 | "metadata": {},
78 | "output_type": "execute_result"
79 | },
80 | {
81 | "data": {
82 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAACqCAYAAACTZZUqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAEOklEQVR4nO3dMU4jWRRAURuZDKkXQEBCxg5YApsgYzOkELMmxAomn4QUCWqSETNliTaWq3zd5pyIL4xctD5Xn4ddvRyGYQHA/p3UFwDwUwkwQESAASICDBARYICIAANEVts8eLlces0asxqGYbnv57SvmdtX+9oJGCAiwAARAQaICDBARIABIgIMEBFggIgAA0QEGCAiwAARAQaICDBARIABIgIMEBFggIgAA0QEGCAiwAARAQaICDBARIABIgIMEBFggIgAA0QEGCAiwAARAQaICDBARIABIqv6Ao7NMAzffuxyuZzxSmA69vU8nIABIgIMEDGC2NE2v5pt+lq/uvE7Jyfj89L6/ln//MfHx5eP3cS+3g8nYICIAANEBBggcpQz4F3mV+ZVHIrT09PR+ubmZrS+uLgYra+urkbrx8fHz49fXl5Gn9vlZ4TpOAEDRAQYICLAAJGjnAHDMVif8T49PY3WZ2dno/X63y9Wq/9+vO/u7kafe39/n+IS2ZETMEBEgAEiAgwQMQPe0frczW37mMrt7e1o/evXr9F6070hLi8vPz/edq/Z1/vhBAwQEWCAiAADRMyAJ2b+xVTOz89H6/U57Pr6//f/XSwWi+fn5y8fuy37eh5OwAARAQaICDBA5ChnwOZVHIOHh4fR+u3tbbS+vr4erV9fX0fr+/v7z4/d//cwOQEDRAQYILLc8i2Gfo9hVsMw7H1+dKj7en2Utmm96WVpdL7a107AABEBBogIMEDkKF+GBsdg01uP+fM5AQNEBBggIsAAEQEGiAgwQESAASICDBDZ2+uAt30N4y63lNznc1XK7/En/Pt+l309rZ+2r52AASICDBARYIDIbPcDnvp967+bt0z5XIc8N6u+z30+76HfD9i+nt5P3tdOwAARAQaICDBARIABIgIMEBFggIgAA0QEGCAiwAARAQaIzPZW5HVu2zetY71t36G/FXmdfT2tn7avnYABIgIMEBFggMjeZsDwHX/aDBi+wwwY4MAIMEBEgAEiAgwQEWCAiAADRAQYICLAABEBBogIMEBEgAEiAgwQEWCAiAADRAQYICLAABEBBogIMEBkVV/Asdnyv3ia8UpgOvb1PJyAASICDBARYICIGfCOtpmNbfpaszMOhX29H07AABEBBogIMEBEgAEiAgwQEWCAiAADRAQYICLAABEBBoh4K/KO1t9m6bZ9HAP7ej+cgAEiAgwQEWCAiBnwxMy/OEb29TycgAEiAgwQEWCAiAADRAQYICLAABEBBogIMEBEgAEiAgwQEWCAiAADRAQYICLAABEBBogIMEBEgAEiAgwQEWCAiAADRAQYICLAABEBBogIMEBEgAEiAgwQWW35+L8Xi8Vfc1wILBaLi+h57Wvm9OW+Xg7DsM8LAeBfRhAAEQEGiAgwQESAASICDBARYICIAANEBBggIsAAkX8ArTL/DHro1LgAAAAASUVORK5CYII=\n",
83 | "text/plain": [
84 | ""
85 | ]
86 | },
87 | "metadata": {},
88 | "output_type": "display_data"
89 | }
90 | ],
91 | "source": [
92 | "rand_idx = random.randint(0, len(test_set))\n",
93 | "x, u, x_next = test_set[rand_idx]\n",
94 | "with torch.no_grad():\n",
95 | " x_next_pred = model.predict(x.unsqueeze(0).view(-1,1600).double().cuda(),\n",
96 | " torch.Tensor(u).unsqueeze(0).double().cuda())\n",
97 | "plt.show()\n",
98 | "f, axarr = plt.subplots(1,2)\n",
99 | "plt.setp(axarr, xticks=[], yticks=[])\n",
100 | "axarr[0].imshow(x_next.squeeze(), cmap='gray')\n",
101 | "axarr[1].imshow(x_next_pred.squeeze().view(40,40).cpu(), cmap='gray')"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": null,
107 | "metadata": {},
108 | "outputs": [],
109 | "source": []
110 | }
111 | ],
112 | "metadata": {
113 | "kernelspec": {
114 | "display_name": "Python 3",
115 | "language": "python",
116 | "name": "python3"
117 | },
118 | "language_info": {
119 | "codemirror_mode": {
120 | "name": "ipython",
121 | "version": 3
122 | },
123 | "file_extension": ".py",
124 | "mimetype": "text/x-python",
125 | "name": "python",
126 | "nbconvert_exporter": "python",
127 | "pygments_lexer": "ipython3",
128 | "version": "3.6.8"
129 | }
130 | },
131 | "nbformat": 4,
132 | "nbformat_minor": 2
133 | }
134 |
--------------------------------------------------------------------------------
/map.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tung-nd/E2C-pytorch/8af375b58e1ddef3cd041d518b7061faece2f102/map.png
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from normal import NormalDistribution
4 |
5 | torch.set_default_dtype(torch.float64)
6 |
7 | def weights_init(m):
8 | if type(m) in [nn.Conv2d, nn.Linear, nn.ConvTranspose2d]:
9 | torch.nn.init.orthogonal_(m.weight)
10 |
11 | class Encoder(nn.Module):
12 | def __init__(self, net, obs_dim, z_dim):
13 | super(Encoder, self).__init__()
14 | self.net = net
15 | self.net.apply(weights_init)
16 | self.img_dim = obs_dim
17 | self.z_dim = z_dim
18 |
19 | def forward(self, x):
20 | """
21 | :param x: observation
22 | :return: the parameters of distribution q(z|x)
23 | """
24 | return self.net(x).chunk(2, dim = 1) # first half is mean, second half is logvar
25 |
26 | class Decoder(nn.Module):
27 | def __init__(self, net, z_dim, obs_dim):
28 | super(Decoder, self).__init__()
29 | self.net = net
30 | self.net.apply(weights_init)
31 | self.z_dim = z_dim
32 | self.obs_dim = obs_dim
33 |
34 | def forward(self, z):
35 | """
36 | :param z: sample from q(z|x)
37 | :return: reconstructed x
38 | """
39 | return self.net(z)
40 |
41 | class Transition(nn.Module):
42 | def __init__(self, net, z_dim, u_dim):
43 | super(Transition, self).__init__()
44 | self.net = net # network to output the last layer before predicting A_t, B_t and o_t
45 | self.net.apply(weights_init)
46 | self.h_dim = self.net[-3].out_features
47 | self.z_dim = z_dim
48 | self.u_dim = u_dim
49 |
50 | self.fc_A = nn.Sequential(
51 | nn.Linear(self.h_dim, self.z_dim * 2), # v_t and r_t
52 | nn.Sigmoid()
53 | )
54 | self.fc_A.apply(weights_init)
55 |
56 | self.fc_B = nn.Linear(self.h_dim, self.z_dim * self.u_dim)
57 | torch.nn.init.orthogonal_(self.fc_B.weight)
58 | self.fc_o = nn.Linear(self.h_dim, self.z_dim)
59 | torch.nn.init.orthogonal_(self.fc_o.weight)
60 |
61 | def forward(self, z_bar_t, q_z_t, u_t):
62 | """
63 | :param z_bar_t: the reference point
64 | :param Q_z_t: the distribution q(z|x)
65 | :param u_t: the action taken
66 | :return: the predicted q(z^_t+1 | z_t, z_bar_t, u_t)
67 | """
68 | h_t = self.net(z_bar_t)
69 | B_t = self.fc_B(h_t)
70 | o_t = self.fc_o(h_t)
71 |
72 | v_t, r_t = self.fc_A(h_t).chunk(2, dim=1)
73 | v_t = torch.unsqueeze(v_t, dim=-1)
74 | r_t = torch.unsqueeze(r_t, dim=-2)
75 |
76 | A_t = torch.eye(self.z_dim).repeat(z_bar_t.size(0), 1, 1).cuda() + torch.bmm(v_t, r_t)
77 |
78 | B_t = B_t.view(-1, self.z_dim, self.u_dim)
79 |
80 | mu_t = q_z_t.mean
81 |
82 | mean = A_t.bmm(mu_t.unsqueeze(-1)).squeeze(-1) + B_t.bmm(u_t.unsqueeze(-1)).squeeze(-1) + o_t
83 |
84 | return mean, NormalDistribution(mean, logvar=q_z_t.logvar, v=v_t.squeeze(), r=r_t.squeeze(), A=A_t)
85 |
86 | class PlanarEncoder(Encoder):
87 | def __init__(self, obs_dim = 1600, z_dim = 2):
88 | net = nn.Sequential(
89 | nn.Linear(obs_dim, 150),
90 | nn.BatchNorm1d(150),
91 | nn.ReLU(),
92 |
93 | nn.Linear(150, 150),
94 | nn.BatchNorm1d(150),
95 | nn.ReLU(),
96 |
97 | nn.Linear(150, 150),
98 | nn.BatchNorm1d(150),
99 | nn.ReLU(),
100 |
101 | nn.Linear(150, z_dim * 2)
102 | )
103 | super(PlanarEncoder, self).__init__(net, obs_dim, z_dim)
104 |
105 | class PlanarDecoder(Decoder):
106 | def __init__(self, z_dim = 2, obs_dim = 1600):
107 | net = nn.Sequential(
108 | nn.Linear(z_dim, 200),
109 | nn.BatchNorm1d(200),
110 | nn.ReLU(),
111 |
112 | nn.Linear(200, 200),
113 | nn.BatchNorm1d(200),
114 | nn.ReLU(),
115 |
116 | nn.Linear(200, 1600),
117 | nn.Sigmoid()
118 | )
119 | super(PlanarDecoder, self).__init__(net, z_dim, obs_dim)
120 |
121 | class PlanarTransition(Transition):
122 | def __init__(self, z_dim = 2, u_dim = 2):
123 | net = nn.Sequential(
124 | nn.Linear(z_dim, 100),
125 | nn.BatchNorm1d(100),
126 | nn.ReLU(),
127 |
128 | nn.Linear(100, 100),
129 | nn.BatchNorm1d(100),
130 | nn.ReLU()
131 | )
132 | super(PlanarTransition, self).__init__(net, z_dim, u_dim)
133 |
134 | class PendulumEncoder(Encoder):
135 | def __init__(self, obs_dim = 4608, z_dim = 3):
136 | net = nn.Sequential(
137 | nn.Linear(obs_dim, 800),
138 | nn.BatchNorm1d(800),
139 | nn.ReLU(),
140 |
141 | nn.Linear(800, 800),
142 | nn.BatchNorm1d(800),
143 | nn.ReLU(),
144 |
145 | nn.Linear(800, z_dim * 2)
146 | )
147 | super(PendulumEncoder, self).__init__(net, obs_dim, z_dim)
148 |
149 | class PendulumDecoder(Decoder):
150 | def __init__(self, z_dim = 3, obs_dim = 4608):
151 | net = nn.Sequential(
152 | nn.Linear(z_dim, 800),
153 | nn.BatchNorm1d(800),
154 | nn.ReLU(),
155 |
156 | nn.Linear(800, 800),
157 | nn.BatchNorm1d(800),
158 | nn.ReLU(),
159 |
160 | nn.Linear(800, obs_dim)
161 | )
162 | super(PendulumDecoder, self).__init__(net, z_dim, obs_dim)
163 |
164 | class PendulumTransition(Transition):
165 | def __init__(self, z_dim = 3, u_dim = 1):
166 | net = nn.Sequential(
167 | nn.Linear(z_dim, 100),
168 | nn.BatchNorm1d(100),
169 | nn.ReLU(),
170 |
171 | nn.Linear(100, 100),
172 | nn.BatchNorm1d(100),
173 | nn.ReLU()
174 | )
175 | super(PendulumTransition, self).__init__(net, z_dim, u_dim)
176 |
177 | CONFIG = {
178 | 'planar': (PlanarEncoder, PlanarDecoder, PlanarTransition),
179 | 'pendulum': (PendulumEncoder, PendulumDecoder, PendulumTransition)
180 | }
181 |
182 | def load_config(name):
183 | return CONFIG[name]
184 |
185 | __all__ = ['load_config']
186 |
187 | # enc = PendulumEncoder()
188 | # dec = PendulumDecoder()
189 | # trans = PendulumTransition()
190 | #
191 | # x = torch.randn(size=(10, 4608))
192 | # # print (x.size())
193 | # mean, logvar = enc(x)
194 | # # print (logvar.size())
195 | # x_recon = dec(mean)
196 | # # print (x_recon.size())
197 | #
198 | # q_z_t = NormalDistribution(mean, logvar)
199 | # print (q_z_t.mean.size())
200 | # print (q_z_t.cov.size())
201 | # u_t = torch.randn(size=(10, 1))
202 | # z_t_1 = trans(mean, q_z_t, u_t)
203 | # print (z_t_1[1].mean.size())
204 | # print (z_t_1[1].cov.size())
205 | #
206 | # kl = NormalDistribution.KL_divergence(z_t_1[1], q_z_t)
207 | # print (kl)
--------------------------------------------------------------------------------
/normal.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | torch.set_default_dtype(torch.float64)
4 |
5 | class NormalDistribution:
6 | def __init__(self, mean, logvar, v=None, r=None, A=None):
7 | """
8 | :param mean: mu in the paper
9 | :param logvar: \Sigma in the paper
10 | :param v:
11 | :param r:
12 | if A is not None then covariance matrix = A \Sigma A^T, where A = I + v^T r
13 | else the covariance matrix is simply diag(logvar.exp())
14 | """
15 | self.mean = mean
16 | self.logvar = logvar
17 | self.v = v
18 | self.r = r
19 |
20 | sigma = torch.diag_embed(torch.exp(logvar))
21 | if A is None:
22 | self.cov = sigma
23 | else:
24 | self.cov = A.bmm(sigma.bmm(A.transpose(1, 2)))
25 |
26 |
27 | @staticmethod
28 | def KL_divergence(q_z_next_pred, q_z_next):
29 | """
30 | :param q_z_next_pred: q(z_{t+1} | z_bar_t, q_z_t, u_t) using the transition
31 | :param q_z_next: q(z_t+1 | x_t+1) using the encoder
32 | :return: KL divergence between two distributions
33 | """
34 | mu_0 = q_z_next_pred.mean
35 | mu_1 = q_z_next.mean
36 | sigma_0 = torch.exp(q_z_next_pred.logvar)
37 | sigma_1 = torch.exp(q_z_next.logvar)
38 | v = q_z_next_pred.v
39 | r = q_z_next_pred.r
40 | k = float(q_z_next_pred.mean.size(1))
41 |
42 | sum = lambda x: torch.sum(x, dim=1)
43 |
44 | KL = 0.5 * torch.mean(sum((sigma_0 + 2*sigma_0*v*r) / sigma_1)
45 | + sum(r.pow(2) * sigma_0) * sum(v.pow(2) / sigma_1)
46 | + sum(torch.pow(mu_1-mu_0, 2) / sigma_1) - k
47 | + 2 * (sum(q_z_next.logvar - q_z_next_pred.logvar) - torch.log(1 + sum(v*r)))
48 | )
49 | return KL
--------------------------------------------------------------------------------
/test_plot.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 34,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from tensorboardX import SummaryWriter\n",
10 | "import matplotlib.pyplot as plt\n",
11 | "import numpy as np"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 35,
17 | "metadata": {},
18 | "outputs": [
19 | {
20 | "data": {
21 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAJlklEQVR4nO3dX6jX9R3H8fentbmVWi0PYxB6sChOzIrwokmUF9qaK7qSoG1QdLOBV+4PMXY5dBd1402W7CKaN8Egk0Ujajf748URt0I2Q3Qqupk6kjVpBn12kQMRZ+r5xHkdezxAkK/Hl58fh8OT7++o39Z7LwBIc9VsHwAAzkegAIgkUABEEigAIgkUAJGunu0DAJdu0aJFfXJycraPAUPs3LnzeO994tzrAgVz0OTkZE1PT8/2MWCI1tqB8133Fh8AkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACJdPdsHuBQ33nhjX7x48Yx3PvjggwGnqTp16tSQnQ8//HDIzsKFC4fszJ8/f8jOvn37huxMTEwM2Tl+/PiQnUWLFg3Zeeedd4733se8OLgCzalALV68uN58880Z7+zZs2fAaareeuutITuHDx8esrN69eohO/fee++QnbVr1w7ZWbdu3ZCd5557bsjOk08+OWRn1apVB4YMwRXKW3wARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIg0px5YePLkyXrttddmvPPYY48NOE3Vhg0bhuzceuutQ3befvvtITuvv/76kJ3169cP2RnxkMqqqmXLlg3Zue+++4bsABfmDgqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIc+qJujfccEOtXbt2yM4IJ0+eHLLz8ssvD9lZunTpkJ158+YN2dm7d++QnRUrVgzZOXTo0JCdHTt2DNkBLswdFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJHm1BN1jx49Ws8888yMd1auXDnzwwy0fv36ITtTU1NDdp5//vkhO9dcc82QnXfffXfIzqgnDm/fvn3IDnBh7qAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACINKeeqHv69Ok6cODAjHd27do14DRVK1asGLKzZMmSITsvvPDCkJ077rhjyM7ChQuH7GzZsmXIzqjXNWoHuDB3UABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASBcVqNba31prq2byB7XWHm+t/W4mGwB8driDAiDSJwaqtfZiVS2uqu2ttfdbaz9urd3TWvtDa+291tqfW2srz/r4x1tr+1pr/2qt7W+tfbu1NlVVm6vq62c23vvUXhEAV4RPDFTv/btVdbCqHu69z6+qrVX166r6WVV9uap+WFW/aq1NtNaurapNVfXN3vuCqlpRVX/qvf+lqr5XVX/svc/vvV//6bwcAK4Ul/MW33eq6tXe+6u99496769X1XRVrTnz6x9V1ddaa1/qvf+997571GEB+Oy4nCfqLqmqta21h8+69vmq+m3v/d+ttUfr47uqX7TWfl9VP+i9/3XAWeu6666rNWvWfPIHfoITJ04MOE3V1NTUkJ0HHnhgyM7mzZuH7GzatGnIzsaNG4fsPPvss0N29uzZM2TnqaeeGrIDXNjF3kH1s35+qKpe7L1ff9aPa3vvP6+q6r3/pve+uqq+WlV/raot59kAgAu62EAdraqlZ37+y6p6uLX2jdba51prX2ytrWyt3dRa+0pr7ZEz34v6T1W9Xx+/5fe/jZtaa18Y+goAuCJdbKA2VtVPz/ztu0er6pGq+klVHauP76h+dGbrqqpaX1VHquqfVXV/VX3/zMabVbW7qv7RWjs+6gUAcGW6qO9B9d63VdW2cy7f/38+/LzXe++nq+pbF380AD7L/ENdACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgUut97jxHcNmyZX3btnP/U/VLt3Xr1gGnqbr55puH7Ozdu3fIzoIFC4bsPPTQQ0N2XnrppSE7R44cGbIz6vO1atWqITt33nnnzt778sv5vcuXL+/T09NDzgGzrbV23q8Fd1AARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEunq2DzAbnnjiiSE7r7zyypCd/fv3D9l58MEHh+wcPnx4yM4tt9wyZOeNN94YsjPqybwnTpwYsgNcmDsoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIs2pJ+qeOnWqdu3aNeOdgwcPDjjNuCfYjnhNVVU7duwYsrNu3bohO8eOHRuyc9tttw3Zuf3224fsTE1NDdnZsGHDkB24UrmDAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECINKceqLuvHnzanJycsY7p0+fnvlhqmr37t1Ddu66664hO3ffffeQnaeffnrIzj333DNkZ2JiYsjO6tWrh+yM+rwDF+YOCoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEit9z7bZ7horbVjVXVgts8BgyzpvV/W44KXL1/ep6enR58HZkVrbWfvffm51+fUI98v94sZgLnHW3wARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARBJoACIJFAARBIoACIJFACRBAqASAIFQCSBAiCSQAEQSaAAiCRQAEQSKAAiCRQAkQQKgEgCBUAkgQIgkkABEKn13mf7DMAlaq0dq6oDs30OGGRJ733i3IsCBUAkb/EBEEmgAIgkUABEEigAIgkUAJEECoBIAgVAJIECIJJAARDpv7rGe11cTsUVAAAAAElFTkSuQmCC\n",
22 | "text/plain": [
23 | ""
24 | ]
25 | },
26 | "metadata": {},
27 | "output_type": "display_data"
28 | }
29 | ],
30 | "source": [
31 | "writer = SummaryWriter('test')\n",
32 | "fig, ax = plt.subplots(nrows=1, ncols=2)\n",
33 | "plt.setp(ax, xticks=[], yticks=[])\n",
34 | "# ax[0].set_ylabel('test', rotation=0,size='large')\n",
35 | "pad = 5\n",
36 | "ax[0].annotate('test', xy=(0, 0.5), xytext=(-ax[0].yaxis.labelpad - pad, 0),\n",
37 | " xycoords=ax[0].yaxis.label, textcoords='offset points',\n",
38 | " size='large', ha='right', va='center')\n",
39 | "ax[0].imshow(np.random.randn(10,10), cmap='Greys')\n",
40 | "fig.tight_layout()"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 36,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "writer.add_figure('test', fig, 0)"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": []
58 | }
59 | ],
60 | "metadata": {
61 | "kernelspec": {
62 | "display_name": "Python 3",
63 | "language": "python",
64 | "name": "python3"
65 | },
66 | "language_info": {
67 | "codemirror_mode": {
68 | "name": "ipython",
69 | "version": 3
70 | },
71 | "file_extension": ".py",
72 | "mimetype": "text/x-python",
73 | "name": "python",
74 | "nbconvert_exporter": "python",
75 | "pygments_lexer": "ipython3",
76 | "version": "3.6.8"
77 | }
78 | },
79 | "nbformat": 4,
80 | "nbformat_minor": 2
81 | }
82 |
--------------------------------------------------------------------------------
/train_e2c.py:
--------------------------------------------------------------------------------
1 | from tensorboardX import SummaryWriter
2 | import torch.optim as optim
3 | from torch.utils.data import DataLoader
4 | import matplotlib.pyplot as plt
5 | import argparse
6 | import sys
7 |
8 | from normal import *
9 | from e2c_model import E2C
10 | from datasets import *
11 | import data.sample_planar as planar_sampler
12 | import data.sample_pendulum_data as pendulum_sampler
13 | import data.sample_cartpole_data as cartpole_sampler
14 |
15 | torch.set_default_dtype(torch.float64)
16 |
17 | device = torch.device("cuda")
18 | datasets = {'planar': PlanarDataset, 'pendulum': GymPendulumDatasetV2}
19 | settings = {'planar': (1600, 2, 2), 'pendulum': (4608, 3, 1)}
20 | samplers = {'planar': planar_sampler, 'pendulum': pendulum_sampler, 'cartpole': cartpole_sampler}
21 | num_eval = 10 # number of images evaluated on tensorboard
22 |
23 | # dataset = datasets['planar']('./data/data/' + 'planar')
24 | # x, u, x_next = dataset[0]
25 | # imgplot = plt.imshow(x.squeeze(), cmap='gray')
26 | # plt.show()
27 | # print (np.array(u, dtype=float))
28 | # imgplot = plt.imshow(x_next.squeeze(), cmap='gray')
29 | # plt.show()
30 |
31 | def compute_loss(x, x_next, q_z_next, x_recon, x_next_pred, q_z, q_z_next_pred, lamda):
32 | # lower-bound loss
33 | recon_term = -torch.mean(torch.sum(x * torch.log(1e-5 + x_recon)
34 | + (1 - x) * torch.log(1e-5 + 1 - x_recon), dim=1))
35 | pred_loss = -torch.mean(torch.sum(x_next * torch.log(1e-5 + x_next_pred)
36 | + (1 - x_next) * torch.log(1e-5 + 1 - x_next_pred), dim=1))
37 |
38 | kl_term = - 0.5 * torch.mean(torch.sum(1 + q_z.logvar - q_z.mean.pow(2) - q_z.logvar.exp(), dim = 1))
39 |
40 | lower_bound = recon_term + pred_loss + kl_term
41 |
42 | # consistency loss
43 | consis_term = NormalDistribution.KL_divergence(q_z_next_pred, q_z_next)
44 | return lower_bound + lamda * consis_term
45 |
46 | def train(model, train_loader, lam, optimizer):
47 | model.train()
48 | avg_loss = 0.0
49 |
50 | num_batches = len(train_loader)
51 | for i, (x, u, x_next) in enumerate(train_loader, 0):
52 | x = x.view(-1, model.obs_dim).double().to(device)
53 | u = u.double().to(device)
54 | x_next = x_next.view(-1, model.obs_dim).double().to(device)
55 | optimizer.zero_grad()
56 |
57 | x_recon, x_next_pred, q_z, q_z_next_pred, q_z_next = model(x, u, x_next)
58 |
59 | loss = compute_loss(x, x_next, q_z_next, x_recon, x_next_pred, q_z, q_z_next_pred, lam)
60 |
61 | avg_loss += loss.item()
62 | loss.backward()
63 | optimizer.step()
64 |
65 | return avg_loss / num_batches
66 |
67 | def compute_log_likelihood(x, x_recon, x_next, x_next_pred):
68 | loss_1 = -torch.mean(torch.sum(x * torch.log(1e-5 + x_recon)
69 | + (1 - x) * torch.log(1e-5 + 1 - x_recon), dim=1))
70 | loss_2 = -torch.mean(torch.sum(x_next * torch.log(1e-5 + x_next_pred)
71 | + (1 - x_next) * torch.log(1e-5 + 1 - x_next_pred), dim=1))
72 | return loss_1, loss_2
73 |
74 | def evaluate(model, test_loader):
75 | model.eval()
76 | num_batches = len(test_loader)
77 | state_loss, next_state_loss = 0., 0.
78 | with torch.no_grad():
79 | for x, u, x_next in test_loader:
80 | x = x.view(-1, model.obs_dim).double().to(device)
81 | u = u.double().to(device)
82 | x_next = x_next.view(-1, model.obs_dim).double().to(device)
83 |
84 | x_recon, x_next_pred, q_z, q_z_next_pred, q_z_next = model(x, u, x_next)
85 | loss_1, loss_2 = compute_log_likelihood(x, x_recon, x_next, x_next_pred)
86 | state_loss += loss_1
87 | next_state_loss += loss_2
88 |
89 | return state_loss.item() / num_batches, next_state_loss.item() / num_batches
90 |
91 | # code for visualizing the training process
92 | def predict_x_next(model, env, num_eval):
93 | # frist sample a true trajectory from the environment
94 | sampler = samplers[env]
95 | state_samples, sampled_data = sampler.sample(num_eval)
96 |
97 | # use the trained model to predict the next observation
98 | predicted = []
99 | for x, u, x_next in sampled_data:
100 | x_reshaped = x.reshape(-1)
101 | x_reshaped = torch.from_numpy(x_reshaped).double().unsqueeze(dim=0).to(device)
102 | u = torch.from_numpy(u).double().unsqueeze(dim=0).to(device)
103 | with torch.no_grad():
104 | x_next_pred = model.predict(x_reshaped, u)
105 | predicted.append(x_next_pred.squeeze().cpu().numpy().reshape(sampler.width, sampler.height))
106 | true_x_next = [data[-1] for data in sampled_data]
107 | return true_x_next, predicted
108 |
109 | def plot_preds(model, env, num_eval):
110 | true_x_next, pred_x_next = predict_x_next(model, env, num_eval)
111 |
112 | # plot the predicted and true observations
113 | fig, axes =plt.subplots(nrows=2, ncols=num_eval)
114 | plt.setp(axes, xticks=[], yticks=[])
115 | pad = 5
116 | axes[0, 0].annotate('True observations', xy=(0, 0.5), xytext=(-axes[0,0].yaxis.labelpad - pad, 0),
117 | xycoords=axes[0,0].yaxis.label, textcoords='offset points',
118 | size='large', ha='right', va='center')
119 | axes[1, 0].annotate('Predicted observations', xy=(0, 0.5), xytext=(-axes[1, 0].yaxis.labelpad - pad, 0),
120 | xycoords=axes[1, 0].yaxis.label, textcoords='offset points',
121 | size='large', ha='right', va='center')
122 |
123 | for idx in np.arange(num_eval):
124 | axes[0, idx].imshow(true_x_next[idx], cmap='Greys')
125 | axes[1, idx].imshow(pred_x_next[idx], cmap='Greys')
126 | fig.tight_layout()
127 | return fig
128 |
129 | def main(args):
130 | env_name = args.env
131 | assert env_name in ['planar', 'pendulum']
132 | propor = args.propor
133 | batch_size = args.batch_size
134 | lr = args.lr
135 | weight_decay = args.decay
136 | lam = args.lam
137 | epoches = args.num_iter
138 | iter_save = args.iter_save
139 | log_dir = args.log_dir
140 | seed = args.seed
141 |
142 | np.random.seed(seed)
143 | torch.manual_seed(seed)
144 |
145 | dataset = datasets[env_name]('./data/data/' + env_name)
146 | train_set, test_set = dataset[:int(len(dataset) * propor)], dataset[int(len(dataset) * propor):]
147 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=8)
148 | test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=8)
149 |
150 | obs_dim, z_dim, u_dim = settings[env_name]
151 | model = E2C(obs_dim=obs_dim, z_dim=z_dim, u_dim=u_dim, env=env_name).to(device)
152 |
153 | optimizer = optim.Adam(model.parameters(), betas=(0.9, 0.999), eps=1e-8, lr=lr, weight_decay=weight_decay)
154 |
155 | writer = SummaryWriter('logs/' + env_name + '/' + log_dir)
156 |
157 | result_path = './result/' + env_name + '/' + log_dir
158 | if not path.exists(result_path):
159 | os.makedirs(result_path)
160 | with open(result_path + '/settings', 'w') as f:
161 | json.dump(args.__dict__, f, indent=2)
162 |
163 | for i in range(epoches):
164 | avg_loss = train(model, train_loader, lam, optimizer)
165 | print('Epoch %d' % i)
166 | print("Training loss: %f" % (avg_loss))
167 | # evaluate on test set
168 | state_loss, next_state_loss = evaluate(model, test_loader)
169 | print('State loss: ' + str(state_loss))
170 | print('Next state loss: ' + str(next_state_loss))
171 |
172 | # ...log the running loss
173 | writer.add_scalar('training loss', avg_loss, i)
174 | writer.add_scalar('state loss', state_loss, i)
175 | writer.add_scalar('next state loss', next_state_loss, i)
176 |
177 | # save model
178 | if (i + 1) % iter_save == 0:
179 | writer.add_figure('actual vs. predicted observations',
180 | plot_preds(model, env_name, num_eval),
181 | global_step=i)
182 | print('Saving the model.............')
183 |
184 | torch.save(model.state_dict(), result_path + '/model_' + str(i + 1))
185 | with open(result_path + '/loss_' + str(i + 1), 'w') as f:
186 | f.write('\n'.join([str(state_loss), str(next_state_loss)]))
187 |
188 | writer.close()
189 |
190 | if __name__ == "__main__":
191 | parser = argparse.ArgumentParser(description='train e2c model')
192 |
193 | # the default value is used for the planar task
194 | parser.add_argument('--env', required=True, type=str, help='the environment used for training')
195 | parser.add_argument('--propor', default=3/4, type=float, help='the proportion of data used for training')
196 | parser.add_argument('--batch_size', default=128, type=int, help='batch size')
197 | parser.add_argument('--lr', default=0.0005, type=float, help='the learning rate')
198 | parser.add_argument('--decay', default=0.001, type=float, help='the L2 regularization')
199 | parser.add_argument('--lam', default=0.25, type=float, help='the weight of the consistency term')
200 | parser.add_argument('--num_iter', default=5000, type=int, help='the number of epoches')
201 | parser.add_argument('--iter_save', default=1000, type=int, help='save model and result after this number of iterations')
202 | parser.add_argument('--log_dir', required=True, type=str, help='the directory to save training log')
203 | parser.add_argument('--seed', required=True, type=int, help='seed number')
204 |
205 | args = parser.parse_args()
206 |
207 | main(args)
208 |
--------------------------------------------------------------------------------