├── .gitignore
├── LICENSE
├── README.md
├── assets
└── example.svg
└── gridworld_vis
├── __init__.py
├── gridworld.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Marcell Vazquez-Chanlatte
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # gridworld-visualizer
2 |
3 |
6 |
7 | This is a small library for visualizing gridworlds by generating svgs
8 | styled and animated by css. The api of `gridworld-visualizer` centers
9 | around the `gridworld` function.
10 | ```python
11 | gridworld(n=10, actions=None, tile2classes=None, extra_css="") -> SVG
12 | ```
13 | which takes in the dimension of the gridworld (currently assumed to
14 | be a square `n x n`), the sequence of actions (currently support moving
15 | in the cardinal directions), and a function
16 | ```python
17 | tile2classes(x: int, y: int) -> str
18 | ```
19 | which given a grid cell `(x, y)` returns a string for its
20 | type. Currently, the default styling supports, "water", "recharge",
21 | "dry", "lava", and "normal". As these types just correspond to css
22 | classes, one can add additional styling using the `extra_css` option.
23 |
24 | ## Example
25 | Below we generate the gridworld from at the top of the page, (originally from [Vazquez-Chanlatte, Marcell, et al. "Learning Task Specifications from Demonstrations."](https://arxiv.org/abs/1710.03875)).
26 | ```python
27 | import gridworld_visualizer as gv
28 |
29 |
30 | def tile2classes(x, y):
31 | if (3 <= x <= 4) and (2 <= y <= 5):
32 | return "water"
33 | elif (x in (0, 7)) and (y in (0, 7)):
34 | return "recharge"
35 | elif (2 <= x <= 5) and y in (0, 7):
36 | return "dry"
37 | elif x in (1, 6) and (y in (4, 5) or y <= 1):
38 | return "lava"
39 | elif (x in (0, 7)) and (y in (1, 4, 5)):
40 | return "lava"
41 |
42 | return "normal"
43 |
44 | actions = [gv.E, gv.N, gv.N, gv.N, gv.N, gv.W, gv.W, gv.W]
45 | svg = gv.gridworld(n=8, tile2classes=tile2classes, actions=actions)
46 | svg.saveas("example.svg", pretty=True)
47 | ```
48 |
--------------------------------------------------------------------------------
/assets/example.svg:
--------------------------------------------------------------------------------
1 |
2 |
126 |
--------------------------------------------------------------------------------
/gridworld_vis/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mvcisback/gridworld-visualizer/a480b1718895a741ab1c6e60c0ba5c98661de903/gridworld_vis/__init__.py
--------------------------------------------------------------------------------
/gridworld_vis/gridworld.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 |
3 | import svgwrite
4 |
5 |
6 | BOARD_SIZE = ("200", "200")
7 | CSS_STYLES = """
8 | .background { fill: white; }
9 | .line { stroke: firebrick; stroke-width: .1mm; }
10 | .lava { fill: #ff8b8b; }
11 | .dry { fill: #f4a460; }
12 | .water { fill: #afafff; }
13 | .recharge {fill: #ffff00; }
14 | .normal {fill: white; }
15 | rect {
16 | stroke: black;
17 | stroke-width: 1;
18 | }
19 |
20 | .agent {
21 | r: 10%;
22 | fill: black;
23 | stroke-width: 2;
24 | stroke: grey;
25 | animation: blinker 4s linear infinite;
26 | animation: move 5s ease forwards;
27 | }
28 | @keyframes blinker {
29 | 50% {
30 | opacity: 0.5;
31 | }
32 | }
33 | """
34 |
35 |
36 | def draw_board(n=3, tile2classes=None):
37 | dwg = svgwrite.Drawing(size=(f"{n+0.05}cm", f"{n+0.05}cm"))
38 |
39 | dwg.add(dwg.rect(size=('100%','100%'), class_='background'))
40 |
41 | def group(classname):
42 | return dwg.add(dwg.g(class_=classname))
43 |
44 | # draw squares
45 | for x, y in product(range(n), range(n)):
46 | kwargs = {
47 | 'insert': (f"{x+0.1}cm", f"{y+0.1}cm"),
48 | 'size': (f"0.9cm", f"0.9cm"),
49 | }
50 | if tile2classes is not None and tile2classes(x, y):
51 | kwargs["class_"] = tile2classes(x, n - y - 1)
52 |
53 | dwg.add(dwg.rect(**kwargs))
54 |
55 | return dwg
56 |
57 |
58 | N = (0, -1)
59 | S = (0, 1)
60 | W = (-1, 0)
61 | E = (1, 0)
62 |
63 |
64 | def gen_offsets(actions):
65 | dx, dy = 0, 0
66 | for ax, ay in actions:
67 | dx += ax
68 | dy += ay
69 | yield dx, dy
70 |
71 |
72 | def move_keyframe(dx, dy, ratio):
73 | return f"""{ratio*100}% {{
74 | transform: translate({dx}cm, {dy}cm);
75 | }}"""
76 |
77 |
78 | def gridworld(n=10, actions=None, tile2classes=None, extra_css=""):
79 | dwg = draw_board(n=n, tile2classes=tile2classes)
80 |
81 | css_styles = CSS_STYLES
82 | if actions is not None:
83 | # Add agent.
84 | x, y = 2, 3 # start position.
85 | cx, cy = x + 0.55, (n - y - 1) + 0.55
86 | dwg.add(svgwrite.shapes.Circle(
87 | r="0.3cm",
88 | center=(f"{cx}cm", f"{cy}cm"),
89 | class_="agent",
90 | ))
91 |
92 | offsets = gen_offsets(actions)
93 | keyframes = [move_keyframe(x, y, (i+1)/len(actions)) for i, (x, y)
94 | in enumerate(offsets)]
95 | move_css = "\n@keyframes move {\n" + '\n'.join(keyframes) + "\n}"
96 | css_styles += move_css
97 |
98 | dwg.defs.add(dwg.style(css_styles + extra_css))
99 | return dwg
100 |
101 |
102 | if __name__== '__main__':
103 | def tile2classes(x, y):
104 | if (3 <= x <= 4) and (2 <= y <= 5):
105 | return "water"
106 | elif (x in (0, 7)) and (y in (0, 7)):
107 | return "recharge"
108 | elif (2 <= x <= 5) and y in (0, 7):
109 | return "dry"
110 | elif x in (1, 6) and (y in (4, 5) or y <= 1):
111 | return "lava"
112 | elif (x in (0, 7)) and (y in (1, 4, 5)):
113 | return "lava"
114 |
115 | return "normal"
116 |
117 | actions = [E, N, N, N, N, W, W, W]
118 | dwg = gridworld(n=8, tile2classes=tile2classes, actions=actions)
119 | dwg.saveas("example.svg", pretty=True)
120 |
--------------------------------------------------------------------------------
/gridworld_vis/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | DESC = 'Gridworld visualization library..'
4 |
5 | setup(
6 | name='gridworld-viz',
7 | version='0.0.0',
8 | description=DESC,
9 | url='http://github.com/mvcisback/gridworld-viz',
10 | author='Marcell Vazquez-Chanlatte',
11 | author_email='marcell.vc@eecs.berkeley.edu',
12 | license='MIT',
13 | install_requires=[
14 | 'svgwrite'
15 | ],
16 | packages=find_packages(),
17 | )
18 |
--------------------------------------------------------------------------------