├── .gitignore ├── LICENSE ├── README.md ├── assets └── example.svg └── gridworld_vis ├── __init__.py ├── gridworld.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Marcell Vazquez-Chanlatte 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gridworld-visualizer 2 | 3 | 4 | 5 | 6 | 7 | This is a small library for visualizing gridworlds by generating svgs 8 | styled and animated by css. The api of `gridworld-visualizer` centers 9 | around the `gridworld` function. 10 | ```python 11 | gridworld(n=10, actions=None, tile2classes=None, extra_css="") -> SVG 12 | ``` 13 | which takes in the dimension of the gridworld (currently assumed to 14 | be a square `n x n`), the sequence of actions (currently support moving 15 | in the cardinal directions), and a function 16 | ```python 17 | tile2classes(x: int, y: int) -> str 18 | ``` 19 | which given a grid cell `(x, y)` returns a string for its 20 | type. Currently, the default styling supports, "water", "recharge", 21 | "dry", "lava", and "normal". As these types just correspond to css 22 | classes, one can add additional styling using the `extra_css` option. 23 | 24 | ## Example 25 | Below we generate the gridworld from at the top of the page, (originally from [Vazquez-Chanlatte, Marcell, et al. "Learning Task Specifications from Demonstrations."](https://arxiv.org/abs/1710.03875)). 26 | ```python 27 | import gridworld_visualizer as gv 28 | 29 | 30 | def tile2classes(x, y): 31 | if (3 <= x <= 4) and (2 <= y <= 5): 32 | return "water" 33 | elif (x in (0, 7)) and (y in (0, 7)): 34 | return "recharge" 35 | elif (2 <= x <= 5) and y in (0, 7): 36 | return "dry" 37 | elif x in (1, 6) and (y in (4, 5) or y <= 1): 38 | return "lava" 39 | elif (x in (0, 7)) and (y in (1, 4, 5)): 40 | return "lava" 41 | 42 | return "normal" 43 | 44 | actions = [gv.E, gv.N, gv.N, gv.N, gv.N, gv.W, gv.W, gv.W] 45 | svg = gv.gridworld(n=8, tile2classes=tile2classes, actions=actions) 46 | svg.saveas("example.svg", pretty=True) 47 | ``` 48 | -------------------------------------------------------------------------------- /assets/example.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /gridworld_vis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mvcisback/gridworld-visualizer/a480b1718895a741ab1c6e60c0ba5c98661de903/gridworld_vis/__init__.py -------------------------------------------------------------------------------- /gridworld_vis/gridworld.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import svgwrite 4 | 5 | 6 | BOARD_SIZE = ("200", "200") 7 | CSS_STYLES = """ 8 | .background { fill: white; } 9 | .line { stroke: firebrick; stroke-width: .1mm; } 10 | .lava { fill: #ff8b8b; } 11 | .dry { fill: #f4a460; } 12 | .water { fill: #afafff; } 13 | .recharge {fill: #ffff00; } 14 | .normal {fill: white; } 15 | rect { 16 | stroke: black; 17 | stroke-width: 1; 18 | } 19 | 20 | .agent { 21 | r: 10%; 22 | fill: black; 23 | stroke-width: 2; 24 | stroke: grey; 25 | animation: blinker 4s linear infinite; 26 | animation: move 5s ease forwards; 27 | } 28 | @keyframes blinker { 29 | 50% { 30 | opacity: 0.5; 31 | } 32 | } 33 | """ 34 | 35 | 36 | def draw_board(n=3, tile2classes=None): 37 | dwg = svgwrite.Drawing(size=(f"{n+0.05}cm", f"{n+0.05}cm")) 38 | 39 | dwg.add(dwg.rect(size=('100%','100%'), class_='background')) 40 | 41 | def group(classname): 42 | return dwg.add(dwg.g(class_=classname)) 43 | 44 | # draw squares 45 | for x, y in product(range(n), range(n)): 46 | kwargs = { 47 | 'insert': (f"{x+0.1}cm", f"{y+0.1}cm"), 48 | 'size': (f"0.9cm", f"0.9cm"), 49 | } 50 | if tile2classes is not None and tile2classes(x, y): 51 | kwargs["class_"] = tile2classes(x, n - y - 1) 52 | 53 | dwg.add(dwg.rect(**kwargs)) 54 | 55 | return dwg 56 | 57 | 58 | N = (0, -1) 59 | S = (0, 1) 60 | W = (-1, 0) 61 | E = (1, 0) 62 | 63 | 64 | def gen_offsets(actions): 65 | dx, dy = 0, 0 66 | for ax, ay in actions: 67 | dx += ax 68 | dy += ay 69 | yield dx, dy 70 | 71 | 72 | def move_keyframe(dx, dy, ratio): 73 | return f"""{ratio*100}% {{ 74 | transform: translate({dx}cm, {dy}cm); 75 | }}""" 76 | 77 | 78 | def gridworld(n=10, actions=None, tile2classes=None, extra_css=""): 79 | dwg = draw_board(n=n, tile2classes=tile2classes) 80 | 81 | css_styles = CSS_STYLES 82 | if actions is not None: 83 | # Add agent. 84 | x, y = 2, 3 # start position. 85 | cx, cy = x + 0.55, (n - y - 1) + 0.55 86 | dwg.add(svgwrite.shapes.Circle( 87 | r="0.3cm", 88 | center=(f"{cx}cm", f"{cy}cm"), 89 | class_="agent", 90 | )) 91 | 92 | offsets = gen_offsets(actions) 93 | keyframes = [move_keyframe(x, y, (i+1)/len(actions)) for i, (x, y) 94 | in enumerate(offsets)] 95 | move_css = "\n@keyframes move {\n" + '\n'.join(keyframes) + "\n}" 96 | css_styles += move_css 97 | 98 | dwg.defs.add(dwg.style(css_styles + extra_css)) 99 | return dwg 100 | 101 | 102 | if __name__== '__main__': 103 | def tile2classes(x, y): 104 | if (3 <= x <= 4) and (2 <= y <= 5): 105 | return "water" 106 | elif (x in (0, 7)) and (y in (0, 7)): 107 | return "recharge" 108 | elif (2 <= x <= 5) and y in (0, 7): 109 | return "dry" 110 | elif x in (1, 6) and (y in (4, 5) or y <= 1): 111 | return "lava" 112 | elif (x in (0, 7)) and (y in (1, 4, 5)): 113 | return "lava" 114 | 115 | return "normal" 116 | 117 | actions = [E, N, N, N, N, W, W, W] 118 | dwg = gridworld(n=8, tile2classes=tile2classes, actions=actions) 119 | dwg.saveas("example.svg", pretty=True) 120 | -------------------------------------------------------------------------------- /gridworld_vis/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | DESC = 'Gridworld visualization library..' 4 | 5 | setup( 6 | name='gridworld-viz', 7 | version='0.0.0', 8 | description=DESC, 9 | url='http://github.com/mvcisback/gridworld-viz', 10 | author='Marcell Vazquez-Chanlatte', 11 | author_email='marcell.vc@eecs.berkeley.edu', 12 | license='MIT', 13 | install_requires=[ 14 | 'svgwrite' 15 | ], 16 | packages=find_packages(), 17 | ) 18 | --------------------------------------------------------------------------------