├── .editorconfig
├── .gitignore
├── .vscode
└── launch.json
├── LICENSE
├── README.md
├── cspell.json
├── dictionaries
├── math.txt
├── numpy.txt
├── pygame.txt
├── pypaths.txt
├── pytest.txt
├── pytmx.txt
├── readme.txt
└── scipy.txt
├── docs
├── screenshot_predator_prey.png
└── screenshot_soccer.png
├── pygame_rl
├── __init__.py
├── data
│ ├── map
│ │ ├── gridworld
│ │ │ ├── agents_sprite.yaml
│ │ │ ├── gridworld.tmx
│ │ │ ├── ground_tile.yaml
│ │ │ ├── minecraft_tileset.tsx
│ │ │ └── obstacles_sprite.yaml
│ │ ├── predator_prey
│ │ │ ├── agent_sprite.yaml
│ │ │ ├── ground_tile.yaml
│ │ │ ├── minecraft_tileset.tsx
│ │ │ └── predator_prey.tmx
│ │ └── soccer
│ │ │ ├── agent_sprite.yaml
│ │ │ ├── goal_tile.yaml
│ │ │ ├── ground_tile.yaml
│ │ │ ├── minecraft_tileset.tsx
│ │ │ ├── soccer.tmx
│ │ │ └── spawn_tile.yaml
│ └── tileset
│ │ └── minecraft_sprite_32x32.png
├── renderer
│ ├── __init__.py
│ └── pygame_renderer.py
├── rl
│ ├── __init__.py
│ └── environment.py
├── scenario
│ ├── __init__.py
│ ├── gridworld
│ │ ├── __init__.py
│ │ ├── envs
│ │ │ ├── __init__.py
│ │ │ ├── gridworld_v0.py
│ │ │ └── gridworld_v1.py
│ │ ├── map_data.py
│ │ ├── options.py
│ │ └── renderer.py
│ ├── predator_prey_environment.py
│ ├── predator_prey_renderer.py
│ ├── soccer
│ │ ├── __init__.py
│ │ ├── actions.py
│ │ ├── agent_modes.py
│ │ ├── ai_modes.py
│ │ ├── envs
│ │ │ ├── __init__.py
│ │ │ └── soccer_v0.py
│ │ ├── map_data.py
│ │ ├── options.py
│ │ ├── renderer.py
│ │ ├── renderer_options.py
│ │ ├── state.py
│ │ └── teams.py
│ ├── soccer_environment.py
│ └── soccer_renderer.py
└── util
│ ├── __init__.py
│ └── file_util.py
├── pylintrc
├── sample
├── data
│ ├── map
│ │ ├── gridworld
│ │ │ ├── agents_sprite.yaml
│ │ │ ├── gridworld_9x9.tmx
│ │ │ ├── ground_tile.yaml
│ │ │ ├── minecraft_tileset.tsx
│ │ │ └── obstacles_sprite.yaml
│ │ ├── predator_prey
│ │ │ ├── agent_sprite.yaml
│ │ │ ├── ground_tile.yaml
│ │ │ ├── minecraft_tileset.tsx
│ │ │ └── predator_prey_15x15.tmx
│ │ └── soccer
│ │ │ ├── agent_sprite.yaml
│ │ │ ├── goal_tile.yaml
│ │ │ ├── ground_tile.yaml
│ │ │ ├── minecraft_tileset.tsx
│ │ │ ├── soccer_13x10_goal_4.tmx
│ │ │ ├── soccer_21x14_goal_4.tmx
│ │ │ ├── soccer_21x14_goal_6.tmx
│ │ │ ├── soccer_21x14_goal_8.tmx
│ │ │ └── spawn_tile.yaml
│ └── tileset
│ │ └── minecraft_sprite_32x32.png
├── gridworld
│ ├── environment_advanced.py
│ ├── environment_simple.py
│ └── renderer.py
├── predator_prey
│ ├── environment_advanced.py
│ ├── environment_simple.py
│ └── renderer.py
└── soccer
│ ├── env_soccer_v0.py
│ ├── environment_advanced.py
│ ├── environment_legacy.py
│ ├── environment_simple.py
│ ├── renderer.py
│ └── renderer_custom_map.py
├── setup.cfg
├── setup.py
└── tests
├── main.py
├── test_file_uitl.py
├── test_soccer_environment.py
├── test_soccer_environment_scenarios.py
└── test_soccer_renderer.py
/.editorconfig:
--------------------------------------------------------------------------------
1 | # EditorConfig is awesome: http://EditorConfig.org
2 |
3 | # top-most EditorConfig file
4 | root = true
5 |
6 | # Unix-style newlines with a newline ending every file
7 | [*]
8 | end_of_line = lf
9 | insert_final_newline = true
10 |
11 | # Matches Python files
12 | [*.py]
13 | indent_style = space
14 | indent_size = 4
15 |
16 | # Matches YAML files
17 | [*.yaml]
18 | indent_style = space
19 | indent_size = 2
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.gitignore.io/api/osx,linux,python,windows,visualstudiocode
3 |
4 | ### Linux ###
5 | *~
6 |
7 | # temporary files which can be created if a process still has a handle open of a deleted file
8 | .fuse_hidden*
9 |
10 | # KDE directory preferences
11 | .directory
12 |
13 | # Linux trash folder which might appear on any partition or disk
14 | .Trash-*
15 |
16 | # .nfs files are created when an open file is removed but is still being accessed
17 | .nfs*
18 |
19 | ### OSX ###
20 | *.DS_Store
21 | .AppleDouble
22 | .LSOverride
23 |
24 | # Icon must end with two \r
25 | Icon
26 |
27 | # Thumbnails
28 | ._*
29 |
30 | # Files that might appear in the root of a volume
31 | .DocumentRevisions-V100
32 | .fseventsd
33 | .Spotlight-V100
34 | .TemporaryItems
35 | .Trashes
36 | .VolumeIcon.icns
37 | .com.apple.timemachine.donotpresent
38 |
39 | # Directories potentially created on remote AFP share
40 | .AppleDB
41 | .AppleDesktop
42 | Network Trash Folder
43 | Temporary Items
44 | .apdisk
45 |
46 | ### Python ###
47 | # Byte-compiled / optimized / DLL files
48 | __pycache__/
49 | *.py[cod]
50 | *$py.class
51 |
52 | # C extensions
53 | *.so
54 |
55 | # Distribution / packaging
56 | .Python
57 | build/
58 | develop-eggs/
59 | dist/
60 | downloads/
61 | eggs/
62 | .eggs/
63 | lib/
64 | lib64/
65 | parts/
66 | sdist/
67 | var/
68 | wheels/
69 | *.egg-info/
70 | .installed.cfg
71 | *.egg
72 |
73 | # PyInstaller
74 | # Usually these files are written by a python script from a template
75 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
76 | *.manifest
77 | *.spec
78 |
79 | # Installer logs
80 | pip-log.txt
81 | pip-delete-this-directory.txt
82 |
83 | # Unit test / coverage reports
84 | htmlcov/
85 | .tox/
86 | .coverage
87 | .coverage.*
88 | .cache
89 | .pytest_cache/
90 | nosetests.xml
91 | coverage.xml
92 | *.cover
93 | .hypothesis/
94 |
95 | # Translations
96 | *.mo
97 | *.pot
98 |
99 | # Flask stuff:
100 | instance/
101 | .webassets-cache
102 |
103 | # Scrapy stuff:
104 | .scrapy
105 |
106 | # Sphinx documentation
107 | docs/_build/
108 |
109 | # PyBuilder
110 | target/
111 |
112 | # Jupyter Notebook
113 | .ipynb_checkpoints
114 |
115 | # pyenv
116 | .python-version
117 |
118 | # celery beat schedule file
119 | celerybeat-schedule.*
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 |
146 | ### VisualStudioCode ###
147 | .vscode/*
148 | !.vscode/settings.json
149 | !.vscode/tasks.json
150 | !.vscode/launch.json
151 | !.vscode/extensions.json
152 | .history
153 |
154 | ### Windows ###
155 | # Windows thumbnail cache files
156 | Thumbs.db
157 | ehthumbs.db
158 | ehthumbs_vista.db
159 |
160 | # Folder config file
161 | Desktop.ini
162 |
163 | # Recycle Bin used on file shares
164 | $RECYCLE.BIN/
165 |
166 | # Windows Installer files
167 | *.cab
168 | *.msi
169 | *.msm
170 | *.msp
171 |
172 | # Windows shortcuts
173 | *.lnk
174 |
175 |
176 | # End of https://www.gitignore.io/api/osx,linux,python,windows,visualstudiocode
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: Current File",
9 | "type": "python",
10 | "request": "launch",
11 | "program": "${file}"
12 | },
13 | {
14 | "name": "Python: Attach",
15 | "type": "python",
16 | "request": "attach",
17 | "localRoot": "${workspaceFolder}",
18 | "remoteRoot": "${workspaceFolder}",
19 | "port": 3000,
20 | "secret": "my_secret",
21 | "host": "localhost"
22 | },
23 | {
24 | "name": "Python: Terminal (integrated)",
25 | "type": "python",
26 | "request": "launch",
27 | "program": "${file}",
28 | "console": "integratedTerminal"
29 | },
30 | {
31 | "name": "Python: Terminal (external)",
32 | "type": "python",
33 | "request": "launch",
34 | "program": "${file}",
35 | "console": "externalTerminal"
36 | },
37 | {
38 | "name": "Python: Django",
39 | "type": "python",
40 | "request": "launch",
41 | "program": "${workspaceFolder}/manage.py",
42 | "args": [
43 | "runserver",
44 | "--noreload",
45 | "--nothreading"
46 | ],
47 | "debugOptions": [
48 | "RedirectOutput",
49 | "Django"
50 | ]
51 | },
52 | {
53 | "name": "Python: Flask (0.11.x or later)",
54 | "type": "python",
55 | "request": "launch",
56 | "module": "flask",
57 | "env": {
58 | "FLASK_APP": "app.py"
59 | },
60 | "args": [
61 | "run",
62 | "--no-debugger",
63 | "--no-reload"
64 | ]
65 | },
66 | {
67 | "name": "Python: Module",
68 | "type": "python",
69 | "request": "launch",
70 | "module": "module.name"
71 | },
72 | {
73 | "name": "Python: Pyramid",
74 | "type": "python",
75 | "request": "launch",
76 | "args": [
77 | "${workspaceFolder}/development.ini"
78 | ],
79 | "debugOptions": [
80 | "RedirectOutput",
81 | "Pyramid"
82 | ]
83 | },
84 | {
85 | "name": "Python: Watson",
86 | "type": "python",
87 | "request": "launch",
88 | "program": "${workspaceFolder}/console.py",
89 | "args": [
90 | "dev",
91 | "runserver",
92 | "--noreload=True"
93 | ]
94 | },
95 | {
96 | "name": "Python: All debug Options",
97 | "type": "python",
98 | "request": "launch",
99 | "pythonPath": "${config:python.pythonPath}",
100 | "program": "${file}",
101 | "module": "module.name",
102 | "env": {
103 | "VAR1": "1",
104 | "VAR2": "2"
105 | },
106 | "envFile": "${workspaceFolder}/.env",
107 | "args": [
108 | "arg1",
109 | "arg2"
110 | ],
111 | "debugOptions": [
112 | "RedirectOutput"
113 | ]
114 | }
115 | ]
116 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Shawn Chang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pygame RL
2 |
3 | Some game environments used for reinforcement learning.
4 |
5 | ## Soccer
6 |
7 | A variant of the game described in the paper [He, He, et al. "Opponent modeling in deep reinforcement learning." International Conference on Machine Learning. 2016](https://www.umiacs.umd.edu/~hal/docs/daume16opponent.pdf). Pygame is used as the rendering framework. PyTMX is used to read the map file. Customized Minecraft texture is used for displaying the tiles.
8 |
9 | 
10 |
11 | Reinforcement learning controls agent 1 (shown as the player Steve head), the computer controls agent 2 (shown as the pig head). The agent who has the ball is bordered by a blue square (in this case, the player has the ball shown in the image).
12 |
13 | When the player carries the ball to the rightmost goal area, a reward of 1.0 is given; When the computer carries the ball to the leftmost goal area, a reward of -1.0 is given. The episode ends when either one of the agent carries the ball to its goal area or the time step reaches 100.
14 |
15 | ### Changing the Map
16 |
17 | The map data is embedded in the map file. For example, see `pygame_rl/data/map/soccer/soccer.tmx`. Config file in the same directory is associated with layers regarding the name to positions mapping. See `pygame_rl/renderer/pygame_renderer.py` for more information.
18 |
19 | To modify the map, for example:
20 |
21 | * Change the spawn area: Modify the layer `spawn_area` in `soccer.tmx`.
22 | * Change the moving agents: Modify the layer `agent` in `soccer.tmx` and the mapping file `agent_sprite.yaml`.
23 | * Change the goal area: Modify the layer `goal` in `soccer.tmx` and the mapping file `goal_tile.yaml`.
24 | * Change the walkable area: Modify the layer `ground` in `soccer.tmx` and the mapping file `ground_tile.yaml`.
25 |
26 | ### Computer Agent Algorithm
27 |
28 | The computer agent has 4 strategies. The internal algorithm of either approaching or avoiding is by randomly moving the direction in either axis so that the Euclidean distance from the target is shorter or further. The defensive target is either the player who possesses the ball if one of the players has it or the nearest player if no players has the ball.
29 |
30 | * "Avoid opponent": See where the nearest player is, avoid him.
31 | * "Advance to goal": See where the leftmost goal area is, select a grid which has the maximum distance from the nearest player, approach it.
32 | * "Defend goal": See where the rightmost goal area is, select a grid which has the minimum distance from the defensive target, approach it.
33 | * "Intercept goal": See where the defensive target is, intercept him. It's basically approaching with an exception that the Euclidean distance is always greater than or equal to 1.
34 |
35 | ## Predator-Prey
36 |
37 | 
38 |
39 | There are fixed amount of predators (Steve head), preys (Cake) and obstacles (Cobblestone) on the field (Snow). Each agent can take five actions: moving to 4-direction grid points or standing still at each time step. Overlapping is only allowed when more than 1 predators and at least one prey intends to move to the same grid point.
40 |
41 | Player can control part of predators and preys. Rule-based predators approach preys by finding the shortest path; Rule-based preys evade predators based on the directions and distances to predators. When a predator catches a prey, the prey disappears from the field, and a reward of 1.0 is given; A reward of 0.0 is given at all other time steps. The episode ends when there are no more preys or the time step reaches 100.
42 |
43 | ## Installation
44 |
45 | ### Requirements
46 |
47 | * [Python 3.6](https://www.continuum.io/)
48 |
49 | ### Getting Started
50 |
51 | 1. Clone the repository.
52 | 2. Open a command line and change the working directory to the cloned directory.
53 | 3. Install the package in [editable](https://pip.pypa.io/en/stable/reference/pip_install/#editable-installs) mode because this package is not intended to be published.
54 | ```shell
55 | pip install -e .
56 | ```
57 |
58 | ### Running the Samples
59 |
60 | Run and see the sample files in `sample/` to get started.
61 |
62 | ## Development
63 |
64 | ### Software
65 |
66 | * [Visual Studio Code](https://code.visualstudio.com/) for editing the text files.
67 | * [Python extension for VSCode](https://marketplace.visualstudio.com/items?itemName=donjayamanne.python) for debugging and linting Python files.
68 | * [Tiled Map Editor](http://www.mapeditor.org/) for editing `.tmx` and `.tsx` files.
69 | * [GIMP](https://www.gimp.org/) for editing the image files.
70 |
71 | ### Installing the Test Dependencies
72 |
73 | 1. Install the test dependencies.
74 | ```shell
75 | pip install -e .[test]
76 | ```
77 |
78 | ### Running the Tests
79 |
80 | 1. Run the tests with Pytest.
81 | ```shell
82 | pytest
83 | ```
84 | 2. Debug the tests with Python.
85 | ```shell
86 | python tests/main.py
87 | ```
88 |
89 | ### Measuring Code Coverage
90 |
91 | 1. Run the tests with Coverage.py.
92 | ```shell
93 | coverage run tests/main.py
94 | ```
95 | 2. Generate the web page report.
96 | ```shell
97 | coverage html
98 | ```
99 | 3. See the report in `htmlcov\index.html`.
100 |
101 | ### Measuring Performance
102 |
103 | 1. Run the sample with cProfile.
104 | ```shell
105 | python -m cProfile -o environment_advanced.prof sample/soccer/environment_advanced.py
106 | ```
107 | 2. See the report with SnakeViz.
108 | ```shell
109 | snakeviz environment_advanced.prof
110 | ```
111 |
112 | ### Resources
113 |
114 | The materials of the tileset comes from the following links:
115 |
116 | * [Minecraft Block Sprite](http://minecraft.gamepedia.com/index.php?title=File:BlockCSS.png)
117 | * [Minecraft Entity Sprite](https://minecraft.gamepedia.com/index.php?title=File:EntityCSS.png)
--------------------------------------------------------------------------------
/cspell.json:
--------------------------------------------------------------------------------
1 | // cSpell Settings
2 | {
3 | // Version of the setting file. Always 0.1
4 | "version": "0.1",
5 | // language - current active spelling language
6 | "language": "en",
7 | // words - list of words to be always considered correct
8 | "words": [],
9 | // flagWords - list of words to be always considered incorrect
10 | // This is useful for offensive words and common spelling errors.
11 | // For example "hte" should be "the"
12 | "flagWords": [],
13 | "dictionaryDefinitions": [
14 | {
15 | "name": "math",
16 | "path": "./dictionaries/math.txt"
17 | },
18 | {
19 | "name": "numpy",
20 | "path": "./dictionaries/numpy.txt"
21 | },
22 | {
23 | "name": "pygame",
24 | "path": "./dictionaries/pygame.txt"
25 | },
26 | {
27 | "name": "pypaths",
28 | "path": "./dictionaries/pypaths.txt"
29 | },
30 | {
31 | "name": "pytest",
32 | "path": "./dictionaries/pytest.txt"
33 | },
34 | {
35 | "name": "pytmx",
36 | "path": "./dictionaries/pytmx.txt"
37 | },
38 | {
39 | "name": "readme",
40 | "path": "./dictionaries/readme.txt"
41 | },
42 | {
43 | "name": "scipy",
44 | "path": "./dictionaries/scipy.txt"
45 | }
46 | ],
47 | "dictionaries": [],
48 | "languageSettings": [
49 | {
50 | "languageId": "markdown",
51 | "dictionaries": [
52 | "pygame",
53 | "pytest",
54 | "readme"
55 | ]
56 | },
57 | {
58 | "languageId": "python",
59 | "dictionaries": [
60 | "math",
61 | "numpy",
62 | "pygame",
63 | "pypaths",
64 | "pytest",
65 | "pytmx",
66 | "scipy"
67 | ]
68 | }
69 | ]
70 | }
--------------------------------------------------------------------------------
/dictionaries/math.txt:
--------------------------------------------------------------------------------
1 | hypot
--------------------------------------------------------------------------------
/dictionaries/numpy.txt:
--------------------------------------------------------------------------------
1 | dtype
2 | linalg
3 | ndarray
4 | numpy
--------------------------------------------------------------------------------
/dictionaries/pygame.txt:
--------------------------------------------------------------------------------
1 | blit
2 | pygame
3 | rect
--------------------------------------------------------------------------------
/dictionaries/pypaths.txt:
--------------------------------------------------------------------------------
1 | astar
2 | pypaths
--------------------------------------------------------------------------------
/dictionaries/pytest.txt:
--------------------------------------------------------------------------------
1 | parametrize
2 | pytest
--------------------------------------------------------------------------------
/dictionaries/pytmx.txt:
--------------------------------------------------------------------------------
1 | pytmx
2 | tiledgidmap
--------------------------------------------------------------------------------
/dictionaries/readme.txt:
--------------------------------------------------------------------------------
1 | htmlcov
2 | minecraft
3 | snakeviz
4 | tileset
--------------------------------------------------------------------------------
/dictionaries/scipy.txt:
--------------------------------------------------------------------------------
1 | imread
2 | imsave
3 | scipy
--------------------------------------------------------------------------------
/docs/screenshot_predator_prey.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sc420/pygame-rl/f81da559385876616d99c74b43e4345f53d086d2/docs/screenshot_predator_prey.png
--------------------------------------------------------------------------------
/docs/screenshot_soccer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sc420/pygame-rl/f81da559385876616d99c74b43e4345f53d086d2/docs/screenshot_soccer.png
--------------------------------------------------------------------------------
/pygame_rl/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['renderer', 'soccer', 'util']
2 |
--------------------------------------------------------------------------------
/pygame_rl/data/map/gridworld/agents_sprite.yaml:
--------------------------------------------------------------------------------
1 | PLAYER1:
2 | x: 0
3 | y: 0
4 | PLAYER2:
5 | x: 1
6 | y: 0
7 | PLAYER3:
8 | x: 2
9 | y: 0
10 | GOAL:
11 | x: 3
12 | y: 0
--------------------------------------------------------------------------------
/pygame_rl/data/map/gridworld/gridworld.tmx:
--------------------------------------------------------------------------------
1 |
2 |
56 |
--------------------------------------------------------------------------------
/pygame_rl/data/map/gridworld/ground_tile.yaml:
--------------------------------------------------------------------------------
1 | GROUND: 686
--------------------------------------------------------------------------------
/pygame_rl/data/map/gridworld/minecraft_tileset.tsx:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/pygame_rl/data/map/gridworld/obstacles_sprite.yaml:
--------------------------------------------------------------------------------
1 | OBSTACLE1:
2 | x: 0
3 | y: 0
4 | OBSTACLE2:
5 | x: 1
6 | y: 0
--------------------------------------------------------------------------------
/pygame_rl/data/map/predator_prey/agent_sprite.yaml:
--------------------------------------------------------------------------------
1 | PREDATOR:
2 | x: 0
3 | y: 0
4 | OBSTACLE:
5 | x: 1
6 | y: 0
7 | PREY:
8 | x: 2
9 | y: 0
--------------------------------------------------------------------------------
/pygame_rl/data/map/predator_prey/ground_tile.yaml:
--------------------------------------------------------------------------------
1 | FIELD: 686
--------------------------------------------------------------------------------
/pygame_rl/data/map/predator_prey/minecraft_tileset.tsx:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/pygame_rl/data/map/predator_prey/predator_prey.tmx:
--------------------------------------------------------------------------------
1 |
2 |
61 |
--------------------------------------------------------------------------------
/pygame_rl/data/map/soccer/agent_sprite.yaml:
--------------------------------------------------------------------------------
1 | AGENT1:
2 | x: 0
3 | y: 0
4 | AGENT2:
5 | x: 1
6 | y: 0
7 | AGENT3:
8 | x: 2
9 | y: 0
10 | AGENT4:
11 | x: 3
12 | y: 0
13 | AGENT5:
14 | x: 4
15 | y: 0
16 | AGENT6:
17 | x: 0
18 | y: 2
19 | AGENT7:
20 | x: 1
21 | y: 2
22 | AGENT8:
23 | x: 2
24 | y: 2
25 | AGENT9:
26 | x: 3
27 | y: 2
28 | AGENT10:
29 | x: 4
30 | y: 2
31 | AGENT1_BALL:
32 | x: 0
33 | y: 1
34 | AGENT2_BALL:
35 | x: 1
36 | y: 1
37 | AGENT3_BALL:
38 | x: 2
39 | y: 1
40 | AGENT4_BALL:
41 | x: 3
42 | y: 1
43 | AGENT5_BALL:
44 | x: 4
45 | y: 1
46 | AGENT6_BALL:
47 | x: 0
48 | y: 3
49 | AGENT7_BALL:
50 | x: 1
51 | y: 3
52 | AGENT8_BALL:
53 | x: 2
54 | y: 3
55 | AGENT9_BALL:
56 | x: 3
57 | y: 3
58 | AGENT10_BALL:
59 | x: 4
60 | y: 3
--------------------------------------------------------------------------------
/pygame_rl/data/map/soccer/goal_tile.yaml:
--------------------------------------------------------------------------------
1 | PLAYER: 74
2 | COMPUTER: 98
--------------------------------------------------------------------------------
/pygame_rl/data/map/soccer/ground_tile.yaml:
--------------------------------------------------------------------------------
1 | WALKABLE: 1
--------------------------------------------------------------------------------
/pygame_rl/data/map/soccer/minecraft_tileset.tsx:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/pygame_rl/data/map/soccer/soccer.tmx:
--------------------------------------------------------------------------------
1 |
2 |
60 |
--------------------------------------------------------------------------------
/pygame_rl/data/map/soccer/spawn_tile.yaml:
--------------------------------------------------------------------------------
1 | PLAYER: 40
2 | COMPUTER: 203
--------------------------------------------------------------------------------
/pygame_rl/data/tileset/minecraft_sprite_32x32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sc420/pygame-rl/f81da559385876616d99c74b43e4345f53d086d2/pygame_rl/data/tileset/minecraft_sprite_32x32.png
--------------------------------------------------------------------------------
/pygame_rl/renderer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sc420/pygame-rl/f81da559385876616d99c74b43e4345f53d086d2/pygame_rl/renderer/__init__.py
--------------------------------------------------------------------------------
/pygame_rl/renderer/pygame_renderer.py:
--------------------------------------------------------------------------------
1 | # Native modules
2 | import abc
3 |
4 | # Third-party modules
5 | import numpy as np
6 | import pygame
7 | import pytmx
8 | import pytmx.util_pygame
9 |
10 | # User-defined modules
11 | import pygame_rl.util.file_util as file_util
12 |
13 |
14 | class TiledLoader(metaclass=abc.ABCMeta):
15 | # Map filename
16 | filename = None
17 |
18 | # Map
19 | tiled_map = None
20 |
21 | # Layers in the map
22 | layers = None
23 |
24 | def __init__(self, filename):
25 | self.filename = filename
26 |
27 | def load_layers(self):
28 | """Load the layers.
29 |
30 | Layers will be categorized based on the layer property. Layers are
31 | categorized into either "background" or "overlay" if the boolean
32 | property exists and is enabled. All layers will be added to the "all"
33 | category.
34 |
35 | """
36 | self.layers = {
37 | 'all': [],
38 | 'background': [],
39 | 'overlay': [],
40 | }
41 | for layer in self.tiled_map.layers:
42 | # Add the layer to all category
43 | self.layers['all'].append(layer)
44 | # Categorize based on the property
45 | prop = layer.properties
46 | if prop.get('background', None):
47 | self.layers['background'].append(layer)
48 | elif prop.get('overlay', None):
49 | self.layers['overlay'].append(layer)
50 |
51 | class TiledData(TiledLoader):
52 | def load(self):
53 | # Load the tiled map
54 | self.tiled_map = pytmx.TiledMap(self.filename)
55 |
56 | # Load the layers
57 | self.load_layers()
58 |
59 | def get_map_size(self):
60 | return np.array([self.tiled_map.width, self.tiled_map.height])
61 |
62 | def get_tile_positions(self):
63 | """Get the tile positions.
64 |
65 | A tile mapping file can be associated with each layer containing the
66 | tile types. If the property "tile" exists, with the value of the file
67 | path relative to the map path, the contents of the mapping from tile
68 | name to tid (tile ID) will be read; Otherwise, the 2nd mapping will be
69 | an empty dict.
70 |
71 | Returns:
72 | dict: 1st mapping is from the layer name to the 2nd dict. 2nd
73 | mapping is from the name to the tile positions.
74 | """
75 | # Get the background layer
76 | layers = self.layers['all']
77 | # Build the mapping
78 | tile_pos = {}
79 | for layer in layers:
80 | # Check whether the tile mapping property exists
81 | if 'tile' in layer.properties:
82 | # Get the tile file path relative to the map file
83 | path = layer.properties['tile']
84 | resolved_path = file_util.resolve_path(self.filename, path)
85 | # Read the tile file
86 | tile_name_to_tid = file_util.read_yaml(resolved_path)
87 | # Build the inverse lookup of the mapping from tile name to tid
88 | tid_to_tile_name = {
89 | v: k for (k, v) in tile_name_to_tid.items()}
90 | # Create the 2nd mapping
91 | tile_name_to_pos = {}
92 | # Create the initial lists
93 | for name in tile_name_to_tid.keys():
94 | tile_name_to_pos[name] = []
95 | # Add the positions
96 | for (px, py, gid) in layer:
97 | # Ignore the empty tile
98 | if gid <= 0:
99 | continue
100 | pos = [px, py]
101 | tid = self.tiled_map.tiledgidmap[gid]
102 | # Append when the mapping exists
103 | if tid in tid_to_tile_name:
104 | tile_name = tid_to_tile_name[tid]
105 | tile_name_to_pos[tile_name].append(pos)
106 | tile_pos[layer.name] = tile_name_to_pos
107 | else:
108 | tile_pos[layer.name] = {}
109 | return tile_pos
110 |
111 |
112 | class TiledRenderer(TiledLoader):
113 | # Pygame surfaces (pygame.Surface)
114 | screen = None
115 | background = None
116 |
117 | def load(self):
118 | # Load the tiled map
119 | self.tiled_map = pytmx.util_pygame.load_pygame(self.filename)
120 |
121 | # Load the layers
122 | self.load_layers()
123 |
124 | def get_display_size(self):
125 | width = self.tiled_map.width * self.tiled_map.tilewidth
126 | height = self.tiled_map.height * self.tiled_map.tileheight
127 | return np.array([width, height])
128 |
129 | def get_map_size(self):
130 | return np.array([self.tiled_map.width, self.tiled_map.height])
131 |
132 | def get_tile_size(self):
133 | return np.array([self.tiled_map.tilewidth, self.tiled_map.tileheight])
134 |
135 | def get_total_tile_num(self):
136 | return self.tiled_map.maxgid - 1
137 |
138 | def get_background(self):
139 | """Get the background surface.
140 |
141 | All background layers will be blitted to the single surface.
142 |
143 | Returns:
144 | pygame.Surface: The background surface.
145 | """
146 | # Get the background layer
147 | background_layers = self.layers['background']
148 | # Create a new Pygame surface by bliting all the images on it
149 | background = pygame.Surface(self.screen.get_size())
150 | for layer in background_layers:
151 | for (px, py, image) in layer.tiles():
152 | area = [px * self.tiled_map.tilewidth,
153 | py * self.tiled_map.tileheight]
154 | background.blit(image, area)
155 | return background
156 |
157 | def get_overlays(self):
158 | """Get the overlay sprites.
159 |
160 | A sprite mapping file is associated with each overlay layer containing
161 | the sprite positions. If the property "sprite" exists, with the value of
162 | the file path relative to the map path, the contents of the mapping from
163 | sprite name to position will be read; Otherwise, an error will be
164 | raised.
165 |
166 | Returns:
167 | dict: A mapping from the name to the sprite.
168 | """
169 | # Get the tile dimension
170 | tile_dim = [self.tiled_map.tilewidth, self.tiled_map.tileheight]
171 | # Get the overlay layer
172 | overlay_layers = self.layers['overlay']
173 | # Get all the overlay images
174 | overlays = {}
175 | for layer in overlay_layers:
176 | # Add the overlay images
177 | if 'sprite' in layer.properties:
178 | # Build the table by pointing the position to the image
179 | pos_to_image = {}
180 | for (px, py, image) in layer.tiles():
181 | pos_to_image[(px, py)] = image
182 | # Get the sprite file path relative to the map file
183 | path = layer.properties['sprite']
184 | resolved_path = file_util.resolve_path(self.filename, path)
185 | # Read the sprite file
186 | sprite = file_util.read_yaml(resolved_path)
187 | # Map the name to the sprite
188 | for (name, pos) in sprite.items():
189 | px = pos['x']
190 | py = pos['y']
191 | pos = (px, py)
192 | if pos not in pos_to_image:
193 | raise KeyError('{} ({}, {}) is not found in the layer'
194 | .format(name, px, py))
195 | # Get the image
196 | image = pos_to_image[pos]
197 | # Create a new sprite
198 | sprite = OverlaySprite(image, pos, tile_dim)
199 | # Save the sprite in the overlays
200 | if name in overlays:
201 | raise RuntimeError(
202 | 'Duplicate name {} in the sprite file'.format(name))
203 | overlays[name] = sprite
204 | else:
205 | raise KeyError('"sprite" property in required for the layer {} '
206 | 'to load the overlays'
207 | .format(layer.name))
208 | return overlays
209 |
210 | def get_screenshot_dim(self):
211 | dim_2d = self.screen.get_size()
212 | return [dim_2d[1], dim_2d[0], 3]
213 |
214 | def get_screenshot(self):
215 | """Get the full screenshot.
216 |
217 | "screen" surface must be rendered first, otherwise the image will be all
218 | black.
219 |
220 | Returns:
221 | numpy.ndarray: The full screenshot.
222 | """
223 | # Get the entire image
224 | image = pygame.surfarray.pixels3d(self.screen)
225 | # Swap the axes as the X and Y axes in Pygame and Scipy are opposite
226 | image_rotated = np.swapaxes(image, 0, 1)
227 | # Copy the array, otherwise the surface will be locked
228 | return np.array(image_rotated)
229 |
230 | def get_po_screenshot(self, pos, radius):
231 | """Get the partially observable (po) screenshot.
232 |
233 | The returned screenshot is always a square with the length of
234 | "tile size" * (2 * radius + 1). The image of the agent is always
235 | centered. The default background is black is the cropped image is near
236 | the boundaries.
237 |
238 | Args:
239 | pos (numpy.array): The position of the partially observable area.
240 | radius (int): The radius of the partially observable area.
241 |
242 | Returns:
243 | numpy.ndarray: The partially observable screenshot.
244 | """
245 | # Get the entire image
246 | image = pygame.surfarray.pixels3d(self.screen)
247 | # Get the size of a single tile as a Numpy array
248 | tile_size = np.array(self.get_tile_size())
249 | # Get the size of the display
250 | display_size = self.get_display_size()
251 | # Calculate the length of the tiles needed
252 | tile_len = 2 * radius + 1
253 | # Calculate the size of the partially observable screenshot
254 | po_size = tile_size * tile_len
255 | # Calculate the offset of the crop area
256 | crop_offset = tile_size * (pos - radius)
257 | # Calculate the crop slice ((x, x+w), (y, y+h))
258 | crop_slice = (
259 | slice(np.max([0, crop_offset[0]]),
260 | np.min([display_size[0], crop_offset[0] + po_size[0]])),
261 | slice(np.max([0, crop_offset[1]]),
262 | np.min([display_size[1], crop_offset[1] + po_size[1]])),
263 | )
264 | # Create a black filled partially observable screenshot
265 | po_screenshot = np.zeros(
266 | (po_size[0], po_size[1], 3), dtype=image.dtype)
267 | # Calculate the crop size
268 | crop_size = [
269 | crop_slice[0].stop - crop_slice[0].start,
270 | crop_slice[1].stop - crop_slice[1].start,
271 | ]
272 | # Calculate the offset of the paste area
273 | paste_offset = [
274 | np.max([0, (-crop_offset[0])]),
275 | np.max([0, (-crop_offset[1])]),
276 | ]
277 | # Calculate the paste slice ((x, x+w), (y, y+h))
278 | paste_slice = (
279 | slice(paste_offset[0], paste_offset[0] + crop_size[0]),
280 | slice(paste_offset[1], paste_offset[1] + crop_size[1]),
281 | )
282 | # Copy and paste the partial screenshot
283 | po_screenshot[paste_slice] = image[crop_slice]
284 | # Swap the axes as the X and Y axes in Pygame and Scipy are opposite
285 | return np.swapaxes(po_screenshot, 0, 1)
286 |
287 |
288 | class OverlaySprite(pygame.sprite.Sprite):
289 | # Position on the grid
290 | pos = None
291 |
292 | # Tile dimension
293 | tile_dim = None
294 |
295 | # Image (pygame.Surface)
296 | image = None
297 |
298 | # Image position (pygame.Rect)
299 | rect = None
300 |
301 | def __init__(self, image, pos, tile_dim):
302 | super().__init__()
303 | # Save the arguments
304 | self.image = image
305 | self.pos = pos
306 | self.tile_dim = tile_dim
307 | # Cache the Pygame Rect
308 | self.rect = self.image.get_rect()
309 | # Update the image position
310 | self.set_pos(self.pos)
311 |
312 | def get_pos(self):
313 | return self.pos
314 |
315 | def set_pos(self, pos):
316 | # Set the intrinsic position
317 | self.pos = pos
318 | # Set the image position
319 | self.rect.x = pos[0] * self.tile_dim[0]
320 | self.rect.y = pos[1] * self.tile_dim[1]
321 |
--------------------------------------------------------------------------------
/pygame_rl/rl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sc420/pygame-rl/f81da559385876616d99c74b43e4345f53d086d2/pygame_rl/rl/__init__.py
--------------------------------------------------------------------------------
/pygame_rl/rl/environment.py:
--------------------------------------------------------------------------------
1 | # Native modules
2 | import abc
3 |
4 |
5 | class Environment(metaclass=abc.ABCMeta):
6 | """The abstract class for the environment.
7 | """
8 | @abc.abstractmethod
9 | def reset(self):
10 | """Reset the environment and return the initial state.
11 | """
12 | raise NotImplementedError()
13 |
14 | @abc.abstractmethod
15 | def take_action(self, action):
16 | """Take an action from the agent and return the observation.
17 | """
18 | raise NotImplementedError()
19 |
20 | @abc.abstractmethod
21 | def render(self):
22 | """Render the environment.
23 | """
24 | raise NotImplementedError()
25 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sc420/pygame-rl/f81da559385876616d99c74b43e4345f53d086d2/pygame_rl/scenario/__init__.py
--------------------------------------------------------------------------------
/pygame_rl/scenario/gridworld/__init__.py:
--------------------------------------------------------------------------------
1 | # Third-party modules
2 | import gym.envs.registration as env_reg
3 |
4 |
5 | # Register environments
6 | env_reg.register(
7 | id='gridworld-v0',
8 | entry_point='pygame_rl.scenario.gridworld.envs:GridworldV0',
9 | )
10 | env_reg.register(
11 | id='gridworld-v1',
12 | entry_point='pygame_rl.scenario.gridworld.envs:GridworldV1',
13 | )
14 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/gridworld/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from pygame_rl.scenario.gridworld.envs.gridworld_v0 import GridworldV0
2 | from pygame_rl.scenario.gridworld.envs.gridworld_v1 import GridworldV1
3 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/gridworld/envs/gridworld_v0.py:
--------------------------------------------------------------------------------
1 | # Third-party modules
2 | import gym
3 | import numpy as np
4 |
5 | # User-defined modules
6 | import pygame_rl.scenario.gridworld.map_data as map_data
7 | import pygame_rl.scenario.gridworld.options as options
8 | import pygame_rl.scenario.gridworld.renderer as renderer
9 |
10 |
11 | class GridworldV0(gym.Env):
12 | """Generic gridworld Gym environment.
13 |
14 | The states (observation) returned by step() and reset() are high-level
15 | features while render() returns RGB array.
16 | """
17 | ############################################################################
18 | # Gym Attributes
19 | ############################################################################
20 | # Metadata
21 | metadata = {'render.modes': ['rgb_array']}
22 | # Observation space
23 | observation_space = None
24 | # Action space
25 | action_space = None
26 |
27 | ############################################################################
28 | # Environment Attributes
29 | ############################################################################
30 | # Environment options
31 | env_options = None
32 | # Renderer options
33 | renderer_options = None
34 | # Map data
35 | map_data = None
36 | # Renderer
37 | renderer = None
38 |
39 | ############################################################################
40 | # State
41 | ############################################################################
42 | # State. A dict where the key is the group name and the value is the
43 | # list of positions of each object.
44 | state = None
45 | # Numpy random state
46 | random_state = None
47 |
48 | ############################################################################
49 | # Cached Objects
50 | ############################################################################
51 | # Object indexes. A dict where the key is the group name and the value is
52 | # the 2nd dict, where the key is the local index and the key is the global
53 | # index.
54 | object_indexes = {}
55 | # Reverse object indexes. A reverse lookup dict of 'object_indexes'. A dict
56 | # where the key is the global index and the value is [group name,
57 | # local object index].
58 | reverse_object_indexes = {}
59 | # Total object numbers
60 | total_object_num = 0
61 |
62 | ############################################################################
63 | # Gym Methods
64 | ############################################################################
65 |
66 | def seed(self, seed=None):
67 | self.random_state = np.random.RandomState(seed)
68 | return self.random_state
69 |
70 | def step(self, action):
71 | next_state, reward, done, info = self.env_options.step_callback(
72 | self.state, action, random_state=self.random_state)
73 | self.state = next_state
74 | return next_state, reward, done, info
75 |
76 | def reset(self):
77 | # Initialize object indexes
78 | self._init_object_indexes()
79 | # Reset the state
80 | self.state = self.env_options.reset_callback(
81 | random_state=self.random_state)
82 | # Reset the renderer
83 | self.renderer.reset()
84 | # Return initial observation
85 | return self._get_obs()
86 |
87 | def render(self, mode='rgb_array'):
88 | # Render
89 | self.renderer.render()
90 | # Return renderer sceenshot
91 | return self.renderer.get_screenshot()
92 |
93 | ############################################################################
94 | # Initialization Methods
95 | ############################################################################
96 |
97 | def load(self):
98 | # Save or create environment options
99 | self.env_options = self.env_options or options.GridworldOptions()
100 | # Load map data
101 | self.map_data = map_data.GridworldMapData(self.env_options.map_path)
102 | # Initialize renderer
103 | self.renderer = renderer.GridworldRenderer(
104 | self.env_options.map_path, self, self.renderer_options)
105 | # Load the renderer
106 | self.renderer.load()
107 | # Initialize observation space
108 | self._init_obs_space()
109 | # Initialize action space
110 | self.action_space = self.env_options.action_sapce
111 |
112 | def _init_object_indexes(self):
113 | self.object_indexes = {}
114 | global_index = 0
115 | # Iterate each group
116 | for group_index, group_name in enumerate(self.env_options.group_names):
117 | group_indexes = {}
118 | group_size = self.env_options.group_sizes[group_index]
119 | # Iterate each local object
120 | for local_index in range(group_size):
121 | group_indexes[local_index] = global_index
122 | self.reverse_object_indexes[global_index] = [
123 | group_name, local_index]
124 | global_index += 1
125 | self.object_indexes[group_name] = group_indexes
126 | # Save the total object number
127 | self.total_object_num = global_index
128 |
129 | def _init_obs_space(self):
130 | map_size = self.renderer.get_map_size()
131 | flattened_map_size = map_size.prod()
132 | nvec = np.repeat(self.total_object_num, flattened_map_size)
133 | self.observation_space = gym.spaces.MultiDiscrete(nvec)
134 |
135 | ############################################################################
136 | # Observation Retrieval
137 | ############################################################################
138 |
139 | def _get_obs(self):
140 | """Get flattened observation.
141 |
142 | The observation is a flattened vector of one-hot vectors, the flattened
143 | (row-major) vector is the representation of the 2D map, and each one-hot
144 | vector represents existence of the objects, with each index the global
145 | index of the object.
146 | """
147 | map_size = self.renderer.get_map_size()
148 | map_width = map_size[1]
149 | flattened_map_size = map_size.prod()
150 | obs = np.zeros(
151 | [flattened_map_size, self.total_object_num], dtype=np.int)
152 | for group_name, positions in self.state.items():
153 | for local_index, pos in enumerate(positions):
154 | index_1d = index_2d_to_1d(pos, map_width)
155 | global_index = self.object_indexes[group_name][local_index]
156 | obs[index_1d][global_index] = 1
157 | return obs
158 |
159 |
160 | def index_2d_to_1d(pos, width):
161 | px, py = pos
162 | return np.asscalar((width * py) + px)
163 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/gridworld/envs/gridworld_v1.py:
--------------------------------------------------------------------------------
1 | # Third-party modules
2 | import gym
3 | import numpy as np
4 |
5 | # User-defined modules
6 | import pygame_rl.scenario.gridworld.map_data as map_data
7 | import pygame_rl.scenario.gridworld.options as options
8 | import pygame_rl.scenario.gridworld.renderer as renderer
9 |
10 |
11 | class GridworldV1(gym.Env):
12 | """Generic gridworld Gym environment.
13 |
14 | The states (observation) returned by step(), reset(), render() are RGB
15 | arrays.
16 | """
17 | ############################################################################
18 | # Gym Attributes
19 | ############################################################################
20 | # Metadata
21 | metadata = {'render.modes': ['rgb_array']}
22 | # Observation space
23 | observation_space = None
24 | # Action space
25 | action_space = None
26 |
27 | ############################################################################
28 | # Environment Attributes
29 | ############################################################################
30 | # Environment options
31 | env_options = None
32 | # Renderer options
33 | renderer_options = None
34 | # Map data
35 | map_data = None
36 | # Renderer
37 | renderer = None
38 |
39 | ############################################################################
40 | # State
41 | ############################################################################
42 | # State. A dict where the key is the group name and the value is the
43 | # list of positions of each object.
44 | state = None
45 | # Numpy random state
46 | random_state = None
47 |
48 | ############################################################################
49 | # Cached Objects
50 | ############################################################################
51 | # Object indexes. A dict where the key is the group name and the value is
52 | # the 2nd dict, where the key is the local index and the key is the global
53 | # index.
54 | object_indexes = {}
55 | # Reverse object indexes. A reverse lookup dict of 'object_indexes'. A dict
56 | # where the key is the global index and the value is [group name,
57 | # local object index].
58 | reverse_object_indexes = {}
59 | # Total object numbers
60 | total_object_num = 0
61 |
62 | ############################################################################
63 | # Gym Methods
64 | ############################################################################
65 |
66 | def seed(self, seed=None):
67 | self.random_state = np.random.RandomState(seed)
68 | return self.random_state
69 |
70 | def step(self, action):
71 | next_state, reward, done, info = self.env_options.step_callback(
72 | self.state, action, random_state=self.random_state)
73 | self.state = next_state
74 | obs = self._get_obs()
75 | return obs, reward, done, info
76 |
77 | def reset(self):
78 | # Initialize object indexes
79 | self._init_object_indexes()
80 | # Reset the state
81 | self.state = self.env_options.reset_callback(
82 | random_state=self.random_state)
83 | # Reset the renderer
84 | self.renderer.reset()
85 | # Return initial observation
86 | return self._get_obs()
87 |
88 | def render(self, mode='rgb_array'):
89 | # Render
90 | self.renderer.render()
91 | # Return renderer sceenshot
92 | return self.renderer.get_screenshot()
93 |
94 | ############################################################################
95 | # Initialization Methods
96 | ############################################################################
97 |
98 | def load(self):
99 | # Save or create environment options
100 | self.env_options = self.env_options or options.GridworldOptions()
101 | # Load map data
102 | self.map_data = map_data.GridworldMapData(self.env_options.map_path)
103 | # Initialize renderer
104 | self.renderer = renderer.GridworldRenderer(
105 | self.env_options.map_path, self, self.renderer_options)
106 | # Load the renderer
107 | self.renderer.load()
108 | # Initialize observation space
109 | self._init_obs_space()
110 | # Initialize action space
111 | self.action_space = self.env_options.action_sapce
112 |
113 | def _init_object_indexes(self):
114 | self.object_indexes = {}
115 | global_index = 0
116 | # Iterate each group
117 | for group_index, group_name in enumerate(self.env_options.group_names):
118 | group_indexes = {}
119 | group_size = self.env_options.group_sizes[group_index]
120 | # Iterate each local object
121 | for local_index in range(group_size):
122 | group_indexes[local_index] = global_index
123 | self.reverse_object_indexes[global_index] = [
124 | group_name, local_index]
125 | global_index += 1
126 | self.object_indexes[group_name] = group_indexes
127 | # Save the total object number
128 | self.total_object_num = global_index
129 |
130 | def _init_obs_space(self):
131 | display_size = self.renderer.get_display_size()
132 | nvec = [display_size[0], display_size[1], 3]
133 | self.observation_space = gym.spaces.MultiDiscrete(nvec)
134 |
135 | ############################################################################
136 | # Observation Retrieval
137 | ############################################################################
138 |
139 | def _get_obs(self):
140 | # Render
141 | self.renderer.render()
142 | # Return renderer sceenshot
143 | return self.renderer.get_screenshot()
144 |
145 |
146 | def index_2d_to_1d(pos, width):
147 | px, py = pos
148 | return np.asscalar((width * py) + px)
149 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/gridworld/map_data.py:
--------------------------------------------------------------------------------
1 | # User-defined modules
2 | import pygame_rl.renderer.pygame_renderer as pygame_renderer
3 |
4 |
5 | class GridworldMapData:
6 | """Map data as the geographical info.
7 | """
8 | # Tile positions
9 | tile_pos = None
10 |
11 | def __init__(self, map_path):
12 | # Create a tile data and load
13 | tiled_data = pygame_renderer.TiledData(map_path)
14 | tiled_data.load()
15 | # Get the background tile positions
16 | self.tile_pos = tiled_data.get_tile_positions()
17 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/gridworld/options.py:
--------------------------------------------------------------------------------
1 | # Native modules
2 | import copy
3 |
4 | # Third-party modules
5 | import gym
6 | import numpy as np
7 |
8 | # User-defined modules
9 | import pygame_rl.util.file_util as file_util
10 |
11 |
12 | class GridworldOptions:
13 | """Environment options.
14 | """
15 | # Internal map resource name
16 | map_resource_name = 'pygame_rl/data/map/gridworld/gridworld.tmx'
17 | # Map path
18 | map_path = None
19 | # Action space
20 | action_sapce = None
21 | # Callback to step
22 | step_callback = None
23 | # Callback to reset state
24 | reset_callback = None
25 | # Group names
26 | group_names = []
27 | # Group sizes
28 | group_sizes = []
29 |
30 | def __init__(self, map_path=None, action_space=None,
31 | step_callback=None, reset_callback=None):
32 | self._init_map_path(map_path)
33 | self._init_action_space(action_space)
34 | self._init_step_callback(step_callback)
35 | self._init_reset_callback(reset_callback)
36 | self._init_default_group()
37 |
38 | def set_group(self, group_names, group_sizes):
39 | if len(group_names) != len(group_sizes):
40 | raise ValueError('Length of group names and sizes should be the '
41 | 'same')
42 | self.group_names = group_names
43 | self.group_sizes = group_sizes
44 |
45 | def _init_map_path(self, map_path):
46 | if map_path:
47 | self.map_path = map_path
48 | else:
49 | self.map_path = file_util.get_resource_path(self.map_resource_name)
50 |
51 | def _init_action_space(self, action_space):
52 | if action_space:
53 | self.action_sapce = action_space
54 | else:
55 | # 4-directional walk and stand still
56 | self.action_sapce = gym.spaces.Discrete(5)
57 |
58 | def _init_default_group(self):
59 | self.group_names = [
60 | 'PLAYER1',
61 | 'PLAYER2',
62 | 'PLAYER3',
63 | 'GOAL',
64 | 'OBSTACLE1',
65 | 'OBSTACLE2',
66 | ]
67 | self.group_sizes = [
68 | 1,
69 | 1,
70 | 1,
71 | 3,
72 | 5,
73 | 1,
74 | ]
75 |
76 | def _init_step_callback(self, step_callback):
77 | def default_callback(prev_state, action, random_state):
78 | del random_state
79 | state = copy.deepcopy(prev_state)
80 | # Get player 1 position
81 | pos = prev_state['PLAYER1'][0]
82 | # Get new position
83 | new_pos = get_new_pos(pos, action)
84 | # Update state
85 | if is_valid_pos(new_pos, prev_state):
86 | state['PLAYER1'][0] = new_pos
87 | done = is_done(pos, state)
88 | reward = 1.0 if done else 0.0
89 | info = {}
90 | return state, reward, done, info
91 |
92 | def get_new_pos(pos, action):
93 | new_pos = np.array(pos)
94 | if action == 0: # Move right
95 | new_pos[0] += 1
96 | elif action == 1: # Move up
97 | new_pos[1] -= 1
98 | elif action == 2: # Move left
99 | new_pos[0] -= 1
100 | elif action == 3: # Move down
101 | new_pos[1] += 1
102 | elif action == 4: # Stand still
103 | pass
104 | else:
105 | raise ValueError('Unknown action: {}'.format(action))
106 | return new_pos
107 |
108 | def is_valid_pos(pos, prev_state):
109 | in_bound = (pos[0] >= 0 and pos[0] < 9 and
110 | pos[1] >= 0 and pos[1] < 9)
111 | collision_group_names = [
112 | 'PLAYER2',
113 | 'PLAYER3',
114 | 'OBSTACLE1',
115 | 'OBSTACLE2',
116 | ]
117 | no_collision = not check_collision(
118 | pos, collision_group_names, prev_state)
119 | return in_bound and no_collision
120 |
121 | def is_done(pos, state):
122 | collision_group_names = [
123 | 'GOAL',
124 | ]
125 | return check_collision(pos, collision_group_names, state)
126 |
127 | def check_collision(pos, collision_group_names, state):
128 | for group_index, group_name in enumerate(self.group_names):
129 | if not group_name in collision_group_names:
130 | continue
131 | for local_index in range(self.group_sizes[group_index]):
132 | other_pos = state[group_name][local_index]
133 | if np.array_equal(pos, other_pos):
134 | return True
135 | return False
136 |
137 | if step_callback:
138 | self.step_callback = step_callback
139 | else:
140 | self.step_callback = default_callback
141 |
142 | def _init_reset_callback(self, reset_callback):
143 | def default_callback(random_state):
144 | del random_state
145 | return {
146 | 'PLAYER1': np.asarray([
147 | np.array([0, 0]),
148 | ]),
149 | 'PLAYER2': np.asarray([
150 | np.array([8, 0]),
151 | ]),
152 | 'PLAYER3': np.asarray([
153 | np.array([0, 8]),
154 | ]),
155 | 'GOAL': np.asarray([
156 | np.array([8, 4]),
157 | np.array([4, 8]),
158 | np.array([8, 8]),
159 | ]),
160 | 'OBSTACLE1': np.asarray([
161 | np.array([4, 3]),
162 | np.array([3, 4]),
163 | np.array([4, 4]),
164 | np.array([5, 4]),
165 | np.array([4, 5]),
166 | ]),
167 | 'OBSTACLE2': np.asarray([
168 | np.array([4, 4]),
169 | ]),
170 | }
171 |
172 | if reset_callback:
173 | self.reset_callback = reset_callback
174 | else:
175 | self.reset_callback = default_callback
176 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/gridworld/renderer.py:
--------------------------------------------------------------------------------
1 | # Third-party modules
2 | import pygame
3 | import pygame.locals
4 |
5 | # User-defined modules
6 | import pygame_rl.renderer.pygame_renderer as pygame_renderer
7 |
8 |
9 | class GridworldRenderer(pygame_renderer.TiledRenderer):
10 | """Gridworld renderer.
11 | """
12 | # Constants
13 | title = 'Gridworld'
14 |
15 | # Environment
16 | env = None
17 |
18 | # Renderer options
19 | renderer_options = None
20 |
21 | # Display state
22 | display_quitted = False
23 |
24 | # TMX objects
25 | static_overlays = None
26 | moving_overlays = None
27 |
28 | # Clock object (pygame.time.Clock())
29 | clock = None
30 |
31 | # Dirty groups (pygame.sprite.RenderUpdates)
32 | dirty_groups = None
33 |
34 | def __init__(self, map_path, env, renderer_options=None):
35 | super().__init__(map_path)
36 | # Save the environment
37 | self.env = env
38 | # Use or create the renderer options
39 | self.renderer_options = renderer_options or RendererOptions()
40 |
41 | def load(self):
42 | # Initialize Pygame
43 | pygame.display.init()
44 | pygame.display.set_mode([400, 300])
45 | pygame.display.set_caption(self.title)
46 |
47 | # Initialize the renderer
48 | super().load()
49 |
50 | # Set the screen size
51 | resolution = super().get_display_size()
52 | self.screen = pygame.display.set_mode(resolution)
53 |
54 | # Get the background
55 | self.background = super().get_background()
56 |
57 | # Get the static overlays
58 | self.static_overlays = super().get_overlays()
59 |
60 | # Blit the background to the screen
61 | self.screen.blit(self.background, [0, 0])
62 |
63 | # Update the full display
64 | if self.renderer_options.show_display:
65 | pygame.display.flip()
66 |
67 | # Create the clock
68 | self.clock = pygame.time.Clock()
69 |
70 | def render(self):
71 | # Close the display if the renderer options is set to disable the
72 | # display
73 | if not self.display_quitted and not self.renderer_options.show_display:
74 | # Replace the screen surface with in-memory surface
75 | self.screen = self.screen.copy()
76 | # Close the display
77 | pygame.display.quit()
78 | # Prevent from further closing
79 | self.display_quitted = True
80 |
81 | # Clear the overlays
82 | self.dirty_groups.clear(self.screen, self.background)
83 |
84 | # Update the overlays by the environment state
85 | self._update_overlay_pos()
86 |
87 | # Draw the overlays
88 | dirty = self.dirty_groups.draw(self.screen)
89 |
90 | # Update only the dirty surface
91 | if self.renderer_options.show_display:
92 | pygame.display.update(dirty)
93 |
94 | # Limit the max frames per second
95 | if self.renderer_options.show_display:
96 | self.clock.tick(self.renderer_options.max_fps)
97 |
98 | # Handle the events
99 | if self.renderer_options.show_display:
100 | for event in pygame.event.get():
101 | # Detect the quit event
102 | if event.type == pygame.locals.QUIT:
103 | # Indicate the rendering should stop
104 | return False
105 | # Detect the keydown event
106 | if self.renderer_options.enable_key_events:
107 | if event.type == pygame.locals.KEYDOWN:
108 | if event.key == pygame.locals.K_RIGHT:
109 | self.env.step(0)
110 | elif event.key == pygame.locals.K_UP:
111 | self.env.step(1)
112 | elif event.key == pygame.locals.K_LEFT:
113 | self.env.step(2)
114 | elif event.key == pygame.locals.K_DOWN:
115 | self.env.step(3)
116 | elif event.key == pygame.locals.K_s:
117 | self.env.step(4)
118 |
119 | # Indicate the rendering should continue
120 | return True
121 |
122 | def reset(self):
123 | # Clear the previous overlays
124 | if self.dirty_groups:
125 | # Remove all sprites
126 | self.dirty_groups.empty()
127 |
128 | # Clear the overlays
129 | self.dirty_groups.clear(self.screen, self.background)
130 |
131 | # Draw the overlays
132 | dirty = self.dirty_groups.draw(self.screen)
133 |
134 | # Update only the dirty surface
135 | if self.renderer_options.show_display:
136 | pygame.display.update(dirty)
137 |
138 | # Initialize the moving overlays
139 | self._load_moving_overlays()
140 |
141 | # Initialize the dirty group
142 | self._load_dirty_group()
143 |
144 | def _load_moving_overlays(self):
145 | self.moving_overlays = []
146 | for group_name, object_indexes in self.env.object_indexes.items():
147 | static_overlay = self.static_overlays[group_name]
148 | for _ in object_indexes:
149 | moving_overlay = copy_static_overlay(static_overlay)
150 | self.moving_overlays.append(moving_overlay)
151 |
152 | def _load_dirty_group(self):
153 | self.dirty_groups = pygame.sprite.RenderUpdates()
154 | self.dirty_groups.add(self.moving_overlays)
155 |
156 | def _update_overlay_pos(self):
157 | for group_name, positions in self.env.state.items():
158 | for local_index, pos in enumerate(positions):
159 | global_index = self.env.object_indexes[group_name][local_index]
160 | self.moving_overlays[global_index].set_pos(pos)
161 |
162 |
163 | class RendererOptions(object):
164 | """Renderer options.
165 | """
166 | show_display = False
167 | max_fps = 0
168 | enable_key_events = False
169 |
170 | def __init__(self, show_display=False, max_fps=0, enable_key_events=False):
171 | self.show_display = show_display
172 | self.max_fps = max_fps
173 | self.enable_key_events = enable_key_events
174 |
175 |
176 | def copy_static_overlay(static_overlay):
177 | return pygame_renderer.OverlaySprite(
178 | static_overlay.image, static_overlay.pos, static_overlay.tile_dim)
179 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/predator_prey_renderer.py:
--------------------------------------------------------------------------------
1 | # Third-party modules
2 | import pygame
3 | import pygame.locals
4 |
5 | # User-defined modules
6 | import pygame_rl.renderer.pygame_renderer as pygame_renderer
7 |
8 |
9 | class PredatorPreyRenderer(pygame_renderer.TiledRenderer):
10 | """Predator-prey renderer.
11 | """
12 | # Constants
13 | title = 'Predator-prey'
14 |
15 | # Environment
16 | env = None
17 |
18 | # Renderer options
19 | renderer_options = None
20 |
21 | # Display state
22 | display_quitted = False
23 |
24 | # TMX objects
25 | static_overlays = None
26 | moving_overlays = None
27 |
28 | # Clock object (pygame.time.Clock())
29 | clock = None
30 |
31 | # Dirty groups (pygame.sprite.RenderUpdates)
32 | dirty_groups = None
33 |
34 | def __init__(self, map_path, env, renderer_options=None):
35 | super().__init__(map_path)
36 | # Save the environment
37 | self.env = env
38 | # Use or create the renderer options
39 | self.renderer_options = renderer_options or RendererOptions()
40 |
41 | def load(self):
42 | # Initialize Pygame
43 | pygame.display.init()
44 | pygame.display.set_mode([400, 300])
45 | pygame.display.set_caption(self.title)
46 |
47 | # Initialize the renderer
48 | super().load()
49 |
50 | # Set the screen size
51 | resolution = super().get_display_size()
52 | self.screen = pygame.display.set_mode(resolution)
53 |
54 | # Get the background
55 | self.background = super().get_background()
56 |
57 | # Get the static overlays
58 | self.static_overlays = super().get_overlays()
59 |
60 | # Initialize the moving overlays
61 | self._load_moving_overlays()
62 |
63 | # Initialize the dirty group
64 | self._load_dirty_group()
65 |
66 | # Blit the background to the screen
67 | self.screen.blit(self.background, [0, 0])
68 |
69 | # Update the full display
70 | if self.renderer_options.show_display:
71 | pygame.display.flip()
72 |
73 | # Create the clock
74 | self.clock = pygame.time.Clock()
75 |
76 | def render(self):
77 | # Close the display if the renderer options is set to disable the
78 | # display
79 | if not self.display_quitted and not self.renderer_options.show_display:
80 | # Replace the screen surface with in-memory surface
81 | self.screen = self.screen.copy()
82 | # Close the display
83 | pygame.display.quit()
84 | # Prevent from further closing
85 | self.display_quitted = True
86 |
87 | # Clear the overlays
88 | self.dirty_groups.clear(self.screen, self.background)
89 |
90 | # Update the overlays by the environment state
91 | self._update_overlay_pos()
92 | self._update_overlay_visibility()
93 |
94 | # Draw the overlays
95 | dirty = self.dirty_groups.draw(self.screen)
96 |
97 | # Update only the dirty surface
98 | if self.renderer_options.show_display:
99 | pygame.display.update(dirty)
100 |
101 | # Limit the max frames per second
102 | if self.renderer_options.show_display:
103 | self.clock.tick(self.renderer_options.max_fps)
104 |
105 | # Handle the events
106 | if self.renderer_options.show_display:
107 | for event in pygame.event.get():
108 | # Detect the quit event
109 | if event.type == pygame.locals.QUIT:
110 | # Indicate the rendering should stop
111 | return False
112 | # Detect the keydown event
113 | if self.renderer_options.enable_key_events:
114 | if event.type == pygame.locals.KEYDOWN:
115 | if event.key == pygame.locals.K_RIGHT:
116 | self.env.take_cached_action(0, 'MOVE_RIGHT')
117 | elif event.key == pygame.locals.K_UP:
118 | self.env.take_cached_action(0, 'MOVE_UP')
119 | elif event.key == pygame.locals.K_LEFT:
120 | self.env.take_cached_action(0, 'MOVE_LEFT')
121 | elif event.key == pygame.locals.K_DOWN:
122 | self.env.take_cached_action(0, 'MOVE_DOWN')
123 | elif event.key == pygame.locals.K_s:
124 | self.env.take_cached_action(0, 'STAND')
125 | # Update the state
126 | self.env.update_state()
127 |
128 | # Indicate the rendering should continue
129 | return True
130 |
131 | def _load_moving_overlays(self):
132 | self.moving_overlays = []
133 | for group_name in self.env.group_names:
134 | static_overlay = self.static_overlays[group_name]
135 | object_index = self.env.get_group_index_range(group_name)
136 | for _ in range(*object_index):
137 | moving_overlay = pygame_renderer.OverlaySprite(
138 | static_overlay.image, static_overlay.pos,
139 | static_overlay.tile_dim)
140 | self.moving_overlays.append(moving_overlay)
141 |
142 | def _load_dirty_group(self):
143 | self.dirty_groups = pygame.sprite.RenderUpdates()
144 | self.dirty_groups.add(self.moving_overlays)
145 |
146 | def _update_overlay_pos(self):
147 | for object_index in range(self.env.options.get_total_object_size()):
148 | pos = self.env.state.get_object_pos(object_index)
149 | self.moving_overlays[object_index].set_pos(pos)
150 |
151 | def _update_overlay_visibility(self):
152 | for object_index in range(self.env.options.get_total_object_size()):
153 | availability = self.env.state.get_object_availability(object_index)
154 | if not availability:
155 | self.dirty_groups.remove(self.moving_overlays[object_index])
156 |
157 |
158 | class RendererOptions(object):
159 | """Renderer options.
160 | """
161 | show_display = False
162 | max_fps = 0
163 | enable_key_events = False
164 |
165 | def __init__(self, show_display=False, max_fps=0, enable_key_events=False):
166 | self.show_display = show_display
167 | self.max_fps = max_fps
168 | self.enable_key_events = enable_key_events
169 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/soccer/__init__.py:
--------------------------------------------------------------------------------
1 | # Third-party modules
2 | import gym.envs.registration as env_reg
3 |
4 |
5 | # Register environments
6 | env_reg.register(
7 | id='soccer-v0',
8 | entry_point='pygame_rl.scenario.soccer.envs:SoccerV0',
9 | )
10 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/soccer/actions.py:
--------------------------------------------------------------------------------
1 | # Native modules
2 | from enum import IntEnum
3 |
4 |
5 | class Actions(IntEnum):
6 | # Indicates that rule-based AI will override the action
7 | NOOP = 0
8 | MOVE_RIGHT = 1
9 | MOVE_UP = 2
10 | MOVE_LEFT = 3
11 | MOVE_DOWN = 4
12 | STAND = 5
13 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/soccer/agent_modes.py:
--------------------------------------------------------------------------------
1 | # Native modules
2 | from enum import IntEnum
3 |
4 |
5 | class AgentModes(IntEnum):
6 | DEFENSIVE = 0
7 | OFFENSIVE = 1
8 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/soccer/ai_modes.py:
--------------------------------------------------------------------------------
1 | # Native modules
2 | from enum import IntEnum
3 |
4 |
5 | class AiModes(IntEnum):
6 | APPROACH = 0
7 | AVOID = 1
8 | INTERCEPT = 2
9 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/soccer/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from pygame_rl.scenario.soccer.envs.soccer_v0 import SoccerV0
2 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/soccer/map_data.py:
--------------------------------------------------------------------------------
1 | # Project modules
2 | from pygame_rl.renderer.pygame_renderer import TiledData
3 |
4 |
5 | class MapData(object):
6 | """The soccer map data as the geographical info.
7 | """
8 | # Map size
9 | map_size = None
10 | # Tile positions
11 | spawn = []
12 | goals = []
13 | walkable = []
14 |
15 | def __init__(self, map_path):
16 | # Create a tile data and load
17 | tiled_data = TiledData(map_path)
18 | tiled_data.load()
19 | # Get the map size
20 | self.map_size = tiled_data.get_map_size()
21 | # Get the background tile positions
22 | tile_pos = tiled_data.get_tile_positions()
23 | # Build the tile positions
24 | self.spawn = tile_pos['spawn_area']
25 | self.goals = tile_pos['goal']
26 | self.walkable = tile_pos['ground']['WALKABLE']
27 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/soccer/options.py:
--------------------------------------------------------------------------------
1 | # User-defined modules
2 | import pygame_rl.util.file_util as file_util
3 |
4 |
5 | class Options(object):
6 | """The options for the soccer environment.
7 | """
8 | # Resource names
9 | map_resource_name = 'pygame_rl/data/map/soccer/soccer.tmx'
10 |
11 | # Map path
12 | map_path = None
13 |
14 | # Team size
15 | team_size = 1
16 |
17 | # Frame skip for AI
18 | ai_frame_skip = 1
19 |
20 | def __init__(self, map_path=None, team_size=1, ai_frame_skip=1):
21 | # Save the map path or use the internal resource
22 | if map_path:
23 | self.map_path = map_path
24 | else:
25 | self.map_path = file_util.get_resource_path(self.map_resource_name)
26 | # Save the team size
27 | self.team_size = team_size
28 | # Save the frame skip
29 | self.ai_frame_skip = ai_frame_skip
30 |
31 | @property
32 | def agent_size(self):
33 | return 2 * self.team_size
34 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/soccer/renderer.py:
--------------------------------------------------------------------------------
1 | # Third-party modules
2 | import pygame
3 | import pygame.locals
4 |
5 | # User-defined modules
6 | from pygame_rl.renderer.pygame_renderer import TiledRenderer
7 | from pygame_rl.scenario.soccer.renderer_options import RendererOptions
8 |
9 |
10 | class Renderer(TiledRenderer):
11 | """Soccer renderer.
12 | """
13 | # Constants
14 | title = 'Soccer'
15 |
16 | # Environment
17 | env = None
18 |
19 | # Renderer options
20 | renderer_options = None
21 |
22 | # Display state
23 | display_quitted = False
24 |
25 | # TMX objects
26 | static_overlays = None
27 |
28 | # Clock object (pygame.time.Clock())
29 | clock = None
30 |
31 | # Dirty groups (pygame.sprite.RenderUpdates)
32 | dirty_groups = None
33 |
34 | # Previous ball state
35 | prev_ball_state = None
36 |
37 | def __init__(self, map_path, env, renderer_options=None):
38 | super().__init__(map_path)
39 | # Save the environment
40 | self.env = env
41 | # Use or create the renderer options
42 | self.renderer_options = renderer_options or RendererOptions()
43 |
44 | def load(self):
45 | # Initialize Pygame
46 | pygame.display.init()
47 | pygame.display.set_mode([400, 300])
48 | pygame.display.set_caption(self.title)
49 |
50 | # Initialize the renderer
51 | super().load()
52 |
53 | # Set the screen size
54 | resolution = super().get_display_size()
55 | self.screen = pygame.display.set_mode(resolution)
56 |
57 | # Get the background
58 | self.background = super().get_background()
59 |
60 | # Get the static overlays
61 | self.static_overlays = super().get_overlays()
62 |
63 | # Initialize previous ball state
64 | self._init_prev_ball_state()
65 |
66 | # Initialize the dirty group
67 | self._load_dirty_group()
68 |
69 | # Blit the background to the screen
70 | self.screen.blit(self.background, [0, 0])
71 |
72 | # Update the full display
73 | if self.renderer_options.show_display:
74 | pygame.display.flip()
75 |
76 | # Create the clock
77 | self.clock = pygame.time.Clock()
78 |
79 | # Close the display if the renderer options is set to disable the
80 | # display
81 | if not self.display_quitted and not self.renderer_options.show_display:
82 | # Replace the screen surface with in-memory surface
83 | self.screen = self.screen.copy()
84 | # Close the display
85 | pygame.display.quit()
86 | # Prevent from further closing
87 | self.display_quitted = True
88 |
89 | def render(self):
90 | # Clear the overlays
91 | self.dirty_groups.clear(self.screen, self.background)
92 |
93 | # Update the overlays by the environment state
94 | self._update_overlay_pos()
95 | self._update_overlay_visibility()
96 |
97 | # Draw the overlays
98 | dirty = self.dirty_groups.draw(self.screen)
99 |
100 | # Update only the dirty surface
101 | if self.renderer_options.show_display:
102 | pygame.display.update(dirty)
103 |
104 | # Limit the max frames per second
105 | if self.renderer_options.show_display:
106 | self.clock.tick(self.renderer_options.max_fps)
107 |
108 | # Handle the events
109 | if self.renderer_options.show_display:
110 | for event in pygame.event.get():
111 | # Detect the quit event
112 | if event.type == pygame.locals.QUIT:
113 | # Indicate the rendering should stop
114 | return False
115 | # Detect the keydown event
116 | if self.renderer_options.enable_key_events:
117 | if event.type == pygame.locals.KEYDOWN:
118 | # Get the agent index of the first player
119 | team_agent_index = 0
120 | agent_index = self.env.get_agent_index(
121 | 'PLAYER', team_agent_index)
122 | # Prepare the cached action
123 | cached_action = None
124 | if event.key == pygame.locals.K_RIGHT:
125 | cached_action = 'MOVE_RIGHT'
126 | elif event.key == pygame.locals.K_UP:
127 | cached_action = 'MOVE_UP'
128 | elif event.key == pygame.locals.K_LEFT:
129 | cached_action = 'MOVE_LEFT'
130 | elif event.key == pygame.locals.K_DOWN:
131 | cached_action = 'MOVE_DOWN'
132 | elif event.key == pygame.locals.K_s:
133 | cached_action = 'STAND'
134 | # Take the cached action and update the state
135 | if cached_action:
136 | self.env.take_cached_action(
137 | agent_index, cached_action)
138 | self.env.update_state()
139 |
140 | # Indicate the rendering should continue
141 | return True
142 |
143 | def _init_prev_ball_state(self):
144 | agent_size = self.env.options.agent_size
145 | self.prev_ball_state = agent_size * [None]
146 |
147 | def _load_dirty_group(self):
148 | self.dirty_groups = pygame.sprite.RenderUpdates()
149 |
150 | def _update_overlay_pos(self):
151 | for agent_index in range(self.env.options.agent_size):
152 | [overlay_has_ball, overlay_no_ball] = self._get_overlays(
153 | agent_index)
154 | has_ball = self.env.state.get_agent_ball(agent_index)
155 | agent_pos = self.env.state.get_agent_pos(agent_index)
156 | if has_ball:
157 | overlay_has_ball.set_pos(agent_pos)
158 | else:
159 | overlay_no_ball.set_pos(agent_pos)
160 |
161 | def _update_overlay_visibility(self):
162 | for agent_index in range(self.env.options.agent_size):
163 | # Get the static overlays
164 | [overlay_has_ball, overlay_no_ball] = self._get_overlays(
165 | agent_index)
166 | # Check whether the agent has the ball
167 | has_ball = self.env.state.get_agent_ball(agent_index)
168 | # Get the previous ball state
169 | prev_has_ball = self.prev_ball_state[agent_index]
170 | # Check whether the ball state has changed
171 | if prev_has_ball is None or prev_has_ball != has_ball:
172 | # Remove the old sprite and add the new sprite in the dirty
173 | # group
174 | if has_ball:
175 | self.dirty_groups.remove(overlay_no_ball)
176 | self.dirty_groups.add(overlay_has_ball)
177 | else:
178 | self.dirty_groups.remove(overlay_has_ball)
179 | self.dirty_groups.add(overlay_no_ball)
180 | # Set the previous ball state
181 | self.prev_ball_state[agent_index] = has_ball
182 |
183 | def _get_overlays(self, agent_index):
184 | name_has_ball = 'AGENT{}_BALL'.format(agent_index + 1)
185 | name_no_ball = 'AGENT{}'.format(agent_index + 1)
186 | overlay_has_ball = self.static_overlays[name_has_ball]
187 | overlay_no_ball = self.static_overlays[name_no_ball]
188 | return [overlay_has_ball, overlay_no_ball]
189 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/soccer/renderer_options.py:
--------------------------------------------------------------------------------
1 | class RendererOptions(object):
2 | """Renderer options.
3 | """
4 | show_display = False
5 | max_fps = 0
6 | enable_key_events = False
7 |
8 | def __init__(self, show_display=False, max_fps=0, enable_key_events=False):
9 | self.show_display = show_display
10 | self.max_fps = max_fps
11 | self.enable_key_events = enable_key_events
12 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/soccer/state.py:
--------------------------------------------------------------------------------
1 | # Third-party modules
2 | import numpy as np
3 |
4 | # Project modules
5 | from pygame_rl.scenario.soccer.actions import Actions
6 | from pygame_rl.scenario.soccer.agent_modes import AgentModes
7 | from pygame_rl.scenario.soccer.teams import Teams
8 |
9 |
10 | class State(object):
11 | """The internal soccer state.
12 | """
13 | # Agent statuses as a list
14 | # * pos: Positions
15 | # * ball: Possession of the ball
16 | # * mode: Mode for the agent
17 | # * action: Last taken action for the agent
18 | # * frame_skip_index: Current frame skipping index, starting from 0,
19 | # resetting after it reaches the frame skip
20 | agent_list = []
21 |
22 | # Position to object index map
23 | pos_map = None
24 |
25 | # Time step
26 | time_step = 0
27 |
28 | # Soccer environment
29 | env = None
30 |
31 | # Soccer environment options
32 | env_options = None
33 |
34 | # Map data
35 | map_data = None
36 |
37 | # Random state
38 | random_state = None
39 |
40 | def __init__(self, env, env_options, map_data, random_state):
41 | self.env = env
42 | self.env_options = env_options
43 | self.map_data = map_data
44 | self.random_state = random_state
45 | self.reset()
46 |
47 | def update_random_state(self, random_state):
48 | self.random_state = random_state
49 |
50 | def reset(self):
51 | # Initialize the agent list
52 | self._reset_agent_list()
53 | # Reset position map
54 | self._reset_pos_map()
55 | # Randomize the agent statuses
56 | self.randomize()
57 | # Initialize the time step
58 | self.time_step = 0
59 |
60 | def randomize(self):
61 | # Choose a random agent in a random team to possess the ball
62 | rand_idx = self.random_state.randint(len(Teams))
63 | team_has_ball = Teams(rand_idx)
64 | team_agent_has_ball = self.random_state.randint(
65 | self.env_options.team_size)
66 | # Set the properties for each team and each agent
67 | for team_name in Teams:
68 | for team_agent_index in range(self.env_options.team_size):
69 | # Get the agent index
70 | agent_index = self.env.get_agent_index(
71 | team_name, team_agent_index)
72 | # Randomize the agent positions
73 | found_pos = False
74 | while not found_pos:
75 | spawn_list = self.map_data.spawn[team_name.name]
76 | rand_idx = self.random_state.randint(len(spawn_list))
77 | agent_pos = spawn_list[rand_idx]
78 | if not self.get_pos_status(agent_pos):
79 | self.set_agent_pos(agent_index, agent_pos)
80 | found_pos = True
81 | # Randomize the possession of the ball
82 | set_ball = (team_name == team_has_ball
83 | and team_agent_index == team_agent_has_ball)
84 | if set_ball:
85 | self.set_agent_ball(agent_index, True)
86 | else:
87 | self.set_agent_ball(agent_index, False)
88 | # Randomize the agent mode
89 | rand_idx = self.random_state.randint(len(AgentModes))
90 | agent_mode = AgentModes(rand_idx)
91 | self.set_agent_mode(agent_index, agent_mode)
92 | # Reset the action
93 | self.set_agent_action(agent_index, Actions.STAND)
94 |
95 | def is_terminal(self):
96 | # When the time step exceeds 100
97 | if self.time_step >= 100:
98 | return True
99 | # When one of the agent reaches the goal
100 | for agent_index in range(self.env_options.agent_size):
101 | if self.is_agent_win(agent_index):
102 | return True
103 | # Otherwise, the state isn't terminal
104 | return False
105 |
106 | def is_team_win(self, team_name):
107 | for team_agent_index in range(self.env_options.team_size):
108 | agent_index = self.env.get_agent_index(team_name, team_agent_index)
109 | if self.is_agent_win(agent_index):
110 | return True
111 | return False
112 |
113 | def is_agent_win(self, agent_index):
114 | # Get the agent statuses
115 | agent_pos = self.get_agent_pos(agent_index)
116 | has_ball = self.get_agent_ball(agent_index)
117 | # Agent cannot win if he doesn't possess the ball
118 | if not has_ball:
119 | return False
120 | # Get the team name
121 | team_name = Teams(agent_index)
122 | # Check whether the position is in the goal area
123 | return agent_pos in self.map_data.goals[team_name.name]
124 |
125 | def get_gym_state(self, map_size):
126 | agent_size = len(self.agent_list)
127 | player_goal_size = len(self.map_data.goals['PLAYER'])
128 | computer_goal_size = len(self.map_data.goals['COMPUTER'])
129 | map_2d = np.zeros(map_size)
130 | agent_pos_list = np.zeros((agent_size, 2))
131 | rel_player_goals = np.zeros((agent_size, player_goal_size, 2))
132 | rel_computer_goals = np.zeros((agent_size, computer_goal_size, 2))
133 | rel_other_agent_pos = np.zeros((agent_size, agent_size - 1, 2))
134 | ball_list = np.zeros(agent_size)
135 | mode_list = np.zeros(agent_size)
136 | action_list = np.zeros(agent_size)
137 | for pos in self.map_data.walkable:
138 | # Walkable
139 | map_2d[tuple(pos)] = 1
140 | for pos in self.map_data.goals['PLAYER']:
141 | # Player goal
142 | map_2d[tuple(pos)] = 2
143 | for pos in self.map_data.goals['COMPUTER']:
144 | # Computer goal
145 | map_2d[tuple(pos)] = 3
146 | for idx, agent in enumerate(self.agent_list):
147 | # Agent
148 | ball_list[idx] = self.get_agent_ball(idx)
149 | mode_list[idx] = self.get_agent_mode(idx)
150 | action_list[idx] = self.get_agent_action(idx)
151 | # Agent position
152 | agent_pos = agent['pos']
153 | agent_pos_list[idx, :] = agent_pos
154 | # Other agent positions
155 | for other_idx, other_agent in enumerate(self.agent_list):
156 | if idx != other_idx:
157 | other_agent_pos = other_agent['pos']
158 | rel_other_agent_pos[idx, :] = self.get_rel_pos(
159 | agent_pos, other_agent_pos)
160 | # Relative goal positions
161 | for pos_idx, goal in enumerate(self.map_data.goals['PLAYER']):
162 | rel_player_goals[idx, pos_idx, :] = self.get_rel_pos(
163 | agent_pos, goal)
164 | for pos_idx, goal in enumerate(self.map_data.goals['COMPUTER']):
165 | rel_computer_goals[idx, pos_idx, :] = self.get_rel_pos(
166 | agent_pos, goal)
167 | return {
168 | 'map': map_2d,
169 | 'agent_pos': agent_pos_list,
170 | 'relative': {
171 | 'player_goals': rel_player_goals,
172 | 'computer_goals': rel_computer_goals,
173 | 'other_agent_pos': rel_other_agent_pos,
174 | },
175 | 'ball': ball_list,
176 | 'mode': mode_list,
177 | 'action': action_list,
178 | }
179 |
180 | def get_agent_pos(self, agent_index):
181 | return self.agent_list[agent_index]['pos']
182 |
183 | def set_agent_pos(self, agent_index, pos):
184 | # Get old position
185 | old_pos = self.agent_list[agent_index].get('pos', None)
186 | # Remove old position from map
187 | if old_pos:
188 | old_pos_tuple = tuple(old_pos)
189 | self.pos_map.pop(old_pos_tuple, None)
190 | # Set position in map
191 | if pos:
192 | pos_tuple = tuple(pos)
193 | self.pos_map[pos_tuple] = agent_index
194 | # Set the new position
195 | self.agent_list[agent_index]['pos'] = pos
196 |
197 | def get_agent_ball(self, agent_index):
198 | return self.agent_list[agent_index]['ball']
199 |
200 | def set_agent_ball(self, agent_index, has_ball):
201 | self.agent_list[agent_index]['ball'] = has_ball
202 |
203 | def get_agent_mode(self, agent_index):
204 | return self.agent_list[agent_index]['mode']
205 |
206 | def set_agent_mode(self, agent_index, mode):
207 | self.agent_list[agent_index]['mode'] = mode
208 |
209 | def get_agent_action(self, agent_index):
210 | return self.agent_list[agent_index]['action']
211 |
212 | def set_agent_action(self, agent_index, action):
213 | self.agent_list[agent_index]['action'] = action
214 |
215 | def get_agent_frame_skip_index(self, agent_index):
216 | return self.agent_list[agent_index]['frame_skip_index']
217 |
218 | def set_agent_frame_skip_index(self, agent_index, frame_skip_index):
219 | self.agent_list[agent_index]['frame_skip_index'] = frame_skip_index
220 |
221 | def get_pos_status(self, pos):
222 | pos_tuple = tuple(pos)
223 | agent_index = self.pos_map.get(pos_tuple, None)
224 | if agent_index:
225 | team_name = Teams(agent_index)
226 | team_agent_index = self.env.get_team_agent_index(agent_index)
227 | return {
228 | 'team_name': team_name,
229 | 'team_agent_index': team_agent_index,
230 | 'agent_index': agent_index,
231 | }
232 | else:
233 | return None
234 |
235 | def get_ball_possession(self):
236 | for team_name in Teams:
237 | for team_agent_index in range(self.env_options.team_size):
238 | agent_index = self.env.get_agent_index(
239 | team_name, team_agent_index)
240 | if self.get_agent_ball(agent_index):
241 | return {
242 | 'team_name': team_name,
243 | 'team_agent_index': team_agent_index,
244 | 'agent_index': agent_index,
245 | }
246 | return None
247 |
248 | def switch_ball(self, agent_index, other_agent_index):
249 | agent_ball = self.get_agent_ball(agent_index)
250 | self.set_agent_ball(agent_index, not agent_ball)
251 | self.set_agent_ball(other_agent_index, agent_ball)
252 |
253 | def increase_frame_skip_index(self, agent_index, frame_skip):
254 | old_frame_skip_index = self.agent_list[agent_index]['frame_skip_index']
255 | new_frame_skip_index = (old_frame_skip_index + 1) % frame_skip
256 | self.agent_list[agent_index]['frame_skip_index'] = new_frame_skip_index
257 |
258 | def increase_time_step(self):
259 | self.time_step += 1
260 |
261 | @staticmethod
262 | def get_rel_pos(ref, target):
263 | return [target[0] - ref[0], target[1] - ref[1]]
264 |
265 | def _reset_agent_list(self):
266 | self.agent_list = [{}
267 | for _ in range(self.env_options.agent_size)]
268 | for agent_index in range(self.env_options.agent_size):
269 | self.set_agent_pos(agent_index, None)
270 | self.set_agent_ball(agent_index, False)
271 | self.set_agent_mode(agent_index, None)
272 | self.set_agent_action(agent_index, None)
273 | self.set_agent_frame_skip_index(agent_index, 0)
274 |
275 | def _reset_pos_map(self):
276 | self.pos_map = {}
277 |
278 | def __repr__(self):
279 | message = ''
280 | # The agent positions, mode, and last taken action
281 | for team_index, team in enumerate(Teams):
282 | if team_index > 0:
283 | message += '\n'
284 | message += 'Team {}:'.format(team.name)
285 | for team_agent_index in range(self.env_options.team_size):
286 | # Get the agent index
287 | agent_index = self.env.get_agent_index(team, team_agent_index)
288 | # Get the position
289 | agent_pos = self.get_agent_pos(agent_index)
290 | # Get the mode
291 | agent_mode = self.get_agent_mode(agent_index)
292 | # Get the last taken action
293 | agent_action = self.get_agent_action(agent_index)
294 | message += '\nAgent {}:'.format(team_agent_index + 1)
295 | message += ' Position: {}'.format(agent_pos)
296 | if agent_mode is not None:
297 | message += ', Mode: {}'.format(agent_mode.name)
298 | if agent_action is not None:
299 | message += ', Action: {}'.format(agent_action.name)
300 | # The possession of the ball
301 | ball_possession = self.get_ball_possession()
302 | team = ball_possession['team_name']
303 | team_agent_index = ball_possession['team_agent_index']
304 | message += '\nBall possession: In team {} with agent {}'.format(
305 | team.name, team_agent_index + 1)
306 | # The time step
307 | message += '\nTime step: {}'.format(self.time_step)
308 | return message
309 |
310 | def __eq__(self, other):
311 | if not isinstance(other, State):
312 | return False
313 | return (self.agent_list == other.agent_list
314 | and self.time_step == other.time_step)
315 |
316 | def __hash__(self):
317 | hash_list = []
318 | for agent_index in range(self.env_options.agent_size):
319 | hash_list.extend(self.get_agent_pos(agent_index))
320 | hash_list.append(self.get_agent_ball(agent_index))
321 | hash_list.append(self.get_agent_mode(agent_index))
322 | hash_list.append(self.get_agent_action(agent_index))
323 | hash_list.append(self.time_step)
324 | return hash(tuple(hash_list))
325 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/soccer/teams.py:
--------------------------------------------------------------------------------
1 | # Native modules
2 | from enum import IntEnum
3 |
4 |
5 | class Teams(IntEnum):
6 | PLAYER = 0
7 | COMPUTER = 1
8 |
--------------------------------------------------------------------------------
/pygame_rl/scenario/soccer_renderer.py:
--------------------------------------------------------------------------------
1 | # Third-party modules
2 | import pygame
3 | import pygame.locals
4 |
5 | # User-defined modules
6 | import pygame_rl.renderer.pygame_renderer as pygame_renderer
7 |
8 |
9 | class SoccerRenderer(pygame_renderer.TiledRenderer):
10 | """Soccer renderer.
11 | """
12 | # Constants
13 | title = 'Soccer'
14 |
15 | # Environment
16 | env = None
17 |
18 | # Renderer options
19 | renderer_options = None
20 |
21 | # Display state
22 | display_quitted = False
23 |
24 | # TMX objects
25 | static_overlays = None
26 |
27 | # Clock object (pygame.time.Clock())
28 | clock = None
29 |
30 | # Dirty groups (pygame.sprite.RenderUpdates)
31 | dirty_groups = None
32 |
33 | # Previous ball state
34 | prev_ball_state = None
35 |
36 | def __init__(self, map_path, env, renderer_options=None):
37 | super().__init__(map_path)
38 | # Save the environment
39 | self.env = env
40 | # Use or create the renderer options
41 | self.renderer_options = renderer_options or RendererOptions()
42 |
43 | def load(self):
44 | # Initialize Pygame
45 | pygame.display.init()
46 | pygame.display.set_mode([400, 300])
47 | pygame.display.set_caption(self.title)
48 |
49 | # Initialize the renderer
50 | super().load()
51 |
52 | # Set the screen size
53 | resolution = super().get_display_size()
54 | self.screen = pygame.display.set_mode(resolution)
55 |
56 | # Get the background
57 | self.background = super().get_background()
58 |
59 | # Get the static overlays
60 | self.static_overlays = super().get_overlays()
61 |
62 | # Initialize previous ball state
63 | self._init_prev_ball_state()
64 |
65 | # Initialize the dirty group
66 | self._load_dirty_group()
67 |
68 | # Blit the background to the screen
69 | self.screen.blit(self.background, [0, 0])
70 |
71 | # Update the full display
72 | if self.renderer_options.show_display:
73 | pygame.display.flip()
74 |
75 | # Create the clock
76 | self.clock = pygame.time.Clock()
77 |
78 | def render(self):
79 | # Close the display if the renderer options is set to disable the
80 | # display
81 | if not self.display_quitted and not self.renderer_options.show_display:
82 | # Replace the screen surface with in-memory surface
83 | self.screen = self.screen.copy()
84 | # Close the display
85 | pygame.display.quit()
86 | # Prevent from further closing
87 | self.display_quitted = True
88 |
89 | # Clear the overlays
90 | self.dirty_groups.clear(self.screen, self.background)
91 |
92 | # Update the overlays by the environment state
93 | self._update_overlay_pos()
94 | self._update_overlay_visibility()
95 |
96 | # Draw the overlays
97 | dirty = self.dirty_groups.draw(self.screen)
98 |
99 | # Update only the dirty surface
100 | if self.renderer_options.show_display:
101 | pygame.display.update(dirty)
102 |
103 | # Limit the max frames per second
104 | if self.renderer_options.show_display:
105 | self.clock.tick(self.renderer_options.max_fps)
106 |
107 | # Handle the events
108 | if self.renderer_options.show_display:
109 | for event in pygame.event.get():
110 | # Detect the quit event
111 | if event.type == pygame.locals.QUIT:
112 | # Indicate the rendering should stop
113 | return False
114 | # Detect the keydown event
115 | if self.renderer_options.enable_key_events:
116 | if event.type == pygame.locals.KEYDOWN:
117 | # Get the agent index of the first player
118 | team_agent_index = 0
119 | agent_index = self.env.get_agent_index(
120 | 'PLAYER', team_agent_index)
121 | # Prepare the cached action
122 | cached_action = None
123 | if event.key == pygame.locals.K_RIGHT:
124 | cached_action = 'MOVE_RIGHT'
125 | elif event.key == pygame.locals.K_UP:
126 | cached_action = 'MOVE_UP'
127 | elif event.key == pygame.locals.K_LEFT:
128 | cached_action = 'MOVE_LEFT'
129 | elif event.key == pygame.locals.K_DOWN:
130 | cached_action = 'MOVE_DOWN'
131 | elif event.key == pygame.locals.K_s:
132 | cached_action = 'STAND'
133 | # Take the cached action and update the state
134 | if cached_action:
135 | self.env.take_cached_action(
136 | agent_index, cached_action)
137 | self.env.update_state()
138 |
139 | # Indicate the rendering should continue
140 | return True
141 |
142 | def _init_prev_ball_state(self):
143 | agent_size = self.env.options.get_agent_size()
144 | self.prev_ball_state = agent_size * [None]
145 |
146 | def _load_dirty_group(self):
147 | self.dirty_groups = pygame.sprite.RenderUpdates()
148 |
149 | def _update_overlay_pos(self):
150 | for agent_index in range(self.env.options.get_agent_size()):
151 | [overlay_has_ball, overlay_no_ball] = self._get_overlays(
152 | agent_index)
153 | has_ball = self.env.state.get_agent_ball(agent_index)
154 | agent_pos = self.env.state.get_agent_pos(agent_index)
155 | if has_ball:
156 | overlay_has_ball.set_pos(agent_pos)
157 | else:
158 | overlay_no_ball.set_pos(agent_pos)
159 |
160 | def _update_overlay_visibility(self):
161 | for agent_index in range(self.env.options.get_agent_size()):
162 | # Get the static overlays
163 | [overlay_has_ball, overlay_no_ball] = self._get_overlays(
164 | agent_index)
165 | # Check whether the agent has the ball
166 | has_ball = self.env.state.get_agent_ball(agent_index)
167 | # Get the previous ball state
168 | prev_has_ball = self.prev_ball_state[agent_index]
169 | # Check whether the ball state has changed
170 | if prev_has_ball is None or prev_has_ball != has_ball:
171 | # Remove the old sprite and add the new sprite in the dirty
172 | # group
173 | if has_ball:
174 | self.dirty_groups.remove(overlay_no_ball)
175 | self.dirty_groups.add(overlay_has_ball)
176 | else:
177 | self.dirty_groups.remove(overlay_has_ball)
178 | self.dirty_groups.add(overlay_no_ball)
179 | # Set the previous ball state
180 | self.prev_ball_state[agent_index] = has_ball
181 |
182 | def _get_overlays(self, agent_index):
183 | name_has_ball = 'AGENT{}_BALL'.format(agent_index + 1)
184 | name_no_ball = 'AGENT{}'.format(agent_index + 1)
185 | overlay_has_ball = self.static_overlays[name_has_ball]
186 | overlay_no_ball = self.static_overlays[name_no_ball]
187 | return [overlay_has_ball, overlay_no_ball]
188 |
189 |
190 | class RendererOptions(object):
191 | """Renderer options.
192 | """
193 | show_display = False
194 | max_fps = 0
195 | enable_key_events = False
196 |
197 | def __init__(self, show_display=False, max_fps=0, enable_key_events=False):
198 | self.show_display = show_display
199 | self.max_fps = max_fps
200 | self.enable_key_events = enable_key_events
201 |
--------------------------------------------------------------------------------
/pygame_rl/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sc420/pygame-rl/f81da559385876616d99c74b43e4345f53d086d2/pygame_rl/util/__init__.py
--------------------------------------------------------------------------------
/pygame_rl/util/file_util.py:
--------------------------------------------------------------------------------
1 | # Native modules
2 | import os
3 |
4 | # Third-party modules
5 | import pkg_resources
6 | import yaml
7 |
8 | # Package name
9 | PACKAGE_NAME = 'pygame_rl'
10 |
11 |
12 | def get_resource_path(resource_name):
13 | """Get the resource path.
14 |
15 | Args:
16 | resource_name (str): The resource name relative to the project root
17 | directory.
18 |
19 | Returns:
20 | str: The true resource path on the system.
21 | """
22 | package = pkg_resources.Requirement.parse(PACKAGE_NAME)
23 | return pkg_resources.resource_filename(package, resource_name)
24 |
25 |
26 | def read_yaml(filename):
27 | """Read a yaml file.
28 |
29 | Args:
30 | filename (str): The yaml filename.
31 |
32 | Returns:
33 | dict: The yaml config.
34 | """
35 | with open(filename, 'r') as stream:
36 | obj = yaml.safe_load(stream)
37 | stream.close()
38 | return obj
39 |
40 |
41 | def resolve_path(path1, path2):
42 | """Resolve the path.
43 |
44 | Args:
45 | path1 (str): The base path.
46 | path2 (str): The relative path to path1.
47 |
48 | Returns:
49 | str: The resolved path.
50 | """
51 | # Get the directory at which the file is
52 | path1_dir = os.path.dirname(path1)
53 | # Join the paths of the file directory and the map path
54 | joined_path = os.path.join(path1_dir, path2)
55 | # Normalize the path
56 | return os.path.normpath(joined_path)
57 |
--------------------------------------------------------------------------------
/pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 |
3 | # A comma-separated list of package or module names from where C extensions may
4 | # be loaded. Extensions are loading into the active Python interpreter and may
5 | # run arbitrary code
6 | extension-pkg-whitelist=numpy,pygame
7 |
8 | # Add files or directories to the blacklist. They should be base names, not
9 | # paths.
10 | ignore=CVS
11 |
12 | # Add files or directories matching the regex patterns to the blacklist. The
13 | # regex matches against base names, not paths.
14 | ignore-patterns=
15 |
16 | # Python code to execute, usually for sys.path manipulation such as
17 | # pygtk.require().
18 | #init-hook=
19 |
20 | # Use multiple processes to speed up Pylint.
21 | jobs=1
22 |
23 | # List of plugins (as comma separated values of python modules names) to load,
24 | # usually to register additional checkers.
25 | load-plugins=
26 |
27 | # Pickle collected data for later comparisons.
28 | persistent=yes
29 |
30 | # Specify a configuration file.
31 | #rcfile=
32 |
33 | # Allow loading of arbitrary C extensions. Extensions are imported into the
34 | # active Python interpreter and may run arbitrary code.
35 | unsafe-load-any-extension=no
36 |
37 |
38 | [MESSAGES CONTROL]
39 |
40 | # Only show warnings with the listed confidence levels. Leave empty to show
41 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
42 | confidence=
43 |
44 | # Disable the message, report, category or checker with the given id(s). You
45 | # can either give multiple identifiers separated by comma (,) or put this
46 | # option multiple times (only on the command line, not in the configuration
47 | # file where it should appear only once).You can also use "--disable=all" to
48 | # disable everything first and then reenable specific checks. For example, if
49 | # you want to run only the similarities checker, you can use "--disable=all
50 | # --enable=similarities". If you want to run only the classes checker, but have
51 | # no Warning level messages displayed, use"--disable=all --enable=classes
52 | # --disable=W"
53 | disable=print-statement,parameter-unpacking,unpacking-in-except,old-raise-syntax,backtick,long-suffix,old-ne-operator,old-octal-literal,import-star-module-level,raw-checker-failed,bad-inline-option,locally-disabled,locally-enabled,file-ignored,suppressed-message,useless-suppression,deprecated-pragma,apply-builtin,basestring-builtin,buffer-builtin,cmp-builtin,coerce-builtin,execfile-builtin,file-builtin,long-builtin,raw_input-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,no-absolute-import,old-division,dict-iter-method,dict-view-method,next-method-called,metaclass-assignment,indexing-exception,raising-string,reload-builtin,oct-method,hex-method,nonzero-method,cmp-method,input-builtin,round-builtin,intern-builtin,unichr-builtin,map-builtin-not-iterating,zip-builtin-not-iterating,range-builtin-not-iterating,filter-builtin-not-iterating,using-cmp-argument,eq-without-hash,div-method,idiv-method,rdiv-method,exception-message-attribute,invalid-str-codec,sys-max-int,bad-python3-import,deprecated-string-function,deprecated-str-translate-call,C0111,R0201,R0902,R0903,R0904,R0912,R0914,R1705
54 |
55 | # Enable the message, report, category or checker with the given id(s). You can
56 | # either give multiple identifier separated by comma (,) or put this option
57 | # multiple time (only on the command line, not in the configuration file where
58 | # it should appear only once). See also the "--disable" option for examples.
59 | enable=
60 |
61 |
62 | [REPORTS]
63 |
64 | # Python expression which should return a note less than 10 (10 is the highest
65 | # note). You have access to the variables errors warning, statement which
66 | # respectively contain the number of errors / warnings messages and the total
67 | # number of statements analyzed. This is used by the global evaluation report
68 | # (RP0004).
69 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
70 |
71 | # Template used to display messages. This is a python new-style format string
72 | # used to format the message information. See doc for all details
73 | #msg-template=
74 |
75 | # Set the output format. Available formats are text, parseable, colorized, json
76 | # and msvs (visual studio).You can also give a reporter class, eg
77 | # mypackage.mymodule.MyReporterClass.
78 | output-format=text
79 |
80 | # Tells whether to display a full report or only the messages
81 | reports=no
82 |
83 | # Activate the evaluation score.
84 | score=yes
85 |
86 |
87 | [REFACTORING]
88 |
89 | # Maximum number of nested blocks for function / method body
90 | max-nested-blocks=5
91 |
92 |
93 | [BASIC]
94 |
95 | # Naming hint for argument names
96 | argument-name-hint=(([a-z][a-z0-9_]{1,30})|(_[a-z0-9_]*))$
97 |
98 | # Regular expression matching correct argument names
99 | argument-rgx=(([a-z][a-z0-9_]{1,30})|(_[a-z0-9_]*))$
100 |
101 | # Naming hint for attribute names
102 | attr-name-hint=(([a-z][a-z0-9_]{1,30})|(_[a-z0-9_]*))$
103 |
104 | # Regular expression matching correct attribute names
105 | attr-rgx=(([a-z][a-z0-9_]{1,30})|(_[a-z0-9_]*))$
106 |
107 | # Bad variable names which should always be refused, separated by a comma
108 | bad-names=foo,bar,baz,toto,tutu,tata
109 |
110 | # Naming hint for class attribute names
111 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{1,30}|(__.*__))$
112 |
113 | # Regular expression matching correct class attribute names
114 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{1,30}|(__.*__))$
115 |
116 | # Naming hint for class names
117 | class-name-hint=[A-Z_][a-zA-Z0-9]+$
118 |
119 | # Regular expression matching correct class names
120 | class-rgx=[A-Z_][a-zA-Z0-9]+$
121 |
122 | # Naming hint for constant names
123 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
124 |
125 | # Regular expression matching correct constant names
126 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
127 |
128 | # Minimum line length for functions/classes that require docstrings, shorter
129 | # ones are exempt.
130 | docstring-min-length=-1
131 |
132 | # Naming hint for function names
133 | function-name-hint=(([a-z][a-z0-9_]{1,30})|(_[a-z0-9_]*))$
134 |
135 | # Regular expression matching correct function names
136 | function-rgx=(([a-z][a-z0-9_]{1,30})|(_[a-z0-9_]*))$
137 |
138 | # Good variable names which should always be accepted, separated by a comma
139 | good-names=i,j,k,ex,Run,_
140 |
141 | # Include a hint for the correct naming format with invalid-name
142 | include-naming-hint=no
143 |
144 | # Naming hint for inline iteration names
145 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
146 |
147 | # Regular expression matching correct inline iteration names
148 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
149 |
150 | # Naming hint for method names
151 | method-name-hint=(([a-z][a-z0-9_]{1,30})|(_[a-z0-9_]*))$
152 |
153 | # Regular expression matching correct method names
154 | method-rgx=(([a-z][a-z0-9_]{1,30})|(_[a-z0-9_]*))$
155 |
156 | # Naming hint for module names
157 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
158 |
159 | # Regular expression matching correct module names
160 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
161 |
162 | # Colon-delimited sets of names that determine each other's naming style when
163 | # the name regexes allow several styles.
164 | name-group=
165 |
166 | # Regular expression which should only match function or class names that do
167 | # not require a docstring.
168 | no-docstring-rgx=^_
169 |
170 | # List of decorators that produce properties, such as abc.abstractproperty. Add
171 | # to this list to register other decorators that produce valid properties.
172 | property-classes=abc.abstractproperty
173 |
174 | # Naming hint for variable names
175 | variable-name-hint=(([a-z][a-z0-9_]{1,30})|(_[a-z0-9_]*))$
176 |
177 | # Regular expression matching correct variable names
178 | variable-rgx=(([a-z][a-z0-9_]{1,30})|(_[a-z0-9_]*))$
179 |
180 |
181 | [FORMAT]
182 |
183 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
184 | expected-line-ending-format=
185 |
186 | # Regexp for a line that is allowed to be longer than the limit.
187 | ignore-long-lines=^\s*(# )??$
188 |
189 | # Number of spaces of indent required inside a hanging or continued line.
190 | indent-after-paren=4
191 |
192 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
193 | # tab).
194 | indent-string=' '
195 |
196 | # Maximum number of characters on a single line.
197 | max-line-length=100
198 |
199 | # Maximum number of lines in a module
200 | max-module-lines=1000
201 |
202 | # List of optional constructs for which whitespace checking is disabled. `dict-
203 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
204 | # `trailing-comma` allows a space between comma and closing bracket: (a, ).
205 | # `empty-line` allows space-only lines.
206 | no-space-check=trailing-comma,dict-separator
207 |
208 | # Allow the body of a class to be on the same line as the declaration if body
209 | # contains single statement.
210 | single-line-class-stmt=no
211 |
212 | # Allow the body of an if to be on the same line as the test if there is no
213 | # else.
214 | single-line-if-stmt=no
215 |
216 |
217 | [LOGGING]
218 |
219 | # Logging modules to check that the string format arguments are in logging
220 | # function parameter format
221 | logging-modules=logging
222 |
223 |
224 | [MISCELLANEOUS]
225 |
226 | # List of note tags to take in consideration, separated by a comma.
227 | notes=FIXME,XXX,TODO
228 |
229 |
230 | [SIMILARITIES]
231 |
232 | # Ignore comments when computing similarities.
233 | ignore-comments=yes
234 |
235 | # Ignore docstrings when computing similarities.
236 | ignore-docstrings=yes
237 |
238 | # Ignore imports when computing similarities.
239 | ignore-imports=no
240 |
241 | # Minimum lines number of a similarity.
242 | min-similarity-lines=4
243 |
244 |
245 | [SPELLING]
246 |
247 | # Spelling dictionary name. Available dictionaries: none. To make it working
248 | # install python-enchant package.
249 | spelling-dict=
250 |
251 | # List of comma separated words that should not be checked.
252 | spelling-ignore-words=
253 |
254 | # A path to a file that contains private dictionary; one word per line.
255 | spelling-private-dict-file=
256 |
257 | # Tells whether to store unknown words to indicated private dictionary in
258 | # --spelling-private-dict-file option instead of raising a message.
259 | spelling-store-unknown-words=no
260 |
261 |
262 | [TYPECHECK]
263 |
264 | # List of decorators that produce context managers, such as
265 | # contextlib.contextmanager. Add to this list to register other decorators that
266 | # produce valid context managers.
267 | contextmanager-decorators=contextlib.contextmanager
268 |
269 | # List of members which are set dynamically and missed by pylint inference
270 | # system, and so shouldn't trigger E1101 when accessed. Python regular
271 | # expressions are accepted.
272 | generated-members=
273 |
274 | # Tells whether missing members accessed in mixin class should be ignored. A
275 | # mixin class is detected if its name ends with "mixin" (case insensitive).
276 | ignore-mixin-members=yes
277 |
278 | # This flag controls whether pylint should warn about no-member and similar
279 | # checks whenever an opaque object is returned when inferring. The inference
280 | # can return multiple potential results while evaluating a Python object, but
281 | # some branches might not be evaluated, which results in partial inference. In
282 | # that case, it might be useful to still emit no-member and other checks for
283 | # the rest of the inferred objects.
284 | ignore-on-opaque-inference=yes
285 |
286 | # List of class names for which member attributes should not be checked (useful
287 | # for classes with dynamically set attributes). This supports the use of
288 | # qualified names.
289 | ignored-classes=optparse.Values,thread._local,_thread._local
290 |
291 | # List of module names for which member attributes should not be checked
292 | # (useful for modules/projects where namespaces are manipulated during runtime
293 | # and thus existing member attributes cannot be deduced by static analysis. It
294 | # supports qualified module names, as well as Unix pattern matching.
295 | ignored-modules=
296 |
297 | # Show a hint with possible names when a member name was not found. The aspect
298 | # of finding the hint is based on edit distance.
299 | missing-member-hint=yes
300 |
301 | # The minimum edit distance a name should have in order to be considered a
302 | # similar match for a missing member name.
303 | missing-member-hint-distance=1
304 |
305 | # The total number of similar names that should be taken in consideration when
306 | # showing a hint for a missing member.
307 | missing-member-max-choices=1
308 |
309 |
310 | [VARIABLES]
311 |
312 | # List of additional names supposed to be defined in builtins. Remember that
313 | # you should avoid to define new builtins when possible.
314 | additional-builtins=
315 |
316 | # Tells whether unused global variables should be treated as a violation.
317 | allow-global-unused-variables=yes
318 |
319 | # List of strings which can identify a callback function by name. A callback
320 | # name must start or end with one of those strings.
321 | callbacks=cb_,_cb
322 |
323 | # A regular expression matching the name of dummy variables (i.e. expectedly
324 | # not used).
325 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
326 |
327 | # Argument names that match this expression will be ignored. Default to name
328 | # with leading underscore
329 | ignored-argument-names=_.*|^ignored_|^unused_
330 |
331 | # Tells whether we should check for unused import in __init__ files.
332 | init-import=no
333 |
334 | # List of qualified module names which can have objects that can redefine
335 | # builtins.
336 | redefining-builtins-modules=six.moves,future.builtins
337 |
338 |
339 | [CLASSES]
340 |
341 | # List of method names used to declare (i.e. assign) instance attributes.
342 | defining-attr-methods=__init__,__new__,setUp
343 |
344 | # List of member names, which should be excluded from the protected access
345 | # warning.
346 | exclude-protected=_asdict,_fields,_replace,_source,_make
347 |
348 | # List of valid names for the first argument in a class method.
349 | valid-classmethod-first-arg=cls
350 |
351 | # List of valid names for the first argument in a metaclass class method.
352 | valid-metaclass-classmethod-first-arg=mcs
353 |
354 |
355 | [DESIGN]
356 |
357 | # Maximum number of arguments for function / method
358 | max-args=5
359 |
360 | # Maximum number of attributes for a class (see R0902).
361 | max-attributes=7
362 |
363 | # Maximum number of boolean expressions in a if statement
364 | max-bool-expr=5
365 |
366 | # Maximum number of branch for function / method body
367 | max-branches=12
368 |
369 | # Maximum number of locals for function / method body
370 | max-locals=15
371 |
372 | # Maximum number of parents for a class (see R0901).
373 | max-parents=7
374 |
375 | # Maximum number of public methods for a class (see R0904).
376 | max-public-methods=20
377 |
378 | # Maximum number of return / yield for function / method body
379 | max-returns=6
380 |
381 | # Maximum number of statements in function / method body
382 | max-statements=50
383 |
384 | # Minimum number of public methods for a class (see R0903).
385 | min-public-methods=2
386 |
387 |
388 | [IMPORTS]
389 |
390 | # Allow wildcard imports from modules that define __all__.
391 | allow-wildcard-with-all=no
392 |
393 | # Analyse import fallback blocks. This can be used to support both Python 2 and
394 | # 3 compatible code, which means that the block might have code that exists
395 | # only in one or another interpreter, leading to false positives when analysed.
396 | analyse-fallback-blocks=no
397 |
398 | # Deprecated modules which should not be used, separated by a comma
399 | deprecated-modules=optparse,tkinter.tix
400 |
401 | # Create a graph of external dependencies in the given file (report RP0402 must
402 | # not be disabled)
403 | ext-import-graph=
404 |
405 | # Create a graph of every (i.e. internal and external) dependencies in the
406 | # given file (report RP0402 must not be disabled)
407 | import-graph=
408 |
409 | # Create a graph of internal dependencies in the given file (report RP0402 must
410 | # not be disabled)
411 | int-import-graph=
412 |
413 | # Force import order to recognize a module as part of the standard
414 | # compatibility libraries.
415 | known-standard-library=
416 |
417 | # Force import order to recognize a module as part of a third party library.
418 | known-third-party=enchant
419 |
420 |
421 | [EXCEPTIONS]
422 |
423 | # Exceptions that will emit a warning when being caught. Defaults to
424 | # "Exception"
425 | overgeneral-exceptions=Exception
426 |
--------------------------------------------------------------------------------
/sample/data/map/gridworld/agents_sprite.yaml:
--------------------------------------------------------------------------------
1 | PLAYER1:
2 | x: 0
3 | y: 0
4 | GOAL:
5 | x: 1
6 | y: 0
--------------------------------------------------------------------------------
/sample/data/map/gridworld/gridworld_9x9.tmx:
--------------------------------------------------------------------------------
1 |
2 |
56 |
--------------------------------------------------------------------------------
/sample/data/map/gridworld/ground_tile.yaml:
--------------------------------------------------------------------------------
1 | GROUND: 686
--------------------------------------------------------------------------------
/sample/data/map/gridworld/minecraft_tileset.tsx:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/sample/data/map/gridworld/obstacles_sprite.yaml:
--------------------------------------------------------------------------------
1 | OBSTACLE1:
2 | x: 0
3 | y: 0
--------------------------------------------------------------------------------
/sample/data/map/predator_prey/agent_sprite.yaml:
--------------------------------------------------------------------------------
1 | PREDATOR:
2 | x: 0
3 | y: 0
4 | OBSTACLE:
5 | x: 1
6 | y: 0
7 | PREY:
8 | x: 2
9 | y: 0
--------------------------------------------------------------------------------
/sample/data/map/predator_prey/ground_tile.yaml:
--------------------------------------------------------------------------------
1 | FIELD: 686
--------------------------------------------------------------------------------
/sample/data/map/predator_prey/minecraft_tileset.tsx:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/sample/data/map/predator_prey/predator_prey_15x15.tmx:
--------------------------------------------------------------------------------
1 |
2 |
51 |
--------------------------------------------------------------------------------
/sample/data/map/soccer/agent_sprite.yaml:
--------------------------------------------------------------------------------
1 | AGENT1:
2 | x: 0
3 | y: 0
4 | AGENT2:
5 | x: 1
6 | y: 0
7 | AGENT3:
8 | x: 2
9 | y: 0
10 | AGENT4:
11 | x: 3
12 | y: 0
13 | AGENT5:
14 | x: 4
15 | y: 0
16 | AGENT6:
17 | x: 0
18 | y: 2
19 | AGENT7:
20 | x: 1
21 | y: 2
22 | AGENT8:
23 | x: 2
24 | y: 2
25 | AGENT9:
26 | x: 3
27 | y: 2
28 | AGENT10:
29 | x: 4
30 | y: 2
31 | AGENT1_BALL:
32 | x: 0
33 | y: 1
34 | AGENT2_BALL:
35 | x: 1
36 | y: 1
37 | AGENT3_BALL:
38 | x: 2
39 | y: 1
40 | AGENT4_BALL:
41 | x: 3
42 | y: 1
43 | AGENT5_BALL:
44 | x: 4
45 | y: 1
46 | AGENT6_BALL:
47 | x: 0
48 | y: 3
49 | AGENT7_BALL:
50 | x: 1
51 | y: 3
52 | AGENT8_BALL:
53 | x: 2
54 | y: 3
55 | AGENT9_BALL:
56 | x: 3
57 | y: 3
58 | AGENT10_BALL:
59 | x: 4
60 | y: 3
--------------------------------------------------------------------------------
/sample/data/map/soccer/goal_tile.yaml:
--------------------------------------------------------------------------------
1 | PLAYER: 74
2 | COMPUTER: 98
--------------------------------------------------------------------------------
/sample/data/map/soccer/ground_tile.yaml:
--------------------------------------------------------------------------------
1 | WALKABLE: 1
--------------------------------------------------------------------------------
/sample/data/map/soccer/minecraft_tileset.tsx:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/sample/data/map/soccer/soccer_13x10_goal_4.tmx:
--------------------------------------------------------------------------------
1 |
2 |
79 |
--------------------------------------------------------------------------------
/sample/data/map/soccer/soccer_21x14_goal_4.tmx:
--------------------------------------------------------------------------------
1 |
2 |
95 |
--------------------------------------------------------------------------------
/sample/data/map/soccer/soccer_21x14_goal_6.tmx:
--------------------------------------------------------------------------------
1 |
2 |
95 |
--------------------------------------------------------------------------------
/sample/data/map/soccer/soccer_21x14_goal_8.tmx:
--------------------------------------------------------------------------------
1 |
2 |
95 |
--------------------------------------------------------------------------------
/sample/data/map/soccer/spawn_tile.yaml:
--------------------------------------------------------------------------------
1 | PLAYER: 40
2 | COMPUTER: 203
--------------------------------------------------------------------------------
/sample/data/tileset/minecraft_sprite_32x32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sc420/pygame-rl/f81da559385876616d99c74b43e4345f53d086d2/sample/data/tileset/minecraft_sprite_32x32.png
--------------------------------------------------------------------------------
/sample/gridworld/environment_advanced.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # pylint: disable=W0611
3 | """Sample: Interacting with the environment with advanced usage.
4 | """
5 |
6 | # Native modules
7 | import copy
8 | import os
9 |
10 | # Third-party modules
11 | import gym
12 | import numpy as np
13 | import scipy.misc
14 |
15 | # User-defined modules
16 | import pygame_rl.scenario.gridworld
17 | import pygame_rl.scenario.gridworld.options as options
18 | import pygame_rl.scenario.gridworld.renderer as renderer
19 | import pygame_rl.util.file_util as file_util
20 |
21 |
22 | # Map size
23 | MAP_SIZE = [9, 9]
24 | # Action size
25 | ACTION_SIZE = 5
26 | # Group names and sizes
27 | GROUP_NAMES = [
28 | 'PLAYER1',
29 | 'OBSTACLE1',
30 | 'GOAL',
31 | ]
32 | # Probability of random action
33 | RANDOM_ACTION_PROB = 0.1
34 |
35 |
36 | # Group sizes
37 | group_sizes = []
38 | # Reset counter
39 | reset_counter = 0
40 |
41 |
42 | def main():
43 | global group_sizes
44 |
45 | # Create an environment
46 | env = gym.make('gridworld-v1')
47 |
48 | # Resolve the map path relative to this file
49 | map_path = file_util.resolve_path(
50 | __file__, '../data/map/gridworld/gridworld_9x9.tmx')
51 |
52 | # Set the environment options
53 | env.env_options = options.GridworldOptions(
54 | map_path=map_path,
55 | action_space=gym.spaces.Discrete(ACTION_SIZE),
56 | step_callback=step_callback,
57 | reset_callback=reset_callback
58 | )
59 |
60 | env.renderer_options = renderer.RendererOptions(
61 | show_display=False, max_fps=60)
62 |
63 | # Load the enviornment
64 | env.load()
65 |
66 | # Set the random seed of the environment
67 | env.seed(0)
68 |
69 | # Run many episodes
70 | for episode_ind in range(6):
71 | # Print the episode number
72 | print('')
73 | print('Episode {}:'.format(episode_ind + 1))
74 | # Set the group names and sizes
75 | group_sizes = [
76 | 1,
77 | episode_ind,
78 | 1,
79 | ]
80 | env.env_options.set_group(GROUP_NAMES, group_sizes)
81 | # Reset the environment
82 | state = env.reset()
83 | # Print the shape of initial state
84 | print('Shape of initial state:{}'.format(state.shape))
85 | # Run the episode
86 | done = False
87 | timestep = 0
88 | while not done:
89 | # Render the environment
90 | screenshot = env.render()
91 | # Take random action
92 | random_action = env.action_space.sample()
93 | # Update the environment
94 | next_state, reward, done, _ = env.step(random_action)
95 | # Transition to the next state
96 | state = next_state
97 | timestep += 1
98 | print('Episode ended. Reward: {}. Timestep: {}'.format(
99 | reward, timestep))
100 |
101 | # Save the last screenshot
102 | screenshot_relative_path = 'screenshot.png'
103 | screenshot_abs_path = os.path.abspath(screenshot_relative_path)
104 | scipy.misc.imsave(screenshot_abs_path, screenshot)
105 | print('The last screenshot is saved to {}'.format(screenshot_abs_path))
106 |
107 |
108 | def step_callback(prev_state, action, random_state):
109 | state = copy.deepcopy(prev_state)
110 | # Get player 1 position
111 | pos = prev_state['PLAYER1'][0]
112 | # Get new position
113 | new_pos = get_new_pos(pos, action, random_state)
114 | # Update state
115 | if is_valid_pos(new_pos, prev_state):
116 | state['PLAYER1'][0] = new_pos
117 | done = is_done(pos, state)
118 | reward = 0.0
119 | info = {}
120 | return state, reward, done, info
121 |
122 |
123 | def reset_callback(random_state):
124 | global reset_counter
125 | del random_state
126 | obstacles1 = np.asarray([
127 | np.array([4, 3]),
128 | np.array([3, 4]),
129 | np.array([4, 4]),
130 | np.array([5, 4]),
131 | np.array([4, 5]),
132 | ])
133 | reset_counter += 1
134 | return {
135 | 'PLAYER1': np.asarray([
136 | np.array([0, 0]),
137 | ]),
138 | 'OBSTACLE1': obstacles1[:(reset_counter - 1)],
139 | 'GOAL': np.asarray([
140 | np.array([8, 8]),
141 | ]),
142 | }
143 |
144 |
145 | def get_new_pos(pos, action, random_state):
146 | new_pos = np.array(pos)
147 | # Whether to choose random action
148 | if random_state.rand() < RANDOM_ACTION_PROB:
149 | action = random_state.randint(ACTION_SIZE)
150 | # Move the position
151 | if action == 0: # Move right
152 | new_pos[0] += 1
153 | elif action == 1: # Move up
154 | new_pos[1] -= 1
155 | elif action == 2: # Move left
156 | new_pos[0] -= 1
157 | elif action == 3: # Move down
158 | new_pos[1] += 1
159 | elif action == 4: # Stand still
160 | pass
161 | else:
162 | raise ValueError('Unknown action: {}'.format(action))
163 | return new_pos
164 |
165 |
166 | def is_valid_pos(pos, prev_state):
167 | in_bound = (pos[0] >= 0 and pos[0] < MAP_SIZE[0] and
168 | pos[1] >= 0 and pos[1] < MAP_SIZE[1])
169 | collision_group_names = [
170 | 'OBSTACLE1',
171 | ]
172 | no_collision = not check_collision(
173 | pos, collision_group_names, prev_state)
174 | return in_bound and no_collision
175 |
176 |
177 | def is_done(pos, state):
178 | collision_group_names = [
179 | 'GOAL',
180 | ]
181 | return check_collision(pos, collision_group_names, state)
182 |
183 |
184 | def check_collision(pos, collision_group_names, state):
185 | global group_sizes
186 | for group_index, group_name in enumerate(GROUP_NAMES):
187 | if not group_name in collision_group_names:
188 | continue
189 | for local_index in range(group_sizes[group_index]):
190 | other_pos = state[group_name][local_index]
191 | if np.array_equal(pos, other_pos):
192 | return True
193 | return False
194 |
195 |
196 | if __name__ == '__main__':
197 | main()
198 |
--------------------------------------------------------------------------------
/sample/gridworld/environment_simple.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # pylint: disable=W0611
3 | """Sample: Interacting with the environment with minimal setup.
4 | """
5 |
6 | # Native modules
7 | import os
8 |
9 | # Third-party modules
10 | import gym
11 | import scipy.misc
12 |
13 | # User-defined modules
14 | import pygame_rl.scenario.gridworld
15 |
16 |
17 | def main():
18 | # Create an environment
19 | env = gym.make('gridworld-v0')
20 |
21 | # Load the enviornment
22 | env.load()
23 |
24 | # Run many episodes
25 | for episode_ind in range(10):
26 | # Print the episode number
27 | print('')
28 | print('Episode {}:'.format(episode_ind + 1))
29 | # Reset the environment
30 | state = env.reset()
31 | # Print the shape of initial state
32 | print('Shape of initial state:{}'.format(state.shape))
33 | # Run the episode
34 | done = False
35 | timestep = 0
36 | while not done:
37 | # Render the environment
38 | screenshot = env.render()
39 | # Take random action
40 | random_action = env.action_space.sample()
41 | # Update the environment
42 | next_state, reward, done, _ = env.step(random_action)
43 | # Print the status
44 | print('Timestep: {}'.format(timestep + 1))
45 | print('Reward: {}'.format(reward))
46 | # Transition to the next state
47 | state = next_state
48 | timestep += 1
49 |
50 | # Save the last screenshot
51 | screenshot_relative_path = 'screenshot.png'
52 | screenshot_abs_path = os.path.abspath(screenshot_relative_path)
53 | scipy.misc.imsave(screenshot_abs_path, screenshot)
54 | print('The last screenshot is saved to {}'.format(screenshot_abs_path))
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/sample/gridworld/renderer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Sample: Use only the renderer with the default map.
3 |
4 | Press the arrow keys to move agent 1. Press the "S" key to take the "STAND"
5 | action of agent 1.
6 |
7 | """
8 |
9 | # Third-party modules
10 | import gym
11 |
12 | # User-defined modules
13 | import pygame_rl.scenario.gridworld.renderer as renderer
14 |
15 |
16 | def main():
17 | # Create an environment
18 | env = gym.make('gridworld-v0')
19 |
20 | # Set the renderer options
21 | env.renderer_options = renderer.RendererOptions(
22 | show_display=True, max_fps=60, enable_key_events=True)
23 |
24 | # Reset the environment
25 | env.reset()
26 |
27 | # Keep rendering until the renderer window is closed
28 | is_running = True
29 | while is_running:
30 | is_running = env.renderer.render()
31 |
32 |
33 | if __name__ == '__main__':
34 | main()
35 |
--------------------------------------------------------------------------------
/sample/predator_prey/environment_advanced.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Sample: Interacting with the environment with advanced usage.
3 | """
4 |
5 | # Native modules
6 | import os
7 | import random
8 |
9 | # Third-party modules
10 | import numpy as np
11 | import scipy.misc
12 |
13 | # User-defined modules
14 | import pygame_rl.scenario.predator_prey_environment as predator_prey_environment
15 | import pygame_rl.util.file_util as file_util
16 |
17 |
18 | def main():
19 | # Initialize the random number generator to have consistent results
20 | random.seed(0)
21 |
22 | # Resolve the map path relative to this file
23 | map_path = file_util.resolve_path(
24 | __file__, '../data/map/predator_prey/predator_prey_15x15.tmx')
25 |
26 | # Create an environment options
27 | object_size = {
28 | 'PREDATOR': 3,
29 | 'PREY': 3,
30 | 'OBSTACLE': 8,
31 | }
32 | env_options = predator_prey_environment.PredatorPreyEnvironmentOptions(
33 | map_path=map_path, object_size=object_size, po_radius=3,
34 | ai_frame_skip=2)
35 |
36 | # Create an environment
37 | env = predator_prey_environment.PredatorPreyEnvironment(
38 | env_options=env_options)
39 |
40 | # Get index range of preys
41 | predator_index_range = env.get_group_index_range('PREDATOR')
42 | first_predator_index = range(*predator_index_range)[0]
43 |
44 | # Run many episodes
45 | for episode_index in range(10):
46 | # Print the episode number
47 | print('')
48 | print('Episode {}:'.format(episode_index + 1))
49 | # Reset the environment and get the initial observation
50 | observation = env.reset()
51 | state = observation.state
52 | action = observation.action
53 | reward = observation.reward
54 | next_state = observation.next_state
55 | # Print the state
56 | print('Initial state:\n({}, {}, {}, {})\n'.format(
57 | state, action, reward, next_state))
58 | # Run the episode
59 | is_running = True
60 | while is_running:
61 | # Render the environment
62 | env.render()
63 | # Get position of the first predator
64 | pos = np.array(env.state.get_object_pos(first_predator_index))
65 | # Get partially observable symbolic view of the first agent with a
66 | # radius of 2
67 | po_view = env.state.get_po_symbolic_view(pos, 2)
68 | # Get partially observable screenshot of the first agent with a
69 | # radius of 2
70 | po_screenshot = env.renderer.get_po_screenshot(pos, 2)
71 | # Build actions without obstacles
72 | actions_wo = [None] * (env.options.object_size['PREDATOR'] +
73 | env.options.object_size['PREY'])
74 | # Get a random action from the action list
75 | action = random.choice(env.actions)
76 | # Set the action of the first predator
77 | actions_wo[0] = action
78 | # Update the environment and get observation
79 | observation = env.step_without_obstacles(actions_wo)
80 | # Check the terminal state
81 | if env.state.is_terminal():
82 | print('Terminal state:\n{}'.format(observation))
83 | print('Episode {} ends at time step {}'.format(
84 | episode_index + 1, env.state.time_step + 1))
85 | is_running = False
86 |
87 | # Get position of the first predator
88 | pos = np.array(env.state.get_object_pos(first_predator_index))
89 |
90 | # Print the last partially observable symbolic view
91 | po_view = env.state.get_po_symbolic_view(pos, 2)
92 | print(po_view)
93 |
94 | # Save the last partially observable screenshot
95 | env.render()
96 | po_screenshot = env.renderer.get_po_screenshot(pos, 2)
97 | screenshot_relative_path = 'screenshot.png'
98 | screenshot_abs_path = os.path.abspath(screenshot_relative_path)
99 | scipy.misc.imsave(screenshot_abs_path, po_screenshot)
100 | print('The last partially observable screenshot is saved to {}'.format(
101 | screenshot_abs_path))
102 |
103 |
104 | if __name__ == '__main__':
105 | main()
106 |
--------------------------------------------------------------------------------
/sample/predator_prey/environment_simple.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Sample: Interacting with the environment with minimal setup.
3 | """
4 |
5 | # Native modules
6 | import os
7 | import random
8 |
9 | # Third-party modules
10 | import scipy.misc
11 |
12 | # User-defined modules
13 | import pygame_rl.scenario.predator_prey_environment as predator_prey_environment
14 |
15 |
16 | def main():
17 | # Initialize the random number generator to have consistent results
18 | random.seed(0)
19 |
20 | # Create an environment
21 | env = predator_prey_environment.PredatorPreyEnvironment()
22 |
23 | # Get index range of predators
24 | predator_index_range = env.get_group_index_range('PREDATOR')
25 |
26 | # Run many episodes
27 | for episode_index in range(10):
28 | # Print the episode number
29 | print('')
30 | print('Episode {}:'.format(episode_index + 1))
31 | # Reset the environment
32 | observation = env.reset()
33 | # Print the initial state
34 | print('Initial observation:\n{}\n'.format(observation))
35 | # Run the episode
36 | is_running = True
37 | while is_running:
38 | # Render the environment
39 | env.render()
40 | # Get the screenshot
41 | screenshot = env.renderer.get_screenshot()
42 | # Take cached actions
43 | for predator_index in range(*predator_index_range):
44 | # Get a random action from the action list
45 | action = random.choice(env.actions)
46 | # Take the cached action
47 | env.take_cached_action(predator_index, action)
48 | # Update the environment and get observation
49 | observation = env.update_state()
50 | # Check the terminal state
51 | if env.state.is_terminal():
52 | print('Terminal observation:\n{}'.format(observation))
53 | print('Episode {} ends at time step {}'.format(
54 | episode_index + 1, env.state.time_step + 1))
55 | is_running = False
56 |
57 | # Save the last screenshot
58 | env.render()
59 | screenshot = env.renderer.get_screenshot()
60 | screenshot_relative_path = 'screenshot.png'
61 | screenshot_abs_path = os.path.abspath(screenshot_relative_path)
62 | scipy.misc.imsave(screenshot_abs_path, screenshot)
63 | print('The last screenshot is saved to {}'.format(screenshot_abs_path))
64 |
65 |
66 | if __name__ == '__main__':
67 | main()
68 |
--------------------------------------------------------------------------------
/sample/predator_prey/renderer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Sample: Use only the renderer with the default map.
3 |
4 | Press the arrow keys to move agent 1. Press the "S" key to take the "STAND"
5 | action of agent 1.
6 |
7 | """
8 |
9 | # User-defined modules
10 | import pygame_rl.scenario.predator_prey_environment as predator_prey_environment
11 | import pygame_rl.scenario.predator_prey_renderer as predator_prey_renderer
12 |
13 |
14 | def main():
15 | # Create a renderer options
16 | renderer_options = predator_prey_renderer.RendererOptions(
17 | show_display=True, max_fps=60, enable_key_events=True)
18 |
19 | # Create an environment
20 | env = predator_prey_environment.PredatorPreyEnvironment(
21 | renderer_options=renderer_options)
22 |
23 | # Get the renderer wrapped in the environment
24 | renderer = env.renderer
25 |
26 | # Initialize the renderer
27 | renderer.load()
28 |
29 | # Keep rendering until the renderer window is closed
30 | is_running = True
31 | while is_running:
32 | is_running = renderer.render()
33 |
34 |
35 | if __name__ == '__main__':
36 | main()
37 |
--------------------------------------------------------------------------------
/sample/soccer/env_soccer_v0.py:
--------------------------------------------------------------------------------
1 | """Sample: Interacting with the environment with minimal setup.
2 | """
3 |
4 | # Native modules
5 | import os
6 |
7 | # Third-party modules
8 | import gym
9 | import scipy.misc
10 |
11 | # User-defined modules
12 | from pygame_rl.scenario.soccer.actions import Actions
13 | import pygame_rl.scenario.soccer
14 |
15 |
16 | def main():
17 | # Create a soccer environment
18 | env = gym.make('soccer-v0')
19 |
20 | # Load the environment
21 | env.load()
22 |
23 | # Run many episodes
24 | for ep_idx in range(20):
25 | # Print the episode number
26 | print('')
27 | print('Episode {}:'.format(ep_idx + 1))
28 | # Reset the environment
29 | state = env.reset()
30 | # Print the initial state
31 | print('Initial state:\n{}\n'.format(state))
32 | # Run the episode
33 | while True:
34 | # Render the environment
35 | screenshot = env.render()
36 | # Get random actions
37 | actions = env.action_space.sample()
38 | # Reset the computer action because we don't want to control it
39 | actions[1] = Actions.NOOP
40 | # Interact with the environment
41 | new_state, reward, done, _ = env.step(actions)
42 | # Check the terminal state
43 | if done:
44 | print('Terminal state:\n{}\nReward: {}'.format(new_state, reward))
45 | break
46 | # Transition the state
47 | state = new_state
48 |
49 | # Save the last screenshot
50 | screenshot = env.render()
51 | screenshot_relative_path = 'screenshot.png'
52 | screenshot_abs_path = os.path.abspath(screenshot_relative_path)
53 | scipy.misc.imsave(screenshot_abs_path, screenshot)
54 | print('The last screenshot is saved to {}'.format(screenshot_abs_path))
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/sample/soccer/environment_advanced.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Sample: Interacting with the environment with advanced usage.
3 | """
4 |
5 | # Native modules
6 | import os
7 | import random
8 |
9 | # Third-party modules
10 | import numpy as np
11 | import scipy.misc
12 |
13 | # User-defined modules
14 | import pygame_rl.scenario.soccer_environment as soccer_environment
15 | import pygame_rl.util.file_util as file_util
16 |
17 |
18 | def main():
19 | # Initialize the random number generator to have consistent results
20 | random.seed(0)
21 |
22 | # Resolve the map path relative to this file
23 | map_path = file_util.resolve_path(
24 | __file__, '../data/map/soccer/soccer_21x14_goal_4.tmx')
25 |
26 | # Create a soccer environment options
27 | # "map_data" is specified to use the custom map.
28 | # "team_size" is given to specify the agents in one team.
29 | # "ai_frame_skip" is to control the frame skip for AI only
30 | env_options = soccer_environment.SoccerEnvironmentOptions(
31 | map_path=map_path, team_size=2, ai_frame_skip=2)
32 |
33 | # Create a soccer environment
34 | # If you want to render the environment, an optional argument
35 | # "renderer_options" can be used. For the sample usage, see
36 | # "sample/renderer.py".
37 | env = soccer_environment.SoccerEnvironment(env_options=env_options)
38 |
39 | # Run many episodes
40 | for episode_index in range(20):
41 | # Print the episode number
42 | print('')
43 | print('Episode {}:'.format(episode_index + 1))
44 | # Reset the environment and get the initial observation. The observation
45 | # is a class defined as "soccer_environment.SoccerObservation".
46 | observation = env.reset()
47 | state = observation.state
48 | action = observation.action
49 | reward = observation.reward
50 | next_state = observation.next_state
51 | # Print the state, action, reward, and next state pair
52 | print('Initial state:\n({}, {}, {}, {})\n'.format(
53 | state, action, reward, next_state))
54 | # Run the episode
55 | is_running = True
56 | while is_running:
57 | # Render the environment. The renderer will lazy load on the first
58 | # call. Skip the call if you don't need the rendering.
59 | env.render()
60 | # Get the partially observable screenshot of the first agent with a
61 | # radius of 1. The returned `screenshot` is a `numpy.ndarray`, the
62 | # format is the same as the returned value of `scipy.misc.imread`.
63 | # The previous call is required for this call to work.
64 | po_screenshot = env.renderer.get_po_screenshot(0, 1)
65 | # Control only the first agent in each team
66 | team_agent_index = 0
67 | for team_name in env.team_names:
68 | agent_index = env.get_agent_index(team_name, team_agent_index)
69 | action = random.choice(env.actions)
70 | env.take_cached_action(agent_index, action)
71 | # Update the state and get the observation
72 | observation = env.update_state()
73 | # Check the terminal state
74 | if env.state.is_terminal():
75 | print('Terminal state:\n{}'.format(observation))
76 | print('Episode {} ends at time step {}'.format(
77 | episode_index + 1, env.state.time_step + 1))
78 | is_running = False
79 |
80 | # Save the last partially observable screenshot
81 | env.render()
82 | agent_pos = np.array(env.state.get_agent_pos(0))
83 | po_screenshot = env.renderer.get_po_screenshot(agent_pos, radius=1)
84 | screenshot_relative_path = 'screenshot.png'
85 | screenshot_abs_path = os.path.abspath(screenshot_relative_path)
86 | scipy.misc.imsave(screenshot_abs_path, po_screenshot)
87 | print('The last partially observable screenshot is saved to {}'.format(
88 | screenshot_abs_path))
89 |
90 |
91 | if __name__ == '__main__':
92 | main()
93 |
--------------------------------------------------------------------------------
/sample/soccer/environment_legacy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Sample: Interacting with the legacy environment.
3 | """
4 |
5 | # Native modules
6 | import os
7 | import random
8 |
9 | # Third-party modules
10 | import scipy.misc
11 |
12 | # User-defined modules
13 | import pygame_rl.scenario.soccer_environment as soccer_environment
14 |
15 |
16 | def main():
17 | # Initialize the random number generator to have consistent results
18 | random.seed(0)
19 |
20 | # Create a soccer environment
21 | soccer_env = soccer_environment.SoccerLegacyEnvironment()
22 |
23 | # Run many episodes
24 | for episode_index in range(20):
25 | # Print the episode number
26 | print('')
27 | print('Episode {}:'.format(episode_index + 1))
28 | # Reset the environment
29 | observation = soccer_env.reset()
30 | # Print the initial state
31 | print('Initial state:\n{}\n'.format(observation))
32 | # Run the episode
33 | is_running = True
34 | while is_running:
35 | # Render the environment
36 | soccer_env.render()
37 | # Get the screenshot
38 | screenshot = soccer_env.renderer.get_screenshot()
39 | # Get a random action from the action list
40 | action = random.choice(soccer_env.actions)
41 | # Take the action and get the observation
42 | observation = soccer_env.take_action(action)
43 | # Check the terminal state
44 | if soccer_env.state.is_terminal():
45 | print('Terminal state:\n{}'.format(observation))
46 | print('Episode {} ends at time step {}'.format(
47 | episode_index + 1, soccer_env.state.time_step + 1))
48 | is_running = False
49 |
50 | # Save the last screenshot
51 | soccer_env.render()
52 | screenshot = soccer_env.renderer.get_screenshot()
53 | screenshot_relative_path = 'screenshot.png'
54 | screenshot_abs_path = os.path.abspath(screenshot_relative_path)
55 | scipy.misc.imsave(screenshot_abs_path, screenshot)
56 | print('The last screenshot is saved to {}'.format(screenshot_abs_path))
57 |
58 |
59 | if __name__ == '__main__':
60 | main()
61 |
--------------------------------------------------------------------------------
/sample/soccer/environment_simple.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Sample: Interacting with the environment with minimal setup.
3 | """
4 |
5 | # Native modules
6 | import os
7 | import random
8 |
9 | # Third-party modules
10 | import scipy.misc
11 |
12 | # User-defined modules
13 | import pygame_rl.scenario.soccer_environment as soccer_environment
14 |
15 |
16 | def main():
17 | # Initialize the random number generator to have consistent results
18 | random.seed(0)
19 |
20 | # Create a soccer environment
21 | soccer_env = soccer_environment.SoccerEnvironment()
22 |
23 | # Run many episodes
24 | for episode_index in range(20):
25 | # Print the episode number
26 | print('')
27 | print('Episode {}:'.format(episode_index + 1))
28 | # Reset the environment
29 | observation = soccer_env.reset()
30 | # Print the initial state
31 | print('Initial state:\n{}\n'.format(observation))
32 | # Run the episode
33 | is_running = True
34 | while is_running:
35 | # Render the environment
36 | soccer_env.render()
37 | # Get the screenshot
38 | screenshot = soccer_env.renderer.get_screenshot()
39 | # Get a random action from the action list
40 | action = random.choice(soccer_env.actions)
41 | # Take the action
42 | team_agent_index = 0
43 | agent_index = soccer_env.get_agent_index(
44 | 'PLAYER', team_agent_index)
45 | soccer_env.take_cached_action(agent_index, action)
46 | # Update the state and get the observation
47 | observation = soccer_env.update_state()
48 | # Check the terminal state
49 | if soccer_env.state.is_terminal():
50 | print('Terminal state:\n{}'.format(observation))
51 | print('Episode {} ends at time step {}'.format(
52 | episode_index + 1, soccer_env.state.time_step + 1))
53 | is_running = False
54 |
55 | # Save the last screenshot
56 | soccer_env.render()
57 | screenshot = soccer_env.renderer.get_screenshot()
58 | screenshot_relative_path = 'screenshot.png'
59 | screenshot_abs_path = os.path.abspath(screenshot_relative_path)
60 | scipy.misc.imsave(screenshot_abs_path, screenshot)
61 | print('The last screenshot is saved to {}'.format(screenshot_abs_path))
62 |
63 |
64 | if __name__ == '__main__':
65 | main()
66 |
--------------------------------------------------------------------------------
/sample/soccer/renderer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Sample: Use only the renderer with the default map.
3 |
4 | Press the arrow keys to move the agent 1. Press the "S" key to take the "STAND"
5 | action of the agent 1.
6 |
7 | """
8 |
9 | # User-defined modules
10 | import pygame_rl.scenario.soccer_environment as soccer_environment
11 | import pygame_rl.scenario.soccer_renderer as soccer_renderer
12 |
13 |
14 | def main():
15 | # Create a renderer options
16 | renderer_options = soccer_renderer.RendererOptions(
17 | show_display=True, max_fps=60, enable_key_events=True)
18 |
19 | # Create a soccer environment
20 | soccer_env = soccer_environment.SoccerEnvironment(
21 | renderer_options=renderer_options)
22 |
23 | # Get the renderer wrapped in the environment
24 | renderer = soccer_env.renderer
25 |
26 | # Initialize the renderer
27 | renderer.load()
28 |
29 | # Keep rendering until the renderer window is closed
30 | is_running = True
31 | while is_running:
32 | is_running = renderer.render()
33 |
34 |
35 | if __name__ == '__main__':
36 | main()
37 |
--------------------------------------------------------------------------------
/sample/soccer/renderer_custom_map.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Sample: Use only the renderer with a custom map in "sample/data".
3 |
4 | Press the arrow keys to move the agent 1. Press the "S" key to take the "STAND"
5 | action of the agent 1.
6 |
7 | """
8 |
9 | # User-defined modules
10 | import pygame_rl.scenario.soccer_environment as soccer_environment
11 | import pygame_rl.scenario.soccer_renderer as soccer_renderer
12 | import pygame_rl.util.file_util as file_util
13 |
14 |
15 | def main():
16 | # Resolve the map path relative to this file
17 | map_path = file_util.resolve_path(
18 | __file__, '../data/map/soccer/soccer_21x14_goal_6.tmx')
19 |
20 | # Create a soccer environment options
21 | env_options = soccer_environment.SoccerEnvironmentOptions(
22 | map_path=map_path, team_size=5, ai_frame_skip=2)
23 |
24 | # Create a renderer options
25 | renderer_options = soccer_renderer.RendererOptions(
26 | show_display=True, max_fps=60, enable_key_events=True)
27 |
28 | # Create a soccer environment
29 | soccer_env = soccer_environment.SoccerEnvironment(
30 | env_options, renderer_options=renderer_options)
31 |
32 | # Get the renderer wrapped in the environment
33 | renderer = soccer_env.renderer
34 |
35 | # Initialize the renderer
36 | renderer.load()
37 |
38 | # Keep rendering until the renderer window is closed
39 | is_running = True
40 | while is_running:
41 | is_running = renderer.render()
42 |
43 |
44 | if __name__ == '__main__':
45 | main()
46 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [aliases]
2 | test=pytest
3 |
4 | [tool:pytest]
5 | addopts=-rsxX -q
6 | testpaths=tests
7 | python_files=test_*.py
8 | python_classes=*Test
9 | python_functions=test_*
10 |
11 | [coverage:report]
12 | # Regexes for lines to exclude from consideration
13 | exclude_lines =
14 | # Have to re-enable the standard pragma
15 | pragma: no cover
16 |
17 | # Don't complain about missing debug-only code:
18 | def __repr__
19 | if self\.debug
20 |
21 | # Don't complain if tests don't hit defensive assertion code:
22 | raise AssertionError
23 | raise NotImplementedError
24 |
25 | # Don't complain if non-runnable code isn't run:
26 | if 0:
27 | if __name__ == .__main__.:
28 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Sample from https://github.com/pypa/sampleproject/blob/master/setup.py
3 | """
4 |
5 | # Native modules
6 | from setuptools import setup, find_packages
7 |
8 |
9 | setup(
10 | name='pygame_rl',
11 |
12 | # Versions should comply with PEP440. For a discussion on single-sourcing
13 | # the version across setup.py and the project code, see
14 | # https://packaging.python.org/en/latest/single_source_version.html
15 | version='1.0.0',
16 |
17 | description='Game environment for reinforcement learning using Pygame',
18 |
19 | # The project's main homepage.
20 | url='https://github.com/ebola777/pygame-rl',
21 |
22 | # Author details
23 | author='Shawn Chang',
24 | author_email='ebola777@yahoo.com.tw',
25 |
26 | # Choose your license
27 | license='MIT',
28 |
29 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers
30 | classifiers=[
31 | # How mature is this project? Common values are
32 | # 3 - Alpha
33 | # 4 - Beta
34 | # 5 - Production/Stable
35 | 'Development Status :: 3 - Alpha',
36 |
37 | # Indicate who your project is intended for
38 | 'Intended Audience :: Developers',
39 | 'Topic :: Software Development :: Build Tools',
40 |
41 | # Pick your license as you wish (should match "license" above)
42 | 'License :: OSI Approved :: MIT License',
43 |
44 | # Specify the Python versions you support here. In particular, ensure
45 | # that you indicate whether you support Python 2, Python 3 or both.
46 | 'Programming Language :: Python :: 3',
47 | 'Programming Language :: Python :: 3.3',
48 | 'Programming Language :: Python :: 3.4',
49 | 'Programming Language :: Python :: 3.5',
50 | ],
51 |
52 | # What does your project relate to?
53 | keywords='pygame reinforcement-learning',
54 |
55 | # You can just specify the packages manually here if your project is
56 | # simple. Or you can use find_packages().
57 | packages=find_packages(exclude=['contrib', 'docs', 'tests']),
58 |
59 | # Alternatively, if you want to distribute just a my_module.py, uncomment
60 | # this:
61 | # py_modules=["my_module"],
62 |
63 | # List run-time dependencies here. These will be installed by pip when
64 | # your project is installed. For an analysis of "install_requires" vs pip's
65 | # requirements files see:
66 | # https://packaging.python.org/en/latest/requirements.html
67 | install_requires=[
68 | 'gym',
69 | 'numpy',
70 | 'pygame==1.9.6',
71 | 'pypaths==0.1.2',
72 | 'pytmx==3.21.7',
73 | 'pyyaml',
74 | ],
75 |
76 | # List additional groups of dependencies here (e.g. development
77 | # dependencies). You can install these using the following syntax,
78 | # for example:
79 | # $ pip install -e .[dev,test]
80 | extras_require={
81 | 'dev': [
82 | 'scipy',
83 | ],
84 | 'test': [
85 | 'coverage',
86 | 'pytest-runner',
87 | 'pytest',
88 | 'snakeviz',
89 | ],
90 | },
91 |
92 | # If there are data files included in your packages that need to be
93 | # installed, specify them here. If using Python 2.6 or less, then these
94 | # have to be included in MANIFEST.in as well.
95 | package_data={
96 | 'pygame_rl': ['data/**/*'],
97 | },
98 |
99 | # Although 'package_data' is the preferred approach, in some case you may
100 | # need to place data files outside of your packages. See:
101 | # http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files # noqa
102 | # In this case, 'data_file' will be installed into '/my_data'
103 | data_files=[],
104 |
105 | # To provide executable scripts, use entry points in preference to the
106 | # "scripts" keyword. Entry points provide cross-platform support and allow
107 | # pip to create the appropriate form of executable for the target platform.
108 | entry_points={
109 | 'console_scripts': [
110 | ],
111 | },
112 | )
113 |
--------------------------------------------------------------------------------
/tests/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Third-party modules
4 | import pytest
5 |
6 | # Run all the tests
7 | pytest.main()
8 |
--------------------------------------------------------------------------------
/tests/test_file_uitl.py:
--------------------------------------------------------------------------------
1 | # Native modules
2 | import os
3 |
4 | # Testing targets
5 | import pygame_rl.util.file_util as file_util
6 |
7 |
8 | class FileUtilTest(object):
9 | def test_get_resource_path(self):
10 | resource_name = 'pygame_rl/data/map/soccer/soccer.tmx'
11 | resource_path = file_util.get_resource_path(resource_name)
12 | assert os.path.normpath(resource_name) in resource_path
13 |
14 | def test_read_yaml(self):
15 | resource_name = 'pygame_rl/data/map/soccer/agent_sprite.yaml'
16 | resource_path = file_util.get_resource_path(resource_name)
17 | contents = file_util.read_yaml(resource_path)
18 | assert len(contents) > 0
19 |
20 | def test_resolve_path(self):
21 | path1 = 'dir1/file1'
22 | path2 = 'dir2/file2'
23 | expected_path = 'dir1/dir2/file2'
24 | resolved_path = file_util.resolve_path(path1, path2)
25 | assert os.path.normpath(expected_path) == resolved_path
26 |
--------------------------------------------------------------------------------
/tests/test_soccer_environment.py:
--------------------------------------------------------------------------------
1 | # Test modules
2 | import pytest
3 |
4 | # Testing targets
5 | import pygame_rl.scenario.soccer_environment as soccer_environment
6 |
7 |
8 | class SoccerEnvironmentTest(object):
9 | env = None
10 |
11 | @classmethod
12 | def setup_class(cls):
13 | cls.env = soccer_environment.SoccerEnvironment()
14 |
15 | def test_init(self):
16 | # The soccer positions should be non-empty
17 | assert len(self.env.map_data.walkable) > 0
18 | assert len(self.env.map_data.goals) > 0
19 |
20 | def test_reset(self):
21 | # Reset the environment
22 | observation = self.env.reset()
23 | # Get the state
24 | state = self.env.state
25 | # Check the initial observation
26 | assert observation.state == state
27 | assert observation.action is None
28 | assert observation.reward == pytest.approx(0.0)
29 | assert observation.next_state is None
30 | # The initial state should not be terminal
31 | assert not state.is_terminal()
32 | # The player agent positions should be on the left to the computer agent
33 | player_pos = state.get_agent_pos(0)
34 | computer_pos = state.get_agent_pos(1)
35 | assert player_pos[0] < computer_pos[0]
36 | # Either agent has the ball
37 | player_has_ball = state.get_agent_ball(0)
38 | computer_has_ball = state.get_agent_ball(1)
39 | assert player_has_ball != computer_has_ball
40 | # The agents should have random mode
41 | assert state.get_agent_mode(0) in self.env.modes
42 | assert state.get_agent_mode(1) in self.env.modes
43 | # The agents should have the standing action in the beginning
44 | assert state.get_agent_action(0) == self.env.actions[-1]
45 | assert state.get_agent_action(1) == self.env.actions[-1]
46 | # The agents should set the frame skip index to 0
47 | assert state.get_agent_frame_skip_index(0) == 0
48 | assert state.get_agent_frame_skip_index(1) == 0
49 | # The time step should be 0
50 | assert state.time_step == 0
51 |
52 | def test_take_cached_action(self):
53 | # Take each action
54 | team_agent_index = 0
55 | agent_index = self.env.get_agent_index(
56 | self.env.team_names[0], team_agent_index)
57 | expected_time_step = 0
58 | for action in self.env.actions:
59 | # Take the action
60 | self.env.take_cached_action(agent_index, action)
61 | # Update the state
62 | observation = self.env.update_state()
63 | # Get the next state
64 | next_state = observation.next_state
65 | # Increment the expected time step
66 | expected_time_step += 1
67 | # Check the observation
68 | assert observation.state is None
69 | assert observation.action[agent_index] == action
70 | assert observation.reward >= -1.0 and observation.reward <= 1.0
71 | assert observation.next_state.time_step == expected_time_step
72 | # The computer agent should have the last taken action
73 | assert next_state.get_agent_action(1) in self.env.actions
74 |
75 | def test_renderer(self):
76 | self.env.render()
77 | # The renderer should contain the environment
78 | assert self.env.renderer.env == self.env
79 | # The renderer display should have been quitted
80 | assert self.env.renderer.display_quitted
81 |
82 | def test_get_moved_pos(self):
83 | pos = [0, 0]
84 | # Check the moved positions of each action
85 | assert self.env.get_moved_pos(pos, 'MOVE_RIGHT') == [1, 0]
86 | assert self.env.get_moved_pos(pos, 'MOVE_UP') == [0, -1]
87 | assert self.env.get_moved_pos(pos, 'MOVE_LEFT') == [-1, 0]
88 | assert self.env.get_moved_pos(pos, 'MOVE_DOWN') == [0, 1]
89 | assert self.env.get_moved_pos(pos, 'STAND') == pos
90 |
91 | def test_get_pos_distance(self):
92 | pos1 = [0, 0]
93 | pos2 = [3, 4]
94 | # Check the Euclidean distance
95 | assert self.env.get_pos_distance(pos1, pos2) == pytest.approx(5.0)
96 |
97 |
98 | class SoccerLegacyEnvironmentTest(SoccerEnvironmentTest):
99 | @classmethod
100 | def setup_class(cls):
101 | cls.env = soccer_environment.SoccerLegacyEnvironment()
102 |
--------------------------------------------------------------------------------
/tests/test_soccer_environment_scenarios.py:
--------------------------------------------------------------------------------
1 | # Native modules
2 | import random
3 |
4 | # Third-party modules
5 | import pytest
6 |
7 | # Testing targets
8 | import pygame_rl.scenario.soccer_environment as soccer_environment
9 |
10 |
11 | class SoccerEnvironmentTest(object):
12 | env = None
13 | state = None
14 | player_index = None
15 | computer_index = None
16 |
17 | @classmethod
18 | def setup_class(cls):
19 | # Initialize the environment
20 | env_options = soccer_environment.SoccerEnvironmentOptions(team_size=2)
21 | cls.env = soccer_environment.SoccerEnvironment(env_options=env_options)
22 | # Get the environment state
23 | cls.state = cls.env.state
24 | # Get the agent indexes
25 | cls.player_index = [cls.env.get_agent_index('PLAYER', team_index)
26 | for team_index in range(2)]
27 | cls.computer_index = [cls.env.get_agent_index('COMPUTER', team_index)
28 | for team_index in range(2)]
29 |
30 | @pytest.mark.parametrize('seed', range(100))
31 | def test_adjacent(self, seed):
32 | # Set the random seed
33 | random.seed(seed)
34 | # Set the initial positions
35 | self.state.set_agent_pos(self.player_index[0], [6, 0])
36 | self.state.set_agent_pos(self.player_index[1], [6, 1])
37 | self.state.set_agent_pos(self.computer_index[0], [7, 0])
38 | self.state.set_agent_pos(self.computer_index[1], [7, 1])
39 | # Set the agent modes
40 | self.state.set_agent_mode(self.player_index[1], 'OFFENSIVE')
41 | self.state.set_agent_mode(self.computer_index[0], 'OFFENSIVE')
42 | self.state.set_agent_mode(self.computer_index[1], 'OFFENSIVE')
43 | # Give the ball to player 1
44 | ball_possession = self.state.get_ball_possession()
45 | ball_agent_index = ball_possession['agent_index']
46 | self.state.switch_ball(ball_agent_index, self.player_index[0])
47 | # Take the action
48 | self.env.take_cached_action(self.player_index[0], 'STAND')
49 | # Update the state
50 | self.env.update_state()
51 | # Player 1 position should have not changed
52 | assert self.state.get_agent_pos(self.player_index[0]) == [6, 0]
53 | # Player 2 position should have either not changed or moved, but not up
54 | possible_pos = [[6, 1], [5, 1], [6, 2], [7, 1]]
55 | assert self.state.get_agent_pos(self.player_index[1]) in possible_pos
56 | # Computer 1 positions should have either not changed or swapped with
57 | # computer 2
58 | possible_pos = [[7, 0], [7, 1]]
59 | assert self.state.get_agent_pos(self.computer_index[0]) in possible_pos
60 | # Computer 2 positions should have either not changed, swapped with
61 | # computer 1, or swapped with player 2
62 | possible_pos = [[7, 0], [7, 1], [6, 1]]
63 | assert self.state.get_agent_pos(self.computer_index[1]) in possible_pos
64 | # Computer 2 can't have the ball
65 | ball_possession = self.state.get_ball_possession()
66 | ball_agent_index = ball_possession['agent_index']
67 | assert ball_agent_index != self.computer_index[1]
68 |
69 | @pytest.mark.parametrize('seed', range(100))
70 | def test_avoid_opponent(self, seed):
71 | # Set the random seed
72 | random.seed(seed)
73 | # Set the initial positions
74 | self.state.set_agent_pos(self.player_index[0], [1, 2])
75 | self.state.set_agent_pos(self.player_index[1], [1, 3])
76 | self.state.set_agent_pos(self.computer_index[0], [7, 2])
77 | self.state.set_agent_pos(self.computer_index[1], [7, 3])
78 | # Set the agent modes
79 | self.state.set_agent_mode(self.player_index[1], 'DEFENSIVE')
80 | self.state.set_agent_mode(self.computer_index[0], 'DEFENSIVE')
81 | self.state.set_agent_mode(self.computer_index[1], 'DEFENSIVE')
82 | # Give the ball to computer 1
83 | ball_possession = self.state.get_ball_possession()
84 | ball_agent_index = ball_possession['agent_index']
85 | self.state.switch_ball(ball_agent_index, self.computer_index[0])
86 | # Take the action
87 | self.env.take_cached_action(self.player_index[0], 'MOVE_RIGHT')
88 | # Update the state
89 | self.env.update_state()
90 | # Player 2 should approach the nearest opponent goal position against
91 | # computer 1
92 | possible_pos = [[1, 2], [0, 3]]
93 | assert self.state.get_agent_pos(self.player_index[1]) in possible_pos
94 | # Computer 1 should avoid player 1
95 | assert self.state.get_agent_pos(self.computer_index[0]) == [8, 2]
96 | # Computer 2 should approach the nearest opponent goal position against
97 | # player 2
98 | assert self.state.get_agent_pos(self.computer_index[1]) == [8, 3]
99 |
100 | @pytest.mark.parametrize('seed', range(100))
101 | def test_advance_to_goal(self, seed):
102 | # Set the random seed
103 | random.seed(seed)
104 | # Set the initial positions
105 | self.state.set_agent_pos(self.player_index[0], [1, 2])
106 | self.state.set_agent_pos(self.player_index[1], [1, 3])
107 | self.state.set_agent_pos(self.computer_index[0], [7, 2])
108 | self.state.set_agent_pos(self.computer_index[1], [7, 3])
109 | # Set the agent modes
110 | self.state.set_agent_mode(self.player_index[1], 'OFFENSIVE')
111 | self.state.set_agent_mode(self.computer_index[0], 'OFFENSIVE')
112 | self.state.set_agent_mode(self.computer_index[1], 'OFFENSIVE')
113 | # Give the ball to computer 1
114 | ball_possession = self.state.get_ball_possession()
115 | ball_agent_index = ball_possession['agent_index']
116 | self.state.switch_ball(ball_agent_index, self.computer_index[0])
117 | # Take the action
118 | self.env.take_cached_action(self.player_index[0], 'MOVE_RIGHT')
119 | # Update the state
120 | self.env.update_state()
121 | # Player 2 should intercept against computer 1
122 | assert self.state.get_agent_pos(self.player_index[1]) == [2, 3]
123 | # Computer 1 should approach the furthest goal position against player 1
124 | assert self.state.get_agent_pos(self.computer_index[0]) == [6, 2]
125 | # Computer 2 should intercept against player 2
126 | assert self.state.get_agent_pos(self.computer_index[1]) == [6, 3]
127 |
128 | @pytest.mark.parametrize('seed', range(100))
129 | def test_defend_goal(self, seed):
130 | # Set the random seed
131 | random.seed(seed)
132 | # Set the initial positions
133 | self.state.set_agent_pos(self.player_index[0], [1, 2])
134 | self.state.set_agent_pos(self.player_index[1], [1, 3])
135 | self.state.set_agent_pos(self.computer_index[0], [7, 2])
136 | self.state.set_agent_pos(self.computer_index[1], [7, 3])
137 | # Set the agent modes
138 | self.state.set_agent_mode(self.player_index[1], 'OFFENSIVE')
139 | self.state.set_agent_mode(self.computer_index[0], 'DEFENSIVE')
140 | self.state.set_agent_mode(self.computer_index[1], 'DEFENSIVE')
141 | # Give the ball to player 1
142 | ball_possession = self.state.get_ball_possession()
143 | ball_agent_index = ball_possession['agent_index']
144 | self.state.switch_ball(ball_agent_index, self.player_index[0])
145 | # Take the action
146 | self.env.take_cached_action(self.player_index[0], 'MOVE_RIGHT')
147 | # Update the state
148 | self.env.update_state()
149 | # Player 2 should intercept against computer 1
150 | assert self.state.get_agent_pos(self.player_index[1]) == [2, 3]
151 | # Computer 1 should approach the nearest opponent goal position against
152 | # player 1
153 | assert self.state.get_agent_pos(self.computer_index[0]) == [8, 2]
154 | # Computer 2 should approach the nearest opponent goal position against
155 | # player 1
156 | possible_pos = [[7, 2], [8, 3]]
157 | assert self.state.get_agent_pos(self.computer_index[1]) in possible_pos
158 |
159 | @pytest.mark.parametrize('seed', range(100))
160 | def test_intercept_goal(self, seed):
161 | # Set the random seed
162 | random.seed(seed)
163 | # Set the initial positions
164 | self.state.set_agent_pos(self.player_index[0], [1, 2])
165 | self.state.set_agent_pos(self.player_index[1], [1, 3])
166 | self.state.set_agent_pos(self.computer_index[0], [7, 2])
167 | self.state.set_agent_pos(self.computer_index[1], [7, 3])
168 | # Set the agent modes
169 | self.state.set_agent_mode(self.player_index[1], 'DEFENSIVE')
170 | self.state.set_agent_mode(self.computer_index[0], 'OFFENSIVE')
171 | self.state.set_agent_mode(self.computer_index[1], 'OFFENSIVE')
172 | # Give the ball to player 1
173 | ball_possession = self.state.get_ball_possession()
174 | ball_agent_index = ball_possession['agent_index']
175 | self.state.switch_ball(ball_agent_index, self.player_index[0])
176 | # Take the action
177 | self.env.take_cached_action(self.player_index[0], 'MOVE_RIGHT')
178 | # Update the state
179 | self.env.update_state()
180 | # Player 2 should approach the nearest opponent goal position against
181 | # computer 1
182 | assert self.state.get_agent_pos(self.player_index[1]) == [0, 3]
183 | # Computer 1 should intercept against player 1
184 | assert self.state.get_agent_pos(self.computer_index[0]) == [6, 2]
185 | # Computer 2 should intercept against player 1
186 | assert self.state.get_agent_pos(self.computer_index[1]) == [6, 3]
187 |
188 | @pytest.mark.parametrize('seed', range(100))
189 | def test_negative_reward(self, seed):
190 | # Set the random seed
191 | random.seed(seed)
192 | # Set the initial positions
193 | self.state.set_agent_pos(self.player_index[0], [1, 2])
194 | self.state.set_agent_pos(self.player_index[1], [7, 0])
195 | self.state.set_agent_pos(self.computer_index[0], [3, 2])
196 | self.state.set_agent_pos(self.computer_index[1], [3, 3])
197 | # Set the agent modes
198 | self.state.set_agent_mode(self.player_index[1], 'DEFENSIVE')
199 | self.state.set_agent_mode(self.computer_index[0], 'OFFENSIVE')
200 | self.state.set_agent_mode(self.computer_index[1], 'OFFENSIVE')
201 | # Give the ball to player 1
202 | ball_possession = self.state.get_ball_possession()
203 | ball_agent_index = ball_possession['agent_index']
204 | self.state.switch_ball(ball_agent_index, self.player_index[0])
205 | # The computer agent should score in 100 steps
206 | for _ in range(100):
207 | # Take the action
208 | self.env.take_cached_action(self.player_index[0], 'STAND')
209 | # Update the state and get the observation
210 | observation = self.env.update_state()
211 | # Teleport player 2 to the original position so that he can't never
212 | # catch the ball
213 | self.state.set_agent_pos(self.player_index[1], [7, 0])
214 | if observation.next_state.is_team_win('COMPUTER'):
215 | break
216 | assert observation.reward == pytest.approx(-1.0)
217 |
218 | @pytest.mark.parametrize('seed', range(100))
219 | def test_positive_reward(self, seed):
220 | # Set the random seed
221 | random.seed(seed)
222 | # Set the initial positions
223 | self.state.set_agent_pos(self.player_index[0], [5, 2])
224 | self.state.set_agent_pos(self.player_index[1], [5, 3])
225 | self.state.set_agent_pos(self.computer_index[0], [3, 2])
226 | self.state.set_agent_pos(self.computer_index[1], [3, 3])
227 | # Set the computer agent modes
228 | self.state.set_agent_mode(self.computer_index[0], 'OFFENSIVE')
229 | self.state.set_agent_mode(self.computer_index[1], 'OFFENSIVE')
230 | # Give the ball to player 1
231 | ball_possession = self.state.get_ball_possession()
232 | ball_agent_index = ball_possession['agent_index']
233 | self.state.switch_ball(ball_agent_index, self.player_index[0])
234 | # The player agent should score in exactly 3 steps
235 | for _ in range(3):
236 | # Take the action
237 | self.env.take_cached_action(self.player_index[0], 'MOVE_RIGHT')
238 | # Update the state
239 | observation = self.env.update_state()
240 | assert observation.next_state.is_team_win('PLAYER')
241 | assert observation.reward == pytest.approx(1.0)
242 |
--------------------------------------------------------------------------------
/tests/test_soccer_renderer.py:
--------------------------------------------------------------------------------
1 | # Third-party modules
2 | import pygame
3 |
4 | # Testing targets
5 | import pygame_rl.scenario.soccer_environment as soccer_environment
6 |
7 |
8 | class SoccerRendererTest(object):
9 | renderer = None
10 |
11 | @classmethod
12 | def setup_class(cls):
13 | env = soccer_environment.SoccerEnvironment()
14 | cls.renderer = env.renderer
15 |
16 | def test_load(self):
17 | self.renderer.load()
18 | # Check the types of the attributes
19 | assert isinstance(self.renderer.static_overlays, dict)
20 | assert not self.renderer.clock is None
21 | assert isinstance(self.renderer.screen, pygame.Surface)
22 | assert isinstance(self.renderer.background, pygame.Surface)
23 | assert isinstance(self.renderer.dirty_groups,
24 | pygame.sprite.RenderUpdates)
25 |
26 | def test_render(self):
27 | # The renderer should indicate to continue
28 | assert self.renderer.render()
29 | # The display should have quitted
30 | assert self.renderer.display_quitted
31 | # The agent sprites should contain exactly 2 sprites
32 | assert len(self.renderer.dirty_groups.sprites()) == 2
33 |
34 | def test_get_screenshot(self):
35 | # Get the display size
36 | display_size = self.renderer.get_display_size()
37 | # Get the entire screenshot
38 | screenshot = self.renderer.get_screenshot()
39 | # The returned screenshot should have opposite axes and 3 channels
40 | assert screenshot.shape == (display_size[1], display_size[0], 3)
41 |
42 | def test_get_po_screenshot(self):
43 | # Get the display size
44 | tile_size = self.renderer.get_tile_size()
45 | # Specify the agent index
46 | agent_index = 0
47 | # Get the partially observable screenshot of the first agent with a
48 | # radius of 0
49 | radius = 0
50 | po_screenshot = self.renderer.get_po_screenshot(agent_index, radius)
51 | assert po_screenshot.shape == (1 * tile_size[1], 1 * tile_size[0], 3)
52 | # Get the partially observable screenshot of the first agent with a
53 | # radius of 1
54 | radius = 1
55 | po_screenshot = self.renderer.get_po_screenshot(agent_index, radius)
56 | assert po_screenshot.shape == (3 * tile_size[1], 3 * tile_size[0], 3)
57 | # Get the partially observable screenshot of the first agent with a
58 | # radius of 2
59 | radius = 2
60 | po_screenshot = self.renderer.get_po_screenshot(agent_index, radius)
61 | assert po_screenshot.shape == (5 * tile_size[1], 5 * tile_size[0], 3)
62 | # Get the partially observable screenshot of the first agent with a
63 | # radius of 10
64 | radius = 10
65 | po_screenshot = self.renderer.get_po_screenshot(agent_index, radius)
66 | assert po_screenshot.shape == (21 * tile_size[1], 21 * tile_size[0], 3)
67 |
--------------------------------------------------------------------------------