├── .github └── workflows │ ├── python-package.yml │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── examples ├── colab_example.ipynb ├── interactive_agent.py └── random_agent.py ├── ma_gym ├── __init__.py ├── envs │ ├── __init__.py │ ├── checkers │ │ ├── __init__.py │ │ └── checkers.py │ ├── combat │ │ ├── __init__.py │ │ └── combat.py │ ├── lumberjacks │ │ ├── __init__.py │ │ └── lumberjacks.py │ ├── openai │ │ └── __init__.py │ ├── pong_duel │ │ ├── __init__.py │ │ └── pong_duel.py │ ├── predator_prey │ │ ├── __init__.py │ │ └── predator_prey.py │ ├── switch │ │ ├── __init__.py │ │ └── switch_one_corridor.py │ ├── traffic_junction │ │ ├── __init__.py │ │ └── traffic_junction.py │ └── utils │ │ ├── __init__.py │ │ ├── action_space.py │ │ ├── draw.py │ │ └── observation_space.py └── wrappers │ ├── __init__.py │ ├── monitor.py │ └── monitoring │ ├── __init__.py │ └── stats_recorder.py ├── scripts ├── generate_env_markdown_table.py └── record_environment.py ├── setup.py ├── static └── gif │ ├── Checkers-v0.gif │ ├── Combat-v0.gif │ ├── Lumberjacks-v0.gif │ ├── PongDuel-v0.gif │ ├── PredatorPrey5x5-v0.gif │ ├── PredatorPrey7x7-v0.gif │ ├── Switch2-v0.gif │ ├── Switch4-v0.gif │ ├── TrafficJunction10-v0.gif │ └── TrafficJunction4-v0.gif └── tests ├── __init__.py └── envs ├── __init__.py ├── test_checkers.py ├── test_combat.py ├── test_lumberjacks.py ├── test_openai_cartpole.py ├── test_pong_duel.py ├── test_predatorprey5x5.py ├── test_predatorprey7x7.py ├── test_switch2.py └── test_trafficjunction.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | schedule: 12 | # https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule 13 | - cron: '0 1 * * *' 14 | workflow_dispatch: 15 | inputs: 16 | logLevel: 17 | description: 'Log level' 18 | required: true 19 | default: 'warning' 20 | type: choice 21 | options: 22 | - info 23 | - warning 24 | - debug 25 | 26 | jobs: 27 | build: 28 | runs-on: ${{ matrix.os }} 29 | strategy: 30 | matrix: 31 | os: [macos-latest, ubuntu-latest] 32 | python-version: [3.8, 3.9, '3.10', '3.11'] 33 | 34 | steps: 35 | - uses: actions/checkout@v2 36 | - name: Set up Python ${{ matrix.python-version }} 37 | uses: actions/setup-python@v2 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | - name: Install dependencies 41 | run: | 42 | python -m pip install --upgrade 'pip<24.1' 43 | python -m pip install flake8 pytest 44 | pip install 'setuptools<=66' 45 | pip install 'wheel<=0.38.4' 46 | pip install -e . 47 | pip install -e ".[test]" 48 | pip freeze 49 | - name: Lint with flake8 50 | run: | 51 | # stop the build if there are Python syntax errors or undefined names 52 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 53 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 54 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 55 | - name: Test with pytest 56 | run: | 57 | pytest 58 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | 11 | jobs: 12 | deploy: 13 | 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade 'pip<=23.0.1' 24 | pip install setuptools wheel twine 'readme_renderer[md]' 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pytest_cache 2 | *.egg-info 3 | *.egg 4 | *.pyc 5 | *.coverage 6 | /logs 7 | /dist 8 | /build 9 | /examples/recordings 10 | *~ 11 | 12 | # Pycharm 13 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 Anurag Koul 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ma-gym 2 | It's a collection of multi agent environments based on OpenAI gym. Also, you can use [**minimal-marl**](https://github.com/koulanurag/minimal-marl) to warm-start training of agents. 3 | 4 | ![Python package](https://github.com/koulanurag/ma-gym/workflows/Python%20package/badge.svg) 5 | ![Upload Python Package](https://github.com/koulanurag/ma-gym/workflows/Upload%20Python%20Package/badge.svg) 6 | ![Python Version](https://img.shields.io/pypi/pyversions/ma-gym) 7 | [![Downloads](https://static.pepy.tech/badge/ma-gym)](https://pepy.tech/project/ma-gym) 8 | [![Wiki Docs](https://img.shields.io/badge/-Wiki%20Docs-informational?style=flat)](https://github.com/koulanurag/ma-gym/wiki) 9 | [![Papers using ma-gym](https://img.shields.io/badge/-Papers%20using%20ma--gym-9cf?style=flat&logo=googlescholar)](https://scholar.google.com/scholar?oi=bibs&hl=en&cites=14123576959169220642,12284637994392993807) 10 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/koulanurag/ma-gym/blob/master/examples/colab_example.ipynb) 11 | 12 | 13 | ## Installation 14 | 15 | - Setup (important): 16 | ```bash 17 | pip install 'pip<24.1' 18 | pip install 'setuptools<=66' 19 | pip install 'wheel<=0.38.4' 20 | ``` 21 | - Install package: 22 | - Using PyPI: 23 | ```bash 24 | pip install ma-gym 25 | ``` 26 | 27 | - Directly from source (recommended): 28 | ```bash 29 | git clone https://github.com/koulanurag/ma-gym.git 30 | cd ma-gym 31 | pip install -e . 32 | ``` 33 | ## Reference: 34 | Please use this bibtex if you would like to cite it: 35 | ``` 36 | @misc{magym, 37 | author = {Koul, Anurag}, 38 | title = {ma-gym: Collection of multi-agent environments based on OpenAI gym.}, 39 | year = {2019}, 40 | publisher = {GitHub}, 41 | journal = {GitHub repository}, 42 | howpublished = {\url{https://github.com/koulanurag/ma-gym}}, 43 | } 44 | ``` 45 | 46 | ## Usage: 47 | ```python 48 | import gym 49 | 50 | env = gym.make('ma_gym:Switch2-v0') 51 | done_n = [False for _ in range(env.n_agents)] 52 | ep_reward = 0 53 | 54 | obs_n = env.reset() 55 | while not all(done_n): 56 | env.render() 57 | obs_n, reward_n, done_n, info = env.step(env.action_space.sample()) 58 | ep_reward += sum(reward_n) 59 | env.close() 60 | ``` 61 | 62 | Please refer to [**Wiki**](https://github.com/koulanurag/ma-gym/wiki/Usage) for complete usage details 63 | 64 | ## Environments: 65 | - [x] Checkers 66 | - [x] Combat 67 | - [x] PredatorPrey 68 | - [x] Pong Duel ```(two player pong game)``` 69 | - [x] Switch 70 | - [x] Lumberjacks 71 | - [x] TrafficJunction 72 | 73 | ``` 74 | Note : openai's environment can be accessed in multi agent form by prefix "ma_".Eg: ma_CartPole-v0 75 | This returns an instance of CartPole-v0 in "multi agent wrapper" having a single agent. 76 | These environments are helpful during debugging. 77 | ``` 78 | 79 | Please refer to [Wiki](https://github.com/koulanurag/ma-gym/wiki/Environments) for more details. 80 | 81 | ## Zoo! 82 | 83 | | __Checkers-v0__ | __Combat-v0__ | __Lumberjacks-v0__ | 84 | |:---:|:---:|:---:| 85 | |![Checkers-v0.gif](https://raw.githubusercontent.com/koulanurag/ma-gym/master/static/gif/Checkers-v0.gif)|![Combat-v0.gif](https://raw.githubusercontent.com/koulanurag/ma-gym/master/static/gif/Combat-v0.gif)|![Lumberjacks-v0.gif](https://raw.githubusercontent.com/koulanurag/ma-gym/master/static/gif/Lumberjacks-v0.gif)| 86 | | __PongDuel-v0__ | __PredatorPrey5x5-v0__ | __PredatorPrey7x7-v0__ | 87 | | ![PongDuel-v0.gif](https://raw.githubusercontent.com/koulanurag/ma-gym/master/static/gif/PongDuel-v0.gif) | ![PredatorPrey5x5-v0.gif](https://raw.githubusercontent.com/koulanurag/ma-gym/master/static/gif/PredatorPrey5x5-v0.gif) | ![PredatorPrey7x7-v0.gif](https://raw.githubusercontent.com/koulanurag/ma-gym/master/static/gif/PredatorPrey7x7-v0.gif) | 88 | | __Switch2-v0__ | __Switch4-v0__ | __TrafficJunction4-v0__ | | 89 | | ![Switch2-v0.gif](https://raw.githubusercontent.com/koulanurag/ma-gym/master/static/gif/Switch2-v0.gif) | ![Switch4-v0.gif](https://raw.githubusercontent.com/koulanurag/ma-gym/master/static/gif/Switch4-v0.gif)|![TrafficJunction4-v0.gif](https://raw.githubusercontent.com/koulanurag/ma-gym/master/static/gif/TrafficJunction4-v0.gif)| 90 | | __TrafficJunction10-v0__ | 91 | |![TrafficJunction10-v0.gif](https://raw.githubusercontent.com/koulanurag/ma-gym/master/static/gif/TrafficJunction10-v0.gif)| | | 92 | 93 | ## Testing: 94 | 95 | - Install: ```pip install -e ".[test]" ``` 96 | - Run: ```pytest``` 97 | 98 | 99 | ## Acknowledgement: 100 | - This project was initially developed to complement my research internship @ [SAS](https://www.sas.com/en_us/home.html) (Summer - 2019). 101 | 102 | 103 | -------------------------------------------------------------------------------- /examples/colab_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | }, 12 | "language_info": { 13 | "name": "python" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "source": [ 20 | "# **Dependency Setup**" 21 | ], 22 | "metadata": { 23 | "id": "7bic0d8YZt26" 24 | } 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": { 30 | "colab": { 31 | "base_uri": "https://localhost:8080/" 32 | }, 33 | "id": "Mgxn1Fw6ZAs7", 34 | "outputId": "c08fcc0c-612d-444a-f184-d869b788e675" 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "!pip install --root-user-action=ignore 'pip<=23.0.1'\n", 39 | "!pip install --root-user-action=ignore 'setuptools<=66'\n", 40 | "!pip install --root-user-action=ignore 'wheel<=0.38.4'\n", 41 | "!pip install --root-user-action=ignore ma-gym" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "source": [ 47 | "# **Demo Run:**" 48 | ], 49 | "metadata": { 50 | "id": "tl0aaZuMbAEF" 51 | } 52 | }, 53 | { 54 | "cell_type": "code", 55 | "source": [ 56 | "import gym\n", 57 | "\n", 58 | "env = gym.make('ma_gym:Switch2-v0')\n", 59 | "done_n = [False for _ in range(env.n_agents)]\n", 60 | "ep_reward = 0\n", 61 | "\n", 62 | "obs_n = env.reset()\n", 63 | "while not all(done_n):\n", 64 | " obs_n, reward_n, done_n, info = env.step(env.action_space.sample())\n", 65 | " ep_reward += sum(reward_n)\n", 66 | "env.close()\n", 67 | "print(f'Episode Return:{ep_reward:.2f}')" 68 | ], 69 | "metadata": { 70 | "colab": { 71 | "base_uri": "https://localhost:8080/" 72 | }, 73 | "id": "mEQ9X0cQZtmV", 74 | "outputId": "84fae668-a2f9-495b-d95f-0e40a77d0f24" 75 | }, 76 | "execution_count": null, 77 | "outputs": [] 78 | } 79 | ] 80 | } 81 | -------------------------------------------------------------------------------- /examples/interactive_agent.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | from ma_gym.wrappers import Monitor 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser(description='Interactive Agent for ma-gym') 9 | parser.add_argument('--env', default='Checkers-v0', 10 | help='Name of the environment (default: %(default)s)') 11 | parser.add_argument('--episodes', type=int, default=1, 12 | help='episodes (default: %(default)s)') 13 | args = parser.parse_args() 14 | 15 | print('Enter the actions space together and press enter ( Eg: \'11\' which meanes take 1' 16 | ' for agent 1 and 1 for agent 2)') 17 | 18 | env = gym.make('ma_gym:{}'.format(args.env)) 19 | env = Monitor(env, directory='recordings', force=True) 20 | for ep_i in range(args.episodes): 21 | done_n = [False for _ in range(env.n_agents)] 22 | ep_reward = 0 23 | 24 | obs_n = env.reset() 25 | env.render() 26 | while not all(done_n): 27 | action_n = [int(_) for _ in input('Action:')] 28 | obs_n, reward_n, done_n, _ = env.step(action_n) 29 | ep_reward += sum(reward_n) 30 | env.render() 31 | 32 | print('Episode #{} Reward: {}'.format(ep_i, ep_reward)) 33 | env.close() 34 | -------------------------------------------------------------------------------- /examples/random_agent.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | from ma_gym.wrappers import Monitor 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser(description='Random Agent for ma-gym') 9 | parser.add_argument('--env', default='Checkers-v0', 10 | help='Name of the environment (default: %(default)s)') 11 | parser.add_argument('--episodes', type=int, default=1, 12 | help='episodes (default: %(default)s)') 13 | args = parser.parse_args() 14 | 15 | env = gym.make(args.env) 16 | env = Monitor(env, directory='recordings/' + args.env, force=True) 17 | for ep_i in range(args.episodes): 18 | done_n = [False for _ in range(env.n_agents)] 19 | ep_reward = 0 20 | 21 | env.seed(ep_i) 22 | obs_n = env.reset() 23 | env.render() 24 | 25 | while not all(done_n): 26 | action_n = env.action_space.sample() 27 | obs_n, reward_n, done_n, info = env.step(action_n) 28 | ep_reward += sum(reward_n) 29 | env.render() 30 | 31 | print('Episode #{} Reward: {}'.format(ep_i, ep_reward)) 32 | env.close() 33 | -------------------------------------------------------------------------------- /ma_gym/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from gym import envs 4 | from gym.envs.registration import register 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | # Register openai's environments as multi agent 9 | # This should be done before registering new environments 10 | env_specs = [env_spec for env_spec in envs.registry.all() if 'gym.envs' in env_spec.entry_point] 11 | for spec in env_specs: 12 | register( 13 | id='ma_' + spec.id, 14 | entry_point='ma_gym.envs.openai:MultiAgentWrapper', 15 | kwargs={'name': spec.id, **spec._kwargs} 16 | ) 17 | 18 | # add new environments : iterate over full observability 19 | for i, observability in enumerate([False, True]): 20 | 21 | for clock in [False, True]: 22 | register( 23 | id='Checkers-v{}'.format(i + (2 if clock else 0)), 24 | entry_point='ma_gym.envs.checkers:Checkers', 25 | kwargs={'full_observable': observability, 'step_cost': -0.01, 'clock': clock} 26 | ) 27 | register( 28 | id='Switch2-v{}'.format(i + (2 if clock else 0)), 29 | entry_point='ma_gym.envs.switch:Switch', 30 | kwargs={'n_agents': 2, 'full_observable': observability, 'step_cost': -0.1, 'clock': clock} 31 | ) 32 | register( 33 | id='Switch4-v{}'.format(i + (2 if clock else 0)), 34 | entry_point='ma_gym.envs.switch:Switch', 35 | kwargs={'n_agents': 4, 'full_observable': observability, 'step_cost': -0.1, 'clock': clock} 36 | ) 37 | 38 | for num_max_cars in [4, 10]: 39 | register( 40 | id='TrafficJunction{}-v'.format(num_max_cars) + str(i), 41 | entry_point='ma_gym.envs.traffic_junction:TrafficJunction', 42 | kwargs={'full_observable': observability, 'n_max': num_max_cars} 43 | ) 44 | 45 | register( 46 | id='Lumberjacks-v' + str(i), 47 | entry_point='ma_gym.envs.lumberjacks:Lumberjacks', 48 | kwargs={'full_observable': observability} 49 | ) 50 | 51 | register( 52 | id='Combat-v0', 53 | entry_point='ma_gym.envs.combat:Combat', 54 | ) 55 | register( 56 | id='PongDuel-v0', 57 | entry_point='ma_gym.envs.pong_duel:PongDuel', 58 | ) 59 | 60 | for game_info in [[(5, 5), 2, 1], [(7, 7), 4, 2]]: # [(grid_shape, predator_n, prey_n),..] 61 | grid_shape, n_agents, n_preys = game_info 62 | _game_name = 'PredatorPrey{}x{}'.format(grid_shape[0], grid_shape[1]) 63 | register( 64 | id='{}-v0'.format(_game_name), 65 | entry_point='ma_gym.envs.predator_prey:PredatorPrey', 66 | kwargs={ 67 | 'grid_shape': grid_shape, 'n_agents': n_agents, 'n_preys': n_preys 68 | } 69 | ) 70 | # fully -observable ( each agent sees observation of other agents) 71 | register( 72 | id='{}-v1'.format(_game_name), 73 | entry_point='ma_gym.envs.predator_prey:PredatorPrey', 74 | kwargs={ 75 | 'grid_shape': grid_shape, 'n_agents': n_agents, 'n_preys': n_preys, 'full_observable': True 76 | } 77 | ) 78 | 79 | # prey is initialized at random location and thereafter doesn't move 80 | register( 81 | id='{}-v2'.format(_game_name), 82 | entry_point='ma_gym.envs.predator_prey:PredatorPrey', 83 | kwargs={ 84 | 'grid_shape': grid_shape, 'n_agents': n_agents, 'n_preys': n_preys, 85 | 'prey_move_probs': [0, 0, 0, 0, 1] 86 | } 87 | ) 88 | 89 | # full observability + prey is initialized at random location and thereafter doesn't move 90 | register( 91 | id='{}-v3'.format(_game_name), 92 | entry_point='ma_gym.envs.predator_prey:PredatorPrey', 93 | kwargs={ 94 | 'grid_shape': grid_shape, 'n_agents': n_agents, 'n_preys': n_preys, 'full_observable': True, 95 | 'prey_move_probs': [0, 0, 0, 0, 1] 96 | } 97 | ) 98 | -------------------------------------------------------------------------------- /ma_gym/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/ma_gym/envs/__init__.py -------------------------------------------------------------------------------- /ma_gym/envs/checkers/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkers import Checkers -------------------------------------------------------------------------------- /ma_gym/envs/checkers/checkers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | 4 | import gym 5 | import numpy as np 6 | from gym import spaces 7 | from gym.utils import seeding 8 | 9 | from ..utils.action_space import MultiAgentActionSpace 10 | from ..utils.draw import draw_grid, fill_cell, draw_circle, write_cell_text 11 | from ..utils.observation_space import MultiAgentObservationSpace 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class Checkers(gym.Env): 17 | """ 18 | The map contains apples and lemons. The first player (red) is very sensitive and scores 10 for 19 | the team for an apple (green square) and −10 for a lemon (orange square). The second (blue), less sensitive 20 | player scores 1 for the team for an apple and −1 for a lemon. There is a wall of lemons between the 21 | players and the apples. Apples and lemons disappear when collected, and the environment resets 22 | when all apples are eaten. It is important that the sensitive agent eats the apples while the less sensitive 23 | agent should leave them to its team mate but clear the way by eating obstructing lemons. 24 | 25 | Reference Paper : Value-Decomposition Networks For Cooperative Multi-Agent Learning ( Section 4.2) 26 | """ 27 | metadata = {'render.modes': ['human', 'rgb_array']} 28 | 29 | def __init__(self, full_observable=False, step_cost=-0.01, max_steps=100, clock=False): 30 | self._grid_shape = (3, 8) 31 | self.n_agents = 2 32 | self._max_steps = max_steps 33 | self._step_count = None 34 | self._step_cost = step_cost 35 | self.full_observable = full_observable 36 | self._add_clock = clock 37 | 38 | self.action_space = MultiAgentActionSpace([spaces.Discrete(5) for _ in range(self.n_agents)]) 39 | self._obs_high = np.ones(2 + (3 * 3 * 5) + (1 if clock else 0)) 40 | self._obs_low = np.zeros(2 + (3 * 3 * 5) + (1 if clock else 0)) 41 | if self.full_observable: 42 | self._obs_high = np.tile(self._obs_high, self.n_agents) 43 | self._obs_low = np.tile(self._obs_low, self.n_agents) 44 | self.observation_space = MultiAgentObservationSpace([spaces.Box(self._obs_low, self._obs_high) 45 | for _ in range(self.n_agents)]) 46 | 47 | self.init_agent_pos = {0: [0, self._grid_shape[1] - 2], 1: [2, self._grid_shape[1] - 2]} 48 | self.agent_reward = {0: {'lemon': -10, 'apple': 10}, 49 | 1: {'lemon': -1, 'apple': 1}} 50 | 51 | self.agent_prev_pos = None 52 | self._base_grid = None 53 | self._full_obs = None 54 | self._agent_dones = None 55 | self.viewer = None 56 | self._food_count = None 57 | self._total_episode_reward = None 58 | self.steps_beyond_done = None 59 | self.seed() 60 | 61 | def get_action_meanings(self, agent_i=None): 62 | if agent_i is not None: 63 | assert agent_i <= self.n_agents 64 | return [ACTION_MEANING[i] for i in range(self.action_space[agent_i].n)] 65 | else: 66 | return [[ACTION_MEANING[i] for i in range(ac.n)] for ac in self.action_space] 67 | 68 | def __draw_base_img(self): 69 | self._base_img = draw_grid(self._grid_shape[0], self._grid_shape[1], cell_size=CELL_SIZE, fill='white') 70 | for row in range(self._grid_shape[0]): 71 | for col in range(self._grid_shape[1]): 72 | if PRE_IDS['wall'] in self._full_obs[row][col]: 73 | fill_cell(self._base_img, (row, col), cell_size=CELL_SIZE, fill=WALL_COLOR, margin=0.05) 74 | elif PRE_IDS['lemon'] in self._full_obs[row][col]: 75 | fill_cell(self._base_img, (row, col), cell_size=CELL_SIZE, fill=LEMON_COLOR, margin=0.05) 76 | elif PRE_IDS['apple'] in self._full_obs[row][col]: 77 | fill_cell(self._base_img, (row, col), cell_size=CELL_SIZE, fill=APPLE_COLOR, margin=0.05) 78 | 79 | def __create_grid(self): 80 | """create grid and fill in lemon and apple locations. This grid doesn't fill agents location""" 81 | _grid = [] 82 | for row in range(self._grid_shape[0]): 83 | if row % 2 == 0: 84 | _grid.append([PRE_IDS['apple'] if (c % 2 == 0) else PRE_IDS['lemon'] 85 | for c in range(self._grid_shape[1] - 2)] + [PRE_IDS['empty'], PRE_IDS['empty']]) 86 | else: 87 | _grid.append([PRE_IDS['apple'] if (c % 2 != 0) else PRE_IDS['lemon'] 88 | for c in range(self._grid_shape[1] - 2)] + [PRE_IDS['empty'], PRE_IDS['empty']]) 89 | 90 | return _grid 91 | 92 | def __init_full_obs(self): 93 | self.agent_pos = copy.copy(self.init_agent_pos) 94 | self.agent_prev_pos = copy.copy(self.init_agent_pos) 95 | self._full_obs = self.__create_grid() 96 | for agent_i in range(self.n_agents): 97 | self.__update_agent_view(agent_i) 98 | self.__draw_base_img() 99 | 100 | def get_agent_obs(self): 101 | _obs = [] 102 | for agent_i in range(self.n_agents): 103 | pos = self.agent_pos[agent_i] 104 | 105 | # add coordinates 106 | _agent_i_obs = [round(pos[0] / self._grid_shape[0], 2), 107 | round(pos[1] / (self._grid_shape[1] - 1), 2)] 108 | 109 | # add 3 x3 mask around the agent current location and share neighbours 110 | # ( in practice: this information may not be so critical since the map never changes) 111 | _agent_i_neighbour = np.zeros((3, 3, 5)) 112 | for r in range(pos[0] - 1, pos[0] + 2): 113 | for c in range(pos[1] - 1, pos[1] + 2): 114 | if self.is_valid((r, c)): 115 | item = [0, 0, 0, 0, 0] 116 | if PRE_IDS['lemon'] in self._full_obs[r][c]: 117 | item[ITEM_ONE_HOT_INDEX['lemon']] = 1 118 | elif PRE_IDS['apple'] in self._full_obs[r][c]: 119 | item[ITEM_ONE_HOT_INDEX['apple']] = 1 120 | elif PRE_IDS['agent'] in self._full_obs[r][c]: 121 | item[ITEM_ONE_HOT_INDEX[self._full_obs[r][c]]] = 1 122 | elif PRE_IDS['wall'] in self._full_obs[r][c]: 123 | item[ITEM_ONE_HOT_INDEX['wall']] = 1 124 | _agent_i_neighbour[r - (pos[0] - 1)][c - (pos[1] - 1)] = item 125 | _agent_i_obs += _agent_i_neighbour.flatten().tolist() 126 | 127 | # adding time 128 | if self._add_clock: 129 | _agent_i_obs += [self._step_count / self._max_steps] 130 | _obs.append(_agent_i_obs) 131 | 132 | if self.full_observable: 133 | _obs = np.array(_obs).flatten().tolist() 134 | _obs = [_obs for _ in range(self.n_agents)] 135 | return _obs 136 | 137 | def reset(self): 138 | self.__init_full_obs() 139 | self._step_count = 0 140 | self._total_episode_reward = [0 for _ in range(self.n_agents)] 141 | self._food_count = {'lemon': ((self._grid_shape[1] - 2) // 2) * self._grid_shape[0], 142 | 'apple': ((self._grid_shape[1] - 2) // 2) * self._grid_shape[0]} 143 | self._agent_dones = [False for _ in range(self.n_agents)] 144 | self.steps_beyond_done = None 145 | 146 | return self.get_agent_obs() 147 | 148 | def is_valid(self, pos): 149 | return (0 <= pos[0] < self._grid_shape[0]) and (0 <= pos[1] < self._grid_shape[1]) 150 | 151 | def _has_no_agent(self, pos): 152 | return self.is_valid(pos) and (PRE_IDS['agent'] not in self._full_obs[pos[0]][pos[1]]) 153 | 154 | def __update_agent_pos(self, agent_i, move): 155 | 156 | curr_pos = copy.copy(self.agent_pos[agent_i]) 157 | if move == 0: # down 158 | next_pos = [curr_pos[0] + 1, curr_pos[1]] 159 | elif move == 1: # left 160 | next_pos = [curr_pos[0], curr_pos[1] - 1] 161 | elif move == 2: # up 162 | next_pos = [curr_pos[0] - 1, curr_pos[1]] 163 | elif move == 3: # right 164 | next_pos = [curr_pos[0], curr_pos[1] + 1] 165 | elif move == 4: # no-op 166 | next_pos = None 167 | else: 168 | raise Exception('Action Not found!') 169 | 170 | self.agent_prev_pos[agent_i] = self.agent_pos[agent_i] 171 | if next_pos is not None and self._has_no_agent(next_pos): 172 | self.agent_pos[agent_i] = next_pos 173 | 174 | def __update_agent_view(self, agent_i): 175 | self._full_obs[self.agent_prev_pos[agent_i][0]][self.agent_prev_pos[agent_i][1]] = PRE_IDS['empty'] 176 | self._full_obs[self.agent_pos[agent_i][0]][self.agent_pos[agent_i][1]] = PRE_IDS['agent'] + str(agent_i + 1) 177 | 178 | def step(self, agents_action): 179 | assert (self._step_count is not None), \ 180 | "Call reset before using step method." 181 | 182 | assert len(agents_action) == self.n_agents 183 | 184 | self._step_count += 1 185 | rewards = [self._step_cost for _ in range(self.n_agents)] 186 | 187 | for agent_i, action in enumerate(agents_action): 188 | 189 | self.__update_agent_pos(agent_i, action) 190 | 191 | if self.agent_pos[agent_i] != self.agent_prev_pos[agent_i]: 192 | for food in ['lemon', 'apple']: 193 | if PRE_IDS[food] in self._full_obs[self.agent_pos[agent_i][0]][self.agent_pos[agent_i][1]]: 194 | rewards[agent_i] += self.agent_reward[agent_i][food] 195 | self._food_count[food] -= 1 196 | break 197 | 198 | self.__update_agent_view(agent_i) 199 | 200 | if self._step_count >= self._max_steps or self._food_count['apple'] == 0: 201 | for i in range(self.n_agents): 202 | self._agent_dones[i] = True 203 | 204 | for i in range(self.n_agents): 205 | self._total_episode_reward[i] += rewards[i] 206 | 207 | # Following snippet of code was refereed from: 208 | # https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py#L144 209 | if self.steps_beyond_done is None and all(self._agent_dones): 210 | self.steps_beyond_done = 0 211 | elif self.steps_beyond_done is not None: 212 | if self.steps_beyond_done == 0: 213 | logger.warning( 214 | "You are calling 'step()' even though this environment has already returned all(dones) = True for " 215 | "all agents. You should always call 'reset()' once you receive 'all(dones) = True' -- any further" 216 | " steps are undefined behavior.") 217 | self.steps_beyond_done += 1 218 | rewards = [0 for _ in range(self.n_agents)] 219 | 220 | return self.get_agent_obs(), rewards, self._agent_dones, {'food_count': self._food_count} 221 | 222 | def render(self, mode='human'): 223 | assert (self._step_count is not None), \ 224 | "Call reset before using render method." 225 | 226 | for agent_i in range(self.n_agents): 227 | fill_cell(self._base_img, self.agent_pos[agent_i], cell_size=CELL_SIZE, fill='white', margin=0.05) 228 | fill_cell(self._base_img, self.agent_prev_pos[agent_i], cell_size=CELL_SIZE, fill='white', margin=0.05) 229 | draw_circle(self._base_img, self.agent_pos[agent_i], cell_size=CELL_SIZE, fill=AGENT_COLORS[agent_i]) 230 | write_cell_text(self._base_img, text=str(agent_i + 1), pos=self.agent_pos[agent_i], cell_size=CELL_SIZE, 231 | fill='white', margin=0.4) 232 | 233 | # adds a score board on top of the image 234 | # img = draw_score_board(self._base_img, score=self._total_episode_reward) 235 | # img = np.asarray(img) 236 | 237 | img = np.asarray(self._base_img) 238 | if mode == 'rgb_array': 239 | return img 240 | elif mode == 'human': 241 | from gym.envs.classic_control import rendering 242 | if self.viewer is None: 243 | self.viewer = rendering.SimpleImageViewer() 244 | self.viewer.imshow(img) 245 | return self.viewer.isopen 246 | 247 | def seed(self, n=None): 248 | self.np_random, seed = seeding.np_random(n) 249 | return [seed] 250 | 251 | def close(self): 252 | if self.viewer is not None: 253 | self.viewer.close() 254 | self.viewer = None 255 | 256 | 257 | CELL_SIZE = 30 258 | 259 | ACTION_MEANING = { 260 | 0: "DOWN", 261 | 1: "LEFT", 262 | 2: "UP", 263 | 3: "RIGHT", 264 | 4: "NOOP", 265 | } 266 | 267 | OBSERVATION_MEANING = { 268 | 0: 'empty', 269 | 1: 'lemon', 270 | 2: 'apple', 271 | 3: 'agent', 272 | -1: 'wall' 273 | } 274 | 275 | # each pre-id should be unique and single char 276 | PRE_IDS = { 277 | 'agent': 'A', 278 | 'wall': 'W', 279 | 'empty': '0', 280 | 'lemon': 'Y', # yellow color 281 | 'apple': 'R', # red color 282 | } 283 | 284 | AGENT_COLORS = { 285 | 0: 'red', 286 | 1: 'blue' 287 | } 288 | ITEM_ONE_HOT_INDEX = { 289 | 'lemon': 0, 290 | 'apple': 1, 291 | 'A1': 2, 292 | 'A2': 3, 293 | 'wall': 4, 294 | } 295 | WALL_COLOR = 'black' 296 | LEMON_COLOR = 'yellow' 297 | APPLE_COLOR = 'green' 298 | -------------------------------------------------------------------------------- /ma_gym/envs/combat/__init__.py: -------------------------------------------------------------------------------- 1 | from .combat import Combat 2 | -------------------------------------------------------------------------------- /ma_gym/envs/combat/combat.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import copy 4 | import logging 5 | 6 | import gym 7 | import numpy as np 8 | from gym import spaces 9 | from gym.utils import seeding 10 | 11 | from ..utils.action_space import MultiAgentActionSpace 12 | from ..utils.draw import draw_grid, fill_cell, write_cell_text 13 | from ..utils.observation_space import MultiAgentObservationSpace 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class Combat(gym.Env): 19 | """ 20 | We simulate a simple battle involving two opposing teams in a n x n grid. 21 | Each team consists of m = 5 agents and their initial positions are sampled uniformly in a 5 × 5 22 | square around the team center, which is picked uniformly in the grid. At each time step, an agent can 23 | perform one of the following actions: move one cell in one of four directions; attack another agent 24 | by specifying its ID j (there are m attack actions, each corresponding to one enemy agent); or do 25 | nothing. If agent A attacks agent B, then B’s health point will be reduced by 1, but only if B is inside 26 | the firing range of A (its surrounding 3 × 3 area). Agents need one time step of cooling down after 27 | an attack, during which they cannot attack. All agents start with 3 health points, and die when their 28 | health reaches 0. A team will win if all agents in the other team die. The simulation ends when one 29 | team wins, or neither of teams win within 40 time steps (a draw). 30 | 31 | The model controls one team during training, and the other team consist of bots that follow a hardcoded policy. 32 | The bot policy is to attack the nearest enemy agent if it is within its firing range. If not, 33 | it approaches the nearest visible enemy agent within visual range. An agent is visible to all bots if it 34 | is inside the visual range of any individual bot. This shared vision gives an advantage to the bot team. 35 | 36 | When input to a model, each agent is represented by a set of one-hot binary vectors {i, t, l, h, c} 37 | encoding its unique ID, team ID, location, health points and cooldown. A model controlling an agent 38 | also sees other agents in its visual range (3 × 3 surrounding area). The model gets reward of -1 if the 39 | team loses or draws at the end of the game. In addition, it also get reward of −0.1 times the total 40 | health points of the enemy team, which encourages it to attack enemy bots. 41 | 42 | Reference : Learning Multiagent Communication with Backpropagation 43 | Url : https://papers.nips.cc/paper/6398-learning-multiagent-communication-with-backpropagation.pdf 44 | """ 45 | metadata = {'render.modes': ['human', 'rgb_array']} 46 | 47 | def __init__(self, grid_shape=(15, 15), n_agents=5, n_opponents=5, init_health=3, full_observable=False, 48 | step_cost=0, max_steps=100, step_cool=1): 49 | self._grid_shape = grid_shape 50 | self.n_agents = n_agents 51 | self._n_opponents = n_opponents 52 | self._max_steps = max_steps 53 | self._step_cool = step_cool + 1 54 | self._step_cost = step_cost 55 | self._step_count = None 56 | 57 | self.action_space = MultiAgentActionSpace( 58 | [spaces.Discrete(5 + self._n_opponents) for _ in range(self.n_agents)]) 59 | 60 | self.agent_pos = {_: None for _ in range(self.n_agents)} 61 | self.agent_prev_pos = {_: None for _ in range(self.n_agents)} 62 | self.opp_pos = {_: None for _ in range(self._n_opponents)} 63 | self.opp_prev_pos = {_: None for _ in range(self._n_opponents)} 64 | 65 | self._init_health = init_health 66 | self.agent_health = {_: None for _ in range(self.n_agents)} 67 | self.opp_health = {_: None for _ in range(self._n_opponents)} 68 | self._agent_dones = [None for _ in range(self.n_agents)] 69 | self._agent_cool = {_: None for _ in range(self.n_agents)} 70 | self._agent_cool_step = {_: None for _ in range(self.n_agents)} 71 | self._opp_cool = {_: None for _ in range(self._n_opponents)} 72 | self._opp_cool_step = {_: None for _ in range(self._n_opponents)} 73 | self._total_episode_reward = None 74 | self.viewer = None 75 | self.full_observable = full_observable 76 | 77 | # 5 * 5 * (type, id, health, cool, x, y) 78 | self._obs_low = np.repeat(np.array([-1., 0., 0., -1., 0., 0.], dtype=np.float32), 5 * 5) 79 | self._obs_high = np.repeat(np.array([1., n_opponents, init_health, 1., 1., 1.], dtype=np.float32), 5 * 5) 80 | self.observation_space = MultiAgentObservationSpace( 81 | [spaces.Box(self._obs_low, self._obs_high) for _ in range(self.n_agents)]) 82 | self.seed() 83 | 84 | # For debug only 85 | self._agents_trace = {_: None for _ in range(self.n_agents)} 86 | self._opponents_trace = {_: None for _ in range(self._n_opponents)} 87 | 88 | def get_action_meanings(self, agent_i=None): 89 | action_meaning = [] 90 | for _ in range(self.n_agents): 91 | meaning = [ACTION_MEANING[i] for i in range(5)] 92 | meaning += ['Attack Opponent {}'.format(o) for o in range(self._n_opponents)] 93 | action_meaning.append(meaning) 94 | if agent_i is not None: 95 | assert isinstance(agent_i, int) 96 | assert agent_i <= self.n_agents 97 | 98 | return action_meaning[agent_i] 99 | else: 100 | return action_meaning 101 | 102 | @staticmethod 103 | def _one_hot_encoding(i, n): 104 | x = np.zeros(n) 105 | x[i] = 1 106 | return x.tolist() 107 | 108 | def get_agent_obs(self): 109 | """ 110 | When input to a model, each agent is represented by a set of one-hot binary vectors {i, t, l, h, c} 111 | encoding its team ID, unique ID, location, health points and cooldown. 112 | A model controlling an agent also sees other agents in its visual range (5 × 5 surrounding area). 113 | :return: 114 | """ 115 | _obs = [] 116 | for agent_i in range(self.n_agents): 117 | # team id , unique id, location, health, cooldown 118 | _agent_i_obs = np.zeros((6, 5, 5)) 119 | hp = self.agent_health[agent_i] 120 | 121 | # If agent is alive 122 | if hp > 0: 123 | # _agent_i_obs = self._one_hot_encoding(agent_i, self.n_agents) 124 | # _agent_i_obs += [pos[0] / self._grid_shape[0], pos[1] / (self._grid_shape[1] - 1)] # coordinates 125 | # _agent_i_obs += [self.agent_health[agent_i]] 126 | # _agent_i_obs += [1 if self._agent_cool else 0] # flag if agent is cooling down 127 | 128 | pos = self.agent_pos[agent_i] 129 | for row in range(0, 5): 130 | for col in range(0, 5): 131 | if self.is_valid([row + (pos[0] - 2), col + (pos[1] - 2)]) and ( 132 | PRE_IDS['empty'] not in self._full_obs[row + (pos[0] - 2)][col + (pos[1] - 2)]): 133 | x = self._full_obs[row + pos[0] - 2][col + pos[1] - 2] 134 | _type = 1 if PRE_IDS['agent'] in x else -1 135 | _id = int(x[1:]) - 1 # id 136 | _agent_i_obs[0][row][col] = _type 137 | _agent_i_obs[1][row][col] = _id 138 | _agent_i_obs[2][row][col] = self.agent_health[_id] if _type == 1 else self.opp_health[_id] 139 | _agent_i_obs[3][row][col] = self._agent_cool[_id] if _type == 1 else self._opp_cool[_id] 140 | _agent_i_obs[3][row][col] = 1 if _agent_i_obs[3][row][col] else -1 # cool/uncool 141 | entity_position = self.agent_pos[_id] if _type == 1 else self.opp_pos[_id] 142 | _agent_i_obs[4][row][col] = entity_position[0] / self._grid_shape[0] # x-coordinate 143 | _agent_i_obs[5][row][col] = entity_position[1] / self._grid_shape[1] # y-coordinate 144 | 145 | _agent_i_obs = _agent_i_obs.flatten().tolist() 146 | _obs.append(_agent_i_obs) 147 | return _obs 148 | 149 | def get_state(self): 150 | state = np.zeros((self.n_agents + self._n_opponents, 6)) 151 | # agent info 152 | for agent_i in range(self.n_agents): 153 | hp = self.agent_health[agent_i] 154 | if hp > 0: 155 | pos = self.agent_pos[agent_i] 156 | feature = np.array([1, agent_i, hp, 1 if self._agent_cool[agent_i] else -1, 157 | pos[0] / self._grid_shape[0], pos[1] / self._grid_shape[1]], dtype=np.float) 158 | state[agent_i] = feature 159 | 160 | # opponent info 161 | for opp_i in range(self._n_opponents): 162 | opp_hp = self.opp_health[opp_i] 163 | if opp_hp > 0: 164 | pos = self.opp_pos[opp_i] 165 | feature = np.array([-1, opp_i, opp_hp, 1 if self._opp_cool[opp_i] else -1, 166 | pos[0] / self._grid_shape[0], pos[1] / self._grid_shape[1]], dtype=np.float) 167 | state[opp_i + self.n_agents] = feature 168 | return state.flatten() 169 | 170 | def get_state_size(self): 171 | return (self.n_agents + self._n_opponents) * 6 172 | 173 | def __create_grid(self): 174 | _grid = [[PRE_IDS['empty'] for _ in range(self._grid_shape[1])] for row in range(self._grid_shape[0])] 175 | return _grid 176 | 177 | def __draw_base_img(self): 178 | self._base_img = draw_grid(self._grid_shape[0], self._grid_shape[1], cell_size=CELL_SIZE, fill='white') 179 | 180 | def __update_agent_view(self, agent_i): 181 | self._full_obs[self.agent_prev_pos[agent_i][0]][self.agent_prev_pos[agent_i][1]] = PRE_IDS['empty'] 182 | self._full_obs[self.agent_pos[agent_i][0]][self.agent_pos[agent_i][1]] = PRE_IDS['agent'] + str(agent_i + 1) 183 | 184 | def __update_opp_view(self, opp_i): 185 | self._full_obs[self.opp_prev_pos[opp_i][0]][self.opp_prev_pos[opp_i][1]] = PRE_IDS['empty'] 186 | self._full_obs[self.opp_pos[opp_i][0]][self.opp_pos[opp_i][1]] = PRE_IDS['opponent'] + str(opp_i + 1) 187 | 188 | def __init_full_obs(self): 189 | """ Each team consists of m = 5 agents and their initial positions are sampled uniformly in a 5 × 5 190 | square around the team center, which is picked uniformly in the grid. 191 | """ 192 | self._full_obs = self.__create_grid() 193 | 194 | # select agent team center 195 | # Note : Leaving space from edges so as to have a 5x5 grid around it 196 | agent_team_center = self.np_random.randint(2, self._grid_shape[0] - 3), self.np_random.randint(2, 197 | self._grid_shape[ 198 | 1] - 3) 199 | # randomly select agent pos 200 | for agent_i in range(self.n_agents): 201 | while True: 202 | pos = [self.np_random.randint(agent_team_center[0] - 2, agent_team_center[0] + 2), 203 | self.np_random.randint(agent_team_center[1] - 2, agent_team_center[1] + 2)] 204 | if self._full_obs[pos[0]][pos[1]] == PRE_IDS['empty']: 205 | self.agent_prev_pos[agent_i] = pos 206 | self.agent_pos[agent_i] = pos 207 | self.__update_agent_view(agent_i) 208 | break 209 | 210 | # select opponent team center 211 | while True: 212 | pos = self.np_random.randint(2, self._grid_shape[0] - 3), self.np_random.randint(2, self._grid_shape[1] - 3) 213 | if self._full_obs[pos[0]][pos[1]] == PRE_IDS['empty']: 214 | opp_team_center = pos 215 | break 216 | 217 | # randomly select opponent pos 218 | for opp_i in range(self._n_opponents): 219 | while True: 220 | pos = [self.np_random.randint(opp_team_center[0] - 2, opp_team_center[0] + 2), 221 | self.np_random.randint(opp_team_center[1] - 2, opp_team_center[1] + 2)] 222 | if self._full_obs[pos[0]][pos[1]] == PRE_IDS['empty']: 223 | self.opp_prev_pos[opp_i] = pos 224 | self.opp_pos[opp_i] = pos 225 | self.__update_opp_view(opp_i) 226 | break 227 | 228 | self.__draw_base_img() 229 | 230 | def reset(self): 231 | self._step_count = 0 232 | self._steps_beyond_done = None 233 | self._total_episode_reward = [0 for _ in range(self.n_agents)] 234 | self.agent_health = {_: self._init_health for _ in range(self.n_agents)} 235 | self.opp_health = {_: self._init_health for _ in range(self._n_opponents)} 236 | self._agent_cool = {_: True for _ in range(self.n_agents)} 237 | self._agent_cool_step = {_: 0 for _ in range(self.n_agents)} 238 | self._opp_cool = {_: True for _ in range(self._n_opponents)} 239 | self._opp_cool_step = {_: 0 for _ in range(self._n_opponents)} 240 | self._agent_dones = [False for _ in range(self.n_agents)] 241 | 242 | self.__init_full_obs() 243 | 244 | # For debug only 245 | self._agents_trace = {_: [self.agent_pos[_]] for _ in range(self.n_agents)} 246 | self._opponents_trace = {_: [self.opp_pos[_]] for _ in range(self._n_opponents)} 247 | 248 | return self.get_agent_obs() 249 | 250 | def render(self, mode='human'): 251 | assert (self._step_count is not None), \ 252 | "Call reset before using render method." 253 | 254 | img = copy.copy(self._base_img) 255 | 256 | # draw agents 257 | for agent_i in range(self.n_agents): 258 | if self.agent_health[agent_i] > 0: 259 | fill_cell(img, self.agent_pos[agent_i], cell_size=CELL_SIZE, fill=AGENT_COLOR) 260 | write_cell_text(img, text=str(agent_i + 1), pos=self.agent_pos[agent_i], cell_size=CELL_SIZE, 261 | fill='white', margin=0.3) 262 | 263 | # draw opponents 264 | for opp_i in range(self._n_opponents): 265 | if self.opp_health[opp_i] > 0: 266 | fill_cell(img, self.opp_pos[opp_i], cell_size=CELL_SIZE, fill=OPPONENT_COLOR) 267 | write_cell_text(img, text=str(opp_i + 1), pos=self.opp_pos[opp_i], cell_size=CELL_SIZE, 268 | fill='white', margin=0.3) 269 | 270 | img = np.asarray(img) 271 | 272 | if mode == 'rgb_array': 273 | return img 274 | elif mode == 'human': 275 | from gym.envs.classic_control import rendering 276 | if self.viewer is None: 277 | self.viewer = rendering.SimpleImageViewer() 278 | self.viewer.imshow(img) 279 | return self.viewer.isopen 280 | 281 | def __update_agent_pos(self, agent_i, move): 282 | 283 | curr_pos = copy.copy(self.agent_pos[agent_i]) 284 | next_pos = None 285 | if move == 0: # down 286 | next_pos = [curr_pos[0] + 1, curr_pos[1]] 287 | elif move == 1: # left 288 | next_pos = [curr_pos[0], curr_pos[1] - 1] 289 | elif move == 2: # up 290 | next_pos = [curr_pos[0] - 1, curr_pos[1]] 291 | elif move == 3: # right 292 | next_pos = [curr_pos[0], curr_pos[1] + 1] 293 | elif move == 4: # no-op 294 | pass 295 | else: 296 | raise Exception('Action Not found!') 297 | 298 | if next_pos is not None and self._is_cell_vacant(next_pos): 299 | self.agent_pos[agent_i] = next_pos 300 | self.agent_prev_pos[agent_i] = curr_pos 301 | self.__update_agent_view(agent_i) 302 | self._agents_trace[agent_i].append(next_pos) 303 | 304 | def __update_opp_pos(self, opp_i, move): 305 | 306 | curr_pos = copy.copy(self.opp_pos[opp_i]) 307 | next_pos = None 308 | if move == 0: # down 309 | next_pos = [curr_pos[0] + 1, curr_pos[1]] 310 | elif move == 1: # left 311 | next_pos = [curr_pos[0], curr_pos[1] - 1] 312 | elif move == 2: # up 313 | next_pos = [curr_pos[0] - 1, curr_pos[1]] 314 | elif move == 3: # right 315 | next_pos = [curr_pos[0], curr_pos[1] + 1] 316 | elif move == 4: # no-op 317 | pass 318 | else: 319 | raise Exception('Action Not found!') 320 | 321 | if next_pos is not None and self._is_cell_vacant(next_pos): 322 | self.opp_pos[opp_i] = next_pos 323 | self.opp_prev_pos[opp_i] = curr_pos 324 | self.__update_opp_view(opp_i) 325 | self._opponents_trace[opp_i].append(next_pos) 326 | 327 | def is_valid(self, pos): 328 | return (0 <= pos[0] < self._grid_shape[0]) and (0 <= pos[1] < self._grid_shape[1]) 329 | 330 | def _is_cell_vacant(self, pos): 331 | return self.is_valid(pos) and (self._full_obs[pos[0]][pos[1]] == PRE_IDS['empty']) 332 | 333 | @staticmethod 334 | def is_visible(source_pos, target_pos): 335 | """ 336 | Checks if the target_pos is in the visible range(5x5) of the source pos 337 | 338 | :param source_pos: Coordinates of the source 339 | :param target_pos: Coordinates of the target 340 | :return: 341 | """ 342 | return (source_pos[0] - 2) <= target_pos[0] <= (source_pos[0] + 2) \ 343 | and (source_pos[1] - 2) <= target_pos[1] <= (source_pos[1] + 2) 344 | 345 | @staticmethod 346 | def is_fireable(source_cooling_down, source_pos, target_pos): 347 | """ 348 | Checks if the target_pos is in the firing range(5x5) 349 | 350 | :param source_pos: Coordinates of the source 351 | :param target_pos: Coordinates of the target 352 | :return: 353 | """ 354 | return source_cooling_down and (source_pos[0] - 1) <= target_pos[0] <= (source_pos[0] + 1) \ 355 | and (source_pos[1] - 1) <= target_pos[1] <= (source_pos[1] + 1) 356 | 357 | def reduce_distance_move(self, opp_i, source_pos, agent_i, target_pos): 358 | # Todo: makes moves Enum 359 | _moves = [] 360 | if source_pos[0] > target_pos[0]: 361 | _moves.append('UP') 362 | elif source_pos[0] < target_pos[0]: 363 | _moves.append('DOWN') 364 | 365 | if source_pos[1] > target_pos[1]: 366 | _moves.append('LEFT') 367 | elif source_pos[1] < target_pos[1]: 368 | _moves.append('RIGHT') 369 | 370 | if len(_moves) == 0: 371 | print(self._step_count, source_pos, target_pos) 372 | print("agent-{}, hp={}, move_trace={}".format(agent_i, self.agent_health[agent_i], 373 | self._agents_trace[agent_i])) 374 | print( 375 | "opponent-{}, hp={}, move_trace={}".format(opp_i, self.opp_health[opp_i], self._opponents_trace[opp_i])) 376 | raise AssertionError("One place exists 2 entities!") 377 | move = self.np_random.choice(_moves) 378 | for k, v in ACTION_MEANING.items(): 379 | if move.lower() == v.lower(): 380 | move = k 381 | break 382 | return move 383 | 384 | @property 385 | def opps_action(self): 386 | """ 387 | Opponent bots follow a hardcoded policy. 388 | 389 | The bot policy is to attack the nearest enemy agent if it is within its firing range. If not, 390 | it approaches the nearest visible enemy agent within visual range. An agent is visible to all bots if it 391 | is inside the visual range of any individual bot. This shared vision gives an advantage to the bot team. 392 | 393 | :return: 394 | """ 395 | 396 | visible_agents = set([]) 397 | opp_agent_distance = {_: [] for _ in range(self._n_opponents)} 398 | 399 | for opp_i, opp_pos in self.opp_pos.items(): 400 | for agent_i, agent_pos in self.agent_pos.items(): 401 | if agent_i not in visible_agents and self.agent_health[agent_i] > 0 \ 402 | and self.is_visible(opp_pos, agent_pos): 403 | visible_agents.add(agent_i) 404 | distance = abs(agent_pos[0] - opp_pos[0]) + abs(agent_pos[1] - opp_pos[1]) # manhattan distance 405 | opp_agent_distance[opp_i].append([distance, agent_i]) 406 | 407 | opp_action_n = [] 408 | for opp_i in range(self._n_opponents): 409 | action = None 410 | for _, agent_i in sorted(opp_agent_distance[opp_i]): 411 | if agent_i in visible_agents: 412 | if self.is_fireable(self._opp_cool[opp_i], self.opp_pos[opp_i], self.agent_pos[agent_i]): 413 | action = agent_i + 5 414 | elif self.opp_health[opp_i] > 0: 415 | action = self.reduce_distance_move(opp_i, self.opp_pos[opp_i], agent_i, self.agent_pos[agent_i]) 416 | break 417 | if action is None: 418 | if self.opp_health[opp_i] > 0: 419 | # logger.debug('No visible agent for enemy:{}'.format(opp_i)) 420 | action = self.np_random.choice(range(5)) 421 | else: 422 | action = 4 # dead opponent could only execute 'no-op' action. 423 | opp_action_n.append(action) 424 | return opp_action_n 425 | 426 | def step(self, agents_action): 427 | assert (self._step_count is not None), \ 428 | "Call reset before using step method." 429 | 430 | assert len(agents_action) == self.n_agents 431 | 432 | self._step_count += 1 433 | rewards = [self._step_cost for _ in range(self.n_agents)] 434 | 435 | # What's the confusion? 436 | # What if agents attack each other at the same time? Should both of them be effected? 437 | # Ans: I guess, yes 438 | # What if other agent moves before the attack is performed in the same time-step. 439 | # Ans: May be, I can process all the attack actions before move directions to ensure attacks have their effect. 440 | 441 | # processing attacks 442 | agent_health, opp_health = copy.copy(self.agent_health), copy.copy(self.opp_health) 443 | for agent_i, action in enumerate(agents_action): 444 | if self.agent_health[agent_i] > 0: 445 | if action > 4: # attack actions 446 | target_opp = action - 5 447 | if self.is_fireable(self._agent_cool[agent_i], self.agent_pos[agent_i], self.opp_pos[target_opp]) \ 448 | and opp_health[target_opp] > 0: 449 | # Fire 450 | opp_health[target_opp] -= 1 451 | rewards[agent_i] += 1 452 | 453 | # Update agent cooling down 454 | self._agent_cool[agent_i] = False 455 | self._agent_cool_step[agent_i] = self._step_cool 456 | 457 | # Remove opp from the map 458 | if opp_health[target_opp] == 0: 459 | pos = self.opp_pos[target_opp] 460 | self._full_obs[pos[0]][pos[1]] = PRE_IDS['empty'] 461 | 462 | # Update agent cooling down 463 | self._agent_cool_step[agent_i] = max(self._agent_cool_step[agent_i] - 1, 0) 464 | if self._agent_cool_step[agent_i] == 0 and not self._agent_cool[agent_i]: 465 | self._agent_cool[agent_i] = True 466 | 467 | opp_action = self.opps_action 468 | for opp_i, action in enumerate(opp_action): 469 | if self.opp_health[opp_i] > 0: 470 | target_agent = action - 5 471 | if action > 4: # attack actions 472 | if self.is_fireable(self._opp_cool[opp_i], self.opp_pos[opp_i], self.agent_pos[target_agent]) \ 473 | and agent_health[target_agent] > 0: 474 | # Fire 475 | agent_health[target_agent] -= 1 476 | rewards[target_agent] -= 1 477 | 478 | # Update opp cooling down 479 | self._opp_cool[opp_i] = False 480 | self._opp_cool_step[opp_i] = self._step_cool 481 | 482 | # Remove agent from the map 483 | if agent_health[target_agent] == 0: 484 | pos = self.agent_pos[target_agent] 485 | self._full_obs[pos[0]][pos[1]] = PRE_IDS['empty'] 486 | # Update opp cooling down 487 | self._opp_cool_step[opp_i] = max(self._opp_cool_step[opp_i] - 1, 0) 488 | if self._opp_cool_step[opp_i] == 0 and not self._opp_cool[opp_i]: 489 | self._opp_cool[opp_i] = True 490 | 491 | self.agent_health, self.opp_health = agent_health, opp_health 492 | 493 | # process move actions 494 | for agent_i, action in enumerate(agents_action): 495 | if self.agent_health[agent_i] > 0: 496 | if action <= 4: 497 | self.__update_agent_pos(agent_i, action) 498 | 499 | for opp_i, action in enumerate(opp_action): 500 | if self.opp_health[opp_i] > 0: 501 | if action <= 4: 502 | self.__update_opp_pos(opp_i, action) 503 | 504 | # step overflow or all opponents dead 505 | if (self._step_count >= self._max_steps) \ 506 | or (sum([v for k, v in self.opp_health.items()]) == 0) \ 507 | or (sum([v for k, v in self.agent_health.items()]) == 0): 508 | self._agent_dones = [True for _ in range(self.n_agents)] 509 | 510 | for i in range(self.n_agents): 511 | self._total_episode_reward[i] += rewards[i] 512 | 513 | # Check for episode overflow 514 | if all(self._agent_dones): 515 | if self._steps_beyond_done is None: 516 | self._steps_beyond_done = 0 517 | else: 518 | if self._steps_beyond_done == 0: 519 | logger.warn( 520 | "You are calling 'step()' even though this " 521 | "environment has already returned done = True. You " 522 | "should always call 'reset()' once you receive " 523 | "'done = True' -- any further steps are undefined " 524 | "behavior." 525 | ) 526 | self._steps_beyond_done += 1 527 | 528 | return self.get_agent_obs(), rewards, self._agent_dones, {'health': self.agent_health} 529 | 530 | def seed(self, n=None): 531 | self.np_random, seed = seeding.np_random(n) 532 | return [seed] 533 | 534 | def close(self): 535 | if self.viewer is not None: 536 | self.viewer.close() 537 | self.viewer = None 538 | 539 | 540 | CELL_SIZE = 15 541 | 542 | WALL_COLOR = 'black' 543 | AGENT_COLOR = 'red' 544 | OPPONENT_COLOR = 'blue' 545 | 546 | ACTION_MEANING = { 547 | 0: "DOWN", 548 | 1: "LEFT", 549 | 2: "UP", 550 | 3: "RIGHT", 551 | 4: "NOOP", 552 | } 553 | 554 | PRE_IDS = { 555 | 'wall': 'W', 556 | 'empty': '0', 557 | 'agent': 'A', 558 | 'opponent': 'X', 559 | } 560 | -------------------------------------------------------------------------------- /ma_gym/envs/lumberjacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .lumberjacks import Lumberjacks 2 | -------------------------------------------------------------------------------- /ma_gym/envs/lumberjacks/lumberjacks.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import itertools 3 | import logging 4 | from typing import List, Tuple, Union 5 | 6 | import gym 7 | import numpy as np 8 | from PIL import ImageColor 9 | from gym import spaces 10 | from gym.utils import seeding 11 | 12 | from ..utils.action_space import MultiAgentActionSpace 13 | from ..utils.draw import draw_circle, draw_grid, fill_cell, write_cell_text 14 | from ..utils.observation_space import MultiAgentObservationSpace 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | Coordinates = Tuple[int, int] 19 | 20 | 21 | class Agent: 22 | """Dataclass keeping all data for one agent/lumberjack in environment. 23 | In order to keep the support for Python3.6 we are not using `dataclasses` module. 24 | 25 | Attributes: 26 | id: unique id in one environment run 27 | pos: position of the agent in grid 28 | """ 29 | 30 | def __init__(self, id: int, pos: Coordinates): 31 | self.id = id 32 | self.pos = pos 33 | 34 | 35 | class Lumberjacks(gym.Env): 36 | """ 37 | Lumberjacks environment involve a grid world, in which multiple lumberjacks attempt to cut down all trees. In order to cut down a tree in given cell, there must be present greater or equal number of agents/lumberjacks then the tree strength in the same location as tree. Tree is then cut down automatically. 38 | 39 | Agents select one of fire actions ∈ {No-Op, Down, Left, Up, Right}. 40 | Each agent's observation includes its: 41 | - agent ID (1) 42 | - position with in grid (2) 43 | - number of steps since beginning (1) 44 | - number of agents and tree strength for each cell in agent view (2 * `np.prod(tuple(2 * v + 1 for v in agent_view))`). 45 | All values are scaled down into range ∈ [0, 1]. 46 | 47 | Only the agents who are involved in cutting down the tree are rewarded with `tree_cutdown_reward`. 48 | The environment is terminated as soon as all trees are cut down or when the number of steps reach the `max_steps`. 49 | 50 | Upon rendering, we show the grid, where each cell shows the agents (blue) and tree (green) with their current strength. 51 | 52 | Args: 53 | grid_shape: size of the grid 54 | n_agents: number of agents/lumberjacks 55 | n_trees: number of trees 56 | agent_view: size of the agent view range in each direction 57 | full_observable: flag whether agents should receive observation for all other agents 58 | step_cost: reward receive in each time step 59 | tree_cutdown_reward: reward received by agents who cut down the tree 60 | max_steps: maximum steps in one environment episode 61 | 62 | Attributes: 63 | _agents: list of all agents. The index in this list is also the ID of the agent 64 | _agent_map: tree dimensional numpy array of indicators where the agents are located 65 | _tree_map: two dimensional numpy array of strength of the trees 66 | _total_episode_reward: array with accumulated rewards for each agent. 67 | _agent_dones: list with indicater whether the agent is done or not. 68 | _base_img: base image with grid 69 | _viewer: viewer for the rendered image 70 | """ 71 | metadata = {'render.modes': ['human', 'rgb_array']} 72 | 73 | def __init__(self, grid_shape: Coordinates = (5, 5), n_agents: int = 2, n_trees: int = 12, 74 | agent_view: Tuple[int, int] = (1, 1), full_observable: bool = False, 75 | step_cost: float = -1, tree_cutdown_reward: float = 10, max_steps: int = 100): 76 | assert 0 < n_agents 77 | assert n_agents + n_trees <= np.prod(grid_shape) 78 | assert 1 <= agent_view[0] <= grid_shape[0] and 1 <= agent_view[1] <= grid_shape[1] 79 | 80 | self._grid_shape = grid_shape 81 | self.n_agents = n_agents 82 | self._n_trees = n_trees 83 | self._agent_view = agent_view 84 | self.full_observable = full_observable 85 | self._step_cost = step_cost 86 | self._step_count = None 87 | self._tree_cutdown_reward = tree_cutdown_reward 88 | self._max_steps = max_steps 89 | self.steps_beyond_done = 0 90 | self.seed() 91 | 92 | self._agents = [] # List[Agent] 93 | # In order to speed up the environment we used the advantage of vector operations. 94 | # Therefor we need to pad the grid size by the maximum agent_view size. 95 | # Relative coordinates refer to the coordinates in non pad grid. These are the only 96 | # coordinates visible to user. Extended coordinates refer to the coordinates in pad grid. 97 | self.__init_pos = None 98 | self._agent_map = None 99 | self._tree_map = None 100 | self._total_episode_reward = None 101 | self._agent_dones = None 102 | 103 | mask_size = np.prod(tuple(2 * v + 1 for v in self._agent_view)) 104 | # Agent ID (1) + Pos (2) + Step (1) + Neighborhood (2 * mask_size) 105 | self._obs_len = (1 + 2 + 1 + 2 * mask_size) 106 | obs_high = np.array([1.] * self._obs_len, dtype=np.float32) 107 | obs_low = np.array([0.] * self._obs_len, dtype=np.float32) 108 | if self.full_observable: 109 | obs_high = np.tile(obs_high, self.n_agents) 110 | obs_low = np.tile(obs_low, self.n_agents) 111 | self.action_space = MultiAgentActionSpace([spaces.Discrete(5)] * self.n_agents) 112 | self.observation_space = MultiAgentObservationSpace([spaces.Box(obs_low, obs_high)] * self.n_agents) 113 | 114 | self._base_img = draw_grid(self._grid_shape[0], self._grid_shape[1], cell_size=CELL_SIZE, fill='white') 115 | self._viewer = None 116 | 117 | def get_action_meanings(self, agent_id: int = None) -> Union[List[str], List[List[str]]]: 118 | """Returns list of actions meaning for `agent_id`. 119 | 120 | If `agent_id` is not specified returns meaning for all agents. 121 | """ 122 | if agent_id is not None: 123 | assert agent_id <= self.n_agents 124 | return [k.upper() for k, v in sorted(ACTIONS_IDS.items(), key=lambda item: item[1])] 125 | else: 126 | return [[k.upper() for k, v in sorted(ACTIONS_IDS.items(), key=lambda item: item[1])]] 127 | 128 | def reset(self) -> List[List[float]]: 129 | self._init_episode() 130 | self._step_count = 0 131 | self._total_episode_reward = np.zeros(self.n_agents) 132 | self._agent_dones = [False] * self.n_agents 133 | self.steps_beyond_done = 0 134 | 135 | return self.get_agent_obs() 136 | 137 | def _init_episode(self): 138 | """Initialize environment for new episode. 139 | 140 | Fills `self._agents`, self._agent_map` and `self._tree_map` with new values. 141 | """ 142 | init_positions = self._generate_init_pos() 143 | agent_id, tree_id = 0, self.n_agents 144 | self._agents = [] 145 | self._agent_map = np.zeros(( 146 | self._grid_shape[0] + 2 * (self._agent_view[0]), 147 | self._grid_shape[1] + 2 * (self._agent_view[1]), 148 | self.n_agents 149 | ), dtype=np.int32) 150 | self._tree_map = np.zeros(( 151 | self._grid_shape[0] + 2 * (self._agent_view[0]), 152 | self._grid_shape[1] + 2 * (self._agent_view[1]), 153 | ), dtype=np.int32) 154 | 155 | for pos, cell in np.ndenumerate(init_positions): 156 | pos = self._to_extended_coordinates(pos) 157 | if cell == PRE_IDS['agent']: 158 | self._agent_map[pos[0], pos[1], agent_id] = 1 159 | self._agents.append(Agent(agent_id, pos=pos)) 160 | agent_id += 1 161 | elif cell == PRE_IDS['tree']: 162 | self._tree_map[pos] = self.np_random.randint(1, self.n_agents + 1) 163 | tree_id += 1 164 | 165 | def _to_extended_coordinates(self, relative_coordinates): 166 | """Translate relative coordinates in to the extended coordinates.""" 167 | return relative_coordinates[0] + self._agent_view[0], relative_coordinates[1] + self._agent_view[1] 168 | 169 | def _to_relative_coordinates(self, extended_coordinates): 170 | """Translate extended coordinates in to the relative coordinates.""" 171 | return extended_coordinates[0] - self._agent_view[0], extended_coordinates[1] - self._agent_view[1] 172 | 173 | def _generate_init_pos(self) -> np.ndarray: 174 | """Returns randomly selected initial positions for agents and trees 175 | in relative coordinates. 176 | 177 | No agent or trees share the same cell in initial positions. 178 | """ 179 | init_pos = np.array( 180 | [PRE_IDS['agent']] * self.n_agents + 181 | [PRE_IDS['tree']] * self._n_trees + 182 | [PRE_IDS['empty']] * (np.prod(self._grid_shape) - self.n_agents - self._n_trees) 183 | ) 184 | 185 | # We ensure initial grid position is not same as last episode's 186 | # initial position. Though, just shuffling the array is sufficient, 187 | # we do validate and iterate to ensure change in position. If it still 188 | # remains same, we issue a warning and continue with the configuration. 189 | if self.__init_pos is not None: 190 | _shuffle_counter = 0 191 | self.np_random.shuffle(init_pos) 192 | while not all(self.__init_pos == init_pos): 193 | self.np_random.shuffle(init_pos) 194 | _shuffle_counter += 1 195 | if _shuffle_counter > 10: 196 | logger.warning("Grid configuration same as last episode") 197 | break 198 | self.__init_pos = init_pos 199 | return np.reshape(init_pos, self._grid_shape) 200 | 201 | def render(self, mode='human'): 202 | assert (self._step_count is not None), \ 203 | "Call reset before using render method." 204 | 205 | img = copy.copy(self._base_img) 206 | 207 | mask = ( 208 | slice(self._agent_view[0], self._agent_view[0] + self._grid_shape[0]), 209 | slice(self._agent_view[1], self._agent_view[1] + self._grid_shape[1]), 210 | ) 211 | 212 | # Iterate over all grid positions 213 | for pos, agent_strength, tree_strength in self._view_generator(mask): 214 | if tree_strength and agent_strength: 215 | cell_size = (CELL_SIZE, CELL_SIZE / 2) 216 | tree_pos = (pos[0], 2 * pos[1]) 217 | agent_pos = (pos[0], 2 * pos[1] + 1) 218 | else: 219 | cell_size = (CELL_SIZE, CELL_SIZE) 220 | tree_pos = agent_pos = (pos[0], pos[1]) 221 | 222 | if tree_strength != 0: 223 | fill_cell(img, pos=tree_pos, cell_size=cell_size, fill=TREE_COLOR, margin=0.1) 224 | write_cell_text(img, text=str(tree_strength), pos=tree_pos, 225 | cell_size=cell_size, fill='white', margin=0.4) 226 | 227 | if agent_strength != 0: 228 | draw_circle(img, pos=agent_pos, cell_size=cell_size, fill=AGENT_COLOR, radius=0.30) 229 | write_cell_text(img, text=str(agent_strength), pos=agent_pos, 230 | cell_size=cell_size, fill='white', margin=0.4) 231 | 232 | img = np.asarray(img) 233 | if mode == 'rgb_array': 234 | return img 235 | elif mode == 'human': 236 | from gym.envs.classic_control import rendering 237 | if self._viewer is None: 238 | self._viewer = rendering.SimpleImageViewer() 239 | self._viewer.imshow(img) 240 | return self._viewer.isopen 241 | 242 | def _view_generator(self, mask: Tuple[slice, slice]) -> Tuple[Coordinates, int, int]: 243 | """Yields position, number of agent and tree strength for all cells 244 | defined by `mask`. 245 | 246 | Args: 247 | mask: tuple of slices in extended coordinates. 248 | """ 249 | agent_iter = np.ndenumerate(np.sum(self._agent_map[mask], axis=2)) 250 | tree_iter = np.nditer(self._tree_map[mask]) 251 | for (pos, n_a), n_t in zip(agent_iter, tree_iter): 252 | yield pos, n_a, n_t 253 | 254 | def get_agent_obs(self) -> List[List[float]]: 255 | """Returns list of observations for each agent.""" 256 | obs = np.zeros((self.n_agents, self._obs_len)) 257 | for i, (agent_id, agent) in enumerate(self._agent_generator()): 258 | rel_pos = self._to_relative_coordinates(agent.pos) 259 | obs[i, 0] = agent_id / self.n_agents # Agent ID 260 | obs[i, 1] = rel_pos[0] / (self._grid_shape[0] - 1) # Coordinate 261 | obs[i, 2] = rel_pos[1] / (self._grid_shape[1] - 1) # Coordinate 262 | obs[i, 3] = self._step_count / self._max_steps # Steps 263 | 264 | for j, (_, agent_strength, tree_strength) in zip( 265 | itertools.count(start=4, step=2), 266 | self._agent_view_generator(agent.pos, self._agent_view)): 267 | obs[i, j] = agent_strength / self.n_agents 268 | obs[i, j + 1] = tree_strength / self.n_agents 269 | 270 | # Convert it from numpy array 271 | obs = obs.tolist() 272 | 273 | if self.full_observable: 274 | obs = [feature for agent_obs in obs for feature in agent_obs] 275 | obs = [obs] * self.n_agents 276 | 277 | return obs 278 | 279 | def _agent_generator(self) -> Tuple[int, Agent]: 280 | """Yields agent_id and agent for all agents in environment.""" 281 | for agent_id, agent in enumerate(self._agents): 282 | yield agent_id, agent 283 | 284 | def _agent_view_generator(self, pos: Coordinates, view_range: Tuple[int, int]): 285 | """Yields position, number of agent and tree strength for cells in distance of `view_range` from `pos`. """ 286 | mask = ( 287 | slice(pos[0] - view_range[0], pos[0] + view_range[0] + 1), 288 | slice(pos[1] - view_range[1], pos[1] + view_range[1] + 1), 289 | ) 290 | yield from self._view_generator(mask) 291 | 292 | def step(self, agents_action: List[int]): 293 | assert (self._step_count is not None), \ 294 | "Call reset before using step method." 295 | # Assert would slow down the environment which is undesirable. We rather expect the check on the user side. 296 | # assert len(agents_action) == self.n_agents 297 | 298 | # Following snippet of code was refereed from: 299 | # https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py#L124 300 | if all(self._agent_dones): 301 | if self.steps_beyond_done == 0: 302 | logger.warning( 303 | "You are calling 'step()' even though this environment has already returned all(dones) = True for " 304 | "all agents. You should always call 'reset()' once you receive 'all(dones) = True' -- any further" 305 | " steps are undefined behavior.") 306 | self.steps_beyond_done += 1 307 | return self.get_agent_obs(), [0] * self.n_agents, self._agent_dones, {} 308 | 309 | self._step_count += 1 310 | rewards = np.full(self.n_agents, self._step_cost) 311 | 312 | # Move agents 313 | for (agent_id, agent), action in zip(self._agent_generator(), agents_action): 314 | if not self._agent_dones[agent_id]: 315 | self._update_agent_pos(agent, action) 316 | 317 | # Cut down trees 318 | mask = (np.sum(self._agent_map, axis=2) >= self._tree_map) & (self._tree_map > 0) 319 | self._tree_map[mask] = 0 320 | 321 | # Calculate rewards 322 | rewards += np.sum(mask * self._tree_cutdown_reward, axis=(0, 1)) 323 | self._total_episode_reward += rewards 324 | 325 | if (self._step_count >= self._max_steps) or (np.count_nonzero(self._tree_map) == 0): 326 | self._agent_dones = [True] * self.n_agents 327 | 328 | return self.get_agent_obs(), rewards, self._agent_dones, {} 329 | 330 | def _update_agent_pos(self, agent: Agent, move: int): 331 | """Moves `agent` according the `move` command.""" 332 | next_pos = self._next_pos(agent.pos, move) 333 | 334 | # Remove agent from old position 335 | self._agent_map[agent.pos[0], agent.pos[1], agent.id] = 0 336 | 337 | # Add agent to the new position 338 | agent.pos = next_pos 339 | self._agent_map[next_pos[0], next_pos[1], agent.id] = 1 340 | 341 | def _next_pos(self, curr_pos: Coordinates, move: int) -> Coordinates: 342 | """Returns next valid position in extended coordinates given by `move` command relative to `curr_pos`.""" 343 | if move == ACTIONS_IDS['noop']: 344 | next_pos = curr_pos 345 | elif move == ACTIONS_IDS['down']: 346 | next_pos = (curr_pos[0] + 1, curr_pos[1]) 347 | elif move == ACTIONS_IDS['left']: 348 | next_pos = (curr_pos[0], curr_pos[1] - 1) 349 | elif move == ACTIONS_IDS['up']: 350 | next_pos = (curr_pos[0] - 1, curr_pos[1]) 351 | elif move == ACTIONS_IDS['right']: 352 | next_pos = (curr_pos[0], curr_pos[1] + 1) 353 | else: 354 | raise ValueError('Unknown action {}. Valid action are {}'.format(move, list(ACTIONS_IDS.values()))) 355 | # np.clip is significantly slower, see: https://github.com/numpy/numpy/issues/14281 356 | # return tuple(np.clip(next_pos, 357 | # (self._agent_view[0], self._agent_view[1]), 358 | # (self._agent_view[0] + self._grid_shape[0] - 1, 359 | # self._agent_view[1] + self._grid_shape[1] - 1), 360 | # )) 361 | return ( 362 | min(max(next_pos[0], self._agent_view[0]), self._grid_shape[0] - 1), 363 | min(max(next_pos[1], self._agent_view[1]), self._grid_shape[1] - 1), 364 | ) 365 | 366 | def seed(self, n: Union[None, int] = None): 367 | self.np_random, seed = seeding.np_random(n) 368 | return [seed] 369 | 370 | def close(self): 371 | if self._viewer is not None: 372 | self._viewer.close() 373 | self._viewer = None 374 | 375 | 376 | AGENT_COLOR = ImageColor.getcolor('blue', mode='RGB') 377 | TREE_COLOR = 'green' 378 | WALL_COLOR = 'black' 379 | 380 | CELL_SIZE = 35 381 | 382 | ACTIONS_IDS = { 383 | 'noop': 0, 384 | 'down': 1, 385 | 'left': 2, 386 | 'up': 3, 387 | 'right': 4, 388 | } 389 | 390 | PRE_IDS = { 391 | 'empty': 0, 392 | 'wall': 1, 393 | 'agent': 2, 394 | 'tree': 3, 395 | } 396 | -------------------------------------------------------------------------------- /ma_gym/envs/openai/__init__.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from ..utils.action_space import MultiAgentActionSpace 4 | from ..utils.observation_space import MultiAgentObservationSpace 5 | 6 | 7 | class MultiAgentWrapper(gym.Wrapper): 8 | """ It's a multi agent wrapper over openai's single agent environments. """ 9 | 10 | def __init__(self, name): 11 | super().__init__(gym.make(name)) 12 | self.n_agents = 1 13 | self._step_count = None 14 | self._total_episode_reward = None 15 | self._agent_dones = [None for _ in range(self.n_agents)] 16 | 17 | self.action_space = MultiAgentActionSpace([self.env.action_space]) 18 | self.observation_space = MultiAgentObservationSpace([self.env.observation_space]) 19 | 20 | def step(self, action_n): 21 | assert (self._step_count is not None), \ 22 | "Call reset before using step method." 23 | 24 | self._step_count += 1 25 | assert len(action_n) == self.n_agents 26 | 27 | action = action_n[0] 28 | obs, reward, done, info = self.env.step(action) 29 | 30 | # Following is a hack: 31 | # If this is not done and there is a there max step overflow then the TimeLimit Wrapper handles it and 32 | # makes done = True rather than making it a list of boolean values. 33 | # Nicer Options : Re-write Env Registry to have custom TimeLimit Wrapper for Multi agent envs 34 | # Or, we can simply pass a boolean value ourselves rather than a list 35 | if self.env._elapsed_steps == (self.env._max_episode_steps - 1): 36 | done = True 37 | 38 | self._total_episode_reward[0] += reward 39 | 40 | return [obs], [reward], [done], info 41 | 42 | def reset(self): 43 | self._step_count = 0 44 | self._total_episode_reward = [0 for _ in range(self.n_agents)] 45 | self._agent_dones = [False for _ in range(self.n_agents)] 46 | 47 | obs = self.env.reset() 48 | return [obs] 49 | -------------------------------------------------------------------------------- /ma_gym/envs/pong_duel/__init__.py: -------------------------------------------------------------------------------- 1 | from .pong_duel import PongDuel -------------------------------------------------------------------------------- /ma_gym/envs/pong_duel/pong_duel.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | 4 | import gym 5 | import numpy as np 6 | from gym import spaces 7 | from gym.utils import seeding 8 | 9 | from ..utils.action_space import MultiAgentActionSpace 10 | from ..utils.draw import draw_grid, fill_cell, draw_border 11 | from ..utils.observation_space import MultiAgentObservationSpace 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class PongDuel(gym.Env): 17 | """Two Player Pong Game - Competitive""" 18 | 19 | metadata = {'render.modes': ['human', 'rgb_array']} 20 | 21 | def __init__(self, step_cost=0, reward=1, max_rounds=10): 22 | self._grid_shape = (40, 30) 23 | self.n_agents = 2 24 | self.reward = reward 25 | self._max_rounds = max_rounds 26 | self.action_space = MultiAgentActionSpace([spaces.Discrete(3) for _ in range(self.n_agents)]) 27 | 28 | self._step_count = None 29 | self._steps_beyond_done = None 30 | self._step_cost = step_cost 31 | self._total_episode_reward = None 32 | self.agent_pos = {_: None for _ in range(self.n_agents)} 33 | self._agent_dones = None 34 | self.ball_pos = None 35 | self.__rounds = None 36 | 37 | # agent pos(2), ball pos (2), balldir (6-onehot) 38 | self._obs_low = np.array([0., 0., 0., 0.] + [0.] * len(BALL_DIRECTIONS), dtype=np.float32) 39 | self._obs_high = np.array([1., 1., 1., 1.] + [1.] * len(BALL_DIRECTIONS), dtype=np.float32) 40 | self.observation_space = MultiAgentObservationSpace([spaces.Box(self._obs_low, self._obs_high)for _ in range(self.n_agents)]) 41 | 42 | self.curr_ball_dir = None 43 | self.viewer = None 44 | self.seed() 45 | 46 | def get_action_meanings(self, agent_i=None): 47 | if agent_i is not None: 48 | assert agent_i <= self.n_agents 49 | return [ACTION_MEANING[i] for i in range(self.action_space[agent_i].n)] 50 | else: 51 | return [[ACTION_MEANING[i] for i in range(ac.n)] for ac in self.action_space] 52 | 53 | def __create_grid(self): 54 | _grid = [[PRE_IDS['empty'] for _ in range(self._grid_shape[1])] for row in range(self._grid_shape[0])] 55 | return _grid 56 | 57 | def __update_agent_view(self, agent_i): 58 | for row in range(self.agent_prev_pos[agent_i][0] - PADDLE_SIZE, 59 | self.agent_prev_pos[agent_i][0] + PADDLE_SIZE + 1): 60 | self._full_obs[row][self.agent_prev_pos[agent_i][1]] = PRE_IDS['empty'] 61 | 62 | for row in range(self.agent_pos[agent_i][0] - PADDLE_SIZE, self.agent_pos[agent_i][0] + PADDLE_SIZE + 1): 63 | self._full_obs[row][self.agent_pos[agent_i][1]] = PRE_IDS['agent'] + str(agent_i + 1) \ 64 | + '_' + str(row - self.agent_pos[agent_i][0]) 65 | 66 | def __update_ball_view(self): 67 | self._full_obs[self.ball_pos[0]][self.ball_pos[1]] = PRE_IDS['ball'] 68 | 69 | def __draw_base_img(self): 70 | self._base_img = draw_grid(self._grid_shape[0], self._grid_shape[1], 71 | cell_size=CELL_SIZE, fill='white', line_color='white') 72 | 73 | def __init_full_obs(self): 74 | self._full_obs = self.__create_grid() 75 | for agent_i in range(self.n_agents): 76 | self.__update_agent_view(agent_i) 77 | 78 | for agent_i in range(self.n_agents): 79 | self.__update_agent_view(agent_i) 80 | 81 | self.__update_ball_view() 82 | 83 | self.__draw_base_img() 84 | 85 | def get_agent_obs(self): 86 | _obs = [] 87 | 88 | for agent_i in range(self.n_agents): 89 | pos = self.agent_pos[agent_i] 90 | _agent_i_obs = [pos[0] / self._grid_shape[0], pos[1] / self._grid_shape[1]] 91 | 92 | pos = self.ball_pos 93 | _agent_i_obs += [pos[0] / self._grid_shape[0], pos[1] / self._grid_shape[1]] 94 | 95 | _ball_dir = [0 for _ in range(len(BALL_DIRECTIONS))] 96 | _ball_dir[BALL_DIRECTIONS.index(self.curr_ball_dir)] = 1 97 | 98 | _agent_i_obs += _ball_dir # one hot ball dir encoding 99 | 100 | _obs.append(_agent_i_obs) 101 | 102 | return _obs 103 | 104 | def __init_ball_pos(self): 105 | self.ball_pos = [self.np_random.randint(5, self._grid_shape[0] - 5), self.np_random.randint(10, self._grid_shape[1] - 10)] 106 | self.curr_ball_dir = self.np_random.choice(['NW', 'SW', 'SE', 'NE']) 107 | 108 | def reset(self): 109 | self.__rounds = 0 110 | self.agent_pos[0] = (self.np_random.randint(PADDLE_SIZE, self._grid_shape[0] - PADDLE_SIZE - 1), 1) 111 | self.agent_pos[1] = (self.np_random.randint(PADDLE_SIZE, self._grid_shape[0] - PADDLE_SIZE - 1), 112 | self._grid_shape[1] - 2) 113 | self.agent_prev_pos = {_: self.agent_pos[_] for _ in range(self.n_agents)} 114 | self.__init_ball_pos() 115 | self._agent_dones = [False, False] 116 | self.__init_full_obs() 117 | self._step_count = 0 118 | self._steps_beyond_done = None 119 | self._total_episode_reward = [0 for _ in range(self.n_agents)] 120 | 121 | return self.get_agent_obs() 122 | 123 | @property 124 | def __ball_cells(self): 125 | if self.curr_ball_dir == 'E': 126 | return [self.ball_pos, [self.ball_pos[0], self.ball_pos[1] - 1], [self.ball_pos[0], self.ball_pos[1] - 2]] 127 | if self.curr_ball_dir == 'W': 128 | return [self.ball_pos, [self.ball_pos[0], self.ball_pos[1] + 1], [self.ball_pos[0], self.ball_pos[1] + 2]] 129 | if self.curr_ball_dir == 'NE': 130 | return [self.ball_pos, [self.ball_pos[0] + 1, self.ball_pos[1] - 1], 131 | [self.ball_pos[0] + 2, self.ball_pos[1] - 2]] 132 | if self.curr_ball_dir == 'NW': 133 | return [self.ball_pos, [self.ball_pos[0] + 1, self.ball_pos[1] + 1], 134 | [self.ball_pos[0] + 2, self.ball_pos[1] + 2]] 135 | if self.curr_ball_dir == 'SE': 136 | return [self.ball_pos, [self.ball_pos[0] - 1, self.ball_pos[1] - 1], 137 | [self.ball_pos[0] - 2, self.ball_pos[1] - 2]] 138 | if self.curr_ball_dir == 'SW': 139 | return [self.ball_pos, [self.ball_pos[0] - 1, self.ball_pos[1] + 1], 140 | [self.ball_pos[0] - 2, self.ball_pos[1] + 2]] 141 | 142 | def render(self, mode='human'): 143 | assert (self._step_count is not None), \ 144 | "Call reset before using render method." 145 | 146 | img = copy.copy(self._base_img) 147 | for agent_i in range(self.n_agents): 148 | for row in range(self.agent_pos[agent_i][0] - 2, self.agent_pos[agent_i][0] + 3): 149 | fill_cell(img, (row, self.agent_pos[agent_i][1]), cell_size=CELL_SIZE, fill=AGENT_COLORS[agent_i]) 150 | 151 | ball_cells = self.__ball_cells 152 | fill_cell(img, ball_cells[0], cell_size=CELL_SIZE, fill=BALL_HEAD_COLOR) 153 | fill_cell(img, ball_cells[1], cell_size=CELL_SIZE, fill=BALL_TAIL_COLOR) 154 | fill_cell(img, ball_cells[2], cell_size=CELL_SIZE, fill=BALL_TAIL_COLOR) 155 | 156 | img = draw_border(img, border_width=2, fill='gray') 157 | 158 | img = np.asarray(img) 159 | if mode == 'rgb_array': 160 | return img 161 | elif mode == 'human': 162 | from gym.envs.classic_control import rendering 163 | if self.viewer is None: 164 | self.viewer = rendering.SimpleImageViewer() 165 | self.viewer.imshow(img) 166 | return self.viewer.isopen 167 | 168 | def __update_agent_pos(self, agent_i, move): 169 | 170 | curr_pos = copy.copy(self.agent_pos[agent_i]) 171 | if move == 0: # noop 172 | next_pos = None 173 | elif move == 1: # up 174 | next_pos = [curr_pos[0] - 1, curr_pos[1]] 175 | elif move == 2: # down 176 | next_pos = [curr_pos[0] + 1, curr_pos[1]] 177 | else: 178 | raise Exception('Action Not found!') 179 | 180 | if next_pos is not None and PADDLE_SIZE <= next_pos[0] <= (self._grid_shape[0] - PADDLE_SIZE - 1): 181 | self.agent_prev_pos[agent_i] = self.agent_pos[agent_i] 182 | self.agent_pos[agent_i] = next_pos 183 | self.__update_agent_view(agent_i) 184 | 185 | def __update_ball_pos(self): 186 | 187 | if self.ball_pos[0] <= 1: 188 | self.curr_ball_dir = 'SE' if self.curr_ball_dir == 'NE' else 'SW' 189 | elif self.ball_pos[0] >= (self._grid_shape[0] - 2): 190 | self.curr_ball_dir = 'NE' if self.curr_ball_dir == 'SE' else 'NW' 191 | elif PRE_IDS['agent'] in self._full_obs[self.ball_pos[0]][self.ball_pos[1] + 1]: 192 | edge = int(self._full_obs[self.ball_pos[0]][self.ball_pos[1] + 1].split('_')[1]) 193 | _dir = ['NW', 'W', 'SW'] 194 | if edge <= 0: 195 | _p = [0.25 + ((1 - 0.25) / PADDLE_SIZE * (abs(edge))), 196 | 0.5 - (0.5 / PADDLE_SIZE * (abs(edge))), 197 | 0.25 - (0.25 / PADDLE_SIZE * (abs(edge))), ] 198 | elif edge >= 0: 199 | _p = [0.25 - (0.25 / PADDLE_SIZE * (abs(edge))), 200 | 0.5 - (0.5 / PADDLE_SIZE * (abs(edge))), 201 | 0.25 + ((1 - 0.25) / PADDLE_SIZE * (abs(edge)))] 202 | _p[len(_dir) // 2] += 1 - sum(_p) 203 | 204 | self.curr_ball_dir = self.np_random.choice(_dir, p=_p) 205 | elif PRE_IDS['agent'] in self._full_obs[self.ball_pos[0]][self.ball_pos[1] - 1]: 206 | _dir = ['NE', 'E', 'SE'] 207 | edge = int(self._full_obs[self.ball_pos[0]][self.ball_pos[1] - 1].split('_')[1]) 208 | if edge <= 0: 209 | _p = [0.25 + ((1 - 0.25) / PADDLE_SIZE * (abs(edge))), 210 | 0.5 - (0.5 / PADDLE_SIZE * (abs(edge))), 211 | 0.25 - (0.25 / PADDLE_SIZE * (abs(edge))), ] 212 | elif edge >= 0: 213 | _p = [0.25 - (0.25 / PADDLE_SIZE * (abs(edge))), 214 | 0.5 - (0.5 / PADDLE_SIZE * (abs(edge))), 215 | 0.25 + ((1 - 0.25) / PADDLE_SIZE * (abs(edge)))] 216 | _p[len(_dir) // 2] += 1 - sum(_p) 217 | self.curr_ball_dir = self.np_random.choice(_dir, p=_p) 218 | 219 | if self.curr_ball_dir == 'E': 220 | new_ball_pos = self.ball_pos[0], self.ball_pos[1] + 1 221 | elif self.curr_ball_dir == 'W': 222 | new_ball_pos = self.ball_pos[0], self.ball_pos[1] - 1 223 | elif self.curr_ball_dir == 'NE': 224 | new_ball_pos = self.ball_pos[0] - 1, self.ball_pos[1] + 1 225 | elif self.curr_ball_dir == 'NW': 226 | new_ball_pos = self.ball_pos[0] - 1, self.ball_pos[1] - 1 227 | elif self.curr_ball_dir == 'SE': 228 | new_ball_pos = self.ball_pos[0] + 1, self.ball_pos[1] + 1 229 | elif self.curr_ball_dir == 'SW': 230 | new_ball_pos = self.ball_pos[0] + 1, self.ball_pos[1] - 1 231 | 232 | self.ball_pos = new_ball_pos 233 | 234 | def seed(self, n=None): 235 | self.np_random, seed = seeding.np_random(n) 236 | return [seed] 237 | 238 | def step(self, action_n): 239 | assert (self._step_count is not None), \ 240 | "Call reset before using step method." 241 | 242 | assert len(action_n) == self.n_agents 243 | self._step_count += 1 244 | rewards = [self._step_cost for _ in range(self.n_agents)] 245 | 246 | # if ball is beyond paddle, initiate a new round 247 | if self.ball_pos[1] < 1: 248 | rewards = [0, self.reward] 249 | self.__rounds += 1 250 | elif self.ball_pos[1] >= (self._grid_shape[1] - 1): 251 | rewards = [self.reward, 0] 252 | self.__rounds += 1 253 | 254 | if self.__rounds == self._max_rounds: 255 | self._agent_dones = [True for _ in range(self.n_agents)] 256 | else: 257 | for agent_i in range(self.n_agents): 258 | self.__update_agent_pos(agent_i, action_n[agent_i]) 259 | 260 | if (self.ball_pos[1] < 1) or (self.ball_pos[1] >= self._grid_shape[1] - 1): 261 | self.__init_ball_pos() 262 | else: 263 | self.__update_ball_pos() 264 | 265 | for i in range(self.n_agents): 266 | self._total_episode_reward[i] += rewards[i] 267 | 268 | # Check for episode overflow 269 | if all(self._agent_dones): 270 | if self._steps_beyond_done is None: 271 | self._steps_beyond_done = 0 272 | else: 273 | if self._steps_beyond_done == 0: 274 | logger.warn( 275 | "You are calling 'step()' even though this " 276 | "environment has already returned all(done) = True. You " 277 | "should always call 'reset()' once you receive " 278 | "'all(done) = True' -- any further steps are undefined " 279 | "behavior." 280 | ) 281 | self._steps_beyond_done += 1 282 | 283 | return self.get_agent_obs(), rewards, self._agent_dones, {'rounds': self.__rounds} 284 | 285 | 286 | CELL_SIZE = 5 287 | 288 | ACTION_MEANING = { 289 | 0: "NOOP", 290 | 1: "UP", 291 | 2: "DOWN", 292 | } 293 | 294 | AGENT_COLORS = { 295 | 0: 'red', 296 | 1: 'blue' 297 | } 298 | WALL_COLOR = 'black' 299 | BALL_HEAD_COLOR = 'orange' 300 | BALL_TAIL_COLOR = 'yellow' 301 | 302 | # each pre-id should be unique and single char 303 | PRE_IDS = { 304 | 'agent': 'A', 305 | 'wall': 'W', 306 | 'ball': 'B', 307 | 'empty': 'O' 308 | } 309 | 310 | BALL_DIRECTIONS = ['NW', 'W', 'SW', 'SE', 'E', 'NE'] 311 | PADDLE_SIZE = 2 312 | -------------------------------------------------------------------------------- /ma_gym/envs/predator_prey/__init__.py: -------------------------------------------------------------------------------- 1 | from .predator_prey import PredatorPrey 2 | -------------------------------------------------------------------------------- /ma_gym/envs/predator_prey/predator_prey.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | 4 | import gym 5 | import numpy as np 6 | from PIL import ImageColor 7 | from gym import spaces 8 | from gym.utils import seeding 9 | 10 | from ..utils.action_space import MultiAgentActionSpace 11 | from ..utils.draw import draw_grid, fill_cell, draw_circle, write_cell_text 12 | from ..utils.observation_space import MultiAgentObservationSpace 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class PredatorPrey(gym.Env): 18 | """ 19 | Predator-prey involves a grid world, in which multiple predators attempt to capture randomly moving prey. 20 | Agents have a 5 × 5 view and select one of five actions ∈ {Left, Right, Up, Down, Stop} at each time step. 21 | Prey move according to selecting a uniformly random action at each time step. 22 | 23 | We define the “catching” of a prey as when the prey is within the cardinal direction of at least one predator. 24 | Each agent’s observation includes its own coordinates, agent ID, and the coordinates of the prey relative 25 | to itself, if observed. The agents can separate roles even if the parameters of the neural networks are 26 | shared by agent ID. We test with two different grid worlds: (i) a 5 × 5 grid world with two predators and one prey, 27 | and (ii) a 7 × 7 grid world with four predators and two prey. 28 | 29 | We modify the general predator-prey, such that a positive reward is given only if multiple predators catch a prey 30 | simultaneously, requiring a higher degree of cooperation. The predators get a team reward of 1 if two or more 31 | catch a prey at the same time, but they are given negative reward −P.We experimented with three varying P vales, 32 | where P = 0.5, 1.0, 1.5. 33 | 34 | The terminating condition of this task is when all preys are caught by more than one predator. 35 | For every new episodes , preys are initialized into random locations. Also, preys never move by themself into 36 | predator's neighbourhood 37 | """ 38 | metadata = {'render.modes': ['human', 'rgb_array']} 39 | 40 | def __init__(self, grid_shape=(5, 5), n_agents=2, n_preys=1, prey_move_probs=(0.175, 0.175, 0.175, 0.175, 0.3), 41 | full_observable=False, penalty=-0.5, step_cost=-0.01, prey_capture_reward=5, max_steps=100, 42 | agent_view_mask=(5, 5)): 43 | assert len(grid_shape) == 2, 'expected a tuple of size 2 for grid_shape, but found {}'.format(grid_shape) 44 | assert len(agent_view_mask) == 2, 'expected a tuple of size 2 for agent view mask,' \ 45 | ' but found {}'.format(agent_view_mask) 46 | assert grid_shape[0] > 0 and grid_shape[1] > 0, 'grid shape should be > 0' 47 | assert 0 < agent_view_mask[0] <= grid_shape[0], 'agent view mask has to be within (0,{}]'.format(grid_shape[0]) 48 | assert 0 < agent_view_mask[1] <= grid_shape[1], 'agent view mask has to be within (0,{}]'.format(grid_shape[1]) 49 | 50 | self._grid_shape = grid_shape 51 | self.n_agents = n_agents 52 | self.n_preys = n_preys 53 | self._max_steps = max_steps 54 | self._step_count = None 55 | self._steps_beyond_done = None 56 | self._penalty = penalty 57 | self._step_cost = step_cost 58 | self._prey_capture_reward = prey_capture_reward 59 | self._agent_view_mask = agent_view_mask 60 | 61 | self.action_space = MultiAgentActionSpace([spaces.Discrete(5) for _ in range(self.n_agents)]) 62 | self.agent_pos = {_: None for _ in range(self.n_agents)} 63 | self.prey_pos = {_: None for _ in range(self.n_preys)} 64 | self._prey_alive = None 65 | 66 | self._base_grid = self.__create_grid() # with no agents 67 | self._full_obs = self.__create_grid() 68 | self._agent_dones = [False for _ in range(self.n_agents)] 69 | self._prey_move_probs = prey_move_probs 70 | self.viewer = None 71 | self.full_observable = full_observable 72 | 73 | # agent pos (2), prey (25), step (1) 74 | mask_size = np.prod(self._agent_view_mask) 75 | self._obs_high = np.array([1., 1.] + [1.] * mask_size + [1.0], dtype=np.float32) 76 | self._obs_low = np.array([0., 0.] + [0.] * mask_size + [0.0], dtype=np.float32) 77 | if self.full_observable: 78 | self._obs_high = np.tile(self._obs_high, self.n_agents) 79 | self._obs_low = np.tile(self._obs_low, self.n_agents) 80 | self.observation_space = MultiAgentObservationSpace( 81 | [spaces.Box(self._obs_low, self._obs_high) for _ in range(self.n_agents)]) 82 | 83 | self._total_episode_reward = None 84 | self.seed() 85 | 86 | def get_action_meanings(self, agent_i=None): 87 | if agent_i is not None: 88 | assert agent_i <= self.n_agents 89 | return [ACTION_MEANING[i] for i in range(self.action_space[agent_i].n)] 90 | else: 91 | return [[ACTION_MEANING[i] for i in range(ac.n)] for ac in self.action_space] 92 | 93 | def action_space_sample(self): 94 | return [agent_action_space.sample() for agent_action_space in self.action_space] 95 | 96 | def __draw_base_img(self): 97 | self._base_img = draw_grid(self._grid_shape[0], self._grid_shape[1], cell_size=CELL_SIZE, fill='white') 98 | 99 | def __create_grid(self): 100 | _grid = [[PRE_IDS['empty'] for _ in range(self._grid_shape[1])] for row in range(self._grid_shape[0])] 101 | return _grid 102 | 103 | def __init_full_obs(self): 104 | self._full_obs = self.__create_grid() 105 | 106 | for agent_i in range(self.n_agents): 107 | while True: 108 | pos = [self.np_random.randint(0, self._grid_shape[0] - 1), 109 | self.np_random.randint(0, self._grid_shape[1] - 1)] 110 | if self._is_cell_vacant(pos): 111 | self.agent_pos[agent_i] = pos 112 | break 113 | self.__update_agent_view(agent_i) 114 | 115 | for prey_i in range(self.n_preys): 116 | while True: 117 | pos = [self.np_random.randint(0, self._grid_shape[0] - 1), 118 | self.np_random.randint(0, self._grid_shape[1] - 1)] 119 | if self._is_cell_vacant(pos) and (self._neighbour_agents(pos)[0] == 0): 120 | self.prey_pos[prey_i] = pos 121 | break 122 | self.__update_prey_view(prey_i) 123 | 124 | self.__draw_base_img() 125 | 126 | def get_agent_obs(self): 127 | _obs = [] 128 | for agent_i in range(self.n_agents): 129 | pos = self.agent_pos[agent_i] 130 | _agent_i_obs = [pos[0] / (self._grid_shape[0] - 1), pos[1] / (self._grid_shape[1] - 1)] # coordinates 131 | 132 | # check if prey is in the view area 133 | _prey_pos = np.zeros(self._agent_view_mask) # prey location in neighbour 134 | for row in range(max(0, pos[0] - 2), min(pos[0] + 2 + 1, self._grid_shape[0])): 135 | for col in range(max(0, pos[1] - 2), min(pos[1] + 2 + 1, self._grid_shape[1])): 136 | if PRE_IDS['prey'] in self._full_obs[row][col]: 137 | _prey_pos[row - (pos[0] - 2), col - (pos[1] - 2)] = 1 # get relative position for the prey loc. 138 | 139 | _agent_i_obs += _prey_pos.flatten().tolist() # adding prey pos in observable area 140 | _agent_i_obs += [self._step_count / self._max_steps] # adding time 141 | _obs.append(_agent_i_obs) 142 | 143 | if self.full_observable: 144 | _obs = np.array(_obs).flatten().tolist() 145 | _obs = [_obs for _ in range(self.n_agents)] 146 | return _obs 147 | 148 | def reset(self): 149 | self._total_episode_reward = [0 for _ in range(self.n_agents)] 150 | self.agent_pos = {} 151 | self.prey_pos = {} 152 | 153 | self.__init_full_obs() 154 | self._step_count = 0 155 | self._steps_beyond_done = None 156 | self._agent_dones = [False for _ in range(self.n_agents)] 157 | self._prey_alive = [True for _ in range(self.n_preys)] 158 | 159 | return self.get_agent_obs() 160 | 161 | def __wall_exists(self, pos): 162 | row, col = pos 163 | return PRE_IDS['wall'] in self._base_grid[row, col] 164 | 165 | def is_valid(self, pos): 166 | return (0 <= pos[0] < self._grid_shape[0]) and (0 <= pos[1] < self._grid_shape[1]) 167 | 168 | def _is_cell_vacant(self, pos): 169 | return self.is_valid(pos) and (self._full_obs[pos[0]][pos[1]] == PRE_IDS['empty']) 170 | 171 | def __update_agent_pos(self, agent_i, move): 172 | 173 | curr_pos = copy.copy(self.agent_pos[agent_i]) 174 | next_pos = None 175 | if move == 0: # down 176 | next_pos = [curr_pos[0] + 1, curr_pos[1]] 177 | elif move == 1: # left 178 | next_pos = [curr_pos[0], curr_pos[1] - 1] 179 | elif move == 2: # up 180 | next_pos = [curr_pos[0] - 1, curr_pos[1]] 181 | elif move == 3: # right 182 | next_pos = [curr_pos[0], curr_pos[1] + 1] 183 | elif move == 4: # no-op 184 | pass 185 | else: 186 | raise Exception('Action Not found!') 187 | 188 | if next_pos is not None and self._is_cell_vacant(next_pos): 189 | self.agent_pos[agent_i] = next_pos 190 | self._full_obs[curr_pos[0]][curr_pos[1]] = PRE_IDS['empty'] 191 | self.__update_agent_view(agent_i) 192 | 193 | def __next_pos(self, curr_pos, move): 194 | if move == 0: # down 195 | next_pos = [curr_pos[0] + 1, curr_pos[1]] 196 | elif move == 1: # left 197 | next_pos = [curr_pos[0], curr_pos[1] - 1] 198 | elif move == 2: # up 199 | next_pos = [curr_pos[0] - 1, curr_pos[1]] 200 | elif move == 3: # right 201 | next_pos = [curr_pos[0], curr_pos[1] + 1] 202 | elif move == 4: # no-op 203 | next_pos = curr_pos 204 | return next_pos 205 | 206 | def __update_prey_pos(self, prey_i, move): 207 | curr_pos = copy.copy(self.prey_pos[prey_i]) 208 | if self._prey_alive[prey_i]: 209 | next_pos = None 210 | if move == 0: # down 211 | next_pos = [curr_pos[0] + 1, curr_pos[1]] 212 | elif move == 1: # left 213 | next_pos = [curr_pos[0], curr_pos[1] - 1] 214 | elif move == 2: # up 215 | next_pos = [curr_pos[0] - 1, curr_pos[1]] 216 | elif move == 3: # right 217 | next_pos = [curr_pos[0], curr_pos[1] + 1] 218 | elif move == 4: # no-op 219 | pass 220 | else: 221 | raise Exception('Action Not found!') 222 | 223 | if next_pos is not None and self._is_cell_vacant(next_pos): 224 | self.prey_pos[prey_i] = next_pos 225 | self._full_obs[curr_pos[0]][curr_pos[1]] = PRE_IDS['empty'] 226 | self.__update_prey_view(prey_i) 227 | else: 228 | # print('pos not updated') 229 | pass 230 | else: 231 | self._full_obs[curr_pos[0]][curr_pos[1]] = PRE_IDS['empty'] 232 | 233 | def __update_agent_view(self, agent_i): 234 | self._full_obs[self.agent_pos[agent_i][0]][self.agent_pos[agent_i][1]] = PRE_IDS['agent'] + str(agent_i + 1) 235 | 236 | def __update_prey_view(self, prey_i): 237 | self._full_obs[self.prey_pos[prey_i][0]][self.prey_pos[prey_i][1]] = PRE_IDS['prey'] + str(prey_i + 1) 238 | 239 | def _neighbour_agents(self, pos): 240 | # check if agent is in neighbour 241 | _count = 0 242 | neighbours_xy = [] 243 | if self.is_valid([pos[0] + 1, pos[1]]) and PRE_IDS['agent'] in self._full_obs[pos[0] + 1][pos[1]]: 244 | _count += 1 245 | neighbours_xy.append([pos[0] + 1, pos[1]]) 246 | if self.is_valid([pos[0] - 1, pos[1]]) and PRE_IDS['agent'] in self._full_obs[pos[0] - 1][pos[1]]: 247 | _count += 1 248 | neighbours_xy.append([pos[0] - 1, pos[1]]) 249 | if self.is_valid([pos[0], pos[1] + 1]) and PRE_IDS['agent'] in self._full_obs[pos[0]][pos[1] + 1]: 250 | _count += 1 251 | neighbours_xy.append([pos[0], pos[1] + 1]) 252 | if self.is_valid([pos[0], pos[1] - 1]) and PRE_IDS['agent'] in self._full_obs[pos[0]][pos[1] - 1]: 253 | neighbours_xy.append([pos[0], pos[1] - 1]) 254 | _count += 1 255 | 256 | agent_id = [] 257 | for x, y in neighbours_xy: 258 | agent_id.append(int(self._full_obs[x][y].split(PRE_IDS['agent'])[1]) - 1) 259 | return _count, agent_id 260 | 261 | def step(self, agents_action): 262 | assert (self._step_count is not None), \ 263 | "Call reset before using step method." 264 | 265 | self._step_count += 1 266 | rewards = [self._step_cost for _ in range(self.n_agents)] 267 | 268 | for agent_i, action in enumerate(agents_action): 269 | if not (self._agent_dones[agent_i]): 270 | self.__update_agent_pos(agent_i, action) 271 | 272 | for prey_i in range(self.n_preys): 273 | if self._prey_alive[prey_i]: 274 | predator_neighbour_count, n_i = self._neighbour_agents(self.prey_pos[prey_i]) 275 | 276 | if predator_neighbour_count >= 1: 277 | _reward = self._penalty if predator_neighbour_count == 1 else self._prey_capture_reward 278 | self._prey_alive[prey_i] = (predator_neighbour_count == 1) 279 | 280 | for agent_i in range(self.n_agents): 281 | rewards[agent_i] += _reward 282 | 283 | prey_move = None 284 | if self._prey_alive[prey_i]: 285 | # 5 trails : we sample next move and check if prey (smart) doesn't go in neighbourhood of predator 286 | for _ in range(5): 287 | _move = self.np_random.choice(len(self._prey_move_probs), 1, p=self._prey_move_probs)[0] 288 | if self._neighbour_agents(self.__next_pos(self.prey_pos[prey_i], _move))[0] == 0: 289 | prey_move = _move 290 | break 291 | prey_move = 4 if prey_move is None else prey_move # default is no-op(4) 292 | 293 | self.__update_prey_pos(prey_i, prey_move) 294 | 295 | if (self._step_count >= self._max_steps) or (True not in self._prey_alive): 296 | for i in range(self.n_agents): 297 | self._agent_dones[i] = True 298 | 299 | for i in range(self.n_agents): 300 | self._total_episode_reward[i] += rewards[i] 301 | 302 | # Check for episode overflow 303 | if all(self._agent_dones): 304 | if self._steps_beyond_done is None: 305 | self._steps_beyond_done = 0 306 | else: 307 | if self._steps_beyond_done == 0: 308 | logger.warn( 309 | "You are calling 'step()' even though this " 310 | "environment has already returned all(done) = True. You " 311 | "should always call 'reset()' once you receive " 312 | "'all(done) = True' -- any further steps are undefined " 313 | "behavior." 314 | ) 315 | self._steps_beyond_done += 1 316 | 317 | return self.get_agent_obs(), rewards, self._agent_dones, {'prey_alive': self._prey_alive} 318 | 319 | def __get_neighbour_coordinates(self, pos): 320 | neighbours = [] 321 | if self.is_valid([pos[0] + 1, pos[1]]): 322 | neighbours.append([pos[0] + 1, pos[1]]) 323 | if self.is_valid([pos[0] - 1, pos[1]]): 324 | neighbours.append([pos[0] - 1, pos[1]]) 325 | if self.is_valid([pos[0], pos[1] + 1]): 326 | neighbours.append([pos[0], pos[1] + 1]) 327 | if self.is_valid([pos[0], pos[1] - 1]): 328 | neighbours.append([pos[0], pos[1] - 1]) 329 | return neighbours 330 | 331 | def render(self, mode='human'): 332 | assert (self._step_count is not None), \ 333 | "Call reset before using render method." 334 | 335 | img = copy.copy(self._base_img) 336 | for agent_i in range(self.n_agents): 337 | for neighbour in self.__get_neighbour_coordinates(self.agent_pos[agent_i]): 338 | fill_cell(img, neighbour, cell_size=CELL_SIZE, fill=AGENT_NEIGHBORHOOD_COLOR, margin=0.1) 339 | fill_cell(img, self.agent_pos[agent_i], cell_size=CELL_SIZE, fill=AGENT_NEIGHBORHOOD_COLOR, margin=0.1) 340 | 341 | for agent_i in range(self.n_agents): 342 | draw_circle(img, self.agent_pos[agent_i], cell_size=CELL_SIZE, fill=AGENT_COLOR) 343 | write_cell_text(img, text=str(agent_i + 1), pos=self.agent_pos[agent_i], cell_size=CELL_SIZE, 344 | fill='white', margin=0.4) 345 | 346 | for prey_i in range(self.n_preys): 347 | if self._prey_alive[prey_i]: 348 | draw_circle(img, self.prey_pos[prey_i], cell_size=CELL_SIZE, fill=PREY_COLOR) 349 | write_cell_text(img, text=str(prey_i + 1), pos=self.prey_pos[prey_i], cell_size=CELL_SIZE, 350 | fill='white', margin=0.4) 351 | 352 | img = np.asarray(img) 353 | if mode == 'rgb_array': 354 | return img 355 | elif mode == 'human': 356 | from gym.envs.classic_control import rendering 357 | if self.viewer is None: 358 | self.viewer = rendering.SimpleImageViewer() 359 | self.viewer.imshow(img) 360 | return self.viewer.isopen 361 | 362 | def seed(self, n=None): 363 | self.np_random, seed = seeding.np_random(n) 364 | return [seed] 365 | 366 | def close(self): 367 | if self.viewer is not None: 368 | self.viewer.close() 369 | self.viewer = None 370 | 371 | 372 | AGENT_COLOR = ImageColor.getcolor('blue', mode='RGB') 373 | AGENT_NEIGHBORHOOD_COLOR = (186, 238, 247) 374 | PREY_COLOR = 'red' 375 | 376 | CELL_SIZE = 35 377 | 378 | WALL_COLOR = 'black' 379 | 380 | ACTION_MEANING = { 381 | 0: "DOWN", 382 | 1: "LEFT", 383 | 2: "UP", 384 | 3: "RIGHT", 385 | 4: "NOOP", 386 | } 387 | 388 | PRE_IDS = { 389 | 'agent': 'A', 390 | 'prey': 'P', 391 | 'wall': 'W', 392 | 'empty': '0' 393 | } 394 | -------------------------------------------------------------------------------- /ma_gym/envs/switch/__init__.py: -------------------------------------------------------------------------------- 1 | from .switch_one_corridor import Switch -------------------------------------------------------------------------------- /ma_gym/envs/switch/switch_one_corridor.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | 4 | import gym 5 | import numpy as np 6 | from gym import spaces 7 | from gym.utils import seeding 8 | 9 | from ..utils.action_space import MultiAgentActionSpace 10 | from ..utils.draw import draw_grid, fill_cell, draw_cell_outline, draw_circle, write_cell_text 11 | from ..utils.observation_space import MultiAgentObservationSpace 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class Switch(gym.Env): 17 | metadata = {'render.modes': ['human', 'rgb_array']} 18 | 19 | def __init__(self, full_observable: bool = False, step_cost: float = 0, n_agents: int = 4, max_steps: int = 50, 20 | clock: bool = True): 21 | assert 2 <= n_agents <= 4, 'Number of Agents has to be in range [2,4]' 22 | self._grid_shape = (3, 7) 23 | self.n_agents = n_agents 24 | self._max_steps = max_steps 25 | self._step_count = None 26 | self._step_cost = step_cost 27 | self._total_episode_reward = None 28 | self._add_clock = clock 29 | self._agent_dones = None 30 | 31 | self.action_space = MultiAgentActionSpace([spaces.Discrete(5) for _ in range(self.n_agents)]) # l,r,t,d,noop 32 | 33 | init_agent_pos = {0: [0, 1], 1: [0, self._grid_shape[1] - 2], 34 | 2: [2, 1], 3: [2, self._grid_shape[1] - 2]} 35 | final_agent_pos = {0: [0, self._grid_shape[1] - 1], 1: [0, 0], 36 | 2: [2, self._grid_shape[1] - 1], 3: [2, 0]} # they have to go in opposite direction 37 | 38 | self.init_agent_pos, self.final_agent_pos = {}, {} 39 | for agent_i in range(n_agents): 40 | self.init_agent_pos[agent_i] = init_agent_pos[agent_i] 41 | self.final_agent_pos[agent_i] = final_agent_pos[agent_i] 42 | 43 | self._base_grid = self.__create_grid() # with no agents 44 | self._full_obs = self.__create_grid() 45 | self.__init_full_obs() 46 | self.viewer = None 47 | 48 | self.full_observable = full_observable 49 | # agent pos (2) 50 | self._obs_high = np.ones(2 + (1 if self._add_clock else 0)) 51 | self._obs_low = np.zeros(2 + (1 if self._add_clock else 0)) 52 | if self.full_observable: 53 | self._obs_high = np.tile(self._obs_high, self.n_agents) 54 | self._obs_low = np.tile(self._obs_low, self.n_agents) 55 | self.observation_space = MultiAgentObservationSpace([spaces.Box(self._obs_low, self._obs_high) 56 | for _ in range(self.n_agents)]) 57 | self.seed() 58 | 59 | def get_action_meanings(self, agent_i=None): 60 | if agent_i is not None: 61 | assert agent_i <= self.n_agents 62 | return [ACTION_MEANING[i] for i in range(self.action_space[agent_i].n)] 63 | else: 64 | return [[ACTION_MEANING[i] for i in range(ac.n)] for ac in self.action_space] 65 | 66 | def __draw_base_img(self): 67 | self._base_img = draw_grid(self._grid_shape[0], self._grid_shape[1], cell_size=CELL_SIZE, fill='white') 68 | for row in range(self._grid_shape[0]): 69 | for col in range(self._grid_shape[1]): 70 | if self.__wall_exists((row, col)): 71 | fill_cell(self._base_img, (row, col), cell_size=CELL_SIZE, fill=WALL_COLOR) 72 | 73 | for agent_i, pos in list(self.final_agent_pos.items())[:self.n_agents]: 74 | row, col = pos[0], pos[1] 75 | draw_cell_outline(self._base_img, (row, col), cell_size=CELL_SIZE, fill=AGENT_COLORS[agent_i]) 76 | 77 | def __create_grid(self): 78 | _grid = -1 * np.ones(self._grid_shape) # all are walls 79 | _grid[self._grid_shape[0] // 2, :] = 0 # road in the middle 80 | _grid[:, [0, 1]] = 0 81 | _grid[:, [-1, -2]] = 0 82 | return _grid 83 | 84 | def __init_full_obs(self): 85 | self.agent_pos = copy.copy(self.init_agent_pos) 86 | self._full_obs = self.__create_grid() 87 | for agent_i, pos in self.agent_pos.items(): 88 | self.__update_agent_view(agent_i) 89 | self.__draw_base_img() 90 | 91 | def get_agent_obs(self): 92 | _obs = [] 93 | for agent_i in range(0, self.n_agents): 94 | pos = self.agent_pos[agent_i] 95 | _agent_i_obs = [round(pos[0] / (self._grid_shape[0] - 1), 2), 96 | round(pos[1] / (self._grid_shape[1] - 1), 2)] 97 | if self._add_clock: 98 | _agent_i_obs += [self._step_count / self._max_steps] # add current step count (for time reference) 99 | _obs.append(_agent_i_obs) 100 | 101 | if self.full_observable: 102 | _obs = np.array(_obs).flatten().tolist() 103 | _obs = [_obs for _ in range(self.n_agents)] 104 | 105 | return _obs 106 | 107 | def reset(self): 108 | self.__init_full_obs() 109 | self._step_count = 0 110 | self._agent_dones = [False for _ in range(self.n_agents)] 111 | self._total_episode_reward = [0 for _ in range(self.n_agents)] 112 | return self.get_agent_obs() 113 | 114 | def __wall_exists(self, pos): 115 | row, col = pos 116 | return self._base_grid[row, col] == -1 117 | 118 | def _is_cell_vacant(self, pos): 119 | is_valid = (0 <= pos[0] < self._grid_shape[0]) and (0 <= pos[1] < self._grid_shape[1]) 120 | return is_valid and (self._full_obs[pos[0], pos[1]] == 0) 121 | 122 | def __update_agent_pos(self, agent_i, move): 123 | curr_pos = copy.copy(self.agent_pos[agent_i]) 124 | next_pos = None 125 | if move == 0: # down 126 | next_pos = [curr_pos[0] + 1, curr_pos[1]] 127 | elif move == 1: # left 128 | next_pos = [curr_pos[0], curr_pos[1] - 1] 129 | elif move == 2: # up 130 | next_pos = [curr_pos[0] - 1, curr_pos[1]] 131 | elif move == 3: # right 132 | next_pos = [curr_pos[0], curr_pos[1] + 1] 133 | elif move == 4: # no-op 134 | pass 135 | else: 136 | raise Exception('Action Not found!') 137 | 138 | if next_pos is not None and self._is_cell_vacant(next_pos): 139 | self.agent_pos[agent_i] = next_pos 140 | self._full_obs[curr_pos[0], curr_pos[1]] = 0 141 | self.__update_agent_view(agent_i) 142 | else: 143 | pass 144 | 145 | def __update_agent_view(self, agent_i): 146 | self._full_obs[self.agent_pos[agent_i][0], self.agent_pos[agent_i][1]] = agent_i + 1 147 | 148 | def __is_agent_done(self, agent_i): 149 | return self.agent_pos[agent_i] == self.final_agent_pos[agent_i] 150 | 151 | def step(self, agents_action): 152 | assert (self._step_count is not None), \ 153 | "Call reset before using step method." 154 | 155 | self._step_count += 1 156 | rewards = [self._step_cost for _ in range(self.n_agents)] 157 | for agent_i, action in enumerate(agents_action): 158 | if not (self._agent_dones[agent_i]): 159 | self.__update_agent_pos(agent_i, action) 160 | 161 | self._agent_dones[agent_i] = self.__is_agent_done(agent_i) 162 | if self._agent_dones[agent_i]: 163 | rewards[agent_i] = 5 164 | else: 165 | rewards[agent_i] = 0 166 | 167 | if self._step_count >= self._max_steps: 168 | for i in range(self.n_agents): 169 | self._agent_dones[i] = True 170 | 171 | for i in range(self.n_agents): 172 | self._total_episode_reward[i] += rewards[i] 173 | 174 | return self.get_agent_obs(), rewards, self._agent_dones, {} 175 | 176 | def render(self, mode='human'): 177 | assert (self._step_count is not None), \ 178 | "Call reset before using render method." 179 | 180 | img = copy.copy(self._base_img) 181 | for agent_i in range(self.n_agents): 182 | draw_circle(img, self.agent_pos[agent_i], cell_size=CELL_SIZE, fill=AGENT_COLORS[agent_i], radius=0.3) 183 | write_cell_text(img, text=str(agent_i + 1), pos=self.agent_pos[agent_i], cell_size=CELL_SIZE, 184 | fill='white', margin=0.4) 185 | img = np.asarray(img) 186 | 187 | if mode == 'rgb_array': 188 | return img 189 | elif mode == 'human': 190 | from gym.envs.classic_control import rendering 191 | if self.viewer is None: 192 | self.viewer = rendering.SimpleImageViewer() 193 | self.viewer.imshow(img) 194 | return self.viewer.isopen 195 | 196 | def seed(self, n=None): 197 | self.np_random, seed = seeding.np_random(n) 198 | return [seed] 199 | 200 | def close(self): 201 | if self.viewer is not None: 202 | self.viewer.close() 203 | self.viewer = None 204 | 205 | 206 | AGENT_COLORS = { 207 | 0: 'red', 208 | 1: 'blue', 209 | 2: 'green', 210 | 3: 'orange' 211 | } 212 | 213 | CELL_SIZE = 30 214 | 215 | WALL_COLOR = 'black' 216 | 217 | ACTION_MEANING = { 218 | 0: "DOWN", 219 | 1: "LEFT", 220 | 2: "UP", 221 | 3: "RIGHT", 222 | 4: "NOOP", 223 | } 224 | -------------------------------------------------------------------------------- /ma_gym/envs/traffic_junction/__init__.py: -------------------------------------------------------------------------------- 1 | from .traffic_junction import TrafficJunction 2 | -------------------------------------------------------------------------------- /ma_gym/envs/traffic_junction/traffic_junction.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import copy 4 | import logging 5 | import random 6 | 7 | import gym 8 | import numpy as np 9 | from gym import spaces 10 | from gym.utils import seeding 11 | 12 | from ..utils.action_space import MultiAgentActionSpace 13 | from ..utils.draw import draw_grid, fill_cell, write_cell_text 14 | from ..utils.observation_space import MultiAgentObservationSpace 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class TrafficJunction(gym.Env): 20 | """ 21 | This consists of a 4-way junction on a 14 × 14 grid. At each time step, "new" cars enter the grid with 22 | probability `p_arrive` from each of the four directions. However, the total number of cars at any given 23 | time is limited to `Nmax`. 24 | 25 | Each car occupies a single cell at any given time and is randomly assigned to one of three possible routes 26 | (keeping to the right-hand side of the road). At every time step, a car has two possible actions: gas which advances 27 | it by one cell on its route or brake to stay at its current location. A car will be removed once it reaches its 28 | destination at the edge of the grid. 29 | 30 | Two cars collide if their locations overlap. A collision incurs a reward `rcoll = −10`, but does not affect 31 | the simulation in any other way. To discourage a traffic jam, each car gets reward of `τ * r_time = −0.01τ` 32 | at every time step, where `τ` is the number time steps passed since the car arrived. Therefore, the total 33 | reward at time t is 34 | 35 | r(t) = C^t * r_coll + \sum_{i=1}_{N^t} {\tau_i * r_time} 36 | 37 | where C^t is the number of collisions occurring at time t and N^t is number of cars present. The simulation is 38 | terminated after 'max_steps(default:40)' steps and is classified as a failure if one or more collisions have 39 | occurred. 40 | 41 | Each car is represented by one-hot binary vector set {n, l, r}, that encodes its unique ID, current location 42 | and assigned route number respectively. Each agent controlling a car can only observe other cars in its vision 43 | range (a surrounding 3 × 3 neighborhood), though low level communication is allowed in "v1" version of the game. 44 | 45 | The state vector s_j for each agent is thus a concatenation of all these vectors, having dimension 46 | (3^2) × (|n| + |l| + |r|). 47 | 48 | Reference : Learning Multi-agent Communication with Backpropagation 49 | Url : https://papers.nips.cc/paper/6398-learning-multiagent-communication-with-backpropagation.pdf 50 | 51 | 52 | For details on various versions, please refer to "wiki" 53 | (https://github.com/koulanurag/ma-gym/wiki/Environments#TrafficJunction) 54 | """ 55 | metadata = {'render.modes': ['human', 'rgb_array']} 56 | 57 | def __init__(self, grid_shape=(14, 14), step_cost=-0.01, n_max=4, collision_reward=-10, arrive_prob=0.5, 58 | full_observable: bool = False, max_steps: int = 100): 59 | assert 1 <= n_max <= 10, "n_max should be range in [1,10]" 60 | assert 0 <= arrive_prob <= 1, "arrive probability should be in range [0,1]" 61 | assert len(grid_shape) == 2, 'only 2-d grids are acceptable' 62 | assert 1 <= max_steps, "max_steps should be more than 1" 63 | 64 | self._grid_shape = grid_shape 65 | self.n_agents = n_max 66 | self._max_steps = max_steps 67 | self._step_count = None # environment step counter 68 | self._collision_reward = collision_reward 69 | self._total_episode_reward = None 70 | self._arrive_prob = arrive_prob 71 | self._n_max = n_max 72 | self._step_cost = step_cost 73 | self.curr_cars_count = 0 74 | self._n_routes = 3 75 | 76 | self._agent_view_mask = (3, 3) 77 | 78 | # entry gates where the cars spawn 79 | # Note: [(7, 0), (13, 7), (6, 13), (0, 6)] for (14 x 14) grid 80 | self._entry_gates = [(self._grid_shape[0] // 2, 0), 81 | (self._grid_shape[0] - 1, self._grid_shape[1] // 2), 82 | (self._grid_shape[0] // 2 - 1, self._grid_shape[1] - 1), 83 | (0, self._grid_shape[1] // 2 - 1)] 84 | 85 | # destination places for the cars to reach 86 | # Note: [(7, 13), (0, 7), (6, 0), (13, 6)] for (14 x 14) grid 87 | self._destination = [(self._grid_shape[0] // 2, self._grid_shape[1] - 1), 88 | (0, self._grid_shape[1] // 2), 89 | (self._grid_shape[0] // 2 - 1, 0), 90 | (self._grid_shape[0] - 1, self._grid_shape[1] // 2 - 1)] 91 | 92 | # dict{direction_vectors: (turn_right, turn_left)} 93 | # Note: [((7, 6), (7,7))), ((7, 7),(6,7)), ((6,6),(7, 6)), ((6, 7),(6,6))] for (14 x14) grid 94 | self._turning_places = {(0, 1): ((self._grid_shape[0] // 2, self._grid_shape[0] // 2 - 1), 95 | (self._grid_shape[0] // 2, self._grid_shape[0] // 2)), 96 | (-1, 0): ((self._grid_shape[0] // 2, self._grid_shape[0] // 2), 97 | (self._grid_shape[0] // 2 - 1, self._grid_shape[0] // 2)), 98 | (1, 0): ((self._grid_shape[0] // 2 - 1, self._grid_shape[0] // 2 - 1), 99 | (self._grid_shape[0] // 2, self._grid_shape[0] // 2 - 1)), 100 | (0, -1): ((self._grid_shape[0] // 2 - 1, self._grid_shape[0] // 2), 101 | (self._grid_shape[0] // 2 - 1, self._grid_shape[0] // 2 - 1))} 102 | 103 | # dict{starting_place: direction_vector} 104 | self._route_vectors = {(self._grid_shape[0] // 2, 0): (0, 1), 105 | (self._grid_shape[0] - 1, self._grid_shape[0] // 2): (-1, 0), 106 | (0, self._grid_shape[0] // 2 - 1): (1, 0), 107 | (self._grid_shape[0] // 2 - 1, self._grid_shape[0] - 1): (0, -1)} 108 | 109 | self._agent_turned = [False for _ in range(self.n_agents)] # flag if car changed direction 110 | self._agents_routes = [-1 for _ in range(self.n_agents)] # route each car is following atm 111 | self._agents_direction = [(0, 0) for _ in range(self.n_agents)] # cars are not on the road initially 112 | self._agent_step_count = [0 for _ in range(self.n_agents)] # holds a step counter for each car 113 | 114 | self.action_space = MultiAgentActionSpace([spaces.Discrete(2) for _ in range(self.n_agents)]) 115 | self.agent_pos = {_: None for _ in range(self.n_agents)} 116 | self._on_the_road = [False for _ in range(self.n_agents)] # flag if car is on the road 117 | 118 | self._full_obs = self.__create_grid() 119 | self._base_img = self.__draw_base_img() 120 | self._agent_dones = [None for _ in range(self.n_agents)] 121 | 122 | self.viewer = None 123 | self.full_observable = full_observable 124 | 125 | # agent id (n_agents, onehot), obs_mask (9), pos (2), route (3) 126 | mask_size = np.prod(self._agent_view_mask) 127 | self._obs_high = np.ones((mask_size * (self.n_agents + self._n_routes + 2))) # 2 is for location 128 | self._obs_low = np.zeros((mask_size * (self.n_agents + self._n_routes + 2))) # 2 is for location 129 | if self.full_observable: 130 | self._obs_high = np.tile(self._obs_high, self.n_agents) 131 | self._obs_low = np.tile(self._obs_low, self.n_agents) 132 | self.observation_space = MultiAgentObservationSpace([spaces.Box(self._obs_low, self._obs_high) 133 | for _ in range(self.n_agents)]) 134 | 135 | def action_space_sample(self): 136 | return [agent_action_space.sample() for agent_action_space in self.action_space] 137 | 138 | def __init_full_obs(self): 139 | """ 140 | Initiates environment: inserts up to |entry_gates| cars. once the entry gates are filled, the remaining agents 141 | stay initialized outside the road waiting to enter 142 | """ 143 | self._full_obs = self.__create_grid() 144 | 145 | shuffled_gates = list(self._route_vectors.keys()) 146 | random.shuffle(shuffled_gates) 147 | for agent_i in range(self.n_agents): 148 | if self.curr_cars_count >= len(self._entry_gates): 149 | self.agent_pos[agent_i] = (0, 0) # not yet on the road 150 | else: 151 | pos = shuffled_gates[agent_i] 152 | # gets direction vector for agent_i that spawned in position pos 153 | self._agents_direction[agent_i] = self._route_vectors[pos] 154 | self.agent_pos[agent_i] = pos 155 | self.curr_cars_count += 1 156 | self._on_the_road[agent_i] = True 157 | self._agents_routes[agent_i] = random.randint(1, self._n_routes) # [1,3] (inclusive) 158 | self.__update_agent_view(agent_i) 159 | 160 | self.__draw_base_img() 161 | 162 | def _is_cell_vacant(self, pos): 163 | return self.is_valid(pos) and (self._full_obs[pos[0]][pos[1]] == PRE_IDS['empty']) 164 | 165 | def is_valid(self, pos): 166 | return (0 <= pos[0] < self._grid_shape[0]) and (0 <= pos[1] < self._grid_shape[1]) 167 | 168 | def __update_agent_view(self, agent_i): 169 | self._full_obs[self.agent_pos[agent_i][0]][self.agent_pos[agent_i][1]] = PRE_IDS['agent'] + str(agent_i + 1) 170 | 171 | def __check_collision(self, pos): 172 | """ 173 | Verifies if a transition to the position pos will result on a collision. 174 | :param pos: position to verify if there is collision 175 | :type pos: tuple 176 | 177 | :return: boolean stating true or false 178 | :rtype: bool 179 | """ 180 | return self.is_valid(pos) and (self._full_obs[pos[0]][pos[1]].find(PRE_IDS['agent']) > -1) 181 | 182 | def __is_gate_free(self): 183 | """ 184 | Verifies if any spawning gate is free for a car to be placed 185 | 186 | :return: list of currently free gates 187 | :rtype: list 188 | """ 189 | free_gates = [] 190 | for pos in self._entry_gates: 191 | if pos not in self.agent_pos.values(): 192 | free_gates.append(pos) 193 | return free_gates 194 | 195 | def __reached_dest(self, agent_i): 196 | """ 197 | Verifies if the agent_i reached a destination place. 198 | :param agent_i: id of the agent 199 | :type agent_i: int 200 | 201 | :return: boolean stating true or false 202 | :rtype: bool 203 | """ 204 | pos = self.agent_pos[agent_i] 205 | if pos in self._destination: 206 | self._full_obs[pos[0]][pos[1]] = PRE_IDS['empty'] 207 | return True 208 | return False 209 | 210 | def get_agent_obs(self): 211 | """ 212 | Computes the observations for the agents. Each agent receives information about cars in it's vision 213 | range (a surrounding 3 × 3 neighborhood),where each car is represented by one-hot binary vector set {n, l, r}, 214 | that encodes its unique ID, current location and assigned route number respectively. 215 | 216 | The state vector s_j for each agent is thus a concatenation of all these vectors, having dimension 217 | (3^2) × (|n| + |l| + |r|). 218 | 219 | :return: list with observations of all agents. the full list has shape (n_agents, (3^2) × (|n| + |l| + |r|)) 220 | :rtype: list 221 | """ 222 | agent_no_mask_obs = [] 223 | 224 | for agent_i in range(self.n_agents): 225 | pos = self.agent_pos[agent_i] 226 | 227 | # agent id 228 | _agent_i_obs = [0 for _ in range(self.n_agents)] 229 | _agent_i_obs[agent_i] = 1 230 | 231 | # location 232 | _agent_i_obs += [pos[0] / (self._grid_shape[0] - 1), pos[1] / (self._grid_shape[1] - 1)] # coordinates 233 | 234 | # route 235 | route_agent_i = np.zeros(self._n_routes) 236 | route_agent_i[self._agents_routes[agent_i] - 1] = 1 237 | 238 | _agent_i_obs += route_agent_i.tolist() 239 | 240 | agent_no_mask_obs.append(_agent_i_obs) 241 | 242 | agent_obs = [] 243 | for agent_i in range(self.n_agents): 244 | pos = self.agent_pos[agent_i] 245 | mask_view = np.zeros((*self._agent_view_mask, len(agent_no_mask_obs[0])), dtype=np.float32) 246 | for row in range(max(0, pos[0] - 1), min(pos[0] + 1 + 1, self._grid_shape[0])): 247 | for col in range(max(0, pos[1] - 1), min(pos[1] + 1 + 1, self._grid_shape[1])): 248 | if PRE_IDS['agent'] in self._full_obs[row][col]: 249 | _id = int(self._full_obs[row][col].split(PRE_IDS['agent'])[1]) - 1 250 | mask_view[row - (pos[0] - 1), col - (pos[1] - 1), :] = agent_no_mask_obs[_id] 251 | agent_obs.append(mask_view.flatten()) 252 | 253 | if self.full_observable: 254 | _obs = np.array(agent_obs).flatten().tolist() 255 | agent_obs = [_obs for _ in range(self.n_agents)] 256 | return agent_obs 257 | 258 | def __draw_base_img(self): 259 | # create grid and make everything black 260 | img = draw_grid(self._grid_shape[0], self._grid_shape[1], cell_size=CELL_SIZE, fill=WALL_COLOR) 261 | 262 | # draw tracks 263 | for i, row in enumerate(self._full_obs): 264 | for j, col in enumerate(row): 265 | if col == PRE_IDS['empty']: 266 | fill_cell(img, (i, j), cell_size=CELL_SIZE, fill=(143, 141, 136), margin=0.05) 267 | elif col == PRE_IDS['wall']: 268 | fill_cell(img, (i, j), cell_size=CELL_SIZE, fill=(242, 227, 167), margin=0.02) 269 | 270 | return img 271 | 272 | def __create_grid(self): 273 | # create a grid with every cell as wall 274 | _grid = [[PRE_IDS['wall'] for _ in range(self._grid_shape[1])] for _ in range(self._grid_shape[0])] 275 | 276 | # draw track by making cells empty : 277 | # horizontal tracks 278 | _grid[self._grid_shape[0] // 2 - 1] = [PRE_IDS['empty'] for _ in range(self._grid_shape[1])] 279 | _grid[self._grid_shape[0] // 2] = [PRE_IDS['empty'] for _ in range(self._grid_shape[1])] 280 | 281 | # vertical tracks 282 | for row in range(self._grid_shape[0]): 283 | _grid[row][self._grid_shape[1] // 2 - 1] = PRE_IDS['empty'] 284 | _grid[row][self._grid_shape[1] // 2] = PRE_IDS['empty'] 285 | 286 | return _grid 287 | 288 | def step(self, agents_action): 289 | """ 290 | Performs an action in the environment and steps forward. At each step a new agent enters the road by 291 | one of the 4 gates according to a probability "_arrive_prob". A "ncoll" reward is given to an agent if it 292 | collides and all of them receive "-0.01*step_n" to avoid traffic jams. 293 | 294 | :param agents_action: list of actions of all the agents to perform in the environment 295 | :type agents_action: list 296 | 297 | :return: agents observations, rewards, if agents are done and additional info 298 | :rtype: tuple 299 | """ 300 | assert len(agents_action) == self.n_agents, \ 301 | "Invalid action! It was expected to be list of {}" \ 302 | " dimension but was found to be of {}".format(self.n_agents, len(agents_action)) 303 | 304 | assert all([action_i in ACTION_MEANING.keys() for action_i in agents_action]), \ 305 | "Invalid action found in the list of sampled actions {}" \ 306 | ". Valid actions are {}".format(agents_action, ACTION_MEANING.keys()) 307 | 308 | self._step_count += 1 # global environment step 309 | rewards = [0 for _ in range(self.n_agents)] # initialize rewards array 310 | step_collisions = 0 # counts collisions in this step 311 | 312 | # checks if there is a collision; this is done in the __update_agent_pos method 313 | # we still need to check both agent_dones and on_the_road because an agent may not be done 314 | # and have not entered the road yet 315 | for agent_i, action in enumerate(agents_action): 316 | if not self._agent_dones[agent_i] and self._on_the_road[agent_i]: 317 | self._agent_step_count[agent_i] += 1 # agent step count 318 | collision_flag = self.__update_agent_pos(agent_i, action) 319 | if collision_flag: 320 | rewards[agent_i] += self._collision_reward 321 | step_collisions += 1 322 | 323 | # gives additional step punishment to avoid jams 324 | # at every time step, where `τ` is the number time steps passed since the car arrived. 325 | # We need to keep track of step_count of each car and that has to be multiplied. 326 | rewards[agent_i] += self._step_cost * self._agent_step_count[agent_i] 327 | self._total_episode_reward[agent_i] += rewards[agent_i] 328 | 329 | # checks if destination was reached 330 | # once a car reaches it's destination , it will never enter again in any of the tracks 331 | # Also, if all cars have reached their destination, then we terminate the episode. 332 | if self.__reached_dest(agent_i): 333 | self._agent_dones[agent_i] = True 334 | self.curr_cars_count -= 1 335 | 336 | # if max_steps was reached, terminate the episode 337 | if self._step_count >= self._max_steps: 338 | self._agent_dones[agent_i] = True 339 | 340 | # adds new car according to the probability _arrive_prob 341 | if random.uniform(0, 1) < self._arrive_prob: 342 | free_gates = self.__is_gate_free() 343 | # if there are agents outside the road and if any gate is free 344 | if not all(self._on_the_road) and free_gates: 345 | # then gets first agent on the list which is not on the road 346 | agent_to_enter = self._on_the_road.index(False) 347 | pos = random.choice(free_gates) 348 | self._agents_direction[agent_to_enter] = self._route_vectors[pos] 349 | self.agent_pos[agent_to_enter] = pos 350 | self.curr_cars_count += 1 351 | self._on_the_road[agent_to_enter] = True 352 | self._agent_turned[agent_to_enter] = False 353 | self._agents_routes[agent_to_enter] = random.randint(1, self._n_routes) # (1, 3) 354 | self.__update_agent_view(agent_to_enter) 355 | 356 | return self.get_agent_obs(), rewards, self._agent_dones, {'step_collisions': step_collisions} 357 | 358 | def __get_next_direction(self, route, agent_i): 359 | """ 360 | Computes the new direction vector after the cars turn on the junction for route 2 (turn right) and 3 (turn left) 361 | :param route: route that was assigned to the car (1 - fwd, 2 - turn right, 3 - turn left) 362 | :type route: int 363 | 364 | :param agent_i: id of the agent 365 | :type agent_i: int 366 | 367 | :return: new direction vector following the assigned route 368 | :rtype: tuple 369 | """ 370 | # gets current direction vector 371 | dir_vector = self._agents_direction[agent_i] 372 | 373 | sig = (1 if dir_vector[1] != 0 else -1) if route == 2 else (-1 if dir_vector[1] != 0 else 1) 374 | new_dir_vector = (dir_vector[1] * sig, 0) if dir_vector[0] == 0 else (0, dir_vector[0] * sig) 375 | 376 | return new_dir_vector 377 | 378 | def __update_agent_pos(self, agent_i, move): 379 | """ 380 | Updates the agent position in the environment. Moves can be 0 (GAS) or 1 (BRAKE). If the move is 1 does nothing, 381 | car remains stopped. If the move is 0 then evaluate the route assigned. If the route is 1 (forward) then 382 | maintain the same direction vector. Otherwise, compute new direction vector and apply the change of direction 383 | when the junction turning place was reached. After the move is made, verifies if it resulted into a collision 384 | and returns the reward collision if that happens. The position is only updated if no collision occurred. 385 | 386 | :param agent_i: id of the agent 387 | :type agent_i: int 388 | 389 | :param move: move picked by the agent_i 390 | :type move: int 391 | 392 | :return: bool flag associated to the existence or absence of a collision 393 | :rtype: bool 394 | """ 395 | 396 | curr_pos = copy.copy(self.agent_pos[agent_i]) 397 | next_pos = None 398 | route = self._agents_routes[agent_i] 399 | 400 | if move == 0: # GAS 401 | if route == 1: 402 | next_pos = tuple([curr_pos[i] + self._agents_direction[agent_i][i] for i in range(len(curr_pos))]) 403 | else: 404 | turn_pos = self._turning_places[self._agents_direction[agent_i]] 405 | # if the car reached the turning position in the junction for his route and starting gate 406 | if curr_pos == turn_pos[route - 2] and not self._agent_turned[agent_i]: 407 | new_dir_vector = self.__get_next_direction(route, agent_i) 408 | self._agents_direction[agent_i] = new_dir_vector 409 | self._agent_turned[agent_i] = True 410 | next_pos = tuple([curr_pos[i] + new_dir_vector[i] for i in range(len(curr_pos))]) 411 | else: 412 | next_pos = tuple([curr_pos[i] + self._agents_direction[agent_i][i] for i in range(len(curr_pos))]) 413 | elif move == 1: # BRAKE 414 | pass 415 | else: 416 | raise Exception('Action Not found!') 417 | 418 | # if there is a collision 419 | if next_pos is not None and self.__check_collision(next_pos): 420 | return True 421 | 422 | # if there is no collision and the next position is free updates agent position 423 | if next_pos is not None and self._is_cell_vacant(next_pos): 424 | self.agent_pos[agent_i] = next_pos 425 | self._full_obs[curr_pos[0]][curr_pos[1]] = PRE_IDS['empty'] 426 | self.__update_agent_view(agent_i) 427 | 428 | return False 429 | 430 | def reset(self): 431 | """ 432 | Resets the environment when a terminal state is reached. 433 | 434 | :return: list with the observations of the agents 435 | :rtype: list 436 | """ 437 | self._total_episode_reward = [0 for _ in range(self.n_agents)] 438 | self._step_count = 0 439 | self._agent_step_count = [0 for _ in range(self.n_agents)] 440 | self._agent_dones = [False for _ in range(self.n_agents)] 441 | self._on_the_road = [False for _ in range(self.n_agents)] 442 | self._agent_turned = [False for _ in range(self.n_agents)] 443 | self.curr_cars_count = 0 444 | 445 | self.agent_pos = {} 446 | self.__init_full_obs() 447 | 448 | return self.get_agent_obs() 449 | 450 | def render(self, mode: str = 'human'): 451 | img = copy.copy(self._base_img) 452 | 453 | for agent_i in range(self.n_agents): 454 | if not self._agent_dones[agent_i] and self._on_the_road[agent_i]: 455 | fill_cell(img, self.agent_pos[agent_i], cell_size=CELL_SIZE, fill=AGENTS_COLORS[agent_i]) 456 | write_cell_text(img, text=str(agent_i + 1), pos=self.agent_pos[agent_i], cell_size=CELL_SIZE, 457 | fill='white', margin=0.3) 458 | 459 | img = np.asarray(img) 460 | if mode == 'rgb_array': 461 | return img 462 | elif mode == 'human': 463 | from gym.envs.classic_control import rendering 464 | if self.viewer is None: 465 | self.viewer = rendering.SimpleImageViewer() 466 | self.viewer.imshow(img) 467 | return self.viewer.isopen 468 | 469 | def seed(self, n: int): 470 | self.np_random, seed1 = seeding.np_random(n) 471 | seed2 = seeding.hash_seed(seed1 + 1) % 2 ** 31 472 | return [seed1, seed2] 473 | 474 | def close(self): 475 | if self.viewer is not None: 476 | self.viewer.close() 477 | self.viewer = None 478 | 479 | 480 | CELL_SIZE = 30 481 | 482 | WALL_COLOR = 'black' 483 | 484 | # fixed colors for #agents = n_max <= 10 485 | AGENTS_COLORS = [ 486 | "red", 487 | "blue", 488 | "yellow", 489 | "orange", 490 | "black", 491 | "green", 492 | "purple", 493 | "pink", 494 | "brown", 495 | "grey" 496 | ] 497 | 498 | ACTION_MEANING = { 499 | 0: "GAS", 500 | 1: "BRAKE", 501 | } 502 | 503 | PRE_IDS = { 504 | 'wall': 'W', 505 | 'empty': '0', 506 | 'agent': 'A' 507 | } 508 | -------------------------------------------------------------------------------- /ma_gym/envs/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/ma_gym/envs/utils/__init__.py -------------------------------------------------------------------------------- /ma_gym/envs/utils/action_space.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | 4 | class MultiAgentActionSpace(list): 5 | def __init__(self, agents_action_space): 6 | for x in agents_action_space: 7 | assert isinstance(x, gym.spaces.space.Space) 8 | 9 | super(MultiAgentActionSpace, self).__init__(agents_action_space) 10 | self._agents_action_space = agents_action_space 11 | 12 | def sample(self): 13 | """ samples action for each agent from uniform distribution""" 14 | return [agent_action_space.sample() for agent_action_space in self._agents_action_space] 15 | -------------------------------------------------------------------------------- /ma_gym/envs/utils/draw.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from PIL import Image, ImageDraw 4 | 5 | 6 | def get_cell_sizes(cell_size: Union[int, list, tuple]): 7 | """Handle multiple type options of `cell_size`. 8 | 9 | In order to keep the old API of following functions, as well as add 10 | support for non-square grids we need to check cell_size type and 11 | extend it appropriately. 12 | 13 | Args: 14 | cell_size: integer of tuple/list size of two with cell size 15 | in horizontal and vertical direction. 16 | 17 | Returns: 18 | Horizontal and vertical cell size. 19 | """ 20 | if isinstance(cell_size, int): 21 | cell_size_vertical = cell_size 22 | cell_size_horizontal = cell_size 23 | elif isinstance(cell_size, (tuple, list)) and len(cell_size) == 2: 24 | # Flipping coordinates, because first coordinates coresponds with height (=vertical direction) 25 | cell_size_vertical, cell_size_horizontal = cell_size 26 | else: 27 | raise TypeError("`cell_size` must be integer, tuple or list with length two.") 28 | 29 | return cell_size_horizontal, cell_size_vertical 30 | 31 | 32 | def draw_grid(rows, cols, cell_size=50, fill='black', line_color='black'): 33 | cell_size_x, cell_size_y = get_cell_sizes(cell_size) 34 | 35 | width = cols * cell_size_x 36 | height = rows * cell_size_y 37 | image = Image.new(mode='RGB', size=(width, height), color=fill) 38 | 39 | # Draw some lines 40 | draw = ImageDraw.Draw(image) 41 | y_start = 0 42 | y_end = image.height 43 | 44 | for x in range(0, image.width, cell_size_x): 45 | line = ((x, y_start), (x, y_end)) 46 | draw.line(line, fill=line_color) 47 | 48 | x = image.width - 1 49 | line = ((x, y_start), (x, y_end)) 50 | draw.line(line, fill=line_color) 51 | 52 | x_start = 0 53 | x_end = image.width 54 | 55 | for y in range(0, image.height, cell_size_y): 56 | line = ((x_start, y), (x_end, y)) 57 | draw.line(line, fill=line_color) 58 | 59 | y = image.height - 1 60 | line = ((x_start, y), (x_end, y)) 61 | draw.line(line, fill=line_color) 62 | 63 | del draw 64 | 65 | return image 66 | 67 | 68 | def fill_cell(image, pos, cell_size=None, fill='black', margin=0): 69 | assert cell_size is not None and 0 <= margin <= 1 70 | 71 | cell_size_x, cell_size_y = get_cell_sizes(cell_size) 72 | col, row = pos 73 | row, col = row * cell_size_x, col * cell_size_y 74 | margin_x, margin_y = margin * cell_size_x, margin * cell_size_y 75 | x, y, x_dash, y_dash = row + margin_x, col + margin_y, row + cell_size_x - margin_x, col + cell_size_y - margin_y 76 | ImageDraw.Draw(image).rectangle([(x, y), (x_dash, y_dash)], fill=fill) 77 | 78 | 79 | def write_cell_text(image, text, pos, cell_size=None, fill='black', margin=0): 80 | assert cell_size is not None and 0 <= margin <= 1 81 | 82 | cell_size_x, cell_size_y = get_cell_sizes(cell_size) 83 | col, row = pos 84 | row, col = row * cell_size_x, col * cell_size_y 85 | margin_x, margin_y = margin * cell_size_x, margin * cell_size_y 86 | x, y = row + margin_x, col + margin_y 87 | ImageDraw.Draw(image).text((x, y), text=text, fill=fill) 88 | 89 | 90 | def draw_cell_outline(image, pos, cell_size=50, fill='black'): 91 | cell_size_x, cell_size_y = get_cell_sizes(cell_size) 92 | col, row = pos 93 | row, col = row * cell_size_x, col * cell_size_y 94 | ImageDraw.Draw(image).rectangle([(row, col), (row + cell_size_x, col + cell_size_y)], outline=fill, width=3) 95 | 96 | 97 | def draw_circle(image, pos, cell_size=50, fill='black', radius=0.3): 98 | cell_size_x, cell_size_y = get_cell_sizes(cell_size) 99 | col, row = pos 100 | row, col = row * cell_size_x, col * cell_size_y 101 | gap_x, gap_y = cell_size_x * radius, cell_size_y * radius 102 | x, y = row + gap_x, col + gap_y 103 | x_dash, y_dash = row + cell_size_x - gap_x, col + cell_size_y - gap_y 104 | ImageDraw.Draw(image).ellipse([(x, y), (x_dash, y_dash)], outline=fill, fill=fill) 105 | 106 | 107 | def draw_border(image, border_width=1, fill='black'): 108 | width, height = image.size 109 | new_im = Image.new("RGB", size=(width + 2 * border_width, height + 2 * border_width), color=fill) 110 | new_im.paste(image, (border_width, border_width)) 111 | return new_im 112 | 113 | 114 | def draw_score_board(image, score, board_height=30): 115 | im_width, im_height = image.size 116 | new_im = Image.new("RGB", size=(im_width, im_height + board_height), color='#e1e4e8') 117 | new_im.paste(image, (0, board_height)) 118 | 119 | _text = ', '.join([str(round(x, 2)) for x in score]) 120 | ImageDraw.Draw(new_im).text((10, board_height // 3), text=_text, fill='black') 121 | return new_im 122 | -------------------------------------------------------------------------------- /ma_gym/envs/utils/observation_space.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | 4 | class MultiAgentObservationSpace(list): 5 | def __init__(self, agents_observation_space): 6 | for x in agents_observation_space: 7 | assert isinstance(x, gym.spaces.space.Space) 8 | 9 | super().__init__(agents_observation_space) 10 | self._agents_observation_space = agents_observation_space 11 | 12 | def sample(self): 13 | """ samples observations for each agent from uniform distribution""" 14 | return [agent_observation_space.sample() for agent_observation_space in self._agents_observation_space] 15 | 16 | def contains(self, obs): 17 | """ contains observation """ 18 | for space, ob in zip(self._agents_observation_space, obs): 19 | if not space.contains(ob): 20 | return False 21 | else: 22 | return True 23 | -------------------------------------------------------------------------------- /ma_gym/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from ma_gym.wrappers.monitor import Monitor -------------------------------------------------------------------------------- /ma_gym/wrappers/monitor.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import six 4 | from gym import error, logger 5 | from gym.utils import closer 6 | from gym.wrappers import Monitor as MO 7 | 8 | from ma_gym.wrappers.monitoring import stats_recorder 9 | 10 | FILE_PREFIX = 'openaigym' 11 | MANIFEST_PREFIX = FILE_PREFIX + '.manifest' 12 | 13 | 14 | class Monitor(MO): 15 | """ Multi Agent Monitor""" 16 | 17 | def __init__(self, *args, **kwargs): 18 | super().__init__(*args, **kwargs) 19 | self.n_agents = self.env.n_agents 20 | 21 | def _start(self, directory, video_callable=None, force=False, resume=False, 22 | write_upon_reset=False, uid=None, mode=None): 23 | """Start monitoring. 24 | Args: 25 | directory (str): A per-training run directory where to record stats. 26 | video_callable (Optional[function, False]): function that takes in the index of the episode and outputs a boolean, indicating whether we should record a video on this episode. The default (for video_callable is None) is to take perfect cubes, capped at 1000. False disables video recording. 27 | force (bool): Clear out existing training data from this directory (by deleting every file prefixed with "openaigym."). 28 | resume (bool): Retain the training data already in this directory, which will be merged with our new data 29 | write_upon_reset (bool): Write the manifest file on each reset. (This is currently a JSON file, so writing it is somewhat expensive.) 30 | uid (Optional[str]): A unique id used as part of the suffix for the file. By default, uses os.getpid(). 31 | mode (['evaluation', 'training']): Whether this is an evaluation or training episode. 32 | """ 33 | if self.env.spec is None: 34 | logger.warn( 35 | "Trying to monitor an environment which has no 'spec' set. " 36 | "This usually means you did not create it via 'gym.make', and is recommended only for advanced users.") 37 | env_id = '(unknown)' 38 | else: 39 | env_id = self.env.spec.id 40 | 41 | if not os.path.exists(directory): 42 | logger.info('Creating monitor directory %s', directory) 43 | if six.PY3: 44 | os.makedirs(directory, exist_ok=True) 45 | else: 46 | os.makedirs(directory) 47 | 48 | if video_callable is None: 49 | video_callable = capped_cubic_video_schedule 50 | elif video_callable == False: 51 | video_callable = disable_videos 52 | elif not callable(video_callable): 53 | raise error.Error('You must provide a function, None, or False for video_callable, not {}: {}'.format( 54 | type(video_callable), video_callable)) 55 | self.video_callable = video_callable 56 | 57 | # Check on whether we need to clear anything 58 | if force: 59 | clear_monitor_files(directory) 60 | elif not resume: 61 | training_manifests = detect_training_manifests(directory) 62 | if len(training_manifests) > 0: 63 | raise error.Error('''Trying to write to monitor directory {} with existing monitor files: {}. 64 | You should use a unique directory for each training run, or use 'force=True' 65 | to automatically clear previous monitor files.''' 66 | .format(directory, ', '.join(training_manifests[:5]))) 67 | 68 | self._monitor_id = monitor_closer.register(self) 69 | 70 | self.enabled = True 71 | self.directory = os.path.abspath(directory) 72 | # We use the 'openai-gym' prefix to determine if a file is 73 | # ours 74 | self.file_prefix = FILE_PREFIX 75 | self.file_infix = '{}.{}'.format(self._monitor_id, uid if uid else os.getpid()) 76 | 77 | self.stats_recorder = stats_recorder.StatsRecorder(directory, '{}.episode_batch.{}'.format(self.file_prefix, 78 | self.file_infix), 79 | autoreset=self.env_semantics_autoreset, env_id=env_id) 80 | 81 | if not os.path.exists(directory): os.mkdir(directory) 82 | self.write_upon_reset = write_upon_reset 83 | 84 | if mode is not None: 85 | self._set_mode(mode) 86 | 87 | 88 | def detect_training_manifests(training_dir, files=None): 89 | if files is None: 90 | files = os.listdir(training_dir) 91 | return [os.path.join(training_dir, f) for f in files if f.startswith(MANIFEST_PREFIX + '.')] 92 | 93 | 94 | def detect_monitor_files(training_dir): 95 | return [os.path.join(training_dir, f) for f in os.listdir(training_dir) if f.startswith(FILE_PREFIX + '.')] 96 | 97 | 98 | def clear_monitor_files(training_dir): 99 | files = detect_monitor_files(training_dir) 100 | if len(files) == 0: 101 | return 102 | 103 | logger.info('Clearing %d monitor files from previous run (because force=True was provided)', len(files)) 104 | for file in files: 105 | os.unlink(file) 106 | 107 | 108 | def capped_cubic_video_schedule(episode_id): 109 | if episode_id < 1000: 110 | return int(round(episode_id ** (1. / 3))) ** 3 == episode_id 111 | else: 112 | return episode_id % 1000 == 0 113 | 114 | 115 | def disable_videos(episode_id): 116 | return False 117 | 118 | 119 | monitor_closer = closer.Closer() 120 | -------------------------------------------------------------------------------- /ma_gym/wrappers/monitoring/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/ma_gym/wrappers/monitoring/__init__.py -------------------------------------------------------------------------------- /ma_gym/wrappers/monitoring/stats_recorder.py: -------------------------------------------------------------------------------- 1 | from gym.wrappers.monitoring.stats_recorder import StatsRecorder as SR 2 | 3 | 4 | class StatsRecorder(SR): 5 | def __init__(self, *args, **kwargs): 6 | super().__init__(*args, **kwargs) 7 | 8 | def after_step(self, observation, reward, done, info): 9 | super().after_step(observation, sum(reward), all(done), info) 10 | -------------------------------------------------------------------------------- /scripts/generate_env_markdown_table.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: python scripts/generate_env_markdown_table.py 3 | """ 4 | 5 | import argparse 6 | from os import listdir 7 | from os.path import join 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description='Generate Markdown table for ReadMe') 11 | parser.add_argument('--path', default='static/gif/', 12 | help='Path (default: %(default)s)') 13 | args = parser.parse_args() 14 | 15 | onlyfiles = [(f, join(args.path, f)) for f in sorted(listdir(args.path))] 16 | 17 | for i in range(0, len(onlyfiles), 3): 18 | 19 | msg = "|" 20 | for f, file_path in onlyfiles[i:i + 3]: 21 | msg += ' __' + f.split('.')[0] + '__ |' 22 | if i == 0: 23 | msg += '\n' 24 | msg += '|' + ''.join(':---:|' for _ in range(3)) 25 | print(msg) 26 | 27 | msg = "|" 28 | for f, file_path in onlyfiles[i:i + 3]: 29 | msg += '![' + f + '](' + file_path + ')|' 30 | print(msg) 31 | -------------------------------------------------------------------------------- /scripts/record_environment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: python scripts/record_environment.py --env Checkers-v0 3 | """ 4 | 5 | import argparse 6 | import os 7 | 8 | import gym 9 | import imageio 10 | 11 | 12 | def parse_arguments(): 13 | parser = argparse.ArgumentParser(description='Record the environment.') 14 | parser.add_argument('--output_dir', type=str, default='static/gif/', 15 | help='Output directory with GIF record.') 16 | parser.add_argument('--env', type=str, 17 | help='Name of recorded environment.', required=True) 18 | parser.add_argument('--frames', type=int, default=100, 19 | help='Number of frames in GIF record.') 20 | parser.add_argument('--fps', type=int, default=21, 21 | help='Frame per second.') 22 | return parser.parse_args() 23 | 24 | 25 | def main(args): 26 | env = gym.make('ma_gym:' + args.env) 27 | pics = [] 28 | done_n = [False] * env.n_agents 29 | 30 | env.reset() 31 | while not all(done_n): 32 | pics.append(env.render(mode='rgb_array')) 33 | _, _, done_n, _ = env.step(env.action_space.sample()) 34 | 35 | print("Environment finished.") 36 | imageio.mimwrite(os.path.join(args.output_dir, args.env + '.gif'), pics[:args.frames], fps=args.fps) 37 | 38 | 39 | if __name__ == "__main__": 40 | args = parse_arguments() 41 | main(args) 42 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | import setuptools 4 | from setuptools import setup 5 | 6 | extras = { 7 | 'test': ['pytest<8.0.0', 'pytest_cases'], 8 | 'develop': ['imageio'], 9 | } 10 | 11 | # Meta dependency groups. 12 | extras['all'] = [item for group in extras.values() for item in group] 13 | 14 | setup(name='ma_gym', 15 | version='0.0.14', 16 | description='A collection of multi agent environments based on OpenAI gym.', 17 | long_description_content_type='text/markdown', 18 | long_description=open(path.join(path.abspath(path.dirname(__file__)), 'README.md'), encoding='utf-8').read(), 19 | url='https://github.com/koulanurag/ma-gym', 20 | author='Anurag Koul', 21 | author_email='koulanurag@gmail.com', 22 | license='MIT License', 23 | packages=setuptools.find_packages(), 24 | install_requires=[ 25 | 'scipy>=1.3.0', 26 | 'numpy>=1.16.4', 27 | 'pyglet>=1.4.0,<=1.5.27', 28 | 'cloudpickle==2.0.0', 29 | 'gym>=0.19.0,<=0.20.0', 30 | 'pillow>=7.2.0', 31 | 'six>=1.16.0' 32 | ], 33 | extras_require=extras, 34 | tests_require=extras['test'], 35 | python_requires='>=3.6, <3.12', 36 | classifiers=[ 37 | 'Programming Language :: Python :: 3.6', 38 | 'Programming Language :: Python :: 3.7', 39 | 'Programming Language :: Python :: 3.8', 40 | 'Programming Language :: Python :: 3.9', 41 | 'Programming Language :: Python :: 3.10', 42 | 'Programming Language :: Python :: 3.11', 43 | ], 44 | ) 45 | -------------------------------------------------------------------------------- /static/gif/Checkers-v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/static/gif/Checkers-v0.gif -------------------------------------------------------------------------------- /static/gif/Combat-v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/static/gif/Combat-v0.gif -------------------------------------------------------------------------------- /static/gif/Lumberjacks-v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/static/gif/Lumberjacks-v0.gif -------------------------------------------------------------------------------- /static/gif/PongDuel-v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/static/gif/PongDuel-v0.gif -------------------------------------------------------------------------------- /static/gif/PredatorPrey5x5-v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/static/gif/PredatorPrey5x5-v0.gif -------------------------------------------------------------------------------- /static/gif/PredatorPrey7x7-v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/static/gif/PredatorPrey7x7-v0.gif -------------------------------------------------------------------------------- /static/gif/Switch2-v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/static/gif/Switch2-v0.gif -------------------------------------------------------------------------------- /static/gif/Switch4-v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/static/gif/Switch4-v0.gif -------------------------------------------------------------------------------- /static/gif/TrafficJunction10-v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/static/gif/TrafficJunction10-v0.gif -------------------------------------------------------------------------------- /static/gif/TrafficJunction4-v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/static/gif/TrafficJunction4-v0.gif -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/tests/__init__.py -------------------------------------------------------------------------------- /tests/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koulanurag/ma-gym/1f0aa3d93f8c2e48b617f5cc04b01cda1f1a3943/tests/envs/__init__.py -------------------------------------------------------------------------------- /tests/envs/test_checkers.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import pytest 3 | from pytest_cases import parametrize, fixture_ref 4 | 5 | 6 | @pytest.fixture(scope='module') 7 | def env(): 8 | env = gym.make('ma_gym:Checkers-v0') 9 | yield env 10 | env.close() 11 | 12 | 13 | @pytest.fixture(scope='module') 14 | def env_full(): 15 | env = gym.make('ma_gym:Checkers-v1') 16 | yield env 17 | env.close 18 | 19 | 20 | def test_init(env): 21 | assert env.n_agents == 2 22 | 23 | 24 | def test_reset(env): 25 | import numpy as np 26 | obs_n = env.reset() 27 | 28 | # add agent 1 obs 29 | agent_1_obs = [0.0, 0.86] 30 | agent_1_obs += np.array([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], 31 | [[1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0]], 32 | [[0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]).flatten().tolist() 33 | # add agent 2 obs 34 | agent_2_obs = [0.67, 0.86] 35 | agent_2_obs += np.array([[[0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], 36 | [[1, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0]], 37 | [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]).flatten().tolist() 38 | 39 | init_obs_n = [agent_1_obs, agent_2_obs] 40 | 41 | assert env._step_count == 0 42 | assert env._total_episode_reward == [0 for _ in range(env.n_agents)] 43 | assert env._agent_dones == [False for _ in range(env.n_agents)] 44 | 45 | for i in range(env.n_agents): 46 | assert obs_n[i] == init_obs_n[i], \ 47 | 'Agent {} observation mis-match'.format(i + 1) 48 | 49 | 50 | @pytest.mark.parametrize('pos,valid', 51 | [((-1, -1), False), ((-1, 0), False), ((-1, 8), False), ((3, 8), False)]) 52 | def test_pos_validity(env, pos, valid): 53 | assert env.is_valid(pos) == valid 54 | 55 | 56 | @pytest.mark.parametrize('action_n,output', 57 | [([1, 1], # action 58 | ([[0.0, 0.71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 59 | 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], 60 | [0.67, 0.71, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 61 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 62 | {'lemon': 7, 'apple': 9}))]) # food_count 63 | def test_step(env, action_n, output): 64 | env.reset() 65 | target_obs_n, food_count = output 66 | obs_n, reward_n, done_n, info = env.step(action_n) 67 | 68 | assert obs_n == target_obs_n, 'observation does not match . Expected {}. Got {}'.format(target_obs_n, obs_n) 69 | for k, v in food_count.items(): 70 | assert info['food_count'][k] == food_count[k], '{} does not match'.format(k) 71 | assert env._step_count == 1 72 | assert env._total_episode_reward == reward_n, 'Total Episode reward doesn\'t match with one step reward' 73 | assert env._agent_dones == [False for _ in range(env.n_agents)] 74 | 75 | 76 | def test_reset_after_episode_end(env): 77 | env.reset() 78 | done = [False for _ in range(env.n_agents)] 79 | step_i = 0 80 | ep_reward = [0 for _ in range(env.n_agents)] 81 | while not all(done): 82 | step_i += 1 83 | _, reward_n, done, _ = env.step(env.action_space.sample()) 84 | for i in range(env.n_agents): 85 | ep_reward[i] += reward_n[i] 86 | 87 | assert step_i == env._step_count 88 | assert env._total_episode_reward == ep_reward 89 | test_reset(env) 90 | 91 | 92 | @parametrize('env', [fixture_ref(env), 93 | fixture_ref(env_full)]) 94 | def test_observation_space(env): 95 | obs = env.reset() 96 | assert env.observation_space.contains(obs) 97 | done = [False for _ in range(env.n_agents)] 98 | while not all(done): 99 | obs, reward_n, done, _ = env.step(env.action_space.sample()) 100 | assert env.observation_space.contains(obs) 101 | assert env.observation_space.contains(obs) 102 | assert env.observation_space.contains(env.observation_space.sample()) 103 | 104 | 105 | @parametrize('env', [fixture_ref(env)]) 106 | def test_rollout_env(env): 107 | actions = [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], 108 | [0, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4]] 109 | target_rewards = [[-10.01, -1.01], [9.99, 0.99], [-10.01, -1.01], [9.99, 0.99], 110 | [-10.01, -1.01], [9.99, 0.99], [-0.01, -0.01], [-0.01, -0.01], 111 | [-10.01, -0.01], [9.99, -0.01], [-10.01, -0.01], [9.99, -0.01], 112 | [-10.01, -0.01], [9.99, -0.01]] 113 | 114 | for episode_i in range(1): # multiple episode to validate the seq. again on reset. 115 | 116 | obs = env.reset() 117 | done = [False for _ in range(env.n_agents)] 118 | for step_i in range(len(actions)): 119 | obs, reward_n, done, _ = env.step(actions[step_i]) 120 | assert reward_n == target_rewards[step_i] 121 | step_i += 1 122 | 123 | assert done == [True for _m in range(env.n_agents)] 124 | 125 | 126 | @parametrize('env', [fixture_ref(env), 127 | fixture_ref(env_full)]) 128 | def test_max_steps(env): 129 | for episode_i in range(2): 130 | env.reset() 131 | done = [False for _ in range(env.n_agents)] 132 | step_i = 0 133 | while not all(done): 134 | obs, reward_n, done, _ = env.step([4, 4]) 135 | step_i += 1 136 | assert step_i == env._max_steps 137 | assert done == [True for _m in range(env.n_agents)] 138 | 139 | 140 | @parametrize('env', [fixture_ref(env), 141 | fixture_ref(env_full)]) 142 | def test_collision(env): 143 | for episode_i in range(2): 144 | env.reset() 145 | obs_1, reward_n, done, _ = env.step([0, 2]) 146 | obs_2, reward_n, done, _ = env.step([0, 2]) 147 | 148 | assert obs_1 == obs_2 149 | 150 | 151 | @parametrize('env', [fixture_ref(env), 152 | fixture_ref(env_full)]) 153 | def test_revisit_fruit_cell(env): 154 | for episode_i in range(2): 155 | env.reset() 156 | obs_1, reward_1, done, _ = env.step([1, 1]) 157 | obs_2, reward_2, done, _ = env.step([3, 3]) 158 | obs_3, reward_3, done, _ = env.step([1, 1]) 159 | 160 | assert reward_1 != reward_3 161 | -------------------------------------------------------------------------------- /tests/envs/test_combat.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import pytest 3 | from pytest_cases import parametrize, fixture_ref 4 | 5 | 6 | @pytest.fixture(scope='module') 7 | def env(): 8 | env = gym.make('ma_gym:Combat-v0') 9 | yield env 10 | env.close() 11 | 12 | 13 | def test_init(env): 14 | assert env.n_agents == 5 15 | 16 | 17 | def test_reset(env): 18 | env.reset() 19 | assert env._step_count == 0 20 | assert env._total_episode_reward == [0 for _ in range(env.n_agents)] 21 | assert env._agent_dones == [False for _ in range(env.n_agents)] 22 | 23 | 24 | def test_reset_after_episode_end(env): 25 | env.reset() 26 | done = [False for _ in range(env.n_agents)] 27 | step_i = 0 28 | ep_reward = [0 for _ in range(env.n_agents)] 29 | while not all(done): 30 | step_i += 1 31 | _, reward_n, done, _ = env.step(env.action_space.sample()) 32 | for i in range(env.n_agents): 33 | ep_reward[i] += reward_n[i] 34 | 35 | assert step_i == env._step_count 36 | assert env._total_episode_reward == ep_reward 37 | test_reset(env) 38 | 39 | 40 | @parametrize('env', 41 | [fixture_ref(env)]) 42 | def test_observation_space(env): 43 | obs = env.reset() 44 | assert env.observation_space.contains(obs) 45 | done = [False for _ in range(env.n_agents)] 46 | while not all(done): 47 | obs, reward_n, done, _ = env.step(env.action_space.sample()) 48 | assert env.observation_space.contains(obs) 49 | assert env.observation_space.contains(obs) 50 | assert env.observation_space.contains(env.observation_space.sample()) 51 | -------------------------------------------------------------------------------- /tests/envs/test_lumberjacks.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import pytest 4 | from pytest_cases import fixture_ref, parametrize 5 | 6 | 7 | @pytest.fixture(scope='module') 8 | def env(): 9 | env = gym.make('ma_gym:Lumberjacks-v0') 10 | yield env 11 | env.close() 12 | 13 | 14 | @pytest.fixture(scope='module') 15 | def env_full(): 16 | env = gym.make('ma_gym:Lumberjacks-v1') 17 | yield env 18 | env.close() 19 | 20 | 21 | def test_init(env): 22 | assert env.n_agents == 2 23 | assert env._n_trees == 12 24 | assert env._agent_view == (1, 1) 25 | 26 | 27 | def test_reset(env): 28 | agent_map = env._agent_map 29 | tree_map = env._tree_map 30 | 31 | env.reset() 32 | assert env._step_count == 0 33 | assert len(env._agents) == np.sum(env._agent_map) == 2 34 | assert np.sum(env._tree_map > 0) == 12 35 | for agent_id, agent in env._agent_generator(): 36 | assert env._agent_dones[agent_id] == False, 'Game cannot finished after reset' 37 | assert env._total_episode_reward[agent_id] == 0, 'Total Episode reward doesn\'t match with one step reward' 38 | assert np.sum((env._tree_map < 0) & (env._tree_map > 2)) == 0 39 | assert not (agent_map == env._agent_map).all(), 'Initial possition of agents must be different on each reset.' 40 | assert not (tree_map == env._tree_map).all(), 'Initial possition of trees must be different on each reset.' 41 | 42 | 43 | def test_seed(env): 44 | env.seed(5) 45 | env.reset() 46 | agent_map = env._agent_map 47 | tree_map = env._tree_map 48 | 49 | env.seed(5) 50 | env.reset() 51 | assert (agent_map == env._agent_map).all(), 'Initial possition of agents must be the same on reset with same seed.' 52 | assert (tree_map == env._tree_map).all(), 'Initial possition of trees must be the same on reset with same seed.' 53 | 54 | 55 | @pytest.mark.parametrize('action_n', [[0, 0]]) # no-op action 56 | def test_step(env, action_n): 57 | env.reset() 58 | obs_n, reward_n, done_n, info = env.step(action_n) 59 | 60 | assert env._step_count == 1 61 | for (agent_id, agent), reward in zip(env._agent_generator(), reward_n): 62 | assert env._agent_dones[agent_id] == False, 'Game cannot finished after one step' 63 | assert env._total_episode_reward[agent_id] == reward, 'Total Episode reward doesn\'t match with one step reward' 64 | 65 | 66 | def test_reset_after_episode_end(env): 67 | env.reset() 68 | done = [False for _ in range(env.n_agents)] 69 | step_i = 0 70 | ep_reward = [0 for _ in range(env.n_agents)] 71 | while not all(done): 72 | step_i += 1 73 | _, reward_n, done, _ = env.step(env.action_space.sample()) 74 | for i in range(env.n_agents): 75 | ep_reward[i] += reward_n[i] 76 | 77 | assert step_i == env._step_count 78 | for (agent_id, agent), reward in zip(env._agent_generator(), ep_reward): 79 | assert env._agent_dones[agent_id] == True 80 | assert env._total_episode_reward[agent_id] == reward, 'Total Episode reward doesn\'t match with one step reward' 81 | test_reset(env) 82 | 83 | 84 | @parametrize('env', 85 | [fixture_ref(env), 86 | fixture_ref(env_full)]) 87 | def test_observation_space(env): 88 | obs = env.reset() 89 | assert env.observation_space.contains(obs) 90 | done = [False for _ in range(env.n_agents)] 91 | while not all(done): 92 | obs, reward_n, done, _ = env.step(env.action_space.sample()) 93 | assert env.observation_space.contains(obs) 94 | assert env.observation_space.contains(obs) 95 | assert env.observation_space.contains(env.observation_space.sample()) 96 | -------------------------------------------------------------------------------- /tests/envs/test_openai_cartpole.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import pytest 3 | 4 | 5 | @pytest.fixture(scope='module') 6 | def env(): 7 | env = gym.make('ma_gym:ma_CartPole-v0') 8 | yield env 9 | env.close() 10 | 11 | 12 | def test_init(env): 13 | assert env.n_agents == 1 14 | 15 | 16 | def test_reset(env): 17 | env.reset() 18 | assert env._step_count == 0 19 | assert env._total_episode_reward == [0 for _ in range(env.n_agents)] 20 | assert env._agent_dones == [False for _ in range(env.n_agents)] 21 | 22 | 23 | def test_reset_after_episode_end(env): 24 | env.reset() 25 | done = [False for _ in range(env.n_agents)] 26 | ep_reward = [0 for _ in range(env.n_agents)] 27 | step_i = 0 28 | while not all(done): 29 | step_i += 1 30 | _, reward_n, done, _ = env.step(env.action_space.sample()) 31 | for i in range(env.n_agents): 32 | ep_reward[i] += reward_n[i] 33 | 34 | assert env._step_count == step_i 35 | assert env._total_episode_reward == ep_reward 36 | test_reset(env) 37 | 38 | 39 | def test_observation_space(env): 40 | obs = env.reset() 41 | assert env.observation_space.contains(obs) 42 | done = [False for _ in range(env.n_agents)] 43 | while not all(done): 44 | _, reward_n, done, _ = env.step(env.action_space.sample()) 45 | assert env.observation_space.contains(obs) 46 | assert env.observation_space.contains(env.observation_space.sample()) 47 | -------------------------------------------------------------------------------- /tests/envs/test_pong_duel.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import pytest 3 | 4 | 5 | @pytest.fixture(scope='module') 6 | def env(): 7 | env = gym.make('ma_gym:PongDuel-v0') 8 | yield env 9 | env.close() 10 | 11 | 12 | def test_init(env): 13 | assert env.n_agents == 2 14 | 15 | 16 | def test_reset(env): 17 | env.reset() 18 | assert env._step_count == 0 19 | assert env._total_episode_reward == [0 for _ in range(env.n_agents)] 20 | assert env._agent_dones == [False for _ in range(env.n_agents)] 21 | 22 | 23 | def test_reset_after_episode_end(env): 24 | env.reset() 25 | done = [False for _ in range(env.n_agents)] 26 | step_i = 0 27 | ep_reward = [0 for _ in range(env.n_agents)] 28 | while not all(done): 29 | step_i += 1 30 | _, reward_n, done, _ = env.step(env.action_space.sample()) 31 | for i in range(env.n_agents): 32 | ep_reward[i] += reward_n[i] 33 | 34 | assert step_i == env._step_count 35 | assert env._total_episode_reward == ep_reward 36 | test_reset(env) 37 | 38 | 39 | def test_observation_space(env): 40 | obs = env.reset() 41 | assert env.observation_space.contains(obs) 42 | done = [False for _ in range(env.n_agents)] 43 | while not all(done): 44 | _, reward_n, done, _ = env.step(env.action_space.sample()) 45 | assert env.observation_space.contains(obs) 46 | assert env.observation_space.contains(env.observation_space.sample()) 47 | -------------------------------------------------------------------------------- /tests/envs/test_predatorprey5x5.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import pytest 3 | from pytest_cases import parametrize, fixture_ref 4 | 5 | 6 | @pytest.fixture(scope='module') 7 | def env(): 8 | env = gym.make('ma_gym:PredatorPrey5x5-v0') 9 | yield env 10 | env.close() 11 | 12 | 13 | @pytest.fixture(scope='module') 14 | def env_full(): 15 | env = gym.make('ma_gym:PredatorPrey5x5-v1') 16 | yield env 17 | env.close() 18 | 19 | 20 | def test_init(env): 21 | assert env.n_agents == 2 22 | assert env.n_preys == 1 23 | 24 | 25 | def test_reset(env): 26 | env.reset() 27 | assert env._step_count == 0 28 | assert env._total_episode_reward == [0 for _ in range(env.n_agents)] 29 | assert env._agent_dones == [False for _ in range(env.n_agents)] 30 | assert env._prey_alive == [True for _ in range(env.n_preys)] 31 | 32 | 33 | @pytest.mark.parametrize('action_n,output', [([4, 4], [])]) # no-op action 34 | def test_step(env, action_n, output): 35 | env.reset() 36 | obs_n, reward_n, done_n, info = env.step(action_n) 37 | assert env._step_count == 1 38 | assert env._total_episode_reward == reward_n, 'Total Episode reward doesn\'t match with one step reward' 39 | 40 | 41 | def test_reset_after_episode_end(env): 42 | env.reset() 43 | done = [False for _ in range(env.n_agents)] 44 | step_i = 0 45 | ep_reward = [0 for _ in range(env.n_agents)] 46 | while not all(done): 47 | step_i += 1 48 | _, reward_n, done, _ = env.step(env.action_space.sample()) 49 | for i in range(env.n_agents): 50 | ep_reward[i] += reward_n[i] 51 | 52 | assert step_i == env._step_count 53 | assert env._total_episode_reward == ep_reward 54 | test_reset(env) 55 | 56 | 57 | @parametrize('env', [fixture_ref(env), 58 | fixture_ref(env_full)]) 59 | def test_observation_space(env): 60 | obs = env.reset() 61 | assert env.observation_space.contains(obs) 62 | done = [False for _ in range(env.n_agents)] 63 | while not all(done): 64 | obs, reward_n, done, _ = env.step(env.action_space.sample()) 65 | assert env.observation_space.contains(obs) 66 | assert env.observation_space.contains(obs) 67 | assert env.observation_space.contains(env.observation_space.sample()) 68 | -------------------------------------------------------------------------------- /tests/envs/test_predatorprey7x7.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import pytest 3 | from pytest_cases import parametrize, fixture_ref 4 | 5 | 6 | @pytest.fixture(scope='module') 7 | def env(): 8 | env = gym.make('ma_gym:PredatorPrey7x7-v0') 9 | yield env 10 | env.close() 11 | 12 | 13 | @pytest.fixture(scope='module') 14 | def env_full(): 15 | env = gym.make('ma_gym:PredatorPrey7x7-v1') 16 | yield env 17 | env.close() 18 | 19 | 20 | def test_init(env): 21 | assert env.n_agents == 4 22 | 23 | 24 | def test_reset(env): 25 | env.reset() 26 | assert env._step_count == 0 27 | assert env._total_episode_reward == [0 for _ in range(env.n_agents)] 28 | assert env._agent_dones == [False for _ in range(env.n_agents)] 29 | assert env._prey_alive == [True for _ in range(env.n_preys)] 30 | 31 | 32 | def test_reset_after_episode_end(env): 33 | env.reset() 34 | done = [False for _ in range(env.n_agents)] 35 | step_i = 0 36 | ep_reward = [0 for _ in range(env.n_agents)] 37 | while not all(done): 38 | step_i += 1 39 | _, reward_n, done, _ = env.step(env.action_space.sample()) 40 | for i in range(env.n_agents): 41 | ep_reward[i] += reward_n[i] 42 | 43 | assert step_i == env._step_count 44 | assert env._total_episode_reward == ep_reward 45 | test_reset(env) 46 | 47 | 48 | @parametrize('env', [fixture_ref(env), 49 | fixture_ref(env_full)]) 50 | def test_observation_space(env): 51 | obs = env.reset() 52 | assert env.observation_space.contains(obs) 53 | done = [False for _ in range(env.n_agents)] 54 | while not all(done): 55 | obs, reward_n, done, _ = env.step(env.action_space.sample()) 56 | assert env.observation_space.contains(obs) 57 | assert env.observation_space.contains(obs) 58 | assert env.observation_space.contains(env.observation_space.sample()) 59 | -------------------------------------------------------------------------------- /tests/envs/test_switch2.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import pytest 3 | from pytest_cases import parametrize, fixture_ref 4 | 5 | 6 | @pytest.fixture(scope='module') 7 | def env(): 8 | env = gym.make('ma_gym:Switch2-v0') 9 | yield env 10 | env.close() 11 | 12 | 13 | @pytest.fixture(scope='module') 14 | def env_full(): 15 | env = gym.make('ma_gym:Switch2-v1') 16 | yield env 17 | env.close() 18 | 19 | 20 | def test_init(env): 21 | assert env.n_agents == 2 22 | 23 | 24 | def test_reset(env): 25 | obs_n = env.reset() 26 | 27 | target_obs_n = [[0, 0.17], [0, 0.83]] 28 | assert env._step_count == 0 29 | assert env._total_episode_reward == [0 for _ in range(env.n_agents)] 30 | assert env._agent_dones == [False for _ in range(env.n_agents)] 31 | 32 | for i in range(env.n_agents): 33 | assert obs_n[i] == target_obs_n[i] 34 | 35 | 36 | def test_reset_after_episode_end(env): 37 | env.reset() 38 | done = [False for _ in range(env.n_agents)] 39 | step_i = 0 40 | while not all(done): 41 | step_i += 1 42 | _, _, done, info = env.step(env.action_space.sample()) 43 | 44 | assert step_i == env._step_count 45 | test_reset(env) 46 | 47 | 48 | @pytest.mark.parametrize('action_n,output', 49 | [([1, 1], # action 50 | ([[0.0, 0.00], [0, 0.83]]) # obs 51 | )]) 52 | def test_step(env, action_n, output): 53 | obs_n = env.reset() 54 | target_obs_n = output 55 | obs_n, reward_n, done_n, info = env.step(action_n) 56 | 57 | assert env._step_count == 1 58 | assert env._total_episode_reward == reward_n, 'Total Episode reward doesn\'t match with one step reward' 59 | assert env._agent_dones == [False for _ in range(env.n_agents)] 60 | assert obs_n == target_obs_n 61 | 62 | 63 | def test_reset_after_episode_end(env): 64 | env.reset() 65 | done = [False for _ in range(env.n_agents)] 66 | ep_reward = [0 for _ in range(env.n_agents)] 67 | step_i = 0 68 | while not all(done): 69 | step_i += 1 70 | _, reward_n, done, _ = env.step(env.action_space.sample()) 71 | for i in range(env.n_agents): 72 | ep_reward[i] += reward_n[i] 73 | 74 | assert step_i == env._step_count 75 | assert env._total_episode_reward == ep_reward 76 | test_reset(env) 77 | 78 | 79 | @parametrize('env', 80 | [fixture_ref(env), 81 | fixture_ref(env_full)]) 82 | def test_observation_space(env): 83 | obs = env.reset() 84 | assert env.observation_space.contains(obs) 85 | done = [False for _ in range(env.n_agents)] 86 | while not all(done): 87 | obs, reward_n, done, _ = env.step(env.action_space.sample()) 88 | assert env.observation_space.contains(obs) 89 | assert env.observation_space.contains(obs) 90 | assert env.observation_space.contains(env.observation_space.sample()) 91 | 92 | 93 | @parametrize('env', 94 | [fixture_ref(env), 95 | fixture_ref(env_full)]) 96 | def test_observation_space(env): 97 | obs = env.reset() 98 | assert env.observation_space.contains(obs) 99 | done = [False for _ in range(env.n_agents)] 100 | while not all(done): 101 | obs, reward_n, done, _ = env.step(env.action_space.sample()) 102 | assert env.observation_space.contains(obs) 103 | assert env.observation_space.contains(obs) 104 | assert env.observation_space.contains(env.observation_space.sample()) 105 | 106 | 107 | @parametrize('env', 108 | [fixture_ref(env), 109 | fixture_ref(env_full)]) 110 | def test_optimal_rollout(env): 111 | actions = [[4, 0], [4, 1], [4, 1], [4, 1], [4, 1], [4, 1], [0, 2], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [2, 4]] 112 | target_rewards = [[-0.1, -0.1], [-0.1, -0.1], [-0.1, -0.1], [-0.1, -0.1], [-0.1, -0.1], [-0.1, -0.1], [-0.1, 5], 113 | [-0.1, 0], [-0.1, 0], [-0.1, 0], [-0.1, 0], [-0.1, 0], [5, 0]] 114 | target_dones = [[False, False], [False, False], [False, False], [False, False], [False, False], [False, False], 115 | [False, True], [False, True], [False, True], [False, True], [False, True], [False, True], 116 | [True, True]] 117 | 118 | for _ in range(2): # multiple episodes to ensure it works after reset as well 119 | env.reset() 120 | done = [False for _ in range(env.n_agents)] 121 | step_i = 0 122 | while not all(done): 123 | obs, reward_n, done, _ = env.step(actions[step_i]) 124 | assert reward_n == target_rewards[step_i], 'Expected {}, Got {} at step {}'.format(target_rewards[step_i], 125 | reward_n, step_i) 126 | assert done == target_dones[step_i] 127 | step_i += 1 128 | 129 | 130 | @parametrize('env', 131 | [fixture_ref(env), 132 | fixture_ref(env_full)]) 133 | def test_max_steps(env): 134 | """ All agent remain at their initial position for the entire duration""" 135 | for _ in range(2): 136 | env.reset() 137 | step_i = 0 138 | done = [False for _ in range(env.n_agents)] 139 | while not all(done): 140 | obs, reward_n, done, _ = env.step([4 for _ in range(env.n_agents)]) 141 | target_reward = [env._step_cost for _ in range(env.n_agents)] 142 | step_i += 1 143 | assert (reward_n == target_reward), \ 144 | 'step_cost is not correct. Expected {} ; Got {}'.format(target_reward, reward_n) 145 | assert step_i == env._max_steps, 'max-steps should be reached' 146 | -------------------------------------------------------------------------------- /tests/envs/test_trafficjunction.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import pytest 4 | from pytest_cases import parametrize, fixture_ref 5 | 6 | 7 | @pytest.fixture(scope='module') 8 | def env_4(): 9 | env = gym.make('ma_gym:TrafficJunction4-v0') 10 | yield env 11 | env.close() 12 | 13 | 14 | @pytest.fixture(scope='module') 15 | def env_10(): 16 | env = gym.make('ma_gym:TrafficJunction10-v0') 17 | yield env 18 | env.close() 19 | 20 | 21 | @parametrize('env', 22 | [fixture_ref(env_4), fixture_ref(env_10)]) 23 | def test_init(env): 24 | assert 1 <= env._n_max <= 10, 'N_max must be between 1 and 10, got {}'.format(env._n_max) 25 | assert env._on_the_road.count(True) <= len(env._entry_gates), 'Cars on the road after initializing cannot ' \ 26 | 'be higher than {}'.format(len(env._entry_gates)) 27 | 28 | 29 | @parametrize('env', 30 | [fixture_ref(env_4), fixture_ref(env_10)]) 31 | def test_reset(env): 32 | env.reset() 33 | assert env._step_count == 0, 'Step count should be 0 after reset, got {}'.format(env._step_count) 34 | assert env._agent_step_count == [0 for _ in range(env.n_agents)], 'Agent step count should be 0 for all agents' \ 35 | ' after reset' 36 | assert env._total_episode_reward == [0 for _ in range(env.n_agents)], 'Total reward should be 0 after reset' 37 | assert env._agent_dones == [False for _ in range(env.n_agents)], 'Agents cannot be done when the environment' \ 38 | ' resets' 39 | assert env._agent_turned == [False for _ in range(env.n_agents)], 'Agents cannot have changed direction ' \ 40 | ' when the environment resets' 41 | 42 | 43 | @parametrize('env', 44 | [fixture_ref(env_4), fixture_ref(env_10)]) 45 | def test_reset_after_episode_end(env): 46 | env.reset() 47 | done = [False for _ in range(env.n_agents)] 48 | step_i = 0 49 | ep_reward = [0 for _ in range(env.n_agents)] 50 | while not all(done): 51 | step_i += 1 52 | _, reward_n, done, _ = env.step(env.action_space.sample()) 53 | for i in range(env.n_agents): 54 | ep_reward[i] += reward_n[i] 55 | 56 | assert step_i == env._step_count, 'Number of steps after an episode must match ' \ 57 | 'env._step_count, got {} and {}'.format(step_i, env._step_count) 58 | assert env._total_episode_reward == ep_reward, 'Total reward after an episode must match' \ 59 | ' env._total_episode_reward, ' \ 60 | 'got {} and {}'.format(ep_reward, env._total_episode_reward) 61 | test_reset(env) 62 | 63 | 64 | @parametrize('env', 65 | [fixture_ref(env_4), fixture_ref(env_10)]) 66 | def test_observation_space(env): 67 | obs = env.reset() 68 | expected_agent_i_shape = (np.prod(env._agent_view_mask) * (env.n_agents + 2 + 3),) 69 | for agent_i in range(env.n_agents): 70 | assert obs[agent_i].shape == expected_agent_i_shape, \ 71 | 'shape of obs. expected to be {}; but found to be {}'.format(expected_agent_i_shape, obs[agent_i].shape) 72 | 73 | assert env.observation_space.contains(obs), 'Observation must be part of the observation space' 74 | done = [False for _ in range(env.n_agents)] 75 | while not all(done): 76 | obs, reward_n, done, _ = env.step(env.action_space.sample()) 77 | assert env.observation_space.contains(obs), 'Observation must be part of the observation space' 78 | assert env.observation_space.contains(obs), 'Observation must be part of the observation space' 79 | assert env.observation_space.contains(env.observation_space.sample()), 'Observation must be part of the' \ 80 | ' observation space' 81 | 82 | 83 | @parametrize('env', 84 | [fixture_ref(env_4)]) 85 | def test_step_cost_env4(env): 86 | env.reset() 87 | for step_i in range(3): # small number of steps so that no collision occurs 88 | obs, reward_n, done, _ = env.step(env.action_space.sample()) 89 | target_reward = [env._step_cost * (step_i + 1) for _ in range(env.n_agents)] 90 | assert (reward_n == target_reward), \ 91 | 'step_cost is not correct. Expected {} ; Got {}'.format(target_reward, reward_n) 92 | 93 | 94 | @parametrize('env', 95 | [fixture_ref(env_10)]) 96 | def test_step_cost_env10(env): 97 | env.reset() 98 | for step_i in range(1): # just 1 step so that no collision occurs 99 | obs, reward_n, done, _ = env.step(env.action_space.sample()) 100 | target_reward = [env._step_cost * (step_i + 1) for _ in range(4)] + [0 for _ in range(env.n_agents - 4)] 101 | assert (reward_n == target_reward), \ 102 | 'step_cost is not correct. Expected {} ; Got {}'.format(target_reward, reward_n) 103 | 104 | 105 | @parametrize('env', 106 | [fixture_ref(env_4)]) 107 | def test_all_brake_rollout_env4(env): 108 | """ All agent apply brake for the entire duration""" 109 | for _ in range(2): 110 | env.reset() 111 | step_i = 0 112 | done = [False for _ in range(env.n_agents)] 113 | while not all(done): # small number of steps so that no collision occurs 114 | obs, reward_n, done, _ = env.step([1 for _ in range(env.n_agents)]) 115 | target_reward = [env._step_cost * (step_i + 1) for _ in range(env.n_agents)] 116 | step_i += 1 117 | assert (reward_n == target_reward), \ 118 | 'step_cost is not correct. Expected {} ; Got {}'.format(target_reward, reward_n) 119 | assert step_i == env._max_steps, 'max-steps should be reached' 120 | 121 | 122 | @parametrize('env', 123 | [fixture_ref(env_4)]) 124 | def test_one_gas_others_brake_rollout_env4(env): 125 | """ 126 | "Agent 0" applies gas and others brake. This will mean that there will not be any collision and "Agent 0" will 127 | reach it's destination in minimal number of steps; beyond which reward for agent "0" would be 0. 128 | """ 129 | 130 | # testing over multiple episode to ensure it works with multiple routes assigned to "agent 0" 131 | for episode_i in range(5): 132 | obs = env.reset() 133 | step_i = 0 134 | route_max_steps = [13, 12, 14] # routes [fwd, r, l] 135 | done = [False for _ in range(env.n_agents)] 136 | agent_0_route = obs[0].reshape((9, 9))[4][6:] # one-hot 137 | agent_0_route_idx = np.where(agent_0_route == 1)[0][0] 138 | max_agent_0_steps = route_max_steps[agent_0_route_idx] 139 | while not all(done): # small number of steps so that no collision occurs 140 | obs, reward_n, done, _ = env.step([0] + [1 for _ in range(env.n_agents - 1)]) 141 | target_reward = [env._step_cost * (step_i + 1) for _ in range(env.n_agents)] 142 | # once the car reaches destination, there is no step cost 143 | if step_i >= max_agent_0_steps: 144 | target_reward[0] = 0 145 | step_i += 1 146 | assert (reward_n == target_reward), \ 147 | 'step_cost is not correct. Expected {} ; Got {}, Episode {} Agent 0 route: {} '.format(target_reward, 148 | reward_n, 149 | episode_i, 150 | agent_0_route) 151 | assert step_i == env._max_steps, 'max-steps should be reached' 152 | 153 | 154 | @parametrize('env', 155 | [fixture_ref(env_4)]) 156 | def test_all_gas_all_routes_forward_rollout_env4(env): 157 | """ 158 | All the agents apply gas and follow the forward route (1). This will mean that all the agents 159 | will reach the junction at the same time and then will collide in the junction with each other 160 | in a deadlock and never reach the destination. 161 | """ 162 | 163 | obs = env.reset() 164 | env._agents_routes = [1 for _ in range(env.n_agents)] # changes all routes to fwd 165 | step_i = 0 166 | done = [False for _ in range(env.n_agents)] 167 | while not all(done): 168 | _, reward_n, done, info = env.step([0 for _ in range(env.n_agents)]) # all gas 169 | target_reward = [env._step_cost * (step_i + 1) for _ in range(env.n_agents)] 170 | if step_i >= 6: 171 | target_reward = [target_reward[agent_i] + env._collision_reward for agent_i in range(env.n_agents)] 172 | assert info['step_collisions'] == 4, 'collision count is not correct. ' \ 173 | 'Expected 4 ; Got {}'.format(info['step_collisions']) 174 | step_i += 1 175 | assert (reward_n == target_reward), \ 176 | 'collision reward is not correct. Expected {} ; Got {} '.format(target_reward, reward_n) 177 | assert step_i == env._max_steps, 'max-steps should be reached' 178 | --------------------------------------------------------------------------------