├── .gitignore ├── LICENSE ├── README.md ├── generate_graph.py ├── media └── pyreason-gym-demo.gif ├── pyreason_gym ├── __init__.py ├── envs │ ├── __init__.py │ ├── grid_world.py │ └── map_world.py ├── pyreason_grid_world │ ├── __init__.py │ ├── graph │ │ └── game_graph.graphml │ ├── pyreason_grid_world.py │ └── yamls │ │ └── rules.yaml └── pyreason_map_world │ ├── __init__.py │ ├── graph │ ├── .gitignore │ ├── map_graph.graphml │ ├── map_graph_clustering.graphml │ └── map_graph_landmark.graphml │ ├── pyreason_map_world.py │ └── yamls │ └── rules.yaml ├── setup.py ├── test.py └── tests ├── agent_criss_cross ├── gen_graph_test.py └── test.py ├── agent_same_location_shoot ├── gen_graph_test.py └── test.py ├── base_done_check ├── gen_graph_test.py └── test.py ├── bulk_movement_test ├── gen_graph_test.py └── test.py ├── bullet_pass_through ├── bullet_pass_through_gen_graph_test.py └── bullet_pass_through_test.py ├── bullet_speed ├── bullet_speed_gen_graph_test.py └── bullet_speed_test.py ├── bullet_speed_large ├── bullet_speed_large_gen_graph_test_.py └── bullet_speed_large_test.py ├── follow_bullet ├── follow_bullet_gen_graph_test.py └── follow_bullet_test.py ├── health_check ├── gen_graph_test.py └── test.py ├── multi_agent_done_check ├── gen_graph_test.py └── test.py ├── multi_agent_eval ├── gen_graph_test.py └── test.py ├── multi_agent_shoot ├── gen_graph_test.py └── test.py ├── multi_agent_shoot_reappear ├── gen_graph_test.py └── test.py ├── multi_friendly_shoot ├── gen_graph_test.py └── test.py ├── nop_bullet_freeze ├── graph_gen.py └── test_case.py ├── observation_rgb ├── gen_graph_test.py └── test.py ├── obstacle_shooting ├── obstacle_shooting_gen_graph_test.py └── obstacle_shooting_test.py ├── out_bounds_shooting ├── out_bounds_shooting_gen_graph_test.py └── out_bounds_shooting_test.py ├── random_action_sample ├── gen_graph_test.py └── test.py ├── repetitive_shooting ├── graph_gen.py └── test_case.py └── same_location_shoot_all_dir ├── gen_graph_test.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea/ 132 | .DS_Store 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Lab V2 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyReason Gym 🏋 2 | An OpenAI gym wrapper for PyReason to use in a reinforcement learning Grid World setting. 3 | 4 | 5 | ![Grid World Demo](media/pyreason-gym-demo.gif) 6 | 7 | 8 | ## Table of Contents 9 | 10 | * [Getting Started](#getting-started) 11 | * [The Setting](#the-setting) 12 | * [The Actions](#the-actions) 13 | * [The Objective](#the-objective) 14 | * [Rewards](#rewards) 15 | * [Installation](#installation) 16 | * [Usage](#usage) 17 | * [Actions](#actions) 18 | * [Observations](#observations) 19 | * [Render Modes](#render-modes) 20 | * [Other Options](#other-options) 21 | * [Contributing](#contributing) 22 | * [Bibtex](#bibtex) 23 | * [License](#License) 24 | * [Contact](#contact) 25 | 26 | ## Getting Started 27 | This is an OpenAI Gym environment for reinforcement learning in a grid world setting using [PyReason](https://github.com/lab-v2/pyreason) as a simulator. 28 | 29 | ### The Setting 30 | 1. There are two teams: Red and Blue 31 | 2. There are two bases: Red Base and Blue Base 32 | 3. There are a certain number of agents in each team 33 | 34 | ### The Actions 35 | There are 9 actions an agent can take: 36 | 37 | 1. Move Up 38 | 2. Move Down 39 | 3. Move Left 40 | 4. Move Right 41 | 5. Shoot Up 42 | 6. Shoot Down 43 | 7. Shoot Left 44 | 8. Shoot Right 45 | 9. Do Nothing 46 | 47 | ### The Objective 48 | The objecive of the game is to kill all enemy agents or make their `health=0`. The game will terminate (or signal `done=True` when this happens). This objective can be changed in the `is_done()` function in [`grid_world.py`](./pyreason_gym/envs/grid_world.py) to determine when the game should be over. 49 | 50 | ### Rewards 51 | **The reward function is currently not defined** A Reward of `0` is given at each step. You can modify this in the `_get_rew` function in [`grid_world.py`](./pyreason_gym/envs/grid_world.py) 52 | 53 | ## Installation 54 | Make sure `pyreason==1.5.1` has been installed using the instructions found [here](https://github.com/lab-v2/pyreason#21-install-as-a-python-library) 55 | 56 | Clone the repository, and install: 57 | ```bash 58 | git clone https://github.com/lab-v2/pyreason-gym 59 | pip install -e pyreason-gym 60 | ``` 61 | **NOTE:** Do not install this package using `setup.py`--this will not work. Use the instructions above to install. 62 | 63 | ## Usage 64 | To run the environment and get a feel for things you can run the [`test.py`](./test.py) file which will perform random actions in the grid world for 50 steps. 65 | ```bash 66 | python test.py 67 | ``` 68 | 69 | This Grid World scenario needs a graph in GraphML format to run. A graph file has **already been generated** in the [graphs folder](./pyreason_gym/pyreason_grid_world/graph/). However if you wish to change certain parameters such as 70 | 71 | 1. Number of agents per team 72 | 2. Start locations of the agents 73 | 3. Obstacle locations in the grid 74 | 4. The Grid World size (height, width) 75 | 5. The locations of the Red and Blue bases 76 | 77 | You will need to re-generate the graph file using the [`generate_graph.py`](./generate_graph.py) script by changing the appropriate parameters. This will generate the graph in the appropriate location for PyReason to find. NOTE: This is optional if you just want to try out the package--you can use the graph file already provided. 78 | 79 | This is an OpenAI Gym custom environment. More on OpenAI Gym: 80 | 81 | 1. [Documentation](https://www.gymlibrary.dev/) 82 | 2. [GitHub Repo](https://github.com/openai/gym) 83 | 84 | The interface is just like a normal Gym environment. To create an environment and start using it, insert the following into your Python script. Make sure you've [Installed](#installation) this package before this. 85 | 86 | ```python 87 | import gym 88 | import pyreason_gym 89 | 90 | env = gym.make('PyReasonGridWorld-v0') 91 | 92 | # Reset the environment 93 | obs, _ = env.reset() 94 | 95 | # Take a random action and get observation, rewards, done signal etc. 96 | # This will sample a random action from the action space of the environment 97 | action = env.action_space.sample() 98 | obs, rew, done, _, _ = env.step(action) 99 | 100 | # Keep using `env.step(action)` and `env.reset()` to get observations and run the grid world game. 101 | ``` 102 | 103 | A Tutorial on how to interact with gym environments can be found [here](https://www.gymlibrary.dev/) 104 | 105 | ### Actions 106 | The action space is currently a list for each team with discrete numbers representing each action: 107 | 108 | 1. Move Up is represented by `0` 109 | 2. Move Down is represented by `1` 110 | 3. Move Left is represented by `2` 111 | 4. Move Right is represented by `3` 112 | 5. Shoot Up is represented by `4` 113 | 6. Shoot Down is represented by `5` 114 | 7. Shoot Left is represented by `6` 115 | 8. Shoot Right is represented by `7` 116 | 9. Do Nothing is represented by `8` 117 | 118 | A sample action with `1` agent per team is of the form: 119 | ```python 120 | # Sample action. The list will increase with the number of agents per team 121 | action = { 122 | 'red_team': [0], 123 | 'blue_team': [2] 124 | } 125 | 126 | # Send the action to the environment 127 | obs, rew, done, _, _ = env.step(action) 128 | ``` 129 | 130 | ### Observations 131 | Observations contain information about each player's position in the grid (`[x,y]`), their `health` as well as blue and red `bullet` information including the position of the bullet in the grid (`[x,y]`) and its direction. 132 | A sample observation with `1` agent per team is a dictionary of the form: 133 | 134 | ```python 135 | observation = { 136 | 'red_team': [{'pos': [1,3], 'health': [1]}], 137 | 'blue_team': [{'pos': [7,2], 'health': [1]}], 138 | 'red_bullets': [{'pos': [2,3], 'dir': 1}, {'pos': [5,3], 'dir': 3}], 139 | 'blue_bullets': [{'pos': [7,1], 'dir': 2}] 140 | } 141 | ``` 142 | Information about agent positions, health, bullet positions and direction can be extracted from this observation space. 143 | 144 | ### Render Modes 145 | There are a few render modes supported: 146 | 147 | 1. `human` - Creates a PyGame visualization of the grid world and actions 148 | 2. `None` - No rendering, interaction only through actions and observations 149 | 3. `rgb_array` - An RGB array of the screen that would have been displayed using `render_mode='human'`. This can be used alongside CNNs etc. 150 | 151 | These can be used when creating the environment: 152 | ```python 153 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 154 | # Or 155 | env = gym.make('PyReasonGridWorld-v0', render_mode=None) 156 | # Or 157 | env = gym.make('PyReasonGridWorld-v0', render_mode='rgb_array') 158 | ``` 159 | If you're using `render_mode='rgb_array` you have to call `env.render(observation)` after `observation = env.step()` to get the rgb data. 160 | 161 | 162 | ### Other Options 163 | If you've generated the graph using the `generate_graph.py` script with a custom `grid size` and custom `number of agents per team`, you can pass these parameters to the grid world while creating the environment: 164 | ```python 165 | env = gym.make('PyReasonGridWorld-v0', grid_size=8, num_agents_per_team=1) 166 | ``` 167 | 168 | ## Contributing 169 | Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change. 170 | 171 | ## Bibtex 172 | If you used this software in your work please cite our paper 173 | ``` 174 | @inproceedings{aditya_pyreason_2023, 175 | title = {{PyReason}: Software for Open World Temporal Logic}, 176 | booktitle = {{AAAI} Spring Symposium}, 177 | author = {Aditya, Dyuman and Mukherji, Kaustuv and Balasubramanian, Srikar and Chaudhary, Abhiraj and Shakarian, Paulo}, 178 | year = {2023}} 179 | ``` 180 | 181 | ## License 182 | This repository is licensed under [BSD-3-Clause](./LICENSE) 183 | 184 | ## Contact 185 | Dyuman Aditya - dyuman.aditya@asu.edu 186 | 187 | Kaustuv Mukherji - kmukherji@asu.edu 188 | 189 | Paulo Shakarian - pshak02@asu.edu 190 | -------------------------------------------------------------------------------- /generate_graph.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[7], [56]], 100 | obstacle_loc=[26, 27, 34, 35, 36, 44]) 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /media/pyreason-gym-demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-v2/pyreason-gym/90cb05a1e886c54e55b54472499f350ec691d823/media/pyreason-gym-demo.gif -------------------------------------------------------------------------------- /pyreason_gym/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='PyReasonGridWorld-v0', 5 | entry_point='pyreason_gym.envs:GridWorldEnv' 6 | ) 7 | 8 | register( 9 | id='PyReasonMapWorld-v0', 10 | entry_point='pyreason_gym.envs:MapWorldEnv' 11 | ) 12 | -------------------------------------------------------------------------------- /pyreason_gym/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from pyreason_gym.envs.grid_world import GridWorldEnv 2 | from pyreason_gym.envs.map_world import MapWorldEnv 3 | -------------------------------------------------------------------------------- /pyreason_gym/envs/grid_world.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | import numpy as np 4 | import pygame 5 | 6 | from pyreason_gym.pyreason_grid_world.pyreason_grid_world import PyReasonGridWorld 7 | 8 | 9 | class GridWorldEnv(gym.Env): 10 | metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4} 11 | 12 | def __init__(self, grid_size=8, num_agents_per_team=1, render_mode=None, graph=None, rules=None): 13 | """Initialize grid world 14 | 15 | :param render_mode: whether to render in human viewable format or not, defaults to None 16 | :param grid_size: size of the grid world square, defaults to 8 17 | :param num_agents_per_team: number of agents in each team, defaults to 1 18 | :param render_mode: how to render the environment, defaults to None 19 | """ 20 | super(GridWorldEnv, self).__init__() 21 | 22 | self.grid_size = grid_size 23 | self.render_mode = render_mode 24 | self.window_size = 512 25 | 26 | # Initialize the PyReason gridworld 27 | self.pyreason_grid_world = PyReasonGridWorld(grid_size, num_agents_per_team, graph, rules) 28 | 29 | # Get the position of obstacles for the render function 30 | self.obstacle_positions = None 31 | self.base_positions = None 32 | 33 | # The observation space consists of the positions of the agents as well as their state (health etc.) 34 | # It also contains information about bullet positions as well as direction 35 | # Length of the sequence = num_agents_per_team 36 | self.observation_space = spaces.Dict( 37 | { 38 | 'red_team': spaces.Sequence(spaces.Dict({'pos': spaces.Box(0, grid_size-1, shape=(2,), dtype=int), 'health': spaces.Box(0, 1, dtype=np.float32), 'killed': spaces.Sequence(spaces.Discrete(num_agents_per_team+1))})), 39 | 'blue_team': spaces.Sequence(spaces.Dict({'pos': spaces.Box(0, grid_size-1, shape=(2,), dtype=int), 'health': spaces.Box(0, 1, dtype=np.float32), 'killed': spaces.Sequence(spaces.Discrete(num_agents_per_team+1))})), 40 | 'red_bullets': spaces.Sequence(spaces.Dict({'pos': spaces.Box(0, grid_size-1, shape=(2,), dtype=int), 'dir': spaces.Discrete(4)})), 41 | 'blue_bullets': spaces.Sequence(spaces.Dict({'pos': spaces.Box(0, grid_size-1, shape=(2,), dtype=int), 'dir': spaces.Discrete(4)})) 42 | } 43 | ) 44 | 45 | # We have 9 actions, corresponding to "up", "down", "left", "right", "shootUp", "shootDown", "shootLeft", "shootRight", "doNothing" 46 | self.action_space = spaces.Dict( 47 | { 48 | 'red_team': spaces.MultiDiscrete([9]*num_agents_per_team), 49 | 'blue_team': spaces.MultiDiscrete([9]*num_agents_per_team) 50 | } 51 | ) 52 | self.actions = {0: 'up', 1: 'down', 2: 'left', 3: 'right', 4: 'shootUp', 5: 'shootDown', 6: 'shootLeft', 7: 'shootRight'} 53 | self.current_observation = None 54 | 55 | assert render_mode is None or render_mode in self.metadata["render_modes"] 56 | 57 | # If human-rendering is used, `self.window` will be a reference 58 | # to the window that we draw to. `self.clock` will be a clock that is used 59 | # to ensure that the environment is rendered at the correct framerate in 60 | # human-mode. They will remain `None` until human-mode is used for the 61 | # first time. 62 | self.window = None 63 | self.clock = None 64 | 65 | def _get_obs(self): 66 | self.current_observation = self.pyreason_grid_world.get_obs() 67 | return self.current_observation 68 | 69 | def _get_info(self): 70 | return {} 71 | 72 | def _get_rew(self): 73 | return 0 74 | 75 | def reset(self, seed=None, options=None): 76 | """Resets the environment to the initial conditions 77 | 78 | :param seed: random seed if there is a random component, defaults to None 79 | :param options: defaults to None 80 | """ 81 | # We need the following line to seed self.np_random 82 | super().reset(seed=seed) 83 | 84 | self.pyreason_grid_world.reset() 85 | 86 | # Get the position of obstacles for the render function 87 | self.obstacle_positions = self.pyreason_grid_world.get_obstacle_locations() 88 | self.base_positions = self.pyreason_grid_world.get_base_locations() 89 | 90 | observation = self._get_obs() 91 | info = self._get_info() 92 | 93 | # Render if necessary 94 | if self.render_mode == "human": 95 | self._render_frame(observation) 96 | 97 | return observation, info 98 | 99 | def step(self, action): 100 | self.pyreason_grid_world.move(action) 101 | 102 | observation = self._get_obs() 103 | info = self._get_info() 104 | 105 | # Get reward 106 | rew = self._get_rew() 107 | 108 | # End of game 109 | done = self.is_done(observation) 110 | 111 | # Render if necessary 112 | if self.render_mode == "human": 113 | self._render_frame(observation) 114 | 115 | return observation, rew, done, False, info 116 | 117 | def render(self): 118 | if self.render_mode == "rgb_array": 119 | return self._render_frame(self.current_observation) 120 | 121 | def _render_frame(self, observation): 122 | if self.window is None and self.render_mode=="human": 123 | pygame.init() 124 | pygame.display.init() 125 | self.window = pygame.display.set_mode((self.window_size, self.window_size)) 126 | if self.clock is None and self.render_mode=="human": 127 | self.clock = pygame.time.Clock() 128 | 129 | canvas = pygame.Surface((self.window_size, self.window_size)) 130 | canvas.fill((255, 255, 255)) 131 | 132 | # The size of a single grid square in pixels 133 | pix_square_size = ( 134 | self.window_size / self.grid_size 135 | ) 136 | 137 | # First draw both bases 138 | pygame.draw.rect( 139 | canvas, 140 | (100, 0, 0), 141 | pygame.Rect( 142 | pix_square_size * self.to_pygame_coords(self.base_positions[0]), 143 | (pix_square_size, pix_square_size), 144 | ), 145 | ) 146 | pygame.draw.rect( 147 | canvas, 148 | (0, 0, 100), 149 | pygame.Rect( 150 | pix_square_size * self.to_pygame_coords(self.base_positions[1]), 151 | (pix_square_size, pix_square_size), 152 | ), 153 | ) 154 | 155 | # Draw the obstacles 156 | for i in self.obstacle_positions: 157 | triangle_coords = [pix_square_size * self.to_pygame_coords(i), pix_square_size * self.to_pygame_coords(i), pix_square_size * self.to_pygame_coords(i)] 158 | triangle_coords[0][0] += pix_square_size/2 159 | triangle_coords[1][1] += pix_square_size 160 | triangle_coords[2][0] += pix_square_size 161 | triangle_coords[2][1] += pix_square_size 162 | pygame.draw.polygon( 163 | canvas, 164 | (0, 0, 0), 165 | triangle_coords, 166 | ) 167 | 168 | # Draw the agents according to the observation 169 | for i in observation['red_team']: 170 | if i['health'][0] != 0: 171 | pos = self.to_pygame_coords(i['pos']) * pix_square_size 172 | pos += int(pix_square_size/2) 173 | # Draw circle and border 174 | pygame.draw.circle( 175 | canvas, 176 | (255, 0, 0), 177 | pos, 178 | pix_square_size/3, 179 | ) 180 | 181 | for i in observation['blue_team']: 182 | if i['health'][0] != 0: 183 | pos = self.to_pygame_coords(i['pos']) * pix_square_size 184 | pos += int(pix_square_size/2) 185 | # Draw circle and border 186 | pygame.draw.circle( 187 | canvas, 188 | (0, 0, 255), 189 | pos, 190 | pix_square_size/3, 191 | ) 192 | 193 | # Add active bullets to the grid (currently we don't display direction) 194 | direction_map = {0: 'up', 1: 'down', 2: 'left', 3: 'right'} 195 | for bullet in observation['red_bullets']: 196 | red_pos = bullet['pos'] 197 | red_dir = bullet['dir'] 198 | # Which dir the bullet should point 199 | if direction_map[red_dir] == 'up' or direction_map[red_dir] == 'down': 200 | idx = 1 201 | elif direction_map[red_dir] == 'left' or direction_map[red_dir] == 'right': 202 | idx = 0 203 | start_pos = self.to_pygame_coords(red_pos) * pix_square_size + int(pix_square_size/2) 204 | end_pos = self.to_pygame_coords(red_pos) * pix_square_size + int(pix_square_size/2) 205 | start_pos[idx] -= pix_square_size/5 206 | end_pos[idx] += pix_square_size/5 207 | pygame.draw.line( 208 | canvas, 209 | (255, 0, 0), 210 | start_pos, 211 | end_pos, 212 | 10 213 | ) 214 | 215 | # Draw triangles at the end of each bullet 216 | if direction_map[red_dir] == 'up': 217 | tri_1 = [start_pos[0], start_pos[1] - pix_square_size / 8] 218 | tri_2 = [start_pos[0] + pix_square_size / 8, start_pos[1]] 219 | tri_3 = [start_pos[0] - pix_square_size / 8, start_pos[1]] 220 | elif direction_map[red_dir] == 'down': 221 | tri_1 = [end_pos[0], end_pos[1] + pix_square_size / 8] 222 | tri_2 = [end_pos[0] + pix_square_size / 8, end_pos[1]] 223 | tri_3 = [end_pos[0] - pix_square_size / 8, end_pos[1]] 224 | elif direction_map[red_dir] == 'left': 225 | tri_1 = [start_pos[0] - pix_square_size / 8, start_pos[1]] 226 | tri_2 = [start_pos[0], start_pos[1] + pix_square_size / 8] 227 | tri_3 = [start_pos[0], start_pos[1] - pix_square_size / 8] 228 | elif direction_map[red_dir] == 'right': 229 | tri_1 = [end_pos[0] + pix_square_size / 8, end_pos[1]] 230 | tri_2 = [end_pos[0], end_pos[1] + pix_square_size / 8] 231 | tri_3 = [end_pos[0], end_pos[1] - pix_square_size / 8] 232 | 233 | pygame.draw.polygon( 234 | canvas, 235 | (255, 0, 0), 236 | (tri_1, tri_2, tri_3), 237 | ) 238 | 239 | for bullet in observation['blue_bullets']: 240 | blue_pos = bullet['pos'] 241 | blue_dir = bullet['dir'] 242 | # Which dir the bullet should point 243 | if direction_map[blue_dir] == 'up' or direction_map[blue_dir] == 'down': 244 | idx = 1 245 | elif direction_map[blue_dir] == 'left' or direction_map[blue_dir] == 'right': 246 | idx = 0 247 | start_pos = self.to_pygame_coords(blue_pos) * pix_square_size + int(pix_square_size/2) 248 | end_pos = self.to_pygame_coords(blue_pos) * pix_square_size + int(pix_square_size/2) 249 | start_pos[idx] -= pix_square_size / 5 250 | end_pos[idx] += pix_square_size / 5 251 | pygame.draw.line( 252 | canvas, 253 | (0, 0, 255), 254 | start_pos, 255 | end_pos, 256 | 10 257 | ) 258 | 259 | # Draw triangles at the end of each bullet 260 | if direction_map[blue_dir] == 'up': 261 | tri_1 = [start_pos[0], start_pos[1] - pix_square_size / 8] 262 | tri_2 = [start_pos[0] + pix_square_size / 8, start_pos[1]] 263 | tri_3 = [start_pos[0] - pix_square_size / 8, start_pos[1]] 264 | elif direction_map[blue_dir] == 'down': 265 | tri_1 = [end_pos[0], end_pos[1] + pix_square_size / 8] 266 | tri_2 = [end_pos[0] + pix_square_size / 8, end_pos[1]] 267 | tri_3 = [end_pos[0] - pix_square_size / 8, end_pos[1]] 268 | elif direction_map[blue_dir] == 'left': 269 | tri_1 = [start_pos[0] - pix_square_size / 8, start_pos[1]] 270 | tri_2 = [start_pos[0], start_pos[1] + pix_square_size / 8] 271 | tri_3 = [start_pos[0], start_pos[1] - pix_square_size / 8] 272 | elif direction_map[blue_dir] == 'right': 273 | tri_1 = [end_pos[0] + pix_square_size / 8, end_pos[1]] 274 | tri_2 = [end_pos[0], end_pos[1] + pix_square_size / 8] 275 | tri_3 = [end_pos[0], end_pos[1] - pix_square_size / 8] 276 | 277 | pygame.draw.polygon( 278 | canvas, 279 | (0, 0, 255), 280 | (tri_1, tri_2, tri_3), 281 | ) 282 | 283 | # Finally, add some gridlines 284 | for x in range(self.grid_size + 1): 285 | pygame.draw.line( 286 | canvas, 287 | 0, 288 | (0, pix_square_size * x), 289 | (self.window_size, pix_square_size * x), 290 | width=3, 291 | ) 292 | pygame.draw.line( 293 | canvas, 294 | 0, 295 | (pix_square_size * x, 0), 296 | (pix_square_size * x, self.window_size), 297 | width=3, 298 | ) 299 | 300 | if self.render_mode == "human": 301 | # The following line copies our drawings from `canvas` to the visible window 302 | self.window.blit(canvas, canvas.get_rect()) 303 | pygame.event.pump() 304 | pygame.display.update() 305 | 306 | # We need to ensure that human-rendering occurs at the predefined framerate. 307 | # The following line will automatically add a delay to keep the framerate stable. 308 | self.clock.tick(self.metadata["render_fps"]) 309 | elif self.render_mode == 'rgb_array': 310 | return np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)) 311 | 312 | def close(self): 313 | if self.window is not None: 314 | pygame.display.quit() 315 | pygame.quit() 316 | 317 | def is_done(self, observation): 318 | # End the game when the health goes to zero of an entire team 319 | red_end = True 320 | blue_end = True 321 | for i in observation['red_team']: 322 | if i['health'] != 0: 323 | red_end = False 324 | 325 | for i in observation['blue_team']: 326 | if i['health'] != 0: 327 | blue_end = False 328 | 329 | return red_end or blue_end 330 | 331 | def to_pygame_coords(self, coords): 332 | """Convert coordinates into pygame coordinates (lower-left => top left).""" 333 | return np.array([coords[0], self.grid_size - 1 - coords[1]]) 334 | -------------------------------------------------------------------------------- /pyreason_gym/envs/map_world.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | import numpy as np 4 | import pygame 5 | import time 6 | 7 | from pyreason_gym.pyreason_map_world.pyreason_map_world import PyReasonMapWorld 8 | 9 | np.set_printoptions(precision=20) 10 | 11 | # Odd order is due to orientation on canvas while displaying 12 | LAT_LONG_SCALE = int(10e14) 13 | LAT_MAX = int(35.64402770996094 * 10e14) 14 | LAT_MIN = int(36.58740997314453 * 10e14) 15 | LONG_MIN = int(-84.71105957031250 * 10e14) 16 | LONG_MAX = int(-83.68660736083984 * 10e14) 17 | PYGAME_MIN = 0 18 | PYGAME_MAX = 1000 19 | 20 | 21 | def map_lat_long_to_pygame_coords(lat, long): 22 | # Map from lat long range to pygame range 23 | lat = int(lat * LAT_LONG_SCALE) 24 | long = int(long * LAT_LONG_SCALE) 25 | lat_range = LAT_MAX - LAT_MIN 26 | long_range = LONG_MAX - LONG_MIN 27 | pygame_range = PYGAME_MAX - PYGAME_MIN 28 | new_lat = (((lat - LAT_MIN) * pygame_range) / lat_range) + PYGAME_MIN 29 | new_long = (((long - LONG_MIN) * pygame_range) / long_range) + PYGAME_MIN 30 | coord = np.array([new_long, new_lat]) 31 | 32 | return coord 33 | 34 | 35 | class MapWorldEnv(gym.Env): 36 | metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4} 37 | 38 | def __init__(self, start_point, end_point, render_mode=None): 39 | """Initialize map world 40 | 41 | :param start_point: Point where agent will start 42 | :type start_point: str 43 | :param end_point: Point where agent should end 44 | :type end_point: str 45 | :param render_mode: how to render the environment, defaults to None 46 | :type render_mode: str or None 47 | """ 48 | super(MapWorldEnv, self).__init__() 49 | 50 | self.render_mode = render_mode 51 | self.window_size = PYGAME_MAX 52 | 53 | # Start End points are required for observations 54 | self.start_point = start_point 55 | self.end_point = end_point 56 | 57 | # Rendering info 58 | self.start_point_lat_long = None 59 | self.end_point_lat_long = None 60 | 61 | # Initialize the PyReason map-world 62 | self.pyreason_map_world = PyReasonMapWorld(end_point) 63 | 64 | # Observation space is how close/far it is to the goal point. Coordinates from current point to end point 65 | # And how many valid actions there are in the state 66 | self.observation_space = spaces.Tuple((spaces.Text(max_length=20), spaces.Box(-100, 100, shape=(2,), dtype=np.float128), spaces.Box(-100, 100, shape=(2,), dtype=np.float128), spaces.Discrete(100))) 67 | 68 | # The choice of action is limited to the number of outgoing edges from one node. The agent has to pick one edge to go on 69 | self.action_space = spaces.Discrete(1) 70 | 71 | assert render_mode is None or render_mode in self.metadata["render_modes"] 72 | 73 | # If human-rendering is used, `self.window` will be a reference 74 | # to the window that we draw to. `self.clock` will be a clock that is used 75 | # to ensure that the environment is rendered at the correct framerate in 76 | # human-mode. `self.canvas` is a pygame surface where we draw out the game 77 | # They will remain `None` until human-mode is used for the first time. 78 | self.window = None 79 | self.clock = None 80 | self.canvas = None 81 | 82 | def _get_obs(self): 83 | return self.pyreason_map_world.get_obs() 84 | 85 | def _get_info(self): 86 | return {} 87 | 88 | def _get_rew(self): 89 | return 0 90 | 91 | def reset(self, seed=None, options=None): 92 | """Resets the environment to the initial conditions 93 | 94 | :param seed: random seed if there is a random component, defaults to None 95 | :param options: defaults to None 96 | """ 97 | # We need the following line to seed self.np_random 98 | super().reset(seed=seed) 99 | 100 | self.pyreason_map_world.reset() 101 | 102 | observation = self._get_obs() 103 | info = self._get_info() 104 | 105 | # Save new action space 106 | _, _, _, new_action_space = observation 107 | self.action_space = spaces.Discrete(new_action_space) 108 | 109 | # Render if necessary 110 | if self.render_mode == "human": 111 | self._render_frame(observation) 112 | 113 | return observation, info 114 | 115 | def step(self, action): 116 | self.pyreason_map_world.move(action) 117 | 118 | observation = self._get_obs() 119 | info = self._get_info() 120 | 121 | # Get reward 122 | rew = self._get_rew() 123 | 124 | # End of game 125 | done = self.is_done(observation) 126 | 127 | # Render if necessary 128 | if self.render_mode == "human": 129 | self._render_frame(observation) 130 | 131 | return observation, rew, done, False, info 132 | 133 | def _render_init(self): 134 | pygame.init() 135 | pygame.display.init() 136 | self.window = pygame.display.set_mode((self.window_size, self.window_size)) 137 | self.clock = pygame.time.Clock() 138 | self.canvas = pygame.Surface((self.window_size, self.window_size)) 139 | self.canvas.fill((255, 255, 255)) 140 | 141 | # Draw the nodes and edges on this canvas so that we don't have to keep drawing it every step 142 | nodes_lat_long, edges_lat_long = self.pyreason_map_world.get_map() 143 | 144 | # Draw points for nodes 145 | for node in nodes_lat_long: 146 | pygame.draw.circle( 147 | self.canvas, 148 | (69, 69, 69), 149 | map_lat_long_to_pygame_coords(*node), 150 | 2 151 | ) 152 | 153 | # Draw edges between points 154 | for edge in edges_lat_long: 155 | pygame.draw.aaline( 156 | self.canvas, 157 | (169, 169, 169), 158 | map_lat_long_to_pygame_coords(*edge[0]), 159 | map_lat_long_to_pygame_coords(*edge[1]) 160 | ) 161 | 162 | self.window.blit(self.canvas, self.canvas.get_rect()) 163 | pygame.event.pump() 164 | pygame.display.update() 165 | 166 | # We need to ensure that human-rendering occurs at the predefined framerate. 167 | # The following line will automatically add a delay to keep the framerate stable. 168 | self.clock.tick(self.metadata["render_fps"]) 169 | time.sleep(10) 170 | 171 | def _render_frame(self, observation): 172 | if self.canvas is None: 173 | self._render_init() 174 | if self.clock is None and self.render_mode == "human": 175 | self.clock = pygame.time.Clock() 176 | 177 | canvas = self.canvas.copy() 178 | 179 | current_node, current_lat_long, end_lat_long, new_action_space = observation 180 | 181 | if self.start_point_lat_long is None: 182 | self.start_point_lat_long = current_lat_long 183 | if self.end_point_lat_long is None: 184 | self.end_point_lat_long = end_lat_long 185 | 186 | # Draw start and end nodes 187 | pygame.draw.circle( 188 | canvas, 189 | (0, 255, 0), 190 | map_lat_long_to_pygame_coords(self.start_point_lat_long[0], self.start_point_lat_long[1]), 191 | 5, 192 | ) 193 | 194 | pygame.draw.circle( 195 | canvas, 196 | (0, 0, 255), 197 | map_lat_long_to_pygame_coords(current_lat_long[0], current_lat_long[1]), 198 | 5, 199 | ) 200 | pygame.draw.circle( 201 | canvas, 202 | (255, 0, 0), 203 | map_lat_long_to_pygame_coords(self.end_point_lat_long[0], self.end_point_lat_long[1]), 204 | 5, 205 | ) 206 | 207 | if self.render_mode == "human": 208 | # The following line copies our drawings from `canvas` to the visible window 209 | self.window.blit(canvas, canvas.get_rect()) 210 | pygame.event.pump() 211 | pygame.display.update() 212 | 213 | # We need to ensure that human-rendering occurs at the predefined framerate. 214 | # The following line will automatically add a delay to keep the framerate stable. 215 | self.clock.tick(self.metadata["render_fps"]) 216 | elif self.render_mode == 'rgb_array': 217 | return np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)) 218 | 219 | def close(self): 220 | if self.window is not None: 221 | pygame.display.quit() 222 | pygame.quit() 223 | 224 | def is_done(self, observation): 225 | # End the game when the agent reaches the end point 226 | 227 | return 228 | -------------------------------------------------------------------------------- /pyreason_gym/pyreason_grid_world/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-v2/pyreason-gym/90cb05a1e886c54e55b54472499f350ec691d823/pyreason_gym/pyreason_grid_world/__init__.py -------------------------------------------------------------------------------- /pyreason_gym/pyreason_grid_world/pyreason_grid_world.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pyreason as pr 3 | import numpy as np 4 | 5 | 6 | class PyReasonGridWorld: 7 | def __init__(self, grid_size, num_agents_per_team, graph, rules): 8 | self.grid_size = grid_size 9 | self.num_agents_per_team = num_agents_per_team 10 | self.interpretation = None 11 | 12 | # Keep track of the next timestep to start 13 | self.next_time = 0 14 | 15 | # Pyreason settings 16 | pr.settings.verbose = False 17 | pr.settings.atom_trace = False 18 | pr.settings.canonical = True 19 | pr.settings.inconsistency_check = False 20 | pr.settings.static_graph_facts = False 21 | pr.settings.store_interpretation_changes = False 22 | current_path = os.path.abspath(os.path.dirname(__file__)) 23 | 24 | # Load the graph 25 | if graph is None: 26 | pr.load_graph(f'{current_path}/graph/game_graph.graphml') 27 | else: 28 | pr.load_graph(graph) 29 | 30 | # Load rules 31 | if rules is None: 32 | pr.load_rules(f'{current_path}/yamls/rules.yaml') 33 | else: 34 | pr.load_rules(rules) 35 | 36 | def reset(self): 37 | # Reason for 1 timestep to initialize everything 38 | # Certain internal variables need to be reset otherwise memory blows up 39 | pr.reset() 40 | self.interpretation = pr.reason(0, again=False) 41 | self.next_time = self.interpretation.time + 1 42 | 43 | def move(self, action): 44 | # Define facts, then run pyreason 45 | # action input is a Dict with two keys, one for each team, consisting of a list of actions, one for each agent 46 | red_team_actions = action['red_team'] 47 | blue_team_actions = action['blue_team'] 48 | 49 | facts = [] 50 | red_available_actions = {0:'moveUp', 1:'moveDown', 2:'moveLeft', 3:'moveRight', 4:'shootUpRed', 5:'shootDownRed', 6:'shootLeftRed', 7:'shootRightRed'} 51 | blue_available_actions = {0:'moveUp', 1:'moveDown', 2:'moveLeft', 3:'moveRight', 4:'shootUpBlue', 5:'shootDownBlue', 6:'shootLeftBlue', 7:'shootRightBlue'} 52 | for i, a in enumerate(red_team_actions): 53 | if a != 8: 54 | fact_on = pr.fact_node.Fact(f'red_action_{i+1}', f'red-soldier-{i+1}', pr.label.Label(red_available_actions[a]), pr.interval.closed(1,1), self.next_time, self.next_time) 55 | fact_off = pr.fact_node.Fact(f'red_action_{i+1}_off', f'red-soldier-{i+1}', pr.label.Label(red_available_actions[a]), pr.interval.closed(0,0), self.next_time+1, self.next_time+1) 56 | facts.append(fact_on) 57 | facts.append(fact_off) 58 | 59 | for i, a in enumerate(blue_team_actions): 60 | if a != 8: 61 | fact_on = pr.fact_node.Fact(f'blue_action_{i+1}', f'blue-soldier-{i+1}', pr.label.Label(blue_available_actions[a]), pr.interval.closed(1,1), self.next_time, self.next_time) 62 | fact_off = pr.fact_node.Fact(f'blue_action_{i+1}_off', f'blue-soldier-{i+1}', pr.label.Label(blue_available_actions[a]), pr.interval.closed(0,0), self.next_time+1, self.next_time+1) 63 | facts.append(fact_on) 64 | facts.append(fact_off) 65 | 66 | self.interpretation = pr.reason(1, again=True, node_facts=facts) 67 | self.next_time = self.interpretation.time + 1 68 | 69 | def get_obs(self): 70 | observation = {'red_team': [], 'blue_team': [], 'red_bullets': [], 'blue_bullets': []} 71 | 72 | # Gather bullet info for red and blue bullets 73 | (red_bullet_positions, blue_bullet_positions), (red_bullet_directions, blue_bullet_directions), (red_killed_who, blue_killed_who) = self._get_bullet_info() 74 | for red_pos, red_dir in zip(red_bullet_positions, red_bullet_directions): 75 | observation['red_bullets'].append({'pos': red_pos, 'dir': red_dir}) 76 | 77 | for blue_pos, blue_dir in zip(blue_bullet_positions, blue_bullet_directions): 78 | observation['blue_bullets'].append({'pos': blue_pos, 'dir': blue_dir}) 79 | 80 | # Filter edges that are of the form (red-soldier-x, y) where x and y are ints 81 | red_relevant_edges = [edge for edge in self.interpretation.edges if 'red-soldier' in edge[0] and edge[1].isnumeric()] 82 | blue_relevant_edges = [edge for edge in self.interpretation.edges if 'blue-soldier' in edge[0] and edge[1].isnumeric()] 83 | 84 | # Select edges that have the atLoc predicate set to [1,1] 85 | red_position_edges = [edge for edge in red_relevant_edges if self.interpretation.interpretations_edge[edge].world[pr.label.Label('atLoc')]==pr.interval.closed(1,1)] 86 | blue_position_edges = [edge for edge in blue_relevant_edges if self.interpretation.interpretations_edge[edge].world[pr.label.Label('atLoc')]==pr.interval.closed(1,1)] 87 | 88 | # Make sure that the length of these lists are the same as the num agents per team 89 | assert len(red_position_edges)==self.num_agents_per_team and len(blue_position_edges)==self.num_agents_per_team, 'Number of agents per team does not match info retrieved about agent position from interpretations' 90 | 91 | # Sort the lists according to the last char of the source node (eg. red-soldier-1, 1 is the last char) 92 | red_position_edges = sorted(red_position_edges, key=lambda x: str(x[0][-1])) 93 | blue_position_edges = sorted(blue_position_edges, key=lambda x: str(x[0][-1])) 94 | 95 | # Gather info about the agents 96 | for i in range(1, self.num_agents_per_team+1): 97 | red_pos = int(red_position_edges[i-1][1]) 98 | blue_pos = int(blue_position_edges[i-1][1]) 99 | red_pos_coords = [red_pos%self.grid_size, red_pos//self.grid_size] 100 | blue_pos_coords = [blue_pos%self.grid_size, blue_pos//self.grid_size] 101 | red_health = self.interpretation.interpretations_node[f'red-soldier-{i}'].world[pr.label.Label('health')].lower 102 | blue_health = self.interpretation.interpretations_node[f'blue-soldier-{i}'].world[pr.label.Label('health')].lower 103 | 104 | observation['red_team'].append({'pos': np.array(red_pos_coords, dtype=np.int32), 'health': np.array([red_health], dtype=np.float32), 'killed': list(red_killed_who[i-1])}) 105 | observation['blue_team'].append({'pos': np.array(blue_pos_coords, dtype=np.int32), 'health': np.array([blue_health], dtype=np.float32), 'killed': list(blue_killed_who[i-1])}) 106 | 107 | return observation 108 | 109 | def get_obstacle_locations(self): 110 | # Return the coordinates of all the mountains in the grid to be able to draw them 111 | relevant_edges = [edge for edge in self.interpretation.edges if edge[1]=='mountain'] 112 | obstacle_positions = [int(edge[0]) for edge in relevant_edges] 113 | obstacle_positions_coords = np.array([[pos%self.grid_size, pos//self.grid_size] for pos in obstacle_positions]) 114 | return obstacle_positions_coords 115 | 116 | def get_base_locations(self): 117 | # Return the locations of the two bases 118 | relevant_edges = [edge for edge in self.interpretation.edges if 'base' in edge[0]] 119 | sorted_relevant_edges = [relevant_edges[0], relevant_edges[1]] if relevant_edges[0][0]=='red-base' else [relevant_edges[1], relevant_edges[0]] 120 | base_positions = [int(edge[1]) for edge in sorted_relevant_edges] 121 | base_positions_coords = np.array([[pos%self.grid_size, pos//self.grid_size] for pos in base_positions]) 122 | return base_positions_coords 123 | 124 | def _get_bullet_info(self): 125 | # Return the location of red and blue bullets to be displayed on the grid 126 | relevant_edges = [edge for edge in self.interpretation.edges if 'bullet' in edge[1] and edge[0].isdigit()] 127 | filtered_edges = [edge for edge in relevant_edges if self.interpretation.interpretations_edge[edge].world[pr.label.Label('atLoc')] == pr.interval.closed(1,1) 128 | and self.interpretation.interpretations_edge[edge].world[pr.label.Label('life')] == pr.interval.closed(1,1)] 129 | red_bullet_positions = [int(edge[0]) for edge in filtered_edges if 'red' in edge[1]] 130 | blue_bullet_positions = [int(edge[0]) for edge in filtered_edges if 'blue' in edge[1]] 131 | red_bullet_positions_coords = np.array([[pos%self.grid_size, pos//self.grid_size] for pos in red_bullet_positions]) 132 | blue_bullet_positions_coords = np.array([[pos%self.grid_size, pos//self.grid_size] for pos in blue_bullet_positions]) 133 | positions = (red_bullet_positions_coords, blue_bullet_positions_coords) 134 | 135 | # Get info about who killed whom. Stored in the form a list for every agent: (red-killer: [blue-casualties]) or (blue-killer: [red-casualties]) 136 | kill_info_edges = [edge for edge in self.interpretation.edges if pr.label.Label('killed') in self.interpretation.interpretations_edge[edge].world 137 | and self.interpretation.interpretations_edge[edge].world[pr.label.Label('killed')] == pr.interval.closed(1, 1)] 138 | kill_info_edges = sorted(kill_info_edges, key=lambda x: int(x[0][-1])) 139 | red_killed_who_tuple = [(int(edge[0][-1]), int(edge[1][-1])) for edge in kill_info_edges if 'red' in edge[0]] 140 | blue_killed_who_tuple = [(int(edge[0][-1]), int(edge[1][-1])) for edge in kill_info_edges if 'blue' in edge[0]] 141 | red_killed_who = [[] for _ in range(self.num_agents_per_team)] 142 | blue_killed_who = [[] for _ in range(self.num_agents_per_team)] 143 | 144 | for shooter, casualty in red_killed_who_tuple: 145 | red_killed_who[shooter-1].append(casualty) 146 | for shooter, casualty in blue_killed_who_tuple: 147 | blue_killed_who[shooter-1].append(casualty) 148 | 149 | who_killed_who = (red_killed_who, blue_killed_who) 150 | 151 | # Bullet direction of movement 152 | direction_map = {0.2: 0, 0.6: 1, 0.4: 2, 0.8: 3} 153 | red_bullet_directions = [direction_map[self.interpretation.interpretations_edge[edge].world[pr.label.Label('direction')].lower] for edge in filtered_edges if 'red' in edge[1]] 154 | blue_bullet_directions = [direction_map[self.interpretation.interpretations_edge[edge].world[pr.label.Label('direction')].lower] for edge in filtered_edges if 'blue' in edge[1]] 155 | directions = (red_bullet_directions, blue_bullet_directions) 156 | 157 | # Make sure the length of positions is the same as directions 158 | assert len(red_bullet_positions) == len(red_bullet_directions), 'Length of bullet positions does not math length of bullet directions' 159 | assert len(blue_bullet_positions) == len(blue_bullet_directions), 'Length of bullet positions does not math length of bullet directions' 160 | 161 | return positions, directions, who_killed_who 162 | -------------------------------------------------------------------------------- /pyreason_gym/pyreason_grid_world/yamls/rules.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Makes a spot on the grid 'blocked' if there is a mountain associated with that spot 3 | obstacle_rule: 4 | target: blocked 5 | target_criteria: 6 | - [blocked, 0, 0] 7 | delta_t: 0 8 | neigh_criteria: 9 | - [node, [x], isMountain, [1,1]] 10 | - [edge, [target, x], atLoc, [1,1]] 11 | ann_fn: [1,1] 12 | 13 | # BULLET RULES GET CHECKED FIRST 14 | # BULLET Rules 15 | # BULLET HITS SOMEONE 16 | # BLUE DIES 17 | bullet_hit_just_died_blue_rule: 18 | target: justDied 19 | target_criteria: 20 | - [health, 0.1, 1] 21 | - [teamBlue, 1, 1] 22 | delta_t: 0 23 | immediate: true 24 | neigh_criteria: 25 | - [edge, [target, loc], atLoc, [1,1]] 26 | - [edge, [loc, bullet], atLoc, [1,1]] 27 | - [edge, [loc, bullet], life, [0.1, 1]] 28 | - [node, [bullet], teamRed, [1,1]] 29 | ann_fn: [1, 1] 30 | 31 | bullet_hit_just_died_red_rule: 32 | target: justDied 33 | target_criteria: 34 | - [health, 0.1, 1] 35 | - [teamRed, 1, 1] 36 | delta_t: 0 37 | immediate: true 38 | neigh_criteria: 39 | - [edge, [target, loc], atLoc, [1,1]] 40 | - [edge, [loc, bullet], atLoc, [1,1]] 41 | - [edge, [loc, bullet], life, [0.1, 1]] 42 | - [node, [bullet], teamBlue, [1,1]] 43 | ann_fn: [1, 1] 44 | 45 | bullet_hit_just_died_off_blue_rule: 46 | target: justDied 47 | target_criteria: 48 | - [health, 0, 0] 49 | - [justDied, 1, 1] 50 | - [teamBlue, 1, 1] 51 | delta_t: 1 52 | neigh_criteria: 53 | - [edge, [target, loc], atLoc, [1,1]] 54 | - [edge, [loc, bullet], atLoc, [1,1]] 55 | - [node, [bullet], teamRed, [1,1]] 56 | ann_fn: [0, 0] 57 | 58 | bullet_hit_just_died_off_red_rule: 59 | target: justDied 60 | target_criteria: 61 | - [health, 0, 0] 62 | - [justDied, 1, 1] 63 | - [teamRed, 1, 1] 64 | delta_t: 1 65 | neigh_criteria: 66 | - [edge, [target, loc], atLoc, [1,1]] 67 | - [edge, [loc, bullet], atLoc, [1,1]] 68 | - [node, [bullet], teamBlue, [1,1]] 69 | ann_fn: [0, 0] 70 | 71 | bullet_hit_rule: 72 | target: health 73 | target_criteria: 74 | - [health, 0.1, 1] 75 | - [justDied, 1, 1] 76 | delta_t: 0 77 | immediate: true 78 | neigh_criteria: 79 | ann_fn: [0, 0] 80 | 81 | add_kill_blue_info_rule: 82 | target: 83 | target_criteria: 84 | - [atLoc, 1, 1] 85 | - [life, 0.1, 1] 86 | delta_t: 0 87 | immediate: true 88 | neigh_criteria: 89 | - [node, [player], justDied, [1,1]] 90 | - [node, [target], teamRed, [1,1]] 91 | - [node, [player], teamBlue, [1,1]] 92 | edges: [target, player, killed] 93 | ann_fn: [1, 1] 94 | 95 | add_kill_red_info_rule: 96 | target: 97 | target_criteria: 98 | - [atLoc, 1, 1] 99 | - [life, 0.1, 1] 100 | delta_t: 0 101 | immediate: true 102 | neigh_criteria: 103 | - [node, [player], justDied, [1,1]] 104 | - [node, [target], teamBlue, [1,1]] 105 | - [node, [player], teamRed, [1,1]] 106 | edges: [target, player, killed] 107 | ann_fn: [1, 1] 108 | 109 | remove_kill_info_rule: 110 | target: killed 111 | target_criteria: 112 | - [killed, 1, 1] 113 | delta_t: 1 114 | neigh_criteria: 115 | ann_fn: [0, 0] 116 | 117 | remove_bullet_after_hit_blue_rule: 118 | target: life 119 | target_criteria: 120 | - [atLoc, 1, 1] 121 | - [life, 0.1, 1] 122 | delta_t: 0 123 | immediate: true 124 | neigh_criteria: 125 | - [node, [player], justDied, [1,1]] 126 | - [node, [target], teamRed, [1,1]] 127 | - [node, [player], teamBlue, [1,1]] 128 | 129 | ann_fn: [0, 0] 130 | 131 | # Sometimes a bullet is scheduled to move to a location twice before it has actually moved 132 | # This prevents the bullet from re-appearing 133 | remove_bullet_after_hit_blue_rule_2: 134 | target: life 135 | target_criteria: 136 | - [life, 0, 0] 137 | delta_t: 1 138 | neigh_criteria: 139 | - [node, [player], justDied, [1,1]] 140 | - [node, [target], teamRed, [1,1]] 141 | - [node, [player], teamBlue, [1,1]] 142 | 143 | ann_fn: [0, 0] 144 | 145 | remove_bullet_after_hit_red_rule: 146 | target: life 147 | target_criteria: 148 | - [atLoc, 1, 1] 149 | - [life, 0.1, 1] 150 | delta_t: 0 151 | immediate: true 152 | neigh_criteria: 153 | - [node, [player], justDied, [1,1]] 154 | - [node, [target], teamBlue, [1,1]] 155 | - [node, [player], teamRed, [1,1]] 156 | 157 | ann_fn: [0, 0] 158 | 159 | # Sometimes a bullet is scheduled to move to a location twice before it has actually moved 160 | # This prevents the bullet from re-appearing 161 | remove_bullet_after_hit_red_rule_2: 162 | target: life 163 | target_criteria: 164 | - [life, 0, 0] 165 | delta_t: 1 166 | neigh_criteria: 167 | - [node, [player], justDied, [1,1]] 168 | - [node, [target], teamBlue, [1,1]] 169 | - [node, [player], teamRed, [1,1]] 170 | 171 | ann_fn: [0, 0] 172 | 173 | 174 | # Moving rules 175 | # Move UP 176 | move_up_rule_1: 177 | target: 178 | target_criteria: 179 | - [moveUp, 1, 1] 180 | - [health, 0.1, 1] 181 | delta_t: 0 182 | immediate: true 183 | neigh_criteria: 184 | - [edge, [target, oldLoc], atLoc, [1,1]] 185 | - [edge, [oldLoc, newLoc], up, [1,1]] 186 | - [node, [newLoc], blocked, [0,0]] 187 | 188 | edges: [target, newLoc, atLoc] 189 | ann_fn: [0.5,0.5] 190 | 191 | move_up_rule_2: 192 | target: 193 | target_criteria: 194 | - [moveUp, 1, 1] 195 | - [health, 0.1, 1] 196 | delta_t: 0 197 | immediate: true 198 | neigh_criteria: 199 | - [edge, [target, oldLoc], atLoc, [1,1]] 200 | - [edge, [oldLoc, newLoc], up, [1,1]] 201 | - [node, [newLoc], blocked, [0,0]] 202 | 203 | edges: [newLoc, target, atLoc] 204 | ann_fn: [0.5,0.5] 205 | 206 | move_up_change_prev_rule_1: 207 | target: atLoc 208 | target_criteria: 209 | - [atLoc, 1, 1] 210 | delta_t: 0 211 | immediate: true 212 | neigh_criteria: 213 | - [edge, [target, newLoc], up, [1,1]] 214 | - [edge, [source, newLoc], atLoc, [0.5,0.5]] 215 | 216 | ann_fn: [0,0] 217 | 218 | move_up_change_prev_rule_2: 219 | target: atLoc 220 | target_criteria: 221 | - [atLoc, 1, 1] 222 | delta_t: 0 223 | immediate: true 224 | neigh_criteria: 225 | - [edge, [source, newLoc], up, [1,1]] 226 | - [edge, [newLoc, target], atLoc, [0.5,0.5]] 227 | 228 | ann_fn: [0,0] 229 | 230 | # Move DOWN 231 | move_down_rule_1: 232 | target: 233 | target_criteria: 234 | - [moveDown, 1, 1] 235 | - [health, 0.1, 1] 236 | delta_t: 0 237 | immediate: true 238 | neigh_criteria: 239 | - [edge, [target, oldLoc], atLoc, [1,1]] 240 | - [edge, [oldLoc, newLoc], down, [1,1]] 241 | - [node, [newLoc], blocked, [0,0]] 242 | 243 | edges: [target, newLoc, atLoc] 244 | ann_fn: [0.5,0.5] 245 | 246 | move_down_rule_2: 247 | target: 248 | target_criteria: 249 | - [moveDown, 1, 1] 250 | - [health, 0.1, 1] 251 | delta_t: 0 252 | immediate: true 253 | neigh_criteria: 254 | - [edge, [target, oldLoc], atLoc, [1,1]] 255 | - [edge, [oldLoc, newLoc], down, [1,1]] 256 | - [node, [newLoc], blocked, [0,0]] 257 | 258 | edges: [newLoc, target, atLoc] 259 | ann_fn: [0.5,0.5] 260 | 261 | move_down_change_prev_rule_1: 262 | target: atLoc 263 | target_criteria: 264 | - [atLoc, 1, 1] 265 | delta_t: 0 266 | immediate: true 267 | neigh_criteria: 268 | - [edge, [target, newLoc], down, [1,1]] 269 | - [edge, [source, newLoc], atLoc, [0.5,0.5]] 270 | 271 | ann_fn: [0,0] 272 | 273 | move_down_change_prev_rule_2: 274 | target: atLoc 275 | target_criteria: 276 | - [atLoc, 1, 1] 277 | delta_t: 0 278 | immediate: true 279 | neigh_criteria: 280 | - [edge, [source, newLoc], down, [1,1]] 281 | - [edge, [newLoc, target], atLoc, [0.5,0.5]] 282 | 283 | ann_fn: [0,0] 284 | 285 | # Move LEFT 286 | move_left_rule_1: 287 | target: 288 | target_criteria: 289 | - [moveLeft, 1, 1] 290 | - [health, 0.1, 1] 291 | delta_t: 0 292 | immediate: true 293 | neigh_criteria: 294 | - [edge, [target, oldLoc], atLoc, [1,1]] 295 | - [edge, [oldLoc, newLoc], left, [1,1]] 296 | - [node, [newLoc], blocked, [0,0]] 297 | 298 | edges: [target, newLoc, atLoc] 299 | ann_fn: [0.5,0.5] 300 | 301 | move_left_rule_2: 302 | target: 303 | target_criteria: 304 | - [moveLeft, 1, 1] 305 | - [health, 0.1, 1] 306 | delta_t: 0 307 | immediate: true 308 | neigh_criteria: 309 | - [edge, [target, oldLoc], atLoc, [1,1]] 310 | - [edge, [oldLoc, newLoc], left, [1,1]] 311 | - [node, [newLoc], blocked, [0,0]] 312 | 313 | edges: [newLoc, target, atLoc] 314 | ann_fn: [0.5,0.5] 315 | 316 | move_left_change_prev_rule_1: 317 | target: atLoc 318 | target_criteria: 319 | - [atLoc, 1, 1] 320 | delta_t: 0 321 | immediate: true 322 | neigh_criteria: 323 | - [edge, [target, newLoc], left, [1,1]] 324 | - [edge, [source, newLoc], atLoc, [0.5,0.5]] 325 | 326 | ann_fn: [0,0] 327 | 328 | move_left_change_prev_rule_2: 329 | target: atLoc 330 | target_criteria: 331 | - [atLoc, 1, 1] 332 | delta_t: 0 333 | immediate: true 334 | neigh_criteria: 335 | - [edge, [source, newLoc], left, [1,1]] 336 | - [edge, [newLoc, target], atLoc, [0.5,0.5]] 337 | 338 | ann_fn: [0,0] 339 | 340 | # Move RIGHT 341 | move_right_rule_1: 342 | target: 343 | target_criteria: 344 | - [moveRight, 1, 1] 345 | - [health, 0.1, 1] 346 | delta_t: 0 347 | immediate: true 348 | neigh_criteria: 349 | - [edge, [target, oldLoc], atLoc, [1,1]] 350 | - [edge, [oldLoc, newLoc], right, [1,1]] 351 | - [node, [newLoc], blocked, [0,0]] 352 | 353 | edges: [target, newLoc, atLoc] 354 | ann_fn: [0.5,0.5] 355 | 356 | move_right_rule_2: 357 | target: 358 | target_criteria: 359 | - [moveRight, 1, 1] 360 | - [health, 0.1, 1] 361 | delta_t: 0 362 | immediate: true 363 | neigh_criteria: 364 | - [edge, [target, oldLoc], atLoc, [1,1]] 365 | - [edge, [oldLoc, newLoc], right, [1,1]] 366 | - [node, [newLoc], blocked, [0,0]] 367 | 368 | edges: [newLoc, target, atLoc] 369 | ann_fn: [0.5,0.5] 370 | 371 | move_right_change_prev_rule_1: 372 | target: atLoc 373 | target_criteria: 374 | - [atLoc, 1, 1] 375 | delta_t: 0 376 | immediate: true 377 | neigh_criteria: 378 | - [edge, [target, newLoc], right, [1,1]] 379 | - [edge, [source, newLoc], atLoc, [0.5,0.5]] 380 | 381 | ann_fn: [0,0] 382 | 383 | move_right_change_prev_rule_2: 384 | target: atLoc 385 | target_criteria: 386 | - [atLoc, 1, 1] 387 | delta_t: 0 388 | immediate: true 389 | neigh_criteria: 390 | - [edge, [source, newLoc], right, [1,1]] 391 | - [edge, [newLoc, target], atLoc, [0.5,0.5]] 392 | 393 | ann_fn: [0,0] 394 | 395 | move_complete: 396 | target: atLoc 397 | target_criteria: 398 | - [atLoc, 0.5, 0.5] 399 | delta_t: 0 400 | immediate: true 401 | neigh_criteria: 402 | 403 | ann_fn: [1,1] 404 | 405 | 406 | # SHOOT UP 407 | # RED 408 | shoot_up_red_setup_new_bullet_rule: 409 | target: 410 | target_criteria: 411 | - [shootUpRed, 1, 1] 412 | - [health, 0.1, 1] 413 | delta_t: 0 414 | neigh_criteria: 415 | - [edge, [target, loc], atLoc, [1,1]] 416 | - [edge, [target, bullet], bullet, [1,1]] 417 | 418 | edges: [loc, bullet, newBullet] 419 | ann_fn: [0.2, 0.2] 420 | 421 | shoot_up_blue_setup_new_bullet_rule: 422 | target: 423 | target_criteria: 424 | - [shootUpBlue, 1, 1] 425 | - [health, 0.1, 1] 426 | delta_t: 0 427 | neigh_criteria: 428 | - [edge, [target, loc], atLoc, [1,1]] 429 | - [edge, [target, bullet], bullet, [1,1]] 430 | 431 | edges: [loc, bullet, newBullet] 432 | ann_fn: [0.2, 0.2] 433 | 434 | 435 | # SHOOT DOWN 436 | shoot_down_red_setup_new_bullet_rule: 437 | target: 438 | target_criteria: 439 | - [shootDownRed, 1, 1] 440 | - [health, 0.1, 1] 441 | delta_t: 0 442 | neigh_criteria: 443 | - [edge, [target, loc], atLoc, [1,1]] 444 | - [edge, [target, bullet], bullet, [1,1]] 445 | 446 | edges: [loc, bullet, newBullet] 447 | ann_fn: [0.6, 0.6] 448 | 449 | shoot_down_blue_setup_new_bullet_rule: 450 | target: 451 | target_criteria: 452 | - [shootDownBlue, 1, 1] 453 | - [health, 0.1, 1] 454 | delta_t: 0 455 | neigh_criteria: 456 | - [edge, [target, loc], atLoc, [1,1]] 457 | - [edge, [target, bullet], bullet, [1,1]] 458 | 459 | edges: [loc, bullet, newBullet] 460 | ann_fn: [0.6, 0.6] 461 | 462 | 463 | # SHOOT LEFT 464 | shoot_left_red_setup_new_bullet_rule: 465 | target: 466 | target_criteria: 467 | - [shootLeftRed, 1, 1] 468 | - [health, 0.1, 1] 469 | delta_t: 0 470 | neigh_criteria: 471 | - [edge, [target, loc], atLoc, [1,1]] 472 | - [edge, [target, bullet], bullet, [1,1]] 473 | 474 | edges: [loc, bullet, newBullet] 475 | ann_fn: [0.4, 0.4] 476 | 477 | shoot_left_blue_setup_new_bullet_rule: 478 | target: 479 | target_criteria: 480 | - [shootLeftBlue, 1, 1] 481 | - [health, 0.1, 1] 482 | delta_t: 0 483 | neigh_criteria: 484 | - [edge, [target, loc], atLoc, [1,1]] 485 | - [edge, [target, bullet], bullet, [1,1]] 486 | 487 | edges: [loc, bullet, newBullet] 488 | ann_fn: [0.4, 0.4] 489 | 490 | 491 | # SHOOT RIGHT 492 | shoot_right_red_setup_new_bullet_rule: 493 | target: 494 | target_criteria: 495 | - [shootRightRed, 1, 1] 496 | - [health, 0.1, 1] 497 | delta_t: 0 498 | neigh_criteria: 499 | - [edge, [target, loc], atLoc, [1,1]] 500 | - [edge, [target, bullet], bullet, [1,1]] 501 | 502 | edges: [loc, bullet, newBullet] 503 | ann_fn: [0.8, 0.8] 504 | 505 | shoot_right_blue_setup_new_bullet_rule: 506 | target: 507 | target_criteria: 508 | - [shootRightBlue, 1, 1] 509 | - [health, 0.1, 1] 510 | delta_t: 0 511 | neigh_criteria: 512 | - [edge, [target, loc], atLoc, [1,1]] 513 | - [edge, [target, bullet], bullet, [1,1]] 514 | 515 | edges: [loc, bullet, newBullet] 516 | ann_fn: [0.8, 0.8] 517 | 518 | 519 | # COMMON RULES FOR SETUP 520 | # Initial Location 521 | shoot_setup_loc_rule: 522 | target: 523 | target_criteria: 524 | - [newBullet, 0.2, 0.8] 525 | delta_t: 0 526 | immediate: true 527 | neigh_criteria: 528 | 529 | edges: [source, target, atLoc] 530 | ann_fn: [1,1] 531 | 532 | # Initial Life 533 | shoot_setup_life_rule: 534 | target: 535 | target_criteria: 536 | - [newBullet, 0.2, 0.8] 537 | delta_t: 0 538 | immediate: true 539 | neigh_criteria: 540 | 541 | edges: [source, target, life] 542 | ann_fn: [1, 1] 543 | 544 | # Bullet Directions 545 | shoot_up_setup_direction_rule: 546 | target: 547 | target_criteria: 548 | - [newBullet, 0.2, 0.2] 549 | delta_t: 0 550 | immediate: true 551 | neigh_criteria: 552 | 553 | edges: [source, target, direction] 554 | ann_fn: [0.2, 0.2] 555 | 556 | shoot_down_setup_direction_rule: 557 | target: 558 | target_criteria: 559 | - [newBullet, 0.6, 0.6] 560 | delta_t: 0 561 | immediate: true 562 | neigh_criteria: 563 | 564 | edges: [source, target, direction] 565 | ann_fn: [0.6, 0.6] 566 | 567 | shoot_left_setup_direction_rule: 568 | target: 569 | target_criteria: 570 | - [newBullet, 0.4, 0.4] 571 | delta_t: 0 572 | immediate: true 573 | neigh_criteria: 574 | 575 | edges: [source, target, direction] 576 | ann_fn: [0.4, 0.4] 577 | 578 | shoot_right_setup_direction_rule: 579 | target: 580 | target_criteria: 581 | - [newBullet, 0.8, 0.8] 582 | delta_t: 0 583 | immediate: true 584 | neigh_criteria: 585 | 586 | edges: [source, target, direction] 587 | ann_fn: [0.8, 0.8] 588 | 589 | # Turn off New Bullet mode 590 | shoot_new_bullet_off_rule: 591 | target: newBullet 592 | target_criteria: 593 | - [newBullet, 0.2, 0.8] 594 | delta_t: 0 595 | immediate: true 596 | neigh_criteria: 597 | ann_fn: [0, 0] 598 | 599 | # BULLET DYNAMICS 600 | # UP 601 | bullet_move_up_setup_pos_rule: 602 | target: 603 | target_criteria: 604 | - [atLoc, 1, 1] 605 | - [life, 1, 1] 606 | - [direction, 0.2, 0.2] 607 | delta_t: 1 608 | neigh_criteria: 609 | - [edge, [source, newLoc], up, [1,1]] 610 | - [node, [newLoc], blocked, [0,0]] 611 | 612 | edges: [newLoc, target, atLoc] 613 | ann_fn: [1, 1] 614 | 615 | bullet_move_up_change_prev_loc_rule: 616 | target: atLoc 617 | target_criteria: 618 | - [atLoc, 1, 1] 619 | - [life, 1, 1] 620 | - [direction, 0.2, 0.2] 621 | delta_t: 1 622 | neigh_criteria: 623 | - [edge, [source, newLoc], up, [1,1]] 624 | - [node, [newLoc], blocked, [0,0]] 625 | 626 | ann_fn: [0,0] 627 | 628 | bullet_move_up_setup_life_rule: 629 | target: 630 | target_criteria: 631 | - [atLoc, 1, 1] 632 | - [life, 1, 1] 633 | - [direction, 0.2, 0.2] 634 | delta_t: 1 635 | neigh_criteria: 636 | - [edge, [source, newLoc], up, [1,1]] 637 | - [node, [newLoc], blocked, [0,0]] 638 | 639 | edges: [newLoc, target, life] 640 | ann_fn: [1, 1] 641 | 642 | bullet_move_up_setup_direction_rule: 643 | target: 644 | target_criteria: 645 | - [atLoc, 1, 1] 646 | - [life, 1, 1] 647 | - [direction, 0.2, 0.2] 648 | delta_t: 1 649 | neigh_criteria: 650 | - [edge, [source, newLoc], up, [1,1]] 651 | - [node, [newLoc], blocked, [0,0]] 652 | 653 | edges: [newLoc, target, direction] 654 | ann_fn: [0.2, 0.2] 655 | 656 | # Out of grid 657 | bullet_move_up_out_of_grid_rule: 658 | target: life 659 | target_criteria: 660 | - [atLoc, 1, 1] 661 | - [life, 1, 1] 662 | - [direction, 0.2, 0.2] 663 | delta_t: 1 664 | immediate: true 665 | neigh_criteria: 666 | - [edge, [source, newLoc], up, [1,1]] 667 | - [node, [newLoc], blocked, [1,1]] 668 | 669 | ann_fn: [0, 0] 670 | 671 | # DOWN 672 | bullet_move_down_setup_pos_rule: 673 | target: 674 | target_criteria: 675 | - [atLoc, 1, 1] 676 | - [life, 1, 1] 677 | - [direction, 0.6, 0.6] 678 | delta_t: 1 679 | neigh_criteria: 680 | - [edge, [source, newLoc], down, [1,1]] 681 | - [node, [newLoc], blocked, [0,0]] 682 | 683 | edges: [newLoc, target, atLoc] 684 | ann_fn: [1, 1] 685 | 686 | bullet_move_down_change_prev_loc_rule: 687 | target: atLoc 688 | target_criteria: 689 | - [atLoc, 1, 1] 690 | - [life, 1, 1] 691 | - [direction, 0.6, 0.6] 692 | delta_t: 1 693 | neigh_criteria: 694 | - [edge, [source, newLoc], down, [1,1]] 695 | - [node, [newLoc], blocked, [0,0]] 696 | 697 | ann_fn: [0,0] 698 | 699 | bullet_move_down_setup_life_rule: 700 | target: 701 | target_criteria: 702 | - [atLoc, 1, 1] 703 | - [life, 1, 1] 704 | - [direction, 0.6, 0.6] 705 | delta_t: 1 706 | neigh_criteria: 707 | - [edge, [source, newLoc], down, [1,1]] 708 | - [node, [newLoc], blocked, [0,0]] 709 | 710 | edges: [newLoc, target, life] 711 | ann_fn: [1, 1] 712 | 713 | bullet_move_down_setup_direction_rule: 714 | target: 715 | target_criteria: 716 | - [atLoc, 1, 1] 717 | - [life, 1, 1] 718 | - [direction, 0.6, 0.6] 719 | delta_t: 1 720 | neigh_criteria: 721 | - [edge, [source, newLoc], down, [1,1]] 722 | - [node, [newLoc], blocked, [0,0]] 723 | 724 | edges: [newLoc, target, direction] 725 | ann_fn: [0.6, 0.6] 726 | 727 | # Out of grid 728 | bullet_move_down_out_of_grid_rule: 729 | target: life 730 | target_criteria: 731 | - [atLoc, 1, 1] 732 | - [life, 1, 1] 733 | - [direction, 0.6, 0.6] 734 | delta_t: 1 735 | immediate: true 736 | neigh_criteria: 737 | - [edge, [source, newLoc], down, [1,1]] 738 | - [node, [newLoc], blocked, [1,1]] 739 | 740 | ann_fn: [0, 0] 741 | 742 | # LEFT 743 | bullet_move_left_setup_pos_rule: 744 | target: 745 | target_criteria: 746 | - [atLoc, 1, 1] 747 | - [life, 1, 1] 748 | - [direction, 0.4, 0.4] 749 | delta_t: 1 750 | neigh_criteria: 751 | - [edge, [source, newLoc], left, [1,1]] 752 | - [node, [newLoc], blocked, [0,0]] 753 | 754 | edges: [newLoc, target, atLoc] 755 | ann_fn: [1, 1] 756 | 757 | bullet_move_left_change_prev_loc_rule: 758 | target: atLoc 759 | target_criteria: 760 | - [atLoc, 1, 1] 761 | - [life, 1, 1] 762 | - [direction, 0.4, 0.4] 763 | delta_t: 1 764 | neigh_criteria: 765 | - [edge, [source, newLoc], left, [1,1]] 766 | - [node, [newLoc], blocked, [0,0]] 767 | 768 | ann_fn: [0,0] 769 | 770 | bullet_move_left_setup_life_rule: 771 | target: 772 | target_criteria: 773 | - [atLoc, 1, 1] 774 | - [life, 1, 1] 775 | - [direction, 0.4, 0.4] 776 | delta_t: 1 777 | neigh_criteria: 778 | - [edge, [source, newLoc], left, [1,1]] 779 | - [node, [newLoc], blocked, [0,0]] 780 | 781 | edges: [newLoc, target, life] 782 | ann_fn: [1, 1] 783 | 784 | bullet_move_left_setup_direction_rule: 785 | target: 786 | target_criteria: 787 | - [atLoc, 1, 1] 788 | - [life, 1, 1] 789 | - [direction, 0.4, 0.4] 790 | delta_t: 1 791 | neigh_criteria: 792 | - [edge, [source, newLoc], left, [1,1]] 793 | - [node, [newLoc], blocked, [0,0]] 794 | 795 | edges: [newLoc, target, direction] 796 | ann_fn: [0.4, 0.4] 797 | 798 | # Out of grid 799 | bullet_move_left_out_of_grid_rule: 800 | target: life 801 | target_criteria: 802 | - [atLoc, 1, 1] 803 | - [life, 1, 1] 804 | - [direction, 0.4, 0.4] 805 | delta_t: 1 806 | immediate: true 807 | neigh_criteria: 808 | - [edge, [source, newLoc], left, [1,1]] 809 | - [node, [newLoc], blocked, [1,1]] 810 | 811 | ann_fn: [0, 0] 812 | 813 | # RIGHT 814 | bullet_move_right_setup_pos_rule: 815 | target: 816 | target_criteria: 817 | - [atLoc, 1, 1] 818 | - [life, 1, 1] 819 | - [direction, 0.8, 0.8] 820 | delta_t: 1 821 | neigh_criteria: 822 | - [edge, [source, newLoc], right, [1,1]] 823 | - [node, [newLoc], blocked, [0,0]] 824 | 825 | edges: [newLoc, target, atLoc] 826 | ann_fn: [1, 1] 827 | 828 | bullet_move_right_change_prev_loc_rule: 829 | target: atLoc 830 | target_criteria: 831 | - [atLoc, 1, 1] 832 | - [life, 1, 1] 833 | - [direction, 0.8, 0.8] 834 | delta_t: 1 835 | neigh_criteria: 836 | - [edge, [source, newLoc], right, [1,1]] 837 | - [node, [newLoc], blocked, [0,0]] 838 | 839 | ann_fn: [0,0] 840 | 841 | bullet_move_right_setup_life_rule: 842 | target: 843 | target_criteria: 844 | - [atLoc, 1, 1] 845 | - [life, 1, 1] 846 | - [direction, 0.8, 0.8] 847 | delta_t: 1 848 | neigh_criteria: 849 | - [edge, [source, newLoc], right, [1,1]] 850 | - [node, [newLoc], blocked, [0,0]] 851 | 852 | edges: [newLoc, target, life] 853 | ann_fn: [1, 1] 854 | 855 | bullet_move_right_setup_direction_rule: 856 | target: 857 | target_criteria: 858 | - [atLoc, 1, 1] 859 | - [life, 1, 1] 860 | - [direction, 0.8, 0.8] 861 | delta_t: 1 862 | neigh_criteria: 863 | - [edge, [source, newLoc], right, [1,1]] 864 | - [node, [newLoc], blocked, [0,0]] 865 | 866 | edges: [newLoc, target, direction] 867 | ann_fn: [0.8, 0.8] 868 | 869 | # Out of grid 870 | bullet_move_right_out_of_grid_rule: 871 | target: life 872 | target_criteria: 873 | - [atLoc, 1, 1] 874 | - [life, 1, 1] 875 | - [direction, 0.8, 0.8] 876 | delta_t: 1 877 | immediate: true 878 | neigh_criteria: 879 | - [edge, [source, newLoc], right, [1,1]] 880 | - [node, [newLoc], blocked, [1,1]] 881 | 882 | ann_fn: [0, 0] 883 | -------------------------------------------------------------------------------- /pyreason_gym/pyreason_map_world/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-v2/pyreason-gym/90cb05a1e886c54e55b54472499f350ec691d823/pyreason_gym/pyreason_map_world/__init__.py -------------------------------------------------------------------------------- /pyreason_gym/pyreason_map_world/graph/.gitignore: -------------------------------------------------------------------------------- 1 | map_graph_new.graphml -------------------------------------------------------------------------------- /pyreason_gym/pyreason_map_world/pyreason_map_world.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pyreason as pr 3 | import numpy as np 4 | import time 5 | 6 | 7 | class PyReasonMapWorld: 8 | def __init__(self, end_point): 9 | self.interpretation = None 10 | self.end_point = end_point 11 | 12 | # Store the lat/long of the end point 13 | self.end_point_lat = None 14 | self.end_point_long = None 15 | 16 | # Keep track of the next timestep to start 17 | self.next_time = 0 18 | self.steps = 0 19 | 20 | # Pyreason settings 21 | pr.settings.verbose = False 22 | pr.settings.atom_trace = True 23 | pr.settings.canonical = True 24 | pr.settings.inconsistency_check = False 25 | pr.settings.static_graph_facts = False 26 | current_path = os.path.abspath(os.path.dirname(__file__)) 27 | 28 | # Load the graph 29 | pr.load_graph(f'{current_path}/graph/map_graph_clustering.graphml') 30 | 31 | # Load rules 32 | pr.load_rules(f'{current_path}/yamls/rules.yaml') 33 | 34 | def reset(self): 35 | # Reason for 1 timestep to initialize everything 36 | # Certain internal variables need to be reset otherwise memory blows up 37 | pr.reset() 38 | self.interpretation = pr.reason(0, again=False) 39 | self.next_time = self.interpretation.time + 1 40 | 41 | # Store the lat/long of the end point 42 | self.end_point_lat, self.end_point_long = self._get_lat_long(self.end_point) 43 | 44 | def move(self, action): 45 | # Define facts, then run pyreason 46 | # action input is a number corresponding to which path (edge from one node to another) the agent should take 47 | facts = [] 48 | fact_on = pr.fact_node.Fact(f'move_{self.steps}', 'agent', pr.label.Label(f'move_{action}'), pr.interval.closed(1, 1), self.next_time, self.next_time) 49 | fact_off = pr.fact_node.Fact(f'move_{self.steps}', 'agent', pr.label.Label(f'move_{action}'), pr.interval.closed(0, 0), self.next_time + 1, self.next_time + 1) 50 | facts.append(fact_on) 51 | facts.append(fact_off) 52 | 53 | self.interpretation = pr.reason(2, again=True, node_facts=facts) 54 | self.next_time = self.interpretation.time + 1 55 | self.steps += 1 56 | 57 | def get_obs(self): 58 | # Calculate current and end point lat longs 59 | relevant_edges = [edge for edge in self.interpretation.edges if edge[0] == 'agent' and self.interpretation.interpretations_edge[edge].world[pr.label.Label('atLoc')] == pr.interval.closed(1,1)] 60 | print(relevant_edges) 61 | assert len(relevant_edges) == 1, 'Agent cannot be in multiple places at once--mistake in the interpretation data' 62 | current_edge = relevant_edges[0] 63 | loc = current_edge[1] 64 | 65 | lat, long = self._get_lat_long(loc) 66 | 67 | current_lat_long = np.array([lat, long], dtype=np.float128) 68 | end_lat_long = np.array([self.end_point_lat, self.end_point_long], dtype=np.float128) 69 | 70 | # Get info about current action space 71 | # Get number of outgoing edges. New action space = num outgoing edges 72 | outgoing_edges = [edge for edge in self.interpretation.edges if edge[1] == current_edge[1] and edge[0] != 'agent'] 73 | num_outgoing_edges = len(outgoing_edges) 74 | 75 | observation = (loc, current_lat_long, end_lat_long, num_outgoing_edges) 76 | return observation 77 | 78 | def _get_lat_long(self, node): 79 | world = self.interpretation.interpretations_node[node].world 80 | lat = None 81 | long = None 82 | for label, interval in world.items(): 83 | # Represented internally by lat-x and long-y 84 | if 'lat' in label._value: 85 | lat = float(label._value[4:]) 86 | elif 'long' in label._value: 87 | long = float(label._value[5:]) 88 | 89 | return lat, long 90 | 91 | def get_map(self): 92 | nodes = [node for node in self.interpretation.nodes if node != 'agent'] 93 | edges = [edge for edge in self.interpretation.edges if edge[0] != 'agent'] 94 | 95 | # Return list of nodes (landmarks/stops) and list of edges connecting these points 96 | nodes_lat_long = [(self._get_lat_long(node)) for node in nodes] 97 | edges_lat_long = [((self._get_lat_long(edge[0])), (self._get_lat_long(edge[1]))) for edge in edges] 98 | return nodes_lat_long, edges_lat_long 99 | -------------------------------------------------------------------------------- /pyreason_gym/pyreason_map_world/yamls/rules.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | move_rule_0: 3 | target: 4 | target_criteria: 5 | - [move_0, 1, 1] 6 | delta_t: 1 7 | neigh_criteria: 8 | - [edge, [target, oldLoc], atLoc, [1,1]] 9 | - [edge, [oldLoc, newLoc], path-0, [1,1]] 10 | 11 | edges: [target, newLoc, atLoc] 12 | ann_fn: [1,1] 13 | 14 | move_change_prev_rule_0: 15 | target: atLoc 16 | target_criteria: 17 | - [atLoc, 1, 1] 18 | delta_t: 1 19 | neigh_criteria: 20 | - [node, [source], move_0, [1,1]] 21 | - [edge, [target, newLoc], path-0, [1,1]] 22 | 23 | ann_fn: [0,0] 24 | 25 | move_rule_1: 26 | target: 27 | target_criteria: 28 | - [move_1, 1, 1] 29 | delta_t: 1 30 | neigh_criteria: 31 | - [edge, [target, oldLoc], atLoc, [1,1]] 32 | - [edge, [oldLoc, newLoc], path-1, [1,1]] 33 | 34 | edges: [target, newLoc, atLoc] 35 | ann_fn: [1,1] 36 | 37 | move_change_prev_rule_1: 38 | target: atLoc 39 | target_criteria: 40 | - [atLoc, 1, 1] 41 | delta_t: 1 42 | neigh_criteria: 43 | - [node, [source], move_1, [1,1]] 44 | - [edge, [target, newLoc], path-1, [1,1]] 45 | 46 | ann_fn: [0,0] 47 | 48 | move_rule_2: 49 | target: 50 | target_criteria: 51 | - [move_2, 1, 1] 52 | delta_t: 1 53 | neigh_criteria: 54 | - [edge, [target, oldLoc], atLoc, [1,1]] 55 | - [edge, [oldLoc, newLoc], path-2, [1,1]] 56 | 57 | edges: [target, newLoc, atLoc] 58 | ann_fn: [1,1] 59 | 60 | move_change_prev_rule_2: 61 | target: atLoc 62 | target_criteria: 63 | - [atLoc, 1, 1] 64 | delta_t: 1 65 | neigh_criteria: 66 | - [node, [source], move_2, [1,1]] 67 | - [edge, [target, newLoc], path-2, [1,1]] 68 | 69 | ann_fn: [0,0] 70 | 71 | move_rule_3: 72 | target: 73 | target_criteria: 74 | - [move_3, 1, 1] 75 | delta_t: 1 76 | neigh_criteria: 77 | - [edge, [target, oldLoc], atLoc, [1,1]] 78 | - [edge, [oldLoc, newLoc], path-3, [1,1]] 79 | 80 | edges: [target, newLoc, atLoc] 81 | ann_fn: [1,1] 82 | 83 | move_change_prev_rule_3: 84 | target: atLoc 85 | target_criteria: 86 | - [atLoc, 1, 1] 87 | delta_t: 1 88 | neigh_criteria: 89 | - [node, [source], move_3, [1,1]] 90 | - [edge, [target, newLoc], path-3, [1,1]] 91 | 92 | ann_fn: [0,0] 93 | 94 | move_rule_4: 95 | target: 96 | target_criteria: 97 | - [move_4, 1, 1] 98 | delta_t: 1 99 | neigh_criteria: 100 | - [edge, [target, oldLoc], atLoc, [1,1]] 101 | - [edge, [oldLoc, newLoc], path-4, [1,1]] 102 | 103 | edges: [target, newLoc, atLoc] 104 | ann_fn: [1,1] 105 | 106 | move_change_prev_rule_4: 107 | target: atLoc 108 | target_criteria: 109 | - [atLoc, 1, 1] 110 | delta_t: 1 111 | neigh_criteria: 112 | - [node, [source], move_4, [1,1]] 113 | - [edge, [target, newLoc], path-4, [1,1]] 114 | 115 | ann_fn: [0,0] 116 | 117 | move_rule_5: 118 | target: 119 | target_criteria: 120 | - [move_5, 1, 1] 121 | delta_t: 1 122 | neigh_criteria: 123 | - [edge, [target, oldLoc], atLoc, [1,1]] 124 | - [edge, [oldLoc, newLoc], path-5, [1,1]] 125 | 126 | edges: [target, newLoc, atLoc] 127 | ann_fn: [1,1] 128 | 129 | move_change_prev_rule_5: 130 | target: atLoc 131 | target_criteria: 132 | - [atLoc, 1, 1] 133 | delta_t: 1 134 | neigh_criteria: 135 | - [node, [source], move_5, [1,1]] 136 | - [edge, [target, newLoc], path-5, [1,1]] 137 | 138 | ann_fn: [0,0] 139 | 140 | move_rule_6: 141 | target: 142 | target_criteria: 143 | - [move_6, 1, 1] 144 | delta_t: 1 145 | neigh_criteria: 146 | - [edge, [target, oldLoc], atLoc, [1,1]] 147 | - [edge, [oldLoc, newLoc], path-6, [1,1]] 148 | 149 | edges: [target, newLoc, atLoc] 150 | ann_fn: [1,1] 151 | 152 | move_change_prev_rule_6: 153 | target: atLoc 154 | target_criteria: 155 | - [atLoc, 1, 1] 156 | delta_t: 1 157 | neigh_criteria: 158 | - [node, [source], move_6, [1,1]] 159 | - [edge, [target, newLoc], path-6, [1,1]] 160 | 161 | ann_fn: [0,0] 162 | 163 | move_rule_7: 164 | target: 165 | target_criteria: 166 | - [move_7, 1, 1] 167 | delta_t: 1 168 | neigh_criteria: 169 | - [edge, [target, oldLoc], atLoc, [1,1]] 170 | - [edge, [oldLoc, newLoc], path-7, [1,1]] 171 | 172 | edges: [target, newLoc, atLoc] 173 | ann_fn: [1,1] 174 | 175 | move_change_prev_rule_7: 176 | target: atLoc 177 | target_criteria: 178 | - [atLoc, 1, 1] 179 | delta_t: 1 180 | neigh_criteria: 181 | - [node, [source], move_7, [1,1]] 182 | - [edge, [target, newLoc], path-7, [1,1]] 183 | 184 | ann_fn: [0,0] 185 | 186 | move_rule_8: 187 | target: 188 | target_criteria: 189 | - [move_8, 1, 1] 190 | delta_t: 1 191 | neigh_criteria: 192 | - [edge, [target, oldLoc], atLoc, [1,1]] 193 | - [edge, [oldLoc, newLoc], path-8, [1,1]] 194 | 195 | edges: [target, newLoc, atLoc] 196 | ann_fn: [1,1] 197 | 198 | move_change_prev_rule_8: 199 | target: atLoc 200 | target_criteria: 201 | - [atLoc, 1, 1] 202 | delta_t: 1 203 | neigh_criteria: 204 | - [node, [source], move_8, [1,1]] 205 | - [edge, [target, newLoc], path-8, [1,1]] 206 | 207 | ann_fn: [0,0] 208 | 209 | move_rule_9: 210 | target: 211 | target_criteria: 212 | - [move_9, 1, 1] 213 | delta_t: 1 214 | neigh_criteria: 215 | - [edge, [target, oldLoc], atLoc, [1,1]] 216 | - [edge, [oldLoc, newLoc], path-9, [1,1]] 217 | 218 | edges: [target, newLoc, atLoc] 219 | ann_fn: [1,1] 220 | 221 | move_change_prev_rule_9: 222 | target: atLoc 223 | target_criteria: 224 | - [atLoc, 1, 1] 225 | delta_t: 1 226 | neigh_criteria: 227 | - [node, [source], move_9, [1,1]] 228 | - [edge, [target, newLoc], path-9, [1,1]] 229 | 230 | ann_fn: [0,0] 231 | 232 | move_rule_10: 233 | target: 234 | target_criteria: 235 | - [move_10, 1, 1] 236 | delta_t: 1 237 | neigh_criteria: 238 | - [edge, [target, oldLoc], atLoc, [1,1]] 239 | - [edge, [oldLoc, newLoc], path-10, [1,1]] 240 | 241 | edges: [target, newLoc, atLoc] 242 | ann_fn: [1,1] 243 | 244 | move_change_prev_rule_10: 245 | target: atLoc 246 | target_criteria: 247 | - [atLoc, 1, 1] 248 | delta_t: 1 249 | neigh_criteria: 250 | - [node, [source], move_10, [1,1]] 251 | - [edge, [target, newLoc], path-10, [1,1]] 252 | 253 | ann_fn: [0,0] 254 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="pyreason_gym", 5 | version="0.0.1", 6 | py_modules=[], 7 | install_requires=["gym", "pygame", "networkx", "pyreason==1.6.4", "numpy"], 8 | ) 9 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | 4 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 5 | obs = env.reset() 6 | 7 | # Sample actions: 8 | # action = { 9 | # 'red_team': [0], 10 | # 'blue_team': [1] 11 | # } 12 | # obs = env.step(action) 13 | # action = { 14 | # 'red_team': [2], 15 | # 'blue_team': [3] 16 | # } 17 | # obs = env.step(action) 18 | 19 | # Randomly sample actions from the action space 20 | for i in range(50): 21 | action = env.action_space.sample() 22 | print(action) 23 | env.step(action) 24 | 25 | env.close() 26 | 27 | # env = gym.make('PyReasonMapWorld-v0', start_point='node1', end_point='node2') 28 | # obs = env.reset() 29 | # print(obs) 30 | -------------------------------------------------------------------------------- /tests/agent_criss_cross/gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[2], [5]], 101 | obstacle_loc=[26, 27, 34, 35, 36, 44]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/agent_criss_cross/test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 6 | obs = env.reset() 7 | 8 | for _ in range(8): 9 | action = { 10 | 'red_team': [3], 11 | 'blue_team': [2] 12 | } 13 | obs = env.step(action) 14 | print(obs) 15 | time.sleep(1) 16 | env.close() -------------------------------------------------------------------------------- /tests/agent_same_location_shoot/gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[2], [3]], 101 | obstacle_loc=[26, 27, 34, 35, 36, 44]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/agent_same_location_shoot/test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 6 | obs = env.reset() 7 | action = { 8 | 'red_team': [3], 9 | 'blue_team': [1] 10 | } 11 | obs = env.step(action) 12 | print(obs) 13 | time.sleep(1) 14 | action = { 15 | 'red_team': [4], 16 | 'blue_team': [1] 17 | } 18 | obs = env.step(action) 19 | print(obs) 20 | time.sleep(1) 21 | action = { 22 | 'red_team': [3], 23 | 'blue_team': [1] 24 | } 25 | obs = env.step(action) 26 | print(obs) 27 | time.sleep(1) 28 | env.close() -------------------------------------------------------------------------------- /tests/base_done_check/gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[2], [15]], 101 | obstacle_loc=[26, 27, 34, 35, 36, 44]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/base_done_check/test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 6 | obs = env.reset() 7 | 8 | # Sample actions: 9 | action = { 10 | 'red_team': [0], 11 | 'blue_team': [1] 12 | } 13 | obs = env.step(action) 14 | print(obs) 15 | for _ in range(8): 16 | time.sleep(1) 17 | action = { 18 | 'red_team': [0], 19 | 'blue_team': [1] 20 | } 21 | obs = env.step(action) 22 | print(obs) 23 | env.close() -------------------------------------------------------------------------------- /tests/bulk_movement_test/gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[2], [5]], 101 | obstacle_loc=[26, 27, 34, 35, 36, 44]) 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /tests/bulk_movement_test/test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | import random 5 | 6 | env = gym.make('PyReasonGridWorld-v0', render_mode=None) 7 | obs = env.reset() 8 | 9 | # print("Reset:", obs) 10 | exp_start = time.time() 11 | for i in range(10000): 12 | # time.sleep(1) 13 | # action = env.action_space.sample() 14 | # print(action) 15 | action = { 16 | 'red_team': [random.randint(0, 3)], 17 | 'blue_team': [random.randint(0, 3)] 18 | } 19 | start = time.time() 20 | obs, _, terminated, truncated, _ = env.step(action) 21 | print(f"Time for single step {i}:", time.time() - start) 22 | done = terminated or truncated 23 | if done: 24 | obs = env.reset() 25 | print("Time for 10000 steps:", time.time() - exp_start) 26 | env.close() -------------------------------------------------------------------------------- /tests/bullet_pass_through/bullet_pass_through_gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[2], [5]], 101 | obstacle_loc=[]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/bullet_pass_through/bullet_pass_through_test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 6 | obs = env.reset() 7 | 8 | # Sample actions: 9 | action = { 10 | 'red_team': [7], 11 | 'blue_team': [6] 12 | } 13 | obs = env.step(action) 14 | time.sleep(1) 15 | action = { 16 | 'red_team': [2], 17 | 'blue_team': [3] 18 | } 19 | obs = env.step(action) 20 | time.sleep(1) 21 | action = { 22 | 'red_team': [2], 23 | 'blue_team': [3] 24 | } 25 | obs = env.step(action) 26 | time.sleep(1) 27 | action = { 28 | 'red_team': [2], 29 | 'blue_team': [3] 30 | } 31 | obs = env.step(action) 32 | print(obs) 33 | time.sleep(1) 34 | action = { 35 | 'red_team': [2], 36 | 'blue_team': [3] 37 | } 38 | obs = env.step(action) 39 | print(obs) 40 | time.sleep(1) 41 | env.close() -------------------------------------------------------------------------------- /tests/bullet_speed/bullet_speed_gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=10, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[0], [2]], 101 | obstacle_loc=[]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/bullet_speed/bullet_speed_test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 6 | obs = env.reset() 7 | 8 | # Sample actions: 9 | action = { 10 | 'red_team': [7], 11 | 'blue_team': [3] 12 | } 13 | obs = env.step(action) 14 | print(obs) 15 | time.sleep(1) 16 | for _ in range(20): 17 | action = { 18 | 'red_team': [1], 19 | 'blue_team': [3] 20 | } 21 | obs = env.step(action) 22 | print(obs) 23 | time.sleep(1) 24 | env.close() -------------------------------------------------------------------------------- /tests/bullet_speed_large/bullet_speed_large_gen_graph_test_.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=30, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[0], [2]], 101 | obstacle_loc=[]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/bullet_speed_large/bullet_speed_large_test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human', grid_size=30) 6 | obs = env.reset() 7 | steps = 0 8 | # Sample actions: 9 | action = { 10 | 'red_team': [7], 11 | 'blue_team': [3] 12 | } 13 | obs = env.step(action) 14 | steps +=1 15 | print("step:", steps) 16 | print(obs) 17 | time.sleep(1) 18 | for _ in range(35): 19 | action = { 20 | 'red_team': [1], 21 | 'blue_team': [3] 22 | } 23 | obs = env.step(action) 24 | steps +=1 25 | print("step:", steps) 26 | print(obs) 27 | time.sleep(1) 28 | env.close() -------------------------------------------------------------------------------- /tests/follow_bullet/follow_bullet_gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[56], [6]], 101 | obstacle_loc=[]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/follow_bullet/follow_bullet_test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 6 | obs = env.reset() 7 | 8 | # Sample actions: 9 | action = { 10 | 'red_team': [5], 11 | 'blue_team': [1] 12 | } 13 | obs = env.step(action) 14 | print(obs) 15 | time.sleep(1) 16 | for _ in range(10): 17 | action = { 18 | 'red_team': [1], 19 | 'blue_team': [1] 20 | } 21 | obs = env.step(action) 22 | print(obs) 23 | time.sleep(1) 24 | env.close() -------------------------------------------------------------------------------- /tests/health_check/gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[2], [5]], 101 | obstacle_loc=[]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/health_check/test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 6 | obs = env.reset() 7 | 8 | # Sample actions: 9 | action = { 10 | 'red_team': [7], 11 | 'blue_team': [1] 12 | } 13 | obs = env.step(action) 14 | print(obs) 15 | time.sleep(1) 16 | action = { 17 | 'red_team': [1], 18 | 'blue_team': [1] 19 | } 20 | obs = env.step(action) 21 | print(obs) 22 | time.sleep(1) 23 | action = { 24 | 'red_team': [1], 25 | 'blue_team': [1] 26 | } 27 | obs = env.step(action) 28 | print(obs) 29 | time.sleep(1) 30 | action = { 31 | 'red_team': [1], 32 | 'blue_team': [1] 33 | } 34 | obs = env.step(action) 35 | print(obs) 36 | time.sleep(1) 37 | action = { 38 | 'red_team': [1], 39 | 'blue_team': [1] 40 | } 41 | obs = env.step(action) 42 | print(obs) 43 | time.sleep(1) 44 | env.close() -------------------------------------------------------------------------------- /tests/multi_agent_done_check/gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=2, base_loc=[7, 56], start_loc=[[2,16], [5, 21]], 101 | obstacle_loc=[26, 27, 34, 35, 36, 44]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/multi_agent_done_check/test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human', num_agents_per_team=2) 6 | obs = env.reset() 7 | 8 | action = { 9 | 'red_team': [7,7], 10 | 'blue_team': [1,3] 11 | } 12 | obs = env.step(action) 13 | print(obs) 14 | time.sleep(1) 15 | action = { 16 | 'red_team': [1,2], 17 | 'blue_team': [1,2] 18 | } 19 | obs = env.step(action) 20 | print(obs) 21 | time.sleep(1) 22 | action = { 23 | 'red_team': [1,2], 24 | 'blue_team': [1,3] 25 | } 26 | obs = env.step(action) 27 | print(obs) 28 | time.sleep(1) 29 | action = { 30 | 'red_team': [1,2], 31 | 'blue_team': [1,2] 32 | } 33 | obs = env.step(action) 34 | print(obs) 35 | time.sleep(1) 36 | action = { 37 | 'red_team': [2,3], 38 | 'blue_team': [0,1] 39 | } 40 | obs = env.step(action) 41 | print(obs) 42 | time.sleep(1) 43 | action = { 44 | 'red_team': [2,3], 45 | 'blue_team': [0,3] 46 | } 47 | obs = env.step(action) 48 | print(obs) 49 | time.sleep(1) 50 | action = { 51 | 'red_team': [2,3], 52 | 'blue_team': [0,3] 53 | } 54 | obs = env.step(action) 55 | print(obs) 56 | time.sleep(1) 57 | env.close() -------------------------------------------------------------------------------- /tests/multi_agent_eval/gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=2, base_loc=[7, 56], start_loc=[[7,7], [56, 56]], 101 | obstacle_loc=[26, 27, 34, 35, 36, 44]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/multi_agent_eval/test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human', num_agents_per_team=5) 6 | obs = env.reset() 7 | 8 | # action = { 9 | # 'red_team': [7,7], 10 | # 'blue_team': [1,3] 11 | # } 12 | # obs = env.step(action) 13 | # print(obs) 14 | # time.sleep(1) 15 | action = { 16 | 'red_team': [1,2], 17 | 'blue_team': [1,2] 18 | } 19 | obs = env.step(action) 20 | print(obs) 21 | time.sleep(1) 22 | action = { 23 | 'red_team': [1,2], 24 | 'blue_team': [1,3] 25 | } 26 | obs = env.step(action) 27 | print(obs) 28 | time.sleep(1) 29 | action = { 30 | 'red_team': [1,2], 31 | 'blue_team': [1,2] 32 | } 33 | obs = env.step(action) 34 | print(obs) 35 | time.sleep(1) 36 | action = { 37 | 'red_team': [2,3], 38 | 'blue_team': [0,1] 39 | } 40 | obs = env.step(action) 41 | print(obs) 42 | time.sleep(1) 43 | action = { 44 | 'red_team': [2,3], 45 | 'blue_team': [0,3] 46 | } 47 | obs = env.step(action) 48 | print(obs) 49 | time.sleep(1) 50 | action = { 51 | 'red_team': [2,3], 52 | 'blue_team': [0,3] 53 | } 54 | obs = env.step(action) 55 | print(obs) 56 | time.sleep(1) 57 | env.close() -------------------------------------------------------------------------------- /tests/multi_agent_shoot/gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=2, base_loc=[7, 56], start_loc=[[2,16], [5, 21]], 101 | obstacle_loc=[26, 27, 34, 35, 36, 44]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/multi_agent_shoot/test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human', num_agents_per_team=2) 6 | obs = env.reset() 7 | 8 | action = { 9 | 'red_team': [7,7], 10 | 'blue_team': [1,3] 11 | } 12 | obs = env.step(action) 13 | print(obs) 14 | time.sleep(1) 15 | action = { 16 | 'red_team': [0,0], 17 | 'blue_team': [1,0] 18 | } 19 | obs = env.step(action) 20 | print(obs) 21 | time.sleep(1) 22 | action = { 23 | 'red_team': [1,2], 24 | 'blue_team': [1,0] 25 | } 26 | obs = env.step(action) 27 | print(obs) 28 | time.sleep(1) 29 | action = { 30 | 'red_team': [1,2], 31 | 'blue_team': [1,0] 32 | } 33 | obs = env.step(action) 34 | print(obs) 35 | time.sleep(1) 36 | action = { 37 | 'red_team': [2,2], 38 | 'blue_team': [1,0] 39 | } 40 | obs = env.step(action) 41 | print(obs) 42 | time.sleep(1) 43 | action = { 44 | 'red_team': [2,2], 45 | 'blue_team': [1,0] 46 | } 47 | obs = env.step(action) 48 | print(obs) 49 | time.sleep(1) 50 | env.close() -------------------------------------------------------------------------------- /tests/multi_agent_shoot_reappear/gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=2, base_loc=[7, 56], start_loc=[[56,48], [55, 63]], 101 | obstacle_loc=[26, 27, 34, 35, 36, 44]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/multi_agent_shoot_reappear/test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human', num_agents_per_team=2) 6 | obs = env.reset() 7 | print("\n") 8 | print("Starting Experiment with init obs:") 9 | print(obs) 10 | print("\n") 11 | action = { 12 | 'red_team': [7, 7], 13 | 'blue_team': [2, 2] 14 | } 15 | obs = env.step(action) 16 | print(obs) 17 | print("\n") 18 | time.sleep(1) 19 | for ss in range(6): 20 | print("Step", ss) 21 | if ss % 2 == 0: 22 | print("Taking both actions") 23 | action = { 24 | 'red_team': [2, 2], 25 | 'blue_team': [2, 2] 26 | } 27 | else: 28 | print("Taking first agent actions") 29 | action = { 30 | 'red_team': [2,2], 31 | 'blue_team': [8,2] 32 | } 33 | obs = env.step(action) 34 | print(obs) 35 | print("\n") 36 | time.sleep(1) 37 | 38 | env.close() -------------------------------------------------------------------------------- /tests/multi_friendly_shoot/gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=2, base_loc=[7, 56], start_loc=[[2,16], [5, 21]], 101 | obstacle_loc=[26, 27, 34, 35, 36, 44]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/multi_friendly_shoot/test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human', num_agents_per_team=2) 6 | obs = env.reset() 7 | 8 | action = { 9 | 'red_team': [1,2], 10 | 'blue_team': [1,5] 11 | } 12 | obs = env.step(action) 13 | print(obs) 14 | time.sleep(1) 15 | action = { 16 | 'red_team': [1,2], 17 | 'blue_team': [1,2] 18 | } 19 | obs = env.step(action) 20 | print(obs) 21 | time.sleep(1) 22 | action = { 23 | 'red_team': [1,3], 24 | 'blue_team': [1,2] 25 | } 26 | obs = env.step(action) 27 | print(obs) 28 | time.sleep(1) 29 | env.close() -------------------------------------------------------------------------------- /tests/nop_bullet_freeze/graph_gen.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[16], [6]], 101 | obstacle_loc=[]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/nop_bullet_freeze/test_case.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 6 | obs = env.reset() 7 | 8 | # Sample actions: 9 | action = { 10 | 'red_team': [7], 11 | 'blue_team': [8] 12 | } 13 | obs = env.step(action) 14 | print(obs) 15 | time.sleep(1) 16 | action = { 17 | 'red_team': [7], 18 | 'blue_team': [8] 19 | } 20 | obs = env.step(action) 21 | print(obs) 22 | time.sleep(1) 23 | for _ in range(10): 24 | action = { 25 | 'red_team': [8], 26 | 'blue_team': [8] 27 | } 28 | obs = env.step(action) 29 | print(obs) 30 | time.sleep(1) 31 | 32 | env.close() -------------------------------------------------------------------------------- /tests/observation_rgb/gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[2], [5]], 101 | obstacle_loc=[26, 27, 34, 35, 36, 44]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/observation_rgb/test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='rgb_array') 6 | obs = env.reset() 7 | 8 | # Sample actions: 9 | action = { 10 | 'red_team': [7], 11 | 'blue_team': [1] 12 | } 13 | obs = env.step(action) 14 | print(obs) 15 | print(env.render()) 16 | for _ in range(5): 17 | time.sleep(1) 18 | action = { 19 | 'red_team': [1], 20 | 'blue_team': [1] 21 | } 22 | obs = env.step(action) 23 | print(obs) 24 | print(env.render()) 25 | env.close() -------------------------------------------------------------------------------- /tests/obstacle_shooting/obstacle_shooting_gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[1], [6]], 101 | obstacle_loc=[4]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/obstacle_shooting/obstacle_shooting_test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 6 | obs = env.reset() 7 | 8 | # Sample actions: 9 | action = { 10 | 'red_team': [7], 11 | 'blue_team': [1] 12 | } 13 | obs = env.step(action) 14 | print(obs) 15 | time.sleep(1) 16 | action = { 17 | 'red_team': [1], 18 | 'blue_team': [1] 19 | } 20 | obs = env.step(action) 21 | print(obs) 22 | time.sleep(1) 23 | action = { 24 | 'red_team': [1], 25 | 'blue_team': [1] 26 | } 27 | obs = env.step(action) 28 | print(obs) 29 | time.sleep(1) 30 | action = { 31 | 'red_team': [1], 32 | 'blue_team': [1] 33 | } 34 | obs = env.step(action) 35 | print(obs) 36 | time.sleep(1) 37 | action = { 38 | 'red_team': [1], 39 | 'blue_team': [1] 40 | } 41 | obs = env.step(action) 42 | print(obs) 43 | time.sleep(1) 44 | env.close() -------------------------------------------------------------------------------- /tests/out_bounds_shooting/out_bounds_shooting_gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[16], [6]], 101 | obstacle_loc=[]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/out_bounds_shooting/out_bounds_shooting_test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 6 | obs = env.reset() 7 | 8 | # Sample actions: 9 | action = { 10 | 'red_team': [5], 11 | 'blue_team': [4] 12 | } 13 | obs = env.step(action) 14 | print(obs) 15 | time.sleep(1) 16 | for _ in range(10): 17 | action = { 18 | 'red_team': [2], 19 | 'blue_team': [1] 20 | } 21 | obs = env.step(action) 22 | print(obs) 23 | time.sleep(1) 24 | 25 | env.close() -------------------------------------------------------------------------------- /tests/random_action_sample/gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[2], [5]], 101 | obstacle_loc=[26, 27, 34, 35, 36, 44]) 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /tests/random_action_sample/test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 6 | obs = env.reset() 7 | 8 | print("Reset:", obs) 9 | for _ in range(20): 10 | time.sleep(1) 11 | action = env.action_space.sample() 12 | 13 | obs, _, terminated, truncated, _ = env.step(action) 14 | print(obs) 15 | done = terminated or truncated 16 | if done: 17 | obs = env.reset() 18 | env.close() -------------------------------------------------------------------------------- /tests/repetitive_shooting/graph_gen.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[2], [4]], 101 | obstacle_loc=[]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/repetitive_shooting/test_case.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 6 | obs = env.reset() 7 | 8 | # Sample actions: 9 | action = { 10 | 'red_team': [7], 11 | 'blue_team': [1] 12 | } 13 | obs = env.step(action) 14 | print(obs) 15 | time.sleep(1) 16 | action = { 17 | 'red_team': [1], 18 | 'blue_team': [1] 19 | } 20 | obs = env.step(action) 21 | print(obs) 22 | time.sleep(1) 23 | action = { 24 | 'red_team': [7], 25 | 'blue_team': [1] 26 | } 27 | obs = env.step(action) 28 | print(obs) 29 | time.sleep(1) 30 | for _ in range(10): 31 | action = { 32 | 'red_team': [1], 33 | 'blue_team': [1] 34 | } 35 | obs = env.step(action) 36 | print(obs) 37 | time.sleep(1) 38 | 39 | env.close() -------------------------------------------------------------------------------- /tests/same_location_shoot_all_dir/gen_graph_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def generate_graph(grid_dim, num_agents_per_team, base_loc, start_loc, obstacle_loc): 5 | # Check parameters 6 | assert len(base_loc) == 2, 'There are only two bases--supply two locations' 7 | assert len(start_loc) == 2, 'There are only two teams--supply lists of start positions for each team in a nested list' 8 | assert len(start_loc[0]) == num_agents_per_team and len(start_loc[1]) == num_agents_per_team, 'Supply correct number of start locations' 9 | 10 | g = nx.DiGraph() 11 | 12 | # Game variables 13 | game_height = grid_dim 14 | game_width = grid_dim 15 | 16 | # ======================================================================= 17 | # Add a node for each grid location in the game 18 | nodes = list(range(0, game_height * game_width)) 19 | g.add_nodes_from([str(node) for node in nodes], blocked='0,0') 20 | 21 | # ======================================================================= 22 | # Add edges connecting each of the grid nodes. Add up, down, left, right attributes to correct edges 23 | # Right edges 24 | for node in g.nodes: 25 | if (int(node) + 1) % game_width != 0: 26 | g.add_edge(node, str(int(node) + 1), right=1) 27 | 28 | # Left edges 29 | for node in g.nodes: 30 | if int(node) % game_width != 0: 31 | g.add_edge(node, str(int(node) - 1), left=1) 32 | 33 | # Up edges 34 | for node in g.nodes: 35 | if (int(node) // game_width) + 1 != game_height: 36 | g.add_edge(node, str(int(node) + game_width), up=1) 37 | 38 | # Down edges 39 | for node in g.nodes: 40 | if int(node) // game_width != 0: 41 | g.add_edge(node, str(int(node) - game_width), down=1) 42 | 43 | # Add edges between border nodes and end nodes to demarcate the end of the grid world. Bullets will disappear after it crosses this end 44 | g.add_node('end', blocked=1) 45 | # Bottom border 46 | for i in range(game_width): 47 | g.add_edge(f'{i}', 'end', down=1) 48 | # Top border 49 | for i in range(game_width): 50 | g.add_edge(f'{i + game_width * (game_height - 1)}', 'end', up=1) 51 | # Left border 52 | for i in range(game_height): 53 | g.add_edge(f'{i * game_width}', 'end', left=1) 54 | # Right border 55 | for i in range(game_height): 56 | g.add_edge(f'{i * game_width + game_width - 1}', 'end', right=1) 57 | 58 | # ======================================================================= 59 | # Add the bases and connect them to the correct location: bottom right and top left 60 | g.add_node('red-base') 61 | g.add_node('blue-base') 62 | g.add_edge('red-base', str(base_loc[0]), atLoc=1) 63 | g.add_edge('blue-base', str(base_loc[1]), atLoc=1) 64 | 65 | # ======================================================================= 66 | # Add mountains and obstacle attributes 67 | mountain_loc = obstacle_loc 68 | g.add_node('mountain', isMountain=1) 69 | for i in mountain_loc: 70 | g.add_edge(str(i), 'mountain', atLoc=1) 71 | 72 | # ======================================================================= 73 | # Initialize players health, action choice and team 74 | for i in range(1, num_agents_per_team + 1): 75 | g.add_node(f'red-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpRed=0, 76 | shootDownRed=0, shootLeftRed=0, shootRightRed=0, teamRed=1, justDied='0,0') 77 | g.add_node(f'blue-soldier-{i}', health=1, moveUp=0, moveDown=0, moveLeft=0, moveRight=0, shootUpBlue=0, 78 | shootDownBlue=0, shootLeftBlue=0, shootRightBlue=0, teamBlue=1, justDied='0,0') 79 | # Teams 80 | g.add_edge(f'red-soldier-{i}', 'red-base', team=1) 81 | g.add_edge(f'blue-soldier-{i}', 'blue-base', team=1) 82 | # Soldier Start Locations (dual edge) 83 | g.add_edge(f'red-soldier-{i}', str(start_loc[0][i - 1]), atLoc=1) 84 | g.add_edge(f'blue-soldier-{i}', str(start_loc[1][i - 1]), atLoc=1) 85 | g.add_edge(str(start_loc[0][i - 1]), f'red-soldier-{i}', atLoc=1) 86 | g.add_edge(str(start_loc[1][i - 1]), f'blue-soldier-{i}', atLoc=1) 87 | # Bullets 88 | g.add_node(f'red-bullet-{i}', teamRed=1, bullet=1) 89 | g.add_node(f'blue-bullet-{i}', teamBlue=1, bullet=1) 90 | g.add_edge(f'red-soldier-{i}', f'red-bullet-{i}', bullet=1) 91 | g.add_edge(f'blue-soldier-{i}', f'blue-bullet-{i}', bullet=1) 92 | 93 | # ======================================================================= 94 | 95 | nx.write_graphml_lxml(g, 'pyreason_gym/pyreason_grid_world/graph/game_graph.graphml', named_key_ids=True) 96 | 97 | 98 | def main(): 99 | ## Red is first then Blue 100 | generate_graph(grid_dim=8, num_agents_per_team=1, base_loc=[7, 56], start_loc=[[2], [3]], 101 | obstacle_loc=[26, 27, 34, 35, 36, 44]) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /tests/same_location_shoot_all_dir/test.py: -------------------------------------------------------------------------------- 1 | import pyreason_gym 2 | import gym 3 | import time 4 | 5 | env = gym.make('PyReasonGridWorld-v0', render_mode='human') 6 | obs = env.reset() 7 | action = { 8 | 'red_team': [3], 9 | 'blue_team': [1] 10 | } 11 | obs = env.step(action) 12 | print(obs) 13 | time.sleep(1) 14 | # action = { 15 | # 'red_team': [5], 16 | # 'blue_team': [1] 17 | # } 18 | # obs = env.step(action) 19 | # print(obs) 20 | # time.sleep(1) 21 | # action = { 22 | # 'red_team': [6], 23 | # 'blue_team': [1] 24 | # } 25 | # obs = env.step(action) 26 | # print(obs) 27 | # time.sleep(1) 28 | action = { 29 | 'red_team': [7], 30 | 'blue_team': [1] 31 | } 32 | obs = env.step(action) 33 | print(obs) 34 | time.sleep(1) 35 | action = { 36 | 'red_team': [3], 37 | 'blue_team': [1] 38 | } 39 | obs = env.step(action) 40 | print(obs) 41 | time.sleep(1) 42 | action = { 43 | 'red_team': [2], 44 | 'blue_team': [1] 45 | } 46 | obs = env.step(action) 47 | print(obs) 48 | time.sleep(1) 49 | env.close() --------------------------------------------------------------------------------